From 261e575262a03645b10ed0a7fe429fb78f7fefd7 Mon Sep 17 00:00:00 2001 From: Gelu Vrabie Date: Fri, 25 Jul 2025 15:09:03 +0100 Subject: [PATCH] Serialize topology Co-authored-by: Gelu Vrabie --- shared/tests/test_state_serialization.py | 30 ++++++++++++++++ shared/topology.py | 43 +++++++++++++++++++++-- shared/types/state.py | 44 ++++++++++++++++++++++-- 3 files changed, 113 insertions(+), 4 deletions(-) create mode 100644 shared/tests/test_state_serialization.py diff --git a/shared/tests/test_state_serialization.py b/shared/tests/test_state_serialization.py new file mode 100644 index 00000000..11306b34 --- /dev/null +++ b/shared/tests/test_state_serialization.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +from shared.types.common import NodeId +from shared.types.state import State +from shared.types.topology import Connection + + +def test_state_serialization_roundtrip() -> None: + """Verify that State → JSON → State round-trip preserves topology.""" + + # --- build a simple state ------------------------------------------------ + node_a = NodeId("node-a") + node_b = NodeId("node-b") + + connection = Connection( + source_node_id=node_a, + sink_node_id=node_b, + source_multiaddr="/ip4/127.0.0.1/tcp/10000", + sink_multiaddr="/ip4/127.0.0.1/tcp/10001", + ) + + state = State() + state.topology.add_connection(connection) + state.topology.master_node_id = node_a + + json_repr = state.model_dump_json() + restored_state = State.model_validate_json(json_repr) + + assert state.topology.to_snapshot() == restored_state.topology.to_snapshot() + assert restored_state.model_dump_json() == json_repr \ No newline at end of file diff --git a/shared/topology.py b/shared/topology.py index c44c717e..0e40905d 100644 --- a/shared/topology.py +++ b/shared/topology.py @@ -1,12 +1,24 @@ +import contextlib from typing import Iterable import rustworkx as rx +from pydantic import BaseModel, ConfigDict from shared.types.common import NodeId from shared.types.profiling import ConnectionProfile, NodePerformanceProfile from shared.types.topology import Connection, Node, TopologyProto +class TopologySnapshot(BaseModel): + """Immutable serialisable representation of a :class:`Topology`.""" + + nodes: list[Node] + connections: list[Connection] + master_node_id: NodeId | None = None + + model_config = ConfigDict(frozen=True, extra="forbid", strict=True) + + class Topology(TopologyProto): def __init__(self) -> None: self._graph: rx.PyDiGraph[Node, Connection] = rx.PyDiGraph() @@ -14,8 +26,35 @@ class Topology(TopologyProto): self._rx_id_to_node_id_map: dict[int, NodeId] = dict() self._edge_id_to_rx_id_map: dict[Connection, int] = dict() self.master_node_id: NodeId | None = None - - # TODO: implement serialization + deserialization method + + def to_snapshot(self) -> TopologySnapshot: + """Return an immutable snapshot suitable for JSON serialisation.""" + + return TopologySnapshot( + nodes=list(self.list_nodes()), + connections=list(self.list_connections()), + master_node_id=self.master_node_id, + ) + + @classmethod + def from_snapshot(cls, snapshot: TopologySnapshot) -> "Topology": + """Reconstruct a :class:`Topology` from *snapshot*. + + The reconstructed topology is equivalent (w.r.t. nodes, connections + and ``master_node_id``) to the original one that produced *snapshot*. + """ + + topology = cls() + topology.master_node_id = snapshot.master_node_id + + for node in snapshot.nodes: + with contextlib.suppress(ValueError): + topology.add_node(node, node.node_id) + + for connection in snapshot.connections: + topology.add_connection(connection) + + return topology def add_node(self, node: Node, node_id: NodeId) -> None: if node_id in self._node_id_to_rx_id_map: diff --git a/shared/types/state.py b/shared/types/state.py index 7736b838..24a0c424 100644 --- a/shared/types/state.py +++ b/shared/types/state.py @@ -1,6 +1,7 @@ from collections.abc import Mapping, Sequence +from typing import Any, cast -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, field_validator from shared.topology import Topology from shared.types.common import NodeId @@ -11,8 +12,25 @@ from shared.types.worker.instances import Instance from shared.types.worker.runners import RunnerId, RunnerStatus +def _encode_topology(topo: "Topology") -> dict[str, Any]: # noqa: D401 + """Serialise *topo* into a JSON-compatible dict.""" + + return topo.to_snapshot().model_dump() + class State(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) + """Global system state. + + The :class:`Topology` instance is encoded/decoded via an immutable + :class:`~shared.topology.TopologySnapshot` to ensure compatibility with + standard JSON serialisation. + """ + + model_config = ConfigDict( + arbitrary_types_allowed=True, + json_encoders={ + Topology: _encode_topology, + }, + ) node_status: Mapping[NodeId, NodeStatus] = {} instances: Mapping[InstanceId, Instance] = {} runners: Mapping[RunnerId, RunnerStatus] = {} @@ -21,3 +39,25 @@ class State(BaseModel): topology: Topology = Topology() history: Sequence[Topology] = [] last_event_applied_idx: int = Field(default=0, ge=0) + + @field_validator("topology", mode="before") + @classmethod + def _deserialize_topology(cls, value: object) -> Topology: # noqa: D401 – Pydantic validator signature + """Convert an incoming *value* into a :class:`Topology` instance. + + Accepts either an already constructed :class:`Topology` or a mapping + representing :class:`~shared.topology.TopologySnapshot`. + """ + + if isinstance(value, Topology): + return value + + # Lazy import to avoid circular dependencies. + from shared.topology import Topology as _Topology + from shared.topology import TopologySnapshot + + if isinstance(value, Mapping): # likely a snapshot-dict coming from JSON + snapshot = TopologySnapshot(**cast(dict[str, Any], value)) # type: ignore[arg-type] + return _Topology.from_snapshot(snapshot) + + raise TypeError("Invalid representation for Topology field in State")