mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 22:27:50 -05:00
Consolidate cleanup
This commit is contained in:
@@ -2,14 +2,24 @@
|
||||
This type stub file was generated by pyright.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Protocol, Literal, Self
|
||||
|
||||
import mlx.nn as nn
|
||||
from mlx.core import array
|
||||
import mlx.core as mx
|
||||
|
||||
class Cache(Protocol):
|
||||
keys: mx.array
|
||||
values: mx.array
|
||||
def update_and_fetch(self, keys: mx.array, values: mx.array) -> None: ...
|
||||
@property
|
||||
def state(self) -> tuple[mx.array, mx.array]: ...
|
||||
@state.setter
|
||||
def state(self, v) -> None: ...
|
||||
|
||||
def make_prompt_cache(
|
||||
model: nn.Module, max_kv_size: Optional[int] = ...
|
||||
) -> List[KVCache | Any]:
|
||||
) -> List[Cache | Any]:
|
||||
"""
|
||||
Construct the model's cache for use in generation.
|
||||
|
||||
@@ -24,7 +34,7 @@ def make_prompt_cache(
|
||||
"""
|
||||
|
||||
def save_prompt_cache(
|
||||
file_name: str, cache: List[Any], metadata: Dict[str, str] = ...
|
||||
file_name: str, cache: List[Cache], metadata: Dict[str, str] = ...
|
||||
) -> None:
|
||||
"""
|
||||
Save a pre-computed prompt cache to a file.
|
||||
@@ -50,12 +60,12 @@ def load_prompt_cache(file_name: str, return_metadata=...) -> array:
|
||||
the metadata if requested.
|
||||
"""
|
||||
|
||||
def can_trim_prompt_cache(cache: List[Any]) -> bool:
|
||||
def can_trim_prompt_cache(cache: List[Cache]) -> bool:
|
||||
"""
|
||||
Check if model's cache can be trimmed.
|
||||
"""
|
||||
|
||||
def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]:
|
||||
def trim_prompt_cache(cache: List[Cache], num_tokens: int) -> List[Cache]:
|
||||
"""
|
||||
Trim the model's cache by the given number of tokens.
|
||||
|
||||
@@ -72,27 +82,22 @@ def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]:
|
||||
|
||||
def create_attention_mask(
|
||||
N: int, offset: int, return_array: bool, window_size: Optional[int]
|
||||
): # -> array | Literal['causal'] | None:
|
||||
...
|
||||
) -> array | Literal["causal"] | None: ...
|
||||
|
||||
class _BaseCache:
|
||||
class _BaseCache(Cache):
|
||||
keys: mx.array
|
||||
values: mx.array
|
||||
@property
|
||||
def state(self): # -> list[Any]:
|
||||
...
|
||||
def state(self) -> tuple[mx.array, mx.array]: ...
|
||||
@state.setter
|
||||
def state(self, v): # -> None:
|
||||
...
|
||||
def state(self, v) -> None: ...
|
||||
@property
|
||||
def meta_state(self): # -> Literal['']:
|
||||
...
|
||||
def meta_state(self) -> Literal[""]: ...
|
||||
@meta_state.setter
|
||||
def meta_state(self, v): # -> None:
|
||||
...
|
||||
def is_trimmable(self): # -> Literal[False]:
|
||||
...
|
||||
def meta_state(self, v) -> None: ...
|
||||
def is_trimmable(self) -> Literal[False]: ...
|
||||
@classmethod
|
||||
def from_state(cls, state, meta_state): # -> Self:
|
||||
...
|
||||
def from_state(cls, state, meta_state) -> Self: ...
|
||||
|
||||
class ConcatenateKVCache(_BaseCache):
|
||||
"""ConcatenateKVCache the simplest KV cache implementation.
|
||||
|
||||
Reference in New Issue
Block a user