Compare commits

..

2 Commits

Author SHA1 Message Date
Alex Cheema
55b67e2be2 Merge branch 'main' into releases/v1.0.64 2026-01-23 01:40:20 +00:00
Ryuichi Leo Takashige
30cfad9b68 Use custom fork 2026-01-22 22:19:35 +00:00
17 changed files with 441 additions and 765 deletions

View File

@@ -2,14 +2,11 @@
This type stub file was generated by pyright.
"""
from typing import Any, Dict, List, Optional, Protocol, Literal, Self
import mlx.nn as nn
import mlx.core as mx
from typing import Any, Dict, List, Literal, Optional, Protocol, Self
from mlx.core import array
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
class Cache(Protocol):
keys: mx.array
@@ -35,7 +32,6 @@ def make_prompt_cache(
``make_cache`` method, a ``RotatingKVCache`` is used with a maximum
size of ``max_kv_size``
"""
...
def save_prompt_cache(
file_name: str, cache: List[Cache], metadata: Dict[str, str] = ...
@@ -49,7 +45,6 @@ def save_prompt_cache(
metadata (Dict[str, str]): Optional metadata to save along with model
state.
"""
...
def load_prompt_cache(file_name: str, return_metadata=...) -> array:
"""
@@ -64,15 +59,13 @@ def load_prompt_cache(file_name: str, return_metadata=...) -> array:
List[Any] or Tuple[List[Any], Dict[str, str]]: The prompt cache and
the metadata if requested.
"""
...
def can_trim_prompt_cache(cache: List[Any]) -> bool:
def can_trim_prompt_cache(cache: List[Cache]) -> bool:
"""
Check if model's cache can be trimmed.
"""
...
def trim_prompt_cache(cache: List[Any], num_tokens: int) -> int:
def trim_prompt_cache(cache: List[Cache], num_tokens: int) -> List[Cache]:
"""
Trim the model's cache by the given number of tokens.
@@ -86,7 +79,6 @@ def trim_prompt_cache(cache: List[Any], num_tokens: int) -> int:
Returns:
(int): The number of tokens that were trimmed.
"""
...
def create_attention_mask(
N: int, offset: int, return_array: bool, window_size: Optional[int]
@@ -115,125 +107,164 @@ class ConcatenateKVCache(_BaseCache):
KVCache with a larger step size before using this cache.
"""
def __init__(self) -> None: ...
def update_and_fetch(self, keys, values): ...
def update_and_fetch(self, keys, values): # -> tuple[Any | array, Any | array]:
...
@property
def state(self): ...
def state(self): # -> tuple[Any | array | None, Any | array | None]:
...
@state.setter
def state(self, v): ...
def is_trimmable(self): ...
def trim(self, n): ...
def make_mask(self, *args, **kwargs): ...
def state(self, v): # -> None:
...
def is_trimmable(self): # -> Literal[True]:
...
def trim(self, n): # -> int:
...
def make_mask(self, *args, **kwargs): # -> array | Literal['causal'] | None:
...
class QuantizedKVCache(_BaseCache):
step = ...
offset: int
def __init__(self, group_size: int = ..., bits: int = ...) -> None: ...
def update_and_fetch(self, keys, values): ...
def update_and_fetch(self, keys, values): # -> Any:
...
@property
def state(self): ...
def state(
self,
): # -> tuple[Any | tuple[array, array, array] | None, Any | tuple[array, array, array] | None] | Any:
...
@state.setter
def state(self, v): ...
def state(self, v): # -> None:
...
@property
def meta_state(self): ...
def meta_state(self): # -> tuple[str, ...]:
...
@meta_state.setter
def meta_state(self, v): ...
def is_trimmable(self): ...
def trim(self, n): ...
def make_mask(self, *args, **kwargs): ...
def meta_state(self, v): # -> None:
...
def is_trimmable(self): # -> Literal[True]:
...
def trim(self, n): # -> int:
...
def make_mask(self, *args, **kwargs): # -> array | Literal['causal'] | None:
...
class KVCache(_BaseCache):
step = ...
offset: int
def __init__(self) -> None: ...
def update_and_fetch(self, keys, values): ...
def update_and_fetch(self, keys, values): # -> tuple[array | Any, array | Any]:
...
@property
def state(self) -> tuple[array, array]: ...
def state(
self,
) -> tuple[array, array]: ...
@state.setter
def state(self, v) -> None: ...
def is_trimmable(self): ...
def trim(self, n): ...
def is_trimmable(self): # -> Literal[True]:
...
def trim(self, n): # -> int:
...
def to_quantized(
self, group_size: int = ..., bits: int = ...
) -> QuantizedKVCache: ...
def make_mask(self, *args, **kwargs): ...
def make_mask(self, *args, **kwargs): # -> array | Literal['causal'] | None:
...
class RotatingKVCache(_BaseCache):
step = ...
offset: int
def __init__(self, max_size, keep=...) -> None: ...
def update_and_fetch(self, keys, values): ...
def update_and_fetch(
self, keys, values
): # -> tuple[array | Any, array | Any] | tuple[array | Any, array | Any | None]:
...
@property
def state(self): ...
def state(
self,
): # -> tuple[Any | array, Any | array] | tuple[Any | array | None, Any | array | None]:
...
@state.setter
def state(self, v): ...
def state(self, v): # -> None:
...
@property
def meta_state(self): ...
def meta_state(self): # -> tuple[str, ...]:
...
@meta_state.setter
def meta_state(self, v): ...
def is_trimmable(self): ...
def trim(self, n): ...
def meta_state(self, v): # -> None:
...
def is_trimmable(self): # -> bool:
...
def trim(self, n): # -> int:
...
def to_quantized(
self, group_size: int = ..., bits: int = ...
) -> QuantizedKVCache: ...
def make_mask(
self, N: int, window_size: Optional[int] = ..., return_array: bool = ...
): ...
): # -> array | Literal['causal'] | None:
...
class ArraysCache(_BaseCache):
def __init__(self, size, left_padding: Optional[List[int]] = ...) -> None: ...
def __setitem__(self, idx, value): ...
def __setitem__(self, idx, value): # -> None:
...
def __getitem__(self, idx): ...
@property
def state(self): ...
def state(self): # -> list[Any | array] | list[array]:
...
@state.setter
def state(self, v): ...
def filter(self, batch_indices):
def state(self, v): # -> None:
...
def filter(self, batch_indices): # -> None:
"""
In-place filter to keep just the given indices in the cache.
"""
...
def extend(self, other):
def extend(self, other): # -> None:
"""
In-place extend this cache with the other cache.
"""
...
def make_mask(self, N: int): ...
def make_mask(self, N: int): # -> array | None:
...
class MambaCache(ArraysCache):
def __init__(self, left_padding: Optional[List[int]] = ...) -> None: ...
class ChunkedKVCache(KVCache):
def __init__(self, chunk_size) -> None: ...
def maybe_trim_front(self): ...
def update_and_fetch(self, keys, values): ...
def trim(self, n): ...
def maybe_trim_front(self): # -> None:
...
def update_and_fetch(self, keys, values): # -> tuple[array, array]:
...
def trim(self, n): # -> int:
...
@property
def meta_state(self): ...
def meta_state(self): # -> tuple[str, ...]:
...
@meta_state.setter
def meta_state(self, v): ...
def meta_state(self, v): # -> None:
...
class CacheList(_BaseCache):
def __init__(self, *caches) -> None: ...
def __getitem__(self, idx): ...
def is_trimmable(self): ...
def is_trimmable(self): # -> bool:
...
def trim(self, n): ...
@property
def state(self): ...
def state(self): # -> list[Any]:
...
@state.setter
def state(self, v): ...
def filter(self, batch_indices):
def state(self, v): # -> None:
...
def filter(self, batch_indices): # -> None:
"""
In-place filter to keep just the given indices in the cache.
"""
...
def extend(self, other):
def extend(self, other): # -> None:
"""
In-place extend this cache with the other cache.
"""
...
class BatchKVCache(_BaseCache):
step = ...
@@ -256,56 +287,71 @@ class BatchKVCache(_BaseCache):
And ``left_padding`` specifies the amount of padding for each.
In this case, ``left_padding = [1, 3, 0]``.
"""
...
def update_and_fetch(self, keys, values): ...
def update_and_fetch(self, keys, values): # -> tuple[array | Any, array | Any]:
...
@property
def state(self): ...
def state(
self,
): # -> tuple[Any | array | None, Any | array | None, array | Any, array | Any]:
...
@state.setter
def state(self, v): ...
def is_trimmable(self): ...
def trim(self, n): ...
def make_mask(self, N: int, return_array: bool = ..., **kwargs): ...
def filter(self, batch_indices):
def state(self, v): # -> None:
...
def is_trimmable(self): # -> Literal[True]:
...
def trim(self, n): # -> int | float:
...
def make_mask(self, N: int, return_array: bool = ..., **kwargs): # -> array:
...
def filter(self, batch_indices): # -> None:
"""
In-place filter to keep just the given indices in the cache.
"""
...
def extend(self, other):
def extend(self, other): # -> None:
"""
In-place extend this cache with the other cache.
"""
...
class BatchRotatingKVCache(_BaseCache):
step = ...
def __init__(self, max_size, left_padding: List[int]) -> None: ...
def update_and_fetch(self, keys, values): ...
def update_and_fetch(
self, keys, values
): # -> tuple[array | Any, array | Any] | tuple[array | Any, array | Any | None]:
...
@property
def state(self): ...
def state(
self,
): # -> tuple[Any | array | None, Any | array | None, array | Any, array | Any]:
...
@state.setter
def state(self, v): ...
def state(self, v): # -> None:
...
@property
def meta_state(self): ...
def meta_state(self): # -> tuple[str, ...]:
...
@meta_state.setter
def meta_state(self, v): ...
def is_trimmable(self): ...
def trim(self, n): ...
def meta_state(self, v): # -> None:
...
def is_trimmable(self): # -> bool:
...
def trim(self, n): # -> int:
...
def to_quantized(
self, group_size: int = ..., bits: int = ...
) -> QuantizedKVCache: ...
def make_mask(
self, N: int, window_size: Optional[int] = ..., return_array: bool = ...
): ...
def filter(self, batch_indices):
): # -> array:
...
def filter(self, batch_indices): # -> None:
"""
In-place filter to keep just the given indices in the cache.
"""
...
def extend(self, other):
def extend(self, other): # -> None:
"""
In-place extend this cache with the other cache.
"""
...

View File

@@ -8,10 +8,6 @@ from typing import Any
from transformers import PreTrainedTokenizerFast
"""
This type stub file was generated by pyright.
"""
class StreamingDetokenizer:
"""The streaming detokenizer interface so that we can detokenize one token at a time.
@@ -49,7 +45,6 @@ class StreamingDetokenizer:
@property
def last_segment(self):
"""Return the last segment of readable text since last time this property was accessed."""
...
class NaiveStreamingDetokenizer(StreamingDetokenizer):
"""NaiveStreamingDetokenizer relies on the underlying tokenizer
@@ -59,11 +54,15 @@ class NaiveStreamingDetokenizer(StreamingDetokenizer):
repeatedly detokenize the same tokens until a new line is generated.
"""
def __init__(self, tokenizer) -> None: ...
def reset(self): ...
def add_token(self, token): ...
def finalize(self): ...
def reset(self): # -> None:
...
def add_token(self, token): # -> None:
...
def finalize(self): # -> None:
...
@property
def text(self): ...
def text(self): # -> str:
...
class SPMStreamingDetokenizer(StreamingDetokenizer):
"""A streaming detokenizer for SPM models.
@@ -72,9 +71,12 @@ class SPMStreamingDetokenizer(StreamingDetokenizer):
underscore which results in linear complexity.
"""
def __init__(self, tokenizer, trim_space=...) -> None: ...
def reset(self): ...
def add_token(self, token): ...
def finalize(self): ...
def reset(self): # -> None:
...
def add_token(self, token): # -> None:
...
def finalize(self): # -> None:
...
class BPEStreamingDetokenizer(StreamingDetokenizer):
"""A streaming detokenizer for OpenAI style BPE models.
@@ -86,13 +88,15 @@ class BPEStreamingDetokenizer(StreamingDetokenizer):
_byte_decoder = ...
_space_matches = ...
def __init__(self, tokenizer) -> None: ...
def reset(self): ...
def add_token(self, token): ...
def finalize(self): ...
@classmethod
def make_byte_decoder(cls):
"""See https://github.com/openai/gpt-2/blob/master/src/encoder.py for the rationale."""
def reset(self): # -> None:
...
def add_token(self, token): # -> None:
...
def finalize(self): # -> None:
...
@classmethod
def make_byte_decoder(cls): # -> None:
"""See https://github.com/openai/gpt-2/blob/master/src/encoder.py for the rationale."""
class TokenizerWrapper:
"""A wrapper that combines an HF tokenizer and a detokenizer.
@@ -153,10 +157,13 @@ class TokenizerWrapper:
class NewlineTokenizer(PreTrainedTokenizerFast):
"""A tokenizer that replaces newlines with <n> and <n> with new line."""
def __init__(self, *args, **kwargs) -> None: ...
def encode(self, text, **kwargs): ...
def encode(self, text, **kwargs): # -> list[int]:
...
def encode_batch(self, texts, **kwargs): ...
def decode(self, *args, **kwargs): ...
def batch_decode(self, *args, **kwargs): ...
def decode(self, *args, **kwargs): # -> str:
...
def batch_decode(self, *args, **kwargs): # -> list[str]:
...
def load(
model_path: Path,
@@ -169,7 +176,6 @@ def load(
Note, to use a fast streaming tokenizer, pass a local file path rather than
a Hugging Face repo ID.
"""
...
# Alias for backward compatibility
load_tokenizer = load

View File

@@ -3,45 +3,6 @@
perSystem =
{ pkgs, lib, ... }:
let
# Stub source with lockfiles and minimal files for build to succeed
# This allows prettier-svelte to avoid rebuilding when dashboard source changes
dashboardStubSrc = pkgs.runCommand "dashboard-stub-src" { } ''
mkdir -p $out
cp ${inputs.self}/dashboard/package.json $out/
cp ${inputs.self}/dashboard/package-lock.json $out/
# Minimal files so vite build succeeds (produces empty output)
echo '<!DOCTYPE html><html><head></head><body></body></html>' > $out/index.html
mkdir -p $out/src
touch $out/src/app.html
'';
# Deps-only build using stub source (for prettier-svelte)
# Only rebuilds when package.json or package-lock.json change
dashboardDeps = inputs.dream2nix.lib.evalModules {
packageSets.nixpkgs = pkgs;
modules = [
./dashboard.nix
{
paths.projectRoot = inputs.self;
paths.projectRootFile = "flake.nix";
paths.package = inputs.self + "/dashboard";
}
{
deps.dashboardSrc = lib.mkForce dashboardStubSrc;
}
# Override build phases to skip the actual build - just need node_modules
{
mkDerivation = {
buildPhase = lib.mkForce "true";
installPhase = lib.mkForce ''
runHook preInstall
runHook postInstall
'';
};
}
];
};
# Filter source to only include dashboard directory
dashboardSrc = lib.cleanSourceWith {
src = inputs.self;
@@ -81,12 +42,11 @@
'';
# Prettier with svelte plugin for treefmt
# Uses dashboardDeps instead of dashboardFull to avoid rebuilding on source changes
packages.prettier-svelte = pkgs.writeShellScriptBin "prettier-svelte" ''
export NODE_PATH="${dashboardDeps}/lib/node_modules/exo-dashboard/node_modules"
export NODE_PATH="${dashboardFull}/lib/node_modules/exo-dashboard/node_modules"
exec ${pkgs.nodejs}/bin/node \
${dashboardDeps}/lib/node_modules/exo-dashboard/node_modules/prettier/bin/prettier.cjs \
--plugin "${dashboardDeps}/lib/node_modules/exo-dashboard/node_modules/prettier-plugin-svelte/plugin.js" \
${dashboardFull}/lib/node_modules/exo-dashboard/node_modules/prettier/bin/prettier.cjs \
--plugin "${dashboardFull}/lib/node_modules/exo-dashboard/node_modules/prettier-plugin-svelte/plugin.js" \
"$@"
'';
};

View File

@@ -89,10 +89,7 @@
const isImageModel = $derived(() => {
if (!currentModel) return false;
return (
modelSupportsTextToImage(currentModel) ||
modelSupportsImageEditing(currentModel)
);
return modelSupportsTextToImage(currentModel);
});
const isEditOnlyWithoutImage = $derived(
@@ -649,23 +646,6 @@
</svg>
<span>EDIT</span>
</span>
{:else if isEditOnlyWithoutImage}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
/>
</svg>
<span>EDIT</span>
</span>
{:else if isImageModel()}
<span class="inline-flex items-center gap-1.5">
<svg

View File

@@ -110,36 +110,6 @@
setImageGenerationParams({ negativePrompt: value || null });
}
function handleNumImagesChange(event: Event) {
const input = event.target as HTMLInputElement;
const value = input.value.trim();
if (value === "") {
setImageGenerationParams({ numImages: 1 });
} else {
const num = parseInt(value, 10);
if (!isNaN(num) && num >= 1) {
setImageGenerationParams({ numImages: num });
}
}
}
function handleStreamChange(enabled: boolean) {
setImageGenerationParams({ stream: enabled });
}
function handlePartialImagesChange(event: Event) {
const input = event.target as HTMLInputElement;
const value = input.value.trim();
if (value === "") {
setImageGenerationParams({ partialImages: 0 });
} else {
const num = parseInt(value, 10);
if (!isNaN(num) && num >= 0) {
setImageGenerationParams({ partialImages: num });
}
}
}
function clearSteps() {
setImageGenerationParams({ numInferenceSteps: null });
}
@@ -164,92 +134,90 @@
<div class="border-b border-exo-medium-gray/30 px-3 py-2">
<!-- Basic params row -->
<div class="flex items-center gap-3 flex-wrap">
<!-- Size (hidden in edit mode - output size comes from input image) -->
{#if !isEditMode}
<div class="flex items-center gap-1.5">
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
>SIZE:</span
<!-- Size -->
<div class="flex items-center gap-1.5">
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
>SIZE:</span
>
<div class="relative">
<button
bind:this={sizeButtonRef}
type="button"
onclick={() => (isSizeDropdownOpen = !isSizeDropdownOpen)}
class="bg-exo-medium-gray/50 border border-exo-yellow/30 rounded pl-2 pr-6 py-1 text-xs font-mono text-exo-yellow cursor-pointer transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70 {isSizeDropdownOpen
? 'border-exo-yellow/70'
: ''}"
>
<div class="relative">
<button
bind:this={sizeButtonRef}
type="button"
onclick={() => (isSizeDropdownOpen = !isSizeDropdownOpen)}
class="bg-exo-medium-gray/50 border border-exo-yellow/30 rounded pl-2 pr-6 py-1 text-xs font-mono text-exo-yellow cursor-pointer transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70 {isSizeDropdownOpen
? 'border-exo-yellow/70'
: ''}"
{params.size}
</button>
<div
class="absolute right-1.5 top-1/2 -translate-y-1/2 pointer-events-none transition-transform duration-200 {isSizeDropdownOpen
? 'rotate-180'
: ''}"
>
<svg
class="w-3 h-3 text-exo-yellow/60"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
{params.size}
</button>
<div
class="absolute right-1.5 top-1/2 -translate-y-1/2 pointer-events-none transition-transform duration-200 {isSizeDropdownOpen
? 'rotate-180'
: ''}"
>
<svg
class="w-3 h-3 text-exo-yellow/60"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M19 9l-7 7-7-7"
/>
</svg>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M19 9l-7 7-7-7"
/>
</svg>
</div>
</div>
{#if isSizeDropdownOpen}
<!-- Backdrop to close dropdown -->
<button
type="button"
class="fixed inset-0 z-[9998] cursor-default"
onclick={() => (isSizeDropdownOpen = false)}
aria-label="Close dropdown"
></button>
<!-- Dropdown Panel - fixed positioning to escape overflow:hidden -->
<div
class="fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto min-w-max"
style="bottom: calc(100vh - {sizeDropdownPosition()
.top}px + 4px); left: {sizeDropdownPosition().left}px;"
>
<div class="py-1">
{#each sizeOptions as size}
<button
type="button"
onclick={() => selectSize(size)}
class="w-full px-3 py-1.5 text-left text-xs font-mono tracking-wide transition-colors duration-100 flex items-center gap-2 {params.size ===
size
? 'bg-transparent text-exo-yellow'
: 'text-exo-light-gray hover:text-exo-yellow'}"
>
{#if params.size === size}
<svg
class="w-3 h-3 flex-shrink-0"
fill="currentColor"
viewBox="0 0 20 20"
>
<path
fill-rule="evenodd"
d="M16.707 5.293a1 1 0 010 1.414l-8 8a1 1 0 01-1.414 0l-4-4a1 1 0 011.414-1.414L8 12.586l7.293-7.293a1 1 0 011.414 0z"
clip-rule="evenodd"
/>
</svg>
{:else}
<span class="w-3"></span>
{/if}
<span>{size}</span>
</button>
{/each}
</div>
</div>
{#if isSizeDropdownOpen}
<!-- Backdrop to close dropdown -->
<button
type="button"
class="fixed inset-0 z-[9998] cursor-default"
onclick={() => (isSizeDropdownOpen = false)}
aria-label="Close dropdown"
></button>
<!-- Dropdown Panel - fixed positioning to escape overflow:hidden -->
<div
class="fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto min-w-max"
style="bottom: calc(100vh - {sizeDropdownPosition()
.top}px + 4px); left: {sizeDropdownPosition().left}px;"
>
<div class="py-1">
{#each sizeOptions as size}
<button
type="button"
onclick={() => selectSize(size)}
class="w-full px-3 py-1.5 text-left text-xs font-mono tracking-wide transition-colors duration-100 flex items-center gap-2 {params.size ===
size
? 'bg-transparent text-exo-yellow'
: 'text-exo-light-gray hover:text-exo-yellow'}"
>
{#if params.size === size}
<svg
class="w-3 h-3 flex-shrink-0"
fill="currentColor"
viewBox="0 0 20 20"
>
<path
fill-rule="evenodd"
d="M16.707 5.293a1 1 0 010 1.414l-8 8a1 1 0 01-1.414 0l-4-4a1 1 0 011.414-1.414L8 12.586l7.293-7.293a1 1 0 011.414 0z"
clip-rule="evenodd"
/>
</svg>
{:else}
<span class="w-3"></span>
{/if}
<span>{size}</span>
</button>
{/each}
</div>
</div>
{/if}
</div>
{/if}
{/if}
</div>
<!-- Quality -->
<div class="flex items-center gap-1.5">
@@ -357,59 +325,6 @@
</div>
</div>
<!-- Number of Images (not in edit mode) -->
{#if !isEditMode}
<div class="flex items-center gap-1.5">
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
>IMAGES:</span
>
<input
type="number"
min="1"
value={params.numImages}
oninput={handleNumImagesChange}
class="w-12 bg-exo-medium-gray/50 border border-exo-yellow/30 rounded px-2 py-1 text-xs font-mono text-exo-yellow text-center transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70"
/>
</div>
{/if}
<!-- Stream toggle -->
<div class="flex items-center gap-1.5">
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
>STREAM:</span
>
<button
type="button"
onclick={() => handleStreamChange(!params.stream)}
class="w-8 h-4 rounded-full transition-all duration-200 cursor-pointer relative {params.stream
? 'bg-exo-yellow'
: 'bg-exo-medium-gray/50 border border-exo-yellow/30'}"
title={params.stream ? "Streaming enabled" : "Streaming disabled"}
>
<div
class="absolute top-0.5 w-3 h-3 rounded-full transition-all duration-200 {params.stream
? 'right-0.5 bg-exo-black'
: 'left-0.5 bg-exo-light-gray'}"
></div>
</button>
</div>
<!-- Partial Images (only when streaming) -->
{#if params.stream}
<div class="flex items-center gap-1.5">
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
>PARTIALS:</span
>
<input
type="number"
min="0"
value={params.partialImages}
oninput={handlePartialImagesChange}
class="w-12 bg-exo-medium-gray/50 border border-exo-yellow/30 rounded px-2 py-1 text-xs font-mono text-exo-yellow text-center transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70"
/>
</div>
{/if}
<!-- Input Fidelity (edit mode only) -->
{#if isEditMode}
<div class="flex items-center gap-1.5">

View File

@@ -238,10 +238,6 @@ export interface ImageGenerationParams {
size: "512x512" | "768x768" | "1024x1024" | "1024x768" | "768x1024";
quality: "low" | "medium" | "high";
outputFormat: "png" | "jpeg";
numImages: number;
// Streaming params
stream: boolean;
partialImages: number;
// Advanced params
seed: number | null;
numInferenceSteps: number | null;
@@ -261,9 +257,6 @@ const DEFAULT_IMAGE_PARAMS: ImageGenerationParams = {
size: "1024x1024",
quality: "medium",
outputFormat: "png",
numImages: 1,
stream: true,
partialImages: 3,
seed: null,
numInferenceSteps: null,
guidance: null,
@@ -1816,13 +1809,12 @@ class AppStore {
const requestBody: Record<string, unknown> = {
model,
prompt,
n: params.numImages,
quality: params.quality,
size: params.size,
output_format: params.outputFormat,
response_format: "b64_json",
stream: params.stream,
partial_images: params.partialImages,
stream: true,
partial_images: 3,
};
if (hasAdvancedParams) {
@@ -1886,74 +1878,31 @@ class AppStore {
if (imageData && idx !== -1) {
const format = parsed.format || "png";
const mimeType = `image/${format}`;
const imageIndex = parsed.image_index ?? 0;
const numImages = params.numImages;
if (parsed.type === "partial") {
// Update with partial image and progress
const partialNum = (parsed.partial_index ?? 0) + 1;
const totalPartials = parsed.total_partials ?? 3;
const progressText =
numImages > 1
? `Generating image ${imageIndex + 1}/${numImages}... ${partialNum}/${totalPartials}`
: `Generating... ${partialNum}/${totalPartials}`;
this.messages[idx].content = progressText;
const partialAttachment: MessageAttachment = {
type: "generated-image",
name: `generated-image.${format}`,
preview: `data:${mimeType};base64,${imageData}`,
mimeType,
};
if (imageIndex === 0) {
// First image - safe to replace attachments with partial preview
this.messages[idx].attachments = [partialAttachment];
} else {
// Subsequent images - keep existing finals, show partial at current position
const existingAttachments =
this.messages[idx].attachments || [];
// Keep only the completed final images (up to current imageIndex)
const finals = existingAttachments.slice(0, imageIndex);
this.messages[idx].attachments = [
...finals,
partialAttachment,
];
}
this.messages[idx].content =
`Generating... ${partialNum}/${totalPartials}`;
this.messages[idx].attachments = [
{
type: "generated-image",
name: `generated-image.${format}`,
preview: `data:${mimeType};base64,${imageData}`,
mimeType,
},
];
} else if (parsed.type === "final") {
// Final image - replace partial at this position
const newAttachment: MessageAttachment = {
type: "generated-image",
name: `generated-image-${imageIndex + 1}.${format}`,
preview: `data:${mimeType};base64,${imageData}`,
mimeType,
};
if (imageIndex === 0) {
// First final image - replace any partial preview
this.messages[idx].attachments = [newAttachment];
} else {
// Subsequent images - keep previous finals, replace partial at current position
const existingAttachments =
this.messages[idx].attachments || [];
// Slice keeps indices 0 to imageIndex-1 (the previous final images)
const previousFinals = existingAttachments.slice(
0,
imageIndex,
);
this.messages[idx].attachments = [
...previousFinals,
newAttachment,
];
}
// Update progress message for multiple images
if (numImages > 1 && imageIndex < numImages - 1) {
this.messages[idx].content =
`Generating image ${imageIndex + 2}/${numImages}...`;
} else {
this.messages[idx].content = "";
}
// Final image
this.messages[idx].content = "";
this.messages[idx].attachments = [
{
type: "generated-image",
name: `generated-image.${format}`,
preview: `data:${mimeType};base64,${imageData}`,
mimeType,
},
];
}
}
} catch {
@@ -2034,8 +1983,8 @@ class AppStore {
formData.append("size", params.size);
formData.append("output_format", params.outputFormat);
formData.append("response_format", "b64_json");
formData.append("stream", params.stream ? "1" : "0");
formData.append("partial_images", params.partialImages.toString());
formData.append("stream", "1"); // Use "1" instead of "true" for reliable FastAPI boolean parsing
formData.append("partial_images", "3");
formData.append("input_fidelity", params.inputFidelity);
// Advanced params

View File

@@ -17,7 +17,7 @@ dependencies = [
"loguru>=0.7.3",
"exo_pyo3_bindings", # rust bindings
"anyio==4.11.0",
"mlx==0.30.3; sys_platform == 'darwin'",
"mlx @ git+https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git; sys_platform == 'darwin'",
"mlx[cpu]==0.30.3; sys_platform == 'linux'",
"mlx-lm @ git+https://github.com/AlexCheema/mlx-lm.git@fix-transformers-5.0.0rc2",
"tiktoken>=0.12.0", # required for kimi k2 tokenizer

View File

@@ -1,5 +1,4 @@
import base64
import contextlib
import json
import time
from collections.abc import AsyncGenerator
@@ -34,7 +33,6 @@ from exo.shared.models.model_cards import (
ModelId,
)
from exo.shared.types.api import (
AdvancedImageParams,
BenchChatCompletionResponse,
BenchChatCompletionTaskParams,
BenchImageGenerationResponse,
@@ -837,7 +835,6 @@ class API:
# Yield partial image event (always use b64_json for partials)
event_data = {
"type": "partial",
"image_index": chunk.image_index,
"partial_index": partial_idx,
"total_partials": total_partials,
"format": str(chunk.format),
@@ -1027,9 +1024,6 @@ class API:
stream: bool,
partial_images: int,
bench: bool,
quality: Literal["high", "medium", "low"],
output_format: Literal["png", "jpeg", "webp"],
advanced_params: AdvancedImageParams | None,
) -> ImageEdits:
"""Prepare and send an image edits command with chunked image upload."""
resolved_model = await self._validate_image_model(model)
@@ -1058,9 +1052,6 @@ class API:
stream=stream,
partial_images=partial_images,
bench=bench,
quality=quality,
output_format=output_format,
advanced_params=advanced_params,
),
)
@@ -1095,22 +1086,12 @@ class API:
input_fidelity: Literal["low", "high"] = Form("low"),
stream: str = Form("false"),
partial_images: str = Form("0"),
quality: Literal["high", "medium", "low"] = Form("medium"),
output_format: Literal["png", "jpeg", "webp"] = Form("png"),
advanced_params: str | None = Form(None),
) -> ImageGenerationResponse | StreamingResponse:
"""Handle image editing requests (img2img)."""
# Parse string form values to proper types
stream_bool = stream.lower() in ("true", "1", "yes")
partial_images_int = int(partial_images) if partial_images.isdigit() else 0
parsed_advanced_params: AdvancedImageParams | None = None
if advanced_params:
with contextlib.suppress(Exception):
parsed_advanced_params = AdvancedImageParams.model_validate_json(
advanced_params
)
command = await self._send_image_edits_command(
image=image,
prompt=prompt,
@@ -1122,9 +1103,6 @@ class API:
stream=stream_bool,
partial_images=partial_images_int,
bench=False,
quality=quality,
output_format=output_format,
advanced_params=parsed_advanced_params,
)
if stream_bool and partial_images_int > 0:
@@ -1155,18 +1133,8 @@ class API:
size: str = Form("1024x1024"),
response_format: Literal["url", "b64_json"] = Form("b64_json"),
input_fidelity: Literal["low", "high"] = Form("low"),
quality: Literal["high", "medium", "low"] = Form("medium"),
output_format: Literal["png", "jpeg", "webp"] = Form("png"),
advanced_params: str | None = Form(None),
) -> BenchImageGenerationResponse:
"""Handle benchmark image editing requests with generation stats."""
parsed_advanced_params: AdvancedImageParams | None = None
if advanced_params:
with contextlib.suppress(Exception):
parsed_advanced_params = AdvancedImageParams.model_validate_json(
advanced_params
)
command = await self._send_image_edits_command(
image=image,
prompt=prompt,
@@ -1178,9 +1146,6 @@ class API:
stream=False,
partial_images=0,
bench=True,
quality=quality,
output_format=output_format,
advanced_params=parsed_advanced_params,
)
return await self._collect_image_generation_with_stats(

View File

@@ -1,11 +0,0 @@
"""Shared types for MLX-related functionality."""
from mlx_lm.models.cache import (
KVCache,
QuantizedKVCache,
RotatingKVCache,
)
# Type alias for KV cache - matches make_kv_cache return type
# This list contains one cache entry per transformer layer
KVCacheType = list[KVCache | RotatingKVCache | QuantizedKVCache]

View File

@@ -30,7 +30,6 @@ class ImageGenerationResponse(BaseRunnerResponse):
image_data: bytes
format: Literal["png", "jpeg", "webp"] = "png"
stats: ImageGenerationStats | None = None
image_index: int = 0
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
for name, value in super().__repr_args__(): # pyright: ignore[reportAny]
@@ -45,7 +44,6 @@ class PartialImageResponse(BaseRunnerResponse):
format: Literal["png", "jpeg", "webp"] = "png"
partial_index: int
total_partials: int
image_index: int = 0
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
for name, value in super().__repr_args__(): # pyright: ignore[reportAny]

View File

@@ -75,20 +75,19 @@ def generate_image(
intermediate images, then ImageGenerationResponse for the final image.
Yields:
PartialImageResponse for intermediate images (if partial_images > 0, first image only)
ImageGenerationResponse for final complete images
PartialImageResponse for intermediate images (if partial_images > 0)
ImageGenerationResponse for the final complete image
"""
width, height = parse_size(task.size)
quality: Literal["low", "medium", "high"] = task.quality or "medium"
advanced_params = task.advanced_params
if advanced_params is not None and advanced_params.seed is not None:
base_seed = advanced_params.seed
seed = advanced_params.seed
else:
base_seed = random.randint(0, 2**32 - 1)
seed = random.randint(0, 2**32 - 1)
is_bench = getattr(task, "bench", False)
num_images = task.n or 1
generation_start_time: float = 0.0
@@ -96,11 +95,7 @@ def generate_image(
mx.reset_peak_memory()
generation_start_time = time.perf_counter()
partial_images = (
task.partial_images
if task.partial_images is not None
else (3 if task.stream else 0)
)
partial_images = task.partial_images or (3 if task.stream else 0)
image_path: Path | None = None
@@ -110,81 +105,72 @@ def generate_image(
image_path = Path(tmpdir) / "input.png"
image_path.write_bytes(base64.b64decode(task.image_data))
for image_num in range(num_images):
# Increment seed for each image to ensure unique results
current_seed = base_seed + image_num
# Iterate over generator results
for result in model.generate(
prompt=task.prompt,
height=height,
width=width,
quality=quality,
seed=seed,
image_path=image_path,
partial_images=partial_images,
advanced_params=advanced_params,
):
if isinstance(result, tuple):
# Partial image: (Image, partial_index, total_partials)
image, partial_idx, total_partials = result
buffer = io.BytesIO()
image_format = task.output_format.upper()
if image_format == "JPG":
image_format = "JPEG"
if image_format == "JPEG" and image.mode == "RGBA":
image = image.convert("RGB")
image.save(buffer, format=image_format)
for result in model.generate(
prompt=task.prompt,
height=height,
width=width,
quality=quality,
seed=current_seed,
image_path=image_path,
partial_images=partial_images,
advanced_params=advanced_params,
):
if isinstance(result, tuple):
# Partial image: (Image, partial_index, total_partials)
image, partial_idx, total_partials = result
buffer = io.BytesIO()
image_format = task.output_format.upper()
if image_format == "JPG":
image_format = "JPEG"
if image_format == "JPEG" and image.mode == "RGBA":
image = image.convert("RGB")
image.save(buffer, format=image_format)
yield PartialImageResponse(
image_data=buffer.getvalue(),
format=task.output_format,
partial_index=partial_idx,
total_partials=total_partials,
)
else:
image = result
yield PartialImageResponse(
image_data=buffer.getvalue(),
format=task.output_format,
partial_index=partial_idx,
total_partials=total_partials,
image_index=image_num,
stats: ImageGenerationStats | None = None
if is_bench:
generation_end_time = time.perf_counter()
total_generation_time = generation_end_time - generation_start_time
num_inference_steps = model.get_steps_for_quality(quality)
seconds_per_step = (
total_generation_time / num_inference_steps
if num_inference_steps > 0
else 0.0
)
else:
image = result
# Only include stats on the final image
stats: ImageGenerationStats | None = None
if is_bench and image_num == num_images - 1:
generation_end_time = time.perf_counter()
total_generation_time = (
generation_end_time - generation_start_time
)
peak_memory_gb = mx.get_peak_memory() / (1024**3)
num_inference_steps = model.get_steps_for_quality(quality)
total_steps = num_inference_steps * num_images
seconds_per_step = (
total_generation_time / total_steps
if total_steps > 0
else 0.0
)
peak_memory_gb = mx.get_peak_memory() / (1024**3)
stats = ImageGenerationStats(
seconds_per_step=seconds_per_step,
total_generation_time=total_generation_time,
num_inference_steps=num_inference_steps,
num_images=num_images,
image_width=width,
image_height=height,
peak_memory_usage=Memory.from_gb(peak_memory_gb),
)
buffer = io.BytesIO()
image_format = task.output_format.upper()
if image_format == "JPG":
image_format = "JPEG"
if image_format == "JPEG" and image.mode == "RGBA":
image = image.convert("RGB")
image.save(buffer, format=image_format)
yield ImageGenerationResponse(
image_data=buffer.getvalue(),
format=task.output_format,
stats=stats,
image_index=image_num,
stats = ImageGenerationStats(
seconds_per_step=seconds_per_step,
total_generation_time=total_generation_time,
num_inference_steps=num_inference_steps,
num_images=task.n or 1,
image_width=width,
image_height=height,
peak_memory_usage=Memory.from_gb(peak_memory_gb),
)
buffer = io.BytesIO()
image_format = task.output_format.upper()
if image_format == "JPG":
image_format = "JPEG"
if image_format == "JPEG" and image.mode == "RGBA":
image = image.convert("RGB")
image.save(buffer, format=image_format)
yield ImageGenerationResponse(
image_data=buffer.getvalue(),
format=task.output_format,
stats=stats,
)

View File

@@ -145,6 +145,10 @@ class PipelineLastLayer(CustomMlxLayer):
if cache is not None:
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
output = mx.distributed.all_gather(output, group=self.group)[
-output.shape[0] :
] # type :ignore
return output
@@ -252,10 +256,6 @@ def patch_pipeline_model[T](model: T, group: mx.distributed.Group) -> T:
if cache is not None:
cache[-1].state = mx.depends(cache[-1].state, logits) # type: ignore
logits = mx.distributed.all_gather(logits, group=group)[
-logits.shape[0] :
] # type :ignore
return logits
cls.__call__ = patched_call

View File

@@ -1,53 +1,39 @@
# type: ignore
# TODO: Fix this file, including types!
from copy import deepcopy
from typing import Callable
import mlx.core as mx
from mlx_lm.models.cache import trim_prompt_cache
from mlx_lm import stream_generate
from mlx_lm.models.cache import _BaseCache, trim_prompt_cache
from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.shared.types.mlx import KVCacheType
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.constants import KEEP_KV_SIZE, KV_BITS, KV_GROUP_SIZE
from exo.worker.engines.mlx.utils_mlx import make_kv_cache
from exo.worker.runner.bootstrap import logger
class KVPrefixCache:
def __init__(self):
# Only one prefix cache per runner.
self.prompts: list[mx.array] = [] # mx array of tokens (ints)
self.caches: list[KVCacheType] = []
def clear(self):
"""Clear all cached prompts and caches."""
self.prompts.clear()
self.caches.clear()
self.caches: list[list[_BaseCache]] = []
def add_kv_cache(
self, tokenizer: TokenizerWrapper, prompt: str, cache: KVCacheType
self, tokenizer: TokenizerWrapper, prompt: str, cache: list[_BaseCache]
):
tokenized_prompt = encode_prompt(tokenizer, prompt)
tokenized_prompt = self.encode_prompt(tokenizer, prompt)
self.prompts.append(tokenized_prompt)
self.caches.append(deepcopy(cache))
logger.info(f"KV cache saved: {len(tokenized_prompt)} tokens")
def get_kv_cache(
self,
model: Model,
tokenizer: TokenizerWrapper,
sampler: Callable[[mx.array], mx.array],
prompt: str,
) -> tuple[KVCacheType, mx.array]:
"""Get KV cache for prompt, returning remaining tokens to prefill.
This method finds the best matching cached prefix and returns:
- A copy of the cache trimmed to the prefix length
- The remaining tokens that need to be prefilled before generation
The caller is responsible for prefilling the remaining tokens.
Returns:
Tuple of (cache, remaining_tokens) where remaining_tokens are the
tokens that still need to be prefilled/processed.
"""
tokenized_prompt = encode_prompt(tokenizer, prompt)
) -> list[_BaseCache]:
tokenized_prompt = self.encode_prompt(tokenizer, prompt)
max_length = len(tokenized_prompt)
best_snapshot_index, best_snapshot_length = None, 0
@@ -56,75 +42,63 @@ class KVPrefixCache:
length = _get_prefix_length(tokenized_prompt, cached_prompt)
if length == max_length:
# Exact match - cached prompt starts with our entire prompt
# Trim cache to prompt length - 1, return last token for stream_generate
prompt_cache = deepcopy(self.caches[i])
cached_length = _cache_length(self.caches[i])
tokens_to_trim = cached_length - (max_length - 1)
if tokens_to_trim > 0:
trim_prompt_cache(prompt_cache, tokens_to_trim)
logger.info(f"KV cache exact match: {max_length} tokens (instant)")
return prompt_cache, tokenized_prompt[-1:]
return self.caches[i]
if length > best_snapshot_length:
best_snapshot_index, best_snapshot_length = i, length
if best_snapshot_index is not None:
new_tokens = max_length - best_snapshot_length
logger.info(
f"KV cache prefix match: {best_snapshot_length}/{max_length} tokens "
f"(reusing {best_snapshot_length}, need to prefill {new_tokens})"
)
prompt_cache = deepcopy(self.caches[best_snapshot_index])
# Trim removes tokens from the end, so we trim (cached_length - prefix_length) to keep the prefix
cached_length = _cache_length(self.caches[best_snapshot_index])
tokens_to_trim = cached_length - best_snapshot_length
if tokens_to_trim > 0:
trim_prompt_cache(prompt_cache, tokens_to_trim)
# Return remaining tokens for caller to prefill
remaining_tokens = tokenized_prompt[best_snapshot_length:]
return prompt_cache, remaining_tokens
trim_prompt_cache(prompt_cache, max_length - best_snapshot_length)
tokenized_prompt = tokenized_prompt[best_snapshot_index:]
else:
prompt_cache = make_kv_cache(model)
if len(self.prompts) == 0:
logger.info(f"KV cache empty, need to prefill {max_length} tokens")
else:
logger.info(
f"KV cache no prefix match, need to prefill {max_length} tokens"
)
prompt_cache = make_kv_cache(
model,
# max_kv_size=MAX_KV_SIZE,
# keep=KEEP_KV_SIZE
)
# Return all tokens for caller to prefill
return prompt_cache, tokenized_prompt
prefill(model, tokenizer, sampler, tokenized_prompt, prompt_cache)
return prompt_cache
def encode_prompt(tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
"""Encode a prompt string to token array.
For chat-templated prompts (which have their own structure markers like
<|im_user|>, <|im_middle|>, etc.), we should NOT add BOS/EOS tokens as
that would corrupt the prompt structure.
"""
# Chat templates define their own structure - don't add BOS/EOS
tokenized_prompt = tokenizer.encode(prompt, add_special_tokens=False)
return mx.array(tokenized_prompt)
def _cache_length(cache: KVCacheType) -> int:
"""Get the number of tokens in a KV cache."""
# Use .offset attribute which all cache types have (len() not implemented in older QuantizedKVCache)
return max(c.offset for c in cache)
def encode_prompt(self, tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
add_special_tokens = tokenizer.bos_token is None or not prompt.startswith(
tokenizer.bos_token
)
tokenized_prompt = tokenizer.encode(
prompt, add_special_tokens=add_special_tokens
)
return mx.array(tokenized_prompt)
def _get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
"""Find the length of the common prefix between two token arrays."""
n = min(int(prompt.shape[0]), int(cached_prompt.shape[0]))
n = min(int(prompt.shape[0]), int(cached_prompt.shape[0]), KEEP_KV_SIZE)
if n == 0:
return 0
equal = mx.equal(prompt[:n], cached_prompt[:n]).astype(mx.int32)
equal = (prompt[:n] == cached_prompt[:n]).astype(mx.int32)
prefix_mask = mx.cumprod(equal) # stays 1 until first mismatch, then 0 forever
return int(mx.sum(prefix_mask).item())
def prefill(
model: Model,
tokenizer: TokenizerWrapper,
sampler: Callable[[mx.array], mx.array],
prompt: mx.array,
cache: list[_BaseCache],
) -> None:
for _ in stream_generate(
model=model,
tokenizer=tokenizer,
prompt=prompt,
max_tokens=0,
sampler=sampler,
prompt_cache=cache,
prefill_step_size=2048,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
):
pass

View File

@@ -1,12 +1,12 @@
import time
from typing import Callable, Generator, cast, get_args
from typing import Any, Callable, Generator, cast, get_args
import mlx.core as mx
from mlx_lm.generate import stream_generate
from mlx_lm.models.cache import trim_prompt_cache
from mlx_lm.models.cache import KVCache
from mlx_lm.sample_utils import make_sampler
from mlx_lm.tokenizer_utils import TokenizerWrapper
# from exo.engines.mlx.cache import KVPrefixCache
from exo.shared.types.api import (
BenchChatCompletionTaskParams,
ChatCompletionMessage,
@@ -14,13 +14,11 @@ from exo.shared.types.api import (
GenerationStats,
)
from exo.shared.types.memory import Memory
from exo.shared.types.mlx import KVCacheType
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.worker.runner_response import (
GenerationResponse,
)
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.cache import KVPrefixCache, encode_prompt
from exo.worker.engines.mlx.constants import KV_BITS, KV_GROUP_SIZE, MAX_TOKENS
from exo.worker.engines.mlx.utils_mlx import (
apply_chat_template,
@@ -32,59 +30,19 @@ from exo.worker.runner.bootstrap import logger
generation_stream = mx.new_stream(mx.default_device())
def prefill(
model: Model,
tokenizer: TokenizerWrapper,
sampler: Callable[[mx.array], mx.array],
prompt_tokens: mx.array,
cache: KVCacheType,
) -> float:
"""Prefill the KV cache with prompt tokens.
This runs the model over the prompt tokens to populate the cache,
then trims off the extra generated token.
Returns:
tokens_per_sec
"""
num_tokens = len(prompt_tokens)
if num_tokens == 0:
return 0.0
logger.debug(f"Prefilling {num_tokens} tokens...")
start_time = time.perf_counter()
def progress_callback(processed: int, total: int) -> None:
elapsed = time.time() - start_time
tok_per_sec = processed / elapsed if elapsed > 0 else 0
logger.debug(
f"Prefill progress: {processed}/{total} tokens ({tok_per_sec:.1f} tok/s)"
)
# Use max_tokens=1 because max_tokens=0 does not work.
# We just throw away the generated token - we only care about filling the cache
for _ in stream_generate(
model=model,
tokenizer=tokenizer,
prompt=prompt_tokens,
max_tokens=1,
sampler=sampler,
prompt_cache=cache,
prefill_step_size=2048,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
prompt_progress_callback=progress_callback,
):
break # Stop after first iteration - cache is now filled
trim_prompt_cache(cache, 1)
elapsed = time.perf_counter() - start_time
tokens_per_sec = num_tokens / elapsed if elapsed > 0 else 0.0
logger.debug(
f"Prefill complete: {num_tokens} tokens in {elapsed:.2f}s "
f"({tokens_per_sec:.1f} tok/s)"
)
return tokens_per_sec
def maybe_quantize_kv_cache(
prompt_cache: list[KVCache | Any],
quantized_kv_start: int,
kv_group_size: int,
kv_bits: int | None,
) -> None:
if kv_bits is None:
return
for e, c in enumerate(prompt_cache):
if (
hasattr(c, "to_quantized") and c.offset >= quantized_kv_start # type: ignore
):
prompt_cache[e] = c.to_quantized(group_size=kv_group_size, bits=kv_bits)
def warmup_inference(
@@ -162,7 +120,6 @@ def mlx_generate(
tokenizer: TokenizerWrapper,
task: ChatCompletionTaskParams,
prompt: str,
kv_prefix_cache: KVPrefixCache | None = None,
) -> Generator[GenerationResponse]:
# Ensure that generation stats only contains peak memory for this generation
mx.reset_peak_memory()
@@ -174,16 +131,7 @@ def mlx_generate(
if task.seed is not None:
mx.random.seed(task.seed)
# Do not use the prefix cache if we are trying to do benchmarks.
if is_bench:
kv_prefix_cache = None
# Use prefix cache if available, otherwise create fresh cache
if kv_prefix_cache is None:
caches = make_kv_cache(model=model)
prompt_tokens = encode_prompt(tokenizer, prompt)
else:
caches, prompt_tokens = kv_prefix_cache.get_kv_cache(model, tokenizer, prompt)
caches = make_kv_cache(model=model)
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] = []
if is_bench:
@@ -196,19 +144,11 @@ def mlx_generate(
top_p=task.top_p if task.top_p is not None else 1.0,
)
# Prefill cache with all tokens except the last one
prefill_tps = prefill(model, tokenizer, sampler, prompt_tokens[-1:], caches)
# stream_generate starts from the last token
last_token = prompt_tokens[-1:]
max_tokens = task.max_tokens or MAX_TOKENS
generated_text_parts: list[str] = []
generation_start_time = time.perf_counter()
for out in stream_generate(
model=model,
tokenizer=tokenizer,
prompt=last_token,
prompt=prompt,
max_tokens=max_tokens,
sampler=sampler,
logits_processors=logits_processors,
@@ -218,13 +158,12 @@ def mlx_generate(
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
):
generated_text_parts.append(out.text)
logger.info(out.text)
stats: GenerationStats | None = None
if out.finish_reason is not None:
stats = GenerationStats(
prompt_tps=float(prefill_tps or out.prompt_tps),
prompt_tps=float(out.prompt_tps),
generation_tps=float(out.generation_tps),
prompt_tokens=int(out.prompt_tokens),
generation_tokens=int(out.generation_tokens),
@@ -246,22 +185,6 @@ def mlx_generate(
)
if out.finish_reason is not None:
# Log generation stats
generation_elapsed = time.perf_counter() - generation_start_time
generated_tokens = len(generated_text_parts)
generation_tps = (
generated_tokens / generation_elapsed if generation_elapsed > 0 else 0.0
)
logger.debug(
f"Generation complete: prefill {prompt_tokens} tokens @ "
f"{prefill_tps:.1f} tok/s, generated {generated_tokens} tokens @ "
f"{generation_tps:.1f} tok/s"
)
# Save cache for future prefix matching (clear first to keep only the last one)
if kv_prefix_cache is not None:
kv_prefix_cache.clear()
full_prompt = prompt + "".join(generated_text_parts)
kv_prefix_cache.add_kv_cache(tokenizer, full_prompt, caches)
break
# TODO: Do we want an mx_barrier?

View File

@@ -170,10 +170,10 @@ def mlx_distributed_init(
# TODO: update once upstream fixes
logger.info(
f"rank {rank} MLX_JACCL_DEVICES: {coordination_file} with devices: {jaccl_devices_json}"
f"rank {rank} MLX_IBV_DEVICES: {coordination_file} with devices: {jaccl_devices_json}"
)
logger.info(f"rank {rank} MLX_JACCL_COORDINATOR: {jaccl_coordinator}")
os.environ["MLX_JACCL_DEVICES"] = coordination_file
os.environ["MLX_IBV_DEVICES"] = coordination_file
os.environ["MLX_RANK"] = str(rank)
os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator
group = mx.distributed.init(backend="jaccl", strict=True)

View File

@@ -70,7 +70,6 @@ from exo.worker.engines.image import (
warmup_image_generator,
)
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.cache import KVPrefixCache
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
from exo.worker.engines.mlx.utils_mlx import (
apply_chat_template,
@@ -104,7 +103,6 @@ def main(
model: Model | DistributedImageModel | None = None
tokenizer = None
group = None
kv_prefix_cache: KVPrefixCache | None = None
current_status: RunnerStatus = RunnerIdle()
logger.info("runner created")
@@ -173,9 +171,6 @@ def main(
f"Unknown model task(s): {shard_metadata.model_card.tasks}"
)
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
kv_prefix_cache = KVPrefixCache()
current_status = RunnerLoaded()
logger.info("runner loaded")
case StartWarmup() if isinstance(current_status, RunnerLoaded):
@@ -243,7 +238,6 @@ def main(
tokenizer=tokenizer,
task=task_params,
prompt=prompt,
kv_prefix_cache=kv_prefix_cache,
)
# GPT-OSS specific parsing to match other model formats.
@@ -618,7 +612,7 @@ def _process_image_response(
command_id=command_id,
model_id=shard_metadata.model_card.model_id,
event_sender=event_sender,
image_index=response.image_index,
image_index=response.partial_index if is_partial else image_index,
is_partial=is_partial,
partial_index=response.partial_index if is_partial else None,
total_partials=response.total_partials if is_partial else None,

59
uv.lock generated
View File

@@ -376,8 +376,8 @@ dependencies = [
{ name = "hypercorn", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "loguru", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mflux", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", extra = ["cpu"], marker = "sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.3", source = { registry = "https://pypi.org/simple" }, extra = ["cpu"], marker = "sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.4.dev20260121+fbe306f9", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git#fbe306f92a47d9b887ee7af2e3af6f1b9e28e663" }, marker = "sys_platform == 'darwin'" },
{ name = "mlx-lm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "openai-harmony", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "pillow", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -413,7 +413,7 @@ requires-dist = [
{ name = "hypercorn", specifier = ">=0.18.0" },
{ name = "loguru", specifier = ">=0.7.3" },
{ name = "mflux", specifier = ">=0.14.2" },
{ name = "mlx", marker = "sys_platform == 'darwin'", specifier = "==0.30.3" },
{ name = "mlx", marker = "sys_platform == 'darwin'", git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git" },
{ name = "mlx", extras = ["cpu"], marker = "sys_platform == 'linux'", specifier = "==0.30.3" },
{ name = "mlx-lm", git = "https://github.com/AlexCheema/mlx-lm.git?rev=fix-transformers-5.0.0rc2" },
{ name = "openai-harmony", specifier = ">=0.0.8" },
@@ -458,16 +458,6 @@ dev = [
{ name = "pytest-asyncio", specifier = ">=1.0.0" },
]
[[package]]
name = "tomlkit"
version = "0.14.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/c3/af/14b24e41977adb296d6bd1fb59402cf7d60ce364f90c890bd2ec65c43b5a/tomlkit-0.14.0.tar.gz", hash = "sha256:cf00efca415dbd57575befb1f6634c4f42d2d87dbba376128adb42c121b87064", size = 187167 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/b5/11/87d6d29fb5d237229d67973a6c9e06e048f01cf4994dee194ab0ea841814/tomlkit-0.14.0-py3-none-any.whl", hash = "sha256:592064ed85b40fa213469f81ac584f67a4f2992509a7c3ea2d632208623a3680", size = 39310 },
]
[[package]]
name = "fastapi"
version = "0.128.0"
@@ -1004,8 +994,8 @@ dependencies = [
{ name = "fonttools", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "huggingface-hub", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "matplotlib", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", extra = ["cuda13"], marker = "sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.3", source = { registry = "https://pypi.org/simple" }, extra = ["cuda13"], marker = "sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.4.dev20260121+fbe306f9", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git#fbe306f92a47d9b887ee7af2e3af6f1b9e28e663" }, marker = "sys_platform == 'darwin'" },
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "opencv-python", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "piexif", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -1032,18 +1022,12 @@ wheels = [
name = "mlx"
version = "0.30.3"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "mlx-metal", marker = "sys_platform == 'darwin'" },
resolution-markers = [
"sys_platform == 'linux'",
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/d0/22/42935d593fe82d3b98eb9d60e4620ed99703886635106f89d407c68f33bc/mlx-0.30.3-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:743fac1e4f9e8e46c8262943c643a31139c255cdb256c99ad496958215ccac1e", size = 569344, upload-time = "2026-01-14T01:16:54.847Z" },
{ url = "https://files.pythonhosted.org/packages/7d/27/f2e7a5236289d45315d0215e8553b4dd7e2faaba3bcb5025b34b25d5ab66/mlx-0.30.3-cp313-cp313-macosx_15_0_arm64.whl", hash = "sha256:3b04ae81655aa0e63a6e8f2c749de3bbce64cf5b168ae10f39ed086dfa99e7f8", size = 569345, upload-time = "2026-01-14T01:16:56.564Z" },
{ url = "https://files.pythonhosted.org/packages/01/41/06b042457f51952456e9bb46b2c6e205ab3a28fc52d6751b5787fdb762b2/mlx-0.30.3-cp313-cp313-macosx_26_0_arm64.whl", hash = "sha256:ba9b5bdb1e929cc130af72efd7f73508c0f4e526d224489af7ec1c6419564659", size = 569213, upload-time = "2026-01-14T05:52:10.86Z" },
{ url = "https://files.pythonhosted.org/packages/ec/1e/f62c98fc0d2d878ee4235671f9d406b13cc9240493ba6fcfde2f72c2ff83/mlx-0.30.3-cp313-cp313-manylinux_2_35_aarch64.whl", hash = "sha256:dfe5c5b64e55398a22100804abbf9681996b03129e720e36b1727ed704db12b5", size = 617309, upload-time = "2026-01-14T01:16:57.58Z" },
{ url = "https://files.pythonhosted.org/packages/e9/62/811f064693449de740350d27793ce39343a460305ec8d878c318b80921d0/mlx-0.30.3-cp313-cp313-manylinux_2_35_x86_64.whl", hash = "sha256:a3364924610929936e6aaf13c71106161258e5a5d3f7813a64c07cc2435f9f55", size = 659521, upload-time = "2026-01-14T01:16:58.719Z" },
{ url = "https://files.pythonhosted.org/packages/82/e2/6e551bd48fb350fbf0ee4cc5cd09485437d260b8f4937f22d8623e14687a/mlx-0.30.3-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:2c27fd8daaae14ca6cf407fcd236006a6e968f7708c8f61a2709116f2e754852", size = 571920, upload-time = "2026-01-14T01:16:59.683Z" },
{ url = "https://files.pythonhosted.org/packages/82/c0/561d1c9d3d12830b0e7fdcbd807585ef20909e398d4bcdbf25e4367543eb/mlx-0.30.3-cp314-cp314-macosx_15_0_arm64.whl", hash = "sha256:b755fd4ed4b6a2ae4dee3766b5a2ea52fcbe83ebd1cf018458e18b74139409f3", size = 571921, upload-time = "2026-01-14T01:17:00.868Z" },
{ url = "https://files.pythonhosted.org/packages/42/1a/fb573fc2edc22a777fa254ff5c0c886ffd2c88aeb1f21c45778ef170f990/mlx-0.30.3-cp314-cp314-macosx_26_0_arm64.whl", hash = "sha256:7e352c0369a2f7e54d4f317b434eab3333918ea9edde1c43c61d36386b6f76bf", size = 571732, upload-time = "2026-01-14T05:52:11.893Z" },
{ url = "https://files.pythonhosted.org/packages/9e/db/d0083e8f2205b3b2dcd9670eb6f0d6c1b7cbfea6b01a1f8bff39142edf44/mlx-0.30.3-cp314-cp314-manylinux_2_35_aarch64.whl", hash = "sha256:00ac867f3d003c1477a66a579442c2040ba7ea43ce3c174490d1f8bf379606bd", size = 619635, upload-time = "2026-01-14T01:17:01.812Z" },
{ url = "https://files.pythonhosted.org/packages/ab/90/ab0b93ff0e76da4fe0e878722c76a308cfb950b044a4676e9617276d8ccd/mlx-0.30.3-cp314-cp314-manylinux_2_35_x86_64.whl", hash = "sha256:5be7d0329036f09c6ed003ea3e307e97e3144f20a3e4711b01810d7d5013cf2c", size = 659652, upload-time = "2026-01-14T01:17:02.915Z" },
]
@@ -1056,6 +1040,14 @@ cuda13 = [
{ name = "mlx-cuda-13", marker = "sys_platform == 'linux'" },
]
[[package]]
name = "mlx"
version = "0.30.4.dev20260121+fbe306f9"
source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git#fbe306f92a47d9b887ee7af2e3af6f1b9e28e663" }
resolution-markers = [
"sys_platform == 'darwin'",
]
[[package]]
name = "mlx-cpu"
version = "0.30.3"
@@ -1086,7 +1078,7 @@ version = "0.30.4"
source = { git = "https://github.com/AlexCheema/mlx-lm.git?rev=fix-transformers-5.0.0rc2#a5daf2b894f31793dfaef0fdf9bc3ed683176ad6" }
dependencies = [
{ name = "jinja2", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", marker = "sys_platform == 'darwin'" },
{ name = "mlx", version = "0.30.4.dev20260121+fbe306f9", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git#fbe306f92a47d9b887ee7af2e3af6f1b9e28e663" }, marker = "sys_platform == 'darwin'" },
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "protobuf", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "pyyaml", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -1094,16 +1086,6 @@ dependencies = [
{ name = "transformers", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
[[package]]
name = "mlx-metal"
version = "0.30.3"
source = { registry = "https://pypi.org/simple" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/f6/63/4d8f6fefb507c028df4454dabfe8d8e0ad2961bb06510b6aca23d2d5b2be/mlx_metal-0.30.3-py3-none-macosx_14_0_arm64.whl", hash = "sha256:6276312b02353714c7c6515169569fe1c4bebe3229c8ecf1fdb375a13e78c966", size = 37716245, upload-time = "2026-01-14T01:16:34.838Z" },
{ url = "https://files.pythonhosted.org/packages/35/91/1d452e48a4bb4958844fd3bb28ae31b8de110549c009ebec5024ce27ebf3/mlx_metal-0.30.3-py3-none-macosx_15_0_arm64.whl", hash = "sha256:c096c0a3428f3f96a06220f97a36f9528b18bc05173f821eb05bc8458e723fa8", size = 37712125, upload-time = "2026-01-14T01:16:38.619Z" },
{ url = "https://files.pythonhosted.org/packages/fe/36/7a3cbca85542b5ca4faf871e35927f43aa0e3fc830ae5b699780fe723677/mlx_metal-0.30.3-py3-none-macosx_26_0_arm64.whl", hash = "sha256:69068533bd1ee8b0379ce5de57ed5fd313577a10ecab58e1332fd1ff7248a75e", size = 46488962, upload-time = "2026-01-14T05:52:04.523Z" },
]
[[package]]
name = "more-itertools"
version = "10.8.0"
@@ -2227,6 +2209,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/44/6f/7120676b6d73228c96e17f1f794d8ab046fc910d781c8d151120c3f1569e/toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b", size = 16588, upload-time = "2020-11-01T01:40:20.672Z" },
]
[[package]]
name = "tomlkit"
version = "0.14.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/c3/af/14b24e41977adb296d6bd1fb59402cf7d60ce364f90c890bd2ec65c43b5a/tomlkit-0.14.0.tar.gz", hash = "sha256:cf00efca415dbd57575befb1f6634c4f42d2d87dbba376128adb42c121b87064", size = 187167, upload-time = "2026-01-13T01:14:53.304Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/b5/11/87d6d29fb5d237229d67973a6c9e06e048f01cf4994dee194ab0ea841814/tomlkit-0.14.0-py3-none-any.whl", hash = "sha256:592064ed85b40fa213469f81ac584f67a4f2992509a7c3ea2d632208623a3680", size = 39310, upload-time = "2026-01-13T01:14:51.965Z" },
]
[[package]]
name = "torch"
version = "2.9.1"