mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 22:27:50 -05:00
Refactor runner supervisor
Co-authored-by: Gelu Vrabie <gelu@exolabs.net>
This commit is contained in:
@@ -7,12 +7,14 @@ from typing import Any, Callable
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx_lm.generate import stream_generate # type: ignore
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper, load_tokenizer # type: ignore
|
||||
from mlx_lm.utils import load_model # type: ignore
|
||||
from pydantic import RootModel
|
||||
|
||||
from engines.mlx.auto_parallel import auto_parallel
|
||||
from shared.types.api import ChatCompletionMessage
|
||||
from shared.types.common import Host
|
||||
from shared.types.tasks import ChatCompletionTaskParams
|
||||
from shared.types.worker.shards import ShardMetadata
|
||||
@@ -134,6 +136,46 @@ async def apply_chat_template(
|
||||
|
||||
return prompt
|
||||
|
||||
async def warmup_inference(
|
||||
mlx_executor: concurrent.futures.ThreadPoolExecutor,
|
||||
model: nn.Module,
|
||||
tokenizer: TokenizerWrapper,
|
||||
sampler: Callable[[mx.array], mx.array],
|
||||
) -> int:
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
warmup_prompt = await apply_chat_template(
|
||||
mlx_executor=mlx_executor,
|
||||
tokenizer=tokenizer,
|
||||
chat_task_data=ChatCompletionTaskParams(
|
||||
model="warmup",
|
||||
messages=[
|
||||
ChatCompletionMessage(
|
||||
role='user',
|
||||
content='Prompt to warm up the inference engine. Repeat this.'
|
||||
)
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
tokens_generated = 0
|
||||
|
||||
def _generate_warmup():
|
||||
nonlocal tokens_generated
|
||||
for _ in stream_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=warmup_prompt,
|
||||
max_tokens=50,
|
||||
sampler=sampler,
|
||||
):
|
||||
tokens_generated += 1
|
||||
|
||||
await loop.run_in_executor(mlx_executor, _generate_warmup)
|
||||
mx_barrier()
|
||||
|
||||
return tokens_generated
|
||||
|
||||
|
||||
def mlx_force_oom(size: int = 40000) -> None:
|
||||
"""
|
||||
|
||||
@@ -23,8 +23,6 @@ async def supervisor_write_message(
|
||||
)
|
||||
|
||||
encoded: bytes = message.model_dump_json().encode("utf-8") + b"\n"
|
||||
print(f"message: {message}")
|
||||
# print(f"encoded: {encoded}")
|
||||
proc.stdin.write(encoded)
|
||||
await proc.stdin.drain()
|
||||
|
||||
@@ -63,12 +61,11 @@ async def supervisor_read_response(
|
||||
"proc.stdout should not be None when created with stdout=PIPE"
|
||||
)
|
||||
line_bytes: bytes = await asyncio.wait_for(proc.stdout.readline(), timeout=180)
|
||||
if not line_bytes:
|
||||
# return None
|
||||
raise EOFError("No more data to read when reading response from runner")
|
||||
|
||||
line: str = line_bytes.decode("utf-8").strip()
|
||||
|
||||
if not line:
|
||||
return None
|
||||
|
||||
try:
|
||||
return RunnerResponseTypeAdapter.validate_json(line)
|
||||
except Exception as err:
|
||||
@@ -98,4 +95,4 @@ def runner_write_error(error: Exception) -> None:
|
||||
error_message=str(error),
|
||||
traceback=traceback.format_exc(),
|
||||
)
|
||||
runner_write_response(error_response)
|
||||
runner_write_response(error_response)
|
||||
@@ -10,7 +10,12 @@ import mlx.nn as nn
|
||||
from mlx_lm.generate import stream_generate # type: ignore
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
from engines.mlx.utils_mlx import apply_chat_template, initialize_mlx, mlx_force_oom
|
||||
from engines.mlx.utils_mlx import (
|
||||
apply_chat_template,
|
||||
initialize_mlx,
|
||||
mlx_force_oom,
|
||||
warmup_inference,
|
||||
)
|
||||
from shared.openai_compat import FinishReason
|
||||
from shared.types.tasks import ChatCompletionTaskParams
|
||||
from shared.types.worker.commands_runner import (
|
||||
@@ -122,6 +127,13 @@ async def main():
|
||||
partial(initialize_mlx, model_shard_meta=model_shard_meta, hosts=hosts),
|
||||
)
|
||||
|
||||
toks = await warmup_inference(
|
||||
mlx_executor=mlx_executor,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
sampler=sampler,
|
||||
)
|
||||
runner_print(f'Warmed up by generating {toks} tokens')
|
||||
runner_write_response(InitializedResponse(time_taken=time.time() - setup_start_time))
|
||||
|
||||
while True:
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
import time
|
||||
import traceback
|
||||
from collections.abc import AsyncGenerator
|
||||
from logging import Logger
|
||||
@@ -19,6 +18,7 @@ from shared.types.worker.commands_runner import (
|
||||
GenerationResponse,
|
||||
InitializedResponse,
|
||||
PrintResponse,
|
||||
RunnerMessage,
|
||||
RunnerResponse,
|
||||
SetupMessage,
|
||||
)
|
||||
@@ -34,37 +34,34 @@ from worker.runner.utils import (
|
||||
get_runner_command,
|
||||
get_token_generate_timeout,
|
||||
get_weights_size_kb,
|
||||
kill_process_tree,
|
||||
)
|
||||
|
||||
|
||||
class RunnerSupervisor:
|
||||
"""
|
||||
RunnerSupervisor manages the lifecycle of a runner subprocess for model inference.
|
||||
Use the class method `create` to properly initialize an instance.
|
||||
"""
|
||||
# TODO: Logger.
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_shard_meta: ShardMetadata,
|
||||
hosts: list[Host],
|
||||
runner_process: asyncio.subprocess.Process,
|
||||
logger: Logger,
|
||||
read_queue: asyncio.Queue[RunnerResponse],
|
||||
write_queue: asyncio.Queue[RunnerMessage],
|
||||
stderr_queue: asyncio.Queue[str],
|
||||
):
|
||||
"""Private constructor. Use RunnerSupervisor.create() instead."""
|
||||
self.model_shard_meta: ShardMetadata = model_shard_meta
|
||||
self.hosts: list[Host] = hosts
|
||||
self.runner_process: asyncio.subprocess.Process = runner_process
|
||||
self.running: bool = True
|
||||
|
||||
self.stderr_queue = asyncio.Queue[tuple[float, str]]()
|
||||
self.stderr_task = asyncio.create_task(self._watch_stderr(logger, self.stderr_queue))
|
||||
self.running_task: asyncio.Task[None] = asyncio.create_task(
|
||||
self._watch_runner()
|
||||
)
|
||||
self.logger = logger
|
||||
self.returncode: int | None = None
|
||||
self.stderr_outpu: str | None = None
|
||||
|
||||
self.model_shard_meta = model_shard_meta
|
||||
self.hosts = hosts
|
||||
self.runner_process = runner_process
|
||||
|
||||
self.read_queue = read_queue
|
||||
self.write_queue = write_queue
|
||||
self.stderr_queue = stderr_queue
|
||||
|
||||
self.read_task = asyncio.create_task(self._read_coro())
|
||||
self.write_task = asyncio.create_task(self._write_coro())
|
||||
self.stderr_task = asyncio.create_task(self._watch_stderr())
|
||||
|
||||
@classmethod
|
||||
async def create(
|
||||
@@ -79,8 +76,7 @@ class RunnerSupervisor:
|
||||
The .create() classmethod pattern is used to ensure the constructor is asynchronous.
|
||||
"""
|
||||
cmd: list[str] = get_runner_command()
|
||||
|
||||
runner_process: asyncio.subprocess.Process = (
|
||||
runner_process = (
|
||||
await asyncio.create_subprocess_exec(
|
||||
*cmd,
|
||||
stdin=asyncio.subprocess.PIPE,
|
||||
@@ -88,63 +84,170 @@ class RunnerSupervisor:
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
)
|
||||
logger.info(f'initializing mlx instance with {model_shard_meta=}')
|
||||
|
||||
|
||||
read_queue: asyncio.Queue[RunnerResponse] = asyncio.Queue()
|
||||
write_queue: asyncio.Queue[RunnerMessage] = asyncio.Queue()
|
||||
stderr_queue: asyncio.Queue[str] = asyncio.Queue()
|
||||
|
||||
self = cls(
|
||||
model_shard_meta=model_shard_meta,
|
||||
hosts=hosts,
|
||||
runner_process=runner_process,
|
||||
logger=logger,
|
||||
read_queue=read_queue,
|
||||
write_queue=write_queue,
|
||||
stderr_queue=stderr_queue,
|
||||
)
|
||||
|
||||
await supervisor_write_message(
|
||||
runner_process,
|
||||
SetupMessage(
|
||||
model_shard_meta=model_shard_meta,
|
||||
hosts=hosts,
|
||||
),
|
||||
)
|
||||
|
||||
async def read_initialization_message() -> None:
|
||||
while True:
|
||||
try:
|
||||
line: RunnerResponse | None = await supervisor_read_response(
|
||||
self.runner_process
|
||||
)
|
||||
if line is None:
|
||||
continue
|
||||
except EOFError:
|
||||
if not self.runner_process.returncode:
|
||||
continue
|
||||
raise await self._raise_crashed() from EOFError
|
||||
|
||||
if isinstance(line, PrintResponse):
|
||||
self.logger.info(f"runner printed: {line.text}")
|
||||
continue
|
||||
elif isinstance(line, ErrorResponse):
|
||||
raise RunnerError(line.error_type, line.error_message, line.traceback or "")
|
||||
elif isinstance(line, InitializedResponse):
|
||||
assert isinstance(line, InitializedResponse)
|
||||
logger.info(f'Runner initialized in {line.time_taken} seconds')
|
||||
break
|
||||
else:
|
||||
raise AssertionError(f'Non-valid line read from runner during initialization: {line}')
|
||||
self.logger.info(f'initializing mlx instance with {model_shard_meta=}')
|
||||
await self.write_queue.put(SetupMessage(
|
||||
model_shard_meta=model_shard_meta,
|
||||
hosts=hosts,
|
||||
))
|
||||
|
||||
if not initialize_timeout:
|
||||
initialize_timeout = get_init_timeout(model_shard_meta)
|
||||
await asyncio.wait_for(read_initialization_message(), timeout=initialize_timeout)
|
||||
|
||||
response = await self._read_with_error_check(initialize_timeout)
|
||||
|
||||
assert isinstance(response, InitializedResponse)
|
||||
self.logger.info(f'Runner initialized in {response.time_taken} seconds')
|
||||
|
||||
return self
|
||||
|
||||
|
||||
async def _read_with_error_check(self, timeout: float) -> RunnerResponse:
|
||||
"""
|
||||
Read from the queue with a timeout, but also check if the read_task has failed.
|
||||
"""
|
||||
queue_task = asyncio.create_task(self.read_queue.get())
|
||||
|
||||
done, pending = await asyncio.wait(
|
||||
[queue_task, self.read_task],
|
||||
timeout=timeout,
|
||||
return_when=asyncio.FIRST_COMPLETED
|
||||
)
|
||||
|
||||
for task in pending:
|
||||
if task is queue_task:
|
||||
task.cancel()
|
||||
|
||||
if queue_task in done:
|
||||
response = await queue_task
|
||||
if isinstance(response, ErrorResponse):
|
||||
raise RunnerError(response.error_type, response.error_message, response.traceback or "")
|
||||
return response
|
||||
|
||||
if self.read_task in done:
|
||||
await self.read_task # Re-raises any exception from read_task
|
||||
self.logger.error('Unreachable code run. We should have raised an error on the read_task being done.')
|
||||
|
||||
# if we haven't read from the queue, we have timed out.
|
||||
await self.astop()
|
||||
raise asyncio.TimeoutError()
|
||||
|
||||
async def stream_response(
|
||||
self,
|
||||
task: Task,
|
||||
request_started_callback: Callable[..., CoroutineType[Any, Any, None]] | None = None,
|
||||
) -> AsyncGenerator[GenerationChunk]:
|
||||
"""
|
||||
Streams a chat request from the model.
|
||||
The request is pushed to the runner, and if the shard is the terminal shard, the response is streamed back to the worker.
|
||||
request_started_callback is called once the request is pushed to the runner, used to publish InferencePrepareCompleted and InferenceTriggerCompleted events.
|
||||
"""
|
||||
if not self.healthy:
|
||||
raise RuntimeError("Runner process was found to be dead")
|
||||
|
||||
task_params = task.task_params
|
||||
assert isinstance(task_params, ChatCompletionTaskParams) # this is messy for now.
|
||||
await self.write_queue.put(
|
||||
ChatTaskMessage(
|
||||
task_data=task_params,
|
||||
),
|
||||
)
|
||||
|
||||
# This is simpler for now: we say 'request started' as soon as we've told runner to start, without waiting for an ack.
|
||||
# If we need more reliability, the runner can have a new 'ready' message type.
|
||||
if request_started_callback is not None:
|
||||
await request_started_callback()
|
||||
|
||||
prefil_timeout = get_prefil_timeout(self.model_shard_meta)
|
||||
token_timeout = get_token_generate_timeout(self.model_shard_meta)
|
||||
timeout = prefil_timeout
|
||||
self.logger.info(f'starting chat completion with timeout {timeout}')
|
||||
|
||||
while True:
|
||||
try:
|
||||
response = await self._read_with_error_check(timeout)
|
||||
except asyncio.TimeoutError as e:
|
||||
self.logger.info(f'timed out from timeout duration {timeout} - {"prefil" if timeout == prefil_timeout else "decoding stage"}')
|
||||
raise e
|
||||
|
||||
match response:
|
||||
case GenerationResponse():
|
||||
yield TokenChunk(
|
||||
command_id=CommandId(task.command_id),
|
||||
idx=response.token,
|
||||
model=self.model_shard_meta.model_meta.model_id,
|
||||
text=response.text,
|
||||
token_id=response.token,
|
||||
finish_reason=response.finish_reason,
|
||||
)
|
||||
timeout = token_timeout
|
||||
case FinishedResponse():
|
||||
break
|
||||
case ErrorResponse():
|
||||
await self.astop()
|
||||
raise RunnerError(response.error_type, response.error_message, response.traceback)
|
||||
case _:
|
||||
raise ValueError(f'Unexpected response type found: {response}')
|
||||
|
||||
async def _write_coro(self):
|
||||
while True:
|
||||
message = await self.write_queue.get()
|
||||
await supervisor_write_message(
|
||||
self.runner_process,
|
||||
message
|
||||
)
|
||||
|
||||
async def _read_coro(self):
|
||||
while True:
|
||||
response: RunnerResponse | None = await supervisor_read_response(
|
||||
self.runner_process
|
||||
)
|
||||
if response is None:
|
||||
# Runner process died unexpectedly (C++ crash)
|
||||
e = await self._raise_crashed()
|
||||
if e:
|
||||
raise e from EOFError
|
||||
else:
|
||||
break
|
||||
|
||||
match response:
|
||||
case PrintResponse():
|
||||
self.logger.info(f"runner printed: {response.text}")
|
||||
case ErrorResponse():
|
||||
## Failure case #1: a crash happens Python, so it's neatly handled by passing an ErrorResponse with the details
|
||||
await self.read_queue.put(response)
|
||||
case _:
|
||||
await self.read_queue.put(response)
|
||||
|
||||
|
||||
async def astop(self) -> None:
|
||||
# Cancel the stderr monitoring task
|
||||
if not self.stderr_task.done():
|
||||
self.stderr_task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await self.stderr_task
|
||||
async def await_task(task: asyncio.Task[Any]):
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await task
|
||||
|
||||
await await_task(self.stderr_task)
|
||||
await await_task(self.read_task)
|
||||
await await_task(self.write_task)
|
||||
|
||||
# Kill the process and all its children
|
||||
await self._kill_process_tree()
|
||||
await kill_process_tree(self.runner_process, self.logger)
|
||||
|
||||
# Wait to make sure that the model has been unloaded from memory
|
||||
async def wait_for_memory_release() -> None:
|
||||
@@ -160,89 +263,9 @@ class RunnerSupervisor:
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
await wait_for_memory_release()
|
||||
self.running = False
|
||||
|
||||
async def _kill_process_tree(self) -> None:
|
||||
"""Kill the process and all its children forcefully."""
|
||||
if self.runner_process.returncode is not None:
|
||||
return # Process already dead
|
||||
|
||||
try:
|
||||
# Get the main process
|
||||
pid = self.runner_process.pid
|
||||
|
||||
# Find all child processes
|
||||
try:
|
||||
parent = psutil.Process(pid)
|
||||
children = parent.children(recursive=True)
|
||||
|
||||
# Kill all children first (bottom-up)
|
||||
for child in reversed(children):
|
||||
with contextlib.suppress(psutil.NoSuchProcess, psutil.AccessDenied):
|
||||
child.kill() # SIGKILL
|
||||
|
||||
# Kill the parent
|
||||
with contextlib.suppress(psutil.NoSuchProcess, psutil.AccessDenied):
|
||||
parent.kill() # SIGKILL
|
||||
|
||||
except psutil.NoSuchProcess:
|
||||
# Process already gone, try subprocess kill anyway
|
||||
self.runner_process.kill()
|
||||
|
||||
# Wait for the subprocess to exit
|
||||
try:
|
||||
await asyncio.wait_for(self.runner_process.wait(), timeout=2.0)
|
||||
except asyncio.TimeoutError:
|
||||
self.logger.error(f"Process {pid} did not exit after kill signal")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error killing process tree: {e}")
|
||||
|
||||
async def _watch_runner(self) -> None:
|
||||
returncode = await self.runner_process.wait()
|
||||
self.running = False
|
||||
|
||||
if returncode != 0:
|
||||
self.returncode = returncode # Will be picked up by _watch_stderr too
|
||||
|
||||
await self.astop()
|
||||
|
||||
async def _watch_stderr(self, logger: Logger, stderr_queue: asyncio.Queue[tuple[float, str]]) -> None:
|
||||
assert self.runner_process.stderr is not None
|
||||
while self.running:
|
||||
try:
|
||||
line_bytes = await self.runner_process.stderr.readline()
|
||||
if not line_bytes:
|
||||
break
|
||||
line = line_bytes.decode('utf-8').strip()
|
||||
|
||||
await stderr_queue.put((time.time(), line))
|
||||
logger.warning(f"Runner stderr read: {line}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error reading runner stderr: {e}")
|
||||
break
|
||||
|
||||
async def _raise_crashed(self) -> Exception:
|
||||
await self.astop()
|
||||
|
||||
# Accumulate all stderr messages from the queue
|
||||
stderr_output = ''
|
||||
while not self.stderr_queue.empty():
|
||||
try:
|
||||
timestamp, line = self.stderr_queue.get_nowait()
|
||||
stderr_output += f"[{timestamp}] {line}\n"
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
|
||||
self.logger.error(f'Error {self.returncode}: {stderr_output}')
|
||||
return RunnerError(
|
||||
error_type="MLXCrash",
|
||||
error_message=stderr_output,
|
||||
traceback=traceback.format_exc(),
|
||||
)
|
||||
|
||||
def __del__(self) -> None:
|
||||
if self.running:
|
||||
if self.runner_process.returncode is None:
|
||||
print(
|
||||
"Warning: RunnerSupervisor was not stopped cleanly before garbage collection. Force killing process tree."
|
||||
)
|
||||
@@ -264,79 +287,49 @@ class RunnerSupervisor:
|
||||
@property
|
||||
def healthy(self) -> bool:
|
||||
return (
|
||||
self.running
|
||||
and self.runner_process.returncode is None
|
||||
self.runner_process.returncode is None
|
||||
and self.runner_process.stdin is not None
|
||||
and not self.runner_process.stdin.is_closing()
|
||||
and self.runner_process.stdout is not None
|
||||
)
|
||||
|
||||
async def stream_response(
|
||||
self,
|
||||
task: Task,
|
||||
request_started_callback: Callable[..., CoroutineType[Any, Any, None]] | None = None,
|
||||
) -> AsyncGenerator[GenerationChunk]:
|
||||
"""
|
||||
Streams a chat request from the model.
|
||||
The request is pushed to the runner, and if the shard is the terminal shard, the response is streamed back to the worker.
|
||||
request_started_callback is called once the request is pushed to the runner, used to publish InferencePrepareCompleted and InferenceTriggerCompleted events.
|
||||
"""
|
||||
if not self.healthy:
|
||||
raise RuntimeError("Runner process was found to be dead")
|
||||
task_params = task.task_params
|
||||
assert isinstance(task_params, ChatCompletionTaskParams) # this is messy for now.
|
||||
await supervisor_write_message(
|
||||
proc=self.runner_process,
|
||||
message=ChatTaskMessage(
|
||||
task_data=task_params,
|
||||
),
|
||||
)
|
||||
# This is easy for now. If we need more reliability, the runner can have a new 'ready' message type.
|
||||
if request_started_callback is not None:
|
||||
await request_started_callback()
|
||||
prefil_timeout = get_prefil_timeout(task, self.model_shard_meta)
|
||||
token_timeout = get_token_generate_timeout(self.model_shard_meta)
|
||||
timeout = prefil_timeout
|
||||
self.logger.info(f'starting chat completion with timeout {timeout}')
|
||||
## Failure case #2: a crash happens in MLX / C++ (eg segfault) that leads to error flushed to stderr and process dies
|
||||
async def _raise_crashed(self) -> Exception | None:
|
||||
if self.runner_process.returncode == 0:
|
||||
return None
|
||||
|
||||
await self.astop()
|
||||
|
||||
# Accumulate all stderr messages from the queue
|
||||
stderr_output = ''
|
||||
while not self.stderr_queue.empty():
|
||||
try:
|
||||
line = self.stderr_queue.get_nowait()
|
||||
stderr_output += f"{line}\n"
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
|
||||
# print('STDERR OUTPUT IS')
|
||||
# print(stderr_output)
|
||||
|
||||
self.logger.error(f'Error {self.runner_process.returncode}: {stderr_output}')
|
||||
return RunnerError(
|
||||
error_type="MLXCrash",
|
||||
error_message=stderr_output,
|
||||
traceback=traceback.format_exc(),
|
||||
)
|
||||
|
||||
async def _watch_stderr(self) -> None:
|
||||
assert self.runner_process.stderr is not None
|
||||
while True:
|
||||
try:
|
||||
line: RunnerResponse | None = await asyncio.wait_for(supervisor_read_response(
|
||||
self.runner_process
|
||||
), timeout=timeout)
|
||||
if line is None:
|
||||
continue
|
||||
except asyncio.TimeoutError as e:
|
||||
self.logger.info(f'timed out from timeout duration {timeout} - {"prefil" if timeout == prefil_timeout else "decoding stage"}')
|
||||
await self.astop()
|
||||
raise RunnerError(
|
||||
error_type=type(e).__name__,
|
||||
error_message=str(e),
|
||||
traceback=traceback.format_exc(),
|
||||
) from e
|
||||
# TODO: change this to a return none instead of error coming from the supervisor_Read_respons3
|
||||
except EOFError as e:
|
||||
if not self.runner_process.returncode:
|
||||
continue
|
||||
raise await self._raise_crashed() from e
|
||||
match line:
|
||||
case GenerationResponse():
|
||||
yield TokenChunk(
|
||||
command_id=CommandId(task.command_id),
|
||||
idx=line.token,
|
||||
model=self.model_shard_meta.model_meta.model_id,
|
||||
text=line.text,
|
||||
token_id=line.token,
|
||||
finish_reason=line.finish_reason,
|
||||
)
|
||||
timeout = token_timeout
|
||||
case InitializedResponse():
|
||||
raise ValueError('Initialized Response read during streaming flow')
|
||||
case FinishedResponse():
|
||||
line_bytes = await self.runner_process.stderr.readline()
|
||||
if not line_bytes:
|
||||
break
|
||||
case PrintResponse():
|
||||
# print(f"runner printed: {line.text}")
|
||||
self.logger.info(f"runner printed: {line.text}")
|
||||
case ErrorResponse():
|
||||
await self.astop()
|
||||
raise RunnerError(line.error_type, line.error_message, line.traceback)
|
||||
line = line_bytes.decode('utf-8').strip()
|
||||
|
||||
await self.stderr_queue.put(line)
|
||||
self.logger.warning(f"Runner stderr read: {line}")
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Error reading runner stderr: {e}")
|
||||
break
|
||||
@@ -1,10 +1,50 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
import sys
|
||||
from logging import Logger
|
||||
|
||||
from shared.constants import LB_DISK_GBPS, LB_MEMBW_GBPS
|
||||
from shared.types.tasks import Task
|
||||
import psutil
|
||||
|
||||
from shared.constants import LB_DISK_GBPS, LB_MEMBW_GBPS, LB_TFLOPS
|
||||
from shared.types.worker.shards import ShardMetadata
|
||||
|
||||
|
||||
async def kill_process_tree(runner_process: asyncio.subprocess.Process, logger: Logger) -> None:
|
||||
"""Kill the process and all its children forcefully."""
|
||||
if runner_process.returncode is not None:
|
||||
return # Process already dead
|
||||
|
||||
try:
|
||||
# Get the main process
|
||||
pid = runner_process.pid
|
||||
|
||||
# Find all child processes
|
||||
try:
|
||||
parent = psutil.Process(pid)
|
||||
children = parent.children(recursive=True)
|
||||
|
||||
# Kill all children first (bottom-up)
|
||||
for child in reversed(children):
|
||||
with contextlib.suppress(psutil.NoSuchProcess, psutil.AccessDenied):
|
||||
child.kill() # SIGKILL
|
||||
|
||||
# Kill the parent
|
||||
with contextlib.suppress(psutil.NoSuchProcess, psutil.AccessDenied):
|
||||
parent.kill() # SIGKILL
|
||||
|
||||
except psutil.NoSuchProcess:
|
||||
# Process already gone, try subprocess kill anyway
|
||||
runner_process.kill()
|
||||
|
||||
# Wait for the subprocess to exit
|
||||
try:
|
||||
await asyncio.wait_for(runner_process.wait(), timeout=2.0)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"Process {pid} did not exit after kill signal")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error killing process tree: {e}")
|
||||
|
||||
def get_runner_command() -> list[str]:
|
||||
python = sys.executable
|
||||
return [python, "-m", "worker.runner.runner"]
|
||||
@@ -19,20 +59,13 @@ def get_init_timeout(model_shard_meta: ShardMetadata) -> float:
|
||||
|
||||
return weights_size_kb / kbps_read + 2.0
|
||||
|
||||
def get_prefil_timeout(task: Task, model_shard_meta: ShardMetadata) -> float:
|
||||
def get_prompt_str(task: Task) -> str:
|
||||
messages = [x.content for x in task.task_params.messages if x.content]
|
||||
return ''.join(messages)
|
||||
def get_prefil_timeout(model_shard_meta: ShardMetadata) -> float:
|
||||
weights_size_gb = get_weights_size_kb(model_shard_meta) / (1024 * 1024)
|
||||
|
||||
tokens = 1000 # constant for now - the prompt is only tokenized in the device...
|
||||
prompt_gflops = tokens * weights_size_gb * 2
|
||||
|
||||
# TODO: made this timeout very long
|
||||
tokens = len(get_prompt_str(task)) // 3 + 3000 # constant for now - the prompt is only tokenized in the device...
|
||||
|
||||
# TODO: For now we just hack and assume we prefil at 10tok/s
|
||||
return tokens * 0.1
|
||||
|
||||
# prompt_gflops = tokens * weights_size_gb * 2
|
||||
|
||||
# return LB_TFLOPS / (1024 * prompt_gflops) * 3 + 10.0
|
||||
return LB_TFLOPS / (1024 * prompt_gflops) * 3 + 10.0
|
||||
|
||||
def get_token_generate_timeout(model_shard_meta: ShardMetadata) -> float:
|
||||
weights_size_kb = get_weights_size_kb(model_shard_meta)
|
||||
|
||||
@@ -35,7 +35,19 @@ def user_message():
|
||||
|
||||
@pytest.fixture
|
||||
def logger() -> Logger:
|
||||
return getLogger("test_logger")
|
||||
import logging
|
||||
logger = getLogger("test_logger")
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
# Add console handler if none exists
|
||||
if not logger.handlers:
|
||||
handler = logging.StreamHandler()
|
||||
handler.setLevel(logging.DEBUG)
|
||||
formatter = logging.Formatter('%(name)s - %(levelname)s - %(message)s')
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
|
||||
return logger
|
||||
|
||||
@pytest.fixture
|
||||
async def model_meta() -> ModelMetadata:
|
||||
|
||||
@@ -74,7 +74,7 @@ async def test_execute_task_timeouts(
|
||||
task=task
|
||||
)
|
||||
|
||||
with pytest.raises(RunnerError): # At the moment this is a RunnerError that says 'TimeoutError'.
|
||||
with pytest.raises(asyncio.TimeoutError):
|
||||
await read_events_op(worker, execute_task_op)
|
||||
|
||||
|
||||
|
||||
@@ -164,7 +164,7 @@ async def test_stream_response_failed_once(
|
||||
assert isinstance(event.chunk, TokenChunk)
|
||||
response_string += event.chunk.text
|
||||
|
||||
assert 'elizabeth' in response_string.lower()
|
||||
assert 'queen' in response_string.lower()
|
||||
assert seen_task_started
|
||||
assert seen_task_finished
|
||||
|
||||
@@ -206,7 +206,7 @@ async def test_stream_response_timeout(
|
||||
print(events)
|
||||
assert len([x for x in events if isinstance(x.event, RunnerStatusUpdated) and isinstance(x.event.runner_status, FailedRunnerStatus)]) == 3
|
||||
assert len([x for x in events if isinstance(x.event, TaskStateUpdated) and x.event.task_status == TaskStatus.FAILED]) == 3
|
||||
assert len([x for x in events if isinstance(x.event, TaskFailed) and 'timeouterror' in x.event.error_message.lower()]) == 3
|
||||
assert len([x for x in events if isinstance(x.event, TaskFailed) and 'timeouterror' in x.event.error_type.lower()]) == 3
|
||||
|
||||
await global_events.append_events(
|
||||
[
|
||||
|
||||
@@ -53,7 +53,10 @@ async def test_runner_spinup_exception(
|
||||
# Ensure the correct events have been emitted
|
||||
events = await global_events.get_events_since(0)
|
||||
|
||||
assert len([x for x in events if isinstance(x.event, RunnerStatusUpdated) and isinstance(x.event.runner_status, FailedRunnerStatus)]) == 3
|
||||
assert len([x for x in events if isinstance(x.event, RunnerStatusUpdated) \
|
||||
and isinstance(x.event.runner_status, FailedRunnerStatus) \
|
||||
and x.event.runner_status.error_message is not None \
|
||||
and 'fake exception' in x.event.runner_status.error_message.lower()]) == 3
|
||||
assert any([isinstance(x.event, InstanceDeleted) for x in events])
|
||||
|
||||
|
||||
|
||||
@@ -21,7 +21,8 @@ def user_message():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_supervisor_single_node_response(
|
||||
@pytest.mark.skip(reason="Must run `sudo sysctl -w iogpu.wired_limit_mb=` and `sudo sysctl -w iogpu.wired_lwm_mb=` before running this test.")
|
||||
async def test_supervisor_catches_oom(
|
||||
pipeline_shard_meta: Callable[..., PipelineShardMetadata],
|
||||
hosts: Callable[..., list[Host]],
|
||||
chat_completion_task: Callable[[InstanceId, TaskId], Task],
|
||||
@@ -38,8 +39,11 @@ async def test_supervisor_single_node_response(
|
||||
|
||||
task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID)
|
||||
task.task_params.messages[0].content = 'EXO RUNNER MUST OOM'
|
||||
with pytest.raises(RunnerError):
|
||||
async for _ in supervisor.stream_response(task):
|
||||
pass
|
||||
with pytest.raises(RunnerError) as exc_info:
|
||||
async for _ in supervisor.stream_response(task):
|
||||
pass
|
||||
|
||||
error = exc_info.value
|
||||
assert 'memory' in error.error_message.lower()
|
||||
|
||||
await supervisor.astop()
|
||||
|
||||
@@ -72,20 +72,17 @@ async def test_supervisor_two_node_response(
|
||||
):
|
||||
"""Test that asking for the capital of France returns 'Paris' in the response"""
|
||||
instance_id = InstanceId()
|
||||
create_supervisor_0 = asyncio.create_task(
|
||||
RunnerSupervisor.create(
|
||||
model_shard_meta=pipeline_shard_meta(2, 0),
|
||||
|
||||
async def create_supervisor(shard_idx: int) -> RunnerSupervisor:
|
||||
supervisor = await RunnerSupervisor.create(
|
||||
model_shard_meta=pipeline_shard_meta(2, shard_idx),
|
||||
hosts=hosts(2, offset=15),
|
||||
logger=logger,
|
||||
)
|
||||
)
|
||||
create_supervisor_1 = asyncio.create_task(
|
||||
RunnerSupervisor.create(
|
||||
model_shard_meta=pipeline_shard_meta(2, 1),
|
||||
hosts=hosts(2, offset=15),
|
||||
logger=logger,
|
||||
)
|
||||
)
|
||||
return supervisor
|
||||
|
||||
create_supervisor_0 = asyncio.create_task(create_supervisor(0))
|
||||
create_supervisor_1 = asyncio.create_task(create_supervisor(1))
|
||||
supervisor_0, supervisor_1 = await asyncio.gather(create_supervisor_0, create_supervisor_1)
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
@@ -23,7 +23,7 @@ async def test_supervisor_instantiation_exception(
|
||||
model_shard_meta.immediate_exception = True
|
||||
|
||||
with pytest.raises(RunnerError):
|
||||
await RunnerSupervisor.create(
|
||||
_ = await RunnerSupervisor.create(
|
||||
model_shard_meta=model_shard_meta,
|
||||
hosts=hosts(1, offset=10),
|
||||
logger=logger,
|
||||
@@ -40,7 +40,7 @@ async def test_supervisor_instantiation_timeout(
|
||||
model_shard_meta.should_timeout = 10 # timeout after 10s
|
||||
|
||||
with pytest.raises(asyncio.TimeoutError):
|
||||
await RunnerSupervisor.create(
|
||||
_ = await RunnerSupervisor.create(
|
||||
model_shard_meta=model_shard_meta,
|
||||
hosts=hosts(1, offset=10),
|
||||
logger=logger,
|
||||
@@ -88,7 +88,7 @@ async def test_supervisor_inference_timeout(
|
||||
|
||||
task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID)
|
||||
task.task_params.messages[0].content = 'EXO RUNNER MUST TIMEOUT'
|
||||
with pytest.raises(RunnerError):
|
||||
with pytest.raises(asyncio.TimeoutError):
|
||||
async for _ in supervisor.stream_response(task):
|
||||
pass
|
||||
|
||||
|
||||
@@ -171,7 +171,7 @@ class Worker:
|
||||
This op assigns the runner, and moves from Downloading -> Inactive (ready to spin) state.
|
||||
"""
|
||||
assigned_runner = self._create_assigned_runner(op)
|
||||
initial_progress = await asyncio.wait_for(self.shard_downloader.get_shard_download_status_for_shard(op.shard_metadata), timeout=15)
|
||||
initial_progress = await self.shard_downloader.get_shard_download_status_for_shard(op.shard_metadata)
|
||||
|
||||
if initial_progress.status == "complete":
|
||||
async for event in self._handle_already_downloaded_shard(assigned_runner):
|
||||
@@ -217,8 +217,6 @@ class Worker:
|
||||
runner = assigned_runner.runner
|
||||
health_issues: list[str] = []
|
||||
|
||||
if not runner.running:
|
||||
health_issues.append("runner.running is False")
|
||||
if runner.runner_process.returncode is not None:
|
||||
health_issues.append(f"runner_process.returncode is {runner.runner_process.returncode}")
|
||||
if runner.runner_process.stdin is None:
|
||||
@@ -348,6 +346,7 @@ class Worker:
|
||||
## Operation Planner
|
||||
|
||||
async def execute_op(self, op: RunnerOp) -> AsyncGenerator[Event, None]:
|
||||
## It would be great if we can get rid of this async for ... yield pattern.
|
||||
match op.op_type:
|
||||
case RunnerOpType.ASSIGN_RUNNER:
|
||||
event_generator = self._execute_assign_op(op)
|
||||
@@ -410,4 +409,3 @@ class Worker:
|
||||
assert self.worker_events is not None
|
||||
await self.worker_events.append_events([event], self.node_id)
|
||||
self.logger.info(f"published event: {event}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user