Refactor runner supervisor

Co-authored-by: Gelu Vrabie <gelu@exolabs.net>
This commit is contained in:
Gelu Vrabie
2025-08-18 18:37:52 +01:00
committed by GitHub
parent 345fafd80d
commit ea9e573409
13 changed files with 351 additions and 260 deletions

View File

@@ -7,12 +7,14 @@ from typing import Any, Callable
import mlx.core as mx
import mlx.nn as nn
from mlx_lm.generate import stream_generate # type: ignore
from mlx_lm.sample_utils import make_sampler
from mlx_lm.tokenizer_utils import TokenizerWrapper, load_tokenizer # type: ignore
from mlx_lm.utils import load_model # type: ignore
from pydantic import RootModel
from engines.mlx.auto_parallel import auto_parallel
from shared.types.api import ChatCompletionMessage
from shared.types.common import Host
from shared.types.tasks import ChatCompletionTaskParams
from shared.types.worker.shards import ShardMetadata
@@ -134,6 +136,46 @@ async def apply_chat_template(
return prompt
async def warmup_inference(
mlx_executor: concurrent.futures.ThreadPoolExecutor,
model: nn.Module,
tokenizer: TokenizerWrapper,
sampler: Callable[[mx.array], mx.array],
) -> int:
loop = asyncio.get_running_loop()
warmup_prompt = await apply_chat_template(
mlx_executor=mlx_executor,
tokenizer=tokenizer,
chat_task_data=ChatCompletionTaskParams(
model="warmup",
messages=[
ChatCompletionMessage(
role='user',
content='Prompt to warm up the inference engine. Repeat this.'
)
]
),
)
tokens_generated = 0
def _generate_warmup():
nonlocal tokens_generated
for _ in stream_generate(
model=model,
tokenizer=tokenizer,
prompt=warmup_prompt,
max_tokens=50,
sampler=sampler,
):
tokens_generated += 1
await loop.run_in_executor(mlx_executor, _generate_warmup)
mx_barrier()
return tokens_generated
def mlx_force_oom(size: int = 40000) -> None:
"""

View File

@@ -23,8 +23,6 @@ async def supervisor_write_message(
)
encoded: bytes = message.model_dump_json().encode("utf-8") + b"\n"
print(f"message: {message}")
# print(f"encoded: {encoded}")
proc.stdin.write(encoded)
await proc.stdin.drain()
@@ -63,12 +61,11 @@ async def supervisor_read_response(
"proc.stdout should not be None when created with stdout=PIPE"
)
line_bytes: bytes = await asyncio.wait_for(proc.stdout.readline(), timeout=180)
if not line_bytes:
# return None
raise EOFError("No more data to read when reading response from runner")
line: str = line_bytes.decode("utf-8").strip()
if not line:
return None
try:
return RunnerResponseTypeAdapter.validate_json(line)
except Exception as err:
@@ -98,4 +95,4 @@ def runner_write_error(error: Exception) -> None:
error_message=str(error),
traceback=traceback.format_exc(),
)
runner_write_response(error_response)
runner_write_response(error_response)

View File

@@ -10,7 +10,12 @@ import mlx.nn as nn
from mlx_lm.generate import stream_generate # type: ignore
from mlx_lm.tokenizer_utils import TokenizerWrapper
from engines.mlx.utils_mlx import apply_chat_template, initialize_mlx, mlx_force_oom
from engines.mlx.utils_mlx import (
apply_chat_template,
initialize_mlx,
mlx_force_oom,
warmup_inference,
)
from shared.openai_compat import FinishReason
from shared.types.tasks import ChatCompletionTaskParams
from shared.types.worker.commands_runner import (
@@ -122,6 +127,13 @@ async def main():
partial(initialize_mlx, model_shard_meta=model_shard_meta, hosts=hosts),
)
toks = await warmup_inference(
mlx_executor=mlx_executor,
model=model,
tokenizer=tokenizer,
sampler=sampler,
)
runner_print(f'Warmed up by generating {toks} tokens')
runner_write_response(InitializedResponse(time_taken=time.time() - setup_start_time))
while True:

View File

@@ -1,6 +1,5 @@
import asyncio
import contextlib
import time
import traceback
from collections.abc import AsyncGenerator
from logging import Logger
@@ -19,6 +18,7 @@ from shared.types.worker.commands_runner import (
GenerationResponse,
InitializedResponse,
PrintResponse,
RunnerMessage,
RunnerResponse,
SetupMessage,
)
@@ -34,37 +34,34 @@ from worker.runner.utils import (
get_runner_command,
get_token_generate_timeout,
get_weights_size_kb,
kill_process_tree,
)
class RunnerSupervisor:
"""
RunnerSupervisor manages the lifecycle of a runner subprocess for model inference.
Use the class method `create` to properly initialize an instance.
"""
# TODO: Logger.
def __init__(
self,
model_shard_meta: ShardMetadata,
hosts: list[Host],
runner_process: asyncio.subprocess.Process,
logger: Logger,
read_queue: asyncio.Queue[RunnerResponse],
write_queue: asyncio.Queue[RunnerMessage],
stderr_queue: asyncio.Queue[str],
):
"""Private constructor. Use RunnerSupervisor.create() instead."""
self.model_shard_meta: ShardMetadata = model_shard_meta
self.hosts: list[Host] = hosts
self.runner_process: asyncio.subprocess.Process = runner_process
self.running: bool = True
self.stderr_queue = asyncio.Queue[tuple[float, str]]()
self.stderr_task = asyncio.create_task(self._watch_stderr(logger, self.stderr_queue))
self.running_task: asyncio.Task[None] = asyncio.create_task(
self._watch_runner()
)
self.logger = logger
self.returncode: int | None = None
self.stderr_outpu: str | None = None
self.model_shard_meta = model_shard_meta
self.hosts = hosts
self.runner_process = runner_process
self.read_queue = read_queue
self.write_queue = write_queue
self.stderr_queue = stderr_queue
self.read_task = asyncio.create_task(self._read_coro())
self.write_task = asyncio.create_task(self._write_coro())
self.stderr_task = asyncio.create_task(self._watch_stderr())
@classmethod
async def create(
@@ -79,8 +76,7 @@ class RunnerSupervisor:
The .create() classmethod pattern is used to ensure the constructor is asynchronous.
"""
cmd: list[str] = get_runner_command()
runner_process: asyncio.subprocess.Process = (
runner_process = (
await asyncio.create_subprocess_exec(
*cmd,
stdin=asyncio.subprocess.PIPE,
@@ -88,63 +84,170 @@ class RunnerSupervisor:
stderr=asyncio.subprocess.PIPE,
)
)
logger.info(f'initializing mlx instance with {model_shard_meta=}')
read_queue: asyncio.Queue[RunnerResponse] = asyncio.Queue()
write_queue: asyncio.Queue[RunnerMessage] = asyncio.Queue()
stderr_queue: asyncio.Queue[str] = asyncio.Queue()
self = cls(
model_shard_meta=model_shard_meta,
hosts=hosts,
runner_process=runner_process,
logger=logger,
read_queue=read_queue,
write_queue=write_queue,
stderr_queue=stderr_queue,
)
await supervisor_write_message(
runner_process,
SetupMessage(
model_shard_meta=model_shard_meta,
hosts=hosts,
),
)
async def read_initialization_message() -> None:
while True:
try:
line: RunnerResponse | None = await supervisor_read_response(
self.runner_process
)
if line is None:
continue
except EOFError:
if not self.runner_process.returncode:
continue
raise await self._raise_crashed() from EOFError
if isinstance(line, PrintResponse):
self.logger.info(f"runner printed: {line.text}")
continue
elif isinstance(line, ErrorResponse):
raise RunnerError(line.error_type, line.error_message, line.traceback or "")
elif isinstance(line, InitializedResponse):
assert isinstance(line, InitializedResponse)
logger.info(f'Runner initialized in {line.time_taken} seconds')
break
else:
raise AssertionError(f'Non-valid line read from runner during initialization: {line}')
self.logger.info(f'initializing mlx instance with {model_shard_meta=}')
await self.write_queue.put(SetupMessage(
model_shard_meta=model_shard_meta,
hosts=hosts,
))
if not initialize_timeout:
initialize_timeout = get_init_timeout(model_shard_meta)
await asyncio.wait_for(read_initialization_message(), timeout=initialize_timeout)
response = await self._read_with_error_check(initialize_timeout)
assert isinstance(response, InitializedResponse)
self.logger.info(f'Runner initialized in {response.time_taken} seconds')
return self
async def _read_with_error_check(self, timeout: float) -> RunnerResponse:
"""
Read from the queue with a timeout, but also check if the read_task has failed.
"""
queue_task = asyncio.create_task(self.read_queue.get())
done, pending = await asyncio.wait(
[queue_task, self.read_task],
timeout=timeout,
return_when=asyncio.FIRST_COMPLETED
)
for task in pending:
if task is queue_task:
task.cancel()
if queue_task in done:
response = await queue_task
if isinstance(response, ErrorResponse):
raise RunnerError(response.error_type, response.error_message, response.traceback or "")
return response
if self.read_task in done:
await self.read_task # Re-raises any exception from read_task
self.logger.error('Unreachable code run. We should have raised an error on the read_task being done.')
# if we haven't read from the queue, we have timed out.
await self.astop()
raise asyncio.TimeoutError()
async def stream_response(
self,
task: Task,
request_started_callback: Callable[..., CoroutineType[Any, Any, None]] | None = None,
) -> AsyncGenerator[GenerationChunk]:
"""
Streams a chat request from the model.
The request is pushed to the runner, and if the shard is the terminal shard, the response is streamed back to the worker.
request_started_callback is called once the request is pushed to the runner, used to publish InferencePrepareCompleted and InferenceTriggerCompleted events.
"""
if not self.healthy:
raise RuntimeError("Runner process was found to be dead")
task_params = task.task_params
assert isinstance(task_params, ChatCompletionTaskParams) # this is messy for now.
await self.write_queue.put(
ChatTaskMessage(
task_data=task_params,
),
)
# This is simpler for now: we say 'request started' as soon as we've told runner to start, without waiting for an ack.
# If we need more reliability, the runner can have a new 'ready' message type.
if request_started_callback is not None:
await request_started_callback()
prefil_timeout = get_prefil_timeout(self.model_shard_meta)
token_timeout = get_token_generate_timeout(self.model_shard_meta)
timeout = prefil_timeout
self.logger.info(f'starting chat completion with timeout {timeout}')
while True:
try:
response = await self._read_with_error_check(timeout)
except asyncio.TimeoutError as e:
self.logger.info(f'timed out from timeout duration {timeout} - {"prefil" if timeout == prefil_timeout else "decoding stage"}')
raise e
match response:
case GenerationResponse():
yield TokenChunk(
command_id=CommandId(task.command_id),
idx=response.token,
model=self.model_shard_meta.model_meta.model_id,
text=response.text,
token_id=response.token,
finish_reason=response.finish_reason,
)
timeout = token_timeout
case FinishedResponse():
break
case ErrorResponse():
await self.astop()
raise RunnerError(response.error_type, response.error_message, response.traceback)
case _:
raise ValueError(f'Unexpected response type found: {response}')
async def _write_coro(self):
while True:
message = await self.write_queue.get()
await supervisor_write_message(
self.runner_process,
message
)
async def _read_coro(self):
while True:
response: RunnerResponse | None = await supervisor_read_response(
self.runner_process
)
if response is None:
# Runner process died unexpectedly (C++ crash)
e = await self._raise_crashed()
if e:
raise e from EOFError
else:
break
match response:
case PrintResponse():
self.logger.info(f"runner printed: {response.text}")
case ErrorResponse():
## Failure case #1: a crash happens Python, so it's neatly handled by passing an ErrorResponse with the details
await self.read_queue.put(response)
case _:
await self.read_queue.put(response)
async def astop(self) -> None:
# Cancel the stderr monitoring task
if not self.stderr_task.done():
self.stderr_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self.stderr_task
async def await_task(task: asyncio.Task[Any]):
if not task.done():
task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await task
await await_task(self.stderr_task)
await await_task(self.read_task)
await await_task(self.write_task)
# Kill the process and all its children
await self._kill_process_tree()
await kill_process_tree(self.runner_process, self.logger)
# Wait to make sure that the model has been unloaded from memory
async def wait_for_memory_release() -> None:
@@ -160,89 +263,9 @@ class RunnerSupervisor:
await asyncio.sleep(0.1)
await wait_for_memory_release()
self.running = False
async def _kill_process_tree(self) -> None:
"""Kill the process and all its children forcefully."""
if self.runner_process.returncode is not None:
return # Process already dead
try:
# Get the main process
pid = self.runner_process.pid
# Find all child processes
try:
parent = psutil.Process(pid)
children = parent.children(recursive=True)
# Kill all children first (bottom-up)
for child in reversed(children):
with contextlib.suppress(psutil.NoSuchProcess, psutil.AccessDenied):
child.kill() # SIGKILL
# Kill the parent
with contextlib.suppress(psutil.NoSuchProcess, psutil.AccessDenied):
parent.kill() # SIGKILL
except psutil.NoSuchProcess:
# Process already gone, try subprocess kill anyway
self.runner_process.kill()
# Wait for the subprocess to exit
try:
await asyncio.wait_for(self.runner_process.wait(), timeout=2.0)
except asyncio.TimeoutError:
self.logger.error(f"Process {pid} did not exit after kill signal")
except Exception as e:
self.logger.error(f"Error killing process tree: {e}")
async def _watch_runner(self) -> None:
returncode = await self.runner_process.wait()
self.running = False
if returncode != 0:
self.returncode = returncode # Will be picked up by _watch_stderr too
await self.astop()
async def _watch_stderr(self, logger: Logger, stderr_queue: asyncio.Queue[tuple[float, str]]) -> None:
assert self.runner_process.stderr is not None
while self.running:
try:
line_bytes = await self.runner_process.stderr.readline()
if not line_bytes:
break
line = line_bytes.decode('utf-8').strip()
await stderr_queue.put((time.time(), line))
logger.warning(f"Runner stderr read: {line}")
except Exception as e:
logger.warning(f"Error reading runner stderr: {e}")
break
async def _raise_crashed(self) -> Exception:
await self.astop()
# Accumulate all stderr messages from the queue
stderr_output = ''
while not self.stderr_queue.empty():
try:
timestamp, line = self.stderr_queue.get_nowait()
stderr_output += f"[{timestamp}] {line}\n"
except asyncio.QueueEmpty:
break
self.logger.error(f'Error {self.returncode}: {stderr_output}')
return RunnerError(
error_type="MLXCrash",
error_message=stderr_output,
traceback=traceback.format_exc(),
)
def __del__(self) -> None:
if self.running:
if self.runner_process.returncode is None:
print(
"Warning: RunnerSupervisor was not stopped cleanly before garbage collection. Force killing process tree."
)
@@ -264,79 +287,49 @@ class RunnerSupervisor:
@property
def healthy(self) -> bool:
return (
self.running
and self.runner_process.returncode is None
self.runner_process.returncode is None
and self.runner_process.stdin is not None
and not self.runner_process.stdin.is_closing()
and self.runner_process.stdout is not None
)
async def stream_response(
self,
task: Task,
request_started_callback: Callable[..., CoroutineType[Any, Any, None]] | None = None,
) -> AsyncGenerator[GenerationChunk]:
"""
Streams a chat request from the model.
The request is pushed to the runner, and if the shard is the terminal shard, the response is streamed back to the worker.
request_started_callback is called once the request is pushed to the runner, used to publish InferencePrepareCompleted and InferenceTriggerCompleted events.
"""
if not self.healthy:
raise RuntimeError("Runner process was found to be dead")
task_params = task.task_params
assert isinstance(task_params, ChatCompletionTaskParams) # this is messy for now.
await supervisor_write_message(
proc=self.runner_process,
message=ChatTaskMessage(
task_data=task_params,
),
)
# This is easy for now. If we need more reliability, the runner can have a new 'ready' message type.
if request_started_callback is not None:
await request_started_callback()
prefil_timeout = get_prefil_timeout(task, self.model_shard_meta)
token_timeout = get_token_generate_timeout(self.model_shard_meta)
timeout = prefil_timeout
self.logger.info(f'starting chat completion with timeout {timeout}')
## Failure case #2: a crash happens in MLX / C++ (eg segfault) that leads to error flushed to stderr and process dies
async def _raise_crashed(self) -> Exception | None:
if self.runner_process.returncode == 0:
return None
await self.astop()
# Accumulate all stderr messages from the queue
stderr_output = ''
while not self.stderr_queue.empty():
try:
line = self.stderr_queue.get_nowait()
stderr_output += f"{line}\n"
except asyncio.QueueEmpty:
break
# print('STDERR OUTPUT IS')
# print(stderr_output)
self.logger.error(f'Error {self.runner_process.returncode}: {stderr_output}')
return RunnerError(
error_type="MLXCrash",
error_message=stderr_output,
traceback=traceback.format_exc(),
)
async def _watch_stderr(self) -> None:
assert self.runner_process.stderr is not None
while True:
try:
line: RunnerResponse | None = await asyncio.wait_for(supervisor_read_response(
self.runner_process
), timeout=timeout)
if line is None:
continue
except asyncio.TimeoutError as e:
self.logger.info(f'timed out from timeout duration {timeout} - {"prefil" if timeout == prefil_timeout else "decoding stage"}')
await self.astop()
raise RunnerError(
error_type=type(e).__name__,
error_message=str(e),
traceback=traceback.format_exc(),
) from e
# TODO: change this to a return none instead of error coming from the supervisor_Read_respons3
except EOFError as e:
if not self.runner_process.returncode:
continue
raise await self._raise_crashed() from e
match line:
case GenerationResponse():
yield TokenChunk(
command_id=CommandId(task.command_id),
idx=line.token,
model=self.model_shard_meta.model_meta.model_id,
text=line.text,
token_id=line.token,
finish_reason=line.finish_reason,
)
timeout = token_timeout
case InitializedResponse():
raise ValueError('Initialized Response read during streaming flow')
case FinishedResponse():
line_bytes = await self.runner_process.stderr.readline()
if not line_bytes:
break
case PrintResponse():
# print(f"runner printed: {line.text}")
self.logger.info(f"runner printed: {line.text}")
case ErrorResponse():
await self.astop()
raise RunnerError(line.error_type, line.error_message, line.traceback)
line = line_bytes.decode('utf-8').strip()
await self.stderr_queue.put(line)
self.logger.warning(f"Runner stderr read: {line}")
except Exception as e:
self.logger.warning(f"Error reading runner stderr: {e}")
break

View File

@@ -1,10 +1,50 @@
import asyncio
import contextlib
import sys
from logging import Logger
from shared.constants import LB_DISK_GBPS, LB_MEMBW_GBPS
from shared.types.tasks import Task
import psutil
from shared.constants import LB_DISK_GBPS, LB_MEMBW_GBPS, LB_TFLOPS
from shared.types.worker.shards import ShardMetadata
async def kill_process_tree(runner_process: asyncio.subprocess.Process, logger: Logger) -> None:
"""Kill the process and all its children forcefully."""
if runner_process.returncode is not None:
return # Process already dead
try:
# Get the main process
pid = runner_process.pid
# Find all child processes
try:
parent = psutil.Process(pid)
children = parent.children(recursive=True)
# Kill all children first (bottom-up)
for child in reversed(children):
with contextlib.suppress(psutil.NoSuchProcess, psutil.AccessDenied):
child.kill() # SIGKILL
# Kill the parent
with contextlib.suppress(psutil.NoSuchProcess, psutil.AccessDenied):
parent.kill() # SIGKILL
except psutil.NoSuchProcess:
# Process already gone, try subprocess kill anyway
runner_process.kill()
# Wait for the subprocess to exit
try:
await asyncio.wait_for(runner_process.wait(), timeout=2.0)
except asyncio.TimeoutError:
logger.error(f"Process {pid} did not exit after kill signal")
except Exception as e:
logger.error(f"Error killing process tree: {e}")
def get_runner_command() -> list[str]:
python = sys.executable
return [python, "-m", "worker.runner.runner"]
@@ -19,20 +59,13 @@ def get_init_timeout(model_shard_meta: ShardMetadata) -> float:
return weights_size_kb / kbps_read + 2.0
def get_prefil_timeout(task: Task, model_shard_meta: ShardMetadata) -> float:
def get_prompt_str(task: Task) -> str:
messages = [x.content for x in task.task_params.messages if x.content]
return ''.join(messages)
def get_prefil_timeout(model_shard_meta: ShardMetadata) -> float:
weights_size_gb = get_weights_size_kb(model_shard_meta) / (1024 * 1024)
tokens = 1000 # constant for now - the prompt is only tokenized in the device...
prompt_gflops = tokens * weights_size_gb * 2
# TODO: made this timeout very long
tokens = len(get_prompt_str(task)) // 3 + 3000 # constant for now - the prompt is only tokenized in the device...
# TODO: For now we just hack and assume we prefil at 10tok/s
return tokens * 0.1
# prompt_gflops = tokens * weights_size_gb * 2
# return LB_TFLOPS / (1024 * prompt_gflops) * 3 + 10.0
return LB_TFLOPS / (1024 * prompt_gflops) * 3 + 10.0
def get_token_generate_timeout(model_shard_meta: ShardMetadata) -> float:
weights_size_kb = get_weights_size_kb(model_shard_meta)

View File

@@ -35,7 +35,19 @@ def user_message():
@pytest.fixture
def logger() -> Logger:
return getLogger("test_logger")
import logging
logger = getLogger("test_logger")
logger.setLevel(logging.DEBUG)
# Add console handler if none exists
if not logger.handlers:
handler = logging.StreamHandler()
handler.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger
@pytest.fixture
async def model_meta() -> ModelMetadata:

View File

@@ -74,7 +74,7 @@ async def test_execute_task_timeouts(
task=task
)
with pytest.raises(RunnerError): # At the moment this is a RunnerError that says 'TimeoutError'.
with pytest.raises(asyncio.TimeoutError):
await read_events_op(worker, execute_task_op)

View File

@@ -164,7 +164,7 @@ async def test_stream_response_failed_once(
assert isinstance(event.chunk, TokenChunk)
response_string += event.chunk.text
assert 'elizabeth' in response_string.lower()
assert 'queen' in response_string.lower()
assert seen_task_started
assert seen_task_finished
@@ -206,7 +206,7 @@ async def test_stream_response_timeout(
print(events)
assert len([x for x in events if isinstance(x.event, RunnerStatusUpdated) and isinstance(x.event.runner_status, FailedRunnerStatus)]) == 3
assert len([x for x in events if isinstance(x.event, TaskStateUpdated) and x.event.task_status == TaskStatus.FAILED]) == 3
assert len([x for x in events if isinstance(x.event, TaskFailed) and 'timeouterror' in x.event.error_message.lower()]) == 3
assert len([x for x in events if isinstance(x.event, TaskFailed) and 'timeouterror' in x.event.error_type.lower()]) == 3
await global_events.append_events(
[

View File

@@ -53,7 +53,10 @@ async def test_runner_spinup_exception(
# Ensure the correct events have been emitted
events = await global_events.get_events_since(0)
assert len([x for x in events if isinstance(x.event, RunnerStatusUpdated) and isinstance(x.event.runner_status, FailedRunnerStatus)]) == 3
assert len([x for x in events if isinstance(x.event, RunnerStatusUpdated) \
and isinstance(x.event.runner_status, FailedRunnerStatus) \
and x.event.runner_status.error_message is not None \
and 'fake exception' in x.event.runner_status.error_message.lower()]) == 3
assert any([isinstance(x.event, InstanceDeleted) for x in events])

View File

@@ -21,7 +21,8 @@ def user_message():
@pytest.mark.asyncio
async def test_supervisor_single_node_response(
@pytest.mark.skip(reason="Must run `sudo sysctl -w iogpu.wired_limit_mb=` and `sudo sysctl -w iogpu.wired_lwm_mb=` before running this test.")
async def test_supervisor_catches_oom(
pipeline_shard_meta: Callable[..., PipelineShardMetadata],
hosts: Callable[..., list[Host]],
chat_completion_task: Callable[[InstanceId, TaskId], Task],
@@ -38,8 +39,11 @@ async def test_supervisor_single_node_response(
task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID)
task.task_params.messages[0].content = 'EXO RUNNER MUST OOM'
with pytest.raises(RunnerError):
async for _ in supervisor.stream_response(task):
pass
with pytest.raises(RunnerError) as exc_info:
async for _ in supervisor.stream_response(task):
pass
error = exc_info.value
assert 'memory' in error.error_message.lower()
await supervisor.astop()

View File

@@ -72,20 +72,17 @@ async def test_supervisor_two_node_response(
):
"""Test that asking for the capital of France returns 'Paris' in the response"""
instance_id = InstanceId()
create_supervisor_0 = asyncio.create_task(
RunnerSupervisor.create(
model_shard_meta=pipeline_shard_meta(2, 0),
async def create_supervisor(shard_idx: int) -> RunnerSupervisor:
supervisor = await RunnerSupervisor.create(
model_shard_meta=pipeline_shard_meta(2, shard_idx),
hosts=hosts(2, offset=15),
logger=logger,
)
)
create_supervisor_1 = asyncio.create_task(
RunnerSupervisor.create(
model_shard_meta=pipeline_shard_meta(2, 1),
hosts=hosts(2, offset=15),
logger=logger,
)
)
return supervisor
create_supervisor_0 = asyncio.create_task(create_supervisor(0))
create_supervisor_1 = asyncio.create_task(create_supervisor(1))
supervisor_0, supervisor_1 = await asyncio.gather(create_supervisor_0, create_supervisor_1)
await asyncio.sleep(0.1)

View File

@@ -23,7 +23,7 @@ async def test_supervisor_instantiation_exception(
model_shard_meta.immediate_exception = True
with pytest.raises(RunnerError):
await RunnerSupervisor.create(
_ = await RunnerSupervisor.create(
model_shard_meta=model_shard_meta,
hosts=hosts(1, offset=10),
logger=logger,
@@ -40,7 +40,7 @@ async def test_supervisor_instantiation_timeout(
model_shard_meta.should_timeout = 10 # timeout after 10s
with pytest.raises(asyncio.TimeoutError):
await RunnerSupervisor.create(
_ = await RunnerSupervisor.create(
model_shard_meta=model_shard_meta,
hosts=hosts(1, offset=10),
logger=logger,
@@ -88,7 +88,7 @@ async def test_supervisor_inference_timeout(
task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID)
task.task_params.messages[0].content = 'EXO RUNNER MUST TIMEOUT'
with pytest.raises(RunnerError):
with pytest.raises(asyncio.TimeoutError):
async for _ in supervisor.stream_response(task):
pass

View File

@@ -171,7 +171,7 @@ class Worker:
This op assigns the runner, and moves from Downloading -> Inactive (ready to spin) state.
"""
assigned_runner = self._create_assigned_runner(op)
initial_progress = await asyncio.wait_for(self.shard_downloader.get_shard_download_status_for_shard(op.shard_metadata), timeout=15)
initial_progress = await self.shard_downloader.get_shard_download_status_for_shard(op.shard_metadata)
if initial_progress.status == "complete":
async for event in self._handle_already_downloaded_shard(assigned_runner):
@@ -217,8 +217,6 @@ class Worker:
runner = assigned_runner.runner
health_issues: list[str] = []
if not runner.running:
health_issues.append("runner.running is False")
if runner.runner_process.returncode is not None:
health_issues.append(f"runner_process.returncode is {runner.runner_process.returncode}")
if runner.runner_process.stdin is None:
@@ -348,6 +346,7 @@ class Worker:
## Operation Planner
async def execute_op(self, op: RunnerOp) -> AsyncGenerator[Event, None]:
## It would be great if we can get rid of this async for ... yield pattern.
match op.op_type:
case RunnerOpType.ASSIGN_RUNNER:
event_generator = self._execute_assign_op(op)
@@ -410,4 +409,3 @@ class Worker:
assert self.worker_events is not None
await self.worker_events.append_events([event], self.node_id)
self.logger.info(f"published event: {event}")