Compare commits

..

2 Commits

Author SHA1 Message Date
Ryuichi Leo Takashige
8f6f2f3065 Add fixes 2026-01-20 17:13:02 +00:00
Evan
e6af53c2ae foo 2026-01-20 17:12:31 +00:00
281 changed files with 1289 additions and 12817 deletions

View File

@@ -1,7 +0,0 @@
"""
This type stub file was generated by pyright.
"""
import os
if "TOKENIZERS_PARALLELISM" not in os.environ: ...

View File

@@ -1,3 +0,0 @@
"""
This type stub file was generated by pyright.
"""

View File

@@ -1,47 +0,0 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
import PIL.Image
import tqdm
from typing import Protocol
from mflux.models.common.config.config import Config
class BeforeLoopCallback(Protocol):
def call_before_loop(
self,
seed: int,
prompt: str,
latents: mx.array,
config: Config,
canny_image: PIL.Image.Image | None = ...,
depth_image: PIL.Image.Image | None = ...,
) -> None: ...
class InLoopCallback(Protocol):
def call_in_loop(
self,
t: int,
seed: int,
prompt: str,
latents: mx.array,
config: Config,
time_steps: tqdm,
) -> None: ...
class AfterLoopCallback(Protocol):
def call_after_loop(
self, seed: int, prompt: str, latents: mx.array, config: Config
) -> None: ...
class InterruptCallback(Protocol):
def call_interrupt(
self,
t: int,
seed: int,
prompt: str,
latents: mx.array,
config: Config,
time_steps: tqdm,
) -> None: ...

View File

@@ -1,24 +0,0 @@
"""
This type stub file was generated by pyright.
"""
from typing import TYPE_CHECKING
from mflux.callbacks.callback import (
AfterLoopCallback,
BeforeLoopCallback,
InLoopCallback,
InterruptCallback,
)
from mflux.callbacks.generation_context import GenerationContext
from mflux.models.common.config.config import Config
if TYPE_CHECKING: ...
class CallbackRegistry:
def __init__(self) -> None: ...
def register(self, callback) -> None: ...
def start(self, seed: int, prompt: str, config: Config) -> GenerationContext: ...
def before_loop_callbacks(self) -> list[BeforeLoopCallback]: ...
def in_loop_callbacks(self) -> list[InLoopCallback]: ...
def after_loop_callbacks(self) -> list[AfterLoopCallback]: ...
def interrupt_callbacks(self) -> list[InterruptCallback]: ...

View File

@@ -1,29 +0,0 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
import PIL.Image
import tqdm
from typing import TYPE_CHECKING
from mflux.callbacks.callback_registry import CallbackRegistry
from mflux.models.common.config.config import Config
if TYPE_CHECKING: ...
class GenerationContext:
def __init__(
self, registry: CallbackRegistry, seed: int, prompt: str, config: Config
) -> None: ...
def before_loop(
self,
latents: mx.array,
*,
canny_image: PIL.Image.Image | None = ...,
depth_image: PIL.Image.Image | None = ...,
) -> None: ...
def in_loop(self, t: int, latents: mx.array, time_steps: tqdm = ...) -> None: ...
def after_loop(self, latents: mx.array) -> None: ...
def interruption(
self, t: int, latents: mx.array, time_steps: tqdm = ...
) -> None: ...

View File

@@ -1,3 +0,0 @@
"""
This type stub file was generated by pyright.
"""

View File

@@ -1,22 +0,0 @@
"""
This type stub file was generated by pyright.
"""
import os
BATTERY_PERCENTAGE_STOP_LIMIT = ...
CONTROLNET_STRENGTH = ...
DEFAULT_DEV_FILL_GUIDANCE = ...
DEFAULT_DEPTH_GUIDANCE = ...
DIMENSION_STEP_PIXELS = ...
GUIDANCE_SCALE = ...
GUIDANCE_SCALE_KONTEXT = ...
IMAGE_STRENGTH = ...
MODEL_CHOICES = ...
MODEL_INFERENCE_STEPS = ...
QUANTIZE_CHOICES = ...
if os.environ.get("MFLUX_CACHE_DIR"):
MFLUX_CACHE_DIR = ...
else:
MFLUX_CACHE_DIR = ...
MFLUX_LORA_CACHE_DIR = ...

View File

@@ -1,3 +0,0 @@
"""
This type stub file was generated by pyright.
"""

View File

@@ -1,3 +0,0 @@
"""
This type stub file was generated by pyright.
"""

View File

@@ -1,3 +0,0 @@
"""
This type stub file was generated by pyright.
"""

View File

@@ -1,8 +0,0 @@
"""
This type stub file was generated by pyright.
"""
from mflux.models.common.config.config import Config
from mflux.models.common.config.model_config import ModelConfig
__all__ = ["Config", "ModelConfig"]

View File

@@ -1,66 +0,0 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from pathlib import Path
from typing import Any
from tqdm import tqdm
from mflux.models.common.config.model_config import ModelConfig
logger = ...
class Config:
def __init__(
self,
model_config: ModelConfig,
num_inference_steps: int = ...,
height: int = ...,
width: int = ...,
guidance: float = ...,
image_path: Path | str | None = ...,
image_strength: float | None = ...,
depth_image_path: Path | str | None = ...,
redux_image_paths: list[Path | str] | None = ...,
redux_image_strengths: list[float] | None = ...,
masked_image_path: Path | str | None = ...,
controlnet_strength: float | None = ...,
scheduler: str = ...,
) -> None: ...
@property
def height(self) -> int: ...
@property
def width(self) -> int: ...
@width.setter
def width(self, value): # -> None:
...
@property
def image_seq_len(self) -> int: ...
@property
def guidance(self) -> float: ...
@property
def num_inference_steps(self) -> int: ...
@property
def precision(self) -> mx.Dtype: ...
@property
def num_train_steps(self) -> int: ...
@property
def image_path(self) -> Path | None: ...
@property
def image_strength(self) -> float | None: ...
@property
def depth_image_path(self) -> Path | None: ...
@property
def redux_image_paths(self) -> list[Path] | None: ...
@property
def redux_image_strengths(self) -> list[float] | None: ...
@property
def masked_image_path(self) -> Path | None: ...
@property
def init_time_step(self) -> int: ...
@property
def time_steps(self) -> tqdm: ...
@property
def controlnet_strength(self) -> float | None: ...
@property
def scheduler(self) -> Any: ...

