Compare commits

..

1 Commits

Author SHA1 Message Date
Evan
2ee0bce898 state compaction
introduces a new topic ("state_catchup") over which a full state can be
sent. currently the master sends the worker + api this new state, and
they update only if they have no other events applied - otherwise usual
NACK systems function

## testing

manually tested on two nodes
2026-01-23 14:15:06 +00:00
13 changed files with 300 additions and 354 deletions

View File

@@ -2,14 +2,11 @@
This type stub file was generated by pyright.
"""
from typing import Any, Dict, List, Optional, Protocol, Literal, Self
import mlx.nn as nn
import mlx.core as mx
from typing import Any, Dict, List, Literal, Optional, Protocol, Self
from mlx.core import array
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
class Cache(Protocol):
keys: mx.array
@@ -35,7 +32,6 @@ def make_prompt_cache(
``make_cache`` method, a ``RotatingKVCache`` is used with a maximum
size of ``max_kv_size``
"""
...
def save_prompt_cache(
file_name: str, cache: List[Cache], metadata: Dict[str, str] = ...
@@ -49,7 +45,6 @@ def save_prompt_cache(
metadata (Dict[str, str]): Optional metadata to save along with model
state.
"""
...
def load_prompt_cache(file_name: str, return_metadata=...) -> array:
"""
@@ -64,15 +59,13 @@ def load_prompt_cache(file_name: str, return_metadata=...) -> array:
List[Any] or Tuple[List[Any], Dict[str, str]]: The prompt cache and
the metadata if requested.
"""
...
def can_trim_prompt_cache(cache: List[Any]) -> bool:
def can_trim_prompt_cache(cache: List[Cache]) -> bool:
"""
Check if model's cache can be trimmed.
"""
...
def trim_prompt_cache(cache: List[Any], num_tokens: int) -> int:
def trim_prompt_cache(cache: List[Cache], num_tokens: int) -> List[Cache]:
"""
Trim the model's cache by the given number of tokens.
@@ -86,7 +79,6 @@ def trim_prompt_cache(cache: List[Any], num_tokens: int) -> int:
Returns:
(int): The number of tokens that were trimmed.
"""
...
def create_attention_mask(
N: int, offset: int, return_array: bool, window_size: Optional[int]
@@ -115,125 +107,164 @@ class ConcatenateKVCache(_BaseCache):
KVCache with a larger step size before using this cache.
"""
def __init__(self) -> None: ...
def update_and_fetch(self, keys, values): ...
def update_and_fetch(self, keys, values): # -> tuple[Any | array, Any | array]:
...
@property
def state(self): ...
def state(self): # -> tuple[Any | array | None, Any | array | None]:
...
@state.setter
def state(self, v): ...
def is_trimmable(self): ...
def trim(self, n): ...
def make_mask(self, *args, **kwargs): ...
def state(self, v): # -> None:
...
def is_trimmable(self): # -> Literal[True]:
...
def trim(self, n): # -> int:
...
def make_mask(self, *args, **kwargs): # -> array | Literal['causal'] | None:
...
class QuantizedKVCache(_BaseCache):
step = ...
offset: int
def __init__(self, group_size: int = ..., bits: int = ...) -> None: ...
def update_and_fetch(self, keys, values): ...
def update_and_fetch(self, keys, values): # -> Any:
...
@property
def state(self): ...
def state(
self,
): # -> tuple[Any | tuple[array, array, array] | None, Any | tuple[array, array, array] | None] | Any:
...
@state.setter
def state(self, v): ...
def state(self, v): # -> None:
...
@property
def meta_state(self): ...
def meta_state(self): # -> tuple[str, ...]:
...
@meta_state.setter
def meta_state(self, v): ...
def is_trimmable(self): ...
def trim(self, n): ...
def make_mask(self, *args, **kwargs): ...
def meta_state(self, v): # -> None:
...
def is_trimmable(self): # -> Literal[True]:
...
def trim(self, n): # -> int:
...
def make_mask(self, *args, **kwargs): # -> array | Literal['causal'] | None:
...
class KVCache(_BaseCache):
step = ...
offset: int
def __init__(self) -> None: ...
def update_and_fetch(self, keys, values): ...
def update_and_fetch(self, keys, values): # -> tuple[array | Any, array | Any]:
...
@property
def state(self) -> tuple[array, array]: ...
def state(
self,
) -> tuple[array, array]: ...
@state.setter
def state(self, v) -> None: ...
def is_trimmable(self): ...
def trim(self, n): ...
def is_trimmable(self): # -> Literal[True]:
...
def trim(self, n): # -> int:
...
def to_quantized(
self, group_size: int = ..., bits: int = ...
) -> QuantizedKVCache: ...
def make_mask(self, *args, **kwargs): ...
def make_mask(self, *args, **kwargs): # -> array | Literal['causal'] | None:
...
class RotatingKVCache(_BaseCache):
step = ...
offset: int
def __init__(self, max_size, keep=...) -> None: ...
def update_and_fetch(self, keys, values): ...
def update_and_fetch(
self, keys, values
): # -> tuple[array | Any, array | Any] | tuple[array | Any, array | Any | None]:
...
@property
def state(self): ...
def state(
self,
): # -> tuple[Any | array, Any | array] | tuple[Any | array | None, Any | array | None]:
...
@state.setter
def state(self, v): ...
def state(self, v): # -> None:
...
@property
def meta_state(self): ...
def meta_state(self): # -> tuple[str, ...]:
...
@meta_state.setter
def meta_state(self, v): ...
def is_trimmable(self): ...
def trim(self, n): ...
def meta_state(self, v): # -> None:
...
def is_trimmable(self): # -> bool:
...
def trim(self, n): # -> int:
...
def to_quantized(
self, group_size: int = ..., bits: int = ...
) -> QuantizedKVCache: ...
def make_mask(
self, N: int, window_size: Optional[int] = ..., return_array: bool = ...
): ...
): # -> array | Literal['causal'] | None:
...
class ArraysCache(_BaseCache):
def __init__(self, size, left_padding: Optional[List[int]] = ...) -> None: ...
def __setitem__(self, idx, value): ...
def __setitem__(self, idx, value): # -> None:
...
def __getitem__(self, idx): ...
@property
def state(self): ...
def state(self): # -> list[Any | array] | list[array]:
...
@state.setter
def state(self, v): ...
def filter(self, batch_indices):
def state(self, v): # -> None:
...
def filter(self, batch_indices): # -> None:
"""
In-place filter to keep just the given indices in the cache.
"""
...
def extend(self, other):
def extend(self, other): # -> None:
"""
In-place extend this cache with the other cache.
"""
...
def make_mask(self, N: int): ...
def make_mask(self, N: int): # -> array | None:
...
class MambaCache(ArraysCache):
def __init__(self, left_padding: Optional[List[int]] = ...) -> None: ...
class ChunkedKVCache(KVCache):
def __init__(self, chunk_size) -> None: ...
def maybe_trim_front(self): ...
def update_and_fetch(self, keys, values): ...
def trim(self, n): ...
def maybe_trim_front(self): # -> None:
...
def update_and_fetch(self, keys, values): # -> tuple[array, array]:
...
def trim(self, n): # -> int:
...
@property
def meta_state(self): ...
def meta_state(self): # -> tuple[str, ...]:
...
@meta_state.setter
def meta_state(self, v): ...
def meta_state(self, v): # -> None:
...
class CacheList(_BaseCache):
def __init__(self, *caches) -> None: ...
def __getitem__(self, idx): ...
def is_trimmable(self): ...
def is_trimmable(self): # -> bool:
...
def trim(self, n): ...
@property
def state(self): ...
def state(self): # -> list[Any]:
...
@state.setter
def state(self, v): ...
def filter(self, batch_indices):
def state(self, v): # -> None:
...
def filter(self, batch_indices): # -> None:
"""
In-place filter to keep just the given indices in the cache.
"""
...
def extend(self, other):
def extend(self, other): # -> None:
"""
In-place extend this cache with the other cache.
"""
...
class BatchKVCache(_BaseCache):
step = ...
@@ -256,56 +287,71 @@ class BatchKVCache(_BaseCache):
And ``left_padding`` specifies the amount of padding for each.
In this case, ``left_padding = [1, 3, 0]``.
"""
...
def update_and_fetch(self, keys, values): ...
def update_and_fetch(self, keys, values): # -> tuple[array | Any, array | Any]:
...
@property
def state(self): ...
def state(
self,
): # -> tuple[Any | array | None, Any | array | None, array | Any, array | Any]:
...
@state.setter
def state(self, v): ...
def is_trimmable(self): ...
def trim(self, n): ...
def make_mask(self, N: int, return_array: bool = ..., **kwargs): ...
def filter(self, batch_indices):
def state(self, v): # -> None:
...
def is_trimmable(self): # -> Literal[True]:
...
def trim(self, n): # -> int | float:
...
def make_mask(self, N: int, return_array: bool = ..., **kwargs): # -> array:
...
def filter(self, batch_indices): # -> None:
"""
In-place filter to keep just the given indices in the cache.
"""
...
def extend(self, other):
def extend(self, other): # -> None:
"""
In-place extend this cache with the other cache.
"""
...
class BatchRotatingKVCache(_BaseCache):
step = ...
def __init__(self, max_size, left_padding: List[int]) -> None: ...
def update_and_fetch(self, keys, values): ...
def update_and_fetch(
self, keys, values
): # -> tuple[array | Any, array | Any] | tuple[array | Any, array | Any | None]:
...
@property
def state(self): ...
def state(
self,
): # -> tuple[Any | array | None, Any | array | None, array | Any, array | Any]:
...
@state.setter
def state(self, v): ...
def state(self, v): # -> None:
...
@property
def meta_state(self): ...
def meta_state(self): # -> tuple[str, ...]:
...
@meta_state.setter
def meta_state(self, v): ...
def is_trimmable(self): ...
def trim(self, n): ...
def meta_state(self, v): # -> None:
...
def is_trimmable(self): # -> bool:
...
def trim(self, n): # -> int:
...
def to_quantized(
self, group_size: int = ..., bits: int = ...
) -> QuantizedKVCache: ...
def make_mask(
self, N: int, window_size: Optional[int] = ..., return_array: bool = ...
): ...
def filter(self, batch_indices):
): # -> array:
...
def filter(self, batch_indices): # -> None:
"""
In-place filter to keep just the given indices in the cache.
"""
...
def extend(self, other):
def extend(self, other): # -> None:
"""
In-place extend this cache with the other cache.
"""
...

View File

@@ -8,10 +8,6 @@ from typing import Any
from transformers import PreTrainedTokenizerFast
"""
This type stub file was generated by pyright.
"""
class StreamingDetokenizer:
"""The streaming detokenizer interface so that we can detokenize one token at a time.
@@ -49,7 +45,6 @@ class StreamingDetokenizer:
@property
def last_segment(self):
"""Return the last segment of readable text since last time this property was accessed."""
...
class NaiveStreamingDetokenizer(StreamingDetokenizer):
"""NaiveStreamingDetokenizer relies on the underlying tokenizer
@@ -59,11 +54,15 @@ class NaiveStreamingDetokenizer(StreamingDetokenizer):
repeatedly detokenize the same tokens until a new line is generated.
"""
def __init__(self, tokenizer) -> None: ...
def reset(self): ...
def add_token(self, token): ...
def finalize(self): ...
def reset(self): # -> None:
...
def add_token(self, token): # -> None:
...
def finalize(self): # -> None:
...
@property
def text(self): ...
def text(self): # -> str:
...
class SPMStreamingDetokenizer(StreamingDetokenizer):
"""A streaming detokenizer for SPM models.
@@ -72,9 +71,12 @@ class SPMStreamingDetokenizer(StreamingDetokenizer):
underscore which results in linear complexity.
"""
def __init__(self, tokenizer, trim_space=...) -> None: ...
def reset(self): ...
def add_token(self, token): ...
def finalize(self): ...
def reset(self): # -> None:
...
def add_token(self, token): # -> None:
...
def finalize(self): # -> None:
...
class BPEStreamingDetokenizer(StreamingDetokenizer):
"""A streaming detokenizer for OpenAI style BPE models.
@@ -86,13 +88,15 @@ class BPEStreamingDetokenizer(StreamingDetokenizer):
_byte_decoder = ...
_space_matches = ...
def __init__(self, tokenizer) -> None: ...
def reset(self): ...
def add_token(self, token): ...
def finalize(self): ...
@classmethod
def make_byte_decoder(cls):
"""See https://github.com/openai/gpt-2/blob/master/src/encoder.py for the rationale."""
def reset(self): # -> None:
...
def add_token(self, token): # -> None:
...
def finalize(self): # -> None:
...
@classmethod
def make_byte_decoder(cls): # -> None:
"""See https://github.com/openai/gpt-2/blob/master/src/encoder.py for the rationale."""
class TokenizerWrapper:
"""A wrapper that combines an HF tokenizer and a detokenizer.
@@ -153,10 +157,13 @@ class TokenizerWrapper:
class NewlineTokenizer(PreTrainedTokenizerFast):
"""A tokenizer that replaces newlines with <n> and <n> with new line."""
def __init__(self, *args, **kwargs) -> None: ...
def encode(self, text, **kwargs): ...
def encode(self, text, **kwargs): # -> list[int]:
...
def encode_batch(self, texts, **kwargs): ...
def decode(self, *args, **kwargs): ...
def batch_decode(self, *args, **kwargs): ...
def decode(self, *args, **kwargs): # -> str:
...
def batch_decode(self, *args, **kwargs): # -> list[str]:
...
def load(
model_path: Path,
@@ -169,7 +176,6 @@ def load(
Note, to use a fast streaming tokenizer, pass a local file path rather than
a Hugging Face repo ID.
"""
...
# Alias for backward compatibility
load_tokenizer = load

View File

@@ -3,45 +3,6 @@
perSystem =
{ pkgs, lib, ... }:
let
# Stub source with lockfiles and minimal files for build to succeed
# This allows prettier-svelte to avoid rebuilding when dashboard source changes
dashboardStubSrc = pkgs.runCommand "dashboard-stub-src" { } ''
mkdir -p $out
cp ${inputs.self}/dashboard/package.json $out/
cp ${inputs.self}/dashboard/package-lock.json $out/
# Minimal files so vite build succeeds (produces empty output)
echo '<!DOCTYPE html><html><head></head><body></body></html>' > $out/index.html
mkdir -p $out/src
touch $out/src/app.html
'';
# Deps-only build using stub source (for prettier-svelte)
# Only rebuilds when package.json or package-lock.json change
dashboardDeps = inputs.dream2nix.lib.evalModules {
packageSets.nixpkgs = pkgs;
modules = [
./dashboard.nix
{
paths.projectRoot = inputs.self;
paths.projectRootFile = "flake.nix";
paths.package = inputs.self + "/dashboard";
}
{
deps.dashboardSrc = lib.mkForce dashboardStubSrc;
}
# Override build phases to skip the actual build - just need node_modules
{
mkDerivation = {
buildPhase = lib.mkForce "true";
installPhase = lib.mkForce ''
runHook preInstall
runHook postInstall
'';
};
}
];
};
# Filter source to only include dashboard directory
dashboardSrc = lib.cleanSourceWith {
src = inputs.self;
@@ -81,12 +42,11 @@
'';
# Prettier with svelte plugin for treefmt
# Uses dashboardDeps instead of dashboardFull to avoid rebuilding on source changes
packages.prettier-svelte = pkgs.writeShellScriptBin "prettier-svelte" ''
export NODE_PATH="${dashboardDeps}/lib/node_modules/exo-dashboard/node_modules"
export NODE_PATH="${dashboardFull}/lib/node_modules/exo-dashboard/node_modules"
exec ${pkgs.nodejs}/bin/node \
${dashboardDeps}/lib/node_modules/exo-dashboard/node_modules/prettier/bin/prettier.cjs \
--plugin "${dashboardDeps}/lib/node_modules/exo-dashboard/node_modules/prettier-plugin-svelte/plugin.js" \
${dashboardFull}/lib/node_modules/exo-dashboard/node_modules/prettier/bin/prettier.cjs \
--plugin "${dashboardFull}/lib/node_modules/exo-dashboard/node_modules/prettier-plugin-svelte/plugin.js" \
"$@"
'';
};

View File

@@ -49,6 +49,7 @@ class Node:
await router.register_topic(topics.COMMANDS)
await router.register_topic(topics.ELECTION_MESSAGES)
await router.register_topic(topics.CONNECTION_MESSAGES)
await router.register_topic(topics.STATE_CATCHUP)
logger.info(f"Starting node {node_id}")
if args.spawn_api:
@@ -59,6 +60,7 @@ class Node:
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
command_sender=router.sender(topics.COMMANDS),
election_receiver=router.receiver(topics.ELECTION_MESSAGES),
state_catchup_receiver=router.receiver(topics.STATE_CATCHUP),
)
else:
api = None
@@ -72,6 +74,7 @@ class Node:
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
local_event_sender=router.sender(topics.LOCAL_EVENTS),
command_sender=router.sender(topics.COMMANDS),
state_catchup_receiver=router.receiver(topics.STATE_CATCHUP),
)
else:
worker = None
@@ -83,6 +86,7 @@ class Node:
global_event_sender=router.sender(topics.GLOBAL_EVENTS),
local_event_receiver=router.receiver(topics.LOCAL_EVENTS),
command_receiver=router.receiver(topics.COMMANDS),
state_catchup_sender=router.sender(topics.STATE_CATCHUP),
)
er_send, er_recv = channel[ElectionResult]()
@@ -153,6 +157,7 @@ class Node:
global_event_sender=self.router.sender(topics.GLOBAL_EVENTS),
local_event_receiver=self.router.receiver(topics.LOCAL_EVENTS),
command_receiver=self.router.receiver(topics.COMMANDS),
state_catchup_sender=self.router.sender(topics.STATE_CATCHUP),
)
self._tg.start_soon(self.master.run)
elif (
@@ -185,6 +190,9 @@ class Node:
),
local_event_sender=self.router.sender(topics.LOCAL_EVENTS),
command_sender=self.router.sender(topics.COMMANDS),
state_catchup_receiver=self.router.receiver(
topics.STATE_CATCHUP
),
)
self._tg.start_soon(self.worker.run)
if self.api:

