mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-22 04:51:19 -05:00
Compare commits
14 Commits
foo
...
tool-calli
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5266f8d7ae | ||
|
|
75036ee9f6 | ||
|
|
023108a19d | ||
|
|
c9818c30b4 | ||
|
|
8f6726d6be | ||
|
|
ede779219c | ||
|
|
a7e205e489 | ||
|
|
a354aaa3e5 | ||
|
|
307f454b96 | ||
|
|
a31b6ee045 | ||
|
|
6a9251b920 | ||
|
|
758464703d | ||
|
|
9e2179c848 | ||
|
|
22b5d836ef |
7
.mlx_typings/mflux/__init__.pyi
Normal file
7
.mlx_typings/mflux/__init__.pyi
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
if "TOKENIZERS_PARALLELISM" not in os.environ: ...
|
||||||
3
.mlx_typings/mflux/callbacks/__init__.pyi
Normal file
3
.mlx_typings/mflux/callbacks/__init__.pyi
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
47
.mlx_typings/mflux/callbacks/callback.pyi
Normal file
47
.mlx_typings/mflux/callbacks/callback.pyi
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import PIL.Image
|
||||||
|
import tqdm
|
||||||
|
from typing import Protocol
|
||||||
|
from mflux.models.common.config.config import Config
|
||||||
|
|
||||||
|
class BeforeLoopCallback(Protocol):
|
||||||
|
def call_before_loop(
|
||||||
|
self,
|
||||||
|
seed: int,
|
||||||
|
prompt: str,
|
||||||
|
latents: mx.array,
|
||||||
|
config: Config,
|
||||||
|
canny_image: PIL.Image.Image | None = ...,
|
||||||
|
depth_image: PIL.Image.Image | None = ...,
|
||||||
|
) -> None: ...
|
||||||
|
|
||||||
|
class InLoopCallback(Protocol):
|
||||||
|
def call_in_loop(
|
||||||
|
self,
|
||||||
|
t: int,
|
||||||
|
seed: int,
|
||||||
|
prompt: str,
|
||||||
|
latents: mx.array,
|
||||||
|
config: Config,
|
||||||
|
time_steps: tqdm,
|
||||||
|
) -> None: ...
|
||||||
|
|
||||||
|
class AfterLoopCallback(Protocol):
|
||||||
|
def call_after_loop(
|
||||||
|
self, seed: int, prompt: str, latents: mx.array, config: Config
|
||||||
|
) -> None: ...
|
||||||
|
|
||||||
|
class InterruptCallback(Protocol):
|
||||||
|
def call_interrupt(
|
||||||
|
self,
|
||||||
|
t: int,
|
||||||
|
seed: int,
|
||||||
|
prompt: str,
|
||||||
|
latents: mx.array,
|
||||||
|
config: Config,
|
||||||
|
time_steps: tqdm,
|
||||||
|
) -> None: ...
|
||||||
24
.mlx_typings/mflux/callbacks/callback_registry.pyi
Normal file
24
.mlx_typings/mflux/callbacks/callback_registry.pyi
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
from mflux.callbacks.callback import (
|
||||||
|
AfterLoopCallback,
|
||||||
|
BeforeLoopCallback,
|
||||||
|
InLoopCallback,
|
||||||
|
InterruptCallback,
|
||||||
|
)
|
||||||
|
from mflux.callbacks.generation_context import GenerationContext
|
||||||
|
from mflux.models.common.config.config import Config
|
||||||
|
|
||||||
|
if TYPE_CHECKING: ...
|
||||||
|
|
||||||
|
class CallbackRegistry:
|
||||||
|
def __init__(self) -> None: ...
|
||||||
|
def register(self, callback) -> None: ...
|
||||||
|
def start(self, seed: int, prompt: str, config: Config) -> GenerationContext: ...
|
||||||
|
def before_loop_callbacks(self) -> list[BeforeLoopCallback]: ...
|
||||||
|
def in_loop_callbacks(self) -> list[InLoopCallback]: ...
|
||||||
|
def after_loop_callbacks(self) -> list[AfterLoopCallback]: ...
|
||||||
|
def interrupt_callbacks(self) -> list[InterruptCallback]: ...
|
||||||
29
.mlx_typings/mflux/callbacks/generation_context.pyi
Normal file
29
.mlx_typings/mflux/callbacks/generation_context.pyi
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import PIL.Image
|
||||||
|
import tqdm
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
from mflux.callbacks.callback_registry import CallbackRegistry
|
||||||
|
from mflux.models.common.config.config import Config
|
||||||
|
|
||||||
|
if TYPE_CHECKING: ...
|
||||||
|
|
||||||
|
class GenerationContext:
|
||||||
|
def __init__(
|
||||||
|
self, registry: CallbackRegistry, seed: int, prompt: str, config: Config
|
||||||
|
) -> None: ...
|
||||||
|
def before_loop(
|
||||||
|
self,
|
||||||
|
latents: mx.array,
|
||||||
|
*,
|
||||||
|
canny_image: PIL.Image.Image | None = ...,
|
||||||
|
depth_image: PIL.Image.Image | None = ...,
|
||||||
|
) -> None: ...
|
||||||
|
def in_loop(self, t: int, latents: mx.array, time_steps: tqdm = ...) -> None: ...
|
||||||
|
def after_loop(self, latents: mx.array) -> None: ...
|
||||||
|
def interruption(
|
||||||
|
self, t: int, latents: mx.array, time_steps: tqdm = ...
|
||||||
|
) -> None: ...
|
||||||
3
.mlx_typings/mflux/cli/__init__.pyi
Normal file
3
.mlx_typings/mflux/cli/__init__.pyi
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
22
.mlx_typings/mflux/cli/defaults/defaults.pyi
Normal file
22
.mlx_typings/mflux/cli/defaults/defaults.pyi
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
BATTERY_PERCENTAGE_STOP_LIMIT = ...
|
||||||
|
CONTROLNET_STRENGTH = ...
|
||||||
|
DEFAULT_DEV_FILL_GUIDANCE = ...
|
||||||
|
DEFAULT_DEPTH_GUIDANCE = ...
|
||||||
|
DIMENSION_STEP_PIXELS = ...
|
||||||
|
GUIDANCE_SCALE = ...
|
||||||
|
GUIDANCE_SCALE_KONTEXT = ...
|
||||||
|
IMAGE_STRENGTH = ...
|
||||||
|
MODEL_CHOICES = ...
|
||||||
|
MODEL_INFERENCE_STEPS = ...
|
||||||
|
QUANTIZE_CHOICES = ...
|
||||||
|
if os.environ.get("MFLUX_CACHE_DIR"):
|
||||||
|
MFLUX_CACHE_DIR = ...
|
||||||
|
else:
|
||||||
|
MFLUX_CACHE_DIR = ...
|
||||||
|
MFLUX_LORA_CACHE_DIR = ...
|
||||||
3
.mlx_typings/mflux/models/__init__.pyi
Normal file
3
.mlx_typings/mflux/models/__init__.pyi
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
3
.mlx_typings/mflux/models/common/__init__.pyi
Normal file
3
.mlx_typings/mflux/models/common/__init__.pyi
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
3
.mlx_typings/mflux/models/common/cli/__init__.pyi
Normal file
3
.mlx_typings/mflux/models/common/cli/__init__.pyi
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
8
.mlx_typings/mflux/models/common/config/__init__.pyi
Normal file
8
.mlx_typings/mflux/models/common/config/__init__.pyi
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from mflux.models.common.config.config import Config
|
||||||
|
from mflux.models.common.config.model_config import ModelConfig
|
||||||
|
|
||||||
|
__all__ = ["Config", "ModelConfig"]
|
||||||
66
.mlx_typings/mflux/models/common/config/config.pyi
Normal file
66
.mlx_typings/mflux/models/common/config/config.pyi
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
from tqdm import tqdm
|
||||||
|
from mflux.models.common.config.model_config import ModelConfig
|
||||||
|
|
||||||
|
logger = ...
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
num_inference_steps: int = ...,
|
||||||
|
height: int = ...,
|
||||||
|
width: int = ...,
|
||||||
|
guidance: float = ...,
|
||||||
|
image_path: Path | str | None = ...,
|
||||||
|
image_strength: float | None = ...,
|
||||||
|
depth_image_path: Path | str | None = ...,
|
||||||
|
redux_image_paths: list[Path | str] | None = ...,
|
||||||
|
redux_image_strengths: list[float] | None = ...,
|
||||||
|
masked_image_path: Path | str | None = ...,
|
||||||
|
controlnet_strength: float | None = ...,
|
||||||
|
scheduler: str = ...,
|
||||||
|
) -> None: ...
|
||||||
|
@property
|
||||||
|
def height(self) -> int: ...
|
||||||
|
@property
|
||||||
|
def width(self) -> int: ...
|
||||||
|
@width.setter
|
||||||
|
def width(self, value): # -> None:
|
||||||
|
...
|
||||||
|
@property
|
||||||
|
def image_seq_len(self) -> int: ...
|
||||||
|
@property
|
||||||
|
def guidance(self) -> float: ...
|
||||||
|
@property
|
||||||
|
def num_inference_steps(self) -> int: ...
|
||||||
|
@property
|
||||||
|
def precision(self) -> mx.Dtype: ...
|
||||||
|
@property
|
||||||
|
def num_train_steps(self) -> int: ...
|
||||||
|
@property
|
||||||
|
def image_path(self) -> Path | None: ...
|
||||||
|
@property
|
||||||
|
def image_strength(self) -> float | None: ...
|
||||||
|
@property
|
||||||
|
def depth_image_path(self) -> Path | None: ...
|
||||||
|
@property
|
||||||
|
def redux_image_paths(self) -> list[Path] | None: ...
|
||||||
|
@property
|
||||||
|
def redux_image_strengths(self) -> list[float] | None: ...
|
||||||
|
@property
|
||||||
|
def masked_image_path(self) -> Path | None: ...
|
||||||
|
@property
|
||||||
|
def init_time_step(self) -> int: ...
|
||||||
|
@property
|
||||||
|
def time_steps(self) -> tqdm: ...
|
||||||
|
@property
|
||||||
|
def controlnet_strength(self) -> float | None: ...
|
||||||
|
@property
|
||||||
|
def scheduler(self) -> Any: ...
|
||||||
86
.mlx_typings/mflux/models/common/config/model_config.pyi
Normal file
86
.mlx_typings/mflux/models/common/config/model_config.pyi
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from functools import lru_cache
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
class ModelConfig:
|
||||||
|
precision: mx.Dtype = ...
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
priority: int,
|
||||||
|
aliases: list[str],
|
||||||
|
model_name: str,
|
||||||
|
base_model: str | None,
|
||||||
|
controlnet_model: str | None,
|
||||||
|
custom_transformer_model: str | None,
|
||||||
|
num_train_steps: int | None,
|
||||||
|
max_sequence_length: int | None,
|
||||||
|
supports_guidance: bool | None,
|
||||||
|
requires_sigma_shift: bool | None,
|
||||||
|
transformer_overrides: dict | None = ...,
|
||||||
|
) -> None: ...
|
||||||
|
@staticmethod
|
||||||
|
@lru_cache
|
||||||
|
def dev() -> ModelConfig: ...
|
||||||
|
@staticmethod
|
||||||
|
@lru_cache
|
||||||
|
def schnell() -> ModelConfig: ...
|
||||||
|
@staticmethod
|
||||||
|
@lru_cache
|
||||||
|
def dev_kontext() -> ModelConfig: ...
|
||||||
|
@staticmethod
|
||||||
|
@lru_cache
|
||||||
|
def dev_fill() -> ModelConfig: ...
|
||||||
|
@staticmethod
|
||||||
|
@lru_cache
|
||||||
|
def dev_redux() -> ModelConfig: ...
|
||||||
|
@staticmethod
|
||||||
|
@lru_cache
|
||||||
|
def dev_depth() -> ModelConfig: ...
|
||||||
|
@staticmethod
|
||||||
|
@lru_cache
|
||||||
|
def dev_controlnet_canny() -> ModelConfig: ...
|
||||||
|
@staticmethod
|
||||||
|
@lru_cache
|
||||||
|
def schnell_controlnet_canny() -> ModelConfig: ...
|
||||||
|
@staticmethod
|
||||||
|
@lru_cache
|
||||||
|
def dev_controlnet_upscaler() -> ModelConfig: ...
|
||||||
|
@staticmethod
|
||||||
|
@lru_cache
|
||||||
|
def dev_fill_catvton() -> ModelConfig: ...
|
||||||
|
@staticmethod
|
||||||
|
@lru_cache
|
||||||
|
def krea_dev() -> ModelConfig: ...
|
||||||
|
@staticmethod
|
||||||
|
@lru_cache
|
||||||
|
def flux2_klein_4b() -> ModelConfig: ...
|
||||||
|
@staticmethod
|
||||||
|
@lru_cache
|
||||||
|
def flux2_klein_9b() -> ModelConfig: ...
|
||||||
|
@staticmethod
|
||||||
|
@lru_cache
|
||||||
|
def qwen_image() -> ModelConfig: ...
|
||||||
|
@staticmethod
|
||||||
|
@lru_cache
|
||||||
|
def qwen_image_edit() -> ModelConfig: ...
|
||||||
|
@staticmethod
|
||||||
|
@lru_cache
|
||||||
|
def fibo() -> ModelConfig: ...
|
||||||
|
@staticmethod
|
||||||
|
@lru_cache
|
||||||
|
def z_image_turbo() -> ModelConfig: ...
|
||||||
|
@staticmethod
|
||||||
|
@lru_cache
|
||||||
|
def seedvr2_3b() -> ModelConfig: ...
|
||||||
|
def x_embedder_input_dim(self) -> int: ...
|
||||||
|
def is_canny(self) -> bool: ...
|
||||||
|
@staticmethod
|
||||||
|
def from_name(
|
||||||
|
model_name: str, base_model: Literal["dev", "schnell", "krea-dev"] | None = ...
|
||||||
|
) -> ModelConfig: ...
|
||||||
|
|
||||||
|
AVAILABLE_MODELS = ...
|
||||||
@@ -0,0 +1,7 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
@@ -0,0 +1,49 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING, TypeAlias
|
||||||
|
from mlx import nn
|
||||||
|
from mflux.models.common.vae.tiling_config import TilingConfig
|
||||||
|
from mflux.models.fibo.latent_creator.fibo_latent_creator import FiboLatentCreator
|
||||||
|
from mflux.models.flux.latent_creator.flux_latent_creator import FluxLatentCreator
|
||||||
|
from mflux.models.qwen.latent_creator.qwen_latent_creator import QwenLatentCreator
|
||||||
|
from mflux.models.z_image.latent_creator.z_image_latent_creator import (
|
||||||
|
ZImageLatentCreator,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
LatentCreatorType: TypeAlias = type[
|
||||||
|
FiboLatentCreator | FluxLatentCreator | QwenLatentCreator | ZImageLatentCreator
|
||||||
|
]
|
||||||
|
|
||||||
|
class Img2Img:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vae: nn.Module,
|
||||||
|
latent_creator: LatentCreatorType,
|
||||||
|
sigmas: mx.array,
|
||||||
|
init_time_step: int,
|
||||||
|
image_path: str | Path | None,
|
||||||
|
tiling_config: TilingConfig | None = ...,
|
||||||
|
) -> None: ...
|
||||||
|
|
||||||
|
class LatentCreator:
|
||||||
|
@staticmethod
|
||||||
|
def create_for_txt2img_or_img2img(
|
||||||
|
seed: int, height: int, width: int, img2img: Img2Img
|
||||||
|
) -> mx.array: ...
|
||||||
|
@staticmethod
|
||||||
|
def encode_image(
|
||||||
|
vae: nn.Module,
|
||||||
|
image_path: str | Path,
|
||||||
|
height: int,
|
||||||
|
width: int,
|
||||||
|
tiling_config: TilingConfig | None = ...,
|
||||||
|
) -> mx.array: ...
|
||||||
|
@staticmethod
|
||||||
|
def add_noise_by_interpolation(
|
||||||
|
clean: mx.array, noise: mx.array, sigma: float
|
||||||
|
) -> mx.array: ...
|
||||||
3
.mlx_typings/mflux/models/common/lora/__init__.pyi
Normal file
3
.mlx_typings/mflux/models/common/lora/__init__.pyi
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
@@ -0,0 +1,13 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from mlx import nn
|
||||||
|
from mflux.models.common.lora.layer.linear_lora_layer import LoRALinear
|
||||||
|
|
||||||
|
class FusedLoRALinear(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, base_linear: nn.Linear | nn.QuantizedLinear, loras: list[LoRALinear]
|
||||||
|
) -> None: ...
|
||||||
|
def __call__(self, x): # -> array:
|
||||||
|
...
|
||||||
@@ -0,0 +1,22 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from mlx import nn
|
||||||
|
|
||||||
|
class LoRALinear(nn.Module):
|
||||||
|
@staticmethod
|
||||||
|
def from_linear(
|
||||||
|
linear: nn.Linear | nn.QuantizedLinear, r: int = ..., scale: float = ...
|
||||||
|
): # -> LoRALinear:
|
||||||
|
...
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_dims: int,
|
||||||
|
output_dims: int,
|
||||||
|
r: int = ...,
|
||||||
|
scale: float = ...,
|
||||||
|
bias: bool = ...,
|
||||||
|
) -> None: ...
|
||||||
|
def __call__(self, x): # -> array:
|
||||||
|
...
|
||||||
@@ -0,0 +1,26 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
from collections.abc import Callable
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from mflux.models.common.lora.mapping.lora_mapping import LoRATarget
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PatternMatch:
|
||||||
|
source_pattern: str
|
||||||
|
target_path: str
|
||||||
|
matrix_name: str
|
||||||
|
transpose: bool
|
||||||
|
transform: Callable[[mx.array], mx.array] | None = ...
|
||||||
|
|
||||||
|
class LoRALoader:
|
||||||
|
@staticmethod
|
||||||
|
def load_and_apply_lora(
|
||||||
|
lora_mapping: list[LoRATarget],
|
||||||
|
transformer: nn.Module,
|
||||||
|
lora_paths: list[str] | None = ...,
|
||||||
|
lora_scales: list[float] | None = ...,
|
||||||
|
) -> tuple[list[str], list[float]]: ...
|
||||||
@@ -0,0 +1,21 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from collections.abc import Callable
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List, Protocol
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LoRATarget:
|
||||||
|
model_path: str
|
||||||
|
possible_up_patterns: List[str]
|
||||||
|
possible_down_patterns: List[str]
|
||||||
|
possible_alpha_patterns: List[str] = ...
|
||||||
|
up_transform: Callable[[mx.array], mx.array] | None = ...
|
||||||
|
down_transform: Callable[[mx.array], mx.array] | None = ...
|
||||||
|
|
||||||
|
class LoRAMapping(Protocol):
|
||||||
|
@staticmethod
|
||||||
|
def get_mapping() -> List[LoRATarget]: ...
|
||||||
@@ -0,0 +1,9 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
class LoRASaver:
|
||||||
|
@staticmethod
|
||||||
|
def bake_and_strip_lora(module: nn.Module) -> nn.Module: ...
|
||||||
@@ -0,0 +1,35 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
|
class LoraTransforms:
|
||||||
|
@staticmethod
|
||||||
|
def split_q_up(tensor: mx.array) -> mx.array: ...
|
||||||
|
@staticmethod
|
||||||
|
def split_k_up(tensor: mx.array) -> mx.array: ...
|
||||||
|
@staticmethod
|
||||||
|
def split_v_up(tensor: mx.array) -> mx.array: ...
|
||||||
|
@staticmethod
|
||||||
|
def split_q_down(tensor: mx.array) -> mx.array: ...
|
||||||
|
@staticmethod
|
||||||
|
def split_k_down(tensor: mx.array) -> mx.array: ...
|
||||||
|
@staticmethod
|
||||||
|
def split_v_down(tensor: mx.array) -> mx.array: ...
|
||||||
|
@staticmethod
|
||||||
|
def split_single_q_up(tensor: mx.array) -> mx.array: ...
|
||||||
|
@staticmethod
|
||||||
|
def split_single_k_up(tensor: mx.array) -> mx.array: ...
|
||||||
|
@staticmethod
|
||||||
|
def split_single_v_up(tensor: mx.array) -> mx.array: ...
|
||||||
|
@staticmethod
|
||||||
|
def split_single_mlp_up(tensor: mx.array) -> mx.array: ...
|
||||||
|
@staticmethod
|
||||||
|
def split_single_q_down(tensor: mx.array) -> mx.array: ...
|
||||||
|
@staticmethod
|
||||||
|
def split_single_k_down(tensor: mx.array) -> mx.array: ...
|
||||||
|
@staticmethod
|
||||||
|
def split_single_v_down(tensor: mx.array) -> mx.array: ...
|
||||||
|
@staticmethod
|
||||||
|
def split_single_mlp_down(tensor: mx.array) -> mx.array: ...
|
||||||
17
.mlx_typings/mflux/models/common/resolution/__init__.pyi
Normal file
17
.mlx_typings/mflux/models/common/resolution/__init__.pyi
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from mflux.models.common.resolution.config_resolution import ConfigResolution
|
||||||
|
from mflux.models.common.resolution.lora_resolution import LoraResolution
|
||||||
|
from mflux.models.common.resolution.path_resolution import PathResolution
|
||||||
|
from mflux.models.common.resolution.quantization_resolution import (
|
||||||
|
QuantizationResolution,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ConfigResolution",
|
||||||
|
"LoraResolution",
|
||||||
|
"PathResolution",
|
||||||
|
"QuantizationResolution",
|
||||||
|
]
|
||||||
39
.mlx_typings/mflux/models/common/resolution/actions.pyi
Normal file
39
.mlx_typings/mflux/models/common/resolution/actions.pyi
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
from typing import NamedTuple
|
||||||
|
|
||||||
|
class QuantizationAction(Enum):
|
||||||
|
NONE = ...
|
||||||
|
STORED = ...
|
||||||
|
REQUESTED = ...
|
||||||
|
|
||||||
|
class PathAction(Enum):
|
||||||
|
LOCAL = ...
|
||||||
|
HUGGINGFACE_CACHED = ...
|
||||||
|
HUGGINGFACE = ...
|
||||||
|
ERROR = ...
|
||||||
|
|
||||||
|
class LoraAction(Enum):
|
||||||
|
LOCAL = ...
|
||||||
|
REGISTRY = ...
|
||||||
|
HUGGINGFACE_COLLECTION_CACHED = ...
|
||||||
|
HUGGINGFACE_COLLECTION = ...
|
||||||
|
HUGGINGFACE_REPO_CACHED = ...
|
||||||
|
HUGGINGFACE_REPO = ...
|
||||||
|
ERROR = ...
|
||||||
|
|
||||||
|
class ConfigAction(Enum):
|
||||||
|
EXACT_MATCH = ...
|
||||||
|
EXPLICIT_BASE = ...
|
||||||
|
INFER_SUBSTRING = ...
|
||||||
|
ERROR = ...
|
||||||
|
|
||||||
|
class Rule(NamedTuple):
|
||||||
|
priority: int
|
||||||
|
name: str
|
||||||
|
check: str
|
||||||
|
action: QuantizationAction | PathAction | LoraAction | ConfigAction
|
||||||
|
...
|
||||||
@@ -0,0 +1,14 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
from mflux.models.common.config.model_config import ModelConfig
|
||||||
|
|
||||||
|
if TYPE_CHECKING: ...
|
||||||
|
logger = ...
|
||||||
|
|
||||||
|
class ConfigResolution:
|
||||||
|
RULES = ...
|
||||||
|
@staticmethod
|
||||||
|
def resolve(model_name: str, base_model: str | None = ...) -> ModelConfig: ...
|
||||||
@@ -0,0 +1,21 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
logger = ...
|
||||||
|
|
||||||
|
class LoraResolution:
|
||||||
|
RULES = ...
|
||||||
|
_registry: dict[str, Path] = ...
|
||||||
|
@staticmethod
|
||||||
|
def resolve(path: str) -> str: ...
|
||||||
|
@staticmethod
|
||||||
|
def resolve_paths(paths: list[str] | None) -> list[str]: ...
|
||||||
|
@staticmethod
|
||||||
|
def resolve_scales(scales: list[float] | None, num_paths: int) -> list[float]: ...
|
||||||
|
@staticmethod
|
||||||
|
def get_registry() -> dict[str, Path]: ...
|
||||||
|
@staticmethod
|
||||||
|
def discover_files(library_paths: list[Path]) -> dict[str, Path]: ...
|
||||||
@@ -0,0 +1,12 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
logger = ...
|
||||||
|
|
||||||
|
class PathResolution:
|
||||||
|
RULES = ...
|
||||||
|
@staticmethod
|
||||||
|
def resolve(path: str | None, patterns: list[str] | None = ...) -> Path | None: ...
|
||||||
@@ -0,0 +1,12 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
logger = ...
|
||||||
|
|
||||||
|
class QuantizationResolution:
|
||||||
|
RULES = ...
|
||||||
|
@staticmethod
|
||||||
|
def resolve(
|
||||||
|
stored: int | None, requested: int | None
|
||||||
|
) -> tuple[int | None, str | None]: ...
|
||||||
26
.mlx_typings/mflux/models/common/schedulers/__init__.pyi
Normal file
26
.mlx_typings/mflux/models/common/schedulers/__init__.pyi
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .flow_match_euler_discrete_scheduler import FlowMatchEulerDiscreteScheduler
|
||||||
|
from .linear_scheduler import LinearScheduler
|
||||||
|
from .seedvr2_euler_scheduler import SeedVR2EulerScheduler
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"LinearScheduler",
|
||||||
|
"FlowMatchEulerDiscreteScheduler",
|
||||||
|
"SeedVR2EulerScheduler",
|
||||||
|
]
|
||||||
|
|
||||||
|
class SchedulerModuleNotFound(ValueError): ...
|
||||||
|
class SchedulerClassNotFound(ValueError): ...
|
||||||
|
class InvalidSchedulerType(TypeError): ...
|
||||||
|
|
||||||
|
SCHEDULER_REGISTRY = ...
|
||||||
|
|
||||||
|
def register_contrib(scheduler_object, scheduler_name=...): # -> None:
|
||||||
|
...
|
||||||
|
def try_import_external_scheduler(
|
||||||
|
scheduler_object_path: str,
|
||||||
|
): # -> type[BaseScheduler]:
|
||||||
|
...
|
||||||
@@ -0,0 +1,16 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
class BaseScheduler(ABC):
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def sigmas(self) -> mx.array: ...
|
||||||
|
@abstractmethod
|
||||||
|
def step(
|
||||||
|
self, noise: mx.array, timestep: int, latents: mx.array, **kwargs
|
||||||
|
) -> mx.array: ...
|
||||||
|
def scale_model_input(self, latents: mx.array, t: int) -> mx.array: ...
|
||||||
@@ -0,0 +1,26 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
from mflux.models.common.config.config import Config
|
||||||
|
from mflux.models.common.schedulers.base_scheduler import BaseScheduler
|
||||||
|
|
||||||
|
if TYPE_CHECKING: ...
|
||||||
|
|
||||||
|
class FlowMatchEulerDiscreteScheduler(BaseScheduler):
|
||||||
|
def __init__(self, config: Config) -> None: ...
|
||||||
|
@property
|
||||||
|
def sigmas(self) -> mx.array: ...
|
||||||
|
@property
|
||||||
|
def timesteps(self) -> mx.array: ...
|
||||||
|
def set_image_seq_len(self, image_seq_len: int) -> None: ...
|
||||||
|
@staticmethod
|
||||||
|
def get_timesteps_and_sigmas(
|
||||||
|
image_seq_len: int, num_inference_steps: int, num_train_timesteps: int = ...
|
||||||
|
) -> tuple[mx.array, mx.array]: ...
|
||||||
|
def step(
|
||||||
|
self, noise: mx.array, timestep: int, latents: mx.array, **kwargs
|
||||||
|
) -> mx.array: ...
|
||||||
|
def scale_model_input(self, latents: mx.array, t: int) -> mx.array: ...
|
||||||
@@ -0,0 +1,20 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
from mflux.models.common.config.config import Config
|
||||||
|
from mflux.models.common.schedulers.base_scheduler import BaseScheduler
|
||||||
|
|
||||||
|
if TYPE_CHECKING: ...
|
||||||
|
|
||||||
|
class LinearScheduler(BaseScheduler):
|
||||||
|
def __init__(self, config: Config) -> None: ...
|
||||||
|
@property
|
||||||
|
def sigmas(self) -> mx.array: ...
|
||||||
|
@property
|
||||||
|
def timesteps(self) -> mx.array: ...
|
||||||
|
def step(
|
||||||
|
self, noise: mx.array, timestep: int, latents: mx.array, **kwargs
|
||||||
|
) -> mx.array: ...
|
||||||
@@ -0,0 +1,20 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
from mflux.models.common.config.config import Config
|
||||||
|
from mflux.models.common.schedulers.base_scheduler import BaseScheduler
|
||||||
|
|
||||||
|
if TYPE_CHECKING: ...
|
||||||
|
|
||||||
|
class SeedVR2EulerScheduler(BaseScheduler):
|
||||||
|
def __init__(self, config: Config) -> None: ...
|
||||||
|
@property
|
||||||
|
def timesteps(self) -> mx.array: ...
|
||||||
|
@property
|
||||||
|
def sigmas(self) -> mx.array: ...
|
||||||
|
def step(
|
||||||
|
self, noise: mx.array, timestep: int, latents: mx.array, **kwargs
|
||||||
|
) -> mx.array: ...
|
||||||
24
.mlx_typings/mflux/models/common/tokenizer/__init__.pyi
Normal file
24
.mlx_typings/mflux/models/common/tokenizer/__init__.pyi
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from mflux.models.common.tokenizer.tokenizer import (
|
||||||
|
BaseTokenizer,
|
||||||
|
LanguageTokenizer,
|
||||||
|
Tokenizer,
|
||||||
|
VisionLanguageTokenizer,
|
||||||
|
)
|
||||||
|
from mflux.models.common.tokenizer.tokenizer_loader import TokenizerLoader
|
||||||
|
from mflux.models.common.tokenizer.tokenizer_output import TokenizerOutput
|
||||||
|
|
||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
__all__ = [
|
||||||
|
"Tokenizer",
|
||||||
|
"BaseTokenizer",
|
||||||
|
"LanguageTokenizer",
|
||||||
|
"VisionLanguageTokenizer",
|
||||||
|
"TokenizerLoader",
|
||||||
|
"TokenizerOutput",
|
||||||
|
]
|
||||||
74
.mlx_typings/mflux/models/common/tokenizer/tokenizer.pyi
Normal file
74
.mlx_typings/mflux/models/common/tokenizer/tokenizer.pyi
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Protocol, runtime_checkable
|
||||||
|
from PIL import Image
|
||||||
|
from transformers import PreTrainedTokenizer
|
||||||
|
from mflux.models.common.tokenizer.tokenizer_output import TokenizerOutput
|
||||||
|
|
||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class Tokenizer(Protocol):
|
||||||
|
tokenizer: PreTrainedTokenizer
|
||||||
|
def tokenize(
|
||||||
|
self,
|
||||||
|
prompt: str | list[str],
|
||||||
|
images: list[Image.Image] | None = ...,
|
||||||
|
max_length: int | None = ...,
|
||||||
|
**kwargs,
|
||||||
|
) -> TokenizerOutput: ...
|
||||||
|
|
||||||
|
class BaseTokenizer(ABC):
|
||||||
|
def __init__(
|
||||||
|
self, tokenizer: PreTrainedTokenizer, max_length: int = ...
|
||||||
|
) -> None: ...
|
||||||
|
@abstractmethod
|
||||||
|
def tokenize(
|
||||||
|
self,
|
||||||
|
prompt: str | list[str],
|
||||||
|
images: list[Image.Image] | None = ...,
|
||||||
|
max_length: int | None = ...,
|
||||||
|
**kwargs,
|
||||||
|
) -> TokenizerOutput: ...
|
||||||
|
|
||||||
|
class LanguageTokenizer(BaseTokenizer):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
|
max_length: int = ...,
|
||||||
|
padding: str = ...,
|
||||||
|
return_attention_mask: bool = ...,
|
||||||
|
template: str | None = ...,
|
||||||
|
use_chat_template: bool = ...,
|
||||||
|
chat_template_kwargs: dict | None = ...,
|
||||||
|
add_special_tokens: bool = ...,
|
||||||
|
) -> None: ...
|
||||||
|
def tokenize(
|
||||||
|
self,
|
||||||
|
prompt: str | list[str],
|
||||||
|
images: list[Image.Image] | None = ...,
|
||||||
|
max_length: int | None = ...,
|
||||||
|
**kwargs,
|
||||||
|
) -> TokenizerOutput: ...
|
||||||
|
|
||||||
|
class VisionLanguageTokenizer(BaseTokenizer):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
|
processor,
|
||||||
|
max_length: int = ...,
|
||||||
|
template: str | None = ...,
|
||||||
|
image_token: str = ...,
|
||||||
|
) -> None: ...
|
||||||
|
def tokenize(
|
||||||
|
self,
|
||||||
|
prompt: str | list[str],
|
||||||
|
images: list[Image.Image] | None = ...,
|
||||||
|
max_length: int | None = ...,
|
||||||
|
**kwargs,
|
||||||
|
) -> TokenizerOutput: ...
|
||||||
@@ -0,0 +1,22 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
from mflux.models.common.tokenizer.tokenizer import BaseTokenizer
|
||||||
|
from mflux.models.common.weights.loading.weight_definition import TokenizerDefinition
|
||||||
|
|
||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
if TYPE_CHECKING: ...
|
||||||
|
|
||||||
|
class TokenizerLoader:
|
||||||
|
@staticmethod
|
||||||
|
def load(definition: TokenizerDefinition, model_path: str) -> BaseTokenizer: ...
|
||||||
|
@staticmethod
|
||||||
|
def load_all(
|
||||||
|
definitions: list[TokenizerDefinition],
|
||||||
|
model_path: str,
|
||||||
|
max_length_overrides: dict[str, int] | None = ...,
|
||||||
|
) -> dict[str, BaseTokenizer]: ...
|
||||||
@@ -0,0 +1,17 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TokenizerOutput:
|
||||||
|
input_ids: mx.array
|
||||||
|
attention_mask: mx.array
|
||||||
|
pixel_values: mx.array | None = ...
|
||||||
|
image_grid_thw: mx.array | None = ...
|
||||||
8
.mlx_typings/mflux/models/common/vae/__init__.pyi
Normal file
8
.mlx_typings/mflux/models/common/vae/__init__.pyi
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from mflux.models.common.vae.tiling_config import TilingConfig
|
||||||
|
from mflux.models.common.vae.vae_tiler import VAETiler
|
||||||
|
|
||||||
|
__all__ = ["TilingConfig", "VAETiler"]
|
||||||
13
.mlx_typings/mflux/models/common/vae/tiling_config.pyi
Normal file
13
.mlx_typings/mflux/models/common/vae/tiling_config.pyi
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class TilingConfig:
|
||||||
|
vae_decode_tiles_per_dim: int | None = ...
|
||||||
|
vae_decode_overlap: int = ...
|
||||||
|
vae_encode_tiled: bool = ...
|
||||||
|
vae_encode_tile_size: int = ...
|
||||||
|
vae_encode_tile_overlap: int = ...
|
||||||
27
.mlx_typings/mflux/models/common/vae/vae_tiler.pyi
Normal file
27
.mlx_typings/mflux/models/common/vae/vae_tiler.pyi
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
class VAETiler:
|
||||||
|
@staticmethod
|
||||||
|
def encode_image_tiled(
|
||||||
|
*,
|
||||||
|
image: mx.array,
|
||||||
|
encode_fn: Callable[[mx.array], mx.array],
|
||||||
|
latent_channels: int,
|
||||||
|
tile_size: tuple[int, int] = ...,
|
||||||
|
tile_overlap: tuple[int, int] = ...,
|
||||||
|
spatial_scale: int = ...,
|
||||||
|
) -> mx.array: ...
|
||||||
|
@staticmethod
|
||||||
|
def decode_image_tiled(
|
||||||
|
*,
|
||||||
|
latent: mx.array,
|
||||||
|
decode_fn: Callable[[mx.array], mx.array],
|
||||||
|
tile_size: tuple[int, int] = ...,
|
||||||
|
tile_overlap: tuple[int, int] = ...,
|
||||||
|
spatial_scale: int = ...,
|
||||||
|
) -> mx.array: ...
|
||||||
17
.mlx_typings/mflux/models/common/vae/vae_util.pyi
Normal file
17
.mlx_typings/mflux/models/common/vae/vae_util.pyi
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from mlx import nn
|
||||||
|
from mflux.models.common.vae.tiling_config import TilingConfig
|
||||||
|
|
||||||
|
class VAEUtil:
|
||||||
|
@staticmethod
|
||||||
|
def encode(
|
||||||
|
vae: nn.Module, image: mx.array, tiling_config: TilingConfig | None = ...
|
||||||
|
) -> mx.array: ...
|
||||||
|
@staticmethod
|
||||||
|
def decode(
|
||||||
|
vae: nn.Module, latent: mx.array, tiling_config: TilingConfig | None = ...
|
||||||
|
) -> mx.array: ...
|
||||||
18
.mlx_typings/mflux/models/common/weights/__init__.pyi
Normal file
18
.mlx_typings/mflux/models/common/weights/__init__.pyi
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from mflux.models.common.weights.loading.loaded_weights import LoadedWeights, MetaData
|
||||||
|
from mflux.models.common.weights.loading.weight_applier import WeightApplier
|
||||||
|
from mflux.models.common.weights.loading.weight_definition import ComponentDefinition
|
||||||
|
from mflux.models.common.weights.loading.weight_loader import WeightLoader
|
||||||
|
from mflux.models.common.weights.saving.model_saver import ModelSaver
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ComponentDefinition",
|
||||||
|
"LoadedWeights",
|
||||||
|
"MetaData",
|
||||||
|
"ModelSaver",
|
||||||
|
"WeightApplier",
|
||||||
|
"WeightLoader",
|
||||||
|
]
|
||||||
@@ -0,0 +1,18 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MetaData:
|
||||||
|
quantization_level: int | None = ...
|
||||||
|
mflux_version: str | None = ...
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LoadedWeights:
|
||||||
|
components: dict[str, dict]
|
||||||
|
meta_data: MetaData
|
||||||
|
def __getattr__(self, name: str) -> dict | None: ...
|
||||||
|
def num_transformer_blocks(self, component_name: str = ...) -> int: ...
|
||||||
|
def num_single_transformer_blocks(self, component_name: str = ...) -> int: ...
|
||||||
@@ -0,0 +1,30 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import mlx.nn as nn
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
from mflux.models.common.weights.loading.loaded_weights import LoadedWeights
|
||||||
|
from mflux.models.common.weights.loading.weight_definition import (
|
||||||
|
ComponentDefinition,
|
||||||
|
WeightDefinitionType,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING: ...
|
||||||
|
|
||||||
|
class WeightApplier:
|
||||||
|
@staticmethod
|
||||||
|
def apply_and_quantize_single(
|
||||||
|
weights: LoadedWeights,
|
||||||
|
model: nn.Module,
|
||||||
|
component: ComponentDefinition,
|
||||||
|
quantize_arg: int | None,
|
||||||
|
quantization_predicate=...,
|
||||||
|
) -> int | None: ...
|
||||||
|
@staticmethod
|
||||||
|
def apply_and_quantize(
|
||||||
|
weights: LoadedWeights,
|
||||||
|
models: dict[str, nn.Module],
|
||||||
|
quantize_arg: int | None,
|
||||||
|
weight_definition: WeightDefinitionType,
|
||||||
|
) -> int | None: ...
|
||||||
@@ -0,0 +1,73 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Callable, List, TYPE_CHECKING, TypeAlias
|
||||||
|
from mflux.models.common.weights.mapping.weight_mapping import WeightTarget
|
||||||
|
from mflux.models.common.tokenizer.tokenizer import BaseTokenizer
|
||||||
|
from mflux.models.depth_pro.weights.depth_pro_weight_definition import (
|
||||||
|
DepthProWeightDefinition,
|
||||||
|
)
|
||||||
|
from mflux.models.fibo.weights.fibo_weight_definition import FIBOWeightDefinition
|
||||||
|
from mflux.models.fibo_vlm.weights.fibo_vlm_weight_definition import (
|
||||||
|
FIBOVLMWeightDefinition,
|
||||||
|
)
|
||||||
|
from mflux.models.flux.weights.flux_weight_definition import FluxWeightDefinition
|
||||||
|
from mflux.models.qwen.weights.qwen_weight_definition import QwenWeightDefinition
|
||||||
|
from mflux.models.seedvr2.weights.seedvr2_weight_definition import (
|
||||||
|
SeedVR2WeightDefinition,
|
||||||
|
)
|
||||||
|
from mflux.models.z_image.weights.z_image_weight_definition import (
|
||||||
|
ZImageWeightDefinition,
|
||||||
|
)
|
||||||
|
|
||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
WeightDefinitionType: TypeAlias = type[
|
||||||
|
FluxWeightDefinition
|
||||||
|
| FIBOWeightDefinition
|
||||||
|
| FIBOVLMWeightDefinition
|
||||||
|
| QwenWeightDefinition
|
||||||
|
| ZImageWeightDefinition
|
||||||
|
| SeedVR2WeightDefinition
|
||||||
|
| DepthProWeightDefinition
|
||||||
|
]
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ComponentDefinition:
|
||||||
|
name: str
|
||||||
|
hf_subdir: str
|
||||||
|
mapping_getter: Callable[[], List[WeightTarget]] | None = ...
|
||||||
|
model_attr: str | None = ...
|
||||||
|
num_blocks: int | None = ...
|
||||||
|
num_layers: int | None = ...
|
||||||
|
loading_mode: str = ...
|
||||||
|
precision: mx.Dtype | None = ...
|
||||||
|
skip_quantization: bool = ...
|
||||||
|
bulk_transform: Callable[[mx.array], mx.array] | None = ...
|
||||||
|
weight_subkey: str | None = ...
|
||||||
|
download_url: str | None = ...
|
||||||
|
weight_prefix_filters: List[str] | None = ...
|
||||||
|
weight_files: List[str] | None = ...
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TokenizerDefinition:
|
||||||
|
name: str
|
||||||
|
hf_subdir: str
|
||||||
|
tokenizer_class: str = ...
|
||||||
|
fallback_subdirs: List[str] | None = ...
|
||||||
|
download_patterns: List[str] | None = ...
|
||||||
|
encoder_class: type[BaseTokenizer] | None = ...
|
||||||
|
max_length: int = ...
|
||||||
|
padding: str = ...
|
||||||
|
template: str | None = ...
|
||||||
|
use_chat_template: bool = ...
|
||||||
|
chat_template_kwargs: dict | None = ...
|
||||||
|
add_special_tokens: bool = ...
|
||||||
|
processor_class: type | None = ...
|
||||||
|
image_token: str = ...
|
||||||
|
chat_template: str | None = ...
|
||||||
@@ -0,0 +1,23 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
from mflux.models.common.weights.loading.loaded_weights import LoadedWeights
|
||||||
|
from mflux.models.common.weights.loading.weight_definition import (
|
||||||
|
ComponentDefinition,
|
||||||
|
WeightDefinitionType,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING: ...
|
||||||
|
logger = ...
|
||||||
|
|
||||||
|
class WeightLoader:
|
||||||
|
@staticmethod
|
||||||
|
def load_single(
|
||||||
|
component: ComponentDefinition, repo_id: str, file_pattern: str = ...
|
||||||
|
) -> LoadedWeights: ...
|
||||||
|
@staticmethod
|
||||||
|
def load(
|
||||||
|
weight_definition: WeightDefinitionType, model_path: str | None = ...
|
||||||
|
) -> LoadedWeights: ...
|
||||||
@@ -0,0 +1,16 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
from mflux.models.common.weights.mapping.weight_mapping import WeightTarget
|
||||||
|
|
||||||
|
class WeightMapper:
|
||||||
|
@staticmethod
|
||||||
|
def apply_mapping(
|
||||||
|
hf_weights: Dict[str, mx.array],
|
||||||
|
mapping: List[WeightTarget],
|
||||||
|
num_blocks: Optional[int] = ...,
|
||||||
|
num_layers: Optional[int] = ...,
|
||||||
|
) -> Dict: ...
|
||||||
@@ -0,0 +1,23 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Callable, List, Optional, Protocol
|
||||||
|
|
||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class WeightTarget:
|
||||||
|
to_pattern: str
|
||||||
|
from_pattern: List[str]
|
||||||
|
transform: Optional[Callable[[mx.array], mx.array]] = ...
|
||||||
|
required: bool = ...
|
||||||
|
max_blocks: Optional[int] = ...
|
||||||
|
|
||||||
|
class WeightMapping(Protocol):
|
||||||
|
@staticmethod
|
||||||
|
def get_mapping() -> List[WeightTarget]: ...
|
||||||
@@ -0,0 +1,17 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
|
class WeightTransforms:
|
||||||
|
@staticmethod
|
||||||
|
def reshape_gamma_to_1d(tensor: mx.array) -> mx.array: ...
|
||||||
|
@staticmethod
|
||||||
|
def transpose_patch_embed(tensor: mx.array) -> mx.array: ...
|
||||||
|
@staticmethod
|
||||||
|
def transpose_conv3d_weight(tensor: mx.array) -> mx.array: ...
|
||||||
|
@staticmethod
|
||||||
|
def transpose_conv2d_weight(tensor: mx.array) -> mx.array: ...
|
||||||
|
@staticmethod
|
||||||
|
def transpose_conv_transpose2d_weight(tensor: mx.array) -> mx.array: ...
|
||||||
@@ -0,0 +1,14 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any, TYPE_CHECKING
|
||||||
|
from mflux.models.common.weights.loading.weight_definition import WeightDefinitionType
|
||||||
|
|
||||||
|
if TYPE_CHECKING: ...
|
||||||
|
|
||||||
|
class ModelSaver:
|
||||||
|
@staticmethod
|
||||||
|
def save_model(
|
||||||
|
model: Any, bits: int, base_path: str, weight_definition: WeightDefinitionType
|
||||||
|
) -> None: ...
|
||||||
@@ -0,0 +1,9 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from mflux.models.depth_pro.model.depth_pro_model import DepthProModel
|
||||||
|
|
||||||
|
class DepthProInitializer:
|
||||||
|
@staticmethod
|
||||||
|
def init(model: DepthProModel, quantize: int | None = ...) -> None: ...
|
||||||
@@ -0,0 +1,10 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
class FeatureFusionBlock2d(nn.Module):
|
||||||
|
def __init__(self, num_features: int, deconv: bool = ...) -> None: ...
|
||||||
|
def __call__(self, x0: mx.array, x1: mx.array | None = ...) -> mx.array: ...
|
||||||
@@ -0,0 +1,17 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
class MultiresConvDecoder(nn.Module):
|
||||||
|
def __init__(self) -> None: ...
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x0_latent: mx.array,
|
||||||
|
x1_latent: mx.array,
|
||||||
|
x0_features: mx.array,
|
||||||
|
x1_features: mx.array,
|
||||||
|
x_global_features: mx.array,
|
||||||
|
) -> mx.array: ...
|
||||||
@@ -0,0 +1,10 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
class ResidualBlock(nn.Module):
|
||||||
|
def __init__(self, num_features: int) -> None: ...
|
||||||
|
def __call__(self, x: mx.array) -> mx.array: ...
|
||||||
20
.mlx_typings/mflux/models/depth_pro/model/depth_pro.pyi
Normal file
20
.mlx_typings/mflux/models/depth_pro/model/depth_pro.pyi
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DepthResult:
|
||||||
|
depth_image: Image.Image
|
||||||
|
depth_array: mx.array
|
||||||
|
min_depth: float
|
||||||
|
max_depth: float
|
||||||
|
...
|
||||||
|
|
||||||
|
class DepthPro:
|
||||||
|
def __init__(self, quantize: int | None = ...) -> None: ...
|
||||||
|
def create_depth_map(self, image_path: str | Path) -> DepthResult: ...
|
||||||
@@ -0,0 +1,12 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
class DepthProModel(nn.Module):
|
||||||
|
def __init__(self) -> None: ...
|
||||||
|
def __call__(
|
||||||
|
self, x0: mx.array, x1: mx.array, x2: mx.array
|
||||||
|
) -> tuple[mx.array, mx.array]: ...
|
||||||
15
.mlx_typings/mflux/models/depth_pro/model/depth_pro_util.pyi
Normal file
15
.mlx_typings/mflux/models/depth_pro/model/depth_pro_util.pyi
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
class DepthProUtil:
|
||||||
|
@staticmethod
|
||||||
|
def split(x: mx.array, overlap_ratio: float = ...) -> mx.array: ...
|
||||||
|
@staticmethod
|
||||||
|
def interpolate(x: mx.array, size=..., scale_factor=...): # -> array:
|
||||||
|
...
|
||||||
|
@staticmethod
|
||||||
|
def apply_conv(x: mx.array, conv_module: nn.Module) -> mx.array: ...
|
||||||
@@ -0,0 +1,12 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from mlx import nn
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, dim: int = ..., head_dim: int = ..., num_heads: int = ...
|
||||||
|
) -> None: ...
|
||||||
|
def __call__(self, x: mx.array) -> mx.array: ...
|
||||||
@@ -0,0 +1,10 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
class DinoVisionTransformer(nn.Module):
|
||||||
|
def __init__(self) -> None: ...
|
||||||
|
def __call__(self, x: mx.array) -> tuple[mx.array, mx.array, mx.array]: ...
|
||||||
@@ -0,0 +1,10 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
class LayerScale(nn.Module):
|
||||||
|
def __init__(self, dims: int, init_values: float = ...) -> None: ...
|
||||||
|
def __call__(self, x: mx.array) -> mx.array: ...
|
||||||
10
.mlx_typings/mflux/models/depth_pro/model/dino_v2/mlp.pyi
Normal file
10
.mlx_typings/mflux/models/depth_pro/model/dino_v2/mlp.pyi
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
class MLP(nn.Module):
|
||||||
|
def __init__(self) -> None: ...
|
||||||
|
def __call__(self, x: mx.array) -> mx.array: ...
|
||||||
@@ -0,0 +1,10 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
class PatchEmbed(nn.Module):
|
||||||
|
def __init__(self) -> None: ...
|
||||||
|
def __call__(self, x: mx.array) -> mx.array: ...
|
||||||
@@ -0,0 +1,10 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
class TransformerBlock(nn.Module):
|
||||||
|
def __init__(self) -> None: ...
|
||||||
|
def __call__(self, x: mx.array) -> mx.array: ...
|
||||||
@@ -0,0 +1,12 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
class DepthProEncoder(nn.Module):
|
||||||
|
def __init__(self) -> None: ...
|
||||||
|
def __call__(
|
||||||
|
self, x0: mx.array, x1: mx.array, x2: mx.array
|
||||||
|
) -> tuple[mx.array, mx.array, mx.array, mx.array, mx.array]: ...
|
||||||
@@ -0,0 +1,16 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
class UpSampleBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim_in: int = ...,
|
||||||
|
dim_int: int = ...,
|
||||||
|
dim_out: int = ...,
|
||||||
|
upsample_layers: int = ...,
|
||||||
|
) -> None: ...
|
||||||
|
def __call__(self, x: mx.array) -> mx.array: ...
|
||||||
10
.mlx_typings/mflux/models/depth_pro/model/head/fov_head.pyi
Normal file
10
.mlx_typings/mflux/models/depth_pro/model/head/fov_head.pyi
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
class FOVHead(nn.Module):
|
||||||
|
def __init__(self) -> None: ...
|
||||||
|
def __call__(self, x: mx.array) -> mx.array: ...
|
||||||
@@ -0,0 +1,23 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
from mflux.models.common.weights.loading.weight_definition import (
|
||||||
|
ComponentDefinition,
|
||||||
|
TokenizerDefinition,
|
||||||
|
)
|
||||||
|
|
||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class DepthProWeightDefinition:
|
||||||
|
@staticmethod
|
||||||
|
def get_components() -> List[ComponentDefinition]: ...
|
||||||
|
@staticmethod
|
||||||
|
def get_tokenizers() -> List[TokenizerDefinition]: ...
|
||||||
|
@staticmethod
|
||||||
|
def get_download_patterns() -> List[str]: ...
|
||||||
|
@staticmethod
|
||||||
|
def quantization_predicate(path: str, module) -> bool: ...
|
||||||
@@ -0,0 +1,13 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
from mflux.models.common.weights.mapping.weight_mapping import (
|
||||||
|
WeightMapping,
|
||||||
|
WeightTarget,
|
||||||
|
)
|
||||||
|
|
||||||
|
class DepthProWeightMapping(WeightMapping):
|
||||||
|
@staticmethod
|
||||||
|
def get_mapping() -> List[WeightTarget]: ...
|
||||||
@@ -0,0 +1,13 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
|
class FiboLatentCreator:
|
||||||
|
@staticmethod
|
||||||
|
def create_noise(seed: int, height: int, width: int) -> mx.array: ...
|
||||||
|
@staticmethod
|
||||||
|
def pack_latents(latents: mx.array, height: int, width: int) -> mx.array: ...
|
||||||
|
@staticmethod
|
||||||
|
def unpack_latents(latents: mx.array, height: int, width: int) -> mx.array: ...
|
||||||
@@ -0,0 +1,23 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
from mflux.models.common.weights.loading.weight_definition import (
|
||||||
|
ComponentDefinition,
|
||||||
|
TokenizerDefinition,
|
||||||
|
)
|
||||||
|
|
||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class FIBOWeightDefinition:
|
||||||
|
@staticmethod
|
||||||
|
def get_components() -> List[ComponentDefinition]: ...
|
||||||
|
@staticmethod
|
||||||
|
def get_tokenizers() -> List[TokenizerDefinition]: ...
|
||||||
|
@staticmethod
|
||||||
|
def get_download_patterns() -> List[str]: ...
|
||||||
|
@staticmethod
|
||||||
|
def quantization_predicate(path: str, module) -> bool: ...
|
||||||
@@ -0,0 +1,17 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
from mflux.models.common.weights.mapping.weight_mapping import (
|
||||||
|
WeightMapping,
|
||||||
|
WeightTarget,
|
||||||
|
)
|
||||||
|
|
||||||
|
class FIBOWeightMapping(WeightMapping):
|
||||||
|
@staticmethod
|
||||||
|
def get_transformer_mapping() -> List[WeightTarget]: ...
|
||||||
|
@staticmethod
|
||||||
|
def get_text_encoder_mapping() -> List[WeightTarget]: ...
|
||||||
|
@staticmethod
|
||||||
|
def get_vae_mapping() -> List[WeightTarget]: ...
|
||||||
@@ -0,0 +1,8 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from mflux.models.qwen.tokenizer.qwen_image_processor import QwenImageProcessor
|
||||||
|
|
||||||
|
class Qwen2VLImageProcessor(QwenImageProcessor):
|
||||||
|
def __init__(self) -> None: ...
|
||||||
@@ -0,0 +1,28 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional, Union
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
class Qwen2VLProcessor:
|
||||||
|
def __init__(self, tokenizer) -> None: ...
|
||||||
|
def apply_chat_template(
|
||||||
|
self,
|
||||||
|
messages,
|
||||||
|
tokenize: bool = ...,
|
||||||
|
add_generation_prompt: bool = ...,
|
||||||
|
return_tensors: Optional[str] = ...,
|
||||||
|
return_dict: bool = ...,
|
||||||
|
**kwargs,
|
||||||
|
): # -> dict[Any, Any]:
|
||||||
|
...
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
text: Optional[Union[str, list[str]]] = ...,
|
||||||
|
images: Optional[Union[Image.Image, list[Image.Image]]] = ...,
|
||||||
|
padding: bool = ...,
|
||||||
|
return_tensors: Optional[str] = ...,
|
||||||
|
**kwargs,
|
||||||
|
): # -> dict[Any, Any]:
|
||||||
|
...
|
||||||
@@ -0,0 +1,24 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
from mflux.models.common.weights.loading.weight_definition import (
|
||||||
|
ComponentDefinition,
|
||||||
|
TokenizerDefinition,
|
||||||
|
)
|
||||||
|
|
||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
QWEN2VL_CHAT_TEMPLATE = ...
|
||||||
|
|
||||||
|
class FIBOVLMWeightDefinition:
|
||||||
|
@staticmethod
|
||||||
|
def get_components() -> List[ComponentDefinition]: ...
|
||||||
|
@staticmethod
|
||||||
|
def get_tokenizers() -> List[TokenizerDefinition]: ...
|
||||||
|
@staticmethod
|
||||||
|
def get_download_patterns() -> List[str]: ...
|
||||||
|
@staticmethod
|
||||||
|
def quantization_predicate(path: str, module) -> bool: ...
|
||||||
@@ -0,0 +1,15 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
from mflux.models.common.weights.mapping.weight_mapping import (
|
||||||
|
WeightMapping,
|
||||||
|
WeightTarget,
|
||||||
|
)
|
||||||
|
|
||||||
|
class FIBOVLMWeightMapping(WeightMapping):
|
||||||
|
@staticmethod
|
||||||
|
def get_vlm_decoder_mapping(num_layers: int = ...) -> List[WeightTarget]: ...
|
||||||
|
@staticmethod
|
||||||
|
def get_vlm_visual_mapping(depth: int = ...) -> List[WeightTarget]: ...
|
||||||
3
.mlx_typings/mflux/models/flux/__init__.pyi
Normal file
3
.mlx_typings/mflux/models/flux/__init__.pyi
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
3
.mlx_typings/mflux/models/flux/cli/__init__.pyi
Normal file
3
.mlx_typings/mflux/models/flux/cli/__init__.pyi
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
53
.mlx_typings/mflux/models/flux/flux_initializer.pyi
Normal file
53
.mlx_typings/mflux/models/flux/flux_initializer.pyi
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from mflux.models.common.config import ModelConfig
|
||||||
|
|
||||||
|
class FluxInitializer:
|
||||||
|
@staticmethod
|
||||||
|
def init(
|
||||||
|
model,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
quantize: int | None,
|
||||||
|
model_path: str | None = ...,
|
||||||
|
lora_paths: list[str] | None = ...,
|
||||||
|
lora_scales: list[float] | None = ...,
|
||||||
|
custom_transformer=...,
|
||||||
|
) -> None: ...
|
||||||
|
@staticmethod
|
||||||
|
def init_depth(
|
||||||
|
model,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
quantize: int | None,
|
||||||
|
model_path: str | None = ...,
|
||||||
|
lora_paths: list[str] | None = ...,
|
||||||
|
lora_scales: list[float] | None = ...,
|
||||||
|
) -> None: ...
|
||||||
|
@staticmethod
|
||||||
|
def init_redux(
|
||||||
|
model,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
quantize: int | None,
|
||||||
|
model_path: str | None = ...,
|
||||||
|
lora_paths: list[str] | None = ...,
|
||||||
|
lora_scales: list[float] | None = ...,
|
||||||
|
) -> None: ...
|
||||||
|
@staticmethod
|
||||||
|
def init_controlnet(
|
||||||
|
model,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
quantize: int | None,
|
||||||
|
model_path: str | None = ...,
|
||||||
|
lora_paths: list[str] | None = ...,
|
||||||
|
lora_scales: list[float] | None = ...,
|
||||||
|
) -> None: ...
|
||||||
|
@staticmethod
|
||||||
|
def init_concept(
|
||||||
|
model,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
quantize: int | None,
|
||||||
|
model_path: str | None = ...,
|
||||||
|
lora_paths: list[str] | None = ...,
|
||||||
|
lora_scales: list[float] | None = ...,
|
||||||
|
) -> None: ...
|
||||||
@@ -0,0 +1,7 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
@@ -0,0 +1,19 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class FluxLatentCreator:
|
||||||
|
@staticmethod
|
||||||
|
def create_noise(seed: int, height: int, width: int) -> mx.array: ...
|
||||||
|
@staticmethod
|
||||||
|
def pack_latents(
|
||||||
|
latents: mx.array, height: int, width: int, num_channels_latents: int = ...
|
||||||
|
) -> mx.array: ...
|
||||||
|
@staticmethod
|
||||||
|
def unpack_latents(latents: mx.array, height: int, width: int) -> mx.array: ...
|
||||||
7
.mlx_typings/mflux/models/flux/model/__init__.pyi
Normal file
7
.mlx_typings/mflux/models/flux/model/__init__.pyi
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
@@ -0,0 +1,10 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from mlx import nn
|
||||||
|
|
||||||
|
class CLIPEmbeddings(nn.Module):
|
||||||
|
def __init__(self, dims: int) -> None: ...
|
||||||
|
def __call__(self, tokens: mx.array) -> mx.array: ...
|
||||||
@@ -0,0 +1,14 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from mlx import nn
|
||||||
|
|
||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class CLIPEncoder(nn.Module):
|
||||||
|
def __init__(self) -> None: ...
|
||||||
|
def __call__(self, tokens: mx.array) -> mx.array: ...
|
||||||
@@ -0,0 +1,12 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from mlx import nn
|
||||||
|
|
||||||
|
class CLIPEncoderLayer(nn.Module):
|
||||||
|
def __init__(self, layer: int) -> None: ...
|
||||||
|
def __call__(
|
||||||
|
self, hidden_states: mx.array, causal_attention_mask: mx.array
|
||||||
|
) -> mx.array: ...
|
||||||
@@ -0,0 +1,12 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from mlx import nn
|
||||||
|
|
||||||
|
class CLIPMLP(nn.Module):
|
||||||
|
def __init__(self) -> None: ...
|
||||||
|
def __call__(self, hidden_states: mx.array) -> mx.array: ...
|
||||||
|
@staticmethod
|
||||||
|
def quick_gelu(input_array: mx.array) -> mx.array: ...
|
||||||
@@ -0,0 +1,18 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from mlx import nn
|
||||||
|
|
||||||
|
class CLIPSdpaAttention(nn.Module):
|
||||||
|
head_dimension = ...
|
||||||
|
batch_size = ...
|
||||||
|
num_heads = ...
|
||||||
|
def __init__(self) -> None: ...
|
||||||
|
def __call__(
|
||||||
|
self, hidden_states: mx.array, causal_attention_mask: mx.array
|
||||||
|
) -> mx.array: ...
|
||||||
|
@staticmethod
|
||||||
|
def reshape_and_transpose(x, batch_size, num_heads, head_dim): # -> array:
|
||||||
|
...
|
||||||
@@ -0,0 +1,12 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from mlx import nn
|
||||||
|
|
||||||
|
class CLIPTextModel(nn.Module):
|
||||||
|
def __init__(self, dims: int, num_encoder_layers: int) -> None: ...
|
||||||
|
def __call__(self, tokens: mx.array) -> tuple[mx.array, mx.array]: ...
|
||||||
|
@staticmethod
|
||||||
|
def create_causal_attention_mask(input_shape: tuple) -> mx.array: ...
|
||||||
@@ -0,0 +1,12 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from mlx import nn
|
||||||
|
|
||||||
|
class EncoderCLIP(nn.Module):
|
||||||
|
def __init__(self, num_encoder_layers: int) -> None: ...
|
||||||
|
def __call__(
|
||||||
|
self, tokens: mx.array, causal_attention_mask: mx.array
|
||||||
|
) -> mx.array: ...
|
||||||
@@ -0,0 +1,25 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from mflux.models.common.tokenizer import Tokenizer
|
||||||
|
from mflux.models.flux.model.flux_text_encoder.clip_encoder.clip_encoder import (
|
||||||
|
CLIPEncoder,
|
||||||
|
)
|
||||||
|
from mflux.models.flux.model.flux_text_encoder.t5_encoder.t5_encoder import T5Encoder
|
||||||
|
|
||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class PromptEncoder:
|
||||||
|
@staticmethod
|
||||||
|
def encode_prompt(
|
||||||
|
prompt: str,
|
||||||
|
prompt_cache: dict[str, tuple[mx.array, mx.array]],
|
||||||
|
t5_tokenizer: Tokenizer,
|
||||||
|
clip_tokenizer: Tokenizer,
|
||||||
|
t5_text_encoder: T5Encoder,
|
||||||
|
clip_text_encoder: CLIPEncoder,
|
||||||
|
) -> tuple[mx.array, mx.array]: ...
|
||||||
@@ -0,0 +1,10 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from mlx import nn
|
||||||
|
|
||||||
|
class T5Attention(nn.Module):
|
||||||
|
def __init__(self) -> None: ...
|
||||||
|
def __call__(self, hidden_states: mx.array) -> mx.array: ...
|
||||||
@@ -0,0 +1,10 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from mlx import nn
|
||||||
|
|
||||||
|
class T5Block(nn.Module):
|
||||||
|
def __init__(self, layer: int) -> None: ...
|
||||||
|
def __call__(self, hidden_states: mx.array) -> mx.array: ...
|
||||||
@@ -0,0 +1,12 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from mlx import nn
|
||||||
|
|
||||||
|
class T5DenseReluDense(nn.Module):
|
||||||
|
def __init__(self) -> None: ...
|
||||||
|
def __call__(self, hidden_states: mx.array) -> mx.array: ...
|
||||||
|
@staticmethod
|
||||||
|
def new_gelu(input_array: mx.array) -> mx.array: ...
|
||||||
@@ -0,0 +1,14 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from mlx import nn
|
||||||
|
|
||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class T5Encoder(nn.Module):
|
||||||
|
def __init__(self) -> None: ...
|
||||||
|
def __call__(self, tokens: mx.array): ...
|
||||||
@@ -0,0 +1,10 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from mlx import nn
|
||||||
|
|
||||||
|
class T5FeedForward(nn.Module):
|
||||||
|
def __init__(self) -> None: ...
|
||||||
|
def __call__(self, hidden_states: mx.array) -> mx.array: ...
|
||||||
@@ -0,0 +1,10 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from mlx import nn
|
||||||
|
|
||||||
|
class T5LayerNorm(nn.Module):
|
||||||
|
def __init__(self) -> None: ...
|
||||||
|
def __call__(self, hidden_states: mx.array) -> mx.array: ...
|
||||||
@@ -0,0 +1,16 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from mlx import nn
|
||||||
|
|
||||||
|
class T5SelfAttention(nn.Module):
|
||||||
|
def __init__(self) -> None: ...
|
||||||
|
def __call__(self, hidden_states: mx.array) -> mx.array: ...
|
||||||
|
@staticmethod
|
||||||
|
def shape(states): # -> array:
|
||||||
|
...
|
||||||
|
@staticmethod
|
||||||
|
def un_shape(states): # -> array:
|
||||||
|
...
|
||||||
@@ -0,0 +1,10 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from mlx import nn
|
||||||
|
|
||||||
|
class AdaLayerNormContinuous(nn.Module):
|
||||||
|
def __init__(self, embedding_dim: int, conditioning_embedding_dim: int) -> None: ...
|
||||||
|
def __call__(self, x: mx.array, text_embeddings: mx.array) -> mx.array: ...
|
||||||
@@ -0,0 +1,12 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from mlx import nn
|
||||||
|
|
||||||
|
class AdaLayerNormZero(nn.Module):
|
||||||
|
def __init__(self) -> None: ...
|
||||||
|
def __call__(
|
||||||
|
self, hidden_states: mx.array, text_embeddings: mx.array
|
||||||
|
) -> tuple[mx.array, mx.array, mx.array, mx.array, mx.array]: ...
|
||||||
@@ -0,0 +1,12 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from mlx import nn
|
||||||
|
|
||||||
|
class AdaLayerNormZeroSingle(nn.Module):
|
||||||
|
def __init__(self) -> None: ...
|
||||||
|
def __call__(
|
||||||
|
self, hidden_states: mx.array, text_embeddings: mx.array
|
||||||
|
) -> tuple[mx.array, mx.array]: ...
|
||||||
@@ -0,0 +1,41 @@
|
|||||||
|
"""
|
||||||
|
This type stub file was generated by pyright.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from mlx import nn
|
||||||
|
|
||||||
|
class AttentionUtils:
|
||||||
|
@staticmethod
|
||||||
|
def process_qkv(
|
||||||
|
hidden_states: mx.array,
|
||||||
|
to_q: nn.Linear,
|
||||||
|
to_k: nn.Linear,
|
||||||
|
to_v: nn.Linear,
|
||||||
|
norm_q: nn.RMSNorm,
|
||||||
|
norm_k: nn.RMSNorm,
|
||||||
|
num_heads: int,
|
||||||
|
head_dim: int,
|
||||||
|
) -> tuple[mx.array, mx.array, mx.array]: ...
|
||||||
|
@staticmethod
|
||||||
|
def compute_attention(
|
||||||
|
query: mx.array,
|
||||||
|
key: mx.array,
|
||||||
|
value: mx.array,
|
||||||
|
batch_size: int,
|
||||||
|
num_heads: int,
|
||||||
|
head_dim: int,
|
||||||
|
mask: mx.array | None = ...,
|
||||||
|
) -> mx.array: ...
|
||||||
|
@staticmethod
|
||||||
|
def convert_key_padding_mask_to_additive_mask(
|
||||||
|
mask: mx.array | None, joint_seq_len: int, txt_seq_len: int
|
||||||
|
) -> mx.array | None: ...
|
||||||
|
@staticmethod
|
||||||
|
def apply_rope(
|
||||||
|
xq: mx.array, xk: mx.array, freqs_cis: mx.array
|
||||||
|
) -> tuple[mx.array, mx.array]: ...
|
||||||
|
@staticmethod
|
||||||
|
def apply_rope_bshd(
|
||||||
|
xq: mx.array, xk: mx.array, cos: mx.array, sin: mx.array
|
||||||
|
) -> tuple[mx.array, mx.array]: ...
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user