View File

@@ -1,86 +0,0 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from functools import lru_cache
from typing import Literal
class ModelConfig:
precision: mx.Dtype = ...
def __init__(
self,
priority: int,
aliases: list[str],
model_name: str,
base_model: str | None,
controlnet_model: str | None,
custom_transformer_model: str | None,
num_train_steps: int | None,
max_sequence_length: int | None,
supports_guidance: bool | None,
requires_sigma_shift: bool | None,
transformer_overrides: dict | None = ...,
) -> None: ...
@staticmethod
@lru_cache
def dev() -> ModelConfig: ...
@staticmethod
@lru_cache
def schnell() -> ModelConfig: ...
@staticmethod
@lru_cache
def dev_kontext() -> ModelConfig: ...
@staticmethod
@lru_cache
def dev_fill() -> ModelConfig: ...
@staticmethod
@lru_cache
def dev_redux() -> ModelConfig: ...
@staticmethod
@lru_cache
def dev_depth() -> ModelConfig: ...
@staticmethod
@lru_cache
def dev_controlnet_canny() -> ModelConfig: ...
@staticmethod
@lru_cache
def schnell_controlnet_canny() -> ModelConfig: ...
@staticmethod
@lru_cache
def dev_controlnet_upscaler() -> ModelConfig: ...
@staticmethod
@lru_cache
def dev_fill_catvton() -> ModelConfig: ...
@staticmethod
@lru_cache
def krea_dev() -> ModelConfig: ...
@staticmethod
@lru_cache
def flux2_klein_4b() -> ModelConfig: ...
@staticmethod
@lru_cache
def flux2_klein_9b() -> ModelConfig: ...
@staticmethod
@lru_cache
def qwen_image() -> ModelConfig: ...
@staticmethod
@lru_cache
def qwen_image_edit() -> ModelConfig: ...
@staticmethod
@lru_cache
def fibo() -> ModelConfig: ...
@staticmethod
@lru_cache
def z_image_turbo() -> ModelConfig: ...
@staticmethod
@lru_cache
def seedvr2_3b() -> ModelConfig: ...
def x_embedder_input_dim(self) -> int: ...
def is_canny(self) -> bool: ...
@staticmethod
def from_name(
model_name: str, base_model: Literal["dev", "schnell", "krea-dev"] | None = ...
) -> ModelConfig: ...
AVAILABLE_MODELS = ...

View File

@@ -1,7 +0,0 @@
"""
This type stub file was generated by pyright.
"""
"""
This type stub file was generated by pyright.
"""

View File

@@ -1,49 +0,0 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from pathlib import Path
from typing import TYPE_CHECKING, TypeAlias
from mlx import nn
from mflux.models.common.vae.tiling_config import TilingConfig
from mflux.models.fibo.latent_creator.fibo_latent_creator import FiboLatentCreator
from mflux.models.flux.latent_creator.flux_latent_creator import FluxLatentCreator
from mflux.models.qwen.latent_creator.qwen_latent_creator import QwenLatentCreator
from mflux.models.z_image.latent_creator.z_image_latent_creator import (
ZImageLatentCreator,
)
if TYPE_CHECKING:
LatentCreatorType: TypeAlias = type[
FiboLatentCreator | FluxLatentCreator | QwenLatentCreator | ZImageLatentCreator
]
class Img2Img:
def __init__(
self,
vae: nn.Module,
latent_creator: LatentCreatorType,
sigmas: mx.array,
init_time_step: int,
image_path: str | Path | None,
tiling_config: TilingConfig | None = ...,
) -> None: ...
class LatentCreator:
@staticmethod
def create_for_txt2img_or_img2img(
seed: int, height: int, width: int, img2img: Img2Img
) -> mx.array: ...
@staticmethod
def encode_image(
vae: nn.Module,
image_path: str | Path,
height: int,
width: int,
tiling_config: TilingConfig | None = ...,
) -> mx.array: ...
@staticmethod
def add_noise_by_interpolation(
clean: mx.array, noise: mx.array, sigma: float
) -> mx.array: ...

View File

@@ -1,3 +0,0 @@
"""
This type stub file was generated by pyright.
"""

View File

@@ -1,13 +0,0 @@
"""
This type stub file was generated by pyright.
"""
from mlx import nn
from mflux.models.common.lora.layer.linear_lora_layer import LoRALinear
class FusedLoRALinear(nn.Module):
def __init__(
self, base_linear: nn.Linear | nn.QuantizedLinear, loras: list[LoRALinear]
) -> None: ...
def __call__(self, x): # -> array:
...

View File

@@ -1,22 +0,0 @@
"""
This type stub file was generated by pyright.
"""
from mlx import nn
class LoRALinear(nn.Module):
@staticmethod
def from_linear(
linear: nn.Linear | nn.QuantizedLinear, r: int = ..., scale: float = ...
): # -> LoRALinear:
...
def __init__(
self,
input_dims: int,
output_dims: int,
r: int = ...,
scale: float = ...,
bias: bool = ...,
) -> None: ...
def __call__(self, x): # -> array:
...

View File

@@ -1,26 +0,0 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
import mlx.nn as nn
from collections.abc import Callable
from dataclasses import dataclass
from mflux.models.common.lora.mapping.lora_mapping import LoRATarget
@dataclass
class PatternMatch:
source_pattern: str
target_path: str
matrix_name: str
transpose: bool
transform: Callable[[mx.array], mx.array] | None = ...
class LoRALoader:
@staticmethod
def load_and_apply_lora(
lora_mapping: list[LoRATarget],
transformer: nn.Module,
lora_paths: list[str] | None = ...,
lora_scales: list[float] | None = ...,
) -> tuple[list[str], list[float]]: ...

View File

@@ -1,21 +0,0 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from collections.abc import Callable
from dataclasses import dataclass
from typing import List, Protocol
@dataclass
class LoRATarget:
model_path: str
possible_up_patterns: List[str]
possible_down_patterns: List[str]
possible_alpha_patterns: List[str] = ...
up_transform: Callable[[mx.array], mx.array] | None = ...
down_transform: Callable[[mx.array], mx.array] | None = ...
class LoRAMapping(Protocol):
@staticmethod
def get_mapping() -> List[LoRATarget]: ...

View File

