Consolidate cleanup

This commit is contained in:
rltakashige
2025-11-21 14:54:02 +00:00
committed by GitHub
parent 28a91787e8
commit b45cbdeecd
72 changed files with 634 additions and 4854 deletions

View File

@@ -4,9 +4,8 @@
<inspection_tool class="PyCompatibilityInspection" enabled="true" level="WARNING" enabled_by_default="true"> <inspection_tool class="PyCompatibilityInspection" enabled="true" level="WARNING" enabled_by_default="true">
<option name="ourVersions"> <option name="ourVersions">
<value> <value>
<list size="2"> <list size="1">
<item index="0" class="java.lang.String" itemvalue="2.7" /> <item index="0" class="java.lang.String" itemvalue="3.14" />
<item index="1" class="java.lang.String" itemvalue="3.14" />
</list> </list>
</value> </value>
</option> </option>

View File

@@ -2,14 +2,24 @@
This type stub file was generated by pyright. This type stub file was generated by pyright.
""" """
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional, Protocol, Literal, Self
import mlx.nn as nn import mlx.nn as nn
from mlx.core import array from mlx.core import array
import mlx.core as mx
class Cache(Protocol):
keys: mx.array
values: mx.array
def update_and_fetch(self, keys: mx.array, values: mx.array) -> None: ...
@property
def state(self) -> tuple[mx.array, mx.array]: ...
@state.setter
def state(self, v) -> None: ...
def make_prompt_cache( def make_prompt_cache(
model: nn.Module, max_kv_size: Optional[int] = ... model: nn.Module, max_kv_size: Optional[int] = ...
) -> List[KVCache | Any]: ) -> List[Cache | Any]:
""" """
Construct the model's cache for use in generation. Construct the model's cache for use in generation.
@@ -24,7 +34,7 @@ def make_prompt_cache(
""" """
def save_prompt_cache( def save_prompt_cache(
file_name: str, cache: List[Any], metadata: Dict[str, str] = ... file_name: str, cache: List[Cache], metadata: Dict[str, str] = ...
) -> None: ) -> None:
""" """
Save a pre-computed prompt cache to a file. Save a pre-computed prompt cache to a file.
@@ -50,12 +60,12 @@ def load_prompt_cache(file_name: str, return_metadata=...) -> array:
the metadata if requested. the metadata if requested.
""" """
def can_trim_prompt_cache(cache: List[Any]) -> bool: def can_trim_prompt_cache(cache: List[Cache]) -> bool:
""" """
Check if model's cache can be trimmed. Check if model's cache can be trimmed.
""" """
def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]: def trim_prompt_cache(cache: List[Cache], num_tokens: int) -> List[Cache]:
""" """
Trim the model's cache by the given number of tokens. Trim the model's cache by the given number of tokens.
@@ -72,27 +82,22 @@ def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]:
def create_attention_mask( def create_attention_mask(
N: int, offset: int, return_array: bool, window_size: Optional[int] N: int, offset: int, return_array: bool, window_size: Optional[int]
): # -> array | Literal['causal'] | None: ) -> array | Literal["causal"] | None: ...
...
class _BaseCache: class _BaseCache(Cache):
keys: mx.array
values: mx.array
@property @property
def state(self): # -> list[Any]: def state(self) -> tuple[mx.array, mx.array]: ...
...
@state.setter @state.setter
def state(self, v): # -> None: def state(self, v) -> None: ...
...
@property @property
def meta_state(self): # -> Literal['']: def meta_state(self) -> Literal[""]: ...
...
@meta_state.setter @meta_state.setter
def meta_state(self, v): # -> None: def meta_state(self, v) -> None: ...
... def is_trimmable(self) -> Literal[False]: ...
def is_trimmable(self): # -> Literal[False]:
...
@classmethod @classmethod
def from_state(cls, state, meta_state): # -> Self: def from_state(cls, state, meta_state) -> Self: ...
...
class ConcatenateKVCache(_BaseCache): class ConcatenateKVCache(_BaseCache):
"""ConcatenateKVCache the simplest KV cache implementation. """ConcatenateKVCache the simplest KV cache implementation.

35
TODO.md
View File

@@ -1,6 +1,5 @@
1. Currently EXO just doesn't start cleanly a lot of the time. I see two kinds of issues:
b. EXO starts but then after creating an instance that instance never loads (either gets stuck in Loading of Inactive).
2. Currently a lot of requests from the API are timing out, but we still process those requests internally. If an API request times out, we should cancel all corresponding tasks to that API request (why process a request with nobody listening). 2. Currently a lot of requests from the API are timing out, but we still process those requests internally. If an API request times out, we should cancel all corresponding tasks to that API request (why process a request with nobody listening).
3. Task cancellation. When API http request gets cancelled, it should cancel corresponding task.
4. I'd like to see profiled network latency / bandwidth. 4. I'd like to see profiled network latency / bandwidth.
5. I'd like to see how much bandwidth each link is using. 5. I'd like to see how much bandwidth each link is using.
6. We should handle the case where one machine doesn't have the model downloaded and then other machines are waiting on it. In this case we get loads of timeout errors because the others are waiting for the one that needs to download the model. 6. We should handle the case where one machine doesn't have the model downloaded and then other machines are waiting on it. In this case we get loads of timeout errors because the others are waiting for the one that needs to download the model.
@@ -14,41 +13,13 @@
16. Dynamically switch to higher priority connection when it becomes available. Probably bring back InstanceReplacedAtomically. 16. Dynamically switch to higher priority connection when it becomes available. Probably bring back InstanceReplacedAtomically.
17. Faster model loads by streaming model from other devices in cluster. 17. Faster model loads by streaming model from other devices in cluster.
18. Add support for specifying the type of network connection to use in a test. Depends on 15/16. 18. Add support for specifying the type of network connection to use in a test. Depends on 15/16.
19. Fix mx.distributed.Group typing.
20. Add chat completion cancellations (e.g OpenWebUI has something for cancelling an ongoing request). 20. Add chat completion cancellations (e.g OpenWebUI has something for cancelling an ongoing request).
21. Make two separate things: tensor or pipeline, and ring or ibv.
22. When downloading for the first time, stuff times out and I think the model never ends up actually loading into memory, or something.
23. Do we need cache_limit? We went back and forth on that a lot because we thought it might be causing issues. One problem is it sets it relative to model size. So if you have multiple models loaded in it will take the most recent model size for the cache_limit. This is problematic if you launch DeepSeek -> Llama for example. 23. Do we need cache_limit? We went back and forth on that a lot because we thought it might be causing issues. One problem is it sets it relative to model size. So if you have multiple models loaded in it will take the most recent model size for the cache_limit. This is problematic if you launch DeepSeek -> Llama for example.
24. Task cancellation. When API http request gets cancelled, it should cancel corresponding task. 24. further openai/lmstudio api compatibility
25. Rethink retry logic
Potential refactors: Potential refactors:
1. Make ForwarderEvent typed
2. Topology can be simplified 2. Topology can be simplified
3. Get rid of InstanceReplacedAtomically
Random errors we've run into: Random errors we've run into:
1. exo.shared.types.worker.common.RunnerError: RuntimeError: [ibv] Couldn't connect (error: 60). Traceback: Traceback (most recent call last):
File "/Users/puffin4/actions-runner/_work/exo/exo/src/exo/worker/runner/runner.py", line 54, in main
model, tokenizer, sampler, group = await loop.run_in_executor(
^^^^^^^^^^^^^^^^^^^^^^^^^^^
...<8 lines>...
)
^
File "/nix/store/s7ik6dazn4nd2jdg9l36qf5q0z18sjyk-python3-3.13.8/lib/python3.13/concurrent/futures/thread.py", line 59, in run
result = self.fn(*self.args, **self.kwargs)
File "/Users/puffin4/actions-runner/_work/exo/exo/src/exo/engines/mlx/utils_mlx.py", line 149, in initialize_mlx
group = mlx_distributed_init(
model_shard_meta.device_rank,
...<4 lines>...
or (mlx_ibv_devices is not None and len(mlx_ibv_devices) > 1),
)
File "/Users/puffin4/actions-runner/_work/exo/exo/src/exo/engines/mlx/utils_mlx.py", line 124, in mlx_distributed_init
group = mx.distributed.init(
backend="ring" if hosts is not None else "ibv",
strict=strict,
)
RuntimeError: [ibv] Couldn't connect (error: 60)
2.

View File

@@ -17,7 +17,6 @@ dependencies = [
"filelock>=3.18.0", "filelock>=3.18.0",
"aiosqlite>=0.21.0", "aiosqlite>=0.21.0",
"networkx>=3.5", "networkx>=3.5",
"openai>=1.99.9",
"pathlib>=1.0.1", "pathlib>=1.0.1",
"protobuf>=6.32.0", "protobuf>=6.32.0",
"rich>=14.1.0", "rich>=14.1.0",
@@ -49,6 +48,7 @@ exo = "exo.main:main"
dev = [ dev = [
"pytest>=8.4.0", "pytest>=8.4.0",
"pytest-asyncio>=1.0.0", "pytest-asyncio>=1.0.0",
"pytest-env",
"ruff>=0.11.13", "ruff>=0.11.13",
] ]
@@ -131,4 +131,7 @@ asyncio_mode = "auto"
markers = [ markers = [
"slow: marks tests as slow (deselected by default)" "slow: marks tests as slow (deselected by default)"
] ]
env = [
"EXO_TESTS=1"
]
addopts = "-m 'not slow'" addopts = "-m 'not slow'"

View File

@@ -1,17 +0,0 @@
# TODO: Do we want so many constants?
KV_GROUP_SIZE = 32
KV_BITS = None
ATTENTION_KV_BITS = 4
MAX_TOKENS = 8192
MAX_KV_SIZE = 3200
KEEP_KV_SIZE = 1600
QUANTIZE_MODEL_MODE = "affine"
CACHE_GROUP_SIZE = 64
KV_CACHE_BITS = 8
TEMPERATURE = 1.0
# TODO: We should really make this opt-in, but Kimi requires trust_remote_code=True
TRUST_REMOTE_CODE = True
# TODO: Do we really want this?
HIDE_THINKING = False

View File

@@ -23,9 +23,7 @@ from exo.worker.download.impl_shard_downloader import exo_shard_downloader
from exo.worker.main import Worker from exo.worker.main import Worker
# TODO: Entrypoint refactor
# I marked this as a dataclass as I want trivial constructors. # I marked this as a dataclass as I want trivial constructors.
# This is the collection of systems for our entire application.
@dataclass @dataclass
class Node: class Node:
router: Router router: Router

View File

@@ -14,7 +14,6 @@ from fastapi.responses import StreamingResponse
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from loguru import logger from loguru import logger
from exo.engines.mlx.constants import HIDE_THINKING
from exo.shared.apply import apply from exo.shared.apply import apply
from exo.shared.election import ElectionMessage from exo.shared.election import ElectionMessage
from exo.shared.models.model_cards import MODEL_CARDS from exo.shared.models.model_cards import MODEL_CARDS
@@ -49,6 +48,7 @@ from exo.shared.types.worker.instances import Instance, InstanceId
from exo.utils.banner import print_startup_banner from exo.utils.banner import print_startup_banner
from exo.utils.channels import Receiver, Sender from exo.utils.channels import Receiver, Sender
from exo.utils.event_buffer import OrderedBuffer from exo.utils.event_buffer import OrderedBuffer
from exo.worker.engines.mlx.constants import HIDE_THINKING
def chunk_to_response( def chunk_to_response(

View File

@@ -240,8 +240,6 @@ def _find_connection_ip(
if ( if (
connection.local_node_id == node_i.node_id connection.local_node_id == node_i.node_id
and connection.send_back_node_id == node_j.node_id and connection.send_back_node_id == node_j.node_id
# TODO: Check if we need this.
and connection.send_back_multiaddr is not None
): ):
yield connection.send_back_multiaddr.ip_address yield connection.send_back_multiaddr.ip_address

View File

@@ -30,7 +30,7 @@ def create_node():
swap_available=1000, swap_available=1000,
), ),
network_interfaces=[], network_interfaces=[],
system=SystemPerformanceProfile(flops_fp16=1000), system=SystemPerformanceProfile(),
), ),
) )

View File

