Consolidate cleanup

This commit is contained in:
rltakashige
2025-11-21 14:54:02 +00:00
committed by GitHub
parent 28a91787e8
commit b45cbdeecd
72 changed files with 634 additions and 4854 deletions

View File

@@ -2,14 +2,24 @@
This type stub file was generated by pyright.
"""
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Protocol, Literal, Self
import mlx.nn as nn
from mlx.core import array
import mlx.core as mx
class Cache(Protocol):
keys: mx.array
values: mx.array
def update_and_fetch(self, keys: mx.array, values: mx.array) -> None: ...
@property
def state(self) -> tuple[mx.array, mx.array]: ...
@state.setter
def state(self, v) -> None: ...
def make_prompt_cache(
model: nn.Module, max_kv_size: Optional[int] = ...
) -> List[KVCache | Any]:
) -> List[Cache | Any]:
"""
Construct the model's cache for use in generation.
@@ -24,7 +34,7 @@ def make_prompt_cache(
"""
def save_prompt_cache(
file_name: str, cache: List[Any], metadata: Dict[str, str] = ...
file_name: str, cache: List[Cache], metadata: Dict[str, str] = ...
) -> None:
"""
Save a pre-computed prompt cache to a file.
@@ -50,12 +60,12 @@ def load_prompt_cache(file_name: str, return_metadata=...) -> array:
the metadata if requested.
"""
def can_trim_prompt_cache(cache: List[Any]) -> bool:
def can_trim_prompt_cache(cache: List[Cache]) -> bool:
"""
Check if model's cache can be trimmed.
"""
def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]:
def trim_prompt_cache(cache: List[Cache], num_tokens: int) -> List[Cache]:
"""
Trim the model's cache by the given number of tokens.
@@ -72,27 +82,22 @@ def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]:
def create_attention_mask(
N: int, offset: int, return_array: bool, window_size: Optional[int]
): # -> array | Literal['causal'] | None:
...
) -> array | Literal["causal"] | None: ...
class _BaseCache:
class _BaseCache(Cache):
keys: mx.array
values: mx.array
@property
def state(self): # -> list[Any]:
...
def state(self) -> tuple[mx.array, mx.array]: ...
@state.setter
def state(self, v): # -> None:
...
def state(self, v) -> None: ...
@property
def meta_state(self): # -> Literal['']:
...
def meta_state(self) -> Literal[""]: ...
@meta_state.setter
def meta_state(self, v): # -> None:
...
def is_trimmable(self): # -> Literal[False]:
...
def meta_state(self, v) -> None: ...
def is_trimmable(self) -> Literal[False]: ...
@classmethod
def from_state(cls, state, meta_state): # -> Self:
...
def from_state(cls, state, meta_state) -> Self: ...
class ConcatenateKVCache(_BaseCache):
"""ConcatenateKVCache the simplest KV cache implementation.