Serialize topology

Co-authored-by: Gelu Vrabie <gelu@exolabs.net>
This commit is contained in:
Gelu Vrabie
2025-07-25 15:09:03 +01:00
committed by GitHub
parent a97fb27c64
commit 261e575262
3 changed files with 113 additions and 4 deletions

View File

@@ -0,0 +1,30 @@
from __future__ import annotations
from shared.types.common import NodeId
from shared.types.state import State
from shared.types.topology import Connection
def test_state_serialization_roundtrip() -> None:
"""Verify that State → JSON → State round-trip preserves topology."""
# --- build a simple state ------------------------------------------------
node_a = NodeId("node-a")
node_b = NodeId("node-b")
connection = Connection(
source_node_id=node_a,
sink_node_id=node_b,
source_multiaddr="/ip4/127.0.0.1/tcp/10000",
sink_multiaddr="/ip4/127.0.0.1/tcp/10001",
)
state = State()
state.topology.add_connection(connection)
state.topology.master_node_id = node_a
json_repr = state.model_dump_json()
restored_state = State.model_validate_json(json_repr)
assert state.topology.to_snapshot() == restored_state.topology.to_snapshot()
assert restored_state.model_dump_json() == json_repr

View File

@@ -1,12 +1,24 @@
import contextlib
from typing import Iterable
import rustworkx as rx
from pydantic import BaseModel, ConfigDict
from shared.types.common import NodeId
from shared.types.profiling import ConnectionProfile, NodePerformanceProfile
from shared.types.topology import Connection, Node, TopologyProto
class TopologySnapshot(BaseModel):
"""Immutable serialisable representation of a :class:`Topology`."""
nodes: list[Node]
connections: list[Connection]
master_node_id: NodeId | None = None
model_config = ConfigDict(frozen=True, extra="forbid", strict=True)
class Topology(TopologyProto):
def __init__(self) -> None:
self._graph: rx.PyDiGraph[Node, Connection] = rx.PyDiGraph()
@@ -14,8 +26,35 @@ class Topology(TopologyProto):
self._rx_id_to_node_id_map: dict[int, NodeId] = dict()
self._edge_id_to_rx_id_map: dict[Connection, int] = dict()
self.master_node_id: NodeId | None = None
# TODO: implement serialization + deserialization method
def to_snapshot(self) -> TopologySnapshot:
"""Return an immutable snapshot suitable for JSON serialisation."""
return TopologySnapshot(
nodes=list(self.list_nodes()),
connections=list(self.list_connections()),
master_node_id=self.master_node_id,
)
@classmethod
def from_snapshot(cls, snapshot: TopologySnapshot) -> "Topology":
"""Reconstruct a :class:`Topology` from *snapshot*.
The reconstructed topology is equivalent (w.r.t. nodes, connections
and ``master_node_id``) to the original one that produced *snapshot*.
"""
topology = cls()
topology.master_node_id = snapshot.master_node_id
for node in snapshot.nodes:
with contextlib.suppress(ValueError):
topology.add_node(node, node.node_id)
for connection in snapshot.connections:
topology.add_connection(connection)
return topology
def add_node(self, node: Node, node_id: NodeId) -> None:
if node_id in self._node_id_to_rx_id_map:

View File

@@ -1,6 +1,7 @@
from collections.abc import Mapping, Sequence
from typing import Any, cast
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field, field_validator
from shared.topology import Topology
from shared.types.common import NodeId
@@ -11,8 +12,25 @@ from shared.types.worker.instances import Instance
from shared.types.worker.runners import RunnerId, RunnerStatus
def _encode_topology(topo: "Topology") -> dict[str, Any]: # noqa: D401
"""Serialise *topo* into a JSON-compatible dict."""
return topo.to_snapshot().model_dump()
class State(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
"""Global system state.
The :class:`Topology` instance is encoded/decoded via an immutable
:class:`~shared.topology.TopologySnapshot` to ensure compatibility with
standard JSON serialisation.
"""
model_config = ConfigDict(
arbitrary_types_allowed=True,
json_encoders={
Topology: _encode_topology,
},
)
node_status: Mapping[NodeId, NodeStatus] = {}
instances: Mapping[InstanceId, Instance] = {}
runners: Mapping[RunnerId, RunnerStatus] = {}
@@ -21,3 +39,25 @@ class State(BaseModel):
topology: Topology = Topology()
history: Sequence[Topology] = []
last_event_applied_idx: int = Field(default=0, ge=0)
@field_validator("topology", mode="before")
@classmethod
def _deserialize_topology(cls, value: object) -> Topology: # noqa: D401 Pydantic validator signature
"""Convert an incoming *value* into a :class:`Topology` instance.
Accepts either an already constructed :class:`Topology` or a mapping
representing :class:`~shared.topology.TopologySnapshot`.
"""
if isinstance(value, Topology):
return value
# Lazy import to avoid circular dependencies.
from shared.topology import Topology as _Topology
from shared.topology import TopologySnapshot
if isinstance(value, Mapping): # likely a snapshot-dict coming from JSON
snapshot = TopologySnapshot(**cast(dict[str, Any], value)) # type: ignore[arg-type]
return _Topology.from_snapshot(snapshot)
raise TypeError("Invalid representation for Topology field in State")