mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 22:27:50 -05:00
chore: Run formatter
This commit is contained in:
@@ -3,15 +3,18 @@ from typing import TYPE_CHECKING, Literal, TypeAlias, get_type_hints
|
||||
if TYPE_CHECKING:
|
||||
import openai.types as openai_types
|
||||
import openai.types.chat as openai_chat
|
||||
|
||||
types = openai_types
|
||||
chat = openai_chat
|
||||
else:
|
||||
types = None
|
||||
chat = None
|
||||
|
||||
FinishReason: TypeAlias = Literal["stop", "length", "tool_calls", "content_filter", "function_call"]
|
||||
assert get_type_hints(chat.chat_completion_chunk.Choice)["finish_reason"] == FinishReason, (
|
||||
"Upstream changed Choice.finish_reason; update FinishReason alias."
|
||||
)
|
||||
FinishReason: TypeAlias = Literal[
|
||||
"stop", "length", "tool_calls", "content_filter", "function_call"
|
||||
]
|
||||
assert (
|
||||
get_type_hints(chat.chat_completion_chunk.Choice)["finish_reason"] == FinishReason
|
||||
), "Upstream changed Choice.finish_reason; update FinishReason alias."
|
||||
|
||||
__all__ = ["types", "chat", "FinishReason"]
|
||||
__all__ = ["types", "chat", "FinishReason"]
|
||||
|
||||
@@ -60,12 +60,12 @@ class Event(BaseModel, Generic[TEventType]):
|
||||
event_type: TEventType
|
||||
event_id: EventId
|
||||
|
||||
|
||||
class PersistedEvent(BaseModel, Generic[TEventType]):
|
||||
event: Event[TEventType]
|
||||
sequence_number: int = Field(gt=0)
|
||||
|
||||
|
||||
|
||||
class State(BaseModel, Generic[EventTypeT]):
|
||||
event_types: tuple[EventTypeT, ...] = get_args(EventTypeT)
|
||||
sequence_number: int = Field(default=0, ge=0)
|
||||
@@ -76,7 +76,9 @@ EventTypeParser: TypeAdapter[AnnotatedEventType] = TypeAdapter(AnnotatedEventTyp
|
||||
|
||||
Applicator = Callable[[State[EventTypeT], Event[TEventType]], State[EventTypeT]]
|
||||
Apply = Callable[[State[EventTypeT], Event[EventTypeT]], State[EventTypeT]]
|
||||
SagaApplicator = Callable[[State[EventTypeT], Event[TEventType]], Sequence[Event[EventTypeT]]]
|
||||
SagaApplicator = Callable[
|
||||
[State[EventTypeT], Event[TEventType]], Sequence[Event[EventTypeT]]
|
||||
]
|
||||
Saga = Callable[[State[EventTypeT], Event[EventTypeT]], Sequence[Event[EventTypeT]]]
|
||||
|
||||
StateAndEvent = Tuple[State[EventTypeT], Event[EventTypeT]]
|
||||
@@ -130,4 +132,6 @@ class Command(BaseModel, Generic[TEventType, TCommandType]):
|
||||
command_id: CommandId
|
||||
|
||||
|
||||
Decide = Callable[[State[EventTypeT], Command[TEventType, TCommandType]], Sequence[Event[EventTypeT]]]
|
||||
Decide = Callable[
|
||||
[State[EventTypeT], Command[TEventType, TCommandType]], Sequence[Event[EventTypeT]]
|
||||
]
|
||||
|
||||
@@ -25,17 +25,21 @@ _TimerId = Annotated[UUID, UuidVersion(4)]
|
||||
TimerId = type("TimerId", (UUID,), {})
|
||||
TimerIdParser: TypeAdapter[TimerId] = TypeAdapter(_TimerId)
|
||||
|
||||
|
||||
class Shard(BaseModel):
|
||||
# TODO: this has changed
|
||||
model_id: ModelId
|
||||
|
||||
|
||||
class InstanceComputePlan(BaseModel):
|
||||
# TODO: this has changed
|
||||
model_id: ModelId
|
||||
|
||||
|
||||
class Timer(BaseModel):
|
||||
timer_id: TimerId
|
||||
|
||||
|
||||
# Chat completions ----------------------------------------------------------------
|
||||
class ChatCompletionsRequestStarted(Event[Literal["ChatCompletionsRequestStarted"]]):
|
||||
event_type = "ChatCompletionsRequestStarted"
|
||||
@@ -44,7 +48,9 @@ class ChatCompletionsRequestStarted(Event[Literal["ChatCompletionsRequestStarted
|
||||
request: chat.completion_create_params.CompletionCreateParams
|
||||
|
||||
|
||||
class ChatCompletionsRequestCompleted(Event[Literal["ChatCompletionsRequestCompleted"]]):
|
||||
class ChatCompletionsRequestCompleted(
|
||||
Event[Literal["ChatCompletionsRequestCompleted"]]
|
||||
):
|
||||
event_type = "ChatCompletionsRequestCompleted"
|
||||
request_id: RequestId
|
||||
model_id: ModelId
|
||||
@@ -265,6 +271,7 @@ class DeviceProfiled(Event[Literal["DeviceProfiled"]]):
|
||||
available_memory_bytes: int
|
||||
total_flops_fp16: int
|
||||
|
||||
|
||||
# Token streaming ----------------------------------------------------------------
|
||||
class TokenGenerated(Event[Literal["TokenGenerated"]]):
|
||||
# TODO: replace with matt chunk code
|
||||
|
||||
@@ -12,37 +12,40 @@ _ModelId = Annotated[UUID, UuidVersion(4)]
|
||||
ModelId = type("ModelId", (UUID,), {})
|
||||
ModelIdParser: TypeAdapter[ModelId] = TypeAdapter(_ModelId)
|
||||
|
||||
RepoPath = Annotated[str, Field(pattern=r'^[^/]+/[^/]+$')]
|
||||
RepoPath = Annotated[str, Field(pattern=r"^[^/]+/[^/]+$")]
|
||||
RepoURL = Annotated[str, AnyHttpUrl]
|
||||
|
||||
|
||||
class BaseModelSource(BaseModel, Generic[T]):
|
||||
model_uuid: ModelId
|
||||
source_type: T
|
||||
source_data: Any
|
||||
|
||||
|
||||
@final
|
||||
class HuggingFaceModelSourceData(BaseModel):
|
||||
path: RepoPath
|
||||
|
||||
|
||||
@final
|
||||
class GitHubModelSourceData(BaseModel):
|
||||
url: AnyHttpUrl
|
||||
|
||||
|
||||
@final
|
||||
class HuggingFaceModelSource(BaseModelSource[Literal["HuggingFace"]]):
|
||||
source_type: Literal["HuggingFace"] = "HuggingFace"
|
||||
source_data: HuggingFaceModelSourceData
|
||||
|
||||
|
||||
@final
|
||||
class GitHubModelSource(BaseModelSource[Literal["GitHub"]]):
|
||||
source_type: Literal["GitHub"] = "GitHub"
|
||||
source_data: GitHubModelSourceData
|
||||
|
||||
|
||||
RepoType = BaseModelSource[SourceType]
|
||||
|
||||
RepoValidatorThing = Annotated[
|
||||
RepoType,
|
||||
Field(discriminator="source_type")
|
||||
]
|
||||
RepoValidatorThing = Annotated[RepoType, Field(discriminator="source_type")]
|
||||
|
||||
RepoValidator: TypeAdapter[RepoValidatorThing] = TypeAdapter(RepoValidatorThing)
|
||||
|
||||
Reference in New Issue
Block a user