@@ -1,9 +0,0 @@
"""
This type stub file was generated by pyright.
"""
import mlx.nn as nn
class LoRASaver:
@staticmethod
def bake_and_strip_lora(module: nn.Module) -> nn.Module: ...

View File

@@ -1,35 +0,0 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
class LoraTransforms:
@staticmethod
def split_q_up(tensor: mx.array) -> mx.array: ...
@staticmethod
def split_k_up(tensor: mx.array) -> mx.array: ...
@staticmethod
def split_v_up(tensor: mx.array) -> mx.array: ...
@staticmethod
def split_q_down(tensor: mx.array) -> mx.array: ...
@staticmethod
def split_k_down(tensor: mx.array) -> mx.array: ...
@staticmethod
def split_v_down(tensor: mx.array) -> mx.array: ...
@staticmethod
def split_single_q_up(tensor: mx.array) -> mx.array: ...
@staticmethod
def split_single_k_up(tensor: mx.array) -> mx.array: ...
@staticmethod
def split_single_v_up(tensor: mx.array) -> mx.array: ...
@staticmethod
def split_single_mlp_up(tensor: mx.array) -> mx.array: ...
@staticmethod
def split_single_q_down(tensor: mx.array) -> mx.array: ...
@staticmethod
def split_single_k_down(tensor: mx.array) -> mx.array: ...
@staticmethod
def split_single_v_down(tensor: mx.array) -> mx.array: ...
@staticmethod
def split_single_mlp_down(tensor: mx.array) -> mx.array: ...

View File

@@ -1,17 +0,0 @@
"""
This type stub file was generated by pyright.
"""
from mflux.models.common.resolution.config_resolution import ConfigResolution
from mflux.models.common.resolution.lora_resolution import LoraResolution
from mflux.models.common.resolution.path_resolution import PathResolution
from mflux.models.common.resolution.quantization_resolution import (
QuantizationResolution,
)
__all__ = [
"ConfigResolution",
"LoraResolution",
"PathResolution",
"QuantizationResolution",
]

View File

@@ -1,39 +0,0 @@
"""
This type stub file was generated by pyright.
"""
from enum import Enum
from typing import NamedTuple
class QuantizationAction(Enum):
NONE = ...
STORED = ...
REQUESTED = ...
class PathAction(Enum):
LOCAL = ...
HUGGINGFACE_CACHED = ...
HUGGINGFACE = ...
ERROR = ...
class LoraAction(Enum):
LOCAL = ...
REGISTRY = ...
HUGGINGFACE_COLLECTION_CACHED = ...
HUGGINGFACE_COLLECTION = ...
HUGGINGFACE_REPO_CACHED = ...
HUGGINGFACE_REPO = ...
ERROR = ...
class ConfigAction(Enum):
EXACT_MATCH = ...
EXPLICIT_BASE = ...
INFER_SUBSTRING = ...
ERROR = ...
class Rule(NamedTuple):
priority: int
name: str
check: str
action: QuantizationAction | PathAction | LoraAction | ConfigAction
...

View File

@@ -1,14 +0,0 @@
"""
This type stub file was generated by pyright.
"""
from typing import TYPE_CHECKING
from mflux.models.common.config.model_config import ModelConfig
if TYPE_CHECKING: ...
logger = ...
class ConfigResolution:
RULES = ...
@staticmethod
def resolve(model_name: str, base_model: str | None = ...) -> ModelConfig: ...

View File

@@ -1,21 +0,0 @@
"""
This type stub file was generated by pyright.
"""
from pathlib import Path
logger = ...
class LoraResolution:
RULES = ...
_registry: dict[str, Path] = ...
@staticmethod
def resolve(path: str) -> str: ...
@staticmethod
def resolve_paths(paths: list[str] | None) -> list[str]: ...
@staticmethod
def resolve_scales(scales: list[float] | None, num_paths: int) -> list[float]: ...
@staticmethod
def get_registry() -> dict[str, Path]: ...
@staticmethod
def discover_files(library_paths: list[Path]) -> dict[str, Path]: ...

View File

@@ -1,12 +0,0 @@
"""
This type stub file was generated by pyright.
"""
from pathlib import Path
logger = ...
class PathResolution:
RULES = ...
@staticmethod
def resolve(path: str | None, patterns: list[str] | None = ...) -> Path | None: ...

View File

@@ -1,12 +0,0 @@
"""
This type stub file was generated by pyright.
"""
logger = ...
class QuantizationResolution:
RULES = ...
@staticmethod
def resolve(
stored: int | None, requested: int | None
) -> tuple[int | None, str | None]: ...

View File

@@ -1,26 +0,0 @@
"""
This type stub file was generated by pyright.
"""
from .flow_match_euler_discrete_scheduler import FlowMatchEulerDiscreteScheduler
from .linear_scheduler import LinearScheduler
from .seedvr2_euler_scheduler import SeedVR2EulerScheduler
__all__ = [
"LinearScheduler",
"FlowMatchEulerDiscreteScheduler",
"SeedVR2EulerScheduler",
]
class SchedulerModuleNotFound(ValueError): ...
class SchedulerClassNotFound(ValueError): ...
class InvalidSchedulerType(TypeError): ...
SCHEDULER_REGISTRY = ...
def register_contrib(scheduler_object, scheduler_name=...): # -> None:
...
def try_import_external_scheduler(
scheduler_object_path: str,
): # -> type[BaseScheduler]:
...

View File

@@ -1,16 +0,0 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from abc import ABC, abstractmethod
class BaseScheduler(ABC):
@property
@abstractmethod
def sigmas(self) -> mx.array: ...
@abstractmethod
def step(
self, noise: mx.array, timestep: int, latents: mx.array, **kwargs
) -> mx.array: ...
def scale_model_input(self, latents: mx.array, t: int) -> mx.array: ...

View File

@@ -1,26 +0,0 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from typing import TYPE_CHECKING
from mflux.models.common.config.config import Config
from mflux.models.common.schedulers.base_scheduler import BaseScheduler
if TYPE_CHECKING: ...
class FlowMatchEulerDiscreteScheduler(BaseScheduler):
def __init__(self, config: Config) -> None: ...
@property
def sigmas(self) -> mx.array: ...
@property
def timesteps(self) -> mx.array: ...
def set_image_seq_len(self, image_seq_len: int) -> None: ...
@staticmethod
def get_timesteps_and_sigmas(
image_seq_len: int, num_inference_steps: int, num_train_timesteps: int = ...
) -> tuple[mx.array, mx.array]: ...
def step(
self, noise: mx.array, timestep: int, latents: mx.array, **kwargs
) -> mx.array: ...
def scale_model_input(self, latents: mx.array, t: int) -> mx.array: ...

