Compare commits

..

3 Commits

Author SHA1 Message Date
Ryuichi Leo Takashige
905fd5e900 fix prompt tokens reporting mistake 2026-02-12 19:00:07 +00:00
rltakashige
1f19619e1e Merge branch 'main' into leo/fix-usage-metrics 2026-02-12 18:28:46 +00:00
Ryuichi Leo Takashige
c4da2bc211 Pass usage and generation stats through all adapters correctly 2026-02-12 18:27:33 +00:00
17 changed files with 52 additions and 596 deletions

View File

@@ -1,15 +0,0 @@
.venv/
.direnv/
target/
.git/
.idea/
.pytest_cache/
.ruff_cache/
dashboard/node_modules/
dashboard/.svelte-kit/
dashboard/build/
dist/
*.pdb
**/__pycache__
**/.DS_Store
.mlx_typings/

View File

@@ -1,29 +0,0 @@
name: e2e-tests
on:
push:
pull_request:
branches:
- staging
- main
jobs:
e2e:
runs-on: ubuntu-latest
timeout-minutes: 30
steps:
- name: Free up disk space
run: |
sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc \
/opt/hostedtoolcache /usr/local/share/boost /usr/share/swift \
/opt/microsoft /opt/az
docker system prune -af
df -h /
- name: Checkout repository
uses: actions/checkout@v4
with:
lfs: false
- name: Run E2E tests
run: python3 e2e/run_all.py

View File

