mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-23 13:29:29 -05:00
Compare commits
6 Commits
v1.0.64
...
fix-kv-pre
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7744420341 | ||
|
|
b777c6f505 | ||
|
|
812a9f232e | ||
|
|
f255345a1a | ||
|
|
a1939c89f2 | ||
|
|
cb9c9ee55c |
@@ -2,11 +2,14 @@
|
||||
This type stub file was generated by pyright.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Protocol, Literal, Self
|
||||
|
||||
import mlx.nn as nn
|
||||
from mlx.core import array
|
||||
import mlx.core as mx
|
||||
from typing import Any, Dict, List, Literal, Optional, Protocol, Self
|
||||
from mlx.core import array
|
||||
|
||||
"""
|
||||
This type stub file was generated by pyright.
|
||||
"""
|
||||
|
||||
class Cache(Protocol):
|
||||
keys: mx.array
|
||||
@@ -32,6 +35,7 @@ def make_prompt_cache(
|
||||
``make_cache`` method, a ``RotatingKVCache`` is used with a maximum
|
||||
size of ``max_kv_size``
|
||||
"""
|
||||
...
|
||||
|
||||
def save_prompt_cache(
|
||||
file_name: str, cache: List[Cache], metadata: Dict[str, str] = ...
|
||||
@@ -45,6 +49,7 @@ def save_prompt_cache(
|
||||
metadata (Dict[str, str]): Optional metadata to save along with model
|
||||
state.
|
||||
"""
|
||||
...
|
||||
|
||||
def load_prompt_cache(file_name: str, return_metadata=...) -> array:
|
||||
"""
|
||||
@@ -59,13 +64,15 @@ def load_prompt_cache(file_name: str, return_metadata=...) -> array:
|
||||
List[Any] or Tuple[List[Any], Dict[str, str]]: The prompt cache and
|
||||
the metadata if requested.
|
||||
"""
|
||||
...
|
||||
|
||||
def can_trim_prompt_cache(cache: List[Cache]) -> bool:
|
||||
def can_trim_prompt_cache(cache: List[Any]) -> bool:
|
||||
"""
|
||||
Check if model's cache can be trimmed.
|
||||
"""
|
||||
...
|
||||
|
||||
def trim_prompt_cache(cache: List[Cache], num_tokens: int) -> List[Cache]:
|
||||
def trim_prompt_cache(cache: List[Any], num_tokens: int) -> int:
|
||||
"""
|
||||
Trim the model's cache by the given number of tokens.
|
||||
|
||||
@@ -79,6 +86,7 @@ def trim_prompt_cache(cache: List[Cache], num_tokens: int) -> List[Cache]:
|
||||
Returns:
|
||||
(int): The number of tokens that were trimmed.
|
||||
"""
|
||||
...
|
||||
|
||||
def create_attention_mask(
|
||||
N: int, offset: int, return_array: bool, window_size: Optional[int]
|
||||
@@ -107,164 +115,125 @@ class ConcatenateKVCache(_BaseCache):
|
||||
KVCache with a larger step size before using this cache.
|
||||
"""
|
||||
def __init__(self) -> None: ...
|
||||
def update_and_fetch(self, keys, values): # -> tuple[Any | array, Any | array]:
|
||||
...
|
||||
def update_and_fetch(self, keys, values): ...
|
||||
@property
|
||||
def state(self): # -> tuple[Any | array | None, Any | array | None]:
|
||||
...
|
||||
def state(self): ...
|
||||
@state.setter
|
||||
def state(self, v): # -> None:
|
||||
...
|
||||
def is_trimmable(self): # -> Literal[True]:
|
||||
...
|
||||
def trim(self, n): # -> int:
|
||||
...
|
||||
def make_mask(self, *args, **kwargs): # -> array | Literal['causal'] | None:
|
||||
...
|
||||
def state(self, v): ...
|
||||
def is_trimmable(self): ...
|
||||
def trim(self, n): ...
|
||||
def make_mask(self, *args, **kwargs): ...
|
||||
|
||||
class QuantizedKVCache(_BaseCache):
|
||||
step = ...
|
||||
offset: int
|
||||
def __init__(self, group_size: int = ..., bits: int = ...) -> None: ...
|
||||
def update_and_fetch(self, keys, values): # -> Any:
|
||||
...
|
||||
def update_and_fetch(self, keys, values): ...
|
||||
@property
|
||||
def state(
|
||||
self,
|
||||
): # -> tuple[Any | tuple[array, array, array] | None, Any | tuple[array, array, array] | None] | Any:
|
||||
...
|
||||
def state(self): ...
|
||||
@state.setter
|
||||
def state(self, v): # -> None:
|
||||
...
|
||||
def state(self, v): ...
|
||||
@property
|
||||
def meta_state(self): # -> tuple[str, ...]:
|
||||
...
|
||||
def meta_state(self): ...
|
||||
@meta_state.setter
|
||||
def meta_state(self, v): # -> None:
|
||||
...
|
||||
def is_trimmable(self): # -> Literal[True]:
|
||||
...
|
||||
def trim(self, n): # -> int:
|
||||
...
|
||||
def make_mask(self, *args, **kwargs): # -> array | Literal['causal'] | None:
|
||||
...
|
||||
def meta_state(self, v): ...
|
||||
def is_trimmable(self): ...
|
||||
def trim(self, n): ...
|
||||
def make_mask(self, *args, **kwargs): ...
|
||||
|
||||
class KVCache(_BaseCache):
|
||||
step = ...
|
||||
offset: int
|
||||
def __init__(self) -> None: ...
|
||||
def update_and_fetch(self, keys, values): # -> tuple[array | Any, array | Any]:
|
||||
...
|
||||
def update_and_fetch(self, keys, values): ...
|
||||
@property
|
||||
def state(
|
||||
self,
|
||||
) -> tuple[array, array]: ...
|
||||
def state(self) -> tuple[array, array]: ...
|
||||
@state.setter
|
||||
def state(self, v) -> None: ...
|
||||
def is_trimmable(self): # -> Literal[True]:
|
||||
...
|
||||
def trim(self, n): # -> int:
|
||||
...
|
||||
def is_trimmable(self): ...
|
||||
def trim(self, n): ...
|
||||
def to_quantized(
|
||||
self, group_size: int = ..., bits: int = ...
|
||||
) -> QuantizedKVCache: ...
|
||||
def make_mask(self, *args, **kwargs): # -> array | Literal['causal'] | None:
|
||||
...
|
||||
def make_mask(self, *args, **kwargs): ...
|
||||
|
||||
class RotatingKVCache(_BaseCache):
|
||||
step = ...
|
||||
offset: int
|
||||
def __init__(self, max_size, keep=...) -> None: ...
|
||||
def update_and_fetch(
|
||||
self, keys, values
|
||||
): # -> tuple[array | Any, array | Any] | tuple[array | Any, array | Any | None]:
|
||||
...
|
||||
def update_and_fetch(self, keys, values): ...
|
||||
@property
|
||||
def state(
|
||||
self,
|
||||
): # -> tuple[Any | array, Any | array] | tuple[Any | array | None, Any | array | None]:
|
||||
...
|
||||
def state(self): ...
|
||||
@state.setter
|
||||
def state(self, v): # -> None:
|
||||
...
|
||||
def state(self, v): ...
|
||||
@property
|
||||
def meta_state(self): # -> tuple[str, ...]:
|
||||
...
|
||||
def meta_state(self): ...
|
||||
@meta_state.setter
|
||||
def meta_state(self, v): # -> None:
|
||||
...
|
||||
def is_trimmable(self): # -> bool:
|
||||
...
|
||||
def trim(self, n): # -> int:
|
||||
...
|
||||
def meta_state(self, v): ...
|
||||
def is_trimmable(self): ...
|
||||
def trim(self, n): ...
|
||||
def to_quantized(
|
||||
self, group_size: int = ..., bits: int = ...
|
||||
) -> QuantizedKVCache: ...
|
||||
def make_mask(
|
||||
self, N: int, window_size: Optional[int] = ..., return_array: bool = ...
|
||||
): # -> array | Literal['causal'] | None:
|
||||
...
|
||||
): ...
|
||||
|
||||
class ArraysCache(_BaseCache):
|
||||
def __init__(self, size, left_padding: Optional[List[int]] = ...) -> None: ...
|
||||
def __setitem__(self, idx, value): # -> None:
|
||||
...
|
||||
def __setitem__(self, idx, value): ...
|
||||
def __getitem__(self, idx): ...
|
||||
@property
|
||||
def state(self): # -> list[Any | array] | list[array]:
|
||||
...
|
||||
def state(self): ...
|
||||
@state.setter
|
||||
def state(self, v): # -> None:
|
||||
...
|
||||
def filter(self, batch_indices): # -> None:
|
||||
def state(self, v): ...
|
||||
def filter(self, batch_indices):
|
||||
"""
|
||||
In-place filter to keep just the given indices in the cache.
|
||||
"""
|
||||
...
|
||||
|
||||
def extend(self, other): # -> None:
|
||||
def extend(self, other):
|
||||
"""
|
||||
In-place extend this cache with the other cache.
|
||||
"""
|
||||
|
||||
def make_mask(self, N: int): # -> array | None:
|
||||
...
|
||||
|
||||
def make_mask(self, N: int): ...
|
||||
|
||||
class MambaCache(ArraysCache):
|
||||
def __init__(self, left_padding: Optional[List[int]] = ...) -> None: ...
|
||||
|
||||
class ChunkedKVCache(KVCache):
|
||||
def __init__(self, chunk_size) -> None: ...
|
||||
def maybe_trim_front(self): # -> None:
|
||||
...
|
||||
def update_and_fetch(self, keys, values): # -> tuple[array, array]:
|
||||
...
|
||||
def trim(self, n): # -> int:
|
||||
...
|
||||
def maybe_trim_front(self): ...
|
||||
def update_and_fetch(self, keys, values): ...
|
||||
def trim(self, n): ...
|
||||
@property
|
||||
def meta_state(self): # -> tuple[str, ...]:
|
||||
...
|
||||
def meta_state(self): ...
|
||||
@meta_state.setter
|
||||
def meta_state(self, v): # -> None:
|
||||
...
|
||||
def meta_state(self, v): ...
|
||||
|
||||
class CacheList(_BaseCache):
|
||||
def __init__(self, *caches) -> None: ...
|
||||
def __getitem__(self, idx): ...
|
||||
def is_trimmable(self): # -> bool:
|
||||
...
|
||||
def is_trimmable(self): ...
|
||||
def trim(self, n): ...
|
||||
@property
|
||||
def state(self): # -> list[Any]:
|
||||
...
|
||||
def state(self): ...
|
||||
@state.setter
|
||||
def state(self, v): # -> None:
|
||||
...
|
||||
def filter(self, batch_indices): # -> None:
|
||||
def state(self, v): ...
|
||||
def filter(self, batch_indices):
|
||||
"""
|
||||
In-place filter to keep just the given indices in the cache.
|
||||
"""
|
||||
...
|
||||
|
||||
def extend(self, other): # -> None:
|
||||
def extend(self, other):
|
||||
"""
|
||||
In-place extend this cache with the other cache.
|
||||
"""
|
||||
...
|
||||
|
||||
class BatchKVCache(_BaseCache):
|
||||
step = ...
|
||||
@@ -287,71 +256,56 @@ class BatchKVCache(_BaseCache):
|
||||
And ``left_padding`` specifies the amount of padding for each.
|
||||
In this case, ``left_padding = [1, 3, 0]``.
|
||||
"""
|
||||
...
|
||||
|
||||
def update_and_fetch(self, keys, values): # -> tuple[array | Any, array | Any]:
|
||||
...
|
||||
def update_and_fetch(self, keys, values): ...
|
||||
@property
|
||||
def state(
|
||||
self,
|
||||
): # -> tuple[Any | array | None, Any | array | None, array | Any, array | Any]:
|
||||
...
|
||||
def state(self): ...
|
||||
@state.setter
|
||||
def state(self, v): # -> None:
|
||||
...
|
||||
def is_trimmable(self): # -> Literal[True]:
|
||||
...
|
||||
def trim(self, n): # -> int | float:
|
||||
...
|
||||
def make_mask(self, N: int, return_array: bool = ..., **kwargs): # -> array:
|
||||
...
|
||||
def filter(self, batch_indices): # -> None:
|
||||
def state(self, v): ...
|
||||
def is_trimmable(self): ...
|
||||
def trim(self, n): ...
|
||||
def make_mask(self, N: int, return_array: bool = ..., **kwargs): ...
|
||||
def filter(self, batch_indices):
|
||||
"""
|
||||
In-place filter to keep just the given indices in the cache.
|
||||
"""
|
||||
...
|
||||
|
||||
def extend(self, other): # -> None:
|
||||
def extend(self, other):
|
||||
"""
|
||||
In-place extend this cache with the other cache.
|
||||
"""
|
||||
...
|
||||
|
||||
class BatchRotatingKVCache(_BaseCache):
|
||||
step = ...
|
||||
def __init__(self, max_size, left_padding: List[int]) -> None: ...
|
||||
def update_and_fetch(
|
||||
self, keys, values
|
||||
): # -> tuple[array | Any, array | Any] | tuple[array | Any, array | Any | None]:
|
||||
...
|
||||
def update_and_fetch(self, keys, values): ...
|
||||
@property
|
||||
def state(
|
||||
self,
|
||||
): # -> tuple[Any | array | None, Any | array | None, array | Any, array | Any]:
|
||||
...
|
||||
def state(self): ...
|
||||
@state.setter
|
||||
def state(self, v): # -> None:
|
||||
...
|
||||
def state(self, v): ...
|
||||
@property
|
||||
def meta_state(self): # -> tuple[str, ...]:
|
||||
...
|
||||
def meta_state(self): ...
|
||||
@meta_state.setter
|
||||
def meta_state(self, v): # -> None:
|
||||
...
|
||||
def is_trimmable(self): # -> bool:
|
||||
...
|
||||
def trim(self, n): # -> int:
|
||||
...
|
||||
def meta_state(self, v): ...
|
||||
def is_trimmable(self): ...
|
||||
def trim(self, n): ...
|
||||
def to_quantized(
|
||||
self, group_size: int = ..., bits: int = ...
|
||||
) -> QuantizedKVCache: ...
|
||||
def make_mask(
|
||||
self, N: int, window_size: Optional[int] = ..., return_array: bool = ...
|
||||
): # -> array:
|
||||
...
|
||||
def filter(self, batch_indices): # -> None:
|
||||
): ...
|
||||
def filter(self, batch_indices):
|
||||
"""
|
||||
In-place filter to keep just the given indices in the cache.
|
||||
"""
|
||||
...
|
||||
|
||||
def extend(self, other): # -> None:
|
||||
def extend(self, other):
|
||||
"""
|
||||
In-place extend this cache with the other cache.
|
||||
"""
|
||||
...
|
||||
|
||||
@@ -8,6 +8,10 @@ from typing import Any
|
||||
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
"""
|
||||
This type stub file was generated by pyright.
|
||||
"""
|
||||
|
||||
class StreamingDetokenizer:
|
||||
"""The streaming detokenizer interface so that we can detokenize one token at a time.
|
||||
|
||||
@@ -45,6 +49,7 @@ class StreamingDetokenizer:
|
||||
@property
|
||||
def last_segment(self):
|
||||
"""Return the last segment of readable text since last time this property was accessed."""
|
||||
...
|
||||
|
||||
class NaiveStreamingDetokenizer(StreamingDetokenizer):
|
||||
"""NaiveStreamingDetokenizer relies on the underlying tokenizer
|
||||
@@ -54,15 +59,11 @@ class NaiveStreamingDetokenizer(StreamingDetokenizer):
|
||||
repeatedly detokenize the same tokens until a new line is generated.
|
||||
"""
|
||||
def __init__(self, tokenizer) -> None: ...
|
||||
def reset(self): # -> None:
|
||||
...
|
||||
def add_token(self, token): # -> None:
|
||||
...
|
||||
def finalize(self): # -> None:
|
||||
...
|
||||
def reset(self): ...
|
||||
def add_token(self, token): ...
|
||||
def finalize(self): ...
|
||||
@property
|
||||
def text(self): # -> str:
|
||||
...
|
||||
def text(self): ...
|
||||
|
||||
class SPMStreamingDetokenizer(StreamingDetokenizer):
|
||||
"""A streaming detokenizer for SPM models.
|
||||
@@ -71,12 +72,9 @@ class SPMStreamingDetokenizer(StreamingDetokenizer):
|
||||
underscore which results in linear complexity.
|
||||
"""
|
||||
def __init__(self, tokenizer, trim_space=...) -> None: ...
|
||||
def reset(self): # -> None:
|
||||
...
|
||||
def add_token(self, token): # -> None:
|
||||
...
|
||||
def finalize(self): # -> None:
|
||||
...
|
||||
def reset(self): ...
|
||||
def add_token(self, token): ...
|
||||
def finalize(self): ...
|
||||
|
||||
class BPEStreamingDetokenizer(StreamingDetokenizer):
|
||||
"""A streaming detokenizer for OpenAI style BPE models.
|
||||
@@ -88,15 +86,13 @@ class BPEStreamingDetokenizer(StreamingDetokenizer):
|
||||
_byte_decoder = ...
|
||||
_space_matches = ...
|
||||
def __init__(self, tokenizer) -> None: ...
|
||||
def reset(self): # -> None:
|
||||
...
|
||||
def add_token(self, token): # -> None:
|
||||
...
|
||||
def finalize(self): # -> None:
|
||||
...
|
||||
def reset(self): ...
|
||||
def add_token(self, token): ...
|
||||
def finalize(self): ...
|
||||
@classmethod
|
||||
def make_byte_decoder(cls): # -> None:
|
||||
def make_byte_decoder(cls):
|
||||
"""See https://github.com/openai/gpt-2/blob/master/src/encoder.py for the rationale."""
|
||||
...
|
||||
|
||||
class TokenizerWrapper:
|
||||
"""A wrapper that combines an HF tokenizer and a detokenizer.
|
||||
@@ -157,13 +153,10 @@ class TokenizerWrapper:
|
||||
class NewlineTokenizer(PreTrainedTokenizerFast):
|
||||
"""A tokenizer that replaces newlines with <n> and <n> with new line."""
|
||||
def __init__(self, *args, **kwargs) -> None: ...
|
||||
def encode(self, text, **kwargs): # -> list[int]:
|
||||
...
|
||||
def encode(self, text, **kwargs): ...
|
||||
def encode_batch(self, texts, **kwargs): ...
|
||||
def decode(self, *args, **kwargs): # -> str:
|
||||
...
|
||||
def batch_decode(self, *args, **kwargs): # -> list[str]:
|
||||
...
|
||||
def decode(self, *args, **kwargs): ...
|
||||
def batch_decode(self, *args, **kwargs): ...
|
||||
|
||||
def load(
|
||||
model_path: Path,
|
||||
@@ -176,6 +169,7 @@ def load(
|
||||
Note, to use a fast streaming tokenizer, pass a local file path rather than
|
||||
a Hugging Face repo ID.
|
||||
"""
|
||||
...
|
||||
|
||||
# Alias for backward compatibility
|
||||
load_tokenizer = load
|
||||
|
||||
@@ -3,6 +3,45 @@
|
||||
perSystem =
|
||||
{ pkgs, lib, ... }:
|
||||
let
|
||||
# Stub source with lockfiles and minimal files for build to succeed
|
||||
# This allows prettier-svelte to avoid rebuilding when dashboard source changes
|
||||
dashboardStubSrc = pkgs.runCommand "dashboard-stub-src" { } ''
|
||||
mkdir -p $out
|
||||
cp ${inputs.self}/dashboard/package.json $out/
|
||||
cp ${inputs.self}/dashboard/package-lock.json $out/
|
||||
# Minimal files so vite build succeeds (produces empty output)
|
||||
echo '<!DOCTYPE html><html><head></head><body></body></html>' > $out/index.html
|
||||
mkdir -p $out/src
|
||||
touch $out/src/app.html
|
||||
'';
|
||||
|
||||
# Deps-only build using stub source (for prettier-svelte)
|
||||
# Only rebuilds when package.json or package-lock.json change
|
||||
dashboardDeps = inputs.dream2nix.lib.evalModules {
|
||||
packageSets.nixpkgs = pkgs;
|
||||
modules = [
|
||||
./dashboard.nix
|
||||
{
|
||||
paths.projectRoot = inputs.self;
|
||||
paths.projectRootFile = "flake.nix";
|
||||
paths.package = inputs.self + "/dashboard";
|
||||
}
|
||||
{
|
||||
deps.dashboardSrc = lib.mkForce dashboardStubSrc;
|
||||
}
|
||||
# Override build phases to skip the actual build - just need node_modules
|
||||
{
|
||||
mkDerivation = {
|
||||
buildPhase = lib.mkForce "true";
|
||||
installPhase = lib.mkForce ''
|
||||
runHook preInstall
|
||||
runHook postInstall
|
||||
'';
|
||||
};
|
||||
}
|
||||
];
|
||||
};
|
||||
|
||||
# Filter source to only include dashboard directory
|
||||
dashboardSrc = lib.cleanSourceWith {
|
||||
src = inputs.self;
|
||||
@@ -42,11 +81,12 @@
|
||||
'';
|
||||
|
||||
# Prettier with svelte plugin for treefmt
|
||||
# Uses dashboardDeps instead of dashboardFull to avoid rebuilding on source changes
|
||||
packages.prettier-svelte = pkgs.writeShellScriptBin "prettier-svelte" ''
|
||||
export NODE_PATH="${dashboardFull}/lib/node_modules/exo-dashboard/node_modules"
|
||||
export NODE_PATH="${dashboardDeps}/lib/node_modules/exo-dashboard/node_modules"
|
||||
exec ${pkgs.nodejs}/bin/node \
|
||||
${dashboardFull}/lib/node_modules/exo-dashboard/node_modules/prettier/bin/prettier.cjs \
|
||||
--plugin "${dashboardFull}/lib/node_modules/exo-dashboard/node_modules/prettier-plugin-svelte/plugin.js" \
|
||||
${dashboardDeps}/lib/node_modules/exo-dashboard/node_modules/prettier/bin/prettier.cjs \
|
||||
--plugin "${dashboardDeps}/lib/node_modules/exo-dashboard/node_modules/prettier-plugin-svelte/plugin.js" \
|
||||
"$@"
|
||||
'';
|
||||
};
|
||||
|
||||
@@ -89,7 +89,10 @@
|
||||
|
||||
const isImageModel = $derived(() => {
|
||||
if (!currentModel) return false;
|
||||
return modelSupportsTextToImage(currentModel);
|
||||
return (
|
||||
modelSupportsTextToImage(currentModel) ||
|
||||
modelSupportsImageEditing(currentModel)
|
||||
);
|
||||
});
|
||||
|
||||
const isEditOnlyWithoutImage = $derived(
|
||||
@@ -646,6 +649,23 @@
|
||||
</svg>
|
||||
<span>EDIT</span>
|
||||
</span>
|
||||
{:else if isEditOnlyWithoutImage}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
|
||||
/>
|
||||
</svg>
|
||||
<span>EDIT</span>
|
||||
</span>
|
||||
{:else if isImageModel()}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
|
||||
@@ -110,6 +110,36 @@
|
||||
setImageGenerationParams({ negativePrompt: value || null });
|
||||
}
|
||||
|
||||
function handleNumImagesChange(event: Event) {
|
||||
const input = event.target as HTMLInputElement;
|
||||
const value = input.value.trim();
|
||||
if (value === "") {
|
||||
setImageGenerationParams({ numImages: 1 });
|
||||
} else {
|
||||
const num = parseInt(value, 10);
|
||||
if (!isNaN(num) && num >= 1) {
|
||||
setImageGenerationParams({ numImages: num });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function handleStreamChange(enabled: boolean) {
|
||||
setImageGenerationParams({ stream: enabled });
|
||||
}
|
||||
|
||||
function handlePartialImagesChange(event: Event) {
|
||||
const input = event.target as HTMLInputElement;
|
||||
const value = input.value.trim();
|
||||
if (value === "") {
|
||||
setImageGenerationParams({ partialImages: 0 });
|
||||
} else {
|
||||
const num = parseInt(value, 10);
|
||||
if (!isNaN(num) && num >= 0) {
|
||||
setImageGenerationParams({ partialImages: num });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function clearSteps() {
|
||||
setImageGenerationParams({ numInferenceSteps: null });
|
||||
}
|
||||
@@ -134,90 +164,92 @@
|
||||
<div class="border-b border-exo-medium-gray/30 px-3 py-2">
|
||||
<!-- Basic params row -->
|
||||
<div class="flex items-center gap-3 flex-wrap">
|
||||
<!-- Size -->
|
||||
<div class="flex items-center gap-1.5">
|
||||
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
|
||||
>SIZE:</span
|
||||
>
|
||||
<div class="relative">
|
||||
<button
|
||||
bind:this={sizeButtonRef}
|
||||
type="button"
|
||||
onclick={() => (isSizeDropdownOpen = !isSizeDropdownOpen)}
|
||||
class="bg-exo-medium-gray/50 border border-exo-yellow/30 rounded pl-2 pr-6 py-1 text-xs font-mono text-exo-yellow cursor-pointer transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70 {isSizeDropdownOpen
|
||||
? 'border-exo-yellow/70'
|
||||
: ''}"
|
||||
<!-- Size (hidden in edit mode - output size comes from input image) -->
|
||||
{#if !isEditMode}
|
||||
<div class="flex items-center gap-1.5">
|
||||
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
|
||||
>SIZE:</span
|
||||
>
|
||||
{params.size}
|
||||
</button>
|
||||
<div
|
||||
class="absolute right-1.5 top-1/2 -translate-y-1/2 pointer-events-none transition-transform duration-200 {isSizeDropdownOpen
|
||||
? 'rotate-180'
|
||||
: ''}"
|
||||
>
|
||||
<svg
|
||||
class="w-3 h-3 text-exo-yellow/60"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
<div class="relative">
|
||||
<button
|
||||
bind:this={sizeButtonRef}
|
||||
type="button"
|
||||
onclick={() => (isSizeDropdownOpen = !isSizeDropdownOpen)}
|
||||
class="bg-exo-medium-gray/50 border border-exo-yellow/30 rounded pl-2 pr-6 py-1 text-xs font-mono text-exo-yellow cursor-pointer transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70 {isSizeDropdownOpen
|
||||
? 'border-exo-yellow/70'
|
||||
: ''}"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
stroke-width="2"
|
||||
d="M19 9l-7 7-7-7"
|
||||
/>
|
||||
</svg>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{#if isSizeDropdownOpen}
|
||||
<!-- Backdrop to close dropdown -->
|
||||
<button
|
||||
type="button"
|
||||
class="fixed inset-0 z-[9998] cursor-default"
|
||||
onclick={() => (isSizeDropdownOpen = false)}
|
||||
aria-label="Close dropdown"
|
||||
></button>
|
||||
|
||||
<!-- Dropdown Panel - fixed positioning to escape overflow:hidden -->
|
||||
<div
|
||||
class="fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto min-w-max"
|
||||
style="bottom: calc(100vh - {sizeDropdownPosition()
|
||||
.top}px + 4px); left: {sizeDropdownPosition().left}px;"
|
||||
>
|
||||
<div class="py-1">
|
||||
{#each sizeOptions as size}
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => selectSize(size)}
|
||||
class="w-full px-3 py-1.5 text-left text-xs font-mono tracking-wide transition-colors duration-100 flex items-center gap-2 {params.size ===
|
||||
size
|
||||
? 'bg-transparent text-exo-yellow'
|
||||
: 'text-exo-light-gray hover:text-exo-yellow'}"
|
||||
>
|
||||
{#if params.size === size}
|
||||
<svg
|
||||
class="w-3 h-3 flex-shrink-0"
|
||||
fill="currentColor"
|
||||
viewBox="0 0 20 20"
|
||||
>
|
||||
<path
|
||||
fill-rule="evenodd"
|
||||
d="M16.707 5.293a1 1 0 010 1.414l-8 8a1 1 0 01-1.414 0l-4-4a1 1 0 011.414-1.414L8 12.586l7.293-7.293a1 1 0 011.414 0z"
|
||||
clip-rule="evenodd"
|
||||
/>
|
||||
</svg>
|
||||
{:else}
|
||||
<span class="w-3"></span>
|
||||
{/if}
|
||||
<span>{size}</span>
|
||||
</button>
|
||||
{/each}
|
||||
{params.size}
|
||||
</button>
|
||||
<div
|
||||
class="absolute right-1.5 top-1/2 -translate-y-1/2 pointer-events-none transition-transform duration-200 {isSizeDropdownOpen
|
||||
? 'rotate-180'
|
||||
: ''}"
|
||||
>
|
||||
<svg
|
||||
class="w-3 h-3 text-exo-yellow/60"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
stroke-width="2"
|
||||
d="M19 9l-7 7-7-7"
|
||||
/>
|
||||
</svg>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
{#if isSizeDropdownOpen}
|
||||
<!-- Backdrop to close dropdown -->
|
||||
<button
|
||||
type="button"
|
||||
class="fixed inset-0 z-[9998] cursor-default"
|
||||
onclick={() => (isSizeDropdownOpen = false)}
|
||||
aria-label="Close dropdown"
|
||||
></button>
|
||||
|
||||
<!-- Dropdown Panel - fixed positioning to escape overflow:hidden -->
|
||||
<div
|
||||
class="fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto min-w-max"
|
||||
style="bottom: calc(100vh - {sizeDropdownPosition()
|
||||
.top}px + 4px); left: {sizeDropdownPosition().left}px;"
|
||||
>
|
||||
<div class="py-1">
|
||||
{#each sizeOptions as size}
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => selectSize(size)}
|
||||
class="w-full px-3 py-1.5 text-left text-xs font-mono tracking-wide transition-colors duration-100 flex items-center gap-2 {params.size ===
|
||||
size
|
||||
? 'bg-transparent text-exo-yellow'
|
||||
: 'text-exo-light-gray hover:text-exo-yellow'}"
|
||||
>
|
||||
{#if params.size === size}
|
||||
<svg
|
||||
class="w-3 h-3 flex-shrink-0"
|
||||
fill="currentColor"
|
||||
viewBox="0 0 20 20"
|
||||
>
|
||||
<path
|
||||
fill-rule="evenodd"
|
||||
d="M16.707 5.293a1 1 0 010 1.414l-8 8a1 1 0 01-1.414 0l-4-4a1 1 0 011.414-1.414L8 12.586l7.293-7.293a1 1 0 011.414 0z"
|
||||
clip-rule="evenodd"
|
||||
/>
|
||||
</svg>
|
||||
{:else}
|
||||
<span class="w-3"></span>
|
||||
{/if}
|
||||
<span>{size}</span>
|
||||
</button>
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Quality -->
|
||||
<div class="flex items-center gap-1.5">
|
||||
@@ -325,6 +357,59 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Number of Images (not in edit mode) -->
|
||||
{#if !isEditMode}
|
||||
<div class="flex items-center gap-1.5">
|
||||
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
|
||||
>IMAGES:</span
|
||||
>
|
||||
<input
|
||||
type="number"
|
||||
min="1"
|
||||
value={params.numImages}
|
||||
oninput={handleNumImagesChange}
|
||||
class="w-12 bg-exo-medium-gray/50 border border-exo-yellow/30 rounded px-2 py-1 text-xs font-mono text-exo-yellow text-center transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70"
|
||||
/>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Stream toggle -->
|
||||
<div class="flex items-center gap-1.5">
|
||||
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
|
||||
>STREAM:</span
|
||||
>
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => handleStreamChange(!params.stream)}
|
||||
class="w-8 h-4 rounded-full transition-all duration-200 cursor-pointer relative {params.stream
|
||||
? 'bg-exo-yellow'
|
||||
: 'bg-exo-medium-gray/50 border border-exo-yellow/30'}"
|
||||
title={params.stream ? "Streaming enabled" : "Streaming disabled"}
|
||||
>
|
||||
<div
|
||||
class="absolute top-0.5 w-3 h-3 rounded-full transition-all duration-200 {params.stream
|
||||
? 'right-0.5 bg-exo-black'
|
||||
: 'left-0.5 bg-exo-light-gray'}"
|
||||
></div>
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Partial Images (only when streaming) -->
|
||||
{#if params.stream}
|
||||
<div class="flex items-center gap-1.5">
|
||||
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
|
||||
>PARTIALS:</span
|
||||
>
|
||||
<input
|
||||
type="number"
|
||||
min="0"
|
||||
value={params.partialImages}
|
||||
oninput={handlePartialImagesChange}
|
||||
class="w-12 bg-exo-medium-gray/50 border border-exo-yellow/30 rounded px-2 py-1 text-xs font-mono text-exo-yellow text-center transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70"
|
||||
/>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Input Fidelity (edit mode only) -->
|
||||
{#if isEditMode}
|
||||
<div class="flex items-center gap-1.5">
|
||||
|
||||
@@ -238,6 +238,10 @@ export interface ImageGenerationParams {
|
||||
size: "512x512" | "768x768" | "1024x1024" | "1024x768" | "768x1024";
|
||||
quality: "low" | "medium" | "high";
|
||||
outputFormat: "png" | "jpeg";
|
||||
numImages: number;
|
||||
// Streaming params
|
||||
stream: boolean;
|
||||
partialImages: number;
|
||||
// Advanced params
|
||||
seed: number | null;
|
||||
numInferenceSteps: number | null;
|
||||
@@ -257,6 +261,9 @@ const DEFAULT_IMAGE_PARAMS: ImageGenerationParams = {
|
||||
size: "1024x1024",
|
||||
quality: "medium",
|
||||
outputFormat: "png",
|
||||
numImages: 1,
|
||||
stream: true,
|
||||
partialImages: 3,
|
||||
seed: null,
|
||||
numInferenceSteps: null,
|
||||
guidance: null,
|
||||
@@ -1809,12 +1816,13 @@ class AppStore {
|
||||
const requestBody: Record<string, unknown> = {
|
||||
model,
|
||||
prompt,
|
||||
n: params.numImages,
|
||||
quality: params.quality,
|
||||
size: params.size,
|
||||
output_format: params.outputFormat,
|
||||
response_format: "b64_json",
|
||||
stream: true,
|
||||
partial_images: 3,
|
||||
stream: params.stream,
|
||||
partial_images: params.partialImages,
|
||||
};
|
||||
|
||||
if (hasAdvancedParams) {
|
||||
@@ -1878,31 +1886,74 @@ class AppStore {
|
||||
if (imageData && idx !== -1) {
|
||||
const format = parsed.format || "png";
|
||||
const mimeType = `image/${format}`;
|
||||
const imageIndex = parsed.image_index ?? 0;
|
||||
const numImages = params.numImages;
|
||||
|
||||
if (parsed.type === "partial") {
|
||||
// Update with partial image and progress
|
||||
const partialNum = (parsed.partial_index ?? 0) + 1;
|
||||
const totalPartials = parsed.total_partials ?? 3;
|
||||
this.messages[idx].content =
|
||||
`Generating... ${partialNum}/${totalPartials}`;
|
||||
this.messages[idx].attachments = [
|
||||
{
|
||||
type: "generated-image",
|
||||
name: `generated-image.${format}`,
|
||||
preview: `data:${mimeType};base64,${imageData}`,
|
||||
mimeType,
|
||||
},
|
||||
];
|
||||
const progressText =
|
||||
numImages > 1
|
||||
? `Generating image ${imageIndex + 1}/${numImages}... ${partialNum}/${totalPartials}`
|
||||
: `Generating... ${partialNum}/${totalPartials}`;
|
||||
this.messages[idx].content = progressText;
|
||||
|
||||
const partialAttachment: MessageAttachment = {
|
||||
type: "generated-image",
|
||||
name: `generated-image.${format}`,
|
||||
preview: `data:${mimeType};base64,${imageData}`,
|
||||
mimeType,
|
||||
};
|
||||
|
||||
if (imageIndex === 0) {
|
||||
// First image - safe to replace attachments with partial preview
|
||||
this.messages[idx].attachments = [partialAttachment];
|
||||
} else {
|
||||
// Subsequent images - keep existing finals, show partial at current position
|
||||
const existingAttachments =
|
||||
this.messages[idx].attachments || [];
|
||||
// Keep only the completed final images (up to current imageIndex)
|
||||
const finals = existingAttachments.slice(0, imageIndex);
|
||||
this.messages[idx].attachments = [
|
||||
...finals,
|
||||
partialAttachment,
|
||||
];
|
||||
}
|
||||
} else if (parsed.type === "final") {
|
||||
// Final image
|
||||
this.messages[idx].content = "";
|
||||
this.messages[idx].attachments = [
|
||||
{
|
||||
type: "generated-image",
|
||||
name: `generated-image.${format}`,
|
||||
preview: `data:${mimeType};base64,${imageData}`,
|
||||
mimeType,
|
||||
},
|
||||
];
|
||||
// Final image - replace partial at this position
|
||||
const newAttachment: MessageAttachment = {
|
||||
type: "generated-image",
|
||||
name: `generated-image-${imageIndex + 1}.${format}`,
|
||||
preview: `data:${mimeType};base64,${imageData}`,
|
||||
mimeType,
|
||||
};
|
||||
|
||||
if (imageIndex === 0) {
|
||||
// First final image - replace any partial preview
|
||||
this.messages[idx].attachments = [newAttachment];
|
||||
} else {
|
||||
// Subsequent images - keep previous finals, replace partial at current position
|
||||
const existingAttachments =
|
||||
this.messages[idx].attachments || [];
|
||||
// Slice keeps indices 0 to imageIndex-1 (the previous final images)
|
||||
const previousFinals = existingAttachments.slice(
|
||||
0,
|
||||
imageIndex,
|
||||
);
|
||||
this.messages[idx].attachments = [
|
||||
...previousFinals,
|
||||
newAttachment,
|
||||
];
|
||||
}
|
||||
|
||||
// Update progress message for multiple images
|
||||
if (numImages > 1 && imageIndex < numImages - 1) {
|
||||
this.messages[idx].content =
|
||||
`Generating image ${imageIndex + 2}/${numImages}...`;
|
||||
} else {
|
||||
this.messages[idx].content = "";
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
@@ -1983,8 +2034,8 @@ class AppStore {
|
||||
formData.append("size", params.size);
|
||||
formData.append("output_format", params.outputFormat);
|
||||
formData.append("response_format", "b64_json");
|
||||
formData.append("stream", "1"); // Use "1" instead of "true" for reliable FastAPI boolean parsing
|
||||
formData.append("partial_images", "3");
|
||||
formData.append("stream", params.stream ? "1" : "0");
|
||||
formData.append("partial_images", params.partialImages.toString());
|
||||
formData.append("input_fidelity", params.inputFidelity);
|
||||
|
||||
// Advanced params
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import base64
|
||||
import contextlib
|
||||
import json
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
@@ -33,6 +34,7 @@ from exo.shared.models.model_cards import (
|
||||
ModelId,
|
||||
)
|
||||
from exo.shared.types.api import (
|
||||
AdvancedImageParams,
|
||||
BenchChatCompletionResponse,
|
||||
BenchChatCompletionTaskParams,
|
||||
BenchImageGenerationResponse,
|
||||
@@ -835,6 +837,7 @@ class API:
|
||||
# Yield partial image event (always use b64_json for partials)
|
||||
event_data = {
|
||||
"type": "partial",
|
||||
"image_index": chunk.image_index,
|
||||
"partial_index": partial_idx,
|
||||
"total_partials": total_partials,
|
||||
"format": str(chunk.format),
|
||||
@@ -1024,6 +1027,9 @@ class API:
|
||||
stream: bool,
|
||||
partial_images: int,
|
||||
bench: bool,
|
||||
quality: Literal["high", "medium", "low"],
|
||||
output_format: Literal["png", "jpeg", "webp"],
|
||||
advanced_params: AdvancedImageParams | None,
|
||||
) -> ImageEdits:
|
||||
"""Prepare and send an image edits command with chunked image upload."""
|
||||
resolved_model = await self._validate_image_model(model)
|
||||
@@ -1052,6 +1058,9 @@ class API:
|
||||
stream=stream,
|
||||
partial_images=partial_images,
|
||||
bench=bench,
|
||||
quality=quality,
|
||||
output_format=output_format,
|
||||
advanced_params=advanced_params,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1086,12 +1095,22 @@ class API:
|
||||
input_fidelity: Literal["low", "high"] = Form("low"),
|
||||
stream: str = Form("false"),
|
||||
partial_images: str = Form("0"),
|
||||
quality: Literal["high", "medium", "low"] = Form("medium"),
|
||||
output_format: Literal["png", "jpeg", "webp"] = Form("png"),
|
||||
advanced_params: str | None = Form(None),
|
||||
) -> ImageGenerationResponse | StreamingResponse:
|
||||
"""Handle image editing requests (img2img)."""
|
||||
# Parse string form values to proper types
|
||||
stream_bool = stream.lower() in ("true", "1", "yes")
|
||||
partial_images_int = int(partial_images) if partial_images.isdigit() else 0
|
||||
|
||||
parsed_advanced_params: AdvancedImageParams | None = None
|
||||
if advanced_params:
|
||||
with contextlib.suppress(Exception):
|
||||
parsed_advanced_params = AdvancedImageParams.model_validate_json(
|
||||
advanced_params
|
||||
)
|
||||
|
||||
command = await self._send_image_edits_command(
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
@@ -1103,6 +1122,9 @@ class API:
|
||||
stream=stream_bool,
|
||||
partial_images=partial_images_int,
|
||||
bench=False,
|
||||
quality=quality,
|
||||
output_format=output_format,
|
||||
advanced_params=parsed_advanced_params,
|
||||
)
|
||||
|
||||
if stream_bool and partial_images_int > 0:
|
||||
@@ -1133,8 +1155,18 @@ class API:
|
||||
size: str = Form("1024x1024"),
|
||||
response_format: Literal["url", "b64_json"] = Form("b64_json"),
|
||||
input_fidelity: Literal["low", "high"] = Form("low"),
|
||||
quality: Literal["high", "medium", "low"] = Form("medium"),
|
||||
output_format: Literal["png", "jpeg", "webp"] = Form("png"),
|
||||
advanced_params: str | None = Form(None),
|
||||
) -> BenchImageGenerationResponse:
|
||||
"""Handle benchmark image editing requests with generation stats."""
|
||||
parsed_advanced_params: AdvancedImageParams | None = None
|
||||
if advanced_params:
|
||||
with contextlib.suppress(Exception):
|
||||
parsed_advanced_params = AdvancedImageParams.model_validate_json(
|
||||
advanced_params
|
||||
)
|
||||
|
||||
command = await self._send_image_edits_command(
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
@@ -1146,6 +1178,9 @@ class API:
|
||||
stream=False,
|
||||
partial_images=0,
|
||||
bench=True,
|
||||
quality=quality,
|
||||
output_format=output_format,
|
||||
advanced_params=parsed_advanced_params,
|
||||
)
|
||||
|
||||
return await self._collect_image_generation_with_stats(
|
||||
|
||||
11
src/exo/shared/types/mlx.py
Normal file
11
src/exo/shared/types/mlx.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""Shared types for MLX-related functionality."""
|
||||
|
||||
from mlx_lm.models.cache import (
|
||||
KVCache,
|
||||
QuantizedKVCache,
|
||||
RotatingKVCache,
|
||||
)
|
||||
|
||||
# Type alias for KV cache - matches make_kv_cache return type
|
||||
# This list contains one cache entry per transformer layer
|
||||
KVCacheType = list[KVCache | RotatingKVCache | QuantizedKVCache]
|
||||
@@ -30,6 +30,7 @@ class ImageGenerationResponse(BaseRunnerResponse):
|
||||
image_data: bytes
|
||||
format: Literal["png", "jpeg", "webp"] = "png"
|
||||
stats: ImageGenerationStats | None = None
|
||||
image_index: int = 0
|
||||
|
||||
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
|
||||
for name, value in super().__repr_args__(): # pyright: ignore[reportAny]
|
||||
@@ -44,6 +45,7 @@ class PartialImageResponse(BaseRunnerResponse):
|
||||
format: Literal["png", "jpeg", "webp"] = "png"
|
||||
partial_index: int
|
||||
total_partials: int
|
||||
image_index: int = 0
|
||||
|
||||
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
|
||||
for name, value in super().__repr_args__(): # pyright: ignore[reportAny]
|
||||
|
||||
@@ -75,19 +75,20 @@ def generate_image(
|
||||
intermediate images, then ImageGenerationResponse for the final image.
|
||||
|
||||
Yields:
|
||||
PartialImageResponse for intermediate images (if partial_images > 0)
|
||||
ImageGenerationResponse for the final complete image
|
||||
PartialImageResponse for intermediate images (if partial_images > 0, first image only)
|
||||
ImageGenerationResponse for final complete images
|
||||
"""
|
||||
width, height = parse_size(task.size)
|
||||
quality: Literal["low", "medium", "high"] = task.quality or "medium"
|
||||
|
||||
advanced_params = task.advanced_params
|
||||
if advanced_params is not None and advanced_params.seed is not None:
|
||||
seed = advanced_params.seed
|
||||
base_seed = advanced_params.seed
|
||||
else:
|
||||
seed = random.randint(0, 2**32 - 1)
|
||||
base_seed = random.randint(0, 2**32 - 1)
|
||||
|
||||
is_bench = getattr(task, "bench", False)
|
||||
num_images = task.n or 1
|
||||
|
||||
generation_start_time: float = 0.0
|
||||
|
||||
@@ -95,7 +96,11 @@ def generate_image(
|
||||
mx.reset_peak_memory()
|
||||
generation_start_time = time.perf_counter()
|
||||
|
||||
partial_images = task.partial_images or (3 if task.stream else 0)
|
||||
partial_images = (
|
||||
task.partial_images
|
||||
if task.partial_images is not None
|
||||
else (3 if task.stream else 0)
|
||||
)
|
||||
|
||||
image_path: Path | None = None
|
||||
|
||||
@@ -105,72 +110,81 @@ def generate_image(
|
||||
image_path = Path(tmpdir) / "input.png"
|
||||
image_path.write_bytes(base64.b64decode(task.image_data))
|
||||
|
||||
# Iterate over generator results
|
||||
for result in model.generate(
|
||||
prompt=task.prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
quality=quality,
|
||||
seed=seed,
|
||||
image_path=image_path,
|
||||
partial_images=partial_images,
|
||||
advanced_params=advanced_params,
|
||||
):
|
||||
if isinstance(result, tuple):
|
||||
# Partial image: (Image, partial_index, total_partials)
|
||||
image, partial_idx, total_partials = result
|
||||
buffer = io.BytesIO()
|
||||
image_format = task.output_format.upper()
|
||||
if image_format == "JPG":
|
||||
image_format = "JPEG"
|
||||
if image_format == "JPEG" and image.mode == "RGBA":
|
||||
image = image.convert("RGB")
|
||||
image.save(buffer, format=image_format)
|
||||
for image_num in range(num_images):
|
||||
# Increment seed for each image to ensure unique results
|
||||
current_seed = base_seed + image_num
|
||||
|
||||
yield PartialImageResponse(
|
||||
image_data=buffer.getvalue(),
|
||||
format=task.output_format,
|
||||
partial_index=partial_idx,
|
||||
total_partials=total_partials,
|
||||
)
|
||||
else:
|
||||
image = result
|
||||
for result in model.generate(
|
||||
prompt=task.prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
quality=quality,
|
||||
seed=current_seed,
|
||||
image_path=image_path,
|
||||
partial_images=partial_images,
|
||||
advanced_params=advanced_params,
|
||||
):
|
||||
if isinstance(result, tuple):
|
||||
# Partial image: (Image, partial_index, total_partials)
|
||||
image, partial_idx, total_partials = result
|
||||
buffer = io.BytesIO()
|
||||
image_format = task.output_format.upper()
|
||||
if image_format == "JPG":
|
||||
image_format = "JPEG"
|
||||
if image_format == "JPEG" and image.mode == "RGBA":
|
||||
image = image.convert("RGB")
|
||||
image.save(buffer, format=image_format)
|
||||
|
||||
stats: ImageGenerationStats | None = None
|
||||
if is_bench:
|
||||
generation_end_time = time.perf_counter()
|
||||
total_generation_time = generation_end_time - generation_start_time
|
||||
|
||||
num_inference_steps = model.get_steps_for_quality(quality)
|
||||
|
||||
seconds_per_step = (
|
||||
total_generation_time / num_inference_steps
|
||||
if num_inference_steps > 0
|
||||
else 0.0
|
||||
yield PartialImageResponse(
|
||||
image_data=buffer.getvalue(),
|
||||
format=task.output_format,
|
||||
partial_index=partial_idx,
|
||||
total_partials=total_partials,
|
||||
image_index=image_num,
|
||||
)
|
||||
else:
|
||||
image = result
|
||||
|
||||
peak_memory_gb = mx.get_peak_memory() / (1024**3)
|
||||
# Only include stats on the final image
|
||||
stats: ImageGenerationStats | None = None
|
||||
if is_bench and image_num == num_images - 1:
|
||||
generation_end_time = time.perf_counter()
|
||||
total_generation_time = (
|
||||
generation_end_time - generation_start_time
|
||||
)
|
||||
|
||||
stats = ImageGenerationStats(
|
||||
seconds_per_step=seconds_per_step,
|
||||
total_generation_time=total_generation_time,
|
||||
num_inference_steps=num_inference_steps,
|
||||
num_images=task.n or 1,
|
||||
image_width=width,
|
||||
image_height=height,
|
||||
peak_memory_usage=Memory.from_gb(peak_memory_gb),
|
||||
num_inference_steps = model.get_steps_for_quality(quality)
|
||||
total_steps = num_inference_steps * num_images
|
||||
|
||||
seconds_per_step = (
|
||||
total_generation_time / total_steps
|
||||
if total_steps > 0
|
||||
else 0.0
|
||||
)
|
||||
|
||||
peak_memory_gb = mx.get_peak_memory() / (1024**3)
|
||||
|
||||
stats = ImageGenerationStats(
|
||||
seconds_per_step=seconds_per_step,
|
||||
total_generation_time=total_generation_time,
|
||||
num_inference_steps=num_inference_steps,
|
||||
num_images=num_images,
|
||||
image_width=width,
|
||||
image_height=height,
|
||||
peak_memory_usage=Memory.from_gb(peak_memory_gb),
|
||||
)
|
||||
|
||||
buffer = io.BytesIO()
|
||||
image_format = task.output_format.upper()
|
||||
if image_format == "JPG":
|
||||
image_format = "JPEG"
|
||||
if image_format == "JPEG" and image.mode == "RGBA":
|
||||
image = image.convert("RGB")
|
||||
image.save(buffer, format=image_format)
|
||||
|
||||
yield ImageGenerationResponse(
|
||||
image_data=buffer.getvalue(),
|
||||
format=task.output_format,
|
||||
stats=stats,
|
||||
image_index=image_num,
|
||||
)
|
||||
|
||||
buffer = io.BytesIO()
|
||||
image_format = task.output_format.upper()
|
||||
if image_format == "JPG":
|
||||
image_format = "JPEG"
|
||||
if image_format == "JPEG" and image.mode == "RGBA":
|
||||
image = image.convert("RGB")
|
||||
image.save(buffer, format=image_format)
|
||||
|
||||
yield ImageGenerationResponse(
|
||||
image_data=buffer.getvalue(),
|
||||
format=task.output_format,
|
||||
stats=stats,
|
||||
)
|
||||
|
||||
@@ -1,39 +1,53 @@
|
||||
# type: ignore
|
||||
# TODO: Fix this file, including types!
|
||||
from copy import deepcopy
|
||||
from typing import Callable
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm import stream_generate
|
||||
from mlx_lm.models.cache import _BaseCache, trim_prompt_cache
|
||||
from mlx_lm.models.cache import trim_prompt_cache
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
from exo.shared.types.mlx import KVCacheType
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.constants import KEEP_KV_SIZE, KV_BITS, KV_GROUP_SIZE
|
||||
from exo.worker.engines.mlx.utils_mlx import make_kv_cache
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
|
||||
class KVPrefixCache:
|
||||
def __init__(self):
|
||||
# Only one prefix cache per runner.
|
||||
self.prompts: list[mx.array] = [] # mx array of tokens (ints)
|
||||
self.caches: list[list[_BaseCache]] = []
|
||||
self.caches: list[KVCacheType] = []
|
||||
|
||||
def clear(self):
|
||||
"""Clear all cached prompts and caches."""
|
||||
self.prompts.clear()
|
||||
self.caches.clear()
|
||||
|
||||
def add_kv_cache(
|
||||
self, tokenizer: TokenizerWrapper, prompt: str, cache: list[_BaseCache]
|
||||
self, tokenizer: TokenizerWrapper, prompt: str, cache: KVCacheType
|
||||
):
|
||||
tokenized_prompt = self.encode_prompt(tokenizer, prompt)
|
||||
tokenized_prompt = encode_prompt(tokenizer, prompt)
|
||||
self.prompts.append(tokenized_prompt)
|
||||
self.caches.append(deepcopy(cache))
|
||||
logger.info(f"KV cache saved: {len(tokenized_prompt)} tokens")
|
||||
|
||||
def get_kv_cache(
|
||||
self,
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
sampler: Callable[[mx.array], mx.array],
|
||||
prompt: str,
|
||||
) -> list[_BaseCache]:
|
||||
tokenized_prompt = self.encode_prompt(tokenizer, prompt)
|
||||
) -> tuple[KVCacheType, mx.array]:
|
||||
"""Get KV cache for prompt, returning remaining tokens to prefill.
|
||||
|
||||
This method finds the best matching cached prefix and returns:
|
||||
- A copy of the cache trimmed to the prefix length
|
||||
- The remaining tokens that need to be prefilled before generation
|
||||
|
||||
The caller is responsible for prefilling the remaining tokens.
|
||||
|
||||
Returns:
|
||||
Tuple of (cache, remaining_tokens) where remaining_tokens are the
|
||||
tokens that still need to be prefilled/processed.
|
||||
"""
|
||||
tokenized_prompt = encode_prompt(tokenizer, prompt)
|
||||
max_length = len(tokenized_prompt)
|
||||
|
||||
best_snapshot_index, best_snapshot_length = None, 0
|
||||
@@ -42,63 +56,75 @@ class KVPrefixCache:
|
||||
length = _get_prefix_length(tokenized_prompt, cached_prompt)
|
||||
|
||||
if length == max_length:
|
||||
return self.caches[i]
|
||||
# Exact match - cached prompt starts with our entire prompt
|
||||
# Trim cache to prompt length - 1, return last token for stream_generate
|
||||
prompt_cache = deepcopy(self.caches[i])
|
||||
cached_length = _cache_length(self.caches[i])
|
||||
tokens_to_trim = cached_length - (max_length - 1)
|
||||
if tokens_to_trim > 0:
|
||||
trim_prompt_cache(prompt_cache, tokens_to_trim)
|
||||
logger.info(f"KV cache exact match: {max_length} tokens (instant)")
|
||||
return prompt_cache, tokenized_prompt[-1:]
|
||||
|
||||
if length > best_snapshot_length:
|
||||
best_snapshot_index, best_snapshot_length = i, length
|
||||
|
||||
if best_snapshot_index is not None:
|
||||
prompt_cache = deepcopy(self.caches[best_snapshot_index])
|
||||
trim_prompt_cache(prompt_cache, max_length - best_snapshot_length)
|
||||
tokenized_prompt = tokenized_prompt[best_snapshot_index:]
|
||||
|
||||
else:
|
||||
prompt_cache = make_kv_cache(
|
||||
model,
|
||||
# max_kv_size=MAX_KV_SIZE,
|
||||
# keep=KEEP_KV_SIZE
|
||||
new_tokens = max_length - best_snapshot_length
|
||||
logger.info(
|
||||
f"KV cache prefix match: {best_snapshot_length}/{max_length} tokens "
|
||||
f"(reusing {best_snapshot_length}, need to prefill {new_tokens})"
|
||||
)
|
||||
|
||||
prefill(model, tokenizer, sampler, tokenized_prompt, prompt_cache)
|
||||
prompt_cache = deepcopy(self.caches[best_snapshot_index])
|
||||
|
||||
return prompt_cache
|
||||
# Trim removes tokens from the end, so we trim (cached_length - prefix_length) to keep the prefix
|
||||
cached_length = _cache_length(self.caches[best_snapshot_index])
|
||||
tokens_to_trim = cached_length - best_snapshot_length
|
||||
if tokens_to_trim > 0:
|
||||
trim_prompt_cache(prompt_cache, tokens_to_trim)
|
||||
|
||||
def encode_prompt(self, tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
|
||||
add_special_tokens = tokenizer.bos_token is None or not prompt.startswith(
|
||||
tokenizer.bos_token
|
||||
)
|
||||
tokenized_prompt = tokenizer.encode(
|
||||
prompt, add_special_tokens=add_special_tokens
|
||||
)
|
||||
return mx.array(tokenized_prompt)
|
||||
# Return remaining tokens for caller to prefill
|
||||
remaining_tokens = tokenized_prompt[best_snapshot_length:]
|
||||
return prompt_cache, remaining_tokens
|
||||
|
||||
else:
|
||||
prompt_cache = make_kv_cache(model)
|
||||
if len(self.prompts) == 0:
|
||||
logger.info(f"KV cache empty, need to prefill {max_length} tokens")
|
||||
else:
|
||||
logger.info(
|
||||
f"KV cache no prefix match, need to prefill {max_length} tokens"
|
||||
)
|
||||
|
||||
# Return all tokens for caller to prefill
|
||||
return prompt_cache, tokenized_prompt
|
||||
|
||||
|
||||
def encode_prompt(tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
|
||||
"""Encode a prompt string to token array.
|
||||
|
||||
For chat-templated prompts (which have their own structure markers like
|
||||
<|im_user|>, <|im_middle|>, etc.), we should NOT add BOS/EOS tokens as
|
||||
that would corrupt the prompt structure.
|
||||
"""
|
||||
# Chat templates define their own structure - don't add BOS/EOS
|
||||
tokenized_prompt = tokenizer.encode(prompt, add_special_tokens=False)
|
||||
return mx.array(tokenized_prompt)
|
||||
|
||||
|
||||
def _cache_length(cache: KVCacheType) -> int:
|
||||
"""Get the number of tokens in a KV cache."""
|
||||
# Use .offset attribute which all cache types have (len() not implemented in older QuantizedKVCache)
|
||||
return max(c.offset for c in cache)
|
||||
|
||||
|
||||
def _get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
|
||||
n = min(int(prompt.shape[0]), int(cached_prompt.shape[0]), KEEP_KV_SIZE)
|
||||
"""Find the length of the common prefix between two token arrays."""
|
||||
n = min(int(prompt.shape[0]), int(cached_prompt.shape[0]))
|
||||
if n == 0:
|
||||
return 0
|
||||
|
||||
equal = (prompt[:n] == cached_prompt[:n]).astype(mx.int32)
|
||||
equal = mx.equal(prompt[:n], cached_prompt[:n]).astype(mx.int32)
|
||||
prefix_mask = mx.cumprod(equal) # stays 1 until first mismatch, then 0 forever
|
||||
return int(mx.sum(prefix_mask).item())
|
||||
|
||||
|
||||
def prefill(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
sampler: Callable[[mx.array], mx.array],
|
||||
prompt: mx.array,
|
||||
cache: list[_BaseCache],
|
||||
) -> None:
|
||||
for _ in stream_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=prompt,
|
||||
max_tokens=0,
|
||||
sampler=sampler,
|
||||
prompt_cache=cache,
|
||||
prefill_step_size=2048,
|
||||
kv_group_size=KV_GROUP_SIZE,
|
||||
kv_bits=KV_BITS,
|
||||
):
|
||||
pass
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
from typing import Any, Callable, Generator, cast, get_args
|
||||
import time
|
||||
from typing import Callable, Generator, cast, get_args
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm.generate import stream_generate
|
||||
from mlx_lm.models.cache import KVCache
|
||||
from mlx_lm.models.cache import trim_prompt_cache
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
# from exo.engines.mlx.cache import KVPrefixCache
|
||||
from exo.shared.types.api import (
|
||||
BenchChatCompletionTaskParams,
|
||||
ChatCompletionMessage,
|
||||
@@ -14,11 +14,13 @@ from exo.shared.types.api import (
|
||||
GenerationStats,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.mlx import KVCacheType
|
||||
from exo.shared.types.tasks import ChatCompletionTaskParams
|
||||
from exo.shared.types.worker.runner_response import (
|
||||
GenerationResponse,
|
||||
)
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.cache import KVPrefixCache, encode_prompt
|
||||
from exo.worker.engines.mlx.constants import KV_BITS, KV_GROUP_SIZE, MAX_TOKENS
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
apply_chat_template,
|
||||
@@ -30,19 +32,59 @@ from exo.worker.runner.bootstrap import logger
|
||||
generation_stream = mx.new_stream(mx.default_device())
|
||||
|
||||
|
||||
def maybe_quantize_kv_cache(
|
||||
prompt_cache: list[KVCache | Any],
|
||||
quantized_kv_start: int,
|
||||
kv_group_size: int,
|
||||
kv_bits: int | None,
|
||||
) -> None:
|
||||
if kv_bits is None:
|
||||
return
|
||||
for e, c in enumerate(prompt_cache):
|
||||
if (
|
||||
hasattr(c, "to_quantized") and c.offset >= quantized_kv_start # type: ignore
|
||||
):
|
||||
prompt_cache[e] = c.to_quantized(group_size=kv_group_size, bits=kv_bits)
|
||||
def prefill(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
sampler: Callable[[mx.array], mx.array],
|
||||
prompt_tokens: mx.array,
|
||||
cache: KVCacheType,
|
||||
) -> float:
|
||||
"""Prefill the KV cache with prompt tokens.
|
||||
|
||||
This runs the model over the prompt tokens to populate the cache,
|
||||
then trims off the extra generated token.
|
||||
|
||||
Returns:
|
||||
tokens_per_sec
|
||||
"""
|
||||
num_tokens = len(prompt_tokens)
|
||||
if num_tokens == 0:
|
||||
return 0.0
|
||||
|
||||
logger.debug(f"Prefilling {num_tokens} tokens...")
|
||||
start_time = time.perf_counter()
|
||||
|
||||
def progress_callback(processed: int, total: int) -> None:
|
||||
elapsed = time.time() - start_time
|
||||
tok_per_sec = processed / elapsed if elapsed > 0 else 0
|
||||
logger.debug(
|
||||
f"Prefill progress: {processed}/{total} tokens ({tok_per_sec:.1f} tok/s)"
|
||||
)
|
||||
|
||||
# Use max_tokens=1 because max_tokens=0 does not work.
|
||||
# We just throw away the generated token - we only care about filling the cache
|
||||
for _ in stream_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=prompt_tokens,
|
||||
max_tokens=1,
|
||||
sampler=sampler,
|
||||
prompt_cache=cache,
|
||||
prefill_step_size=2048,
|
||||
kv_group_size=KV_GROUP_SIZE,
|
||||
kv_bits=KV_BITS,
|
||||
prompt_progress_callback=progress_callback,
|
||||
):
|
||||
break # Stop after first iteration - cache is now filled
|
||||
trim_prompt_cache(cache, 1)
|
||||
|
||||
elapsed = time.perf_counter() - start_time
|
||||
tokens_per_sec = num_tokens / elapsed if elapsed > 0 else 0.0
|
||||
logger.debug(
|
||||
f"Prefill complete: {num_tokens} tokens in {elapsed:.2f}s "
|
||||
f"({tokens_per_sec:.1f} tok/s)"
|
||||
)
|
||||
return tokens_per_sec
|
||||
|
||||
|
||||
def warmup_inference(
|
||||
@@ -120,6 +162,7 @@ def mlx_generate(
|
||||
tokenizer: TokenizerWrapper,
|
||||
task: ChatCompletionTaskParams,
|
||||
prompt: str,
|
||||
kv_prefix_cache: KVPrefixCache | None = None,
|
||||
) -> Generator[GenerationResponse]:
|
||||
# Ensure that generation stats only contains peak memory for this generation
|
||||
mx.reset_peak_memory()
|
||||
@@ -131,7 +174,16 @@ def mlx_generate(
|
||||
if task.seed is not None:
|
||||
mx.random.seed(task.seed)
|
||||
|
||||
caches = make_kv_cache(model=model)
|
||||
# Do not use the prefix cache if we are trying to do benchmarks.
|
||||
if is_bench:
|
||||
kv_prefix_cache = None
|
||||
|
||||
# Use prefix cache if available, otherwise create fresh cache
|
||||
if kv_prefix_cache is None:
|
||||
caches = make_kv_cache(model=model)
|
||||
prompt_tokens = encode_prompt(tokenizer, prompt)
|
||||
else:
|
||||
caches, prompt_tokens = kv_prefix_cache.get_kv_cache(model, tokenizer, prompt)
|
||||
|
||||
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] = []
|
||||
if is_bench:
|
||||
@@ -144,11 +196,19 @@ def mlx_generate(
|
||||
top_p=task.top_p if task.top_p is not None else 1.0,
|
||||
)
|
||||
|
||||
# Prefill cache with all tokens except the last one
|
||||
prefill_tps = prefill(model, tokenizer, sampler, prompt_tokens[-1:], caches)
|
||||
|
||||
# stream_generate starts from the last token
|
||||
last_token = prompt_tokens[-1:]
|
||||
|
||||
max_tokens = task.max_tokens or MAX_TOKENS
|
||||
generated_text_parts: list[str] = []
|
||||
generation_start_time = time.perf_counter()
|
||||
for out in stream_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=prompt,
|
||||
prompt=last_token,
|
||||
max_tokens=max_tokens,
|
||||
sampler=sampler,
|
||||
logits_processors=logits_processors,
|
||||
@@ -158,12 +218,13 @@ def mlx_generate(
|
||||
kv_group_size=KV_GROUP_SIZE,
|
||||
kv_bits=KV_BITS,
|
||||
):
|
||||
generated_text_parts.append(out.text)
|
||||
logger.info(out.text)
|
||||
|
||||
stats: GenerationStats | None = None
|
||||
if out.finish_reason is not None:
|
||||
stats = GenerationStats(
|
||||
prompt_tps=float(out.prompt_tps),
|
||||
prompt_tps=float(prefill_tps or out.prompt_tps),
|
||||
generation_tps=float(out.generation_tps),
|
||||
prompt_tokens=int(out.prompt_tokens),
|
||||
generation_tokens=int(out.generation_tokens),
|
||||
@@ -185,6 +246,22 @@ def mlx_generate(
|
||||
)
|
||||
|
||||
if out.finish_reason is not None:
|
||||
# Log generation stats
|
||||
generation_elapsed = time.perf_counter() - generation_start_time
|
||||
generated_tokens = len(generated_text_parts)
|
||||
generation_tps = (
|
||||
generated_tokens / generation_elapsed if generation_elapsed > 0 else 0.0
|
||||
)
|
||||
logger.debug(
|
||||
f"Generation complete: prefill {prompt_tokens} tokens @ "
|
||||
f"{prefill_tps:.1f} tok/s, generated {generated_tokens} tokens @ "
|
||||
f"{generation_tps:.1f} tok/s"
|
||||
)
|
||||
# Save cache for future prefix matching (clear first to keep only the last one)
|
||||
if kv_prefix_cache is not None:
|
||||
kv_prefix_cache.clear()
|
||||
full_prompt = prompt + "".join(generated_text_parts)
|
||||
kv_prefix_cache.add_kv_cache(tokenizer, full_prompt, caches)
|
||||
break
|
||||
|
||||
# TODO: Do we want an mx_barrier?
|
||||
|
||||
@@ -70,6 +70,7 @@ from exo.worker.engines.image import (
|
||||
warmup_image_generator,
|
||||
)
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.cache import KVPrefixCache
|
||||
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
apply_chat_template,
|
||||
@@ -103,6 +104,7 @@ def main(
|
||||
model: Model | DistributedImageModel | None = None
|
||||
tokenizer = None
|
||||
group = None
|
||||
kv_prefix_cache: KVPrefixCache | None = None
|
||||
|
||||
current_status: RunnerStatus = RunnerIdle()
|
||||
logger.info("runner created")
|
||||
@@ -171,6 +173,9 @@ def main(
|
||||
f"Unknown model task(s): {shard_metadata.model_card.tasks}"
|
||||
)
|
||||
|
||||
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
|
||||
kv_prefix_cache = KVPrefixCache()
|
||||
|
||||
current_status = RunnerLoaded()
|
||||
logger.info("runner loaded")
|
||||
case StartWarmup() if isinstance(current_status, RunnerLoaded):
|
||||
@@ -238,6 +243,7 @@ def main(
|
||||
tokenizer=tokenizer,
|
||||
task=task_params,
|
||||
prompt=prompt,
|
||||
kv_prefix_cache=kv_prefix_cache,
|
||||
)
|
||||
|
||||
# GPT-OSS specific parsing to match other model formats.
|
||||
@@ -612,7 +618,7 @@ def _process_image_response(
|
||||
command_id=command_id,
|
||||
model_id=shard_metadata.model_card.model_id,
|
||||
event_sender=event_sender,
|
||||
image_index=response.partial_index if is_partial else image_index,
|
||||
image_index=response.image_index,
|
||||
is_partial=is_partial,
|
||||
partial_index=response.partial_index if is_partial else None,
|
||||
total_partials=response.total_partials if is_partial else None,
|
||||
|
||||
Reference in New Issue
Block a user