View File

@@ -1,20 +0,0 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from typing import TYPE_CHECKING
from mflux.models.common.config.config import Config
from mflux.models.common.schedulers.base_scheduler import BaseScheduler
if TYPE_CHECKING: ...
class LinearScheduler(BaseScheduler):
def __init__(self, config: Config) -> None: ...
@property
def sigmas(self) -> mx.array: ...
@property
def timesteps(self) -> mx.array: ...
def step(
self, noise: mx.array, timestep: int, latents: mx.array, **kwargs
) -> mx.array: ...

View File

@@ -1,20 +0,0 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from typing import TYPE_CHECKING
from mflux.models.common.config.config import Config
from mflux.models.common.schedulers.base_scheduler import BaseScheduler
if TYPE_CHECKING: ...
class SeedVR2EulerScheduler(BaseScheduler):
def __init__(self, config: Config) -> None: ...
@property
def timesteps(self) -> mx.array: ...
@property
def sigmas(self) -> mx.array: ...
def step(
self, noise: mx.array, timestep: int, latents: mx.array, **kwargs
) -> mx.array: ...

View File

@@ -1,24 +0,0 @@
"""
This type stub file was generated by pyright.
"""
from mflux.models.common.tokenizer.tokenizer import (
BaseTokenizer,
LanguageTokenizer,
Tokenizer,
VisionLanguageTokenizer,
)
from mflux.models.common.tokenizer.tokenizer_loader import TokenizerLoader
from mflux.models.common.tokenizer.tokenizer_output import TokenizerOutput
"""
This type stub file was generated by pyright.
"""
__all__ = [
"Tokenizer",
"BaseTokenizer",
"LanguageTokenizer",
"VisionLanguageTokenizer",
"TokenizerLoader",
"TokenizerOutput",
]

View File

@@ -1,74 +0,0 @@
"""
This type stub file was generated by pyright.
"""
from abc import ABC, abstractmethod
from typing import Protocol, runtime_checkable
from PIL import Image
from transformers import PreTrainedTokenizer
from mflux.models.common.tokenizer.tokenizer_output import TokenizerOutput
"""
This type stub file was generated by pyright.
"""
@runtime_checkable
class Tokenizer(Protocol):
tokenizer: PreTrainedTokenizer
def tokenize(
self,
prompt: str | list[str],
images: list[Image.Image] | None = ...,
max_length: int | None = ...,
**kwargs,
) -> TokenizerOutput: ...
class BaseTokenizer(ABC):
def __init__(
self, tokenizer: PreTrainedTokenizer, max_length: int = ...
) -> None: ...
@abstractmethod
def tokenize(
self,
prompt: str | list[str],
images: list[Image.Image] | None = ...,
max_length: int | None = ...,
**kwargs,
) -> TokenizerOutput: ...
class LanguageTokenizer(BaseTokenizer):
def __init__(
self,
tokenizer: PreTrainedTokenizer,
max_length: int = ...,
padding: str = ...,
return_attention_mask: bool = ...,
template: str | None = ...,
use_chat_template: bool = ...,
chat_template_kwargs: dict | None = ...,
add_special_tokens: bool = ...,
) -> None: ...
def tokenize(
self,
prompt: str | list[str],
images: list[Image.Image] | None = ...,
max_length: int | None = ...,
**kwargs,
) -> TokenizerOutput: ...
class VisionLanguageTokenizer(BaseTokenizer):
def __init__(
self,
tokenizer: PreTrainedTokenizer,
processor,
max_length: int = ...,
template: str | None = ...,
image_token: str = ...,
) -> None: ...
def tokenize(
self,
prompt: str | list[str],
images: list[Image.Image] | None = ...,
max_length: int | None = ...,
**kwargs,
) -> TokenizerOutput: ...

View File

@@ -1,22 +0,0 @@
"""
This type stub file was generated by pyright.
"""
from typing import TYPE_CHECKING
from mflux.models.common.tokenizer.tokenizer import BaseTokenizer
from mflux.models.common.weights.loading.weight_definition import TokenizerDefinition
"""
This type stub file was generated by pyright.
"""
if TYPE_CHECKING: ...
class TokenizerLoader:
@staticmethod
def load(definition: TokenizerDefinition, model_path: str) -> BaseTokenizer: ...
@staticmethod
def load_all(
definitions: list[TokenizerDefinition],
model_path: str,
max_length_overrides: dict[str, int] | None = ...,
) -> dict[str, BaseTokenizer]: ...

View File

@@ -1,17 +0,0 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from dataclasses import dataclass
"""
This type stub file was generated by pyright.
"""
@dataclass
class TokenizerOutput:
input_ids: mx.array
attention_mask: mx.array
pixel_values: mx.array | None = ...
image_grid_thw: mx.array | None = ...

View File

@@ -1,8 +0,0 @@
"""
This type stub file was generated by pyright.
"""
from mflux.models.common.vae.tiling_config import TilingConfig
from mflux.models.common.vae.vae_tiler import VAETiler
__all__ = ["TilingConfig", "VAETiler"]

View File

@@ -1,13 +0,0 @@
"""
This type stub file was generated by pyright.
"""
from dataclasses import dataclass
@dataclass(frozen=True, slots=True)
class TilingConfig:
vae_decode_tiles_per_dim: int | None = ...
vae_decode_overlap: int = ...
vae_encode_tiled: bool = ...
vae_encode_tile_size: int = ...
vae_encode_tile_overlap: int = ...

View File

@@ -1,27 +0,0 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from typing import Callable
class VAETiler:
@staticmethod
def encode_image_tiled(
*,
image: mx.array,
encode_fn: Callable[[mx.array], mx.array],
latent_channels: int,
tile_size: tuple[int, int] = ...,
tile_overlap: tuple[int, int] = ...,
spatial_scale: int = ...,
) -> mx.array: ...
@staticmethod
def decode_image_tiled(
*,
latent: mx.array,
decode_fn: Callable[[mx.array], mx.array],
tile_size: tuple[int, int] = ...,
tile_overlap: tuple[int, int] = ...,
spatial_scale: int = ...,
) -> mx.array: ...

