diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml index 84212658..12df2a84 100644 --- a/.idea/inspectionProfiles/Project_Default.xml +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -4,9 +4,8 @@ diff --git a/.mlx_typings/mlx_lm/models/cache.pyi b/.mlx_typings/mlx_lm/models/cache.pyi index 177dde3a..37f96845 100644 --- a/.mlx_typings/mlx_lm/models/cache.pyi +++ b/.mlx_typings/mlx_lm/models/cache.pyi @@ -2,14 +2,24 @@ This type stub file was generated by pyright. """ -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Protocol, Literal, Self import mlx.nn as nn from mlx.core import array +import mlx.core as mx + +class Cache(Protocol): + keys: mx.array + values: mx.array + def update_and_fetch(self, keys: mx.array, values: mx.array) -> None: ... + @property + def state(self) -> tuple[mx.array, mx.array]: ... + @state.setter + def state(self, v) -> None: ... def make_prompt_cache( model: nn.Module, max_kv_size: Optional[int] = ... -) -> List[KVCache | Any]: +) -> List[Cache | Any]: """ Construct the model's cache for use in generation. @@ -24,7 +34,7 @@ def make_prompt_cache( """ def save_prompt_cache( - file_name: str, cache: List[Any], metadata: Dict[str, str] = ... + file_name: str, cache: List[Cache], metadata: Dict[str, str] = ... ) -> None: """ Save a pre-computed prompt cache to a file. @@ -50,12 +60,12 @@ def load_prompt_cache(file_name: str, return_metadata=...) -> array: the metadata if requested. """ -def can_trim_prompt_cache(cache: List[Any]) -> bool: +def can_trim_prompt_cache(cache: List[Cache]) -> bool: """ Check if model's cache can be trimmed. """ -def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]: +def trim_prompt_cache(cache: List[Cache], num_tokens: int) -> List[Cache]: """ Trim the model's cache by the given number of tokens. @@ -72,27 +82,22 @@ def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]: def create_attention_mask( N: int, offset: int, return_array: bool, window_size: Optional[int] -): # -> array | Literal['causal'] | None: - ... +) -> array | Literal["causal"] | None: ... -class _BaseCache: +class _BaseCache(Cache): + keys: mx.array + values: mx.array @property - def state(self): # -> list[Any]: - ... + def state(self) -> tuple[mx.array, mx.array]: ... @state.setter - def state(self, v): # -> None: - ... + def state(self, v) -> None: ... @property - def meta_state(self): # -> Literal['']: - ... + def meta_state(self) -> Literal[""]: ... @meta_state.setter - def meta_state(self, v): # -> None: - ... - def is_trimmable(self): # -> Literal[False]: - ... + def meta_state(self, v) -> None: ... + def is_trimmable(self) -> Literal[False]: ... @classmethod - def from_state(cls, state, meta_state): # -> Self: - ... + def from_state(cls, state, meta_state) -> Self: ... class ConcatenateKVCache(_BaseCache): """ConcatenateKVCache the simplest KV cache implementation. diff --git a/TODO.md b/TODO.md index 85577411..fb5ef0d9 100644 --- a/TODO.md +++ b/TODO.md @@ -1,6 +1,5 @@ -1. Currently EXO just doesn't start cleanly a lot of the time. I see two kinds of issues: - b. EXO starts but then after creating an instance that instance never loads (either gets stuck in Loading of Inactive). 2. Currently a lot of requests from the API are timing out, but we still process those requests internally. If an API request times out, we should cancel all corresponding tasks to that API request (why process a request with nobody listening). +3. Task cancellation. When API http request gets cancelled, it should cancel corresponding task. 4. I'd like to see profiled network latency / bandwidth. 5. I'd like to see how much bandwidth each link is using. 6. We should handle the case where one machine doesn't have the model downloaded and then other machines are waiting on it. In this case we get loads of timeout errors because the others are waiting for the one that needs to download the model. @@ -14,41 +13,13 @@ 16. Dynamically switch to higher priority connection when it becomes available. Probably bring back InstanceReplacedAtomically. 17. Faster model loads by streaming model from other devices in cluster. 18. Add support for specifying the type of network connection to use in a test. Depends on 15/16. -19. Fix mx.distributed.Group typing. 20. Add chat completion cancellations (e.g OpenWebUI has something for cancelling an ongoing request). -21. Make two separate things: tensor or pipeline, and ring or ibv. -22. When downloading for the first time, stuff times out and I think the model never ends up actually loading into memory, or something. 23. Do we need cache_limit? We went back and forth on that a lot because we thought it might be causing issues. One problem is it sets it relative to model size. So if you have multiple models loaded in it will take the most recent model size for the cache_limit. This is problematic if you launch DeepSeek -> Llama for example. -24. Task cancellation. When API http request gets cancelled, it should cancel corresponding task. +24. further openai/lmstudio api compatibility +25. Rethink retry logic Potential refactors: -1. Make ForwarderEvent typed 2. Topology can be simplified -3. Get rid of InstanceReplacedAtomically Random errors we've run into: - -1. exo.shared.types.worker.common.RunnerError: RuntimeError: [ibv] Couldn't connect (error: 60). Traceback: Traceback (most recent call last): - File "/Users/puffin4/actions-runner/_work/exo/exo/src/exo/worker/runner/runner.py", line 54, in main - model, tokenizer, sampler, group = await loop.run_in_executor( - ^^^^^^^^^^^^^^^^^^^^^^^^^^^ - ...<8 lines>... - ) - ^ - File "/nix/store/s7ik6dazn4nd2jdg9l36qf5q0z18sjyk-python3-3.13.8/lib/python3.13/concurrent/futures/thread.py", line 59, in run - result = self.fn(*self.args, **self.kwargs) - File "/Users/puffin4/actions-runner/_work/exo/exo/src/exo/engines/mlx/utils_mlx.py", line 149, in initialize_mlx - group = mlx_distributed_init( - model_shard_meta.device_rank, - ...<4 lines>... - or (mlx_ibv_devices is not None and len(mlx_ibv_devices) > 1), - ) - File "/Users/puffin4/actions-runner/_work/exo/exo/src/exo/engines/mlx/utils_mlx.py", line 124, in mlx_distributed_init - group = mx.distributed.init( - backend="ring" if hosts is not None else "ibv", - strict=strict, - ) -RuntimeError: [ibv] Couldn't connect (error: 60) - -2. \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index cd617aee..465ef15a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,6 @@ dependencies = [ "filelock>=3.18.0", "aiosqlite>=0.21.0", "networkx>=3.5", - "openai>=1.99.9", "pathlib>=1.0.1", "protobuf>=6.32.0", "rich>=14.1.0", @@ -49,6 +48,7 @@ exo = "exo.main:main" dev = [ "pytest>=8.4.0", "pytest-asyncio>=1.0.0", + "pytest-env", "ruff>=0.11.13", ] @@ -131,4 +131,7 @@ asyncio_mode = "auto" markers = [ "slow: marks tests as slow (deselected by default)" ] +env = [ + "EXO_TESTS=1" +] addopts = "-m 'not slow'" diff --git a/src/exo/engines/mlx/constants.py b/src/exo/engines/mlx/constants.py deleted file mode 100644 index c73d62d3..00000000 --- a/src/exo/engines/mlx/constants.py +++ /dev/null @@ -1,17 +0,0 @@ -# TODO: Do we want so many constants? - -KV_GROUP_SIZE = 32 -KV_BITS = None -ATTENTION_KV_BITS = 4 -MAX_TOKENS = 8192 -MAX_KV_SIZE = 3200 -KEEP_KV_SIZE = 1600 -QUANTIZE_MODEL_MODE = "affine" -CACHE_GROUP_SIZE = 64 -KV_CACHE_BITS = 8 -TEMPERATURE = 1.0 - -# TODO: We should really make this opt-in, but Kimi requires trust_remote_code=True -TRUST_REMOTE_CODE = True -# TODO: Do we really want this? -HIDE_THINKING = False diff --git a/src/exo/main.py b/src/exo/main.py index b21434af..382b957a 100644 --- a/src/exo/main.py +++ b/src/exo/main.py @@ -23,9 +23,7 @@ from exo.worker.download.impl_shard_downloader import exo_shard_downloader from exo.worker.main import Worker -# TODO: Entrypoint refactor # I marked this as a dataclass as I want trivial constructors. -# This is the collection of systems for our entire application. @dataclass class Node: router: Router diff --git a/src/exo/master/api.py b/src/exo/master/api.py index 69176792..f0ed302b 100644 --- a/src/exo/master/api.py +++ b/src/exo/master/api.py @@ -14,7 +14,6 @@ from fastapi.responses import StreamingResponse from fastapi.staticfiles import StaticFiles from loguru import logger -from exo.engines.mlx.constants import HIDE_THINKING from exo.shared.apply import apply from exo.shared.election import ElectionMessage from exo.shared.models.model_cards import MODEL_CARDS @@ -49,6 +48,7 @@ from exo.shared.types.worker.instances import Instance, InstanceId from exo.utils.banner import print_startup_banner from exo.utils.channels import Receiver, Sender from exo.utils.event_buffer import OrderedBuffer +from exo.worker.engines.mlx.constants import HIDE_THINKING def chunk_to_response( diff --git a/src/exo/master/placement_utils.py b/src/exo/master/placement_utils.py index c96a8d35..4e512765 100644 --- a/src/exo/master/placement_utils.py +++ b/src/exo/master/placement_utils.py @@ -240,8 +240,6 @@ def _find_connection_ip( if ( connection.local_node_id == node_i.node_id and connection.send_back_node_id == node_j.node_id - # TODO: Check if we need this. - and connection.send_back_multiaddr is not None ): yield connection.send_back_multiaddr.ip_address diff --git a/src/exo/master/tests/conftest.py b/src/exo/master/tests/conftest.py index 9ebfa152..8441cef8 100644 --- a/src/exo/master/tests/conftest.py +++ b/src/exo/master/tests/conftest.py @@ -30,7 +30,7 @@ def create_node(): swap_available=1000, ), network_interfaces=[], - system=SystemPerformanceProfile(flops_fp16=1000), + system=SystemPerformanceProfile(), ), ) diff --git a/src/exo/master/tests/test_master.py b/src/exo/master/tests/test_master.py index 53b3fced..5aa26d48 100644 --- a/src/exo/master/tests/test_master.py +++ b/src/exo/master/tests/test_master.py @@ -28,9 +28,14 @@ from exo.shared.types.profiling import ( NodePerformanceProfile, SystemPerformanceProfile, ) -from exo.shared.types.tasks import ChatCompletionTask, TaskStatus -from exo.shared.types.worker.instances import Instance, InstanceStatus, ShardAssignments -from exo.shared.types.worker.shards import PipelineShardMetadata +from exo.shared.types.tasks import ChatCompletion as ChatCompletionTask +from exo.shared.types.tasks import TaskStatus +from exo.shared.types.worker.instances import ( + InstanceMeta, + MlxRingInstance, + ShardAssignments, +) +from exo.shared.types.worker.shards import PipelineShardMetadata, Sharding from exo.utils.channels import channel @@ -91,7 +96,7 @@ async def test_master(): swap_available=Memory.from_bytes(0), ), network_interfaces=[], - system=SystemPerformanceProfile(flops_fp16=0), + system=SystemPerformanceProfile(), ), ) ), @@ -118,7 +123,8 @@ async def test_master(): n_layers=16, storage_size=Memory.from_bytes(678948), ), - strategy="auto", + sharding=Sharding.Pipeline, + instance_meta=InstanceMeta.MlxRing, ) ), ) @@ -160,9 +166,8 @@ async def test_master(): )[0] assert events[1].event == InstanceCreated( event_id=events[1].event.event_id, - instance=Instance( + instance=MlxRingInstance( instance_id=events[1].event.instance.instance_id, - instance_type=InstanceStatus.Active, shard_assignments=ShardAssignments( model_id=ModelId("llama-3.2-1b"), runner_to_shard={ @@ -186,22 +191,13 @@ async def test_master(): ), ) assert isinstance(events[2].event, TaskCreated) - assert events[2].event == TaskCreated( - event_id=events[2].event.event_id, - task_id=events[2].event.task_id, - task=ChatCompletionTask( - task_id=events[2].event.task_id, - command_id=events[2].event.task.command_id, - instance_id=events[2].event.task.instance_id, - task_status=TaskStatus.Pending, - task_params=ChatCompletionTaskParams( - model="llama-3.2-1b", - messages=[ - ChatCompletionMessage( - role="user", content="Hello, how are you?" - ) - ], - ), - ), + assert events[2].event.task.task_status == TaskStatus.Pending + assert isinstance(events[2].event.task, ChatCompletionTask) + assert events[2].event.task.task_params == ChatCompletionTaskParams( + model="llama-3.2-1b", + messages=[ + ChatCompletionMessage(role="user", content="Hello, how are you?") + ], ) + await master.shutdown() diff --git a/src/exo/master/tests/test_placement.py b/src/exo/master/tests/test_placement.py index c52b0b33..41cd8360 100644 --- a/src/exo/master/tests/test_placement.py +++ b/src/exo/master/tests/test_placement.py @@ -1,7 +1,6 @@ from typing import Callable import pytest -from loguru import logger from exo.master.placement import ( get_instance_placements_after_create, @@ -15,9 +14,15 @@ from exo.shared.types.memory import Memory from exo.shared.types.models import ModelId, ModelMetadata from exo.shared.types.profiling import NetworkInterfaceInfo, NodePerformanceProfile from exo.shared.types.topology import Connection, NodeInfo -from exo.shared.types.worker.common import InstanceId -from exo.shared.types.worker.instances import Instance, InstanceStatus +from exo.shared.types.worker.instances import ( + Instance, + InstanceId, + InstanceMeta, + MlxIbvInstance, + MlxRingInstance, +) from exo.shared.types.worker.runners import ShardAssignments +from exo.shared.types.worker.shards import Sharding @pytest.fixture @@ -27,9 +32,8 @@ def topology() -> Topology: @pytest.fixture def instance() -> Instance: - return Instance( + return MlxRingInstance( instance_id=InstanceId(), - instance_type=InstanceStatus.Active, shard_assignments=ShardAssignments( model_id=ModelId("test-model"), runner_to_shard={}, node_to_runner={} ), @@ -51,7 +55,8 @@ def create_instance_command(model_meta: ModelMetadata) -> CreateInstance: return CreateInstance( command_id=CommandId(), model_meta=model_meta, - strategy="auto", + sharding=Sharding.Pipeline, + instance_meta=InstanceMeta.MlxRing, ) @@ -78,11 +83,7 @@ def test_get_instance_placements_create_instance( available_memory ) # make it exactly fit across all nodes - create_instance_command = CreateInstance( - command_id=CommandId(), - model_meta=model_meta, - strategy="auto", - ) + cic = create_instance_command(model_meta) node_id_a = NodeId() node_id_b = NodeId() node_id_c = NodeId() @@ -94,9 +95,7 @@ def test_get_instance_placements_create_instance( topology.add_connection(create_connection(node_id_c, node_id_a)) # act - placements = get_instance_placements_after_create( - create_instance_command, topology, {} - ) + placements = get_instance_placements_after_create(cic, topology, {}) # assert assert len(placements) == 1 @@ -128,19 +127,15 @@ def test_get_instance_placements_one_node_exact_fit( topology = Topology() node_id = NodeId() topology.add_node(create_node(1000 * 1024, node_id)) - create_instance_command = CreateInstance( - command_id=CommandId(), - model_meta=ModelMetadata( + cic = create_instance_command( + ModelMetadata( model_id=ModelId("test-model"), storage_size=Memory.from_kb(1000), pretty_name="Test Model", n_layers=10, ), - strategy="auto", - ) - placements = get_instance_placements_after_create( - create_instance_command, topology, {} ) + placements = get_instance_placements_after_create(cic, topology, {}) assert len(placements) == 1 instance_id = list(placements.keys())[0] @@ -157,19 +152,15 @@ def test_get_instance_placements_one_node_fits_with_extra_memory( topology = Topology() node_id = NodeId() topology.add_node(create_node(1001 * 1024, node_id)) - create_instance_command = CreateInstance( - command_id=CommandId(), - model_meta=ModelMetadata( + cic = create_instance_command( + ModelMetadata( model_id=ModelId("test-model"), storage_size=Memory.from_kb(1000), pretty_name="Test Model", n_layers=10, ), - strategy="auto", - ) - placements = get_instance_placements_after_create( - create_instance_command, topology, {} ) + placements = get_instance_placements_after_create(cic, topology, {}) assert len(placements) == 1 instance_id = list(placements.keys())[0] @@ -186,19 +177,17 @@ def test_get_instance_placements_one_node_not_fit( topology = Topology() node_id = NodeId() topology.add_node(create_node(1000 * 1024, node_id)) - create_instance_command = CreateInstance( - command_id=CommandId(), + cic = create_instance_command( model_meta=ModelMetadata( model_id=ModelId("test-model"), storage_size=Memory.from_kb(1001), pretty_name="Test Model", n_layers=10, ), - strategy="auto", ) with pytest.raises(ValueError, match="No cycles found with sufficient memory"): - get_instance_placements_after_create(create_instance_command, topology, {}) + get_instance_placements_after_create(cic, topology, {}) def test_get_transition_events_no_change(instance: Instance): @@ -301,16 +290,12 @@ def test_placement_prioritizes_leaf_cycle_with_less_memory( topology.add_connection(create_connection(node_id_e, node_id_y)) topology.add_connection(create_connection(node_id_f, node_id_z)) - create_instance_command = CreateInstance( - command_id=CommandId(), + cic = create_instance_command( model_meta=model_meta, - strategy="auto", ) # Act - placements = get_instance_placements_after_create( - create_instance_command, topology, {} - ) + placements = get_instance_placements_after_create(cic, topology, {}) # Assert the chosen cycle is A-B-C (contains at least one leaf node), even though # D-E-F has more total memory. @@ -346,7 +331,6 @@ def test_tensor_rdma_backend_connectivity_matrix( ethernet_interface = NetworkInterfaceInfo( name="en0", ip_address="192.168.1.100", - type="ethernet", ) assert node_a.node_profile is not None @@ -377,13 +361,7 @@ def test_tensor_rdma_backend_connectivity_matrix( network_interfaces=[ NetworkInterfaceInfo( name="en3", - ip_address=conn_c_a.send_back_multiaddr.ip_address, - type="rdma", - ), - NetworkInterfaceInfo( - name="en4", - ip_address=conn_b_a.send_back_multiaddr.ip_address, - type="rdma", + ip_address=conn_a_b.send_back_multiaddr.ip_address, ), ethernet_interface, ], @@ -395,15 +373,9 @@ def test_tensor_rdma_backend_connectivity_matrix( friendly_name="test", memory=node_b.node_profile.memory, network_interfaces=[ - NetworkInterfaceInfo( - name="en3", - ip_address=conn_c_b.send_back_multiaddr.ip_address, - type="rdma", - ), NetworkInterfaceInfo( name="en4", - ip_address=conn_a_b.send_back_multiaddr.ip_address, - type="rdma", + ip_address=conn_b_c.send_back_multiaddr.ip_address, ), ethernet_interface, ], @@ -416,14 +388,8 @@ def test_tensor_rdma_backend_connectivity_matrix( memory=node_c.node_profile.memory, network_interfaces=[ NetworkInterfaceInfo( - name="en3", - ip_address=conn_a_c.send_back_multiaddr.ip_address, - type="rdma", - ), - NetworkInterfaceInfo( - name="en4", - ip_address=conn_b_c.send_back_multiaddr.ip_address, - type="rdma", + name="en5", + ip_address=conn_c_a.send_back_multiaddr.ip_address, ), ethernet_interface, ], @@ -436,29 +402,26 @@ def test_tensor_rdma_backend_connectivity_matrix( topology.add_connection(conn_a_b) topology.add_connection(conn_b_c) topology.add_connection(conn_c_a) - topology.add_connection(conn_b_a) - topology.add_connection(conn_c_b) - topology.add_connection(conn_a_c) - create_instance_command = CreateInstance( + cic = CreateInstance( + sharding=Sharding.Tensor, + instance_meta=InstanceMeta.MlxIbv, command_id=CommandId(), model_meta=model_meta, - strategy="tensor_rdma", ) - placements = get_instance_placements_after_create( - create_instance_command, topology, {} - ) + placements = get_instance_placements_after_create(cic, topology, {}) assert len(placements) == 1 instance_id = list(placements.keys())[0] instance = placements[instance_id] - assert instance.hosts is None - assert instance.mlx_ibv_devices is not None - assert instance.mlx_ibv_coordinator is not None + assert isinstance(instance, MlxIbvInstance) - matrix = instance.mlx_ibv_devices + assert instance.ibv_devices is not None + assert instance.ibv_coordinator is not None + + matrix = instance.ibv_devices assert len(matrix) == 3 for i in range(3): @@ -471,11 +434,9 @@ def test_tensor_rdma_backend_connectivity_matrix( idx_b = node_to_idx[node_id_b] idx_c = node_to_idx[node_id_c] - logger.info(matrix) + assert matrix[idx_a][idx_b] == "rdma_en3" + assert matrix[idx_b][idx_c] == "rdma_en4" + assert matrix[idx_c][idx_a] == "rdma_en5" - assert matrix[idx_a][idx_b] == "rdma_en4" - assert matrix[idx_b][idx_c] == "rdma_en3" - assert matrix[idx_c][idx_a] == "rdma_en3" - - assert ":" in instance.mlx_ibv_coordinator - assert not instance.mlx_ibv_coordinator.startswith("169.254") + assert ":" in instance.ibv_coordinator + assert not instance.ibv_coordinator.startswith("169.254") diff --git a/src/exo/master/tests/test_placement_utils.py b/src/exo/master/tests/test_placement_utils.py index 1da3e270..d5f42ccf 100644 --- a/src/exo/master/tests/test_placement_utils.py +++ b/src/exo/master/tests/test_placement_utils.py @@ -13,6 +13,7 @@ from exo.shared.types.common import Host, NodeId from exo.shared.types.memory import Memory from exo.shared.types.models import ModelId, ModelMetadata from exo.shared.types.topology import Connection, NodeInfo +from exo.shared.types.worker.shards import Sharding @pytest.fixture @@ -200,7 +201,9 @@ def test_get_shard_assignments( selected_cycle = cycles[0] # act - shard_assignments = get_shard_assignments(model_meta, selected_cycle, "pipeline") + shard_assignments = get_shard_assignments( + model_meta, selected_cycle, Sharding.Pipeline + ) # assert runner_id_a = shard_assignments.node_to_runner[node_a_id] diff --git a/src/exo/master/tests/test_topology.py b/src/exo/master/tests/test_topology.py index e794c445..d6afb339 100644 --- a/src/exo/master/tests/test_topology.py +++ b/src/exo/master/tests/test_topology.py @@ -32,7 +32,7 @@ def node_profile() -> NodePerformanceProfile: memory_profile = MemoryPerformanceProfile.from_bytes( ram_total=1000, ram_available=1000, swap_total=1000, swap_available=1000 ) - system_profile = SystemPerformanceProfile(flops_fp16=1000) + system_profile = SystemPerformanceProfile() return NodePerformanceProfile( model_id="test", chip_id="test", @@ -99,7 +99,7 @@ def test_update_node_profile( ram_total=1000, ram_available=1000, swap_total=1000, swap_available=1000 ), network_interfaces=[], - system=SystemPerformanceProfile(flops_fp16=1000), + system=SystemPerformanceProfile(), ) # act diff --git a/src/exo/shared/apply.py b/src/exo/shared/apply.py index 6ea031a7..5ef1c15a 100644 --- a/src/exo/shared/apply.py +++ b/src/exo/shared/apply.py @@ -10,6 +10,7 @@ from exo.shared.types.events import ( IndexedEvent, InstanceCreated, InstanceDeleted, + NodeCreated, NodeDownloadProgress, NodeMemoryMeasured, NodePerformanceMeasured, @@ -23,7 +24,6 @@ from exo.shared.types.events import ( TestEvent, TopologyEdgeCreated, TopologyEdgeDeleted, - TopologyNodeCreated, ) from exo.shared.types.profiling import NodePerformanceProfile, SystemPerformanceProfile from exo.shared.types.state import State @@ -41,14 +41,14 @@ def event_apply(event: Event, state: State) -> State: TestEvent() | ChunkGenerated() | TaskAcknowledged() ): # TaskAcknowledged should never be sent by a worker but i dont mind if it just gets ignored return state - case NodeDownloadProgress(): - return apply_node_download_progress(event, state) case InstanceCreated(): return apply_instance_created(event, state) case InstanceDeleted(): return apply_instance_deleted(event, state) case NodePerformanceMeasured(): return apply_node_performance_measured(event, state) + case NodeDownloadProgress(): + return apply_node_download_progress(event, state) case NodeMemoryMeasured(): return apply_node_memory_measured(event, state) case RunnerDeleted(): @@ -63,7 +63,7 @@ def event_apply(event: Event, state: State) -> State: return apply_task_failed(event, state) case TaskStatusUpdated(): return apply_task_status_updated(event, state) - case TopologyNodeCreated(): + case NodeCreated(): return apply_topology_node_created(event, state) case TopologyEdgeCreated(): return apply_topology_edge_created(event, state) @@ -173,7 +173,6 @@ def apply_runner_deleted(event: RunnerDeleted, state: State) -> State: return state.model_copy(update={"runners": new_runners}) -# TODO: This whole function needs fixing def apply_node_performance_measured( event: NodePerformanceMeasured, state: State ) -> State: @@ -183,8 +182,8 @@ def apply_node_performance_measured( } state = state.model_copy(update={"node_profiles": new_profiles}) topology = copy.copy(state.topology) + # TODO: NodeCreated if not topology.contains_node(event.node_id): - # TODO: figure out why this is happening in the first place topology.add_node(NodeInfo(node_id=event.node_id)) topology.update_node_profile(event.node_id, event.node_profile) return state.model_copy(update={"topology": topology}) @@ -202,7 +201,7 @@ def apply_node_memory_measured(event: NodeMemoryMeasured, state: State) -> State memory=event.memory, network_interfaces=[], system=SystemPerformanceProfile( - flops_fp16=0.0, + # TODO: flops_fp16=0.0, gpu_usage=0.0, temp=0.0, sys_power=0.0, @@ -217,6 +216,7 @@ def apply_node_memory_measured(event: NodeMemoryMeasured, state: State) -> State } if not topology.contains_node(event.node_id): topology.add_node(NodeInfo(node_id=event.node_id)) + # TODO: NodeCreated topology.update_node_profile(event.node_id, created) return state.model_copy( update={"node_profiles": created_profiles, "topology": topology} @@ -227,6 +227,7 @@ def apply_node_memory_measured(event: NodeMemoryMeasured, state: State) -> State **state.node_profiles, event.node_id: updated, } + # TODO: NodeCreated if not topology.contains_node(event.node_id): topology.add_node(NodeInfo(node_id=event.node_id)) topology.update_node_profile(event.node_id, updated) @@ -235,7 +236,7 @@ def apply_node_memory_measured(event: NodeMemoryMeasured, state: State) -> State ) -def apply_topology_node_created(event: TopologyNodeCreated, state: State) -> State: +def apply_topology_node_created(event: NodeCreated, state: State) -> State: topology = copy.copy(state.topology) topology.add_node(NodeInfo(node_id=event.node_id)) return state.model_copy(update={"topology": topology}) diff --git a/src/exo/shared/openai_compat.py b/src/exo/shared/openai_compat.py deleted file mode 100644 index ed651356..00000000 --- a/src/exo/shared/openai_compat.py +++ /dev/null @@ -1,23 +0,0 @@ -from typing import TYPE_CHECKING, Literal, TypeAlias, get_type_hints - -if TYPE_CHECKING: - import openai.types as openai_types - import openai.types.chat as openai_chat - - types = openai_types - chat = openai_chat -else: - types = None - chat = None - -FinishReason: TypeAlias = Literal[ - "stop", "length", "tool_calls", "content_filter", "function_call" -] - -if TYPE_CHECKING: - assert ( - get_type_hints(chat.chat_completion_chunk.Choice)["finish_reason"] - == FinishReason - ), "Upstream changed Choice.finish_reason; update FinishReason alias." - -__all__ = ["types", "chat", "FinishReason"] diff --git a/src/exo/shared/types/api.py b/src/exo/shared/types/api.py index 3ec61289..56def4dc 100644 --- a/src/exo/shared/types/api.py +++ b/src/exo/shared/types/api.py @@ -3,12 +3,15 @@ from typing import Any, Literal from pydantic import BaseModel, Field -from exo.shared.openai_compat import FinishReason from exo.shared.types.common import CommandId from exo.shared.types.models import ModelMetadata from exo.shared.types.worker.instances import InstanceId, InstanceMeta from exo.shared.types.worker.shards import Sharding +FinishReason = Literal[ + "stop", "length", "tool_calls", "content_filter", "function_call" +] + class ModelListModel(BaseModel): id: str diff --git a/src/exo/shared/types/chunks.py b/src/exo/shared/types/chunks.py index 990416c0..ac90d20c 100644 --- a/src/exo/shared/types/chunks.py +++ b/src/exo/shared/types/chunks.py @@ -1,9 +1,10 @@ from enum import Enum -from exo.shared.openai_compat import FinishReason -from exo.shared.types.models import ModelId from exo.utils.pydantic_ext import TaggedModel +from .api import FinishReason +from .models import ModelId + class ChunkType(str, Enum): Token = "Token" diff --git a/src/exo/shared/types/commands.py b/src/exo/shared/types/commands.py index 1deca8ff..39c117f9 100644 --- a/src/exo/shared/types/commands.py +++ b/src/exo/shared/types/commands.py @@ -8,7 +8,6 @@ from exo.shared.types.worker.shards import Sharding from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel -# TODO: We need to have a distinction between create instance and spin up instance. class BaseCommand(TaggedModel): command_id: CommandId = Field(default_factory=CommandId) diff --git a/src/exo/shared/types/events.py b/src/exo/shared/types/events.py index ccc88185..3cc1c872 100644 --- a/src/exo/shared/types/events.py +++ b/src/exo/shared/types/events.py @@ -70,6 +70,11 @@ class RunnerDeleted(BaseEvent): runner_id: RunnerId +# TODO +class NodeCreated(BaseEvent): + node_id: NodeId + + class NodePerformanceMeasured(BaseEvent): node_id: NodeId node_profile: NodePerformanceProfile @@ -89,10 +94,6 @@ class ChunkGenerated(BaseEvent): chunk: GenerationChunk -class TopologyNodeCreated(BaseEvent): - node_id: NodeId - - class TopologyEdgeCreated(BaseEvent): edge: Connection @@ -116,7 +117,7 @@ Event = ( | NodeMemoryMeasured | NodeDownloadProgress | ChunkGenerated - | TopologyNodeCreated + | NodeCreated | TopologyEdgeCreated | TopologyEdgeDeleted ) diff --git a/src/exo/shared/types/profiling.py b/src/exo/shared/types/profiling.py index 3ebb6798..5ed6e0d4 100644 --- a/src/exo/shared/types/profiling.py +++ b/src/exo/shared/types/profiling.py @@ -1,5 +1,7 @@ from typing import Self +import psutil + from exo.shared.types.memory import Memory from exo.utils.pydantic_ext import CamelCaseModel @@ -21,9 +23,21 @@ class MemoryPerformanceProfile(CamelCaseModel): swap_available=Memory.from_bytes(swap_available), ) + @classmethod + def from_psutil(cls, *, override_memory: int | None) -> Self: + vm = psutil.virtual_memory() + sm = psutil.swap_memory() + + return cls.from_bytes( + ram_total=vm.total, + ram_available=vm.available if override_memory is None else override_memory, + swap_total=sm.total, + swap_available=sm.free, + ) + class SystemPerformanceProfile(CamelCaseModel): - flops_fp16: float + # TODO: flops_fp16: float gpu_usage: float = 0.0 temp: float = 0.0 @@ -36,7 +50,6 @@ class SystemPerformanceProfile(CamelCaseModel): class NetworkInterfaceInfo(CamelCaseModel): name: str ip_address: str - type: str class NodePerformanceProfile(CamelCaseModel): diff --git a/src/exo/shared/types/worker/commands_runner.py b/src/exo/shared/types/worker/commands_runner.py deleted file mode 100644 index 8878937f..00000000 --- a/src/exo/shared/types/worker/commands_runner.py +++ /dev/null @@ -1,45 +0,0 @@ -from exo.shared.openai_compat import FinishReason -from exo.utils.pydantic_ext import TaggedModel - - -class BaseRunnerResponse(TaggedModel): - pass - - -class InitializedResponse(BaseRunnerResponse): - time_taken: float - - -class TokenizedResponse(BaseRunnerResponse): - prompt_tokens: int - - -class GenerationResponse(BaseRunnerResponse): - text: str - token: int - # logprobs: list[float] | None = None # too big. we can change to be top-k - finish_reason: FinishReason | None = None - - -class PrintResponse(BaseRunnerResponse): - text: str - - -class FinishedResponse(BaseRunnerResponse): - pass - - -class ErrorResponse(BaseRunnerResponse): - error_type: str - error_message: str - traceback: str - - -RunnerResponse = ( - InitializedResponse - | TokenizedResponse - | GenerationResponse - | PrintResponse - | FinishedResponse - | ErrorResponse -) diff --git a/src/exo/shared/types/worker/common.py b/src/exo/shared/types/worker/common.py deleted file mode 100644 index 8b137891..00000000 --- a/src/exo/shared/types/worker/common.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/src/exo/shared/types/worker/ops.py b/src/exo/shared/types/worker/ops.py deleted file mode 100644 index 5dd98c9a..00000000 --- a/src/exo/shared/types/worker/ops.py +++ /dev/null @@ -1,34 +0,0 @@ -from exo.shared.types.tasks import Task -from exo.shared.types.worker.instances import BoundInstance, Instance -from exo.shared.types.worker.runners import RunnerId -from exo.utils.pydantic_ext import TaggedModel - - -class BaseRunnerOp(TaggedModel): - runner_id: RunnerId - - -class AssignRunnerOp(BaseRunnerOp): - instance: Instance - - def bound_instance(self) -> BoundInstance: - return BoundInstance(instance=self.instance, bound_runner_id=self.runner_id) - - -class UnassignRunnerOp(BaseRunnerOp): - pass - - -class RunnerUpOp(BaseRunnerOp): - pass - - -class RunnerDownOp(BaseRunnerOp): - pass - - -class ExecuteTaskOp(BaseRunnerOp): - task: Task - - -RunnerOp = AssignRunnerOp | ExecuteTaskOp | UnassignRunnerOp | RunnerUpOp | RunnerDownOp diff --git a/src/exo/shared/types/worker/runner_response.py b/src/exo/shared/types/worker/runner_response.py new file mode 100644 index 00000000..8c2d3754 --- /dev/null +++ b/src/exo/shared/types/worker/runner_response.py @@ -0,0 +1,21 @@ +from exo.shared.types.api import FinishReason +from exo.utils.pydantic_ext import TaggedModel + + +class BaseRunnerResponse(TaggedModel): + pass + + +class TokenizedResponse(BaseRunnerResponse): + prompt_tokens: int + + +class GenerationResponse(BaseRunnerResponse): + text: str + token: int + # logprobs: list[float] | None = None # too big. we can change to be top-k + finish_reason: FinishReason | None = None + + +class FinishedResponse(BaseRunnerResponse): + pass diff --git a/src/exo/utils/channels.py b/src/exo/utils/channels.py index c335fb02..72caa7ea 100644 --- a/src/exo/utils/channels.py +++ b/src/exo/utils/channels.py @@ -177,7 +177,7 @@ class MpReceiver[T]: try: item = self._state.buffer.get(block=False) - if item is MP_END_OF_STREAM: + if item == MP_END_OF_STREAM: self.close() raise EndOfStream assert not isinstance(item, _MpEndOfStream) @@ -193,7 +193,7 @@ class MpReceiver[T]: return self.receive_nowait() except WouldBlock: item = self._state.buffer.get() - if item is MP_END_OF_STREAM: + if item == MP_END_OF_STREAM: self.close() raise EndOfStream from None assert not isinstance(item, _MpEndOfStream) diff --git a/src/exo/worker/NOTES.md b/src/exo/worker/NOTES.md deleted file mode 100644 index 1170d0b9..00000000 --- a/src/exo/worker/NOTES.md +++ /dev/null @@ -1,2 +0,0 @@ -- Where should we check where the model is downloaded? -- Error handling. How do we handle the scenario where an operation keeps failing to execute diff --git a/src/exo/worker/download/conftest.py b/src/exo/worker/download/conftest.py deleted file mode 100644 index 4cf8b936..00000000 --- a/src/exo/worker/download/conftest.py +++ /dev/null @@ -1,36 +0,0 @@ -import pytest - -from exo.shared.models.model_meta import get_model_meta -from exo.shared.types.models import ModelMetadata -from exo.shared.types.worker.shards import PipelineShardMetadata - - -@pytest.fixture -async def model_meta() -> ModelMetadata: - return await get_model_meta("mlx-community/Llama-3.2-1B-Instruct-4bit") - - -@pytest.fixture -def pipeline_shard_meta(model_meta: ModelMetadata): - def _pipeline_shard_meta( - num_nodes: int = 1, device_rank: int = 0 - ) -> PipelineShardMetadata: - total_layers = 16 - layers_per_node = total_layers // num_nodes - start_layer = device_rank * layers_per_node - end_layer = ( - start_layer + layers_per_node - if device_rank < num_nodes - 1 - else total_layers - ) - - return PipelineShardMetadata( - model_meta=model_meta, - device_rank=device_rank, - n_layers=total_layers, - start_layer=start_layer, - end_layer=end_layer, - world_size=num_nodes, - ) - - return _pipeline_shard_meta diff --git a/src/exo/worker/download/huggingface_utils.py b/src/exo/worker/download/huggingface_utils.py index fbf711e1..cde32a48 100644 --- a/src/exo/worker/download/huggingface_utils.py +++ b/src/exo/worker/download/huggingface_utils.py @@ -1,7 +1,7 @@ import os from fnmatch import fnmatch from pathlib import Path -from typing import Callable, Generator, Iterable, TypeVar +from typing import Callable, Generator, Iterable import aiofiles import aiofiles.os as aios @@ -9,10 +9,8 @@ from loguru import logger from exo.shared.types.worker.shards import ShardMetadata -T = TypeVar("T") - -def filter_repo_objects( +def filter_repo_objects[T]( items: Iterable[T], *, allow_patterns: list[str] | str | None = None, diff --git a/src/exo/worker/tests/test_handlers/__init__.py b/src/exo/worker/engines/__init__.py similarity index 100% rename from src/exo/worker/tests/test_handlers/__init__.py rename to src/exo/worker/engines/__init__.py diff --git a/src/exo/engines/mlx/__init__.py b/src/exo/worker/engines/mlx/__init__.py similarity index 99% rename from src/exo/engines/mlx/__init__.py rename to src/exo/worker/engines/mlx/__init__.py index 8c0c8fa3..d6f0b6b3 100644 --- a/src/exo/engines/mlx/__init__.py +++ b/src/exo/worker/engines/mlx/__init__.py @@ -1,9 +1,8 @@ from typing import Any -from mlx_lm.models.cache import KVCache - import mlx.core as mx import mlx.nn as nn +from mlx_lm.models.cache import KVCache # These are wrapper functions to fix the fact that mlx is not strongly typed in the same way that EXO is. # For example - MLX has no guarantee of the interface that nn.Module will expose. But we need a guarantee that it has a __call__() function diff --git a/src/exo/engines/mlx/auto_parallel.py b/src/exo/worker/engines/mlx/auto_parallel.py similarity index 99% rename from src/exo/engines/mlx/auto_parallel.py rename to src/exo/worker/engines/mlx/auto_parallel.py index 4ff747b8..d6f419d5 100644 --- a/src/exo/engines/mlx/auto_parallel.py +++ b/src/exo/worker/engines/mlx/auto_parallel.py @@ -3,8 +3,14 @@ from functools import partial from inspect import signature from typing import TYPE_CHECKING, Callable, Protocol, cast, override +import mlx.core as mx +import mlx.nn as nn +from mlx.nn.layers.distributed import ( + shard_inplace, + shard_linear, + sum_gradients, +) from mlx_lm.models.cache import ( - KVCache, _BaseCache, # pyright: ignore[reportPrivateUsage] ) from mlx_lm.models.deepseek_v3 import DeepseekV3MLP @@ -13,16 +19,9 @@ from mlx_lm.models.llama import Model as LlamaModel from mlx_lm.models.qwen3_moe import Model as Qwen3MoeModel from mlx_lm.models.qwen3_moe import Qwen3MoeSparseMoeBlock -import mlx.core as mx -import mlx.nn as nn from exo.shared.types.worker.shards import ( PipelineShardMetadata, ) -from mlx.nn.layers.distributed import ( - shard_inplace, - shard_linear, - sum_gradients, -) class _LayerCallable(Protocol): @@ -94,7 +93,7 @@ class PipelineLastLayer(CustomMlxLayer): x, *args, **kwargs ).arguments.get("cache", None) - assert cache is None or issubclass(type(cache), _BaseCache) # type: ignore + assert cache is None or issubclass(type(cache), _BaseCache) # type: ignore output: mx.array = self.original_layer(x, *args, **kwargs) diff --git a/src/exo/engines/mlx/cache.py b/src/exo/worker/engines/mlx/cache.py similarity index 92% rename from src/exo/engines/mlx/cache.py rename to src/exo/worker/engines/mlx/cache.py index f4e7df8d..8a7f828b 100644 --- a/src/exo/engines/mlx/cache.py +++ b/src/exo/worker/engines/mlx/cache.py @@ -1,14 +1,16 @@ +# type: ignore +# TODO: Fix this file, including types! from copy import deepcopy from typing import Callable +import mlx.core as mx from mlx_lm import stream_generate from mlx_lm.models.cache import _BaseCache, trim_prompt_cache from mlx_lm.tokenizer_utils import TokenizerWrapper -import mlx.core as mx -from exo.engines.mlx import Model -from exo.engines.mlx.constants import KEEP_KV_SIZE, KV_BITS, KV_GROUP_SIZE -from exo.engines.mlx.utils_mlx import make_kv_cache +from exo.worker.engines.mlx import Model +from exo.worker.engines.mlx.constants import KEEP_KV_SIZE, KV_BITS, KV_GROUP_SIZE +from exo.worker.engines.mlx.utils_mlx import make_kv_cache class KVPrefixCache: diff --git a/src/exo/worker/engines/mlx/constants.py b/src/exo/worker/engines/mlx/constants.py new file mode 100644 index 00000000..91c20de4 --- /dev/null +++ b/src/exo/worker/engines/mlx/constants.py @@ -0,0 +1,18 @@ +# TODO: Do we want so many constants? +# I think we want a lot of these as parameters? + +KV_GROUP_SIZE: int | None = 32 +KV_BITS: int | None = None +ATTENTION_KV_BITS: int | None = 4 +MAX_TOKENS: int = 8192 +MAX_KV_SIZE: int | None = 3200 +KEEP_KV_SIZE: int | None = 1600 +QUANTIZE_MODEL_MODE: str | None = "affine" +CACHE_GROUP_SIZE: int = 64 +KV_CACHE_BITS: int | None = 8 +TEMPERATURE: float = 1.0 + +# TODO: We should really make this opt-in, but Kimi requires trust_remote_code=True +TRUST_REMOTE_CODE: bool = True +# TODO: Do we really want this? +HIDE_THINKING: bool = False diff --git a/src/exo/engines/mlx/utils_mlx.py b/src/exo/worker/engines/mlx/utils_mlx.py similarity index 91% rename from src/exo/engines/mlx/utils_mlx.py rename to src/exo/worker/engines/mlx/utils_mlx.py index 8c48bd2e..c9f47449 100644 --- a/src/exo/engines/mlx/utils_mlx.py +++ b/src/exo/worker/engines/mlx/utils_mlx.py @@ -1,6 +1,7 @@ import os import resource import time +from pathlib import Path from typing import Any, Callable, cast from mlx_lm.models.cache import KVCache, QuantizedKVCache, RotatingKVCache @@ -8,29 +9,22 @@ from mlx_lm.models.deepseek_v3 import DeepseekV3Model from mlx_lm.sample_utils import make_sampler from mlx_lm.tokenizer_utils import TokenizerWrapper -from exo.worker.runner.utils import get_weights_size +from exo.worker.engines.mlx.constants import ( + CACHE_GROUP_SIZE, + KV_CACHE_BITS, + TEMPERATURE, + TRUST_REMOTE_CODE, +) try: from mlx_lm.tokenizer_utils import load_tokenizer except ImportError: from mlx_lm.tokenizer_utils import load as load_tokenizer # type: ignore +import mlx.core as mx +import mlx.nn as nn from mlx_lm.utils import load_model from pydantic import RootModel -import mlx.core as mx -import mlx.nn as nn -from exo.engines.mlx import Model -from exo.engines.mlx.auto_parallel import ( - pipeline_auto_parallel, - tensor_auto_parallel, -) -from exo.engines.mlx.constants import ( - CACHE_GROUP_SIZE, - KV_CACHE_BITS, - PATCH_SYSTEM_PROMPT, - TEMPERATURE, - TRUST_REMOTE_CODE, -) from exo.shared.types.api import ChatCompletionMessageText from exo.shared.types.common import Host from exo.shared.types.memory import Memory @@ -46,13 +40,31 @@ from exo.shared.types.worker.shards import ( TensorShardMetadata, ) from exo.worker.download.download_utils import build_model_path +from exo.worker.engines.mlx import Model +from exo.worker.engines.mlx.auto_parallel import ( + pipeline_auto_parallel, + tensor_auto_parallel, +) from exo.worker.runner.bootstrap import logger # Needed for 8 bit model resource.setrlimit(resource.RLIMIT_NOFILE, (2048, 4096)) -mlx_rank: None | int = None -mlx_world_size: None | int = None + +# TODO: Test this +# ALSO https://github.com/exo-explore/exo/pull/233#discussion_r2549683673 +def get_weights_size(model_shard_meta: ShardMetadata) -> Memory: + return Memory.from_float_kb( + (model_shard_meta.end_layer - model_shard_meta.start_layer) + / model_shard_meta.n_layers + * model_shard_meta.model_meta.storage_size.in_kb + / ( + 1 + if isinstance(model_shard_meta, PipelineShardMetadata) + else model_shard_meta.world_size + ) + ) + def mx_barrier(group: mx.distributed.Group | None = None): mx.eval( @@ -65,10 +77,10 @@ def mx_barrier(group: mx.distributed.Group | None = None): def broadcast_from_zero(value: int, group: mx.distributed.Group | None = None): - if mlx_rank is None: + if group is None: return value - if mlx_rank == 0: + if group.rank() == 0: a = mx.array([value], dtype=mx.int32) else: a = mx.array([0], dtype=mx.int32) @@ -154,10 +166,10 @@ def initialize_mlx( logger.info(f"Single device used for {bound_instance.instance}") model_path = build_model_path(bound_instance.bound_shard.model_meta.model_id) start_time = time.perf_counter() - model, config = load_model(model_path, strict=True) + model, _ = load_model(model_path, strict=True) end_time = time.perf_counter() logger.info(f"Time taken to load model: {(end_time - start_time):.2f}s") - if isinstance(model.model, DeepseekV3Model): + if hasattr(model, "model") and isinstance(model.model, DeepseekV3Model): # type: ignore pass # model, config = quantize_model( # model, config, group_size=KV_GROUP_SIZE, bits=ATTENTION_KV_BITS, quant_predicate=quant_predicate, mode=QUANTIZE_MODEL_MODE @@ -189,9 +201,9 @@ def shard_and_load( ) -> tuple[nn.Module, TokenizerWrapper]: model_path = build_model_path(shard_metadata.model_meta.model_id) - model, config = load_model(model_path, lazy=True, strict=False) + model, _ = load_model(model_path, lazy=True, strict=False) logger.debug(model) - if isinstance(model.model, DeepseekV3Model): + if hasattr(model, "model") and isinstance(model.model, DeepseekV3Model): # type: ignore pass # TODO: See if we should quantize the model. # def is_attention_layer(path: str) -> bool: @@ -199,7 +211,6 @@ def shard_and_load( # return "self_attn" in path and "layernorm" not in path - # def quant_predicate(path: str, module: nn.Module): # if not isinstance(module, nn.Linear): # return False @@ -237,7 +248,7 @@ def shard_and_load( return model, tokenizer -def get_tokenizer(model_path: str, shard_metadata: ShardMetadata): +def get_tokenizer(model_path: Path, shard_metadata: ShardMetadata): tokenizer = cast( TokenizerWrapper, load_tokenizer( @@ -262,7 +273,7 @@ def apply_chat_template( messages = chat_task_data.messages formatted_messages: list[dict[str, Any]] = [] - for i, message in enumerate(messages): + for _, message in enumerate(messages): if isinstance(message.content, ChatCompletionMessageText): message.content = message.content.text if isinstance(message.content, list): @@ -276,7 +287,7 @@ def apply_chat_template( # Null values are not valid when applying templates in tokenizer formatted_messages.append( - {k: v for k, v in message.model_dump().items() if v is not None} + {k: v for k, v in message.model_dump().items() if v is not None} # type: ignore ) prompt: str = tokenizer.apply_chat_template( # type: ignore diff --git a/src/exo/worker/main.py b/src/exo/worker/main.py index 830bd7ce..073b1dbb 100644 --- a/src/exo/worker/main.py +++ b/src/exo/worker/main.py @@ -226,9 +226,7 @@ class Worker: task_id=task.task_id, task_status=TaskStatus.Running ) ) - await self._handle_shard_download_process( - task, initial_progress - ) + self._handle_shard_download_process(task, initial_progress) case Shutdown(runner_id=runner_id): await self.runners.pop(runner_id).start_task(task) case task: @@ -313,7 +311,7 @@ class Worker: self._tg.start_soon(runner.run) return runner - async def _handle_shard_download_process( + def _handle_shard_download_process( self, task: DownloadModel, initial_progress: RepoDownloadProgress, diff --git a/src/exo/worker/plan.py b/src/exo/worker/plan.py index e44b1975..cc886b4b 100644 --- a/src/exo/worker/plan.py +++ b/src/exo/worker/plan.py @@ -17,6 +17,7 @@ from exo.shared.types.tasks import ( from exo.shared.types.worker.downloads import DownloadCompleted, DownloadProgress from exo.shared.types.worker.instances import BoundInstance, Instance, InstanceId from exo.shared.types.worker.runners import ( + RunnerFailed, RunnerId, RunnerLoaded, RunnerLoading, @@ -59,16 +60,21 @@ def _kill_runner( instances: Mapping[InstanceId, Instance], ) -> Shutdown | None: for runner in runners.values(): + runner_id = runner.bound_instance.bound_runner_id if (instance_id := runner.bound_instance.instance.instance_id) not in instances: - return Shutdown( - instance_id=instance_id, runner_id=runner.bound_instance.bound_runner_id - ) + return Shutdown(instance_id=instance_id, runner_id=runner_id) - """ --- Potential code to kill a runner if any runners in its instance have failed --- - global_runners_in_instance = runner.bound_instance.instance.shard_assignments.node_to_runner.values() - if any(isinstance(all_runners[runner_id], RunnerFailed) for runner_id in global_runners_in_instance if runner_id != runner.bound_instance.bound_runner_id): - Shutdown(instance_id=runner.bound_instance.instance.instance_id, runner_id=runner.bound_instance.bound_runner_id) - """ + for ( + global_runner_id + ) in runner.bound_instance.instance.shard_assignments.node_to_runner.values(): + if runner_id == global_runner_id: + continue + + if isinstance(all_runners.get(global_runner_id, None), RunnerFailed): + return Shutdown( + instance_id=instance_id, + runner_id=runner_id, + ) def _create_runner( @@ -125,25 +131,36 @@ def _load_model( global_download_status: Mapping[NodeId, Sequence[DownloadProgress]], ) -> LoadModel | None: for runner in runners.values(): - if ( - all( + instance = runner.bound_instance.instance + shard_assignments = instance.shard_assignments + + all_downloads_complete_local = all( + any( isinstance(dp, DownloadCompleted) - if dp.shard_metadata - == runner.bound_instance.instance.shard_assignments.runner_to_shard[rid] - else True - for nid, rid in runner.bound_instance.instance.shard_assignments.node_to_runner.items() + and dp.shard_metadata == shard_assignments.runner_to_shard[rid] for dp in global_download_status[nid] ) - and isinstance(runner.status, RunnerWaitingForModel) - and all( - isinstance( - all_runners.get(global_runner_id, None), - (RunnerWaitingForModel, RunnerLoading, RunnerLoaded), - ) - for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard + for nid, rid in shard_assignments.node_to_runner.items() + ) + + runner_is_waiting = isinstance(runner.status, RunnerWaitingForModel) + + all_runners_expecting_model = all( + isinstance( + all_runners.get(global_runner_id), + (RunnerWaitingForModel, RunnerLoading, RunnerLoaded), ) + for global_runner_id in shard_assignments.runner_to_shard + ) + + if ( + all_downloads_complete_local + and runner_is_waiting + and all_runners_expecting_model ): - return LoadModel(instance_id=runner.bound_instance.instance.instance_id) + return LoadModel(instance_id=instance.instance_id) + + return None def _ready_to_warmup( @@ -151,29 +168,37 @@ def _ready_to_warmup( all_runners: Mapping[RunnerId, RunnerStatus], ) -> StartWarmup | None: for runner in runners.values(): - if isinstance(runner.status, RunnerLoaded) and ( - ( - all( - isinstance( - all_runners.get(global_runner_id, None), - (RunnerLoaded, RunnerWarmingUp), - ) - for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard - ) - and runner.bound_instance.bound_shard.device_rank != 0 + instance = runner.bound_instance.instance + shard_assignments = instance.shard_assignments + shard = runner.bound_instance.bound_shard + device_rank = shard.device_rank + runner_id = runner.bound_instance.bound_runner_id + + is_runner_loaded = isinstance(runner.status, RunnerLoaded) + + # Rank != 0 + all_runners_loaded_or_warming_up = all( + isinstance( + all_runners.get(global_runner_id, None), + (RunnerLoaded, RunnerWarmingUp), ) - or ( - all( - isinstance( - all_runners.get(global_runner_id, None), (RunnerWarmingUp) - ) - for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard - if global_runner_id != runner.bound_instance.bound_runner_id - ) - and runner.bound_instance.bound_shard.device_rank == 0 - ) - ): - return StartWarmup(instance_id=runner.bound_instance.instance.instance_id) + for global_runner_id in shard_assignments.runner_to_shard + ) + + # Rank= 0 + all_other_runners_warming_up = all( + isinstance(all_runners.get(global_runner_id, None), RunnerWarmingUp) + for global_runner_id in shard_assignments.runner_to_shard + if global_runner_id != runner_id + ) + + nonzero_rank_ready = device_rank != 0 and all_runners_loaded_or_warming_up + zero_rank_ready = device_rank == 0 and all_other_runners_warming_up + + if is_runner_loaded and (nonzero_rank_ready or zero_rank_ready): + return StartWarmup(instance_id=instance.instance_id) + + return None def _pending_tasks( diff --git a/src/exo/worker/runner/bootstrap.py b/src/exo/worker/runner/bootstrap.py index e05b4789..22eab98a 100644 --- a/src/exo/worker/runner/bootstrap.py +++ b/src/exo/worker/runner/bootstrap.py @@ -1,8 +1,4 @@ -"""--- not doing this anymore -import faulthandler import os -import sys -""" import loguru @@ -11,45 +7,25 @@ from exo.shared.types.tasks import Task from exo.shared.types.worker.instances import BoundInstance from exo.utils.channels import MpReceiver, MpSender -""" -- not doing this anymore -def _redirect_stderr_to_file(path: str) -> None: - # Replace fd 2 (stderr) with a file descriptor pointing to `path` - fd = os.open(path, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o644) - os.dup2(fd, 2) - os.close(fd) - # Rebind sys.stderr so Python's own writes go to the new fd as well (line-buffered) - sys.stderr = os.fdopen(2, "w", buffering=1, closefd=False) -""" +logger: "loguru.Logger" + + +if os.getenv("EXO_TESTS") == "1": + logger = loguru.logger def entrypoint( bound_instance: BoundInstance, event_sender: MpSender[Event], task_receiver: MpReceiver[Task], - # err_path: str, _logger: "loguru.Logger", ) -> None: - """ - Minimal entrypoint for the spawned child process. - - It redirects fd=2 (stderr) to a pipe provided by the parent, *then* imports - the heavy runner module so that any C/C++ or MLX logs/crashes land in that pipe. - """ - """ --- not doing this anymore - _redirect_stderr_to_file(err_path) - faulthandler.enable(file=sys.stderr, all_threads=True) - """ - import os - os.environ["MLX_METAL_FAST_SYNCH"] = "1" global logger logger = _logger - # Import the heavy runner only after stderr is redirected + # Import main after setting global logger - this lets us just import logger from this module from exo.worker.runner.runner import main main(bound_instance, event_sender, task_receiver) - - -logger: "loguru.Logger" diff --git a/src/exo/worker/runner/generate.py b/src/exo/worker/runner/generate.py index 134ac956..ae80797b 100644 --- a/src/exo/worker/runner/generate.py +++ b/src/exo/worker/runner/generate.py @@ -5,21 +5,19 @@ from mlx_lm import stream_generate from mlx_lm.models.cache import KVCache from mlx_lm.tokenizer_utils import TokenizerWrapper -from exo.engines.mlx import Model - # from exo.engines.mlx.cache import KVPrefixCache -from exo.engines.mlx.constants import KV_BITS, KV_GROUP_SIZE, MAX_TOKENS -from exo.engines.mlx.utils_mlx import ( +from exo.shared.types.api import ChatCompletionMessage, FinishReason +from exo.shared.types.tasks import ChatCompletionTaskParams +from exo.shared.types.worker.runner_response import ( + GenerationResponse, +) +from exo.worker.engines.mlx import Model +from exo.worker.engines.mlx.constants import KV_BITS, KV_GROUP_SIZE, MAX_TOKENS +from exo.worker.engines.mlx.utils_mlx import ( apply_chat_template, make_kv_cache, mx_barrier, ) -from exo.shared.openai_compat import FinishReason -from exo.shared.types.api import ChatCompletionMessage -from exo.shared.types.tasks import ChatCompletionTaskParams -from exo.shared.types.worker.commands_runner import ( - GenerationResponse, -) from exo.worker.runner.bootstrap import logger generation_stream = mx.new_stream(mx.default_device()) diff --git a/src/exo/worker/runner/runner.py b/src/exo/worker/runner/runner.py index 87eb742d..81b43524 100644 --- a/src/exo/worker/runner/runner.py +++ b/src/exo/worker/runner/runner.py @@ -1,9 +1,6 @@ import time -from exo.engines.mlx.utils_mlx import ( - initialize_mlx, - mlx_force_oom, -) +from exo.shared.types.api import ChatCompletionMessageText from exo.shared.types.chunks import TokenChunk from exo.shared.types.events import ( ChunkGenerated, @@ -20,11 +17,10 @@ from exo.shared.types.tasks import ( Task, TaskStatus, ) -from exo.shared.types.worker.commands_runner import ( - GenerationResponse, - # TokenizedResponse, -) from exo.shared.types.worker.instances import BoundInstance +from exo.shared.types.worker.runner_response import ( + GenerationResponse, +) from exo.shared.types.worker.runners import ( RunnerFailed, RunnerLoaded, @@ -37,6 +33,10 @@ from exo.shared.types.worker.runners import ( RunnerWarmingUp, ) from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender +from exo.worker.engines.mlx.utils_mlx import ( + initialize_mlx, + mlx_force_oom, +) from exo.worker.runner.bootstrap import logger from exo.worker.runner.generate import mlx_generate, warmup_inference @@ -142,27 +142,8 @@ def main( runner_id=runner_id, runner_status=current_status ) ) - # Ensure we have a chat-completion task subtype - # TODO: this is a hack, why are we only looking at the first message? should have a tokenizer - prompt = task_params.messages[0] - if ( - prompt.content is not None - and "EXO RUNNER MUST FAIL" in prompt.content - ): - logger.info("raising exception") - raise Exception( - "Artificial runner exception - for testing purposes only." - ) - if ( - prompt.content is not None - and "EXO RUNNER MUST OOM" in prompt.content - ): - mlx_force_oom() - if ( - prompt.content is not None - and "EXO RUNNER MUST TIMEOUT" in prompt.content - ): - time.sleep(100) + assert task_params.messages[0].content is not None + _check_for_debug_prompts(task_params.messages[0].content) # Generate responses using the actual MLX generation for response in mlx_generate( @@ -186,9 +167,9 @@ def main( ), ) ) - # case TokenizedResponse(): - # TODO: something here ig - # logger.info("Finished tokenizing?") + # case TokenizedResponse(): + # TODO: something here ig + logger.info("Finished tokenizing?") current_status = RunnerReady() logger.info("runner ready") @@ -233,3 +214,29 @@ def main( event_sender.join() task_receiver.join() logger.info("bye from the runner") + + +EXO_RUNNER_MUST_FAIL = "EXO RUNNER MUST FAIL" +EXO_RUNNER_MUST_OOM = "EXO RUNNER MUST OOM" +EXO_RUNNER_MUST_TIMEOUT = "EXO RUNNER MUST TIMEOUT" + + +def _check_for_debug_prompts( + prompt: str | ChatCompletionMessageText | list[ChatCompletionMessageText], +): + if isinstance(prompt, list): + if len(prompt) == 0: + logger.debug("Empty message prompt received in debug prompt") + return + prompt = prompt[0] + + if isinstance(prompt, ChatCompletionMessageText): + prompt = prompt.text + + if EXO_RUNNER_MUST_FAIL in prompt: + logger.info("raising exception") + raise Exception("Artificial runner exception - for testing purposes only.") + if EXO_RUNNER_MUST_OOM in prompt: + mlx_force_oom() + if EXO_RUNNER_MUST_TIMEOUT in prompt: + time.sleep(100) diff --git a/src/exo/worker/runner/utils.py b/src/exo/worker/runner/utils.py deleted file mode 100644 index 9cf22c95..00000000 --- a/src/exo/worker/runner/utils.py +++ /dev/null @@ -1,64 +0,0 @@ -import asyncio -import contextlib -import sys - -import psutil -from loguru import logger - -from exo.shared.types.memory import Memory -from exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata - - -async def kill_process_tree(runner_process: asyncio.subprocess.Process) -> None: - """Kill the process and all its children forcefully.""" - if runner_process.returncode is not None: - return # Process already dead - - try: - # Get the main process - pid = runner_process.pid - - # Find all child processes - try: - parent = psutil.Process(pid) - children = parent.children(recursive=True) - - # Kill all children first (bottom-up) - for child in reversed(children): - with contextlib.suppress(psutil.NoSuchProcess, psutil.AccessDenied): - child.kill() # SIGKILL - - # Kill the parent - with contextlib.suppress(psutil.NoSuchProcess, psutil.AccessDenied): - parent.kill() # SIGKILL - - except psutil.NoSuchProcess: - # Process already gone, try subprocess kill anyway - runner_process.kill() - - # Wait for the subprocess to exit - try: - await asyncio.wait_for(runner_process.wait(), timeout=2.0) - except asyncio.TimeoutError: - logger.error(f"Process {pid} did not exit after kill signal") - - except Exception as e: - logger.error(f"Error killing process tree: {e}") - - -def get_runner_command() -> list[str]: - python = sys.executable - return [python, "-m", "exo.worker.runner.runner"] - - -def get_weights_size(model_shard_meta: ShardMetadata) -> Memory: - return Memory.from_float_kb( - (model_shard_meta.end_layer - model_shard_meta.start_layer) - / model_shard_meta.n_layers - * model_shard_meta.model_meta.storage_size.in_kb - / ( - 1 - if isinstance(model_shard_meta, PipelineShardMetadata) - else model_shard_meta.world_size - ) - ) diff --git a/src/exo/worker/tests/TODO.tests b/src/exo/worker/tests/TODO.tests new file mode 100644 index 00000000..ab667fc3 --- /dev/null +++ b/src/exo/worker/tests/TODO.tests @@ -0,0 +1,57 @@ +Unit Tests +1. Test worker plans as expected + - State transitions are correct + - Unexpected states throw + +2. Test runner + - Stays loaded + - Unloads under end condition + - Accepts tasks + - Returns ChunkGenerated events + +3. Test mlx engine + - Autoparallel on n of the same nodes returns tensors with 1/n size + - mx.barrier forces computation + - Distributed init returns expected configuration + - initialize_mlx sets wired limit + - shard_and_load returns expected model + - Quantization returns quantized layers + + 4. Download + - hits the correct endpoint + - normalizes tags correctly + - updates download progress + + 5. Serialization/Deserialization of tagged models + + + + + +Integration tests: +1. Test model inference is "sensible" (per-configuration) + - Non-empty response + - Sensible inference speed + - Answers are non-gibberish for many seeds (What is the capital of France? -> "Paris" in answer.) + - Answer is the same for particular seed + +2. Test that node count does not affect inference result (per-configuration) + - Llama on 1 node, and on 2 nodes returns the same result, given temperature 0 and set seed. + - Do for all configurations (Ring/Ibv, Pipeline/Tensor) + +3. Test supervisor catches exceptions gracefully + - Timeouts + - OOM + - MLX error + +4. distributed init memory requirements are as expected + +5. MLX + - KVCache size is same length as prompt tokens + - Prefix cache (once implemented) + +6. Spin up creates a runner or goes to failed status + + +Regression tests: +1. Per-configuration baseline performance - no 20% drop in performance (device, node count, model, strategy, backend) diff --git a/src/exo/worker/tests/conftest.py b/src/exo/worker/tests/conftest.py deleted file mode 100644 index 380a93d9..00000000 --- a/src/exo/worker/tests/conftest.py +++ /dev/null @@ -1,165 +0,0 @@ -from typing import Callable - -import pytest - -from exo.shared.models.model_meta import get_model_meta -from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams -from exo.shared.types.common import Host, NodeId -from exo.shared.types.models import ModelId, ModelMetadata -from exo.shared.types.tasks import ( - ChatCompletionTask, - TaskId, - TaskStatus, -) -from exo.shared.types.worker.common import InstanceId -from exo.shared.types.worker.instances import Instance, InstanceStatus -from exo.shared.types.worker.runners import RunnerId, ShardAssignments -from exo.shared.types.worker.shards import PipelineShardMetadata -from exo.worker.main import Worker -from exo.worker.tests.constants import ( - COMMAND_1_ID, - INSTANCE_1_ID, - MODEL_A_ID, - NODE_A, - NODE_B, - RUNNER_1_ID, - TASK_1_ID, -) - -from .worker_management import ( - WorkerMailbox, - create_worker_and_mailbox, - create_worker_void_mailbox, - create_worker_with_old_mailbox, -) - - -@pytest.fixture -def worker_void_mailbox() -> Worker: - return create_worker_void_mailbox(NODE_A) - - -@pytest.fixture -def worker_and_mailbox() -> tuple[Worker, WorkerMailbox]: - return create_worker_and_mailbox(NODE_A) - - -@pytest.fixture -def two_workers_with_shared_mailbox() -> tuple[Worker, Worker, WorkerMailbox]: - worker1, mailbox = create_worker_and_mailbox(NODE_A) - worker2 = create_worker_with_old_mailbox(NODE_B, mailbox) - return worker1, worker2, mailbox - - -@pytest.fixture -def user_message() -> str: - """Override this fixture in tests to customize the message""" - return "Hello, how are you?" - - -@pytest.fixture -async def model_meta() -> ModelMetadata: - return await get_model_meta("mlx-community/Llama-3.2-1B-Instruct-4bit") - - -@pytest.fixture -def hosts(): - def _hosts(count: int, offset: int = 0) -> list[Host]: - return [ - Host( - ip="127.0.0.1", - port=5000 + offset + i, - ) - for i in range(count) - ] - - return _hosts - - -@pytest.fixture -def pipeline_shard_meta( - model_meta: ModelMetadata, -) -> Callable[[int, int], PipelineShardMetadata]: - def _pipeline_shard_meta( - num_nodes: int = 1, device_rank: int = 0 - ) -> PipelineShardMetadata: - total_layers = model_meta.n_layers - layers_per_node = total_layers // num_nodes - start_layer = device_rank * layers_per_node - end_layer = ( - start_layer + layers_per_node - if device_rank < num_nodes - 1 - else total_layers - ) - - return PipelineShardMetadata( - model_meta=model_meta, - device_rank=device_rank, - n_layers=total_layers, - start_layer=start_layer, - end_layer=end_layer, - world_size=num_nodes, - ) - - return _pipeline_shard_meta - - -@pytest.fixture -def instance( - pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata], - hosts: Callable[[int], list[Host]], -): - def _instance( - instance_id: InstanceId | None = None, - node_id: NodeId | None = None, - runner_id: RunnerId | None = None, - model_id: ModelId | None = None, - ) -> Instance: - resolved_instance_id = instance_id if instance_id is not None else INSTANCE_1_ID - resolved_node_id = node_id if node_id is not None else NODE_A - resolved_runner_id = runner_id if runner_id is not None else RUNNER_1_ID - resolved_model_id = model_id if model_id is not None else MODEL_A_ID - - shard_assignments = ShardAssignments( - model_id=resolved_model_id, - runner_to_shard={resolved_runner_id: pipeline_shard_meta(1, 0)}, - node_to_runner={resolved_node_id: resolved_runner_id}, - ) - - return Instance( - instance_id=resolved_instance_id, - instance_type=InstanceStatus.Active, - shard_assignments=shard_assignments, - hosts=hosts(1), - ) - - return _instance - - -@pytest.fixture -def completion_create_params(user_message: str) -> ChatCompletionTaskParams: - return ChatCompletionTaskParams( - model="gpt-4", - messages=[ChatCompletionMessage(role="user", content=user_message)], - stream=True, - ) - - -@pytest.fixture -def chat_completion_task(completion_create_params: ChatCompletionTaskParams): - def _chat_completion_task( - instance_id: InstanceId | None = None, - task_id: TaskId | None = None, - user_message: str = "Hello", - ) -> ChatCompletionTask: - resolved_instance_id = instance_id if instance_id is not None else INSTANCE_1_ID - resolved_task_id = task_id if task_id is not None else TASK_1_ID - return ChatCompletionTask( - task_id=resolved_task_id, - command_id=COMMAND_1_ID, - instance_id=resolved_instance_id, - task_status=TaskStatus.Pending, - task_params=completion_create_params, - ) - - return _chat_completion_task diff --git a/src/exo/worker/tests/constants.py b/src/exo/worker/tests/constants.py index 85e16ed6..787f2ff7 100644 --- a/src/exo/worker/tests/constants.py +++ b/src/exo/worker/tests/constants.py @@ -3,7 +3,7 @@ from typing import Final from exo.shared.types.common import CommandId, NodeId from exo.shared.types.models import ModelId from exo.shared.types.tasks import TaskId -from exo.shared.types.worker.common import InstanceId, RunnerId +from exo.shared.types.worker.instances import InstanceId, RunnerId MASTER_NODE_ID = NodeId("ffffffff-aaaa-4aaa-8aaa-aaaaaaaaaaaa") diff --git a/src/exo/worker/tests/test_download.py b/src/exo/worker/tests/test_download.py deleted file mode 100644 index 3ce6b964..00000000 --- a/src/exo/worker/tests/test_download.py +++ /dev/null @@ -1,49 +0,0 @@ -import time -from typing import Callable - -import pytest - -from exo.shared.types.worker.shards import PipelineShardMetadata -from exo.worker.download.impl_shard_downloader import exo_shard_downloader -from exo.worker.download.shard_downloader import ShardDownloader - - -@pytest.mark.slow -@pytest.mark.asyncio -async def test_shard_downloader( - pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata], -): - shard_downloader: ShardDownloader = exo_shard_downloader() - shard_downloader.on_progress( - lambda shard, progress: print(f"Download progress: {progress}") - ) - - shard_metadata = pipeline_shard_meta(1, 0) - path = await shard_downloader.ensure_shard(shard_metadata) - assert path.exists() - - downloaded_model_path = path.parent / "mlx-community--Llama-3.2-1B-Instruct-4bit" - assert (downloaded_model_path / "config.json").exists() - assert (downloaded_model_path / "model.safetensors").exists() - assert (downloaded_model_path / "model.safetensors.index.json").exists() - assert (downloaded_model_path / "special_tokens_map.json").exists() - assert (downloaded_model_path / "tokenizer.json").exists() - assert (downloaded_model_path / "tokenizer_config.json").exists() - - expected_files_and_sizes = [ - ("config.json", 1121), - ("model.safetensors", 695283921), - ("model.safetensors.index.json", 26159), - ("special_tokens_map.json", 296), - ("tokenizer.json", 17209920), - ("tokenizer_config.json", 54558), - ] - for filename, expected_size in expected_files_and_sizes: - file_path = downloaded_model_path / filename - assert file_path.stat().st_size == expected_size, f"{filename} size mismatch" - - start_time = time.monotonic() - path_again = await shard_downloader.ensure_shard(shard_metadata) - duration = time.monotonic() - start_time - assert path_again == path - assert duration < 5, f"Second call to ensure_shard took too long: {duration:.2f}s" diff --git a/src/exo/worker/tests/test_handlers/conftest.py b/src/exo/worker/tests/test_handlers/conftest.py deleted file mode 100644 index 1cfd7a41..00000000 --- a/src/exo/worker/tests/test_handlers/conftest.py +++ /dev/null @@ -1,65 +0,0 @@ -from typing import Callable - -import pytest - -from exo.shared.types.common import NodeId -from exo.shared.types.worker.common import InstanceId -from exo.shared.types.worker.instances import Instance -from exo.shared.types.worker.ops import ( - AssignRunnerOp, - RunnerUpOp, -) -from exo.shared.types.worker.runners import RunnerId -from exo.worker.main import Worker -from exo.worker.tests.constants import INSTANCE_1_ID, RUNNER_1_ID - - -@pytest.fixture -def user_message(): - return "What, according to Douglas Adams, is the meaning of life, the universe and everything?" - - -# TODO: instance_id and runner_id are selectable. -@pytest.fixture -async def worker_with_assigned_runner( - worker_void_mailbox: Worker, - instance: Callable[[InstanceId, NodeId, RunnerId], Instance], -): - """Fixture that provides a worker with an already assigned runner.""" - worker = worker_void_mailbox - - instance_id = INSTANCE_1_ID - runner_id = RUNNER_1_ID - instance_obj: Instance = instance(instance_id, worker.node_id, runner_id) - - # Assign the runner - assign_op = AssignRunnerOp( - runner_id=runner_id, - shard_metadata=instance_obj.shard_assignments.runner_to_shard[runner_id], - hosts=instance_obj.hosts, - instance_id=instance_obj.instance_id, - ) - - async for _ in worker.execute_op(assign_op): - pass - - return worker, instance_obj - - -@pytest.fixture -async def worker_with_running_runner( - worker_with_assigned_runner: tuple[Worker, Instance], -): - """Fixture that provides a worker with an already assigned runner.""" - worker, instance_obj = worker_with_assigned_runner - - runner_up_op = RunnerUpOp(runner_id=RUNNER_1_ID) - async for _ in worker.execute_op(runner_up_op): - pass - - # Is the runner actually running? - supervisor = next(iter(worker.assigned_runners.values())).runner - assert supervisor is not None - assert supervisor.runner_process.is_alive() - - return worker, instance_obj diff --git a/src/exo/worker/tests/test_handlers/test_handlers_happy.py b/src/exo/worker/tests/test_handlers/test_handlers_happy.py deleted file mode 100644 index 89e1bc10..00000000 --- a/src/exo/worker/tests/test_handlers/test_handlers_happy.py +++ /dev/null @@ -1,171 +0,0 @@ -from typing import Callable - -import pytest - -from exo.shared.types.chunks import TokenChunk -from exo.shared.types.common import NodeId -from exo.shared.types.events import ( - ChunkGenerated, - RunnerDeleted, - RunnerStatusUpdated, - TaskStateUpdated, -) -from exo.shared.types.tasks import ChatCompletionTask, TaskStatus -from exo.shared.types.worker.common import RunnerId -from exo.shared.types.worker.instances import Instance, InstanceId -from exo.shared.types.worker.ops import ( - AssignRunnerOp, - ExecuteTaskOp, - RunnerDownOp, - RunnerUpOp, - UnassignRunnerOp, -) -from exo.shared.types.worker.runners import ( - DownloadingRunnerStatus, - InactiveRunnerStatus, - LoadedRunnerStatus, - RunningRunnerStatus, - StartingRunnerStatus, -) -from exo.worker.main import Worker -from exo.worker.tests.constants import ( - RUNNER_1_ID, -) -from exo.worker.tests.test_handlers.utils import read_events_op - - -@pytest.mark.asyncio -async def test_assign_op( - worker_void_mailbox: Worker, - instance: Callable[[InstanceId, NodeId, RunnerId], Instance], -): - worker = worker_void_mailbox - instance_obj: Instance = instance(InstanceId(), worker.node_id, RUNNER_1_ID) - - assign_op = AssignRunnerOp( - runner_id=RUNNER_1_ID, - shard_metadata=instance_obj.shard_assignments.runner_to_shard[RUNNER_1_ID], - hosts=instance_obj.hosts, - instance_id=instance_obj.instance_id, - ) - - events = await read_events_op(worker, assign_op) - - # We should have a status update saying 'starting'. - assert len(events) == 2 - assert isinstance(events[0], RunnerStatusUpdated) - assert isinstance(events[0].runner_status, DownloadingRunnerStatus) - assert isinstance(events[1], RunnerStatusUpdated) - assert isinstance(events[1].runner_status, InactiveRunnerStatus) - - # And the runner should be assigned - assert RUNNER_1_ID in worker.assigned_runners - assert isinstance(worker.assigned_runners[RUNNER_1_ID].status, InactiveRunnerStatus) - - -@pytest.mark.asyncio -async def test_unassign_op(worker_with_assigned_runner: tuple[Worker, Instance]): - worker, _ = worker_with_assigned_runner - - unassign_op = UnassignRunnerOp(runner_id=RUNNER_1_ID) - - events = await read_events_op(worker, unassign_op) - - # We should have no assigned runners and no events were emitted - assert len(worker.assigned_runners) == 0 - assert len(events) == 1 - assert isinstance(events[0], RunnerDeleted) - - -@pytest.mark.asyncio -async def test_runner_up_op( - worker_with_assigned_runner: tuple[Worker, Instance], - chat_completion_task: Callable[[], ChatCompletionTask], -): - worker, _ = worker_with_assigned_runner - - runner_up_op = RunnerUpOp(runner_id=RUNNER_1_ID) - - events = await read_events_op(worker, runner_up_op) - - assert len(events) == 2 - assert isinstance(events[0], RunnerStatusUpdated) - assert isinstance(events[0].runner_status, StartingRunnerStatus) - assert isinstance(events[1], RunnerStatusUpdated) - assert isinstance(events[1].runner_status, LoadedRunnerStatus) - - # Is the runner actually running? - supervisor = next(iter(worker.assigned_runners.values())).runner - assert supervisor is not None - assert supervisor.runner_process.is_alive() - - full_response = "" - - async for chunk in supervisor.stream_response(task=chat_completion_task()): - if isinstance(chunk, TokenChunk): - full_response += chunk.text - - assert "42" in full_response.lower(), ( - f"Expected '42' in response, but got: {full_response}" - ) - - runner = worker.assigned_runners[RUNNER_1_ID].runner - assert runner is not None - await runner.astop() # Neat cleanup. - - -@pytest.mark.asyncio -async def test_runner_down_op(worker_with_running_runner: tuple[Worker, Instance]): - worker, _ = worker_with_running_runner - - runner_down_op = RunnerDownOp(runner_id=RUNNER_1_ID) - events = await read_events_op(worker, runner_down_op) - - assert len(events) == 1 - assert isinstance(events[0], RunnerStatusUpdated) - assert isinstance(events[0].runner_status, InactiveRunnerStatus) - - -@pytest.mark.asyncio -async def test_execute_task_op( - worker_with_running_runner: tuple[Worker, Instance], - chat_completion_task: Callable[[], ChatCompletionTask], -): - worker, _ = worker_with_running_runner - - execute_task_op = ExecuteTaskOp(runner_id=RUNNER_1_ID, task=chat_completion_task()) - - events = await read_events_op(worker, execute_task_op) - - assert len(events) > 20 - - print(f"{events=}") - - assert isinstance(events[0], RunnerStatusUpdated) - assert isinstance(events[0].runner_status, RunningRunnerStatus) - - assert isinstance(events[1], TaskStateUpdated) - assert events[1].task_status == TaskStatus.Running # It tried to start. - - assert isinstance(events[-2], TaskStateUpdated) - assert events[-2].task_status == TaskStatus.Complete # It tried to start. - - assert isinstance(events[-1], RunnerStatusUpdated) - assert isinstance( - events[-1].runner_status, LoadedRunnerStatus - ) # It should not have failed. - - gen_events: list[ChunkGenerated] = [ - x for x in events if isinstance(x, ChunkGenerated) - ] - text_chunks: list[TokenChunk] = [ - x.chunk for x in gen_events if isinstance(x.chunk, TokenChunk) - ] - assert len(text_chunks) == len(events) - 4 - - output_text = "".join([x.text for x in text_chunks]) - assert "42" in output_text - - runner = worker.assigned_runners[RUNNER_1_ID].runner - assert runner is not None - await runner.astop() # Neat cleanup. diff --git a/src/exo/worker/tests/test_handlers/test_handlers_sad.py b/src/exo/worker/tests/test_handlers/test_handlers_sad.py deleted file mode 100644 index 97d2772c..00000000 --- a/src/exo/worker/tests/test_handlers/test_handlers_sad.py +++ /dev/null @@ -1,83 +0,0 @@ -## Tests for worker state handlers - -import asyncio -from typing import Callable - -import pytest - -from exo.shared.types.tasks import ChatCompletionTask -from exo.shared.types.worker.common import RunnerError -from exo.shared.types.worker.instances import Instance -from exo.shared.types.worker.ops import ( - ExecuteTaskOp, - RunnerUpOp, -) -from exo.worker.main import Worker -from exo.worker.tests.constants import RUNNER_1_ID -from exo.worker.tests.test_handlers.utils import read_events_op - - -@pytest.mark.asyncio -async def test_runner_up_fails( - worker_with_assigned_runner: tuple[Worker, Instance], - chat_completion_task: Callable[[], ChatCompletionTask], -): - worker, _ = worker_with_assigned_runner - worker.assigned_runners[RUNNER_1_ID].shard_metadata.immediate_exception = True - - runner_up_op = RunnerUpOp(runner_id=RUNNER_1_ID) - - with pytest.raises(RunnerError): - await read_events_op(worker, runner_up_op) - - -@pytest.mark.asyncio -async def test_runner_up_timeouts( - worker_with_assigned_runner: tuple[Worker, Instance], - chat_completion_task: Callable[[], ChatCompletionTask], -): - worker, _ = worker_with_assigned_runner - worker.assigned_runners[RUNNER_1_ID].shard_metadata.should_timeout = 10 - - runner_up_op = RunnerUpOp(runner_id=RUNNER_1_ID) - - with pytest.raises(asyncio.TimeoutError): - await read_events_op(worker, runner_up_op) - - -@pytest.mark.asyncio -async def test_execute_task_fails( - worker_with_running_runner: tuple[Worker, Instance], - chat_completion_task: Callable[[], ChatCompletionTask], -): - worker, _ = worker_with_running_runner - - task = chat_completion_task() - messages = task.task_params.messages - messages[0].content = "Artificial prompt: EXO RUNNER MUST FAIL" - - execute_task_op = ExecuteTaskOp(runner_id=RUNNER_1_ID, task=task) - - with pytest.raises(RunnerError): - await read_events_op(worker, execute_task_op) - - -@pytest.mark.asyncio -async def test_execute_task_timeouts( - worker_with_running_runner: tuple[Worker, Instance], - chat_completion_task: Callable[[], ChatCompletionTask], -): - worker, _ = worker_with_running_runner - - task = chat_completion_task() - messages = task.task_params.messages - messages[0].content = "Artificial prompt: EXO RUNNER MUST TIMEOUT" - - execute_task_op = ExecuteTaskOp(runner_id=RUNNER_1_ID, task=task) - - with pytest.raises(asyncio.TimeoutError): - await read_events_op(worker, execute_task_op) - - -# TODO: Much more to do here! -# runner assigned download stuff diff --git a/src/exo/worker/tests/test_handlers/utils.py b/src/exo/worker/tests/test_handlers/utils.py deleted file mode 100644 index db5af33a..00000000 --- a/src/exo/worker/tests/test_handlers/utils.py +++ /dev/null @@ -1,17 +0,0 @@ -## Tests for worker state handlers - - -from exo.shared.types.events import ( - Event, -) -from exo.shared.types.worker.ops import ( - RunnerOp, -) -from exo.worker.main import Worker - - -async def read_events_op(worker: Worker, op: RunnerOp) -> list[Event]: - events: list[Event] = [] - async for event in worker.execute_op(op): - events.append(event) - return events diff --git a/src/exo/worker/tests/test_integration/__init__.py b/src/exo/worker/tests/test_integration/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/exo/worker/tests/test_integration/test_inference.py b/src/exo/worker/tests/test_integration/test_inference.py deleted file mode 100644 index 7b9b07d0..00000000 --- a/src/exo/worker/tests/test_integration/test_inference.py +++ /dev/null @@ -1,262 +0,0 @@ -import asyncio -from typing import Callable - -import pytest -from anyio import create_task_group - -from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams -from exo.shared.types.common import CommandId, Host, NodeId -from exo.shared.types.events import ( - InstanceCreated, - InstanceDeleted, - TaskCreated, -) -from exo.shared.types.models import ModelId -from exo.shared.types.tasks import ( - ChatCompletionTask, - Task, - TaskId, - TaskStatus, -) -from exo.shared.types.worker.common import InstanceId, RunnerId -from exo.shared.types.worker.instances import ( - Instance, - InstanceStatus, - ShardAssignments, -) -from exo.shared.types.worker.shards import PipelineShardMetadata -from exo.worker.main import Worker -from exo.worker.tests.constants import ( - INSTANCE_1_ID, - MASTER_NODE_ID, - NODE_A, - NODE_B, - RUNNER_1_ID, - RUNNER_2_ID, - TASK_1_ID, -) -from exo.worker.tests.worker_management import ( - WorkerMailbox, - read_streaming_response, -) - - -@pytest.fixture -def user_message(): - """Override this fixture in tests to customize the message""" - return "What's the capital of Japan?" - - -async def test_runner_inference( - instance: Callable[[InstanceId, NodeId, RunnerId], Instance], - chat_completion_task: Callable[[InstanceId, TaskId], Task], - worker_and_mailbox: tuple[Worker, WorkerMailbox], -): - worker, global_events = worker_and_mailbox - async with create_task_group() as tg: - tg.start_soon(worker.run) - instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID) - instance_value.instance_type = InstanceStatus.Active - - task: Task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID) - await global_events.append_events( - [ - InstanceCreated( - instance=instance_value, - ), - TaskCreated(task_id=task.task_id, task=task), - ], - origin=MASTER_NODE_ID, - ) - - # TODO: This needs to get fixed - sometimes it misses the 'starting' event. - ( - seen_task_started, - seen_task_finished, - response_string, - _, - ) = await read_streaming_response(global_events) - - assert seen_task_started - assert seen_task_finished - assert "tokyo" in response_string.lower() - - await global_events.append_events( - [ - InstanceDeleted( - instance_id=instance_value.instance_id, - ), - ], - origin=MASTER_NODE_ID, - ) - - await asyncio.sleep(0.3) - worker.shutdown() - # TODO: Ensure this is sufficient, or add mechanism to fail the test gracefully if workers do not shutdown properly. - - -async def test_2_runner_inference( - pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata], - hosts: Callable[[int], list[Host]], - chat_completion_task: Callable[[InstanceId, TaskId], Task], - two_workers_with_shared_mailbox: tuple[Worker, Worker, WorkerMailbox], -): - worker1, worker2, global_events = two_workers_with_shared_mailbox - async with create_task_group() as tg: - tg.start_soon(worker1.run) - tg.start_soon(worker2.run) - ## Instance - model_id = ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit") - - shard_assignments = ShardAssignments( - model_id=model_id, - runner_to_shard={ - RUNNER_1_ID: pipeline_shard_meta(2, 0), - RUNNER_2_ID: pipeline_shard_meta(2, 1), - }, - node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID}, - ) - - instance = Instance( - instance_id=INSTANCE_1_ID, - instance_type=InstanceStatus.Active, - shard_assignments=shard_assignments, - hosts=hosts(2), - ) - - task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID) - await global_events.append_events( - [ - InstanceCreated(instance=instance), - TaskCreated(task_id=task.task_id, task=task), - ], - origin=MASTER_NODE_ID, - ) - - ( - seen_task_started, - seen_task_finished, - response_string, - _, - ) = await read_streaming_response(global_events) - - assert seen_task_started - assert seen_task_finished - assert "tokyo" in response_string.lower() - - _ = global_events.collect() - await asyncio.sleep(1.0) - events = global_events.collect() - assert len(events) == 0 - - await global_events.append_events( - [ - InstanceDeleted( - instance_id=instance.instance_id, - ), - ], - origin=MASTER_NODE_ID, - ) - - await asyncio.sleep(2.0) - worker1.shutdown() - worker2.shutdown() - # TODO: Ensure this is sufficient, or add mechanism to fail the test gracefully if workers do not shutdown properly. - - -# TODO: Multi message parallel -async def test_2_runner_multi_message( - pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata], - hosts: Callable[[int], list[Host]], - two_workers_with_shared_mailbox: tuple[Worker, Worker, WorkerMailbox], -): - worker1, worker2, global_events = two_workers_with_shared_mailbox - async with create_task_group() as tg: - tg.start_soon(worker1.run) - tg.start_soon(worker2.run) - - ## Instance - model_id = ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit") - - shard_assignments = ShardAssignments( - model_id=model_id, - runner_to_shard={ - RUNNER_1_ID: pipeline_shard_meta(2, 0), - RUNNER_2_ID: pipeline_shard_meta(2, 1), - }, - node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID}, - ) - - instance = Instance( - instance_id=INSTANCE_1_ID, - instance_type=InstanceStatus.Active, - shard_assignments=shard_assignments, - hosts=hosts(2), - ) - - # Task - we have three messages here, which is what the task is about - - completion_create_params = ChatCompletionTaskParams( - model="gpt-4", - messages=[ - ChatCompletionMessage( - role="user", content="What is the capital of France?" - ), - ChatCompletionMessage( - role="assistant", content="The capital of France is Paris." - ), - ChatCompletionMessage( - role="user", - content="Ok great. Now write me a haiku about what you can do there.", - ), - ], - stream=True, - ) - - task = ChatCompletionTask( - task_id=TASK_1_ID, - command_id=CommandId(), - instance_id=INSTANCE_1_ID, - task_status=TaskStatus.Pending, - task_params=completion_create_params, - ) - - await global_events.append_events( - [ - InstanceCreated(instance=instance), - TaskCreated(task_id=task.task_id, task=task), - ], - origin=MASTER_NODE_ID, - ) - - ( - seen_task_started, - seen_task_finished, - response_string, - _, - ) = await read_streaming_response(global_events) - - assert seen_task_started - assert seen_task_finished - assert any( - keyword in response_string.lower() - for keyword in ("kiss", "paris", "art", "love") - ) - - _ = global_events.collect() - await asyncio.sleep(1.0) - events = global_events.collect() - assert len(events) == 0 - - await global_events.append_events( - [ - InstanceDeleted( - instance_id=instance.instance_id, - ), - ], - origin=MASTER_NODE_ID, - ) - - worker1.shutdown() - worker2.shutdown() - # TODO: Ensure this is sufficient, or add mechanism to fail the test gracefully if workers do not shutdown properly. diff --git a/src/exo/worker/tests/test_integration/test_inference_sad.py b/src/exo/worker/tests/test_integration/test_inference_sad.py deleted file mode 100644 index 595adb22..00000000 --- a/src/exo/worker/tests/test_integration/test_inference_sad.py +++ /dev/null @@ -1,311 +0,0 @@ -import asyncio -from collections.abc import AsyncGenerator -from types import CoroutineType -from typing import Any, Callable - -import pytest -from _pytest.monkeypatch import MonkeyPatch -from anyio import create_task_group - -from exo.shared.types.chunks import GenerationChunk, TokenChunk - -# TaskStateUpdated and ChunkGenerated are used in test_worker_integration_utils.py -from exo.shared.types.common import NodeId -from exo.shared.types.events import ( - ChunkGenerated, - InstanceCreated, - InstanceDeleted, - RunnerStatusUpdated, - TaskCreated, - TaskFailed, - TaskStateUpdated, -) -from exo.shared.types.tasks import Task, TaskId, TaskStatus -from exo.shared.types.worker.common import InstanceId, RunnerId -from exo.shared.types.worker.instances import ( - Instance, - InstanceStatus, -) -from exo.shared.types.worker.runners import FailedRunnerStatus -from exo.worker.main import Worker -from exo.worker.runner.runner_supervisor import RunnerSupervisor -from exo.worker.tests.constants import ( - INSTANCE_1_ID, - MASTER_NODE_ID, - NODE_A, - RUNNER_1_ID, - TASK_1_ID, -) -from exo.worker.tests.worker_management import ( - WorkerMailbox, - until_event_with_timeout, -) - - -@pytest.fixture -def user_message(): - """Override this fixture in tests to customize the message""" - return "Who is the longest ruling monarch of England?" - - -async def test_stream_response_failed_always( - monkeypatch: MonkeyPatch, - instance: Callable[[InstanceId, NodeId, RunnerId], Instance], - chat_completion_task: Callable[[InstanceId, TaskId], Task], - worker_and_mailbox: tuple[Worker, WorkerMailbox], -) -> None: - worker, global_events = worker_and_mailbox - async with create_task_group() as tg: - tg.start_soon(worker.run) - instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID) - instance_value.instance_type = InstanceStatus.Active - - async def mock_stream_response( - self: RunnerSupervisor, - task: Task, - request_started_callback: Callable[..., CoroutineType[Any, Any, None]] - | None = None, - ) -> AsyncGenerator[GenerationChunk, None]: - raise RuntimeError("Simulated stream response failure") - - monkeypatch.setattr(RunnerSupervisor, "stream_response", mock_stream_response) - - task: Task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID) - await global_events.append_events( - [ - InstanceCreated(instance=instance_value), - TaskCreated(task_id=task.task_id, task=task), - ], - origin=MASTER_NODE_ID, - ) - - await until_event_with_timeout(global_events, InstanceDeleted, timeout=10.0) - - events = global_events.collect() - - assert ( - len( - [ - x - for x in events - if isinstance(x.event, RunnerStatusUpdated) - and isinstance(x.event.runner_status, FailedRunnerStatus) - ] - ) - == 3 - ) - assert ( - len( - [ - x - for x in events - if isinstance(x.event, TaskStateUpdated) - and x.event.task_status == TaskStatus.Failed - ] - ) - == 3 - ) - assert any([isinstance(x.event, InstanceDeleted) for x in events]) - - await global_events.append_events( - [ - InstanceDeleted( - instance_id=instance_value.instance_id, - ), - ], - origin=MASTER_NODE_ID, - ) - - await asyncio.sleep(0.3) - worker.shutdown() - - -async def test_stream_response_failed_once( - monkeypatch: MonkeyPatch, - instance: Callable[[InstanceId, NodeId, RunnerId], Instance], - chat_completion_task: Callable[[InstanceId, TaskId], Task], - worker_and_mailbox: tuple[Worker, WorkerMailbox], -): - worker, global_events = worker_and_mailbox - failed_already = False - original_stream_response = RunnerSupervisor.stream_response - - async def mock_stream_response( - self: RunnerSupervisor, - task: Task, - request_started_callback: Callable[..., CoroutineType[Any, Any, None]] - | None = None, - ) -> AsyncGenerator[GenerationChunk]: - nonlocal failed_already - if not failed_already: - failed_already = True - raise RuntimeError("Simulated stream response failure") - else: - async for event in original_stream_response( - self, task, request_started_callback - ): - yield event - return - - monkeypatch.setattr(RunnerSupervisor, "stream_response", mock_stream_response) - - async with create_task_group() as tg: - tg.start_soon(worker.run) - instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID) - instance_value.instance_type = InstanceStatus.Active - - task: Task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID) - await global_events.append_events( - [ - InstanceCreated(instance=instance_value), - TaskCreated(task_id=task.task_id, task=task), - ], - origin=MASTER_NODE_ID, - ) - - await until_event_with_timeout( - global_events, - ChunkGenerated, - 1, - condition=lambda x: isinstance(x.chunk, TokenChunk) - and x.chunk.finish_reason is not None, - timeout=30.0, - ) - - # TODO: The ideal with this test is if we had some tooling to scroll through the state, and say - # 'asser that there was a time that the error_type, error_message was not none and the failure count was nonzero' - - # as we reset the failures back to zero when we have a successful inference. - assert len(worker.assigned_runners[RUNNER_1_ID].failures) == 0 - assert worker.state.tasks[TASK_1_ID].error_type is None - assert worker.state.tasks[TASK_1_ID].error_message is None - - events = global_events.collect() - assert ( - len( - [ - x - for x in events - if isinstance(x.event, RunnerStatusUpdated) - and isinstance(x.event.runner_status, FailedRunnerStatus) - ] - ) - == 1 - ) - assert ( - len( - [ - x - for x in events - if isinstance(x.event, TaskStateUpdated) - and x.event.task_status == TaskStatus.Failed - ] - ) - == 1 - ) - - response_string = "" - events = global_events.collect() - - seen_task_started, seen_task_finished = False, False - for wrapped_event in events: - event = wrapped_event.event - if isinstance(event, TaskStateUpdated): - if event.task_status == TaskStatus.Running: - seen_task_started = True - if event.task_status == TaskStatus.Complete: - seen_task_finished = True - - if isinstance(event, ChunkGenerated): - assert isinstance(event.chunk, TokenChunk) - response_string += event.chunk.text - - assert "queen" in response_string.lower() - assert seen_task_started - assert seen_task_finished - - await global_events.append_events( - [ - InstanceDeleted( - instance_id=instance_value.instance_id, - ), - ], - origin=MASTER_NODE_ID, - ) - - await asyncio.sleep(0.3) - worker.shutdown() - - -async def test_stream_response_timeout( - instance: Callable[[InstanceId, NodeId, RunnerId], Instance], - chat_completion_task: Callable[[InstanceId, TaskId], Task], - worker_and_mailbox: tuple[Worker, WorkerMailbox], -): - worker, global_events = worker_and_mailbox - async with create_task_group() as tg: - tg.start_soon(worker.run) - instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID) - instance_value.instance_type = InstanceStatus.Active - - task: Task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID) - task.task_params.messages[0].content = "EXO RUNNER MUST TIMEOUT" - await global_events.append_events( - [ - InstanceCreated(instance=instance_value), - TaskCreated(task_id=task.task_id, task=task), - ], - origin=MASTER_NODE_ID, - ) - - await until_event_with_timeout( - global_events, TaskFailed, multiplicity=3, timeout=30.0 - ) - - events = global_events.collect() - print(events) - assert ( - len( - [ - x - for x in events - if isinstance(x.event, RunnerStatusUpdated) - and isinstance(x.event.runner_status, FailedRunnerStatus) - ] - ) - == 3 - ) - assert ( - len( - [ - x - for x in events - if isinstance(x.event, TaskStateUpdated) - and x.event.task_status == TaskStatus.Failed - ] - ) - == 3 - ) - assert ( - len( - [ - x - for x in events - if isinstance(x.event, TaskFailed) - and "timeouterror" in x.event.error_type.lower() - ] - ) - == 3 - ) - - await global_events.append_events( - [ - InstanceDeleted( - instance_id=instance_value.instance_id, - ), - ], - origin=MASTER_NODE_ID, - ) - - await asyncio.sleep(0.3) - worker.shutdown() diff --git a/src/exo/worker/tests/test_integration/test_instantiation.py b/src/exo/worker/tests/test_integration/test_instantiation.py deleted file mode 100644 index 4d852123..00000000 --- a/src/exo/worker/tests/test_integration/test_instantiation.py +++ /dev/null @@ -1,71 +0,0 @@ -from typing import Callable - -from anyio import create_task_group - -# TaskStateUpdated and ChunkGenerated are used in test_worker_integration_utils.py -from exo.shared.types.common import NodeId - -# TaskStateUpdated and ChunkGenerated are used in test_worker_integration_utils.py -from exo.shared.types.events import ( - InstanceCreated, - InstanceDeleted, - RunnerStatusUpdated, -) -from exo.shared.types.worker.common import InstanceId, RunnerId -from exo.shared.types.worker.instances import ( - Instance, - InstanceStatus, -) -from exo.shared.types.worker.runners import ( - FailedRunnerStatus, -) -from exo.worker.main import Worker -from exo.worker.tests.constants import ( - INSTANCE_1_ID, - MASTER_NODE_ID, - NODE_A, - RUNNER_1_ID, -) -from exo.worker.tests.worker_management import WorkerMailbox, until_event_with_timeout - - -async def test_runner_spinup_timeout( - instance: Callable[[InstanceId, NodeId, RunnerId], Instance], - worker_and_mailbox: tuple[Worker, WorkerMailbox], -): - worker, global_events = worker_and_mailbox - async with create_task_group() as tg: - tg.start_soon(worker.run) - instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID) - instance_value.instance_type = InstanceStatus.Active - instance_value.shard_assignments.runner_to_shard[ - RUNNER_1_ID - ].should_timeout = 10 - - await global_events.append_events( - [InstanceCreated(instance=instance_value)], origin=MASTER_NODE_ID - ) - - await until_event_with_timeout( - global_events, - RunnerStatusUpdated, - multiplicity=3, - condition=lambda x: isinstance(x.runner_status, FailedRunnerStatus), - ) - - # Ensure the correct events have been emitted - events = global_events.collect() - - assert ( - len( - [ - x - for x in events - if isinstance(x.event, RunnerStatusUpdated) - and isinstance(x.event.runner_status, FailedRunnerStatus) - ] - ) - == 3 - ) - assert any([isinstance(x.event, InstanceDeleted) for x in events]) - worker.shutdown() diff --git a/src/exo/worker/tests/test_integration/test_instantiation_sad.py b/src/exo/worker/tests/test_integration/test_instantiation_sad.py deleted file mode 100644 index e734ed49..00000000 --- a/src/exo/worker/tests/test_integration/test_instantiation_sad.py +++ /dev/null @@ -1,109 +0,0 @@ -import asyncio -from typing import Callable - -from anyio import create_task_group - -# TaskStateUpdated and ChunkGenerated are used in test_worker_integration_utils.py -from exo.shared.types.common import NodeId - -# TaskStateUpdated and ChunkGenerated are used in test_worker_integration_utils.py -from exo.shared.types.events import ( - InstanceCreated, - InstanceDeleted, - RunnerStatusUpdated, -) -from exo.shared.types.worker.common import InstanceId, RunnerId -from exo.shared.types.worker.instances import ( - Instance, - InstanceStatus, -) -from exo.shared.types.worker.runners import ( - FailedRunnerStatus, -) -from exo.worker.main import Worker -from exo.worker.tests.constants import ( - INSTANCE_1_ID, - MASTER_NODE_ID, - NODE_A, - RUNNER_1_ID, -) -from exo.worker.tests.worker_management import WorkerMailbox, until_event_with_timeout - - -async def test_runner_spinup_exception( - instance: Callable[[InstanceId, NodeId, RunnerId], Instance], - worker_and_mailbox: tuple[Worker, WorkerMailbox], -): - worker, global_events = worker_and_mailbox - async with create_task_group() as tg: - tg.start_soon(worker.run) - instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID) - instance_value.instance_type = InstanceStatus.Active - instance_value.shard_assignments.runner_to_shard[ - RUNNER_1_ID - ].immediate_exception = True - - await global_events.append_events( - [InstanceCreated(instance=instance_value)], origin=MASTER_NODE_ID - ) - - await asyncio.sleep(10.0) - - # Ensure the correct events have been emitted - events = global_events.collect() - - assert ( - len( - [ - x - for x in events - if isinstance(x.event, RunnerStatusUpdated) - and isinstance(x.event.runner_status, FailedRunnerStatus) - ] - ) - == 3 - ) - assert any([isinstance(x.event, InstanceDeleted) for x in events]) - worker.shutdown() - - -async def test_runner_spinup_timeout( - instance: Callable[[InstanceId, NodeId, RunnerId], Instance], - worker_and_mailbox: tuple[Worker, WorkerMailbox], -): - worker, global_events = worker_and_mailbox - async with create_task_group() as tg: - tg.start_soon(worker.run) - instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID) - instance_value.instance_type = InstanceStatus.Active - instance_value.shard_assignments.runner_to_shard[ - RUNNER_1_ID - ].should_timeout = 10 - - await global_events.append_events( - [InstanceCreated(instance=instance_value)], origin=MASTER_NODE_ID - ) - - await until_event_with_timeout( - global_events, - RunnerStatusUpdated, - multiplicity=3, - condition=lambda x: isinstance(x.runner_status, FailedRunnerStatus), - ) - - # Ensure the correct events have been emitted - events = global_events.collect() - - assert ( - len( - [ - x - for x in events - if isinstance(x.event, RunnerStatusUpdated) - and isinstance(x.event.runner_status, FailedRunnerStatus) - ] - ) - == 3 - ) - assert any([isinstance(x.event, InstanceDeleted) for x in events]) - worker.shutdown() diff --git a/src/exo/worker/tests/test_multimodel/test_inference_llama70B.py b/src/exo/worker/tests/test_multimodel/test_inference_llama70B.py deleted file mode 100644 index 60501d9c..00000000 --- a/src/exo/worker/tests/test_multimodel/test_inference_llama70B.py +++ /dev/null @@ -1,525 +0,0 @@ -import asyncio -import os -import time -from typing import Callable - -import pytest -from anyio import create_task_group - -from exo.shared.models.model_meta import get_model_meta -from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams -from exo.shared.types.common import Host -from exo.shared.types.events import ( - ChunkGenerated, - InstanceCreated, - InstanceDeleted, - RunnerStatusUpdated, - TaskCreated, -) -from exo.shared.types.models import ModelId, ModelMetadata -from exo.shared.types.tasks import ( - ChatCompletionTask, - Task, - TaskId, - TaskStatus, -) -from exo.shared.types.worker.common import InstanceId -from exo.shared.types.worker.instances import ( - Instance, - InstanceStatus, - ShardAssignments, -) -from exo.shared.types.worker.runners import LoadedRunnerStatus -from exo.shared.types.worker.shards import PipelineShardMetadata -from exo.worker.main import Worker -from exo.worker.tests.constants import ( - COMMAND_1_ID, - COMMAND_2_ID, - INSTANCE_1_ID, - MASTER_NODE_ID, - NODE_A, - NODE_B, - RUNNER_1_ID, - RUNNER_2_ID, - TASK_1_ID, - TASK_2_ID, -) -from exo.worker.tests.worker_management import ( - WorkerMailbox, - read_streaming_response, - until_event_with_timeout, -) - -MODEL_ID = "mlx-community/Llama-3.3-70B-Instruct-4bit" -SKIP = True - - -@pytest.fixture -async def model_meta() -> ModelMetadata: - return await get_model_meta(MODEL_ID) - - -def _get_model_size_gb(path: str) -> float: - """Calculate total size of directory recursively in GB.""" - total_size = 0 - for dirpath, _, filenames in os.walk(path): - for filename in filenames: - filepath = os.path.join(dirpath, filename) - if os.path.isfile(filepath): - total_size += os.path.getsize(filepath) - return total_size / (1024**3) # Convert bytes to GB - - -skip = SKIP or not ( - os.path.exists( - os.path.expanduser("~/.exo/models/mlx-community--Llama-3.3-70B-Instruct-4bit/") - ) - and _get_model_size_gb( - os.path.expanduser("~/.exo/models/mlx-community--Llama-3.3-70B-Instruct-4bit/") - ) - > 30 -) - - -@pytest.mark.skipif( - skip, - reason="This test only runs when model mlx-community/Llama-3.3-70B-Instruct-4bit is downloaded", -) -async def test_ttft( - pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata], - hosts: Callable[[int], list[Host]], - worker_and_mailbox: tuple[Worker, WorkerMailbox], -): - from loguru import logger - - worker, global_events = worker_and_mailbox - async with create_task_group() as tg: - tg.start_soon(worker.run) - ## Instance - model_id = ModelId(MODEL_ID) - - shard_assignments = ShardAssignments( - model_id=model_id, - runner_to_shard={RUNNER_1_ID: pipeline_shard_meta(1, 0)}, - node_to_runner={NODE_A: RUNNER_1_ID}, - ) - - instance = Instance( - instance_id=INSTANCE_1_ID, - instance_type=InstanceStatus.Active, - shard_assignments=shard_assignments, - hosts=hosts(1), - ) - - # Create instance first - await global_events.append_events( - [InstanceCreated(instance=instance)], origin=MASTER_NODE_ID - ) - - await until_event_with_timeout( - global_events, - event_type=RunnerStatusUpdated, - condition=lambda x: isinstance(x.runner_status, LoadedRunnerStatus), - ) - logger.info("model loaded.") - - # First inference - task1_params = ChatCompletionTaskParams( - model="gpt-4", - messages=[ - ChatCompletionMessage( - role="user", content="Please write a haiku about a flower." - ) - ], - stream=True, - max_tokens=100, - ) - task1 = ChatCompletionTask( - task_id=TASK_1_ID, - command_id=COMMAND_1_ID, - instance_id=INSTANCE_1_ID, - task_status=TaskStatus.Pending, - task_params=task1_params, - ) - - print("Starting first inference...") - # Clean out the current global events - _ = global_events.collect() - - task_created_time_1 = time.time() - await global_events.append_events( - [TaskCreated(task_id=task1.task_id, task=task1)], origin=MASTER_NODE_ID - ) - - # Wait for first chunk to measure time to first token - first_chunk_seen_1 = False - time_to_first_token_1: None | float = None - while not first_chunk_seen_1: - event = (await global_events.receive()).event - if isinstance(event, ChunkGenerated) and hasattr(event, "chunk"): - first_chunk_time_1 = time.time() - time_to_first_token_1 = first_chunk_time_1 - task_created_time_1 - first_chunk_seen_1 = True - break - - ( - _, - seen_task_finished_1, - response_string_1, - token_count_1, - ) = await read_streaming_response(global_events) - total_time_1 = time.time() - task_created_time_1 - - assert seen_task_finished_1 - - # Wait for first task to complete - await asyncio.sleep(5.0) - - # Second inference - task2_params = ChatCompletionTaskParams( - model="gpt-4", - messages=[ - ChatCompletionMessage( - role="user", content="Write me a haiku about a robot." - ) - ], - stream=True, - max_tokens=150, - ) - task2 = ChatCompletionTask( - task_id=TASK_2_ID, - command_id=COMMAND_2_ID, - instance_id=INSTANCE_1_ID, - task_status=TaskStatus.Pending, - task_params=task2_params, - ) - - print("Starting second inference...") - # Clean out the current global events - # Record the current event index before creating the second task - _ = global_events.collect() - - task_created_time_2 = time.time() - await global_events.append_events( - [TaskCreated(task_id=task2.task_id, task=task2)], origin=MASTER_NODE_ID - ) - - # Wait for first chunk of second task to measure time to first token - first_chunk_seen_2 = False - time_to_first_token_2: float | None = None - while not first_chunk_seen_2: - event = (await global_events.receive()).event - if isinstance(event, ChunkGenerated) and hasattr(event, "chunk"): - first_chunk_time_2 = time.time() - time_to_first_token_2 = first_chunk_time_2 - task_created_time_2 - first_chunk_seen_2 = True - break - - ( - _, - seen_task_finished_2, - response_string_2, - token_count_2, - ) = await read_streaming_response(global_events, filter_task=TASK_2_ID) - total_time_2 = time.time() - task_created_time_2 - - assert seen_task_finished_2 - assert time_to_first_token_1 - assert time_to_first_token_2 - - # Calculate TPS metrics - # Prompt is approximately 45 tokens according to user - prompt_tokens = 45 - - # Prefill TPS = prompt tokens / time to first token - prefill_tps_1 = ( - prompt_tokens / time_to_first_token_1 if time_to_first_token_1 > 0 else 0 - ) - prefill_tps_2 = ( - prompt_tokens / time_to_first_token_2 if time_to_first_token_2 > 0 else 0 - ) - - # Generation TPS = generated tokens / generation time - # Generation time = total time - time to first token - generation_time_1 = total_time_1 - time_to_first_token_1 - generation_time_2 = total_time_2 - time_to_first_token_2 - generation_tps_1 = ( - token_count_1 / generation_time_1 if generation_time_1 > 0 else 0 - ) - generation_tps_2 = ( - token_count_2 / generation_time_2 if generation_time_2 > 0 else 0 - ) - - # Display time to first token profiling results - print("\n=== Time to First Token Profiling ===") - print(f"First inference ('{task1.task_params.messages[0].content}'):") - print(f" Time to first token: {time_to_first_token_1:.3f}s") - print(f" Total completion time: {total_time_1:.3f}s") - print(f" Tokens generated: {token_count_1}") - print(f" Response length: {len(response_string_1)} chars") - print( - f" Prefill TPS: {prefill_tps_1:.1f} tokens/sec ({prompt_tokens} prompt tokens / {time_to_first_token_1:.3f}s)" - ) - print( - f" Generation TPS: {generation_tps_1:.1f} tokens/sec ({token_count_1} tokens / {generation_time_1:.3f}s)" - ) - - print(f"\nSecond inference ('{task2.task_params.messages[0].content}'):") - print(f" Time to first token: {time_to_first_token_2:.3f}s") - print(f" Total completion time: {total_time_2:.3f}s") - print(f" Tokens generated: {token_count_2}") - print(f" Response length: {len(response_string_2)} chars") - print( - f" Prefill TPS: {prefill_tps_2:.1f} tokens/sec ({prompt_tokens} prompt tokens / {time_to_first_token_2:.3f}s)" - ) - print( - f" Generation TPS: {generation_tps_2:.1f} tokens/sec ({token_count_2} tokens / {generation_time_2:.3f}s)" - ) - - print("\nComparison:") - print( - f" Second inference time to first token: {time_to_first_token_2 / time_to_first_token_1:.2f}x the first" - ) - print( - f" Second inference prefill TPS: {prefill_tps_2 / prefill_tps_1:.2f}x the first" - ) - print( - f" Second inference generation TPS: {generation_tps_2 / generation_tps_1:.2f}x the first" - ) - - # Basic assertions to ensure responses make sense - assert len(response_string_1) > 0 - assert len(response_string_2) > 0 - assert time_to_first_token_1 and time_to_first_token_1 > 0 - assert time_to_first_token_2 and time_to_first_token_2 > 0 - - # Cleanup - _ = global_events.collect() - await asyncio.sleep(1.0) - events = global_events.collect() - assert len(events) == 0 - - await global_events.append_events( - [ - InstanceDeleted( - instance_id=instance.instance_id, - ), - ], - origin=MASTER_NODE_ID, - ) - - await asyncio.sleep(2.0) - worker.shutdown() - - -@pytest.mark.skipif( - skip, - reason="This test only runs when model mlx-community/Llama-3.3-70B-Instruct-4bit is downloaded", -) -async def test_2_runner_inference( - pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata], - hosts: Callable[[int], list[Host]], - chat_completion_task: Callable[[InstanceId, TaskId], Task], - two_workers_with_shared_mailbox: tuple[Worker, Worker, WorkerMailbox], -): - worker1, worker2, global_events = two_workers_with_shared_mailbox - - async with create_task_group() as tg: - tg.start_soon(worker1.run) - tg.start_soon(worker2.run) - ## Instance - model_id = ModelId(MODEL_ID) - - shard_assignments = ShardAssignments( - model_id=model_id, - runner_to_shard={ - RUNNER_1_ID: pipeline_shard_meta(2, 0), - RUNNER_2_ID: pipeline_shard_meta(2, 1), - }, - node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID}, - ) - - instance = Instance( - instance_id=INSTANCE_1_ID, - instance_type=InstanceStatus.Active, - shard_assignments=shard_assignments, - hosts=hosts(2), - ) - - task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID) - task.task_params.messages[ - 0 - ].content = "Can you explain to me how a bubble sort works, speaking as if you are a fairy." - task.task_params.max_tokens = 1000 - - await global_events.append_events( - [ - InstanceCreated(instance=instance), - TaskCreated(task_id=task.task_id, task=task), - ], - origin=MASTER_NODE_ID, - ) - - ( - seen_task_started, - seen_task_finished, - response_string, - _, - ) = await read_streaming_response(global_events) - - assert seen_task_started - assert seen_task_finished - assert "swap" in response_string.lower() - - _ = global_events.collect() - await asyncio.sleep(1.0) - events = global_events.collect() - assert len(events) == 0 - - await global_events.append_events( - [ - InstanceDeleted( - instance_id=instance.instance_id, - ), - ], - origin=MASTER_NODE_ID, - ) - - await asyncio.sleep(2.0) - - worker1.shutdown() - worker2.shutdown() - - -@pytest.mark.skipif( - skip, - reason="This test only runs when model mlx-community/Llama-3.3-70B-Instruct-4bit is downloaded", -) -async def test_parallel_inference( - pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata], - hosts: Callable[[int], list[Host]], - chat_completion_task: Callable[[InstanceId, TaskId], Task], - two_workers_with_shared_mailbox: tuple[Worker, Worker, WorkerMailbox], -): - worker1, worker2, global_events = two_workers_with_shared_mailbox - - async with create_task_group() as tg: - tg.start_soon(worker1.run) - tg.start_soon(worker2.run) - - ## Instance - model_id = ModelId(MODEL_ID) - - shard_assignments = ShardAssignments( - model_id=model_id, - runner_to_shard={ - RUNNER_1_ID: pipeline_shard_meta(2, 0), - RUNNER_2_ID: pipeline_shard_meta(2, 1), - }, - node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID}, - ) - - instance = Instance( - instance_id=INSTANCE_1_ID, - instance_type=InstanceStatus.Active, - shard_assignments=shard_assignments, - hosts=hosts(2), - ) - - completion_create_params_1 = ChatCompletionTaskParams( - model="gpt-4", - messages=[ - ChatCompletionMessage( - role="user", content='Tell me a haiku that uses the word "pond".' - ) - ], - stream=True, - max_tokens=1000, - ) - task1 = ChatCompletionTask( - task_id=TASK_1_ID, - command_id=COMMAND_1_ID, - instance_id=INSTANCE_1_ID, - task_status=TaskStatus.Pending, - task_params=completion_create_params_1, - ) - - completion_create_params_2 = ChatCompletionTaskParams( - model="gpt-4", - messages=[ - ChatCompletionMessage( - role="user", content='Tell me a haiku that uses the word "tree".' - ) - ], - stream=True, - max_tokens=1000, - ) - task2 = ChatCompletionTask( - task_id=TASK_2_ID, - command_id=COMMAND_2_ID, - instance_id=INSTANCE_1_ID, - task_status=TaskStatus.Pending, - task_params=completion_create_params_2, - ) - - await global_events.append_events( - [ - InstanceCreated(instance=instance), - TaskCreated(task_id=task1.task_id, task=task1), - TaskCreated(task_id=task2.task_id, task=task2), - ], - origin=MASTER_NODE_ID, - ) - - ( - seen_task_started_1, - seen_task_finished_1, - response_string_1, - _, - ) = await read_streaming_response(global_events) - - incomplete_task = ( - TASK_2_ID - if worker1.state.tasks[TASK_1_ID].task_status == TaskStatus.Complete - else TASK_2_ID - ) - ( - seen_task_started_2, - seen_task_finished_2, - response_string_2, - _, - ) = await read_streaming_response(global_events, filter_task=incomplete_task) - - assert seen_task_started_1 - assert seen_task_finished_1 - assert seen_task_started_2 - assert seen_task_finished_2 - - print(response_string_1) - print(response_string_2) - - assert ("pond" in response_string_1.lower()) ^ ( - "pond" in response_string_2.lower() - ), "'pond' must appear in exactly one response" - assert ("tree" in response_string_1.lower()) ^ ( - "tree" in response_string_2.lower() - ), "'tree' must appear in exactly one response" - - _ = global_events.collect() - await asyncio.sleep(1.0) - events = global_events.collect() - assert len(events) == 0 - - await global_events.append_events( - [ - InstanceDeleted( - instance_id=instance.instance_id, - ), - ], - origin=MASTER_NODE_ID, - ) - - await asyncio.sleep(2.0) - - worker1.shutdown() - worker2.shutdown() diff --git a/src/exo/worker/tests/test_plan/test_worker_plan.py b/src/exo/worker/tests/test_plan/test_worker_plan.py deleted file mode 100644 index c555edd4..00000000 --- a/src/exo/worker/tests/test_plan/test_worker_plan.py +++ /dev/null @@ -1,550 +0,0 @@ -import pytest -from exo.worker.common import AssignedRunner - -from exo.shared.types.api import ChatCompletionMessage -from exo.shared.types.state import State -from exo.shared.types.tasks import ( - ChatCompletionTask, - ChatCompletionTaskParams, - TaskStatus, -) -from exo.shared.types.worker.common import WorkerStatus -from exo.shared.types.worker.downloads import ( - DownloadPending, -) -from exo.shared.types.worker.instances import InstanceStatus -from exo.shared.types.worker.ops import ( - AssignRunnerOp, - ExecuteTaskOp, - RunnerDownOp, - RunnerUpOp, - UnassignRunnerOp, -) -from exo.shared.types.worker.runners import ( - DownloadingRunnerStatus, - FailedRunnerStatus, - InactiveRunnerStatus, - LoadedRunnerStatus, - RunningRunnerStatus, -) -from exo.shared.types.worker.shards import PipelineShardMetadata -from exo.worker.main import Worker -from exo.worker.plan import plan -from exo.worker.tests.constants import ( - COMMAND_1_ID, - INSTANCE_1_ID, - MODEL_A_ID, - NODE_A, - NODE_B, - RUNNER_1_ID, - RUNNER_2_ID, - TASK_1_ID, -) -from exo.worker.tests.test_plan.test_worker_plan_utils import ( - InProcessRunner, - PlanTestCase, - make_downloading_status, - make_model_meta, - make_state, - make_test_case, -) - -""" -The idea with these tests is to define declaratively the input and expected output of the worker.plan function. - -We initialize a Worker with InProcessRunners. We then construct a State which gets passed to Worker.plan. -We then check what operation is returned by Worker.plan. - -Note that the 'self' node will always be NODE_A. This leads to the swapped-around cases when checking failure cases etc. -""" - - -def _get_test_cases() -> list[PlanTestCase]: - # The `model_path` for `RUNNER_1_ID` must exist for the `DownloadOp` test case to pass validation. - model_a_meta = make_model_meta(MODEL_A_ID) - return [ - PlanTestCase( - description="no runners -> no-op", - in_process_runners=[], - state=State( - node_status={NODE_A: WorkerStatus.Idle}, instances={}, runners={} - ), - expected_op=None, - ), - # Both 'assigned' and 'downloading' should be blocking ops - so if we are in either of these we should unassign to retry. - # This needs to change when we move to an async worker - make_test_case( - description="runner state assigned, runner is assigned and downloading -> unassign", - runner_specs=[ - { - "runner_id": RUNNER_1_ID, - "node_id": NODE_A, - "device_rank": 0, - "status": make_downloading_status(NODE_A), - "downloaded": False, - } - ], - instance_status=InstanceStatus.Inactive, - expected_op=UnassignRunnerOp(runner_id=RUNNER_1_ID), - ), - make_test_case( - description="ready runner, model present -> no-op", - runner_specs=[ - { - "runner_id": RUNNER_1_ID, - "node_id": NODE_A, - "device_rank": 0, - "status": InactiveRunnerStatus(), - "downloaded": True, - } - ], - instance_status=InstanceStatus.Inactive, - expected_op=None, - ), - PlanTestCase( - description="runner assigned and not in state -> AssignRunnerOp", - in_process_runners=[], - state=make_state( - runner_specs_per_instance={ - INSTANCE_1_ID: [(RUNNER_1_ID, NODE_A, 0, InactiveRunnerStatus())] - }, - model_id=MODEL_A_ID, - instance_status=InstanceStatus.Active, # Either active or inactive should yield the same. - ), - expected_op=AssignRunnerOp( - instance_id=INSTANCE_1_ID, - runner_id=RUNNER_1_ID, - shard_metadata=PipelineShardMetadata( - device_rank=0, - world_size=1, - model_meta=model_a_meta, - start_layer=0, - end_layer=1, - n_layers=1, - ), - hosts=[], - ), - ), - PlanTestCase( - description="runner assigned but no longer in state -> UnassignRunnerOp", - in_process_runners=[ - InProcessRunner( - runner_id=RUNNER_1_ID, - instance_id=INSTANCE_1_ID, - model_id=MODEL_A_ID, - status=InactiveRunnerStatus(), - downloaded=False, - ) - ], - state=State( - node_status={NODE_A: WorkerStatus.Idle}, instances={}, runners={} - ), - expected_op=UnassignRunnerOp(runner_id=RUNNER_1_ID), - ), - make_test_case( - description="ready runner (and state up) -> expect RunnerUpOp", - runner_specs=[ - { - "runner_id": RUNNER_1_ID, - "node_id": NODE_A, - "device_rank": 0, - "status": InactiveRunnerStatus(), - "downloaded": True, - } - ], - instance_status=InstanceStatus.Active, - expected_op=RunnerUpOp(runner_id=RUNNER_1_ID), - ), - make_test_case( - description="1 ready, 1 downloading (and state up) -> no-op", - runner_specs=[ - { - "runner_id": RUNNER_1_ID, - "node_id": NODE_A, - "device_rank": 0, - "status": InactiveRunnerStatus(), - "downloaded": True, - }, - { - "runner_id": RUNNER_2_ID, - "node_id": NODE_B, - "device_rank": 1, - "status": DownloadingRunnerStatus( - download_progress=DownloadPending(node_id=NODE_A) - ), - "downloaded": False, - }, - ], - tasks=[ - { - "task_id": TASK_1_ID, - "instance_id": INSTANCE_1_ID, - "status": TaskStatus.Pending, - "messages": [{"role": "user", "content": "Hello, world!"}], - } - ], - instance_status=InstanceStatus.Active, - expected_op=None, - ), - make_test_case( - description="2 ready runners (and state up) -> expect RunnerUpOp", - runner_specs=[ - { - "runner_id": RUNNER_1_ID, - "node_id": NODE_A, - "device_rank": 0, - "status": InactiveRunnerStatus(), - "downloaded": True, - }, - { - "runner_id": RUNNER_2_ID, - "node_id": NODE_B, - "device_rank": 1, - "status": InactiveRunnerStatus(), - "downloaded": True, - }, - ], - tasks=[ - { - "task_id": TASK_1_ID, - "instance_id": INSTANCE_1_ID, - "status": TaskStatus.Pending, - "messages": [{"role": "user", "content": "Hello, world!"}], - } - ], - instance_status=InstanceStatus.Active, - expected_op=RunnerUpOp(runner_id=RUNNER_1_ID), - ), - make_test_case( - description="loaded runner (and state down) -> expect RunnerDownOp", - runner_specs=[ - { - "runner_id": RUNNER_1_ID, - "node_id": NODE_A, - "device_rank": 0, - "status": LoadedRunnerStatus(), - "downloaded": True, - } - ], - instance_status=InstanceStatus.Inactive, - expected_op=RunnerDownOp(runner_id=RUNNER_1_ID), - ), - make_test_case( - description="failed runner (and state down) -> expect RunnerDownOp", - runner_specs=[ - { - "runner_id": RUNNER_1_ID, - "node_id": NODE_A, - "device_rank": 0, - "status": FailedRunnerStatus(), - "downloaded": True, - } - ], - instance_status=InstanceStatus.Inactive, - expected_op=RunnerDownOp(runner_id=RUNNER_1_ID), - ), - make_test_case( - description="loaded runner, model present, task pending -> expect ExecuteTaskOp", - runner_specs=[ - { - "runner_id": RUNNER_1_ID, - "node_id": NODE_A, - "device_rank": 0, - "status": LoadedRunnerStatus(), - "downloaded": True, - } - ], - tasks=[ - { - "task_id": TASK_1_ID, - "instance_id": INSTANCE_1_ID, - "status": TaskStatus.Pending, - "messages": [{"role": "user", "content": "Hello, world!"}], - } - ], - instance_status=InstanceStatus.Active, - expected_op=ExecuteTaskOp( - runner_id=RUNNER_1_ID, - task=ChatCompletionTask( - task_id=TASK_1_ID, - command_id=COMMAND_1_ID, - instance_id=INSTANCE_1_ID, - task_status=TaskStatus.Pending, - task_params=ChatCompletionTaskParams( - model=str(MODEL_A_ID), - messages=[ - ChatCompletionMessage(role="user", content="Hello, world!") - ], - ), - ), - ), - ), - # We should only run rank 0 once all other ranks are running. - make_test_case( - description="two loaded runners & task, i'm rank 0 -> no-op", - runner_specs=[ - { - "runner_id": RUNNER_1_ID, - "node_id": NODE_A, - "device_rank": 0, - "status": LoadedRunnerStatus(), - "downloaded": True, - }, - { - "runner_id": RUNNER_2_ID, - "node_id": NODE_B, - "device_rank": 1, - "status": LoadedRunnerStatus(), - "downloaded": True, - }, - ], - tasks=[ - { - "task_id": TASK_1_ID, - "instance_id": INSTANCE_1_ID, - "status": TaskStatus.Pending, - "messages": [{"role": "user", "content": "Hello, world!"}], - } - ], - instance_status=InstanceStatus.Active, - expected_op=None, - ), - make_test_case( - description="two loaded runners & task, i'm rank 1 -> expect ExecuteTaskOp on rank 1", - runner_specs=[ - { - "runner_id": RUNNER_1_ID, - "node_id": NODE_A, - "device_rank": 1, - "status": LoadedRunnerStatus(), - "downloaded": True, - }, - { - "runner_id": RUNNER_2_ID, - "node_id": NODE_B, - "device_rank": 0, - "status": LoadedRunnerStatus(), - "downloaded": True, - }, - ], - tasks=[ - { - "task_id": TASK_1_ID, - "instance_id": INSTANCE_1_ID, - "status": TaskStatus.Pending, - "messages": [{"role": "user", "content": "Hello, world!"}], - } - ], - instance_status=InstanceStatus.Active, - expected_op=ExecuteTaskOp( - runner_id=RUNNER_1_ID, - task=ChatCompletionTask( - task_id=TASK_1_ID, - command_id=COMMAND_1_ID, - instance_id=INSTANCE_1_ID, - task_params=ChatCompletionTaskParams( - model=str(MODEL_A_ID), - messages=[ - ChatCompletionMessage(role="user", content="Hello, world!") - ], - ), - task_status=TaskStatus.Pending, - ), - ), - ), - make_test_case( - description="rank 1 loaded, rank 0 ready, i'm rank 0 -> expect ExecuteTaskOp on rank 0", - runner_specs=[ - { - "runner_id": RUNNER_1_ID, - "node_id": NODE_A, - "device_rank": 0, - "status": LoadedRunnerStatus(), - "downloaded": True, - }, - { - "runner_id": RUNNER_2_ID, - "node_id": NODE_B, - "device_rank": 1, - "status": RunningRunnerStatus(), - "downloaded": True, - }, - ], - tasks=[ - { - "task_id": TASK_1_ID, - "instance_id": INSTANCE_1_ID, - "status": TaskStatus.Pending, - "messages": [{"role": "user", "content": "Hello, world!"}], - } - ], - instance_status=InstanceStatus.Active, - expected_op=ExecuteTaskOp( - runner_id=RUNNER_1_ID, - task=ChatCompletionTask( - task_id=TASK_1_ID, - command_id=COMMAND_1_ID, - instance_id=INSTANCE_1_ID, - task_params=ChatCompletionTaskParams( - model=str(MODEL_A_ID), - messages=[ - ChatCompletionMessage(role="user", content="Hello, world!") - ], - ), - task_status=TaskStatus.Pending, - ), - ), - ), - make_test_case( - description="this runner failed (1 node) -> RunnerDownOp", - runner_specs=[ - { - "runner_id": RUNNER_1_ID, - "node_id": NODE_A, - "device_rank": 0, - "status": FailedRunnerStatus(), - "downloaded": True, - } - ], - instance_status=InstanceStatus.Active, - expected_op=RunnerDownOp(runner_id=RUNNER_1_ID), - ), - make_test_case( - description="other runner failed -> RunnerDownOp", - runner_specs=[ - { - "runner_id": RUNNER_1_ID, - "node_id": NODE_A, - "device_rank": 0, - "status": LoadedRunnerStatus(), - "downloaded": True, - }, - { - "runner_id": RUNNER_2_ID, - "node_id": NODE_B, - "device_rank": 1, - "status": FailedRunnerStatus(), - "downloaded": True, - }, - ], - instance_status=InstanceStatus.Active, - expected_op=RunnerDownOp(runner_id=RUNNER_1_ID), - ), - make_test_case( - description="this runner failed (2 nodes) -> no-op", - runner_specs=[ - { - "runner_id": RUNNER_1_ID, - "node_id": NODE_A, - "device_rank": 0, - "status": FailedRunnerStatus(), - "downloaded": True, - }, - { - "runner_id": RUNNER_2_ID, - "node_id": NODE_B, - "device_rank": 1, - "status": LoadedRunnerStatus(), - "downloaded": True, - }, - ], - instance_status=InstanceStatus.Active, - expected_op=None, - ), - make_test_case( - description="this node failed, other node spun down -> RunnerDownOp", - runner_specs=[ - { - "runner_id": RUNNER_1_ID, - "node_id": NODE_A, - "device_rank": 0, - "status": FailedRunnerStatus(), - "downloaded": True, - }, - { - "runner_id": RUNNER_2_ID, - "node_id": NODE_B, - "device_rank": 1, - "status": InactiveRunnerStatus(), - "downloaded": True, - }, - ], - instance_status=InstanceStatus.Active, - expected_op=RunnerDownOp(runner_id=RUNNER_1_ID), - ), - ] - - -# --------------------------------------------------------------------------- -# Parametrised test -# --------------------------------------------------------------------------- - - -# Pre-compute readable identifiers for each case to avoid lambda typing issues. -@pytest.mark.parametrize( - "case", - # We use a factory to delay test case generation until tmp_path is available. - [pytest.param(c, id=c.id()) for c in _get_test_cases()], -) -def test_worker_plan(case: PlanTestCase, worker_void_mailbox: Worker) -> None: - """Exercise Worker.plan across declarative scenarios.""" - - print(f"----- case: {case.description}") - - # Regenerate test cases with the actual tmp_path fixture - test_cases = {c.description: c for c in _get_test_cases()} - case = test_cases[case.description] - - worker = worker_void_mailbox - - runner_config: InProcessRunner - for runner_config in case.in_process_runners: - if len(case.state.instances) == 1: - instance_id = next(iter(case.state.instances)) - - shard_assignments = case.state.instances[instance_id].shard_assignments - shard_metadata = shard_assignments.runner_to_shard[runner_config.runner_id] - - # Only add this runner if it belongs to our node - runner_node = None - for node, runner in shard_assignments.node_to_runner.items(): - if runner == runner_config.runner_id: - runner_node = node - break - - if runner_node != worker.node_id: - # This runner belongs to a different node, skip it - continue - - elif len(case.state.instances) == 0: - shard_metadata = PipelineShardMetadata( - device_rank=runner_config.device_rank, - world_size=1, - model_meta=make_model_meta(runner_config.model_id), - start_layer=0, - end_layer=1, - n_layers=1, - ) - else: - raise Exception( - "test_worker_plan not currently designed to have more than 1 instance." - ) - - assigned_runner = AssignedRunner( - runner_id=runner_config.runner_id, - instance_id=runner_config.instance_id, - shard_metadata=shard_metadata, - hosts=[], - status=runner_config.status, - runner=None, - ) - worker.assigned_runners[runner_config.runner_id] = assigned_runner - - op = plan( - worker.assigned_runners, - NODE_A, - case.state.instances, - case.state.runners, - case.state.tasks, - ) - assert op == case.expected_op diff --git a/src/exo/worker/tests/test_plan/test_worker_plan_utils.py b/src/exo/worker/tests/test_plan/test_worker_plan_utils.py deleted file mode 100644 index 9053df1f..00000000 --- a/src/exo/worker/tests/test_plan/test_worker_plan_utils.py +++ /dev/null @@ -1,292 +0,0 @@ -from dataclasses import dataclass -from typing import NotRequired, TypedDict - -from typing_extensions import Literal - -from exo.shared.models.model_cards import MODEL_CARDS, ModelCard -from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams -from exo.shared.types.common import CommandId, NodeId -from exo.shared.types.memory import Memory -from exo.shared.types.models import ModelId, ModelMetadata -from exo.shared.types.state import State -from exo.shared.types.tasks import ChatCompletionTask, TaskId, TaskStatus -from exo.shared.types.worker.common import InstanceId, RunnerId, WorkerStatus -from exo.shared.types.worker.downloads import DownloadOngoing, DownloadProgressData -from exo.shared.types.worker.instances import Instance, InstanceStatus -from exo.shared.types.worker.ops import RunnerOp -from exo.shared.types.worker.runners import ( - DownloadingRunnerStatus, - RunnerStatus, - RunningRunnerStatus, - ShardAssignments, -) -from exo.shared.types.worker.shards import PipelineShardMetadata -from exo.worker.tests.constants import COMMAND_1_ID, INSTANCE_1_ID, MODEL_A_ID - - -class RunnerSpecDict(TypedDict): - """Type definition for runner specification dictionaries.""" - - runner_id: RunnerId - node_id: NodeId - device_rank: int - status: RunnerStatus - downloaded: NotRequired[bool] # defaults to True if not provided - - -class MessageDict(TypedDict): - """Type definition for message dictionaries.""" - - role: Literal["system", "user", "assistant", "developer", "tool", "function"] - content: NotRequired[str | None] - name: NotRequired[str | None] - tool_calls: NotRequired[list[dict[str, str]] | None] - tool_call_id: NotRequired[str | None] - function_call: NotRequired[dict[str, str] | None] - - -class TaskSpecDict(TypedDict): - """Type definition for task specification dictionaries.""" - - task_id: TaskId - instance_id: NotRequired[ - InstanceId - ] # defaults to function parameter if not provided - command_id: NotRequired[CommandId] # defaults to COMMAND_1_ID if not provided - status: NotRequired[TaskStatus] # defaults to TaskStatus.PENDING if not provided - model: NotRequired[str] # defaults to model_id if not provided - messages: NotRequired[ - list[MessageDict] - ] # defaults to [{'role': 'user', 'content': 'Hello, world!'}] if not provided - - -@dataclass(slots=True, frozen=True) -class InProcessRunner: - """Minimal description of a runner's in-process state.""" - - runner_id: RunnerId - instance_id: InstanceId - model_id: ModelId - status: RunnerStatus - downloaded: bool - device_rank: int = 0 - - -@dataclass(slots=True, frozen=True) -class PlanTestCase: - """Table-driven description of an entire planning scenario.""" - - description: str - state: State - in_process_runners: list[InProcessRunner] - expected_op: RunnerOp | None - - def id(self) -> str: # noqa: D401 - return self.description.replace(" ", "_") - - -def make_shard_metadata( - device_rank: int, world_size: int, model_id: ModelId = MODEL_A_ID -) -> PipelineShardMetadata: - """Create PipelineShardMetadata with proper layer assignments based on device_rank and world_size.""" - total_layers = world_size # For simplicity in tests, total_layers = world_size - - if world_size == 1: - start_layer = 0 - end_layer = 1 - n_layers = 1 - else: - # For multi-device setup, each device gets one layer - start_layer = device_rank - end_layer = device_rank + 1 - n_layers = total_layers - - return PipelineShardMetadata( - device_rank=device_rank, - world_size=world_size, - model_meta=make_model_meta(model_id), - start_layer=start_layer, - end_layer=end_layer, - n_layers=n_layers, - ) - - -def make_downloading_status(node_id: NodeId) -> DownloadingRunnerStatus: - """Factory for a *Downloading* status with placeholder progress.""" - return DownloadingRunnerStatus( - download_progress=DownloadOngoing( - node_id=node_id, - download_progress=DownloadProgressData( - total_bytes=Memory.from_bytes(1), - downloaded_bytes=Memory.from_bytes(0), - downloaded_bytes_this_session=Memory.from_bytes(0), - completed_files=0, - total_files=0, - speed=0, - eta_ms=0, - files={}, - ), - ) - ) - - -def make_model_meta(model_id: str) -> ModelMetadata: - model_card: ModelCard - for card in MODEL_CARDS.values(): - if card.model_id == model_id: - model_card = card - - return ModelMetadata( - model_id=ModelId(model_id), - pretty_name=model_card.model_id, - storage_size=Memory.from_kb(10**6), - n_layers=16, - ) - - raise Exception(f"Unknown model_id passed: {model_id}") - - ## Alternatively, if we are ok for this method to be async: - # await _get_model_meta(model_id) - - -def make_instance( - instance_id: InstanceId, - runner_specs: list[tuple[RunnerId, NodeId, int, RunnerStatus]], - model_id: ModelId = MODEL_A_ID, - instance_status: InstanceStatus = InstanceStatus.Active, -) -> tuple[Instance, dict[RunnerId, RunnerStatus], dict[NodeId, WorkerStatus]]: - """Creates an instance with one or more runners.""" - runner_to_shard: dict[RunnerId, PipelineShardMetadata] = {} - node_to_runner: dict[NodeId, RunnerId] = {} - world_size = len(runner_specs) - - for runner_id, node_id, device_rank, _ in runner_specs: - shard_metadata = make_shard_metadata(device_rank, world_size, model_id) - runner_to_shard[runner_id] = shard_metadata - node_to_runner[node_id] = runner_id - - shard_assignments = ShardAssignments( - model_id=model_id, - runner_to_shard=runner_to_shard, - node_to_runner=node_to_runner, - ) - instance = Instance( - instance_id=instance_id, - instance_type=instance_status, - shard_assignments=shard_assignments, - hosts=[], - ) - - # Currently nodes are only ever idle - as if they were running we would be blocking - so we wouldn't be running plan() - # node_statuses = {node_id: WorkerStatus.Idle for _, node_id, _, _ in runner_specs} - node_statuses: dict[NodeId, WorkerStatus] = {} - for _runner_id, node_id, _, status in runner_specs: - if isinstance(status, RunningRunnerStatus): - node_statuses[node_id] = WorkerStatus.Running - else: - node_statuses[node_id] = WorkerStatus.Idle - runner_statuses = {runner_id: status for runner_id, _, _, status in runner_specs} - - return instance, runner_statuses, node_statuses - - -def make_state( - runner_specs_per_instance: dict[ - InstanceId, list[tuple[RunnerId, NodeId, int, RunnerStatus]] - ], - tasks: dict[TaskId, ChatCompletionTask] | None = None, - model_id: ModelId = MODEL_A_ID, - instance_status: InstanceStatus = InstanceStatus.Active, -) -> State: - """Builds a full State from runner specs per instance, tasks, and defaults.""" - if tasks is None: - tasks = {} - instances: dict[InstanceId, Instance] = {} - all_runner_statuses: dict[RunnerId, RunnerStatus] = {} - all_node_statuses: dict[NodeId, WorkerStatus] = {} - - for inst_id, specs in runner_specs_per_instance.items(): - # Build per-instance data using make_instance - instance, runner_statuses, node_statuses = make_instance( - instance_id=inst_id, - runner_specs=specs, - model_id=model_id, - instance_status=instance_status, - ) - instances[inst_id] = instance - all_runner_statuses.update(runner_statuses) - all_node_statuses.update(node_statuses) - - return State( - node_status=all_node_statuses, - instances=instances, - runners=all_runner_statuses, - tasks=tasks, - ) - - -def make_test_case( - description: str, - runner_specs: list[RunnerSpecDict], - tasks: list[TaskSpecDict] | None = None, - expected_op: RunnerOp | None = None, - instance_id: InstanceId = INSTANCE_1_ID, - instance_status: InstanceStatus = InstanceStatus.Active, - model_id: ModelId = MODEL_A_ID, - command_id: CommandId = COMMAND_1_ID, # Default for tasks -) -> PlanTestCase: - """Builds a PlanTestCase from high-level specs.""" - if tasks is None: - tasks = [] - # Convert runner_specs to tuple format for make_instance - specs_tuple = [ - (r["runner_id"], r["node_id"], r["device_rank"], r["status"]) - for r in runner_specs - ] - - # Build state using make_state (wrap single instance) - state_tasks: dict[TaskId, ChatCompletionTask] = {} - for t in tasks: - task = ChatCompletionTask( - instance_id=instance_id, - task_id=t["task_id"], - command_id=t.get("command_id", command_id), - task_status=t.get("status", TaskStatus.Pending), - task_params=ChatCompletionTaskParams( - model=t.get("model", str(model_id)), - messages=[ - ChatCompletionMessage(**m) - for m in t.get( - "messages", [{"role": "user", "content": "Hello, world!"}] - ) - ], - ), - ) - state_tasks[t["task_id"]] = task - - state = make_state( - runner_specs_per_instance={instance_id: specs_tuple}, - tasks=state_tasks, - model_id=model_id, - instance_status=instance_status, - ) - - # Build in_process_runners with downloaded (default True if missing) - in_process_runners = [ - InProcessRunner( - runner_id=r["runner_id"], - instance_id=instance_id, - model_id=model_id, - status=r["status"], - downloaded=r.get("downloaded", True), - device_rank=r["device_rank"], - ) - for r in runner_specs - ] - - return PlanTestCase( - description=description, - state=state, - in_process_runners=in_process_runners, - expected_op=expected_op, - ) diff --git a/src/exo/worker/tests/test_runner_connection.py b/src/exo/worker/tests/test_runner_connection.py deleted file mode 100644 index a887b866..00000000 --- a/src/exo/worker/tests/test_runner_connection.py +++ /dev/null @@ -1,181 +0,0 @@ -import asyncio -import os -from typing import Callable - -import pytest -from anyio import create_task_group, move_on_after - -from exo.shared.types.common import Host -from exo.shared.types.events import InstanceCreated, InstanceDeleted -from exo.shared.types.models import ModelId -from exo.shared.types.worker.instances import Instance, InstanceStatus, ShardAssignments -from exo.shared.types.worker.runners import FailedRunnerStatus -from exo.shared.types.worker.shards import PipelineShardMetadata -from exo.worker.main import Worker -from exo.worker.runner.runner_supervisor import RunnerSupervisor -from exo.worker.tests.constants import ( - INSTANCE_1_ID, - MASTER_NODE_ID, - NODE_A, - NODE_B, - RUNNER_1_ID, - RUNNER_2_ID, -) -from exo.worker.tests.worker_management import WorkerMailbox - - -@pytest.fixture -def user_message() -> str: - return "What is the capital of Japan?" - - -@pytest.mark.skipif( - os.environ.get("DETAILED", "").lower() != "true", - reason="This test only runs when ENABLE_SPINUP_TIMEOUT_TEST=true environment variable is set", -) -async def check_runner_connection( - pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata], - hosts: Callable[[int], list[Host]], - two_workers_with_shared_mailbox: tuple[Worker, Worker, WorkerMailbox], -) -> bool: - async def wait_for_runner_supervisor( - worker: Worker, timeout: float = 5.0 - ) -> RunnerSupervisor | None: - with move_on_after(timeout): - while True: - assigned_runners = list(worker.assigned_runners.values()) - if assigned_runners: - runner = assigned_runners[0].runner - if isinstance(runner, RunnerSupervisor): - print("breaking because success") - return runner - if isinstance(assigned_runners[0].status, FailedRunnerStatus): - print("breaking because failed") - return runner - await asyncio.sleep(0.001) - - worker1, worker2, global_events = two_workers_with_shared_mailbox - # Track all tasks and workers for cleanup - async with create_task_group() as tg: - tg.start_soon(worker1.run) - tg.start_soon(worker2.run) - model_id = ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit") - - shard_assignments = ShardAssignments( - model_id=model_id, - runner_to_shard={ - RUNNER_1_ID: pipeline_shard_meta(2, 0), - RUNNER_2_ID: pipeline_shard_meta(2, 1), - }, - node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID}, - ) - - instance = Instance( - instance_id=INSTANCE_1_ID, - instance_type=InstanceStatus.Active, - shard_assignments=shard_assignments, - hosts=hosts(2), - ) - - await global_events.append_events( - [ - InstanceCreated(instance=instance), - ], - origin=MASTER_NODE_ID, - ) - - runner_supervisor = await wait_for_runner_supervisor(worker1, timeout=6.0) - ret = ( - runner_supervisor is not None - and runner_supervisor.runner_process.is_alive() - ) - - await global_events.append_events( - [ - InstanceDeleted( - instance_id=instance.instance_id, - ), - ], - origin=MASTER_NODE_ID, - ) - - await asyncio.sleep(0.5) - - worker1.shutdown() - worker2.shutdown() - tg.cancel_scope.cancel() - - return ret - # should be unreachable - raise - - -# Check Running status - -# # not now. - -# def test_runner_connection_stress( -# pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata], -# hosts: Callable[[int], list[Host]], -# chat_completion_task: Callable[[InstanceId, str], Task], -# ) -> None: -# total_runs = 100 -# successes = 0 -# # not now. - -# def test_runner_connection_stress( -# pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata], -# hosts: Callable[[int], list[Host]], -# chat_completion_task: Callable[[InstanceId, str], Task], -# ) -> None: -# total_runs = 100 -# successes = 0 - -# for _ in range(total_runs): -# # Create a fresh event loop for each iteration -# loop = asyncio.new_event_loop() -# asyncio.set_event_loop(loop) -# for _ in range(total_runs): -# # Create a fresh event loop for each iteration -# loop = asyncio.new_event_loop() -# asyncio.set_event_loop(loop) - -# try: -# result = loop.run_until_complete(check_runner_connection( -# pipeline_shard_meta=pipeline_shard_meta, -# hosts=hosts, -# chat_completion_task=chat_completion_task, -# )) -# if result: -# successes += 1 -# finally: -# # Cancel all running tasks -# pending = asyncio.all_tasks(loop) -# for task in pending: -# task.cancel() -# try: -# result = loop.run_until_complete(check_runner_connection( -# pipeline_shard_meta=pipeline_shard_meta, -# hosts=hosts, -# chat_completion_task=chat_completion_task, -# )) -# if result: -# successes += 1 -# finally: -# # Cancel all running tasks -# pending = asyncio.all_tasks(loop) -# for task in pending: -# task.cancel() - -# # Run the event loop briefly to allow cancellation to complete -# loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True)) -# # Run the event loop briefly to allow cancellation to complete -# loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True)) - -# # Close the event loop -# loop.close() -# # Close the event loop -# loop.close() - -# print(f"Runner connection successes: {successes} / {total_runs}") -# print(f"Runner connection successes: {successes} / {total_runs}") diff --git a/src/exo/worker/tests/test_serdes.py b/src/exo/worker/tests/test_serdes.py deleted file mode 100644 index 58c9c307..00000000 --- a/src/exo/worker/tests/test_serdes.py +++ /dev/null @@ -1,43 +0,0 @@ -from typing import Callable - -from pydantic import BaseModel, TypeAdapter - -from exo.shared.types.common import Host -from exo.shared.types.tasks import Task, TaskId -from exo.shared.types.worker.commands_runner import ( - ChatTaskMessage, - RunnerMessage, - SetupMessage, -) -from exo.shared.types.worker.common import InstanceId -from exo.shared.types.worker.shards import PipelineShardMetadata - - -def assert_equal_serdes[T: BaseModel](obj: T, typeadapter: TypeAdapter[T]): - encoded: bytes = obj.model_dump_json().encode("utf-8") + b"\n" - decoded: T = typeadapter.validate_json(encoded) - - assert decoded == obj, ( - f"Decoded: {decoded} != \nOriginal: {obj}. \n binary encoded: {encoded}" - ) - - -def test_supervisor_setup_message_serdes( - pipeline_shard_meta: Callable[..., PipelineShardMetadata], - hosts: Callable[..., list[Host]], -): - setup_message = SetupMessage( - model_shard_meta=pipeline_shard_meta(1, 0), - hosts=hosts(1), - ) - assert_equal_serdes(setup_message, TypeAdapter(RunnerMessage)) - - -def test_supervisor_task_message_serdes( - chat_completion_task: Callable[[InstanceId, TaskId], Task], -): - task = chat_completion_task(InstanceId(), TaskId()) - task_message = ChatTaskMessage( - task_data=task.task_params, - ) - assert_equal_serdes(task_message, TypeAdapter(RunnerMessage)) diff --git a/src/exo/worker/tests/test_spinup_timeout.py b/src/exo/worker/tests/test_spinup_timeout.py deleted file mode 100644 index 3780023a..00000000 --- a/src/exo/worker/tests/test_spinup_timeout.py +++ /dev/null @@ -1,50 +0,0 @@ -## Tests for worker state handlers - -import os -from typing import Callable - -import pytest - -from exo.shared.types.events import ( - Event, - RunnerStatusUpdated, -) -from exo.shared.types.tasks import Task, TaskId -from exo.shared.types.worker.instances import Instance, InstanceId -from exo.shared.types.worker.ops import ( - RunnerUpOp, -) -from exo.shared.types.worker.runners import FailedRunnerStatus -from exo.worker.main import Worker -from exo.worker.tests.constants import RUNNER_1_ID - -# To enable this test, run pytest with: ENABLE_SPINUP_TIMEOUT_TEST=true pytest - - -@pytest.mark.skipif( - os.environ.get("DETAILED", "").lower() != "true", - reason="This test only runs when ENABLE_SPINUP_TIMEOUT_TEST=true environment variable is set", -) -@pytest.mark.asyncio -async def test_runner_up_op_timeout( - worker_with_assigned_runner: tuple[Worker, Instance], - chat_completion_task: Callable[[InstanceId, TaskId], Task], - monkeypatch: pytest.MonkeyPatch, -): - worker, _ = worker_with_assigned_runner - - runner_up_op = RunnerUpOp(runner_id=RUNNER_1_ID) - - # _execute_runner_up_op should throw a TimeoutError with a short timeout - events: list[Event] = [] - async for event in worker._execute_runner_up_op( # type: ignore[misc] - runner_up_op, initialize_timeout=0.2 - ): - events.append(event) - - assert isinstance(events[-1], RunnerStatusUpdated) - assert isinstance(events[-1].runner_status, FailedRunnerStatus) - assert events[-1].runner_status.error_message is not None - assert "timeout" in events[-1].runner_status.error_message.lower() - - del worker.assigned_runners[list(worker.assigned_runners.keys())[0]] diff --git a/src/exo/worker/tests/test_supervisor/test_long.py b/src/exo/worker/tests/test_supervisor/test_long.py deleted file mode 100644 index 89f81969..00000000 --- a/src/exo/worker/tests/test_supervisor/test_long.py +++ /dev/null @@ -1,163 +0,0 @@ -import asyncio -from typing import Callable - -import pytest - -from exo.shared.models.model_cards import MODEL_CARDS -from exo.shared.openai_compat import FinishReason -from exo.shared.types.chunks import TokenChunk -from exo.shared.types.common import Host -from exo.shared.types.tasks import ( - Task, - TaskId, -) -from exo.shared.types.worker.common import InstanceId -from exo.shared.types.worker.shards import PipelineShardMetadata -from exo.worker.runner.runner_supervisor import RunnerSupervisor - - -@pytest.fixture -def user_message(): - """Override the default message to ask about France's capital""" - return "What is the capital of France?" - - -@pytest.fixture -def lorem_ipsum() -> str: - return """ -Lorem ipsum dolor sit amet, consectetur adipiscing elit. Phasellus rhoncus felis in velit tempus tristique. Nullam ipsum lectus, tristique a eros quis, ullamcorper accumsan lorem. Aliquam ut auctor elit, finibus porttitor neque. In cursus augue facilisis ante ullamcorper, at sollicitudin quam aliquam. Etiam ac lacinia lacus, et aliquet nunc. Phasellus nisi ex, feugiat quis dolor non, mollis consequat nulla. Suspendisse gravida, sem non lobortis viverra, turpis lacus elementum orci, in tristique augue tortor nec mauris. Curabitur aliquet lorem in rhoncus mollis. Aliquam pulvinar elit odio, ac feugiat magna luctus nec. Pellentesque non risus egestas, pellentesque arcu tincidunt, gravida risus. Etiam ut lorem ac lorem pharetra efficitur. Donec augue arcu, varius nec lorem vitae, suscipit semper tellus. Aliquam dignissim quis augue id fermentum. Proin aliquet pellentesque est, eget tincidunt odio ullamcorper vel. Suspendisse potenti. -Aenean imperdiet justo sit amet erat aliquet tristique. Sed tempus, turpis a cursus lobortis, ante sem imperdiet est, eu dapibus sapien velit eget elit. Donec feugiat sed risus sed scelerisque. Donec posuere tempor orci, sit amet pellentesque est efficitur non. Vivamus sodales pretium purus, sed rutrum enim auctor ut. Cras pharetra vitae libero et hendrerit. Sed nec tempus odio. Proin blandit facilisis scelerisque. Nulla in mattis mi. Etiam bibendum efficitur aliquam. Proin ut risus aliquet, rhoncus lectus non, rhoncus arcu. Nam nibh felis, ultrices a elit sed, ultricies sollicitudin tellus. Interdum et malesuada fames ac ante ipsum primis in faucibus. Maecenas faucibus magna ut purus imperdiet faucibus. Nam fermentum nulla fermentum magna aliquam, vel lacinia neque euismod. Donec tincidunt sed neque non facilisis. -Proin id lorem cursus, vehicula ante non, lacinia metus. Nam egestas dui a iaculis convallis. Ut suscipit justo est, nec pharetra ante accumsan ac. Pellentesque nec nisi ipsum. Duis non arcu neque. Curabitur non luctus purus. Phasellus pulvinar commodo lacus sit amet auctor. Ut ut mattis metus, eu auctor arcu. Etiam a suscipit est. Morbi orci mauris, suscipit tempus fermentum vel, luctus faucibus lectus. Aliquam a euismod arcu. Suspendisse porttitor eget libero vitae laoreet. -Fusce congue lorem mi, a mollis felis efficitur quis. Quisque lobortis scelerisque arcu, a varius sapien. Nulla eget orci non urna imperdiet tincidunt. Nunc mi massa, consectetur id lorem consectetur, molestie dignissim sem. Suspendisse et augue magna. Mauris id tempus velit, cursus suscipit tortor. Duis non mi non nisi fringilla maximus in et erat. -Proin consequat sapien eget tellus aliquam ultrices. Nunc hendrerit semper massa, pulvinar sodales ipsum condimentum eu. Proin vel ligula venenatis, lobortis lectus eu, vehicula justo. Mauris eu arcu at orci vehicula feugiat non eu metus. Duis ut vestibulum quam. Maecenas dolor elit, egestas ut purus sit amet, convallis lobortis massa. Ut volutpat augue ac ante consectetur dignissim. Maecenas vitae felis elementum, semper augue eu, auctor dolor. Ut pulvinar convallis tortor non volutpat. Curabitur vulputate sem sodales sapien pretium ultrices. Sed luctus libero vitae urna eleifend tincidunt. Proin pulvinar imperdiet cursus. Suspendisse ullamcorper laoreet leo dapibus tincidunt. Pellentesque molestie elementum felis. -Integer vitae congue nulla. Vestibulum ante ipsum primis in faucibus orci luctus et ultrices posuere cubilia curae; Vestibulum elit velit, malesuada quis ipsum et, imperdiet varius velit. Nam tristique viverra maximus. Curabitur eget semper lectus. Sed vitae lorem sit amet mi lacinia posuere ac a risus. Pellentesque et magna nisl. In hac habitasse platea dictumst. Aenean suscipit, nibh vitae sollicitudin commodo, risus mi commodo neque, nec venenatis velit augue sed massa. Nam tempus, arcu id eleifend auctor, est dui viverra odio, vel convallis arcu dolor id quam. Ut malesuada ligula vel interdum eleifend. In posuere ultrices tincidunt. Sed non enim sit amet lectus sagittis mattis eu at sapien. Pellentesque eu urna mollis, vehicula dolor eget, lobortis nisl. Suspendisse ex nisi, iaculis non sapien ac, fringilla rutrum dolor. Quisque pretium mauris nec ante gravida, sed laoreet neque viverra. -Donec mattis orci sit amet tincidunt maximus. Vestibulum ante ipsum primis in faucibus orci luctus et ultrices posuere cubilia curae; Curabitur tristique venenatis lectus, vel pulvinar sem. Sed vel dolor lacinia, aliquet nisi ac, bibendum libero. Nullam vulputate euismod augue ac imperdiet. Proin at fermentum sapien. Nam et fringilla lorem. Aenean sed lacus sed tellus sodales mattis ut rutrum ex. Nulla ligula diam, interdum quis faucibus sit amet, laoreet vel massa. Fusce mauris massa, tempor quis tempus nec, dictum a ligula. Ut at dapibus sapien. Nullam sem lorem, sollicitudin non dui a, consequat molestie mauris. Quisque sem nulla, vehicula nec vulputate ac, viverra in massa. Lorem ipsum dolor sit amet, consectetur adipiscing elit. Curabitur pretium venenatis nisi non bibendum. Nam vitae ligula auctor, rutrum lectus eget, feugiat augue. -Ut nunc risus, vehicula at metus non, consequat suscipit risus. Mauris eget sem in neque tincidunt iaculis. Pellentesque lacus leo, molestie ut pharetra sit amet, porta nec neque. Aliquam eu bibendum odio. Proin tempus bibendum ornare. Morbi non risus vitae ante tempor porta quis sed augue. Nullam hendrerit nulla in eleifend tincidunt. Integer suscipit ligula at nunc blandit vehicula. Nam porttitor leo in turpis suscipit malesuada. Etiam sodales nunc nisi, pharetra malesuada nibh varius in. Cras quis pellentesque augue, vitae convallis velit. In et dui lorem. Integer semper eros eget augue posuere, ac elementum tellus convallis. Praesent blandit tempus ultrices. Suspendisse nec dui vitae neque varius eleifend. Sed pretium metus leo, id viverra tellus scelerisque in. -Aenean sodales urna vitae lobortis cursus. Sed vitae pellentesque erat, fermentum pellentesque urna. Suspendisse potenti. Sed porttitor placerat turpis non vestibulum. Duis in nisi non purus venenatis tempus non eu nisi. Sed bibendum sapien vitae ultricies condimentum. Integer vel mattis lectus, consequat congue ex. Cras convallis odio volutpat nulla vehicula efficitur. Pellentesque eget justo neque. Morbi mattis vitae magna et suscipit. Etiam orci sapien, tincidunt non tellus eget, laoreet vestibulum massa. Pellentesque habitant morbi tristique senectus et netus et malesuada fames ac turpis egestas. Mauris nec nisi enim. Donec risus odio, lobortis in odio malesuada, laoreet rutrum urna. Nunc sit amet euismod quam. -Fusce rhoncus ullamcorper nunc, ut pellentesque nisi dictum sed. Fusce sem mi, bibendum ut dictum at, porta in libero. Pellentesque placerat mollis sapien, sed eleifend lorem consequat in. Phasellus vel tempor ligula. Pellentesque tincidunt suscipit tortor vel blandit. Maecenas purus mi, mattis ac aliquam vel, rutrum eu nulla. Proin rhoncus nec sem a congue. Pellentesque sit amet sapien quam. Sed hendrerit neque id venenatis dignissim. -Vestibulum laoreet eu felis nec aliquam. Praesent gravida ornare odio nec porttitor. Donec ut tellus eros. Proin fringilla urna augue, vitae ornare leo varius non. Curabitur consectetur, purus in iaculis finibus, lectus lacus porttitor dolor, nec eleifend tellus massa eget tellus. Mauris sit amet convallis risus, a fermentum lorem. Suspendisse potenti. Curabitur vulputate finibus maximus. Interdum et malesuada fames ac ante ipsum primis in faucibus. In vel erat pellentesque, rhoncus magna vel, scelerisque mauris. -Nulla facilisi. Morbi mattis felis nec accumsan varius. Vestibulum in sodales arcu. Vivamus egestas, ante nec dapibus vestibulum, tellus ipsum rhoncus mi, at fermentum sapien justo nec turpis. Quisque rhoncus, urna sit amet imperdiet cursus, tortor lacus ultricies sapien, eu bibendum ligula enim id mi. Sed sem leo, pharetra in pulvinar sed, faucibus sed dui. Morbi tempus erat nec neque placerat tincidunt. -Quisque ut lorem sodales magna faucibus mattis. Aenean dui neque, gravida ut fringilla non, fermentum sit amet dolor. Mauris a sapien lacinia, elementum dolor in, sagittis metus. Donec viverra magna non lorem rutrum, at eleifend lacus volutpat. Nunc sit amet dolor tempor, blandit sapien a, consectetur magna. Suspendisse maximus nunc nec imperdiet aliquet. Nunc aliquam interdum purus quis pretium. Mauris molestie feugiat pellentesque. Nunc maximus, est sed consequat malesuada, risus turpis consequat velit, ac feugiat nunc magna vitae ligula. Vestibulum tincidunt massa ante, vitae pellentesque tortor rutrum sed. Aliquam vel est libero. Suspendisse et convallis orci. Cras sed lorem consectetur, blandit massa sit amet, semper neque. Vestibulum et mi euismod, imperdiet justo non, facilisis libero. -Sed at lacus ac tortor dictum tempus. Integer commodo purus lacus, ut pretium est tempor ac. Ut vulputate nulla magna, ac facilisis velit commodo in. Interdum et malesuada fames ac ante ipsum primis in faucibus. Donec pellentesque congue nibh nec eleifend. Ut ante turpis, sodales sed aliquet quis, tempus eu dui. Proin et eros non risus porttitor pharetra. -Mauris a urna id justo gravida ultrices. Mauris commodo sed ipsum a dictum. In posuere luctus scelerisque. Morbi sit amet gravida ipsum. Quisque vel dui sit amet ex lobortis eleifend non vel neque. Fusce sit amet imperdiet felis, eu tempor diam. Pellentesque sit amet turpis in libero tristique posuere. Pellentesque habitant morbi tristique senectus et netus et malesuada fames ac turpis egestas. Mauris quis est suscipit, tristique odio elementum, molestie nibh. Maecenas ex dui, pulvinar quis pellentesque sed, imperdiet nec mauris. Pellentesque ultrices at mauris eget fringilla. Donec bibendum rhoncus felis, ut pretium nulla eleifend commodo. -Ut euismod erat accumsan tincidunt sagittis. Proin eget massa ex. Suspendisse at faucibus enim, vitae posuere mi. Cras nec ex finibus, porttitor purus quis, efficitur libero. Nulla sagittis ornare iaculis. Donec venenatis dui ut libero aliquam lobortis. Vestibulum imperdiet lorem urna, eget gravida orci sollicitudin ut. Quisque ultrices tortor at quam laoreet aliquet. Pellentesque tincidunt consequat pharetra. Cras a lacinia erat. Mauris sed neque lobortis ipsum facilisis hendrerit. -Cras at orci odio. Curabitur eros metus, consequat non placerat et, tincidunt at turpis. Morbi quis viverra metus. Vestibulum molestie, ex at suscipit finibus, ex magna pellentesque nisi, eu ullamcorper nisl sapien eu quam. Phasellus volutpat lacinia enim, nec fermentum augue tincidunt ut. Duis rutrum purus eu nulla elementum, a faucibus odio fringilla. Sed cursus risus neque, dictum luctus tortor tempus eu. -Mauris non arcu eu nunc faucibus tincidunt id quis dolor. Quisque ac fringilla libero. Sed non ligula ut nunc auctor consequat vitae eget metus. Ut suscipit leo quam, vitae ultrices urna feugiat eu. Vestibulum volutpat nisl quis nunc pretium, vel viverra orci fringilla. Proin erat nibh, laoreet nec nisi sit amet, volutpat efficitur nunc. Cras id tortor quis lectus imperdiet rutrum non id purus. Proin efficitur ligula non dapibus consectetur. Nam quis quam eget dui facilisis scelerisque. Praesent non bibendum risus. Etiam imperdiet nisi id consectetur porta. In pretium nulla ut leo ultricies rhoncus. -Curabitur non vehicula purus. Cras et justo risus. Duis et rutrum urna. Aliquam condimentum purus nec ante dignissim rhoncus. Vestibulum commodo pharetra eros, ac euismod orci rutrum vel. Integer sed cursus erat, euismod accumsan libero. Nullam ut odio sit amet nibh tempor congue. Pellentesque porttitor aliquam ipsum, sit amet facilisis quam fringilla ac. Aliquam scelerisque tempor nisl in tempor. Sed vestibulum, tellus sit amet mattis pellentesque, eros diam convallis felis, id pellentesque massa leo quis dolor. Integer dignissim orci lorem, vel porttitor felis blandit et. Nam ultrices enim sed elementum accumsan. Fusce rutrum, quam et feugiat maximus, lorem leo porttitor ex, a eleifend risus odio consectetur lacus. In hac habitasse platea dictumst. Aenean pharetra erat tellus, at tempus urna iaculis ut. Ut ac mi eu lorem volutpat egestas. -Vestibulum ante ipsum primis in faucibus orci luctus et ultrices posuere cubilia curae; Praesent porttitor tempor ligula. Quisque mollis arcu in metus ornare pellentesque. Aenean ultrices mollis quam quis sodales. Maecenas a cursus elit, id gravida tortor. Donec vel purus magna. Aliquam elementum est sed convallis fermentum. Nam nec eros arcu. Pellentesque sed eros a lacus sagittis maximus. Integer et tellus id libero dapibus convallis. Maecenas viverra, purus facilisis porttitor tincidunt, tellus lacus elementum dui, sed porttitor sem justo a lorem. Curabitur ipsum odio, efficitur quis efficitur at, tempus aliquet nisi. Aliquam ultrices tortor in arcu vulputate, vel iaculis lorem facilisis. Cras eleifend laoreet feugiat. Integer placerat blandit sem, mattis elementum purus pellentesque quis. Etiam vel arcu ut mi commodo placerat non id tortor. -""" - - -@pytest.mark.asyncio -async def test_supervisor_long_prompt_response( - pipeline_shard_meta: Callable[..., PipelineShardMetadata], - hosts: Callable[..., list[Host]], - chat_completion_task: Callable[[InstanceId, TaskId], Task], - lorem_ipsum: str, -): - """Test that asking for the capital of France returns 'Paris' in the response""" - - model_meta = MODEL_CARDS["llama-3.2-1b"].metadata - model_shard_meta = PipelineShardMetadata( - model_meta=model_meta, - device_rank=0, - world_size=1, - n_layers=model_meta.n_layers, - start_layer=0, - end_layer=model_meta.n_layers, - ) - instance_id = InstanceId() - - print(f"{model_shard_meta=}") - - supervisor = await RunnerSupervisor.create( - model_shard_meta=model_shard_meta, - hosts=hosts(1, offset=10), - ) - - try: - full_response = "" - - task = chat_completion_task(instance_id, TaskId()) - task.task_params.messages[0].content = lorem_ipsum * 3 - - async for chunk in supervisor.stream_response(task=task): - if isinstance(chunk, TokenChunk): - full_response += chunk.text - - assert len(full_response) > 100 - - finally: - await supervisor.astop() - - -@pytest.mark.asyncio -async def test_supervisor_two_node_long_prompt_response( - pipeline_shard_meta: Callable[..., PipelineShardMetadata], - hosts: Callable[..., list[Host]], - chat_completion_task: Callable[[InstanceId, TaskId], Task], - lorem_ipsum: str, -): - """Test two-node long prompt inference""" - instance_id = InstanceId() - - async def create_supervisor(shard_idx: int) -> RunnerSupervisor: - model_meta = MODEL_CARDS["llama-3.2-1b"].metadata - model_shard_meta = PipelineShardMetadata( - model_meta=model_meta, - device_rank=shard_idx, - world_size=2, - n_layers=model_meta.n_layers, - start_layer=0 if shard_idx == 0 else model_meta.n_layers // 2, - end_layer=model_meta.n_layers // 2 - if shard_idx == 0 - else model_meta.n_layers, - ) - supervisor = await RunnerSupervisor.create( - model_shard_meta=model_shard_meta, - hosts=hosts(2, offset=15), - ) - return supervisor - - create_supervisor_0 = asyncio.create_task(create_supervisor(0)) - create_supervisor_1 = asyncio.create_task(create_supervisor(1)) - supervisor_0, supervisor_1 = await asyncio.gather( - create_supervisor_0, create_supervisor_1 - ) - - await asyncio.sleep(0.1) - - try: - full_response_0 = "" - full_response_1 = "" - stop_reason_0: FinishReason | None = None - stop_reason_1: FinishReason | None = None - - task = chat_completion_task(instance_id, TaskId()) - task.task_params.messages[0].content = lorem_ipsum * 3 - - async def collect_response_0(): - nonlocal full_response_0, stop_reason_0 - async for chunk in supervisor_0.stream_response(task=task): - if isinstance(chunk, TokenChunk): - full_response_0 += chunk.text - if chunk.finish_reason: - stop_reason_0 = chunk.finish_reason - - async def collect_response_1(): - nonlocal full_response_1, stop_reason_1 - async for chunk in supervisor_1.stream_response(task=task): - if isinstance(chunk, TokenChunk): - full_response_1 += chunk.text - if chunk.finish_reason: - stop_reason_1 = chunk.finish_reason - - # Run both stream responses simultaneously - _ = await asyncio.gather(collect_response_0(), collect_response_1()) - - assert len(full_response_0) > 100 - assert len(full_response_1) > 100 - - finally: - await supervisor_0.astop() - await supervisor_1.astop() diff --git a/src/exo/worker/tests/test_supervisor/test_memory.py b/src/exo/worker/tests/test_supervisor/test_memory.py deleted file mode 100644 index 140923a2..00000000 --- a/src/exo/worker/tests/test_supervisor/test_memory.py +++ /dev/null @@ -1,58 +0,0 @@ -from multiprocessing import Process -from typing import Callable - -import psutil -import pytest - -from exo.shared.models.model_meta import get_model_meta -from exo.shared.types.common import Host -from exo.shared.types.models import ModelMetadata -from exo.shared.types.tasks import Task, TaskId -from exo.shared.types.worker.common import InstanceId, RunnerError -from exo.shared.types.worker.shards import PipelineShardMetadata -from exo.worker.runner.runner_supervisor import RunnerSupervisor -from exo.worker.tests.constants import INSTANCE_1_ID, TASK_1_ID - - -def get_memory_mb(process: Process) -> float: - """ - Returns the resident set size (RSS) memory usage in MiB for the given process. - """ - ps = psutil.Process(process.pid) - rss_bytes: int = ps.memory_info().rss # type: ignore[attr-defined] - return rss_bytes / (1024 * 1024) - - -@pytest.fixture -async def model_meta() -> ModelMetadata: - return await get_model_meta("mlx-community/Llama-3.3-70B-Instruct-4bit") - - -@pytest.mark.asyncio -async def test_supervisor_inference_exception( - pipeline_shard_meta: Callable[..., PipelineShardMetadata], - hosts: Callable[..., list[Host]], - chat_completion_task: Callable[[InstanceId, TaskId], Task], -): - model_shard_meta = pipeline_shard_meta(1, 0) - - supervisor = await RunnerSupervisor.create( - model_shard_meta=model_shard_meta, - hosts=hosts(1, offset=10), - ) - - process: Process = supervisor.runner_process - memory = get_memory_mb(process) - assert memory > 30 * 100 - - task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID) - task.task_params.messages[0].content = "EXO RUNNER MUST FAIL" - with pytest.raises(RunnerError): - async for _ in supervisor.stream_response(task): - pass - - await supervisor.astop() - - available_memory_bytes: int = psutil.virtual_memory().available - print(available_memory_bytes // (2**30)) - assert available_memory_bytes > 30 * 2**30 diff --git a/src/exo/worker/tests/test_supervisor/test_oom.py b/src/exo/worker/tests/test_supervisor/test_oom.py deleted file mode 100644 index 8ea4c2b8..00000000 --- a/src/exo/worker/tests/test_supervisor/test_oom.py +++ /dev/null @@ -1,48 +0,0 @@ -from typing import Callable - -import pytest - -from exo.shared.types.common import Host -from exo.shared.types.tasks import ( - Task, - TaskId, -) -from exo.shared.types.worker.common import InstanceId, RunnerError -from exo.shared.types.worker.shards import PipelineShardMetadata -from exo.worker.runner.runner_supervisor import RunnerSupervisor -from exo.worker.tests.constants import INSTANCE_1_ID, TASK_1_ID - - -@pytest.fixture -def user_message(): - """Override the default message to ask about France's capital""" - return "What is the capital of France?" - - -@pytest.mark.asyncio -@pytest.mark.skip( - reason="Must run `sudo sysctl -w iogpu.wired_limit_mb=` and `sudo sysctl -w iogpu.wired_lwm_mb=` before running this test." -) -async def test_supervisor_catches_oom( - pipeline_shard_meta: Callable[..., PipelineShardMetadata], - hosts: Callable[..., list[Host]], - chat_completion_task: Callable[[InstanceId, TaskId], Task], -): - """Test that asking for the capital of France returns 'Paris' in the response""" - model_shard_meta = pipeline_shard_meta(1, 0) - - supervisor = await RunnerSupervisor.create( - model_shard_meta=model_shard_meta, - hosts=hosts(1, offset=10), - ) - - task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID) - task.task_params.messages[0].content = "EXO RUNNER MUST OOM" - with pytest.raises(RunnerError) as exc_info: - async for _ in supervisor.stream_response(task): - pass - - error = exc_info.value - assert "memory" in error.error_message.lower() - - await supervisor.astop() diff --git a/src/exo/worker/tests/test_supervisor/test_supervisor.py b/src/exo/worker/tests/test_supervisor/test_supervisor.py deleted file mode 100644 index 9a03862c..00000000 --- a/src/exo/worker/tests/test_supervisor/test_supervisor.py +++ /dev/null @@ -1,224 +0,0 @@ -import asyncio -from typing import Callable - -import pytest - -from exo.shared.openai_compat import FinishReason -from exo.shared.types.chunks import TokenChunk -from exo.shared.types.common import Host -from exo.shared.types.tasks import ( - ChatCompletionTask, - ChatCompletionTaskParams, - Task, - TaskId, -) -from exo.shared.types.worker.common import InstanceId -from exo.shared.types.worker.shards import PipelineShardMetadata -from exo.worker.runner.runner_supervisor import RunnerSupervisor - - -@pytest.fixture -def user_message(): - """Override the default message to ask about France's capital""" - return "What is the capital of France?" - - -@pytest.mark.asyncio -async def test_supervisor_single_node_response( - pipeline_shard_meta: Callable[..., PipelineShardMetadata], - hosts: Callable[..., list[Host]], - chat_completion_task: Callable[[InstanceId, TaskId], Task], -): - """Test that asking for the capital of France returns 'Paris' in the response""" - model_shard_meta = pipeline_shard_meta(1, 0) - instance_id = InstanceId() - - print(f"{model_shard_meta=}") - - supervisor = await RunnerSupervisor.create( - model_shard_meta=model_shard_meta, - hosts=hosts(1, offset=10), - ) - - try: - full_response = "" - stop_reason: FinishReason | None = None - - async for chunk in supervisor.stream_response( - task=chat_completion_task(instance_id, TaskId()) - ): - if isinstance(chunk, TokenChunk): - full_response += chunk.text - if chunk.finish_reason: - stop_reason = chunk.finish_reason - - # Case-insensitive check for Paris in the response - assert "paris" in full_response.lower(), ( - f"Expected 'Paris' in response, but got: {full_response}" - ) - assert stop_reason == "stop" - - finally: - await supervisor.astop() - - -@pytest.mark.asyncio -async def test_supervisor_two_node_response( - pipeline_shard_meta: Callable[..., PipelineShardMetadata], - hosts: Callable[..., list[Host]], - chat_completion_task: Callable[[InstanceId, TaskId], Task], -): - """Test that asking for the capital of France returns 'Paris' in the response""" - instance_id = InstanceId() - - async def create_supervisor(shard_idx: int) -> RunnerSupervisor: - supervisor = await RunnerSupervisor.create( - model_shard_meta=pipeline_shard_meta(2, shard_idx), - hosts=hosts(2, offset=15), - ) - return supervisor - - create_supervisor_0 = asyncio.create_task(create_supervisor(0)) - create_supervisor_1 = asyncio.create_task(create_supervisor(1)) - supervisor_0, supervisor_1 = await asyncio.gather( - create_supervisor_0, create_supervisor_1 - ) - - await asyncio.sleep(0.1) - - try: - full_response_0 = "" - full_response_1 = "" - - async def collect_response_0(): - nonlocal full_response_0 - async for chunk in supervisor_0.stream_response( - task=chat_completion_task(instance_id, TaskId()) - ): - if isinstance(chunk, TokenChunk): - full_response_0 += chunk.text - - async def collect_response_1(): - nonlocal full_response_1 - async for chunk in supervisor_1.stream_response( - task=chat_completion_task(instance_id, TaskId()) - ): - if isinstance(chunk, TokenChunk): - full_response_1 += chunk.text - - # Run both stream responses simultaneously - _ = await asyncio.gather(collect_response_0(), collect_response_1()) - - print(f"full_response_0: {full_response_0}") - print(f"full_response_1: {full_response_1}") - - # Case-insensitive check for Paris in both responses - assert "paris" in full_response_0.lower(), ( - f"Expected 'Paris' in response, but got: {full_response_0}" - ) - assert "paris" in full_response_1.lower(), ( - f"Expected 'Paris' in response, but got: {full_response_1}" - ) - - finally: - await supervisor_0.astop() - await supervisor_1.astop() - - -@pytest.mark.asyncio -async def test_supervisor_early_stopping( - pipeline_shard_meta: Callable[..., PipelineShardMetadata], - hosts: Callable[..., list[Host]], - chat_completion_task: Callable[[InstanceId, TaskId], Task], -): - """Test that asking for the capital of France returns 'Paris' in the response""" - model_shard_meta = pipeline_shard_meta(1, 0) - instance_id = InstanceId() - - supervisor = await RunnerSupervisor.create( - model_shard_meta=model_shard_meta, - hosts=hosts(1, offset=10), - ) - - task = chat_completion_task(instance_id, TaskId()) - - max_tokens = 50 - assert isinstance(task, ChatCompletionTask) - print(f"chat_completion_task.task_params: {task.task_params}") - assert isinstance(task.task_params, ChatCompletionTaskParams) - task_params: ChatCompletionTaskParams = task.task_params - - try: - task_params.max_tokens = max_tokens - # Convert messages to a list to allow indexing, then update the first message's content - messages = list(task_params.messages) - messages[0].content = "Please count from 1 to 100" - task_params.messages = messages - - full_response = "" - count = 0 - stop_reason: FinishReason | None = None - - async for chunk in supervisor.stream_response(task=task): - if isinstance(chunk, TokenChunk): - full_response += chunk.text - count += 1 - if chunk.finish_reason: - stop_reason = chunk.finish_reason - - print(f"full_response: {full_response}") - - assert count == max_tokens + 1 - assert "7" in full_response.lower() - assert "99" not in full_response.lower() - - assert stop_reason == "length" - - finally: - await supervisor.astop() - - -@pytest.mark.asyncio -async def test_supervisor_handles_terminated_runner( - pipeline_shard_meta: Callable[..., PipelineShardMetadata], - hosts: Callable[..., list[Host]], -): - """Test that the supervisor handles a terminated runner""" - model_shard_meta = pipeline_shard_meta(1, 0) - - supervisor = await RunnerSupervisor.create( - model_shard_meta=model_shard_meta, - hosts=hosts(1, offset=10), - ) - - # Terminate the runner - supervisor.runner_process.terminate() - await asyncio.sleep(0.1) - - assert not supervisor.runner_process.is_alive() - - del supervisor - - -@pytest.mark.asyncio -async def test_supervisor_handles_killed_runner( - pipeline_shard_meta: Callable[..., PipelineShardMetadata], - hosts: Callable[..., list[Host]], -): - """Test that the supervisor handles a killed runner""" - model_shard_meta = pipeline_shard_meta(1, 0) - - supervisor = await RunnerSupervisor.create( - model_shard_meta=model_shard_meta, - hosts=hosts(1, offset=10), - ) - - assert supervisor.runner_process.is_alive() - - # Forcibly kill the runner - supervisor.runner_process.kill() - await asyncio.sleep(0.1) - - assert not supervisor.runner_process.is_alive() - - del supervisor diff --git a/src/exo/worker/tests/test_supervisor/test_supervisor_sad.py b/src/exo/worker/tests/test_supervisor/test_supervisor_sad.py deleted file mode 100644 index 11d24f2b..00000000 --- a/src/exo/worker/tests/test_supervisor/test_supervisor_sad.py +++ /dev/null @@ -1,92 +0,0 @@ -import asyncio -from typing import Callable - -import pytest - -from exo.shared.types.common import Host -from exo.shared.types.tasks import Task, TaskId -from exo.shared.types.worker.common import InstanceId, RunnerError -from exo.shared.types.worker.shards import PipelineShardMetadata -from exo.worker.runner.runner_supervisor import RunnerSupervisor -from exo.worker.tests.constants import INSTANCE_1_ID, TASK_1_ID - - -@pytest.mark.asyncio -async def test_supervisor_instantiation_exception( - pipeline_shard_meta: Callable[..., PipelineShardMetadata], - hosts: Callable[..., list[Host]], -): - """Test that asking for the capital of France returns 'Paris' in the response""" - model_shard_meta = pipeline_shard_meta(1, 0) - model_shard_meta.immediate_exception = True - - # _ = await RunnerSupervisor.create( - # model_shard_meta=model_shard_meta, - # hosts=hosts(1, offset=10), - # ) - - with pytest.raises(RunnerError): - _ = await RunnerSupervisor.create( - model_shard_meta=model_shard_meta, - hosts=hosts(1, offset=10), - ) - - -@pytest.mark.asyncio -async def test_supervisor_instantiation_timeout( - pipeline_shard_meta: Callable[..., PipelineShardMetadata], - hosts: Callable[..., list[Host]], -): - """Test that asking for the capital of France returns 'Paris' in the response""" - model_shard_meta = pipeline_shard_meta(1, 0) - model_shard_meta.should_timeout = 10 # timeout after 10s - - with pytest.raises(asyncio.TimeoutError): - _ = await RunnerSupervisor.create( - model_shard_meta=model_shard_meta, - hosts=hosts(1, offset=10), - ) - - -@pytest.mark.asyncio -async def test_supervisor_inference_exception( - pipeline_shard_meta: Callable[..., PipelineShardMetadata], - hosts: Callable[..., list[Host]], - chat_completion_task: Callable[[InstanceId, TaskId], Task], -): - """Test that asking for the capital of France returns 'Paris' in the response""" - model_shard_meta = pipeline_shard_meta(1, 0) - - supervisor = await RunnerSupervisor.create( - model_shard_meta=model_shard_meta, - hosts=hosts(1, offset=10), - ) - - task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID) - task.task_params.messages[0].content = "EXO RUNNER MUST FAIL" - with pytest.raises(RunnerError): - async for _ in supervisor.stream_response(task): - pass - - -@pytest.mark.asyncio -async def test_supervisor_inference_timeout( - pipeline_shard_meta: Callable[..., PipelineShardMetadata], - hosts: Callable[..., list[Host]], - chat_completion_task: Callable[[InstanceId, TaskId], Task], -): - """Test that asking for the capital of France returns 'Paris' in the response""" - model_shard_meta = pipeline_shard_meta(1, 0) - - supervisor = await RunnerSupervisor.create( - model_shard_meta=model_shard_meta, - hosts=hosts(1, offset=10), - ) - - task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID) - task.task_params.messages[0].content = "EXO RUNNER MUST TIMEOUT" - with pytest.raises(asyncio.TimeoutError): - async for _ in supervisor.stream_response(task): - pass - - await asyncio.sleep(0.1) diff --git a/src/exo/worker/tests/worker_management.py b/src/exo/worker/tests/worker_management.py deleted file mode 100644 index 220665e6..00000000 --- a/src/exo/worker/tests/worker_management.py +++ /dev/null @@ -1,189 +0,0 @@ -from dataclasses import dataclass -from typing import Callable - -from anyio import fail_after - -from exo.routing.topics import ConnectionMessage, ForwarderCommand, ForwarderEvent -from exo.shared.types.chunks import TokenChunk -from exo.shared.types.common import NodeId, SessionId -from exo.shared.types.events import ChunkGenerated, Event, TaskStateUpdated -from exo.shared.types.tasks import TaskId, TaskStatus -from exo.utils.channels import Receiver, Sender, channel -from exo.worker.download.shard_downloader import NoopShardDownloader, ShardDownloader -from exo.worker.main import Worker -from exo.worker.tests.constants import MASTER_NODE_ID - -session = SessionId(master_node_id=MASTER_NODE_ID, election_clock=0) - - -@dataclass -class WorkerMailbox: - sender: Sender[ForwarderEvent] - receiver: Receiver[ForwarderEvent] - counter: int = 0 - - async def append_events( - self, - events: list[Event], - *, - origin: NodeId, - ): - for event in events: - await self.sender.send( - ForwarderEvent( - origin=origin, - session=session, - event=event, - origin_idx=self.counter, - ) - ) - self.counter += 1 - - async def receive(self) -> ForwarderEvent: - return await self.receiver.receive() - - def collect(self) -> list[ForwarderEvent]: - # Clear out the test mailboxes currently held events - return self.receiver.collect() - - -def create_worker_void_mailbox( - node_id: NodeId, shard_downloader: ShardDownloader | None = None -) -> Worker: - if shard_downloader is None: - shard_downloader = NoopShardDownloader() - return Worker( - node_id, - session_id=session, - shard_downloader=shard_downloader, - initial_connection_messages=[], - connection_message_receiver=channel[ConnectionMessage]()[1], - global_event_receiver=channel[ForwarderEvent]()[1], - local_event_sender=channel[ForwarderEvent]()[0], - command_sender=channel[ForwarderCommand]()[0], - ) - - -def create_worker_and_mailbox( - node_id: NodeId, shard_downloader: ShardDownloader | None = None -) -> tuple[Worker, WorkerMailbox]: - if shard_downloader is None: - shard_downloader = NoopShardDownloader() - - lsend, receiver = channel[ForwarderEvent]() - sender, grecv = channel[ForwarderEvent]() - worker = Worker( - node_id, - session_id=session, - shard_downloader=shard_downloader, - initial_connection_messages=[], - connection_message_receiver=channel[ConnectionMessage]()[1], - global_event_receiver=grecv, - local_event_sender=lsend, - command_sender=channel[ForwarderCommand]()[0], - ) - return worker, WorkerMailbox(sender, receiver) - - -def create_worker_with_old_mailbox( - node_id: NodeId, - mailbox: WorkerMailbox, - shard_downloader: ShardDownloader | None = None, -) -> Worker: - if shard_downloader is None: - shard_downloader = NoopShardDownloader() - # This function is subtly complex, come talk to Evan if you want to know what it's actually doing. - worker = Worker( - node_id, - session_id=session, - shard_downloader=shard_downloader, - initial_connection_messages=[], - connection_message_receiver=channel[ConnectionMessage]()[1], - global_event_receiver=mailbox.sender.clone_receiver(), - local_event_sender=mailbox.receiver.clone_sender(), - command_sender=channel[ForwarderCommand]()[0], - ) - return worker - - -async def read_streaming_response( - global_event_receiver: WorkerMailbox, filter_task: TaskId | None = None -) -> tuple[bool, bool, str, int]: - # Read off all events - these should be our GenerationChunk events - seen_task_started = 0 - seen_task_finished = 0 - response_string = "" - finish_reason: str | None = None - token_count = 0 - extra_events: list[Event] = [] - - event = (await global_event_receiver.receive()).event - extra_events.append(event) - - from loguru import logger - - logger.info("STARTING READ") - - with fail_after(10.0): - if filter_task: - while not ( - isinstance(event, TaskStateUpdated) - and event.task_status == TaskStatus.Running - and event.task_id == filter_task - ): - event = (await global_event_receiver.receive()).event - extra_events.append(event) - - for event in extra_events: - if isinstance(event, TaskStateUpdated): - if event.task_status == TaskStatus.Running: - seen_task_started += 1 - if event.task_status == TaskStatus.Complete: - seen_task_finished += 1 - if isinstance(event, ChunkGenerated) and isinstance( - event.chunk, TokenChunk - ): - response_string += event.chunk.text - token_count += 1 - if event.chunk.finish_reason: - finish_reason = event.chunk.finish_reason - - while not seen_task_finished: - event = (await global_event_receiver.receive()).event - if isinstance(event, TaskStateUpdated): - if event.task_status == TaskStatus.Running: - seen_task_started += 1 - if event.task_status == TaskStatus.Complete: - seen_task_finished += 1 - if isinstance(event, ChunkGenerated) and isinstance( - event.chunk, TokenChunk - ): - response_string += event.chunk.text - token_count += 1 - if event.chunk.finish_reason: - finish_reason = event.chunk.finish_reason - - logger.info(f"finish reason {finish_reason}") - - return seen_task_started == 1, seen_task_finished == 1, response_string, token_count - - -async def until_event_with_timeout[T]( - global_event_receiver: WorkerMailbox, - event_type: type[T], - multiplicity: int = 1, - condition: Callable[[T], bool] = lambda x: True, - timeout: float = 30.0, -) -> None: - times_seen = 0 - - with fail_after(timeout): - while times_seen < multiplicity: - event = (await global_event_receiver.receive()).event - if isinstance(event, event_type): - print(f"Wow! We got a {event}") - print( - f"But condition? {condition(event) if isinstance(event, event_type) else False}" - ) - if event and isinstance(event, event_type) and condition(event): - times_seen += 1 diff --git a/src/exo/worker/utils/macmon.py b/src/exo/worker/utils/macmon.py new file mode 100644 index 00000000..3e4e29e1 --- /dev/null +++ b/src/exo/worker/utils/macmon.py @@ -0,0 +1,97 @@ +import platform +import shutil +from subprocess import CalledProcessError + +from anyio import run_process +from pydantic import BaseModel, ConfigDict, ValidationError + + +class MacMonError(Exception): + """Exception raised for errors in the MacMon functions.""" + + +def _get_binary_path() -> str: + """ + Get the path to the macmon binary. + + Raises: + MacMonError: If the binary doesn't exist or can't be made executable. + """ + # Check for macOS with ARM chip + system = platform.system().lower() + machine = platform.machine().lower() + + if system != "darwin" or not ( + "arm" in machine or "m1" in machine or "m2" in machine + ): + raise MacMonError("MacMon only supports macOS with Apple Silicon (ARM) chips") + + path = shutil.which("macmon") + + if path is None: + raise MacMonError("MacMon not found in PATH") + + return path + + +class TempMetrics(BaseModel): + """Temperature-related metrics returned by macmon.""" + + cpu_temp_avg: float + gpu_temp_avg: float + + model_config = ConfigDict(extra="ignore") + + +class Metrics(BaseModel): + """Complete set of metrics returned by macmon. + + Unknown fields are ignored for forward-compatibility. + """ + + all_power: float + ane_power: float + cpu_power: float + ecpu_usage: tuple[int, float] + gpu_power: float + gpu_ram_power: float + gpu_usage: tuple[int, float] + pcpu_usage: tuple[int, float] + ram_power: float + sys_power: float + temp: TempMetrics + timestamp: str + + model_config = ConfigDict(extra="ignore") + + +async def get_metrics_async() -> Metrics: + """ + Asynchronously run the binary and return the metrics as a Python dictionary. + + Args: + binary_path: Optional path to the binary. If not provided, will use the bundled binary. + + Returns: + A mapping containing system metrics. + + Raises: + MacMonError: If there's an error running the binary. + """ + path = _get_binary_path() + + result = None + try: + # TODO: Keep Macmon running in the background? + result = await run_process([path, "pipe", "-s", "1"]) + + return Metrics.model_validate_json(result.stdout.decode().strip()) + + except ValidationError as e: + raise MacMonError(f"Error parsing JSON output: {e}") from e + except CalledProcessError as e: + if result: + raise MacMonError( + f"MacMon failed with return code {result.returncode}" + ) from e + raise e diff --git a/src/exo/worker/utils/macmon/.DS_Store b/src/exo/worker/utils/macmon/.DS_Store deleted file mode 100644 index a3585876..00000000 Binary files a/src/exo/worker/utils/macmon/.DS_Store and /dev/null differ diff --git a/src/exo/worker/utils/macmon/__init__.py b/src/exo/worker/utils/macmon/__init__.py deleted file mode 100644 index bf4bda58..00000000 --- a/src/exo/worker/utils/macmon/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .macmon import MacMonError, get_metrics, get_metrics_async - -__all__ = ["get_metrics", "get_metrics_async", "MacMonError"] diff --git a/src/exo/worker/utils/macmon/macmon.py b/src/exo/worker/utils/macmon/macmon.py deleted file mode 100644 index 81e949ff..00000000 --- a/src/exo/worker/utils/macmon/macmon.py +++ /dev/null @@ -1,150 +0,0 @@ -import asyncio -import platform -import shutil -import subprocess - -from pydantic import BaseModel, ConfigDict, ValidationError - - -class MacMonError(Exception): - """Exception raised for errors in the MacMon functions.""" - - -def _get_binary_path() -> str: - """ - Get the path to the macmon binary. - - Raises: - MacMonError: If the binary doesn't exist or can't be made executable. - """ - # Check for macOS with ARM chip - system = platform.system().lower() - machine = platform.machine().lower() - - if system != "darwin" or not ( - "arm" in machine or "m1" in machine or "m2" in machine - ): - raise MacMonError("MacMon only supports macOS with Apple Silicon (ARM) chips") - - path = shutil.which("macmon") - - if path is None: - raise MacMonError("MacMon not found in PATH") - - return path - - -# --------------------------------------------------------------------------- -# Pydantic metric structures -# --------------------------------------------------------------------------- - - -class MemoryMetrics(BaseModel): - """Memory-related metrics returned by macmon.""" - - ram_total: int | None = None - ram_usage: int | None = None - swap_total: int | None = None - swap_usage: int | None = None - - model_config = ConfigDict(extra="ignore") - - -class TempMetrics(BaseModel): - """Temperature-related metrics returned by macmon.""" - - cpu_temp_avg: float | None = None - gpu_temp_avg: float | None = None - - model_config = ConfigDict(extra="ignore") - - -class Metrics(BaseModel): - """Complete set of metrics returned by *macmon* binary. - - All fields are optional to allow for partial output from the binary. - Unknown fields are ignored for forward-compatibility. - """ - - all_power: float | None = None - ane_power: float | None = None - cpu_power: float | None = None - ecpu_usage: tuple[int, float] | None = None - gpu_power: float | None = None - gpu_ram_power: float | None = None - gpu_usage: tuple[int, float] | None = None - memory: MemoryMetrics | None = None - pcpu_usage: tuple[int, float] | None = None - ram_power: float | None = None - sys_power: float | None = None - temp: TempMetrics | None = None - timestamp: str | None = None - - model_config = ConfigDict(extra="ignore") - - -# --------------------------------------------------------------------------- -# Synchronous helper -# --------------------------------------------------------------------------- - - -def get_metrics() -> Metrics: - """ - Run the binary and return the metrics as a Python dictionary. - - Returns: - A mapping containing system metrics. - - Raises: - MacMonError: If there's an error running the binary. - """ - path = _get_binary_path() - - try: - # Run the binary with the argument -s 1 and capture its output - result = subprocess.run( - [path, "pipe", "-s", "1"], capture_output=True, text=True, check=True - ) - - return Metrics.model_validate_json(result.stdout) - - except subprocess.CalledProcessError as e: - raise MacMonError(f"Error running binary: {e.stderr}") from e # type: ignore - except ValidationError as e: - raise MacMonError(f"Error parsing JSON output: {e}") from e - - -async def get_metrics_async() -> Metrics: - """ - Asynchronously run the binary and return the metrics as a Python dictionary. - - Args: - binary_path: Optional path to the binary. If not provided, will use the bundled binary. - - Returns: - A mapping containing system metrics. - - Raises: - MacMonError: If there's an error running the binary. - """ - path = _get_binary_path() - - try: - proc = await asyncio.create_subprocess_exec( - path, - "pipe", - "-s", - "1", - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, - ) - - stdout, stderr = await proc.communicate() - - if proc.returncode != 0: - raise MacMonError(f"Error running binary: {stderr.decode().strip()}") - - return Metrics.model_validate_json(stdout.decode().strip()) - - except ValidationError as e: - raise MacMonError(f"Error parsing JSON output: {e}") from e diff --git a/src/exo/worker/utils/profile.py b/src/exo/worker/utils/profile.py index 134aa600..9506428b 100644 --- a/src/exo/worker/utils/profile.py +++ b/src/exo/worker/utils/profile.py @@ -4,7 +4,6 @@ import platform from typing import Any, Callable, Coroutine import anyio -import psutil from loguru import logger from exo.shared.types.memory import Memory @@ -13,59 +12,37 @@ from exo.shared.types.profiling import ( NodePerformanceProfile, SystemPerformanceProfile, ) -from exo.worker.utils.macmon.macmon import ( +from exo.worker.utils.macmon import ( + MacMonError, Metrics, ) -from exo.worker.utils.macmon.macmon import ( +from exo.worker.utils.macmon import ( get_metrics_async as macmon_get_metrics_async, ) from exo.worker.utils.system_info import ( - get_mac_friendly_name_async, - get_mac_system_info_async, - get_network_interface_info_async, + get_friendly_name, + get_model_and_chip, + get_network_interfaces, ) -async def get_metrics_async() -> Metrics: - """Return detailed Metrics on macOS or a minimal fallback elsewhere. - - The *Metrics* schema comes from ``utils.macmon.macmon``; on non-macOS systems we - fill only the ``memory`` sub-structure so downstream code can still access - ``metrics.memory.ram_total`` & ``ram_usage``. - """ +async def get_metrics_async() -> Metrics | None: + """Return detailed Metrics on macOS or a minimal fallback elsewhere.""" if platform.system().lower() == "darwin": return await macmon_get_metrics_async() - return Metrics() -async def get_memory_profile_async() -> MemoryPerformanceProfile: - """Return MemoryPerformanceProfile using psutil (fast, cross-platform). +def get_memory_profile() -> MemoryPerformanceProfile: + """Construct a MemoryPerformanceProfile using psutil""" + override_memory_env = os.getenv("OVERRIDE_MEMORY_MB") + override_memory: int | None = ( + Memory.from_mb(int(override_memory_env)).in_bytes + if override_memory_env + else None + ) - Uses synchronous psutil calls in a worker thread to avoid blocking the event loop. - """ - - def _read_psutil() -> MemoryPerformanceProfile: - vm = psutil.virtual_memory() - sm = psutil.swap_memory() - - override_memory_env = os.getenv("OVERRIDE_MEMORY_MB") - override_memory: int | None = ( - Memory.from_mb(int(override_memory_env)).in_bytes - if override_memory_env - else None - ) - - return MemoryPerformanceProfile.from_bytes( - ram_total=int(vm.total), - ram_available=int(override_memory) - if override_memory - else int(vm.available), - swap_total=int(sm.total), - swap_available=int(sm.free), - ) - - return await asyncio.to_thread(_read_psutil) + return MemoryPerformanceProfile.from_psutil(override_memory=override_memory) async def start_polling_memory_metrics( @@ -81,9 +58,9 @@ async def start_polling_memory_metrics( """ while True: try: - mem = await get_memory_profile_async() + mem = get_memory_profile() await callback(mem) - except Exception as e: + except MacMonError as e: logger.opt(exception=e).error("Memory Monitor encountered error") finally: await anyio.sleep(poll_interval_s) @@ -95,61 +72,41 @@ async def start_polling_node_metrics( poll_interval_s = 1.0 while True: try: - # Gather metrics & system info with a timeout on each call metrics = await get_metrics_async() + if metrics is None: + return - ( - system_info, - network_interfaces, - mac_friendly_name, - ) = await asyncio.gather( - get_mac_system_info_async(), - get_network_interface_info_async(), - get_mac_friendly_name_async(), - ) + network_interfaces = get_network_interfaces() + # these awaits could be joined but realistically they should be cached + model_id, chip_id = await get_model_and_chip() + friendly_name = await get_friendly_name() # do the memory profile last to get a fresh reading to not conflict with the other memory profiling loop - memory_profile = await get_memory_profile_async() + memory_profile = get_memory_profile() await callback( NodePerformanceProfile( - model_id=system_info.model_id, - chip_id=system_info.chip_id, - friendly_name=mac_friendly_name or "Unknown", + model_id=model_id, + chip_id=chip_id, + friendly_name=friendly_name, network_interfaces=network_interfaces, memory=memory_profile, system=SystemPerformanceProfile( - flops_fp16=0, - gpu_usage=metrics.gpu_usage[1] - if metrics.gpu_usage is not None - else 0, - temp=metrics.temp.gpu_temp_avg - if metrics.temp is not None - and metrics.temp.gpu_temp_avg is not None - else 0, - sys_power=metrics.sys_power - if metrics.sys_power is not None - else 0, - pcpu_usage=metrics.pcpu_usage[1] - if metrics.pcpu_usage is not None - else 0, - ecpu_usage=metrics.ecpu_usage[1] - if metrics.ecpu_usage is not None - else 0, - ane_power=metrics.ane_power - if metrics.ane_power is not None - else 0, + gpu_usage=metrics.gpu_usage[1], + temp=metrics.temp.gpu_temp_avg, + sys_power=metrics.sys_power, + pcpu_usage=metrics.pcpu_usage[1], + ecpu_usage=metrics.ecpu_usage[1], + ane_power=metrics.ane_power, ), ) ) except asyncio.TimeoutError: - # One of the operations took too long; skip this iteration but keep the loop alive. logger.warning( "[resource_monitor] Operation timed out after 30s, skipping this cycle." ) - except Exception as e: - # Catch-all to ensure the monitor keeps running. + except MacMonError as e: logger.opt(exception=e).error("Resource Monitor encountered error") finally: await anyio.sleep(poll_interval_s) diff --git a/src/exo/worker/utils/system_info.py b/src/exo/worker/utils/system_info.py index d9873df2..930d9428 100644 --- a/src/exo/worker/utils/system_info.py +++ b/src/exo/worker/utils/system_info.py @@ -1,77 +1,34 @@ -import asyncio -import re +import socket import sys +from subprocess import CalledProcessError -from loguru import logger -from pydantic import BaseModel, Field +import psutil +from anyio import run_process from exo.shared.types.profiling import NetworkInterfaceInfo -class SystemInfo(BaseModel): - model_id: str - chip_id: str - memory: int - network_interfaces: list[NetworkInterfaceInfo] = Field(default_factory=list) - - -async def get_mac_friendly_name_async() -> str | None: +async def get_friendly_name() -> str: """ Asynchronously gets the 'Computer Name' (friendly name) of a Mac. e.g., "John's MacBook Pro" Returns the name as a string, or None if an error occurs or not on macOS. """ + hostname = socket.gethostname() + + # TODO: better non mac support if sys.platform != "darwin": # 'darwin' is the platform name for macOS - logger.warning("Mac friendly name is designed for macOS only.") - return None + return hostname try: - # asyncio.create_subprocess_exec allows running external commands asynchronously. - # stdout=asyncio.subprocess.PIPE captures standard output. - # stderr=asyncio.subprocess.PIPE captures standard error. - process = await asyncio.create_subprocess_exec( - "scutil", - "--get", - "ComputerName", - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, - ) + process = await run_process(["scutil", "--get", "ComputerName"]) + except CalledProcessError: + return hostname - # process.communicate() reads all data from stdout and stderr - # and waits for the process to terminate. - # It returns a tuple (stdout_data, stderr_data). - stdout_data, stderr_data = await process.communicate() - - # Check the return code of the process - if process.returncode == 0: - if stdout_data: - # Decode from bytes to string and strip whitespace - friendly_name = stdout_data.decode().strip() - return friendly_name - else: - # Should not happen if returncode is 0, but good to check - print("scutil command succeeded but produced no output.") - return None - else: - # If there was an error, print the stderr output - error_message = ( - stderr_data.decode().strip() if stderr_data else "Unknown error" - ) - print( - f"Error executing scutil (return code {process.returncode}): {error_message}" - ) - return None - - except FileNotFoundError: - # This would happen if scutil is somehow not found, highly unlikely on a Mac. - print("Error: 'scutil' command not found. Are you sure this is macOS?") - return None - except Exception as e: - print(f"An unexpected error occurred: {e}") - return None + return process.stdout.decode("utf-8", errors="replace").strip() or hostname -async def get_network_interface_info_async() -> list[NetworkInterfaceInfo]: +def get_network_interfaces() -> list[NetworkInterfaceInfo]: """ Retrieves detailed network interface information on macOS. Parses output from 'networksetup -listallhardwareports' and 'ifconfig' @@ -80,162 +37,47 @@ async def get_network_interface_info_async() -> list[NetworkInterfaceInfo]: """ interfaces_info: list[NetworkInterfaceInfo] = [] - async def _run_cmd_async(command_parts: list[str]) -> str | None: - # Helper to run a command and return its stdout, or None on error. - try: - process = await asyncio.create_subprocess_exec( - *command_parts, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, - ) - stdout_data, stderr_data = await process.communicate() - if process.returncode == 0: - # Use 'utf-8' and replace errors for robustness - return stdout_data.decode("utf-8", errors="replace").strip() - else: - error_message = ( - stderr_data.decode("utf-8", errors="replace").strip() - if stderr_data - else "Unknown error" - ) - print( - f"Error executing {' '.join(command_parts)} (code {process.returncode}): {error_message}" - ) - return None - except FileNotFoundError: - print( - f"Error: Command '{command_parts[0]}' not found. Ensure it's in PATH." - ) - return None - except Exception as e: - print( - f"An unexpected error occurred running {' '.join(command_parts)}: {e}" - ) - return None - - # Get interface names and IP addresses from ifconfig - ifconfig_output = await _run_cmd_async(["ifconfig"]) - if ifconfig_output: - # Regex for interface name (e.g., en0:, utun0:, tailscale0.) - interface_header_pattern = re.compile(r"^([a-zA-Z0-9\._-]+):") - # Regex for IPv4 address (inet) - inet_pattern = re.compile(r"^\s+inet\s+(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})") - # Regex for IPv6 address (inet6) - inet6_pattern = re.compile(r"^\s+inet6\s+([0-9a-fA-F:]+(?:%[a-zA-Z0-9._-]+)?)") - - current_if_name: str | None = None - for line in ifconfig_output.splitlines(): - header_match = interface_header_pattern.match(line) - if header_match: - current_if_name = header_match.group(1) - - if current_if_name: - inet_m = inet_pattern.match(line) - if inet_m: - ipv4_address = inet_m.group(1) + for iface, services in psutil.net_if_addrs().items(): + for service in services: + match service.family: + case socket.AF_INET | socket.AF_INET6: interfaces_info.append( - NetworkInterfaceInfo( - name=current_if_name, ip_address=ipv4_address, type="" - ) - ) - - inet6_m = inet6_pattern.match(line) - if inet6_m: - ipv6_address = inet6_m.group(1) - interfaces_info.append( - NetworkInterfaceInfo( - name=current_if_name, ip_address=ipv6_address, type="" - ) + NetworkInterfaceInfo(name=iface, ip_address=service.address) ) + case _: + pass return interfaces_info -async def get_mac_system_info_async() -> SystemInfo: +async def get_model_and_chip() -> tuple[str, str]: """Get Mac system information using system_profiler.""" - model_id_val = "Unknown Model" - chip_id_val = "Unknown Chip" - memory_val = 0 - network_interfaces_info_list: list[NetworkInterfaceInfo] = [] + model = "Unknown Model" + chip = "Unknown Chip" + # TODO: better non mac support if sys.platform != "darwin": - return SystemInfo( - model_id=model_id_val, - chip_id=chip_id_val, - memory=memory_val, - network_interfaces=network_interfaces_info_list, - ) + return (model, chip) try: - process = await asyncio.create_subprocess_exec( - "system_profiler", - "SPHardwareDataType", - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, + process = await run_process( + [ + "system_profiler", + "SPHardwareDataType", + ] ) - stdout_data, stderr_data = await process.communicate() - if process.returncode == 0: - if stdout_data: - output = stdout_data.decode().strip() - model_line = next( - (line for line in output.split("\n") if "Model Name" in line), None - ) - model_id_val = ( - model_line.split(": ")[1] if model_line else "Unknown Model" - ) + except CalledProcessError: + return (model, chip) - chip_line = next( - (line for line in output.split("\n") if "Chip" in line), None - ) - chip_id_val = chip_line.split(": ")[1] if chip_line else "Unknown Chip" + # less interested in errors here because this value should be hard coded + output = process.stdout.decode().strip() - memory_line = next( - (line for line in output.split("\n") if "Memory" in line), None - ) - memory_str = ( - memory_line.split(": ")[1] if memory_line else "0 GB" - ) # Default to "0 GB" - memory_units = memory_str.split() - if len(memory_units) == 2: - try: - memory_value_int = int(memory_units[0]) - if memory_units[1] == "GB": - memory_val = memory_value_int * 1024 # Assuming MB - elif memory_units[1] == "MB": - memory_val = memory_value_int - else: # TB? Unlikely for typical memory, handle gracefully - memory_val = memory_value_int # Store as is, let consumer decide unit or log - print(f"Warning: Unknown memory unit {memory_units[1]}") - except ValueError: - print( - f"Warning: Could not parse memory value {memory_units[0]}" - ) - memory_val = 0 - - else: - print( - "system_profiler command succeeded but produced no output for hardware." - ) - else: - error_message = ( - stderr_data.decode().strip() if stderr_data else "Unknown error" - ) - print( - f"Error executing system_profiler (return code {process.returncode}): {error_message}" - ) - except Exception as e: - print(f"Error getting Mac hardware info: {e}") - - # Call the new function to get network info - try: - network_interfaces_info_list = await get_network_interface_info_async() - except Exception as e: - print(f"Error getting Mac network interface info: {e}") - network_interfaces_info_list = [] - - return SystemInfo( - model_id=model_id_val, - chip_id=chip_id_val, - memory=memory_val, - network_interfaces=network_interfaces_info_list, + model_line = next( + (line for line in output.split("\n") if "Model Name" in line), None ) + model = model_line.split(": ")[1] if model_line else "Unknown Model" + + chip_line = next((line for line in output.split("\n") if "Chip" in line), None) + chip = chip_line.split(": ")[1] if chip_line else "Unknown Chip" + + return (model, chip) diff --git a/uv.lock b/uv.lock index a3e25d6f..861f4649 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 1 +revision = 3 requires-python = ">=3.13" resolution-markers = [ "sys_platform == 'darwin'", @@ -320,15 +320,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/cd/c7/f65027c2810e14c3e7268353b1681932b87e5a48e65505d8cc17c99e36ae/cryptography-46.0.3-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:3b51b8ca4f1c6453d8829e1eb7299499ca7f313900dd4d89a24b8b87c0a780d4", size = 4686573, upload-time = "2025-10-15T23:18:06.908Z" }, ] -[[package]] -name = "distro" -version = "1.9.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/fc/f8/98eea607f65de6527f8a2e8885fc8015d3e6f5775df186e443e0964a11c3/distro-1.9.0.tar.gz", hash = "sha256:2fa77c6fd8940f116ee1d6b94a2f90b13b5ea8d019b98bc8bafdcabcdd9bdbed", size = 60722, upload-time = "2023-12-24T09:54:32.31Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/12/b3/231ffd4ab1fc9d679809f356cebee130ac7daa00d6d6f3206dd4fd137e9e/distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2", size = 20277, upload-time = "2023-12-24T09:54:30.421Z" }, -] - [[package]] name = "exo" version = "0.3.0" @@ -351,7 +342,6 @@ dependencies = [ { name = "mlx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "mlx-lm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "networkx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "openai", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "pathlib", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "protobuf", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "psutil", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, @@ -372,6 +362,7 @@ dependencies = [ dev = [ { name = "pytest", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "pytest-asyncio", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "pytest-env", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "ruff", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, ] @@ -394,7 +385,6 @@ requires-dist = [ { name = "mlx", specifier = ">=0.29.3" }, { name = "mlx-lm", specifier = ">=0.28.3" }, { name = "networkx", specifier = ">=3.5" }, - { name = "openai", specifier = ">=1.99.9" }, { name = "pathlib", specifier = ">=1.0.1" }, { name = "protobuf", specifier = ">=6.32.0" }, { name = "psutil", specifier = ">=7.0.0" }, @@ -415,6 +405,7 @@ requires-dist = [ dev = [ { name = "pytest", specifier = ">=8.4.0" }, { name = "pytest-asyncio", specifier = ">=1.0.0" }, + { name = "pytest-env" }, { name = "ruff", specifier = ">=0.11.13" }, ] @@ -594,34 +585,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/92/68/89ac4e5b12a9ff6286a12174c8538a5930e2ed662091dd2572bbe0a18c8a/hf_xet-1.2.0-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:a55558084c16b09b5ed32ab9ed38421e2d87cf3f1f89815764d1177081b99865", size = 3508920, upload-time = "2025-10-24T19:04:26.927Z" }, ] -[[package]] -name = "httpcore" -version = "1.0.9" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "certifi", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "h11", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/06/94/82699a10bca87a5556c9c59b5963f2d039dbd239f25bc2a63907a05a14cb/httpcore-1.0.9.tar.gz", hash = "sha256:6e34463af53fd2ab5d807f399a9b45ea31c3dfa2276f15a2c3f00afff6e176e8", size = 85484, upload-time = "2025-04-24T22:06:22.219Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/7e/f5/f66802a942d491edb555dd61e3a9961140fd64c90bce1eafd741609d334d/httpcore-1.0.9-py3-none-any.whl", hash = "sha256:2d400746a40668fc9dec9810239072b40b4484b640a8c38fd654a024c7a1bf55", size = 78784, upload-time = "2025-04-24T22:06:20.566Z" }, -] - -[[package]] -name = "httpx" -version = "0.28.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "anyio", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "certifi", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "httpcore", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "idna", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/b1/df/48c586a5fe32a0f01324ee087459e112ebb7224f646c0b5023f5e79e9956/httpx-0.28.1.tar.gz", hash = "sha256:75e98c5f16b0f35b567856f597f06ff2270a374470a5c2392242528e3e3e42fc", size = 141406, upload-time = "2024-12-06T15:37:23.222Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517, upload-time = "2024-12-06T15:37:21.509Z" }, -] - [[package]] name = "huggingface-hub" version = "0.36.0" @@ -671,46 +634,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/62/a1/3d680cbfd5f4b8f15abc1d571870c5fc3e594bb582bc3b64ea099db13e56/jinja2-3.1.6-py3-none-any.whl", hash = "sha256:85ece4451f492d0c13c5dd7c13a64681a86afae63a5f347908daf103ce6d2f67", size = 134899, upload-time = "2025-03-05T20:05:00.369Z" }, ] -[[package]] -name = "jiter" -version = "0.11.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/a3/68/0357982493a7b20925aece061f7fb7a2678e3b232f8d73a6edb7e5304443/jiter-0.11.1.tar.gz", hash = "sha256:849dcfc76481c0ea0099391235b7ca97d7279e0fa4c86005457ac7c88e8b76dc", size = 168385, upload-time = "2025-10-17T11:31:15.186Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/7c/4b/e4dd3c76424fad02a601d570f4f2a8438daea47ba081201a721a903d3f4c/jiter-0.11.1-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:71b6a920a5550f057d49d0e8bcc60945a8da998019e83f01adf110e226267663", size = 305272, upload-time = "2025-10-17T11:29:39.249Z" }, - { url = "https://files.pythonhosted.org/packages/67/83/2cd3ad5364191130f4de80eacc907f693723beaab11a46c7d155b07a092c/jiter-0.11.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:0b3de72e925388453a5171be83379549300db01284f04d2a6f244d1d8de36f94", size = 314038, upload-time = "2025-10-17T11:29:40.563Z" }, - { url = "https://files.pythonhosted.org/packages/d3/3c/8e67d9ba524e97d2f04c8f406f8769a23205026b13b0938d16646d6e2d3e/jiter-0.11.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cc19dd65a2bd3d9c044c5b4ebf657ca1e6003a97c0fc10f555aa4f7fb9821c00", size = 345977, upload-time = "2025-10-17T11:29:42.009Z" }, - { url = "https://files.pythonhosted.org/packages/8d/a5/489ce64d992c29bccbffabb13961bbb0435e890d7f2d266d1f3df5e917d2/jiter-0.11.1-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d58faaa936743cd1464540562f60b7ce4fd927e695e8bc31b3da5b914baa9abd", size = 364503, upload-time = "2025-10-17T11:29:43.459Z" }, - { url = "https://files.pythonhosted.org/packages/d4/c0/e321dd83ee231d05c8fe4b1a12caf1f0e8c7a949bf4724d58397104f10f2/jiter-0.11.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:902640c3103625317291cb73773413b4d71847cdf9383ba65528745ff89f1d14", size = 487092, upload-time = "2025-10-17T11:29:44.835Z" }, - { url = "https://files.pythonhosted.org/packages/f9/5e/8f24ec49c8d37bd37f34ec0112e0b1a3b4b5a7b456c8efff1df5e189ad43/jiter-0.11.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:30405f726e4c2ed487b176c09f8b877a957f535d60c1bf194abb8dadedb5836f", size = 376328, upload-time = "2025-10-17T11:29:46.175Z" }, - { url = "https://files.pythonhosted.org/packages/7f/70/ded107620e809327cf7050727e17ccfa79d6385a771b7fe38fb31318ef00/jiter-0.11.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3217f61728b0baadd2551844870f65219ac4a1285d5e1a4abddff3d51fdabe96", size = 356632, upload-time = "2025-10-17T11:29:47.454Z" }, - { url = "https://files.pythonhosted.org/packages/19/53/c26f7251613f6a9079275ee43c89b8a973a95ff27532c421abc2a87afb04/jiter-0.11.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b1364cc90c03a8196f35f396f84029f12abe925415049204446db86598c8b72c", size = 384358, upload-time = "2025-10-17T11:29:49.377Z" }, - { url = "https://files.pythonhosted.org/packages/84/16/e0f2cc61e9c4d0b62f6c1bd9b9781d878a427656f88293e2a5335fa8ff07/jiter-0.11.1-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:53a54bf8e873820ab186b2dca9f6c3303f00d65ae5e7b7d6bda1b95aa472d646", size = 517279, upload-time = "2025-10-17T11:29:50.968Z" }, - { url = "https://files.pythonhosted.org/packages/60/5c/4cd095eaee68961bca3081acbe7c89e12ae24a5dae5fd5d2a13e01ed2542/jiter-0.11.1-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:7e29aca023627b0e0c2392d4248f6414d566ff3974fa08ff2ac8dbb96dfee92a", size = 508276, upload-time = "2025-10-17T11:29:52.619Z" }, - { url = "https://files.pythonhosted.org/packages/65/9b/4a57922437ca8753ef823f434c2dec5028b237d84fa320f06a3ba1aec6e8/jiter-0.11.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:d892b184da4d94d94ddb4031296931c74ec8b325513a541ebfd6dfb9ae89904b", size = 313814, upload-time = "2025-10-17T11:29:58.509Z" }, - { url = "https://files.pythonhosted.org/packages/76/50/62a0683dadca25490a4bedc6a88d59de9af2a3406dd5a576009a73a1d392/jiter-0.11.1-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:aa22c223a3041dacb2fcd37c70dfd648b44662b4a48e242592f95bda5ab09d58", size = 344987, upload-time = "2025-10-17T11:30:00.208Z" }, - { url = "https://files.pythonhosted.org/packages/da/00/2355dbfcbf6cdeaddfdca18287f0f38ae49446bb6378e4a5971e9356fc8a/jiter-0.11.1-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:330e8e6a11ad4980cd66a0f4a3e0e2e0f646c911ce047014f984841924729789", size = 356399, upload-time = "2025-10-17T11:30:02.084Z" }, - { url = "https://files.pythonhosted.org/packages/8d/00/d6006d069e7b076e4c66af90656b63da9481954f290d5eca8c715f4bf125/jiter-0.11.1-cp314-cp314-macosx_10_12_x86_64.whl", hash = "sha256:0fa1f70da7a8a9713ff8e5f75ec3f90c0c870be6d526aa95e7c906f6a1c8c676", size = 304624, upload-time = "2025-10-17T11:30:06.678Z" }, - { url = "https://files.pythonhosted.org/packages/fc/45/4a0e31eb996b9ccfddbae4d3017b46f358a599ccf2e19fbffa5e531bd304/jiter-0.11.1-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:569ee559e5046a42feb6828c55307cf20fe43308e3ae0d8e9e4f8d8634d99944", size = 315042, upload-time = "2025-10-17T11:30:08.87Z" }, - { url = "https://files.pythonhosted.org/packages/e7/91/22f5746f5159a28c76acdc0778801f3c1181799aab196dbea2d29e064968/jiter-0.11.1-cp314-cp314-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f69955fa1d92e81987f092b233f0be49d4c937da107b7f7dcf56306f1d3fcce9", size = 346357, upload-time = "2025-10-17T11:30:10.222Z" }, - { url = "https://files.pythonhosted.org/packages/f5/4f/57620857d4e1dc75c8ff4856c90cb6c135e61bff9b4ebfb5dc86814e82d7/jiter-0.11.1-cp314-cp314-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:090f4c9d4a825e0fcbd0a2647c9a88a0f366b75654d982d95a9590745ff0c48d", size = 365057, upload-time = "2025-10-17T11:30:11.585Z" }, - { url = "https://files.pythonhosted.org/packages/ce/34/caf7f9cc8ae0a5bb25a5440cc76c7452d264d1b36701b90fdadd28fe08ec/jiter-0.11.1-cp314-cp314-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bbf3d8cedf9e9d825233e0dcac28ff15c47b7c5512fdfe2e25fd5bbb6e6b0cee", size = 487086, upload-time = "2025-10-17T11:30:13.052Z" }, - { url = "https://files.pythonhosted.org/packages/50/17/85b5857c329d533d433fedf98804ebec696004a1f88cabad202b2ddc55cf/jiter-0.11.1-cp314-cp314-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2aa9b1958f9c30d3d1a558b75f0626733c60eb9b7774a86b34d88060be1e67fe", size = 376083, upload-time = "2025-10-17T11:30:14.416Z" }, - { url = "https://files.pythonhosted.org/packages/85/d3/2d9f973f828226e6faebdef034097a2918077ea776fb4d88489949024787/jiter-0.11.1-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e42d1ca16590b768c5e7d723055acd2633908baacb3628dd430842e2e035aa90", size = 357825, upload-time = "2025-10-17T11:30:15.765Z" }, - { url = "https://files.pythonhosted.org/packages/f4/55/848d4dabf2c2c236a05468c315c2cb9dc736c5915e65449ccecdba22fb6f/jiter-0.11.1-cp314-cp314-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:5db4c2486a023820b701a17aec9c5a6173c5ba4393f26662f032f2de9c848b0f", size = 383933, upload-time = "2025-10-17T11:30:17.34Z" }, - { url = "https://files.pythonhosted.org/packages/0b/6c/204c95a4fbb0e26dfa7776c8ef4a878d0c0b215868011cc904bf44f707e2/jiter-0.11.1-cp314-cp314-musllinux_1_1_aarch64.whl", hash = "sha256:4573b78777ccfac954859a6eff45cbd9d281d80c8af049d0f1a3d9fc323d5c3a", size = 517118, upload-time = "2025-10-17T11:30:18.684Z" }, - { url = "https://files.pythonhosted.org/packages/88/25/09956644ea5a2b1e7a2a0f665cb69a973b28f4621fa61fc0c0f06ff40a31/jiter-0.11.1-cp314-cp314-musllinux_1_1_x86_64.whl", hash = "sha256:7593ac6f40831d7961cb67633c39b9fef6689a211d7919e958f45710504f52d3", size = 508194, upload-time = "2025-10-17T11:30:20.719Z" }, - { url = "https://files.pythonhosted.org/packages/d5/fa/3b05e5c9d32efc770a8510eeb0b071c42ae93a5b576fd91cee9af91689a1/jiter-0.11.1-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:2cc5a3965285ddc33e0cab933e96b640bc9ba5940cea27ebbbf6695e72d6511c", size = 312561, upload-time = "2025-10-17T11:30:26.742Z" }, - { url = "https://files.pythonhosted.org/packages/50/d3/335822eb216154ddb79a130cbdce88fdf5c3e2b43dc5dba1fd95c485aaf5/jiter-0.11.1-cp314-cp314t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b572b3636a784c2768b2342f36a23078c8d3aa6d8a30745398b1bab58a6f1a8", size = 344551, upload-time = "2025-10-17T11:30:28.252Z" }, - { url = "https://files.pythonhosted.org/packages/31/6d/a0bed13676b1398f9b3ba61f32569f20a3ff270291161100956a577b2dd3/jiter-0.11.1-cp314-cp314t-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ad93e3d67a981f96596d65d2298fe8d1aa649deb5374a2fb6a434410ee11915e", size = 363051, upload-time = "2025-10-17T11:30:30.009Z" }, - { url = "https://files.pythonhosted.org/packages/a4/03/313eda04aa08545a5a04ed5876e52f49ab76a4d98e54578896ca3e16313e/jiter-0.11.1-cp314-cp314t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a83097ce379e202dcc3fe3fc71a16d523d1ee9192c8e4e854158f96b3efe3f2f", size = 485897, upload-time = "2025-10-17T11:30:31.429Z" }, - { url = "https://files.pythonhosted.org/packages/5f/13/a1011b9d325e40b53b1b96a17c010b8646013417f3902f97a86325b19299/jiter-0.11.1-cp314-cp314t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7042c51e7fbeca65631eb0c332f90c0c082eab04334e7ccc28a8588e8e2804d9", size = 375224, upload-time = "2025-10-17T11:30:33.18Z" }, - { url = "https://files.pythonhosted.org/packages/92/da/1b45026b19dd39b419e917165ff0ea629dbb95f374a3a13d2df95e40a6ac/jiter-0.11.1-cp314-cp314t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0a68d679c0e47649a61df591660507608adc2652442de7ec8276538ac46abe08", size = 356606, upload-time = "2025-10-17T11:30:34.572Z" }, - { url = "https://files.pythonhosted.org/packages/7a/0c/9acb0e54d6a8ba59ce923a180ebe824b4e00e80e56cefde86cc8e0a948be/jiter-0.11.1-cp314-cp314t-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a1b0da75dbf4b6ec0b3c9e604d1ee8beaf15bc046fff7180f7d89e3cdbd3bb51", size = 384003, upload-time = "2025-10-17T11:30:35.987Z" }, - { url = "https://files.pythonhosted.org/packages/3f/2b/e5a5fe09d6da2145e4eed651e2ce37f3c0cf8016e48b1d302e21fb1628b7/jiter-0.11.1-cp314-cp314t-musllinux_1_1_aarch64.whl", hash = "sha256:69dd514bf0fa31c62147d6002e5ca2b3e7ef5894f5ac6f0a19752385f4e89437", size = 516946, upload-time = "2025-10-17T11:30:37.425Z" }, - { url = "https://files.pythonhosted.org/packages/5f/fe/db936e16e0228d48eb81f9934e8327e9fde5185e84f02174fcd22a01be87/jiter-0.11.1-cp314-cp314t-musllinux_1_1_x86_64.whl", hash = "sha256:bb31ac0b339efa24c0ca606febd8b77ef11c58d09af1b5f2be4c99e907b11111", size = 507614, upload-time = "2025-10-17T11:30:38.977Z" }, -] - [[package]] name = "linkify-it-py" version = "2.0.3" @@ -969,25 +892,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e2/c1/6dba12fdf68b02a21ac411c9df19afa66bed2540f467150ca64d246b463d/numpy-2.3.4-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:e1708fac43ef8b419c975926ce1eaf793b0c13b7356cfab6ab0dc34c0a02ac0f", size = 18652691, upload-time = "2025-10-15T16:17:46.247Z" }, ] -[[package]] -name = "openai" -version = "2.7.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "anyio", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "distro", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "httpx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "jiter", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "pydantic", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "sniffio", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "tqdm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "typing-extensions", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/84/2c/3ca91dbd1a5b80c20fbd1e21d601f6afd7fd51927a1b27b08226b67ebd61/openai-2.7.0.tar.gz", hash = "sha256:8c42c24d06afece19e69afcb6c2b23b8b90f603a81616d8a0be80b80fb527ed2", size = 595876, upload-time = "2025-11-03T23:52:07.935Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/fc/0f/e9618a92a9497846a3071f2a7ed43409215947106c7e5ce7d082f784de10/openai-2.7.0-py3-none-any.whl", hash = "sha256:9fc44861a692b7e80a7ec1252c10af79612a3ef1581ecb192caf4585afca5363", size = 1008759, upload-time = "2025-11-03T23:52:05.322Z" }, -] - [[package]] name = "packaging" version = "25.0" @@ -1213,6 +1117,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/04/93/2fa34714b7a4ae72f2f8dad66ba17dd9a2c793220719e736dda28b7aec27/pytest_asyncio-1.2.0-py3-none-any.whl", hash = "sha256:8e17ae5e46d8e7efe51ab6494dd2010f4ca8dae51652aa3c8d55acf50bfb2e99", size = 15095, upload-time = "2025-09-12T07:33:52.639Z" }, ] +[[package]] +name = "pytest-env" +version = "1.2.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytest", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/13/12/9c87d0ca45d5992473208bcef2828169fa7d39b8d7fc6e3401f5c08b8bf7/pytest_env-1.2.0.tar.gz", hash = "sha256:475e2ebe8626cee01f491f304a74b12137742397d6c784ea4bc258f069232b80", size = 8973, upload-time = "2025-10-09T19:15:47.42Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/27/98/822b924a4a3eb58aacba84444c7439fce32680592f394de26af9c76e2569/pytest_env-1.2.0-py3-none-any.whl", hash = "sha256:d7e5b7198f9b83c795377c09feefa45d56083834e60d04767efd64819fc9da00", size = 6251, upload-time = "2025-10-09T19:15:46.077Z" }, +] + [[package]] name = "pyyaml" version = "6.0.3" @@ -1468,32 +1384,32 @@ dependencies = [ { name = "regex", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "requests", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/7d/ab/4d017d0f76ec3171d469d80fc03dfbb4e48a4bcaddaa831b31d526f05edc/tiktoken-0.12.0.tar.gz", hash = "sha256:b18ba7ee2b093863978fcb14f74b3707cdc8d4d4d3836853ce7ec60772139931", size = 37806 } +sdist = { url = "https://files.pythonhosted.org/packages/7d/ab/4d017d0f76ec3171d469d80fc03dfbb4e48a4bcaddaa831b31d526f05edc/tiktoken-0.12.0.tar.gz", hash = "sha256:b18ba7ee2b093863978fcb14f74b3707cdc8d4d4d3836853ce7ec60772139931", size = 37806, upload-time = "2025-10-06T20:22:45.419Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/00/61/441588ee21e6b5cdf59d6870f86beb9789e532ee9718c251b391b70c68d6/tiktoken-0.12.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:775c2c55de2310cc1bc9a3ad8826761cbdc87770e586fd7b6da7d4589e13dab3", size = 1050802 }, - { url = "https://files.pythonhosted.org/packages/1f/05/dcf94486d5c5c8d34496abe271ac76c5b785507c8eae71b3708f1ad9b45a/tiktoken-0.12.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a01b12f69052fbe4b080a2cfb867c4de12c704b56178edf1d1d7b273561db160", size = 993995 }, - { url = "https://files.pythonhosted.org/packages/a0/70/5163fe5359b943f8db9946b62f19be2305de8c3d78a16f629d4165e2f40e/tiktoken-0.12.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:01d99484dc93b129cd0964f9d34eee953f2737301f18b3c7257bf368d7615baa", size = 1128948 }, - { url = "https://files.pythonhosted.org/packages/0c/da/c028aa0babf77315e1cef357d4d768800c5f8a6de04d0eac0f377cb619fa/tiktoken-0.12.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:4a1a4fcd021f022bfc81904a911d3df0f6543b9e7627b51411da75ff2fe7a1be", size = 1151986 }, - { url = "https://files.pythonhosted.org/packages/a0/5a/886b108b766aa53e295f7216b509be95eb7d60b166049ce2c58416b25f2a/tiktoken-0.12.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:981a81e39812d57031efdc9ec59fa32b2a5a5524d20d4776574c4b4bd2e9014a", size = 1194222 }, - { url = "https://files.pythonhosted.org/packages/f4/f8/4db272048397636ac7a078d22773dd2795b1becee7bc4922fe6207288d57/tiktoken-0.12.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:9baf52f84a3f42eef3ff4e754a0db79a13a27921b457ca9832cf944c6be4f8f3", size = 1255097 }, - { url = "https://files.pythonhosted.org/packages/ce/76/994fc868f88e016e6d05b0da5ac24582a14c47893f4474c3e9744283f1d5/tiktoken-0.12.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:d5f89ea5680066b68bcb797ae85219c72916c922ef0fcdd3480c7d2315ffff16", size = 1050309 }, - { url = "https://files.pythonhosted.org/packages/f6/b8/57ef1456504c43a849821920d582a738a461b76a047f352f18c0b26c6516/tiktoken-0.12.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:b4e7ed1c6a7a8a60a3230965bdedba8cc58f68926b835e519341413370e0399a", size = 993712 }, - { url = "https://files.pythonhosted.org/packages/72/90/13da56f664286ffbae9dbcfadcc625439142675845baa62715e49b87b68b/tiktoken-0.12.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:fc530a28591a2d74bce821d10b418b26a094bf33839e69042a6e86ddb7a7fb27", size = 1128725 }, - { url = "https://files.pythonhosted.org/packages/05/df/4f80030d44682235bdaecd7346c90f67ae87ec8f3df4a3442cb53834f7e4/tiktoken-0.12.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:06a9f4f49884139013b138920a4c393aa6556b2f8f536345f11819389c703ebb", size = 1151875 }, - { url = "https://files.pythonhosted.org/packages/22/1f/ae535223a8c4ef4c0c1192e3f9b82da660be9eb66b9279e95c99288e9dab/tiktoken-0.12.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:04f0e6a985d95913cabc96a741c5ffec525a2c72e9df086ff17ebe35985c800e", size = 1194451 }, - { url = "https://files.pythonhosted.org/packages/78/a7/f8ead382fce0243cb625c4f266e66c27f65ae65ee9e77f59ea1653b6d730/tiktoken-0.12.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:0ee8f9ae00c41770b5f9b0bb1235474768884ae157de3beb5439ca0fd70f3e25", size = 1253794 }, - { url = "https://files.pythonhosted.org/packages/72/05/3abc1db5d2c9aadc4d2c76fa5640134e475e58d9fbb82b5c535dc0de9b01/tiktoken-0.12.0-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:a90388128df3b3abeb2bfd1895b0681412a8d7dc644142519e6f0a97c2111646", size = 1050188 }, - { url = "https://files.pythonhosted.org/packages/e3/7b/50c2f060412202d6c95f32b20755c7a6273543b125c0985d6fa9465105af/tiktoken-0.12.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:da900aa0ad52247d8794e307d6446bd3cdea8e192769b56276695d34d2c9aa88", size = 993978 }, - { url = "https://files.pythonhosted.org/packages/14/27/bf795595a2b897e271771cd31cb847d479073497344c637966bdf2853da1/tiktoken-0.12.0-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:285ba9d73ea0d6171e7f9407039a290ca77efcdb026be7769dccc01d2c8d7fff", size = 1129271 }, - { url = "https://files.pythonhosted.org/packages/f5/de/9341a6d7a8f1b448573bbf3425fa57669ac58258a667eb48a25dfe916d70/tiktoken-0.12.0-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:d186a5c60c6a0213f04a7a802264083dea1bbde92a2d4c7069e1a56630aef830", size = 1151216 }, - { url = "https://files.pythonhosted.org/packages/75/0d/881866647b8d1be4d67cb24e50d0c26f9f807f994aa1510cb9ba2fe5f612/tiktoken-0.12.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:604831189bd05480f2b885ecd2d1986dc7686f609de48208ebbbddeea071fc0b", size = 1194860 }, - { url = "https://files.pythonhosted.org/packages/b3/1e/b651ec3059474dab649b8d5b69f5c65cd8fcd8918568c1935bd4136c9392/tiktoken-0.12.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:8f317e8530bb3a222547b85a58583238c8f74fd7a7408305f9f63246d1a0958b", size = 1254567 }, - { url = "https://files.pythonhosted.org/packages/ac/a4/72eed53e8976a099539cdd5eb36f241987212c29629d0a52c305173e0a68/tiktoken-0.12.0-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:c2c714c72bc00a38ca969dae79e8266ddec999c7ceccd603cc4f0d04ccd76365", size = 1050473 }, - { url = "https://files.pythonhosted.org/packages/e6/d7/0110b8f54c008466b19672c615f2168896b83706a6611ba6e47313dbc6e9/tiktoken-0.12.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:cbb9a3ba275165a2cb0f9a83f5d7025afe6b9d0ab01a22b50f0e74fee2ad253e", size = 993855 }, - { url = "https://files.pythonhosted.org/packages/5f/77/4f268c41a3957c418b084dd576ea2fad2e95da0d8e1ab705372892c2ca22/tiktoken-0.12.0-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:dfdfaa5ffff8993a3af94d1125870b1d27aed7cb97aa7eb8c1cefdbc87dbee63", size = 1129022 }, - { url = "https://files.pythonhosted.org/packages/4e/2b/fc46c90fe5028bd094cd6ee25a7db321cb91d45dc87531e2bdbb26b4867a/tiktoken-0.12.0-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:584c3ad3d0c74f5269906eb8a659c8bfc6144a52895d9261cdaf90a0ae5f4de0", size = 1150736 }, - { url = "https://files.pythonhosted.org/packages/28/c0/3c7a39ff68022ddfd7d93f3337ad90389a342f761c4d71de99a3ccc57857/tiktoken-0.12.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:54c891b416a0e36b8e2045b12b33dd66fb34a4fe7965565f1b482da50da3e86a", size = 1194908 }, - { url = "https://files.pythonhosted.org/packages/ab/0d/c1ad6f4016a3968c048545f5d9b8ffebf577774b2ede3e2e352553b685fe/tiktoken-0.12.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:5edb8743b88d5be814b1a8a8854494719080c28faaa1ccbef02e87354fe71ef0", size = 1253706 }, + { url = "https://files.pythonhosted.org/packages/00/61/441588ee21e6b5cdf59d6870f86beb9789e532ee9718c251b391b70c68d6/tiktoken-0.12.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:775c2c55de2310cc1bc9a3ad8826761cbdc87770e586fd7b6da7d4589e13dab3", size = 1050802, upload-time = "2025-10-06T20:22:00.96Z" }, + { url = "https://files.pythonhosted.org/packages/1f/05/dcf94486d5c5c8d34496abe271ac76c5b785507c8eae71b3708f1ad9b45a/tiktoken-0.12.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a01b12f69052fbe4b080a2cfb867c4de12c704b56178edf1d1d7b273561db160", size = 993995, upload-time = "2025-10-06T20:22:02.788Z" }, + { url = "https://files.pythonhosted.org/packages/a0/70/5163fe5359b943f8db9946b62f19be2305de8c3d78a16f629d4165e2f40e/tiktoken-0.12.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:01d99484dc93b129cd0964f9d34eee953f2737301f18b3c7257bf368d7615baa", size = 1128948, upload-time = "2025-10-06T20:22:03.814Z" }, + { url = "https://files.pythonhosted.org/packages/0c/da/c028aa0babf77315e1cef357d4d768800c5f8a6de04d0eac0f377cb619fa/tiktoken-0.12.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:4a1a4fcd021f022bfc81904a911d3df0f6543b9e7627b51411da75ff2fe7a1be", size = 1151986, upload-time = "2025-10-06T20:22:05.173Z" }, + { url = "https://files.pythonhosted.org/packages/a0/5a/886b108b766aa53e295f7216b509be95eb7d60b166049ce2c58416b25f2a/tiktoken-0.12.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:981a81e39812d57031efdc9ec59fa32b2a5a5524d20d4776574c4b4bd2e9014a", size = 1194222, upload-time = "2025-10-06T20:22:06.265Z" }, + { url = "https://files.pythonhosted.org/packages/f4/f8/4db272048397636ac7a078d22773dd2795b1becee7bc4922fe6207288d57/tiktoken-0.12.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:9baf52f84a3f42eef3ff4e754a0db79a13a27921b457ca9832cf944c6be4f8f3", size = 1255097, upload-time = "2025-10-06T20:22:07.403Z" }, + { url = "https://files.pythonhosted.org/packages/ce/76/994fc868f88e016e6d05b0da5ac24582a14c47893f4474c3e9744283f1d5/tiktoken-0.12.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:d5f89ea5680066b68bcb797ae85219c72916c922ef0fcdd3480c7d2315ffff16", size = 1050309, upload-time = "2025-10-06T20:22:10.939Z" }, + { url = "https://files.pythonhosted.org/packages/f6/b8/57ef1456504c43a849821920d582a738a461b76a047f352f18c0b26c6516/tiktoken-0.12.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:b4e7ed1c6a7a8a60a3230965bdedba8cc58f68926b835e519341413370e0399a", size = 993712, upload-time = "2025-10-06T20:22:12.115Z" }, + { url = "https://files.pythonhosted.org/packages/72/90/13da56f664286ffbae9dbcfadcc625439142675845baa62715e49b87b68b/tiktoken-0.12.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:fc530a28591a2d74bce821d10b418b26a094bf33839e69042a6e86ddb7a7fb27", size = 1128725, upload-time = "2025-10-06T20:22:13.541Z" }, + { url = "https://files.pythonhosted.org/packages/05/df/4f80030d44682235bdaecd7346c90f67ae87ec8f3df4a3442cb53834f7e4/tiktoken-0.12.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:06a9f4f49884139013b138920a4c393aa6556b2f8f536345f11819389c703ebb", size = 1151875, upload-time = "2025-10-06T20:22:14.559Z" }, + { url = "https://files.pythonhosted.org/packages/22/1f/ae535223a8c4ef4c0c1192e3f9b82da660be9eb66b9279e95c99288e9dab/tiktoken-0.12.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:04f0e6a985d95913cabc96a741c5ffec525a2c72e9df086ff17ebe35985c800e", size = 1194451, upload-time = "2025-10-06T20:22:15.545Z" }, + { url = "https://files.pythonhosted.org/packages/78/a7/f8ead382fce0243cb625c4f266e66c27f65ae65ee9e77f59ea1653b6d730/tiktoken-0.12.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:0ee8f9ae00c41770b5f9b0bb1235474768884ae157de3beb5439ca0fd70f3e25", size = 1253794, upload-time = "2025-10-06T20:22:16.624Z" }, + { url = "https://files.pythonhosted.org/packages/72/05/3abc1db5d2c9aadc4d2c76fa5640134e475e58d9fbb82b5c535dc0de9b01/tiktoken-0.12.0-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:a90388128df3b3abeb2bfd1895b0681412a8d7dc644142519e6f0a97c2111646", size = 1050188, upload-time = "2025-10-06T20:22:19.563Z" }, + { url = "https://files.pythonhosted.org/packages/e3/7b/50c2f060412202d6c95f32b20755c7a6273543b125c0985d6fa9465105af/tiktoken-0.12.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:da900aa0ad52247d8794e307d6446bd3cdea8e192769b56276695d34d2c9aa88", size = 993978, upload-time = "2025-10-06T20:22:20.702Z" }, + { url = "https://files.pythonhosted.org/packages/14/27/bf795595a2b897e271771cd31cb847d479073497344c637966bdf2853da1/tiktoken-0.12.0-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:285ba9d73ea0d6171e7f9407039a290ca77efcdb026be7769dccc01d2c8d7fff", size = 1129271, upload-time = "2025-10-06T20:22:22.06Z" }, + { url = "https://files.pythonhosted.org/packages/f5/de/9341a6d7a8f1b448573bbf3425fa57669ac58258a667eb48a25dfe916d70/tiktoken-0.12.0-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:d186a5c60c6a0213f04a7a802264083dea1bbde92a2d4c7069e1a56630aef830", size = 1151216, upload-time = "2025-10-06T20:22:23.085Z" }, + { url = "https://files.pythonhosted.org/packages/75/0d/881866647b8d1be4d67cb24e50d0c26f9f807f994aa1510cb9ba2fe5f612/tiktoken-0.12.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:604831189bd05480f2b885ecd2d1986dc7686f609de48208ebbbddeea071fc0b", size = 1194860, upload-time = "2025-10-06T20:22:24.602Z" }, + { url = "https://files.pythonhosted.org/packages/b3/1e/b651ec3059474dab649b8d5b69f5c65cd8fcd8918568c1935bd4136c9392/tiktoken-0.12.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:8f317e8530bb3a222547b85a58583238c8f74fd7a7408305f9f63246d1a0958b", size = 1254567, upload-time = "2025-10-06T20:22:25.671Z" }, + { url = "https://files.pythonhosted.org/packages/ac/a4/72eed53e8976a099539cdd5eb36f241987212c29629d0a52c305173e0a68/tiktoken-0.12.0-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:c2c714c72bc00a38ca969dae79e8266ddec999c7ceccd603cc4f0d04ccd76365", size = 1050473, upload-time = "2025-10-06T20:22:27.775Z" }, + { url = "https://files.pythonhosted.org/packages/e6/d7/0110b8f54c008466b19672c615f2168896b83706a6611ba6e47313dbc6e9/tiktoken-0.12.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:cbb9a3ba275165a2cb0f9a83f5d7025afe6b9d0ab01a22b50f0e74fee2ad253e", size = 993855, upload-time = "2025-10-06T20:22:28.799Z" }, + { url = "https://files.pythonhosted.org/packages/5f/77/4f268c41a3957c418b084dd576ea2fad2e95da0d8e1ab705372892c2ca22/tiktoken-0.12.0-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:dfdfaa5ffff8993a3af94d1125870b1d27aed7cb97aa7eb8c1cefdbc87dbee63", size = 1129022, upload-time = "2025-10-06T20:22:29.981Z" }, + { url = "https://files.pythonhosted.org/packages/4e/2b/fc46c90fe5028bd094cd6ee25a7db321cb91d45dc87531e2bdbb26b4867a/tiktoken-0.12.0-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:584c3ad3d0c74f5269906eb8a659c8bfc6144a52895d9261cdaf90a0ae5f4de0", size = 1150736, upload-time = "2025-10-06T20:22:30.996Z" }, + { url = "https://files.pythonhosted.org/packages/28/c0/3c7a39ff68022ddfd7d93f3337ad90389a342f761c4d71de99a3ccc57857/tiktoken-0.12.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:54c891b416a0e36b8e2045b12b33dd66fb34a4fe7965565f1b482da50da3e86a", size = 1194908, upload-time = "2025-10-06T20:22:32.073Z" }, + { url = "https://files.pythonhosted.org/packages/ab/0d/c1ad6f4016a3968c048545f5d9b8ffebf577774b2ede3e2e352553b685fe/tiktoken-0.12.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:5edb8743b88d5be814b1a8a8854494719080c28faaa1ccbef02e87354fe71ef0", size = 1253706, upload-time = "2025-10-06T20:22:33.385Z" }, ] [[package]]