@@ -28,9 +28,14 @@ from exo.shared.types.profiling import (
NodePerformanceProfile, NodePerformanceProfile,
SystemPerformanceProfile, SystemPerformanceProfile,
) )
from exo.shared.types.tasks import ChatCompletionTask, TaskStatus from exo.shared.types.tasks import ChatCompletion as ChatCompletionTask
from exo.shared.types.worker.instances import Instance, InstanceStatus, ShardAssignments from exo.shared.types.tasks import TaskStatus
from exo.shared.types.worker.shards import PipelineShardMetadata from exo.shared.types.worker.instances import (
InstanceMeta,
MlxRingInstance,
ShardAssignments,
)
from exo.shared.types.worker.shards import PipelineShardMetadata, Sharding
from exo.utils.channels import channel from exo.utils.channels import channel
@@ -91,7 +96,7 @@ async def test_master():
swap_available=Memory.from_bytes(0), swap_available=Memory.from_bytes(0),
), ),
network_interfaces=[], network_interfaces=[],
system=SystemPerformanceProfile(flops_fp16=0), system=SystemPerformanceProfile(),
), ),
) )
), ),
@@ -118,7 +123,8 @@ async def test_master():
n_layers=16, n_layers=16,
storage_size=Memory.from_bytes(678948), storage_size=Memory.from_bytes(678948),
), ),
strategy="auto", sharding=Sharding.Pipeline,
instance_meta=InstanceMeta.MlxRing,
) )
), ),
) )
@@ -160,9 +166,8 @@ async def test_master():
)[0] )[0]
assert events[1].event == InstanceCreated( assert events[1].event == InstanceCreated(
event_id=events[1].event.event_id, event_id=events[1].event.event_id,
instance=Instance( instance=MlxRingInstance(
instance_id=events[1].event.instance.instance_id, instance_id=events[1].event.instance.instance_id,
instance_type=InstanceStatus.Active,
shard_assignments=ShardAssignments( shard_assignments=ShardAssignments(
model_id=ModelId("llama-3.2-1b"), model_id=ModelId("llama-3.2-1b"),
runner_to_shard={ runner_to_shard={
@@ -186,22 +191,13 @@ async def test_master():
), ),
) )
assert isinstance(events[2].event, TaskCreated) assert isinstance(events[2].event, TaskCreated)
assert events[2].event == TaskCreated( assert events[2].event.task.task_status == TaskStatus.Pending
event_id=events[2].event.event_id, assert isinstance(events[2].event.task, ChatCompletionTask)
task_id=events[2].event.task_id, assert events[2].event.task.task_params == ChatCompletionTaskParams(
task=ChatCompletionTask( model="llama-3.2-1b",
task_id=events[2].event.task_id, messages=[
command_id=events[2].event.task.command_id, ChatCompletionMessage(role="user", content="Hello, how are you?")
instance_id=events[2].event.task.instance_id, ],
task_status=TaskStatus.Pending,
task_params=ChatCompletionTaskParams(
model="llama-3.2-1b",
messages=[
ChatCompletionMessage(
role="user", content="Hello, how are you?"
)
],
),
),
) )
await master.shutdown() await master.shutdown()

View File

@@ -1,7 +1,6 @@
from typing import Callable from typing import Callable
import pytest import pytest
from loguru import logger
from exo.master.placement import ( from exo.master.placement import (
get_instance_placements_after_create, get_instance_placements_after_create,
@@ -15,9 +14,15 @@ from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.profiling import NetworkInterfaceInfo, NodePerformanceProfile from exo.shared.types.profiling import NetworkInterfaceInfo, NodePerformanceProfile
from exo.shared.types.topology import Connection, NodeInfo from exo.shared.types.topology import Connection, NodeInfo
from exo.shared.types.worker.common import InstanceId from exo.shared.types.worker.instances import (
from exo.shared.types.worker.instances import Instance, InstanceStatus Instance,
InstanceId,
InstanceMeta,
MlxIbvInstance,
MlxRingInstance,
)
from exo.shared.types.worker.runners import ShardAssignments from exo.shared.types.worker.runners import ShardAssignments
from exo.shared.types.worker.shards import Sharding
@pytest.fixture @pytest.fixture
@@ -27,9 +32,8 @@ def topology() -> Topology:
@pytest.fixture @pytest.fixture
def instance() -> Instance: def instance() -> Instance:
return Instance( return MlxRingInstance(
instance_id=InstanceId(), instance_id=InstanceId(),
instance_type=InstanceStatus.Active,
shard_assignments=ShardAssignments( shard_assignments=ShardAssignments(
model_id=ModelId("test-model"), runner_to_shard={}, node_to_runner={} model_id=ModelId("test-model"), runner_to_shard={}, node_to_runner={}
), ),
@@ -51,7 +55,8 @@ def create_instance_command(model_meta: ModelMetadata) -> CreateInstance:
return CreateInstance( return CreateInstance(
command_id=CommandId(), command_id=CommandId(),
model_meta=model_meta, model_meta=model_meta,
strategy="auto", sharding=Sharding.Pipeline,
instance_meta=InstanceMeta.MlxRing,
) )
@@ -78,11 +83,7 @@ def test_get_instance_placements_create_instance(
available_memory available_memory
) # make it exactly fit across all nodes ) # make it exactly fit across all nodes
create_instance_command = CreateInstance( cic = create_instance_command(model_meta)
command_id=CommandId(),
model_meta=model_meta,
strategy="auto",
)
node_id_a = NodeId() node_id_a = NodeId()
node_id_b = NodeId() node_id_b = NodeId()
node_id_c = NodeId() node_id_c = NodeId()
@@ -94,9 +95,7 @@ def test_get_instance_placements_create_instance(
topology.add_connection(create_connection(node_id_c, node_id_a)) topology.add_connection(create_connection(node_id_c, node_id_a))
# act # act
placements = get_instance_placements_after_create( placements = get_instance_placements_after_create(cic, topology, {})
create_instance_command, topology, {}
)
# assert # assert
assert len(placements) == 1 assert len(placements) == 1
@@ -128,19 +127,15 @@ def test_get_instance_placements_one_node_exact_fit(
topology = Topology() topology = Topology()
node_id = NodeId() node_id = NodeId()
topology.add_node(create_node(1000 * 1024, node_id)) topology.add_node(create_node(1000 * 1024, node_id))
create_instance_command = CreateInstance( cic = create_instance_command(
command_id=CommandId(), ModelMetadata(
model_meta=ModelMetadata(
model_id=ModelId("test-model"), model_id=ModelId("test-model"),
storage_size=Memory.from_kb(1000), storage_size=Memory.from_kb(1000),
pretty_name="Test Model", pretty_name="Test Model",
n_layers=10, n_layers=10,
), ),
strategy="auto",
)
placements = get_instance_placements_after_create(
create_instance_command, topology, {}
) )
placements = get_instance_placements_after_create(cic, topology, {})
assert len(placements) == 1 assert len(placements) == 1
instance_id = list(placements.keys())[0] instance_id = list(placements.keys())[0]
@@ -157,19 +152,15 @@ def test_get_instance_placements_one_node_fits_with_extra_memory(
topology = Topology() topology = Topology()
node_id = NodeId() node_id = NodeId()
topology.add_node(create_node(1001 * 1024, node_id)) topology.add_node(create_node(1001 * 1024, node_id))
create_instance_command = CreateInstance( cic = create_instance_command(
command_id=CommandId(), ModelMetadata(
model_meta=ModelMetadata(
model_id=ModelId("test-model"), model_id=ModelId("test-model"),
storage_size=Memory.from_kb(1000), storage_size=Memory.from_kb(1000),
pretty_name="Test Model", pretty_name="Test Model",
n_layers=10, n_layers=10,
), ),
strategy="auto",
)
placements = get_instance_placements_after_create(
create_instance_command, topology, {}
) )
placements = get_instance_placements_after_create(cic, topology, {})
assert len(placements) == 1 assert len(placements) == 1
instance_id = list(placements.keys())[0] instance_id = list(placements.keys())[0]
@@ -186,19 +177,17 @@ def test_get_instance_placements_one_node_not_fit(
topology = Topology() topology = Topology()
node_id = NodeId() node_id = NodeId()
topology.add_node(create_node(1000 * 1024, node_id)) topology.add_node(create_node(1000 * 1024, node_id))
create_instance_command = CreateInstance( cic = create_instance_command(
command_id=CommandId(),
model_meta=ModelMetadata( model_meta=ModelMetadata(
model_id=ModelId("test-model"), model_id=ModelId("test-model"),
storage_size=Memory.from_kb(1001), storage_size=Memory.from_kb(1001),
pretty_name="Test Model", pretty_name="Test Model",
n_layers=10, n_layers=10,
), ),
strategy="auto",
) )
with pytest.raises(ValueError, match="No cycles found with sufficient memory"): with pytest.raises(ValueError, match="No cycles found with sufficient memory"):
get_instance_placements_after_create(create_instance_command, topology, {}) get_instance_placements_after_create(cic, topology, {})
def test_get_transition_events_no_change(instance: Instance): def test_get_transition_events_no_change(instance: Instance):
@@ -301,16 +290,12 @@ def test_placement_prioritizes_leaf_cycle_with_less_memory(
topology.add_connection(create_connection(node_id_e, node_id_y)) topology.add_connection(create_connection(node_id_e, node_id_y))
topology.add_connection(create_connection(node_id_f, node_id_z)) topology.add_connection(create_connection(node_id_f, node_id_z))
create_instance_command = CreateInstance( cic = create_instance_command(
command_id=CommandId(),
model_meta=model_meta, model_meta=model_meta,
strategy="auto",
) )
# Act # Act
placements = get_instance_placements_after_create( placements = get_instance_placements_after_create(cic, topology, {})
create_instance_command, topology, {}
)
# Assert the chosen cycle is A-B-C (contains at least one leaf node), even though # Assert the chosen cycle is A-B-C (contains at least one leaf node), even though
# D-E-F has more total memory. # D-E-F has more total memory.
@@ -346,7 +331,6 @@ def test_tensor_rdma_backend_connectivity_matrix(
ethernet_interface = NetworkInterfaceInfo( ethernet_interface = NetworkInterfaceInfo(
name="en0", name="en0",
ip_address="192.168.1.100", ip_address="192.168.1.100",
type="ethernet",
) )
assert node_a.node_profile is not None assert node_a.node_profile is not None
@@ -377,13 +361,7 @@ def test_tensor_rdma_backend_connectivity_matrix(
network_interfaces=[ network_interfaces=[
NetworkInterfaceInfo( NetworkInterfaceInfo(
name="en3", name="en3",
ip_address=conn_c_a.send_back_multiaddr.ip_address, ip_address=conn_a_b.send_back_multiaddr.ip_address,
type="rdma",
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_b_a.send_back_multiaddr.ip_address,
type="rdma",
), ),
ethernet_interface, ethernet_interface,
], ],
@@ -395,15 +373,9 @@ def test_tensor_rdma_backend_connectivity_matrix(
friendly_name="test", friendly_name="test",
memory=node_b.node_profile.memory, memory=node_b.node_profile.memory,
network_interfaces=[ network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_c_b.send_back_multiaddr.ip_address,
type="rdma",
),
NetworkInterfaceInfo( NetworkInterfaceInfo(
name="en4", name="en4",
ip_address=conn_a_b.send_back_multiaddr.ip_address, ip_address=conn_b_c.send_back_multiaddr.ip_address,
type="rdma",
), ),
ethernet_interface, ethernet_interface,
], ],
@@ -416,14 +388,8 @@ def test_tensor_rdma_backend_connectivity_matrix(
memory=node_c.node_profile.memory, memory=node_c.node_profile.memory,
network_interfaces=[ network_interfaces=[
NetworkInterfaceInfo( NetworkInterfaceInfo(
name="en3", name="en5",
ip_address=conn_a_c.send_back_multiaddr.ip_address, ip_address=conn_c_a.send_back_multiaddr.ip_address,
type="rdma",
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_b_c.send_back_multiaddr.ip_address,
type="rdma",
), ),
ethernet_interface, ethernet_interface,
], ],
@@ -436,29 +402,26 @@ def test_tensor_rdma_backend_connectivity_matrix(
topology.add_connection(conn_a_b) topology.add_connection(conn_a_b)
topology.add_connection(conn_b_c) topology.add_connection(conn_b_c)
topology.add_connection(conn_c_a) topology.add_connection(conn_c_a)
topology.add_connection(conn_b_a)
topology.add_connection(conn_c_b)
topology.add_connection(conn_a_c)
create_instance_command = CreateInstance( cic = CreateInstance(
sharding=Sharding.Tensor,
instance_meta=InstanceMeta.MlxIbv,
command_id=CommandId(), command_id=CommandId(),
model_meta=model_meta, model_meta=model_meta,
strategy="tensor_rdma",
) )
placements = get_instance_placements_after_create( placements = get_instance_placements_after_create(cic, topology, {})
create_instance_command, topology, {}
)
assert len(placements) == 1 assert len(placements) == 1
instance_id = list(placements.keys())[0] instance_id = list(placements.keys())[0]
instance = placements[instance_id] instance = placements[instance_id]
assert instance.hosts is None assert isinstance(instance, MlxIbvInstance)
assert instance.mlx_ibv_devices is not None
assert instance.mlx_ibv_coordinator is not None
matrix = instance.mlx_ibv_devices assert instance.ibv_devices is not None
assert instance.ibv_coordinator is not None
matrix = instance.ibv_devices
assert len(matrix) == 3 assert len(matrix) == 3
for i in range(3): for i in range(3):
@@ -471,11 +434,9 @@ def test_tensor_rdma_backend_connectivity_matrix(
idx_b = node_to_idx[node_id_b] idx_b = node_to_idx[node_id_b]
idx_c = node_to_idx[node_id_c] idx_c = node_to_idx[node_id_c]
logger.info(matrix) assert matrix[idx_a][idx_b] == "rdma_en3"
assert matrix[idx_b][idx_c] == "rdma_en4"
assert matrix[idx_c][idx_a] == "rdma_en5"
assert matrix[idx_a][idx_b] == "rdma_en4" assert ":" in instance.ibv_coordinator
assert matrix[idx_b][idx_c] == "rdma_en3" assert not instance.ibv_coordinator.startswith("169.254")
assert matrix[idx_c][idx_a] == "rdma_en3"
assert ":" in instance.mlx_ibv_coordinator
assert not instance.mlx_ibv_coordinator.startswith("169.254")

View File

@@ -13,6 +13,7 @@ from exo.shared.types.common import Host, NodeId
from exo.shared.types.memory import Memory from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.topology import Connection, NodeInfo from exo.shared.types.topology import Connection, NodeInfo
from exo.shared.types.worker.shards import Sharding
@pytest.fixture @pytest.fixture
@@ -200,7 +201,9 @@ def test_get_shard_assignments(
selected_cycle = cycles[0] selected_cycle = cycles[0]
# act # act
shard_assignments = get_shard_assignments(model_meta, selected_cycle, "pipeline") shard_assignments = get_shard_assignments(
model_meta, selected_cycle, Sharding.Pipeline
)
# assert # assert
runner_id_a = shard_assignments.node_to_runner[node_a_id] runner_id_a = shard_assignments.node_to_runner[node_a_id]

View File

@@ -32,7 +32,7 @@ def node_profile() -> NodePerformanceProfile:
memory_profile = MemoryPerformanceProfile.from_bytes( memory_profile = MemoryPerformanceProfile.from_bytes(
ram_total=1000, ram_available=1000, swap_total=1000, swap_available=1000 ram_total=1000, ram_available=1000, swap_total=1000, swap_available=1000
) )
system_profile = SystemPerformanceProfile(flops_fp16=1000) system_profile = SystemPerformanceProfile()
return NodePerformanceProfile( return NodePerformanceProfile(
model_id="test", model_id="test",
chip_id="test", chip_id="test",
@@ -99,7 +99,7 @@ def test_update_node_profile(
ram_total=1000, ram_available=1000, swap_total=1000, swap_available=1000 ram_total=1000, ram_available=1000, swap_total=1000, swap_available=1000
), ),
network_interfaces=[], network_interfaces=[],
system=SystemPerformanceProfile(flops_fp16=1000), system=SystemPerformanceProfile(),
) )
# act # act

View File

@@ -10,6 +10,7 @@ from exo.shared.types.events import (
IndexedEvent, IndexedEvent,
InstanceCreated, InstanceCreated,
InstanceDeleted, InstanceDeleted,
NodeCreated,
NodeDownloadProgress, NodeDownloadProgress,
NodeMemoryMeasured, NodeMemoryMeasured,
NodePerformanceMeasured, NodePerformanceMeasured,
@@ -23,7 +24,6 @@ from exo.shared.types.events import (
TestEvent, TestEvent,
TopologyEdgeCreated, TopologyEdgeCreated,
TopologyEdgeDeleted, TopologyEdgeDeleted,
TopologyNodeCreated,
) )
from exo.shared.types.profiling import NodePerformanceProfile, SystemPerformanceProfile from exo.shared.types.profiling import NodePerformanceProfile, SystemPerformanceProfile
from exo.shared.types.state import State from exo.shared.types.state import State
@@ -41,14 +41,14 @@ def event_apply(event: Event, state: State) -> State:
TestEvent() | ChunkGenerated() | TaskAcknowledged() TestEvent() | ChunkGenerated() | TaskAcknowledged()
): # TaskAcknowledged should never be sent by a worker but i dont mind if it just gets ignored ): # TaskAcknowledged should never be sent by a worker but i dont mind if it just gets ignored
return state return state
case NodeDownloadProgress():
return apply_node_download_progress(event, state)
case InstanceCreated(): case InstanceCreated():
return apply_instance_created(event, state) return apply_instance_created(event, state)
case InstanceDeleted(): case InstanceDeleted():
return apply_instance_deleted(event, state) return apply_instance_deleted(event, state)
case NodePerformanceMeasured(): case NodePerformanceMeasured():
return apply_node_performance_measured(event, state) return apply_node_performance_measured(event, state)
case NodeDownloadProgress():
return apply_node_download_progress(event, state)
case NodeMemoryMeasured(): case NodeMemoryMeasured():
return apply_node_memory_measured(event, state) return apply_node_memory_measured(event, state)
case RunnerDeleted(): case RunnerDeleted():
@@ -63,7 +63,7 @@ def event_apply(event: Event, state: State) -> State:
return apply_task_failed(event, state) return apply_task_failed(event, state)
case TaskStatusUpdated(): case TaskStatusUpdated():
return apply_task_status_updated(event, state) return apply_task_status_updated(event, state)
case TopologyNodeCreated(): case NodeCreated():
return apply_topology_node_created(event, state) return apply_topology_node_created(event, state)
case TopologyEdgeCreated(): case TopologyEdgeCreated():
return apply_topology_edge_created(event, state) return apply_topology_edge_created(event, state)
@@ -173,7 +173,6 @@ def apply_runner_deleted(event: RunnerDeleted, state: State) -> State:
return state.model_copy(update={"runners": new_runners}) return state.model_copy(update={"runners": new_runners})
# TODO: This whole function needs fixing
def apply_node_performance_measured( def apply_node_performance_measured(
event: NodePerformanceMeasured, state: State event: NodePerformanceMeasured, state: State
) -> State: ) -> State:
@@ -183,8 +182,8 @@ def apply_node_performance_measured(
} }
state = state.model_copy(update={"node_profiles": new_profiles}) state = state.model_copy(update={"node_profiles": new_profiles})
topology = copy.copy(state.topology) topology = copy.copy(state.topology)
# TODO: NodeCreated
if not topology.contains_node(event.node_id): if not topology.contains_node(event.node_id):
# TODO: figure out why this is happening in the first place
topology.add_node(NodeInfo(node_id=event.node_id)) topology.add_node(NodeInfo(node_id=event.node_id))
topology.update_node_profile(event.node_id, event.node_profile) topology.update_node_profile(event.node_id, event.node_profile)
return state.model_copy(update={"topology": topology}) return state.model_copy(update={"topology": topology})
@@ -202,7 +201,7 @@ def apply_node_memory_measured(event: NodeMemoryMeasured, state: State) -> State
memory=event.memory, memory=event.memory,
network_interfaces=[], network_interfaces=[],
system=SystemPerformanceProfile( system=SystemPerformanceProfile(
flops_fp16=0.0, # TODO: flops_fp16=0.0,
gpu_usage=0.0, gpu_usage=0.0,
temp=0.0, temp=0.0,
sys_power=0.0, sys_power=0.0,
@@ -217,6 +216,7 @@ def apply_node_memory_measured(event: NodeMemoryMeasured, state: State) -> State
} }
if not topology.contains_node(event.node_id): if not topology.contains_node(event.node_id):
topology.add_node(NodeInfo(node_id=event.node_id)) topology.add_node(NodeInfo(node_id=event.node_id))
# TODO: NodeCreated
topology.update_node_profile(event.node_id, created) topology.update_node_profile(event.node_id, created)
return state.model_copy( return state.model_copy(
update={"node_profiles": created_profiles, "topology": topology} update={"node_profiles": created_profiles, "topology": topology}
@@ -227,6 +227,7 @@ def apply_node_memory_measured(event: NodeMemoryMeasured, state: State) -> State
**state.node_profiles, **state.node_profiles,
event.node_id: updated, event.node_id: updated,
} }
# TODO: NodeCreated
if not topology.contains_node(event.node_id): if not topology.contains_node(event.node_id):
topology.add_node(NodeInfo(node_id=event.node_id)) topology.add_node(NodeInfo(node_id=event.node_id))
topology.update_node_profile(event.node_id, updated) topology.update_node_profile(event.node_id, updated)
@@ -235,7 +236,7 @@ def apply_node_memory_measured(event: NodeMemoryMeasured, state: State) -> State
) )
def apply_topology_node_created(event: TopologyNodeCreated, state: State) -> State: def apply_topology_node_created(event: NodeCreated, state: State) -> State:
topology = copy.copy(state.topology) topology = copy.copy(state.topology)
topology.add_node(NodeInfo(node_id=event.node_id)) topology.add_node(NodeInfo(node_id=event.node_id))
return state.model_copy(update={"topology": topology}) return state.model_copy(update={"topology": topology})

View File

@@ -1,23 +0,0 @@
from typing import TYPE_CHECKING, Literal, TypeAlias, get_type_hints
if TYPE_CHECKING:
import openai.types as openai_types
import openai.types.chat as openai_chat
types = openai_types
chat = openai_chat
else:
types = None
chat = None
FinishReason: TypeAlias = Literal[
"stop", "length", "tool_calls", "content_filter", "function_call"
]
if TYPE_CHECKING:
assert (
get_type_hints(chat.chat_completion_chunk.Choice)["finish_reason"]
== FinishReason
), "Upstream changed Choice.finish_reason; update FinishReason alias."
__all__ = ["types", "chat", "FinishReason"]

View File

@@ -3,12 +3,15 @@ from typing import Any, Literal
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from exo.shared.openai_compat import FinishReason
from exo.shared.types.common import CommandId from exo.shared.types.common import CommandId
from exo.shared.types.models import ModelMetadata from exo.shared.types.models import ModelMetadata
from exo.shared.types.worker.instances import InstanceId, InstanceMeta from exo.shared.types.worker.instances import InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding from exo.shared.types.worker.shards import Sharding
FinishReason = Literal[
"stop", "length", "tool_calls", "content_filter", "function_call"
]
class ModelListModel(BaseModel): class ModelListModel(BaseModel):
id: str id: str

View File

@@ -1,9 +1,10 @@
from enum import Enum from enum import Enum
from exo.shared.openai_compat import FinishReason
from exo.shared.types.models import ModelId
from exo.utils.pydantic_ext import TaggedModel from exo.utils.pydantic_ext import TaggedModel
from .api import FinishReason
from .models import ModelId
class ChunkType(str, Enum): class ChunkType(str, Enum):
Token = "Token" Token = "Token"

View File

@@ -8,7 +8,6 @@ from exo.shared.types.worker.shards import Sharding
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
# TODO: We need to have a distinction between create instance and spin up instance.
class BaseCommand(TaggedModel): class BaseCommand(TaggedModel):
command_id: CommandId = Field(default_factory=CommandId) command_id: CommandId = Field(default_factory=CommandId)

View File

@@ -70,6 +70,11 @@ class RunnerDeleted(BaseEvent):
runner_id: RunnerId runner_id: RunnerId
# TODO
class NodeCreated(BaseEvent):
node_id: NodeId
class NodePerformanceMeasured(BaseEvent): class NodePerformanceMeasured(BaseEvent):
node_id: NodeId node_id: NodeId
node_profile: NodePerformanceProfile node_profile: NodePerformanceProfile
@@ -89,10 +94,6 @@ class ChunkGenerated(BaseEvent):
chunk: GenerationChunk chunk: GenerationChunk
class TopologyNodeCreated(BaseEvent):
node_id: NodeId
class TopologyEdgeCreated(BaseEvent): class TopologyEdgeCreated(BaseEvent):
edge: Connection edge: Connection
@@ -116,7 +117,7 @@ Event = (
| NodeMemoryMeasured | NodeMemoryMeasured
| NodeDownloadProgress | NodeDownloadProgress
| ChunkGenerated | ChunkGenerated
| TopologyNodeCreated | NodeCreated
| TopologyEdgeCreated | TopologyEdgeCreated
| TopologyEdgeDeleted | TopologyEdgeDeleted
) )

View File

@@ -1,5 +1,7 @@
from typing import Self from typing import Self
import psutil
from exo.shared.types.memory import Memory from exo.shared.types.memory import Memory
from exo.utils.pydantic_ext import CamelCaseModel from exo.utils.pydantic_ext import CamelCaseModel
@@ -21,9 +23,21 @@ class MemoryPerformanceProfile(CamelCaseModel):
swap_available=Memory.from_bytes(swap_available), swap_available=Memory.from_bytes(swap_available),
) )
@classmethod
def from_psutil(cls, *, override_memory: int | None) -> Self:
vm = psutil.virtual_memory()
sm = psutil.swap_memory()
return cls.from_bytes(
ram_total=vm.total,
ram_available=vm.available if override_memory is None else override_memory,
swap_total=sm.total,
swap_available=sm.free,
)
class SystemPerformanceProfile(CamelCaseModel): class SystemPerformanceProfile(CamelCaseModel):
flops_fp16: float # TODO: flops_fp16: float
gpu_usage: float = 0.0 gpu_usage: float = 0.0
temp: float = 0.0 temp: float = 0.0
@@ -36,7 +50,6 @@ class SystemPerformanceProfile(CamelCaseModel):
class NetworkInterfaceInfo(CamelCaseModel): class NetworkInterfaceInfo(CamelCaseModel):
name: str name: str
ip_address: str ip_address: str
type: str
class NodePerformanceProfile(CamelCaseModel): class NodePerformanceProfile(CamelCaseModel):

View File

@@ -1,45 +0,0 @@
from exo.shared.openai_compat import FinishReason
from exo.utils.pydantic_ext import TaggedModel
class BaseRunnerResponse(TaggedModel):
pass
class InitializedResponse(BaseRunnerResponse):
time_taken: float
class TokenizedResponse(BaseRunnerResponse):
prompt_tokens: int
class GenerationResponse(BaseRunnerResponse):
text: str
token: int
# logprobs: list[float] | None = None # too big. we can change to be top-k
finish_reason: FinishReason | None = None
class PrintResponse(BaseRunnerResponse):
text: str
class FinishedResponse(BaseRunnerResponse):
pass
class ErrorResponse(BaseRunnerResponse):
error_type: str
error_message: str
traceback: str
RunnerResponse = (
InitializedResponse
| TokenizedResponse
| GenerationResponse
| PrintResponse
| FinishedResponse
| ErrorResponse
)

View File

@@ -1 +0,0 @@

View File

@@ -1,34 +0,0 @@
from exo.shared.types.tasks import Task
from exo.shared.types.worker.instances import BoundInstance, Instance
from exo.shared.types.worker.runners import RunnerId
from exo.utils.pydantic_ext import TaggedModel
class BaseRunnerOp(TaggedModel):
runner_id: RunnerId
class AssignRunnerOp(BaseRunnerOp):
instance: Instance
def bound_instance(self) -> BoundInstance:
return BoundInstance(instance=self.instance, bound_runner_id=self.runner_id)
class UnassignRunnerOp(BaseRunnerOp):
pass
class RunnerUpOp(BaseRunnerOp):
pass
class RunnerDownOp(BaseRunnerOp):
pass
class ExecuteTaskOp(BaseRunnerOp):
task: Task
RunnerOp = AssignRunnerOp | ExecuteTaskOp | UnassignRunnerOp | RunnerUpOp | RunnerDownOp

View File

@@ -0,0 +1,21 @@
from exo.shared.types.api import FinishReason
from exo.utils.pydantic_ext import TaggedModel
class BaseRunnerResponse(TaggedModel):
pass
class TokenizedResponse(BaseRunnerResponse):
prompt_tokens: int
class GenerationResponse(BaseRunnerResponse):
text: str
token: int
# logprobs: list[float] | None = None # too big. we can change to be top-k
finish_reason: FinishReason | None = None
class FinishedResponse(BaseRunnerResponse):
pass

View File

@@ -177,7 +177,7 @@ class MpReceiver[T]:
try: try:
item = self._state.buffer.get(block=False) item = self._state.buffer.get(block=False)
if item is MP_END_OF_STREAM: if item == MP_END_OF_STREAM:
self.close() self.close()
raise EndOfStream raise EndOfStream
assert not isinstance(item, _MpEndOfStream) assert not isinstance(item, _MpEndOfStream)
@@ -193,7 +193,7 @@ class MpReceiver[T]:
return self.receive_nowait() return self.receive_nowait()
except WouldBlock: except WouldBlock:
item = self._state.buffer.get() item = self._state.buffer.get()
if item is MP_END_OF_STREAM: if item == MP_END_OF_STREAM:
self.close() self.close()
raise EndOfStream from None raise EndOfStream from None
assert not isinstance(item, _MpEndOfStream) assert not isinstance(item, _MpEndOfStream)

View File

@@ -1,2 +0,0 @@
- Where should we check where the model is downloaded?
- Error handling. How do we handle the scenario where an operation keeps failing to execute

View File

@@ -1,36 +0,0 @@
import pytest
from exo.shared.models.model_meta import get_model_meta
from exo.shared.types.models import ModelMetadata
from exo.shared.types.worker.shards import PipelineShardMetadata
@pytest.fixture
async def model_meta() -> ModelMetadata:
return await get_model_meta("mlx-community/Llama-3.2-1B-Instruct-4bit")
@pytest.fixture
def pipeline_shard_meta(model_meta: ModelMetadata):
def _pipeline_shard_meta(
num_nodes: int = 1, device_rank: int = 0
) -> PipelineShardMetadata:
total_layers = 16
layers_per_node = total_layers // num_nodes
start_layer = device_rank * layers_per_node
end_layer = (
start_layer + layers_per_node
if device_rank < num_nodes - 1
else total_layers
)
return PipelineShardMetadata(
model_meta=model_meta,
device_rank=device_rank,
n_layers=total_layers,
start_layer=start_layer,
end_layer=end_layer,
world_size=num_nodes,
)
return _pipeline_shard_meta

View File

@@ -1,7 +1,7 @@
import os import os
from fnmatch import fnmatch from fnmatch import fnmatch
from pathlib import Path from pathlib import Path
from typing import Callable, Generator, Iterable, TypeVar from typing import Callable, Generator, Iterable
import aiofiles import aiofiles
import aiofiles.os as aios import aiofiles.os as aios
@@ -9,10 +9,8 @@ from loguru import logger
from exo.shared.types.worker.shards import ShardMetadata from exo.shared.types.worker.shards import ShardMetadata
T = TypeVar("T")
def filter_repo_objects[T](
def filter_repo_objects(
items: Iterable[T], items: Iterable[T],
*, *,
allow_patterns: list[str] | str | None = None, allow_patterns: list[str] | str | None = None,

View File

@@ -1,9 +1,8 @@
from typing import Any from typing import Any
from mlx_lm.models.cache import KVCache
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from mlx_lm.models.cache import KVCache
# These are wrapper functions to fix the fact that mlx is not strongly typed in the same way that EXO is. # These are wrapper functions to fix the fact that mlx is not strongly typed in the same way that EXO is.
# For example - MLX has no guarantee of the interface that nn.Module will expose. But we need a guarantee that it has a __call__() function # For example - MLX has no guarantee of the interface that nn.Module will expose. But we need a guarantee that it has a __call__() function

View File

@@ -3,8 +3,14 @@ from functools import partial
from inspect import signature from inspect import signature
from typing import TYPE_CHECKING, Callable, Protocol, cast, override from typing import TYPE_CHECKING, Callable, Protocol, cast, override
import mlx.core as mx
import mlx.nn as nn
from mlx.nn.layers.distributed import (
shard_inplace,
shard_linear,
sum_gradients,
)
from mlx_lm.models.cache import ( from mlx_lm.models.cache import (
KVCache,
_BaseCache, # pyright: ignore[reportPrivateUsage] _BaseCache, # pyright: ignore[reportPrivateUsage]
) )
from mlx_lm.models.deepseek_v3 import DeepseekV3MLP from mlx_lm.models.deepseek_v3 import DeepseekV3MLP
@@ -13,16 +19,9 @@ from mlx_lm.models.llama import Model as LlamaModel
from mlx_lm.models.qwen3_moe import Model as Qwen3MoeModel from mlx_lm.models.qwen3_moe import Model as Qwen3MoeModel
from mlx_lm.models.qwen3_moe import Qwen3MoeSparseMoeBlock from mlx_lm.models.qwen3_moe import Qwen3MoeSparseMoeBlock
import mlx.core as mx
import mlx.nn as nn
from exo.shared.types.worker.shards import ( from exo.shared.types.worker.shards import (
PipelineShardMetadata, PipelineShardMetadata,
) )
from mlx.nn.layers.distributed import (
shard_inplace,
shard_linear,
sum_gradients,
)
class _LayerCallable(Protocol): class _LayerCallable(Protocol):
@@ -94,7 +93,7 @@ class PipelineLastLayer(CustomMlxLayer):
x, *args, **kwargs x, *args, **kwargs
).arguments.get("cache", None) ).arguments.get("cache", None)
assert cache is None or issubclass(type(cache), _BaseCache) # type: ignore assert cache is None or issubclass(type(cache), _BaseCache) # type: ignore
output: mx.array = self.original_layer(x, *args, **kwargs) output: mx.array = self.original_layer(x, *args, **kwargs)

View File

@@ -1,14 +1,16 @@
# type: ignore
# TODO: Fix this file, including types!
from copy import deepcopy from copy import deepcopy
from typing import Callable from typing import Callable
import mlx.core as mx
from mlx_lm import stream_generate from mlx_lm import stream_generate
from mlx_lm.models.cache import _BaseCache, trim_prompt_cache from mlx_lm.models.cache import _BaseCache, trim_prompt_cache
from mlx_lm.tokenizer_utils import TokenizerWrapper from mlx_lm.tokenizer_utils import TokenizerWrapper
import mlx.core as mx from exo.worker.engines.mlx import Model
from exo.engines.mlx import Model from exo.worker.engines.mlx.constants import KEEP_KV_SIZE, KV_BITS, KV_GROUP_SIZE
from exo.engines.mlx.constants import KEEP_KV_SIZE, KV_BITS, KV_GROUP_SIZE from exo.worker.engines.mlx.utils_mlx import make_kv_cache
from exo.engines.mlx.utils_mlx import make_kv_cache
class KVPrefixCache: class KVPrefixCache:

View File

@@ -0,0 +1,18 @@
# TODO: Do we want so many constants?
# I think we want a lot of these as parameters?
KV_GROUP_SIZE: int | None = 32
KV_BITS: int | None = None
ATTENTION_KV_BITS: int | None = 4
MAX_TOKENS: int = 8192
MAX_KV_SIZE: int | None = 3200
KEEP_KV_SIZE: int | None = 1600
QUANTIZE_MODEL_MODE: str | None = "affine"
CACHE_GROUP_SIZE: int = 64
KV_CACHE_BITS: int | None = 8
TEMPERATURE: float = 1.0
# TODO: We should really make this opt-in, but Kimi requires trust_remote_code=True
TRUST_REMOTE_CODE: bool = True
# TODO: Do we really want this?
HIDE_THINKING: bool = False

View File

@@ -1,6 +1,7 @@
import os import os
import resource import resource
import time import time
from pathlib import Path
from typing import Any, Callable, cast from typing import Any, Callable, cast
from mlx_lm.models.cache import KVCache, QuantizedKVCache, RotatingKVCache from mlx_lm.models.cache import KVCache, QuantizedKVCache, RotatingKVCache
@@ -8,29 +9,22 @@ from mlx_lm.models.deepseek_v3 import DeepseekV3Model
from mlx_lm.sample_utils import make_sampler from mlx_lm.sample_utils import make_sampler
from mlx_lm.tokenizer_utils import TokenizerWrapper from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.worker.runner.utils import get_weights_size from exo.worker.engines.mlx.constants import (
CACHE_GROUP_SIZE,
KV_CACHE_BITS,
TEMPERATURE,
TRUST_REMOTE_CODE,
)
try: try:
from mlx_lm.tokenizer_utils import load_tokenizer from mlx_lm.tokenizer_utils import load_tokenizer
except ImportError: except ImportError:
from mlx_lm.tokenizer_utils import load as load_tokenizer # type: ignore from mlx_lm.tokenizer_utils import load as load_tokenizer # type: ignore
import mlx.core as mx
import mlx.nn as nn
from mlx_lm.utils import load_model from mlx_lm.utils import load_model
from pydantic import RootModel from pydantic import RootModel
import mlx.core as mx
import mlx.nn as nn
from exo.engines.mlx import Model
from exo.engines.mlx.auto_parallel import (
pipeline_auto_parallel,
tensor_auto_parallel,
)
from exo.engines.mlx.constants import (
CACHE_GROUP_SIZE,
KV_CACHE_BITS,
PATCH_SYSTEM_PROMPT,
TEMPERATURE,
TRUST_REMOTE_CODE,
)
from exo.shared.types.api import ChatCompletionMessageText from exo.shared.types.api import ChatCompletionMessageText
from exo.shared.types.common import Host from exo.shared.types.common import Host
from exo.shared.types.memory import Memory from exo.shared.types.memory import Memory
@@ -46,13 +40,31 @@ from exo.shared.types.worker.shards import (
TensorShardMetadata, TensorShardMetadata,
) )
from exo.worker.download.download_utils import build_model_path from exo.worker.download.download_utils import build_model_path
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.auto_parallel import (
pipeline_auto_parallel,
tensor_auto_parallel,
)
from exo.worker.runner.bootstrap import logger from exo.worker.runner.bootstrap import logger
# Needed for 8 bit model # Needed for 8 bit model
resource.setrlimit(resource.RLIMIT_NOFILE, (2048, 4096)) resource.setrlimit(resource.RLIMIT_NOFILE, (2048, 4096))
mlx_rank: None | int = None
mlx_world_size: None | int = None # TODO: Test this
# ALSO https://github.com/exo-explore/exo/pull/233#discussion_r2549683673
def get_weights_size(model_shard_meta: ShardMetadata) -> Memory:
return Memory.from_float_kb(
(model_shard_meta.end_layer - model_shard_meta.start_layer)
/ model_shard_meta.n_layers
* model_shard_meta.model_meta.storage_size.in_kb
/ (
1
if isinstance(model_shard_meta, PipelineShardMetadata)
else model_shard_meta.world_size
)
)
def mx_barrier(group: mx.distributed.Group | None = None): def mx_barrier(group: mx.distributed.Group | None = None):
mx.eval( mx.eval(
@@ -65,10 +77,10 @@ def mx_barrier(group: mx.distributed.Group | None = None):
def broadcast_from_zero(value: int, group: mx.distributed.Group | None = None): def broadcast_from_zero(value: int, group: mx.distributed.Group | None = None):
if mlx_rank is None: if group is None:
return value return value
if mlx_rank == 0: if group.rank() == 0:
a = mx.array([value], dtype=mx.int32) a = mx.array([value], dtype=mx.int32)
else: else:
a = mx.array([0], dtype=mx.int32) a = mx.array([0], dtype=mx.int32)
@@ -154,10 +166,10 @@ def initialize_mlx(
logger.info(f"Single device used for {bound_instance.instance}") logger.info(f"Single device used for {bound_instance.instance}")
model_path = build_model_path(bound_instance.bound_shard.model_meta.model_id) model_path = build_model_path(bound_instance.bound_shard.model_meta.model_id)
start_time = time.perf_counter() start_time = time.perf_counter()
model, config = load_model(model_path, strict=True) model, _ = load_model(model_path, strict=True)
end_time = time.perf_counter() end_time = time.perf_counter()
logger.info(f"Time taken to load model: {(end_time - start_time):.2f}s") logger.info(f"Time taken to load model: {(end_time - start_time):.2f}s")
if isinstance(model.model, DeepseekV3Model): if hasattr(model, "model") and isinstance(model.model, DeepseekV3Model): # type: ignore
pass pass
# model, config = quantize_model( # model, config = quantize_model(
# model, config, group_size=KV_GROUP_SIZE, bits=ATTENTION_KV_BITS, quant_predicate=quant_predicate, mode=QUANTIZE_MODEL_MODE # model, config, group_size=KV_GROUP_SIZE, bits=ATTENTION_KV_BITS, quant_predicate=quant_predicate, mode=QUANTIZE_MODEL_MODE
@@ -189,9 +201,9 @@ def shard_and_load(
) -> tuple[nn.Module, TokenizerWrapper]: ) -> tuple[nn.Module, TokenizerWrapper]:
model_path = build_model_path(shard_metadata.model_meta.model_id) model_path = build_model_path(shard_metadata.model_meta.model_id)
model, config = load_model(model_path, lazy=True, strict=False) model, _ = load_model(model_path, lazy=True, strict=False)
logger.debug(model) logger.debug(model)
if isinstance(model.model, DeepseekV3Model): if hasattr(model, "model") and isinstance(model.model, DeepseekV3Model): # type: ignore
pass pass
# TODO: See if we should quantize the model. # TODO: See if we should quantize the model.
# def is_attention_layer(path: str) -> bool: # def is_attention_layer(path: str) -> bool:
@@ -199,7 +211,6 @@ def shard_and_load(
# return "self_attn" in path and "layernorm" not in path # return "self_attn" in path and "layernorm" not in path
# def quant_predicate(path: str, module: nn.Module): # def quant_predicate(path: str, module: nn.Module):
# if not isinstance(module, nn.Linear): # if not isinstance(module, nn.Linear):
# return False # return False
@@ -237,7 +248,7 @@ def shard_and_load(
return model, tokenizer return model, tokenizer
def get_tokenizer(model_path: str, shard_metadata: ShardMetadata): def get_tokenizer(model_path: Path, shard_metadata: ShardMetadata):
tokenizer = cast( tokenizer = cast(
TokenizerWrapper, TokenizerWrapper,
load_tokenizer( load_tokenizer(
@@ -262,7 +273,7 @@ def apply_chat_template(
messages = chat_task_data.messages messages = chat_task_data.messages
formatted_messages: list[dict[str, Any]] = [] formatted_messages: list[dict[str, Any]] = []
for i, message in enumerate(messages): for _, message in enumerate(messages):
if isinstance(message.content, ChatCompletionMessageText): if isinstance(message.content, ChatCompletionMessageText):
message.content = message.content.text message.content = message.content.text
if isinstance(message.content, list): if isinstance(message.content, list):
@@ -276,7 +287,7 @@ def apply_chat_template(
# Null values are not valid when applying templates in tokenizer # Null values are not valid when applying templates in tokenizer
formatted_messages.append( formatted_messages.append(
{k: v for k, v in message.model_dump().items() if v is not None} {k: v for k, v in message.model_dump().items() if v is not None} # type: ignore
) )
prompt: str = tokenizer.apply_chat_template( # type: ignore prompt: str = tokenizer.apply_chat_template( # type: ignore

View File

@@ -226,9 +226,7 @@ class Worker:
task_id=task.task_id, task_status=TaskStatus.Running task_id=task.task_id, task_status=TaskStatus.Running
) )
) )
await self._handle_shard_download_process( self._handle_shard_download_process(task, initial_progress)
task, initial_progress
)
case Shutdown(runner_id=runner_id): case Shutdown(runner_id=runner_id):
await self.runners.pop(runner_id).start_task(task) await self.runners.pop(runner_id).start_task(task)
case task: case task:
@@ -313,7 +311,7 @@ class Worker:
self._tg.start_soon(runner.run) self._tg.start_soon(runner.run)
return runner return runner
async def _handle_shard_download_process( def _handle_shard_download_process(
self, self,
task: DownloadModel, task: DownloadModel,
initial_progress: RepoDownloadProgress, initial_progress: RepoDownloadProgress,

View File

@@ -17,6 +17,7 @@ from exo.shared.types.tasks import (
from exo.shared.types.worker.downloads import DownloadCompleted, DownloadProgress from exo.shared.types.worker.downloads import DownloadCompleted, DownloadProgress
from exo.shared.types.worker.instances import BoundInstance, Instance, InstanceId from exo.shared.types.worker.instances import BoundInstance, Instance, InstanceId
from exo.shared.types.worker.runners import ( from exo.shared.types.worker.runners import (
RunnerFailed,
RunnerId, RunnerId,
RunnerLoaded, RunnerLoaded,
RunnerLoading, RunnerLoading,
@@ -59,16 +60,21 @@ def _kill_runner(
instances: Mapping[InstanceId, Instance], instances: Mapping[InstanceId, Instance],
) -> Shutdown | None: ) -> Shutdown | None:
for runner in runners.values(): for runner in runners.values():
runner_id = runner.bound_instance.bound_runner_id
if (instance_id := runner.bound_instance.instance.instance_id) not in instances: if (instance_id := runner.bound_instance.instance.instance_id) not in instances:
return Shutdown( return Shutdown(instance_id=instance_id, runner_id=runner_id)
instance_id=instance_id, runner_id=runner.bound_instance.bound_runner_id
)
""" --- Potential code to kill a runner if any runners in its instance have failed --- for (
global_runners_in_instance = runner.bound_instance.instance.shard_assignments.node_to_runner.values() global_runner_id
if any(isinstance(all_runners[runner_id], RunnerFailed) for runner_id in global_runners_in_instance if runner_id != runner.bound_instance.bound_runner_id): ) in runner.bound_instance.instance.shard_assignments.node_to_runner.values():
Shutdown(instance_id=runner.bound_instance.instance.instance_id, runner_id=runner.bound_instance.bound_runner_id) if runner_id == global_runner_id:
""" continue
if isinstance(all_runners.get(global_runner_id, None), RunnerFailed):
return Shutdown(
instance_id=instance_id,
runner_id=runner_id,
)
def _create_runner( def _create_runner(
@@ -125,25 +131,36 @@ def _load_model(
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]], global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
) -> LoadModel | None: ) -> LoadModel | None:
for runner in runners.values(): for runner in runners.values():
if ( instance = runner.bound_instance.instance
all( shard_assignments = instance.shard_assignments
all_downloads_complete_local = all(
any(
isinstance(dp, DownloadCompleted) isinstance(dp, DownloadCompleted)
if dp.shard_metadata and dp.shard_metadata == shard_assignments.runner_to_shard[rid]
== runner.bound_instance.instance.shard_assignments.runner_to_shard[rid]
else True
for nid, rid in runner.bound_instance.instance.shard_assignments.node_to_runner.items()
for dp in global_download_status[nid] for dp in global_download_status[nid]
) )
and isinstance(runner.status, RunnerWaitingForModel) for nid, rid in shard_assignments.node_to_runner.items()
and all( )
isinstance(
all_runners.get(global_runner_id, None), runner_is_waiting = isinstance(runner.status, RunnerWaitingForModel)
(RunnerWaitingForModel, RunnerLoading, RunnerLoaded),
) all_runners_expecting_model = all(
for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard isinstance(
all_runners.get(global_runner_id),
(RunnerWaitingForModel, RunnerLoading, RunnerLoaded),
) )
for global_runner_id in shard_assignments.runner_to_shard
)
if (
all_downloads_complete_local
and runner_is_waiting
and all_runners_expecting_model
): ):
return LoadModel(instance_id=runner.bound_instance.instance.instance_id) return LoadModel(instance_id=instance.instance_id)
return None
def _ready_to_warmup( def _ready_to_warmup(
@@ -151,29 +168,37 @@ def _ready_to_warmup(
all_runners: Mapping[RunnerId, RunnerStatus], all_runners: Mapping[RunnerId, RunnerStatus],
) -> StartWarmup | None: ) -> StartWarmup | None:
for runner in runners.values(): for runner in runners.values():
if isinstance(runner.status, RunnerLoaded) and ( instance = runner.bound_instance.instance
( shard_assignments = instance.shard_assignments
all( shard = runner.bound_instance.bound_shard
isinstance( device_rank = shard.device_rank
all_runners.get(global_runner_id, None), runner_id = runner.bound_instance.bound_runner_id
(RunnerLoaded, RunnerWarmingUp),
) is_runner_loaded = isinstance(runner.status, RunnerLoaded)
for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard
) # Rank != 0
and runner.bound_instance.bound_shard.device_rank != 0 all_runners_loaded_or_warming_up = all(
isinstance(
all_runners.get(global_runner_id, None),
(RunnerLoaded, RunnerWarmingUp),
) )
or ( for global_runner_id in shard_assignments.runner_to_shard
all( )
isinstance(
all_runners.get(global_runner_id, None), (RunnerWarmingUp) # Rank= 0
) all_other_runners_warming_up = all(
for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard isinstance(all_runners.get(global_runner_id, None), RunnerWarmingUp)
if global_runner_id != runner.bound_instance.bound_runner_id for global_runner_id in shard_assignments.runner_to_shard
) if global_runner_id != runner_id
and runner.bound_instance.bound_shard.device_rank == 0 )
)
): nonzero_rank_ready = device_rank != 0 and all_runners_loaded_or_warming_up
return StartWarmup(instance_id=runner.bound_instance.instance.instance_id) zero_rank_ready = device_rank == 0 and all_other_runners_warming_up
if is_runner_loaded and (nonzero_rank_ready or zero_rank_ready):
return StartWarmup(instance_id=instance.instance_id)
return None
def _pending_tasks( def _pending_tasks(

View File

@@ -1,8 +1,4 @@
"""--- not doing this anymore
import faulthandler
import os import os
import sys
"""
import loguru import loguru
@@ -11,45 +7,25 @@ from exo.shared.types.tasks import Task
from exo.shared.types.worker.instances import BoundInstance from exo.shared.types.worker.instances import BoundInstance
from exo.utils.channels import MpReceiver, MpSender from exo.utils.channels import MpReceiver, MpSender
""" -- not doing this anymore logger: "loguru.Logger"
def _redirect_stderr_to_file(path: str) -> None:
# Replace fd 2 (stderr) with a file descriptor pointing to `path`
fd = os.open(path, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o644) if os.getenv("EXO_TESTS") == "1":
os.dup2(fd, 2) logger = loguru.logger
os.close(fd)
# Rebind sys.stderr so Python's own writes go to the new fd as well (line-buffered)
sys.stderr = os.fdopen(2, "w", buffering=1, closefd=False)
"""
def entrypoint( def entrypoint(
bound_instance: BoundInstance, bound_instance: BoundInstance,
event_sender: MpSender[Event], event_sender: MpSender[Event],
task_receiver: MpReceiver[Task], task_receiver: MpReceiver[Task],
# err_path: str,
_logger: "loguru.Logger", _logger: "loguru.Logger",
) -> None: ) -> None:
"""
Minimal entrypoint for the spawned child process.
It redirects fd=2 (stderr) to a pipe provided by the parent, *then* imports
the heavy runner module so that any C/C++ or MLX logs/crashes land in that pipe.
"""
""" --- not doing this anymore
_redirect_stderr_to_file(err_path)
faulthandler.enable(file=sys.stderr, all_threads=True)
"""
import os
os.environ["MLX_METAL_FAST_SYNCH"] = "1" os.environ["MLX_METAL_FAST_SYNCH"] = "1"
global logger global logger
logger = _logger logger = _logger
# Import the heavy runner only after stderr is redirected # Import main after setting global logger - this lets us just import logger from this module
from exo.worker.runner.runner import main from exo.worker.runner.runner import main
main(bound_instance, event_sender, task_receiver) main(bound_instance, event_sender, task_receiver)
logger: "loguru.Logger"

View File

@@ -5,21 +5,19 @@ from mlx_lm import stream_generate
from mlx_lm.models.cache import KVCache from mlx_lm.models.cache import KVCache
from mlx_lm.tokenizer_utils import TokenizerWrapper from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.engines.mlx import Model
# from exo.engines.mlx.cache import KVPrefixCache # from exo.engines.mlx.cache import KVPrefixCache
from exo.engines.mlx.constants import KV_BITS, KV_GROUP_SIZE, MAX_TOKENS from exo.shared.types.api import ChatCompletionMessage, FinishReason
from exo.engines.mlx.utils_mlx import ( from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.worker.runner_response import (
GenerationResponse,
)
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.constants import KV_BITS, KV_GROUP_SIZE, MAX_TOKENS
from exo.worker.engines.mlx.utils_mlx import (
apply_chat_template, apply_chat_template,
make_kv_cache, make_kv_cache,
mx_barrier, mx_barrier,
) )
from exo.shared.openai_compat import FinishReason
from exo.shared.types.api import ChatCompletionMessage
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.worker.commands_runner import (
GenerationResponse,
)
from exo.worker.runner.bootstrap import logger from exo.worker.runner.bootstrap import logger
generation_stream = mx.new_stream(mx.default_device()) generation_stream = mx.new_stream(mx.default_device())

View File

@@ -1,9 +1,6 @@
import time import time
from exo.engines.mlx.utils_mlx import ( from exo.shared.types.api import ChatCompletionMessageText
initialize_mlx,
mlx_force_oom,
)
from exo.shared.types.chunks import TokenChunk from exo.shared.types.chunks import TokenChunk
from exo.shared.types.events import ( from exo.shared.types.events import (
ChunkGenerated, ChunkGenerated,
@@ -20,11 +17,10 @@ from exo.shared.types.tasks import (
Task, Task,
TaskStatus, TaskStatus,
) )
from exo.shared.types.worker.commands_runner import (
GenerationResponse,
# TokenizedResponse,
)
from exo.shared.types.worker.instances import BoundInstance from exo.shared.types.worker.instances import BoundInstance
from exo.shared.types.worker.runner_response import (
GenerationResponse,
)
from exo.shared.types.worker.runners import ( from exo.shared.types.worker.runners import (
RunnerFailed, RunnerFailed,
RunnerLoaded, RunnerLoaded,
@@ -37,6 +33,10 @@ from exo.shared.types.worker.runners import (
RunnerWarmingUp, RunnerWarmingUp,
) )
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
from exo.worker.engines.mlx.utils_mlx import (
initialize_mlx,
mlx_force_oom,
)
from exo.worker.runner.bootstrap import logger from exo.worker.runner.bootstrap import logger
from exo.worker.runner.generate import mlx_generate, warmup_inference from exo.worker.runner.generate import mlx_generate, warmup_inference
@@ -142,27 +142,8 @@ def main(
runner_id=runner_id, runner_status=current_status runner_id=runner_id, runner_status=current_status
) )
) )
# Ensure we have a chat-completion task subtype assert task_params.messages[0].content is not None
# TODO: this is a hack, why are we only looking at the first message? should have a tokenizer _check_for_debug_prompts(task_params.messages[0].content)
prompt = task_params.messages[0]
if (
prompt.content is not None
and "EXO RUNNER MUST FAIL" in prompt.content
):
logger.info("raising exception")
raise Exception(
"Artificial runner exception - for testing purposes only."
)
if (
prompt.content is not None
and "EXO RUNNER MUST OOM" in prompt.content
):
mlx_force_oom()
if (
prompt.content is not None
and "EXO RUNNER MUST TIMEOUT" in prompt.content
):
time.sleep(100)
# Generate responses using the actual MLX generation # Generate responses using the actual MLX generation
for response in mlx_generate( for response in mlx_generate(
@@ -186,9 +167,9 @@ def main(
), ),
) )
) )
# case TokenizedResponse(): # case TokenizedResponse():
# TODO: something here ig # TODO: something here ig
# logger.info("Finished tokenizing?") logger.info("Finished tokenizing?")
current_status = RunnerReady() current_status = RunnerReady()
logger.info("runner ready") logger.info("runner ready")
@@ -233,3 +214,29 @@ def main(
event_sender.join() event_sender.join()
task_receiver.join() task_receiver.join()
logger.info("bye from the runner") logger.info("bye from the runner")
EXO_RUNNER_MUST_FAIL = "EXO RUNNER MUST FAIL"
EXO_RUNNER_MUST_OOM = "EXO RUNNER MUST OOM"
EXO_RUNNER_MUST_TIMEOUT = "EXO RUNNER MUST TIMEOUT"
def _check_for_debug_prompts(
prompt: str | ChatCompletionMessageText | list[ChatCompletionMessageText],
):
if isinstance(prompt, list):
if len(prompt) == 0:
logger.debug("Empty message prompt received in debug prompt")
return
prompt = prompt[0]
if isinstance(prompt, ChatCompletionMessageText):
prompt = prompt.text
if EXO_RUNNER_MUST_FAIL in prompt:
logger.info("raising exception")
raise Exception("Artificial runner exception - for testing purposes only.")
if EXO_RUNNER_MUST_OOM in prompt:
mlx_force_oom()
if EXO_RUNNER_MUST_TIMEOUT in prompt:
time.sleep(100)

View File

@@ -1,64 +0,0 @@
import asyncio
import contextlib
import sys
import psutil
from loguru import logger
from exo.shared.types.memory import Memory
from exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata
async def kill_process_tree(runner_process: asyncio.subprocess.Process) -> None:
"""Kill the process and all its children forcefully."""
if runner_process.returncode is not None:
return # Process already dead
try:
# Get the main process
pid = runner_process.pid
# Find all child processes
try:
parent = psutil.Process(pid)
children = parent.children(recursive=True)
# Kill all children first (bottom-up)
for child in reversed(children):
with contextlib.suppress(psutil.NoSuchProcess, psutil.AccessDenied):
child.kill() # SIGKILL
# Kill the parent
with contextlib.suppress(psutil.NoSuchProcess, psutil.AccessDenied):
parent.kill() # SIGKILL
except psutil.NoSuchProcess:
# Process already gone, try subprocess kill anyway
runner_process.kill()
# Wait for the subprocess to exit
try:
await asyncio.wait_for(runner_process.wait(), timeout=2.0)
except asyncio.TimeoutError:
logger.error(f"Process {pid} did not exit after kill signal")
except Exception as e:
logger.error(f"Error killing process tree: {e}")
def get_runner_command() -> list[str]:
python = sys.executable
return [python, "-m", "exo.worker.runner.runner"]
def get_weights_size(model_shard_meta: ShardMetadata) -> Memory:
return Memory.from_float_kb(
(model_shard_meta.end_layer - model_shard_meta.start_layer)
/ model_shard_meta.n_layers
* model_shard_meta.model_meta.storage_size.in_kb
/ (
1
if isinstance(model_shard_meta, PipelineShardMetadata)
else model_shard_meta.world_size
)
)

View File

@@ -0,0 +1,57 @@
Unit Tests
1. Test worker plans as expected
- State transitions are correct
- Unexpected states throw
2. Test runner
- Stays loaded
- Unloads under end condition
- Accepts tasks
- Returns ChunkGenerated events
3. Test mlx engine
- Autoparallel on n of the same nodes returns tensors with 1/n size
- mx.barrier forces computation
- Distributed init returns expected configuration
- initialize_mlx sets wired limit
- shard_and_load returns expected model
- Quantization returns quantized layers
4. Download
- hits the correct endpoint
- normalizes tags correctly
- updates download progress
5. Serialization/Deserialization of tagged models
Integration tests:
1. Test model inference is "sensible" (per-configuration)
- Non-empty response
- Sensible inference speed
- Answers are non-gibberish for many seeds (What is the capital of France? -> "Paris" in answer.)
- Answer is the same for particular seed
2. Test that node count does not affect inference result (per-configuration)
- Llama on 1 node, and on 2 nodes returns the same result, given temperature 0 and set seed.
- Do for all configurations (Ring/Ibv, Pipeline/Tensor)
3. Test supervisor catches exceptions gracefully
- Timeouts
- OOM
- MLX error
4. distributed init memory requirements are as expected
5. MLX
- KVCache size is same length as prompt tokens
- Prefix cache (once implemented)
6. Spin up creates a runner or goes to failed status
Regression tests:
1. Per-configuration baseline performance - no 20% drop in performance (device, node count, model, strategy, backend)

View File

@@ -1,165 +0,0 @@
from typing import Callable
import pytest
from exo.shared.models.model_meta import get_model_meta
from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
from exo.shared.types.common import Host, NodeId
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.tasks import (
ChatCompletionTask,
TaskId,
TaskStatus,
)
from exo.shared.types.worker.common import InstanceId
from exo.shared.types.worker.instances import Instance, InstanceStatus
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
from exo.shared.types.worker.shards import PipelineShardMetadata
from exo.worker.main import Worker
from exo.worker.tests.constants import (
COMMAND_1_ID,
INSTANCE_1_ID,
MODEL_A_ID,
NODE_A,
NODE_B,
RUNNER_1_ID,
TASK_1_ID,
)
from .worker_management import (
WorkerMailbox,
create_worker_and_mailbox,
create_worker_void_mailbox,
create_worker_with_old_mailbox,
)
@pytest.fixture
def worker_void_mailbox() -> Worker:
return create_worker_void_mailbox(NODE_A)
@pytest.fixture
def worker_and_mailbox() -> tuple[Worker, WorkerMailbox]:
return create_worker_and_mailbox(NODE_A)
@pytest.fixture
def two_workers_with_shared_mailbox() -> tuple[Worker, Worker, WorkerMailbox]:
worker1, mailbox = create_worker_and_mailbox(NODE_A)
worker2 = create_worker_with_old_mailbox(NODE_B, mailbox)
return worker1, worker2, mailbox
@pytest.fixture
def user_message() -> str:
"""Override this fixture in tests to customize the message"""
return "Hello, how are you?"
@pytest.fixture
async def model_meta() -> ModelMetadata:
return await get_model_meta("mlx-community/Llama-3.2-1B-Instruct-4bit")
@pytest.fixture
def hosts():
def _hosts(count: int, offset: int = 0) -> list[Host]:
return [
Host(
ip="127.0.0.1",
port=5000 + offset + i,
)
for i in range(count)
]
return _hosts
@pytest.fixture
def pipeline_shard_meta(
model_meta: ModelMetadata,
) -> Callable[[int, int], PipelineShardMetadata]:
def _pipeline_shard_meta(
num_nodes: int = 1, device_rank: int = 0
) -> PipelineShardMetadata:
total_layers = model_meta.n_layers
layers_per_node = total_layers // num_nodes
start_layer = device_rank * layers_per_node
end_layer = (
start_layer + layers_per_node
if device_rank < num_nodes - 1
else total_layers
)
return PipelineShardMetadata(
model_meta=model_meta,
device_rank=device_rank,
n_layers=total_layers,
start_layer=start_layer,
end_layer=end_layer,
world_size=num_nodes,
)
return _pipeline_shard_meta
@pytest.fixture
def instance(
pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata],
hosts: Callable[[int], list[Host]],
):
def _instance(
instance_id: InstanceId | None = None,
node_id: NodeId | None = None,
runner_id: RunnerId | None = None,
model_id: ModelId | None = None,
) -> Instance:
resolved_instance_id = instance_id if instance_id is not None else INSTANCE_1_ID
resolved_node_id = node_id if node_id is not None else NODE_A
resolved_runner_id = runner_id if runner_id is not None else RUNNER_1_ID
resolved_model_id = model_id if model_id is not None else MODEL_A_ID
shard_assignments = ShardAssignments(
model_id=resolved_model_id,
runner_to_shard={resolved_runner_id: pipeline_shard_meta(1, 0)},
node_to_runner={resolved_node_id: resolved_runner_id},
)
return Instance(
instance_id=resolved_instance_id,
instance_type=InstanceStatus.Active,
shard_assignments=shard_assignments,
hosts=hosts(1),
)
return _instance
@pytest.fixture
def completion_create_params(user_message: str) -> ChatCompletionTaskParams:
return ChatCompletionTaskParams(
model="gpt-4",
messages=[ChatCompletionMessage(role="user", content=user_message)],
stream=True,
)
@pytest.fixture
def chat_completion_task(completion_create_params: ChatCompletionTaskParams):
def _chat_completion_task(
instance_id: InstanceId | None = None,
task_id: TaskId | None = None,
user_message: str = "Hello",
) -> ChatCompletionTask:
resolved_instance_id = instance_id if instance_id is not None else INSTANCE_1_ID
resolved_task_id = task_id if task_id is not None else TASK_1_ID
return ChatCompletionTask(
task_id=resolved_task_id,
command_id=COMMAND_1_ID,
instance_id=resolved_instance_id,
task_status=TaskStatus.Pending,
task_params=completion_create_params,
)
return _chat_completion_task

View File

@@ -3,7 +3,7 @@ from typing import Final
from exo.shared.types.common import CommandId, NodeId from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.models import ModelId from exo.shared.types.models import ModelId
from exo.shared.types.tasks import TaskId from exo.shared.types.tasks import TaskId
from exo.shared.types.worker.common import InstanceId, RunnerId from exo.shared.types.worker.instances import InstanceId, RunnerId
MASTER_NODE_ID = NodeId("ffffffff-aaaa-4aaa-8aaa-aaaaaaaaaaaa") MASTER_NODE_ID = NodeId("ffffffff-aaaa-4aaa-8aaa-aaaaaaaaaaaa")

View File

@@ -1,49 +0,0 @@
import time
from typing import Callable
import pytest
from exo.shared.types.worker.shards import PipelineShardMetadata
from exo.worker.download.impl_shard_downloader import exo_shard_downloader
from exo.worker.download.shard_downloader import ShardDownloader
@pytest.mark.slow
@pytest.mark.asyncio
async def test_shard_downloader(
pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata],
):
shard_downloader: ShardDownloader = exo_shard_downloader()
shard_downloader.on_progress(
lambda shard, progress: print(f"Download progress: {progress}")
)
shard_metadata = pipeline_shard_meta(1, 0)
path = await shard_downloader.ensure_shard(shard_metadata)
assert path.exists()
downloaded_model_path = path.parent / "mlx-community--Llama-3.2-1B-Instruct-4bit"
assert (downloaded_model_path / "config.json").exists()
assert (downloaded_model_path / "model.safetensors").exists()
assert (downloaded_model_path / "model.safetensors.index.json").exists()
assert (downloaded_model_path / "special_tokens_map.json").exists()
assert (downloaded_model_path / "tokenizer.json").exists()
assert (downloaded_model_path / "tokenizer_config.json").exists()
expected_files_and_sizes = [
("config.json", 1121),
("model.safetensors", 695283921),
("model.safetensors.index.json", 26159),
("special_tokens_map.json", 296),
("tokenizer.json", 17209920),
("tokenizer_config.json", 54558),
]
for filename, expected_size in expected_files_and_sizes:
file_path = downloaded_model_path / filename
assert file_path.stat().st_size == expected_size, f"{filename} size mismatch"
start_time = time.monotonic()
path_again = await shard_downloader.ensure_shard(shard_metadata)
duration = time.monotonic() - start_time
assert path_again == path
assert duration < 5, f"Second call to ensure_shard took too long: {duration:.2f}s"

View File

@@ -1,65 +0,0 @@
from typing import Callable
import pytest
from exo.shared.types.common import NodeId
from exo.shared.types.worker.common import InstanceId
from exo.shared.types.worker.instances import Instance
from exo.shared.types.worker.ops import (
AssignRunnerOp,
RunnerUpOp,
)
from exo.shared.types.worker.runners import RunnerId
from exo.worker.main import Worker
from exo.worker.tests.constants import INSTANCE_1_ID, RUNNER_1_ID
@pytest.fixture
def user_message():
return "What, according to Douglas Adams, is the meaning of life, the universe and everything?"
# TODO: instance_id and runner_id are selectable.
@pytest.fixture
async def worker_with_assigned_runner(
worker_void_mailbox: Worker,
instance: Callable[[InstanceId, NodeId, RunnerId], Instance],
):
"""Fixture that provides a worker with an already assigned runner."""
worker = worker_void_mailbox
instance_id = INSTANCE_1_ID
runner_id = RUNNER_1_ID
instance_obj: Instance = instance(instance_id, worker.node_id, runner_id)
# Assign the runner
assign_op = AssignRunnerOp(
runner_id=runner_id,
shard_metadata=instance_obj.shard_assignments.runner_to_shard[runner_id],
hosts=instance_obj.hosts,
instance_id=instance_obj.instance_id,
)
async for _ in worker.execute_op(assign_op):
pass
return worker, instance_obj
@pytest.fixture
async def worker_with_running_runner(
worker_with_assigned_runner: tuple[Worker, Instance],
):
"""Fixture that provides a worker with an already assigned runner."""
worker, instance_obj = worker_with_assigned_runner
runner_up_op = RunnerUpOp(runner_id=RUNNER_1_ID)
async for _ in worker.execute_op(runner_up_op):
pass
# Is the runner actually running?
supervisor = next(iter(worker.assigned_runners.values())).runner
assert supervisor is not None
assert supervisor.runner_process.is_alive()
return worker, instance_obj

View File

@@ -1,171 +0,0 @@
from typing import Callable
import pytest
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.common import NodeId
from exo.shared.types.events import (
ChunkGenerated,
RunnerDeleted,
RunnerStatusUpdated,
TaskStateUpdated,
)
from exo.shared.types.tasks import ChatCompletionTask, TaskStatus
from exo.shared.types.worker.common import RunnerId
from exo.shared.types.worker.instances import Instance, InstanceId
from exo.shared.types.worker.ops import (
AssignRunnerOp,
ExecuteTaskOp,
RunnerDownOp,
RunnerUpOp,
UnassignRunnerOp,
)
from exo.shared.types.worker.runners import (
DownloadingRunnerStatus,
InactiveRunnerStatus,
LoadedRunnerStatus,
RunningRunnerStatus,
StartingRunnerStatus,
)
from exo.worker.main import Worker
from exo.worker.tests.constants import (
RUNNER_1_ID,
)
from exo.worker.tests.test_handlers.utils import read_events_op
@pytest.mark.asyncio
async def test_assign_op(
worker_void_mailbox: Worker,
instance: Callable[[InstanceId, NodeId, RunnerId], Instance],
):
worker = worker_void_mailbox
instance_obj: Instance = instance(InstanceId(), worker.node_id, RUNNER_1_ID)
assign_op = AssignRunnerOp(
runner_id=RUNNER_1_ID,
shard_metadata=instance_obj.shard_assignments.runner_to_shard[RUNNER_1_ID],
hosts=instance_obj.hosts,
instance_id=instance_obj.instance_id,
)
events = await read_events_op(worker, assign_op)
# We should have a status update saying 'starting'.
assert len(events) == 2
assert isinstance(events[0], RunnerStatusUpdated)
assert isinstance(events[0].runner_status, DownloadingRunnerStatus)
assert isinstance(events[1], RunnerStatusUpdated)
assert isinstance(events[1].runner_status, InactiveRunnerStatus)
# And the runner should be assigned
assert RUNNER_1_ID in worker.assigned_runners
assert isinstance(worker.assigned_runners[RUNNER_1_ID].status, InactiveRunnerStatus)
@pytest.mark.asyncio
async def test_unassign_op(worker_with_assigned_runner: tuple[Worker, Instance]):
worker, _ = worker_with_assigned_runner
unassign_op = UnassignRunnerOp(runner_id=RUNNER_1_ID)
events = await read_events_op(worker, unassign_op)
# We should have no assigned runners and no events were emitted
assert len(worker.assigned_runners) == 0
assert len(events) == 1
assert isinstance(events[0], RunnerDeleted)
@pytest.mark.asyncio
async def test_runner_up_op(
worker_with_assigned_runner: tuple[Worker, Instance],
chat_completion_task: Callable[[], ChatCompletionTask],
):
worker, _ = worker_with_assigned_runner
runner_up_op = RunnerUpOp(runner_id=RUNNER_1_ID)
events = await read_events_op(worker, runner_up_op)
assert len(events) == 2
assert isinstance(events[0], RunnerStatusUpdated)
assert isinstance(events[0].runner_status, StartingRunnerStatus)
assert isinstance(events[1], RunnerStatusUpdated)
assert isinstance(events[1].runner_status, LoadedRunnerStatus)
# Is the runner actually running?
supervisor = next(iter(worker.assigned_runners.values())).runner
assert supervisor is not None
assert supervisor.runner_process.is_alive()
full_response = ""
async for chunk in supervisor.stream_response(task=chat_completion_task()):
if isinstance(chunk, TokenChunk):
full_response += chunk.text
assert "42" in full_response.lower(), (
f"Expected '42' in response, but got: {full_response}"
)
runner = worker.assigned_runners[RUNNER_1_ID].runner
assert runner is not None
await runner.astop() # Neat cleanup.
@pytest.mark.asyncio
async def test_runner_down_op(worker_with_running_runner: tuple[Worker, Instance]):
worker, _ = worker_with_running_runner
runner_down_op = RunnerDownOp(runner_id=RUNNER_1_ID)
events = await read_events_op(worker, runner_down_op)
assert len(events) == 1
assert isinstance(events[0], RunnerStatusUpdated)
assert isinstance(events[0].runner_status, InactiveRunnerStatus)
@pytest.mark.asyncio
async def test_execute_task_op(
worker_with_running_runner: tuple[Worker, Instance],
chat_completion_task: Callable[[], ChatCompletionTask],
):
worker, _ = worker_with_running_runner
execute_task_op = ExecuteTaskOp(runner_id=RUNNER_1_ID, task=chat_completion_task())
events = await read_events_op(worker, execute_task_op)
assert len(events) > 20
print(f"{events=}")
assert isinstance(events[0], RunnerStatusUpdated)
assert isinstance(events[0].runner_status, RunningRunnerStatus)
assert isinstance(events[1], TaskStateUpdated)
assert events[1].task_status == TaskStatus.Running # It tried to start.
assert isinstance(events[-2], TaskStateUpdated)
assert events[-2].task_status == TaskStatus.Complete # It tried to start.
assert isinstance(events[-1], RunnerStatusUpdated)
assert isinstance(
events[-1].runner_status, LoadedRunnerStatus
) # It should not have failed.
gen_events: list[ChunkGenerated] = [
x for x in events if isinstance(x, ChunkGenerated)
]
text_chunks: list[TokenChunk] = [
x.chunk for x in gen_events if isinstance(x.chunk, TokenChunk)
]
assert len(text_chunks) == len(events) - 4
output_text = "".join([x.text for x in text_chunks])
assert "42" in output_text
runner = worker.assigned_runners[RUNNER_1_ID].runner
assert runner is not None
await runner.astop() # Neat cleanup.

View File

@@ -1,83 +0,0 @@
## Tests for worker state handlers
import asyncio
from typing import Callable
import pytest
from exo.shared.types.tasks import ChatCompletionTask
from exo.shared.types.worker.common import RunnerError
from exo.shared.types.worker.instances import Instance
from exo.shared.types.worker.ops import (
ExecuteTaskOp,
RunnerUpOp,
)
from exo.worker.main import Worker
from exo.worker.tests.constants import RUNNER_1_ID
from exo.worker.tests.test_handlers.utils import read_events_op
@pytest.mark.asyncio
async def test_runner_up_fails(
worker_with_assigned_runner: tuple[Worker, Instance],
chat_completion_task: Callable[[], ChatCompletionTask],
):
worker, _ = worker_with_assigned_runner
worker.assigned_runners[RUNNER_1_ID].shard_metadata.immediate_exception = True
runner_up_op = RunnerUpOp(runner_id=RUNNER_1_ID)
with pytest.raises(RunnerError):
await read_events_op(worker, runner_up_op)
@pytest.mark.asyncio
async def test_runner_up_timeouts(
worker_with_assigned_runner: tuple[Worker, Instance],
chat_completion_task: Callable[[], ChatCompletionTask],
):
worker, _ = worker_with_assigned_runner
worker.assigned_runners[RUNNER_1_ID].shard_metadata.should_timeout = 10
runner_up_op = RunnerUpOp(runner_id=RUNNER_1_ID)
with pytest.raises(asyncio.TimeoutError):
await read_events_op(worker, runner_up_op)
@pytest.mark.asyncio
async def test_execute_task_fails(
worker_with_running_runner: tuple[Worker, Instance],
chat_completion_task: Callable[[], ChatCompletionTask],
):
worker, _ = worker_with_running_runner
task = chat_completion_task()
messages = task.task_params.messages
messages[0].content = "Artificial prompt: EXO RUNNER MUST FAIL"
execute_task_op = ExecuteTaskOp(runner_id=RUNNER_1_ID, task=task)
with pytest.raises(RunnerError):
await read_events_op(worker, execute_task_op)
@pytest.mark.asyncio
async def test_execute_task_timeouts(
worker_with_running_runner: tuple[Worker, Instance],
chat_completion_task: Callable[[], ChatCompletionTask],
):
worker, _ = worker_with_running_runner
task = chat_completion_task()
messages = task.task_params.messages
messages[0].content = "Artificial prompt: EXO RUNNER MUST TIMEOUT"
execute_task_op = ExecuteTaskOp(runner_id=RUNNER_1_ID, task=task)
with pytest.raises(asyncio.TimeoutError):
await read_events_op(worker, execute_task_op)
# TODO: Much more to do here!
# runner assigned download stuff

View File

@@ -1,17 +0,0 @@
## Tests for worker state handlers
from exo.shared.types.events import (
Event,
)
from exo.shared.types.worker.ops import (
RunnerOp,
)
from exo.worker.main import Worker
async def read_events_op(worker: Worker, op: RunnerOp) -> list[Event]:
events: list[Event] = []
async for event in worker.execute_op(op):
events.append(event)
return events

View File

@@ -1,262 +0,0 @@
import asyncio
from typing import Callable
import pytest
from anyio import create_task_group
from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
from exo.shared.types.common import CommandId, Host, NodeId
from exo.shared.types.events import (
InstanceCreated,
InstanceDeleted,
TaskCreated,
)
from exo.shared.types.models import ModelId
from exo.shared.types.tasks import (
ChatCompletionTask,
Task,
TaskId,
TaskStatus,
)
from exo.shared.types.worker.common import InstanceId, RunnerId
from exo.shared.types.worker.instances import (
Instance,
InstanceStatus,
ShardAssignments,
)
from exo.shared.types.worker.shards import PipelineShardMetadata
from exo.worker.main import Worker
from exo.worker.tests.constants import (
INSTANCE_1_ID,
MASTER_NODE_ID,
NODE_A,
NODE_B,
RUNNER_1_ID,
RUNNER_2_ID,
TASK_1_ID,
)
from exo.worker.tests.worker_management import (
WorkerMailbox,
read_streaming_response,
)
@pytest.fixture
def user_message():
"""Override this fixture in tests to customize the message"""
return "What's the capital of Japan?"
async def test_runner_inference(
instance: Callable[[InstanceId, NodeId, RunnerId], Instance],
chat_completion_task: Callable[[InstanceId, TaskId], Task],
worker_and_mailbox: tuple[Worker, WorkerMailbox],
):
worker, global_events = worker_and_mailbox
async with create_task_group() as tg:
tg.start_soon(worker.run)
instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID)
instance_value.instance_type = InstanceStatus.Active
task: Task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID)
await global_events.append_events(
[
InstanceCreated(
instance=instance_value,
),
TaskCreated(task_id=task.task_id, task=task),
],
origin=MASTER_NODE_ID,
)
# TODO: This needs to get fixed - sometimes it misses the 'starting' event.
(
seen_task_started,
seen_task_finished,
response_string,
_,
) = await read_streaming_response(global_events)
assert seen_task_started
assert seen_task_finished
assert "tokyo" in response_string.lower()
await global_events.append_events(
[
InstanceDeleted(
instance_id=instance_value.instance_id,
),
],
origin=MASTER_NODE_ID,
)
await asyncio.sleep(0.3)
worker.shutdown()
# TODO: Ensure this is sufficient, or add mechanism to fail the test gracefully if workers do not shutdown properly.
async def test_2_runner_inference(
pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata],
hosts: Callable[[int], list[Host]],
chat_completion_task: Callable[[InstanceId, TaskId], Task],
two_workers_with_shared_mailbox: tuple[Worker, Worker, WorkerMailbox],
):
worker1, worker2, global_events = two_workers_with_shared_mailbox
async with create_task_group() as tg:
tg.start_soon(worker1.run)
tg.start_soon(worker2.run)
## Instance
model_id = ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit")
shard_assignments = ShardAssignments(
model_id=model_id,
runner_to_shard={
RUNNER_1_ID: pipeline_shard_meta(2, 0),
RUNNER_2_ID: pipeline_shard_meta(2, 1),
},
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
)
instance = Instance(
instance_id=INSTANCE_1_ID,
instance_type=InstanceStatus.Active,
shard_assignments=shard_assignments,
hosts=hosts(2),
)
task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID)
await global_events.append_events(
[
InstanceCreated(instance=instance),
TaskCreated(task_id=task.task_id, task=task),
],
origin=MASTER_NODE_ID,
)
(
seen_task_started,
seen_task_finished,
response_string,
_,
) = await read_streaming_response(global_events)
assert seen_task_started
assert seen_task_finished
assert "tokyo" in response_string.lower()
_ = global_events.collect()
await asyncio.sleep(1.0)
events = global_events.collect()
assert len(events) == 0
await global_events.append_events(
[
InstanceDeleted(
instance_id=instance.instance_id,
),
],
origin=MASTER_NODE_ID,
)
await asyncio.sleep(2.0)
worker1.shutdown()
worker2.shutdown()
# TODO: Ensure this is sufficient, or add mechanism to fail the test gracefully if workers do not shutdown properly.
# TODO: Multi message parallel
async def test_2_runner_multi_message(
pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata],
hosts: Callable[[int], list[Host]],
two_workers_with_shared_mailbox: tuple[Worker, Worker, WorkerMailbox],
):
worker1, worker2, global_events = two_workers_with_shared_mailbox
async with create_task_group() as tg:
tg.start_soon(worker1.run)
tg.start_soon(worker2.run)
## Instance
model_id = ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit")
shard_assignments = ShardAssignments(
model_id=model_id,
runner_to_shard={
RUNNER_1_ID: pipeline_shard_meta(2, 0),
RUNNER_2_ID: pipeline_shard_meta(2, 1),
},
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
)
instance = Instance(
instance_id=INSTANCE_1_ID,
instance_type=InstanceStatus.Active,
shard_assignments=shard_assignments,
hosts=hosts(2),
)
# Task - we have three messages here, which is what the task is about
completion_create_params = ChatCompletionTaskParams(
model="gpt-4",
messages=[
ChatCompletionMessage(
role="user", content="What is the capital of France?"
),
ChatCompletionMessage(
role="assistant", content="The capital of France is Paris."
),
ChatCompletionMessage(
role="user",
content="Ok great. Now write me a haiku about what you can do there.",
),
],
stream=True,
)
task = ChatCompletionTask(
task_id=TASK_1_ID,
command_id=CommandId(),
instance_id=INSTANCE_1_ID,
task_status=TaskStatus.Pending,
task_params=completion_create_params,
)
await global_events.append_events(
[
InstanceCreated(instance=instance),
TaskCreated(task_id=task.task_id, task=task),
],
origin=MASTER_NODE_ID,
)
(
seen_task_started,
seen_task_finished,
response_string,
_,
) = await read_streaming_response(global_events)
assert seen_task_started
assert seen_task_finished
assert any(
keyword in response_string.lower()
for keyword in ("kiss", "paris", "art", "love")
)
_ = global_events.collect()
await asyncio.sleep(1.0)
events = global_events.collect()
assert len(events) == 0
await global_events.append_events(
[
InstanceDeleted(
instance_id=instance.instance_id,
),
],
origin=MASTER_NODE_ID,
)
worker1.shutdown()
worker2.shutdown()
# TODO: Ensure this is sufficient, or add mechanism to fail the test gracefully if workers do not shutdown properly.

View File

@@ -1,311 +0,0 @@
import asyncio
from collections.abc import AsyncGenerator
from types import CoroutineType
from typing import Any, Callable
import pytest
from _pytest.monkeypatch import MonkeyPatch
from anyio import create_task_group
from exo.shared.types.chunks import GenerationChunk, TokenChunk
# TaskStateUpdated and ChunkGenerated are used in test_worker_integration_utils.py
from exo.shared.types.common import NodeId
from exo.shared.types.events import (
ChunkGenerated,
InstanceCreated,
InstanceDeleted,
RunnerStatusUpdated,
TaskCreated,
TaskFailed,
TaskStateUpdated,
)
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.worker.common import InstanceId, RunnerId
from exo.shared.types.worker.instances import (
Instance,
InstanceStatus,
)
from exo.shared.types.worker.runners import FailedRunnerStatus
from exo.worker.main import Worker
from exo.worker.runner.runner_supervisor import RunnerSupervisor
from exo.worker.tests.constants import (
INSTANCE_1_ID,
MASTER_NODE_ID,
NODE_A,
RUNNER_1_ID,
TASK_1_ID,
)
from exo.worker.tests.worker_management import (
WorkerMailbox,
until_event_with_timeout,
)
@pytest.fixture
def user_message():
"""Override this fixture in tests to customize the message"""
return "Who is the longest ruling monarch of England?"
async def test_stream_response_failed_always(
monkeypatch: MonkeyPatch,
instance: Callable[[InstanceId, NodeId, RunnerId], Instance],
chat_completion_task: Callable[[InstanceId, TaskId], Task],
worker_and_mailbox: tuple[Worker, WorkerMailbox],
) -> None:
worker, global_events = worker_and_mailbox
async with create_task_group() as tg:
tg.start_soon(worker.run)
instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID)
instance_value.instance_type = InstanceStatus.Active
async def mock_stream_response(
self: RunnerSupervisor,
task: Task,
request_started_callback: Callable[..., CoroutineType[Any, Any, None]]
| None = None,
) -> AsyncGenerator[GenerationChunk, None]:
raise RuntimeError("Simulated stream response failure")
monkeypatch.setattr(RunnerSupervisor, "stream_response", mock_stream_response)
task: Task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID)
await global_events.append_events(
[
InstanceCreated(instance=instance_value),
TaskCreated(task_id=task.task_id, task=task),
],
origin=MASTER_NODE_ID,
)
await until_event_with_timeout(global_events, InstanceDeleted, timeout=10.0)
events = global_events.collect()
assert (
len(
[
x
for x in events
if isinstance(x.event, RunnerStatusUpdated)
and isinstance(x.event.runner_status, FailedRunnerStatus)
]
)
== 3
)
assert (
len(
[
x
for x in events
if isinstance(x.event, TaskStateUpdated)
and x.event.task_status == TaskStatus.Failed
]
)
== 3
)
assert any([isinstance(x.event, InstanceDeleted) for x in events])
await global_events.append_events(
[
InstanceDeleted(
instance_id=instance_value.instance_id,
),
],
origin=MASTER_NODE_ID,
)
await asyncio.sleep(0.3)
worker.shutdown()
async def test_stream_response_failed_once(
monkeypatch: MonkeyPatch,
instance: Callable[[InstanceId, NodeId, RunnerId], Instance],
chat_completion_task: Callable[[InstanceId, TaskId], Task],
worker_and_mailbox: tuple[Worker, WorkerMailbox],
):
worker, global_events = worker_and_mailbox
failed_already = False
original_stream_response = RunnerSupervisor.stream_response
async def mock_stream_response(
self: RunnerSupervisor,
task: Task,
request_started_callback: Callable[..., CoroutineType[Any, Any, None]]
| None = None,
) -> AsyncGenerator[GenerationChunk]:
nonlocal failed_already
if not failed_already:
failed_already = True
raise RuntimeError("Simulated stream response failure")
else:
async for event in original_stream_response(
self, task, request_started_callback
):
yield event
return
monkeypatch.setattr(RunnerSupervisor, "stream_response", mock_stream_response)
async with create_task_group() as tg:
tg.start_soon(worker.run)
instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID)
instance_value.instance_type = InstanceStatus.Active
task: Task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID)
await global_events.append_events(
[
InstanceCreated(instance=instance_value),
TaskCreated(task_id=task.task_id, task=task),
],
origin=MASTER_NODE_ID,
)
await until_event_with_timeout(
global_events,
ChunkGenerated,
1,
condition=lambda x: isinstance(x.chunk, TokenChunk)
and x.chunk.finish_reason is not None,
timeout=30.0,
)
# TODO: The ideal with this test is if we had some tooling to scroll through the state, and say
# 'asser that there was a time that the error_type, error_message was not none and the failure count was nonzero'
# as we reset the failures back to zero when we have a successful inference.
assert len(worker.assigned_runners[RUNNER_1_ID].failures) == 0
assert worker.state.tasks[TASK_1_ID].error_type is None
assert worker.state.tasks[TASK_1_ID].error_message is None
events = global_events.collect()
assert (
len(
[
x
for x in events
if isinstance(x.event, RunnerStatusUpdated)
and isinstance(x.event.runner_status, FailedRunnerStatus)
]
)
== 1
)
assert (
len(
[
x
for x in events
if isinstance(x.event, TaskStateUpdated)
and x.event.task_status == TaskStatus.Failed
]
)
== 1
)
response_string = ""
events = global_events.collect()
seen_task_started, seen_task_finished = False, False
for wrapped_event in events:
event = wrapped_event.event
if isinstance(event, TaskStateUpdated):
if event.task_status == TaskStatus.Running:
seen_task_started = True
if event.task_status == TaskStatus.Complete:
seen_task_finished = True
if isinstance(event, ChunkGenerated):
assert isinstance(event.chunk, TokenChunk)
response_string += event.chunk.text
assert "queen" in response_string.lower()
assert seen_task_started
assert seen_task_finished
await global_events.append_events(
[
InstanceDeleted(
instance_id=instance_value.instance_id,
),
],
origin=MASTER_NODE_ID,
)
await asyncio.sleep(0.3)
worker.shutdown()
async def test_stream_response_timeout(
instance: Callable[[InstanceId, NodeId, RunnerId], Instance],
chat_completion_task: Callable[[InstanceId, TaskId], Task],
worker_and_mailbox: tuple[Worker, WorkerMailbox],
):
worker, global_events = worker_and_mailbox
async with create_task_group() as tg:
tg.start_soon(worker.run)
instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID)
instance_value.instance_type = InstanceStatus.Active
task: Task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID)
task.task_params.messages[0].content = "EXO RUNNER MUST TIMEOUT"
await global_events.append_events(
[
InstanceCreated(instance=instance_value),
TaskCreated(task_id=task.task_id, task=task),
],
origin=MASTER_NODE_ID,
)
await until_event_with_timeout(
global_events, TaskFailed, multiplicity=3, timeout=30.0
)
events = global_events.collect()
print(events)
assert (
len(
[
x
for x in events
if isinstance(x.event, RunnerStatusUpdated)
and isinstance(x.event.runner_status, FailedRunnerStatus)
]
)
== 3
)
assert (
len(
[
x
for x in events
if isinstance(x.event, TaskStateUpdated)
and x.event.task_status == TaskStatus.Failed
]
)
== 3
)
assert (
len(
[
x
for x in events
if isinstance(x.event, TaskFailed)
and "timeouterror" in x.event.error_type.lower()
]
)
== 3
)
await global_events.append_events(
[
InstanceDeleted(
instance_id=instance_value.instance_id,
),
],
origin=MASTER_NODE_ID,
)
await asyncio.sleep(0.3)
worker.shutdown()

View File

@@ -1,71 +0,0 @@
from typing import Callable
from anyio import create_task_group
# TaskStateUpdated and ChunkGenerated are used in test_worker_integration_utils.py
from exo.shared.types.common import NodeId
# TaskStateUpdated and ChunkGenerated are used in test_worker_integration_utils.py
from exo.shared.types.events import (
InstanceCreated,
InstanceDeleted,
RunnerStatusUpdated,
)
from exo.shared.types.worker.common import InstanceId, RunnerId
from exo.shared.types.worker.instances import (
Instance,
InstanceStatus,
)
from exo.shared.types.worker.runners import (
FailedRunnerStatus,
)
from exo.worker.main import Worker
from exo.worker.tests.constants import (
INSTANCE_1_ID,
MASTER_NODE_ID,
NODE_A,
RUNNER_1_ID,
)
from exo.worker.tests.worker_management import WorkerMailbox, until_event_with_timeout
async def test_runner_spinup_timeout(
instance: Callable[[InstanceId, NodeId, RunnerId], Instance],
worker_and_mailbox: tuple[Worker, WorkerMailbox],
):
worker, global_events = worker_and_mailbox
async with create_task_group() as tg:
tg.start_soon(worker.run)
instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID)
instance_value.instance_type = InstanceStatus.Active
instance_value.shard_assignments.runner_to_shard[
RUNNER_1_ID
].should_timeout = 10
await global_events.append_events(
[InstanceCreated(instance=instance_value)], origin=MASTER_NODE_ID
)
await until_event_with_timeout(
global_events,
RunnerStatusUpdated,
multiplicity=3,
condition=lambda x: isinstance(x.runner_status, FailedRunnerStatus),
)
# Ensure the correct events have been emitted
events = global_events.collect()
assert (
len(
[
x
for x in events
if isinstance(x.event, RunnerStatusUpdated)
and isinstance(x.event.runner_status, FailedRunnerStatus)
]
)
== 3
)
assert any([isinstance(x.event, InstanceDeleted) for x in events])
worker.shutdown()

View File

@@ -1,109 +0,0 @@
import asyncio
from typing import Callable
from anyio import create_task_group
# TaskStateUpdated and ChunkGenerated are used in test_worker_integration_utils.py
from exo.shared.types.common import NodeId
# TaskStateUpdated and ChunkGenerated are used in test_worker_integration_utils.py
from exo.shared.types.events import (
InstanceCreated,
InstanceDeleted,
RunnerStatusUpdated,
)
from exo.shared.types.worker.common import InstanceId, RunnerId
from exo.shared.types.worker.instances import (
Instance,
InstanceStatus,
)
from exo.shared.types.worker.runners import (
FailedRunnerStatus,
)
from exo.worker.main import Worker
from exo.worker.tests.constants import (
INSTANCE_1_ID,
MASTER_NODE_ID,
NODE_A,
RUNNER_1_ID,
)
from exo.worker.tests.worker_management import WorkerMailbox, until_event_with_timeout
async def test_runner_spinup_exception(
instance: Callable[[InstanceId, NodeId, RunnerId], Instance],
worker_and_mailbox: tuple[Worker, WorkerMailbox],
):
worker, global_events = worker_and_mailbox
async with create_task_group() as tg:
tg.start_soon(worker.run)
instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID)
instance_value.instance_type = InstanceStatus.Active
instance_value.shard_assignments.runner_to_shard[
RUNNER_1_ID
].immediate_exception = True
await global_events.append_events(
[InstanceCreated(instance=instance_value)], origin=MASTER_NODE_ID
)
await asyncio.sleep(10.0)
# Ensure the correct events have been emitted
events = global_events.collect()
assert (
len(
[
x
for x in events
if isinstance(x.event, RunnerStatusUpdated)
and isinstance(x.event.runner_status, FailedRunnerStatus)
]
)
== 3
)
assert any([isinstance(x.event, InstanceDeleted) for x in events])
worker.shutdown()
async def test_runner_spinup_timeout(
instance: Callable[[InstanceId, NodeId, RunnerId], Instance],
worker_and_mailbox: tuple[Worker, WorkerMailbox],
):
worker, global_events = worker_and_mailbox
async with create_task_group() as tg:
tg.start_soon(worker.run)
instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID)
instance_value.instance_type = InstanceStatus.Active
instance_value.shard_assignments.runner_to_shard[
RUNNER_1_ID
].should_timeout = 10
await global_events.append_events(
[InstanceCreated(instance=instance_value)], origin=MASTER_NODE_ID
)
await until_event_with_timeout(
global_events,
RunnerStatusUpdated,
multiplicity=3,
condition=lambda x: isinstance(x.runner_status, FailedRunnerStatus),
)
# Ensure the correct events have been emitted
events = global_events.collect()
assert (
len(
[
x
for x in events
if isinstance(x.event, RunnerStatusUpdated)
and isinstance(x.event.runner_status, FailedRunnerStatus)
]
)
== 3
)
assert any([isinstance(x.event, InstanceDeleted) for x in events])
worker.shutdown()

View File

@@ -1,525 +0,0 @@
import asyncio
import os
import time
from typing import Callable
import pytest
from anyio import create_task_group
from exo.shared.models.model_meta import get_model_meta
from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
from exo.shared.types.common import Host
from exo.shared.types.events import (
ChunkGenerated,
InstanceCreated,
InstanceDeleted,
RunnerStatusUpdated,
TaskCreated,
)
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.tasks import (
ChatCompletionTask,
Task,
TaskId,
TaskStatus,
)
from exo.shared.types.worker.common import InstanceId
from exo.shared.types.worker.instances import (
Instance,
InstanceStatus,
ShardAssignments,
)
from exo.shared.types.worker.runners import LoadedRunnerStatus
from exo.shared.types.worker.shards import PipelineShardMetadata
from exo.worker.main import Worker
from exo.worker.tests.constants import (
COMMAND_1_ID,
COMMAND_2_ID,
INSTANCE_1_ID,
MASTER_NODE_ID,
NODE_A,
NODE_B,
RUNNER_1_ID,
RUNNER_2_ID,
TASK_1_ID,
TASK_2_ID,
)
from exo.worker.tests.worker_management import (
WorkerMailbox,
read_streaming_response,
until_event_with_timeout,
)
MODEL_ID = "mlx-community/Llama-3.3-70B-Instruct-4bit"
SKIP = True
@pytest.fixture
async def model_meta() -> ModelMetadata:
return await get_model_meta(MODEL_ID)
def _get_model_size_gb(path: str) -> float:
"""Calculate total size of directory recursively in GB."""
total_size = 0
for dirpath, _, filenames in os.walk(path):
for filename in filenames:
filepath = os.path.join(dirpath, filename)
if os.path.isfile(filepath):
total_size += os.path.getsize(filepath)
return total_size / (1024**3) # Convert bytes to GB
skip = SKIP or not (
os.path.exists(
os.path.expanduser("~/.exo/models/mlx-community--Llama-3.3-70B-Instruct-4bit/")
)
and _get_model_size_gb(
os.path.expanduser("~/.exo/models/mlx-community--Llama-3.3-70B-Instruct-4bit/")
)
> 30
)
@pytest.mark.skipif(
skip,
reason="This test only runs when model mlx-community/Llama-3.3-70B-Instruct-4bit is downloaded",
)
async def test_ttft(
pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata],
hosts: Callable[[int], list[Host]],
worker_and_mailbox: tuple[Worker, WorkerMailbox],
):
from loguru import logger
worker, global_events = worker_and_mailbox
async with create_task_group() as tg:
tg.start_soon(worker.run)
## Instance
model_id = ModelId(MODEL_ID)
shard_assignments = ShardAssignments(
model_id=model_id,
runner_to_shard={RUNNER_1_ID: pipeline_shard_meta(1, 0)},
node_to_runner={NODE_A: RUNNER_1_ID},
)
instance = Instance(
instance_id=INSTANCE_1_ID,
instance_type=InstanceStatus.Active,
shard_assignments=shard_assignments,
hosts=hosts(1),
)
# Create instance first
await global_events.append_events(
[InstanceCreated(instance=instance)], origin=MASTER_NODE_ID
)
await until_event_with_timeout(
global_events,
event_type=RunnerStatusUpdated,
condition=lambda x: isinstance(x.runner_status, LoadedRunnerStatus),
)
logger.info("model loaded.")
# First inference
task1_params = ChatCompletionTaskParams(
model="gpt-4",
messages=[
ChatCompletionMessage(
role="user", content="Please write a haiku about a flower."
)
],
stream=True,
max_tokens=100,
)
task1 = ChatCompletionTask(
task_id=TASK_1_ID,
command_id=COMMAND_1_ID,
instance_id=INSTANCE_1_ID,
task_status=TaskStatus.Pending,
task_params=task1_params,
)
print("Starting first inference...")
# Clean out the current global events
_ = global_events.collect()
task_created_time_1 = time.time()
await global_events.append_events(
[TaskCreated(task_id=task1.task_id, task=task1)], origin=MASTER_NODE_ID
)
# Wait for first chunk to measure time to first token
first_chunk_seen_1 = False
time_to_first_token_1: None | float = None
while not first_chunk_seen_1:
event = (await global_events.receive()).event
if isinstance(event, ChunkGenerated) and hasattr(event, "chunk"):
first_chunk_time_1 = time.time()
time_to_first_token_1 = first_chunk_time_1 - task_created_time_1
first_chunk_seen_1 = True
break
(
_,
seen_task_finished_1,
response_string_1,
token_count_1,
) = await read_streaming_response(global_events)
total_time_1 = time.time() - task_created_time_1
assert seen_task_finished_1
# Wait for first task to complete
await asyncio.sleep(5.0)
# Second inference
task2_params = ChatCompletionTaskParams(
model="gpt-4",
messages=[
ChatCompletionMessage(
role="user", content="Write me a haiku about a robot."
)
],
stream=True,
max_tokens=150,
)
task2 = ChatCompletionTask(
task_id=TASK_2_ID,
command_id=COMMAND_2_ID,
instance_id=INSTANCE_1_ID,
task_status=TaskStatus.Pending,
task_params=task2_params,
)
print("Starting second inference...")
# Clean out the current global events
# Record the current event index before creating the second task
_ = global_events.collect()
task_created_time_2 = time.time()
await global_events.append_events(
[TaskCreated(task_id=task2.task_id, task=task2)], origin=MASTER_NODE_ID
)
# Wait for first chunk of second task to measure time to first token
first_chunk_seen_2 = False
time_to_first_token_2: float | None = None
while not first_chunk_seen_2:
event = (await global_events.receive()).event
if isinstance(event, ChunkGenerated) and hasattr(event, "chunk"):
first_chunk_time_2 = time.time()
time_to_first_token_2 = first_chunk_time_2 - task_created_time_2
first_chunk_seen_2 = True
break
(
_,
seen_task_finished_2,
response_string_2,
token_count_2,
) = await read_streaming_response(global_events, filter_task=TASK_2_ID)
total_time_2 = time.time() - task_created_time_2
assert seen_task_finished_2
assert time_to_first_token_1
assert time_to_first_token_2
# Calculate TPS metrics
# Prompt is approximately 45 tokens according to user
prompt_tokens = 45
# Prefill TPS = prompt tokens / time to first token
prefill_tps_1 = (
prompt_tokens / time_to_first_token_1 if time_to_first_token_1 > 0 else 0
)
prefill_tps_2 = (
prompt_tokens / time_to_first_token_2 if time_to_first_token_2 > 0 else 0
)
# Generation TPS = generated tokens / generation time
# Generation time = total time - time to first token
generation_time_1 = total_time_1 - time_to_first_token_1
generation_time_2 = total_time_2 - time_to_first_token_2
generation_tps_1 = (
token_count_1 / generation_time_1 if generation_time_1 > 0 else 0
)
generation_tps_2 = (
token_count_2 / generation_time_2 if generation_time_2 > 0 else 0
)
# Display time to first token profiling results
print("\n=== Time to First Token Profiling ===")
print(f"First inference ('{task1.task_params.messages[0].content}'):")
print(f" Time to first token: {time_to_first_token_1:.3f}s")
print(f" Total completion time: {total_time_1:.3f}s")
print(f" Tokens generated: {token_count_1}")
print(f" Response length: {len(response_string_1)} chars")
print(
f" Prefill TPS: {prefill_tps_1:.1f} tokens/sec ({prompt_tokens} prompt tokens / {time_to_first_token_1:.3f}s)"
)
print(
f" Generation TPS: {generation_tps_1:.1f} tokens/sec ({token_count_1} tokens / {generation_time_1:.3f}s)"
)
print(f"\nSecond inference ('{task2.task_params.messages[0].content}'):")
print(f" Time to first token: {time_to_first_token_2:.3f}s")
print(f" Total completion time: {total_time_2:.3f}s")
print(f" Tokens generated: {token_count_2}")
print(f" Response length: {len(response_string_2)} chars")
print(
f" Prefill TPS: {prefill_tps_2:.1f} tokens/sec ({prompt_tokens} prompt tokens / {time_to_first_token_2:.3f}s)"
)
print(
f" Generation TPS: {generation_tps_2:.1f} tokens/sec ({token_count_2} tokens / {generation_time_2:.3f}s)"
)
print("\nComparison:")
print(
f" Second inference time to first token: {time_to_first_token_2 / time_to_first_token_1:.2f}x the first"
)
print(
f" Second inference prefill TPS: {prefill_tps_2 / prefill_tps_1:.2f}x the first"
)
print(
f" Second inference generation TPS: {generation_tps_2 / generation_tps_1:.2f}x the first"
)
# Basic assertions to ensure responses make sense
assert len(response_string_1) > 0
assert len(response_string_2) > 0
assert time_to_first_token_1 and time_to_first_token_1 > 0
assert time_to_first_token_2 and time_to_first_token_2 > 0
# Cleanup
_ = global_events.collect()
await asyncio.sleep(1.0)
events = global_events.collect()
assert len(events) == 0
await global_events.append_events(
[
InstanceDeleted(
instance_id=instance.instance_id,
),
],
origin=MASTER_NODE_ID,
)
await asyncio.sleep(2.0)
worker.shutdown()
@pytest.mark.skipif(
skip,
reason="This test only runs when model mlx-community/Llama-3.3-70B-Instruct-4bit is downloaded",
)
async def test_2_runner_inference(
pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata],
hosts: Callable[[int], list[Host]],
chat_completion_task: Callable[[InstanceId, TaskId], Task],
two_workers_with_shared_mailbox: tuple[Worker, Worker, WorkerMailbox],
):
worker1, worker2, global_events = two_workers_with_shared_mailbox
async with create_task_group() as tg:
tg.start_soon(worker1.run)
tg.start_soon(worker2.run)
## Instance
model_id = ModelId(MODEL_ID)
shard_assignments = ShardAssignments(
model_id=model_id,
runner_to_shard={
RUNNER_1_ID: pipeline_shard_meta(2, 0),
RUNNER_2_ID: pipeline_shard_meta(2, 1),
},
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
)
instance = Instance(
instance_id=INSTANCE_1_ID,
instance_type=InstanceStatus.Active,
shard_assignments=shard_assignments,
hosts=hosts(2),
)
task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID)
task.task_params.messages[
0
].content = "Can you explain to me how a bubble sort works, speaking as if you are a fairy."
task.task_params.max_tokens = 1000
await global_events.append_events(
[
InstanceCreated(instance=instance),
TaskCreated(task_id=task.task_id, task=task),
],
origin=MASTER_NODE_ID,
)
(
seen_task_started,
seen_task_finished,
response_string,
_,
) = await read_streaming_response(global_events)
assert seen_task_started
assert seen_task_finished
assert "swap" in response_string.lower()
_ = global_events.collect()
await asyncio.sleep(1.0)
events = global_events.collect()
assert len(events) == 0
await global_events.append_events(
[
InstanceDeleted(
instance_id=instance.instance_id,
),
],
origin=MASTER_NODE_ID,
)
await asyncio.sleep(2.0)
worker1.shutdown()
worker2.shutdown()
@pytest.mark.skipif(
skip,
reason="This test only runs when model mlx-community/Llama-3.3-70B-Instruct-4bit is downloaded",
)
async def test_parallel_inference(
pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata],
hosts: Callable[[int], list[Host]],
chat_completion_task: Callable[[InstanceId, TaskId], Task],
two_workers_with_shared_mailbox: tuple[Worker, Worker, WorkerMailbox],
):
worker1, worker2, global_events = two_workers_with_shared_mailbox
async with create_task_group() as tg:
tg.start_soon(worker1.run)
tg.start_soon(worker2.run)
## Instance
model_id = ModelId(MODEL_ID)
shard_assignments = ShardAssignments(
model_id=model_id,
runner_to_shard={
RUNNER_1_ID: pipeline_shard_meta(2, 0),
RUNNER_2_ID: pipeline_shard_meta(2, 1),
},
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
)
instance = Instance(
instance_id=INSTANCE_1_ID,
instance_type=InstanceStatus.Active,
shard_assignments=shard_assignments,
hosts=hosts(2),
)
completion_create_params_1 = ChatCompletionTaskParams(
model="gpt-4",
messages=[
ChatCompletionMessage(
role="user", content='Tell me a haiku that uses the word "pond".'
)
],
stream=True,
max_tokens=1000,
)
task1 = ChatCompletionTask(
task_id=TASK_1_ID,
command_id=COMMAND_1_ID,
instance_id=INSTANCE_1_ID,
task_status=TaskStatus.Pending,
task_params=completion_create_params_1,
)
completion_create_params_2 = ChatCompletionTaskParams(
model="gpt-4",
messages=[
ChatCompletionMessage(
role="user", content='Tell me a haiku that uses the word "tree".'
)
],
stream=True,
max_tokens=1000,
)
task2 = ChatCompletionTask(
task_id=TASK_2_ID,
command_id=COMMAND_2_ID,
instance_id=INSTANCE_1_ID,
task_status=TaskStatus.Pending,
task_params=completion_create_params_2,
)
await global_events.append_events(
[
InstanceCreated(instance=instance),
TaskCreated(task_id=task1.task_id, task=task1),
TaskCreated(task_id=task2.task_id, task=task2),
],
origin=MASTER_NODE_ID,
)
(
seen_task_started_1,
seen_task_finished_1,
response_string_1,
_,
) = await read_streaming_response(global_events)
incomplete_task = (
TASK_2_ID
if worker1.state.tasks[TASK_1_ID].task_status == TaskStatus.Complete
else TASK_2_ID
)
(
seen_task_started_2,
seen_task_finished_2,
response_string_2,
_,
) = await read_streaming_response(global_events, filter_task=incomplete_task)
assert seen_task_started_1
assert seen_task_finished_1
assert seen_task_started_2
assert seen_task_finished_2
print(response_string_1)
print(response_string_2)
assert ("pond" in response_string_1.lower()) ^ (
"pond" in response_string_2.lower()
), "'pond' must appear in exactly one response"
assert ("tree" in response_string_1.lower()) ^ (
"tree" in response_string_2.lower()
), "'tree' must appear in exactly one response"
_ = global_events.collect()
await asyncio.sleep(1.0)
events = global_events.collect()
assert len(events) == 0
await global_events.append_events(
[
InstanceDeleted(
instance_id=instance.instance_id,
),
],
origin=MASTER_NODE_ID,
)
await asyncio.sleep(2.0)
worker1.shutdown()
worker2.shutdown()

