mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 22:27:50 -05:00
fmt: format all python/rust/nix files
This commit is contained in:
333
.github/scripts/bench.py
vendored
333
.github/scripts/bench.py
vendored
@@ -9,6 +9,7 @@ Requests are fire-and-forget, allowing overlapping execution.
|
||||
Simple benchmark (1 iteration): --config .github/configs/bench_simple.yaml
|
||||
Complex benchmark (multiple stages): --config .github/configs/bench_config.yaml
|
||||
"""
|
||||
|
||||
# pyright: reportAny=false, reportUnknownArgumentType=false, reportUnknownVariableType=false
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -34,7 +35,11 @@ def _format_http_error(error: Exception) -> str:
|
||||
except Exception:
|
||||
body = "<unable to read body>"
|
||||
|
||||
headers_str = "\n".join(f" {k}: {v}" for k, v in error.headers.items()) if error.headers else "<no headers>"
|
||||
headers_str = (
|
||||
"\n".join(f" {k}: {v}" for k, v in error.headers.items())
|
||||
if error.headers
|
||||
else "<no headers>"
|
||||
)
|
||||
|
||||
return (
|
||||
f"HTTP {error.code} {error.reason}\n"
|
||||
@@ -48,7 +53,9 @@ def _format_http_error(error: Exception) -> str:
|
||||
return str(error)
|
||||
|
||||
|
||||
def _http_request(url: str, *, method: str = "GET", data: Mapping[str, Any] | None = None) -> dict[str, Any]:
|
||||
def _http_request(
|
||||
url: str, *, method: str = "GET", data: Mapping[str, Any] | None = None
|
||||
) -> dict[str, Any]:
|
||||
headers = {"Content-Type": "application/json"}
|
||||
payload: bytes | None = None
|
||||
if data is not None:
|
||||
@@ -67,14 +74,21 @@ def _http_request(url: str, *, method: str = "GET", data: Mapping[str, Any] | No
|
||||
raise
|
||||
|
||||
|
||||
async def _http_request_async(url: str, *, method: str = "GET", data: Mapping[str, Any] | None = None) -> dict[str, Any]:
|
||||
async def _http_request_async(
|
||||
url: str, *, method: str = "GET", data: Mapping[str, Any] | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""Async version that runs in executor to not block event loop."""
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(None, lambda: _http_request(url, method=method, data=data))
|
||||
return await loop.run_in_executor(
|
||||
None, lambda: _http_request(url, method=method, data=data)
|
||||
)
|
||||
|
||||
|
||||
async def _http_stream_async(url: str, *, method: str = "POST", data: Mapping[str, Any], timeout: int = 300) -> list[tuple[str, float]]:
|
||||
async def _http_stream_async(
|
||||
url: str, *, method: str = "POST", data: Mapping[str, Any], timeout: int = 300
|
||||
) -> list[tuple[str, float]]:
|
||||
"""Async streaming request. Returns list of (line, timestamp) tuples."""
|
||||
|
||||
def _stream() -> list[tuple[str, float]]:
|
||||
headers = {"Content-Type": "application/json"}
|
||||
payload = json.dumps(data).encode("utf-8")
|
||||
@@ -92,6 +106,7 @@ async def _http_stream_async(url: str, *, method: str = "POST", data: Mapping[st
|
||||
error_details = _format_http_error(e)
|
||||
print(f"HTTP request failed:\n{error_details}")
|
||||
raise
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(None, _stream)
|
||||
|
||||
@@ -196,7 +211,9 @@ def collect_metrics_snapshot(state: Mapping[str, Any]) -> MetricsSnapshot:
|
||||
|
||||
# Extract actual task from wrapper (e.g., {"ChatCompletion": {...}})
|
||||
if len(task_wrapper) != 1:
|
||||
print(f"[DEBUG] Task wrapper has unexpected number of keys: {len(task_wrapper)}")
|
||||
print(
|
||||
f"[DEBUG] Task wrapper has unexpected number of keys: {len(task_wrapper)}"
|
||||
)
|
||||
tasks_skipped += 1
|
||||
continue
|
||||
|
||||
@@ -223,7 +240,9 @@ def collect_metrics_snapshot(state: Mapping[str, Any]) -> MetricsSnapshot:
|
||||
tasks_matched += 1
|
||||
|
||||
if tasks_skipped > 0:
|
||||
print(f"[DEBUG] Task matching: {tasks_matched} matched, {tasks_skipped} skipped (from {len(tasks)} total)")
|
||||
print(
|
||||
f"[DEBUG] Task matching: {tasks_matched} matched, {tasks_skipped} skipped (from {len(tasks)} total)"
|
||||
)
|
||||
|
||||
# Build snapshots for each instance (assign to primary node - first in sorted order)
|
||||
for instance_id, counts in instance_task_counts.items():
|
||||
@@ -310,7 +329,9 @@ def count_instances_by_model(state: Mapping[str, Any], model_id: str) -> int:
|
||||
return count
|
||||
|
||||
|
||||
def get_all_instance_ids_for_model(state: Mapping[str, Any], model_id: str) -> list[str]:
|
||||
def get_all_instance_ids_for_model(
|
||||
state: Mapping[str, Any], model_id: str
|
||||
) -> list[str]:
|
||||
"""Get all instance IDs for a given model_id."""
|
||||
instances: Mapping[str, Any] = state.get("instances", {})
|
||||
instance_ids = []
|
||||
@@ -358,7 +379,9 @@ def count_ready_instances_by_model(state: Mapping[str, Any], model_id: str) -> i
|
||||
return ready_count
|
||||
|
||||
|
||||
def get_runner_ids_for_instance(state: Mapping[str, Any], instance_id: str) -> list[str]:
|
||||
def get_runner_ids_for_instance(
|
||||
state: Mapping[str, Any], instance_id: str
|
||||
) -> list[str]:
|
||||
instances: Mapping[str, Any] = state.get("instances", {})
|
||||
instance_wrapped = instances.get(instance_id, {})
|
||||
|
||||
@@ -387,28 +410,40 @@ def get_runner_status_kind(state: Mapping[str, Any], runner_id: str) -> str | No
|
||||
return None
|
||||
|
||||
|
||||
async def wait_for_topology_ready(api_base: str, expected_nodes: int, timeout_s: int) -> None:
|
||||
async def wait_for_topology_ready(
|
||||
api_base: str, expected_nodes: int, timeout_s: int
|
||||
) -> None:
|
||||
"""Wait for all expected nodes to appear in the topology."""
|
||||
print(f"Waiting for {expected_nodes} node(s) to appear in topology (timeout: {timeout_s}s)...")
|
||||
print(
|
||||
f"Waiting for {expected_nodes} node(s) to appear in topology (timeout: {timeout_s}s)..."
|
||||
)
|
||||
start = time.monotonic()
|
||||
while True:
|
||||
state = fetch_state(api_base)
|
||||
node_count = get_topology_node_count(state)
|
||||
elapsed = time.monotonic() - start
|
||||
print(f" Topology has {node_count}/{expected_nodes} nodes (elapsed: {elapsed:.1f}s)")
|
||||
print(
|
||||
f" Topology has {node_count}/{expected_nodes} nodes (elapsed: {elapsed:.1f}s)"
|
||||
)
|
||||
|
||||
if node_count >= expected_nodes:
|
||||
print(f"All {expected_nodes} node(s) are in topology!")
|
||||
return
|
||||
|
||||
if elapsed > timeout_s:
|
||||
raise TimeoutError(f"Timed out waiting for topology. Expected {expected_nodes} nodes, got {node_count}")
|
||||
raise TimeoutError(
|
||||
f"Timed out waiting for topology. Expected {expected_nodes} nodes, got {node_count}"
|
||||
)
|
||||
await asyncio.sleep(2)
|
||||
|
||||
|
||||
async def wait_for_instances_ready(api_base: str, model_id: str, expected_count: int, timeout_s: int) -> list[str]:
|
||||
async def wait_for_instances_ready(
|
||||
api_base: str, model_id: str, expected_count: int, timeout_s: int
|
||||
) -> list[str]:
|
||||
"""Wait for a specific count of instances for a model to be fully ready."""
|
||||
print(f"Waiting for {expected_count} instance(s) of {model_id} to be ready (timeout: {timeout_s}s)...")
|
||||
print(
|
||||
f"Waiting for {expected_count} instance(s) of {model_id} to be ready (timeout: {timeout_s}s)..."
|
||||
)
|
||||
start = time.monotonic()
|
||||
while True:
|
||||
state = fetch_state(api_base)
|
||||
@@ -417,11 +452,15 @@ async def wait_for_instances_ready(api_base: str, model_id: str, expected_count:
|
||||
ready_count = count_ready_instances_by_model(state, model_id)
|
||||
elapsed = time.monotonic() - start
|
||||
|
||||
print(f" Model {model_id}: {ready_count}/{expected_count} ready ({total_count} total) (elapsed: {elapsed:.1f}s)")
|
||||
print(
|
||||
f" Model {model_id}: {ready_count}/{expected_count} ready ({total_count} total) (elapsed: {elapsed:.1f}s)"
|
||||
)
|
||||
|
||||
if ready_count >= expected_count:
|
||||
instance_ids = get_all_instance_ids_for_model(state, model_id)
|
||||
print(f"All {expected_count} instance(s) ready! Instance IDs: {instance_ids}")
|
||||
print(
|
||||
f"All {expected_count} instance(s) ready! Instance IDs: {instance_ids}"
|
||||
)
|
||||
return instance_ids
|
||||
|
||||
if elapsed > timeout_s:
|
||||
@@ -452,9 +491,9 @@ async def wait_for_tasks_drained(api_base: str, timeout_s: int = 600) -> None:
|
||||
Tasks are deleted from state when complete, so we wait until there are no
|
||||
pending or running tasks remaining.
|
||||
"""
|
||||
print(f"\n{'='*80}")
|
||||
print(f"\n{'=' * 80}")
|
||||
print(f"⏳ WAITING FOR ALL TASKS TO DRAIN")
|
||||
print(f"{'='*80}")
|
||||
print(f"{'=' * 80}")
|
||||
start = time.monotonic()
|
||||
|
||||
while True:
|
||||
@@ -472,18 +511,26 @@ async def wait_for_tasks_drained(api_base: str, timeout_s: int = 600) -> None:
|
||||
print(f"✅ All tasks drained after {elapsed:.1f}s")
|
||||
return
|
||||
|
||||
print(f" [{elapsed:.1f}s] Still draining: {total_active} active tasks ({total_pending} pending, {total_running} running)")
|
||||
print(
|
||||
f" [{elapsed:.1f}s] Still draining: {total_active} active tasks ({total_pending} pending, {total_running} running)"
|
||||
)
|
||||
|
||||
# Print per-node breakdown if there are active tasks
|
||||
if snapshot.node_tasks:
|
||||
for node_snapshot in snapshot.node_tasks:
|
||||
if node_snapshot.total_active_tasks > 0:
|
||||
node_short = node_snapshot.node_id[-4:]
|
||||
print(f" Node ...{node_short}: {node_snapshot.running_tasks} running, {node_snapshot.pending_tasks} pending")
|
||||
print(
|
||||
f" Node ...{node_short}: {node_snapshot.running_tasks} running, {node_snapshot.pending_tasks} pending"
|
||||
)
|
||||
|
||||
if elapsed > timeout_s:
|
||||
print(f"⚠️ WARNING: Timed out waiting for tasks to drain after {timeout_s}s")
|
||||
print(f" Remaining: {total_active} tasks ({total_pending} pending, {total_running} running)")
|
||||
print(
|
||||
f"⚠️ WARNING: Timed out waiting for tasks to drain after {timeout_s}s"
|
||||
)
|
||||
print(
|
||||
f" Remaining: {total_active} tasks ({total_pending} pending, {total_running} running)"
|
||||
)
|
||||
return
|
||||
|
||||
await asyncio.sleep(2)
|
||||
@@ -545,6 +592,7 @@ class StageResult:
|
||||
@dataclass(frozen=True)
|
||||
class MemorySnapshot:
|
||||
"""Memory snapshot for a node at a point in time."""
|
||||
|
||||
ram_total_bytes: int
|
||||
ram_available_bytes: int
|
||||
ram_used_bytes: int
|
||||
@@ -560,6 +608,7 @@ class InstanceTaskSnapshot:
|
||||
Note: Tasks are deleted from state when complete, so we only track active tasks.
|
||||
total_active_tasks = pending + running.
|
||||
"""
|
||||
|
||||
instance_id: str
|
||||
node_id: str
|
||||
pending_tasks: int
|
||||
@@ -574,6 +623,7 @@ class NodeTaskSnapshot:
|
||||
Note: Tasks are deleted from state when complete, so we only track active tasks.
|
||||
total_active_tasks = pending + running across all instances on this node.
|
||||
"""
|
||||
|
||||
node_id: str
|
||||
pending_tasks: int
|
||||
running_tasks: int
|
||||
@@ -584,6 +634,7 @@ class NodeTaskSnapshot:
|
||||
@dataclass(frozen=True)
|
||||
class MetricsSnapshot:
|
||||
"""System metrics snapshot at a point in time."""
|
||||
|
||||
timestamp: float
|
||||
node_memory: dict[str, MemorySnapshot]
|
||||
instance_tasks: list[InstanceTaskSnapshot]
|
||||
@@ -622,7 +673,7 @@ async def run_single_request(
|
||||
for line, timestamp in lines:
|
||||
if not line.startswith("data:"):
|
||||
continue
|
||||
payload = line[len("data:"):].strip()
|
||||
payload = line[len("data:") :].strip()
|
||||
if payload == "[DONE]":
|
||||
got_done = True
|
||||
break
|
||||
@@ -655,7 +706,9 @@ async def run_single_request(
|
||||
|
||||
# Request is only successful if we got at least one token AND a [DONE] marker
|
||||
if tokens == 0:
|
||||
print(f" Request #{request_id}: FAILED - no tokens generated in {elapsed:.2f}s")
|
||||
print(
|
||||
f" Request #{request_id}: FAILED - no tokens generated in {elapsed:.2f}s"
|
||||
)
|
||||
return RequestResult(
|
||||
request_id=request_id,
|
||||
success=False,
|
||||
@@ -665,11 +718,13 @@ async def run_single_request(
|
||||
completed_at=completed_at,
|
||||
time_to_first_token_s=time_to_first_token,
|
||||
decode_tps=decode_tps,
|
||||
error="No tokens generated"
|
||||
error="No tokens generated",
|
||||
)
|
||||
|
||||
if not got_done:
|
||||
print(f" Request #{request_id}: FAILED - incomplete response (no [DONE]) after {elapsed:.2f}s")
|
||||
print(
|
||||
f" Request #{request_id}: FAILED - incomplete response (no [DONE]) after {elapsed:.2f}s"
|
||||
)
|
||||
return RequestResult(
|
||||
request_id=request_id,
|
||||
success=False,
|
||||
@@ -679,12 +734,16 @@ async def run_single_request(
|
||||
completed_at=completed_at,
|
||||
time_to_first_token_s=time_to_first_token,
|
||||
decode_tps=decode_tps,
|
||||
error="Incomplete response (no [DONE] marker)"
|
||||
error="Incomplete response (no [DONE] marker)",
|
||||
)
|
||||
|
||||
ttft_str = f"{time_to_first_token:.3f}s" if time_to_first_token is not None else "N/A"
|
||||
ttft_str = (
|
||||
f"{time_to_first_token:.3f}s" if time_to_first_token is not None else "N/A"
|
||||
)
|
||||
tps_str = f"{decode_tps:.1f} t/s" if decode_tps is not None else "N/A"
|
||||
print(f" Request #{request_id}: SUCCESS - {tokens} tokens in {elapsed:.2f}s (TTFT: {ttft_str}, Decode: {tps_str})")
|
||||
print(
|
||||
f" Request #{request_id}: SUCCESS - {tokens} tokens in {elapsed:.2f}s (TTFT: {ttft_str}, Decode: {tps_str})"
|
||||
)
|
||||
return RequestResult(
|
||||
request_id=request_id,
|
||||
success=True,
|
||||
@@ -693,7 +752,7 @@ async def run_single_request(
|
||||
started_at=started_at,
|
||||
completed_at=completed_at,
|
||||
time_to_first_token_s=time_to_first_token,
|
||||
decode_tps=decode_tps
|
||||
decode_tps=decode_tps,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -710,7 +769,7 @@ async def run_single_request(
|
||||
completed_at=completed_at,
|
||||
time_to_first_token_s=None,
|
||||
decode_tps=None,
|
||||
error=error_details
|
||||
error=error_details,
|
||||
)
|
||||
|
||||
|
||||
@@ -721,9 +780,9 @@ async def monitor_metrics(
|
||||
interval_seconds: float = 5.0,
|
||||
) -> None:
|
||||
"""Background task that collects metrics snapshots every interval_seconds."""
|
||||
print(f"\n{'='*80}")
|
||||
print(f"\n{'=' * 80}")
|
||||
print(f"🔍 METRICS MONITORING STARTED (polling every {interval_seconds}s)")
|
||||
print(f"{'='*80}\n")
|
||||
print(f"{'=' * 80}\n")
|
||||
|
||||
snapshot_count = 0
|
||||
while not stop_event.is_set():
|
||||
@@ -743,17 +802,22 @@ async def monitor_metrics(
|
||||
total_active = sum(node.total_active_tasks for node in snapshot.node_tasks)
|
||||
|
||||
# Print detailed breakdown
|
||||
print(f"\n[METRICS #{snapshot_count}] {node_count} nodes, {instance_count} instances | Active Tasks: {total_active} ({total_pending} pending, {total_running} running)")
|
||||
print(
|
||||
f"\n[METRICS #{snapshot_count}] {node_count} nodes, {instance_count} instances | Active Tasks: {total_active} ({total_pending} pending, {total_running} running)"
|
||||
)
|
||||
|
||||
# Print per-node breakdown (only if there are nodes)
|
||||
if snapshot.node_tasks:
|
||||
for node_snapshot in snapshot.node_tasks:
|
||||
node_short = node_snapshot.node_id[-4:]
|
||||
print(f" Node ...{node_short}: {node_snapshot.total_active_tasks} active ({node_snapshot.pending_tasks} pending, {node_snapshot.running_tasks} running) across {node_snapshot.instance_count} instances")
|
||||
print(
|
||||
f" Node ...{node_short}: {node_snapshot.total_active_tasks} active ({node_snapshot.pending_tasks} pending, {node_snapshot.running_tasks} running) across {node_snapshot.instance_count} instances"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"[METRICS] Error collecting snapshot: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
# Wait for interval or until stopped
|
||||
@@ -788,7 +852,9 @@ async def run_stage(
|
||||
# Sequential execution: wait for each request to complete before starting next
|
||||
print("\nRunning requests sequentially (no overlap)...")
|
||||
for i in range(stage.iterations):
|
||||
result = await run_single_request(api_base, model_id, prompt, stage.generation_length, i + 1)
|
||||
result = await run_single_request(
|
||||
api_base, model_id, prompt, stage.generation_length, i + 1
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
# Wait before starting next request (except after last one)
|
||||
@@ -802,7 +868,9 @@ async def run_stage(
|
||||
# Fire off requests with delays between them
|
||||
for i in range(stage.iterations):
|
||||
task = asyncio.create_task(
|
||||
run_single_request(api_base, model_id, prompt, stage.generation_length, i + 1)
|
||||
run_single_request(
|
||||
api_base, model_id, prompt, stage.generation_length, i + 1
|
||||
)
|
||||
)
|
||||
tasks.append(task)
|
||||
|
||||
@@ -833,29 +901,47 @@ async def run_stage(
|
||||
successful_results = [r for r in results if r.success]
|
||||
|
||||
# Skip first iteration if there are more than 1 iterations (warmup)
|
||||
results_for_stats = successful_results[1:] if len(successful_results) > 1 else successful_results
|
||||
results_for_stats = (
|
||||
successful_results[1:] if len(successful_results) > 1 else successful_results
|
||||
)
|
||||
|
||||
# TTFT statistics
|
||||
ttft_values = [r.time_to_first_token_s for r in results_for_stats if r.time_to_first_token_s is not None]
|
||||
ttft_values = [
|
||||
r.time_to_first_token_s
|
||||
for r in results_for_stats
|
||||
if r.time_to_first_token_s is not None
|
||||
]
|
||||
avg_ttft = sum(ttft_values) / len(ttft_values) if ttft_values else None
|
||||
|
||||
if avg_ttft is not None and len(ttft_values) > 1:
|
||||
variance_ttft = sum((x - avg_ttft) ** 2 for x in ttft_values) / len(ttft_values)
|
||||
std_ttft = variance_ttft ** 0.5
|
||||
std_ttft = variance_ttft**0.5
|
||||
else:
|
||||
std_ttft = None
|
||||
|
||||
# Decode TPS and ms per token statistics
|
||||
decode_tps_values = [r.decode_tps for r in results_for_stats if r.decode_tps is not None]
|
||||
avg_decode_tps = sum(decode_tps_values) / len(decode_tps_values) if decode_tps_values else None
|
||||
decode_tps_values = [
|
||||
r.decode_tps for r in results_for_stats if r.decode_tps is not None
|
||||
]
|
||||
avg_decode_tps = (
|
||||
sum(decode_tps_values) / len(decode_tps_values) if decode_tps_values else None
|
||||
)
|
||||
|
||||
# Convert to ms per token
|
||||
ms_per_token_values = [1000.0 / tps for tps in decode_tps_values] if decode_tps_values else []
|
||||
avg_ms_per_token = sum(ms_per_token_values) / len(ms_per_token_values) if ms_per_token_values else None
|
||||
ms_per_token_values = (
|
||||
[1000.0 / tps for tps in decode_tps_values] if decode_tps_values else []
|
||||
)
|
||||
avg_ms_per_token = (
|
||||
sum(ms_per_token_values) / len(ms_per_token_values)
|
||||
if ms_per_token_values
|
||||
else None
|
||||
)
|
||||
|
||||
if avg_ms_per_token is not None and len(ms_per_token_values) > 1:
|
||||
variance_ms_per_token = sum((x - avg_ms_per_token) ** 2 for x in ms_per_token_values) / len(ms_per_token_values)
|
||||
std_ms_per_token = variance_ms_per_token ** 0.5
|
||||
variance_ms_per_token = sum(
|
||||
(x - avg_ms_per_token) ** 2 for x in ms_per_token_values
|
||||
) / len(ms_per_token_values)
|
||||
std_ms_per_token = variance_ms_per_token**0.5
|
||||
else:
|
||||
std_ms_per_token = None
|
||||
|
||||
@@ -920,8 +1006,12 @@ async def run_benchmark(
|
||||
print(f"Configuration File: {config_path}")
|
||||
print(f"Model IDs: {model_ids}")
|
||||
print(f"Instance Count: {len(model_ids)}")
|
||||
print(f"Sharding: {sharding if sharding else 'not specified (defaults to Pipeline)'}")
|
||||
print(f"Instance Type: {instance_meta if instance_meta else 'not specified (defaults to MlxRing)'}")
|
||||
print(
|
||||
f"Sharding: {sharding if sharding else 'not specified (defaults to Pipeline)'}"
|
||||
)
|
||||
print(
|
||||
f"Instance Type: {instance_meta if instance_meta else 'not specified (defaults to MlxRing)'}"
|
||||
)
|
||||
print(f"No Overlap: {no_overlap}")
|
||||
print(f"Stages: {len(stages)}")
|
||||
print(f"Expected Nodes: {expected_nodes}")
|
||||
@@ -930,15 +1020,20 @@ async def run_benchmark(
|
||||
|
||||
try:
|
||||
# Wait for all nodes to join the topology first
|
||||
await wait_for_topology_ready(api_base, expected_nodes, timeout_s=timeout_seconds)
|
||||
await wait_for_topology_ready(
|
||||
api_base, expected_nodes, timeout_s=timeout_seconds
|
||||
)
|
||||
|
||||
# Add 30 second delay to allow topology to stabilize before creating instances
|
||||
print(f"\nWaiting 30 seconds for topology to stabilize before creating instances...")
|
||||
print(
|
||||
f"\nWaiting 30 seconds for topology to stabilize before creating instances..."
|
||||
)
|
||||
await asyncio.sleep(30)
|
||||
print("Proceeding with instance creation\n")
|
||||
|
||||
# Count how many instances we need for each unique model_id
|
||||
from collections import Counter
|
||||
|
||||
model_counts = Counter(model_ids)
|
||||
|
||||
print(f"\nTarget instance counts by model:")
|
||||
@@ -958,8 +1053,12 @@ async def run_benchmark(
|
||||
target_count = current_ready + 1
|
||||
|
||||
print("=" * 80)
|
||||
print(f"[PRIMARY] Creating instance {idx+1}/{len(model_ids)} for model: {model_id}")
|
||||
print(f"[PRIMARY] Current ready count for {model_id}: {current_ready}, target: {target_count}")
|
||||
print(
|
||||
f"[PRIMARY] Creating instance {idx + 1}/{len(model_ids)} for model: {model_id}"
|
||||
)
|
||||
print(
|
||||
f"[PRIMARY] Current ready count for {model_id}: {current_ready}, target: {target_count}"
|
||||
)
|
||||
|
||||
# Build instance creation request data
|
||||
instance_data: dict[str, Any] = {"model_id": model_id}
|
||||
@@ -969,21 +1068,23 @@ async def run_benchmark(
|
||||
instance_data["instance_meta"] = instance_meta
|
||||
|
||||
response = await _http_request_async(
|
||||
f"{api_base}/instance",
|
||||
method="POST",
|
||||
data=instance_data
|
||||
f"{api_base}/instance", method="POST", data=instance_data
|
||||
)
|
||||
print(f"[PRIMARY] Instance creation response: {response}")
|
||||
|
||||
# Wait for one more instance of this model to be ready
|
||||
await wait_for_instances_ready(api_base, model_id, target_count, timeout_s=timeout_seconds)
|
||||
print(f"[PRIMARY] Instance {idx+1}/{len(model_ids)} is ready")
|
||||
await wait_for_instances_ready(
|
||||
api_base, model_id, target_count, timeout_s=timeout_seconds
|
||||
)
|
||||
print(f"[PRIMARY] Instance {idx + 1}/{len(model_ids)} is ready")
|
||||
print("=" * 80)
|
||||
else:
|
||||
# Secondary: wait for expected counts of each model to be ready
|
||||
print("[SECONDARY] Waiting for all instances to be created and ready...")
|
||||
for model_id, expected_count in model_counts.items():
|
||||
await wait_for_instances_ready(api_base, model_id, expected_count, timeout_s=timeout_seconds)
|
||||
await wait_for_instances_ready(
|
||||
api_base, model_id, expected_count, timeout_s=timeout_seconds
|
||||
)
|
||||
|
||||
# Collect all instance IDs for all models
|
||||
state = fetch_state(api_base)
|
||||
@@ -997,7 +1098,9 @@ async def run_benchmark(
|
||||
runner_ids = get_runner_ids_for_instance(state, instance_id)
|
||||
total_runners += len(runner_ids)
|
||||
|
||||
print(f"\nAll {len(all_instance_ids)} instance(s) with {total_runners} total runner(s) are ready!")
|
||||
print(
|
||||
f"\nAll {len(all_instance_ids)} instance(s) with {total_runners} total runner(s) are ready!"
|
||||
)
|
||||
print(f"Instance IDs: {all_instance_ids}")
|
||||
|
||||
if is_primary:
|
||||
@@ -1013,12 +1116,16 @@ async def run_benchmark(
|
||||
metrics_snapshots: list[MetricsSnapshot] = []
|
||||
stop_monitoring = asyncio.Event()
|
||||
monitoring_task = asyncio.create_task(
|
||||
monitor_metrics(api_base, metrics_snapshots, stop_monitoring, interval_seconds=0.5)
|
||||
monitor_metrics(
|
||||
api_base, metrics_snapshots, stop_monitoring, interval_seconds=0.5
|
||||
)
|
||||
)
|
||||
|
||||
stage_results: list[StageResult] = []
|
||||
for stage in stages:
|
||||
result = await run_stage(api_base, benchmark_model_id, stage, no_overlap=no_overlap)
|
||||
result = await run_stage(
|
||||
api_base, benchmark_model_id, stage, no_overlap=no_overlap
|
||||
)
|
||||
stage_results.append(result)
|
||||
|
||||
# Stop metrics monitoring
|
||||
@@ -1046,12 +1153,18 @@ async def run_benchmark(
|
||||
print(f" Avg Time/Request: {result.avg_time_per_request:.2f}s")
|
||||
if result.avg_time_to_first_token is not None:
|
||||
if result.std_time_to_first_token is not None:
|
||||
print(f" Avg TTFT: {result.avg_time_to_first_token:.3f}s ± {result.std_time_to_first_token:.3f}s")
|
||||
print(
|
||||
f" Avg TTFT: {result.avg_time_to_first_token:.3f}s ± {result.std_time_to_first_token:.3f}s"
|
||||
)
|
||||
else:
|
||||
print(f" Avg TTFT: {result.avg_time_to_first_token:.3f}s")
|
||||
print(
|
||||
f" Avg TTFT: {result.avg_time_to_first_token:.3f}s"
|
||||
)
|
||||
if result.avg_ms_per_token is not None:
|
||||
if result.std_ms_per_token is not None:
|
||||
print(f" Avg ms/token: {result.avg_ms_per_token:.2f}ms ± {result.std_ms_per_token:.2f}ms")
|
||||
print(
|
||||
f" Avg ms/token: {result.avg_ms_per_token:.2f}ms ± {result.std_ms_per_token:.2f}ms"
|
||||
)
|
||||
else:
|
||||
print(f" Avg ms/token: {result.avg_ms_per_token:.2f}ms")
|
||||
if result.avg_decode_tps is not None:
|
||||
@@ -1100,16 +1213,33 @@ async def run_benchmark(
|
||||
"failed_requests": r.failed_requests,
|
||||
"success_rate": round(r.success_rate, 4),
|
||||
"total_tokens": r.total_tokens,
|
||||
"avg_tokens_per_request": round(r.avg_tokens_per_request, 2),
|
||||
"avg_tokens_per_request": round(
|
||||
r.avg_tokens_per_request, 2
|
||||
),
|
||||
"avg_time_per_request": round(r.avg_time_per_request, 3),
|
||||
"avg_time_to_first_token": round(r.avg_time_to_first_token, 3) if r.avg_time_to_first_token is not None else None,
|
||||
"std_time_to_first_token": round(r.std_time_to_first_token, 3) if r.std_time_to_first_token is not None else None,
|
||||
"avg_decode_tps": round(r.avg_decode_tps, 2) if r.avg_decode_tps is not None else None,
|
||||
"avg_ms_per_token": round(r.avg_ms_per_token, 2) if r.avg_ms_per_token is not None else None,
|
||||
"std_ms_per_token": round(r.std_ms_per_token, 2) if r.std_ms_per_token is not None else None,
|
||||
"avg_time_to_first_token": round(
|
||||
r.avg_time_to_first_token, 3
|
||||
)
|
||||
if r.avg_time_to_first_token is not None
|
||||
else None,
|
||||
"std_time_to_first_token": round(
|
||||
r.std_time_to_first_token, 3
|
||||
)
|
||||
if r.std_time_to_first_token is not None
|
||||
else None,
|
||||
"avg_decode_tps": round(r.avg_decode_tps, 2)
|
||||
if r.avg_decode_tps is not None
|
||||
else None,
|
||||
"avg_ms_per_token": round(r.avg_ms_per_token, 2)
|
||||
if r.avg_ms_per_token is not None
|
||||
else None,
|
||||
"std_ms_per_token": round(r.std_ms_per_token, 2)
|
||||
if r.std_ms_per_token is not None
|
||||
else None,
|
||||
"stage_started_at": r.stage_started_at,
|
||||
"stage_completed_at": r.stage_completed_at,
|
||||
"stage_duration_s": r.stage_completed_at - r.stage_started_at,
|
||||
"stage_duration_s": r.stage_completed_at
|
||||
- r.stage_started_at,
|
||||
"requests": [
|
||||
{
|
||||
"request_id": req.request_id,
|
||||
@@ -1118,12 +1248,18 @@ async def run_benchmark(
|
||||
"elapsed_s": round(req.elapsed_s, 3),
|
||||
"started_at": req.started_at,
|
||||
"completed_at": req.completed_at,
|
||||
"time_to_first_token_s": round(req.time_to_first_token_s, 3) if req.time_to_first_token_s is not None else None,
|
||||
"decode_tps": round(req.decode_tps, 2) if req.decode_tps is not None else None,
|
||||
"time_to_first_token_s": round(
|
||||
req.time_to_first_token_s, 3
|
||||
)
|
||||
if req.time_to_first_token_s is not None
|
||||
else None,
|
||||
"decode_tps": round(req.decode_tps, 2)
|
||||
if req.decode_tps is not None
|
||||
else None,
|
||||
"error": req.error,
|
||||
}
|
||||
for req in r.request_results
|
||||
]
|
||||
],
|
||||
}
|
||||
for r in stage_results
|
||||
]
|
||||
@@ -1162,11 +1298,11 @@ async def run_benchmark(
|
||||
"instance_count": node.instance_count,
|
||||
}
|
||||
for node in snapshot.node_tasks
|
||||
]
|
||||
],
|
||||
}
|
||||
for snapshot in metrics_snapshots
|
||||
]
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
# Output JSON summary
|
||||
@@ -1186,10 +1322,14 @@ async def run_benchmark(
|
||||
# Cleanup all instances
|
||||
for instance_id in all_instance_ids:
|
||||
print(f"[PRIMARY] Cleaning up instance: {instance_id}")
|
||||
await _http_request_async(f"{api_base}/instance/{instance_id}", method="DELETE")
|
||||
await _http_request_async(
|
||||
f"{api_base}/instance/{instance_id}", method="DELETE"
|
||||
)
|
||||
print(f"[PRIMARY] Instance {instance_id} deleted successfully")
|
||||
else:
|
||||
print("[SECONDARY] Waiting with cluster (primary handles benchmark execution)")
|
||||
print(
|
||||
"[SECONDARY] Waiting with cluster (primary handles benchmark execution)"
|
||||
)
|
||||
# Secondary nodes wait until all instances of all models are deleted
|
||||
for model_id in model_counts.keys():
|
||||
await wait_for_all_instances_deleted(api_base, model_id)
|
||||
@@ -1205,28 +1345,45 @@ async def run_benchmark(
|
||||
print("=" * 80)
|
||||
print(f"ERROR: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
print("=" * 80)
|
||||
return 1
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(description="Run unified benchmark for EXO (single or multi-stage)")
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run unified benchmark for EXO (single or multi-stage)"
|
||||
)
|
||||
parser.add_argument("--api-port", type=int, required=True)
|
||||
parser.add_argument("--config", type=Path, required=True, help="Path to YAML config file")
|
||||
parser.add_argument("--expected-nodes", type=int, required=True, help="Total number of nodes expected in the cluster")
|
||||
parser.add_argument("--is-primary", type=str, choices=["true", "false"], required=True)
|
||||
parser.add_argument(
|
||||
"--config", type=Path, required=True, help="Path to YAML config file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--expected-nodes",
|
||||
type=int,
|
||||
required=True,
|
||||
help="Total number of nodes expected in the cluster",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--is-primary", type=str, choices=["true", "false"], required=True
|
||||
)
|
||||
parser.add_argument("--timeout-seconds", type=int, default=1800)
|
||||
parser.add_argument("--output", type=Path, help="Path to save detailed results JSON")
|
||||
parser.add_argument(
|
||||
"--output", type=Path, help="Path to save detailed results JSON"
|
||||
)
|
||||
parser.add_argument("--git-commit", type=str, help="Git commit hash for metadata")
|
||||
parser.add_argument("--hardware-labels", type=str, help="Comma-separated hardware labels")
|
||||
parser.add_argument(
|
||||
"--hardware-labels", type=str, help="Comma-separated hardware labels"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
api_base = f"http://localhost:{args.api_port}"
|
||||
is_primary = args.is_primary.lower() == "true"
|
||||
hardware_labels = args.hardware_labels.split(",") if args.hardware_labels else None
|
||||
|
||||
return asyncio.run(run_benchmark(
|
||||
return asyncio.run(
|
||||
run_benchmark(
|
||||
api_base,
|
||||
args.config,
|
||||
args.expected_nodes,
|
||||
@@ -1235,9 +1392,9 @@ def main() -> int:
|
||||
results_output_path=args.output,
|
||||
git_commit=args.git_commit,
|
||||
hardware_labels=hardware_labels,
|
||||
))
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
|
||||
|
||||
21
.github/scripts/build_matrix.py
vendored
21
.github/scripts/build_matrix.py
vendored
@@ -24,12 +24,12 @@ class Config(TypedDict):
|
||||
|
||||
|
||||
# Read the config file
|
||||
config_file: str = os.environ['CONFIG_FILE']
|
||||
with open(config_file, 'r') as f:
|
||||
config_file: str = os.environ["CONFIG_FILE"]
|
||||
with open(config_file, "r") as f:
|
||||
config: Config = cast(Config, yaml.safe_load(f))
|
||||
|
||||
# Extract hardware plan from config
|
||||
plan: dict[str, int] = config['hardware_plan']
|
||||
plan: dict[str, int] = config["hardware_plan"]
|
||||
if not plan:
|
||||
raise ValueError(f"No hardware_plan found in {config_file}")
|
||||
|
||||
@@ -40,22 +40,24 @@ for label, count in plan.items():
|
||||
entries.append({"label": label, "index": idx})
|
||||
|
||||
total_nodes: int = len(entries)
|
||||
matrix: dict[str, list[MatrixInclude]] = {"include": [
|
||||
matrix: dict[str, list[MatrixInclude]] = {
|
||||
"include": [
|
||||
{
|
||||
"label": e["label"],
|
||||
"index": e["index"],
|
||||
"is_primary": (i == 0),
|
||||
"expected_nodes": total_nodes
|
||||
"expected_nodes": total_nodes,
|
||||
}
|
||||
for i, e in enumerate(entries)
|
||||
]}
|
||||
]
|
||||
}
|
||||
|
||||
# Extract other config values
|
||||
timeout_seconds: int = config.get('timeout_seconds', 600)
|
||||
environment: dict[str, str] = config.get('environment', {})
|
||||
timeout_seconds: int = config.get("timeout_seconds", 600)
|
||||
environment: dict[str, str] = config.get("environment", {})
|
||||
|
||||
# Output to GitHub Actions
|
||||
with open(os.environ['GITHUB_OUTPUT'], 'a') as f:
|
||||
with open(os.environ["GITHUB_OUTPUT"], "a") as f:
|
||||
f.write(f"matrix={json.dumps(matrix)}\n")
|
||||
f.write(f"config_file={config_file}\n")
|
||||
f.write(f"timeout_seconds={timeout_seconds}\n")
|
||||
@@ -65,4 +67,3 @@ print(f"Matrix: {json.dumps(matrix)}")
|
||||
print(f"Config file: {config_file}")
|
||||
print(f"Timeout: {timeout_seconds}")
|
||||
print(f"Environment: {json.dumps(environment)}")
|
||||
|
||||
|
||||
@@ -14,21 +14,20 @@ use libp2p::futures::StreamExt as _;
|
||||
use libp2p::gossipsub::{IdentTopic, Message, MessageId, PublishError};
|
||||
use libp2p::swarm::SwarmEvent;
|
||||
use libp2p::{gossipsub, mdns};
|
||||
use networking::discovery;
|
||||
use networking::swarm::create_swarm;
|
||||
use pyo3::prelude::{PyModule, PyModuleMethods as _};
|
||||
use pyo3::types::PyBytes;
|
||||
use pyo3::{Bound, Py, PyErr, PyResult, PyTraverseError, PyVisit, Python, pymethods};
|
||||
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pyclass_enum, gen_stub_pymethods};
|
||||
use std::net::IpAddr;
|
||||
use tokio::sync::{Mutex, mpsc, oneshot};
|
||||
use networking::discovery;
|
||||
use networking::swarm::create_swarm;
|
||||
use util::ext::VecExt as _;
|
||||
|
||||
mod exception {
|
||||
use pyo3::{exceptions::{PyException}, prelude::*, PyErrArguments};
|
||||
use pyo3::types::PyTuple;
|
||||
use pyo3_stub_gen::{derive::*};
|
||||
|
||||
use pyo3::{PyErrArguments, exceptions::PyException, prelude::*};
|
||||
use pyo3_stub_gen::derive::*;
|
||||
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(frozen, extends=PyException, name="NoPeersSubscribedToTopicError")]
|
||||
@@ -71,7 +70,8 @@ mod exception {
|
||||
pub struct PyAllQueuesFullError {}
|
||||
|
||||
impl PyAllQueuesFullError {
|
||||
const MSG: &'static str = "All libp2p peers are unresponsive, resend the message or reconnect.";
|
||||
const MSG: &'static str =
|
||||
"All libp2p peers are unresponsive, resend the message or reconnect.";
|
||||
|
||||
/// Creates a new [ `PyErr` ] of this type.
|
||||
///
|
||||
@@ -154,10 +154,10 @@ async fn networking_task(
|
||||
connection_update_tx: mpsc::Sender<PyConnectionUpdate>,
|
||||
gossipsub_message_tx: mpsc::Sender<(String, Vec<u8>)>,
|
||||
) {
|
||||
use networking::swarm::BehaviourEvent::*;
|
||||
use SwarmEvent::*;
|
||||
use ToTask::*;
|
||||
use mdns::Event::*;
|
||||
use networking::swarm::BehaviourEvent::*;
|
||||
|
||||
log::info!("RUST: networking task started");
|
||||
|
||||
|
||||
@@ -18,17 +18,14 @@
|
||||
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
// DEALINGS IN THE SOFTWARE.
|
||||
|
||||
use std::{
|
||||
error::Error,
|
||||
hash::{Hash},
|
||||
};
|
||||
use std::time::Duration;
|
||||
use futures::stream::StreamExt;
|
||||
use libp2p::{
|
||||
gossipsub, mdns, noise,
|
||||
swarm::{NetworkBehaviour, SwarmEvent},
|
||||
tcp, yamux,
|
||||
};
|
||||
use std::time::Duration;
|
||||
use std::{error::Error, hash::Hash};
|
||||
use tokio::{io, io::AsyncBufReadExt, select};
|
||||
use tracing_subscriber::EnvFilter;
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use crate::ext::MultiaddrExt;
|
||||
use crate::keep_alive;
|
||||
use delegate::delegate;
|
||||
use either::Either;
|
||||
@@ -7,7 +8,11 @@ use libp2p::core::transport::PortUse;
|
||||
use libp2p::core::{ConnectedPoint, Endpoint};
|
||||
use libp2p::swarm::behaviour::ConnectionEstablished;
|
||||
use libp2p::swarm::dial_opts::DialOpts;
|
||||
use libp2p::swarm::{dummy, CloseConnection, ConnectionClosed, ConnectionDenied, ConnectionHandler, ConnectionHandlerSelect, ConnectionId, FromSwarm, NetworkBehaviour, THandler, THandlerInEvent, THandlerOutEvent, ToSwarm};
|
||||
use libp2p::swarm::{
|
||||
CloseConnection, ConnectionClosed, ConnectionDenied, ConnectionHandler,
|
||||
ConnectionHandlerSelect, ConnectionId, FromSwarm, NetworkBehaviour, THandler, THandlerInEvent,
|
||||
THandlerOutEvent, ToSwarm, dummy,
|
||||
};
|
||||
use libp2p::{Multiaddr, PeerId, identity, mdns};
|
||||
use std::collections::{BTreeSet, HashMap};
|
||||
use std::convert::Infallible;
|
||||
@@ -16,16 +21,14 @@ use std::net::IpAddr;
|
||||
use std::task::{Context, Poll};
|
||||
use std::time::Duration;
|
||||
use util::wakerdeque::WakerDeque;
|
||||
use crate::ext::MultiaddrExt;
|
||||
|
||||
|
||||
const RETRY_CONNECT_INTERVAL: Duration = Duration::from_secs(5);
|
||||
|
||||
mod managed {
|
||||
use libp2p::swarm::NetworkBehaviour;
|
||||
use libp2p::{identity, mdns, ping};
|
||||
use std::io;
|
||||
use std::time::Duration;
|
||||
use libp2p::{identity, mdns, ping};
|
||||
use libp2p::swarm::NetworkBehaviour;
|
||||
|
||||
const MDNS_RECORD_TTL: Duration = Duration::from_secs(2_500);
|
||||
const MDNS_QUERY_INTERVAL: Duration = Duration::from_secs(1_500);
|
||||
@@ -64,7 +67,11 @@ mod managed {
|
||||
}
|
||||
|
||||
fn ping_behaviour() -> ping::Behaviour {
|
||||
ping::Behaviour::new(ping::Config::new().with_timeout(PING_TIMEOUT).with_interval(PING_INTERVAL))
|
||||
ping::Behaviour::new(
|
||||
ping::Config::new()
|
||||
.with_timeout(PING_TIMEOUT)
|
||||
.with_interval(PING_INTERVAL),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -129,7 +136,6 @@ impl Behaviour {
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
fn handle_mdns_discovered(&mut self, peers: Vec<(PeerId, Multiaddr)>) {
|
||||
for (p, ma) in peers {
|
||||
self.dial(p, ma.clone()); // always connect
|
||||
@@ -261,11 +267,10 @@ impl NetworkBehaviour for Behaviour {
|
||||
) {
|
||||
match event {
|
||||
Either::Left(ev) => libp2p::core::util::unreachable(ev),
|
||||
Either::Right(ev) => self.managed.on_connection_handler_event(
|
||||
peer_id,
|
||||
connection_id,
|
||||
ev,
|
||||
),
|
||||
Either::Right(ev) => {
|
||||
self.managed
|
||||
.on_connection_handler_event(peer_id, connection_id, ev)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -331,7 +336,7 @@ impl NetworkBehaviour for Behaviour {
|
||||
mdns::Event::Expired(peers) => {
|
||||
self.handle_mdns_expired(peers);
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
// handle ping events => if error then disconnect
|
||||
managed::BehaviourEvent::Ping(e) => {
|
||||
@@ -346,7 +351,6 @@ impl NetworkBehaviour for Behaviour {
|
||||
cx.waker().wake_by_ref();
|
||||
}
|
||||
|
||||
|
||||
// forward any other mDNS event to the swarm or its connection handler(s)
|
||||
Poll::Ready(e) => {
|
||||
return Poll::Ready(
|
||||
|
||||
@@ -28,10 +28,10 @@ pub(crate) mod alias {
|
||||
|
||||
/// Namespace for crate-wide extension traits/methods
|
||||
pub(crate) mod ext {
|
||||
use std::net::IpAddr;
|
||||
use extend::ext;
|
||||
use libp2p::Multiaddr;
|
||||
use libp2p::multiaddr::Protocol;
|
||||
use std::net::IpAddr;
|
||||
|
||||
#[ext(pub, name = MultiaddrExt)]
|
||||
impl Multiaddr {
|
||||
@@ -42,7 +42,7 @@ pub(crate) mod ext {
|
||||
match p {
|
||||
Protocol::Ip4(ip) => IpAddr::V4(ip),
|
||||
Protocol::Ip6(ip) => IpAddr::V6(ip),
|
||||
_ => return None
|
||||
_ => return None,
|
||||
}
|
||||
} else {
|
||||
return None;
|
||||
|
||||
@@ -37,7 +37,7 @@ mod transport {
|
||||
use libp2p::core::transport::Boxed;
|
||||
use libp2p::pnet::{PnetError, PnetOutput};
|
||||
use libp2p::{PeerId, Transport, identity, noise, pnet, yamux};
|
||||
use std::{sync::LazyLock, env};
|
||||
use std::{env, sync::LazyLock};
|
||||
|
||||
/// Key used for networking's private network; parametrized on the [`NETWORK_VERSION`].
|
||||
/// See [`pnet_upgrade`] for more.
|
||||
@@ -49,7 +49,8 @@ mod transport {
|
||||
builder.update(&bytes)
|
||||
} else {
|
||||
builder.update(NETWORK_VERSION)
|
||||
}.finalize()
|
||||
}
|
||||
.finalize()
|
||||
});
|
||||
|
||||
/// Make the Swarm run on a private network, as to not clash with public libp2p nodes and
|
||||
@@ -103,9 +104,9 @@ mod transport {
|
||||
|
||||
mod behaviour {
|
||||
use crate::{alias, discovery};
|
||||
use std::time::Duration;
|
||||
use libp2p::swarm::NetworkBehaviour;
|
||||
use libp2p::{gossipsub, identity};
|
||||
use std::time::Duration;
|
||||
|
||||
/// Behavior of the Swarm which composes all desired behaviors:
|
||||
/// Right now its just [`discovery::Behaviour`] and [`gossipsub::Behaviour`].
|
||||
|
||||
@@ -27,7 +27,7 @@ def stream_chat(host: str, query: str) -> None:
|
||||
if not line.startswith("data:"):
|
||||
continue
|
||||
|
||||
data = line[len("data:"):].strip()
|
||||
data = line[len("data:") :].strip()
|
||||
if data == "[DONE]":
|
||||
break
|
||||
|
||||
@@ -55,7 +55,8 @@ def main() -> None:
|
||||
)
|
||||
parser.add_argument("host", help="Hostname (without protocol), e.g. localhost")
|
||||
parser.add_argument(
|
||||
"-f", "--file",
|
||||
"-f",
|
||||
"--file",
|
||||
help="Path to a text file whose contents will be used as the query",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
||||
Reference in New Issue
Block a user