mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 22:27:50 -05:00
222 lines
7.8 KiB
Python
222 lines
7.8 KiB
Python
from typing import Set, Mapping
|
|
from dataclasses import dataclass
|
|
from pydantic import TypeAdapter
|
|
|
|
import rustworkx as rx
|
|
|
|
from shared.types.graphs.common import (
|
|
Edge,
|
|
EdgeData,
|
|
MutableGraphProtocol,
|
|
Vertex,
|
|
VertexData,
|
|
EdgeIdT,
|
|
VertexIdT,
|
|
EdgeTypeT,
|
|
VertexTypeT,
|
|
)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class _VertexWrapper[VertexTypeT, VertexIdT]:
|
|
"""Internal wrapper to store vertex ID alongside vertex data."""
|
|
|
|
vertex_id: VertexIdT
|
|
vertex_data: VertexData[VertexTypeT]
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class _EdgeWrapper[EdgeTypeT, EdgeIdT]:
|
|
"""Internal wrapper to store edge ID alongside edge data."""
|
|
|
|
edge_id: EdgeIdT
|
|
edge_data: EdgeData[EdgeTypeT]
|
|
|
|
|
|
class NetworkXGraph(MutableGraphProtocol[EdgeTypeT, VertexTypeT, EdgeIdT, VertexIdT]):
|
|
edge_base: TypeAdapter[EdgeTypeT]
|
|
vertex_base: TypeAdapter[VertexTypeT]
|
|
|
|
_graph: rx.PyDiGraph[
|
|
_VertexWrapper[VertexTypeT, VertexIdT], _EdgeWrapper[EdgeTypeT, EdgeIdT]
|
|
]
|
|
_vertex_id_to_index: dict[VertexIdT, int]
|
|
_edge_id_to_endpoints: dict[EdgeIdT, tuple[int, int]]
|
|
|
|
def __init__(
|
|
self, edge_base: TypeAdapter[EdgeTypeT], vertex_base: TypeAdapter[VertexTypeT]
|
|
) -> None:
|
|
self.edge_base = edge_base
|
|
self.vertex_base = vertex_base
|
|
self._graph = rx.PyDiGraph()
|
|
self._vertex_id_to_index = {}
|
|
self._edge_id_to_endpoints = {}
|
|
|
|
###
|
|
# GraphProtocol methods
|
|
###
|
|
|
|
def list_edges(self) -> Set[EdgeIdT]:
|
|
return {edge.edge_id for edge in self._graph.edges()}
|
|
|
|
def list_vertices(self) -> Set[VertexIdT]:
|
|
return {node.vertex_id for node in self._graph.nodes()}
|
|
|
|
def get_vertices_from_edges(
|
|
self, edges: Set[EdgeIdT]
|
|
) -> Mapping[EdgeIdT, Set[VertexIdT]]:
|
|
result: dict[EdgeIdT, Set[VertexIdT]] = {}
|
|
|
|
for edge_id in edges:
|
|
if edge_id in self._edge_id_to_endpoints:
|
|
u_idx, v_idx = self._edge_id_to_endpoints[edge_id]
|
|
u_data = self._graph.get_node_data(u_idx)
|
|
v_data = self._graph.get_node_data(v_idx)
|
|
result[edge_id] = {u_data.vertex_id, v_data.vertex_id}
|
|
|
|
return result
|
|
|
|
def get_edges_from_vertices(
|
|
self, vertices: Set[VertexIdT]
|
|
) -> Mapping[VertexIdT, Set[EdgeIdT]]:
|
|
result: dict[VertexIdT, Set[EdgeIdT]] = {}
|
|
|
|
for vertex_id in vertices:
|
|
if vertex_id in self._vertex_id_to_index:
|
|
vertex_idx = self._vertex_id_to_index[vertex_id]
|
|
edge_ids: Set[EdgeIdT] = set()
|
|
|
|
# Get outgoing edges
|
|
for _, _, edge_data in self._graph.out_edges(vertex_idx):
|
|
edge_ids.add(edge_data.edge_id)
|
|
|
|
# Get incoming edges
|
|
for _, _, edge_data in self._graph.in_edges(vertex_idx):
|
|
edge_ids.add(edge_data.edge_id)
|
|
|
|
result[vertex_id] = edge_ids
|
|
|
|
return result
|
|
|
|
def get_edge_data(
|
|
self, edges: Set[EdgeIdT]
|
|
) -> Mapping[EdgeIdT, EdgeData[EdgeTypeT]]:
|
|
result: dict[EdgeIdT, EdgeData[EdgeTypeT]] = {}
|
|
|
|
for edge_id in edges:
|
|
if edge_id in self._edge_id_to_endpoints:
|
|
u_idx, v_idx = self._edge_id_to_endpoints[edge_id]
|
|
edge_wrapper = self._graph.get_edge_data(u_idx, v_idx)
|
|
result[edge_id] = edge_wrapper.edge_data
|
|
|
|
return result
|
|
|
|
def get_vertex_data(
|
|
self, vertices: Set[VertexIdT]
|
|
) -> Mapping[VertexIdT, VertexData[VertexTypeT]]:
|
|
result: dict[VertexIdT, VertexData[VertexTypeT]] = {}
|
|
|
|
for vertex_id in vertices:
|
|
if vertex_id in self._vertex_id_to_index:
|
|
vertex_idx = self._vertex_id_to_index[vertex_id]
|
|
vertex_wrapper = self._graph.get_node_data(vertex_idx)
|
|
result[vertex_id] = vertex_wrapper.vertex_data
|
|
|
|
return result
|
|
|
|
###
|
|
# MutableGraphProtocol methods
|
|
###
|
|
|
|
def check_edges_exists(self, edge_id: EdgeIdT) -> bool:
|
|
return edge_id in self._edge_id_to_endpoints
|
|
|
|
def check_vertex_exists(self, vertex_id: VertexIdT) -> bool:
|
|
return vertex_id in self._vertex_id_to_index
|
|
|
|
def _add_edge(self, edge_id: EdgeIdT, edge_data: EdgeData[EdgeTypeT]) -> None:
|
|
# This internal method is not used in favor of a safer `attach_edge` implementation.
|
|
raise NotImplementedError(
|
|
"Use attach_edge to add edges. The internal _add_edge protocol method is flawed."
|
|
)
|
|
|
|
def _add_vertex(
|
|
self, vertex_id: VertexIdT, vertex_data: VertexData[VertexTypeT]
|
|
) -> None:
|
|
if vertex_id not in self._vertex_id_to_index:
|
|
wrapper = _VertexWrapper(vertex_id=vertex_id, vertex_data=vertex_data)
|
|
idx = self._graph.add_node(wrapper)
|
|
self._vertex_id_to_index[vertex_id] = idx
|
|
|
|
def _remove_edge(self, edge_id: EdgeIdT) -> None:
|
|
if edge_id in self._edge_id_to_endpoints:
|
|
u_idx, v_idx = self._edge_id_to_endpoints[edge_id]
|
|
self._graph.remove_edge(u_idx, v_idx)
|
|
del self._edge_id_to_endpoints[edge_id]
|
|
else:
|
|
raise ValueError(f"Edge with id {edge_id} not found.")
|
|
|
|
def _remove_vertex(self, vertex_id: VertexIdT) -> None:
|
|
if vertex_id in self._vertex_id_to_index:
|
|
vertex_idx = self._vertex_id_to_index[vertex_id]
|
|
|
|
# Remove any edges connected to this vertex from our mapping
|
|
edges_to_remove: list[EdgeIdT] = []
|
|
for edge_id, (u_idx, v_idx) in self._edge_id_to_endpoints.items():
|
|
if u_idx == vertex_idx or v_idx == vertex_idx:
|
|
edges_to_remove.append(edge_id)
|
|
|
|
for edge_id in edges_to_remove:
|
|
del self._edge_id_to_endpoints[edge_id]
|
|
|
|
# Remove the vertex from the graph
|
|
self._graph.remove_node(vertex_idx)
|
|
del self._vertex_id_to_index[vertex_id]
|
|
else:
|
|
raise ValueError(f"Vertex with id {vertex_id} not found.")
|
|
|
|
def attach_edge(
|
|
self,
|
|
edge: Edge[EdgeTypeT, EdgeIdT, VertexIdT],
|
|
extra_vertex: Vertex[VertexTypeT, EdgeIdT, VertexIdT] | None = None,
|
|
) -> None:
|
|
"""
|
|
Attaches an edge to the graph, overriding the default protocol implementation.
|
|
|
|
This implementation corrects a flaw in the protocol's `_add_edge`
|
|
signature and provides more intuitive behavior when connecting existing vertices.
|
|
"""
|
|
base_vertex_id, target_vertex_id = edge.edge_vertices
|
|
|
|
if not self.check_vertex_exists(base_vertex_id):
|
|
raise ValueError(f"Base vertex {base_vertex_id} does not exist.")
|
|
|
|
target_vertex_exists = self.check_vertex_exists(target_vertex_id)
|
|
|
|
if not target_vertex_exists:
|
|
if extra_vertex is None:
|
|
raise ValueError(
|
|
f"Target vertex {target_vertex_id} does not exist and no `extra_vertex` was provided."
|
|
)
|
|
if extra_vertex.vertex_id != target_vertex_id:
|
|
raise ValueError(
|
|
f"The ID of `extra_vertex` ({extra_vertex.vertex_id}) does not match "
|
|
f"the target vertex ID of the edge ({target_vertex_id})."
|
|
)
|
|
self._add_vertex(extra_vertex.vertex_id, extra_vertex.vertex_data)
|
|
elif extra_vertex is not None:
|
|
raise ValueError(
|
|
f"Target vertex {target_vertex_id} already exists, but `extra_vertex` was provided."
|
|
)
|
|
|
|
# Get the internal indices
|
|
base_idx = self._vertex_id_to_index[base_vertex_id]
|
|
target_idx = self._vertex_id_to_index[target_vertex_id]
|
|
|
|
# Create edge wrapper and add to graph
|
|
edge_wrapper = _EdgeWrapper(edge_id=edge.edge_id, edge_data=edge.edge_data)
|
|
self._graph.add_edge(base_idx, target_idx, edge_wrapper)
|
|
|
|
# Store the mapping
|
|
self._edge_id_to_endpoints[edge.edge_id] = (base_idx, target_idx)
|