View File

@@ -1,17 +0,0 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
from mflux.models.common.vae.tiling_config import TilingConfig
class VAEUtil:
@staticmethod
def encode(
vae: nn.Module, image: mx.array, tiling_config: TilingConfig | None = ...
) -> mx.array: ...
@staticmethod
def decode(
vae: nn.Module, latent: mx.array, tiling_config: TilingConfig | None = ...
) -> mx.array: ...

View File

@@ -1,18 +0,0 @@
"""
This type stub file was generated by pyright.
"""
from mflux.models.common.weights.loading.loaded_weights import LoadedWeights, MetaData
from mflux.models.common.weights.loading.weight_applier import WeightApplier
from mflux.models.common.weights.loading.weight_definition import ComponentDefinition
from mflux.models.common.weights.loading.weight_loader import WeightLoader
from mflux.models.common.weights.saving.model_saver import ModelSaver
__all__ = [
"ComponentDefinition",
"LoadedWeights",
"MetaData",
"ModelSaver",
"WeightApplier",
"WeightLoader",
]

View File

@@ -1,18 +0,0 @@
"""
This type stub file was generated by pyright.
"""
from dataclasses import dataclass
@dataclass
class MetaData:
quantization_level: int | None = ...
mflux_version: str | None = ...
@dataclass
class LoadedWeights:
components: dict[str, dict]
meta_data: MetaData
def __getattr__(self, name: str) -> dict | None: ...
def num_transformer_blocks(self, component_name: str = ...) -> int: ...
def num_single_transformer_blocks(self, component_name: str = ...) -> int: ...

View File

@@ -1,30 +0,0 @@
"""
This type stub file was generated by pyright.
"""
import mlx.nn as nn
from typing import TYPE_CHECKING
from mflux.models.common.weights.loading.loaded_weights import LoadedWeights
from mflux.models.common.weights.loading.weight_definition import (
ComponentDefinition,
WeightDefinitionType,
)
if TYPE_CHECKING: ...
class WeightApplier:
@staticmethod
def apply_and_quantize_single(
weights: LoadedWeights,
model: nn.Module,
component: ComponentDefinition,
quantize_arg: int | None,
quantization_predicate=...,
) -> int | None: ...
@staticmethod
def apply_and_quantize(
weights: LoadedWeights,
models: dict[str, nn.Module],
quantize_arg: int | None,
weight_definition: WeightDefinitionType,
) -> int | None: ...

View File

@@ -1,73 +0,0 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from dataclasses import dataclass
from typing import Callable, List, TYPE_CHECKING, TypeAlias
from mflux.models.common.weights.mapping.weight_mapping import WeightTarget
from mflux.models.common.tokenizer.tokenizer import BaseTokenizer
from mflux.models.depth_pro.weights.depth_pro_weight_definition import (
DepthProWeightDefinition,
)
from mflux.models.fibo.weights.fibo_weight_definition import FIBOWeightDefinition
from mflux.models.fibo_vlm.weights.fibo_vlm_weight_definition import (
FIBOVLMWeightDefinition,
)
from mflux.models.flux.weights.flux_weight_definition import FluxWeightDefinition
from mflux.models.qwen.weights.qwen_weight_definition import QwenWeightDefinition
from mflux.models.seedvr2.weights.seedvr2_weight_definition import (
SeedVR2WeightDefinition,
)
from mflux.models.z_image.weights.z_image_weight_definition import (
ZImageWeightDefinition,
)
"""
This type stub file was generated by pyright.
"""
if TYPE_CHECKING:
WeightDefinitionType: TypeAlias = type[
FluxWeightDefinition
| FIBOWeightDefinition
| FIBOVLMWeightDefinition
| QwenWeightDefinition
| ZImageWeightDefinition
| SeedVR2WeightDefinition
| DepthProWeightDefinition
]
@dataclass
class ComponentDefinition:
name: str
hf_subdir: str
mapping_getter: Callable[[], List[WeightTarget]] | None = ...
model_attr: str | None = ...
num_blocks: int | None = ...
num_layers: int | None = ...
loading_mode: str = ...
precision: mx.Dtype | None = ...
skip_quantization: bool = ...
bulk_transform: Callable[[mx.array], mx.array] | None = ...
weight_subkey: str | None = ...
download_url: str | None = ...
weight_prefix_filters: List[str] | None = ...
weight_files: List[str] | None = ...
@dataclass
class TokenizerDefinition:
name: str
hf_subdir: str
tokenizer_class: str = ...
fallback_subdirs: List[str] | None = ...
download_patterns: List[str] | None = ...
encoder_class: type[BaseTokenizer] | None = ...
max_length: int = ...
padding: str = ...
template: str | None = ...
use_chat_template: bool = ...
chat_template_kwargs: dict | None = ...
add_special_tokens: bool = ...
processor_class: type | None = ...
image_token: str = ...
chat_template: str | None = ...

View File

@@ -1,23 +0,0 @@
"""
This type stub file was generated by pyright.
"""
from typing import TYPE_CHECKING
from mflux.models.common.weights.loading.loaded_weights import LoadedWeights
from mflux.models.common.weights.loading.weight_definition import (
ComponentDefinition,
WeightDefinitionType,
)
if TYPE_CHECKING: ...
logger = ...
class WeightLoader:
@staticmethod
def load_single(
component: ComponentDefinition, repo_id: str, file_pattern: str = ...
) -> LoadedWeights: ...
@staticmethod
def load(
weight_definition: WeightDefinitionType, model_path: str | None = ...
) -> LoadedWeights: ...

View File

@@ -1,16 +0,0 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from typing import Dict, List, Optional
from mflux.models.common.weights.mapping.weight_mapping import WeightTarget
class WeightMapper:
@staticmethod
def apply_mapping(
hf_weights: Dict[str, mx.array],
mapping: List[WeightTarget],
num_blocks: Optional[int] = ...,
num_layers: Optional[int] = ...,
) -> Dict: ...

View File