View File

@@ -1,550 +0,0 @@
import pytest
from exo.worker.common import AssignedRunner
from exo.shared.types.api import ChatCompletionMessage
from exo.shared.types.state import State
from exo.shared.types.tasks import (
ChatCompletionTask,
ChatCompletionTaskParams,
TaskStatus,
)
from exo.shared.types.worker.common import WorkerStatus
from exo.shared.types.worker.downloads import (
DownloadPending,
)
from exo.shared.types.worker.instances import InstanceStatus
from exo.shared.types.worker.ops import (
AssignRunnerOp,
ExecuteTaskOp,
RunnerDownOp,
RunnerUpOp,
UnassignRunnerOp,
)
from exo.shared.types.worker.runners import (
DownloadingRunnerStatus,
FailedRunnerStatus,
InactiveRunnerStatus,
LoadedRunnerStatus,
RunningRunnerStatus,
)
from exo.shared.types.worker.shards import PipelineShardMetadata
from exo.worker.main import Worker
from exo.worker.plan import plan
from exo.worker.tests.constants import (
COMMAND_1_ID,
INSTANCE_1_ID,
MODEL_A_ID,
NODE_A,
NODE_B,
RUNNER_1_ID,
RUNNER_2_ID,
TASK_1_ID,
)
from exo.worker.tests.test_plan.test_worker_plan_utils import (
InProcessRunner,
PlanTestCase,
make_downloading_status,
make_model_meta,
make_state,
make_test_case,
)
"""
The idea with these tests is to define declaratively the input and expected output of the worker.plan function.
We initialize a Worker with InProcessRunners. We then construct a State which gets passed to Worker.plan.
We then check what operation is returned by Worker.plan.
Note that the 'self' node will always be NODE_A. This leads to the swapped-around cases when checking failure cases etc.
"""
def _get_test_cases() -> list[PlanTestCase]:
# The `model_path` for `RUNNER_1_ID` must exist for the `DownloadOp` test case to pass validation.
model_a_meta = make_model_meta(MODEL_A_ID)
return [
PlanTestCase(
description="no runners -> no-op",
in_process_runners=[],
state=State(
node_status={NODE_A: WorkerStatus.Idle}, instances={}, runners={}
),
expected_op=None,
),
# Both 'assigned' and 'downloading' should be blocking ops - so if we are in either of these we should unassign to retry.
# This needs to change when we move to an async worker
make_test_case(
description="runner state assigned, runner is assigned and downloading -> unassign",
runner_specs=[
{
"runner_id": RUNNER_1_ID,
"node_id": NODE_A,
"device_rank": 0,
"status": make_downloading_status(NODE_A),
"downloaded": False,
}
],
instance_status=InstanceStatus.Inactive,
expected_op=UnassignRunnerOp(runner_id=RUNNER_1_ID),
),
make_test_case(
description="ready runner, model present -> no-op",
runner_specs=[
{
"runner_id": RUNNER_1_ID,
"node_id": NODE_A,
"device_rank": 0,
"status": InactiveRunnerStatus(),
"downloaded": True,
}
],
instance_status=InstanceStatus.Inactive,
expected_op=None,
),
PlanTestCase(
description="runner assigned and not in state -> AssignRunnerOp",
in_process_runners=[],
state=make_state(
runner_specs_per_instance={
INSTANCE_1_ID: [(RUNNER_1_ID, NODE_A, 0, InactiveRunnerStatus())]
},
model_id=MODEL_A_ID,
instance_status=InstanceStatus.Active, # Either active or inactive should yield the same.
),
expected_op=AssignRunnerOp(
instance_id=INSTANCE_1_ID,
runner_id=RUNNER_1_ID,
shard_metadata=PipelineShardMetadata(
device_rank=0,
world_size=1,
model_meta=model_a_meta,
start_layer=0,
end_layer=1,
n_layers=1,
),
hosts=[],
),
),
PlanTestCase(
description="runner assigned but no longer in state -> UnassignRunnerOp",
in_process_runners=[
InProcessRunner(
runner_id=RUNNER_1_ID,
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
status=InactiveRunnerStatus(),
downloaded=False,
)
],
state=State(
node_status={NODE_A: WorkerStatus.Idle}, instances={}, runners={}
),
expected_op=UnassignRunnerOp(runner_id=RUNNER_1_ID),
),
make_test_case(
description="ready runner (and state up) -> expect RunnerUpOp",
runner_specs=[
{
"runner_id": RUNNER_1_ID,
"node_id": NODE_A,
"device_rank": 0,
"status": InactiveRunnerStatus(),
"downloaded": True,
}
],
instance_status=InstanceStatus.Active,
expected_op=RunnerUpOp(runner_id=RUNNER_1_ID),
),
make_test_case(
description="1 ready, 1 downloading (and state up) -> no-op",
runner_specs=[
{
"runner_id": RUNNER_1_ID,
"node_id": NODE_A,
"device_rank": 0,
"status": InactiveRunnerStatus(),
"downloaded": True,
},
{
"runner_id": RUNNER_2_ID,
"node_id": NODE_B,
"device_rank": 1,
"status": DownloadingRunnerStatus(
download_progress=DownloadPending(node_id=NODE_A)
),
"downloaded": False,
},
],
tasks=[
{
"task_id": TASK_1_ID,
"instance_id": INSTANCE_1_ID,
"status": TaskStatus.Pending,
"messages": [{"role": "user", "content": "Hello, world!"}],
}
],
instance_status=InstanceStatus.Active,
expected_op=None,
),
make_test_case(
description="2 ready runners (and state up) -> expect RunnerUpOp",
runner_specs=[
{
"runner_id": RUNNER_1_ID,
"node_id": NODE_A,
"device_rank": 0,
"status": InactiveRunnerStatus(),
"downloaded": True,
},
{
"runner_id": RUNNER_2_ID,
"node_id": NODE_B,
"device_rank": 1,
"status": InactiveRunnerStatus(),
"downloaded": True,
},
],
tasks=[
{
"task_id": TASK_1_ID,
"instance_id": INSTANCE_1_ID,
"status": TaskStatus.Pending,
"messages": [{"role": "user", "content": "Hello, world!"}],
}
],
instance_status=InstanceStatus.Active,
expected_op=RunnerUpOp(runner_id=RUNNER_1_ID),
),
make_test_case(
description="loaded runner (and state down) -> expect RunnerDownOp",
runner_specs=[
{
"runner_id": RUNNER_1_ID,
"node_id": NODE_A,
"device_rank": 0,
"status": LoadedRunnerStatus(),
"downloaded": True,
}
],
instance_status=InstanceStatus.Inactive,
expected_op=RunnerDownOp(runner_id=RUNNER_1_ID),
),
make_test_case(
description="failed runner (and state down) -> expect RunnerDownOp",
runner_specs=[
{
"runner_id": RUNNER_1_ID,
"node_id": NODE_A,
"device_rank": 0,
"status": FailedRunnerStatus(),
"downloaded": True,
}
],
instance_status=InstanceStatus.Inactive,
expected_op=RunnerDownOp(runner_id=RUNNER_1_ID),
),
make_test_case(
description="loaded runner, model present, task pending -> expect ExecuteTaskOp",
runner_specs=[
{
"runner_id": RUNNER_1_ID,
"node_id": NODE_A,
"device_rank": 0,
"status": LoadedRunnerStatus(),
"downloaded": True,
}
],
tasks=[
{
"task_id": TASK_1_ID,
"instance_id": INSTANCE_1_ID,
"status": TaskStatus.Pending,
"messages": [{"role": "user", "content": "Hello, world!"}],
}
],
instance_status=InstanceStatus.Active,
expected_op=ExecuteTaskOp(
runner_id=RUNNER_1_ID,
task=ChatCompletionTask(
task_id=TASK_1_ID,
command_id=COMMAND_1_ID,
instance_id=INSTANCE_1_ID,
task_status=TaskStatus.Pending,
task_params=ChatCompletionTaskParams(
model=str(MODEL_A_ID),
messages=[
ChatCompletionMessage(role="user", content="Hello, world!")
],
),
),
),
),
# We should only run rank 0 once all other ranks are running.
make_test_case(
description="two loaded runners & task, i'm rank 0 -> no-op",
runner_specs=[
{
"runner_id": RUNNER_1_ID,
"node_id": NODE_A,
"device_rank": 0,
"status": LoadedRunnerStatus(),
"downloaded": True,
},
{
"runner_id": RUNNER_2_ID,
"node_id": NODE_B,
"device_rank": 1,
"status": LoadedRunnerStatus(),
"downloaded": True,
},
],
tasks=[
{
"task_id": TASK_1_ID,
"instance_id": INSTANCE_1_ID,
"status": TaskStatus.Pending,
"messages": [{"role": "user", "content": "Hello, world!"}],
}
],
instance_status=InstanceStatus.Active,
expected_op=None,
),
make_test_case(
description="two loaded runners & task, i'm rank 1 -> expect ExecuteTaskOp on rank 1",
runner_specs=[
{
"runner_id": RUNNER_1_ID,
"node_id": NODE_A,
"device_rank": 1,
"status": LoadedRunnerStatus(),
"downloaded": True,
},
{
"runner_id": RUNNER_2_ID,
"node_id": NODE_B,
"device_rank": 0,
"status": LoadedRunnerStatus(),
"downloaded": True,
},
],
tasks=[
{
"task_id": TASK_1_ID,
"instance_id": INSTANCE_1_ID,
"status": TaskStatus.Pending,
"messages": [{"role": "user", "content": "Hello, world!"}],
}
],
instance_status=InstanceStatus.Active,
expected_op=ExecuteTaskOp(
runner_id=RUNNER_1_ID,
task=ChatCompletionTask(
task_id=TASK_1_ID,
command_id=COMMAND_1_ID,
instance_id=INSTANCE_1_ID,
task_params=ChatCompletionTaskParams(
model=str(MODEL_A_ID),
messages=[
ChatCompletionMessage(role="user", content="Hello, world!")
],
),
task_status=TaskStatus.Pending,
),
),
),
make_test_case(
description="rank 1 loaded, rank 0 ready, i'm rank 0 -> expect ExecuteTaskOp on rank 0",
runner_specs=[
{
"runner_id": RUNNER_1_ID,
"node_id": NODE_A,
"device_rank": 0,
"status": LoadedRunnerStatus(),
"downloaded": True,
},
{
"runner_id": RUNNER_2_ID,
"node_id": NODE_B,
"device_rank": 1,
"status": RunningRunnerStatus(),
"downloaded": True,
},
],
tasks=[
{
"task_id": TASK_1_ID,
"instance_id": INSTANCE_1_ID,
"status": TaskStatus.Pending,
"messages": [{"role": "user", "content": "Hello, world!"}],
}
],
instance_status=InstanceStatus.Active,
expected_op=ExecuteTaskOp(
runner_id=RUNNER_1_ID,
task=ChatCompletionTask(
task_id=TASK_1_ID,
command_id=COMMAND_1_ID,
instance_id=INSTANCE_1_ID,
task_params=ChatCompletionTaskParams(
model=str(MODEL_A_ID),
messages=[
ChatCompletionMessage(role="user", content="Hello, world!")
],
),
task_status=TaskStatus.Pending,
),
),
),
make_test_case(
description="this runner failed (1 node) -> RunnerDownOp",
runner_specs=[
{
"runner_id": RUNNER_1_ID,
"node_id": NODE_A,
"device_rank": 0,
"status": FailedRunnerStatus(),
"downloaded": True,
}
],
instance_status=InstanceStatus.Active,
expected_op=RunnerDownOp(runner_id=RUNNER_1_ID),
),
make_test_case(
description="other runner failed -> RunnerDownOp",
runner_specs=[
{
"runner_id": RUNNER_1_ID,
"node_id": NODE_A,
"device_rank": 0,
"status": LoadedRunnerStatus(),
"downloaded": True,
},
{
"runner_id": RUNNER_2_ID,
"node_id": NODE_B,
"device_rank": 1,
"status": FailedRunnerStatus(),
"downloaded": True,
},
],
instance_status=InstanceStatus.Active,
expected_op=RunnerDownOp(runner_id=RUNNER_1_ID),
),
make_test_case(
description="this runner failed (2 nodes) -> no-op",
runner_specs=[
{
"runner_id": RUNNER_1_ID,
"node_id": NODE_A,
"device_rank": 0,
"status": FailedRunnerStatus(),
"downloaded": True,
},
{
"runner_id": RUNNER_2_ID,
"node_id": NODE_B,
"device_rank": 1,
"status": LoadedRunnerStatus(),
"downloaded": True,
},
],
instance_status=InstanceStatus.Active,
expected_op=None,
),
make_test_case(
description="this node failed, other node spun down -> RunnerDownOp",
runner_specs=[
{
"runner_id": RUNNER_1_ID,
"node_id": NODE_A,
"device_rank": 0,
"status": FailedRunnerStatus(),
"downloaded": True,
},
{
"runner_id": RUNNER_2_ID,
"node_id": NODE_B,
"device_rank": 1,
"status": InactiveRunnerStatus(),
"downloaded": True,
},
],
instance_status=InstanceStatus.Active,
expected_op=RunnerDownOp(runner_id=RUNNER_1_ID),
),
]
# ---------------------------------------------------------------------------
# Parametrised test
# ---------------------------------------------------------------------------
# Pre-compute readable identifiers for each case to avoid lambda typing issues.
@pytest.mark.parametrize(
"case",
# We use a factory to delay test case generation until tmp_path is available.
[pytest.param(c, id=c.id()) for c in _get_test_cases()],
)
def test_worker_plan(case: PlanTestCase, worker_void_mailbox: Worker) -> None:
"""Exercise Worker.plan across declarative scenarios."""
print(f"----- case: {case.description}")
# Regenerate test cases with the actual tmp_path fixture
test_cases = {c.description: c for c in _get_test_cases()}
case = test_cases[case.description]
worker = worker_void_mailbox
runner_config: InProcessRunner
for runner_config in case.in_process_runners:
if len(case.state.instances) == 1:
instance_id = next(iter(case.state.instances))
shard_assignments = case.state.instances[instance_id].shard_assignments
shard_metadata = shard_assignments.runner_to_shard[runner_config.runner_id]
# Only add this runner if it belongs to our node
runner_node = None
for node, runner in shard_assignments.node_to_runner.items():
if runner == runner_config.runner_id:
runner_node = node
break
if runner_node != worker.node_id:
# This runner belongs to a different node, skip it
continue
elif len(case.state.instances) == 0:
shard_metadata = PipelineShardMetadata(
device_rank=runner_config.device_rank,
world_size=1,
model_meta=make_model_meta(runner_config.model_id),
start_layer=0,
end_layer=1,
n_layers=1,
)
else:
raise Exception(
"test_worker_plan not currently designed to have more than 1 instance."
)
assigned_runner = AssignedRunner(
runner_id=runner_config.runner_id,
instance_id=runner_config.instance_id,
shard_metadata=shard_metadata,
hosts=[],
status=runner_config.status,
runner=None,
)
worker.assigned_runners[runner_config.runner_id] = assigned_runner
op = plan(
worker.assigned_runners,
NODE_A,
case.state.instances,
case.state.runners,
case.state.tasks,
)
assert op == case.expected_op

