mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 22:27:50 -05:00
Co-authored-by: Gelu Vrabie <gelu@exolabs.net> Co-authored-by: Alex Cheema <41707476+AlexCheema@users.noreply.github.com> Co-authored-by: Seth Howes <71157822+sethhowes@users.noreply.github.com> Co-authored-by: Matt Beton <matthew.beton@gmail.com> Co-authored-by: Alex Cheema <alexcheema123@gmail.com>
191 lines
6.7 KiB
Python
191 lines
6.7 KiB
Python
import asyncio
|
|
import os
|
|
from logging import Logger
|
|
from typing import Callable, Final
|
|
|
|
import pytest
|
|
|
|
from shared.db.sqlite.event_log_manager import EventLogConfig, EventLogManager
|
|
from shared.types.common import Host, NodeId
|
|
from shared.types.events import InstanceCreated, InstanceDeleted
|
|
from shared.types.models import ModelId
|
|
from shared.types.worker.common import InstanceId, RunnerId
|
|
from shared.types.worker.instances import Instance, InstanceStatus, ShardAssignments
|
|
from shared.types.worker.runners import FailedRunnerStatus
|
|
from shared.types.worker.shards import PipelineShardMetadata
|
|
from worker.download.shard_downloader import NoopShardDownloader
|
|
from worker.main import run
|
|
from worker.worker import Worker
|
|
|
|
MASTER_NODE_ID = NodeId("ffffffff-aaaa-4aaa-8aaa-aaaaaaaaaaaa")
|
|
NODE_A: Final[NodeId] = NodeId("aaaaaaaa-aaaa-4aaa-8aaa-aaaaaaaaaaaa")
|
|
NODE_B: Final[NodeId] = NodeId("bbbbbbbb-bbbb-4bbb-8bbb-bbbbbbbbbbbb")
|
|
|
|
RUNNER_1_ID: Final[RunnerId] = RunnerId("11111111-1111-4111-8111-111111111111")
|
|
INSTANCE_1_ID: Final[InstanceId] = InstanceId("22222222-2222-4222-8222-222222222222")
|
|
RUNNER_2_ID: Final[RunnerId] = RunnerId("33333333-3333-4333-8333-333333333333")
|
|
INSTANCE_2_ID: Final[InstanceId] = InstanceId("44444444-4444-4444-8444-444444444444")
|
|
MODEL_A_ID: Final[ModelId] = 'mlx-community/Llama-3.2-1B-Instruct-4bit'
|
|
MODEL_B_ID: Final[ModelId] = 'mlx-community/Llama-3.2-1B-Instruct-4bit'
|
|
TASK_1_ID: Final = "55555555-5555-4555-8555-555555555555"
|
|
TASK_2_ID: Final = "66666666-6666-4666-8666-666666666666"
|
|
|
|
@pytest.fixture
|
|
def user_message() -> str:
|
|
return "What is the capital of Japan?"
|
|
|
|
@pytest.mark.skipif(
|
|
os.environ.get("DETAILED", "").lower() != "true",
|
|
reason="This test only runs when ENABLE_SPINUP_TIMEOUT_TEST=true environment variable is set"
|
|
)
|
|
async def check_runner_connection(
|
|
logger: Logger,
|
|
pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata],
|
|
hosts: Callable[[int], list[Host]],
|
|
) -> bool:
|
|
# Track all tasks and workers for cleanup
|
|
tasks: list[asyncio.Task[None]] = []
|
|
workers: list[Worker] = []
|
|
|
|
try:
|
|
event_log_manager = EventLogManager(EventLogConfig(), logger)
|
|
await event_log_manager.initialize()
|
|
shard_downloader = NoopShardDownloader()
|
|
|
|
global_events = event_log_manager.global_events
|
|
await global_events.delete_all_events()
|
|
|
|
worker1 = Worker(
|
|
NODE_A,
|
|
logger=logger,
|
|
shard_downloader=shard_downloader,
|
|
worker_events=global_events,
|
|
global_events=global_events,
|
|
)
|
|
workers.append(worker1)
|
|
task1 = asyncio.create_task(run(worker1))
|
|
tasks.append(task1)
|
|
|
|
worker2 = Worker(
|
|
NODE_B,
|
|
logger=logger,
|
|
shard_downloader=shard_downloader,
|
|
worker_events=global_events,
|
|
global_events=global_events,
|
|
)
|
|
workers.append(worker2)
|
|
task2 = asyncio.create_task(run(worker2))
|
|
tasks.append(task2)
|
|
|
|
model_id = ModelId('mlx-community/Llama-3.2-1B-Instruct-4bit')
|
|
|
|
shard_assignments = ShardAssignments(
|
|
model_id=model_id,
|
|
runner_to_shard={
|
|
RUNNER_1_ID: pipeline_shard_meta(2, 0),
|
|
RUNNER_2_ID: pipeline_shard_meta(2, 1)
|
|
},
|
|
node_to_runner={
|
|
NODE_A: RUNNER_1_ID,
|
|
NODE_B: RUNNER_2_ID
|
|
}
|
|
)
|
|
|
|
instance = Instance(
|
|
instance_id=INSTANCE_1_ID,
|
|
instance_type=InstanceStatus.ACTIVE,
|
|
shard_assignments=shard_assignments,
|
|
hosts=hosts(2)
|
|
)
|
|
|
|
await global_events.append_events(
|
|
[
|
|
InstanceCreated(
|
|
instance=instance
|
|
),
|
|
],
|
|
origin=MASTER_NODE_ID
|
|
)
|
|
|
|
from worker.runner.runner_supervisor import RunnerSupervisor
|
|
|
|
async def wait_for_runner_supervisor(worker: Worker, timeout: float = 5.0) -> RunnerSupervisor | None:
|
|
end = asyncio.get_event_loop().time() + timeout
|
|
while True:
|
|
assigned_runners = list(worker.assigned_runners.values())
|
|
if assigned_runners:
|
|
runner = assigned_runners[0].runner
|
|
if isinstance(runner, RunnerSupervisor):
|
|
print('breaking because success')
|
|
return runner
|
|
if isinstance(assigned_runners[0].status, FailedRunnerStatus):
|
|
print('breaking because failed')
|
|
return runner
|
|
if asyncio.get_event_loop().time() > end:
|
|
raise TimeoutError("RunnerSupervisor was not set within timeout")
|
|
await asyncio.sleep(0.001)
|
|
|
|
runner_supervisor = await wait_for_runner_supervisor(worker1, timeout=6.0)
|
|
ret = runner_supervisor is not None and runner_supervisor.healthy
|
|
|
|
await global_events.append_events(
|
|
[
|
|
InstanceDeleted(
|
|
instance_id=instance.instance_id,
|
|
),
|
|
],
|
|
origin=MASTER_NODE_ID
|
|
)
|
|
|
|
await asyncio.sleep(0.5)
|
|
|
|
return ret
|
|
finally:
|
|
# Cancel all worker tasks
|
|
for task in tasks:
|
|
task.cancel()
|
|
|
|
# Wait for cancellation to complete
|
|
await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
# Check Running status
|
|
|
|
# # not now.
|
|
|
|
# def test_runner_connection_stress(
|
|
# logger: Logger,
|
|
# pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata],
|
|
# hosts: Callable[[int], list[Host]],
|
|
# chat_completion_task: Callable[[InstanceId, str], Task],
|
|
# ) -> None:
|
|
# total_runs = 100
|
|
# successes = 0
|
|
|
|
# for _ in range(total_runs):
|
|
# # Create a fresh event loop for each iteration
|
|
# loop = asyncio.new_event_loop()
|
|
# asyncio.set_event_loop(loop)
|
|
|
|
# try:
|
|
# result = loop.run_until_complete(check_runner_connection(
|
|
# logger=logger,
|
|
# pipeline_shard_meta=pipeline_shard_meta,
|
|
# hosts=hosts,
|
|
# chat_completion_task=chat_completion_task,
|
|
# ))
|
|
# if result:
|
|
# successes += 1
|
|
# finally:
|
|
# # Cancel all running tasks
|
|
# pending = asyncio.all_tasks(loop)
|
|
# for task in pending:
|
|
# task.cancel()
|
|
|
|
# # Run the event loop briefly to allow cancellation to complete
|
|
# loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
|
|
|
|
# # Close the event loop
|
|
# loop.close()
|
|
|
|
# print(f"Runner connection successes: {successes} / {total_runs}")
|