@@ -1,23 +0,0 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from dataclasses import dataclass
from typing import Callable, List, Optional, Protocol
"""
This type stub file was generated by pyright.
"""
@dataclass
class WeightTarget:
to_pattern: str
from_pattern: List[str]
transform: Optional[Callable[[mx.array], mx.array]] = ...
required: bool = ...
max_blocks: Optional[int] = ...
class WeightMapping(Protocol):
@staticmethod
def get_mapping() -> List[WeightTarget]: ...

View File

@@ -1,17 +0,0 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
class WeightTransforms:
@staticmethod
def reshape_gamma_to_1d(tensor: mx.array) -> mx.array: ...
@staticmethod
def transpose_patch_embed(tensor: mx.array) -> mx.array: ...
@staticmethod
def transpose_conv3d_weight(tensor: mx.array) -> mx.array: ...
@staticmethod
def transpose_conv2d_weight(tensor: mx.array) -> mx.array: ...
@staticmethod
def transpose_conv_transpose2d_weight(tensor: mx.array) -> mx.array: ...

View File

@@ -1,14 +0,0 @@
"""
This type stub file was generated by pyright.
"""
from typing import Any, TYPE_CHECKING
from mflux.models.common.weights.loading.weight_definition import WeightDefinitionType
if TYPE_CHECKING: ...
class ModelSaver:
@staticmethod
def save_model(
model: Any, bits: int, base_path: str, weight_definition: WeightDefinitionType
) -> None: ...

View File

@@ -1,9 +0,0 @@
"""
This type stub file was generated by pyright.
"""
from mflux.models.depth_pro.model.depth_pro_model import DepthProModel
class DepthProInitializer:
@staticmethod
def init(model: DepthProModel, quantize: int | None = ...) -> None: ...

View File

@@ -1,10 +0,0 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
import mlx.nn as nn
class FeatureFusionBlock2d(nn.Module):
def __init__(self, num_features: int, deconv: bool = ...) -> None: ...
def __call__(self, x0: mx.array, x1: mx.array | None = ...) -> mx.array: ...

View File

@@ -1,17 +0,0 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
import mlx.nn as nn
class MultiresConvDecoder(nn.Module):
def __init__(self) -> None: ...
def __call__(
self,
x0_latent: mx.array,
x1_latent: mx.array,
x0_features: mx.array,
x1_features: mx.array,
x_global_features: mx.array,
) -> mx.array: ...

View File

@@ -1,10 +0,0 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
import mlx.nn as nn
class ResidualBlock(nn.Module):
def __init__(self, num_features: int) -> None: ...
def __call__(self, x: mx.array) -> mx.array: ...

View File

@@ -1,20 +0,0 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from dataclasses import dataclass
from pathlib import Path
from PIL import Image
@dataclass
class DepthResult:
depth_image: Image.Image
depth_array: mx.array
min_depth: float
max_depth: float
...
class DepthPro:
def __init__(self, quantize: int | None = ...) -> None: ...
def create_depth_map(self, image_path: str | Path) -> DepthResult: ...

View File

@@ -1,12 +0,0 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
import mlx.nn as nn
class DepthProModel(nn.Module):
def __init__(self) -> None: ...
def __call__(
self, x0: mx.array, x1: mx.array, x2: mx.array
) -> tuple[mx.array, mx.array]: ...

View File

@@ -1,15 +0,0 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
import mlx.nn as nn
class DepthProUtil:
@staticmethod
def split(x: mx.array, overlap_ratio: float = ...) -> mx.array: ...
@staticmethod
def interpolate(x: mx.array, size=..., scale_factor=...): # -> array:
...
@staticmethod
def apply_conv(x: mx.array, conv_module: nn.Module) -> mx.array: ...

View File

@@ -1,12 +0,0 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
class Attention(nn.Module):
def __init__(
self, dim: int = ..., head_dim: int = ..., num_heads: int = ...
) -> None: ...
def __call__(self, x: mx.array) -> mx.array: ...

View File

@@ -1,10 +0,0 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
import mlx.nn as nn
class DinoVisionTransformer(nn.Module):
def __init__(self) -> None: ...
def __call__(self, x: mx.array) -> tuple[mx.array, mx.array, mx.array]: ...

View File

@@ -1,10 +0,0 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
import mlx.nn as nn
class LayerScale(nn.Module):
def __init__(self, dims: int, init_values: float = ...) -> None: ...
def __call__(self, x: mx.array) -> mx.array: ...

View File

@@ -1,10 +0,0 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
import mlx.nn as nn
class MLP(nn.Module):
def __init__(self) -> None: ...
def __call__(self, x: mx.array) -> mx.array: ...

View File

@@ -1,10 +0,0 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
import mlx.nn as nn
class PatchEmbed(nn.Module):
def __init__(self) -> None: ...
def __call__(self, x: mx.array) -> mx.array: ...

View File

@@ -1,10 +0,0 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
import mlx.nn as nn
class TransformerBlock(nn.Module):
def __init__(self) -> None: ...
def __call__(self, x: mx.array) -> mx.array: ...

View File

@@ -1,12 +0,0 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
import mlx.nn as nn
class DepthProEncoder(nn.Module):
def __init__(self) -> None: ...
def __call__(
self, x0: mx.array, x1: mx.array, x2: mx.array
) -> tuple[mx.array, mx.array, mx.array, mx.array, mx.array]: ...

View File

@@ -1,16 +0,0 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
import mlx.nn as nn
class UpSampleBlock(nn.Module):
def __init__(
self,
dim_in: int = ...,
dim_int: int = ...,
dim_out: int = ...,
upsample_layers: int = ...,
) -> None: ...
def __call__(self, x: mx.array) -> mx.array: ...

View File

@@ -1,10 +0,0 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
import mlx.nn as nn
class FOVHead(nn.Module):
def __init__(self) -> None: ...
def __call__(self, x: mx.array) -> mx.array: ...

View File

@@ -1,23 +0,0 @@
"""
This type stub file was generated by pyright.
"""
from typing import List
from mflux.models.common.weights.loading.weight_definition import (
ComponentDefinition,
TokenizerDefinition,
)
"""
This type stub file was generated by pyright.
"""
class DepthProWeightDefinition:
@staticmethod
def get_components() -> List[ComponentDefinition]: ...
@staticmethod
def get_tokenizers() -> List[TokenizerDefinition]: ...
@staticmethod
def get_download_patterns() -> List[str]: ...
@staticmethod
def quantization_predicate(path: str, module) -> bool: ...

View File