View File

@@ -1,292 +0,0 @@
from dataclasses import dataclass
from typing import NotRequired, TypedDict
from typing_extensions import Literal
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard
from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.state import State
from exo.shared.types.tasks import ChatCompletionTask, TaskId, TaskStatus
from exo.shared.types.worker.common import InstanceId, RunnerId, WorkerStatus
from exo.shared.types.worker.downloads import DownloadOngoing, DownloadProgressData
from exo.shared.types.worker.instances import Instance, InstanceStatus
from exo.shared.types.worker.ops import RunnerOp
from exo.shared.types.worker.runners import (
DownloadingRunnerStatus,
RunnerStatus,
RunningRunnerStatus,
ShardAssignments,
)
from exo.shared.types.worker.shards import PipelineShardMetadata
from exo.worker.tests.constants import COMMAND_1_ID, INSTANCE_1_ID, MODEL_A_ID
class RunnerSpecDict(TypedDict):
"""Type definition for runner specification dictionaries."""
runner_id: RunnerId
node_id: NodeId
device_rank: int
status: RunnerStatus
downloaded: NotRequired[bool] # defaults to True if not provided
class MessageDict(TypedDict):
"""Type definition for message dictionaries."""
role: Literal["system", "user", "assistant", "developer", "tool", "function"]
content: NotRequired[str | None]
name: NotRequired[str | None]
tool_calls: NotRequired[list[dict[str, str]] | None]
tool_call_id: NotRequired[str | None]
function_call: NotRequired[dict[str, str] | None]
class TaskSpecDict(TypedDict):
"""Type definition for task specification dictionaries."""
task_id: TaskId
instance_id: NotRequired[
InstanceId
] # defaults to function parameter if not provided
command_id: NotRequired[CommandId] # defaults to COMMAND_1_ID if not provided
status: NotRequired[TaskStatus] # defaults to TaskStatus.PENDING if not provided
model: NotRequired[str] # defaults to model_id if not provided
messages: NotRequired[
list[MessageDict]
] # defaults to [{'role': 'user', 'content': 'Hello, world!'}] if not provided
@dataclass(slots=True, frozen=True)
class InProcessRunner:
"""Minimal description of a runner's in-process state."""
runner_id: RunnerId
instance_id: InstanceId
model_id: ModelId
status: RunnerStatus
downloaded: bool
device_rank: int = 0
@dataclass(slots=True, frozen=True)
class PlanTestCase:
"""Table-driven description of an entire planning scenario."""
description: str
state: State
in_process_runners: list[InProcessRunner]
expected_op: RunnerOp | None
def id(self) -> str: # noqa: D401
return self.description.replace(" ", "_")
def make_shard_metadata(
device_rank: int, world_size: int, model_id: ModelId = MODEL_A_ID
) -> PipelineShardMetadata:
"""Create PipelineShardMetadata with proper layer assignments based on device_rank and world_size."""
total_layers = world_size # For simplicity in tests, total_layers = world_size
if world_size == 1:
start_layer = 0
end_layer = 1
n_layers = 1
else:
# For multi-device setup, each device gets one layer
start_layer = device_rank
end_layer = device_rank + 1
n_layers = total_layers
return PipelineShardMetadata(
device_rank=device_rank,
world_size=world_size,
model_meta=make_model_meta(model_id),
start_layer=start_layer,
end_layer=end_layer,
n_layers=n_layers,
)
def make_downloading_status(node_id: NodeId) -> DownloadingRunnerStatus:
"""Factory for a *Downloading* status with placeholder progress."""
return DownloadingRunnerStatus(
download_progress=DownloadOngoing(
node_id=node_id,
download_progress=DownloadProgressData(
total_bytes=Memory.from_bytes(1),
downloaded_bytes=Memory.from_bytes(0),
downloaded_bytes_this_session=Memory.from_bytes(0),
completed_files=0,
total_files=0,
speed=0,
eta_ms=0,
files={},
),
)
)
def make_model_meta(model_id: str) -> ModelMetadata:
model_card: ModelCard
for card in MODEL_CARDS.values():
if card.model_id == model_id:
model_card = card
return ModelMetadata(
model_id=ModelId(model_id),
pretty_name=model_card.model_id,
storage_size=Memory.from_kb(10**6),
n_layers=16,
)
raise Exception(f"Unknown model_id passed: {model_id}")
## Alternatively, if we are ok for this method to be async:
# await _get_model_meta(model_id)
def make_instance(
instance_id: InstanceId,
runner_specs: list[tuple[RunnerId, NodeId, int, RunnerStatus]],
model_id: ModelId = MODEL_A_ID,
instance_status: InstanceStatus = InstanceStatus.Active,
) -> tuple[Instance, dict[RunnerId, RunnerStatus], dict[NodeId, WorkerStatus]]:
"""Creates an instance with one or more runners."""
runner_to_shard: dict[RunnerId, PipelineShardMetadata] = {}
node_to_runner: dict[NodeId, RunnerId] = {}
world_size = len(runner_specs)
for runner_id, node_id, device_rank, _ in runner_specs:
shard_metadata = make_shard_metadata(device_rank, world_size, model_id)
runner_to_shard[runner_id] = shard_metadata
node_to_runner[node_id] = runner_id
shard_assignments = ShardAssignments(
model_id=model_id,
runner_to_shard=runner_to_shard,
node_to_runner=node_to_runner,
)
instance = Instance(
instance_id=instance_id,
instance_type=instance_status,
shard_assignments=shard_assignments,
hosts=[],
)
# Currently nodes are only ever idle - as if they were running we would be blocking - so we wouldn't be running plan()
# node_statuses = {node_id: WorkerStatus.Idle for _, node_id, _, _ in runner_specs}
node_statuses: dict[NodeId, WorkerStatus] = {}
for _runner_id, node_id, _, status in runner_specs:
if isinstance(status, RunningRunnerStatus):
node_statuses[node_id] = WorkerStatus.Running
else:
node_statuses[node_id] = WorkerStatus.Idle
runner_statuses = {runner_id: status for runner_id, _, _, status in runner_specs}
return instance, runner_statuses, node_statuses
def make_state(
runner_specs_per_instance: dict[
InstanceId, list[tuple[RunnerId, NodeId, int, RunnerStatus]]
],
tasks: dict[TaskId, ChatCompletionTask] | None = None,
model_id: ModelId = MODEL_A_ID,
instance_status: InstanceStatus = InstanceStatus.Active,
) -> State:
"""Builds a full State from runner specs per instance, tasks, and defaults."""
if tasks is None:
tasks = {}
instances: dict[InstanceId, Instance] = {}
all_runner_statuses: dict[RunnerId, RunnerStatus] = {}
all_node_statuses: dict[NodeId, WorkerStatus] = {}
for inst_id, specs in runner_specs_per_instance.items():
# Build per-instance data using make_instance
instance, runner_statuses, node_statuses = make_instance(
instance_id=inst_id,
runner_specs=specs,
model_id=model_id,
instance_status=instance_status,
)
instances[inst_id] = instance
all_runner_statuses.update(runner_statuses)
all_node_statuses.update(node_statuses)
return State(
node_status=all_node_statuses,
instances=instances,
runners=all_runner_statuses,
tasks=tasks,
)
def make_test_case(
description: str,
runner_specs: list[RunnerSpecDict],
tasks: list[TaskSpecDict] | None = None,
expected_op: RunnerOp | None = None,
instance_id: InstanceId = INSTANCE_1_ID,
instance_status: InstanceStatus = InstanceStatus.Active,
model_id: ModelId = MODEL_A_ID,
command_id: CommandId = COMMAND_1_ID, # Default for tasks
) -> PlanTestCase:
"""Builds a PlanTestCase from high-level specs."""
if tasks is None:
tasks = []
# Convert runner_specs to tuple format for make_instance
specs_tuple = [
(r["runner_id"], r["node_id"], r["device_rank"], r["status"])
for r in runner_specs
]
# Build state using make_state (wrap single instance)
state_tasks: dict[TaskId, ChatCompletionTask] = {}
for t in tasks:
task = ChatCompletionTask(
instance_id=instance_id,
task_id=t["task_id"],
command_id=t.get("command_id", command_id),
task_status=t.get("status", TaskStatus.Pending),
task_params=ChatCompletionTaskParams(
model=t.get("model", str(model_id)),
messages=[
ChatCompletionMessage(**m)
for m in t.get(
"messages", [{"role": "user", "content": "Hello, world!"}]
)
],
),
)
state_tasks[t["task_id"]] = task
state = make_state(
runner_specs_per_instance={instance_id: specs_tuple},
tasks=state_tasks,
model_id=model_id,
instance_status=instance_status,
)
# Build in_process_runners with downloaded (default True if missing)
in_process_runners = [
InProcessRunner(
runner_id=r["runner_id"],
instance_id=instance_id,
model_id=model_id,
status=r["status"],
downloaded=r.get("downloaded", True),
device_rank=r["device_rank"],
)
for r in runner_specs
]
return PlanTestCase(
description=description,
state=state,
in_process_runners=in_process_runners,
expected_op=expected_op,
)

