From 108128b620eff8eb72a9cf95d83ff9ef40067f43 Mon Sep 17 00:00:00 2001 From: Gelu Vrabie Date: Mon, 21 Jul 2025 22:43:09 +0100 Subject: [PATCH] fix sqlite connector Co-authored-by: Gelu Vrabie --- shared/db/sqlite/connector.py | 25 ++++++---- shared/tests/test_sqlite_connector.py | 72 +++++++++++++++++++++++++++ shared/types/events/common.py | 12 ++++- shared/types/events/registry.py | 16 +++++- 4 files changed, 112 insertions(+), 13 deletions(-) diff --git a/shared/db/sqlite/connector.py b/shared/db/sqlite/connector.py index 199d2973..b0abff65 100644 --- a/shared/db/sqlite/connector.py +++ b/shared/db/sqlite/connector.py @@ -15,10 +15,9 @@ from sqlmodel import SQLModel from shared.types.events.common import ( BaseEvent, EventCategories, - EventFromEventLog, NodeId, ) -from shared.types.events.registry import EventParser +from shared.types.events.registry import Event, EventFromEventLogTyped, EventParser from .types import StoredEvent @@ -87,7 +86,7 @@ class AsyncSQLiteEventStorage: async def get_events_since( self, last_idx: int - ) -> Sequence[EventFromEventLog[EventCategories]]: + ) -> Sequence[EventFromEventLogTyped]: """Retrieve events after a specific index.""" if self._closed: raise RuntimeError("Storage is closed") @@ -102,7 +101,7 @@ class AsyncSQLiteEventStorage: ) rows = result.fetchall() - events: list[EventFromEventLog[EventCategories]] = [] + events: list[EventFromEventLogTyped] = [] for row in rows: rowid: int = cast(int, row[0]) origin: str = cast(str, row[1]) @@ -113,7 +112,7 @@ class AsyncSQLiteEventStorage: else: event_data = cast(dict[str, Any], raw_event_data) event = await self._deserialize_event(event_data) - events.append(EventFromEventLog( + events.append(EventFromEventLogTyped( event=event, origin=NodeId(uuid=UUID(origin)), idx_in_log=rowid # rowid becomes idx_in_log @@ -215,13 +214,13 @@ class AsyncSQLiteEventStorage: try: async with AsyncSession(self._engine) as session: - for event, origin in batch: + for event, origin in batch: stored_event = StoredEvent( origin=str(origin.uuid), - event_type=event.event_type.value, - event_category=next(iter(event.event_category)).value, + event_type=str(event.event_type), + event_category=str(next(iter(event.event_category))), event_id=str(event.event_id), - event_data=event.model_dump() # SQLModel handles JSON serialization automatically + event_data=event.model_dump(mode='json') # mode='json' ensures UUID conversion ) session.add(stored_event) @@ -233,9 +232,13 @@ class AsyncSQLiteEventStorage: self._logger.error(f"Failed to commit batch: {e}") raise - async def _deserialize_event(self, event_data: dict[str, Any]) -> BaseEvent[EventCategories]: + # TODO: This is a hack to get the event deserialization working. We need to find a better way to do this. + async def _deserialize_event(self, event_data: dict[str, Any]) -> Event: """Deserialize event data back to typed Event.""" - return EventParser.validate_python(event_data) + # EventParser expects the discriminator field for proper deserialization + result = EventParser.validate_python(event_data) + # EventParser returns BaseEvent but we know it's actually a specific Event type + return result # type: ignore[reportReturnType] async def _deserialize_event_raw(self, event_data: dict[str, Any]) -> dict[str, Any]: """Return raw event data for testing purposes.""" diff --git a/shared/tests/test_sqlite_connector.py b/shared/tests/test_sqlite_connector.py index 80e921ac..6d3ec13f 100644 --- a/shared/tests/test_sqlite_connector.py +++ b/shared/tests/test_sqlite_connector.py @@ -11,6 +11,13 @@ from sqlalchemy.ext.asyncio import AsyncSession from shared.db.sqlite import AsyncSQLiteEventStorage, EventLogConfig from shared.types.common import NodeId +from shared.types.events.chunks import ChunkType, TokenChunk, TokenChunkData +from shared.types.events.events import ( + ChunkGenerated, + EventCategoryEnum, + StreamingEventTypes, +) +from shared.types.tasks.common import TaskId # Type ignore comment for all protected member access in this test file # pyright: reportPrivateUsage=false @@ -393,4 +400,69 @@ class TestAsyncSQLiteEventStorage: for i, row in enumerate(rows): assert row[0] == i + 1 # rowid should be sequential + await storage.close() + + @pytest.mark.asyncio + async def test_chunk_generated_event_serialization(self, temp_db_path: Path, sample_node_id: NodeId) -> None: + """Test that ChunkGenerated event with nested types can be serialized and deserialized correctly.""" + default_config = EventLogConfig() + storage = AsyncSQLiteEventStorage(db_path=temp_db_path, batch_size=default_config.batch_size, batch_timeout_ms=default_config.batch_timeout_ms, debounce_ms=default_config.debounce_ms, max_age_ms=default_config.max_age_ms) + await storage.start() + + # Create a ChunkGenerated event with nested TokenChunk + task_id = TaskId(uuid=uuid4()) + chunk_data = TokenChunkData( + text="Hello, world!", + token_id=42, + finish_reason="stop" + ) + token_chunk = TokenChunk( + chunk_data=chunk_data, + chunk_type=ChunkType.token, + task_id=task_id, + idx=0, + model="test-model" + ) + + chunk_generated_event = ChunkGenerated( + event_type=StreamingEventTypes.ChunkGenerated, + event_category=EventCategoryEnum.MutatesTaskState, + task_id=task_id, + chunk=token_chunk + ) + + # Store the event using the storage API + await storage.append_events([chunk_generated_event], sample_node_id) # type: ignore[reportArgumentType] + + # Wait for batch to be written + await asyncio.sleep(0.5) + + # Retrieve the event + events = await storage.get_events_since(0) + + # Verify we got the event back + assert len(events) == 1 + retrieved_event_wrapper = events[0] + assert retrieved_event_wrapper.origin == sample_node_id + + # Verify the event was deserialized correctly + retrieved_event = retrieved_event_wrapper.event + assert isinstance(retrieved_event, ChunkGenerated) + assert retrieved_event.event_type == StreamingEventTypes.ChunkGenerated + assert retrieved_event.event_category == EventCategoryEnum.MutatesTaskState + assert retrieved_event.task_id == task_id + + # Verify the nested chunk was deserialized correctly + retrieved_chunk = retrieved_event.chunk + assert isinstance(retrieved_chunk, TokenChunk) + assert retrieved_chunk.chunk_type == ChunkType.token + assert retrieved_chunk.task_id == task_id + assert retrieved_chunk.idx == 0 + assert retrieved_chunk.model == "test-model" + + # Verify the chunk data + assert retrieved_chunk.chunk_data.text == "Hello, world!" + assert retrieved_chunk.chunk_data.token_id == 42 + assert retrieved_chunk.chunk_data.finish_reason == "stop" + await storage.close() \ No newline at end of file diff --git a/shared/types/events/common.py b/shared/types/events/common.py index 5dcbd945..cdae35c9 100644 --- a/shared/types/events/common.py +++ b/shared/types/events/common.py @@ -1,5 +1,6 @@ from enum import Enum, StrEnum from typing import ( + TYPE_CHECKING, Any, Callable, FrozenSet, @@ -10,6 +11,9 @@ from typing import ( cast, ) +if TYPE_CHECKING: + pass + from pydantic import BaseModel, Field, model_validator from shared.types.common import NewUUID, NodeId @@ -138,7 +142,13 @@ class BaseEvent[ event_category: SetMembersT event_id: EventId = EventId() - def check_event_was_sent_by_correct_node(self, origin_id: NodeId) -> bool: ... + def check_event_was_sent_by_correct_node(self, origin_id: NodeId) -> bool: + """Check if the event was sent by the correct node. + + This is a placeholder implementation that always returns True. + Subclasses can override this method to implement specific validation logic. + """ + return True class EventFromEventLog[SetMembersT: EventCategories | EventCategory](BaseModel): diff --git a/shared/types/events/registry.py b/shared/types/events/registry.py index 5748d6a8..8ba17138 100644 --- a/shared/types/events/registry.py +++ b/shared/types/events/registry.py @@ -1,7 +1,7 @@ from types import UnionType from typing import Annotated, Any, Mapping, Type, get_args -from pydantic import Field, TypeAdapter +from pydantic import BaseModel, Field, TypeAdapter from shared.constants import get_error_reporting_message from shared.types.events.common import ( @@ -9,6 +9,7 @@ from shared.types.events.common import ( EventCategories, EventTypes, InstanceEventTypes, + NodeId, NodePerformanceEventTypes, RunnerStatusEventTypes, StreamingEventTypes, @@ -127,3 +128,16 @@ check_union_of_all_events_is_consistent_with_registry(EventRegistry, Event) _EventType = Annotated[Event, Field(discriminator="event_type")] EventParser: TypeAdapter[BaseEvent[EventCategories]] = TypeAdapter(_EventType) + + +# Define a properly typed EventFromEventLog that preserves specific event types + +class EventFromEventLogTyped(BaseModel): + """Properly typed EventFromEventLog that preserves specific event types.""" + event: _EventType + origin: NodeId + idx_in_log: int = Field(gt=0) + + def check_event_was_sent_by_correct_node(self) -> bool: + """Check if the event was sent by the correct node.""" + return self.event.check_event_was_sent_by_correct_node(self.origin)