From 72c71a9abdda30d2b5f29cff28cc0367577e6b19 Mon Sep 17 00:00:00 2001 From: Evan Date: Mon, 22 Dec 2025 03:11:42 +0000 Subject: [PATCH] fmt, lint --- rust/networking/examples/chatroom.rs | 11 +++++++---- rust/networking/examples/mdns_responder.rs | 2 +- src/exo/master/placement.py | 2 +- src/exo/master/placement_utils.py | 2 +- src/exo/master/tests/test_master.py | 1 - src/exo/master/tests/test_placement.py | 3 +-- src/exo/master/tests/test_placement_utils.py | 3 +-- src/exo/master/tests/test_topology.py | 5 +++-- src/exo/routing/connection_message.py | 2 +- src/exo/shared/tests/test_state_serialization.py | 3 ++- src/exo/worker/_connection_handler.py | 2 +- src/exo/worker/engines/mlx/__init__.py | 3 +++ src/exo/worker/main.py | 10 +++++----- src/exo/worker/utils/net_profile.py | 2 +- 14 files changed, 28 insertions(+), 23 deletions(-) diff --git a/rust/networking/examples/chatroom.rs b/rust/networking/examples/chatroom.rs index 93906cf8..e5f1736c 100644 --- a/rust/networking/examples/chatroom.rs +++ b/rust/networking/examples/chatroom.rs @@ -5,8 +5,8 @@ use std::time::Duration; use iroh::SecretKey; use iroh_gossip::api::{Event, Message}; -use networking::ExoNet; use n0_future::StreamExt as _; +use networking::ExoNet; use tokio::time::sleep; use tokio::{io, io::AsyncBufReadExt as _}; use tracing_subscriber::EnvFilter; @@ -16,7 +16,8 @@ use tracing_subscriber::filter::LevelFilter; async fn main() { tracing_subscriber::fmt() .with_env_filter(EnvFilter::from_default_env().add_directive(LevelFilter::INFO.into())) - .try_init().expect("logger"); + .try_init() + .expect("logger"); // Configure swarm let net = Arc::new( @@ -43,8 +44,10 @@ async fn main() { let jh2 = tokio::spawn(async move { loop { - if let Ok(Some(line)) = stdin.next_line().await && let Err(e) = send.broadcast(line.into()).await { - println!("Publish error: {e:?}"); + if let Ok(Some(line)) = stdin.next_line().await + && let Err(e) = send.broadcast(line.into()).await + { + println!("Publish error: {e:?}"); } } }); diff --git a/rust/networking/examples/mdns_responder.rs b/rust/networking/examples/mdns_responder.rs index eab9df1b..c314bee1 100644 --- a/rust/networking/examples/mdns_responder.rs +++ b/rust/networking/examples/mdns_responder.rs @@ -1,7 +1,7 @@ #![allow(clippy::cargo, clippy::unwrap_used)] use iroh::{SecretKey, endpoint_info::EndpointIdExt as _}; -use networking::ExoNet; use n0_future::StreamExt as _; +use networking::ExoNet; // Launch a mock version of iroh for testing purposes #[tokio::main] diff --git a/src/exo/master/placement.py b/src/exo/master/placement.py index 271acccc..916eaa87 100644 --- a/src/exo/master/placement.py +++ b/src/exo/master/placement.py @@ -13,6 +13,7 @@ from exo.master.placement_utils import ( get_shard_assignments, get_smallest_cycles, ) +from exo.routing.connection_message import IpAddress from exo.shared.topology import Topology from exo.shared.types.commands import ( CreateInstance, @@ -30,7 +31,6 @@ from exo.shared.types.worker.instances import ( MlxJacclInstance, MlxRingInstance, ) -from exo.routing.connection_message import IpAddress def random_ephemeral_port() -> int: diff --git a/src/exo/master/placement_utils.py b/src/exo/master/placement_utils.py index 66ae39e3..8bd59ac6 100644 --- a/src/exo/master/placement_utils.py +++ b/src/exo/master/placement_utils.py @@ -4,6 +4,7 @@ from typing import TypeGuard, cast from loguru import logger from pydantic import BaseModel +from exo.routing.connection_message import IpAddress from exo.shared.topology import Topology from exo.shared.types.common import NodeId from exo.shared.types.memory import Memory @@ -17,7 +18,6 @@ from exo.shared.types.worker.shards import ( ShardMetadata, TensorShardMetadata, ) -from exo.routing.connection_message import IpAddress class NodeWithProfile(BaseModel): diff --git a/src/exo/master/tests/test_master.py b/src/exo/master/tests/test_master.py index 01336bce..f84eddc2 100644 --- a/src/exo/master/tests/test_master.py +++ b/src/exo/master/tests/test_master.py @@ -11,7 +11,6 @@ from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams from exo.shared.types.commands import ( ChatCompletion, CommandId, - CreateInstance, ForwarderCommand, PlaceInstance, ) diff --git a/src/exo/master/tests/test_placement.py b/src/exo/master/tests/test_placement.py index 52174a85..53b28aad 100644 --- a/src/exo/master/tests/test_placement.py +++ b/src/exo/master/tests/test_placement.py @@ -7,6 +7,7 @@ from exo.master.placement import ( get_transition_events, place_instance, ) +from exo.master.tests.conftest import create_connection, create_node from exo.shared.topology import Topology from exo.shared.types.commands import PlaceInstance from exo.shared.types.common import CommandId, NodeId @@ -24,8 +25,6 @@ from exo.shared.types.worker.instances import ( from exo.shared.types.worker.runners import ShardAssignments from exo.shared.types.worker.shards import Sharding -from exo.master.tests.conftest import create_node, create_connection - @pytest.fixture def topology() -> Topology: diff --git a/src/exo/master/tests/test_placement_utils.py b/src/exo/master/tests/test_placement_utils.py index 1ef9f48e..840a9516 100644 --- a/src/exo/master/tests/test_placement_utils.py +++ b/src/exo/master/tests/test_placement_utils.py @@ -7,6 +7,7 @@ from exo.master.placement_utils import ( get_shard_assignments, get_smallest_cycles, ) +from exo.master.tests.conftest import create_connection, create_node from exo.shared.topology import Topology from exo.shared.types.common import Host, NodeId from exo.shared.types.memory import Memory @@ -14,8 +15,6 @@ from exo.shared.types.models import ModelId, ModelMetadata from exo.shared.types.profiling import NetworkInterfaceInfo, NodePerformanceProfile from exo.shared.types.worker.shards import Sharding -from exo.master.tests.conftest import create_connection, create_node - @pytest.fixture def topology() -> Topology: diff --git a/src/exo/master/tests/test_topology.py b/src/exo/master/tests/test_topology.py index 9ef55081..d3c0cb48 100644 --- a/src/exo/master/tests/test_topology.py +++ b/src/exo/master/tests/test_topology.py @@ -1,6 +1,8 @@ -import pytest from ipaddress import ip_address +import pytest + +from exo.routing.connection_message import SocketAddress from exo.shared.topology import Topology from exo.shared.types.profiling import ( MemoryPerformanceProfile, @@ -8,7 +10,6 @@ from exo.shared.types.profiling import ( SystemPerformanceProfile, ) from exo.shared.types.topology import Connection, ConnectionProfile, NodeId, NodeInfo -from exo.routing.connection_message import SocketAddress @pytest.fixture diff --git a/src/exo/routing/connection_message.py b/src/exo/routing/connection_message.py index 807752a7..7687a6f2 100644 --- a/src/exo/routing/connection_message.py +++ b/src/exo/routing/connection_message.py @@ -1,7 +1,7 @@ from ipaddress import IPv4Address, IPv6Address, ip_address -from pydantic import ConfigDict from exo_pyo3_bindings import RustConnectionMessage +from pydantic import ConfigDict from exo.shared.types.common import NodeId from exo.utils.pydantic_ext import CamelCaseModel diff --git a/src/exo/shared/tests/test_state_serialization.py b/src/exo/shared/tests/test_state_serialization.py index bbaa5f1f..da2f732d 100644 --- a/src/exo/shared/tests/test_state_serialization.py +++ b/src/exo/shared/tests/test_state_serialization.py @@ -1,8 +1,9 @@ from ipaddress import ip_address + +from exo.routing.connection_message import SocketAddress from exo.shared.types.common import NodeId from exo.shared.types.state import State from exo.shared.types.topology import Connection -from exo.routing.connection_message import SocketAddress def test_state_serialization_roundtrip() -> None: diff --git a/src/exo/worker/_connection_handler.py b/src/exo/worker/_connection_handler.py index 4a2d1968..511affd3 100644 --- a/src/exo/worker/_connection_handler.py +++ b/src/exo/worker/_connection_handler.py @@ -1,6 +1,6 @@ from exo.routing.connection_message import ConnectionMessage from exo.shared.types.common import NodeId -from exo.shared.types.events import Event, TopologyEdgeCreated, TopologyEdgeDeleted +from exo.shared.types.events import Event, TopologyEdgeCreated from exo.shared.types.state import State from exo.shared.types.topology import Connection diff --git a/src/exo/worker/engines/mlx/__init__.py b/src/exo/worker/engines/mlx/__init__.py index bf8601f8..d6f0b6b3 100644 --- a/src/exo/worker/engines/mlx/__init__.py +++ b/src/exo/worker/engines/mlx/__init__.py @@ -7,6 +7,7 @@ from mlx_lm.models.cache import KVCache # These are wrapper functions to fix the fact that mlx is not strongly typed in the same way that EXO is. # For example - MLX has no guarantee of the interface that nn.Module will expose. But we need a guarantee that it has a __call__() function + class Model(nn.Module): layers: list[nn.Module] @@ -17,6 +18,7 @@ class Model(nn.Module): input_embeddings: mx.array | None = None, ) -> mx.array: ... + class Detokenizer: def reset(self) -> None: ... def add_token(self, token: int) -> None: ... @@ -25,6 +27,7 @@ class Detokenizer: @property def last_segment(self) -> str: ... + class TokenizerWrapper: bos_token: str | None eos_token_ids: list[int] diff --git a/src/exo/worker/main.py b/src/exo/worker/main.py index 1134ba61..ea38f07f 100644 --- a/src/exo/worker/main.py +++ b/src/exo/worker/main.py @@ -25,7 +25,6 @@ from exo.shared.types.events import ( ) from exo.shared.types.profiling import MemoryPerformanceProfile, NodePerformanceProfile from exo.shared.types.state import State -from exo.shared.types.topology import Connection from exo.shared.types.tasks import ( CreateRunner, DownloadModel, @@ -33,6 +32,7 @@ from exo.shared.types.tasks import ( Task, TaskStatus, ) +from exo.shared.types.topology import Connection from exo.shared.types.worker.downloads import ( DownloadCompleted, DownloadOngoing, @@ -256,12 +256,12 @@ class Worker: async def _connection_message_event_writer(self): with self.connection_message_receiver as connection_messages: - async for msg in connection_messages: + async for _msg in connection_messages: break # TODO: use mdns for partial discovery - for event in check_connections(self.node_id, msg, self.state): - logger.info(f"Worker discovered connection {event}") - await self.event_sender.send(event) + # for event in check_connections(self.node_id, msg, self.state): + # logger.info(f"Worker discovered connection {event}") + # await self.event_sender.send(event) async def _nack_request(self, since_idx: int) -> None: # We request all events after (and including) the missing index. diff --git a/src/exo/worker/utils/net_profile.py b/src/exo/worker/utils/net_profile.py index 8563a6d0..c554751c 100644 --- a/src/exo/worker/utils/net_profile.py +++ b/src/exo/worker/utils/net_profile.py @@ -3,9 +3,9 @@ from ipaddress import ip_address from anyio import create_task_group, to_thread +from exo.routing.connection_message import IpAddress from exo.shared.topology import Topology from exo.shared.types.common import NodeId -from exo.routing.connection_message import IpAddress # TODO: ref. api port