View File

@@ -1,181 +0,0 @@
import asyncio
import os
from typing import Callable
import pytest
from anyio import create_task_group, move_on_after
from exo.shared.types.common import Host
from exo.shared.types.events import InstanceCreated, InstanceDeleted
from exo.shared.types.models import ModelId
from exo.shared.types.worker.instances import Instance, InstanceStatus, ShardAssignments
from exo.shared.types.worker.runners import FailedRunnerStatus
from exo.shared.types.worker.shards import PipelineShardMetadata
from exo.worker.main import Worker
from exo.worker.runner.runner_supervisor import RunnerSupervisor
from exo.worker.tests.constants import (
INSTANCE_1_ID,
MASTER_NODE_ID,
NODE_A,
NODE_B,
RUNNER_1_ID,
RUNNER_2_ID,
)
from exo.worker.tests.worker_management import WorkerMailbox
@pytest.fixture
def user_message() -> str:
return "What is the capital of Japan?"
@pytest.mark.skipif(
os.environ.get("DETAILED", "").lower() != "true",
reason="This test only runs when ENABLE_SPINUP_TIMEOUT_TEST=true environment variable is set",
)
async def check_runner_connection(
pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata],
hosts: Callable[[int], list[Host]],
two_workers_with_shared_mailbox: tuple[Worker, Worker, WorkerMailbox],
) -> bool:
async def wait_for_runner_supervisor(
worker: Worker, timeout: float = 5.0
) -> RunnerSupervisor | None:
with move_on_after(timeout):
while True:
assigned_runners = list(worker.assigned_runners.values())
if assigned_runners:
runner = assigned_runners[0].runner
if isinstance(runner, RunnerSupervisor):
print("breaking because success")
return runner
if isinstance(assigned_runners[0].status, FailedRunnerStatus):
print("breaking because failed")
return runner
await asyncio.sleep(0.001)
worker1, worker2, global_events = two_workers_with_shared_mailbox
# Track all tasks and workers for cleanup
async with create_task_group() as tg:
tg.start_soon(worker1.run)
tg.start_soon(worker2.run)
model_id = ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit")
shard_assignments = ShardAssignments(
model_id=model_id,
runner_to_shard={
RUNNER_1_ID: pipeline_shard_meta(2, 0),
RUNNER_2_ID: pipeline_shard_meta(2, 1),
},
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
)
instance = Instance(
instance_id=INSTANCE_1_ID,
instance_type=InstanceStatus.Active,
shard_assignments=shard_assignments,
hosts=hosts(2),
)
await global_events.append_events(
[
InstanceCreated(instance=instance),
],
origin=MASTER_NODE_ID,
)
runner_supervisor = await wait_for_runner_supervisor(worker1, timeout=6.0)
ret = (
runner_supervisor is not None
and runner_supervisor.runner_process.is_alive()
)
await global_events.append_events(
[
InstanceDeleted(
instance_id=instance.instance_id,
),
],
origin=MASTER_NODE_ID,
)
await asyncio.sleep(0.5)
worker1.shutdown()
worker2.shutdown()
tg.cancel_scope.cancel()
return ret
# should be unreachable
raise
# Check Running status
# # not now.
# def test_runner_connection_stress(
# pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata],
# hosts: Callable[[int], list[Host]],
# chat_completion_task: Callable[[InstanceId, str], Task],
# ) -> None:
# total_runs = 100
# successes = 0
# # not now.
# def test_runner_connection_stress(
# pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata],
# hosts: Callable[[int], list[Host]],
# chat_completion_task: Callable[[InstanceId, str], Task],
# ) -> None:
# total_runs = 100
# successes = 0
# for _ in range(total_runs):
# # Create a fresh event loop for each iteration
# loop = asyncio.new_event_loop()
# asyncio.set_event_loop(loop)
# for _ in range(total_runs):
# # Create a fresh event loop for each iteration
# loop = asyncio.new_event_loop()
# asyncio.set_event_loop(loop)
# try:
# result = loop.run_until_complete(check_runner_connection(
# pipeline_shard_meta=pipeline_shard_meta,
# hosts=hosts,
# chat_completion_task=chat_completion_task,
# ))
# if result:
# successes += 1
# finally:
# # Cancel all running tasks
# pending = asyncio.all_tasks(loop)
# for task in pending:
# task.cancel()
# try:
# result = loop.run_until_complete(check_runner_connection(
# pipeline_shard_meta=pipeline_shard_meta,
# hosts=hosts,
# chat_completion_task=chat_completion_task,
# ))
# if result:
# successes += 1
# finally:
# # Cancel all running tasks
# pending = asyncio.all_tasks(loop)
# for task in pending:
# task.cancel()
# # Run the event loop briefly to allow cancellation to complete
# loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
# # Run the event loop briefly to allow cancellation to complete
# loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
# # Close the event loop
# loop.close()
# # Close the event loop
# loop.close()
# print(f"Runner connection successes: {successes} / {total_runs}")
# print(f"Runner connection successes: {successes} / {total_runs}")

View File

@@ -1,43 +0,0 @@
from typing import Callable
from pydantic import BaseModel, TypeAdapter
from exo.shared.types.common import Host
from exo.shared.types.tasks import Task, TaskId
from exo.shared.types.worker.commands_runner import (
ChatTaskMessage,
RunnerMessage,
SetupMessage,
)
from exo.shared.types.worker.common import InstanceId
from exo.shared.types.worker.shards import PipelineShardMetadata
def assert_equal_serdes[T: BaseModel](obj: T, typeadapter: TypeAdapter[T]):
encoded: bytes = obj.model_dump_json().encode("utf-8") + b"\n"
decoded: T = typeadapter.validate_json(encoded)
assert decoded == obj, (
f"Decoded: {decoded} != \nOriginal: {obj}. \n binary encoded: {encoded}"
)
def test_supervisor_setup_message_serdes(
pipeline_shard_meta: Callable[..., PipelineShardMetadata],
hosts: Callable[..., list[Host]],
):
setup_message = SetupMessage(
model_shard_meta=pipeline_shard_meta(1, 0),
hosts=hosts(1),
)
assert_equal_serdes(setup_message, TypeAdapter(RunnerMessage))
def test_supervisor_task_message_serdes(
chat_completion_task: Callable[[InstanceId, TaskId], Task],
):
task = chat_completion_task(InstanceId(), TaskId())
task_message = ChatTaskMessage(
task_data=task.task_params,
)
assert_equal_serdes(task_message, TypeAdapter(RunnerMessage))

View File

@@ -1,50 +0,0 @@
## Tests for worker state handlers
import os
from typing import Callable
import pytest
from exo.shared.types.events import (
Event,
RunnerStatusUpdated,
)
from exo.shared.types.tasks import Task, TaskId
from exo.shared.types.worker.instances import Instance, InstanceId
from exo.shared.types.worker.ops import (
RunnerUpOp,
)
from exo.shared.types.worker.runners import FailedRunnerStatus
from exo.worker.main import Worker
from exo.worker.tests.constants import RUNNER_1_ID
# To enable this test, run pytest with: ENABLE_SPINUP_TIMEOUT_TEST=true pytest
@pytest.mark.skipif(
os.environ.get("DETAILED", "").lower() != "true",
reason="This test only runs when ENABLE_SPINUP_TIMEOUT_TEST=true environment variable is set",
)
@pytest.mark.asyncio
async def test_runner_up_op_timeout(
worker_with_assigned_runner: tuple[Worker, Instance],
chat_completion_task: Callable[[InstanceId, TaskId], Task],
monkeypatch: pytest.MonkeyPatch,
):
worker, _ = worker_with_assigned_runner
runner_up_op = RunnerUpOp(runner_id=RUNNER_1_ID)
# _execute_runner_up_op should throw a TimeoutError with a short timeout
events: list[Event] = []
async for event in worker._execute_runner_up_op( # type: ignore[misc]
runner_up_op, initialize_timeout=0.2
):
events.append(event)
assert isinstance(events[-1], RunnerStatusUpdated)
assert isinstance(events[-1].runner_status, FailedRunnerStatus)
assert events[-1].runner_status.error_message is not None
assert "timeout" in events[-1].runner_status.error_message.lower()
del worker.assigned_runners[list(worker.assigned_runners.keys())[0]]

View File

