fix pydantic validation

This commit is contained in:
Evan
2025-12-19 17:47:57 +00:00
parent 9451ced365
commit f7a2208694
6 changed files with 27 additions and 24 deletions

View File

@@ -8,7 +8,7 @@ from exo.shared.types.common import Host, NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelMetadata
from exo.shared.types.profiling import NodePerformanceProfile
from exo.shared.types.topology import SocketConnection, RDMAConnection
from exo.shared.types.topology import RDMAConnection, SocketConnection
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
from exo.shared.types.worker.shards import (
PipelineShardMetadata,

View File

@@ -1,7 +1,7 @@
from exo.shared.types.common import NodeId
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.state import State
from exo.shared.types.topology import Connection
from exo.shared.types.topology import SocketConnection
def test_state_serialization_roundtrip() -> None:
@@ -11,14 +11,12 @@ def test_state_serialization_roundtrip() -> None:
node_a = NodeId("node-a")
node_b = NodeId("node-b")
connection = Connection(
local_node_id=node_a,
send_back_node_id=node_b,
send_back_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/10001"),
connection = SocketConnection(
sink_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/10001"),
)
state = State()
state.topology.add_connection(connection)
state.topology.add_connection(node_a, node_b, connection)
json_repr = state.model_dump_json()
restored_state = State.model_validate_json(json_repr)

View File

@@ -13,7 +13,7 @@ class TopologySnapshot(BaseModel):
nodes: list[NodeId]
connections: list[tuple[NodeId, NodeId, SocketConnection | RDMAConnection]]
model_config = ConfigDict(frozen=True, extra="forbid", strict=True)
model_config = ConfigDict(frozen=True, extra="forbid")
class Topology:

View File

@@ -1,5 +1,5 @@
import anyio
from pydantic import BaseModel
from pydantic import BaseModel, Field
from exo.utils.pydantic_ext import CamelCaseModel
@@ -28,7 +28,7 @@ class TBConnectivityItem(BaseModel, extra="ignore"):
class TBConnectivityData(BaseModel, extra="ignore"):
domain_uuid_key: str | None
device_name_key: str
_items: list[TBConnectivityItem] | None
items: list[TBConnectivityItem] | None = Field(None, alias="_items")
receptacle_1_tag: TBReceptacleTag
def ident(self, ifaces: dict[str, str]) -> TBIdentifier | None:
@@ -39,12 +39,12 @@ class TBConnectivityData(BaseModel, extra="ignore"):
return TBIdentifier(rdma_interface=iface, domain_uuid=self.domain_uuid_key)
def conn(self) -> TBConnection | None:
if self.domain_uuid_key is None or self._items is None:
if self.domain_uuid_key is None or self.items is None:
return
sink_key = next(
item.domain_uuid_key
for item in self._items
for item in self.items
if item.domain_uuid_key is not None
)
return TBConnection(source_uuid=self.domain_uuid_key, sink_uuid=sink_key)

View File

@@ -1,5 +1,5 @@
from collections.abc import Mapping
import socket
from collections.abc import Mapping
from anyio import create_task_group, to_thread
@@ -29,15 +29,15 @@ async def check_reachability(
out[target_node_id].add(target_ip)
async def check_reachable(topology: Topology, profiles: Mapping[NodeId, NodePerformanceProfile]) -> dict[NodeId, set[str]]:
async def check_reachable(
topology: Topology, profiles: Mapping[NodeId, NodePerformanceProfile]
) -> dict[NodeId, set[str]]:
reachable: dict[NodeId, set[str]] = {}
async with create_task_group() as tg:
for node in topology.list_nodes():
if not node not in profiles:
continue
for iface in profiles[node].network_interfaces:
tg.start_soon(
check_reachability, iface.ip_address, node, reachable
)
tg.start_soon(check_reachability, iface.ip_address, node, reachable)
return reachable

View File

@@ -255,7 +255,7 @@ class Worker:
sink_multiaddr=Multiaddr(
address=f"/ip4/{msg.remote_ipv4}/tcp/{msg.remote_tcp_port}"
),
)
),
)
case ConnectionMessageType.Disconnected:
@@ -266,7 +266,7 @@ class Worker:
sink_multiaddr=Multiaddr(
address=f"/ip4/{msg.remote_ipv4}/tcp/{msg.remote_tcp_port}"
),
)
),
)
async def _nack_request(self, since_idx: int) -> None:
@@ -407,16 +407,21 @@ class Worker:
)
if edge not in edges:
logger.debug(f"ping discovered {edge=}")
await self.event_sender.send(TopologyEdgeCreated(source=self.node_id, sink=nid, edge=edge))
await self.event_sender.send(
TopologyEdgeCreated(
source=self.node_id, sink=nid, edge=edge
)
)
for nid, conn in self.state.topology.out_edges(self.node_id):
if not isinstance(conn, SocketConnection):
continue
if (
nid not in conns
or conn.sink_multiaddr.ip_address not in conns.get(nid, set())
if nid not in conns or conn.sink_multiaddr.ip_address not in conns.get(
nid, set()
):
logger.debug(f"ping failed to discover {conn=}")
await self.event_sender.send(TopologyEdgeDeleted(source=self.node_id, sink=nid, edge=conn))
await self.event_sender.send(
TopologyEdgeDeleted(source=self.node_id, sink=nid, edge=conn)
)
await anyio.sleep(10)