mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 22:27:50 -05:00
Co-authored-by: Gelu Vrabie <gelu@exolabs.net> Co-authored-by: Alex Cheema <41707476+AlexCheema@users.noreply.github.com> Co-authored-by: Seth Howes <71157822+sethhowes@users.noreply.github.com> Co-authored-by: Matt Beton <matthew.beton@gmail.com> Co-authored-by: Alex Cheema <alexcheema123@gmail.com>
46 lines
1.3 KiB
Python
46 lines
1.3 KiB
Python
from typing import Callable, TypeVar
|
|
|
|
from pydantic import BaseModel, TypeAdapter
|
|
|
|
from shared.types.common import Host
|
|
from shared.types.tasks import Task, TaskId
|
|
from shared.types.worker.commands_runner import (
|
|
ChatTaskMessage,
|
|
RunnerMessageTypeAdapter,
|
|
SetupMessage,
|
|
)
|
|
from shared.types.worker.common import InstanceId
|
|
from shared.types.worker.shards import PipelineShardMetadata
|
|
|
|
T = TypeVar("T", bound=BaseModel)
|
|
|
|
|
|
def assert_equal_serdes(obj: T, typeadapter: TypeAdapter[T]):
|
|
encoded: bytes = obj.model_dump_json().encode("utf-8") + b"\n"
|
|
decoded: T = typeadapter.validate_json(encoded)
|
|
|
|
assert decoded == obj, (
|
|
f"Decoded: {decoded} != \nOriginal: {obj}. \n binary encoded: {encoded}"
|
|
)
|
|
|
|
|
|
def test_supervisor_setup_message_serdes(
|
|
pipeline_shard_meta: Callable[..., PipelineShardMetadata],
|
|
hosts: Callable[..., list[Host]],
|
|
):
|
|
setup_message = SetupMessage(
|
|
model_shard_meta=pipeline_shard_meta(1, 0),
|
|
hosts=hosts(1),
|
|
)
|
|
assert_equal_serdes(setup_message, RunnerMessageTypeAdapter)
|
|
|
|
|
|
def test_supervisor_task_message_serdes(
|
|
chat_completion_task: Callable[[InstanceId, TaskId], Task],
|
|
):
|
|
task = chat_completion_task(InstanceId(), TaskId())
|
|
task_message = ChatTaskMessage(
|
|
task_data=task.task_params,
|
|
)
|
|
assert_equal_serdes(task_message, RunnerMessageTypeAdapter)
|