@@ -1,163 +0,0 @@
import asyncio
from typing import Callable
import pytest
from exo.shared.models.model_cards import MODEL_CARDS
from exo.shared.openai_compat import FinishReason
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.common import Host
from exo.shared.types.tasks import (
Task,
TaskId,
)
from exo.shared.types.worker.common import InstanceId
from exo.shared.types.worker.shards import PipelineShardMetadata
from exo.worker.runner.runner_supervisor import RunnerSupervisor
@pytest.fixture
def user_message():
"""Override the default message to ask about France's capital"""
return "What is the capital of France?"
@pytest.fixture
def lorem_ipsum() -> str:
return """
Lorem ipsum dolor sit amet, consectetur adipiscing elit. Phasellus rhoncus felis in velit tempus tristique. Nullam ipsum lectus, tristique a eros quis, ullamcorper accumsan lorem. Aliquam ut auctor elit, finibus porttitor neque. In cursus augue facilisis ante ullamcorper, at sollicitudin quam aliquam. Etiam ac lacinia lacus, et aliquet nunc. Phasellus nisi ex, feugiat quis dolor non, mollis consequat nulla. Suspendisse gravida, sem non lobortis viverra, turpis lacus elementum orci, in tristique augue tortor nec mauris. Curabitur aliquet lorem in rhoncus mollis. Aliquam pulvinar elit odio, ac feugiat magna luctus nec. Pellentesque non risus egestas, pellentesque arcu tincidunt, gravida risus. Etiam ut lorem ac lorem pharetra efficitur. Donec augue arcu, varius nec lorem vitae, suscipit semper tellus. Aliquam dignissim quis augue id fermentum. Proin aliquet pellentesque est, eget tincidunt odio ullamcorper vel. Suspendisse potenti.
Aenean imperdiet justo sit amet erat aliquet tristique. Sed tempus, turpis a cursus lobortis, ante sem imperdiet est, eu dapibus sapien velit eget elit. Donec feugiat sed risus sed scelerisque. Donec posuere tempor orci, sit amet pellentesque est efficitur non. Vivamus sodales pretium purus, sed rutrum enim auctor ut. Cras pharetra vitae libero et hendrerit. Sed nec tempus odio. Proin blandit facilisis scelerisque. Nulla in mattis mi. Etiam bibendum efficitur aliquam. Proin ut risus aliquet, rhoncus lectus non, rhoncus arcu. Nam nibh felis, ultrices a elit sed, ultricies sollicitudin tellus. Interdum et malesuada fames ac ante ipsum primis in faucibus. Maecenas faucibus magna ut purus imperdiet faucibus. Nam fermentum nulla fermentum magna aliquam, vel lacinia neque euismod. Donec tincidunt sed neque non facilisis.
Proin id lorem cursus, vehicula ante non, lacinia metus. Nam egestas dui a iaculis convallis. Ut suscipit justo est, nec pharetra ante accumsan ac. Pellentesque nec nisi ipsum. Duis non arcu neque. Curabitur non luctus purus. Phasellus pulvinar commodo lacus sit amet auctor. Ut ut mattis metus, eu auctor arcu. Etiam a suscipit est. Morbi orci mauris, suscipit tempus fermentum vel, luctus faucibus lectus. Aliquam a euismod arcu. Suspendisse porttitor eget libero vitae laoreet.
Fusce congue lorem mi, a mollis felis efficitur quis. Quisque lobortis scelerisque arcu, a varius sapien. Nulla eget orci non urna imperdiet tincidunt. Nunc mi massa, consectetur id lorem consectetur, molestie dignissim sem. Suspendisse et augue magna. Mauris id tempus velit, cursus suscipit tortor. Duis non mi non nisi fringilla maximus in et erat.
Proin consequat sapien eget tellus aliquam ultrices. Nunc hendrerit semper massa, pulvinar sodales ipsum condimentum eu. Proin vel ligula venenatis, lobortis lectus eu, vehicula justo. Mauris eu arcu at orci vehicula feugiat non eu metus. Duis ut vestibulum quam. Maecenas dolor elit, egestas ut purus sit amet, convallis lobortis massa. Ut volutpat augue ac ante consectetur dignissim. Maecenas vitae felis elementum, semper augue eu, auctor dolor. Ut pulvinar convallis tortor non volutpat. Curabitur vulputate sem sodales sapien pretium ultrices. Sed luctus libero vitae urna eleifend tincidunt. Proin pulvinar imperdiet cursus. Suspendisse ullamcorper laoreet leo dapibus tincidunt. Pellentesque molestie elementum felis.
Integer vitae congue nulla. Vestibulum ante ipsum primis in faucibus orci luctus et ultrices posuere cubilia curae; Vestibulum elit velit, malesuada quis ipsum et, imperdiet varius velit. Nam tristique viverra maximus. Curabitur eget semper lectus. Sed vitae lorem sit amet mi lacinia posuere ac a risus. Pellentesque et magna nisl. In hac habitasse platea dictumst. Aenean suscipit, nibh vitae sollicitudin commodo, risus mi commodo neque, nec venenatis velit augue sed massa. Nam tempus, arcu id eleifend auctor, est dui viverra odio, vel convallis arcu dolor id quam. Ut malesuada ligula vel interdum eleifend. In posuere ultrices tincidunt. Sed non enim sit amet lectus sagittis mattis eu at sapien. Pellentesque eu urna mollis, vehicula dolor eget, lobortis nisl. Suspendisse ex nisi, iaculis non sapien ac, fringilla rutrum dolor. Quisque pretium mauris nec ante gravida, sed laoreet neque viverra.
Donec mattis orci sit amet tincidunt maximus. Vestibulum ante ipsum primis in faucibus orci luctus et ultrices posuere cubilia curae; Curabitur tristique venenatis lectus, vel pulvinar sem. Sed vel dolor lacinia, aliquet nisi ac, bibendum libero. Nullam vulputate euismod augue ac imperdiet. Proin at fermentum sapien. Nam et fringilla lorem. Aenean sed lacus sed tellus sodales mattis ut rutrum ex. Nulla ligula diam, interdum quis faucibus sit amet, laoreet vel massa. Fusce mauris massa, tempor quis tempus nec, dictum a ligula. Ut at dapibus sapien. Nullam sem lorem, sollicitudin non dui a, consequat molestie mauris. Quisque sem nulla, vehicula nec vulputate ac, viverra in massa. Lorem ipsum dolor sit amet, consectetur adipiscing elit. Curabitur pretium venenatis nisi non bibendum. Nam vitae ligula auctor, rutrum lectus eget, feugiat augue.
Ut nunc risus, vehicula at metus non, consequat suscipit risus. Mauris eget sem in neque tincidunt iaculis. Pellentesque lacus leo, molestie ut pharetra sit amet, porta nec neque. Aliquam eu bibendum odio. Proin tempus bibendum ornare. Morbi non risus vitae ante tempor porta quis sed augue. Nullam hendrerit nulla in eleifend tincidunt. Integer suscipit ligula at nunc blandit vehicula. Nam porttitor leo in turpis suscipit malesuada. Etiam sodales nunc nisi, pharetra malesuada nibh varius in. Cras quis pellentesque augue, vitae convallis velit. In et dui lorem. Integer semper eros eget augue posuere, ac elementum tellus convallis. Praesent blandit tempus ultrices. Suspendisse nec dui vitae neque varius eleifend. Sed pretium metus leo, id viverra tellus scelerisque in.
Aenean sodales urna vitae lobortis cursus. Sed vitae pellentesque erat, fermentum pellentesque urna. Suspendisse potenti. Sed porttitor placerat turpis non vestibulum. Duis in nisi non purus venenatis tempus non eu nisi. Sed bibendum sapien vitae ultricies condimentum. Integer vel mattis lectus, consequat congue ex. Cras convallis odio volutpat nulla vehicula efficitur. Pellentesque eget justo neque. Morbi mattis vitae magna et suscipit. Etiam orci sapien, tincidunt non tellus eget, laoreet vestibulum massa. Pellentesque habitant morbi tristique senectus et netus et malesuada fames ac turpis egestas. Mauris nec nisi enim. Donec risus odio, lobortis in odio malesuada, laoreet rutrum urna. Nunc sit amet euismod quam.
Fusce rhoncus ullamcorper nunc, ut pellentesque nisi dictum sed. Fusce sem mi, bibendum ut dictum at, porta in libero. Pellentesque placerat mollis sapien, sed eleifend lorem consequat in. Phasellus vel tempor ligula. Pellentesque tincidunt suscipit tortor vel blandit. Maecenas purus mi, mattis ac aliquam vel, rutrum eu nulla. Proin rhoncus nec sem a congue. Pellentesque sit amet sapien quam. Sed hendrerit neque id venenatis dignissim.
Vestibulum laoreet eu felis nec aliquam. Praesent gravida ornare odio nec porttitor. Donec ut tellus eros. Proin fringilla urna augue, vitae ornare leo varius non. Curabitur consectetur, purus in iaculis finibus, lectus lacus porttitor dolor, nec eleifend tellus massa eget tellus. Mauris sit amet convallis risus, a fermentum lorem. Suspendisse potenti. Curabitur vulputate finibus maximus. Interdum et malesuada fames ac ante ipsum primis in faucibus. In vel erat pellentesque, rhoncus magna vel, scelerisque mauris.
Nulla facilisi. Morbi mattis felis nec accumsan varius. Vestibulum in sodales arcu. Vivamus egestas, ante nec dapibus vestibulum, tellus ipsum rhoncus mi, at fermentum sapien justo nec turpis. Quisque rhoncus, urna sit amet imperdiet cursus, tortor lacus ultricies sapien, eu bibendum ligula enim id mi. Sed sem leo, pharetra in pulvinar sed, faucibus sed dui. Morbi tempus erat nec neque placerat tincidunt.
Quisque ut lorem sodales magna faucibus mattis. Aenean dui neque, gravida ut fringilla non, fermentum sit amet dolor. Mauris a sapien lacinia, elementum dolor in, sagittis metus. Donec viverra magna non lorem rutrum, at eleifend lacus volutpat. Nunc sit amet dolor tempor, blandit sapien a, consectetur magna. Suspendisse maximus nunc nec imperdiet aliquet. Nunc aliquam interdum purus quis pretium. Mauris molestie feugiat pellentesque. Nunc maximus, est sed consequat malesuada, risus turpis consequat velit, ac feugiat nunc magna vitae ligula. Vestibulum tincidunt massa ante, vitae pellentesque tortor rutrum sed. Aliquam vel est libero. Suspendisse et convallis orci. Cras sed lorem consectetur, blandit massa sit amet, semper neque. Vestibulum et mi euismod, imperdiet justo non, facilisis libero.
Sed at lacus ac tortor dictum tempus. Integer commodo purus lacus, ut pretium est tempor ac. Ut vulputate nulla magna, ac facilisis velit commodo in. Interdum et malesuada fames ac ante ipsum primis in faucibus. Donec pellentesque congue nibh nec eleifend. Ut ante turpis, sodales sed aliquet quis, tempus eu dui. Proin et eros non risus porttitor pharetra.
Mauris a urna id justo gravida ultrices. Mauris commodo sed ipsum a dictum. In posuere luctus scelerisque. Morbi sit amet gravida ipsum. Quisque vel dui sit amet ex lobortis eleifend non vel neque. Fusce sit amet imperdiet felis, eu tempor diam. Pellentesque sit amet turpis in libero tristique posuere. Pellentesque habitant morbi tristique senectus et netus et malesuada fames ac turpis egestas. Mauris quis est suscipit, tristique odio elementum, molestie nibh. Maecenas ex dui, pulvinar quis pellentesque sed, imperdiet nec mauris. Pellentesque ultrices at mauris eget fringilla. Donec bibendum rhoncus felis, ut pretium nulla eleifend commodo.
Ut euismod erat accumsan tincidunt sagittis. Proin eget massa ex. Suspendisse at faucibus enim, vitae posuere mi. Cras nec ex finibus, porttitor purus quis, efficitur libero. Nulla sagittis ornare iaculis. Donec venenatis dui ut libero aliquam lobortis. Vestibulum imperdiet lorem urna, eget gravida orci sollicitudin ut. Quisque ultrices tortor at quam laoreet aliquet. Pellentesque tincidunt consequat pharetra. Cras a lacinia erat. Mauris sed neque lobortis ipsum facilisis hendrerit.
Cras at orci odio. Curabitur eros metus, consequat non placerat et, tincidunt at turpis. Morbi quis viverra metus. Vestibulum molestie, ex at suscipit finibus, ex magna pellentesque nisi, eu ullamcorper nisl sapien eu quam. Phasellus volutpat lacinia enim, nec fermentum augue tincidunt ut. Duis rutrum purus eu nulla elementum, a faucibus odio fringilla. Sed cursus risus neque, dictum luctus tortor tempus eu.
Mauris non arcu eu nunc faucibus tincidunt id quis dolor. Quisque ac fringilla libero. Sed non ligula ut nunc auctor consequat vitae eget metus. Ut suscipit leo quam, vitae ultrices urna feugiat eu. Vestibulum volutpat nisl quis nunc pretium, vel viverra orci fringilla. Proin erat nibh, laoreet nec nisi sit amet, volutpat efficitur nunc. Cras id tortor quis lectus imperdiet rutrum non id purus. Proin efficitur ligula non dapibus consectetur. Nam quis quam eget dui facilisis scelerisque. Praesent non bibendum risus. Etiam imperdiet nisi id consectetur porta. In pretium nulla ut leo ultricies rhoncus.
Curabitur non vehicula purus. Cras et justo risus. Duis et rutrum urna. Aliquam condimentum purus nec ante dignissim rhoncus. Vestibulum commodo pharetra eros, ac euismod orci rutrum vel. Integer sed cursus erat, euismod accumsan libero. Nullam ut odio sit amet nibh tempor congue. Pellentesque porttitor aliquam ipsum, sit amet facilisis quam fringilla ac. Aliquam scelerisque tempor nisl in tempor. Sed vestibulum, tellus sit amet mattis pellentesque, eros diam convallis felis, id pellentesque massa leo quis dolor. Integer dignissim orci lorem, vel porttitor felis blandit et. Nam ultrices enim sed elementum accumsan. Fusce rutrum, quam et feugiat maximus, lorem leo porttitor ex, a eleifend risus odio consectetur lacus. In hac habitasse platea dictumst. Aenean pharetra erat tellus, at tempus urna iaculis ut. Ut ac mi eu lorem volutpat egestas.
Vestibulum ante ipsum primis in faucibus orci luctus et ultrices posuere cubilia curae; Praesent porttitor tempor ligula. Quisque mollis arcu in metus ornare pellentesque. Aenean ultrices mollis quam quis sodales. Maecenas a cursus elit, id gravida tortor. Donec vel purus magna. Aliquam elementum est sed convallis fermentum. Nam nec eros arcu. Pellentesque sed eros a lacus sagittis maximus. Integer et tellus id libero dapibus convallis. Maecenas viverra, purus facilisis porttitor tincidunt, tellus lacus elementum dui, sed porttitor sem justo a lorem. Curabitur ipsum odio, efficitur quis efficitur at, tempus aliquet nisi. Aliquam ultrices tortor in arcu vulputate, vel iaculis lorem facilisis. Cras eleifend laoreet feugiat. Integer placerat blandit sem, mattis elementum purus pellentesque quis. Etiam vel arcu ut mi commodo placerat non id tortor.
"""
@pytest.mark.asyncio
async def test_supervisor_long_prompt_response(
pipeline_shard_meta: Callable[..., PipelineShardMetadata],
hosts: Callable[..., list[Host]],
chat_completion_task: Callable[[InstanceId, TaskId], Task],
lorem_ipsum: str,
):
"""Test that asking for the capital of France returns 'Paris' in the response"""
model_meta = MODEL_CARDS["llama-3.2-1b"].metadata
model_shard_meta = PipelineShardMetadata(
model_meta=model_meta,
device_rank=0,
world_size=1,
n_layers=model_meta.n_layers,
start_layer=0,
end_layer=model_meta.n_layers,
)
instance_id = InstanceId()
print(f"{model_shard_meta=}")
supervisor = await RunnerSupervisor.create(
model_shard_meta=model_shard_meta,
hosts=hosts(1, offset=10),
)
try:
full_response = ""
task = chat_completion_task(instance_id, TaskId())
task.task_params.messages[0].content = lorem_ipsum * 3
async for chunk in supervisor.stream_response(task=task):
if isinstance(chunk, TokenChunk):
full_response += chunk.text
assert len(full_response) > 100
finally:
await supervisor.astop()
@pytest.mark.asyncio
async def test_supervisor_two_node_long_prompt_response(
pipeline_shard_meta: Callable[..., PipelineShardMetadata],
hosts: Callable[..., list[Host]],
chat_completion_task: Callable[[InstanceId, TaskId], Task],
lorem_ipsum: str,
):
"""Test two-node long prompt inference"""
instance_id = InstanceId()
async def create_supervisor(shard_idx: int) -> RunnerSupervisor:
model_meta = MODEL_CARDS["llama-3.2-1b"].metadata
model_shard_meta = PipelineShardMetadata(
model_meta=model_meta,
device_rank=shard_idx,
world_size=2,
n_layers=model_meta.n_layers,
start_layer=0 if shard_idx == 0 else model_meta.n_layers // 2,
end_layer=model_meta.n_layers // 2
if shard_idx == 0
else model_meta.n_layers,
)
supervisor = await RunnerSupervisor.create(
model_shard_meta=model_shard_meta,
hosts=hosts(2, offset=15),
)
return supervisor
create_supervisor_0 = asyncio.create_task(create_supervisor(0))
create_supervisor_1 = asyncio.create_task(create_supervisor(1))
supervisor_0, supervisor_1 = await asyncio.gather(
create_supervisor_0, create_supervisor_1
)
await asyncio.sleep(0.1)
try:
full_response_0 = ""
full_response_1 = ""
stop_reason_0: FinishReason | None = None
stop_reason_1: FinishReason | None = None
task = chat_completion_task(instance_id, TaskId())
task.task_params.messages[0].content = lorem_ipsum * 3
async def collect_response_0():
nonlocal full_response_0, stop_reason_0
async for chunk in supervisor_0.stream_response(task=task):
if isinstance(chunk, TokenChunk):
full_response_0 += chunk.text
if chunk.finish_reason:
stop_reason_0 = chunk.finish_reason
async def collect_response_1():
nonlocal full_response_1, stop_reason_1
async for chunk in supervisor_1.stream_response(task=task):
if isinstance(chunk, TokenChunk):
full_response_1 += chunk.text
if chunk.finish_reason:
stop_reason_1 = chunk.finish_reason
# Run both stream responses simultaneously
_ = await asyncio.gather(collect_response_0(), collect_response_1())
assert len(full_response_0) > 100
assert len(full_response_1) > 100
finally:
await supervisor_0.astop()
await supervisor_1.astop()

View File

@@ -1,58 +0,0 @@
from multiprocessing import Process
from typing import Callable
import psutil
import pytest
from exo.shared.models.model_meta import get_model_meta
from exo.shared.types.common import Host
from exo.shared.types.models import ModelMetadata
from exo.shared.types.tasks import Task, TaskId
from exo.shared.types.worker.common import InstanceId, RunnerError
from exo.shared.types.worker.shards import PipelineShardMetadata
from exo.worker.runner.runner_supervisor import RunnerSupervisor
from exo.worker.tests.constants import INSTANCE_1_ID, TASK_1_ID
def get_memory_mb(process: Process) -> float:
"""
Returns the resident set size (RSS) memory usage in MiB for the given process.
"""
ps = psutil.Process(process.pid)
rss_bytes: int = ps.memory_info().rss # type: ignore[attr-defined]
return rss_bytes / (1024 * 1024)
@pytest.fixture
async def model_meta() -> ModelMetadata:
return await get_model_meta("mlx-community/Llama-3.3-70B-Instruct-4bit")
@pytest.mark.asyncio
async def test_supervisor_inference_exception(
pipeline_shard_meta: Callable[..., PipelineShardMetadata],
hosts: Callable[..., list[Host]],
chat_completion_task: Callable[[InstanceId, TaskId], Task],
):
model_shard_meta = pipeline_shard_meta(1, 0)
supervisor = await RunnerSupervisor.create(
model_shard_meta=model_shard_meta,
hosts=hosts(1, offset=10),
)
process: Process = supervisor.runner_process
memory = get_memory_mb(process)
assert memory > 30 * 100
task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID)
task.task_params.messages[0].content = "EXO RUNNER MUST FAIL"
with pytest.raises(RunnerError):
async for _ in supervisor.stream_response(task):
pass
await supervisor.astop()
available_memory_bytes: int = psutil.virtual_memory().available
print(available_memory_bytes // (2**30))
assert available_memory_bytes > 30 * 2**30

View File

@@ -1,48 +0,0 @@
from typing import Callable
import pytest
from exo.shared.types.common import Host
from exo.shared.types.tasks import (
Task,
TaskId,
)
from exo.shared.types.worker.common import InstanceId, RunnerError
from exo.shared.types.worker.shards import PipelineShardMetadata
from exo.worker.runner.runner_supervisor import RunnerSupervisor
from exo.worker.tests.constants import INSTANCE_1_ID, TASK_1_ID
@pytest.fixture
def user_message():
"""Override the default message to ask about France's capital"""
return "What is the capital of France?"
@pytest.mark.asyncio
@pytest.mark.skip(
reason="Must run `sudo sysctl -w iogpu.wired_limit_mb=` and `sudo sysctl -w iogpu.wired_lwm_mb=` before running this test."
)
async def test_supervisor_catches_oom(
pipeline_shard_meta: Callable[..., PipelineShardMetadata],
hosts: Callable[..., list[Host]],
chat_completion_task: Callable[[InstanceId, TaskId], Task],
):
"""Test that asking for the capital of France returns 'Paris' in the response"""
model_shard_meta = pipeline_shard_meta(1, 0)
supervisor = await RunnerSupervisor.create(
model_shard_meta=model_shard_meta,
hosts=hosts(1, offset=10),
)
task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID)
task.task_params.messages[0].content = "EXO RUNNER MUST OOM"
with pytest.raises(RunnerError) as exc_info:
async for _ in supervisor.stream_response(task):
pass
error = exc_info.value
assert "memory" in error.error_message.lower()
await supervisor.astop()

View File

@@ -1,224 +0,0 @@
import asyncio
from typing import Callable
import pytest
from exo.shared.openai_compat import FinishReason
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.common import Host
from exo.shared.types.tasks import (
ChatCompletionTask,
ChatCompletionTaskParams,
Task,
TaskId,
)
from exo.shared.types.worker.common import InstanceId
from exo.shared.types.worker.shards import PipelineShardMetadata
from exo.worker.runner.runner_supervisor import RunnerSupervisor
@pytest.fixture
def user_message():
"""Override the default message to ask about France's capital"""
return "What is the capital of France?"
@pytest.mark.asyncio
async def test_supervisor_single_node_response(
pipeline_shard_meta: Callable[..., PipelineShardMetadata],
hosts: Callable[..., list[Host]],
chat_completion_task: Callable[[InstanceId, TaskId], Task],
):
"""Test that asking for the capital of France returns 'Paris' in the response"""
model_shard_meta = pipeline_shard_meta(1, 0)
instance_id = InstanceId()
print(f"{model_shard_meta=}")
supervisor = await RunnerSupervisor.create(
model_shard_meta=model_shard_meta,
hosts=hosts(1, offset=10),
)
try:
full_response = ""
stop_reason: FinishReason | None = None
async for chunk in supervisor.stream_response(
task=chat_completion_task(instance_id, TaskId())
):
if isinstance(chunk, TokenChunk):
full_response += chunk.text
if chunk.finish_reason:
stop_reason = chunk.finish_reason
# Case-insensitive check for Paris in the response
assert "paris" in full_response.lower(), (
f"Expected 'Paris' in response, but got: {full_response}"
)
assert stop_reason == "stop"
finally:
await supervisor.astop()
@pytest.mark.asyncio
async def test_supervisor_two_node_response(
pipeline_shard_meta: Callable[..., PipelineShardMetadata],
hosts: Callable[..., list[Host]],
chat_completion_task: Callable[[InstanceId, TaskId], Task],
):
"""Test that asking for the capital of France returns 'Paris' in the response"""
instance_id = InstanceId()
async def create_supervisor(shard_idx: int) -> RunnerSupervisor:
supervisor = await RunnerSupervisor.create(
model_shard_meta=pipeline_shard_meta(2, shard_idx),
hosts=hosts(2, offset=15),
)
return supervisor
create_supervisor_0 = asyncio.create_task(create_supervisor(0))
create_supervisor_1 = asyncio.create_task(create_supervisor(1))
supervisor_0, supervisor_1 = await asyncio.gather(
create_supervisor_0, create_supervisor_1
)
await asyncio.sleep(0.1)
try:
full_response_0 = ""
full_response_1 = ""
async def collect_response_0():
nonlocal full_response_0
async for chunk in supervisor_0.stream_response(
task=chat_completion_task(instance_id, TaskId())
):
if isinstance(chunk, TokenChunk):
full_response_0 += chunk.text
async def collect_response_1():
nonlocal full_response_1
async for chunk in supervisor_1.stream_response(
task=chat_completion_task(instance_id, TaskId())
):
if isinstance(chunk, TokenChunk):
full_response_1 += chunk.text
# Run both stream responses simultaneously
_ = await asyncio.gather(collect_response_0(), collect_response_1())
print(f"full_response_0: {full_response_0}")
print(f"full_response_1: {full_response_1}")
# Case-insensitive check for Paris in both responses
assert "paris" in full_response_0.lower(), (
f"Expected 'Paris' in response, but got: {full_response_0}"
)
assert "paris" in full_response_1.lower(), (
f"Expected 'Paris' in response, but got: {full_response_1}"
)
finally:
await supervisor_0.astop()
await supervisor_1.astop()
@pytest.mark.asyncio
async def test_supervisor_early_stopping(
pipeline_shard_meta: Callable[..., PipelineShardMetadata],
hosts: Callable[..., list[Host]],
chat_completion_task: Callable[[InstanceId, TaskId], Task],
):
"""Test that asking for the capital of France returns 'Paris' in the response"""
model_shard_meta = pipeline_shard_meta(1, 0)
instance_id = InstanceId()
supervisor = await RunnerSupervisor.create(
model_shard_meta=model_shard_meta,
hosts=hosts(1, offset=10),
)
task = chat_completion_task(instance_id, TaskId())
max_tokens = 50
assert isinstance(task, ChatCompletionTask)
print(f"chat_completion_task.task_params: {task.task_params}")
assert isinstance(task.task_params, ChatCompletionTaskParams)
task_params: ChatCompletionTaskParams = task.task_params
try:
task_params.max_tokens = max_tokens
# Convert messages to a list to allow indexing, then update the first message's content
messages = list(task_params.messages)
messages[0].content = "Please count from 1 to 100"
task_params.messages = messages
full_response = ""
count = 0
stop_reason: FinishReason | None = None
async for chunk in supervisor.stream_response(task=task):
if isinstance(chunk, TokenChunk):
full_response += chunk.text
count += 1
if chunk.finish_reason:
stop_reason = chunk.finish_reason
print(f"full_response: {full_response}")
assert count == max_tokens + 1
assert "7" in full_response.lower()
assert "99" not in full_response.lower()
assert stop_reason == "length"
finally:
await supervisor.astop()
@pytest.mark.asyncio
async def test_supervisor_handles_terminated_runner(
pipeline_shard_meta: Callable[..., PipelineShardMetadata],
hosts: Callable[..., list[Host]],
):
"""Test that the supervisor handles a terminated runner"""
model_shard_meta = pipeline_shard_meta(1, 0)
supervisor = await RunnerSupervisor.create(
model_shard_meta=model_shard_meta,
hosts=hosts(1, offset=10),
)
# Terminate the runner
supervisor.runner_process.terminate()
await asyncio.sleep(0.1)
assert not supervisor.runner_process.is_alive()
del supervisor
@pytest.mark.asyncio
async def test_supervisor_handles_killed_runner(
pipeline_shard_meta: Callable[..., PipelineShardMetadata],
hosts: Callable[..., list[Host]],
):
"""Test that the supervisor handles a killed runner"""
model_shard_meta = pipeline_shard_meta(1, 0)
supervisor = await RunnerSupervisor.create(
model_shard_meta=model_shard_meta,
hosts=hosts(1, offset=10),
)
assert supervisor.runner_process.is_alive()
# Forcibly kill the runner
supervisor.runner_process.kill()
await asyncio.sleep(0.1)
assert not supervisor.runner_process.is_alive()
del supervisor

View File

@@ -1,92 +0,0 @@
import asyncio
from typing import Callable
import pytest
from exo.shared.types.common import Host
from exo.shared.types.tasks import Task, TaskId
from exo.shared.types.worker.common import InstanceId, RunnerError
from exo.shared.types.worker.shards import PipelineShardMetadata
from exo.worker.runner.runner_supervisor import RunnerSupervisor
from exo.worker.tests.constants import INSTANCE_1_ID, TASK_1_ID
@pytest.mark.asyncio
async def test_supervisor_instantiation_exception(
pipeline_shard_meta: Callable[..., PipelineShardMetadata],
hosts: Callable[..., list[Host]],
):
"""Test that asking for the capital of France returns 'Paris' in the response"""
model_shard_meta = pipeline_shard_meta(1, 0)
model_shard_meta.immediate_exception = True
# _ = await RunnerSupervisor.create(
# model_shard_meta=model_shard_meta,
# hosts=hosts(1, offset=10),
# )
with pytest.raises(RunnerError):
_ = await RunnerSupervisor.create(
model_shard_meta=model_shard_meta,
hosts=hosts(1, offset=10),
)
@pytest.mark.asyncio
async def test_supervisor_instantiation_timeout(
pipeline_shard_meta: Callable[..., PipelineShardMetadata],
hosts: Callable[..., list[Host]],
):
"""Test that asking for the capital of France returns 'Paris' in the response"""
model_shard_meta = pipeline_shard_meta(1, 0)
model_shard_meta.should_timeout = 10 # timeout after 10s
with pytest.raises(asyncio.TimeoutError):
_ = await RunnerSupervisor.create(
model_shard_meta=model_shard_meta,
hosts=hosts(1, offset=10),
)
@pytest.mark.asyncio
async def test_supervisor_inference_exception(
pipeline_shard_meta: Callable[..., PipelineShardMetadata],
hosts: Callable[..., list[Host]],
chat_completion_task: Callable[[InstanceId, TaskId], Task],
):
"""Test that asking for the capital of France returns 'Paris' in the response"""
model_shard_meta = pipeline_shard_meta(1, 0)
supervisor = await RunnerSupervisor.create(
model_shard_meta=model_shard_meta,
hosts=hosts(1, offset=10),
)
task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID)
task.task_params.messages[0].content = "EXO RUNNER MUST FAIL"
with pytest.raises(RunnerError):
async for _ in supervisor.stream_response(task):
pass
@pytest.mark.asyncio
async def test_supervisor_inference_timeout(
pipeline_shard_meta: Callable[..., PipelineShardMetadata],
hosts: Callable[..., list[Host]],
chat_completion_task: Callable[[InstanceId, TaskId], Task],
):
"""Test that asking for the capital of France returns 'Paris' in the response"""
model_shard_meta = pipeline_shard_meta(1, 0)
supervisor = await RunnerSupervisor.create(
model_shard_meta=model_shard_meta,
hosts=hosts(1, offset=10),
)
task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID)
task.task_params.messages[0].content = "EXO RUNNER MUST TIMEOUT"
with pytest.raises(asyncio.TimeoutError):
async for _ in supervisor.stream_response(task):
pass
await asyncio.sleep(0.1)

View File

@@ -1,189 +0,0 @@
from dataclasses import dataclass
from typing import Callable
from anyio import fail_after
from exo.routing.topics import ConnectionMessage, ForwarderCommand, ForwarderEvent
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.common import NodeId, SessionId
from exo.shared.types.events import ChunkGenerated, Event, TaskStateUpdated
from exo.shared.types.tasks import TaskId, TaskStatus
from exo.utils.channels import Receiver, Sender, channel
from exo.worker.download.shard_downloader import NoopShardDownloader, ShardDownloader
from exo.worker.main import Worker
from exo.worker.tests.constants import MASTER_NODE_ID
session = SessionId(master_node_id=MASTER_NODE_ID, election_clock=0)
@dataclass
class WorkerMailbox:
sender: Sender[ForwarderEvent]
receiver: Receiver[ForwarderEvent]
counter: int = 0
async def append_events(
self,
events: list[Event],
*,
origin: NodeId,
):
for event in events:
await self.sender.send(
ForwarderEvent(
origin=origin,
session=session,
event=event,
origin_idx=self.counter,
)
)
self.counter += 1
async def receive(self) -> ForwarderEvent:
return await self.receiver.receive()
def collect(self) -> list[ForwarderEvent]:
# Clear out the test mailboxes currently held events
return self.receiver.collect()
def create_worker_void_mailbox(
node_id: NodeId, shard_downloader: ShardDownloader | None = None
) -> Worker:
if shard_downloader is None:
shard_downloader = NoopShardDownloader()
return Worker(
node_id,
session_id=session,
shard_downloader=shard_downloader,
initial_connection_messages=[],
connection_message_receiver=channel[ConnectionMessage]()[1],
global_event_receiver=channel[ForwarderEvent]()[1],
local_event_sender=channel[ForwarderEvent]()[0],
command_sender=channel[ForwarderCommand]()[0],
)
def create_worker_and_mailbox(
node_id: NodeId, shard_downloader: ShardDownloader | None = None
) -> tuple[Worker, WorkerMailbox]:
if shard_downloader is None:
shard_downloader = NoopShardDownloader()
lsend, receiver = channel[ForwarderEvent]()
sender, grecv = channel[ForwarderEvent]()
worker = Worker(
node_id,
session_id=session,
shard_downloader=shard_downloader,
initial_connection_messages=[],
connection_message_receiver=channel[ConnectionMessage]()[1],
global_event_receiver=grecv,
local_event_sender=lsend,
command_sender=channel[ForwarderCommand]()[0],
)
return worker, WorkerMailbox(sender, receiver)
def create_worker_with_old_mailbox(
node_id: NodeId,
mailbox: WorkerMailbox,
shard_downloader: ShardDownloader | None = None,
) -> Worker:
if shard_downloader is None:
shard_downloader = NoopShardDownloader()
# This function is subtly complex, come talk to Evan if you want to know what it's actually doing.
worker = Worker(
node_id,
session_id=session,
shard_downloader=shard_downloader,
initial_connection_messages=[],
connection_message_receiver=channel[ConnectionMessage]()[1],
global_event_receiver=mailbox.sender.clone_receiver(),
local_event_sender=mailbox.receiver.clone_sender(),
command_sender=channel[ForwarderCommand]()[0],
)
return worker
async def read_streaming_response(
global_event_receiver: WorkerMailbox, filter_task: TaskId | None = None
) -> tuple[bool, bool, str, int]:
# Read off all events - these should be our GenerationChunk events
seen_task_started = 0
seen_task_finished = 0
response_string = ""
finish_reason: str | None = None
token_count = 0
extra_events: list[Event] = []
event = (await global_event_receiver.receive()).event
extra_events.append(event)
from loguru import logger
logger.info("STARTING READ")
with fail_after(10.0):
if filter_task:
while not (
isinstance(event, TaskStateUpdated)
and event.task_status == TaskStatus.Running
and event.task_id == filter_task
):
event = (await global_event_receiver.receive()).event
extra_events.append(event)
for event in extra_events:
if isinstance(event, TaskStateUpdated):
if event.task_status == TaskStatus.Running:
seen_task_started += 1
if event.task_status == TaskStatus.Complete:
seen_task_finished += 1
if isinstance(event, ChunkGenerated) and isinstance(
event.chunk, TokenChunk
):
response_string += event.chunk.text
token_count += 1
if event.chunk.finish_reason:
finish_reason = event.chunk.finish_reason
while not seen_task_finished:
event = (await global_event_receiver.receive()).event
if isinstance(event, TaskStateUpdated):
if event.task_status == TaskStatus.Running:
seen_task_started += 1
if event.task_status == TaskStatus.Complete:
seen_task_finished += 1
if isinstance(event, ChunkGenerated) and isinstance(
event.chunk, TokenChunk
):
response_string += event.chunk.text
token_count += 1
if event.chunk.finish_reason:
finish_reason = event.chunk.finish_reason
logger.info(f"finish reason {finish_reason}")
return seen_task_started == 1, seen_task_finished == 1, response_string, token_count
async def until_event_with_timeout[T](
global_event_receiver: WorkerMailbox,
event_type: type[T],
multiplicity: int = 1,
condition: Callable[[T], bool] = lambda x: True,
timeout: float = 30.0,
) -> None:
times_seen = 0
with fail_after(timeout):
while times_seen < multiplicity:
event = (await global_event_receiver.receive()).event
if isinstance(event, event_type):
print(f"Wow! We got a {event}")
print(
f"But condition? {condition(event) if isinstance(event, event_type) else False}"
)
if event and isinstance(event, event_type) and condition(event):
times_seen += 1

