Compare commits

...

7 Commits

Author SHA1 Message Date
Ryuichi Leo Takashige
12669cd59f size tests... 2026-01-30 19:27:40 +00:00
Ryuichi Leo Takashige
dc7c3bb5b4 send smaller chunks 2026-01-30 18:58:09 +00:00
Ryuichi Leo Takashige
ed59e1d0c5 dont use libp2p port oops 2026-01-30 18:55:32 +00:00
Ryuichi Leo Takashige
b28c0bbe96 Try profiling connections 2026-01-30 18:47:53 +00:00
Evan Quiney
cd946742f7 fix skipping logic in worker plan (#1342)
the worker plan function had some skipping logic missing, leading to
double-submitting tasks.
2026-01-30 14:31:40 +00:00
rltakashige
a5bc38ad1f Check all nodes to evict (#1341)
## Motivation

If nodes have uneven memory, one node may evict cache that remains on
another node. This will break prefill on some setups.

## Changes

<!-- Describe what you changed in detail -->

## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
<!-- - -->

### Automated Testing
<!-- Describe changes to automated tests, or how existing tests cover
this change -->
<!-- - -->
2026-01-30 13:42:09 +00:00
Evan Quiney
2a4e0d4629 make node-ids unique per-session (#1338)
we currently have no strict reuqirements that node ids persist across
sessions, so we can generate fresh nodeids each time

this avoids issues like #1332, but prevents further features such as
caching downloads or node-id dialling

Co-authored-by: rltakashige <rl.takashige@gmail.com>
2026-01-30 13:33:31 +00:00
16 changed files with 369 additions and 92 deletions

View File

@@ -40,6 +40,7 @@ class Node:
node_id: NodeId
event_index_counter: Iterator[int]
_profiling_in_progress: set[NodeId] = field(default_factory=set)
_tg: TaskGroup = field(init=False, default_factory=anyio.create_task_group)
@classmethod
@@ -86,6 +87,7 @@ class Node:
else:
api = None
profiling_in_progress: set[NodeId] = set()
if not args.no_worker:
worker = Worker(
node_id,
@@ -96,6 +98,7 @@ class Node:
command_sender=router.sender(topics.COMMANDS),
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
event_index_counter=event_index_counter,
profiling_in_progress=profiling_in_progress,
)
else:
worker = None
@@ -133,6 +136,7 @@ class Node:
api,
node_id,
event_index_counter,
profiling_in_progress,
)
async def run(self):
@@ -239,6 +243,7 @@ class Node:
topics.DOWNLOAD_COMMANDS
),
event_index_counter=self.event_index_counter,
profiling_in_progress=self._profiling_in_progress,
)
self._tg.start_soon(self.worker.run)
if self.api:

View File

@@ -8,8 +8,8 @@ from typing import Annotated, Literal, cast
from uuid import uuid4
import anyio
from anyio import BrokenResourceError, create_task_group
from anyio.abc import TaskGroup
from anyio import BrokenResourceError, ClosedResourceError, EndOfStream, create_task_group
from anyio.abc import SocketStream, TaskGroup
from fastapi import FastAPI, File, Form, HTTPException, Query, Request, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
@@ -1250,6 +1250,7 @@ class API:
tg.start_soon(self._apply_state)
tg.start_soon(self._pause_on_new_election)
tg.start_soon(self._cleanup_expired_images)
tg.start_soon(self._run_bandwidth_server)
print_startup_banner(self.port)
await serve(
cast(ASGIFramework, self.app),
@@ -1257,6 +1258,43 @@ class API:
shutdown_trigger=lambda: anyio.sleep_forever(),
)
async def _run_bandwidth_server(self):
"""TCP server for iperf-like bandwidth testing."""
bandwidth_port = self.port + 1
listener = await anyio.create_tcp_listener(local_port=bandwidth_port)
logger.info(f"Bandwidth test server listening on port {bandwidth_port}")
await listener.serve(self._handle_bandwidth_connection)
async def _handle_bandwidth_connection(self, stream: SocketStream) -> None:
"""Handle a single bandwidth test connection."""
try:
mode = await stream.receive(1)
if mode == b"U":
# Upload test: client sends, we receive
bytes_received = 0
start = time.perf_counter()
while True:
try:
data = await stream.receive(1024 * 1024)
if not data or data == b"DONE" or b"DONE" in data:
break
bytes_received += len(data)
except EndOfStream:
break
elapsed = time.perf_counter() - start
logger.debug(f"Bandwidth upload: {bytes_received} bytes in {elapsed:.3f}s")
elif mode == b"D":
# Download test: we send, client receives
chunk = b"X" * (1024 * 1024)
start = time.perf_counter()
while time.perf_counter() - start < 1.0: # Send for 1s
try:
await stream.send(chunk)
except (BrokenResourceError, ClosedResourceError):
break
except Exception as e:
logger.debug(f"Bandwidth connection error: {e}")
self.command_sender.close()
self.global_event_receiver.close()

