From 562998380926ee73e141f185085e6e05cfa510f9 Mon Sep 17 00:00:00 2001 From: Jake Hillion Date: Fri, 5 Dec 2025 16:58:55 +0000 Subject: [PATCH] fmt: format all python/rust/nix files --- .github/scripts/bench.py | 613 ++++++++++++-------- .github/scripts/build_matrix.py | 33 +- rust/exo_pyo3_bindings/src/networking.rs | 16 +- rust/networking/examples/chatroom_manual.rs | 9 +- rust/networking/src/discovery.rs | 54 +- rust/networking/src/keep_alive.rs | 8 +- rust/networking/src/lib.rs | 6 +- rust/networking/src/swarm.rs | 13 +- tmp/run_llm.py | 7 +- 9 files changed, 460 insertions(+), 299 deletions(-) diff --git a/.github/scripts/bench.py b/.github/scripts/bench.py index 06b81542..6b4b3ab1 100644 --- a/.github/scripts/bench.py +++ b/.github/scripts/bench.py @@ -9,6 +9,7 @@ Requests are fire-and-forget, allowing overlapping execution. Simple benchmark (1 iteration): --config .github/configs/bench_simple.yaml Complex benchmark (multiple stages): --config .github/configs/bench_config.yaml """ + # pyright: reportAny=false, reportUnknownArgumentType=false, reportUnknownVariableType=false from __future__ import annotations @@ -33,9 +34,13 @@ def _format_http_error(error: Exception) -> str: body = error.read().decode("utf-8", errors="replace") except Exception: body = "" - - headers_str = "\n".join(f" {k}: {v}" for k, v in error.headers.items()) if error.headers else "" - + + headers_str = ( + "\n".join(f" {k}: {v}" for k, v in error.headers.items()) + if error.headers + else "" + ) + return ( f"HTTP {error.code} {error.reason}\n" f"URL: {error.url}\n" @@ -48,7 +53,9 @@ def _format_http_error(error: Exception) -> str: return str(error) -def _http_request(url: str, *, method: str = "GET", data: Mapping[str, Any] | None = None) -> dict[str, Any]: +def _http_request( + url: str, *, method: str = "GET", data: Mapping[str, Any] | None = None +) -> dict[str, Any]: headers = {"Content-Type": "application/json"} payload: bytes | None = None if data is not None: @@ -67,14 +74,21 @@ def _http_request(url: str, *, method: str = "GET", data: Mapping[str, Any] | No raise -async def _http_request_async(url: str, *, method: str = "GET", data: Mapping[str, Any] | None = None) -> dict[str, Any]: +async def _http_request_async( + url: str, *, method: str = "GET", data: Mapping[str, Any] | None = None +) -> dict[str, Any]: """Async version that runs in executor to not block event loop.""" loop = asyncio.get_event_loop() - return await loop.run_in_executor(None, lambda: _http_request(url, method=method, data=data)) + return await loop.run_in_executor( + None, lambda: _http_request(url, method=method, data=data) + ) -async def _http_stream_async(url: str, *, method: str = "POST", data: Mapping[str, Any], timeout: int = 300) -> list[tuple[str, float]]: +async def _http_stream_async( + url: str, *, method: str = "POST", data: Mapping[str, Any], timeout: int = 300 +) -> list[tuple[str, float]]: """Async streaming request. Returns list of (line, timestamp) tuples.""" + def _stream() -> list[tuple[str, float]]: headers = {"Content-Type": "application/json"} payload = json.dumps(data).encode("utf-8") @@ -92,6 +106,7 @@ async def _http_stream_async(url: str, *, method: str = "POST", data: Mapping[st error_details = _format_http_error(e) print(f"HTTP request failed:\n{error_details}") raise + loop = asyncio.get_event_loop() return await loop.run_in_executor(None, _stream) @@ -102,48 +117,48 @@ def fetch_state(api_base: str) -> dict[str, Any]: def unwrap_tagged_union(obj: Any) -> tuple[str | None, Any]: """Extract tag and payload from tagged union format {Tag: {fields...}}. - + Returns (tag_name, payload) if the object is a tagged union, otherwise (None, obj). """ if not isinstance(obj, dict): return None, obj - + keys = list(obj.keys()) if len(keys) == 1 and isinstance(keys[0], str): tag = keys[0] payload = obj[tag] return tag, payload - + return None, obj def collect_metrics_snapshot(state: Mapping[str, Any]) -> MetricsSnapshot: """Collect current metrics snapshot from state.""" timestamp = time.time() - + # Collect memory for each node node_memory: dict[str, MemorySnapshot] = {} node_profiles: Mapping[str, Any] = state.get("nodeProfiles", {}) - + for node_id, profile in node_profiles.items(): if not isinstance(profile, dict): continue - + memory = profile.get("memory", {}) if not isinstance(memory, dict): continue - + # Parse memory values - they're objects with 'inBytes' field def get_bytes(mem_obj: Any) -> int: if isinstance(mem_obj, dict): return int(mem_obj.get("inBytes", 0)) return 0 - + ram_total = get_bytes(memory.get("ramTotal")) ram_available = get_bytes(memory.get("ramAvailable")) swap_total = get_bytes(memory.get("swapTotal")) swap_available = get_bytes(memory.get("swapAvailable")) - + node_memory[node_id] = MemorySnapshot( ram_total_bytes=ram_total, ram_available_bytes=ram_available, @@ -152,13 +167,13 @@ def collect_metrics_snapshot(state: Mapping[str, Any]) -> MetricsSnapshot: swap_available_bytes=swap_available, swap_used_bytes=max(swap_total - swap_available, 0), ) - + # Collect task counts per instance and per node instance_tasks: list[InstanceTaskSnapshot] = [] instances: Mapping[str, Any] = state.get("instances", {}) tasks: Mapping[str, Any] = state.get("tasks", {}) print(f"[DEBUG] Num tasks: {len(tasks)}. Num instances: {len(instances)}.") - + # Map instance_id -> node_ids (instances can span multiple nodes) instance_to_nodes: dict[str, set[str]] = {} for instance_id, instance_wrapped in instances.items(): @@ -166,16 +181,16 @@ def collect_metrics_snapshot(state: Mapping[str, Any]) -> MetricsSnapshot: _instance_tag, instance_data = unwrap_tagged_union(instance_wrapped) if not isinstance(instance_data, dict): continue - + shard_assignments = instance_data.get("shardAssignments", {}) if not isinstance(shard_assignments, dict): continue - + # Get all nodes that this instance uses node_to_runner = shard_assignments.get("nodeToRunner", {}) if isinstance(node_to_runner, dict): instance_to_nodes[instance_id] = set(node_to_runner.keys()) - + # Count tasks per instance (only Pending and Running exist in state; completed tasks are deleted) instance_task_counts: dict[str, dict[str, int]] = {} for instance_id in instances.keys(): @@ -183,57 +198,61 @@ def collect_metrics_snapshot(state: Mapping[str, Any]) -> MetricsSnapshot: "Pending": 0, "Running": 0, } - + # Iterate through tasks and count by instance and status tasks_matched = 0 tasks_skipped = 0 - + for _task_id, task_wrapper in tasks.items(): if not isinstance(task_wrapper, dict): print(f"[DEBUG] Task wrapper is not a dict: {task_wrapper}") tasks_skipped += 1 continue - + # Extract actual task from wrapper (e.g., {"ChatCompletion": {...}}) if len(task_wrapper) != 1: - print(f"[DEBUG] Task wrapper has unexpected number of keys: {len(task_wrapper)}") + print( + f"[DEBUG] Task wrapper has unexpected number of keys: {len(task_wrapper)}" + ) tasks_skipped += 1 continue - + _task_type, task_data = next(iter(task_wrapper.items())) - + if not isinstance(task_data, dict): print(f"[DEBUG] Task data is not a dict: {task_data}") tasks_skipped += 1 continue - + instance_id = task_data.get("instanceId") task_status = task_data.get("taskStatus") - + if not instance_id or instance_id not in instance_task_counts: tasks_skipped += 1 continue - + if task_status not in ["Pending", "Running"]: tasks_skipped += 1 continue - + # Count this task instance_task_counts[instance_id][task_status] += 1 tasks_matched += 1 - + if tasks_skipped > 0: - print(f"[DEBUG] Task matching: {tasks_matched} matched, {tasks_skipped} skipped (from {len(tasks)} total)") - + print( + f"[DEBUG] Task matching: {tasks_matched} matched, {tasks_skipped} skipped (from {len(tasks)} total)" + ) + # Build snapshots for each instance (assign to primary node - first in sorted order) for instance_id, counts in instance_task_counts.items(): pending = counts["Pending"] running = counts["Running"] total_active = pending + running - + node_ids = instance_to_nodes.get(instance_id, set()) primary_node = sorted(node_ids)[0] if node_ids else "unknown" - + instance_tasks.append( InstanceTaskSnapshot( instance_id=instance_id, @@ -243,32 +262,32 @@ def collect_metrics_snapshot(state: Mapping[str, Any]) -> MetricsSnapshot: total_active_tasks=total_active, ) ) - + # Aggregate tasks per node node_task_counts: dict[str, dict[str, int]] = {} node_instance_counts: dict[str, int] = {} - + for instance_snapshot in instance_tasks: node_id = instance_snapshot.node_id - + if node_id not in node_task_counts: node_task_counts[node_id] = { "Pending": 0, "Running": 0, } node_instance_counts[node_id] = 0 - + node_task_counts[node_id]["Pending"] += instance_snapshot.pending_tasks node_task_counts[node_id]["Running"] += instance_snapshot.running_tasks node_instance_counts[node_id] += 1 - + # Build node snapshots node_tasks: list[NodeTaskSnapshot] = [] for node_id, counts in node_task_counts.items(): pending = counts["Pending"] running = counts["Running"] total_active = pending + running - + node_tasks.append( NodeTaskSnapshot( node_id=node_id, @@ -278,7 +297,7 @@ def collect_metrics_snapshot(state: Mapping[str, Any]) -> MetricsSnapshot: instance_count=node_instance_counts.get(node_id, 0), ) ) - + return MetricsSnapshot( timestamp=timestamp, node_memory=node_memory, @@ -303,14 +322,16 @@ def count_instances_by_model(state: Mapping[str, Any], model_id: str) -> int: _instance_tag, instance_data = unwrap_tagged_union(instance_wrapped) if not isinstance(instance_data, dict): continue - + shard = instance_data.get("shardAssignments", {}) if isinstance(shard, dict) and shard.get("modelId") == model_id: count += 1 return count -def get_all_instance_ids_for_model(state: Mapping[str, Any], model_id: str) -> list[str]: +def get_all_instance_ids_for_model( + state: Mapping[str, Any], model_id: str +) -> list[str]: """Get all instance IDs for a given model_id.""" instances: Mapping[str, Any] = state.get("instances", {}) instance_ids = [] @@ -319,7 +340,7 @@ def get_all_instance_ids_for_model(state: Mapping[str, Any], model_id: str) -> l _instance_tag, instance_data = unwrap_tagged_union(instance_wrapped) if not isinstance(instance_data, dict): continue - + shard = instance_data.get("shardAssignments", {}) if isinstance(shard, dict) and shard.get("modelId") == model_id: instance_ids.append(instance_id) @@ -330,47 +351,49 @@ def count_ready_instances_by_model(state: Mapping[str, Any], model_id: str) -> i """Count how many instances for a model have all their runners ready.""" instances: Mapping[str, Any] = state.get("instances", {}) ready_count = 0 - + for instance_id, instance_wrapped in instances.items(): # Unwrap tagged Instance union _instance_tag, instance_data = unwrap_tagged_union(instance_wrapped) if not isinstance(instance_data, dict): continue - + shard = instance_data.get("shardAssignments", {}) if not isinstance(shard, dict) or shard.get("modelId") != model_id: continue - + # Check if all runners for this instance are ready runner_ids = get_runner_ids_for_instance(state, instance_id) if len(runner_ids) == 0: continue - + # Fixed runner status names: RunnerReady and RunnerRunning (not LoadedRunnerStatus/RunningRunnerStatus) all_ready = all( get_runner_status_kind(state, rid) in {"RunnerReady", "RunnerRunning"} for rid in runner_ids ) - + if all_ready: ready_count += 1 - + return ready_count -def get_runner_ids_for_instance(state: Mapping[str, Any], instance_id: str) -> list[str]: +def get_runner_ids_for_instance( + state: Mapping[str, Any], instance_id: str +) -> list[str]: instances: Mapping[str, Any] = state.get("instances", {}) instance_wrapped = instances.get(instance_id, {}) - + # Unwrap tagged Instance union _instance_tag, instance_data = unwrap_tagged_union(instance_wrapped) if not isinstance(instance_data, dict): return [] - + shard_assignments = instance_data.get("shardAssignments", {}) if not isinstance(shard_assignments, dict): return [] - + r2s = shard_assignments.get("runnerToShard", {}) if isinstance(r2s, dict): return list(r2s.keys()) @@ -387,43 +410,59 @@ def get_runner_status_kind(state: Mapping[str, Any], runner_id: str) -> str | No return None -async def wait_for_topology_ready(api_base: str, expected_nodes: int, timeout_s: int) -> None: +async def wait_for_topology_ready( + api_base: str, expected_nodes: int, timeout_s: int +) -> None: """Wait for all expected nodes to appear in the topology.""" - print(f"Waiting for {expected_nodes} node(s) to appear in topology (timeout: {timeout_s}s)...") + print( + f"Waiting for {expected_nodes} node(s) to appear in topology (timeout: {timeout_s}s)..." + ) start = time.monotonic() while True: state = fetch_state(api_base) node_count = get_topology_node_count(state) elapsed = time.monotonic() - start - print(f" Topology has {node_count}/{expected_nodes} nodes (elapsed: {elapsed:.1f}s)") - + print( + f" Topology has {node_count}/{expected_nodes} nodes (elapsed: {elapsed:.1f}s)" + ) + if node_count >= expected_nodes: print(f"All {expected_nodes} node(s) are in topology!") return - + if elapsed > timeout_s: - raise TimeoutError(f"Timed out waiting for topology. Expected {expected_nodes} nodes, got {node_count}") + raise TimeoutError( + f"Timed out waiting for topology. Expected {expected_nodes} nodes, got {node_count}" + ) await asyncio.sleep(2) -async def wait_for_instances_ready(api_base: str, model_id: str, expected_count: int, timeout_s: int) -> list[str]: +async def wait_for_instances_ready( + api_base: str, model_id: str, expected_count: int, timeout_s: int +) -> list[str]: """Wait for a specific count of instances for a model to be fully ready.""" - print(f"Waiting for {expected_count} instance(s) of {model_id} to be ready (timeout: {timeout_s}s)...") + print( + f"Waiting for {expected_count} instance(s) of {model_id} to be ready (timeout: {timeout_s}s)..." + ) start = time.monotonic() while True: state = fetch_state(api_base) - + total_count = count_instances_by_model(state, model_id) ready_count = count_ready_instances_by_model(state, model_id) elapsed = time.monotonic() - start - - print(f" Model {model_id}: {ready_count}/{expected_count} ready ({total_count} total) (elapsed: {elapsed:.1f}s)") - + + print( + f" Model {model_id}: {ready_count}/{expected_count} ready ({total_count} total) (elapsed: {elapsed:.1f}s)" + ) + if ready_count >= expected_count: instance_ids = get_all_instance_ids_for_model(state, model_id) - print(f"All {expected_count} instance(s) ready! Instance IDs: {instance_ids}") + print( + f"All {expected_count} instance(s) ready! Instance IDs: {instance_ids}" + ) return instance_ids - + if elapsed > timeout_s: raise TimeoutError( f"Timed out waiting for instances. Expected {expected_count} ready instances of {model_id}, " @@ -448,44 +487,52 @@ async def wait_for_all_instances_deleted(api_base: str, model_id: str) -> None: async def wait_for_tasks_drained(api_base: str, timeout_s: int = 600) -> None: """Wait for all tasks in the cluster to be drained (completed or failed). - + Tasks are deleted from state when complete, so we wait until there are no pending or running tasks remaining. """ - print(f"\n{'='*80}") + print(f"\n{'=' * 80}") print(f"⏳ WAITING FOR ALL TASKS TO DRAIN") - print(f"{'='*80}") + print(f"{'=' * 80}") start = time.monotonic() - + while True: state = fetch_state(api_base) snapshot = collect_metrics_snapshot(state) - + # Count total active tasks across all nodes total_pending = sum(node.pending_tasks for node in snapshot.node_tasks) total_running = sum(node.running_tasks for node in snapshot.node_tasks) total_active = total_pending + total_running - + elapsed = time.monotonic() - start - + if total_active == 0: print(f"✅ All tasks drained after {elapsed:.1f}s") return - - print(f" [{elapsed:.1f}s] Still draining: {total_active} active tasks ({total_pending} pending, {total_running} running)") - + + print( + f" [{elapsed:.1f}s] Still draining: {total_active} active tasks ({total_pending} pending, {total_running} running)" + ) + # Print per-node breakdown if there are active tasks if snapshot.node_tasks: for node_snapshot in snapshot.node_tasks: if node_snapshot.total_active_tasks > 0: node_short = node_snapshot.node_id[-4:] - print(f" Node ...{node_short}: {node_snapshot.running_tasks} running, {node_snapshot.pending_tasks} pending") - + print( + f" Node ...{node_short}: {node_snapshot.running_tasks} running, {node_snapshot.pending_tasks} pending" + ) + if elapsed > timeout_s: - print(f"⚠️ WARNING: Timed out waiting for tasks to drain after {timeout_s}s") - print(f" Remaining: {total_active} tasks ({total_pending} pending, {total_running} running)") + print( + f"⚠️ WARNING: Timed out waiting for tasks to drain after {timeout_s}s" + ) + print( + f" Remaining: {total_active} tasks ({total_pending} pending, {total_running} running)" + ) return - + await asyncio.sleep(2) @@ -545,6 +592,7 @@ class StageResult: @dataclass(frozen=True) class MemorySnapshot: """Memory snapshot for a node at a point in time.""" + ram_total_bytes: int ram_available_bytes: int ram_used_bytes: int @@ -556,10 +604,11 @@ class MemorySnapshot: @dataclass(frozen=True) class InstanceTaskSnapshot: """Task counts for an instance at a point in time. - + Note: Tasks are deleted from state when complete, so we only track active tasks. total_active_tasks = pending + running. """ + instance_id: str node_id: str pending_tasks: int @@ -570,10 +619,11 @@ class InstanceTaskSnapshot: @dataclass(frozen=True) class NodeTaskSnapshot: """Task counts for a node at a point in time. - + Note: Tasks are deleted from state when complete, so we only track active tasks. total_active_tasks = pending + running across all instances on this node. """ + node_id: str pending_tasks: int running_tasks: int @@ -584,6 +634,7 @@ class NodeTaskSnapshot: @dataclass(frozen=True) class MetricsSnapshot: """System metrics snapshot at a point in time.""" + timestamp: float node_memory: dict[str, MemorySnapshot] instance_tasks: list[InstanceTaskSnapshot] @@ -613,16 +664,16 @@ async def run_single_request( }, timeout=timeout, ) - + tokens = 0 got_done = False first_token_time: float | None = None last_token_time: float | None = None - + for line, timestamp in lines: if not line.startswith("data:"): continue - payload = line[len("data:"):].strip() + payload = line[len("data:") :].strip() if payload == "[DONE]": got_done = True break @@ -636,26 +687,28 @@ async def run_single_request( tokens += 1 except json.JSONDecodeError: continue - + elapsed = time.monotonic() - start completed_at = time.time() - + # Calculate TTFT and decode TPS time_to_first_token: float | None = None decode_tps: float | None = None - + if first_token_time is not None: time_to_first_token = first_token_time - start - + # Decode TPS: tokens per second after first token if last_token_time is not None and tokens > 1: decode_time = last_token_time - first_token_time if decode_time > 0: decode_tps = (tokens - 1) / decode_time - + # Request is only successful if we got at least one token AND a [DONE] marker if tokens == 0: - print(f" Request #{request_id}: FAILED - no tokens generated in {elapsed:.2f}s") + print( + f" Request #{request_id}: FAILED - no tokens generated in {elapsed:.2f}s" + ) return RequestResult( request_id=request_id, success=False, @@ -665,11 +718,13 @@ async def run_single_request( completed_at=completed_at, time_to_first_token_s=time_to_first_token, decode_tps=decode_tps, - error="No tokens generated" + error="No tokens generated", ) - + if not got_done: - print(f" Request #{request_id}: FAILED - incomplete response (no [DONE]) after {elapsed:.2f}s") + print( + f" Request #{request_id}: FAILED - incomplete response (no [DONE]) after {elapsed:.2f}s" + ) return RequestResult( request_id=request_id, success=False, @@ -679,12 +734,16 @@ async def run_single_request( completed_at=completed_at, time_to_first_token_s=time_to_first_token, decode_tps=decode_tps, - error="Incomplete response (no [DONE] marker)" + error="Incomplete response (no [DONE] marker)", ) - - ttft_str = f"{time_to_first_token:.3f}s" if time_to_first_token is not None else "N/A" + + ttft_str = ( + f"{time_to_first_token:.3f}s" if time_to_first_token is not None else "N/A" + ) tps_str = f"{decode_tps:.1f} t/s" if decode_tps is not None else "N/A" - print(f" Request #{request_id}: SUCCESS - {tokens} tokens in {elapsed:.2f}s (TTFT: {ttft_str}, Decode: {tps_str})") + print( + f" Request #{request_id}: SUCCESS - {tokens} tokens in {elapsed:.2f}s (TTFT: {ttft_str}, Decode: {tps_str})" + ) return RequestResult( request_id=request_id, success=True, @@ -693,9 +752,9 @@ async def run_single_request( started_at=started_at, completed_at=completed_at, time_to_first_token_s=time_to_first_token, - decode_tps=decode_tps + decode_tps=decode_tps, ) - + except Exception as e: elapsed = time.monotonic() - start completed_at = time.time() @@ -710,7 +769,7 @@ async def run_single_request( completed_at=completed_at, time_to_first_token_s=None, decode_tps=None, - error=error_details + error=error_details, ) @@ -721,10 +780,10 @@ async def monitor_metrics( interval_seconds: float = 5.0, ) -> None: """Background task that collects metrics snapshots every interval_seconds.""" - print(f"\n{'='*80}") + print(f"\n{'=' * 80}") print(f"🔍 METRICS MONITORING STARTED (polling every {interval_seconds}s)") - print(f"{'='*80}\n") - + print(f"{'=' * 80}\n") + snapshot_count = 0 while not stop_event.is_set(): try: @@ -732,30 +791,35 @@ async def monitor_metrics( state = fetch_state(api_base) snapshot = collect_metrics_snapshot(state) metrics_snapshots.append(snapshot) - + # Print detailed summary node_count = len(snapshot.node_memory) instance_count = len(snapshot.instance_tasks) - + # Aggregate task counts from node level (only active tasks in state) total_pending = sum(node.pending_tasks for node in snapshot.node_tasks) total_running = sum(node.running_tasks for node in snapshot.node_tasks) total_active = sum(node.total_active_tasks for node in snapshot.node_tasks) - + # Print detailed breakdown - print(f"\n[METRICS #{snapshot_count}] {node_count} nodes, {instance_count} instances | Active Tasks: {total_active} ({total_pending} pending, {total_running} running)") - + print( + f"\n[METRICS #{snapshot_count}] {node_count} nodes, {instance_count} instances | Active Tasks: {total_active} ({total_pending} pending, {total_running} running)" + ) + # Print per-node breakdown (only if there are nodes) if snapshot.node_tasks: for node_snapshot in snapshot.node_tasks: node_short = node_snapshot.node_id[-4:] - print(f" Node ...{node_short}: {node_snapshot.total_active_tasks} active ({node_snapshot.pending_tasks} pending, {node_snapshot.running_tasks} running) across {node_snapshot.instance_count} instances") - + print( + f" Node ...{node_short}: {node_snapshot.total_active_tasks} active ({node_snapshot.pending_tasks} pending, {node_snapshot.running_tasks} running) across {node_snapshot.instance_count} instances" + ) + except Exception as e: print(f"[METRICS] Error collecting snapshot: {e}") import traceback + traceback.print_exc() - + # Wait for interval or until stopped try: await asyncio.wait_for(stop_event.wait(), timeout=interval_seconds) @@ -779,18 +843,20 @@ async def run_stage( print(f" Iterations: {stage.iterations}") print(f" No Overlap: {no_overlap}") print("=" * 80) - + stage_started_at = time.time() prompt = generate_prompt(stage.prompt_length) results: list[RequestResult] = [] - + if no_overlap: # Sequential execution: wait for each request to complete before starting next print("\nRunning requests sequentially (no overlap)...") for i in range(stage.iterations): - result = await run_single_request(api_base, model_id, prompt, stage.generation_length, i + 1) + result = await run_single_request( + api_base, model_id, prompt, stage.generation_length, i + 1 + ) results.append(result) - + # Wait before starting next request (except after last one) if i < stage.iterations - 1: await asyncio.sleep(stage.time_between_requests) @@ -798,28 +864,30 @@ async def run_stage( # Concurrent execution: fire-and-forget with delays between starts print("\nRunning requests concurrently (with overlap)...") tasks: list[asyncio.Task[RequestResult]] = [] - + # Fire off requests with delays between them for i in range(stage.iterations): task = asyncio.create_task( - run_single_request(api_base, model_id, prompt, stage.generation_length, i + 1) + run_single_request( + api_base, model_id, prompt, stage.generation_length, i + 1 + ) ) tasks.append(task) - + # Wait before firing next request (except after last one) if i < stage.iterations - 1: await asyncio.sleep(stage.time_between_requests) - + # Wait for all requests to complete print(f"\nWaiting for all {len(tasks)} HTTP requests to complete...") results = list(await asyncio.gather(*tasks)) - + # Wait for all tasks in the cluster to be drained print(f"\nHTTP requests completed. Now waiting for cluster tasks to drain...") await wait_for_tasks_drained(api_base, timeout_s=600) - + stage_completed_at = time.time() - + # Compute statistics successful = sum(1 for r in results if r.success) failed = len(results) - successful @@ -828,37 +896,55 @@ async def run_stage( total_time = sum(r.elapsed_s for r in results) avg_tokens = total_tokens / successful if successful > 0 else 0.0 avg_time = total_time / successful if successful > 0 else 0.0 - + # Calculate average TTFT and decode TPS for successful requests only successful_results = [r for r in results if r.success] - + # Skip first iteration if there are more than 1 iterations (warmup) - results_for_stats = successful_results[1:] if len(successful_results) > 1 else successful_results - + results_for_stats = ( + successful_results[1:] if len(successful_results) > 1 else successful_results + ) + # TTFT statistics - ttft_values = [r.time_to_first_token_s for r in results_for_stats if r.time_to_first_token_s is not None] + ttft_values = [ + r.time_to_first_token_s + for r in results_for_stats + if r.time_to_first_token_s is not None + ] avg_ttft = sum(ttft_values) / len(ttft_values) if ttft_values else None - + if avg_ttft is not None and len(ttft_values) > 1: variance_ttft = sum((x - avg_ttft) ** 2 for x in ttft_values) / len(ttft_values) - std_ttft = variance_ttft ** 0.5 + std_ttft = variance_ttft**0.5 else: std_ttft = None - + # Decode TPS and ms per token statistics - decode_tps_values = [r.decode_tps for r in results_for_stats if r.decode_tps is not None] - avg_decode_tps = sum(decode_tps_values) / len(decode_tps_values) if decode_tps_values else None - + decode_tps_values = [ + r.decode_tps for r in results_for_stats if r.decode_tps is not None + ] + avg_decode_tps = ( + sum(decode_tps_values) / len(decode_tps_values) if decode_tps_values else None + ) + # Convert to ms per token - ms_per_token_values = [1000.0 / tps for tps in decode_tps_values] if decode_tps_values else [] - avg_ms_per_token = sum(ms_per_token_values) / len(ms_per_token_values) if ms_per_token_values else None - + ms_per_token_values = ( + [1000.0 / tps for tps in decode_tps_values] if decode_tps_values else [] + ) + avg_ms_per_token = ( + sum(ms_per_token_values) / len(ms_per_token_values) + if ms_per_token_values + else None + ) + if avg_ms_per_token is not None and len(ms_per_token_values) > 1: - variance_ms_per_token = sum((x - avg_ms_per_token) ** 2 for x in ms_per_token_values) / len(ms_per_token_values) - std_ms_per_token = variance_ms_per_token ** 0.5 + variance_ms_per_token = sum( + (x - avg_ms_per_token) ** 2 for x in ms_per_token_values + ) / len(ms_per_token_values) + std_ms_per_token = variance_ms_per_token**0.5 else: std_ms_per_token = None - + return StageResult( name=stage.name, total_requests=len(results), @@ -892,11 +978,11 @@ async def run_benchmark( ) -> int: """Run the full staged benchmark.""" benchmark_started_at = time.time() - + # Load configuration with open(config_path) as f: config = yaml.safe_load(f) - + # Support both model_id (legacy) and model_ids (new) if "model_ids" in config: model_ids = config["model_ids"] @@ -904,51 +990,60 @@ async def run_benchmark( model_ids = [config["model_id"]] else: raise ValueError("Config must contain either 'model_id' or 'model_ids'") - + # Get sharding and instance_meta (optional, defaults to None if not specified) sharding: str | None = config.get("sharding") instance_meta: str | None = config.get("instance_meta") - + # Get no_overlap flag (optional, defaults to False) no_overlap: bool = config.get("no_overlap", False) - + stages = [StageConfig(**s) for s in config["stages"]] - + print("=" * 80) print("EXO BENCHMARK") print("=" * 80) print(f"Configuration File: {config_path}") print(f"Model IDs: {model_ids}") print(f"Instance Count: {len(model_ids)}") - print(f"Sharding: {sharding if sharding else 'not specified (defaults to Pipeline)'}") - print(f"Instance Type: {instance_meta if instance_meta else 'not specified (defaults to MlxRing)'}") + print( + f"Sharding: {sharding if sharding else 'not specified (defaults to Pipeline)'}" + ) + print( + f"Instance Type: {instance_meta if instance_meta else 'not specified (defaults to MlxRing)'}" + ) print(f"No Overlap: {no_overlap}") print(f"Stages: {len(stages)}") print(f"Expected Nodes: {expected_nodes}") print(f"Is Primary: {is_primary}") print("=" * 80) - + try: # Wait for all nodes to join the topology first - await wait_for_topology_ready(api_base, expected_nodes, timeout_s=timeout_seconds) - + await wait_for_topology_ready( + api_base, expected_nodes, timeout_s=timeout_seconds + ) + # Add 30 second delay to allow topology to stabilize before creating instances - print(f"\nWaiting 30 seconds for topology to stabilize before creating instances...") + print( + f"\nWaiting 30 seconds for topology to stabilize before creating instances..." + ) await asyncio.sleep(30) print("Proceeding with instance creation\n") - + # Count how many instances we need for each unique model_id from collections import Counter + model_counts = Counter(model_ids) - + print(f"\nTarget instance counts by model:") for model_id, count in model_counts.items(): print(f" {model_id}: {count} instance(s)") print() - + # Track all instance IDs (collected at the end) all_instance_ids: list[str] = [] - + if is_primary: # Primary: create instances one at a time, waiting for count to increase for idx, model_id in enumerate(model_ids): @@ -956,50 +1051,58 @@ async def run_benchmark( current_state = fetch_state(api_base) current_ready = count_ready_instances_by_model(current_state, model_id) target_count = current_ready + 1 - + print("=" * 80) - print(f"[PRIMARY] Creating instance {idx+1}/{len(model_ids)} for model: {model_id}") - print(f"[PRIMARY] Current ready count for {model_id}: {current_ready}, target: {target_count}") - + print( + f"[PRIMARY] Creating instance {idx + 1}/{len(model_ids)} for model: {model_id}" + ) + print( + f"[PRIMARY] Current ready count for {model_id}: {current_ready}, target: {target_count}" + ) + # Build instance creation request data instance_data: dict[str, Any] = {"model_id": model_id} if sharding is not None: instance_data["sharding"] = sharding if instance_meta is not None: instance_data["instance_meta"] = instance_meta - + response = await _http_request_async( - f"{api_base}/instance", - method="POST", - data=instance_data + f"{api_base}/instance", method="POST", data=instance_data ) print(f"[PRIMARY] Instance creation response: {response}") - + # Wait for one more instance of this model to be ready - await wait_for_instances_ready(api_base, model_id, target_count, timeout_s=timeout_seconds) - print(f"[PRIMARY] Instance {idx+1}/{len(model_ids)} is ready") + await wait_for_instances_ready( + api_base, model_id, target_count, timeout_s=timeout_seconds + ) + print(f"[PRIMARY] Instance {idx + 1}/{len(model_ids)} is ready") print("=" * 80) else: # Secondary: wait for expected counts of each model to be ready print("[SECONDARY] Waiting for all instances to be created and ready...") for model_id, expected_count in model_counts.items(): - await wait_for_instances_ready(api_base, model_id, expected_count, timeout_s=timeout_seconds) - + await wait_for_instances_ready( + api_base, model_id, expected_count, timeout_s=timeout_seconds + ) + # Collect all instance IDs for all models state = fetch_state(api_base) for model_id in model_counts.keys(): ids = get_all_instance_ids_for_model(state, model_id) all_instance_ids.extend(ids) - + # Count total runners total_runners = 0 for instance_id in all_instance_ids: runner_ids = get_runner_ids_for_instance(state, instance_id) total_runners += len(runner_ids) - - print(f"\nAll {len(all_instance_ids)} instance(s) with {total_runners} total runner(s) are ready!") + + print( + f"\nAll {len(all_instance_ids)} instance(s) with {total_runners} total runner(s) are ready!" + ) print(f"Instance IDs: {all_instance_ids}") - + if is_primary: # Run all stages once (requests will use available instances) # We use the first model_id for the benchmark requests @@ -1008,25 +1111,29 @@ async def run_benchmark( print(f"RUNNING BENCHMARK (using model: {benchmark_model_id})") print(f"Instances available: {len(all_instance_ids)}") print(f"{'=' * 80}") - + # Start metrics monitoring with 500ms interval to catch fast-completing tasks metrics_snapshots: list[MetricsSnapshot] = [] stop_monitoring = asyncio.Event() monitoring_task = asyncio.create_task( - monitor_metrics(api_base, metrics_snapshots, stop_monitoring, interval_seconds=0.5) + monitor_metrics( + api_base, metrics_snapshots, stop_monitoring, interval_seconds=0.5 + ) ) - + stage_results: list[StageResult] = [] for stage in stages: - result = await run_stage(api_base, benchmark_model_id, stage, no_overlap=no_overlap) + result = await run_stage( + api_base, benchmark_model_id, stage, no_overlap=no_overlap + ) stage_results.append(result) - + # Stop metrics monitoring print("\nStopping metrics monitoring...") stop_monitoring.set() await monitoring_task print(f"Collected {len(metrics_snapshots)} metrics snapshots") - + # Print final results print("\n" + "=" * 80) print("BENCHMARK COMPLETE - RESULTS SUMMARY") @@ -1034,7 +1141,7 @@ async def run_benchmark( print(f"Instances tested: {len(all_instance_ids)}") print(f"Model IDs: {model_ids}") print(f"Instance IDs: {all_instance_ids}") - + for result in stage_results: print(f"\nStage: {result.name}") print(f" Total Requests: {result.total_requests}") @@ -1046,19 +1153,25 @@ async def run_benchmark( print(f" Avg Time/Request: {result.avg_time_per_request:.2f}s") if result.avg_time_to_first_token is not None: if result.std_time_to_first_token is not None: - print(f" Avg TTFT: {result.avg_time_to_first_token:.3f}s ± {result.std_time_to_first_token:.3f}s") + print( + f" Avg TTFT: {result.avg_time_to_first_token:.3f}s ± {result.std_time_to_first_token:.3f}s" + ) else: - print(f" Avg TTFT: {result.avg_time_to_first_token:.3f}s") + print( + f" Avg TTFT: {result.avg_time_to_first_token:.3f}s" + ) if result.avg_ms_per_token is not None: if result.std_ms_per_token is not None: - print(f" Avg ms/token: {result.avg_ms_per_token:.2f}ms ± {result.std_ms_per_token:.2f}ms") + print( + f" Avg ms/token: {result.avg_ms_per_token:.2f}ms ± {result.std_ms_per_token:.2f}ms" + ) else: print(f" Avg ms/token: {result.avg_ms_per_token:.2f}ms") if result.avg_decode_tps is not None: print(f" Avg Decode TPS: {result.avg_decode_tps:.2f} tokens/s") - + benchmark_completed_at = time.time() - + # Build comprehensive results document results_doc = { "metadata": { @@ -1100,16 +1213,33 @@ async def run_benchmark( "failed_requests": r.failed_requests, "success_rate": round(r.success_rate, 4), "total_tokens": r.total_tokens, - "avg_tokens_per_request": round(r.avg_tokens_per_request, 2), + "avg_tokens_per_request": round( + r.avg_tokens_per_request, 2 + ), "avg_time_per_request": round(r.avg_time_per_request, 3), - "avg_time_to_first_token": round(r.avg_time_to_first_token, 3) if r.avg_time_to_first_token is not None else None, - "std_time_to_first_token": round(r.std_time_to_first_token, 3) if r.std_time_to_first_token is not None else None, - "avg_decode_tps": round(r.avg_decode_tps, 2) if r.avg_decode_tps is not None else None, - "avg_ms_per_token": round(r.avg_ms_per_token, 2) if r.avg_ms_per_token is not None else None, - "std_ms_per_token": round(r.std_ms_per_token, 2) if r.std_ms_per_token is not None else None, + "avg_time_to_first_token": round( + r.avg_time_to_first_token, 3 + ) + if r.avg_time_to_first_token is not None + else None, + "std_time_to_first_token": round( + r.std_time_to_first_token, 3 + ) + if r.std_time_to_first_token is not None + else None, + "avg_decode_tps": round(r.avg_decode_tps, 2) + if r.avg_decode_tps is not None + else None, + "avg_ms_per_token": round(r.avg_ms_per_token, 2) + if r.avg_ms_per_token is not None + else None, + "std_ms_per_token": round(r.std_ms_per_token, 2) + if r.std_ms_per_token is not None + else None, "stage_started_at": r.stage_started_at, "stage_completed_at": r.stage_completed_at, - "stage_duration_s": r.stage_completed_at - r.stage_started_at, + "stage_duration_s": r.stage_completed_at + - r.stage_started_at, "requests": [ { "request_id": req.request_id, @@ -1118,12 +1248,18 @@ async def run_benchmark( "elapsed_s": round(req.elapsed_s, 3), "started_at": req.started_at, "completed_at": req.completed_at, - "time_to_first_token_s": round(req.time_to_first_token_s, 3) if req.time_to_first_token_s is not None else None, - "decode_tps": round(req.decode_tps, 2) if req.decode_tps is not None else None, + "time_to_first_token_s": round( + req.time_to_first_token_s, 3 + ) + if req.time_to_first_token_s is not None + else None, + "decode_tps": round(req.decode_tps, 2) + if req.decode_tps is not None + else None, "error": req.error, } for req in r.request_results - ] + ], } for r in stage_results ] @@ -1162,40 +1298,44 @@ async def run_benchmark( "instance_count": node.instance_count, } for node in snapshot.node_tasks - ] + ], } for snapshot in metrics_snapshots ] - } + }, } - + # Output JSON summary print("\n" + "=" * 80) print("JSON RESULTS") print("=" * 80) print(json.dumps(results_doc, indent=2)) print("=" * 80) - + # Save to file if path provided if results_output_path: print(f"Saving results to: {results_output_path}") with open(results_output_path, "w") as f: json.dump(results_doc, f, indent=2) print(f"Results saved successfully") - + # Cleanup all instances for instance_id in all_instance_ids: print(f"[PRIMARY] Cleaning up instance: {instance_id}") - await _http_request_async(f"{api_base}/instance/{instance_id}", method="DELETE") + await _http_request_async( + f"{api_base}/instance/{instance_id}", method="DELETE" + ) print(f"[PRIMARY] Instance {instance_id} deleted successfully") else: - print("[SECONDARY] Waiting with cluster (primary handles benchmark execution)") + print( + "[SECONDARY] Waiting with cluster (primary handles benchmark execution)" + ) # Secondary nodes wait until all instances of all models are deleted for model_id in model_counts.keys(): await wait_for_all_instances_deleted(api_base, model_id) - + return 0 - + except TimeoutError as e: print("=" * 80) print(f"TIMEOUT ERROR: {e}") @@ -1205,39 +1345,56 @@ async def run_benchmark( print("=" * 80) print(f"ERROR: {e}") import traceback + traceback.print_exc() print("=" * 80) return 1 def main() -> int: - parser = argparse.ArgumentParser(description="Run unified benchmark for EXO (single or multi-stage)") + parser = argparse.ArgumentParser( + description="Run unified benchmark for EXO (single or multi-stage)" + ) parser.add_argument("--api-port", type=int, required=True) - parser.add_argument("--config", type=Path, required=True, help="Path to YAML config file") - parser.add_argument("--expected-nodes", type=int, required=True, help="Total number of nodes expected in the cluster") - parser.add_argument("--is-primary", type=str, choices=["true", "false"], required=True) + parser.add_argument( + "--config", type=Path, required=True, help="Path to YAML config file" + ) + parser.add_argument( + "--expected-nodes", + type=int, + required=True, + help="Total number of nodes expected in the cluster", + ) + parser.add_argument( + "--is-primary", type=str, choices=["true", "false"], required=True + ) parser.add_argument("--timeout-seconds", type=int, default=1800) - parser.add_argument("--output", type=Path, help="Path to save detailed results JSON") + parser.add_argument( + "--output", type=Path, help="Path to save detailed results JSON" + ) parser.add_argument("--git-commit", type=str, help="Git commit hash for metadata") - parser.add_argument("--hardware-labels", type=str, help="Comma-separated hardware labels") + parser.add_argument( + "--hardware-labels", type=str, help="Comma-separated hardware labels" + ) args = parser.parse_args() - + api_base = f"http://localhost:{args.api_port}" is_primary = args.is_primary.lower() == "true" hardware_labels = args.hardware_labels.split(",") if args.hardware_labels else None - - return asyncio.run(run_benchmark( - api_base, - args.config, - args.expected_nodes, - is_primary, - args.timeout_seconds, - results_output_path=args.output, - git_commit=args.git_commit, - hardware_labels=hardware_labels, - )) + + return asyncio.run( + run_benchmark( + api_base, + args.config, + args.expected_nodes, + is_primary, + args.timeout_seconds, + results_output_path=args.output, + git_commit=args.git_commit, + hardware_labels=hardware_labels, + ) + ) if __name__ == "__main__": sys.exit(main()) - diff --git a/.github/scripts/build_matrix.py b/.github/scripts/build_matrix.py index 324495df..a54cbf7b 100644 --- a/.github/scripts/build_matrix.py +++ b/.github/scripts/build_matrix.py @@ -24,12 +24,12 @@ class Config(TypedDict): # Read the config file -config_file: str = os.environ['CONFIG_FILE'] -with open(config_file, 'r') as f: +config_file: str = os.environ["CONFIG_FILE"] +with open(config_file, "r") as f: config: Config = cast(Config, yaml.safe_load(f)) # Extract hardware plan from config -plan: dict[str, int] = config['hardware_plan'] +plan: dict[str, int] = config["hardware_plan"] if not plan: raise ValueError(f"No hardware_plan found in {config_file}") @@ -40,22 +40,24 @@ for label, count in plan.items(): entries.append({"label": label, "index": idx}) total_nodes: int = len(entries) -matrix: dict[str, list[MatrixInclude]] = {"include": [ - { - "label": e["label"], - "index": e["index"], - "is_primary": (i == 0), - "expected_nodes": total_nodes - } - for i, e in enumerate(entries) -]} +matrix: dict[str, list[MatrixInclude]] = { + "include": [ + { + "label": e["label"], + "index": e["index"], + "is_primary": (i == 0), + "expected_nodes": total_nodes, + } + for i, e in enumerate(entries) + ] +} # Extract other config values -timeout_seconds: int = config.get('timeout_seconds', 600) -environment: dict[str, str] = config.get('environment', {}) +timeout_seconds: int = config.get("timeout_seconds", 600) +environment: dict[str, str] = config.get("environment", {}) # Output to GitHub Actions -with open(os.environ['GITHUB_OUTPUT'], 'a') as f: +with open(os.environ["GITHUB_OUTPUT"], "a") as f: f.write(f"matrix={json.dumps(matrix)}\n") f.write(f"config_file={config_file}\n") f.write(f"timeout_seconds={timeout_seconds}\n") @@ -65,4 +67,3 @@ print(f"Matrix: {json.dumps(matrix)}") print(f"Config file: {config_file}") print(f"Timeout: {timeout_seconds}") print(f"Environment: {json.dumps(environment)}") - diff --git a/rust/exo_pyo3_bindings/src/networking.rs b/rust/exo_pyo3_bindings/src/networking.rs index bf02ec56..e2f88f2b 100644 --- a/rust/exo_pyo3_bindings/src/networking.rs +++ b/rust/exo_pyo3_bindings/src/networking.rs @@ -14,21 +14,20 @@ use libp2p::futures::StreamExt as _; use libp2p::gossipsub::{IdentTopic, Message, MessageId, PublishError}; use libp2p::swarm::SwarmEvent; use libp2p::{gossipsub, mdns}; +use networking::discovery; +use networking::swarm::create_swarm; use pyo3::prelude::{PyModule, PyModuleMethods as _}; use pyo3::types::PyBytes; use pyo3::{Bound, Py, PyErr, PyResult, PyTraverseError, PyVisit, Python, pymethods}; use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pyclass_enum, gen_stub_pymethods}; use std::net::IpAddr; use tokio::sync::{Mutex, mpsc, oneshot}; -use networking::discovery; -use networking::swarm::create_swarm; use util::ext::VecExt as _; mod exception { - use pyo3::{exceptions::{PyException}, prelude::*, PyErrArguments}; use pyo3::types::PyTuple; - use pyo3_stub_gen::{derive::*}; - + use pyo3::{PyErrArguments, exceptions::PyException, prelude::*}; + use pyo3_stub_gen::derive::*; #[gen_stub_pyclass] #[pyclass(frozen, extends=PyException, name="NoPeersSubscribedToTopicError")] @@ -71,7 +70,8 @@ mod exception { pub struct PyAllQueuesFullError {} impl PyAllQueuesFullError { - const MSG: &'static str = "All libp2p peers are unresponsive, resend the message or reconnect."; + const MSG: &'static str = + "All libp2p peers are unresponsive, resend the message or reconnect."; /// Creates a new [ `PyErr` ] of this type. /// @@ -154,10 +154,10 @@ async fn networking_task( connection_update_tx: mpsc::Sender, gossipsub_message_tx: mpsc::Sender<(String, Vec)>, ) { - use networking::swarm::BehaviourEvent::*; use SwarmEvent::*; use ToTask::*; use mdns::Event::*; + use networking::swarm::BehaviourEvent::*; log::info!("RUST: networking task started"); @@ -367,7 +367,7 @@ impl PyNetworkingHandle { connection_update_tx, gossipsub_message_tx, ) - .await; + .await; }); Ok(Self::new( to_task_tx, diff --git a/rust/networking/examples/chatroom_manual.rs b/rust/networking/examples/chatroom_manual.rs index 6c1ffd88..5d92ac86 100644 --- a/rust/networking/examples/chatroom_manual.rs +++ b/rust/networking/examples/chatroom_manual.rs @@ -18,17 +18,14 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use std::{ - error::Error, - hash::{Hash}, -}; -use std::time::Duration; use futures::stream::StreamExt; use libp2p::{ gossipsub, mdns, noise, swarm::{NetworkBehaviour, SwarmEvent}, tcp, yamux, }; +use std::time::Duration; +use std::{error::Error, hash::Hash}; use tokio::{io, io::AsyncBufReadExt, select}; use tracing_subscriber::EnvFilter; @@ -127,4 +124,4 @@ async fn main() -> Result<(), Box> { } } } -} \ No newline at end of file +} diff --git a/rust/networking/src/discovery.rs b/rust/networking/src/discovery.rs index 64a297c3..b9a4052c 100644 --- a/rust/networking/src/discovery.rs +++ b/rust/networking/src/discovery.rs @@ -1,3 +1,4 @@ +use crate::ext::MultiaddrExt; use crate::keep_alive; use delegate::delegate; use either::Either; @@ -7,7 +8,11 @@ use libp2p::core::transport::PortUse; use libp2p::core::{ConnectedPoint, Endpoint}; use libp2p::swarm::behaviour::ConnectionEstablished; use libp2p::swarm::dial_opts::DialOpts; -use libp2p::swarm::{dummy, CloseConnection, ConnectionClosed, ConnectionDenied, ConnectionHandler, ConnectionHandlerSelect, ConnectionId, FromSwarm, NetworkBehaviour, THandler, THandlerInEvent, THandlerOutEvent, ToSwarm}; +use libp2p::swarm::{ + CloseConnection, ConnectionClosed, ConnectionDenied, ConnectionHandler, + ConnectionHandlerSelect, ConnectionId, FromSwarm, NetworkBehaviour, THandler, THandlerInEvent, + THandlerOutEvent, ToSwarm, dummy, +}; use libp2p::{Multiaddr, PeerId, identity, mdns}; use std::collections::{BTreeSet, HashMap}; use std::convert::Infallible; @@ -16,16 +21,14 @@ use std::net::IpAddr; use std::task::{Context, Poll}; use std::time::Duration; use util::wakerdeque::WakerDeque; -use crate::ext::MultiaddrExt; - const RETRY_CONNECT_INTERVAL: Duration = Duration::from_secs(5); mod managed { + use libp2p::swarm::NetworkBehaviour; + use libp2p::{identity, mdns, ping}; use std::io; use std::time::Duration; - use libp2p::{identity, mdns, ping}; - use libp2p::swarm::NetworkBehaviour; const MDNS_RECORD_TTL: Duration = Duration::from_secs(2_500); const MDNS_QUERY_INTERVAL: Duration = Duration::from_secs(1_500); @@ -64,7 +67,11 @@ mod managed { } fn ping_behaviour() -> ping::Behaviour { - ping::Behaviour::new(ping::Config::new().with_timeout(PING_TIMEOUT).with_interval(PING_INTERVAL)) + ping::Behaviour::new( + ping::Config::new() + .with_timeout(PING_TIMEOUT) + .with_interval(PING_INTERVAL), + ) } } @@ -129,7 +136,6 @@ impl Behaviour { }) } - fn handle_mdns_discovered(&mut self, peers: Vec<(PeerId, Multiaddr)>) { for (p, ma) in peers { self.dial(p, ma.clone()); // always connect @@ -202,7 +208,7 @@ impl Behaviour { impl NetworkBehaviour for Behaviour { type ConnectionHandler = - ConnectionHandlerSelect>; + ConnectionHandlerSelect>; type ToSwarm = Event; // simply delegate to underlying mDNS behaviour @@ -261,11 +267,10 @@ impl NetworkBehaviour for Behaviour { ) { match event { Either::Left(ev) => libp2p::core::util::unreachable(ev), - Either::Right(ev) => self.managed.on_connection_handler_event( - peer_id, - connection_id, - ev, - ), + Either::Right(ev) => { + self.managed + .on_connection_handler_event(peer_id, connection_id, ev) + } } } @@ -277,11 +282,11 @@ impl NetworkBehaviour for Behaviour { // handle swarm events to update internal state: match event { FromSwarm::ConnectionEstablished(ConnectionEstablished { - peer_id, - connection_id, - endpoint, - .. - }) => { + peer_id, + connection_id, + endpoint, + .. + }) => { let remote_address = match endpoint { ConnectedPoint::Dialer { address, .. } => address, ConnectedPoint::Listener { send_back_addr, .. } => send_back_addr, @@ -293,11 +298,11 @@ impl NetworkBehaviour for Behaviour { } } FromSwarm::ConnectionClosed(ConnectionClosed { - peer_id, - connection_id, - endpoint, - .. - }) => { + peer_id, + connection_id, + endpoint, + .. + }) => { let remote_address = match endpoint { ConnectedPoint::Dialer { address, .. } => address, ConnectedPoint::Listener { send_back_addr, .. } => send_back_addr, @@ -331,7 +336,7 @@ impl NetworkBehaviour for Behaviour { mdns::Event::Expired(peers) => { self.handle_mdns_expired(peers); } - } + }, // handle ping events => if error then disconnect managed::BehaviourEvent::Ping(e) => { @@ -346,7 +351,6 @@ impl NetworkBehaviour for Behaviour { cx.waker().wake_by_ref(); } - // forward any other mDNS event to the swarm or its connection handler(s) Poll::Ready(e) => { return Poll::Ready( diff --git a/rust/networking/src/keep_alive.rs b/rust/networking/src/keep_alive.rs index eb67aecb..881b11d7 100644 --- a/rust/networking/src/keep_alive.rs +++ b/rust/networking/src/keep_alive.rs @@ -20,13 +20,13 @@ impl handler::ConnectionHandler for ConnectionHandler { type FromBehaviour = ::FromBehaviour; type ToBehaviour = ::ToBehaviour; type InboundProtocol = - ::InboundProtocol; + ::InboundProtocol; type OutboundProtocol = - ::OutboundProtocol; + ::OutboundProtocol; type InboundOpenInfo = - ::InboundOpenInfo; + ::InboundOpenInfo; type OutboundOpenInfo = - ::OutboundOpenInfo; + ::OutboundOpenInfo; delegate! { to self.0 { diff --git a/rust/networking/src/lib.rs b/rust/networking/src/lib.rs index a83bdc71..59b83817 100644 --- a/rust/networking/src/lib.rs +++ b/rust/networking/src/lib.rs @@ -28,10 +28,10 @@ pub(crate) mod alias { /// Namespace for crate-wide extension traits/methods pub(crate) mod ext { - use std::net::IpAddr; use extend::ext; use libp2p::Multiaddr; use libp2p::multiaddr::Protocol; + use std::net::IpAddr; #[ext(pub, name = MultiaddrExt)] impl Multiaddr { @@ -42,7 +42,7 @@ pub(crate) mod ext { match p { Protocol::Ip4(ip) => IpAddr::V4(ip), Protocol::Ip6(ip) => IpAddr::V6(ip), - _ => return None + _ => return None, } } else { return None; @@ -61,4 +61,4 @@ pub(crate) mod private { /// Sealed traits support pub trait Sealed {} impl Sealed for T {} -} \ No newline at end of file +} diff --git a/rust/networking/src/swarm.rs b/rust/networking/src/swarm.rs index 8be3f160..a5c87af5 100644 --- a/rust/networking/src/swarm.rs +++ b/rust/networking/src/swarm.rs @@ -37,19 +37,20 @@ mod transport { use libp2p::core::transport::Boxed; use libp2p::pnet::{PnetError, PnetOutput}; use libp2p::{PeerId, Transport, identity, noise, pnet, yamux}; - use std::{sync::LazyLock, env}; + use std::{env, sync::LazyLock}; /// Key used for networking's private network; parametrized on the [`NETWORK_VERSION`]. /// See [`pnet_upgrade`] for more. static PNET_PRESHARED_KEY: LazyLock<[u8; 32]> = LazyLock::new(|| { let builder = Sha3_256::new().update(b"exo_discovery_network"); - + if let Ok(var) = env::var(OVERRIDE_VERSION_ENV_VAR) { - let bytes = var.into_bytes(); + let bytes = var.into_bytes(); builder.update(&bytes) } else { builder.update(NETWORK_VERSION) - }.finalize() + } + .finalize() }); /// Make the Swarm run on a private network, as to not clash with public libp2p nodes and @@ -103,9 +104,9 @@ mod transport { mod behaviour { use crate::{alias, discovery}; - use std::time::Duration; use libp2p::swarm::NetworkBehaviour; use libp2p::{gossipsub, identity}; + use std::time::Duration; /// Behavior of the Swarm which composes all desired behaviors: /// Right now its just [`discovery::Behaviour`] and [`gossipsub::Behaviour`]. @@ -139,6 +140,6 @@ mod behaviour { .build() .expect("the configuration should always be valid"), ) - .expect("creating gossipsub behavior should always work") + .expect("creating gossipsub behavior should always work") } } diff --git a/tmp/run_llm.py b/tmp/run_llm.py index 10f335b6..89a2e50b 100644 --- a/tmp/run_llm.py +++ b/tmp/run_llm.py @@ -27,7 +27,7 @@ def stream_chat(host: str, query: str) -> None: if not line.startswith("data:"): continue - data = line[len("data:"):].strip() + data = line[len("data:") :].strip() if data == "[DONE]": break @@ -55,7 +55,8 @@ def main() -> None: ) parser.add_argument("host", help="Hostname (without protocol), e.g. localhost") parser.add_argument( - "-f", "--file", + "-f", + "--file", help="Path to a text file whose contents will be used as the query", ) parser.add_argument( @@ -82,4 +83,4 @@ def main() -> None: if __name__ == "__main__": - main() \ No newline at end of file + main()