mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-20 07:46:42 -05:00
Compare commits
13 Commits
feat/bug-r
...
JakeHillio
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e9bd90a647 | ||
|
|
f662c129dd | ||
|
|
c45ff9ad43 | ||
|
|
7031901ae5 | ||
|
|
cf648a53b8 | ||
|
|
94b2ce6922 | ||
|
|
423ed0f07f | ||
|
|
ed001f2409 | ||
|
|
4c4c6ce99f | ||
|
|
42e1e7322b | ||
|
|
aa3f106fb9 | ||
|
|
2e29605194 | ||
|
|
cacb456cb2 |
13
Cargo.lock
generated
13
Cargo.lock
generated
@@ -890,7 +890,7 @@ dependencies = [
|
||||
"delegate",
|
||||
"env_logger",
|
||||
"extend",
|
||||
"futures",
|
||||
"futures-lite",
|
||||
"libp2p",
|
||||
"log",
|
||||
"networking",
|
||||
@@ -914,6 +914,12 @@ dependencies = [
|
||||
"syn 2.0.111",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fastrand"
|
||||
version = "2.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
|
||||
|
||||
[[package]]
|
||||
name = "ff"
|
||||
version = "0.13.1"
|
||||
@@ -1022,7 +1028,10 @@ version = "2.6.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f78e10609fe0e0b3f4157ffab1876319b5b0db102a2c60dc4626306dc46b44ad"
|
||||
dependencies = [
|
||||
"fastrand",
|
||||
"futures-core",
|
||||
"futures-io",
|
||||
"parking",
|
||||
"pin-project-lite",
|
||||
]
|
||||
|
||||
@@ -2753,7 +2762,7 @@ dependencies = [
|
||||
"delegate",
|
||||
"either",
|
||||
"extend",
|
||||
"futures",
|
||||
"futures-lite",
|
||||
"futures-timer",
|
||||
"keccak-const",
|
||||
"libp2p",
|
||||
|
||||
@@ -29,14 +29,13 @@ util = { path = "rust/util" }
|
||||
# Macro dependecies
|
||||
extend = "1.2"
|
||||
delegate = "0.13"
|
||||
pin-project = "1"
|
||||
|
||||
# Utility dependencies
|
||||
keccak-const = "0.2"
|
||||
|
||||
# Async dependencies
|
||||
tokio = "1.46"
|
||||
futures = "0.3"
|
||||
futures-lite = "2.6.1"
|
||||
futures-timer = "3.0"
|
||||
|
||||
# Data structures
|
||||
|
||||
@@ -20,6 +20,7 @@ from harness import (
|
||||
instance_id_from_instance,
|
||||
nodes_used_in_instance,
|
||||
resolve_model_short_id,
|
||||
run_planning_phase,
|
||||
settle_and_fetch_placements,
|
||||
wait_for_instance_gone,
|
||||
wait_for_instance_ready,
|
||||
@@ -962,6 +963,21 @@ Examples:
|
||||
|
||||
selected.sort(key=_placement_sort_key)
|
||||
preview = selected[0]
|
||||
|
||||
settle_deadline = (
|
||||
time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None
|
||||
)
|
||||
|
||||
print("Planning phase: checking downloads...", file=log)
|
||||
run_planning_phase(
|
||||
exo,
|
||||
full_model_id,
|
||||
preview,
|
||||
args.danger_delete_downloads,
|
||||
args.timeout,
|
||||
settle_deadline,
|
||||
)
|
||||
|
||||
instance = preview["instance"]
|
||||
instance_id = instance_id_from_instance(instance)
|
||||
sharding = str(preview["sharding"])
|
||||
|
||||
@@ -35,6 +35,7 @@ from harness import (
|
||||
instance_id_from_instance,
|
||||
nodes_used_in_instance,
|
||||
resolve_model_short_id,
|
||||
run_planning_phase,
|
||||
settle_and_fetch_placements,
|
||||
wait_for_instance_gone,
|
||||
wait_for_instance_ready,
|
||||
@@ -332,6 +333,20 @@ def main() -> int:
|
||||
if args.dry_run:
|
||||
return 0
|
||||
|
||||
settle_deadline = (
|
||||
time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None
|
||||
)
|
||||
|
||||
logger.info("Planning phase: checking downloads...")
|
||||
run_planning_phase(
|
||||
client,
|
||||
full_model_id,
|
||||
selected[0],
|
||||
args.danger_delete_downloads,
|
||||
args.timeout,
|
||||
settle_deadline,
|
||||
)
|
||||
|
||||
all_rows: list[dict[str, Any]] = []
|
||||
|
||||
for preview in selected:
|
||||
|
||||
150
bench/harness.py
150
bench/harness.py
@@ -282,6 +282,151 @@ def settle_and_fetch_placements(
|
||||
return selected
|
||||
|
||||
|
||||
def run_planning_phase(
|
||||
client: ExoClient,
|
||||
full_model_id: str,
|
||||
preview: dict[str, Any],
|
||||
danger_delete: bool,
|
||||
timeout: float,
|
||||
settle_deadline: float | None,
|
||||
) -> None:
|
||||
"""Check disk space and ensure model is downloaded before benchmarking."""
|
||||
# Get model size from /models
|
||||
models = client.request_json("GET", "/models") or {}
|
||||
model_bytes = 0
|
||||
for m in models.get("data", []):
|
||||
if m.get("hugging_face_id") == full_model_id:
|
||||
model_bytes = m.get("storage_size_megabytes", 0) * 1024 * 1024
|
||||
break
|
||||
|
||||
if not model_bytes:
|
||||
logger.warning(
|
||||
f"Could not determine size for {full_model_id}, skipping disk check"
|
||||
)
|
||||
return
|
||||
|
||||
# Get nodes from preview
|
||||
inner = unwrap_instance(preview["instance"])
|
||||
node_ids = list(inner["shardAssignments"]["nodeToRunner"].keys())
|
||||
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
|
||||
|
||||
state = client.request_json("GET", "/state")
|
||||
downloads = state.get("downloads", {})
|
||||
node_disk = state.get("nodeDisk", {})
|
||||
|
||||
for node_id in node_ids:
|
||||
node_downloads = downloads.get(node_id, [])
|
||||
|
||||
# Check if model already downloaded on this node
|
||||
already_downloaded = any(
|
||||
"DownloadCompleted" in p
|
||||
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
|
||||
"modelId"
|
||||
]
|
||||
== full_model_id
|
||||
for p in node_downloads
|
||||
)
|
||||
if already_downloaded:
|
||||
continue
|
||||
|
||||
# Wait for disk info if settle_deadline is set
|
||||
disk_info = node_disk.get(node_id, {})
|
||||
backoff = _SETTLE_INITIAL_BACKOFF_S
|
||||
while not disk_info and settle_deadline and time.monotonic() < settle_deadline:
|
||||
remaining = settle_deadline - time.monotonic()
|
||||
logger.info(
|
||||
f"Waiting for disk info on {node_id} ({remaining:.0f}s remaining)..."
|
||||
)
|
||||
time.sleep(min(backoff, remaining))
|
||||
backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)
|
||||
state = client.request_json("GET", "/state")
|
||||
node_disk = state.get("nodeDisk", {})
|
||||
disk_info = node_disk.get(node_id, {})
|
||||
|
||||
if not disk_info:
|
||||
logger.warning(f"No disk info for {node_id}, skipping space check")
|
||||
continue
|
||||
|
||||
avail = disk_info.get("available", {}).get("inBytes", 0)
|
||||
if avail >= model_bytes:
|
||||
continue
|
||||
|
||||
if not danger_delete:
|
||||
raise RuntimeError(
|
||||
f"Insufficient disk on {node_id}: need {model_bytes // (1024**3)}GB, "
|
||||
f"have {avail // (1024**3)}GB. Use --danger-delete-downloads to free space."
|
||||
)
|
||||
|
||||
# Delete from smallest to largest
|
||||
completed = [
|
||||
(
|
||||
unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
|
||||
"modelId"
|
||||
],
|
||||
p["DownloadCompleted"]["totalBytes"]["inBytes"],
|
||||
)
|
||||
for p in node_downloads
|
||||
if "DownloadCompleted" in p
|
||||
]
|
||||
for del_model, size in sorted(completed, key=lambda x: x[1]):
|
||||
logger.info(f"Deleting {del_model} from {node_id} ({size // (1024**2)}MB)")
|
||||
client.request_json("DELETE", f"/download/{node_id}/{del_model}")
|
||||
avail += size
|
||||
if avail >= model_bytes:
|
||||
break
|
||||
|
||||
if avail < model_bytes:
|
||||
raise RuntimeError(f"Could not free enough space on {node_id}")
|
||||
|
||||
# Start downloads (idempotent)
|
||||
for node_id in node_ids:
|
||||
runner_id = inner["shardAssignments"]["nodeToRunner"][node_id]
|
||||
shard = runner_to_shard[runner_id]
|
||||
client.request_json(
|
||||
"POST",
|
||||
"/download/start",
|
||||
body={
|
||||
"targetNodeId": node_id,
|
||||
"shardMetadata": shard,
|
||||
},
|
||||
)
|
||||
logger.info(f"Started download on {node_id}")
|
||||
|
||||
# Wait for downloads
|
||||
start = time.time()
|
||||
while time.time() - start < timeout:
|
||||
state = client.request_json("GET", "/state")
|
||||
downloads = state.get("downloads", {})
|
||||
all_done = True
|
||||
for node_id in node_ids:
|
||||
done = any(
|
||||
"DownloadCompleted" in p
|
||||
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])[
|
||||
"modelCard"
|
||||
]["modelId"]
|
||||
== full_model_id
|
||||
for p in downloads.get(node_id, [])
|
||||
)
|
||||
failed = [
|
||||
p["DownloadFailed"]["errorMessage"]
|
||||
for p in downloads.get(node_id, [])
|
||||
if "DownloadFailed" in p
|
||||
and unwrap_instance(p["DownloadFailed"]["shardMetadata"])["modelCard"][
|
||||
"modelId"
|
||||
]
|
||||
== full_model_id
|
||||
]
|
||||
if failed:
|
||||
raise RuntimeError(f"Download failed on {node_id}: {failed[0]}")
|
||||
if not done:
|
||||
all_done = False
|
||||
if all_done:
|
||||
return
|
||||
time.sleep(1)
|
||||
|
||||
raise TimeoutError("Downloads did not complete in time")
|
||||
|
||||
|
||||
def add_common_instance_args(ap: argparse.ArgumentParser) -> None:
|
||||
ap.add_argument("--host", default=os.environ.get("EXO_HOST", "localhost"))
|
||||
ap.add_argument(
|
||||
@@ -325,3 +470,8 @@ def add_common_instance_args(ap: argparse.ArgumentParser) -> None:
|
||||
default=0,
|
||||
help="Max seconds to wait for the cluster to produce valid placements (0 = try once).",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--danger-delete-downloads",
|
||||
action="store_true",
|
||||
help="Delete existing models from smallest to largest to make room for benchmark model.",
|
||||
)
|
||||
|
||||
@@ -250,6 +250,11 @@ interface RawStateResponse {
|
||||
>;
|
||||
// Thunderbolt bridge cycles (nodes with bridge enabled forming loops)
|
||||
thunderboltBridgeCycles?: string[][];
|
||||
// Disk usage per node
|
||||
nodeDisk?: Record<
|
||||
string,
|
||||
{ total: { inBytes: number }; available: { inBytes: number } }
|
||||
>;
|
||||
}
|
||||
|
||||
export interface MessageAttachment {
|
||||
@@ -1652,11 +1657,12 @@ class AppStore {
|
||||
if (!reader) throw new Error("No response body");
|
||||
|
||||
let fullContent = prefixText;
|
||||
let streamedThinking = "";
|
||||
const collectedTokens: TokenData[] = [...tokensToKeep];
|
||||
|
||||
interface ChatCompletionChunk {
|
||||
choices?: Array<{
|
||||
delta?: { content?: string };
|
||||
delta?: { content?: string; reasoning_content?: string };
|
||||
logprobs?: {
|
||||
content?: Array<{
|
||||
token: string;
|
||||
@@ -1677,6 +1683,7 @@ class AppStore {
|
||||
(parsed) => {
|
||||
const choice = parsed.choices?.[0];
|
||||
const delta = choice?.delta?.content;
|
||||
const thinkingDelta = choice?.delta?.reasoning_content;
|
||||
|
||||
// Collect logprobs data
|
||||
const logprobsContent = choice?.logprobs?.content;
|
||||
@@ -1695,7 +1702,11 @@ class AppStore {
|
||||
}
|
||||
}
|
||||
|
||||
if (delta) {
|
||||
if (thinkingDelta) {
|
||||
streamedThinking += thinkingDelta;
|
||||
}
|
||||
|
||||
if (delta || thinkingDelta) {
|
||||
if (firstTokenTime === null) {
|
||||
firstTokenTime = performance.now();
|
||||
this.ttftMs = firstTokenTime - requestStartTime;
|
||||
@@ -1709,9 +1720,14 @@ class AppStore {
|
||||
this.tps = ((tokenCount - tokensToKeep.length) / elapsed) * 1000;
|
||||
}
|
||||
|
||||
fullContent += delta;
|
||||
const { displayContent, thinkingContent } =
|
||||
if (delta) {
|
||||
fullContent += delta;
|
||||
}
|
||||
const { displayContent, thinkingContent: tagThinking } =
|
||||
this.stripThinkingTags(fullContent);
|
||||
const combinedThinking = [streamedThinking, tagThinking]
|
||||
.filter(Boolean)
|
||||
.join("\n\n");
|
||||
|
||||
if (this.activeConversationId === targetConversationId) {
|
||||
this.currentResponse = displayContent;
|
||||
@@ -1723,7 +1739,7 @@ class AppStore {
|
||||
messageId,
|
||||
(m) => {
|
||||
m.content = displayContent;
|
||||
m.thinking = thinkingContent || undefined;
|
||||
m.thinking = combinedThinking || undefined;
|
||||
m.tokens = [...collectedTokens];
|
||||
},
|
||||
);
|
||||
@@ -1735,11 +1751,14 @@ class AppStore {
|
||||
|
||||
// Final update
|
||||
if (this.conversationExists(targetConversationId)) {
|
||||
const { displayContent, thinkingContent } =
|
||||
const { displayContent, thinkingContent: tagThinking } =
|
||||
this.stripThinkingTags(fullContent);
|
||||
const finalThinking = [streamedThinking, tagThinking]
|
||||
.filter(Boolean)
|
||||
.join("\n\n");
|
||||
this.updateConversationMessage(targetConversationId, messageId, (m) => {
|
||||
m.content = displayContent;
|
||||
m.thinking = thinkingContent || undefined;
|
||||
m.thinking = finalThinking || undefined;
|
||||
m.tokens = [...collectedTokens];
|
||||
if (this.ttftMs !== null) m.ttftMs = this.ttftMs;
|
||||
if (this.tps !== null) m.tps = this.tps;
|
||||
@@ -1847,11 +1866,12 @@ class AppStore {
|
||||
}
|
||||
|
||||
let streamedContent = "";
|
||||
let streamedThinking = "";
|
||||
const collectedTokens: TokenData[] = [];
|
||||
|
||||
interface ChatCompletionChunk {
|
||||
choices?: Array<{
|
||||
delta?: { content?: string };
|
||||
delta?: { content?: string; reasoning_content?: string };
|
||||
logprobs?: {
|
||||
content?: Array<{
|
||||
token: string;
|
||||
@@ -1872,6 +1892,7 @@ class AppStore {
|
||||
(parsed) => {
|
||||
const choice = parsed.choices?.[0];
|
||||
const delta = choice?.delta?.content;
|
||||
const thinkingDelta = choice?.delta?.reasoning_content;
|
||||
|
||||
// Collect logprobs data
|
||||
const logprobsContent = choice?.logprobs?.content;
|
||||
@@ -1890,10 +1911,19 @@ class AppStore {
|
||||
}
|
||||
}
|
||||
|
||||
if (delta) {
|
||||
streamedContent += delta;
|
||||
const { displayContent, thinkingContent } =
|
||||
if (thinkingDelta) {
|
||||
streamedThinking += thinkingDelta;
|
||||
}
|
||||
|
||||
if (delta || thinkingDelta) {
|
||||
if (delta) {
|
||||
streamedContent += delta;
|
||||
}
|
||||
const { displayContent, thinkingContent: tagThinking } =
|
||||
this.stripThinkingTags(streamedContent);
|
||||
const combinedThinking = [streamedThinking, tagThinking]
|
||||
.filter(Boolean)
|
||||
.join("\n\n");
|
||||
|
||||
// Only update currentResponse if target conversation is active
|
||||
if (this.activeConversationId === targetConversationId) {
|
||||
@@ -1906,7 +1936,7 @@ class AppStore {
|
||||
assistantMessage.id,
|
||||
(msg) => {
|
||||
msg.content = displayContent;
|
||||
msg.thinking = thinkingContent || undefined;
|
||||
msg.thinking = combinedThinking || undefined;
|
||||
msg.tokens = [...collectedTokens];
|
||||
},
|
||||
);
|
||||
@@ -1918,14 +1948,17 @@ class AppStore {
|
||||
|
||||
// Final cleanup of the message (if conversation still exists)
|
||||
if (this.conversationExists(targetConversationId)) {
|
||||
const { displayContent, thinkingContent } =
|
||||
const { displayContent, thinkingContent: tagThinking } =
|
||||
this.stripThinkingTags(streamedContent);
|
||||
const finalThinking = [streamedThinking, tagThinking]
|
||||
.filter(Boolean)
|
||||
.join("\n\n");
|
||||
this.updateConversationMessage(
|
||||
targetConversationId,
|
||||
assistantMessage.id,
|
||||
(msg) => {
|
||||
msg.content = displayContent;
|
||||
msg.thinking = thinkingContent || undefined;
|
||||
msg.thinking = finalThinking || undefined;
|
||||
msg.tokens = [...collectedTokens];
|
||||
},
|
||||
);
|
||||
@@ -2317,10 +2350,11 @@ class AppStore {
|
||||
}
|
||||
|
||||
let streamedContent = "";
|
||||
let streamedThinking = "";
|
||||
|
||||
interface ChatCompletionChunk {
|
||||
choices?: Array<{
|
||||
delta?: { content?: string };
|
||||
delta?: { content?: string; reasoning_content?: string };
|
||||
logprobs?: {
|
||||
content?: Array<{
|
||||
token: string;
|
||||
@@ -2348,6 +2382,7 @@ class AppStore {
|
||||
|
||||
const choice = parsed.choices?.[0];
|
||||
const tokenContent = choice?.delta?.content;
|
||||
const thinkingContent = choice?.delta?.reasoning_content;
|
||||
|
||||
// Collect logprobs data
|
||||
const logprobsContent = choice?.logprobs?.content;
|
||||
@@ -2366,7 +2401,11 @@ class AppStore {
|
||||
}
|
||||
}
|
||||
|
||||
if (tokenContent) {
|
||||
if (thinkingContent) {
|
||||
streamedThinking += thinkingContent;
|
||||
}
|
||||
|
||||
if (tokenContent || thinkingContent) {
|
||||
// Track first token for TTFT
|
||||
if (firstTokenTime === null) {
|
||||
firstTokenTime = performance.now();
|
||||
@@ -2383,11 +2422,16 @@ class AppStore {
|
||||
this.tps = (tokenCount / elapsed) * 1000;
|
||||
}
|
||||
|
||||
streamedContent += tokenContent;
|
||||
if (tokenContent) {
|
||||
streamedContent += tokenContent;
|
||||
}
|
||||
|
||||
// Strip thinking tags for display and extract thinking content
|
||||
const { displayContent, thinkingContent } =
|
||||
// Use stripThinkingTags as fallback for any <think> tags still in content
|
||||
const { displayContent, thinkingContent: tagThinking } =
|
||||
this.stripThinkingTags(streamedContent);
|
||||
const combinedThinking = [streamedThinking, tagThinking]
|
||||
.filter(Boolean)
|
||||
.join("\n\n");
|
||||
|
||||
// Only update currentResponse if target conversation is active
|
||||
if (this.activeConversationId === targetConversationId) {
|
||||
@@ -2400,7 +2444,7 @@ class AppStore {
|
||||
assistantMessage.id,
|
||||
(msg) => {
|
||||
msg.content = displayContent;
|
||||
msg.thinking = thinkingContent || undefined;
|
||||
msg.thinking = combinedThinking || undefined;
|
||||
msg.tokens = [...collectedTokens];
|
||||
},
|
||||
);
|
||||
@@ -2436,14 +2480,17 @@ class AppStore {
|
||||
|
||||
// Final cleanup of the message (if conversation still exists)
|
||||
if (this.conversationExists(targetConversationId)) {
|
||||
const { displayContent, thinkingContent } =
|
||||
const { displayContent, thinkingContent: tagThinking } =
|
||||
this.stripThinkingTags(streamedContent);
|
||||
const finalThinking = [streamedThinking, tagThinking]
|
||||
.filter(Boolean)
|
||||
.join("\n\n");
|
||||
this.updateConversationMessage(
|
||||
targetConversationId,
|
||||
assistantMessage.id,
|
||||
(msg) => {
|
||||
msg.content = displayContent;
|
||||
msg.thinking = thinkingContent || undefined;
|
||||
msg.thinking = finalThinking || undefined;
|
||||
msg.tokens = [...collectedTokens];
|
||||
// Store performance metrics on the message
|
||||
if (this.ttftMs !== null) {
|
||||
|
||||
@@ -114,6 +114,74 @@
|
||||
});
|
||||
let tb5InfoDismissed = $state(false);
|
||||
|
||||
// Detect Mac Studio nodes using RDMA on en2 (the port next to ethernet — RDMA doesn't work there)
|
||||
const macStudioEn2RdmaWarning = $derived.by(() => {
|
||||
const edges = data?.edges;
|
||||
const ids = tbIdentifiers;
|
||||
const rdmaCtl = rdmaCtlData;
|
||||
if (!edges || !ids || !rdmaCtl) return null;
|
||||
|
||||
const affectedConnections: Array<{
|
||||
nodeId: string;
|
||||
nodeName: string;
|
||||
peerNodeId: string;
|
||||
peerNodeName: string;
|
||||
rdmaIface: string;
|
||||
}> = [];
|
||||
|
||||
const isMacStudio = (node: (typeof data.nodes)[string] | undefined) =>
|
||||
node?.system_info?.model_id === "Mac Studio";
|
||||
|
||||
for (const edge of edges) {
|
||||
if (!edge.sourceRdmaIface && !edge.sinkRdmaIface) continue;
|
||||
|
||||
const sourceNode = data?.nodes?.[edge.source];
|
||||
if (
|
||||
isMacStudio(sourceNode) &&
|
||||
edge.sourceRdmaIface === "rdma_en2" &&
|
||||
rdmaCtl[edge.source]?.enabled
|
||||
) {
|
||||
affectedConnections.push({
|
||||
nodeId: edge.source,
|
||||
nodeName:
|
||||
sourceNode?.friendly_name || edge.source.slice(0, 8) + "...",
|
||||
peerNodeId: edge.target,
|
||||
peerNodeName:
|
||||
data?.nodes?.[edge.target]?.friendly_name ||
|
||||
edge.target.slice(0, 8) + "...",
|
||||
rdmaIface: "en2",
|
||||
});
|
||||
}
|
||||
|
||||
const sinkNode = data?.nodes?.[edge.target];
|
||||
if (
|
||||
isMacStudio(sinkNode) &&
|
||||
edge.sinkRdmaIface === "rdma_en2" &&
|
||||
rdmaCtl[edge.target]?.enabled
|
||||
) {
|
||||
affectedConnections.push({
|
||||
nodeId: edge.target,
|
||||
nodeName: sinkNode?.friendly_name || edge.target.slice(0, 8) + "...",
|
||||
peerNodeId: edge.source,
|
||||
peerNodeName:
|
||||
sourceNode?.friendly_name || edge.source.slice(0, 8) + "...",
|
||||
rdmaIface: "en2",
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Deduplicate by nodeId
|
||||
const seen = new Set<string>();
|
||||
const unique = affectedConnections.filter((c) => {
|
||||
if (seen.has(c.nodeId)) return false;
|
||||
seen.add(c.nodeId);
|
||||
return true;
|
||||
});
|
||||
|
||||
return unique.length > 0 ? unique : null;
|
||||
});
|
||||
let macStudioEn2Dismissed = $state(false);
|
||||
|
||||
// Helper to get friendly node name from node ID
|
||||
function getNodeName(nodeId: string): string {
|
||||
const node = data?.nodes?.[nodeId];
|
||||
@@ -790,10 +858,8 @@
|
||||
if (!progress || typeof progress !== "object") return null;
|
||||
|
||||
const prog = progress as Record<string, unknown>;
|
||||
const totalBytes = getBytes(prog.total_bytes ?? prog.totalBytes);
|
||||
const downloadedBytes = getBytes(
|
||||
prog.downloaded_bytes ?? prog.downloadedBytes,
|
||||
);
|
||||
const totalBytes = getBytes(prog.total);
|
||||
const downloadedBytes = getBytes(prog.downloaded);
|
||||
const speed = (prog.speed as number) ?? 0;
|
||||
const completedFiles =
|
||||
(prog.completed_files as number) ?? (prog.completedFiles as number) ?? 0;
|
||||
@@ -806,8 +872,8 @@
|
||||
for (const [fileName, fileData] of Object.entries(filesObj)) {
|
||||
if (!fileData || typeof fileData !== "object") continue;
|
||||
const fd = fileData as Record<string, unknown>;
|
||||
const fTotal = getBytes(fd.total_bytes ?? fd.totalBytes);
|
||||
const fDownloaded = getBytes(fd.downloaded_bytes ?? fd.downloadedBytes);
|
||||
const fTotal = getBytes(fd.total);
|
||||
const fDownloaded = getBytes(fd.downloaded);
|
||||
files.push({
|
||||
name: fileName,
|
||||
totalBytes: fTotal,
|
||||
@@ -1196,7 +1262,6 @@
|
||||
if (typeof value === "number") return value;
|
||||
if (value && typeof value === "object") {
|
||||
const v = value as Record<string, unknown>;
|
||||
if (typeof v.in_bytes === "number") return v.in_bytes;
|
||||
if (typeof v.inBytes === "number") return v.inBytes;
|
||||
}
|
||||
return 0;
|
||||
@@ -1758,7 +1823,7 @@
|
||||
</script>
|
||||
|
||||
{#snippet clusterWarnings()}
|
||||
{#if tbBridgeCycles.length > 0 || macosVersionMismatch || (tb5WithoutRdma && !tb5InfoDismissed)}
|
||||
{#if tbBridgeCycles.length > 0 || macosVersionMismatch || (tb5WithoutRdma && !tb5InfoDismissed) || (macStudioEn2RdmaWarning && !macStudioEn2Dismissed)}
|
||||
<div class="absolute top-4 left-4 flex flex-col gap-2 z-40">
|
||||
{#if tbBridgeCycles.length > 0}
|
||||
{@const cycle = tbBridgeCycles[0]}
|
||||
@@ -1923,12 +1988,260 @@
|
||||
</button>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
{#if macStudioEn2RdmaWarning && !macStudioEn2Dismissed}
|
||||
<div class="group relative" role="alert">
|
||||
<div
|
||||
class="flex items-center gap-2 px-3 py-2 rounded border border-red-500/50 bg-red-500/10 backdrop-blur-sm cursor-help"
|
||||
>
|
||||
<svg
|
||||
class="w-5 h-5 text-red-400 flex-shrink-0"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d={warningIconPath}
|
||||
/>
|
||||
</svg>
|
||||
<span class="text-sm font-mono text-red-200">
|
||||
RDMA INCOMPATIBLE PORT
|
||||
</span>
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => (macStudioEn2Dismissed = true)}
|
||||
class="ml-1 text-red-300/60 hover:text-red-200 transition-colors cursor-pointer"
|
||||
title="Dismiss"
|
||||
>
|
||||
<svg
|
||||
class="w-4 h-4"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M6 18L18 6M6 6l12 12"
|
||||
/>
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Expanded tooltip on hover -->
|
||||
<div
|
||||
class="absolute top-full left-0 mt-2 w-96 p-4 rounded border border-red-500/30 bg-[#1a1a1a]/95 backdrop-blur-sm opacity-0 invisible group-hover:opacity-100 group-hover:visible transition-all duration-200 z-50 shadow-lg"
|
||||
>
|
||||
<p class="text-xs text-white/80 mb-3">
|
||||
The Thunderbolt 5 port next to the Ethernet port on Mac Studio
|
||||
does
|
||||
<span class="text-red-400 font-semibold">not support RDMA</span>.
|
||||
Move the cable to one of the other three TB5 ports.
|
||||
</p>
|
||||
|
||||
<div class="text-xs text-white/60 mb-3">
|
||||
<span class="text-red-300">Affected:</span>
|
||||
{#each macStudioEn2RdmaWarning as conn}
|
||||
<div class="ml-2 mt-0.5">
|
||||
<span class="text-white/80">{conn.nodeName}</span>
|
||||
<span class="text-white/30">→</span>
|
||||
<span class="text-white/60">{conn.peerNodeName}</span>
|
||||
<span class="text-white/30 ml-1">(en2)</span>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
|
||||
<!-- Mac Studio back panel illustration -->
|
||||
<div class="bg-black/40 rounded p-3 mb-3">
|
||||
<p
|
||||
class="text-[10px] font-mono text-white/30 uppercase tracking-wider mb-2"
|
||||
>
|
||||
Mac Studio — Rear Panel
|
||||
</p>
|
||||
<svg
|
||||
viewBox="0 0 320 72"
|
||||
class="w-full"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
>
|
||||
<rect
|
||||
x="1"
|
||||
y="1"
|
||||
width="318"
|
||||
height="70"
|
||||
rx="6"
|
||||
ry="6"
|
||||
fill="none"
|
||||
stroke="rgba(255,255,255,0.12)"
|
||||
stroke-width="1"
|
||||
/>
|
||||
<!-- TB5 port 1 -->
|
||||
<rect
|
||||
x="24"
|
||||
y="22"
|
||||
width="28"
|
||||
height="14"
|
||||
rx="4"
|
||||
fill="none"
|
||||
stroke="rgba(255,255,255,0.3)"
|
||||
stroke-width="1"
|
||||
/>
|
||||
<text
|
||||
x="38"
|
||||
y="52"
|
||||
text-anchor="middle"
|
||||
fill="rgba(255,255,255,0.25)"
|
||||
style="font-size:7px;font-family:ui-monospace,monospace;"
|
||||
>TB5</text
|
||||
>
|
||||
<!-- TB5 port 2 -->
|
||||
<rect
|
||||
x="62"
|
||||
y="22"
|
||||
width="28"
|
||||
height="14"
|
||||
rx="4"
|
||||
fill="none"
|
||||
stroke="rgba(255,255,255,0.3)"
|
||||
stroke-width="1"
|
||||
/>
|
||||
<text
|
||||
x="76"
|
||||
y="52"
|
||||
text-anchor="middle"
|
||||
fill="rgba(255,255,255,0.25)"
|
||||
style="font-size:7px;font-family:ui-monospace,monospace;"
|
||||
>TB5</text
|
||||
>
|
||||
<!-- TB5 port 3 -->
|
||||
<rect
|
||||
x="100"
|
||||
y="22"
|
||||
width="28"
|
||||
height="14"
|
||||
rx="4"
|
||||
fill="none"
|
||||
stroke="rgba(255,255,255,0.3)"
|
||||
stroke-width="1"
|
||||
/>
|
||||
<text
|
||||
x="114"
|
||||
y="52"
|
||||
text-anchor="middle"
|
||||
fill="rgba(255,255,255,0.25)"
|
||||
style="font-size:7px;font-family:ui-monospace,monospace;"
|
||||
>TB5</text
|
||||
>
|
||||
<!-- TB5 port 4: INCOMPATIBLE (en2) — equally spaced with ports 1-3 -->
|
||||
<rect
|
||||
x="138"
|
||||
y="22"
|
||||
width="28"
|
||||
height="14"
|
||||
rx="4"
|
||||
fill="rgba(239,68,68,0.1)"
|
||||
stroke="rgba(239,68,68,0.7)"
|
||||
stroke-width="1.5"
|
||||
/>
|
||||
<line
|
||||
x1="142"
|
||||
y1="25"
|
||||
x2="162"
|
||||
y2="33"
|
||||
stroke="rgba(239,68,68,0.8)"
|
||||
stroke-width="1.5"
|
||||
stroke-linecap="round"
|
||||
/>
|
||||
<line
|
||||
x1="162"
|
||||
y1="25"
|
||||
x2="142"
|
||||
y2="33"
|
||||
stroke="rgba(239,68,68,0.8)"
|
||||
stroke-width="1.5"
|
||||
stroke-linecap="round"
|
||||
/>
|
||||
<text
|
||||
x="152"
|
||||
y="52"
|
||||
text-anchor="middle"
|
||||
fill="rgba(239,68,68,0.6)"
|
||||
style="font-size:7px;font-family:ui-monospace,monospace;font-weight:600;"
|
||||
>en2</text
|
||||
>
|
||||
<!-- Ethernet port -->
|
||||
<rect
|
||||
x="196"
|
||||
y="19"
|
||||
width="24"
|
||||
height="20"
|
||||
rx="2"
|
||||
fill="none"
|
||||
stroke="rgba(255,255,255,0.2)"
|
||||
stroke-width="1"
|
||||
/>
|
||||
<rect
|
||||
x="200"
|
||||
y="23"
|
||||
width="16"
|
||||
height="12"
|
||||
rx="1"
|
||||
fill="none"
|
||||
stroke="rgba(255,255,255,0.12)"
|
||||
stroke-width="0.75"
|
||||
/>
|
||||
<text
|
||||
x="208"
|
||||
y="52"
|
||||
text-anchor="middle"
|
||||
fill="rgba(255,255,255,0.25)"
|
||||
style="font-size:7px;font-family:ui-monospace,monospace;"
|
||||
>ETH</text
|
||||
>
|
||||
<!-- Green checkmarks on working ports -->
|
||||
<circle
|
||||
cx="38"
|
||||
cy="62"
|
||||
r="3"
|
||||
fill="none"
|
||||
stroke="rgba(74,222,128,0.5)"
|
||||
stroke-width="0.75"
|
||||
/>
|
||||
<circle
|
||||
cx="76"
|
||||
cy="62"
|
||||
r="3"
|
||||
fill="none"
|
||||
stroke="rgba(74,222,128,0.5)"
|
||||
stroke-width="0.75"
|
||||
/>
|
||||
<circle
|
||||
cx="114"
|
||||
cy="62"
|
||||
r="3"
|
||||
fill="none"
|
||||
stroke="rgba(74,222,128,0.5)"
|
||||
stroke-width="0.75"
|
||||
/>
|
||||
</svg>
|
||||
</div>
|
||||
|
||||
<p class="text-xs text-white/50">
|
||||
<span class="text-green-400">Fix:</span> Move the Thunderbolt cable
|
||||
to any of the three leftmost ports (all support RDMA).
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
{/if}
|
||||
{/snippet}
|
||||
|
||||
{#snippet clusterWarningsCompact()}
|
||||
{#if tbBridgeCycles.length > 0 || macosVersionMismatch || (tb5WithoutRdma && !tb5InfoDismissed)}
|
||||
{#if tbBridgeCycles.length > 0 || macosVersionMismatch || (tb5WithoutRdma && !tb5InfoDismissed) || (macStudioEn2RdmaWarning && !macStudioEn2Dismissed)}
|
||||
<div class="absolute top-2 left-2 flex flex-col gap-1">
|
||||
{#if tbBridgeCycles.length > 0}
|
||||
<div
|
||||
@@ -1996,6 +2309,27 @@
|
||||
>
|
||||
</div>
|
||||
{/if}
|
||||
{#if macStudioEn2RdmaWarning && !macStudioEn2Dismissed}
|
||||
<div
|
||||
class="flex items-center gap-1.5 px-2 py-1 rounded border border-red-500/50 bg-red-500/10 backdrop-blur-sm"
|
||||
title="Mac Studio RDMA incompatible port (en2) — move cable to another TB5 port"
|
||||
>
|
||||
<svg
|
||||
class="w-3.5 h-3.5 text-red-400"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d={warningIconPath}
|
||||
/>
|
||||
</svg>
|
||||
<span class="text-[10px] font-mono text-red-200">BAD RDMA PORT</span>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
{/if}
|
||||
{/snippet}
|
||||
|
||||
@@ -74,7 +74,6 @@
|
||||
if (typeof value === "number") return value;
|
||||
if (value && typeof value === "object") {
|
||||
const v = value as Record<string, unknown>;
|
||||
if (typeof v.in_bytes === "number") return v.in_bytes;
|
||||
if (typeof v.inBytes === "number") return v.inBytes;
|
||||
}
|
||||
return 0;
|
||||
@@ -231,23 +230,14 @@
|
||||
undefined;
|
||||
let cell: CellStatus;
|
||||
if (tag === "DownloadCompleted") {
|
||||
const totalBytes = getBytes(
|
||||
payload.total_bytes ?? payload.totalBytes,
|
||||
);
|
||||
const totalBytes = getBytes(payload.total);
|
||||
cell = { kind: "completed", totalBytes, modelDirectory };
|
||||
} else if (tag === "DownloadOngoing") {
|
||||
const rawProgress =
|
||||
payload.download_progress ?? payload.downloadProgress ?? {};
|
||||
const prog = rawProgress as Record<string, unknown>;
|
||||
const totalBytes = getBytes(
|
||||
prog.total_bytes ??
|
||||
prog.totalBytes ??
|
||||
payload.total_bytes ??
|
||||
payload.totalBytes,
|
||||
);
|
||||
const downloadedBytes = getBytes(
|
||||
prog.downloaded_bytes ?? prog.downloadedBytes,
|
||||
);
|
||||
const totalBytes = getBytes(prog.total ?? payload.total);
|
||||
const downloadedBytes = getBytes(prog.downloaded);
|
||||
const speed = (prog.speed as number) ?? 0;
|
||||
const etaMs =
|
||||
(prog.eta_ms as number) ?? (prog.etaMs as number) ?? 0;
|
||||
|
||||
@@ -74,7 +74,6 @@
|
||||
perSystem =
|
||||
{ config, self', inputs', pkgs, lib, system, ... }:
|
||||
let
|
||||
fenixToolchain = inputs'.fenix.packages.complete;
|
||||
# Use pinned nixpkgs for swift-format (swift is broken on x86_64-linux in newer nixpkgs)
|
||||
pkgsSwift = import inputs.nixpkgs-swift { inherit system; };
|
||||
in
|
||||
|
||||
10
nix/mlx.nix
10
nix/mlx.nix
@@ -41,16 +41,16 @@ let
|
||||
|
||||
mlx = stdenv.mkDerivation rec {
|
||||
pname = "mlx";
|
||||
version = let v = "0.30.7.dev20260218+14841977"; in
|
||||
version = let v = "0.30.7.dev20260220+bdfe78f6"; in
|
||||
assert v == uvLockMlxVersion || throw "MLX version mismatch: nix/mlx.nix has ${v} but uv.lock has ${uvLockMlxVersion}. Update both the version and hash in nix/mlx.nix.";
|
||||
v;
|
||||
pyproject = true;
|
||||
|
||||
src = fetchFromGitHub {
|
||||
owner = "rltakashige";
|
||||
repo = "mlx-jaccl-fix-small-recv";
|
||||
rev = "1484197707f35186ad3bd614357c7c47fdf86ebc";
|
||||
hash = "sha256-FupCMoK/SF/ldfKuvMSAKECcOP8c+ANgkQlPZttDsLk=";
|
||||
owner = "JakeHillion";
|
||||
repo = "mlx";
|
||||
rev = "bdfe78f6e1fccb7cb3dfd049eb38f7c611e5f323";
|
||||
hash = "sha256-MIRrJUOlC5u1SX6Tm6yuKYLo3LTE7vsMUz++rw8HWLI=";
|
||||
};
|
||||
|
||||
patches = [
|
||||
|
||||
@@ -64,7 +64,7 @@ members = [
|
||||
|
||||
[tool.uv.sources]
|
||||
exo_pyo3_bindings = { workspace = true }
|
||||
mlx = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git", branch = "address-rdma-gpu-locks", marker = "sys_platform == 'darwin'" }
|
||||
mlx = { git = "https://github.com/JakeHillion/mlx.git", branch = "test-mlx-lazy-import", marker = "sys_platform == 'darwin'" }
|
||||
#mlx-lm = { git = "https://github.com/davidmcc73/mlx-lm", branch = "stable" }
|
||||
# Uncomment to use local mlx/mlx-lm development versions:
|
||||
# mlx = { path = "/Users/Shared/mlx", editable=true }
|
||||
|
||||
@@ -1,2 +0,0 @@
|
||||
# we can manually exclude false-positive lint errors for dual packages (if in dependencies)
|
||||
#allowed-duplicate-crates = ["hashbrown"]
|
||||
@@ -27,7 +27,7 @@ networking = { workspace = true }
|
||||
# interop
|
||||
pyo3 = { version = "0.27.2", features = [
|
||||
# "abi3-py313", # tells pyo3 (and maturin) to build using the stable ABI with minimum Python version 3.13
|
||||
"nightly", # enables better-supported GIL integration
|
||||
# "nightly", # enables better-supported GIL integration
|
||||
"experimental-async", # async support in #[pyfunction] & #[pymethods]
|
||||
#"experimental-inspect", # inspection of generated binary => easier to automate type-hint generation
|
||||
#"py-clone", # adding Clone-ing of `Py<T>` without GIL (may cause panics - remove if panics happen)
|
||||
@@ -45,11 +45,10 @@ pyo3-log = "0.13.2"
|
||||
# macro dependencies
|
||||
extend = { workspace = true }
|
||||
delegate = { workspace = true }
|
||||
pin-project = { workspace = true }
|
||||
|
||||
# async runtime
|
||||
tokio = { workspace = true, features = ["full", "tracing"] }
|
||||
futures = { workspace = true }
|
||||
futures-lite = { workspace = true }
|
||||
|
||||
# utility dependencies
|
||||
util = { workspace = true }
|
||||
@@ -60,3 +59,4 @@ env_logger = "0.11"
|
||||
|
||||
# Networking
|
||||
libp2p = { workspace = true, features = ["full"] }
|
||||
pin-project = "1.1.10"
|
||||
|
||||
@@ -19,7 +19,7 @@ class ConnectionUpdate:
|
||||
Whether this is a connection or disconnection event
|
||||
"""
|
||||
@property
|
||||
def peer_id(self) -> PeerId:
|
||||
def peer_id(self) -> builtins.str:
|
||||
r"""
|
||||
Identity of the peer that we have connected to or disconnected from.
|
||||
"""
|
||||
@@ -40,92 +40,22 @@ class Keypair:
|
||||
Identity keypair of a node.
|
||||
"""
|
||||
@staticmethod
|
||||
def generate_ed25519() -> Keypair:
|
||||
def generate() -> Keypair:
|
||||
r"""
|
||||
Generate a new Ed25519 keypair.
|
||||
"""
|
||||
@staticmethod
|
||||
def generate_ecdsa() -> Keypair:
|
||||
def from_bytes(bytes: bytes) -> Keypair:
|
||||
r"""
|
||||
Generate a new ECDSA keypair.
|
||||
"""
|
||||
@staticmethod
|
||||
def generate_secp256k1() -> Keypair:
|
||||
r"""
|
||||
Generate a new Secp256k1 keypair.
|
||||
"""
|
||||
@staticmethod
|
||||
def from_protobuf_encoding(bytes: bytes) -> Keypair:
|
||||
r"""
|
||||
Decode a private key from a protobuf structure and parse it as a `Keypair`.
|
||||
"""
|
||||
@staticmethod
|
||||
def rsa_from_pkcs8(bytes: bytes) -> Keypair:
|
||||
r"""
|
||||
Decode an keypair from a DER-encoded secret key in PKCS#8 `PrivateKeyInfo`
|
||||
format (i.e. unencrypted) as defined in [RFC5208].
|
||||
|
||||
[RFC5208]: https://tools.ietf.org/html/rfc5208#section-5
|
||||
"""
|
||||
@staticmethod
|
||||
def secp256k1_from_der(bytes: bytes) -> Keypair:
|
||||
r"""
|
||||
Decode a keypair from a DER-encoded Secp256k1 secret key in an `ECPrivateKey`
|
||||
structure as defined in [RFC5915].
|
||||
|
||||
[RFC5915]: https://tools.ietf.org/html/rfc5915
|
||||
"""
|
||||
@staticmethod
|
||||
def ed25519_from_bytes(bytes: bytes) -> Keypair: ...
|
||||
def to_protobuf_encoding(self) -> bytes:
|
||||
r"""
|
||||
Encode a private key as protobuf structure.
|
||||
"""
|
||||
def to_peer_id(self) -> PeerId:
|
||||
r"""
|
||||
Convert the `Keypair` into the corresponding `PeerId`.
|
||||
"""
|
||||
|
||||
@typing.final
|
||||
class Multiaddr:
|
||||
r"""
|
||||
Representation of a Multiaddr.
|
||||
"""
|
||||
@staticmethod
|
||||
def empty() -> Multiaddr:
|
||||
r"""
|
||||
Create a new, empty multiaddress.
|
||||
"""
|
||||
@staticmethod
|
||||
def with_capacity(n: builtins.int) -> Multiaddr:
|
||||
r"""
|
||||
Create a new, empty multiaddress with the given capacity.
|
||||
"""
|
||||
@staticmethod
|
||||
def from_bytes(bytes: bytes) -> Multiaddr:
|
||||
r"""
|
||||
Parse a `Multiaddr` value from its byte slice representation.
|
||||
"""
|
||||
@staticmethod
|
||||
def from_string(string: builtins.str) -> Multiaddr:
|
||||
r"""
|
||||
Parse a `Multiaddr` value from its string representation.
|
||||
"""
|
||||
def len(self) -> builtins.int:
|
||||
r"""
|
||||
Return the length in bytes of this multiaddress.
|
||||
"""
|
||||
def is_empty(self) -> builtins.bool:
|
||||
r"""
|
||||
Returns true if the length of this multiaddress is 0.
|
||||
Construct an Ed25519 keypair from secret key bytes
|
||||
"""
|
||||
def to_bytes(self) -> bytes:
|
||||
r"""
|
||||
Return a copy of this [`Multiaddr`]'s byte representation.
|
||||
Get the secret key bytes underlying the keypair
|
||||
"""
|
||||
def to_string(self) -> builtins.str:
|
||||
def to_node_id(self) -> builtins.str:
|
||||
r"""
|
||||
Convert a Multiaddr to a string.
|
||||
Convert the `Keypair` into the corresponding `PeerId` string, which we use as our `NodeId`.
|
||||
"""
|
||||
|
||||
@typing.final
|
||||
@@ -180,37 +110,6 @@ class NoPeersSubscribedToTopicError(builtins.Exception):
|
||||
def __repr__(self) -> builtins.str: ...
|
||||
def __str__(self) -> builtins.str: ...
|
||||
|
||||
@typing.final
|
||||
class PeerId:
|
||||
r"""
|
||||
Identifier of a peer of the network.
|
||||
|
||||
The data is a `CIDv0` compatible multihash of the protobuf encoded public key of the peer
|
||||
as specified in [specs/peer-ids](https://github.com/libp2p/specs/blob/master/peer-ids/peer-ids.md).
|
||||
"""
|
||||
@staticmethod
|
||||
def random() -> PeerId:
|
||||
r"""
|
||||
Generates a random peer ID from a cryptographically secure PRNG.
|
||||
|
||||
This is useful for randomly walking on a DHT, or for testing purposes.
|
||||
"""
|
||||
@staticmethod
|
||||
def from_bytes(bytes: bytes) -> PeerId:
|
||||
r"""
|
||||
Parses a `PeerId` from bytes.
|
||||
"""
|
||||
def to_bytes(self) -> bytes:
|
||||
r"""
|
||||
Returns a raw bytes representation of this `PeerId`.
|
||||
"""
|
||||
def to_base58(self) -> builtins.str:
|
||||
r"""
|
||||
Returns a base-58 encoded string of this `PeerId`.
|
||||
"""
|
||||
def __repr__(self) -> builtins.str: ...
|
||||
def __str__(self) -> builtins.str: ...
|
||||
|
||||
@typing.final
|
||||
class ConnectionUpdateType(enum.Enum):
|
||||
r"""
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
//!
|
||||
|
||||
use pin_project::pin_project;
|
||||
use pyo3::marker::Ungil;
|
||||
use pyo3::prelude::*;
|
||||
use std::{
|
||||
future::Future,
|
||||
@@ -26,8 +25,8 @@ where
|
||||
|
||||
impl<F> Future for AllowThreads<F>
|
||||
where
|
||||
F: Future + Ungil,
|
||||
F::Output: Ungil,
|
||||
F: Future + Send,
|
||||
F::Output: Send,
|
||||
{
|
||||
type Output = F::Output;
|
||||
|
||||
|
||||
47
rust/exo_pyo3_bindings/src/ident.rs
Normal file
47
rust/exo_pyo3_bindings/src/ident.rs
Normal file
@@ -0,0 +1,47 @@
|
||||
use crate::ext::ResultExt as _;
|
||||
use libp2p::identity::Keypair;
|
||||
use pyo3::types::{PyBytes, PyBytesMethods as _};
|
||||
use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
|
||||
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
|
||||
|
||||
/// Identity keypair of a node.
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "Keypair", frozen)]
|
||||
#[repr(transparent)]
|
||||
pub struct PyKeypair(pub Keypair);
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
#[allow(clippy::needless_pass_by_value)]
|
||||
impl PyKeypair {
|
||||
/// Generate a new Ed25519 keypair.
|
||||
#[staticmethod]
|
||||
fn generate() -> Self {
|
||||
Self(Keypair::generate_ed25519())
|
||||
}
|
||||
|
||||
/// Construct an Ed25519 keypair from secret key bytes
|
||||
#[staticmethod]
|
||||
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let mut bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(Keypair::ed25519_from_bytes(&mut bytes).pyerr()?))
|
||||
}
|
||||
|
||||
/// Get the secret key bytes underlying the keypair
|
||||
fn to_bytes<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
|
||||
let bytes = self
|
||||
.0
|
||||
.clone()
|
||||
.try_into_ed25519()
|
||||
.pyerr()?
|
||||
.secret()
|
||||
.as_ref()
|
||||
.to_vec();
|
||||
Ok(PyBytes::new(py, &bytes))
|
||||
}
|
||||
|
||||
/// Convert the `Keypair` into the corresponding `PeerId` string, which we use as our `NodeId`.
|
||||
fn to_node_id(&self) -> String {
|
||||
self.0.public().to_peer_id().to_base58()
|
||||
}
|
||||
}
|
||||
@@ -4,26 +4,14 @@
|
||||
//!
|
||||
//!
|
||||
|
||||
// enable Rust-unstable features for convenience
|
||||
#![feature(trait_alias)]
|
||||
#![feature(tuple_trait)]
|
||||
#![feature(unboxed_closures)]
|
||||
// #![feature(stmt_expr_attributes)]
|
||||
// #![feature(assert_matches)]
|
||||
// #![feature(async_fn_in_dyn_trait)]
|
||||
// #![feature(async_for_loop)]
|
||||
// #![feature(auto_traits)]
|
||||
// #![feature(negative_impls)]
|
||||
|
||||
extern crate core;
|
||||
mod allow_threading;
|
||||
pub(crate) mod networking;
|
||||
pub(crate) mod pylibp2p;
|
||||
mod ident;
|
||||
mod networking;
|
||||
|
||||
use crate::ident::PyKeypair;
|
||||
use crate::networking::networking_submodule;
|
||||
use crate::pylibp2p::ident::ident_submodule;
|
||||
use crate::pylibp2p::multiaddr::multiaddr_submodule;
|
||||
use pyo3::prelude::PyModule;
|
||||
use pyo3::types::PyModuleMethods;
|
||||
use pyo3::{Bound, PyResult, pyclass, pymodule};
|
||||
use pyo3_stub_gen::define_stub_info_gatherer;
|
||||
|
||||
@@ -32,14 +20,6 @@ pub(crate) mod r#const {
|
||||
pub const MPSC_CHANNEL_SIZE: usize = 1024;
|
||||
}
|
||||
|
||||
/// Namespace for all the type/trait aliases used by this crate.
|
||||
pub(crate) mod alias {
|
||||
use std::marker::Tuple;
|
||||
|
||||
pub trait SendFn<Args: Tuple + Send + 'static, Output> =
|
||||
Fn<Args, Output = Output> + Send + 'static;
|
||||
}
|
||||
|
||||
/// Namespace for crate-wide extension traits/methods
|
||||
pub(crate) mod ext {
|
||||
use crate::allow_threading::AllowThreads;
|
||||
@@ -179,8 +159,7 @@ fn main_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
// TODO: for now this is all NOT a submodule, but figure out how to make the submodule system
|
||||
// work with maturin, where the types generate correctly, in the right folder, without
|
||||
// too many importing issues...
|
||||
ident_submodule(m)?;
|
||||
multiaddr_submodule(m)?;
|
||||
m.add_class::<PyKeypair>()?;
|
||||
networking_submodule(m)?;
|
||||
|
||||
// top-level constructs
|
||||
|
||||
@@ -8,8 +8,8 @@
|
||||
use crate::r#const::MPSC_CHANNEL_SIZE;
|
||||
use crate::ext::{ByteArrayExt as _, FutureExt, PyErrExt as _};
|
||||
use crate::ext::{ResultExt as _, TokioMpscReceiverExt as _, TokioMpscSenderExt as _};
|
||||
use crate::ident::PyKeypair;
|
||||
use crate::pyclass;
|
||||
use crate::pylibp2p::ident::{PyKeypair, PyPeerId};
|
||||
use libp2p::futures::StreamExt as _;
|
||||
use libp2p::gossipsub;
|
||||
use libp2p::gossipsub::{IdentTopic, Message, MessageId, PublishError};
|
||||
@@ -119,7 +119,7 @@ struct PyConnectionUpdate {
|
||||
|
||||
/// Identity of the peer that we have connected to or disconnected from.
|
||||
#[pyo3(get)]
|
||||
peer_id: PyPeerId,
|
||||
peer_id: String,
|
||||
|
||||
/// Remote connection's IPv4 address.
|
||||
#[pyo3(get)]
|
||||
@@ -251,7 +251,7 @@ async fn networking_task(
|
||||
// send connection event to channel (or exit if connection closed)
|
||||
if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
|
||||
update_type: PyConnectionUpdateType::Connected,
|
||||
peer_id: PyPeerId(peer_id),
|
||||
peer_id: peer_id.to_base58(),
|
||||
remote_ipv4,
|
||||
remote_tcp_port,
|
||||
}).await {
|
||||
@@ -272,7 +272,7 @@ async fn networking_task(
|
||||
// send disconnection event to channel (or exit if connection closed)
|
||||
if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
|
||||
update_type: PyConnectionUpdateType::Disconnected,
|
||||
peer_id: PyPeerId(peer_id),
|
||||
peer_id: peer_id.to_base58(),
|
||||
remote_ipv4,
|
||||
remote_tcp_port,
|
||||
}).await {
|
||||
|
||||
@@ -1,159 +0,0 @@
|
||||
use crate::ext::ResultExt as _;
|
||||
use libp2p::PeerId;
|
||||
use libp2p::identity::Keypair;
|
||||
use pyo3::prelude::{PyBytesMethods as _, PyModule, PyModuleMethods as _};
|
||||
use pyo3::types::PyBytes;
|
||||
use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
|
||||
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
|
||||
|
||||
/// Identity keypair of a node.
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "Keypair", frozen)]
|
||||
#[repr(transparent)]
|
||||
pub struct PyKeypair(pub Keypair);
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
#[allow(clippy::needless_pass_by_value)]
|
||||
impl PyKeypair {
|
||||
/// Generate a new Ed25519 keypair.
|
||||
#[staticmethod]
|
||||
fn generate_ed25519() -> Self {
|
||||
Self(Keypair::generate_ed25519())
|
||||
}
|
||||
|
||||
/// Generate a new ECDSA keypair.
|
||||
#[staticmethod]
|
||||
fn generate_ecdsa() -> Self {
|
||||
Self(Keypair::generate_ecdsa())
|
||||
}
|
||||
|
||||
/// Generate a new Secp256k1 keypair.
|
||||
#[staticmethod]
|
||||
fn generate_secp256k1() -> Self {
|
||||
Self(Keypair::generate_secp256k1())
|
||||
}
|
||||
|
||||
/// Decode a private key from a protobuf structure and parse it as a `Keypair`.
|
||||
#[staticmethod]
|
||||
fn from_protobuf_encoding(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(Keypair::from_protobuf_encoding(&bytes).pyerr()?))
|
||||
}
|
||||
|
||||
/// Decode an keypair from a DER-encoded secret key in PKCS#8 `PrivateKeyInfo`
|
||||
/// format (i.e. unencrypted) as defined in [RFC5208].
|
||||
///
|
||||
/// [RFC5208]: https://tools.ietf.org/html/rfc5208#section-5
|
||||
#[staticmethod]
|
||||
fn rsa_from_pkcs8(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let mut bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(Keypair::rsa_from_pkcs8(&mut bytes).pyerr()?))
|
||||
}
|
||||
|
||||
/// Decode a keypair from a DER-encoded Secp256k1 secret key in an `ECPrivateKey`
|
||||
/// structure as defined in [RFC5915].
|
||||
///
|
||||
/// [RFC5915]: https://tools.ietf.org/html/rfc5915
|
||||
#[staticmethod]
|
||||
fn secp256k1_from_der(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let mut bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(Keypair::secp256k1_from_der(&mut bytes).pyerr()?))
|
||||
}
|
||||
|
||||
#[staticmethod]
|
||||
fn ed25519_from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let mut bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(Keypair::ed25519_from_bytes(&mut bytes).pyerr()?))
|
||||
}
|
||||
|
||||
/// Encode a private key as protobuf structure.
|
||||
fn to_protobuf_encoding<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
|
||||
let bytes = self.0.to_protobuf_encoding().pyerr()?;
|
||||
Ok(PyBytes::new(py, &bytes))
|
||||
}
|
||||
|
||||
/// Convert the `Keypair` into the corresponding `PeerId`.
|
||||
fn to_peer_id(&self) -> PyPeerId {
|
||||
PyPeerId(self.0.public().to_peer_id())
|
||||
}
|
||||
|
||||
// /// Hidden constructor for pickling support. TODO: figure out how to do pickling...
|
||||
// #[gen_stub(skip)]
|
||||
// #[new]
|
||||
// fn py_new(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
// Self::from_protobuf_encoding(bytes)
|
||||
// }
|
||||
//
|
||||
// #[gen_stub(skip)]
|
||||
// fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
|
||||
// *self = Self::from_protobuf_encoding(state)?;
|
||||
// Ok(())
|
||||
// }
|
||||
//
|
||||
// #[gen_stub(skip)]
|
||||
// fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
|
||||
// self.to_protobuf_encoding(py)
|
||||
// }
|
||||
//
|
||||
// #[gen_stub(skip)]
|
||||
// pub fn __getnewargs__<'py>(&self, py: Python<'py>) -> PyResult<(Bound<'py, PyBytes>,)> {
|
||||
// Ok((self.to_protobuf_encoding(py)?,))
|
||||
// }
|
||||
}
|
||||
|
||||
/// Identifier of a peer of the network.
|
||||
///
|
||||
/// The data is a `CIDv0` compatible multihash of the protobuf encoded public key of the peer
|
||||
/// as specified in [specs/peer-ids](https://github.com/libp2p/specs/blob/master/peer-ids/peer-ids.md).
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "PeerId", frozen)]
|
||||
#[derive(Debug, Clone)]
|
||||
#[repr(transparent)]
|
||||
pub struct PyPeerId(pub PeerId);
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
#[allow(clippy::needless_pass_by_value)]
|
||||
impl PyPeerId {
|
||||
/// Generates a random peer ID from a cryptographically secure PRNG.
|
||||
///
|
||||
/// This is useful for randomly walking on a DHT, or for testing purposes.
|
||||
#[staticmethod]
|
||||
fn random() -> Self {
|
||||
Self(PeerId::random())
|
||||
}
|
||||
|
||||
/// Parses a `PeerId` from bytes.
|
||||
#[staticmethod]
|
||||
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(PeerId::from_bytes(&bytes).pyerr()?))
|
||||
}
|
||||
|
||||
/// Returns a raw bytes representation of this `PeerId`.
|
||||
fn to_bytes<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> {
|
||||
let bytes = self.0.to_bytes();
|
||||
PyBytes::new(py, &bytes)
|
||||
}
|
||||
|
||||
/// Returns a base-58 encoded string of this `PeerId`.
|
||||
fn to_base58(&self) -> String {
|
||||
self.0.to_base58()
|
||||
}
|
||||
|
||||
fn __repr__(&self) -> String {
|
||||
format!("PeerId({})", self.to_base58())
|
||||
}
|
||||
|
||||
fn __str__(&self) -> String {
|
||||
self.to_base58()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn ident_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_class::<PyKeypair>()?;
|
||||
m.add_class::<PyPeerId>()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,8 +0,0 @@
|
||||
//! A module for exposing Rust's libp2p datatypes over Pyo3
|
||||
//!
|
||||
//! TODO: right now we are coupled to libp2p's identity, but eventually we want to create our own
|
||||
//! independent identity type of some kind or another. This may require handshaking.
|
||||
//!
|
||||
|
||||
pub mod ident;
|
||||
pub mod multiaddr;
|
||||
@@ -1,81 +0,0 @@
|
||||
use crate::ext::ResultExt as _;
|
||||
use libp2p::Multiaddr;
|
||||
use pyo3::prelude::{PyBytesMethods as _, PyModule, PyModuleMethods as _};
|
||||
use pyo3::types::PyBytes;
|
||||
use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
|
||||
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
|
||||
use std::str::FromStr as _;
|
||||
|
||||
/// Representation of a Multiaddr.
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "Multiaddr", frozen)]
|
||||
#[derive(Debug, Clone)]
|
||||
#[repr(transparent)]
|
||||
pub struct PyMultiaddr(pub Multiaddr);
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
#[allow(clippy::needless_pass_by_value)]
|
||||
impl PyMultiaddr {
|
||||
/// Create a new, empty multiaddress.
|
||||
#[staticmethod]
|
||||
fn empty() -> Self {
|
||||
Self(Multiaddr::empty())
|
||||
}
|
||||
|
||||
/// Create a new, empty multiaddress with the given capacity.
|
||||
#[staticmethod]
|
||||
fn with_capacity(n: usize) -> Self {
|
||||
Self(Multiaddr::with_capacity(n))
|
||||
}
|
||||
|
||||
/// Parse a `Multiaddr` value from its byte slice representation.
|
||||
#[staticmethod]
|
||||
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(Multiaddr::try_from(bytes).pyerr()?))
|
||||
}
|
||||
|
||||
/// Parse a `Multiaddr` value from its string representation.
|
||||
#[staticmethod]
|
||||
fn from_string(string: String) -> PyResult<Self> {
|
||||
Ok(Self(Multiaddr::from_str(&string).pyerr()?))
|
||||
}
|
||||
|
||||
/// Return the length in bytes of this multiaddress.
|
||||
fn len(&self) -> usize {
|
||||
self.0.len()
|
||||
}
|
||||
|
||||
/// Returns true if the length of this multiaddress is 0.
|
||||
fn is_empty(&self) -> bool {
|
||||
self.0.is_empty()
|
||||
}
|
||||
|
||||
/// Return a copy of this [`Multiaddr`]'s byte representation.
|
||||
fn to_bytes<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> {
|
||||
let bytes = self.0.to_vec();
|
||||
PyBytes::new(py, &bytes)
|
||||
}
|
||||
|
||||
/// Convert a Multiaddr to a string.
|
||||
fn to_string(&self) -> String {
|
||||
self.0.to_string()
|
||||
}
|
||||
|
||||
#[gen_stub(skip)]
|
||||
fn __repr__(&self) -> String {
|
||||
format!("Multiaddr({})", self.0)
|
||||
}
|
||||
|
||||
#[gen_stub(skip)]
|
||||
fn __str__(&self) -> String {
|
||||
self.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn multiaddr_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_class::<PyMultiaddr>()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -22,7 +22,7 @@ delegate = { workspace = true }
|
||||
|
||||
# async
|
||||
tokio = { workspace = true, features = ["full"] }
|
||||
futures = { workspace = true }
|
||||
futures-lite = { workspace = true }
|
||||
futures-timer = { workspace = true }
|
||||
|
||||
# utility dependencies
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use futures::stream::StreamExt as _;
|
||||
use futures_lite::StreamExt;
|
||||
use libp2p::{gossipsub, identity, swarm::SwarmEvent};
|
||||
use networking::{discovery, swarm};
|
||||
use tokio::{io, io::AsyncBufReadExt as _, select};
|
||||
@@ -38,19 +38,19 @@ async fn main() {
|
||||
println!("Publish error: {e:?}");
|
||||
}
|
||||
}
|
||||
event = swarm.select_next_some() => match event {
|
||||
event = swarm.next() => match event {
|
||||
// on gossipsub incoming
|
||||
SwarmEvent::Behaviour(swarm::BehaviourEvent::Gossipsub(gossipsub::Event::Message {
|
||||
Some(SwarmEvent::Behaviour(swarm::BehaviourEvent::Gossipsub(gossipsub::Event::Message {
|
||||
propagation_source: peer_id,
|
||||
message_id: id,
|
||||
message,
|
||||
})) => println!(
|
||||
}))) => println!(
|
||||
"\n\nGot message: '{}' with id: {id} from peer: {peer_id}\n\n",
|
||||
String::from_utf8_lossy(&message.data),
|
||||
),
|
||||
|
||||
// on discovery
|
||||
SwarmEvent::Behaviour(swarm::BehaviourEvent::Discovery(e)) => match e {
|
||||
Some(SwarmEvent::Behaviour(swarm::BehaviourEvent::Discovery(e)) )=> match e {
|
||||
discovery::Event::ConnectionEstablished {
|
||||
peer_id, connection_id, remote_ip, remote_tcp_port
|
||||
} => {
|
||||
@@ -64,7 +64,7 @@ async fn main() {
|
||||
}
|
||||
|
||||
// ignore outgoing errors: those are normal
|
||||
e@SwarmEvent::OutgoingConnectionError { .. } => { log::debug!("Outgoing connection error: {e:?}"); }
|
||||
e@Some(SwarmEvent::OutgoingConnectionError { .. }) => { log::debug!("Outgoing connection error: {e:?}"); }
|
||||
|
||||
// otherwise log any other event
|
||||
e => { log::info!("Other event {e:?}"); }
|
||||
|
||||
@@ -1,127 +0,0 @@
|
||||
// Copyright 2018 Parity Technologies (UK) Ltd.
|
||||
//
|
||||
// Permission is hereby granted, free of charge, to any person obtaining a
|
||||
// copy of this software and associated documentation files (the "Software"),
|
||||
// to deal in the Software without restriction, including without limitation
|
||||
// the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
// and/or sell copies of the Software, and to permit persons to whom the
|
||||
// Software is furnished to do so, subject to the following conditions:
|
||||
//
|
||||
// The above copyright notice and this permission notice shall be included in
|
||||
// all copies or substantial portions of the Software.
|
||||
//
|
||||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
// DEALINGS IN THE SOFTWARE.
|
||||
|
||||
use futures::stream::StreamExt;
|
||||
use libp2p::{
|
||||
gossipsub, mdns, noise,
|
||||
swarm::{NetworkBehaviour, SwarmEvent},
|
||||
tcp, yamux,
|
||||
};
|
||||
use std::error::Error;
|
||||
use std::time::Duration;
|
||||
use tokio::{io, io::AsyncBufReadExt, select};
|
||||
use tracing_subscriber::EnvFilter;
|
||||
|
||||
// We create a custom network behaviour that combines Gossipsub and Mdns.
|
||||
#[derive(NetworkBehaviour)]
|
||||
struct MyBehaviour {
|
||||
gossipsub: gossipsub::Behaviour,
|
||||
mdns: mdns::tokio::Behaviour,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn Error>> {
|
||||
let _ = tracing_subscriber::fmt()
|
||||
.with_env_filter(EnvFilter::from_default_env())
|
||||
.try_init();
|
||||
|
||||
let mut swarm = libp2p::SwarmBuilder::with_new_identity()
|
||||
.with_tokio()
|
||||
.with_tcp(
|
||||
tcp::Config::default(),
|
||||
noise::Config::new,
|
||||
yamux::Config::default,
|
||||
)?
|
||||
.with_behaviour(|key| {
|
||||
// Set a custom gossipsub configuration
|
||||
let gossipsub_config = gossipsub::ConfigBuilder::default()
|
||||
.heartbeat_interval(Duration::from_secs(10))
|
||||
.validation_mode(gossipsub::ValidationMode::Strict) // This sets the kind of message validation. The default is Strict (enforce message signing)
|
||||
.build()
|
||||
.map_err(io::Error::other)?; // Temporary hack because `build` does not return a proper `std::error::Error`.
|
||||
|
||||
// build a gossipsub network behaviour
|
||||
let gossipsub = gossipsub::Behaviour::new(
|
||||
gossipsub::MessageAuthenticity::Signed(key.clone()),
|
||||
gossipsub_config,
|
||||
)?;
|
||||
|
||||
let mdns =
|
||||
mdns::tokio::Behaviour::new(mdns::Config::default(), key.public().to_peer_id())?;
|
||||
Ok(MyBehaviour { gossipsub, mdns })
|
||||
})?
|
||||
.build();
|
||||
|
||||
println!("Running swarm with identity {}", swarm.local_peer_id());
|
||||
|
||||
// Create a Gossipsub topic
|
||||
let topic = gossipsub::IdentTopic::new("test-net");
|
||||
// subscribes to our topic
|
||||
swarm.behaviour_mut().gossipsub.subscribe(&topic)?;
|
||||
|
||||
// Read full lines from stdin
|
||||
let mut stdin = io::BufReader::new(io::stdin()).lines();
|
||||
|
||||
// Listen on all interfaces and whatever port the OS assigns
|
||||
swarm.listen_on("/ip4/0.0.0.0/tcp/0".parse()?)?;
|
||||
|
||||
println!("Enter messages via STDIN and they will be sent to connected peers using Gossipsub");
|
||||
|
||||
// Kick it off
|
||||
loop {
|
||||
select! {
|
||||
Ok(Some(line)) = stdin.next_line() => {
|
||||
if let Err(e) = swarm
|
||||
.behaviour_mut().gossipsub
|
||||
.publish(topic.clone(), line.as_bytes()) {
|
||||
println!("Publish error: {e:?}");
|
||||
}
|
||||
}
|
||||
event = swarm.select_next_some() => match event {
|
||||
SwarmEvent::Behaviour(MyBehaviourEvent::Mdns(mdns::Event::Discovered(list))) => {
|
||||
for (peer_id, multiaddr) in list {
|
||||
println!("mDNS discovered a new peer: {peer_id} on {multiaddr}");
|
||||
swarm.behaviour_mut().gossipsub.add_explicit_peer(&peer_id);
|
||||
}
|
||||
},
|
||||
SwarmEvent::Behaviour(MyBehaviourEvent::Mdns(mdns::Event::Expired(list))) => {
|
||||
for (peer_id, multiaddr) in list {
|
||||
println!("mDNS discover peer has expired: {peer_id} on {multiaddr}");
|
||||
swarm.behaviour_mut().gossipsub.remove_explicit_peer(&peer_id);
|
||||
}
|
||||
},
|
||||
SwarmEvent::Behaviour(MyBehaviourEvent::Gossipsub(gossipsub::Event::Message {
|
||||
propagation_source: peer_id,
|
||||
message_id: id,
|
||||
message,
|
||||
})) => println!(
|
||||
"Got message: '{}' with id: {id} from peer: {peer_id}",
|
||||
String::from_utf8_lossy(&message.data),
|
||||
),
|
||||
SwarmEvent::NewListenAddr { address, .. } => {
|
||||
println!("Local node is listening on {address}");
|
||||
}
|
||||
e => {
|
||||
println!("Other swarm event: {:?}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
use crate::ext::MultiaddrExt;
|
||||
use delegate::delegate;
|
||||
use either::Either;
|
||||
use futures::FutureExt;
|
||||
use futures_lite::FutureExt;
|
||||
use futures_timer::Delay;
|
||||
use libp2p::core::transport::PortUse;
|
||||
use libp2p::core::{ConnectedPoint, Endpoint};
|
||||
@@ -362,7 +362,7 @@ impl NetworkBehaviour for Behaviour {
|
||||
}
|
||||
|
||||
// retry connecting to all mDNS peers periodically (fails safely if already connected)
|
||||
if self.retry_delay.poll_unpin(cx).is_ready() {
|
||||
if self.retry_delay.poll(cx).is_ready() {
|
||||
for (p, mas) in self.mdns_discovered.clone() {
|
||||
for ma in mas {
|
||||
self.dial(p, ma)
|
||||
|
||||
@@ -31,7 +31,7 @@ pub fn create_swarm(keypair: identity::Keypair) -> alias::AnyResult<Swarm> {
|
||||
mod transport {
|
||||
use crate::alias;
|
||||
use crate::swarm::{NETWORK_VERSION, OVERRIDE_VERSION_ENV_VAR};
|
||||
use futures::{AsyncRead, AsyncWrite};
|
||||
use futures_lite::{AsyncRead, AsyncWrite};
|
||||
use keccak_const::Sha3_256;
|
||||
use libp2p::core::muxing;
|
||||
use libp2p::core::transport::Boxed;
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
{ inputs, ... }:
|
||||
{
|
||||
perSystem =
|
||||
{ config, self', inputs', pkgs, lib, ... }:
|
||||
{ inputs', pkgs, lib, ... }:
|
||||
let
|
||||
# Fenix nightly toolchain with all components
|
||||
fenixPkgs = inputs'.fenix.packages;
|
||||
rustToolchain = fenixPkgs.complete.withComponents [
|
||||
rustToolchain = inputs'.fenix.packages.stable.withComponents [
|
||||
"cargo"
|
||||
"rustc"
|
||||
"clippy"
|
||||
|
||||
@@ -1,2 +0,0 @@
|
||||
[toolchain]
|
||||
channel = "nightly"
|
||||
@@ -80,7 +80,7 @@ class DownloadCoordinator:
|
||||
completed = DownloadCompleted(
|
||||
shard_metadata=callback_shard,
|
||||
node_id=self.node_id,
|
||||
total_bytes=progress.total_bytes,
|
||||
total=progress.total,
|
||||
model_directory=self._model_dir(model_id),
|
||||
)
|
||||
self.download_status[model_id] = completed
|
||||
@@ -203,7 +203,7 @@ class DownloadCoordinator:
|
||||
completed = DownloadCompleted(
|
||||
shard_metadata=shard,
|
||||
node_id=self.node_id,
|
||||
total_bytes=initial_progress.total_bytes,
|
||||
total=initial_progress.total,
|
||||
model_directory=self._model_dir(model_id),
|
||||
)
|
||||
self.download_status[model_id] = completed
|
||||
@@ -332,13 +332,13 @@ class DownloadCoordinator:
|
||||
status: DownloadProgress = DownloadCompleted(
|
||||
node_id=self.node_id,
|
||||
shard_metadata=progress.shard,
|
||||
total_bytes=progress.total_bytes,
|
||||
total=progress.total,
|
||||
model_directory=self._model_dir(
|
||||
progress.shard.model_card.model_id
|
||||
),
|
||||
)
|
||||
elif progress.status in ["in_progress", "not_started"]:
|
||||
if progress.downloaded_bytes_this_session.in_bytes == 0:
|
||||
if progress.downloaded_this_session.in_bytes == 0:
|
||||
status = DownloadPending(
|
||||
node_id=self.node_id,
|
||||
shard_metadata=progress.shard,
|
||||
|
||||
@@ -80,9 +80,9 @@ def map_repo_file_download_progress_to_download_progress_data(
|
||||
repo_file_download_progress: RepoFileDownloadProgress,
|
||||
) -> DownloadProgressData:
|
||||
return DownloadProgressData(
|
||||
downloaded_bytes=repo_file_download_progress.downloaded,
|
||||
downloaded_bytes_this_session=repo_file_download_progress.downloaded_this_session,
|
||||
total_bytes=repo_file_download_progress.total,
|
||||
downloaded=repo_file_download_progress.downloaded,
|
||||
downloaded_this_session=repo_file_download_progress.downloaded_this_session,
|
||||
total=repo_file_download_progress.total,
|
||||
completed_files=1 if repo_file_download_progress.status == "complete" else 0,
|
||||
total_files=1,
|
||||
speed=repo_file_download_progress.speed,
|
||||
@@ -95,9 +95,9 @@ def map_repo_download_progress_to_download_progress_data(
|
||||
repo_download_progress: RepoDownloadProgress,
|
||||
) -> DownloadProgressData:
|
||||
return DownloadProgressData(
|
||||
total_bytes=repo_download_progress.total_bytes,
|
||||
downloaded_bytes=repo_download_progress.downloaded_bytes,
|
||||
downloaded_bytes_this_session=repo_download_progress.downloaded_bytes_this_session,
|
||||
total=repo_download_progress.total,
|
||||
downloaded=repo_download_progress.downloaded,
|
||||
downloaded_this_session=repo_download_progress.downloaded_this_session,
|
||||
completed_files=repo_download_progress.completed_files,
|
||||
total_files=repo_download_progress.total_files,
|
||||
speed=repo_download_progress.overall_speed,
|
||||
@@ -578,19 +578,20 @@ def calculate_repo_progress(
|
||||
file_progress: dict[str, RepoFileDownloadProgress],
|
||||
all_start_time: float,
|
||||
) -> RepoDownloadProgress:
|
||||
all_total_bytes = sum((p.total.in_bytes for p in file_progress.values()), 0)
|
||||
all_downloaded_bytes = sum(
|
||||
(p.downloaded.in_bytes for p in file_progress.values()), 0
|
||||
all_total = sum((p.total for p in file_progress.values()), Memory.from_bytes(0))
|
||||
all_downloaded = sum(
|
||||
(p.downloaded for p in file_progress.values()), Memory.from_bytes(0)
|
||||
)
|
||||
all_downloaded_bytes_this_session = sum(
|
||||
(p.downloaded_this_session.in_bytes for p in file_progress.values()), 0
|
||||
all_downloaded_this_session = sum(
|
||||
(p.downloaded_this_session for p in file_progress.values()),
|
||||
Memory.from_bytes(0),
|
||||
)
|
||||
elapsed_time = time.time() - all_start_time
|
||||
all_speed = (
|
||||
all_downloaded_bytes_this_session / elapsed_time if elapsed_time > 0 else 0
|
||||
all_downloaded_this_session.in_bytes / elapsed_time if elapsed_time > 0 else 0
|
||||
)
|
||||
all_eta = (
|
||||
timedelta(seconds=(all_total_bytes - all_downloaded_bytes) / all_speed)
|
||||
timedelta(seconds=(all_total - all_downloaded).in_bytes / all_speed)
|
||||
if all_speed > 0
|
||||
else timedelta(seconds=0)
|
||||
)
|
||||
@@ -609,11 +610,9 @@ def calculate_repo_progress(
|
||||
[p for p in file_progress.values() if p.downloaded == p.total]
|
||||
),
|
||||
total_files=len(file_progress),
|
||||
downloaded_bytes=Memory.from_bytes(all_downloaded_bytes),
|
||||
downloaded_bytes_this_session=Memory.from_bytes(
|
||||
all_downloaded_bytes_this_session
|
||||
),
|
||||
total_bytes=Memory.from_bytes(all_total_bytes),
|
||||
downloaded=all_downloaded,
|
||||
downloaded_this_session=all_downloaded_this_session,
|
||||
total=all_total,
|
||||
overall_speed=all_speed,
|
||||
overall_eta=all_eta,
|
||||
status=status,
|
||||
|
||||
@@ -107,9 +107,9 @@ NOOP_DOWNLOAD_PROGRESS = RepoDownloadProgress(
|
||||
),
|
||||
completed_files=0,
|
||||
total_files=0,
|
||||
downloaded_bytes=Memory.from_bytes(0),
|
||||
downloaded_bytes_this_session=Memory.from_bytes(0),
|
||||
total_bytes=Memory.from_bytes(0),
|
||||
downloaded=Memory.from_bytes(0),
|
||||
downloaded_this_session=Memory.from_bytes(0),
|
||||
total=Memory.from_bytes(0),
|
||||
overall_speed=0,
|
||||
overall_eta=timedelta(seconds=0),
|
||||
status="complete",
|
||||
|
||||
@@ -45,7 +45,7 @@ class Node:
|
||||
@classmethod
|
||||
async def create(cls, args: "Args") -> "Self":
|
||||
keypair = get_node_id_keypair()
|
||||
node_id = NodeId(keypair.to_peer_id().to_base58())
|
||||
node_id = NodeId(keypair.to_node_id())
|
||||
session_id = SessionId(master_node_id=node_id, election_clock=0)
|
||||
router = Router.create(keypair)
|
||||
await router.register_topic(topics.GLOBAL_EVENTS)
|
||||
|
||||
@@ -59,7 +59,11 @@ def chat_request_to_text_generation(
|
||||
chat_template_messages.append({"role": "system", "content": content})
|
||||
else:
|
||||
# Skip messages with no meaningful content
|
||||
if msg.content is None and msg.thinking is None and msg.tool_calls is None:
|
||||
if (
|
||||
msg.content is None
|
||||
and msg.reasoning_content is None
|
||||
and msg.tool_calls is None
|
||||
):
|
||||
continue
|
||||
|
||||
if msg.role in ("user", "assistant", "developer"):
|
||||
@@ -111,6 +115,11 @@ def chunk_to_response(
|
||||
]
|
||||
)
|
||||
|
||||
if chunk.is_thinking:
|
||||
delta = ChatCompletionMessage(role="assistant", reasoning_content=chunk.text)
|
||||
else:
|
||||
delta = ChatCompletionMessage(role="assistant", content=chunk.text)
|
||||
|
||||
return ChatCompletionResponse(
|
||||
id=command_id,
|
||||
created=int(time.time()),
|
||||
@@ -118,7 +127,7 @@ def chunk_to_response(
|
||||
choices=[
|
||||
StreamingChoiceResponse(
|
||||
index=0,
|
||||
delta=ChatCompletionMessage(role="assistant", content=chunk.text),
|
||||
delta=delta,
|
||||
logprobs=logprobs,
|
||||
finish_reason=chunk.finish_reason,
|
||||
)
|
||||
@@ -208,6 +217,7 @@ async def collect_chat_response(
|
||||
# FastAPI handles the cancellation better but wouldn't auto-serialize for some reason
|
||||
"""Collect all token chunks and return a single ChatCompletionResponse."""
|
||||
text_parts: list[str] = []
|
||||
thinking_parts: list[str] = []
|
||||
tool_calls: list[ToolCall] = []
|
||||
logprobs_content: list[LogprobsContentItem] = []
|
||||
model: str | None = None
|
||||
@@ -228,7 +238,10 @@ async def collect_chat_response(
|
||||
if model is None:
|
||||
model = chunk.model
|
||||
last_usage = chunk.usage or last_usage
|
||||
text_parts.append(chunk.text)
|
||||
if chunk.is_thinking:
|
||||
thinking_parts.append(chunk.text)
|
||||
else:
|
||||
text_parts.append(chunk.text)
|
||||
if chunk.logprob is not None:
|
||||
logprobs_content.append(
|
||||
LogprobsContentItem(
|
||||
@@ -258,6 +271,7 @@ async def collect_chat_response(
|
||||
raise ValueError(error_message)
|
||||
|
||||
combined_text = "".join(text_parts)
|
||||
combined_thinking = "".join(thinking_parts) if thinking_parts else None
|
||||
assert model is not None
|
||||
|
||||
yield ChatCompletionResponse(
|
||||
@@ -270,6 +284,7 @@ async def collect_chat_response(
|
||||
message=ChatCompletionMessage(
|
||||
role="assistant",
|
||||
content=combined_text,
|
||||
reasoning_content=combined_thinking,
|
||||
tool_calls=tool_calls if tool_calls else None,
|
||||
),
|
||||
logprobs=Logprobs(content=logprobs_content)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Claude Messages API adapter for converting requests/responses."""
|
||||
|
||||
import json
|
||||
import re
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
@@ -28,6 +29,8 @@ from exo.shared.types.claude_api import (
|
||||
ClaudeStopReason,
|
||||
ClaudeTextBlock,
|
||||
ClaudeTextDelta,
|
||||
ClaudeThinkingBlock,
|
||||
ClaudeThinkingDelta,
|
||||
ClaudeToolResultBlock,
|
||||
ClaudeToolUseBlock,
|
||||
ClaudeUsage,
|
||||
@@ -61,6 +64,22 @@ def _extract_tool_result_text(block: ClaudeToolResultBlock) -> str:
|
||||
return "".join(sub_block.text for sub_block in block.content)
|
||||
|
||||
|
||||
# Matches "x-anthropic-billing-header: ...;" (with optional trailing newline)
|
||||
# or similar telemetry headers that change every request and break KV prefix caching.
|
||||
_VOLATILE_HEADER_RE = re.compile(r"^x-anthropic-[^\n]*;\n?", re.MULTILINE)
|
||||
|
||||
|
||||
def _strip_volatile_headers(text: str) -> str:
|
||||
"""Remove Anthropic billing/telemetry headers from system prompt text.
|
||||
|
||||
Claude Code prepends headers like 'x-anthropic-billing-header: cc_version=...;
|
||||
cc_entrypoint=...; cch=...;' that contain per-request content hashes. These
|
||||
change every request and break KV prefix caching (the prefix diverges at ~20
|
||||
tokens instead of matching thousands of conversation tokens).
|
||||
"""
|
||||
return _VOLATILE_HEADER_RE.sub("", text)
|
||||
|
||||
|
||||
def claude_request_to_text_generation(
|
||||
request: ClaudeMessagesRequest,
|
||||
) -> TextGenerationTaskParams:
|
||||
@@ -73,6 +92,8 @@ def claude_request_to_text_generation(
|
||||
instructions = request.system
|
||||
else:
|
||||
instructions = "".join(block.text for block in request.system)
|
||||
|
||||
instructions = _strip_volatile_headers(instructions)
|
||||
chat_template_messages.append({"role": "system", "content": instructions})
|
||||
|
||||
# Convert messages to input
|
||||
@@ -85,12 +106,15 @@ def claude_request_to_text_generation(
|
||||
|
||||
# Process structured content blocks
|
||||
text_parts: list[str] = []
|
||||
thinking_parts: list[str] = []
|
||||
tool_calls: list[dict[str, Any]] = []
|
||||
tool_results: list[ClaudeToolResultBlock] = []
|
||||
|
||||
for block in msg.content:
|
||||
if isinstance(block, ClaudeTextBlock):
|
||||
text_parts.append(block.text)
|
||||
elif isinstance(block, ClaudeThinkingBlock):
|
||||
thinking_parts.append(block.thinking)
|
||||
elif isinstance(block, ClaudeToolUseBlock):
|
||||
tool_calls.append(
|
||||
{
|
||||
@@ -106,6 +130,7 @@ def claude_request_to_text_generation(
|
||||
tool_results.append(block)
|
||||
|
||||
content = "".join(text_parts)
|
||||
reasoning_content = "".join(thinking_parts) if thinking_parts else None
|
||||
|
||||
# Build InputMessage from text content
|
||||
if msg.role in ("user", "assistant"):
|
||||
@@ -113,9 +138,14 @@ def claude_request_to_text_generation(
|
||||
|
||||
# Build chat_template_messages preserving tool structure
|
||||
if tool_calls:
|
||||
chat_template_messages.append(
|
||||
{"role": "assistant", "content": content, "tool_calls": tool_calls}
|
||||
)
|
||||
chat_msg: dict[str, Any] = {
|
||||
"role": "assistant",
|
||||
"content": content,
|
||||
"tool_calls": tool_calls,
|
||||
}
|
||||
if reasoning_content:
|
||||
chat_msg["reasoning_content"] = reasoning_content
|
||||
chat_template_messages.append(chat_msg)
|
||||
elif tool_results:
|
||||
for tr in tool_results:
|
||||
chat_template_messages.append(
|
||||
@@ -126,7 +156,10 @@ def claude_request_to_text_generation(
|
||||
}
|
||||
)
|
||||
else:
|
||||
chat_template_messages.append({"role": msg.role, "content": content})
|
||||
chat_msg = {"role": msg.role, "content": content}
|
||||
if reasoning_content:
|
||||
chat_msg["reasoning_content"] = reasoning_content
|
||||
chat_template_messages.append(chat_msg)
|
||||
|
||||
# Convert Claude tool definitions to OpenAI-style function tools
|
||||
tools: list[dict[str, Any]] | None = None
|
||||
@@ -143,6 +176,10 @@ def claude_request_to_text_generation(
|
||||
for tool in request.tools
|
||||
]
|
||||
|
||||
enable_thinking: bool | None = None
|
||||
if request.thinking is not None:
|
||||
enable_thinking = request.thinking.type in ("enabled", "adaptive")
|
||||
|
||||
return TextGenerationTaskParams(
|
||||
model=request.model,
|
||||
input=input_messages
|
||||
@@ -156,6 +193,7 @@ def claude_request_to_text_generation(
|
||||
stop=request.stop_sequences,
|
||||
stream=request.stream,
|
||||
tools=tools,
|
||||
enable_thinking=enable_thinking,
|
||||
chat_template_messages=chat_template_messages
|
||||
if chat_template_messages
|
||||
else None,
|
||||
@@ -173,6 +211,7 @@ async def collect_claude_response(
|
||||
# FastAPI handles the cancellation better but wouldn't auto-serialize for some reason
|
||||
"""Collect all token chunks and return a single ClaudeMessagesResponse."""
|
||||
text_parts: list[str] = []
|
||||
thinking_parts: list[str] = []
|
||||
tool_use_blocks: list[ClaudeToolUseBlock] = []
|
||||
stop_reason: ClaudeStopReason | None = None
|
||||
last_usage: Usage | None = None
|
||||
@@ -200,7 +239,10 @@ async def collect_claude_response(
|
||||
stop_reason = "tool_use"
|
||||
continue
|
||||
|
||||
text_parts.append(chunk.text)
|
||||
if chunk.is_thinking:
|
||||
thinking_parts.append(chunk.text)
|
||||
else:
|
||||
text_parts.append(chunk.text)
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
stop_reason = finish_reason_to_claude_stop_reason(chunk.finish_reason)
|
||||
@@ -209,9 +251,12 @@ async def collect_claude_response(
|
||||
raise ValueError(error_message)
|
||||
|
||||
combined_text = "".join(text_parts)
|
||||
combined_thinking = "".join(thinking_parts)
|
||||
|
||||
# Build content blocks
|
||||
content: list[ClaudeContentBlock] = []
|
||||
if combined_thinking:
|
||||
content.append(ClaudeThinkingBlock(thinking=combined_thinking))
|
||||
if combined_text:
|
||||
content.append(ClaudeTextBlock(text=combined_text))
|
||||
content.extend(tool_use_blocks)
|
||||
@@ -256,16 +301,16 @@ async def generate_claude_stream(
|
||||
start_event = ClaudeMessageStartEvent(message=initial_message)
|
||||
yield f"event: message_start\ndata: {start_event.model_dump_json()}\n\n"
|
||||
|
||||
# content_block_start for text block at index 0
|
||||
block_start = ClaudeContentBlockStartEvent(
|
||||
index=0, content_block=ClaudeTextBlock(text="")
|
||||
)
|
||||
yield f"event: content_block_start\ndata: {block_start.model_dump_json()}\n\n"
|
||||
|
||||
output_tokens = 0
|
||||
stop_reason: ClaudeStopReason | None = None
|
||||
last_usage: Usage | None = None
|
||||
next_block_index = 1 # text block is 0, tool blocks start at 1
|
||||
next_block_index = 0
|
||||
|
||||
# Track whether we've started thinking/text blocks
|
||||
thinking_block_started = False
|
||||
thinking_block_index = -1
|
||||
text_block_started = False
|
||||
text_block_index = -1
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
if isinstance(chunk, PrefillProgressChunk):
|
||||
@@ -310,12 +355,45 @@ async def generate_claude_stream(
|
||||
|
||||
output_tokens += 1 # Count each chunk as one token
|
||||
|
||||
# content_block_delta
|
||||
delta_event = ClaudeContentBlockDeltaEvent(
|
||||
index=0,
|
||||
delta=ClaudeTextDelta(text=chunk.text),
|
||||
)
|
||||
yield f"event: content_block_delta\ndata: {delta_event.model_dump_json()}\n\n"
|
||||
if chunk.is_thinking:
|
||||
# Start thinking block on first thinking token
|
||||
if not thinking_block_started:
|
||||
thinking_block_started = True
|
||||
thinking_block_index = next_block_index
|
||||
next_block_index += 1
|
||||
block_start = ClaudeContentBlockStartEvent(
|
||||
index=thinking_block_index,
|
||||
content_block=ClaudeThinkingBlock(thinking=""),
|
||||
)
|
||||
yield f"event: content_block_start\ndata: {block_start.model_dump_json()}\n\n"
|
||||
|
||||
delta_event = ClaudeContentBlockDeltaEvent(
|
||||
index=thinking_block_index,
|
||||
delta=ClaudeThinkingDelta(thinking=chunk.text),
|
||||
)
|
||||
yield f"event: content_block_delta\ndata: {delta_event.model_dump_json()}\n\n"
|
||||
else:
|
||||
# Close thinking block when transitioning to text
|
||||
if thinking_block_started and text_block_index == -1:
|
||||
block_stop = ClaudeContentBlockStopEvent(index=thinking_block_index)
|
||||
yield f"event: content_block_stop\ndata: {block_stop.model_dump_json()}\n\n"
|
||||
|
||||
# Start text block on first text token
|
||||
if not text_block_started:
|
||||
text_block_started = True
|
||||
text_block_index = next_block_index
|
||||
next_block_index += 1
|
||||
block_start = ClaudeContentBlockStartEvent(
|
||||
index=text_block_index,
|
||||
content_block=ClaudeTextBlock(text=""),
|
||||
)
|
||||
yield f"event: content_block_start\ndata: {block_start.model_dump_json()}\n\n"
|
||||
|
||||
delta_event = ClaudeContentBlockDeltaEvent(
|
||||
index=text_block_index,
|
||||
delta=ClaudeTextDelta(text=chunk.text),
|
||||
)
|
||||
yield f"event: content_block_delta\ndata: {delta_event.model_dump_json()}\n\n"
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
stop_reason = finish_reason_to_claude_stop_reason(chunk.finish_reason)
|
||||
@@ -324,9 +402,22 @@ async def generate_claude_stream(
|
||||
if last_usage is not None:
|
||||
output_tokens = last_usage.completion_tokens
|
||||
|
||||
# content_block_stop for text block
|
||||
block_stop = ClaudeContentBlockStopEvent(index=0)
|
||||
yield f"event: content_block_stop\ndata: {block_stop.model_dump_json()}\n\n"
|
||||
# Close any open blocks
|
||||
if thinking_block_started and text_block_index == -1:
|
||||
block_stop = ClaudeContentBlockStopEvent(index=thinking_block_index)
|
||||
yield f"event: content_block_stop\ndata: {block_stop.model_dump_json()}\n\n"
|
||||
|
||||
if text_block_started:
|
||||
block_stop = ClaudeContentBlockStopEvent(index=text_block_index)
|
||||
yield f"event: content_block_stop\ndata: {block_stop.model_dump_json()}\n\n"
|
||||
|
||||
if not thinking_block_started and not text_block_started:
|
||||
empty_start = ClaudeContentBlockStartEvent(
|
||||
index=0, content_block=ClaudeTextBlock(text="")
|
||||
)
|
||||
yield f"event: content_block_start\ndata: {empty_start.model_dump_json()}\n\n"
|
||||
empty_stop = ClaudeContentBlockStopEvent(index=0)
|
||||
yield f"event: content_block_stop\ndata: {empty_stop.model_dump_json()}\n\n"
|
||||
|
||||
# message_delta
|
||||
message_delta = ClaudeMessageDeltaEvent(
|
||||
|
||||
@@ -29,8 +29,15 @@ from exo.shared.types.openai_responses import (
|
||||
ResponseOutputItemAddedEvent,
|
||||
ResponseOutputItemDoneEvent,
|
||||
ResponseOutputText,
|
||||
ResponseReasoningItem,
|
||||
ResponseReasoningSummaryPartAddedEvent,
|
||||
ResponseReasoningSummaryPartDoneEvent,
|
||||
ResponseReasoningSummaryText,
|
||||
ResponseReasoningSummaryTextDeltaEvent,
|
||||
ResponseReasoningSummaryTextDoneEvent,
|
||||
ResponsesRequest,
|
||||
ResponsesResponse,
|
||||
ResponsesStreamEvent,
|
||||
ResponseTextDeltaEvent,
|
||||
ResponseTextDoneEvent,
|
||||
ResponseUsage,
|
||||
@@ -38,6 +45,11 @@ from exo.shared.types.openai_responses import (
|
||||
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
||||
|
||||
|
||||
def _format_sse(event: ResponsesStreamEvent) -> str:
|
||||
"""Format a streaming event as an SSE message."""
|
||||
return f"event: {event.type}\ndata: {event.model_dump_json()}\n\n"
|
||||
|
||||
|
||||
def _extract_content(content: str | list[ResponseContentPart]) -> str:
|
||||
"""Extract plain text from a content field that may be a string or list of parts."""
|
||||
if isinstance(content, str):
|
||||
@@ -135,7 +147,9 @@ async def collect_responses_response(
|
||||
"""Collect all token chunks and return a single ResponsesResponse."""
|
||||
response_id = f"resp_{command_id}"
|
||||
item_id = f"item_{command_id}"
|
||||
reasoning_id = f"rs_{command_id}"
|
||||
accumulated_text = ""
|
||||
thinking_parts: list[str] = []
|
||||
function_call_items: list[ResponseFunctionCallItem] = []
|
||||
last_usage: Usage | None = None
|
||||
error_message: str | None = None
|
||||
@@ -162,6 +176,10 @@ async def collect_responses_response(
|
||||
)
|
||||
continue
|
||||
|
||||
if chunk.is_thinking:
|
||||
thinking_parts.append(chunk.text)
|
||||
continue
|
||||
|
||||
accumulated_text += chunk.text
|
||||
|
||||
if error_message is not None:
|
||||
@@ -176,13 +194,21 @@ async def collect_responses_response(
|
||||
total_tokens=last_usage.total_tokens,
|
||||
)
|
||||
|
||||
output: list[ResponseItem] = [
|
||||
output: list[ResponseItem] = []
|
||||
if thinking_parts:
|
||||
output.append(
|
||||
ResponseReasoningItem(
|
||||
id=reasoning_id,
|
||||
summary=[ResponseReasoningSummaryText(text="".join(thinking_parts))],
|
||||
)
|
||||
)
|
||||
output.append(
|
||||
ResponseMessageItem(
|
||||
id=item_id,
|
||||
content=[ResponseOutputText(text=accumulated_text)],
|
||||
status="completed",
|
||||
)
|
||||
]
|
||||
)
|
||||
output.extend(function_call_items)
|
||||
|
||||
yield ResponsesResponse(
|
||||
@@ -206,6 +232,7 @@ async def generate_responses_stream(
|
||||
"""Generate OpenAI Responses API streaming events from TokenChunks."""
|
||||
response_id = f"resp_{command_id}"
|
||||
item_id = f"item_{command_id}"
|
||||
reasoning_id = f"rs_{command_id}"
|
||||
seq = count(1)
|
||||
|
||||
# response.created
|
||||
@@ -219,40 +246,25 @@ async def generate_responses_stream(
|
||||
created_event = ResponseCreatedEvent(
|
||||
sequence_number=next(seq), response=initial_response
|
||||
)
|
||||
yield f"event: response.created\ndata: {created_event.model_dump_json()}\n\n"
|
||||
yield _format_sse(created_event)
|
||||
|
||||
# response.in_progress
|
||||
in_progress_event = ResponseInProgressEvent(
|
||||
sequence_number=next(seq), response=initial_response
|
||||
)
|
||||
yield f"event: response.in_progress\ndata: {in_progress_event.model_dump_json()}\n\n"
|
||||
|
||||
# response.output_item.added
|
||||
initial_item = ResponseMessageItem(
|
||||
id=item_id,
|
||||
content=[ResponseOutputText(text="")],
|
||||
status="in_progress",
|
||||
)
|
||||
item_added = ResponseOutputItemAddedEvent(
|
||||
sequence_number=next(seq), output_index=0, item=initial_item
|
||||
)
|
||||
yield f"event: response.output_item.added\ndata: {item_added.model_dump_json()}\n\n"
|
||||
|
||||
# response.content_part.added
|
||||
initial_part = ResponseOutputText(text="")
|
||||
part_added = ResponseContentPartAddedEvent(
|
||||
sequence_number=next(seq),
|
||||
item_id=item_id,
|
||||
output_index=0,
|
||||
content_index=0,
|
||||
part=initial_part,
|
||||
)
|
||||
yield f"event: response.content_part.added\ndata: {part_added.model_dump_json()}\n\n"
|
||||
yield _format_sse(in_progress_event)
|
||||
|
||||
accumulated_text = ""
|
||||
accumulated_thinking = ""
|
||||
function_call_items: list[ResponseFunctionCallItem] = []
|
||||
last_usage: Usage | None = None
|
||||
next_output_index = 1 # message item is at 0
|
||||
next_output_index = 0
|
||||
|
||||
# Track dynamic block creation
|
||||
reasoning_started = False
|
||||
reasoning_output_index = -1
|
||||
message_started = False
|
||||
message_output_index = -1
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
if isinstance(chunk, PrefillProgressChunk):
|
||||
@@ -281,7 +293,7 @@ async def generate_responses_stream(
|
||||
output_index=next_output_index,
|
||||
item=fc_item,
|
||||
)
|
||||
yield f"event: response.output_item.added\ndata: {fc_added.model_dump_json()}\n\n"
|
||||
yield _format_sse(fc_added)
|
||||
|
||||
# response.function_call_arguments.delta
|
||||
args_delta = ResponseFunctionCallArgumentsDeltaEvent(
|
||||
@@ -290,7 +302,7 @@ async def generate_responses_stream(
|
||||
output_index=next_output_index,
|
||||
delta=tool.arguments,
|
||||
)
|
||||
yield f"event: response.function_call_arguments.delta\ndata: {args_delta.model_dump_json()}\n\n"
|
||||
yield _format_sse(args_delta)
|
||||
|
||||
# response.function_call_arguments.done
|
||||
args_done = ResponseFunctionCallArgumentsDoneEvent(
|
||||
@@ -300,7 +312,7 @@ async def generate_responses_stream(
|
||||
name=tool.name,
|
||||
arguments=tool.arguments,
|
||||
)
|
||||
yield f"event: response.function_call_arguments.done\ndata: {args_done.model_dump_json()}\n\n"
|
||||
yield _format_sse(args_done)
|
||||
|
||||
# response.output_item.done
|
||||
fc_done_item = ResponseFunctionCallItem(
|
||||
@@ -315,44 +327,205 @@ async def generate_responses_stream(
|
||||
output_index=next_output_index,
|
||||
item=fc_done_item,
|
||||
)
|
||||
yield f"event: response.output_item.done\ndata: {fc_item_done.model_dump_json()}\n\n"
|
||||
yield _format_sse(fc_item_done)
|
||||
|
||||
function_call_items.append(fc_done_item)
|
||||
next_output_index += 1
|
||||
continue
|
||||
|
||||
if chunk.is_thinking:
|
||||
# Start reasoning block on first thinking token
|
||||
if not reasoning_started:
|
||||
reasoning_started = True
|
||||
reasoning_output_index = next_output_index
|
||||
next_output_index += 1
|
||||
|
||||
# response.output_item.added for reasoning
|
||||
reasoning_item = ResponseReasoningItem(
|
||||
id=reasoning_id,
|
||||
summary=[],
|
||||
status="in_progress",
|
||||
)
|
||||
rs_added = ResponseOutputItemAddedEvent(
|
||||
sequence_number=next(seq),
|
||||
output_index=reasoning_output_index,
|
||||
item=reasoning_item,
|
||||
)
|
||||
yield _format_sse(rs_added)
|
||||
|
||||
# response.reasoning_summary_part.added
|
||||
part_added = ResponseReasoningSummaryPartAddedEvent(
|
||||
sequence_number=next(seq),
|
||||
item_id=reasoning_id,
|
||||
output_index=reasoning_output_index,
|
||||
summary_index=0,
|
||||
part=ResponseReasoningSummaryText(text=""),
|
||||
)
|
||||
yield _format_sse(part_added)
|
||||
|
||||
accumulated_thinking += chunk.text
|
||||
|
||||
# response.reasoning_summary_text.delta
|
||||
rs_delta = ResponseReasoningSummaryTextDeltaEvent(
|
||||
sequence_number=next(seq),
|
||||
item_id=reasoning_id,
|
||||
output_index=reasoning_output_index,
|
||||
summary_index=0,
|
||||
delta=chunk.text,
|
||||
)
|
||||
yield _format_sse(rs_delta)
|
||||
continue
|
||||
|
||||
# Close reasoning block when transitioning to text
|
||||
if reasoning_started and not message_started:
|
||||
# response.reasoning_summary_text.done
|
||||
rs_text_done = ResponseReasoningSummaryTextDoneEvent(
|
||||
sequence_number=next(seq),
|
||||
item_id=reasoning_id,
|
||||
output_index=reasoning_output_index,
|
||||
summary_index=0,
|
||||
text=accumulated_thinking,
|
||||
)
|
||||
yield _format_sse(rs_text_done)
|
||||
|
||||
# response.reasoning_summary_part.done
|
||||
rs_part_done = ResponseReasoningSummaryPartDoneEvent(
|
||||
sequence_number=next(seq),
|
||||
item_id=reasoning_id,
|
||||
output_index=reasoning_output_index,
|
||||
summary_index=0,
|
||||
part=ResponseReasoningSummaryText(text=accumulated_thinking),
|
||||
)
|
||||
yield _format_sse(rs_part_done)
|
||||
|
||||
# response.output_item.done for reasoning
|
||||
rs_item_done = ResponseOutputItemDoneEvent(
|
||||
sequence_number=next(seq),
|
||||
output_index=reasoning_output_index,
|
||||
item=ResponseReasoningItem(
|
||||
id=reasoning_id,
|
||||
summary=[ResponseReasoningSummaryText(text=accumulated_thinking)],
|
||||
),
|
||||
)
|
||||
yield _format_sse(rs_item_done)
|
||||
|
||||
# Start message block on first text token
|
||||
if not message_started:
|
||||
message_started = True
|
||||
message_output_index = next_output_index
|
||||
next_output_index += 1
|
||||
|
||||
initial_item = ResponseMessageItem(
|
||||
id=item_id,
|
||||
content=[ResponseOutputText(text="")],
|
||||
status="in_progress",
|
||||
)
|
||||
item_added = ResponseOutputItemAddedEvent(
|
||||
sequence_number=next(seq),
|
||||
output_index=message_output_index,
|
||||
item=initial_item,
|
||||
)
|
||||
yield _format_sse(item_added)
|
||||
|
||||
initial_part = ResponseOutputText(text="")
|
||||
part_added = ResponseContentPartAddedEvent(
|
||||
sequence_number=next(seq),
|
||||
item_id=item_id,
|
||||
output_index=message_output_index,
|
||||
content_index=0,
|
||||
part=initial_part,
|
||||
)
|
||||
yield _format_sse(part_added)
|
||||
|
||||
accumulated_text += chunk.text
|
||||
|
||||
# response.output_text.delta
|
||||
delta_event = ResponseTextDeltaEvent(
|
||||
sequence_number=next(seq),
|
||||
item_id=item_id,
|
||||
output_index=0,
|
||||
output_index=message_output_index,
|
||||
content_index=0,
|
||||
delta=chunk.text,
|
||||
)
|
||||
yield f"event: response.output_text.delta\ndata: {delta_event.model_dump_json()}\n\n"
|
||||
yield _format_sse(delta_event)
|
||||
|
||||
# Close reasoning block if it was never followed by text
|
||||
if reasoning_started and not message_started:
|
||||
rs_text_done = ResponseReasoningSummaryTextDoneEvent(
|
||||
sequence_number=next(seq),
|
||||
item_id=reasoning_id,
|
||||
output_index=reasoning_output_index,
|
||||
summary_index=0,
|
||||
text=accumulated_thinking,
|
||||
)
|
||||
yield _format_sse(rs_text_done)
|
||||
|
||||
rs_part_done = ResponseReasoningSummaryPartDoneEvent(
|
||||
sequence_number=next(seq),
|
||||
item_id=reasoning_id,
|
||||
output_index=reasoning_output_index,
|
||||
summary_index=0,
|
||||
part=ResponseReasoningSummaryText(text=accumulated_thinking),
|
||||
)
|
||||
yield _format_sse(rs_part_done)
|
||||
|
||||
rs_item_done = ResponseOutputItemDoneEvent(
|
||||
sequence_number=next(seq),
|
||||
output_index=reasoning_output_index,
|
||||
item=ResponseReasoningItem(
|
||||
id=reasoning_id,
|
||||
summary=[ResponseReasoningSummaryText(text=accumulated_thinking)],
|
||||
),
|
||||
)
|
||||
yield _format_sse(rs_item_done)
|
||||
|
||||
# If no message block was started, create one now (empty text)
|
||||
if not message_started:
|
||||
message_output_index = next_output_index
|
||||
next_output_index += 1
|
||||
|
||||
initial_item = ResponseMessageItem(
|
||||
id=item_id,
|
||||
content=[ResponseOutputText(text="")],
|
||||
status="in_progress",
|
||||
)
|
||||
item_added = ResponseOutputItemAddedEvent(
|
||||
sequence_number=next(seq),
|
||||
output_index=message_output_index,
|
||||
item=initial_item,
|
||||
)
|
||||
yield _format_sse(item_added)
|
||||
|
||||
initial_part = ResponseOutputText(text="")
|
||||
part_added_evt = ResponseContentPartAddedEvent(
|
||||
sequence_number=next(seq),
|
||||
item_id=item_id,
|
||||
output_index=message_output_index,
|
||||
content_index=0,
|
||||
part=initial_part,
|
||||
)
|
||||
yield _format_sse(part_added_evt)
|
||||
|
||||
# response.output_text.done
|
||||
text_done = ResponseTextDoneEvent(
|
||||
sequence_number=next(seq),
|
||||
item_id=item_id,
|
||||
output_index=0,
|
||||
output_index=message_output_index,
|
||||
content_index=0,
|
||||
text=accumulated_text,
|
||||
)
|
||||
yield f"event: response.output_text.done\ndata: {text_done.model_dump_json()}\n\n"
|
||||
yield _format_sse(text_done)
|
||||
|
||||
# response.content_part.done
|
||||
final_part = ResponseOutputText(text=accumulated_text)
|
||||
part_done = ResponseContentPartDoneEvent(
|
||||
sequence_number=next(seq),
|
||||
item_id=item_id,
|
||||
output_index=0,
|
||||
output_index=message_output_index,
|
||||
content_index=0,
|
||||
part=final_part,
|
||||
)
|
||||
yield f"event: response.content_part.done\ndata: {part_done.model_dump_json()}\n\n"
|
||||
yield _format_sse(part_done)
|
||||
|
||||
# response.output_item.done
|
||||
final_message_item = ResponseMessageItem(
|
||||
@@ -361,9 +534,11 @@ async def generate_responses_stream(
|
||||
status="completed",
|
||||
)
|
||||
item_done = ResponseOutputItemDoneEvent(
|
||||
sequence_number=next(seq), output_index=0, item=final_message_item
|
||||
sequence_number=next(seq),
|
||||
output_index=message_output_index,
|
||||
item=final_message_item,
|
||||
)
|
||||
yield f"event: response.output_item.done\ndata: {item_done.model_dump_json()}\n\n"
|
||||
yield _format_sse(item_done)
|
||||
|
||||
# Create usage from usage data if available
|
||||
usage = None
|
||||
@@ -375,7 +550,15 @@ async def generate_responses_stream(
|
||||
)
|
||||
|
||||
# response.completed
|
||||
output: list[ResponseItem] = [final_message_item]
|
||||
output: list[ResponseItem] = []
|
||||
if reasoning_started:
|
||||
output.append(
|
||||
ResponseReasoningItem(
|
||||
id=reasoning_id,
|
||||
summary=[ResponseReasoningSummaryText(text=accumulated_thinking)],
|
||||
)
|
||||
)
|
||||
output.append(final_message_item)
|
||||
output.extend(function_call_items)
|
||||
final_response = ResponsesResponse(
|
||||
id=response_id,
|
||||
@@ -388,4 +571,4 @@ async def generate_responses_stream(
|
||||
completed_event = ResponseCompletedEvent(
|
||||
sequence_number=next(seq), response=final_response
|
||||
)
|
||||
yield f"event: response.completed\ndata: {completed_event.model_dump_json()}\n\n"
|
||||
yield _format_sse(completed_event)
|
||||
|
||||
@@ -138,7 +138,6 @@ from exo.shared.types.events import (
|
||||
Event,
|
||||
ForwarderEvent,
|
||||
IndexedEvent,
|
||||
PrefillProgress,
|
||||
TracesMerged,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
@@ -1323,7 +1322,7 @@ class API:
|
||||
name=card.model_id.short(),
|
||||
description="",
|
||||
tags=[],
|
||||
storage_size_megabytes=int(card.storage_size.in_mb),
|
||||
storage_size_megabytes=card.storage_size.in_mb,
|
||||
supports_tensor=card.supports_tensor,
|
||||
tasks=[task.value for task in card.tasks],
|
||||
is_custom=is_custom_card(card.model_id),
|
||||
@@ -1455,22 +1454,6 @@ class API:
|
||||
await queue.send(event.chunk)
|
||||
except BrokenResourceError:
|
||||
self._text_generation_queues.pop(event.command_id, None)
|
||||
|
||||
elif isinstance(event, PrefillProgress):
|
||||
if queue := self._text_generation_queues.get(
|
||||
event.command_id, None
|
||||
):
|
||||
try:
|
||||
await queue.send(
|
||||
PrefillProgressChunk(
|
||||
model=event.model,
|
||||
processed_tokens=event.processed_tokens,
|
||||
total_tokens=event.total_tokens,
|
||||
)
|
||||
)
|
||||
except BrokenResourceError:
|
||||
self._text_generation_queues.pop(event.command_id, None)
|
||||
|
||||
if isinstance(event, TracesMerged):
|
||||
self._save_merged_trace(event)
|
||||
|
||||
|
||||
@@ -102,22 +102,21 @@ def _allocate_and_validate_layers(
|
||||
layer_allocations = allocate_layers_proportionally(
|
||||
total_layers=model_card.n_layers,
|
||||
memory_fractions=[
|
||||
node_memory[node_id].ram_available.in_bytes / total_memory.in_bytes
|
||||
for node_id in node_ids
|
||||
node_memory[node_id].ram_available / total_memory for node_id in node_ids
|
||||
],
|
||||
)
|
||||
|
||||
total_storage_bytes = model_card.storage_size.in_bytes
|
||||
total_storage = model_card.storage_size
|
||||
total_layers = model_card.n_layers
|
||||
for i, node_id in enumerate(node_ids):
|
||||
node_layers = layer_allocations[i]
|
||||
required_memory = (total_storage_bytes * node_layers) // total_layers
|
||||
available_memory = node_memory[node_id].ram_available.in_bytes
|
||||
required_memory = (total_storage * node_layers) // total_layers
|
||||
available_memory = node_memory[node_id].ram_available
|
||||
if required_memory > available_memory:
|
||||
raise ValueError(
|
||||
f"Node {i} ({node_id}) has insufficient memory: "
|
||||
f"requires {required_memory / (1024**3):.2f} GB for {node_layers} layers, "
|
||||
f"but only has {available_memory / (1024**3):.2f} GB available"
|
||||
f"requires {required_memory.in_gb:.2f} GB for {node_layers} layers, "
|
||||
f"but only has {available_memory.in_gb:.2f} GB available"
|
||||
)
|
||||
|
||||
return layer_allocations
|
||||
@@ -342,6 +341,7 @@ def _find_ip_prioritised(
|
||||
other_node_id: NodeId,
|
||||
cycle_digraph: Topology,
|
||||
node_network: Mapping[NodeId, NodeNetworkInfo],
|
||||
ring: bool,
|
||||
) -> str | None:
|
||||
"""Find an IP address between nodes with prioritization.
|
||||
|
||||
@@ -354,13 +354,27 @@ def _find_ip_prioritised(
|
||||
ip_to_type = {
|
||||
iface.ip_address: iface.interface_type for iface in other_network.interfaces
|
||||
}
|
||||
priority = {
|
||||
"ethernet": 0,
|
||||
"wifi": 1,
|
||||
"unknown": 2,
|
||||
"maybe_ethernet": 3,
|
||||
"thunderbolt": 4,
|
||||
}
|
||||
|
||||
# Ring should prioritise fastest connection. As a best-effort, we prioritise TB.
|
||||
# TODO: Profile and get actual connection speeds.
|
||||
if ring:
|
||||
priority = {
|
||||
"thunderbolt": 0,
|
||||
"maybe_ethernet": 1,
|
||||
"ethernet": 2,
|
||||
"wifi": 3,
|
||||
"unknown": 4,
|
||||
}
|
||||
|
||||
# RDMA prefers ethernet coordinator
|
||||
else:
|
||||
priority = {
|
||||
"ethernet": 0,
|
||||
"wifi": 1,
|
||||
"unknown": 2,
|
||||
"maybe_ethernet": 3,
|
||||
"thunderbolt": 4,
|
||||
}
|
||||
return min(ips, key=lambda ip: priority.get(ip_to_type.get(ip, "unknown"), 2))
|
||||
|
||||
|
||||
@@ -400,7 +414,7 @@ def get_mlx_ring_hosts_by_node(
|
||||
continue
|
||||
|
||||
connection_ip = _find_ip_prioritised(
|
||||
node_id, other_node_id, cycle_digraph, node_network
|
||||
node_id, other_node_id, cycle_digraph, node_network, ring=True
|
||||
)
|
||||
if connection_ip is None:
|
||||
raise ValueError(
|
||||
@@ -431,7 +445,9 @@ def get_mlx_jaccl_coordinators(
|
||||
if n == coordinator:
|
||||
return "0.0.0.0"
|
||||
|
||||
ip = _find_ip_prioritised(n, coordinator, cycle_digraph, node_network)
|
||||
ip = _find_ip_prioritised(
|
||||
n, coordinator, cycle_digraph, node_network, ring=False
|
||||
)
|
||||
if ip is not None:
|
||||
return ip
|
||||
|
||||
|
||||
@@ -261,7 +261,7 @@ class TestGenerateClaudeStreamToolUse:
|
||||
|
||||
parsed = _parse_sse_events(events)
|
||||
|
||||
# Two tool block starts (at indices 1 and 2)
|
||||
# Two tool block starts (at indices 0 and 1 — no text block when only tools)
|
||||
tool_starts = [
|
||||
e
|
||||
for e in parsed
|
||||
@@ -270,12 +270,11 @@ class TestGenerateClaudeStreamToolUse:
|
||||
== "tool_use"
|
||||
]
|
||||
assert len(tool_starts) == 2
|
||||
assert tool_starts[0]["index"] == 1
|
||||
assert tool_starts[1]["index"] == 2
|
||||
assert tool_starts[0]["index"] == 0
|
||||
assert tool_starts[1]["index"] == 1
|
||||
|
||||
# Two tool block stops (at indices 1 and 2), plus text block stop at 0
|
||||
# Two tool block stops (at indices 0 and 1)
|
||||
block_stops = [e for e in parsed if e.get("type") == "content_block_stop"]
|
||||
stop_indices = [e["index"] for e in block_stops]
|
||||
assert 0 in stop_indices
|
||||
assert 1 in stop_indices
|
||||
assert 2 in stop_indices
|
||||
|
||||
@@ -42,7 +42,7 @@ from exo.utils.channels import channel
|
||||
@pytest.mark.asyncio
|
||||
async def test_master():
|
||||
keypair = get_node_id_keypair()
|
||||
node_id = NodeId(keypair.to_peer_id().to_base58())
|
||||
node_id = NodeId(keypair.to_node_id())
|
||||
session_id = SessionId(master_node_id=node_id, election_clock=0)
|
||||
|
||||
ge_sender, global_event_receiver = channel[ForwarderEvent]()
|
||||
@@ -75,7 +75,7 @@ async def test_master():
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(master.run)
|
||||
|
||||
sender_node_id = NodeId(f"{keypair.to_peer_id().to_base58()}_sender")
|
||||
sender_node_id = NodeId(f"{keypair.to_node_id()}_sender")
|
||||
# inject a NodeGatheredInfo event
|
||||
logger.info("inject a NodeGatheredInfo event")
|
||||
await local_event_sender.send(
|
||||
|
||||
@@ -80,8 +80,8 @@ def test_get_instance_placements_create_instance(
|
||||
):
|
||||
# arrange
|
||||
model_card.n_layers = total_layers
|
||||
model_card.storage_size.in_bytes = sum(
|
||||
available_memory
|
||||
model_card.storage_size = Memory.from_bytes(
|
||||
sum(available_memory)
|
||||
) # make it exactly fit across all nodes
|
||||
topology = Topology()
|
||||
|
||||
@@ -349,7 +349,7 @@ def test_tensor_rdma_backend_connectivity_matrix(
|
||||
# arrange
|
||||
topology = Topology()
|
||||
model_card.n_layers = 12
|
||||
model_card.storage_size.in_bytes = 1500
|
||||
model_card.storage_size = Memory.from_bytes(1500)
|
||||
|
||||
node_a = NodeId()
|
||||
node_b = NodeId()
|
||||
|
||||
@@ -30,7 +30,7 @@ class ConnectionMessage(CamelCaseModel):
|
||||
@classmethod
|
||||
def from_update(cls, update: ConnectionUpdate) -> "ConnectionMessage":
|
||||
return cls(
|
||||
node_id=NodeId(update.peer_id.to_base58()),
|
||||
node_id=NodeId(update.peer_id),
|
||||
connection_type=ConnectionMessageType.from_update_type(update.update_type),
|
||||
remote_ipv4=update.remote_ipv4,
|
||||
remote_tcp_port=update.remote_tcp_port,
|
||||
|
||||
@@ -221,7 +221,7 @@ def get_node_id_keypair(
|
||||
Obtain the :class:`PeerId` by from it.
|
||||
"""
|
||||
# TODO(evan): bring back node id persistence once we figure out how to deal with duplicates
|
||||
return Keypair.generate_ed25519()
|
||||
return Keypair.generate()
|
||||
|
||||
def lock_path(path: str | bytes | PathLike[str] | PathLike[bytes]) -> Path:
|
||||
return Path(str(path) + ".lock")
|
||||
@@ -235,12 +235,12 @@ def get_node_id_keypair(
|
||||
protobuf_encoded = f.read()
|
||||
|
||||
try: # if decoded successfully, save & return
|
||||
return Keypair.from_protobuf_encoding(protobuf_encoded)
|
||||
return Keypair.from_bytes(protobuf_encoded)
|
||||
except ValueError as e: # on runtime error, assume corrupt file
|
||||
logger.warning(f"Encountered error when trying to get keypair: {e}")
|
||||
|
||||
# if no valid credentials, create new ones and persist
|
||||
with open(path, "w+b") as f:
|
||||
keypair = Keypair.generate_ed25519()
|
||||
f.write(keypair.to_protobuf_encoding())
|
||||
f.write(keypair.to_bytes())
|
||||
return keypair
|
||||
|
||||
@@ -15,7 +15,6 @@ from exo.shared.types.events import (
|
||||
NodeDownloadProgress,
|
||||
NodeGatheredInfo,
|
||||
NodeTimedOut,
|
||||
PrefillProgress,
|
||||
RunnerDeleted,
|
||||
RunnerStatusUpdated,
|
||||
TaskAcknowledged,
|
||||
@@ -65,7 +64,6 @@ def event_apply(event: Event, state: State) -> State:
|
||||
| ChunkGenerated()
|
||||
| TaskAcknowledged()
|
||||
| InputChunkReceived()
|
||||
| PrefillProgress()
|
||||
| TracesCollected()
|
||||
| TracesMerged()
|
||||
): # Pass-through events that don't modify state
|
||||
|
||||
@@ -14,7 +14,7 @@ def test_apply_node_download_progress():
|
||||
event = DownloadCompleted(
|
||||
node_id=NodeId("node-1"),
|
||||
shard_metadata=shard1,
|
||||
total_bytes=Memory(),
|
||||
total=Memory(),
|
||||
)
|
||||
|
||||
new_state = apply_node_download_progress(
|
||||
@@ -30,12 +30,12 @@ def test_apply_two_node_download_progress():
|
||||
event1 = DownloadCompleted(
|
||||
node_id=NodeId("node-1"),
|
||||
shard_metadata=shard1,
|
||||
total_bytes=Memory(),
|
||||
total=Memory(),
|
||||
)
|
||||
event2 = DownloadCompleted(
|
||||
node_id=NodeId("node-1"),
|
||||
shard_metadata=shard2,
|
||||
total_bytes=Memory(),
|
||||
total=Memory(),
|
||||
)
|
||||
state = State(downloads={NodeId("node-1"): [event1]})
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ def _get_keypair_concurrent_subprocess_task(
|
||||
sem.release()
|
||||
# wait to be told to begin simultaneous read
|
||||
ev.wait()
|
||||
queue.put(get_node_id_keypair().to_protobuf_encoding())
|
||||
queue.put(get_node_id_keypair().to_bytes())
|
||||
|
||||
|
||||
def _get_keypair_concurrent(num_procs: int) -> bytes:
|
||||
|
||||
@@ -77,7 +77,7 @@ class ChatCompletionMessage(BaseModel):
|
||||
content: (
|
||||
str | ChatCompletionMessageText | list[ChatCompletionMessageText] | None
|
||||
) = None
|
||||
thinking: str | None = None # Added for GPT-OSS harmony format support
|
||||
reasoning_content: str | None = None
|
||||
name: str | None = None
|
||||
tool_calls: list[ToolCall] | None = None
|
||||
tool_call_id: str | None = None
|
||||
|
||||
@@ -27,6 +27,7 @@ class TokenChunk(BaseChunk):
|
||||
stats: GenerationStats | None = None
|
||||
logprob: float | None = None
|
||||
top_logprobs: list[TopLogprobItem] | None = None
|
||||
is_thinking: bool = False
|
||||
|
||||
|
||||
class ErrorChunk(BaseChunk):
|
||||
|
||||
@@ -47,6 +47,14 @@ class ClaudeImageBlock(BaseModel, frozen=True):
|
||||
source: ClaudeImageSource
|
||||
|
||||
|
||||
class ClaudeThinkingBlock(BaseModel, frozen=True):
|
||||
"""Thinking content block in Claude Messages API."""
|
||||
|
||||
type: Literal["thinking"] = "thinking"
|
||||
thinking: str
|
||||
signature: str | None = None
|
||||
|
||||
|
||||
class ClaudeToolUseBlock(BaseModel, frozen=True):
|
||||
"""Tool use content block in Claude Messages API."""
|
||||
|
||||
@@ -66,11 +74,17 @@ class ClaudeToolResultBlock(BaseModel, frozen=True):
|
||||
cache_control: dict[str, str] | None = None
|
||||
|
||||
|
||||
ClaudeContentBlock = ClaudeTextBlock | ClaudeImageBlock | ClaudeToolUseBlock
|
||||
ClaudeContentBlock = (
|
||||
ClaudeTextBlock | ClaudeImageBlock | ClaudeThinkingBlock | ClaudeToolUseBlock
|
||||
)
|
||||
|
||||
# Input content blocks can also include tool_result (sent by user after tool_use)
|
||||
ClaudeInputContentBlock = (
|
||||
ClaudeTextBlock | ClaudeImageBlock | ClaudeToolUseBlock | ClaudeToolResultBlock
|
||||
ClaudeTextBlock
|
||||
| ClaudeImageBlock
|
||||
| ClaudeThinkingBlock
|
||||
| ClaudeToolUseBlock
|
||||
| ClaudeToolResultBlock
|
||||
)
|
||||
|
||||
|
||||
@@ -82,6 +96,11 @@ class ClaudeMessage(BaseModel, frozen=True):
|
||||
content: str | list[ClaudeInputContentBlock]
|
||||
|
||||
|
||||
class ClaudeThinkingConfig(BaseModel, frozen=True):
|
||||
type: Literal["enabled", "disabled", "adaptive"]
|
||||
budget_tokens: int | None = None
|
||||
|
||||
|
||||
class ClaudeMessagesRequest(BaseModel):
|
||||
"""Request body for Claude Messages API."""
|
||||
|
||||
@@ -96,6 +115,7 @@ class ClaudeMessagesRequest(BaseModel):
|
||||
top_k: int | None = None
|
||||
tools: list[ClaudeToolDefinition] | None = None
|
||||
metadata: dict[str, str] | None = None
|
||||
thinking: ClaudeThinkingConfig | None = None
|
||||
|
||||
|
||||
# Response types
|
||||
@@ -145,7 +165,7 @@ class ClaudeContentBlockStartEvent(BaseModel, frozen=True):
|
||||
|
||||
type: Literal["content_block_start"] = "content_block_start"
|
||||
index: int
|
||||
content_block: ClaudeTextBlock | ClaudeToolUseBlock
|
||||
content_block: ClaudeTextBlock | ClaudeThinkingBlock | ClaudeToolUseBlock
|
||||
|
||||
|
||||
class ClaudeTextDelta(BaseModel, frozen=True):
|
||||
@@ -155,6 +175,13 @@ class ClaudeTextDelta(BaseModel, frozen=True):
|
||||
text: str
|
||||
|
||||
|
||||
class ClaudeThinkingDelta(BaseModel, frozen=True):
|
||||
"""Delta for thinking content block."""
|
||||
|
||||
type: Literal["thinking_delta"] = "thinking_delta"
|
||||
thinking: str
|
||||
|
||||
|
||||
class ClaudeInputJsonDelta(BaseModel, frozen=True):
|
||||
"""Delta for tool use input JSON content block."""
|
||||
|
||||
@@ -167,7 +194,7 @@ class ClaudeContentBlockDeltaEvent(BaseModel, frozen=True):
|
||||
|
||||
type: Literal["content_block_delta"] = "content_block_delta"
|
||||
index: int
|
||||
delta: ClaudeTextDelta | ClaudeInputJsonDelta
|
||||
delta: ClaudeTextDelta | ClaudeThinkingDelta | ClaudeInputJsonDelta
|
||||
|
||||
|
||||
class ClaudeContentBlockStopEvent(BaseModel, frozen=True):
|
||||
|
||||
@@ -5,7 +5,7 @@ from pydantic import Field
|
||||
|
||||
from exo.shared.topology import Connection
|
||||
from exo.shared.types.chunks import GenerationChunk, InputImageChunk
|
||||
from exo.shared.types.common import CommandId, Id, ModelId, NodeId, SessionId
|
||||
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
|
||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||
from exo.shared.types.worker.downloads import DownloadProgress
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId
|
||||
@@ -102,13 +102,6 @@ class InputChunkReceived(BaseEvent):
|
||||
chunk: InputImageChunk
|
||||
|
||||
|
||||
class PrefillProgress(BaseEvent):
|
||||
command_id: CommandId
|
||||
model: ModelId
|
||||
processed_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class TopologyEdgeCreated(BaseEvent):
|
||||
conn: Connection
|
||||
|
||||
@@ -155,7 +148,6 @@ Event = (
|
||||
| NodeDownloadProgress
|
||||
| ChunkGenerated
|
||||
| InputChunkReceived
|
||||
| PrefillProgress
|
||||
| TopologyEdgeCreated
|
||||
| TopologyEdgeDeleted
|
||||
| TracesCollected
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from math import ceil
|
||||
from typing import Self
|
||||
from typing import Self, overload
|
||||
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
from exo.utils.pydantic_ext import FrozenModel
|
||||
|
||||
|
||||
class Memory(CamelCaseModel):
|
||||
class Memory(FrozenModel):
|
||||
in_bytes: int = 0
|
||||
|
||||
@classmethod
|
||||
@@ -33,12 +33,22 @@ class Memory(CamelCaseModel):
|
||||
return cls(in_bytes=round(val * 1024))
|
||||
|
||||
@property
|
||||
def in_mb(self) -> float:
|
||||
"""The approximate megabytes this memory represents. Setting this property rounds to the nearest byte."""
|
||||
return self.in_bytes / (1024**2)
|
||||
def in_mb(self) -> int:
|
||||
"""The approximate megabytes this memory represents, rounded to nearest MB. Setting this property rounds to the nearest byte."""
|
||||
return round(self.in_bytes / (1024**2))
|
||||
|
||||
@in_mb.setter
|
||||
def in_mb(self, val: float):
|
||||
def in_mb(self, val: int):
|
||||
"""Set the megabytes for this memory."""
|
||||
self.in_bytes = val * (1024**2)
|
||||
|
||||
@property
|
||||
def in_float_mb(self) -> float:
|
||||
"""The megabytes this memory represents as a float. Setting this property rounds to the nearest byte."""
|
||||
return self.in_bytes / (1024**2)
|
||||
|
||||
@in_float_mb.setter
|
||||
def in_float_mb(self, val: float):
|
||||
"""Set the megabytes for this memory, rounded to the nearest byte."""
|
||||
self.in_bytes = round(val * (1024**2))
|
||||
|
||||
@@ -57,17 +67,85 @@ class Memory(CamelCaseModel):
|
||||
"""The approximate gigabytes this memory represents."""
|
||||
return self.in_bytes / (1024**3)
|
||||
|
||||
def __add__(self, other: "Memory") -> "Memory":
|
||||
return Memory.from_bytes(self.in_bytes + other.in_bytes)
|
||||
def __add__(self, other: object) -> "Memory":
|
||||
if isinstance(other, Memory):
|
||||
return Memory.from_bytes(self.in_bytes + other.in_bytes)
|
||||
return NotImplemented
|
||||
|
||||
def __lt__(self, other: Self) -> bool:
|
||||
return self.in_bytes < other.in_bytes
|
||||
def __radd__(self, other: object) -> "Memory":
|
||||
if other == 0:
|
||||
return self
|
||||
return NotImplemented
|
||||
|
||||
def __le__(self, other: Self) -> bool:
|
||||
return self.in_bytes <= other.in_bytes
|
||||
def __sub__(self, other: object) -> "Memory":
|
||||
if isinstance(other, Memory):
|
||||
return Memory.from_bytes(self.in_bytes - other.in_bytes)
|
||||
return NotImplemented
|
||||
|
||||
def __gt__(self, other: Self) -> bool:
|
||||
return self.in_bytes > other.in_bytes
|
||||
def __mul__(self, other: int | float):
|
||||
return Memory.from_bytes(round(self.in_bytes * other))
|
||||
|
||||
def __ge__(self, other: Self) -> bool:
|
||||
return self.in_bytes >= other.in_bytes
|
||||
def __rmul__(self, other: int | float):
|
||||
return self * other
|
||||
|
||||
@overload
|
||||
def __truediv__(self, other: "Memory") -> float: ...
|
||||
@overload
|
||||
def __truediv__(self, other: int) -> "Memory": ...
|
||||
@overload
|
||||
def __truediv__(self, other: float) -> "Memory": ...
|
||||
def __truediv__(self, other: object) -> "Memory | float":
|
||||
if isinstance(other, Memory):
|
||||
return self.in_bytes / other.in_bytes
|
||||
if isinstance(other, (int, float)):
|
||||
return Memory.from_bytes(round(self.in_bytes / other))
|
||||
return NotImplemented
|
||||
|
||||
def __floordiv__(self, other: object) -> "Memory":
|
||||
if isinstance(other, (int, float)):
|
||||
return Memory.from_bytes(int(self.in_bytes // other))
|
||||
return NotImplemented
|
||||
|
||||
def __lt__(self, other: object) -> bool:
|
||||
if isinstance(other, Memory):
|
||||
return self.in_bytes < other.in_bytes
|
||||
return NotImplemented
|
||||
|
||||
def __le__(self, other: object) -> bool:
|
||||
if isinstance(other, Memory):
|
||||
return self.in_bytes <= other.in_bytes
|
||||
return NotImplemented
|
||||
|
||||
def __gt__(self, other: object) -> bool:
|
||||
if isinstance(other, Memory):
|
||||
return self.in_bytes > other.in_bytes
|
||||
return NotImplemented
|
||||
|
||||
def __ge__(self, other: object) -> bool:
|
||||
if isinstance(other, Memory):
|
||||
return self.in_bytes >= other.in_bytes
|
||||
return NotImplemented
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if isinstance(other, Memory):
|
||||
return self.in_bytes == other.in_bytes
|
||||
return NotImplemented
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Memory.from_bytes({self.in_bytes})"
|
||||
|
||||
def __str__(self) -> str:
|
||||
if self.in_gb > 2:
|
||||
val = self.in_gb
|
||||
unit = "GiB"
|
||||
elif self.in_mb > 2:
|
||||
val = self.in_mb
|
||||
unit = "MiB"
|
||||
elif self.in_kb > 3:
|
||||
val = self.in_kb
|
||||
unit = "KiB"
|
||||
else:
|
||||
val = self.in_bytes
|
||||
unit = "B"
|
||||
|
||||
return f"{val:.2f} {unit}".rstrip("0").rstrip(".") + f" {unit}"
|
||||
|
||||
@@ -145,7 +145,23 @@ class ResponseFunctionCallItem(BaseModel, frozen=True):
|
||||
status: ResponseStatus = "completed"
|
||||
|
||||
|
||||
ResponseItem = ResponseMessageItem | ResponseFunctionCallItem
|
||||
class ResponseReasoningSummaryText(BaseModel, frozen=True):
|
||||
"""Summary text part in a reasoning output item."""
|
||||
|
||||
type: Literal["summary_text"] = "summary_text"
|
||||
text: str
|
||||
|
||||
|
||||
class ResponseReasoningItem(BaseModel, frozen=True):
|
||||
"""Reasoning output item in response output array."""
|
||||
|
||||
type: Literal["reasoning"] = "reasoning"
|
||||
id: str
|
||||
summary: list[ResponseReasoningSummaryText] = Field(default_factory=list)
|
||||
status: ResponseStatus = "completed"
|
||||
|
||||
|
||||
ResponseItem = ResponseMessageItem | ResponseFunctionCallItem | ResponseReasoningItem
|
||||
|
||||
|
||||
class ResponseUsage(BaseModel, frozen=True):
|
||||
@@ -273,6 +289,58 @@ class ResponseFunctionCallArgumentsDoneEvent(BaseModel, frozen=True):
|
||||
arguments: str
|
||||
|
||||
|
||||
class ResponseReasoningSummaryPartAddedEvent(BaseModel, frozen=True):
|
||||
"""Event sent when a reasoning summary part is added."""
|
||||
|
||||
type: Literal["response.reasoning_summary_part.added"] = (
|
||||
"response.reasoning_summary_part.added"
|
||||
)
|
||||
sequence_number: int
|
||||
item_id: str
|
||||
output_index: int
|
||||
summary_index: int
|
||||
part: ResponseReasoningSummaryText
|
||||
|
||||
|
||||
class ResponseReasoningSummaryTextDeltaEvent(BaseModel, frozen=True):
|
||||
"""Event sent for reasoning summary text delta during streaming."""
|
||||
|
||||
type: Literal["response.reasoning_summary_text.delta"] = (
|
||||
"response.reasoning_summary_text.delta"
|
||||
)
|
||||
sequence_number: int
|
||||
item_id: str
|
||||
output_index: int
|
||||
summary_index: int
|
||||
delta: str
|
||||
|
||||
|
||||
class ResponseReasoningSummaryTextDoneEvent(BaseModel, frozen=True):
|
||||
"""Event sent when reasoning summary text is done."""
|
||||
|
||||
type: Literal["response.reasoning_summary_text.done"] = (
|
||||
"response.reasoning_summary_text.done"
|
||||
)
|
||||
sequence_number: int
|
||||
item_id: str
|
||||
output_index: int
|
||||
summary_index: int
|
||||
text: str
|
||||
|
||||
|
||||
class ResponseReasoningSummaryPartDoneEvent(BaseModel, frozen=True):
|
||||
"""Event sent when a reasoning summary part is done."""
|
||||
|
||||
type: Literal["response.reasoning_summary_part.done"] = (
|
||||
"response.reasoning_summary_part.done"
|
||||
)
|
||||
sequence_number: int
|
||||
item_id: str
|
||||
output_index: int
|
||||
summary_index: int
|
||||
part: ResponseReasoningSummaryText
|
||||
|
||||
|
||||
class ResponseCompletedEvent(BaseModel, frozen=True):
|
||||
"""Event sent when response is completed."""
|
||||
|
||||
@@ -292,5 +360,9 @@ ResponsesStreamEvent = (
|
||||
| ResponseOutputItemDoneEvent
|
||||
| ResponseFunctionCallArgumentsDeltaEvent
|
||||
| ResponseFunctionCallArgumentsDoneEvent
|
||||
| ResponseReasoningSummaryPartAddedEvent
|
||||
| ResponseReasoningSummaryTextDeltaEvent
|
||||
| ResponseReasoningSummaryTextDoneEvent
|
||||
| ResponseReasoningSummaryPartDoneEvent
|
||||
| ResponseCompletedEvent
|
||||
)
|
||||
|
||||
@@ -10,9 +10,9 @@ from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
|
||||
|
||||
|
||||
class DownloadProgressData(CamelCaseModel):
|
||||
total_bytes: Memory
|
||||
downloaded_bytes: Memory
|
||||
downloaded_bytes_this_session: Memory
|
||||
total: Memory
|
||||
downloaded: Memory
|
||||
downloaded_this_session: Memory
|
||||
|
||||
completed_files: int
|
||||
total_files: int
|
||||
@@ -34,7 +34,7 @@ class DownloadPending(BaseDownloadProgress):
|
||||
|
||||
|
||||
class DownloadCompleted(BaseDownloadProgress):
|
||||
total_bytes: Memory
|
||||
total: Memory
|
||||
|
||||
|
||||
class DownloadFailed(BaseDownloadProgress):
|
||||
@@ -86,9 +86,9 @@ class RepoDownloadProgress(BaseModel):
|
||||
shard: ShardMetadata
|
||||
completed_files: int
|
||||
total_files: int
|
||||
downloaded_bytes: Memory
|
||||
downloaded_bytes_this_session: Memory
|
||||
total_bytes: Memory
|
||||
downloaded: Memory
|
||||
downloaded_this_session: Memory
|
||||
total: Memory
|
||||
overall_speed: float
|
||||
overall_eta: timedelta
|
||||
status: Literal["not_started", "in_progress", "complete"]
|
||||
|
||||
@@ -28,6 +28,7 @@ class GenerationResponse(BaseRunnerResponse):
|
||||
finish_reason: FinishReason | None = None
|
||||
stats: GenerationStats | None = None
|
||||
usage: Usage | None
|
||||
is_thinking: bool = False
|
||||
|
||||
|
||||
class ImageGenerationResponse(BaseRunnerResponse):
|
||||
|
||||
@@ -192,7 +192,13 @@ class MpReceiver[T]:
|
||||
try:
|
||||
return self.receive_nowait()
|
||||
except WouldBlock:
|
||||
item = self._state.buffer.get()
|
||||
try:
|
||||
item = self._state.buffer.get()
|
||||
except (TypeError, OSError):
|
||||
# Queue pipe can get closed while we are blocked on get().
|
||||
# The underlying connection._handle becomes None, causing
|
||||
# TypeError in read(handle, remaining).
|
||||
raise ClosedResourceError from None
|
||||
if isinstance(item, _MpEndOfStream):
|
||||
self.close()
|
||||
raise EndOfStream from None
|
||||
|
||||
@@ -108,7 +108,7 @@ async def check_reachable(
|
||||
await send.send((target_ip, expected_node_id))
|
||||
|
||||
async with (
|
||||
httpx.AsyncClient(timeout=timeout, limits=limits) as client,
|
||||
httpx.AsyncClient(timeout=timeout, limits=limits, verify=False) as client,
|
||||
create_task_group() as tg,
|
||||
):
|
||||
for node_id in topology.list_nodes():
|
||||
|
||||
@@ -166,7 +166,7 @@ def generate_image(
|
||||
else 0.0
|
||||
)
|
||||
|
||||
peak_memory_gb = mx.get_peak_memory() / (1024**3)
|
||||
peak_memory = Memory.from_bytes(mx.get_peak_memory())
|
||||
|
||||
stats = ImageGenerationStats(
|
||||
seconds_per_step=seconds_per_step,
|
||||
@@ -175,7 +175,7 @@ def generate_image(
|
||||
num_images=num_images,
|
||||
image_width=width,
|
||||
image_height=height,
|
||||
peak_memory_usage=Memory.from_gb(peak_memory_gb),
|
||||
peak_memory_usage=peak_memory,
|
||||
)
|
||||
|
||||
buffer = io.BytesIO()
|
||||
|
||||
@@ -22,7 +22,7 @@ from exo.worker.runner.bootstrap import logger
|
||||
# Fraction of device memory above which LRU eviction kicks in.
|
||||
# Smaller machines need more aggressive eviction.
|
||||
def _default_memory_threshold() -> float:
|
||||
total_gb = psutil.virtual_memory().total / (1024**3)
|
||||
total_gb = Memory.from_bytes(psutil.virtual_memory().total).in_gb
|
||||
if total_gb >= 128:
|
||||
return 0.85
|
||||
if total_gb >= 64:
|
||||
|
||||
72
src/exo/worker/engines/mlx/dsml_encoding.py
Normal file
72
src/exo/worker/engines/mlx/dsml_encoding.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import json
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from mlx_lm.chat_templates import deepseek_v32
|
||||
|
||||
from exo.shared.types.api import ToolCallItem
|
||||
|
||||
BOS_TOKEN: str = deepseek_v32.bos_token
|
||||
EOS_TOKEN: str = deepseek_v32.eos_token
|
||||
DSML_TOKEN: str = deepseek_v32.dsml_token
|
||||
THINKING_START: str = deepseek_v32.thinking_start_token
|
||||
THINKING_END: str = deepseek_v32.thinking_end_token
|
||||
USER_TOKEN = "<\uff5cUser\uff5c>"
|
||||
ASSISTANT_TOKEN = "<\uff5cAssistant\uff5c>"
|
||||
TOOL_CALLS_START = f"<{DSML_TOKEN}function_calls>"
|
||||
TOOL_CALLS_END = f"</{DSML_TOKEN}function_calls>"
|
||||
encode_messages = deepseek_v32.encode_messages
|
||||
|
||||
_INVOKE_PATTERN = re.compile(
|
||||
rf"<{re.escape(DSML_TOKEN)}invoke\s+name=\"([^\"]+)\">"
|
||||
rf"(.*?)"
|
||||
rf"</{re.escape(DSML_TOKEN)}invoke>",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
_PARAM_PATTERN = re.compile(
|
||||
rf"<{re.escape(DSML_TOKEN)}parameter\s+name=\"([^\"]+)\"\s+string=\"(true|false)\">"
|
||||
rf"(.*?)"
|
||||
rf"</{re.escape(DSML_TOKEN)}parameter>",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
|
||||
def parse_dsml_output(text: str) -> list[ToolCallItem] | None:
|
||||
"""Parse DSML function_calls block from model output text.
|
||||
|
||||
Args:
|
||||
text: The text containing the DSML function_calls block
|
||||
(including the start/end markers).
|
||||
|
||||
Returns:
|
||||
List of ToolCallItem, or None if parsing fails.
|
||||
"""
|
||||
tool_calls: list[ToolCallItem] = []
|
||||
|
||||
for invoke_match in _INVOKE_PATTERN.finditer(text):
|
||||
func_name = invoke_match.group(1)
|
||||
invoke_body = invoke_match.group(2)
|
||||
|
||||
args: dict[str, Any] = {}
|
||||
for param_match in _PARAM_PATTERN.finditer(invoke_body):
|
||||
param_name = param_match.group(1)
|
||||
is_string = param_match.group(2) == "true"
|
||||
param_value = param_match.group(3)
|
||||
|
||||
if is_string:
|
||||
args[param_name] = param_value
|
||||
else:
|
||||
try:
|
||||
args[param_name] = json.loads(param_value)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
args[param_name] = param_value
|
||||
|
||||
tool_calls.append(
|
||||
ToolCallItem(
|
||||
name=func_name,
|
||||
arguments=json.dumps(args),
|
||||
)
|
||||
)
|
||||
|
||||
return tool_calls if tool_calls else None
|
||||
@@ -232,11 +232,11 @@ def shard_and_load(
|
||||
|
||||
# Estimate timeout based on model size (5x default for large queued workloads)
|
||||
base_timeout = float(os.environ.get("EXO_MODEL_LOAD_TIMEOUT", "300"))
|
||||
model_size_gb = get_weights_size(shard_metadata).in_bytes / (1024**3)
|
||||
timeout_seconds = base_timeout + model_size_gb
|
||||
model_size = get_weights_size(shard_metadata)
|
||||
timeout_seconds = base_timeout + model_size.in_gb
|
||||
logger.info(
|
||||
f"Evaluating model parameters with timeout of {timeout_seconds:.0f}s "
|
||||
f"(model size: {model_size_gb:.1f}GB)"
|
||||
f"(model size: {model_size.in_gb:.1f}GB)"
|
||||
)
|
||||
|
||||
match shard_metadata:
|
||||
@@ -458,6 +458,19 @@ def _patch_lossy_chat_template(template: str) -> str | None:
|
||||
return patched if n > 0 else None
|
||||
|
||||
|
||||
def _needs_dsml_encoding(task_params: TextGenerationTaskParams) -> bool:
|
||||
if "deepseek-v3.2" not in task_params.model.lower():
|
||||
return False
|
||||
# Use DSML encoding when tools are provided or tool results are in the conversation
|
||||
if task_params.tools:
|
||||
return True
|
||||
if task_params.chat_template_messages:
|
||||
return any(
|
||||
msg.get("role") == "tool" for msg in task_params.chat_template_messages
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
def apply_chat_template(
|
||||
tokenizer: TokenizerWrapper,
|
||||
task_params: TextGenerationTaskParams,
|
||||
@@ -469,7 +482,6 @@ def apply_chat_template(
|
||||
|
||||
When chat_template_messages is available (from Chat Completions API),
|
||||
uses those directly to preserve tool_calls, thinking, and other fields.
|
||||
Otherwise builds messages from the task params input/instructions.
|
||||
"""
|
||||
formatted_messages: list[dict[str, Any]] = []
|
||||
if task_params.chat_template_messages is not None:
|
||||
@@ -497,6 +509,19 @@ def apply_chat_template(
|
||||
partial_assistant_content = cast(str, formatted_messages[-1].get("content", ""))
|
||||
formatted_messages = formatted_messages[:-1]
|
||||
|
||||
if _needs_dsml_encoding(task_params):
|
||||
from exo.worker.engines.mlx.dsml_encoding import encode_messages
|
||||
|
||||
prompt = encode_messages(
|
||||
messages=formatted_messages,
|
||||
thinking_mode="thinking" if task_params.enable_thinking else "chat",
|
||||
tools=task_params.tools,
|
||||
)
|
||||
if partial_assistant_content:
|
||||
prompt += partial_assistant_content
|
||||
logger.info(prompt)
|
||||
return prompt
|
||||
|
||||
extra_kwargs: dict[str, Any] = {}
|
||||
if task_params.enable_thinking is not None:
|
||||
# Qwen3 and GLM use "enable_thinking"; DeepSeek uses "thinking".
|
||||
@@ -617,18 +642,17 @@ def set_wired_limit_for_model(model_size: Memory):
|
||||
if not mx.metal.is_available():
|
||||
return
|
||||
|
||||
model_bytes = model_size.in_bytes
|
||||
max_rec_size = int(mx.metal.device_info()["max_recommended_working_set_size"])
|
||||
if model_bytes > 0.9 * max_rec_size:
|
||||
model_mb = model_bytes // 2**20
|
||||
max_rec_mb = max_rec_size // 2**20
|
||||
max_rec_size = Memory.from_bytes(
|
||||
int(mx.metal.device_info()["max_recommended_working_set_size"])
|
||||
)
|
||||
if model_size > 0.9 * max_rec_size:
|
||||
logger.warning(
|
||||
f"Generating with a model that requires {model_mb} MB "
|
||||
f"which is close to the maximum recommended size of {max_rec_mb} "
|
||||
f"Generating with a model that requires {model_size.in_float_mb:.1f} MB "
|
||||
f"which is close to the maximum recommended size of {max_rec_size.in_float_mb:.1f} "
|
||||
"MB. This can be slow. See the documentation for possible work-arounds: "
|
||||
"https://github.com/ml-explore/mlx-lm/tree/main#large-models"
|
||||
)
|
||||
mx.set_wired_limit(max_rec_size)
|
||||
mx.set_wired_limit(max_rec_size.in_bytes)
|
||||
logger.info(f"Wired limit set to {max_rec_size}.")
|
||||
|
||||
|
||||
|
||||
@@ -241,6 +241,11 @@ class Worker:
|
||||
cancelled_task_id=cancelled_task_id, runner_id=runner_id
|
||||
):
|
||||
await self.runners[runner_id].cancel_task(cancelled_task_id)
|
||||
await self.event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id, task_status=TaskStatus.Complete
|
||||
)
|
||||
)
|
||||
case ImageEdits() if task.task_params.total_input_chunks > 0:
|
||||
# Assemble image from chunks and inject into task
|
||||
cmd_id = task.command_id
|
||||
|
||||
@@ -4,9 +4,10 @@ import resource
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from functools import cache
|
||||
from typing import Literal
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm.models.deepseek_v32 import Model as DeepseekV32Model
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
|
||||
@@ -21,12 +22,17 @@ from exo.shared.constants import EXO_MAX_CHUNK_SIZE, EXO_TRACING_ENABLED
|
||||
from exo.shared.models.model_cards import ModelId, ModelTask
|
||||
from exo.shared.tracing import clear_trace_buffer, get_trace_buffer
|
||||
from exo.shared.types.api import ImageGenerationStats
|
||||
from exo.shared.types.chunks import ErrorChunk, ImageChunk, TokenChunk, ToolCallChunk
|
||||
from exo.shared.types.chunks import (
|
||||
ErrorChunk,
|
||||
ImageChunk,
|
||||
PrefillProgressChunk,
|
||||
TokenChunk,
|
||||
ToolCallChunk,
|
||||
)
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.events import (
|
||||
ChunkGenerated,
|
||||
Event,
|
||||
PrefillProgress,
|
||||
RunnerStatusUpdated,
|
||||
TaskAcknowledged,
|
||||
TaskStatusUpdated,
|
||||
@@ -315,11 +321,13 @@ def main(
|
||||
) -> None:
|
||||
if device_rank == 0:
|
||||
event_sender.send(
|
||||
PrefillProgress(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
model=shard_metadata.model_card.model_id,
|
||||
processed_tokens=processed,
|
||||
total_tokens=total,
|
||||
chunk=PrefillProgressChunk(
|
||||
model=shard_metadata.model_card.model_id,
|
||||
processed_tokens=processed,
|
||||
total_tokens=total,
|
||||
),
|
||||
)
|
||||
)
|
||||
cancelled_tasks.update(cancel_receiver.collect())
|
||||
@@ -346,16 +354,22 @@ def main(
|
||||
group=group,
|
||||
)
|
||||
|
||||
# For other thinking models (GLM, etc.), check if we need to
|
||||
# prepend the thinking tag that was consumed by the chat template
|
||||
if detect_thinking_prompt_suffix(prompt, tokenizer):
|
||||
if tokenizer.has_thinking:
|
||||
mlx_generator = parse_thinking_models(
|
||||
mlx_generator, tokenizer
|
||||
mlx_generator,
|
||||
tokenizer,
|
||||
# For other thinking models (GLM, etc.), check if we need to
|
||||
# prepend the thinking tag that was consumed by the chat template
|
||||
starts_in_thinking=detect_thinking_prompt_suffix(
|
||||
prompt, tokenizer
|
||||
),
|
||||
)
|
||||
|
||||
# GPT-OSS specific parsing to match other model formats.
|
||||
# Model-specific output parsing for tool calls.
|
||||
if isinstance(inference_model, GptOssModel):
|
||||
mlx_generator = parse_gpt_oss(mlx_generator)
|
||||
elif isinstance(inference_model, DeepseekV32Model):
|
||||
mlx_generator = parse_deepseek_v32(mlx_generator)
|
||||
elif tool_parser:
|
||||
mlx_generator = parse_tool_calls(mlx_generator, tool_parser)
|
||||
|
||||
@@ -407,6 +421,7 @@ def main(
|
||||
stats=response.stats,
|
||||
logprob=response.logprob,
|
||||
top_logprobs=response.top_logprobs,
|
||||
is_thinking=response.is_thinking,
|
||||
),
|
||||
)
|
||||
)
|
||||
@@ -573,6 +588,13 @@ def main(
|
||||
case Shutdown():
|
||||
current_status = RunnerShuttingDown()
|
||||
logger.info("runner shutting down")
|
||||
if not TYPE_CHECKING:
|
||||
del inference_model, image_model, tokenizer, group
|
||||
mx.clear_cache()
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
@@ -597,12 +619,8 @@ def main(
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
|
||||
)
|
||||
if isinstance(current_status, RunnerShutdown):
|
||||
del inference_model, image_model, tokenizer, group
|
||||
mx.clear_cache()
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
if isinstance(current_status, RunnerShutdown):
|
||||
break
|
||||
|
||||
|
||||
@@ -668,44 +686,208 @@ def parse_gpt_oss(
|
||||
|
||||
if ch == "analysis" and not thinking:
|
||||
thinking = True
|
||||
yield response.model_copy(update={"text": "<think>"})
|
||||
|
||||
if ch != "analysis" and thinking:
|
||||
thinking = False
|
||||
yield response.model_copy(update={"text": "</think>"})
|
||||
|
||||
if delta:
|
||||
yield response.model_copy(update={"text": delta})
|
||||
yield response.model_copy(update={"text": delta, "is_thinking": thinking})
|
||||
|
||||
if response.finish_reason is not None:
|
||||
if thinking:
|
||||
yield response.model_copy(update={"text": "</think>"})
|
||||
yield response
|
||||
|
||||
|
||||
def parse_deepseek_v32(
|
||||
responses: Generator[GenerationResponse],
|
||||
) -> Generator[GenerationResponse | ToolCallResponse]:
|
||||
"""Parse DeepSeek V3.2 DSML tool calls from the generation stream.
|
||||
|
||||
Uses accumulated-text matching (not per-token marker checks) because
|
||||
DSML markers like <|DSML|function_calls> may span multiple tokens.
|
||||
Also handles <think>...</think> blocks for thinking mode.
|
||||
"""
|
||||
from exo.worker.engines.mlx.dsml_encoding import (
|
||||
THINKING_END,
|
||||
THINKING_START,
|
||||
TOOL_CALLS_END,
|
||||
TOOL_CALLS_START,
|
||||
parse_dsml_output,
|
||||
)
|
||||
|
||||
accumulated = ""
|
||||
in_tool_call = False
|
||||
thinking = False
|
||||
# Tokens buffered while we detect the start of a DSML block
|
||||
pending_buffer: list[GenerationResponse] = []
|
||||
# Text accumulated during a tool call block
|
||||
tool_call_text = ""
|
||||
|
||||
for response in responses:
|
||||
assert isinstance(response, GenerationResponse)
|
||||
|
||||
# ── Handle thinking tags ──
|
||||
if not thinking and THINKING_START in response.text:
|
||||
thinking = True
|
||||
# Yield any text before the <think> tag
|
||||
before = response.text[: response.text.index(THINKING_START)]
|
||||
if before:
|
||||
yield response.model_copy(update={"text": before})
|
||||
continue
|
||||
|
||||
if thinking and THINKING_END in response.text:
|
||||
thinking = False
|
||||
# Yield any text after the </think> tag
|
||||
after = response.text[
|
||||
response.text.index(THINKING_END) + len(THINKING_END) :
|
||||
]
|
||||
if after:
|
||||
yield response.model_copy(update={"text": after, "is_thinking": False})
|
||||
continue
|
||||
|
||||
if thinking:
|
||||
yield response.model_copy(update={"is_thinking": True})
|
||||
continue
|
||||
|
||||
# ── Handle tool call accumulation ──
|
||||
if in_tool_call:
|
||||
tool_call_text += response.text
|
||||
if TOOL_CALLS_END in tool_call_text:
|
||||
# Parse the accumulated DSML block
|
||||
parsed = parse_dsml_output(tool_call_text)
|
||||
if parsed is not None:
|
||||
logger.info(f"parsed DSML tool calls: {parsed}")
|
||||
yield ToolCallResponse(
|
||||
tool_calls=parsed,
|
||||
usage=response.usage,
|
||||
stats=response.stats,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"DSML tool call parsing failed for: {tool_call_text}"
|
||||
)
|
||||
yield response.model_copy(update={"text": tool_call_text})
|
||||
in_tool_call = False
|
||||
tool_call_text = ""
|
||||
continue
|
||||
|
||||
# EOS reached before end marker — yield buffered text as-is
|
||||
if response.finish_reason is not None:
|
||||
logger.info("DSML tool call parsing interrupted by EOS")
|
||||
yield response.model_copy(update={"text": tool_call_text})
|
||||
in_tool_call = False
|
||||
tool_call_text = ""
|
||||
continue
|
||||
|
||||
# ── Detect start of tool call block ──
|
||||
accumulated += response.text
|
||||
|
||||
if TOOL_CALLS_START in accumulated:
|
||||
# The start marker might be split across pending_buffer + current token
|
||||
start_idx = accumulated.index(TOOL_CALLS_START)
|
||||
# Yield any pending tokens that are purely before the marker
|
||||
pre_text = accumulated[:start_idx]
|
||||
if pre_text:
|
||||
# Flush pending buffer tokens that contributed text before the marker
|
||||
for buf_resp in pending_buffer:
|
||||
if pre_text:
|
||||
chunk = buf_resp.text
|
||||
if len(chunk) <= len(pre_text):
|
||||
yield buf_resp
|
||||
pre_text = pre_text[len(chunk) :]
|
||||
else:
|
||||
yield buf_resp.model_copy(update={"text": pre_text})
|
||||
pre_text = ""
|
||||
pending_buffer = []
|
||||
tool_call_text = accumulated[start_idx:]
|
||||
accumulated = ""
|
||||
|
||||
# Check if the end marker is already present (entire tool call in one token)
|
||||
if TOOL_CALLS_END in tool_call_text:
|
||||
parsed = parse_dsml_output(tool_call_text)
|
||||
if parsed is not None:
|
||||
logger.info(f"parsed DSML tool calls: {parsed}")
|
||||
yield ToolCallResponse(
|
||||
tool_calls=parsed,
|
||||
usage=response.usage,
|
||||
stats=response.stats,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"DSML tool call parsing failed for: {tool_call_text}"
|
||||
)
|
||||
yield response.model_copy(update={"text": tool_call_text})
|
||||
tool_call_text = ""
|
||||
else:
|
||||
in_tool_call = True
|
||||
continue
|
||||
|
||||
# Check if accumulated text might be the start of a DSML marker
|
||||
# Buffer tokens if we see a partial match at the end
|
||||
if _could_be_dsml_prefix(accumulated):
|
||||
pending_buffer.append(response)
|
||||
continue
|
||||
|
||||
# No partial match — flush all pending tokens and the current one
|
||||
for buf_resp in pending_buffer:
|
||||
yield buf_resp
|
||||
pending_buffer = []
|
||||
accumulated = ""
|
||||
yield response
|
||||
|
||||
# Flush any remaining pending buffer at generator end
|
||||
for buf_resp in pending_buffer:
|
||||
yield buf_resp
|
||||
|
||||
|
||||
def _could_be_dsml_prefix(text: str) -> bool:
|
||||
"""Check if the end of text could be the start of a DSML function_calls marker.
|
||||
|
||||
We look for suffixes of text that are prefixes of the TOOL_CALLS_START pattern.
|
||||
This allows us to buffer tokens until we can determine if a tool call is starting.
|
||||
"""
|
||||
from exo.worker.engines.mlx.dsml_encoding import TOOL_CALLS_START
|
||||
|
||||
# Only check the last portion of text that could overlap with the marker
|
||||
max_check = len(TOOL_CALLS_START)
|
||||
tail = text[-max_check:] if len(text) > max_check else text
|
||||
|
||||
# Check if any suffix of tail is a prefix of TOOL_CALLS_START
|
||||
for i in range(len(tail)):
|
||||
suffix = tail[i:]
|
||||
if TOOL_CALLS_START.startswith(suffix):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def parse_thinking_models(
|
||||
responses: Generator[GenerationResponse],
|
||||
tokenizer: TokenizerWrapper,
|
||||
starts_in_thinking: bool = True,
|
||||
) -> Generator[GenerationResponse]:
|
||||
"""Route thinking tokens via is_thinking flag.
|
||||
|
||||
Swallows think tag tokens, sets is_thinking on all others.
|
||||
Always yields tokens with finish_reason to avoid hanging the chunk stream.
|
||||
"""
|
||||
For models that inject thinking tags in the prompt (like GLM-4.7),
|
||||
prepend the thinking tag to the output stream so the frontend
|
||||
can properly parse thinking content.
|
||||
"""
|
||||
first = True
|
||||
in_thinking = starts_in_thinking
|
||||
for response in responses:
|
||||
if isinstance(response, ToolCallResponse):
|
||||
yield response
|
||||
continue
|
||||
if first:
|
||||
first = False
|
||||
yield response.model_copy(
|
||||
update={
|
||||
"text": tokenizer.think_start,
|
||||
"token": tokenizer.think_start_id,
|
||||
}
|
||||
)
|
||||
yield response
|
||||
|
||||
is_think_tag = (
|
||||
tokenizer.think_end is not None and response.text == tokenizer.think_end
|
||||
) or (
|
||||
tokenizer.think_start is not None and response.text == tokenizer.think_start
|
||||
)
|
||||
|
||||
if is_think_tag:
|
||||
in_thinking = response.text != tokenizer.think_end
|
||||
# Never swallow finish_reason — the chunk stream needs it to terminate.
|
||||
if response.finish_reason is not None:
|
||||
yield response.model_copy(update={"text": "", "is_thinking": False})
|
||||
continue
|
||||
yield response.model_copy(update={"is_thinking": in_thinking})
|
||||
|
||||
|
||||
def _send_image_chunk(
|
||||
|
||||
@@ -100,8 +100,8 @@ class RunnerSupervisor:
|
||||
logger.info("Runner supervisor shutting down")
|
||||
self._ev_recv.close()
|
||||
self._task_sender.close()
|
||||
self._event_sender.close()
|
||||
self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK"))
|
||||
with contextlib.suppress(ClosedResourceError):
|
||||
self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK"))
|
||||
self._cancel_sender.close()
|
||||
self.runner_process.join(5)
|
||||
if not self.runner_process.is_alive():
|
||||
@@ -180,6 +180,7 @@ class RunnerSupervisor:
|
||||
await self._check_runner(e)
|
||||
for tid in self.pending:
|
||||
self.pending[tid].set()
|
||||
self._event_sender.close()
|
||||
|
||||
def __del__(self) -> None:
|
||||
if self.runner_process.is_alive():
|
||||
@@ -208,10 +209,15 @@ class RunnerSupervisor:
|
||||
|
||||
logger.opt(exception=e).error(f"Runner terminated ({cause})")
|
||||
|
||||
await self._event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=self.bound_instance.bound_runner_id,
|
||||
runner_status=RunnerFailed(error_message=f"Terminated ({cause})"),
|
||||
try:
|
||||
await self._event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=self.bound_instance.bound_runner_id,
|
||||
runner_status=RunnerFailed(error_message=f"Terminated ({cause})"),
|
||||
)
|
||||
)
|
||||
except (ClosedResourceError, BrokenResourceError):
|
||||
logger.warning(
|
||||
"Event sender already closed, unable to report runner failure"
|
||||
)
|
||||
)
|
||||
self.shutdown()
|
||||
|
||||
@@ -90,14 +90,10 @@ def test_plan_loads_model_when_all_shards_downloaded_and_waiting():
|
||||
|
||||
global_download_status = {
|
||||
NODE_A: [
|
||||
DownloadCompleted(
|
||||
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
|
||||
)
|
||||
DownloadCompleted(shard_metadata=shard1, node_id=NODE_A, total=Memory())
|
||||
],
|
||||
NODE_B: [
|
||||
DownloadCompleted(
|
||||
shard_metadata=shard2, node_id=NODE_B, total_bytes=Memory()
|
||||
)
|
||||
DownloadCompleted(shard_metadata=shard2, node_id=NODE_B, total=Memory())
|
||||
],
|
||||
}
|
||||
|
||||
@@ -138,9 +134,7 @@ def test_plan_does_not_request_download_when_shard_already_downloaded():
|
||||
# Global state shows shard is downloaded for NODE_A
|
||||
global_download_status: dict[NodeId, list[DownloadProgress]] = {
|
||||
NODE_A: [
|
||||
DownloadCompleted(
|
||||
shard_metadata=shard, node_id=NODE_A, total_bytes=Memory()
|
||||
)
|
||||
DownloadCompleted(shard_metadata=shard, node_id=NODE_A, total=Memory())
|
||||
],
|
||||
NODE_B: [],
|
||||
}
|
||||
@@ -187,9 +181,7 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
|
||||
|
||||
global_download_status = {
|
||||
NODE_A: [
|
||||
DownloadCompleted(
|
||||
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
|
||||
)
|
||||
DownloadCompleted(shard_metadata=shard1, node_id=NODE_A, total=Memory())
|
||||
],
|
||||
NODE_B: [], # NODE_B has no downloads completed yet
|
||||
}
|
||||
@@ -207,14 +199,10 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
|
||||
|
||||
global_download_status = {
|
||||
NODE_A: [
|
||||
DownloadCompleted(
|
||||
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
|
||||
)
|
||||
DownloadCompleted(shard_metadata=shard1, node_id=NODE_A, total=Memory())
|
||||
],
|
||||
NODE_B: [
|
||||
DownloadCompleted(
|
||||
shard_metadata=shard2, node_id=NODE_B, total_bytes=Memory()
|
||||
)
|
||||
DownloadCompleted(shard_metadata=shard2, node_id=NODE_B, total=Memory())
|
||||
], # NODE_B has no downloads completed yet
|
||||
}
|
||||
|
||||
|
||||
967
src/exo/worker/tests/unittests/test_runner/test_dsml_e2e.py
Normal file
967
src/exo/worker/tests/unittests/test_runner/test_dsml_e2e.py
Normal file
@@ -0,0 +1,967 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from exo.shared.types.worker.runner_response import (
|
||||
GenerationResponse,
|
||||
ToolCallResponse,
|
||||
)
|
||||
from exo.worker.engines.mlx.dsml_encoding import (
|
||||
ASSISTANT_TOKEN,
|
||||
BOS_TOKEN,
|
||||
DSML_TOKEN,
|
||||
EOS_TOKEN,
|
||||
THINKING_END,
|
||||
THINKING_START,
|
||||
TOOL_CALLS_END,
|
||||
TOOL_CALLS_START,
|
||||
USER_TOKEN,
|
||||
encode_messages,
|
||||
parse_dsml_output,
|
||||
)
|
||||
from exo.worker.runner.runner import parse_deepseek_v32
|
||||
|
||||
# ── Shared fixtures ──────────────────────────────────────────────
|
||||
|
||||
_WEATHER_TOOLS: list[dict[str, Any]] = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the current weather in a given city",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {"type": "string", "description": "The city name"},
|
||||
"units": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "Temperature units",
|
||||
},
|
||||
},
|
||||
"required": ["city"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_time",
|
||||
"description": "Get the current time in a timezone",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"timezone": {"type": "string"},
|
||||
},
|
||||
"required": ["timezone"],
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def _simulate_tokens(
|
||||
texts: list[str],
|
||||
finish_on_last: bool = True,
|
||||
) -> Generator[GenerationResponse]:
|
||||
"""Simulate a model producing tokens from a list of text strings."""
|
||||
for i, text in enumerate(texts):
|
||||
is_last = i == len(texts) - 1
|
||||
yield GenerationResponse(
|
||||
text=text,
|
||||
token=i,
|
||||
finish_reason="stop" if (is_last and finish_on_last) else None,
|
||||
usage=None,
|
||||
)
|
||||
|
||||
|
||||
# ── Test: Standard text response (no tool calls) ────────────────
|
||||
|
||||
|
||||
class TestE2EStandardResponse:
|
||||
"""Model generates a plain text response — no tool calling involved."""
|
||||
|
||||
def test_plain_text_passthrough(self):
|
||||
"""Simulate model producing: 'The weather in NYC is 72°F and sunny.'"""
|
||||
# Step 1: Encode the prompt (with tools available)
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "What's the weather in NYC?"},
|
||||
]
|
||||
prompt = encode_messages(messages, thinking_mode="chat", tools=_WEATHER_TOOLS)
|
||||
|
||||
# Verify prompt structure
|
||||
assert BOS_TOKEN in prompt
|
||||
assert "## Tools" in prompt
|
||||
assert "get_weather" in prompt
|
||||
assert f"{USER_TOKEN}What's the weather in NYC?{ASSISTANT_TOKEN}" in prompt
|
||||
|
||||
# Step 2: Simulate model response — plain text tokens (no DSML)
|
||||
model_tokens = [
|
||||
"The weather",
|
||||
" in NYC",
|
||||
" is 72",
|
||||
"°F",
|
||||
" and sunny",
|
||||
".",
|
||||
]
|
||||
results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))
|
||||
|
||||
# Step 3: Verify all tokens pass through as GenerationResponse
|
||||
gen_results = [r for r in results if isinstance(r, GenerationResponse)]
|
||||
tool_results = [r for r in results if isinstance(r, ToolCallResponse)]
|
||||
|
||||
assert len(tool_results) == 0
|
||||
assert len(gen_results) == 6
|
||||
full_text = "".join(r.text for r in gen_results)
|
||||
assert full_text == "The weather in NYC is 72°F and sunny."
|
||||
assert gen_results[-1].finish_reason == "stop"
|
||||
|
||||
|
||||
# ── Test: Tool call response ─────────────────────────────────────
|
||||
|
||||
|
||||
class TestE2EToolCallResponse:
|
||||
"""Model generates a DSML tool call — realistic token boundaries."""
|
||||
|
||||
def test_realistic_tool_call_tokens(self):
|
||||
"""Simulate model generating a get_weather tool call with realistic token splits.
|
||||
|
||||
Real models split DSML markers across tokens unpredictably.
|
||||
This simulates how DeepSeek V3.2 actually tokenizes DSML output.
|
||||
"""
|
||||
# Step 1: Encode prompt
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "What's the weather in San Francisco?"},
|
||||
]
|
||||
prompt = encode_messages(messages, thinking_mode="chat", tools=_WEATHER_TOOLS)
|
||||
assert "get_weather" in prompt
|
||||
|
||||
# Step 2: Simulate realistic token-by-token model output
|
||||
# The model first produces some text, then a DSML tool call block
|
||||
model_tokens = [
|
||||
"I'll check the weather for you.",
|
||||
"\n\n",
|
||||
f"<{DSML_TOKEN}", # marker split across tokens
|
||||
"function_calls>\n",
|
||||
f'<{DSML_TOKEN}invoke name="get_weather">\n',
|
||||
f'<{DSML_TOKEN}parameter name="city" string="true">',
|
||||
"San Francisco",
|
||||
f"</{DSML_TOKEN}parameter>\n",
|
||||
f'<{DSML_TOKEN}parameter name="units" string="false">',
|
||||
'"celsius"',
|
||||
f"</{DSML_TOKEN}parameter>\n",
|
||||
f"</{DSML_TOKEN}invoke>\n",
|
||||
f"</{DSML_TOKEN}function_calls>",
|
||||
]
|
||||
|
||||
results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))
|
||||
|
||||
# Step 3: Verify
|
||||
gen_results = [r for r in results if isinstance(r, GenerationResponse)]
|
||||
tool_results = [r for r in results if isinstance(r, ToolCallResponse)]
|
||||
|
||||
# Should have text tokens before tool call + one ToolCallResponse
|
||||
assert len(tool_results) == 1
|
||||
assert len(tool_results[0].tool_calls) == 1
|
||||
|
||||
tc = tool_results[0].tool_calls[0]
|
||||
assert tc.name == "get_weather"
|
||||
args = json.loads(tc.arguments) # pyright: ignore[reportAny]
|
||||
assert args["city"] == "San Francisco"
|
||||
assert args["units"] == "celsius"
|
||||
|
||||
# The text before the tool call should still be yielded
|
||||
text_before = "".join(r.text for r in gen_results if not r.is_thinking)
|
||||
assert "check the weather" in text_before
|
||||
|
||||
def test_multiple_tool_calls_in_one_block(self):
|
||||
"""Model generates two tool calls in a single function_calls block."""
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{"role": "user", "content": "Weather in NYC and time in EST?"},
|
||||
]
|
||||
prompt = encode_messages(messages, thinking_mode="chat", tools=_WEATHER_TOOLS)
|
||||
assert "get_weather" in prompt
|
||||
assert "get_time" in prompt
|
||||
|
||||
# Simulate model output with two invocations
|
||||
model_tokens = [
|
||||
"Let me check both.\n\n",
|
||||
TOOL_CALLS_START,
|
||||
"\n",
|
||||
f'<{DSML_TOKEN}invoke name="get_weather">\n',
|
||||
f'<{DSML_TOKEN}parameter name="city" string="true">NYC</{DSML_TOKEN}parameter>\n',
|
||||
f"</{DSML_TOKEN}invoke>\n",
|
||||
f'<{DSML_TOKEN}invoke name="get_time">\n',
|
||||
f'<{DSML_TOKEN}parameter name="timezone" string="true">EST</{DSML_TOKEN}parameter>\n',
|
||||
f"</{DSML_TOKEN}invoke>\n",
|
||||
TOOL_CALLS_END,
|
||||
]
|
||||
|
||||
results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))
|
||||
tool_results = [r for r in results if isinstance(r, ToolCallResponse)]
|
||||
|
||||
assert len(tool_results) == 1
|
||||
assert len(tool_results[0].tool_calls) == 2
|
||||
assert tool_results[0].tool_calls[0].name == "get_weather"
|
||||
assert tool_results[0].tool_calls[1].name == "get_time"
|
||||
|
||||
args0 = json.loads(tool_results[0].tool_calls[0].arguments) # pyright: ignore[reportAny]
|
||||
args1 = json.loads(tool_results[0].tool_calls[1].arguments) # pyright: ignore[reportAny]
|
||||
assert args0 == {"city": "NYC"}
|
||||
assert args1 == {"timezone": "EST"}
|
||||
|
||||
|
||||
# ── Test: Multi-turn tool use flow ───────────────────────────────
|
||||
|
||||
|
||||
class TestE2EMultiTurnToolUse:
|
||||
"""Full multi-turn: user asks → model calls tool → tool result → model answers."""
|
||||
|
||||
def test_encode_multi_turn_with_tool_results(self):
|
||||
"""Verify the prompt for turn 2 (after tool results) is correctly encoded."""
|
||||
# Turn 1: user asks, model calls tool
|
||||
# Turn 2: tool result provided, model answers
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "system", "content": "You are a weather assistant."},
|
||||
{"role": "user", "content": "What's the weather in NYC?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"city": "NYC"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "tool", "content": '{"temperature": 72, "condition": "sunny"}'},
|
||||
]
|
||||
|
||||
prompt = encode_messages(messages, thinking_mode="chat", tools=_WEATHER_TOOLS)
|
||||
|
||||
# Verify multi-turn structure
|
||||
assert BOS_TOKEN in prompt
|
||||
assert "You are a weather assistant." in prompt
|
||||
assert "## Tools" in prompt
|
||||
|
||||
# The assistant's tool call should be encoded as DSML
|
||||
assert TOOL_CALLS_START in prompt
|
||||
assert f'<{DSML_TOKEN}invoke name="get_weather">' in prompt
|
||||
assert EOS_TOKEN in prompt
|
||||
|
||||
# The tool result should be wrapped in function_results
|
||||
assert "<function_results>" in prompt
|
||||
assert "<result>" in prompt
|
||||
assert "72" in prompt
|
||||
assert "</function_results>" in prompt
|
||||
|
||||
# Now simulate model answering after seeing the tool result
|
||||
model_tokens = [
|
||||
"The current",
|
||||
" weather in NYC",
|
||||
" is 72°F",
|
||||
" and sunny.",
|
||||
]
|
||||
results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))
|
||||
|
||||
gen_results = [r for r in results if isinstance(r, GenerationResponse)]
|
||||
tool_results = [r for r in results if isinstance(r, ToolCallResponse)]
|
||||
|
||||
assert len(tool_results) == 0
|
||||
full_text = "".join(r.text for r in gen_results)
|
||||
assert full_text == "The current weather in NYC is 72°F and sunny."
|
||||
|
||||
def test_multi_tool_results_encoding(self):
|
||||
"""Verify encoding when model called two tools and both return results."""
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "user", "content": "Weather and time?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"city": "LA"}',
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_time",
|
||||
"arguments": '{"timezone": "PST"}',
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
{"role": "tool", "content": "85F, clear skies"},
|
||||
{"role": "tool", "content": "3:42 PM PST"},
|
||||
]
|
||||
|
||||
prompt = encode_messages(messages, thinking_mode="chat", tools=_WEATHER_TOOLS)
|
||||
|
||||
# Should have one function_results block with two results
|
||||
assert prompt.count("<function_results>") == 1
|
||||
assert prompt.count("</function_results>") == 1
|
||||
assert "<result>85F, clear skies</result>" in prompt
|
||||
assert "<result>3:42 PM PST</result>" in prompt
|
||||
|
||||
|
||||
# ── Test: Thinking + tool call ───────────────────────────────────
|
||||
|
||||
|
||||
class TestE2EThinkingAndToolCall:
|
||||
"""Model uses thinking mode, reasons, then makes a tool call."""
|
||||
|
||||
def test_thinking_then_tool_call(self):
|
||||
"""Model thinks first, then produces a DSML tool call block."""
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "user", "content": "What's the weather?"},
|
||||
]
|
||||
prompt = encode_messages(
|
||||
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
|
||||
)
|
||||
# Thinking mode: prompt should end with <think>
|
||||
assert prompt.endswith(THINKING_START)
|
||||
|
||||
# Simulate: model outputs <think>, thinks, closes thinking, then tool call.
|
||||
# In the full pipeline, parse_thinking_models handles the case where
|
||||
# <think> is in the prompt. Here we test parse_deepseek_v32 directly,
|
||||
# which detects <think>/<think> markers in the stream.
|
||||
model_tokens = [
|
||||
THINKING_START,
|
||||
"The user wants weather",
|
||||
" information. I should use",
|
||||
" the get_weather tool.",
|
||||
THINKING_END,
|
||||
"\n\n",
|
||||
TOOL_CALLS_START,
|
||||
"\n",
|
||||
f'<{DSML_TOKEN}invoke name="get_weather">\n',
|
||||
f'<{DSML_TOKEN}parameter name="city" string="true">',
|
||||
"San Francisco",
|
||||
f"</{DSML_TOKEN}parameter>\n",
|
||||
f"</{DSML_TOKEN}invoke>\n",
|
||||
TOOL_CALLS_END,
|
||||
]
|
||||
|
||||
results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))
|
||||
|
||||
gen_results = [r for r in results if isinstance(r, GenerationResponse)]
|
||||
tool_results = [r for r in results if isinstance(r, ToolCallResponse)]
|
||||
|
||||
# Should have thinking tokens + tool call
|
||||
thinking_results = [r for r in gen_results if r.is_thinking]
|
||||
|
||||
assert len(thinking_results) >= 1
|
||||
thinking_text = "".join(r.text for r in thinking_results)
|
||||
assert "get_weather tool" in thinking_text
|
||||
|
||||
assert len(tool_results) == 1
|
||||
assert tool_results[0].tool_calls[0].name == "get_weather"
|
||||
args = json.loads(tool_results[0].tool_calls[0].arguments) # pyright: ignore[reportAny]
|
||||
assert args["city"] == "San Francisco"
|
||||
|
||||
def test_thinking_prompt_encoding(self):
|
||||
"""Verify thinking mode affects prompt encoding correctly."""
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "system", "content": "Be thorough."},
|
||||
{"role": "user", "content": "What's the weather?"},
|
||||
]
|
||||
|
||||
# With thinking enabled
|
||||
prompt_think = encode_messages(
|
||||
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
|
||||
)
|
||||
assert prompt_think.endswith(THINKING_START)
|
||||
|
||||
# With thinking disabled
|
||||
prompt_no_think = encode_messages(
|
||||
messages, tools=_WEATHER_TOOLS, thinking_mode="chat"
|
||||
)
|
||||
assert prompt_no_think.endswith(THINKING_END)
|
||||
|
||||
# Both should have the same tool definitions
|
||||
assert "get_weather" in prompt_think
|
||||
assert "get_weather" in prompt_no_think
|
||||
|
||||
|
||||
# ── Test: Round-trip encode → parse ──────────────────────────────
|
||||
|
||||
|
||||
class TestE2ERoundTrip:
|
||||
"""Verify that DSML we encode can be parsed back correctly."""
|
||||
|
||||
def test_encoded_tool_call_is_parseable(self):
|
||||
"""Encode an assistant tool call message, then parse the DSML output."""
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "user", "content": "Weather?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"city": "Tokyo", "units": "celsius"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
prompt = encode_messages(messages, thinking_mode="chat", tools=_WEATHER_TOOLS)
|
||||
|
||||
# Extract the DSML function_calls block from the prompt
|
||||
start = prompt.index(TOOL_CALLS_START)
|
||||
end = prompt.index(TOOL_CALLS_END) + len(TOOL_CALLS_END)
|
||||
dsml_block = prompt[start:end]
|
||||
|
||||
# Parse it back
|
||||
parsed = parse_dsml_output(dsml_block)
|
||||
assert parsed is not None
|
||||
assert len(parsed) == 1
|
||||
assert parsed[0].name == "get_weather"
|
||||
args = json.loads(parsed[0].arguments) # pyright: ignore[reportAny]
|
||||
assert args["city"] == "Tokyo"
|
||||
assert args["units"] == "celsius"
|
||||
|
||||
def test_encoded_multi_tool_call_round_trips(self):
|
||||
"""Encode multiple tool calls, verify they parse back correctly."""
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "user", "content": "Both please"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"city": "Paris"}',
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_time",
|
||||
"arguments": '{"timezone": "CET"}',
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
prompt = encode_messages(messages, thinking_mode="chat", tools=_WEATHER_TOOLS)
|
||||
|
||||
start = prompt.index(TOOL_CALLS_START)
|
||||
end = prompt.index(TOOL_CALLS_END) + len(TOOL_CALLS_END)
|
||||
dsml_block = prompt[start:end]
|
||||
|
||||
parsed = parse_dsml_output(dsml_block)
|
||||
assert parsed is not None
|
||||
assert len(parsed) == 2
|
||||
assert parsed[0].name == "get_weather"
|
||||
assert parsed[1].name == "get_time"
|
||||
assert json.loads(parsed[0].arguments) == {"city": "Paris"}
|
||||
assert json.loads(parsed[1].arguments) == {"timezone": "CET"}
|
||||
|
||||
|
||||
# ── Test: Edge cases with realistic token boundaries ─────────────
|
||||
|
||||
|
||||
class TestE2EEdgeCases:
|
||||
"""Edge cases that occur in real model inference."""
|
||||
|
||||
def test_dsml_marker_split_at_fullwidth_pipe(self):
|
||||
"""The fullwidth pipe character | might be its own token."""
|
||||
# This is a realistic tokenization: the DSML marker is split at the | chars
|
||||
model_tokens = [
|
||||
"Let me help.\n\n",
|
||||
"<\uff5c", # start of |DSML|
|
||||
"DSML\uff5c", # rest of DSML token
|
||||
"function_calls>\n",
|
||||
f'<{DSML_TOKEN}invoke name="get_weather">\n',
|
||||
f'<{DSML_TOKEN}parameter name="city" string="true">NYC</{DSML_TOKEN}parameter>\n',
|
||||
f"</{DSML_TOKEN}invoke>\n",
|
||||
TOOL_CALLS_END,
|
||||
]
|
||||
|
||||
results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))
|
||||
tool_results = [r for r in results if isinstance(r, ToolCallResponse)]
|
||||
|
||||
assert len(tool_results) == 1
|
||||
assert tool_results[0].tool_calls[0].name == "get_weather"
|
||||
|
||||
def test_tool_call_with_nested_json_object(self):
|
||||
"""Model passes a complex JSON object as a non-string parameter."""
|
||||
dsml_block = (
|
||||
f"{TOOL_CALLS_START}\n"
|
||||
f'<{DSML_TOKEN}invoke name="create_event">\n'
|
||||
f'<{DSML_TOKEN}parameter name="title" string="true">Team Standup</{DSML_TOKEN}parameter>\n'
|
||||
f'<{DSML_TOKEN}parameter name="config" string="false">'
|
||||
f'{{"recurring": true, "days": ["mon", "wed", "fri"], "time": "09:00"}}'
|
||||
f"</{DSML_TOKEN}parameter>\n"
|
||||
f"</{DSML_TOKEN}invoke>\n"
|
||||
f"{TOOL_CALLS_END}"
|
||||
)
|
||||
|
||||
# Feed as single token (model might produce it all at once after prefill)
|
||||
results = list(parse_deepseek_v32(_simulate_tokens([dsml_block])))
|
||||
tool_results = [r for r in results if isinstance(r, ToolCallResponse)]
|
||||
|
||||
assert len(tool_results) == 1
|
||||
tc = tool_results[0].tool_calls[0]
|
||||
assert tc.name == "create_event"
|
||||
args = json.loads(tc.arguments) # pyright: ignore[reportAny]
|
||||
assert args["title"] == "Team Standup"
|
||||
assert args["config"]["recurring"] is True
|
||||
assert args["config"]["days"] == ["mon", "wed", "fri"]
|
||||
|
||||
def test_text_with_angle_brackets_not_mistaken_for_dsml(self):
|
||||
"""Angle brackets in normal text should not trigger DSML buffering."""
|
||||
model_tokens = [
|
||||
"The formula is ",
|
||||
"<x, y>",
|
||||
" where x > 0",
|
||||
" and y < 100.",
|
||||
]
|
||||
|
||||
results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))
|
||||
gen_results = [r for r in results if isinstance(r, GenerationResponse)]
|
||||
tool_results = [r for r in results if isinstance(r, ToolCallResponse)]
|
||||
|
||||
assert len(tool_results) == 0
|
||||
full_text = "".join(r.text for r in gen_results)
|
||||
assert "formula" in full_text
|
||||
assert "<x, y>" in full_text
|
||||
|
||||
def test_empty_model_response(self):
|
||||
"""Model produces only EOS (empty response)."""
|
||||
model_tokens = [""]
|
||||
results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))
|
||||
gen_results = [r for r in results if isinstance(r, GenerationResponse)]
|
||||
assert len(gen_results) == 1
|
||||
assert gen_results[0].text == ""
|
||||
assert gen_results[0].finish_reason == "stop"
|
||||
|
||||
|
||||
# ── Test: Full EPDP spec round-trip ──────────────────────────────
|
||||
|
||||
|
||||
class TestE2EFullRoundTrip:
|
||||
"""Full round-trip matching the vLLM EPDP spec.
|
||||
|
||||
Simulates the complete multi-turn flow:
|
||||
Turn 1: user asks → think → tool call → tool result → think → answer
|
||||
Turn 2: user asks again → old reasoning stripped → think → answer
|
||||
"""
|
||||
|
||||
def test_single_tool_full_flow_with_thinking(self):
|
||||
"""Complete flow: user → think → tool call → tool result → think → answer.
|
||||
|
||||
This is the core EPDP flow from the vLLM spec.
|
||||
"""
|
||||
# ── Turn 1.1: User asks, encode prompt ──
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "system", "content": "You are a weather assistant."},
|
||||
{"role": "user", "content": "How's the weather in Hangzhou?"},
|
||||
]
|
||||
prompt_1 = encode_messages(
|
||||
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
|
||||
)
|
||||
assert prompt_1.endswith(THINKING_START)
|
||||
assert "## Tools" in prompt_1
|
||||
assert "get_weather" in prompt_1
|
||||
|
||||
# ── Turn 1.1: Model thinks, then calls tool ──
|
||||
model_tokens_1 = [
|
||||
THINKING_START,
|
||||
"The user wants to know the weather in Hangzhou.",
|
||||
" I need to use the get_weather tool.",
|
||||
THINKING_END,
|
||||
"\n\n",
|
||||
TOOL_CALLS_START,
|
||||
"\n",
|
||||
f'<{DSML_TOKEN}invoke name="get_weather">\n',
|
||||
f'<{DSML_TOKEN}parameter name="city" string="true">Hangzhou</{DSML_TOKEN}parameter>\n',
|
||||
f"</{DSML_TOKEN}invoke>\n",
|
||||
TOOL_CALLS_END,
|
||||
]
|
||||
results_1 = list(parse_deepseek_v32(_simulate_tokens(model_tokens_1)))
|
||||
|
||||
# Verify: thinking tokens + tool call
|
||||
gen_1 = [r for r in results_1 if isinstance(r, GenerationResponse)]
|
||||
tool_1 = [r for r in results_1 if isinstance(r, ToolCallResponse)]
|
||||
thinking_1 = [r for r in gen_1 if r.is_thinking]
|
||||
|
||||
assert len(thinking_1) >= 1
|
||||
assert "get_weather tool" in "".join(r.text for r in thinking_1)
|
||||
assert len(tool_1) == 1
|
||||
assert tool_1[0].tool_calls[0].name == "get_weather"
|
||||
tc_args = json.loads(tool_1[0].tool_calls[0].arguments) # pyright: ignore[reportAny]
|
||||
assert tc_args == {"city": "Hangzhou"}
|
||||
|
||||
# ── Turn 1.2: Add assistant response + tool result to messages ──
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"reasoning_content": "The user wants to know the weather in Hangzhou. I need to use the get_weather tool.",
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"city": "Hangzhou"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": '{"temperature": "7~13°C", "condition": "Cloudy"}',
|
||||
}
|
||||
)
|
||||
|
||||
# Encode prompt for turn 1.2
|
||||
prompt_2 = encode_messages(
|
||||
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
|
||||
)
|
||||
|
||||
# Verify: prompt has the full conversation structure
|
||||
assert TOOL_CALLS_START in prompt_2 # assistant's encoded tool call
|
||||
assert EOS_TOKEN in prompt_2 # assistant turn ends with EOS
|
||||
assert "<function_results>" in prompt_2
|
||||
assert "<result>" in prompt_2
|
||||
assert "Cloudy" in prompt_2
|
||||
assert "</function_results>" in prompt_2
|
||||
# After tool results with thinking enabled → <think> appended
|
||||
assert prompt_2.endswith(THINKING_START)
|
||||
# The assistant's reasoning_content should appear (it's after last_user_idx)
|
||||
assert "get_weather tool" in prompt_2
|
||||
|
||||
# ── Turn 1.2: Model thinks about results, then answers ──
|
||||
model_tokens_2 = [
|
||||
THINKING_START,
|
||||
"The weather in Hangzhou is Cloudy, 7~13°C.",
|
||||
" I'll tell the user.",
|
||||
THINKING_END,
|
||||
"The weather in Hangzhou is currently cloudy with temperatures between 7°C and 13°C.",
|
||||
]
|
||||
results_2 = list(parse_deepseek_v32(_simulate_tokens(model_tokens_2)))
|
||||
|
||||
gen_2 = [r for r in results_2 if isinstance(r, GenerationResponse)]
|
||||
tool_2 = [r for r in results_2 if isinstance(r, ToolCallResponse)]
|
||||
thinking_2 = [r for r in gen_2 if r.is_thinking]
|
||||
non_thinking_2 = [r for r in gen_2 if not r.is_thinking]
|
||||
|
||||
assert len(tool_2) == 0 # No more tool calls
|
||||
assert len(thinking_2) >= 1
|
||||
assert "Cloudy" in "".join(r.text for r in thinking_2)
|
||||
assert len(non_thinking_2) >= 1
|
||||
final_text = "".join(r.text for r in non_thinking_2)
|
||||
assert "7°C" in final_text
|
||||
assert "13°C" in final_text
|
||||
|
||||
def test_multi_tool_full_flow(self):
|
||||
"""Flow with two tools: user → think → 2 tool calls → 2 results → think → answer."""
|
||||
# ── Initial prompt ──
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "system", "content": "You help with weather and time."},
|
||||
{"role": "user", "content": "Weather in NYC and time in EST?"},
|
||||
]
|
||||
prompt_1 = encode_messages(
|
||||
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
|
||||
)
|
||||
assert prompt_1.endswith(THINKING_START)
|
||||
|
||||
# ── Model thinks, calls both tools ──
|
||||
model_tokens_1 = [
|
||||
THINKING_START,
|
||||
"Two requests: weather and time. I'll call both.",
|
||||
THINKING_END,
|
||||
"\n\n",
|
||||
TOOL_CALLS_START,
|
||||
"\n",
|
||||
f'<{DSML_TOKEN}invoke name="get_weather">\n',
|
||||
f'<{DSML_TOKEN}parameter name="city" string="true">NYC</{DSML_TOKEN}parameter>\n',
|
||||
f"</{DSML_TOKEN}invoke>\n",
|
||||
f'<{DSML_TOKEN}invoke name="get_time">\n',
|
||||
f'<{DSML_TOKEN}parameter name="timezone" string="true">EST</{DSML_TOKEN}parameter>\n',
|
||||
f"</{DSML_TOKEN}invoke>\n",
|
||||
TOOL_CALLS_END,
|
||||
]
|
||||
results_1 = list(parse_deepseek_v32(_simulate_tokens(model_tokens_1)))
|
||||
tool_1 = [r for r in results_1 if isinstance(r, ToolCallResponse)]
|
||||
|
||||
assert len(tool_1) == 1
|
||||
assert len(tool_1[0].tool_calls) == 2
|
||||
assert tool_1[0].tool_calls[0].name == "get_weather"
|
||||
assert tool_1[0].tool_calls[1].name == "get_time"
|
||||
|
||||
# ── Add assistant + both tool results ──
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"reasoning_content": "Two requests: weather and time. I'll call both.",
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"city": "NYC"}',
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_time",
|
||||
"arguments": '{"timezone": "EST"}',
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
)
|
||||
messages.append({"role": "tool", "content": "72°F, sunny"})
|
||||
messages.append({"role": "tool", "content": "2:30 PM EST"})
|
||||
|
||||
prompt_2 = encode_messages(
|
||||
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
|
||||
)
|
||||
|
||||
# Verify multi-tool result encoding
|
||||
# Count is 2: 1 in _TOOLS_SYSTEM_TEMPLATE example + 1 in conversation
|
||||
assert prompt_2.count("<function_results>") == 2
|
||||
assert prompt_2.count("</function_results>") == 2
|
||||
assert "<result>72°F, sunny</result>" in prompt_2
|
||||
assert "<result>2:30 PM EST</result>" in prompt_2
|
||||
assert prompt_2.endswith(THINKING_START)
|
||||
|
||||
# ── Model thinks about results, answers ──
|
||||
model_tokens_2 = [
|
||||
THINKING_START,
|
||||
"Got both results. Weather is 72°F sunny, time is 2:30 PM.",
|
||||
THINKING_END,
|
||||
"In NYC it's currently 72°F and sunny. The time in EST is 2:30 PM.",
|
||||
]
|
||||
results_2 = list(parse_deepseek_v32(_simulate_tokens(model_tokens_2)))
|
||||
|
||||
tool_2 = [r for r in results_2 if isinstance(r, ToolCallResponse)]
|
||||
gen_2 = [r for r in results_2 if isinstance(r, GenerationResponse)]
|
||||
non_thinking_2 = [r for r in gen_2 if not r.is_thinking]
|
||||
|
||||
assert len(tool_2) == 0
|
||||
final_text = "".join(r.text for r in non_thinking_2)
|
||||
assert "72°F" in final_text
|
||||
assert "2:30 PM" in final_text
|
||||
|
||||
def test_two_user_turns_reasoning_stripped(self):
|
||||
"""Turn 2: old reasoning_content is stripped from history.
|
||||
|
||||
Per the vLLM spec, clear_reasoning_content is called between user turns
|
||||
to save bandwidth. Our _drop_old_thinking handles this.
|
||||
"""
|
||||
# Full turn 1 conversation (already completed)
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{"role": "user", "content": "Weather in Hangzhou?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"reasoning_content": "I need to call get_weather for Hangzhou.",
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"city": "Hangzhou"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "tool", "content": "Cloudy 7~13°C"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "The weather in Hangzhou is cloudy, 7-13°C.",
|
||||
"reasoning_content": "The tool returned cloudy weather. I'll summarize.",
|
||||
},
|
||||
# Turn 2: user asks again
|
||||
{"role": "user", "content": "What about Beijing?"},
|
||||
]
|
||||
|
||||
prompt = encode_messages(
|
||||
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
|
||||
)
|
||||
|
||||
# Old reasoning_content from turn 1 assistants should be STRIPPED
|
||||
# (they're before the last user message at index 5)
|
||||
assert "I need to call get_weather" not in prompt
|
||||
assert "tool returned cloudy" not in prompt
|
||||
|
||||
# But the assistant's content and tool calls should still be there
|
||||
assert "cloudy, 7-13°C" in prompt
|
||||
assert TOOL_CALLS_START in prompt
|
||||
|
||||
# Prompt ends with <think> for the new turn
|
||||
assert prompt.endswith(THINKING_START)
|
||||
|
||||
# ── Turn 2: Model thinks, calls tool for Beijing ──
|
||||
model_tokens = [
|
||||
THINKING_START,
|
||||
"Now the user wants Beijing weather.",
|
||||
THINKING_END,
|
||||
"\n\n",
|
||||
TOOL_CALLS_START,
|
||||
"\n",
|
||||
f'<{DSML_TOKEN}invoke name="get_weather">\n',
|
||||
f'<{DSML_TOKEN}parameter name="city" string="true">Beijing</{DSML_TOKEN}parameter>\n',
|
||||
f"</{DSML_TOKEN}invoke>\n",
|
||||
TOOL_CALLS_END,
|
||||
]
|
||||
results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))
|
||||
tool_results = [r for r in results if isinstance(r, ToolCallResponse)]
|
||||
|
||||
assert len(tool_results) == 1
|
||||
assert tool_results[0].tool_calls[0].name == "get_weather"
|
||||
args = json.loads(tool_results[0].tool_calls[0].arguments) # pyright: ignore[reportAny]
|
||||
assert args == {"city": "Beijing"}
|
||||
|
||||
def test_chained_tool_calls_loop(self):
|
||||
"""Model calls tool, gets result, calls another tool, gets result, answers.
|
||||
|
||||
This simulates the inner while loop from the vLLM spec where the model
|
||||
may need multiple sub-turns of tool calling before it has enough info.
|
||||
"""
|
||||
# ── Sub-turn 1: user asks, model calls get_time ──
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{"role": "user", "content": "What's the weather in Hangzhou tomorrow?"},
|
||||
]
|
||||
|
||||
prompt_1 = encode_messages(
|
||||
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
|
||||
)
|
||||
assert prompt_1.endswith(THINKING_START)
|
||||
|
||||
# Model first calls get_time to figure out the date
|
||||
model_tokens_1 = [
|
||||
THINKING_START,
|
||||
"I need the current date first to calculate tomorrow.",
|
||||
THINKING_END,
|
||||
"\n\n",
|
||||
TOOL_CALLS_START,
|
||||
"\n",
|
||||
f'<{DSML_TOKEN}invoke name="get_time">\n',
|
||||
f'<{DSML_TOKEN}parameter name="timezone" string="true">Asia/Shanghai</{DSML_TOKEN}parameter>\n',
|
||||
f"</{DSML_TOKEN}invoke>\n",
|
||||
TOOL_CALLS_END,
|
||||
]
|
||||
results_1 = list(parse_deepseek_v32(_simulate_tokens(model_tokens_1)))
|
||||
tool_1 = [r for r in results_1 if isinstance(r, ToolCallResponse)]
|
||||
assert len(tool_1) == 1
|
||||
assert tool_1[0].tool_calls[0].name == "get_time"
|
||||
|
||||
# ── Sub-turn 2: add tool result, model calls get_weather ──
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"reasoning_content": "I need the current date first to calculate tomorrow.",
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_time",
|
||||
"arguments": '{"timezone": "Asia/Shanghai"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
messages.append({"role": "tool", "content": "2025-12-01 14:30 CST"})
|
||||
|
||||
prompt_2 = encode_messages(
|
||||
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
|
||||
)
|
||||
assert "<result>2025-12-01 14:30 CST</result>" in prompt_2
|
||||
assert prompt_2.endswith(THINKING_START)
|
||||
|
||||
# Model now knows the date, calls get_weather
|
||||
model_tokens_2 = [
|
||||
THINKING_START,
|
||||
"Today is 2025-12-01, so tomorrow is 2025-12-02.",
|
||||
" Now I can check weather for Hangzhou.",
|
||||
THINKING_END,
|
||||
"\n\n",
|
||||
TOOL_CALLS_START,
|
||||
"\n",
|
||||
f'<{DSML_TOKEN}invoke name="get_weather">\n',
|
||||
f'<{DSML_TOKEN}parameter name="city" string="true">Hangzhou</{DSML_TOKEN}parameter>\n',
|
||||
f"</{DSML_TOKEN}invoke>\n",
|
||||
TOOL_CALLS_END,
|
||||
]
|
||||
results_2 = list(parse_deepseek_v32(_simulate_tokens(model_tokens_2)))
|
||||
tool_2 = [r for r in results_2 if isinstance(r, ToolCallResponse)]
|
||||
assert len(tool_2) == 1
|
||||
assert tool_2[0].tool_calls[0].name == "get_weather"
|
||||
|
||||
# ── Sub-turn 3: add weather result, model answers ──
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"reasoning_content": "Today is 2025-12-01, so tomorrow is 2025-12-02. Now I can check weather for Hangzhou.",
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"city": "Hangzhou"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
messages.append({"role": "tool", "content": "Sunny, 5~12°C"})
|
||||
|
||||
prompt_3 = encode_messages(
|
||||
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
|
||||
)
|
||||
# Should have both function_results blocks (one per tool round)
|
||||
# Count is 3: 1 in _TOOLS_SYSTEM_TEMPLATE example + 2 in conversation
|
||||
assert prompt_3.count("<function_results>") == 3
|
||||
assert prompt_3.count("</function_results>") == 3
|
||||
assert "<result>2025-12-01 14:30 CST</result>" in prompt_3
|
||||
assert "<result>Sunny, 5~12°C</result>" in prompt_3
|
||||
assert prompt_3.endswith(THINKING_START)
|
||||
|
||||
# Model finally answers
|
||||
model_tokens_3 = [
|
||||
THINKING_START,
|
||||
"I have the weather for tomorrow in Hangzhou.",
|
||||
THINKING_END,
|
||||
"Tomorrow in Hangzhou will be sunny with temperatures between 5°C and 12°C.",
|
||||
]
|
||||
results_3 = list(parse_deepseek_v32(_simulate_tokens(model_tokens_3)))
|
||||
|
||||
tool_3 = [r for r in results_3 if isinstance(r, ToolCallResponse)]
|
||||
gen_3 = [r for r in results_3 if isinstance(r, GenerationResponse)]
|
||||
non_thinking_3 = [r for r in gen_3 if not r.is_thinking]
|
||||
|
||||
assert len(tool_3) == 0 # No more tool calls — loop ends
|
||||
final_text = "".join(r.text for r in non_thinking_3)
|
||||
assert "sunny" in final_text.lower()
|
||||
assert "5°C" in final_text
|
||||
assert "12°C" in final_text
|
||||
@@ -148,6 +148,7 @@ class MockTokenizer:
|
||||
tool_call_start = None
|
||||
tool_call_end = None
|
||||
has_tool_calling = False
|
||||
has_thinking = False
|
||||
|
||||
|
||||
class MockGroup:
|
||||
|
||||
@@ -149,12 +149,23 @@ class TestParseGptOssThinkingThenToolCall:
|
||||
def test_thinking_then_tool_call(self):
|
||||
results = _collect(THINKING_THEN_TOOL_TOKENS)
|
||||
|
||||
# Should have thinking tags + content + tool call
|
||||
text_parts = [r.text for r in results if isinstance(r, GenerationResponse)]
|
||||
combined = "".join(text_parts)
|
||||
assert "<think>" in combined
|
||||
assert "</think>" in combined
|
||||
assert "Let me think about this." in combined
|
||||
# Thinking tokens should have is_thinking=True and no <think> tags
|
||||
thinking_responses = [
|
||||
r for r in results if isinstance(r, GenerationResponse) and r.is_thinking
|
||||
]
|
||||
thinking_text = "".join(r.text for r in thinking_responses)
|
||||
assert "Let me think about this." in thinking_text
|
||||
assert "<think>" not in thinking_text
|
||||
assert "</think>" not in thinking_text
|
||||
|
||||
# Non-thinking tokens should have is_thinking=False
|
||||
non_thinking = [
|
||||
r
|
||||
for r in results
|
||||
if isinstance(r, GenerationResponse) and not r.is_thinking
|
||||
]
|
||||
non_thinking_text = "".join(r.text for r in non_thinking)
|
||||
assert "<think>" not in non_thinking_text
|
||||
|
||||
# And the tool call
|
||||
tc = _get_tool_call(results)
|
||||
|
||||
8
tmp/config_examples/claude_code.sh
Executable file
8
tmp/config_examples/claude_code.sh
Executable file
@@ -0,0 +1,8 @@
|
||||
#!/bin/bash
|
||||
# Run Claude Code against a local exo cluster! (Here, GPT OSS 120B)
|
||||
ANTHROPIC_BASE_URL="http://localhost:52415/" \
|
||||
ANTHROPIC_AUTH_TOKEN="dummy" \
|
||||
ANTHROPIC_MODEL="mlx-community/gpt-oss-120b-MXFP4-Q8" \
|
||||
ANTHROPIC_SMALL_FAST_MODEL="mlx-community/gpt-oss-120b-MXFP4-Q8" \
|
||||
CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC=1 \
|
||||
claude
|
||||
24
uv.lock
generated
24
uv.lock
generated
@@ -193,20 +193,14 @@ sdist = { url = "https://files.pythonhosted.org/packages/eb/56/b1ba7935a17738ae8
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/b0/1e/d22cc63332bd59b06481ceaac49d6c507598642e2230f201649058a7e704/cffi-2.0.0-cp313-cp313-manylinux1_i686.manylinux2014_i686.manylinux_2_17_i686.manylinux_2_5_i686.whl", hash = "sha256:07b271772c100085dd28b74fa0cd81c8fb1a3ba18b21e03d7c27f3436a10606b", size = 212446, upload-time = "2025-09-08T23:23:03.472Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a9/f5/a2c23eb03b61a0b8747f211eb716446c826ad66818ddc7810cc2cc19b3f2/cffi-2.0.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d48a880098c96020b02d5a1f7d9251308510ce8858940e6fa99ece33f610838b", size = 220101, upload-time = "2025-09-08T23:23:04.792Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f2/7f/e6647792fc5850d634695bc0e6ab4111ae88e89981d35ac269956605feba/cffi-2.0.0-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:f93fd8e5c8c0a4aa1f424d6173f14a892044054871c771f8566e4008eaa359d2", size = 207948, upload-time = "2025-09-08T23:23:06.127Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/cb/1e/a5a1bd6f1fb30f22573f76533de12a00bf274abcdc55c8edab639078abb6/cffi-2.0.0-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:dd4f05f54a52fb558f1ba9f528228066954fee3ebe629fc1660d874d040ae5a3", size = 206422, upload-time = "2025-09-08T23:23:07.753Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/98/df/0a1755e750013a2081e863e7cd37e0cdd02664372c754e5560099eb7aa44/cffi-2.0.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:c8d3b5532fc71b7a77c09192b4a5a200ea992702734a2e9279a37f2478236f26", size = 219499, upload-time = "2025-09-08T23:23:09.648Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/50/e1/a969e687fcf9ea58e6e2a928ad5e2dd88cc12f6f0ab477e9971f2309b57c/cffi-2.0.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:d9b29c1f0ae438d5ee9acb31cadee00a58c46cc9c0b2f9038c6b0b3470877a8c", size = 222928, upload-time = "2025-09-08T23:23:10.928Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/36/54/0362578dd2c9e557a28ac77698ed67323ed5b9775ca9d3fe73fe191bb5d8/cffi-2.0.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:6d50360be4546678fc1b79ffe7a66265e28667840010348dd69a314145807a1b", size = 221302, upload-time = "2025-09-08T23:23:12.42Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d6/43/0e822876f87ea8a4ef95442c3d766a06a51fc5298823f884ef87aaad168c/cffi-2.0.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:24b6f81f1983e6df8db3adc38562c83f7d4a0c36162885ec7f7b77c7dcbec97b", size = 220049, upload-time = "2025-09-08T23:23:20.853Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b4/89/76799151d9c2d2d1ead63c2429da9ea9d7aac304603de0c6e8764e6e8e70/cffi-2.0.0-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:12873ca6cb9b0f0d3a0da705d6086fe911591737a59f28b7936bdfed27c0d47c", size = 207793, upload-time = "2025-09-08T23:23:22.08Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/bb/dd/3465b14bb9e24ee24cb88c9e3730f6de63111fffe513492bf8c808a3547e/cffi-2.0.0-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:d9b97165e8aed9272a6bb17c01e3cc5871a594a446ebedc996e2397a1c1ea8ef", size = 206300, upload-time = "2025-09-08T23:23:23.314Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/47/d9/d83e293854571c877a92da46fdec39158f8d7e68da75bf73581225d28e90/cffi-2.0.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:afb8db5439b81cf9c9d0c80404b60c3cc9c3add93e114dcae767f1477cb53775", size = 219244, upload-time = "2025-09-08T23:23:24.541Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/2b/0f/1f177e3683aead2bb00f7679a16451d302c436b5cbf2505f0ea8146ef59e/cffi-2.0.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:737fe7d37e1a1bffe70bd5754ea763a62a066dc5913ca57e957824b72a85e205", size = 222828, upload-time = "2025-09-08T23:23:26.143Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c6/0f/cafacebd4b040e3119dcb32fed8bdef8dfe94da653155f9d0b9dc660166e/cffi-2.0.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:38100abb9d1b1435bc4cc340bb4489635dc2f0da7456590877030c9b3d40b0c1", size = 220926, upload-time = "2025-09-08T23:23:27.873Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/be/b4/c56878d0d1755cf9caa54ba71e5d049479c52f9e4afc230f06822162ab2f/cffi-2.0.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:7cc09976e8b56f8cebd752f7113ad07752461f48a58cbba644139015ac24954c", size = 221593, upload-time = "2025-09-08T23:23:31.91Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e0/0d/eb704606dfe8033e7128df5e90fee946bbcb64a04fcdaa97321309004000/cffi-2.0.0-cp314-cp314t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:92b68146a71df78564e4ef48af17551a5ddd142e5190cdf2c5624d0c3ff5b2e8", size = 209354, upload-time = "2025-09-08T23:23:33.214Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d8/19/3c435d727b368ca475fb8742ab97c9cb13a0de600ce86f62eab7fa3eea60/cffi-2.0.0-cp314-cp314t-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:b1e74d11748e7e98e2f426ab176d4ed720a64412b6a15054378afdb71e0f37dc", size = 208480, upload-time = "2025-09-08T23:23:34.495Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d0/44/681604464ed9541673e486521497406fadcc15b5217c3e326b061696899a/cffi-2.0.0-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:28a3a209b96630bca57cce802da70c266eb08c6e97e5afd61a75611ee6c64592", size = 221584, upload-time = "2025-09-08T23:23:36.096Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/25/8e/342a504ff018a2825d395d44d63a767dd8ebc927ebda557fecdaca3ac33a/cffi-2.0.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:7553fb2090d71822f02c629afe6042c299edf91ba1bf94951165613553984512", size = 224443, upload-time = "2025-09-08T23:23:37.328Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e1/5e/b666bacbbc60fbf415ba9988324a132c9a7a0448a9a8f125074671c0f2c3/cffi-2.0.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:6c6c373cfc5c83a975506110d17457138c8c63016b563cc9ed6e056a82f13ce4", size = 223437, upload-time = "2025-09-08T23:23:38.945Z" },
|
||||
@@ -312,10 +306,8 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/5c/49/498c86566a1d80e978b42f0d702795f69887005548c041636df6ae1ca64c/cryptography-46.0.3-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:01ca9ff2885f3acc98c29f1860552e37f6d7c7d013d7334ff2a9de43a449315d", size = 4450807, upload-time = "2025-10-15T23:16:56.414Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/4b/0a/863a3604112174c8624a2ac3c038662d9e59970c7f926acdcfaed8d61142/cryptography-46.0.3-cp311-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:6eae65d4c3d33da080cff9c4ab1f711b15c1d9760809dad6ea763f3812d254cb", size = 4299615, upload-time = "2025-10-15T23:16:58.442Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/64/02/b73a533f6b64a69f3cd3872acb6ebc12aef924d8d103133bb3ea750dc703/cryptography-46.0.3-cp311-abi3-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:e5bf0ed4490068a2e72ac03d786693adeb909981cc596425d09032d372bcc849", size = 4016800, upload-time = "2025-10-15T23:17:00.378Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/25/d5/16e41afbfa450cde85a3b7ec599bebefaef16b5c6ba4ec49a3532336ed72/cryptography-46.0.3-cp311-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:5ecfccd2329e37e9b7112a888e76d9feca2347f12f37918facbb893d7bb88ee8", size = 4984707, upload-time = "2025-10-15T23:17:01.98Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c9/56/e7e69b427c3878352c2fb9b450bd0e19ed552753491d39d7d0a2f5226d41/cryptography-46.0.3-cp311-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:a2c0cd47381a3229c403062f764160d57d4d175e022c1df84e168c6251a22eec", size = 4482541, upload-time = "2025-10-15T23:17:04.078Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/78/f6/50736d40d97e8483172f1bb6e698895b92a223dba513b0ca6f06b2365339/cryptography-46.0.3-cp311-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:549e234ff32571b1f4076ac269fcce7a808d3bf98b76c8dd560e42dbc66d7d91", size = 4299464, upload-time = "2025-10-15T23:17:05.483Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/00/de/d8e26b1a855f19d9994a19c702fa2e93b0456beccbcfe437eda00e0701f2/cryptography-46.0.3-cp311-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:c0a7bb1a68a5d3471880e264621346c48665b3bf1c3759d682fc0864c540bd9e", size = 4950838, upload-time = "2025-10-15T23:17:07.425Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/8f/29/798fc4ec461a1c9e9f735f2fc58741b0daae30688f41b2497dcbc9ed1355/cryptography-46.0.3-cp311-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:10b01676fc208c3e6feeb25a8b83d81767e8059e1fe86e1dc62d10a3018fa926", size = 4481596, upload-time = "2025-10-15T23:17:09.343Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/15/8d/03cd48b20a573adfff7652b76271078e3045b9f49387920e7f1f631d125e/cryptography-46.0.3-cp311-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:0abf1ffd6e57c67e92af68330d05760b7b7efb243aab8377e583284dbab72c71", size = 4426782, upload-time = "2025-10-15T23:17:11.22Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fa/b1/ebacbfe53317d55cf33165bda24c86523497a6881f339f9aae5c2e13e57b/cryptography-46.0.3-cp311-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:a04bee9ab6a4da801eb9b51f1b708a1b5b5c9eb48c03f74198464c66f0d344ac", size = 4698381, upload-time = "2025-10-15T23:17:12.829Z" },
|
||||
@@ -323,10 +315,8 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/c5/fd/bc1daf8230eaa075184cbbf5f8cd00ba9db4fd32d63fb83da4671b72ed8a/cryptography-46.0.3-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:39b6755623145ad5eff1dab323f4eae2a32a77a7abef2c5089a04a3d04366715", size = 4435078, upload-time = "2025-10-15T23:17:23.042Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/82/98/d3bd5407ce4c60017f8ff9e63ffee4200ab3e23fe05b765cab805a7db008/cryptography-46.0.3-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:db391fa7c66df6762ee3f00c95a89e6d428f4d60e7abc8328f4fe155b5ac6e54", size = 4293460, upload-time = "2025-10-15T23:17:24.885Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/26/e9/e23e7900983c2b8af7a08098db406cf989d7f09caea7897e347598d4cd5b/cryptography-46.0.3-cp314-cp314t-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:78a97cf6a8839a48c49271cdcbd5cf37ca2c1d6b7fdd86cc864f302b5e9bf459", size = 3995237, upload-time = "2025-10-15T23:17:26.449Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/91/15/af68c509d4a138cfe299d0d7ddb14afba15233223ebd933b4bbdbc7155d3/cryptography-46.0.3-cp314-cp314t-manylinux_2_28_ppc64le.whl", hash = "sha256:dfb781ff7eaa91a6f7fd41776ec37c5853c795d3b358d4896fdbb5df168af422", size = 4967344, upload-time = "2025-10-15T23:17:28.06Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ca/e3/8643d077c53868b681af077edf6b3cb58288b5423610f21c62aadcbe99f4/cryptography-46.0.3-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:6f61efb26e76c45c4a227835ddeae96d83624fb0d29eb5df5b96e14ed1a0afb7", size = 4466564, upload-time = "2025-10-15T23:17:29.665Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0e/43/c1e8726fa59c236ff477ff2b5dc071e54b21e5a1e51aa2cee1676f1c986f/cryptography-46.0.3-cp314-cp314t-manylinux_2_34_aarch64.whl", hash = "sha256:23b1a8f26e43f47ceb6d6a43115f33a5a37d57df4ea0ca295b780ae8546e8044", size = 4292415, upload-time = "2025-10-15T23:17:31.686Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/42/f9/2f8fefdb1aee8a8e3256a0568cffc4e6d517b256a2fe97a029b3f1b9fe7e/cryptography-46.0.3-cp314-cp314t-manylinux_2_34_ppc64le.whl", hash = "sha256:b419ae593c86b87014b9be7396b385491ad7f320bde96826d0dd174459e54665", size = 4931457, upload-time = "2025-10-15T23:17:33.478Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/79/30/9b54127a9a778ccd6d27c3da7563e9f2d341826075ceab89ae3b41bf5be2/cryptography-46.0.3-cp314-cp314t-manylinux_2_34_x86_64.whl", hash = "sha256:50fc3343ac490c6b08c0cf0d704e881d0d660be923fd3076db3e932007e726e3", size = 4466074, upload-time = "2025-10-15T23:17:35.158Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ac/68/b4f4a10928e26c941b1b6a179143af9f4d27d88fe84a6a3c53592d2e76bf/cryptography-46.0.3-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:22d7e97932f511d6b0b04f2bfd818d73dcd5928db509460aaf48384778eb6d20", size = 4420569, upload-time = "2025-10-15T23:17:37.188Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a3/49/3746dab4c0d1979888f125226357d3262a6dd40e114ac29e3d2abdf1ec55/cryptography-46.0.3-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:d55f3dffadd674514ad19451161118fd010988540cee43d8bc20675e775925de", size = 4681941, upload-time = "2025-10-15T23:17:39.236Z" },
|
||||
@@ -334,10 +324,8 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/26/42/fa8389d4478368743e24e61eea78846a0006caffaf72ea24a15159215a14/cryptography-46.0.3-cp38-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:15ab9b093e8f09daab0f2159bb7e47532596075139dd74365da52ecc9cb46c5d", size = 4440029, upload-time = "2025-10-15T23:17:49.837Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5f/eb/f483db0ec5ac040824f269e93dd2bd8a21ecd1027e77ad7bdf6914f2fd80/cryptography-46.0.3-cp38-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:46acf53b40ea38f9c6c229599a4a13f0d46a6c3fa9ef19fc1a124d62e338dfa0", size = 4297222, upload-time = "2025-10-15T23:17:51.357Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fd/cf/da9502c4e1912cb1da3807ea3618a6829bee8207456fbbeebc361ec38ba3/cryptography-46.0.3-cp38-abi3-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:10ca84c4668d066a9878890047f03546f3ae0a6b8b39b697457b7757aaf18dbc", size = 4012280, upload-time = "2025-10-15T23:17:52.964Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6b/8f/9adb86b93330e0df8b3dcf03eae67c33ba89958fc2e03862ef1ac2b42465/cryptography-46.0.3-cp38-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:36e627112085bb3b81b19fed209c05ce2a52ee8b15d161b7c643a7d5a88491f3", size = 4978958, upload-time = "2025-10-15T23:17:54.965Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d1/a0/5fa77988289c34bdb9f913f5606ecc9ada1adb5ae870bd0d1054a7021cc4/cryptography-46.0.3-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:1000713389b75c449a6e979ffc7dcc8ac90b437048766cef052d4d30b8220971", size = 4473714, upload-time = "2025-10-15T23:17:56.754Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/14/e5/fc82d72a58d41c393697aa18c9abe5ae1214ff6f2a5c18ac470f92777895/cryptography-46.0.3-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:b02cf04496f6576afffef5ddd04a0cb7d49cf6be16a9059d793a30b035f6b6ac", size = 4296970, upload-time = "2025-10-15T23:17:58.588Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/78/06/5663ed35438d0b09056973994f1aec467492b33bd31da36e468b01ec1097/cryptography-46.0.3-cp38-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:71e842ec9bc7abf543b47cf86b9a743baa95f4677d22baa4c7d5c69e49e9bc04", size = 4940236, upload-time = "2025-10-15T23:18:00.897Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fc/59/873633f3f2dcd8a053b8dd1d38f783043b5fce589c0f6988bf55ef57e43e/cryptography-46.0.3-cp38-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:402b58fc32614f00980b66d6e56a5b4118e6cb362ae8f3fda141ba4689bd4506", size = 4472642, upload-time = "2025-10-15T23:18:02.749Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/3d/39/8e71f3930e40f6877737d6f69248cf74d4e34b886a3967d32f919cc50d3b/cryptography-46.0.3-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:ef639cb3372f69ec44915fafcd6698b6cc78fbe0c2ea41be867f6ed612811963", size = 4423126, upload-time = "2025-10-15T23:18:04.85Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/cd/c7/f65027c2810e14c3e7268353b1681932b87e5a48e65505d8cc17c99e36ae/cryptography-46.0.3-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:3b51b8ca4f1c6453d8829e1eb7299499ca7f313900dd4d89a24b8b87c0a780d4", size = 4686573, upload-time = "2025-10-15T23:18:06.908Z" },
|
||||
@@ -378,7 +366,7 @@ dependencies = [
|
||||
{ name = "loguru", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mflux", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mlx", version = "0.30.6", source = { registry = "https://pypi.org/simple" }, extra = ["cpu"], marker = "sys_platform == 'linux'" },
|
||||
{ name = "mlx", version = "0.30.7.dev20260218+14841977", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#1484197707f35186ad3bd614357c7c47fdf86ebc" }, marker = "sys_platform == 'darwin'" },
|
||||
{ name = "mlx", version = "0.30.7.dev20260220+bdfe78f6", source = { git = "https://github.com/JakeHillion/mlx.git?branch=test-mlx-lazy-import#bdfe78f6e1fccb7cb3dfd049eb38f7c611e5f323" }, marker = "sys_platform == 'darwin'" },
|
||||
{ name = "mlx-lm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "msgspec", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "openai-harmony", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
@@ -416,7 +404,7 @@ requires-dist = [
|
||||
{ name = "hypercorn", specifier = ">=0.18.0" },
|
||||
{ name = "loguru", specifier = ">=0.7.3" },
|
||||
{ name = "mflux", specifier = "==0.15.5" },
|
||||
{ name = "mlx", marker = "sys_platform == 'darwin'", git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks" },
|
||||
{ name = "mlx", marker = "sys_platform == 'darwin'", git = "https://github.com/JakeHillion/mlx.git?branch=test-mlx-lazy-import" },
|
||||
{ name = "mlx", extras = ["cpu"], marker = "sys_platform == 'linux'", specifier = "==0.30.6" },
|
||||
{ name = "mlx-lm", specifier = "==0.30.7" },
|
||||
{ name = "msgspec", specifier = ">=0.19.0" },
|
||||
@@ -1023,7 +1011,7 @@ dependencies = [
|
||||
{ name = "huggingface-hub", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "matplotlib", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mlx", version = "0.30.6", source = { registry = "https://pypi.org/simple" }, extra = ["cuda13"], marker = "sys_platform == 'linux'" },
|
||||
{ name = "mlx", version = "0.30.7.dev20260218+14841977", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#1484197707f35186ad3bd614357c7c47fdf86ebc" }, marker = "sys_platform == 'darwin'" },
|
||||
{ name = "mlx", version = "0.30.7.dev20260220+bdfe78f6", source = { git = "https://github.com/JakeHillion/mlx.git?branch=test-mlx-lazy-import#bdfe78f6e1fccb7cb3dfd049eb38f7c611e5f323" }, marker = "sys_platform == 'darwin'" },
|
||||
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "opencv-python", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "piexif", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
@@ -1070,8 +1058,8 @@ cuda13 = [
|
||||
|
||||
[[package]]
|
||||
name = "mlx"
|
||||
version = "0.30.7.dev20260218+14841977"
|
||||
source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#1484197707f35186ad3bd614357c7c47fdf86ebc" }
|
||||
version = "0.30.7.dev20260220+bdfe78f6"
|
||||
source = { git = "https://github.com/JakeHillion/mlx.git?branch=test-mlx-lazy-import#bdfe78f6e1fccb7cb3dfd049eb38f7c611e5f323" }
|
||||
resolution-markers = [
|
||||
"sys_platform == 'darwin'",
|
||||
]
|
||||
@@ -1106,7 +1094,7 @@ version = "0.30.7"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "jinja2", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mlx", version = "0.30.7.dev20260218+14841977", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#1484197707f35186ad3bd614357c7c47fdf86ebc" }, marker = "sys_platform == 'darwin'" },
|
||||
{ name = "mlx", version = "0.30.7.dev20260220+bdfe78f6", source = { git = "https://github.com/JakeHillion/mlx.git?branch=test-mlx-lazy-import#bdfe78f6e1fccb7cb3dfd049eb38f7c611e5f323" }, marker = "sys_platform == 'darwin'" },
|
||||
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "protobuf", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "pyyaml", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
|
||||
Reference in New Issue
Block a user