Compare commits

..

25 Commits

Author SHA1 Message Date
Evan
879d178057 add kimi tool parseing 2026-01-21 16:53:38 +00:00
Evan
daa31b4472 implement mlx-lm tool calling 2026-01-21 16:53:38 +00:00
ciaranbor
6a9251b920 Add mflux type stubs (#1234)
## Motivation

Simplify image generation review
2026-01-21 15:07:42 +00:00
rltakashige
758464703d Fix GPT OSS tensor sharding with upstream MLX LM (#1223)
## Motivation
MLX LM has given GPT OSS a shard method, but MLX does not have an update
to match.

## Changes

<!-- Describe what you changed in detail -->

## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
<!-- - -->

### Automated Testing
<!-- Describe changes to automated tests, or how existing tests cover
this change -->
<!-- - -->
2026-01-20 18:24:54 +00:00
rltakashige
9e2179c848 Register original layer in CustomMlxLayer (#1229)
## Motivation
Kimi K2 Thinking Pipeline RDMA was broken before.

## Why It Works
No clue tbh

## Test Plan

### Manual Testing
Kimi K2 Thinking and GPT OSS work at the same time on Pipeline RDMA.
Needs exo bench to check more thoroughly

### Automated Testing
Layer composition tests still pass.
2026-01-20 18:20:01 +00:00
Evan Quiney
22b5d836ef swap all instances of model_id: str for model_id: ModelId (#1221)
This change uses the stronger typed ModelId, and introduces some
convenience methods. It also cleans up some code left over from #1204.

## Changes

`model_id: str -> model_id: ModelId`
`repo_id: str -> model_id: ModelId`

Introduces methods on ModelId, in particular ModelId.normalize() to
replace `/` with `--`.

This PR did introduce some circular imports, so has moved some code
around to try and limit them.

## Test Plan

Tests still pass, types still check. As this is about metadata, I
haven't tested inference.
2026-01-20 17:38:06 +00:00
Alex Cheema
ea9c6d6bdf Remove dead local paths code from download_shard (#1227)
## Motivation

The `download_progress_for_local_path` function and the "Handle local
paths" code block in `download_shard` are dead code that cannot be
reached in normal usage. The code checks if `model_id` (e.g.,
"mlx-community/Llama-3.2-3B-Instruct-4bit") exists as a filesystem path,
but model IDs are constrained to HuggingFace repo format and there's no
API pathway to pass local paths.

## Changes

- Removed `download_progress_for_local_path()` function (45 lines)
- Removed the "Handle local paths" block in `download_shard()` (7 lines)

## Why It Works

This code was added in PR #669 as part of a "feature-local-models"
branch, but the feature was never fully integrated. The check
`aios.path.exists(str(shard.model_card.model_id))` would only return
true if a directory literally named
"mlx-community/Llama-3.2-3B-Instruct-4bit" existed in the cwd, which
doesn't happen in practice. Offline caching is already handled by
`fetch_file_list_with_cache`.

## Test Plan

### Manual Testing
- Run exo normally and verify downloads still work

### Automated Testing
- Existing tests pass (this code had no test coverage)

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 17:07:27 +00:00
Alex Cheema
4ea66d427b Reduce download log spam (#1225)
## Motivation

When `skip_download=True`, exo was logging a lot of unnecessary messages during periodic download status checks. This resulted in spammy logs that made it hard to see important messages.

## Changes

- Only log "Downloading ... with allow_patterns=..." when actually downloading (not when skip_download is true)
- Changed periodic download progress check logs from INFO to DEBUG level

## Why It Works

The `skip_download=True` parameter is used when checking download status without actually downloading. By guarding the log behind `if not skip_download:`, we avoid logging on every status check. Changing the periodic emitting logs to DEBUG level reduces noise while still keeping them available for debugging.

## Test Plan

### Manual Testing
- Run exo and observe that logs are less spammy during normal operation
- Use -v or -vv flags to see DEBUG logs when needed

### Automated Testing
- Existing tests cover this code path
2026-01-20 16:57:05 +00:00
rltakashige
8b709e68b2 Mark slow tests as slow (#1220)
## Motivation

<!-- Why is this change needed? What problem does it solve? -->
<!-- If it fixes an open issue, please link to the issue here -->

## Changes

<!-- Describe what you changed in detail -->

## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
<!-- - -->

### Automated Testing
<!-- Describe changes to automated tests, or how existing tests cover
this change -->
<!-- - -->
2026-01-20 15:03:46 +00:00
Evan Quiney
4da6eeb11f fix a test broken by #1204 (#1219)
bad merge broke a test - fix it
2026-01-20 14:56:20 +00:00
Evan
3d2eee4884 quiet localhost log
this log is just noise - remove it
2026-01-20 14:51:26 +00:00
Evan
116558839e don't clear mdns discovered connections
pingers currently removes mdns discovered connections - these systems
should be independent
2026-01-20 14:46:20 +00:00
Evan Quiney
d4f551c602 Simplify model cards (#1204)
## Motivation

We have a lot of unneeded data in the model card - lets just keep the
necessary stuff and add back more data when we need it

## Test Plan

EXO still runs! (pipeline on 2)

Co-authored-by: rltakashige <rl.takashige@gmail.com>
2026-01-20 11:01:19 +00:00
Alex Cheema
176ab5ba40 Add GLM-4.7-Flash model cards (4bit, 5bit, 6bit, 8bit) (#1214)
## Motivation

Add support for GLM-4.7-Flash, a lighter variant of GLM-4.7 with the
`glm4_moe_lite` architecture. These models are smaller and faster while
maintaining good performance.

## Changes

1. **Added 4 new model cards** for GLM-4.7-Flash variants:
   - `glm-4.7-flash-4bit` (~18 GB)
   - `glm-4.7-flash-5bit` (~21 GB)
   - `glm-4.7-flash-6bit` (~25 GB)
   - `glm-4.7-flash-8bit` (~32 GB)

   All variants have:
   - `n_layers`: 47 (vs 91 in GLM-4.7)
   - `hidden_size`: 2048 (vs 5120 in GLM-4.7)
   - `supports_tensor`: True (native `shard()` method)

2. **Bumped mlx from 0.30.1 to 0.30.3** - required by mlx-lm 0.30.4

3. **Updated mlx-lm from 0.30.2 to 0.30.4** - adds `glm4_moe_lite`
architecture support

4. **Added type ignores** in `auto_parallel.py` for stricter type
annotations in new mlx-lm

5. **Fixed EOS token IDs** for GLM-4.7-Flash - uses different tokenizer
with IDs `[154820, 154827, 154829]` vs other GLM models' `[151336,
151329, 151338]`

6. **Renamed `MLX_IBV_DEVICES` to `MLX_JACCL_DEVICES`** - env var name
changed in new mlx

## Why It Works

The model cards follow the same pattern as existing GLM-4.7 models.
Tensor parallel support is enabled because GLM-4.7-Flash implements the
native `shard()` method in mlx-lm 0.30.4, which is automatically
detected in `auto_parallel.py`.

GLM-4.7-Flash uses a new tokenizer with different special token IDs.
Without the correct EOS tokens, generation wouldn't stop properly.

## Test Plan

### Manual Testing
Tested generation with GLM-4.7-Flash-4bit - now correctly stops at EOS
tokens.

### Automated Testing
- `basedpyright`: 0 errors
- `ruff check`: All checks passed
- `pytest`: 162/162 tests pass (excluding pre-existing
`test_distributed_fix.py` timeout failures)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 03:58:09 +00:00
rltakashige
f5e6aa82d2 Load layers individually (#1211)
## Motivation

Certain models hang at model loading in tensor parallel. 

Hopefully closes #1205 

## Changes

- Load layer by layer for tensor parallel sharding
- Move eval_with_timeout to auto_parallel.py to resolve circular import.

## Why It Works

The naive way to fix this is to use load model with lazy = False and
then shard in tensor parallel. However, this requires the entire model
to be loaded into memory.

Instead, we can load layer by layer and shard after loading. There is a
very small memory footprint to this, but it is negligible.

I tried loading layer by layer after the sharding, and this allowed
model loading but got stuck at warming up.

## Test Plan

### Manual Testing
GPT OSS loads with TP and FAST SYNCH. Kimi does too.

### Automated Testing
We need to run a suite of exo_bench before merging this!
2026-01-20 03:26:51 +00:00
Alex Cheema
39f0ed6018 Prepend <think> tag to stream for thinking models like GLM-4.7 (#1186)
## Motivation

For thinking models like GLM-4.7, the `<think>` tag is inserted by the
tokenizer's `apply_chat_template()` into the **prompt** (input). The
model generates tokens starting *after* this tag, so `<think>` never
appears in the streamed output. The frontend expects
`<think>...</think>` tags to extract and display thinking content.

**Log evidence:**
```
[gMASK]<sop><|system|>...<|user|>...<|assistant|><think>
```
The prompt ends with `<think>`, so the model generates content after it,
never returning the opening tag.

## Changes

- Added `detect_thinking_prompt_suffix()` helper function in
`utils_mlx.py` to detect if a prompt ends with `<think>` tag
- Added `parse_thinking_models()` generator wrapper in `runner.py` that
prepends the thinking tag to the output stream
- Modified the main generation loop to use the thinking wrapper for
non-GptOssModel models when a thinking prefix is detected
- Updated test mocks to handle the new `apply_chat_template` call

## Why It Works

The solution follows the same pattern as `parse_gpt_oss()` - a generator
wrapper that transforms the output stream. When the chat template ends
with `<think>`, we prepend this tag to the first generated token so the
frontend receives the complete `<think>...</think>` structure it
expects.

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
- Run exo: `uv run exo`
- Send a chat request to GLM-4.7:
  ```bash
curl http://localhost:52415/v1/chat/completions -H "Content-Type:
application/json" -d '{
    "model": "mlx-community/GLM-4.7-8bit-gs32",
    "messages": [{"role": "user", "content": "What is 2+2?"}],
    "stream": true
  }'
  ```
- Verify the streamed response starts with `<think>` tag
- Verify the frontend dashboard correctly shows the thinking section
collapsed

### Automated Testing
- All 72 worker tests pass: `uv run pytest src/exo/worker/`
- Type checker passes: `uv run basedpyright`
- Linter passes: `uv run ruff check`

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Co-authored-by: Ryuichi Leo Takashige <leo@exolabs.net>
2026-01-19 19:44:51 +00:00
Alex Cheema
ee43b598fe Split NodePerformanceProfile into granular state mappings (#1209)
## Motivation

The current `NodePerformanceProfile` is a monolithic object where every
update (even 1-second memory updates) replaces the entire profile,
touching unrelated data. Different fields update at vastly different
frequencies:

| Data | Update Frequency |
|------|------------------|
| Memory, System | 1 second |
| Thunderbolt | 5 seconds |
| Network interfaces | 10 seconds |
| Friendly name | 60 seconds |
| Model/Chip ID | Once at startup |

## Changes

Split into separate state mappings so each data type updates
independently:

- `node_identities`: Static and slow-changing data (model_id, chip_id,
friendly_name)
- `node_memory`: RAM and swap usage
- `node_system`: GPU usage, temperature, power, CPU metrics
- `node_network`: Network interface information
- `node_thunderbolt`: Thunderbolt interface identifiers

Added a backwards-compatible `node_profiles` property that reconstructs
`NodePerformanceProfile` from the granular mappings for dashboard
compatibility.

**Files modified:**
- `src/exo/shared/types/profiling.py` - Added `NodeIdentity`,
`NodeNetworkInfo`, `NodeThunderboltInfo` types
- `src/exo/shared/types/state.py` - Added 5 new mappings +
`node_profiles` property
- `src/exo/shared/apply.py` - Updated `apply_node_gathered_info` and
`apply_node_timed_out`

## Why It Works

Each info type now writes only to its specific mapping, avoiding
unnecessary updates to unrelated data. The `MacThunderboltConnections`
handler reads from `node_thunderbolt` instead of the old `node_profiles`
for RDMA connection mapping. The backwards-compatible property ensures
the dashboard continues to work unchanged.

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
- Start exo and verify dashboard shows node info
- Verify memory/GPU updates stream correctly
- Check that node timeout properly cleans up all mappings

### Automated Testing
- All 162 existing tests pass
- basedpyright: 0 errors
- ruff check: All checks passed
- nix fmt: Applied

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-19 18:24:15 +00:00
rltakashige
5fd55594c9 Wrap pipeline models for explicit mx.depends between cache and logits (#1206)
## Motivation

GPU timeouts often when prompt size > profile_step_size. It also happens
for seemingly random models.

## Changes

Add mx.depends for cache on the logits.
All gather at the model level rather than the layer level, reducing the
amount of data sent.

## Why It Works

mlx_lm's prefill loop only evaluates cache state, not logits.
When prompt > prefill_step_size, the all_gather is never evaluated,
causing GPU timeout.

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
<!-- - -->

### Automated Testing
Added failing test cases and then resolved them.
2026-01-19 17:49:42 +00:00
Jake Hillion
5ab1f8b3e2 NetworkSetupHelper: detect stale startup script content
The daemonAlreadyInstalled() function checked that the startup script
file existed and validated plist properties, but did not compare the
actual script content. If the setupScript constant was updated in a new
app version, the stale on-disk script would not be detected or replaced.

Added a guard clause that reads the installed script from disk and
compares it against the expected setupScript content (with whitespace
normalization). When content differs, the function returns false,
triggering the reinstallation flow with an admin privileges prompt.

Test plan:
- Installed on a cluster that had the previous network config. Got the
  popup asking for permissions. After accepting I could run Kimi K2
  Thinking Tensor RDMA on all 4 nodes.
2026-01-19 17:36:15 +00:00
Evan Quiney
2202685c3e refactor all information sources (including ipless rdma discovery) (#928)
## Motivation

Information gathering is tightly coupled to MacMon - we should start
generalizing our information sources so we can add more in future.

## Changes

Added a new system to gather any information. Currently, it is attached
to the Worker - though this is mostly to keep the data processing logic
simple. It could be made independent quite easily.

I also refactored topology to include different kinds of connections as
we can gather RDMA connections without having a pre-existing socket
connection, and made the relevant placement updates. We should no longer
need the network locations script in the app.

Other sources of information now include:
- static node information like "model" and "chip" (macos, "Unknown"
fallback)
- device friendly name (macos, falls back to device hostname)
- network interfaces + ips (cross platform)
- thunderbolt interfaces (macos)
- thunderbolt connections (macos)
- RAM usage (cross platform)
- per-device configuration written to EXO_HOME/config.toml

## Limitations

Model and Chip are not cross platform concepts.

We do not differentiate between unified and non-unified memory systems.

A lot of this data collection is based on simple timers. Watching the SC
store on macos is the correct way to gather some of this information,
but requires a detour into rust for macos.

## Why It Works

The InfoGatherer is a generic subsystem which returns a union of metric
datatypes. It writes them to an event, which is applied to state. It is
currently re-spawned with the worker so each cluster receives the
correct information.

As for topology, macOS identifies TB ports with a uuid in
SPThunderboltDataType, and also stores remote uuids if it can find them.
These changes read that data with the system_profiler, hopefully not so
often as to cause notable performance impacts (though this should be
tuned) but frequently enough for moderate responsiveness.
As we can identify TB connections between devices without needing ips
attached to each interface, we can remove the network setup script
(almost) completely.

## Test Plan

### Manual Testing
Spawn RDMA instances without enabling DHCP on the RDMA interfaces.

### Automated Testing
Updated the current master and shared tests to cover the topology
refactor and new events.

---------

Co-authored-by: Sami Khan <smsak99@gmail.com>
Co-authored-by: Alex Cheema <alexcheema123@gmail.com>
Co-authored-by: Jake Hillion <jake@hillion.co.uk>
2026-01-19 16:58:09 +00:00
Andrei Onel
ce3ad391b1 Update README.md with some changes from release 1.0.61 (#1157)
Updated README.md with documentation for four new features:

- added a "Benchmarking" section documenting the exo-bench tool for
measuring model performance across different placement configurations
- documented the custom namespace feature for cluster isolation in the
macOS app section
- added a "Configuration Options" subsection explaining the --no-worker
CLI flag for coordinator-only nodes
- added a "File Locations (Linux)" subsection documenting XDG Base
Directory Specification compliance on Linux systems

Issue #930
2026-01-19 16:43:18 +00:00
Jake Hillion
fb0151630d shard_downloader: make on_progress callback async
The on_progress callback was synchronous but always invoked from async
contexts, forcing the use of send_nowait() which could raise WouldBlock
if the channel buffer was full, potentially dropping progress updates.

Changed the callback type from `Callable[[ShardMetadata,
RepoDownloadProgress], None]` to return a coroutine, updated all
implementations to be async, and replaced send_nowait() with await
send() in the worker's download progress handler.

This allows proper backpressure handling when sending download progress
events through the channel, eliminating the "Footgun!" that was
previously documented in the code.

Test plan:
- Built a DMG and ran it on one node. All existing models showed as
  downloaded.
- Downloaded a new model. The progress bar on the download page worked.
- Downloaded another new model. The progress bar on the home page
  worked.
2026-01-19 16:19:37 +00:00
Alex Cheema
346b13e2c9 Enhance LaTeX rendering in dashboard markdown (#1197)
## Motivation

When models output LaTeX-formatted math proofs, the dashboard was not
rendering them correctly. Issues included:
- `\documentclass`, `\begin{document}`, `\usepackage` showing as raw
text
- `$...$` inline math with complex expressions (like `\frac`, `\ldots`)
not rendering due to markdown escaping backslashes
- `\begin{align*}...\end{align*}` and other math environments showing as
raw text
- `\emph{...}`, `\textbf{...}` LaTeX formatting commands not being
converted
- `$\require{...}$` (MathJax-specific) causing KaTeX errors
- `\begin{proof}...\end{proof}` showing as raw text

## Changes

Enhanced `MarkdownContent.svelte` with comprehensive LaTeX support:

**Math extraction before markdown processing:**
- Extract `$...$`, `$$...$$`, `\(...\)`, `\[...\]` into placeholders
before markdown processes the text
- Use alphanumeric placeholders (`MATHPLACEHOLDERINLINE0END`) that won't
be interpreted as HTML tags
- Restore and render with KaTeX after markdown processing

**LaTeX document command removal:**
- Strip `\documentclass{...}`, `\usepackage{...}`, `\begin{document}`,
`\end{document}`
- Strip `\maketitle`, `\title{...}`, `\author{...}`, `\date{...}`
- Strip `\require{...}` (MathJax-specific, not KaTeX)
- Replace `tikzpicture` environments with `[diagram]` placeholder
- Strip `\label{...}` cross-reference commands

**LaTeX math environments:**
- Convert `\begin{align*}`, `\begin{equation}`, `\begin{gather}`, etc.
to display math blocks

**LaTeX text formatting:**
- `\emph{...}` and `\textit{...}` → `<em>...</em>`
- `\textbf{...}` → `<strong>...</strong>`
- `\texttt{...}` → `<code>...</code>`
- `\underline{...}` → `<u>...</u>`

**LaTeX environments styling:**
- `\begin{proof}...\end{proof}` → styled proof block with QED symbol
- `\begin{theorem}`, `\begin{lemma}`, etc. → styled theorem blocks

**Display math enhancements:**
- Wrapped in styled container with subtle gold border
- "LaTeX" label and copy button appear on hover
- Dark theme KaTeX color overrides for better readability
- Custom scrollbar for overflow

## Why It Works

The key insight is that markdown processing was escaping backslashes in
LaTeX before KaTeX could see them. By extracting all math expressions
into alphanumeric placeholders *before* markdown runs, then restoring
them *after*, the LaTeX content passes through to KaTeX unmodified.

Using purely alphanumeric placeholders like `MATHPLACEHOLDERINLINE0END`
instead of `<<MATH_INLINE_0>>` prevents markdown from interpreting them
as HTML tags and stripping them.

## Test Plan

### Manual Testing
- Hardware: Any machine with the dashboard
- What you did:
  - Ask model to "write a proof in latex"
  - Verify inline math like `$x \in S$` renders correctly
- Verify display math like `\begin{align*}...\end{align*}` renders as
block
  - Verify `\documentclass`, `\begin{document}` are stripped (not shown)
  - Verify `\emph{...}` converts to italics
  - Verify copy button works on display math blocks
- Test edge cases: `$5` (currency) stays as text, `\$50` (escaped)
becomes `$50`

Before:
<img width="799" height="637" alt="Screenshot 2026-01-19 at 11 51 22 AM"
src="https://github.com/user-attachments/assets/62a705b8-b3c2-47b8-afd0-5d0c1b240e44"
/>

After:
<img width="809" height="642" alt="Screenshot 2026-01-19 at 11 46 58 AM"
src="https://github.com/user-attachments/assets/4f35fa1d-333c-4285-bc68-58a50f8f148e"
/>


### Automated Testing
- Dashboard builds successfully with `npm run build`
- Existing functionality preserved

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-19 14:50:41 +00:00
rltakashige
ea0588429b Custom mlx layer composition (#1201)
## Motivation

With a single pipeline layer, PipelineFirstLayer gets composed with
PipelineLastLayer.

## Changes

<!-- Describe what you changed in detail -->

## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing


### Automated Testing
Made failing tests. Fixed them!
2026-01-19 12:36:25 +00:00
rltakashige
73b3f87e07 Set swa_idx and ga_idx for single layer (#1202)
## Motivation

Layer types does not contain either "sliding_attention" or
"full_attention" for pipeline parallel (single layer).

## Changes

<!-- Describe what you changed in detail -->

## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing
Manually tested single layer of GPT OSS. Doesn't crash

### Automated Testing
<!-- Describe changes to automated tests, or how existing tests cover
this change -->
<!-- - -->
2026-01-19 12:31:11 +00:00
311 changed files with 9199 additions and 5458 deletions

View File

@@ -0,0 +1,7 @@
"""
This type stub file was generated by pyright.
"""
import os
if "TOKENIZERS_PARALLELISM" not in os.environ: ...

View File

@@ -0,0 +1,3 @@
"""
This type stub file was generated by pyright.
"""

View File

@@ -0,0 +1,47 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
import PIL.Image
import tqdm
from typing import Protocol
from mflux.models.common.config.config import Config
class BeforeLoopCallback(Protocol):
def call_before_loop(
self,
seed: int,
prompt: str,
latents: mx.array,
config: Config,
canny_image: PIL.Image.Image | None = ...,
depth_image: PIL.Image.Image | None = ...,
) -> None: ...
class InLoopCallback(Protocol):
def call_in_loop(
self,
t: int,
seed: int,
prompt: str,
latents: mx.array,
config: Config,
time_steps: tqdm,
) -> None: ...
class AfterLoopCallback(Protocol):
def call_after_loop(
self, seed: int, prompt: str, latents: mx.array, config: Config
) -> None: ...
class InterruptCallback(Protocol):
def call_interrupt(
self,
t: int,
seed: int,
prompt: str,
latents: mx.array,
config: Config,
time_steps: tqdm,
) -> None: ...

View File

@@ -0,0 +1,24 @@
"""
This type stub file was generated by pyright.
"""
from typing import TYPE_CHECKING
from mflux.callbacks.callback import (
AfterLoopCallback,
BeforeLoopCallback,
InLoopCallback,
InterruptCallback,
)
from mflux.callbacks.generation_context import GenerationContext
from mflux.models.common.config.config import Config
if TYPE_CHECKING: ...
class CallbackRegistry:
def __init__(self) -> None: ...
def register(self, callback) -> None: ...
def start(self, seed: int, prompt: str, config: Config) -> GenerationContext: ...
def before_loop_callbacks(self) -> list[BeforeLoopCallback]: ...
def in_loop_callbacks(self) -> list[InLoopCallback]: ...
def after_loop_callbacks(self) -> list[AfterLoopCallback]: ...
def interrupt_callbacks(self) -> list[InterruptCallback]: ...

View File

@@ -0,0 +1,29 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
import PIL.Image
import tqdm
from typing import TYPE_CHECKING
from mflux.callbacks.callback_registry import CallbackRegistry
from mflux.models.common.config.config import Config
if TYPE_CHECKING: ...
class GenerationContext:
def __init__(
self, registry: CallbackRegistry, seed: int, prompt: str, config: Config
) -> None: ...
def before_loop(
self,
latents: mx.array,
*,
canny_image: PIL.Image.Image | None = ...,
depth_image: PIL.Image.Image | None = ...,
) -> None: ...
def in_loop(self, t: int, latents: mx.array, time_steps: tqdm = ...) -> None: ...
def after_loop(self, latents: mx.array) -> None: ...
def interruption(
self, t: int, latents: mx.array, time_steps: tqdm = ...
) -> None: ...

View File

@@ -0,0 +1,3 @@
"""
This type stub file was generated by pyright.
"""

View File

@@ -0,0 +1,22 @@
"""
This type stub file was generated by pyright.
"""
import os
BATTERY_PERCENTAGE_STOP_LIMIT = ...
CONTROLNET_STRENGTH = ...
DEFAULT_DEV_FILL_GUIDANCE = ...
DEFAULT_DEPTH_GUIDANCE = ...
DIMENSION_STEP_PIXELS = ...
GUIDANCE_SCALE = ...
GUIDANCE_SCALE_KONTEXT = ...
IMAGE_STRENGTH = ...
MODEL_CHOICES = ...
MODEL_INFERENCE_STEPS = ...
QUANTIZE_CHOICES = ...
if os.environ.get("MFLUX_CACHE_DIR"):
MFLUX_CACHE_DIR = ...
else:
MFLUX_CACHE_DIR = ...
MFLUX_LORA_CACHE_DIR = ...

View File

@@ -0,0 +1,3 @@
"""
This type stub file was generated by pyright.
"""

View File

@@ -0,0 +1,3 @@
"""
This type stub file was generated by pyright.
"""

View File

@@ -0,0 +1,3 @@
"""
This type stub file was generated by pyright.
"""

View File

@@ -0,0 +1,8 @@
"""
This type stub file was generated by pyright.
"""
from mflux.models.common.config.config import Config
from mflux.models.common.config.model_config import ModelConfig
__all__ = ["Config", "ModelConfig"]

View File

@@ -0,0 +1,66 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from pathlib import Path
from typing import Any
from tqdm import tqdm
from mflux.models.common.config.model_config import ModelConfig
logger = ...
class Config:
def __init__(
self,
model_config: ModelConfig,
num_inference_steps: int = ...,
height: int = ...,
width: int = ...,
guidance: float = ...,
image_path: Path | str | None = ...,
image_strength: float | None = ...,
depth_image_path: Path | str | None = ...,
redux_image_paths: list[Path | str] | None = ...,
redux_image_strengths: list[float] | None = ...,
masked_image_path: Path | str | None = ...,
controlnet_strength: float | None = ...,
scheduler: str = ...,
) -> None: ...
@property
def height(self) -> int: ...
@property
def width(self) -> int: ...
@width.setter
def width(self, value): # -> None:
...
@property
def image_seq_len(self) -> int: ...
@property
def guidance(self) -> float: ...
@property
def num_inference_steps(self) -> int: ...
@property
def precision(self) -> mx.Dtype: ...
@property
def num_train_steps(self) -> int: ...
@property
def image_path(self) -> Path | None: ...
@property
def image_strength(self) -> float | None: ...
@property
def depth_image_path(self) -> Path | None: ...
@property
def redux_image_paths(self) -> list[Path] | None: ...
@property
def redux_image_strengths(self) -> list[float] | None: ...
@property
def masked_image_path(self) -> Path | None: ...
@property
def init_time_step(self) -> int: ...
@property
def time_steps(self) -> tqdm: ...
@property
def controlnet_strength(self) -> float | None: ...
@property
def scheduler(self) -> Any: ...

View File

@@ -0,0 +1,86 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from functools import lru_cache
from typing import Literal
class ModelConfig:
precision: mx.Dtype = ...
def __init__(
self,
priority: int,
aliases: list[str],
model_name: str,
base_model: str | None,
controlnet_model: str | None,
custom_transformer_model: str | None,
num_train_steps: int | None,
max_sequence_length: int | None,
supports_guidance: bool | None,
requires_sigma_shift: bool | None,
transformer_overrides: dict | None = ...,
) -> None: ...
@staticmethod
@lru_cache
def dev() -> ModelConfig: ...
@staticmethod
@lru_cache
def schnell() -> ModelConfig: ...
@staticmethod
@lru_cache
def dev_kontext() -> ModelConfig: ...
@staticmethod
@lru_cache
def dev_fill() -> ModelConfig: ...
@staticmethod
@lru_cache
def dev_redux() -> ModelConfig: ...
@staticmethod
@lru_cache
def dev_depth() -> ModelConfig: ...
@staticmethod
@lru_cache
def dev_controlnet_canny() -> ModelConfig: ...
@staticmethod
@lru_cache
def schnell_controlnet_canny() -> ModelConfig: ...
@staticmethod
@lru_cache
def dev_controlnet_upscaler() -> ModelConfig: ...
@staticmethod
@lru_cache
def dev_fill_catvton() -> ModelConfig: ...
@staticmethod
@lru_cache
def krea_dev() -> ModelConfig: ...
@staticmethod
@lru_cache
def flux2_klein_4b() -> ModelConfig: ...
@staticmethod
@lru_cache
def flux2_klein_9b() -> ModelConfig: ...
@staticmethod
@lru_cache
def qwen_image() -> ModelConfig: ...
@staticmethod
@lru_cache
def qwen_image_edit() -> ModelConfig: ...
@staticmethod
@lru_cache
def fibo() -> ModelConfig: ...
@staticmethod
@lru_cache
def z_image_turbo() -> ModelConfig: ...
@staticmethod
@lru_cache
def seedvr2_3b() -> ModelConfig: ...
def x_embedder_input_dim(self) -> int: ...
def is_canny(self) -> bool: ...
@staticmethod
def from_name(
model_name: str, base_model: Literal["dev", "schnell", "krea-dev"] | None = ...
) -> ModelConfig: ...
AVAILABLE_MODELS = ...

View File

@@ -0,0 +1,7 @@
"""
This type stub file was generated by pyright.
"""
"""
This type stub file was generated by pyright.
"""

View File

@@ -0,0 +1,49 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from pathlib import Path
from typing import TYPE_CHECKING, TypeAlias
from mlx import nn
from mflux.models.common.vae.tiling_config import TilingConfig
from mflux.models.fibo.latent_creator.fibo_latent_creator import FiboLatentCreator
from mflux.models.flux.latent_creator.flux_latent_creator import FluxLatentCreator
from mflux.models.qwen.latent_creator.qwen_latent_creator import QwenLatentCreator
from mflux.models.z_image.latent_creator.z_image_latent_creator import (
ZImageLatentCreator,
)
if TYPE_CHECKING:
LatentCreatorType: TypeAlias = type[
FiboLatentCreator | FluxLatentCreator | QwenLatentCreator | ZImageLatentCreator
]
class Img2Img:
def __init__(
self,
vae: nn.Module,
latent_creator: LatentCreatorType,
sigmas: mx.array,
init_time_step: int,
image_path: str | Path | None,
tiling_config: TilingConfig | None = ...,
) -> None: ...
class LatentCreator:
@staticmethod
def create_for_txt2img_or_img2img(
seed: int, height: int, width: int, img2img: Img2Img
) -> mx.array: ...
@staticmethod
def encode_image(
vae: nn.Module,
image_path: str | Path,
height: int,
width: int,
tiling_config: TilingConfig | None = ...,
) -> mx.array: ...
@staticmethod
def add_noise_by_interpolation(
clean: mx.array, noise: mx.array, sigma: float
) -> mx.array: ...

View File

@@ -0,0 +1,3 @@
"""
This type stub file was generated by pyright.
"""

View File

@@ -0,0 +1,13 @@
"""
This type stub file was generated by pyright.
"""
from mlx import nn
from mflux.models.common.lora.layer.linear_lora_layer import LoRALinear
class FusedLoRALinear(nn.Module):
def __init__(
self, base_linear: nn.Linear | nn.QuantizedLinear, loras: list[LoRALinear]
) -> None: ...
def __call__(self, x): # -> array:
...

View File

@@ -0,0 +1,22 @@
"""
This type stub file was generated by pyright.
"""
from mlx import nn
class LoRALinear(nn.Module):
@staticmethod
def from_linear(
linear: nn.Linear | nn.QuantizedLinear, r: int = ..., scale: float = ...
): # -> LoRALinear:
...
def __init__(
self,
input_dims: int,
output_dims: int,
r: int = ...,
scale: float = ...,
bias: bool = ...,
) -> None: ...
def __call__(self, x): # -> array:
...

View File

@@ -0,0 +1,26 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
import mlx.nn as nn
from collections.abc import Callable
from dataclasses import dataclass
from mflux.models.common.lora.mapping.lora_mapping import LoRATarget
@dataclass
class PatternMatch:
source_pattern: str
target_path: str
matrix_name: str
transpose: bool
transform: Callable[[mx.array], mx.array] | None = ...
class LoRALoader:
@staticmethod
def load_and_apply_lora(
lora_mapping: list[LoRATarget],
transformer: nn.Module,
lora_paths: list[str] | None = ...,
lora_scales: list[float] | None = ...,
) -> tuple[list[str], list[float]]: ...

View File

@@ -0,0 +1,21 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from collections.abc import Callable
from dataclasses import dataclass
from typing import List, Protocol
@dataclass
class LoRATarget:
model_path: str
possible_up_patterns: List[str]
possible_down_patterns: List[str]
possible_alpha_patterns: List[str] = ...
up_transform: Callable[[mx.array], mx.array] | None = ...
down_transform: Callable[[mx.array], mx.array] | None = ...
class LoRAMapping(Protocol):
@staticmethod
def get_mapping() -> List[LoRATarget]: ...

View File

@@ -0,0 +1,9 @@
"""
This type stub file was generated by pyright.
"""
import mlx.nn as nn
class LoRASaver:
@staticmethod
def bake_and_strip_lora(module: nn.Module) -> nn.Module: ...

View File

@@ -0,0 +1,35 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
class LoraTransforms:
@staticmethod
def split_q_up(tensor: mx.array) -> mx.array: ...
@staticmethod
def split_k_up(tensor: mx.array) -> mx.array: ...
@staticmethod
def split_v_up(tensor: mx.array) -> mx.array: ...
@staticmethod
def split_q_down(tensor: mx.array) -> mx.array: ...
@staticmethod
def split_k_down(tensor: mx.array) -> mx.array: ...
@staticmethod
def split_v_down(tensor: mx.array) -> mx.array: ...
@staticmethod
def split_single_q_up(tensor: mx.array) -> mx.array: ...
@staticmethod
def split_single_k_up(tensor: mx.array) -> mx.array: ...
@staticmethod
def split_single_v_up(tensor: mx.array) -> mx.array: ...
@staticmethod
def split_single_mlp_up(tensor: mx.array) -> mx.array: ...
@staticmethod
def split_single_q_down(tensor: mx.array) -> mx.array: ...
@staticmethod
def split_single_k_down(tensor: mx.array) -> mx.array: ...
@staticmethod
def split_single_v_down(tensor: mx.array) -> mx.array: ...
@staticmethod
def split_single_mlp_down(tensor: mx.array) -> mx.array: ...

View File

@@ -0,0 +1,17 @@
"""
This type stub file was generated by pyright.
"""
from mflux.models.common.resolution.config_resolution import ConfigResolution
from mflux.models.common.resolution.lora_resolution import LoraResolution
from mflux.models.common.resolution.path_resolution import PathResolution
from mflux.models.common.resolution.quantization_resolution import (
QuantizationResolution,
)
__all__ = [
"ConfigResolution",
"LoraResolution",
"PathResolution",
"QuantizationResolution",
]

View File

@@ -0,0 +1,39 @@
"""
This type stub file was generated by pyright.
"""
from enum import Enum
from typing import NamedTuple
class QuantizationAction(Enum):
NONE = ...
STORED = ...
REQUESTED = ...
class PathAction(Enum):
LOCAL = ...
HUGGINGFACE_CACHED = ...
HUGGINGFACE = ...
ERROR = ...
class LoraAction(Enum):
LOCAL = ...
REGISTRY = ...
HUGGINGFACE_COLLECTION_CACHED = ...
HUGGINGFACE_COLLECTION = ...
HUGGINGFACE_REPO_CACHED = ...
HUGGINGFACE_REPO = ...
ERROR = ...
class ConfigAction(Enum):
EXACT_MATCH = ...
EXPLICIT_BASE = ...
INFER_SUBSTRING = ...
ERROR = ...
class Rule(NamedTuple):
priority: int
name: str
check: str
action: QuantizationAction | PathAction | LoraAction | ConfigAction
...

View File

@@ -0,0 +1,14 @@
"""
This type stub file was generated by pyright.
"""
from typing import TYPE_CHECKING
from mflux.models.common.config.model_config import ModelConfig
if TYPE_CHECKING: ...
logger = ...
class ConfigResolution:
RULES = ...
@staticmethod
def resolve(model_name: str, base_model: str | None = ...) -> ModelConfig: ...

View File

@@ -0,0 +1,21 @@
"""
This type stub file was generated by pyright.
"""
from pathlib import Path
logger = ...
class LoraResolution:
RULES = ...
_registry: dict[str, Path] = ...
@staticmethod
def resolve(path: str) -> str: ...
@staticmethod
def resolve_paths(paths: list[str] | None) -> list[str]: ...
@staticmethod
def resolve_scales(scales: list[float] | None, num_paths: int) -> list[float]: ...
@staticmethod
def get_registry() -> dict[str, Path]: ...
@staticmethod
def discover_files(library_paths: list[Path]) -> dict[str, Path]: ...

View File

@@ -0,0 +1,12 @@
"""
This type stub file was generated by pyright.
"""
from pathlib import Path
logger = ...
class PathResolution:
RULES = ...
@staticmethod
def resolve(path: str | None, patterns: list[str] | None = ...) -> Path | None: ...

View File

@@ -0,0 +1,12 @@
"""
This type stub file was generated by pyright.
"""
logger = ...
class QuantizationResolution:
RULES = ...
@staticmethod
def resolve(
stored: int | None, requested: int | None
) -> tuple[int | None, str | None]: ...

View File

@@ -0,0 +1,26 @@
"""
This type stub file was generated by pyright.
"""
from .flow_match_euler_discrete_scheduler import FlowMatchEulerDiscreteScheduler
from .linear_scheduler import LinearScheduler
from .seedvr2_euler_scheduler import SeedVR2EulerScheduler
__all__ = [
"LinearScheduler",
"FlowMatchEulerDiscreteScheduler",
"SeedVR2EulerScheduler",
]
class SchedulerModuleNotFound(ValueError): ...
class SchedulerClassNotFound(ValueError): ...
class InvalidSchedulerType(TypeError): ...
SCHEDULER_REGISTRY = ...
def register_contrib(scheduler_object, scheduler_name=...): # -> None:
...
def try_import_external_scheduler(
scheduler_object_path: str,
): # -> type[BaseScheduler]:
...

View File

@@ -0,0 +1,16 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from abc import ABC, abstractmethod
class BaseScheduler(ABC):
@property
@abstractmethod
def sigmas(self) -> mx.array: ...
@abstractmethod
def step(
self, noise: mx.array, timestep: int, latents: mx.array, **kwargs
) -> mx.array: ...
def scale_model_input(self, latents: mx.array, t: int) -> mx.array: ...

View File

@@ -0,0 +1,26 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from typing import TYPE_CHECKING
from mflux.models.common.config.config import Config
from mflux.models.common.schedulers.base_scheduler import BaseScheduler
if TYPE_CHECKING: ...
class FlowMatchEulerDiscreteScheduler(BaseScheduler):
def __init__(self, config: Config) -> None: ...
@property
def sigmas(self) -> mx.array: ...
@property
def timesteps(self) -> mx.array: ...
def set_image_seq_len(self, image_seq_len: int) -> None: ...
@staticmethod
def get_timesteps_and_sigmas(
image_seq_len: int, num_inference_steps: int, num_train_timesteps: int = ...
) -> tuple[mx.array, mx.array]: ...
def step(
self, noise: mx.array, timestep: int, latents: mx.array, **kwargs
) -> mx.array: ...
def scale_model_input(self, latents: mx.array, t: int) -> mx.array: ...

View File

@@ -0,0 +1,20 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from typing import TYPE_CHECKING
from mflux.models.common.config.config import Config
from mflux.models.common.schedulers.base_scheduler import BaseScheduler
if TYPE_CHECKING: ...
class LinearScheduler(BaseScheduler):
def __init__(self, config: Config) -> None: ...
@property
def sigmas(self) -> mx.array: ...
@property
def timesteps(self) -> mx.array: ...
def step(
self, noise: mx.array, timestep: int, latents: mx.array, **kwargs
) -> mx.array: ...

View File

@@ -0,0 +1,20 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from typing import TYPE_CHECKING
from mflux.models.common.config.config import Config
from mflux.models.common.schedulers.base_scheduler import BaseScheduler
if TYPE_CHECKING: ...
class SeedVR2EulerScheduler(BaseScheduler):
def __init__(self, config: Config) -> None: ...
@property
def timesteps(self) -> mx.array: ...
@property
def sigmas(self) -> mx.array: ...
def step(
self, noise: mx.array, timestep: int, latents: mx.array, **kwargs
) -> mx.array: ...

View File

@@ -0,0 +1,24 @@
"""
This type stub file was generated by pyright.
"""
from mflux.models.common.tokenizer.tokenizer import (
BaseTokenizer,
LanguageTokenizer,
Tokenizer,
VisionLanguageTokenizer,
)
from mflux.models.common.tokenizer.tokenizer_loader import TokenizerLoader
from mflux.models.common.tokenizer.tokenizer_output import TokenizerOutput
"""
This type stub file was generated by pyright.
"""
__all__ = [
"Tokenizer",
"BaseTokenizer",
"LanguageTokenizer",
"VisionLanguageTokenizer",
"TokenizerLoader",
"TokenizerOutput",
]

View File

@@ -0,0 +1,74 @@
"""
This type stub file was generated by pyright.
"""
from abc import ABC, abstractmethod
from typing import Protocol, runtime_checkable
from PIL import Image
from transformers import PreTrainedTokenizer
from mflux.models.common.tokenizer.tokenizer_output import TokenizerOutput
"""
This type stub file was generated by pyright.
"""
@runtime_checkable
class Tokenizer(Protocol):
tokenizer: PreTrainedTokenizer
def tokenize(
self,
prompt: str | list[str],
images: list[Image.Image] | None = ...,
max_length: int | None = ...,
**kwargs,
) -> TokenizerOutput: ...
class BaseTokenizer(ABC):
def __init__(
self, tokenizer: PreTrainedTokenizer, max_length: int = ...
) -> None: ...
@abstractmethod
def tokenize(
self,
prompt: str | list[str],
images: list[Image.Image] | None = ...,
max_length: int | None = ...,
**kwargs,
) -> TokenizerOutput: ...
class LanguageTokenizer(BaseTokenizer):
def __init__(
self,
tokenizer: PreTrainedTokenizer,
max_length: int = ...,
padding: str = ...,
return_attention_mask: bool = ...,
template: str | None = ...,
use_chat_template: bool = ...,
chat_template_kwargs: dict | None = ...,
add_special_tokens: bool = ...,
) -> None: ...
def tokenize(
self,
prompt: str | list[str],
images: list[Image.Image] | None = ...,
max_length: int | None = ...,
**kwargs,
) -> TokenizerOutput: ...
class VisionLanguageTokenizer(BaseTokenizer):
def __init__(
self,
tokenizer: PreTrainedTokenizer,
processor,
max_length: int = ...,
template: str | None = ...,
image_token: str = ...,
) -> None: ...
def tokenize(
self,
prompt: str | list[str],
images: list[Image.Image] | None = ...,
max_length: int | None = ...,
**kwargs,
) -> TokenizerOutput: ...

View File

@@ -0,0 +1,22 @@
"""
This type stub file was generated by pyright.
"""
from typing import TYPE_CHECKING
from mflux.models.common.tokenizer.tokenizer import BaseTokenizer
from mflux.models.common.weights.loading.weight_definition import TokenizerDefinition
"""
This type stub file was generated by pyright.
"""
if TYPE_CHECKING: ...
class TokenizerLoader:
@staticmethod
def load(definition: TokenizerDefinition, model_path: str) -> BaseTokenizer: ...
@staticmethod
def load_all(
definitions: list[TokenizerDefinition],
model_path: str,
max_length_overrides: dict[str, int] | None = ...,
) -> dict[str, BaseTokenizer]: ...

View File

@@ -0,0 +1,17 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from dataclasses import dataclass
"""
This type stub file was generated by pyright.
"""
@dataclass
class TokenizerOutput:
input_ids: mx.array
attention_mask: mx.array
pixel_values: mx.array | None = ...
image_grid_thw: mx.array | None = ...

View File

@@ -0,0 +1,8 @@
"""
This type stub file was generated by pyright.
"""
from mflux.models.common.vae.tiling_config import TilingConfig
from mflux.models.common.vae.vae_tiler import VAETiler
__all__ = ["TilingConfig", "VAETiler"]

View File

@@ -0,0 +1,13 @@
"""
This type stub file was generated by pyright.
"""
from dataclasses import dataclass
@dataclass(frozen=True, slots=True)
class TilingConfig:
vae_decode_tiles_per_dim: int | None = ...
vae_decode_overlap: int = ...
vae_encode_tiled: bool = ...
vae_encode_tile_size: int = ...
vae_encode_tile_overlap: int = ...

View File

@@ -0,0 +1,27 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from typing import Callable
class VAETiler:
@staticmethod
def encode_image_tiled(
*,
image: mx.array,
encode_fn: Callable[[mx.array], mx.array],
latent_channels: int,
tile_size: tuple[int, int] = ...,
tile_overlap: tuple[int, int] = ...,
spatial_scale: int = ...,
) -> mx.array: ...
@staticmethod
def decode_image_tiled(
*,
latent: mx.array,
decode_fn: Callable[[mx.array], mx.array],
tile_size: tuple[int, int] = ...,
tile_overlap: tuple[int, int] = ...,
spatial_scale: int = ...,
) -> mx.array: ...

View File

@@ -0,0 +1,17 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
from mflux.models.common.vae.tiling_config import TilingConfig
class VAEUtil:
@staticmethod
def encode(
vae: nn.Module, image: mx.array, tiling_config: TilingConfig | None = ...
) -> mx.array: ...
@staticmethod
def decode(
vae: nn.Module, latent: mx.array, tiling_config: TilingConfig | None = ...
) -> mx.array: ...

View File

@@ -0,0 +1,18 @@
"""
This type stub file was generated by pyright.
"""
from mflux.models.common.weights.loading.loaded_weights import LoadedWeights, MetaData
from mflux.models.common.weights.loading.weight_applier import WeightApplier
from mflux.models.common.weights.loading.weight_definition import ComponentDefinition
from mflux.models.common.weights.loading.weight_loader import WeightLoader
from mflux.models.common.weights.saving.model_saver import ModelSaver
__all__ = [
"ComponentDefinition",
"LoadedWeights",
"MetaData",
"ModelSaver",
"WeightApplier",
"WeightLoader",
]

View File

@@ -0,0 +1,18 @@
"""
This type stub file was generated by pyright.
"""
from dataclasses import dataclass
@dataclass
class MetaData:
quantization_level: int | None = ...
mflux_version: str | None = ...
@dataclass
class LoadedWeights:
components: dict[str, dict]
meta_data: MetaData
def __getattr__(self, name: str) -> dict | None: ...
def num_transformer_blocks(self, component_name: str = ...) -> int: ...
def num_single_transformer_blocks(self, component_name: str = ...) -> int: ...

View File

@@ -0,0 +1,30 @@
"""
This type stub file was generated by pyright.
"""
import mlx.nn as nn
from typing import TYPE_CHECKING
from mflux.models.common.weights.loading.loaded_weights import LoadedWeights
from mflux.models.common.weights.loading.weight_definition import (
ComponentDefinition,
WeightDefinitionType,
)
if TYPE_CHECKING: ...
class WeightApplier:
@staticmethod
def apply_and_quantize_single(
weights: LoadedWeights,
model: nn.Module,
component: ComponentDefinition,
quantize_arg: int | None,
quantization_predicate=...,
) -> int | None: ...
@staticmethod
def apply_and_quantize(
weights: LoadedWeights,
models: dict[str, nn.Module],
quantize_arg: int | None,
weight_definition: WeightDefinitionType,
) -> int | None: ...

View File

@@ -0,0 +1,73 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from dataclasses import dataclass
from typing import Callable, List, TYPE_CHECKING, TypeAlias
from mflux.models.common.weights.mapping.weight_mapping import WeightTarget
from mflux.models.common.tokenizer.tokenizer import BaseTokenizer
from mflux.models.depth_pro.weights.depth_pro_weight_definition import (
DepthProWeightDefinition,
)
from mflux.models.fibo.weights.fibo_weight_definition import FIBOWeightDefinition
from mflux.models.fibo_vlm.weights.fibo_vlm_weight_definition import (
FIBOVLMWeightDefinition,
)
from mflux.models.flux.weights.flux_weight_definition import FluxWeightDefinition
from mflux.models.qwen.weights.qwen_weight_definition import QwenWeightDefinition
from mflux.models.seedvr2.weights.seedvr2_weight_definition import (
SeedVR2WeightDefinition,
)
from mflux.models.z_image.weights.z_image_weight_definition import (
ZImageWeightDefinition,
)
"""
This type stub file was generated by pyright.
"""
if TYPE_CHECKING:
WeightDefinitionType: TypeAlias = type[
FluxWeightDefinition
| FIBOWeightDefinition
| FIBOVLMWeightDefinition
| QwenWeightDefinition
| ZImageWeightDefinition
| SeedVR2WeightDefinition
| DepthProWeightDefinition
]
@dataclass
class ComponentDefinition:
name: str
hf_subdir: str
mapping_getter: Callable[[], List[WeightTarget]] | None = ...
model_attr: str | None = ...
num_blocks: int | None = ...
num_layers: int | None = ...
loading_mode: str = ...
precision: mx.Dtype | None = ...
skip_quantization: bool = ...
bulk_transform: Callable[[mx.array], mx.array] | None = ...
weight_subkey: str | None = ...
download_url: str | None = ...
weight_prefix_filters: List[str] | None = ...
weight_files: List[str] | None = ...
@dataclass
class TokenizerDefinition:
name: str
hf_subdir: str
tokenizer_class: str = ...
fallback_subdirs: List[str] | None = ...
download_patterns: List[str] | None = ...
encoder_class: type[BaseTokenizer] | None = ...
max_length: int = ...
padding: str = ...
template: str | None = ...
use_chat_template: bool = ...
chat_template_kwargs: dict | None = ...
add_special_tokens: bool = ...
processor_class: type | None = ...
image_token: str = ...
chat_template: str | None = ...

View File

@@ -0,0 +1,23 @@
"""
This type stub file was generated by pyright.
"""
from typing import TYPE_CHECKING
from mflux.models.common.weights.loading.loaded_weights import LoadedWeights
from mflux.models.common.weights.loading.weight_definition import (
ComponentDefinition,
WeightDefinitionType,
)
if TYPE_CHECKING: ...
logger = ...
class WeightLoader:
@staticmethod
def load_single(
component: ComponentDefinition, repo_id: str, file_pattern: str = ...
) -> LoadedWeights: ...
@staticmethod
def load(
weight_definition: WeightDefinitionType, model_path: str | None = ...
) -> LoadedWeights: ...

View File

@@ -0,0 +1,16 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from typing import Dict, List, Optional
from mflux.models.common.weights.mapping.weight_mapping import WeightTarget
class WeightMapper:
@staticmethod
def apply_mapping(
hf_weights: Dict[str, mx.array],
mapping: List[WeightTarget],
num_blocks: Optional[int] = ...,
num_layers: Optional[int] = ...,
) -> Dict: ...

View File

@@ -0,0 +1,23 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from dataclasses import dataclass
from typing import Callable, List, Optional, Protocol
"""
This type stub file was generated by pyright.
"""
@dataclass
class WeightTarget:
to_pattern: str
from_pattern: List[str]
transform: Optional[Callable[[mx.array], mx.array]] = ...
required: bool = ...
max_blocks: Optional[int] = ...
class WeightMapping(Protocol):
@staticmethod
def get_mapping() -> List[WeightTarget]: ...

View File

@@ -0,0 +1,17 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
class WeightTransforms:
@staticmethod
def reshape_gamma_to_1d(tensor: mx.array) -> mx.array: ...
@staticmethod
def transpose_patch_embed(tensor: mx.array) -> mx.array: ...
@staticmethod
def transpose_conv3d_weight(tensor: mx.array) -> mx.array: ...
@staticmethod
def transpose_conv2d_weight(tensor: mx.array) -> mx.array: ...
@staticmethod
def transpose_conv_transpose2d_weight(tensor: mx.array) -> mx.array: ...

View File

@@ -0,0 +1,14 @@
"""
This type stub file was generated by pyright.
"""
from typing import Any, TYPE_CHECKING
from mflux.models.common.weights.loading.weight_definition import WeightDefinitionType
if TYPE_CHECKING: ...
class ModelSaver:
@staticmethod
def save_model(
model: Any, bits: int, base_path: str, weight_definition: WeightDefinitionType
) -> None: ...

View File

@@ -0,0 +1,9 @@
"""
This type stub file was generated by pyright.
"""
from mflux.models.depth_pro.model.depth_pro_model import DepthProModel
class DepthProInitializer:
@staticmethod
def init(model: DepthProModel, quantize: int | None = ...) -> None: ...

View File

@@ -0,0 +1,10 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
import mlx.nn as nn
class FeatureFusionBlock2d(nn.Module):
def __init__(self, num_features: int, deconv: bool = ...) -> None: ...
def __call__(self, x0: mx.array, x1: mx.array | None = ...) -> mx.array: ...

View File

@@ -0,0 +1,17 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
import mlx.nn as nn
class MultiresConvDecoder(nn.Module):
def __init__(self) -> None: ...
def __call__(
self,
x0_latent: mx.array,
x1_latent: mx.array,
x0_features: mx.array,
x1_features: mx.array,
x_global_features: mx.array,
) -> mx.array: ...

View File

@@ -0,0 +1,10 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
import mlx.nn as nn
class ResidualBlock(nn.Module):
def __init__(self, num_features: int) -> None: ...
def __call__(self, x: mx.array) -> mx.array: ...

View File

@@ -0,0 +1,20 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from dataclasses import dataclass
from pathlib import Path
from PIL import Image
@dataclass
class DepthResult:
depth_image: Image.Image
depth_array: mx.array
min_depth: float
max_depth: float
...
class DepthPro:
def __init__(self, quantize: int | None = ...) -> None: ...
def create_depth_map(self, image_path: str | Path) -> DepthResult: ...

View File

@@ -0,0 +1,12 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
import mlx.nn as nn
class DepthProModel(nn.Module):
def __init__(self) -> None: ...
def __call__(
self, x0: mx.array, x1: mx.array, x2: mx.array
) -> tuple[mx.array, mx.array]: ...

View File

@@ -0,0 +1,15 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
import mlx.nn as nn
class DepthProUtil:
@staticmethod
def split(x: mx.array, overlap_ratio: float = ...) -> mx.array: ...
@staticmethod
def interpolate(x: mx.array, size=..., scale_factor=...): # -> array:
...
@staticmethod
def apply_conv(x: mx.array, conv_module: nn.Module) -> mx.array: ...

View File

@@ -0,0 +1,12 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
class Attention(nn.Module):
def __init__(
self, dim: int = ..., head_dim: int = ..., num_heads: int = ...
) -> None: ...
def __call__(self, x: mx.array) -> mx.array: ...

View File

@@ -0,0 +1,10 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
import mlx.nn as nn
class DinoVisionTransformer(nn.Module):
def __init__(self) -> None: ...
def __call__(self, x: mx.array) -> tuple[mx.array, mx.array, mx.array]: ...

View File

@@ -0,0 +1,10 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
import mlx.nn as nn
class LayerScale(nn.Module):
def __init__(self, dims: int, init_values: float = ...) -> None: ...
def __call__(self, x: mx.array) -> mx.array: ...

View File

@@ -0,0 +1,10 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
import mlx.nn as nn
class MLP(nn.Module):
def __init__(self) -> None: ...
def __call__(self, x: mx.array) -> mx.array: ...

View File

@@ -0,0 +1,10 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
import mlx.nn as nn
class PatchEmbed(nn.Module):
def __init__(self) -> None: ...
def __call__(self, x: mx.array) -> mx.array: ...

View File

@@ -0,0 +1,10 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
import mlx.nn as nn
class TransformerBlock(nn.Module):
def __init__(self) -> None: ...
def __call__(self, x: mx.array) -> mx.array: ...

View File

@@ -0,0 +1,12 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
import mlx.nn as nn
class DepthProEncoder(nn.Module):
def __init__(self) -> None: ...
def __call__(
self, x0: mx.array, x1: mx.array, x2: mx.array
) -> tuple[mx.array, mx.array, mx.array, mx.array, mx.array]: ...

View File

@@ -0,0 +1,16 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
import mlx.nn as nn
class UpSampleBlock(nn.Module):
def __init__(
self,
dim_in: int = ...,
dim_int: int = ...,
dim_out: int = ...,
upsample_layers: int = ...,
) -> None: ...
def __call__(self, x: mx.array) -> mx.array: ...

View File

@@ -0,0 +1,10 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
import mlx.nn as nn
class FOVHead(nn.Module):
def __init__(self) -> None: ...
def __call__(self, x: mx.array) -> mx.array: ...

View File

@@ -0,0 +1,23 @@
"""
This type stub file was generated by pyright.
"""
from typing import List
from mflux.models.common.weights.loading.weight_definition import (
ComponentDefinition,
TokenizerDefinition,
)
"""
This type stub file was generated by pyright.
"""
class DepthProWeightDefinition:
@staticmethod
def get_components() -> List[ComponentDefinition]: ...
@staticmethod
def get_tokenizers() -> List[TokenizerDefinition]: ...
@staticmethod
def get_download_patterns() -> List[str]: ...
@staticmethod
def quantization_predicate(path: str, module) -> bool: ...

View File

@@ -0,0 +1,13 @@
"""
This type stub file was generated by pyright.
"""
from typing import List
from mflux.models.common.weights.mapping.weight_mapping import (
WeightMapping,
WeightTarget,
)
class DepthProWeightMapping(WeightMapping):
@staticmethod
def get_mapping() -> List[WeightTarget]: ...

View File

@@ -0,0 +1,13 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
class FiboLatentCreator:
@staticmethod
def create_noise(seed: int, height: int, width: int) -> mx.array: ...
@staticmethod
def pack_latents(latents: mx.array, height: int, width: int) -> mx.array: ...
@staticmethod
def unpack_latents(latents: mx.array, height: int, width: int) -> mx.array: ...

View File

@@ -0,0 +1,23 @@
"""
This type stub file was generated by pyright.
"""
from typing import List
from mflux.models.common.weights.loading.weight_definition import (
ComponentDefinition,
TokenizerDefinition,
)
"""
This type stub file was generated by pyright.
"""
class FIBOWeightDefinition:
@staticmethod
def get_components() -> List[ComponentDefinition]: ...
@staticmethod
def get_tokenizers() -> List[TokenizerDefinition]: ...
@staticmethod
def get_download_patterns() -> List[str]: ...
@staticmethod
def quantization_predicate(path: str, module) -> bool: ...

View File

@@ -0,0 +1,17 @@
"""
This type stub file was generated by pyright.
"""
from typing import List
from mflux.models.common.weights.mapping.weight_mapping import (
WeightMapping,
WeightTarget,
)
class FIBOWeightMapping(WeightMapping):
@staticmethod
def get_transformer_mapping() -> List[WeightTarget]: ...
@staticmethod
def get_text_encoder_mapping() -> List[WeightTarget]: ...
@staticmethod
def get_vae_mapping() -> List[WeightTarget]: ...

View File

@@ -0,0 +1,8 @@
"""
This type stub file was generated by pyright.
"""
from mflux.models.qwen.tokenizer.qwen_image_processor import QwenImageProcessor
class Qwen2VLImageProcessor(QwenImageProcessor):
def __init__(self) -> None: ...

View File

@@ -0,0 +1,28 @@
"""
This type stub file was generated by pyright.
"""
from typing import Optional, Union
from PIL import Image
class Qwen2VLProcessor:
def __init__(self, tokenizer) -> None: ...
def apply_chat_template(
self,
messages,
tokenize: bool = ...,
add_generation_prompt: bool = ...,
return_tensors: Optional[str] = ...,
return_dict: bool = ...,
**kwargs,
): # -> dict[Any, Any]:
...
def __call__(
self,
text: Optional[Union[str, list[str]]] = ...,
images: Optional[Union[Image.Image, list[Image.Image]]] = ...,
padding: bool = ...,
return_tensors: Optional[str] = ...,
**kwargs,
): # -> dict[Any, Any]:
...

View File

@@ -0,0 +1,24 @@
"""
This type stub file was generated by pyright.
"""
from typing import List
from mflux.models.common.weights.loading.weight_definition import (
ComponentDefinition,
TokenizerDefinition,
)
"""
This type stub file was generated by pyright.
"""
QWEN2VL_CHAT_TEMPLATE = ...
class FIBOVLMWeightDefinition:
@staticmethod
def get_components() -> List[ComponentDefinition]: ...
@staticmethod
def get_tokenizers() -> List[TokenizerDefinition]: ...
@staticmethod
def get_download_patterns() -> List[str]: ...
@staticmethod
def quantization_predicate(path: str, module) -> bool: ...

View File

@@ -0,0 +1,15 @@
"""
This type stub file was generated by pyright.
"""
from typing import List
from mflux.models.common.weights.mapping.weight_mapping import (
WeightMapping,
WeightTarget,
)
class FIBOVLMWeightMapping(WeightMapping):
@staticmethod
def get_vlm_decoder_mapping(num_layers: int = ...) -> List[WeightTarget]: ...
@staticmethod
def get_vlm_visual_mapping(depth: int = ...) -> List[WeightTarget]: ...

View File

@@ -0,0 +1,3 @@
"""
This type stub file was generated by pyright.
"""

View File

@@ -0,0 +1,3 @@
"""
This type stub file was generated by pyright.
"""

View File

@@ -0,0 +1,53 @@
"""
This type stub file was generated by pyright.
"""
from mflux.models.common.config import ModelConfig
class FluxInitializer:
@staticmethod
def init(
model,
model_config: ModelConfig,
quantize: int | None,
model_path: str | None = ...,
lora_paths: list[str] | None = ...,
lora_scales: list[float] | None = ...,
custom_transformer=...,
) -> None: ...
@staticmethod
def init_depth(
model,
model_config: ModelConfig,
quantize: int | None,
model_path: str | None = ...,
lora_paths: list[str] | None = ...,
lora_scales: list[float] | None = ...,
) -> None: ...
@staticmethod
def init_redux(
model,
model_config: ModelConfig,
quantize: int | None,
model_path: str | None = ...,
lora_paths: list[str] | None = ...,
lora_scales: list[float] | None = ...,
) -> None: ...
@staticmethod
def init_controlnet(
model,
model_config: ModelConfig,
quantize: int | None,
model_path: str | None = ...,
lora_paths: list[str] | None = ...,
lora_scales: list[float] | None = ...,
) -> None: ...
@staticmethod
def init_concept(
model,
model_config: ModelConfig,
quantize: int | None,
model_path: str | None = ...,
lora_paths: list[str] | None = ...,
lora_scales: list[float] | None = ...,
) -> None: ...

View File

@@ -0,0 +1,7 @@
"""
This type stub file was generated by pyright.
"""
"""
This type stub file was generated by pyright.
"""

View File

@@ -0,0 +1,19 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
"""
This type stub file was generated by pyright.
"""
class FluxLatentCreator:
@staticmethod
def create_noise(seed: int, height: int, width: int) -> mx.array: ...
@staticmethod
def pack_latents(
latents: mx.array, height: int, width: int, num_channels_latents: int = ...
) -> mx.array: ...
@staticmethod
def unpack_latents(latents: mx.array, height: int, width: int) -> mx.array: ...

View File

@@ -0,0 +1,7 @@
"""
This type stub file was generated by pyright.
"""
"""
This type stub file was generated by pyright.
"""

View File

@@ -0,0 +1,10 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
class CLIPEmbeddings(nn.Module):
def __init__(self, dims: int) -> None: ...
def __call__(self, tokens: mx.array) -> mx.array: ...

View File

@@ -0,0 +1,14 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
"""
This type stub file was generated by pyright.
"""
class CLIPEncoder(nn.Module):
def __init__(self) -> None: ...
def __call__(self, tokens: mx.array) -> mx.array: ...

View File

@@ -0,0 +1,12 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
class CLIPEncoderLayer(nn.Module):
def __init__(self, layer: int) -> None: ...
def __call__(
self, hidden_states: mx.array, causal_attention_mask: mx.array
) -> mx.array: ...

View File

@@ -0,0 +1,12 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
class CLIPMLP(nn.Module):
def __init__(self) -> None: ...
def __call__(self, hidden_states: mx.array) -> mx.array: ...
@staticmethod
def quick_gelu(input_array: mx.array) -> mx.array: ...

View File

@@ -0,0 +1,18 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
class CLIPSdpaAttention(nn.Module):
head_dimension = ...
batch_size = ...
num_heads = ...
def __init__(self) -> None: ...
def __call__(
self, hidden_states: mx.array, causal_attention_mask: mx.array
) -> mx.array: ...
@staticmethod
def reshape_and_transpose(x, batch_size, num_heads, head_dim): # -> array:
...

View File

@@ -0,0 +1,12 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
class CLIPTextModel(nn.Module):
def __init__(self, dims: int, num_encoder_layers: int) -> None: ...
def __call__(self, tokens: mx.array) -> tuple[mx.array, mx.array]: ...
@staticmethod
def create_causal_attention_mask(input_shape: tuple) -> mx.array: ...

View File

@@ -0,0 +1,12 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
class EncoderCLIP(nn.Module):
def __init__(self, num_encoder_layers: int) -> None: ...
def __call__(
self, tokens: mx.array, causal_attention_mask: mx.array
) -> mx.array: ...

View File

@@ -0,0 +1,25 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mflux.models.common.tokenizer import Tokenizer
from mflux.models.flux.model.flux_text_encoder.clip_encoder.clip_encoder import (
CLIPEncoder,
)
from mflux.models.flux.model.flux_text_encoder.t5_encoder.t5_encoder import T5Encoder
"""
This type stub file was generated by pyright.
"""
class PromptEncoder:
@staticmethod
def encode_prompt(
prompt: str,
prompt_cache: dict[str, tuple[mx.array, mx.array]],
t5_tokenizer: Tokenizer,
clip_tokenizer: Tokenizer,
t5_text_encoder: T5Encoder,
clip_text_encoder: CLIPEncoder,
) -> tuple[mx.array, mx.array]: ...

View File

@@ -0,0 +1,10 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
class T5Attention(nn.Module):
def __init__(self) -> None: ...
def __call__(self, hidden_states: mx.array) -> mx.array: ...

View File

@@ -0,0 +1,10 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
class T5Block(nn.Module):
def __init__(self, layer: int) -> None: ...
def __call__(self, hidden_states: mx.array) -> mx.array: ...

View File

@@ -0,0 +1,12 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
class T5DenseReluDense(nn.Module):
def __init__(self) -> None: ...
def __call__(self, hidden_states: mx.array) -> mx.array: ...
@staticmethod
def new_gelu(input_array: mx.array) -> mx.array: ...

View File

@@ -0,0 +1,14 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
"""
This type stub file was generated by pyright.
"""
class T5Encoder(nn.Module):
def __init__(self) -> None: ...
def __call__(self, tokens: mx.array): ...

View File

@@ -0,0 +1,10 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
class T5FeedForward(nn.Module):
def __init__(self) -> None: ...
def __call__(self, hidden_states: mx.array) -> mx.array: ...

View File

@@ -0,0 +1,10 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
class T5LayerNorm(nn.Module):
def __init__(self) -> None: ...
def __call__(self, hidden_states: mx.array) -> mx.array: ...

View File

@@ -0,0 +1,16 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
class T5SelfAttention(nn.Module):
def __init__(self) -> None: ...
def __call__(self, hidden_states: mx.array) -> mx.array: ...
@staticmethod
def shape(states): # -> array:
...
@staticmethod
def un_shape(states): # -> array:
...

View File

@@ -0,0 +1,10 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
class AdaLayerNormContinuous(nn.Module):
def __init__(self, embedding_dim: int, conditioning_embedding_dim: int) -> None: ...
def __call__(self, x: mx.array, text_embeddings: mx.array) -> mx.array: ...

View File

@@ -0,0 +1,12 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
class AdaLayerNormZero(nn.Module):
def __init__(self) -> None: ...
def __call__(
self, hidden_states: mx.array, text_embeddings: mx.array
) -> tuple[mx.array, mx.array, mx.array, mx.array, mx.array]: ...

View File

@@ -0,0 +1,12 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
class AdaLayerNormZeroSingle(nn.Module):
def __init__(self) -> None: ...
def __call__(
self, hidden_states: mx.array, text_embeddings: mx.array
) -> tuple[mx.array, mx.array]: ...

View File

@@ -0,0 +1,41 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
class AttentionUtils:
@staticmethod
def process_qkv(
hidden_states: mx.array,
to_q: nn.Linear,
to_k: nn.Linear,
to_v: nn.Linear,
norm_q: nn.RMSNorm,
norm_k: nn.RMSNorm,
num_heads: int,
head_dim: int,
) -> tuple[mx.array, mx.array, mx.array]: ...
@staticmethod
def compute_attention(
query: mx.array,
key: mx.array,
value: mx.array,
batch_size: int,
num_heads: int,
head_dim: int,
mask: mx.array | None = ...,
) -> mx.array: ...
@staticmethod
def convert_key_padding_mask_to_additive_mask(
mask: mx.array | None, joint_seq_len: int, txt_seq_len: int
) -> mx.array | None: ...
@staticmethod
def apply_rope(
xq: mx.array, xk: mx.array, freqs_cis: mx.array
) -> tuple[mx.array, mx.array]: ...
@staticmethod
def apply_rope_bshd(
xq: mx.array, xk: mx.array, cos: mx.array, sin: mx.array
) -> tuple[mx.array, mx.array]: ...

Some files were not shown because too many files have changed in this diff Show More