View File

@@ -0,0 +1,97 @@
import platform
import shutil
from subprocess import CalledProcessError
from anyio import run_process
from pydantic import BaseModel, ConfigDict, ValidationError
class MacMonError(Exception):
"""Exception raised for errors in the MacMon functions."""
def _get_binary_path() -> str:
"""
Get the path to the macmon binary.
Raises:
MacMonError: If the binary doesn't exist or can't be made executable.
"""
# Check for macOS with ARM chip
system = platform.system().lower()
machine = platform.machine().lower()
if system != "darwin" or not (
"arm" in machine or "m1" in machine or "m2" in machine
):
raise MacMonError("MacMon only supports macOS with Apple Silicon (ARM) chips")
path = shutil.which("macmon")
if path is None:
raise MacMonError("MacMon not found in PATH")
return path
class TempMetrics(BaseModel):
"""Temperature-related metrics returned by macmon."""
cpu_temp_avg: float
gpu_temp_avg: float
model_config = ConfigDict(extra="ignore")
class Metrics(BaseModel):
"""Complete set of metrics returned by macmon.
Unknown fields are ignored for forward-compatibility.
"""
all_power: float
ane_power: float
cpu_power: float
ecpu_usage: tuple[int, float]
gpu_power: float
gpu_ram_power: float
gpu_usage: tuple[int, float]
pcpu_usage: tuple[int, float]
ram_power: float
sys_power: float
temp: TempMetrics
timestamp: str
model_config = ConfigDict(extra="ignore")
async def get_metrics_async() -> Metrics:
"""
Asynchronously run the binary and return the metrics as a Python dictionary.
Args:
binary_path: Optional path to the binary. If not provided, will use the bundled binary.
Returns:
A mapping containing system metrics.
Raises:
MacMonError: If there's an error running the binary.
"""
path = _get_binary_path()
result = None
try:
# TODO: Keep Macmon running in the background?
result = await run_process([path, "pipe", "-s", "1"])
return Metrics.model_validate_json(result.stdout.decode().strip())
except ValidationError as e:
raise MacMonError(f"Error parsing JSON output: {e}") from e
except CalledProcessError as e:
if result:
raise MacMonError(
f"MacMon failed with return code {result.returncode}"
) from e
raise e

View File

Binary file not shown.

View File

@@ -1,3 +0,0 @@
from .macmon import MacMonError, get_metrics, get_metrics_async
__all__ = ["get_metrics", "get_metrics_async", "MacMonError"]

View File

