mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 22:27:50 -05:00
30
shared/tests/test_state_serialization.py
Normal file
30
shared/tests/test_state_serialization.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from shared.types.common import NodeId
|
||||
from shared.types.state import State
|
||||
from shared.types.topology import Connection
|
||||
|
||||
|
||||
def test_state_serialization_roundtrip() -> None:
|
||||
"""Verify that State → JSON → State round-trip preserves topology."""
|
||||
|
||||
# --- build a simple state ------------------------------------------------
|
||||
node_a = NodeId("node-a")
|
||||
node_b = NodeId("node-b")
|
||||
|
||||
connection = Connection(
|
||||
source_node_id=node_a,
|
||||
sink_node_id=node_b,
|
||||
source_multiaddr="/ip4/127.0.0.1/tcp/10000",
|
||||
sink_multiaddr="/ip4/127.0.0.1/tcp/10001",
|
||||
)
|
||||
|
||||
state = State()
|
||||
state.topology.add_connection(connection)
|
||||
state.topology.master_node_id = node_a
|
||||
|
||||
json_repr = state.model_dump_json()
|
||||
restored_state = State.model_validate_json(json_repr)
|
||||
|
||||
assert state.topology.to_snapshot() == restored_state.topology.to_snapshot()
|
||||
assert restored_state.model_dump_json() == json_repr
|
||||
@@ -1,12 +1,24 @@
|
||||
import contextlib
|
||||
from typing import Iterable
|
||||
|
||||
import rustworkx as rx
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from shared.types.common import NodeId
|
||||
from shared.types.profiling import ConnectionProfile, NodePerformanceProfile
|
||||
from shared.types.topology import Connection, Node, TopologyProto
|
||||
|
||||
|
||||
class TopologySnapshot(BaseModel):
|
||||
"""Immutable serialisable representation of a :class:`Topology`."""
|
||||
|
||||
nodes: list[Node]
|
||||
connections: list[Connection]
|
||||
master_node_id: NodeId | None = None
|
||||
|
||||
model_config = ConfigDict(frozen=True, extra="forbid", strict=True)
|
||||
|
||||
|
||||
class Topology(TopologyProto):
|
||||
def __init__(self) -> None:
|
||||
self._graph: rx.PyDiGraph[Node, Connection] = rx.PyDiGraph()
|
||||
@@ -14,8 +26,35 @@ class Topology(TopologyProto):
|
||||
self._rx_id_to_node_id_map: dict[int, NodeId] = dict()
|
||||
self._edge_id_to_rx_id_map: dict[Connection, int] = dict()
|
||||
self.master_node_id: NodeId | None = None
|
||||
|
||||
# TODO: implement serialization + deserialization method
|
||||
|
||||
def to_snapshot(self) -> TopologySnapshot:
|
||||
"""Return an immutable snapshot suitable for JSON serialisation."""
|
||||
|
||||
return TopologySnapshot(
|
||||
nodes=list(self.list_nodes()),
|
||||
connections=list(self.list_connections()),
|
||||
master_node_id=self.master_node_id,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_snapshot(cls, snapshot: TopologySnapshot) -> "Topology":
|
||||
"""Reconstruct a :class:`Topology` from *snapshot*.
|
||||
|
||||
The reconstructed topology is equivalent (w.r.t. nodes, connections
|
||||
and ``master_node_id``) to the original one that produced *snapshot*.
|
||||
"""
|
||||
|
||||
topology = cls()
|
||||
topology.master_node_id = snapshot.master_node_id
|
||||
|
||||
for node in snapshot.nodes:
|
||||
with contextlib.suppress(ValueError):
|
||||
topology.add_node(node, node.node_id)
|
||||
|
||||
for connection in snapshot.connections:
|
||||
topology.add_connection(connection)
|
||||
|
||||
return topology
|
||||
|
||||
def add_node(self, node: Node, node_id: NodeId) -> None:
|
||||
if node_id in self._node_id_to_rx_id_map:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from shared.topology import Topology
|
||||
from shared.types.common import NodeId
|
||||
@@ -11,8 +12,25 @@ from shared.types.worker.instances import Instance
|
||||
from shared.types.worker.runners import RunnerId, RunnerStatus
|
||||
|
||||
|
||||
def _encode_topology(topo: "Topology") -> dict[str, Any]: # noqa: D401
|
||||
"""Serialise *topo* into a JSON-compatible dict."""
|
||||
|
||||
return topo.to_snapshot().model_dump()
|
||||
|
||||
class State(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
"""Global system state.
|
||||
|
||||
The :class:`Topology` instance is encoded/decoded via an immutable
|
||||
:class:`~shared.topology.TopologySnapshot` to ensure compatibility with
|
||||
standard JSON serialisation.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
json_encoders={
|
||||
Topology: _encode_topology,
|
||||
},
|
||||
)
|
||||
node_status: Mapping[NodeId, NodeStatus] = {}
|
||||
instances: Mapping[InstanceId, Instance] = {}
|
||||
runners: Mapping[RunnerId, RunnerStatus] = {}
|
||||
@@ -21,3 +39,25 @@ class State(BaseModel):
|
||||
topology: Topology = Topology()
|
||||
history: Sequence[Topology] = []
|
||||
last_event_applied_idx: int = Field(default=0, ge=0)
|
||||
|
||||
@field_validator("topology", mode="before")
|
||||
@classmethod
|
||||
def _deserialize_topology(cls, value: object) -> Topology: # noqa: D401 – Pydantic validator signature
|
||||
"""Convert an incoming *value* into a :class:`Topology` instance.
|
||||
|
||||
Accepts either an already constructed :class:`Topology` or a mapping
|
||||
representing :class:`~shared.topology.TopologySnapshot`.
|
||||
"""
|
||||
|
||||
if isinstance(value, Topology):
|
||||
return value
|
||||
|
||||
# Lazy import to avoid circular dependencies.
|
||||
from shared.topology import Topology as _Topology
|
||||
from shared.topology import TopologySnapshot
|
||||
|
||||
if isinstance(value, Mapping): # likely a snapshot-dict coming from JSON
|
||||
snapshot = TopologySnapshot(**cast(dict[str, Any], value)) # type: ignore[arg-type]
|
||||
return _Topology.from_snapshot(snapshot)
|
||||
|
||||
raise TypeError("Invalid representation for Topology field in State")
|
||||
|
||||
Reference in New Issue
Block a user