diff --git a/.gitignore b/.gitignore index 936e5433..310df30d 100644 --- a/.gitignore +++ b/.gitignore @@ -23,4 +23,7 @@ dist/ */.DS_Store # Says this symlink should be git-ignored https://github.com/juspay/just-flake -just-flake.just \ No newline at end of file +just-flake.just + +# for the gitingest enthusiasts +digest.txt \ No newline at end of file diff --git a/src/exo/engines/mlx/utils_mlx.py b/src/exo/engines/mlx/utils_mlx.py index e8df5a8d..72b99584 100644 --- a/src/exo/engines/mlx/utils_mlx.py +++ b/src/exo/engines/mlx/utils_mlx.py @@ -136,7 +136,7 @@ def initialize_mlx( def shard_and_load( - model_shard_meta: ShardMetadata, + model_shard_meta: ShardMetadata, ) -> tuple[nn.Module, TokenizerWrapper]: model_path = build_model_path(model_shard_meta.model_meta.model_id) diff --git a/src/exo/shared/global_conn.py b/src/exo/shared/global_conn.py new file mode 100644 index 00000000..5def2999 --- /dev/null +++ b/src/exo/shared/global_conn.py @@ -0,0 +1,64 @@ +# src/exo/shared/global_conn.py + +import asyncio +import threading +from multiprocessing.connection import Connection +from typing import Optional + +from exo.shared.types.worker.commands_runner import ( + RunnerMessage, + RunnerResponse, +) + + +class AsyncConnection[SendT, RecvT]: + """ + Async/sync wrapper around multiprocessing.Connection with thread-safe send. + Use: + - await send(...) from asyncio code + - send_sync(...) from executor/background threads + """ + def __init__(self, conn: Connection): + self._conn = conn + self._send_lock = threading.Lock() + self._recv_lock = threading.Lock() + + # ---- sending ---- + async def send(self, obj: SendT) -> None: + loop = asyncio.get_running_loop() + await loop.run_in_executor(None, self._send_blocking, obj) + + def send_sync(self, obj: SendT) -> None: + self._send_blocking(obj) + + def _send_blocking(self, obj: SendT) -> None: + # Single critical section for the whole pickle frame + with self._send_lock: + self._conn.send(obj) + + # ---- receiving ---- + async def recv(self) -> RecvT: + loop = asyncio.get_running_loop() + return await loop.run_in_executor(None, self._recv_blocking) + + def _recv_blocking(self) -> RecvT: + # Not strictly needed in your parent, but safe if misused elsewhere + with self._recv_lock: + return self._conn.recv() # type: ignore[no-any-return] + + async def poll(self, timeout: float | None = None) -> bool: + return await asyncio.to_thread(self._conn.poll, timeout) + + def close(self) -> None: + self._conn.close() + +_conn: Optional[AsyncConnection[RunnerResponse, RunnerMessage]] = None + +def set_conn(c: AsyncConnection[RunnerResponse, RunnerMessage]) -> None: + global _conn + _conn = c + +def get_conn() -> AsyncConnection[RunnerResponse, RunnerMessage]: + if _conn is None: + raise RuntimeError("Global conn has not been set yet") + return _conn diff --git a/src/exo/shared/types/worker/communication.py b/src/exo/shared/types/worker/communication.py index a1ea6c4e..3afe8e69 100644 --- a/src/exo/shared/types/worker/communication.py +++ b/src/exo/shared/types/worker/communication.py @@ -1,138 +1,17 @@ import asyncio -import json -import struct -import sys import traceback -from typing import Any, BinaryIO, Dict, Tuple, Union, cast from loguru import logger +from exo.shared.global_conn import AsyncConnection, get_conn from exo.shared.types.worker.commands_runner import ( ErrorResponse, PrintResponse, RunnerMessage, - RunnerMessageTypeAdapter, RunnerResponse, RunnerResponseType, - RunnerResponseTypeAdapter, ) -### Utils - SAFE LENGTH READ/WRITE - -MAGIC = b"EXO1" -HDR_FMT = "!I" # 4-byte big-endian length - - -async def write_frame(stream: Union[asyncio.StreamWriter, Any], obj: Union[Dict[str, Any], bytes]) -> None: - """Write a length-prefixed frame to a stream.""" - payload = obj if isinstance(obj, bytes) else json.dumps(obj).encode("utf-8") - header = MAGIC + struct.pack(HDR_FMT, len(payload)) - stream.write(header + payload) - if hasattr(stream, 'drain'): - await stream.drain() - - -async def read_frame(stream: Union[asyncio.StreamReader, Any]) -> Dict[str, Any]: - """Read a length-prefixed frame from a stream.""" - # Read 8 bytes: 4-byte magic + 4-byte length - header: bytes = await stream.readexactly(8) - if header[:4] != MAGIC: - # Fallback to legacy newline mode for backward compatibility - # Reconstruct the partial line and read the rest - remaining: bytes = await stream.readline() - line = header + remaining - return cast(Dict[str, Any], json.loads(line.strip().decode('utf-8'))) - - (length,) = cast(Tuple[int], struct.unpack(HDR_FMT, header[4:])) - data: bytes = await stream.readexactly(length) - return cast(Dict[str, Any], json.loads(data.decode('utf-8'))) - - -def write_frame_sync(stream: BinaryIO, obj: Union[Dict[str, Any], bytes]) -> None: - """Synchronous version of write_frame for use in runner.""" - payload = obj if isinstance(obj, bytes) else json.dumps(obj).encode("utf-8") - header = MAGIC + struct.pack(HDR_FMT, len(payload)) - stream.write(header + payload) - stream.flush() - - -def read_frame_sync(stream: BinaryIO) -> Dict[str, Any]: - """Synchronous version of read_frame for use in runner.""" - # Read 8 bytes: 4-byte magic + 4-byte length - header: bytes = stream.read(8) - if not header or len(header) < 8: - raise EOFError("No more data to read") - - if header[:4] != MAGIC: - # Fallback to legacy newline mode for backward compatibility - # Reconstruct the partial line and read the rest - remaining: bytes = stream.readline() - if not remaining: - raise EOFError("No more data to read") - line = header + remaining - return cast(Dict[str, Any], json.loads(line.strip().decode('utf-8'))) - - (length,) = cast(Tuple[int], struct.unpack(HDR_FMT, header[4:])) - data: bytes = stream.read(length) - if len(data) < length: - raise EOFError(f"Expected {length} bytes, got {len(data)}") - return cast(Dict[str, Any], json.loads(data.decode('utf-8'))) - - - -### Utils - MESSAGE TO RUNNER - -async def supervisor_write_message( - proc: asyncio.subprocess.Process, message: RunnerMessage -) -> None: - assert proc.stdin is not None, ( - "proc.stdin should not be None when created with stdin=PIPE" - ) - - # Use model_dump_json to get proper JSON encoding for Pydantic types like IPv4Address - await write_frame(proc.stdin, message.model_dump_json().encode('utf-8')) - - -async def runner_read_message() -> RunnerMessage: - loop = asyncio.get_running_loop() - - # Use executor to avoid blocking the event loop - data: Dict[str, Any] = await loop.run_in_executor(None, read_frame_sync, sys.stdin.buffer) - - try: - return RunnerMessageTypeAdapter.validate_python(data) - except Exception as e: - raise ValueError(f"Error validating message: {data}") from e - - -### Utils - RESPONSE FROM RUNNER - -def runner_write_response(obj: RunnerResponse) -> None: - try: - # Use model_dump_json to get proper JSON encoding - write_frame_sync(sys.stdout.buffer, obj.model_dump_json().encode('utf-8')) - except BrokenPipeError: - # Supervisor has closed the pipe, silently exit - sys.exit(0) - - -async def supervisor_read_response( - proc: asyncio.subprocess.Process, -) -> RunnerResponse: - assert proc.stdout is not None, ( - "proc.stdout should not be None when created with stdout=PIPE" - ) - - data: Dict[str, Any] - try: - data = await read_frame(proc.stdout) - return RunnerResponseTypeAdapter.validate_python(data) - except EOFError: - raise EOFError('No more data to read when reading response from runner.') from None - except Exception as err: - raise ValueError(f"Error validating response: {err}") from err - - ### Utils - Runner Prints @@ -142,19 +21,24 @@ def runner_print(text: str) -> None: text=text, ) - runner_write_response(obj) + conn: AsyncConnection[RunnerResponse, RunnerMessage] = get_conn() + conn.send_sync(obj) def runner_write_error(error: Exception) -> None: - # Skip writing error if it's a BrokenPipeError - supervisor is already gone - if isinstance(error, BrokenPipeError): - sys.exit(0) - error_response: ErrorResponse = ErrorResponse( type=RunnerResponseType.ErrorResponse, error_type=type(error).__name__, error_message=str(error), traceback=traceback.format_exc(), ) - runner_write_response(error_response) + + conn = get_conn() + asyncio.create_task(conn.send(error_response)) logger.opt(exception=error).exception("Critical Runner error") + + + +## TODO: To make this cleaner, it seems like we should have only one writer. +# This is fine in runner_supervisor but there's a risk in runner.py that we overlap things +# We can guarantee this by enqueueing messages and have a writing thread. \ No newline at end of file diff --git a/src/exo/worker/plan.py b/src/exo/worker/plan.py index da142434..250f8fd3 100644 --- a/src/exo/worker/plan.py +++ b/src/exo/worker/plan.py @@ -58,7 +58,7 @@ def failed_runners( for runner_id, assigned_runner in assigned_runners.items(): if ( assigned_runner.runner is not None - and not assigned_runner.runner.healthy + and not assigned_runner.runner.runner_process.is_alive() and not isinstance(assigned_runner.status, FailedRunnerStatus) ): return RunnerFailedOp(runner_id=runner_id) diff --git a/src/exo/worker/runner/bootstrap.py b/src/exo/worker/runner/bootstrap.py new file mode 100644 index 00000000..24d96bf3 --- /dev/null +++ b/src/exo/worker/runner/bootstrap.py @@ -0,0 +1,28 @@ +import asyncio +import faulthandler +import os +import sys +from multiprocessing.connection import Connection + + +def _redirect_stderr_to_file(path: str) -> None: + # Replace fd 2 (stderr) with a file descriptor pointing to `path` + fd = os.open(path, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o644) + os.dup2(fd, 2) + os.close(fd) + # Rebind sys.stderr so Python's own writes go to the new fd as well (line-buffered) + sys.stderr = os.fdopen(2, "w", buffering=1, closefd=False) + +def entrypoint(raw_conn: Connection, err_path: str) -> None: + """ + Minimal entrypoint for the spawned child process. + + It redirects fd=2 (stderr) to a pipe provided by the parent, *then* imports + the heavy runner module so that any C/C++ or MLX logs/crashes land in that pipe. + """ + _redirect_stderr_to_file(err_path) + faulthandler.enable(file=sys.stderr, all_threads=True) + + # Import the heavy runner only after stderr is redirected + from exo.worker.runner.runner import main + asyncio.run(main(raw_conn)) diff --git a/src/exo/worker/runner/generate.py b/src/exo/worker/runner/generate.py new file mode 100644 index 00000000..b415fb54 --- /dev/null +++ b/src/exo/worker/runner/generate.py @@ -0,0 +1,301 @@ +import asyncio +import concurrent.futures +import time +from collections.abc import AsyncGenerator +from functools import partial +from typing import Callable, Generator, Optional, Tuple + +import mlx.core as mx +from mlx.core import array +from mlx_lm.models import cache +from mlx_lm.models.cache import KVCache + +from exo.engines.mlx import Model, TokenizerWrapper +from exo.engines.mlx.utils_mlx import ( + apply_chat_template, + broadcast_from_zero, + make_kv_cache, + mx_barrier, +) +from exo.shared.types.api import ChatCompletionMessage +from exo.shared.types.tasks import ChatCompletionTaskParams +from exo.shared.types.worker.commands_runner import ( + GenerationResponse, + RunnerMessage, + RunnerResponse, + TokenizedResponse, +) +from exo.shared.types.worker.communication import ( + AsyncConnection, + runner_print, +) + +generation_stream = mx.new_stream(mx.default_device()) + +def generate_step( + prompt: mx.array, + model: Model, + *, + max_tokens: int = 256, + sampler: Callable[[mx.array], mx.array], + max_kv_size: Optional[int] = None, + prompt_cache: Optional[list[KVCache]] = None, + prefill_step_size: int = 2048, +) -> Generator[Tuple[int, mx.array], None, None]: + """ + A generator producing token ids based on the given prompt from the model. + + Args: + prompt (mx.array): The input prompt. + model (Model): The model to use for generation. + max_tokens (int): The maximum number of tokens. Use``-1`` for an infinite + generator. Default: ``256``. + sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a + token from a vector of log probabilities. Default: ``None``. + max_kv_size (int, optional): Maximum size of the key-value cache. Old + entries (except the first 4 tokens) will be overwritten. + prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if + provided, the cache will be updated in place. + prefill_step_size (int): Step size for processing the prompt. + + Yields: + Tuple[int, mx.array]: One token and a vector of log probabilities. + """ + tokens = None + + # Create the KV cache for generation + if prompt_cache is None: + prompt_cache = cache.make_prompt_cache( + model, + max_kv_size=max_kv_size, + ) + + def _step(input_tokens: mx.array): + nonlocal tokens + + with mx.stream(generation_stream): + logits = model( + input_tokens[None], + cache=prompt_cache, + ) + + logits = logits[:, -1, :] + + logprobs = logits - mx.logsumexp(logits, keepdims=True) # pyright: ignore[reportUnknownMemberType] + sampled = sampler(logprobs) + return sampled, logprobs.squeeze(0) + + with mx.stream(generation_stream): + total_prompt_tokens = len(prompt) + prompt_processed_tokens = 0 + + while total_prompt_tokens - prompt_processed_tokens > prefill_step_size: + runner_print(f'Prefilling {min(prefill_step_size, len(prompt))} tokens. Remaining tokens: {len(prompt)}. Peak memory: {mx.get_peak_memory() // 2**30} GB') + logits = model( + prompt[:prefill_step_size][None], + cache=prompt_cache + ) + + start_time = time.time() + mx.eval([c.state for c in prompt_cache] + [logits]) # type: ignore + eval_time = time.time() - start_time + prompt_processed_tokens += prefill_step_size + + prompt = prompt[prefill_step_size:] + + mx.clear_cache() + if eval_time > 7.0: + prefill_step_size = prefill_step_size // 2 + prefill_step_size = broadcast_from_zero(prefill_step_size) + prefill_step_size = max(1, prefill_step_size) + + + runner_print('finished prefil.') + y, logprobs = _step(input_tokens=prompt) + + mx.async_eval(y, logprobs) # type: ignore + n = 0 + next_y: array | None = None + next_logprobs: array | None = None + + mx.async_eval(y, logprobs) # type: ignore + n = 0 + while True: + if n != max_tokens: + assert y is not None + next_y, next_logprobs = _step(y) + mx.async_eval(next_y, next_logprobs) # type: ignore + if n == 0: + mx.eval(y) # type: ignore + if n == max_tokens: + break + yield int(y.item()), logprobs # type: ignore + if n % 256 == 0: + mx.clear_cache() + y, logprobs = next_y, next_logprobs + n += 1 + + + +def stream_generate( + model: Model, + tokenizer: TokenizerWrapper, + prompt: str, + max_tokens: int, + sampler: Callable[[mx.array], mx.array], + conn: AsyncConnection[RunnerResponse, RunnerMessage] | None, + prompt_cache: Optional[list[KVCache]] = None, + prefill_step_size: int = 2048, +) -> Generator[GenerationResponse, None, None]: + + # Try to infer if special tokens are needed + add_special_tokens = tokenizer.bos_token is None or not prompt.startswith( + tokenizer.bos_token + ) + prompt_array: mx.array = mx.array(tokenizer.encode(prompt, add_special_tokens=add_special_tokens)) + if conn is not None: + conn.send_sync(TokenizedResponse(prompt_tokens=len(prompt_array))) + + detokenizer = tokenizer.detokenizer + + token_generator: Generator[Tuple[int, array], None, None] = generate_step( + prompt_array, + model, + max_tokens=max_tokens, + sampler=sampler, + prompt_cache=prompt_cache, + prefill_step_size=prefill_step_size, + ) + + token = None + detokenizer.reset() + for token, _ in token_generator: + if token in tokenizer.eos_token_ids: + break + + detokenizer.add_token(token) + + # TODO: We could put more metrics on this GenerationResponse if we wish + yield GenerationResponse( + text=detokenizer.last_segment, + token=token, + finish_reason=None, + ) + + assert token is not None + detokenizer.finalize() + yield GenerationResponse( + text=detokenizer.last_segment, + token=token, + finish_reason="stop" if token in tokenizer.eos_token_ids else "length", + ) + +async def warmup_inference( + mlx_executor: concurrent.futures.ThreadPoolExecutor, + model: Model, + tokenizer: TokenizerWrapper, + sampler: Callable[[mx.array], mx.array], +) -> int: + loop = asyncio.get_running_loop() + + warmup_prompt = await apply_chat_template( + mlx_executor=mlx_executor, + tokenizer=tokenizer, + chat_task_data=ChatCompletionTaskParams( + model="warmup", + messages=[ + ChatCompletionMessage( + role="user", + content="Prompt to warm up the inference engine. Repeat this.", + ) + ], + ), + ) + + tokens_generated = 0 + + def _generate_warmup(): + nonlocal tokens_generated + for _ in stream_generate( + model=model, + tokenizer=tokenizer, + prompt=warmup_prompt, + max_tokens=50, + sampler=sampler, + conn=None + ): + tokens_generated += 1 + + await loop.run_in_executor(mlx_executor, _generate_warmup) + mx_barrier() + + return tokens_generated + +async def mlx_generate( + mlx_executor: concurrent.futures.ThreadPoolExecutor, + model: Model, + tokenizer: TokenizerWrapper, + sampler: Callable[[mx.array], mx.array], + task: ChatCompletionTaskParams, + conn: AsyncConnection[RunnerResponse, RunnerMessage], +) -> AsyncGenerator[GenerationResponse]: + loop = asyncio.get_running_loop() + queue: asyncio.Queue[GenerationResponse | Exception | object] = asyncio.Queue() + sentinel = object() + + def _generate_tokens(prompt: str, max_tokens: int, cache: list[KVCache]) -> None: + try: + for generation_response in stream_generate( + model=model, + tokenizer=tokenizer, + prompt=prompt, + max_tokens=max_tokens, + sampler=sampler, + prompt_cache=cache, + prefill_step_size=1024, + conn=conn, + ): + _ = loop.call_soon_threadsafe(queue.put_nowait, generation_response) + except Exception as e: + _ = loop.call_soon_threadsafe(queue.put_nowait, e) + finally: + _ = loop.call_soon_threadsafe(queue.put_nowait, sentinel) + + # Currently we support chat-completion tasks only. + runner_print(f"task_params: {task}") + + prompt = await apply_chat_template( + mlx_executor=mlx_executor, + tokenizer=tokenizer, + chat_task_data=task, + ) + + cache_future = loop.run_in_executor( + mlx_executor, + lambda: asyncio.run(make_kv_cache( + model=model, + )) + ) + cache = await cache_future + + max_tokens = task.max_tokens or 1000 + generation_fn = partial(_generate_tokens, prompt, max_tokens, cache) + + future = loop.run_in_executor(mlx_executor, generation_fn) + + while True: + item = await queue.get() + queue.task_done() + + if item is sentinel: + break + + if isinstance(item, Exception): + raise item + + assert isinstance(item, GenerationResponse) # constrain datatype + runner_print(item.text) + yield item + + # Wait for the executor thread to complete + await future \ No newline at end of file diff --git a/src/exo/worker/runner/runner.py b/src/exo/worker/runner/runner.py index ab513c76..44874a0d 100644 --- a/src/exo/worker/runner/runner.py +++ b/src/exo/worker/runner/runner.py @@ -1,331 +1,58 @@ import asyncio import concurrent.futures import time -from collections.abc import AsyncGenerator from functools import partial -from typing import Callable, Generator, Optional, Tuple +from multiprocessing.connection import Connection -import mlx.core as mx -from mlx.core import array -from mlx_lm.generate import stream_generate as mlx_stream_generate -from mlx_lm.models import cache -from mlx_lm.models.cache import KVCache - -from exo.engines.mlx import Model, TokenizerWrapper from exo.engines.mlx.utils_mlx import ( - apply_chat_template, - broadcast_from_zero, initialize_mlx, - make_kv_cache, mlx_force_oom, mlx_setup, - mx_barrier, ) -from exo.shared.types.api import ChatCompletionMessage -from exo.shared.types.tasks import ChatCompletionTaskParams +from exo.shared.global_conn import set_conn from exo.shared.types.worker.commands_runner import ( ChatTaskMessage, ExitMessage, FinishedResponse, - GenerationResponse, InitializedResponse, + RunnerMessage, + RunnerResponse, SetupMessage, - TokenizedResponse, ) from exo.shared.types.worker.communication import ( + AsyncConnection, runner_print, - runner_read_message, runner_write_error, - runner_write_response, ) +from exo.shared.types.worker.shards import ShardMetadata from exo.shared.utils import ensure_type +from exo.worker.runner.generate import mlx_generate, warmup_inference from exo.worker.runner.utils import get_weights_size_kb -generation_stream = mx.new_stream(mx.default_device()) -def generate_step( - prompt: mx.array, - model: Model, - *, - max_tokens: int = 256, - sampler: Callable[[mx.array], mx.array], - max_kv_size: Optional[int] = None, - prompt_cache: Optional[list[KVCache]] = None, - prefill_step_size: int = 2048, -) -> Generator[Tuple[int, mx.array], None, None]: - """ - A generator producing token ids based on the given prompt from the model. +async def main( + raw_conn: Connection +): + conn = AsyncConnection[RunnerResponse, RunnerMessage](raw_conn) + set_conn(conn) - Args: - prompt (mx.array): The input prompt. - model (Model): The model to use for generation. - max_tokens (int): The maximum number of tokens. Use``-1`` for an infinite - generator. Default: ``256``. - sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a - token from a vector of log probabilities. Default: ``None``. - max_kv_size (int, optional): Maximum size of the key-value cache. Old - entries (except the first 4 tokens) will be overwritten. - prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if - provided, the cache will be updated in place. - prefill_step_size (int): Step size for processing the prompt. - - Yields: - Tuple[int, mx.array]: One token and a vector of log probabilities. - """ - tokens = None - - # Create the KV cache for generation - if prompt_cache is None: - prompt_cache = cache.make_prompt_cache( - model, - max_kv_size=max_kv_size, - ) - - def _step(input_tokens: mx.array): - nonlocal tokens - - with mx.stream(generation_stream): - logits = model( - input_tokens[None], - cache=prompt_cache, - ) - - logits = logits[:, -1, :] - - logprobs = logits - mx.logsumexp(logits, keepdims=True) # pyright: ignore[reportUnknownMemberType] - sampled = sampler(logprobs) - return sampled, logprobs.squeeze(0) - - with mx.stream(generation_stream): - total_prompt_tokens = len(prompt) - prompt_processed_tokens = 0 - - while total_prompt_tokens - prompt_processed_tokens > prefill_step_size: - runner_print(f'Prefilling {min(prefill_step_size, len(prompt))} tokens. Remaining tokens: {len(prompt)}. Peak memory: {mx.get_peak_memory() // 2**30} GB') - logits = model( - prompt[:prefill_step_size][None], - cache=prompt_cache - ) - - start_time = time.time() - mx.eval([c.state for c in prompt_cache] + [logits]) # type: ignore - eval_time = time.time() - start_time - prompt_processed_tokens += prefill_step_size - - prompt = prompt[prefill_step_size:] - - mx.clear_cache() - if eval_time > 7.0: - prefill_step_size = prefill_step_size // 2 - prefill_step_size = broadcast_from_zero(prefill_step_size) - prefill_step_size = max(1, prefill_step_size) - - - runner_print('finished prefil.') - y, logprobs = _step(input_tokens=prompt) - - mx.async_eval(y, logprobs) # type: ignore - n = 0 - next_y: array | None = None - next_logprobs: array | None = None - - mx.async_eval(y, logprobs) # type: ignore - n = 0 - while True: - if n != max_tokens: - assert y is not None - next_y, next_logprobs = _step(y) - mx.async_eval(next_y, next_logprobs) # type: ignore - if n == 0: - mx.eval(y) # type: ignore - if n == max_tokens: - break - yield int(y.item()), logprobs # type: ignore - if n % 256 == 0: - mx.clear_cache() - y, logprobs = next_y, next_logprobs - n += 1 - - - -def stream_generate( - model: Model, - tokenizer: TokenizerWrapper, - prompt: str, - max_tokens: int, - sampler: Callable[[mx.array], mx.array], - prompt_cache: Optional[list[KVCache]] = None, - prefill_step_size: int = 2048, - warmup: bool = False, -) -> Generator[GenerationResponse, None, None]: - - # Try to infer if special tokens are needed - add_special_tokens = tokenizer.bos_token is None or not prompt.startswith( - tokenizer.bos_token - ) - prompt_array: mx.array = mx.array(tokenizer.encode(prompt, add_special_tokens=add_special_tokens)) - if not warmup: - runner_write_response(TokenizedResponse(prompt_tokens=len(prompt_array))) - - detokenizer = tokenizer.detokenizer - - token_generator: Generator[Tuple[int, array], None, None] = generate_step( - prompt_array, - model, - max_tokens=max_tokens, - sampler=sampler, - prompt_cache=prompt_cache, - prefill_step_size=prefill_step_size, - ) - - token = None - detokenizer.reset() - for token, _ in token_generator: - if token in tokenizer.eos_token_ids: - break - - detokenizer.add_token(token) - - # TODO: We could put more metrics on this GenerationResponse if we wish - yield GenerationResponse( - text=detokenizer.last_segment, - token=token, - finish_reason=None, - ) - - assert token is not None - detokenizer.finalize() - yield GenerationResponse( - text=detokenizer.last_segment, - token=token, - finish_reason="stop" if token in tokenizer.eos_token_ids else "length", - ) - -async def warmup_inference( - mlx_executor: concurrent.futures.ThreadPoolExecutor, - model: Model, - tokenizer: TokenizerWrapper, - sampler: Callable[[mx.array], mx.array], -) -> int: - loop = asyncio.get_running_loop() - - warmup_prompt = await apply_chat_template( - mlx_executor=mlx_executor, - tokenizer=tokenizer, - chat_task_data=ChatCompletionTaskParams( - model="warmup", - messages=[ - ChatCompletionMessage( - role="user", - content="Prompt to warm up the inference engine. Repeat this.", - ) - ], - ), - ) - - tokens_generated = 0 - - def _generate_warmup(): - nonlocal tokens_generated - for _ in stream_generate( - model=model, - tokenizer=tokenizer, - prompt=warmup_prompt, - max_tokens=50, - sampler=sampler, - warmup=True, - ): - tokens_generated += 1 - - await loop.run_in_executor(mlx_executor, _generate_warmup) - mx_barrier() - - return tokens_generated - -async def _mlx_generate( - mlx_executor: concurrent.futures.ThreadPoolExecutor, - model: Model, - tokenizer: TokenizerWrapper, - sampler: Callable[[mx.array], mx.array], - task: ChatCompletionTaskParams, -) -> AsyncGenerator[GenerationResponse]: - loop = asyncio.get_running_loop() - queue: asyncio.Queue[GenerationResponse | Exception | object] = asyncio.Queue() - sentinel = object() - - def _generate_tokens(prompt: str, max_tokens: int, cache: list[KVCache]) -> None: - try: - for generation_response in stream_generate( - model=model, - tokenizer=tokenizer, - prompt=prompt, - max_tokens=max_tokens, - sampler=sampler, - prompt_cache=cache, - prefill_step_size=1024, - ): - _ = loop.call_soon_threadsafe(queue.put_nowait, generation_response) - except Exception as e: - _ = loop.call_soon_threadsafe(queue.put_nowait, e) - finally: - _ = loop.call_soon_threadsafe(queue.put_nowait, sentinel) - - # Currently we support chat-completion tasks only. - runner_print(f"task_params: {task}") - - prompt = await apply_chat_template( - mlx_executor=mlx_executor, - tokenizer=tokenizer, - chat_task_data=task, - ) - - cache_future = loop.run_in_executor( - mlx_executor, - lambda: asyncio.run(make_kv_cache( - model=model, - )) - ) - cache = await cache_future - - max_tokens = task.max_tokens or 1000 - generation_fn = partial(_generate_tokens, prompt, max_tokens, cache) - - future = loop.run_in_executor(mlx_executor, generation_fn) - - while True: - item = await queue.get() - queue.task_done() - - if item is sentinel: - break - - if isinstance(item, Exception): - raise item - - assert isinstance(item, GenerationResponse) # constrain datatype - runner_print(item.text) - yield item - - # Wait for the executor thread to complete - await future - - -async def main(): try: runner_print("hello from the runner") - # Get setup info from worker - init_message = await runner_read_message() + init_message = await conn.recv() setup_message = ensure_type(init_message, SetupMessage) - model_shard_meta = setup_message.model_shard_meta + model_shard_meta: ShardMetadata = setup_message.model_shard_meta hosts = setup_message.hosts - mlx_setup(int(get_weights_size_kb(model_shard_meta) // 2**10), cache_frac_of_mrwss=0.8, wired_frac_of_mrwss=0.8) - - # For testing - these are fake break conditions - if model_shard_meta.immediate_exception: + if getattr(model_shard_meta, "immediate_exception", False): raise Exception("Fake exception - runner failed to spin up.") - if model_shard_meta.should_timeout: - await asyncio.sleep(model_shard_meta.should_timeout) + if timeout := getattr(model_shard_meta, "should_timeout", 0): + await asyncio.sleep(timeout) + + mlx_setup( + int(get_weights_size_kb(model_shard_meta) // 2**10), + cache_frac_of_mrwss=0.8, + wired_frac_of_mrwss=0.8 + ) setup_start_time = time.time() @@ -344,12 +71,12 @@ async def main(): sampler=sampler, ) runner_print(f"Warmed up by generating {toks} tokens") - runner_write_response( + await conn.send( InitializedResponse(time_taken=time.time() - setup_start_time) ) while True: - message = await runner_read_message() + message = await conn.recv() match message: case ChatTaskMessage(task_data=task): runner_print(f"received chat request: {str(task)[:500]}") @@ -376,16 +103,17 @@ async def main(): await asyncio.sleep(100) # Generate responses using the actual MLX generation - async for generation_response in _mlx_generate( + async for generation_response in mlx_generate( mlx_executor=mlx_executor, model=model, tokenizer=tokenizer, sampler=sampler, task=task, + conn=conn, ): - runner_write_response(generation_response) + await conn.send(generation_response) - runner_write_response(FinishedResponse()) + await conn.send(FinishedResponse()) case ExitMessage(): break case _: @@ -394,6 +122,3 @@ async def main(): except Exception as e: runner_write_error(e) - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/src/exo/worker/runner/runner_supervisor.py b/src/exo/worker/runner/runner_supervisor.py index 20a5fc09..d9cc638a 100644 --- a/src/exo/worker/runner/runner_supervisor.py +++ b/src/exo/worker/runner/runner_supervisor.py @@ -1,9 +1,13 @@ import asyncio import contextlib +import multiprocessing as mp +import os +import signal +import tempfile import traceback -from collections.abc import AsyncGenerator -from types import CoroutineType -from typing import Any, Callable, Optional +from multiprocessing import Process +from multiprocessing.connection import Connection +from typing import Any, AsyncGenerator, Callable, Coroutine, Optional import psutil from loguru import logger @@ -25,17 +29,15 @@ from exo.shared.types.worker.commands_runner import ( ) from exo.shared.types.worker.common import RunnerError from exo.shared.types.worker.communication import ( - supervisor_read_response, - supervisor_write_message, + AsyncConnection, ) from exo.shared.types.worker.shards import ShardMetadata +from exo.worker.runner.bootstrap import entrypoint from exo.worker.runner.utils import ( get_init_timeout, get_prefil_timeout, - get_runner_command, get_token_generate_timeout, get_weights_size_kb, - kill_process_tree, ) @@ -44,22 +46,22 @@ class RunnerSupervisor: self, model_shard_meta: ShardMetadata, hosts: list[Host], - runner_process: asyncio.subprocess.Process, + runner_process: Process, + conn: Connection, read_queue: asyncio.Queue[RunnerResponse], - write_queue: asyncio.Queue[RunnerMessage], - stderr_queue: asyncio.Queue[str], + err_path: str, ): self.model_shard_meta = model_shard_meta self.hosts = hosts self.runner_process = runner_process - self.read_queue = read_queue - self.write_queue = write_queue - self.stderr_queue = stderr_queue + self.conn = AsyncConnection[RunnerMessage, RunnerResponse](conn) + self._raw_conn = conn + self.read_queue = read_queue self.read_task = asyncio.create_task(self._read_coro()) - self.write_task = asyncio.create_task(self._write_coro()) - self.stderr_task = asyncio.create_task(self._watch_stderr()) + + self.err_path = err_path @classmethod async def create( @@ -72,29 +74,33 @@ class RunnerSupervisor: Create and initialize a RunnerSupervisor instance. The .create() classmethod pattern is used to ensure the constructor is asynchronous. """ - cmd: list[str] = get_runner_command() - runner_process = await asyncio.create_subprocess_exec( - *cmd, - stdin=asyncio.subprocess.PIPE, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, - ) + ctx = mp.get_context('spawn') + parent_conn, child_conn = ctx.Pipe(duplex=True) + + with tempfile.NamedTemporaryFile(prefix="child_stderr_", suffix=".log", delete=False) as tmp: + err_path = tmp.name - read_queue: asyncio.Queue[RunnerResponse] = asyncio.Queue() - write_queue: asyncio.Queue[RunnerMessage] = asyncio.Queue() - stderr_queue: asyncio.Queue[str] = asyncio.Queue() + runner_process = Process( + target=entrypoint, + args=(child_conn, err_path), + daemon=False + ) + runner_process.start() + child_conn.close() + + read_queue = asyncio.Queue[RunnerResponse]() self = cls( model_shard_meta=model_shard_meta, hosts=hosts, runner_process=runner_process, read_queue=read_queue, - write_queue=write_queue, - stderr_queue=stderr_queue, + conn=parent_conn, + err_path=err_path ) logger.info(f"Initializing mlx instance with {model_shard_meta=}") - await self.write_queue.put( + await self.conn.send( SetupMessage( model_shard_meta=model_shard_meta, hosts=hosts, @@ -104,23 +110,24 @@ class RunnerSupervisor: if not initialize_timeout: initialize_timeout = get_init_timeout(model_shard_meta) - response = await self._read_with_error_check(initialize_timeout) + response = await self._read_with_error_check(timeout=initialize_timeout) assert isinstance(response, InitializedResponse) logger.info(f"Runner initialized in {response.time_taken} seconds") return self - async def _read_with_error_check(self, timeout: float) -> RunnerResponse: + async def _read_with_error_check(self, timeout: float) -> RunnerResponse | None: """ Read from the queue with a timeout, but also check if the read_task has failed. """ - try: - assert not self.read_task.done() - except AssertionError as e_assert: + if self.read_task.done(): e = self.read_task.exception() - assert e is not None - raise e from e_assert + await self.astop() + if e is not None: + raise e + else: + return None queue_task = asyncio.create_task(self.read_queue.get()) @@ -135,78 +142,65 @@ class RunnerSupervisor: task.cancel() if queue_task in done: - response = await queue_task - if isinstance(response, ErrorResponse): - await self.astop() - raise RunnerError( - response.error_type, - response.error_message, - response.traceback or "", - ) - return response + return await queue_task if self.read_task in done: - try: - await self.read_task # Re-raises any exception from read_task - except Exception: - raise # bubble up exception + await self.astop() + await self.read_task # Re-raises any exception from read_task + + # This should never get hit. raise RunnerError("RunnerStopped", "Runner read loop terminated unexpectedly before any response.", "") # if we haven't read from the queue, we have timed out. await self.astop() # TODO: This could be handled by the called or _read_with_error_check - as we don't want a false Timeout to bring the whole runner down. raise asyncio.TimeoutError() + async def _read_coro(self): + while True: + try: + response: RunnerResponse = await self.conn.recv() + except EOFError as e_eof: + e = await self._raise_crashed() + if e is not None: + raise e from e_eof + break + + match response: + case PrintResponse(): + # TODO: THIS IS A REALLY IMPORTANT LOG MESSAGE, AND SHOULD BE MADE PRETTIER + logger.bind(user_facing=True).info(f"{response.text}") + case ErrorResponse(): + raise RunnerError(response.error_type, response.error_message, response.traceback) + case _: + await self.read_queue.put(response) + async def stream_response( self, task: Task, - request_started_callback: Callable[..., CoroutineType[Any, Any, None]] + request_started_callback: Callable[..., Coroutine[Any, Any, None]] | None = None, - ) -> AsyncGenerator[GenerationChunk]: + ) -> AsyncGenerator[GenerationChunk, None]: """ Streams a chat request from the model. The request is pushed to the runner, and if the shard is the terminal shard, the response is streamed back to the worker. request_started_callback is called once the request is pushed to the runner, used to publish InferencePrepareCompleted and InferenceTriggerCompleted events. """ - if not self.healthy: + if not self.runner_process.is_alive(): raise RuntimeError("Runner process was found to be dead") task_params = task.task_params assert isinstance( task_params, ChatCompletionTaskParams ) # this is messy for now. - await self.write_queue.put( + await self.conn.send( ChatTaskMessage( task_data=task_params, ), ) - while True: - try: - response = await self._read_with_error_check(5.0) - except asyncio.TimeoutError as e: - logger.bind(user_facing=True).error( - "Generation timed out during tokenization" - ) - raise e - except asyncio.LimitOverrunError as e: - raise RunnerError( - "IPCMessageTooLarge", - "The serialized prompt/response exceeded the IPC line limit. Switch to length-prefixed framing or reduce prompt size.", - "" - ) from e - - - match response: - case TokenizedResponse(): - prompt_tokens = response.prompt_tokens - break - case ErrorResponse(): - await self.astop() - raise RunnerError( - response.error_type, response.error_message, response.traceback - ) - case _: - raise ValueError(f"Unexpected response type found: {response}") + response = await self._read_with_error_check(5.0) + assert isinstance(response, TokenizedResponse) + prompt_tokens = response.prompt_tokens if request_started_callback is not None: await request_started_callback() @@ -240,42 +234,9 @@ class RunnerSupervisor: timeout = token_timeout case FinishedResponse(): break - case ErrorResponse(): - await self.astop() - raise RunnerError( - response.error_type, response.error_message, response.traceback - ) case _: raise ValueError(f"Unexpected response type found: {response}") - async def _write_coro(self): - while True: - message = await self.write_queue.get() - await supervisor_write_message(self.runner_process, message) - - async def _read_coro(self): - while True: - try: - response: RunnerResponse = await supervisor_read_response( - self.runner_process - ) - except EOFError: - e = await self._raise_crashed() - if e: - # Runner process died unexpectedly (C++ crash) - raise e from EOFError # TODO: Do we just want to create an error and put it on the read_queue here? - else: - continue - - match response: - case PrintResponse(): - # TODO: THIS IS A REALLY IMPORTANT LOG MESSAGE, AND SHOULD BE MADE PRETTIER - logger.bind(user_facing=True).info(f"{response.text}") - case ErrorResponse(): - ## Failure case #1: a crash happens Python, so it's neatly handled by passing an ErrorResponse with the details - await self.read_queue.put(response) - case _: - await self.read_queue.put(response) async def astop(self) -> None: # Cancel the stderr monitoring task @@ -285,12 +246,12 @@ class RunnerSupervisor: with contextlib.suppress(asyncio.CancelledError): await task - await await_task(self.stderr_task) await await_task(self.read_task) - await await_task(self.write_task) - # Kill the process and all its children - await kill_process_tree(self.runner_process) + self.runner_process.kill() + + with contextlib.suppress(Exception): + self._raw_conn.close() # Wait to make sure that the model has been unloaded from memory async def wait_for_memory_release() -> None: @@ -310,7 +271,7 @@ class RunnerSupervisor: await wait_for_memory_release() def __del__(self) -> None: - if self.runner_process.returncode is None: + if self.runner_process.is_alive(): logger.warning( "RunnerSupervisor was not stopped cleanly before garbage collection. Force killing process tree." ) @@ -331,51 +292,35 @@ class RunnerSupervisor: with contextlib.suppress(ProcessLookupError): self.runner_process.kill() - @property - def healthy(self) -> bool: - return ( - self.runner_process.returncode is None - and self.runner_process.stdin is not None - and not self.runner_process.stdin.is_closing() - and self.runner_process.stdout is not None - ) - - ## Failure case #2: a crash happens in MLX / C++ (eg segfault) that leads to error flushed to stderr and process dies async def _raise_crashed(self) -> Exception | None: - if self.runner_process.returncode == 0: + await asyncio.sleep(0.1) + + rc = self.runner_process.exitcode + if rc == 0: return None - await self.astop() + try: + with open(self.err_path, "r", errors="replace") as f: + captured = f.read() + finally: + with contextlib.suppress(OSError): + os.unlink(self.err_path) - # Accumulate all stderr messages from the queue - stderr_output = "" - while not self.stderr_queue.empty(): + # 2) Describe cause (signal vs exitcode) + cause = f"exitcode={rc}" + if isinstance(rc, int) and rc < 0: + sig = -rc try: - line = self.stderr_queue.get_nowait() - stderr_output += f"{line}\n" - except asyncio.QueueEmpty: - break + cause = f"signal={sig} ({signal.strsignal(sig)})" + except Exception: + cause = f"signal={sig}" logger.bind(user_facing=True).error( - f"Runner Error {self.runner_process.returncode}: {stderr_output}" + f"Runner terminated ({cause}).\n{captured}" ) + return RunnerError( - error_type="MLXCrash", - error_message=stderr_output, + error_type='RunnerCrash', + error_message=f"Runner terminated ({cause}).\n{captured}", traceback=traceback.format_exc(), ) - - async def _watch_stderr(self) -> None: - assert self.runner_process.stderr is not None - while True: - try: - line_bytes = await self.runner_process.stderr.readline() - if not line_bytes: - break - line = line_bytes.decode("utf-8").strip() - - await self.stderr_queue.put(line) - logger.warning(f"Runner stderr read: {line}") - except Exception as e: - logger.warning(f"Error reading runner stderr: {e}") - break diff --git a/src/exo/worker/tests/test_handlers/conftest.py b/src/exo/worker/tests/test_handlers/conftest.py index ccd1b75b..b05fb23a 100644 --- a/src/exo/worker/tests/test_handlers/conftest.py +++ b/src/exo/worker/tests/test_handlers/conftest.py @@ -77,6 +77,6 @@ async def worker_with_running_runner( # Is the runner actually running? supervisor = next(iter(worker.assigned_runners.values())).runner assert supervisor is not None - assert supervisor.healthy + assert supervisor.runner_process.is_alive() return worker, instance_obj diff --git a/src/exo/worker/tests/test_handlers/test_handlers_happy.py b/src/exo/worker/tests/test_handlers/test_handlers_happy.py index eaf8b078..7accd983 100644 --- a/src/exo/worker/tests/test_handlers/test_handlers_happy.py +++ b/src/exo/worker/tests/test_handlers/test_handlers_happy.py @@ -95,7 +95,7 @@ async def test_runner_up_op( # Is the runner actually running? supervisor = next(iter(worker.assigned_runners.values())).runner assert supervisor is not None - assert supervisor.healthy + assert supervisor.runner_process.is_alive() full_response = "" diff --git a/src/exo/worker/tests/test_integration/conftest.py b/src/exo/worker/tests/test_integration/conftest.py deleted file mode 100644 index b4e0ee7f..00000000 --- a/src/exo/worker/tests/test_integration/conftest.py +++ /dev/null @@ -1,41 +0,0 @@ -import asyncio -from logging import Logger -from typing import Awaitable, Callable - -import pytest - -from exo.shared.db.sqlite.connector import AsyncSQLiteEventStorage -from exo.shared.db.sqlite.event_log_manager import EventLogConfig, EventLogManager -from exo.shared.logging import logger_test_install -from exo.shared.types.common import NodeId -from exo.worker.download.shard_downloader import NoopShardDownloader -from exo.worker.main import run -from exo.worker.worker import Worker - - -@pytest.fixture -def worker_running( - logger: Logger, -) -> Callable[[NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]]]: - async def _worker_running( - node_id: NodeId, - ) -> tuple[Worker, AsyncSQLiteEventStorage]: - logger_test_install(logger) - event_log_manager = EventLogManager(EventLogConfig()) - await event_log_manager.initialize() - - global_events = event_log_manager.global_events - await global_events.delete_all_events() - - shard_downloader = NoopShardDownloader() - worker = Worker( - node_id, - shard_downloader=shard_downloader, - worker_events=global_events, - global_events=global_events, - ) - asyncio.create_task(run(worker)) - - return worker, global_events - - return _worker_running diff --git a/src/exo/worker/tests/test_integration/integration_utils.py b/src/exo/worker/tests/test_integration/integration_utils.py index c0fea3ed..50154020 100644 --- a/src/exo/worker/tests/test_integration/integration_utils.py +++ b/src/exo/worker/tests/test_integration/integration_utils.py @@ -1,12 +1,55 @@ import asyncio +import contextlib +from contextlib import asynccontextmanager +from logging import Logger from typing import Callable, Optional, Tuple, TypeVar from exo.shared.db.sqlite.connector import AsyncSQLiteEventStorage +from exo.shared.db.sqlite.event_log_manager import EventLogConfig, EventLogManager +from exo.shared.logging import logger_test_install +from exo.shared.types.common import NodeId from exo.shared.types.events import ChunkGenerated, TaskStateUpdated from exo.shared.types.events.chunks import TokenChunk from exo.shared.types.tasks import TaskId, TaskStatus +from exo.worker.download.shard_downloader import NoopShardDownloader +from exo.worker.main import run +from exo.worker.worker import Worker +@asynccontextmanager +async def worker_running(node_id: NodeId, logger: Logger): + """Context manager that provides a running worker and cleans up after.""" + logger_test_install(logger) + event_log_manager = EventLogManager(EventLogConfig()) + await event_log_manager.initialize() + + global_events = event_log_manager.global_events + await global_events.delete_all_events() + + shard_downloader = NoopShardDownloader() + worker = Worker( + node_id, + shard_downloader=shard_downloader, + worker_events=global_events, + global_events=global_events, + ) + + # Start the worker task + task = asyncio.create_task(run(worker)) + + try: + yield worker, global_events + finally: + # Cleanup + task.cancel() + with contextlib.suppress(asyncio.CancelledError, asyncio.TimeoutError): + await asyncio.wait_for(task, timeout=1.0) + + # Clean up any runners + for assigned_runner in worker.assigned_runners.values(): + if assigned_runner.runner: + await assigned_runner.runner.astop() + async def read_streaming_response( global_events: AsyncSQLiteEventStorage, filter_task: Optional[TaskId] = None ) -> Tuple[bool, bool, str, int]: diff --git a/src/exo/worker/tests/test_integration/test_creation.py b/src/exo/worker/tests/test_integration/test_creation.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/exo/worker/tests/test_integration/test_inference.py b/src/exo/worker/tests/test_integration/test_inference.py index 23399b6d..33a3c7ee 100644 --- a/src/exo/worker/tests/test_integration/test_inference.py +++ b/src/exo/worker/tests/test_integration/test_inference.py @@ -1,10 +1,9 @@ import asyncio from logging import Logger -from typing import Awaitable, Callable +from typing import Callable import pytest -from exo.shared.db.sqlite.connector import AsyncSQLiteEventStorage from exo.shared.db.sqlite.event_log_manager import EventLogConfig, EventLogManager from exo.shared.logging import logger_test_install from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams @@ -42,6 +41,7 @@ from exo.worker.tests.constants import ( ) from exo.worker.tests.test_integration.integration_utils import ( read_streaming_response, + worker_running, ) from exo.worker.worker import Worker @@ -52,50 +52,47 @@ def user_message(): return "What's the capital of Japan?" async def test_runner_inference( - worker_running: Callable[ - [NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]] - ], instance: Callable[[InstanceId, NodeId, RunnerId], Instance], chat_completion_task: Callable[[InstanceId, TaskId], Task], + logger: Logger, ): - _worker, global_events = await worker_running(NODE_A) + async with worker_running(NODE_A, logger) as (_, global_events): + instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID) + instance_value.instance_type = InstanceStatus.ACTIVE - instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID) - instance_value.instance_type = InstanceStatus.ACTIVE + task: Task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID) + await global_events.append_events( + [ + InstanceCreated( + instance=instance_value, + ), + TaskCreated(task_id=task.task_id, task=task), + ], + origin=MASTER_NODE_ID, + ) - task: Task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID) - await global_events.append_events( - [ - InstanceCreated( - instance=instance_value, - ), - TaskCreated(task_id=task.task_id, task=task), - ], - origin=MASTER_NODE_ID, - ) + # TODO: This needs to get fixed - sometimes it misses the 'starting' event. + ( + seen_task_started, + seen_task_finished, + response_string, + _, + ) = await read_streaming_response(global_events) - # TODO: This needs to get fixed - sometimes it misses the 'starting' event. - ( - seen_task_started, - seen_task_finished, - response_string, - _, - ) = await read_streaming_response(global_events) + assert seen_task_started + assert seen_task_finished + assert "tokyo" in response_string.lower() - assert seen_task_started - assert seen_task_finished - assert "tokyo" in response_string.lower() + await global_events.append_events( + [ + InstanceDeleted( + instance_id=instance_value.instance_id, + ), + ], + origin=MASTER_NODE_ID, + ) - await global_events.append_events( - [ - InstanceDeleted( - instance_id=instance_value.instance_id, - ), - ], - origin=MASTER_NODE_ID, - ) - - await asyncio.sleep(0.3) + await asyncio.sleep(0.3) async def test_2_runner_inference( @@ -112,13 +109,15 @@ async def test_2_runner_inference( global_events = event_log_manager.global_events await global_events.delete_all_events() + tasks: list[asyncio.Task[None]] = [] + worker1 = Worker( NODE_A, shard_downloader=shard_downloader, worker_events=global_events, global_events=global_events, ) - asyncio.create_task(run(worker1)) + tasks.append(asyncio.create_task(run(worker1))) worker2 = Worker( NODE_B, @@ -126,7 +125,7 @@ async def test_2_runner_inference( worker_events=global_events, global_events=global_events, ) - asyncio.create_task(run(worker2)) + tasks.append(asyncio.create_task(run(worker2))) ## Instance model_id = ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit") @@ -183,6 +182,21 @@ async def test_2_runner_inference( await asyncio.sleep(2.0) + for task in tasks: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass # This is expected when we cancel a task + except Exception: + pass # Suppress any other exceptions during cleanup + + + for worker in (worker1, worker2): + for assigned_runner in worker.assigned_runners.values(): + if assigned_runner.runner: + await assigned_runner.runner.astop() + # TODO: Multi message parallel async def test_2_runner_multi_message( @@ -198,13 +212,15 @@ async def test_2_runner_multi_message( global_events = event_log_manager.global_events await global_events.delete_all_events() + tasks: list[asyncio.Task[None]] = [] + worker1 = Worker( NODE_A, shard_downloader=shard_downloader, worker_events=global_events, global_events=global_events, ) - asyncio.create_task(run(worker1)) + tasks.append(asyncio.create_task(run(worker1))) worker2 = Worker( NODE_B, @@ -212,7 +228,7 @@ async def test_2_runner_multi_message( worker_events=global_events, global_events=global_events, ) - asyncio.create_task(run(worker2)) + tasks.append(asyncio.create_task(run(worker2))) ## Instance model_id = ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit") @@ -297,4 +313,18 @@ async def test_2_runner_multi_message( origin=MASTER_NODE_ID, ) + for task in tasks: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass # This is expected when we cancel a task + except Exception: + pass # Suppress any other exceptions during cleanup + + for worker in (worker1, worker2): + for assigned_runner in worker.assigned_runners.values(): + if assigned_runner.runner: + await assigned_runner.runner.astop() + await asyncio.sleep(2.0) diff --git a/src/exo/worker/tests/test_integration/test_inference_sad.py b/src/exo/worker/tests/test_integration/test_inference_sad.py index e42c92a7..e88bba39 100644 --- a/src/exo/worker/tests/test_integration/test_inference_sad.py +++ b/src/exo/worker/tests/test_integration/test_inference_sad.py @@ -1,13 +1,13 @@ import asyncio from collections.abc import AsyncGenerator +from logging import Logger from types import CoroutineType -from typing import Any, Awaitable, Callable +from typing import Any, Callable import pytest from _pytest.monkeypatch import MonkeyPatch # TaskStateUpdated and ChunkGenerated are used in test_worker_integration_utils.py -from exo.shared.db.sqlite.connector import AsyncSQLiteEventStorage from exo.shared.types.common import NodeId from exo.shared.types.events import ( ChunkGenerated, @@ -26,7 +26,6 @@ from exo.shared.types.worker.instances import ( InstanceStatus, ) from exo.shared.types.worker.runners import FailedRunnerStatus -from exo.worker.main import Worker from exo.worker.runner.runner_supervisor import RunnerSupervisor from exo.worker.tests.constants import ( INSTANCE_1_ID, @@ -35,7 +34,10 @@ from exo.worker.tests.constants import ( RUNNER_1_ID, TASK_1_ID, ) -from exo.worker.tests.test_integration.integration_utils import until_event_with_timeout +from exo.worker.tests.test_integration.integration_utils import ( + until_event_with_timeout, + worker_running, +) @pytest.fixture @@ -46,83 +48,78 @@ def user_message(): async def test_stream_response_failed_always( monkeypatch: MonkeyPatch, - worker_running: Callable[ - [NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]] - ], instance: Callable[[InstanceId, NodeId, RunnerId], Instance], + logger: Logger, chat_completion_task: Callable[[InstanceId, TaskId], Task], ) -> None: - _, global_events = await worker_running(NODE_A) + async with worker_running(NODE_A, logger) as (_, global_events): + instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID) + instance_value.instance_type = InstanceStatus.ACTIVE - instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID) - instance_value.instance_type = InstanceStatus.ACTIVE + async def mock_stream_response( + self: RunnerSupervisor, + task: Task, + request_started_callback: Callable[..., CoroutineType[Any, Any, None]] + | None = None, + ) -> AsyncGenerator[GenerationChunk]: + raise RuntimeError("Simulated stream response failure") + return + yield - async def mock_stream_response( - self: RunnerSupervisor, - task: Task, - request_started_callback: Callable[..., CoroutineType[Any, Any, None]] - | None = None, - ) -> AsyncGenerator[GenerationChunk]: - raise RuntimeError("Simulated stream response failure") - return - yield + monkeypatch.setattr(RunnerSupervisor, "stream_response", mock_stream_response) - monkeypatch.setattr(RunnerSupervisor, "stream_response", mock_stream_response) - - task: Task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID) - await global_events.append_events( - [ - InstanceCreated(instance=instance_value), - TaskCreated(task_id=task.task_id, task=task), - ], - origin=MASTER_NODE_ID, - ) - - await until_event_with_timeout(global_events, InstanceDeleted, timeout=10.0) - - events = await global_events.get_events_since(0) - - assert ( - len( + task: Task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID) + await global_events.append_events( [ - x - for x in events - if isinstance(x.event, RunnerStatusUpdated) - and isinstance(x.event.runner_status, FailedRunnerStatus) - ] + InstanceCreated(instance=instance_value), + TaskCreated(task_id=task.task_id, task=task), + ], + origin=MASTER_NODE_ID, ) - == 3 - ) - assert ( - len( + + await until_event_with_timeout(global_events, InstanceDeleted, timeout=10.0) + + events = await global_events.get_events_since(0) + + assert ( + len( + [ + x + for x in events + if isinstance(x.event, RunnerStatusUpdated) + and isinstance(x.event.runner_status, FailedRunnerStatus) + ] + ) + == 3 + ) + assert ( + len( + [ + x + for x in events + if isinstance(x.event, TaskStateUpdated) + and x.event.task_status == TaskStatus.FAILED + ] + ) + == 3 + ) + assert any([isinstance(x.event, InstanceDeleted) for x in events]) + + await global_events.append_events( [ - x - for x in events - if isinstance(x.event, TaskStateUpdated) - and x.event.task_status == TaskStatus.FAILED - ] + InstanceDeleted( + instance_id=instance_value.instance_id, + ), + ], + origin=MASTER_NODE_ID, ) - == 3 - ) - assert any([isinstance(x.event, InstanceDeleted) for x in events]) - await global_events.append_events( - [ - InstanceDeleted( - instance_id=instance_value.instance_id, - ), - ], - origin=MASTER_NODE_ID, - ) - - await asyncio.sleep(0.3) + await asyncio.sleep(0.3) async def test_stream_response_failed_once( monkeypatch: MonkeyPatch, - worker_running: Callable[ - [NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]] - ], + logger: Logger, instance: Callable[[InstanceId, NodeId, RunnerId], Instance], chat_completion_task: Callable[[InstanceId, TaskId], Task], ): @@ -148,160 +145,156 @@ async def test_stream_response_failed_once( monkeypatch.setattr(RunnerSupervisor, "stream_response", mock_stream_response) - worker, global_events = await worker_running(NODE_A) + async with worker_running(NODE_A, logger) as (worker, global_events): + instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID) + instance_value.instance_type = InstanceStatus.ACTIVE - instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID) - instance_value.instance_type = InstanceStatus.ACTIVE - - task: Task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID) - await global_events.append_events( - [ - InstanceCreated(instance=instance_value), - TaskCreated(task_id=task.task_id, task=task), - ], - origin=MASTER_NODE_ID, - ) - - await until_event_with_timeout( - global_events, - ChunkGenerated, - 1, - condition=lambda x: isinstance(x.chunk, TokenChunk) - and x.chunk.finish_reason is not None, - timeout=30.0, - ) - - # TODO: The ideal with this test is if we had some tooling to scroll through the state, and say - # 'asser that there was a time that the error_type, error_message was not none and the failure count was nonzero' - - # as we reset the failures back to zero when we have a successful inference. - assert len(worker.assigned_runners[RUNNER_1_ID].failures) == 0 - assert worker.state.tasks[TASK_1_ID].error_type is None - assert worker.state.tasks[TASK_1_ID].error_message is None - - events = await global_events.get_events_since(0) - assert ( - len( + task: Task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID) + await global_events.append_events( [ - x - for x in events - if isinstance(x.event, RunnerStatusUpdated) - and isinstance(x.event.runner_status, FailedRunnerStatus) - ] + InstanceCreated(instance=instance_value), + TaskCreated(task_id=task.task_id, task=task), + ], + origin=MASTER_NODE_ID, ) - == 1 - ) - assert ( - len( + + await until_event_with_timeout( + global_events, + ChunkGenerated, + 1, + condition=lambda x: isinstance(x.chunk, TokenChunk) + and x.chunk.finish_reason is not None, + timeout=30.0, + ) + + # TODO: The ideal with this test is if we had some tooling to scroll through the state, and say + # 'asser that there was a time that the error_type, error_message was not none and the failure count was nonzero' + + # as we reset the failures back to zero when we have a successful inference. + assert len(worker.assigned_runners[RUNNER_1_ID].failures) == 0 + assert worker.state.tasks[TASK_1_ID].error_type is None + assert worker.state.tasks[TASK_1_ID].error_message is None + + events = await global_events.get_events_since(0) + assert ( + len( + [ + x + for x in events + if isinstance(x.event, RunnerStatusUpdated) + and isinstance(x.event.runner_status, FailedRunnerStatus) + ] + ) + == 1 + ) + assert ( + len( + [ + x + for x in events + if isinstance(x.event, TaskStateUpdated) + and x.event.task_status == TaskStatus.FAILED + ] + ) + == 1 + ) + + response_string = "" + events = await global_events.get_events_since(0) + + seen_task_started, seen_task_finished = False, False + for wrapped_event in events: + event = wrapped_event.event + if isinstance(event, TaskStateUpdated): + if event.task_status == TaskStatus.RUNNING: + seen_task_started = True + if event.task_status == TaskStatus.COMPLETE: + seen_task_finished = True + + if isinstance(event, ChunkGenerated): + assert isinstance(event.chunk, TokenChunk) + response_string += event.chunk.text + + assert "queen" in response_string.lower() + assert seen_task_started + assert seen_task_finished + + await global_events.append_events( [ - x - for x in events - if isinstance(x.event, TaskStateUpdated) - and x.event.task_status == TaskStatus.FAILED - ] + InstanceDeleted( + instance_id=instance_value.instance_id, + ), + ], + origin=MASTER_NODE_ID, ) - == 1 - ) - response_string = "" - events = await global_events.get_events_since(0) - - seen_task_started, seen_task_finished = False, False - for wrapped_event in events: - event = wrapped_event.event - if isinstance(event, TaskStateUpdated): - if event.task_status == TaskStatus.RUNNING: - seen_task_started = True - if event.task_status == TaskStatus.COMPLETE: - seen_task_finished = True - - if isinstance(event, ChunkGenerated): - assert isinstance(event.chunk, TokenChunk) - response_string += event.chunk.text - - assert "queen" in response_string.lower() - assert seen_task_started - assert seen_task_finished - - await global_events.append_events( - [ - InstanceDeleted( - instance_id=instance_value.instance_id, - ), - ], - origin=MASTER_NODE_ID, - ) - - await asyncio.sleep(0.3) + await asyncio.sleep(0.3) async def test_stream_response_timeout( - worker_running: Callable[ - [NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]] - ], instance: Callable[[InstanceId, NodeId, RunnerId], Instance], chat_completion_task: Callable[[InstanceId, TaskId], Task], + logger: Logger, ): - _, global_events = await worker_running(NODE_A) + async with worker_running(NODE_A, logger) as (_, global_events): + instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID) + instance_value.instance_type = InstanceStatus.ACTIVE - instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID) - instance_value.instance_type = InstanceStatus.ACTIVE - - task: Task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID) - task.task_params.messages[0].content = "EXO RUNNER MUST TIMEOUT" - await global_events.append_events( - [ - InstanceCreated(instance=instance_value), - TaskCreated(task_id=task.task_id, task=task), - ], - origin=MASTER_NODE_ID, - ) - - await until_event_with_timeout(global_events, TaskFailed, multiplicity=3, timeout=30.0) - - events = await global_events.get_events_since(0) - print(events) - assert ( - len( + task: Task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID) + task.task_params.messages[0].content = "EXO RUNNER MUST TIMEOUT" + await global_events.append_events( [ - x - for x in events - if isinstance(x.event, RunnerStatusUpdated) - and isinstance(x.event.runner_status, FailedRunnerStatus) - ] + InstanceCreated(instance=instance_value), + TaskCreated(task_id=task.task_id, task=task), + ], + origin=MASTER_NODE_ID, ) - == 3 - ) - assert ( - len( - [ - x - for x in events - if isinstance(x.event, TaskStateUpdated) - and x.event.task_status == TaskStatus.FAILED - ] - ) - == 3 - ) - assert ( - len( - [ - x - for x in events - if isinstance(x.event, TaskFailed) - and "timeouterror" in x.event.error_type.lower() - ] - ) - == 3 - ) - await global_events.append_events( - [ - InstanceDeleted( - instance_id=instance_value.instance_id, - ), - ], - origin=MASTER_NODE_ID, - ) + await until_event_with_timeout(global_events, TaskFailed, multiplicity=3, timeout=30.0) - await asyncio.sleep(0.3) + events = await global_events.get_events_since(0) + print(events) + assert ( + len( + [ + x + for x in events + if isinstance(x.event, RunnerStatusUpdated) + and isinstance(x.event.runner_status, FailedRunnerStatus) + ] + ) + == 3 + ) + assert ( + len( + [ + x + for x in events + if isinstance(x.event, TaskStateUpdated) + and x.event.task_status == TaskStatus.FAILED + ] + ) + == 3 + ) + assert ( + len( + [ + x + for x in events + if isinstance(x.event, TaskFailed) + and "timeouterror" in x.event.error_type.lower() + ] + ) + == 3 + ) + + await global_events.append_events( + [ + InstanceDeleted( + instance_id=instance_value.instance_id, + ), + ], + origin=MASTER_NODE_ID, + ) + + await asyncio.sleep(0.3) diff --git a/src/exo/worker/tests/test_integration/test_instantiation.py b/src/exo/worker/tests/test_integration/test_instantiation.py index 8671777e..673afd92 100644 --- a/src/exo/worker/tests/test_integration/test_instantiation.py +++ b/src/exo/worker/tests/test_integration/test_instantiation.py @@ -1,7 +1,7 @@ -from typing import Awaitable, Callable +from logging import Logger +from typing import Callable # TaskStateUpdated and ChunkGenerated are used in test_worker_integration_utils.py -from exo.shared.db.sqlite.connector import AsyncSQLiteEventStorage from exo.shared.types.common import NodeId # TaskStateUpdated and ChunkGenerated are used in test_worker_integration_utils.py @@ -18,51 +18,50 @@ from exo.shared.types.worker.instances import ( from exo.shared.types.worker.runners import ( FailedRunnerStatus, ) -from exo.worker.main import Worker from exo.worker.tests.constants import ( INSTANCE_1_ID, MASTER_NODE_ID, NODE_A, RUNNER_1_ID, ) -from exo.worker.tests.test_integration.integration_utils import until_event_with_timeout +from exo.worker.tests.test_integration.integration_utils import ( + until_event_with_timeout, + worker_running, +) async def test_runner_spinup_timeout( - worker_running: Callable[ - [NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]] - ], instance: Callable[[InstanceId, NodeId, RunnerId], Instance], + logger: Logger, ): - _, global_events = await worker_running(NODE_A) + async with worker_running(NODE_A, logger) as (_, global_events): + instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID) + instance_value.instance_type = InstanceStatus.ACTIVE + instance_value.shard_assignments.runner_to_shard[RUNNER_1_ID].should_timeout = 10 - instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID) - instance_value.instance_type = InstanceStatus.ACTIVE - instance_value.shard_assignments.runner_to_shard[RUNNER_1_ID].should_timeout = 10 - - await global_events.append_events( - [InstanceCreated(instance=instance_value)], origin=MASTER_NODE_ID - ) - - await until_event_with_timeout( - global_events, - RunnerStatusUpdated, - multiplicity=3, - condition=lambda x: isinstance(x.runner_status, FailedRunnerStatus), - ) - - # Ensure the correct events have been emitted - events = await global_events.get_events_since(0) - - assert ( - len( - [ - x - for x in events - if isinstance(x.event, RunnerStatusUpdated) - and isinstance(x.event.runner_status, FailedRunnerStatus) - ] + await global_events.append_events( + [InstanceCreated(instance=instance_value)], origin=MASTER_NODE_ID ) - == 3 - ) - assert any([isinstance(x.event, InstanceDeleted) for x in events]) + + await until_event_with_timeout( + global_events, + RunnerStatusUpdated, + multiplicity=3, + condition=lambda x: isinstance(x.runner_status, FailedRunnerStatus), + ) + + # Ensure the correct events have been emitted + events = await global_events.get_events_since(0) + + assert ( + len( + [ + x + for x in events + if isinstance(x.event, RunnerStatusUpdated) + and isinstance(x.event.runner_status, FailedRunnerStatus) + ] + ) + == 3 + ) + assert any([isinstance(x.event, InstanceDeleted) for x in events]) \ No newline at end of file diff --git a/src/exo/worker/tests/test_integration/test_instantiation_sad.py b/src/exo/worker/tests/test_integration/test_instantiation_sad.py index c4329162..ed4b59e4 100644 --- a/src/exo/worker/tests/test_integration/test_instantiation_sad.py +++ b/src/exo/worker/tests/test_integration/test_instantiation_sad.py @@ -1,8 +1,8 @@ import asyncio -from typing import Awaitable, Callable +from logging import Logger +from typing import Callable # TaskStateUpdated and ChunkGenerated are used in test_worker_integration_utils.py -from exo.shared.db.sqlite.connector import AsyncSQLiteEventStorage from exo.shared.types.common import NodeId # TaskStateUpdated and ChunkGenerated are used in test_worker_integration_utils.py @@ -19,88 +19,84 @@ from exo.shared.types.worker.instances import ( from exo.shared.types.worker.runners import ( FailedRunnerStatus, ) -from exo.worker.main import Worker from exo.worker.tests.constants import ( INSTANCE_1_ID, MASTER_NODE_ID, NODE_A, RUNNER_1_ID, ) -from exo.worker.tests.test_integration.integration_utils import until_event_with_timeout +from exo.worker.tests.test_integration.integration_utils import ( + until_event_with_timeout, + worker_running, +) async def test_runner_spinup_exception( - worker_running: Callable[ - [NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]] - ], instance: Callable[[InstanceId, NodeId, RunnerId], Instance], + logger: Logger, ): - _, global_events = await worker_running(NODE_A) + async with worker_running(NODE_A, logger) as (_, global_events): + instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID) + instance_value.instance_type = InstanceStatus.ACTIVE + instance_value.shard_assignments.runner_to_shard[ + RUNNER_1_ID + ].immediate_exception = True - instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID) - instance_value.instance_type = InstanceStatus.ACTIVE - instance_value.shard_assignments.runner_to_shard[ - RUNNER_1_ID - ].immediate_exception = True - - await global_events.append_events( - [InstanceCreated(instance=instance_value)], origin=MASTER_NODE_ID - ) - - await asyncio.sleep(10.0) - - # Ensure the correct events have been emitted - events = await global_events.get_events_since(0) - - assert ( - len( - [ - x - for x in events - if isinstance(x.event, RunnerStatusUpdated) - and isinstance(x.event.runner_status, FailedRunnerStatus) - ] + await global_events.append_events( + [InstanceCreated(instance=instance_value)], origin=MASTER_NODE_ID ) - == 3 - ) - assert any([isinstance(x.event, InstanceDeleted) for x in events]) + + await asyncio.sleep(10.0) + + # Ensure the correct events have been emitted + events = await global_events.get_events_since(0) + + assert ( + len( + [ + x + for x in events + if isinstance(x.event, RunnerStatusUpdated) + and isinstance(x.event.runner_status, FailedRunnerStatus) + ] + ) + == 3 + ) + assert any([isinstance(x.event, InstanceDeleted) for x in events]) async def test_runner_spinup_timeout( - worker_running: Callable[ - [NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]] - ], instance: Callable[[InstanceId, NodeId, RunnerId], Instance], + logger: Logger, ): - _, global_events = await worker_running(NODE_A) + async with worker_running(NODE_A, logger) as (_, global_events): + instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID) + instance_value.instance_type = InstanceStatus.ACTIVE + instance_value.shard_assignments.runner_to_shard[RUNNER_1_ID].should_timeout = 10 - instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID) - instance_value.instance_type = InstanceStatus.ACTIVE - instance_value.shard_assignments.runner_to_shard[RUNNER_1_ID].should_timeout = 10 - - await global_events.append_events( - [InstanceCreated(instance=instance_value)], origin=MASTER_NODE_ID - ) - - await until_event_with_timeout( - global_events, - RunnerStatusUpdated, - multiplicity=3, - condition=lambda x: isinstance(x.runner_status, FailedRunnerStatus), - ) - - # Ensure the correct events have been emitted - events = await global_events.get_events_since(0) - - assert ( - len( - [ - x - for x in events - if isinstance(x.event, RunnerStatusUpdated) - and isinstance(x.event.runner_status, FailedRunnerStatus) - ] + await global_events.append_events( + [InstanceCreated(instance=instance_value)], origin=MASTER_NODE_ID ) - == 3 - ) - assert any([isinstance(x.event, InstanceDeleted) for x in events]) + + await until_event_with_timeout( + global_events, + RunnerStatusUpdated, + multiplicity=3, + condition=lambda x: isinstance(x.runner_status, FailedRunnerStatus), + ) + + # Ensure the correct events have been emitted + events = await global_events.get_events_since(0) + + assert ( + len( + [ + x + for x in events + if isinstance(x.event, RunnerStatusUpdated) + and isinstance(x.event.runner_status, FailedRunnerStatus) + ] + ) + == 3 + ) + assert any([isinstance(x.event, InstanceDeleted) for x in events]) diff --git a/src/exo/worker/tests/test_multimodel/test_inference_llama70B.py b/src/exo/worker/tests/test_multimodel/test_inference_llama70B.py index f36818c9..2cc9f7da 100644 --- a/src/exo/worker/tests/test_multimodel/test_inference_llama70B.py +++ b/src/exo/worker/tests/test_multimodel/test_inference_llama70B.py @@ -51,11 +51,12 @@ from exo.worker.tests.constants import ( from exo.worker.tests.test_integration.integration_utils import ( read_streaming_response, until_event_with_timeout, + worker_running, ) from exo.worker.worker import Worker MODEL_ID = "mlx-community/Llama-3.3-70B-Instruct-4bit" - +SKIP = True @pytest.fixture async def model_meta() -> ModelMetadata: @@ -72,9 +73,7 @@ def _get_model_size_gb(path: str) -> float: total_size += os.path.getsize(filepath) return total_size / (1024**3) # Convert bytes to GB - -@pytest.mark.skipif( - True or not ( +skip = SKIP or not ( os.path.exists( os.path.expanduser( "~/.exo/models/mlx-community--Llama-3.3-70B-Instruct-4bit/" @@ -86,7 +85,10 @@ def _get_model_size_gb(path: str) -> float: ) ) > 30 - ), +) + +@pytest.mark.skipif( + skip, reason="This test only runs when model mlx-community/Llama-3.3-70B-Instruct-4bit is downloaded", ) async def test_ttft( @@ -94,235 +96,208 @@ async def test_ttft( pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata], hosts: Callable[[int], list[Host]], ): - logger_test_install(logger) - event_log_manager = EventLogManager(EventLogConfig()) - await event_log_manager.initialize() - shard_downloader = NoopShardDownloader() + async with worker_running(NODE_A, logger) as (_, global_events): + ## Instance + model_id = ModelId(MODEL_ID) - global_events = event_log_manager.global_events - await global_events.delete_all_events() + shard_assignments = ShardAssignments( + model_id=model_id, + runner_to_shard={RUNNER_1_ID: pipeline_shard_meta(1, 0)}, + node_to_runner={NODE_A: RUNNER_1_ID}, + ) - worker1 = Worker( - NODE_A, - shard_downloader=shard_downloader, - worker_events=global_events, - global_events=global_events, - ) - asyncio.create_task(run(worker1)) + instance = Instance( + instance_id=INSTANCE_1_ID, + instance_type=InstanceStatus.ACTIVE, + shard_assignments=shard_assignments, + hosts=hosts(1), + ) - ## Instance - model_id = ModelId(MODEL_ID) + # Create instance first + await global_events.append_events( + [InstanceCreated(instance=instance)], origin=MASTER_NODE_ID + ) - shard_assignments = ShardAssignments( - model_id=model_id, - runner_to_shard={RUNNER_1_ID: pipeline_shard_meta(1, 0)}, - node_to_runner={NODE_A: RUNNER_1_ID}, - ) + await until_event_with_timeout( + global_events, + event_type=RunnerStatusUpdated, + condition=lambda x: isinstance(x.runner_status, LoadedRunnerStatus), + ) + logger.info("model loaded.") - instance = Instance( - instance_id=INSTANCE_1_ID, - instance_type=InstanceStatus.ACTIVE, - shard_assignments=shard_assignments, - hosts=hosts(1), - ) + # First inference + task1_params = ChatCompletionTaskParams( + model="gpt-4", + messages=[ + ChatCompletionMessage( + role="user", content="Please write a haiku about a flower." + ) + ], + stream=True, + max_tokens=100, + ) + task1 = ChatCompletionTask( + task_id=TASK_1_ID, + command_id=COMMAND_1_ID, + instance_id=INSTANCE_1_ID, + task_type=TaskType.CHAT_COMPLETION, + task_status=TaskStatus.PENDING, + task_params=task1_params, + ) - # Create instance first - await global_events.append_events( - [InstanceCreated(instance=instance)], origin=MASTER_NODE_ID - ) + print("Starting first inference...") + # Record the current event index before creating the task + idx_before_task1 = await global_events.get_last_idx() - await until_event_with_timeout( - global_events, - event_type=RunnerStatusUpdated, - condition=lambda x: isinstance(x.runner_status, LoadedRunnerStatus), - ) - logger.info("model loaded.") + task_created_time_1 = time.time() + await global_events.append_events( + [TaskCreated(task_id=task1.task_id, task=task1)], origin=MASTER_NODE_ID + ) - # First inference - task1_params = ChatCompletionTaskParams( - model="gpt-4", - messages=[ - ChatCompletionMessage( - role="user", content="Please write a haiku about a flower." - ) - ], - stream=True, - max_tokens=100, - ) - task1 = ChatCompletionTask( - task_id=TASK_1_ID, - command_id=COMMAND_1_ID, - instance_id=INSTANCE_1_ID, - task_type=TaskType.CHAT_COMPLETION, - task_status=TaskStatus.PENDING, - task_params=task1_params, - ) + # Wait for first chunk to measure time to first token + first_chunk_seen_1 = False + time_to_first_token_1: None | float = None + while not first_chunk_seen_1: + events = await global_events.get_events_since(idx_before_task1) + for wrapped_event in events: + if isinstance(wrapped_event.event, ChunkGenerated) and hasattr( + wrapped_event.event, "chunk" + ): + first_chunk_time_1 = time.time() + time_to_first_token_1 = first_chunk_time_1 - task_created_time_1 + first_chunk_seen_1 = True + break + if not first_chunk_seen_1: + await asyncio.sleep(0.01) - print("Starting first inference...") - # Record the current event index before creating the task - idx_before_task1 = await global_events.get_last_idx() + _, seen_task_finished_1, response_string_1, token_count_1 = await read_streaming_response( + global_events + ) + total_time_1 = time.time() - task_created_time_1 - task_created_time_1 = time.time() - await global_events.append_events( - [TaskCreated(task_id=task1.task_id, task=task1)], origin=MASTER_NODE_ID - ) + assert seen_task_finished_1 - # Wait for first chunk to measure time to first token - first_chunk_seen_1 = False - time_to_first_token_1: None | float = None - while not first_chunk_seen_1: - events = await global_events.get_events_since(idx_before_task1) - for wrapped_event in events: - if isinstance(wrapped_event.event, ChunkGenerated) and hasattr( - wrapped_event.event, "chunk" - ): - first_chunk_time_1 = time.time() - time_to_first_token_1 = first_chunk_time_1 - task_created_time_1 - first_chunk_seen_1 = True - break - if not first_chunk_seen_1: - await asyncio.sleep(0.01) + # Wait for first task to complete + await asyncio.sleep(5.0) - _, seen_task_finished_1, response_string_1, token_count_1 = await read_streaming_response( - global_events - ) - total_time_1 = time.time() - task_created_time_1 + # Second inference + task2_params = ChatCompletionTaskParams( + model="gpt-4", + messages=[ + ChatCompletionMessage( + role="user", content="Write me a haiku about a robot." + ) + ], + stream=True, + max_tokens=150, + ) + task2 = ChatCompletionTask( + task_id=TASK_2_ID, + command_id=COMMAND_2_ID, + instance_id=INSTANCE_1_ID, + task_type=TaskType.CHAT_COMPLETION, + task_status=TaskStatus.PENDING, + task_params=task2_params, + ) - assert seen_task_finished_1 + print("Starting second inference...") + # Record the current event index before creating the second task + idx_before_task2 = await global_events.get_last_idx() - # Wait for first task to complete - await asyncio.sleep(5.0) + task_created_time_2 = time.time() + await global_events.append_events( + [TaskCreated(task_id=task2.task_id, task=task2)], origin=MASTER_NODE_ID + ) - # Second inference - task2_params = ChatCompletionTaskParams( - model="gpt-4", - messages=[ - ChatCompletionMessage( - role="user", content="Write me a haiku about a robot." - ) - ], - stream=True, - max_tokens=150, - ) - task2 = ChatCompletionTask( - task_id=TASK_2_ID, - command_id=COMMAND_2_ID, - instance_id=INSTANCE_1_ID, - task_type=TaskType.CHAT_COMPLETION, - task_status=TaskStatus.PENDING, - task_params=task2_params, - ) + # Wait for first chunk of second task to measure time to first token + first_chunk_seen_2 = False + time_to_first_token_2: float | None = None + while not first_chunk_seen_2: + events = await global_events.get_events_since(idx_before_task2) + for wrapped_event in events: + if isinstance(wrapped_event.event, ChunkGenerated) and hasattr( + wrapped_event.event, "chunk" + ): + first_chunk_time_2 = time.time() + time_to_first_token_2 = first_chunk_time_2 - task_created_time_2 + first_chunk_seen_2 = True + break + if not first_chunk_seen_2: + await asyncio.sleep(0.01) - print("Starting second inference...") - # Record the current event index before creating the second task - idx_before_task2 = await global_events.get_last_idx() + _, seen_task_finished_2, response_string_2, token_count_2 = await read_streaming_response( + global_events, filter_task=TASK_2_ID + ) + total_time_2 = time.time() - task_created_time_2 - task_created_time_2 = time.time() - await global_events.append_events( - [TaskCreated(task_id=task2.task_id, task=task2)], origin=MASTER_NODE_ID - ) + assert seen_task_finished_2 + assert time_to_first_token_1 + assert time_to_first_token_2 - # Wait for first chunk of second task to measure time to first token - first_chunk_seen_2 = False - time_to_first_token_2: float | None = None - while not first_chunk_seen_2: - events = await global_events.get_events_since(idx_before_task2) - for wrapped_event in events: - if isinstance(wrapped_event.event, ChunkGenerated) and hasattr( - wrapped_event.event, "chunk" - ): - first_chunk_time_2 = time.time() - time_to_first_token_2 = first_chunk_time_2 - task_created_time_2 - first_chunk_seen_2 = True - break - if not first_chunk_seen_2: - await asyncio.sleep(0.01) + # Calculate TPS metrics + # Prompt is approximately 45 tokens according to user + prompt_tokens = 45 - _, seen_task_finished_2, response_string_2, token_count_2 = await read_streaming_response( - global_events, filter_task=TASK_2_ID - ) - total_time_2 = time.time() - task_created_time_2 + # Prefill TPS = prompt tokens / time to first token + prefill_tps_1 = prompt_tokens / time_to_first_token_1 if time_to_first_token_1 > 0 else 0 + prefill_tps_2 = prompt_tokens / time_to_first_token_2 if time_to_first_token_2 > 0 else 0 - assert seen_task_finished_2 - assert time_to_first_token_1 - assert time_to_first_token_2 + # Generation TPS = generated tokens / generation time + # Generation time = total time - time to first token + generation_time_1 = total_time_1 - time_to_first_token_1 + generation_time_2 = total_time_2 - time_to_first_token_2 + generation_tps_1 = token_count_1 / generation_time_1 if generation_time_1 > 0 else 0 + generation_tps_2 = token_count_2 / generation_time_2 if generation_time_2 > 0 else 0 - # Calculate TPS metrics - # Prompt is approximately 45 tokens according to user - prompt_tokens = 45 + # Display time to first token profiling results + print("\n=== Time to First Token Profiling ===") + print(f"First inference ('{task1.task_params.messages[0].content}'):") + print(f" Time to first token: {time_to_first_token_1:.3f}s") + print(f" Total completion time: {total_time_1:.3f}s") + print(f" Tokens generated: {token_count_1}") + print(f" Response length: {len(response_string_1)} chars") + print(f" Prefill TPS: {prefill_tps_1:.1f} tokens/sec ({prompt_tokens} prompt tokens / {time_to_first_token_1:.3f}s)") + print(f" Generation TPS: {generation_tps_1:.1f} tokens/sec ({token_count_1} tokens / {generation_time_1:.3f}s)") - # Prefill TPS = prompt tokens / time to first token - prefill_tps_1 = prompt_tokens / time_to_first_token_1 if time_to_first_token_1 > 0 else 0 - prefill_tps_2 = prompt_tokens / time_to_first_token_2 if time_to_first_token_2 > 0 else 0 + print(f"\nSecond inference ('{task2.task_params.messages[0].content}'):") + print(f" Time to first token: {time_to_first_token_2:.3f}s") + print(f" Total completion time: {total_time_2:.3f}s") + print(f" Tokens generated: {token_count_2}") + print(f" Response length: {len(response_string_2)} chars") + print(f" Prefill TPS: {prefill_tps_2:.1f} tokens/sec ({prompt_tokens} prompt tokens / {time_to_first_token_2:.3f}s)") + print(f" Generation TPS: {generation_tps_2:.1f} tokens/sec ({token_count_2} tokens / {generation_time_2:.3f}s)") - # Generation TPS = generated tokens / generation time - # Generation time = total time - time to first token - generation_time_1 = total_time_1 - time_to_first_token_1 - generation_time_2 = total_time_2 - time_to_first_token_2 - generation_tps_1 = token_count_1 / generation_time_1 if generation_time_1 > 0 else 0 - generation_tps_2 = token_count_2 / generation_time_2 if generation_time_2 > 0 else 0 + print("\nComparison:") + print(f" Second inference time to first token: {time_to_first_token_2/time_to_first_token_1:.2f}x the first") + print(f" Second inference prefill TPS: {prefill_tps_2/prefill_tps_1:.2f}x the first") + print(f" Second inference generation TPS: {generation_tps_2/generation_tps_1:.2f}x the first") - # Display time to first token profiling results - print("\n=== Time to First Token Profiling ===") - print(f"First inference ('{task1.task_params.messages[0].content}'):") - print(f" Time to first token: {time_to_first_token_1:.3f}s") - print(f" Total completion time: {total_time_1:.3f}s") - print(f" Tokens generated: {token_count_1}") - print(f" Response length: {len(response_string_1)} chars") - print(f" Prefill TPS: {prefill_tps_1:.1f} tokens/sec ({prompt_tokens} prompt tokens / {time_to_first_token_1:.3f}s)") - print(f" Generation TPS: {generation_tps_1:.1f} tokens/sec ({token_count_1} tokens / {generation_time_1:.3f}s)") + # Basic assertions to ensure responses make sense + assert len(response_string_1) > 0 + assert len(response_string_2) > 0 + assert time_to_first_token_1 and time_to_first_token_1 > 0 + assert time_to_first_token_2 and time_to_first_token_2 > 0 - print(f"\nSecond inference ('{task2.task_params.messages[0].content}'):") - print(f" Time to first token: {time_to_first_token_2:.3f}s") - print(f" Total completion time: {total_time_2:.3f}s") - print(f" Tokens generated: {token_count_2}") - print(f" Response length: {len(response_string_2)} chars") - print(f" Prefill TPS: {prefill_tps_2:.1f} tokens/sec ({prompt_tokens} prompt tokens / {time_to_first_token_2:.3f}s)") - print(f" Generation TPS: {generation_tps_2:.1f} tokens/sec ({token_count_2} tokens / {generation_time_2:.3f}s)") + # Cleanup + idx = await global_events.get_last_idx() + await asyncio.sleep(1.0) + events = await global_events.get_events_since(idx) + assert len(events) == 0 - print("\nComparison:") - print(f" Second inference time to first token: {time_to_first_token_2/time_to_first_token_1:.2f}x the first") - print(f" Second inference prefill TPS: {prefill_tps_2/prefill_tps_1:.2f}x the first") - print(f" Second inference generation TPS: {generation_tps_2/generation_tps_1:.2f}x the first") + await global_events.append_events( + [ + InstanceDeleted( + instance_id=instance.instance_id, + ), + ], + origin=MASTER_NODE_ID, + ) - # Basic assertions to ensure responses make sense - assert len(response_string_1) > 0 - assert len(response_string_2) > 0 - assert time_to_first_token_1 and time_to_first_token_1 > 0 - assert time_to_first_token_2 and time_to_first_token_2 > 0 - - # Cleanup - idx = await global_events.get_last_idx() - await asyncio.sleep(1.0) - events = await global_events.get_events_since(idx) - assert len(events) == 0 - - await global_events.append_events( - [ - InstanceDeleted( - instance_id=instance.instance_id, - ), - ], - origin=MASTER_NODE_ID, - ) - - await asyncio.sleep(2.0) + await asyncio.sleep(2.0) @pytest.mark.skipif( - True or not ( - os.path.exists( - os.path.expanduser( - "~/.exo/models/mlx-community--Llama-3.3-70B-Instruct-4bit/" - ) - ) - and _get_model_size_gb( - os.path.expanduser( - "~/.exo/models/mlx-community--Llama-3.3-70B-Instruct-4bit/" - ) - ) - > 30 - ), + skip, reason="This test only runs when model mlx-community/Llama-3.3-70B-Instruct-4bit is downloaded", ) async def test_2_runner_inference( @@ -339,13 +314,15 @@ async def test_2_runner_inference( global_events = event_log_manager.global_events await global_events.delete_all_events() + tasks: list[asyncio.Task[None]] = [] + worker1 = Worker( NODE_A, shard_downloader=shard_downloader, worker_events=global_events, global_events=global_events, ) - asyncio.create_task(run(worker1)) + tasks.append(asyncio.create_task(run(worker1))) worker2 = Worker( NODE_B, @@ -353,7 +330,7 @@ async def test_2_runner_inference( worker_events=global_events, global_events=global_events, ) - asyncio.create_task(run(worker2)) + tasks.append(asyncio.create_task(run(worker2))) ## Instance model_id = ModelId(MODEL_ID) @@ -417,21 +394,23 @@ async def test_2_runner_inference( await asyncio.sleep(2.0) + for task in tasks: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass # This is expected when we cancel a task + except Exception: + pass # Suppress any other exceptions during cleanup + + for worker in (worker1, worker2): + for assigned_runner in worker.assigned_runners.values(): + if assigned_runner.runner: + await assigned_runner.runner.astop() + @pytest.mark.skipif( - True or not ( - os.path.exists( - os.path.expanduser( - "~/.exo/models/mlx-community--Llama-3.3-70B-Instruct-4bit/" - ) - ) - and _get_model_size_gb( - os.path.expanduser( - "~/.exo/models/mlx-community--Llama-3.3-70B-Instruct-4bit/" - ) - ) - > 30 - ), + skip, reason="This test only runs when model mlx-community/Llama-3.3-70B-Instruct-4bit is downloaded", ) async def test_parallel_inference( @@ -448,13 +427,15 @@ async def test_parallel_inference( global_events = event_log_manager.global_events await global_events.delete_all_events() + tasks: list[asyncio.Task[None]] = [] + worker1 = Worker( NODE_A, shard_downloader=shard_downloader, worker_events=global_events, global_events=global_events, ) - asyncio.create_task(run(worker1)) + tasks.append(asyncio.create_task(run(worker1))) worker2 = Worker( NODE_B, @@ -462,7 +443,7 @@ async def test_parallel_inference( worker_events=global_events, global_events=global_events, ) - asyncio.create_task(run(worker2)) + tasks.append(asyncio.create_task(run(worker2))) ## Instance model_id = ModelId(MODEL_ID) @@ -579,3 +560,17 @@ async def test_parallel_inference( ) await asyncio.sleep(2.0) + + for task in tasks: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass # This is expected when we cancel a task + except Exception: + pass # Suppress any other exceptions during cleanup + + for worker in (worker1, worker2): + for assigned_runner in worker.assigned_runners.values(): + if assigned_runner.runner: + await assigned_runner.runner.astop() diff --git a/src/exo/worker/tests/test_runner_connection.py b/src/exo/worker/tests/test_runner_connection.py index a561de85..29e2f1ba 100644 --- a/src/exo/worker/tests/test_runner_connection.py +++ b/src/exo/worker/tests/test_runner_connection.py @@ -119,7 +119,7 @@ async def check_runner_connection( await asyncio.sleep(0.001) runner_supervisor = await wait_for_runner_supervisor(worker1, timeout=6.0) - ret = runner_supervisor is not None and runner_supervisor.healthy + ret = runner_supervisor is not None and runner_supervisor.runner_process.is_alive() await global_events.append_events( [ diff --git a/src/exo/worker/tests/test_supervisor/test_memory.py b/src/exo/worker/tests/test_supervisor/test_memory.py index c7c494ba..e250e5a4 100644 --- a/src/exo/worker/tests/test_supervisor/test_memory.py +++ b/src/exo/worker/tests/test_supervisor/test_memory.py @@ -1,5 +1,5 @@ -from asyncio.subprocess import Process from logging import Logger +from multiprocessing import Process from typing import Callable import psutil diff --git a/src/exo/worker/tests/test_supervisor/test_supervisor.py b/src/exo/worker/tests/test_supervisor/test_supervisor.py index 17756c18..1a7f7fb3 100644 --- a/src/exo/worker/tests/test_supervisor/test_supervisor.py +++ b/src/exo/worker/tests/test_supervisor/test_supervisor.py @@ -205,8 +205,7 @@ async def test_supervisor_handles_terminated_runner( supervisor.runner_process.terminate() await asyncio.sleep(0.1) - assert not supervisor.healthy - assert supervisor.runner_process.returncode is not None + assert not supervisor.runner_process.is_alive() del supervisor @@ -226,13 +225,12 @@ async def test_supervisor_handles_killed_runner( hosts=hosts(1, offset=10), ) - assert supervisor.healthy + assert supervisor.runner_process.is_alive() # Forcibly kill the runner supervisor.runner_process.kill() await asyncio.sleep(0.1) - assert not supervisor.healthy - assert supervisor.runner_process.returncode is not None + assert not supervisor.runner_process.is_alive() del supervisor diff --git a/src/exo/worker/tests/test_supervisor/test_supervisor_sad.py b/src/exo/worker/tests/test_supervisor/test_supervisor_sad.py index 959e41b2..87a06273 100644 --- a/src/exo/worker/tests/test_supervisor/test_supervisor_sad.py +++ b/src/exo/worker/tests/test_supervisor/test_supervisor_sad.py @@ -24,6 +24,11 @@ async def test_supervisor_instantiation_exception( model_shard_meta = pipeline_shard_meta(1, 0) model_shard_meta.immediate_exception = True + # _ = await RunnerSupervisor.create( + # model_shard_meta=model_shard_meta, + # hosts=hosts(1, offset=10), + # ) + with pytest.raises(RunnerError): _ = await RunnerSupervisor.create( model_shard_meta=model_shard_meta, diff --git a/src/exo/worker/worker.py b/src/exo/worker/worker.py index 7b7fa689..606f487a 100644 --- a/src/exo/worker/worker.py +++ b/src/exo/worker/worker.py @@ -240,25 +240,12 @@ class Worker: initialize_timeout=initialize_timeout, ) - if assigned_runner.runner.healthy: + if assigned_runner.runner.runner_process.is_alive(): assigned_runner.status = LoadedRunnerStatus() else: - # Log detailed reasons why the runner is not healthy runner = assigned_runner.runner - health_issues: list[str] = [] - - if runner.runner_process.returncode is not None: - health_issues.append( - f"runner_process.returncode is {runner.runner_process.returncode}" - ) - if runner.runner_process.stdin is None: - health_issues.append("runner_process.stdin is None") - elif runner.runner_process.stdin.is_closing(): - health_issues.append("runner_process.stdin is closing") - if runner.runner_process.stdout is None: - health_issues.append("runner_process.stdout is None") - - logger.warning(f"Runner status is not healthy: {', '.join(health_issues)}") + logger.warning(f"Runner status is not runner_process.is_alive(): exit code {runner.runner_process.exitcode}") + assigned_runner.status = FailedRunnerStatus() yield self.assigned_runners[op.runner_id].status_update_event() @@ -318,7 +305,7 @@ class Worker: ) assert assigned_runner.runner is not None - assert assigned_runner.runner.healthy + assert assigned_runner.runner.runner_process.is_alive() async for chunk in assigned_runner.runner.stream_response( task=op.task, request_started_callback=partial(running_callback, queue) @@ -407,7 +394,9 @@ class Worker: if runner_id in self.assigned_runners: assigned_runner = self.assigned_runners[runner_id] - assigned_runner.runner = None + if assigned_runner.runner is not None: + await assigned_runner.runner.astop() + assigned_runner.runner = None assigned_runner.status = FailedRunnerStatus(error_message=str(e)) assigned_runner.failures.append((time.time(), e))