@@ -1,150 +0,0 @@
import asyncio
import platform
import shutil
import subprocess
from pydantic import BaseModel, ConfigDict, ValidationError
class MacMonError(Exception):
"""Exception raised for errors in the MacMon functions."""
def _get_binary_path() -> str:
"""
Get the path to the macmon binary.
Raises:
MacMonError: If the binary doesn't exist or can't be made executable.
"""
# Check for macOS with ARM chip
system = platform.system().lower()
machine = platform.machine().lower()
if system != "darwin" or not (
"arm" in machine or "m1" in machine or "m2" in machine
):
raise MacMonError("MacMon only supports macOS with Apple Silicon (ARM) chips")
path = shutil.which("macmon")
if path is None:
raise MacMonError("MacMon not found in PATH")
return path
# ---------------------------------------------------------------------------
# Pydantic metric structures
# ---------------------------------------------------------------------------
class MemoryMetrics(BaseModel):
"""Memory-related metrics returned by macmon."""
ram_total: int | None = None
ram_usage: int | None = None
swap_total: int | None = None
swap_usage: int | None = None
model_config = ConfigDict(extra="ignore")
class TempMetrics(BaseModel):
"""Temperature-related metrics returned by macmon."""
cpu_temp_avg: float | None = None
gpu_temp_avg: float | None = None
model_config = ConfigDict(extra="ignore")
class Metrics(BaseModel):
"""Complete set of metrics returned by *macmon* binary.
All fields are optional to allow for partial output from the binary.
Unknown fields are ignored for forward-compatibility.
"""
all_power: float | None = None
ane_power: float | None = None
cpu_power: float | None = None
ecpu_usage: tuple[int, float] | None = None
gpu_power: float | None = None
gpu_ram_power: float | None = None
gpu_usage: tuple[int, float] | None = None
memory: MemoryMetrics | None = None
pcpu_usage: tuple[int, float] | None = None
ram_power: float | None = None
sys_power: float | None = None
temp: TempMetrics | None = None
timestamp: str | None = None
model_config = ConfigDict(extra="ignore")
# ---------------------------------------------------------------------------
# Synchronous helper
# ---------------------------------------------------------------------------
def get_metrics() -> Metrics:
"""
Run the binary and return the metrics as a Python dictionary.
Returns:
A mapping containing system metrics.
Raises:
MacMonError: If there's an error running the binary.
"""
path = _get_binary_path()
try:
# Run the binary with the argument -s 1 and capture its output
result = subprocess.run(
[path, "pipe", "-s", "1"], capture_output=True, text=True, check=True
)
return Metrics.model_validate_json(result.stdout)
except subprocess.CalledProcessError as e:
raise MacMonError(f"Error running binary: {e.stderr}") from e # type: ignore
except ValidationError as e:
raise MacMonError(f"Error parsing JSON output: {e}") from e
async def get_metrics_async() -> Metrics:
"""
Asynchronously run the binary and return the metrics as a Python dictionary.
Args:
binary_path: Optional path to the binary. If not provided, will use the bundled binary.
Returns:
A mapping containing system metrics.
Raises:
MacMonError: If there's an error running the binary.
"""
path = _get_binary_path()
try:
proc = await asyncio.create_subprocess_exec(
path,
"pipe",
"-s",
"1",
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdout, stderr = await proc.communicate()
if proc.returncode != 0:
raise MacMonError(f"Error running binary: {stderr.decode().strip()}")
return Metrics.model_validate_json(stdout.decode().strip())
except ValidationError as e:
raise MacMonError(f"Error parsing JSON output: {e}") from e

View File

@@ -4,7 +4,6 @@ import platform
from typing import Any, Callable, Coroutine from typing import Any, Callable, Coroutine
import anyio import anyio
import psutil
from loguru import logger from loguru import logger
from exo.shared.types.memory import Memory from exo.shared.types.memory import Memory
@@ -13,59 +12,37 @@ from exo.shared.types.profiling import (
NodePerformanceProfile, NodePerformanceProfile,
SystemPerformanceProfile, SystemPerformanceProfile,
) )
from exo.worker.utils.macmon.macmon import ( from exo.worker.utils.macmon import (
MacMonError,
Metrics, Metrics,
) )
from exo.worker.utils.macmon.macmon import ( from exo.worker.utils.macmon import (
get_metrics_async as macmon_get_metrics_async, get_metrics_async as macmon_get_metrics_async,
) )
from exo.worker.utils.system_info import ( from exo.worker.utils.system_info import (
get_mac_friendly_name_async, get_friendly_name,
get_mac_system_info_async, get_model_and_chip,
get_network_interface_info_async, get_network_interfaces,
) )
async def get_metrics_async() -> Metrics: async def get_metrics_async() -> Metrics | None:
"""Return detailed Metrics on macOS or a minimal fallback elsewhere. """Return detailed Metrics on macOS or a minimal fallback elsewhere."""
The *Metrics* schema comes from ``utils.macmon.macmon``; on non-macOS systems we
fill only the ``memory`` sub-structure so downstream code can still access
``metrics.memory.ram_total`` & ``ram_usage``.
"""
if platform.system().lower() == "darwin": if platform.system().lower() == "darwin":
return await macmon_get_metrics_async() return await macmon_get_metrics_async()
return Metrics()
async def get_memory_profile_async() -> MemoryPerformanceProfile: def get_memory_profile() -> MemoryPerformanceProfile:
"""Return MemoryPerformanceProfile using psutil (fast, cross-platform). """Construct a MemoryPerformanceProfile using psutil"""
override_memory_env = os.getenv("OVERRIDE_MEMORY_MB")
override_memory: int | None = (
Memory.from_mb(int(override_memory_env)).in_bytes
if override_memory_env
else None
)
Uses synchronous psutil calls in a worker thread to avoid blocking the event loop. return MemoryPerformanceProfile.from_psutil(override_memory=override_memory)
"""
def _read_psutil() -> MemoryPerformanceProfile:
vm = psutil.virtual_memory()
sm = psutil.swap_memory()
override_memory_env = os.getenv("OVERRIDE_MEMORY_MB")
override_memory: int | None = (
Memory.from_mb(int(override_memory_env)).in_bytes
if override_memory_env
else None
)
return MemoryPerformanceProfile.from_bytes(
ram_total=int(vm.total),
ram_available=int(override_memory)
if override_memory
else int(vm.available),
swap_total=int(sm.total),
swap_available=int(sm.free),
)
return await asyncio.to_thread(_read_psutil)
async def start_polling_memory_metrics( async def start_polling_memory_metrics(
@@ -81,9 +58,9 @@ async def start_polling_memory_metrics(
""" """
while True: while True:
try: try:
mem = await get_memory_profile_async() mem = get_memory_profile()
await callback(mem) await callback(mem)
except Exception as e: except MacMonError as e:
logger.opt(exception=e).error("Memory Monitor encountered error") logger.opt(exception=e).error("Memory Monitor encountered error")
finally: finally:
await anyio.sleep(poll_interval_s) await anyio.sleep(poll_interval_s)
@@ -95,61 +72,41 @@ async def start_polling_node_metrics(
poll_interval_s = 1.0 poll_interval_s = 1.0
while True: while True:
try: try:
# Gather metrics & system info with a timeout on each call
metrics = await get_metrics_async() metrics = await get_metrics_async()
if metrics is None:
return
( network_interfaces = get_network_interfaces()
system_info, # these awaits could be joined but realistically they should be cached
network_interfaces, model_id, chip_id = await get_model_and_chip()
mac_friendly_name, friendly_name = await get_friendly_name()
) = await asyncio.gather(
get_mac_system_info_async(),
get_network_interface_info_async(),
get_mac_friendly_name_async(),
)
# do the memory profile last to get a fresh reading to not conflict with the other memory profiling loop # do the memory profile last to get a fresh reading to not conflict with the other memory profiling loop
memory_profile = await get_memory_profile_async() memory_profile = get_memory_profile()
await callback( await callback(
NodePerformanceProfile( NodePerformanceProfile(
model_id=system_info.model_id, model_id=model_id,
chip_id=system_info.chip_id, chip_id=chip_id,
friendly_name=mac_friendly_name or "Unknown", friendly_name=friendly_name,
network_interfaces=network_interfaces, network_interfaces=network_interfaces,
memory=memory_profile, memory=memory_profile,
system=SystemPerformanceProfile( system=SystemPerformanceProfile(
flops_fp16=0, gpu_usage=metrics.gpu_usage[1],
gpu_usage=metrics.gpu_usage[1] temp=metrics.temp.gpu_temp_avg,
if metrics.gpu_usage is not None sys_power=metrics.sys_power,
else 0, pcpu_usage=metrics.pcpu_usage[1],
temp=metrics.temp.gpu_temp_avg ecpu_usage=metrics.ecpu_usage[1],
if metrics.temp is not None ane_power=metrics.ane_power,
and metrics.temp.gpu_temp_avg is not None
else 0,
sys_power=metrics.sys_power
if metrics.sys_power is not None
else 0,
pcpu_usage=metrics.pcpu_usage[1]
if metrics.pcpu_usage is not None
else 0,
ecpu_usage=metrics.ecpu_usage[1]
if metrics.ecpu_usage is not None
else 0,
ane_power=metrics.ane_power
if metrics.ane_power is not None
else 0,
), ),
) )
) )
except asyncio.TimeoutError: except asyncio.TimeoutError:
# One of the operations took too long; skip this iteration but keep the loop alive.
logger.warning( logger.warning(
"[resource_monitor] Operation timed out after 30s, skipping this cycle." "[resource_monitor] Operation timed out after 30s, skipping this cycle."
) )
except Exception as e: except MacMonError as e:
# Catch-all to ensure the monitor keeps running.
logger.opt(exception=e).error("Resource Monitor encountered error") logger.opt(exception=e).error("Resource Monitor encountered error")
finally: finally:
await anyio.sleep(poll_interval_s) await anyio.sleep(poll_interval_s)

View File

@@ -1,77 +1,34 @@
import asyncio import socket
import re
import sys import sys
from subprocess import CalledProcessError
from loguru import logger import psutil
from pydantic import BaseModel, Field from anyio import run_process
from exo.shared.types.profiling import NetworkInterfaceInfo from exo.shared.types.profiling import NetworkInterfaceInfo
class SystemInfo(BaseModel): async def get_friendly_name() -> str:
model_id: str
chip_id: str
memory: int
network_interfaces: list[NetworkInterfaceInfo] = Field(default_factory=list)
async def get_mac_friendly_name_async() -> str | None:
""" """
Asynchronously gets the 'Computer Name' (friendly name) of a Mac. Asynchronously gets the 'Computer Name' (friendly name) of a Mac.
e.g., "John's MacBook Pro" e.g., "John's MacBook Pro"
Returns the name as a string, or None if an error occurs or not on macOS. Returns the name as a string, or None if an error occurs or not on macOS.
""" """
hostname = socket.gethostname()
# TODO: better non mac support
if sys.platform != "darwin": # 'darwin' is the platform name for macOS if sys.platform != "darwin": # 'darwin' is the platform name for macOS
logger.warning("Mac friendly name is designed for macOS only.") return hostname
return None
try: try:
# asyncio.create_subprocess_exec allows running external commands asynchronously. process = await run_process(["scutil", "--get", "ComputerName"])
# stdout=asyncio.subprocess.PIPE captures standard output. except CalledProcessError:
# stderr=asyncio.subprocess.PIPE captures standard error. return hostname
process = await asyncio.create_subprocess_exec(
"scutil",
"--get",
"ComputerName",
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
# process.communicate() reads all data from stdout and stderr return process.stdout.decode("utf-8", errors="replace").strip() or hostname
# and waits for the process to terminate.
# It returns a tuple (stdout_data, stderr_data).
stdout_data, stderr_data = await process.communicate()
# Check the return code of the process
if process.returncode == 0:
if stdout_data:
# Decode from bytes to string and strip whitespace
friendly_name = stdout_data.decode().strip()
return friendly_name
else:
# Should not happen if returncode is 0, but good to check
print("scutil command succeeded but produced no output.")
return None
else:
# If there was an error, print the stderr output
error_message = (
stderr_data.decode().strip() if stderr_data else "Unknown error"
)
print(
f"Error executing scutil (return code {process.returncode}): {error_message}"
)
return None
except FileNotFoundError:
# This would happen if scutil is somehow not found, highly unlikely on a Mac.
print("Error: 'scutil' command not found. Are you sure this is macOS?")
return None
except Exception as e:
print(f"An unexpected error occurred: {e}")
return None
async def get_network_interface_info_async() -> list[NetworkInterfaceInfo]: def get_network_interfaces() -> list[NetworkInterfaceInfo]:
""" """
Retrieves detailed network interface information on macOS. Retrieves detailed network interface information on macOS.
Parses output from 'networksetup -listallhardwareports' and 'ifconfig' Parses output from 'networksetup -listallhardwareports' and 'ifconfig'
@@ -80,162 +37,47 @@ async def get_network_interface_info_async() -> list[NetworkInterfaceInfo]:
""" """
interfaces_info: list[NetworkInterfaceInfo] = [] interfaces_info: list[NetworkInterfaceInfo] = []
async def _run_cmd_async(command_parts: list[str]) -> str | None: for iface, services in psutil.net_if_addrs().items():
# Helper to run a command and return its stdout, or None on error. for service in services:
try: match service.family:
process = await asyncio.create_subprocess_exec( case socket.AF_INET | socket.AF_INET6:
*command_parts,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdout_data, stderr_data = await process.communicate()
if process.returncode == 0:
# Use 'utf-8' and replace errors for robustness
return stdout_data.decode("utf-8", errors="replace").strip()
else:
error_message = (
stderr_data.decode("utf-8", errors="replace").strip()
if stderr_data
else "Unknown error"
)
print(
f"Error executing {' '.join(command_parts)} (code {process.returncode}): {error_message}"
)
return None
except FileNotFoundError:
print(
f"Error: Command '{command_parts[0]}' not found. Ensure it's in PATH."
)
return None
except Exception as e:
print(
f"An unexpected error occurred running {' '.join(command_parts)}: {e}"
)
return None
# Get interface names and IP addresses from ifconfig
ifconfig_output = await _run_cmd_async(["ifconfig"])
if ifconfig_output:
# Regex for interface name (e.g., en0:, utun0:, tailscale0.)
interface_header_pattern = re.compile(r"^([a-zA-Z0-9\._-]+):")
# Regex for IPv4 address (inet)
inet_pattern = re.compile(r"^\s+inet\s+(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})")
# Regex for IPv6 address (inet6)
inet6_pattern = re.compile(r"^\s+inet6\s+([0-9a-fA-F:]+(?:%[a-zA-Z0-9._-]+)?)")
current_if_name: str | None = None
for line in ifconfig_output.splitlines():
header_match = interface_header_pattern.match(line)
if header_match:
current_if_name = header_match.group(1)
if current_if_name:
inet_m = inet_pattern.match(line)
if inet_m:
ipv4_address = inet_m.group(1)
interfaces_info.append( interfaces_info.append(
NetworkInterfaceInfo( NetworkInterfaceInfo(name=iface, ip_address=service.address)
name=current_if_name, ip_address=ipv4_address, type=""
)
)
inet6_m = inet6_pattern.match(line)
if inet6_m:
ipv6_address = inet6_m.group(1)
interfaces_info.append(
NetworkInterfaceInfo(
name=current_if_name, ip_address=ipv6_address, type=""
)
) )
case _:
pass
return interfaces_info return interfaces_info
async def get_mac_system_info_async() -> SystemInfo: async def get_model_and_chip() -> tuple[str, str]:
"""Get Mac system information using system_profiler.""" """Get Mac system information using system_profiler."""
model_id_val = "Unknown Model" model = "Unknown Model"
chip_id_val = "Unknown Chip" chip = "Unknown Chip"
memory_val = 0
network_interfaces_info_list: list[NetworkInterfaceInfo] = []
# TODO: better non mac support
if sys.platform != "darwin": if sys.platform != "darwin":
return SystemInfo( return (model, chip)
model_id=model_id_val,
chip_id=chip_id_val,
memory=memory_val,
network_interfaces=network_interfaces_info_list,
)
try: try:
process = await asyncio.create_subprocess_exec( process = await run_process(
"system_profiler", [
"SPHardwareDataType", "system_profiler",
stdout=asyncio.subprocess.PIPE, "SPHardwareDataType",
stderr=asyncio.subprocess.PIPE, ]
) )
stdout_data, stderr_data = await process.communicate() except CalledProcessError:
if process.returncode == 0: return (model, chip)
if stdout_data:
output = stdout_data.decode().strip()
model_line = next(
(line for line in output.split("\n") if "Model Name" in line), None
)
model_id_val = (
model_line.split(": ")[1] if model_line else "Unknown Model"
)
chip_line = next( # less interested in errors here because this value should be hard coded
(line for line in output.split("\n") if "Chip" in line), None output = process.stdout.decode().strip()
)
chip_id_val = chip_line.split(": ")[1] if chip_line else "Unknown Chip"
memory_line = next( model_line = next(
(line for line in output.split("\n") if "Memory" in line), None (line for line in output.split("\n") if "Model Name" in line), None
)
memory_str = (
memory_line.split(": ")[1] if memory_line else "0 GB"
) # Default to "0 GB"
memory_units = memory_str.split()
if len(memory_units) == 2:
try:
memory_value_int = int(memory_units[0])
if memory_units[1] == "GB":
memory_val = memory_value_int * 1024 # Assuming MB
elif memory_units[1] == "MB":
memory_val = memory_value_int
else: # TB? Unlikely for typical memory, handle gracefully
memory_val = memory_value_int # Store as is, let consumer decide unit or log
print(f"Warning: Unknown memory unit {memory_units[1]}")
except ValueError:
print(
f"Warning: Could not parse memory value {memory_units[0]}"
)
memory_val = 0
else:
print(
"system_profiler command succeeded but produced no output for hardware."
)
else:
error_message = (
stderr_data.decode().strip() if stderr_data else "Unknown error"
)
print(
f"Error executing system_profiler (return code {process.returncode}): {error_message}"
)
except Exception as e:
print(f"Error getting Mac hardware info: {e}")
# Call the new function to get network info
try:
network_interfaces_info_list = await get_network_interface_info_async()
except Exception as e:
print(f"Error getting Mac network interface info: {e}")
network_interfaces_info_list = []
return SystemInfo(
model_id=model_id_val,
chip_id=chip_id_val,
memory=memory_val,
network_interfaces=network_interfaces_info_list,
) )
model = model_line.split(": ")[1] if model_line else "Unknown Model"
chip_line = next((line for line in output.split("\n") if "Chip" in line), None)
chip = chip_line.split(": ")[1] if chip_line else "Unknown Chip"
return (model, chip)

164
uv.lock generated
View File

@@ -1,5 +1,5 @@
version = 1 version = 1
revision = 1 revision = 3
requires-python = ">=3.13" requires-python = ">=3.13"
resolution-markers = [ resolution-markers = [
"sys_platform == 'darwin'", "sys_platform == 'darwin'",
@@ -320,15 +320,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/cd/c7/f65027c2810e14c3e7268353b1681932b87e5a48e65505d8cc17c99e36ae/cryptography-46.0.3-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:3b51b8ca4f1c6453d8829e1eb7299499ca7f313900dd4d89a24b8b87c0a780d4", size = 4686573, upload-time = "2025-10-15T23:18:06.908Z" }, { url = "https://files.pythonhosted.org/packages/cd/c7/f65027c2810e14c3e7268353b1681932b87e5a48e65505d8cc17c99e36ae/cryptography-46.0.3-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:3b51b8ca4f1c6453d8829e1eb7299499ca7f313900dd4d89a24b8b87c0a780d4", size = 4686573, upload-time = "2025-10-15T23:18:06.908Z" },
] ]
[[package]]
name = "distro"
version = "1.9.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/fc/f8/98eea607f65de6527f8a2e8885fc8015d3e6f5775df186e443e0964a11c3/distro-1.9.0.tar.gz", hash = "sha256:2fa77c6fd8940f116ee1d6b94a2f90b13b5ea8d019b98bc8bafdcabcdd9bdbed", size = 60722, upload-time = "2023-12-24T09:54:32.31Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/12/b3/231ffd4ab1fc9d679809f356cebee130ac7daa00d6d6f3206dd4fd137e9e/distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2", size = 20277, upload-time = "2023-12-24T09:54:30.421Z" },
]
[[package]] [[package]]
name = "exo" name = "exo"
version = "0.3.0" version = "0.3.0"
@@ -351,7 +342,6 @@ dependencies = [
{ name = "mlx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "mlx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx-lm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "mlx-lm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "networkx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "networkx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "openai", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "pathlib", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "pathlib", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "protobuf", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "protobuf", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "psutil", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "psutil", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -372,6 +362,7 @@ dependencies = [
dev = [ dev = [
{ name = "pytest", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "pytest", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "pytest-asyncio", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "pytest-asyncio", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "pytest-env", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "ruff", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "ruff", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
] ]
@@ -394,7 +385,6 @@ requires-dist = [
{ name = "mlx", specifier = ">=0.29.3" }, { name = "mlx", specifier = ">=0.29.3" },
{ name = "mlx-lm", specifier = ">=0.28.3" }, { name = "mlx-lm", specifier = ">=0.28.3" },
{ name = "networkx", specifier = ">=3.5" }, { name = "networkx", specifier = ">=3.5" },
{ name = "openai", specifier = ">=1.99.9" },
{ name = "pathlib", specifier = ">=1.0.1" }, { name = "pathlib", specifier = ">=1.0.1" },
{ name = "protobuf", specifier = ">=6.32.0" }, { name = "protobuf", specifier = ">=6.32.0" },
{ name = "psutil", specifier = ">=7.0.0" }, { name = "psutil", specifier = ">=7.0.0" },
@@ -415,6 +405,7 @@ requires-dist = [
dev = [ dev = [
{ name = "pytest", specifier = ">=8.4.0" }, { name = "pytest", specifier = ">=8.4.0" },
{ name = "pytest-asyncio", specifier = ">=1.0.0" }, { name = "pytest-asyncio", specifier = ">=1.0.0" },
{ name = "pytest-env" },
{ name = "ruff", specifier = ">=0.11.13" }, { name = "ruff", specifier = ">=0.11.13" },
] ]
@@ -594,34 +585,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/92/68/89ac4e5b12a9ff6286a12174c8538a5930e2ed662091dd2572bbe0a18c8a/hf_xet-1.2.0-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:a55558084c16b09b5ed32ab9ed38421e2d87cf3f1f89815764d1177081b99865", size = 3508920, upload-time = "2025-10-24T19:04:26.927Z" }, { url = "https://files.pythonhosted.org/packages/92/68/89ac4e5b12a9ff6286a12174c8538a5930e2ed662091dd2572bbe0a18c8a/hf_xet-1.2.0-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:a55558084c16b09b5ed32ab9ed38421e2d87cf3f1f89815764d1177081b99865", size = 3508920, upload-time = "2025-10-24T19:04:26.927Z" },
] ]
[[package]]
name = "httpcore"
version = "1.0.9"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "certifi", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "h11", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/06/94/82699a10bca87a5556c9c59b5963f2d039dbd239f25bc2a63907a05a14cb/httpcore-1.0.9.tar.gz", hash = "sha256:6e34463af53fd2ab5d807f399a9b45ea31c3dfa2276f15a2c3f00afff6e176e8", size = 85484, upload-time = "2025-04-24T22:06:22.219Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/7e/f5/f66802a942d491edb555dd61e3a9961140fd64c90bce1eafd741609d334d/httpcore-1.0.9-py3-none-any.whl", hash = "sha256:2d400746a40668fc9dec9810239072b40b4484b640a8c38fd654a024c7a1bf55", size = 78784, upload-time = "2025-04-24T22:06:20.566Z" },
]
[[package]]
name = "httpx"
version = "0.28.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "anyio", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "certifi", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "httpcore", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "idna", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/b1/df/48c586a5fe32a0f01324ee087459e112ebb7224f646c0b5023f5e79e9956/httpx-0.28.1.tar.gz", hash = "sha256:75e98c5f16b0f35b567856f597f06ff2270a374470a5c2392242528e3e3e42fc", size = 141406, upload-time = "2024-12-06T15:37:23.222Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517, upload-time = "2024-12-06T15:37:21.509Z" },
]
[[package]] [[package]]
name = "huggingface-hub" name = "huggingface-hub"
version = "0.36.0" version = "0.36.0"
@@ -671,46 +634,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/62/a1/3d680cbfd5f4b8f15abc1d571870c5fc3e594bb582bc3b64ea099db13e56/jinja2-3.1.6-py3-none-any.whl", hash = "sha256:85ece4451f492d0c13c5dd7c13a64681a86afae63a5f347908daf103ce6d2f67", size = 134899, upload-time = "2025-03-05T20:05:00.369Z" }, { url = "https://files.pythonhosted.org/packages/62/a1/3d680cbfd5f4b8f15abc1d571870c5fc3e594bb582bc3b64ea099db13e56/jinja2-3.1.6-py3-none-any.whl", hash = "sha256:85ece4451f492d0c13c5dd7c13a64681a86afae63a5f347908daf103ce6d2f67", size = 134899, upload-time = "2025-03-05T20:05:00.369Z" },
] ]
[[package]]
name = "jiter"
version = "0.11.1"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/a3/68/0357982493a7b20925aece061f7fb7a2678e3b232f8d73a6edb7e5304443/jiter-0.11.1.tar.gz", hash = "sha256:849dcfc76481c0ea0099391235b7ca97d7279e0fa4c86005457ac7c88e8b76dc", size = 168385, upload-time = "2025-10-17T11:31:15.186Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/7c/4b/e4dd3c76424fad02a601d570f4f2a8438daea47ba081201a721a903d3f4c/jiter-0.11.1-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:71b6a920a5550f057d49d0e8bcc60945a8da998019e83f01adf110e226267663", size = 305272, upload-time = "2025-10-17T11:29:39.249Z" },
{ url = "https://files.pythonhosted.org/packages/67/83/2cd3ad5364191130f4de80eacc907f693723beaab11a46c7d155b07a092c/jiter-0.11.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:0b3de72e925388453a5171be83379549300db01284f04d2a6f244d1d8de36f94", size = 314038, upload-time = "2025-10-17T11:29:40.563Z" },
{ url = "https://files.pythonhosted.org/packages/d3/3c/8e67d9ba524e97d2f04c8f406f8769a23205026b13b0938d16646d6e2d3e/jiter-0.11.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cc19dd65a2bd3d9c044c5b4ebf657ca1e6003a97c0fc10f555aa4f7fb9821c00", size = 345977, upload-time = "2025-10-17T11:29:42.009Z" },
{ url = "https://files.pythonhosted.org/packages/8d/a5/489ce64d992c29bccbffabb13961bbb0435e890d7f2d266d1f3df5e917d2/jiter-0.11.1-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d58faaa936743cd1464540562f60b7ce4fd927e695e8bc31b3da5b914baa9abd", size = 364503, upload-time = "2025-10-17T11:29:43.459Z" },
{ url = "https://files.pythonhosted.org/packages/d4/c0/e321dd83ee231d05c8fe4b1a12caf1f0e8c7a949bf4724d58397104f10f2/jiter-0.11.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:902640c3103625317291cb73773413b4d71847cdf9383ba65528745ff89f1d14", size = 487092, upload-time = "2025-10-17T11:29:44.835Z" },
{ url = "https://files.pythonhosted.org/packages/f9/5e/8f24ec49c8d37bd37f34ec0112e0b1a3b4b5a7b456c8efff1df5e189ad43/jiter-0.11.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:30405f726e4c2ed487b176c09f8b877a957f535d60c1bf194abb8dadedb5836f", size = 376328, upload-time = "2025-10-17T11:29:46.175Z" },
{ url = "https://files.pythonhosted.org/packages/7f/70/ded107620e809327cf7050727e17ccfa79d6385a771b7fe38fb31318ef00/jiter-0.11.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3217f61728b0baadd2551844870f65219ac4a1285d5e1a4abddff3d51fdabe96", size = 356632, upload-time = "2025-10-17T11:29:47.454Z" },
{ url = "https://files.pythonhosted.org/packages/19/53/c26f7251613f6a9079275ee43c89b8a973a95ff27532c421abc2a87afb04/jiter-0.11.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b1364cc90c03a8196f35f396f84029f12abe925415049204446db86598c8b72c", size = 384358, upload-time = "2025-10-17T11:29:49.377Z" },
{ url = "https://files.pythonhosted.org/packages/84/16/e0f2cc61e9c4d0b62f6c1bd9b9781d878a427656f88293e2a5335fa8ff07/jiter-0.11.1-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:53a54bf8e873820ab186b2dca9f6c3303f00d65ae5e7b7d6bda1b95aa472d646", size = 517279, upload-time = "2025-10-17T11:29:50.968Z" },
{ url = "https://files.pythonhosted.org/packages/60/5c/4cd095eaee68961bca3081acbe7c89e12ae24a5dae5fd5d2a13e01ed2542/jiter-0.11.1-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:7e29aca023627b0e0c2392d4248f6414d566ff3974fa08ff2ac8dbb96dfee92a", size = 508276, upload-time = "2025-10-17T11:29:52.619Z" },
{ url = "https://files.pythonhosted.org/packages/65/9b/4a57922437ca8753ef823f434c2dec5028b237d84fa320f06a3ba1aec6e8/jiter-0.11.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:d892b184da4d94d94ddb4031296931c74ec8b325513a541ebfd6dfb9ae89904b", size = 313814, upload-time = "2025-10-17T11:29:58.509Z" },
{ url = "https://files.pythonhosted.org/packages/76/50/62a0683dadca25490a4bedc6a88d59de9af2a3406dd5a576009a73a1d392/jiter-0.11.1-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:aa22c223a3041dacb2fcd37c70dfd648b44662b4a48e242592f95bda5ab09d58", size = 344987, upload-time = "2025-10-17T11:30:00.208Z" },
{ url = "https://files.pythonhosted.org/packages/da/00/2355dbfcbf6cdeaddfdca18287f0f38ae49446bb6378e4a5971e9356fc8a/jiter-0.11.1-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:330e8e6a11ad4980cd66a0f4a3e0e2e0f646c911ce047014f984841924729789", size = 356399, upload-time = "2025-10-17T11:30:02.084Z" },
{ url = "https://files.pythonhosted.org/packages/8d/00/d6006d069e7b076e4c66af90656b63da9481954f290d5eca8c715f4bf125/jiter-0.11.1-cp314-cp314-macosx_10_12_x86_64.whl", hash = "sha256:0fa1f70da7a8a9713ff8e5f75ec3f90c0c870be6d526aa95e7c906f6a1c8c676", size = 304624, upload-time = "2025-10-17T11:30:06.678Z" },
{ url = "https://files.pythonhosted.org/packages/fc/45/4a0e31eb996b9ccfddbae4d3017b46f358a599ccf2e19fbffa5e531bd304/jiter-0.11.1-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:569ee559e5046a42feb6828c55307cf20fe43308e3ae0d8e9e4f8d8634d99944", size = 315042, upload-time = "2025-10-17T11:30:08.87Z" },
{ url = "https://files.pythonhosted.org/packages/e7/91/22f5746f5159a28c76acdc0778801f3c1181799aab196dbea2d29e064968/jiter-0.11.1-cp314-cp314-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f69955fa1d92e81987f092b233f0be49d4c937da107b7f7dcf56306f1d3fcce9", size = 346357, upload-time = "2025-10-17T11:30:10.222Z" },
{ url = "https://files.pythonhosted.org/packages/f5/4f/57620857d4e1dc75c8ff4856c90cb6c135e61bff9b4ebfb5dc86814e82d7/jiter-0.11.1-cp314-cp314-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:090f4c9d4a825e0fcbd0a2647c9a88a0f366b75654d982d95a9590745ff0c48d", size = 365057, upload-time = "2025-10-17T11:30:11.585Z" },
{ url = "https://files.pythonhosted.org/packages/ce/34/caf7f9cc8ae0a5bb25a5440cc76c7452d264d1b36701b90fdadd28fe08ec/jiter-0.11.1-cp314-cp314-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bbf3d8cedf9e9d825233e0dcac28ff15c47b7c5512fdfe2e25fd5bbb6e6b0cee", size = 487086, upload-time = "2025-10-17T11:30:13.052Z" },
{ url = "https://files.pythonhosted.org/packages/50/17/85b5857c329d533d433fedf98804ebec696004a1f88cabad202b2ddc55cf/jiter-0.11.1-cp314-cp314-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2aa9b1958f9c30d3d1a558b75f0626733c60eb9b7774a86b34d88060be1e67fe", size = 376083, upload-time = "2025-10-17T11:30:14.416Z" },
{ url = "https://files.pythonhosted.org/packages/85/d3/2d9f973f828226e6faebdef034097a2918077ea776fb4d88489949024787/jiter-0.11.1-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e42d1ca16590b768c5e7d723055acd2633908baacb3628dd430842e2e035aa90", size = 357825, upload-time = "2025-10-17T11:30:15.765Z" },
{ url = "https://files.pythonhosted.org/packages/f4/55/848d4dabf2c2c236a05468c315c2cb9dc736c5915e65449ccecdba22fb6f/jiter-0.11.1-cp314-cp314-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:5db4c2486a023820b701a17aec9c5a6173c5ba4393f26662f032f2de9c848b0f", size = 383933, upload-time = "2025-10-17T11:30:17.34Z" },
{ url = "https://files.pythonhosted.org/packages/0b/6c/204c95a4fbb0e26dfa7776c8ef4a878d0c0b215868011cc904bf44f707e2/jiter-0.11.1-cp314-cp314-musllinux_1_1_aarch64.whl", hash = "sha256:4573b78777ccfac954859a6eff45cbd9d281d80c8af049d0f1a3d9fc323d5c3a", size = 517118, upload-time = "2025-10-17T11:30:18.684Z" },
{ url = "https://files.pythonhosted.org/packages/88/25/09956644ea5a2b1e7a2a0f665cb69a973b28f4621fa61fc0c0f06ff40a31/jiter-0.11.1-cp314-cp314-musllinux_1_1_x86_64.whl", hash = "sha256:7593ac6f40831d7961cb67633c39b9fef6689a211d7919e958f45710504f52d3", size = 508194, upload-time = "2025-10-17T11:30:20.719Z" },
{ url = "https://files.pythonhosted.org/packages/d5/fa/3b05e5c9d32efc770a8510eeb0b071c42ae93a5b576fd91cee9af91689a1/jiter-0.11.1-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:2cc5a3965285ddc33e0cab933e96b640bc9ba5940cea27ebbbf6695e72d6511c", size = 312561, upload-time = "2025-10-17T11:30:26.742Z" },
{ url = "https://files.pythonhosted.org/packages/50/d3/335822eb216154ddb79a130cbdce88fdf5c3e2b43dc5dba1fd95c485aaf5/jiter-0.11.1-cp314-cp314t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b572b3636a784c2768b2342f36a23078c8d3aa6d8a30745398b1bab58a6f1a8", size = 344551, upload-time = "2025-10-17T11:30:28.252Z" },
{ url = "https://files.pythonhosted.org/packages/31/6d/a0bed13676b1398f9b3ba61f32569f20a3ff270291161100956a577b2dd3/jiter-0.11.1-cp314-cp314t-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ad93e3d67a981f96596d65d2298fe8d1aa649deb5374a2fb6a434410ee11915e", size = 363051, upload-time = "2025-10-17T11:30:30.009Z" },
{ url = "https://files.pythonhosted.org/packages/a4/03/313eda04aa08545a5a04ed5876e52f49ab76a4d98e54578896ca3e16313e/jiter-0.11.1-cp314-cp314t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a83097ce379e202dcc3fe3fc71a16d523d1ee9192c8e4e854158f96b3efe3f2f", size = 485897, upload-time = "2025-10-17T11:30:31.429Z" },
{ url = "https://files.pythonhosted.org/packages/5f/13/a1011b9d325e40b53b1b96a17c010b8646013417f3902f97a86325b19299/jiter-0.11.1-cp314-cp314t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7042c51e7fbeca65631eb0c332f90c0c082eab04334e7ccc28a8588e8e2804d9", size = 375224, upload-time = "2025-10-17T11:30:33.18Z" },
{ url = "https://files.pythonhosted.org/packages/92/da/1b45026b19dd39b419e917165ff0ea629dbb95f374a3a13d2df95e40a6ac/jiter-0.11.1-cp314-cp314t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0a68d679c0e47649a61df591660507608adc2652442de7ec8276538ac46abe08", size = 356606, upload-time = "2025-10-17T11:30:34.572Z" },
{ url = "https://files.pythonhosted.org/packages/7a/0c/9acb0e54d6a8ba59ce923a180ebe824b4e00e80e56cefde86cc8e0a948be/jiter-0.11.1-cp314-cp314t-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a1b0da75dbf4b6ec0b3c9e604d1ee8beaf15bc046fff7180f7d89e3cdbd3bb51", size = 384003, upload-time = "2025-10-17T11:30:35.987Z" },
{ url = "https://files.pythonhosted.org/packages/3f/2b/e5a5fe09d6da2145e4eed651e2ce37f3c0cf8016e48b1d302e21fb1628b7/jiter-0.11.1-cp314-cp314t-musllinux_1_1_aarch64.whl", hash = "sha256:69dd514bf0fa31c62147d6002e5ca2b3e7ef5894f5ac6f0a19752385f4e89437", size = 516946, upload-time = "2025-10-17T11:30:37.425Z" },
{ url = "https://files.pythonhosted.org/packages/5f/fe/db936e16e0228d48eb81f9934e8327e9fde5185e84f02174fcd22a01be87/jiter-0.11.1-cp314-cp314t-musllinux_1_1_x86_64.whl", hash = "sha256:bb31ac0b339efa24c0ca606febd8b77ef11c58d09af1b5f2be4c99e907b11111", size = 507614, upload-time = "2025-10-17T11:30:38.977Z" },
]
[[package]] [[package]]
name = "linkify-it-py" name = "linkify-it-py"
version = "2.0.3" version = "2.0.3"
@@ -969,25 +892,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/e2/c1/6dba12fdf68b02a21ac411c9df19afa66bed2540f467150ca64d246b463d/numpy-2.3.4-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:e1708fac43ef8b419c975926ce1eaf793b0c13b7356cfab6ab0dc34c0a02ac0f", size = 18652691, upload-time = "2025-10-15T16:17:46.247Z" }, { url = "https://files.pythonhosted.org/packages/e2/c1/6dba12fdf68b02a21ac411c9df19afa66bed2540f467150ca64d246b463d/numpy-2.3.4-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:e1708fac43ef8b419c975926ce1eaf793b0c13b7356cfab6ab0dc34c0a02ac0f", size = 18652691, upload-time = "2025-10-15T16:17:46.247Z" },
] ]
[[package]]
name = "openai"
version = "2.7.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "anyio", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "distro", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "httpx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "jiter", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "pydantic", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "sniffio", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "tqdm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "typing-extensions", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/84/2c/3ca91dbd1a5b80c20fbd1e21d601f6afd7fd51927a1b27b08226b67ebd61/openai-2.7.0.tar.gz", hash = "sha256:8c42c24d06afece19e69afcb6c2b23b8b90f603a81616d8a0be80b80fb527ed2", size = 595876, upload-time = "2025-11-03T23:52:07.935Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/fc/0f/e9618a92a9497846a3071f2a7ed43409215947106c7e5ce7d082f784de10/openai-2.7.0-py3-none-any.whl", hash = "sha256:9fc44861a692b7e80a7ec1252c10af79612a3ef1581ecb192caf4585afca5363", size = 1008759, upload-time = "2025-11-03T23:52:05.322Z" },
]
[[package]] [[package]]
name = "packaging" name = "packaging"
version = "25.0" version = "25.0"
@@ -1213,6 +1117,18 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/04/93/2fa34714b7a4ae72f2f8dad66ba17dd9a2c793220719e736dda28b7aec27/pytest_asyncio-1.2.0-py3-none-any.whl", hash = "sha256:8e17ae5e46d8e7efe51ab6494dd2010f4ca8dae51652aa3c8d55acf50bfb2e99", size = 15095, upload-time = "2025-09-12T07:33:52.639Z" }, { url = "https://files.pythonhosted.org/packages/04/93/2fa34714b7a4ae72f2f8dad66ba17dd9a2c793220719e736dda28b7aec27/pytest_asyncio-1.2.0-py3-none-any.whl", hash = "sha256:8e17ae5e46d8e7efe51ab6494dd2010f4ca8dae51652aa3c8d55acf50bfb2e99", size = 15095, upload-time = "2025-09-12T07:33:52.639Z" },
] ]
[[package]]
name = "pytest-env"
version = "1.2.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "pytest", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/13/12/9c87d0ca45d5992473208bcef2828169fa7d39b8d7fc6e3401f5c08b8bf7/pytest_env-1.2.0.tar.gz", hash = "sha256:475e2ebe8626cee01f491f304a74b12137742397d6c784ea4bc258f069232b80", size = 8973, upload-time = "2025-10-09T19:15:47.42Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/27/98/822b924a4a3eb58aacba84444c7439fce32680592f394de26af9c76e2569/pytest_env-1.2.0-py3-none-any.whl", hash = "sha256:d7e5b7198f9b83c795377c09feefa45d56083834e60d04767efd64819fc9da00", size = 6251, upload-time = "2025-10-09T19:15:46.077Z" },
]
[[package]] [[package]]
name = "pyyaml" name = "pyyaml"
version = "6.0.3" version = "6.0.3"
@@ -1468,32 +1384,32 @@ dependencies = [
{ name = "regex", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "regex", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "requests", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "requests", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
] ]
sdist = { url = "https://files.pythonhosted.org/packages/7d/ab/4d017d0f76ec3171d469d80fc03dfbb4e48a4bcaddaa831b31d526f05edc/tiktoken-0.12.0.tar.gz", hash = "sha256:b18ba7ee2b093863978fcb14f74b3707cdc8d4d4d3836853ce7ec60772139931", size = 37806 } sdist = { url = "https://files.pythonhosted.org/packages/7d/ab/4d017d0f76ec3171d469d80fc03dfbb4e48a4bcaddaa831b31d526f05edc/tiktoken-0.12.0.tar.gz", hash = "sha256:b18ba7ee2b093863978fcb14f74b3707cdc8d4d4d3836853ce7ec60772139931", size = 37806, upload-time = "2025-10-06T20:22:45.419Z" }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/00/61/441588ee21e6b5cdf59d6870f86beb9789e532ee9718c251b391b70c68d6/tiktoken-0.12.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:775c2c55de2310cc1bc9a3ad8826761cbdc87770e586fd7b6da7d4589e13dab3", size = 1050802 }, { url = "https://files.pythonhosted.org/packages/00/61/441588ee21e6b5cdf59d6870f86beb9789e532ee9718c251b391b70c68d6/tiktoken-0.12.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:775c2c55de2310cc1bc9a3ad8826761cbdc87770e586fd7b6da7d4589e13dab3", size = 1050802, upload-time = "2025-10-06T20:22:00.96Z" },
{ url = "https://files.pythonhosted.org/packages/1f/05/dcf94486d5c5c8d34496abe271ac76c5b785507c8eae71b3708f1ad9b45a/tiktoken-0.12.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a01b12f69052fbe4b080a2cfb867c4de12c704b56178edf1d1d7b273561db160", size = 993995 }, { url = "https://files.pythonhosted.org/packages/1f/05/dcf94486d5c5c8d34496abe271ac76c5b785507c8eae71b3708f1ad9b45a/tiktoken-0.12.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a01b12f69052fbe4b080a2cfb867c4de12c704b56178edf1d1d7b273561db160", size = 993995, upload-time = "2025-10-06T20:22:02.788Z" },
{ url = "https://files.pythonhosted.org/packages/a0/70/5163fe5359b943f8db9946b62f19be2305de8c3d78a16f629d4165e2f40e/tiktoken-0.12.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:01d99484dc93b129cd0964f9d34eee953f2737301f18b3c7257bf368d7615baa", size = 1128948 }, { url = "https://files.pythonhosted.org/packages/a0/70/5163fe5359b943f8db9946b62f19be2305de8c3d78a16f629d4165e2f40e/tiktoken-0.12.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:01d99484dc93b129cd0964f9d34eee953f2737301f18b3c7257bf368d7615baa", size = 1128948, upload-time = "2025-10-06T20:22:03.814Z" },
{ url = "https://files.pythonhosted.org/packages/0c/da/c028aa0babf77315e1cef357d4d768800c5f8a6de04d0eac0f377cb619fa/tiktoken-0.12.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:4a1a4fcd021f022bfc81904a911d3df0f6543b9e7627b51411da75ff2fe7a1be", size = 1151986 }, { url = "https://files.pythonhosted.org/packages/0c/da/c028aa0babf77315e1cef357d4d768800c5f8a6de04d0eac0f377cb619fa/tiktoken-0.12.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:4a1a4fcd021f022bfc81904a911d3df0f6543b9e7627b51411da75ff2fe7a1be", size = 1151986, upload-time = "2025-10-06T20:22:05.173Z" },
{ url = "https://files.pythonhosted.org/packages/a0/5a/886b108b766aa53e295f7216b509be95eb7d60b166049ce2c58416b25f2a/tiktoken-0.12.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:981a81e39812d57031efdc9ec59fa32b2a5a5524d20d4776574c4b4bd2e9014a", size = 1194222 }, { url = "https://files.pythonhosted.org/packages/a0/5a/886b108b766aa53e295f7216b509be95eb7d60b166049ce2c58416b25f2a/tiktoken-0.12.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:981a81e39812d57031efdc9ec59fa32b2a5a5524d20d4776574c4b4bd2e9014a", size = 1194222, upload-time = "2025-10-06T20:22:06.265Z" },
{ url = "https://files.pythonhosted.org/packages/f4/f8/4db272048397636ac7a078d22773dd2795b1becee7bc4922fe6207288d57/tiktoken-0.12.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:9baf52f84a3f42eef3ff4e754a0db79a13a27921b457ca9832cf944c6be4f8f3", size = 1255097 }, { url = "https://files.pythonhosted.org/packages/f4/f8/4db272048397636ac7a078d22773dd2795b1becee7bc4922fe6207288d57/tiktoken-0.12.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:9baf52f84a3f42eef3ff4e754a0db79a13a27921b457ca9832cf944c6be4f8f3", size = 1255097, upload-time = "2025-10-06T20:22:07.403Z" },
{ url = "https://files.pythonhosted.org/packages/ce/76/994fc868f88e016e6d05b0da5ac24582a14c47893f4474c3e9744283f1d5/tiktoken-0.12.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:d5f89ea5680066b68bcb797ae85219c72916c922ef0fcdd3480c7d2315ffff16", size = 1050309 }, { url = "https://files.pythonhosted.org/packages/ce/76/994fc868f88e016e6d05b0da5ac24582a14c47893f4474c3e9744283f1d5/tiktoken-0.12.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:d5f89ea5680066b68bcb797ae85219c72916c922ef0fcdd3480c7d2315ffff16", size = 1050309, upload-time = "2025-10-06T20:22:10.939Z" },
{ url = "https://files.pythonhosted.org/packages/f6/b8/57ef1456504c43a849821920d582a738a461b76a047f352f18c0b26c6516/tiktoken-0.12.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:b4e7ed1c6a7a8a60a3230965bdedba8cc58f68926b835e519341413370e0399a", size = 993712 }, { url = "https://files.pythonhosted.org/packages/f6/b8/57ef1456504c43a849821920d582a738a461b76a047f352f18c0b26c6516/tiktoken-0.12.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:b4e7ed1c6a7a8a60a3230965bdedba8cc58f68926b835e519341413370e0399a", size = 993712, upload-time = "2025-10-06T20:22:12.115Z" },
{ url = "https://files.pythonhosted.org/packages/72/90/13da56f664286ffbae9dbcfadcc625439142675845baa62715e49b87b68b/tiktoken-0.12.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:fc530a28591a2d74bce821d10b418b26a094bf33839e69042a6e86ddb7a7fb27", size = 1128725 }, { url = "https://files.pythonhosted.org/packages/72/90/13da56f664286ffbae9dbcfadcc625439142675845baa62715e49b87b68b/tiktoken-0.12.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:fc530a28591a2d74bce821d10b418b26a094bf33839e69042a6e86ddb7a7fb27", size = 1128725, upload-time = "2025-10-06T20:22:13.541Z" },
{ url = "https://files.pythonhosted.org/packages/05/df/4f80030d44682235bdaecd7346c90f67ae87ec8f3df4a3442cb53834f7e4/tiktoken-0.12.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:06a9f4f49884139013b138920a4c393aa6556b2f8f536345f11819389c703ebb", size = 1151875 }, { url = "https://files.pythonhosted.org/packages/05/df/4f80030d44682235bdaecd7346c90f67ae87ec8f3df4a3442cb53834f7e4/tiktoken-0.12.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:06a9f4f49884139013b138920a4c393aa6556b2f8f536345f11819389c703ebb", size = 1151875, upload-time = "2025-10-06T20:22:14.559Z" },
{ url = "https://files.pythonhosted.org/packages/22/1f/ae535223a8c4ef4c0c1192e3f9b82da660be9eb66b9279e95c99288e9dab/tiktoken-0.12.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:04f0e6a985d95913cabc96a741c5ffec525a2c72e9df086ff17ebe35985c800e", size = 1194451 }, { url = "https://files.pythonhosted.org/packages/22/1f/ae535223a8c4ef4c0c1192e3f9b82da660be9eb66b9279e95c99288e9dab/tiktoken-0.12.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:04f0e6a985d95913cabc96a741c5ffec525a2c72e9df086ff17ebe35985c800e", size = 1194451, upload-time = "2025-10-06T20:22:15.545Z" },
{ url = "https://files.pythonhosted.org/packages/78/a7/f8ead382fce0243cb625c4f266e66c27f65ae65ee9e77f59ea1653b6d730/tiktoken-0.12.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:0ee8f9ae00c41770b5f9b0bb1235474768884ae157de3beb5439ca0fd70f3e25", size = 1253794 }, { url = "https://files.pythonhosted.org/packages/78/a7/f8ead382fce0243cb625c4f266e66c27f65ae65ee9e77f59ea1653b6d730/tiktoken-0.12.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:0ee8f9ae00c41770b5f9b0bb1235474768884ae157de3beb5439ca0fd70f3e25", size = 1253794, upload-time = "2025-10-06T20:22:16.624Z" },
{ url = "https://files.pythonhosted.org/packages/72/05/3abc1db5d2c9aadc4d2c76fa5640134e475e58d9fbb82b5c535dc0de9b01/tiktoken-0.12.0-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:a90388128df3b3abeb2bfd1895b0681412a8d7dc644142519e6f0a97c2111646", size = 1050188 }, { url = "https://files.pythonhosted.org/packages/72/05/3abc1db5d2c9aadc4d2c76fa5640134e475e58d9fbb82b5c535dc0de9b01/tiktoken-0.12.0-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:a90388128df3b3abeb2bfd1895b0681412a8d7dc644142519e6f0a97c2111646", size = 1050188, upload-time = "2025-10-06T20:22:19.563Z" },
{ url = "https://files.pythonhosted.org/packages/e3/7b/50c2f060412202d6c95f32b20755c7a6273543b125c0985d6fa9465105af/tiktoken-0.12.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:da900aa0ad52247d8794e307d6446bd3cdea8e192769b56276695d34d2c9aa88", size = 993978 }, { url = "https://files.pythonhosted.org/packages/e3/7b/50c2f060412202d6c95f32b20755c7a6273543b125c0985d6fa9465105af/tiktoken-0.12.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:da900aa0ad52247d8794e307d6446bd3cdea8e192769b56276695d34d2c9aa88", size = 993978, upload-time = "2025-10-06T20:22:20.702Z" },
{ url = "https://files.pythonhosted.org/packages/14/27/bf795595a2b897e271771cd31cb847d479073497344c637966bdf2853da1/tiktoken-0.12.0-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:285ba9d73ea0d6171e7f9407039a290ca77efcdb026be7769dccc01d2c8d7fff", size = 1129271 }, { url = "https://files.pythonhosted.org/packages/14/27/bf795595a2b897e271771cd31cb847d479073497344c637966bdf2853da1/tiktoken-0.12.0-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:285ba9d73ea0d6171e7f9407039a290ca77efcdb026be7769dccc01d2c8d7fff", size = 1129271, upload-time = "2025-10-06T20:22:22.06Z" },
{ url = "https://files.pythonhosted.org/packages/f5/de/9341a6d7a8f1b448573bbf3425fa57669ac58258a667eb48a25dfe916d70/tiktoken-0.12.0-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:d186a5c60c6a0213f04a7a802264083dea1bbde92a2d4c7069e1a56630aef830", size = 1151216 }, { url = "https://files.pythonhosted.org/packages/f5/de/9341a6d7a8f1b448573bbf3425fa57669ac58258a667eb48a25dfe916d70/tiktoken-0.12.0-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:d186a5c60c6a0213f04a7a802264083dea1bbde92a2d4c7069e1a56630aef830", size = 1151216, upload-time = "2025-10-06T20:22:23.085Z" },
{ url = "https://files.pythonhosted.org/packages/75/0d/881866647b8d1be4d67cb24e50d0c26f9f807f994aa1510cb9ba2fe5f612/tiktoken-0.12.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:604831189bd05480f2b885ecd2d1986dc7686f609de48208ebbbddeea071fc0b", size = 1194860 }, { url = "https://files.pythonhosted.org/packages/75/0d/881866647b8d1be4d67cb24e50d0c26f9f807f994aa1510cb9ba2fe5f612/tiktoken-0.12.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:604831189bd05480f2b885ecd2d1986dc7686f609de48208ebbbddeea071fc0b", size = 1194860, upload-time = "2025-10-06T20:22:24.602Z" },
{ url = "https://files.pythonhosted.org/packages/b3/1e/b651ec3059474dab649b8d5b69f5c65cd8fcd8918568c1935bd4136c9392/tiktoken-0.12.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:8f317e8530bb3a222547b85a58583238c8f74fd7a7408305f9f63246d1a0958b", size = 1254567 }, { url = "https://files.pythonhosted.org/packages/b3/1e/b651ec3059474dab649b8d5b69f5c65cd8fcd8918568c1935bd4136c9392/tiktoken-0.12.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:8f317e8530bb3a222547b85a58583238c8f74fd7a7408305f9f63246d1a0958b", size = 1254567, upload-time = "2025-10-06T20:22:25.671Z" },
{ url = "https://files.pythonhosted.org/packages/ac/a4/72eed53e8976a099539cdd5eb36f241987212c29629d0a52c305173e0a68/tiktoken-0.12.0-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:c2c714c72bc00a38ca969dae79e8266ddec999c7ceccd603cc4f0d04ccd76365", size = 1050473 }, { url = "https://files.pythonhosted.org/packages/ac/a4/72eed53e8976a099539cdd5eb36f241987212c29629d0a52c305173e0a68/tiktoken-0.12.0-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:c2c714c72bc00a38ca969dae79e8266ddec999c7ceccd603cc4f0d04ccd76365", size = 1050473, upload-time = "2025-10-06T20:22:27.775Z" },
{ url = "https://files.pythonhosted.org/packages/e6/d7/0110b8f54c008466b19672c615f2168896b83706a6611ba6e47313dbc6e9/tiktoken-0.12.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:cbb9a3ba275165a2cb0f9a83f5d7025afe6b9d0ab01a22b50f0e74fee2ad253e", size = 993855 }, { url = "https://files.pythonhosted.org/packages/e6/d7/0110b8f54c008466b19672c615f2168896b83706a6611ba6e47313dbc6e9/tiktoken-0.12.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:cbb9a3ba275165a2cb0f9a83f5d7025afe6b9d0ab01a22b50f0e74fee2ad253e", size = 993855, upload-time = "2025-10-06T20:22:28.799Z" },
{ url = "https://files.pythonhosted.org/packages/5f/77/4f268c41a3957c418b084dd576ea2fad2e95da0d8e1ab705372892c2ca22/tiktoken-0.12.0-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:dfdfaa5ffff8993a3af94d1125870b1d27aed7cb97aa7eb8c1cefdbc87dbee63", size = 1129022 }, { url = "https://files.pythonhosted.org/packages/5f/77/4f268c41a3957c418b084dd576ea2fad2e95da0d8e1ab705372892c2ca22/tiktoken-0.12.0-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:dfdfaa5ffff8993a3af94d1125870b1d27aed7cb97aa7eb8c1cefdbc87dbee63", size = 1129022, upload-time = "2025-10-06T20:22:29.981Z" },
{ url = "https://files.pythonhosted.org/packages/4e/2b/fc46c90fe5028bd094cd6ee25a7db321cb91d45dc87531e2bdbb26b4867a/tiktoken-0.12.0-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:584c3ad3d0c74f5269906eb8a659c8bfc6144a52895d9261cdaf90a0ae5f4de0", size = 1150736 }, { url = "https://files.pythonhosted.org/packages/4e/2b/fc46c90fe5028bd094cd6ee25a7db321cb91d45dc87531e2bdbb26b4867a/tiktoken-0.12.0-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:584c3ad3d0c74f5269906eb8a659c8bfc6144a52895d9261cdaf90a0ae5f4de0", size = 1150736, upload-time = "2025-10-06T20:22:30.996Z" },
{ url = "https://files.pythonhosted.org/packages/28/c0/3c7a39ff68022ddfd7d93f3337ad90389a342f761c4d71de99a3ccc57857/tiktoken-0.12.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:54c891b416a0e36b8e2045b12b33dd66fb34a4fe7965565f1b482da50da3e86a", size = 1194908 }, { url = "https://files.pythonhosted.org/packages/28/c0/3c7a39ff68022ddfd7d93f3337ad90389a342f761c4d71de99a3ccc57857/tiktoken-0.12.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:54c891b416a0e36b8e2045b12b33dd66fb34a4fe7965565f1b482da50da3e86a", size = 1194908, upload-time = "2025-10-06T20:22:32.073Z" },
{ url = "https://files.pythonhosted.org/packages/ab/0d/c1ad6f4016a3968c048545f5d9b8ffebf577774b2ede3e2e352553b685fe/tiktoken-0.12.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:5edb8743b88d5be814b1a8a8854494719080c28faaa1ccbef02e87354fe71ef0", size = 1253706 }, { url = "https://files.pythonhosted.org/packages/ab/0d/c1ad6f4016a3968c048545f5d9b8ffebf577774b2ede3e2e352553b685fe/tiktoken-0.12.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:5edb8743b88d5be814b1a8a8854494719080c28faaa1ccbef02e87354fe71ef0", size = 1253706, upload-time = "2025-10-06T20:22:33.385Z" },
] ]
[[package]] [[package]]