fmt: format all python/rust/nix files

This commit is contained in:
Jake Hillion
2025-12-05 16:58:55 +00:00
committed by GitHub
parent 7312a7e000
commit 5629983809
9 changed files with 460 additions and 299 deletions

View File

@@ -9,6 +9,7 @@ Requests are fire-and-forget, allowing overlapping execution.
Simple benchmark (1 iteration): --config .github/configs/bench_simple.yaml
Complex benchmark (multiple stages): --config .github/configs/bench_config.yaml
"""
# pyright: reportAny=false, reportUnknownArgumentType=false, reportUnknownVariableType=false
from __future__ import annotations
@@ -34,7 +35,11 @@ def _format_http_error(error: Exception) -> str:
except Exception:
body = "<unable to read body>"
headers_str = "\n".join(f" {k}: {v}" for k, v in error.headers.items()) if error.headers else "<no headers>"
headers_str = (
"\n".join(f" {k}: {v}" for k, v in error.headers.items())
if error.headers
else "<no headers>"
)
return (
f"HTTP {error.code} {error.reason}\n"
@@ -48,7 +53,9 @@ def _format_http_error(error: Exception) -> str:
return str(error)
def _http_request(url: str, *, method: str = "GET", data: Mapping[str, Any] | None = None) -> dict[str, Any]:
def _http_request(
url: str, *, method: str = "GET", data: Mapping[str, Any] | None = None
) -> dict[str, Any]:
headers = {"Content-Type": "application/json"}
payload: bytes | None = None
if data is not None:
@@ -67,14 +74,21 @@ def _http_request(url: str, *, method: str = "GET", data: Mapping[str, Any] | No
raise
async def _http_request_async(url: str, *, method: str = "GET", data: Mapping[str, Any] | None = None) -> dict[str, Any]:
async def _http_request_async(
url: str, *, method: str = "GET", data: Mapping[str, Any] | None = None
) -> dict[str, Any]:
"""Async version that runs in executor to not block event loop."""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, lambda: _http_request(url, method=method, data=data))
return await loop.run_in_executor(
None, lambda: _http_request(url, method=method, data=data)
)
async def _http_stream_async(url: str, *, method: str = "POST", data: Mapping[str, Any], timeout: int = 300) -> list[tuple[str, float]]:
async def _http_stream_async(
url: str, *, method: str = "POST", data: Mapping[str, Any], timeout: int = 300
) -> list[tuple[str, float]]:
"""Async streaming request. Returns list of (line, timestamp) tuples."""
def _stream() -> list[tuple[str, float]]:
headers = {"Content-Type": "application/json"}
payload = json.dumps(data).encode("utf-8")
@@ -92,6 +106,7 @@ async def _http_stream_async(url: str, *, method: str = "POST", data: Mapping[st
error_details = _format_http_error(e)
print(f"HTTP request failed:\n{error_details}")
raise
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, _stream)
@@ -196,7 +211,9 @@ def collect_metrics_snapshot(state: Mapping[str, Any]) -> MetricsSnapshot:
# Extract actual task from wrapper (e.g., {"ChatCompletion": {...}})
if len(task_wrapper) != 1:
print(f"[DEBUG] Task wrapper has unexpected number of keys: {len(task_wrapper)}")
print(
f"[DEBUG] Task wrapper has unexpected number of keys: {len(task_wrapper)}"
)
tasks_skipped += 1
continue
@@ -223,7 +240,9 @@ def collect_metrics_snapshot(state: Mapping[str, Any]) -> MetricsSnapshot:
tasks_matched += 1
if tasks_skipped > 0:
print(f"[DEBUG] Task matching: {tasks_matched} matched, {tasks_skipped} skipped (from {len(tasks)} total)")
print(
f"[DEBUG] Task matching: {tasks_matched} matched, {tasks_skipped} skipped (from {len(tasks)} total)"
)
# Build snapshots for each instance (assign to primary node - first in sorted order)
for instance_id, counts in instance_task_counts.items():
@@ -310,7 +329,9 @@ def count_instances_by_model(state: Mapping[str, Any], model_id: str) -> int:
return count
def get_all_instance_ids_for_model(state: Mapping[str, Any], model_id: str) -> list[str]:
def get_all_instance_ids_for_model(
state: Mapping[str, Any], model_id: str
) -> list[str]:
"""Get all instance IDs for a given model_id."""
instances: Mapping[str, Any] = state.get("instances", {})
instance_ids = []
@@ -358,7 +379,9 @@ def count_ready_instances_by_model(state: Mapping[str, Any], model_id: str) -> i
return ready_count
def get_runner_ids_for_instance(state: Mapping[str, Any], instance_id: str) -> list[str]:
def get_runner_ids_for_instance(
state: Mapping[str, Any], instance_id: str
) -> list[str]:
instances: Mapping[str, Any] = state.get("instances", {})
instance_wrapped = instances.get(instance_id, {})
@@ -387,28 +410,40 @@ def get_runner_status_kind(state: Mapping[str, Any], runner_id: str) -> str | No
return None
async def wait_for_topology_ready(api_base: str, expected_nodes: int, timeout_s: int) -> None:
async def wait_for_topology_ready(
api_base: str, expected_nodes: int, timeout_s: int
) -> None:
"""Wait for all expected nodes to appear in the topology."""
print(f"Waiting for {expected_nodes} node(s) to appear in topology (timeout: {timeout_s}s)...")
print(
f"Waiting for {expected_nodes} node(s) to appear in topology (timeout: {timeout_s}s)..."
)
start = time.monotonic()
while True:
state = fetch_state(api_base)
node_count = get_topology_node_count(state)
elapsed = time.monotonic() - start
print(f" Topology has {node_count}/{expected_nodes} nodes (elapsed: {elapsed:.1f}s)")
print(
f" Topology has {node_count}/{expected_nodes} nodes (elapsed: {elapsed:.1f}s)"
)
if node_count >= expected_nodes:
print(f"All {expected_nodes} node(s) are in topology!")
return
if elapsed > timeout_s:
raise TimeoutError(f"Timed out waiting for topology. Expected {expected_nodes} nodes, got {node_count}")
raise TimeoutError(
f"Timed out waiting for topology. Expected {expected_nodes} nodes, got {node_count}"
)
await asyncio.sleep(2)
async def wait_for_instances_ready(api_base: str, model_id: str, expected_count: int, timeout_s: int) -> list[str]:
async def wait_for_instances_ready(
api_base: str, model_id: str, expected_count: int, timeout_s: int
) -> list[str]:
"""Wait for a specific count of instances for a model to be fully ready."""
print(f"Waiting for {expected_count} instance(s) of {model_id} to be ready (timeout: {timeout_s}s)...")
print(
f"Waiting for {expected_count} instance(s) of {model_id} to be ready (timeout: {timeout_s}s)..."
)
start = time.monotonic()
while True:
state = fetch_state(api_base)
@@ -417,11 +452,15 @@ async def wait_for_instances_ready(api_base: str, model_id: str, expected_count:
ready_count = count_ready_instances_by_model(state, model_id)
elapsed = time.monotonic() - start
print(f" Model {model_id}: {ready_count}/{expected_count} ready ({total_count} total) (elapsed: {elapsed:.1f}s)")
print(
f" Model {model_id}: {ready_count}/{expected_count} ready ({total_count} total) (elapsed: {elapsed:.1f}s)"
)
if ready_count >= expected_count:
instance_ids = get_all_instance_ids_for_model(state, model_id)
print(f"All {expected_count} instance(s) ready! Instance IDs: {instance_ids}")
print(
f"All {expected_count} instance(s) ready! Instance IDs: {instance_ids}"
)
return instance_ids
if elapsed > timeout_s:
@@ -452,9 +491,9 @@ async def wait_for_tasks_drained(api_base: str, timeout_s: int = 600) -> None:
Tasks are deleted from state when complete, so we wait until there are no
pending or running tasks remaining.
"""
print(f"\n{'='*80}")
print(f"\n{'=' * 80}")
print(f"⏳ WAITING FOR ALL TASKS TO DRAIN")
print(f"{'='*80}")
print(f"{'=' * 80}")
start = time.monotonic()
while True:
@@ -472,18 +511,26 @@ async def wait_for_tasks_drained(api_base: str, timeout_s: int = 600) -> None:
print(f"✅ All tasks drained after {elapsed:.1f}s")
return
print(f" [{elapsed:.1f}s] Still draining: {total_active} active tasks ({total_pending} pending, {total_running} running)")
print(
f" [{elapsed:.1f}s] Still draining: {total_active} active tasks ({total_pending} pending, {total_running} running)"
)
# Print per-node breakdown if there are active tasks
if snapshot.node_tasks:
for node_snapshot in snapshot.node_tasks:
if node_snapshot.total_active_tasks > 0:
node_short = node_snapshot.node_id[-4:]
print(f" Node ...{node_short}: {node_snapshot.running_tasks} running, {node_snapshot.pending_tasks} pending")
print(
f" Node ...{node_short}: {node_snapshot.running_tasks} running, {node_snapshot.pending_tasks} pending"
)
if elapsed > timeout_s:
print(f"⚠️ WARNING: Timed out waiting for tasks to drain after {timeout_s}s")
print(f" Remaining: {total_active} tasks ({total_pending} pending, {total_running} running)")
print(
f"⚠️ WARNING: Timed out waiting for tasks to drain after {timeout_s}s"
)
print(
f" Remaining: {total_active} tasks ({total_pending} pending, {total_running} running)"
)
return
await asyncio.sleep(2)
@@ -545,6 +592,7 @@ class StageResult:
@dataclass(frozen=True)
class MemorySnapshot:
"""Memory snapshot for a node at a point in time."""
ram_total_bytes: int
ram_available_bytes: int
ram_used_bytes: int
@@ -560,6 +608,7 @@ class InstanceTaskSnapshot:
Note: Tasks are deleted from state when complete, so we only track active tasks.
total_active_tasks = pending + running.
"""
instance_id: str
node_id: str
pending_tasks: int
@@ -574,6 +623,7 @@ class NodeTaskSnapshot:
Note: Tasks are deleted from state when complete, so we only track active tasks.
total_active_tasks = pending + running across all instances on this node.
"""
node_id: str
pending_tasks: int
running_tasks: int
@@ -584,6 +634,7 @@ class NodeTaskSnapshot:
@dataclass(frozen=True)
class MetricsSnapshot:
"""System metrics snapshot at a point in time."""
timestamp: float
node_memory: dict[str, MemorySnapshot]
instance_tasks: list[InstanceTaskSnapshot]
@@ -622,7 +673,7 @@ async def run_single_request(
for line, timestamp in lines:
if not line.startswith("data:"):
continue
payload = line[len("data:"):].strip()
payload = line[len("data:") :].strip()
if payload == "[DONE]":
got_done = True
break
@@ -655,7 +706,9 @@ async def run_single_request(
# Request is only successful if we got at least one token AND a [DONE] marker
if tokens == 0:
print(f" Request #{request_id}: FAILED - no tokens generated in {elapsed:.2f}s")
print(
f" Request #{request_id}: FAILED - no tokens generated in {elapsed:.2f}s"
)
return RequestResult(
request_id=request_id,
success=False,
@@ -665,11 +718,13 @@ async def run_single_request(
completed_at=completed_at,
time_to_first_token_s=time_to_first_token,
decode_tps=decode_tps,
error="No tokens generated"
error="No tokens generated",
)
if not got_done:
print(f" Request #{request_id}: FAILED - incomplete response (no [DONE]) after {elapsed:.2f}s")
print(
f" Request #{request_id}: FAILED - incomplete response (no [DONE]) after {elapsed:.2f}s"
)
return RequestResult(
request_id=request_id,
success=False,
@@ -679,12 +734,16 @@ async def run_single_request(
completed_at=completed_at,
time_to_first_token_s=time_to_first_token,
decode_tps=decode_tps,
error="Incomplete response (no [DONE] marker)"
error="Incomplete response (no [DONE] marker)",
)
ttft_str = f"{time_to_first_token:.3f}s" if time_to_first_token is not None else "N/A"
ttft_str = (
f"{time_to_first_token:.3f}s" if time_to_first_token is not None else "N/A"
)
tps_str = f"{decode_tps:.1f} t/s" if decode_tps is not None else "N/A"
print(f" Request #{request_id}: SUCCESS - {tokens} tokens in {elapsed:.2f}s (TTFT: {ttft_str}, Decode: {tps_str})")
print(
f" Request #{request_id}: SUCCESS - {tokens} tokens in {elapsed:.2f}s (TTFT: {ttft_str}, Decode: {tps_str})"
)
return RequestResult(
request_id=request_id,
success=True,
@@ -693,7 +752,7 @@ async def run_single_request(
started_at=started_at,
completed_at=completed_at,
time_to_first_token_s=time_to_first_token,
decode_tps=decode_tps
decode_tps=decode_tps,
)
except Exception as e:
@@ -710,7 +769,7 @@ async def run_single_request(
completed_at=completed_at,
time_to_first_token_s=None,
decode_tps=None,
error=error_details
error=error_details,
)
@@ -721,9 +780,9 @@ async def monitor_metrics(
interval_seconds: float = 5.0,
) -> None:
"""Background task that collects metrics snapshots every interval_seconds."""
print(f"\n{'='*80}")
print(f"\n{'=' * 80}")
print(f"🔍 METRICS MONITORING STARTED (polling every {interval_seconds}s)")
print(f"{'='*80}\n")
print(f"{'=' * 80}\n")
snapshot_count = 0
while not stop_event.is_set():
@@ -743,17 +802,22 @@ async def monitor_metrics(
total_active = sum(node.total_active_tasks for node in snapshot.node_tasks)
# Print detailed breakdown
print(f"\n[METRICS #{snapshot_count}] {node_count} nodes, {instance_count} instances | Active Tasks: {total_active} ({total_pending} pending, {total_running} running)")
print(
f"\n[METRICS #{snapshot_count}] {node_count} nodes, {instance_count} instances | Active Tasks: {total_active} ({total_pending} pending, {total_running} running)"
)
# Print per-node breakdown (only if there are nodes)
if snapshot.node_tasks:
for node_snapshot in snapshot.node_tasks:
node_short = node_snapshot.node_id[-4:]
print(f" Node ...{node_short}: {node_snapshot.total_active_tasks} active ({node_snapshot.pending_tasks} pending, {node_snapshot.running_tasks} running) across {node_snapshot.instance_count} instances")
print(
f" Node ...{node_short}: {node_snapshot.total_active_tasks} active ({node_snapshot.pending_tasks} pending, {node_snapshot.running_tasks} running) across {node_snapshot.instance_count} instances"
)
except Exception as e:
print(f"[METRICS] Error collecting snapshot: {e}")
import traceback
traceback.print_exc()
# Wait for interval or until stopped
@@ -788,7 +852,9 @@ async def run_stage(
# Sequential execution: wait for each request to complete before starting next
print("\nRunning requests sequentially (no overlap)...")
for i in range(stage.iterations):
result = await run_single_request(api_base, model_id, prompt, stage.generation_length, i + 1)
result = await run_single_request(
api_base, model_id, prompt, stage.generation_length, i + 1
)
results.append(result)
# Wait before starting next request (except after last one)
@@ -802,7 +868,9 @@ async def run_stage(
# Fire off requests with delays between them
for i in range(stage.iterations):
task = asyncio.create_task(
run_single_request(api_base, model_id, prompt, stage.generation_length, i + 1)
run_single_request(
api_base, model_id, prompt, stage.generation_length, i + 1
)
)
tasks.append(task)
@@ -833,29 +901,47 @@ async def run_stage(
successful_results = [r for r in results if r.success]
# Skip first iteration if there are more than 1 iterations (warmup)
results_for_stats = successful_results[1:] if len(successful_results) > 1 else successful_results
results_for_stats = (
successful_results[1:] if len(successful_results) > 1 else successful_results
)
# TTFT statistics
ttft_values = [r.time_to_first_token_s for r in results_for_stats if r.time_to_first_token_s is not None]
ttft_values = [
r.time_to_first_token_s
for r in results_for_stats
if r.time_to_first_token_s is not None
]
avg_ttft = sum(ttft_values) / len(ttft_values) if ttft_values else None
if avg_ttft is not None and len(ttft_values) > 1:
variance_ttft = sum((x - avg_ttft) ** 2 for x in ttft_values) / len(ttft_values)
std_ttft = variance_ttft ** 0.5
std_ttft = variance_ttft**0.5
else:
std_ttft = None
# Decode TPS and ms per token statistics
decode_tps_values = [r.decode_tps for r in results_for_stats if r.decode_tps is not None]
avg_decode_tps = sum(decode_tps_values) / len(decode_tps_values) if decode_tps_values else None
decode_tps_values = [
r.decode_tps for r in results_for_stats if r.decode_tps is not None
]
avg_decode_tps = (
sum(decode_tps_values) / len(decode_tps_values) if decode_tps_values else None
)
# Convert to ms per token
ms_per_token_values = [1000.0 / tps for tps in decode_tps_values] if decode_tps_values else []
avg_ms_per_token = sum(ms_per_token_values) / len(ms_per_token_values) if ms_per_token_values else None
ms_per_token_values = (
[1000.0 / tps for tps in decode_tps_values] if decode_tps_values else []
)
avg_ms_per_token = (
sum(ms_per_token_values) / len(ms_per_token_values)
if ms_per_token_values
else None
)
if avg_ms_per_token is not None and len(ms_per_token_values) > 1:
variance_ms_per_token = sum((x - avg_ms_per_token) ** 2 for x in ms_per_token_values) / len(ms_per_token_values)
std_ms_per_token = variance_ms_per_token ** 0.5
variance_ms_per_token = sum(
(x - avg_ms_per_token) ** 2 for x in ms_per_token_values
) / len(ms_per_token_values)
std_ms_per_token = variance_ms_per_token**0.5
else:
std_ms_per_token = None
@@ -920,8 +1006,12 @@ async def run_benchmark(
print(f"Configuration File: {config_path}")
print(f"Model IDs: {model_ids}")
print(f"Instance Count: {len(model_ids)}")
print(f"Sharding: {sharding if sharding else 'not specified (defaults to Pipeline)'}")
print(f"Instance Type: {instance_meta if instance_meta else 'not specified (defaults to MlxRing)'}")
print(
f"Sharding: {sharding if sharding else 'not specified (defaults to Pipeline)'}"
)
print(
f"Instance Type: {instance_meta if instance_meta else 'not specified (defaults to MlxRing)'}"
)
print(f"No Overlap: {no_overlap}")
print(f"Stages: {len(stages)}")
print(f"Expected Nodes: {expected_nodes}")
@@ -930,15 +1020,20 @@ async def run_benchmark(
try:
# Wait for all nodes to join the topology first
await wait_for_topology_ready(api_base, expected_nodes, timeout_s=timeout_seconds)
await wait_for_topology_ready(
api_base, expected_nodes, timeout_s=timeout_seconds
)
# Add 30 second delay to allow topology to stabilize before creating instances
print(f"\nWaiting 30 seconds for topology to stabilize before creating instances...")
print(
f"\nWaiting 30 seconds for topology to stabilize before creating instances..."
)
await asyncio.sleep(30)
print("Proceeding with instance creation\n")
# Count how many instances we need for each unique model_id
from collections import Counter
model_counts = Counter(model_ids)
print(f"\nTarget instance counts by model:")
@@ -958,8 +1053,12 @@ async def run_benchmark(
target_count = current_ready + 1
print("=" * 80)
print(f"[PRIMARY] Creating instance {idx+1}/{len(model_ids)} for model: {model_id}")
print(f"[PRIMARY] Current ready count for {model_id}: {current_ready}, target: {target_count}")
print(
f"[PRIMARY] Creating instance {idx + 1}/{len(model_ids)} for model: {model_id}"
)
print(
f"[PRIMARY] Current ready count for {model_id}: {current_ready}, target: {target_count}"
)
# Build instance creation request data
instance_data: dict[str, Any] = {"model_id": model_id}
@@ -969,21 +1068,23 @@ async def run_benchmark(
instance_data["instance_meta"] = instance_meta
response = await _http_request_async(
f"{api_base}/instance",
method="POST",
data=instance_data
f"{api_base}/instance", method="POST", data=instance_data
)
print(f"[PRIMARY] Instance creation response: {response}")
# Wait for one more instance of this model to be ready
await wait_for_instances_ready(api_base, model_id, target_count, timeout_s=timeout_seconds)
print(f"[PRIMARY] Instance {idx+1}/{len(model_ids)} is ready")
await wait_for_instances_ready(
api_base, model_id, target_count, timeout_s=timeout_seconds
)
print(f"[PRIMARY] Instance {idx + 1}/{len(model_ids)} is ready")
print("=" * 80)
else:
# Secondary: wait for expected counts of each model to be ready
print("[SECONDARY] Waiting for all instances to be created and ready...")
for model_id, expected_count in model_counts.items():
await wait_for_instances_ready(api_base, model_id, expected_count, timeout_s=timeout_seconds)
await wait_for_instances_ready(
api_base, model_id, expected_count, timeout_s=timeout_seconds
)
# Collect all instance IDs for all models
state = fetch_state(api_base)
@@ -997,7 +1098,9 @@ async def run_benchmark(
runner_ids = get_runner_ids_for_instance(state, instance_id)
total_runners += len(runner_ids)
print(f"\nAll {len(all_instance_ids)} instance(s) with {total_runners} total runner(s) are ready!")
print(
f"\nAll {len(all_instance_ids)} instance(s) with {total_runners} total runner(s) are ready!"
)
print(f"Instance IDs: {all_instance_ids}")
if is_primary:
@@ -1013,12 +1116,16 @@ async def run_benchmark(
metrics_snapshots: list[MetricsSnapshot] = []
stop_monitoring = asyncio.Event()
monitoring_task = asyncio.create_task(
monitor_metrics(api_base, metrics_snapshots, stop_monitoring, interval_seconds=0.5)
monitor_metrics(
api_base, metrics_snapshots, stop_monitoring, interval_seconds=0.5
)
)
stage_results: list[StageResult] = []
for stage in stages:
result = await run_stage(api_base, benchmark_model_id, stage, no_overlap=no_overlap)
result = await run_stage(
api_base, benchmark_model_id, stage, no_overlap=no_overlap
)
stage_results.append(result)
# Stop metrics monitoring
@@ -1046,12 +1153,18 @@ async def run_benchmark(
print(f" Avg Time/Request: {result.avg_time_per_request:.2f}s")
if result.avg_time_to_first_token is not None:
if result.std_time_to_first_token is not None:
print(f" Avg TTFT: {result.avg_time_to_first_token:.3f}s ± {result.std_time_to_first_token:.3f}s")
print(
f" Avg TTFT: {result.avg_time_to_first_token:.3f}s ± {result.std_time_to_first_token:.3f}s"
)
else:
print(f" Avg TTFT: {result.avg_time_to_first_token:.3f}s")
print(
f" Avg TTFT: {result.avg_time_to_first_token:.3f}s"
)
if result.avg_ms_per_token is not None:
if result.std_ms_per_token is not None:
print(f" Avg ms/token: {result.avg_ms_per_token:.2f}ms ± {result.std_ms_per_token:.2f}ms")
print(
f" Avg ms/token: {result.avg_ms_per_token:.2f}ms ± {result.std_ms_per_token:.2f}ms"
)
else:
print(f" Avg ms/token: {result.avg_ms_per_token:.2f}ms")
if result.avg_decode_tps is not None:
@@ -1100,16 +1213,33 @@ async def run_benchmark(
"failed_requests": r.failed_requests,
"success_rate": round(r.success_rate, 4),
"total_tokens": r.total_tokens,
"avg_tokens_per_request": round(r.avg_tokens_per_request, 2),
"avg_tokens_per_request": round(
r.avg_tokens_per_request, 2
),
"avg_time_per_request": round(r.avg_time_per_request, 3),
"avg_time_to_first_token": round(r.avg_time_to_first_token, 3) if r.avg_time_to_first_token is not None else None,
"std_time_to_first_token": round(r.std_time_to_first_token, 3) if r.std_time_to_first_token is not None else None,
"avg_decode_tps": round(r.avg_decode_tps, 2) if r.avg_decode_tps is not None else None,
"avg_ms_per_token": round(r.avg_ms_per_token, 2) if r.avg_ms_per_token is not None else None,
"std_ms_per_token": round(r.std_ms_per_token, 2) if r.std_ms_per_token is not None else None,
"avg_time_to_first_token": round(
r.avg_time_to_first_token, 3
)
if r.avg_time_to_first_token is not None
else None,
"std_time_to_first_token": round(
r.std_time_to_first_token, 3
)
if r.std_time_to_first_token is not None
else None,
"avg_decode_tps": round(r.avg_decode_tps, 2)
if r.avg_decode_tps is not None
else None,
"avg_ms_per_token": round(r.avg_ms_per_token, 2)
if r.avg_ms_per_token is not None
else None,
"std_ms_per_token": round(r.std_ms_per_token, 2)
if r.std_ms_per_token is not None
else None,
"stage_started_at": r.stage_started_at,
"stage_completed_at": r.stage_completed_at,
"stage_duration_s": r.stage_completed_at - r.stage_started_at,
"stage_duration_s": r.stage_completed_at
- r.stage_started_at,
"requests": [
{
"request_id": req.request_id,
@@ -1118,12 +1248,18 @@ async def run_benchmark(
"elapsed_s": round(req.elapsed_s, 3),
"started_at": req.started_at,
"completed_at": req.completed_at,
"time_to_first_token_s": round(req.time_to_first_token_s, 3) if req.time_to_first_token_s is not None else None,
"decode_tps": round(req.decode_tps, 2) if req.decode_tps is not None else None,
"time_to_first_token_s": round(
req.time_to_first_token_s, 3
)
if req.time_to_first_token_s is not None
else None,
"decode_tps": round(req.decode_tps, 2)
if req.decode_tps is not None
else None,
"error": req.error,
}
for req in r.request_results
]
],
}
for r in stage_results
]
@@ -1162,11 +1298,11 @@ async def run_benchmark(
"instance_count": node.instance_count,
}
for node in snapshot.node_tasks
]
],
}
for snapshot in metrics_snapshots
]
}
},
}
# Output JSON summary
@@ -1186,10 +1322,14 @@ async def run_benchmark(
# Cleanup all instances
for instance_id in all_instance_ids:
print(f"[PRIMARY] Cleaning up instance: {instance_id}")
await _http_request_async(f"{api_base}/instance/{instance_id}", method="DELETE")
await _http_request_async(
f"{api_base}/instance/{instance_id}", method="DELETE"
)
print(f"[PRIMARY] Instance {instance_id} deleted successfully")
else:
print("[SECONDARY] Waiting with cluster (primary handles benchmark execution)")
print(
"[SECONDARY] Waiting with cluster (primary handles benchmark execution)"
)
# Secondary nodes wait until all instances of all models are deleted
for model_id in model_counts.keys():
await wait_for_all_instances_deleted(api_base, model_id)
@@ -1205,39 +1345,56 @@ async def run_benchmark(
print("=" * 80)
print(f"ERROR: {e}")
import traceback
traceback.print_exc()
print("=" * 80)
return 1
def main() -> int:
parser = argparse.ArgumentParser(description="Run unified benchmark for EXO (single or multi-stage)")
parser = argparse.ArgumentParser(
description="Run unified benchmark for EXO (single or multi-stage)"
)
parser.add_argument("--api-port", type=int, required=True)
parser.add_argument("--config", type=Path, required=True, help="Path to YAML config file")
parser.add_argument("--expected-nodes", type=int, required=True, help="Total number of nodes expected in the cluster")
parser.add_argument("--is-primary", type=str, choices=["true", "false"], required=True)
parser.add_argument(
"--config", type=Path, required=True, help="Path to YAML config file"
)
parser.add_argument(
"--expected-nodes",
type=int,
required=True,
help="Total number of nodes expected in the cluster",
)
parser.add_argument(
"--is-primary", type=str, choices=["true", "false"], required=True
)
parser.add_argument("--timeout-seconds", type=int, default=1800)
parser.add_argument("--output", type=Path, help="Path to save detailed results JSON")
parser.add_argument(
"--output", type=Path, help="Path to save detailed results JSON"
)
parser.add_argument("--git-commit", type=str, help="Git commit hash for metadata")
parser.add_argument("--hardware-labels", type=str, help="Comma-separated hardware labels")
parser.add_argument(
"--hardware-labels", type=str, help="Comma-separated hardware labels"
)
args = parser.parse_args()
api_base = f"http://localhost:{args.api_port}"
is_primary = args.is_primary.lower() == "true"
hardware_labels = args.hardware_labels.split(",") if args.hardware_labels else None
return asyncio.run(run_benchmark(
api_base,
args.config,
args.expected_nodes,
is_primary,
args.timeout_seconds,
results_output_path=args.output,
git_commit=args.git_commit,
hardware_labels=hardware_labels,
))
return asyncio.run(
run_benchmark(
api_base,
args.config,
args.expected_nodes,
is_primary,
args.timeout_seconds,
results_output_path=args.output,
git_commit=args.git_commit,
hardware_labels=hardware_labels,
)
)
if __name__ == "__main__":
sys.exit(main())

View File

@@ -24,12 +24,12 @@ class Config(TypedDict):
# Read the config file
config_file: str = os.environ['CONFIG_FILE']
with open(config_file, 'r') as f:
config_file: str = os.environ["CONFIG_FILE"]
with open(config_file, "r") as f:
config: Config = cast(Config, yaml.safe_load(f))
# Extract hardware plan from config
plan: dict[str, int] = config['hardware_plan']
plan: dict[str, int] = config["hardware_plan"]
if not plan:
raise ValueError(f"No hardware_plan found in {config_file}")
@@ -40,22 +40,24 @@ for label, count in plan.items():
entries.append({"label": label, "index": idx})
total_nodes: int = len(entries)
matrix: dict[str, list[MatrixInclude]] = {"include": [
{
"label": e["label"],
"index": e["index"],
"is_primary": (i == 0),
"expected_nodes": total_nodes
}
for i, e in enumerate(entries)
]}
matrix: dict[str, list[MatrixInclude]] = {
"include": [
{
"label": e["label"],
"index": e["index"],
"is_primary": (i == 0),
"expected_nodes": total_nodes,
}
for i, e in enumerate(entries)
]
}
# Extract other config values
timeout_seconds: int = config.get('timeout_seconds', 600)
environment: dict[str, str] = config.get('environment', {})
timeout_seconds: int = config.get("timeout_seconds", 600)
environment: dict[str, str] = config.get("environment", {})
# Output to GitHub Actions
with open(os.environ['GITHUB_OUTPUT'], 'a') as f:
with open(os.environ["GITHUB_OUTPUT"], "a") as f:
f.write(f"matrix={json.dumps(matrix)}\n")
f.write(f"config_file={config_file}\n")
f.write(f"timeout_seconds={timeout_seconds}\n")
@@ -65,4 +67,3 @@ print(f"Matrix: {json.dumps(matrix)}")
print(f"Config file: {config_file}")
print(f"Timeout: {timeout_seconds}")
print(f"Environment: {json.dumps(environment)}")

View File

@@ -14,21 +14,20 @@ use libp2p::futures::StreamExt as _;
use libp2p::gossipsub::{IdentTopic, Message, MessageId, PublishError};
use libp2p::swarm::SwarmEvent;
use libp2p::{gossipsub, mdns};
use networking::discovery;
use networking::swarm::create_swarm;
use pyo3::prelude::{PyModule, PyModuleMethods as _};
use pyo3::types::PyBytes;
use pyo3::{Bound, Py, PyErr, PyResult, PyTraverseError, PyVisit, Python, pymethods};
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pyclass_enum, gen_stub_pymethods};
use std::net::IpAddr;
use tokio::sync::{Mutex, mpsc, oneshot};
use networking::discovery;
use networking::swarm::create_swarm;
use util::ext::VecExt as _;
mod exception {
use pyo3::{exceptions::{PyException}, prelude::*, PyErrArguments};
use pyo3::types::PyTuple;
use pyo3_stub_gen::{derive::*};
use pyo3::{PyErrArguments, exceptions::PyException, prelude::*};
use pyo3_stub_gen::derive::*;
#[gen_stub_pyclass]
#[pyclass(frozen, extends=PyException, name="NoPeersSubscribedToTopicError")]
@@ -71,7 +70,8 @@ mod exception {
pub struct PyAllQueuesFullError {}
impl PyAllQueuesFullError {
const MSG: &'static str = "All libp2p peers are unresponsive, resend the message or reconnect.";
const MSG: &'static str =
"All libp2p peers are unresponsive, resend the message or reconnect.";
/// Creates a new [ `PyErr` ] of this type.
///
@@ -154,10 +154,10 @@ async fn networking_task(
connection_update_tx: mpsc::Sender<PyConnectionUpdate>,
gossipsub_message_tx: mpsc::Sender<(String, Vec<u8>)>,
) {
use networking::swarm::BehaviourEvent::*;
use SwarmEvent::*;
use ToTask::*;
use mdns::Event::*;
use networking::swarm::BehaviourEvent::*;
log::info!("RUST: networking task started");
@@ -367,7 +367,7 @@ impl PyNetworkingHandle {
connection_update_tx,
gossipsub_message_tx,
)
.await;
.await;
});
Ok(Self::new(
to_task_tx,

View File

@@ -18,17 +18,14 @@
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.
use std::{
error::Error,
hash::{Hash},
};
use std::time::Duration;
use futures::stream::StreamExt;
use libp2p::{
gossipsub, mdns, noise,
swarm::{NetworkBehaviour, SwarmEvent},
tcp, yamux,
};
use std::time::Duration;
use std::{error::Error, hash::Hash};
use tokio::{io, io::AsyncBufReadExt, select};
use tracing_subscriber::EnvFilter;

View File

@@ -1,3 +1,4 @@
use crate::ext::MultiaddrExt;
use crate::keep_alive;
use delegate::delegate;
use either::Either;
@@ -7,7 +8,11 @@ use libp2p::core::transport::PortUse;
use libp2p::core::{ConnectedPoint, Endpoint};
use libp2p::swarm::behaviour::ConnectionEstablished;
use libp2p::swarm::dial_opts::DialOpts;
use libp2p::swarm::{dummy, CloseConnection, ConnectionClosed, ConnectionDenied, ConnectionHandler, ConnectionHandlerSelect, ConnectionId, FromSwarm, NetworkBehaviour, THandler, THandlerInEvent, THandlerOutEvent, ToSwarm};
use libp2p::swarm::{
CloseConnection, ConnectionClosed, ConnectionDenied, ConnectionHandler,
ConnectionHandlerSelect, ConnectionId, FromSwarm, NetworkBehaviour, THandler, THandlerInEvent,
THandlerOutEvent, ToSwarm, dummy,
};
use libp2p::{Multiaddr, PeerId, identity, mdns};
use std::collections::{BTreeSet, HashMap};
use std::convert::Infallible;
@@ -16,16 +21,14 @@ use std::net::IpAddr;
use std::task::{Context, Poll};
use std::time::Duration;
use util::wakerdeque::WakerDeque;
use crate::ext::MultiaddrExt;
const RETRY_CONNECT_INTERVAL: Duration = Duration::from_secs(5);
mod managed {
use libp2p::swarm::NetworkBehaviour;
use libp2p::{identity, mdns, ping};
use std::io;
use std::time::Duration;
use libp2p::{identity, mdns, ping};
use libp2p::swarm::NetworkBehaviour;
const MDNS_RECORD_TTL: Duration = Duration::from_secs(2_500);
const MDNS_QUERY_INTERVAL: Duration = Duration::from_secs(1_500);
@@ -64,7 +67,11 @@ mod managed {
}
fn ping_behaviour() -> ping::Behaviour {
ping::Behaviour::new(ping::Config::new().with_timeout(PING_TIMEOUT).with_interval(PING_INTERVAL))
ping::Behaviour::new(
ping::Config::new()
.with_timeout(PING_TIMEOUT)
.with_interval(PING_INTERVAL),
)
}
}
@@ -129,7 +136,6 @@ impl Behaviour {
})
}
fn handle_mdns_discovered(&mut self, peers: Vec<(PeerId, Multiaddr)>) {
for (p, ma) in peers {
self.dial(p, ma.clone()); // always connect
@@ -202,7 +208,7 @@ impl Behaviour {
impl NetworkBehaviour for Behaviour {
type ConnectionHandler =
ConnectionHandlerSelect<dummy::ConnectionHandler, THandler<managed::Behaviour>>;
ConnectionHandlerSelect<dummy::ConnectionHandler, THandler<managed::Behaviour>>;
type ToSwarm = Event;
// simply delegate to underlying mDNS behaviour
@@ -261,11 +267,10 @@ impl NetworkBehaviour for Behaviour {
) {
match event {
Either::Left(ev) => libp2p::core::util::unreachable(ev),
Either::Right(ev) => self.managed.on_connection_handler_event(
peer_id,
connection_id,
ev,
),
Either::Right(ev) => {
self.managed
.on_connection_handler_event(peer_id, connection_id, ev)
}
}
}
@@ -277,11 +282,11 @@ impl NetworkBehaviour for Behaviour {
// handle swarm events to update internal state:
match event {
FromSwarm::ConnectionEstablished(ConnectionEstablished {
peer_id,
connection_id,
endpoint,
..
}) => {
peer_id,
connection_id,
endpoint,
..
}) => {
let remote_address = match endpoint {
ConnectedPoint::Dialer { address, .. } => address,
ConnectedPoint::Listener { send_back_addr, .. } => send_back_addr,
@@ -293,11 +298,11 @@ impl NetworkBehaviour for Behaviour {
}
}
FromSwarm::ConnectionClosed(ConnectionClosed {
peer_id,
connection_id,
endpoint,
..
}) => {
peer_id,
connection_id,
endpoint,
..
}) => {
let remote_address = match endpoint {
ConnectedPoint::Dialer { address, .. } => address,
ConnectedPoint::Listener { send_back_addr, .. } => send_back_addr,
@@ -331,7 +336,7 @@ impl NetworkBehaviour for Behaviour {
mdns::Event::Expired(peers) => {
self.handle_mdns_expired(peers);
}
}
},
// handle ping events => if error then disconnect
managed::BehaviourEvent::Ping(e) => {
@@ -346,7 +351,6 @@ impl NetworkBehaviour for Behaviour {
cx.waker().wake_by_ref();
}
// forward any other mDNS event to the swarm or its connection handler(s)
Poll::Ready(e) => {
return Poll::Ready(

View File

@@ -20,13 +20,13 @@ impl handler::ConnectionHandler for ConnectionHandler {
type FromBehaviour = <dummy::ConnectionHandler as handler::ConnectionHandler>::FromBehaviour;
type ToBehaviour = <dummy::ConnectionHandler as handler::ConnectionHandler>::ToBehaviour;
type InboundProtocol =
<dummy::ConnectionHandler as handler::ConnectionHandler>::InboundProtocol;
<dummy::ConnectionHandler as handler::ConnectionHandler>::InboundProtocol;
type OutboundProtocol =
<dummy::ConnectionHandler as handler::ConnectionHandler>::OutboundProtocol;
<dummy::ConnectionHandler as handler::ConnectionHandler>::OutboundProtocol;
type InboundOpenInfo =
<dummy::ConnectionHandler as handler::ConnectionHandler>::InboundOpenInfo;
<dummy::ConnectionHandler as handler::ConnectionHandler>::InboundOpenInfo;
type OutboundOpenInfo =
<dummy::ConnectionHandler as handler::ConnectionHandler>::OutboundOpenInfo;
<dummy::ConnectionHandler as handler::ConnectionHandler>::OutboundOpenInfo;
delegate! {
to self.0 {

View File

@@ -28,10 +28,10 @@ pub(crate) mod alias {
/// Namespace for crate-wide extension traits/methods
pub(crate) mod ext {
use std::net::IpAddr;
use extend::ext;
use libp2p::Multiaddr;
use libp2p::multiaddr::Protocol;
use std::net::IpAddr;
#[ext(pub, name = MultiaddrExt)]
impl Multiaddr {
@@ -42,7 +42,7 @@ pub(crate) mod ext {
match p {
Protocol::Ip4(ip) => IpAddr::V4(ip),
Protocol::Ip6(ip) => IpAddr::V6(ip),
_ => return None
_ => return None,
}
} else {
return None;

View File

@@ -37,7 +37,7 @@ mod transport {
use libp2p::core::transport::Boxed;
use libp2p::pnet::{PnetError, PnetOutput};
use libp2p::{PeerId, Transport, identity, noise, pnet, yamux};
use std::{sync::LazyLock, env};
use std::{env, sync::LazyLock};
/// Key used for networking's private network; parametrized on the [`NETWORK_VERSION`].
/// See [`pnet_upgrade`] for more.
@@ -49,7 +49,8 @@ mod transport {
builder.update(&bytes)
} else {
builder.update(NETWORK_VERSION)
}.finalize()
}
.finalize()
});
/// Make the Swarm run on a private network, as to not clash with public libp2p nodes and
@@ -103,9 +104,9 @@ mod transport {
mod behaviour {
use crate::{alias, discovery};
use std::time::Duration;
use libp2p::swarm::NetworkBehaviour;
use libp2p::{gossipsub, identity};
use std::time::Duration;
/// Behavior of the Swarm which composes all desired behaviors:
/// Right now its just [`discovery::Behaviour`] and [`gossipsub::Behaviour`].
@@ -139,6 +140,6 @@ mod behaviour {
.build()
.expect("the configuration should always be valid"),
)
.expect("creating gossipsub behavior should always work")
.expect("creating gossipsub behavior should always work")
}
}

View File

@@ -27,7 +27,7 @@ def stream_chat(host: str, query: str) -> None:
if not line.startswith("data:"):
continue
data = line[len("data:"):].strip()
data = line[len("data:") :].strip()
if data == "[DONE]":
break
@@ -55,7 +55,8 @@ def main() -> None:
)
parser.add_argument("host", help="Hostname (without protocol), e.g. localhost")
parser.add_argument(
"-f", "--file",
"-f",
"--file",
help="Path to a text file whose contents will be used as the query",
)
parser.add_argument(