mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 22:27:50 -05:00
fmt, lint
This commit is contained in:
@@ -5,8 +5,8 @@ use std::time::Duration;
|
||||
|
||||
use iroh::SecretKey;
|
||||
use iroh_gossip::api::{Event, Message};
|
||||
use networking::ExoNet;
|
||||
use n0_future::StreamExt as _;
|
||||
use networking::ExoNet;
|
||||
use tokio::time::sleep;
|
||||
use tokio::{io, io::AsyncBufReadExt as _};
|
||||
use tracing_subscriber::EnvFilter;
|
||||
@@ -16,7 +16,8 @@ use tracing_subscriber::filter::LevelFilter;
|
||||
async fn main() {
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(EnvFilter::from_default_env().add_directive(LevelFilter::INFO.into()))
|
||||
.try_init().expect("logger");
|
||||
.try_init()
|
||||
.expect("logger");
|
||||
|
||||
// Configure swarm
|
||||
let net = Arc::new(
|
||||
@@ -43,8 +44,10 @@ async fn main() {
|
||||
|
||||
let jh2 = tokio::spawn(async move {
|
||||
loop {
|
||||
if let Ok(Some(line)) = stdin.next_line().await && let Err(e) = send.broadcast(line.into()).await {
|
||||
println!("Publish error: {e:?}");
|
||||
if let Ok(Some(line)) = stdin.next_line().await
|
||||
&& let Err(e) = send.broadcast(line.into()).await
|
||||
{
|
||||
println!("Publish error: {e:?}");
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#![allow(clippy::cargo, clippy::unwrap_used)]
|
||||
use iroh::{SecretKey, endpoint_info::EndpointIdExt as _};
|
||||
use networking::ExoNet;
|
||||
use n0_future::StreamExt as _;
|
||||
use networking::ExoNet;
|
||||
|
||||
// Launch a mock version of iroh for testing purposes
|
||||
#[tokio::main]
|
||||
|
||||
@@ -13,6 +13,7 @@ from exo.master.placement_utils import (
|
||||
get_shard_assignments,
|
||||
get_smallest_cycles,
|
||||
)
|
||||
from exo.routing.connection_message import IpAddress
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.commands import (
|
||||
CreateInstance,
|
||||
@@ -30,7 +31,6 @@ from exo.shared.types.worker.instances import (
|
||||
MlxJacclInstance,
|
||||
MlxRingInstance,
|
||||
)
|
||||
from exo.routing.connection_message import IpAddress
|
||||
|
||||
|
||||
def random_ephemeral_port() -> int:
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import TypeGuard, cast
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from exo.routing.connection_message import IpAddress
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.memory import Memory
|
||||
@@ -17,7 +18,6 @@ from exo.shared.types.worker.shards import (
|
||||
ShardMetadata,
|
||||
TensorShardMetadata,
|
||||
)
|
||||
from exo.routing.connection_message import IpAddress
|
||||
|
||||
|
||||
class NodeWithProfile(BaseModel):
|
||||
|
||||
@@ -11,7 +11,6 @@ from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
|
||||
from exo.shared.types.commands import (
|
||||
ChatCompletion,
|
||||
CommandId,
|
||||
CreateInstance,
|
||||
ForwarderCommand,
|
||||
PlaceInstance,
|
||||
)
|
||||
|
||||
@@ -7,6 +7,7 @@ from exo.master.placement import (
|
||||
get_transition_events,
|
||||
place_instance,
|
||||
)
|
||||
from exo.master.tests.conftest import create_connection, create_node
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.commands import PlaceInstance
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
@@ -24,8 +25,6 @@ from exo.shared.types.worker.instances import (
|
||||
from exo.shared.types.worker.runners import ShardAssignments
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
|
||||
from exo.master.tests.conftest import create_node, create_connection
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def topology() -> Topology:
|
||||
|
||||
@@ -7,6 +7,7 @@ from exo.master.placement_utils import (
|
||||
get_shard_assignments,
|
||||
get_smallest_cycles,
|
||||
)
|
||||
from exo.master.tests.conftest import create_connection, create_node
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.common import Host, NodeId
|
||||
from exo.shared.types.memory import Memory
|
||||
@@ -14,8 +15,6 @@ from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.profiling import NetworkInterfaceInfo, NodePerformanceProfile
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
|
||||
from exo.master.tests.conftest import create_connection, create_node
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def topology() -> Topology:
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import pytest
|
||||
from ipaddress import ip_address
|
||||
|
||||
import pytest
|
||||
|
||||
from exo.routing.connection_message import SocketAddress
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.profiling import (
|
||||
MemoryPerformanceProfile,
|
||||
@@ -8,7 +10,6 @@ from exo.shared.types.profiling import (
|
||||
SystemPerformanceProfile,
|
||||
)
|
||||
from exo.shared.types.topology import Connection, ConnectionProfile, NodeId, NodeInfo
|
||||
from exo.routing.connection_message import SocketAddress
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from ipaddress import IPv4Address, IPv6Address, ip_address
|
||||
|
||||
from pydantic import ConfigDict
|
||||
from exo_pyo3_bindings import RustConnectionMessage
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from ipaddress import ip_address
|
||||
|
||||
from exo.routing.connection_message import SocketAddress
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.topology import Connection
|
||||
from exo.routing.connection_message import SocketAddress
|
||||
|
||||
|
||||
def test_state_serialization_roundtrip() -> None:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from exo.routing.connection_message import ConnectionMessage
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.events import Event, TopologyEdgeCreated, TopologyEdgeDeleted
|
||||
from exo.shared.types.events import Event, TopologyEdgeCreated
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.topology import Connection
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ from mlx_lm.models.cache import KVCache
|
||||
# These are wrapper functions to fix the fact that mlx is not strongly typed in the same way that EXO is.
|
||||
# For example - MLX has no guarantee of the interface that nn.Module will expose. But we need a guarantee that it has a __call__() function
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
layers: list[nn.Module]
|
||||
|
||||
@@ -17,6 +18,7 @@ class Model(nn.Module):
|
||||
input_embeddings: mx.array | None = None,
|
||||
) -> mx.array: ...
|
||||
|
||||
|
||||
class Detokenizer:
|
||||
def reset(self) -> None: ...
|
||||
def add_token(self, token: int) -> None: ...
|
||||
@@ -25,6 +27,7 @@ class Detokenizer:
|
||||
@property
|
||||
def last_segment(self) -> str: ...
|
||||
|
||||
|
||||
class TokenizerWrapper:
|
||||
bos_token: str | None
|
||||
eos_token_ids: list[int]
|
||||
|
||||
@@ -25,7 +25,6 @@ from exo.shared.types.events import (
|
||||
)
|
||||
from exo.shared.types.profiling import MemoryPerformanceProfile, NodePerformanceProfile
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.topology import Connection
|
||||
from exo.shared.types.tasks import (
|
||||
CreateRunner,
|
||||
DownloadModel,
|
||||
@@ -33,6 +32,7 @@ from exo.shared.types.tasks import (
|
||||
Task,
|
||||
TaskStatus,
|
||||
)
|
||||
from exo.shared.types.topology import Connection
|
||||
from exo.shared.types.worker.downloads import (
|
||||
DownloadCompleted,
|
||||
DownloadOngoing,
|
||||
@@ -256,12 +256,12 @@ class Worker:
|
||||
|
||||
async def _connection_message_event_writer(self):
|
||||
with self.connection_message_receiver as connection_messages:
|
||||
async for msg in connection_messages:
|
||||
async for _msg in connection_messages:
|
||||
break
|
||||
# TODO: use mdns for partial discovery
|
||||
for event in check_connections(self.node_id, msg, self.state):
|
||||
logger.info(f"Worker discovered connection {event}")
|
||||
await self.event_sender.send(event)
|
||||
# for event in check_connections(self.node_id, msg, self.state):
|
||||
# logger.info(f"Worker discovered connection {event}")
|
||||
# await self.event_sender.send(event)
|
||||
|
||||
async def _nack_request(self, since_idx: int) -> None:
|
||||
# We request all events after (and including) the missing index.
|
||||
|
||||
@@ -3,9 +3,9 @@ from ipaddress import ip_address
|
||||
|
||||
from anyio import create_task_group, to_thread
|
||||
|
||||
from exo.routing.connection_message import IpAddress
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.routing.connection_message import IpAddress
|
||||
|
||||
|
||||
# TODO: ref. api port
|
||||
|
||||
Reference in New Issue
Block a user