@@ -1,53 +0,0 @@
# Stage 1: Build the dashboard
FROM node:22-slim AS dashboard
WORKDIR /app/dashboard
COPY dashboard/package.json dashboard/package-lock.json ./
RUN npm ci
COPY dashboard/ .
RUN npm run build
# Stage 2: Build and run exo
FROM python:3.13-slim
# Install system dependencies
RUN apt-get update && apt-get install -y \
build-essential \
pkg-config \
libssl-dev \
curl \
protobuf-compiler \
iptables \
&& rm -rf /var/lib/apt/lists/*
# Install Rust nightly
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain nightly
ENV PATH="/root/.cargo/bin:${PATH}"
# Install uv
COPY --from=ghcr.io/astral-sh/uv:latest /uv /usr/local/bin/uv
WORKDIR /app
# Copy dependency files first for better layer caching
COPY pyproject.toml Cargo.toml uv.lock README.md ./
COPY rust/ ./rust/
COPY bench/pyproject.toml ./bench/pyproject.toml
# Copy source and resources
COPY src/ ./src/
COPY resources/ ./resources/
# Copy built dashboard from stage 1
COPY --from=dashboard /app/dashboard/build ./dashboard/build/
# Install Python deps and build Rust bindings, then clean up build artifacts
# to keep the layer small (Rust target/ and cargo registry can be 1-2 GB)
RUN uv sync && rm -rf /app/rust/target /root/.cargo/registry /root/.cargo/git
# Wrap g++ with -fpermissive to fix MLX CPU JIT compilation with GCC 14
# (GCC 14 treats _Float128/_Float32/_Float64 as built-in types, conflicting with MLX-generated code)
RUN mv /usr/bin/g++ /usr/bin/g++.real && \
printf '#!/bin/sh\nexec /usr/bin/g++.real -fpermissive "$@"\n' > /usr/bin/g++ && \
chmod +x /usr/bin/g++
CMD [".venv/bin/exo", "-v"]

View File

@@ -1,182 +0,0 @@
"""Shared E2E test infrastructure for exo cluster tests."""
import asyncio
import json
import os
import sys
from pathlib import Path
from urllib.error import URLError
from urllib.request import Request, urlopen
E2E_DIR = Path(__file__).parent.resolve()
TIMEOUT = int(os.environ.get("E2E_TIMEOUT", "120"))
class Cluster:
"""Async wrapper around a docker compose exo cluster."""
def __init__(self, name: str, overrides: list[str] | None = None):
self.name = name
self.project = f"e2e-{name}"
compose_files = [str(E2E_DIR / "docker-compose.yml")]
for path in overrides or []:
compose_files.append(str(E2E_DIR / path))
self._compose_base = [
"docker",
"compose",
"-p",
self.project,
*[arg for f in compose_files for arg in ("-f", f)],
]
async def __aenter__(self):
return self
async def __aexit__(self, *exc):
await self.stop()
async def _run(self, *args: str, check: bool = True) -> str:
proc = await asyncio.create_subprocess_exec(
*self._compose_base,
*args,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT,
)
stdout, _ = await proc.communicate()
output = stdout.decode()
if check and proc.returncode != 0:
print(output, file=sys.stderr)
raise RuntimeError(
f"docker compose {' '.join(args)} failed (rc={proc.returncode})"
)
return output
async def build(self):
print(" Building images...")
await self._run("build", "--quiet")
async def start(self):
print(" Starting cluster...")
await self._run("up", "-d")
async def stop(self):
print(" Cleaning up...")
await self._run("down", "--timeout", "5", check=False)
async def logs(self) -> str:
return await self._run("logs", check=False)
async def exec(
self, service: str, *cmd: str, check: bool = True
) -> tuple[int, str]:
"""Run a command inside a running container. Returns (returncode, output)."""
proc = await asyncio.create_subprocess_exec(
*self._compose_base,
"exec",
"-T",
service,
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT,
)
stdout, _ = await proc.communicate()
output = stdout.decode()
if check and proc.returncode != 0:
raise RuntimeError(
f"exec {' '.join(cmd)} in {service} failed (rc={proc.returncode})"
)
return proc.returncode, output
async def wait_for(self, description: str, check_fn, timeout: int = TIMEOUT):
"""Poll check_fn every 2s until it returns True or timeout expires."""
print(f" Waiting for {description}...")
deadline = asyncio.get_event_loop().time() + timeout
while asyncio.get_event_loop().time() < deadline:
if await check_fn():
print(f" {description}")
return
await asyncio.sleep(2)
output = await self.logs()
print(f"--- cluster logs ---\n{output}\n---", file=sys.stderr)
raise TimeoutError(f"Timed out waiting for {description}")
async def assert_healthy(self):
"""Verify the cluster formed correctly: nodes started, discovered each other, elected a master, API responds."""
async def both_nodes_started():
log = await self.logs()
return log.count("Starting node") >= 2
async def nodes_discovered():
log = await self.logs()
return log.count("ConnectionMessageType.Connected") >= 2
async def master_elected():
log = await self.logs()
return "demoting self" in log
async def api_responding():
try:
with urlopen("http://localhost:52415/v1/models", timeout=3) as resp:
return resp.status == 200
except (URLError, OSError):
return False
await self.wait_for("Both nodes started", both_nodes_started)
await self.wait_for("Nodes discovered each other", nodes_discovered)
await self.wait_for("Master election resolved", master_elected)
await self.wait_for("API responding", api_responding)
async def _api(
self, method: str, path: str, body: dict | None = None, timeout: int = 30
) -> dict:
"""Make an API request to the cluster. Returns parsed JSON."""
url = f"http://localhost:52415{path}"
data = json.dumps(body).encode() if body else None
req = Request(
url, data=data, headers={"Content-Type": "application/json"}, method=method
)
loop = asyncio.get_event_loop()
resp_bytes = await loop.run_in_executor(
None, lambda: urlopen(req, timeout=timeout).read()
)
return json.loads(resp_bytes)
async def place_model(self, model: str, timeout: int = 600):
"""Place a model instance on the cluster (triggers download) and wait until it's ready."""
await self._api("POST", "/place_instance", {"model_id": model})
async def model_ready():
try:
resp = await self._api("GET", "/v1/models")
return any(m.get("id") == model for m in resp.get("data", []))
except Exception:
return False
await self.wait_for(f"Model {model} ready", model_ready, timeout=timeout)
async def chat(
self, model: str, messages: list[dict], timeout: int = 600, **kwargs
) -> dict:
"""Send a chat completion request. Retries until model is downloaded and inference completes."""
body = json.dumps({"model": model, "messages": messages, **kwargs}).encode()
deadline = asyncio.get_event_loop().time() + timeout
last_error = None
while asyncio.get_event_loop().time() < deadline:
try:
req = Request(
"http://localhost:52415/v1/chat/completions",
data=body,
headers={"Content-Type": "application/json"},
)
loop = asyncio.get_event_loop()
resp_bytes = await loop.run_in_executor(
None, lambda r=req: urlopen(r, timeout=300).read()
)
return json.loads(resp_bytes)
except Exception as e:
last_error = e
await asyncio.sleep(5)
raise TimeoutError(f"Chat request failed after {timeout}s: {last_error}")

View File

@@ -1,18 +0,0 @@
services:
exo-node-1:
build:
context: ..
dockerfile: e2e/Dockerfile
environment:
- EXO_LIBP2P_NAMESPACE=docker-e2e
command: [".venv/bin/exo", "-v"]
ports:
- "52415:52415"
exo-node-2:
build:
context: ..
dockerfile: e2e/Dockerfile
environment:
- EXO_LIBP2P_NAMESPACE=docker-e2e
command: [".venv/bin/exo", "-v"]

View File

@@ -1,75 +0,0 @@
#!/usr/bin/env python3
"""Discovers and runs all E2E tests in e2e/test_*.py.
Tests with '# slow' on the first line of their docstring are skipped
unless --slow is passed or E2E_SLOW=1 is set.
"""
import os
import subprocess
import sys
from pathlib import Path
E2E_DIR = Path(__file__).parent.resolve()
def is_slow(test_file: Path) -> bool:
"""Check if the test file is marked as slow (has '# slow' in first 3 lines)."""
with open(test_file) as f:
for line in f:
if line.strip().startswith("#"):
continue
if line.strip().startswith('"""') or line.strip().startswith("'''"):
# Read into the docstring
for doc_line in f:
if "slow" in doc_line.lower() and doc_line.strip().startswith(
"slow"
):
return True
if '"""' in doc_line or "'''" in doc_line:
break
break
return False
def main():
run_slow = "--slow" in sys.argv or os.environ.get("E2E_SLOW") == "1"
test_files = sorted(E2E_DIR.glob("test_*.py"))
if not test_files:
print("No test files found")
sys.exit(1)
passed = 0
failed = 0
skipped = 0
failures = []
for test_file in test_files:
name = test_file.stem
if is_slow(test_file) and not run_slow:
print(f"=== {name} === SKIPPED (slow, use --slow to run)")
skipped += 1
continue
print(f"=== {name} ===")
result = subprocess.run([sys.executable, str(test_file)])
if result.returncode == 0:
passed += 1
else:
failed += 1
failures.append(name)
print()
total = passed + failed + skipped
print("================================")
print(
f"{passed}/{total} tests passed" + (f", {skipped} skipped" if skipped else "")
)
if failed:
print(f"Failed: {' '.join(failures)}")
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -1,8 +0,0 @@
{
"model": "mlx-community/Qwen3-0.6B-4bit",
"seed": 42,
"temperature": 0,
"prompt": "What is 2+2? Reply with just the number.",
"max_tokens": 32,
"content": "<think>\nOkay, so I need to figure out what 2+2 is. Let me think. Well, if you add 2 and 2 together"
}

View File

@@ -1,22 +0,0 @@
"""Test: Basic cluster formation.
Verifies two nodes discover each other, elect a master, and the API responds.
"""
import asyncio
import sys
sys.path.insert(0, str(__import__("pathlib").Path(__file__).parent))
from conftest import Cluster
async def main():
async with Cluster("cluster_formation") as cluster:
await cluster.build()
await cluster.start()
await cluster.assert_healthy()
print("PASSED: cluster_formation")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,82 +0,0 @@
"""Test: Deterministic inference output (snapshot test).
slow
Sends a chat completion request with a fixed seed and temperature=0,
then verifies the output matches a known-good snapshot. This ensures
inference produces consistent results across runs.
Requires a machine that can run MLX inference at reasonable speed (Apple Silicon).
Run with: python3 e2e/run_all.py --slow or E2E_SLOW=1 python3 e2e/run_all.py
"""
import asyncio
import json
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent))
from conftest import Cluster
MODEL = "mlx-community/Qwen3-0.6B-4bit"
SEED = 42
PROMPT = "What is 2+2? Reply with just the number."
MAX_TOKENS = 32
SNAPSHOT_FILE = Path(__file__).parent / "snapshots" / "inference.json"
async def main():
async with Cluster("inference_snapshot") as cluster:
await cluster.build()
await cluster.start()
await cluster.assert_healthy()
# Launch the model instance (triggers download + placement)
print(f" Launching model {MODEL}...")
await cluster.place_model(MODEL)
print(f" Sending chat completion (seed={SEED}, temperature=0)...")
resp = await cluster.chat(
model=MODEL,
messages=[{"role": "user", "content": PROMPT}],
seed=SEED,
temperature=0,
max_tokens=MAX_TOKENS,
)
content = resp["choices"][0]["message"]["content"]
print(f" Response: {content!r}")
# Load or create snapshot
if SNAPSHOT_FILE.exists():
snapshot = json.loads(SNAPSHOT_FILE.read_text())
expected = snapshot["content"]
assert content == expected, (
f"Snapshot mismatch!\n"
f" Expected: {expected!r}\n"
f" Got: {content!r}\n"
f" Delete {SNAPSHOT_FILE} to regenerate."
)
print(" Output matches snapshot")
else:
SNAPSHOT_FILE.parent.mkdir(parents=True, exist_ok=True)
SNAPSHOT_FILE.write_text(
json.dumps(
{
"model": MODEL,
"seed": SEED,
"temperature": 0,
"prompt": PROMPT,
"max_tokens": MAX_TOKENS,
"content": content,
},
indent=2,
)
+ "\n"
)
print(f" Snapshot created: {SNAPSHOT_FILE}")
print("PASSED: inference_snapshot")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,47 +0,0 @@
"""Test: Cluster works without internet access.
Verifies exo functions correctly when containers can talk to each other
but cannot reach the internet. Uses iptables to block all outbound traffic
except private subnets and multicast (for mDNS discovery).
"""
import asyncio
import sys
sys.path.insert(0, str(__import__("pathlib").Path(__file__).parent))
from conftest import Cluster
async def main():
async with Cluster(
"no_internet",
overrides=["tests/no_internet/docker-compose.override.yml"],
) as cluster:
await cluster.build()
await cluster.start()
await cluster.assert_healthy()
# Verify internet is actually blocked from inside the containers
for node in ["exo-node-1", "exo-node-2"]:
rc, _ = await cluster.exec(
node,
"curl",
"-sf",
"--max-time",
"3",
"https://huggingface.co",
check=False,
)
assert rc != 0, f"{node} should not be able to reach the internet"
print(f" {node}: internet correctly blocked")
# Verify exo detected no internet connectivity
log = await cluster.logs()
assert "Internet connectivity: False" in log, "exo should detect no internet"
print(" exo correctly detected no internet connectivity")
print("PASSED: no_internet")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,32 +0,0 @@
# Block all outbound internet traffic using iptables while preserving:
# - Multicast (224.0.0.0/4) for mDNS peer discovery
# - Private subnets (10/8, 172.16/12, 192.168/16) for inter-container communication
# - Loopback (127/8)
# Requires NET_ADMIN capability for iptables.
services:
exo-node-1:
cap_add:
- NET_ADMIN
entrypoint: ["/bin/sh", "-c"]
command:
- |
iptables -A OUTPUT -d 127.0.0.0/8 -j ACCEPT
iptables -A OUTPUT -d 10.0.0.0/8 -j ACCEPT
iptables -A OUTPUT -d 172.16.0.0/12 -j ACCEPT
iptables -A OUTPUT -d 192.168.0.0/16 -j ACCEPT
iptables -A OUTPUT -d 224.0.0.0/4 -j ACCEPT
iptables -A OUTPUT -j REJECT
exec .venv/bin/exo -v
exo-node-2:
cap_add:
- NET_ADMIN
entrypoint: ["/bin/sh", "-c"]
command:
- |
iptables -A OUTPUT -d 127.0.0.0/8 -j ACCEPT
iptables -A OUTPUT -d 10.0.0.0/8 -j ACCEPT
iptables -A OUTPUT -d 172.16.0.0/12 -j ACCEPT
iptables -A OUTPUT -d 192.168.0.0/16 -j ACCEPT
iptables -A OUTPUT -d 224.0.0.0/4 -j ACCEPT
iptables -A OUTPUT -j REJECT
exec .venv/bin/exo -v

View File

@@ -17,6 +17,7 @@ from exo.shared.types.api import (
LogprobsContentItem,
StreamingChoiceResponse,
ToolCall,
Usage,
)
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
from exo.shared.types.common import CommandId
@@ -125,6 +126,8 @@ async def generate_chat_stream(
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
) -> AsyncGenerator[str, None]:
"""Generate Chat Completions API streaming events from chunks."""
last_usage: Usage | None = None
async for chunk in chunk_stream:
if isinstance(chunk, ErrorChunk):
error_response = ErrorResponse(
@@ -138,6 +141,8 @@ async def generate_chat_stream(
yield "data: [DONE]\n\n"
return
last_usage = chunk.usage or last_usage
if isinstance(chunk, ToolCallChunk):
tool_call_deltas = [
ToolCall(
@@ -161,12 +166,15 @@ async def generate_chat_stream(
finish_reason="tool_calls",
)
],
usage=last_usage,
)
yield f"data: {tool_response.model_dump_json()}\n\n"
yield "data: [DONE]\n\n"
return
chunk_response = chunk_to_response(chunk, command_id)
if chunk.finish_reason is not None:
chunk_response = chunk_response.model_copy(update={"usage": last_usage})
yield f"data: {chunk_response.model_dump_json()}\n\n"
if chunk.finish_reason is not None:
@@ -184,6 +192,7 @@ async def collect_chat_response(
model: str | None = None
finish_reason: FinishReason | None = None
error_message: str | None = None
last_usage: Usage | None = None
async for chunk in chunk_stream:
if isinstance(chunk, ErrorChunk):
@@ -193,6 +202,8 @@ async def collect_chat_response(
if model is None:
model = chunk.model
last_usage = chunk.usage or last_usage
if isinstance(chunk, TokenChunk):
text_parts.append(chunk.text)
if chunk.logprob is not None:
@@ -241,4 +252,5 @@ async def collect_chat_response(
finish_reason=finish_reason,
)
],
usage=last_usage,
)

View File

@@ -4,7 +4,7 @@ import json
from collections.abc import AsyncGenerator
from typing import Any
from exo.shared.types.api import FinishReason
from exo.shared.types.api import FinishReason, Usage
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
from exo.shared.types.claude_api import (
ClaudeContentBlock,
@@ -166,7 +166,7 @@ async def collect_claude_response(
text_parts: list[str] = []
tool_use_blocks: list[ClaudeToolUseBlock] = []
stop_reason: ClaudeStopReason | None = None
last_stats = None
last_usage: Usage | None = None
error_message: str | None = None
async for chunk in chunk_stream:
@@ -174,6 +174,8 @@ async def collect_claude_response(
error_message = chunk.error_message or "Internal server error"
break
last_usage = chunk.usage or last_usage
if isinstance(chunk, ToolCallChunk):
for tool in chunk.tool_calls:
tool_use_blocks.append(
@@ -183,12 +185,10 @@ async def collect_claude_response(
input=json.loads(tool.arguments), # pyright: ignore[reportAny]
)
)
last_stats = chunk.stats or last_stats
stop_reason = "tool_use"
continue
text_parts.append(chunk.text)
last_stats = chunk.stats or last_stats
if chunk.finish_reason is not None:
stop_reason = finish_reason_to_claude_stop_reason(chunk.finish_reason)
@@ -208,9 +208,9 @@ async def collect_claude_response(
if not content:
content.append(ClaudeTextBlock(text=""))
# Use actual usage data from stats if available
input_tokens = last_stats.prompt_tokens if last_stats else 0
output_tokens = last_stats.generation_tokens if last_stats else 0
# Use actual usage data if available
input_tokens = last_usage.prompt_tokens if last_usage else 0
output_tokens = last_usage.completion_tokens if last_usage else 0
return ClaudeMessagesResponse(
id=f"msg_{command_id}",
@@ -249,7 +249,7 @@ async def generate_claude_stream(
output_tokens = 0
stop_reason: ClaudeStopReason | None = None
last_stats = None
last_usage: Usage | None = None
next_block_index = 1 # text block is 0, tool blocks start at 1
async for chunk in chunk_stream:
@@ -257,8 +257,9 @@ async def generate_claude_stream(
# Close text block and bail
break
last_usage = chunk.usage or last_usage
if isinstance(chunk, ToolCallChunk):
last_stats = chunk.stats or last_stats
stop_reason = "tool_use"
# Emit tool_use content blocks
@@ -290,7 +291,6 @@ async def generate_claude_stream(
continue
output_tokens += 1 # Count each chunk as one token
last_stats = chunk.stats or last_stats
# content_block_delta
delta_event = ClaudeContentBlockDeltaEvent(
@@ -302,9 +302,9 @@ async def generate_claude_stream(
if chunk.finish_reason is not None:
stop_reason = finish_reason_to_claude_stop_reason(chunk.finish_reason)
# Use actual token count from stats if available
if last_stats is not None:
output_tokens = last_stats.generation_tokens
# Use actual token count from usage if available
if last_usage is not None:
output_tokens = last_usage.completion_tokens
# content_block_stop for text block
block_stop = ClaudeContentBlockStopEvent(index=0)

View File

@@ -4,6 +4,7 @@ from collections.abc import AsyncGenerator
from itertools import count
from typing import Any
from exo.shared.types.api import Usage
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
from exo.shared.types.common import CommandId
from exo.shared.types.openai_responses import (
@@ -127,7 +128,7 @@ async def collect_responses_response(
item_id = f"item_{command_id}"
accumulated_text = ""
function_call_items: list[ResponseFunctionCallItem] = []
last_stats = None
last_usage: Usage | None = None
error_message: str | None = None
async for chunk in chunk_stream:
@@ -135,6 +136,8 @@ async def collect_responses_response(
error_message = chunk.error_message or "Internal server error"
break
last_usage = chunk.usage or last_usage
if isinstance(chunk, ToolCallChunk):
for tool in chunk.tool_calls:
function_call_items.append(
@@ -145,22 +148,20 @@ async def collect_responses_response(
arguments=tool.arguments,
)
)
last_stats = chunk.stats or last_stats
continue
accumulated_text += chunk.text
last_stats = chunk.stats or last_stats
if error_message is not None:
raise ValueError(error_message)
# Create usage from stats if available
# Create usage from usage data if available
usage = None
if last_stats is not None:
if last_usage is not None:
usage = ResponseUsage(
input_tokens=last_stats.prompt_tokens,
output_tokens=last_stats.generation_tokens,
total_tokens=last_stats.prompt_tokens + last_stats.generation_tokens,
input_tokens=last_usage.prompt_tokens,
output_tokens=last_usage.completion_tokens,
total_tokens=last_usage.total_tokens,
)
output: list[ResponseItem] = [
@@ -235,15 +236,16 @@ async def generate_responses_stream(
accumulated_text = ""
function_call_items: list[ResponseFunctionCallItem] = []
last_stats = None
last_usage: Usage | None = None
next_output_index = 1 # message item is at 0
async for chunk in chunk_stream:
if isinstance(chunk, ErrorChunk):
break
last_usage = chunk.usage or last_usage
if isinstance(chunk, ToolCallChunk):
last_stats = chunk.stats or last_stats
for tool in chunk.tool_calls:
fc_id = f"fc_{tool.id}"
call_id = f"call_{tool.id}"
@@ -302,7 +304,6 @@ async def generate_responses_stream(
continue
accumulated_text += chunk.text
last_stats = chunk.stats or last_stats
# response.output_text.delta
delta_event = ResponseTextDeltaEvent(
@@ -346,13 +347,13 @@ async def generate_responses_stream(
)
yield f"event: response.output_item.done\ndata: {item_done.model_dump_json()}\n\n"
# Create usage from stats if available
# Create usage from usage data if available
usage = None
if last_stats is not None:
if last_usage is not None:
usage = ResponseUsage(
input_tokens=last_stats.prompt_tokens,
output_tokens=last_stats.generation_tokens,
total_tokens=last_stats.prompt_tokens + last_stats.generation_tokens,
input_tokens=last_usage.prompt_tokens,
output_tokens=last_usage.completion_tokens,
total_tokens=last_usage.total_tokens,
)
# response.completed

View File

@@ -62,6 +62,7 @@ class PartialImageResponse(BaseRunnerResponse):
class ToolCallResponse(BaseRunnerResponse):
tool_calls: list[ToolCallItem]
usage: Usage | None
stats: GenerationStats | None = None
class FinishedResponse(BaseRunnerResponse):

View File

@@ -393,10 +393,11 @@ def mlx_generate(
f"Model generated unexpected finish_reason: {out.finish_reason}"
)
total_prompt_tokens = len(all_prompt_tokens)
usage = Usage(
prompt_tokens=int(out.prompt_tokens),
prompt_tokens=total_prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=int(out.prompt_tokens) + completion_tokens,
total_tokens=total_prompt_tokens + completion_tokens,
prompt_tokens_details=PromptTokensDetails(
cached_tokens=prefix_hit_length
),

View File

@@ -364,6 +364,7 @@ def main(
tool_calls=response.tool_calls,
model=shard_metadata.model_card.model_id,
usage=response.usage,
stats=response.stats,
),
)
)
@@ -764,7 +765,9 @@ def parse_tool_calls(
tools = [_validate_single_tool(tool) for tool in parsed]
else:
tools = [_validate_single_tool(parsed)]
yield ToolCallResponse(tool_calls=tools, usage=response.usage)
yield ToolCallResponse(
tool_calls=tools, usage=response.usage, stats=response.stats
)
except (
json.JSONDecodeError,
@@ -795,7 +798,8 @@ def parse_tool_calls(
text=tool_call_start + "".join(tool_call_text_parts),
token=0,
finish_reason=response.finish_reason,
usage=None,
usage=response.usage,
stats=response.stats,
)
continue
# fallthrough