View File

@@ -158,12 +158,14 @@ class API:
command_sender: Sender[ForwarderCommand],
# This lets us pause the API if an election is running
election_receiver: Receiver[ElectionMessage],
state_catchup_receiver: Receiver[State],
) -> None:
self.state = State()
self._event_log: list[Event] = []
self.command_sender = command_sender
self.global_event_receiver = global_event_receiver
self.election_receiver = election_receiver
self.state_catchup_receiver = state_catchup_receiver
self.event_buffer: OrderedBuffer[Event] = OrderedBuffer[Event]()
self.node_id: NodeId = node_id
self.session_id: SessionId = session_id
@@ -1231,6 +1233,7 @@ class API:
tg.start_soon(self._apply_state)
tg.start_soon(self._pause_on_new_election)
tg.start_soon(self._cleanup_expired_images)
tg.start_soon(self._state_catchup)
print_startup_banner(self.port)
await serve(
cast(ASGIFramework, self.app),
@@ -1241,6 +1244,22 @@ class API:
self.command_sender.close()
self.global_event_receiver.close()
async def _state_catchup(self):
with self.state_catchup_receiver as states:
async for state in states:
if (
self.state.last_event_applied_idx == -1
and state.last_event_applied_idx > self.state.last_event_applied_idx
):
logger.info(
f"API catching up state to idx {state.last_event_applied_idx}"
)
self.event_buffer.store = {}
self.event_buffer.next_idx_to_release = (
state.last_event_applied_idx + 1
)
self.state = state
async def _apply_state(self):
with self.global_event_receiver as events:
async for f_event in events:

View File

@@ -68,6 +68,8 @@ class Master:
# Send events to the forwarder to be indexed (usually from command processing)
# Ideally these would be MasterForwarderEvents but type system says no :(
global_event_sender: Sender[ForwarderEvent],
# not a fan but - send the entire state to a node so it can catchup without the whole event log.
state_catchup_sender: Sender[State],
):
self.state = State()
self._tg: TaskGroup = anyio.create_task_group()
@@ -77,6 +79,7 @@ class Master:
self.command_receiver = command_receiver
self.local_event_receiver = local_event_receiver
self.global_event_sender = global_event_sender
self.state_catchup_sender = state_catchup_sender
send, recv = channel[Event]()
self.event_sender: Sender[Event] = send
self._loopback_event_receiver: Receiver[Event] = recv
@@ -84,7 +87,6 @@ class Master:
local_event_receiver.clone_sender()
)
self._multi_buffer = MultiSourceBuffer[NodeId, Event]()
# TODO: not have this
self._event_log: list[Event] = []
async def run(self):
@@ -291,11 +293,17 @@ class Master:
command.finished_command_id
]
case RequestEventLog():
# We should just be able to send everything, since other buffers will ignore old messages
for i in range(command.since_idx, len(self._event_log)):
await self._send_event(
IndexedEvent(idx=i, event=self._event_log[i])
if command.since_idx == 0:
# This is an optimization, and should not be relied upon in theory.
logger.info(
f"Master sending catchup state for index {self.state.last_event_applied_idx}"
)
await self.state_catchup_sender.send(self.state)
else:
for i in range(command.since_idx, len(self._event_log)):
await self._send_event(
IndexedEvent(idx=i, event=self._event_log[i])
)
for event in generated_events:
await self.event_sender.send(event)
except ValueError as e:

View File

@@ -27,6 +27,7 @@ from exo.shared.types.memory import Memory
from exo.shared.types.profiling import (
MemoryUsage,
)
from exo.shared.types.state import State
from exo.shared.types.tasks import ChatCompletion as ChatCompletionTask
from exo.shared.types.tasks import TaskStatus
from exo.shared.types.worker.instances import (
@@ -47,6 +48,7 @@ async def test_master():
ge_sender, global_event_receiver = channel[ForwarderEvent]()
command_sender, co_receiver = channel[ForwarderCommand]()
local_event_sender, le_receiver = channel[ForwarderEvent]()
st_s, _st_r = channel[State]()
all_events: list[IndexedEvent] = []
@@ -67,6 +69,7 @@ async def test_master():
global_event_sender=ge_sender,
local_event_receiver=le_receiver,
command_receiver=co_receiver,
state_catchup_sender=st_s,
)
logger.info("run the master")
async with anyio.create_task_group() as tg:

View File

@@ -7,6 +7,7 @@ from exo.shared.types.commands import ForwarderCommand
from exo.shared.types.events import (
ForwarderEvent,
)
from exo.shared.types.state import State
from exo.utils.pydantic_ext import CamelCaseModel
@@ -45,3 +46,4 @@ ELECTION_MESSAGES = TypedTopic(
CONNECTION_MESSAGES = TypedTopic(
"connection_messages", PublishPolicy.Never, ConnectionMessage
)
STATE_CATCHUP = TypedTopic("state_catchup", PublishPolicy.Always, State)

View File

@@ -1,11 +0,0 @@
"""Shared types for MLX-related functionality."""
from mlx_lm.models.cache import (
KVCache,
QuantizedKVCache,
RotatingKVCache,
)
# Type alias for KV cache - matches make_kv_cache return type
# This list contains one cache entry per transformer layer
KVCacheType = list[KVCache | RotatingKVCache | QuantizedKVCache]

View File

@@ -1,53 +1,39 @@
# type: ignore
# TODO: Fix this file, including types!
from copy import deepcopy
from typing import Callable
import mlx.core as mx
from mlx_lm.models.cache import trim_prompt_cache
from mlx_lm import stream_generate
from mlx_lm.models.cache import _BaseCache, trim_prompt_cache
from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.shared.types.mlx import KVCacheType
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.constants import KEEP_KV_SIZE, KV_BITS, KV_GROUP_SIZE
from exo.worker.engines.mlx.utils_mlx import make_kv_cache
from exo.worker.runner.bootstrap import logger
class KVPrefixCache:
def __init__(self):
# Only one prefix cache per runner.
self.prompts: list[mx.array] = [] # mx array of tokens (ints)
self.caches: list[KVCacheType] = []
def clear(self):
"""Clear all cached prompts and caches."""
self.prompts.clear()
self.caches.clear()
self.caches: list[list[_BaseCache]] = []
def add_kv_cache(
self, tokenizer: TokenizerWrapper, prompt: str, cache: KVCacheType
self, tokenizer: TokenizerWrapper, prompt: str, cache: list[_BaseCache]
):
tokenized_prompt = encode_prompt(tokenizer, prompt)
tokenized_prompt = self.encode_prompt(tokenizer, prompt)
self.prompts.append(tokenized_prompt)
self.caches.append(deepcopy(cache))
logger.info(f"KV cache saved: {len(tokenized_prompt)} tokens")
def get_kv_cache(
self,
model: Model,
tokenizer: TokenizerWrapper,
sampler: Callable[[mx.array], mx.array],
prompt: str,
) -> tuple[KVCacheType, mx.array]:
"""Get KV cache for prompt, returning remaining tokens to prefill.
This method finds the best matching cached prefix and returns:
- A copy of the cache trimmed to the prefix length
- The remaining tokens that need to be prefilled before generation
The caller is responsible for prefilling the remaining tokens.
Returns:
Tuple of (cache, remaining_tokens) where remaining_tokens are the
tokens that still need to be prefilled/processed.
"""
tokenized_prompt = encode_prompt(tokenizer, prompt)
) -> list[_BaseCache]:
tokenized_prompt = self.encode_prompt(tokenizer, prompt)
max_length = len(tokenized_prompt)
best_snapshot_index, best_snapshot_length = None, 0
@@ -56,75 +42,63 @@ class KVPrefixCache:
length = _get_prefix_length(tokenized_prompt, cached_prompt)
if length == max_length:
# Exact match - cached prompt starts with our entire prompt
# Trim cache to prompt length - 1, return last token for stream_generate
prompt_cache = deepcopy(self.caches[i])
cached_length = _cache_length(self.caches[i])
tokens_to_trim = cached_length - (max_length - 1)
if tokens_to_trim > 0:
trim_prompt_cache(prompt_cache, tokens_to_trim)
logger.info(f"KV cache exact match: {max_length} tokens (instant)")
return prompt_cache, tokenized_prompt[-1:]
return self.caches[i]
if length > best_snapshot_length:
best_snapshot_index, best_snapshot_length = i, length
if best_snapshot_index is not None:
new_tokens = max_length - best_snapshot_length
logger.info(
f"KV cache prefix match: {best_snapshot_length}/{max_length} tokens "
f"(reusing {best_snapshot_length}, need to prefill {new_tokens})"
)
prompt_cache = deepcopy(self.caches[best_snapshot_index])
# Trim removes tokens from the end, so we trim (cached_length - prefix_length) to keep the prefix
cached_length = _cache_length(self.caches[best_snapshot_index])
tokens_to_trim = cached_length - best_snapshot_length
if tokens_to_trim > 0:
trim_prompt_cache(prompt_cache, tokens_to_trim)
# Return remaining tokens for caller to prefill
remaining_tokens = tokenized_prompt[best_snapshot_length:]
return prompt_cache, remaining_tokens
trim_prompt_cache(prompt_cache, max_length - best_snapshot_length)
tokenized_prompt = tokenized_prompt[best_snapshot_index:]
else:
prompt_cache = make_kv_cache(model)
if len(self.prompts) == 0:
logger.info(f"KV cache empty, need to prefill {max_length} tokens")
else:
logger.info(
f"KV cache no prefix match, need to prefill {max_length} tokens"
)
prompt_cache = make_kv_cache(
model,
# max_kv_size=MAX_KV_SIZE,
# keep=KEEP_KV_SIZE
)
# Return all tokens for caller to prefill
return prompt_cache, tokenized_prompt
prefill(model, tokenizer, sampler, tokenized_prompt, prompt_cache)
return prompt_cache
def encode_prompt(tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
"""Encode a prompt string to token array.
For chat-templated prompts (which have their own structure markers like
<|im_user|>, <|im_middle|>, etc.), we should NOT add BOS/EOS tokens as
that would corrupt the prompt structure.
"""
# Chat templates define their own structure - don't add BOS/EOS
tokenized_prompt = tokenizer.encode(prompt, add_special_tokens=False)
return mx.array(tokenized_prompt)
def _cache_length(cache: KVCacheType) -> int:
"""Get the number of tokens in a KV cache."""
# Use .offset attribute which all cache types have (len() not implemented in older QuantizedKVCache)
return max(c.offset for c in cache)
def encode_prompt(self, tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
add_special_tokens = tokenizer.bos_token is None or not prompt.startswith(
tokenizer.bos_token
)
tokenized_prompt = tokenizer.encode(
prompt, add_special_tokens=add_special_tokens
)
return mx.array(tokenized_prompt)
def _get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
"""Find the length of the common prefix between two token arrays."""
n = min(int(prompt.shape[0]), int(cached_prompt.shape[0]))
n = min(int(prompt.shape[0]), int(cached_prompt.shape[0]), KEEP_KV_SIZE)
if n == 0:
return 0
equal = mx.equal(prompt[:n], cached_prompt[:n]).astype(mx.int32)
equal = (prompt[:n] == cached_prompt[:n]).astype(mx.int32)
prefix_mask = mx.cumprod(equal) # stays 1 until first mismatch, then 0 forever
return int(mx.sum(prefix_mask).item())
def prefill(
model: Model,
tokenizer: TokenizerWrapper,
sampler: Callable[[mx.array], mx.array],
prompt: mx.array,
cache: list[_BaseCache],
) -> None:
for _ in stream_generate(
model=model,
tokenizer=tokenizer,
prompt=prompt,
max_tokens=0,
sampler=sampler,
prompt_cache=cache,
prefill_step_size=2048,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
):
pass

View File

@@ -1,12 +1,12 @@
import time
from typing import Callable, Generator, cast, get_args
from typing import Any, Callable, Generator, cast, get_args
import mlx.core as mx
from mlx_lm.generate import stream_generate
from mlx_lm.models.cache import trim_prompt_cache
from mlx_lm.models.cache import KVCache
from mlx_lm.sample_utils import make_sampler
from mlx_lm.tokenizer_utils import TokenizerWrapper
# from exo.engines.mlx.cache import KVPrefixCache
from exo.shared.types.api import (
BenchChatCompletionTaskParams,
ChatCompletionMessage,
@@ -14,13 +14,11 @@ from exo.shared.types.api import (
GenerationStats,
)
from exo.shared.types.memory import Memory
from exo.shared.types.mlx import KVCacheType
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.worker.runner_response import (
GenerationResponse,
)
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.cache import KVPrefixCache, encode_prompt
from exo.worker.engines.mlx.constants import KV_BITS, KV_GROUP_SIZE, MAX_TOKENS
from exo.worker.engines.mlx.utils_mlx import (
apply_chat_template,
@@ -32,59 +30,19 @@ from exo.worker.runner.bootstrap import logger
generation_stream = mx.new_stream(mx.default_device())
def prefill(
model: Model,
tokenizer: TokenizerWrapper,
sampler: Callable[[mx.array], mx.array],
prompt_tokens: mx.array,
cache: KVCacheType,
) -> float:
"""Prefill the KV cache with prompt tokens.
This runs the model over the prompt tokens to populate the cache,
then trims off the extra generated token.
Returns:
tokens_per_sec
"""
num_tokens = len(prompt_tokens)
if num_tokens == 0:
return 0.0
logger.debug(f"Prefilling {num_tokens} tokens...")
start_time = time.perf_counter()
def progress_callback(processed: int, total: int) -> None:
elapsed = time.time() - start_time
tok_per_sec = processed / elapsed if elapsed > 0 else 0
logger.debug(
f"Prefill progress: {processed}/{total} tokens ({tok_per_sec:.1f} tok/s)"
)
# Use max_tokens=1 because max_tokens=0 does not work.
# We just throw away the generated token - we only care about filling the cache
for _ in stream_generate(
model=model,
tokenizer=tokenizer,
prompt=prompt_tokens,
max_tokens=1,
sampler=sampler,
prompt_cache=cache,
prefill_step_size=2048,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
prompt_progress_callback=progress_callback,
):
break # Stop after first iteration - cache is now filled
trim_prompt_cache(cache, 1)
elapsed = time.perf_counter() - start_time
tokens_per_sec = num_tokens / elapsed if elapsed > 0 else 0.0
logger.debug(
f"Prefill complete: {num_tokens} tokens in {elapsed:.2f}s "
f"({tokens_per_sec:.1f} tok/s)"
)
return tokens_per_sec
def maybe_quantize_kv_cache(
prompt_cache: list[KVCache | Any],
quantized_kv_start: int,
kv_group_size: int,
kv_bits: int | None,
) -> None:
if kv_bits is None:
return
for e, c in enumerate(prompt_cache):
if (
hasattr(c, "to_quantized") and c.offset >= quantized_kv_start # type: ignore
):
prompt_cache[e] = c.to_quantized(group_size=kv_group_size, bits=kv_bits)
def warmup_inference(
@@ -162,7 +120,6 @@ def mlx_generate(
tokenizer: TokenizerWrapper,
task: ChatCompletionTaskParams,
prompt: str,
kv_prefix_cache: KVPrefixCache | None = None,
) -> Generator[GenerationResponse]:
# Ensure that generation stats only contains peak memory for this generation
mx.reset_peak_memory()
@@ -174,16 +131,7 @@ def mlx_generate(
if task.seed is not None:
mx.random.seed(task.seed)
# Do not use the prefix cache if we are trying to do benchmarks.
if is_bench:
kv_prefix_cache = None
# Use prefix cache if available, otherwise create fresh cache
if kv_prefix_cache is None:
caches = make_kv_cache(model=model)
prompt_tokens = encode_prompt(tokenizer, prompt)
else:
caches, prompt_tokens = kv_prefix_cache.get_kv_cache(model, tokenizer, prompt)
caches = make_kv_cache(model=model)
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] = []
if is_bench:
@@ -196,19 +144,11 @@ def mlx_generate(
top_p=task.top_p if task.top_p is not None else 1.0,
)
# Prefill cache with all tokens except the last one
prefill_tps = prefill(model, tokenizer, sampler, prompt_tokens[-1:], caches)
# stream_generate starts from the last token
last_token = prompt_tokens[-1:]
max_tokens = task.max_tokens or MAX_TOKENS
generated_text_parts: list[str] = []
generation_start_time = time.perf_counter()
for out in stream_generate(
model=model,
tokenizer=tokenizer,
prompt=last_token,
prompt=prompt,
max_tokens=max_tokens,
sampler=sampler,
logits_processors=logits_processors,
@@ -218,13 +158,12 @@ def mlx_generate(
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
):
generated_text_parts.append(out.text)
logger.info(out.text)
stats: GenerationStats | None = None
if out.finish_reason is not None:
stats = GenerationStats(
prompt_tps=float(prefill_tps or out.prompt_tps),
prompt_tps=float(out.prompt_tps),
generation_tps=float(out.generation_tps),
prompt_tokens=int(out.prompt_tokens),
generation_tokens=int(out.generation_tokens),
@@ -246,22 +185,6 @@ def mlx_generate(
)
if out.finish_reason is not None:
# Log generation stats
generation_elapsed = time.perf_counter() - generation_start_time
generated_tokens = len(generated_text_parts)
generation_tps = (
generated_tokens / generation_elapsed if generation_elapsed > 0 else 0.0
)
logger.debug(
f"Generation complete: prefill {prompt_tokens} tokens @ "
f"{prefill_tps:.1f} tok/s, generated {generated_tokens} tokens @ "
f"{generation_tps:.1f} tok/s"
)
# Save cache for future prefix matching (clear first to keep only the last one)
if kv_prefix_cache is not None:
kv_prefix_cache.clear()
full_prompt = prompt + "".join(generated_text_parts)
kv_prefix_cache.add_kv_cache(tokenizer, full_prompt, caches)
break
# TODO: Do we want an mx_barrier?

View File

@@ -67,9 +67,8 @@ class Worker:
connection_message_receiver: Receiver[ConnectionMessage],
global_event_receiver: Receiver[ForwarderEvent],
local_event_sender: Sender[ForwarderEvent],
# This is for requesting updates. It doesn't need to be a general command sender right now,
# but I think it's the correct way to be thinking about commands
command_sender: Sender[ForwarderCommand],
state_catchup_receiver: Receiver[State],
):
self.node_id: NodeId = node_id
self.session_id: SessionId = session_id
@@ -79,6 +78,7 @@ class Worker:
self.global_event_receiver = global_event_receiver
self.local_event_sender = local_event_sender
self.state_catchup_receiver = state_catchup_receiver
self.local_event_index = 0
self.command_sender = command_sender
self.connection_message_receiver = connection_message_receiver
@@ -117,6 +117,7 @@ class Worker:
tg.start_soon(self._event_applier)
tg.start_soon(self._forward_events)
tg.start_soon(self._poll_connection_updates)
tg.start_soon(self._check_catchup_state)
# Actual shutdown code - waits for all tasks to complete before executing.
self.local_event_sender.close()
@@ -135,6 +136,22 @@ class Worker:
)
)
async def _check_catchup_state(self):
with self.state_catchup_receiver as states:
async for state in states:
if (
self.state.last_event_applied_idx == -1
and state.last_event_applied_idx > self.state.last_event_applied_idx
):
logger.info(
f"Worker catching up state to idx {state.last_event_applied_idx}"
)
self.event_buffer.store = {}
self.event_buffer.next_idx_to_release = (
state.last_event_applied_idx + 1
)
self.state = state
async def _event_applier(self):
with self.global_event_receiver as events:
async for f_event in events:
@@ -342,10 +359,7 @@ class Worker:
# We request all events after (and including) the missing index.
# This function is started whenever we receive an event that is out of sequence.
# It is cancelled as soon as we receiver an event that is in sequence.
if since_idx < 0:
logger.warning(f"Negative value encountered for nack request {since_idx=}")
since_idx = 0
assert since_idx >= 0
with CancelScope() as scope:
self._nack_cancel_scope = scope

View File

@@ -70,7 +70,6 @@ from exo.worker.engines.image import (
warmup_image_generator,
)
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.cache import KVPrefixCache
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
from exo.worker.engines.mlx.utils_mlx import (
apply_chat_template,
@@ -104,7 +103,6 @@ def main(
model: Model | DistributedImageModel | None = None
tokenizer = None
group = None
kv_prefix_cache: KVPrefixCache | None = None
current_status: RunnerStatus = RunnerIdle()
logger.info("runner created")
@@ -173,9 +171,6 @@ def main(
f"Unknown model task(s): {shard_metadata.model_card.tasks}"
)
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
kv_prefix_cache = KVPrefixCache()
current_status = RunnerLoaded()
logger.info("runner loaded")
case StartWarmup() if isinstance(current_status, RunnerLoaded):
@@ -243,7 +238,6 @@ def main(
tokenizer=tokenizer,
task=task_params,
prompt=prompt,
kv_prefix_cache=kv_prefix_cache,
)
# GPT-OSS specific parsing to match other model formats.