diff --git a/shared/openai.py b/shared/openai.py index 1caa4a43..0a0a546f 100644 --- a/shared/openai.py +++ b/shared/openai.py @@ -3,15 +3,18 @@ from typing import TYPE_CHECKING, Literal, TypeAlias, get_type_hints if TYPE_CHECKING: import openai.types as openai_types import openai.types.chat as openai_chat + types = openai_types chat = openai_chat else: types = None chat = None -FinishReason: TypeAlias = Literal["stop", "length", "tool_calls", "content_filter", "function_call"] -assert get_type_hints(chat.chat_completion_chunk.Choice)["finish_reason"] == FinishReason, ( - "Upstream changed Choice.finish_reason; update FinishReason alias." -) +FinishReason: TypeAlias = Literal[ + "stop", "length", "tool_calls", "content_filter", "function_call" +] +assert ( + get_type_hints(chat.chat_completion_chunk.Choice)["finish_reason"] == FinishReason +), "Upstream changed Choice.finish_reason; update FinishReason alias." -__all__ = ["types", "chat", "FinishReason"] \ No newline at end of file +__all__ = ["types", "chat", "FinishReason"] diff --git a/shared/types/event_sourcing.py b/shared/types/event_sourcing.py index ed239e43..e4b6138b 100644 --- a/shared/types/event_sourcing.py +++ b/shared/types/event_sourcing.py @@ -60,12 +60,12 @@ class Event(BaseModel, Generic[TEventType]): event_type: TEventType event_id: EventId + class PersistedEvent(BaseModel, Generic[TEventType]): event: Event[TEventType] sequence_number: int = Field(gt=0) - class State(BaseModel, Generic[EventTypeT]): event_types: tuple[EventTypeT, ...] = get_args(EventTypeT) sequence_number: int = Field(default=0, ge=0) @@ -76,7 +76,9 @@ EventTypeParser: TypeAdapter[AnnotatedEventType] = TypeAdapter(AnnotatedEventTyp Applicator = Callable[[State[EventTypeT], Event[TEventType]], State[EventTypeT]] Apply = Callable[[State[EventTypeT], Event[EventTypeT]], State[EventTypeT]] -SagaApplicator = Callable[[State[EventTypeT], Event[TEventType]], Sequence[Event[EventTypeT]]] +SagaApplicator = Callable[ + [State[EventTypeT], Event[TEventType]], Sequence[Event[EventTypeT]] +] Saga = Callable[[State[EventTypeT], Event[EventTypeT]], Sequence[Event[EventTypeT]]] StateAndEvent = Tuple[State[EventTypeT], Event[EventTypeT]] @@ -130,4 +132,6 @@ class Command(BaseModel, Generic[TEventType, TCommandType]): command_id: CommandId -Decide = Callable[[State[EventTypeT], Command[TEventType, TCommandType]], Sequence[Event[EventTypeT]]] +Decide = Callable[ + [State[EventTypeT], Command[TEventType, TCommandType]], Sequence[Event[EventTypeT]] +] diff --git a/shared/types/events.py b/shared/types/events.py index 8b9b9cb5..9e79e659 100644 --- a/shared/types/events.py +++ b/shared/types/events.py @@ -25,17 +25,21 @@ _TimerId = Annotated[UUID, UuidVersion(4)] TimerId = type("TimerId", (UUID,), {}) TimerIdParser: TypeAdapter[TimerId] = TypeAdapter(_TimerId) + class Shard(BaseModel): # TODO: this has changed model_id: ModelId + class InstanceComputePlan(BaseModel): # TODO: this has changed model_id: ModelId + class Timer(BaseModel): timer_id: TimerId + # Chat completions ---------------------------------------------------------------- class ChatCompletionsRequestStarted(Event[Literal["ChatCompletionsRequestStarted"]]): event_type = "ChatCompletionsRequestStarted" @@ -44,7 +48,9 @@ class ChatCompletionsRequestStarted(Event[Literal["ChatCompletionsRequestStarted request: chat.completion_create_params.CompletionCreateParams -class ChatCompletionsRequestCompleted(Event[Literal["ChatCompletionsRequestCompleted"]]): +class ChatCompletionsRequestCompleted( + Event[Literal["ChatCompletionsRequestCompleted"]] +): event_type = "ChatCompletionsRequestCompleted" request_id: RequestId model_id: ModelId @@ -265,6 +271,7 @@ class DeviceProfiled(Event[Literal["DeviceProfiled"]]): available_memory_bytes: int total_flops_fp16: int + # Token streaming ---------------------------------------------------------------- class TokenGenerated(Event[Literal["TokenGenerated"]]): # TODO: replace with matt chunk code diff --git a/shared/types/model.py b/shared/types/model.py index b9b3fc8c..d0e11ed6 100644 --- a/shared/types/model.py +++ b/shared/types/model.py @@ -12,37 +12,40 @@ _ModelId = Annotated[UUID, UuidVersion(4)] ModelId = type("ModelId", (UUID,), {}) ModelIdParser: TypeAdapter[ModelId] = TypeAdapter(_ModelId) -RepoPath = Annotated[str, Field(pattern=r'^[^/]+/[^/]+$')] +RepoPath = Annotated[str, Field(pattern=r"^[^/]+/[^/]+$")] RepoURL = Annotated[str, AnyHttpUrl] + class BaseModelSource(BaseModel, Generic[T]): model_uuid: ModelId source_type: T source_data: Any + @final class HuggingFaceModelSourceData(BaseModel): path: RepoPath + @final class GitHubModelSourceData(BaseModel): url: AnyHttpUrl + @final class HuggingFaceModelSource(BaseModelSource[Literal["HuggingFace"]]): source_type: Literal["HuggingFace"] = "HuggingFace" source_data: HuggingFaceModelSourceData + @final class GitHubModelSource(BaseModelSource[Literal["GitHub"]]): source_type: Literal["GitHub"] = "GitHub" source_data: GitHubModelSourceData + RepoType = BaseModelSource[SourceType] -RepoValidatorThing = Annotated[ - RepoType, - Field(discriminator="source_type") -] +RepoValidatorThing = Annotated[RepoType, Field(discriminator="source_type")] RepoValidator: TypeAdapter[RepoValidatorThing] = TypeAdapter(RepoValidatorThing)