View File

@@ -28,6 +28,9 @@ def create_node_network() -> NodeNetworkInfo:
def create_socket_connection(ip: int, sink_port: int = 1234) -> SocketConnection:
return SocketConnection(
sink_multiaddr=Multiaddr(address=f"/ip4/169.254.0.{ip}/tcp/{sink_port}"),
latency_ms=1.0,
other_to_sink_bandwidth_mbps=1000.0,
sink_to_other_bandwidth_mbps=1000.0,
)

View File

@@ -366,7 +366,10 @@ def test_tensor_rdma_backend_connectivity_matrix(
ip_address="10.0.0.1",
)
ethernet_conn = SocketConnection(
sink_multiaddr=Multiaddr(address="/ip4/10.0.0.1/tcp/8000")
sink_multiaddr=Multiaddr(address="/ip4/10.0.0.1/tcp/8000"),
latency_ms=1.0,
other_to_sink_bandwidth_mbps=1000.0,
sink_to_other_bandwidth_mbps=1000.0,
)
node_network = {

View File

@@ -15,6 +15,9 @@ def topology() -> Topology:
def socket_connection() -> SocketConnection:
return SocketConnection(
sink_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/1235"),
latency_ms=1.0,
sink_to_other_bandwidth_mbps=1000.0,
other_to_sink_bandwidth_mbps=1000.0,
)

View File

@@ -216,6 +216,8 @@ def get_node_id_keypair(
Obtains the :class:`Keypair` associated with this node-ID.
Obtain the :class:`PeerId` by from it.
"""
# TODO(evan): bring back node id persistence once we figure out how to deal with duplicates
return Keypair.generate_ed25519()
def lock_path(path: str | bytes | PathLike[str] | PathLike[bytes]) -> Path:
return Path(str(path) + ".lock")

View File

@@ -8,7 +8,7 @@ from multiprocessing.synchronize import Event as EventT
from multiprocessing.synchronize import Semaphore as SemaphoreT
from loguru import logger
from pytest import LogCaptureFixture
from pytest import LogCaptureFixture, mark
from exo.routing.router import get_node_id_keypair
from exo.shared.constants import EXO_NODE_ID_KEYPAIR
@@ -74,6 +74,7 @@ def _delete_if_exists(p: str | bytes | os.PathLike[str] | os.PathLike[bytes]):
os.remove(p)
@mark.skip(reason="this functionality is currently disabled but may return in future")
def test_node_id_fetching(caplog: LogCaptureFixture):
reps = 10

View File

@@ -16,6 +16,9 @@ def test_state_serialization_roundtrip() -> None:
sink=node_b,
edge=SocketConnection(
sink_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/10001"),
latency_ms=1.5,
sink_to_other_bandwidth_mbps=1000.0,
other_to_sink_bandwidth_mbps=1100.0,
),
)

View File

@@ -24,10 +24,22 @@ class RDMAConnection(FrozenModel):
class SocketConnection(FrozenModel):
sink_multiaddr: Multiaddr
latency_ms: float
sink_to_other_bandwidth_mbps: float
other_to_sink_bandwidth_mbps: float
@property
def bandwidth_mbps(self):
return min(self.sink_to_other_bandwidth_mbps, self.other_to_sink_bandwidth_mbps)
def __hash__(self):
return hash(self.sink_multiaddr.ip_address)
def __eq__(self, other: object) -> bool:
if not isinstance(other, SocketConnection):
return NotImplemented
return self.sink_multiaddr == other.sink_multiaddr
class Connection(FrozenModel):
source: NodeId

View File

@@ -0,0 +1,125 @@
import time
import anyio
import httpx
from anyio.abc import SocketStream
from exo.shared.logging import logger
from pydantic.v1 import BaseModel
LATENCY_PING_COUNT = 5
BANDWIDTH_TEST_DURATION_S = 0.5
BANDWIDTH_TEST_PORT_OFFSET = 1 # API port + 1
class ConnectionProfile(BaseModel):
latency_ms: float
upload_mbps: float
download_mbps: float
async def measure_latency(target_ip: str, port: int = 52415) -> float:
if ":" in target_ip:
url = f"http://[{target_ip}]:{port}/node_id"
else:
url = f"http://{target_ip}:{port}/node_id"
rtts: list[float] = []
async with httpx.AsyncClient(timeout=10.0) as client:
for _ in range(LATENCY_PING_COUNT):
try:
start = time.perf_counter()
response = await client.get(url)
end = time.perf_counter()
if response.status_code == 200:
rtts.append((end - start) * 1000)
except (httpx.TimeoutException, httpx.NetworkError, httpx.RemoteProtocolError) as e:
logger.debug(f"Latency ping failed: {e}")
if not rtts:
raise ConnectionError(f"Failed to measure latency to {target_ip}:{port}")
return sum(rtts) / len(rtts)
async def _measure_upload_tcp(stream: SocketStream, duration: float) -> float:
"""Send data for duration seconds, return Mbps."""
chunk = b"X" * (1024 * 1024) # 1MB
bytes_sent = 0
start = time.perf_counter()
deadline = start + duration
while time.perf_counter() < deadline:
await stream.send(chunk)
bytes_sent += len(chunk)
elapsed = time.perf_counter() - start
return (bytes_sent * 8 / elapsed) / 1_000_000 if elapsed > 0 else 0.0
async def _measure_download_tcp(stream: SocketStream, duration: float) -> float:
"""Receive data for duration seconds, return Mbps."""
bytes_received = 0
start = time.perf_counter()
with anyio.move_on_after(duration):
while True:
data = await stream.receive(1024 * 1024)
if not data:
break
bytes_received += len(data)
elapsed = time.perf_counter() - start
return (bytes_received * 8 / elapsed) / 1_000_000 if elapsed > 0 else 0.0
async def measure_bandwidth_tcp(target_ip: str, port: int) -> tuple[float, float]:
"""Measure bandwidth using raw TCP like iperf."""
upload_mbps = 0.0
download_mbps = 0.0
try:
async with await anyio.connect_tcp(target_ip, port) as stream:
# Protocol: send 'U' for upload test, 'D' for download test
# Upload: client sends, server receives
await stream.send(b"U")
upload_mbps = await _measure_upload_tcp(stream, BANDWIDTH_TEST_DURATION_S)
await stream.send(b"DONE")
logger.debug(f"Upload: {upload_mbps:.1f} Mbps")
except Exception as e:
logger.debug(f"Upload TCP test failed: {e}")
try:
async with await anyio.connect_tcp(target_ip, port) as stream:
# Download: client receives, server sends
await stream.send(b"D")
download_mbps = await _measure_download_tcp(stream, BANDWIDTH_TEST_DURATION_S)
logger.debug(f"Download: {download_mbps:.1f} Mbps")
except Exception as e:
logger.debug(f"Download TCP test failed: {e}")
return upload_mbps, download_mbps
async def profile_connection(target_ip: str, port: int = 52415) -> ConnectionProfile:
logger.debug(f"Profiling connection to {target_ip}:{port}")
latency_ms = await measure_latency(target_ip, port)
logger.debug(f"Measured latency to {target_ip}: {latency_ms:.2f}ms")
bandwidth_port = port + BANDWIDTH_TEST_PORT_OFFSET
upload_mbps, download_mbps = await measure_bandwidth_tcp(target_ip, bandwidth_port)
logger.debug(
f"Measured bandwidth to {target_ip}: "
f"upload={upload_mbps:.1f}Mbps, download={download_mbps:.1f}Mbps"
)
if upload_mbps == 0.0 and download_mbps == 0.0:
raise ConnectionError(f"Failed to measure bandwidth to {target_ip}:{bandwidth_port}")
return ConnectionProfile(
latency_ms=latency_ms,
upload_mbps=upload_mbps,
download_mbps=download_mbps,
)

View File

@@ -3,6 +3,7 @@ from copy import deepcopy
from typing import Any, cast
import mlx.core as mx
import psutil
from mlx_lm.models.cache import (
KVCache,
QuantizedKVCache,
@@ -12,25 +13,29 @@ from mlx_lm.models.cache import (
from mlx_lm.models.gpt_oss import Model as GptOssModel
from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.shared.types.memory import Memory
from exo.shared.types.mlx import KVCacheType
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.constants import CACHE_GROUP_SIZE, KV_CACHE_BITS
from exo.worker.runner.bootstrap import logger
# Fraction of device memory above which LRU eviction kicks in
_DEFAULT_MEMORY_THRESHOLD = 0.85
_DEFAULT_MEMORY_THRESHOLD = 0.9
_MEMORY_THRESHOLD = float(
os.environ.get("EXO_MEMORY_THRESHOLD", _DEFAULT_MEMORY_THRESHOLD)
)
class KVPrefixCache:
def __init__(self, tokenizer: TokenizerWrapper):
def __init__(
self, tokenizer: TokenizerWrapper, group: mx.distributed.Group | None = None
):
self.prompts: list[mx.array] = [] # mx array of tokens (ints)
self.caches: list[KVCacheType] = []
self._last_used: list[int] = [] # monotonic counter of last access per entry
self._access_counter: int = 0
self._tokenizer: TokenizerWrapper = tokenizer
self._group = group
def clear(self):
"""Clear all cached prompts and caches."""
@@ -81,13 +86,13 @@ class KVPrefixCache:
best_snapshot_index, best_snapshot_length = None, 0
for i, cached_prompt in enumerate(self.prompts):
length = _get_prefix_length(tokenized_prompt, cached_prompt)
length = get_prefix_length(tokenized_prompt, cached_prompt)
if length == max_length:
# Exact match - cached prompt starts with our entire prompt
# Trim cache to prompt length - 1, return last token for stream_generate
prompt_cache = deepcopy(self.caches[i])
cached_length = _cache_length(self.caches[i])
cached_length = cache_length(self.caches[i])
tokens_to_trim = cached_length - (max_length - 1)
if tokens_to_trim > 0:
trim_prompt_cache(cast(list[Any], prompt_cache), tokens_to_trim)
@@ -109,7 +114,7 @@ class KVPrefixCache:
prompt_cache = deepcopy(self.caches[best_snapshot_index])
# Trim removes tokens from the end, so we trim (cached_length - prefix_length) to keep the prefix
cached_length = _cache_length(self.caches[best_snapshot_index])
cached_length = cache_length(self.caches[best_snapshot_index])
tokens_to_trim = cached_length - best_snapshot_length
if tokens_to_trim > 0:
trim_prompt_cache(cast(list[Any], prompt_cache), tokens_to_trim)
@@ -131,29 +136,37 @@ class KVPrefixCache:
return prompt_cache, tokenized_prompt, None
def _evict_if_needed(self):
"""Evict least recently used entries while memory pressure is high."""
"""Evict least recently used entries while memory usage is high."""
if len(self.caches) == 0:
return
active: int = mx.metal.get_active_memory()
limit = int(mx.metal.device_info()["max_recommended_working_set_size"])
if active < limit * _MEMORY_THRESHOLD:
return
# Evict LRU entries until below threshold or only one entry left
while len(self.caches) > 0:
while (
len(self.caches) > 1
and self.get_memory_used_percentage() > _MEMORY_THRESHOLD
):
lru_index = self._last_used.index(min(self._last_used))
evicted_tokens = len(self.prompts[lru_index])
self.prompts.pop(lru_index)
self.caches.pop(lru_index)
self._last_used.pop(lru_index)
logger.info(
f"KV cache evicted LRU entry ({evicted_tokens} tokens) due to memory pressure"
f"KV cache evicted LRU entry ({evicted_tokens} tokens) due to memory usage"
)
active = mx.metal.get_active_memory()
if active < limit * _MEMORY_THRESHOLD:
break
def get_memory_used_percentage(self) -> float:
local_pressure: float = get_memory_used_percentage()
if self._group is None:
return local_pressure
all_pressure = mx.distributed.all_gather(
mx.array([local_pressure], dtype=mx.float32),
group=self._group,
)
# .item() evals.
max_pressure = float(mx.max(all_pressure).item())
return max_pressure
def encode_prompt(tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
@@ -168,13 +181,13 @@ def encode_prompt(tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
return mx.array(tokenized_prompt)
def _cache_length(cache: KVCacheType) -> int:
def cache_length(cache: KVCacheType) -> int:
"""Get the number of tokens in a KV cache."""
# Use .offset attribute which all cache types have (len() not implemented in older QuantizedKVCache)
return max(c.offset for c in cache) # type: ignore
def _get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
def get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
"""Find the length of the common prefix between two token arrays."""
n = min(int(prompt.shape[0]), int(cached_prompt.shape[0]))
if n == 0:
@@ -185,6 +198,17 @@ def _get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
return int(mx.sum(prefix_mask).item())
def get_available_memory() -> Memory:
mem: int = psutil.virtual_memory().available
return Memory.from_bytes(mem)
def get_memory_used_percentage() -> float:
mem = psutil.virtual_memory()
# percent is 0-100
return float(mem.percent / 100)
def make_kv_cache(
model: Model, max_kv_size: int | None = None, keep: int = 0
) -> KVCacheType:

View File

@@ -44,6 +44,7 @@ from exo.shared.types.topology import Connection, SocketConnection
from exo.shared.types.worker.runners import RunnerId
from exo.utils.channels import Receiver, Sender, channel
from exo.utils.event_buffer import OrderedBuffer
from exo.utils.info_gatherer.connection_profiler import profile_connection
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
from exo.utils.info_gatherer.net_profile import check_reachable
from exo.utils.keyed_backoff import KeyedBackoff
@@ -65,6 +66,7 @@ class Worker:
command_sender: Sender[ForwarderCommand],
download_command_sender: Sender[ForwarderDownloadCommand],
event_index_counter: Iterator[int],
profiling_in_progress: set[NodeId] | None = None,
):
self.node_id: NodeId = node_id
self.session_id: SessionId = session_id
@@ -81,6 +83,8 @@ class Worker:
self.state: State = State()
self.runners: dict[RunnerId, RunnerSupervisor] = {}
self._tg: TaskGroup = create_task_group()
self._profiling_in_progress: set[NodeId] = profiling_in_progress if profiling_in_progress is not None else set()
self._profiling_lock = anyio.Lock()
self._nack_cancel_scope: CancelScope | None = None
self._nack_attempts: int = 0
@@ -282,37 +286,73 @@ class Worker:
async def _connection_message_event_writer(self):
with self.connection_message_receiver as connection_messages:
async for msg in connection_messages:
await self.event_sender.send(
self._convert_connection_message_to_event(msg)
)
event = await self._convert_connection_message_to_event(msg)
if event:
await self.event_sender.send(event)
def _convert_connection_message_to_event(self, msg: ConnectionMessage):
match msg.connection_type:
case ConnectionMessageType.Connected:
return TopologyEdgeCreated(
conn=Connection(
source=self.node_id,
sink=msg.node_id,
edge=SocketConnection(
sink_multiaddr=Multiaddr(
address=f"/ip4/{msg.remote_ipv4}/tcp/{msg.remote_tcp_port}"
),
),
),
)
async def _convert_connection_message_to_event(self, msg: ConnectionMessage):
if msg.connection_type == ConnectionMessageType.Connected:
async with self._profiling_lock:
if msg.node_id in self._profiling_in_progress:
return None
self._profiling_in_progress.add(msg.node_id)
return await self._profile_and_emit_connection(
msg.node_id,
msg.remote_ipv4,
msg.remote_tcp_port,
)
elif msg.connection_type == ConnectionMessageType.Disconnected:
self._profiling_in_progress.discard(msg.node_id)
target_ip = msg.remote_ipv4
for connection in self.state.topology.list_connections():
if (
isinstance(connection.edge, SocketConnection)
and connection.edge.sink_multiaddr.ip_address == target_ip
):
return TopologyEdgeDeleted(conn=connection)
case ConnectionMessageType.Disconnected:
return TopologyEdgeDeleted(
conn=Connection(
source=self.node_id,
sink=msg.node_id,
edge=SocketConnection(
sink_multiaddr=Multiaddr(
address=f"/ip4/{msg.remote_ipv4}/tcp/{msg.remote_tcp_port}"
),
),
async def _profile_and_emit_connection(
self,
sink_node_id: NodeId,
remote_ip: str,
remote_port: int,
):
try:
profile = await profile_connection(remote_ip)
except ConnectionError as e:
logger.warning(f"Failed to profile connection to {sink_node_id}: {e}")
profile = None
if profile:
latency_ms = profile.latency_ms
other_to_sink_mbps = profile.upload_mbps
sink_to_other_mbps = profile.download_mbps
else:
latency_ms = 0.0
other_to_sink_mbps = 0.0
sink_to_other_mbps = 0.0
logger.info(
f"Connection to {sink_node_id} profiled: "
f"latency={latency_ms:.2f}ms, "
f"upload={other_to_sink_mbps:.1f}Mbps, "
f"download={sink_to_other_mbps:.1f}Mbps"
)
return TopologyEdgeCreated(
conn=Connection(
source=self.node_id,
sink=sink_node_id,
edge=SocketConnection(
sink_multiaddr=Multiaddr(
address=f"/ip4/{remote_ip}/tcp/{remote_port}"
),
)
latency_ms=latency_ms,
sink_to_other_bandwidth_mbps=sink_to_other_mbps,
other_to_sink_bandwidth_mbps=other_to_sink_mbps,
),
),
)
async def _nack_request(self, since_idx: int) -> None:
# We request all events after (and including) the missing index.
@@ -389,21 +429,21 @@ class Worker:
)
for nid in conns:
for ip in conns[nid]:
edge = SocketConnection(
# nonsense multiaddr
temp_edge = SocketConnection(
sink_multiaddr=Multiaddr(address=f"/ip4/{ip}/tcp/52415")
if "." in ip
# nonsense multiaddr
else Multiaddr(address=f"/ip6/{ip}/tcp/52415"),
latency_ms=0.0,
sink_to_other_bandwidth_mbps=0.0,
other_to_sink_bandwidth_mbps=0.0,
)
if edge not in edges:
logger.debug(f"ping discovered {edge=}")
await self.event_sender.send(
TopologyEdgeCreated(
conn=Connection(
source=self.node_id, sink=nid, edge=edge
)
)
if temp_edge not in edges:
logger.debug(f"ping discovered new connection to {nid} at {ip}")
self._tg.start_soon(
self._profile_and_emit_connection,
nid,
ip,
52415,
)
for conn in self.state.topology.out_edges(self.node_id):

View File

@@ -37,6 +37,7 @@ from exo.shared.types.tasks import (
Shutdown,
StartWarmup,
Task,
TaskId,
TaskStatus,
)
from exo.shared.types.worker.instances import BoundInstance
@@ -111,8 +112,12 @@ def main(
event_sender.send(
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
)
seen = set[TaskId]()
with task_receiver as tasks:
for task in tasks:
if task.task_id in seen:
logger.warning("repeat task - potential error")
seen.add(task.task_id)
event_sender.send(
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
)
@@ -163,7 +168,7 @@ def main(
logger.info(
f"model has_tool_calling={tokenizer.has_tool_calling}"
)
kv_prefix_cache = KVPrefixCache(tokenizer)
kv_prefix_cache = KVPrefixCache(tokenizer, group)
elif (
ModelTask.TextToImage in shard_metadata.model_card.tasks

View File

@@ -127,20 +127,25 @@ class RunnerSupervisor:
self._tg.cancel_scope.cancel()
async def start_task(self, task: Task):
if task.task_id in self.pending:
logger.warning(
f"Skipping invalid task {task} as it has already been submitted"
)
return
if task.task_id in self.completed:
logger.info(
logger.warning(
f"Skipping invalid task {task} as it has already been completed"
)
return
logger.info(f"Starting task {task}")
event = anyio.Event()
self.pending[task.task_id] = event
try:
self._task_sender.send(task)
await self._task_sender.send_async(task)
except ClosedResourceError:
logger.warning(f"Task {task} dropped, runner closed communication.")
return
await event.wait()
logger.info(f"Finished task {task}")
async def _forward_events(self):
with self._ev_recv as events:

View File

@@ -14,9 +14,9 @@ from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.cache import (
KVPrefixCache,
_cache_length,
_get_prefix_length,
cache_length,
encode_prompt,
get_prefix_length,
make_kv_cache,
)
from exo.worker.engines.mlx.generator.generate import mlx_generate, prefill
@@ -35,47 +35,47 @@ class TestGetPrefixLength:
def test_identical_arrays(self):
a = mx.array([1, 2, 3, 4, 5])
b = mx.array([1, 2, 3, 4, 5])
assert _get_prefix_length(a, b) == 5
assert get_prefix_length(a, b) == 5
def test_no_common_prefix(self):
a = mx.array([1, 2, 3])
b = mx.array([4, 5, 6])
assert _get_prefix_length(a, b) == 0
assert get_prefix_length(a, b) == 0
def test_partial_prefix(self):
a = mx.array([1, 2, 3, 4, 5])
b = mx.array([1, 2, 3, 7, 8])
assert _get_prefix_length(a, b) == 3
assert get_prefix_length(a, b) == 3
def test_prompt_longer_than_cached(self):
a = mx.array([1, 2, 3, 4, 5])
b = mx.array([1, 2, 3])
assert _get_prefix_length(a, b) == 3
assert get_prefix_length(a, b) == 3
def test_cached_longer_than_prompt(self):
a = mx.array([1, 2, 3])
b = mx.array([1, 2, 3, 4, 5])
assert _get_prefix_length(a, b) == 3
assert get_prefix_length(a, b) == 3
def test_single_token_match(self):
a = mx.array([1, 2, 3])
b = mx.array([1, 5, 6])
assert _get_prefix_length(a, b) == 1
assert get_prefix_length(a, b) == 1
def test_empty_prompt(self):
a = mx.array([]).astype(mx.int32)
b = mx.array([1, 2, 3])
assert _get_prefix_length(a, b) == 0
assert get_prefix_length(a, b) == 0
def test_empty_cached(self):
a = mx.array([1, 2, 3])
b = mx.array([]).astype(mx.int32)
assert _get_prefix_length(a, b) == 0
assert get_prefix_length(a, b) == 0
def test_both_empty(self):
a = mx.array([]).astype(mx.int32)
b = mx.array([]).astype(mx.int32)
assert _get_prefix_length(a, b) == 0
assert get_prefix_length(a, b) == 0
class TestKVPrefix:
@@ -146,7 +146,7 @@ class TestKVPrefixCacheWithModel:
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
# Cache should now hold the prompt tokens
assert _cache_length(cache) == len(tokens)
assert cache_length(cache) == len(tokens)
def test_add_and_get_exact_match(self, model_and_tokenizer):
model, tokenizer = model_and_tokenizer
@@ -166,7 +166,7 @@ class TestKVPrefixCacheWithModel:
kv_prefix_cache.add_kv_cache(prompt, cache)
assert len(kv_prefix_cache.prompts) == 1
stored_length = _cache_length(kv_prefix_cache.caches[0])
stored_length = cache_length(kv_prefix_cache.caches[0])
assert stored_length > 0
# Retrieve with same prompt: exact match
@@ -209,7 +209,7 @@ class TestKVPrefixCacheWithModel:
long_tokens = encode_prompt(tokenizer, long_prompt)
# The prompts share a prefix (chat template preamble + "Hi")
expected_prefix = _get_prefix_length(long_tokens, short_tokens)
expected_prefix = get_prefix_length(long_tokens, short_tokens)
assert expected_prefix > 0, (
"Prompts should share a prefix from the chat template"
)
@@ -243,7 +243,7 @@ class TestKVPrefixCacheWithModel:
kv_prefix_cache = KVPrefixCache(tokenizer)
kv_prefix_cache.add_kv_cache(prompt, cache)
stored_length = _cache_length(kv_prefix_cache.caches[0])
stored_length = cache_length(kv_prefix_cache.caches[0])
# Get cache and mutate it (simulating what generation does)
result_cache, _, matched_index = kv_prefix_cache.get_kv_cache(model, prompt)
@@ -259,7 +259,7 @@ class TestKVPrefixCacheWithModel:
mx.eval([c.keys for c in result_cache])
# Stored cache must be unchanged
assert _cache_length(kv_prefix_cache.caches[0]) == stored_length
assert cache_length(kv_prefix_cache.caches[0]) == stored_length
def test_stored_cache_survives_repeated_get_mutate_cycles(
self, model_and_tokenizer
@@ -281,7 +281,7 @@ class TestKVPrefixCacheWithModel:
kv_prefix_cache = KVPrefixCache(tokenizer)
kv_prefix_cache.add_kv_cache(prompt, cache)
stored_length = _cache_length(kv_prefix_cache.caches[0])
stored_length = cache_length(kv_prefix_cache.caches[0])
for i in range(3):
result_cache, _, _ = kv_prefix_cache.get_kv_cache(model, prompt)
@@ -293,7 +293,7 @@ class TestKVPrefixCacheWithModel:
layer_cache.update_and_fetch(extra, extra)
mx.eval([c.keys for c in result_cache])
assert _cache_length(kv_prefix_cache.caches[0]) == stored_length, (
assert cache_length(kv_prefix_cache.caches[0]) == stored_length, (
f"Failed on loop {i}"
)
@@ -325,7 +325,7 @@ class TestKVPrefixCacheWithModel:
assert len(kv_prefix_cache.caches) == 1
# Cache should contain prompt + generated tokens
expected_length = len(prompt_tokens) + generated_tokens
assert _cache_length(kv_prefix_cache.caches[0]) == expected_length
assert cache_length(kv_prefix_cache.caches[0]) == expected_length
def test_mlx_generate_second_call_gets_prefix_hit(self, model_and_tokenizer):
"""Second mlx_generate call with same prompt should get a prefix hit from stored cache."""
@@ -400,7 +400,7 @@ class TestKVPrefixCacheWithModel:
first_gen_time = time.perf_counter() - t0
assert len(kv_prefix_cache.prompts) == 1
first_cache_length = _cache_length(kv_prefix_cache.caches[0])
first_cache_length = cache_length(kv_prefix_cache.caches[0])
# Second generation: same long prompt + extra content (simulating multi-turn)
task2 = ChatCompletionTaskParams(
@@ -416,7 +416,7 @@ class TestKVPrefixCacheWithModel:
prompt2_tokens = encode_prompt(tokenizer, prompt2)
# Verify the prompts share a long prefix
prefix_len = _get_prefix_length(prompt2_tokens, prompt1_tokens)
prefix_len = get_prefix_length(prompt2_tokens, prompt1_tokens)
assert prefix_len > 1000, "Prompts must share > 1000 token prefix"
# Second generation should reuse the cached prefix (only prefill new tokens)
@@ -440,7 +440,7 @@ class TestKVPrefixCacheWithModel:
# With prefix_hit > 1000, should update in-place (not add a second entry)
assert len(kv_prefix_cache.prompts) == 1
# Updated cache should be longer (prompt2 + generated > prompt1 + generated)
updated_cache_length = _cache_length(kv_prefix_cache.caches[0])
updated_cache_length = cache_length(kv_prefix_cache.caches[0])
assert updated_cache_length > first_cache_length
def test_mlx_generate_stored_cache_not_mutated(self, model_and_tokenizer):
@@ -465,7 +465,7 @@ class TestKVPrefixCacheWithModel:
):
pass
first_cache_length = _cache_length(kv_prefix_cache.caches[0])
firstcache_length = cache_length(kv_prefix_cache.caches[0])
# Second generation gets the cache and mutates it during generation
for _response in mlx_generate(
@@ -478,7 +478,7 @@ class TestKVPrefixCacheWithModel:
pass
# The first stored cache must not have been mutated by the second generation
assert _cache_length(kv_prefix_cache.caches[0]) == first_cache_length
assert cache_length(kv_prefix_cache.caches[0]) == firstcache_length
def test_evicts_lru_entry_under_memory_pressure(self, model_and_tokenizer):
"""Under memory pressure, adding a new cache entry evicts the least recently used one."""
@@ -540,6 +540,6 @@ class TestKVPrefixCacheWithModel:
assert len(kv_prefix_cache.prompts) == 1
# The surviving entry should be the newly added one
new_tokens = encode_prompt(tokenizer, prompt)
assert _get_prefix_length(kv_prefix_cache.prompts[0], new_tokens) == len(
assert get_prefix_length(kv_prefix_cache.prompts[0], new_tokens) == len(
new_tokens
)

View File

@@ -109,8 +109,8 @@ def assert_events_equal(test_events: Iterable[Event], true_events: Iterable[Even
@pytest.fixture
def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
# initialize_mlx returns a "group" equal to 1
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(1))
# initialize_mlx returns a mock group
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(MockGroup()))
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, MockTokenizer)))
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)
@@ -147,6 +147,14 @@ class MockTokenizer:
has_tool_calling = False
class MockGroup:
def rank(self) -> int:
return 0
def size(self) -> int:
return 1
def _run(tasks: Iterable[Task]):
bound_instance = get_bound_mlx_ring_instance(
instance_id=INSTANCE_1_ID,