mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 22:27:50 -05:00
48 lines
1.1 KiB
Python
48 lines
1.1 KiB
Python
"""
|
|
This type stub file was generated by pyright.
|
|
"""
|
|
|
|
from dataclasses import dataclass
|
|
from typing import Optional
|
|
|
|
import mlx.core as mx
|
|
|
|
@dataclass
|
|
class BaseModelArgs:
|
|
@classmethod
|
|
def from_dict(cls, params): # -> Self:
|
|
...
|
|
|
|
def create_causal_mask(
|
|
N: int,
|
|
offset: int = ...,
|
|
window_size: Optional[int] = ...,
|
|
right_padding: Optional[mx.array] = ...,
|
|
left_padding: Optional[mx.array] = ...,
|
|
): # -> array:
|
|
...
|
|
def create_attention_mask(
|
|
h, cache=..., window_size: Optional[int] = ..., return_array: bool = ...
|
|
): # -> array | Literal['causal'] | None:
|
|
...
|
|
def create_ssm_mask(h, cache=...): # -> None:
|
|
...
|
|
def quantized_scaled_dot_product_attention(
|
|
queries: mx.array,
|
|
q_keys: tuple[mx.array, mx.array, mx.array],
|
|
q_values: tuple[mx.array, mx.array, mx.array],
|
|
scale: float,
|
|
mask: Optional[mx.array],
|
|
group_size: int = ...,
|
|
bits: int = ...,
|
|
) -> mx.array: ...
|
|
def scaled_dot_product_attention(
|
|
queries,
|
|
keys,
|
|
values,
|
|
cache,
|
|
scale: float,
|
|
mask: Optional[mx.array],
|
|
sinks: Optional[mx.array] = ...,
|
|
) -> mx.array: ...
|