fix sqlite connector

Co-authored-by: Gelu Vrabie <gelu@exolabs.net>
This commit is contained in:
Gelu Vrabie
2025-07-21 22:43:09 +01:00
committed by GitHub
parent 449fdac27a
commit 108128b620
4 changed files with 112 additions and 13 deletions

View File

@@ -15,10 +15,9 @@ from sqlmodel import SQLModel
from shared.types.events.common import (
BaseEvent,
EventCategories,
EventFromEventLog,
NodeId,
)
from shared.types.events.registry import EventParser
from shared.types.events.registry import Event, EventFromEventLogTyped, EventParser
from .types import StoredEvent
@@ -87,7 +86,7 @@ class AsyncSQLiteEventStorage:
async def get_events_since(
self,
last_idx: int
) -> Sequence[EventFromEventLog[EventCategories]]:
) -> Sequence[EventFromEventLogTyped]:
"""Retrieve events after a specific index."""
if self._closed:
raise RuntimeError("Storage is closed")
@@ -102,7 +101,7 @@ class AsyncSQLiteEventStorage:
)
rows = result.fetchall()
events: list[EventFromEventLog[EventCategories]] = []
events: list[EventFromEventLogTyped] = []
for row in rows:
rowid: int = cast(int, row[0])
origin: str = cast(str, row[1])
@@ -113,7 +112,7 @@ class AsyncSQLiteEventStorage:
else:
event_data = cast(dict[str, Any], raw_event_data)
event = await self._deserialize_event(event_data)
events.append(EventFromEventLog(
events.append(EventFromEventLogTyped(
event=event,
origin=NodeId(uuid=UUID(origin)),
idx_in_log=rowid # rowid becomes idx_in_log
@@ -215,13 +214,13 @@ class AsyncSQLiteEventStorage:
try:
async with AsyncSession(self._engine) as session:
for event, origin in batch:
for event, origin in batch:
stored_event = StoredEvent(
origin=str(origin.uuid),
event_type=event.event_type.value,
event_category=next(iter(event.event_category)).value,
event_type=str(event.event_type),
event_category=str(next(iter(event.event_category))),
event_id=str(event.event_id),
event_data=event.model_dump() # SQLModel handles JSON serialization automatically
event_data=event.model_dump(mode='json') # mode='json' ensures UUID conversion
)
session.add(stored_event)
@@ -233,9 +232,13 @@ class AsyncSQLiteEventStorage:
self._logger.error(f"Failed to commit batch: {e}")
raise
async def _deserialize_event(self, event_data: dict[str, Any]) -> BaseEvent[EventCategories]:
# TODO: This is a hack to get the event deserialization working. We need to find a better way to do this.
async def _deserialize_event(self, event_data: dict[str, Any]) -> Event:
"""Deserialize event data back to typed Event."""
return EventParser.validate_python(event_data)
# EventParser expects the discriminator field for proper deserialization
result = EventParser.validate_python(event_data)
# EventParser returns BaseEvent but we know it's actually a specific Event type
return result # type: ignore[reportReturnType]
async def _deserialize_event_raw(self, event_data: dict[str, Any]) -> dict[str, Any]:
"""Return raw event data for testing purposes."""

View File

@@ -11,6 +11,13 @@ from sqlalchemy.ext.asyncio import AsyncSession
from shared.db.sqlite import AsyncSQLiteEventStorage, EventLogConfig
from shared.types.common import NodeId
from shared.types.events.chunks import ChunkType, TokenChunk, TokenChunkData
from shared.types.events.events import (
ChunkGenerated,
EventCategoryEnum,
StreamingEventTypes,
)
from shared.types.tasks.common import TaskId
# Type ignore comment for all protected member access in this test file
# pyright: reportPrivateUsage=false
@@ -393,4 +400,69 @@ class TestAsyncSQLiteEventStorage:
for i, row in enumerate(rows):
assert row[0] == i + 1 # rowid should be sequential
await storage.close()
@pytest.mark.asyncio
async def test_chunk_generated_event_serialization(self, temp_db_path: Path, sample_node_id: NodeId) -> None:
"""Test that ChunkGenerated event with nested types can be serialized and deserialized correctly."""
default_config = EventLogConfig()
storage = AsyncSQLiteEventStorage(db_path=temp_db_path, batch_size=default_config.batch_size, batch_timeout_ms=default_config.batch_timeout_ms, debounce_ms=default_config.debounce_ms, max_age_ms=default_config.max_age_ms)
await storage.start()
# Create a ChunkGenerated event with nested TokenChunk
task_id = TaskId(uuid=uuid4())
chunk_data = TokenChunkData(
text="Hello, world!",
token_id=42,
finish_reason="stop"
)
token_chunk = TokenChunk(
chunk_data=chunk_data,
chunk_type=ChunkType.token,
task_id=task_id,
idx=0,
model="test-model"
)
chunk_generated_event = ChunkGenerated(
event_type=StreamingEventTypes.ChunkGenerated,
event_category=EventCategoryEnum.MutatesTaskState,
task_id=task_id,
chunk=token_chunk
)
# Store the event using the storage API
await storage.append_events([chunk_generated_event], sample_node_id) # type: ignore[reportArgumentType]
# Wait for batch to be written
await asyncio.sleep(0.5)
# Retrieve the event
events = await storage.get_events_since(0)
# Verify we got the event back
assert len(events) == 1
retrieved_event_wrapper = events[0]
assert retrieved_event_wrapper.origin == sample_node_id
# Verify the event was deserialized correctly
retrieved_event = retrieved_event_wrapper.event
assert isinstance(retrieved_event, ChunkGenerated)
assert retrieved_event.event_type == StreamingEventTypes.ChunkGenerated
assert retrieved_event.event_category == EventCategoryEnum.MutatesTaskState
assert retrieved_event.task_id == task_id
# Verify the nested chunk was deserialized correctly
retrieved_chunk = retrieved_event.chunk
assert isinstance(retrieved_chunk, TokenChunk)
assert retrieved_chunk.chunk_type == ChunkType.token
assert retrieved_chunk.task_id == task_id
assert retrieved_chunk.idx == 0
assert retrieved_chunk.model == "test-model"
# Verify the chunk data
assert retrieved_chunk.chunk_data.text == "Hello, world!"
assert retrieved_chunk.chunk_data.token_id == 42
assert retrieved_chunk.chunk_data.finish_reason == "stop"
await storage.close()

View File

@@ -1,5 +1,6 @@
from enum import Enum, StrEnum
from typing import (
TYPE_CHECKING,
Any,
Callable,
FrozenSet,
@@ -10,6 +11,9 @@ from typing import (
cast,
)
if TYPE_CHECKING:
pass
from pydantic import BaseModel, Field, model_validator
from shared.types.common import NewUUID, NodeId
@@ -138,7 +142,13 @@ class BaseEvent[
event_category: SetMembersT
event_id: EventId = EventId()
def check_event_was_sent_by_correct_node(self, origin_id: NodeId) -> bool: ...
def check_event_was_sent_by_correct_node(self, origin_id: NodeId) -> bool:
"""Check if the event was sent by the correct node.
This is a placeholder implementation that always returns True.
Subclasses can override this method to implement specific validation logic.
"""
return True
class EventFromEventLog[SetMembersT: EventCategories | EventCategory](BaseModel):

View File

@@ -1,7 +1,7 @@
from types import UnionType
from typing import Annotated, Any, Mapping, Type, get_args
from pydantic import Field, TypeAdapter
from pydantic import BaseModel, Field, TypeAdapter
from shared.constants import get_error_reporting_message
from shared.types.events.common import (
@@ -9,6 +9,7 @@ from shared.types.events.common import (
EventCategories,
EventTypes,
InstanceEventTypes,
NodeId,
NodePerformanceEventTypes,
RunnerStatusEventTypes,
StreamingEventTypes,
@@ -127,3 +128,16 @@ check_union_of_all_events_is_consistent_with_registry(EventRegistry, Event)
_EventType = Annotated[Event, Field(discriminator="event_type")]
EventParser: TypeAdapter[BaseEvent[EventCategories]] = TypeAdapter(_EventType)
# Define a properly typed EventFromEventLog that preserves specific event types
class EventFromEventLogTyped(BaseModel):
"""Properly typed EventFromEventLog that preserves specific event types."""
event: _EventType
origin: NodeId
idx_in_log: int = Field(gt=0)
def check_event_was_sent_by_correct_node(self) -> bool:
"""Check if the event was sent by the correct node."""
return self.event.check_event_was_sent_by_correct_node(self.origin)