From a33787f5fda4d7a78ee0793adb448aa6065241e5 Mon Sep 17 00:00:00 2001 From: Matt Beton Date: Fri, 29 Aug 2025 08:07:36 -0700 Subject: [PATCH] Prompt length --- pyproject.toml | 9 + src/exo/engines/mlx/__init__.py | 31 +++ src/exo/engines/mlx/utils_mlx.py | 106 ++++---- src/exo/shared/models/model_cards.py | 13 + .../shared/types/worker/commands_runner.py | 9 + src/exo/shared/types/worker/communication.py | 160 ++++++++++++ src/exo/worker/main.py | 4 +- src/exo/worker/runner/communication.py | 102 -------- src/exo/worker/runner/runner.py | 245 ++++++++++++++++-- src/exo/worker/runner/runner_supervisor.py | 38 ++- src/exo/worker/runner/utils.py | 31 ++- .../test_integration/integration_utils.py | 10 + .../tests/test_integration/test_inference.py | 7 + .../test_integration/test_inference_sad.py | 5 +- .../test_integration/test_instantiation.py | 40 --- .../test_instantiation_sad.py | 2 +- .../worker/tests/test_supervisor/test_long.py | 169 ++++++++++++ 17 files changed, 753 insertions(+), 228 deletions(-) create mode 100644 src/exo/shared/types/worker/communication.py delete mode 100644 src/exo/worker/runner/communication.py create mode 100644 src/exo/worker/tests/test_supervisor/test_long.py diff --git a/pyproject.toml b/pyproject.toml index 52e708e2..ba64ebba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,6 +87,15 @@ reportUnnecessaryTypeIgnoreComment = "error" pythonVersion = "3.13" pythonPlatform = "Darwin" +exclude = ["**/.venv", "**/venv", "**/__pycache__", "**/exo_scripts"] +stubPath = "typings" + +[[tool.basedpyright.executionEnvironments]] +root = "src" + +[[tool.basedpyright.executionEnvironments]] +root = "." + ### # uv configuration ### diff --git a/src/exo/engines/mlx/__init__.py b/src/exo/engines/mlx/__init__.py index e69de29b..3672ffac 100644 --- a/src/exo/engines/mlx/__init__.py +++ b/src/exo/engines/mlx/__init__.py @@ -0,0 +1,31 @@ +from typing import Optional + +from mlx_lm.models.cache import KVCache + +import mlx.core as mx +import mlx.nn as nn # type: ignore + +# These are wrapper functions to fix the fact that mlx is not strongly typed in the same way that EXO is. +# For example - MLX has no guarantee of the interface that nn.Module will expose. But we need a guarantee that it has a __call__() function + +class Model(nn.Module): + layers: list[nn.Module] + + def __call__(self, x: mx.array, cache: Optional[list[KVCache]]) -> mx.array: ... + + +class Detokenizer: + def reset(self) -> None: ... + def add_token(self, token: int) -> None: ... + def finalize(self) -> None: ... + + @property + def last_segment(self) -> str: ... + + +class TokenizerWrapper: + bos_token: Optional[str] + eos_token_ids: list[int] + detokenizer: Detokenizer + + def encode(self, text: str, add_special_tokens: bool = True) -> list[int]: ... \ No newline at end of file diff --git a/src/exo/engines/mlx/utils_mlx.py b/src/exo/engines/mlx/utils_mlx.py index daf1636b..e8df5a8d 100644 --- a/src/exo/engines/mlx/utils_mlx.py +++ b/src/exo/engines/mlx/utils_mlx.py @@ -4,27 +4,30 @@ import contextlib import os import resource from asyncio import AbstractEventLoop -from typing import Any, Callable +from typing import Any, Callable, Optional, cast -from mlx_lm.generate import stream_generate # type: ignore +from mlx_lm.models.cache import KVCache from mlx_lm.sample_utils import make_sampler -from mlx_lm.tokenizer_utils import TokenizerWrapper, load_tokenizer # type: ignore +from mlx_lm.tokenizer_utils import TokenizerWrapper as _TokenizerWrapper +from mlx_lm.tokenizer_utils import load_tokenizer # type: ignore from mlx_lm.utils import load_model # type: ignore from pydantic import RootModel import mlx.core as mx import mlx.nn as nn # pyright: ignore[reportMissingTypeStubs] -from exo.engines.mlx.auto_parallel import auto_parallel -from exo.shared.types.api import ChatCompletionMessage +from exo.engines.mlx import Model, TokenizerWrapper +from exo.engines.mlx.auto_parallel import IdentityLayer, auto_parallel from exo.shared.types.common import Host from exo.shared.types.tasks import ChatCompletionTaskParams +from exo.shared.types.worker.communication import runner_print from exo.shared.types.worker.shards import ShardMetadata from exo.worker.download.download_utils import build_model_path -from exo.worker.runner.communication import runner_print # Needed for 8 bit model resource.setrlimit(resource.RLIMIT_NOFILE, (2048, 4096)) +mlx_rank: None | int = None +mlx_world_size: None | int = None def mx_barrier(): mx.eval( # type: ignore @@ -33,6 +36,18 @@ def mx_barrier(): ) ) +def broadcast_from_zero(value: int) -> int: + if mlx_rank is None: + return value + + if mlx_rank == 0: + a = mx.array([value], dtype=mx.int32) + else: + a = mx.array([0], dtype=mx.int32) + + m = mx.distributed.all_sum(a, stream=mx.Device(mx.DeviceType.cpu)) + mx.eval(m) # type: ignore + return int(m.item()) # type: ignore class HostList(RootModel[list[str]]): @classmethod @@ -78,6 +93,7 @@ def mlx_distributed_init(rank: int, hosts: list[Host]) -> mx.distributed.Group: """ Initialize the MLX distributed (runs in thread pool) """ + global mlx_rank, mlx_world_size runner_print(f"Starting initialization for rank {rank}") # Setup distributed environment @@ -94,6 +110,8 @@ def mlx_distributed_init(rank: int, hosts: list[Host]) -> mx.distributed.Group: os.environ["MLX_RING_VERBOSE"] = "1" group = mx.distributed.init(backend="ring", strict=True) + mlx_rank = group.rank() + mlx_world_size = group.rank() runner_print(f"Rank {rank} mlx distributed initialization complete") return group @@ -102,7 +120,7 @@ def mlx_distributed_init(rank: int, hosts: list[Host]) -> mx.distributed.Group: def initialize_mlx( model_shard_meta: ShardMetadata, hosts: list[Host], -) -> tuple[nn.Module, TokenizerWrapper, Callable[[mx.array], mx.array]]: +) -> tuple[Model, TokenizerWrapper, Callable[[mx.array], mx.array]]: """ Initialize the MLX model, tokenizer, and sampler. Runs in the MLX thread. """ @@ -112,6 +130,7 @@ def initialize_mlx( sampler: Callable[[mx.array], mx.array] = make_sampler(temp=0.7) model, tokenizer = shard_and_load(model_shard_meta) + model = cast(Model, model) return model, tokenizer, sampler @@ -123,18 +142,19 @@ def shard_and_load( runner_print(f"loading model from {model_path}") - model, _ = load_model(model_path, lazy=True, strict=False) # type: ignore + model, config = load_model(model_path, lazy=True, strict=False) # type: ignore + runner_print(f'{config=}') assert isinstance(model, nn.Module) tokenizer = load_tokenizer(model_path) - assert isinstance(tokenizer, TokenizerWrapper) + assert isinstance(tokenizer, _TokenizerWrapper) model = auto_parallel(model, model_shard_meta) mx.eval(model.parameters()) # type: ignore # Synchronize processes before generation to avoid timeout mx_barrier() - return model, tokenizer + return model, tokenizer # type: ignore async def apply_chat_template( @@ -179,47 +199,37 @@ async def apply_chat_template( return prompt +class NullKVCache(KVCache): + """ + A KVCache that pretends to exist but holds zero tokens. + It satisfies .state/.meta_state and never allocates real keys/values. + """ + def __init__(self, dtype: mx.Dtype = mx.float16): + super().__init__() + # zero-length K/V so shapes/dtypes are defined but empty + self.keys = mx.zeros((1, 1, 0, 1), dtype=dtype) # pyright: ignore[reportUnknownMemberType] + self.values = mx.zeros((1, 1, 0, 1), dtype=dtype) # pyright: ignore[reportUnknownMemberType] + self.offset = 0 -async def warmup_inference( - mlx_executor: concurrent.futures.ThreadPoolExecutor, - model: nn.Module, - tokenizer: TokenizerWrapper, - sampler: Callable[[mx.array], mx.array], -) -> int: - loop = asyncio.get_running_loop() + @property + def state(self) -> tuple[mx.array, mx.array]: + # matches what mx.save_safetensors / mx.eval expect + return self.keys, self.values - warmup_prompt = await apply_chat_template( - mlx_executor=mlx_executor, - tokenizer=tokenizer, - chat_task_data=ChatCompletionTaskParams( - model="warmup", - messages=[ - ChatCompletionMessage( - role="user", - content="Prompt to warm up the inference engine. Repeat this.", - ) - ], - ), - ) - - tokens_generated = 0 - - def _generate_warmup(): - nonlocal tokens_generated - for _ in stream_generate( - model=model, - tokenizer=tokenizer, - prompt=warmup_prompt, - max_tokens=50, - sampler=sampler, - ): - tokens_generated += 1 - - await loop.run_in_executor(mlx_executor, _generate_warmup) - mx_barrier() - - return tokens_generated + @state.setter + def state(self, v: tuple[mx.array, mx.array]) -> None: + raise NotImplementedError('We should not be setting a NullKVCache.') +async def make_kv_cache( + model: Model, + max_kv_size: Optional[int] = None, +) -> list[KVCache]: + assert hasattr(model, 'layers') + + return [ + NullKVCache() if isinstance(layer, IdentityLayer) else KVCache() + for layer in model.layers + ] def mlx_force_oom(size: int = 40000) -> None: """ diff --git a/src/exo/shared/models/model_cards.py b/src/exo/shared/models/model_cards.py index ff0669ec..4b47559a 100644 --- a/src/exo/shared/models/model_cards.py +++ b/src/exo/shared/models/model_cards.py @@ -55,6 +55,19 @@ MODEL_CARDS: dict[str, ModelCard] = { n_layers=61, ), ), + "deepseek-v3.1:4bit": ModelCard( + short_id="deepseek-v3.1:4bit", + model_id="mlx-community/DeepSeek-V3.1-4bit", + name="DeepSeek V3.1 (4-bit)", + description="""DeepSeek V3.1 is a large language model trained on the DeepSeek V3.1 dataset.""", + tags=[], + metadata=ModelMetadata( + model_id="mlx-community/DeepSeek-V3.1-4bit", + pretty_name="DeepSeek V3.1 (4-bit)", + storage_size_kilobytes=754706307 // 2, # TODO !!!!! + n_layers=61, + ), + ), # deepseek r1 "deepseek-r1-0528:4bit": ModelCard( short_id="deepseek-r1-0528:4bit", diff --git a/src/exo/shared/types/worker/commands_runner.py b/src/exo/shared/types/worker/commands_runner.py index be3b27c5..512e81cc 100644 --- a/src/exo/shared/types/worker/commands_runner.py +++ b/src/exo/shared/types/worker/commands_runner.py @@ -52,6 +52,7 @@ RunnerMessageTypeAdapter: TypeAdapter[RunnerMessage] = TypeAdapter(RunnerMessage class RunnerResponseType(str, Enum): InitializedResponse = "initialized_response" + TokenizedResponse = "tokenized_response" GenerationResponse = "generation_response" FinishedResponse = "finished_response" PrintResponse = "print_response" @@ -72,6 +73,13 @@ class InitializedResponse(BaseRunnerResponse[RunnerResponseType.InitializedRespo time_taken: float +class TokenizedResponse(BaseRunnerResponse[RunnerResponseType.TokenizedResponse]): + type: Literal[RunnerResponseType.TokenizedResponse] = Field( + default=RunnerResponseType.TokenizedResponse, frozen=True + ) + prompt_tokens: int + + class GenerationResponse(BaseRunnerResponse[RunnerResponseType.GenerationResponse]): type: Literal[RunnerResponseType.GenerationResponse] = Field( default=RunnerResponseType.GenerationResponse, frozen=True @@ -106,6 +114,7 @@ class ErrorResponse(BaseRunnerResponse[RunnerResponseType.ErrorResponse]): RunnerResponse = Annotated[ InitializedResponse + | TokenizedResponse | GenerationResponse | PrintResponse | FinishedResponse diff --git a/src/exo/shared/types/worker/communication.py b/src/exo/shared/types/worker/communication.py new file mode 100644 index 00000000..a1ea6c4e --- /dev/null +++ b/src/exo/shared/types/worker/communication.py @@ -0,0 +1,160 @@ +import asyncio +import json +import struct +import sys +import traceback +from typing import Any, BinaryIO, Dict, Tuple, Union, cast + +from loguru import logger + +from exo.shared.types.worker.commands_runner import ( + ErrorResponse, + PrintResponse, + RunnerMessage, + RunnerMessageTypeAdapter, + RunnerResponse, + RunnerResponseType, + RunnerResponseTypeAdapter, +) + +### Utils - SAFE LENGTH READ/WRITE + +MAGIC = b"EXO1" +HDR_FMT = "!I" # 4-byte big-endian length + + +async def write_frame(stream: Union[asyncio.StreamWriter, Any], obj: Union[Dict[str, Any], bytes]) -> None: + """Write a length-prefixed frame to a stream.""" + payload = obj if isinstance(obj, bytes) else json.dumps(obj).encode("utf-8") + header = MAGIC + struct.pack(HDR_FMT, len(payload)) + stream.write(header + payload) + if hasattr(stream, 'drain'): + await stream.drain() + + +async def read_frame(stream: Union[asyncio.StreamReader, Any]) -> Dict[str, Any]: + """Read a length-prefixed frame from a stream.""" + # Read 8 bytes: 4-byte magic + 4-byte length + header: bytes = await stream.readexactly(8) + if header[:4] != MAGIC: + # Fallback to legacy newline mode for backward compatibility + # Reconstruct the partial line and read the rest + remaining: bytes = await stream.readline() + line = header + remaining + return cast(Dict[str, Any], json.loads(line.strip().decode('utf-8'))) + + (length,) = cast(Tuple[int], struct.unpack(HDR_FMT, header[4:])) + data: bytes = await stream.readexactly(length) + return cast(Dict[str, Any], json.loads(data.decode('utf-8'))) + + +def write_frame_sync(stream: BinaryIO, obj: Union[Dict[str, Any], bytes]) -> None: + """Synchronous version of write_frame for use in runner.""" + payload = obj if isinstance(obj, bytes) else json.dumps(obj).encode("utf-8") + header = MAGIC + struct.pack(HDR_FMT, len(payload)) + stream.write(header + payload) + stream.flush() + + +def read_frame_sync(stream: BinaryIO) -> Dict[str, Any]: + """Synchronous version of read_frame for use in runner.""" + # Read 8 bytes: 4-byte magic + 4-byte length + header: bytes = stream.read(8) + if not header or len(header) < 8: + raise EOFError("No more data to read") + + if header[:4] != MAGIC: + # Fallback to legacy newline mode for backward compatibility + # Reconstruct the partial line and read the rest + remaining: bytes = stream.readline() + if not remaining: + raise EOFError("No more data to read") + line = header + remaining + return cast(Dict[str, Any], json.loads(line.strip().decode('utf-8'))) + + (length,) = cast(Tuple[int], struct.unpack(HDR_FMT, header[4:])) + data: bytes = stream.read(length) + if len(data) < length: + raise EOFError(f"Expected {length} bytes, got {len(data)}") + return cast(Dict[str, Any], json.loads(data.decode('utf-8'))) + + + +### Utils - MESSAGE TO RUNNER + +async def supervisor_write_message( + proc: asyncio.subprocess.Process, message: RunnerMessage +) -> None: + assert proc.stdin is not None, ( + "proc.stdin should not be None when created with stdin=PIPE" + ) + + # Use model_dump_json to get proper JSON encoding for Pydantic types like IPv4Address + await write_frame(proc.stdin, message.model_dump_json().encode('utf-8')) + + +async def runner_read_message() -> RunnerMessage: + loop = asyncio.get_running_loop() + + # Use executor to avoid blocking the event loop + data: Dict[str, Any] = await loop.run_in_executor(None, read_frame_sync, sys.stdin.buffer) + + try: + return RunnerMessageTypeAdapter.validate_python(data) + except Exception as e: + raise ValueError(f"Error validating message: {data}") from e + + +### Utils - RESPONSE FROM RUNNER + +def runner_write_response(obj: RunnerResponse) -> None: + try: + # Use model_dump_json to get proper JSON encoding + write_frame_sync(sys.stdout.buffer, obj.model_dump_json().encode('utf-8')) + except BrokenPipeError: + # Supervisor has closed the pipe, silently exit + sys.exit(0) + + +async def supervisor_read_response( + proc: asyncio.subprocess.Process, +) -> RunnerResponse: + assert proc.stdout is not None, ( + "proc.stdout should not be None when created with stdout=PIPE" + ) + + data: Dict[str, Any] + try: + data = await read_frame(proc.stdout) + return RunnerResponseTypeAdapter.validate_python(data) + except EOFError: + raise EOFError('No more data to read when reading response from runner.') from None + except Exception as err: + raise ValueError(f"Error validating response: {err}") from err + + +### Utils - Runner Prints + + +def runner_print(text: str) -> None: + obj = PrintResponse( + type=RunnerResponseType.PrintResponse, + text=text, + ) + + runner_write_response(obj) + + +def runner_write_error(error: Exception) -> None: + # Skip writing error if it's a BrokenPipeError - supervisor is already gone + if isinstance(error, BrokenPipeError): + sys.exit(0) + + error_response: ErrorResponse = ErrorResponse( + type=RunnerResponseType.ErrorResponse, + error_type=type(error).__name__, + error_message=str(error), + traceback=traceback.format_exc(), + ) + runner_write_response(error_response) + logger.opt(exception=error).exception("Critical Runner error") diff --git a/src/exo/worker/main.py b/src/exo/worker/main.py index a44280a1..edb58f2c 100644 --- a/src/exo/worker/main.py +++ b/src/exo/worker/main.py @@ -47,8 +47,8 @@ async def run(worker: Worker): # run the op, synchronously blocking for now if op is not None: - logger.info(f"Executing op {op}") - logger.bind(user_facing=True).debug(f"Worker executing op: {op}") + logger.info(f"Executing op {str(op)[:500]}") + logger.bind(user_facing=True).debug(f"Worker executing op: {str(op)[:500]}") try: async for event in worker.execute_op(op): await worker.event_publisher(event) diff --git a/src/exo/worker/runner/communication.py b/src/exo/worker/runner/communication.py deleted file mode 100644 index d02ffb02..00000000 --- a/src/exo/worker/runner/communication.py +++ /dev/null @@ -1,102 +0,0 @@ -import asyncio -import sys -import traceback - -from loguru import logger - -from exo.shared.types.worker.commands_runner import ( - ErrorResponse, - PrintResponse, - RunnerMessage, - RunnerMessageTypeAdapter, - RunnerResponse, - RunnerResponseType, - RunnerResponseTypeAdapter, -) - -### Utils - MESSAGE TO RUNNER - - -async def supervisor_write_message( - proc: asyncio.subprocess.Process, message: RunnerMessage -) -> None: - assert proc.stdin is not None, ( - "proc.stdin should not be None when created with stdin=PIPE" - ) - - encoded: bytes = message.model_dump_json().encode("utf-8") + b"\n" - proc.stdin.write(encoded) - await proc.stdin.drain() - - -async def runner_read_message() -> RunnerMessage: - loop = asyncio.get_running_loop() - - line: bytes = await loop.run_in_executor(None, sys.stdin.buffer.readline) - if not line: # This seems to be what triggers when we don't clean up the runner neatly and leave the process dangling. - raise EOFError("No more data to read when reading runner message") - line = line.strip() - - try: - return RunnerMessageTypeAdapter.validate_json(line) - except Exception as e: - raise ValueError(f"Error validating message: {line}") from e - - -### Utils - RESPONSE FROM RUNNER - - -def runner_write_response(obj: RunnerResponse) -> None: - try: - encoded: bytes = obj.model_dump_json().encode("utf-8") + b"\n" - _ = sys.stdout.buffer.write(encoded) - _ = sys.stdout.buffer.flush() - except BrokenPipeError: - # Supervisor has closed the pipe, silently exit - sys.exit(0) - - -async def supervisor_read_response( - proc: asyncio.subprocess.Process, -) -> RunnerResponse: - assert proc.stdout is not None, ( - "proc.stdout should not be None when created with stdout=PIPE" - ) - # TODO: We could put a timeout on this if we decide to send heartbeats from the runner. - # This lets us handle cases where the process dies at some point not during an inference. - line_bytes: bytes = await proc.stdout.readline() - if not line_bytes: - raise EOFError('No more data to read when reading response from runner.') - line: str = line_bytes.decode("utf-8").strip() - - try: - return RunnerResponseTypeAdapter.validate_json(line) - except Exception as err: - raise ValueError(f"Error validating response: {line}") from err - - -### Utils - Runner Prints - - -def runner_print(text: str) -> None: - obj = PrintResponse( - type=RunnerResponseType.PrintResponse, - text=text, - ) - - runner_write_response(obj) - - -def runner_write_error(error: Exception) -> None: - # Skip writing error if it's a BrokenPipeError - supervisor is already gone - if isinstance(error, BrokenPipeError): - sys.exit(0) - - error_response: ErrorResponse = ErrorResponse( - type=RunnerResponseType.ErrorResponse, - error_type=type(error).__name__, - error_message=str(error), - traceback=traceback.format_exc(), - ) - runner_write_response(error_response) - logger.opt(exception=error).exception("Critical Runner error") diff --git a/src/exo/worker/runner/runner.py b/src/exo/worker/runner/runner.py index 287f1e2a..9d118512 100644 --- a/src/exo/worker/runner/runner.py +++ b/src/exo/worker/runner/runner.py @@ -3,21 +3,25 @@ import concurrent.futures import time from collections.abc import AsyncGenerator from functools import partial -from typing import Callable, cast +from typing import Callable, Generator, Optional, Tuple import mlx.core as mx -import mlx.nn as nn # pyright: ignore [reportMissingTypeStubs] -from mlx_lm.generate import stream_generate # type: ignore -from mlx_lm.tokenizer_utils import TokenizerWrapper +from mlx.core import array +from mlx_lm.generate import stream_generate as mlx_stream_generate +from mlx_lm.models import cache +from mlx_lm.models.cache import KVCache +from exo.engines.mlx import Model, TokenizerWrapper from exo.engines.mlx.utils_mlx import ( apply_chat_template, + broadcast_from_zero, initialize_mlx, + make_kv_cache, mlx_force_oom, mlx_setup, - warmup_inference, + mx_barrier, ) -from exo.shared.openai_compat import FinishReason +from exo.shared.types.api import ChatCompletionMessage from exo.shared.types.tasks import ChatCompletionTaskParams from exo.shared.types.worker.commands_runner import ( ChatTaskMessage, @@ -25,22 +29,216 @@ from exo.shared.types.worker.commands_runner import ( FinishedResponse, GenerationResponse, InitializedResponse, - RunnerMessage, SetupMessage, + TokenizedResponse, ) -from exo.shared.utils import ensure_type -from exo.worker.runner.communication import ( +from exo.shared.types.worker.communication import ( runner_print, runner_read_message, runner_write_error, runner_write_response, ) +from exo.shared.utils import ensure_type from exo.worker.runner.utils import get_weights_size_kb +generation_stream = mx.new_stream(mx.default_device()) + +def generate_step( + prompt: mx.array, + model: Model, + *, + max_tokens: int = 256, + sampler: Callable[[mx.array], mx.array], + max_kv_size: Optional[int] = None, + prompt_cache: Optional[list[KVCache]] = None, + prefill_step_size: int = 2048, +) -> Generator[Tuple[mx.array, mx.array], None, None]: + """ + A generator producing token ids based on the given prompt from the model. + + Args: + prompt (mx.array): The input prompt. + model (Model): The model to use for generation. + max_tokens (int): The maximum number of tokens. Use``-1`` for an infinite + generator. Default: ``256``. + sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a + token from a vector of log probabilities. Default: ``None``. + max_kv_size (int, optional): Maximum size of the key-value cache. Old + entries (except the first 4 tokens) will be overwritten. + prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if + provided, the cache will be updated in place. + prefill_step_size (int): Step size for processing the prompt. + + Yields: + Tuple[mx.array, mx.array]: One token and a vector of log probabilities. + """ + tokens = None + + # Create the KV cache for generation + if prompt_cache is None: + prompt_cache = cache.make_prompt_cache( + model, + max_kv_size=max_kv_size, + ) + + def _step(input_tokens: mx.array): + nonlocal tokens + + with mx.stream(generation_stream): + logits = model( + input_tokens[None], + cache=prompt_cache, + ) + + logits = logits[:, -1, :] + + logprobs = logits - mx.logsumexp(logits, keepdims=True) # pyright: ignore[reportUnknownMemberType] + sampled = sampler(logprobs) + return sampled, logprobs.squeeze(0) + + with mx.stream(generation_stream): + total_prompt_tokens = len(prompt) + prompt_processed_tokens = 0 + + while total_prompt_tokens - prompt_processed_tokens > prefill_step_size: + runner_print(f'Prefilling {min(prefill_step_size, len(prompt))} tokens. Remaining tokens: {len(prompt)}. Peak memory: {mx.get_peak_memory() // 2**30} GB') + logits = model( + prompt[:prefill_step_size][None], + cache=prompt_cache + ) + + start_time = time.time() + mx.eval([c.state for c in prompt_cache] + [logits]) # type: ignore + eval_time = time.time() - start_time + prompt_processed_tokens += prefill_step_size + + prompt = prompt[prefill_step_size:] + + mx.clear_cache() + if eval_time > 7.0: + prefill_step_size = prefill_step_size // 2 + prefill_step_size = broadcast_from_zero(prefill_step_size) + prefill_step_size = max(1, prefill_step_size) + + + runner_print('finished prefil.') + y, logprobs = _step(input_tokens=prompt) + + mx.async_eval(y, logprobs) # type: ignore + n = 0 + next_y: array | None = None + next_logprobs: array | None = None + while True: + if n != max_tokens and n > 0: # Only call _step after first iteration + next_y, next_logprobs = _step(y) + mx.async_eval(next_y, next_logprobs) # type: ignore + if n == 0: + mx.eval(y) # type: ignore + if n == max_tokens: + break + yield y, logprobs # y is always defined here, no need for cast + if n % 256 == 0: + mx.clear_cache() + if next_y is not None and next_logprobs is not None: + y, logprobs = next_y, next_logprobs + n += 1 + + + +def stream_generate( + model: Model, + tokenizer: TokenizerWrapper, + prompt: str, + max_tokens: int, + sampler: Callable[[mx.array], mx.array], + prompt_cache: Optional[list[KVCache]] = None, + prefill_step_size: int = 2048, +) -> Generator[GenerationResponse, None, None]: + + # Try to infer if special tokens are needed + add_special_tokens = tokenizer.bos_token is None or not prompt.startswith( + tokenizer.bos_token + ) + prompt_array: mx.array = mx.array(tokenizer.encode(prompt, add_special_tokens=add_special_tokens)) + runner_write_response(TokenizedResponse(prompt_tokens=len(prompt_array))) + + detokenizer = tokenizer.detokenizer + + token_generator: Generator[Tuple[array, array], None, None] = generate_step( + prompt_array, + model, + max_tokens=max_tokens, + sampler=sampler, + prompt_cache=prompt_cache, + prefill_step_size=prefill_step_size, + ) + + token = None + detokenizer.reset() + for token, _ in token_generator: + if token in tokenizer.eos_token_ids: + break + + detokenizer.add_token(int(token)) + + # TODO: We could put more metrics on this GenerationResponse if we wish + yield GenerationResponse( + text=detokenizer.last_segment, + token=int(token), + finish_reason=None, + ) + + assert token is not None + detokenizer.finalize() + yield GenerationResponse( + text=detokenizer.last_segment, + token=int(token), + finish_reason="stop" if token in tokenizer.eos_token_ids else "length", + ) + +async def warmup_inference( + mlx_executor: concurrent.futures.ThreadPoolExecutor, + model: Model, + tokenizer: TokenizerWrapper, + sampler: Callable[[mx.array], mx.array], +) -> int: + loop = asyncio.get_running_loop() + + warmup_prompt = await apply_chat_template( + mlx_executor=mlx_executor, + tokenizer=tokenizer, + chat_task_data=ChatCompletionTaskParams( + model="warmup", + messages=[ + ChatCompletionMessage( + role="user", + content="Prompt to warm up the inference engine. Repeat this.", + ) + ], + ), + ) + + tokens_generated = 0 + + def _generate_warmup(): + nonlocal tokens_generated + for _ in mlx_stream_generate( + model=model, + tokenizer=tokenizer, + prompt=warmup_prompt, + max_tokens=50, + sampler=sampler, + ): + tokens_generated += 1 + + await loop.run_in_executor(mlx_executor, _generate_warmup) + mx_barrier() + + return tokens_generated async def _mlx_generate( mlx_executor: concurrent.futures.ThreadPoolExecutor, - model: nn.Module, + model: Model, tokenizer: TokenizerWrapper, sampler: Callable[[mx.array], mx.array], task: ChatCompletionTaskParams, @@ -49,7 +247,7 @@ async def _mlx_generate( queue: asyncio.Queue[GenerationResponse | Exception | object] = asyncio.Queue() sentinel = object() - def _generate_tokens(prompt: str, max_tokens: int) -> None: + def _generate_tokens(prompt: str, max_tokens: int, cache: list[KVCache]) -> None: try: for generation_response in stream_generate( model=model, @@ -57,15 +255,10 @@ async def _mlx_generate( prompt=prompt, max_tokens=max_tokens, sampler=sampler, + prompt_cache=cache, + prefill_step_size=1024, ): - response = GenerationResponse( - text=generation_response.text, - token=generation_response.token, - finish_reason=cast( - FinishReason | None, generation_response.finish_reason - ), # has to be considered as a FinishReason instead of a str. - ) - _ = loop.call_soon_threadsafe(queue.put_nowait, response) + _ = loop.call_soon_threadsafe(queue.put_nowait, generation_response) except Exception as e: _ = loop.call_soon_threadsafe(queue.put_nowait, e) finally: @@ -80,8 +273,16 @@ async def _mlx_generate( chat_task_data=task, ) + cache_future = loop.run_in_executor( + mlx_executor, + lambda: asyncio.run(make_kv_cache( + model=model, + )) + ) + cache = await cache_future + max_tokens = task.max_tokens or 1000 - generation_fn = partial(_generate_tokens, prompt, max_tokens) + generation_fn = partial(_generate_tokens, prompt, max_tokens, cache) future = loop.run_in_executor(mlx_executor, generation_fn) @@ -142,10 +343,10 @@ async def main(): ) while True: - message: RunnerMessage = await runner_read_message() + message = await runner_read_message() match message: case ChatTaskMessage(task_data=task): - runner_print(f"received chat request: {task}") + runner_print(f"received chat request: {str(task)[:500]}") # Ensure we have a chat-completion task subtype # TODO: this is a hack, why are we only looking at the first message? should have a tokenizer prompt = task.messages[0] diff --git a/src/exo/worker/runner/runner_supervisor.py b/src/exo/worker/runner/runner_supervisor.py index 665f00c4..20a5fc09 100644 --- a/src/exo/worker/runner/runner_supervisor.py +++ b/src/exo/worker/runner/runner_supervisor.py @@ -21,13 +21,14 @@ from exo.shared.types.worker.commands_runner import ( RunnerMessage, RunnerResponse, SetupMessage, + TokenizedResponse, ) from exo.shared.types.worker.common import RunnerError -from exo.shared.types.worker.shards import ShardMetadata -from exo.worker.runner.communication import ( +from exo.shared.types.worker.communication import ( supervisor_read_response, supervisor_write_message, ) +from exo.shared.types.worker.shards import ShardMetadata from exo.worker.runner.utils import ( get_init_timeout, get_prefil_timeout, @@ -136,6 +137,7 @@ class RunnerSupervisor: if queue_task in done: response = await queue_task if isinstance(response, ErrorResponse): + await self.astop() raise RunnerError( response.error_type, response.error_message, @@ -178,12 +180,38 @@ class RunnerSupervisor: ), ) - # This is simpler for now: we say 'request started' as soon as we've told runner to start, without waiting for an ack. - # If we need more reliability, the runner can have a new 'ready' message type. + while True: + try: + response = await self._read_with_error_check(5.0) + except asyncio.TimeoutError as e: + logger.bind(user_facing=True).error( + "Generation timed out during tokenization" + ) + raise e + except asyncio.LimitOverrunError as e: + raise RunnerError( + "IPCMessageTooLarge", + "The serialized prompt/response exceeded the IPC line limit. Switch to length-prefixed framing or reduce prompt size.", + "" + ) from e + + + match response: + case TokenizedResponse(): + prompt_tokens = response.prompt_tokens + break + case ErrorResponse(): + await self.astop() + raise RunnerError( + response.error_type, response.error_message, response.traceback + ) + case _: + raise ValueError(f"Unexpected response type found: {response}") + if request_started_callback is not None: await request_started_callback() - prefil_timeout = get_prefil_timeout(self.model_shard_meta) + prefil_timeout = get_prefil_timeout(self.model_shard_meta, prompt_tokens=prompt_tokens) token_timeout = get_token_generate_timeout(self.model_shard_meta) timeout = prefil_timeout logger.bind(user_facing=True).info( diff --git a/src/exo/worker/runner/utils.py b/src/exo/worker/runner/utils.py index e3ddae62..1d68f377 100644 --- a/src/exo/worker/runner/utils.py +++ b/src/exo/worker/runner/utils.py @@ -67,14 +67,33 @@ def get_init_timeout(model_shard_meta: ShardMetadata) -> float: return weights_size_kb / kbps_read + 2.0 -def get_prefil_timeout(model_shard_meta: ShardMetadata) -> float: - return 30.0 # TODO: Proper prefil timeout calculation, but this requires knowing the number of tokens in the prompt. - weights_size_gb = get_weights_size_kb(model_shard_meta) / (1024 * 1024) - tokens = 1000 # constant for now - the prompt is only tokenized in the device... - prompt_gflops = tokens * weights_size_gb * 2 +def _prefill_flops_for_shard(model_shard_meta: ShardMetadata, s: int) -> float: + p = get_weights_size_kb(model_shard_meta) * 1024 + flops = 2.0 * p * s # parameter-dependent GEMMs + # flops += _attention_flops(meta, S) # optional S^2 term + return flops + +def get_prefil_timeout( + model_shard_meta: ShardMetadata, + prompt_tokens: int, + *, + effective_tflops: float = LB_TFLOPS, + safety_mult: float = 1.6, + base_pad_s: float = 5.0 +) -> float: + """ + Returns a conservative timeout (seconds) for the prefill stage. + """ + total_flops = _prefill_flops_for_shard(model_shard_meta, prompt_tokens) + + # Convert to seconds using sustained throughput + time_seconds = total_flops / (effective_tflops * 1e12) + + # Prefill across pipeline stages is largely sequential; summing FLOPs already accounts for it. + # Add a base pad (launch/IO) and a safety multiplier for variance. + return base_pad_s + safety_mult * time_seconds - return LB_TFLOPS / (1024 * prompt_gflops) * 3 + 10.0 def get_token_generate_timeout(model_shard_meta: ShardMetadata) -> float: diff --git a/src/exo/worker/tests/test_integration/integration_utils.py b/src/exo/worker/tests/test_integration/integration_utils.py index 9d088a70..c0fea3ed 100644 --- a/src/exo/worker/tests/test_integration/integration_utils.py +++ b/src/exo/worker/tests/test_integration/integration_utils.py @@ -74,9 +74,12 @@ async def until_event_with_timeout( event_type: type[T], multiplicity: int = 1, condition: Callable[[T], bool] = lambda x: True, + timeout: float = 30.0, ) -> None: idx = await global_events.get_last_idx() times_seen = 0 + start_time = asyncio.get_event_loop().time() + while True: events = await global_events.get_events_since(idx) if events: @@ -89,4 +92,11 @@ async def until_event_with_timeout( return idx = events[-1].idx_in_log + current_time = asyncio.get_event_loop().time() + if current_time - start_time > timeout: + raise asyncio.TimeoutError( + f"Timeout waiting for {multiplicity} events of type {event_type.__name__} " + f"(found {times_seen} in {timeout}s)" + ) + await asyncio.sleep(0.01) diff --git a/src/exo/worker/tests/test_integration/test_inference.py b/src/exo/worker/tests/test_integration/test_inference.py index 3d430f41..23399b6d 100644 --- a/src/exo/worker/tests/test_integration/test_inference.py +++ b/src/exo/worker/tests/test_integration/test_inference.py @@ -2,6 +2,8 @@ import asyncio from logging import Logger from typing import Awaitable, Callable +import pytest + from exo.shared.db.sqlite.connector import AsyncSQLiteEventStorage from exo.shared.db.sqlite.event_log_manager import EventLogConfig, EventLogManager from exo.shared.logging import logger_test_install @@ -44,6 +46,11 @@ from exo.worker.tests.test_integration.integration_utils import ( from exo.worker.worker import Worker +@pytest.fixture +def user_message(): + """Override this fixture in tests to customize the message""" + return "What's the capital of Japan?" + async def test_runner_inference( worker_running: Callable[ [NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]] diff --git a/src/exo/worker/tests/test_integration/test_inference_sad.py b/src/exo/worker/tests/test_integration/test_inference_sad.py index d5aa4688..e42c92a7 100644 --- a/src/exo/worker/tests/test_integration/test_inference_sad.py +++ b/src/exo/worker/tests/test_integration/test_inference_sad.py @@ -78,7 +78,7 @@ async def test_stream_response_failed_always( origin=MASTER_NODE_ID, ) - await until_event_with_timeout(global_events, InstanceDeleted) + await until_event_with_timeout(global_events, InstanceDeleted, timeout=10.0) events = await global_events.get_events_since(0) @@ -168,6 +168,7 @@ async def test_stream_response_failed_once( 1, condition=lambda x: isinstance(x.chunk, TokenChunk) and x.chunk.finish_reason is not None, + timeout=30.0, ) # TODO: The ideal with this test is if we had some tooling to scroll through the state, and say @@ -256,7 +257,7 @@ async def test_stream_response_timeout( origin=MASTER_NODE_ID, ) - await until_event_with_timeout(global_events, TaskFailed, multiplicity=3) + await until_event_with_timeout(global_events, TaskFailed, multiplicity=3, timeout=30.0) events = await global_events.get_events_since(0) print(events) diff --git a/src/exo/worker/tests/test_integration/test_instantiation.py b/src/exo/worker/tests/test_integration/test_instantiation.py index dc0773b2..8671777e 100644 --- a/src/exo/worker/tests/test_integration/test_instantiation.py +++ b/src/exo/worker/tests/test_integration/test_instantiation.py @@ -1,4 +1,3 @@ -import asyncio from typing import Awaitable, Callable # TaskStateUpdated and ChunkGenerated are used in test_worker_integration_utils.py @@ -29,45 +28,6 @@ from exo.worker.tests.constants import ( from exo.worker.tests.test_integration.integration_utils import until_event_with_timeout -async def test_runner_spinup_exception( - worker_running: Callable[ - [NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]] - ], - instance: Callable[[InstanceId, NodeId, RunnerId], Instance], -): - _, global_events = await worker_running(NODE_A) - - instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID) - instance_value.instance_type = InstanceStatus.ACTIVE - instance_value.shard_assignments.runner_to_shard[ - RUNNER_1_ID - ].immediate_exception = True - - await global_events.append_events( - [InstanceCreated(instance=instance_value)], origin=MASTER_NODE_ID - ) - - await asyncio.sleep(5.0) - - # Ensure the correct events have been emitted - events = await global_events.get_events_since(0) - - assert ( - len( - [ - x - for x in events - if isinstance(x.event, RunnerStatusUpdated) - and isinstance(x.event.runner_status, FailedRunnerStatus) - and x.event.runner_status.error_message is not None - and "fake exception" in x.event.runner_status.error_message.lower() - ] - ) - == 3 - ) - assert any([isinstance(x.event, InstanceDeleted) for x in events]) - - async def test_runner_spinup_timeout( worker_running: Callable[ [NodeId], Awaitable[tuple[Worker, AsyncSQLiteEventStorage]] diff --git a/src/exo/worker/tests/test_integration/test_instantiation_sad.py b/src/exo/worker/tests/test_integration/test_instantiation_sad.py index beb73acf..c4329162 100644 --- a/src/exo/worker/tests/test_integration/test_instantiation_sad.py +++ b/src/exo/worker/tests/test_integration/test_instantiation_sad.py @@ -47,7 +47,7 @@ async def test_runner_spinup_exception( [InstanceCreated(instance=instance_value)], origin=MASTER_NODE_ID ) - await asyncio.sleep(5.0) + await asyncio.sleep(10.0) # Ensure the correct events have been emitted events = await global_events.get_events_since(0) diff --git a/src/exo/worker/tests/test_supervisor/test_long.py b/src/exo/worker/tests/test_supervisor/test_long.py new file mode 100644 index 00000000..51381ba5 --- /dev/null +++ b/src/exo/worker/tests/test_supervisor/test_long.py @@ -0,0 +1,169 @@ +import asyncio +from logging import Logger +from typing import Callable + +import pytest + +from exo.shared.logging import logger_test_install +from exo.shared.models.model_cards import MODEL_CARDS +from exo.shared.openai_compat import FinishReason +from exo.shared.types.common import Host +from exo.shared.types.events.chunks import TokenChunk +from exo.shared.types.tasks import ( + Task, + TaskId, +) +from exo.shared.types.worker.common import InstanceId +from exo.shared.types.worker.shards import PipelineShardMetadata +from exo.worker.runner.runner_supervisor import RunnerSupervisor + + +@pytest.fixture +def user_message(): + """Override the default message to ask about France's capital""" + return "What is the capital of France?" + +@pytest.fixture +def lorem_ipsum() -> str: + return """ +Lorem ipsum dolor sit amet, consectetur adipiscing elit. Phasellus rhoncus felis in velit tempus tristique. Nullam ipsum lectus, tristique a eros quis, ullamcorper accumsan lorem. Aliquam ut auctor elit, finibus porttitor neque. In cursus augue facilisis ante ullamcorper, at sollicitudin quam aliquam. Etiam ac lacinia lacus, et aliquet nunc. Phasellus nisi ex, feugiat quis dolor non, mollis consequat nulla. Suspendisse gravida, sem non lobortis viverra, turpis lacus elementum orci, in tristique augue tortor nec mauris. Curabitur aliquet lorem in rhoncus mollis. Aliquam pulvinar elit odio, ac feugiat magna luctus nec. Pellentesque non risus egestas, pellentesque arcu tincidunt, gravida risus. Etiam ut lorem ac lorem pharetra efficitur. Donec augue arcu, varius nec lorem vitae, suscipit semper tellus. Aliquam dignissim quis augue id fermentum. Proin aliquet pellentesque est, eget tincidunt odio ullamcorper vel. Suspendisse potenti. +Aenean imperdiet justo sit amet erat aliquet tristique. Sed tempus, turpis a cursus lobortis, ante sem imperdiet est, eu dapibus sapien velit eget elit. Donec feugiat sed risus sed scelerisque. Donec posuere tempor orci, sit amet pellentesque est efficitur non. Vivamus sodales pretium purus, sed rutrum enim auctor ut. Cras pharetra vitae libero et hendrerit. Sed nec tempus odio. Proin blandit facilisis scelerisque. Nulla in mattis mi. Etiam bibendum efficitur aliquam. Proin ut risus aliquet, rhoncus lectus non, rhoncus arcu. Nam nibh felis, ultrices a elit sed, ultricies sollicitudin tellus. Interdum et malesuada fames ac ante ipsum primis in faucibus. Maecenas faucibus magna ut purus imperdiet faucibus. Nam fermentum nulla fermentum magna aliquam, vel lacinia neque euismod. Donec tincidunt sed neque non facilisis. +Proin id lorem cursus, vehicula ante non, lacinia metus. Nam egestas dui a iaculis convallis. Ut suscipit justo est, nec pharetra ante accumsan ac. Pellentesque nec nisi ipsum. Duis non arcu neque. Curabitur non luctus purus. Phasellus pulvinar commodo lacus sit amet auctor. Ut ut mattis metus, eu auctor arcu. Etiam a suscipit est. Morbi orci mauris, suscipit tempus fermentum vel, luctus faucibus lectus. Aliquam a euismod arcu. Suspendisse porttitor eget libero vitae laoreet. +Fusce congue lorem mi, a mollis felis efficitur quis. Quisque lobortis scelerisque arcu, a varius sapien. Nulla eget orci non urna imperdiet tincidunt. Nunc mi massa, consectetur id lorem consectetur, molestie dignissim sem. Suspendisse et augue magna. Mauris id tempus velit, cursus suscipit tortor. Duis non mi non nisi fringilla maximus in et erat. +Proin consequat sapien eget tellus aliquam ultrices. Nunc hendrerit semper massa, pulvinar sodales ipsum condimentum eu. Proin vel ligula venenatis, lobortis lectus eu, vehicula justo. Mauris eu arcu at orci vehicula feugiat non eu metus. Duis ut vestibulum quam. Maecenas dolor elit, egestas ut purus sit amet, convallis lobortis massa. Ut volutpat augue ac ante consectetur dignissim. Maecenas vitae felis elementum, semper augue eu, auctor dolor. Ut pulvinar convallis tortor non volutpat. Curabitur vulputate sem sodales sapien pretium ultrices. Sed luctus libero vitae urna eleifend tincidunt. Proin pulvinar imperdiet cursus. Suspendisse ullamcorper laoreet leo dapibus tincidunt. Pellentesque molestie elementum felis. +Integer vitae congue nulla. Vestibulum ante ipsum primis in faucibus orci luctus et ultrices posuere cubilia curae; Vestibulum elit velit, malesuada quis ipsum et, imperdiet varius velit. Nam tristique viverra maximus. Curabitur eget semper lectus. Sed vitae lorem sit amet mi lacinia posuere ac a risus. Pellentesque et magna nisl. In hac habitasse platea dictumst. Aenean suscipit, nibh vitae sollicitudin commodo, risus mi commodo neque, nec venenatis velit augue sed massa. Nam tempus, arcu id eleifend auctor, est dui viverra odio, vel convallis arcu dolor id quam. Ut malesuada ligula vel interdum eleifend. In posuere ultrices tincidunt. Sed non enim sit amet lectus sagittis mattis eu at sapien. Pellentesque eu urna mollis, vehicula dolor eget, lobortis nisl. Suspendisse ex nisi, iaculis non sapien ac, fringilla rutrum dolor. Quisque pretium mauris nec ante gravida, sed laoreet neque viverra. +Donec mattis orci sit amet tincidunt maximus. Vestibulum ante ipsum primis in faucibus orci luctus et ultrices posuere cubilia curae; Curabitur tristique venenatis lectus, vel pulvinar sem. Sed vel dolor lacinia, aliquet nisi ac, bibendum libero. Nullam vulputate euismod augue ac imperdiet. Proin at fermentum sapien. Nam et fringilla lorem. Aenean sed lacus sed tellus sodales mattis ut rutrum ex. Nulla ligula diam, interdum quis faucibus sit amet, laoreet vel massa. Fusce mauris massa, tempor quis tempus nec, dictum a ligula. Ut at dapibus sapien. Nullam sem lorem, sollicitudin non dui a, consequat molestie mauris. Quisque sem nulla, vehicula nec vulputate ac, viverra in massa. Lorem ipsum dolor sit amet, consectetur adipiscing elit. Curabitur pretium venenatis nisi non bibendum. Nam vitae ligula auctor, rutrum lectus eget, feugiat augue. +Ut nunc risus, vehicula at metus non, consequat suscipit risus. Mauris eget sem in neque tincidunt iaculis. Pellentesque lacus leo, molestie ut pharetra sit amet, porta nec neque. Aliquam eu bibendum odio. Proin tempus bibendum ornare. Morbi non risus vitae ante tempor porta quis sed augue. Nullam hendrerit nulla in eleifend tincidunt. Integer suscipit ligula at nunc blandit vehicula. Nam porttitor leo in turpis suscipit malesuada. Etiam sodales nunc nisi, pharetra malesuada nibh varius in. Cras quis pellentesque augue, vitae convallis velit. In et dui lorem. Integer semper eros eget augue posuere, ac elementum tellus convallis. Praesent blandit tempus ultrices. Suspendisse nec dui vitae neque varius eleifend. Sed pretium metus leo, id viverra tellus scelerisque in. +Aenean sodales urna vitae lobortis cursus. Sed vitae pellentesque erat, fermentum pellentesque urna. Suspendisse potenti. Sed porttitor placerat turpis non vestibulum. Duis in nisi non purus venenatis tempus non eu nisi. Sed bibendum sapien vitae ultricies condimentum. Integer vel mattis lectus, consequat congue ex. Cras convallis odio volutpat nulla vehicula efficitur. Pellentesque eget justo neque. Morbi mattis vitae magna et suscipit. Etiam orci sapien, tincidunt non tellus eget, laoreet vestibulum massa. Pellentesque habitant morbi tristique senectus et netus et malesuada fames ac turpis egestas. Mauris nec nisi enim. Donec risus odio, lobortis in odio malesuada, laoreet rutrum urna. Nunc sit amet euismod quam. +Fusce rhoncus ullamcorper nunc, ut pellentesque nisi dictum sed. Fusce sem mi, bibendum ut dictum at, porta in libero. Pellentesque placerat mollis sapien, sed eleifend lorem consequat in. Phasellus vel tempor ligula. Pellentesque tincidunt suscipit tortor vel blandit. Maecenas purus mi, mattis ac aliquam vel, rutrum eu nulla. Proin rhoncus nec sem a congue. Pellentesque sit amet sapien quam. Sed hendrerit neque id venenatis dignissim. +Vestibulum laoreet eu felis nec aliquam. Praesent gravida ornare odio nec porttitor. Donec ut tellus eros. Proin fringilla urna augue, vitae ornare leo varius non. Curabitur consectetur, purus in iaculis finibus, lectus lacus porttitor dolor, nec eleifend tellus massa eget tellus. Mauris sit amet convallis risus, a fermentum lorem. Suspendisse potenti. Curabitur vulputate finibus maximus. Interdum et malesuada fames ac ante ipsum primis in faucibus. In vel erat pellentesque, rhoncus magna vel, scelerisque mauris. +Nulla facilisi. Morbi mattis felis nec accumsan varius. Vestibulum in sodales arcu. Vivamus egestas, ante nec dapibus vestibulum, tellus ipsum rhoncus mi, at fermentum sapien justo nec turpis. Quisque rhoncus, urna sit amet imperdiet cursus, tortor lacus ultricies sapien, eu bibendum ligula enim id mi. Sed sem leo, pharetra in pulvinar sed, faucibus sed dui. Morbi tempus erat nec neque placerat tincidunt. +Quisque ut lorem sodales magna faucibus mattis. Aenean dui neque, gravida ut fringilla non, fermentum sit amet dolor. Mauris a sapien lacinia, elementum dolor in, sagittis metus. Donec viverra magna non lorem rutrum, at eleifend lacus volutpat. Nunc sit amet dolor tempor, blandit sapien a, consectetur magna. Suspendisse maximus nunc nec imperdiet aliquet. Nunc aliquam interdum purus quis pretium. Mauris molestie feugiat pellentesque. Nunc maximus, est sed consequat malesuada, risus turpis consequat velit, ac feugiat nunc magna vitae ligula. Vestibulum tincidunt massa ante, vitae pellentesque tortor rutrum sed. Aliquam vel est libero. Suspendisse et convallis orci. Cras sed lorem consectetur, blandit massa sit amet, semper neque. Vestibulum et mi euismod, imperdiet justo non, facilisis libero. +Sed at lacus ac tortor dictum tempus. Integer commodo purus lacus, ut pretium est tempor ac. Ut vulputate nulla magna, ac facilisis velit commodo in. Interdum et malesuada fames ac ante ipsum primis in faucibus. Donec pellentesque congue nibh nec eleifend. Ut ante turpis, sodales sed aliquet quis, tempus eu dui. Proin et eros non risus porttitor pharetra. +Mauris a urna id justo gravida ultrices. Mauris commodo sed ipsum a dictum. In posuere luctus scelerisque. Morbi sit amet gravida ipsum. Quisque vel dui sit amet ex lobortis eleifend non vel neque. Fusce sit amet imperdiet felis, eu tempor diam. Pellentesque sit amet turpis in libero tristique posuere. Pellentesque habitant morbi tristique senectus et netus et malesuada fames ac turpis egestas. Mauris quis est suscipit, tristique odio elementum, molestie nibh. Maecenas ex dui, pulvinar quis pellentesque sed, imperdiet nec mauris. Pellentesque ultrices at mauris eget fringilla. Donec bibendum rhoncus felis, ut pretium nulla eleifend commodo. +Ut euismod erat accumsan tincidunt sagittis. Proin eget massa ex. Suspendisse at faucibus enim, vitae posuere mi. Cras nec ex finibus, porttitor purus quis, efficitur libero. Nulla sagittis ornare iaculis. Donec venenatis dui ut libero aliquam lobortis. Vestibulum imperdiet lorem urna, eget gravida orci sollicitudin ut. Quisque ultrices tortor at quam laoreet aliquet. Pellentesque tincidunt consequat pharetra. Cras a lacinia erat. Mauris sed neque lobortis ipsum facilisis hendrerit. +Cras at orci odio. Curabitur eros metus, consequat non placerat et, tincidunt at turpis. Morbi quis viverra metus. Vestibulum molestie, ex at suscipit finibus, ex magna pellentesque nisi, eu ullamcorper nisl sapien eu quam. Phasellus volutpat lacinia enim, nec fermentum augue tincidunt ut. Duis rutrum purus eu nulla elementum, a faucibus odio fringilla. Sed cursus risus neque, dictum luctus tortor tempus eu. +Mauris non arcu eu nunc faucibus tincidunt id quis dolor. Quisque ac fringilla libero. Sed non ligula ut nunc auctor consequat vitae eget metus. Ut suscipit leo quam, vitae ultrices urna feugiat eu. Vestibulum volutpat nisl quis nunc pretium, vel viverra orci fringilla. Proin erat nibh, laoreet nec nisi sit amet, volutpat efficitur nunc. Cras id tortor quis lectus imperdiet rutrum non id purus. Proin efficitur ligula non dapibus consectetur. Nam quis quam eget dui facilisis scelerisque. Praesent non bibendum risus. Etiam imperdiet nisi id consectetur porta. In pretium nulla ut leo ultricies rhoncus. +Curabitur non vehicula purus. Cras et justo risus. Duis et rutrum urna. Aliquam condimentum purus nec ante dignissim rhoncus. Vestibulum commodo pharetra eros, ac euismod orci rutrum vel. Integer sed cursus erat, euismod accumsan libero. Nullam ut odio sit amet nibh tempor congue. Pellentesque porttitor aliquam ipsum, sit amet facilisis quam fringilla ac. Aliquam scelerisque tempor nisl in tempor. Sed vestibulum, tellus sit amet mattis pellentesque, eros diam convallis felis, id pellentesque massa leo quis dolor. Integer dignissim orci lorem, vel porttitor felis blandit et. Nam ultrices enim sed elementum accumsan. Fusce rutrum, quam et feugiat maximus, lorem leo porttitor ex, a eleifend risus odio consectetur lacus. In hac habitasse platea dictumst. Aenean pharetra erat tellus, at tempus urna iaculis ut. Ut ac mi eu lorem volutpat egestas. +Vestibulum ante ipsum primis in faucibus orci luctus et ultrices posuere cubilia curae; Praesent porttitor tempor ligula. Quisque mollis arcu in metus ornare pellentesque. Aenean ultrices mollis quam quis sodales. Maecenas a cursus elit, id gravida tortor. Donec vel purus magna. Aliquam elementum est sed convallis fermentum. Nam nec eros arcu. Pellentesque sed eros a lacus sagittis maximus. Integer et tellus id libero dapibus convallis. Maecenas viverra, purus facilisis porttitor tincidunt, tellus lacus elementum dui, sed porttitor sem justo a lorem. Curabitur ipsum odio, efficitur quis efficitur at, tempus aliquet nisi. Aliquam ultrices tortor in arcu vulputate, vel iaculis lorem facilisis. Cras eleifend laoreet feugiat. Integer placerat blandit sem, mattis elementum purus pellentesque quis. Etiam vel arcu ut mi commodo placerat non id tortor. +""" + +@pytest.mark.asyncio +async def test_supervisor_long_prompt_response( + pipeline_shard_meta: Callable[..., PipelineShardMetadata], + hosts: Callable[..., list[Host]], + chat_completion_task: Callable[[InstanceId, TaskId], Task], + lorem_ipsum: str, + logger: Logger, +): + """Test that asking for the capital of France returns 'Paris' in the response""" + logger_test_install(logger) + + model_meta = MODEL_CARDS['llama-3.2-1b'].metadata + model_shard_meta = PipelineShardMetadata( + model_meta=model_meta, + device_rank=0, + world_size=1, + n_layers=model_meta.n_layers, + start_layer=0, + end_layer=model_meta.n_layers, + ) + instance_id = InstanceId() + + print(f"{model_shard_meta=}") + + supervisor = await RunnerSupervisor.create( + model_shard_meta=model_shard_meta, + hosts=hosts(1, offset=10), + ) + + try: + full_response = "" + + task = chat_completion_task(instance_id, TaskId()) + task.task_params.messages[0].content = lorem_ipsum * 3 + + + async for chunk in supervisor.stream_response( + task=task + ): + if isinstance(chunk, TokenChunk): + full_response += chunk.text + + assert len(full_response) > 100 + + finally: + await supervisor.astop() + + +@pytest.mark.asyncio +async def test_supervisor_two_node_long_prompt_response( + pipeline_shard_meta: Callable[..., PipelineShardMetadata], + hosts: Callable[..., list[Host]], + chat_completion_task: Callable[[InstanceId, TaskId], Task], + lorem_ipsum: str, + logger: Logger, +): + """Test two-node long prompt inference""" + logger_test_install(logger) + instance_id = InstanceId() + + async def create_supervisor(shard_idx: int) -> RunnerSupervisor: + model_meta = MODEL_CARDS['llama-3.2-1b'].metadata + model_shard_meta = PipelineShardMetadata( + model_meta=model_meta, + device_rank=shard_idx, + world_size=2, + n_layers=model_meta.n_layers, + start_layer=0 if shard_idx == 0 else model_meta.n_layers // 2, + end_layer=model_meta.n_layers // 2 if shard_idx == 0 else model_meta.n_layers, + ) + supervisor = await RunnerSupervisor.create( + model_shard_meta=model_shard_meta, + hosts=hosts(2, offset=15), + ) + return supervisor + + create_supervisor_0 = asyncio.create_task(create_supervisor(0)) + create_supervisor_1 = asyncio.create_task(create_supervisor(1)) + supervisor_0, supervisor_1 = await asyncio.gather( + create_supervisor_0, create_supervisor_1 + ) + + await asyncio.sleep(0.1) + + try: + full_response_0 = "" + full_response_1 = "" + stop_reason_0: FinishReason | None = None + stop_reason_1: FinishReason | None = None + + task = chat_completion_task(instance_id, TaskId()) + task.task_params.messages[0].content = lorem_ipsum * 3 + + async def collect_response_0(): + nonlocal full_response_0, stop_reason_0 + async for chunk in supervisor_0.stream_response(task=task): + if isinstance(chunk, TokenChunk): + full_response_0 += chunk.text + if chunk.finish_reason: + stop_reason_0 = chunk.finish_reason + + async def collect_response_1(): + nonlocal full_response_1, stop_reason_1 + async for chunk in supervisor_1.stream_response(task=task): + if isinstance(chunk, TokenChunk): + full_response_1 += chunk.text + if chunk.finish_reason: + stop_reason_1 = chunk.finish_reason + + # Run both stream responses simultaneously + _ = await asyncio.gather(collect_response_0(), collect_response_1()) + + assert len(full_response_0) > 100 + assert len(full_response_1) > 100 + + finally: + await supervisor_0.astop() + await supervisor_1.astop() +