@@ -1,13 +0,0 @@
"""
This type stub file was generated by pyright.
"""
from typing import List
from mflux.models.common.weights.mapping.weight_mapping import (
WeightMapping,
WeightTarget,
)
class DepthProWeightMapping(WeightMapping):
@staticmethod
def get_mapping() -> List[WeightTarget]: ...

View File

@@ -1,13 +0,0 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
class FiboLatentCreator:
@staticmethod
def create_noise(seed: int, height: int, width: int) -> mx.array: ...
@staticmethod
def pack_latents(latents: mx.array, height: int, width: int) -> mx.array: ...
@staticmethod
def unpack_latents(latents: mx.array, height: int, width: int) -> mx.array: ...

View File

@@ -1,23 +0,0 @@
"""
This type stub file was generated by pyright.
"""
from typing import List
from mflux.models.common.weights.loading.weight_definition import (
ComponentDefinition,
TokenizerDefinition,
)
"""
This type stub file was generated by pyright.
"""
class FIBOWeightDefinition:
@staticmethod
def get_components() -> List[ComponentDefinition]: ...
@staticmethod
def get_tokenizers() -> List[TokenizerDefinition]: ...
@staticmethod
def get_download_patterns() -> List[str]: ...
@staticmethod
def quantization_predicate(path: str, module) -> bool: ...

View File

@@ -1,17 +0,0 @@
"""
This type stub file was generated by pyright.
"""
from typing import List
from mflux.models.common.weights.mapping.weight_mapping import (
WeightMapping,
WeightTarget,
)
class FIBOWeightMapping(WeightMapping):
@staticmethod
def get_transformer_mapping() -> List[WeightTarget]: ...
@staticmethod
def get_text_encoder_mapping() -> List[WeightTarget]: ...
@staticmethod
def get_vae_mapping() -> List[WeightTarget]: ...

View File

@@ -1,8 +0,0 @@
"""
This type stub file was generated by pyright.
"""
from mflux.models.qwen.tokenizer.qwen_image_processor import QwenImageProcessor
class Qwen2VLImageProcessor(QwenImageProcessor):
def __init__(self) -> None: ...

View File

@@ -1,28 +0,0 @@
"""
This type stub file was generated by pyright.
"""
from typing import Optional, Union
from PIL import Image
class Qwen2VLProcessor:
def __init__(self, tokenizer) -> None: ...
def apply_chat_template(
self,
messages,
tokenize: bool = ...,
add_generation_prompt: bool = ...,
return_tensors: Optional[str] = ...,
return_dict: bool = ...,
**kwargs,
): # -> dict[Any, Any]:
...
def __call__(
self,
text: Optional[Union[str, list[str]]] = ...,
images: Optional[Union[Image.Image, list[Image.Image]]] = ...,
padding: bool = ...,
return_tensors: Optional[str] = ...,
**kwargs,
): # -> dict[Any, Any]:
...

View File

@@ -1,24 +0,0 @@
"""
This type stub file was generated by pyright.
"""
from typing import List
from mflux.models.common.weights.loading.weight_definition import (
ComponentDefinition,
TokenizerDefinition,
)
"""
This type stub file was generated by pyright.
"""
QWEN2VL_CHAT_TEMPLATE = ...
class FIBOVLMWeightDefinition:
@staticmethod
def get_components() -> List[ComponentDefinition]: ...
@staticmethod
def get_tokenizers() -> List[TokenizerDefinition]: ...
@staticmethod
def get_download_patterns() -> List[str]: ...
@staticmethod
def quantization_predicate(path: str, module) -> bool: ...

View File

@@ -1,15 +0,0 @@
"""
This type stub file was generated by pyright.
"""
from typing import List
from mflux.models.common.weights.mapping.weight_mapping import (
WeightMapping,
WeightTarget,
)
class FIBOVLMWeightMapping(WeightMapping):
@staticmethod
def get_vlm_decoder_mapping(num_layers: int = ...) -> List[WeightTarget]: ...
@staticmethod
def get_vlm_visual_mapping(depth: int = ...) -> List[WeightTarget]: ...

View File

@@ -1,3 +0,0 @@
"""
This type stub file was generated by pyright.
"""

View File

@@ -1,3 +0,0 @@
"""
This type stub file was generated by pyright.
"""

View File

@@ -1,53 +0,0 @@
"""
This type stub file was generated by pyright.
"""
from mflux.models.common.config import ModelConfig
class FluxInitializer:
@staticmethod
def init(
model,
model_config: ModelConfig,
quantize: int | None,
model_path: str | None = ...,
lora_paths: list[str] | None = ...,
lora_scales: list[float] | None = ...,
custom_transformer=...,
) -> None: ...
@staticmethod
def init_depth(
model,
model_config: ModelConfig,
quantize: int | None,
model_path: str | None = ...,
lora_paths: list[str] | None = ...,
lora_scales: list[float] | None = ...,
) -> None: ...
@staticmethod
def init_redux(
model,
model_config: ModelConfig,
quantize: int | None,
model_path: str | None = ...,
lora_paths: list[str] | None = ...,
lora_scales: list[float] | None = ...,
) -> None: ...
@staticmethod
def init_controlnet(
model,
model_config: ModelConfig,
quantize: int | None,
model_path: str | None = ...,
lora_paths: list[str] | None = ...,
lora_scales: list[float] | None = ...,
) -> None: ...
@staticmethod
def init_concept(
model,
model_config: ModelConfig,
quantize: int | None,
model_path: str | None = ...,
lora_paths: list[str] | None = ...,
lora_scales: list[float] | None = ...,
) -> None: ...

View File

@@ -1,7 +0,0 @@
"""
This type stub file was generated by pyright.
"""
"""
This type stub file was generated by pyright.
"""

View File

@@ -1,19 +0,0 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
"""
This type stub file was generated by pyright.
"""
class FluxLatentCreator:
@staticmethod
def create_noise(seed: int, height: int, width: int) -> mx.array: ...
@staticmethod
def pack_latents(
latents: mx.array, height: int, width: int, num_channels_latents: int = ...
) -> mx.array: ...
@staticmethod
def unpack_latents(latents: mx.array, height: int, width: int) -> mx.array: ...

View File

@@ -1,7 +0,0 @@
"""
This type stub file was generated by pyright.
"""
"""
This type stub file was generated by pyright.
"""

View File

