mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 22:27:50 -05:00
@@ -15,10 +15,9 @@ from sqlmodel import SQLModel
|
||||
from shared.types.events.common import (
|
||||
BaseEvent,
|
||||
EventCategories,
|
||||
EventFromEventLog,
|
||||
NodeId,
|
||||
)
|
||||
from shared.types.events.registry import EventParser
|
||||
from shared.types.events.registry import Event, EventFromEventLogTyped, EventParser
|
||||
|
||||
from .types import StoredEvent
|
||||
|
||||
@@ -87,7 +86,7 @@ class AsyncSQLiteEventStorage:
|
||||
async def get_events_since(
|
||||
self,
|
||||
last_idx: int
|
||||
) -> Sequence[EventFromEventLog[EventCategories]]:
|
||||
) -> Sequence[EventFromEventLogTyped]:
|
||||
"""Retrieve events after a specific index."""
|
||||
if self._closed:
|
||||
raise RuntimeError("Storage is closed")
|
||||
@@ -102,7 +101,7 @@ class AsyncSQLiteEventStorage:
|
||||
)
|
||||
rows = result.fetchall()
|
||||
|
||||
events: list[EventFromEventLog[EventCategories]] = []
|
||||
events: list[EventFromEventLogTyped] = []
|
||||
for row in rows:
|
||||
rowid: int = cast(int, row[0])
|
||||
origin: str = cast(str, row[1])
|
||||
@@ -113,7 +112,7 @@ class AsyncSQLiteEventStorage:
|
||||
else:
|
||||
event_data = cast(dict[str, Any], raw_event_data)
|
||||
event = await self._deserialize_event(event_data)
|
||||
events.append(EventFromEventLog(
|
||||
events.append(EventFromEventLogTyped(
|
||||
event=event,
|
||||
origin=NodeId(uuid=UUID(origin)),
|
||||
idx_in_log=rowid # rowid becomes idx_in_log
|
||||
@@ -215,13 +214,13 @@ class AsyncSQLiteEventStorage:
|
||||
|
||||
try:
|
||||
async with AsyncSession(self._engine) as session:
|
||||
for event, origin in batch:
|
||||
for event, origin in batch:
|
||||
stored_event = StoredEvent(
|
||||
origin=str(origin.uuid),
|
||||
event_type=event.event_type.value,
|
||||
event_category=next(iter(event.event_category)).value,
|
||||
event_type=str(event.event_type),
|
||||
event_category=str(next(iter(event.event_category))),
|
||||
event_id=str(event.event_id),
|
||||
event_data=event.model_dump() # SQLModel handles JSON serialization automatically
|
||||
event_data=event.model_dump(mode='json') # mode='json' ensures UUID conversion
|
||||
)
|
||||
session.add(stored_event)
|
||||
|
||||
@@ -233,9 +232,13 @@ class AsyncSQLiteEventStorage:
|
||||
self._logger.error(f"Failed to commit batch: {e}")
|
||||
raise
|
||||
|
||||
async def _deserialize_event(self, event_data: dict[str, Any]) -> BaseEvent[EventCategories]:
|
||||
# TODO: This is a hack to get the event deserialization working. We need to find a better way to do this.
|
||||
async def _deserialize_event(self, event_data: dict[str, Any]) -> Event:
|
||||
"""Deserialize event data back to typed Event."""
|
||||
return EventParser.validate_python(event_data)
|
||||
# EventParser expects the discriminator field for proper deserialization
|
||||
result = EventParser.validate_python(event_data)
|
||||
# EventParser returns BaseEvent but we know it's actually a specific Event type
|
||||
return result # type: ignore[reportReturnType]
|
||||
|
||||
async def _deserialize_event_raw(self, event_data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Return raw event data for testing purposes."""
|
||||
|
||||
@@ -11,6 +11,13 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from shared.db.sqlite import AsyncSQLiteEventStorage, EventLogConfig
|
||||
from shared.types.common import NodeId
|
||||
from shared.types.events.chunks import ChunkType, TokenChunk, TokenChunkData
|
||||
from shared.types.events.events import (
|
||||
ChunkGenerated,
|
||||
EventCategoryEnum,
|
||||
StreamingEventTypes,
|
||||
)
|
||||
from shared.types.tasks.common import TaskId
|
||||
|
||||
# Type ignore comment for all protected member access in this test file
|
||||
# pyright: reportPrivateUsage=false
|
||||
@@ -393,4 +400,69 @@ class TestAsyncSQLiteEventStorage:
|
||||
for i, row in enumerate(rows):
|
||||
assert row[0] == i + 1 # rowid should be sequential
|
||||
|
||||
await storage.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chunk_generated_event_serialization(self, temp_db_path: Path, sample_node_id: NodeId) -> None:
|
||||
"""Test that ChunkGenerated event with nested types can be serialized and deserialized correctly."""
|
||||
default_config = EventLogConfig()
|
||||
storage = AsyncSQLiteEventStorage(db_path=temp_db_path, batch_size=default_config.batch_size, batch_timeout_ms=default_config.batch_timeout_ms, debounce_ms=default_config.debounce_ms, max_age_ms=default_config.max_age_ms)
|
||||
await storage.start()
|
||||
|
||||
# Create a ChunkGenerated event with nested TokenChunk
|
||||
task_id = TaskId(uuid=uuid4())
|
||||
chunk_data = TokenChunkData(
|
||||
text="Hello, world!",
|
||||
token_id=42,
|
||||
finish_reason="stop"
|
||||
)
|
||||
token_chunk = TokenChunk(
|
||||
chunk_data=chunk_data,
|
||||
chunk_type=ChunkType.token,
|
||||
task_id=task_id,
|
||||
idx=0,
|
||||
model="test-model"
|
||||
)
|
||||
|
||||
chunk_generated_event = ChunkGenerated(
|
||||
event_type=StreamingEventTypes.ChunkGenerated,
|
||||
event_category=EventCategoryEnum.MutatesTaskState,
|
||||
task_id=task_id,
|
||||
chunk=token_chunk
|
||||
)
|
||||
|
||||
# Store the event using the storage API
|
||||
await storage.append_events([chunk_generated_event], sample_node_id) # type: ignore[reportArgumentType]
|
||||
|
||||
# Wait for batch to be written
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Retrieve the event
|
||||
events = await storage.get_events_since(0)
|
||||
|
||||
# Verify we got the event back
|
||||
assert len(events) == 1
|
||||
retrieved_event_wrapper = events[0]
|
||||
assert retrieved_event_wrapper.origin == sample_node_id
|
||||
|
||||
# Verify the event was deserialized correctly
|
||||
retrieved_event = retrieved_event_wrapper.event
|
||||
assert isinstance(retrieved_event, ChunkGenerated)
|
||||
assert retrieved_event.event_type == StreamingEventTypes.ChunkGenerated
|
||||
assert retrieved_event.event_category == EventCategoryEnum.MutatesTaskState
|
||||
assert retrieved_event.task_id == task_id
|
||||
|
||||
# Verify the nested chunk was deserialized correctly
|
||||
retrieved_chunk = retrieved_event.chunk
|
||||
assert isinstance(retrieved_chunk, TokenChunk)
|
||||
assert retrieved_chunk.chunk_type == ChunkType.token
|
||||
assert retrieved_chunk.task_id == task_id
|
||||
assert retrieved_chunk.idx == 0
|
||||
assert retrieved_chunk.model == "test-model"
|
||||
|
||||
# Verify the chunk data
|
||||
assert retrieved_chunk.chunk_data.text == "Hello, world!"
|
||||
assert retrieved_chunk.chunk_data.token_id == 42
|
||||
assert retrieved_chunk.chunk_data.finish_reason == "stop"
|
||||
|
||||
await storage.close()
|
||||
@@ -1,5 +1,6 @@
|
||||
from enum import Enum, StrEnum
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
FrozenSet,
|
||||
@@ -10,6 +11,9 @@ from typing import (
|
||||
cast,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from shared.types.common import NewUUID, NodeId
|
||||
@@ -138,7 +142,13 @@ class BaseEvent[
|
||||
event_category: SetMembersT
|
||||
event_id: EventId = EventId()
|
||||
|
||||
def check_event_was_sent_by_correct_node(self, origin_id: NodeId) -> bool: ...
|
||||
def check_event_was_sent_by_correct_node(self, origin_id: NodeId) -> bool:
|
||||
"""Check if the event was sent by the correct node.
|
||||
|
||||
This is a placeholder implementation that always returns True.
|
||||
Subclasses can override this method to implement specific validation logic.
|
||||
"""
|
||||
return True
|
||||
|
||||
|
||||
class EventFromEventLog[SetMembersT: EventCategories | EventCategory](BaseModel):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from types import UnionType
|
||||
from typing import Annotated, Any, Mapping, Type, get_args
|
||||
|
||||
from pydantic import Field, TypeAdapter
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
|
||||
from shared.constants import get_error_reporting_message
|
||||
from shared.types.events.common import (
|
||||
@@ -9,6 +9,7 @@ from shared.types.events.common import (
|
||||
EventCategories,
|
||||
EventTypes,
|
||||
InstanceEventTypes,
|
||||
NodeId,
|
||||
NodePerformanceEventTypes,
|
||||
RunnerStatusEventTypes,
|
||||
StreamingEventTypes,
|
||||
@@ -127,3 +128,16 @@ check_union_of_all_events_is_consistent_with_registry(EventRegistry, Event)
|
||||
|
||||
_EventType = Annotated[Event, Field(discriminator="event_type")]
|
||||
EventParser: TypeAdapter[BaseEvent[EventCategories]] = TypeAdapter(_EventType)
|
||||
|
||||
|
||||
# Define a properly typed EventFromEventLog that preserves specific event types
|
||||
|
||||
class EventFromEventLogTyped(BaseModel):
|
||||
"""Properly typed EventFromEventLog that preserves specific event types."""
|
||||
event: _EventType
|
||||
origin: NodeId
|
||||
idx_in_log: int = Field(gt=0)
|
||||
|
||||
def check_event_was_sent_by_correct_node(self) -> bool:
|
||||
"""Check if the event was sent by the correct node."""
|
||||
return self.event.check_event_was_sent_by_correct_node(self.origin)
|
||||
|
||||
Reference in New Issue
Block a user