Files
exo/shared/graphs/networkx.py
2025-07-14 21:09:08 +01:00

222 lines
7.8 KiB
Python

from typing import Set, Mapping
from dataclasses import dataclass
from pydantic import TypeAdapter
import rustworkx as rx
from shared.types.graphs.common import (
Edge,
EdgeData,
MutableGraphProtocol,
Vertex,
VertexData,
EdgeIdT,
VertexIdT,
EdgeTypeT,
VertexTypeT,
)
@dataclass(frozen=True)
class _VertexWrapper[VertexTypeT, VertexIdT]:
"""Internal wrapper to store vertex ID alongside vertex data."""
vertex_id: VertexIdT
vertex_data: VertexData[VertexTypeT]
@dataclass(frozen=True)
class _EdgeWrapper[EdgeTypeT, EdgeIdT]:
"""Internal wrapper to store edge ID alongside edge data."""
edge_id: EdgeIdT
edge_data: EdgeData[EdgeTypeT]
class NetworkXGraph(MutableGraphProtocol[EdgeTypeT, VertexTypeT, EdgeIdT, VertexIdT]):
edge_base: TypeAdapter[EdgeTypeT]
vertex_base: TypeAdapter[VertexTypeT]
_graph: rx.PyDiGraph[
_VertexWrapper[VertexTypeT, VertexIdT], _EdgeWrapper[EdgeTypeT, EdgeIdT]
]
_vertex_id_to_index: dict[VertexIdT, int]
_edge_id_to_endpoints: dict[EdgeIdT, tuple[int, int]]
def __init__(
self, edge_base: TypeAdapter[EdgeTypeT], vertex_base: TypeAdapter[VertexTypeT]
) -> None:
self.edge_base = edge_base
self.vertex_base = vertex_base
self._graph = rx.PyDiGraph()
self._vertex_id_to_index = {}
self._edge_id_to_endpoints = {}
###
# GraphProtocol methods
###
def list_edges(self) -> Set[EdgeIdT]:
return {edge.edge_id for edge in self._graph.edges()}
def list_vertices(self) -> Set[VertexIdT]:
return {node.vertex_id for node in self._graph.nodes()}
def get_vertices_from_edges(
self, edges: Set[EdgeIdT]
) -> Mapping[EdgeIdT, Set[VertexIdT]]:
result: dict[EdgeIdT, Set[VertexIdT]] = {}
for edge_id in edges:
if edge_id in self._edge_id_to_endpoints:
u_idx, v_idx = self._edge_id_to_endpoints[edge_id]
u_data = self._graph.get_node_data(u_idx)
v_data = self._graph.get_node_data(v_idx)
result[edge_id] = {u_data.vertex_id, v_data.vertex_id}
return result
def get_edges_from_vertices(
self, vertices: Set[VertexIdT]
) -> Mapping[VertexIdT, Set[EdgeIdT]]:
result: dict[VertexIdT, Set[EdgeIdT]] = {}
for vertex_id in vertices:
if vertex_id in self._vertex_id_to_index:
vertex_idx = self._vertex_id_to_index[vertex_id]
edge_ids: Set[EdgeIdT] = set()
# Get outgoing edges
for _, _, edge_data in self._graph.out_edges(vertex_idx):
edge_ids.add(edge_data.edge_id)
# Get incoming edges
for _, _, edge_data in self._graph.in_edges(vertex_idx):
edge_ids.add(edge_data.edge_id)
result[vertex_id] = edge_ids
return result
def get_edge_data(
self, edges: Set[EdgeIdT]
) -> Mapping[EdgeIdT, EdgeData[EdgeTypeT]]:
result: dict[EdgeIdT, EdgeData[EdgeTypeT]] = {}
for edge_id in edges:
if edge_id in self._edge_id_to_endpoints:
u_idx, v_idx = self._edge_id_to_endpoints[edge_id]
edge_wrapper = self._graph.get_edge_data(u_idx, v_idx)
result[edge_id] = edge_wrapper.edge_data
return result
def get_vertex_data(
self, vertices: Set[VertexIdT]
) -> Mapping[VertexIdT, VertexData[VertexTypeT]]:
result: dict[VertexIdT, VertexData[VertexTypeT]] = {}
for vertex_id in vertices:
if vertex_id in self._vertex_id_to_index:
vertex_idx = self._vertex_id_to_index[vertex_id]
vertex_wrapper = self._graph.get_node_data(vertex_idx)
result[vertex_id] = vertex_wrapper.vertex_data
return result
###
# MutableGraphProtocol methods
###
def check_edges_exists(self, edge_id: EdgeIdT) -> bool:
return edge_id in self._edge_id_to_endpoints
def check_vertex_exists(self, vertex_id: VertexIdT) -> bool:
return vertex_id in self._vertex_id_to_index
def _add_edge(self, edge_id: EdgeIdT, edge_data: EdgeData[EdgeTypeT]) -> None:
# This internal method is not used in favor of a safer `attach_edge` implementation.
raise NotImplementedError(
"Use attach_edge to add edges. The internal _add_edge protocol method is flawed."
)
def _add_vertex(
self, vertex_id: VertexIdT, vertex_data: VertexData[VertexTypeT]
) -> None:
if vertex_id not in self._vertex_id_to_index:
wrapper = _VertexWrapper(vertex_id=vertex_id, vertex_data=vertex_data)
idx = self._graph.add_node(wrapper)
self._vertex_id_to_index[vertex_id] = idx
def _remove_edge(self, edge_id: EdgeIdT) -> None:
if edge_id in self._edge_id_to_endpoints:
u_idx, v_idx = self._edge_id_to_endpoints[edge_id]
self._graph.remove_edge(u_idx, v_idx)
del self._edge_id_to_endpoints[edge_id]
else:
raise ValueError(f"Edge with id {edge_id} not found.")
def _remove_vertex(self, vertex_id: VertexIdT) -> None:
if vertex_id in self._vertex_id_to_index:
vertex_idx = self._vertex_id_to_index[vertex_id]
# Remove any edges connected to this vertex from our mapping
edges_to_remove: list[EdgeIdT] = []
for edge_id, (u_idx, v_idx) in self._edge_id_to_endpoints.items():
if u_idx == vertex_idx or v_idx == vertex_idx:
edges_to_remove.append(edge_id)
for edge_id in edges_to_remove:
del self._edge_id_to_endpoints[edge_id]
# Remove the vertex from the graph
self._graph.remove_node(vertex_idx)
del self._vertex_id_to_index[vertex_id]
else:
raise ValueError(f"Vertex with id {vertex_id} not found.")
def attach_edge(
self,
edge: Edge[EdgeTypeT, EdgeIdT, VertexIdT],
extra_vertex: Vertex[VertexTypeT, EdgeIdT, VertexIdT] | None = None,
) -> None:
"""
Attaches an edge to the graph, overriding the default protocol implementation.
This implementation corrects a flaw in the protocol's `_add_edge`
signature and provides more intuitive behavior when connecting existing vertices.
"""
base_vertex_id, target_vertex_id = edge.edge_vertices
if not self.check_vertex_exists(base_vertex_id):
raise ValueError(f"Base vertex {base_vertex_id} does not exist.")
target_vertex_exists = self.check_vertex_exists(target_vertex_id)
if not target_vertex_exists:
if extra_vertex is None:
raise ValueError(
f"Target vertex {target_vertex_id} does not exist and no `extra_vertex` was provided."
)
if extra_vertex.vertex_id != target_vertex_id:
raise ValueError(
f"The ID of `extra_vertex` ({extra_vertex.vertex_id}) does not match "
f"the target vertex ID of the edge ({target_vertex_id})."
)
self._add_vertex(extra_vertex.vertex_id, extra_vertex.vertex_data)
elif extra_vertex is not None:
raise ValueError(
f"Target vertex {target_vertex_id} already exists, but `extra_vertex` was provided."
)
# Get the internal indices
base_idx = self._vertex_id_to_index[base_vertex_id]
target_idx = self._vertex_id_to_index[target_vertex_id]
# Create edge wrapper and add to graph
edge_wrapper = _EdgeWrapper(edge_id=edge.edge_id, edge_data=edge.edge_data)
self._graph.add_edge(base_idx, target_idx, edge_wrapper)
# Store the mapping
self._edge_id_to_endpoints[edge.edge_id] = (base_idx, target_idx)