mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 22:27:50 -05:00
74 lines
2.7 KiB
Python
74 lines
2.7 KiB
Python
import asyncio
|
|
import tempfile
|
|
from logging import Logger
|
|
from pathlib import Path
|
|
from typing import List
|
|
|
|
import pytest
|
|
|
|
from master.main import Master
|
|
from shared.db.sqlite.config import EventLogConfig
|
|
from shared.db.sqlite.connector import AsyncSQLiteEventStorage
|
|
from shared.db.sqlite.event_log_manager import EventLogManager
|
|
from shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
|
|
from shared.types.common import NodeId
|
|
from shared.types.events import TaskCreated
|
|
from shared.types.events.commands import ChatCompletionCommand, Command, CommandId
|
|
from shared.types.tasks import ChatCompletionTask, TaskStatus, TaskType
|
|
|
|
|
|
def _create_forwarder_dummy_binary() -> Path:
|
|
path = Path(tempfile.mktemp()) / "forwarder.bin"
|
|
if not path.exists():
|
|
path.parent.mkdir(parents=True, exist_ok=True)
|
|
path.write_bytes(b"#!/bin/sh\necho dummy forwarder && sleep 1000000\n")
|
|
path.chmod(0o755)
|
|
return path
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_master():
|
|
logger = Logger(name='test_master_logger')
|
|
event_log_manager = EventLogManager(EventLogConfig(), logger=logger)
|
|
await event_log_manager.initialize()
|
|
global_events: AsyncSQLiteEventStorage = event_log_manager.global_events
|
|
await global_events.delete_all_events()
|
|
|
|
command_buffer: List[Command] = []
|
|
|
|
forwarder_binary_path = _create_forwarder_dummy_binary()
|
|
|
|
node_id = NodeId("aaaaaaaa-aaaa-4aaa-8aaa-aaaaaaaaaaaa")
|
|
master = Master(node_id, command_buffer=command_buffer, global_events=global_events, forwarder_binary_path=forwarder_binary_path, logger=logger)
|
|
asyncio.create_task(master.run())
|
|
|
|
command_buffer.append(
|
|
ChatCompletionCommand(
|
|
command_id=CommandId(),
|
|
request_params=ChatCompletionTaskParams(
|
|
model="llama-3.2-1b",
|
|
messages=[ChatCompletionMessage(role="user", content="Hello, how are you?")]
|
|
)
|
|
)
|
|
)
|
|
while len(await global_events.get_events_since(0)) == 0:
|
|
await asyncio.sleep(0.001)
|
|
|
|
events = await global_events.get_events_since(0)
|
|
assert len(events) == 1
|
|
assert events[0].idx_in_log == 1
|
|
assert isinstance(events[0].event, TaskCreated)
|
|
assert events[0].event == TaskCreated(
|
|
task_id=events[0].event.task_id,
|
|
task=ChatCompletionTask(
|
|
task_id=events[0].event.task_id,
|
|
task_type=TaskType.CHAT_COMPLETION,
|
|
instance_id=events[0].event.task.instance_id,
|
|
task_status=TaskStatus.PENDING,
|
|
task_params=ChatCompletionTaskParams(
|
|
model="llama-3.2-1b",
|
|
messages=[ChatCompletionMessage(role="user", content="Hello, how are you?")]
|
|
)
|
|
)
|
|
)
|
|
assert len(command_buffer) == 0
|