mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-20 07:46:42 -05:00
Compare commits
16 Commits
meta-insta
...
fix-partia
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
27b4993e64 | ||
|
|
bddad7e79c | ||
|
|
addf73a144 | ||
|
|
a16ff2c047 | ||
|
|
3006c8ea4e | ||
|
|
f662c129dd | ||
|
|
c45ff9ad43 | ||
|
|
7031901ae5 | ||
|
|
cf648a53b8 | ||
|
|
94b2ce6922 | ||
|
|
423ed0f07f | ||
|
|
ed001f2409 | ||
|
|
4c4c6ce99f | ||
|
|
42e1e7322b | ||
|
|
aa3f106fb9 | ||
|
|
526cd9f333 |
@@ -20,6 +20,7 @@ from harness import (
|
|||||||
instance_id_from_instance,
|
instance_id_from_instance,
|
||||||
nodes_used_in_instance,
|
nodes_used_in_instance,
|
||||||
resolve_model_short_id,
|
resolve_model_short_id,
|
||||||
|
run_planning_phase,
|
||||||
settle_and_fetch_placements,
|
settle_and_fetch_placements,
|
||||||
wait_for_instance_gone,
|
wait_for_instance_gone,
|
||||||
wait_for_instance_ready,
|
wait_for_instance_ready,
|
||||||
@@ -962,6 +963,21 @@ Examples:
|
|||||||
|
|
||||||
selected.sort(key=_placement_sort_key)
|
selected.sort(key=_placement_sort_key)
|
||||||
preview = selected[0]
|
preview = selected[0]
|
||||||
|
|
||||||
|
settle_deadline = (
|
||||||
|
time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Planning phase: checking downloads...", file=log)
|
||||||
|
run_planning_phase(
|
||||||
|
exo,
|
||||||
|
full_model_id,
|
||||||
|
preview,
|
||||||
|
args.danger_delete_downloads,
|
||||||
|
args.timeout,
|
||||||
|
settle_deadline,
|
||||||
|
)
|
||||||
|
|
||||||
instance = preview["instance"]
|
instance = preview["instance"]
|
||||||
instance_id = instance_id_from_instance(instance)
|
instance_id = instance_id_from_instance(instance)
|
||||||
sharding = str(preview["sharding"])
|
sharding = str(preview["sharding"])
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ from harness import (
|
|||||||
instance_id_from_instance,
|
instance_id_from_instance,
|
||||||
nodes_used_in_instance,
|
nodes_used_in_instance,
|
||||||
resolve_model_short_id,
|
resolve_model_short_id,
|
||||||
|
run_planning_phase,
|
||||||
settle_and_fetch_placements,
|
settle_and_fetch_placements,
|
||||||
wait_for_instance_gone,
|
wait_for_instance_gone,
|
||||||
wait_for_instance_ready,
|
wait_for_instance_ready,
|
||||||
@@ -332,6 +333,20 @@ def main() -> int:
|
|||||||
if args.dry_run:
|
if args.dry_run:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
settle_deadline = (
|
||||||
|
time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Planning phase: checking downloads...")
|
||||||
|
run_planning_phase(
|
||||||
|
client,
|
||||||
|
full_model_id,
|
||||||
|
selected[0],
|
||||||
|
args.danger_delete_downloads,
|
||||||
|
args.timeout,
|
||||||
|
settle_deadline,
|
||||||
|
)
|
||||||
|
|
||||||
all_rows: list[dict[str, Any]] = []
|
all_rows: list[dict[str, Any]] = []
|
||||||
|
|
||||||
for preview in selected:
|
for preview in selected:
|
||||||
|
|||||||
150
bench/harness.py
150
bench/harness.py
@@ -282,6 +282,151 @@ def settle_and_fetch_placements(
|
|||||||
return selected
|
return selected
|
||||||
|
|
||||||
|
|
||||||
|
def run_planning_phase(
|
||||||
|
client: ExoClient,
|
||||||
|
full_model_id: str,
|
||||||
|
preview: dict[str, Any],
|
||||||
|
danger_delete: bool,
|
||||||
|
timeout: float,
|
||||||
|
settle_deadline: float | None,
|
||||||
|
) -> None:
|
||||||
|
"""Check disk space and ensure model is downloaded before benchmarking."""
|
||||||
|
# Get model size from /models
|
||||||
|
models = client.request_json("GET", "/models") or {}
|
||||||
|
model_bytes = 0
|
||||||
|
for m in models.get("data", []):
|
||||||
|
if m.get("hugging_face_id") == full_model_id:
|
||||||
|
model_bytes = m.get("storage_size_megabytes", 0) * 1024 * 1024
|
||||||
|
break
|
||||||
|
|
||||||
|
if not model_bytes:
|
||||||
|
logger.warning(
|
||||||
|
f"Could not determine size for {full_model_id}, skipping disk check"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get nodes from preview
|
||||||
|
inner = unwrap_instance(preview["instance"])
|
||||||
|
node_ids = list(inner["shardAssignments"]["nodeToRunner"].keys())
|
||||||
|
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
|
||||||
|
|
||||||
|
state = client.request_json("GET", "/state")
|
||||||
|
downloads = state.get("downloads", {})
|
||||||
|
node_disk = state.get("nodeDisk", {})
|
||||||
|
|
||||||
|
for node_id in node_ids:
|
||||||
|
node_downloads = downloads.get(node_id, [])
|
||||||
|
|
||||||
|
# Check if model already downloaded on this node
|
||||||
|
already_downloaded = any(
|
||||||
|
"DownloadCompleted" in p
|
||||||
|
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
|
||||||
|
"modelId"
|
||||||
|
]
|
||||||
|
== full_model_id
|
||||||
|
for p in node_downloads
|
||||||
|
)
|
||||||
|
if already_downloaded:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Wait for disk info if settle_deadline is set
|
||||||
|
disk_info = node_disk.get(node_id, {})
|
||||||
|
backoff = _SETTLE_INITIAL_BACKOFF_S
|
||||||
|
while not disk_info and settle_deadline and time.monotonic() < settle_deadline:
|
||||||
|
remaining = settle_deadline - time.monotonic()
|
||||||
|
logger.info(
|
||||||
|
f"Waiting for disk info on {node_id} ({remaining:.0f}s remaining)..."
|
||||||
|
)
|
||||||
|
time.sleep(min(backoff, remaining))
|
||||||
|
backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)
|
||||||
|
state = client.request_json("GET", "/state")
|
||||||
|
node_disk = state.get("nodeDisk", {})
|
||||||
|
disk_info = node_disk.get(node_id, {})
|
||||||
|
|
||||||
|
if not disk_info:
|
||||||
|
logger.warning(f"No disk info for {node_id}, skipping space check")
|
||||||
|
continue
|
||||||
|
|
||||||
|
avail = disk_info.get("available", {}).get("inBytes", 0)
|
||||||
|
if avail >= model_bytes:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not danger_delete:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Insufficient disk on {node_id}: need {model_bytes // (1024**3)}GB, "
|
||||||
|
f"have {avail // (1024**3)}GB. Use --danger-delete-downloads to free space."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Delete from smallest to largest
|
||||||
|
completed = [
|
||||||
|
(
|
||||||
|
unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
|
||||||
|
"modelId"
|
||||||
|
],
|
||||||
|
p["DownloadCompleted"]["totalBytes"]["inBytes"],
|
||||||
|
)
|
||||||
|
for p in node_downloads
|
||||||
|
if "DownloadCompleted" in p
|
||||||
|
]
|
||||||
|
for del_model, size in sorted(completed, key=lambda x: x[1]):
|
||||||
|
logger.info(f"Deleting {del_model} from {node_id} ({size // (1024**2)}MB)")
|
||||||
|
client.request_json("DELETE", f"/download/{node_id}/{del_model}")
|
||||||
|
avail += size
|
||||||
|
if avail >= model_bytes:
|
||||||
|
break
|
||||||
|
|
||||||
|
if avail < model_bytes:
|
||||||
|
raise RuntimeError(f"Could not free enough space on {node_id}")
|
||||||
|
|
||||||
|
# Start downloads (idempotent)
|
||||||
|
for node_id in node_ids:
|
||||||
|
runner_id = inner["shardAssignments"]["nodeToRunner"][node_id]
|
||||||
|
shard = runner_to_shard[runner_id]
|
||||||
|
client.request_json(
|
||||||
|
"POST",
|
||||||
|
"/download/start",
|
||||||
|
body={
|
||||||
|
"targetNodeId": node_id,
|
||||||
|
"shardMetadata": shard,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
logger.info(f"Started download on {node_id}")
|
||||||
|
|
||||||
|
# Wait for downloads
|
||||||
|
start = time.time()
|
||||||
|
while time.time() - start < timeout:
|
||||||
|
state = client.request_json("GET", "/state")
|
||||||
|
downloads = state.get("downloads", {})
|
||||||
|
all_done = True
|
||||||
|
for node_id in node_ids:
|
||||||
|
done = any(
|
||||||
|
"DownloadCompleted" in p
|
||||||
|
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])[
|
||||||
|
"modelCard"
|
||||||
|
]["modelId"]
|
||||||
|
== full_model_id
|
||||||
|
for p in downloads.get(node_id, [])
|
||||||
|
)
|
||||||
|
failed = [
|
||||||
|
p["DownloadFailed"]["errorMessage"]
|
||||||
|
for p in downloads.get(node_id, [])
|
||||||
|
if "DownloadFailed" in p
|
||||||
|
and unwrap_instance(p["DownloadFailed"]["shardMetadata"])["modelCard"][
|
||||||
|
"modelId"
|
||||||
|
]
|
||||||
|
== full_model_id
|
||||||
|
]
|
||||||
|
if failed:
|
||||||
|
raise RuntimeError(f"Download failed on {node_id}: {failed[0]}")
|
||||||
|
if not done:
|
||||||
|
all_done = False
|
||||||
|
if all_done:
|
||||||
|
return
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
raise TimeoutError("Downloads did not complete in time")
|
||||||
|
|
||||||
|
|
||||||
def add_common_instance_args(ap: argparse.ArgumentParser) -> None:
|
def add_common_instance_args(ap: argparse.ArgumentParser) -> None:
|
||||||
ap.add_argument("--host", default=os.environ.get("EXO_HOST", "localhost"))
|
ap.add_argument("--host", default=os.environ.get("EXO_HOST", "localhost"))
|
||||||
ap.add_argument(
|
ap.add_argument(
|
||||||
@@ -325,3 +470,8 @@ def add_common_instance_args(ap: argparse.ArgumentParser) -> None:
|
|||||||
default=0,
|
default=0,
|
||||||
help="Max seconds to wait for the cluster to produce valid placements (0 = try once).",
|
help="Max seconds to wait for the cluster to produce valid placements (0 = try once).",
|
||||||
)
|
)
|
||||||
|
ap.add_argument(
|
||||||
|
"--danger-delete-downloads",
|
||||||
|
action="store_true",
|
||||||
|
help="Delete existing models from smallest to largest to make room for benchmark model.",
|
||||||
|
)
|
||||||
|
|||||||
@@ -14,6 +14,21 @@
|
|||||||
: 0,
|
: 0,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
const etaText = $derived.by(() => {
|
||||||
|
if (progress.processed <= 0 || progress.total <= 0) return null;
|
||||||
|
const elapsedMs = performance.now() - progress.startedAt;
|
||||||
|
if (elapsedMs < 200) return null; // need a minimum sample window
|
||||||
|
const tokensPerMs = progress.processed / elapsedMs;
|
||||||
|
const remainingTokens = progress.total - progress.processed;
|
||||||
|
const remainingMs = remainingTokens / tokensPerMs;
|
||||||
|
const remainingSec = Math.ceil(remainingMs / 1000);
|
||||||
|
if (remainingSec <= 0) return null;
|
||||||
|
if (remainingSec < 60) return `~${remainingSec}s remaining`;
|
||||||
|
const mins = Math.floor(remainingSec / 60);
|
||||||
|
const secs = remainingSec % 60;
|
||||||
|
return `~${mins}m ${secs}s remaining`;
|
||||||
|
});
|
||||||
|
|
||||||
function formatTokenCount(count: number | undefined): string {
|
function formatTokenCount(count: number | undefined): string {
|
||||||
if (count == null) return "0";
|
if (count == null) return "0";
|
||||||
if (count >= 1000) {
|
if (count >= 1000) {
|
||||||
@@ -40,8 +55,11 @@
|
|||||||
style="width: {percentage}%"
|
style="width: {percentage}%"
|
||||||
></div>
|
></div>
|
||||||
</div>
|
</div>
|
||||||
<div class="text-right text-xs text-exo-light-gray/70 mt-0.5 font-mono">
|
<div
|
||||||
{percentage}%
|
class="flex items-center justify-between text-xs text-exo-light-gray/70 mt-0.5 font-mono"
|
||||||
|
>
|
||||||
|
<span>{etaText ?? ""}</span>
|
||||||
|
<span>{percentage}%</span>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
|||||||
@@ -250,6 +250,11 @@ interface RawStateResponse {
|
|||||||
>;
|
>;
|
||||||
// Thunderbolt bridge cycles (nodes with bridge enabled forming loops)
|
// Thunderbolt bridge cycles (nodes with bridge enabled forming loops)
|
||||||
thunderboltBridgeCycles?: string[][];
|
thunderboltBridgeCycles?: string[][];
|
||||||
|
// Disk usage per node
|
||||||
|
nodeDisk?: Record<
|
||||||
|
string,
|
||||||
|
{ total: { inBytes: number }; available: { inBytes: number } }
|
||||||
|
>;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface MessageAttachment {
|
export interface MessageAttachment {
|
||||||
@@ -276,6 +281,8 @@ export interface TokenData {
|
|||||||
export interface PrefillProgress {
|
export interface PrefillProgress {
|
||||||
processed: number;
|
processed: number;
|
||||||
total: number;
|
total: number;
|
||||||
|
/** Timestamp (performance.now()) when prefill started. */
|
||||||
|
startedAt: number;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface Message {
|
export interface Message {
|
||||||
@@ -1652,11 +1659,12 @@ class AppStore {
|
|||||||
if (!reader) throw new Error("No response body");
|
if (!reader) throw new Error("No response body");
|
||||||
|
|
||||||
let fullContent = prefixText;
|
let fullContent = prefixText;
|
||||||
|
let streamedThinking = "";
|
||||||
const collectedTokens: TokenData[] = [...tokensToKeep];
|
const collectedTokens: TokenData[] = [...tokensToKeep];
|
||||||
|
|
||||||
interface ChatCompletionChunk {
|
interface ChatCompletionChunk {
|
||||||
choices?: Array<{
|
choices?: Array<{
|
||||||
delta?: { content?: string };
|
delta?: { content?: string; reasoning_content?: string };
|
||||||
logprobs?: {
|
logprobs?: {
|
||||||
content?: Array<{
|
content?: Array<{
|
||||||
token: string;
|
token: string;
|
||||||
@@ -1677,6 +1685,7 @@ class AppStore {
|
|||||||
(parsed) => {
|
(parsed) => {
|
||||||
const choice = parsed.choices?.[0];
|
const choice = parsed.choices?.[0];
|
||||||
const delta = choice?.delta?.content;
|
const delta = choice?.delta?.content;
|
||||||
|
const thinkingDelta = choice?.delta?.reasoning_content;
|
||||||
|
|
||||||
// Collect logprobs data
|
// Collect logprobs data
|
||||||
const logprobsContent = choice?.logprobs?.content;
|
const logprobsContent = choice?.logprobs?.content;
|
||||||
@@ -1695,7 +1704,11 @@ class AppStore {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (delta) {
|
if (thinkingDelta) {
|
||||||
|
streamedThinking += thinkingDelta;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (delta || thinkingDelta) {
|
||||||
if (firstTokenTime === null) {
|
if (firstTokenTime === null) {
|
||||||
firstTokenTime = performance.now();
|
firstTokenTime = performance.now();
|
||||||
this.ttftMs = firstTokenTime - requestStartTime;
|
this.ttftMs = firstTokenTime - requestStartTime;
|
||||||
@@ -1709,9 +1722,14 @@ class AppStore {
|
|||||||
this.tps = ((tokenCount - tokensToKeep.length) / elapsed) * 1000;
|
this.tps = ((tokenCount - tokensToKeep.length) / elapsed) * 1000;
|
||||||
}
|
}
|
||||||
|
|
||||||
fullContent += delta;
|
if (delta) {
|
||||||
const { displayContent, thinkingContent } =
|
fullContent += delta;
|
||||||
|
}
|
||||||
|
const { displayContent, thinkingContent: tagThinking } =
|
||||||
this.stripThinkingTags(fullContent);
|
this.stripThinkingTags(fullContent);
|
||||||
|
const combinedThinking = [streamedThinking, tagThinking]
|
||||||
|
.filter(Boolean)
|
||||||
|
.join("\n\n");
|
||||||
|
|
||||||
if (this.activeConversationId === targetConversationId) {
|
if (this.activeConversationId === targetConversationId) {
|
||||||
this.currentResponse = displayContent;
|
this.currentResponse = displayContent;
|
||||||
@@ -1723,7 +1741,7 @@ class AppStore {
|
|||||||
messageId,
|
messageId,
|
||||||
(m) => {
|
(m) => {
|
||||||
m.content = displayContent;
|
m.content = displayContent;
|
||||||
m.thinking = thinkingContent || undefined;
|
m.thinking = combinedThinking || undefined;
|
||||||
m.tokens = [...collectedTokens];
|
m.tokens = [...collectedTokens];
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
@@ -1735,11 +1753,14 @@ class AppStore {
|
|||||||
|
|
||||||
// Final update
|
// Final update
|
||||||
if (this.conversationExists(targetConversationId)) {
|
if (this.conversationExists(targetConversationId)) {
|
||||||
const { displayContent, thinkingContent } =
|
const { displayContent, thinkingContent: tagThinking } =
|
||||||
this.stripThinkingTags(fullContent);
|
this.stripThinkingTags(fullContent);
|
||||||
|
const finalThinking = [streamedThinking, tagThinking]
|
||||||
|
.filter(Boolean)
|
||||||
|
.join("\n\n");
|
||||||
this.updateConversationMessage(targetConversationId, messageId, (m) => {
|
this.updateConversationMessage(targetConversationId, messageId, (m) => {
|
||||||
m.content = displayContent;
|
m.content = displayContent;
|
||||||
m.thinking = thinkingContent || undefined;
|
m.thinking = finalThinking || undefined;
|
||||||
m.tokens = [...collectedTokens];
|
m.tokens = [...collectedTokens];
|
||||||
if (this.ttftMs !== null) m.ttftMs = this.ttftMs;
|
if (this.ttftMs !== null) m.ttftMs = this.ttftMs;
|
||||||
if (this.tps !== null) m.tps = this.tps;
|
if (this.tps !== null) m.tps = this.tps;
|
||||||
@@ -1847,11 +1868,12 @@ class AppStore {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let streamedContent = "";
|
let streamedContent = "";
|
||||||
|
let streamedThinking = "";
|
||||||
const collectedTokens: TokenData[] = [];
|
const collectedTokens: TokenData[] = [];
|
||||||
|
|
||||||
interface ChatCompletionChunk {
|
interface ChatCompletionChunk {
|
||||||
choices?: Array<{
|
choices?: Array<{
|
||||||
delta?: { content?: string };
|
delta?: { content?: string; reasoning_content?: string };
|
||||||
logprobs?: {
|
logprobs?: {
|
||||||
content?: Array<{
|
content?: Array<{
|
||||||
token: string;
|
token: string;
|
||||||
@@ -1872,6 +1894,7 @@ class AppStore {
|
|||||||
(parsed) => {
|
(parsed) => {
|
||||||
const choice = parsed.choices?.[0];
|
const choice = parsed.choices?.[0];
|
||||||
const delta = choice?.delta?.content;
|
const delta = choice?.delta?.content;
|
||||||
|
const thinkingDelta = choice?.delta?.reasoning_content;
|
||||||
|
|
||||||
// Collect logprobs data
|
// Collect logprobs data
|
||||||
const logprobsContent = choice?.logprobs?.content;
|
const logprobsContent = choice?.logprobs?.content;
|
||||||
@@ -1890,10 +1913,19 @@ class AppStore {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (delta) {
|
if (thinkingDelta) {
|
||||||
streamedContent += delta;
|
streamedThinking += thinkingDelta;
|
||||||
const { displayContent, thinkingContent } =
|
}
|
||||||
|
|
||||||
|
if (delta || thinkingDelta) {
|
||||||
|
if (delta) {
|
||||||
|
streamedContent += delta;
|
||||||
|
}
|
||||||
|
const { displayContent, thinkingContent: tagThinking } =
|
||||||
this.stripThinkingTags(streamedContent);
|
this.stripThinkingTags(streamedContent);
|
||||||
|
const combinedThinking = [streamedThinking, tagThinking]
|
||||||
|
.filter(Boolean)
|
||||||
|
.join("\n\n");
|
||||||
|
|
||||||
// Only update currentResponse if target conversation is active
|
// Only update currentResponse if target conversation is active
|
||||||
if (this.activeConversationId === targetConversationId) {
|
if (this.activeConversationId === targetConversationId) {
|
||||||
@@ -1906,7 +1938,7 @@ class AppStore {
|
|||||||
assistantMessage.id,
|
assistantMessage.id,
|
||||||
(msg) => {
|
(msg) => {
|
||||||
msg.content = displayContent;
|
msg.content = displayContent;
|
||||||
msg.thinking = thinkingContent || undefined;
|
msg.thinking = combinedThinking || undefined;
|
||||||
msg.tokens = [...collectedTokens];
|
msg.tokens = [...collectedTokens];
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
@@ -1918,14 +1950,17 @@ class AppStore {
|
|||||||
|
|
||||||
// Final cleanup of the message (if conversation still exists)
|
// Final cleanup of the message (if conversation still exists)
|
||||||
if (this.conversationExists(targetConversationId)) {
|
if (this.conversationExists(targetConversationId)) {
|
||||||
const { displayContent, thinkingContent } =
|
const { displayContent, thinkingContent: tagThinking } =
|
||||||
this.stripThinkingTags(streamedContent);
|
this.stripThinkingTags(streamedContent);
|
||||||
|
const finalThinking = [streamedThinking, tagThinking]
|
||||||
|
.filter(Boolean)
|
||||||
|
.join("\n\n");
|
||||||
this.updateConversationMessage(
|
this.updateConversationMessage(
|
||||||
targetConversationId,
|
targetConversationId,
|
||||||
assistantMessage.id,
|
assistantMessage.id,
|
||||||
(msg) => {
|
(msg) => {
|
||||||
msg.content = displayContent;
|
msg.content = displayContent;
|
||||||
msg.thinking = thinkingContent || undefined;
|
msg.thinking = finalThinking || undefined;
|
||||||
msg.tokens = [...collectedTokens];
|
msg.tokens = [...collectedTokens];
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
@@ -2317,10 +2352,11 @@ class AppStore {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let streamedContent = "";
|
let streamedContent = "";
|
||||||
|
let streamedThinking = "";
|
||||||
|
|
||||||
interface ChatCompletionChunk {
|
interface ChatCompletionChunk {
|
||||||
choices?: Array<{
|
choices?: Array<{
|
||||||
delta?: { content?: string };
|
delta?: { content?: string; reasoning_content?: string };
|
||||||
logprobs?: {
|
logprobs?: {
|
||||||
content?: Array<{
|
content?: Array<{
|
||||||
token: string;
|
token: string;
|
||||||
@@ -2348,6 +2384,7 @@ class AppStore {
|
|||||||
|
|
||||||
const choice = parsed.choices?.[0];
|
const choice = parsed.choices?.[0];
|
||||||
const tokenContent = choice?.delta?.content;
|
const tokenContent = choice?.delta?.content;
|
||||||
|
const thinkingContent = choice?.delta?.reasoning_content;
|
||||||
|
|
||||||
// Collect logprobs data
|
// Collect logprobs data
|
||||||
const logprobsContent = choice?.logprobs?.content;
|
const logprobsContent = choice?.logprobs?.content;
|
||||||
@@ -2366,7 +2403,11 @@ class AppStore {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (tokenContent) {
|
if (thinkingContent) {
|
||||||
|
streamedThinking += thinkingContent;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (tokenContent || thinkingContent) {
|
||||||
// Track first token for TTFT
|
// Track first token for TTFT
|
||||||
if (firstTokenTime === null) {
|
if (firstTokenTime === null) {
|
||||||
firstTokenTime = performance.now();
|
firstTokenTime = performance.now();
|
||||||
@@ -2383,11 +2424,16 @@ class AppStore {
|
|||||||
this.tps = (tokenCount / elapsed) * 1000;
|
this.tps = (tokenCount / elapsed) * 1000;
|
||||||
}
|
}
|
||||||
|
|
||||||
streamedContent += tokenContent;
|
if (tokenContent) {
|
||||||
|
streamedContent += tokenContent;
|
||||||
|
}
|
||||||
|
|
||||||
// Strip thinking tags for display and extract thinking content
|
// Use stripThinkingTags as fallback for any <think> tags still in content
|
||||||
const { displayContent, thinkingContent } =
|
const { displayContent, thinkingContent: tagThinking } =
|
||||||
this.stripThinkingTags(streamedContent);
|
this.stripThinkingTags(streamedContent);
|
||||||
|
const combinedThinking = [streamedThinking, tagThinking]
|
||||||
|
.filter(Boolean)
|
||||||
|
.join("\n\n");
|
||||||
|
|
||||||
// Only update currentResponse if target conversation is active
|
// Only update currentResponse if target conversation is active
|
||||||
if (this.activeConversationId === targetConversationId) {
|
if (this.activeConversationId === targetConversationId) {
|
||||||
@@ -2400,7 +2446,7 @@ class AppStore {
|
|||||||
assistantMessage.id,
|
assistantMessage.id,
|
||||||
(msg) => {
|
(msg) => {
|
||||||
msg.content = displayContent;
|
msg.content = displayContent;
|
||||||
msg.thinking = thinkingContent || undefined;
|
msg.thinking = combinedThinking || undefined;
|
||||||
msg.tokens = [...collectedTokens];
|
msg.tokens = [...collectedTokens];
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
@@ -2420,6 +2466,7 @@ class AppStore {
|
|||||||
this.prefillProgress = {
|
this.prefillProgress = {
|
||||||
processed: inner.processed_tokens,
|
processed: inner.processed_tokens,
|
||||||
total: inner.total_tokens,
|
total: inner.total_tokens,
|
||||||
|
startedAt: this.prefillProgress?.startedAt ?? performance.now(),
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -2436,14 +2483,17 @@ class AppStore {
|
|||||||
|
|
||||||
// Final cleanup of the message (if conversation still exists)
|
// Final cleanup of the message (if conversation still exists)
|
||||||
if (this.conversationExists(targetConversationId)) {
|
if (this.conversationExists(targetConversationId)) {
|
||||||
const { displayContent, thinkingContent } =
|
const { displayContent, thinkingContent: tagThinking } =
|
||||||
this.stripThinkingTags(streamedContent);
|
this.stripThinkingTags(streamedContent);
|
||||||
|
const finalThinking = [streamedThinking, tagThinking]
|
||||||
|
.filter(Boolean)
|
||||||
|
.join("\n\n");
|
||||||
this.updateConversationMessage(
|
this.updateConversationMessage(
|
||||||
targetConversationId,
|
targetConversationId,
|
||||||
assistantMessage.id,
|
assistantMessage.id,
|
||||||
(msg) => {
|
(msg) => {
|
||||||
msg.content = displayContent;
|
msg.content = displayContent;
|
||||||
msg.thinking = thinkingContent || undefined;
|
msg.thinking = finalThinking || undefined;
|
||||||
msg.tokens = [...collectedTokens];
|
msg.tokens = [...collectedTokens];
|
||||||
// Store performance metrics on the message
|
// Store performance metrics on the message
|
||||||
if (this.ttftMs !== null) {
|
if (this.ttftMs !== null) {
|
||||||
|
|||||||
@@ -114,6 +114,74 @@
|
|||||||
});
|
});
|
||||||
let tb5InfoDismissed = $state(false);
|
let tb5InfoDismissed = $state(false);
|
||||||
|
|
||||||
|
// Detect Mac Studio nodes using RDMA on en2 (the port next to ethernet — RDMA doesn't work there)
|
||||||
|
const macStudioEn2RdmaWarning = $derived.by(() => {
|
||||||
|
const edges = data?.edges;
|
||||||
|
const ids = tbIdentifiers;
|
||||||
|
const rdmaCtl = rdmaCtlData;
|
||||||
|
if (!edges || !ids || !rdmaCtl) return null;
|
||||||
|
|
||||||
|
const affectedConnections: Array<{
|
||||||
|
nodeId: string;
|
||||||
|
nodeName: string;
|
||||||
|
peerNodeId: string;
|
||||||
|
peerNodeName: string;
|
||||||
|
rdmaIface: string;
|
||||||
|
}> = [];
|
||||||
|
|
||||||
|
const isMacStudio = (node: (typeof data.nodes)[string] | undefined) =>
|
||||||
|
node?.system_info?.model_id === "Mac Studio";
|
||||||
|
|
||||||
|
for (const edge of edges) {
|
||||||
|
if (!edge.sourceRdmaIface && !edge.sinkRdmaIface) continue;
|
||||||
|
|
||||||
|
const sourceNode = data?.nodes?.[edge.source];
|
||||||
|
if (
|
||||||
|
isMacStudio(sourceNode) &&
|
||||||
|
edge.sourceRdmaIface === "rdma_en2" &&
|
||||||
|
rdmaCtl[edge.source]?.enabled
|
||||||
|
) {
|
||||||
|
affectedConnections.push({
|
||||||
|
nodeId: edge.source,
|
||||||
|
nodeName:
|
||||||
|
sourceNode?.friendly_name || edge.source.slice(0, 8) + "...",
|
||||||
|
peerNodeId: edge.target,
|
||||||
|
peerNodeName:
|
||||||
|
data?.nodes?.[edge.target]?.friendly_name ||
|
||||||
|
edge.target.slice(0, 8) + "...",
|
||||||
|
rdmaIface: "en2",
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
const sinkNode = data?.nodes?.[edge.target];
|
||||||
|
if (
|
||||||
|
isMacStudio(sinkNode) &&
|
||||||
|
edge.sinkRdmaIface === "rdma_en2" &&
|
||||||
|
rdmaCtl[edge.target]?.enabled
|
||||||
|
) {
|
||||||
|
affectedConnections.push({
|
||||||
|
nodeId: edge.target,
|
||||||
|
nodeName: sinkNode?.friendly_name || edge.target.slice(0, 8) + "...",
|
||||||
|
peerNodeId: edge.source,
|
||||||
|
peerNodeName:
|
||||||
|
sourceNode?.friendly_name || edge.source.slice(0, 8) + "...",
|
||||||
|
rdmaIface: "en2",
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deduplicate by nodeId
|
||||||
|
const seen = new Set<string>();
|
||||||
|
const unique = affectedConnections.filter((c) => {
|
||||||
|
if (seen.has(c.nodeId)) return false;
|
||||||
|
seen.add(c.nodeId);
|
||||||
|
return true;
|
||||||
|
});
|
||||||
|
|
||||||
|
return unique.length > 0 ? unique : null;
|
||||||
|
});
|
||||||
|
let macStudioEn2Dismissed = $state(false);
|
||||||
|
|
||||||
// Helper to get friendly node name from node ID
|
// Helper to get friendly node name from node ID
|
||||||
function getNodeName(nodeId: string): string {
|
function getNodeName(nodeId: string): string {
|
||||||
const node = data?.nodes?.[nodeId];
|
const node = data?.nodes?.[nodeId];
|
||||||
@@ -790,10 +858,8 @@
|
|||||||
if (!progress || typeof progress !== "object") return null;
|
if (!progress || typeof progress !== "object") return null;
|
||||||
|
|
||||||
const prog = progress as Record<string, unknown>;
|
const prog = progress as Record<string, unknown>;
|
||||||
const totalBytes = getBytes(prog.total_bytes ?? prog.totalBytes);
|
const totalBytes = getBytes(prog.total);
|
||||||
const downloadedBytes = getBytes(
|
const downloadedBytes = getBytes(prog.downloaded);
|
||||||
prog.downloaded_bytes ?? prog.downloadedBytes,
|
|
||||||
);
|
|
||||||
const speed = (prog.speed as number) ?? 0;
|
const speed = (prog.speed as number) ?? 0;
|
||||||
const completedFiles =
|
const completedFiles =
|
||||||
(prog.completed_files as number) ?? (prog.completedFiles as number) ?? 0;
|
(prog.completed_files as number) ?? (prog.completedFiles as number) ?? 0;
|
||||||
@@ -806,8 +872,8 @@
|
|||||||
for (const [fileName, fileData] of Object.entries(filesObj)) {
|
for (const [fileName, fileData] of Object.entries(filesObj)) {
|
||||||
if (!fileData || typeof fileData !== "object") continue;
|
if (!fileData || typeof fileData !== "object") continue;
|
||||||
const fd = fileData as Record<string, unknown>;
|
const fd = fileData as Record<string, unknown>;
|
||||||
const fTotal = getBytes(fd.total_bytes ?? fd.totalBytes);
|
const fTotal = getBytes(fd.total);
|
||||||
const fDownloaded = getBytes(fd.downloaded_bytes ?? fd.downloadedBytes);
|
const fDownloaded = getBytes(fd.downloaded);
|
||||||
files.push({
|
files.push({
|
||||||
name: fileName,
|
name: fileName,
|
||||||
totalBytes: fTotal,
|
totalBytes: fTotal,
|
||||||
@@ -1196,7 +1262,6 @@
|
|||||||
if (typeof value === "number") return value;
|
if (typeof value === "number") return value;
|
||||||
if (value && typeof value === "object") {
|
if (value && typeof value === "object") {
|
||||||
const v = value as Record<string, unknown>;
|
const v = value as Record<string, unknown>;
|
||||||
if (typeof v.in_bytes === "number") return v.in_bytes;
|
|
||||||
if (typeof v.inBytes === "number") return v.inBytes;
|
if (typeof v.inBytes === "number") return v.inBytes;
|
||||||
}
|
}
|
||||||
return 0;
|
return 0;
|
||||||
@@ -1758,7 +1823,7 @@
|
|||||||
</script>
|
</script>
|
||||||
|
|
||||||
{#snippet clusterWarnings()}
|
{#snippet clusterWarnings()}
|
||||||
{#if tbBridgeCycles.length > 0 || macosVersionMismatch || (tb5WithoutRdma && !tb5InfoDismissed)}
|
{#if tbBridgeCycles.length > 0 || macosVersionMismatch || (tb5WithoutRdma && !tb5InfoDismissed) || (macStudioEn2RdmaWarning && !macStudioEn2Dismissed)}
|
||||||
<div class="absolute top-4 left-4 flex flex-col gap-2 z-40">
|
<div class="absolute top-4 left-4 flex flex-col gap-2 z-40">
|
||||||
{#if tbBridgeCycles.length > 0}
|
{#if tbBridgeCycles.length > 0}
|
||||||
{@const cycle = tbBridgeCycles[0]}
|
{@const cycle = tbBridgeCycles[0]}
|
||||||
@@ -1923,12 +1988,260 @@
|
|||||||
</button>
|
</button>
|
||||||
</div>
|
</div>
|
||||||
{/if}
|
{/if}
|
||||||
|
|
||||||
|
{#if macStudioEn2RdmaWarning && !macStudioEn2Dismissed}
|
||||||
|
<div class="group relative" role="alert">
|
||||||
|
<div
|
||||||
|
class="flex items-center gap-2 px-3 py-2 rounded border border-red-500/50 bg-red-500/10 backdrop-blur-sm cursor-help"
|
||||||
|
>
|
||||||
|
<svg
|
||||||
|
class="w-5 h-5 text-red-400 flex-shrink-0"
|
||||||
|
fill="none"
|
||||||
|
viewBox="0 0 24 24"
|
||||||
|
stroke="currentColor"
|
||||||
|
stroke-width="2"
|
||||||
|
>
|
||||||
|
<path
|
||||||
|
stroke-linecap="round"
|
||||||
|
stroke-linejoin="round"
|
||||||
|
d={warningIconPath}
|
||||||
|
/>
|
||||||
|
</svg>
|
||||||
|
<span class="text-sm font-mono text-red-200">
|
||||||
|
RDMA INCOMPATIBLE PORT
|
||||||
|
</span>
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
onclick={() => (macStudioEn2Dismissed = true)}
|
||||||
|
class="ml-1 text-red-300/60 hover:text-red-200 transition-colors cursor-pointer"
|
||||||
|
title="Dismiss"
|
||||||
|
>
|
||||||
|
<svg
|
||||||
|
class="w-4 h-4"
|
||||||
|
fill="none"
|
||||||
|
viewBox="0 0 24 24"
|
||||||
|
stroke="currentColor"
|
||||||
|
stroke-width="2"
|
||||||
|
>
|
||||||
|
<path
|
||||||
|
stroke-linecap="round"
|
||||||
|
stroke-linejoin="round"
|
||||||
|
d="M6 18L18 6M6 6l12 12"
|
||||||
|
/>
|
||||||
|
</svg>
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- Expanded tooltip on hover -->
|
||||||
|
<div
|
||||||
|
class="absolute top-full left-0 mt-2 w-96 p-4 rounded border border-red-500/30 bg-[#1a1a1a]/95 backdrop-blur-sm opacity-0 invisible group-hover:opacity-100 group-hover:visible transition-all duration-200 z-50 shadow-lg"
|
||||||
|
>
|
||||||
|
<p class="text-xs text-white/80 mb-3">
|
||||||
|
The Thunderbolt 5 port next to the Ethernet port on Mac Studio
|
||||||
|
does
|
||||||
|
<span class="text-red-400 font-semibold">not support RDMA</span>.
|
||||||
|
Move the cable to one of the other three TB5 ports.
|
||||||
|
</p>
|
||||||
|
|
||||||
|
<div class="text-xs text-white/60 mb-3">
|
||||||
|
<span class="text-red-300">Affected:</span>
|
||||||
|
{#each macStudioEn2RdmaWarning as conn}
|
||||||
|
<div class="ml-2 mt-0.5">
|
||||||
|
<span class="text-white/80">{conn.nodeName}</span>
|
||||||
|
<span class="text-white/30">→</span>
|
||||||
|
<span class="text-white/60">{conn.peerNodeName}</span>
|
||||||
|
<span class="text-white/30 ml-1">(en2)</span>
|
||||||
|
</div>
|
||||||
|
{/each}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- Mac Studio back panel illustration -->
|
||||||
|
<div class="bg-black/40 rounded p-3 mb-3">
|
||||||
|
<p
|
||||||
|
class="text-[10px] font-mono text-white/30 uppercase tracking-wider mb-2"
|
||||||
|
>
|
||||||
|
Mac Studio — Rear Panel
|
||||||
|
</p>
|
||||||
|
<svg
|
||||||
|
viewBox="0 0 320 72"
|
||||||
|
class="w-full"
|
||||||
|
xmlns="http://www.w3.org/2000/svg"
|
||||||
|
>
|
||||||
|
<rect
|
||||||
|
x="1"
|
||||||
|
y="1"
|
||||||
|
width="318"
|
||||||
|
height="70"
|
||||||
|
rx="6"
|
||||||
|
ry="6"
|
||||||
|
fill="none"
|
||||||
|
stroke="rgba(255,255,255,0.12)"
|
||||||
|
stroke-width="1"
|
||||||
|
/>
|
||||||
|
<!-- TB5 port 1 -->
|
||||||
|
<rect
|
||||||
|
x="24"
|
||||||
|
y="22"
|
||||||
|
width="28"
|
||||||
|
height="14"
|
||||||
|
rx="4"
|
||||||
|
fill="none"
|
||||||
|
stroke="rgba(255,255,255,0.3)"
|
||||||
|
stroke-width="1"
|
||||||
|
/>
|
||||||
|
<text
|
||||||
|
x="38"
|
||||||
|
y="52"
|
||||||
|
text-anchor="middle"
|
||||||
|
fill="rgba(255,255,255,0.25)"
|
||||||
|
style="font-size:7px;font-family:ui-monospace,monospace;"
|
||||||
|
>TB5</text
|
||||||
|
>
|
||||||
|
<!-- TB5 port 2 -->
|
||||||
|
<rect
|
||||||
|
x="62"
|
||||||
|
y="22"
|
||||||
|
width="28"
|
||||||
|
height="14"
|
||||||
|
rx="4"
|
||||||
|
fill="none"
|
||||||
|
stroke="rgba(255,255,255,0.3)"
|
||||||
|
stroke-width="1"
|
||||||
|
/>
|
||||||
|
<text
|
||||||
|
x="76"
|
||||||
|
y="52"
|
||||||
|
text-anchor="middle"
|
||||||
|
fill="rgba(255,255,255,0.25)"
|
||||||
|
style="font-size:7px;font-family:ui-monospace,monospace;"
|
||||||
|
>TB5</text
|
||||||
|
>
|
||||||
|
<!-- TB5 port 3 -->
|
||||||
|
<rect
|
||||||
|
x="100"
|
||||||
|
y="22"
|
||||||
|
width="28"
|
||||||
|
height="14"
|
||||||
|
rx="4"
|
||||||
|
fill="none"
|
||||||
|
stroke="rgba(255,255,255,0.3)"
|
||||||
|
stroke-width="1"
|
||||||
|
/>
|
||||||
|
<text
|
||||||
|
x="114"
|
||||||
|
y="52"
|
||||||
|
text-anchor="middle"
|
||||||
|
fill="rgba(255,255,255,0.25)"
|
||||||
|
style="font-size:7px;font-family:ui-monospace,monospace;"
|
||||||
|
>TB5</text
|
||||||
|
>
|
||||||
|
<!-- TB5 port 4: INCOMPATIBLE (en2) — equally spaced with ports 1-3 -->
|
||||||
|
<rect
|
||||||
|
x="138"
|
||||||
|
y="22"
|
||||||
|
width="28"
|
||||||
|
height="14"
|
||||||
|
rx="4"
|
||||||
|
fill="rgba(239,68,68,0.1)"
|
||||||
|
stroke="rgba(239,68,68,0.7)"
|
||||||
|
stroke-width="1.5"
|
||||||
|
/>
|
||||||
|
<line
|
||||||
|
x1="142"
|
||||||
|
y1="25"
|
||||||
|
x2="162"
|
||||||
|
y2="33"
|
||||||
|
stroke="rgba(239,68,68,0.8)"
|
||||||
|
stroke-width="1.5"
|
||||||
|
stroke-linecap="round"
|
||||||
|
/>
|
||||||
|
<line
|
||||||
|
x1="162"
|
||||||
|
y1="25"
|
||||||
|
x2="142"
|
||||||
|
y2="33"
|
||||||
|
stroke="rgba(239,68,68,0.8)"
|
||||||
|
stroke-width="1.5"
|
||||||
|
stroke-linecap="round"
|
||||||
|
/>
|
||||||
|
<text
|
||||||
|
x="152"
|
||||||
|
y="52"
|
||||||
|
text-anchor="middle"
|
||||||
|
fill="rgba(239,68,68,0.6)"
|
||||||
|
style="font-size:7px;font-family:ui-monospace,monospace;font-weight:600;"
|
||||||
|
>en2</text
|
||||||
|
>
|
||||||
|
<!-- Ethernet port -->
|
||||||
|
<rect
|
||||||
|
x="196"
|
||||||
|
y="19"
|
||||||
|
width="24"
|
||||||
|
height="20"
|
||||||
|
rx="2"
|
||||||
|
fill="none"
|
||||||
|
stroke="rgba(255,255,255,0.2)"
|
||||||
|
stroke-width="1"
|
||||||
|
/>
|
||||||
|
<rect
|
||||||
|
x="200"
|
||||||
|
y="23"
|
||||||
|
width="16"
|
||||||
|
height="12"
|
||||||
|
rx="1"
|
||||||
|
fill="none"
|
||||||
|
stroke="rgba(255,255,255,0.12)"
|
||||||
|
stroke-width="0.75"
|
||||||
|
/>
|
||||||
|
<text
|
||||||
|
x="208"
|
||||||
|
y="52"
|
||||||
|
text-anchor="middle"
|
||||||
|
fill="rgba(255,255,255,0.25)"
|
||||||
|
style="font-size:7px;font-family:ui-monospace,monospace;"
|
||||||
|
>ETH</text
|
||||||
|
>
|
||||||
|
<!-- Green checkmarks on working ports -->
|
||||||
|
<circle
|
||||||
|
cx="38"
|
||||||
|
cy="62"
|
||||||
|
r="3"
|
||||||
|
fill="none"
|
||||||
|
stroke="rgba(74,222,128,0.5)"
|
||||||
|
stroke-width="0.75"
|
||||||
|
/>
|
||||||
|
<circle
|
||||||
|
cx="76"
|
||||||
|
cy="62"
|
||||||
|
r="3"
|
||||||
|
fill="none"
|
||||||
|
stroke="rgba(74,222,128,0.5)"
|
||||||
|
stroke-width="0.75"
|
||||||
|
/>
|
||||||
|
<circle
|
||||||
|
cx="114"
|
||||||
|
cy="62"
|
||||||
|
r="3"
|
||||||
|
fill="none"
|
||||||
|
stroke="rgba(74,222,128,0.5)"
|
||||||
|
stroke-width="0.75"
|
||||||
|
/>
|
||||||
|
</svg>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<p class="text-xs text-white/50">
|
||||||
|
<span class="text-green-400">Fix:</span> Move the Thunderbolt cable
|
||||||
|
to any of the three leftmost ports (all support RDMA).
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
{/if}
|
||||||
</div>
|
</div>
|
||||||
{/if}
|
{/if}
|
||||||
{/snippet}
|
{/snippet}
|
||||||
|
|
||||||
{#snippet clusterWarningsCompact()}
|
{#snippet clusterWarningsCompact()}
|
||||||
{#if tbBridgeCycles.length > 0 || macosVersionMismatch || (tb5WithoutRdma && !tb5InfoDismissed)}
|
{#if tbBridgeCycles.length > 0 || macosVersionMismatch || (tb5WithoutRdma && !tb5InfoDismissed) || (macStudioEn2RdmaWarning && !macStudioEn2Dismissed)}
|
||||||
<div class="absolute top-2 left-2 flex flex-col gap-1">
|
<div class="absolute top-2 left-2 flex flex-col gap-1">
|
||||||
{#if tbBridgeCycles.length > 0}
|
{#if tbBridgeCycles.length > 0}
|
||||||
<div
|
<div
|
||||||
@@ -1996,6 +2309,27 @@
|
|||||||
>
|
>
|
||||||
</div>
|
</div>
|
||||||
{/if}
|
{/if}
|
||||||
|
{#if macStudioEn2RdmaWarning && !macStudioEn2Dismissed}
|
||||||
|
<div
|
||||||
|
class="flex items-center gap-1.5 px-2 py-1 rounded border border-red-500/50 bg-red-500/10 backdrop-blur-sm"
|
||||||
|
title="Mac Studio RDMA incompatible port (en2) — move cable to another TB5 port"
|
||||||
|
>
|
||||||
|
<svg
|
||||||
|
class="w-3.5 h-3.5 text-red-400"
|
||||||
|
fill="none"
|
||||||
|
viewBox="0 0 24 24"
|
||||||
|
stroke="currentColor"
|
||||||
|
stroke-width="2"
|
||||||
|
>
|
||||||
|
<path
|
||||||
|
stroke-linecap="round"
|
||||||
|
stroke-linejoin="round"
|
||||||
|
d={warningIconPath}
|
||||||
|
/>
|
||||||
|
</svg>
|
||||||
|
<span class="text-[10px] font-mono text-red-200">BAD RDMA PORT</span>
|
||||||
|
</div>
|
||||||
|
{/if}
|
||||||
</div>
|
</div>
|
||||||
{/if}
|
{/if}
|
||||||
{/snippet}
|
{/snippet}
|
||||||
|
|||||||
@@ -74,7 +74,6 @@
|
|||||||
if (typeof value === "number") return value;
|
if (typeof value === "number") return value;
|
||||||
if (value && typeof value === "object") {
|
if (value && typeof value === "object") {
|
||||||
const v = value as Record<string, unknown>;
|
const v = value as Record<string, unknown>;
|
||||||
if (typeof v.in_bytes === "number") return v.in_bytes;
|
|
||||||
if (typeof v.inBytes === "number") return v.inBytes;
|
if (typeof v.inBytes === "number") return v.inBytes;
|
||||||
}
|
}
|
||||||
return 0;
|
return 0;
|
||||||
@@ -231,23 +230,14 @@
|
|||||||
undefined;
|
undefined;
|
||||||
let cell: CellStatus;
|
let cell: CellStatus;
|
||||||
if (tag === "DownloadCompleted") {
|
if (tag === "DownloadCompleted") {
|
||||||
const totalBytes = getBytes(
|
const totalBytes = getBytes(payload.total);
|
||||||
payload.total_bytes ?? payload.totalBytes,
|
|
||||||
);
|
|
||||||
cell = { kind: "completed", totalBytes, modelDirectory };
|
cell = { kind: "completed", totalBytes, modelDirectory };
|
||||||
} else if (tag === "DownloadOngoing") {
|
} else if (tag === "DownloadOngoing") {
|
||||||
const rawProgress =
|
const rawProgress =
|
||||||
payload.download_progress ?? payload.downloadProgress ?? {};
|
payload.download_progress ?? payload.downloadProgress ?? {};
|
||||||
const prog = rawProgress as Record<string, unknown>;
|
const prog = rawProgress as Record<string, unknown>;
|
||||||
const totalBytes = getBytes(
|
const totalBytes = getBytes(prog.total ?? payload.total);
|
||||||
prog.total_bytes ??
|
const downloadedBytes = getBytes(prog.downloaded);
|
||||||
prog.totalBytes ??
|
|
||||||
payload.total_bytes ??
|
|
||||||
payload.totalBytes,
|
|
||||||
);
|
|
||||||
const downloadedBytes = getBytes(
|
|
||||||
prog.downloaded_bytes ?? prog.downloadedBytes,
|
|
||||||
);
|
|
||||||
const speed = (prog.speed as number) ?? 0;
|
const speed = (prog.speed as number) ?? 0;
|
||||||
const etaMs =
|
const etaMs =
|
||||||
(prog.eta_ms as number) ?? (prog.etaMs as number) ?? 0;
|
(prog.eta_ms as number) ?? (prog.etaMs as number) ?? 0;
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ class ConnectionUpdate:
|
|||||||
Whether this is a connection or disconnection event
|
Whether this is a connection or disconnection event
|
||||||
"""
|
"""
|
||||||
@property
|
@property
|
||||||
def peer_id(self) -> PeerId:
|
def peer_id(self) -> builtins.str:
|
||||||
r"""
|
r"""
|
||||||
Identity of the peer that we have connected to or disconnected from.
|
Identity of the peer that we have connected to or disconnected from.
|
||||||
"""
|
"""
|
||||||
@@ -40,92 +40,22 @@ class Keypair:
|
|||||||
Identity keypair of a node.
|
Identity keypair of a node.
|
||||||
"""
|
"""
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def generate_ed25519() -> Keypair:
|
def generate() -> Keypair:
|
||||||
r"""
|
r"""
|
||||||
Generate a new Ed25519 keypair.
|
Generate a new Ed25519 keypair.
|
||||||
"""
|
"""
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def generate_ecdsa() -> Keypair:
|
def from_bytes(bytes: bytes) -> Keypair:
|
||||||
r"""
|
r"""
|
||||||
Generate a new ECDSA keypair.
|
Construct an Ed25519 keypair from secret key bytes
|
||||||
"""
|
|
||||||
@staticmethod
|
|
||||||
def generate_secp256k1() -> Keypair:
|
|
||||||
r"""
|
|
||||||
Generate a new Secp256k1 keypair.
|
|
||||||
"""
|
|
||||||
@staticmethod
|
|
||||||
def from_protobuf_encoding(bytes: bytes) -> Keypair:
|
|
||||||
r"""
|
|
||||||
Decode a private key from a protobuf structure and parse it as a `Keypair`.
|
|
||||||
"""
|
|
||||||
@staticmethod
|
|
||||||
def rsa_from_pkcs8(bytes: bytes) -> Keypair:
|
|
||||||
r"""
|
|
||||||
Decode an keypair from a DER-encoded secret key in PKCS#8 `PrivateKeyInfo`
|
|
||||||
format (i.e. unencrypted) as defined in [RFC5208].
|
|
||||||
|
|
||||||
[RFC5208]: https://tools.ietf.org/html/rfc5208#section-5
|
|
||||||
"""
|
|
||||||
@staticmethod
|
|
||||||
def secp256k1_from_der(bytes: bytes) -> Keypair:
|
|
||||||
r"""
|
|
||||||
Decode a keypair from a DER-encoded Secp256k1 secret key in an `ECPrivateKey`
|
|
||||||
structure as defined in [RFC5915].
|
|
||||||
|
|
||||||
[RFC5915]: https://tools.ietf.org/html/rfc5915
|
|
||||||
"""
|
|
||||||
@staticmethod
|
|
||||||
def ed25519_from_bytes(bytes: bytes) -> Keypair: ...
|
|
||||||
def to_protobuf_encoding(self) -> bytes:
|
|
||||||
r"""
|
|
||||||
Encode a private key as protobuf structure.
|
|
||||||
"""
|
|
||||||
def to_peer_id(self) -> PeerId:
|
|
||||||
r"""
|
|
||||||
Convert the `Keypair` into the corresponding `PeerId`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@typing.final
|
|
||||||
class Multiaddr:
|
|
||||||
r"""
|
|
||||||
Representation of a Multiaddr.
|
|
||||||
"""
|
|
||||||
@staticmethod
|
|
||||||
def empty() -> Multiaddr:
|
|
||||||
r"""
|
|
||||||
Create a new, empty multiaddress.
|
|
||||||
"""
|
|
||||||
@staticmethod
|
|
||||||
def with_capacity(n: builtins.int) -> Multiaddr:
|
|
||||||
r"""
|
|
||||||
Create a new, empty multiaddress with the given capacity.
|
|
||||||
"""
|
|
||||||
@staticmethod
|
|
||||||
def from_bytes(bytes: bytes) -> Multiaddr:
|
|
||||||
r"""
|
|
||||||
Parse a `Multiaddr` value from its byte slice representation.
|
|
||||||
"""
|
|
||||||
@staticmethod
|
|
||||||
def from_string(string: builtins.str) -> Multiaddr:
|
|
||||||
r"""
|
|
||||||
Parse a `Multiaddr` value from its string representation.
|
|
||||||
"""
|
|
||||||
def len(self) -> builtins.int:
|
|
||||||
r"""
|
|
||||||
Return the length in bytes of this multiaddress.
|
|
||||||
"""
|
|
||||||
def is_empty(self) -> builtins.bool:
|
|
||||||
r"""
|
|
||||||
Returns true if the length of this multiaddress is 0.
|
|
||||||
"""
|
"""
|
||||||
def to_bytes(self) -> bytes:
|
def to_bytes(self) -> bytes:
|
||||||
r"""
|
r"""
|
||||||
Return a copy of this [`Multiaddr`]'s byte representation.
|
Get the secret key bytes underlying the keypair
|
||||||
"""
|
"""
|
||||||
def to_string(self) -> builtins.str:
|
def to_node_id(self) -> builtins.str:
|
||||||
r"""
|
r"""
|
||||||
Convert a Multiaddr to a string.
|
Convert the `Keypair` into the corresponding `PeerId` string, which we use as our `NodeId`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@typing.final
|
@typing.final
|
||||||
@@ -180,37 +110,6 @@ class NoPeersSubscribedToTopicError(builtins.Exception):
|
|||||||
def __repr__(self) -> builtins.str: ...
|
def __repr__(self) -> builtins.str: ...
|
||||||
def __str__(self) -> builtins.str: ...
|
def __str__(self) -> builtins.str: ...
|
||||||
|
|
||||||
@typing.final
|
|
||||||
class PeerId:
|
|
||||||
r"""
|
|
||||||
Identifier of a peer of the network.
|
|
||||||
|
|
||||||
The data is a `CIDv0` compatible multihash of the protobuf encoded public key of the peer
|
|
||||||
as specified in [specs/peer-ids](https://github.com/libp2p/specs/blob/master/peer-ids/peer-ids.md).
|
|
||||||
"""
|
|
||||||
@staticmethod
|
|
||||||
def random() -> PeerId:
|
|
||||||
r"""
|
|
||||||
Generates a random peer ID from a cryptographically secure PRNG.
|
|
||||||
|
|
||||||
This is useful for randomly walking on a DHT, or for testing purposes.
|
|
||||||
"""
|
|
||||||
@staticmethod
|
|
||||||
def from_bytes(bytes: bytes) -> PeerId:
|
|
||||||
r"""
|
|
||||||
Parses a `PeerId` from bytes.
|
|
||||||
"""
|
|
||||||
def to_bytes(self) -> bytes:
|
|
||||||
r"""
|
|
||||||
Returns a raw bytes representation of this `PeerId`.
|
|
||||||
"""
|
|
||||||
def to_base58(self) -> builtins.str:
|
|
||||||
r"""
|
|
||||||
Returns a base-58 encoded string of this `PeerId`.
|
|
||||||
"""
|
|
||||||
def __repr__(self) -> builtins.str: ...
|
|
||||||
def __str__(self) -> builtins.str: ...
|
|
||||||
|
|
||||||
@typing.final
|
@typing.final
|
||||||
class ConnectionUpdateType(enum.Enum):
|
class ConnectionUpdateType(enum.Enum):
|
||||||
r"""
|
r"""
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
use crate::ext::ResultExt as _;
|
use crate::ext::ResultExt as _;
|
||||||
use libp2p::PeerId;
|
|
||||||
use libp2p::identity::Keypair;
|
use libp2p::identity::Keypair;
|
||||||
use pyo3::prelude::{PyBytesMethods as _, PyModule, PyModuleMethods as _};
|
use pyo3::types::{PyBytes, PyBytesMethods as _};
|
||||||
use pyo3::types::PyBytes;
|
|
||||||
use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
|
use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
|
||||||
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
|
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
|
||||||
|
|
||||||
@@ -18,142 +16,32 @@ pub struct PyKeypair(pub Keypair);
|
|||||||
impl PyKeypair {
|
impl PyKeypair {
|
||||||
/// Generate a new Ed25519 keypair.
|
/// Generate a new Ed25519 keypair.
|
||||||
#[staticmethod]
|
#[staticmethod]
|
||||||
fn generate_ed25519() -> Self {
|
fn generate() -> Self {
|
||||||
Self(Keypair::generate_ed25519())
|
Self(Keypair::generate_ed25519())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generate a new ECDSA keypair.
|
/// Construct an Ed25519 keypair from secret key bytes
|
||||||
#[staticmethod]
|
#[staticmethod]
|
||||||
fn generate_ecdsa() -> Self {
|
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||||
Self(Keypair::generate_ecdsa())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Generate a new Secp256k1 keypair.
|
|
||||||
#[staticmethod]
|
|
||||||
fn generate_secp256k1() -> Self {
|
|
||||||
Self(Keypair::generate_secp256k1())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Decode a private key from a protobuf structure and parse it as a `Keypair`.
|
|
||||||
#[staticmethod]
|
|
||||||
fn from_protobuf_encoding(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
|
||||||
let bytes = Vec::from(bytes.as_bytes());
|
|
||||||
Ok(Self(Keypair::from_protobuf_encoding(&bytes).pyerr()?))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Decode an keypair from a DER-encoded secret key in PKCS#8 `PrivateKeyInfo`
|
|
||||||
/// format (i.e. unencrypted) as defined in [RFC5208].
|
|
||||||
///
|
|
||||||
/// [RFC5208]: https://tools.ietf.org/html/rfc5208#section-5
|
|
||||||
#[staticmethod]
|
|
||||||
fn rsa_from_pkcs8(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
|
||||||
let mut bytes = Vec::from(bytes.as_bytes());
|
|
||||||
Ok(Self(Keypair::rsa_from_pkcs8(&mut bytes).pyerr()?))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Decode a keypair from a DER-encoded Secp256k1 secret key in an `ECPrivateKey`
|
|
||||||
/// structure as defined in [RFC5915].
|
|
||||||
///
|
|
||||||
/// [RFC5915]: https://tools.ietf.org/html/rfc5915
|
|
||||||
#[staticmethod]
|
|
||||||
fn secp256k1_from_der(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
|
||||||
let mut bytes = Vec::from(bytes.as_bytes());
|
|
||||||
Ok(Self(Keypair::secp256k1_from_der(&mut bytes).pyerr()?))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[staticmethod]
|
|
||||||
fn ed25519_from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
|
||||||
let mut bytes = Vec::from(bytes.as_bytes());
|
let mut bytes = Vec::from(bytes.as_bytes());
|
||||||
Ok(Self(Keypair::ed25519_from_bytes(&mut bytes).pyerr()?))
|
Ok(Self(Keypair::ed25519_from_bytes(&mut bytes).pyerr()?))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Encode a private key as protobuf structure.
|
/// Get the secret key bytes underlying the keypair
|
||||||
fn to_protobuf_encoding<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
|
fn to_bytes<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
|
||||||
let bytes = self.0.to_protobuf_encoding().pyerr()?;
|
let bytes = self
|
||||||
|
.0
|
||||||
|
.clone()
|
||||||
|
.try_into_ed25519()
|
||||||
|
.pyerr()?
|
||||||
|
.secret()
|
||||||
|
.as_ref()
|
||||||
|
.to_vec();
|
||||||
Ok(PyBytes::new(py, &bytes))
|
Ok(PyBytes::new(py, &bytes))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Convert the `Keypair` into the corresponding `PeerId`.
|
/// Convert the `Keypair` into the corresponding `PeerId` string, which we use as our `NodeId`.
|
||||||
fn to_peer_id(&self) -> PyPeerId {
|
fn to_node_id(&self) -> String {
|
||||||
PyPeerId(self.0.public().to_peer_id())
|
self.0.public().to_peer_id().to_base58()
|
||||||
}
|
|
||||||
|
|
||||||
// /// Hidden constructor for pickling support. TODO: figure out how to do pickling...
|
|
||||||
// #[gen_stub(skip)]
|
|
||||||
// #[new]
|
|
||||||
// fn py_new(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
|
||||||
// Self::from_protobuf_encoding(bytes)
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// #[gen_stub(skip)]
|
|
||||||
// fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
|
|
||||||
// *self = Self::from_protobuf_encoding(state)?;
|
|
||||||
// Ok(())
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// #[gen_stub(skip)]
|
|
||||||
// fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
|
|
||||||
// self.to_protobuf_encoding(py)
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// #[gen_stub(skip)]
|
|
||||||
// pub fn __getnewargs__<'py>(&self, py: Python<'py>) -> PyResult<(Bound<'py, PyBytes>,)> {
|
|
||||||
// Ok((self.to_protobuf_encoding(py)?,))
|
|
||||||
// }
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Identifier of a peer of the network.
|
|
||||||
///
|
|
||||||
/// The data is a `CIDv0` compatible multihash of the protobuf encoded public key of the peer
|
|
||||||
/// as specified in [specs/peer-ids](https://github.com/libp2p/specs/blob/master/peer-ids/peer-ids.md).
|
|
||||||
#[gen_stub_pyclass]
|
|
||||||
#[pyclass(name = "PeerId", frozen)]
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
#[repr(transparent)]
|
|
||||||
pub struct PyPeerId(pub PeerId);
|
|
||||||
|
|
||||||
#[gen_stub_pymethods]
|
|
||||||
#[pymethods]
|
|
||||||
#[allow(clippy::needless_pass_by_value)]
|
|
||||||
impl PyPeerId {
|
|
||||||
/// Generates a random peer ID from a cryptographically secure PRNG.
|
|
||||||
///
|
|
||||||
/// This is useful for randomly walking on a DHT, or for testing purposes.
|
|
||||||
#[staticmethod]
|
|
||||||
fn random() -> Self {
|
|
||||||
Self(PeerId::random())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Parses a `PeerId` from bytes.
|
|
||||||
#[staticmethod]
|
|
||||||
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
|
||||||
let bytes = Vec::from(bytes.as_bytes());
|
|
||||||
Ok(Self(PeerId::from_bytes(&bytes).pyerr()?))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns a raw bytes representation of this `PeerId`.
|
|
||||||
fn to_bytes<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> {
|
|
||||||
let bytes = self.0.to_bytes();
|
|
||||||
PyBytes::new(py, &bytes)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns a base-58 encoded string of this `PeerId`.
|
|
||||||
fn to_base58(&self) -> String {
|
|
||||||
self.0.to_base58()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn __repr__(&self) -> String {
|
|
||||||
format!("PeerId({})", self.to_base58())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn __str__(&self) -> String {
|
|
||||||
self.to_base58()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn ident_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
|
||||||
m.add_class::<PyKeypair>()?;
|
|
||||||
m.add_class::<PyPeerId>()?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -8,9 +8,10 @@ mod allow_threading;
|
|||||||
mod ident;
|
mod ident;
|
||||||
mod networking;
|
mod networking;
|
||||||
|
|
||||||
use crate::ident::ident_submodule;
|
use crate::ident::PyKeypair;
|
||||||
use crate::networking::networking_submodule;
|
use crate::networking::networking_submodule;
|
||||||
use pyo3::prelude::PyModule;
|
use pyo3::prelude::PyModule;
|
||||||
|
use pyo3::types::PyModuleMethods;
|
||||||
use pyo3::{Bound, PyResult, pyclass, pymodule};
|
use pyo3::{Bound, PyResult, pyclass, pymodule};
|
||||||
use pyo3_stub_gen::define_stub_info_gatherer;
|
use pyo3_stub_gen::define_stub_info_gatherer;
|
||||||
|
|
||||||
@@ -158,7 +159,7 @@ fn main_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
|||||||
// TODO: for now this is all NOT a submodule, but figure out how to make the submodule system
|
// TODO: for now this is all NOT a submodule, but figure out how to make the submodule system
|
||||||
// work with maturin, where the types generate correctly, in the right folder, without
|
// work with maturin, where the types generate correctly, in the right folder, without
|
||||||
// too many importing issues...
|
// too many importing issues...
|
||||||
ident_submodule(m)?;
|
m.add_class::<PyKeypair>()?;
|
||||||
networking_submodule(m)?;
|
networking_submodule(m)?;
|
||||||
|
|
||||||
// top-level constructs
|
// top-level constructs
|
||||||
|
|||||||
@@ -8,7 +8,7 @@
|
|||||||
use crate::r#const::MPSC_CHANNEL_SIZE;
|
use crate::r#const::MPSC_CHANNEL_SIZE;
|
||||||
use crate::ext::{ByteArrayExt as _, FutureExt, PyErrExt as _};
|
use crate::ext::{ByteArrayExt as _, FutureExt, PyErrExt as _};
|
||||||
use crate::ext::{ResultExt as _, TokioMpscReceiverExt as _, TokioMpscSenderExt as _};
|
use crate::ext::{ResultExt as _, TokioMpscReceiverExt as _, TokioMpscSenderExt as _};
|
||||||
use crate::ident::{PyKeypair, PyPeerId};
|
use crate::ident::PyKeypair;
|
||||||
use crate::pyclass;
|
use crate::pyclass;
|
||||||
use libp2p::futures::StreamExt as _;
|
use libp2p::futures::StreamExt as _;
|
||||||
use libp2p::gossipsub;
|
use libp2p::gossipsub;
|
||||||
@@ -119,7 +119,7 @@ struct PyConnectionUpdate {
|
|||||||
|
|
||||||
/// Identity of the peer that we have connected to or disconnected from.
|
/// Identity of the peer that we have connected to or disconnected from.
|
||||||
#[pyo3(get)]
|
#[pyo3(get)]
|
||||||
peer_id: PyPeerId,
|
peer_id: String,
|
||||||
|
|
||||||
/// Remote connection's IPv4 address.
|
/// Remote connection's IPv4 address.
|
||||||
#[pyo3(get)]
|
#[pyo3(get)]
|
||||||
@@ -251,7 +251,7 @@ async fn networking_task(
|
|||||||
// send connection event to channel (or exit if connection closed)
|
// send connection event to channel (or exit if connection closed)
|
||||||
if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
|
if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
|
||||||
update_type: PyConnectionUpdateType::Connected,
|
update_type: PyConnectionUpdateType::Connected,
|
||||||
peer_id: PyPeerId(peer_id),
|
peer_id: peer_id.to_base58(),
|
||||||
remote_ipv4,
|
remote_ipv4,
|
||||||
remote_tcp_port,
|
remote_tcp_port,
|
||||||
}).await {
|
}).await {
|
||||||
@@ -272,7 +272,7 @@ async fn networking_task(
|
|||||||
// send disconnection event to channel (or exit if connection closed)
|
// send disconnection event to channel (or exit if connection closed)
|
||||||
if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
|
if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
|
||||||
update_type: PyConnectionUpdateType::Disconnected,
|
update_type: PyConnectionUpdateType::Disconnected,
|
||||||
peer_id: PyPeerId(peer_id),
|
peer_id: peer_id.to_base58(),
|
||||||
remote_ipv4,
|
remote_ipv4,
|
||||||
remote_tcp_port,
|
remote_tcp_port,
|
||||||
}).await {
|
}).await {
|
||||||
|
|||||||
@@ -80,7 +80,7 @@ class DownloadCoordinator:
|
|||||||
completed = DownloadCompleted(
|
completed = DownloadCompleted(
|
||||||
shard_metadata=callback_shard,
|
shard_metadata=callback_shard,
|
||||||
node_id=self.node_id,
|
node_id=self.node_id,
|
||||||
total_bytes=progress.total_bytes,
|
total=progress.total,
|
||||||
model_directory=self._model_dir(model_id),
|
model_directory=self._model_dir(model_id),
|
||||||
)
|
)
|
||||||
self.download_status[model_id] = completed
|
self.download_status[model_id] = completed
|
||||||
@@ -203,7 +203,7 @@ class DownloadCoordinator:
|
|||||||
completed = DownloadCompleted(
|
completed = DownloadCompleted(
|
||||||
shard_metadata=shard,
|
shard_metadata=shard,
|
||||||
node_id=self.node_id,
|
node_id=self.node_id,
|
||||||
total_bytes=initial_progress.total_bytes,
|
total=initial_progress.total,
|
||||||
model_directory=self._model_dir(model_id),
|
model_directory=self._model_dir(model_id),
|
||||||
)
|
)
|
||||||
self.download_status[model_id] = completed
|
self.download_status[model_id] = completed
|
||||||
@@ -332,13 +332,13 @@ class DownloadCoordinator:
|
|||||||
status: DownloadProgress = DownloadCompleted(
|
status: DownloadProgress = DownloadCompleted(
|
||||||
node_id=self.node_id,
|
node_id=self.node_id,
|
||||||
shard_metadata=progress.shard,
|
shard_metadata=progress.shard,
|
||||||
total_bytes=progress.total_bytes,
|
total=progress.total,
|
||||||
model_directory=self._model_dir(
|
model_directory=self._model_dir(
|
||||||
progress.shard.model_card.model_id
|
progress.shard.model_card.model_id
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
elif progress.status in ["in_progress", "not_started"]:
|
elif progress.status in ["in_progress", "not_started"]:
|
||||||
if progress.downloaded_bytes_this_session.in_bytes == 0:
|
if progress.downloaded_bytes.in_bytes == 0:
|
||||||
status = DownloadPending(
|
status = DownloadPending(
|
||||||
node_id=self.node_id,
|
node_id=self.node_id,
|
||||||
shard_metadata=progress.shard,
|
shard_metadata=progress.shard,
|
||||||
|
|||||||
@@ -80,9 +80,9 @@ def map_repo_file_download_progress_to_download_progress_data(
|
|||||||
repo_file_download_progress: RepoFileDownloadProgress,
|
repo_file_download_progress: RepoFileDownloadProgress,
|
||||||
) -> DownloadProgressData:
|
) -> DownloadProgressData:
|
||||||
return DownloadProgressData(
|
return DownloadProgressData(
|
||||||
downloaded_bytes=repo_file_download_progress.downloaded,
|
downloaded=repo_file_download_progress.downloaded,
|
||||||
downloaded_bytes_this_session=repo_file_download_progress.downloaded_this_session,
|
downloaded_this_session=repo_file_download_progress.downloaded_this_session,
|
||||||
total_bytes=repo_file_download_progress.total,
|
total=repo_file_download_progress.total,
|
||||||
completed_files=1 if repo_file_download_progress.status == "complete" else 0,
|
completed_files=1 if repo_file_download_progress.status == "complete" else 0,
|
||||||
total_files=1,
|
total_files=1,
|
||||||
speed=repo_file_download_progress.speed,
|
speed=repo_file_download_progress.speed,
|
||||||
@@ -95,9 +95,9 @@ def map_repo_download_progress_to_download_progress_data(
|
|||||||
repo_download_progress: RepoDownloadProgress,
|
repo_download_progress: RepoDownloadProgress,
|
||||||
) -> DownloadProgressData:
|
) -> DownloadProgressData:
|
||||||
return DownloadProgressData(
|
return DownloadProgressData(
|
||||||
total_bytes=repo_download_progress.total_bytes,
|
total=repo_download_progress.total,
|
||||||
downloaded_bytes=repo_download_progress.downloaded_bytes,
|
downloaded=repo_download_progress.downloaded,
|
||||||
downloaded_bytes_this_session=repo_download_progress.downloaded_bytes_this_session,
|
downloaded_this_session=repo_download_progress.downloaded_this_session,
|
||||||
completed_files=repo_download_progress.completed_files,
|
completed_files=repo_download_progress.completed_files,
|
||||||
total_files=repo_download_progress.total_files,
|
total_files=repo_download_progress.total_files,
|
||||||
speed=repo_download_progress.overall_speed,
|
speed=repo_download_progress.overall_speed,
|
||||||
@@ -142,7 +142,7 @@ async def delete_model(model_id: ModelId) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
async def seed_models(seed_dir: str | Path):
|
async def seed_models(seed_dir: str | Path):
|
||||||
"""Move model in resources folder of app to .cache/huggingface/hub"""
|
"""Move models from resources folder to EXO_MODELS_DIR."""
|
||||||
source_dir = Path(seed_dir)
|
source_dir = Path(seed_dir)
|
||||||
dest_dir = await ensure_models_dir()
|
dest_dir = await ensure_models_dir()
|
||||||
for path in source_dir.iterdir():
|
for path in source_dir.iterdir():
|
||||||
@@ -578,19 +578,20 @@ def calculate_repo_progress(
|
|||||||
file_progress: dict[str, RepoFileDownloadProgress],
|
file_progress: dict[str, RepoFileDownloadProgress],
|
||||||
all_start_time: float,
|
all_start_time: float,
|
||||||
) -> RepoDownloadProgress:
|
) -> RepoDownloadProgress:
|
||||||
all_total_bytes = sum((p.total.in_bytes for p in file_progress.values()), 0)
|
all_total = sum((p.total for p in file_progress.values()), Memory.from_bytes(0))
|
||||||
all_downloaded_bytes = sum(
|
all_downloaded = sum(
|
||||||
(p.downloaded.in_bytes for p in file_progress.values()), 0
|
(p.downloaded for p in file_progress.values()), Memory.from_bytes(0)
|
||||||
)
|
)
|
||||||
all_downloaded_bytes_this_session = sum(
|
all_downloaded_this_session = sum(
|
||||||
(p.downloaded_this_session.in_bytes for p in file_progress.values()), 0
|
(p.downloaded_this_session for p in file_progress.values()),
|
||||||
|
Memory.from_bytes(0),
|
||||||
)
|
)
|
||||||
elapsed_time = time.time() - all_start_time
|
elapsed_time = time.time() - all_start_time
|
||||||
all_speed = (
|
all_speed = (
|
||||||
all_downloaded_bytes_this_session / elapsed_time if elapsed_time > 0 else 0
|
all_downloaded_this_session.in_bytes / elapsed_time if elapsed_time > 0 else 0
|
||||||
)
|
)
|
||||||
all_eta = (
|
all_eta = (
|
||||||
timedelta(seconds=(all_total_bytes - all_downloaded_bytes) / all_speed)
|
timedelta(seconds=(all_total - all_downloaded).in_bytes / all_speed)
|
||||||
if all_speed > 0
|
if all_speed > 0
|
||||||
else timedelta(seconds=0)
|
else timedelta(seconds=0)
|
||||||
)
|
)
|
||||||
@@ -609,11 +610,9 @@ def calculate_repo_progress(
|
|||||||
[p for p in file_progress.values() if p.downloaded == p.total]
|
[p for p in file_progress.values() if p.downloaded == p.total]
|
||||||
),
|
),
|
||||||
total_files=len(file_progress),
|
total_files=len(file_progress),
|
||||||
downloaded_bytes=Memory.from_bytes(all_downloaded_bytes),
|
downloaded=all_downloaded,
|
||||||
downloaded_bytes_this_session=Memory.from_bytes(
|
downloaded_this_session=all_downloaded_this_session,
|
||||||
all_downloaded_bytes_this_session
|
total=all_total,
|
||||||
),
|
|
||||||
total_bytes=Memory.from_bytes(all_total_bytes),
|
|
||||||
overall_speed=all_speed,
|
overall_speed=all_speed,
|
||||||
overall_eta=all_eta,
|
overall_eta=all_eta,
|
||||||
status=status,
|
status=status,
|
||||||
|
|||||||
@@ -107,9 +107,9 @@ NOOP_DOWNLOAD_PROGRESS = RepoDownloadProgress(
|
|||||||
),
|
),
|
||||||
completed_files=0,
|
completed_files=0,
|
||||||
total_files=0,
|
total_files=0,
|
||||||
downloaded_bytes=Memory.from_bytes(0),
|
downloaded=Memory.from_bytes(0),
|
||||||
downloaded_bytes_this_session=Memory.from_bytes(0),
|
downloaded_this_session=Memory.from_bytes(0),
|
||||||
total_bytes=Memory.from_bytes(0),
|
total=Memory.from_bytes(0),
|
||||||
overall_speed=0,
|
overall_speed=0,
|
||||||
overall_eta=timedelta(seconds=0),
|
overall_eta=timedelta(seconds=0),
|
||||||
status="complete",
|
status="complete",
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ class Node:
|
|||||||
@classmethod
|
@classmethod
|
||||||
async def create(cls, args: "Args") -> "Self":
|
async def create(cls, args: "Args") -> "Self":
|
||||||
keypair = get_node_id_keypair()
|
keypair = get_node_id_keypair()
|
||||||
node_id = NodeId(keypair.to_peer_id().to_base58())
|
node_id = NodeId(keypair.to_node_id())
|
||||||
session_id = SessionId(master_node_id=node_id, election_clock=0)
|
session_id = SessionId(master_node_id=node_id, election_clock=0)
|
||||||
router = Router.create(keypair)
|
router = Router.create(keypair)
|
||||||
await router.register_topic(topics.GLOBAL_EVENTS)
|
await router.register_topic(topics.GLOBAL_EVENTS)
|
||||||
|
|||||||
@@ -59,7 +59,11 @@ def chat_request_to_text_generation(
|
|||||||
chat_template_messages.append({"role": "system", "content": content})
|
chat_template_messages.append({"role": "system", "content": content})
|
||||||
else:
|
else:
|
||||||
# Skip messages with no meaningful content
|
# Skip messages with no meaningful content
|
||||||
if msg.content is None and msg.thinking is None and msg.tool_calls is None:
|
if (
|
||||||
|
msg.content is None
|
||||||
|
and msg.reasoning_content is None
|
||||||
|
and msg.tool_calls is None
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if msg.role in ("user", "assistant", "developer"):
|
if msg.role in ("user", "assistant", "developer"):
|
||||||
@@ -111,6 +115,11 @@ def chunk_to_response(
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if chunk.is_thinking:
|
||||||
|
delta = ChatCompletionMessage(role="assistant", reasoning_content=chunk.text)
|
||||||
|
else:
|
||||||
|
delta = ChatCompletionMessage(role="assistant", content=chunk.text)
|
||||||
|
|
||||||
return ChatCompletionResponse(
|
return ChatCompletionResponse(
|
||||||
id=command_id,
|
id=command_id,
|
||||||
created=int(time.time()),
|
created=int(time.time()),
|
||||||
@@ -118,7 +127,7 @@ def chunk_to_response(
|
|||||||
choices=[
|
choices=[
|
||||||
StreamingChoiceResponse(
|
StreamingChoiceResponse(
|
||||||
index=0,
|
index=0,
|
||||||
delta=ChatCompletionMessage(role="assistant", content=chunk.text),
|
delta=delta,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
finish_reason=chunk.finish_reason,
|
finish_reason=chunk.finish_reason,
|
||||||
)
|
)
|
||||||
@@ -208,6 +217,7 @@ async def collect_chat_response(
|
|||||||
# FastAPI handles the cancellation better but wouldn't auto-serialize for some reason
|
# FastAPI handles the cancellation better but wouldn't auto-serialize for some reason
|
||||||
"""Collect all token chunks and return a single ChatCompletionResponse."""
|
"""Collect all token chunks and return a single ChatCompletionResponse."""
|
||||||
text_parts: list[str] = []
|
text_parts: list[str] = []
|
||||||
|
thinking_parts: list[str] = []
|
||||||
tool_calls: list[ToolCall] = []
|
tool_calls: list[ToolCall] = []
|
||||||
logprobs_content: list[LogprobsContentItem] = []
|
logprobs_content: list[LogprobsContentItem] = []
|
||||||
model: str | None = None
|
model: str | None = None
|
||||||
@@ -228,7 +238,10 @@ async def collect_chat_response(
|
|||||||
if model is None:
|
if model is None:
|
||||||
model = chunk.model
|
model = chunk.model
|
||||||
last_usage = chunk.usage or last_usage
|
last_usage = chunk.usage or last_usage
|
||||||
text_parts.append(chunk.text)
|
if chunk.is_thinking:
|
||||||
|
thinking_parts.append(chunk.text)
|
||||||
|
else:
|
||||||
|
text_parts.append(chunk.text)
|
||||||
if chunk.logprob is not None:
|
if chunk.logprob is not None:
|
||||||
logprobs_content.append(
|
logprobs_content.append(
|
||||||
LogprobsContentItem(
|
LogprobsContentItem(
|
||||||
@@ -258,6 +271,7 @@ async def collect_chat_response(
|
|||||||
raise ValueError(error_message)
|
raise ValueError(error_message)
|
||||||
|
|
||||||
combined_text = "".join(text_parts)
|
combined_text = "".join(text_parts)
|
||||||
|
combined_thinking = "".join(thinking_parts) if thinking_parts else None
|
||||||
assert model is not None
|
assert model is not None
|
||||||
|
|
||||||
yield ChatCompletionResponse(
|
yield ChatCompletionResponse(
|
||||||
@@ -270,6 +284,7 @@ async def collect_chat_response(
|
|||||||
message=ChatCompletionMessage(
|
message=ChatCompletionMessage(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
content=combined_text,
|
content=combined_text,
|
||||||
|
reasoning_content=combined_thinking,
|
||||||
tool_calls=tool_calls if tool_calls else None,
|
tool_calls=tool_calls if tool_calls else None,
|
||||||
),
|
),
|
||||||
logprobs=Logprobs(content=logprobs_content)
|
logprobs=Logprobs(content=logprobs_content)
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""Claude Messages API adapter for converting requests/responses."""
|
"""Claude Messages API adapter for converting requests/responses."""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import re
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@@ -28,6 +29,8 @@ from exo.shared.types.claude_api import (
|
|||||||
ClaudeStopReason,
|
ClaudeStopReason,
|
||||||
ClaudeTextBlock,
|
ClaudeTextBlock,
|
||||||
ClaudeTextDelta,
|
ClaudeTextDelta,
|
||||||
|
ClaudeThinkingBlock,
|
||||||
|
ClaudeThinkingDelta,
|
||||||
ClaudeToolResultBlock,
|
ClaudeToolResultBlock,
|
||||||
ClaudeToolUseBlock,
|
ClaudeToolUseBlock,
|
||||||
ClaudeUsage,
|
ClaudeUsage,
|
||||||
@@ -61,6 +64,22 @@ def _extract_tool_result_text(block: ClaudeToolResultBlock) -> str:
|
|||||||
return "".join(sub_block.text for sub_block in block.content)
|
return "".join(sub_block.text for sub_block in block.content)
|
||||||
|
|
||||||
|
|
||||||
|
# Matches "x-anthropic-billing-header: ...;" (with optional trailing newline)
|
||||||
|
# or similar telemetry headers that change every request and break KV prefix caching.
|
||||||
|
_VOLATILE_HEADER_RE = re.compile(r"^x-anthropic-[^\n]*;\n?", re.MULTILINE)
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_volatile_headers(text: str) -> str:
|
||||||
|
"""Remove Anthropic billing/telemetry headers from system prompt text.
|
||||||
|
|
||||||
|
Claude Code prepends headers like 'x-anthropic-billing-header: cc_version=...;
|
||||||
|
cc_entrypoint=...; cch=...;' that contain per-request content hashes. These
|
||||||
|
change every request and break KV prefix caching (the prefix diverges at ~20
|
||||||
|
tokens instead of matching thousands of conversation tokens).
|
||||||
|
"""
|
||||||
|
return _VOLATILE_HEADER_RE.sub("", text)
|
||||||
|
|
||||||
|
|
||||||
def claude_request_to_text_generation(
|
def claude_request_to_text_generation(
|
||||||
request: ClaudeMessagesRequest,
|
request: ClaudeMessagesRequest,
|
||||||
) -> TextGenerationTaskParams:
|
) -> TextGenerationTaskParams:
|
||||||
@@ -73,6 +92,8 @@ def claude_request_to_text_generation(
|
|||||||
instructions = request.system
|
instructions = request.system
|
||||||
else:
|
else:
|
||||||
instructions = "".join(block.text for block in request.system)
|
instructions = "".join(block.text for block in request.system)
|
||||||
|
|
||||||
|
instructions = _strip_volatile_headers(instructions)
|
||||||
chat_template_messages.append({"role": "system", "content": instructions})
|
chat_template_messages.append({"role": "system", "content": instructions})
|
||||||
|
|
||||||
# Convert messages to input
|
# Convert messages to input
|
||||||
@@ -85,12 +106,15 @@ def claude_request_to_text_generation(
|
|||||||
|
|
||||||
# Process structured content blocks
|
# Process structured content blocks
|
||||||
text_parts: list[str] = []
|
text_parts: list[str] = []
|
||||||
|
thinking_parts: list[str] = []
|
||||||
tool_calls: list[dict[str, Any]] = []
|
tool_calls: list[dict[str, Any]] = []
|
||||||
tool_results: list[ClaudeToolResultBlock] = []
|
tool_results: list[ClaudeToolResultBlock] = []
|
||||||
|
|
||||||
for block in msg.content:
|
for block in msg.content:
|
||||||
if isinstance(block, ClaudeTextBlock):
|
if isinstance(block, ClaudeTextBlock):
|
||||||
text_parts.append(block.text)
|
text_parts.append(block.text)
|
||||||
|
elif isinstance(block, ClaudeThinkingBlock):
|
||||||
|
thinking_parts.append(block.thinking)
|
||||||
elif isinstance(block, ClaudeToolUseBlock):
|
elif isinstance(block, ClaudeToolUseBlock):
|
||||||
tool_calls.append(
|
tool_calls.append(
|
||||||
{
|
{
|
||||||
@@ -106,6 +130,7 @@ def claude_request_to_text_generation(
|
|||||||
tool_results.append(block)
|
tool_results.append(block)
|
||||||
|
|
||||||
content = "".join(text_parts)
|
content = "".join(text_parts)
|
||||||
|
reasoning_content = "".join(thinking_parts) if thinking_parts else None
|
||||||
|
|
||||||
# Build InputMessage from text content
|
# Build InputMessage from text content
|
||||||
if msg.role in ("user", "assistant"):
|
if msg.role in ("user", "assistant"):
|
||||||
@@ -113,9 +138,14 @@ def claude_request_to_text_generation(
|
|||||||
|
|
||||||
# Build chat_template_messages preserving tool structure
|
# Build chat_template_messages preserving tool structure
|
||||||
if tool_calls:
|
if tool_calls:
|
||||||
chat_template_messages.append(
|
chat_msg: dict[str, Any] = {
|
||||||
{"role": "assistant", "content": content, "tool_calls": tool_calls}
|
"role": "assistant",
|
||||||
)
|
"content": content,
|
||||||
|
"tool_calls": tool_calls,
|
||||||
|
}
|
||||||
|
if reasoning_content:
|
||||||
|
chat_msg["reasoning_content"] = reasoning_content
|
||||||
|
chat_template_messages.append(chat_msg)
|
||||||
elif tool_results:
|
elif tool_results:
|
||||||
for tr in tool_results:
|
for tr in tool_results:
|
||||||
chat_template_messages.append(
|
chat_template_messages.append(
|
||||||
@@ -126,7 +156,10 @@ def claude_request_to_text_generation(
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
chat_template_messages.append({"role": msg.role, "content": content})
|
chat_msg = {"role": msg.role, "content": content}
|
||||||
|
if reasoning_content:
|
||||||
|
chat_msg["reasoning_content"] = reasoning_content
|
||||||
|
chat_template_messages.append(chat_msg)
|
||||||
|
|
||||||
# Convert Claude tool definitions to OpenAI-style function tools
|
# Convert Claude tool definitions to OpenAI-style function tools
|
||||||
tools: list[dict[str, Any]] | None = None
|
tools: list[dict[str, Any]] | None = None
|
||||||
@@ -143,6 +176,10 @@ def claude_request_to_text_generation(
|
|||||||
for tool in request.tools
|
for tool in request.tools
|
||||||
]
|
]
|
||||||
|
|
||||||
|
enable_thinking: bool | None = None
|
||||||
|
if request.thinking is not None:
|
||||||
|
enable_thinking = request.thinking.type in ("enabled", "adaptive")
|
||||||
|
|
||||||
return TextGenerationTaskParams(
|
return TextGenerationTaskParams(
|
||||||
model=request.model,
|
model=request.model,
|
||||||
input=input_messages
|
input=input_messages
|
||||||
@@ -156,6 +193,7 @@ def claude_request_to_text_generation(
|
|||||||
stop=request.stop_sequences,
|
stop=request.stop_sequences,
|
||||||
stream=request.stream,
|
stream=request.stream,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
|
enable_thinking=enable_thinking,
|
||||||
chat_template_messages=chat_template_messages
|
chat_template_messages=chat_template_messages
|
||||||
if chat_template_messages
|
if chat_template_messages
|
||||||
else None,
|
else None,
|
||||||
@@ -173,6 +211,7 @@ async def collect_claude_response(
|
|||||||
# FastAPI handles the cancellation better but wouldn't auto-serialize for some reason
|
# FastAPI handles the cancellation better but wouldn't auto-serialize for some reason
|
||||||
"""Collect all token chunks and return a single ClaudeMessagesResponse."""
|
"""Collect all token chunks and return a single ClaudeMessagesResponse."""
|
||||||
text_parts: list[str] = []
|
text_parts: list[str] = []
|
||||||
|
thinking_parts: list[str] = []
|
||||||
tool_use_blocks: list[ClaudeToolUseBlock] = []
|
tool_use_blocks: list[ClaudeToolUseBlock] = []
|
||||||
stop_reason: ClaudeStopReason | None = None
|
stop_reason: ClaudeStopReason | None = None
|
||||||
last_usage: Usage | None = None
|
last_usage: Usage | None = None
|
||||||
@@ -200,7 +239,10 @@ async def collect_claude_response(
|
|||||||
stop_reason = "tool_use"
|
stop_reason = "tool_use"
|
||||||
continue
|
continue
|
||||||
|
|
||||||
text_parts.append(chunk.text)
|
if chunk.is_thinking:
|
||||||
|
thinking_parts.append(chunk.text)
|
||||||
|
else:
|
||||||
|
text_parts.append(chunk.text)
|
||||||
|
|
||||||
if chunk.finish_reason is not None:
|
if chunk.finish_reason is not None:
|
||||||
stop_reason = finish_reason_to_claude_stop_reason(chunk.finish_reason)
|
stop_reason = finish_reason_to_claude_stop_reason(chunk.finish_reason)
|
||||||
@@ -209,9 +251,12 @@ async def collect_claude_response(
|
|||||||
raise ValueError(error_message)
|
raise ValueError(error_message)
|
||||||
|
|
||||||
combined_text = "".join(text_parts)
|
combined_text = "".join(text_parts)
|
||||||
|
combined_thinking = "".join(thinking_parts)
|
||||||
|
|
||||||
# Build content blocks
|
# Build content blocks
|
||||||
content: list[ClaudeContentBlock] = []
|
content: list[ClaudeContentBlock] = []
|
||||||
|
if combined_thinking:
|
||||||
|
content.append(ClaudeThinkingBlock(thinking=combined_thinking))
|
||||||
if combined_text:
|
if combined_text:
|
||||||
content.append(ClaudeTextBlock(text=combined_text))
|
content.append(ClaudeTextBlock(text=combined_text))
|
||||||
content.extend(tool_use_blocks)
|
content.extend(tool_use_blocks)
|
||||||
@@ -256,16 +301,16 @@ async def generate_claude_stream(
|
|||||||
start_event = ClaudeMessageStartEvent(message=initial_message)
|
start_event = ClaudeMessageStartEvent(message=initial_message)
|
||||||
yield f"event: message_start\ndata: {start_event.model_dump_json()}\n\n"
|
yield f"event: message_start\ndata: {start_event.model_dump_json()}\n\n"
|
||||||
|
|
||||||
# content_block_start for text block at index 0
|
|
||||||
block_start = ClaudeContentBlockStartEvent(
|
|
||||||
index=0, content_block=ClaudeTextBlock(text="")
|
|
||||||
)
|
|
||||||
yield f"event: content_block_start\ndata: {block_start.model_dump_json()}\n\n"
|
|
||||||
|
|
||||||
output_tokens = 0
|
output_tokens = 0
|
||||||
stop_reason: ClaudeStopReason | None = None
|
stop_reason: ClaudeStopReason | None = None
|
||||||
last_usage: Usage | None = None
|
last_usage: Usage | None = None
|
||||||
next_block_index = 1 # text block is 0, tool blocks start at 1
|
next_block_index = 0
|
||||||
|
|
||||||
|
# Track whether we've started thinking/text blocks
|
||||||
|
thinking_block_started = False
|
||||||
|
thinking_block_index = -1
|
||||||
|
text_block_started = False
|
||||||
|
text_block_index = -1
|
||||||
|
|
||||||
async for chunk in chunk_stream:
|
async for chunk in chunk_stream:
|
||||||
if isinstance(chunk, PrefillProgressChunk):
|
if isinstance(chunk, PrefillProgressChunk):
|
||||||
@@ -310,12 +355,45 @@ async def generate_claude_stream(
|
|||||||
|
|
||||||
output_tokens += 1 # Count each chunk as one token
|
output_tokens += 1 # Count each chunk as one token
|
||||||
|
|
||||||
# content_block_delta
|
if chunk.is_thinking:
|
||||||
delta_event = ClaudeContentBlockDeltaEvent(
|
# Start thinking block on first thinking token
|
||||||
index=0,
|
if not thinking_block_started:
|
||||||
delta=ClaudeTextDelta(text=chunk.text),
|
thinking_block_started = True
|
||||||
)
|
thinking_block_index = next_block_index
|
||||||
yield f"event: content_block_delta\ndata: {delta_event.model_dump_json()}\n\n"
|
next_block_index += 1
|
||||||
|
block_start = ClaudeContentBlockStartEvent(
|
||||||
|
index=thinking_block_index,
|
||||||
|
content_block=ClaudeThinkingBlock(thinking=""),
|
||||||
|
)
|
||||||
|
yield f"event: content_block_start\ndata: {block_start.model_dump_json()}\n\n"
|
||||||
|
|
||||||
|
delta_event = ClaudeContentBlockDeltaEvent(
|
||||||
|
index=thinking_block_index,
|
||||||
|
delta=ClaudeThinkingDelta(thinking=chunk.text),
|
||||||
|
)
|
||||||
|
yield f"event: content_block_delta\ndata: {delta_event.model_dump_json()}\n\n"
|
||||||
|
else:
|
||||||
|
# Close thinking block when transitioning to text
|
||||||
|
if thinking_block_started and text_block_index == -1:
|
||||||
|
block_stop = ClaudeContentBlockStopEvent(index=thinking_block_index)
|
||||||
|
yield f"event: content_block_stop\ndata: {block_stop.model_dump_json()}\n\n"
|
||||||
|
|
||||||
|
# Start text block on first text token
|
||||||
|
if not text_block_started:
|
||||||
|
text_block_started = True
|
||||||
|
text_block_index = next_block_index
|
||||||
|
next_block_index += 1
|
||||||
|
block_start = ClaudeContentBlockStartEvent(
|
||||||
|
index=text_block_index,
|
||||||
|
content_block=ClaudeTextBlock(text=""),
|
||||||
|
)
|
||||||
|
yield f"event: content_block_start\ndata: {block_start.model_dump_json()}\n\n"
|
||||||
|
|
||||||
|
delta_event = ClaudeContentBlockDeltaEvent(
|
||||||
|
index=text_block_index,
|
||||||
|
delta=ClaudeTextDelta(text=chunk.text),
|
||||||
|
)
|
||||||
|
yield f"event: content_block_delta\ndata: {delta_event.model_dump_json()}\n\n"
|
||||||
|
|
||||||
if chunk.finish_reason is not None:
|
if chunk.finish_reason is not None:
|
||||||
stop_reason = finish_reason_to_claude_stop_reason(chunk.finish_reason)
|
stop_reason = finish_reason_to_claude_stop_reason(chunk.finish_reason)
|
||||||
@@ -324,9 +402,22 @@ async def generate_claude_stream(
|
|||||||
if last_usage is not None:
|
if last_usage is not None:
|
||||||
output_tokens = last_usage.completion_tokens
|
output_tokens = last_usage.completion_tokens
|
||||||
|
|
||||||
# content_block_stop for text block
|
# Close any open blocks
|
||||||
block_stop = ClaudeContentBlockStopEvent(index=0)
|
if thinking_block_started and text_block_index == -1:
|
||||||
yield f"event: content_block_stop\ndata: {block_stop.model_dump_json()}\n\n"
|
block_stop = ClaudeContentBlockStopEvent(index=thinking_block_index)
|
||||||
|
yield f"event: content_block_stop\ndata: {block_stop.model_dump_json()}\n\n"
|
||||||
|
|
||||||
|
if text_block_started:
|
||||||
|
block_stop = ClaudeContentBlockStopEvent(index=text_block_index)
|
||||||
|
yield f"event: content_block_stop\ndata: {block_stop.model_dump_json()}\n\n"
|
||||||
|
|
||||||
|
if not thinking_block_started and not text_block_started:
|
||||||
|
empty_start = ClaudeContentBlockStartEvent(
|
||||||
|
index=0, content_block=ClaudeTextBlock(text="")
|
||||||
|
)
|
||||||
|
yield f"event: content_block_start\ndata: {empty_start.model_dump_json()}\n\n"
|
||||||
|
empty_stop = ClaudeContentBlockStopEvent(index=0)
|
||||||
|
yield f"event: content_block_stop\ndata: {empty_stop.model_dump_json()}\n\n"
|
||||||
|
|
||||||
# message_delta
|
# message_delta
|
||||||
message_delta = ClaudeMessageDeltaEvent(
|
message_delta = ClaudeMessageDeltaEvent(
|
||||||
|
|||||||
456
src/exo/master/adapters/ollama.py
Normal file
456
src/exo/master/adapters/ollama.py
Normal file
@@ -0,0 +1,456 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from exo.shared.types.chunks import (
|
||||||
|
ErrorChunk,
|
||||||
|
PrefillProgressChunk,
|
||||||
|
TokenChunk,
|
||||||
|
ToolCallChunk,
|
||||||
|
)
|
||||||
|
from exo.shared.types.common import CommandId
|
||||||
|
from exo.shared.types.ollama_api import (
|
||||||
|
OllamaChatRequest,
|
||||||
|
OllamaChatResponse,
|
||||||
|
OllamaDoneReason,
|
||||||
|
OllamaGenerateRequest,
|
||||||
|
OllamaGenerateResponse,
|
||||||
|
OllamaMessage,
|
||||||
|
OllamaToolCall,
|
||||||
|
OllamaToolFunction,
|
||||||
|
)
|
||||||
|
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
||||||
|
|
||||||
|
|
||||||
|
def _map_done_reason(
|
||||||
|
finish_reason: str | None,
|
||||||
|
) -> OllamaDoneReason | None:
|
||||||
|
if finish_reason is None:
|
||||||
|
return None
|
||||||
|
if finish_reason == "stop":
|
||||||
|
return "stop"
|
||||||
|
if finish_reason == "length":
|
||||||
|
return "length"
|
||||||
|
if finish_reason in ("tool_calls", "function_call"):
|
||||||
|
return "tool_call"
|
||||||
|
if finish_reason == "error":
|
||||||
|
return "error"
|
||||||
|
return "stop"
|
||||||
|
|
||||||
|
|
||||||
|
def _try_parse_json(value: str) -> dict[str, Any] | str:
|
||||||
|
try:
|
||||||
|
return json.loads(value) # type: ignore
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def _build_tool_calls(chunk: ToolCallChunk) -> list[OllamaToolCall]:
|
||||||
|
tool_calls: list[OllamaToolCall] = []
|
||||||
|
for index, tool in enumerate(chunk.tool_calls):
|
||||||
|
# tool.arguments is always str; try to parse as JSON dict for Ollama format
|
||||||
|
arguments: dict[str, Any] | str = _try_parse_json(tool.arguments)
|
||||||
|
tool_calls.append(
|
||||||
|
OllamaToolCall(
|
||||||
|
id=tool.id,
|
||||||
|
type="function",
|
||||||
|
function=OllamaToolFunction(
|
||||||
|
name=tool.name, arguments=arguments, index=index
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return tool_calls
|
||||||
|
|
||||||
|
|
||||||
|
def _get_usage(
|
||||||
|
chunk: TokenChunk | ToolCallChunk,
|
||||||
|
) -> tuple[int | None, int | None]:
|
||||||
|
"""Extract (prompt_eval_count, eval_count) from a chunk."""
|
||||||
|
if chunk.usage is not None:
|
||||||
|
return (chunk.usage.prompt_tokens, chunk.usage.completion_tokens)
|
||||||
|
if chunk.stats is not None:
|
||||||
|
return (chunk.stats.prompt_tokens, chunk.stats.generation_tokens)
|
||||||
|
return (None, None)
|
||||||
|
|
||||||
|
|
||||||
|
def ollama_request_to_text_generation(
|
||||||
|
request: OllamaChatRequest,
|
||||||
|
) -> TextGenerationTaskParams:
|
||||||
|
"""Convert Ollama chat request to exo's internal text generation format."""
|
||||||
|
instructions: str | None = None
|
||||||
|
input_messages: list[InputMessage] = []
|
||||||
|
chat_template_messages: list[dict[str, Any]] = []
|
||||||
|
tool_message_index = 0
|
||||||
|
|
||||||
|
for msg in request.messages:
|
||||||
|
content = msg.content or ""
|
||||||
|
|
||||||
|
if msg.role == "system":
|
||||||
|
if instructions is None:
|
||||||
|
instructions = content
|
||||||
|
else:
|
||||||
|
instructions = f"{instructions}\n{content}"
|
||||||
|
chat_template_messages.append({"role": "system", "content": content})
|
||||||
|
continue
|
||||||
|
|
||||||
|
if msg.role in ("user", "assistant") and (
|
||||||
|
msg.content is not None or msg.thinking is not None or msg.tool_calls
|
||||||
|
):
|
||||||
|
input_messages.append(InputMessage(role=msg.role, content=content))
|
||||||
|
|
||||||
|
dumped: dict[str, Any] = {"role": msg.role, "content": content}
|
||||||
|
if msg.thinking is not None:
|
||||||
|
dumped["thinking"] = msg.thinking
|
||||||
|
if msg.tool_calls is not None:
|
||||||
|
tool_calls_list: list[dict[str, Any]] = []
|
||||||
|
for tc in msg.tool_calls:
|
||||||
|
function: dict[str, Any] = {
|
||||||
|
"name": tc.function.name,
|
||||||
|
"arguments": (
|
||||||
|
json.dumps(tc.function.arguments)
|
||||||
|
if isinstance(tc.function.arguments, dict)
|
||||||
|
else tc.function.arguments
|
||||||
|
),
|
||||||
|
}
|
||||||
|
if tc.function.index is not None:
|
||||||
|
function["index"] = tc.function.index
|
||||||
|
tool_call: dict[str, Any] = {"function": function}
|
||||||
|
if tc.id is not None:
|
||||||
|
tool_call["id"] = tc.id
|
||||||
|
if tc.type is not None:
|
||||||
|
tool_call["type"] = tc.type
|
||||||
|
tool_calls_list.append(tool_call)
|
||||||
|
dumped["tool_calls"] = tool_calls_list
|
||||||
|
if msg.name is not None:
|
||||||
|
dumped["name"] = msg.name
|
||||||
|
if msg.role == "tool":
|
||||||
|
tool_message_index += 1
|
||||||
|
tool_call_id = msg.tool_name or msg.name or f"tool_{tool_message_index}"
|
||||||
|
dumped["tool_call_id"] = tool_call_id
|
||||||
|
if msg.tool_name is not None:
|
||||||
|
dumped["tool_name"] = msg.tool_name
|
||||||
|
chat_template_messages.append(dumped)
|
||||||
|
|
||||||
|
options = request.options
|
||||||
|
return TextGenerationTaskParams(
|
||||||
|
model=request.model,
|
||||||
|
input=input_messages
|
||||||
|
if input_messages
|
||||||
|
else [InputMessage(role="user", content="")],
|
||||||
|
instructions=instructions,
|
||||||
|
max_output_tokens=options.num_predict if options else None,
|
||||||
|
temperature=options.temperature if options else None,
|
||||||
|
top_p=options.top_p if options else None,
|
||||||
|
top_k=options.top_k if options else None,
|
||||||
|
stop=options.stop if options else None,
|
||||||
|
seed=options.seed if options else None,
|
||||||
|
stream=request.stream,
|
||||||
|
tools=request.tools,
|
||||||
|
enable_thinking=request.think,
|
||||||
|
chat_template_messages=chat_template_messages
|
||||||
|
if chat_template_messages
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_ollama_chat_stream(
|
||||||
|
_command_id: CommandId,
|
||||||
|
chunk_stream: AsyncGenerator[
|
||||||
|
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
|
||||||
|
],
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
"""Generate streaming responses in Ollama format (newline-delimited JSON)."""
|
||||||
|
thinking_parts: list[str] = []
|
||||||
|
|
||||||
|
async for chunk in chunk_stream:
|
||||||
|
match chunk:
|
||||||
|
case PrefillProgressChunk():
|
||||||
|
continue
|
||||||
|
|
||||||
|
case ErrorChunk():
|
||||||
|
error_response = OllamaChatResponse(
|
||||||
|
model=str(chunk.model),
|
||||||
|
message=OllamaMessage(
|
||||||
|
role="assistant", content=chunk.error_message
|
||||||
|
),
|
||||||
|
done=True,
|
||||||
|
done_reason="error",
|
||||||
|
)
|
||||||
|
yield f"{error_response.model_dump_json(exclude_none=True)}\n"
|
||||||
|
return
|
||||||
|
|
||||||
|
case ToolCallChunk():
|
||||||
|
prompt_eval, eval_count = _get_usage(chunk)
|
||||||
|
response = OllamaChatResponse(
|
||||||
|
model=str(chunk.model),
|
||||||
|
message=OllamaMessage(
|
||||||
|
role="assistant",
|
||||||
|
content="",
|
||||||
|
tool_calls=_build_tool_calls(chunk),
|
||||||
|
thinking="".join(thinking_parts) if thinking_parts else None,
|
||||||
|
),
|
||||||
|
done=True,
|
||||||
|
done_reason="tool_call",
|
||||||
|
prompt_eval_count=prompt_eval,
|
||||||
|
eval_count=eval_count,
|
||||||
|
)
|
||||||
|
yield f"{response.model_dump_json(exclude_none=True)}\n"
|
||||||
|
return
|
||||||
|
|
||||||
|
case TokenChunk():
|
||||||
|
done = chunk.finish_reason is not None
|
||||||
|
|
||||||
|
if chunk.is_thinking:
|
||||||
|
thinking_parts.append(chunk.text)
|
||||||
|
response = OllamaChatResponse(
|
||||||
|
model=str(chunk.model),
|
||||||
|
message=OllamaMessage(
|
||||||
|
role="assistant", content="", thinking=chunk.text
|
||||||
|
),
|
||||||
|
done=False,
|
||||||
|
)
|
||||||
|
yield f"{response.model_dump_json(exclude_none=True)}\n"
|
||||||
|
elif done:
|
||||||
|
prompt_eval, eval_count = _get_usage(chunk)
|
||||||
|
response = OllamaChatResponse(
|
||||||
|
model=str(chunk.model),
|
||||||
|
message=OllamaMessage(
|
||||||
|
role="assistant",
|
||||||
|
content=chunk.text,
|
||||||
|
),
|
||||||
|
done=True,
|
||||||
|
done_reason=_map_done_reason(chunk.finish_reason),
|
||||||
|
prompt_eval_count=prompt_eval,
|
||||||
|
eval_count=eval_count,
|
||||||
|
)
|
||||||
|
yield f"{response.model_dump_json(exclude_none=True)}\n"
|
||||||
|
else:
|
||||||
|
response = OllamaChatResponse(
|
||||||
|
model=str(chunk.model),
|
||||||
|
message=OllamaMessage(role="assistant", content=chunk.text),
|
||||||
|
done=False,
|
||||||
|
)
|
||||||
|
yield f"{response.model_dump_json(exclude_none=True)}\n"
|
||||||
|
|
||||||
|
if done:
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
async def collect_ollama_chat_response(
|
||||||
|
_command_id: CommandId,
|
||||||
|
chunk_stream: AsyncGenerator[
|
||||||
|
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
|
||||||
|
],
|
||||||
|
) -> AsyncGenerator[str]:
|
||||||
|
"""Collect streaming chunks into a single non-streaming Ollama response.
|
||||||
|
|
||||||
|
Returns an AsyncGenerator[str] (single yield) for consistency with FastAPI
|
||||||
|
StreamingResponse cancellation handling.
|
||||||
|
"""
|
||||||
|
text_parts: list[str] = []
|
||||||
|
thinking_parts: list[str] = []
|
||||||
|
tool_calls: list[OllamaToolCall] = []
|
||||||
|
model: str | None = None
|
||||||
|
finish_reason: str | None = None
|
||||||
|
prompt_eval_count: int | None = None
|
||||||
|
eval_count: int | None = None
|
||||||
|
|
||||||
|
async for chunk in chunk_stream:
|
||||||
|
match chunk:
|
||||||
|
case PrefillProgressChunk():
|
||||||
|
continue
|
||||||
|
|
||||||
|
case ErrorChunk():
|
||||||
|
raise ValueError(chunk.error_message or "Internal server error")
|
||||||
|
|
||||||
|
case TokenChunk():
|
||||||
|
if model is None:
|
||||||
|
model = str(chunk.model)
|
||||||
|
if chunk.is_thinking:
|
||||||
|
thinking_parts.append(chunk.text)
|
||||||
|
else:
|
||||||
|
text_parts.append(chunk.text)
|
||||||
|
if chunk.finish_reason is not None:
|
||||||
|
finish_reason = chunk.finish_reason
|
||||||
|
prompt_eval_count, eval_count = _get_usage(chunk)
|
||||||
|
|
||||||
|
case ToolCallChunk():
|
||||||
|
if model is None:
|
||||||
|
model = str(chunk.model)
|
||||||
|
tool_calls.extend(_build_tool_calls(chunk))
|
||||||
|
finish_reason = chunk.finish_reason
|
||||||
|
prompt_eval_count, eval_count = _get_usage(chunk)
|
||||||
|
|
||||||
|
combined_text = "".join(text_parts)
|
||||||
|
combined_thinking = "".join(thinking_parts) if thinking_parts else None
|
||||||
|
assert model is not None
|
||||||
|
|
||||||
|
yield OllamaChatResponse(
|
||||||
|
model=model,
|
||||||
|
message=OllamaMessage(
|
||||||
|
role="assistant",
|
||||||
|
content=combined_text,
|
||||||
|
thinking=combined_thinking,
|
||||||
|
tool_calls=tool_calls if tool_calls else None,
|
||||||
|
),
|
||||||
|
done=True,
|
||||||
|
done_reason=_map_done_reason(finish_reason),
|
||||||
|
prompt_eval_count=prompt_eval_count,
|
||||||
|
eval_count=eval_count,
|
||||||
|
).model_dump_json(exclude_none=True)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
# ── /api/generate ──
|
||||||
|
|
||||||
|
|
||||||
|
def ollama_generate_request_to_text_generation(
|
||||||
|
request: OllamaGenerateRequest,
|
||||||
|
) -> TextGenerationTaskParams:
|
||||||
|
"""Convert Ollama generate request to exo's internal text generation format."""
|
||||||
|
chat_template_messages: list[dict[str, Any]] = []
|
||||||
|
if request.system:
|
||||||
|
chat_template_messages.append({"role": "system", "content": request.system})
|
||||||
|
chat_template_messages.append({"role": "user", "content": request.prompt})
|
||||||
|
|
||||||
|
options = request.options
|
||||||
|
return TextGenerationTaskParams(
|
||||||
|
model=request.model,
|
||||||
|
input=[InputMessage(role="user", content=request.prompt)],
|
||||||
|
instructions=request.system,
|
||||||
|
max_output_tokens=options.num_predict if options else None,
|
||||||
|
temperature=options.temperature if options else None,
|
||||||
|
top_p=options.top_p if options else None,
|
||||||
|
top_k=options.top_k if options else None,
|
||||||
|
stop=options.stop if options else None,
|
||||||
|
seed=options.seed if options else None,
|
||||||
|
stream=request.stream,
|
||||||
|
enable_thinking=request.think,
|
||||||
|
chat_template_messages=chat_template_messages
|
||||||
|
if chat_template_messages
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_ollama_generate_stream(
|
||||||
|
_command_id: CommandId,
|
||||||
|
chunk_stream: AsyncGenerator[
|
||||||
|
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
|
||||||
|
],
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
"""Generate streaming responses for /api/generate in Ollama NDJSON format."""
|
||||||
|
thinking_parts: list[str] = []
|
||||||
|
|
||||||
|
async for chunk in chunk_stream:
|
||||||
|
match chunk:
|
||||||
|
case PrefillProgressChunk():
|
||||||
|
continue
|
||||||
|
|
||||||
|
case ErrorChunk():
|
||||||
|
resp = OllamaGenerateResponse(
|
||||||
|
model=str(chunk.model),
|
||||||
|
response="",
|
||||||
|
done=True,
|
||||||
|
done_reason="error",
|
||||||
|
)
|
||||||
|
yield f"{resp.model_dump_json(exclude_none=True)}\n"
|
||||||
|
return
|
||||||
|
|
||||||
|
case ToolCallChunk():
|
||||||
|
# generate endpoint doesn't support tools; emit as done
|
||||||
|
prompt_eval, eval_count = _get_usage(chunk)
|
||||||
|
resp = OllamaGenerateResponse(
|
||||||
|
model=str(chunk.model),
|
||||||
|
response="",
|
||||||
|
done=True,
|
||||||
|
done_reason="stop",
|
||||||
|
prompt_eval_count=prompt_eval,
|
||||||
|
eval_count=eval_count,
|
||||||
|
)
|
||||||
|
yield f"{resp.model_dump_json(exclude_none=True)}\n"
|
||||||
|
return
|
||||||
|
|
||||||
|
case TokenChunk():
|
||||||
|
done = chunk.finish_reason is not None
|
||||||
|
|
||||||
|
if chunk.is_thinking:
|
||||||
|
thinking_parts.append(chunk.text)
|
||||||
|
resp = OllamaGenerateResponse(
|
||||||
|
model=str(chunk.model),
|
||||||
|
response="",
|
||||||
|
thinking=chunk.text,
|
||||||
|
done=False,
|
||||||
|
)
|
||||||
|
yield f"{resp.model_dump_json(exclude_none=True)}\n"
|
||||||
|
elif done:
|
||||||
|
prompt_eval, eval_count = _get_usage(chunk)
|
||||||
|
resp = OllamaGenerateResponse(
|
||||||
|
model=str(chunk.model),
|
||||||
|
response=chunk.text,
|
||||||
|
done=True,
|
||||||
|
done_reason=_map_done_reason(chunk.finish_reason),
|
||||||
|
prompt_eval_count=prompt_eval,
|
||||||
|
eval_count=eval_count,
|
||||||
|
)
|
||||||
|
yield f"{resp.model_dump_json(exclude_none=True)}\n"
|
||||||
|
else:
|
||||||
|
resp = OllamaGenerateResponse(
|
||||||
|
model=str(chunk.model),
|
||||||
|
response=chunk.text,
|
||||||
|
done=False,
|
||||||
|
)
|
||||||
|
yield f"{resp.model_dump_json(exclude_none=True)}\n"
|
||||||
|
|
||||||
|
if done:
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
async def collect_ollama_generate_response(
|
||||||
|
_command_id: CommandId,
|
||||||
|
chunk_stream: AsyncGenerator[
|
||||||
|
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
|
||||||
|
],
|
||||||
|
) -> AsyncGenerator[str]:
|
||||||
|
"""Collect chunks into a single non-streaming /api/generate response."""
|
||||||
|
text_parts: list[str] = []
|
||||||
|
thinking_parts: list[str] = []
|
||||||
|
model: str | None = None
|
||||||
|
finish_reason: str | None = None
|
||||||
|
prompt_eval_count: int | None = None
|
||||||
|
eval_count: int | None = None
|
||||||
|
|
||||||
|
async for chunk in chunk_stream:
|
||||||
|
match chunk:
|
||||||
|
case PrefillProgressChunk():
|
||||||
|
continue
|
||||||
|
case ErrorChunk():
|
||||||
|
raise ValueError(chunk.error_message or "Internal server error")
|
||||||
|
case TokenChunk():
|
||||||
|
if model is None:
|
||||||
|
model = str(chunk.model)
|
||||||
|
if chunk.is_thinking:
|
||||||
|
thinking_parts.append(chunk.text)
|
||||||
|
else:
|
||||||
|
text_parts.append(chunk.text)
|
||||||
|
if chunk.finish_reason is not None:
|
||||||
|
finish_reason = chunk.finish_reason
|
||||||
|
prompt_eval_count, eval_count = _get_usage(chunk)
|
||||||
|
case ToolCallChunk():
|
||||||
|
if model is None:
|
||||||
|
model = str(chunk.model)
|
||||||
|
finish_reason = chunk.finish_reason
|
||||||
|
prompt_eval_count, eval_count = _get_usage(chunk)
|
||||||
|
|
||||||
|
assert model is not None
|
||||||
|
yield OllamaGenerateResponse(
|
||||||
|
model=model,
|
||||||
|
response="".join(text_parts),
|
||||||
|
thinking="".join(thinking_parts) if thinking_parts else None,
|
||||||
|
done=True,
|
||||||
|
done_reason=_map_done_reason(finish_reason),
|
||||||
|
prompt_eval_count=prompt_eval_count,
|
||||||
|
eval_count=eval_count,
|
||||||
|
).model_dump_json(exclude_none=True)
|
||||||
|
return
|
||||||
@@ -29,8 +29,15 @@ from exo.shared.types.openai_responses import (
|
|||||||
ResponseOutputItemAddedEvent,
|
ResponseOutputItemAddedEvent,
|
||||||
ResponseOutputItemDoneEvent,
|
ResponseOutputItemDoneEvent,
|
||||||
ResponseOutputText,
|
ResponseOutputText,
|
||||||
|
ResponseReasoningItem,
|
||||||
|
ResponseReasoningSummaryPartAddedEvent,
|
||||||
|
ResponseReasoningSummaryPartDoneEvent,
|
||||||
|
ResponseReasoningSummaryText,
|
||||||
|
ResponseReasoningSummaryTextDeltaEvent,
|
||||||
|
ResponseReasoningSummaryTextDoneEvent,
|
||||||
ResponsesRequest,
|
ResponsesRequest,
|
||||||
ResponsesResponse,
|
ResponsesResponse,
|
||||||
|
ResponsesStreamEvent,
|
||||||
ResponseTextDeltaEvent,
|
ResponseTextDeltaEvent,
|
||||||
ResponseTextDoneEvent,
|
ResponseTextDoneEvent,
|
||||||
ResponseUsage,
|
ResponseUsage,
|
||||||
@@ -38,6 +45,11 @@ from exo.shared.types.openai_responses import (
|
|||||||
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
||||||
|
|
||||||
|
|
||||||
|
def _format_sse(event: ResponsesStreamEvent) -> str:
|
||||||
|
"""Format a streaming event as an SSE message."""
|
||||||
|
return f"event: {event.type}\ndata: {event.model_dump_json()}\n\n"
|
||||||
|
|
||||||
|
|
||||||
def _extract_content(content: str | list[ResponseContentPart]) -> str:
|
def _extract_content(content: str | list[ResponseContentPart]) -> str:
|
||||||
"""Extract plain text from a content field that may be a string or list of parts."""
|
"""Extract plain text from a content field that may be a string or list of parts."""
|
||||||
if isinstance(content, str):
|
if isinstance(content, str):
|
||||||
@@ -135,7 +147,9 @@ async def collect_responses_response(
|
|||||||
"""Collect all token chunks and return a single ResponsesResponse."""
|
"""Collect all token chunks and return a single ResponsesResponse."""
|
||||||
response_id = f"resp_{command_id}"
|
response_id = f"resp_{command_id}"
|
||||||
item_id = f"item_{command_id}"
|
item_id = f"item_{command_id}"
|
||||||
|
reasoning_id = f"rs_{command_id}"
|
||||||
accumulated_text = ""
|
accumulated_text = ""
|
||||||
|
thinking_parts: list[str] = []
|
||||||
function_call_items: list[ResponseFunctionCallItem] = []
|
function_call_items: list[ResponseFunctionCallItem] = []
|
||||||
last_usage: Usage | None = None
|
last_usage: Usage | None = None
|
||||||
error_message: str | None = None
|
error_message: str | None = None
|
||||||
@@ -162,6 +176,10 @@ async def collect_responses_response(
|
|||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if chunk.is_thinking:
|
||||||
|
thinking_parts.append(chunk.text)
|
||||||
|
continue
|
||||||
|
|
||||||
accumulated_text += chunk.text
|
accumulated_text += chunk.text
|
||||||
|
|
||||||
if error_message is not None:
|
if error_message is not None:
|
||||||
@@ -176,13 +194,21 @@ async def collect_responses_response(
|
|||||||
total_tokens=last_usage.total_tokens,
|
total_tokens=last_usage.total_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
output: list[ResponseItem] = [
|
output: list[ResponseItem] = []
|
||||||
|
if thinking_parts:
|
||||||
|
output.append(
|
||||||
|
ResponseReasoningItem(
|
||||||
|
id=reasoning_id,
|
||||||
|
summary=[ResponseReasoningSummaryText(text="".join(thinking_parts))],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
output.append(
|
||||||
ResponseMessageItem(
|
ResponseMessageItem(
|
||||||
id=item_id,
|
id=item_id,
|
||||||
content=[ResponseOutputText(text=accumulated_text)],
|
content=[ResponseOutputText(text=accumulated_text)],
|
||||||
status="completed",
|
status="completed",
|
||||||
)
|
)
|
||||||
]
|
)
|
||||||
output.extend(function_call_items)
|
output.extend(function_call_items)
|
||||||
|
|
||||||
yield ResponsesResponse(
|
yield ResponsesResponse(
|
||||||
@@ -206,6 +232,7 @@ async def generate_responses_stream(
|
|||||||
"""Generate OpenAI Responses API streaming events from TokenChunks."""
|
"""Generate OpenAI Responses API streaming events from TokenChunks."""
|
||||||
response_id = f"resp_{command_id}"
|
response_id = f"resp_{command_id}"
|
||||||
item_id = f"item_{command_id}"
|
item_id = f"item_{command_id}"
|
||||||
|
reasoning_id = f"rs_{command_id}"
|
||||||
seq = count(1)
|
seq = count(1)
|
||||||
|
|
||||||
# response.created
|
# response.created
|
||||||
@@ -219,40 +246,25 @@ async def generate_responses_stream(
|
|||||||
created_event = ResponseCreatedEvent(
|
created_event = ResponseCreatedEvent(
|
||||||
sequence_number=next(seq), response=initial_response
|
sequence_number=next(seq), response=initial_response
|
||||||
)
|
)
|
||||||
yield f"event: response.created\ndata: {created_event.model_dump_json()}\n\n"
|
yield _format_sse(created_event)
|
||||||
|
|
||||||
# response.in_progress
|
# response.in_progress
|
||||||
in_progress_event = ResponseInProgressEvent(
|
in_progress_event = ResponseInProgressEvent(
|
||||||
sequence_number=next(seq), response=initial_response
|
sequence_number=next(seq), response=initial_response
|
||||||
)
|
)
|
||||||
yield f"event: response.in_progress\ndata: {in_progress_event.model_dump_json()}\n\n"
|
yield _format_sse(in_progress_event)
|
||||||
|
|
||||||
# response.output_item.added
|
|
||||||
initial_item = ResponseMessageItem(
|
|
||||||
id=item_id,
|
|
||||||
content=[ResponseOutputText(text="")],
|
|
||||||
status="in_progress",
|
|
||||||
)
|
|
||||||
item_added = ResponseOutputItemAddedEvent(
|
|
||||||
sequence_number=next(seq), output_index=0, item=initial_item
|
|
||||||
)
|
|
||||||
yield f"event: response.output_item.added\ndata: {item_added.model_dump_json()}\n\n"
|
|
||||||
|
|
||||||
# response.content_part.added
|
|
||||||
initial_part = ResponseOutputText(text="")
|
|
||||||
part_added = ResponseContentPartAddedEvent(
|
|
||||||
sequence_number=next(seq),
|
|
||||||
item_id=item_id,
|
|
||||||
output_index=0,
|
|
||||||
content_index=0,
|
|
||||||
part=initial_part,
|
|
||||||
)
|
|
||||||
yield f"event: response.content_part.added\ndata: {part_added.model_dump_json()}\n\n"
|
|
||||||
|
|
||||||
accumulated_text = ""
|
accumulated_text = ""
|
||||||
|
accumulated_thinking = ""
|
||||||
function_call_items: list[ResponseFunctionCallItem] = []
|
function_call_items: list[ResponseFunctionCallItem] = []
|
||||||
last_usage: Usage | None = None
|
last_usage: Usage | None = None
|
||||||
next_output_index = 1 # message item is at 0
|
next_output_index = 0
|
||||||
|
|
||||||
|
# Track dynamic block creation
|
||||||
|
reasoning_started = False
|
||||||
|
reasoning_output_index = -1
|
||||||
|
message_started = False
|
||||||
|
message_output_index = -1
|
||||||
|
|
||||||
async for chunk in chunk_stream:
|
async for chunk in chunk_stream:
|
||||||
if isinstance(chunk, PrefillProgressChunk):
|
if isinstance(chunk, PrefillProgressChunk):
|
||||||
@@ -281,7 +293,7 @@ async def generate_responses_stream(
|
|||||||
output_index=next_output_index,
|
output_index=next_output_index,
|
||||||
item=fc_item,
|
item=fc_item,
|
||||||
)
|
)
|
||||||
yield f"event: response.output_item.added\ndata: {fc_added.model_dump_json()}\n\n"
|
yield _format_sse(fc_added)
|
||||||
|
|
||||||
# response.function_call_arguments.delta
|
# response.function_call_arguments.delta
|
||||||
args_delta = ResponseFunctionCallArgumentsDeltaEvent(
|
args_delta = ResponseFunctionCallArgumentsDeltaEvent(
|
||||||
@@ -290,7 +302,7 @@ async def generate_responses_stream(
|
|||||||
output_index=next_output_index,
|
output_index=next_output_index,
|
||||||
delta=tool.arguments,
|
delta=tool.arguments,
|
||||||
)
|
)
|
||||||
yield f"event: response.function_call_arguments.delta\ndata: {args_delta.model_dump_json()}\n\n"
|
yield _format_sse(args_delta)
|
||||||
|
|
||||||
# response.function_call_arguments.done
|
# response.function_call_arguments.done
|
||||||
args_done = ResponseFunctionCallArgumentsDoneEvent(
|
args_done = ResponseFunctionCallArgumentsDoneEvent(
|
||||||
@@ -300,7 +312,7 @@ async def generate_responses_stream(
|
|||||||
name=tool.name,
|
name=tool.name,
|
||||||
arguments=tool.arguments,
|
arguments=tool.arguments,
|
||||||
)
|
)
|
||||||
yield f"event: response.function_call_arguments.done\ndata: {args_done.model_dump_json()}\n\n"
|
yield _format_sse(args_done)
|
||||||
|
|
||||||
# response.output_item.done
|
# response.output_item.done
|
||||||
fc_done_item = ResponseFunctionCallItem(
|
fc_done_item = ResponseFunctionCallItem(
|
||||||
@@ -315,44 +327,205 @@ async def generate_responses_stream(
|
|||||||
output_index=next_output_index,
|
output_index=next_output_index,
|
||||||
item=fc_done_item,
|
item=fc_done_item,
|
||||||
)
|
)
|
||||||
yield f"event: response.output_item.done\ndata: {fc_item_done.model_dump_json()}\n\n"
|
yield _format_sse(fc_item_done)
|
||||||
|
|
||||||
function_call_items.append(fc_done_item)
|
function_call_items.append(fc_done_item)
|
||||||
next_output_index += 1
|
next_output_index += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if chunk.is_thinking:
|
||||||
|
# Start reasoning block on first thinking token
|
||||||
|
if not reasoning_started:
|
||||||
|
reasoning_started = True
|
||||||
|
reasoning_output_index = next_output_index
|
||||||
|
next_output_index += 1
|
||||||
|
|
||||||
|
# response.output_item.added for reasoning
|
||||||
|
reasoning_item = ResponseReasoningItem(
|
||||||
|
id=reasoning_id,
|
||||||
|
summary=[],
|
||||||
|
status="in_progress",
|
||||||
|
)
|
||||||
|
rs_added = ResponseOutputItemAddedEvent(
|
||||||
|
sequence_number=next(seq),
|
||||||
|
output_index=reasoning_output_index,
|
||||||
|
item=reasoning_item,
|
||||||
|
)
|
||||||
|
yield _format_sse(rs_added)
|
||||||
|
|
||||||
|
# response.reasoning_summary_part.added
|
||||||
|
part_added = ResponseReasoningSummaryPartAddedEvent(
|
||||||
|
sequence_number=next(seq),
|
||||||
|
item_id=reasoning_id,
|
||||||
|
output_index=reasoning_output_index,
|
||||||
|
summary_index=0,
|
||||||
|
part=ResponseReasoningSummaryText(text=""),
|
||||||
|
)
|
||||||
|
yield _format_sse(part_added)
|
||||||
|
|
||||||
|
accumulated_thinking += chunk.text
|
||||||
|
|
||||||
|
# response.reasoning_summary_text.delta
|
||||||
|
rs_delta = ResponseReasoningSummaryTextDeltaEvent(
|
||||||
|
sequence_number=next(seq),
|
||||||
|
item_id=reasoning_id,
|
||||||
|
output_index=reasoning_output_index,
|
||||||
|
summary_index=0,
|
||||||
|
delta=chunk.text,
|
||||||
|
)
|
||||||
|
yield _format_sse(rs_delta)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Close reasoning block when transitioning to text
|
||||||
|
if reasoning_started and not message_started:
|
||||||
|
# response.reasoning_summary_text.done
|
||||||
|
rs_text_done = ResponseReasoningSummaryTextDoneEvent(
|
||||||
|
sequence_number=next(seq),
|
||||||
|
item_id=reasoning_id,
|
||||||
|
output_index=reasoning_output_index,
|
||||||
|
summary_index=0,
|
||||||
|
text=accumulated_thinking,
|
||||||
|
)
|
||||||
|
yield _format_sse(rs_text_done)
|
||||||
|
|
||||||
|
# response.reasoning_summary_part.done
|
||||||
|
rs_part_done = ResponseReasoningSummaryPartDoneEvent(
|
||||||
|
sequence_number=next(seq),
|
||||||
|
item_id=reasoning_id,
|
||||||
|
output_index=reasoning_output_index,
|
||||||
|
summary_index=0,
|
||||||
|
part=ResponseReasoningSummaryText(text=accumulated_thinking),
|
||||||
|
)
|
||||||
|
yield _format_sse(rs_part_done)
|
||||||
|
|
||||||
|
# response.output_item.done for reasoning
|
||||||
|
rs_item_done = ResponseOutputItemDoneEvent(
|
||||||
|
sequence_number=next(seq),
|
||||||
|
output_index=reasoning_output_index,
|
||||||
|
item=ResponseReasoningItem(
|
||||||
|
id=reasoning_id,
|
||||||
|
summary=[ResponseReasoningSummaryText(text=accumulated_thinking)],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
yield _format_sse(rs_item_done)
|
||||||
|
|
||||||
|
# Start message block on first text token
|
||||||
|
if not message_started:
|
||||||
|
message_started = True
|
||||||
|
message_output_index = next_output_index
|
||||||
|
next_output_index += 1
|
||||||
|
|
||||||
|
initial_item = ResponseMessageItem(
|
||||||
|
id=item_id,
|
||||||
|
content=[ResponseOutputText(text="")],
|
||||||
|
status="in_progress",
|
||||||
|
)
|
||||||
|
item_added = ResponseOutputItemAddedEvent(
|
||||||
|
sequence_number=next(seq),
|
||||||
|
output_index=message_output_index,
|
||||||
|
item=initial_item,
|
||||||
|
)
|
||||||
|
yield _format_sse(item_added)
|
||||||
|
|
||||||
|
initial_part = ResponseOutputText(text="")
|
||||||
|
part_added = ResponseContentPartAddedEvent(
|
||||||
|
sequence_number=next(seq),
|
||||||
|
item_id=item_id,
|
||||||
|
output_index=message_output_index,
|
||||||
|
content_index=0,
|
||||||
|
part=initial_part,
|
||||||
|
)
|
||||||
|
yield _format_sse(part_added)
|
||||||
|
|
||||||
accumulated_text += chunk.text
|
accumulated_text += chunk.text
|
||||||
|
|
||||||
# response.output_text.delta
|
# response.output_text.delta
|
||||||
delta_event = ResponseTextDeltaEvent(
|
delta_event = ResponseTextDeltaEvent(
|
||||||
sequence_number=next(seq),
|
sequence_number=next(seq),
|
||||||
item_id=item_id,
|
item_id=item_id,
|
||||||
output_index=0,
|
output_index=message_output_index,
|
||||||
content_index=0,
|
content_index=0,
|
||||||
delta=chunk.text,
|
delta=chunk.text,
|
||||||
)
|
)
|
||||||
yield f"event: response.output_text.delta\ndata: {delta_event.model_dump_json()}\n\n"
|
yield _format_sse(delta_event)
|
||||||
|
|
||||||
|
# Close reasoning block if it was never followed by text
|
||||||
|
if reasoning_started and not message_started:
|
||||||
|
rs_text_done = ResponseReasoningSummaryTextDoneEvent(
|
||||||
|
sequence_number=next(seq),
|
||||||
|
item_id=reasoning_id,
|
||||||
|
output_index=reasoning_output_index,
|
||||||
|
summary_index=0,
|
||||||
|
text=accumulated_thinking,
|
||||||
|
)
|
||||||
|
yield _format_sse(rs_text_done)
|
||||||
|
|
||||||
|
rs_part_done = ResponseReasoningSummaryPartDoneEvent(
|
||||||
|
sequence_number=next(seq),
|
||||||
|
item_id=reasoning_id,
|
||||||
|
output_index=reasoning_output_index,
|
||||||
|
summary_index=0,
|
||||||
|
part=ResponseReasoningSummaryText(text=accumulated_thinking),
|
||||||
|
)
|
||||||
|
yield _format_sse(rs_part_done)
|
||||||
|
|
||||||
|
rs_item_done = ResponseOutputItemDoneEvent(
|
||||||
|
sequence_number=next(seq),
|
||||||
|
output_index=reasoning_output_index,
|
||||||
|
item=ResponseReasoningItem(
|
||||||
|
id=reasoning_id,
|
||||||
|
summary=[ResponseReasoningSummaryText(text=accumulated_thinking)],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
yield _format_sse(rs_item_done)
|
||||||
|
|
||||||
|
# If no message block was started, create one now (empty text)
|
||||||
|
if not message_started:
|
||||||
|
message_output_index = next_output_index
|
||||||
|
next_output_index += 1
|
||||||
|
|
||||||
|
initial_item = ResponseMessageItem(
|
||||||
|
id=item_id,
|
||||||
|
content=[ResponseOutputText(text="")],
|
||||||
|
status="in_progress",
|
||||||
|
)
|
||||||
|
item_added = ResponseOutputItemAddedEvent(
|
||||||
|
sequence_number=next(seq),
|
||||||
|
output_index=message_output_index,
|
||||||
|
item=initial_item,
|
||||||
|
)
|
||||||
|
yield _format_sse(item_added)
|
||||||
|
|
||||||
|
initial_part = ResponseOutputText(text="")
|
||||||
|
part_added_evt = ResponseContentPartAddedEvent(
|
||||||
|
sequence_number=next(seq),
|
||||||
|
item_id=item_id,
|
||||||
|
output_index=message_output_index,
|
||||||
|
content_index=0,
|
||||||
|
part=initial_part,
|
||||||
|
)
|
||||||
|
yield _format_sse(part_added_evt)
|
||||||
|
|
||||||
# response.output_text.done
|
# response.output_text.done
|
||||||
text_done = ResponseTextDoneEvent(
|
text_done = ResponseTextDoneEvent(
|
||||||
sequence_number=next(seq),
|
sequence_number=next(seq),
|
||||||
item_id=item_id,
|
item_id=item_id,
|
||||||
output_index=0,
|
output_index=message_output_index,
|
||||||
content_index=0,
|
content_index=0,
|
||||||
text=accumulated_text,
|
text=accumulated_text,
|
||||||
)
|
)
|
||||||
yield f"event: response.output_text.done\ndata: {text_done.model_dump_json()}\n\n"
|
yield _format_sse(text_done)
|
||||||
|
|
||||||
# response.content_part.done
|
# response.content_part.done
|
||||||
final_part = ResponseOutputText(text=accumulated_text)
|
final_part = ResponseOutputText(text=accumulated_text)
|
||||||
part_done = ResponseContentPartDoneEvent(
|
part_done = ResponseContentPartDoneEvent(
|
||||||
sequence_number=next(seq),
|
sequence_number=next(seq),
|
||||||
item_id=item_id,
|
item_id=item_id,
|
||||||
output_index=0,
|
output_index=message_output_index,
|
||||||
content_index=0,
|
content_index=0,
|
||||||
part=final_part,
|
part=final_part,
|
||||||
)
|
)
|
||||||
yield f"event: response.content_part.done\ndata: {part_done.model_dump_json()}\n\n"
|
yield _format_sse(part_done)
|
||||||
|
|
||||||
# response.output_item.done
|
# response.output_item.done
|
||||||
final_message_item = ResponseMessageItem(
|
final_message_item = ResponseMessageItem(
|
||||||
@@ -361,9 +534,11 @@ async def generate_responses_stream(
|
|||||||
status="completed",
|
status="completed",
|
||||||
)
|
)
|
||||||
item_done = ResponseOutputItemDoneEvent(
|
item_done = ResponseOutputItemDoneEvent(
|
||||||
sequence_number=next(seq), output_index=0, item=final_message_item
|
sequence_number=next(seq),
|
||||||
|
output_index=message_output_index,
|
||||||
|
item=final_message_item,
|
||||||
)
|
)
|
||||||
yield f"event: response.output_item.done\ndata: {item_done.model_dump_json()}\n\n"
|
yield _format_sse(item_done)
|
||||||
|
|
||||||
# Create usage from usage data if available
|
# Create usage from usage data if available
|
||||||
usage = None
|
usage = None
|
||||||
@@ -375,7 +550,15 @@ async def generate_responses_stream(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# response.completed
|
# response.completed
|
||||||
output: list[ResponseItem] = [final_message_item]
|
output: list[ResponseItem] = []
|
||||||
|
if reasoning_started:
|
||||||
|
output.append(
|
||||||
|
ResponseReasoningItem(
|
||||||
|
id=reasoning_id,
|
||||||
|
summary=[ResponseReasoningSummaryText(text=accumulated_thinking)],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
output.append(final_message_item)
|
||||||
output.extend(function_call_items)
|
output.extend(function_call_items)
|
||||||
final_response = ResponsesResponse(
|
final_response = ResponsesResponse(
|
||||||
id=response_id,
|
id=response_id,
|
||||||
@@ -388,4 +571,4 @@ async def generate_responses_stream(
|
|||||||
completed_event = ResponseCompletedEvent(
|
completed_event = ResponseCompletedEvent(
|
||||||
sequence_number=next(seq), response=final_response
|
sequence_number=next(seq), response=final_response
|
||||||
)
|
)
|
||||||
yield f"event: response.completed\ndata: {completed_event.model_dump_json()}\n\n"
|
yield _format_sse(completed_event)
|
||||||
|
|||||||
@@ -32,6 +32,14 @@ from exo.master.adapters.claude import (
|
|||||||
collect_claude_response,
|
collect_claude_response,
|
||||||
generate_claude_stream,
|
generate_claude_stream,
|
||||||
)
|
)
|
||||||
|
from exo.master.adapters.ollama import (
|
||||||
|
collect_ollama_chat_response,
|
||||||
|
collect_ollama_generate_response,
|
||||||
|
generate_ollama_chat_stream,
|
||||||
|
generate_ollama_generate_stream,
|
||||||
|
ollama_generate_request_to_text_generation,
|
||||||
|
ollama_request_to_text_generation,
|
||||||
|
)
|
||||||
from exo.master.adapters.responses import (
|
from exo.master.adapters.responses import (
|
||||||
collect_responses_response,
|
collect_responses_response,
|
||||||
generate_responses_stream,
|
generate_responses_stream,
|
||||||
@@ -138,10 +146,22 @@ from exo.shared.types.events import (
|
|||||||
Event,
|
Event,
|
||||||
ForwarderEvent,
|
ForwarderEvent,
|
||||||
IndexedEvent,
|
IndexedEvent,
|
||||||
PrefillProgress,
|
|
||||||
TracesMerged,
|
TracesMerged,
|
||||||
)
|
)
|
||||||
from exo.shared.types.memory import Memory
|
from exo.shared.types.memory import Memory
|
||||||
|
from exo.shared.types.ollama_api import (
|
||||||
|
OllamaChatRequest,
|
||||||
|
OllamaChatResponse,
|
||||||
|
OllamaGenerateRequest,
|
||||||
|
OllamaGenerateResponse,
|
||||||
|
OllamaModelDetails,
|
||||||
|
OllamaModelTag,
|
||||||
|
OllamaPsModel,
|
||||||
|
OllamaPsResponse,
|
||||||
|
OllamaShowRequest,
|
||||||
|
OllamaShowResponse,
|
||||||
|
OllamaTagsResponse,
|
||||||
|
)
|
||||||
from exo.shared.types.openai_responses import (
|
from exo.shared.types.openai_responses import (
|
||||||
ResponsesRequest,
|
ResponsesRequest,
|
||||||
ResponsesResponse,
|
ResponsesResponse,
|
||||||
@@ -301,6 +321,21 @@ class API:
|
|||||||
self.app.get("/images/{image_id}")(self.get_image)
|
self.app.get("/images/{image_id}")(self.get_image)
|
||||||
self.app.post("/v1/messages", response_model=None)(self.claude_messages)
|
self.app.post("/v1/messages", response_model=None)(self.claude_messages)
|
||||||
self.app.post("/v1/responses", response_model=None)(self.openai_responses)
|
self.app.post("/v1/responses", response_model=None)(self.openai_responses)
|
||||||
|
|
||||||
|
# Ollama API
|
||||||
|
self.app.head("/ollama/")(self.ollama_version)
|
||||||
|
self.app.head("/ollama/api/version")(self.ollama_version)
|
||||||
|
self.app.post("/ollama/api/chat", response_model=None)(self.ollama_chat)
|
||||||
|
self.app.post("/ollama/api/api/chat", response_model=None)(self.ollama_chat)
|
||||||
|
self.app.post("/ollama/api/v1/chat", response_model=None)(self.ollama_chat)
|
||||||
|
self.app.post("/ollama/api/generate", response_model=None)(self.ollama_generate)
|
||||||
|
self.app.get("/ollama/api/tags")(self.ollama_tags)
|
||||||
|
self.app.get("/ollama/api/api/tags")(self.ollama_tags)
|
||||||
|
self.app.get("/ollama/api/v1/tags")(self.ollama_tags)
|
||||||
|
self.app.post("/ollama/api/show")(self.ollama_show)
|
||||||
|
self.app.get("/ollama/api/ps")(self.ollama_ps)
|
||||||
|
self.app.get("/ollama/api/version")(self.ollama_version)
|
||||||
|
|
||||||
self.app.get("/state")(lambda: self.state)
|
self.app.get("/state")(lambda: self.state)
|
||||||
self.app.get("/events")(self.stream_events)
|
self.app.get("/events")(self.stream_events)
|
||||||
self.app.post("/download/start")(self.start_download)
|
self.app.post("/download/start")(self.start_download)
|
||||||
@@ -1294,6 +1329,163 @@ class API:
|
|||||||
media_type="application/json",
|
media_type="application/json",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _ollama_root(self) -> JSONResponse:
|
||||||
|
"""Respond to HEAD / from Ollama CLI connectivity checks."""
|
||||||
|
return JSONResponse(content="Ollama is running")
|
||||||
|
|
||||||
|
async def ollama_chat(
|
||||||
|
self, request: Request
|
||||||
|
) -> OllamaChatResponse | StreamingResponse:
|
||||||
|
"""Ollama Chat API — accepts JSON regardless of Content-Type."""
|
||||||
|
body = await request.body()
|
||||||
|
payload = OllamaChatRequest.model_validate_json(body)
|
||||||
|
task_params = ollama_request_to_text_generation(payload)
|
||||||
|
resolved_model = await self._resolve_and_validate_text_model(
|
||||||
|
ModelId(task_params.model)
|
||||||
|
)
|
||||||
|
task_params = task_params.model_copy(update={"model": resolved_model})
|
||||||
|
|
||||||
|
command = TextGeneration(task_params=task_params)
|
||||||
|
await self._send(command)
|
||||||
|
|
||||||
|
if payload.stream:
|
||||||
|
return StreamingResponse(
|
||||||
|
generate_ollama_chat_stream(
|
||||||
|
command.command_id,
|
||||||
|
self._token_chunk_stream(command.command_id),
|
||||||
|
),
|
||||||
|
media_type="application/x-ndjson",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "close",
|
||||||
|
"X-Accel-Buffering": "no",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return StreamingResponse(
|
||||||
|
collect_ollama_chat_response(
|
||||||
|
command.command_id,
|
||||||
|
self._token_chunk_stream(command.command_id),
|
||||||
|
),
|
||||||
|
media_type="application/json",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def ollama_generate(
|
||||||
|
self, request: Request
|
||||||
|
) -> OllamaGenerateResponse | StreamingResponse:
|
||||||
|
"""Ollama Generate API — accepts JSON regardless of Content-Type."""
|
||||||
|
body = await request.body()
|
||||||
|
payload = OllamaGenerateRequest.model_validate_json(body)
|
||||||
|
task_params = ollama_generate_request_to_text_generation(payload)
|
||||||
|
resolved_model = await self._resolve_and_validate_text_model(
|
||||||
|
ModelId(task_params.model)
|
||||||
|
)
|
||||||
|
task_params = task_params.model_copy(update={"model": resolved_model})
|
||||||
|
|
||||||
|
command = TextGeneration(task_params=task_params)
|
||||||
|
await self._send(command)
|
||||||
|
|
||||||
|
if payload.stream:
|
||||||
|
return StreamingResponse(
|
||||||
|
generate_ollama_generate_stream(
|
||||||
|
command.command_id,
|
||||||
|
self._token_chunk_stream(command.command_id),
|
||||||
|
),
|
||||||
|
media_type="application/x-ndjson",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "close",
|
||||||
|
"X-Accel-Buffering": "no",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return StreamingResponse(
|
||||||
|
collect_ollama_generate_response(
|
||||||
|
command.command_id,
|
||||||
|
self._token_chunk_stream(command.command_id),
|
||||||
|
),
|
||||||
|
media_type="application/json",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def ollama_tags(self) -> OllamaTagsResponse:
|
||||||
|
"""Returns list of models in Ollama tags format. We return the downloaded ones only."""
|
||||||
|
|
||||||
|
def none_if_empty(value: str) -> str | None:
|
||||||
|
return value or None
|
||||||
|
|
||||||
|
downloaded_model_ids: set[str] = set()
|
||||||
|
for node_downloads in self.state.downloads.values():
|
||||||
|
for dl in node_downloads:
|
||||||
|
if isinstance(dl, DownloadCompleted):
|
||||||
|
downloaded_model_ids.add(dl.shard_metadata.model_card.model_id)
|
||||||
|
|
||||||
|
cards = [
|
||||||
|
c for c in await get_model_cards() if c.model_id in downloaded_model_ids
|
||||||
|
]
|
||||||
|
|
||||||
|
now = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
|
||||||
|
return OllamaTagsResponse(
|
||||||
|
models=[
|
||||||
|
OllamaModelTag(
|
||||||
|
name=str(card.model_id),
|
||||||
|
model=str(card.model_id),
|
||||||
|
modified_at=now,
|
||||||
|
size=card.storage_size.in_bytes,
|
||||||
|
digest="sha256:000000000000",
|
||||||
|
details=OllamaModelDetails(
|
||||||
|
family=none_if_empty(card.family),
|
||||||
|
quantization_level=none_if_empty(card.quantization),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for card in cards
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
async def ollama_show(self, request: Request) -> OllamaShowResponse:
|
||||||
|
"""Returns model information in Ollama show format."""
|
||||||
|
body = await request.body()
|
||||||
|
payload = OllamaShowRequest.model_validate_json(body)
|
||||||
|
model_name = payload.name or payload.model
|
||||||
|
if not model_name:
|
||||||
|
raise HTTPException(status_code=400, detail="name or model is required")
|
||||||
|
try:
|
||||||
|
card = await ModelCard.load(ModelId(model_name))
|
||||||
|
except Exception as exc:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404, detail=f"Model not found: {model_name}"
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
return OllamaShowResponse(
|
||||||
|
modelfile=f"FROM {card.model_id}",
|
||||||
|
template="{{ .Prompt }}",
|
||||||
|
details=OllamaModelDetails(
|
||||||
|
family=card.family or None,
|
||||||
|
quantization_level=card.quantization or None,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def ollama_ps(self) -> OllamaPsResponse:
|
||||||
|
"""Returns list of running models (active instances)."""
|
||||||
|
models: list[OllamaPsModel] = []
|
||||||
|
seen: set[str] = set()
|
||||||
|
for instance in self.state.instances.values():
|
||||||
|
model_id = str(instance.shard_assignments.model_id)
|
||||||
|
if model_id in seen:
|
||||||
|
continue
|
||||||
|
seen.add(model_id)
|
||||||
|
models.append(
|
||||||
|
OllamaPsModel(
|
||||||
|
name=model_id,
|
||||||
|
model=model_id,
|
||||||
|
size=0,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return OllamaPsResponse(models=models)
|
||||||
|
|
||||||
|
async def ollama_version(self) -> dict[str, str]:
|
||||||
|
"""Returns version information for Ollama API compatibility."""
|
||||||
|
return {"version": "exo v1.0"}
|
||||||
|
|
||||||
def _calculate_total_available_memory(self) -> Memory:
|
def _calculate_total_available_memory(self) -> Memory:
|
||||||
"""Calculate total available memory across all nodes in bytes."""
|
"""Calculate total available memory across all nodes in bytes."""
|
||||||
total_available = Memory()
|
total_available = Memory()
|
||||||
@@ -1323,7 +1515,7 @@ class API:
|
|||||||
name=card.model_id.short(),
|
name=card.model_id.short(),
|
||||||
description="",
|
description="",
|
||||||
tags=[],
|
tags=[],
|
||||||
storage_size_megabytes=int(card.storage_size.in_mb),
|
storage_size_megabytes=card.storage_size.in_mb,
|
||||||
supports_tensor=card.supports_tensor,
|
supports_tensor=card.supports_tensor,
|
||||||
tasks=[task.value for task in card.tasks],
|
tasks=[task.value for task in card.tasks],
|
||||||
is_custom=is_custom_card(card.model_id),
|
is_custom=is_custom_card(card.model_id),
|
||||||
@@ -1455,22 +1647,6 @@ class API:
|
|||||||
await queue.send(event.chunk)
|
await queue.send(event.chunk)
|
||||||
except BrokenResourceError:
|
except BrokenResourceError:
|
||||||
self._text_generation_queues.pop(event.command_id, None)
|
self._text_generation_queues.pop(event.command_id, None)
|
||||||
|
|
||||||
elif isinstance(event, PrefillProgress):
|
|
||||||
if queue := self._text_generation_queues.get(
|
|
||||||
event.command_id, None
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
await queue.send(
|
|
||||||
PrefillProgressChunk(
|
|
||||||
model=event.model,
|
|
||||||
processed_tokens=event.processed_tokens,
|
|
||||||
total_tokens=event.total_tokens,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
except BrokenResourceError:
|
|
||||||
self._text_generation_queues.pop(event.command_id, None)
|
|
||||||
|
|
||||||
if isinstance(event, TracesMerged):
|
if isinstance(event, TracesMerged):
|
||||||
self._save_merged_trace(event)
|
self._save_merged_trace(event)
|
||||||
|
|
||||||
|
|||||||
@@ -141,15 +141,29 @@ def place_instance(
|
|||||||
if len(selected_cycle) == 1:
|
if len(selected_cycle) == 1:
|
||||||
command.instance_meta = InstanceMeta.MlxRing
|
command.instance_meta = InstanceMeta.MlxRing
|
||||||
|
|
||||||
# TODO: Single node instances
|
|
||||||
match command.instance_meta:
|
match command.instance_meta:
|
||||||
case InstanceMeta.MlxJaccl:
|
case InstanceMeta.MlxJaccl:
|
||||||
|
# TODO(evan): shard assignments should contain information about ranks, this is ugly
|
||||||
|
def get_device_rank(node_id: NodeId) -> int:
|
||||||
|
runner_id = shard_assignments.node_to_runner[node_id]
|
||||||
|
shard_metadata = shard_assignments.runner_to_shard.get(runner_id)
|
||||||
|
assert shard_metadata is not None
|
||||||
|
return shard_metadata.device_rank
|
||||||
|
|
||||||
|
zero_node_ids = [
|
||||||
|
node_id
|
||||||
|
for node_id in selected_cycle.node_ids
|
||||||
|
if get_device_rank(node_id) == 0
|
||||||
|
]
|
||||||
|
assert len(zero_node_ids) == 1
|
||||||
|
coordinator_node_id = zero_node_ids[0]
|
||||||
|
|
||||||
mlx_jaccl_devices = get_mlx_jaccl_devices_matrix(
|
mlx_jaccl_devices = get_mlx_jaccl_devices_matrix(
|
||||||
[node_id for node_id in selected_cycle],
|
[node_id for node_id in selected_cycle],
|
||||||
cycle_digraph,
|
cycle_digraph,
|
||||||
)
|
)
|
||||||
mlx_jaccl_coordinators = get_mlx_jaccl_coordinators(
|
mlx_jaccl_coordinators = get_mlx_jaccl_coordinators(
|
||||||
coordinator=selected_cycle.node_ids[0],
|
coordinator=coordinator_node_id,
|
||||||
coordinator_port=random_ephemeral_port(),
|
coordinator_port=random_ephemeral_port(),
|
||||||
cycle_digraph=cycle_digraph,
|
cycle_digraph=cycle_digraph,
|
||||||
node_network=node_network,
|
node_network=node_network,
|
||||||
|
|||||||
@@ -102,22 +102,21 @@ def _allocate_and_validate_layers(
|
|||||||
layer_allocations = allocate_layers_proportionally(
|
layer_allocations = allocate_layers_proportionally(
|
||||||
total_layers=model_card.n_layers,
|
total_layers=model_card.n_layers,
|
||||||
memory_fractions=[
|
memory_fractions=[
|
||||||
node_memory[node_id].ram_available.in_bytes / total_memory.in_bytes
|
node_memory[node_id].ram_available / total_memory for node_id in node_ids
|
||||||
for node_id in node_ids
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
total_storage_bytes = model_card.storage_size.in_bytes
|
total_storage = model_card.storage_size
|
||||||
total_layers = model_card.n_layers
|
total_layers = model_card.n_layers
|
||||||
for i, node_id in enumerate(node_ids):
|
for i, node_id in enumerate(node_ids):
|
||||||
node_layers = layer_allocations[i]
|
node_layers = layer_allocations[i]
|
||||||
required_memory = (total_storage_bytes * node_layers) // total_layers
|
required_memory = (total_storage * node_layers) // total_layers
|
||||||
available_memory = node_memory[node_id].ram_available.in_bytes
|
available_memory = node_memory[node_id].ram_available
|
||||||
if required_memory > available_memory:
|
if required_memory > available_memory:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Node {i} ({node_id}) has insufficient memory: "
|
f"Node {i} ({node_id}) has insufficient memory: "
|
||||||
f"requires {required_memory / (1024**3):.2f} GB for {node_layers} layers, "
|
f"requires {required_memory.in_gb:.2f} GB for {node_layers} layers, "
|
||||||
f"but only has {available_memory / (1024**3):.2f} GB available"
|
f"but only has {available_memory.in_gb:.2f} GB available"
|
||||||
)
|
)
|
||||||
|
|
||||||
return layer_allocations
|
return layer_allocations
|
||||||
@@ -342,6 +341,7 @@ def _find_ip_prioritised(
|
|||||||
other_node_id: NodeId,
|
other_node_id: NodeId,
|
||||||
cycle_digraph: Topology,
|
cycle_digraph: Topology,
|
||||||
node_network: Mapping[NodeId, NodeNetworkInfo],
|
node_network: Mapping[NodeId, NodeNetworkInfo],
|
||||||
|
ring: bool,
|
||||||
) -> str | None:
|
) -> str | None:
|
||||||
"""Find an IP address between nodes with prioritization.
|
"""Find an IP address between nodes with prioritization.
|
||||||
|
|
||||||
@@ -354,13 +354,27 @@ def _find_ip_prioritised(
|
|||||||
ip_to_type = {
|
ip_to_type = {
|
||||||
iface.ip_address: iface.interface_type for iface in other_network.interfaces
|
iface.ip_address: iface.interface_type for iface in other_network.interfaces
|
||||||
}
|
}
|
||||||
priority = {
|
|
||||||
"ethernet": 0,
|
# Ring should prioritise fastest connection. As a best-effort, we prioritise TB.
|
||||||
"wifi": 1,
|
# TODO: Profile and get actual connection speeds.
|
||||||
"unknown": 2,
|
if ring:
|
||||||
"maybe_ethernet": 3,
|
priority = {
|
||||||
"thunderbolt": 4,
|
"thunderbolt": 0,
|
||||||
}
|
"maybe_ethernet": 1,
|
||||||
|
"ethernet": 2,
|
||||||
|
"wifi": 3,
|
||||||
|
"unknown": 4,
|
||||||
|
}
|
||||||
|
|
||||||
|
# RDMA prefers ethernet coordinator
|
||||||
|
else:
|
||||||
|
priority = {
|
||||||
|
"ethernet": 0,
|
||||||
|
"wifi": 1,
|
||||||
|
"unknown": 2,
|
||||||
|
"maybe_ethernet": 3,
|
||||||
|
"thunderbolt": 4,
|
||||||
|
}
|
||||||
return min(ips, key=lambda ip: priority.get(ip_to_type.get(ip, "unknown"), 2))
|
return min(ips, key=lambda ip: priority.get(ip_to_type.get(ip, "unknown"), 2))
|
||||||
|
|
||||||
|
|
||||||
@@ -400,7 +414,7 @@ def get_mlx_ring_hosts_by_node(
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
connection_ip = _find_ip_prioritised(
|
connection_ip = _find_ip_prioritised(
|
||||||
node_id, other_node_id, cycle_digraph, node_network
|
node_id, other_node_id, cycle_digraph, node_network, ring=True
|
||||||
)
|
)
|
||||||
if connection_ip is None:
|
if connection_ip is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -431,7 +445,9 @@ def get_mlx_jaccl_coordinators(
|
|||||||
if n == coordinator:
|
if n == coordinator:
|
||||||
return "0.0.0.0"
|
return "0.0.0.0"
|
||||||
|
|
||||||
ip = _find_ip_prioritised(n, coordinator, cycle_digraph, node_network)
|
ip = _find_ip_prioritised(
|
||||||
|
n, coordinator, cycle_digraph, node_network, ring=False
|
||||||
|
)
|
||||||
if ip is not None:
|
if ip is not None:
|
||||||
return ip
|
return ip
|
||||||
|
|
||||||
|
|||||||
@@ -261,7 +261,7 @@ class TestGenerateClaudeStreamToolUse:
|
|||||||
|
|
||||||
parsed = _parse_sse_events(events)
|
parsed = _parse_sse_events(events)
|
||||||
|
|
||||||
# Two tool block starts (at indices 1 and 2)
|
# Two tool block starts (at indices 0 and 1 — no text block when only tools)
|
||||||
tool_starts = [
|
tool_starts = [
|
||||||
e
|
e
|
||||||
for e in parsed
|
for e in parsed
|
||||||
@@ -270,12 +270,11 @@ class TestGenerateClaudeStreamToolUse:
|
|||||||
== "tool_use"
|
== "tool_use"
|
||||||
]
|
]
|
||||||
assert len(tool_starts) == 2
|
assert len(tool_starts) == 2
|
||||||
assert tool_starts[0]["index"] == 1
|
assert tool_starts[0]["index"] == 0
|
||||||
assert tool_starts[1]["index"] == 2
|
assert tool_starts[1]["index"] == 1
|
||||||
|
|
||||||
# Two tool block stops (at indices 1 and 2), plus text block stop at 0
|
# Two tool block stops (at indices 0 and 1)
|
||||||
block_stops = [e for e in parsed if e.get("type") == "content_block_stop"]
|
block_stops = [e for e in parsed if e.get("type") == "content_block_stop"]
|
||||||
stop_indices = [e["index"] for e in block_stops]
|
stop_indices = [e["index"] for e in block_stops]
|
||||||
assert 0 in stop_indices
|
assert 0 in stop_indices
|
||||||
assert 1 in stop_indices
|
assert 1 in stop_indices
|
||||||
assert 2 in stop_indices
|
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ from exo.utils.channels import channel
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_master():
|
async def test_master():
|
||||||
keypair = get_node_id_keypair()
|
keypair = get_node_id_keypair()
|
||||||
node_id = NodeId(keypair.to_peer_id().to_base58())
|
node_id = NodeId(keypair.to_node_id())
|
||||||
session_id = SessionId(master_node_id=node_id, election_clock=0)
|
session_id = SessionId(master_node_id=node_id, election_clock=0)
|
||||||
|
|
||||||
ge_sender, global_event_receiver = channel[ForwarderEvent]()
|
ge_sender, global_event_receiver = channel[ForwarderEvent]()
|
||||||
@@ -75,7 +75,7 @@ async def test_master():
|
|||||||
async with anyio.create_task_group() as tg:
|
async with anyio.create_task_group() as tg:
|
||||||
tg.start_soon(master.run)
|
tg.start_soon(master.run)
|
||||||
|
|
||||||
sender_node_id = NodeId(f"{keypair.to_peer_id().to_base58()}_sender")
|
sender_node_id = NodeId(f"{keypair.to_node_id()}_sender")
|
||||||
# inject a NodeGatheredInfo event
|
# inject a NodeGatheredInfo event
|
||||||
logger.info("inject a NodeGatheredInfo event")
|
logger.info("inject a NodeGatheredInfo event")
|
||||||
await local_event_sender.send(
|
await local_event_sender.send(
|
||||||
|
|||||||
@@ -80,8 +80,8 @@ def test_get_instance_placements_create_instance(
|
|||||||
):
|
):
|
||||||
# arrange
|
# arrange
|
||||||
model_card.n_layers = total_layers
|
model_card.n_layers = total_layers
|
||||||
model_card.storage_size.in_bytes = sum(
|
model_card.storage_size = Memory.from_bytes(
|
||||||
available_memory
|
sum(available_memory)
|
||||||
) # make it exactly fit across all nodes
|
) # make it exactly fit across all nodes
|
||||||
topology = Topology()
|
topology = Topology()
|
||||||
|
|
||||||
@@ -349,7 +349,7 @@ def test_tensor_rdma_backend_connectivity_matrix(
|
|||||||
# arrange
|
# arrange
|
||||||
topology = Topology()
|
topology = Topology()
|
||||||
model_card.n_layers = 12
|
model_card.n_layers = 12
|
||||||
model_card.storage_size.in_bytes = 1500
|
model_card.storage_size = Memory.from_bytes(1500)
|
||||||
|
|
||||||
node_a = NodeId()
|
node_a = NodeId()
|
||||||
node_b = NodeId()
|
node_b = NodeId()
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ class ConnectionMessage(CamelCaseModel):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_update(cls, update: ConnectionUpdate) -> "ConnectionMessage":
|
def from_update(cls, update: ConnectionUpdate) -> "ConnectionMessage":
|
||||||
return cls(
|
return cls(
|
||||||
node_id=NodeId(update.peer_id.to_base58()),
|
node_id=NodeId(update.peer_id),
|
||||||
connection_type=ConnectionMessageType.from_update_type(update.update_type),
|
connection_type=ConnectionMessageType.from_update_type(update.update_type),
|
||||||
remote_ipv4=update.remote_ipv4,
|
remote_ipv4=update.remote_ipv4,
|
||||||
remote_tcp_port=update.remote_tcp_port,
|
remote_tcp_port=update.remote_tcp_port,
|
||||||
|
|||||||
@@ -221,7 +221,7 @@ def get_node_id_keypair(
|
|||||||
Obtain the :class:`PeerId` by from it.
|
Obtain the :class:`PeerId` by from it.
|
||||||
"""
|
"""
|
||||||
# TODO(evan): bring back node id persistence once we figure out how to deal with duplicates
|
# TODO(evan): bring back node id persistence once we figure out how to deal with duplicates
|
||||||
return Keypair.generate_ed25519()
|
return Keypair.generate()
|
||||||
|
|
||||||
def lock_path(path: str | bytes | PathLike[str] | PathLike[bytes]) -> Path:
|
def lock_path(path: str | bytes | PathLike[str] | PathLike[bytes]) -> Path:
|
||||||
return Path(str(path) + ".lock")
|
return Path(str(path) + ".lock")
|
||||||
@@ -235,12 +235,12 @@ def get_node_id_keypair(
|
|||||||
protobuf_encoded = f.read()
|
protobuf_encoded = f.read()
|
||||||
|
|
||||||
try: # if decoded successfully, save & return
|
try: # if decoded successfully, save & return
|
||||||
return Keypair.from_protobuf_encoding(protobuf_encoded)
|
return Keypair.from_bytes(protobuf_encoded)
|
||||||
except ValueError as e: # on runtime error, assume corrupt file
|
except ValueError as e: # on runtime error, assume corrupt file
|
||||||
logger.warning(f"Encountered error when trying to get keypair: {e}")
|
logger.warning(f"Encountered error when trying to get keypair: {e}")
|
||||||
|
|
||||||
# if no valid credentials, create new ones and persist
|
# if no valid credentials, create new ones and persist
|
||||||
with open(path, "w+b") as f:
|
with open(path, "w+b") as f:
|
||||||
keypair = Keypair.generate_ed25519()
|
keypair = Keypair.generate_ed25519()
|
||||||
f.write(keypair.to_protobuf_encoding())
|
f.write(keypair.to_bytes())
|
||||||
return keypair
|
return keypair
|
||||||
|
|||||||
@@ -15,7 +15,6 @@ from exo.shared.types.events import (
|
|||||||
NodeDownloadProgress,
|
NodeDownloadProgress,
|
||||||
NodeGatheredInfo,
|
NodeGatheredInfo,
|
||||||
NodeTimedOut,
|
NodeTimedOut,
|
||||||
PrefillProgress,
|
|
||||||
RunnerDeleted,
|
RunnerDeleted,
|
||||||
RunnerStatusUpdated,
|
RunnerStatusUpdated,
|
||||||
TaskAcknowledged,
|
TaskAcknowledged,
|
||||||
@@ -65,7 +64,6 @@ def event_apply(event: Event, state: State) -> State:
|
|||||||
| ChunkGenerated()
|
| ChunkGenerated()
|
||||||
| TaskAcknowledged()
|
| TaskAcknowledged()
|
||||||
| InputChunkReceived()
|
| InputChunkReceived()
|
||||||
| PrefillProgress()
|
|
||||||
| TracesCollected()
|
| TracesCollected()
|
||||||
| TracesMerged()
|
| TracesMerged()
|
||||||
): # Pass-through events that don't modify state
|
): # Pass-through events that don't modify state
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ def test_apply_node_download_progress():
|
|||||||
event = DownloadCompleted(
|
event = DownloadCompleted(
|
||||||
node_id=NodeId("node-1"),
|
node_id=NodeId("node-1"),
|
||||||
shard_metadata=shard1,
|
shard_metadata=shard1,
|
||||||
total_bytes=Memory(),
|
total=Memory(),
|
||||||
)
|
)
|
||||||
|
|
||||||
new_state = apply_node_download_progress(
|
new_state = apply_node_download_progress(
|
||||||
@@ -30,12 +30,12 @@ def test_apply_two_node_download_progress():
|
|||||||
event1 = DownloadCompleted(
|
event1 = DownloadCompleted(
|
||||||
node_id=NodeId("node-1"),
|
node_id=NodeId("node-1"),
|
||||||
shard_metadata=shard1,
|
shard_metadata=shard1,
|
||||||
total_bytes=Memory(),
|
total=Memory(),
|
||||||
)
|
)
|
||||||
event2 = DownloadCompleted(
|
event2 = DownloadCompleted(
|
||||||
node_id=NodeId("node-1"),
|
node_id=NodeId("node-1"),
|
||||||
shard_metadata=shard2,
|
shard_metadata=shard2,
|
||||||
total_bytes=Memory(),
|
total=Memory(),
|
||||||
)
|
)
|
||||||
state = State(downloads={NodeId("node-1"): [event1]})
|
state = State(downloads={NodeId("node-1"): [event1]})
|
||||||
|
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ def _get_keypair_concurrent_subprocess_task(
|
|||||||
sem.release()
|
sem.release()
|
||||||
# wait to be told to begin simultaneous read
|
# wait to be told to begin simultaneous read
|
||||||
ev.wait()
|
ev.wait()
|
||||||
queue.put(get_node_id_keypair().to_protobuf_encoding())
|
queue.put(get_node_id_keypair().to_bytes())
|
||||||
|
|
||||||
|
|
||||||
def _get_keypair_concurrent(num_procs: int) -> bytes:
|
def _get_keypair_concurrent(num_procs: int) -> bytes:
|
||||||
|
|||||||
@@ -77,7 +77,7 @@ class ChatCompletionMessage(BaseModel):
|
|||||||
content: (
|
content: (
|
||||||
str | ChatCompletionMessageText | list[ChatCompletionMessageText] | None
|
str | ChatCompletionMessageText | list[ChatCompletionMessageText] | None
|
||||||
) = None
|
) = None
|
||||||
thinking: str | None = None # Added for GPT-OSS harmony format support
|
reasoning_content: str | None = None
|
||||||
name: str | None = None
|
name: str | None = None
|
||||||
tool_calls: list[ToolCall] | None = None
|
tool_calls: list[ToolCall] | None = None
|
||||||
tool_call_id: str | None = None
|
tool_call_id: str | None = None
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ class TokenChunk(BaseChunk):
|
|||||||
stats: GenerationStats | None = None
|
stats: GenerationStats | None = None
|
||||||
logprob: float | None = None
|
logprob: float | None = None
|
||||||
top_logprobs: list[TopLogprobItem] | None = None
|
top_logprobs: list[TopLogprobItem] | None = None
|
||||||
|
is_thinking: bool = False
|
||||||
|
|
||||||
|
|
||||||
class ErrorChunk(BaseChunk):
|
class ErrorChunk(BaseChunk):
|
||||||
|
|||||||
@@ -47,6 +47,14 @@ class ClaudeImageBlock(BaseModel, frozen=True):
|
|||||||
source: ClaudeImageSource
|
source: ClaudeImageSource
|
||||||
|
|
||||||
|
|
||||||
|
class ClaudeThinkingBlock(BaseModel, frozen=True):
|
||||||
|
"""Thinking content block in Claude Messages API."""
|
||||||
|
|
||||||
|
type: Literal["thinking"] = "thinking"
|
||||||
|
thinking: str
|
||||||
|
signature: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class ClaudeToolUseBlock(BaseModel, frozen=True):
|
class ClaudeToolUseBlock(BaseModel, frozen=True):
|
||||||
"""Tool use content block in Claude Messages API."""
|
"""Tool use content block in Claude Messages API."""
|
||||||
|
|
||||||
@@ -66,11 +74,17 @@ class ClaudeToolResultBlock(BaseModel, frozen=True):
|
|||||||
cache_control: dict[str, str] | None = None
|
cache_control: dict[str, str] | None = None
|
||||||
|
|
||||||
|
|
||||||
ClaudeContentBlock = ClaudeTextBlock | ClaudeImageBlock | ClaudeToolUseBlock
|
ClaudeContentBlock = (
|
||||||
|
ClaudeTextBlock | ClaudeImageBlock | ClaudeThinkingBlock | ClaudeToolUseBlock
|
||||||
|
)
|
||||||
|
|
||||||
# Input content blocks can also include tool_result (sent by user after tool_use)
|
# Input content blocks can also include tool_result (sent by user after tool_use)
|
||||||
ClaudeInputContentBlock = (
|
ClaudeInputContentBlock = (
|
||||||
ClaudeTextBlock | ClaudeImageBlock | ClaudeToolUseBlock | ClaudeToolResultBlock
|
ClaudeTextBlock
|
||||||
|
| ClaudeImageBlock
|
||||||
|
| ClaudeThinkingBlock
|
||||||
|
| ClaudeToolUseBlock
|
||||||
|
| ClaudeToolResultBlock
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -82,6 +96,11 @@ class ClaudeMessage(BaseModel, frozen=True):
|
|||||||
content: str | list[ClaudeInputContentBlock]
|
content: str | list[ClaudeInputContentBlock]
|
||||||
|
|
||||||
|
|
||||||
|
class ClaudeThinkingConfig(BaseModel, frozen=True):
|
||||||
|
type: Literal["enabled", "disabled", "adaptive"]
|
||||||
|
budget_tokens: int | None = None
|
||||||
|
|
||||||
|
|
||||||
class ClaudeMessagesRequest(BaseModel):
|
class ClaudeMessagesRequest(BaseModel):
|
||||||
"""Request body for Claude Messages API."""
|
"""Request body for Claude Messages API."""
|
||||||
|
|
||||||
@@ -96,6 +115,7 @@ class ClaudeMessagesRequest(BaseModel):
|
|||||||
top_k: int | None = None
|
top_k: int | None = None
|
||||||
tools: list[ClaudeToolDefinition] | None = None
|
tools: list[ClaudeToolDefinition] | None = None
|
||||||
metadata: dict[str, str] | None = None
|
metadata: dict[str, str] | None = None
|
||||||
|
thinking: ClaudeThinkingConfig | None = None
|
||||||
|
|
||||||
|
|
||||||
# Response types
|
# Response types
|
||||||
@@ -145,7 +165,7 @@ class ClaudeContentBlockStartEvent(BaseModel, frozen=True):
|
|||||||
|
|
||||||
type: Literal["content_block_start"] = "content_block_start"
|
type: Literal["content_block_start"] = "content_block_start"
|
||||||
index: int
|
index: int
|
||||||
content_block: ClaudeTextBlock | ClaudeToolUseBlock
|
content_block: ClaudeTextBlock | ClaudeThinkingBlock | ClaudeToolUseBlock
|
||||||
|
|
||||||
|
|
||||||
class ClaudeTextDelta(BaseModel, frozen=True):
|
class ClaudeTextDelta(BaseModel, frozen=True):
|
||||||
@@ -155,6 +175,13 @@ class ClaudeTextDelta(BaseModel, frozen=True):
|
|||||||
text: str
|
text: str
|
||||||
|
|
||||||
|
|
||||||
|
class ClaudeThinkingDelta(BaseModel, frozen=True):
|
||||||
|
"""Delta for thinking content block."""
|
||||||
|
|
||||||
|
type: Literal["thinking_delta"] = "thinking_delta"
|
||||||
|
thinking: str
|
||||||
|
|
||||||
|
|
||||||
class ClaudeInputJsonDelta(BaseModel, frozen=True):
|
class ClaudeInputJsonDelta(BaseModel, frozen=True):
|
||||||
"""Delta for tool use input JSON content block."""
|
"""Delta for tool use input JSON content block."""
|
||||||
|
|
||||||
@@ -167,7 +194,7 @@ class ClaudeContentBlockDeltaEvent(BaseModel, frozen=True):
|
|||||||
|
|
||||||
type: Literal["content_block_delta"] = "content_block_delta"
|
type: Literal["content_block_delta"] = "content_block_delta"
|
||||||
index: int
|
index: int
|
||||||
delta: ClaudeTextDelta | ClaudeInputJsonDelta
|
delta: ClaudeTextDelta | ClaudeThinkingDelta | ClaudeInputJsonDelta
|
||||||
|
|
||||||
|
|
||||||
class ClaudeContentBlockStopEvent(BaseModel, frozen=True):
|
class ClaudeContentBlockStopEvent(BaseModel, frozen=True):
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from pydantic import Field
|
|||||||
|
|
||||||
from exo.shared.topology import Connection
|
from exo.shared.topology import Connection
|
||||||
from exo.shared.types.chunks import GenerationChunk, InputImageChunk
|
from exo.shared.types.chunks import GenerationChunk, InputImageChunk
|
||||||
from exo.shared.types.common import CommandId, Id, ModelId, NodeId, SessionId
|
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
|
||||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||||
from exo.shared.types.worker.downloads import DownloadProgress
|
from exo.shared.types.worker.downloads import DownloadProgress
|
||||||
from exo.shared.types.worker.instances import Instance, InstanceId
|
from exo.shared.types.worker.instances import Instance, InstanceId
|
||||||
@@ -102,13 +102,6 @@ class InputChunkReceived(BaseEvent):
|
|||||||
chunk: InputImageChunk
|
chunk: InputImageChunk
|
||||||
|
|
||||||
|
|
||||||
class PrefillProgress(BaseEvent):
|
|
||||||
command_id: CommandId
|
|
||||||
model: ModelId
|
|
||||||
processed_tokens: int
|
|
||||||
total_tokens: int
|
|
||||||
|
|
||||||
|
|
||||||
class TopologyEdgeCreated(BaseEvent):
|
class TopologyEdgeCreated(BaseEvent):
|
||||||
conn: Connection
|
conn: Connection
|
||||||
|
|
||||||
@@ -155,7 +148,6 @@ Event = (
|
|||||||
| NodeDownloadProgress
|
| NodeDownloadProgress
|
||||||
| ChunkGenerated
|
| ChunkGenerated
|
||||||
| InputChunkReceived
|
| InputChunkReceived
|
||||||
| PrefillProgress
|
|
||||||
| TopologyEdgeCreated
|
| TopologyEdgeCreated
|
||||||
| TopologyEdgeDeleted
|
| TopologyEdgeDeleted
|
||||||
| TracesCollected
|
| TracesCollected
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
from math import ceil
|
from math import ceil
|
||||||
from typing import Self
|
from typing import Self, overload
|
||||||
|
|
||||||
from exo.utils.pydantic_ext import CamelCaseModel
|
from exo.utils.pydantic_ext import FrozenModel
|
||||||
|
|
||||||
|
|
||||||
class Memory(CamelCaseModel):
|
class Memory(FrozenModel):
|
||||||
in_bytes: int = 0
|
in_bytes: int = 0
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -33,12 +33,22 @@ class Memory(CamelCaseModel):
|
|||||||
return cls(in_bytes=round(val * 1024))
|
return cls(in_bytes=round(val * 1024))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def in_mb(self) -> float:
|
def in_mb(self) -> int:
|
||||||
"""The approximate megabytes this memory represents. Setting this property rounds to the nearest byte."""
|
"""The approximate megabytes this memory represents, rounded to nearest MB. Setting this property rounds to the nearest byte."""
|
||||||
return self.in_bytes / (1024**2)
|
return round(self.in_bytes / (1024**2))
|
||||||
|
|
||||||
@in_mb.setter
|
@in_mb.setter
|
||||||
def in_mb(self, val: float):
|
def in_mb(self, val: int):
|
||||||
|
"""Set the megabytes for this memory."""
|
||||||
|
self.in_bytes = val * (1024**2)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def in_float_mb(self) -> float:
|
||||||
|
"""The megabytes this memory represents as a float. Setting this property rounds to the nearest byte."""
|
||||||
|
return self.in_bytes / (1024**2)
|
||||||
|
|
||||||
|
@in_float_mb.setter
|
||||||
|
def in_float_mb(self, val: float):
|
||||||
"""Set the megabytes for this memory, rounded to the nearest byte."""
|
"""Set the megabytes for this memory, rounded to the nearest byte."""
|
||||||
self.in_bytes = round(val * (1024**2))
|
self.in_bytes = round(val * (1024**2))
|
||||||
|
|
||||||
@@ -57,17 +67,85 @@ class Memory(CamelCaseModel):
|
|||||||
"""The approximate gigabytes this memory represents."""
|
"""The approximate gigabytes this memory represents."""
|
||||||
return self.in_bytes / (1024**3)
|
return self.in_bytes / (1024**3)
|
||||||
|
|
||||||
def __add__(self, other: "Memory") -> "Memory":
|
def __add__(self, other: object) -> "Memory":
|
||||||
return Memory.from_bytes(self.in_bytes + other.in_bytes)
|
if isinstance(other, Memory):
|
||||||
|
return Memory.from_bytes(self.in_bytes + other.in_bytes)
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
def __lt__(self, other: Self) -> bool:
|
def __radd__(self, other: object) -> "Memory":
|
||||||
return self.in_bytes < other.in_bytes
|
if other == 0:
|
||||||
|
return self
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
def __le__(self, other: Self) -> bool:
|
def __sub__(self, other: object) -> "Memory":
|
||||||
return self.in_bytes <= other.in_bytes
|
if isinstance(other, Memory):
|
||||||
|
return Memory.from_bytes(self.in_bytes - other.in_bytes)
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
def __gt__(self, other: Self) -> bool:
|
def __mul__(self, other: int | float):
|
||||||
return self.in_bytes > other.in_bytes
|
return Memory.from_bytes(round(self.in_bytes * other))
|
||||||
|
|
||||||
def __ge__(self, other: Self) -> bool:
|
def __rmul__(self, other: int | float):
|
||||||
return self.in_bytes >= other.in_bytes
|
return self * other
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __truediv__(self, other: "Memory") -> float: ...
|
||||||
|
@overload
|
||||||
|
def __truediv__(self, other: int) -> "Memory": ...
|
||||||
|
@overload
|
||||||
|
def __truediv__(self, other: float) -> "Memory": ...
|
||||||
|
def __truediv__(self, other: object) -> "Memory | float":
|
||||||
|
if isinstance(other, Memory):
|
||||||
|
return self.in_bytes / other.in_bytes
|
||||||
|
if isinstance(other, (int, float)):
|
||||||
|
return Memory.from_bytes(round(self.in_bytes / other))
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
|
def __floordiv__(self, other: object) -> "Memory":
|
||||||
|
if isinstance(other, (int, float)):
|
||||||
|
return Memory.from_bytes(int(self.in_bytes // other))
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
|
def __lt__(self, other: object) -> bool:
|
||||||
|
if isinstance(other, Memory):
|
||||||
|
return self.in_bytes < other.in_bytes
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
|
def __le__(self, other: object) -> bool:
|
||||||
|
if isinstance(other, Memory):
|
||||||
|
return self.in_bytes <= other.in_bytes
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
|
def __gt__(self, other: object) -> bool:
|
||||||
|
if isinstance(other, Memory):
|
||||||
|
return self.in_bytes > other.in_bytes
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
|
def __ge__(self, other: object) -> bool:
|
||||||
|
if isinstance(other, Memory):
|
||||||
|
return self.in_bytes >= other.in_bytes
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
|
def __eq__(self, other: object) -> bool:
|
||||||
|
if isinstance(other, Memory):
|
||||||
|
return self.in_bytes == other.in_bytes
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"Memory.from_bytes({self.in_bytes})"
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
if self.in_gb > 2:
|
||||||
|
val = self.in_gb
|
||||||
|
unit = "GiB"
|
||||||
|
elif self.in_mb > 2:
|
||||||
|
val = self.in_mb
|
||||||
|
unit = "MiB"
|
||||||
|
elif self.in_kb > 3:
|
||||||
|
val = self.in_kb
|
||||||
|
unit = "KiB"
|
||||||
|
else:
|
||||||
|
val = self.in_bytes
|
||||||
|
unit = "B"
|
||||||
|
|
||||||
|
return f"{val:.2f} {unit}".rstrip("0").rstrip(".") + f" {unit}"
|
||||||
|
|||||||
148
src/exo/shared/types/ollama_api.py
Normal file
148
src/exo/shared/types/ollama_api.py
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from exo.shared.models.model_cards import ModelId
|
||||||
|
|
||||||
|
# https://github.com/ollama/ollama/blob/main/docs/api.md
|
||||||
|
|
||||||
|
OllamaRole = Literal["system", "user", "assistant", "tool"]
|
||||||
|
OllamaDoneReason = Literal["stop", "length", "tool_call", "error"]
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaToolFunction(BaseModel, frozen=True):
|
||||||
|
name: str
|
||||||
|
arguments: dict[str, Any] | str
|
||||||
|
index: int | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaToolCall(BaseModel, frozen=True):
|
||||||
|
id: str | None = None
|
||||||
|
type: Literal["function"] | None = None
|
||||||
|
function: OllamaToolFunction
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaMessage(BaseModel, frozen=True):
|
||||||
|
role: OllamaRole
|
||||||
|
content: str | None = None
|
||||||
|
thinking: str | None = None
|
||||||
|
tool_calls: list[OllamaToolCall] | None = None
|
||||||
|
name: str | None = None
|
||||||
|
tool_name: str | None = None
|
||||||
|
images: list[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaOptions(BaseModel, frozen=True):
|
||||||
|
num_predict: int | None = None
|
||||||
|
temperature: float | None = None
|
||||||
|
top_p: float | None = None
|
||||||
|
top_k: int | None = None
|
||||||
|
stop: str | list[str] | None = None
|
||||||
|
seed: int | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaChatRequest(BaseModel, frozen=True):
|
||||||
|
model: ModelId
|
||||||
|
messages: list[OllamaMessage]
|
||||||
|
stream: bool = True
|
||||||
|
options: OllamaOptions | None = None
|
||||||
|
tools: list[dict[str, Any]] | None = None
|
||||||
|
format: Literal["json"] | dict[str, Any] | None = None
|
||||||
|
keep_alive: str | int | None = None
|
||||||
|
think: bool | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaGenerateRequest(BaseModel, frozen=True):
|
||||||
|
model: ModelId
|
||||||
|
prompt: str = ""
|
||||||
|
system: str | None = None
|
||||||
|
stream: bool = True
|
||||||
|
options: OllamaOptions | None = None
|
||||||
|
format: Literal["json"] | dict[str, Any] | None = None
|
||||||
|
keep_alive: str | int | None = None
|
||||||
|
think: bool | None = None
|
||||||
|
raw: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaGenerateResponse(BaseModel, frozen=True, strict=True):
|
||||||
|
model: str
|
||||||
|
created_at: str = Field(
|
||||||
|
default_factory=lambda: time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
|
||||||
|
)
|
||||||
|
response: str
|
||||||
|
thinking: str | None = None
|
||||||
|
done: bool
|
||||||
|
done_reason: OllamaDoneReason | None = None
|
||||||
|
total_duration: int | None = None
|
||||||
|
load_duration: int | None = None
|
||||||
|
prompt_eval_count: int | None = None
|
||||||
|
prompt_eval_duration: int | None = None
|
||||||
|
eval_count: int | None = None
|
||||||
|
eval_duration: int | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaShowRequest(BaseModel, frozen=True):
|
||||||
|
name: str | None = None
|
||||||
|
model: str | None = None
|
||||||
|
verbose: bool | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaChatResponse(BaseModel, frozen=True, strict=True):
|
||||||
|
model: str
|
||||||
|
created_at: str = Field(
|
||||||
|
default_factory=lambda: time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
|
||||||
|
)
|
||||||
|
message: OllamaMessage
|
||||||
|
done: bool
|
||||||
|
done_reason: OllamaDoneReason | None = None
|
||||||
|
total_duration: int | None = None
|
||||||
|
load_duration: int | None = None
|
||||||
|
prompt_eval_count: int | None = None
|
||||||
|
prompt_eval_duration: int | None = None
|
||||||
|
eval_count: int | None = None
|
||||||
|
eval_duration: int | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaModelDetails(BaseModel, frozen=True, strict=True):
|
||||||
|
format: str | None = None
|
||||||
|
family: str | None = None
|
||||||
|
parameter_size: str | None = None
|
||||||
|
quantization_level: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaModelTag(BaseModel, frozen=True, strict=True):
|
||||||
|
name: str
|
||||||
|
model: str | None = None
|
||||||
|
modified_at: str | None = None
|
||||||
|
size: int | None = None
|
||||||
|
digest: str | None = None
|
||||||
|
details: OllamaModelDetails | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaTagsResponse(BaseModel, frozen=True, strict=True):
|
||||||
|
models: list[OllamaModelTag]
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaShowResponse(BaseModel, frozen=True, strict=True):
|
||||||
|
modelfile: str | None = None
|
||||||
|
parameters: str | None = None
|
||||||
|
template: str | None = None
|
||||||
|
details: OllamaModelDetails | None = None
|
||||||
|
model_info: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaPsModel(BaseModel, frozen=True, strict=True):
|
||||||
|
name: str
|
||||||
|
model: str
|
||||||
|
size: int
|
||||||
|
digest: str | None = None
|
||||||
|
details: OllamaModelDetails | None = None
|
||||||
|
expires_at: str | None = None
|
||||||
|
size_vram: int | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaPsResponse(BaseModel, frozen=True, strict=True):
|
||||||
|
models: list[OllamaPsModel]
|
||||||
@@ -145,7 +145,23 @@ class ResponseFunctionCallItem(BaseModel, frozen=True):
|
|||||||
status: ResponseStatus = "completed"
|
status: ResponseStatus = "completed"
|
||||||
|
|
||||||
|
|
||||||
ResponseItem = ResponseMessageItem | ResponseFunctionCallItem
|
class ResponseReasoningSummaryText(BaseModel, frozen=True):
|
||||||
|
"""Summary text part in a reasoning output item."""
|
||||||
|
|
||||||
|
type: Literal["summary_text"] = "summary_text"
|
||||||
|
text: str
|
||||||
|
|
||||||
|
|
||||||
|
class ResponseReasoningItem(BaseModel, frozen=True):
|
||||||
|
"""Reasoning output item in response output array."""
|
||||||
|
|
||||||
|
type: Literal["reasoning"] = "reasoning"
|
||||||
|
id: str
|
||||||
|
summary: list[ResponseReasoningSummaryText] = Field(default_factory=list)
|
||||||
|
status: ResponseStatus = "completed"
|
||||||
|
|
||||||
|
|
||||||
|
ResponseItem = ResponseMessageItem | ResponseFunctionCallItem | ResponseReasoningItem
|
||||||
|
|
||||||
|
|
||||||
class ResponseUsage(BaseModel, frozen=True):
|
class ResponseUsage(BaseModel, frozen=True):
|
||||||
@@ -273,6 +289,58 @@ class ResponseFunctionCallArgumentsDoneEvent(BaseModel, frozen=True):
|
|||||||
arguments: str
|
arguments: str
|
||||||
|
|
||||||
|
|
||||||
|
class ResponseReasoningSummaryPartAddedEvent(BaseModel, frozen=True):
|
||||||
|
"""Event sent when a reasoning summary part is added."""
|
||||||
|
|
||||||
|
type: Literal["response.reasoning_summary_part.added"] = (
|
||||||
|
"response.reasoning_summary_part.added"
|
||||||
|
)
|
||||||
|
sequence_number: int
|
||||||
|
item_id: str
|
||||||
|
output_index: int
|
||||||
|
summary_index: int
|
||||||
|
part: ResponseReasoningSummaryText
|
||||||
|
|
||||||
|
|
||||||
|
class ResponseReasoningSummaryTextDeltaEvent(BaseModel, frozen=True):
|
||||||
|
"""Event sent for reasoning summary text delta during streaming."""
|
||||||
|
|
||||||
|
type: Literal["response.reasoning_summary_text.delta"] = (
|
||||||
|
"response.reasoning_summary_text.delta"
|
||||||
|
)
|
||||||
|
sequence_number: int
|
||||||
|
item_id: str
|
||||||
|
output_index: int
|
||||||
|
summary_index: int
|
||||||
|
delta: str
|
||||||
|
|
||||||
|
|
||||||
|
class ResponseReasoningSummaryTextDoneEvent(BaseModel, frozen=True):
|
||||||
|
"""Event sent when reasoning summary text is done."""
|
||||||
|
|
||||||
|
type: Literal["response.reasoning_summary_text.done"] = (
|
||||||
|
"response.reasoning_summary_text.done"
|
||||||
|
)
|
||||||
|
sequence_number: int
|
||||||
|
item_id: str
|
||||||
|
output_index: int
|
||||||
|
summary_index: int
|
||||||
|
text: str
|
||||||
|
|
||||||
|
|
||||||
|
class ResponseReasoningSummaryPartDoneEvent(BaseModel, frozen=True):
|
||||||
|
"""Event sent when a reasoning summary part is done."""
|
||||||
|
|
||||||
|
type: Literal["response.reasoning_summary_part.done"] = (
|
||||||
|
"response.reasoning_summary_part.done"
|
||||||
|
)
|
||||||
|
sequence_number: int
|
||||||
|
item_id: str
|
||||||
|
output_index: int
|
||||||
|
summary_index: int
|
||||||
|
part: ResponseReasoningSummaryText
|
||||||
|
|
||||||
|
|
||||||
class ResponseCompletedEvent(BaseModel, frozen=True):
|
class ResponseCompletedEvent(BaseModel, frozen=True):
|
||||||
"""Event sent when response is completed."""
|
"""Event sent when response is completed."""
|
||||||
|
|
||||||
@@ -292,5 +360,9 @@ ResponsesStreamEvent = (
|
|||||||
| ResponseOutputItemDoneEvent
|
| ResponseOutputItemDoneEvent
|
||||||
| ResponseFunctionCallArgumentsDeltaEvent
|
| ResponseFunctionCallArgumentsDeltaEvent
|
||||||
| ResponseFunctionCallArgumentsDoneEvent
|
| ResponseFunctionCallArgumentsDoneEvent
|
||||||
|
| ResponseReasoningSummaryPartAddedEvent
|
||||||
|
| ResponseReasoningSummaryTextDeltaEvent
|
||||||
|
| ResponseReasoningSummaryTextDoneEvent
|
||||||
|
| ResponseReasoningSummaryPartDoneEvent
|
||||||
| ResponseCompletedEvent
|
| ResponseCompletedEvent
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -10,9 +10,9 @@ from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
|
|||||||
|
|
||||||
|
|
||||||
class DownloadProgressData(CamelCaseModel):
|
class DownloadProgressData(CamelCaseModel):
|
||||||
total_bytes: Memory
|
total: Memory
|
||||||
downloaded_bytes: Memory
|
downloaded: Memory
|
||||||
downloaded_bytes_this_session: Memory
|
downloaded_this_session: Memory
|
||||||
|
|
||||||
completed_files: int
|
completed_files: int
|
||||||
total_files: int
|
total_files: int
|
||||||
@@ -34,7 +34,7 @@ class DownloadPending(BaseDownloadProgress):
|
|||||||
|
|
||||||
|
|
||||||
class DownloadCompleted(BaseDownloadProgress):
|
class DownloadCompleted(BaseDownloadProgress):
|
||||||
total_bytes: Memory
|
total: Memory
|
||||||
|
|
||||||
|
|
||||||
class DownloadFailed(BaseDownloadProgress):
|
class DownloadFailed(BaseDownloadProgress):
|
||||||
@@ -86,9 +86,9 @@ class RepoDownloadProgress(BaseModel):
|
|||||||
shard: ShardMetadata
|
shard: ShardMetadata
|
||||||
completed_files: int
|
completed_files: int
|
||||||
total_files: int
|
total_files: int
|
||||||
downloaded_bytes: Memory
|
downloaded: Memory
|
||||||
downloaded_bytes_this_session: Memory
|
downloaded_this_session: Memory
|
||||||
total_bytes: Memory
|
total: Memory
|
||||||
overall_speed: float
|
overall_speed: float
|
||||||
overall_eta: timedelta
|
overall_eta: timedelta
|
||||||
status: Literal["not_started", "in_progress", "complete"]
|
status: Literal["not_started", "in_progress", "complete"]
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ class GenerationResponse(BaseRunnerResponse):
|
|||||||
finish_reason: FinishReason | None = None
|
finish_reason: FinishReason | None = None
|
||||||
stats: GenerationStats | None = None
|
stats: GenerationStats | None = None
|
||||||
usage: Usage | None
|
usage: Usage | None
|
||||||
|
is_thinking: bool = False
|
||||||
|
|
||||||
|
|
||||||
class ImageGenerationResponse(BaseRunnerResponse):
|
class ImageGenerationResponse(BaseRunnerResponse):
|
||||||
|
|||||||
@@ -192,7 +192,13 @@ class MpReceiver[T]:
|
|||||||
try:
|
try:
|
||||||
return self.receive_nowait()
|
return self.receive_nowait()
|
||||||
except WouldBlock:
|
except WouldBlock:
|
||||||
item = self._state.buffer.get()
|
try:
|
||||||
|
item = self._state.buffer.get()
|
||||||
|
except (TypeError, OSError):
|
||||||
|
# Queue pipe can get closed while we are blocked on get().
|
||||||
|
# The underlying connection._handle becomes None, causing
|
||||||
|
# TypeError in read(handle, remaining).
|
||||||
|
raise ClosedResourceError from None
|
||||||
if isinstance(item, _MpEndOfStream):
|
if isinstance(item, _MpEndOfStream):
|
||||||
self.close()
|
self.close()
|
||||||
raise EndOfStream from None
|
raise EndOfStream from None
|
||||||
|
|||||||
@@ -108,7 +108,7 @@ async def check_reachable(
|
|||||||
await send.send((target_ip, expected_node_id))
|
await send.send((target_ip, expected_node_id))
|
||||||
|
|
||||||
async with (
|
async with (
|
||||||
httpx.AsyncClient(timeout=timeout, limits=limits) as client,
|
httpx.AsyncClient(timeout=timeout, limits=limits, verify=False) as client,
|
||||||
create_task_group() as tg,
|
create_task_group() as tg,
|
||||||
):
|
):
|
||||||
for node_id in topology.list_nodes():
|
for node_id in topology.list_nodes():
|
||||||
|
|||||||
@@ -166,7 +166,7 @@ def generate_image(
|
|||||||
else 0.0
|
else 0.0
|
||||||
)
|
)
|
||||||
|
|
||||||
peak_memory_gb = mx.get_peak_memory() / (1024**3)
|
peak_memory = Memory.from_bytes(mx.get_peak_memory())
|
||||||
|
|
||||||
stats = ImageGenerationStats(
|
stats = ImageGenerationStats(
|
||||||
seconds_per_step=seconds_per_step,
|
seconds_per_step=seconds_per_step,
|
||||||
@@ -175,7 +175,7 @@ def generate_image(
|
|||||||
num_images=num_images,
|
num_images=num_images,
|
||||||
image_width=width,
|
image_width=width,
|
||||||
image_height=height,
|
image_height=height,
|
||||||
peak_memory_usage=Memory.from_gb(peak_memory_gb),
|
peak_memory_usage=peak_memory,
|
||||||
)
|
)
|
||||||
|
|
||||||
buffer = io.BytesIO()
|
buffer = io.BytesIO()
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ from exo.worker.runner.bootstrap import logger
|
|||||||
# Fraction of device memory above which LRU eviction kicks in.
|
# Fraction of device memory above which LRU eviction kicks in.
|
||||||
# Smaller machines need more aggressive eviction.
|
# Smaller machines need more aggressive eviction.
|
||||||
def _default_memory_threshold() -> float:
|
def _default_memory_threshold() -> float:
|
||||||
total_gb = psutil.virtual_memory().total / (1024**3)
|
total_gb = Memory.from_bytes(psutil.virtual_memory().total).in_gb
|
||||||
if total_gb >= 128:
|
if total_gb >= 128:
|
||||||
return 0.85
|
return 0.85
|
||||||
if total_gb >= 64:
|
if total_gb >= 64:
|
||||||
|
|||||||
72
src/exo/worker/engines/mlx/dsml_encoding.py
Normal file
72
src/exo/worker/engines/mlx/dsml_encoding.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
import json
|
||||||
|
import re
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from mlx_lm.chat_templates import deepseek_v32
|
||||||
|
|
||||||
|
from exo.shared.types.api import ToolCallItem
|
||||||
|
|
||||||
|
BOS_TOKEN: str = deepseek_v32.bos_token
|
||||||
|
EOS_TOKEN: str = deepseek_v32.eos_token
|
||||||
|
DSML_TOKEN: str = deepseek_v32.dsml_token
|
||||||
|
THINKING_START: str = deepseek_v32.thinking_start_token
|
||||||
|
THINKING_END: str = deepseek_v32.thinking_end_token
|
||||||
|
USER_TOKEN = "<\uff5cUser\uff5c>"
|
||||||
|
ASSISTANT_TOKEN = "<\uff5cAssistant\uff5c>"
|
||||||
|
TOOL_CALLS_START = f"<{DSML_TOKEN}function_calls>"
|
||||||
|
TOOL_CALLS_END = f"</{DSML_TOKEN}function_calls>"
|
||||||
|
encode_messages = deepseek_v32.encode_messages
|
||||||
|
|
||||||
|
_INVOKE_PATTERN = re.compile(
|
||||||
|
rf"<{re.escape(DSML_TOKEN)}invoke\s+name=\"([^\"]+)\">"
|
||||||
|
rf"(.*?)"
|
||||||
|
rf"</{re.escape(DSML_TOKEN)}invoke>",
|
||||||
|
re.DOTALL,
|
||||||
|
)
|
||||||
|
|
||||||
|
_PARAM_PATTERN = re.compile(
|
||||||
|
rf"<{re.escape(DSML_TOKEN)}parameter\s+name=\"([^\"]+)\"\s+string=\"(true|false)\">"
|
||||||
|
rf"(.*?)"
|
||||||
|
rf"</{re.escape(DSML_TOKEN)}parameter>",
|
||||||
|
re.DOTALL,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_dsml_output(text: str) -> list[ToolCallItem] | None:
|
||||||
|
"""Parse DSML function_calls block from model output text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text containing the DSML function_calls block
|
||||||
|
(including the start/end markers).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of ToolCallItem, or None if parsing fails.
|
||||||
|
"""
|
||||||
|
tool_calls: list[ToolCallItem] = []
|
||||||
|
|
||||||
|
for invoke_match in _INVOKE_PATTERN.finditer(text):
|
||||||
|
func_name = invoke_match.group(1)
|
||||||
|
invoke_body = invoke_match.group(2)
|
||||||
|
|
||||||
|
args: dict[str, Any] = {}
|
||||||
|
for param_match in _PARAM_PATTERN.finditer(invoke_body):
|
||||||
|
param_name = param_match.group(1)
|
||||||
|
is_string = param_match.group(2) == "true"
|
||||||
|
param_value = param_match.group(3)
|
||||||
|
|
||||||
|
if is_string:
|
||||||
|
args[param_name] = param_value
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
args[param_name] = json.loads(param_value)
|
||||||
|
except (json.JSONDecodeError, ValueError):
|
||||||
|
args[param_name] = param_value
|
||||||
|
|
||||||
|
tool_calls.append(
|
||||||
|
ToolCallItem(
|
||||||
|
name=func_name,
|
||||||
|
arguments=json.dumps(args),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return tool_calls if tool_calls else None
|
||||||
@@ -232,11 +232,11 @@ def shard_and_load(
|
|||||||
|
|
||||||
# Estimate timeout based on model size (5x default for large queued workloads)
|
# Estimate timeout based on model size (5x default for large queued workloads)
|
||||||
base_timeout = float(os.environ.get("EXO_MODEL_LOAD_TIMEOUT", "300"))
|
base_timeout = float(os.environ.get("EXO_MODEL_LOAD_TIMEOUT", "300"))
|
||||||
model_size_gb = get_weights_size(shard_metadata).in_bytes / (1024**3)
|
model_size = get_weights_size(shard_metadata)
|
||||||
timeout_seconds = base_timeout + model_size_gb
|
timeout_seconds = base_timeout + model_size.in_gb
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Evaluating model parameters with timeout of {timeout_seconds:.0f}s "
|
f"Evaluating model parameters with timeout of {timeout_seconds:.0f}s "
|
||||||
f"(model size: {model_size_gb:.1f}GB)"
|
f"(model size: {model_size.in_gb:.1f}GB)"
|
||||||
)
|
)
|
||||||
|
|
||||||
match shard_metadata:
|
match shard_metadata:
|
||||||
@@ -458,6 +458,19 @@ def _patch_lossy_chat_template(template: str) -> str | None:
|
|||||||
return patched if n > 0 else None
|
return patched if n > 0 else None
|
||||||
|
|
||||||
|
|
||||||
|
def _needs_dsml_encoding(task_params: TextGenerationTaskParams) -> bool:
|
||||||
|
if "deepseek-v3.2" not in task_params.model.lower():
|
||||||
|
return False
|
||||||
|
# Use DSML encoding when tools are provided or tool results are in the conversation
|
||||||
|
if task_params.tools:
|
||||||
|
return True
|
||||||
|
if task_params.chat_template_messages:
|
||||||
|
return any(
|
||||||
|
msg.get("role") == "tool" for msg in task_params.chat_template_messages
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def apply_chat_template(
|
def apply_chat_template(
|
||||||
tokenizer: TokenizerWrapper,
|
tokenizer: TokenizerWrapper,
|
||||||
task_params: TextGenerationTaskParams,
|
task_params: TextGenerationTaskParams,
|
||||||
@@ -469,7 +482,6 @@ def apply_chat_template(
|
|||||||
|
|
||||||
When chat_template_messages is available (from Chat Completions API),
|
When chat_template_messages is available (from Chat Completions API),
|
||||||
uses those directly to preserve tool_calls, thinking, and other fields.
|
uses those directly to preserve tool_calls, thinking, and other fields.
|
||||||
Otherwise builds messages from the task params input/instructions.
|
|
||||||
"""
|
"""
|
||||||
formatted_messages: list[dict[str, Any]] = []
|
formatted_messages: list[dict[str, Any]] = []
|
||||||
if task_params.chat_template_messages is not None:
|
if task_params.chat_template_messages is not None:
|
||||||
@@ -497,6 +509,19 @@ def apply_chat_template(
|
|||||||
partial_assistant_content = cast(str, formatted_messages[-1].get("content", ""))
|
partial_assistant_content = cast(str, formatted_messages[-1].get("content", ""))
|
||||||
formatted_messages = formatted_messages[:-1]
|
formatted_messages = formatted_messages[:-1]
|
||||||
|
|
||||||
|
if _needs_dsml_encoding(task_params):
|
||||||
|
from exo.worker.engines.mlx.dsml_encoding import encode_messages
|
||||||
|
|
||||||
|
prompt = encode_messages(
|
||||||
|
messages=formatted_messages,
|
||||||
|
thinking_mode="thinking" if task_params.enable_thinking else "chat",
|
||||||
|
tools=task_params.tools,
|
||||||
|
)
|
||||||
|
if partial_assistant_content:
|
||||||
|
prompt += partial_assistant_content
|
||||||
|
logger.info(prompt)
|
||||||
|
return prompt
|
||||||
|
|
||||||
extra_kwargs: dict[str, Any] = {}
|
extra_kwargs: dict[str, Any] = {}
|
||||||
if task_params.enable_thinking is not None:
|
if task_params.enable_thinking is not None:
|
||||||
# Qwen3 and GLM use "enable_thinking"; DeepSeek uses "thinking".
|
# Qwen3 and GLM use "enable_thinking"; DeepSeek uses "thinking".
|
||||||
@@ -617,18 +642,17 @@ def set_wired_limit_for_model(model_size: Memory):
|
|||||||
if not mx.metal.is_available():
|
if not mx.metal.is_available():
|
||||||
return
|
return
|
||||||
|
|
||||||
model_bytes = model_size.in_bytes
|
max_rec_size = Memory.from_bytes(
|
||||||
max_rec_size = int(mx.metal.device_info()["max_recommended_working_set_size"])
|
int(mx.metal.device_info()["max_recommended_working_set_size"])
|
||||||
if model_bytes > 0.9 * max_rec_size:
|
)
|
||||||
model_mb = model_bytes // 2**20
|
if model_size > 0.9 * max_rec_size:
|
||||||
max_rec_mb = max_rec_size // 2**20
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Generating with a model that requires {model_mb} MB "
|
f"Generating with a model that requires {model_size.in_float_mb:.1f} MB "
|
||||||
f"which is close to the maximum recommended size of {max_rec_mb} "
|
f"which is close to the maximum recommended size of {max_rec_size.in_float_mb:.1f} "
|
||||||
"MB. This can be slow. See the documentation for possible work-arounds: "
|
"MB. This can be slow. See the documentation for possible work-arounds: "
|
||||||
"https://github.com/ml-explore/mlx-lm/tree/main#large-models"
|
"https://github.com/ml-explore/mlx-lm/tree/main#large-models"
|
||||||
)
|
)
|
||||||
mx.set_wired_limit(max_rec_size)
|
mx.set_wired_limit(max_rec_size.in_bytes)
|
||||||
logger.info(f"Wired limit set to {max_rec_size}.")
|
logger.info(f"Wired limit set to {max_rec_size}.")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -4,9 +4,10 @@ import resource
|
|||||||
import time
|
import time
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from functools import cache
|
from functools import cache
|
||||||
from typing import Literal
|
from typing import TYPE_CHECKING, Literal
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
from mlx_lm.models.deepseek_v32 import Model as DeepseekV32Model
|
||||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||||
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
|
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
|
||||||
@@ -21,12 +22,17 @@ from exo.shared.constants import EXO_MAX_CHUNK_SIZE, EXO_TRACING_ENABLED
|
|||||||
from exo.shared.models.model_cards import ModelId, ModelTask
|
from exo.shared.models.model_cards import ModelId, ModelTask
|
||||||
from exo.shared.tracing import clear_trace_buffer, get_trace_buffer
|
from exo.shared.tracing import clear_trace_buffer, get_trace_buffer
|
||||||
from exo.shared.types.api import ImageGenerationStats
|
from exo.shared.types.api import ImageGenerationStats
|
||||||
from exo.shared.types.chunks import ErrorChunk, ImageChunk, TokenChunk, ToolCallChunk
|
from exo.shared.types.chunks import (
|
||||||
|
ErrorChunk,
|
||||||
|
ImageChunk,
|
||||||
|
PrefillProgressChunk,
|
||||||
|
TokenChunk,
|
||||||
|
ToolCallChunk,
|
||||||
|
)
|
||||||
from exo.shared.types.common import CommandId
|
from exo.shared.types.common import CommandId
|
||||||
from exo.shared.types.events import (
|
from exo.shared.types.events import (
|
||||||
ChunkGenerated,
|
ChunkGenerated,
|
||||||
Event,
|
Event,
|
||||||
PrefillProgress,
|
|
||||||
RunnerStatusUpdated,
|
RunnerStatusUpdated,
|
||||||
TaskAcknowledged,
|
TaskAcknowledged,
|
||||||
TaskStatusUpdated,
|
TaskStatusUpdated,
|
||||||
@@ -315,11 +321,13 @@ def main(
|
|||||||
) -> None:
|
) -> None:
|
||||||
if device_rank == 0:
|
if device_rank == 0:
|
||||||
event_sender.send(
|
event_sender.send(
|
||||||
PrefillProgress(
|
ChunkGenerated(
|
||||||
command_id=command_id,
|
command_id=command_id,
|
||||||
model=shard_metadata.model_card.model_id,
|
chunk=PrefillProgressChunk(
|
||||||
processed_tokens=processed,
|
model=shard_metadata.model_card.model_id,
|
||||||
total_tokens=total,
|
processed_tokens=processed,
|
||||||
|
total_tokens=total,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
cancelled_tasks.update(cancel_receiver.collect())
|
cancelled_tasks.update(cancel_receiver.collect())
|
||||||
@@ -346,16 +354,22 @@ def main(
|
|||||||
group=group,
|
group=group,
|
||||||
)
|
)
|
||||||
|
|
||||||
# For other thinking models (GLM, etc.), check if we need to
|
if tokenizer.has_thinking:
|
||||||
# prepend the thinking tag that was consumed by the chat template
|
|
||||||
if detect_thinking_prompt_suffix(prompt, tokenizer):
|
|
||||||
mlx_generator = parse_thinking_models(
|
mlx_generator = parse_thinking_models(
|
||||||
mlx_generator, tokenizer
|
mlx_generator,
|
||||||
|
tokenizer,
|
||||||
|
# For other thinking models (GLM, etc.), check if we need to
|
||||||
|
# prepend the thinking tag that was consumed by the chat template
|
||||||
|
starts_in_thinking=detect_thinking_prompt_suffix(
|
||||||
|
prompt, tokenizer
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# GPT-OSS specific parsing to match other model formats.
|
# Model-specific output parsing for tool calls.
|
||||||
if isinstance(inference_model, GptOssModel):
|
if isinstance(inference_model, GptOssModel):
|
||||||
mlx_generator = parse_gpt_oss(mlx_generator)
|
mlx_generator = parse_gpt_oss(mlx_generator)
|
||||||
|
elif isinstance(inference_model, DeepseekV32Model):
|
||||||
|
mlx_generator = parse_deepseek_v32(mlx_generator)
|
||||||
elif tool_parser:
|
elif tool_parser:
|
||||||
mlx_generator = parse_tool_calls(mlx_generator, tool_parser)
|
mlx_generator = parse_tool_calls(mlx_generator, tool_parser)
|
||||||
|
|
||||||
@@ -407,6 +421,7 @@ def main(
|
|||||||
stats=response.stats,
|
stats=response.stats,
|
||||||
logprob=response.logprob,
|
logprob=response.logprob,
|
||||||
top_logprobs=response.top_logprobs,
|
top_logprobs=response.top_logprobs,
|
||||||
|
is_thinking=response.is_thinking,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -573,6 +588,13 @@ def main(
|
|||||||
case Shutdown():
|
case Shutdown():
|
||||||
current_status = RunnerShuttingDown()
|
current_status = RunnerShuttingDown()
|
||||||
logger.info("runner shutting down")
|
logger.info("runner shutting down")
|
||||||
|
if not TYPE_CHECKING:
|
||||||
|
del inference_model, image_model, tokenizer, group
|
||||||
|
mx.clear_cache()
|
||||||
|
import gc
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
event_sender.send(
|
event_sender.send(
|
||||||
RunnerStatusUpdated(
|
RunnerStatusUpdated(
|
||||||
runner_id=runner_id, runner_status=current_status
|
runner_id=runner_id, runner_status=current_status
|
||||||
@@ -597,12 +619,8 @@ def main(
|
|||||||
event_sender.send(
|
event_sender.send(
|
||||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
|
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
|
||||||
)
|
)
|
||||||
if isinstance(current_status, RunnerShutdown):
|
|
||||||
del inference_model, image_model, tokenizer, group
|
|
||||||
mx.clear_cache()
|
|
||||||
import gc
|
|
||||||
|
|
||||||
gc.collect()
|
if isinstance(current_status, RunnerShutdown):
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
@@ -668,44 +686,208 @@ def parse_gpt_oss(
|
|||||||
|
|
||||||
if ch == "analysis" and not thinking:
|
if ch == "analysis" and not thinking:
|
||||||
thinking = True
|
thinking = True
|
||||||
yield response.model_copy(update={"text": "<think>"})
|
|
||||||
|
|
||||||
if ch != "analysis" and thinking:
|
if ch != "analysis" and thinking:
|
||||||
thinking = False
|
thinking = False
|
||||||
yield response.model_copy(update={"text": "</think>"})
|
|
||||||
|
|
||||||
if delta:
|
if delta:
|
||||||
yield response.model_copy(update={"text": delta})
|
yield response.model_copy(update={"text": delta, "is_thinking": thinking})
|
||||||
|
|
||||||
if response.finish_reason is not None:
|
if response.finish_reason is not None:
|
||||||
if thinking:
|
|
||||||
yield response.model_copy(update={"text": "</think>"})
|
|
||||||
yield response
|
yield response
|
||||||
|
|
||||||
|
|
||||||
|
def parse_deepseek_v32(
|
||||||
|
responses: Generator[GenerationResponse],
|
||||||
|
) -> Generator[GenerationResponse | ToolCallResponse]:
|
||||||
|
"""Parse DeepSeek V3.2 DSML tool calls from the generation stream.
|
||||||
|
|
||||||
|
Uses accumulated-text matching (not per-token marker checks) because
|
||||||
|
DSML markers like <|DSML|function_calls> may span multiple tokens.
|
||||||
|
Also handles <think>...</think> blocks for thinking mode.
|
||||||
|
"""
|
||||||
|
from exo.worker.engines.mlx.dsml_encoding import (
|
||||||
|
THINKING_END,
|
||||||
|
THINKING_START,
|
||||||
|
TOOL_CALLS_END,
|
||||||
|
TOOL_CALLS_START,
|
||||||
|
parse_dsml_output,
|
||||||
|
)
|
||||||
|
|
||||||
|
accumulated = ""
|
||||||
|
in_tool_call = False
|
||||||
|
thinking = False
|
||||||
|
# Tokens buffered while we detect the start of a DSML block
|
||||||
|
pending_buffer: list[GenerationResponse] = []
|
||||||
|
# Text accumulated during a tool call block
|
||||||
|
tool_call_text = ""
|
||||||
|
|
||||||
|
for response in responses:
|
||||||
|
assert isinstance(response, GenerationResponse)
|
||||||
|
|
||||||
|
# ── Handle thinking tags ──
|
||||||
|
if not thinking and THINKING_START in response.text:
|
||||||
|
thinking = True
|
||||||
|
# Yield any text before the <think> tag
|
||||||
|
before = response.text[: response.text.index(THINKING_START)]
|
||||||
|
if before:
|
||||||
|
yield response.model_copy(update={"text": before})
|
||||||
|
continue
|
||||||
|
|
||||||
|
if thinking and THINKING_END in response.text:
|
||||||
|
thinking = False
|
||||||
|
# Yield any text after the </think> tag
|
||||||
|
after = response.text[
|
||||||
|
response.text.index(THINKING_END) + len(THINKING_END) :
|
||||||
|
]
|
||||||
|
if after:
|
||||||
|
yield response.model_copy(update={"text": after, "is_thinking": False})
|
||||||
|
continue
|
||||||
|
|
||||||
|
if thinking:
|
||||||
|
yield response.model_copy(update={"is_thinking": True})
|
||||||
|
continue
|
||||||
|
|
||||||
|
# ── Handle tool call accumulation ──
|
||||||
|
if in_tool_call:
|
||||||
|
tool_call_text += response.text
|
||||||
|
if TOOL_CALLS_END in tool_call_text:
|
||||||
|
# Parse the accumulated DSML block
|
||||||
|
parsed = parse_dsml_output(tool_call_text)
|
||||||
|
if parsed is not None:
|
||||||
|
logger.info(f"parsed DSML tool calls: {parsed}")
|
||||||
|
yield ToolCallResponse(
|
||||||
|
tool_calls=parsed,
|
||||||
|
usage=response.usage,
|
||||||
|
stats=response.stats,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"DSML tool call parsing failed for: {tool_call_text}"
|
||||||
|
)
|
||||||
|
yield response.model_copy(update={"text": tool_call_text})
|
||||||
|
in_tool_call = False
|
||||||
|
tool_call_text = ""
|
||||||
|
continue
|
||||||
|
|
||||||
|
# EOS reached before end marker — yield buffered text as-is
|
||||||
|
if response.finish_reason is not None:
|
||||||
|
logger.info("DSML tool call parsing interrupted by EOS")
|
||||||
|
yield response.model_copy(update={"text": tool_call_text})
|
||||||
|
in_tool_call = False
|
||||||
|
tool_call_text = ""
|
||||||
|
continue
|
||||||
|
|
||||||
|
# ── Detect start of tool call block ──
|
||||||
|
accumulated += response.text
|
||||||
|
|
||||||
|
if TOOL_CALLS_START in accumulated:
|
||||||
|
# The start marker might be split across pending_buffer + current token
|
||||||
|
start_idx = accumulated.index(TOOL_CALLS_START)
|
||||||
|
# Yield any pending tokens that are purely before the marker
|
||||||
|
pre_text = accumulated[:start_idx]
|
||||||
|
if pre_text:
|
||||||
|
# Flush pending buffer tokens that contributed text before the marker
|
||||||
|
for buf_resp in pending_buffer:
|
||||||
|
if pre_text:
|
||||||
|
chunk = buf_resp.text
|
||||||
|
if len(chunk) <= len(pre_text):
|
||||||
|
yield buf_resp
|
||||||
|
pre_text = pre_text[len(chunk) :]
|
||||||
|
else:
|
||||||
|
yield buf_resp.model_copy(update={"text": pre_text})
|
||||||
|
pre_text = ""
|
||||||
|
pending_buffer = []
|
||||||
|
tool_call_text = accumulated[start_idx:]
|
||||||
|
accumulated = ""
|
||||||
|
|
||||||
|
# Check if the end marker is already present (entire tool call in one token)
|
||||||
|
if TOOL_CALLS_END in tool_call_text:
|
||||||
|
parsed = parse_dsml_output(tool_call_text)
|
||||||
|
if parsed is not None:
|
||||||
|
logger.info(f"parsed DSML tool calls: {parsed}")
|
||||||
|
yield ToolCallResponse(
|
||||||
|
tool_calls=parsed,
|
||||||
|
usage=response.usage,
|
||||||
|
stats=response.stats,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"DSML tool call parsing failed for: {tool_call_text}"
|
||||||
|
)
|
||||||
|
yield response.model_copy(update={"text": tool_call_text})
|
||||||
|
tool_call_text = ""
|
||||||
|
else:
|
||||||
|
in_tool_call = True
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check if accumulated text might be the start of a DSML marker
|
||||||
|
# Buffer tokens if we see a partial match at the end
|
||||||
|
if _could_be_dsml_prefix(accumulated):
|
||||||
|
pending_buffer.append(response)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# No partial match — flush all pending tokens and the current one
|
||||||
|
for buf_resp in pending_buffer:
|
||||||
|
yield buf_resp
|
||||||
|
pending_buffer = []
|
||||||
|
accumulated = ""
|
||||||
|
yield response
|
||||||
|
|
||||||
|
# Flush any remaining pending buffer at generator end
|
||||||
|
for buf_resp in pending_buffer:
|
||||||
|
yield buf_resp
|
||||||
|
|
||||||
|
|
||||||
|
def _could_be_dsml_prefix(text: str) -> bool:
|
||||||
|
"""Check if the end of text could be the start of a DSML function_calls marker.
|
||||||
|
|
||||||
|
We look for suffixes of text that are prefixes of the TOOL_CALLS_START pattern.
|
||||||
|
This allows us to buffer tokens until we can determine if a tool call is starting.
|
||||||
|
"""
|
||||||
|
from exo.worker.engines.mlx.dsml_encoding import TOOL_CALLS_START
|
||||||
|
|
||||||
|
# Only check the last portion of text that could overlap with the marker
|
||||||
|
max_check = len(TOOL_CALLS_START)
|
||||||
|
tail = text[-max_check:] if len(text) > max_check else text
|
||||||
|
|
||||||
|
# Check if any suffix of tail is a prefix of TOOL_CALLS_START
|
||||||
|
for i in range(len(tail)):
|
||||||
|
suffix = tail[i:]
|
||||||
|
if TOOL_CALLS_START.startswith(suffix):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def parse_thinking_models(
|
def parse_thinking_models(
|
||||||
responses: Generator[GenerationResponse],
|
responses: Generator[GenerationResponse],
|
||||||
tokenizer: TokenizerWrapper,
|
tokenizer: TokenizerWrapper,
|
||||||
|
starts_in_thinking: bool = True,
|
||||||
) -> Generator[GenerationResponse]:
|
) -> Generator[GenerationResponse]:
|
||||||
|
"""Route thinking tokens via is_thinking flag.
|
||||||
|
|
||||||
|
Swallows think tag tokens, sets is_thinking on all others.
|
||||||
|
Always yields tokens with finish_reason to avoid hanging the chunk stream.
|
||||||
"""
|
"""
|
||||||
For models that inject thinking tags in the prompt (like GLM-4.7),
|
in_thinking = starts_in_thinking
|
||||||
prepend the thinking tag to the output stream so the frontend
|
|
||||||
can properly parse thinking content.
|
|
||||||
"""
|
|
||||||
first = True
|
|
||||||
for response in responses:
|
for response in responses:
|
||||||
if isinstance(response, ToolCallResponse):
|
if isinstance(response, ToolCallResponse):
|
||||||
yield response
|
yield response
|
||||||
continue
|
continue
|
||||||
if first:
|
|
||||||
first = False
|
is_think_tag = (
|
||||||
yield response.model_copy(
|
tokenizer.think_end is not None and response.text == tokenizer.think_end
|
||||||
update={
|
) or (
|
||||||
"text": tokenizer.think_start,
|
tokenizer.think_start is not None and response.text == tokenizer.think_start
|
||||||
"token": tokenizer.think_start_id,
|
)
|
||||||
}
|
|
||||||
)
|
if is_think_tag:
|
||||||
yield response
|
in_thinking = response.text != tokenizer.think_end
|
||||||
|
# Never swallow finish_reason — the chunk stream needs it to terminate.
|
||||||
|
if response.finish_reason is not None:
|
||||||
|
yield response.model_copy(update={"text": "", "is_thinking": False})
|
||||||
|
continue
|
||||||
|
yield response.model_copy(update={"is_thinking": in_thinking})
|
||||||
|
|
||||||
|
|
||||||
def _send_image_chunk(
|
def _send_image_chunk(
|
||||||
|
|||||||
@@ -100,8 +100,8 @@ class RunnerSupervisor:
|
|||||||
logger.info("Runner supervisor shutting down")
|
logger.info("Runner supervisor shutting down")
|
||||||
self._ev_recv.close()
|
self._ev_recv.close()
|
||||||
self._task_sender.close()
|
self._task_sender.close()
|
||||||
self._event_sender.close()
|
with contextlib.suppress(ClosedResourceError):
|
||||||
self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK"))
|
self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK"))
|
||||||
self._cancel_sender.close()
|
self._cancel_sender.close()
|
||||||
self.runner_process.join(5)
|
self.runner_process.join(5)
|
||||||
if not self.runner_process.is_alive():
|
if not self.runner_process.is_alive():
|
||||||
@@ -180,6 +180,7 @@ class RunnerSupervisor:
|
|||||||
await self._check_runner(e)
|
await self._check_runner(e)
|
||||||
for tid in self.pending:
|
for tid in self.pending:
|
||||||
self.pending[tid].set()
|
self.pending[tid].set()
|
||||||
|
self._event_sender.close()
|
||||||
|
|
||||||
def __del__(self) -> None:
|
def __del__(self) -> None:
|
||||||
if self.runner_process.is_alive():
|
if self.runner_process.is_alive():
|
||||||
@@ -208,10 +209,15 @@ class RunnerSupervisor:
|
|||||||
|
|
||||||
logger.opt(exception=e).error(f"Runner terminated ({cause})")
|
logger.opt(exception=e).error(f"Runner terminated ({cause})")
|
||||||
|
|
||||||
await self._event_sender.send(
|
try:
|
||||||
RunnerStatusUpdated(
|
await self._event_sender.send(
|
||||||
runner_id=self.bound_instance.bound_runner_id,
|
RunnerStatusUpdated(
|
||||||
runner_status=RunnerFailed(error_message=f"Terminated ({cause})"),
|
runner_id=self.bound_instance.bound_runner_id,
|
||||||
|
runner_status=RunnerFailed(error_message=f"Terminated ({cause})"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except (ClosedResourceError, BrokenResourceError):
|
||||||
|
logger.warning(
|
||||||
|
"Event sender already closed, unable to report runner failure"
|
||||||
)
|
)
|
||||||
)
|
|
||||||
self.shutdown()
|
self.shutdown()
|
||||||
|
|||||||
@@ -90,14 +90,10 @@ def test_plan_loads_model_when_all_shards_downloaded_and_waiting():
|
|||||||
|
|
||||||
global_download_status = {
|
global_download_status = {
|
||||||
NODE_A: [
|
NODE_A: [
|
||||||
DownloadCompleted(
|
DownloadCompleted(shard_metadata=shard1, node_id=NODE_A, total=Memory())
|
||||||
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
|
|
||||||
)
|
|
||||||
],
|
],
|
||||||
NODE_B: [
|
NODE_B: [
|
||||||
DownloadCompleted(
|
DownloadCompleted(shard_metadata=shard2, node_id=NODE_B, total=Memory())
|
||||||
shard_metadata=shard2, node_id=NODE_B, total_bytes=Memory()
|
|
||||||
)
|
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -138,9 +134,7 @@ def test_plan_does_not_request_download_when_shard_already_downloaded():
|
|||||||
# Global state shows shard is downloaded for NODE_A
|
# Global state shows shard is downloaded for NODE_A
|
||||||
global_download_status: dict[NodeId, list[DownloadProgress]] = {
|
global_download_status: dict[NodeId, list[DownloadProgress]] = {
|
||||||
NODE_A: [
|
NODE_A: [
|
||||||
DownloadCompleted(
|
DownloadCompleted(shard_metadata=shard, node_id=NODE_A, total=Memory())
|
||||||
shard_metadata=shard, node_id=NODE_A, total_bytes=Memory()
|
|
||||||
)
|
|
||||||
],
|
],
|
||||||
NODE_B: [],
|
NODE_B: [],
|
||||||
}
|
}
|
||||||
@@ -187,9 +181,7 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
|
|||||||
|
|
||||||
global_download_status = {
|
global_download_status = {
|
||||||
NODE_A: [
|
NODE_A: [
|
||||||
DownloadCompleted(
|
DownloadCompleted(shard_metadata=shard1, node_id=NODE_A, total=Memory())
|
||||||
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
|
|
||||||
)
|
|
||||||
],
|
],
|
||||||
NODE_B: [], # NODE_B has no downloads completed yet
|
NODE_B: [], # NODE_B has no downloads completed yet
|
||||||
}
|
}
|
||||||
@@ -207,14 +199,10 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
|
|||||||
|
|
||||||
global_download_status = {
|
global_download_status = {
|
||||||
NODE_A: [
|
NODE_A: [
|
||||||
DownloadCompleted(
|
DownloadCompleted(shard_metadata=shard1, node_id=NODE_A, total=Memory())
|
||||||
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
|
|
||||||
)
|
|
||||||
],
|
],
|
||||||
NODE_B: [
|
NODE_B: [
|
||||||
DownloadCompleted(
|
DownloadCompleted(shard_metadata=shard2, node_id=NODE_B, total=Memory())
|
||||||
shard_metadata=shard2, node_id=NODE_B, total_bytes=Memory()
|
|
||||||
)
|
|
||||||
], # NODE_B has no downloads completed yet
|
], # NODE_B has no downloads completed yet
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
967
src/exo/worker/tests/unittests/test_runner/test_dsml_e2e.py
Normal file
967
src/exo/worker/tests/unittests/test_runner/test_dsml_e2e.py
Normal file
@@ -0,0 +1,967 @@
|
|||||||
|
import json
|
||||||
|
from collections.abc import Generator
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from exo.shared.types.worker.runner_response import (
|
||||||
|
GenerationResponse,
|
||||||
|
ToolCallResponse,
|
||||||
|
)
|
||||||
|
from exo.worker.engines.mlx.dsml_encoding import (
|
||||||
|
ASSISTANT_TOKEN,
|
||||||
|
BOS_TOKEN,
|
||||||
|
DSML_TOKEN,
|
||||||
|
EOS_TOKEN,
|
||||||
|
THINKING_END,
|
||||||
|
THINKING_START,
|
||||||
|
TOOL_CALLS_END,
|
||||||
|
TOOL_CALLS_START,
|
||||||
|
USER_TOKEN,
|
||||||
|
encode_messages,
|
||||||
|
parse_dsml_output,
|
||||||
|
)
|
||||||
|
from exo.worker.runner.runner import parse_deepseek_v32
|
||||||
|
|
||||||
|
# ── Shared fixtures ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
_WEATHER_TOOLS: list[dict[str, Any]] = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get the current weather in a given city",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"city": {"type": "string", "description": "The city name"},
|
||||||
|
"units": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["celsius", "fahrenheit"],
|
||||||
|
"description": "Temperature units",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["city"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_time",
|
||||||
|
"description": "Get the current time in a timezone",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"timezone": {"type": "string"},
|
||||||
|
},
|
||||||
|
"required": ["timezone"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _simulate_tokens(
|
||||||
|
texts: list[str],
|
||||||
|
finish_on_last: bool = True,
|
||||||
|
) -> Generator[GenerationResponse]:
|
||||||
|
"""Simulate a model producing tokens from a list of text strings."""
|
||||||
|
for i, text in enumerate(texts):
|
||||||
|
is_last = i == len(texts) - 1
|
||||||
|
yield GenerationResponse(
|
||||||
|
text=text,
|
||||||
|
token=i,
|
||||||
|
finish_reason="stop" if (is_last and finish_on_last) else None,
|
||||||
|
usage=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Test: Standard text response (no tool calls) ────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestE2EStandardResponse:
|
||||||
|
"""Model generates a plain text response — no tool calling involved."""
|
||||||
|
|
||||||
|
def test_plain_text_passthrough(self):
|
||||||
|
"""Simulate model producing: 'The weather in NYC is 72°F and sunny.'"""
|
||||||
|
# Step 1: Encode the prompt (with tools available)
|
||||||
|
messages: list[dict[str, Any]] = [
|
||||||
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
{"role": "user", "content": "What's the weather in NYC?"},
|
||||||
|
]
|
||||||
|
prompt = encode_messages(messages, thinking_mode="chat", tools=_WEATHER_TOOLS)
|
||||||
|
|
||||||
|
# Verify prompt structure
|
||||||
|
assert BOS_TOKEN in prompt
|
||||||
|
assert "## Tools" in prompt
|
||||||
|
assert "get_weather" in prompt
|
||||||
|
assert f"{USER_TOKEN}What's the weather in NYC?{ASSISTANT_TOKEN}" in prompt
|
||||||
|
|
||||||
|
# Step 2: Simulate model response — plain text tokens (no DSML)
|
||||||
|
model_tokens = [
|
||||||
|
"The weather",
|
||||||
|
" in NYC",
|
||||||
|
" is 72",
|
||||||
|
"°F",
|
||||||
|
" and sunny",
|
||||||
|
".",
|
||||||
|
]
|
||||||
|
results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))
|
||||||
|
|
||||||
|
# Step 3: Verify all tokens pass through as GenerationResponse
|
||||||
|
gen_results = [r for r in results if isinstance(r, GenerationResponse)]
|
||||||
|
tool_results = [r for r in results if isinstance(r, ToolCallResponse)]
|
||||||
|
|
||||||
|
assert len(tool_results) == 0
|
||||||
|
assert len(gen_results) == 6
|
||||||
|
full_text = "".join(r.text for r in gen_results)
|
||||||
|
assert full_text == "The weather in NYC is 72°F and sunny."
|
||||||
|
assert gen_results[-1].finish_reason == "stop"
|
||||||
|
|
||||||
|
|
||||||
|
# ── Test: Tool call response ─────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestE2EToolCallResponse:
|
||||||
|
"""Model generates a DSML tool call — realistic token boundaries."""
|
||||||
|
|
||||||
|
def test_realistic_tool_call_tokens(self):
|
||||||
|
"""Simulate model generating a get_weather tool call with realistic token splits.
|
||||||
|
|
||||||
|
Real models split DSML markers across tokens unpredictably.
|
||||||
|
This simulates how DeepSeek V3.2 actually tokenizes DSML output.
|
||||||
|
"""
|
||||||
|
# Step 1: Encode prompt
|
||||||
|
messages: list[dict[str, Any]] = [
|
||||||
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
{"role": "user", "content": "What's the weather in San Francisco?"},
|
||||||
|
]
|
||||||
|
prompt = encode_messages(messages, thinking_mode="chat", tools=_WEATHER_TOOLS)
|
||||||
|
assert "get_weather" in prompt
|
||||||
|
|
||||||
|
# Step 2: Simulate realistic token-by-token model output
|
||||||
|
# The model first produces some text, then a DSML tool call block
|
||||||
|
model_tokens = [
|
||||||
|
"I'll check the weather for you.",
|
||||||
|
"\n\n",
|
||||||
|
f"<{DSML_TOKEN}", # marker split across tokens
|
||||||
|
"function_calls>\n",
|
||||||
|
f'<{DSML_TOKEN}invoke name="get_weather">\n',
|
||||||
|
f'<{DSML_TOKEN}parameter name="city" string="true">',
|
||||||
|
"San Francisco",
|
||||||
|
f"</{DSML_TOKEN}parameter>\n",
|
||||||
|
f'<{DSML_TOKEN}parameter name="units" string="false">',
|
||||||
|
'"celsius"',
|
||||||
|
f"</{DSML_TOKEN}parameter>\n",
|
||||||
|
f"</{DSML_TOKEN}invoke>\n",
|
||||||
|
f"</{DSML_TOKEN}function_calls>",
|
||||||
|
]
|
||||||
|
|
||||||
|
results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))
|
||||||
|
|
||||||
|
# Step 3: Verify
|
||||||
|
gen_results = [r for r in results if isinstance(r, GenerationResponse)]
|
||||||
|
tool_results = [r for r in results if isinstance(r, ToolCallResponse)]
|
||||||
|
|
||||||
|
# Should have text tokens before tool call + one ToolCallResponse
|
||||||
|
assert len(tool_results) == 1
|
||||||
|
assert len(tool_results[0].tool_calls) == 1
|
||||||
|
|
||||||
|
tc = tool_results[0].tool_calls[0]
|
||||||
|
assert tc.name == "get_weather"
|
||||||
|
args = json.loads(tc.arguments) # pyright: ignore[reportAny]
|
||||||
|
assert args["city"] == "San Francisco"
|
||||||
|
assert args["units"] == "celsius"
|
||||||
|
|
||||||
|
# The text before the tool call should still be yielded
|
||||||
|
text_before = "".join(r.text for r in gen_results if not r.is_thinking)
|
||||||
|
assert "check the weather" in text_before
|
||||||
|
|
||||||
|
def test_multiple_tool_calls_in_one_block(self):
|
||||||
|
"""Model generates two tool calls in a single function_calls block."""
|
||||||
|
messages: list[dict[str, Any]] = [
|
||||||
|
{"role": "system", "content": "You are helpful."},
|
||||||
|
{"role": "user", "content": "Weather in NYC and time in EST?"},
|
||||||
|
]
|
||||||
|
prompt = encode_messages(messages, thinking_mode="chat", tools=_WEATHER_TOOLS)
|
||||||
|
assert "get_weather" in prompt
|
||||||
|
assert "get_time" in prompt
|
||||||
|
|
||||||
|
# Simulate model output with two invocations
|
||||||
|
model_tokens = [
|
||||||
|
"Let me check both.\n\n",
|
||||||
|
TOOL_CALLS_START,
|
||||||
|
"\n",
|
||||||
|
f'<{DSML_TOKEN}invoke name="get_weather">\n',
|
||||||
|
f'<{DSML_TOKEN}parameter name="city" string="true">NYC</{DSML_TOKEN}parameter>\n',
|
||||||
|
f"</{DSML_TOKEN}invoke>\n",
|
||||||
|
f'<{DSML_TOKEN}invoke name="get_time">\n',
|
||||||
|
f'<{DSML_TOKEN}parameter name="timezone" string="true">EST</{DSML_TOKEN}parameter>\n',
|
||||||
|
f"</{DSML_TOKEN}invoke>\n",
|
||||||
|
TOOL_CALLS_END,
|
||||||
|
]
|
||||||
|
|
||||||
|
results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))
|
||||||
|
tool_results = [r for r in results if isinstance(r, ToolCallResponse)]
|
||||||
|
|
||||||
|
assert len(tool_results) == 1
|
||||||
|
assert len(tool_results[0].tool_calls) == 2
|
||||||
|
assert tool_results[0].tool_calls[0].name == "get_weather"
|
||||||
|
assert tool_results[0].tool_calls[1].name == "get_time"
|
||||||
|
|
||||||
|
args0 = json.loads(tool_results[0].tool_calls[0].arguments) # pyright: ignore[reportAny]
|
||||||
|
args1 = json.loads(tool_results[0].tool_calls[1].arguments) # pyright: ignore[reportAny]
|
||||||
|
assert args0 == {"city": "NYC"}
|
||||||
|
assert args1 == {"timezone": "EST"}
|
||||||
|
|
||||||
|
|
||||||
|
# ── Test: Multi-turn tool use flow ───────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestE2EMultiTurnToolUse:
|
||||||
|
"""Full multi-turn: user asks → model calls tool → tool result → model answers."""
|
||||||
|
|
||||||
|
def test_encode_multi_turn_with_tool_results(self):
|
||||||
|
"""Verify the prompt for turn 2 (after tool results) is correctly encoded."""
|
||||||
|
# Turn 1: user asks, model calls tool
|
||||||
|
# Turn 2: tool result provided, model answers
|
||||||
|
messages: list[dict[str, Any]] = [
|
||||||
|
{"role": "system", "content": "You are a weather assistant."},
|
||||||
|
{"role": "user", "content": "What's the weather in NYC?"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"arguments": '{"city": "NYC"}',
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{"role": "tool", "content": '{"temperature": 72, "condition": "sunny"}'},
|
||||||
|
]
|
||||||
|
|
||||||
|
prompt = encode_messages(messages, thinking_mode="chat", tools=_WEATHER_TOOLS)
|
||||||
|
|
||||||
|
# Verify multi-turn structure
|
||||||
|
assert BOS_TOKEN in prompt
|
||||||
|
assert "You are a weather assistant." in prompt
|
||||||
|
assert "## Tools" in prompt
|
||||||
|
|
||||||
|
# The assistant's tool call should be encoded as DSML
|
||||||
|
assert TOOL_CALLS_START in prompt
|
||||||
|
assert f'<{DSML_TOKEN}invoke name="get_weather">' in prompt
|
||||||
|
assert EOS_TOKEN in prompt
|
||||||
|
|
||||||
|
# The tool result should be wrapped in function_results
|
||||||
|
assert "<function_results>" in prompt
|
||||||
|
assert "<result>" in prompt
|
||||||
|
assert "72" in prompt
|
||||||
|
assert "</function_results>" in prompt
|
||||||
|
|
||||||
|
# Now simulate model answering after seeing the tool result
|
||||||
|
model_tokens = [
|
||||||
|
"The current",
|
||||||
|
" weather in NYC",
|
||||||
|
" is 72°F",
|
||||||
|
" and sunny.",
|
||||||
|
]
|
||||||
|
results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))
|
||||||
|
|
||||||
|
gen_results = [r for r in results if isinstance(r, GenerationResponse)]
|
||||||
|
tool_results = [r for r in results if isinstance(r, ToolCallResponse)]
|
||||||
|
|
||||||
|
assert len(tool_results) == 0
|
||||||
|
full_text = "".join(r.text for r in gen_results)
|
||||||
|
assert full_text == "The current weather in NYC is 72°F and sunny."
|
||||||
|
|
||||||
|
def test_multi_tool_results_encoding(self):
|
||||||
|
"""Verify encoding when model called two tools and both return results."""
|
||||||
|
messages: list[dict[str, Any]] = [
|
||||||
|
{"role": "user", "content": "Weather and time?"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"arguments": '{"city": "LA"}',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_time",
|
||||||
|
"arguments": '{"timezone": "PST"}',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{"role": "tool", "content": "85F, clear skies"},
|
||||||
|
{"role": "tool", "content": "3:42 PM PST"},
|
||||||
|
]
|
||||||
|
|
||||||
|
prompt = encode_messages(messages, thinking_mode="chat", tools=_WEATHER_TOOLS)
|
||||||
|
|
||||||
|
# Should have one function_results block with two results
|
||||||
|
assert prompt.count("<function_results>") == 1
|
||||||
|
assert prompt.count("</function_results>") == 1
|
||||||
|
assert "<result>85F, clear skies</result>" in prompt
|
||||||
|
assert "<result>3:42 PM PST</result>" in prompt
|
||||||
|
|
||||||
|
|
||||||
|
# ── Test: Thinking + tool call ───────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestE2EThinkingAndToolCall:
|
||||||
|
"""Model uses thinking mode, reasons, then makes a tool call."""
|
||||||
|
|
||||||
|
def test_thinking_then_tool_call(self):
|
||||||
|
"""Model thinks first, then produces a DSML tool call block."""
|
||||||
|
messages: list[dict[str, Any]] = [
|
||||||
|
{"role": "user", "content": "What's the weather?"},
|
||||||
|
]
|
||||||
|
prompt = encode_messages(
|
||||||
|
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
|
||||||
|
)
|
||||||
|
# Thinking mode: prompt should end with <think>
|
||||||
|
assert prompt.endswith(THINKING_START)
|
||||||
|
|
||||||
|
# Simulate: model outputs <think>, thinks, closes thinking, then tool call.
|
||||||
|
# In the full pipeline, parse_thinking_models handles the case where
|
||||||
|
# <think> is in the prompt. Here we test parse_deepseek_v32 directly,
|
||||||
|
# which detects <think>/<think> markers in the stream.
|
||||||
|
model_tokens = [
|
||||||
|
THINKING_START,
|
||||||
|
"The user wants weather",
|
||||||
|
" information. I should use",
|
||||||
|
" the get_weather tool.",
|
||||||
|
THINKING_END,
|
||||||
|
"\n\n",
|
||||||
|
TOOL_CALLS_START,
|
||||||
|
"\n",
|
||||||
|
f'<{DSML_TOKEN}invoke name="get_weather">\n',
|
||||||
|
f'<{DSML_TOKEN}parameter name="city" string="true">',
|
||||||
|
"San Francisco",
|
||||||
|
f"</{DSML_TOKEN}parameter>\n",
|
||||||
|
f"</{DSML_TOKEN}invoke>\n",
|
||||||
|
TOOL_CALLS_END,
|
||||||
|
]
|
||||||
|
|
||||||
|
results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))
|
||||||
|
|
||||||
|
gen_results = [r for r in results if isinstance(r, GenerationResponse)]
|
||||||
|
tool_results = [r for r in results if isinstance(r, ToolCallResponse)]
|
||||||
|
|
||||||
|
# Should have thinking tokens + tool call
|
||||||
|
thinking_results = [r for r in gen_results if r.is_thinking]
|
||||||
|
|
||||||
|
assert len(thinking_results) >= 1
|
||||||
|
thinking_text = "".join(r.text for r in thinking_results)
|
||||||
|
assert "get_weather tool" in thinking_text
|
||||||
|
|
||||||
|
assert len(tool_results) == 1
|
||||||
|
assert tool_results[0].tool_calls[0].name == "get_weather"
|
||||||
|
args = json.loads(tool_results[0].tool_calls[0].arguments) # pyright: ignore[reportAny]
|
||||||
|
assert args["city"] == "San Francisco"
|
||||||
|
|
||||||
|
def test_thinking_prompt_encoding(self):
|
||||||
|
"""Verify thinking mode affects prompt encoding correctly."""
|
||||||
|
messages: list[dict[str, Any]] = [
|
||||||
|
{"role": "system", "content": "Be thorough."},
|
||||||
|
{"role": "user", "content": "What's the weather?"},
|
||||||
|
]
|
||||||
|
|
||||||
|
# With thinking enabled
|
||||||
|
prompt_think = encode_messages(
|
||||||
|
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
|
||||||
|
)
|
||||||
|
assert prompt_think.endswith(THINKING_START)
|
||||||
|
|
||||||
|
# With thinking disabled
|
||||||
|
prompt_no_think = encode_messages(
|
||||||
|
messages, tools=_WEATHER_TOOLS, thinking_mode="chat"
|
||||||
|
)
|
||||||
|
assert prompt_no_think.endswith(THINKING_END)
|
||||||
|
|
||||||
|
# Both should have the same tool definitions
|
||||||
|
assert "get_weather" in prompt_think
|
||||||
|
assert "get_weather" in prompt_no_think
|
||||||
|
|
||||||
|
|
||||||
|
# ── Test: Round-trip encode → parse ──────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestE2ERoundTrip:
|
||||||
|
"""Verify that DSML we encode can be parsed back correctly."""
|
||||||
|
|
||||||
|
def test_encoded_tool_call_is_parseable(self):
|
||||||
|
"""Encode an assistant tool call message, then parse the DSML output."""
|
||||||
|
messages: list[dict[str, Any]] = [
|
||||||
|
{"role": "user", "content": "Weather?"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"arguments": '{"city": "Tokyo", "units": "celsius"}',
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
prompt = encode_messages(messages, thinking_mode="chat", tools=_WEATHER_TOOLS)
|
||||||
|
|
||||||
|
# Extract the DSML function_calls block from the prompt
|
||||||
|
start = prompt.index(TOOL_CALLS_START)
|
||||||
|
end = prompt.index(TOOL_CALLS_END) + len(TOOL_CALLS_END)
|
||||||
|
dsml_block = prompt[start:end]
|
||||||
|
|
||||||
|
# Parse it back
|
||||||
|
parsed = parse_dsml_output(dsml_block)
|
||||||
|
assert parsed is not None
|
||||||
|
assert len(parsed) == 1
|
||||||
|
assert parsed[0].name == "get_weather"
|
||||||
|
args = json.loads(parsed[0].arguments) # pyright: ignore[reportAny]
|
||||||
|
assert args["city"] == "Tokyo"
|
||||||
|
assert args["units"] == "celsius"
|
||||||
|
|
||||||
|
def test_encoded_multi_tool_call_round_trips(self):
|
||||||
|
"""Encode multiple tool calls, verify they parse back correctly."""
|
||||||
|
messages: list[dict[str, Any]] = [
|
||||||
|
{"role": "user", "content": "Both please"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"arguments": '{"city": "Paris"}',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_time",
|
||||||
|
"arguments": '{"timezone": "CET"}',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
prompt = encode_messages(messages, thinking_mode="chat", tools=_WEATHER_TOOLS)
|
||||||
|
|
||||||
|
start = prompt.index(TOOL_CALLS_START)
|
||||||
|
end = prompt.index(TOOL_CALLS_END) + len(TOOL_CALLS_END)
|
||||||
|
dsml_block = prompt[start:end]
|
||||||
|
|
||||||
|
parsed = parse_dsml_output(dsml_block)
|
||||||
|
assert parsed is not None
|
||||||
|
assert len(parsed) == 2
|
||||||
|
assert parsed[0].name == "get_weather"
|
||||||
|
assert parsed[1].name == "get_time"
|
||||||
|
assert json.loads(parsed[0].arguments) == {"city": "Paris"}
|
||||||
|
assert json.loads(parsed[1].arguments) == {"timezone": "CET"}
|
||||||
|
|
||||||
|
|
||||||
|
# ── Test: Edge cases with realistic token boundaries ─────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestE2EEdgeCases:
|
||||||
|
"""Edge cases that occur in real model inference."""
|
||||||
|
|
||||||
|
def test_dsml_marker_split_at_fullwidth_pipe(self):
|
||||||
|
"""The fullwidth pipe character | might be its own token."""
|
||||||
|
# This is a realistic tokenization: the DSML marker is split at the | chars
|
||||||
|
model_tokens = [
|
||||||
|
"Let me help.\n\n",
|
||||||
|
"<\uff5c", # start of |DSML|
|
||||||
|
"DSML\uff5c", # rest of DSML token
|
||||||
|
"function_calls>\n",
|
||||||
|
f'<{DSML_TOKEN}invoke name="get_weather">\n',
|
||||||
|
f'<{DSML_TOKEN}parameter name="city" string="true">NYC</{DSML_TOKEN}parameter>\n',
|
||||||
|
f"</{DSML_TOKEN}invoke>\n",
|
||||||
|
TOOL_CALLS_END,
|
||||||
|
]
|
||||||
|
|
||||||
|
results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))
|
||||||
|
tool_results = [r for r in results if isinstance(r, ToolCallResponse)]
|
||||||
|
|
||||||
|
assert len(tool_results) == 1
|
||||||
|
assert tool_results[0].tool_calls[0].name == "get_weather"
|
||||||
|
|
||||||
|
def test_tool_call_with_nested_json_object(self):
|
||||||
|
"""Model passes a complex JSON object as a non-string parameter."""
|
||||||
|
dsml_block = (
|
||||||
|
f"{TOOL_CALLS_START}\n"
|
||||||
|
f'<{DSML_TOKEN}invoke name="create_event">\n'
|
||||||
|
f'<{DSML_TOKEN}parameter name="title" string="true">Team Standup</{DSML_TOKEN}parameter>\n'
|
||||||
|
f'<{DSML_TOKEN}parameter name="config" string="false">'
|
||||||
|
f'{{"recurring": true, "days": ["mon", "wed", "fri"], "time": "09:00"}}'
|
||||||
|
f"</{DSML_TOKEN}parameter>\n"
|
||||||
|
f"</{DSML_TOKEN}invoke>\n"
|
||||||
|
f"{TOOL_CALLS_END}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Feed as single token (model might produce it all at once after prefill)
|
||||||
|
results = list(parse_deepseek_v32(_simulate_tokens([dsml_block])))
|
||||||
|
tool_results = [r for r in results if isinstance(r, ToolCallResponse)]
|
||||||
|
|
||||||
|
assert len(tool_results) == 1
|
||||||
|
tc = tool_results[0].tool_calls[0]
|
||||||
|
assert tc.name == "create_event"
|
||||||
|
args = json.loads(tc.arguments) # pyright: ignore[reportAny]
|
||||||
|
assert args["title"] == "Team Standup"
|
||||||
|
assert args["config"]["recurring"] is True
|
||||||
|
assert args["config"]["days"] == ["mon", "wed", "fri"]
|
||||||
|
|
||||||
|
def test_text_with_angle_brackets_not_mistaken_for_dsml(self):
|
||||||
|
"""Angle brackets in normal text should not trigger DSML buffering."""
|
||||||
|
model_tokens = [
|
||||||
|
"The formula is ",
|
||||||
|
"<x, y>",
|
||||||
|
" where x > 0",
|
||||||
|
" and y < 100.",
|
||||||
|
]
|
||||||
|
|
||||||
|
results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))
|
||||||
|
gen_results = [r for r in results if isinstance(r, GenerationResponse)]
|
||||||
|
tool_results = [r for r in results if isinstance(r, ToolCallResponse)]
|
||||||
|
|
||||||
|
assert len(tool_results) == 0
|
||||||
|
full_text = "".join(r.text for r in gen_results)
|
||||||
|
assert "formula" in full_text
|
||||||
|
assert "<x, y>" in full_text
|
||||||
|
|
||||||
|
def test_empty_model_response(self):
|
||||||
|
"""Model produces only EOS (empty response)."""
|
||||||
|
model_tokens = [""]
|
||||||
|
results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))
|
||||||
|
gen_results = [r for r in results if isinstance(r, GenerationResponse)]
|
||||||
|
assert len(gen_results) == 1
|
||||||
|
assert gen_results[0].text == ""
|
||||||
|
assert gen_results[0].finish_reason == "stop"
|
||||||
|
|
||||||
|
|
||||||
|
# ── Test: Full EPDP spec round-trip ──────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestE2EFullRoundTrip:
|
||||||
|
"""Full round-trip matching the vLLM EPDP spec.
|
||||||
|
|
||||||
|
Simulates the complete multi-turn flow:
|
||||||
|
Turn 1: user asks → think → tool call → tool result → think → answer
|
||||||
|
Turn 2: user asks again → old reasoning stripped → think → answer
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_single_tool_full_flow_with_thinking(self):
|
||||||
|
"""Complete flow: user → think → tool call → tool result → think → answer.
|
||||||
|
|
||||||
|
This is the core EPDP flow from the vLLM spec.
|
||||||
|
"""
|
||||||
|
# ── Turn 1.1: User asks, encode prompt ──
|
||||||
|
messages: list[dict[str, Any]] = [
|
||||||
|
{"role": "system", "content": "You are a weather assistant."},
|
||||||
|
{"role": "user", "content": "How's the weather in Hangzhou?"},
|
||||||
|
]
|
||||||
|
prompt_1 = encode_messages(
|
||||||
|
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
|
||||||
|
)
|
||||||
|
assert prompt_1.endswith(THINKING_START)
|
||||||
|
assert "## Tools" in prompt_1
|
||||||
|
assert "get_weather" in prompt_1
|
||||||
|
|
||||||
|
# ── Turn 1.1: Model thinks, then calls tool ──
|
||||||
|
model_tokens_1 = [
|
||||||
|
THINKING_START,
|
||||||
|
"The user wants to know the weather in Hangzhou.",
|
||||||
|
" I need to use the get_weather tool.",
|
||||||
|
THINKING_END,
|
||||||
|
"\n\n",
|
||||||
|
TOOL_CALLS_START,
|
||||||
|
"\n",
|
||||||
|
f'<{DSML_TOKEN}invoke name="get_weather">\n',
|
||||||
|
f'<{DSML_TOKEN}parameter name="city" string="true">Hangzhou</{DSML_TOKEN}parameter>\n',
|
||||||
|
f"</{DSML_TOKEN}invoke>\n",
|
||||||
|
TOOL_CALLS_END,
|
||||||
|
]
|
||||||
|
results_1 = list(parse_deepseek_v32(_simulate_tokens(model_tokens_1)))
|
||||||
|
|
||||||
|
# Verify: thinking tokens + tool call
|
||||||
|
gen_1 = [r for r in results_1 if isinstance(r, GenerationResponse)]
|
||||||
|
tool_1 = [r for r in results_1 if isinstance(r, ToolCallResponse)]
|
||||||
|
thinking_1 = [r for r in gen_1 if r.is_thinking]
|
||||||
|
|
||||||
|
assert len(thinking_1) >= 1
|
||||||
|
assert "get_weather tool" in "".join(r.text for r in thinking_1)
|
||||||
|
assert len(tool_1) == 1
|
||||||
|
assert tool_1[0].tool_calls[0].name == "get_weather"
|
||||||
|
tc_args = json.loads(tool_1[0].tool_calls[0].arguments) # pyright: ignore[reportAny]
|
||||||
|
assert tc_args == {"city": "Hangzhou"}
|
||||||
|
|
||||||
|
# ── Turn 1.2: Add assistant response + tool result to messages ──
|
||||||
|
messages.append(
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "",
|
||||||
|
"reasoning_content": "The user wants to know the weather in Hangzhou. I need to use the get_weather tool.",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"arguments": '{"city": "Hangzhou"}',
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
messages.append(
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"content": '{"temperature": "7~13°C", "condition": "Cloudy"}',
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Encode prompt for turn 1.2
|
||||||
|
prompt_2 = encode_messages(
|
||||||
|
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify: prompt has the full conversation structure
|
||||||
|
assert TOOL_CALLS_START in prompt_2 # assistant's encoded tool call
|
||||||
|
assert EOS_TOKEN in prompt_2 # assistant turn ends with EOS
|
||||||
|
assert "<function_results>" in prompt_2
|
||||||
|
assert "<result>" in prompt_2
|
||||||
|
assert "Cloudy" in prompt_2
|
||||||
|
assert "</function_results>" in prompt_2
|
||||||
|
# After tool results with thinking enabled → <think> appended
|
||||||
|
assert prompt_2.endswith(THINKING_START)
|
||||||
|
# The assistant's reasoning_content should appear (it's after last_user_idx)
|
||||||
|
assert "get_weather tool" in prompt_2
|
||||||
|
|
||||||
|
# ── Turn 1.2: Model thinks about results, then answers ──
|
||||||
|
model_tokens_2 = [
|
||||||
|
THINKING_START,
|
||||||
|
"The weather in Hangzhou is Cloudy, 7~13°C.",
|
||||||
|
" I'll tell the user.",
|
||||||
|
THINKING_END,
|
||||||
|
"The weather in Hangzhou is currently cloudy with temperatures between 7°C and 13°C.",
|
||||||
|
]
|
||||||
|
results_2 = list(parse_deepseek_v32(_simulate_tokens(model_tokens_2)))
|
||||||
|
|
||||||
|
gen_2 = [r for r in results_2 if isinstance(r, GenerationResponse)]
|
||||||
|
tool_2 = [r for r in results_2 if isinstance(r, ToolCallResponse)]
|
||||||
|
thinking_2 = [r for r in gen_2 if r.is_thinking]
|
||||||
|
non_thinking_2 = [r for r in gen_2 if not r.is_thinking]
|
||||||
|
|
||||||
|
assert len(tool_2) == 0 # No more tool calls
|
||||||
|
assert len(thinking_2) >= 1
|
||||||
|
assert "Cloudy" in "".join(r.text for r in thinking_2)
|
||||||
|
assert len(non_thinking_2) >= 1
|
||||||
|
final_text = "".join(r.text for r in non_thinking_2)
|
||||||
|
assert "7°C" in final_text
|
||||||
|
assert "13°C" in final_text
|
||||||
|
|
||||||
|
def test_multi_tool_full_flow(self):
|
||||||
|
"""Flow with two tools: user → think → 2 tool calls → 2 results → think → answer."""
|
||||||
|
# ── Initial prompt ──
|
||||||
|
messages: list[dict[str, Any]] = [
|
||||||
|
{"role": "system", "content": "You help with weather and time."},
|
||||||
|
{"role": "user", "content": "Weather in NYC and time in EST?"},
|
||||||
|
]
|
||||||
|
prompt_1 = encode_messages(
|
||||||
|
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
|
||||||
|
)
|
||||||
|
assert prompt_1.endswith(THINKING_START)
|
||||||
|
|
||||||
|
# ── Model thinks, calls both tools ──
|
||||||
|
model_tokens_1 = [
|
||||||
|
THINKING_START,
|
||||||
|
"Two requests: weather and time. I'll call both.",
|
||||||
|
THINKING_END,
|
||||||
|
"\n\n",
|
||||||
|
TOOL_CALLS_START,
|
||||||
|
"\n",
|
||||||
|
f'<{DSML_TOKEN}invoke name="get_weather">\n',
|
||||||
|
f'<{DSML_TOKEN}parameter name="city" string="true">NYC</{DSML_TOKEN}parameter>\n',
|
||||||
|
f"</{DSML_TOKEN}invoke>\n",
|
||||||
|
f'<{DSML_TOKEN}invoke name="get_time">\n',
|
||||||
|
f'<{DSML_TOKEN}parameter name="timezone" string="true">EST</{DSML_TOKEN}parameter>\n',
|
||||||
|
f"</{DSML_TOKEN}invoke>\n",
|
||||||
|
TOOL_CALLS_END,
|
||||||
|
]
|
||||||
|
results_1 = list(parse_deepseek_v32(_simulate_tokens(model_tokens_1)))
|
||||||
|
tool_1 = [r for r in results_1 if isinstance(r, ToolCallResponse)]
|
||||||
|
|
||||||
|
assert len(tool_1) == 1
|
||||||
|
assert len(tool_1[0].tool_calls) == 2
|
||||||
|
assert tool_1[0].tool_calls[0].name == "get_weather"
|
||||||
|
assert tool_1[0].tool_calls[1].name == "get_time"
|
||||||
|
|
||||||
|
# ── Add assistant + both tool results ──
|
||||||
|
messages.append(
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "",
|
||||||
|
"reasoning_content": "Two requests: weather and time. I'll call both.",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"arguments": '{"city": "NYC"}',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_time",
|
||||||
|
"arguments": '{"timezone": "EST"}',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
messages.append({"role": "tool", "content": "72°F, sunny"})
|
||||||
|
messages.append({"role": "tool", "content": "2:30 PM EST"})
|
||||||
|
|
||||||
|
prompt_2 = encode_messages(
|
||||||
|
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify multi-tool result encoding
|
||||||
|
# Count is 2: 1 in _TOOLS_SYSTEM_TEMPLATE example + 1 in conversation
|
||||||
|
assert prompt_2.count("<function_results>") == 2
|
||||||
|
assert prompt_2.count("</function_results>") == 2
|
||||||
|
assert "<result>72°F, sunny</result>" in prompt_2
|
||||||
|
assert "<result>2:30 PM EST</result>" in prompt_2
|
||||||
|
assert prompt_2.endswith(THINKING_START)
|
||||||
|
|
||||||
|
# ── Model thinks about results, answers ──
|
||||||
|
model_tokens_2 = [
|
||||||
|
THINKING_START,
|
||||||
|
"Got both results. Weather is 72°F sunny, time is 2:30 PM.",
|
||||||
|
THINKING_END,
|
||||||
|
"In NYC it's currently 72°F and sunny. The time in EST is 2:30 PM.",
|
||||||
|
]
|
||||||
|
results_2 = list(parse_deepseek_v32(_simulate_tokens(model_tokens_2)))
|
||||||
|
|
||||||
|
tool_2 = [r for r in results_2 if isinstance(r, ToolCallResponse)]
|
||||||
|
gen_2 = [r for r in results_2 if isinstance(r, GenerationResponse)]
|
||||||
|
non_thinking_2 = [r for r in gen_2 if not r.is_thinking]
|
||||||
|
|
||||||
|
assert len(tool_2) == 0
|
||||||
|
final_text = "".join(r.text for r in non_thinking_2)
|
||||||
|
assert "72°F" in final_text
|
||||||
|
assert "2:30 PM" in final_text
|
||||||
|
|
||||||
|
def test_two_user_turns_reasoning_stripped(self):
|
||||||
|
"""Turn 2: old reasoning_content is stripped from history.
|
||||||
|
|
||||||
|
Per the vLLM spec, clear_reasoning_content is called between user turns
|
||||||
|
to save bandwidth. Our _drop_old_thinking handles this.
|
||||||
|
"""
|
||||||
|
# Full turn 1 conversation (already completed)
|
||||||
|
messages: list[dict[str, Any]] = [
|
||||||
|
{"role": "system", "content": "You are helpful."},
|
||||||
|
{"role": "user", "content": "Weather in Hangzhou?"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "",
|
||||||
|
"reasoning_content": "I need to call get_weather for Hangzhou.",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"arguments": '{"city": "Hangzhou"}',
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{"role": "tool", "content": "Cloudy 7~13°C"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "The weather in Hangzhou is cloudy, 7-13°C.",
|
||||||
|
"reasoning_content": "The tool returned cloudy weather. I'll summarize.",
|
||||||
|
},
|
||||||
|
# Turn 2: user asks again
|
||||||
|
{"role": "user", "content": "What about Beijing?"},
|
||||||
|
]
|
||||||
|
|
||||||
|
prompt = encode_messages(
|
||||||
|
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Old reasoning_content from turn 1 assistants should be STRIPPED
|
||||||
|
# (they're before the last user message at index 5)
|
||||||
|
assert "I need to call get_weather" not in prompt
|
||||||
|
assert "tool returned cloudy" not in prompt
|
||||||
|
|
||||||
|
# But the assistant's content and tool calls should still be there
|
||||||
|
assert "cloudy, 7-13°C" in prompt
|
||||||
|
assert TOOL_CALLS_START in prompt
|
||||||
|
|
||||||
|
# Prompt ends with <think> for the new turn
|
||||||
|
assert prompt.endswith(THINKING_START)
|
||||||
|
|
||||||
|
# ── Turn 2: Model thinks, calls tool for Beijing ──
|
||||||
|
model_tokens = [
|
||||||
|
THINKING_START,
|
||||||
|
"Now the user wants Beijing weather.",
|
||||||
|
THINKING_END,
|
||||||
|
"\n\n",
|
||||||
|
TOOL_CALLS_START,
|
||||||
|
"\n",
|
||||||
|
f'<{DSML_TOKEN}invoke name="get_weather">\n',
|
||||||
|
f'<{DSML_TOKEN}parameter name="city" string="true">Beijing</{DSML_TOKEN}parameter>\n',
|
||||||
|
f"</{DSML_TOKEN}invoke>\n",
|
||||||
|
TOOL_CALLS_END,
|
||||||
|
]
|
||||||
|
results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))
|
||||||
|
tool_results = [r for r in results if isinstance(r, ToolCallResponse)]
|
||||||
|
|
||||||
|
assert len(tool_results) == 1
|
||||||
|
assert tool_results[0].tool_calls[0].name == "get_weather"
|
||||||
|
args = json.loads(tool_results[0].tool_calls[0].arguments) # pyright: ignore[reportAny]
|
||||||
|
assert args == {"city": "Beijing"}
|
||||||
|
|
||||||
|
def test_chained_tool_calls_loop(self):
|
||||||
|
"""Model calls tool, gets result, calls another tool, gets result, answers.
|
||||||
|
|
||||||
|
This simulates the inner while loop from the vLLM spec where the model
|
||||||
|
may need multiple sub-turns of tool calling before it has enough info.
|
||||||
|
"""
|
||||||
|
# ── Sub-turn 1: user asks, model calls get_time ──
|
||||||
|
messages: list[dict[str, Any]] = [
|
||||||
|
{"role": "system", "content": "You are helpful."},
|
||||||
|
{"role": "user", "content": "What's the weather in Hangzhou tomorrow?"},
|
||||||
|
]
|
||||||
|
|
||||||
|
prompt_1 = encode_messages(
|
||||||
|
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
|
||||||
|
)
|
||||||
|
assert prompt_1.endswith(THINKING_START)
|
||||||
|
|
||||||
|
# Model first calls get_time to figure out the date
|
||||||
|
model_tokens_1 = [
|
||||||
|
THINKING_START,
|
||||||
|
"I need the current date first to calculate tomorrow.",
|
||||||
|
THINKING_END,
|
||||||
|
"\n\n",
|
||||||
|
TOOL_CALLS_START,
|
||||||
|
"\n",
|
||||||
|
f'<{DSML_TOKEN}invoke name="get_time">\n',
|
||||||
|
f'<{DSML_TOKEN}parameter name="timezone" string="true">Asia/Shanghai</{DSML_TOKEN}parameter>\n',
|
||||||
|
f"</{DSML_TOKEN}invoke>\n",
|
||||||
|
TOOL_CALLS_END,
|
||||||
|
]
|
||||||
|
results_1 = list(parse_deepseek_v32(_simulate_tokens(model_tokens_1)))
|
||||||
|
tool_1 = [r for r in results_1 if isinstance(r, ToolCallResponse)]
|
||||||
|
assert len(tool_1) == 1
|
||||||
|
assert tool_1[0].tool_calls[0].name == "get_time"
|
||||||
|
|
||||||
|
# ── Sub-turn 2: add tool result, model calls get_weather ──
|
||||||
|
messages.append(
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "",
|
||||||
|
"reasoning_content": "I need the current date first to calculate tomorrow.",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_time",
|
||||||
|
"arguments": '{"timezone": "Asia/Shanghai"}',
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
messages.append({"role": "tool", "content": "2025-12-01 14:30 CST"})
|
||||||
|
|
||||||
|
prompt_2 = encode_messages(
|
||||||
|
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
|
||||||
|
)
|
||||||
|
assert "<result>2025-12-01 14:30 CST</result>" in prompt_2
|
||||||
|
assert prompt_2.endswith(THINKING_START)
|
||||||
|
|
||||||
|
# Model now knows the date, calls get_weather
|
||||||
|
model_tokens_2 = [
|
||||||
|
THINKING_START,
|
||||||
|
"Today is 2025-12-01, so tomorrow is 2025-12-02.",
|
||||||
|
" Now I can check weather for Hangzhou.",
|
||||||
|
THINKING_END,
|
||||||
|
"\n\n",
|
||||||
|
TOOL_CALLS_START,
|
||||||
|
"\n",
|
||||||
|
f'<{DSML_TOKEN}invoke name="get_weather">\n',
|
||||||
|
f'<{DSML_TOKEN}parameter name="city" string="true">Hangzhou</{DSML_TOKEN}parameter>\n',
|
||||||
|
f"</{DSML_TOKEN}invoke>\n",
|
||||||
|
TOOL_CALLS_END,
|
||||||
|
]
|
||||||
|
results_2 = list(parse_deepseek_v32(_simulate_tokens(model_tokens_2)))
|
||||||
|
tool_2 = [r for r in results_2 if isinstance(r, ToolCallResponse)]
|
||||||
|
assert len(tool_2) == 1
|
||||||
|
assert tool_2[0].tool_calls[0].name == "get_weather"
|
||||||
|
|
||||||
|
# ── Sub-turn 3: add weather result, model answers ──
|
||||||
|
messages.append(
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "",
|
||||||
|
"reasoning_content": "Today is 2025-12-01, so tomorrow is 2025-12-02. Now I can check weather for Hangzhou.",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"arguments": '{"city": "Hangzhou"}',
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
messages.append({"role": "tool", "content": "Sunny, 5~12°C"})
|
||||||
|
|
||||||
|
prompt_3 = encode_messages(
|
||||||
|
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
|
||||||
|
)
|
||||||
|
# Should have both function_results blocks (one per tool round)
|
||||||
|
# Count is 3: 1 in _TOOLS_SYSTEM_TEMPLATE example + 2 in conversation
|
||||||
|
assert prompt_3.count("<function_results>") == 3
|
||||||
|
assert prompt_3.count("</function_results>") == 3
|
||||||
|
assert "<result>2025-12-01 14:30 CST</result>" in prompt_3
|
||||||
|
assert "<result>Sunny, 5~12°C</result>" in prompt_3
|
||||||
|
assert prompt_3.endswith(THINKING_START)
|
||||||
|
|
||||||
|
# Model finally answers
|
||||||
|
model_tokens_3 = [
|
||||||
|
THINKING_START,
|
||||||
|
"I have the weather for tomorrow in Hangzhou.",
|
||||||
|
THINKING_END,
|
||||||
|
"Tomorrow in Hangzhou will be sunny with temperatures between 5°C and 12°C.",
|
||||||
|
]
|
||||||
|
results_3 = list(parse_deepseek_v32(_simulate_tokens(model_tokens_3)))
|
||||||
|
|
||||||
|
tool_3 = [r for r in results_3 if isinstance(r, ToolCallResponse)]
|
||||||
|
gen_3 = [r for r in results_3 if isinstance(r, GenerationResponse)]
|
||||||
|
non_thinking_3 = [r for r in gen_3 if not r.is_thinking]
|
||||||
|
|
||||||
|
assert len(tool_3) == 0 # No more tool calls — loop ends
|
||||||
|
final_text = "".join(r.text for r in non_thinking_3)
|
||||||
|
assert "sunny" in final_text.lower()
|
||||||
|
assert "5°C" in final_text
|
||||||
|
assert "12°C" in final_text
|
||||||
@@ -148,6 +148,7 @@ class MockTokenizer:
|
|||||||
tool_call_start = None
|
tool_call_start = None
|
||||||
tool_call_end = None
|
tool_call_end = None
|
||||||
has_tool_calling = False
|
has_tool_calling = False
|
||||||
|
has_thinking = False
|
||||||
|
|
||||||
|
|
||||||
class MockGroup:
|
class MockGroup:
|
||||||
|
|||||||
@@ -149,12 +149,23 @@ class TestParseGptOssThinkingThenToolCall:
|
|||||||
def test_thinking_then_tool_call(self):
|
def test_thinking_then_tool_call(self):
|
||||||
results = _collect(THINKING_THEN_TOOL_TOKENS)
|
results = _collect(THINKING_THEN_TOOL_TOKENS)
|
||||||
|
|
||||||
# Should have thinking tags + content + tool call
|
# Thinking tokens should have is_thinking=True and no <think> tags
|
||||||
text_parts = [r.text for r in results if isinstance(r, GenerationResponse)]
|
thinking_responses = [
|
||||||
combined = "".join(text_parts)
|
r for r in results if isinstance(r, GenerationResponse) and r.is_thinking
|
||||||
assert "<think>" in combined
|
]
|
||||||
assert "</think>" in combined
|
thinking_text = "".join(r.text for r in thinking_responses)
|
||||||
assert "Let me think about this." in combined
|
assert "Let me think about this." in thinking_text
|
||||||
|
assert "<think>" not in thinking_text
|
||||||
|
assert "</think>" not in thinking_text
|
||||||
|
|
||||||
|
# Non-thinking tokens should have is_thinking=False
|
||||||
|
non_thinking = [
|
||||||
|
r
|
||||||
|
for r in results
|
||||||
|
if isinstance(r, GenerationResponse) and not r.is_thinking
|
||||||
|
]
|
||||||
|
non_thinking_text = "".join(r.text for r in non_thinking)
|
||||||
|
assert "<think>" not in non_thinking_text
|
||||||
|
|
||||||
# And the tool call
|
# And the tool call
|
||||||
tc = _get_tool_call(results)
|
tc = _get_tool_call(results)
|
||||||
|
|||||||
8
tmp/config_examples/claude_code.sh
Executable file
8
tmp/config_examples/claude_code.sh
Executable file
@@ -0,0 +1,8 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Run Claude Code against a local exo cluster! (Here, GPT OSS 120B)
|
||||||
|
ANTHROPIC_BASE_URL="http://localhost:52415/" \
|
||||||
|
ANTHROPIC_AUTH_TOKEN="dummy" \
|
||||||
|
ANTHROPIC_MODEL="mlx-community/gpt-oss-120b-MXFP4-Q8" \
|
||||||
|
ANTHROPIC_SMALL_FAST_MODEL="mlx-community/gpt-oss-120b-MXFP4-Q8" \
|
||||||
|
CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC=1 \
|
||||||
|
claude
|
||||||
Reference in New Issue
Block a user