mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 22:27:50 -05:00
fix pydantic validation
This commit is contained in:
@@ -8,7 +8,7 @@ from exo.shared.types.common import Host, NodeId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelMetadata
|
||||
from exo.shared.types.profiling import NodePerformanceProfile
|
||||
from exo.shared.types.topology import SocketConnection, RDMAConnection
|
||||
from exo.shared.types.topology import RDMAConnection, SocketConnection
|
||||
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
|
||||
from exo.shared.types.worker.shards import (
|
||||
PipelineShardMetadata,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.topology import Connection
|
||||
from exo.shared.types.topology import SocketConnection
|
||||
|
||||
|
||||
def test_state_serialization_roundtrip() -> None:
|
||||
@@ -11,14 +11,12 @@ def test_state_serialization_roundtrip() -> None:
|
||||
node_a = NodeId("node-a")
|
||||
node_b = NodeId("node-b")
|
||||
|
||||
connection = Connection(
|
||||
local_node_id=node_a,
|
||||
send_back_node_id=node_b,
|
||||
send_back_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/10001"),
|
||||
connection = SocketConnection(
|
||||
sink_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/10001"),
|
||||
)
|
||||
|
||||
state = State()
|
||||
state.topology.add_connection(connection)
|
||||
state.topology.add_connection(node_a, node_b, connection)
|
||||
|
||||
json_repr = state.model_dump_json()
|
||||
restored_state = State.model_validate_json(json_repr)
|
||||
|
||||
@@ -13,7 +13,7 @@ class TopologySnapshot(BaseModel):
|
||||
nodes: list[NodeId]
|
||||
connections: list[tuple[NodeId, NodeId, SocketConnection | RDMAConnection]]
|
||||
|
||||
model_config = ConfigDict(frozen=True, extra="forbid", strict=True)
|
||||
model_config = ConfigDict(frozen=True, extra="forbid")
|
||||
|
||||
|
||||
class Topology:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import anyio
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
|
||||
@@ -28,7 +28,7 @@ class TBConnectivityItem(BaseModel, extra="ignore"):
|
||||
class TBConnectivityData(BaseModel, extra="ignore"):
|
||||
domain_uuid_key: str | None
|
||||
device_name_key: str
|
||||
_items: list[TBConnectivityItem] | None
|
||||
items: list[TBConnectivityItem] | None = Field(None, alias="_items")
|
||||
receptacle_1_tag: TBReceptacleTag
|
||||
|
||||
def ident(self, ifaces: dict[str, str]) -> TBIdentifier | None:
|
||||
@@ -39,12 +39,12 @@ class TBConnectivityData(BaseModel, extra="ignore"):
|
||||
return TBIdentifier(rdma_interface=iface, domain_uuid=self.domain_uuid_key)
|
||||
|
||||
def conn(self) -> TBConnection | None:
|
||||
if self.domain_uuid_key is None or self._items is None:
|
||||
if self.domain_uuid_key is None or self.items is None:
|
||||
return
|
||||
|
||||
sink_key = next(
|
||||
item.domain_uuid_key
|
||||
for item in self._items
|
||||
for item in self.items
|
||||
if item.domain_uuid_key is not None
|
||||
)
|
||||
return TBConnection(source_uuid=self.domain_uuid_key, sink_uuid=sink_key)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from collections.abc import Mapping
|
||||
import socket
|
||||
from collections.abc import Mapping
|
||||
|
||||
from anyio import create_task_group, to_thread
|
||||
|
||||
@@ -29,15 +29,15 @@ async def check_reachability(
|
||||
out[target_node_id].add(target_ip)
|
||||
|
||||
|
||||
async def check_reachable(topology: Topology, profiles: Mapping[NodeId, NodePerformanceProfile]) -> dict[NodeId, set[str]]:
|
||||
async def check_reachable(
|
||||
topology: Topology, profiles: Mapping[NodeId, NodePerformanceProfile]
|
||||
) -> dict[NodeId, set[str]]:
|
||||
reachable: dict[NodeId, set[str]] = {}
|
||||
async with create_task_group() as tg:
|
||||
for node in topology.list_nodes():
|
||||
if not node not in profiles:
|
||||
continue
|
||||
for iface in profiles[node].network_interfaces:
|
||||
tg.start_soon(
|
||||
check_reachability, iface.ip_address, node, reachable
|
||||
)
|
||||
tg.start_soon(check_reachability, iface.ip_address, node, reachable)
|
||||
|
||||
return reachable
|
||||
|
||||
@@ -255,7 +255,7 @@ class Worker:
|
||||
sink_multiaddr=Multiaddr(
|
||||
address=f"/ip4/{msg.remote_ipv4}/tcp/{msg.remote_tcp_port}"
|
||||
),
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
case ConnectionMessageType.Disconnected:
|
||||
@@ -266,7 +266,7 @@ class Worker:
|
||||
sink_multiaddr=Multiaddr(
|
||||
address=f"/ip4/{msg.remote_ipv4}/tcp/{msg.remote_tcp_port}"
|
||||
),
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
async def _nack_request(self, since_idx: int) -> None:
|
||||
@@ -407,16 +407,21 @@ class Worker:
|
||||
)
|
||||
if edge not in edges:
|
||||
logger.debug(f"ping discovered {edge=}")
|
||||
await self.event_sender.send(TopologyEdgeCreated(source=self.node_id, sink=nid, edge=edge))
|
||||
await self.event_sender.send(
|
||||
TopologyEdgeCreated(
|
||||
source=self.node_id, sink=nid, edge=edge
|
||||
)
|
||||
)
|
||||
|
||||
for nid, conn in self.state.topology.out_edges(self.node_id):
|
||||
if not isinstance(conn, SocketConnection):
|
||||
continue
|
||||
if (
|
||||
nid not in conns
|
||||
or conn.sink_multiaddr.ip_address not in conns.get(nid, set())
|
||||
if nid not in conns or conn.sink_multiaddr.ip_address not in conns.get(
|
||||
nid, set()
|
||||
):
|
||||
logger.debug(f"ping failed to discover {conn=}")
|
||||
await self.event_sender.send(TopologyEdgeDeleted(source=self.node_id, sink=nid, edge=conn))
|
||||
await self.event_sender.send(
|
||||
TopologyEdgeDeleted(source=self.node_id, sink=nid, edge=conn)
|
||||
)
|
||||
|
||||
await anyio.sleep(10)
|
||||
|
||||
Reference in New Issue
Block a user