@@ -1,10 +0,0 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
class CLIPEmbeddings(nn.Module):
def __init__(self, dims: int) -> None: ...
def __call__(self, tokens: mx.array) -> mx.array: ...

View File

@@ -1,14 +0,0 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
"""
This type stub file was generated by pyright.
"""
class CLIPEncoder(nn.Module):
def __init__(self) -> None: ...
def __call__(self, tokens: mx.array) -> mx.array: ...

View File

@@ -1,12 +0,0 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
class CLIPEncoderLayer(nn.Module):
def __init__(self, layer: int) -> None: ...
def __call__(
self, hidden_states: mx.array, causal_attention_mask: mx.array
) -> mx.array: ...

View File

@@ -1,12 +0,0 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
class CLIPMLP(nn.Module):
def __init__(self) -> None: ...
def __call__(self, hidden_states: mx.array) -> mx.array: ...
@staticmethod
def quick_gelu(input_array: mx.array) -> mx.array: ...

View File

@@ -1,18 +0,0 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
class CLIPSdpaAttention(nn.Module):
head_dimension = ...
batch_size = ...
num_heads = ...
def __init__(self) -> None: ...
def __call__(
self, hidden_states: mx.array, causal_attention_mask: mx.array
) -> mx.array: ...
@staticmethod
def reshape_and_transpose(x, batch_size, num_heads, head_dim): # -> array:
...

View File

@@ -1,12 +0,0 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
class CLIPTextModel(nn.Module):
def __init__(self, dims: int, num_encoder_layers: int) -> None: ...
def __call__(self, tokens: mx.array) -> tuple[mx.array, mx.array]: ...
@staticmethod
def create_causal_attention_mask(input_shape: tuple) -> mx.array: ...

View File

@@ -1,12 +0,0 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
class EncoderCLIP(nn.Module):
def __init__(self, num_encoder_layers: int) -> None: ...
def __call__(
self, tokens: mx.array, causal_attention_mask: mx.array
) -> mx.array: ...

View File

@@ -1,25 +0,0 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mflux.models.common.tokenizer import Tokenizer
from mflux.models.flux.model.flux_text_encoder.clip_encoder.clip_encoder import (
CLIPEncoder,
)
from mflux.models.flux.model.flux_text_encoder.t5_encoder.t5_encoder import T5Encoder
"""
This type stub file was generated by pyright.
"""
class PromptEncoder:
@staticmethod
def encode_prompt(
prompt: str,
prompt_cache: dict[str, tuple[mx.array, mx.array]],
t5_tokenizer: Tokenizer,
clip_tokenizer: Tokenizer,
t5_text_encoder: T5Encoder,
clip_text_encoder: CLIPEncoder,
) -> tuple[mx.array, mx.array]: ...

View File

@@ -1,10 +0,0 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
class T5Attention(nn.Module):
def __init__(self) -> None: ...
def __call__(self, hidden_states: mx.array) -> mx.array: ...

View File

@@ -1,10 +0,0 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
class T5Block(nn.Module):
def __init__(self, layer: int) -> None: ...
def __call__(self, hidden_states: mx.array) -> mx.array: ...

View File

@@ -1,12 +0,0 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
class T5DenseReluDense(nn.Module):
def __init__(self) -> None: ...
def __call__(self, hidden_states: mx.array) -> mx.array: ...
@staticmethod
def new_gelu(input_array: mx.array) -> mx.array: ...

View File

@@ -1,14 +0,0 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
"""
This type stub file was generated by pyright.
"""
class T5Encoder(nn.Module):
def __init__(self) -> None: ...
def __call__(self, tokens: mx.array): ...

View File

@@ -1,10 +0,0 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
class T5FeedForward(nn.Module):
def __init__(self) -> None: ...
def __call__(self, hidden_states: mx.array) -> mx.array: ...

View File

@@ -1,10 +0,0 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
class T5LayerNorm(nn.Module):
def __init__(self) -> None: ...
def __call__(self, hidden_states: mx.array) -> mx.array: ...

View File

@@ -1,16 +0,0 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
class T5SelfAttention(nn.Module):
def __init__(self) -> None: ...
def __call__(self, hidden_states: mx.array) -> mx.array: ...
@staticmethod
def shape(states): # -> array:
...
@staticmethod
def un_shape(states): # -> array:
...

View File

@@ -1,10 +0,0 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
class AdaLayerNormContinuous(nn.Module):
def __init__(self, embedding_dim: int, conditioning_embedding_dim: int) -> None: ...
def __call__(self, x: mx.array, text_embeddings: mx.array) -> mx.array: ...

View File

@@ -1,12 +0,0 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
class AdaLayerNormZero(nn.Module):
def __init__(self) -> None: ...
def __call__(
self, hidden_states: mx.array, text_embeddings: mx.array
) -> tuple[mx.array, mx.array, mx.array, mx.array, mx.array]: ...

View File

@@ -1,12 +0,0 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
class AdaLayerNormZeroSingle(nn.Module):
def __init__(self) -> None: ...
def __call__(
self, hidden_states: mx.array, text_embeddings: mx.array
) -> tuple[mx.array, mx.array]: ...

View File

@@ -1,41 +0,0 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
class AttentionUtils:
@staticmethod
def process_qkv(
hidden_states: mx.array,
to_q: nn.Linear,
to_k: nn.Linear,
to_v: nn.Linear,
norm_q: nn.RMSNorm,
norm_k: nn.RMSNorm,
num_heads: int,
head_dim: int,
) -> tuple[mx.array, mx.array, mx.array]: ...
@staticmethod
def compute_attention(
query: mx.array,
key: mx.array,
value: mx.array,
batch_size: int,
num_heads: int,
head_dim: int,
mask: mx.array | None = ...,
) -> mx.array: ...
@staticmethod
def convert_key_padding_mask_to_additive_mask(
mask: mx.array | None, joint_seq_len: int, txt_seq_len: int
) -> mx.array | None: ...
@staticmethod
def apply_rope(
xq: mx.array, xk: mx.array, freqs_cis: mx.array
) -> tuple[mx.array, mx.array]: ...
@staticmethod
def apply_rope_bshd(
xq: mx.array, xk: mx.array, cos: mx.array, sin: mx.array
) -> tuple[mx.array, mx.array]: ...

Some files were not shown because too many files have changed in this diff Show More