mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-23 13:29:29 -05:00
Compare commits
4 Commits
state-comp
...
fix-kv-pre
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7744420341 | ||
|
|
b777c6f505 | ||
|
|
812a9f232e | ||
|
|
f255345a1a |
@@ -2,11 +2,14 @@
|
||||
This type stub file was generated by pyright.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Protocol, Literal, Self
|
||||
|
||||
import mlx.nn as nn
|
||||
from mlx.core import array
|
||||
import mlx.core as mx
|
||||
from typing import Any, Dict, List, Literal, Optional, Protocol, Self
|
||||
from mlx.core import array
|
||||
|
||||
"""
|
||||
This type stub file was generated by pyright.
|
||||
"""
|
||||
|
||||
class Cache(Protocol):
|
||||
keys: mx.array
|
||||
@@ -32,6 +35,7 @@ def make_prompt_cache(
|
||||
``make_cache`` method, a ``RotatingKVCache`` is used with a maximum
|
||||
size of ``max_kv_size``
|
||||
"""
|
||||
...
|
||||
|
||||
def save_prompt_cache(
|
||||
file_name: str, cache: List[Cache], metadata: Dict[str, str] = ...
|
||||
@@ -45,6 +49,7 @@ def save_prompt_cache(
|
||||
metadata (Dict[str, str]): Optional metadata to save along with model
|
||||
state.
|
||||
"""
|
||||
...
|
||||
|
||||
def load_prompt_cache(file_name: str, return_metadata=...) -> array:
|
||||
"""
|
||||
@@ -59,13 +64,15 @@ def load_prompt_cache(file_name: str, return_metadata=...) -> array:
|
||||
List[Any] or Tuple[List[Any], Dict[str, str]]: The prompt cache and
|
||||
the metadata if requested.
|
||||
"""
|
||||
...
|
||||
|
||||
def can_trim_prompt_cache(cache: List[Cache]) -> bool:
|
||||
def can_trim_prompt_cache(cache: List[Any]) -> bool:
|
||||
"""
|
||||
Check if model's cache can be trimmed.
|
||||
"""
|
||||
...
|
||||
|
||||
def trim_prompt_cache(cache: List[Cache], num_tokens: int) -> List[Cache]:
|
||||
def trim_prompt_cache(cache: List[Any], num_tokens: int) -> int:
|
||||
"""
|
||||
Trim the model's cache by the given number of tokens.
|
||||
|
||||
@@ -79,6 +86,7 @@ def trim_prompt_cache(cache: List[Cache], num_tokens: int) -> List[Cache]:
|
||||
Returns:
|
||||
(int): The number of tokens that were trimmed.
|
||||
"""
|
||||
...
|
||||
|
||||
def create_attention_mask(
|
||||
N: int, offset: int, return_array: bool, window_size: Optional[int]
|
||||
@@ -107,164 +115,125 @@ class ConcatenateKVCache(_BaseCache):
|
||||
KVCache with a larger step size before using this cache.
|
||||
"""
|
||||
def __init__(self) -> None: ...
|
||||
def update_and_fetch(self, keys, values): # -> tuple[Any | array, Any | array]:
|
||||
...
|
||||
def update_and_fetch(self, keys, values): ...
|
||||
@property
|
||||
def state(self): # -> tuple[Any | array | None, Any | array | None]:
|
||||
...
|
||||
def state(self): ...
|
||||
@state.setter
|
||||
def state(self, v): # -> None:
|
||||
...
|
||||
def is_trimmable(self): # -> Literal[True]:
|
||||
...
|
||||
def trim(self, n): # -> int:
|
||||
...
|
||||
def make_mask(self, *args, **kwargs): # -> array | Literal['causal'] | None:
|
||||
...
|
||||
def state(self, v): ...
|
||||
def is_trimmable(self): ...
|
||||
def trim(self, n): ...
|
||||
def make_mask(self, *args, **kwargs): ...
|
||||
|
||||
class QuantizedKVCache(_BaseCache):
|
||||
step = ...
|
||||
offset: int
|
||||
def __init__(self, group_size: int = ..., bits: int = ...) -> None: ...
|
||||
def update_and_fetch(self, keys, values): # -> Any:
|
||||
...
|
||||
def update_and_fetch(self, keys, values): ...
|
||||
@property
|
||||
def state(
|
||||
self,
|
||||
): # -> tuple[Any | tuple[array, array, array] | None, Any | tuple[array, array, array] | None] | Any:
|
||||
...
|
||||
def state(self): ...
|
||||
@state.setter
|
||||
def state(self, v): # -> None:
|
||||
...
|
||||
def state(self, v): ...
|
||||
@property
|
||||
def meta_state(self): # -> tuple[str, ...]:
|
||||
...
|
||||
def meta_state(self): ...
|
||||
@meta_state.setter
|
||||
def meta_state(self, v): # -> None:
|
||||
...
|
||||
def is_trimmable(self): # -> Literal[True]:
|
||||
...
|
||||
def trim(self, n): # -> int:
|
||||
...
|
||||
def make_mask(self, *args, **kwargs): # -> array | Literal['causal'] | None:
|
||||
...
|
||||
def meta_state(self, v): ...
|
||||
def is_trimmable(self): ...
|
||||
def trim(self, n): ...
|
||||
def make_mask(self, *args, **kwargs): ...
|
||||
|
||||
class KVCache(_BaseCache):
|
||||
step = ...
|
||||
offset: int
|
||||
def __init__(self) -> None: ...
|
||||
def update_and_fetch(self, keys, values): # -> tuple[array | Any, array | Any]:
|
||||
...
|
||||
def update_and_fetch(self, keys, values): ...
|
||||
@property
|
||||
def state(
|
||||
self,
|
||||
) -> tuple[array, array]: ...
|
||||
def state(self) -> tuple[array, array]: ...
|
||||
@state.setter
|
||||
def state(self, v) -> None: ...
|
||||
def is_trimmable(self): # -> Literal[True]:
|
||||
...
|
||||
def trim(self, n): # -> int:
|
||||
...
|
||||
def is_trimmable(self): ...
|
||||
def trim(self, n): ...
|
||||
def to_quantized(
|
||||
self, group_size: int = ..., bits: int = ...
|
||||
) -> QuantizedKVCache: ...
|
||||
def make_mask(self, *args, **kwargs): # -> array | Literal['causal'] | None:
|
||||
...
|
||||
def make_mask(self, *args, **kwargs): ...
|
||||
|
||||
class RotatingKVCache(_BaseCache):
|
||||
step = ...
|
||||
offset: int
|
||||
def __init__(self, max_size, keep=...) -> None: ...
|
||||
def update_and_fetch(
|
||||
self, keys, values
|
||||
): # -> tuple[array | Any, array | Any] | tuple[array | Any, array | Any | None]:
|
||||
...
|
||||
def update_and_fetch(self, keys, values): ...
|
||||
@property
|
||||
def state(
|
||||
self,
|
||||
): # -> tuple[Any | array, Any | array] | tuple[Any | array | None, Any | array | None]:
|
||||
...
|
||||
def state(self): ...
|
||||
@state.setter
|
||||
def state(self, v): # -> None:
|
||||
...
|
||||
def state(self, v): ...
|
||||
@property
|
||||
def meta_state(self): # -> tuple[str, ...]:
|
||||
...
|
||||
def meta_state(self): ...
|
||||
@meta_state.setter
|
||||
def meta_state(self, v): # -> None:
|
||||
...
|
||||
def is_trimmable(self): # -> bool:
|
||||
...
|
||||
def trim(self, n): # -> int:
|
||||
...
|
||||
def meta_state(self, v): ...
|
||||
def is_trimmable(self): ...
|
||||
def trim(self, n): ...
|
||||
def to_quantized(
|
||||
self, group_size: int = ..., bits: int = ...
|
||||
) -> QuantizedKVCache: ...
|
||||
def make_mask(
|
||||
self, N: int, window_size: Optional[int] = ..., return_array: bool = ...
|
||||
): # -> array | Literal['causal'] | None:
|
||||
...
|
||||
): ...
|
||||
|
||||
class ArraysCache(_BaseCache):
|
||||
def __init__(self, size, left_padding: Optional[List[int]] = ...) -> None: ...
|
||||
def __setitem__(self, idx, value): # -> None:
|
||||
...
|
||||
def __setitem__(self, idx, value): ...
|
||||
def __getitem__(self, idx): ...
|
||||
@property
|
||||
def state(self): # -> list[Any | array] | list[array]:
|
||||
...
|
||||
def state(self): ...
|
||||
@state.setter
|
||||
def state(self, v): # -> None:
|
||||
...
|
||||
def filter(self, batch_indices): # -> None:
|
||||
def state(self, v): ...
|
||||
def filter(self, batch_indices):
|
||||
"""
|
||||
In-place filter to keep just the given indices in the cache.
|
||||
"""
|
||||
...
|
||||
|
||||
def extend(self, other): # -> None:
|
||||
def extend(self, other):
|
||||
"""
|
||||
In-place extend this cache with the other cache.
|
||||
"""
|
||||
|
||||
def make_mask(self, N: int): # -> array | None:
|
||||
...
|
||||
|
||||
def make_mask(self, N: int): ...
|
||||
|
||||
class MambaCache(ArraysCache):
|
||||
def __init__(self, left_padding: Optional[List[int]] = ...) -> None: ...
|
||||
|
||||
class ChunkedKVCache(KVCache):
|
||||
def __init__(self, chunk_size) -> None: ...
|
||||
def maybe_trim_front(self): # -> None:
|
||||
...
|
||||
def update_and_fetch(self, keys, values): # -> tuple[array, array]:
|
||||
...
|
||||
def trim(self, n): # -> int:
|
||||
...
|
||||
def maybe_trim_front(self): ...
|
||||
def update_and_fetch(self, keys, values): ...
|
||||
def trim(self, n): ...
|
||||
@property
|
||||
def meta_state(self): # -> tuple[str, ...]:
|
||||
...
|
||||
def meta_state(self): ...
|
||||
@meta_state.setter
|
||||
def meta_state(self, v): # -> None:
|
||||
...
|
||||
def meta_state(self, v): ...
|
||||
|
||||
class CacheList(_BaseCache):
|
||||
def __init__(self, *caches) -> None: ...
|
||||
def __getitem__(self, idx): ...
|
||||
def is_trimmable(self): # -> bool:
|
||||
...
|
||||
def is_trimmable(self): ...
|
||||
def trim(self, n): ...
|
||||
@property
|
||||
def state(self): # -> list[Any]:
|
||||
...
|
||||
def state(self): ...
|
||||
@state.setter
|
||||
def state(self, v): # -> None:
|
||||
...
|
||||
def filter(self, batch_indices): # -> None:
|
||||
def state(self, v): ...
|
||||
def filter(self, batch_indices):
|
||||
"""
|
||||
In-place filter to keep just the given indices in the cache.
|
||||
"""
|
||||
...
|
||||
|
||||
def extend(self, other): # -> None:
|
||||
def extend(self, other):
|
||||
"""
|
||||
In-place extend this cache with the other cache.
|
||||
"""
|
||||
...
|
||||
|
||||
class BatchKVCache(_BaseCache):
|
||||
step = ...
|
||||
@@ -287,71 +256,56 @@ class BatchKVCache(_BaseCache):
|
||||
And ``left_padding`` specifies the amount of padding for each.
|
||||
In this case, ``left_padding = [1, 3, 0]``.
|
||||
"""
|
||||
...
|
||||
|
||||
def update_and_fetch(self, keys, values): # -> tuple[array | Any, array | Any]:
|
||||
...
|
||||
def update_and_fetch(self, keys, values): ...
|
||||
@property
|
||||
def state(
|
||||
self,
|
||||
): # -> tuple[Any | array | None, Any | array | None, array | Any, array | Any]:
|
||||
...
|
||||
def state(self): ...
|
||||
@state.setter
|
||||
def state(self, v): # -> None:
|
||||
...
|
||||
def is_trimmable(self): # -> Literal[True]:
|
||||
...
|
||||
def trim(self, n): # -> int | float:
|
||||
...
|
||||
def make_mask(self, N: int, return_array: bool = ..., **kwargs): # -> array:
|
||||
...
|
||||
def filter(self, batch_indices): # -> None:
|
||||
def state(self, v): ...
|
||||
def is_trimmable(self): ...
|
||||
def trim(self, n): ...
|
||||
def make_mask(self, N: int, return_array: bool = ..., **kwargs): ...
|
||||
def filter(self, batch_indices):
|
||||
"""
|
||||
In-place filter to keep just the given indices in the cache.
|
||||
"""
|
||||
...
|
||||
|
||||
def extend(self, other): # -> None:
|
||||
def extend(self, other):
|
||||
"""
|
||||
In-place extend this cache with the other cache.
|
||||
"""
|
||||
...
|
||||
|
||||
class BatchRotatingKVCache(_BaseCache):
|
||||
step = ...
|
||||
def __init__(self, max_size, left_padding: List[int]) -> None: ...
|
||||
def update_and_fetch(
|
||||
self, keys, values
|
||||
): # -> tuple[array | Any, array | Any] | tuple[array | Any, array | Any | None]:
|
||||
...
|
||||
def update_and_fetch(self, keys, values): ...
|
||||
@property
|
||||
def state(
|
||||
self,
|
||||
): # -> tuple[Any | array | None, Any | array | None, array | Any, array | Any]:
|
||||
...
|
||||
def state(self): ...
|
||||
@state.setter
|
||||
def state(self, v): # -> None:
|
||||
...
|
||||
def state(self, v): ...
|
||||
@property
|
||||
def meta_state(self): # -> tuple[str, ...]:
|
||||
...
|
||||
def meta_state(self): ...
|
||||
@meta_state.setter
|
||||
def meta_state(self, v): # -> None:
|
||||
...
|
||||
def is_trimmable(self): # -> bool:
|
||||
...
|
||||
def trim(self, n): # -> int:
|
||||
...
|
||||
def meta_state(self, v): ...
|
||||
def is_trimmable(self): ...
|
||||
def trim(self, n): ...
|
||||
def to_quantized(
|
||||
self, group_size: int = ..., bits: int = ...
|
||||
) -> QuantizedKVCache: ...
|
||||
def make_mask(
|
||||
self, N: int, window_size: Optional[int] = ..., return_array: bool = ...
|
||||
): # -> array:
|
||||
...
|
||||
def filter(self, batch_indices): # -> None:
|
||||
): ...
|
||||
def filter(self, batch_indices):
|
||||
"""
|
||||
In-place filter to keep just the given indices in the cache.
|
||||
"""
|
||||
...
|
||||
|
||||
def extend(self, other): # -> None:
|
||||
def extend(self, other):
|
||||
"""
|
||||
In-place extend this cache with the other cache.
|
||||
"""
|
||||
...
|
||||
|
||||
@@ -8,6 +8,10 @@ from typing import Any
|
||||
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
"""
|
||||
This type stub file was generated by pyright.
|
||||
"""
|
||||
|
||||
class StreamingDetokenizer:
|
||||
"""The streaming detokenizer interface so that we can detokenize one token at a time.
|
||||
|
||||
@@ -45,6 +49,7 @@ class StreamingDetokenizer:
|
||||
@property
|
||||
def last_segment(self):
|
||||
"""Return the last segment of readable text since last time this property was accessed."""
|
||||
...
|
||||
|
||||
class NaiveStreamingDetokenizer(StreamingDetokenizer):
|
||||
"""NaiveStreamingDetokenizer relies on the underlying tokenizer
|
||||
@@ -54,15 +59,11 @@ class NaiveStreamingDetokenizer(StreamingDetokenizer):
|
||||
repeatedly detokenize the same tokens until a new line is generated.
|
||||
"""
|
||||
def __init__(self, tokenizer) -> None: ...
|
||||
def reset(self): # -> None:
|
||||
...
|
||||
def add_token(self, token): # -> None:
|
||||
...
|
||||
def finalize(self): # -> None:
|
||||
...
|
||||
def reset(self): ...
|
||||
def add_token(self, token): ...
|
||||
def finalize(self): ...
|
||||
@property
|
||||
def text(self): # -> str:
|
||||
...
|
||||
def text(self): ...
|
||||
|
||||
class SPMStreamingDetokenizer(StreamingDetokenizer):
|
||||
"""A streaming detokenizer for SPM models.
|
||||
@@ -71,12 +72,9 @@ class SPMStreamingDetokenizer(StreamingDetokenizer):
|
||||
underscore which results in linear complexity.
|
||||
"""
|
||||
def __init__(self, tokenizer, trim_space=...) -> None: ...
|
||||
def reset(self): # -> None:
|
||||
...
|
||||
def add_token(self, token): # -> None:
|
||||
...
|
||||
def finalize(self): # -> None:
|
||||
...
|
||||
def reset(self): ...
|
||||
def add_token(self, token): ...
|
||||
def finalize(self): ...
|
||||
|
||||
class BPEStreamingDetokenizer(StreamingDetokenizer):
|
||||
"""A streaming detokenizer for OpenAI style BPE models.
|
||||
@@ -88,15 +86,13 @@ class BPEStreamingDetokenizer(StreamingDetokenizer):
|
||||
_byte_decoder = ...
|
||||
_space_matches = ...
|
||||
def __init__(self, tokenizer) -> None: ...
|
||||
def reset(self): # -> None:
|
||||
...
|
||||
def add_token(self, token): # -> None:
|
||||
...
|
||||
def finalize(self): # -> None:
|
||||
...
|
||||
def reset(self): ...
|
||||
def add_token(self, token): ...
|
||||
def finalize(self): ...
|
||||
@classmethod
|
||||
def make_byte_decoder(cls): # -> None:
|
||||
def make_byte_decoder(cls):
|
||||
"""See https://github.com/openai/gpt-2/blob/master/src/encoder.py for the rationale."""
|
||||
...
|
||||
|
||||
class TokenizerWrapper:
|
||||
"""A wrapper that combines an HF tokenizer and a detokenizer.
|
||||
@@ -157,13 +153,10 @@ class TokenizerWrapper:
|
||||
class NewlineTokenizer(PreTrainedTokenizerFast):
|
||||
"""A tokenizer that replaces newlines with <n> and <n> with new line."""
|
||||
def __init__(self, *args, **kwargs) -> None: ...
|
||||
def encode(self, text, **kwargs): # -> list[int]:
|
||||
...
|
||||
def encode(self, text, **kwargs): ...
|
||||
def encode_batch(self, texts, **kwargs): ...
|
||||
def decode(self, *args, **kwargs): # -> str:
|
||||
...
|
||||
def batch_decode(self, *args, **kwargs): # -> list[str]:
|
||||
...
|
||||
def decode(self, *args, **kwargs): ...
|
||||
def batch_decode(self, *args, **kwargs): ...
|
||||
|
||||
def load(
|
||||
model_path: Path,
|
||||
@@ -176,6 +169,7 @@ def load(
|
||||
Note, to use a fast streaming tokenizer, pass a local file path rather than
|
||||
a Hugging Face repo ID.
|
||||
"""
|
||||
...
|
||||
|
||||
# Alias for backward compatibility
|
||||
load_tokenizer = load
|
||||
|
||||
@@ -3,6 +3,45 @@
|
||||
perSystem =
|
||||
{ pkgs, lib, ... }:
|
||||
let
|
||||
# Stub source with lockfiles and minimal files for build to succeed
|
||||
# This allows prettier-svelte to avoid rebuilding when dashboard source changes
|
||||
dashboardStubSrc = pkgs.runCommand "dashboard-stub-src" { } ''
|
||||
mkdir -p $out
|
||||
cp ${inputs.self}/dashboard/package.json $out/
|
||||
cp ${inputs.self}/dashboard/package-lock.json $out/
|
||||
# Minimal files so vite build succeeds (produces empty output)
|
||||
echo '<!DOCTYPE html><html><head></head><body></body></html>' > $out/index.html
|
||||
mkdir -p $out/src
|
||||
touch $out/src/app.html
|
||||
'';
|
||||
|
||||
# Deps-only build using stub source (for prettier-svelte)
|
||||
# Only rebuilds when package.json or package-lock.json change
|
||||
dashboardDeps = inputs.dream2nix.lib.evalModules {
|
||||
packageSets.nixpkgs = pkgs;
|
||||
modules = [
|
||||
./dashboard.nix
|
||||
{
|
||||
paths.projectRoot = inputs.self;
|
||||
paths.projectRootFile = "flake.nix";
|
||||
paths.package = inputs.self + "/dashboard";
|
||||
}
|
||||
{
|
||||
deps.dashboardSrc = lib.mkForce dashboardStubSrc;
|
||||
}
|
||||
# Override build phases to skip the actual build - just need node_modules
|
||||
{
|
||||
mkDerivation = {
|
||||
buildPhase = lib.mkForce "true";
|
||||
installPhase = lib.mkForce ''
|
||||
runHook preInstall
|
||||
runHook postInstall
|
||||
'';
|
||||
};
|
||||
}
|
||||
];
|
||||
};
|
||||
|
||||
# Filter source to only include dashboard directory
|
||||
dashboardSrc = lib.cleanSourceWith {
|
||||
src = inputs.self;
|
||||
@@ -42,11 +81,12 @@
|
||||
'';
|
||||
|
||||
# Prettier with svelte plugin for treefmt
|
||||
# Uses dashboardDeps instead of dashboardFull to avoid rebuilding on source changes
|
||||
packages.prettier-svelte = pkgs.writeShellScriptBin "prettier-svelte" ''
|
||||
export NODE_PATH="${dashboardFull}/lib/node_modules/exo-dashboard/node_modules"
|
||||
export NODE_PATH="${dashboardDeps}/lib/node_modules/exo-dashboard/node_modules"
|
||||
exec ${pkgs.nodejs}/bin/node \
|
||||
${dashboardFull}/lib/node_modules/exo-dashboard/node_modules/prettier/bin/prettier.cjs \
|
||||
--plugin "${dashboardFull}/lib/node_modules/exo-dashboard/node_modules/prettier-plugin-svelte/plugin.js" \
|
||||
${dashboardDeps}/lib/node_modules/exo-dashboard/node_modules/prettier/bin/prettier.cjs \
|
||||
--plugin "${dashboardDeps}/lib/node_modules/exo-dashboard/node_modules/prettier-plugin-svelte/plugin.js" \
|
||||
"$@"
|
||||
'';
|
||||
};
|
||||
|
||||
11
src/exo/shared/types/mlx.py
Normal file
11
src/exo/shared/types/mlx.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""Shared types for MLX-related functionality."""
|
||||
|
||||
from mlx_lm.models.cache import (
|
||||
KVCache,
|
||||
QuantizedKVCache,
|
||||
RotatingKVCache,
|
||||
)
|
||||
|
||||
# Type alias for KV cache - matches make_kv_cache return type
|
||||
# This list contains one cache entry per transformer layer
|
||||
KVCacheType = list[KVCache | RotatingKVCache | QuantizedKVCache]
|
||||
@@ -1,39 +1,53 @@
|
||||
# type: ignore
|
||||
# TODO: Fix this file, including types!
|
||||
from copy import deepcopy
|
||||
from typing import Callable
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm import stream_generate
|
||||
from mlx_lm.models.cache import _BaseCache, trim_prompt_cache
|
||||
from mlx_lm.models.cache import trim_prompt_cache
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
from exo.shared.types.mlx import KVCacheType
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.constants import KEEP_KV_SIZE, KV_BITS, KV_GROUP_SIZE
|
||||
from exo.worker.engines.mlx.utils_mlx import make_kv_cache
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
|
||||
class KVPrefixCache:
|
||||
def __init__(self):
|
||||
# Only one prefix cache per runner.
|
||||
self.prompts: list[mx.array] = [] # mx array of tokens (ints)
|
||||
self.caches: list[list[_BaseCache]] = []
|
||||
self.caches: list[KVCacheType] = []
|
||||
|
||||
def clear(self):
|
||||
"""Clear all cached prompts and caches."""
|
||||
self.prompts.clear()
|
||||
self.caches.clear()
|
||||
|
||||
def add_kv_cache(
|
||||
self, tokenizer: TokenizerWrapper, prompt: str, cache: list[_BaseCache]
|
||||
self, tokenizer: TokenizerWrapper, prompt: str, cache: KVCacheType
|
||||
):
|
||||
tokenized_prompt = self.encode_prompt(tokenizer, prompt)
|
||||
tokenized_prompt = encode_prompt(tokenizer, prompt)
|
||||
self.prompts.append(tokenized_prompt)
|
||||
self.caches.append(deepcopy(cache))
|
||||
logger.info(f"KV cache saved: {len(tokenized_prompt)} tokens")
|
||||
|
||||
def get_kv_cache(
|
||||
self,
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
sampler: Callable[[mx.array], mx.array],
|
||||
prompt: str,
|
||||
) -> list[_BaseCache]:
|
||||
tokenized_prompt = self.encode_prompt(tokenizer, prompt)
|
||||
) -> tuple[KVCacheType, mx.array]:
|
||||
"""Get KV cache for prompt, returning remaining tokens to prefill.
|
||||
|
||||
This method finds the best matching cached prefix and returns:
|
||||
- A copy of the cache trimmed to the prefix length
|
||||
- The remaining tokens that need to be prefilled before generation
|
||||
|
||||
The caller is responsible for prefilling the remaining tokens.
|
||||
|
||||
Returns:
|
||||
Tuple of (cache, remaining_tokens) where remaining_tokens are the
|
||||
tokens that still need to be prefilled/processed.
|
||||
"""
|
||||
tokenized_prompt = encode_prompt(tokenizer, prompt)
|
||||
max_length = len(tokenized_prompt)
|
||||
|
||||
best_snapshot_index, best_snapshot_length = None, 0
|
||||
@@ -42,63 +56,75 @@ class KVPrefixCache:
|
||||
length = _get_prefix_length(tokenized_prompt, cached_prompt)
|
||||
|
||||
if length == max_length:
|
||||
return self.caches[i]
|
||||
# Exact match - cached prompt starts with our entire prompt
|
||||
# Trim cache to prompt length - 1, return last token for stream_generate
|
||||
prompt_cache = deepcopy(self.caches[i])
|
||||
cached_length = _cache_length(self.caches[i])
|
||||
tokens_to_trim = cached_length - (max_length - 1)
|
||||
if tokens_to_trim > 0:
|
||||
trim_prompt_cache(prompt_cache, tokens_to_trim)
|
||||
logger.info(f"KV cache exact match: {max_length} tokens (instant)")
|
||||
return prompt_cache, tokenized_prompt[-1:]
|
||||
|
||||
if length > best_snapshot_length:
|
||||
best_snapshot_index, best_snapshot_length = i, length
|
||||
|
||||
if best_snapshot_index is not None:
|
||||
prompt_cache = deepcopy(self.caches[best_snapshot_index])
|
||||
trim_prompt_cache(prompt_cache, max_length - best_snapshot_length)
|
||||
tokenized_prompt = tokenized_prompt[best_snapshot_index:]
|
||||
|
||||
else:
|
||||
prompt_cache = make_kv_cache(
|
||||
model,
|
||||
# max_kv_size=MAX_KV_SIZE,
|
||||
# keep=KEEP_KV_SIZE
|
||||
new_tokens = max_length - best_snapshot_length
|
||||
logger.info(
|
||||
f"KV cache prefix match: {best_snapshot_length}/{max_length} tokens "
|
||||
f"(reusing {best_snapshot_length}, need to prefill {new_tokens})"
|
||||
)
|
||||
|
||||
prefill(model, tokenizer, sampler, tokenized_prompt, prompt_cache)
|
||||
prompt_cache = deepcopy(self.caches[best_snapshot_index])
|
||||
|
||||
return prompt_cache
|
||||
# Trim removes tokens from the end, so we trim (cached_length - prefix_length) to keep the prefix
|
||||
cached_length = _cache_length(self.caches[best_snapshot_index])
|
||||
tokens_to_trim = cached_length - best_snapshot_length
|
||||
if tokens_to_trim > 0:
|
||||
trim_prompt_cache(prompt_cache, tokens_to_trim)
|
||||
|
||||
def encode_prompt(self, tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
|
||||
add_special_tokens = tokenizer.bos_token is None or not prompt.startswith(
|
||||
tokenizer.bos_token
|
||||
)
|
||||
tokenized_prompt = tokenizer.encode(
|
||||
prompt, add_special_tokens=add_special_tokens
|
||||
)
|
||||
return mx.array(tokenized_prompt)
|
||||
# Return remaining tokens for caller to prefill
|
||||
remaining_tokens = tokenized_prompt[best_snapshot_length:]
|
||||
return prompt_cache, remaining_tokens
|
||||
|
||||
else:
|
||||
prompt_cache = make_kv_cache(model)
|
||||
if len(self.prompts) == 0:
|
||||
logger.info(f"KV cache empty, need to prefill {max_length} tokens")
|
||||
else:
|
||||
logger.info(
|
||||
f"KV cache no prefix match, need to prefill {max_length} tokens"
|
||||
)
|
||||
|
||||
# Return all tokens for caller to prefill
|
||||
return prompt_cache, tokenized_prompt
|
||||
|
||||
|
||||
def encode_prompt(tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
|
||||
"""Encode a prompt string to token array.
|
||||
|
||||
For chat-templated prompts (which have their own structure markers like
|
||||
<|im_user|>, <|im_middle|>, etc.), we should NOT add BOS/EOS tokens as
|
||||
that would corrupt the prompt structure.
|
||||
"""
|
||||
# Chat templates define their own structure - don't add BOS/EOS
|
||||
tokenized_prompt = tokenizer.encode(prompt, add_special_tokens=False)
|
||||
return mx.array(tokenized_prompt)
|
||||
|
||||
|
||||
def _cache_length(cache: KVCacheType) -> int:
|
||||
"""Get the number of tokens in a KV cache."""
|
||||
# Use .offset attribute which all cache types have (len() not implemented in older QuantizedKVCache)
|
||||
return max(c.offset for c in cache)
|
||||
|
||||
|
||||
def _get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
|
||||
n = min(int(prompt.shape[0]), int(cached_prompt.shape[0]), KEEP_KV_SIZE)
|
||||
"""Find the length of the common prefix between two token arrays."""
|
||||
n = min(int(prompt.shape[0]), int(cached_prompt.shape[0]))
|
||||
if n == 0:
|
||||
return 0
|
||||
|
||||
equal = (prompt[:n] == cached_prompt[:n]).astype(mx.int32)
|
||||
equal = mx.equal(prompt[:n], cached_prompt[:n]).astype(mx.int32)
|
||||
prefix_mask = mx.cumprod(equal) # stays 1 until first mismatch, then 0 forever
|
||||
return int(mx.sum(prefix_mask).item())
|
||||
|
||||
|
||||
def prefill(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
sampler: Callable[[mx.array], mx.array],
|
||||
prompt: mx.array,
|
||||
cache: list[_BaseCache],
|
||||
) -> None:
|
||||
for _ in stream_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=prompt,
|
||||
max_tokens=0,
|
||||
sampler=sampler,
|
||||
prompt_cache=cache,
|
||||
prefill_step_size=2048,
|
||||
kv_group_size=KV_GROUP_SIZE,
|
||||
kv_bits=KV_BITS,
|
||||
):
|
||||
pass
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
from typing import Any, Callable, Generator, cast, get_args
|
||||
import time
|
||||
from typing import Callable, Generator, cast, get_args
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm.generate import stream_generate
|
||||
from mlx_lm.models.cache import KVCache
|
||||
from mlx_lm.models.cache import trim_prompt_cache
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
# from exo.engines.mlx.cache import KVPrefixCache
|
||||
from exo.shared.types.api import (
|
||||
BenchChatCompletionTaskParams,
|
||||
ChatCompletionMessage,
|
||||
@@ -14,11 +14,13 @@ from exo.shared.types.api import (
|
||||
GenerationStats,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.mlx import KVCacheType
|
||||
from exo.shared.types.tasks import ChatCompletionTaskParams
|
||||
from exo.shared.types.worker.runner_response import (
|
||||
GenerationResponse,
|
||||
)
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.cache import KVPrefixCache, encode_prompt
|
||||
from exo.worker.engines.mlx.constants import KV_BITS, KV_GROUP_SIZE, MAX_TOKENS
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
apply_chat_template,
|
||||
@@ -30,19 +32,59 @@ from exo.worker.runner.bootstrap import logger
|
||||
generation_stream = mx.new_stream(mx.default_device())
|
||||
|
||||
|
||||
def maybe_quantize_kv_cache(
|
||||
prompt_cache: list[KVCache | Any],
|
||||
quantized_kv_start: int,
|
||||
kv_group_size: int,
|
||||
kv_bits: int | None,
|
||||
) -> None:
|
||||
if kv_bits is None:
|
||||
return
|
||||
for e, c in enumerate(prompt_cache):
|
||||
if (
|
||||
hasattr(c, "to_quantized") and c.offset >= quantized_kv_start # type: ignore
|
||||
):
|
||||
prompt_cache[e] = c.to_quantized(group_size=kv_group_size, bits=kv_bits)
|
||||
def prefill(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
sampler: Callable[[mx.array], mx.array],
|
||||
prompt_tokens: mx.array,
|
||||
cache: KVCacheType,
|
||||
) -> float:
|
||||
"""Prefill the KV cache with prompt tokens.
|
||||
|
||||
This runs the model over the prompt tokens to populate the cache,
|
||||
then trims off the extra generated token.
|
||||
|
||||
Returns:
|
||||
tokens_per_sec
|
||||
"""
|
||||
num_tokens = len(prompt_tokens)
|
||||
if num_tokens == 0:
|
||||
return 0.0
|
||||
|
||||
logger.debug(f"Prefilling {num_tokens} tokens...")
|
||||
start_time = time.perf_counter()
|
||||
|
||||
def progress_callback(processed: int, total: int) -> None:
|
||||
elapsed = time.time() - start_time
|
||||
tok_per_sec = processed / elapsed if elapsed > 0 else 0
|
||||
logger.debug(
|
||||
f"Prefill progress: {processed}/{total} tokens ({tok_per_sec:.1f} tok/s)"
|
||||
)
|
||||
|
||||
# Use max_tokens=1 because max_tokens=0 does not work.
|
||||
# We just throw away the generated token - we only care about filling the cache
|
||||
for _ in stream_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=prompt_tokens,
|
||||
max_tokens=1,
|
||||
sampler=sampler,
|
||||
prompt_cache=cache,
|
||||
prefill_step_size=2048,
|
||||
kv_group_size=KV_GROUP_SIZE,
|
||||
kv_bits=KV_BITS,
|
||||
prompt_progress_callback=progress_callback,
|
||||
):
|
||||
break # Stop after first iteration - cache is now filled
|
||||
trim_prompt_cache(cache, 1)
|
||||
|
||||
elapsed = time.perf_counter() - start_time
|
||||
tokens_per_sec = num_tokens / elapsed if elapsed > 0 else 0.0
|
||||
logger.debug(
|
||||
f"Prefill complete: {num_tokens} tokens in {elapsed:.2f}s "
|
||||
f"({tokens_per_sec:.1f} tok/s)"
|
||||
)
|
||||
return tokens_per_sec
|
||||
|
||||
|
||||
def warmup_inference(
|
||||
@@ -120,6 +162,7 @@ def mlx_generate(
|
||||
tokenizer: TokenizerWrapper,
|
||||
task: ChatCompletionTaskParams,
|
||||
prompt: str,
|
||||
kv_prefix_cache: KVPrefixCache | None = None,
|
||||
) -> Generator[GenerationResponse]:
|
||||
# Ensure that generation stats only contains peak memory for this generation
|
||||
mx.reset_peak_memory()
|
||||
@@ -131,7 +174,16 @@ def mlx_generate(
|
||||
if task.seed is not None:
|
||||
mx.random.seed(task.seed)
|
||||
|
||||
caches = make_kv_cache(model=model)
|
||||
# Do not use the prefix cache if we are trying to do benchmarks.
|
||||
if is_bench:
|
||||
kv_prefix_cache = None
|
||||
|
||||
# Use prefix cache if available, otherwise create fresh cache
|
||||
if kv_prefix_cache is None:
|
||||
caches = make_kv_cache(model=model)
|
||||
prompt_tokens = encode_prompt(tokenizer, prompt)
|
||||
else:
|
||||
caches, prompt_tokens = kv_prefix_cache.get_kv_cache(model, tokenizer, prompt)
|
||||
|
||||
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] = []
|
||||
if is_bench:
|
||||
@@ -144,11 +196,19 @@ def mlx_generate(
|
||||
top_p=task.top_p if task.top_p is not None else 1.0,
|
||||
)
|
||||
|
||||
# Prefill cache with all tokens except the last one
|
||||
prefill_tps = prefill(model, tokenizer, sampler, prompt_tokens[-1:], caches)
|
||||
|
||||
# stream_generate starts from the last token
|
||||
last_token = prompt_tokens[-1:]
|
||||
|
||||
max_tokens = task.max_tokens or MAX_TOKENS
|
||||
generated_text_parts: list[str] = []
|
||||
generation_start_time = time.perf_counter()
|
||||
for out in stream_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=prompt,
|
||||
prompt=last_token,
|
||||
max_tokens=max_tokens,
|
||||
sampler=sampler,
|
||||
logits_processors=logits_processors,
|
||||
@@ -158,12 +218,13 @@ def mlx_generate(
|
||||
kv_group_size=KV_GROUP_SIZE,
|
||||
kv_bits=KV_BITS,
|
||||
):
|
||||
generated_text_parts.append(out.text)
|
||||
logger.info(out.text)
|
||||
|
||||
stats: GenerationStats | None = None
|
||||
if out.finish_reason is not None:
|
||||
stats = GenerationStats(
|
||||
prompt_tps=float(out.prompt_tps),
|
||||
prompt_tps=float(prefill_tps or out.prompt_tps),
|
||||
generation_tps=float(out.generation_tps),
|
||||
prompt_tokens=int(out.prompt_tokens),
|
||||
generation_tokens=int(out.generation_tokens),
|
||||
@@ -185,6 +246,22 @@ def mlx_generate(
|
||||
)
|
||||
|
||||
if out.finish_reason is not None:
|
||||
# Log generation stats
|
||||
generation_elapsed = time.perf_counter() - generation_start_time
|
||||
generated_tokens = len(generated_text_parts)
|
||||
generation_tps = (
|
||||
generated_tokens / generation_elapsed if generation_elapsed > 0 else 0.0
|
||||
)
|
||||
logger.debug(
|
||||
f"Generation complete: prefill {prompt_tokens} tokens @ "
|
||||
f"{prefill_tps:.1f} tok/s, generated {generated_tokens} tokens @ "
|
||||
f"{generation_tps:.1f} tok/s"
|
||||
)
|
||||
# Save cache for future prefix matching (clear first to keep only the last one)
|
||||
if kv_prefix_cache is not None:
|
||||
kv_prefix_cache.clear()
|
||||
full_prompt = prompt + "".join(generated_text_parts)
|
||||
kv_prefix_cache.add_kv_cache(tokenizer, full_prompt, caches)
|
||||
break
|
||||
|
||||
# TODO: Do we want an mx_barrier?
|
||||
|
||||
@@ -70,6 +70,7 @@ from exo.worker.engines.image import (
|
||||
warmup_image_generator,
|
||||
)
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.cache import KVPrefixCache
|
||||
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
apply_chat_template,
|
||||
@@ -103,6 +104,7 @@ def main(
|
||||
model: Model | DistributedImageModel | None = None
|
||||
tokenizer = None
|
||||
group = None
|
||||
kv_prefix_cache: KVPrefixCache | None = None
|
||||
|
||||
current_status: RunnerStatus = RunnerIdle()
|
||||
logger.info("runner created")
|
||||
@@ -171,6 +173,9 @@ def main(
|
||||
f"Unknown model task(s): {shard_metadata.model_card.tasks}"
|
||||
)
|
||||
|
||||
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
|
||||
current_status = RunnerLoaded()
|
||||
logger.info("runner loaded")
|
||||
case StartWarmup() if isinstance(current_status, RunnerLoaded):
|
||||
@@ -238,6 +243,7 @@ def main(
|
||||
tokenizer=tokenizer,
|
||||
task=task_params,
|
||||
prompt=prompt,
|
||||
kv_prefix_cache=kv_prefix_cache,
|
||||
)
|
||||
|
||||
# GPT-OSS specific parsing to match other model formats.
|
||||
|
||||
Reference in New Issue
Block a user