mirror of
https://github.com/ollama/ollama.git
synced 2026-01-16 11:29:26 -05:00
Compare commits
1 Commits
v0.14.0-rc
...
usage-anal
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d132315276 |
7
.github/workflows/release.yaml
vendored
7
.github/workflows/release.yaml
vendored
@@ -68,7 +68,6 @@ jobs:
|
||||
name: bundles-darwin
|
||||
path: |
|
||||
dist/*.tgz
|
||||
dist/*.tar.zst
|
||||
dist/*.zip
|
||||
dist/*.dmg
|
||||
|
||||
@@ -393,13 +392,13 @@ jobs:
|
||||
done
|
||||
- run: |
|
||||
for ARCHIVE in dist/${{ matrix.os }}-${{ matrix.arch }}/*.tar.in; do
|
||||
tar c -C dist/${{ matrix.os }}-${{ matrix.arch }} -T $ARCHIVE --owner 0 --group 0 | zstd --ultra -22 -T0 >$(basename ${ARCHIVE//.*/}.tar.zst);
|
||||
tar c -C dist/${{ matrix.os }}-${{ matrix.arch }} -T $ARCHIVE --owner 0 --group 0 | pigz -9vc >$(basename ${ARCHIVE//.*/}.tgz);
|
||||
done
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: bundles-${{ matrix.os }}-${{ matrix.arch }}-${{ matrix.target }}
|
||||
path: |
|
||||
*.tar.zst
|
||||
*.tgz
|
||||
|
||||
# Build each Docker variant (OS, arch, and flavor) separately. Using QEMU is unreliable and slower.
|
||||
docker-build-push:
|
||||
@@ -532,7 +531,7 @@ jobs:
|
||||
- name: Upload release artifacts
|
||||
run: |
|
||||
pids=()
|
||||
for payload in dist/*.txt dist/*.zip dist/*.tgz dist/*.tar.zst dist/*.exe dist/*.dmg ; do
|
||||
for payload in dist/*.txt dist/*.zip dist/*.tgz dist/*.exe dist/*.dmg ; do
|
||||
echo "Uploading $payload"
|
||||
gh release upload ${GITHUB_REF_NAME} $payload --clobber &
|
||||
pids[$!]=$!
|
||||
|
||||
@@ -2,22 +2,6 @@ cmake_minimum_required(VERSION 3.21)
|
||||
|
||||
project(Ollama C CXX)
|
||||
|
||||
# Handle cross-compilation on macOS: when CMAKE_OSX_ARCHITECTURES is set to a
|
||||
# single architecture different from the host, override CMAKE_SYSTEM_PROCESSOR
|
||||
# to match. This is necessary because CMAKE_SYSTEM_PROCESSOR defaults to the
|
||||
# host architecture, but downstream projects (like MLX) use it to detect the
|
||||
# target architecture.
|
||||
if(CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_OSX_ARCHITECTURES MATCHES ";")
|
||||
# Single architecture specified
|
||||
if(CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" AND NOT CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64")
|
||||
message(STATUS "Cross-compiling for x86_64: overriding CMAKE_SYSTEM_PROCESSOR from ${CMAKE_SYSTEM_PROCESSOR} to x86_64")
|
||||
set(CMAKE_SYSTEM_PROCESSOR "x86_64")
|
||||
elseif(CMAKE_OSX_ARCHITECTURES STREQUAL "arm64" AND NOT CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
|
||||
message(STATUS "Cross-compiling for arm64: overriding CMAKE_SYSTEM_PROCESSOR from ${CMAKE_SYSTEM_PROCESSOR} to arm64")
|
||||
set(CMAKE_SYSTEM_PROCESSOR "arm64")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
include(CheckLanguage)
|
||||
include(GNUInstallDirs)
|
||||
|
||||
@@ -28,7 +12,7 @@ set(BUILD_SHARED_LIBS ON)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
set(CMAKE_CXX_EXTENSIONS ON) # Recent versions of MLX Requires gnu++17 extensions to compile properly
|
||||
set(CMAKE_CXX_EXTENSIONS OFF)
|
||||
|
||||
set(GGML_BUILD ON)
|
||||
set(GGML_SHARED ON)
|
||||
@@ -163,48 +147,14 @@ if(CMAKE_HIP_COMPILER)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(NOT APPLE)
|
||||
find_package(Vulkan)
|
||||
if(Vulkan_FOUND)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan)
|
||||
install(TARGETS ggml-vulkan
|
||||
RUNTIME_DEPENDENCIES
|
||||
PRE_INCLUDE_REGEXES vulkan
|
||||
PRE_EXCLUDE_REGEXES ".*"
|
||||
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
|
||||
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
|
||||
)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
option(MLX_ENGINE "Enable MLX backend" OFF)
|
||||
|
||||
if(MLX_ENGINE)
|
||||
message(STATUS "Setting up MLX (this takes a while...)")
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/x/ml/backend/mlx)
|
||||
|
||||
# Find CUDA toolkit if MLX is built with CUDA support
|
||||
find_package(CUDAToolkit)
|
||||
|
||||
install(TARGETS mlx mlxc
|
||||
find_package(Vulkan)
|
||||
if(Vulkan_FOUND)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan)
|
||||
install(TARGETS ggml-vulkan
|
||||
RUNTIME_DEPENDENCIES
|
||||
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
|
||||
PRE_INCLUDE_REGEXES cublas cublasLt cudart nvrtc cudnn nccl
|
||||
PRE_INCLUDE_REGEXES vulkan
|
||||
PRE_EXCLUDE_REGEXES ".*"
|
||||
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
||||
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
||||
FRAMEWORK DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
||||
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
|
||||
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
|
||||
)
|
||||
|
||||
# Manually install cudart and cublas since they might not be picked up as direct dependencies
|
||||
if(CUDAToolkit_FOUND)
|
||||
file(GLOB CUDART_LIBS
|
||||
"${CUDAToolkit_LIBRARY_DIR}/libcudart.so*"
|
||||
"${CUDAToolkit_LIBRARY_DIR}/libcublas.so*")
|
||||
if(CUDART_LIBS)
|
||||
install(FILES ${CUDART_LIBS}
|
||||
DESTINATION ${OLLAMA_INSTALL_DIR}
|
||||
COMPONENT MLX)
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
|
||||
@@ -41,7 +41,7 @@
|
||||
"inherits": [ "CUDA" ],
|
||||
"cacheVariables": {
|
||||
"CMAKE_CUDA_ARCHITECTURES": "75-virtual;80-virtual;86-virtual;87-virtual;89-virtual;90-virtual;90a-virtual;100-virtual;103-virtual;110-virtual;120-virtual;121-virtual",
|
||||
"CMAKE_CUDA_FLAGS": "-t 4",
|
||||
"CMAKE_CUDA_FLAGS": "-t 2",
|
||||
"OLLAMA_RUNNER_DIR": "cuda_v13"
|
||||
}
|
||||
},
|
||||
@@ -83,28 +83,6 @@
|
||||
"cacheVariables": {
|
||||
"OLLAMA_RUNNER_DIR": "vulkan"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "MLX",
|
||||
"inherits": [ "Default" ],
|
||||
"cacheVariables": {
|
||||
"MLX_ENGINE": "ON",
|
||||
"OLLAMA_RUNNER_DIR": "mlx"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "MLX CUDA 12",
|
||||
"inherits": [ "MLX", "CUDA 12" ],
|
||||
"cacheVariables": {
|
||||
"OLLAMA_RUNNER_DIR": "mlx_cuda_v12"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "MLX CUDA 13",
|
||||
"inherits": [ "MLX", "CUDA 13" ],
|
||||
"cacheVariables": {
|
||||
"OLLAMA_RUNNER_DIR": "mlx_cuda_v13"
|
||||
}
|
||||
}
|
||||
],
|
||||
"buildPresets": [
|
||||
@@ -162,21 +140,6 @@
|
||||
"name": "Vulkan",
|
||||
"targets": [ "ggml-vulkan" ],
|
||||
"configurePreset": "Vulkan"
|
||||
},
|
||||
{
|
||||
"name": "MLX",
|
||||
"targets": [ "mlx", "mlxc" ],
|
||||
"configurePreset": "MLX"
|
||||
},
|
||||
{
|
||||
"name": "MLX CUDA 12",
|
||||
"targets": [ "mlx", "mlxc" ],
|
||||
"configurePreset": "MLX CUDA 12"
|
||||
},
|
||||
{
|
||||
"name": "MLX CUDA 13",
|
||||
"targets": [ "mlx", "mlxc" ],
|
||||
"configurePreset": "MLX CUDA 13"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
36
Dockerfile
36
Dockerfile
@@ -131,39 +131,7 @@ COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'Vulkan' \
|
||||
&& cmake --build --parallel --preset 'Vulkan' \
|
||||
&& cmake --install build --component Vulkan --strip --parallel 8
|
||||
|
||||
FROM base AS mlx
|
||||
ARG CUDA13VERSION=13.0
|
||||
RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-} \
|
||||
&& dnf install -y openblas-devel lapack-devel \
|
||||
&& dnf install -y libcudnn9-cuda-13 libcudnn9-devel-cuda-13 \
|
||||
&& dnf install -y libnccl libnccl-devel
|
||||
ENV PATH=/usr/local/cuda-13/bin:$PATH
|
||||
ENV BLAS_INCLUDE_DIRS=/usr/include/openblas
|
||||
ENV LAPACK_INCLUDE_DIRS=/usr/include/openblas
|
||||
ENV CGO_LDFLAGS="-L/usr/local/cuda-13/lib64 -L/usr/local/cuda-13/targets/x86_64-linux/lib/stubs"
|
||||
ARG PARALLEL
|
||||
WORKDIR /go/src/github.com/ollama/ollama
|
||||
COPY CMakeLists.txt CMakePresets.json .
|
||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||
COPY x/ml/backend/mlx x/ml/backend/mlx
|
||||
COPY go.mod go.sum .
|
||||
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
|
||||
ENV PATH=/usr/local/go/bin:$PATH
|
||||
RUN go mod download
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'MLX CUDA 13' -DBLAS_INCLUDE_DIRS=/usr/include/openblas -DLAPACK_INCLUDE_DIRS=/usr/include/openblas \
|
||||
&& cmake --build --parallel ${PARALLEL} --preset 'MLX CUDA 13' \
|
||||
&& cmake --install build --component MLX --strip --parallel ${PARALLEL}
|
||||
COPY . .
|
||||
ARG GOFLAGS="'-ldflags=-w -s'"
|
||||
ENV CGO_ENABLED=1
|
||||
ARG CGO_CFLAGS
|
||||
ARG CGO_CXXFLAGS
|
||||
# TODO wire up the actual MLX engine here instead of building the main binary...
|
||||
RUN mkdir -p dist/bin
|
||||
RUN go build -tags mlx -trimpath -buildmode=pie -o dist/bin/imagegen ./x/imagegen/cmd/engine
|
||||
&& cmake --install build --component Vulkan --strip --parallel 8
|
||||
|
||||
|
||||
FROM base AS build
|
||||
@@ -185,8 +153,6 @@ FROM --platform=linux/amd64 scratch AS amd64
|
||||
COPY --from=cuda-12 dist/lib/ollama /lib/ollama/
|
||||
COPY --from=cuda-13 dist/lib/ollama /lib/ollama/
|
||||
COPY --from=vulkan dist/lib/ollama /lib/ollama/
|
||||
COPY --from=mlx /go/src/github.com/ollama/ollama/dist/lib/ollama /lib/ollama/
|
||||
COPY --from=mlx /go/src/github.com/ollama/ollama/dist/bin/ /bin/
|
||||
|
||||
FROM --platform=linux/arm64 scratch AS arm64
|
||||
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
|
||||
|
||||
@@ -377,6 +377,15 @@ func (c *Client) ListRunning(ctx context.Context) (*ProcessResponse, error) {
|
||||
return &lr, nil
|
||||
}
|
||||
|
||||
// Usage returns usage statistics and system info.
|
||||
func (c *Client) Usage(ctx context.Context) (*UsageResponse, error) {
|
||||
var ur UsageResponse
|
||||
if err := c.do(ctx, http.MethodGet, "/api/usage", nil, &ur); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &ur, nil
|
||||
}
|
||||
|
||||
// Copy copies a model - creating a model with another name from an existing
|
||||
// model.
|
||||
func (c *Client) Copy(ctx context.Context, req *CopyRequest) error {
|
||||
|
||||
27
api/types.go
27
api/types.go
@@ -792,6 +792,33 @@ type ProcessResponse struct {
|
||||
Models []ProcessModelResponse `json:"models"`
|
||||
}
|
||||
|
||||
// UsageResponse is the response from [Client.Usage].
|
||||
type UsageResponse struct {
|
||||
GPUs []GPUUsage `json:"gpus,omitempty"`
|
||||
}
|
||||
|
||||
// GPUUsage contains GPU/device memory usage breakdown.
|
||||
type GPUUsage struct {
|
||||
Name string `json:"name"` // Device name (e.g., "Apple M2 Max", "NVIDIA GeForce RTX 4090")
|
||||
Backend string `json:"backend"` // CUDA, ROCm, Metal, etc.
|
||||
Total uint64 `json:"total"`
|
||||
Free uint64 `json:"free"`
|
||||
Used uint64 `json:"used"` // Memory used by Ollama
|
||||
Other uint64 `json:"other"` // Memory used by other processes
|
||||
}
|
||||
|
||||
// UsageStats contains usage statistics.
|
||||
type UsageStats struct {
|
||||
Requests int64 `json:"requests"`
|
||||
TokensInput int64 `json:"tokens_input"`
|
||||
TokensOutput int64 `json:"tokens_output"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
Models map[string]int64 `json:"models,omitempty"`
|
||||
Sources map[string]int64 `json:"sources,omitempty"`
|
||||
ToolCalls int64 `json:"tool_calls,omitempty"`
|
||||
StructuredOutput int64 `json:"structured_output,omitempty"`
|
||||
}
|
||||
|
||||
// ListModelResponse is a single model description in [ListResponse].
|
||||
type ListModelResponse struct {
|
||||
Name string `json:"name"`
|
||||
|
||||
@@ -1833,6 +1833,7 @@ func NewCLI() *cobra.Command {
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: ListRunningHandler,
|
||||
}
|
||||
|
||||
copyCmd := &cobra.Command{
|
||||
Use: "cp SOURCE DESTINATION",
|
||||
Short: "Copy a model",
|
||||
|
||||
@@ -6,14 +6,11 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"iter"
|
||||
"log/slog"
|
||||
"maps"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
ofs "github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
@@ -21,13 +18,8 @@ type ModelParameters struct {
|
||||
Architectures []string `json:"architectures"`
|
||||
VocabSize uint32 `json:"vocab_size"`
|
||||
|
||||
// TODO is this needed?
|
||||
ModelType string `json:"model_type"`
|
||||
|
||||
TextModel struct {
|
||||
VocabSize uint32 `json:"vocab_size"`
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
ModelType string `json:"model_type"`
|
||||
VocabSize uint32 `json:"vocab_size"`
|
||||
} `json:"text_config"`
|
||||
}
|
||||
|
||||
@@ -41,94 +33,8 @@ type AdapterParameters struct {
|
||||
} `json:"lora_parameters"`
|
||||
}
|
||||
|
||||
type KV map[string]any
|
||||
|
||||
func (kv KV) Architecture() string {
|
||||
return kv.String("general.architecture", "unknown")
|
||||
}
|
||||
|
||||
type valueTypes interface {
|
||||
uint8 | int8 | uint16 | int16 |
|
||||
uint32 | int32 | uint64 | int64 |
|
||||
string | float32 | float64 | bool
|
||||
}
|
||||
|
||||
type arrayValueTypes interface {
|
||||
[]uint8 | []int8 | []uint16 | []int16 |
|
||||
[]uint32 | []int32 | []uint64 | []int64 |
|
||||
[]string | []float32 | []float64 | []bool
|
||||
}
|
||||
|
||||
func keyValue[T valueTypes | arrayValueTypes](kv KV, key string, defaultValue ...T) (T, bool) {
|
||||
if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") {
|
||||
key = kv.Architecture() + "." + key
|
||||
}
|
||||
|
||||
if val, ok := kv[key].(T); ok {
|
||||
return val, true
|
||||
}
|
||||
return defaultValue[0], false
|
||||
}
|
||||
|
||||
func (kv KV) String(key string, defaultValue ...string) string {
|
||||
val, _ := keyValue(kv, key, append(defaultValue, "")...)
|
||||
return val
|
||||
}
|
||||
|
||||
func (kv KV) Uint(key string, defaultValue ...uint32) uint32 {
|
||||
val, _ := keyValue(kv, key, append(defaultValue, 0)...)
|
||||
return val
|
||||
}
|
||||
|
||||
func (kv KV) Float(key string, defaultValue ...float32) float32 {
|
||||
val, _ := keyValue(kv, key, append(defaultValue, 0)...)
|
||||
return val
|
||||
}
|
||||
|
||||
func (kv KV) Bool(key string, defaultValue ...bool) bool {
|
||||
val, _ := keyValue(kv, key, append(defaultValue, false)...)
|
||||
return val
|
||||
}
|
||||
|
||||
func (kv KV) Strings(key string, defaultValue ...[]string) []string {
|
||||
val, _ := keyValue(kv, key, append(defaultValue, []string{""})...)
|
||||
return val
|
||||
}
|
||||
|
||||
func (kv KV) Ints(key string, defaultValue ...[]int32) []int32 {
|
||||
val, _ := keyValue(kv, key, append(defaultValue, []int32{0})...)
|
||||
return val
|
||||
}
|
||||
|
||||
func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 {
|
||||
val, _ := keyValue(kv, key, append(defaultValue, []uint32{0})...)
|
||||
return val
|
||||
}
|
||||
|
||||
func (kv KV) Floats(key string, defaultValue ...[]float32) []float32 {
|
||||
val, _ := keyValue(kv, key, append(defaultValue, []float32{0})...)
|
||||
return val
|
||||
}
|
||||
|
||||
func (kv KV) Bools(key string, defaultValue ...[]bool) []bool {
|
||||
val, _ := keyValue(kv, key, append(defaultValue, []bool{false})...)
|
||||
return val
|
||||
}
|
||||
|
||||
func (kv KV) Len() int {
|
||||
return len(kv)
|
||||
}
|
||||
|
||||
func (kv KV) Keys() iter.Seq[string] {
|
||||
return maps.Keys(kv)
|
||||
}
|
||||
|
||||
func (kv KV) Value(key string) any {
|
||||
return kv[key]
|
||||
}
|
||||
|
||||
func (ModelParameters) KV(t *Tokenizer) KV {
|
||||
kv := KV{
|
||||
func (ModelParameters) KV(t *Tokenizer) ggml.KV {
|
||||
kv := ggml.KV{
|
||||
"general.file_type": uint32(1),
|
||||
"general.quantization_version": uint32(2),
|
||||
"tokenizer.ggml.pre": t.Pre,
|
||||
@@ -157,7 +63,7 @@ func (ModelParameters) KV(t *Tokenizer) KV {
|
||||
return kv
|
||||
}
|
||||
|
||||
func (p AdapterParameters) KV() KV {
|
||||
func (p AdapterParameters) KV() ggml.KV {
|
||||
var alpha float32
|
||||
if p.LoraParameters.Alpha == 0 {
|
||||
alpha = float32(p.Alpha)
|
||||
@@ -165,7 +71,7 @@ func (p AdapterParameters) KV() KV {
|
||||
alpha = p.LoraParameters.Alpha
|
||||
}
|
||||
|
||||
kv := KV{
|
||||
kv := ggml.KV{
|
||||
"adapter.lora.alpha": alpha,
|
||||
"adapter.type": "lora",
|
||||
"general.file_type": uint32(1),
|
||||
@@ -182,14 +88,9 @@ func (ModelParameters) specialTokenTypes() []string {
|
||||
}
|
||||
}
|
||||
|
||||
type ModelKV interface {
|
||||
// KV maps parameters to LLM key-values
|
||||
KV(*Tokenizer) KV
|
||||
}
|
||||
|
||||
type ModelConverter interface {
|
||||
ModelKV
|
||||
|
||||
// KV maps parameters to LLM key-values
|
||||
KV(*Tokenizer) ggml.KV
|
||||
// Tensors maps input tensors to LLM tensors. Model specific modifications can be done here.
|
||||
Tensors([]Tensor) []*ggml.Tensor
|
||||
// Replacements returns a list of string pairs to replace in tensor names.
|
||||
@@ -206,7 +107,7 @@ type moreParser interface {
|
||||
|
||||
type AdapterConverter interface {
|
||||
// KV maps parameters to LLM key-values
|
||||
KV(ofs.Config) KV
|
||||
KV(ggml.KV) ggml.KV
|
||||
// Tensors maps input tensors to LLM tensors. Adapter specific modifications can be done here.
|
||||
Tensors([]Tensor) []*ggml.Tensor
|
||||
// Replacements returns a list of string pairs to replace in tensor names.
|
||||
@@ -214,7 +115,7 @@ type AdapterConverter interface {
|
||||
Replacements() []string
|
||||
}
|
||||
|
||||
func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ofs.Config) error {
|
||||
func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ggml.KV) error {
|
||||
bts, err := fs.ReadFile(fsys, "adapter_config.json")
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -225,8 +126,8 @@ func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ofs.Config) error {
|
||||
return err
|
||||
}
|
||||
|
||||
arch := baseKV.Architecture()
|
||||
if arch == "" {
|
||||
arch, ok := baseKV["general.architecture"]
|
||||
if !ok {
|
||||
return errors.New("architecture not set for the base model")
|
||||
}
|
||||
|
||||
@@ -252,19 +153,23 @@ func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ofs.Config) error {
|
||||
return writeFile(f, conv.KV(baseKV), conv.Tensors(ts))
|
||||
}
|
||||
|
||||
func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
|
||||
// Convert writes an Ollama compatible model to the provided io.WriteSeeker based on configurations
|
||||
// and files it finds in the input path.
|
||||
// Supported input model formats include safetensors.
|
||||
// Supported input tokenizers files include tokenizer.json (preferred) and tokenizer.model.
|
||||
func ConvertModel(fsys fs.FS, f *os.File) error {
|
||||
bts, err := fs.ReadFile(fsys, "config.json")
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
var p ModelParameters
|
||||
if err := json.Unmarshal(bts, &p); err != nil {
|
||||
return nil, nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
if len(p.Architectures) < 1 {
|
||||
return nil, nil, errors.New("unknown architecture")
|
||||
return errors.New("unknown architecture")
|
||||
}
|
||||
|
||||
var conv ModelConverter
|
||||
@@ -312,22 +217,22 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
|
||||
case "DeepseekV3ForCausalLM":
|
||||
conv = &deepseek2Model{}
|
||||
default:
|
||||
return nil, nil, fmt.Errorf("unsupported architecture %q", p.Architectures[0])
|
||||
return fmt.Errorf("unsupported architecture %q", p.Architectures[0])
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(bts, conv); err != nil {
|
||||
return nil, nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
if t, ok := conv.(moreParser); ok {
|
||||
if err := t.parseMore(fsys); err != nil {
|
||||
return nil, nil, err
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
t, err := parseTokenizer(fsys, conv.specialTokenTypes())
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
vocabSize := int(cmp.Or(p.VocabSize, p.TextModel.VocabSize))
|
||||
@@ -349,19 +254,6 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
|
||||
default:
|
||||
slog.Debug("vocabulary", "size", len(t.Vocabulary.Tokens))
|
||||
}
|
||||
return conv, t, nil
|
||||
}
|
||||
|
||||
// Convert writes an Ollama compatible model to the provided io.WriteSeeker based on configurations
|
||||
// and files it finds in the input path.
|
||||
// Supported input model formats include safetensors.
|
||||
// Supported input tokenizers files include tokenizer.json (preferred) and tokenizer.model.
|
||||
func ConvertModel(fsys fs.FS, f *os.File) error {
|
||||
kv, t, err := LoadModelMetadata(fsys)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
conv := kv.(ModelConverter)
|
||||
|
||||
ts, err := parseTensors(fsys, strings.NewReplacer(conv.Replacements()...))
|
||||
if err != nil {
|
||||
@@ -371,7 +263,7 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
|
||||
return writeFile(f, conv.KV(t), conv.Tensors(ts))
|
||||
}
|
||||
|
||||
func writeFile(f *os.File, kv KV, ts []*ggml.Tensor) error {
|
||||
func writeFile(f *os.File, kv ggml.KV, ts []*ggml.Tensor) error {
|
||||
for i := range ts {
|
||||
ts[i].Shape = slices.Clone(ts[i].Shape)
|
||||
slices.Reverse(ts[i].Shape)
|
||||
|
||||
@@ -88,7 +88,7 @@ func (p *bertModel) parseMore(fsys fs.FS) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *bertModel) KV(t *Tokenizer) KV {
|
||||
func (p *bertModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "bert"
|
||||
kv["bert.attention.causal"] = false
|
||||
|
||||
@@ -24,7 +24,7 @@ type commandrModel struct {
|
||||
|
||||
var _ ModelConverter = (*commandrModel)(nil)
|
||||
|
||||
func (p *commandrModel) KV(t *Tokenizer) KV {
|
||||
func (p *commandrModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "command-r"
|
||||
kv["general.name"] = "command-r"
|
||||
|
||||
@@ -47,7 +47,7 @@ type deepseek2Model struct {
|
||||
Architecture string
|
||||
}
|
||||
|
||||
func (p *deepseek2Model) KV(t *Tokenizer) KV {
|
||||
func (p *deepseek2Model) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "deepseek2"
|
||||
kv["general.type"] = "model"
|
||||
|
||||
@@ -41,7 +41,7 @@ type deepseekocr struct {
|
||||
} `json:"vision_config"`
|
||||
}
|
||||
|
||||
func (m *deepseekocr) KV(t *Tokenizer) KV {
|
||||
func (m *deepseekocr) KV(t *Tokenizer) ggml.KV {
|
||||
kv := m.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "deepseekocr"
|
||||
kv["block_count"] = m.LanguageConfig.HiddenLayers
|
||||
|
||||
@@ -23,7 +23,7 @@ type gemmaModel struct {
|
||||
|
||||
var _ ModelConverter = (*gemmaModel)(nil)
|
||||
|
||||
func (p *gemmaModel) KV(t *Tokenizer) KV {
|
||||
func (p *gemmaModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "gemma"
|
||||
kv["gemma.context_length"] = p.MaxPositionEmbeddings
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
package convert
|
||||
|
||||
import "github.com/ollama/ollama/fs/ggml"
|
||||
|
||||
type gemma2Model struct {
|
||||
gemmaModel
|
||||
SlidingWindow uint32 `json:"sliding_window"`
|
||||
@@ -7,7 +9,7 @@ type gemma2Model struct {
|
||||
FinalLogitSoftcap float32 `json:"final_logit_softcapping"`
|
||||
}
|
||||
|
||||
func (p *gemma2Model) KV(t *Tokenizer) KV {
|
||||
func (p *gemma2Model) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "gemma2"
|
||||
kv["gemma2.context_length"] = p.MaxPositionEmbeddings
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"github.com/pdevine/tensor"
|
||||
"github.com/pdevine/tensor/native"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
@@ -16,7 +15,7 @@ type gemma2Adapter struct {
|
||||
|
||||
var _ AdapterConverter = (*gemma2Adapter)(nil)
|
||||
|
||||
func (p *gemma2Adapter) KV(baseKV fs.Config) KV {
|
||||
func (p *gemma2Adapter) KV(baseKV ggml.KV) ggml.KV {
|
||||
kv := p.AdapterParameters.KV()
|
||||
kv["general.architecture"] = "gemma2"
|
||||
return kv
|
||||
|
||||
@@ -3,6 +3,8 @@ package convert
|
||||
import (
|
||||
"cmp"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
type gemma3Model struct {
|
||||
@@ -53,7 +55,7 @@ const (
|
||||
gemma27BLayerCount = 62
|
||||
)
|
||||
|
||||
func (p *gemma3Model) KV(t *Tokenizer) KV {
|
||||
func (p *gemma3Model) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "gemma3"
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ type gemma3nModel struct {
|
||||
VisionModel struct{} `json:"vision_config"`
|
||||
}
|
||||
|
||||
func (m *gemma3nModel) KV(t *Tokenizer) KV {
|
||||
func (m *gemma3nModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := m.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "gemma3n"
|
||||
kv["gemma3n.activation_sparsity_scale"] = slices.Collect(func(yield func(float32) bool) {
|
||||
|
||||
@@ -37,7 +37,7 @@ type gptossModel struct {
|
||||
|
||||
var _ ModelConverter = (*gptossModel)(nil)
|
||||
|
||||
func (m *gptossModel) KV(t *Tokenizer) KV {
|
||||
func (m *gptossModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := m.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "gptoss"
|
||||
kv["general.file_type"] = uint32(4)
|
||||
|
||||
@@ -48,7 +48,7 @@ type llamaModel struct {
|
||||
|
||||
var _ ModelConverter = (*llamaModel)(nil)
|
||||
|
||||
func (p *llamaModel) KV(t *Tokenizer) KV {
|
||||
func (p *llamaModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "llama"
|
||||
kv["llama.vocab_size"] = p.VocabSize
|
||||
|
||||
@@ -35,7 +35,7 @@ type llama4Model struct {
|
||||
}
|
||||
|
||||
// KV implements ModelConverter.
|
||||
func (p *llama4Model) KV(t *Tokenizer) KV {
|
||||
func (p *llama4Model) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "llama4"
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"github.com/pdevine/tensor"
|
||||
"github.com/pdevine/tensor/native"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
@@ -19,13 +18,13 @@ type llamaAdapter struct {
|
||||
|
||||
var _ AdapterConverter = (*llamaAdapter)(nil)
|
||||
|
||||
func (p *llamaAdapter) KV(baseKV fs.Config) KV {
|
||||
func (p *llamaAdapter) KV(baseKV ggml.KV) ggml.KV {
|
||||
kv := p.AdapterParameters.KV()
|
||||
kv["general.architecture"] = "llama"
|
||||
kv["llama.attention.head_count"] = baseKV.Value("llama.attention.head_count")
|
||||
kv["llama.attention.head_count_kv"] = baseKV.Value("llama.attention.head_count_kv")
|
||||
kv["llama.attention.head_count"] = baseKV["llama.attention.head_count"]
|
||||
kv["llama.attention.head_count_kv"] = baseKV["llama.attention.head_count_kv"]
|
||||
|
||||
p.NumAttentionHeads = baseKV.Value("llama.attention.head_count").(uint32)
|
||||
p.NumAttentionHeads = baseKV["llama.attention.head_count"].(uint32)
|
||||
|
||||
return kv
|
||||
}
|
||||
|
||||
@@ -60,7 +60,7 @@ type mistral3Model struct {
|
||||
ProjectorHiddenAct string `json:"projector_hidden_act"`
|
||||
}
|
||||
|
||||
func (p *mistral3Model) KV(t *Tokenizer) KV {
|
||||
func (p *mistral3Model) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "mistral3"
|
||||
kv["mistral3.vocab_size"] = p.TextModel.VocabSize
|
||||
|
||||
@@ -39,7 +39,7 @@ type mistral3CausalModel struct {
|
||||
} `json:"rope_parameters"`
|
||||
}
|
||||
|
||||
func (p *mistral3CausalModel) KV(t *Tokenizer) KV {
|
||||
func (p *mistral3CausalModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "mistral3"
|
||||
kv["mistral3.vocab_size"] = p.VocabSize
|
||||
|
||||
@@ -12,7 +12,7 @@ type mixtralModel struct {
|
||||
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
|
||||
}
|
||||
|
||||
func (p *mixtralModel) KV(t *Tokenizer) KV {
|
||||
func (p *mixtralModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.llamaModel.KV(t)
|
||||
|
||||
if p.NumLocalExperts > 0 {
|
||||
|
||||
@@ -34,7 +34,7 @@ type mllamaModel struct {
|
||||
} `json:"vision_config"`
|
||||
}
|
||||
|
||||
func (m *mllamaModel) KV(t *Tokenizer) KV {
|
||||
func (m *mllamaModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := m.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "mllama"
|
||||
|
||||
|
||||
@@ -87,7 +87,7 @@ func (p *nomicbertModel) parseMore(fsys fs.FS) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *nomicbertModel) KV(t *Tokenizer) KV {
|
||||
func (p *nomicbertModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
|
||||
// Determine architecture based on MoE parameters (following qwen3 pattern)
|
||||
|
||||
@@ -34,7 +34,7 @@ type olmoModel struct {
|
||||
|
||||
var _ ModelConverter = (*olmoModel)(nil)
|
||||
|
||||
func (p *olmoModel) KV(t *Tokenizer) KV {
|
||||
func (p *olmoModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "olmo3"
|
||||
kv["olmo3.block_count"] = p.NumHiddenLayers
|
||||
|
||||
@@ -37,7 +37,7 @@ type phi3Model struct {
|
||||
|
||||
var _ ModelConverter = (*phi3Model)(nil)
|
||||
|
||||
func (p *phi3Model) KV(t *Tokenizer) KV {
|
||||
func (p *phi3Model) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "phi3"
|
||||
kv["phi3.context_length"] = p.MaxPositionEmbeddings
|
||||
|
||||
@@ -22,7 +22,7 @@ type qwen2Model struct {
|
||||
|
||||
var _ ModelConverter = (*qwen2Model)(nil)
|
||||
|
||||
func (q *qwen2Model) KV(t *Tokenizer) KV {
|
||||
func (q *qwen2Model) KV(t *Tokenizer) ggml.KV {
|
||||
kv := q.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "qwen2"
|
||||
kv["qwen2.block_count"] = q.HiddenLayers
|
||||
|
||||
@@ -29,7 +29,7 @@ type qwen25VLModel struct {
|
||||
|
||||
var _ ModelConverter = (*qwen25VLModel)(nil)
|
||||
|
||||
func (q *qwen25VLModel) KV(t *Tokenizer) KV {
|
||||
func (q *qwen25VLModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := q.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "qwen25vl"
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ type qwen3Model struct {
|
||||
}
|
||||
|
||||
// KV implements ModelConverter.
|
||||
func (q *qwen3Model) KV(t *Tokenizer) KV {
|
||||
func (q *qwen3Model) KV(t *Tokenizer) ggml.KV {
|
||||
arch := "qwen3"
|
||||
if q.NumExperts > 0 {
|
||||
arch += "moe"
|
||||
|
||||
@@ -45,7 +45,7 @@ func (m *qwen3VLModel) parseMore(fsys fs.FS) error {
|
||||
return json.Unmarshal(bts, &m.VisionModel)
|
||||
}
|
||||
|
||||
func (m *qwen3VLModel) KV(t *Tokenizer) KV {
|
||||
func (m *qwen3VLModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := m.qwen3Model.KV(t)
|
||||
|
||||
arch := "qwen3vl"
|
||||
|
||||
@@ -19,7 +19,6 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
fsc "github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
@@ -29,7 +28,7 @@ type tensorData struct {
|
||||
Shape []int `json:"shape"`
|
||||
}
|
||||
|
||||
func convertFull(t *testing.T, fsys fs.FS) (*os.File, fsc.Config, ggml.Tensors) {
|
||||
func convertFull(t *testing.T, fsys fs.FS) (*os.File, ggml.KV, ggml.Tensors) {
|
||||
t.Helper()
|
||||
|
||||
f, err := os.CreateTemp(t.TempDir(), "f16")
|
||||
@@ -60,10 +59,9 @@ func convertFull(t *testing.T, fsys fs.FS) (*os.File, fsc.Config, ggml.Tensors)
|
||||
return r, m.KV(), m.Tensors()
|
||||
}
|
||||
|
||||
func generateResultsJSON(t *testing.T, f *os.File, kv fsc.Config, tensors ggml.Tensors) map[string]string {
|
||||
func generateResultsJSON(t *testing.T, f *os.File, kv ggml.KV, tensors ggml.Tensors) map[string]string {
|
||||
actual := make(map[string]string)
|
||||
for k := range kv.Keys() {
|
||||
v := kv.Value(k)
|
||||
for k, v := range kv {
|
||||
if s, ok := v.(json.Marshaler); !ok {
|
||||
actual[k] = fmt.Sprintf("%v", v)
|
||||
} else {
|
||||
@@ -279,7 +277,7 @@ func generateSafetensorTestData(t *testing.T, tempDir string, tensorData map[str
|
||||
func TestConvertAdapter(t *testing.T) {
|
||||
type AdapterCase struct {
|
||||
Name string
|
||||
BaseKV KV
|
||||
BaseKV map[string]any
|
||||
Expected map[string]string
|
||||
}
|
||||
|
||||
|
||||
@@ -20,8 +20,8 @@ curl -fsSL https://ollama.com/install.sh | sh
|
||||
Download and extract the package:
|
||||
|
||||
```shell
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tar.zst \
|
||||
| sudo tar x -C /usr
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tgz \
|
||||
| sudo tar zx -C /usr
|
||||
```
|
||||
|
||||
Start Ollama:
|
||||
@@ -41,8 +41,8 @@ ollama -v
|
||||
If you have an AMD GPU, also download and extract the additional ROCm package:
|
||||
|
||||
```shell
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tar.zst \
|
||||
| sudo tar x -C /usr
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tgz \
|
||||
| sudo tar zx -C /usr
|
||||
```
|
||||
|
||||
### ARM64 install
|
||||
@@ -50,8 +50,8 @@ curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tar.zst \
|
||||
Download and extract the ARM64-specific package:
|
||||
|
||||
```shell
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-arm64.tar.zst \
|
||||
| sudo tar x -C /usr
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-arm64.tgz \
|
||||
| sudo tar zx -C /usr
|
||||
```
|
||||
|
||||
### Adding Ollama as a startup service (recommended)
|
||||
@@ -146,8 +146,8 @@ curl -fsSL https://ollama.com/install.sh | sh
|
||||
Or by re-downloading Ollama:
|
||||
|
||||
```shell
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tar.zst \
|
||||
| sudo tar x -C /usr
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tgz \
|
||||
| sudo tar zx -C /usr
|
||||
```
|
||||
|
||||
## Installing specific versions
|
||||
|
||||
@@ -206,6 +206,8 @@ var (
|
||||
UseAuth = Bool("OLLAMA_AUTH")
|
||||
// Enable Vulkan backend
|
||||
EnableVulkan = Bool("OLLAMA_VULKAN")
|
||||
// Usage enables usage statistics reporting
|
||||
Usage = Bool("OLLAMA_USAGE")
|
||||
)
|
||||
|
||||
func String(s string) func() string {
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
package fs
|
||||
|
||||
import "iter"
|
||||
|
||||
type Config interface {
|
||||
Architecture() string
|
||||
String(string, ...string) string
|
||||
@@ -13,8 +11,4 @@ type Config interface {
|
||||
Ints(string, ...[]int32) []int32
|
||||
Floats(string, ...[]float32) []float32
|
||||
Bools(string, ...[]bool) []bool
|
||||
|
||||
Len() int
|
||||
Keys() iter.Seq[string]
|
||||
Value(key string) any
|
||||
}
|
||||
|
||||
@@ -6,9 +6,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"iter"
|
||||
"log/slog"
|
||||
"maps"
|
||||
"math"
|
||||
"slices"
|
||||
"strings"
|
||||
@@ -241,18 +239,6 @@ func (kv KV) Bools(key string, defaultValue ...[]bool) []bool {
|
||||
return val.values
|
||||
}
|
||||
|
||||
func (kv KV) Len() int {
|
||||
return len(kv)
|
||||
}
|
||||
|
||||
func (kv KV) Keys() iter.Seq[string] {
|
||||
return maps.Keys(kv)
|
||||
}
|
||||
|
||||
func (kv KV) Value(key string) any {
|
||||
return kv[key]
|
||||
}
|
||||
|
||||
func (kv KV) OllamaEngineRequired() bool {
|
||||
return slices.Contains([]string{
|
||||
"bert",
|
||||
|
||||
@@ -8,12 +8,12 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"maps"
|
||||
"os"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
@@ -508,7 +508,7 @@ func writeGGUFArray[S ~[]E, E any](w io.Writer, t uint32, s S) error {
|
||||
return binary.Write(w, binary.LittleEndian, s)
|
||||
}
|
||||
|
||||
func WriteGGUF(f *os.File, kv fs.Config, ts []*Tensor) error {
|
||||
func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
|
||||
arch := kv.String("general.architecture")
|
||||
if arch == "" {
|
||||
return fmt.Errorf("architecture not set")
|
||||
@@ -526,12 +526,12 @@ func WriteGGUF(f *os.File, kv fs.Config, ts []*Tensor) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := binary.Write(f, binary.LittleEndian, uint64(kv.Len())); err != nil {
|
||||
if err := binary.Write(f, binary.LittleEndian, uint64(len(kv))); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, key := range slices.Sorted(kv.Keys()) {
|
||||
if err := ggufWriteKV(f, arch, key, kv.Value(key)); err != nil {
|
||||
for _, key := range slices.Sorted(maps.Keys(kv)) {
|
||||
if err := ggufWriteKV(f, arch, key, kv[key]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
"golang.org/x/text/encoding/unicode"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/convert"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
@@ -802,7 +801,7 @@ func createBinFile(t *testing.T, kv map[string]any, ti []*ggml.Tensor) (string,
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var base convert.KV = map[string]any{"general.architecture": "test"}
|
||||
base := map[string]any{"general.architecture": "test"}
|
||||
maps.Copy(base, kv)
|
||||
|
||||
if err := ggml.WriteGGUF(f, base, ti); err != nil {
|
||||
|
||||
@@ -6,6 +6,9 @@ import (
|
||||
|
||||
var ErrInterrupt = errors.New("Interrupt")
|
||||
|
||||
// ErrExpandOutput is returned when user presses Ctrl+O to expand tool output
|
||||
var ErrExpandOutput = errors.New("ExpandOutput")
|
||||
|
||||
type InterruptError struct {
|
||||
Line []rune
|
||||
}
|
||||
|
||||
@@ -206,6 +206,9 @@ func (i *Instance) Readline() (string, error) {
|
||||
buf.DeleteBefore()
|
||||
case CharCtrlL:
|
||||
buf.ClearScreen()
|
||||
case CharCtrlO:
|
||||
// Ctrl+O - expand tool output
|
||||
return "", ErrExpandOutput
|
||||
case CharCtrlW:
|
||||
buf.DeleteWord()
|
||||
case CharCtrlZ:
|
||||
|
||||
@@ -42,39 +42,18 @@ shift $(( $OPTIND - 1 ))
|
||||
_build_darwin() {
|
||||
for ARCH in $ARCHS; do
|
||||
status "Building darwin $ARCH"
|
||||
INSTALL_PREFIX=dist/darwin-$ARCH/
|
||||
INSTALL_PREFIX=dist/darwin-$ARCH/
|
||||
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 go build -o $INSTALL_PREFIX .
|
||||
|
||||
if [ "$ARCH" = "amd64" ]; then
|
||||
status "Building darwin $ARCH dynamic backends"
|
||||
BUILD_DIR=build/darwin-$ARCH
|
||||
cmake -B $BUILD_DIR \
|
||||
cmake -B build/darwin-$ARCH \
|
||||
-DCMAKE_OSX_ARCHITECTURES=x86_64 \
|
||||
-DCMAKE_OSX_DEPLOYMENT_TARGET=14.0 \
|
||||
-DCMAKE_INSTALL_PREFIX=$INSTALL_PREFIX \
|
||||
-DMLX_ENGINE=ON \
|
||||
-DMLX_ENABLE_X64_MAC=ON \
|
||||
-DOLLAMA_RUNNER_DIR=./
|
||||
cmake --build $BUILD_DIR --target ggml-cpu -j
|
||||
cmake --build $BUILD_DIR --target mlx mlxc -j
|
||||
cmake --install $BUILD_DIR --component CPU
|
||||
cmake --install $BUILD_DIR --component MLX
|
||||
# Override CGO flags to point to the amd64 build directory
|
||||
MLX_CGO_CFLAGS="-O3 -I$(pwd)/$BUILD_DIR/_deps/mlx-c-src -mmacosx-version-min=14.0"
|
||||
MLX_CGO_LDFLAGS="-L$(pwd)/$BUILD_DIR/lib/ollama -lmlxc -lmlx -Wl,-rpath,@executable_path -lc++ -framework Accelerate -mmacosx-version-min=14.0"
|
||||
else
|
||||
BUILD_DIR=build
|
||||
cmake --preset MLX \
|
||||
-DOLLAMA_RUNNER_DIR=./ \
|
||||
-DCMAKE_OSX_DEPLOYMENT_TARGET=14.0 \
|
||||
-DCMAKE_OSX_DEPLOYMENT_TARGET=11.3 \
|
||||
-DCMAKE_INSTALL_PREFIX=$INSTALL_PREFIX
|
||||
cmake --build --preset MLX --parallel
|
||||
cmake --install $BUILD_DIR --component MLX
|
||||
# Use default CGO flags from mlx.go for arm64
|
||||
MLX_CGO_CFLAGS="-O3 -I$(pwd)/$BUILD_DIR/_deps/mlx-c-src -mmacosx-version-min=14.0"
|
||||
MLX_CGO_LDFLAGS="-L$(pwd)/$BUILD_DIR/lib/ollama -lmlxc -lmlx -Wl,-rpath,@executable_path -lc++ -framework Metal -framework Foundation -framework Accelerate -mmacosx-version-min=14.0"
|
||||
cmake --build build/darwin-$ARCH --target ggml-cpu -j
|
||||
cmake --install build/darwin-$ARCH --component CPU
|
||||
fi
|
||||
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 CGO_CFLAGS="$MLX_CGO_CFLAGS" CGO_LDFLAGS="$MLX_CGO_LDFLAGS" go build -tags mlx -o $INSTALL_PREFIX/imagegen ./x/imagegen/cmd/engine
|
||||
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 go build -o $INSTALL_PREFIX .
|
||||
done
|
||||
}
|
||||
|
||||
@@ -82,12 +61,10 @@ _sign_darwin() {
|
||||
status "Creating universal binary..."
|
||||
mkdir -p dist/darwin
|
||||
lipo -create -output dist/darwin/ollama dist/darwin-*/ollama
|
||||
lipo -create -output dist/darwin/imagegen dist/darwin-*/imagegen
|
||||
chmod +x dist/darwin/ollama
|
||||
chmod +x dist/darwin/imagegen
|
||||
|
||||
if [ -n "$APPLE_IDENTITY" ]; then
|
||||
for F in dist/darwin/ollama dist/darwin-*/lib/ollama/* dist/darwin/imagegen; do
|
||||
for F in dist/darwin/ollama dist/darwin-amd64/lib/ollama/*; do
|
||||
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime $F
|
||||
done
|
||||
|
||||
@@ -154,23 +131,17 @@ _build_macapp() {
|
||||
mkdir -p dist/Ollama.app/Contents/Resources
|
||||
if [ -d dist/darwin-amd64 ]; then
|
||||
lipo -create -output dist/Ollama.app/Contents/Resources/ollama dist/darwin-amd64/ollama dist/darwin-arm64/ollama
|
||||
lipo -create -output dist/Ollama.app/Contents/Resources/imagegen dist/darwin-amd64/imagegen dist/darwin-arm64/imagegen
|
||||
for F in dist/darwin-amd64/lib/ollama/*mlx*.dylib ; do
|
||||
lipo -create -output dist/darwin/$(basename $F) $F dist/darwin-arm64/lib/ollama/$(basename $F)
|
||||
done
|
||||
cp dist/darwin-*/lib/ollama/*.so dist/darwin-*/lib/ollama/*.dylib dist/Ollama.app/Contents/Resources/
|
||||
cp dist/darwin/*.dylib dist/Ollama.app/Contents/Resources/
|
||||
cp dist/darwin-amd64/lib/ollama/*.so dist/darwin-amd64/lib/ollama/*.dylib dist/Ollama.app/Contents/Resources/
|
||||
else
|
||||
cp -a dist/darwin/ollama dist/Ollama.app/Contents/Resources/ollama
|
||||
cp dist/darwin/*.so dist/darwin/*.dylib dist/Ollama.app/Contents/Resources/
|
||||
fi
|
||||
cp -a dist/darwin/imagegen dist/Ollama.app/Contents/Resources/imagegen
|
||||
chmod a+x dist/Ollama.app/Contents/Resources/ollama
|
||||
|
||||
# Sign
|
||||
if [ -n "$APPLE_IDENTITY" ]; then
|
||||
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime dist/Ollama.app/Contents/Resources/ollama
|
||||
for lib in dist/Ollama.app/Contents/Resources/*.so dist/Ollama.app/Contents/Resources/*.dylib dist/Ollama.app/Contents/Resources/imagegen ; do
|
||||
for lib in dist/Ollama.app/Contents/Resources/*.so dist/Ollama.app/Contents/Resources/*.dylib ; do
|
||||
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime ${lib}
|
||||
done
|
||||
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier com.electron.ollama --deep --options=runtime dist/Ollama.app
|
||||
@@ -178,7 +149,7 @@ _build_macapp() {
|
||||
|
||||
rm -f dist/Ollama-darwin.zip
|
||||
ditto -c -k --keepParent dist/Ollama.app dist/Ollama-darwin.zip
|
||||
(cd dist/Ollama.app/Contents/Resources/; tar -cf - ollama imagegen *.so *.dylib) | gzip -9vc > dist/ollama-darwin.tgz
|
||||
(cd dist/Ollama.app/Contents/Resources/; tar -cf - ollama *.so *.dylib) | gzip -9vc > dist/ollama-darwin.tgz
|
||||
|
||||
# Notarize and Staple
|
||||
if [ -n "$APPLE_IDENTITY" ]; then
|
||||
|
||||
@@ -12,17 +12,6 @@ set -eu
|
||||
|
||||
. $(dirname $0)/env.sh
|
||||
|
||||
# Check for required tools
|
||||
if ! command -v zstd >/dev/null 2>&1; then
|
||||
echo "ERROR: zstd is required but not installed." >&2
|
||||
echo "Please install zstd:" >&2
|
||||
echo " - macOS: brew install zstd" >&2
|
||||
echo " - Debian/Ubuntu: sudo apt-get install zstd" >&2
|
||||
echo " - RHEL/CentOS/Fedora: sudo dnf install zstd" >&2
|
||||
echo " - Arch: sudo pacman -S zstd" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
mkdir -p dist
|
||||
|
||||
docker buildx build \
|
||||
@@ -48,68 +37,19 @@ if echo $PLATFORM | grep "amd64" > /dev/null; then
|
||||
.
|
||||
fi
|
||||
|
||||
# Deduplicate CUDA libraries across mlx_* and cuda_* directories
|
||||
deduplicate_cuda_libs() {
|
||||
local base_dir="$1"
|
||||
echo "Deduplicating CUDA libraries in ${base_dir}..."
|
||||
|
||||
# Find all mlx_cuda_* directories
|
||||
for mlx_dir in "${base_dir}"/lib/ollama/mlx_cuda_*; do
|
||||
[ -d "${mlx_dir}" ] || continue
|
||||
|
||||
# Extract CUDA version (e.g., v12, v13)
|
||||
cuda_version=$(basename "${mlx_dir}" | sed 's/mlx_cuda_//')
|
||||
cuda_dir="${base_dir}/lib/ollama/cuda_${cuda_version}"
|
||||
|
||||
# Skip if corresponding cuda_* directory doesn't exist
|
||||
[ -d "${cuda_dir}" ] || continue
|
||||
|
||||
echo " Checking ${mlx_dir} against ${cuda_dir}..."
|
||||
|
||||
# Find all .so* files in mlx directory
|
||||
find "${mlx_dir}" -type f -name "*.so*" | while read mlx_file; do
|
||||
filename=$(basename "${mlx_file}")
|
||||
cuda_file="${cuda_dir}/${filename}"
|
||||
|
||||
# Skip if file doesn't exist in cuda directory
|
||||
[ -f "${cuda_file}" ] || continue
|
||||
|
||||
# Compare checksums
|
||||
mlx_sum=$(sha256sum "${mlx_file}" | awk '{print $1}')
|
||||
cuda_sum=$(sha256sum "${cuda_file}" | awk '{print $1}')
|
||||
|
||||
if [ "${mlx_sum}" = "${cuda_sum}" ]; then
|
||||
echo " Deduplicating ${filename}"
|
||||
# Calculate relative path from mlx_dir to cuda_dir
|
||||
rel_path="../cuda_${cuda_version}/${filename}"
|
||||
rm -f "${mlx_file}"
|
||||
ln -s "${rel_path}" "${mlx_file}"
|
||||
fi
|
||||
done
|
||||
done
|
||||
}
|
||||
|
||||
# Run deduplication for each platform output directory
|
||||
if echo $PLATFORM | grep "," > /dev/null ; then
|
||||
deduplicate_cuda_libs "./dist/linux_amd64"
|
||||
deduplicate_cuda_libs "./dist/linux_arm64"
|
||||
elif echo $PLATFORM | grep "amd64\|arm64" > /dev/null ; then
|
||||
deduplicate_cuda_libs "./dist"
|
||||
fi
|
||||
|
||||
# buildx behavior changes for single vs. multiplatform
|
||||
echo "Compressing linux tar bundles..."
|
||||
if echo $PLATFORM | grep "," > /dev/null ; then
|
||||
tar c -C ./dist/linux_arm64 --exclude cuda_jetpack5 --exclude cuda_jetpack6 . | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64.tar.zst
|
||||
tar c -C ./dist/linux_arm64 ./lib/ollama/cuda_jetpack5 | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64-jetpack5.tar.zst
|
||||
tar c -C ./dist/linux_arm64 ./lib/ollama/cuda_jetpack6 | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64-jetpack6.tar.zst
|
||||
tar c -C ./dist/linux_amd64 --exclude rocm . | zstd --ultra -22 -T0 >./dist/ollama-linux-amd64.tar.zst
|
||||
tar c -C ./dist/linux_amd64 ./lib/ollama/rocm | zstd --ultra -22 -T0 >./dist/ollama-linux-amd64-rocm.tar.zst
|
||||
tar c -C ./dist/linux_arm64 --exclude cuda_jetpack5 --exclude cuda_jetpack6 . | pigz -9vc >./dist/ollama-linux-arm64.tgz
|
||||
tar c -C ./dist/linux_arm64 ./lib/ollama/cuda_jetpack5 | pigz -9vc >./dist/ollama-linux-arm64-jetpack5.tgz
|
||||
tar c -C ./dist/linux_arm64 ./lib/ollama/cuda_jetpack6 | pigz -9vc >./dist/ollama-linux-arm64-jetpack6.tgz
|
||||
tar c -C ./dist/linux_amd64 --exclude rocm . | pigz -9vc >./dist/ollama-linux-amd64.tgz
|
||||
tar c -C ./dist/linux_amd64 ./lib/ollama/rocm | pigz -9vc >./dist/ollama-linux-amd64-rocm.tgz
|
||||
elif echo $PLATFORM | grep "arm64" > /dev/null ; then
|
||||
tar c -C ./dist/ --exclude cuda_jetpack5 --exclude cuda_jetpack6 bin lib | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64.tar.zst
|
||||
tar c -C ./dist/ ./lib/ollama/cuda_jetpack5 | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64-jetpack5.tar.zst
|
||||
tar c -C ./dist/ ./lib/ollama/cuda_jetpack6 | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64-jetpack6.tar.zst
|
||||
tar c -C ./dist/ --exclude cuda_jetpack5 --exclude cuda_jetpack6 bin lib | pigz -9vc >./dist/ollama-linux-arm64.tgz
|
||||
tar c -C ./dist/ ./lib/ollama/cuda_jetpack5 | pigz -9vc >./dist/ollama-linux-arm64-jetpack5.tgz
|
||||
tar c -C ./dist/ ./lib/ollama/cuda_jetpack6 | pigz -9vc >./dist/ollama-linux-arm64-jetpack6.tgz
|
||||
elif echo $PLATFORM | grep "amd64" > /dev/null ; then
|
||||
tar c -C ./dist/ --exclude rocm bin lib | zstd --ultra -22 -T0 >./dist/ollama-linux-amd64.tar.zst
|
||||
tar c -C ./dist/ ./lib/ollama/rocm | zstd --ultra -22 -T0 >./dist/ollama-linux-amd64-rocm.tar.zst
|
||||
tar c -C ./dist/ --exclude rocm bin lib | pigz -9vc >./dist/ollama-linux-amd64.tgz
|
||||
tar c -C ./dist/ ./lib/ollama/rocm | pigz -9vc >./dist/ollama-linux-amd64-rocm.tgz
|
||||
fi
|
||||
|
||||
@@ -66,36 +66,6 @@ if [ -n "$NEEDS" ]; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Function to download and extract with fallback from zst to tgz
|
||||
download_and_extract() {
|
||||
local url_base="$1"
|
||||
local dest_dir="$2"
|
||||
local filename="$3"
|
||||
|
||||
# Check if .tar.zst is available
|
||||
if curl --fail --silent --head --location "${url_base}/${filename}.tar.zst${VER_PARAM}" >/dev/null 2>&1; then
|
||||
# zst file exists - check if we have zstd tool
|
||||
if ! available zstd; then
|
||||
error "This version requires zstd for extraction. Please install zstd and try again:
|
||||
- Debian/Ubuntu: sudo apt-get install zstd
|
||||
- RHEL/CentOS/Fedora: sudo dnf install zstd
|
||||
- Arch: sudo pacman -S zstd"
|
||||
fi
|
||||
|
||||
status "Downloading ${filename}.tar.zst"
|
||||
curl --fail --show-error --location --progress-bar \
|
||||
"${url_base}/${filename}.tar.zst${VER_PARAM}" | \
|
||||
zstd -d | $SUDO tar -xf - -C "${dest_dir}"
|
||||
return 0
|
||||
fi
|
||||
|
||||
# Fall back to .tgz for older versions
|
||||
status "Downloading ${filename}.tgz"
|
||||
curl --fail --show-error --location --progress-bar \
|
||||
"${url_base}/${filename}.tgz${VER_PARAM}" | \
|
||||
$SUDO tar -xzf - -C "${dest_dir}"
|
||||
}
|
||||
|
||||
for BINDIR in /usr/local/bin /usr/bin /bin; do
|
||||
echo $PATH | grep -q $BINDIR && break || continue
|
||||
done
|
||||
@@ -108,7 +78,10 @@ fi
|
||||
status "Installing ollama to $OLLAMA_INSTALL_DIR"
|
||||
$SUDO install -o0 -g0 -m755 -d $BINDIR
|
||||
$SUDO install -o0 -g0 -m755 -d "$OLLAMA_INSTALL_DIR/lib/ollama"
|
||||
download_and_extract "https://ollama.com/download" "$OLLAMA_INSTALL_DIR" "ollama-linux-${ARCH}"
|
||||
status "Downloading Linux ${ARCH} bundle"
|
||||
curl --fail --show-error --location --progress-bar \
|
||||
"https://ollama.com/download/ollama-linux-${ARCH}.tgz${VER_PARAM}" | \
|
||||
$SUDO tar -xzf - -C "$OLLAMA_INSTALL_DIR"
|
||||
|
||||
if [ "$OLLAMA_INSTALL_DIR/bin/ollama" != "$BINDIR/ollama" ] ; then
|
||||
status "Making ollama accessible in the PATH in $BINDIR"
|
||||
@@ -118,9 +91,15 @@ fi
|
||||
# Check for NVIDIA JetPack systems with additional downloads
|
||||
if [ -f /etc/nv_tegra_release ] ; then
|
||||
if grep R36 /etc/nv_tegra_release > /dev/null ; then
|
||||
download_and_extract "https://ollama.com/download" "$OLLAMA_INSTALL_DIR" "ollama-linux-${ARCH}-jetpack6"
|
||||
status "Downloading JetPack 6 components"
|
||||
curl --fail --show-error --location --progress-bar \
|
||||
"https://ollama.com/download/ollama-linux-${ARCH}-jetpack6.tgz${VER_PARAM}" | \
|
||||
$SUDO tar -xzf - -C "$OLLAMA_INSTALL_DIR"
|
||||
elif grep R35 /etc/nv_tegra_release > /dev/null ; then
|
||||
download_and_extract "https://ollama.com/download" "$OLLAMA_INSTALL_DIR" "ollama-linux-${ARCH}-jetpack5"
|
||||
status "Downloading JetPack 5 components"
|
||||
curl --fail --show-error --location --progress-bar \
|
||||
"https://ollama.com/download/ollama-linux-${ARCH}-jetpack5.tgz${VER_PARAM}" | \
|
||||
$SUDO tar -xzf - -C "$OLLAMA_INSTALL_DIR"
|
||||
else
|
||||
warning "Unsupported JetPack version detected. GPU may not be supported"
|
||||
fi
|
||||
@@ -243,7 +222,10 @@ if ! check_gpu lspci nvidia && ! check_gpu lshw nvidia && ! check_gpu lspci amdg
|
||||
fi
|
||||
|
||||
if check_gpu lspci amdgpu || check_gpu lshw amdgpu; then
|
||||
download_and_extract "https://ollama.com/download" "$OLLAMA_INSTALL_DIR" "ollama-linux-${ARCH}-rocm"
|
||||
status "Downloading Linux ROCm ${ARCH} bundle"
|
||||
curl --fail --show-error --location --progress-bar \
|
||||
"https://ollama.com/download/ollama-linux-${ARCH}-rocm.tgz${VER_PARAM}" | \
|
||||
$SUDO tar -xzf - -C "$OLLAMA_INSTALL_DIR"
|
||||
|
||||
install_success
|
||||
status "AMD GPU ready."
|
||||
|
||||
@@ -26,7 +26,6 @@ import (
|
||||
"github.com/ollama/ollama/convert"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
ofs "github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/template"
|
||||
"github.com/ollama/ollama/types/errtypes"
|
||||
@@ -455,7 +454,7 @@ func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, is
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
func kvFromLayers(baseLayers []*layerGGML) (ofs.Config, error) {
|
||||
func kvFromLayers(baseLayers []*layerGGML) (ggml.KV, error) {
|
||||
for _, l := range baseLayers {
|
||||
if l.GGML != nil {
|
||||
return l.KV(), nil
|
||||
|
||||
128
server/routes.go
128
server/routes.go
@@ -20,6 +20,7 @@ import (
|
||||
"net/url"
|
||||
"os"
|
||||
"os/signal"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
@@ -44,6 +45,7 @@ import (
|
||||
"github.com/ollama/ollama/model/renderers"
|
||||
"github.com/ollama/ollama/server/internal/client/ollama"
|
||||
"github.com/ollama/ollama/server/internal/registry"
|
||||
"github.com/ollama/ollama/server/usage"
|
||||
"github.com/ollama/ollama/template"
|
||||
"github.com/ollama/ollama/thinking"
|
||||
"github.com/ollama/ollama/tools"
|
||||
@@ -82,6 +84,7 @@ type Server struct {
|
||||
addr net.Addr
|
||||
sched *Scheduler
|
||||
lowVRAM bool
|
||||
stats *usage.Stats
|
||||
}
|
||||
|
||||
func init() {
|
||||
@@ -104,6 +107,30 @@ var (
|
||||
errBadTemplate = errors.New("template error")
|
||||
)
|
||||
|
||||
// usage records a request to usage stats if enabled.
|
||||
func (s *Server) usage(c *gin.Context, endpoint, model, architecture string, promptTokens, completionTokens int, usedTools bool) {
|
||||
if s.stats == nil {
|
||||
return
|
||||
}
|
||||
s.stats.Record(&usage.Request{
|
||||
Endpoint: endpoint,
|
||||
Model: model,
|
||||
Architecture: architecture,
|
||||
APIType: usage.ClassifyAPIType(c.Request.URL.Path),
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
UsedTools: usedTools,
|
||||
})
|
||||
}
|
||||
|
||||
// usageError records a failed request to usage stats if enabled.
|
||||
func (s *Server) usageError() {
|
||||
if s.stats == nil {
|
||||
return
|
||||
}
|
||||
s.stats.RecordError()
|
||||
}
|
||||
|
||||
func modelOptions(model *Model, requestOpts map[string]any) (api.Options, error) {
|
||||
opts := api.DefaultOptions()
|
||||
if err := opts.FromMap(model.Options); err != nil {
|
||||
@@ -374,7 +401,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
|
||||
return
|
||||
} else if err != nil {
|
||||
handleScheduleError(c, req.Model, err)
|
||||
s.handleScheduleError(c, req.Model, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -561,6 +588,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
res.DoneReason = cr.DoneReason.String()
|
||||
res.TotalDuration = time.Since(checkpointStart)
|
||||
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||
s.usage(c, "generate", m.ShortName, m.Config.ModelFamily, cr.PromptEvalCount, cr.EvalCount, false)
|
||||
|
||||
if !req.Raw {
|
||||
tokens, err := r.Tokenize(c.Request.Context(), prompt+sb.String())
|
||||
@@ -680,7 +708,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||
|
||||
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{}, req.Options, req.KeepAlive)
|
||||
if err != nil {
|
||||
handleScheduleError(c, req.Model, err)
|
||||
s.handleScheduleError(c, req.Model, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -790,6 +818,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||
LoadDuration: checkpointLoaded.Sub(checkpointStart),
|
||||
PromptEvalCount: int(totalTokens),
|
||||
}
|
||||
s.usage(c, "embed", m.ShortName, m.Config.ModelFamily, int(totalTokens), 0, false)
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
@@ -827,7 +856,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
||||
|
||||
r, _, _, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{}, req.Options, req.KeepAlive)
|
||||
if err != nil {
|
||||
handleScheduleError(c, req.Model, err)
|
||||
s.handleScheduleError(c, req.Model, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1531,6 +1560,7 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||
|
||||
// Inference
|
||||
r.GET("/api/ps", s.PsHandler)
|
||||
r.GET("/api/usage", s.UsageHandler)
|
||||
r.POST("/api/generate", s.GenerateHandler)
|
||||
r.POST("/api/chat", s.ChatHandler)
|
||||
r.POST("/api/embed", s.EmbedHandler)
|
||||
@@ -1593,6 +1623,13 @@ func Serve(ln net.Listener) error {
|
||||
|
||||
s := &Server{addr: ln.Addr()}
|
||||
|
||||
// Initialize usage stats if enabled
|
||||
if envconfig.Usage() {
|
||||
s.stats = usage.New()
|
||||
s.stats.Start()
|
||||
slog.Info("usage stats enabled")
|
||||
}
|
||||
|
||||
var rc *ollama.Registry
|
||||
if useClient2 {
|
||||
var err error
|
||||
@@ -1632,6 +1669,9 @@ func Serve(ln net.Listener) error {
|
||||
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-signals
|
||||
if s.stats != nil {
|
||||
s.stats.Stop()
|
||||
}
|
||||
srvr.Close()
|
||||
schedDone()
|
||||
sched.unloadAllRunners()
|
||||
@@ -1649,6 +1689,24 @@ func Serve(ln net.Listener) error {
|
||||
gpus := discover.GPUDevices(ctx, nil)
|
||||
discover.LogDetails(gpus)
|
||||
|
||||
// Set GPU info for usage reporting
|
||||
if s.stats != nil {
|
||||
usage.GPUInfoFunc = func() []usage.GPU {
|
||||
var result []usage.GPU
|
||||
for _, gpu := range gpus {
|
||||
result = append(result, usage.GPU{
|
||||
Name: gpu.Name,
|
||||
VRAMBytes: gpu.TotalMemory,
|
||||
ComputeMajor: gpu.ComputeMajor,
|
||||
ComputeMinor: gpu.ComputeMinor,
|
||||
DriverMajor: gpu.DriverMajor,
|
||||
DriverMinor: gpu.DriverMinor,
|
||||
})
|
||||
}
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
var totalVRAM uint64
|
||||
for _, gpu := range gpus {
|
||||
totalVRAM += gpu.TotalMemory - envconfig.GpuOverhead()
|
||||
@@ -1852,6 +1910,63 @@ func (s *Server) PsHandler(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, api.ProcessResponse{Models: models})
|
||||
}
|
||||
|
||||
func (s *Server) UsageHandler(c *gin.Context) {
|
||||
// Get total VRAM used by Ollama
|
||||
s.sched.loadedMu.Lock()
|
||||
var totalOllamaVRAM uint64
|
||||
for _, runner := range s.sched.loaded {
|
||||
totalOllamaVRAM += runner.vramSize
|
||||
}
|
||||
s.sched.loadedMu.Unlock()
|
||||
|
||||
var resp api.UsageResponse
|
||||
|
||||
// Get GPU/device info
|
||||
gpus := discover.GPUDevices(c.Request.Context(), nil)
|
||||
|
||||
// On Apple Silicon, use system memory instead of Metal's recommendedMaxWorkingSetSize
|
||||
// because unified memory means GPU and CPU share the same physical RAM pool
|
||||
var sysTotal, sysFree uint64
|
||||
if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" {
|
||||
sysInfo := discover.GetSystemInfo()
|
||||
sysTotal = sysInfo.TotalMemory
|
||||
sysFree = sysInfo.FreeMemory
|
||||
}
|
||||
|
||||
for _, gpu := range gpus {
|
||||
total := gpu.TotalMemory
|
||||
free := gpu.FreeMemory
|
||||
|
||||
// On Apple Silicon, override with system memory values
|
||||
if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" && sysTotal > 0 {
|
||||
total = sysTotal
|
||||
free = sysFree
|
||||
}
|
||||
|
||||
used := total - free
|
||||
ollamaUsed := min(totalOllamaVRAM, used)
|
||||
otherUsed := used - ollamaUsed
|
||||
|
||||
// Use Description for Name (actual device name like "Apple M2 Max")
|
||||
// Fall back to backend name if Description is empty
|
||||
name := gpu.Description
|
||||
if name == "" {
|
||||
name = gpu.Name
|
||||
}
|
||||
|
||||
resp.GPUs = append(resp.GPUs, api.GPUUsage{
|
||||
Name: name,
|
||||
Backend: gpu.Library,
|
||||
Total: total,
|
||||
Free: free,
|
||||
Used: ollamaUsed,
|
||||
Other: otherUsed,
|
||||
})
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
func toolCallId() string {
|
||||
const letterBytes = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||
b := make([]byte, 8)
|
||||
@@ -2032,7 +2147,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
|
||||
return
|
||||
} else if err != nil {
|
||||
handleScheduleError(c, req.Model, err)
|
||||
s.handleScheduleError(c, req.Model, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -2180,6 +2295,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
res.DoneReason = r.DoneReason.String()
|
||||
res.TotalDuration = time.Since(checkpointStart)
|
||||
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||
s.usage(c, "chat", m.ShortName, m.Config.ModelFamily, r.PromptEvalCount, r.EvalCount, len(req.Tools) > 0)
|
||||
}
|
||||
|
||||
if builtinParser != nil {
|
||||
@@ -2355,6 +2471,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
resp.Message.ToolCalls = toolCalls
|
||||
}
|
||||
|
||||
s.usage(c, "chat", m.ShortName, m.Config.ModelFamily, resp.PromptEvalCount, resp.EvalCount, len(toolCalls) > 0)
|
||||
c.JSON(http.StatusOK, resp)
|
||||
return
|
||||
}
|
||||
@@ -2362,7 +2479,8 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
streamResponse(c, ch)
|
||||
}
|
||||
|
||||
func handleScheduleError(c *gin.Context, name string, err error) {
|
||||
func (s *Server) handleScheduleError(c *gin.Context, name string, err error) {
|
||||
s.usageError()
|
||||
switch {
|
||||
case errors.Is(err, errCapabilities), errors.Is(err, errRequired):
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
|
||||
@@ -22,7 +22,6 @@ import (
|
||||
gocmpopts "github.com/google/go-cmp/cmp/cmpopts"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/convert"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
@@ -42,7 +41,7 @@ func createBinFile(t *testing.T, kv map[string]any, ti []*ggml.Tensor) (string,
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var base convert.KV = map[string]any{"general.architecture": "test"}
|
||||
base := map[string]any{"general.architecture": "test"}
|
||||
maps.Copy(base, kv)
|
||||
|
||||
if err := ggml.WriteGGUF(f, base, ti); err != nil {
|
||||
|
||||
60
server/routes_usage_test.go
Normal file
60
server/routes_usage_test.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestUsageHandler(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
t.Run("empty server", func(t *testing.T) {
|
||||
s := Server{
|
||||
sched: &Scheduler{
|
||||
loaded: make(map[string]*runnerRef),
|
||||
},
|
||||
}
|
||||
|
||||
w := createRequest(t, s.UsageHandler, nil)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||
}
|
||||
|
||||
var resp api.UsageResponse
|
||||
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// GPUs may or may not be present depending on system
|
||||
// Just verify we can decode the response
|
||||
})
|
||||
|
||||
t.Run("response structure", func(t *testing.T) {
|
||||
s := Server{
|
||||
sched: &Scheduler{
|
||||
loaded: make(map[string]*runnerRef),
|
||||
},
|
||||
}
|
||||
|
||||
w := createRequest(t, s.UsageHandler, nil)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||
}
|
||||
|
||||
// Verify we can decode the response as valid JSON
|
||||
var resp map[string]any
|
||||
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// The response should be a valid object (not null)
|
||||
if resp == nil {
|
||||
t.Error("expected non-nil response")
|
||||
}
|
||||
})
|
||||
}
|
||||
65
server/usage/reporter.go
Normal file
65
server/usage/reporter.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package usage
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
|
||||
const (
|
||||
reportTimeout = 10 * time.Second
|
||||
usageURL = "https://ollama.com/api/usage"
|
||||
)
|
||||
|
||||
// HeartbeatResponse is the response from the heartbeat endpoint.
|
||||
type HeartbeatResponse struct {
|
||||
UpdateVersion string `json:"update_version,omitempty"`
|
||||
}
|
||||
|
||||
// UpdateAvailable returns the available update version, if any.
|
||||
func (t *Stats) UpdateAvailable() string {
|
||||
if v := t.updateAvailable.Load(); v != nil {
|
||||
return v.(string)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// sendHeartbeat sends usage stats and checks for updates.
|
||||
func (t *Stats) sendHeartbeat(payload *Payload) {
|
||||
data, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), reportTimeout)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, usageURL, bytes.NewReader(data))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", fmt.Sprintf("ollama/%s", version.Version))
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return
|
||||
}
|
||||
|
||||
var heartbeat HeartbeatResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&heartbeat); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
t.updateAvailable.Store(heartbeat.UpdateVersion)
|
||||
}
|
||||
23
server/usage/source.go
Normal file
23
server/usage/source.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package usage
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// API type constants
|
||||
const (
|
||||
APITypeOllama = "ollama"
|
||||
APITypeOpenAI = "openai"
|
||||
APITypeAnthropic = "anthropic"
|
||||
)
|
||||
|
||||
// ClassifyAPIType determines the API type from the request path.
|
||||
func ClassifyAPIType(path string) string {
|
||||
if strings.HasPrefix(path, "/v1/messages") {
|
||||
return APITypeAnthropic
|
||||
}
|
||||
if strings.HasPrefix(path, "/v1/") {
|
||||
return APITypeOpenAI
|
||||
}
|
||||
return APITypeOllama
|
||||
}
|
||||
324
server/usage/usage.go
Normal file
324
server/usage/usage.go
Normal file
@@ -0,0 +1,324 @@
|
||||
// Package usage provides in-memory usage statistics collection and reporting.
|
||||
package usage
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/discover"
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
|
||||
// Stats collects usage statistics in memory and reports them periodically.
|
||||
type Stats struct {
|
||||
mu sync.RWMutex
|
||||
|
||||
// Atomic counters for hot path
|
||||
requestsTotal atomic.Int64
|
||||
tokensPrompt atomic.Int64
|
||||
tokensCompletion atomic.Int64
|
||||
errorsTotal atomic.Int64
|
||||
|
||||
// Map-based counters (require lock)
|
||||
endpoints map[string]int64
|
||||
architectures map[string]int64
|
||||
apis map[string]int64
|
||||
models map[string]*ModelStats // per-model stats
|
||||
|
||||
// Feature usage
|
||||
toolCalls atomic.Int64
|
||||
structuredOutput atomic.Int64
|
||||
|
||||
// Update info (set by reporter after pinging update endpoint)
|
||||
updateAvailable atomic.Value // string
|
||||
|
||||
// Reporter
|
||||
stopCh chan struct{}
|
||||
doneCh chan struct{}
|
||||
interval time.Duration
|
||||
endpoint string
|
||||
}
|
||||
|
||||
// ModelStats tracks per-model usage statistics.
|
||||
type ModelStats struct {
|
||||
Requests int64
|
||||
TokensInput int64
|
||||
TokensOutput int64
|
||||
}
|
||||
|
||||
// Request contains the data to record for a single request.
|
||||
type Request struct {
|
||||
Endpoint string // "chat", "generate", "embed"
|
||||
Model string // model name (e.g., "llama3.2:3b")
|
||||
Architecture string // model architecture (e.g., "llama", "qwen2")
|
||||
APIType string // "native" or "openai_compat"
|
||||
PromptTokens int
|
||||
CompletionTokens int
|
||||
UsedTools bool
|
||||
StructuredOutput bool
|
||||
}
|
||||
|
||||
// SystemInfo contains hardware information to report.
|
||||
type SystemInfo struct {
|
||||
OS string `json:"os"`
|
||||
Arch string `json:"arch"`
|
||||
CPUCores int `json:"cpu_cores"`
|
||||
RAMBytes uint64 `json:"ram_bytes"`
|
||||
GPUs []GPU `json:"gpus,omitempty"`
|
||||
}
|
||||
|
||||
// GPU contains information about a GPU.
|
||||
type GPU struct {
|
||||
Name string `json:"name"`
|
||||
VRAMBytes uint64 `json:"vram_bytes"`
|
||||
ComputeMajor int `json:"compute_major,omitempty"`
|
||||
ComputeMinor int `json:"compute_minor,omitempty"`
|
||||
DriverMajor int `json:"driver_major,omitempty"`
|
||||
DriverMinor int `json:"driver_minor,omitempty"`
|
||||
}
|
||||
|
||||
// Payload is the data sent to the heartbeat endpoint.
|
||||
type Payload struct {
|
||||
Version string `json:"version"`
|
||||
Time time.Time `json:"time"`
|
||||
System SystemInfo `json:"system"`
|
||||
|
||||
Totals struct {
|
||||
Requests int64 `json:"requests"`
|
||||
Errors int64 `json:"errors"`
|
||||
InputTokens int64 `json:"input_tokens"`
|
||||
OutputTokens int64 `json:"output_tokens"`
|
||||
} `json:"totals"`
|
||||
|
||||
Endpoints map[string]int64 `json:"endpoints"`
|
||||
Architectures map[string]int64 `json:"architectures"`
|
||||
APIs map[string]int64 `json:"apis"`
|
||||
|
||||
Features struct {
|
||||
ToolCalls int64 `json:"tool_calls"`
|
||||
StructuredOutput int64 `json:"structured_output"`
|
||||
} `json:"features"`
|
||||
}
|
||||
|
||||
const (
|
||||
defaultInterval = 1 * time.Hour
|
||||
)
|
||||
|
||||
// New creates a new Stats instance.
|
||||
func New(opts ...Option) *Stats {
|
||||
t := &Stats{
|
||||
endpoints: make(map[string]int64),
|
||||
architectures: make(map[string]int64),
|
||||
apis: make(map[string]int64),
|
||||
models: make(map[string]*ModelStats),
|
||||
stopCh: make(chan struct{}),
|
||||
doneCh: make(chan struct{}),
|
||||
interval: defaultInterval,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(t)
|
||||
}
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
// Option configures the Stats instance.
|
||||
type Option func(*Stats)
|
||||
|
||||
// WithInterval sets the reporting interval.
|
||||
func WithInterval(d time.Duration) Option {
|
||||
return func(t *Stats) {
|
||||
t.interval = d
|
||||
}
|
||||
}
|
||||
|
||||
// Record records a request. This is the hot path and should be fast.
|
||||
func (t *Stats) Record(r *Request) {
|
||||
t.requestsTotal.Add(1)
|
||||
t.tokensPrompt.Add(int64(r.PromptTokens))
|
||||
t.tokensCompletion.Add(int64(r.CompletionTokens))
|
||||
|
||||
if r.UsedTools {
|
||||
t.toolCalls.Add(1)
|
||||
}
|
||||
if r.StructuredOutput {
|
||||
t.structuredOutput.Add(1)
|
||||
}
|
||||
|
||||
t.mu.Lock()
|
||||
t.endpoints[r.Endpoint]++
|
||||
t.architectures[r.Architecture]++
|
||||
t.apis[r.APIType]++
|
||||
|
||||
// Track per-model stats
|
||||
if r.Model != "" {
|
||||
if t.models[r.Model] == nil {
|
||||
t.models[r.Model] = &ModelStats{}
|
||||
}
|
||||
t.models[r.Model].Requests++
|
||||
t.models[r.Model].TokensInput += int64(r.PromptTokens)
|
||||
t.models[r.Model].TokensOutput += int64(r.CompletionTokens)
|
||||
}
|
||||
t.mu.Unlock()
|
||||
}
|
||||
|
||||
// RecordError records a failed request.
|
||||
func (t *Stats) RecordError() {
|
||||
t.errorsTotal.Add(1)
|
||||
}
|
||||
|
||||
// GetModelStats returns a copy of per-model statistics.
|
||||
func (t *Stats) GetModelStats() map[string]*ModelStats {
|
||||
t.mu.RLock()
|
||||
defer t.mu.RUnlock()
|
||||
|
||||
result := make(map[string]*ModelStats, len(t.models))
|
||||
for k, v := range t.models {
|
||||
result[k] = &ModelStats{
|
||||
Requests: v.Requests,
|
||||
TokensInput: v.TokensInput,
|
||||
TokensOutput: v.TokensOutput,
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// View returns current stats without resetting counters.
|
||||
func (t *Stats) View() *Payload {
|
||||
t.mu.RLock()
|
||||
defer t.mu.RUnlock()
|
||||
|
||||
now := time.Now()
|
||||
|
||||
// Copy maps
|
||||
endpoints := make(map[string]int64, len(t.endpoints))
|
||||
for k, v := range t.endpoints {
|
||||
endpoints[k] = v
|
||||
}
|
||||
architectures := make(map[string]int64, len(t.architectures))
|
||||
for k, v := range t.architectures {
|
||||
architectures[k] = v
|
||||
}
|
||||
apis := make(map[string]int64, len(t.apis))
|
||||
for k, v := range t.apis {
|
||||
apis[k] = v
|
||||
}
|
||||
|
||||
p := &Payload{
|
||||
Version: version.Version,
|
||||
Time: now,
|
||||
System: getSystemInfo(),
|
||||
Endpoints: endpoints,
|
||||
Architectures: architectures,
|
||||
APIs: apis,
|
||||
}
|
||||
|
||||
p.Totals.Requests = t.requestsTotal.Load()
|
||||
p.Totals.Errors = t.errorsTotal.Load()
|
||||
p.Totals.InputTokens = t.tokensPrompt.Load()
|
||||
p.Totals.OutputTokens = t.tokensCompletion.Load()
|
||||
p.Features.ToolCalls = t.toolCalls.Load()
|
||||
p.Features.StructuredOutput = t.structuredOutput.Load()
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
// Snapshot returns current stats and resets counters.
|
||||
func (t *Stats) Snapshot() *Payload {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
p := &Payload{
|
||||
Version: version.Version,
|
||||
Time: now,
|
||||
System: getSystemInfo(),
|
||||
Endpoints: t.endpoints,
|
||||
Architectures: t.architectures,
|
||||
APIs: t.apis,
|
||||
}
|
||||
|
||||
p.Totals.Requests = t.requestsTotal.Swap(0)
|
||||
p.Totals.Errors = t.errorsTotal.Swap(0)
|
||||
p.Totals.InputTokens = t.tokensPrompt.Swap(0)
|
||||
p.Totals.OutputTokens = t.tokensCompletion.Swap(0)
|
||||
p.Features.ToolCalls = t.toolCalls.Swap(0)
|
||||
p.Features.StructuredOutput = t.structuredOutput.Swap(0)
|
||||
|
||||
// Reset maps
|
||||
t.endpoints = make(map[string]int64)
|
||||
t.architectures = make(map[string]int64)
|
||||
t.apis = make(map[string]int64)
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
// getSystemInfo collects hardware information.
|
||||
func getSystemInfo() SystemInfo {
|
||||
info := SystemInfo{
|
||||
OS: runtime.GOOS,
|
||||
Arch: runtime.GOARCH,
|
||||
}
|
||||
|
||||
// Get CPU and memory info
|
||||
sysInfo := discover.GetSystemInfo()
|
||||
info.CPUCores = sysInfo.ThreadCount
|
||||
info.RAMBytes = sysInfo.TotalMemory
|
||||
|
||||
// Get GPU info
|
||||
gpus := getGPUInfo()
|
||||
info.GPUs = gpus
|
||||
|
||||
return info
|
||||
}
|
||||
|
||||
// GPUInfoFunc is a function that returns GPU information.
|
||||
// It's set by the server package after GPU discovery.
|
||||
var GPUInfoFunc func() []GPU
|
||||
|
||||
// getGPUInfo collects GPU information.
|
||||
func getGPUInfo() []GPU {
|
||||
if GPUInfoFunc != nil {
|
||||
return GPUInfoFunc()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Start begins the periodic reporting goroutine.
|
||||
func (t *Stats) Start() {
|
||||
go t.reportLoop()
|
||||
}
|
||||
|
||||
// Stop stops reporting and waits for the final report.
|
||||
func (t *Stats) Stop() {
|
||||
close(t.stopCh)
|
||||
<-t.doneCh
|
||||
}
|
||||
|
||||
// reportLoop runs the periodic reporting.
|
||||
func (t *Stats) reportLoop() {
|
||||
defer close(t.doneCh)
|
||||
|
||||
ticker := time.NewTicker(t.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
t.report()
|
||||
case <-t.stopCh:
|
||||
// Send final report before stopping
|
||||
t.report()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// report sends usage stats and checks for updates.
|
||||
func (t *Stats) report() {
|
||||
payload := t.Snapshot()
|
||||
t.sendHeartbeat(payload)
|
||||
}
|
||||
194
server/usage/usage_test.go
Normal file
194
server/usage/usage_test.go
Normal file
@@ -0,0 +1,194 @@
|
||||
package usage
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
stats := New()
|
||||
if stats == nil {
|
||||
t.Fatal("New() returned nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecord(t *testing.T) {
|
||||
stats := New()
|
||||
|
||||
stats.Record(&Request{
|
||||
Model: "llama3:8b",
|
||||
Endpoint: "chat",
|
||||
Architecture: "llama",
|
||||
APIType: "native",
|
||||
PromptTokens: 100,
|
||||
CompletionTokens: 50,
|
||||
UsedTools: true,
|
||||
StructuredOutput: false,
|
||||
})
|
||||
|
||||
// Check totals
|
||||
payload := stats.View()
|
||||
if payload.Totals.Requests != 1 {
|
||||
t.Errorf("expected 1 request, got %d", payload.Totals.Requests)
|
||||
}
|
||||
if payload.Totals.InputTokens != 100 {
|
||||
t.Errorf("expected 100 prompt tokens, got %d", payload.Totals.InputTokens)
|
||||
}
|
||||
if payload.Totals.OutputTokens != 50 {
|
||||
t.Errorf("expected 50 completion tokens, got %d", payload.Totals.OutputTokens)
|
||||
}
|
||||
if payload.Features.ToolCalls != 1 {
|
||||
t.Errorf("expected 1 tool call, got %d", payload.Features.ToolCalls)
|
||||
}
|
||||
if payload.Features.StructuredOutput != 0 {
|
||||
t.Errorf("expected 0 structured outputs, got %d", payload.Features.StructuredOutput)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetModelStats(t *testing.T) {
|
||||
stats := New()
|
||||
|
||||
// Record requests for multiple models
|
||||
stats.Record(&Request{
|
||||
Model: "llama3:8b",
|
||||
PromptTokens: 100,
|
||||
CompletionTokens: 50,
|
||||
})
|
||||
stats.Record(&Request{
|
||||
Model: "llama3:8b",
|
||||
PromptTokens: 200,
|
||||
CompletionTokens: 100,
|
||||
})
|
||||
stats.Record(&Request{
|
||||
Model: "mistral:7b",
|
||||
PromptTokens: 50,
|
||||
CompletionTokens: 25,
|
||||
})
|
||||
|
||||
modelStats := stats.GetModelStats()
|
||||
|
||||
// Check llama3:8b stats
|
||||
llama := modelStats["llama3:8b"]
|
||||
if llama == nil {
|
||||
t.Fatal("expected llama3:8b stats")
|
||||
}
|
||||
if llama.Requests != 2 {
|
||||
t.Errorf("expected 2 requests for llama3:8b, got %d", llama.Requests)
|
||||
}
|
||||
if llama.TokensInput != 300 {
|
||||
t.Errorf("expected 300 input tokens for llama3:8b, got %d", llama.TokensInput)
|
||||
}
|
||||
if llama.TokensOutput != 150 {
|
||||
t.Errorf("expected 150 output tokens for llama3:8b, got %d", llama.TokensOutput)
|
||||
}
|
||||
|
||||
// Check mistral:7b stats
|
||||
mistral := modelStats["mistral:7b"]
|
||||
if mistral == nil {
|
||||
t.Fatal("expected mistral:7b stats")
|
||||
}
|
||||
if mistral.Requests != 1 {
|
||||
t.Errorf("expected 1 request for mistral:7b, got %d", mistral.Requests)
|
||||
}
|
||||
if mistral.TokensInput != 50 {
|
||||
t.Errorf("expected 50 input tokens for mistral:7b, got %d", mistral.TokensInput)
|
||||
}
|
||||
if mistral.TokensOutput != 25 {
|
||||
t.Errorf("expected 25 output tokens for mistral:7b, got %d", mistral.TokensOutput)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordError(t *testing.T) {
|
||||
stats := New()
|
||||
|
||||
stats.RecordError()
|
||||
stats.RecordError()
|
||||
|
||||
payload := stats.View()
|
||||
if payload.Totals.Errors != 2 {
|
||||
t.Errorf("expected 2 errors, got %d", payload.Totals.Errors)
|
||||
}
|
||||
}
|
||||
|
||||
func TestView(t *testing.T) {
|
||||
stats := New()
|
||||
|
||||
stats.Record(&Request{
|
||||
Model: "llama3:8b",
|
||||
Endpoint: "chat",
|
||||
Architecture: "llama",
|
||||
APIType: "native",
|
||||
})
|
||||
|
||||
// First view
|
||||
_ = stats.View()
|
||||
|
||||
// View should not reset counters
|
||||
payload := stats.View()
|
||||
if payload.Totals.Requests != 1 {
|
||||
t.Errorf("View should not reset counters, expected 1 request, got %d", payload.Totals.Requests)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSnapshot(t *testing.T) {
|
||||
stats := New()
|
||||
|
||||
stats.Record(&Request{
|
||||
Model: "llama3:8b",
|
||||
Endpoint: "chat",
|
||||
PromptTokens: 100,
|
||||
CompletionTokens: 50,
|
||||
})
|
||||
|
||||
// Snapshot should return data and reset counters
|
||||
snapshot := stats.Snapshot()
|
||||
if snapshot.Totals.Requests != 1 {
|
||||
t.Errorf("expected 1 request in snapshot, got %d", snapshot.Totals.Requests)
|
||||
}
|
||||
|
||||
// After snapshot, counters should be reset
|
||||
payload2 := stats.View()
|
||||
if payload2.Totals.Requests != 0 {
|
||||
t.Errorf("expected 0 requests after snapshot, got %d", payload2.Totals.Requests)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcurrentAccess(t *testing.T) {
|
||||
stats := New()
|
||||
|
||||
done := make(chan bool)
|
||||
|
||||
// Concurrent writes
|
||||
for i := 0; i < 10; i++ {
|
||||
go func() {
|
||||
for j := 0; j < 100; j++ {
|
||||
stats.Record(&Request{
|
||||
Model: "llama3:8b",
|
||||
PromptTokens: 10,
|
||||
CompletionTokens: 5,
|
||||
})
|
||||
}
|
||||
done <- true
|
||||
}()
|
||||
}
|
||||
|
||||
// Concurrent reads
|
||||
for i := 0; i < 5; i++ {
|
||||
go func() {
|
||||
for j := 0; j < 100; j++ {
|
||||
_ = stats.View()
|
||||
_ = stats.GetModelStats()
|
||||
}
|
||||
done <- true
|
||||
}()
|
||||
}
|
||||
|
||||
// Wait for all goroutines
|
||||
for i := 0; i < 15; i++ {
|
||||
<-done
|
||||
}
|
||||
|
||||
payload := stats.View()
|
||||
if payload.Totals.Requests != 1000 {
|
||||
t.Errorf("expected 1000 requests, got %d", payload.Totals.Requests)
|
||||
}
|
||||
}
|
||||
24
x/README.md
24
x/README.md
@@ -1,24 +0,0 @@
|
||||
# Experimental Features
|
||||
|
||||
## MLX Backend
|
||||
|
||||
We're working on a new experimental backend based on the [MLX project](https://github.com/ml-explore/mlx)
|
||||
|
||||
Support is currently limited to MacOS and Linux with CUDA GPUs. We're looking to add support for Windows CUDA soon, and other GPU vendors. To build:
|
||||
|
||||
```
|
||||
cmake --preset MLX
|
||||
cmake --build --preset MLX --parallel
|
||||
cmake --install --component MLX
|
||||
go build -tags mlx .
|
||||
```
|
||||
|
||||
On linux, use the preset "MLX CUDA 13" or "MLX CUDA 12" to enable CUDA with the default Ollama NVIDIA GPU architectures enabled.
|
||||
|
||||
## Image Generation
|
||||
|
||||
Based on the experimental MLX backend, we're working on adding imagegen support. After running the cmake commands above:
|
||||
|
||||
```
|
||||
go build -o imagegen ./x/imagegen/cmd/engine
|
||||
```
|
||||
@@ -37,25 +37,6 @@ var optionLabels = []string{
|
||||
"3. Deny",
|
||||
}
|
||||
|
||||
// toolDisplayNames maps internal tool names to human-readable display names.
|
||||
var toolDisplayNames = map[string]string{
|
||||
"bash": "Bash",
|
||||
"web_search": "Web Search",
|
||||
}
|
||||
|
||||
// ToolDisplayName returns the human-readable display name for a tool.
|
||||
func ToolDisplayName(toolName string) string {
|
||||
if displayName, ok := toolDisplayNames[toolName]; ok {
|
||||
return displayName
|
||||
}
|
||||
// Default: capitalize first letter and replace underscores with spaces
|
||||
name := strings.ReplaceAll(toolName, "_", " ")
|
||||
if len(name) > 0 {
|
||||
return strings.ToUpper(name[:1]) + name[1:]
|
||||
}
|
||||
return toolName
|
||||
}
|
||||
|
||||
// autoAllowCommands are commands that are always allowed without prompting.
|
||||
// These are zero-risk, read-only commands.
|
||||
var autoAllowCommands = map[string]bool{
|
||||
@@ -528,12 +509,11 @@ func (a *ApprovalManager) RequestApproval(toolName string, args map[string]any)
|
||||
// formatToolDisplay creates the display string for a tool call.
|
||||
func formatToolDisplay(toolName string, args map[string]any) string {
|
||||
var sb strings.Builder
|
||||
displayName := ToolDisplayName(toolName)
|
||||
|
||||
// For bash, show command directly
|
||||
if toolName == "bash" {
|
||||
if cmd, ok := args["command"].(string); ok {
|
||||
sb.WriteString(fmt.Sprintf("Tool: %s\n", displayName))
|
||||
sb.WriteString(fmt.Sprintf("Tool: %s\n", toolName))
|
||||
sb.WriteString(fmt.Sprintf("Command: %s", cmd))
|
||||
return sb.String()
|
||||
}
|
||||
@@ -542,7 +522,7 @@ func formatToolDisplay(toolName string, args map[string]any) string {
|
||||
// For web search, show query and internet notice
|
||||
if toolName == "web_search" {
|
||||
if query, ok := args["query"].(string); ok {
|
||||
sb.WriteString(fmt.Sprintf("Tool: %s\n", displayName))
|
||||
sb.WriteString(fmt.Sprintf("Tool: %s\n", toolName))
|
||||
sb.WriteString(fmt.Sprintf("Query: %s\n", query))
|
||||
sb.WriteString("Uses internet via ollama.com")
|
||||
return sb.String()
|
||||
@@ -550,7 +530,7 @@ func formatToolDisplay(toolName string, args map[string]any) string {
|
||||
}
|
||||
|
||||
// Generic display
|
||||
sb.WriteString(fmt.Sprintf("Tool: %s", displayName))
|
||||
sb.WriteString(fmt.Sprintf("Tool: %s", toolName))
|
||||
if len(args) > 0 {
|
||||
sb.WriteString("\nArguments: ")
|
||||
first := true
|
||||
@@ -744,7 +724,7 @@ func wrapText(text string, maxWidth int) []string {
|
||||
|
||||
// getHintLines returns the hint text wrapped to terminal width
|
||||
func getHintLines(state *selectorState) []string {
|
||||
hint := "up/down select, enter confirm, 1-3 quick select, ctrl+c cancel"
|
||||
hint := "↑/↓ navigate, Enter confirm, 1-3 quick, Ctrl+C cancel"
|
||||
if state.termWidth >= len(hint)+1 {
|
||||
return []string{hint}
|
||||
}
|
||||
@@ -754,60 +734,86 @@ func getHintLines(state *selectorState) []string {
|
||||
|
||||
// calculateTotalLines calculates how many lines the selector will use
|
||||
func calculateTotalLines(state *selectorState) int {
|
||||
toolLines := strings.Split(state.toolDisplay, "\n")
|
||||
toolLines := wrapText(state.toolDisplay, state.innerWidth)
|
||||
hintLines := getHintLines(state)
|
||||
// warning line (if applicable) + tool lines + blank line + options + blank line + hint lines
|
||||
// top border + (warning line if applicable) + tool lines + separator + options + bottom border + hint lines
|
||||
warningLines := 0
|
||||
if state.isWarning {
|
||||
warningLines = 2 // warning line + blank line after
|
||||
warningLines = 1
|
||||
}
|
||||
return warningLines + len(toolLines) + 1 + len(optionLabels) + 1 + len(hintLines)
|
||||
return 1 + warningLines + len(toolLines) + 1 + len(optionLabels) + 1 + len(hintLines)
|
||||
}
|
||||
|
||||
// renderSelectorBox renders the selector (minimal, no box)
|
||||
// renderSelectorBox renders the complete selector box
|
||||
func renderSelectorBox(state *selectorState) {
|
||||
toolLines := strings.Split(state.toolDisplay, "\n")
|
||||
toolLines := wrapText(state.toolDisplay, state.innerWidth)
|
||||
hintLines := getHintLines(state)
|
||||
|
||||
// Draw warning line if needed
|
||||
// Use red for warning (outside cwd), cyan for normal
|
||||
boxColor := "\033[36m" // cyan
|
||||
if state.isWarning {
|
||||
fmt.Fprintf(os.Stderr, "\033[1mwarning:\033[0m command targets paths outside project\033[K\r\n")
|
||||
fmt.Fprintf(os.Stderr, "\033[K\r\n") // blank line after warning
|
||||
boxColor = "\033[91m" // bright red
|
||||
}
|
||||
|
||||
// Draw tool info (plain white)
|
||||
// Draw box top
|
||||
fmt.Fprintf(os.Stderr, "%s┌%s┐\033[0m\033[K\r\n", boxColor, strings.Repeat("─", state.boxWidth-2))
|
||||
|
||||
// Draw warning line if needed (inside the box)
|
||||
if state.isWarning {
|
||||
warning := "!! OUTSIDE PROJECT !!"
|
||||
padding := (state.innerWidth - len(warning)) / 2
|
||||
if padding < 0 {
|
||||
padding = 0
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "%s│\033[0m %s%s%s %s│\033[0m\033[K\r\n", boxColor,
|
||||
strings.Repeat(" ", padding), warning, strings.Repeat(" ", state.innerWidth-len(warning)-padding), boxColor)
|
||||
}
|
||||
|
||||
// Draw tool info
|
||||
for _, line := range toolLines {
|
||||
fmt.Fprintf(os.Stderr, "%s\033[K\r\n", line)
|
||||
fmt.Fprintf(os.Stderr, "%s│\033[0m %-*s %s│\033[0m\033[K\r\n", boxColor, state.innerWidth, line, boxColor)
|
||||
}
|
||||
|
||||
// Blank line separator
|
||||
fmt.Fprintf(os.Stderr, "\033[K\r\n")
|
||||
// Draw separator
|
||||
fmt.Fprintf(os.Stderr, "%s├%s┤\033[0m\033[K\r\n", boxColor, strings.Repeat("─", state.boxWidth-2))
|
||||
|
||||
// Draw options
|
||||
// Draw options with numbers (Deny option includes reason input)
|
||||
for i, label := range optionLabels {
|
||||
if i == 2 { // Deny option with input
|
||||
if i == 2 { // Deny option - show with reason input beside it
|
||||
denyLabel := "3. Deny: "
|
||||
availableWidth := state.innerWidth - 2 - len(denyLabel)
|
||||
if availableWidth < 5 {
|
||||
availableWidth = 5
|
||||
}
|
||||
inputDisplay := state.denyReason
|
||||
if len(inputDisplay) > availableWidth {
|
||||
inputDisplay = inputDisplay[len(inputDisplay)-availableWidth:]
|
||||
}
|
||||
if i == state.selected {
|
||||
fmt.Fprintf(os.Stderr, " \033[1m%s\033[0m%s\033[K\r\n", denyLabel, inputDisplay)
|
||||
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[1;32m> %s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, " \033[37m%s\033[0m%s\033[K\r\n", denyLabel, inputDisplay)
|
||||
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[90m%s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
|
||||
}
|
||||
} else {
|
||||
displayLabel := label
|
||||
if len(displayLabel) > state.innerWidth-2 {
|
||||
displayLabel = displayLabel[:state.innerWidth-5] + "..."
|
||||
}
|
||||
if i == state.selected {
|
||||
fmt.Fprintf(os.Stderr, " \033[1m%s\033[0m\033[K\r\n", label)
|
||||
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[1;32m> %-*s\033[0m %s│\033[0m\033[K\r\n", boxColor, state.innerWidth-2, displayLabel, boxColor)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, " \033[37m%s\033[0m\033[K\r\n", label)
|
||||
fmt.Fprintf(os.Stderr, "%s│\033[0m %-*s %s│\033[0m\033[K\r\n", boxColor, state.innerWidth-2, displayLabel, boxColor)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Blank line before hint
|
||||
fmt.Fprintf(os.Stderr, "\033[K\r\n")
|
||||
// Draw box bottom
|
||||
fmt.Fprintf(os.Stderr, "%s└%s┘\033[0m\033[K\r\n", boxColor, strings.Repeat("─", state.boxWidth-2))
|
||||
|
||||
// Draw hint (dark grey)
|
||||
// Draw hint (may be multiple lines)
|
||||
for i, line := range hintLines {
|
||||
if i == len(hintLines)-1 {
|
||||
// Last line - no newline
|
||||
fmt.Fprintf(os.Stderr, "\033[90m%s\033[0m\033[K", line)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "\033[90m%s\033[0m\033[K\r\n", line)
|
||||
@@ -819,33 +825,50 @@ func renderSelectorBox(state *selectorState) {
|
||||
func updateSelectorOptions(state *selectorState) {
|
||||
hintLines := getHintLines(state)
|
||||
|
||||
// Use red for warning (outside cwd), cyan for normal
|
||||
boxColor := "\033[36m" // cyan
|
||||
if state.isWarning {
|
||||
boxColor = "\033[91m" // bright red
|
||||
}
|
||||
|
||||
// Move up to the first option line
|
||||
// Cursor is at end of last hint line, need to go up:
|
||||
// (hint lines - 1) + 1 (blank line) + numOptions
|
||||
// (hint lines - 1) + 1 (bottom border) + numOptions
|
||||
linesToMove := len(hintLines) - 1 + 1 + len(optionLabels)
|
||||
fmt.Fprintf(os.Stderr, "\033[%dA\r", linesToMove)
|
||||
|
||||
// Redraw options
|
||||
// Redraw options (Deny option includes reason input)
|
||||
for i, label := range optionLabels {
|
||||
if i == 2 { // Deny option
|
||||
denyLabel := "3. Deny: "
|
||||
availableWidth := state.innerWidth - 2 - len(denyLabel)
|
||||
if availableWidth < 5 {
|
||||
availableWidth = 5
|
||||
}
|
||||
inputDisplay := state.denyReason
|
||||
if len(inputDisplay) > availableWidth {
|
||||
inputDisplay = inputDisplay[len(inputDisplay)-availableWidth:]
|
||||
}
|
||||
if i == state.selected {
|
||||
fmt.Fprintf(os.Stderr, " \033[1m%s\033[0m%s\033[K\r\n", denyLabel, inputDisplay)
|
||||
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[1;32m> %s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, " \033[37m%s\033[0m%s\033[K\r\n", denyLabel, inputDisplay)
|
||||
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[90m%s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
|
||||
}
|
||||
} else {
|
||||
displayLabel := label
|
||||
if len(displayLabel) > state.innerWidth-2 {
|
||||
displayLabel = displayLabel[:state.innerWidth-5] + "..."
|
||||
}
|
||||
if i == state.selected {
|
||||
fmt.Fprintf(os.Stderr, " \033[1m%s\033[0m\033[K\r\n", label)
|
||||
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[1;32m> %-*s\033[0m %s│\033[0m\033[K\r\n", boxColor, state.innerWidth-2, displayLabel, boxColor)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, " \033[37m%s\033[0m\033[K\r\n", label)
|
||||
fmt.Fprintf(os.Stderr, "%s│\033[0m %-*s %s│\033[0m\033[K\r\n", boxColor, state.innerWidth-2, displayLabel, boxColor)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Blank line + hint
|
||||
fmt.Fprintf(os.Stderr, "\033[K\r\n")
|
||||
// Redraw bottom and hint
|
||||
fmt.Fprintf(os.Stderr, "%s└%s┘\033[0m\033[K\r\n", boxColor, strings.Repeat("─", state.boxWidth-2))
|
||||
for i, line := range hintLines {
|
||||
if i == len(hintLines)-1 {
|
||||
fmt.Fprintf(os.Stderr, "\033[90m%s\033[0m\033[K", line)
|
||||
@@ -859,23 +882,36 @@ func updateSelectorOptions(state *selectorState) {
|
||||
func updateReasonInput(state *selectorState) {
|
||||
hintLines := getHintLines(state)
|
||||
|
||||
// Use red for warning (outside cwd), cyan for normal
|
||||
boxColor := "\033[36m" // cyan
|
||||
if state.isWarning {
|
||||
boxColor = "\033[91m" // bright red
|
||||
}
|
||||
|
||||
// Move up to the Deny line (3rd option, index 2)
|
||||
// Cursor is at end of last hint line, need to go up:
|
||||
// (hint lines - 1) + 1 (blank line) + 1 (Deny is last option)
|
||||
// (hint lines - 1) + 1 (bottom border) + 1 (Deny is last option)
|
||||
linesToMove := len(hintLines) - 1 + 1 + 1
|
||||
fmt.Fprintf(os.Stderr, "\033[%dA\r", linesToMove)
|
||||
|
||||
// Redraw Deny line with reason
|
||||
denyLabel := "3. Deny: "
|
||||
availableWidth := state.innerWidth - 2 - len(denyLabel)
|
||||
if availableWidth < 5 {
|
||||
availableWidth = 5
|
||||
}
|
||||
inputDisplay := state.denyReason
|
||||
if len(inputDisplay) > availableWidth {
|
||||
inputDisplay = inputDisplay[len(inputDisplay)-availableWidth:]
|
||||
}
|
||||
if state.selected == 2 {
|
||||
fmt.Fprintf(os.Stderr, " \033[1m%s\033[0m%s\033[K\r\n", denyLabel, inputDisplay)
|
||||
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[1;32m> %s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, " \033[37m%s\033[0m%s\033[K\r\n", denyLabel, inputDisplay)
|
||||
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[90m%s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
|
||||
}
|
||||
|
||||
// Blank line + hint
|
||||
fmt.Fprintf(os.Stderr, "\033[K\r\n")
|
||||
// Redraw bottom and hint
|
||||
fmt.Fprintf(os.Stderr, "%s└%s┘\033[0m\033[K\r\n", boxColor, strings.Repeat("─", state.boxWidth-2))
|
||||
for i, line := range hintLines {
|
||||
if i == len(hintLines)-1 {
|
||||
fmt.Fprintf(os.Stderr, "\033[90m%s\033[0m\033[K", line)
|
||||
@@ -899,10 +935,11 @@ func clearSelectorBox(state *selectorState) {
|
||||
// fallbackApproval handles approval when terminal control isn't available.
|
||||
func (a *ApprovalManager) fallbackApproval(toolDisplay string) (ApprovalResult, error) {
|
||||
fmt.Fprintln(os.Stderr)
|
||||
fmt.Fprintln(os.Stderr, "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
|
||||
fmt.Fprintln(os.Stderr, toolDisplay)
|
||||
fmt.Fprintln(os.Stderr)
|
||||
fmt.Fprintln(os.Stderr, "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
|
||||
fmt.Fprintln(os.Stderr, "[1] Execute once [2] Always allow [3] Deny")
|
||||
fmt.Fprint(os.Stderr, "choice: ")
|
||||
fmt.Fprint(os.Stderr, "Choice: ")
|
||||
|
||||
var input string
|
||||
fmt.Scanln(&input)
|
||||
@@ -945,16 +982,19 @@ func (a *ApprovalManager) AllowedTools() []string {
|
||||
|
||||
// FormatApprovalResult returns a formatted string showing the approval result.
|
||||
func FormatApprovalResult(toolName string, args map[string]any, result ApprovalResult) string {
|
||||
var label string
|
||||
displayName := ToolDisplayName(toolName)
|
||||
var status string
|
||||
var icon string
|
||||
|
||||
switch result.Decision {
|
||||
case ApprovalOnce:
|
||||
label = "approved"
|
||||
status = "Approved"
|
||||
icon = "\033[32m✓\033[0m"
|
||||
case ApprovalAlways:
|
||||
label = "always allowed"
|
||||
status = "Always allowed"
|
||||
icon = "\033[32m✓\033[0m"
|
||||
case ApprovalDeny:
|
||||
label = "denied"
|
||||
status = "Denied"
|
||||
icon = "\033[31m✗\033[0m"
|
||||
}
|
||||
|
||||
// Format based on tool type
|
||||
@@ -964,7 +1004,7 @@ func FormatApprovalResult(toolName string, args map[string]any, result ApprovalR
|
||||
if len(cmd) > 40 {
|
||||
cmd = cmd[:37] + "..."
|
||||
}
|
||||
return fmt.Sprintf("\033[1m%s:\033[0m %s: %s", label, displayName, cmd)
|
||||
return fmt.Sprintf("▶ bash: %s [%s] %s", cmd, status, icon)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -974,11 +1014,11 @@ func FormatApprovalResult(toolName string, args map[string]any, result ApprovalR
|
||||
if len(query) > 40 {
|
||||
query = query[:37] + "..."
|
||||
}
|
||||
return fmt.Sprintf("\033[1m%s:\033[0m %s: %s", label, displayName, query)
|
||||
return fmt.Sprintf("▶ web_search: %s [%s] %s", query, status, icon)
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Sprintf("\033[1m%s:\033[0m %s", label, displayName)
|
||||
return fmt.Sprintf("▶ %s [%s] %s", toolName, status, icon)
|
||||
}
|
||||
|
||||
// FormatDenyResult returns the tool result message when a tool is denied.
|
||||
@@ -1009,14 +1049,15 @@ func PromptYesNo(question string) (bool, error) {
|
||||
renderYesNo := func() {
|
||||
// Move to start of line and clear
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K")
|
||||
fmt.Fprintf(os.Stderr, "%s ", question)
|
||||
fmt.Fprintf(os.Stderr, "\033[36m%s\033[0m ", question)
|
||||
for i, opt := range options {
|
||||
if i == selected {
|
||||
fmt.Fprintf(os.Stderr, "\033[1m%s\033[0m ", opt)
|
||||
fmt.Fprintf(os.Stderr, "\033[1;32m[%s]\033[0m ", opt)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "\033[37m%s\033[0m ", opt)
|
||||
fmt.Fprintf(os.Stderr, "\033[90m %s \033[0m ", opt)
|
||||
}
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "\033[90m(←/→ or y/n, Enter to confirm)\033[0m")
|
||||
}
|
||||
|
||||
renderYesNo()
|
||||
|
||||
107
x/cmd/run.go
107
x/cmd/run.go
@@ -91,8 +91,8 @@ func waitForOllamaSignin(ctx context.Context) error {
|
||||
var aErr api.AuthorizationError
|
||||
if errors.As(err, &aErr) && aErr.SigninURL != "" {
|
||||
fmt.Fprintf(os.Stderr, "\n To sign in, navigate to:\n")
|
||||
fmt.Fprintf(os.Stderr, " %s\n\n", aErr.SigninURL)
|
||||
fmt.Fprintf(os.Stderr, " \033[90mwaiting for sign in to complete...\033[0m")
|
||||
fmt.Fprintf(os.Stderr, " \033[36m%s\033[0m\n\n", aErr.SigninURL)
|
||||
fmt.Fprintf(os.Stderr, " \033[90mWaiting for sign in to complete...\033[0m")
|
||||
|
||||
// Poll until auth succeeds
|
||||
ticker := time.NewTicker(2 * time.Second)
|
||||
@@ -106,7 +106,7 @@ func waitForOllamaSignin(ctx context.Context) error {
|
||||
case <-ticker.C:
|
||||
user, whoamiErr := client.Whoami(ctx)
|
||||
if whoamiErr == nil && user != nil && user.Name != "" {
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K \033[1msigned in:\033[0m %s\n", user.Name)
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K \033[32mSigned in as %s\033[0m\n", user.Name)
|
||||
return nil
|
||||
}
|
||||
// Still waiting, show dot
|
||||
@@ -137,6 +137,13 @@ type RunOptions struct {
|
||||
|
||||
// YoloMode skips all tool approval prompts
|
||||
YoloMode bool
|
||||
|
||||
// LastToolOutput stores the full output of the last tool execution
|
||||
// for Ctrl+O expansion. Updated by Chat(), read by caller.
|
||||
LastToolOutput *string
|
||||
|
||||
// LastToolOutputTruncated stores the truncated version shown inline
|
||||
LastToolOutputTruncated *string
|
||||
}
|
||||
|
||||
// Chat runs an agent chat loop with tool support.
|
||||
@@ -264,12 +271,12 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
|
||||
var authErr api.AuthorizationError
|
||||
if errors.As(err, &authErr) {
|
||||
p.StopAndClear()
|
||||
fmt.Fprintf(os.Stderr, "\033[1mauth required:\033[0m cloud model requires authentication\n")
|
||||
fmt.Fprintf(os.Stderr, "\033[33mAuthentication required to use this cloud model.\033[0m\n")
|
||||
result, promptErr := agent.PromptYesNo("Sign in to Ollama?")
|
||||
if promptErr == nil && result {
|
||||
if signinErr := waitForOllamaSignin(ctx); signinErr == nil {
|
||||
// Retry the chat request
|
||||
fmt.Fprintf(os.Stderr, "\033[90mretrying...\033[0m\n")
|
||||
fmt.Fprintf(os.Stderr, "\033[90mRetrying...\033[0m\n")
|
||||
continue // Retry the loop
|
||||
}
|
||||
}
|
||||
@@ -283,11 +290,11 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
|
||||
p.StopAndClear()
|
||||
|
||||
if consecutiveErrors >= 3 {
|
||||
fmt.Fprintf(os.Stderr, "\033[1merror:\033[0m too many consecutive errors, giving up\n")
|
||||
fmt.Fprintf(os.Stderr, "\033[31m✗ Too many consecutive errors, giving up\033[0m\n")
|
||||
return nil, fmt.Errorf("too many consecutive server errors: %s", statusErr.ErrorMessage)
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\033[1mwarning:\033[0m server error (attempt %d/3): %s\n", consecutiveErrors, statusErr.ErrorMessage)
|
||||
fmt.Fprintf(os.Stderr, "\033[33m⚠ Server error (attempt %d/3): %s\033[0m\n", consecutiveErrors, statusErr.ErrorMessage)
|
||||
|
||||
// Include both the model's response and the error so it can learn
|
||||
assistantContent := fullResponse.String()
|
||||
@@ -353,8 +360,8 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
|
||||
if cmd, ok := args["command"].(string); ok {
|
||||
// Check if command is denied (dangerous pattern)
|
||||
if denied, pattern := agent.IsDenied(cmd); denied {
|
||||
fmt.Fprintf(os.Stderr, "\033[1mblocked:\033[0m %s\n", formatToolShort(toolName, args))
|
||||
fmt.Fprintf(os.Stderr, " matches dangerous pattern: %s\n", pattern)
|
||||
fmt.Fprintf(os.Stderr, "\033[91m✗ Blocked: %s\033[0m\n", formatToolShort(toolName, args))
|
||||
fmt.Fprintf(os.Stderr, "\033[91m Matches dangerous pattern: %s\033[0m\n", pattern)
|
||||
toolResults = append(toolResults, api.Message{
|
||||
Role: "tool",
|
||||
Content: agent.FormatDeniedResult(cmd, pattern),
|
||||
@@ -365,7 +372,7 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
|
||||
|
||||
// Check if command is auto-allowed (safe command)
|
||||
if agent.IsAutoAllowed(cmd) {
|
||||
fmt.Fprintf(os.Stderr, "\033[1mauto-allowed:\033[0m %s\n", formatToolShort(toolName, args))
|
||||
fmt.Fprintf(os.Stderr, "\033[90m▶ Auto-allowed: %s\033[0m\n", formatToolShort(toolName, args))
|
||||
skipApproval = true
|
||||
}
|
||||
}
|
||||
@@ -375,7 +382,7 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
|
||||
// In yolo mode, skip all approval prompts
|
||||
if opts.YoloMode {
|
||||
if !skipApproval {
|
||||
fmt.Fprintf(os.Stderr, "\033[1mrunning:\033[0m %s\n", formatToolShort(toolName, args))
|
||||
fmt.Fprintf(os.Stderr, "\033[90m▶ Running: %s\033[0m\n", formatToolShort(toolName, args))
|
||||
}
|
||||
} else if !skipApproval && !approval.IsAllowed(toolName, args) {
|
||||
result, err := approval.RequestApproval(toolName, args)
|
||||
@@ -405,7 +412,7 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
|
||||
}
|
||||
} else if !skipApproval {
|
||||
// Already allowed - show running indicator
|
||||
fmt.Fprintf(os.Stderr, "\033[1mrunning:\033[0m %s\n", formatToolShort(toolName, args))
|
||||
fmt.Fprintf(os.Stderr, "\033[90m▶ Running: %s\033[0m\n", formatToolShort(toolName, args))
|
||||
}
|
||||
|
||||
// Execute the tool
|
||||
@@ -414,13 +421,13 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
|
||||
// Check if web search needs authentication
|
||||
if errors.Is(err, tools.ErrWebSearchAuthRequired) {
|
||||
// Prompt user to sign in
|
||||
fmt.Fprintf(os.Stderr, "\033[1mauth required:\033[0m web search requires authentication\n")
|
||||
fmt.Fprintf(os.Stderr, "\033[33m Web search requires authentication.\033[0m\n")
|
||||
result, promptErr := agent.PromptYesNo("Sign in to Ollama?")
|
||||
if promptErr == nil && result {
|
||||
// Get signin URL and wait for auth completion
|
||||
if signinErr := waitForOllamaSignin(ctx); signinErr == nil {
|
||||
// Retry the web search
|
||||
fmt.Fprintf(os.Stderr, "\033[90mretrying web search...\033[0m\n")
|
||||
fmt.Fprintf(os.Stderr, "\033[90m Retrying web search...\033[0m\n")
|
||||
toolResult, err = toolRegistry.Execute(call)
|
||||
if err == nil {
|
||||
goto toolSuccess
|
||||
@@ -428,7 +435,7 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
|
||||
}
|
||||
}
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "\033[1merror:\033[0m %v\n", err)
|
||||
fmt.Fprintf(os.Stderr, "\033[31m Error: %v\033[0m\n", err)
|
||||
toolResults = append(toolResults, api.Message{
|
||||
Role: "tool",
|
||||
Content: fmt.Sprintf("Error: %v", err),
|
||||
@@ -439,15 +446,25 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
|
||||
toolSuccess:
|
||||
|
||||
// Display tool output (truncated for display)
|
||||
truncatedOutput := ""
|
||||
if toolResult != "" {
|
||||
output := toolResult
|
||||
if len(output) > 300 {
|
||||
output = output[:300] + "... (truncated)"
|
||||
output = output[:300] + "... (truncated, press Ctrl+O to expand)"
|
||||
}
|
||||
truncatedOutput = output
|
||||
// Show result in grey, indented
|
||||
fmt.Fprintf(os.Stderr, "\033[90m %s\033[0m\n", strings.ReplaceAll(output, "\n", "\n "))
|
||||
}
|
||||
|
||||
// Store full and truncated output for Ctrl+O toggle
|
||||
if opts.LastToolOutput != nil {
|
||||
*opts.LastToolOutput = toolResult
|
||||
}
|
||||
if opts.LastToolOutputTruncated != nil {
|
||||
*opts.LastToolOutputTruncated = truncatedOutput
|
||||
}
|
||||
|
||||
// Truncate output to prevent context overflow
|
||||
toolResultForLLM := truncateToolOutput(toolResult, opts.Model)
|
||||
|
||||
@@ -499,18 +516,17 @@ func truncateUTF8(s string, limit int) string {
|
||||
|
||||
// formatToolShort returns a short description of a tool call.
|
||||
func formatToolShort(toolName string, args map[string]any) string {
|
||||
displayName := agent.ToolDisplayName(toolName)
|
||||
if toolName == "bash" {
|
||||
if cmd, ok := args["command"].(string); ok {
|
||||
return fmt.Sprintf("%s: %s", displayName, truncateUTF8(cmd, 50))
|
||||
return fmt.Sprintf("bash: %s", truncateUTF8(cmd, 50))
|
||||
}
|
||||
}
|
||||
if toolName == "web_search" {
|
||||
if query, ok := args["query"].(string); ok {
|
||||
return fmt.Sprintf("%s: %s", displayName, truncateUTF8(query, 50))
|
||||
return fmt.Sprintf("web_search: %s", truncateUTF8(query, 50))
|
||||
}
|
||||
}
|
||||
return displayName
|
||||
return toolName
|
||||
}
|
||||
|
||||
// Helper types and functions for display
|
||||
@@ -650,7 +666,7 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
// Check if model supports tools
|
||||
supportsTools, err := checkModelCapabilities(cmd.Context(), modelName)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "\033[1mwarning:\033[0m could not check model capabilities: %v\n", err)
|
||||
fmt.Fprintf(os.Stderr, "\033[33mWarning: Could not check model capabilities: %v\033[0m\n", err)
|
||||
supportsTools = false
|
||||
}
|
||||
|
||||
@@ -659,13 +675,13 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
if supportsTools {
|
||||
toolRegistry = tools.DefaultRegistry()
|
||||
if toolRegistry.Count() > 0 {
|
||||
fmt.Fprintf(os.Stderr, "\033[90mtools available: %s\033[0m\n", strings.Join(toolRegistry.Names(), ", "))
|
||||
fmt.Fprintf(os.Stderr, "\033[90mTools available: %s\033[0m\n", strings.Join(toolRegistry.Names(), ", "))
|
||||
}
|
||||
if yoloMode {
|
||||
fmt.Fprintf(os.Stderr, "\033[1mwarning:\033[0m yolo mode - all tool approvals will be skipped\n")
|
||||
fmt.Fprintf(os.Stderr, "\033[33m⚠ YOLO mode: All tool approvals will be skipped\033[0m\n")
|
||||
}
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "\033[1mnote:\033[0m model does not support tools - running in chat-only mode\n")
|
||||
fmt.Fprintf(os.Stderr, "\033[33mNote: Model does not support tools - running in chat-only mode\033[0m\n")
|
||||
}
|
||||
|
||||
// Create approval manager for session
|
||||
@@ -674,6 +690,11 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
var messages []api.Message
|
||||
var sb strings.Builder
|
||||
|
||||
// Track last tool output for Ctrl+O toggle
|
||||
var lastToolOutput string
|
||||
var lastToolOutputTruncated string
|
||||
var toolOutputExpanded bool
|
||||
|
||||
for {
|
||||
line, err := scanner.Readline()
|
||||
switch {
|
||||
@@ -686,6 +707,20 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
}
|
||||
sb.Reset()
|
||||
continue
|
||||
case errors.Is(err, readline.ErrExpandOutput):
|
||||
// Ctrl+O pressed - toggle between expanded and collapsed tool output
|
||||
if lastToolOutput == "" {
|
||||
fmt.Fprintf(os.Stderr, "\033[90mNo tool output to expand\033[0m\n")
|
||||
} else if toolOutputExpanded {
|
||||
// Currently expanded, show truncated
|
||||
fmt.Fprintf(os.Stderr, "\033[90m %s\033[0m\n", strings.ReplaceAll(lastToolOutputTruncated, "\n", "\n "))
|
||||
toolOutputExpanded = false
|
||||
} else {
|
||||
// Currently collapsed, show full
|
||||
fmt.Fprintf(os.Stderr, "\033[90m %s\033[0m\n", strings.ReplaceAll(lastToolOutput, "\n", "\n "))
|
||||
toolOutputExpanded = true
|
||||
}
|
||||
continue
|
||||
case err != nil:
|
||||
return err
|
||||
}
|
||||
@@ -724,17 +759,21 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
messages = append(messages, newMessage)
|
||||
|
||||
opts := RunOptions{
|
||||
Model: modelName,
|
||||
Messages: messages,
|
||||
WordWrap: wordWrap,
|
||||
Options: options,
|
||||
Think: think,
|
||||
HideThinking: hideThinking,
|
||||
KeepAlive: keepAlive,
|
||||
Tools: toolRegistry,
|
||||
Approval: approval,
|
||||
YoloMode: yoloMode,
|
||||
Model: modelName,
|
||||
Messages: messages,
|
||||
WordWrap: wordWrap,
|
||||
Options: options,
|
||||
Think: think,
|
||||
HideThinking: hideThinking,
|
||||
KeepAlive: keepAlive,
|
||||
Tools: toolRegistry,
|
||||
Approval: approval,
|
||||
YoloMode: yoloMode,
|
||||
LastToolOutput: &lastToolOutput,
|
||||
LastToolOutputTruncated: &lastToolOutputTruncated,
|
||||
}
|
||||
// Reset expanded state for new tool execution
|
||||
toolOutputExpanded = false
|
||||
|
||||
assistant, err := Chat(cmd.Context(), opts)
|
||||
if err != nil {
|
||||
|
||||
38
x/imagegen/.gitignore
vendored
38
x/imagegen/.gitignore
vendored
@@ -1,38 +0,0 @@
|
||||
# Build directories
|
||||
build/
|
||||
dist/
|
||||
|
||||
# CMake
|
||||
CMakeCache.txt
|
||||
CMakeFiles/
|
||||
cmake_install.cmake
|
||||
Makefile
|
||||
*.cmake
|
||||
|
||||
# IDE
|
||||
.idea/
|
||||
.vscode/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
|
||||
# macOS
|
||||
.DS_Store
|
||||
*.dSYM/
|
||||
|
||||
# Go
|
||||
*.exe
|
||||
*.exe~
|
||||
*.dll
|
||||
*.so
|
||||
*.dylib
|
||||
|
||||
# Python
|
||||
*.npy
|
||||
|
||||
/engine
|
||||
weights
|
||||
outputs
|
||||
|
||||
prompt.txt
|
||||
negative.txt
|
||||
@@ -1,61 +0,0 @@
|
||||
# imagegen
|
||||
|
||||
This is a package that uses MLX to run image generation models, ahead of being integrated into Ollama's primary runner.
|
||||
in `CMakeLists.txt` and rebuild.
|
||||
|
||||
### 1. Download a Model
|
||||
|
||||
Download Llama 3.1 8B (or any compatible model) in safetensors format:
|
||||
|
||||
```bash
|
||||
mkdir -p ./weights
|
||||
|
||||
# Example using huggingface-cli
|
||||
hf download meta-llama/Llama-3.1-8B --local-dir ./weights/Llama-3.1-8B
|
||||
hf download openai/gpt-oss-20b --local-dir ./weights/gpt-oss-20b
|
||||
```
|
||||
|
||||
### 2. Run Inference
|
||||
|
||||
```bash
|
||||
# Build
|
||||
go build ./cmd/engine
|
||||
|
||||
# Text generation
|
||||
./engine -model ./weights/Llama-3.1-8B -prompt "Hello, world!" -max-tokens 250
|
||||
|
||||
# Qwen-Image 2512 (text-to-image)
|
||||
./engine -qwen-image -model ./weights/Qwen-Image-2512 -prompt "A mountain landscape at sunset" \
|
||||
-width 1024 -height 1024 -steps 20 -seed 42 -output landscape.png
|
||||
|
||||
# Qwen-Image Edit (experimental) - 8 steps for speed, but model recommends 50
|
||||
./engine -qwen-image-edit -model ./weights/Qwen-Image-Edit-2511 \
|
||||
-input-image input.png -prompt "Make it winter" -negative-prompt " " -cfg-scale 4.0 \
|
||||
-steps 8 -seed 42 -output edited.png
|
||||
```
|
||||
|
||||
## Memory Management
|
||||
|
||||
MLX Python/C++ uses scope-based memory management - arrays are freed when they go out of scope. Go's garbage collector is non-deterministic, so we can't rely on finalizers to free GPU memory promptly.
|
||||
|
||||
Instead, arrays are automatically tracked and freed on `Eval()`:
|
||||
|
||||
```go
|
||||
// All arrays are automatically tracked when created
|
||||
x := mlx.Add(a, b)
|
||||
y := mlx.Matmul(x, w)
|
||||
|
||||
// Eval frees non-kept arrays, evaluates outputs (auto-kept)
|
||||
mlx.Eval(y)
|
||||
|
||||
// After copying to CPU, free the array
|
||||
data := y.Data()
|
||||
y.Free()
|
||||
```
|
||||
|
||||
Key points:
|
||||
|
||||
- All created arrays are automatically tracked
|
||||
- `mlx.Eval(outputs...)` frees non-kept arrays, evaluates outputs (outputs auto-kept)
|
||||
- `mlx.Keep(arrays...)` marks arrays to survive multiple Eval cycles (for weights, caches)
|
||||
- Call `.Free()` when done with an array
|
||||
156
x/imagegen/cache/cache.go
vendored
156
x/imagegen/cache/cache.go
vendored
@@ -1,156 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package cache
|
||||
|
||||
import "github.com/ollama/ollama/x/imagegen/mlx"
|
||||
|
||||
type Cache interface {
|
||||
Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array)
|
||||
Offset() int
|
||||
Len() int
|
||||
State() []*mlx.Array
|
||||
}
|
||||
|
||||
type KVCache struct {
|
||||
keys, values *mlx.Array
|
||||
offset int
|
||||
step int
|
||||
}
|
||||
|
||||
func NewKVCache() *KVCache {
|
||||
return &KVCache{step: 256}
|
||||
}
|
||||
|
||||
func (c *KVCache) Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) {
|
||||
prev := c.offset
|
||||
shape := k.Shape()
|
||||
B, H, Dk := shape[0], shape[1], shape[3]
|
||||
Dv := v.Shape()[3]
|
||||
|
||||
// Grow buffer if needed
|
||||
if c.keys == nil || (prev+seqLen) > int(c.keys.Shape()[2]) {
|
||||
nSteps := (c.step + seqLen - 1) / c.step
|
||||
newK := mlx.Zeros([]int32{B, H, int32(nSteps * c.step), Dk}, k.Dtype())
|
||||
newV := mlx.Zeros([]int32{B, H, int32(nSteps * c.step), Dv}, v.Dtype())
|
||||
|
||||
if c.keys != nil {
|
||||
if prev%c.step != 0 {
|
||||
c.keys = mlx.Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, int32(prev), Dk})
|
||||
c.values = mlx.Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, int32(prev), Dv})
|
||||
}
|
||||
c.keys = mlx.Concatenate([]*mlx.Array{c.keys, newK}, 2)
|
||||
c.values = mlx.Concatenate([]*mlx.Array{c.values, newV}, 2)
|
||||
} else {
|
||||
c.keys, c.values = newK, newV
|
||||
}
|
||||
}
|
||||
|
||||
c.offset += seqLen
|
||||
c.keys = mlx.SliceUpdateInplace(c.keys, k, []int32{0, 0, int32(prev), 0}, []int32{B, H, int32(c.offset), Dk})
|
||||
c.values = mlx.SliceUpdateInplace(c.values, v, []int32{0, 0, int32(prev), 0}, []int32{B, H, int32(c.offset), Dv})
|
||||
|
||||
return mlx.Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, int32(c.offset), Dk}),
|
||||
mlx.Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, int32(c.offset), Dv})
|
||||
}
|
||||
|
||||
func (c *KVCache) State() []*mlx.Array {
|
||||
if c.keys == nil {
|
||||
return nil
|
||||
}
|
||||
return []*mlx.Array{c.keys, c.values}
|
||||
}
|
||||
|
||||
func (c *KVCache) Offset() int { return c.offset }
|
||||
func (c *KVCache) Len() int { return c.offset }
|
||||
|
||||
// RotatingKVCache implements sliding window attention with bounded memory
|
||||
type RotatingKVCache struct {
|
||||
keys, values *mlx.Array
|
||||
offset int
|
||||
maxSize int
|
||||
step int
|
||||
idx int
|
||||
}
|
||||
|
||||
func NewRotatingKVCache(maxSize int) *RotatingKVCache {
|
||||
return &RotatingKVCache{maxSize: maxSize, step: 256}
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) {
|
||||
if seqLen > 1 {
|
||||
return c.updateConcat(k, v, seqLen)
|
||||
}
|
||||
return c.updateInPlace(k, v)
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) updateInPlace(k, v *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
shape := k.Shape()
|
||||
B, H, Dk := shape[0], shape[1], shape[3]
|
||||
Dv := v.Shape()[3]
|
||||
|
||||
// Grow buffer if not yet at max
|
||||
if c.keys == nil || (c.idx >= int(c.keys.Shape()[2]) && int(c.keys.Shape()[2]) < c.maxSize) {
|
||||
var cap int
|
||||
if c.keys != nil {
|
||||
cap = int(c.keys.Shape()[2])
|
||||
}
|
||||
newSize := min(c.step, c.maxSize-cap)
|
||||
newK := mlx.Zeros([]int32{B, H, int32(newSize), Dk}, k.Dtype())
|
||||
newV := mlx.Zeros([]int32{B, H, int32(newSize), Dv}, v.Dtype())
|
||||
if c.keys != nil {
|
||||
c.keys = mlx.Concatenate([]*mlx.Array{c.keys, newK}, 2)
|
||||
c.values = mlx.Concatenate([]*mlx.Array{c.values, newV}, 2)
|
||||
} else {
|
||||
c.keys, c.values = newK, newV
|
||||
}
|
||||
}
|
||||
|
||||
// Rotate when hitting max
|
||||
if c.idx >= c.maxSize {
|
||||
c.idx = 0
|
||||
}
|
||||
|
||||
c.keys = mlx.SliceUpdateInplace(c.keys, k, []int32{0, 0, int32(c.idx), 0}, []int32{B, H, int32(c.idx + 1), Dk})
|
||||
c.values = mlx.SliceUpdateInplace(c.values, v, []int32{0, 0, int32(c.idx), 0}, []int32{B, H, int32(c.idx + 1), Dv})
|
||||
|
||||
c.offset++
|
||||
c.idx++
|
||||
|
||||
validLen := int32(min(c.offset, c.maxSize))
|
||||
return mlx.Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, validLen, Dk}),
|
||||
mlx.Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, validLen, Dv})
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) updateConcat(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) {
|
||||
shape := k.Shape()
|
||||
B, H, Dk := shape[0], shape[1], shape[3]
|
||||
Dv := v.Shape()[3]
|
||||
|
||||
if c.keys == nil {
|
||||
c.keys, c.values = k, v
|
||||
} else {
|
||||
c.keys = mlx.Concatenate([]*mlx.Array{c.keys, k}, 2)
|
||||
c.values = mlx.Concatenate([]*mlx.Array{c.values, v}, 2)
|
||||
}
|
||||
c.offset += seqLen
|
||||
|
||||
// Trim to max_size to maintain sliding window
|
||||
cap := int(c.keys.Shape()[2])
|
||||
if trim := cap - c.maxSize; trim > 0 {
|
||||
c.keys = mlx.Slice(c.keys, []int32{0, 0, int32(trim), 0}, []int32{B, H, int32(cap), Dk})
|
||||
c.values = mlx.Slice(c.values, []int32{0, 0, int32(trim), 0}, []int32{B, H, int32(cap), Dv})
|
||||
}
|
||||
|
||||
c.idx = int(c.keys.Shape()[2])
|
||||
return c.keys, c.values
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) State() []*mlx.Array {
|
||||
if c.keys == nil {
|
||||
return nil
|
||||
}
|
||||
return []*mlx.Array{c.keys, c.values}
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) Offset() int { return c.offset }
|
||||
func (c *RotatingKVCache) Len() int { return min(c.offset, c.maxSize) }
|
||||
164
x/imagegen/cache/step.go
vendored
164
x/imagegen/cache/step.go
vendored
@@ -1,164 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package cache
|
||||
|
||||
import "github.com/ollama/ollama/x/imagegen/mlx"
|
||||
|
||||
// StepCache caches layer outputs across diffusion denoising steps.
|
||||
// Based on DeepCache (CVPR 2024) and Learning-to-Cache (NeurIPS 2024):
|
||||
// shallow layers change little between consecutive steps, so we can
|
||||
// cache their outputs and skip recomputation on non-refresh steps.
|
||||
//
|
||||
// Supports both single-stream (Z-Image) and dual-stream (Qwen-Image) architectures:
|
||||
// - Single-stream: use Get/Set for the single output per layer
|
||||
// - Dual-stream: use Get/Set for stream 1 (imgH), Get2/Set2 for stream 2 (txtH)
|
||||
//
|
||||
// Usage (single-stream):
|
||||
//
|
||||
// cache := NewStepCache(15) // cache first 15 layers
|
||||
// for step := 0; step < numSteps; step++ {
|
||||
// refresh := cache.ShouldRefresh(step, 3) // refresh every 3 steps
|
||||
// for i, layer := range layers {
|
||||
// if i < 15 && !refresh && cache.Get(i) != nil {
|
||||
// output = cache.Get(i) // reuse cached
|
||||
// } else {
|
||||
// output = layer.Forward(input)
|
||||
// if i < 15 && refresh {
|
||||
// cache.Set(i, output)
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// cache.Free() // cleanup when done
|
||||
//
|
||||
// Usage (dual-stream):
|
||||
//
|
||||
// cache := NewStepCache(15)
|
||||
// for step := 0; step < numSteps; step++ {
|
||||
// refresh := cache.ShouldRefresh(step, 3)
|
||||
// for i, layer := range layers {
|
||||
// if i < 15 && !refresh && cache.Get(i) != nil {
|
||||
// imgH, txtH = cache.Get(i), cache.Get2(i)
|
||||
// } else {
|
||||
// imgH, txtH = layer.Forward(imgH, txtH, ...)
|
||||
// if i < 15 && refresh {
|
||||
// cache.Set(i, imgH)
|
||||
// cache.Set2(i, txtH)
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
type StepCache struct {
|
||||
layers []*mlx.Array // cached layer outputs (stream 1)
|
||||
layers2 []*mlx.Array // cached layer outputs (stream 2, for dual-stream models)
|
||||
constant *mlx.Array // optional constant (e.g., text embeddings)
|
||||
}
|
||||
|
||||
// NewStepCache creates a cache for the given number of layers.
|
||||
func NewStepCache(numLayers int) *StepCache {
|
||||
return &StepCache{
|
||||
layers: make([]*mlx.Array, numLayers),
|
||||
layers2: make([]*mlx.Array, numLayers),
|
||||
}
|
||||
}
|
||||
|
||||
// ShouldRefresh returns true if the cache should be refreshed at this step.
|
||||
// Refresh happens on step 0, interval, 2*interval, etc.
|
||||
func (c *StepCache) ShouldRefresh(step, interval int) bool {
|
||||
return step%interval == 0
|
||||
}
|
||||
|
||||
// Get returns the cached output for a layer, or nil if not cached.
|
||||
func (c *StepCache) Get(layer int) *mlx.Array {
|
||||
if layer < len(c.layers) {
|
||||
return c.layers[layer]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Set stores a layer output (stream 1), freeing any previous value.
|
||||
func (c *StepCache) Set(layer int, arr *mlx.Array) {
|
||||
if layer < len(c.layers) {
|
||||
if c.layers[layer] != nil {
|
||||
c.layers[layer].Free()
|
||||
}
|
||||
c.layers[layer] = arr
|
||||
}
|
||||
}
|
||||
|
||||
// Get2 returns the cached output for a layer (stream 2), or nil if not cached.
|
||||
// Used for dual-stream architectures like Qwen-Image.
|
||||
func (c *StepCache) Get2(layer int) *mlx.Array {
|
||||
if layer < len(c.layers2) {
|
||||
return c.layers2[layer]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Set2 stores a layer output (stream 2), freeing any previous value.
|
||||
// Used for dual-stream architectures like Qwen-Image.
|
||||
func (c *StepCache) Set2(layer int, arr *mlx.Array) {
|
||||
if layer < len(c.layers2) {
|
||||
if c.layers2[layer] != nil {
|
||||
c.layers2[layer].Free()
|
||||
}
|
||||
c.layers2[layer] = arr
|
||||
}
|
||||
}
|
||||
|
||||
// GetConstant returns the cached constant value.
|
||||
func (c *StepCache) GetConstant() *mlx.Array {
|
||||
return c.constant
|
||||
}
|
||||
|
||||
// SetConstant stores a constant value, freeing any previous value.
|
||||
func (c *StepCache) SetConstant(arr *mlx.Array) {
|
||||
if c.constant != nil {
|
||||
c.constant.Free()
|
||||
}
|
||||
c.constant = arr
|
||||
}
|
||||
|
||||
// Arrays returns all non-nil cached arrays (for pool.Keep).
|
||||
func (c *StepCache) Arrays() []*mlx.Array {
|
||||
var result []*mlx.Array
|
||||
if c.constant != nil {
|
||||
result = append(result, c.constant)
|
||||
}
|
||||
for _, arr := range c.layers {
|
||||
if arr != nil {
|
||||
result = append(result, arr)
|
||||
}
|
||||
}
|
||||
for _, arr := range c.layers2 {
|
||||
if arr != nil {
|
||||
result = append(result, arr)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// Free releases all cached arrays. Call when generation completes.
|
||||
func (c *StepCache) Free() {
|
||||
if c.constant != nil {
|
||||
c.constant.Free()
|
||||
c.constant = nil
|
||||
}
|
||||
for i, arr := range c.layers {
|
||||
if arr != nil {
|
||||
arr.Free()
|
||||
c.layers[i] = nil
|
||||
}
|
||||
}
|
||||
for i, arr := range c.layers2 {
|
||||
if arr != nil {
|
||||
arr.Free()
|
||||
c.layers2[i] = nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// NumLayers returns the number of layers this cache can store.
|
||||
func (c *StepCache) NumLayers() int {
|
||||
return len(c.layers)
|
||||
}
|
||||
@@ -1,359 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
)
|
||||
|
||||
// Dedicated stream for generation (like mlx-lm's generation_stream)
|
||||
var generationStream *mlx.Stream
|
||||
|
||||
// utf8Streamer buffers decoded text and emits only complete UTF-8 characters.
|
||||
// This handles cases where tokenizers output partial multi-byte sequences.
|
||||
type utf8Streamer struct {
|
||||
buffer []byte
|
||||
}
|
||||
|
||||
// Write adds decoded text to the buffer and returns complete UTF-8 characters.
|
||||
func (s *utf8Streamer) Write(text string) string {
|
||||
s.buffer = append(s.buffer, text...)
|
||||
|
||||
// Find the last position that ends with a complete UTF-8 character
|
||||
validLen := 0
|
||||
for i := 0; i < len(s.buffer); {
|
||||
r, size := utf8.DecodeRune(s.buffer[i:])
|
||||
if r == utf8.RuneError && size == 1 {
|
||||
// Invalid or incomplete UTF-8 sequence at this position
|
||||
// Check if it could be a valid start of a multi-byte sequence
|
||||
if len(s.buffer)-i < 4 {
|
||||
// Might be incomplete, keep it in buffer
|
||||
break
|
||||
}
|
||||
// Definitely invalid, skip this byte
|
||||
i++
|
||||
validLen = i
|
||||
} else {
|
||||
i += size
|
||||
validLen = i
|
||||
}
|
||||
}
|
||||
|
||||
if validLen == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
result := string(s.buffer[:validLen])
|
||||
s.buffer = s.buffer[validLen:]
|
||||
return result
|
||||
}
|
||||
|
||||
// Flush returns any remaining buffered bytes (may be incomplete UTF-8).
|
||||
func (s *utf8Streamer) Flush() string {
|
||||
if len(s.buffer) == 0 {
|
||||
return ""
|
||||
}
|
||||
result := string(s.buffer)
|
||||
s.buffer = nil
|
||||
return result
|
||||
}
|
||||
|
||||
func init() {
|
||||
generationStream = mlx.NewStream()
|
||||
}
|
||||
|
||||
// withStream runs fn with the generation stream as default
|
||||
func withStream(fn func()) {
|
||||
orig := mlx.GetDefaultStream()
|
||||
mlx.SetDefaultStream(generationStream)
|
||||
fn()
|
||||
mlx.SetDefaultStream(orig)
|
||||
}
|
||||
|
||||
type Model interface {
|
||||
Tokenizer() *tokenizer.Tokenizer
|
||||
VocabSize() int32
|
||||
NewCache(maxSeqLen int32) []cache.Cache
|
||||
Forward(input *mlx.Array, caches []cache.Cache) *mlx.Array
|
||||
}
|
||||
|
||||
// ChatModel is an optional interface for models that support chat formatting
|
||||
type ChatModel interface {
|
||||
FormatPrompt(prompt string) string
|
||||
}
|
||||
|
||||
// MultimodalModel is for models that support image input
|
||||
type MultimodalModel interface {
|
||||
Model
|
||||
FormatPromptWithImage(prompt string) string
|
||||
ExpandImageTokens(tokens []int32) []int32
|
||||
ForwardWithImage(tokens *mlx.Array, image *mlx.Array, caches []cache.Cache) *mlx.Array
|
||||
ImageSize() int32 // Returns expected image size for preprocessing
|
||||
}
|
||||
|
||||
// ImageLoader loads and preprocesses an image for multimodal models
|
||||
// Returns nil if path is empty
|
||||
type ImageLoader func(path string, imageSize int32) (*mlx.Array, error)
|
||||
|
||||
type input struct {
|
||||
Prompt string
|
||||
Image *mlx.Array // Optional preprocessed image for multimodal models
|
||||
MaxTokens int
|
||||
Temperature float32
|
||||
TopP float32
|
||||
TopK int
|
||||
WiredLimitGB int // Metal wired memory limit in GB (default 32)
|
||||
}
|
||||
|
||||
type output struct {
|
||||
Text string
|
||||
Done bool
|
||||
PrefillTokSec float64
|
||||
GenTokSec float64
|
||||
}
|
||||
|
||||
// Decoder wraps model + cache for autoregressive generation.
|
||||
type Decoder struct {
|
||||
model Model
|
||||
caches []cache.Cache
|
||||
vocabSize int32
|
||||
temp float32
|
||||
topK int
|
||||
topP float32
|
||||
token *mlx.Array // Current token (kept across pools)
|
||||
oldCacheState []*mlx.Array // Preallocated slice for old cache state
|
||||
image *mlx.Array // Optional image for multimodal prefill
|
||||
}
|
||||
|
||||
func NewDecoder(m Model, temp float32, topK int, topP float32) *Decoder {
|
||||
caches := m.NewCache(0)
|
||||
return &Decoder{
|
||||
model: m,
|
||||
caches: caches,
|
||||
vocabSize: m.VocabSize(),
|
||||
temp: temp,
|
||||
topK: topK,
|
||||
topP: topP,
|
||||
oldCacheState: make([]*mlx.Array, 0, len(caches)*2),
|
||||
}
|
||||
}
|
||||
|
||||
// SetImage sets the image for multimodal prefill (call before prefill)
|
||||
func (d *Decoder) SetImage(img *mlx.Array) {
|
||||
d.image = img
|
||||
}
|
||||
|
||||
func (d *Decoder) prefill(inputIDs []int32) int {
|
||||
processed := 0
|
||||
|
||||
// Track old cache state to free after each chunk
|
||||
var oldCacheState []*mlx.Array
|
||||
|
||||
// For multimodal models with an image, we need to process all tokens together
|
||||
// in the first forward pass so the image embeddings can be inserted properly.
|
||||
// Skip chunking for multimodal prefill.
|
||||
isMultimodal := d.image != nil
|
||||
|
||||
// Process all-but-1 tokens in chunks, eval cache state for memory management
|
||||
// Skip chunking for multimodal - process everything in the final step
|
||||
if !isMultimodal {
|
||||
for len(inputIDs) > 1 {
|
||||
chunkSize := min(2048, len(inputIDs)-1)
|
||||
if chunkSize <= 0 {
|
||||
break
|
||||
}
|
||||
chunk := inputIDs[:chunkSize]
|
||||
|
||||
// Save old cache state before forward
|
||||
oldCacheState = oldCacheState[:0]
|
||||
for _, c := range d.caches {
|
||||
oldCacheState = append(oldCacheState, c.State()...)
|
||||
}
|
||||
|
||||
var cacheState []*mlx.Array
|
||||
withStream(func() {
|
||||
x := mlx.NewArrayInt32(chunk, []int32{1, int32(len(chunk))})
|
||||
d.model.Forward(x, d.caches)
|
||||
for _, c := range d.caches {
|
||||
cacheState = append(cacheState, c.State()...)
|
||||
}
|
||||
})
|
||||
mlx.Eval(cacheState...)
|
||||
|
||||
// Free old cache state
|
||||
for _, arr := range oldCacheState {
|
||||
if arr != nil {
|
||||
arr.Free()
|
||||
}
|
||||
}
|
||||
|
||||
inputIDs = inputIDs[chunkSize:]
|
||||
processed += chunkSize
|
||||
}
|
||||
}
|
||||
|
||||
// Save old cache state before final step
|
||||
oldCacheState = oldCacheState[:0]
|
||||
for _, c := range d.caches {
|
||||
oldCacheState = append(oldCacheState, c.State()...)
|
||||
}
|
||||
|
||||
// Final token + sampling (or all tokens for multimodal)
|
||||
withStream(func() {
|
||||
x := mlx.NewArrayInt32(inputIDs, []int32{1, int32(len(inputIDs))})
|
||||
mlx.Eval(x) // Materialize before any other evals
|
||||
|
||||
var logits *mlx.Array
|
||||
// Use ForwardWithImage if we have an image and model supports it
|
||||
if d.image != nil {
|
||||
if mm, ok := d.model.(MultimodalModel); ok {
|
||||
logits = mm.ForwardWithImage(x, d.image, d.caches)
|
||||
d.image = nil // Only use image for first forward
|
||||
} else {
|
||||
logits = d.model.Forward(x, d.caches)
|
||||
}
|
||||
} else {
|
||||
logits = d.model.Forward(x, d.caches)
|
||||
}
|
||||
d.token = sample(logits, d.temp, d.topK, d.topP, d.vocabSize)
|
||||
})
|
||||
// Keep cache state (token auto-kept by AsyncEval)
|
||||
for _, c := range d.caches {
|
||||
mlx.Keep(c.State()...)
|
||||
}
|
||||
mlx.AsyncEval(d.token)
|
||||
|
||||
// Free old cache state from before final step
|
||||
for _, arr := range oldCacheState {
|
||||
if arr != nil {
|
||||
arr.Free()
|
||||
}
|
||||
}
|
||||
|
||||
mlx.ClearCache()
|
||||
|
||||
return processed + len(inputIDs)
|
||||
}
|
||||
|
||||
func (d *Decoder) step() int32 {
|
||||
prevToken := d.token
|
||||
|
||||
// Save old cache state (reuse preallocated slice)
|
||||
d.oldCacheState = d.oldCacheState[:0]
|
||||
for _, c := range d.caches {
|
||||
d.oldCacheState = append(d.oldCacheState, c.State()...)
|
||||
}
|
||||
|
||||
withStream(func() {
|
||||
logits := d.model.Forward(mlx.Reshape(prevToken, 1, 1), d.caches)
|
||||
d.token = sample(logits, d.temp, d.topK, d.topP, d.vocabSize)
|
||||
})
|
||||
// Keep token and new cache state so they survive cleanup
|
||||
mlx.Keep(d.token)
|
||||
for _, c := range d.caches {
|
||||
mlx.Keep(c.State()...)
|
||||
}
|
||||
mlx.AsyncEval(d.token)
|
||||
|
||||
// Sync on previous token (GPU already working on next step)
|
||||
val := prevToken.ItemInt32()
|
||||
|
||||
// Free old token and old cache state
|
||||
prevToken.Free()
|
||||
for _, arr := range d.oldCacheState {
|
||||
arr.Free()
|
||||
}
|
||||
return val
|
||||
}
|
||||
|
||||
func generate(ctx context.Context, m Model, in input, cb func(output)) error {
|
||||
mlx.EnableCompile()
|
||||
wiredLimit := in.WiredLimitGB
|
||||
if wiredLimit <= 0 {
|
||||
wiredLimit = 32 // default 32GB
|
||||
}
|
||||
mlx.MetalSetWiredLimit(uint64(wiredLimit) << 30)
|
||||
|
||||
temp := in.Temperature
|
||||
if temp < 0 {
|
||||
temp = 0.7
|
||||
}
|
||||
|
||||
tok := m.Tokenizer()
|
||||
dec := NewDecoder(m, temp, in.TopK, in.TopP)
|
||||
|
||||
// Apply chat template - use image template if we have an image
|
||||
prompt := in.Prompt
|
||||
var tokens []int32
|
||||
if mm, ok := m.(MultimodalModel); ok && in.Image != nil {
|
||||
prompt = mm.FormatPromptWithImage(prompt)
|
||||
tokens = tok.Encode(prompt, true)
|
||||
tokens = mm.ExpandImageTokens(tokens) // Expand <start_of_image> to 256 image tokens
|
||||
dec.SetImage(in.Image)
|
||||
} else if cm, ok := m.(ChatModel); ok {
|
||||
prompt = cm.FormatPrompt(prompt)
|
||||
tokens = tok.Encode(prompt, true)
|
||||
} else {
|
||||
tokens = tok.Encode(prompt, true)
|
||||
}
|
||||
|
||||
prefillStart := time.Now()
|
||||
prefillTokens := dec.prefill(tokens)
|
||||
// Prefill measurement should include time to first token (like mlx-lm)
|
||||
// Step() waits for prefill to complete and returns first token
|
||||
firstToken := dec.step()
|
||||
prefillTokSec := float64(prefillTokens) / time.Since(prefillStart).Seconds()
|
||||
|
||||
genStart := time.Now()
|
||||
maxTokens := max(in.MaxTokens, 100)
|
||||
var genTokens int
|
||||
|
||||
// UTF-8 streamer to handle partial multi-byte characters
|
||||
streamer := &utf8Streamer{}
|
||||
|
||||
// Handle first token
|
||||
genTokens++
|
||||
if tok.IsEOS(firstToken) {
|
||||
cb(output{Done: true, PrefillTokSec: prefillTokSec, GenTokSec: 0})
|
||||
return nil
|
||||
}
|
||||
if text := streamer.Write(tok.Decode([]int32{firstToken})); text != "" {
|
||||
cb(output{Text: text})
|
||||
}
|
||||
|
||||
for n := 1; n < maxTokens; n++ {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
token := dec.step()
|
||||
genTokens++
|
||||
|
||||
if tok.IsEOS(token) {
|
||||
break
|
||||
}
|
||||
if text := streamer.Write(tok.Decode([]int32{token})); text != "" {
|
||||
cb(output{Text: text})
|
||||
}
|
||||
|
||||
if n%256 == 0 {
|
||||
mlx.ClearCache()
|
||||
}
|
||||
}
|
||||
|
||||
// Flush any remaining buffered bytes
|
||||
if text := streamer.Flush(); text != "" {
|
||||
cb(output{Text: text})
|
||||
}
|
||||
|
||||
fmt.Printf("\nPeak memory: %.2fGB\n", float64(mlx.MetalGetPeakMemory())/(1<<30))
|
||||
cb(output{Done: true, PrefillTokSec: prefillTokSec,
|
||||
GenTokSec: float64(genTokens) / time.Since(genStart).Seconds()})
|
||||
return nil
|
||||
}
|
||||
@@ -1,89 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"image"
|
||||
"image/png"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// saveImageArray saves an MLX array as a PNG image.
|
||||
// Expected format: [B, C, H, W] with values in [0, 1] range and C=3 (RGB).
|
||||
func saveImageArray(arr *mlx.Array, path string) error {
|
||||
img, err := arrayToImage(arr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return savePNG(img, path)
|
||||
}
|
||||
|
||||
func savePNG(img *image.RGBA, path string) error {
|
||||
if filepath.Ext(path) != ".png" {
|
||||
path = path + ".png"
|
||||
}
|
||||
f, err := os.Create(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
return png.Encode(f, img)
|
||||
}
|
||||
|
||||
func arrayToImage(arr *mlx.Array) (*image.RGBA, error) {
|
||||
shape := arr.Shape()
|
||||
if len(shape) != 4 {
|
||||
return nil, fmt.Errorf("expected 4D array [B, C, H, W], got %v", shape)
|
||||
}
|
||||
|
||||
// Transform to [H, W, C] for image conversion
|
||||
img := mlx.Squeeze(arr, 0)
|
||||
arr.Free()
|
||||
img = mlx.Transpose(img, 1, 2, 0)
|
||||
img = mlx.Contiguous(img)
|
||||
mlx.Eval(img)
|
||||
|
||||
imgShape := img.Shape()
|
||||
H := int(imgShape[0])
|
||||
W := int(imgShape[1])
|
||||
C := int(imgShape[2])
|
||||
|
||||
if C != 3 {
|
||||
img.Free()
|
||||
return nil, fmt.Errorf("expected 3 channels (RGB), got %d", C)
|
||||
}
|
||||
|
||||
// Copy to CPU and free GPU memory
|
||||
data := img.Data()
|
||||
img.Free()
|
||||
|
||||
// Write directly to Pix slice (faster than SetRGBA)
|
||||
goImg := image.NewRGBA(image.Rect(0, 0, W, H))
|
||||
pix := goImg.Pix
|
||||
for y := 0; y < H; y++ {
|
||||
for x := 0; x < W; x++ {
|
||||
srcIdx := (y*W + x) * C
|
||||
dstIdx := (y*W + x) * 4
|
||||
pix[dstIdx+0] = uint8(clampF(data[srcIdx+0]*255+0.5, 0, 255))
|
||||
pix[dstIdx+1] = uint8(clampF(data[srcIdx+1]*255+0.5, 0, 255))
|
||||
pix[dstIdx+2] = uint8(clampF(data[srcIdx+2]*255+0.5, 0, 255))
|
||||
pix[dstIdx+3] = 255
|
||||
}
|
||||
}
|
||||
|
||||
return goImg, nil
|
||||
}
|
||||
|
||||
func clampF(v, min, max float32) float32 {
|
||||
if v < min {
|
||||
return min
|
||||
}
|
||||
if v > max {
|
||||
return max
|
||||
}
|
||||
return v
|
||||
}
|
||||
@@ -1,286 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime/pprof"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/models/gemma3"
|
||||
"github.com/ollama/ollama/x/imagegen/models/gpt_oss"
|
||||
"github.com/ollama/ollama/x/imagegen/models/llama"
|
||||
"github.com/ollama/ollama/x/imagegen/models/qwen_image"
|
||||
"github.com/ollama/ollama/x/imagegen/models/qwen_image_edit"
|
||||
"github.com/ollama/ollama/x/imagegen/models/zimage"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
|
||||
// stringSlice is a flag type that accumulates multiple values
|
||||
type stringSlice []string
|
||||
|
||||
func (s *stringSlice) String() string {
|
||||
return fmt.Sprintf("%v", *s)
|
||||
}
|
||||
|
||||
func (s *stringSlice) Set(value string) error {
|
||||
*s = append(*s, value)
|
||||
return nil
|
||||
}
|
||||
|
||||
func main() {
|
||||
modelPath := flag.String("model", "", "Model directory")
|
||||
prompt := flag.String("prompt", "Hello", "Prompt")
|
||||
|
||||
// Text generation params
|
||||
maxTokens := flag.Int("max-tokens", 100, "Max tokens")
|
||||
temperature := flag.Float64("temperature", 0.7, "Temperature")
|
||||
topP := flag.Float64("top-p", 0.9, "Top-p sampling")
|
||||
topK := flag.Int("top-k", 40, "Top-k sampling")
|
||||
imagePath := flag.String("image", "", "Image path for multimodal models")
|
||||
|
||||
// Image generation params
|
||||
width := flag.Int("width", 1024, "Image width")
|
||||
height := flag.Int("height", 1024, "Image height")
|
||||
steps := flag.Int("steps", 9, "Denoising steps")
|
||||
seed := flag.Int64("seed", 42, "Random seed")
|
||||
out := flag.String("output", "output.png", "Output path")
|
||||
|
||||
// Utility flags
|
||||
listTensors := flag.Bool("list", false, "List tensors only")
|
||||
cpuProfile := flag.String("cpuprofile", "", "Write CPU profile to file")
|
||||
gpuCapture := flag.String("gpu-capture", "", "Capture GPU trace to .gputrace file (run with MTL_CAPTURE_ENABLED=1)")
|
||||
layerCache := flag.Bool("layer-cache", false, "Enable layer caching for faster diffusion (Z-Image, Qwen-Image). Not compatible with CFG/negative prompts.")
|
||||
wiredLimitGB := flag.Int("wired-limit", 32, "Metal wired memory limit in GB")
|
||||
|
||||
// Legacy mode flags
|
||||
zimageFlag := flag.Bool("zimage", false, "Z-Image generation")
|
||||
qwenImage := flag.Bool("qwen-image", false, "Qwen-Image text-to-image generation")
|
||||
qwenImageEdit := flag.Bool("qwen-image-edit", false, "Qwen-Image-Edit image editing")
|
||||
var inputImages stringSlice
|
||||
flag.Var(&inputImages, "input-image", "Input image for image editing (can be specified multiple times)")
|
||||
negativePrompt := flag.String("negative-prompt", "", "Negative prompt for CFG (empty = no CFG, matching Python)")
|
||||
cfgScale := flag.Float64("cfg-scale", 4.0, "CFG scale for image editing")
|
||||
|
||||
flag.Parse()
|
||||
|
||||
if *modelPath == "" {
|
||||
flag.Usage()
|
||||
return
|
||||
}
|
||||
|
||||
// CPU profiling
|
||||
if *cpuProfile != "" {
|
||||
f, err := os.Create(*cpuProfile)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
if err := pprof.StartCPUProfile(f); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer pprof.StopCPUProfile()
|
||||
}
|
||||
|
||||
var err error
|
||||
|
||||
// Handle legacy mode flags that aren't unified yet
|
||||
switch {
|
||||
case *zimageFlag:
|
||||
m := &zimage.Model{}
|
||||
if loadErr := m.Load(*modelPath); loadErr != nil {
|
||||
log.Fatal(loadErr)
|
||||
}
|
||||
var img *mlx.Array
|
||||
img, err = m.GenerateFromConfig(&zimage.GenerateConfig{
|
||||
Prompt: *prompt,
|
||||
Width: int32(*width),
|
||||
Height: int32(*height),
|
||||
Steps: *steps,
|
||||
Seed: *seed,
|
||||
CapturePath: *gpuCapture,
|
||||
LayerCache: *layerCache,
|
||||
})
|
||||
if err == nil {
|
||||
err = saveImageArray(img, *out)
|
||||
}
|
||||
case *qwenImage:
|
||||
m, loadErr := qwen_image.LoadPersistent(*modelPath)
|
||||
if loadErr != nil {
|
||||
log.Fatal(loadErr)
|
||||
}
|
||||
var img *mlx.Array
|
||||
img, err = m.GenerateFromConfig(&qwen_image.GenerateConfig{
|
||||
Prompt: *prompt,
|
||||
NegativePrompt: *negativePrompt,
|
||||
CFGScale: float32(*cfgScale),
|
||||
Width: int32(*width),
|
||||
Height: int32(*height),
|
||||
Steps: *steps,
|
||||
Seed: *seed,
|
||||
LayerCache: *layerCache,
|
||||
})
|
||||
if err == nil {
|
||||
err = saveImageArray(img, *out)
|
||||
}
|
||||
case *qwenImageEdit:
|
||||
if len(inputImages) == 0 {
|
||||
log.Fatal("qwen-image-edit requires at least one -input-image")
|
||||
}
|
||||
|
||||
m, loadErr := qwen_image_edit.LoadPersistent(*modelPath)
|
||||
if loadErr != nil {
|
||||
log.Fatal(loadErr)
|
||||
}
|
||||
// For image editing, use 0 for dimensions to auto-detect from input image
|
||||
// unless explicitly overridden from defaults
|
||||
editWidth := int32(0)
|
||||
editHeight := int32(0)
|
||||
if *width != 1024 {
|
||||
editWidth = int32(*width)
|
||||
}
|
||||
if *height != 1024 {
|
||||
editHeight = int32(*height)
|
||||
}
|
||||
|
||||
cfg := &qwen_image_edit.GenerateConfig{
|
||||
Prompt: *prompt,
|
||||
NegativePrompt: *negativePrompt,
|
||||
CFGScale: float32(*cfgScale),
|
||||
Width: editWidth,
|
||||
Height: editHeight,
|
||||
Steps: *steps,
|
||||
Seed: *seed,
|
||||
}
|
||||
|
||||
var img *mlx.Array
|
||||
img, err = m.EditFromConfig(inputImages, cfg)
|
||||
if err == nil {
|
||||
err = saveImageArray(img, *out)
|
||||
}
|
||||
case *listTensors:
|
||||
err = listModelTensors(*modelPath)
|
||||
default:
|
||||
// llm path
|
||||
m, err := load(*modelPath)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Load image if provided and model supports it
|
||||
var image *mlx.Array
|
||||
if *imagePath != "" {
|
||||
if mm, ok := m.(interface{ ImageSize() int32 }); ok {
|
||||
image, err = gemma3.ProcessImage(*imagePath, mm.ImageSize())
|
||||
if err != nil {
|
||||
log.Fatal("load image:", err)
|
||||
}
|
||||
} else {
|
||||
log.Fatal("model does not support image input")
|
||||
}
|
||||
}
|
||||
|
||||
err = generate(context.Background(), m, input{
|
||||
Prompt: *prompt,
|
||||
Image: image,
|
||||
MaxTokens: *maxTokens,
|
||||
Temperature: float32(*temperature),
|
||||
TopP: float32(*topP),
|
||||
TopK: *topK,
|
||||
WiredLimitGB: *wiredLimitGB,
|
||||
}, func(out output) {
|
||||
if out.Text != "" {
|
||||
fmt.Print(out.Text)
|
||||
}
|
||||
if out.Done {
|
||||
fmt.Printf("\n\n[prefill: %.1f tok/s, gen: %.1f tok/s]\n", out.PrefillTokSec, out.GenTokSec)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func listModelTensors(modelPath string) error {
|
||||
weights, err := safetensors.LoadModelWeights(modelPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, name := range weights.ListTensors() {
|
||||
info, _ := weights.GetTensorInfo(name)
|
||||
fmt.Printf("%s: %v (%s)\n", name, info.Shape, info.Dtype)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadModel builds and evaluates a model using the common load pattern.
|
||||
// Release safetensors BEFORE eval - lazy arrays have captured their data,
|
||||
// and this reduces peak memory by ~6GB (matches mlx-lm behavior).
|
||||
func loadModel[T Model](build func() T, cleanup func()) T {
|
||||
m := build()
|
||||
weights := mlx.Collect(m)
|
||||
cleanup()
|
||||
mlx.Eval(weights...)
|
||||
return m
|
||||
}
|
||||
|
||||
func load(modelPath string) (Model, error) {
|
||||
kind, err := detectModelKind(modelPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("detect model kind: %w", err)
|
||||
}
|
||||
|
||||
switch kind {
|
||||
case "gpt_oss":
|
||||
return gpt_oss.Load(modelPath)
|
||||
case "gemma3":
|
||||
return gemma3.Load(modelPath)
|
||||
case "gemma3_text":
|
||||
return gemma3.LoadText(modelPath)
|
||||
default:
|
||||
return llama.Load(modelPath)
|
||||
}
|
||||
}
|
||||
|
||||
func detectModelKind(modelPath string) (string, error) {
|
||||
indexPath := filepath.Join(modelPath, "model_index.json")
|
||||
if _, err := os.Stat(indexPath); err == nil {
|
||||
data, err := os.ReadFile(indexPath)
|
||||
if err != nil {
|
||||
return "zimage", nil
|
||||
}
|
||||
var index struct {
|
||||
ClassName string `json:"_class_name"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &index); err == nil {
|
||||
switch index.ClassName {
|
||||
case "FluxPipeline", "ZImagePipeline":
|
||||
return "zimage", nil
|
||||
}
|
||||
}
|
||||
return "zimage", nil
|
||||
}
|
||||
|
||||
configPath := filepath.Join(modelPath, "config.json")
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("no config.json or model_index.json found: %w", err)
|
||||
}
|
||||
|
||||
var cfg struct {
|
||||
ModelType string `json:"model_type"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return "", fmt.Errorf("parse config.json: %w", err)
|
||||
}
|
||||
|
||||
return cfg.ModelType, nil
|
||||
}
|
||||
@@ -1,49 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package main
|
||||
|
||||
import "github.com/ollama/ollama/x/imagegen/mlx"
|
||||
|
||||
// sampleTopK samples from top-k logits using global random state
|
||||
func sampleTopK(scaledLogits *mlx.Array, k int) *mlx.Array {
|
||||
neg := mlx.Neg(scaledLogits)
|
||||
indices := mlx.Argpartition(neg, k-1, -1)
|
||||
topKIdx := mlx.Slice(indices, []int32{0}, []int32{int32(k)})
|
||||
values := mlx.TakeAlongAxis(scaledLogits, topKIdx, -1)
|
||||
sampled := mlx.RandomCategorical(values, -1, 1)
|
||||
return mlx.Take(topKIdx, sampled, -1)
|
||||
}
|
||||
|
||||
// sampleTopP samples using nucleus sampling with global random state
|
||||
func sampleTopP(scaledLogits *mlx.Array, p float32, vocabSize int32) *mlx.Array {
|
||||
sorted := mlx.Argsort(mlx.Neg(scaledLogits), -1)
|
||||
sortedLogits := mlx.TakeAlongAxis(scaledLogits, sorted, -1)
|
||||
probs := mlx.Softmax(sortedLogits, -1)
|
||||
cumProbs := mlx.Cumsum(probs, -1)
|
||||
mask := mlx.LessScalar(cumProbs, p)
|
||||
negInf := mlx.FullDtype(float32(-1e9), scaledLogits.Dtype(), vocabSize)
|
||||
masked := mlx.Where(mask, sortedLogits, negInf)
|
||||
sampled := mlx.RandomCategorical(masked, -1, 1)
|
||||
return mlx.Take(sorted, sampled, -1)
|
||||
}
|
||||
|
||||
// sample samples from logits at the last position
|
||||
func sample(logits *mlx.Array, temp float32, topK int, topP float32, vocab int32) *mlx.Array {
|
||||
// Get last position logits: [1, L, vocab] -> [vocab]
|
||||
shape := logits.Shape()
|
||||
seqLen := shape[1]
|
||||
lastLogits := mlx.Slice(logits, []int32{0, seqLen - 1, 0}, []int32{1, seqLen, vocab})
|
||||
lastLogits = mlx.Reshape(lastLogits, vocab)
|
||||
|
||||
if temp == 0 {
|
||||
return mlx.Argmax(lastLogits, -1, false)
|
||||
}
|
||||
scaled := mlx.DivScalar(lastLogits, temp)
|
||||
if topK > 0 && topK < int(vocab) {
|
||||
return sampleTopK(scaled, topK)
|
||||
}
|
||||
if topP > 0 && topP < 1.0 {
|
||||
return sampleTopP(scaled, topP, vocab)
|
||||
}
|
||||
return mlx.RandomCategorical(scaled, -1, 1)
|
||||
}
|
||||
@@ -1,46 +0,0 @@
|
||||
# MLX Memory Management
|
||||
|
||||
| This package will get consolidated with `x/ml/backend/mlx` in the future.
|
||||
|
||||
## Automatic Tracking
|
||||
|
||||
All arrays are automatically tracked when created. On `Eval()`, non-kept arrays are freed.
|
||||
|
||||
### API
|
||||
|
||||
```go
|
||||
result := mlx.Matmul(x, w) // arrays automatically tracked
|
||||
mlx.Eval(result) // free non-kept, eval result (auto-kept)
|
||||
```
|
||||
|
||||
### Key Functions
|
||||
|
||||
- `mlx.Eval(outputs...)` - free non-kept arrays, then evaluate (outputs auto-kept)
|
||||
- `mlx.AsyncEval(outputs...)` - async version of Eval (outputs auto-kept)
|
||||
- `mlx.Keep(arrays...)` - mark arrays to survive cleanup (for weights, caches)
|
||||
- `array.Free()` - mark array for cleanup on next Eval
|
||||
|
||||
### Loop Pattern
|
||||
|
||||
```go
|
||||
for step := 0; step < maxTokens; step++ {
|
||||
logits := model.Forward(token, caches)
|
||||
oldToken := token
|
||||
token = sample(logits)
|
||||
|
||||
// Keep cache state across iterations
|
||||
for _, c := range caches {
|
||||
mlx.Keep(c.State()...)
|
||||
}
|
||||
|
||||
oldToken.Free() // mark for cleanup
|
||||
mlx.AsyncEval(token) // frees old, evals new
|
||||
}
|
||||
```
|
||||
|
||||
### Notes
|
||||
|
||||
- `Eval()` and `AsyncEval()` auto-keep their outputs
|
||||
- `Free()` marks for cleanup - actual free happens during next Eval
|
||||
- Use `Keep()` for weights and cache state that must survive multiple Eval cycles
|
||||
- Arrays created inside compiled closures are managed by MLX, not tracked
|
||||
@@ -1,173 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package mlx
|
||||
|
||||
/*
|
||||
#include "mlx/c/mlx.h"
|
||||
#include <stdlib.h>
|
||||
|
||||
// Forward declaration for Go callback
|
||||
extern int goClosureCallback(mlx_vector_array* res, mlx_vector_array input, void* payload);
|
||||
|
||||
// Destructor for payload (Go handle)
|
||||
extern void goClosureDestructor(void* payload);
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"runtime/cgo"
|
||||
"sync"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// inClosureCallback is set to true during closure callback execution.
|
||||
var inClosureCallback bool
|
||||
var closureCallbackMu sync.Mutex
|
||||
|
||||
// InClosureCallback returns true if we're currently executing inside a closure callback.
|
||||
func InClosureCallback() bool {
|
||||
closureCallbackMu.Lock()
|
||||
defer closureCallbackMu.Unlock()
|
||||
return inClosureCallback
|
||||
}
|
||||
|
||||
// CompiledFunc is a compiled MLX function that can be called efficiently.
|
||||
// All intermediate arrays during execution stay inside MLX - only inputs
|
||||
// and outputs cross the Go boundary.
|
||||
type CompiledFunc struct {
|
||||
closure C.mlx_closure
|
||||
compiled C.mlx_closure
|
||||
}
|
||||
|
||||
// ClosureFunc is the signature for functions that can be compiled.
|
||||
// It takes a slice of input arrays and returns a slice of output arrays.
|
||||
type ClosureFunc func(inputs []*Array) []*Array
|
||||
|
||||
// Compile compiles a Go function into an optimized MLX closure.
|
||||
// The function is traced once during compilation, then subsequent calls
|
||||
// run the optimized graph without creating Go intermediate arrays.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// compiled := mlx.Compile(func(inputs []*mlx.Array) []*mlx.Array {
|
||||
// a, b := inputs[0], inputs[1]
|
||||
// c := mlx.Add(a, b)
|
||||
// d := mlx.Mul(c, c)
|
||||
// return []*mlx.Array{d}
|
||||
// })
|
||||
// defer compiled.Free()
|
||||
//
|
||||
// result := compiled.Call(x, y)[0]
|
||||
func Compile(fn ClosureFunc) *CompiledFunc {
|
||||
return CompileShapeless(fn, false)
|
||||
}
|
||||
|
||||
// CompileShapeless compiles with optional shapeless mode.
|
||||
// If shapeless=true, the function works for any input shape after tracing.
|
||||
func CompileShapeless(fn ClosureFunc, shapeless bool) *CompiledFunc {
|
||||
// Create a cgo.Handle to prevent the Go function from being GC'd
|
||||
handle := cgo.NewHandle(fn)
|
||||
|
||||
// Create the closure from the Go callback
|
||||
closure := C.mlx_closure_new_func_payload(
|
||||
(*[0]byte)(C.goClosureCallback),
|
||||
unsafe.Pointer(handle),
|
||||
(*[0]byte)(C.goClosureDestructor),
|
||||
)
|
||||
|
||||
// Compile the closure
|
||||
compiled := C.mlx_closure_new()
|
||||
C.mlx_compile(&compiled, closure, C.bool(shapeless))
|
||||
|
||||
return &CompiledFunc{
|
||||
closure: closure,
|
||||
compiled: compiled,
|
||||
}
|
||||
}
|
||||
|
||||
// Call invokes the compiled function with the given inputs.
|
||||
func (cf *CompiledFunc) Call(inputs ...*Array) []*Array {
|
||||
// Pack inputs into vector
|
||||
inputVec := C.mlx_vector_array_new()
|
||||
for _, arr := range inputs {
|
||||
C.mlx_vector_array_append_value(inputVec, arr.c)
|
||||
}
|
||||
|
||||
// Apply compiled closure
|
||||
outputVec := C.mlx_vector_array_new()
|
||||
C.mlx_closure_apply(&outputVec, cf.compiled, inputVec)
|
||||
C.mlx_vector_array_free(inputVec)
|
||||
|
||||
// Unpack outputs
|
||||
numOutputs := int(C.mlx_vector_array_size(outputVec))
|
||||
outputs := make([]*Array, numOutputs)
|
||||
for i := 0; i < numOutputs; i++ {
|
||||
var arr C.mlx_array
|
||||
C.mlx_vector_array_get(&arr, outputVec, C.size_t(i))
|
||||
outputs[i] = newArray(arr)
|
||||
}
|
||||
C.mlx_vector_array_free(outputVec)
|
||||
|
||||
return outputs
|
||||
}
|
||||
|
||||
// CallEval invokes the compiled function and evaluates the results.
|
||||
func (cf *CompiledFunc) CallEval(inputs ...*Array) []*Array {
|
||||
outputs := cf.Call(inputs...)
|
||||
Eval(outputs...)
|
||||
return outputs
|
||||
}
|
||||
|
||||
// Free releases the compiled function resources.
|
||||
func (cf *CompiledFunc) Free() {
|
||||
C.mlx_closure_free(cf.compiled)
|
||||
C.mlx_closure_free(cf.closure)
|
||||
}
|
||||
|
||||
// borrowArray wraps a C array WITHOUT setting up GC cleanup.
|
||||
// Use this for arrays we don't own (e.g., borrowed references in callbacks).
|
||||
func borrowArray(array C.mlx_array) *Array {
|
||||
return &Array{c: array}
|
||||
}
|
||||
|
||||
//export goClosureCallback
|
||||
func goClosureCallback(res *C.mlx_vector_array, input C.mlx_vector_array, payload unsafe.Pointer) C.int {
|
||||
// Set flag to disable AddCleanup during callback
|
||||
closureCallbackMu.Lock()
|
||||
inClosureCallback = true
|
||||
closureCallbackMu.Unlock()
|
||||
defer func() {
|
||||
closureCallbackMu.Lock()
|
||||
inClosureCallback = false
|
||||
closureCallbackMu.Unlock()
|
||||
}()
|
||||
|
||||
// Recover the Go function from the handle
|
||||
handle := cgo.Handle(payload)
|
||||
fn := handle.Value().(ClosureFunc)
|
||||
|
||||
// Convert input vector to Go slice - use borrowArray since MLX owns these
|
||||
numInputs := int(C.mlx_vector_array_size(input))
|
||||
inputs := make([]*Array, numInputs)
|
||||
for i := 0; i < numInputs; i++ {
|
||||
var arr C.mlx_array
|
||||
C.mlx_vector_array_get(&arr, input, C.size_t(i))
|
||||
inputs[i] = borrowArray(arr) // Don't set up cleanup - MLX owns these
|
||||
}
|
||||
|
||||
// Call the Go function
|
||||
outputs := fn(inputs)
|
||||
|
||||
// Build output vector
|
||||
*res = C.mlx_vector_array_new()
|
||||
for _, arr := range outputs {
|
||||
C.mlx_vector_array_append_value(*res, arr.c)
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
//export goClosureDestructor
|
||||
func goClosureDestructor(payload unsafe.Pointer) {
|
||||
handle := cgo.Handle(payload)
|
||||
handle.Delete()
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,614 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package gemma3
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
)
|
||||
|
||||
// TextConfig holds configuration for the text model
|
||||
type TextConfig struct {
|
||||
HiddenSize int32 `json:"hidden_size"`
|
||||
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
||||
IntermediateSize int32 `json:"intermediate_size"`
|
||||
NumAttentionHeads int32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads int32 `json:"num_key_value_heads"`
|
||||
HeadDim int32 `json:"head_dim"`
|
||||
VocabSize int32 `json:"vocab_size"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
RopeLocalBaseFreq float32 `json:"rope_local_base_freq"`
|
||||
MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
|
||||
SlidingWindow int32 `json:"sliding_window"`
|
||||
SlidingWindowPattern int32 `json:"sliding_window_pattern"`
|
||||
|
||||
// Computed fields
|
||||
Scale float32 `json:"-"`
|
||||
}
|
||||
|
||||
// TextModel is the Gemma 3 text-only model
|
||||
type TextModel struct {
|
||||
EmbedTokens *nn.Embedding `weight:"model.embed_tokens"`
|
||||
Layers []*DecoderLayer `weight:"model.layers"`
|
||||
Norm *nn.RMSNorm `weight:"model.norm"`
|
||||
Output *nn.Linear `weight:"-"` // Tied to EmbedTokens, set manually
|
||||
|
||||
// Precomputed (1 + weight) for Gemma-style RMSNorm to avoid allocation per forward
|
||||
NormScaled *mlx.Array `weight:"-"`
|
||||
|
||||
tok *tokenizer.Tokenizer
|
||||
*TextConfig
|
||||
}
|
||||
|
||||
// DecoderLayer is a single transformer block
|
||||
type DecoderLayer struct {
|
||||
InputNorm *nn.RMSNorm `weight:"input_layernorm"`
|
||||
Attention *Attention
|
||||
PostAttnNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
|
||||
PreFFNorm *nn.RMSNorm `weight:"pre_feedforward_layernorm"`
|
||||
MLP *MLP
|
||||
PostFFNorm *nn.RMSNorm `weight:"post_feedforward_layernorm"`
|
||||
|
||||
// Precomputed (1 + weight) for Gemma-style RMSNorm
|
||||
InputNormScaled *mlx.Array `weight:"-"`
|
||||
PostAttnNormScaled *mlx.Array `weight:"-"`
|
||||
PreFFNormScaled *mlx.Array `weight:"-"`
|
||||
PostFFNormScaled *mlx.Array `weight:"-"`
|
||||
|
||||
// Whether this layer uses sliding window attention
|
||||
IsSliding bool
|
||||
LayerIdx int32
|
||||
}
|
||||
|
||||
// Attention implements Gemma 3 attention with Q/K normalization
|
||||
type Attention struct {
|
||||
QProj *nn.Linear `weight:"self_attn.q_proj"`
|
||||
KProj *nn.Linear `weight:"self_attn.k_proj"`
|
||||
VProj *nn.Linear `weight:"self_attn.v_proj"`
|
||||
OProj *nn.Linear `weight:"self_attn.o_proj"`
|
||||
QNorm *nn.RMSNorm `weight:"self_attn.q_norm"`
|
||||
KNorm *nn.RMSNorm `weight:"self_attn.k_norm"`
|
||||
|
||||
// Precomputed (1 + weight) for Gemma-style RMSNorm
|
||||
QNormScaled *mlx.Array `weight:"-"`
|
||||
KNormScaled *mlx.Array `weight:"-"`
|
||||
}
|
||||
|
||||
// MLP is the feed-forward network with GELU activation
|
||||
type MLP struct {
|
||||
GateProj *nn.Linear `weight:"mlp.gate_proj"`
|
||||
UpProj *nn.Linear `weight:"mlp.up_proj"`
|
||||
DownProj *nn.Linear `weight:"mlp.down_proj"`
|
||||
}
|
||||
|
||||
// LoadText loads the text-only Gemma 3 model
|
||||
func LoadText(modelPath string) (*TextModel, error) {
|
||||
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
var cfg TextConfig
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("parse config: %w", err)
|
||||
}
|
||||
|
||||
// Compute scale
|
||||
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
|
||||
|
||||
// Set defaults if not specified
|
||||
if cfg.RopeTheta == 0 {
|
||||
cfg.RopeTheta = 1000000
|
||||
}
|
||||
if cfg.RopeLocalBaseFreq == 0 {
|
||||
cfg.RopeLocalBaseFreq = 10000
|
||||
}
|
||||
if cfg.RMSNormEps == 0 {
|
||||
cfg.RMSNormEps = 1e-6
|
||||
}
|
||||
|
||||
weights, err := safetensors.LoadModelWeights(modelPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
|
||||
tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load tokenizer: %w", err)
|
||||
}
|
||||
|
||||
m := &TextModel{
|
||||
Layers: make([]*DecoderLayer, cfg.NumHiddenLayers),
|
||||
TextConfig: &cfg,
|
||||
tok: tok,
|
||||
}
|
||||
|
||||
// Initialize layer metadata
|
||||
for i := range m.Layers {
|
||||
m.Layers[i] = &DecoderLayer{
|
||||
LayerIdx: int32(i),
|
||||
IsSliding: isLayerSliding(int32(i), cfg.SlidingWindowPattern),
|
||||
}
|
||||
}
|
||||
|
||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Tied embeddings for output
|
||||
m.Output = nn.NewLinear(m.EmbedTokens.Weight, nil)
|
||||
|
||||
mlx.Eval(mlx.Collect(m)...)
|
||||
weights.ReleaseAll()
|
||||
|
||||
// Precompute (1 + weight) for Gemma-style RMSNorm to avoid per-forward allocation
|
||||
precomputeGemmaScaledWeights(m)
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// precomputeGemmaScaledWeights computes (1 + weight) for all RMSNorm layers
|
||||
// This avoids creating temporary arrays on every forward pass
|
||||
func precomputeGemmaScaledWeights(m *TextModel) {
|
||||
m.NormScaled = mlx.AddScalar(m.Norm.Weight, 1.0)
|
||||
|
||||
for _, layer := range m.Layers {
|
||||
layer.InputNormScaled = mlx.AddScalar(layer.InputNorm.Weight, 1.0)
|
||||
layer.PostAttnNormScaled = mlx.AddScalar(layer.PostAttnNorm.Weight, 1.0)
|
||||
layer.PreFFNormScaled = mlx.AddScalar(layer.PreFFNorm.Weight, 1.0)
|
||||
layer.PostFFNormScaled = mlx.AddScalar(layer.PostFFNorm.Weight, 1.0)
|
||||
|
||||
layer.Attention.QNormScaled = mlx.AddScalar(layer.Attention.QNorm.Weight, 1.0)
|
||||
layer.Attention.KNormScaled = mlx.AddScalar(layer.Attention.KNorm.Weight, 1.0)
|
||||
}
|
||||
|
||||
// Eval all the precomputed weights
|
||||
var scaled []*mlx.Array
|
||||
scaled = append(scaled, m.NormScaled)
|
||||
for _, layer := range m.Layers {
|
||||
scaled = append(scaled, layer.InputNormScaled, layer.PostAttnNormScaled,
|
||||
layer.PreFFNormScaled, layer.PostFFNormScaled,
|
||||
layer.Attention.QNormScaled, layer.Attention.KNormScaled)
|
||||
}
|
||||
mlx.Eval(scaled...)
|
||||
}
|
||||
|
||||
// isLayerSliding determines if a layer uses sliding window attention
|
||||
// Pattern N means: layers 0 to N-1 sliding, N full, N+1 to 2N-1 sliding, 2N full, etc.
|
||||
func isLayerSliding(layerIdx, pattern int32) bool {
|
||||
if pattern <= 0 {
|
||||
return false // No sliding window
|
||||
}
|
||||
// Layer is full attention if (layerIdx + 1) % pattern == 0
|
||||
return (layerIdx+1)%pattern != 0
|
||||
}
|
||||
|
||||
// Forward runs the text model forward pass
|
||||
func (m *TextModel) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
|
||||
B, L := tokens.Shape()[0], tokens.Shape()[1]
|
||||
|
||||
// Get embeddings and scale by sqrt(hidden_size)
|
||||
h := m.EmbedTokens.Forward(tokens)
|
||||
h = mlx.MulScalar(h, float32(math.Sqrt(float64(m.HiddenSize))))
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
h = layer.Forward(h, caches[i], B, L, m.TextConfig)
|
||||
}
|
||||
|
||||
// Final norm and output projection
|
||||
return m.Output.Forward(mlx.RMSNorm(h, m.NormScaled, m.RMSNormEps))
|
||||
}
|
||||
|
||||
// Forward runs a decoder layer
|
||||
func (l *DecoderLayer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *TextConfig) *mlx.Array {
|
||||
// Pre-attention norm (use precomputed scaled weight)
|
||||
normed := mlx.RMSNorm(x, l.InputNormScaled, cfg.RMSNormEps)
|
||||
|
||||
// Attention
|
||||
attnOut := l.Attention.Forward(normed, c, B, L, l.IsSliding, cfg)
|
||||
|
||||
// Post-attention norm and residual
|
||||
attnOut = mlx.RMSNorm(attnOut, l.PostAttnNormScaled, cfg.RMSNormEps)
|
||||
h := mlx.Add(x, attnOut)
|
||||
|
||||
// Pre-FFN norm
|
||||
normed = mlx.RMSNorm(h, l.PreFFNormScaled, cfg.RMSNormEps)
|
||||
|
||||
// MLP
|
||||
mlpOut := l.MLP.Forward(normed)
|
||||
|
||||
// Post-FFN norm and residual
|
||||
mlpOut = mlx.RMSNorm(mlpOut, l.PostFFNormScaled, cfg.RMSNormEps)
|
||||
return mlx.Add(h, mlpOut)
|
||||
}
|
||||
|
||||
// Forward runs attention with Q/K normalization
|
||||
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding bool, cfg *TextConfig) *mlx.Array {
|
||||
q := a.QProj.Forward(x)
|
||||
k := a.KProj.Forward(x)
|
||||
v := a.VProj.Forward(x)
|
||||
|
||||
// Reshape to [B, num_heads, L, head_dim]
|
||||
q = mlx.AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim},
|
||||
[]int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0)
|
||||
k = mlx.AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
|
||||
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
|
||||
v = mlx.AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
|
||||
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
|
||||
|
||||
// Q/K normalization after reshaping (use precomputed scaled weight)
|
||||
q = mlx.RMSNorm(q, a.QNormScaled, cfg.RMSNormEps)
|
||||
k = mlx.RMSNorm(k, a.KNormScaled, cfg.RMSNormEps)
|
||||
|
||||
// Apply RoPE with appropriate theta
|
||||
ropeTheta := cfg.RopeTheta
|
||||
if isSliding {
|
||||
ropeTheta = cfg.RopeLocalBaseFreq
|
||||
}
|
||||
q = mlx.RoPE(q, int(cfg.HeadDim), false, ropeTheta, 1.0, c.Offset())
|
||||
k = mlx.RoPE(k, int(cfg.HeadDim), false, ropeTheta, 1.0, c.Offset())
|
||||
|
||||
// Update cache
|
||||
k, v = c.Update(k, v, int(L))
|
||||
|
||||
// Repeat K/V for GQA if needed
|
||||
repeatFactor := cfg.NumAttentionHeads / cfg.NumKeyValueHeads
|
||||
if repeatFactor > 1 {
|
||||
k = nn.RepeatKV(k, repeatFactor)
|
||||
v = nn.RepeatKV(v, repeatFactor)
|
||||
}
|
||||
|
||||
// Attention
|
||||
out := mlx.ScaledDotProductAttention(q, k, v, cfg.Scale, L > 1)
|
||||
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
|
||||
return a.OProj.Forward(out)
|
||||
}
|
||||
|
||||
// compiledGeluApprox is a singleton compiled GELU function shared across all layers
|
||||
var compiledGeluApprox *mlx.CompiledFunc
|
||||
|
||||
// getCompiledGeluApprox returns the compiled GELU function, creating it once if needed
|
||||
func getCompiledGeluApprox() *mlx.CompiledFunc {
|
||||
if compiledGeluApprox == nil {
|
||||
compiledGeluApprox = mlx.CompileShapeless(func(inputs []*mlx.Array) []*mlx.Array {
|
||||
return []*mlx.Array{geluApproxImpl(inputs[0])}
|
||||
}, true)
|
||||
}
|
||||
return compiledGeluApprox
|
||||
}
|
||||
|
||||
// Forward runs the MLP with GELU approximation (tanh variant)
|
||||
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
|
||||
gate := getCompiledGeluApprox().Call(m.GateProj.Forward(x))[0]
|
||||
return m.DownProj.Forward(mlx.Mul(gate, m.UpProj.Forward(x)))
|
||||
}
|
||||
|
||||
// geluApproxImpl computes GELU using the tanh approximation (gelu_pytorch_tanh):
|
||||
// 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
|
||||
func geluApproxImpl(x *mlx.Array) *mlx.Array {
|
||||
// Constants
|
||||
const sqrt2OverPi = 0.7978845608028654 // sqrt(2/pi)
|
||||
const coeff = 0.044715
|
||||
|
||||
// x^3
|
||||
x3 := mlx.Mul(mlx.Mul(x, x), x)
|
||||
// x + 0.044715 * x^3
|
||||
inner := mlx.Add(x, mlx.MulScalar(x3, coeff))
|
||||
// sqrt(2/pi) * (x + 0.044715 * x^3)
|
||||
scaled := mlx.MulScalar(inner, sqrt2OverPi)
|
||||
// tanh(...)
|
||||
tanh := mlx.Tanh(scaled)
|
||||
// 1 + tanh(...)
|
||||
onePlusTanh := mlx.AddScalar(tanh, 1.0)
|
||||
// 0.5 * x * (1 + tanh(...))
|
||||
return mlx.Mul(mlx.MulScalar(x, 0.5), onePlusTanh)
|
||||
}
|
||||
|
||||
// gemmaRMSNorm applies Gemma-style RMS normalization: x * rsqrt(mean(x^2) + eps) * (1 + weight)
|
||||
// Uses mlx.RMSNorm fast kernel with pre-computed (1 + weight)
|
||||
func gemmaRMSNorm(x, weight *mlx.Array, eps float32) *mlx.Array {
|
||||
// Gemma uses (1 + weight) instead of weight
|
||||
scaledWeight := mlx.AddScalar(weight, 1.0)
|
||||
return mlx.RMSNorm(x, scaledWeight, eps)
|
||||
}
|
||||
|
||||
// Interface methods
|
||||
func (m *TextModel) NumLayers() int { return len(m.Layers) }
|
||||
func (m *TextModel) MaxContextLength() int32 { return m.MaxPositionEmbeddings }
|
||||
func (m *TextModel) VocabSize() int32 { return m.TextConfig.VocabSize }
|
||||
|
||||
// Tokenizer returns the tokenizer wrapped to add BOS and apply chat template
|
||||
func (m *TextModel) Tokenizer() *tokenizer.Tokenizer {
|
||||
return m.tok
|
||||
}
|
||||
|
||||
// FormatPrompt applies the Gemma 3 chat template to a prompt
|
||||
func (m *TextModel) FormatPrompt(prompt string) string {
|
||||
// Gemma 3 chat format: <start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n
|
||||
return fmt.Sprintf("<start_of_turn>user\n%s<end_of_turn>\n<start_of_turn>model\n", prompt)
|
||||
}
|
||||
|
||||
func (m *TextModel) NewCache(maxSeqLen int32) []cache.Cache {
|
||||
caches := make([]cache.Cache, len(m.Layers))
|
||||
for i := range caches {
|
||||
if m.Layers[i].IsSliding {
|
||||
// Use rotating cache for sliding window layers
|
||||
caches[i] = cache.NewRotatingKVCache(int(m.SlidingWindow))
|
||||
} else {
|
||||
// Use regular cache for global attention layers
|
||||
caches[i] = cache.NewKVCache()
|
||||
}
|
||||
}
|
||||
return caches
|
||||
}
|
||||
|
||||
// Config holds config for the full multimodal model
|
||||
type Config struct {
|
||||
TextConfig TextConfig `json:"text_config"`
|
||||
VisionConfig VisionConfig `json:"vision_config"`
|
||||
|
||||
// Image token config (from config.json)
|
||||
BOITokenIndex int32 `json:"boi_token_index"` // <start_of_image> = 255999
|
||||
EOITokenIndex int32 `json:"eoi_token_index"` // <end_of_image> = 256000
|
||||
ImageTokenIndex int32 `json:"image_token_index"` // <image_soft_token> = 262144
|
||||
MMTokensPerImage int32 `json:"mm_tokens_per_image"` // 256
|
||||
}
|
||||
|
||||
// Model is the full Gemma 3 multimodal model
|
||||
type Model struct {
|
||||
VisionTower *VisionTower `weight:"vision_tower"`
|
||||
Projector *MultiModalProjector `weight:"multi_modal_projector"`
|
||||
TextModel *TextModel `weight:"language_model"`
|
||||
Config *Config
|
||||
tok *tokenizer.Tokenizer
|
||||
}
|
||||
|
||||
// Load loads the full multimodal Gemma 3 model
|
||||
func Load(modelPath string) (*Model, error) {
|
||||
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
|
||||
var cfg Config
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("parse config: %w", err)
|
||||
}
|
||||
|
||||
// Set defaults for text config (multimodal config often has incomplete text_config)
|
||||
// These defaults match transformers.Gemma3TextConfig defaults
|
||||
tc := &cfg.TextConfig
|
||||
if tc.HeadDim == 0 {
|
||||
tc.HeadDim = 256 // Gemma 3 uses head_dim=256
|
||||
}
|
||||
if tc.NumAttentionHeads == 0 {
|
||||
// Gemma 3 4B uses 8 attention heads (cannot infer from hidden_size/head_dim)
|
||||
tc.NumAttentionHeads = 8
|
||||
}
|
||||
if tc.NumKeyValueHeads == 0 {
|
||||
// Gemma 3 4B uses 4 KV heads (GQA with 2:1 ratio)
|
||||
tc.NumKeyValueHeads = 4
|
||||
}
|
||||
if tc.VocabSize == 0 {
|
||||
tc.VocabSize = 262208 // Gemma 3 vocab size (not 262144!)
|
||||
}
|
||||
if tc.RopeTheta == 0 {
|
||||
tc.RopeTheta = 1000000
|
||||
}
|
||||
if tc.RopeLocalBaseFreq == 0 {
|
||||
tc.RopeLocalBaseFreq = 10000
|
||||
}
|
||||
if tc.RMSNormEps == 0 {
|
||||
tc.RMSNormEps = 1e-6
|
||||
}
|
||||
if tc.SlidingWindowPattern == 0 {
|
||||
tc.SlidingWindowPattern = 6
|
||||
}
|
||||
if tc.MaxPositionEmbeddings == 0 {
|
||||
tc.MaxPositionEmbeddings = 131072 // Gemma 3 4B default
|
||||
}
|
||||
|
||||
// Compute text model scale
|
||||
tc.Scale = float32(1.0 / math.Sqrt(float64(tc.HeadDim)))
|
||||
|
||||
// Set defaults for image token config
|
||||
if cfg.BOITokenIndex == 0 {
|
||||
cfg.BOITokenIndex = 255999 // <start_of_image>
|
||||
}
|
||||
if cfg.EOITokenIndex == 0 {
|
||||
cfg.EOITokenIndex = 256000 // <end_of_image>
|
||||
}
|
||||
if cfg.ImageTokenIndex == 0 {
|
||||
cfg.ImageTokenIndex = 262144 // <image_soft_token>
|
||||
}
|
||||
if cfg.MMTokensPerImage == 0 {
|
||||
cfg.MMTokensPerImage = 256
|
||||
}
|
||||
|
||||
weights, err := safetensors.LoadModelWeights(modelPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
|
||||
tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load tokenizer: %w", err)
|
||||
}
|
||||
|
||||
m := &Model{
|
||||
VisionTower: &VisionTower{
|
||||
Embeddings: &VisionEmbeddings{},
|
||||
Encoder: make([]*VisionEncoderLayer, cfg.VisionConfig.NumHiddenLayers),
|
||||
Config: &cfg.VisionConfig,
|
||||
},
|
||||
Projector: &MultiModalProjector{},
|
||||
TextModel: &TextModel{
|
||||
Layers: make([]*DecoderLayer, cfg.TextConfig.NumHiddenLayers),
|
||||
TextConfig: &cfg.TextConfig,
|
||||
},
|
||||
Config: &cfg,
|
||||
tok: tok,
|
||||
}
|
||||
|
||||
// Initialize text layer metadata
|
||||
for i := range m.TextModel.Layers {
|
||||
m.TextModel.Layers[i] = &DecoderLayer{
|
||||
LayerIdx: int32(i),
|
||||
IsSliding: isLayerSliding(int32(i), cfg.TextConfig.SlidingWindowPattern),
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize vision encoder layers
|
||||
for i := range m.VisionTower.Encoder {
|
||||
m.VisionTower.Encoder[i] = &VisionEncoderLayer{}
|
||||
}
|
||||
|
||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Tied embeddings for text output
|
||||
m.TextModel.Output = nn.NewLinear(m.TextModel.EmbedTokens.Weight, nil)
|
||||
m.TextModel.tok = tok
|
||||
|
||||
mlx.Eval(mlx.Collect(m)...)
|
||||
weights.ReleaseAll()
|
||||
|
||||
// Precompute (1 + weight) for Gemma-style RMSNorm
|
||||
precomputeGemmaScaledWeights(m.TextModel)
|
||||
|
||||
// Precompute projector's scaled weight
|
||||
m.Projector.SoftEmbNormScaled = mlx.AddScalar(m.Projector.SoftEmbNorm.Weight, 1.0)
|
||||
mlx.Eval(m.Projector.SoftEmbNormScaled)
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Forward runs the text-only forward pass
|
||||
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
|
||||
return m.TextModel.Forward(tokens, caches)
|
||||
}
|
||||
|
||||
// ForwardWithImage runs the multimodal forward pass
|
||||
// tokens: [B, L] input token IDs (with image placeholder tokens)
|
||||
// image: [B, H, W, C] preprocessed image tensor
|
||||
func (m *Model) ForwardWithImage(tokens *mlx.Array, image *mlx.Array, caches []cache.Cache) *mlx.Array {
|
||||
B, L := tokens.Shape()[0], tokens.Shape()[1]
|
||||
cfg := m.Config.TextConfig
|
||||
|
||||
// Find image token position FIRST before any eval that might free tokens
|
||||
imageStartPos := int32(-1)
|
||||
if image != nil && B == 1 {
|
||||
tokenData := tokens.DataInt32() // This evals tokens
|
||||
for i, t := range tokenData {
|
||||
if t == m.Config.ImageTokenIndex {
|
||||
imageStartPos = int32(i)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get text embeddings and scale
|
||||
h := m.TextModel.EmbedTokens.Forward(tokens)
|
||||
h = mlx.MulScalar(h, float32(math.Sqrt(float64(cfg.HiddenSize))))
|
||||
|
||||
// Process image if provided
|
||||
if image != nil && imageStartPos >= 0 {
|
||||
// Vision tower: [B, H, W, C] -> [B, num_patches, vision_hidden]
|
||||
visionFeatures := m.VisionTower.Forward(image)
|
||||
|
||||
// Project to text space: [B, num_patches, vision_hidden] -> [B, 256, text_hidden]
|
||||
imageEmbeds := m.Projector.Forward(visionFeatures, cfg.RMSNormEps)
|
||||
|
||||
// Eval h and imageEmbeds together so neither gets freed
|
||||
mlx.Eval(h, imageEmbeds)
|
||||
|
||||
// Cast imageEmbeds to match text embeddings dtype (bf16)
|
||||
if imageEmbeds.Dtype() != h.Dtype() {
|
||||
imageEmbeds = mlx.AsType(imageEmbeds, h.Dtype())
|
||||
mlx.Eval(imageEmbeds)
|
||||
}
|
||||
|
||||
// Insert image embeddings at the known position
|
||||
h = m.insertImageEmbeddingsAt(h, imageEmbeds, imageStartPos)
|
||||
}
|
||||
|
||||
// Run through text model layers
|
||||
for i, layer := range m.TextModel.Layers {
|
||||
h = layer.Forward(h, caches[i], B, L, m.TextModel.TextConfig)
|
||||
}
|
||||
|
||||
// Final norm and output projection
|
||||
return m.TextModel.Output.Forward(mlx.RMSNorm(h, m.TextModel.NormScaled, cfg.RMSNormEps))
|
||||
}
|
||||
|
||||
// insertImageEmbeddingsAt replaces image placeholder tokens with actual image embeddings
|
||||
// at a known position (to avoid re-scanning tokens after eval)
|
||||
// textEmbeds: [B, L, hidden_size] text embeddings
|
||||
// imageEmbeds: [B, 256, hidden_size] image embeddings from projector
|
||||
// startPos: starting position of image tokens in the sequence
|
||||
func (m *Model) insertImageEmbeddingsAt(textEmbeds, imageEmbeds *mlx.Array, startPos int32) *mlx.Array {
|
||||
numImageTokens := imageEmbeds.Shape()[1]
|
||||
L := textEmbeds.Shape()[1]
|
||||
|
||||
// Split text embeddings: [0:startPos] + imageEmbeds + [startPos+256:L]
|
||||
afterStart := startPos + numImageTokens
|
||||
|
||||
// Slice before image tokens: textEmbeds[:, 0:startPos, :]
|
||||
before := mlx.SliceAxis(textEmbeds, 1, 0, startPos)
|
||||
|
||||
// Slice after image tokens: textEmbeds[:, startPos+256:L, :]
|
||||
after := mlx.SliceAxis(textEmbeds, 1, afterStart, L)
|
||||
|
||||
// Concatenate: before + imageEmbeds + after along axis 1
|
||||
return mlx.Concatenate([]*mlx.Array{before, imageEmbeds, after}, 1)
|
||||
}
|
||||
|
||||
// Interface methods for Model
|
||||
func (m *Model) NumLayers() int { return len(m.TextModel.Layers) }
|
||||
func (m *Model) MaxContextLength() int32 { return m.Config.TextConfig.MaxPositionEmbeddings }
|
||||
func (m *Model) VocabSize() int32 { return m.Config.TextConfig.VocabSize }
|
||||
func (m *Model) Tokenizer() *tokenizer.Tokenizer { return m.tok }
|
||||
func (m *Model) NewCache(maxSeqLen int32) []cache.Cache { return m.TextModel.NewCache(maxSeqLen) }
|
||||
func (m *Model) ImageSize() int32 { return m.Config.VisionConfig.ImageSize }
|
||||
|
||||
// FormatPrompt applies the Gemma 3 multimodal chat template
|
||||
func (m *Model) FormatPrompt(prompt string) string {
|
||||
return fmt.Sprintf("<start_of_turn>user\n%s<end_of_turn>\n<start_of_turn>model\n", prompt)
|
||||
}
|
||||
|
||||
// FormatPromptWithImage applies the Gemma 3 multimodal chat template with image
|
||||
func (m *Model) FormatPromptWithImage(prompt string) string {
|
||||
return fmt.Sprintf("<start_of_turn>user\n<start_of_image>%s<end_of_turn>\n<start_of_turn>model\n", prompt)
|
||||
}
|
||||
|
||||
// ExpandImageTokens expands <start_of_image> into 256 image placeholder tokens
|
||||
// Input tokens containing boi_token (255999) are expanded to:
|
||||
// boi_token + 256 * image_token + eoi_token
|
||||
func (m *Model) ExpandImageTokens(tokens []int32) []int32 {
|
||||
result := make([]int32, 0, len(tokens)+int(m.Config.MMTokensPerImage)+1)
|
||||
|
||||
for _, t := range tokens {
|
||||
if t == m.Config.BOITokenIndex {
|
||||
// Expand: boi + 256 * image_token + eoi
|
||||
result = append(result, m.Config.BOITokenIndex)
|
||||
for i := int32(0); i < m.Config.MMTokensPerImage; i++ {
|
||||
result = append(result, m.Config.ImageTokenIndex)
|
||||
}
|
||||
result = append(result, m.Config.EOITokenIndex)
|
||||
} else {
|
||||
result = append(result, t)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
@@ -1,58 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package gemma3
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"image"
|
||||
_ "image/jpeg"
|
||||
_ "image/png"
|
||||
"os"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"golang.org/x/image/draw"
|
||||
)
|
||||
|
||||
// ProcessImage loads and preprocesses an image for the vision tower
|
||||
// Returns [1, H, W, C] tensor in NHWC format normalized for SigLIP
|
||||
func ProcessImage(path string, imageSize int32) (*mlx.Array, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open image: %w", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
img, _, err := image.Decode(f)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode image: %w", err)
|
||||
}
|
||||
|
||||
return ProcessImageData(img, imageSize)
|
||||
}
|
||||
|
||||
// ProcessImageData preprocesses an image.Image for the vision tower
|
||||
func ProcessImageData(img image.Image, imageSize int32) (*mlx.Array, error) {
|
||||
// Resize to target size using bilinear interpolation
|
||||
resized := image.NewRGBA(image.Rect(0, 0, int(imageSize), int(imageSize)))
|
||||
draw.BiLinear.Scale(resized, resized.Bounds(), img, img.Bounds(), draw.Over, nil)
|
||||
|
||||
// Convert to float32 array [H, W, C] and normalize
|
||||
// SigLIP normalization: (pixel / 255.0 - 0.5) / 0.5 = pixel / 127.5 - 1.0
|
||||
data := make([]float32, imageSize*imageSize*3)
|
||||
idx := 0
|
||||
for y := int32(0); y < imageSize; y++ {
|
||||
for x := int32(0); x < imageSize; x++ {
|
||||
r, g, b, _ := resized.At(int(x), int(y)).RGBA()
|
||||
// RGBA returns 16-bit values, convert to 8-bit
|
||||
data[idx] = float32(r>>8)/127.5 - 1.0
|
||||
data[idx+1] = float32(g>>8)/127.5 - 1.0
|
||||
data[idx+2] = float32(b>>8)/127.5 - 1.0
|
||||
idx += 3
|
||||
}
|
||||
}
|
||||
|
||||
// Create MLX array [1, H, W, C] for NHWC layout
|
||||
arr := mlx.NewArrayFloat32(data, []int32{1, imageSize, imageSize, 3})
|
||||
mlx.Eval(arr) // Materialize to prevent use-after-free
|
||||
return arr, nil
|
||||
}
|
||||
@@ -1,50 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package gemma3
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
)
|
||||
|
||||
// MultiModalProjector projects vision features to text embedding space
|
||||
type MultiModalProjector struct {
|
||||
// mm_input_projection_weight: [vision_hidden, text_hidden]
|
||||
InputProjection *mlx.Array `weight:"mm_input_projection_weight"`
|
||||
SoftEmbNorm *nn.RMSNorm `weight:"mm_soft_emb_norm"`
|
||||
|
||||
// Precomputed (1 + weight) for Gemma-style RMSNorm
|
||||
SoftEmbNormScaled *mlx.Array `weight:"-"`
|
||||
}
|
||||
|
||||
// Forward projects vision features to text space
|
||||
// Input: [B, num_patches, vision_hidden] (e.g., [1, 4096, 1152])
|
||||
// Output: [B, num_image_tokens, text_hidden] (e.g., [1, 256, 2560])
|
||||
func (p *MultiModalProjector) Forward(visionFeatures *mlx.Array, eps float32) *mlx.Array {
|
||||
// Average pool 4x4: [B, 4096, 1152] -> [B, 256, 1152]
|
||||
// 4096 patches = 64x64 grid, pool to 16x16 = 256 tokens
|
||||
B := visionFeatures.Shape()[0]
|
||||
visionHidden := visionFeatures.Shape()[2]
|
||||
|
||||
// Reshape to [B, 64, 64, hidden]
|
||||
gridSize := int32(64) // sqrt(4096)
|
||||
pooledSize := int32(16) // 64/4
|
||||
h := mlx.Reshape(visionFeatures, B, gridSize, gridSize, visionHidden)
|
||||
|
||||
// Reshape to [B, 16, 4, 16, 4, hidden] for 4x4 pooling
|
||||
h = mlx.Reshape(h, B, pooledSize, 4, pooledSize, 4, visionHidden)
|
||||
|
||||
// Average over pooling dimensions (axes 2 and 4)
|
||||
h = mlx.Mean(h, 4, false)
|
||||
h = mlx.Mean(h, 2, false)
|
||||
|
||||
// h is now [B, 16, 16, hidden], reshape to [B, 256, hidden]
|
||||
numTokens := pooledSize * pooledSize
|
||||
h = mlx.Reshape(h, B, numTokens, visionHidden)
|
||||
|
||||
// Apply Gemma-style RMS norm (use precomputed 1 + weight)
|
||||
h = mlx.RMSNorm(h, p.SoftEmbNormScaled, eps)
|
||||
|
||||
// Project to text space: [B, 256, vision_hidden] @ [vision_hidden, text_hidden]
|
||||
return mlx.Linear(h, p.InputProjection)
|
||||
}
|
||||
@@ -1,138 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package gemma3
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
)
|
||||
|
||||
// VisionConfig holds configuration for the SigLIP vision tower
|
||||
type VisionConfig struct {
|
||||
HiddenSize int32 `json:"hidden_size"`
|
||||
ImageSize int32 `json:"image_size"`
|
||||
IntermediateSize int32 `json:"intermediate_size"`
|
||||
NumAttentionHeads int32 `json:"num_attention_heads"`
|
||||
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
||||
PatchSize int32 `json:"patch_size"`
|
||||
}
|
||||
|
||||
// VisionTower is the SigLIP vision encoder
|
||||
type VisionTower struct {
|
||||
Embeddings *VisionEmbeddings `weight:"vision_model.embeddings"`
|
||||
Encoder []*VisionEncoderLayer `weight:"vision_model.encoder.layers"`
|
||||
PostLayerNorm *nn.LayerNorm `weight:"vision_model.post_layernorm"`
|
||||
Config *VisionConfig
|
||||
}
|
||||
|
||||
// VisionEmbeddings handles patch and position embeddings
|
||||
type VisionEmbeddings struct {
|
||||
// PatchWeight: [O, C, kH, kW] from PyTorch, transposed to [O, kH, kW, C] for MLX
|
||||
PatchWeight *mlx.Array `weight:"patch_embedding.weight"`
|
||||
PatchBias *mlx.Array `weight:"patch_embedding.bias"`
|
||||
PosEmbed *nn.Embedding `weight:"position_embedding"`
|
||||
}
|
||||
|
||||
// VisionEncoderLayer is a single transformer encoder layer
|
||||
type VisionEncoderLayer struct {
|
||||
LayerNorm1 *nn.LayerNorm `weight:"layer_norm1"`
|
||||
Attention *VisionAttention `weight:"self_attn"`
|
||||
LayerNorm2 *nn.LayerNorm `weight:"layer_norm2"`
|
||||
MLP *VisionMLP `weight:"mlp"`
|
||||
}
|
||||
|
||||
// VisionAttention implements multi-head self-attention
|
||||
type VisionAttention struct {
|
||||
QProj *nn.Linear `weight:"q_proj"`
|
||||
KProj *nn.Linear `weight:"k_proj"`
|
||||
VProj *nn.Linear `weight:"v_proj"`
|
||||
OutProj *nn.Linear `weight:"out_proj"`
|
||||
}
|
||||
|
||||
// VisionMLP is the feed-forward network
|
||||
type VisionMLP struct {
|
||||
FC1 *nn.Linear `weight:"fc1"`
|
||||
FC2 *nn.Linear `weight:"fc2"`
|
||||
}
|
||||
|
||||
// Forward runs the vision tower on preprocessed images
|
||||
// Input: [B, H, W, C] normalized image tensor (NHWC layout for MLX)
|
||||
// Output: [B, num_patches, hidden_size]
|
||||
func (v *VisionTower) Forward(x *mlx.Array) *mlx.Array {
|
||||
// Patch embedding conv: input [B, H, W, C], weight [O, kH, kW, C] -> [B, grid, grid, O]
|
||||
// Weight comes as [O, C, kH, kW] from PyTorch, transpose to [O, kH, kW, C]
|
||||
weight := mlx.Transpose(v.Embeddings.PatchWeight, 0, 2, 3, 1)
|
||||
h := mlx.Conv2d(x, weight, v.Config.PatchSize, 0) // stride=patch_size, no padding
|
||||
|
||||
// Add bias: [O] -> [1, 1, 1, O] for broadcasting
|
||||
bias := mlx.Reshape(v.Embeddings.PatchBias, 1, 1, 1, v.Embeddings.PatchBias.Shape()[0])
|
||||
h = mlx.Add(h, bias)
|
||||
|
||||
// h is [B, grid, grid, hidden], flatten to [B, num_patches, hidden]
|
||||
B := h.Shape()[0]
|
||||
gridH, gridW := h.Shape()[1], h.Shape()[2]
|
||||
hidden := h.Shape()[3]
|
||||
numPatches := gridH * gridW
|
||||
h = mlx.Reshape(h, B, numPatches, hidden)
|
||||
|
||||
// Add position embeddings
|
||||
posIds := mlx.ArangeInt(0, numPatches, 1, mlx.DtypeInt32)
|
||||
posEmbed := v.Embeddings.PosEmbed.Forward(posIds)
|
||||
h = mlx.Add(h, posEmbed)
|
||||
|
||||
// Encoder layers
|
||||
headDim := float32(v.Config.HiddenSize / v.Config.NumAttentionHeads)
|
||||
scale := float32(1.0 / math.Sqrt(float64(headDim)))
|
||||
for _, layer := range v.Encoder {
|
||||
h = layer.Forward(h, v.Config, scale)
|
||||
}
|
||||
|
||||
// Final layer norm
|
||||
h = v.PostLayerNorm.Forward(h)
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// Forward runs a vision encoder layer
|
||||
func (l *VisionEncoderLayer) Forward(x *mlx.Array, cfg *VisionConfig, scale float32) *mlx.Array {
|
||||
// Pre-norm attention
|
||||
h := l.LayerNorm1.Forward(x)
|
||||
h = l.Attention.Forward(h, cfg, scale)
|
||||
x = mlx.Add(x, h)
|
||||
|
||||
// Pre-norm MLP
|
||||
h = l.LayerNorm2.Forward(x)
|
||||
h = l.MLP.Forward(h)
|
||||
return mlx.Add(x, h)
|
||||
}
|
||||
|
||||
// Forward runs multi-head self-attention
|
||||
func (a *VisionAttention) Forward(x *mlx.Array, cfg *VisionConfig, scale float32) *mlx.Array {
|
||||
B, L := x.Shape()[0], x.Shape()[1]
|
||||
headDim := cfg.HiddenSize / cfg.NumAttentionHeads
|
||||
|
||||
q := a.QProj.Forward(x)
|
||||
k := a.KProj.Forward(x)
|
||||
v := a.VProj.Forward(x)
|
||||
|
||||
// Reshape to [B, num_heads, L, head_dim]
|
||||
q = mlx.Transpose(mlx.Reshape(q, B, L, cfg.NumAttentionHeads, headDim), 0, 2, 1, 3)
|
||||
k = mlx.Transpose(mlx.Reshape(k, B, L, cfg.NumAttentionHeads, headDim), 0, 2, 1, 3)
|
||||
v = mlx.Transpose(mlx.Reshape(v, B, L, cfg.NumAttentionHeads, headDim), 0, 2, 1, 3)
|
||||
|
||||
// Scaled dot-product attention (no causal mask for vision)
|
||||
out := mlx.ScaledDotProductAttention(q, k, v, scale, false)
|
||||
|
||||
// Reshape back: [B, num_heads, L, head_dim] -> [B, L, hidden]
|
||||
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.HiddenSize)
|
||||
|
||||
return a.OutProj.Forward(out)
|
||||
}
|
||||
|
||||
// Forward runs the MLP with GELU activation
|
||||
func (m *VisionMLP) Forward(x *mlx.Array) *mlx.Array {
|
||||
h := mlx.GELU(m.FC1.Forward(x))
|
||||
return m.FC2.Forward(h)
|
||||
}
|
||||
@@ -1,487 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package gpt_oss
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
)
|
||||
|
||||
// RopeScaling holds YaRN or other RoPE scaling configuration
|
||||
type RopeScaling struct {
|
||||
RopeType string `json:"rope_type"`
|
||||
Factor float32 `json:"factor"`
|
||||
OriginalMaxPositionEmbeddings int32 `json:"original_max_position_embeddings"`
|
||||
BetaFast float32 `json:"beta_fast"`
|
||||
BetaSlow float32 `json:"beta_slow"`
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
HiddenSize int32 `json:"hidden_size"`
|
||||
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
||||
IntermediateSize int32 `json:"intermediate_size"`
|
||||
NumAttentionHeads int32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads int32 `json:"num_key_value_heads"`
|
||||
VocabSize int32 `json:"vocab_size"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
HeadDim int32 `json:"head_dim"`
|
||||
SlidingWindow int32 `json:"sliding_window"`
|
||||
NumLocalExperts int32 `json:"num_local_experts"`
|
||||
NumExpertsPerTok int32 `json:"num_experts_per_tok"`
|
||||
LayerTypes []string `json:"layer_types"`
|
||||
SwiGLULimit float32 `json:"swiglu_limit"`
|
||||
RopeScaling *RopeScaling `json:"rope_scaling"`
|
||||
Scale float32 `json:"-"` // computed: 1/sqrt(HeadDim)
|
||||
}
|
||||
|
||||
type Attention struct {
|
||||
QProj *nn.Linear `weight:"self_attn.q_proj"`
|
||||
KProj *nn.Linear `weight:"self_attn.k_proj"`
|
||||
VProj *nn.Linear `weight:"self_attn.v_proj"`
|
||||
OProj *nn.Linear `weight:"self_attn.o_proj"`
|
||||
Sinks *mlx.Array `weight:"self_attn.sinks,optional"`
|
||||
YarnFreqs *mlx.Array // computed
|
||||
YarnMscale float32
|
||||
}
|
||||
|
||||
// swiGLU applies the GPT-OSS custom SwiGLU activation.
|
||||
// Formula: (gate * sigmoid(alpha * gate)) * (up + 1)
|
||||
// with clipping: gate to [None, limit], up to [-limit, limit]
|
||||
func swiGLU(gate, up *mlx.Array, alpha, limit float32) *mlx.Array {
|
||||
// Clip gate to [None, limit]
|
||||
gateClipped := mlx.ClipScalar(gate, 0, limit, false, true)
|
||||
|
||||
// Clip up to [-limit, limit]
|
||||
upClipped := mlx.ClipScalar(up, -limit, limit, true, true)
|
||||
|
||||
// glu_scaled = alpha * gate_clipped
|
||||
gluScaled := mlx.MulScalar(gateClipped, alpha)
|
||||
|
||||
// sig = sigmoid(glu_scaled)
|
||||
sig := mlx.Sigmoid(gluScaled)
|
||||
|
||||
// out_glu = gate_clipped * sig
|
||||
outGlu := mlx.Mul(gateClipped, sig)
|
||||
|
||||
// result = out_glu * (up_clipped + 1)
|
||||
return mlx.Mul(outGlu, mlx.AddScalar(upClipped, 1.0))
|
||||
}
|
||||
|
||||
// compiledSwiGLU is a singleton compiled SwiGLU function shared across all layers
|
||||
var compiledSwiGLU *mlx.CompiledFunc
|
||||
|
||||
// getCompiledSwiGLU returns the compiled SwiGLU function, creating it once if needed
|
||||
func getCompiledSwiGLU() *mlx.CompiledFunc {
|
||||
if compiledSwiGLU == nil {
|
||||
const alpha float32 = 1.702
|
||||
const limit float32 = 7.0
|
||||
compiledSwiGLU = mlx.CompileShapeless(func(inputs []*mlx.Array) []*mlx.Array {
|
||||
return []*mlx.Array{swiGLU(inputs[0], inputs[1], alpha, limit)}
|
||||
}, true) // shapeless=true so it works for any input size
|
||||
}
|
||||
return compiledSwiGLU
|
||||
}
|
||||
|
||||
// ComputeYarnFreqs computes YaRN-modified RoPE frequencies
|
||||
// Based on mlx-lm's YarnRoPE implementation
|
||||
func ComputeYarnFreqs(dims int32, base, scalingFactor float32, origMaxPos int32, betaFast, betaSlow float32) (*mlx.Array, float32) {
|
||||
// yarn_find_correction_dim
|
||||
yarnFindCorrectionDim := func(numRotations float64) float64 {
|
||||
return float64(dims) * math.Log(float64(origMaxPos)/(numRotations*2*math.Pi)) / (2 * math.Log(float64(base)))
|
||||
}
|
||||
|
||||
// yarn_find_correction_range
|
||||
low := int(math.Floor(yarnFindCorrectionDim(float64(betaFast))))
|
||||
high := int(math.Ceil(yarnFindCorrectionDim(float64(betaSlow))))
|
||||
if low < 0 {
|
||||
low = 0
|
||||
}
|
||||
if high > int(dims)-1 {
|
||||
high = int(dims) - 1
|
||||
}
|
||||
|
||||
// yarn_get_mscale
|
||||
yarnGetMscale := func(scale, mscale float64) float64 {
|
||||
if scale <= 1 {
|
||||
return 1.0
|
||||
}
|
||||
return 0.1*mscale*math.Log(scale) + 1.0
|
||||
}
|
||||
mscale := float32(yarnGetMscale(float64(scalingFactor), 1.0) / yarnGetMscale(float64(scalingFactor), 0.0))
|
||||
|
||||
// Compute frequencies
|
||||
// freq_extra = base ** (arange(0, dims, 2) / dims)
|
||||
// freq_inter = scaling_factor * freq_extra
|
||||
halfDims := dims / 2
|
||||
freqData := make([]float32, halfDims)
|
||||
for i := int32(0); i < halfDims; i++ {
|
||||
exp := float64(2*i) / float64(dims)
|
||||
freqExtra := math.Pow(float64(base), exp)
|
||||
freqInter := float64(scalingFactor) * freqExtra
|
||||
|
||||
// linear ramp mask
|
||||
var freqMask float64
|
||||
if low == high {
|
||||
freqMask = 0.0
|
||||
} else {
|
||||
t := (float64(i) - float64(low)) / float64(high-low)
|
||||
if t < 0 {
|
||||
t = 0
|
||||
}
|
||||
if t > 1 {
|
||||
t = 1
|
||||
}
|
||||
freqMask = 1.0 - t
|
||||
}
|
||||
|
||||
// Combined frequency: (inter * extra) / (inter * mask + extra * (1 - mask))
|
||||
freqData[i] = float32((freqInter * freqExtra) / (freqInter*freqMask + freqExtra*(1-freqMask)))
|
||||
}
|
||||
|
||||
return mlx.NewArray(freqData, []int32{halfDims}), mscale
|
||||
}
|
||||
|
||||
// initYarn initializes YaRN RoPE if configured
|
||||
func (a *Attention) initYarn(cfg *Config) {
|
||||
a.YarnMscale = 1.0
|
||||
if cfg.RopeScaling != nil && cfg.RopeScaling.RopeType == "yarn" {
|
||||
a.YarnFreqs, a.YarnMscale = ComputeYarnFreqs(
|
||||
cfg.HeadDim,
|
||||
cfg.RopeTheta,
|
||||
cfg.RopeScaling.Factor,
|
||||
cfg.RopeScaling.OriginalMaxPositionEmbeddings,
|
||||
cfg.RopeScaling.BetaFast,
|
||||
cfg.RopeScaling.BetaSlow,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, mask *mlx.Array, maskMode string, cfg *Config) *mlx.Array {
|
||||
q := a.QProj.Forward(x)
|
||||
k := a.KProj.Forward(x)
|
||||
v := a.VProj.Forward(x)
|
||||
|
||||
// Reshape via AsStrided: [B, L, n_heads * head_dim] -> [B, n_heads, L, head_dim]
|
||||
q = mlx.AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim},
|
||||
[]int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0)
|
||||
k = mlx.AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
|
||||
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
|
||||
v = mlx.AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
|
||||
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
|
||||
|
||||
offset := 0
|
||||
if c != nil {
|
||||
offset = c.Offset()
|
||||
}
|
||||
if a.YarnFreqs != nil {
|
||||
if a.YarnMscale != 1.0 {
|
||||
q = mlx.MulScalar(q, a.YarnMscale)
|
||||
}
|
||||
q = mlx.RoPEWithFreqs(q, a.YarnFreqs, int(cfg.HeadDim), false, 1.0, offset)
|
||||
k = mlx.RoPEWithFreqs(k, a.YarnFreqs, int(cfg.HeadDim), false, 1.0, offset)
|
||||
} else {
|
||||
q = mlx.RoPE(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
|
||||
k = mlx.RoPE(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
|
||||
}
|
||||
|
||||
if c != nil {
|
||||
k, v = c.Update(k, v, int(L))
|
||||
}
|
||||
|
||||
out := mlx.ScaledDotProductAttentionWithSinks(q, k, v, cfg.Scale, maskMode, mask, a.Sinks)
|
||||
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
|
||||
return a.OProj.Forward(out)
|
||||
}
|
||||
|
||||
// CreateSlidingWindowMask creates a causal mask with sliding window
|
||||
// Mirrors mlx-lm's create_causal_mask with window_size
|
||||
func CreateSlidingWindowMask(seqLen, queryStart, keyStart, keyLen, windowSize int) *mlx.Array {
|
||||
// Build mask aligned to actual cache length (may be rotated)
|
||||
// rinds covers existing keys: [keyStart, keyStart+keyLen)
|
||||
// linds covers new queries: [queryStart, queryStart+seqLen)
|
||||
rinds := mlx.Arange(float32(keyStart), float32(keyStart+keyLen), 1) // [keyLen]
|
||||
linds := mlx.Arange(float32(queryStart), float32(queryStart+seqLen), 1) // [seqLen]
|
||||
|
||||
linds = mlx.ExpandDims(linds, 1) // [seqLen, 1]
|
||||
rinds = mlx.ExpandDims(rinds, 0) // [1, keyLen]
|
||||
|
||||
causalMask := mlx.GreaterEqual(linds, rinds) // [seqLen, keyLen]
|
||||
windowLimit := mlx.AddScalar(rinds, float32(windowSize))
|
||||
windowMask := mlx.LessArray(linds, windowLimit) // [seqLen, keyLen]
|
||||
|
||||
return mlx.LogicalAnd(causalMask, windowMask)
|
||||
}
|
||||
|
||||
// MoE represents the Mixture of Experts SwiGLU layer with quantized experts.
|
||||
type MoE struct {
|
||||
Router *nn.Linear `weight:"mlp.router"`
|
||||
TopK int32
|
||||
HiddenSize int32
|
||||
GroupSize int
|
||||
Bits int
|
||||
// Expert weights (loaded manually via sanitizeExpertWeights)
|
||||
GateBlocks, GateScales, GateBias *mlx.Array
|
||||
UpBlocks, UpScales, UpBias *mlx.Array
|
||||
DownBlocks, DownScales, DownBias *mlx.Array
|
||||
}
|
||||
|
||||
func (moe *MoE) Forward(x *mlx.Array, B, L int32) *mlx.Array {
|
||||
logits := moe.Router.Forward(x)
|
||||
neg := mlx.Neg(logits)
|
||||
part := mlx.Argpartition(neg, int(moe.TopK)-1, -1)
|
||||
topKIdx := mlx.Slice(part, []int32{0, 0, 0}, []int32{B, L, moe.TopK})
|
||||
topKVal := mlx.TakeAlongAxis(logits, topKIdx, -1)
|
||||
weights := mlx.Softmax(topKVal, -1)
|
||||
|
||||
xFlat := mlx.Reshape(x, B*L, 1, 1, moe.HiddenSize)
|
||||
idxFlat := mlx.Reshape(topKIdx, B*L, moe.TopK)
|
||||
|
||||
doSort := B*L >= 64
|
||||
var invOrder *mlx.Array
|
||||
sorted := false
|
||||
n := B * L * moe.TopK
|
||||
|
||||
if doSort {
|
||||
idxAll := mlx.Flatten(idxFlat)
|
||||
order := mlx.Argsort(idxAll, 0)
|
||||
invOrder = mlx.Argsort(order, 0)
|
||||
xFlat = mlx.ExpandDims(mlx.Take(mlx.Squeeze(xFlat, 1), mlx.FloorDivideScalar(order, moe.TopK), 0), 1)
|
||||
idxFlat = mlx.Reshape(mlx.Take(idxAll, order, 0), n, 1)
|
||||
sorted = true
|
||||
}
|
||||
|
||||
gate := mlx.GatherQMM(xFlat, moe.GateBlocks, moe.GateScales, nil, nil, idxFlat, true, moe.GroupSize, moe.Bits, "mxfp4", sorted)
|
||||
up := mlx.GatherQMM(xFlat, moe.UpBlocks, moe.UpScales, nil, nil, idxFlat, true, moe.GroupSize, moe.Bits, "mxfp4", sorted)
|
||||
|
||||
if moe.GateBias != nil {
|
||||
gate = mlx.Add(gate, mlx.ExpandDims(mlx.Take(moe.GateBias, idxFlat, 0), 2))
|
||||
}
|
||||
if moe.UpBias != nil {
|
||||
up = mlx.Add(up, mlx.ExpandDims(mlx.Take(moe.UpBias, idxFlat, 0), 2))
|
||||
}
|
||||
|
||||
hidden := getCompiledSwiGLU().Call(gate, up)[0]
|
||||
|
||||
down := mlx.GatherQMM(hidden, moe.DownBlocks, moe.DownScales, nil, nil, idxFlat, true, moe.GroupSize, moe.Bits, "mxfp4", sorted)
|
||||
if moe.DownBias != nil {
|
||||
down = mlx.Add(down, mlx.ExpandDims(mlx.Take(moe.DownBias, idxFlat, 0), 2))
|
||||
}
|
||||
|
||||
if doSort {
|
||||
down = mlx.Reshape(mlx.Take(mlx.Squeeze(mlx.Squeeze(down, 2), 1), invOrder, 0), B*L, moe.TopK, moe.HiddenSize)
|
||||
} else {
|
||||
down = mlx.Squeeze(down, 2)
|
||||
}
|
||||
|
||||
ewFlat := mlx.Reshape(weights, B*L, moe.TopK, 1)
|
||||
return mlx.Reshape(mlx.Sum(mlx.Mul(down, ewFlat), 1, false), B, L, moe.HiddenSize)
|
||||
}
|
||||
|
||||
type Block struct {
|
||||
Attention *Attention
|
||||
MLP *MoE
|
||||
InputNorm *nn.RMSNorm `weight:"input_layernorm"`
|
||||
PostAttnNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
|
||||
LayerType string // "sliding_attention" or "full_attention"
|
||||
}
|
||||
|
||||
func (b *Block) Forward(x *mlx.Array, c cache.Cache, B, L int32, mask *mlx.Array, maskMode string, cfg *Config) *mlx.Array {
|
||||
h := mlx.Add(x, b.Attention.Forward(b.InputNorm.Forward(x, cfg.RMSNormEps), c, B, L, mask, maskMode, cfg))
|
||||
return mlx.Add(h, b.MLP.Forward(b.PostAttnNorm.Forward(h, cfg.RMSNormEps), B, L))
|
||||
}
|
||||
|
||||
type Model struct {
|
||||
EmbedTokens *nn.Embedding `weight:"model.embed_tokens"`
|
||||
Layers []*Block `weight:"-"` // loaded manually due to MoE sanitization
|
||||
Norm *nn.RMSNorm `weight:"model.norm"`
|
||||
LMHead *nn.Linear `weight:"lm_head"`
|
||||
|
||||
tok *tokenizer.Tokenizer
|
||||
*Config
|
||||
}
|
||||
|
||||
func (m *Model) Tokenizer() *tokenizer.Tokenizer { return m.tok }
|
||||
func (m *Model) NumLayers() int { return len(m.Layers) }
|
||||
func (m *Model) VocabSize() int32 { return m.Config.VocabSize }
|
||||
|
||||
func (m *Model) NewCache(int32) []cache.Cache {
|
||||
caches := make([]cache.Cache, len(m.Layers))
|
||||
for i, layer := range m.Layers {
|
||||
if layer.LayerType == "sliding_attention" && m.SlidingWindow > 0 {
|
||||
caches[i] = cache.NewRotatingKVCache(int(m.SlidingWindow))
|
||||
} else {
|
||||
caches[i] = cache.NewKVCache()
|
||||
}
|
||||
}
|
||||
return caches
|
||||
}
|
||||
|
||||
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
|
||||
B, L := tokens.Shape()[0], tokens.Shape()[1]
|
||||
x := m.EmbedTokens.Forward(tokens)
|
||||
|
||||
// Find representative cache indices for sliding window attention
|
||||
var swaIdx int = -1
|
||||
for i, layer := range m.Layers {
|
||||
if layer.LayerType == "sliding_attention" {
|
||||
swaIdx = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Create masks once at model level
|
||||
var fullMask, swaMask *mlx.Array
|
||||
var fullMaskMode, swaMaskMode string
|
||||
|
||||
if L > 1 {
|
||||
fullMaskMode = "causal"
|
||||
if swaIdx >= 0 && m.SlidingWindow > 0 && caches != nil {
|
||||
c := caches[swaIdx]
|
||||
offset := c.Offset()
|
||||
windowSize := int(m.SlidingWindow)
|
||||
cacheLen := min(int(L), windowSize)
|
||||
if offset > 0 {
|
||||
cacheLen = min(c.Len()+int(L), windowSize)
|
||||
}
|
||||
if int(L) > windowSize {
|
||||
swaMask = CreateSlidingWindowMask(int(L), offset, offset+int(L)-cacheLen, cacheLen, windowSize)
|
||||
} else {
|
||||
swaMaskMode = "causal"
|
||||
}
|
||||
} else {
|
||||
swaMaskMode = "causal"
|
||||
}
|
||||
}
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
var c cache.Cache
|
||||
if caches != nil {
|
||||
c = caches[i]
|
||||
}
|
||||
mask, maskMode := fullMask, fullMaskMode
|
||||
if layer.LayerType == "sliding_attention" {
|
||||
mask, maskMode = swaMask, swaMaskMode
|
||||
}
|
||||
x = layer.Forward(x, c, B, L, mask, maskMode, m.Config)
|
||||
}
|
||||
|
||||
return m.LMHead.Forward(m.Norm.Forward(x, m.RMSNormEps))
|
||||
}
|
||||
|
||||
// sanitizeExpertWeights splits merged gate_up weights into separate gate/up arrays.
|
||||
// MXFP4 quantized weights require contiguous memory - strided views give wrong results.
|
||||
func sanitizeExpertWeights(weights *safetensors.ModelWeights, prefix string) (moe *MoE) {
|
||||
gateUpBlocks, _ := weights.GetTensor(prefix + ".mlp.experts.gate_up_proj_blocks")
|
||||
gateUpScales, _ := weights.GetTensor(prefix + ".mlp.experts.gate_up_proj_scales")
|
||||
gateUpBias, _ := weights.GetTensor(prefix + ".mlp.experts.gate_up_proj_bias")
|
||||
downBlocks, _ := weights.GetTensor(prefix + ".mlp.experts.down_proj_blocks")
|
||||
downScales, _ := weights.GetTensor(prefix + ".mlp.experts.down_proj_scales")
|
||||
downBias, _ := weights.GetTensor(prefix + ".mlp.experts.down_proj_bias")
|
||||
|
||||
moe = &MoE{GroupSize: 32, Bits: 4, DownScales: downScales, DownBias: downBias}
|
||||
|
||||
if gateUpBlocks != nil {
|
||||
gub := mlx.FlattenRange(mlx.View(gateUpBlocks, int(mlx.DtypeUint32)), -2, -1)
|
||||
s := gub.Shape()
|
||||
moe.GateBlocks = mlx.Contiguous(mlx.SliceStride(gub, []int32{0, 0, 0}, []int32{s[0], s[1], s[2]}, []int32{1, 2, 1}))
|
||||
moe.UpBlocks = mlx.Contiguous(mlx.SliceStride(gub, []int32{0, 1, 0}, []int32{s[0], s[1], s[2]}, []int32{1, 2, 1}))
|
||||
}
|
||||
if gateUpScales != nil {
|
||||
s := gateUpScales.Shape()
|
||||
moe.GateScales = mlx.Contiguous(mlx.SliceStride(gateUpScales, []int32{0, 0, 0}, []int32{s[0], s[1], s[2]}, []int32{1, 2, 1}))
|
||||
moe.UpScales = mlx.Contiguous(mlx.SliceStride(gateUpScales, []int32{0, 1, 0}, []int32{s[0], s[1], s[2]}, []int32{1, 2, 1}))
|
||||
}
|
||||
if gateUpBias != nil {
|
||||
s := gateUpBias.Shape()
|
||||
moe.GateBias = mlx.Contiguous(mlx.SliceStride(gateUpBias, []int32{0, 0}, []int32{s[0], s[1]}, []int32{1, 2}))
|
||||
moe.UpBias = mlx.Contiguous(mlx.SliceStride(gateUpBias, []int32{0, 1}, []int32{s[0], s[1]}, []int32{1, 2}))
|
||||
}
|
||||
if downBlocks != nil {
|
||||
moe.DownBlocks = mlx.FlattenRange(mlx.View(downBlocks, int(mlx.DtypeUint32)), -2, -1)
|
||||
}
|
||||
return moe
|
||||
}
|
||||
|
||||
func Load(modelPath string) (*Model, error) {
|
||||
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
var cfg Config
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("parse config: %w", err)
|
||||
}
|
||||
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
|
||||
|
||||
weights, err := safetensors.LoadModelWeights(modelPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
|
||||
tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load tokenizer: %w", err)
|
||||
}
|
||||
|
||||
m := &Model{
|
||||
Layers: make([]*Block, cfg.NumHiddenLayers),
|
||||
Config: &cfg,
|
||||
tok: tok,
|
||||
}
|
||||
|
||||
// Load simple weights via struct tags
|
||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Load layers with custom MoE handling
|
||||
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
|
||||
prefix := fmt.Sprintf("model.layers.%d", i)
|
||||
layer := &Block{}
|
||||
if err := safetensors.LoadModule(layer, weights, prefix); err != nil {
|
||||
return nil, fmt.Errorf("layer %d: %w", i, err)
|
||||
}
|
||||
|
||||
// Initialize attention YaRN
|
||||
layer.Attention.initYarn(&cfg)
|
||||
|
||||
// Load MoE with weight sanitization
|
||||
moe := sanitizeExpertWeights(weights, prefix)
|
||||
moe.Router = layer.MLP.Router // Router was loaded by LoadModule
|
||||
moe.TopK = cfg.NumExpertsPerTok
|
||||
moe.HiddenSize = cfg.HiddenSize
|
||||
layer.MLP = moe
|
||||
|
||||
// Set layer type
|
||||
layer.LayerType = "full_attention"
|
||||
if int(i) < len(cfg.LayerTypes) {
|
||||
layer.LayerType = cfg.LayerTypes[i]
|
||||
}
|
||||
|
||||
m.Layers[i] = layer
|
||||
}
|
||||
|
||||
// Release safetensors BEFORE eval - lazy arrays have captured data,
|
||||
// this reduces peak memory by freeing mmap during materialization
|
||||
weights.ReleaseAll()
|
||||
mlx.Eval(mlx.Collect(m)...)
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m *Model) MaxContextLength() int32 {
|
||||
if m.RopeScaling != nil && m.RopeScaling.OriginalMaxPositionEmbeddings > 0 {
|
||||
return m.RopeScaling.OriginalMaxPositionEmbeddings
|
||||
}
|
||||
return 131072
|
||||
}
|
||||
@@ -1,152 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package llama
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
HiddenSize int32 `json:"hidden_size"`
|
||||
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
||||
IntermediateSize int32 `json:"intermediate_size"`
|
||||
NumAttentionHeads int32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads int32 `json:"num_key_value_heads"`
|
||||
VocabSize int32 `json:"vocab_size"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
|
||||
HeadDim int32 `json:"-"`
|
||||
Scale float32 `json:"-"`
|
||||
}
|
||||
|
||||
type Model struct {
|
||||
EmbedTokens *nn.Embedding `weight:"model.embed_tokens"`
|
||||
Layers []*Layer `weight:"model.layers"`
|
||||
Norm *nn.RMSNorm `weight:"model.norm"`
|
||||
Output *nn.Linear `weight:"lm_head,optional"`
|
||||
|
||||
tok *tokenizer.Tokenizer
|
||||
*Config
|
||||
}
|
||||
|
||||
type Layer struct {
|
||||
Attention *Attention
|
||||
MLP *MLP
|
||||
AttentionNorm *nn.RMSNorm `weight:"input_layernorm"`
|
||||
MLPNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
|
||||
}
|
||||
|
||||
type Attention struct {
|
||||
QProj *nn.Linear `weight:"self_attn.q_proj"`
|
||||
KProj *nn.Linear `weight:"self_attn.k_proj"`
|
||||
VProj *nn.Linear `weight:"self_attn.v_proj"`
|
||||
OProj *nn.Linear `weight:"self_attn.o_proj"`
|
||||
}
|
||||
|
||||
type MLP struct {
|
||||
GateProj *nn.Linear `weight:"mlp.gate_proj"`
|
||||
UpProj *nn.Linear `weight:"mlp.up_proj"`
|
||||
DownProj *nn.Linear `weight:"mlp.down_proj"`
|
||||
}
|
||||
|
||||
func Load(modelPath string) (*Model, error) {
|
||||
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
var cfg Config
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("parse config: %w", err)
|
||||
}
|
||||
cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads
|
||||
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
|
||||
|
||||
weights, err := safetensors.LoadModelWeights(modelPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
|
||||
tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load tokenizer: %w", err)
|
||||
}
|
||||
|
||||
m := &Model{
|
||||
Layers: make([]*Layer, cfg.NumHiddenLayers),
|
||||
Config: &cfg,
|
||||
tok: tok,
|
||||
}
|
||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m.Output = nn.NewLinear(m.EmbedTokens.Weight, nil)
|
||||
|
||||
mlx.Eval(mlx.Collect(m)...)
|
||||
weights.ReleaseAll()
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
|
||||
B, L := tokens.Shape()[0], tokens.Shape()[1]
|
||||
h := m.EmbedTokens.Forward(tokens)
|
||||
for i, layer := range m.Layers {
|
||||
h = layer.Forward(h, caches[i], B, L, m.Config)
|
||||
}
|
||||
return m.Output.Forward(m.Norm.Forward(h, m.RMSNormEps))
|
||||
}
|
||||
|
||||
func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||
h := mlx.Add(x, l.Attention.Forward(l.AttentionNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg))
|
||||
return mlx.Add(h, l.MLP.Forward(l.MLPNorm.Forward(h, cfg.RMSNormEps)))
|
||||
}
|
||||
|
||||
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||
q := a.QProj.Forward(x)
|
||||
k := a.KProj.Forward(x)
|
||||
v := a.VProj.Forward(x)
|
||||
|
||||
q = mlx.AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim},
|
||||
[]int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0)
|
||||
k = mlx.AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
|
||||
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
|
||||
v = mlx.AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
|
||||
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
|
||||
|
||||
q = mlx.RoPE(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, c.Offset())
|
||||
k = mlx.RoPE(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, c.Offset())
|
||||
|
||||
k, v = c.Update(k, v, int(L))
|
||||
out := mlx.ScaledDotProductAttention(q, k, v, cfg.Scale, L > 1)
|
||||
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
|
||||
return a.OProj.Forward(out)
|
||||
}
|
||||
|
||||
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
|
||||
return m.DownProj.Forward(mlx.Mul(mlx.SiLU(m.GateProj.Forward(x)), m.UpProj.Forward(x)))
|
||||
}
|
||||
|
||||
// Interface methods
|
||||
func (m *Model) NumLayers() int { return len(m.Layers) }
|
||||
func (m *Model) MaxContextLength() int32 { return m.MaxPositionEmbeddings }
|
||||
func (m *Model) VocabSize() int32 { return m.Config.VocabSize }
|
||||
func (m *Model) Tokenizer() *tokenizer.Tokenizer { return m.tok }
|
||||
|
||||
func (m *Model) NewCache(maxSeqLen int32) []cache.Cache {
|
||||
caches := make([]cache.Cache, len(m.Layers))
|
||||
for i := range caches {
|
||||
caches[i] = cache.NewKVCache()
|
||||
}
|
||||
return caches
|
||||
}
|
||||
@@ -1,66 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package qwen_image
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// TestPipelineOutput runs the full pipeline (integration test).
|
||||
// Skips if model weights not found. Requires ~50GB VRAM.
|
||||
func TestPipelineOutput(t *testing.T) {
|
||||
modelPath := "../../../weights/Qwen-Image-2512"
|
||||
if _, err := os.Stat(modelPath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping: model weights not found at " + modelPath)
|
||||
}
|
||||
|
||||
// Load model
|
||||
pm, err := LoadPersistent(modelPath)
|
||||
if err != nil {
|
||||
t.Skipf("Skipping: failed to load model: %v", err)
|
||||
}
|
||||
|
||||
// Run 2-step pipeline (minimum for stable scheduler)
|
||||
cfg := &GenerateConfig{
|
||||
Prompt: "a cat",
|
||||
Width: 256,
|
||||
Height: 256,
|
||||
Steps: 2,
|
||||
Seed: 42,
|
||||
}
|
||||
|
||||
output, err := pm.GenerateFromConfig(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("Pipeline failed: %v", err)
|
||||
}
|
||||
mlx.Eval(output)
|
||||
|
||||
// Verify output shape [1, C, H, W]
|
||||
shape := output.Shape()
|
||||
if len(shape) != 4 {
|
||||
t.Errorf("Expected 4D output, got %v", shape)
|
||||
}
|
||||
if shape[0] != 1 || shape[1] != 3 || shape[2] != cfg.Height || shape[3] != cfg.Width {
|
||||
t.Errorf("Shape mismatch: got %v, expected [1, 3, %d, %d]", shape, cfg.Height, cfg.Width)
|
||||
}
|
||||
|
||||
// Verify values in expected range [0, 1]
|
||||
data := output.Data()
|
||||
minVal, maxVal := float32(1.0), float32(0.0)
|
||||
for _, v := range data {
|
||||
if v < minVal {
|
||||
minVal = v
|
||||
}
|
||||
if v > maxVal {
|
||||
maxVal = v
|
||||
}
|
||||
}
|
||||
t.Logf("Output range: [%.4f, %.4f]", minVal, maxVal)
|
||||
|
||||
if minVal < -0.1 || maxVal > 1.1 {
|
||||
t.Errorf("Output values out of range: [%.4f, %.4f]", minVal, maxVal)
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,350 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package qwen_image implements the Qwen-Image diffusion transformer model.
|
||||
package qwen_image
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
)
|
||||
|
||||
// GenerateConfig holds all options for image generation.
|
||||
type GenerateConfig struct {
|
||||
Prompt string
|
||||
NegativePrompt string // Empty = no CFG
|
||||
CFGScale float32 // Only used if NegativePrompt is set (default: 4.0)
|
||||
Width int32 // Image width (default: 1024)
|
||||
Height int32 // Image height (default: 1024)
|
||||
Steps int // Denoising steps (default: 30)
|
||||
Seed int64 // Random seed
|
||||
Progress ProgressFunc // Optional progress callback
|
||||
|
||||
// Layer caching (DeepCache/Learning-to-Cache speedup)
|
||||
LayerCache bool // Enable layer caching (default: false)
|
||||
CacheInterval int // Refresh cache every N steps (default: 3)
|
||||
CacheLayers int // Number of shallow layers to cache (default: 25)
|
||||
}
|
||||
|
||||
// ProgressFunc is called during generation with step progress.
|
||||
type ProgressFunc func(step, totalSteps int)
|
||||
|
||||
// Model represents a Qwen-Image diffusion model.
|
||||
type Model struct {
|
||||
ModelPath string
|
||||
Tokenizer *tokenizer.Tokenizer
|
||||
TextEncoder *Qwen25VL
|
||||
Transformer *Transformer
|
||||
VAEDecoder *VAEDecoder
|
||||
}
|
||||
|
||||
// Load loads the Qwen-Image model from a directory.
|
||||
func (m *Model) Load(modelPath string) error {
|
||||
fmt.Println("Loading Qwen-Image model...")
|
||||
start := time.Now()
|
||||
|
||||
if mlx.GPUIsAvailable() {
|
||||
mlx.SetDefaultDeviceGPU()
|
||||
mlx.EnableCompile()
|
||||
}
|
||||
|
||||
m.ModelPath = modelPath
|
||||
|
||||
// Load tokenizer
|
||||
fmt.Print(" Loading tokenizer... ")
|
||||
tokenizerPath := filepath.Join(modelPath, "tokenizer")
|
||||
tok, err := tokenizer.Load(tokenizerPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("tokenizer: %w", err)
|
||||
}
|
||||
m.Tokenizer = tok
|
||||
fmt.Println("✓")
|
||||
|
||||
// Load text encoder (Qwen2.5-VL in text-only mode - skip vision tower for efficiency)
|
||||
m.TextEncoder = &Qwen25VL{}
|
||||
if err := m.TextEncoder.LoadTextOnly(filepath.Join(modelPath, "text_encoder")); err != nil {
|
||||
return fmt.Errorf("text encoder: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.TextEncoder)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
// Load transformer
|
||||
m.Transformer = &Transformer{}
|
||||
if err := m.Transformer.Load(filepath.Join(modelPath, "transformer")); err != nil {
|
||||
return fmt.Errorf("transformer: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.Transformer)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
// Load VAE decoder
|
||||
m.VAEDecoder = &VAEDecoder{}
|
||||
if err := m.VAEDecoder.Load(filepath.Join(modelPath, "vae")); err != nil {
|
||||
return fmt.Errorf("VAE decoder: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.VAEDecoder)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
mem := mlx.MetalGetActiveMemory()
|
||||
peak := mlx.MetalGetPeakMemory()
|
||||
fmt.Printf(" Loaded in %.2fs (%.1f GB active, %.1f GB peak)\n",
|
||||
time.Since(start).Seconds(),
|
||||
float64(mem)/(1024*1024*1024),
|
||||
float64(peak)/(1024*1024*1024))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Generate creates an image from a prompt.
|
||||
func (m *Model) Generate(prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
|
||||
return m.GenerateFromConfig(&GenerateConfig{
|
||||
Prompt: prompt,
|
||||
Width: width,
|
||||
Height: height,
|
||||
Steps: steps,
|
||||
Seed: seed,
|
||||
})
|
||||
}
|
||||
|
||||
// GenerateWithProgress creates an image with progress callback.
|
||||
func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress ProgressFunc) (*mlx.Array, error) {
|
||||
return m.GenerateFromConfig(&GenerateConfig{
|
||||
Prompt: prompt,
|
||||
Width: width,
|
||||
Height: height,
|
||||
Steps: steps,
|
||||
Seed: seed,
|
||||
Progress: progress,
|
||||
})
|
||||
}
|
||||
|
||||
// GenerateWithCFG creates an image with classifier-free guidance.
|
||||
func (m *Model) GenerateWithCFG(prompt, negativePrompt string, width, height int32, steps int, seed int64, cfgScale float32, progress ProgressFunc) (*mlx.Array, error) {
|
||||
return m.GenerateFromConfig(&GenerateConfig{
|
||||
Prompt: prompt,
|
||||
NegativePrompt: negativePrompt,
|
||||
CFGScale: cfgScale,
|
||||
Width: width,
|
||||
Height: height,
|
||||
Steps: steps,
|
||||
Seed: seed,
|
||||
Progress: progress,
|
||||
})
|
||||
}
|
||||
|
||||
// GenerateFromConfig generates an image using the unified config struct.
|
||||
func (m *Model) GenerateFromConfig(cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
start := time.Now()
|
||||
result, err := m.generate(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if cfg.NegativePrompt != "" {
|
||||
fmt.Printf("Generated with CFG (scale=%.1f) in %.2fs (%d steps)\n", cfg.CFGScale, time.Since(start).Seconds(), cfg.Steps)
|
||||
} else {
|
||||
fmt.Printf("Generated in %.2fs (%d steps)\n", time.Since(start).Seconds(), cfg.Steps)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GenerateImage implements model.ImageModel interface.
|
||||
func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
|
||||
return m.Generate(prompt, width, height, steps, seed)
|
||||
}
|
||||
|
||||
// generate is the internal denoising pipeline.
|
||||
func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
// Apply defaults
|
||||
if cfg.Width <= 0 {
|
||||
cfg.Width = 1024
|
||||
}
|
||||
if cfg.Height <= 0 {
|
||||
cfg.Height = 1024
|
||||
}
|
||||
if cfg.Steps <= 0 {
|
||||
cfg.Steps = 30
|
||||
}
|
||||
if cfg.CFGScale <= 0 {
|
||||
cfg.CFGScale = 4.0
|
||||
}
|
||||
if cfg.CacheInterval <= 0 {
|
||||
cfg.CacheInterval = 3
|
||||
}
|
||||
if cfg.CacheLayers <= 0 {
|
||||
cfg.CacheLayers = 25 // ~42% of 60 layers (similar ratio to Z-Image's 15/38)
|
||||
}
|
||||
|
||||
useCFG := cfg.NegativePrompt != ""
|
||||
tcfg := m.Transformer.Config
|
||||
latentH := cfg.Height / 8
|
||||
latentW := cfg.Width / 8
|
||||
pH := latentH / tcfg.PatchSize
|
||||
pW := latentW / tcfg.PatchSize
|
||||
imgSeqLen := pH * pW
|
||||
|
||||
// Text encoding
|
||||
var posEmb, negEmb *mlx.Array
|
||||
{
|
||||
posEmb = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.Prompt)
|
||||
if useCFG {
|
||||
negEmb = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.NegativePrompt)
|
||||
mlx.Keep(posEmb, negEmb)
|
||||
mlx.Eval(posEmb, negEmb)
|
||||
} else {
|
||||
mlx.Keep(posEmb)
|
||||
mlx.Eval(posEmb)
|
||||
}
|
||||
}
|
||||
|
||||
// Pad sequences to same length for CFG
|
||||
txtLen := posEmb.Shape()[1]
|
||||
if useCFG {
|
||||
negLen := negEmb.Shape()[1]
|
||||
if negLen > txtLen {
|
||||
txtLen = negLen
|
||||
}
|
||||
if posEmb.Shape()[1] < txtLen {
|
||||
posEmb = padSequence(posEmb, txtLen)
|
||||
}
|
||||
if negEmb.Shape()[1] < txtLen {
|
||||
negEmb = padSequence(negEmb, txtLen)
|
||||
}
|
||||
mlx.Keep(posEmb, negEmb)
|
||||
}
|
||||
|
||||
// Scheduler
|
||||
scheduler := NewFlowMatchScheduler(DefaultSchedulerConfig())
|
||||
scheduler.SetTimesteps(cfg.Steps, imgSeqLen)
|
||||
|
||||
// Init latents [B, C, T, H, W]
|
||||
var latents *mlx.Array
|
||||
{
|
||||
latents = scheduler.InitNoise([]int32{1, tcfg.OutChannels, 1, latentH, latentW}, cfg.Seed)
|
||||
mlx.Eval(latents)
|
||||
}
|
||||
|
||||
// RoPE cache
|
||||
var ropeCache *RoPECache
|
||||
{
|
||||
ropeCache = PrepareRoPE(pH, pW, txtLen, tcfg.AxesDimsRope)
|
||||
mlx.Keep(ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
mlx.Eval(ropeCache.ImgFreqs)
|
||||
}
|
||||
|
||||
// Layer cache for DeepCache/Learning-to-Cache speedup
|
||||
var stepCache *cache.StepCache
|
||||
if cfg.LayerCache {
|
||||
stepCache = cache.NewStepCache(cfg.CacheLayers)
|
||||
fmt.Printf(" Layer caching: %d layers, refresh every %d steps\n", cfg.CacheLayers, cfg.CacheInterval)
|
||||
}
|
||||
|
||||
// Denoising loop
|
||||
for i := 0; i < cfg.Steps; i++ {
|
||||
stepStart := time.Now()
|
||||
if cfg.Progress != nil {
|
||||
cfg.Progress(i+1, cfg.Steps)
|
||||
}
|
||||
|
||||
t := scheduler.Timesteps[i]
|
||||
timestep := mlx.ToBFloat16(mlx.NewArray([]float32{t}, []int32{1}))
|
||||
|
||||
// Squeeze temporal dim: [B, C, T, H, W] -> [B, C, H, W]
|
||||
latents2D := mlx.Squeeze(latents, 2)
|
||||
patches := PackLatents(latents2D, tcfg.PatchSize)
|
||||
|
||||
var output *mlx.Array
|
||||
if useCFG {
|
||||
// True CFG: run twice and combine with norm rescaling
|
||||
// Note: layer caching with CFG is not supported yet (would need 2 caches)
|
||||
posOutput := m.Transformer.Forward(patches, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
negOutput := m.Transformer.Forward(patches, negEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
|
||||
diff := mlx.Sub(posOutput, negOutput)
|
||||
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
|
||||
combPred := mlx.Add(negOutput, scaledDiff)
|
||||
|
||||
// Norm rescaling: rescale combined prediction to match conditional prediction's norm
|
||||
condNorm := mlx.Sqrt(mlx.Sum(mlx.Square(posOutput), -1, true))
|
||||
combNorm := mlx.Sqrt(mlx.Sum(mlx.Square(combPred), -1, true))
|
||||
output = mlx.Mul(combPred, mlx.Div(condNorm, combNorm))
|
||||
} else if stepCache != nil {
|
||||
output = m.Transformer.ForwardWithCache(patches, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs,
|
||||
stepCache, i, cfg.CacheInterval, cfg.CacheLayers)
|
||||
} else {
|
||||
output = m.Transformer.Forward(patches, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
}
|
||||
|
||||
noisePred := UnpackLatents(output, latentH, latentW, tcfg.PatchSize)
|
||||
oldLatents := latents
|
||||
latents = scheduler.Step(noisePred, latents, i)
|
||||
|
||||
// Keep cached arrays alive across cleanup
|
||||
if stepCache != nil {
|
||||
mlx.Keep(stepCache.Arrays()...)
|
||||
}
|
||||
mlx.Eval(latents)
|
||||
oldLatents.Free()
|
||||
|
||||
activeMem := float64(mlx.MetalGetActiveMemory()) / (1024 * 1024 * 1024)
|
||||
peakMem := float64(mlx.MetalGetPeakMemory()) / (1024 * 1024 * 1024)
|
||||
fmt.Printf(" Step %d/%d: t=%.4f (%.2fs) [%.1f GB active, %.1f GB peak]\n", i+1, cfg.Steps, t, time.Since(stepStart).Seconds(), activeMem, peakMem)
|
||||
}
|
||||
|
||||
// Free denoising temporaries before VAE decode
|
||||
posEmb.Free()
|
||||
if negEmb != nil {
|
||||
negEmb.Free()
|
||||
}
|
||||
ropeCache.ImgFreqs.Free()
|
||||
ropeCache.TxtFreqs.Free()
|
||||
if stepCache != nil {
|
||||
stepCache.Free()
|
||||
}
|
||||
|
||||
// VAE decode (Decode manages its own pools for staged memory)
|
||||
decoded := m.VAEDecoder.Decode(latents)
|
||||
latents.Free()
|
||||
// Post-process: squeeze temporal dim and rescale to [0, 1]
|
||||
{
|
||||
decoded = mlx.Squeeze(decoded, 2)
|
||||
decoded = mlx.AddScalar(decoded, 1.0)
|
||||
decoded = mlx.DivScalar(decoded, 2.0)
|
||||
mlx.Eval(decoded)
|
||||
}
|
||||
|
||||
fmt.Printf(" Peak memory: %.2f GB\n", float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
return decoded, nil
|
||||
}
|
||||
|
||||
// padSequence pads a sequence tensor to the target length with zeros
|
||||
func padSequence(x *mlx.Array, targetLen int32) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
currentLen := shape[1]
|
||||
if currentLen >= targetLen {
|
||||
return x
|
||||
}
|
||||
padLen := targetLen - currentLen
|
||||
// Pad on sequence dimension (axis 1)
|
||||
return mlx.Pad(x, []int32{0, 0, 0, padLen, 0, 0})
|
||||
}
|
||||
|
||||
// LoadPersistent is an alias for backward compatibility.
|
||||
// Use m := &Model{}; m.Load(path) instead.
|
||||
func LoadPersistent(modelPath string) (*Model, error) {
|
||||
m := &Model{}
|
||||
if err := m.Load(modelPath); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
@@ -1,218 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package qwen_image
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// SchedulerConfig holds FlowMatchEulerDiscreteScheduler configuration
|
||||
type SchedulerConfig struct {
|
||||
NumTrainTimesteps int32 `json:"num_train_timesteps"` // 1000
|
||||
BaseShift float32 `json:"base_shift"` // 0.5
|
||||
MaxShift float32 `json:"max_shift"` // 0.9
|
||||
BaseImageSeqLen int32 `json:"base_image_seq_len"` // 256
|
||||
MaxImageSeqLen int32 `json:"max_image_seq_len"` // 8192
|
||||
ShiftTerminal float32 `json:"shift_terminal"` // 0.02
|
||||
UseDynamicShift bool `json:"use_dynamic_shifting"` // true
|
||||
}
|
||||
|
||||
// DefaultSchedulerConfig returns config for FlowMatchEulerDiscreteScheduler
|
||||
func DefaultSchedulerConfig() *SchedulerConfig {
|
||||
return &SchedulerConfig{
|
||||
NumTrainTimesteps: 1000,
|
||||
BaseShift: 0.5,
|
||||
MaxShift: 0.9, // Matches scheduler_config.json
|
||||
BaseImageSeqLen: 256,
|
||||
MaxImageSeqLen: 8192,
|
||||
ShiftTerminal: 0.02,
|
||||
UseDynamicShift: true,
|
||||
}
|
||||
}
|
||||
|
||||
// FlowMatchScheduler implements the Flow Match Euler discrete scheduler
|
||||
type FlowMatchScheduler struct {
|
||||
Config *SchedulerConfig
|
||||
Timesteps []float32
|
||||
Sigmas []float32
|
||||
NumSteps int
|
||||
}
|
||||
|
||||
// NewFlowMatchScheduler creates a new scheduler
|
||||
func NewFlowMatchScheduler(cfg *SchedulerConfig) *FlowMatchScheduler {
|
||||
return &FlowMatchScheduler{
|
||||
Config: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
// CalculateShift computes the dynamic shift based on image sequence length
|
||||
// This matches Python's calculate_shift function
|
||||
func CalculateShift(imageSeqLen int32, baseSeqLen int32, maxSeqLen int32, baseShift float32, maxShift float32) float32 {
|
||||
m := (maxShift - baseShift) / float32(maxSeqLen-baseSeqLen)
|
||||
b := baseShift - m*float32(baseSeqLen)
|
||||
mu := float32(imageSeqLen)*m + b
|
||||
return mu
|
||||
}
|
||||
|
||||
// SetTimesteps sets up the scheduler for the given number of inference steps
|
||||
// Matches Python diffusers FlowMatchEulerDiscreteScheduler behavior:
|
||||
// 1. Create sigmas from sigma_max to sigma_min (linspace)
|
||||
// 2. Apply time_shift with mu (if dynamic shifting)
|
||||
// 3. Apply stretch_shift_to_terminal to make final value = shift_terminal
|
||||
func (s *FlowMatchScheduler) SetTimesteps(numSteps int, imageSeqLen int32) {
|
||||
s.NumSteps = numSteps
|
||||
|
||||
// Calculate mu for dynamic shifting
|
||||
var mu float32
|
||||
if s.Config.UseDynamicShift {
|
||||
mu = CalculateShift(
|
||||
imageSeqLen,
|
||||
s.Config.BaseImageSeqLen,
|
||||
s.Config.MaxImageSeqLen,
|
||||
s.Config.BaseShift,
|
||||
s.Config.MaxShift,
|
||||
)
|
||||
}
|
||||
|
||||
// Step 1: Create sigmas from 1.0 to 1/num_steps
|
||||
// Python (pipeline_qwenimage.py:639):
|
||||
// sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
||||
// This gives sigmas from 1.0 to 1/30 = 0.033 for 30 steps
|
||||
sigmas := make([]float32, numSteps)
|
||||
sigmaMax := float32(1.0)
|
||||
sigmaMin := 1.0 / float32(numSteps) // 1/30 = 0.033 for 30 steps
|
||||
if numSteps == 1 {
|
||||
sigmas[0] = sigmaMax
|
||||
} else {
|
||||
for i := 0; i < numSteps; i++ {
|
||||
sigmas[i] = sigmaMax + float32(i)*(sigmaMin-sigmaMax)/float32(numSteps-1)
|
||||
}
|
||||
}
|
||||
|
||||
// Step 2: Apply time shift if using dynamic shifting
|
||||
if s.Config.UseDynamicShift && mu != 0 {
|
||||
for i := range sigmas {
|
||||
sigmas[i] = s.timeShift(mu, sigmas[i])
|
||||
}
|
||||
}
|
||||
|
||||
// Step 3: Apply stretch_shift_to_terminal
|
||||
if s.Config.ShiftTerminal > 0 {
|
||||
sigmas = s.stretchShiftToTerminal(sigmas)
|
||||
}
|
||||
|
||||
// Step 4: Append terminal sigma (0) and store
|
||||
// Note: Python's scheduler.timesteps are sigmas*1000, but the pipeline divides by 1000
|
||||
// before passing to transformer. We skip both steps and just use sigmas directly.
|
||||
s.Sigmas = make([]float32, numSteps+1)
|
||||
s.Timesteps = make([]float32, numSteps+1)
|
||||
for i := 0; i < numSteps; i++ {
|
||||
s.Sigmas[i] = sigmas[i]
|
||||
s.Timesteps[i] = sigmas[i]
|
||||
}
|
||||
s.Sigmas[numSteps] = 0.0
|
||||
s.Timesteps[numSteps] = 0.0
|
||||
}
|
||||
|
||||
// stretchShiftToTerminal stretches and shifts the timestep schedule
|
||||
// so the final value equals shift_terminal (matches Python behavior)
|
||||
func (s *FlowMatchScheduler) stretchShiftToTerminal(sigmas []float32) []float32 {
|
||||
if len(sigmas) == 0 {
|
||||
return sigmas
|
||||
}
|
||||
|
||||
// one_minus_z = 1 - t
|
||||
// scale_factor = one_minus_z[-1] / (1 - shift_terminal)
|
||||
// stretched_t = 1 - (one_minus_z / scale_factor)
|
||||
lastSigma := sigmas[len(sigmas)-1]
|
||||
scaleFactor := (1.0 - lastSigma) / (1.0 - s.Config.ShiftTerminal)
|
||||
|
||||
// Handle edge case: if scaleFactor is 0 or near 0, skip stretch
|
||||
// This happens when lastSigma ≈ 1.0 (e.g., single step with timeshift)
|
||||
if scaleFactor < 1e-6 {
|
||||
return sigmas
|
||||
}
|
||||
|
||||
result := make([]float32, len(sigmas))
|
||||
for i, t := range sigmas {
|
||||
oneMinusZ := 1.0 - t
|
||||
result[i] = 1.0 - (oneMinusZ / scaleFactor)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// timeShift applies the dynamic time shift (exponential)
|
||||
// exp(mu) / (exp(mu) + (1/t - 1))
|
||||
func (s *FlowMatchScheduler) timeShift(mu float32, t float32) float32 {
|
||||
if t <= 0 {
|
||||
return 0
|
||||
}
|
||||
expMu := float32(math.Exp(float64(mu)))
|
||||
return expMu / (expMu + (1.0/t - 1.0))
|
||||
}
|
||||
|
||||
// Step performs one denoising step
|
||||
// modelOutput: predicted velocity from the transformer
|
||||
// sample: current noisy sample
|
||||
// timestepIdx: current timestep index
|
||||
func (s *FlowMatchScheduler) Step(modelOutput, sample *mlx.Array, timestepIdx int) *mlx.Array {
|
||||
// Get current and next sigma
|
||||
sigma := s.Sigmas[timestepIdx]
|
||||
sigmaNext := s.Sigmas[timestepIdx+1]
|
||||
|
||||
// Euler step: x_{t-dt} = x_t + (sigma_next - sigma) * v_t
|
||||
dt := sigmaNext - sigma
|
||||
|
||||
// Upcast to float32 to avoid precision issues (matches Python diffusers)
|
||||
sampleF32 := mlx.AsType(sample, mlx.DtypeFloat32)
|
||||
modelOutputF32 := mlx.AsType(modelOutput, mlx.DtypeFloat32)
|
||||
|
||||
scaledOutput := mlx.MulScalar(modelOutputF32, dt)
|
||||
result := mlx.Add(sampleF32, scaledOutput)
|
||||
|
||||
// Cast back to original dtype
|
||||
return mlx.ToBFloat16(result)
|
||||
}
|
||||
|
||||
// GetTimestep returns the timestep value at the given index
|
||||
func (s *FlowMatchScheduler) GetTimestep(idx int) float32 {
|
||||
if idx < len(s.Timesteps) {
|
||||
return s.Timesteps[idx]
|
||||
}
|
||||
return 0.0
|
||||
}
|
||||
|
||||
// InitNoise creates initial noise for sampling in unpacked format [B, C, T, H, W]
|
||||
func (s *FlowMatchScheduler) InitNoise(shape []int32, seed int64) *mlx.Array {
|
||||
return mlx.RandomNormal(shape, uint64(seed))
|
||||
}
|
||||
|
||||
// InitNoisePacked creates initial noise directly in packed format [B, L, C*4]
|
||||
// This matches how Python diffusers generates noise - directly in packed space.
|
||||
// Generating in unpacked format and then packing produces different spatial
|
||||
// correlation structure, which affects model output quality.
|
||||
func (s *FlowMatchScheduler) InitNoisePacked(batchSize, seqLen, channels int32, seed int64) *mlx.Array {
|
||||
shape := []int32{batchSize, seqLen, channels}
|
||||
return mlx.RandomNormal(shape, uint64(seed))
|
||||
}
|
||||
|
||||
// GetLatentShape returns the latent shape for a given image size
|
||||
// For qwen_image: VAE downscale is 8x (spatial), latent has 16 channels
|
||||
func GetLatentShape(batchSize, height, width int32) []int32 {
|
||||
latentH := height / 8
|
||||
latentW := width / 8
|
||||
return []int32{batchSize, 16, 1, latentH, latentW} // [B, C, T, H, W]
|
||||
}
|
||||
|
||||
// GetPatchedLatentShape returns the patchified latent shape
|
||||
// After patchification: [B, L, C*patch_size^2] where L = H/2 * W/2
|
||||
func GetPatchedLatentShape(batchSize, height, width, patchSize int32) []int32 {
|
||||
latentH := height / 8
|
||||
latentW := width / 8
|
||||
pH := latentH / patchSize
|
||||
pW := latentW / patchSize
|
||||
inChannels := int32(64) // 16 * patch_size^2
|
||||
return []int32{batchSize, pH * pW, inChannels}
|
||||
}
|
||||
@@ -1,135 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package qwen_image
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestSchedulerSetTimesteps verifies scheduler sigmas match Python diffusers reference.
|
||||
// Golden values generated via:
|
||||
//
|
||||
// python3 -c "
|
||||
// from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
||||
// import numpy as np
|
||||
// s = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, base_shift=0.5, max_shift=0.9,
|
||||
// base_image_seq_len=256, max_image_seq_len=8192, shift_terminal=0.02, use_dynamic_shifting=True)
|
||||
// mu = 4096 * (0.9-0.5)/(8192-256) + 0.5 - (0.9-0.5)/(8192-256)*256
|
||||
// sigmas = np.linspace(1.0, 1.0/30, 30)
|
||||
// s.set_timesteps(sigmas=sigmas, mu=mu)
|
||||
// print(s.sigmas.numpy())"
|
||||
func TestSchedulerSetTimesteps(t *testing.T) {
|
||||
cfg := DefaultSchedulerConfig()
|
||||
scheduler := NewFlowMatchScheduler(cfg)
|
||||
scheduler.SetTimesteps(30, 4096)
|
||||
|
||||
// Golden values from Python diffusers (first 3, last 3 before terminal)
|
||||
wantFirst := []float32{1.000000, 0.982251, 0.963889}
|
||||
wantLast := []float32{0.142924, 0.083384, 0.020000}
|
||||
|
||||
// Check first 3
|
||||
for i, want := range wantFirst {
|
||||
got := scheduler.Sigmas[i]
|
||||
if abs32(got-want) > 1e-4 {
|
||||
t.Errorf("sigma[%d]: got %v, want %v", i, got, want)
|
||||
}
|
||||
}
|
||||
|
||||
// Check last 3 (indices 27, 28, 29)
|
||||
for i, want := range wantLast {
|
||||
idx := 27 + i
|
||||
got := scheduler.Sigmas[idx]
|
||||
if abs32(got-want) > 1e-4 {
|
||||
t.Errorf("sigma[%d]: got %v, want %v", idx, got, want)
|
||||
}
|
||||
}
|
||||
|
||||
// Check terminal is 0
|
||||
if scheduler.Sigmas[30] != 0.0 {
|
||||
t.Errorf("terminal sigma: got %v, want 0", scheduler.Sigmas[30])
|
||||
}
|
||||
|
||||
// Check length
|
||||
if len(scheduler.Sigmas) != 31 {
|
||||
t.Errorf("sigmas length: got %d, want 31", len(scheduler.Sigmas))
|
||||
}
|
||||
}
|
||||
|
||||
// TestSchedulerProperties tests mathematical invariants of the scheduler.
|
||||
func TestSchedulerProperties(t *testing.T) {
|
||||
cfg := DefaultSchedulerConfig()
|
||||
scheduler := NewFlowMatchScheduler(cfg)
|
||||
scheduler.SetTimesteps(30, 4096)
|
||||
|
||||
// Property: sigmas monotonically decreasing
|
||||
for i := 1; i < len(scheduler.Sigmas); i++ {
|
||||
if scheduler.Sigmas[i] > scheduler.Sigmas[i-1] {
|
||||
t.Errorf("sigmas not monotonically decreasing at %d: %v > %v",
|
||||
i, scheduler.Sigmas[i], scheduler.Sigmas[i-1])
|
||||
}
|
||||
}
|
||||
|
||||
// Property: first sigma should be ~1.0 (with time shift)
|
||||
if scheduler.Sigmas[0] < 0.9 || scheduler.Sigmas[0] > 1.01 {
|
||||
t.Errorf("first sigma out of expected range [0.9, 1.01]: %v", scheduler.Sigmas[0])
|
||||
}
|
||||
|
||||
// Property: terminal sigma should be exactly 0
|
||||
if scheduler.Sigmas[len(scheduler.Sigmas)-1] != 0.0 {
|
||||
t.Errorf("terminal sigma should be 0, got %v", scheduler.Sigmas[len(scheduler.Sigmas)-1])
|
||||
}
|
||||
|
||||
// Property: last non-terminal sigma should be shift_terminal (0.02)
|
||||
lastNonTerminal := scheduler.Sigmas[len(scheduler.Sigmas)-2]
|
||||
if abs32(lastNonTerminal-0.02) > 1e-5 {
|
||||
t.Errorf("last non-terminal sigma should be 0.02, got %v", lastNonTerminal)
|
||||
}
|
||||
|
||||
// Property: length = steps + 1
|
||||
if len(scheduler.Sigmas) != scheduler.NumSteps+1 {
|
||||
t.Errorf("sigmas length should be steps+1: got %d, want %d",
|
||||
len(scheduler.Sigmas), scheduler.NumSteps+1)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCalculateShift verifies the mu calculation against Python reference.
|
||||
// Golden values from: mu = img_seq_len * m + b where m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
||||
func TestCalculateShift(t *testing.T) {
|
||||
cases := []struct {
|
||||
imgSeqLen int32
|
||||
want float32
|
||||
}{
|
||||
{256, 0.5}, // base case
|
||||
{8192, 0.9}, // max case
|
||||
{4096, 0.6935}, // middle case (rounded)
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
got := CalculateShift(c.imgSeqLen, 256, 8192, 0.5, 0.9)
|
||||
if abs32(got-c.want) > 0.001 {
|
||||
t.Errorf("CalculateShift(%d): got %v, want %v", c.imgSeqLen, got, c.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestSchedulerStep verifies the Euler step formula.
|
||||
func TestSchedulerStep(t *testing.T) {
|
||||
cfg := DefaultSchedulerConfig()
|
||||
scheduler := NewFlowMatchScheduler(cfg)
|
||||
scheduler.SetTimesteps(30, 4096)
|
||||
|
||||
// Verify dt calculation for first step
|
||||
sigma0 := scheduler.Sigmas[0]
|
||||
sigma1 := scheduler.Sigmas[1]
|
||||
expectedDt := sigma1 - sigma0
|
||||
|
||||
// dt should be negative (sigmas decrease)
|
||||
if expectedDt >= 0 {
|
||||
t.Errorf("expected negative dt, got %v (sigma0=%v, sigma1=%v)", expectedDt, sigma0, sigma1)
|
||||
}
|
||||
}
|
||||
|
||||
func abs32(x float32) float32 {
|
||||
return float32(math.Abs(float64(x)))
|
||||
}
|
||||
@@ -1,174 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package qwen_image
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
|
||||
// TinyTextEncoderConfig holds config for the tiny test text encoder
|
||||
type TinyTextEncoderConfig struct {
|
||||
HiddenSize int32 `json:"hidden_size"`
|
||||
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
||||
IntermediateSize int32 `json:"intermediate_size"`
|
||||
NumAttentionHeads int32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads int32 `json:"num_key_value_heads"`
|
||||
VocabSize int32 `json:"vocab_size"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
HeadDim int32 `json:"head_dim"`
|
||||
MRoPESection []int32 `json:"mrope_section"`
|
||||
}
|
||||
|
||||
// loadTinyTextEncoder loads the tiny text encoder from testdata
|
||||
func loadTinyTextEncoder(t *testing.T) (*Qwen25VL, *TinyTextEncoderConfig) {
|
||||
t.Helper()
|
||||
|
||||
testdataDir := filepath.Join("testdata", "tiny_text_encoder")
|
||||
|
||||
// Load config
|
||||
configData, err := os.ReadFile(filepath.Join(testdataDir, "config.json"))
|
||||
if err != nil {
|
||||
t.Skipf("Skipping: tiny weights not found. Regenerate with Python (see models/CLAUDE.md)")
|
||||
}
|
||||
|
||||
var tinyCfg TinyTextEncoderConfig
|
||||
if err := json.Unmarshal(configData, &tinyCfg); err != nil {
|
||||
t.Fatalf("Failed to parse config: %v", err)
|
||||
}
|
||||
|
||||
// Create encoder config (using Qwen25VLConfig)
|
||||
cfg := &Qwen25VLConfig{
|
||||
HiddenSize: tinyCfg.HiddenSize,
|
||||
NumHiddenLayers: tinyCfg.NumHiddenLayers,
|
||||
IntermediateSize: tinyCfg.IntermediateSize,
|
||||
NumAttentionHeads: tinyCfg.NumAttentionHeads,
|
||||
NumKeyValueHeads: tinyCfg.NumKeyValueHeads,
|
||||
VocabSize: tinyCfg.VocabSize,
|
||||
RMSNormEps: tinyCfg.RMSNormEps,
|
||||
RopeTheta: tinyCfg.RopeTheta,
|
||||
HeadDim: tinyCfg.HeadDim,
|
||||
MRoPESection: tinyCfg.MRoPESection,
|
||||
}
|
||||
|
||||
// Load weights
|
||||
weights, err := safetensors.LoadModelWeights(testdataDir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load weights: %v", err)
|
||||
}
|
||||
|
||||
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
|
||||
t.Fatalf("Failed to bulk load weights: %v", err)
|
||||
}
|
||||
|
||||
// Build encoder
|
||||
embedding, err := weights.Get("model.embed_tokens.weight")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get embedding: %v", err)
|
||||
}
|
||||
|
||||
blocks := make([]*VLTextBlock, cfg.NumHiddenLayers)
|
||||
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
|
||||
block, err := newVLTextBlock(weights, int(i), cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load block %d: %v", i, err)
|
||||
}
|
||||
blocks[i] = block
|
||||
}
|
||||
|
||||
finalNorm, err := weights.Get("model.norm.weight")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get final norm: %v", err)
|
||||
}
|
||||
|
||||
encoder := &Qwen25VL{
|
||||
Config: cfg,
|
||||
Embedding: embedding,
|
||||
Blocks: blocks,
|
||||
FinalNorm: finalNorm,
|
||||
HasVision: false, // Text-only mode
|
||||
}
|
||||
|
||||
return encoder, &tinyCfg
|
||||
}
|
||||
|
||||
// TestTextEncoderForward verifies the text encoder forward pass with tiny weights.
|
||||
func TestTextEncoderForward(t *testing.T) {
|
||||
encoder, cfg := loadTinyTextEncoder(t)
|
||||
|
||||
// Create test tokens (within vocab range)
|
||||
tokens := []int32{1, 2, 3, 4, 5}
|
||||
|
||||
// Forward pass using EncodeTextOnly
|
||||
out := encoder.EncodeTextOnly(tokens)
|
||||
mlx.Eval(out)
|
||||
|
||||
// Verify output shape: [batch, seq_len, hidden_size]
|
||||
wantShape := []int32{1, 5, cfg.HiddenSize}
|
||||
if !slices.Equal(out.Shape(), wantShape) {
|
||||
t.Errorf("output shape: got %v, want %v", out.Shape(), wantShape)
|
||||
}
|
||||
|
||||
// Verify output is finite (not NaN or Inf)
|
||||
data := out.Data()
|
||||
for i, v := range data {
|
||||
if math.IsNaN(float64(v)) || math.IsInf(float64(v), 0) {
|
||||
t.Errorf("output[%d] is not finite: %v", i, v)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestTextEncoderBatch tests batch processing.
|
||||
func TestTextEncoderBatch(t *testing.T) {
|
||||
encoder, cfg := loadTinyTextEncoder(t)
|
||||
|
||||
// For batch test, we'll use EncodeTextOnly with a single sequence
|
||||
// (EncodeTextOnly doesn't support batch, but we can verify single sequence works)
|
||||
tokens := []int32{1, 2, 3}
|
||||
|
||||
out := encoder.EncodeTextOnly(tokens)
|
||||
mlx.Eval(out)
|
||||
|
||||
wantShape := []int32{1, 3, cfg.HiddenSize}
|
||||
if !slices.Equal(out.Shape(), wantShape) {
|
||||
t.Errorf("shape: got %v, want %v", out.Shape(), wantShape)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMRoPEComputation verifies M-RoPE frequency computation produces valid values.
|
||||
func TestMRoPEComputation(t *testing.T) {
|
||||
encoder, cfg := loadTinyTextEncoder(t)
|
||||
|
||||
cossin := encoder.computeTextRoPE(10, 1)
|
||||
mlx.Eval(cossin[0], cossin[1])
|
||||
|
||||
// Verify shapes: [3, B, L, head_dim]
|
||||
wantShape := []int32{3, 1, 10, cfg.HeadDim}
|
||||
if !slices.Equal(cossin[0].Shape(), wantShape) {
|
||||
t.Errorf("cos shape: got %v, want %v", cossin[0].Shape(), wantShape)
|
||||
}
|
||||
if !slices.Equal(cossin[1].Shape(), wantShape) {
|
||||
t.Errorf("sin shape: got %v, want %v", cossin[1].Shape(), wantShape)
|
||||
}
|
||||
|
||||
// Verify cos/sin values are in valid range [-1, 1]
|
||||
cosData := cossin[0].Data()
|
||||
sinData := cossin[1].Data()
|
||||
for i := 0; i < min(100, len(cosData)); i++ {
|
||||
if cosData[i] < -1.01 || cosData[i] > 1.01 {
|
||||
t.Errorf("cos[%d] out of range: %v", i, cosData[i])
|
||||
}
|
||||
if sinData[i] < -1.01 || sinData[i] > 1.01 {
|
||||
t.Errorf("sin[%d] out of range: %v", i, sinData[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,868 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package qwen_image
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
|
||||
// TransformerConfig holds Qwen-Image transformer configuration
|
||||
type TransformerConfig struct {
|
||||
HiddenDim int32 `json:"hidden_dim"` // 3072 (24 * 128)
|
||||
NHeads int32 `json:"num_attention_heads"` // 24
|
||||
HeadDim int32 `json:"attention_head_dim"` // 128
|
||||
NLayers int32 `json:"num_layers"` // 60
|
||||
InChannels int32 `json:"in_channels"` // 64
|
||||
OutChannels int32 `json:"out_channels"` // 16
|
||||
PatchSize int32 `json:"patch_size"` // 2
|
||||
JointAttentionDim int32 `json:"joint_attention_dim"` // 3584 (text encoder dim)
|
||||
NormEps float32 `json:"norm_eps"` // 1e-6
|
||||
AxesDimsRope []int32 `json:"axes_dims_rope"` // [16, 56, 56]
|
||||
GuidanceEmbeds bool `json:"guidance_embeds"` // false
|
||||
}
|
||||
|
||||
// defaultTransformerConfig returns config for Qwen-Image transformer
|
||||
func defaultTransformerConfig() *TransformerConfig {
|
||||
return &TransformerConfig{
|
||||
HiddenDim: 3072, // 24 * 128
|
||||
NHeads: 24,
|
||||
HeadDim: 128,
|
||||
NLayers: 60,
|
||||
InChannels: 64,
|
||||
OutChannels: 16,
|
||||
PatchSize: 2,
|
||||
JointAttentionDim: 3584,
|
||||
NormEps: 1e-6,
|
||||
AxesDimsRope: []int32{16, 56, 56},
|
||||
GuidanceEmbeds: false,
|
||||
}
|
||||
}
|
||||
|
||||
// TimestepEmbedder creates timestep embeddings
|
||||
type TimestepEmbedder struct {
|
||||
Linear1Weight *mlx.Array // [256, hidden_dim]
|
||||
Linear1Bias *mlx.Array
|
||||
Linear2Weight *mlx.Array // [hidden_dim, hidden_dim]
|
||||
Linear2Bias *mlx.Array
|
||||
}
|
||||
|
||||
// newTimestepEmbedder creates a timestep embedder from weights
|
||||
func newTimestepEmbedder(weights *safetensors.ModelWeights) (*TimestepEmbedder, error) {
|
||||
linear1Weight, err := weights.Get("time_text_embed.timestep_embedder.linear_1.weight")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
linear1Bias, err := weights.Get("time_text_embed.timestep_embedder.linear_1.bias")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
linear2Weight, err := weights.Get("time_text_embed.timestep_embedder.linear_2.weight")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
linear2Bias, err := weights.Get("time_text_embed.timestep_embedder.linear_2.bias")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &TimestepEmbedder{
|
||||
Linear1Weight: mlx.Transpose(linear1Weight, 1, 0),
|
||||
Linear1Bias: linear1Bias,
|
||||
Linear2Weight: mlx.Transpose(linear2Weight, 1, 0),
|
||||
Linear2Bias: linear2Bias,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward computes timestep embeddings
|
||||
// t: [B] timesteps (normalized 0-1, will be scaled by 1000 internally)
|
||||
func (te *TimestepEmbedder) Forward(t *mlx.Array) *mlx.Array {
|
||||
half := int32(128) // embedding_dim / 2
|
||||
|
||||
// Sinusoidal embedding with flip_sin_to_cos=True, scale=1000
|
||||
freqs := make([]float32, half)
|
||||
for i := int32(0); i < half; i++ {
|
||||
freqs[i] = float32(math.Exp(-math.Log(10000.0) * float64(i) / float64(half)))
|
||||
}
|
||||
freqsArr := mlx.NewArray(freqs, []int32{1, half})
|
||||
|
||||
tExpanded := mlx.ExpandDims(t, 1)
|
||||
args := mlx.Mul(tExpanded, freqsArr)
|
||||
args = mlx.MulScalar(args, 1000.0) // scale
|
||||
|
||||
// [cos, sin] (flip_sin_to_cos=True)
|
||||
sinArgs := mlx.Sin(args)
|
||||
cosArgs := mlx.Cos(args)
|
||||
embedding := mlx.Concatenate([]*mlx.Array{cosArgs, sinArgs}, 1) // [B, 256]
|
||||
|
||||
// MLP: linear1 -> silu -> linear2
|
||||
h := mlx.Linear(embedding, te.Linear1Weight)
|
||||
h = mlx.Add(h, te.Linear1Bias)
|
||||
h = mlx.SiLU(h)
|
||||
h = mlx.Linear(h, te.Linear2Weight)
|
||||
h = mlx.Add(h, te.Linear2Bias)
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// JointAttention implements dual-stream joint attention
|
||||
type JointAttention struct {
|
||||
// Image projections
|
||||
ToQ *mlx.Array
|
||||
ToQB *mlx.Array
|
||||
ToK *mlx.Array
|
||||
ToKB *mlx.Array
|
||||
ToV *mlx.Array
|
||||
ToVB *mlx.Array
|
||||
ToOut *mlx.Array
|
||||
ToOutB *mlx.Array
|
||||
NormQ *mlx.Array
|
||||
NormK *mlx.Array
|
||||
|
||||
// Text (added) projections
|
||||
AddQProj *mlx.Array
|
||||
AddQProjB *mlx.Array
|
||||
AddKProj *mlx.Array
|
||||
AddKProjB *mlx.Array
|
||||
AddVProj *mlx.Array
|
||||
AddVProjB *mlx.Array
|
||||
ToAddOut *mlx.Array
|
||||
ToAddOutB *mlx.Array
|
||||
NormAddQ *mlx.Array
|
||||
NormAddK *mlx.Array
|
||||
|
||||
NHeads int32
|
||||
HeadDim int32
|
||||
Scale float32
|
||||
}
|
||||
|
||||
// newJointAttention creates a joint attention layer
|
||||
func newJointAttention(weights *safetensors.ModelWeights, prefix string, cfg *TransformerConfig) (*JointAttention, error) {
|
||||
toQ, _ := weights.Get(prefix + ".attn.to_q.weight")
|
||||
toQB, _ := weights.Get(prefix + ".attn.to_q.bias")
|
||||
toK, _ := weights.Get(prefix + ".attn.to_k.weight")
|
||||
toKB, _ := weights.Get(prefix + ".attn.to_k.bias")
|
||||
toV, _ := weights.Get(prefix + ".attn.to_v.weight")
|
||||
toVB, _ := weights.Get(prefix + ".attn.to_v.bias")
|
||||
toOut, _ := weights.Get(prefix + ".attn.to_out.0.weight")
|
||||
toOutB, _ := weights.Get(prefix + ".attn.to_out.0.bias")
|
||||
normQ, _ := weights.Get(prefix + ".attn.norm_q.weight")
|
||||
normK, _ := weights.Get(prefix + ".attn.norm_k.weight")
|
||||
|
||||
addQProj, _ := weights.Get(prefix + ".attn.add_q_proj.weight")
|
||||
addQProjB, _ := weights.Get(prefix + ".attn.add_q_proj.bias")
|
||||
addKProj, _ := weights.Get(prefix + ".attn.add_k_proj.weight")
|
||||
addKProjB, _ := weights.Get(prefix + ".attn.add_k_proj.bias")
|
||||
addVProj, _ := weights.Get(prefix + ".attn.add_v_proj.weight")
|
||||
addVProjB, _ := weights.Get(prefix + ".attn.add_v_proj.bias")
|
||||
toAddOut, _ := weights.Get(prefix + ".attn.to_add_out.weight")
|
||||
toAddOutB, _ := weights.Get(prefix + ".attn.to_add_out.bias")
|
||||
normAddQ, _ := weights.Get(prefix + ".attn.norm_added_q.weight")
|
||||
normAddK, _ := weights.Get(prefix + ".attn.norm_added_k.weight")
|
||||
|
||||
return &JointAttention{
|
||||
ToQ: mlx.Transpose(toQ, 1, 0),
|
||||
ToQB: toQB,
|
||||
ToK: mlx.Transpose(toK, 1, 0),
|
||||
ToKB: toKB,
|
||||
ToV: mlx.Transpose(toV, 1, 0),
|
||||
ToVB: toVB,
|
||||
ToOut: mlx.Transpose(toOut, 1, 0),
|
||||
ToOutB: toOutB,
|
||||
NormQ: normQ,
|
||||
NormK: normK,
|
||||
AddQProj: mlx.Transpose(addQProj, 1, 0),
|
||||
AddQProjB: addQProjB,
|
||||
AddKProj: mlx.Transpose(addKProj, 1, 0),
|
||||
AddKProjB: addKProjB,
|
||||
AddVProj: mlx.Transpose(addVProj, 1, 0),
|
||||
AddVProjB: addVProjB,
|
||||
ToAddOut: mlx.Transpose(toAddOut, 1, 0),
|
||||
ToAddOutB: toAddOutB,
|
||||
NormAddQ: normAddQ,
|
||||
NormAddK: normAddK,
|
||||
NHeads: cfg.NHeads,
|
||||
HeadDim: cfg.HeadDim,
|
||||
Scale: float32(1.0 / math.Sqrt(float64(cfg.HeadDim))),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward computes joint attention
|
||||
// img: [B, L_img, D], txt: [B, L_txt, D]
|
||||
// imgFreqs, txtFreqs: complex RoPE frequencies [L, head_dim/2] as interleaved real/imag
|
||||
func (attn *JointAttention) Forward(img, txt *mlx.Array, imgFreqs, txtFreqs *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
imgShape := img.Shape()
|
||||
B := imgShape[0]
|
||||
Limg := imgShape[1]
|
||||
D := imgShape[2]
|
||||
|
||||
txtShape := txt.Shape()
|
||||
Ltxt := txtShape[1]
|
||||
|
||||
// === Image Q/K/V ===
|
||||
imgFlat := mlx.Reshape(img, B*Limg, D)
|
||||
qImg := mlx.Add(mlx.Linear(imgFlat, attn.ToQ), attn.ToQB)
|
||||
kImg := mlx.Add(mlx.Linear(imgFlat, attn.ToK), attn.ToKB)
|
||||
vImg := mlx.Add(mlx.Linear(imgFlat, attn.ToV), attn.ToVB)
|
||||
|
||||
qImg = mlx.Reshape(qImg, B, Limg, attn.NHeads, attn.HeadDim)
|
||||
kImg = mlx.Reshape(kImg, B, Limg, attn.NHeads, attn.HeadDim)
|
||||
vImg = mlx.Reshape(vImg, B, Limg, attn.NHeads, attn.HeadDim)
|
||||
|
||||
// QK norm (RMSNorm per head)
|
||||
qImg = mlx.RMSNorm(qImg, attn.NormQ, 1e-6)
|
||||
kImg = mlx.RMSNorm(kImg, attn.NormK, 1e-6)
|
||||
|
||||
// Apply RoPE
|
||||
if imgFreqs != nil {
|
||||
qImg = applyRoPE(qImg, imgFreqs)
|
||||
kImg = applyRoPE(kImg, imgFreqs)
|
||||
}
|
||||
|
||||
// === Text Q/K/V ===
|
||||
txtFlat := mlx.Reshape(txt, B*Ltxt, D)
|
||||
qTxt := mlx.Add(mlx.Linear(txtFlat, attn.AddQProj), attn.AddQProjB)
|
||||
kTxt := mlx.Add(mlx.Linear(txtFlat, attn.AddKProj), attn.AddKProjB)
|
||||
vTxt := mlx.Add(mlx.Linear(txtFlat, attn.AddVProj), attn.AddVProjB)
|
||||
|
||||
qTxt = mlx.Reshape(qTxt, B, Ltxt, attn.NHeads, attn.HeadDim)
|
||||
kTxt = mlx.Reshape(kTxt, B, Ltxt, attn.NHeads, attn.HeadDim)
|
||||
vTxt = mlx.Reshape(vTxt, B, Ltxt, attn.NHeads, attn.HeadDim)
|
||||
|
||||
qTxt = mlx.RMSNorm(qTxt, attn.NormAddQ, 1e-6)
|
||||
kTxt = mlx.RMSNorm(kTxt, attn.NormAddK, 1e-6)
|
||||
|
||||
if txtFreqs != nil {
|
||||
qTxt = applyRoPE(qTxt, txtFreqs)
|
||||
kTxt = applyRoPE(kTxt, txtFreqs)
|
||||
}
|
||||
|
||||
// Concatenate for joint attention: [txt, img] order
|
||||
qJoint := mlx.Concatenate([]*mlx.Array{qTxt, qImg}, 1)
|
||||
kJoint := mlx.Concatenate([]*mlx.Array{kTxt, kImg}, 1)
|
||||
vJoint := mlx.Concatenate([]*mlx.Array{vTxt, vImg}, 1)
|
||||
|
||||
// Transpose to [B, nheads, L, head_dim]
|
||||
qJoint = mlx.Transpose(qJoint, 0, 2, 1, 3)
|
||||
kJoint = mlx.Transpose(kJoint, 0, 2, 1, 3)
|
||||
vJoint = mlx.Transpose(vJoint, 0, 2, 1, 3)
|
||||
|
||||
// SDPA
|
||||
outJoint := mlx.ScaledDotProductAttention(qJoint, kJoint, vJoint, attn.Scale, false)
|
||||
|
||||
// Transpose back and split
|
||||
outJoint = mlx.Transpose(outJoint, 0, 2, 1, 3) // [B, L, nheads, head_dim]
|
||||
outJoint = mlx.Reshape(outJoint, B, Ltxt+Limg, D)
|
||||
|
||||
outTxt := mlx.Slice(outJoint, []int32{0, 0, 0}, []int32{B, Ltxt, D})
|
||||
outImg := mlx.Slice(outJoint, []int32{0, Ltxt, 0}, []int32{B, Ltxt + Limg, D})
|
||||
|
||||
// Output projections
|
||||
outImg = mlx.Reshape(outImg, B*Limg, D)
|
||||
outImg = mlx.Add(mlx.Linear(outImg, attn.ToOut), attn.ToOutB)
|
||||
outImg = mlx.Reshape(outImg, B, Limg, D)
|
||||
|
||||
outTxt = mlx.Reshape(outTxt, B*Ltxt, D)
|
||||
outTxt = mlx.Add(mlx.Linear(outTxt, attn.ToAddOut), attn.ToAddOutB)
|
||||
outTxt = mlx.Reshape(outTxt, B, Ltxt, D)
|
||||
|
||||
return outImg, outTxt
|
||||
}
|
||||
|
||||
// applyRoPE applies rotary embeddings using complex multiplication
|
||||
// x: [B, L, nheads, head_dim]
|
||||
// freqs: [L, head_dim] as complex (interleaved real/imag pairs)
|
||||
func applyRoPE(x *mlx.Array, freqs *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
nheads := shape[2]
|
||||
headDim := shape[3]
|
||||
halfDim := headDim / 2
|
||||
|
||||
// Reshape x to pairs: [B, L, nheads, half, 2]
|
||||
xPairs := mlx.Reshape(x, B, L, nheads, halfDim, 2)
|
||||
|
||||
// freqs: [L, head_dim] -> [1, L, 1, half, 2]
|
||||
freqsExp := mlx.Reshape(freqs, 1, L, 1, halfDim, 2)
|
||||
|
||||
// Extract real/imag parts
|
||||
xReal := mlx.SliceStride(xPairs, []int32{0, 0, 0, 0, 0}, []int32{B, L, nheads, halfDim, 1}, []int32{1, 1, 1, 1, 1})
|
||||
xImag := mlx.SliceStride(xPairs, []int32{0, 0, 0, 0, 1}, []int32{B, L, nheads, halfDim, 2}, []int32{1, 1, 1, 1, 1})
|
||||
xReal = mlx.Squeeze(xReal, 4)
|
||||
xImag = mlx.Squeeze(xImag, 4)
|
||||
|
||||
freqReal := mlx.SliceStride(freqsExp, []int32{0, 0, 0, 0, 0}, []int32{1, L, 1, halfDim, 1}, []int32{1, 1, 1, 1, 1})
|
||||
freqImag := mlx.SliceStride(freqsExp, []int32{0, 0, 0, 0, 1}, []int32{1, L, 1, halfDim, 2}, []int32{1, 1, 1, 1, 1})
|
||||
freqReal = mlx.Squeeze(freqReal, 4)
|
||||
freqImag = mlx.Squeeze(freqImag, 4)
|
||||
|
||||
// Complex multiplication: (a + bi) * (c + di) = (ac - bd) + (ad + bc)i
|
||||
outReal := mlx.Sub(mlx.Mul(xReal, freqReal), mlx.Mul(xImag, freqImag))
|
||||
outImag := mlx.Add(mlx.Mul(xReal, freqImag), mlx.Mul(xImag, freqReal))
|
||||
|
||||
// Interleave back
|
||||
outReal = mlx.ExpandDims(outReal, 4)
|
||||
outImag = mlx.ExpandDims(outImag, 4)
|
||||
out := mlx.Concatenate([]*mlx.Array{outReal, outImag}, 4)
|
||||
|
||||
return mlx.Reshape(out, B, L, nheads, headDim)
|
||||
}
|
||||
|
||||
// MLP implements GELU MLP (not GEGLU)
|
||||
type MLP struct {
|
||||
ProjWeight *mlx.Array
|
||||
ProjBias *mlx.Array
|
||||
OutWeight *mlx.Array
|
||||
OutBias *mlx.Array
|
||||
}
|
||||
|
||||
// newMLP creates a GELU MLP
|
||||
func newMLP(weights *safetensors.ModelWeights, prefix string) (*MLP, error) {
|
||||
projWeight, _ := weights.Get(prefix + ".net.0.proj.weight")
|
||||
projBias, _ := weights.Get(prefix + ".net.0.proj.bias")
|
||||
outWeight, _ := weights.Get(prefix + ".net.2.weight")
|
||||
outBias, _ := weights.Get(prefix + ".net.2.bias")
|
||||
|
||||
return &MLP{
|
||||
ProjWeight: mlx.Transpose(projWeight, 1, 0),
|
||||
ProjBias: projBias,
|
||||
OutWeight: mlx.Transpose(outWeight, 1, 0),
|
||||
OutBias: outBias,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies GELU MLP
|
||||
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
D := shape[2]
|
||||
|
||||
xFlat := mlx.Reshape(x, B*L, D)
|
||||
h := mlx.Add(mlx.Linear(xFlat, m.ProjWeight), m.ProjBias)
|
||||
h = geluApprox(h)
|
||||
h = mlx.Add(mlx.Linear(h, m.OutWeight), m.OutBias)
|
||||
return mlx.Reshape(h, B, L, m.OutBias.Dim(0))
|
||||
}
|
||||
|
||||
// geluApprox implements approximate GELU
|
||||
func geluApprox(x *mlx.Array) *mlx.Array {
|
||||
sqrt2OverPi := float32(math.Sqrt(2.0 / math.Pi))
|
||||
x3 := mlx.Mul(mlx.Mul(x, x), x)
|
||||
inner := mlx.Add(x, mlx.MulScalar(x3, 0.044715))
|
||||
inner = mlx.MulScalar(inner, sqrt2OverPi)
|
||||
return mlx.Mul(mlx.MulScalar(x, 0.5), mlx.AddScalar(mlx.Tanh(inner), 1.0))
|
||||
}
|
||||
|
||||
// TransformerBlock is a single dual-stream transformer block
|
||||
type TransformerBlock struct {
|
||||
Attention *JointAttention
|
||||
ImgMLP *MLP
|
||||
TxtMLP *MLP
|
||||
|
||||
ImgModWeight *mlx.Array
|
||||
ImgModBias *mlx.Array
|
||||
TxtModWeight *mlx.Array
|
||||
TxtModBias *mlx.Array
|
||||
|
||||
HiddenDim int32
|
||||
NormEps float32
|
||||
}
|
||||
|
||||
// newTransformerBlock creates a transformer block
|
||||
func newTransformerBlock(weights *safetensors.ModelWeights, prefix string, cfg *TransformerConfig) (*TransformerBlock, error) {
|
||||
attn, err := newJointAttention(weights, prefix, cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
imgMLP, _ := newMLP(weights, prefix+".img_mlp")
|
||||
txtMLP, _ := newMLP(weights, prefix+".txt_mlp")
|
||||
|
||||
imgModWeight, _ := weights.Get(prefix + ".img_mod.1.weight")
|
||||
imgModBias, _ := weights.Get(prefix + ".img_mod.1.bias")
|
||||
txtModWeight, _ := weights.Get(prefix + ".txt_mod.1.weight")
|
||||
txtModBias, _ := weights.Get(prefix + ".txt_mod.1.bias")
|
||||
|
||||
return &TransformerBlock{
|
||||
Attention: attn,
|
||||
ImgMLP: imgMLP,
|
||||
TxtMLP: txtMLP,
|
||||
ImgModWeight: mlx.Transpose(imgModWeight, 1, 0),
|
||||
ImgModBias: imgModBias,
|
||||
TxtModWeight: mlx.Transpose(txtModWeight, 1, 0),
|
||||
TxtModBias: txtModBias,
|
||||
HiddenDim: cfg.HiddenDim,
|
||||
NormEps: cfg.NormEps,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies the transformer block
|
||||
func (tb *TransformerBlock) Forward(img, txt, temb *mlx.Array, imgFreqs, txtFreqs *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
// Compute modulation: silu(temb) -> linear -> [B, 6*D]
|
||||
siluT := mlx.SiLU(temb)
|
||||
imgMod := mlx.Add(mlx.Linear(siluT, tb.ImgModWeight), tb.ImgModBias)
|
||||
txtMod := mlx.Add(mlx.Linear(siluT, tb.TxtModWeight), tb.TxtModBias)
|
||||
|
||||
// Split into 6 parts: shift1, scale1, gate1, shift2, scale2, gate2
|
||||
imgModParts := splitMod6(imgMod, tb.HiddenDim)
|
||||
txtModParts := splitMod6(txtMod, tb.HiddenDim)
|
||||
|
||||
// Pre-attention: norm + modulate
|
||||
imgNorm := layerNormNoAffine(img, tb.NormEps)
|
||||
imgNorm = mlx.Add(mlx.Mul(imgNorm, mlx.AddScalar(imgModParts[1], 1.0)), imgModParts[0])
|
||||
|
||||
txtNorm := layerNormNoAffine(txt, tb.NormEps)
|
||||
txtNorm = mlx.Add(mlx.Mul(txtNorm, mlx.AddScalar(txtModParts[1], 1.0)), txtModParts[0])
|
||||
|
||||
// Joint attention
|
||||
attnImg, attnTxt := tb.Attention.Forward(imgNorm, txtNorm, imgFreqs, txtFreqs)
|
||||
|
||||
// Residual with gate
|
||||
img = mlx.Add(img, mlx.Mul(imgModParts[2], attnImg))
|
||||
txt = mlx.Add(txt, mlx.Mul(txtModParts[2], attnTxt))
|
||||
|
||||
// Pre-MLP: norm + modulate
|
||||
imgNorm2 := layerNormNoAffine(img, tb.NormEps)
|
||||
imgNorm2 = mlx.Add(mlx.Mul(imgNorm2, mlx.AddScalar(imgModParts[4], 1.0)), imgModParts[3])
|
||||
|
||||
txtNorm2 := layerNormNoAffine(txt, tb.NormEps)
|
||||
txtNorm2 = mlx.Add(mlx.Mul(txtNorm2, mlx.AddScalar(txtModParts[4], 1.0)), txtModParts[3])
|
||||
|
||||
// MLP
|
||||
mlpImg := tb.ImgMLP.Forward(imgNorm2)
|
||||
mlpTxt := tb.TxtMLP.Forward(txtNorm2)
|
||||
|
||||
// Residual with gate
|
||||
img = mlx.Add(img, mlx.Mul(imgModParts[5], mlpImg))
|
||||
txt = mlx.Add(txt, mlx.Mul(txtModParts[5], mlpTxt))
|
||||
|
||||
return img, txt
|
||||
}
|
||||
|
||||
// splitMod6 splits modulation into 6 parts each [B, 1, D]
|
||||
func splitMod6(mod *mlx.Array, hiddenDim int32) []*mlx.Array {
|
||||
shape := mod.Shape()
|
||||
B := shape[0]
|
||||
parts := make([]*mlx.Array, 6)
|
||||
for i := int32(0); i < 6; i++ {
|
||||
part := mlx.Slice(mod, []int32{0, i * hiddenDim}, []int32{B, (i + 1) * hiddenDim})
|
||||
parts[i] = mlx.ExpandDims(part, 1)
|
||||
}
|
||||
return parts
|
||||
}
|
||||
|
||||
// layerNormNoAffine applies layer norm without learnable parameters
|
||||
func layerNormNoAffine(x *mlx.Array, eps float32) *mlx.Array {
|
||||
ndim := x.Ndim()
|
||||
lastAxis := ndim - 1
|
||||
mean := mlx.Mean(x, lastAxis, true)
|
||||
xCentered := mlx.Sub(x, mean)
|
||||
variance := mlx.Mean(mlx.Square(xCentered), lastAxis, true)
|
||||
return mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, eps)))
|
||||
}
|
||||
|
||||
// Transformer is the full Qwen-Image transformer model
|
||||
type Transformer struct {
|
||||
Config *TransformerConfig
|
||||
|
||||
ImgIn *mlx.Array
|
||||
ImgInBias *mlx.Array
|
||||
TxtIn *mlx.Array
|
||||
TxtInBias *mlx.Array
|
||||
TxtNorm *mlx.Array
|
||||
|
||||
TEmbed *TimestepEmbedder
|
||||
Layers []*TransformerBlock
|
||||
|
||||
NormOutWeight *mlx.Array
|
||||
NormOutBias *mlx.Array
|
||||
ProjOut *mlx.Array
|
||||
ProjOutBias *mlx.Array
|
||||
}
|
||||
|
||||
// Load loads the transformer from a directory
|
||||
func (m *Transformer) Load(path string) error {
|
||||
fmt.Println("Loading Qwen-Image transformer...")
|
||||
|
||||
cfg := defaultTransformerConfig()
|
||||
m.Config = cfg
|
||||
|
||||
weights, err := safetensors.LoadModelWeights(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
|
||||
// Bulk load all weights as bf16
|
||||
fmt.Print(" Loading weights as bf16... ")
|
||||
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
|
||||
return fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
fmt.Printf("✓ (%.1f GB)\n", float64(mlx.MetalGetActiveMemory())/(1024*1024*1024))
|
||||
|
||||
fmt.Print(" Loading input projections... ")
|
||||
imgIn, _ := weights.Get("img_in.weight")
|
||||
imgInBias, _ := weights.Get("img_in.bias")
|
||||
txtIn, _ := weights.Get("txt_in.weight")
|
||||
txtInBias, _ := weights.Get("txt_in.bias")
|
||||
txtNorm, _ := weights.Get("txt_norm.weight")
|
||||
m.ImgIn = mlx.Transpose(imgIn, 1, 0)
|
||||
m.ImgInBias = imgInBias
|
||||
m.TxtIn = mlx.Transpose(txtIn, 1, 0)
|
||||
m.TxtInBias = txtInBias
|
||||
m.TxtNorm = txtNorm
|
||||
fmt.Println("✓")
|
||||
|
||||
fmt.Print(" Loading timestep embedder... ")
|
||||
m.TEmbed, err = newTimestepEmbedder(weights)
|
||||
if err != nil {
|
||||
return fmt.Errorf("timestep embedder: %w", err)
|
||||
}
|
||||
fmt.Println("✓")
|
||||
|
||||
m.Layers = make([]*TransformerBlock, cfg.NLayers)
|
||||
for i := int32(0); i < cfg.NLayers; i++ {
|
||||
fmt.Printf("\r Loading transformer layers... %d/%d", i+1, cfg.NLayers)
|
||||
prefix := fmt.Sprintf("transformer_blocks.%d", i)
|
||||
m.Layers[i], err = newTransformerBlock(weights, prefix, cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("layer %d: %w", i, err)
|
||||
}
|
||||
}
|
||||
fmt.Printf("\r Loading transformer layers... ✓ [%d blocks] \n", cfg.NLayers)
|
||||
|
||||
fmt.Print(" Loading output layers... ")
|
||||
normOutWeight, _ := weights.Get("norm_out.linear.weight")
|
||||
normOutBias, _ := weights.Get("norm_out.linear.bias")
|
||||
projOut, _ := weights.Get("proj_out.weight")
|
||||
projOutBias, _ := weights.Get("proj_out.bias")
|
||||
m.NormOutWeight = mlx.Transpose(normOutWeight, 1, 0)
|
||||
m.NormOutBias = normOutBias
|
||||
m.ProjOut = mlx.Transpose(projOut, 1, 0)
|
||||
m.ProjOutBias = projOutBias
|
||||
fmt.Println("✓")
|
||||
|
||||
weights.ReleaseAll()
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadFromPath is a convenience function to load transformer from path
|
||||
func LoadTransformerFromPath(path string) (*Transformer, error) {
|
||||
m := &Transformer{}
|
||||
if err := m.Load(filepath.Join(path, "transformer")); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Forward runs the transformer
|
||||
// img: [B, L_img, in_channels] patchified latents
|
||||
// txt: [B, L_txt, joint_attention_dim] text embeddings
|
||||
// t: [B] timesteps (0-1)
|
||||
// imgFreqs, txtFreqs: RoPE frequencies
|
||||
func (tr *Transformer) Forward(img, txt, t *mlx.Array, imgFreqs, txtFreqs *mlx.Array) *mlx.Array {
|
||||
imgShape := img.Shape()
|
||||
B := imgShape[0]
|
||||
Limg := imgShape[1]
|
||||
|
||||
txtShape := txt.Shape()
|
||||
Ltxt := txtShape[1]
|
||||
|
||||
// Timestep embedding
|
||||
temb := tr.TEmbed.Forward(t)
|
||||
|
||||
// Project image: [B, L, in_channels] -> [B, L, hidden_dim]
|
||||
imgFlat := mlx.Reshape(img, B*Limg, tr.Config.InChannels)
|
||||
imgH := mlx.Add(mlx.Linear(imgFlat, tr.ImgIn), tr.ImgInBias)
|
||||
imgH = mlx.Reshape(imgH, B, Limg, tr.Config.HiddenDim)
|
||||
|
||||
// Project text: RMSNorm then linear
|
||||
txtFlat := mlx.Reshape(txt, B*Ltxt, tr.Config.JointAttentionDim)
|
||||
txtNormed := mlx.RMSNorm(txtFlat, tr.TxtNorm, 1e-6)
|
||||
txtH := mlx.Add(mlx.Linear(txtNormed, tr.TxtIn), tr.TxtInBias)
|
||||
txtH = mlx.Reshape(txtH, B, Ltxt, tr.Config.HiddenDim)
|
||||
|
||||
for _, layer := range tr.Layers {
|
||||
imgH, txtH = layer.Forward(imgH, txtH, temb, imgFreqs, txtFreqs)
|
||||
}
|
||||
|
||||
// Final norm with modulation (AdaLayerNormContinuous)
|
||||
// Python: scale, shift = torch.chunk(emb, 2, dim=1)
|
||||
finalMod := mlx.Add(mlx.Linear(mlx.SiLU(temb), tr.NormOutWeight), tr.NormOutBias)
|
||||
modShape := finalMod.Shape()
|
||||
halfDim := modShape[1] / 2
|
||||
scale := mlx.ExpandDims(mlx.Slice(finalMod, []int32{0, 0}, []int32{B, halfDim}), 1)
|
||||
shift := mlx.ExpandDims(mlx.Slice(finalMod, []int32{0, halfDim}, []int32{B, modShape[1]}), 1)
|
||||
|
||||
imgH = layerNormNoAffine(imgH, tr.Config.NormEps)
|
||||
imgH = mlx.Add(mlx.Mul(imgH, mlx.AddScalar(scale, 1.0)), shift)
|
||||
|
||||
// Final projection: [B, L, hidden_dim] -> [B, L, patch_size^2 * out_channels]
|
||||
imgFlat = mlx.Reshape(imgH, B*Limg, tr.Config.HiddenDim)
|
||||
out := mlx.Add(mlx.Linear(imgFlat, tr.ProjOut), tr.ProjOutBias)
|
||||
|
||||
outChannels := tr.Config.PatchSize * tr.Config.PatchSize * tr.Config.OutChannels
|
||||
return mlx.Reshape(out, B, Limg, outChannels)
|
||||
}
|
||||
|
||||
// ForwardWithCache runs the transformer with layer caching for speedup.
|
||||
// Based on DeepCache (CVPR 2024) / Learning-to-Cache (NeurIPS 2024):
|
||||
// shallow layers change little between denoising steps, so we cache their
|
||||
// outputs and reuse them on non-refresh steps.
|
||||
//
|
||||
// stepCache: cache for layer outputs (use cache.NewStepCache(cacheLayers))
|
||||
// step: current denoising step (0-indexed)
|
||||
// cacheInterval: refresh cache every N steps (e.g., 3)
|
||||
// cacheLayers: number of shallow layers to cache (e.g., 15)
|
||||
func (tr *Transformer) ForwardWithCache(
|
||||
img, txt, t *mlx.Array,
|
||||
imgFreqs, txtFreqs *mlx.Array,
|
||||
stepCache *cache.StepCache,
|
||||
step, cacheInterval, cacheLayers int,
|
||||
) *mlx.Array {
|
||||
imgShape := img.Shape()
|
||||
B := imgShape[0]
|
||||
Limg := imgShape[1]
|
||||
|
||||
txtShape := txt.Shape()
|
||||
Ltxt := txtShape[1]
|
||||
|
||||
// Timestep embedding
|
||||
temb := tr.TEmbed.Forward(t)
|
||||
|
||||
// Project image: [B, L, in_channels] -> [B, L, hidden_dim]
|
||||
imgFlat := mlx.Reshape(img, B*Limg, tr.Config.InChannels)
|
||||
imgH := mlx.Add(mlx.Linear(imgFlat, tr.ImgIn), tr.ImgInBias)
|
||||
imgH = mlx.Reshape(imgH, B, Limg, tr.Config.HiddenDim)
|
||||
|
||||
// Project text: RMSNorm then linear
|
||||
txtFlat := mlx.Reshape(txt, B*Ltxt, tr.Config.JointAttentionDim)
|
||||
txtNormed := mlx.RMSNorm(txtFlat, tr.TxtNorm, 1e-6)
|
||||
txtH := mlx.Add(mlx.Linear(txtNormed, tr.TxtIn), tr.TxtInBias)
|
||||
txtH = mlx.Reshape(txtH, B, Ltxt, tr.Config.HiddenDim)
|
||||
|
||||
// Check if we should refresh the cache
|
||||
refreshCache := stepCache.ShouldRefresh(step, cacheInterval)
|
||||
|
||||
for i, layer := range tr.Layers {
|
||||
if i < cacheLayers && !refreshCache && stepCache.Get(i) != nil {
|
||||
// Use cached outputs for shallow layers
|
||||
imgH = stepCache.Get(i)
|
||||
txtH = stepCache.Get2(i)
|
||||
} else {
|
||||
// Compute layer
|
||||
imgH, txtH = layer.Forward(imgH, txtH, temb, imgFreqs, txtFreqs)
|
||||
// Cache shallow layers on refresh steps
|
||||
if i < cacheLayers && refreshCache {
|
||||
stepCache.Set(i, imgH)
|
||||
stepCache.Set2(i, txtH)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Final norm with modulation (AdaLayerNormContinuous)
|
||||
finalMod := mlx.Add(mlx.Linear(mlx.SiLU(temb), tr.NormOutWeight), tr.NormOutBias)
|
||||
modShape := finalMod.Shape()
|
||||
halfDim := modShape[1] / 2
|
||||
scale := mlx.ExpandDims(mlx.Slice(finalMod, []int32{0, 0}, []int32{B, halfDim}), 1)
|
||||
shift := mlx.ExpandDims(mlx.Slice(finalMod, []int32{0, halfDim}, []int32{B, modShape[1]}), 1)
|
||||
|
||||
imgH = layerNormNoAffine(imgH, tr.Config.NormEps)
|
||||
imgH = mlx.Add(mlx.Mul(imgH, mlx.AddScalar(scale, 1.0)), shift)
|
||||
|
||||
// Final projection: [B, L, hidden_dim] -> [B, L, patch_size^2 * out_channels]
|
||||
imgFlat = mlx.Reshape(imgH, B*Limg, tr.Config.HiddenDim)
|
||||
out := mlx.Add(mlx.Linear(imgFlat, tr.ProjOut), tr.ProjOutBias)
|
||||
|
||||
outChannels := tr.Config.PatchSize * tr.Config.PatchSize * tr.Config.OutChannels
|
||||
return mlx.Reshape(out, B, Limg, outChannels)
|
||||
}
|
||||
|
||||
// RoPECache holds precomputed RoPE frequencies
|
||||
type RoPECache struct {
|
||||
ImgFreqs *mlx.Array // [L_img, head_dim]
|
||||
TxtFreqs *mlx.Array // [L_txt, head_dim]
|
||||
}
|
||||
|
||||
// PrepareRoPE computes RoPE for image and text sequences
|
||||
// This matches Python's QwenEmbedRope with scale_rope=True
|
||||
func PrepareRoPE(imgH, imgW int32, txtLen int32, axesDims []int32) *RoPECache {
|
||||
theta := float64(10000)
|
||||
maxIdx := int32(4096)
|
||||
|
||||
// Compute base frequencies for each axis dimension
|
||||
freqsT := ComputeAxisFreqs(axesDims[0], theta)
|
||||
freqsH := ComputeAxisFreqs(axesDims[1], theta)
|
||||
freqsW := ComputeAxisFreqs(axesDims[2], theta)
|
||||
|
||||
// Build frequency lookup tables
|
||||
posFreqsT := MakeFreqTable(maxIdx, freqsT, false)
|
||||
posFreqsH := MakeFreqTable(maxIdx, freqsH, false)
|
||||
posFreqsW := MakeFreqTable(maxIdx, freqsW, false)
|
||||
negFreqsH := MakeFreqTable(maxIdx, freqsH, true)
|
||||
negFreqsW := MakeFreqTable(maxIdx, freqsW, true)
|
||||
|
||||
// Image frequencies with scale_rope=True
|
||||
imgLen := imgH * imgW
|
||||
headDim := int32(len(freqsT)+len(freqsH)+len(freqsW)) * 2
|
||||
imgFreqsData := make([]float32, imgLen*headDim)
|
||||
|
||||
hHalf := imgH / 2
|
||||
wHalf := imgW / 2
|
||||
|
||||
idx := int32(0)
|
||||
for y := int32(0); y < imgH; y++ {
|
||||
for x := int32(0); x < imgW; x++ {
|
||||
// Frame = 0
|
||||
for i := 0; i < len(freqsT)*2; i++ {
|
||||
imgFreqsData[idx+int32(i)] = posFreqsT[0][i]
|
||||
}
|
||||
idx += int32(len(freqsT) * 2)
|
||||
|
||||
// Height: scale_rope pattern
|
||||
hNegCount := imgH - hHalf
|
||||
if y < hNegCount {
|
||||
negTableIdx := maxIdx - hNegCount + y
|
||||
for i := 0; i < len(freqsH)*2; i++ {
|
||||
imgFreqsData[idx+int32(i)] = negFreqsH[negTableIdx][i]
|
||||
}
|
||||
} else {
|
||||
posIdx := y - hNegCount
|
||||
for i := 0; i < len(freqsH)*2; i++ {
|
||||
imgFreqsData[idx+int32(i)] = posFreqsH[posIdx][i]
|
||||
}
|
||||
}
|
||||
idx += int32(len(freqsH) * 2)
|
||||
|
||||
// Width: scale_rope pattern
|
||||
wNegCount := imgW - wHalf
|
||||
if x < wNegCount {
|
||||
negTableIdx := maxIdx - wNegCount + x
|
||||
for i := 0; i < len(freqsW)*2; i++ {
|
||||
imgFreqsData[idx+int32(i)] = negFreqsW[negTableIdx][i]
|
||||
}
|
||||
} else {
|
||||
posIdx := x - wNegCount
|
||||
for i := 0; i < len(freqsW)*2; i++ {
|
||||
imgFreqsData[idx+int32(i)] = posFreqsW[posIdx][i]
|
||||
}
|
||||
}
|
||||
idx += int32(len(freqsW) * 2)
|
||||
}
|
||||
}
|
||||
|
||||
imgFreqs := mlx.NewArray(imgFreqsData, []int32{imgLen, headDim})
|
||||
imgFreqs = mlx.ToBFloat16(imgFreqs)
|
||||
|
||||
// Text frequencies
|
||||
maxVidIdx := max(hHalf, wHalf)
|
||||
txtFreqsData := make([]float32, txtLen*headDim)
|
||||
|
||||
idx = 0
|
||||
for t := int32(0); t < txtLen; t++ {
|
||||
pos := maxVidIdx + t
|
||||
for i := 0; i < len(freqsT)*2; i++ {
|
||||
txtFreqsData[idx+int32(i)] = posFreqsT[pos][i]
|
||||
}
|
||||
idx += int32(len(freqsT) * 2)
|
||||
for i := 0; i < len(freqsH)*2; i++ {
|
||||
txtFreqsData[idx+int32(i)] = posFreqsH[pos][i]
|
||||
}
|
||||
idx += int32(len(freqsH) * 2)
|
||||
for i := 0; i < len(freqsW)*2; i++ {
|
||||
txtFreqsData[idx+int32(i)] = posFreqsW[pos][i]
|
||||
}
|
||||
idx += int32(len(freqsW) * 2)
|
||||
}
|
||||
|
||||
txtFreqs := mlx.NewArray(txtFreqsData, []int32{txtLen, headDim})
|
||||
txtFreqs = mlx.ToBFloat16(txtFreqs)
|
||||
|
||||
return &RoPECache{
|
||||
ImgFreqs: imgFreqs,
|
||||
TxtFreqs: txtFreqs,
|
||||
}
|
||||
}
|
||||
|
||||
// ComputeAxisFreqs computes RoPE base frequencies for a given dimension.
|
||||
func ComputeAxisFreqs(dim int32, theta float64) []float64 {
|
||||
halfDim := dim / 2
|
||||
freqs := make([]float64, halfDim)
|
||||
for i := int32(0); i < halfDim; i++ {
|
||||
freqs[i] = 1.0 / math.Pow(theta, float64(i)/float64(halfDim))
|
||||
}
|
||||
return freqs
|
||||
}
|
||||
|
||||
// MakeFreqTable builds a table of cos/sin values for RoPE positions.
|
||||
func MakeFreqTable(maxIdx int32, baseFreqs []float64, negative bool) [][]float32 {
|
||||
table := make([][]float32, maxIdx)
|
||||
for idx := int32(0); idx < maxIdx; idx++ {
|
||||
var pos float64
|
||||
if negative {
|
||||
pos = float64(-maxIdx + int32(idx))
|
||||
} else {
|
||||
pos = float64(idx)
|
||||
}
|
||||
|
||||
row := make([]float32, len(baseFreqs)*2)
|
||||
for i, f := range baseFreqs {
|
||||
angle := pos * f
|
||||
row[i*2] = float32(math.Cos(angle))
|
||||
row[i*2+1] = float32(math.Sin(angle))
|
||||
}
|
||||
table[idx] = row
|
||||
}
|
||||
return table
|
||||
}
|
||||
|
||||
func max(a, b int32) int32 {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// PackLatents converts [B, C, H, W] to [B, L, C*4] patches
|
||||
func PackLatents(latents *mlx.Array, patchSize int32) *mlx.Array {
|
||||
shape := latents.Shape()
|
||||
B := shape[0]
|
||||
C := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
|
||||
pH := H / patchSize
|
||||
pW := W / patchSize
|
||||
|
||||
// [B, C, H, W] -> [B, C, pH, 2, pW, 2]
|
||||
x := mlx.Reshape(latents, B, C, pH, patchSize, pW, patchSize)
|
||||
// -> [B, pH, pW, C, 2, 2]
|
||||
x = mlx.Transpose(x, 0, 2, 4, 1, 3, 5)
|
||||
// -> [B, pH*pW, C*4]
|
||||
return mlx.Reshape(x, B, pH*pW, C*patchSize*patchSize)
|
||||
}
|
||||
|
||||
// UnpackLatents converts [B, L, C*4] back to [B, C, 1, H, W] (5D for VAE)
|
||||
func UnpackLatents(patches *mlx.Array, H, W, patchSize int32) *mlx.Array {
|
||||
shape := patches.Shape()
|
||||
B := shape[0]
|
||||
channels := shape[2] / (patchSize * patchSize)
|
||||
|
||||
pH := H / patchSize
|
||||
pW := W / patchSize
|
||||
|
||||
// [B, L, C*4] -> [B, pH, pW, C, 2, 2]
|
||||
x := mlx.Reshape(patches, B, pH, pW, channels, patchSize, patchSize)
|
||||
// -> [B, C, pH, 2, pW, 2]
|
||||
x = mlx.Transpose(x, 0, 3, 1, 4, 2, 5)
|
||||
// -> [B, C, H, W]
|
||||
x = mlx.Reshape(x, B, channels, pH*patchSize, pW*patchSize)
|
||||
// Add temporal dimension for VAE: [B, C, 1, H, W]
|
||||
return mlx.ExpandDims(x, 2)
|
||||
}
|
||||
@@ -1,119 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package qwen_image
|
||||
|
||||
import (
|
||||
"math"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// TestTransformerConfig tests configuration invariants.
|
||||
func TestTransformerConfig(t *testing.T) {
|
||||
cfg := defaultTransformerConfig()
|
||||
|
||||
// Property: hidden_dim = n_heads * head_dim
|
||||
if cfg.HiddenDim != cfg.NHeads*cfg.HeadDim {
|
||||
t.Errorf("hidden_dim != n_heads * head_dim: %d != %d * %d",
|
||||
cfg.HiddenDim, cfg.NHeads, cfg.HeadDim)
|
||||
}
|
||||
|
||||
// Property: axes_dims_rope sums to head_dim
|
||||
var ropeSum int32
|
||||
for _, d := range cfg.AxesDimsRope {
|
||||
ropeSum += d
|
||||
}
|
||||
if ropeSum != cfg.HeadDim {
|
||||
t.Errorf("axes_dims_rope sum != head_dim: %d != %d", ropeSum, cfg.HeadDim)
|
||||
}
|
||||
|
||||
// Property: in_channels = out_channels * patch_size^2
|
||||
expectedIn := cfg.OutChannels * cfg.PatchSize * cfg.PatchSize
|
||||
if cfg.InChannels != expectedIn {
|
||||
t.Errorf("in_channels != out_channels * patch_size^2: %d != %d", cfg.InChannels, expectedIn)
|
||||
}
|
||||
}
|
||||
|
||||
// TestTransformerRoPE tests RoPE frequency computation produces valid values.
|
||||
func TestTransformerRoPE(t *testing.T) {
|
||||
cfg := defaultTransformerConfig()
|
||||
|
||||
// Test with small image dimensions
|
||||
imgH, imgW := int32(4), int32(4) // 4x4 latent = 16 patches
|
||||
txtLen := int32(5)
|
||||
|
||||
ropeCache := PrepareRoPE(imgH, imgW, txtLen, cfg.AxesDimsRope)
|
||||
mlx.Eval(ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
|
||||
// Verify shapes: [seq_len, head_dim]
|
||||
imgSeqLen := imgH * imgW
|
||||
if ropeCache.ImgFreqs.Shape()[0] != imgSeqLen {
|
||||
t.Errorf("ImgFreqs seq_len: got %d, want %d", ropeCache.ImgFreqs.Shape()[0], imgSeqLen)
|
||||
}
|
||||
if ropeCache.ImgFreqs.Shape()[1] != cfg.HeadDim {
|
||||
t.Errorf("ImgFreqs head_dim: got %d, want %d", ropeCache.ImgFreqs.Shape()[1], cfg.HeadDim)
|
||||
}
|
||||
|
||||
if ropeCache.TxtFreqs.Shape()[0] != txtLen {
|
||||
t.Errorf("TxtFreqs seq_len: got %d, want %d", ropeCache.TxtFreqs.Shape()[0], txtLen)
|
||||
}
|
||||
|
||||
// Verify values are finite
|
||||
imgData := ropeCache.ImgFreqs.Data()
|
||||
for i := 0; i < min(100, len(imgData)); i++ {
|
||||
if math.IsNaN(float64(imgData[i])) || math.IsInf(float64(imgData[i]), 0) {
|
||||
t.Errorf("ImgFreqs[%d] not finite: %v", i, imgData[i])
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestTransformerForward tests full forward pass (integration test).
|
||||
// Skips if model weights are not available.
|
||||
func TestTransformerForward(t *testing.T) {
|
||||
weightsPath := "../../../weights/Qwen-Image-2512/transformer"
|
||||
if _, err := os.Stat(weightsPath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping: model weights not found at " + weightsPath)
|
||||
}
|
||||
|
||||
transformer := &Transformer{}
|
||||
if err := transformer.Load(weightsPath); err != nil {
|
||||
t.Fatalf("Failed to load transformer: %v", err)
|
||||
}
|
||||
mlx.Keep(mlx.Collect(transformer)...)
|
||||
cfg := transformer.Config
|
||||
|
||||
// Small test inputs
|
||||
batchSize := int32(1)
|
||||
imgH, imgW := int32(4), int32(4)
|
||||
imgSeqLen := imgH * imgW
|
||||
txtSeqLen := int32(5)
|
||||
|
||||
hiddenStates := mlx.RandomNormal([]int32{batchSize, imgSeqLen, cfg.InChannels}, 0)
|
||||
encoderHiddenStates := mlx.RandomNormal([]int32{batchSize, txtSeqLen, cfg.JointAttentionDim}, 0)
|
||||
timestep := mlx.NewArray([]float32{0.5}, []int32{batchSize})
|
||||
|
||||
ropeCache := PrepareRoPE(imgH, imgW, txtSeqLen, cfg.AxesDimsRope)
|
||||
|
||||
// Forward pass
|
||||
out := transformer.Forward(hiddenStates, encoderHiddenStates, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
mlx.Eval(out)
|
||||
|
||||
// Verify output shape: [batch, img_seq_len, in_channels]
|
||||
wantShape := []int32{batchSize, imgSeqLen, cfg.InChannels}
|
||||
gotShape := out.Shape()
|
||||
if gotShape[0] != wantShape[0] || gotShape[1] != wantShape[1] || gotShape[2] != wantShape[2] {
|
||||
t.Errorf("output shape: got %v, want %v", gotShape, wantShape)
|
||||
}
|
||||
|
||||
// Verify output is finite
|
||||
outData := out.Data()
|
||||
for i := 0; i < min(100, len(outData)); i++ {
|
||||
if math.IsNaN(float64(outData[i])) || math.IsInf(float64(outData[i]), 0) {
|
||||
t.Errorf("output[%d] not finite: %v", i, outData[i])
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,854 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package qwen_image
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
|
||||
// VAEConfig holds Qwen-Image VAE configuration
|
||||
type VAEConfig struct {
|
||||
ZDim int32 `json:"z_dim"` // 16
|
||||
BaseDim int32 `json:"base_dim"` // 96
|
||||
DimMult []int32 `json:"dim_mult"` // [1, 2, 4, 4]
|
||||
NumResBlocks int32 `json:"num_res_blocks"` // 2
|
||||
LatentsMean []float32 `json:"latents_mean"` // 16 values
|
||||
LatentsStd []float32 `json:"latents_std"` // 16 values
|
||||
TemperalDownsample []bool `json:"temperal_downsample"` // [false, true, true]
|
||||
}
|
||||
|
||||
// defaultVAEConfig returns config for Qwen-Image VAE
|
||||
func defaultVAEConfig() *VAEConfig {
|
||||
return &VAEConfig{
|
||||
ZDim: 16,
|
||||
BaseDim: 96,
|
||||
DimMult: []int32{1, 2, 4, 4},
|
||||
NumResBlocks: 2,
|
||||
LatentsMean: []float32{
|
||||
-0.7571, -0.7089, -0.9113, 0.1075,
|
||||
-0.1745, 0.9653, -0.1517, 1.5508,
|
||||
0.4134, -0.0715, 0.5517, -0.3632,
|
||||
-0.1922, -0.9497, 0.2503, -0.2921,
|
||||
},
|
||||
LatentsStd: []float32{
|
||||
2.8184, 1.4541, 2.3275, 2.6558,
|
||||
1.2196, 1.7708, 2.6052, 2.0743,
|
||||
3.2687, 2.1526, 2.8652, 1.5579,
|
||||
1.6382, 1.1253, 2.8251, 1.916,
|
||||
},
|
||||
TemperalDownsample: []bool{false, true, true},
|
||||
}
|
||||
}
|
||||
|
||||
// CausalConv3d is a causal 3D convolution (for temporal causality)
|
||||
type CausalConv3d struct {
|
||||
Weight *mlx.Array
|
||||
Bias *mlx.Array
|
||||
BiasReshaped *mlx.Array // [1, C, 1, 1, 1]
|
||||
KernelT int32
|
||||
}
|
||||
|
||||
// newCausalConv3d creates a 3D causal conv
|
||||
func newCausalConv3d(weights *safetensors.ModelWeights, prefix string) (*CausalConv3d, error) {
|
||||
weight, err := weights.Get(prefix + ".weight")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("weight not found: %s", prefix)
|
||||
}
|
||||
bias, _ := weights.Get(prefix + ".bias")
|
||||
|
||||
kernelT := weight.Shape()[2]
|
||||
outC := weight.Shape()[0]
|
||||
|
||||
var biasReshaped *mlx.Array
|
||||
if bias != nil {
|
||||
biasReshaped = mlx.Reshape(bias, 1, outC, 1, 1, 1)
|
||||
}
|
||||
|
||||
return &CausalConv3d{
|
||||
Weight: weight,
|
||||
Bias: bias,
|
||||
BiasReshaped: biasReshaped,
|
||||
KernelT: kernelT,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies causal 3D convolution
|
||||
// x: [B, T, H, W, C] (channels-last, MLX format)
|
||||
func (c *CausalConv3d) Forward(x *mlx.Array) *mlx.Array {
|
||||
shape := c.Weight.Shape() // PyTorch format: [O, I, kT, kH, kW]
|
||||
kernelT := shape[2]
|
||||
kernelH := shape[3]
|
||||
kernelW := shape[4]
|
||||
|
||||
// Causal temporal padding, same spatial padding
|
||||
// Input is channels-last: [B, T, H, W, C]
|
||||
padT := kernelT - 1
|
||||
padH := kernelH / 2
|
||||
padW := kernelW / 2
|
||||
|
||||
// Stage 1: Pad
|
||||
{
|
||||
x = pad3DChannelsLast(x, padT, 0, padH, padH, padW, padW)
|
||||
mlx.Eval(x)
|
||||
}
|
||||
|
||||
// Stage 2: Conv + bias
|
||||
var out *mlx.Array
|
||||
{
|
||||
prev := x
|
||||
weight := mlx.Transpose(c.Weight, 0, 2, 3, 4, 1)
|
||||
out = mlx.Conv3d(x, weight, 1, 1, 1, 0, 0, 0)
|
||||
if c.Bias != nil {
|
||||
bias := mlx.Reshape(c.Bias, 1, 1, 1, 1, c.Bias.Dim(0))
|
||||
out = mlx.Add(out, bias)
|
||||
}
|
||||
prev.Free()
|
||||
mlx.Eval(out)
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// RMSNorm3D applies RMS normalization over channels
|
||||
// Works with channels-last [B, T, H, W, C] format
|
||||
type RMSNorm3D struct {
|
||||
Gamma *mlx.Array // [1, 1, 1, 1, C] for broadcasting
|
||||
}
|
||||
|
||||
// newRMSNorm3D creates an RMS norm
|
||||
func newRMSNorm3D(weights *safetensors.ModelWeights, prefix string, dim int32) (*RMSNorm3D, error) {
|
||||
gamma, err := weights.Get(prefix + ".gamma")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Reshape for channels-last broadcasting: [1, 1, 1, 1, C]
|
||||
gamma = mlx.Reshape(gamma, 1, 1, 1, 1, gamma.Dim(0))
|
||||
return &RMSNorm3D{Gamma: gamma}, nil
|
||||
}
|
||||
|
||||
// Forward applies RMS norm to channels-last input [B, T, H, W, C]
|
||||
func (n *RMSNorm3D) Forward(x *mlx.Array) *mlx.Array {
|
||||
// RMSNorm: x * rsqrt(mean(x^2) + eps) * gamma
|
||||
normalized := mlx.RMSNormNoWeight(x, 1e-6)
|
||||
return mlx.Mul(normalized, n.Gamma)
|
||||
}
|
||||
|
||||
// ResBlock is a residual block with RMS norm and causal convs
|
||||
type ResBlock struct {
|
||||
Norm1 *RMSNorm3D
|
||||
Conv1 *CausalConv3d
|
||||
Norm2 *RMSNorm3D
|
||||
Conv2 *CausalConv3d
|
||||
Shortcut *CausalConv3d
|
||||
}
|
||||
|
||||
// newResBlock creates a residual block
|
||||
func newResBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32) (*ResBlock, error) {
|
||||
norm1, err := newRMSNorm3D(weights, prefix+".norm1", inDim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conv1, err := newCausalConv3d(weights, prefix+".conv1")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
norm2, err := newRMSNorm3D(weights, prefix+".norm2", outDim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conv2, err := newCausalConv3d(weights, prefix+".conv2")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var shortcut *CausalConv3d
|
||||
if inDim != outDim {
|
||||
shortcut, err = newCausalConv3d(weights, prefix+".conv_shortcut")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &ResBlock{
|
||||
Norm1: norm1,
|
||||
Conv1: conv1,
|
||||
Norm2: norm2,
|
||||
Conv2: conv2,
|
||||
Shortcut: shortcut,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies the residual block
|
||||
func (r *ResBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
// Use h as working variable, keep x intact for residual (caller will free x)
|
||||
// Conv handles its own pools, so we just need pools for non-conv operations
|
||||
var h *mlx.Array
|
||||
|
||||
// Keep x so it survives Eval() cleanup - needed for residual connection
|
||||
mlx.Keep(x)
|
||||
|
||||
// Stage 1: norm1 + silu
|
||||
{
|
||||
h = r.Norm1.Forward(x)
|
||||
h = silu3D(h)
|
||||
mlx.Eval(h)
|
||||
}
|
||||
|
||||
// Stage 2: conv1 (handles its own pools)
|
||||
{
|
||||
prev := h
|
||||
h = r.Conv1.Forward(h)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
// Stage 3: norm2 + silu
|
||||
{
|
||||
prev := h
|
||||
h = r.Norm2.Forward(h)
|
||||
h = silu3D(h)
|
||||
prev.Free()
|
||||
mlx.Eval(h)
|
||||
}
|
||||
|
||||
// Stage 4: conv2 (handles its own pools)
|
||||
{
|
||||
prev := h
|
||||
h = r.Conv2.Forward(h)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
// Residual connection (shortcut handles its own pools if present)
|
||||
if r.Shortcut != nil {
|
||||
shortcut := r.Shortcut.Forward(x)
|
||||
h = mlx.Add(h, shortcut)
|
||||
mlx.Eval(h)
|
||||
} else {
|
||||
h = mlx.Add(h, x)
|
||||
mlx.Eval(h)
|
||||
}
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// AttentionBlock is a 2D attention block
|
||||
type AttentionBlock struct {
|
||||
Norm *RMSNorm3D
|
||||
ToQKV *mlx.Array
|
||||
ToQKVBias *mlx.Array
|
||||
Proj *mlx.Array
|
||||
ProjBias *mlx.Array
|
||||
Dim int32
|
||||
}
|
||||
|
||||
// newAttentionBlock creates an attention block
|
||||
func newAttentionBlock(weights *safetensors.ModelWeights, prefix string, dim int32) (*AttentionBlock, error) {
|
||||
norm, err := newRMSNorm3D(weights, prefix+".norm", dim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
toQKV, _ := weights.Get(prefix + ".to_qkv.weight")
|
||||
toQKVBias, _ := weights.Get(prefix + ".to_qkv.bias")
|
||||
proj, _ := weights.Get(prefix + ".proj.weight")
|
||||
projBias, _ := weights.Get(prefix + ".proj.bias")
|
||||
|
||||
return &AttentionBlock{
|
||||
Norm: norm,
|
||||
ToQKV: toQKV,
|
||||
ToQKVBias: toQKVBias,
|
||||
Proj: proj,
|
||||
ProjBias: projBias,
|
||||
Dim: dim,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies 2D attention
|
||||
// Input: [B, T, H, W, C] (channels-last)
|
||||
func (a *AttentionBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
T := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
C := shape[4]
|
||||
|
||||
identity := x
|
||||
|
||||
// Flatten to [B*T, 1, H, W, C] for norm
|
||||
x = mlx.Reshape(x, B*T, 1, H, W, C)
|
||||
x = a.Norm.Forward(x)
|
||||
x = mlx.Reshape(x, B*T, H, W, C)
|
||||
|
||||
// Flatten spatial to [B*T, H*W, C]
|
||||
x = mlx.Reshape(x, B*T, H*W, C)
|
||||
|
||||
// Linear to get Q, K, V: [B*T, H*W, 3*C]
|
||||
// Weight is [outC, inC] or [outC, inC, 1, 1]
|
||||
wShape := a.ToQKV.Shape()
|
||||
var w *mlx.Array
|
||||
if len(wShape) == 4 {
|
||||
w = mlx.Reshape(a.ToQKV, wShape[0], wShape[1])
|
||||
} else {
|
||||
w = a.ToQKV
|
||||
}
|
||||
w = mlx.Transpose(w, 1, 0) // [inC, outC]
|
||||
|
||||
qkv := mlx.Linear(x, w) // [B*T, H*W, 3*C]
|
||||
if a.ToQKVBias != nil {
|
||||
qkv = mlx.Add(qkv, a.ToQKVBias)
|
||||
}
|
||||
qkv = mlx.Reshape(qkv, B*T, 1, H*W, 3*C)
|
||||
|
||||
q := mlx.Slice(qkv, []int32{0, 0, 0, 0}, []int32{B * T, 1, H * W, C})
|
||||
k := mlx.Slice(qkv, []int32{0, 0, 0, C}, []int32{B * T, 1, H * W, 2 * C})
|
||||
v := mlx.Slice(qkv, []int32{0, 0, 0, 2 * C}, []int32{B * T, 1, H * W, 3 * C})
|
||||
|
||||
scale := float32(1.0 / math.Sqrt(float64(C)))
|
||||
out := mlx.ScaledDotProductAttention(q, k, v, scale, false)
|
||||
|
||||
// out: [B*T, 1, H*W, C]
|
||||
out = mlx.Reshape(out, B*T, H*W, C)
|
||||
|
||||
// Project back
|
||||
pShape := a.Proj.Shape()
|
||||
var p *mlx.Array
|
||||
if len(pShape) == 4 {
|
||||
p = mlx.Reshape(a.Proj, pShape[0], pShape[1])
|
||||
} else {
|
||||
p = a.Proj
|
||||
}
|
||||
p = mlx.Transpose(p, 1, 0) // [inC, outC]
|
||||
out = mlx.Linear(out, p) // [B*T, H*W, C]
|
||||
if a.ProjBias != nil {
|
||||
out = mlx.Add(out, a.ProjBias)
|
||||
}
|
||||
|
||||
out = mlx.Reshape(out, B, T, H, W, C)
|
||||
return mlx.Add(out, identity)
|
||||
}
|
||||
|
||||
// UpBlock handles upsampling in decoder
|
||||
type UpBlock struct {
|
||||
ResBlocks []*ResBlock
|
||||
Upsampler *Upsample
|
||||
}
|
||||
|
||||
// newUpBlock creates an up block
|
||||
func newUpBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32, numBlocks int32, upsampleMode string) (*UpBlock, error) {
|
||||
resBlocks := make([]*ResBlock, numBlocks+1)
|
||||
|
||||
currentDim := inDim
|
||||
for i := int32(0); i <= numBlocks; i++ {
|
||||
resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i)
|
||||
block, err := newResBlock(weights, resPrefix, currentDim, outDim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resBlocks[i] = block
|
||||
currentDim = outDim
|
||||
}
|
||||
|
||||
var upsampler *Upsample
|
||||
if upsampleMode != "" {
|
||||
upsampler = newUpsample(weights, prefix+".upsamplers.0", outDim, upsampleMode)
|
||||
}
|
||||
|
||||
return &UpBlock{
|
||||
ResBlocks: resBlocks,
|
||||
Upsampler: upsampler,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies up block with staged memory management
|
||||
func (u *UpBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
// ResBlocks handle their own pools
|
||||
for _, block := range u.ResBlocks {
|
||||
prev := x
|
||||
x = block.Forward(x)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
// Upsampler handles its own pools
|
||||
if u.Upsampler != nil {
|
||||
prev := x
|
||||
x = u.Upsampler.Forward(x)
|
||||
prev.Free()
|
||||
}
|
||||
return x
|
||||
}
|
||||
|
||||
// Upsample handles spatial upsampling
|
||||
type Upsample struct {
|
||||
Conv *mlx.Array
|
||||
Bias *mlx.Array
|
||||
Mode string
|
||||
}
|
||||
|
||||
// newUpsample creates an upsampler
|
||||
func newUpsample(weights *safetensors.ModelWeights, prefix string, dim int32, mode string) *Upsample {
|
||||
conv, _ := weights.Get(prefix + ".resample.1.weight")
|
||||
bias, _ := weights.Get(prefix + ".resample.1.bias")
|
||||
return &Upsample{
|
||||
Conv: conv,
|
||||
Bias: bias,
|
||||
Mode: mode,
|
||||
}
|
||||
}
|
||||
|
||||
// Forward applies upsampling to channels-last input [B, T, H, W, C]
|
||||
// Uses staged pools to reduce peak memory during 2x upsampling
|
||||
func (u *Upsample) Forward(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
T := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
C := shape[4]
|
||||
outC := u.Conv.Shape()[0]
|
||||
|
||||
// Stage 1: 2x nearest neighbor upsample
|
||||
{
|
||||
x = mlx.Reshape(x, B*T, H, W, C)
|
||||
x = upsample2xChannelsLast(x)
|
||||
mlx.Eval(x)
|
||||
}
|
||||
|
||||
// Stage 2: Conv + bias
|
||||
{
|
||||
prev := x
|
||||
weight := mlx.Transpose(u.Conv, 0, 2, 3, 1)
|
||||
x = conv2D3x3PaddedChannelsLast(x, weight)
|
||||
if u.Bias != nil {
|
||||
bias := mlx.Reshape(u.Bias, 1, 1, 1, outC)
|
||||
x = mlx.Add(x, bias)
|
||||
}
|
||||
x = mlx.Reshape(x, B, T, H*2, W*2, outC)
|
||||
prev.Free()
|
||||
mlx.Eval(x)
|
||||
}
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// MidBlock is the middle block of decoder
|
||||
type MidBlock struct {
|
||||
ResBlock1 *ResBlock
|
||||
Attention *AttentionBlock
|
||||
ResBlock2 *ResBlock
|
||||
}
|
||||
|
||||
// newMidBlock creates a mid block
|
||||
func newMidBlock(weights *safetensors.ModelWeights, prefix string, dim int32) (*MidBlock, error) {
|
||||
res1, err := newResBlock(weights, prefix+".resnets.0", dim, dim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
attn, err := newAttentionBlock(weights, prefix+".attentions.0", dim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
res2, err := newResBlock(weights, prefix+".resnets.1", dim, dim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &MidBlock{
|
||||
ResBlock1: res1,
|
||||
Attention: attn,
|
||||
ResBlock2: res2,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies mid block
|
||||
func (m *MidBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
// Each component handles its own pools; we just free inputs
|
||||
prev := x
|
||||
x = m.ResBlock1.Forward(x)
|
||||
prev.Free()
|
||||
|
||||
prev = x
|
||||
x = m.Attention.Forward(x)
|
||||
prev.Free()
|
||||
|
||||
prev = x
|
||||
x = m.ResBlock2.Forward(x)
|
||||
prev.Free()
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// VAEDecoder is the full VAE decoder
|
||||
type VAEDecoder struct {
|
||||
Config *VAEConfig
|
||||
|
||||
PostQuantConv *CausalConv3d
|
||||
ConvIn *CausalConv3d
|
||||
MidBlock *MidBlock
|
||||
UpBlocks []*UpBlock
|
||||
NormOut *RMSNorm3D
|
||||
ConvOut *CausalConv3d
|
||||
}
|
||||
|
||||
// Load loads the VAE decoder from a directory
|
||||
func (m *VAEDecoder) Load(path string) error {
|
||||
fmt.Println("Loading Qwen-Image VAE decoder...")
|
||||
|
||||
cfg := defaultVAEConfig()
|
||||
m.Config = cfg
|
||||
|
||||
weights, err := safetensors.LoadModelWeights(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
|
||||
// Bulk load all weights as bf16
|
||||
fmt.Print(" Loading weights as bf16... ")
|
||||
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
|
||||
return fmt.Errorf("failed to load weights: %w", err)
|
||||
}
|
||||
fmt.Printf("✓ (%.1f GB)\n", float64(mlx.MetalGetActiveMemory())/(1024*1024*1024))
|
||||
|
||||
fmt.Print(" Loading post_quant_conv... ")
|
||||
postQuantConv, err := newCausalConv3d(weights, "post_quant_conv")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.PostQuantConv = postQuantConv
|
||||
fmt.Println("✓")
|
||||
|
||||
fmt.Print(" Loading conv_in... ")
|
||||
convIn, err := newCausalConv3d(weights, "decoder.conv_in")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.ConvIn = convIn
|
||||
fmt.Println("✓")
|
||||
|
||||
// Mid block (dim = base_dim * dim_mult[-1] = 96 * 4 = 384)
|
||||
fmt.Print(" Loading mid_block... ")
|
||||
midDim := cfg.BaseDim * cfg.DimMult[len(cfg.DimMult)-1]
|
||||
midBlock, err := newMidBlock(weights, "decoder.mid_block", midDim)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.MidBlock = midBlock
|
||||
fmt.Println("✓")
|
||||
|
||||
// Up blocks (reversed dim_mult)
|
||||
fmt.Print(" Loading up_blocks... ")
|
||||
numUpBlocks := len(cfg.DimMult)
|
||||
m.UpBlocks = make([]*UpBlock, numUpBlocks)
|
||||
|
||||
dimsMult := make([]int32, numUpBlocks+1)
|
||||
dimsMult[0] = cfg.DimMult[numUpBlocks-1]
|
||||
for i := 0; i < numUpBlocks; i++ {
|
||||
dimsMult[i+1] = cfg.DimMult[numUpBlocks-1-i]
|
||||
}
|
||||
|
||||
temporalUpsample := make([]bool, len(cfg.TemperalDownsample))
|
||||
for i := range cfg.TemperalDownsample {
|
||||
temporalUpsample[i] = cfg.TemperalDownsample[len(cfg.TemperalDownsample)-1-i]
|
||||
}
|
||||
|
||||
for i := 0; i < numUpBlocks; i++ {
|
||||
inDim := cfg.BaseDim * dimsMult[i]
|
||||
outDim := cfg.BaseDim * dimsMult[i+1]
|
||||
|
||||
if i > 0 {
|
||||
inDim = inDim / 2
|
||||
}
|
||||
|
||||
upsampleMode := ""
|
||||
if i < numUpBlocks-1 {
|
||||
if temporalUpsample[i] {
|
||||
upsampleMode = "upsample3d"
|
||||
} else {
|
||||
upsampleMode = "upsample2d"
|
||||
}
|
||||
}
|
||||
|
||||
prefix := fmt.Sprintf("decoder.up_blocks.%d", i)
|
||||
upBlock, err := newUpBlock(weights, prefix, inDim, outDim, cfg.NumResBlocks, upsampleMode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.UpBlocks[i] = upBlock
|
||||
}
|
||||
fmt.Printf("✓ [%d blocks]\n", numUpBlocks)
|
||||
|
||||
fmt.Print(" Loading output layers... ")
|
||||
normOut, err := newRMSNorm3D(weights, "decoder.norm_out", cfg.BaseDim)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.NormOut = normOut
|
||||
convOut, err := newCausalConv3d(weights, "decoder.conv_out")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.ConvOut = convOut
|
||||
fmt.Println("✓")
|
||||
|
||||
weights.ReleaseAll()
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadVAEDecoderFromPath is a convenience function to load VAE from path
|
||||
func LoadVAEDecoderFromPath(path string) (*VAEDecoder, error) {
|
||||
m := &VAEDecoder{}
|
||||
if err := m.Load(filepath.Join(path, "vae")); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Decode converts latents to image
|
||||
// z: [B, C, T, H, W] normalized latents
|
||||
// Uses staged pools to free intermediate arrays and reduce peak memory.
|
||||
func (vae *VAEDecoder) Decode(z *mlx.Array) *mlx.Array {
|
||||
var x *mlx.Array
|
||||
|
||||
// Stage 1a: Denormalize and transpose
|
||||
{
|
||||
z = vae.Denormalize(z)
|
||||
// Convert from channels-first [N, C, T, H, W] to channels-last [N, T, H, W, C]
|
||||
z = mlx.Contiguous(mlx.Transpose(z, 0, 2, 3, 4, 1))
|
||||
mlx.Eval(z)
|
||||
}
|
||||
|
||||
// Stage 1b: PostQuantConv (handles its own pools)
|
||||
x = vae.PostQuantConv.Forward(z)
|
||||
z.Free()
|
||||
|
||||
// Stage 1c: ConvIn (handles its own pools)
|
||||
{
|
||||
prev := x
|
||||
x = vae.ConvIn.Forward(x)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
// Stage 2: Mid block (handles its own pools)
|
||||
x = vae.MidBlock.Forward(x)
|
||||
|
||||
// Stage 3: Up blocks (each handles its own pools)
|
||||
for _, upBlock := range vae.UpBlocks {
|
||||
x = upBlock.Forward(x)
|
||||
}
|
||||
|
||||
// Stage 4a: NormOut + silu
|
||||
{
|
||||
prev := x
|
||||
x = vae.NormOut.Forward(x)
|
||||
x = silu3D(x)
|
||||
prev.Free()
|
||||
mlx.Eval(x)
|
||||
}
|
||||
|
||||
// Stage 4b: ConvOut (handles its own pools)
|
||||
{
|
||||
prev := x
|
||||
x = vae.ConvOut.Forward(x)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
// Stage 4c: Post-processing
|
||||
{
|
||||
prev := x
|
||||
// Clamp to [-1, 1]
|
||||
x = mlx.ClipScalar(x, -1.0, 1.0, true, true)
|
||||
// Convert back from channels-last to channels-first
|
||||
x = mlx.Contiguous(mlx.Transpose(x, 0, 4, 1, 2, 3))
|
||||
prev.Free()
|
||||
mlx.Eval(x)
|
||||
}
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// Denormalize reverses the normalization applied during encoding
|
||||
func (vae *VAEDecoder) Denormalize(z *mlx.Array) *mlx.Array {
|
||||
shape := z.Shape()
|
||||
C := shape[1]
|
||||
|
||||
mean := mlx.NewArray(vae.Config.LatentsMean[:C], []int32{1, C, 1, 1, 1})
|
||||
std := mlx.NewArray(vae.Config.LatentsStd[:C], []int32{1, C, 1, 1, 1})
|
||||
|
||||
mean = mlx.ToBFloat16(mean)
|
||||
std = mlx.ToBFloat16(std)
|
||||
|
||||
return mlx.Add(mlx.Mul(z, std), mean)
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
func silu3D(x *mlx.Array) *mlx.Array {
|
||||
return mlx.Mul(x, mlx.Sigmoid(x))
|
||||
}
|
||||
|
||||
// pad3DChannelsLast pads a channels-last [B, T, H, W, C] tensor
|
||||
func pad3DChannelsLast(x *mlx.Array, tBefore, tAfter, hBefore, hAfter, wBefore, wAfter int32) *mlx.Array {
|
||||
if tBefore == 0 && tAfter == 0 && hBefore == 0 && hAfter == 0 && wBefore == 0 && wAfter == 0 {
|
||||
return x
|
||||
}
|
||||
// Pad dims: [B before, B after, T before, T after, H before, H after, W before, W after, C before, C after]
|
||||
return mlx.Pad(x, []int32{0, 0, tBefore, tAfter, hBefore, hAfter, wBefore, wAfter, 0, 0})
|
||||
}
|
||||
|
||||
func pad2D(x *mlx.Array, hBefore, hAfter, wBefore, wAfter int32) *mlx.Array {
|
||||
if hBefore == 0 && hAfter == 0 && wBefore == 0 && wAfter == 0 {
|
||||
return x
|
||||
}
|
||||
return mlx.Pad(x, []int32{0, 0, 0, 0, hBefore, hAfter, wBefore, wAfter})
|
||||
}
|
||||
|
||||
func conv2D1x1(x, weight *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
|
||||
x = mlx.Transpose(x, 0, 2, 3, 1)
|
||||
x = mlx.Reshape(x, B*H*W, shape[1])
|
||||
|
||||
wShape := weight.Shape()
|
||||
var w *mlx.Array
|
||||
if len(wShape) == 4 {
|
||||
w = mlx.Reshape(weight, wShape[0], wShape[1])
|
||||
} else {
|
||||
w = weight
|
||||
}
|
||||
w = mlx.Transpose(w, 1, 0)
|
||||
|
||||
out := mlx.Linear(x, w)
|
||||
outC := w.Dim(1)
|
||||
out = mlx.Reshape(out, B, H, W, outC)
|
||||
return mlx.Transpose(out, 0, 3, 1, 2)
|
||||
}
|
||||
|
||||
func conv2D3x3Padded(x, weight *mlx.Array) *mlx.Array {
|
||||
x = pad2D(x, 1, 1, 1, 1)
|
||||
return conv2D(x, weight, 1, 1)
|
||||
}
|
||||
|
||||
func conv2D(x, w *mlx.Array, strideH, strideW int32) *mlx.Array {
|
||||
x = mlx.Transpose(x, 0, 2, 3, 1)
|
||||
w = mlx.Transpose(w, 0, 2, 3, 1)
|
||||
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
H := shape[1]
|
||||
W := shape[2]
|
||||
|
||||
wShape := w.Shape()
|
||||
Cout := wShape[0]
|
||||
kH := wShape[1]
|
||||
kW := wShape[2]
|
||||
|
||||
outH := (H-kH)/strideH + 1
|
||||
outW := (W-kW)/strideW + 1
|
||||
|
||||
patches := extractPatches2D(x, kH, kW, strideH, strideW)
|
||||
wFlat := mlx.Reshape(w, Cout, -1)
|
||||
patches = mlx.Reshape(patches, B*outH*outW, -1)
|
||||
out := mlx.Linear(patches, mlx.Transpose(wFlat, 1, 0))
|
||||
out = mlx.Reshape(out, B, outH, outW, Cout)
|
||||
return mlx.Transpose(out, 0, 3, 1, 2)
|
||||
}
|
||||
|
||||
func extractPatches2D(x *mlx.Array, kH, kW, strideH, strideW int32) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
H := shape[1]
|
||||
W := shape[2]
|
||||
C := shape[3]
|
||||
|
||||
outH := (H-kH)/strideH + 1
|
||||
outW := (W-kW)/strideW + 1
|
||||
|
||||
patches := make([]*mlx.Array, outH*outW)
|
||||
idx := 0
|
||||
for i := int32(0); i < outH; i++ {
|
||||
for j := int32(0); j < outW; j++ {
|
||||
startH := i * strideH
|
||||
startW := j * strideW
|
||||
patch := mlx.Slice(x, []int32{0, startH, startW, 0}, []int32{B, startH + kH, startW + kW, C})
|
||||
patch = mlx.Reshape(patch, B, kH*kW*C)
|
||||
patches[idx] = patch
|
||||
idx++
|
||||
}
|
||||
}
|
||||
|
||||
for i := range patches {
|
||||
patches[i] = mlx.ExpandDims(patches[i], 1)
|
||||
}
|
||||
stacked := mlx.Concatenate(patches, 1)
|
||||
return mlx.Reshape(stacked, B, outH, outW, kH*kW*C)
|
||||
}
|
||||
|
||||
func upsample2x(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
|
||||
rowIdxData := make([]int32, H*2)
|
||||
for i := int32(0); i < H; i++ {
|
||||
rowIdxData[i*2] = i
|
||||
rowIdxData[i*2+1] = i
|
||||
}
|
||||
rowIdx := mlx.NewArrayInt32(rowIdxData, []int32{H * 2})
|
||||
|
||||
colIdxData := make([]int32, W*2)
|
||||
for i := int32(0); i < W; i++ {
|
||||
colIdxData[i*2] = i
|
||||
colIdxData[i*2+1] = i
|
||||
}
|
||||
colIdx := mlx.NewArrayInt32(colIdxData, []int32{W * 2})
|
||||
|
||||
x = mlx.Take(x, rowIdx, 2)
|
||||
x = mlx.Take(x, colIdx, 3)
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// upsample2xChannelsLast upsamples channels-last input [B, H, W, C] by 2x
|
||||
func upsample2xChannelsLast(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
H := shape[1]
|
||||
W := shape[2]
|
||||
|
||||
// Create repeat indices for rows
|
||||
rowIdxData := make([]int32, H*2)
|
||||
for i := int32(0); i < H; i++ {
|
||||
rowIdxData[i*2] = i
|
||||
rowIdxData[i*2+1] = i
|
||||
}
|
||||
rowIdx := mlx.NewArrayInt32(rowIdxData, []int32{H * 2})
|
||||
|
||||
// Create repeat indices for columns
|
||||
colIdxData := make([]int32, W*2)
|
||||
for i := int32(0); i < W; i++ {
|
||||
colIdxData[i*2] = i
|
||||
colIdxData[i*2+1] = i
|
||||
}
|
||||
colIdx := mlx.NewArrayInt32(colIdxData, []int32{W * 2})
|
||||
|
||||
// Take along H (axis 1) then W (axis 2)
|
||||
x = mlx.Take(x, rowIdx, 1)
|
||||
x = mlx.Take(x, colIdx, 2)
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// conv2D3x3PaddedChannelsLast applies 3x3 conv with padding to channels-last input [B, H, W, C]
|
||||
// weight: [outC, kH, kW, inC] (MLX channels-last format)
|
||||
func conv2D3x3PaddedChannelsLast(x, weight *mlx.Array) *mlx.Array {
|
||||
// Pad spatial dims: [B, H, W, C] -> pad H and W by 1 each side
|
||||
x = mlx.Pad(x, []int32{0, 0, 1, 1, 1, 1, 0, 0})
|
||||
// Conv2d expects: input [B, H, W, inC], weight [outC, kH, kW, inC]
|
||||
// stride=1, padding=0 (we already padded manually)
|
||||
return mlx.Conv2d(x, weight, 1, 0)
|
||||
}
|
||||
@@ -1,114 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package qwen_image
|
||||
|
||||
import (
|
||||
"math"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// TestVAEConfig tests configuration invariants.
|
||||
func TestVAEConfig(t *testing.T) {
|
||||
cfg := defaultVAEConfig()
|
||||
|
||||
// Property: latents_mean and latents_std have z_dim elements
|
||||
if int32(len(cfg.LatentsMean)) != cfg.ZDim {
|
||||
t.Errorf("latents_mean length != z_dim: %d != %d", len(cfg.LatentsMean), cfg.ZDim)
|
||||
}
|
||||
if int32(len(cfg.LatentsStd)) != cfg.ZDim {
|
||||
t.Errorf("latents_std length != z_dim: %d != %d", len(cfg.LatentsStd), cfg.ZDim)
|
||||
}
|
||||
|
||||
// Property: dim_mult defines 4 stages
|
||||
if len(cfg.DimMult) != 4 {
|
||||
t.Errorf("dim_mult should have 4 stages: got %d", len(cfg.DimMult))
|
||||
}
|
||||
|
||||
// Property: temperal_downsample has 3 elements (for 3 transitions)
|
||||
if len(cfg.TemperalDownsample) != 3 {
|
||||
t.Errorf("temperal_downsample should have 3 elements: got %d", len(cfg.TemperalDownsample))
|
||||
}
|
||||
}
|
||||
|
||||
// TestVAELatentsNormalization tests the latent denormalization values.
|
||||
func TestVAELatentsNormalization(t *testing.T) {
|
||||
cfg := defaultVAEConfig()
|
||||
|
||||
// Verify latents_std values are all positive
|
||||
for i, std := range cfg.LatentsStd {
|
||||
if std <= 0 {
|
||||
t.Errorf("latents_std[%d] should be positive: %v", i, std)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify values are in reasonable range (from actual model)
|
||||
for i, mean := range cfg.LatentsMean {
|
||||
if math.Abs(float64(mean)) > 5 {
|
||||
t.Errorf("latents_mean[%d] seems too large: %v", i, mean)
|
||||
}
|
||||
}
|
||||
for i, std := range cfg.LatentsStd {
|
||||
if std > 10 {
|
||||
t.Errorf("latents_std[%d] seems too large: %v", i, std)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestVAEDecoderForward tests full forward pass (integration test).
|
||||
// Skips if model weights are not available.
|
||||
func TestVAEDecoderForward(t *testing.T) {
|
||||
weightsPath := "../../../weights/Qwen-Image-2512/vae"
|
||||
if _, err := os.Stat(weightsPath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping: model weights not found at " + weightsPath)
|
||||
}
|
||||
|
||||
vae := &VAEDecoder{}
|
||||
if err := vae.Load(weightsPath); err != nil {
|
||||
t.Fatalf("Failed to load VAE decoder: %v", err)
|
||||
}
|
||||
mlx.Keep(mlx.Collect(vae)...)
|
||||
|
||||
// Small test input: [B, C, T, H, W]
|
||||
// After 4 upsampling stages (2x each), H/W multiply by 16
|
||||
batchSize := int32(1)
|
||||
channels := int32(16)
|
||||
frames := int32(1)
|
||||
latentH := int32(4)
|
||||
latentW := int32(4)
|
||||
|
||||
latents := mlx.RandomNormal([]int32{batchSize, channels, frames, latentH, latentW}, 0)
|
||||
|
||||
// Decode
|
||||
out := vae.Decode(latents)
|
||||
mlx.Eval(out)
|
||||
|
||||
// Verify output shape: [B, 3, T, H*16, W*16]
|
||||
outShape := out.Shape()
|
||||
if outShape[0] != batchSize {
|
||||
t.Errorf("batch size: got %d, want %d", outShape[0], batchSize)
|
||||
}
|
||||
if outShape[1] != 3 {
|
||||
t.Errorf("channels: got %d, want 3", outShape[1])
|
||||
}
|
||||
if outShape[2] != frames {
|
||||
t.Errorf("frames: got %d, want %d", outShape[2], frames)
|
||||
}
|
||||
expectedH := latentH * 16 // 4 stages of 2x upsampling
|
||||
expectedW := latentW * 16
|
||||
if outShape[3] != expectedH || outShape[4] != expectedW {
|
||||
t.Errorf("spatial dims: got [%d, %d], want [%d, %d]",
|
||||
outShape[3], outShape[4], expectedH, expectedW)
|
||||
}
|
||||
|
||||
// Verify output is in valid range (should be clamped to [0, 1] by decode)
|
||||
outData := out.Data()
|
||||
for i := 0; i < min(100, len(outData)); i++ {
|
||||
if math.IsNaN(float64(outData[i])) || math.IsInf(float64(outData[i]), 0) {
|
||||
t.Errorf("output[%d] not finite: %v", i, outData[i])
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,682 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package qwen_image_edit
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
|
||||
// CausalConv3d is a causal 3D convolution (for temporal causality)
|
||||
type CausalConv3d struct {
|
||||
Weight *mlx.Array
|
||||
Bias *mlx.Array
|
||||
BiasReshaped *mlx.Array // [1, C, 1, 1, 1]
|
||||
KernelT int32
|
||||
}
|
||||
|
||||
// newCausalConv3d creates a 3D causal conv
|
||||
func newCausalConv3d(weights *safetensors.ModelWeights, prefix string) (*CausalConv3d, error) {
|
||||
weight, err := weights.Get(prefix + ".weight")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("weight not found: %s", prefix)
|
||||
}
|
||||
bias, _ := weights.Get(prefix + ".bias")
|
||||
|
||||
kernelT := weight.Shape()[2]
|
||||
outC := weight.Shape()[0]
|
||||
|
||||
var biasReshaped *mlx.Array
|
||||
if bias != nil {
|
||||
biasReshaped = mlx.Reshape(bias, 1, outC, 1, 1, 1)
|
||||
}
|
||||
|
||||
return &CausalConv3d{
|
||||
Weight: weight,
|
||||
Bias: bias,
|
||||
BiasReshaped: biasReshaped,
|
||||
KernelT: kernelT,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies causal 3D convolution (or 2D if weight is 4D)
|
||||
// x: [B, T, H, W, C] (channels-last, MLX format)
|
||||
func (c *CausalConv3d) Forward(x *mlx.Array) *mlx.Array {
|
||||
shape := c.Weight.Shape()
|
||||
|
||||
// Handle both 5D (3D conv) and 4D (2D conv) weights
|
||||
if len(shape) == 4 {
|
||||
// 2D conv: [O, I, kH, kW] - need to apply per-frame
|
||||
return c.forward2D(x)
|
||||
}
|
||||
|
||||
// 3D conv: [O, I, kT, kH, kW]
|
||||
kernelT := shape[2]
|
||||
kernelH := shape[3]
|
||||
kernelW := shape[4]
|
||||
|
||||
// Causal temporal padding, same spatial padding
|
||||
padT := kernelT - 1
|
||||
padH := kernelH / 2
|
||||
padW := kernelW / 2
|
||||
|
||||
// Stage 1: Pad
|
||||
{
|
||||
x = pad3DChannelsLast(x, padT, 0, padH, padH, padW, padW)
|
||||
mlx.Eval(x)
|
||||
}
|
||||
|
||||
// Stage 2: Conv + bias
|
||||
var out *mlx.Array
|
||||
{
|
||||
prev := x
|
||||
weight := mlx.Transpose(c.Weight, 0, 2, 3, 4, 1)
|
||||
out = mlx.Conv3d(x, weight, 1, 1, 1, 0, 0, 0)
|
||||
if c.Bias != nil {
|
||||
bias := mlx.Reshape(c.Bias, 1, 1, 1, 1, c.Bias.Dim(0))
|
||||
out = mlx.Add(out, bias)
|
||||
}
|
||||
prev.Free()
|
||||
mlx.Eval(out)
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// forward2D applies 2D conv per-frame for [B, T, H, W, C] input
|
||||
func (c *CausalConv3d) forward2D(x *mlx.Array) *mlx.Array {
|
||||
xShape := x.Shape()
|
||||
B := xShape[0]
|
||||
T := xShape[1]
|
||||
H := xShape[2]
|
||||
W := xShape[3]
|
||||
C := xShape[4]
|
||||
|
||||
wShape := c.Weight.Shape() // [O, I, kH, kW]
|
||||
kernelH := wShape[2]
|
||||
kernelW := wShape[3]
|
||||
outC := wShape[0]
|
||||
|
||||
padH := kernelH / 2
|
||||
padW := kernelW / 2
|
||||
|
||||
// Reshape to [B*T, H, W, C] for 2D conv
|
||||
x = mlx.Reshape(x, B*T, H, W, C)
|
||||
|
||||
// Pad spatially
|
||||
x = mlx.Pad(x, []int32{0, 0, padH, padH, padW, padW, 0, 0})
|
||||
|
||||
// Apply 2D conv
|
||||
weight := mlx.Transpose(c.Weight, 0, 2, 3, 1) // [O, I, kH, kW] -> [O, kH, kW, I]
|
||||
x = mlx.Conv2d(x, weight, 1, 0)
|
||||
|
||||
if c.Bias != nil {
|
||||
bias := mlx.Reshape(c.Bias, 1, 1, 1, outC)
|
||||
x = mlx.Add(x, bias)
|
||||
}
|
||||
|
||||
// Get output spatial dims
|
||||
outH := H
|
||||
outW := W
|
||||
|
||||
// Reshape back to [B, T, H, W, C]
|
||||
x = mlx.Reshape(x, B, T, outH, outW, outC)
|
||||
mlx.Eval(x)
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// RMSNorm3D applies RMS normalization over channels
|
||||
type RMSNorm3D struct {
|
||||
Gamma *mlx.Array // [1, 1, 1, 1, C] for broadcasting
|
||||
}
|
||||
|
||||
// newRMSNorm3D creates an RMS norm
|
||||
func newRMSNorm3D(weights *safetensors.ModelWeights, prefix string, dim int32) (*RMSNorm3D, error) {
|
||||
gamma, err := weights.Get(prefix + ".gamma")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
gamma = mlx.Reshape(gamma, 1, 1, 1, 1, gamma.Dim(0))
|
||||
return &RMSNorm3D{Gamma: gamma}, nil
|
||||
}
|
||||
|
||||
// Forward applies RMS norm to channels-last input [B, T, H, W, C]
|
||||
func (n *RMSNorm3D) Forward(x *mlx.Array) *mlx.Array {
|
||||
normalized := mlx.RMSNormNoWeight(x, 1e-6)
|
||||
return mlx.Mul(normalized, n.Gamma)
|
||||
}
|
||||
|
||||
// ResBlock is a residual block with RMS norm and causal convs
|
||||
type ResBlock struct {
|
||||
Norm1 *RMSNorm3D
|
||||
Conv1 *CausalConv3d
|
||||
Norm2 *RMSNorm3D
|
||||
Conv2 *CausalConv3d
|
||||
Shortcut *CausalConv3d
|
||||
}
|
||||
|
||||
// newResBlock creates a residual block
|
||||
func newResBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32) (*ResBlock, error) {
|
||||
norm1, err := newRMSNorm3D(weights, prefix+".norm1", inDim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conv1, err := newCausalConv3d(weights, prefix+".conv1")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
norm2, err := newRMSNorm3D(weights, prefix+".norm2", outDim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conv2, err := newCausalConv3d(weights, prefix+".conv2")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var shortcut *CausalConv3d
|
||||
if inDim != outDim {
|
||||
shortcut, err = newCausalConv3d(weights, prefix+".conv_shortcut")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &ResBlock{
|
||||
Norm1: norm1,
|
||||
Conv1: conv1,
|
||||
Norm2: norm2,
|
||||
Conv2: conv2,
|
||||
Shortcut: shortcut,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies the residual block
|
||||
func (r *ResBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
var h *mlx.Array
|
||||
|
||||
mlx.Keep(x)
|
||||
|
||||
// Stage 1: norm1 + silu
|
||||
{
|
||||
h = r.Norm1.Forward(x)
|
||||
h = silu3D(h)
|
||||
mlx.Eval(h)
|
||||
}
|
||||
|
||||
// Stage 2: conv1
|
||||
{
|
||||
prev := h
|
||||
h = r.Conv1.Forward(h)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
// Stage 3: norm2 + silu
|
||||
{
|
||||
prev := h
|
||||
h = r.Norm2.Forward(h)
|
||||
h = silu3D(h)
|
||||
prev.Free()
|
||||
mlx.Eval(h)
|
||||
}
|
||||
|
||||
// Stage 4: conv2
|
||||
{
|
||||
prev := h
|
||||
h = r.Conv2.Forward(h)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
// Residual connection
|
||||
if r.Shortcut != nil {
|
||||
shortcut := r.Shortcut.Forward(x)
|
||||
h = mlx.Add(h, shortcut)
|
||||
mlx.Eval(h)
|
||||
} else {
|
||||
h = mlx.Add(h, x)
|
||||
mlx.Eval(h)
|
||||
}
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// AttentionBlock is a 2D attention block
|
||||
type AttentionBlock struct {
|
||||
Norm *RMSNorm3D
|
||||
ToQKV *mlx.Array
|
||||
ToQKVBias *mlx.Array
|
||||
Proj *mlx.Array
|
||||
ProjBias *mlx.Array
|
||||
Dim int32
|
||||
}
|
||||
|
||||
// newAttentionBlock creates an attention block
|
||||
func newAttentionBlock(weights *safetensors.ModelWeights, prefix string, dim int32) (*AttentionBlock, error) {
|
||||
norm, err := newRMSNorm3D(weights, prefix+".norm", dim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
toQKV, _ := weights.Get(prefix + ".to_qkv.weight")
|
||||
toQKVBias, _ := weights.Get(prefix + ".to_qkv.bias")
|
||||
proj, _ := weights.Get(prefix + ".proj.weight")
|
||||
projBias, _ := weights.Get(prefix + ".proj.bias")
|
||||
|
||||
return &AttentionBlock{
|
||||
Norm: norm,
|
||||
ToQKV: toQKV,
|
||||
ToQKVBias: toQKVBias,
|
||||
Proj: proj,
|
||||
ProjBias: projBias,
|
||||
Dim: dim,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies 2D attention
|
||||
// Input: [B, T, H, W, C] (channels-last)
|
||||
func (a *AttentionBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
T := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
C := shape[4]
|
||||
|
||||
identity := x
|
||||
|
||||
// Flatten to [B*T, 1, H, W, C] for norm
|
||||
x = mlx.Reshape(x, B*T, 1, H, W, C)
|
||||
x = a.Norm.Forward(x)
|
||||
x = mlx.Reshape(x, B*T, H, W, C)
|
||||
|
||||
// Flatten spatial to [B*T, H*W, C]
|
||||
x = mlx.Reshape(x, B*T, H*W, C)
|
||||
|
||||
// Linear to get Q, K, V
|
||||
wShape := a.ToQKV.Shape()
|
||||
var w *mlx.Array
|
||||
if len(wShape) == 4 {
|
||||
w = mlx.Reshape(a.ToQKV, wShape[0], wShape[1])
|
||||
} else {
|
||||
w = a.ToQKV
|
||||
}
|
||||
w = mlx.Transpose(w, 1, 0)
|
||||
|
||||
qkv := mlx.Linear(x, w)
|
||||
if a.ToQKVBias != nil {
|
||||
qkv = mlx.Add(qkv, a.ToQKVBias)
|
||||
}
|
||||
qkv = mlx.Reshape(qkv, B*T, 1, H*W, 3*C)
|
||||
|
||||
q := mlx.Slice(qkv, []int32{0, 0, 0, 0}, []int32{B * T, 1, H * W, C})
|
||||
k := mlx.Slice(qkv, []int32{0, 0, 0, C}, []int32{B * T, 1, H * W, 2 * C})
|
||||
v := mlx.Slice(qkv, []int32{0, 0, 0, 2 * C}, []int32{B * T, 1, H * W, 3 * C})
|
||||
|
||||
scale := float32(1.0 / math.Sqrt(float64(C)))
|
||||
out := mlx.ScaledDotProductAttention(q, k, v, scale, false)
|
||||
|
||||
out = mlx.Reshape(out, B*T, H*W, C)
|
||||
|
||||
// Project back
|
||||
pShape := a.Proj.Shape()
|
||||
var p *mlx.Array
|
||||
if len(pShape) == 4 {
|
||||
p = mlx.Reshape(a.Proj, pShape[0], pShape[1])
|
||||
} else {
|
||||
p = a.Proj
|
||||
}
|
||||
p = mlx.Transpose(p, 1, 0)
|
||||
out = mlx.Linear(out, p)
|
||||
if a.ProjBias != nil {
|
||||
out = mlx.Add(out, a.ProjBias)
|
||||
}
|
||||
|
||||
out = mlx.Reshape(out, B, T, H, W, C)
|
||||
return mlx.Add(out, identity)
|
||||
}
|
||||
|
||||
// UpBlock handles upsampling in decoder
|
||||
type UpBlock struct {
|
||||
ResBlocks []*ResBlock
|
||||
Upsampler *Upsample
|
||||
}
|
||||
|
||||
// newUpBlock creates an up block
|
||||
func newUpBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32, numBlocks int32, upsampleMode string) (*UpBlock, error) {
|
||||
resBlocks := make([]*ResBlock, numBlocks+1)
|
||||
|
||||
currentDim := inDim
|
||||
for i := int32(0); i <= numBlocks; i++ {
|
||||
resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i)
|
||||
block, err := newResBlock(weights, resPrefix, currentDim, outDim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resBlocks[i] = block
|
||||
currentDim = outDim
|
||||
}
|
||||
|
||||
var upsampler *Upsample
|
||||
if upsampleMode != "" {
|
||||
upsampler = newUpsample(weights, prefix+".upsamplers.0", outDim, upsampleMode)
|
||||
}
|
||||
|
||||
return &UpBlock{
|
||||
ResBlocks: resBlocks,
|
||||
Upsampler: upsampler,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies up block
|
||||
func (u *UpBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
for _, block := range u.ResBlocks {
|
||||
prev := x
|
||||
x = block.Forward(x)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
if u.Upsampler != nil {
|
||||
prev := x
|
||||
x = u.Upsampler.Forward(x)
|
||||
prev.Free()
|
||||
}
|
||||
return x
|
||||
}
|
||||
|
||||
// Upsample handles spatial upsampling
|
||||
type Upsample struct {
|
||||
Conv *mlx.Array
|
||||
Bias *mlx.Array
|
||||
Mode string
|
||||
}
|
||||
|
||||
// newUpsample creates an upsampler
|
||||
func newUpsample(weights *safetensors.ModelWeights, prefix string, dim int32, mode string) *Upsample {
|
||||
conv, _ := weights.Get(prefix + ".resample.1.weight")
|
||||
bias, _ := weights.Get(prefix + ".resample.1.bias")
|
||||
return &Upsample{
|
||||
Conv: conv,
|
||||
Bias: bias,
|
||||
Mode: mode,
|
||||
}
|
||||
}
|
||||
|
||||
// Forward applies upsampling to channels-last input [B, T, H, W, C]
|
||||
func (u *Upsample) Forward(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
T := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
C := shape[4]
|
||||
outC := u.Conv.Shape()[0]
|
||||
|
||||
// Stage 1: 2x nearest neighbor upsample
|
||||
{
|
||||
x = mlx.Reshape(x, B*T, H, W, C)
|
||||
x = upsample2xChannelsLast(x)
|
||||
mlx.Eval(x)
|
||||
}
|
||||
|
||||
// Stage 2: Conv + bias
|
||||
{
|
||||
prev := x
|
||||
weight := mlx.Transpose(u.Conv, 0, 2, 3, 1)
|
||||
x = conv2D3x3PaddedChannelsLast(x, weight)
|
||||
if u.Bias != nil {
|
||||
bias := mlx.Reshape(u.Bias, 1, 1, 1, outC)
|
||||
x = mlx.Add(x, bias)
|
||||
}
|
||||
x = mlx.Reshape(x, B, T, H*2, W*2, outC)
|
||||
prev.Free()
|
||||
mlx.Eval(x)
|
||||
}
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// MidBlock is the middle block
|
||||
type MidBlock struct {
|
||||
ResBlock1 *ResBlock
|
||||
Attention *AttentionBlock
|
||||
ResBlock2 *ResBlock
|
||||
}
|
||||
|
||||
// newMidBlock creates a mid block
|
||||
func newMidBlock(weights *safetensors.ModelWeights, prefix string, dim int32) (*MidBlock, error) {
|
||||
res1, err := newResBlock(weights, prefix+".resnets.0", dim, dim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
attn, err := newAttentionBlock(weights, prefix+".attentions.0", dim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
res2, err := newResBlock(weights, prefix+".resnets.1", dim, dim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &MidBlock{
|
||||
ResBlock1: res1,
|
||||
Attention: attn,
|
||||
ResBlock2: res2,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies mid block
|
||||
func (m *MidBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
prev := x
|
||||
x = m.ResBlock1.Forward(x)
|
||||
prev.Free()
|
||||
|
||||
prev = x
|
||||
x = m.Attention.Forward(x)
|
||||
prev.Free()
|
||||
|
||||
prev = x
|
||||
x = m.ResBlock2.Forward(x)
|
||||
prev.Free()
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
func silu3D(x *mlx.Array) *mlx.Array {
|
||||
return mlx.Mul(x, mlx.Sigmoid(x))
|
||||
}
|
||||
|
||||
// pad3DChannelsLast pads a channels-last [B, T, H, W, C] tensor
|
||||
func pad3DChannelsLast(x *mlx.Array, tBefore, tAfter, hBefore, hAfter, wBefore, wAfter int32) *mlx.Array {
|
||||
if tBefore == 0 && tAfter == 0 && hBefore == 0 && hAfter == 0 && wBefore == 0 && wAfter == 0 {
|
||||
return x
|
||||
}
|
||||
return mlx.Pad(x, []int32{0, 0, tBefore, tAfter, hBefore, hAfter, wBefore, wAfter, 0, 0})
|
||||
}
|
||||
|
||||
// upsample2xChannelsLast upsamples channels-last input [B, H, W, C] by 2x
|
||||
func upsample2xChannelsLast(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
H := shape[1]
|
||||
W := shape[2]
|
||||
|
||||
rowIdxData := make([]int32, H*2)
|
||||
for i := int32(0); i < H; i++ {
|
||||
rowIdxData[i*2] = i
|
||||
rowIdxData[i*2+1] = i
|
||||
}
|
||||
rowIdx := mlx.NewArrayInt32(rowIdxData, []int32{H * 2})
|
||||
|
||||
colIdxData := make([]int32, W*2)
|
||||
for i := int32(0); i < W; i++ {
|
||||
colIdxData[i*2] = i
|
||||
colIdxData[i*2+1] = i
|
||||
}
|
||||
colIdx := mlx.NewArrayInt32(colIdxData, []int32{W * 2})
|
||||
|
||||
x = mlx.Take(x, rowIdx, 1)
|
||||
x = mlx.Take(x, colIdx, 2)
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// conv2D3x3PaddedChannelsLast applies 3x3 conv with padding to channels-last input [B, H, W, C]
|
||||
func conv2D3x3PaddedChannelsLast(x, weight *mlx.Array) *mlx.Array {
|
||||
x = mlx.Pad(x, []int32{0, 0, 1, 1, 1, 1, 0, 0})
|
||||
return mlx.Conv2d(x, weight, 1, 0)
|
||||
}
|
||||
|
||||
// conv2DStrided applies conv with stride > 1 using manual patch extraction
|
||||
// x: [B, H, W, C] (channels-last), weight: [O, kH, kW, I]
|
||||
func conv2DStrided(x, weight *mlx.Array, stride int32) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
H := shape[1]
|
||||
W := shape[2]
|
||||
|
||||
wShape := weight.Shape()
|
||||
Cout := wShape[0]
|
||||
kH := wShape[1]
|
||||
kW := wShape[2]
|
||||
|
||||
outH := (H - kH) / stride + 1
|
||||
outW := (W - kW) / stride + 1
|
||||
|
||||
patches := extractPatches2DStrided(x, kH, kW, stride)
|
||||
wFlat := mlx.Reshape(weight, Cout, -1)
|
||||
patches = mlx.Reshape(patches, B*outH*outW, -1)
|
||||
out := mlx.Linear(patches, mlx.Transpose(wFlat, 1, 0))
|
||||
return mlx.Reshape(out, B, outH, outW, Cout)
|
||||
}
|
||||
|
||||
// conv3DStrided applies 3D conv with strides using manual patch extraction
|
||||
// x: [B, T, H, W, C] (channels-last), weight: [O, I, kT, kH, kW] (PyTorch format)
|
||||
// strideT, strideH, strideW are the strides for each dimension
|
||||
// Patches are extracted in [C, T, H, W] order to match Python's preprocessing
|
||||
func conv3DStrided(x, weight *mlx.Array, strideT, strideH, strideW int32) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
T := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
C := shape[4]
|
||||
|
||||
wShape := weight.Shape()
|
||||
Cout := wShape[0]
|
||||
// I := wShape[1]
|
||||
kT := wShape[2]
|
||||
kH := wShape[3]
|
||||
kW := wShape[4]
|
||||
|
||||
// For temporal: if T < kT, we need to repeat frames temporally
|
||||
// For single image with T=1 and kT=2, we duplicate the frame to T=kT
|
||||
// Python Qwen2.5-VL duplicates the frame, not zero-pads
|
||||
if T < kT {
|
||||
// Tile along T dimension: [B, T, H, W, C] -> [B, kT, H, W, C]
|
||||
x = mlx.Tile(x, []int32{1, kT, 1, 1, 1})
|
||||
T = kT
|
||||
}
|
||||
|
||||
outT := (T - kT) / strideT + 1
|
||||
outH := (H - kH) / strideH + 1
|
||||
outW := (W - kW) / strideW + 1
|
||||
|
||||
// Extract 3D patches in [C, T, H, W] order to match Python
|
||||
patches := extractPatches3DStrided(x, kT, kH, kW, strideT, strideH, strideW)
|
||||
// patches shape: [B, outT, outH, outW, C*kT*kH*kW]
|
||||
|
||||
// Weight is [O, I, kT, kH, kW] - flatten to [O, I*kT*kH*kW] to match patch order [C, T, H, W]
|
||||
wFlat := mlx.Reshape(weight, Cout, -1) // [Cout, I*kT*kH*kW]
|
||||
patches = mlx.Reshape(patches, B*outT*outH*outW, C*kT*kH*kW)
|
||||
out := mlx.Linear(patches, mlx.Transpose(wFlat, 1, 0))
|
||||
return mlx.Reshape(out, B, outT, outH, outW, Cout)
|
||||
}
|
||||
|
||||
// extractPatches3DStrided extracts 3D patches with given strides
|
||||
// Returns patches with values in [C, T, H, W] order to match Python's preprocessing
|
||||
func extractPatches3DStrided(x *mlx.Array, kT, kH, kW, strideT, strideH, strideW int32) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
T := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
C := shape[4]
|
||||
|
||||
outT := (T - kT) / strideT + 1
|
||||
outH := (H - kH) / strideH + 1
|
||||
outW := (W - kW) / strideW + 1
|
||||
|
||||
numPatches := outT * outH * outW
|
||||
patches := make([]*mlx.Array, numPatches)
|
||||
idx := 0
|
||||
for t := int32(0); t < outT; t++ {
|
||||
for i := int32(0); i < outH; i++ {
|
||||
for j := int32(0); j < outW; j++ {
|
||||
startT := t * strideT
|
||||
startH := i * strideH
|
||||
startW := j * strideW
|
||||
// Extract patch: [B, kT, kH, kW, C]
|
||||
patch := mlx.Slice(x,
|
||||
[]int32{0, startT, startH, startW, 0},
|
||||
[]int32{B, startT + kT, startH + kH, startW + kW, C})
|
||||
// Transpose from [B, T, H, W, C] to [B, C, T, H, W] to match Python's order
|
||||
patch = mlx.Transpose(patch, 0, 4, 1, 2, 3)
|
||||
// Flatten to [B, C*T*H*W]
|
||||
patch = mlx.Reshape(patch, B, C*kT*kH*kW)
|
||||
patches[idx] = patch
|
||||
idx++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for i := range patches {
|
||||
patches[i] = mlx.ExpandDims(patches[i], 1)
|
||||
}
|
||||
stacked := mlx.Concatenate(patches, 1)
|
||||
return mlx.Reshape(stacked, B, outT, outH, outW, C*kT*kH*kW)
|
||||
}
|
||||
|
||||
// extractPatches2DStrided extracts patches with given stride
|
||||
func extractPatches2DStrided(x *mlx.Array, kH, kW, stride int32) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
H := shape[1]
|
||||
W := shape[2]
|
||||
C := shape[3]
|
||||
|
||||
outH := (H - kH) / stride + 1
|
||||
outW := (W - kW) / stride + 1
|
||||
|
||||
patches := make([]*mlx.Array, outH*outW)
|
||||
idx := 0
|
||||
for i := int32(0); i < outH; i++ {
|
||||
for j := int32(0); j < outW; j++ {
|
||||
startH := i * stride
|
||||
startW := j * stride
|
||||
patch := mlx.Slice(x, []int32{0, startH, startW, 0}, []int32{B, startH + kH, startW + kW, C})
|
||||
patch = mlx.Reshape(patch, B, kH*kW*C)
|
||||
patches[idx] = patch
|
||||
idx++
|
||||
}
|
||||
}
|
||||
|
||||
for i := range patches {
|
||||
patches[i] = mlx.ExpandDims(patches[i], 1)
|
||||
}
|
||||
stacked := mlx.Concatenate(patches, 1)
|
||||
return mlx.Reshape(stacked, B, outH, outW, kH*kW*C)
|
||||
}
|
||||
|
||||
// layerNormNoAffine applies layer norm without learnable parameters
|
||||
func layerNormNoAffine(x *mlx.Array, eps float32) *mlx.Array {
|
||||
ndim := x.Ndim()
|
||||
lastAxis := ndim - 1
|
||||
mean := mlx.Mean(x, lastAxis, true)
|
||||
xCentered := mlx.Sub(x, mean)
|
||||
variance := mlx.Mean(mlx.Square(xCentered), lastAxis, true)
|
||||
return mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, eps)))
|
||||
}
|
||||
@@ -1,475 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package qwen_image_edit
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"image"
|
||||
"image/color"
|
||||
_ "image/jpeg"
|
||||
_ "image/png"
|
||||
"math"
|
||||
"os"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"golang.org/x/image/draw"
|
||||
_ "golang.org/x/image/webp"
|
||||
)
|
||||
|
||||
// loadImageFile loads an image from disk
|
||||
func loadImageFile(path string) (image.Image, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open image: %w", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
img, _, err := image.Decode(f)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode image: %w", err)
|
||||
}
|
||||
return img, nil
|
||||
}
|
||||
|
||||
// imageToFloat32Pixels converts an image to a float32 pixel array [H, W, C] in [0, 1] range
|
||||
func imageToFloat32Pixels(img image.Image, width, height int) []float32 {
|
||||
pixels := make([]float32, width*height*3)
|
||||
idx := 0
|
||||
for y := 0; y < height; y++ {
|
||||
for x := 0; x < width; x++ {
|
||||
r, g, b, _ := img.At(x, y).RGBA()
|
||||
pixels[idx] = float32(r) / 65535.0
|
||||
pixels[idx+1] = float32(g) / 65535.0
|
||||
pixels[idx+2] = float32(b) / 65535.0
|
||||
idx += 3
|
||||
}
|
||||
}
|
||||
return pixels
|
||||
}
|
||||
|
||||
// normalizeImageNet applies ImageNet normalization to an image tensor
|
||||
func (p *Processor) normalizeImageNet(arr *mlx.Array) *mlx.Array {
|
||||
mean := mlx.NewArray(p.Config.ImageMean, []int32{1, 1, 3})
|
||||
std := mlx.NewArray(p.Config.ImageStd, []int32{1, 1, 3})
|
||||
return mlx.Div(mlx.Sub(arr, mean), std)
|
||||
}
|
||||
|
||||
// prepareImageTensor transforms [H, W, C] to [B, C, H, W] and converts to bf16
|
||||
func prepareImageTensor(arr *mlx.Array) *mlx.Array {
|
||||
// Transpose to [C, H, W] and make contiguous
|
||||
arr = mlx.Contiguous(mlx.Transpose(arr, 2, 0, 1))
|
||||
// Add batch dimension [1, C, H, W]
|
||||
arr = mlx.ExpandDims(arr, 0)
|
||||
// Convert to bf16
|
||||
arr = mlx.ToBFloat16(arr)
|
||||
mlx.Eval(arr)
|
||||
return arr
|
||||
}
|
||||
|
||||
// clampFloat clamps a value to [0, 255] and returns uint8
|
||||
func clampFloat(v, weightSum float64) uint8 {
|
||||
v /= weightSum
|
||||
if v < 0 {
|
||||
v = 0
|
||||
}
|
||||
if v > 255 {
|
||||
v = 255
|
||||
}
|
||||
return uint8(math.Round(v))
|
||||
}
|
||||
|
||||
// ImageDims holds dimensions for a preprocessed image
|
||||
type ImageDims struct {
|
||||
// Original image dimensions
|
||||
OrigW, OrigH int32
|
||||
// Condition image dimensions (for vision encoder)
|
||||
CondW, CondH int32
|
||||
// VAE image dimensions
|
||||
VaeW, VaeH int32
|
||||
// Latent dimensions (VAE dims / vae_scale_factor)
|
||||
LatentW, LatentH int32
|
||||
// Patch dimensions (latent dims / patch_size)
|
||||
PatchW, PatchH int32
|
||||
}
|
||||
|
||||
// ProcessorConfig holds image processor configuration
|
||||
type ProcessorConfig struct {
|
||||
// Condition image size (target pixel area for vision encoder input)
|
||||
// Python: CONDITION_IMAGE_SIZE = 384 * 384 = 147456
|
||||
// Pipeline resizes image to this area before passing to encode_prompt
|
||||
ConditionImageSize int32
|
||||
|
||||
// VAE image size (target pixel area)
|
||||
// Python: VAE_IMAGE_SIZE = 1024 * 1024 = 1048576
|
||||
VAEImageSize int32
|
||||
|
||||
// Image normalization (ImageNet stats for vision encoder)
|
||||
ImageMean []float32
|
||||
ImageStd []float32
|
||||
}
|
||||
|
||||
// defaultProcessorConfig returns default processor config
|
||||
func defaultProcessorConfig() *ProcessorConfig {
|
||||
return &ProcessorConfig{
|
||||
ConditionImageSize: 384 * 384, // 147456 - matches Python CONDITION_IMAGE_SIZE
|
||||
VAEImageSize: 1024 * 1024, // 1048576 - matches Python VAE_IMAGE_SIZE
|
||||
ImageMean: []float32{0.48145466, 0.4578275, 0.40821073},
|
||||
ImageStd: []float32{0.26862954, 0.26130258, 0.27577711},
|
||||
}
|
||||
}
|
||||
|
||||
// Processor handles image preprocessing for Qwen-Image-Edit
|
||||
type Processor struct {
|
||||
Config *ProcessorConfig
|
||||
}
|
||||
|
||||
// Load loads the processor config
|
||||
func (p *Processor) Load(path string) error {
|
||||
p.Config = defaultProcessorConfig()
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadAndPreprocess loads an image and preprocesses it for both paths
|
||||
// Returns: condImage (for vision encoder), vaeImage (for VAE encoding)
|
||||
func (p *Processor) LoadAndPreprocess(imagePath string) (*mlx.Array, *mlx.Array, error) {
|
||||
img, err := loadImageFile(imagePath)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
bounds := img.Bounds()
|
||||
origW := bounds.Dx()
|
||||
origH := bounds.Dy()
|
||||
ratio := float64(origW) / float64(origH)
|
||||
|
||||
// Calculate dimensions for condition image (vision encoder)
|
||||
// Python pipeline does TWO resizes:
|
||||
// 1. VaeImageProcessor.resize with Lanczos to CONDITION_IMAGE_SIZE (384x384 area)
|
||||
// 2. Qwen2VLProcessor's smart_resize with Bicubic to multiple of 28
|
||||
intermediateW, intermediateH := calculateDimensions(p.Config.ConditionImageSize, ratio, 32)
|
||||
finalH, finalW := smartResize(intermediateH, intermediateW, 28, 56*56, 28*28*1280)
|
||||
|
||||
// Calculate dimensions for VAE image (1024x1024 area)
|
||||
// Use multiple of 32 (vae_scale_factor * patch_size * 2 = 8 * 2 * 2 = 32)
|
||||
vaeW, vaeH := calculateDimensions(p.Config.VAEImageSize, ratio, 32)
|
||||
|
||||
// Preprocess for condition (vision encoder) - two-step resize
|
||||
condImage := p.preprocessImageTwoStep(img, intermediateW, intermediateH, finalW, finalH)
|
||||
|
||||
// Preprocess for VAE ([-1, 1] range, 5D tensor)
|
||||
vaeImage := p.preprocessImageForVAE(img, vaeW, vaeH)
|
||||
|
||||
return condImage, vaeImage, nil
|
||||
}
|
||||
|
||||
// preprocessImageLanczos does single-step Lanczos resize for vision encoder
|
||||
// Matches Python VaeImageProcessor.resize with resample='lanczos' (the default)
|
||||
// Used by edit_plus pipeline for multi-image input
|
||||
// Returns: [B, C, H, W] normalized tensor
|
||||
func (p *Processor) preprocessImageLanczos(img image.Image, width, height int32) *mlx.Array {
|
||||
resized := resizeImageLanczos(img, int(width), int(height))
|
||||
pixels := imageToFloat32Pixels(resized, int(width), int(height))
|
||||
arr := mlx.NewArray(pixels, []int32{height, width, 3})
|
||||
arr = p.normalizeImageNet(arr)
|
||||
return prepareImageTensor(arr)
|
||||
}
|
||||
|
||||
// preprocessImageTwoStep does two-step resize for vision encoder to match Python pipeline
|
||||
// Step 1: Lanczos resize from original to intermediate size (VaeImageProcessor.resize)
|
||||
// Step 2: Bicubic resize from intermediate to final size (Qwen2VLProcessor smart_resize)
|
||||
// Returns: [B, C, H, W] normalized tensor
|
||||
func (p *Processor) preprocessImageTwoStep(img image.Image, intermediateW, intermediateH, finalW, finalH int32) *mlx.Array {
|
||||
intermediate := resizeImageLanczos(img, int(intermediateW), int(intermediateH))
|
||||
resized := resizeImageBicubic(intermediate, int(finalW), int(finalH))
|
||||
pixels := imageToFloat32Pixels(resized, int(finalW), int(finalH))
|
||||
arr := mlx.NewArray(pixels, []int32{finalH, finalW, 3})
|
||||
arr = p.normalizeImageNet(arr)
|
||||
return prepareImageTensor(arr)
|
||||
}
|
||||
|
||||
// preprocessImage converts image to tensor for vision encoder
|
||||
// Returns: [B, C, H, W] normalized tensor
|
||||
func (p *Processor) preprocessImage(img image.Image, width, height int32, normalize bool) *mlx.Array {
|
||||
resized := resizeImageBicubic(img, int(width), int(height))
|
||||
pixels := imageToFloat32Pixels(resized, int(width), int(height))
|
||||
arr := mlx.NewArray(pixels, []int32{height, width, 3})
|
||||
if normalize {
|
||||
arr = p.normalizeImageNet(arr)
|
||||
}
|
||||
return prepareImageTensor(arr)
|
||||
}
|
||||
|
||||
// preprocessImageForVAE converts image to tensor for VAE encoding
|
||||
// Returns: [B, C, T, H, W] tensor in [-1, 1] range
|
||||
func (p *Processor) preprocessImageForVAE(img image.Image, width, height int32) *mlx.Array {
|
||||
resized := resizeImageLanczos(img, int(width), int(height))
|
||||
pixels := imageToFloat32Pixels(resized, int(width), int(height))
|
||||
arr := mlx.NewArray(pixels, []int32{height, width, 3})
|
||||
|
||||
// Scale to [-1, 1]: arr * 2 - 1
|
||||
arr = mlx.MulScalar(arr, 2.0)
|
||||
arr = mlx.AddScalar(arr, -1.0)
|
||||
|
||||
// Transpose to [C, H, W] and make contiguous
|
||||
arr = mlx.Contiguous(mlx.Transpose(arr, 2, 0, 1))
|
||||
|
||||
// Add batch and temporal dimensions [1, C, 1, H, W]
|
||||
arr = mlx.ExpandDims(arr, 0) // [1, C, H, W]
|
||||
arr = mlx.ExpandDims(arr, 2) // [1, C, 1, H, W]
|
||||
|
||||
arr = mlx.ToBFloat16(arr)
|
||||
mlx.Eval(arr)
|
||||
return arr
|
||||
}
|
||||
|
||||
// smartResize implements Python Qwen2VL processor's smart_resize logic
|
||||
// Returns (resizedHeight, resizedWidth) that fit within min/max pixel constraints
|
||||
func smartResize(height, width, factor, minPixels, maxPixels int32) (int32, int32) {
|
||||
// Round to factor
|
||||
hBar := int32(math.Round(float64(height)/float64(factor))) * factor
|
||||
wBar := int32(math.Round(float64(width)/float64(factor))) * factor
|
||||
|
||||
// Ensure minimum factor size
|
||||
if hBar < factor {
|
||||
hBar = factor
|
||||
}
|
||||
if wBar < factor {
|
||||
wBar = factor
|
||||
}
|
||||
|
||||
// Check pixel constraints
|
||||
total := hBar * wBar
|
||||
if total > maxPixels {
|
||||
// Scale down
|
||||
beta := math.Sqrt(float64(maxPixels) / float64(total))
|
||||
hBar = int32(math.Floor(float64(height)*beta/float64(factor))) * factor
|
||||
wBar = int32(math.Floor(float64(width)*beta/float64(factor))) * factor
|
||||
} else if total < minPixels {
|
||||
// Scale up
|
||||
beta := math.Sqrt(float64(minPixels) / float64(total))
|
||||
hBar = int32(math.Ceil(float64(height)*beta/float64(factor))) * factor
|
||||
wBar = int32(math.Ceil(float64(width)*beta/float64(factor))) * factor
|
||||
}
|
||||
|
||||
return hBar, wBar
|
||||
}
|
||||
|
||||
// calculateDimensions calculates width and height for a target area while maintaining ratio
|
||||
// multiple: the value to round dimensions to (e.g., 28 for vision encoder with patch 14 and 2x2 merge)
|
||||
func calculateDimensions(targetArea int32, ratio float64, multiple int32) (int32, int32) {
|
||||
width := math.Sqrt(float64(targetArea) * ratio)
|
||||
height := width / ratio
|
||||
|
||||
m := float64(multiple)
|
||||
width = math.Round(width/m) * m
|
||||
height = math.Round(height/m) * m
|
||||
|
||||
// Ensure minimum dimensions
|
||||
if width < m {
|
||||
width = m
|
||||
}
|
||||
if height < m {
|
||||
height = m
|
||||
}
|
||||
|
||||
return int32(width), int32(height)
|
||||
}
|
||||
|
||||
// resizeImageLanczos resizes an image using Lanczos3 interpolation (matches PIL.LANCZOS)
|
||||
func resizeImageLanczos(img image.Image, width, height int) image.Image {
|
||||
bounds := img.Bounds()
|
||||
dst := image.NewRGBA(image.Rect(0, 0, width, height))
|
||||
|
||||
// Lanczos3 kernel (a=3) to match PIL.LANCZOS
|
||||
lanczos3 := &draw.Kernel{
|
||||
Support: 3.0,
|
||||
At: func(t float64) float64 {
|
||||
if t == 0 {
|
||||
return 1.0
|
||||
}
|
||||
if t < 0 {
|
||||
t = -t
|
||||
}
|
||||
if t >= 3.0 {
|
||||
return 0.0
|
||||
}
|
||||
// sinc(t) * sinc(t/3)
|
||||
piT := math.Pi * t
|
||||
return (math.Sin(piT) / piT) * (math.Sin(piT/3) / (piT / 3))
|
||||
},
|
||||
}
|
||||
lanczos3.Scale(dst, dst.Bounds(), img, bounds, draw.Over, nil)
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
// resizeImageBicubic resizes an image using bicubic interpolation (matches PIL.BICUBIC)
|
||||
// Uses separable interpolation with PIL's coordinate mapping for exact match
|
||||
func resizeImageBicubic(img image.Image, width, height int) image.Image {
|
||||
bounds := img.Bounds()
|
||||
srcW := bounds.Dx()
|
||||
srcH := bounds.Dy()
|
||||
|
||||
// Convert to RGBA if needed
|
||||
var src *image.RGBA
|
||||
if rgba, ok := img.(*image.RGBA); ok {
|
||||
src = rgba
|
||||
} else {
|
||||
src = image.NewRGBA(bounds)
|
||||
for y := bounds.Min.Y; y < bounds.Max.Y; y++ {
|
||||
for x := bounds.Min.X; x < bounds.Max.X; x++ {
|
||||
src.Set(x, y, img.At(x, y))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Keys cubic with a=-0.5 (PIL BICUBIC)
|
||||
cubic := func(x float64) float64 {
|
||||
if x < 0 {
|
||||
x = -x
|
||||
}
|
||||
if x < 1 {
|
||||
return 1.5*x*x*x - 2.5*x*x + 1
|
||||
}
|
||||
if x < 2 {
|
||||
return -0.5*x*x*x + 2.5*x*x - 4*x + 2
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// Horizontal pass: srcW -> width, keep srcH rows
|
||||
temp := image.NewRGBA(image.Rect(0, 0, width, srcH))
|
||||
for y := 0; y < srcH; y++ {
|
||||
for dstX := 0; dstX < width; dstX++ {
|
||||
// PIL coordinate mapping: center-to-center
|
||||
srcXf := (float64(dstX)+0.5)*(float64(srcW)/float64(width)) - 0.5
|
||||
baseX := int(math.Floor(srcXf))
|
||||
|
||||
var sumR, sumG, sumB, sumA, weightSum float64
|
||||
for i := -1; i <= 2; i++ {
|
||||
sx := baseX + i
|
||||
if sx < 0 {
|
||||
sx = 0
|
||||
}
|
||||
if sx >= srcW {
|
||||
sx = srcW - 1
|
||||
}
|
||||
|
||||
w := cubic(math.Abs(srcXf - float64(baseX+i)))
|
||||
c := src.RGBAAt(sx, y)
|
||||
sumR += float64(c.R) * w
|
||||
sumG += float64(c.G) * w
|
||||
sumB += float64(c.B) * w
|
||||
sumA += float64(c.A) * w
|
||||
weightSum += w
|
||||
}
|
||||
|
||||
temp.SetRGBA(dstX, y, color.RGBA{
|
||||
clampFloat(sumR, weightSum),
|
||||
clampFloat(sumG, weightSum),
|
||||
clampFloat(sumB, weightSum),
|
||||
clampFloat(sumA, weightSum),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Vertical pass: srcH -> height
|
||||
dst := image.NewRGBA(image.Rect(0, 0, width, height))
|
||||
for x := 0; x < width; x++ {
|
||||
for dstY := 0; dstY < height; dstY++ {
|
||||
srcYf := (float64(dstY)+0.5)*(float64(srcH)/float64(height)) - 0.5
|
||||
baseY := int(math.Floor(srcYf))
|
||||
|
||||
var sumR, sumG, sumB, sumA, weightSum float64
|
||||
for j := -1; j <= 2; j++ {
|
||||
sy := baseY + j
|
||||
if sy < 0 {
|
||||
sy = 0
|
||||
}
|
||||
if sy >= srcH {
|
||||
sy = srcH - 1
|
||||
}
|
||||
|
||||
w := cubic(math.Abs(srcYf - float64(baseY+j)))
|
||||
c := temp.RGBAAt(x, sy)
|
||||
sumR += float64(c.R) * w
|
||||
sumG += float64(c.G) * w
|
||||
sumB += float64(c.B) * w
|
||||
sumA += float64(c.A) * w
|
||||
weightSum += w
|
||||
}
|
||||
|
||||
dst.SetRGBA(x, dstY, color.RGBA{
|
||||
clampFloat(sumR, weightSum),
|
||||
clampFloat(sumG, weightSum),
|
||||
clampFloat(sumB, weightSum),
|
||||
clampFloat(sumA, weightSum),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
// LoadAndPreprocessMultiple loads multiple images and preprocesses them
|
||||
// Returns: condImages (for vision encoder), vaeImages (for VAE encoding), dims (per-image dimensions)
|
||||
func (p *Processor) LoadAndPreprocessMultiple(imagePaths []string) ([]*mlx.Array, []*mlx.Array, []ImageDims, error) {
|
||||
const vaeScaleFactor int32 = 8
|
||||
const patchSize int32 = 2
|
||||
|
||||
condImages := make([]*mlx.Array, len(imagePaths))
|
||||
vaeImages := make([]*mlx.Array, len(imagePaths))
|
||||
dims := make([]ImageDims, len(imagePaths))
|
||||
|
||||
for i, imagePath := range imagePaths {
|
||||
img, err := loadImageFile(imagePath)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("image %d: %w", i, err)
|
||||
}
|
||||
|
||||
bounds := img.Bounds()
|
||||
origW := int32(bounds.Dx())
|
||||
origH := int32(bounds.Dy())
|
||||
ratio := float64(origW) / float64(origH)
|
||||
|
||||
// Calculate dimensions for condition image (vision encoder)
|
||||
// Python pipeline does TWO resizes:
|
||||
// 1. VaeImageProcessor.resize with Lanczos to CONDITION_IMAGE_SIZE (384x384 area)
|
||||
// 2. Qwen2VLProcessor's smart_resize with Bicubic to multiple of 28
|
||||
intermediateW, intermediateH := calculateDimensions(p.Config.ConditionImageSize, ratio, 32)
|
||||
condH, condW := smartResize(intermediateH, intermediateW, 28, 56*56, 28*28*1280)
|
||||
|
||||
// Calculate dimensions for VAE image (1024x1024 area)
|
||||
vaeW, vaeH := calculateDimensions(p.Config.VAEImageSize, ratio, 32)
|
||||
|
||||
// Calculate derived dimensions
|
||||
latentW := vaeW / vaeScaleFactor
|
||||
latentH := vaeH / vaeScaleFactor
|
||||
patchW := latentW / patchSize
|
||||
patchH := latentH / patchSize
|
||||
|
||||
dims[i] = ImageDims{
|
||||
OrigW: origW,
|
||||
OrigH: origH,
|
||||
CondW: condW,
|
||||
CondH: condH,
|
||||
VaeW: vaeW,
|
||||
VaeH: vaeH,
|
||||
LatentW: latentW,
|
||||
LatentH: latentH,
|
||||
PatchW: patchW,
|
||||
PatchH: patchH,
|
||||
}
|
||||
|
||||
fmt.Printf(" Image %d: orig=%dx%d, cond=%dx%d, vae=%dx%d, latent=%dx%d, patch=%dx%d\n",
|
||||
i+1, origW, origH, condW, condH, vaeW, vaeH, latentW, latentH, patchW, patchH)
|
||||
|
||||
// Preprocess for condition (vision encoder) - two-step resize to match Python pipeline
|
||||
condImages[i] = p.preprocessImageTwoStep(img, intermediateW, intermediateH, condW, condH)
|
||||
|
||||
// Preprocess for VAE ([-1, 1] range, 5D tensor)
|
||||
vaeImages[i] = p.preprocessImageForVAE(img, vaeW, vaeH)
|
||||
}
|
||||
|
||||
return condImages, vaeImages, dims, nil
|
||||
}
|
||||
@@ -1,610 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package qwen_image_edit implements the Qwen-Image-Edit diffusion model for image editing.
|
||||
// It reuses components from qwen_image where possible.
|
||||
package qwen_image_edit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/models/qwen_image"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
)
|
||||
|
||||
// GenerateConfig holds all options for image editing.
|
||||
type GenerateConfig struct {
|
||||
Prompt string
|
||||
NegativePrompt string // Unconditional prompt for CFG (empty string "" is valid)
|
||||
CFGScale float32 // CFG enabled when > 1.0 (default: 4.0)
|
||||
Width int32 // Output width (default: from input image)
|
||||
Height int32 // Output height (default: from input image)
|
||||
Steps int // Denoising steps (default: 50)
|
||||
Seed int64 // Random seed
|
||||
Progress ProgressFunc // Optional progress callback
|
||||
}
|
||||
|
||||
// ProgressFunc is called during generation with step progress.
|
||||
type ProgressFunc func(step, totalSteps int)
|
||||
|
||||
// Model represents a Qwen-Image-Edit diffusion model.
|
||||
type Model struct {
|
||||
ModelPath string
|
||||
Tokenizer *tokenizer.Tokenizer
|
||||
Processor *Processor // Image processor for vision encoder
|
||||
TextEncoder *qwen_image.Qwen25VL // Qwen2.5-VL vision-language encoder (from qwen_image)
|
||||
Transformer *qwen_image.Transformer // Reuse qwen_image transformer
|
||||
VAE *VAE // Combined encoder + decoder
|
||||
}
|
||||
|
||||
// Load loads the Qwen-Image-Edit model from a directory.
|
||||
func (m *Model) Load(modelPath string) error {
|
||||
fmt.Println("Loading Qwen-Image-Edit model...")
|
||||
start := time.Now()
|
||||
|
||||
if mlx.GPUIsAvailable() {
|
||||
mlx.SetDefaultDeviceGPU()
|
||||
mlx.EnableCompile()
|
||||
}
|
||||
|
||||
m.ModelPath = modelPath
|
||||
|
||||
// Load tokenizer from processor directory
|
||||
fmt.Print(" Loading tokenizer... ")
|
||||
processorPath := filepath.Join(modelPath, "processor")
|
||||
tok, err := tokenizer.Load(processorPath)
|
||||
if err != nil {
|
||||
// Fallback to tokenizer directory
|
||||
tokenizerPath := filepath.Join(modelPath, "tokenizer")
|
||||
tok, err = tokenizer.Load(tokenizerPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("tokenizer: %w", err)
|
||||
}
|
||||
}
|
||||
m.Tokenizer = tok
|
||||
fmt.Println("✓")
|
||||
|
||||
// Load processor (image preprocessing config)
|
||||
fmt.Print(" Loading processor... ")
|
||||
m.Processor = &Processor{}
|
||||
if err := m.Processor.Load(processorPath); err != nil {
|
||||
return fmt.Errorf("processor: %w", err)
|
||||
}
|
||||
fmt.Println("✓")
|
||||
|
||||
// Load vision-language text encoder (Qwen2.5-VL from qwen_image package)
|
||||
m.TextEncoder = &qwen_image.Qwen25VL{}
|
||||
if err := m.TextEncoder.Load(filepath.Join(modelPath, "text_encoder")); err != nil {
|
||||
return fmt.Errorf("text encoder: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.TextEncoder)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
// Load transformer (reuse qwen_image)
|
||||
m.Transformer = &qwen_image.Transformer{}
|
||||
if err := m.Transformer.Load(filepath.Join(modelPath, "transformer")); err != nil {
|
||||
return fmt.Errorf("transformer: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.Transformer)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
// Load VAE (encoder + decoder)
|
||||
m.VAE = &VAE{}
|
||||
if err := m.VAE.Load(filepath.Join(modelPath, "vae")); err != nil {
|
||||
return fmt.Errorf("VAE: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.VAE)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
mem := mlx.MetalGetActiveMemory()
|
||||
peak := mlx.MetalGetPeakMemory()
|
||||
fmt.Printf(" Loaded in %.2fs (%.1f GB active, %.1f GB peak)\n",
|
||||
time.Since(start).Seconds(),
|
||||
float64(mem)/(1024*1024*1024),
|
||||
float64(peak)/(1024*1024*1024))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Edit edits an image based on a text prompt.
|
||||
// inputImagePath: path to input image
|
||||
// prompt: text description of desired edit
|
||||
func (m *Model) Edit(inputImagePath string, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
|
||||
return m.EditFromConfig([]string{inputImagePath}, &GenerateConfig{
|
||||
Prompt: prompt,
|
||||
Width: width,
|
||||
Height: height,
|
||||
Steps: steps,
|
||||
Seed: seed,
|
||||
})
|
||||
}
|
||||
|
||||
// EditFromConfig edits images using the unified config struct.
|
||||
// Accepts one or more input images.
|
||||
func (m *Model) EditFromConfig(inputImagePaths []string, cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
if len(inputImagePaths) == 0 {
|
||||
return nil, fmt.Errorf("no input images provided")
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
result, err := m.edit(inputImagePaths, cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if cfg.NegativePrompt != "" {
|
||||
fmt.Printf("Edited %d image(s) with CFG (scale=%.1f) in %.2fs (%d steps)\n",
|
||||
len(inputImagePaths), cfg.CFGScale, time.Since(start).Seconds(), cfg.Steps)
|
||||
} else {
|
||||
fmt.Printf("Edited %d image(s) in %.2fs (%d steps)\n",
|
||||
len(inputImagePaths), time.Since(start).Seconds(), cfg.Steps)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// EditImage implements model.ImageEditModel interface.
|
||||
func (m *Model) EditImage(ctx context.Context, inputImagePath, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
|
||||
return m.Edit(inputImagePath, prompt, width, height, steps, seed)
|
||||
}
|
||||
|
||||
// EditMultiImage edits using multiple source images.
|
||||
// This matches diffusers' QwenImageEditPlusPipeline behavior.
|
||||
func (m *Model) EditMultiImage(inputImagePaths []string, cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
return m.EditFromConfig(inputImagePaths, cfg)
|
||||
}
|
||||
|
||||
// edit is the internal editing pipeline that handles one or more images.
|
||||
func (m *Model) edit(inputImagePaths []string, cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
// Apply defaults
|
||||
if cfg.Steps <= 0 {
|
||||
cfg.Steps = 50
|
||||
}
|
||||
if cfg.CFGScale <= 0 {
|
||||
cfg.CFGScale = 4.0
|
||||
}
|
||||
|
||||
// Load and preprocess all input images
|
||||
fmt.Printf("Loading %d image(s)...\n", len(inputImagePaths))
|
||||
condImages, vaeImages, inputDims, err := m.Processor.LoadAndPreprocessMultiple(inputImagePaths)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("preprocess images: %w", err)
|
||||
}
|
||||
for _, img := range condImages {
|
||||
mlx.Keep(img)
|
||||
}
|
||||
for _, img := range vaeImages {
|
||||
mlx.Keep(img)
|
||||
}
|
||||
mlx.Eval(append(condImages, vaeImages...)...)
|
||||
|
||||
useCFG := cfg.NegativePrompt != ""
|
||||
tcfg := m.Transformer.Config
|
||||
vaeScaleFactor := int32(8)
|
||||
|
||||
// Output dimensions - if not specified, use first input image dimensions
|
||||
if cfg.Width <= 0 {
|
||||
cfg.Width = inputDims[0].VaeW
|
||||
}
|
||||
if cfg.Height <= 0 {
|
||||
cfg.Height = inputDims[0].VaeH
|
||||
}
|
||||
|
||||
// Output (noise) latent dimensions
|
||||
outLatentH := cfg.Height / vaeScaleFactor
|
||||
outLatentW := cfg.Width / vaeScaleFactor
|
||||
outPH := outLatentH / tcfg.PatchSize
|
||||
outPW := outLatentW / tcfg.PatchSize
|
||||
noiseSeqLen := outPH * outPW
|
||||
imgSeqLen := noiseSeqLen
|
||||
|
||||
// Encode prompt with all images for conditioning
|
||||
posEmb, _, _, err := m.TextEncoder.EncodePromptWithImages(m.Tokenizer, cfg.Prompt, condImages)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("encoding prompt: %w", err)
|
||||
}
|
||||
mlx.Keep(posEmb)
|
||||
mlx.Eval(posEmb)
|
||||
|
||||
var negEmb *mlx.Array
|
||||
if useCFG {
|
||||
negEmb, _, _, err = m.TextEncoder.EncodePromptWithImages(m.Tokenizer, cfg.NegativePrompt, condImages)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("encoding negative prompt: %w", err)
|
||||
}
|
||||
mlx.Keep(negEmb)
|
||||
mlx.Eval(negEmb)
|
||||
}
|
||||
|
||||
// Pad sequences to same length for CFG
|
||||
txtLen := posEmb.Shape()[1]
|
||||
if useCFG {
|
||||
negLen := negEmb.Shape()[1]
|
||||
if negLen > txtLen {
|
||||
txtLen = negLen
|
||||
}
|
||||
if posEmb.Shape()[1] < txtLen {
|
||||
posEmb = padSequence(posEmb, txtLen)
|
||||
}
|
||||
if negEmb.Shape()[1] < txtLen {
|
||||
negEmb = padSequence(negEmb, txtLen)
|
||||
}
|
||||
mlx.Keep(posEmb, negEmb)
|
||||
mlx.Eval(posEmb, negEmb)
|
||||
}
|
||||
|
||||
// Encode all input images to latents and concatenate
|
||||
fmt.Println("Encoding images to latents...")
|
||||
allImageLatentsPacked := make([]*mlx.Array, len(vaeImages))
|
||||
for i, vaeImage := range vaeImages {
|
||||
imageLatents := m.VAE.Encode(vaeImage)
|
||||
imageLatents = m.VAE.Normalize(imageLatents)
|
||||
imageLatents2D := mlx.Squeeze(imageLatents, 2)
|
||||
packed := qwen_image.PackLatents(imageLatents2D, tcfg.PatchSize)
|
||||
mlx.Keep(packed)
|
||||
mlx.Eval(packed)
|
||||
allImageLatentsPacked[i] = packed
|
||||
}
|
||||
|
||||
imageLatentsPacked := mlx.Concatenate(allImageLatentsPacked, 1)
|
||||
mlx.Keep(imageLatentsPacked)
|
||||
mlx.Eval(imageLatentsPacked)
|
||||
|
||||
// Scheduler
|
||||
scheduler := qwen_image.NewFlowMatchScheduler(qwen_image.DefaultSchedulerConfig())
|
||||
scheduler.SetTimesteps(cfg.Steps, noiseSeqLen)
|
||||
|
||||
// Init noise latents in packed format
|
||||
packedChannels := tcfg.OutChannels * tcfg.PatchSize * tcfg.PatchSize
|
||||
packedNoise := scheduler.InitNoisePacked(1, noiseSeqLen, packedChannels, cfg.Seed)
|
||||
latents := qwen_image.UnpackLatents(packedNoise, outLatentH, outLatentW, tcfg.PatchSize)
|
||||
mlx.Eval(latents)
|
||||
|
||||
// RoPE cache
|
||||
ropeCache := PrepareRoPEMultiImage(outPH, outPW, inputDims, txtLen, tcfg.AxesDimsRope)
|
||||
mlx.Keep(ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
mlx.Eval(ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
|
||||
// Denoising loop
|
||||
fmt.Printf("Running denoising (%d steps)...\n", cfg.Steps)
|
||||
for i := 0; i < cfg.Steps; i++ {
|
||||
stepStart := time.Now()
|
||||
if cfg.Progress != nil {
|
||||
cfg.Progress(i+1, cfg.Steps)
|
||||
}
|
||||
|
||||
t := scheduler.Timesteps[i]
|
||||
timestep := mlx.ToBFloat16(mlx.NewArray([]float32{t}, []int32{1}))
|
||||
mlx.Eval(timestep)
|
||||
|
||||
latents2D := mlx.Squeeze(latents, 2)
|
||||
patches := qwen_image.PackLatents(latents2D, tcfg.PatchSize)
|
||||
latentInput := mlx.Concatenate([]*mlx.Array{patches, imageLatentsPacked}, 1)
|
||||
|
||||
var output *mlx.Array
|
||||
if useCFG {
|
||||
posOutput := m.Transformer.Forward(latentInput, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
negOutput := m.Transformer.Forward(latentInput, negEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
|
||||
posOutput = mlx.Slice(posOutput, []int32{0, 0, 0}, []int32{1, imgSeqLen, posOutput.Shape()[2]})
|
||||
negOutput = mlx.Slice(negOutput, []int32{0, 0, 0}, []int32{1, imgSeqLen, negOutput.Shape()[2]})
|
||||
|
||||
output = applyCFGWithNormRescale(posOutput, negOutput, cfg.CFGScale)
|
||||
} else {
|
||||
output = m.Transformer.Forward(latentInput, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
output = mlx.Slice(output, []int32{0, 0, 0}, []int32{1, imgSeqLen, output.Shape()[2]})
|
||||
}
|
||||
|
||||
noisePred := qwen_image.UnpackLatents(output, outLatentH, outLatentW, tcfg.PatchSize)
|
||||
oldLatents := latents
|
||||
latents = scheduler.Step(noisePred, latents, i)
|
||||
mlx.Eval(latents)
|
||||
oldLatents.Free()
|
||||
|
||||
fmt.Printf(" Step %d/%d: t=%.4f (%.2fs)\n", i+1, cfg.Steps, t, time.Since(stepStart).Seconds())
|
||||
}
|
||||
|
||||
// Free denoising temporaries
|
||||
posEmb.Free()
|
||||
if negEmb != nil {
|
||||
negEmb.Free()
|
||||
}
|
||||
ropeCache.ImgFreqs.Free()
|
||||
ropeCache.TxtFreqs.Free()
|
||||
imageLatentsPacked.Free()
|
||||
|
||||
// Decode latents
|
||||
decoded := m.decodeAndPostprocess(latents)
|
||||
latents.Free()
|
||||
|
||||
fmt.Printf(" Peak memory: %.2f GB\n", float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
return decoded, nil
|
||||
}
|
||||
|
||||
// applyCFGWithNormRescale applies classifier-free guidance with norm rescaling.
|
||||
// This prevents CFG from inflating magnitude too much.
|
||||
func applyCFGWithNormRescale(posOutput, negOutput *mlx.Array, scale float32) *mlx.Array {
|
||||
// Upcast to float32 for precision
|
||||
posF32 := mlx.AsType(posOutput, mlx.DtypeFloat32)
|
||||
negF32 := mlx.AsType(negOutput, mlx.DtypeFloat32)
|
||||
|
||||
// CFG: pred = neg + scale * (pos - neg)
|
||||
diff := mlx.Sub(posF32, negF32)
|
||||
scaledDiff := mlx.MulScalar(diff, scale)
|
||||
combPred := mlx.Add(negF32, scaledDiff)
|
||||
|
||||
// Norm rescaling: rescale combined prediction to match conditional norm
|
||||
condNorm := mlx.Sqrt(mlx.Sum(mlx.Square(posF32), -1, true))
|
||||
combNorm := mlx.Sqrt(mlx.Sum(mlx.Square(combPred), -1, true))
|
||||
output := mlx.Mul(combPred, mlx.Div(condNorm, combNorm))
|
||||
|
||||
mlx.Eval(output)
|
||||
return mlx.ToBFloat16(output)
|
||||
}
|
||||
|
||||
// decodeAndPostprocess denormalizes latents, decodes through VAE, and scales to [0,1].
|
||||
func (m *Model) decodeAndPostprocess(latents *mlx.Array) *mlx.Array {
|
||||
latents = m.VAE.Denormalize(latents)
|
||||
decoded := m.VAE.Decode(latents)
|
||||
|
||||
// Post-process: squeeze temporal dim and rescale to [0, 1]
|
||||
decoded = mlx.Squeeze(decoded, 2)
|
||||
decoded = mlx.AddScalar(decoded, 1.0)
|
||||
decoded = mlx.DivScalar(decoded, 2.0)
|
||||
decoded = mlx.ClipScalar(decoded, 0.0, 1.0, true, true)
|
||||
mlx.Eval(decoded)
|
||||
return decoded
|
||||
}
|
||||
|
||||
// padSequence pads a sequence tensor to the target length with zeros
|
||||
func padSequence(x *mlx.Array, targetLen int32) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
currentLen := shape[1]
|
||||
if currentLen >= targetLen {
|
||||
return x
|
||||
}
|
||||
padLen := targetLen - currentLen
|
||||
// Pad on sequence dimension (axis 1)
|
||||
return mlx.Pad(x, []int32{0, 0, 0, padLen, 0, 0})
|
||||
}
|
||||
|
||||
// LoadPersistent is an alias for backward compatibility.
|
||||
func LoadPersistent(modelPath string) (*Model, error) {
|
||||
m := &Model{}
|
||||
if err := m.Load(modelPath); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// PrepareRoPEMultiImage computes RoPE with interpolation for image editing.
|
||||
// Handles single or multiple input images with different resolutions.
|
||||
//
|
||||
// Parameters:
|
||||
// - outPH, outPW: output patch dimensions (noise latent resolution)
|
||||
// - inputDims: patch dimensions for each input image [(pH1, pW1), (pH2, pW2), ...]
|
||||
// - txtLen: text sequence length
|
||||
// - axesDims: RoPE axis dimensions [16, 56, 56]
|
||||
//
|
||||
// Returns RoPE cache where:
|
||||
// - ImgFreqs has (outPH*outPW + sum(inPH*inPW for each image)) positions
|
||||
// - First outPH*outPW positions are for noise latents (standard RoPE at output res)
|
||||
// - Following positions are for each input image (interpolated from output res)
|
||||
func PrepareRoPEMultiImage(outPH, outPW int32, inputDims []ImageDims, txtLen int32, axesDims []int32) *qwen_image.RoPECache {
|
||||
theta := float64(10000)
|
||||
maxIdx := int32(4096)
|
||||
|
||||
// Compute base frequencies for each axis dimension
|
||||
freqsT := qwen_image.ComputeAxisFreqs(axesDims[0], theta)
|
||||
freqsH := qwen_image.ComputeAxisFreqs(axesDims[1], theta)
|
||||
freqsW := qwen_image.ComputeAxisFreqs(axesDims[2], theta)
|
||||
|
||||
// Build frequency lookup tables
|
||||
posFreqsT := qwen_image.MakeFreqTable(maxIdx, freqsT, false)
|
||||
posFreqsH := qwen_image.MakeFreqTable(maxIdx, freqsH, false)
|
||||
posFreqsW := qwen_image.MakeFreqTable(maxIdx, freqsW, false)
|
||||
negFreqsT := qwen_image.MakeFreqTable(maxIdx, freqsT, true) // For frame -1 on last condition image
|
||||
negFreqsH := qwen_image.MakeFreqTable(maxIdx, freqsH, true)
|
||||
negFreqsW := qwen_image.MakeFreqTable(maxIdx, freqsW, true)
|
||||
|
||||
headDim := int32(len(freqsT)+len(freqsH)+len(freqsW)) * 2
|
||||
|
||||
// Helper to compute RoPE for a single position at output resolution with scale_rope
|
||||
computePosFreqs := func(framePos, y, x int32) []float32 {
|
||||
row := make([]float32, headDim)
|
||||
idx := 0
|
||||
|
||||
// Frame position
|
||||
for i := 0; i < len(freqsT)*2; i++ {
|
||||
row[idx+i] = posFreqsT[framePos][i]
|
||||
}
|
||||
idx += len(freqsT) * 2
|
||||
|
||||
// Height with scale_rope centering (using OUTPUT dimensions)
|
||||
outHHalf := outPH / 2
|
||||
hNegCount := outPH - outHHalf
|
||||
if y < hNegCount {
|
||||
negTableIdx := maxIdx - hNegCount + y
|
||||
for i := 0; i < len(freqsH)*2; i++ {
|
||||
row[idx+i] = negFreqsH[negTableIdx][i]
|
||||
}
|
||||
} else {
|
||||
posIdx := y - hNegCount
|
||||
for i := 0; i < len(freqsH)*2; i++ {
|
||||
row[idx+i] = posFreqsH[posIdx][i]
|
||||
}
|
||||
}
|
||||
idx += len(freqsH) * 2
|
||||
|
||||
// Width with scale_rope centering (using OUTPUT dimensions)
|
||||
outWHalf := outPW / 2
|
||||
wNegCount := outPW - outWHalf
|
||||
if x < wNegCount {
|
||||
negTableIdx := maxIdx - wNegCount + x
|
||||
for i := 0; i < len(freqsW)*2; i++ {
|
||||
row[idx+i] = negFreqsW[negTableIdx][i]
|
||||
}
|
||||
} else {
|
||||
posIdx := x - wNegCount
|
||||
for i := 0; i < len(freqsW)*2; i++ {
|
||||
row[idx+i] = posFreqsW[posIdx][i]
|
||||
}
|
||||
}
|
||||
|
||||
return row
|
||||
}
|
||||
|
||||
// Helper to compute RoPE for frame -1 (used for last condition image)
|
||||
// This matches Python's _compute_condition_freqs which uses freqs_neg[0][-1:]
|
||||
computeNegFrameFreqs := func(y, x int32) []float32 {
|
||||
row := make([]float32, headDim)
|
||||
idx := 0
|
||||
|
||||
// Frame -1: use last row of negative frame frequencies
|
||||
negFrameIdx := maxIdx - 1
|
||||
for i := 0; i < len(freqsT)*2; i++ {
|
||||
row[idx+i] = negFreqsT[negFrameIdx][i]
|
||||
}
|
||||
idx += len(freqsT) * 2
|
||||
|
||||
// Height with scale_rope centering (using OUTPUT dimensions)
|
||||
outHHalf := outPH / 2
|
||||
hNegCount := outPH - outHHalf
|
||||
if y < hNegCount {
|
||||
negTableIdx := maxIdx - hNegCount + y
|
||||
for i := 0; i < len(freqsH)*2; i++ {
|
||||
row[idx+i] = negFreqsH[negTableIdx][i]
|
||||
}
|
||||
} else {
|
||||
posIdx := y - hNegCount
|
||||
for i := 0; i < len(freqsH)*2; i++ {
|
||||
row[idx+i] = posFreqsH[posIdx][i]
|
||||
}
|
||||
}
|
||||
idx += len(freqsH) * 2
|
||||
|
||||
// Width with scale_rope centering (using OUTPUT dimensions)
|
||||
outWHalf := outPW / 2
|
||||
wNegCount := outPW - outWHalf
|
||||
if x < wNegCount {
|
||||
negTableIdx := maxIdx - wNegCount + x
|
||||
for i := 0; i < len(freqsW)*2; i++ {
|
||||
row[idx+i] = negFreqsW[negTableIdx][i]
|
||||
}
|
||||
} else {
|
||||
posIdx := x - wNegCount
|
||||
for i := 0; i < len(freqsW)*2; i++ {
|
||||
row[idx+i] = posFreqsW[posIdx][i]
|
||||
}
|
||||
}
|
||||
|
||||
return row
|
||||
}
|
||||
|
||||
// Total image sequence length: noise + all input images
|
||||
noiseSeqLen := outPH * outPW
|
||||
totalImgLen := noiseSeqLen
|
||||
for _, dims := range inputDims {
|
||||
totalImgLen += dims.PatchH * dims.PatchW
|
||||
}
|
||||
|
||||
imgFreqsData := make([]float32, totalImgLen*headDim)
|
||||
idx := int32(0)
|
||||
|
||||
// Segment 0: Noise latents - standard RoPE at output resolution (frame 0)
|
||||
for y := int32(0); y < outPH; y++ {
|
||||
for x := int32(0); x < outPW; x++ {
|
||||
row := computePosFreqs(0, y, x)
|
||||
copy(imgFreqsData[idx:], row)
|
||||
idx += headDim
|
||||
}
|
||||
}
|
||||
|
||||
// Segments 1..N: Edit image latents - INTERPOLATED RoPE
|
||||
// For single image: use frame 1 (matches original PrepareRoPEInterpolated)
|
||||
// For multiple images: Python uses frame -1 for the LAST condition image
|
||||
// (_compute_condition_freqs), positive indices for others.
|
||||
numImages := len(inputDims)
|
||||
lastImgIdx := numImages - 1
|
||||
for imgIdx, dims := range inputDims {
|
||||
inPH := dims.PatchH
|
||||
inPW := dims.PatchW
|
||||
|
||||
// Determine frame index for this image
|
||||
// Single image case: use frame 1 (like original PrepareRoPEInterpolated)
|
||||
// Multi-image case: last image uses frame -1, others use frame 1, 2, etc.
|
||||
useNegFrame := numImages > 1 && imgIdx == lastImgIdx
|
||||
|
||||
// Map each input position to an output position using linear interpolation
|
||||
for y := int32(0); y < inPH; y++ {
|
||||
for x := int32(0); x < inPW; x++ {
|
||||
// Interpolate: map input (y, x) to output grid position
|
||||
// This is the key fix from DiffSynth's forward_sampling
|
||||
var yOut, xOut int32
|
||||
if inPH == 1 {
|
||||
yOut = 0
|
||||
} else {
|
||||
// Linear interpolation: y_out = y * (outPH - 1) / (inPH - 1)
|
||||
yOut = y * (outPH - 1) / (inPH - 1)
|
||||
}
|
||||
if inPW == 1 {
|
||||
xOut = 0
|
||||
} else {
|
||||
xOut = x * (outPW - 1) / (inPW - 1)
|
||||
}
|
||||
|
||||
var row []float32
|
||||
if useNegFrame {
|
||||
// Last image in multi-image uses frame -1
|
||||
row = computeNegFrameFreqs(yOut, xOut)
|
||||
} else {
|
||||
// Single image uses frame 1, multi-image uses frame 1, 2, etc.
|
||||
frameIdx := int32(imgIdx + 1)
|
||||
row = computePosFreqs(frameIdx, yOut, xOut)
|
||||
}
|
||||
copy(imgFreqsData[idx:], row)
|
||||
idx += headDim
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
imgFreqs := mlx.NewArray(imgFreqsData, []int32{totalImgLen, headDim})
|
||||
imgFreqs = mlx.ToBFloat16(imgFreqs)
|
||||
|
||||
// Text frequencies - start after max video index
|
||||
maxVidIdx := max(outPH/2, outPW/2)
|
||||
|
||||
txtFreqsData := make([]float32, txtLen*headDim)
|
||||
idx = 0
|
||||
for t := int32(0); t < txtLen; t++ {
|
||||
pos := maxVidIdx + t
|
||||
for i := 0; i < len(freqsT)*2; i++ {
|
||||
txtFreqsData[idx+int32(i)] = posFreqsT[pos][i]
|
||||
}
|
||||
idx += int32(len(freqsT) * 2)
|
||||
for i := 0; i < len(freqsH)*2; i++ {
|
||||
txtFreqsData[idx+int32(i)] = posFreqsH[pos][i]
|
||||
}
|
||||
idx += int32(len(freqsH) * 2)
|
||||
for i := 0; i < len(freqsW)*2; i++ {
|
||||
txtFreqsData[idx+int32(i)] = posFreqsW[pos][i]
|
||||
}
|
||||
idx += int32(len(freqsW) * 2)
|
||||
}
|
||||
|
||||
txtFreqs := mlx.NewArray(txtFreqsData, []int32{txtLen, headDim})
|
||||
txtFreqs = mlx.ToBFloat16(txtFreqs)
|
||||
|
||||
return &qwen_image.RoPECache{
|
||||
ImgFreqs: imgFreqs,
|
||||
TxtFreqs: txtFreqs,
|
||||
}
|
||||
}
|
||||
@@ -1,227 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package qwen_image_edit
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/models/qwen_image"
|
||||
)
|
||||
|
||||
// TestComputeAxisFreqs verifies frequency computation matches Python reference
|
||||
func TestComputeAxisFreqs(t *testing.T) {
|
||||
theta := float64(10000)
|
||||
|
||||
// Expected values from Python:
|
||||
// freqs = 1.0 / (theta ** (np.arange(0, half_dim) / half_dim))
|
||||
expectedFreqsT := []float64{
|
||||
1.000000000000000, 0.316227766016838, 0.100000000000000, 0.031622776601684,
|
||||
0.010000000000000, 0.003162277660168, 0.001000000000000, 0.000316227766017,
|
||||
}
|
||||
|
||||
expectedFreqsH_first4 := []float64{
|
||||
1.000000000000000, 0.719685673001152, 0.517947467923121, 0.372759372031494,
|
||||
}
|
||||
|
||||
expectedFreqsH_last4 := []float64{
|
||||
0.000372759372031, 0.000268269579528, 0.000193069772888, 0.000138949549437,
|
||||
}
|
||||
|
||||
// Test temporal frequencies (dim=16)
|
||||
freqsT := qwen_image.ComputeAxisFreqs(16, theta)
|
||||
if len(freqsT) != 8 {
|
||||
t.Fatalf("expected 8 temporal frequencies, got %d", len(freqsT))
|
||||
}
|
||||
for i, expected := range expectedFreqsT {
|
||||
if diff := math.Abs(freqsT[i] - expected); diff > 1e-10 {
|
||||
t.Errorf("freqsT[%d]: expected %.15f, got %.15f, diff %.2e", i, expected, freqsT[i], diff)
|
||||
}
|
||||
}
|
||||
|
||||
// Test height/width frequencies (dim=56)
|
||||
freqsH := qwen_image.ComputeAxisFreqs(56, theta)
|
||||
if len(freqsH) != 28 {
|
||||
t.Fatalf("expected 28 height frequencies, got %d", len(freqsH))
|
||||
}
|
||||
for i, expected := range expectedFreqsH_first4 {
|
||||
if diff := math.Abs(freqsH[i] - expected); diff > 1e-10 {
|
||||
t.Errorf("freqsH[%d]: expected %.15f, got %.15f, diff %.2e", i, expected, freqsH[i], diff)
|
||||
}
|
||||
}
|
||||
for i, expected := range expectedFreqsH_last4 {
|
||||
idx := 24 + i // last 4 of 28
|
||||
if diff := math.Abs(freqsH[idx] - expected); diff > 1e-10 {
|
||||
t.Errorf("freqsH[%d]: expected %.15f, got %.15f, diff %.2e", idx, expected, freqsH[idx], diff)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestMakeFreqTable verifies the frequency lookup table for both positive and negative positions
|
||||
func TestMakeFreqTable(t *testing.T) {
|
||||
theta := float64(10000)
|
||||
freqsT := qwen_image.ComputeAxisFreqs(16, theta)
|
||||
maxIdx := int32(4096)
|
||||
|
||||
// Test positive table
|
||||
posTable := qwen_image.MakeFreqTable(maxIdx, freqsT, false)
|
||||
|
||||
// Position 0 should give cos=1, sin=0 for all frequencies
|
||||
for i := 0; i < len(freqsT)*2; i += 2 {
|
||||
if posTable[0][i] != 1.0 {
|
||||
t.Errorf("posTable[0][%d] (cos): expected 1.0, got %f", i, posTable[0][i])
|
||||
}
|
||||
if posTable[0][i+1] != 0.0 {
|
||||
t.Errorf("posTable[0][%d] (sin): expected 0.0, got %f", i+1, posTable[0][i+1])
|
||||
}
|
||||
}
|
||||
|
||||
// Position 1, first frequency (1.0): angle = 1*1 = 1
|
||||
// cos(1) = 0.5403, sin(1) = 0.8415
|
||||
if diff := math.Abs(float64(posTable[1][0]) - 0.5403023058681398); diff > 1e-6 {
|
||||
t.Errorf("posTable[1][0] (cos): expected 0.5403, got %f", posTable[1][0])
|
||||
}
|
||||
if diff := math.Abs(float64(posTable[1][1]) - 0.8414709848078965); diff > 1e-6 {
|
||||
t.Errorf("posTable[1][1] (sin): expected 0.8415, got %f", posTable[1][1])
|
||||
}
|
||||
|
||||
// Test negative table
|
||||
negTable := qwen_image.MakeFreqTable(maxIdx, freqsT, true)
|
||||
|
||||
// negTable[4095] corresponds to position -1
|
||||
// cos(-1) = cos(1), sin(-1) = -sin(1)
|
||||
if diff := math.Abs(float64(negTable[4095][0]) - 0.5403023058681398); diff > 1e-6 {
|
||||
t.Errorf("negTable[4095][0] (cos(-1)): expected 0.5403, got %f", negTable[4095][0])
|
||||
}
|
||||
if diff := math.Abs(float64(negTable[4095][1]) - (-0.8414709848078965)); diff > 1e-6 {
|
||||
t.Errorf("negTable[4095][1] (sin(-1)): expected -0.8415, got %f", negTable[4095][1])
|
||||
}
|
||||
|
||||
// negTable[4094] corresponds to position -2
|
||||
// cos(-2) = cos(2), sin(-2) = -sin(2)
|
||||
cos2 := math.Cos(2.0)
|
||||
sin2 := math.Sin(2.0)
|
||||
if diff := math.Abs(float64(negTable[4094][0]) - cos2); diff > 1e-6 {
|
||||
t.Errorf("negTable[4094][0] (cos(-2)): expected %f, got %f", cos2, negTable[4094][0])
|
||||
}
|
||||
if diff := math.Abs(float64(negTable[4094][1]) - (-sin2)); diff > 1e-6 {
|
||||
t.Errorf("negTable[4094][1] (sin(-2)): expected %f, got %f", -sin2, negTable[4094][1])
|
||||
}
|
||||
}
|
||||
|
||||
// TestPrepareRoPE_QwenImage verifies qwen_image.PrepareRoPE for single-segment case
|
||||
func TestPrepareRoPE_QwenImage(t *testing.T) {
|
||||
if !mlx.GPUIsAvailable() {
|
||||
t.Skip("GPU not available")
|
||||
}
|
||||
|
||||
mlx.SetDefaultDeviceCPU()
|
||||
|
||||
// 4x4 patch grid, single image
|
||||
imgH, imgW := int32(4), int32(4)
|
||||
txtLen := int32(5)
|
||||
axesDims := []int32{16, 56, 56}
|
||||
|
||||
cache := qwen_image.PrepareRoPE(imgH, imgW, txtLen, axesDims)
|
||||
mlx.Eval(cache.ImgFreqs, cache.TxtFreqs)
|
||||
|
||||
// Check shapes
|
||||
imgShape := cache.ImgFreqs.Shape()
|
||||
if imgShape[0] != 16 { // 4*4 patches
|
||||
t.Errorf("ImgFreqs seq len: expected 16, got %d", imgShape[0])
|
||||
}
|
||||
|
||||
// For single image (frame=0), all temporal values should be cos=1, sin=0
|
||||
imgFreqsCPU := mlx.AsType(cache.ImgFreqs, mlx.DtypeFloat32)
|
||||
mlx.Eval(imgFreqsCPU)
|
||||
imgData := imgFreqsCPU.Data()
|
||||
|
||||
// Check first 16 values of patch 0 (temporal cos/sin pairs)
|
||||
for i := 0; i < 16; i += 2 {
|
||||
cosVal := imgData[i]
|
||||
sinVal := imgData[i+1]
|
||||
if diff := math.Abs(float64(cosVal - 1.0)); diff > 1e-5 {
|
||||
t.Errorf("ImgFreqs[0][%d] (cos): expected 1.0, got %f", i, cosVal)
|
||||
}
|
||||
if diff := math.Abs(float64(sinVal - 0.0)); diff > 1e-5 {
|
||||
t.Errorf("ImgFreqs[0][%d] (sin): expected 0.0, got %f", i+1, sinVal)
|
||||
}
|
||||
}
|
||||
|
||||
cache.ImgFreqs.Free()
|
||||
cache.TxtFreqs.Free()
|
||||
}
|
||||
|
||||
// TestScaleRopePositions verifies the centered position calculation for scale_rope=True
|
||||
func TestScaleRopePositions(t *testing.T) {
|
||||
// For a 4x4 grid with scale_rope=True:
|
||||
// hHalf = 2, wHalf = 2
|
||||
// hNegCount = 4 - 2 = 2 (positions 0,1 are negative)
|
||||
// wNegCount = 4 - 2 = 2 (positions 0,1 are negative)
|
||||
//
|
||||
// Height positions:
|
||||
// y=0: -(4-2) + 0 = -2
|
||||
// y=1: -(4-2) + 1 = -1
|
||||
// y=2: 2 - 2 = 0
|
||||
// y=3: 3 - 2 = 1
|
||||
//
|
||||
// Same for width
|
||||
|
||||
pH, pW := int32(4), int32(4)
|
||||
hHalf := pH / 2
|
||||
wHalf := pW / 2
|
||||
hNegCount := pH - hHalf
|
||||
wNegCount := pW - wHalf
|
||||
|
||||
expectedH := []int32{-2, -1, 0, 1}
|
||||
expectedW := []int32{-2, -1, 0, 1}
|
||||
|
||||
for y := int32(0); y < pH; y++ {
|
||||
var hPos int32
|
||||
if y < hNegCount {
|
||||
hPos = -(pH - hHalf) + y
|
||||
} else {
|
||||
hPos = y - hNegCount
|
||||
}
|
||||
if hPos != expectedH[y] {
|
||||
t.Errorf("y=%d: expected h_pos=%d, got %d", y, expectedH[y], hPos)
|
||||
}
|
||||
}
|
||||
|
||||
for x := int32(0); x < pW; x++ {
|
||||
var wPos int32
|
||||
if x < wNegCount {
|
||||
wPos = -(pW - wHalf) + x
|
||||
} else {
|
||||
wPos = x - wNegCount
|
||||
}
|
||||
if wPos != expectedW[x] {
|
||||
t.Errorf("x=%d: expected w_pos=%d, got %d", x, expectedW[x], wPos)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestRoPEHeadDimensions verifies the head dimension breakdown
|
||||
func TestRoPEHeadDimensions(t *testing.T) {
|
||||
// axes_dims_rope = [16, 56, 56]
|
||||
// Each dimension uses half the values for frequencies
|
||||
// So we get: 8 + 28 + 28 = 64 frequency values
|
||||
// Each frequency produces cos + sin, so: 64 * 2 = 128 total values per position
|
||||
|
||||
axesDims := []int32{16, 56, 56}
|
||||
expectedFreqs := (axesDims[0]/2 + axesDims[1]/2 + axesDims[2]/2)
|
||||
expectedHeadDim := expectedFreqs * 2
|
||||
|
||||
if expectedFreqs != 64 {
|
||||
t.Errorf("expected 64 frequency values, got %d", expectedFreqs)
|
||||
}
|
||||
if expectedHeadDim != 128 {
|
||||
t.Errorf("expected head_dim=128, got %d", expectedHeadDim)
|
||||
}
|
||||
|
||||
// This should match the transformer's attention head dimension
|
||||
// hidden_size = 3072, num_heads = 24
|
||||
// head_dim = 3072 / 24 = 128
|
||||
}
|
||||
|
||||
@@ -1,642 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package qwen_image_edit
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
|
||||
// VAEConfig holds Qwen-Image VAE configuration
|
||||
type VAEConfig struct {
|
||||
ZDim int32 `json:"z_dim"` // 16
|
||||
BaseDim int32 `json:"base_dim"` // 96
|
||||
DimMult []int32 `json:"dim_mult"` // [1, 2, 4, 4]
|
||||
NumResBlocks int32 `json:"num_res_blocks"` // 2
|
||||
LatentsMean []float32 `json:"latents_mean"` // 16 values
|
||||
LatentsStd []float32 `json:"latents_std"` // 16 values
|
||||
TemperalDownsample []bool `json:"temperal_downsample"` // [false, true, true]
|
||||
}
|
||||
|
||||
// defaultVAEConfig returns config for Qwen-Image VAE
|
||||
func defaultVAEConfig() *VAEConfig {
|
||||
return &VAEConfig{
|
||||
ZDim: 16,
|
||||
BaseDim: 96,
|
||||
DimMult: []int32{1, 2, 4, 4},
|
||||
NumResBlocks: 2,
|
||||
LatentsMean: []float32{
|
||||
-0.7571, -0.7089, -0.9113, 0.1075,
|
||||
-0.1745, 0.9653, -0.1517, 1.5508,
|
||||
0.4134, -0.0715, 0.5517, -0.3632,
|
||||
-0.1922, -0.9497, 0.2503, -0.2921,
|
||||
},
|
||||
LatentsStd: []float32{
|
||||
2.8184, 1.4541, 2.3275, 2.6558,
|
||||
1.2196, 1.7708, 2.6052, 2.0743,
|
||||
3.2687, 2.1526, 2.8652, 1.5579,
|
||||
1.6382, 1.1253, 2.8251, 1.916,
|
||||
},
|
||||
TemperalDownsample: []bool{false, true, true},
|
||||
}
|
||||
}
|
||||
|
||||
// VAE is the full VAE with encoder and decoder
|
||||
type VAE struct {
|
||||
Config *VAEConfig
|
||||
Encoder *VAEEncoder
|
||||
Decoder *VAEDecoder
|
||||
}
|
||||
|
||||
// Load loads the VAE from a directory
|
||||
func (m *VAE) Load(path string) error {
|
||||
fmt.Println("Loading Qwen-Image-Edit VAE (encoder + decoder)...")
|
||||
|
||||
cfg := defaultVAEConfig()
|
||||
m.Config = cfg
|
||||
|
||||
weights, err := safetensors.LoadModelWeights(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
|
||||
// Load weights as f32 for quality (matches Python default behavior)
|
||||
// VAE decoder precision is critical for final image quality
|
||||
fmt.Print(" Loading weights as f32... ")
|
||||
if err := weights.Load(mlx.DtypeFloat32); err != nil {
|
||||
return fmt.Errorf("failed to load weights: %w", err)
|
||||
}
|
||||
fmt.Printf("✓ (%.1f GB)\n", float64(mlx.MetalGetActiveMemory())/(1024*1024*1024))
|
||||
|
||||
// Load encoder
|
||||
fmt.Print(" Loading encoder... ")
|
||||
m.Encoder = &VAEEncoder{}
|
||||
if err := m.Encoder.loadFromWeights(weights, cfg); err != nil {
|
||||
return fmt.Errorf("encoder: %w", err)
|
||||
}
|
||||
fmt.Println("✓")
|
||||
|
||||
// Load decoder
|
||||
fmt.Print(" Loading decoder... ")
|
||||
m.Decoder = &VAEDecoder{}
|
||||
if err := m.Decoder.loadFromWeights(weights, cfg); err != nil {
|
||||
return fmt.Errorf("decoder: %w", err)
|
||||
}
|
||||
fmt.Println("✓")
|
||||
|
||||
weights.ReleaseAll()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encode encodes an image to latents
|
||||
// x: [B, C, T, H, W] image tensor in [-1, 1] range
|
||||
// Returns: [B, C, T, H/8, W/8] latents (unnormalized)
|
||||
func (m *VAE) Encode(x *mlx.Array) *mlx.Array {
|
||||
return m.Encoder.Encode(x)
|
||||
}
|
||||
|
||||
// Decode decodes latents to image
|
||||
// z: [B, C, T, H, W] latents (denormalized)
|
||||
// Returns: [B, C, T, H*8, W*8] image in [-1, 1]
|
||||
func (m *VAE) Decode(z *mlx.Array) *mlx.Array {
|
||||
return m.Decoder.Decode(z)
|
||||
}
|
||||
|
||||
// Normalize applies latent normalization
|
||||
// Input z should be f32 (from VAE encoder), output is f32 for transformer
|
||||
func (m *VAE) Normalize(z *mlx.Array) *mlx.Array {
|
||||
shape := z.Shape()
|
||||
C := shape[1]
|
||||
|
||||
mean := mlx.NewArray(m.Config.LatentsMean[:C], []int32{1, C, 1, 1, 1})
|
||||
std := mlx.NewArray(m.Config.LatentsStd[:C], []int32{1, C, 1, 1, 1})
|
||||
|
||||
// Mean/std are f32, will match z dtype through broadcasting
|
||||
return mlx.Div(mlx.Sub(z, mean), std)
|
||||
}
|
||||
|
||||
// Denormalize reverses latent normalization
|
||||
// Input z is bf16 (from transformer), output converted to f32 for VAE decoder
|
||||
func (m *VAE) Denormalize(z *mlx.Array) *mlx.Array {
|
||||
shape := z.Shape()
|
||||
C := shape[1]
|
||||
|
||||
// Convert latents to f32 for VAE decoder quality
|
||||
z = mlx.AsType(z, mlx.DtypeFloat32)
|
||||
|
||||
mean := mlx.NewArray(m.Config.LatentsMean[:C], []int32{1, C, 1, 1, 1})
|
||||
std := mlx.NewArray(m.Config.LatentsStd[:C], []int32{1, C, 1, 1, 1})
|
||||
|
||||
return mlx.Add(mlx.Mul(z, std), mean)
|
||||
}
|
||||
|
||||
// VAEEncoder is the encoder part of the VAE
|
||||
// The encoder uses a flat structure where down_blocks contains a mix of ResBlocks and Downsamplers:
|
||||
// - Blocks 0,1: ResBlocks (base_dim)
|
||||
// - Block 2: Downsample
|
||||
// - Blocks 3,4: ResBlocks (base_dim*2)
|
||||
// - Block 5: Downsample + temporal
|
||||
// - Blocks 6,7: ResBlocks (base_dim*4)
|
||||
// - Block 8: Downsample + temporal
|
||||
// - Blocks 9,10: ResBlocks (base_dim*4)
|
||||
type VAEEncoder struct {
|
||||
Config *VAEConfig
|
||||
|
||||
ConvIn *CausalConv3d
|
||||
Blocks []EncoderBlock // Flat list of ResBlocks and Downsamplers
|
||||
MidBlock *MidBlock
|
||||
NormOut *RMSNorm3D
|
||||
ConvOut *CausalConv3d
|
||||
QuantConv *CausalConv3d
|
||||
}
|
||||
|
||||
// EncoderBlock is either a ResBlock or a Downsample
|
||||
type EncoderBlock interface {
|
||||
Forward(x *mlx.Array) *mlx.Array
|
||||
IsDownsample() bool
|
||||
}
|
||||
|
||||
// EncoderResBlock wraps ResBlock
|
||||
type EncoderResBlock struct {
|
||||
*ResBlock
|
||||
}
|
||||
|
||||
func (b *EncoderResBlock) IsDownsample() bool { return false }
|
||||
|
||||
// EncoderDownsample is a downsample layer
|
||||
type EncoderDownsample struct {
|
||||
Resample *CausalConv3d
|
||||
TimeConv *CausalConv3d // Optional temporal downsample
|
||||
}
|
||||
|
||||
func (d *EncoderDownsample) IsDownsample() bool { return true }
|
||||
|
||||
func (d *EncoderDownsample) Forward(x *mlx.Array) *mlx.Array {
|
||||
// Spatial downsample with stride 2
|
||||
// WAN VAE uses: ZeroPad2d(0,1,0,1) + Conv2d(3x3, stride=2)
|
||||
x = d.forwardSpatialDownsample(x)
|
||||
|
||||
// NOTE: In WAN VAE, time_conv is ONLY used in streaming/chunked mode
|
||||
// with feat_cache. For single-frame encoding (T=1), time_conv is skipped.
|
||||
// The Python forward checks: if feat_cache is not None ... then use time_conv
|
||||
// Since we don't support streaming, we skip time_conv entirely.
|
||||
return x
|
||||
}
|
||||
|
||||
// forwardSpatialDownsample applies 2D conv with stride 2 for spatial downsampling
|
||||
func (d *EncoderDownsample) forwardSpatialDownsample(x *mlx.Array) *mlx.Array {
|
||||
xShape := x.Shape()
|
||||
B := xShape[0]
|
||||
T := xShape[1]
|
||||
H := xShape[2]
|
||||
W := xShape[3]
|
||||
C := xShape[4]
|
||||
|
||||
wShape := d.Resample.Weight.Shape()
|
||||
outC := wShape[0]
|
||||
|
||||
// Reshape to [B*T, H, W, C] for 2D conv
|
||||
x = mlx.Reshape(x, B*T, H, W, C)
|
||||
|
||||
// Asymmetric padding: pad right and bottom by 1 (WAN VAE style)
|
||||
// ZeroPad2d(0, 1, 0, 1) means (left=0, right=1, top=0, bottom=1)
|
||||
x = mlx.Pad(x, []int32{0, 0, 0, 1, 0, 1, 0, 0}) // [B, H, W, C] -> pad H and W
|
||||
|
||||
// Apply 2D conv with stride 2
|
||||
weight := mlx.Transpose(d.Resample.Weight, 0, 2, 3, 1) // [O, I, kH, kW] -> [O, kH, kW, I]
|
||||
x = conv2DStrided(x, weight, 2)
|
||||
|
||||
if d.Resample.Bias != nil {
|
||||
bias := mlx.Reshape(d.Resample.Bias, 1, 1, 1, outC)
|
||||
x = mlx.Add(x, bias)
|
||||
}
|
||||
|
||||
// Output dims after stride 2: (H+1)/2, (W+1)/2
|
||||
outH := (H + 1) / 2
|
||||
outW := (W + 1) / 2
|
||||
|
||||
// Reshape back to [B, T, H', W', C]
|
||||
x = mlx.Reshape(x, B, T, outH, outW, outC)
|
||||
mlx.Eval(x)
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// loadFromWeights loads the encoder from pre-loaded weights
|
||||
func (e *VAEEncoder) loadFromWeights(weights *safetensors.ModelWeights, cfg *VAEConfig) error {
|
||||
e.Config = cfg
|
||||
|
||||
// Conv in
|
||||
convIn, err := newCausalConv3d(weights, "encoder.conv_in")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
e.ConvIn = convIn
|
||||
|
||||
// Encoder uses flat block structure:
|
||||
// dim_mult = [1, 2, 4, 4], num_res_blocks = 2, temporal_downsample = [false, true, true]
|
||||
// Block layout: res,res,down, res,res,down+t, res,res,down+t, res,res
|
||||
// That's 11 blocks: 0,1=res, 2=down, 3,4=res, 5=down+t, 6,7=res, 8=down+t, 9,10=res
|
||||
e.Blocks = make([]EncoderBlock, 0, 11)
|
||||
|
||||
// Track dimensions
|
||||
dims := []int32{cfg.BaseDim, cfg.BaseDim * 2, cfg.BaseDim * 4, cfg.BaseDim * 4}
|
||||
blockIdx := 0
|
||||
|
||||
for stage := 0; stage < len(cfg.DimMult); stage++ {
|
||||
inDim := cfg.BaseDim
|
||||
if stage > 0 {
|
||||
inDim = dims[stage-1]
|
||||
}
|
||||
outDim := dims[stage]
|
||||
|
||||
// ResBlocks for this stage (num_res_blocks per stage)
|
||||
for r := int32(0); r < cfg.NumResBlocks; r++ {
|
||||
prefix := fmt.Sprintf("encoder.down_blocks.%d", blockIdx)
|
||||
currentInDim := inDim
|
||||
if r > 0 {
|
||||
currentInDim = outDim
|
||||
}
|
||||
block, err := newEncoderResBlock(weights, prefix, currentInDim, outDim)
|
||||
if err != nil {
|
||||
return fmt.Errorf("encoder res block %d: %w", blockIdx, err)
|
||||
}
|
||||
e.Blocks = append(e.Blocks, block)
|
||||
blockIdx++
|
||||
}
|
||||
|
||||
// Downsample after each stage except the last
|
||||
if stage < len(cfg.DimMult)-1 {
|
||||
prefix := fmt.Sprintf("encoder.down_blocks.%d", blockIdx)
|
||||
down, err := newEncoderDownsample(weights, prefix, cfg.TemperalDownsample[stage])
|
||||
if err != nil {
|
||||
return fmt.Errorf("encoder downsample %d: %w", blockIdx, err)
|
||||
}
|
||||
e.Blocks = append(e.Blocks, down)
|
||||
blockIdx++
|
||||
}
|
||||
}
|
||||
|
||||
// Mid block
|
||||
midDim := cfg.BaseDim * cfg.DimMult[len(cfg.DimMult)-1]
|
||||
midBlock, err := newMidBlock(weights, "encoder.mid_block", midDim)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
e.MidBlock = midBlock
|
||||
|
||||
// Norm out
|
||||
normOut, err := newRMSNorm3D(weights, "encoder.norm_out", midDim)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
e.NormOut = normOut
|
||||
|
||||
// Conv out
|
||||
convOut, err := newCausalConv3d(weights, "encoder.conv_out")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
e.ConvOut = convOut
|
||||
|
||||
// Quant conv
|
||||
quantConv, err := newCausalConv3d(weights, "quant_conv")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
e.QuantConv = quantConv
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// newEncoderResBlock creates a ResBlock for the encoder (flat structure)
|
||||
func newEncoderResBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32) (*EncoderResBlock, error) {
|
||||
block, err := newResBlock(weights, prefix, inDim, outDim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &EncoderResBlock{block}, nil
|
||||
}
|
||||
|
||||
// newEncoderDownsample creates a downsample layer for the encoder
|
||||
func newEncoderDownsample(weights *safetensors.ModelWeights, prefix string, temporal bool) (*EncoderDownsample, error) {
|
||||
resample, err := newCausalConv3d(weights, prefix+".resample.1")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var timeConv *CausalConv3d
|
||||
if temporal {
|
||||
timeConv, _ = newCausalConv3d(weights, prefix+".time_conv")
|
||||
}
|
||||
|
||||
return &EncoderDownsample{
|
||||
Resample: resample,
|
||||
TimeConv: timeConv,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Encode encodes an image to latents
|
||||
// x: [B, C, T, H, W] image tensor (channels-first)
|
||||
// Returns: [B, latent_C, T, H/8, W/8] latent distribution mode
|
||||
func (e *VAEEncoder) Encode(x *mlx.Array) *mlx.Array {
|
||||
// Convert from channels-first [N, C, T, H, W] to channels-last [N, T, H, W, C]
|
||||
x = mlx.Contiguous(mlx.Transpose(x, 0, 2, 3, 4, 1))
|
||||
mlx.Eval(x)
|
||||
|
||||
// Conv in
|
||||
x = e.ConvIn.Forward(x)
|
||||
|
||||
// Encoder blocks (mix of ResBlocks and Downsamplers)
|
||||
for _, block := range e.Blocks {
|
||||
prev := x
|
||||
x = block.Forward(x)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
// Mid block
|
||||
x = e.MidBlock.Forward(x)
|
||||
|
||||
// Norm + silu
|
||||
{
|
||||
prev := x
|
||||
x = e.NormOut.Forward(x)
|
||||
x = silu3D(x)
|
||||
prev.Free()
|
||||
mlx.Eval(x)
|
||||
}
|
||||
|
||||
// Conv out
|
||||
{
|
||||
prev := x
|
||||
x = e.ConvOut.Forward(x)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
// Quant conv
|
||||
{
|
||||
prev := x
|
||||
x = e.QuantConv.Forward(x)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
// Get mode from distribution (first half of channels = mean)
|
||||
// Output is [B, T, H, W, 2*latent_C], we take first latent_C channels
|
||||
shape := x.Shape()
|
||||
latentC := shape[4] / 2
|
||||
x = mlx.Slice(x, []int32{0, 0, 0, 0, 0}, []int32{shape[0], shape[1], shape[2], shape[3], latentC})
|
||||
|
||||
// Convert back to channels-first [N, C, T, H, W]
|
||||
x = mlx.Contiguous(mlx.Transpose(x, 0, 4, 1, 2, 3))
|
||||
mlx.Eval(x)
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// VAEDecoder is the decoder part of the VAE
|
||||
type VAEDecoder struct {
|
||||
Config *VAEConfig
|
||||
|
||||
PostQuantConv *CausalConv3d
|
||||
ConvIn *CausalConv3d
|
||||
MidBlock *MidBlock
|
||||
UpBlocks []*UpBlock
|
||||
NormOut *RMSNorm3D
|
||||
ConvOut *CausalConv3d
|
||||
}
|
||||
|
||||
// loadFromWeights loads the decoder from pre-loaded weights
|
||||
func (d *VAEDecoder) loadFromWeights(weights *safetensors.ModelWeights, cfg *VAEConfig) error {
|
||||
d.Config = cfg
|
||||
|
||||
postQuantConv, err := newCausalConv3d(weights, "post_quant_conv")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
d.PostQuantConv = postQuantConv
|
||||
|
||||
convIn, err := newCausalConv3d(weights, "decoder.conv_in")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
d.ConvIn = convIn
|
||||
|
||||
// Mid block
|
||||
midDim := cfg.BaseDim * cfg.DimMult[len(cfg.DimMult)-1]
|
||||
midBlock, err := newMidBlock(weights, "decoder.mid_block", midDim)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
d.MidBlock = midBlock
|
||||
|
||||
// Up blocks (reversed dim_mult)
|
||||
numUpBlocks := len(cfg.DimMult)
|
||||
d.UpBlocks = make([]*UpBlock, numUpBlocks)
|
||||
|
||||
dimsMult := make([]int32, numUpBlocks+1)
|
||||
dimsMult[0] = cfg.DimMult[numUpBlocks-1]
|
||||
for i := 0; i < numUpBlocks; i++ {
|
||||
dimsMult[i+1] = cfg.DimMult[numUpBlocks-1-i]
|
||||
}
|
||||
|
||||
temporalUpsample := make([]bool, len(cfg.TemperalDownsample))
|
||||
for i := range cfg.TemperalDownsample {
|
||||
temporalUpsample[i] = cfg.TemperalDownsample[len(cfg.TemperalDownsample)-1-i]
|
||||
}
|
||||
|
||||
for i := 0; i < numUpBlocks; i++ {
|
||||
inDim := cfg.BaseDim * dimsMult[i]
|
||||
outDim := cfg.BaseDim * dimsMult[i+1]
|
||||
|
||||
if i > 0 {
|
||||
inDim = inDim / 2
|
||||
}
|
||||
|
||||
upsampleMode := ""
|
||||
if i < numUpBlocks-1 {
|
||||
if temporalUpsample[i] {
|
||||
upsampleMode = "upsample3d"
|
||||
} else {
|
||||
upsampleMode = "upsample2d"
|
||||
}
|
||||
}
|
||||
|
||||
prefix := fmt.Sprintf("decoder.up_blocks.%d", i)
|
||||
upBlock, err := newUpBlock(weights, prefix, inDim, outDim, cfg.NumResBlocks, upsampleMode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
d.UpBlocks[i] = upBlock
|
||||
}
|
||||
|
||||
normOut, err := newRMSNorm3D(weights, "decoder.norm_out", cfg.BaseDim)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
d.NormOut = normOut
|
||||
|
||||
convOut, err := newCausalConv3d(weights, "decoder.conv_out")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
d.ConvOut = convOut
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Decode converts latents to image
|
||||
// z: [B, C, T, H, W] denormalized latents
|
||||
func (d *VAEDecoder) Decode(z *mlx.Array) *mlx.Array {
|
||||
var x *mlx.Array
|
||||
|
||||
// Convert from channels-first to channels-last
|
||||
{
|
||||
z = mlx.Contiguous(mlx.Transpose(z, 0, 2, 3, 4, 1))
|
||||
mlx.Eval(z)
|
||||
}
|
||||
|
||||
// PostQuantConv
|
||||
x = d.PostQuantConv.Forward(z)
|
||||
z.Free()
|
||||
|
||||
// ConvIn
|
||||
{
|
||||
prev := x
|
||||
x = d.ConvIn.Forward(x)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
// Mid block
|
||||
x = d.MidBlock.Forward(x)
|
||||
|
||||
// Up blocks
|
||||
for _, upBlock := range d.UpBlocks {
|
||||
x = upBlock.Forward(x)
|
||||
}
|
||||
|
||||
// NormOut + silu
|
||||
{
|
||||
prev := x
|
||||
x = d.NormOut.Forward(x)
|
||||
x = silu3D(x)
|
||||
prev.Free()
|
||||
mlx.Eval(x)
|
||||
}
|
||||
|
||||
// ConvOut
|
||||
{
|
||||
prev := x
|
||||
x = d.ConvOut.Forward(x)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
// Post-processing: clamp and convert back to channels-first
|
||||
{
|
||||
prev := x
|
||||
x = mlx.ClipScalar(x, -1.0, 1.0, true, true)
|
||||
x = mlx.Contiguous(mlx.Transpose(x, 0, 4, 1, 2, 3))
|
||||
prev.Free()
|
||||
mlx.Eval(x)
|
||||
}
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// DownBlock handles downsampling in encoder
|
||||
type DownBlock struct {
|
||||
ResBlocks []*ResBlock
|
||||
Downsampler *Downsample
|
||||
}
|
||||
|
||||
// newDownBlock creates a down block
|
||||
func newDownBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32, numBlocks int32, downsampleMode string) (*DownBlock, error) {
|
||||
resBlocks := make([]*ResBlock, numBlocks+1)
|
||||
|
||||
currentDim := inDim
|
||||
for i := int32(0); i <= numBlocks; i++ {
|
||||
resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i)
|
||||
block, err := newResBlock(weights, resPrefix, currentDim, outDim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resBlocks[i] = block
|
||||
currentDim = outDim
|
||||
}
|
||||
|
||||
var downsampler *Downsample
|
||||
if downsampleMode != "" {
|
||||
downsampler = newDownsample(weights, prefix+".downsamplers.0", outDim, downsampleMode)
|
||||
}
|
||||
|
||||
return &DownBlock{
|
||||
ResBlocks: resBlocks,
|
||||
Downsampler: downsampler,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies down block
|
||||
func (d *DownBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
for _, block := range d.ResBlocks {
|
||||
prev := x
|
||||
x = block.Forward(x)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
if d.Downsampler != nil {
|
||||
prev := x
|
||||
x = d.Downsampler.Forward(x)
|
||||
prev.Free()
|
||||
}
|
||||
return x
|
||||
}
|
||||
|
||||
// Downsample handles spatial downsampling
|
||||
type Downsample struct {
|
||||
Conv *mlx.Array
|
||||
Bias *mlx.Array
|
||||
Mode string
|
||||
}
|
||||
|
||||
// newDownsample creates a downsampler
|
||||
func newDownsample(weights *safetensors.ModelWeights, prefix string, dim int32, mode string) *Downsample {
|
||||
conv, _ := weights.Get(prefix + ".resample.1.weight")
|
||||
bias, _ := weights.Get(prefix + ".resample.1.bias")
|
||||
return &Downsample{
|
||||
Conv: conv,
|
||||
Bias: bias,
|
||||
Mode: mode,
|
||||
}
|
||||
}
|
||||
|
||||
// Forward applies downsampling to channels-last input [B, T, H, W, C]
|
||||
func (d *Downsample) Forward(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
T := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
C := shape[4]
|
||||
outC := d.Conv.Shape()[0]
|
||||
|
||||
// Reshape to [B*T, H, W, C] for 2D conv
|
||||
x = mlx.Reshape(x, B*T, H, W, C)
|
||||
|
||||
// Pad for stride-2 conv: need (3-1)/2 = 1 on each side, but for stride 2 we need specific padding
|
||||
// For 3x3 stride 2: pad 1 on all sides
|
||||
x = mlx.Pad(x, []int32{0, 0, 1, 1, 1, 1, 0, 0})
|
||||
|
||||
// Conv with stride 2 using manual strided patching
|
||||
weight := mlx.Transpose(d.Conv, 0, 2, 3, 1)
|
||||
x = conv2DStrided(x, weight, 2)
|
||||
if d.Bias != nil {
|
||||
bias := mlx.Reshape(d.Bias, 1, 1, 1, outC)
|
||||
x = mlx.Add(x, bias)
|
||||
}
|
||||
|
||||
x = mlx.Reshape(x, B, T, H/2, W/2, outC)
|
||||
mlx.Eval(x)
|
||||
|
||||
return x
|
||||
}
|
||||
@@ -1,148 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package zimage
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// FlowMatchSchedulerConfig holds scheduler configuration
|
||||
type FlowMatchSchedulerConfig struct {
|
||||
NumTrainTimesteps int32 `json:"num_train_timesteps"` // 1000
|
||||
Shift float32 `json:"shift"` // 3.0
|
||||
UseDynamicShifting bool `json:"use_dynamic_shifting"` // false
|
||||
}
|
||||
|
||||
// DefaultFlowMatchSchedulerConfig returns default config
|
||||
func DefaultFlowMatchSchedulerConfig() *FlowMatchSchedulerConfig {
|
||||
return &FlowMatchSchedulerConfig{
|
||||
NumTrainTimesteps: 1000,
|
||||
Shift: 3.0,
|
||||
UseDynamicShifting: true, // Z-Image-Turbo uses dynamic shifting
|
||||
}
|
||||
}
|
||||
|
||||
// FlowMatchEulerScheduler implements the Flow Match Euler discrete scheduler
|
||||
// This is used in Z-Image-Turbo for fast sampling
|
||||
type FlowMatchEulerScheduler struct {
|
||||
Config *FlowMatchSchedulerConfig
|
||||
Timesteps []float32 // Discretized timesteps
|
||||
Sigmas []float32 // Noise levels at each timestep
|
||||
NumSteps int // Number of inference steps
|
||||
}
|
||||
|
||||
// NewFlowMatchEulerScheduler creates a new scheduler
|
||||
func NewFlowMatchEulerScheduler(cfg *FlowMatchSchedulerConfig) *FlowMatchEulerScheduler {
|
||||
return &FlowMatchEulerScheduler{
|
||||
Config: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
// SetTimesteps sets up the scheduler for the given number of inference steps
|
||||
func (s *FlowMatchEulerScheduler) SetTimesteps(numSteps int) {
|
||||
s.SetTimestepsWithMu(numSteps, 0)
|
||||
}
|
||||
|
||||
// SetTimestepsWithMu sets up the scheduler with dynamic mu shift
|
||||
func (s *FlowMatchEulerScheduler) SetTimestepsWithMu(numSteps int, mu float32) {
|
||||
s.NumSteps = numSteps
|
||||
|
||||
// Create evenly spaced timesteps from 1.0 to 0.0 (flow matching goes t=1 to t=0)
|
||||
// Match Python: np.linspace(1.0, 0.0, num_inference_steps + 1)
|
||||
s.Timesteps = make([]float32, numSteps+1)
|
||||
s.Sigmas = make([]float32, numSteps+1)
|
||||
|
||||
for i := 0; i <= numSteps; i++ {
|
||||
t := 1.0 - float32(i)/float32(numSteps)
|
||||
|
||||
// Apply time shift if using dynamic shifting
|
||||
if s.Config.UseDynamicShifting && mu != 0 {
|
||||
t = s.timeShift(mu, t)
|
||||
}
|
||||
|
||||
s.Timesteps[i] = t
|
||||
s.Sigmas[i] = t
|
||||
}
|
||||
}
|
||||
|
||||
// timeShift applies the dynamic time shift (match Python)
|
||||
func (s *FlowMatchEulerScheduler) timeShift(mu float32, t float32) float32 {
|
||||
if t <= 0 {
|
||||
return 0
|
||||
}
|
||||
// exp(mu) / (exp(mu) + (1/t - 1))
|
||||
expMu := float32(math.Exp(float64(mu)))
|
||||
return expMu / (expMu + (1.0/t - 1.0))
|
||||
}
|
||||
|
||||
// Step performs one denoising step
|
||||
// modelOutput: predicted velocity/noise from the model
|
||||
// timestepIdx: current timestep index
|
||||
// sample: current noisy sample
|
||||
// Returns: denoised sample for next step
|
||||
func (s *FlowMatchEulerScheduler) Step(modelOutput, sample *mlx.Array, timestepIdx int) *mlx.Array {
|
||||
// Get current and next sigma
|
||||
sigma := s.Sigmas[timestepIdx]
|
||||
sigmaNext := s.Sigmas[timestepIdx+1]
|
||||
|
||||
// Euler step: x_{t-dt} = x_t + (sigma_next - sigma) * v_t
|
||||
// where v_t is the velocity predicted by the model
|
||||
dt := sigmaNext - sigma // This is negative (going from noise to clean)
|
||||
|
||||
// x_next = x + dt * velocity
|
||||
scaledOutput := mlx.MulScalar(modelOutput, dt)
|
||||
return mlx.Add(sample, scaledOutput)
|
||||
}
|
||||
|
||||
// ScaleSample scales the sample for model input (identity for flow matching)
|
||||
func (s *FlowMatchEulerScheduler) ScaleSample(sample *mlx.Array, timestepIdx int) *mlx.Array {
|
||||
// Flow matching doesn't need scaling
|
||||
return sample
|
||||
}
|
||||
|
||||
// GetTimestep returns the timestep value at the given index
|
||||
func (s *FlowMatchEulerScheduler) GetTimestep(idx int) float32 {
|
||||
if idx < len(s.Timesteps) {
|
||||
return s.Timesteps[idx]
|
||||
}
|
||||
return 0.0
|
||||
}
|
||||
|
||||
// GetTimesteps returns all timesteps (implements Scheduler interface)
|
||||
func (s *FlowMatchEulerScheduler) GetTimesteps() []float32 {
|
||||
return s.Timesteps
|
||||
}
|
||||
|
||||
// AddNoise adds noise to clean samples for a given timestep
|
||||
// Used for img2img or inpainting
|
||||
func (s *FlowMatchEulerScheduler) AddNoise(cleanSample, noise *mlx.Array, timestepIdx int) *mlx.Array {
|
||||
// In flow matching: x_t = (1-t) * x_0 + t * noise
|
||||
t := s.Timesteps[timestepIdx]
|
||||
oneMinusT := 1.0 - t
|
||||
|
||||
scaledClean := mlx.MulScalar(cleanSample, oneMinusT)
|
||||
scaledNoise := mlx.MulScalar(noise, t)
|
||||
|
||||
return mlx.Add(scaledClean, scaledNoise)
|
||||
}
|
||||
|
||||
// InitNoise creates initial noise for sampling
|
||||
func (s *FlowMatchEulerScheduler) InitNoise(shape []int32, seed int64) *mlx.Array {
|
||||
return RandomNormal(shape, seed)
|
||||
}
|
||||
|
||||
// RandomNormal creates a random normal tensor using MLX
|
||||
func RandomNormal(shape []int32, seed int64) *mlx.Array {
|
||||
return mlx.RandomNormal(shape, uint64(seed))
|
||||
}
|
||||
|
||||
// GetLatentShape returns the latent shape for a given image size
|
||||
func GetLatentShape(batchSize, height, width, latentChannels int32, patchSize int32) []int32 {
|
||||
// Latent is 8x smaller than image (VAE downscale)
|
||||
latentH := height / 8
|
||||
latentW := width / 8
|
||||
|
||||
return []int32{batchSize, latentChannels, latentH, latentW}
|
||||
}
|
||||
@@ -1,296 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package zimage
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
)
|
||||
|
||||
// Qwen3Config holds Qwen3 text encoder configuration
|
||||
type Qwen3Config struct {
|
||||
HiddenSize int32 `json:"hidden_size"`
|
||||
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
||||
IntermediateSize int32 `json:"intermediate_size"`
|
||||
NumAttentionHeads int32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads int32 `json:"num_key_value_heads"`
|
||||
VocabSize int32 `json:"vocab_size"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
HeadDim int32 `json:"head_dim"`
|
||||
}
|
||||
|
||||
// loadQwen3Config loads text encoder config from a JSON file
|
||||
func loadQwen3Config(path string) (*Qwen3Config, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read config: %w", err)
|
||||
}
|
||||
var cfg Qwen3Config
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("parse config: %w", err)
|
||||
}
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
// Qwen3Attention implements Qwen3 attention with QK norms
|
||||
type Qwen3Attention struct {
|
||||
QProj *nn.Linear `weight:"q_proj"`
|
||||
KProj *nn.Linear `weight:"k_proj"`
|
||||
VProj *nn.Linear `weight:"v_proj"`
|
||||
OProj *nn.Linear `weight:"o_proj"`
|
||||
QNorm *nn.RMSNorm `weight:"q_norm"`
|
||||
KNorm *nn.RMSNorm `weight:"k_norm"`
|
||||
// Computed fields
|
||||
NHeads int32
|
||||
NKVHeads int32
|
||||
HeadDim int32
|
||||
Scale float32
|
||||
RopeTheta float32
|
||||
}
|
||||
|
||||
// applyRoPEQwen3 applies the custom RoPE for Qwen3 text encoder
|
||||
func applyRoPEQwen3(x *mlx.Array, seqLen int32, theta float32) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
H := shape[2]
|
||||
D := shape[3]
|
||||
half := D / 2
|
||||
|
||||
freqsArr := make([]float32, half)
|
||||
logTheta := float32(math.Log(float64(theta)))
|
||||
for i := int32(0); i < half; i++ {
|
||||
freqsArr[i] = float32(math.Exp(float64(-logTheta * float32(i) / float32(half))))
|
||||
}
|
||||
freqs := mlx.NewArray(freqsArr, []int32{half})
|
||||
|
||||
posArr := make([]float32, seqLen)
|
||||
for i := int32(0); i < seqLen; i++ {
|
||||
posArr[i] = float32(i)
|
||||
}
|
||||
pos := mlx.NewArray(posArr, []int32{seqLen})
|
||||
|
||||
posExpanded := mlx.Reshape(pos, seqLen, 1)
|
||||
freqsExpanded := mlx.Reshape(freqs, 1, half)
|
||||
args := mlx.Mul(posExpanded, freqsExpanded)
|
||||
|
||||
cosVals := mlx.Cos(args)
|
||||
sinVals := mlx.Sin(args)
|
||||
cosVals = mlx.Reshape(cosVals, seqLen, 1, half)
|
||||
sinVals = mlx.Reshape(sinVals, seqLen, 1, half)
|
||||
|
||||
x1 := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, L, H, half})
|
||||
x2 := mlx.Slice(x, []int32{0, 0, 0, half}, []int32{B, L, H, D})
|
||||
|
||||
part1 := mlx.Sub(mlx.Mul(x1, cosVals), mlx.Mul(x2, sinVals))
|
||||
part2 := mlx.Add(mlx.Mul(x1, sinVals), mlx.Mul(x2, cosVals))
|
||||
|
||||
return mlx.Concatenate([]*mlx.Array{part1, part2}, 3)
|
||||
}
|
||||
|
||||
// Forward computes attention with causal masking
|
||||
func (attn *Qwen3Attention) Forward(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
|
||||
q := attn.QProj.Forward(x)
|
||||
k := attn.KProj.Forward(x)
|
||||
v := attn.VProj.Forward(x)
|
||||
|
||||
q = mlx.Reshape(q, B, L, attn.NHeads, attn.HeadDim)
|
||||
k = mlx.Reshape(k, B, L, attn.NKVHeads, attn.HeadDim)
|
||||
v = mlx.Reshape(v, B, L, attn.NKVHeads, attn.HeadDim)
|
||||
|
||||
// QK norm uses 1e-6 hardcoded (Qwen3 specific)
|
||||
q = attn.QNorm.Forward(q, 1e-6)
|
||||
k = attn.KNorm.Forward(k, 1e-6)
|
||||
|
||||
q = applyRoPEQwen3(q, L, attn.RopeTheta)
|
||||
k = applyRoPEQwen3(k, L, attn.RopeTheta)
|
||||
|
||||
q = mlx.Transpose(q, 0, 2, 1, 3)
|
||||
k = mlx.Transpose(k, 0, 2, 1, 3)
|
||||
v = mlx.Transpose(v, 0, 2, 1, 3)
|
||||
|
||||
if attn.NKVHeads < attn.NHeads {
|
||||
repeats := attn.NHeads / attn.NKVHeads
|
||||
k = repeatKV(k, repeats)
|
||||
v = repeatKV(v, repeats)
|
||||
}
|
||||
|
||||
out := mlx.ScaledDotProductAttention(q, k, v, attn.Scale, true)
|
||||
|
||||
out = mlx.Transpose(out, 0, 2, 1, 3)
|
||||
out = mlx.Reshape(out, B, L, attn.NHeads*attn.HeadDim)
|
||||
|
||||
out = attn.OProj.Forward(out)
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// repeatKV repeats key/value heads for GQA
|
||||
func repeatKV(x *mlx.Array, repeats int32) *mlx.Array {
|
||||
if repeats == 1 {
|
||||
return x
|
||||
}
|
||||
shape := x.Shape()
|
||||
x = mlx.ExpandDims(x, 2)
|
||||
x = mlx.Tile(x, []int32{1, 1, repeats, 1, 1})
|
||||
return mlx.Reshape(x, shape[0], shape[1]*repeats, shape[2], shape[3])
|
||||
}
|
||||
|
||||
// Qwen3MLP implements Qwen3 SwiGLU MLP
|
||||
type Qwen3MLP struct {
|
||||
GateProj *nn.Linear `weight:"gate_proj"`
|
||||
UpProj *nn.Linear `weight:"up_proj"`
|
||||
DownProj *nn.Linear `weight:"down_proj"`
|
||||
}
|
||||
|
||||
// Forward applies the MLP
|
||||
func (m *Qwen3MLP) Forward(x *mlx.Array) *mlx.Array {
|
||||
gate := m.GateProj.Forward(x)
|
||||
gate = mlx.SiLU(gate)
|
||||
up := m.UpProj.Forward(x)
|
||||
h := mlx.Mul(gate, up)
|
||||
return m.DownProj.Forward(h)
|
||||
}
|
||||
|
||||
// Qwen3Block represents a single Qwen3 transformer block
|
||||
type Qwen3Block struct {
|
||||
Attention *Qwen3Attention `weight:"self_attn"`
|
||||
MLP *Qwen3MLP `weight:"mlp"`
|
||||
InputLayerNorm *nn.RMSNorm `weight:"input_layernorm"`
|
||||
PostAttnLayerNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
|
||||
}
|
||||
|
||||
// Forward applies the Qwen3 block
|
||||
func (qb *Qwen3Block) Forward(x *mlx.Array, eps float32) *mlx.Array {
|
||||
h := qb.InputLayerNorm.Forward(x, eps)
|
||||
attnOut := qb.Attention.Forward(h)
|
||||
x = mlx.Add(x, attnOut)
|
||||
|
||||
h = qb.PostAttnLayerNorm.Forward(x, eps)
|
||||
mlpOut := qb.MLP.Forward(h)
|
||||
x = mlx.Add(x, mlpOut)
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// Qwen3TextEncoder is the full Qwen3 encoder for Z-Image
|
||||
type Qwen3TextEncoder struct {
|
||||
EmbedTokens *nn.Embedding `weight:"model.embed_tokens"`
|
||||
Layers []*Qwen3Block `weight:"model.layers"`
|
||||
FinalNorm *nn.RMSNorm `weight:"model.norm"`
|
||||
*Qwen3Config
|
||||
}
|
||||
|
||||
// Load loads the Qwen3 text encoder from a directory
|
||||
func (m *Qwen3TextEncoder) Load(path string) error {
|
||||
fmt.Println("Loading Qwen3 text encoder...")
|
||||
|
||||
// Load config
|
||||
cfg, err := loadQwen3Config(filepath.Join(path, "config.json"))
|
||||
if err != nil {
|
||||
return fmt.Errorf("config: %w", err)
|
||||
}
|
||||
m.Qwen3Config = cfg
|
||||
|
||||
// Pre-allocate layers slice
|
||||
m.Layers = make([]*Qwen3Block, cfg.NumHiddenLayers)
|
||||
|
||||
// Load weights
|
||||
weights, err := safetensors.LoadModelWeights(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
|
||||
fmt.Print(" Loading weights via struct tags... ")
|
||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||
return fmt.Errorf("load module: %w", err)
|
||||
}
|
||||
fmt.Println("✓")
|
||||
|
||||
// Initialize computed fields
|
||||
m.FinalNorm.Eps = cfg.RMSNormEps
|
||||
for _, block := range m.Layers {
|
||||
// Attention
|
||||
block.Attention.NHeads = cfg.NumAttentionHeads
|
||||
block.Attention.NKVHeads = cfg.NumKeyValueHeads
|
||||
block.Attention.HeadDim = cfg.HeadDim
|
||||
block.Attention.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
|
||||
block.Attention.RopeTheta = cfg.RopeTheta
|
||||
block.Attention.QNorm.Eps = cfg.RMSNormEps
|
||||
block.Attention.KNorm.Eps = cfg.RMSNormEps
|
||||
// Block norms
|
||||
block.InputLayerNorm.Eps = cfg.RMSNormEps
|
||||
block.PostAttnLayerNorm.Eps = cfg.RMSNormEps
|
||||
}
|
||||
|
||||
weights.ReleaseAll()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Forward encodes text tokens
|
||||
func (te *Qwen3TextEncoder) Forward(tokens *mlx.Array) *mlx.Array {
|
||||
h := te.EmbedTokens.Forward(tokens)
|
||||
eps := te.RMSNormEps
|
||||
|
||||
for _, layer := range te.Layers {
|
||||
h = layer.Forward(h, eps)
|
||||
}
|
||||
|
||||
// Apply final RMS norm
|
||||
h = te.FinalNorm.Forward(h, eps)
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// ApplyChatTemplate wraps prompt in Qwen3 chat format
|
||||
func ApplyChatTemplate(prompt string) string {
|
||||
return "<|im_start|>user\n" + prompt + "<|im_end|>\n<|im_start|>assistant\n"
|
||||
}
|
||||
|
||||
// EncodePrompt encodes a text prompt using the tokenizer and encoder
|
||||
func (te *Qwen3TextEncoder) EncodePrompt(tok *tokenizer.Tokenizer, prompt string, maxLen int) (*mlx.Array, *mlx.Array) {
|
||||
formattedPrompt := ApplyChatTemplate(prompt)
|
||||
|
||||
tokens := tok.Encode(formattedPrompt, false)
|
||||
|
||||
if len(tokens) > maxLen {
|
||||
tokens = tokens[:maxLen]
|
||||
}
|
||||
|
||||
maskData := make([]float32, maxLen)
|
||||
for i := 0; i < len(tokens); i++ {
|
||||
maskData[i] = 1.0
|
||||
}
|
||||
|
||||
// Get PAD token (different from EOS for Qwen3)
|
||||
padToken := tok.PAD()
|
||||
if padToken < 0 {
|
||||
padToken = tok.EOS() // fallback
|
||||
}
|
||||
|
||||
paddedTokens := make([]int32, maxLen)
|
||||
copy(paddedTokens, tokens)
|
||||
for i := len(tokens); i < maxLen; i++ {
|
||||
paddedTokens[i] = padToken
|
||||
}
|
||||
|
||||
tokensArr := mlx.NewArrayInt32(paddedTokens, []int32{1, int32(maxLen)})
|
||||
maskArr := mlx.NewArray(maskData, []int32{1, int32(maxLen)})
|
||||
|
||||
embeddings := te.Forward(tokensArr)
|
||||
|
||||
return embeddings, maskArr
|
||||
}
|
||||
@@ -1,692 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package zimage implements the Z-Image diffusion transformer model.
|
||||
package zimage
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
|
||||
// TransformerConfig holds Z-Image transformer configuration
|
||||
type TransformerConfig struct {
|
||||
Dim int32 `json:"dim"`
|
||||
NHeads int32 `json:"n_heads"`
|
||||
NKVHeads int32 `json:"n_kv_heads"`
|
||||
NLayers int32 `json:"n_layers"`
|
||||
NRefinerLayers int32 `json:"n_refiner_layers"`
|
||||
InChannels int32 `json:"in_channels"`
|
||||
PatchSize int32 `json:"-"` // Computed from AllPatchSize
|
||||
CapFeatDim int32 `json:"cap_feat_dim"`
|
||||
NormEps float32 `json:"norm_eps"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
TScale float32 `json:"t_scale"`
|
||||
QKNorm bool `json:"qk_norm"`
|
||||
AxesDims []int32 `json:"axes_dims"`
|
||||
AxesLens []int32 `json:"axes_lens"`
|
||||
AllPatchSize []int32 `json:"all_patch_size"` // JSON array, PatchSize = first element
|
||||
}
|
||||
|
||||
// TimestepEmbedder creates sinusoidal timestep embeddings
|
||||
// Output dimension is 256 (fixed), used for AdaLN modulation
|
||||
type TimestepEmbedder struct {
|
||||
Linear1 *nn.Linear `weight:"mlp.0"`
|
||||
Linear2 *nn.Linear `weight:"mlp.2"`
|
||||
FreqEmbedSize int32 // 256 (computed)
|
||||
}
|
||||
|
||||
// Forward computes timestep embeddings -> [B, 256]
|
||||
func (te *TimestepEmbedder) Forward(t *mlx.Array) *mlx.Array {
|
||||
// t: [B] timesteps
|
||||
|
||||
// Create sinusoidal embedding
|
||||
half := te.FreqEmbedSize / 2
|
||||
|
||||
// freqs = exp(-log(10000) * arange(half) / half)
|
||||
freqs := make([]float32, half)
|
||||
for i := int32(0); i < half; i++ {
|
||||
freqs[i] = float32(math.Exp(-math.Log(10000.0) * float64(i) / float64(half)))
|
||||
}
|
||||
freqsArr := mlx.NewArray(freqs, []int32{1, half})
|
||||
|
||||
// t[:, None] * freqs[None, :] -> [B, half]
|
||||
tExpanded := mlx.ExpandDims(t, 1) // [B, 1]
|
||||
args := mlx.Mul(tExpanded, freqsArr)
|
||||
|
||||
// embedding = [cos(args), sin(args)] -> [B, 256]
|
||||
cosArgs := mlx.Cos(args)
|
||||
sinArgs := mlx.Sin(args)
|
||||
embedding := mlx.Concatenate([]*mlx.Array{cosArgs, sinArgs}, 1)
|
||||
|
||||
// MLP: linear1 -> silu -> linear2
|
||||
h := te.Linear1.Forward(embedding)
|
||||
h = mlx.SiLU(h)
|
||||
h = te.Linear2.Forward(h)
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// XEmbedder embeds image patches to model dimension
|
||||
type XEmbedder struct {
|
||||
Linear *nn.Linear `weight:"2-1"`
|
||||
}
|
||||
|
||||
// Forward embeds patchified image latents
|
||||
func (xe *XEmbedder) Forward(x *mlx.Array) *mlx.Array {
|
||||
// x: [B, L, in_channels * 4] -> [B, L, dim]
|
||||
return xe.Linear.Forward(x)
|
||||
}
|
||||
|
||||
// CapEmbedder projects caption features to model dimension
|
||||
type CapEmbedder struct {
|
||||
Norm *nn.RMSNorm `weight:"0"`
|
||||
Linear *nn.Linear `weight:"1"`
|
||||
PadToken *mlx.Array // loaded separately at root level
|
||||
}
|
||||
|
||||
// Forward projects caption embeddings: [B, L, cap_feat_dim] -> [B, L, dim]
|
||||
func (ce *CapEmbedder) Forward(capFeats *mlx.Array) *mlx.Array {
|
||||
// RMSNorm on last axis (uses 1e-6)
|
||||
h := ce.Norm.Forward(capFeats, 1e-6)
|
||||
// Linear projection
|
||||
return ce.Linear.Forward(h)
|
||||
}
|
||||
|
||||
// FeedForward implements SwiGLU FFN
|
||||
type FeedForward struct {
|
||||
W1 *nn.Linear `weight:"w1"` // gate projection
|
||||
W2 *nn.Linear `weight:"w2"` // down projection
|
||||
W3 *nn.Linear `weight:"w3"` // up projection
|
||||
OutDim int32 // computed from W2
|
||||
}
|
||||
|
||||
// Forward applies SwiGLU: silu(W1(x)) * W3(x), then W2
|
||||
func (ff *FeedForward) Forward(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
D := shape[2]
|
||||
|
||||
// Reshape for matmul
|
||||
x = mlx.Reshape(x, B*L, D)
|
||||
gate := ff.W1.Forward(x)
|
||||
gate = mlx.SiLU(gate)
|
||||
up := ff.W3.Forward(x)
|
||||
h := mlx.Mul(gate, up)
|
||||
out := ff.W2.Forward(h)
|
||||
|
||||
return mlx.Reshape(out, B, L, ff.OutDim)
|
||||
}
|
||||
|
||||
// Attention implements multi-head attention with QK norm
|
||||
type Attention struct {
|
||||
ToQ *nn.Linear `weight:"to_q"`
|
||||
ToK *nn.Linear `weight:"to_k"`
|
||||
ToV *nn.Linear `weight:"to_v"`
|
||||
ToOut *nn.Linear `weight:"to_out.0"`
|
||||
NormQ *mlx.Array `weight:"norm_q.weight"` // [head_dim] for per-head RMSNorm
|
||||
NormK *mlx.Array `weight:"norm_k.weight"`
|
||||
// Computed fields
|
||||
NHeads int32
|
||||
HeadDim int32
|
||||
Dim int32
|
||||
Scale float32
|
||||
}
|
||||
|
||||
// Forward computes attention
|
||||
func (attn *Attention) Forward(x *mlx.Array, cos, sin *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
D := shape[2]
|
||||
|
||||
// Project Q, K, V
|
||||
xFlat := mlx.Reshape(x, B*L, D)
|
||||
q := attn.ToQ.Forward(xFlat)
|
||||
k := attn.ToK.Forward(xFlat)
|
||||
v := attn.ToV.Forward(xFlat)
|
||||
|
||||
// Reshape to [B, L, nheads, head_dim]
|
||||
q = mlx.Reshape(q, B, L, attn.NHeads, attn.HeadDim)
|
||||
k = mlx.Reshape(k, B, L, attn.NHeads, attn.HeadDim)
|
||||
v = mlx.Reshape(v, B, L, attn.NHeads, attn.HeadDim)
|
||||
|
||||
// QK norm
|
||||
q = mlx.RMSNorm(q, attn.NormQ, 1e-5)
|
||||
k = mlx.RMSNorm(k, attn.NormK, 1e-5)
|
||||
|
||||
// Apply RoPE if provided
|
||||
if cos != nil && sin != nil {
|
||||
q = applyRoPE3D(q, cos, sin)
|
||||
k = applyRoPE3D(k, cos, sin)
|
||||
}
|
||||
|
||||
// Transpose to [B, nheads, L, head_dim]
|
||||
q = mlx.Transpose(q, 0, 2, 1, 3)
|
||||
k = mlx.Transpose(k, 0, 2, 1, 3)
|
||||
v = mlx.Transpose(v, 0, 2, 1, 3)
|
||||
|
||||
// SDPA
|
||||
out := mlx.ScaledDotProductAttention(q, k, v, attn.Scale, false)
|
||||
|
||||
// Transpose back and reshape
|
||||
out = mlx.Transpose(out, 0, 2, 1, 3)
|
||||
out = mlx.Reshape(out, B*L, attn.Dim)
|
||||
out = attn.ToOut.Forward(out)
|
||||
|
||||
return mlx.Reshape(out, B, L, attn.Dim)
|
||||
}
|
||||
|
||||
// applyRoPE3D applies 3-axis rotary position embeddings
|
||||
// x: [B, L, nheads, head_dim]
|
||||
// cos, sin: [B, L, 1, head_dim/2]
|
||||
func applyRoPE3D(x *mlx.Array, cos, sin *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
nheads := shape[2]
|
||||
headDim := shape[3]
|
||||
half := headDim / 2
|
||||
|
||||
// Create even/odd index arrays
|
||||
evenIdx := make([]int32, half)
|
||||
oddIdx := make([]int32, half)
|
||||
for i := int32(0); i < half; i++ {
|
||||
evenIdx[i] = i * 2
|
||||
oddIdx[i] = i*2 + 1
|
||||
}
|
||||
evenIndices := mlx.NewArrayInt32(evenIdx, []int32{half})
|
||||
oddIndices := mlx.NewArrayInt32(oddIdx, []int32{half})
|
||||
|
||||
// Extract x1 (even indices) and x2 (odd indices) along last axis
|
||||
x1 := mlx.Take(x, evenIndices, 3) // [B, L, nheads, half]
|
||||
x2 := mlx.Take(x, oddIndices, 3) // [B, L, nheads, half]
|
||||
|
||||
// Apply rotation: [x1*cos - x2*sin, x1*sin + x2*cos]
|
||||
r1 := mlx.Sub(mlx.Mul(x1, cos), mlx.Mul(x2, sin))
|
||||
r2 := mlx.Add(mlx.Mul(x1, sin), mlx.Mul(x2, cos))
|
||||
|
||||
// Stack and reshape to interleave: [r1_0, r2_0, r1_1, r2_1, ...]
|
||||
r1 = mlx.ExpandDims(r1, 4) // [B, L, nheads, half, 1]
|
||||
r2 = mlx.ExpandDims(r2, 4) // [B, L, nheads, half, 1]
|
||||
stacked := mlx.Concatenate([]*mlx.Array{r1, r2}, 4) // [B, L, nheads, half, 2]
|
||||
return mlx.Reshape(stacked, B, L, nheads, headDim)
|
||||
}
|
||||
|
||||
// TransformerBlock is a single transformer block with optional AdaLN modulation
|
||||
type TransformerBlock struct {
|
||||
Attention *Attention `weight:"attention"`
|
||||
FeedForward *FeedForward `weight:"feed_forward"`
|
||||
AttentionNorm1 *nn.RMSNorm `weight:"attention_norm1"`
|
||||
AttentionNorm2 *nn.RMSNorm `weight:"attention_norm2"`
|
||||
FFNNorm1 *nn.RMSNorm `weight:"ffn_norm1"`
|
||||
FFNNorm2 *nn.RMSNorm `weight:"ffn_norm2"`
|
||||
AdaLN *nn.Linear `weight:"adaLN_modulation.0,optional"` // only if modulation
|
||||
// Computed fields
|
||||
HasModulation bool
|
||||
Dim int32
|
||||
}
|
||||
|
||||
// Forward applies the transformer block
|
||||
func (tb *TransformerBlock) Forward(x *mlx.Array, adaln *mlx.Array, cos, sin *mlx.Array, eps float32) *mlx.Array {
|
||||
if tb.AdaLN != nil && adaln != nil {
|
||||
// Compute modulation: [B, 256] -> [B, 4*dim]
|
||||
chunks := tb.AdaLN.Forward(adaln)
|
||||
|
||||
// Split into 4 parts: scale_msa, gate_msa, scale_mlp, gate_mlp
|
||||
chunkShape := chunks.Shape()
|
||||
chunkDim := chunkShape[1] / 4
|
||||
|
||||
scaleMSA := mlx.Slice(chunks, []int32{0, 0}, []int32{chunkShape[0], chunkDim})
|
||||
gateMSA := mlx.Slice(chunks, []int32{0, chunkDim}, []int32{chunkShape[0], chunkDim * 2})
|
||||
scaleMLP := mlx.Slice(chunks, []int32{0, chunkDim * 2}, []int32{chunkShape[0], chunkDim * 3})
|
||||
gateMLP := mlx.Slice(chunks, []int32{0, chunkDim * 3}, []int32{chunkShape[0], chunkDim * 4})
|
||||
|
||||
// Expand for broadcasting: [B, 1, dim]
|
||||
scaleMSA = mlx.ExpandDims(scaleMSA, 1)
|
||||
gateMSA = mlx.ExpandDims(gateMSA, 1)
|
||||
scaleMLP = mlx.ExpandDims(scaleMLP, 1)
|
||||
gateMLP = mlx.ExpandDims(gateMLP, 1)
|
||||
|
||||
// Attention with modulation
|
||||
normX := tb.AttentionNorm1.Forward(x, eps)
|
||||
normX = mlx.Mul(normX, mlx.AddScalar(scaleMSA, 1.0))
|
||||
attnOut := tb.Attention.Forward(normX, cos, sin)
|
||||
attnOut = tb.AttentionNorm2.Forward(attnOut, eps)
|
||||
x = mlx.Add(x, mlx.Mul(mlx.Tanh(gateMSA), attnOut))
|
||||
|
||||
// FFN with modulation
|
||||
normFFN := tb.FFNNorm1.Forward(x, eps)
|
||||
normFFN = mlx.Mul(normFFN, mlx.AddScalar(scaleMLP, 1.0))
|
||||
ffnOut := tb.FeedForward.Forward(normFFN)
|
||||
ffnOut = tb.FFNNorm2.Forward(ffnOut, eps)
|
||||
x = mlx.Add(x, mlx.Mul(mlx.Tanh(gateMLP), ffnOut))
|
||||
} else {
|
||||
// No modulation (context refiner)
|
||||
attnOut := tb.Attention.Forward(tb.AttentionNorm1.Forward(x, eps), cos, sin)
|
||||
x = mlx.Add(x, tb.AttentionNorm2.Forward(attnOut, eps))
|
||||
|
||||
ffnOut := tb.FeedForward.Forward(tb.FFNNorm1.Forward(x, eps))
|
||||
x = mlx.Add(x, tb.FFNNorm2.Forward(ffnOut, eps))
|
||||
}
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// FinalLayer outputs the denoised patches
|
||||
type FinalLayer struct {
|
||||
AdaLN *nn.Linear `weight:"adaLN_modulation.1"` // [256] -> [dim]
|
||||
Output *nn.Linear `weight:"linear"` // [dim] -> [out_channels]
|
||||
OutDim int32 // computed from Output
|
||||
}
|
||||
|
||||
// Forward computes final output
|
||||
func (fl *FinalLayer) Forward(x *mlx.Array, c *mlx.Array) *mlx.Array {
|
||||
// c: [B, 256] -> scale: [B, dim]
|
||||
scale := mlx.SiLU(c)
|
||||
scale = fl.AdaLN.Forward(scale)
|
||||
scale = mlx.ExpandDims(scale, 1) // [B, 1, dim]
|
||||
|
||||
// LayerNorm (affine=False) then scale
|
||||
x = layerNormNoAffine(x, 1e-6)
|
||||
x = mlx.Mul(x, mlx.AddScalar(scale, 1.0))
|
||||
|
||||
// Output projection
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
D := shape[2]
|
||||
x = mlx.Reshape(x, B*L, D)
|
||||
x = fl.Output.Forward(x)
|
||||
|
||||
return mlx.Reshape(x, B, L, fl.OutDim)
|
||||
}
|
||||
|
||||
// layerNormNoAffine applies layer norm without learnable parameters
|
||||
func layerNormNoAffine(x *mlx.Array, eps float32) *mlx.Array {
|
||||
ndim := x.Ndim()
|
||||
lastAxis := ndim - 1
|
||||
|
||||
mean := mlx.Mean(x, lastAxis, true)
|
||||
xCentered := mlx.Sub(x, mean)
|
||||
variance := mlx.Mean(mlx.Square(xCentered), lastAxis, true)
|
||||
return mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, eps)))
|
||||
}
|
||||
|
||||
// Transformer is the full Z-Image DiT model
|
||||
type Transformer struct {
|
||||
TEmbed *TimestepEmbedder `weight:"t_embedder"`
|
||||
XEmbed *XEmbedder `weight:"all_x_embedder"`
|
||||
CapEmbed *CapEmbedder `weight:"cap_embedder"`
|
||||
NoiseRefiners []*TransformerBlock `weight:"noise_refiner"`
|
||||
ContextRefiners []*TransformerBlock `weight:"context_refiner"`
|
||||
Layers []*TransformerBlock `weight:"layers"`
|
||||
FinalLayer *FinalLayer `weight:"all_final_layer.2-1"`
|
||||
XPadToken *mlx.Array `weight:"x_pad_token"`
|
||||
CapPadToken *mlx.Array `weight:"cap_pad_token"`
|
||||
*TransformerConfig
|
||||
}
|
||||
|
||||
// Load loads the Z-Image transformer from a directory
|
||||
func (m *Transformer) Load(path string) error {
|
||||
fmt.Println("Loading Z-Image transformer...")
|
||||
|
||||
// Load config
|
||||
cfg, err := loadTransformerConfig(filepath.Join(path, "config.json"))
|
||||
if err != nil {
|
||||
return fmt.Errorf("config: %w", err)
|
||||
}
|
||||
m.TransformerConfig = cfg
|
||||
|
||||
// Pre-allocate slices for loader
|
||||
m.NoiseRefiners = make([]*TransformerBlock, cfg.NRefinerLayers)
|
||||
m.ContextRefiners = make([]*TransformerBlock, cfg.NRefinerLayers)
|
||||
m.Layers = make([]*TransformerBlock, cfg.NLayers)
|
||||
|
||||
// Load weights
|
||||
weights, err := safetensors.LoadModelWeights(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
|
||||
fmt.Print(" Loading weights as bf16... ")
|
||||
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
|
||||
return fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
fmt.Printf("✓ (%.1f GB)\n", float64(mlx.MetalGetActiveMemory())/(1024*1024*1024))
|
||||
|
||||
fmt.Print(" Loading weights via struct tags... ")
|
||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||
return fmt.Errorf("load module: %w", err)
|
||||
}
|
||||
fmt.Println("✓")
|
||||
|
||||
// Initialize computed fields
|
||||
m.TEmbed.FreqEmbedSize = 256
|
||||
m.FinalLayer.OutDim = m.FinalLayer.Output.Weight.Shape()[0]
|
||||
m.CapEmbed.Norm.Eps = 1e-6
|
||||
|
||||
for _, block := range m.NoiseRefiners {
|
||||
initTransformerBlock(block, cfg)
|
||||
}
|
||||
for _, block := range m.ContextRefiners {
|
||||
initTransformerBlock(block, cfg)
|
||||
}
|
||||
for _, block := range m.Layers {
|
||||
initTransformerBlock(block, cfg)
|
||||
}
|
||||
|
||||
weights.ReleaseAll()
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadTransformerConfig loads transformer config from a JSON file
|
||||
func loadTransformerConfig(path string) (*TransformerConfig, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read config: %w", err)
|
||||
}
|
||||
var cfg TransformerConfig
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("parse config: %w", err)
|
||||
}
|
||||
// Extract PatchSize from array
|
||||
if len(cfg.AllPatchSize) > 0 {
|
||||
cfg.PatchSize = cfg.AllPatchSize[0]
|
||||
}
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
// initTransformerBlock sets computed fields on a transformer block
|
||||
func initTransformerBlock(block *TransformerBlock, cfg *TransformerConfig) {
|
||||
block.Dim = cfg.Dim
|
||||
block.HasModulation = block.AdaLN != nil
|
||||
|
||||
// Init attention computed fields
|
||||
attn := block.Attention
|
||||
attn.NHeads = cfg.NHeads
|
||||
attn.HeadDim = cfg.Dim / cfg.NHeads
|
||||
attn.Dim = cfg.Dim
|
||||
attn.Scale = float32(1.0 / math.Sqrt(float64(attn.HeadDim)))
|
||||
|
||||
// Init feedforward OutDim
|
||||
block.FeedForward.OutDim = block.FeedForward.W2.Weight.Shape()[0]
|
||||
|
||||
// Set eps on all RMSNorm layers
|
||||
block.AttentionNorm1.Eps = cfg.NormEps
|
||||
block.AttentionNorm2.Eps = cfg.NormEps
|
||||
block.FFNNorm1.Eps = cfg.NormEps
|
||||
block.FFNNorm2.Eps = cfg.NormEps
|
||||
}
|
||||
|
||||
// RoPECache holds precomputed RoPE values
|
||||
type RoPECache struct {
|
||||
ImgCos *mlx.Array
|
||||
ImgSin *mlx.Array
|
||||
CapCos *mlx.Array
|
||||
CapSin *mlx.Array
|
||||
UnifiedCos *mlx.Array
|
||||
UnifiedSin *mlx.Array
|
||||
ImgLen int32
|
||||
CapLen int32
|
||||
}
|
||||
|
||||
// PrepareRoPECache precomputes RoPE values for the given image and caption lengths.
|
||||
// hTok and wTok are the number of tokens in each dimension (latentH/patchSize, latentW/patchSize).
|
||||
func (m *Transformer) PrepareRoPECache(hTok, wTok, capLen int32) *RoPECache {
|
||||
imgLen := hTok * wTok
|
||||
|
||||
// Image positions: grid over (1, H, W) starting at (capLen+1, 0, 0)
|
||||
imgPos := createCoordinateGrid(1, hTok, wTok, capLen+1, 0, 0)
|
||||
imgPos = mlx.ToBFloat16(imgPos)
|
||||
// Caption positions: grid over (capLen, 1, 1) starting at (1, 0, 0)
|
||||
capPos := createCoordinateGrid(capLen, 1, 1, 1, 0, 0)
|
||||
capPos = mlx.ToBFloat16(capPos)
|
||||
|
||||
// Compute RoPE from UNIFIED positions
|
||||
unifiedPos := mlx.Concatenate([]*mlx.Array{imgPos, capPos}, 1)
|
||||
unifiedCos, unifiedSin := prepareRoPE3D(unifiedPos, m.TransformerConfig.AxesDims)
|
||||
|
||||
// Slice RoPE for image and caption parts
|
||||
imgCos := mlx.Slice(unifiedCos, []int32{0, 0, 0, 0}, []int32{1, imgLen, 1, 64})
|
||||
imgSin := mlx.Slice(unifiedSin, []int32{0, 0, 0, 0}, []int32{1, imgLen, 1, 64})
|
||||
capCos := mlx.Slice(unifiedCos, []int32{0, imgLen, 0, 0}, []int32{1, imgLen + capLen, 1, 64})
|
||||
capSin := mlx.Slice(unifiedSin, []int32{0, imgLen, 0, 0}, []int32{1, imgLen + capLen, 1, 64})
|
||||
|
||||
return &RoPECache{
|
||||
ImgCos: imgCos,
|
||||
ImgSin: imgSin,
|
||||
CapCos: capCos,
|
||||
CapSin: capSin,
|
||||
UnifiedCos: unifiedCos,
|
||||
UnifiedSin: unifiedSin,
|
||||
ImgLen: imgLen,
|
||||
CapLen: capLen,
|
||||
}
|
||||
}
|
||||
|
||||
// Forward runs the Z-Image transformer with precomputed RoPE
|
||||
func (m *Transformer) Forward(x *mlx.Array, t *mlx.Array, capFeats *mlx.Array, rope *RoPECache) *mlx.Array {
|
||||
imgLen := rope.ImgLen
|
||||
|
||||
// Timestep embedding -> [B, 256]
|
||||
temb := m.TEmbed.Forward(mlx.MulScalar(t, m.TransformerConfig.TScale))
|
||||
|
||||
// Embed image patches -> [B, L_img, dim]
|
||||
x = m.XEmbed.Forward(x)
|
||||
|
||||
// Embed caption features -> [B, L_cap, dim]
|
||||
capEmb := m.CapEmbed.Forward(capFeats)
|
||||
|
||||
eps := m.NormEps
|
||||
|
||||
// Noise refiner: refine image patches with modulation
|
||||
for _, refiner := range m.NoiseRefiners {
|
||||
x = refiner.Forward(x, temb, rope.ImgCos, rope.ImgSin, eps)
|
||||
}
|
||||
|
||||
// Context refiner: refine caption (no modulation)
|
||||
for _, refiner := range m.ContextRefiners {
|
||||
capEmb = refiner.Forward(capEmb, nil, rope.CapCos, rope.CapSin, eps)
|
||||
}
|
||||
|
||||
// Concatenate image and caption for joint attention
|
||||
unified := mlx.Concatenate([]*mlx.Array{x, capEmb}, 1)
|
||||
|
||||
// Main transformer layers use full unified RoPE
|
||||
for _, layer := range m.Layers {
|
||||
unified = layer.Forward(unified, temb, rope.UnifiedCos, rope.UnifiedSin, eps)
|
||||
}
|
||||
|
||||
// Extract image tokens only
|
||||
unifiedShape := unified.Shape()
|
||||
B := unifiedShape[0]
|
||||
imgOut := mlx.Slice(unified, []int32{0, 0, 0}, []int32{B, imgLen, unifiedShape[2]})
|
||||
|
||||
// Final layer
|
||||
return m.FinalLayer.Forward(imgOut, temb)
|
||||
}
|
||||
|
||||
// ForwardWithCache runs the transformer with layer caching for faster inference.
|
||||
// On refresh steps (step % cacheInterval == 0), all layers are computed and cached.
|
||||
// On other steps, shallow layers (0 to cacheLayers-1) reuse cached outputs.
|
||||
func (m *Transformer) ForwardWithCache(
|
||||
x *mlx.Array,
|
||||
t *mlx.Array,
|
||||
capFeats *mlx.Array,
|
||||
rope *RoPECache,
|
||||
stepCache *cache.StepCache,
|
||||
step int,
|
||||
cacheInterval int,
|
||||
) *mlx.Array {
|
||||
imgLen := rope.ImgLen
|
||||
cacheLayers := stepCache.NumLayers()
|
||||
eps := m.NormEps
|
||||
|
||||
// Timestep embedding -> [B, 256]
|
||||
temb := m.TEmbed.Forward(mlx.MulScalar(t, m.TransformerConfig.TScale))
|
||||
|
||||
// Embed image patches -> [B, L_img, dim]
|
||||
x = m.XEmbed.Forward(x)
|
||||
|
||||
// Context refiners: compute once on step 0, reuse forever
|
||||
// (caption embedding doesn't depend on timestep or latents)
|
||||
var capEmb *mlx.Array
|
||||
if stepCache.GetConstant() != nil {
|
||||
capEmb = stepCache.GetConstant()
|
||||
} else {
|
||||
capEmb = m.CapEmbed.Forward(capFeats)
|
||||
for _, refiner := range m.ContextRefiners {
|
||||
capEmb = refiner.Forward(capEmb, nil, rope.CapCos, rope.CapSin, eps)
|
||||
}
|
||||
stepCache.SetConstant(capEmb)
|
||||
}
|
||||
|
||||
// Noise refiners: always compute (depend on x which changes each step)
|
||||
for _, refiner := range m.NoiseRefiners {
|
||||
x = refiner.Forward(x, temb, rope.ImgCos, rope.ImgSin, eps)
|
||||
}
|
||||
|
||||
// Concatenate image and caption for joint attention
|
||||
unified := mlx.Concatenate([]*mlx.Array{x, capEmb}, 1)
|
||||
|
||||
// Determine if this is a cache refresh step
|
||||
refreshCache := stepCache.ShouldRefresh(step, cacheInterval)
|
||||
|
||||
// Main transformer layers with caching
|
||||
for i, layer := range m.Layers {
|
||||
if i < cacheLayers && !refreshCache && stepCache.Get(i) != nil {
|
||||
// Use cached output for shallow layers
|
||||
unified = stepCache.Get(i)
|
||||
} else {
|
||||
// Compute layer
|
||||
unified = layer.Forward(unified, temb, rope.UnifiedCos, rope.UnifiedSin, eps)
|
||||
// Cache shallow layer outputs on refresh steps
|
||||
if i < cacheLayers && refreshCache {
|
||||
stepCache.Set(i, unified)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Extract image tokens only
|
||||
unifiedShape := unified.Shape()
|
||||
B := unifiedShape[0]
|
||||
imgOut := mlx.Slice(unified, []int32{0, 0, 0}, []int32{B, imgLen, unifiedShape[2]})
|
||||
|
||||
// Final layer
|
||||
return m.FinalLayer.Forward(imgOut, temb)
|
||||
}
|
||||
|
||||
// createCoordinateGrid creates 3D position grid [1, d0*d1*d2, 3]
|
||||
func createCoordinateGrid(d0, d1, d2, s0, s1, s2 int32) *mlx.Array {
|
||||
// Create meshgrid and stack
|
||||
total := d0 * d1 * d2
|
||||
coords := make([]float32, total*3)
|
||||
|
||||
idx := 0
|
||||
for i := int32(0); i < d0; i++ {
|
||||
for j := int32(0); j < d1; j++ {
|
||||
for k := int32(0); k < d2; k++ {
|
||||
coords[idx*3+0] = float32(s0 + i)
|
||||
coords[idx*3+1] = float32(s1 + j)
|
||||
coords[idx*3+2] = float32(s2 + k)
|
||||
idx++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return mlx.NewArray(coords, []int32{1, total, 3})
|
||||
}
|
||||
|
||||
// prepareRoPE3D computes cos/sin for 3-axis RoPE
|
||||
// positions: [B, L, 3] with (h, w, t) coordinates
|
||||
// axesDims: [32, 48, 48] - dimensions for each axis
|
||||
// Returns: cos, sin each [B, L, 1, head_dim/2]
|
||||
func prepareRoPE3D(positions *mlx.Array, axesDims []int32) (*mlx.Array, *mlx.Array) {
|
||||
// Compute frequencies for each axis
|
||||
// dims = [32, 48, 48], so halves = [16, 24, 24]
|
||||
ropeTheta := float32(256.0)
|
||||
|
||||
freqs := make([]*mlx.Array, 3)
|
||||
for axis := 0; axis < 3; axis++ {
|
||||
half := axesDims[axis] / 2
|
||||
f := make([]float32, half)
|
||||
for i := int32(0); i < half; i++ {
|
||||
f[i] = float32(math.Exp(-math.Log(float64(ropeTheta)) * float64(i) / float64(half)))
|
||||
}
|
||||
freqs[axis] = mlx.NewArray(f, []int32{1, 1, 1, half})
|
||||
}
|
||||
|
||||
// Extract position coordinates
|
||||
shape := positions.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
|
||||
// positions[:, :, 0] -> h positions
|
||||
posH := mlx.Slice(positions, []int32{0, 0, 0}, []int32{B, L, 1})
|
||||
posW := mlx.Slice(positions, []int32{0, 0, 1}, []int32{B, L, 2})
|
||||
posT := mlx.Slice(positions, []int32{0, 0, 2}, []int32{B, L, 3})
|
||||
|
||||
// Compute args: pos * freqs for each axis
|
||||
posH = mlx.ExpandDims(posH, 3) // [B, L, 1, 1]
|
||||
posW = mlx.ExpandDims(posW, 3)
|
||||
posT = mlx.ExpandDims(posT, 3)
|
||||
|
||||
argsH := mlx.Mul(posH, freqs[0]) // [B, L, 1, 16]
|
||||
argsW := mlx.Mul(posW, freqs[1]) // [B, L, 1, 24]
|
||||
argsT := mlx.Mul(posT, freqs[2]) // [B, L, 1, 24]
|
||||
|
||||
// Concatenate: [B, L, 1, 16+24+24=64]
|
||||
args := mlx.Concatenate([]*mlx.Array{argsH, argsW, argsT}, 3)
|
||||
|
||||
// Compute cos and sin
|
||||
return mlx.Cos(args), mlx.Sin(args)
|
||||
}
|
||||
|
||||
// PatchifyLatents converts latents [B, C, H, W] to patches [B, L, C*patch^2]
|
||||
// Matches Python: x.reshape(C, 1, 1, H_tok, 2, W_tok, 2).transpose(1,2,3,5,4,6,0).reshape(1,-1,C*4)
|
||||
func PatchifyLatents(latents *mlx.Array, patchSize int32) *mlx.Array {
|
||||
shape := latents.Shape()
|
||||
C := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
|
||||
pH := H / patchSize // H_tok
|
||||
pW := W / patchSize // W_tok
|
||||
|
||||
// Match Python exactly: reshape treating B=1 as part of contiguous data
|
||||
// [1, C, H, W] -> [C, 1, 1, pH, 2, pW, 2]
|
||||
x := mlx.Reshape(latents, C, 1, 1, pH, patchSize, pW, patchSize)
|
||||
|
||||
// Python: transpose(1, 2, 3, 5, 4, 6, 0)
|
||||
// [C, 1, 1, pH, 2, pW, 2] -> [1, 1, pH, pW, 2, 2, C]
|
||||
x = mlx.Transpose(x, 1, 2, 3, 5, 4, 6, 0)
|
||||
|
||||
// [1, 1, pH, pW, 2, 2, C] -> [1, pH*pW, C*4]
|
||||
return mlx.Reshape(x, 1, pH*pW, C*patchSize*patchSize)
|
||||
}
|
||||
|
||||
// UnpatchifyLatents converts patches [B, L, C*patch^2] back to [B, C, H, W]
|
||||
// Matches Python: out.reshape(1,1,H_tok,W_tok,2,2,C).transpose(6,0,1,2,4,3,5).reshape(1,C,H,W)
|
||||
func UnpatchifyLatents(patches *mlx.Array, patchSize, H, W, C int32) *mlx.Array {
|
||||
pH := H / patchSize
|
||||
pW := W / patchSize
|
||||
|
||||
// [1, L, C*4] -> [1, 1, pH, pW, 2, 2, C]
|
||||
x := mlx.Reshape(patches, 1, 1, pH, pW, patchSize, patchSize, C)
|
||||
|
||||
// Python: transpose(6, 0, 1, 2, 4, 3, 5)
|
||||
// [1, 1, pH, pW, 2, 2, C] -> [C, 1, 1, pH, 2, pW, 2]
|
||||
x = mlx.Transpose(x, 6, 0, 1, 2, 4, 3, 5)
|
||||
|
||||
// [C, 1, 1, pH, 2, pW, 2] -> [1, C, H, W]
|
||||
return mlx.Reshape(x, 1, C, H, W)
|
||||
}
|
||||
@@ -1,652 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package zimage
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
|
||||
// VAEConfig holds VAE decoder configuration
|
||||
type VAEConfig struct {
|
||||
InChannels int32 `json:"in_channels"`
|
||||
OutChannels int32 `json:"out_channels"`
|
||||
LatentChannels int32 `json:"latent_channels"`
|
||||
BlockOutChannels []int32 `json:"block_out_channels"`
|
||||
LayersPerBlock int32 `json:"layers_per_block"`
|
||||
NormNumGroups int32 `json:"norm_num_groups"`
|
||||
ScalingFactor float32 `json:"scaling_factor"`
|
||||
ShiftFactor float32 `json:"shift_factor"`
|
||||
}
|
||||
|
||||
// loadVAEConfig loads VAE config from a JSON file
|
||||
func loadVAEConfig(path string) (*VAEConfig, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read config: %w", err)
|
||||
}
|
||||
var cfg VAEConfig
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("parse config: %w", err)
|
||||
}
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
// GroupNormLayer implements group normalization
|
||||
type GroupNormLayer struct {
|
||||
Weight *mlx.Array
|
||||
Bias *mlx.Array
|
||||
NumGroups int32
|
||||
Eps float32
|
||||
}
|
||||
|
||||
// NewGroupNorm creates a group norm layer
|
||||
func NewGroupNorm(weight, bias *mlx.Array, numGroups int32) *GroupNormLayer {
|
||||
return &GroupNormLayer{
|
||||
Weight: weight,
|
||||
Bias: bias,
|
||||
NumGroups: numGroups,
|
||||
Eps: 1e-5,
|
||||
}
|
||||
}
|
||||
|
||||
// Forward applies group normalization
|
||||
func (gn *GroupNormLayer) Forward(x *mlx.Array) *mlx.Array {
|
||||
// x: [B, C, H, W]
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
C := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
|
||||
// Reshape to [B, groups, C/groups, H, W]
|
||||
groupSize := C / gn.NumGroups
|
||||
x = mlx.Reshape(x, B, gn.NumGroups, groupSize, H, W)
|
||||
|
||||
// Compute mean and variance per group
|
||||
mean := mlx.Mean(x, 2, true)
|
||||
mean = mlx.Mean(mean, 3, true)
|
||||
mean = mlx.Mean(mean, 4, true)
|
||||
|
||||
xCentered := mlx.Sub(x, mean)
|
||||
variance := mlx.Mean(mlx.Square(xCentered), 2, true)
|
||||
variance = mlx.Mean(variance, 3, true)
|
||||
variance = mlx.Mean(variance, 4, true)
|
||||
|
||||
// Normalize
|
||||
xNorm := mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, gn.Eps)))
|
||||
|
||||
// Reshape back to [B, C, H, W]
|
||||
xNorm = mlx.Reshape(xNorm, B, C, H, W)
|
||||
|
||||
// Scale and shift (weight and bias are [C])
|
||||
if gn.Weight != nil {
|
||||
weight := mlx.Reshape(gn.Weight, 1, C, 1, 1)
|
||||
xNorm = mlx.Mul(xNorm, weight)
|
||||
}
|
||||
if gn.Bias != nil {
|
||||
bias := mlx.Reshape(gn.Bias, 1, C, 1, 1)
|
||||
xNorm = mlx.Add(xNorm, bias)
|
||||
}
|
||||
|
||||
return xNorm
|
||||
}
|
||||
|
||||
// Conv2D represents a 2D convolution layer
|
||||
// MLX uses NHWC format, but we store weights in OHWI format for MLX conv
|
||||
type Conv2D struct {
|
||||
Weight *mlx.Array // [out_channels, kH, kW, in_channels] (OHWI for MLX)
|
||||
Bias *mlx.Array // [out_channels]
|
||||
Stride int32
|
||||
Padding int32
|
||||
}
|
||||
|
||||
// NewConv2D creates a Conv2D layer
|
||||
// weight comes in as [out_channels, in_channels, kH, kW] (OIHW from PyTorch)
|
||||
// we transpose to [out_channels, kH, kW, in_channels] (OHWI for MLX)
|
||||
func NewConv2D(weight, bias *mlx.Array, stride, padding int32) *Conv2D {
|
||||
// Transpose weight from OIHW to OHWI
|
||||
// [O, I, H, W] -> [O, H, W, I]
|
||||
weightOHWI := mlx.Transpose(weight, 0, 2, 3, 1)
|
||||
return &Conv2D{
|
||||
Weight: weightOHWI,
|
||||
Bias: bias,
|
||||
Stride: stride,
|
||||
Padding: padding,
|
||||
}
|
||||
}
|
||||
|
||||
// Forward applies convolution
|
||||
// Input x is in NCHW format, we convert to NHWC for MLX, then back to NCHW
|
||||
func (conv *Conv2D) Forward(x *mlx.Array) *mlx.Array {
|
||||
// x: [N, C, H, W] -> [N, H, W, C]
|
||||
xNHWC := mlx.Transpose(x, 0, 2, 3, 1)
|
||||
|
||||
// Conv in NHWC format
|
||||
outNHWC := mlx.Conv2d(xNHWC, conv.Weight, conv.Stride, conv.Padding)
|
||||
|
||||
// Convert back to NCHW: [N, H, W, C] -> [N, C, H, W]
|
||||
out := mlx.Transpose(outNHWC, 0, 3, 1, 2)
|
||||
|
||||
if conv.Bias != nil {
|
||||
bias := mlx.Reshape(conv.Bias, 1, conv.Bias.Dim(0), 1, 1)
|
||||
out = mlx.Add(out, bias)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// ResnetBlock2D implements a ResNet block for VAE
|
||||
type ResnetBlock2D struct {
|
||||
Norm1 *GroupNormLayer
|
||||
Conv1 *Conv2D
|
||||
Norm2 *GroupNormLayer
|
||||
Conv2 *Conv2D
|
||||
ConvShortcut *Conv2D // nil if in_channels == out_channels
|
||||
}
|
||||
|
||||
// NewResnetBlock2D creates a ResNet block
|
||||
func NewResnetBlock2D(weights *safetensors.ModelWeights, prefix string, numGroups int32) (*ResnetBlock2D, error) {
|
||||
norm1Weight, err := weights.GetTensor(prefix + ".norm1.weight")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
norm1Bias, err := weights.GetTensor(prefix + ".norm1.bias")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
conv1Weight, err := weights.GetTensor(prefix + ".conv1.weight")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conv1Bias, err := weights.GetTensor(prefix + ".conv1.bias")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
norm2Weight, err := weights.GetTensor(prefix + ".norm2.weight")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
norm2Bias, err := weights.GetTensor(prefix + ".norm2.bias")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
conv2Weight, err := weights.GetTensor(prefix + ".conv2.weight")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conv2Bias, err := weights.GetTensor(prefix + ".conv2.bias")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
block := &ResnetBlock2D{
|
||||
Norm1: NewGroupNorm(norm1Weight, norm1Bias, numGroups),
|
||||
Conv1: NewConv2D(conv1Weight, conv1Bias, 1, 1),
|
||||
Norm2: NewGroupNorm(norm2Weight, norm2Bias, numGroups),
|
||||
Conv2: NewConv2D(conv2Weight, conv2Bias, 1, 1),
|
||||
}
|
||||
|
||||
if weights.HasTensor(prefix + ".conv_shortcut.weight") {
|
||||
shortcutWeight, err := weights.GetTensor(prefix + ".conv_shortcut.weight")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
shortcutBias, err := weights.GetTensor(prefix + ".conv_shortcut.bias")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
block.ConvShortcut = NewConv2D(shortcutWeight, shortcutBias, 1, 0)
|
||||
}
|
||||
|
||||
return block, nil
|
||||
}
|
||||
|
||||
// Forward applies the ResNet block with staged evaluation
|
||||
func (rb *ResnetBlock2D) Forward(x *mlx.Array) *mlx.Array {
|
||||
var h *mlx.Array
|
||||
|
||||
// Stage 1: norm1
|
||||
{
|
||||
h = rb.Norm1.Forward(x)
|
||||
mlx.Eval(h)
|
||||
}
|
||||
|
||||
// Stage 2: silu + conv1
|
||||
{
|
||||
prev := h
|
||||
h = mlx.SiLU(h)
|
||||
h = rb.Conv1.Forward(h)
|
||||
prev.Free()
|
||||
mlx.Eval(h)
|
||||
}
|
||||
|
||||
// Stage 3: norm2
|
||||
{
|
||||
prev := h
|
||||
h = rb.Norm2.Forward(h)
|
||||
prev.Free()
|
||||
mlx.Eval(h)
|
||||
}
|
||||
|
||||
// Stage 4: silu + conv2
|
||||
{
|
||||
prev := h
|
||||
h = mlx.SiLU(h)
|
||||
h = rb.Conv2.Forward(h)
|
||||
prev.Free()
|
||||
mlx.Eval(h)
|
||||
}
|
||||
|
||||
// Residual connection
|
||||
{
|
||||
prev := h
|
||||
if rb.ConvShortcut != nil {
|
||||
shortcut := rb.ConvShortcut.Forward(x)
|
||||
h = mlx.Add(h, shortcut)
|
||||
} else {
|
||||
h = mlx.Add(h, x)
|
||||
}
|
||||
prev.Free()
|
||||
mlx.Eval(h)
|
||||
}
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// VAEAttentionBlock implements self-attention for VAE
|
||||
type VAEAttentionBlock struct {
|
||||
GroupNorm *GroupNormLayer
|
||||
ToQWeight *mlx.Array
|
||||
ToQBias *mlx.Array
|
||||
ToKWeight *mlx.Array
|
||||
ToKBias *mlx.Array
|
||||
ToVWeight *mlx.Array
|
||||
ToVBias *mlx.Array
|
||||
ToOutWeight *mlx.Array
|
||||
ToOutBias *mlx.Array
|
||||
NumHeads int32
|
||||
}
|
||||
|
||||
// NewVAEAttentionBlock creates an attention block
|
||||
func NewVAEAttentionBlock(weights *safetensors.ModelWeights, prefix string, numGroups int32) (*VAEAttentionBlock, error) {
|
||||
normWeight, err := weights.GetTensor(prefix + ".group_norm.weight")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
normBias, err := weights.GetTensor(prefix + ".group_norm.bias")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
toQWeight, err := weights.GetTensor(prefix + ".to_q.weight")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
toQBias, err := weights.GetTensor(prefix + ".to_q.bias")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
toKWeight, err := weights.GetTensor(prefix + ".to_k.weight")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
toKBias, err := weights.GetTensor(prefix + ".to_k.bias")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
toVWeight, err := weights.GetTensor(prefix + ".to_v.weight")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
toVBias, err := weights.GetTensor(prefix + ".to_v.bias")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
toOutWeight, err := weights.GetTensor(prefix + ".to_out.0.weight")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
toOutBias, err := weights.GetTensor(prefix + ".to_out.0.bias")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &VAEAttentionBlock{
|
||||
GroupNorm: NewGroupNorm(normWeight, normBias, numGroups),
|
||||
ToQWeight: mlx.Transpose(toQWeight, 1, 0),
|
||||
ToQBias: toQBias,
|
||||
ToKWeight: mlx.Transpose(toKWeight, 1, 0),
|
||||
ToKBias: toKBias,
|
||||
ToVWeight: mlx.Transpose(toVWeight, 1, 0),
|
||||
ToVBias: toVBias,
|
||||
ToOutWeight: mlx.Transpose(toOutWeight, 1, 0),
|
||||
ToOutBias: toOutBias,
|
||||
NumHeads: 1,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies attention with staged evaluation
|
||||
func (ab *VAEAttentionBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
residual := x
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
C := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
|
||||
var h *mlx.Array
|
||||
|
||||
// Stage 1: GroupNorm + reshape
|
||||
{
|
||||
h = ab.GroupNorm.Forward(x)
|
||||
h = mlx.Transpose(h, 0, 2, 3, 1)
|
||||
h = mlx.Reshape(h, B, H*W, C)
|
||||
mlx.Eval(h)
|
||||
}
|
||||
|
||||
var out *mlx.Array
|
||||
|
||||
// Stage 2: Q, K, V projections + attention
|
||||
{
|
||||
q := mlx.Linear(h, ab.ToQWeight)
|
||||
q = mlx.Add(q, ab.ToQBias)
|
||||
k := mlx.Linear(h, ab.ToKWeight)
|
||||
k = mlx.Add(k, ab.ToKBias)
|
||||
v := mlx.Linear(h, ab.ToVWeight)
|
||||
v = mlx.Add(v, ab.ToVBias)
|
||||
h.Free()
|
||||
|
||||
q = mlx.ExpandDims(q, 1)
|
||||
k = mlx.ExpandDims(k, 1)
|
||||
v = mlx.ExpandDims(v, 1)
|
||||
|
||||
scale := float32(1.0 / math.Sqrt(float64(C)))
|
||||
out = mlx.ScaledDotProductAttention(q, k, v, scale, false)
|
||||
out = mlx.Squeeze(out, 1)
|
||||
mlx.Eval(out)
|
||||
}
|
||||
|
||||
// Stage 3: Output projection + reshape + residual
|
||||
{
|
||||
prev := out
|
||||
out = mlx.Linear(out, ab.ToOutWeight)
|
||||
out = mlx.Add(out, ab.ToOutBias)
|
||||
out = mlx.Reshape(out, B, H, W, C)
|
||||
out = mlx.Transpose(out, 0, 3, 1, 2)
|
||||
out = mlx.Add(out, residual)
|
||||
prev.Free()
|
||||
mlx.Eval(out)
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// UpDecoderBlock2D implements an upsampling decoder block
|
||||
type UpDecoderBlock2D struct {
|
||||
ResnetBlocks []*ResnetBlock2D
|
||||
Upsample *Conv2D
|
||||
}
|
||||
|
||||
// NewUpDecoderBlock2D creates an up decoder block
|
||||
func NewUpDecoderBlock2D(weights *safetensors.ModelWeights, prefix string, numLayers, numGroups int32, hasUpsample bool) (*UpDecoderBlock2D, error) {
|
||||
resnets := make([]*ResnetBlock2D, numLayers)
|
||||
for i := int32(0); i < numLayers; i++ {
|
||||
resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i)
|
||||
resnet, err := NewResnetBlock2D(weights, resPrefix, numGroups)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resnets[i] = resnet
|
||||
}
|
||||
|
||||
var upsample *Conv2D
|
||||
if hasUpsample {
|
||||
upWeight, err := weights.GetTensor(prefix + ".upsamplers.0.conv.weight")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
upBias, err := weights.GetTensor(prefix + ".upsamplers.0.conv.bias")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
upsample = NewConv2D(upWeight, upBias, 1, 1)
|
||||
}
|
||||
|
||||
return &UpDecoderBlock2D{
|
||||
ResnetBlocks: resnets,
|
||||
Upsample: upsample,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies the up decoder block with staged evaluation to reduce peak memory
|
||||
func (ub *UpDecoderBlock2D) Forward(x *mlx.Array) *mlx.Array {
|
||||
for _, resnet := range ub.ResnetBlocks {
|
||||
prev := x
|
||||
x = resnet.Forward(x) // ResNet handles its own pools
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
if ub.Upsample != nil {
|
||||
// Stage 1: Upsample2x (nearest neighbor)
|
||||
{
|
||||
prev := x
|
||||
x = Upsample2x(x)
|
||||
prev.Free()
|
||||
mlx.Eval(x)
|
||||
}
|
||||
|
||||
// Stage 2: Upsample conv
|
||||
{
|
||||
prev := x
|
||||
x = ub.Upsample.Forward(x)
|
||||
prev.Free()
|
||||
mlx.Eval(x)
|
||||
}
|
||||
}
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// VAEMidBlock is the middle block with attention
|
||||
type VAEMidBlock struct {
|
||||
Resnet1 *ResnetBlock2D
|
||||
Attention *VAEAttentionBlock
|
||||
Resnet2 *ResnetBlock2D
|
||||
}
|
||||
|
||||
// NewVAEMidBlock creates the mid block
|
||||
func NewVAEMidBlock(weights *safetensors.ModelWeights, prefix string, numGroups int32) (*VAEMidBlock, error) {
|
||||
resnet1, err := NewResnetBlock2D(weights, prefix+".resnets.0", numGroups)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
attention, err := NewVAEAttentionBlock(weights, prefix+".attentions.0", numGroups)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resnet2, err := NewResnetBlock2D(weights, prefix+".resnets.1", numGroups)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &VAEMidBlock{
|
||||
Resnet1: resnet1,
|
||||
Attention: attention,
|
||||
Resnet2: resnet2,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies the mid block with staged evaluation
|
||||
func (mb *VAEMidBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
prev := x
|
||||
x = mb.Resnet1.Forward(x) // ResNet handles its own pools
|
||||
prev.Free()
|
||||
|
||||
// Attention handles its own pools
|
||||
prev = x
|
||||
x = mb.Attention.Forward(x)
|
||||
prev.Free()
|
||||
|
||||
prev = x
|
||||
x = mb.Resnet2.Forward(x) // ResNet handles its own pools
|
||||
prev.Free()
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// VAEDecoder is the full VAE decoder
|
||||
type VAEDecoder struct {
|
||||
Config *VAEConfig
|
||||
ConvIn *Conv2D
|
||||
MidBlock *VAEMidBlock
|
||||
UpBlocks []*UpDecoderBlock2D
|
||||
ConvNormOut *GroupNormLayer
|
||||
ConvOut *Conv2D
|
||||
}
|
||||
|
||||
// Load loads the VAE decoder from a directory
|
||||
func (m *VAEDecoder) Load(path string) error {
|
||||
fmt.Println("Loading VAE decoder...")
|
||||
|
||||
// Load config
|
||||
cfg, err := loadVAEConfig(filepath.Join(path, "config.json"))
|
||||
if err != nil {
|
||||
return fmt.Errorf("config: %w", err)
|
||||
}
|
||||
m.Config = cfg
|
||||
|
||||
// Load weights
|
||||
weights, err := safetensors.LoadModelWeights(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
|
||||
// Load conv_in
|
||||
fmt.Print(" Loading conv_in... ")
|
||||
convInWeight, err := weights.GetTensor("decoder.conv_in.weight")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
convInBias, err := weights.GetTensor("decoder.conv_in.bias")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.ConvIn = NewConv2D(convInWeight, convInBias, 1, 1)
|
||||
fmt.Println("✓")
|
||||
|
||||
// Load mid block
|
||||
fmt.Print(" Loading mid block... ")
|
||||
m.MidBlock, err = NewVAEMidBlock(weights, "decoder.mid_block", cfg.NormNumGroups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Println("✓")
|
||||
|
||||
// Load up blocks
|
||||
fmt.Print(" Loading up blocks... ")
|
||||
numBlocks := len(cfg.BlockOutChannels)
|
||||
m.UpBlocks = make([]*UpDecoderBlock2D, numBlocks)
|
||||
for i := 0; i < numBlocks; i++ {
|
||||
prefix := fmt.Sprintf("decoder.up_blocks.%d", i)
|
||||
hasUpsample := i < numBlocks-1
|
||||
m.UpBlocks[i], err = NewUpDecoderBlock2D(weights, prefix, cfg.LayersPerBlock+1, cfg.NormNumGroups, hasUpsample)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
fmt.Printf("✓ [%d blocks]\n", numBlocks)
|
||||
|
||||
// Load conv_norm_out
|
||||
fmt.Print(" Loading conv_norm_out... ")
|
||||
normWeight, err := weights.GetTensor("decoder.conv_norm_out.weight")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
normBias, err := weights.GetTensor("decoder.conv_norm_out.bias")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.ConvNormOut = NewGroupNorm(normWeight, normBias, cfg.NormNumGroups)
|
||||
fmt.Println("✓")
|
||||
|
||||
// Load conv_out
|
||||
fmt.Print(" Loading conv_out... ")
|
||||
convOutWeight, err := weights.GetTensor("decoder.conv_out.weight")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
convOutBias, err := weights.GetTensor("decoder.conv_out.bias")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.ConvOut = NewConv2D(convOutWeight, convOutBias, 1, 1)
|
||||
fmt.Println("✓")
|
||||
|
||||
weights.ReleaseAll()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Decode decodes latents to images.
|
||||
// Uses staged pools to free intermediate arrays and reduce peak memory.
|
||||
func (vae *VAEDecoder) Decode(latents *mlx.Array) *mlx.Array {
|
||||
var h *mlx.Array
|
||||
{
|
||||
z := mlx.DivScalar(latents, vae.Config.ScalingFactor)
|
||||
z = mlx.AddScalar(z, vae.Config.ShiftFactor)
|
||||
h = vae.ConvIn.Forward(z)
|
||||
mlx.Eval(h)
|
||||
}
|
||||
|
||||
h = vae.MidBlock.Forward(h)
|
||||
|
||||
for _, upBlock := range vae.UpBlocks {
|
||||
h = upBlock.Forward(h)
|
||||
}
|
||||
|
||||
{
|
||||
prev := h
|
||||
h = vae.ConvNormOut.Forward(h)
|
||||
h = mlx.SiLU(h)
|
||||
h = vae.ConvOut.Forward(h)
|
||||
// VAE outputs [-1, 1], convert to [0, 1]
|
||||
h = mlx.AddScalar(mlx.MulScalar(h, 0.5), 0.5)
|
||||
h = mlx.ClipScalar(h, 0.0, 1.0, true, true)
|
||||
prev.Free()
|
||||
mlx.Eval(h)
|
||||
}
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// Upsample2x performs 2x nearest neighbor upsampling using broadcast.
|
||||
// x: [B, C, H, W] -> [B, C, H*2, W*2]
|
||||
func Upsample2x(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
C := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
|
||||
// [B, C, H, W] -> [B, C, H, 1, W, 1]
|
||||
x = mlx.Reshape(x, B, C, H, 1, W, 1)
|
||||
// Broadcast to [B, C, H, 2, W, 2]
|
||||
x = mlx.BroadcastTo(x, []int32{B, C, H, 2, W, 2})
|
||||
// Reshape to [B, C, H*2, W*2]
|
||||
x = mlx.Reshape(x, B, C, H*2, W*2)
|
||||
|
||||
return x
|
||||
}
|
||||
@@ -1,363 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package zimage implements the Z-Image diffusion transformer model.
|
||||
package zimage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
)
|
||||
|
||||
// GenerateConfig holds all options for image generation.
|
||||
type GenerateConfig struct {
|
||||
Prompt string
|
||||
NegativePrompt string // Empty = no CFG
|
||||
CFGScale float32 // Only used if NegativePrompt is set (default: 4.0)
|
||||
Width int32 // Image width (default: 1024)
|
||||
Height int32 // Image height (default: 1024)
|
||||
Steps int // Denoising steps (default: 9 for turbo)
|
||||
Seed int64 // Random seed
|
||||
Progress ProgressFunc // Optional progress callback
|
||||
CapturePath string // GPU capture path (debug)
|
||||
|
||||
// Layer caching options (speedup via shallow layer reuse)
|
||||
LayerCache bool // Enable layer caching (default: false)
|
||||
CacheInterval int // Refresh cache every N steps (default: 3)
|
||||
CacheLayers int // Number of shallow layers to cache (default: 15)
|
||||
}
|
||||
|
||||
// ProgressFunc is called during generation with step progress.
|
||||
type ProgressFunc func(step, totalSteps int)
|
||||
|
||||
// Model represents a Z-Image diffusion model.
|
||||
type Model struct {
|
||||
ModelPath string
|
||||
Tokenizer *tokenizer.Tokenizer
|
||||
TextEncoder *Qwen3TextEncoder
|
||||
Transformer *Transformer
|
||||
VAEDecoder *VAEDecoder
|
||||
}
|
||||
|
||||
// Load loads the Z-Image model from a directory.
|
||||
func (m *Model) Load(modelPath string) error {
|
||||
fmt.Println("Loading Z-Image model...")
|
||||
start := time.Now()
|
||||
|
||||
if mlx.GPUIsAvailable() {
|
||||
mlx.SetDefaultDeviceGPU()
|
||||
mlx.EnableCompile()
|
||||
}
|
||||
|
||||
m.ModelPath = modelPath
|
||||
|
||||
// Load tokenizer
|
||||
fmt.Print(" Loading tokenizer... ")
|
||||
tokenizerPath := filepath.Join(modelPath, "tokenizer", "tokenizer.json")
|
||||
tok, err := tokenizer.Load(tokenizerPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("tokenizer: %w", err)
|
||||
}
|
||||
m.Tokenizer = tok
|
||||
fmt.Println("✓")
|
||||
|
||||
// Load text encoder
|
||||
m.TextEncoder = &Qwen3TextEncoder{}
|
||||
if err := m.TextEncoder.Load(filepath.Join(modelPath, "text_encoder")); err != nil {
|
||||
return fmt.Errorf("text encoder: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.TextEncoder)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
// Load transformer
|
||||
m.Transformer = &Transformer{}
|
||||
if err := m.Transformer.Load(filepath.Join(modelPath, "transformer")); err != nil {
|
||||
return fmt.Errorf("transformer: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.Transformer)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
// Load VAE decoder
|
||||
m.VAEDecoder = &VAEDecoder{}
|
||||
if err := m.VAEDecoder.Load(filepath.Join(modelPath, "vae")); err != nil {
|
||||
return fmt.Errorf("VAE decoder: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.VAEDecoder)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
mem := mlx.MetalGetActiveMemory()
|
||||
fmt.Printf(" Loaded in %.2fs (%.1f GB VRAM)\n", time.Since(start).Seconds(), float64(mem)/(1024*1024*1024))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Generate creates an image from a prompt.
|
||||
func (m *Model) Generate(prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
|
||||
return m.GenerateFromConfig(&GenerateConfig{
|
||||
Prompt: prompt,
|
||||
Width: width,
|
||||
Height: height,
|
||||
Steps: steps,
|
||||
Seed: seed,
|
||||
})
|
||||
}
|
||||
|
||||
// GenerateWithProgress creates an image with progress callback.
|
||||
func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress ProgressFunc) (*mlx.Array, error) {
|
||||
return m.GenerateFromConfig(&GenerateConfig{
|
||||
Prompt: prompt,
|
||||
Width: width,
|
||||
Height: height,
|
||||
Steps: steps,
|
||||
Seed: seed,
|
||||
Progress: progress,
|
||||
})
|
||||
}
|
||||
|
||||
// GenerateWithCFG creates an image with classifier-free guidance.
|
||||
func (m *Model) GenerateWithCFG(prompt, negativePrompt string, width, height int32, steps int, seed int64, cfgScale float32, progress ProgressFunc) (*mlx.Array, error) {
|
||||
return m.GenerateFromConfig(&GenerateConfig{
|
||||
Prompt: prompt,
|
||||
NegativePrompt: negativePrompt,
|
||||
CFGScale: cfgScale,
|
||||
Width: width,
|
||||
Height: height,
|
||||
Steps: steps,
|
||||
Seed: seed,
|
||||
Progress: progress,
|
||||
})
|
||||
}
|
||||
|
||||
// GenerateFromConfig generates an image using the unified config struct.
|
||||
func (m *Model) GenerateFromConfig(cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
start := time.Now()
|
||||
result, err := m.generate(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if cfg.NegativePrompt != "" {
|
||||
fmt.Printf("Generated with CFG (scale=%.1f) in %.2fs (%d steps)\n", cfg.CFGScale, time.Since(start).Seconds(), cfg.Steps)
|
||||
} else {
|
||||
fmt.Printf("Generated in %.2fs (%d steps)\n", time.Since(start).Seconds(), cfg.Steps)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GenerateImage implements model.ImageModel interface.
|
||||
func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
|
||||
return m.Generate(prompt, width, height, steps, seed)
|
||||
}
|
||||
|
||||
// generate is the internal denoising pipeline.
|
||||
func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
// Apply defaults
|
||||
if cfg.Width <= 0 {
|
||||
cfg.Width = 1024
|
||||
}
|
||||
if cfg.Height <= 0 {
|
||||
cfg.Height = 1024
|
||||
}
|
||||
if cfg.Steps <= 0 {
|
||||
cfg.Steps = 9 // Turbo default
|
||||
}
|
||||
if cfg.CFGScale <= 0 {
|
||||
cfg.CFGScale = 4.0
|
||||
}
|
||||
if cfg.LayerCache {
|
||||
if cfg.CacheInterval <= 0 {
|
||||
cfg.CacheInterval = 3
|
||||
}
|
||||
if cfg.CacheLayers <= 0 {
|
||||
cfg.CacheLayers = 15 // Half of 30 layers
|
||||
}
|
||||
}
|
||||
|
||||
useCFG := cfg.NegativePrompt != ""
|
||||
tcfg := m.Transformer.TransformerConfig
|
||||
latentH := cfg.Height / 8
|
||||
latentW := cfg.Width / 8
|
||||
hTok := latentH / tcfg.PatchSize
|
||||
wTok := latentW / tcfg.PatchSize
|
||||
|
||||
// Text encoding with padding to multiple of 32
|
||||
var posEmb, negEmb *mlx.Array
|
||||
{
|
||||
posEmb, _ = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.Prompt, 512)
|
||||
if useCFG {
|
||||
negEmb, _ = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.NegativePrompt, 512)
|
||||
}
|
||||
|
||||
// Pad both to same length (multiple of 32)
|
||||
maxLen := posEmb.Shape()[1]
|
||||
if useCFG && negEmb.Shape()[1] > maxLen {
|
||||
maxLen = negEmb.Shape()[1]
|
||||
}
|
||||
if pad := (32 - (maxLen % 32)) % 32; pad > 0 {
|
||||
maxLen += pad
|
||||
}
|
||||
|
||||
posEmb = padToLength(posEmb, maxLen)
|
||||
if useCFG {
|
||||
negEmb = padToLength(negEmb, maxLen)
|
||||
mlx.Keep(posEmb, negEmb)
|
||||
mlx.Eval(posEmb, negEmb)
|
||||
} else {
|
||||
mlx.Keep(posEmb)
|
||||
mlx.Eval(posEmb)
|
||||
}
|
||||
}
|
||||
|
||||
// Scheduler
|
||||
scheduler := NewFlowMatchEulerScheduler(DefaultFlowMatchSchedulerConfig())
|
||||
scheduler.SetTimestepsWithMu(cfg.Steps, CalculateShift(hTok*wTok))
|
||||
|
||||
// Init latents [B, C, H, W]
|
||||
var latents *mlx.Array
|
||||
{
|
||||
latents = scheduler.InitNoise([]int32{1, tcfg.InChannels, latentH, latentW}, cfg.Seed)
|
||||
mlx.Eval(latents)
|
||||
}
|
||||
|
||||
// RoPE cache
|
||||
var ropeCache *RoPECache
|
||||
{
|
||||
ropeCache = m.Transformer.PrepareRoPECache(hTok, wTok, posEmb.Shape()[1])
|
||||
mlx.Keep(ropeCache.ImgCos, ropeCache.ImgSin, ropeCache.CapCos, ropeCache.CapSin,
|
||||
ropeCache.UnifiedCos, ropeCache.UnifiedSin)
|
||||
mlx.Eval(ropeCache.UnifiedCos)
|
||||
}
|
||||
|
||||
// Step cache for shallow layer reuse (DeepCache/Learning-to-Cache style)
|
||||
var stepCache *cache.StepCache
|
||||
if cfg.LayerCache {
|
||||
stepCache = cache.NewStepCache(cfg.CacheLayers)
|
||||
fmt.Printf(" Layer caching enabled: %d layers, refresh every %d steps\n",
|
||||
cfg.CacheLayers, cfg.CacheInterval)
|
||||
}
|
||||
|
||||
// Denoising loop
|
||||
for i := 0; i < cfg.Steps; i++ {
|
||||
stepStart := time.Now()
|
||||
if cfg.Progress != nil {
|
||||
cfg.Progress(i+1, cfg.Steps)
|
||||
}
|
||||
|
||||
// GPU capture on step 2 if requested
|
||||
if cfg.CapturePath != "" && i == 1 {
|
||||
mlx.MetalStartCapture(cfg.CapturePath)
|
||||
}
|
||||
|
||||
tCurr := scheduler.Timesteps[i]
|
||||
timestep := mlx.ToBFloat16(mlx.NewArray([]float32{1.0 - tCurr}, []int32{1}))
|
||||
|
||||
patches := PatchifyLatents(latents, tcfg.PatchSize)
|
||||
|
||||
var output *mlx.Array
|
||||
if stepCache != nil {
|
||||
// Use layer caching for faster inference
|
||||
if useCFG {
|
||||
posOutput := m.Transformer.ForwardWithCache(patches, timestep, posEmb, ropeCache,
|
||||
stepCache, i, cfg.CacheInterval)
|
||||
// Note: CFG with layer cache shares the cache between pos/neg
|
||||
// This is approximate but fast - neg prompt uses same cached shallow layers
|
||||
negOutput := m.Transformer.ForwardWithCache(patches, timestep, negEmb, ropeCache,
|
||||
stepCache, i, cfg.CacheInterval)
|
||||
diff := mlx.Sub(posOutput, negOutput)
|
||||
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
|
||||
output = mlx.Add(negOutput, scaledDiff)
|
||||
} else {
|
||||
output = m.Transformer.ForwardWithCache(patches, timestep, posEmb, ropeCache,
|
||||
stepCache, i, cfg.CacheInterval)
|
||||
}
|
||||
} else {
|
||||
// Standard forward without caching
|
||||
if useCFG {
|
||||
posOutput := m.Transformer.Forward(patches, timestep, posEmb, ropeCache)
|
||||
negOutput := m.Transformer.Forward(patches, timestep, negEmb, ropeCache)
|
||||
diff := mlx.Sub(posOutput, negOutput)
|
||||
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
|
||||
output = mlx.Add(negOutput, scaledDiff)
|
||||
} else {
|
||||
output = m.Transformer.Forward(patches, timestep, posEmb, ropeCache)
|
||||
}
|
||||
}
|
||||
|
||||
noisePred := UnpatchifyLatents(output, tcfg.PatchSize, latentH, latentW, tcfg.InChannels)
|
||||
noisePred = mlx.Neg(noisePred)
|
||||
oldLatents := latents
|
||||
latents = scheduler.Step(noisePred, latents, i)
|
||||
|
||||
// Keep latents and any cached arrays
|
||||
if stepCache != nil {
|
||||
mlx.Keep(stepCache.Arrays()...)
|
||||
}
|
||||
mlx.Eval(latents)
|
||||
oldLatents.Free()
|
||||
|
||||
if cfg.CapturePath != "" && i == 1 {
|
||||
mlx.MetalStopCapture()
|
||||
}
|
||||
|
||||
activeMem := float64(mlx.MetalGetActiveMemory()) / (1024 * 1024 * 1024)
|
||||
peakMem := float64(mlx.MetalGetPeakMemory()) / (1024 * 1024 * 1024)
|
||||
fmt.Printf(" Step %d/%d: t=%.4f (%.2fs) [%.1f GB active, %.1f GB peak]\n",
|
||||
i+1, cfg.Steps, tCurr, time.Since(stepStart).Seconds(), activeMem, peakMem)
|
||||
}
|
||||
|
||||
// Free denoising temporaries before VAE decode
|
||||
posEmb.Free()
|
||||
if negEmb != nil {
|
||||
negEmb.Free()
|
||||
}
|
||||
ropeCache.ImgCos.Free()
|
||||
ropeCache.ImgSin.Free()
|
||||
ropeCache.CapCos.Free()
|
||||
ropeCache.CapSin.Free()
|
||||
ropeCache.UnifiedCos.Free()
|
||||
ropeCache.UnifiedSin.Free()
|
||||
if stepCache != nil {
|
||||
stepCache.Free()
|
||||
}
|
||||
|
||||
// VAE decode
|
||||
decoded := m.VAEDecoder.Decode(latents)
|
||||
latents.Free()
|
||||
|
||||
return decoded, nil
|
||||
}
|
||||
|
||||
// padToLength pads a sequence tensor to the target length by repeating the last token.
|
||||
func padToLength(x *mlx.Array, targetLen int32) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
currentLen := shape[1]
|
||||
if currentLen >= targetLen {
|
||||
return x
|
||||
}
|
||||
padLen := targetLen - currentLen
|
||||
lastToken := mlx.Slice(x, []int32{0, currentLen - 1, 0}, []int32{shape[0], currentLen, shape[2]})
|
||||
padding := mlx.Tile(lastToken, []int32{1, padLen, 1})
|
||||
return mlx.Concatenate([]*mlx.Array{x, padding}, 1)
|
||||
}
|
||||
|
||||
// CalculateShift computes the mu shift value for dynamic scheduling
|
||||
func CalculateShift(imgSeqLen int32) float32 {
|
||||
baseSeqLen := float32(256)
|
||||
maxSeqLen := float32(4096)
|
||||
baseShift := float32(0.5)
|
||||
maxShift := float32(1.15)
|
||||
|
||||
m := (maxShift - baseShift) / (maxSeqLen - baseSeqLen)
|
||||
b := baseShift - m*baseSeqLen
|
||||
return float32(imgSeqLen)*m + b
|
||||
}
|
||||
@@ -1,203 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package nn provides neural network layer types.
|
||||
package nn
|
||||
|
||||
import "github.com/ollama/ollama/x/imagegen/mlx"
|
||||
|
||||
// Layer is the interface for neural network layers with a Forward method.
|
||||
type Layer interface {
|
||||
Forward(x *mlx.Array) *mlx.Array
|
||||
}
|
||||
|
||||
// Linear applies an affine transformation: y = x @ W.T + b
|
||||
// Weight is stored as [out_features, in_features], matching PyTorch/MLX convention.
|
||||
type Linear struct {
|
||||
Weight *mlx.Array `weight:"weight"` // [out_features, in_features]
|
||||
Bias *mlx.Array `weight:"bias,optional"` // [out_features] or nil
|
||||
}
|
||||
|
||||
// NewLinear creates a linear layer.
|
||||
// Weight should be [out_features, in_features].
|
||||
func NewLinear(weight *mlx.Array, bias *mlx.Array) *Linear {
|
||||
return &Linear{Weight: weight, Bias: bias}
|
||||
}
|
||||
|
||||
// NewQuantizedLinear creates a quantized linear layer directly from bf16 weights.
|
||||
// Quantizes the weight immediately and evaluates to break lazy dependencies.
|
||||
func NewQuantizedLinear(weight *mlx.Array, bias *mlx.Array, groupSize, bits int, mode string) *QuantizedLinear {
|
||||
qw, scales, qbiases := mlx.Quantize(weight, groupSize, bits, mode)
|
||||
// Eval immediately so bf16 weight can be freed
|
||||
mlx.Eval(qw, scales, qbiases)
|
||||
return &QuantizedLinear{
|
||||
Weight: qw,
|
||||
Scales: scales,
|
||||
QBiases: qbiases,
|
||||
Bias: bias,
|
||||
GroupSize: groupSize,
|
||||
Bits: bits,
|
||||
Mode: mode,
|
||||
}
|
||||
}
|
||||
|
||||
// Forward applies the linear transformation: x @ W.T + bias
|
||||
func (l *Linear) Forward(x *mlx.Array) *mlx.Array {
|
||||
w := mlx.Transpose(l.Weight, 1, 0)
|
||||
if l.Bias != nil {
|
||||
return mlx.AddMM(l.Bias, x, w, 1.0, 1.0)
|
||||
}
|
||||
return mlx.Linear(x, w)
|
||||
}
|
||||
|
||||
// ToQuantized converts this Linear to a QuantizedLinear.
|
||||
func (l *Linear) ToQuantized(groupSize, bits int, mode string) *QuantizedLinear {
|
||||
qw, scales, qbiases := mlx.Quantize(l.Weight, groupSize, bits, mode)
|
||||
return &QuantizedLinear{
|
||||
Weight: qw,
|
||||
Scales: scales,
|
||||
QBiases: qbiases,
|
||||
Bias: l.Bias,
|
||||
GroupSize: groupSize,
|
||||
Bits: bits,
|
||||
Mode: mode,
|
||||
}
|
||||
}
|
||||
|
||||
// QuantizedLinear applies an affine transformation using quantized weights.
|
||||
// Equivalent to mlx.nn.QuantizedLinear.
|
||||
type QuantizedLinear struct {
|
||||
Weight *mlx.Array // Quantized weight data
|
||||
Scales *mlx.Array // Scale factors for dequantization
|
||||
QBiases *mlx.Array // Quantization biases (NOT layer bias)
|
||||
Bias *mlx.Array // Layer bias [output_dims] or nil
|
||||
GroupSize int
|
||||
Bits int
|
||||
Mode string
|
||||
}
|
||||
|
||||
// Forward applies the quantized linear transformation.
|
||||
func (ql *QuantizedLinear) Forward(x *mlx.Array) *mlx.Array {
|
||||
out := mlx.QuantizedMatmul(x, ql.Weight, ql.Scales, ql.QBiases, true, ql.GroupSize, ql.Bits, ql.Mode)
|
||||
if ql.Bias != nil {
|
||||
out = mlx.Add(out, ql.Bias)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// RMSNorm represents an RMS normalization layer.
|
||||
type RMSNorm struct {
|
||||
Weight *mlx.Array `weight:"weight"`
|
||||
Eps float32 // optional: used if Forward called with eps=0
|
||||
}
|
||||
|
||||
// NewRMSNorm creates an RMSNorm layer (for models not using weight loader).
|
||||
func NewRMSNorm(weight *mlx.Array, eps float32) *RMSNorm {
|
||||
return &RMSNorm{Weight: weight, Eps: eps}
|
||||
}
|
||||
|
||||
// Forward applies RMS normalization. If eps=0, uses stored Eps.
|
||||
func (rn *RMSNorm) Forward(x *mlx.Array, eps float32) *mlx.Array {
|
||||
if eps == 0 {
|
||||
eps = rn.Eps
|
||||
}
|
||||
return mlx.RMSNorm(x, rn.Weight, eps)
|
||||
}
|
||||
|
||||
// Embedding represents an embedding layer.
|
||||
type Embedding struct {
|
||||
Weight *mlx.Array `weight:"weight"`
|
||||
}
|
||||
|
||||
// NewEmbedding creates an embedding layer.
|
||||
func NewEmbedding(weight *mlx.Array) *Embedding {
|
||||
return &Embedding{Weight: weight}
|
||||
}
|
||||
|
||||
// Forward looks up embeddings by indices.
|
||||
func (e *Embedding) Forward(indices *mlx.Array) *mlx.Array {
|
||||
return mlx.Take(e.Weight, indices, 0)
|
||||
}
|
||||
|
||||
// RepeatKV repeats K/V tensors for grouped query attention
|
||||
// x: [B, num_kv_heads, S, head_dim] -> [B, num_heads, S, head_dim]
|
||||
func RepeatKV(x *mlx.Array, repeatFactor int32) *mlx.Array {
|
||||
if repeatFactor == 1 {
|
||||
return x
|
||||
}
|
||||
shape := x.Shape()
|
||||
// [B, num_kv_heads, S, head_dim] -> [B, num_kv_heads, 1, S, head_dim]
|
||||
x = mlx.ExpandDims(x, 2)
|
||||
// Repeat along the new axis
|
||||
reps := []int32{1, 1, repeatFactor, 1, 1}
|
||||
x = mlx.Tile(x, reps)
|
||||
// Reshape: [B, num_kv_heads, repeat, S, head_dim] -> [B, num_kv_heads * repeat, S, head_dim]
|
||||
return mlx.Reshape(x, shape[0], shape[1]*repeatFactor, shape[2], shape[3])
|
||||
}
|
||||
|
||||
// ApplyCausalMask applies causal (lower triangular) mask to attention scores
|
||||
func ApplyCausalMask(scores *mlx.Array) *mlx.Array {
|
||||
// scores: [B, num_heads, S, S]
|
||||
shape := scores.Shape()
|
||||
seqLen := shape[2]
|
||||
|
||||
// Create causal mask: 1 for positions to keep, 0 for positions to mask
|
||||
mask := mlx.Tri(seqLen, seqLen, 0)
|
||||
|
||||
// Where mask is 0, set score to -inf
|
||||
negInf := mlx.NewScalarArray(float32(-1e9))
|
||||
|
||||
// Broadcast mask to match scores shape
|
||||
mask = mlx.ExpandDims(mlx.ExpandDims(mask, 0), 0) // [1, 1, S, S]
|
||||
|
||||
// Use where: if mask > 0, keep scores, else -inf
|
||||
return mlx.Where(mask, scores, negInf)
|
||||
}
|
||||
|
||||
// ApplyCausalMaskWithOffset applies causal mask for cached attention
|
||||
// scores: [B, num_heads, queryLen, keyLen] where keyLen = cacheLen + queryLen
|
||||
// offset: the starting position of the new queries (i.e., cache length)
|
||||
func ApplyCausalMaskWithOffset(scores *mlx.Array, offset int32) *mlx.Array {
|
||||
if offset == 0 {
|
||||
return ApplyCausalMask(scores)
|
||||
}
|
||||
|
||||
shape := scores.Shape()
|
||||
queryLen := shape[2]
|
||||
keyLen := shape[3]
|
||||
|
||||
// For cached attention, new queries can attend to all cached keys plus
|
||||
// new keys up to and including their position.
|
||||
mask := mlx.Tri(queryLen, keyLen, int(offset))
|
||||
|
||||
negInf := mlx.NewScalarArray(float32(-1e9))
|
||||
mask = mlx.ExpandDims(mlx.ExpandDims(mask, 0), 0) // [1, 1, queryLen, keyLen]
|
||||
|
||||
return mlx.Where(mask, scores, negInf)
|
||||
}
|
||||
|
||||
// LayerNorm represents a standard layer normalization layer (with bias).
|
||||
type LayerNorm struct {
|
||||
Weight *mlx.Array `weight:"weight"`
|
||||
Bias *mlx.Array `weight:"bias"`
|
||||
Eps float32
|
||||
}
|
||||
|
||||
// Forward applies layer normalization: (x - mean) / sqrt(var + eps) * weight + bias
|
||||
func (ln *LayerNorm) Forward(x *mlx.Array) *mlx.Array {
|
||||
eps := ln.Eps
|
||||
if eps == 0 {
|
||||
eps = 1e-5
|
||||
}
|
||||
// Compute mean and variance along last dimension
|
||||
mean := mlx.Mean(x, -1, true)
|
||||
centered := mlx.Sub(x, mean)
|
||||
variance := mlx.Mean(mlx.Mul(centered, centered), -1, true)
|
||||
normalized := mlx.Mul(centered, mlx.RSqrt(mlx.AddScalar(variance, eps)))
|
||||
|
||||
// Scale and shift
|
||||
out := mlx.Mul(normalized, ln.Weight)
|
||||
if ln.Bias != nil {
|
||||
out = mlx.Add(out, ln.Bias)
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -1,356 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package nn
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// TestLinearNoBias verifies Linear without bias computes x @ w.T correctly.
|
||||
func TestLinearNoBias(t *testing.T) {
|
||||
// Weight: [out=2, in=3] -> transposed at forward time
|
||||
weight := mlx.NewArrayFloat32([]float32{
|
||||
1, 2, 3, // row 0
|
||||
4, 5, 6, // row 1
|
||||
}, []int32{2, 3})
|
||||
mlx.Eval(weight)
|
||||
|
||||
linear := NewLinear(weight, nil)
|
||||
|
||||
// Input: [1, 3]
|
||||
x := mlx.NewArrayFloat32([]float32{1, 1, 1}, []int32{1, 3})
|
||||
mlx.Eval(x)
|
||||
|
||||
out := linear.Forward(x)
|
||||
mlx.Eval(out)
|
||||
|
||||
// Expected: [1,1,1] @ [[1,4],[2,5],[3,6]] = [6, 15]
|
||||
data := out.Data()
|
||||
if len(data) != 2 || data[0] != 6 || data[1] != 15 {
|
||||
t.Errorf("expected [6, 15], got %v", data)
|
||||
}
|
||||
}
|
||||
|
||||
// TestLinearWithBias verifies Linear with bias computes x @ w.T + b correctly.
|
||||
func TestLinearWithBias(t *testing.T) {
|
||||
weight := mlx.NewArrayFloat32([]float32{
|
||||
1, 2, 3,
|
||||
4, 5, 6,
|
||||
}, []int32{2, 3})
|
||||
bias := mlx.NewArrayFloat32([]float32{10, 20}, []int32{2})
|
||||
mlx.Eval(weight, bias)
|
||||
|
||||
linear := NewLinear(weight, bias)
|
||||
|
||||
x := mlx.NewArrayFloat32([]float32{1, 1, 1}, []int32{1, 3})
|
||||
mlx.Eval(x)
|
||||
|
||||
out := linear.Forward(x)
|
||||
mlx.Eval(out)
|
||||
|
||||
// Expected: [6, 15] + [10, 20] = [16, 35]
|
||||
data := out.Data()
|
||||
if len(data) != 2 || data[0] != 16 || data[1] != 35 {
|
||||
t.Errorf("expected [16, 35], got %v", data)
|
||||
}
|
||||
}
|
||||
|
||||
// TestLinearBatched verifies Linear works with batched input.
|
||||
func TestLinearBatched(t *testing.T) {
|
||||
weight := mlx.NewArrayFloat32([]float32{
|
||||
1, 0,
|
||||
0, 1,
|
||||
}, []int32{2, 2}) // Identity
|
||||
mlx.Eval(weight)
|
||||
|
||||
linear := NewLinear(weight, nil)
|
||||
|
||||
// Batch of 3 inputs
|
||||
x := mlx.NewArrayFloat32([]float32{
|
||||
1, 2,
|
||||
3, 4,
|
||||
5, 6,
|
||||
}, []int32{3, 2})
|
||||
mlx.Eval(x)
|
||||
|
||||
out := linear.Forward(x)
|
||||
mlx.Eval(out)
|
||||
|
||||
// Identity should return same values
|
||||
data := out.Data()
|
||||
expected := []float32{1, 2, 3, 4, 5, 6}
|
||||
for i, v := range expected {
|
||||
if data[i] != v {
|
||||
t.Errorf("at %d: expected %f, got %f", i, v, data[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestRMSNorm verifies RMSNorm computation.
|
||||
func TestRMSNorm(t *testing.T) {
|
||||
weight := mlx.NewArrayFloat32([]float32{1, 1, 1, 1}, []int32{4})
|
||||
mlx.Eval(weight)
|
||||
|
||||
norm := NewRMSNorm(weight, 1e-5)
|
||||
|
||||
// Input with known RMS
|
||||
x := mlx.NewArrayFloat32([]float32{2, 2, 2, 2}, []int32{1, 4})
|
||||
mlx.Eval(x)
|
||||
|
||||
out := norm.Forward(x, 0) // eps=0 uses stored Eps
|
||||
mlx.Eval(out)
|
||||
|
||||
// RMS of [2,2,2,2] = 2, so normalized = [1,1,1,1]
|
||||
data := out.Data()
|
||||
for i, v := range data {
|
||||
if math.Abs(float64(v-1.0)) > 1e-4 {
|
||||
t.Errorf("at %d: expected ~1.0, got %f", i, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestRMSNormWithScale verifies RMSNorm applies weight scaling.
|
||||
func TestRMSNormWithScale(t *testing.T) {
|
||||
weight := mlx.NewArrayFloat32([]float32{2, 2, 2, 2}, []int32{4})
|
||||
mlx.Eval(weight)
|
||||
|
||||
norm := NewRMSNorm(weight, 1e-5)
|
||||
|
||||
x := mlx.NewArrayFloat32([]float32{2, 2, 2, 2}, []int32{1, 4})
|
||||
mlx.Eval(x)
|
||||
|
||||
out := norm.Forward(x, 0) // eps=0 uses stored Eps
|
||||
mlx.Eval(out)
|
||||
|
||||
// Normalized [1,1,1,1] * weight [2,2,2,2] = [2,2,2,2]
|
||||
data := out.Data()
|
||||
for i, v := range data {
|
||||
if math.Abs(float64(v-2.0)) > 1e-4 {
|
||||
t.Errorf("at %d: expected ~2.0, got %f", i, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestEmbedding verifies embedding lookup.
|
||||
func TestEmbedding(t *testing.T) {
|
||||
// Embedding table: 4 tokens, dim 3
|
||||
weight := mlx.NewArrayFloat32([]float32{
|
||||
0, 0, 0, // token 0
|
||||
1, 1, 1, // token 1
|
||||
2, 2, 2, // token 2
|
||||
3, 3, 3, // token 3
|
||||
}, []int32{4, 3})
|
||||
mlx.Eval(weight)
|
||||
|
||||
emb := NewEmbedding(weight)
|
||||
|
||||
// Look up tokens [1, 3, 0]
|
||||
indices := mlx.NewArrayInt32([]int32{1, 3, 0}, []int32{3})
|
||||
mlx.Eval(indices)
|
||||
|
||||
out := emb.Forward(indices)
|
||||
mlx.Eval(out)
|
||||
|
||||
data := out.Data()
|
||||
expected := []float32{1, 1, 1, 3, 3, 3, 0, 0, 0}
|
||||
for i, v := range expected {
|
||||
if data[i] != v {
|
||||
t.Errorf("at %d: expected %f, got %f", i, v, data[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestRepeatKV verifies K/V repetition for GQA.
|
||||
func TestRepeatKV(t *testing.T) {
|
||||
// [B=1, num_kv_heads=2, S=2, head_dim=2]
|
||||
x := mlx.NewArrayFloat32([]float32{
|
||||
// head 0
|
||||
1, 2, // pos 0
|
||||
3, 4, // pos 1
|
||||
// head 1
|
||||
5, 6, // pos 0
|
||||
7, 8, // pos 1
|
||||
}, []int32{1, 2, 2, 2})
|
||||
mlx.Eval(x)
|
||||
|
||||
// Repeat factor 2: 2 kv heads -> 4 heads
|
||||
out := RepeatKV(x, 2)
|
||||
mlx.Eval(out)
|
||||
|
||||
shape := out.Shape()
|
||||
if shape[0] != 1 || shape[1] != 4 || shape[2] != 2 || shape[3] != 2 {
|
||||
t.Errorf("expected shape [1,4,2,2], got %v", shape)
|
||||
}
|
||||
|
||||
data := out.Data()
|
||||
// After repeat: head0, head0, head1, head1
|
||||
expected := []float32{
|
||||
1, 2, 3, 4, // head 0 (original)
|
||||
1, 2, 3, 4, // head 0 (repeat)
|
||||
5, 6, 7, 8, // head 1 (original)
|
||||
5, 6, 7, 8, // head 1 (repeat)
|
||||
}
|
||||
for i, v := range expected {
|
||||
if data[i] != v {
|
||||
t.Errorf("at %d: expected %f, got %f", i, v, data[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestRepeatKVNoOp verifies RepeatKV with factor 1 returns input unchanged.
|
||||
func TestRepeatKVNoOp(t *testing.T) {
|
||||
x := mlx.NewArrayFloat32([]float32{1, 2, 3, 4}, []int32{1, 1, 2, 2})
|
||||
mlx.Eval(x)
|
||||
|
||||
out := RepeatKV(x, 1)
|
||||
// Should return same pointer
|
||||
if out != x {
|
||||
t.Error("RepeatKV with factor 1 should return input unchanged")
|
||||
}
|
||||
}
|
||||
|
||||
// TestApplyCausalMask verifies causal masking.
|
||||
func TestApplyCausalMask(t *testing.T) {
|
||||
// [B=1, heads=1, S=3, S=3] - all ones
|
||||
scores := mlx.Ones(1, 1, 3, 3)
|
||||
mlx.Eval(scores)
|
||||
|
||||
out := ApplyCausalMask(scores)
|
||||
mlx.Eval(out)
|
||||
|
||||
data := out.Data()
|
||||
// Lower triangular should be 1, upper should be -1e9
|
||||
// Row 0: [1, -inf, -inf]
|
||||
// Row 1: [1, 1, -inf]
|
||||
// Row 2: [1, 1, 1]
|
||||
if data[0] != 1 || data[1] >= 0 || data[2] >= 0 {
|
||||
t.Errorf("row 0 wrong: %v", data[0:3])
|
||||
}
|
||||
if data[3] != 1 || data[4] != 1 || data[5] >= 0 {
|
||||
t.Errorf("row 1 wrong: %v", data[3:6])
|
||||
}
|
||||
if data[6] != 1 || data[7] != 1 || data[8] != 1 {
|
||||
t.Errorf("row 2 wrong: %v", data[6:9])
|
||||
}
|
||||
}
|
||||
|
||||
// TestApplyCausalMaskWithOffset verifies causal masking with cache offset.
|
||||
func TestApplyCausalMaskWithOffset(t *testing.T) {
|
||||
// Simulating: cache has 2 tokens, adding 1 new query
|
||||
// scores: [B=1, heads=1, queryLen=1, keyLen=3]
|
||||
scores := mlx.Ones(1, 1, 1, 3)
|
||||
mlx.Eval(scores)
|
||||
|
||||
out := ApplyCausalMaskWithOffset(scores, 2)
|
||||
mlx.Eval(out)
|
||||
|
||||
data := out.Data()
|
||||
// With offset=2, query at position 2 can attend to all 3 positions
|
||||
if data[0] != 1 || data[1] != 1 || data[2] != 1 {
|
||||
t.Errorf("expected [1, 1, 1], got %v", data)
|
||||
}
|
||||
}
|
||||
|
||||
// TestApplyCausalMaskWithOffsetZero verifies offset=0 falls back to regular causal.
|
||||
func TestApplyCausalMaskWithOffsetZero(t *testing.T) {
|
||||
scores := mlx.Ones(1, 1, 2, 2)
|
||||
mlx.Eval(scores)
|
||||
|
||||
out := ApplyCausalMaskWithOffset(scores, 0)
|
||||
mlx.Eval(out)
|
||||
|
||||
data := out.Data()
|
||||
// Standard causal: [1, -inf], [1, 1]
|
||||
if data[0] != 1 || data[1] >= 0 {
|
||||
t.Errorf("row 0 wrong: %v", data[0:2])
|
||||
}
|
||||
if data[2] != 1 || data[3] != 1 {
|
||||
t.Errorf("row 1 wrong: %v", data[2:4])
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkLinearSmall benchmarks small Linear forward pass.
|
||||
func BenchmarkLinearSmall(b *testing.B) {
|
||||
weight := mlx.RandomNormal([]int32{256, 256}, 42)
|
||||
mlx.Eval(weight)
|
||||
|
||||
linear := NewLinear(weight, nil)
|
||||
|
||||
x := mlx.RandomNormal([]int32{1, 256}, 43)
|
||||
mlx.Eval(x)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
out := linear.Forward(x)
|
||||
mlx.Eval(out)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkLinearLarge benchmarks larger Linear forward pass.
|
||||
func BenchmarkLinearLarge(b *testing.B) {
|
||||
weight := mlx.RandomNormal([]int32{4096, 4096}, 42)
|
||||
mlx.Eval(weight)
|
||||
|
||||
linear := NewLinear(weight, nil)
|
||||
|
||||
x := mlx.RandomNormal([]int32{1, 4096}, 43)
|
||||
mlx.Eval(x)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
out := linear.Forward(x)
|
||||
mlx.Eval(out)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkRMSNorm benchmarks RMSNorm forward pass.
|
||||
func BenchmarkRMSNorm(b *testing.B) {
|
||||
weight := mlx.Ones(4096)
|
||||
mlx.Eval(weight)
|
||||
|
||||
norm := NewRMSNorm(weight, 1e-5)
|
||||
|
||||
x := mlx.RandomNormal([]int32{1, 4096}, 42)
|
||||
mlx.Eval(x)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
out := norm.Forward(x, 0)
|
||||
mlx.Eval(out)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkEmbedding benchmarks embedding lookup.
|
||||
func BenchmarkEmbedding(b *testing.B) {
|
||||
// Typical vocab size
|
||||
weight := mlx.RandomNormal([]int32{32000, 4096}, 42)
|
||||
mlx.Eval(weight)
|
||||
|
||||
emb := NewEmbedding(weight)
|
||||
|
||||
// Single token lookup
|
||||
indices := mlx.NewArrayInt32([]int32{1000}, []int32{1})
|
||||
mlx.Eval(indices)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
out := emb.Forward(indices)
|
||||
mlx.Eval(out)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkRepeatKV benchmarks K/V repetition.
|
||||
func BenchmarkRepeatKV(b *testing.B) {
|
||||
// Typical GQA setup: 8 kv heads -> 32 heads
|
||||
x := mlx.RandomNormal([]int32{1, 8, 512, 128}, 42)
|
||||
mlx.Eval(x)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
out := RepeatKV(x, 4)
|
||||
mlx.Eval(out)
|
||||
}
|
||||
}
|
||||
@@ -1,170 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package safetensors
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// LoadModule loads weights into a struct using reflection and struct tags.
|
||||
//
|
||||
// Struct tags use the format: `weight:"path[,optional]"`
|
||||
// - path: the weight name suffix (appended to prefix)
|
||||
// - optional: if present, missing weights don't cause errors
|
||||
// - "-": skip this field entirely
|
||||
// - no tag on struct pointer: recurse with current prefix
|
||||
// - no tag on *mlx.Array: skip (computed fields don't need loading)
|
||||
//
|
||||
// For slices of struct pointers, the loader iterates with .0, .1, .2... suffixes.
|
||||
// The slice must be pre-allocated to the correct length.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// type Attention struct {
|
||||
// QProj *nn.Linear `weight:"self_attn.q_proj"`
|
||||
// KProj *nn.Linear `weight:"self_attn.k_proj"`
|
||||
// Cache *mlx.Array // no tag = skipped (computed field)
|
||||
// }
|
||||
//
|
||||
// err := LoadModule(&attn, weights, "model.layers.0")
|
||||
func LoadModule(dst any, weights *ModelWeights, prefix string) error {
|
||||
v := reflect.ValueOf(dst)
|
||||
if v.Kind() != reflect.Ptr || v.IsNil() {
|
||||
return fmt.Errorf("LoadModule: dst must be a non-nil pointer")
|
||||
}
|
||||
v = v.Elem()
|
||||
if v.Kind() != reflect.Struct {
|
||||
return fmt.Errorf("LoadModule: dst must be a pointer to struct, got %v", v.Kind())
|
||||
}
|
||||
|
||||
var errs []string
|
||||
loadStruct(v, weights, prefix, &errs, false)
|
||||
|
||||
if len(errs) > 0 {
|
||||
return fmt.Errorf("LoadModule: missing weights:\n %s", strings.Join(errs, "\n "))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadStruct recursively loads weights into a struct value.
|
||||
func loadStruct(v reflect.Value, weights *ModelWeights, prefix string, errs *[]string, parentOptional bool) {
|
||||
t := v.Type()
|
||||
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
field := t.Field(i)
|
||||
fieldVal := v.Field(i)
|
||||
|
||||
// Skip unexported fields
|
||||
if !fieldVal.CanSet() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Parse tag
|
||||
tag, hasTag := field.Tag.Lookup("weight")
|
||||
if tag == "-" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Parse tag options
|
||||
optional := parentOptional
|
||||
weightPath := tag
|
||||
if idx := strings.Index(tag, ","); idx != -1 {
|
||||
weightPath = tag[:idx]
|
||||
if strings.Contains(tag[idx+1:], "optional") {
|
||||
optional = true
|
||||
}
|
||||
}
|
||||
|
||||
// Build full path
|
||||
fullPath := joinPath(prefix, weightPath)
|
||||
|
||||
// For struct pointers without a tag, recurse with current prefix
|
||||
if !hasTag && fieldVal.Kind() == reflect.Ptr {
|
||||
elemType := fieldVal.Type().Elem()
|
||||
if elemType.Kind() == reflect.Struct && elemType != reflect.TypeOf(mlx.Array{}) {
|
||||
if fieldVal.IsNil() {
|
||||
fieldVal.Set(reflect.New(elemType))
|
||||
}
|
||||
loadStruct(fieldVal.Elem(), weights, prefix, errs, optional)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Handle by kind
|
||||
switch fieldVal.Kind() {
|
||||
case reflect.Ptr:
|
||||
elemType := fieldVal.Type().Elem()
|
||||
|
||||
// *mlx.Array - load directly (but skip if no tag - computed fields)
|
||||
if fieldVal.Type() == reflect.TypeOf((*mlx.Array)(nil)) {
|
||||
if !hasTag {
|
||||
continue // no tag on *mlx.Array = computed field, skip
|
||||
}
|
||||
arr, err := weights.GetTensor(fullPath)
|
||||
if err != nil {
|
||||
if !optional {
|
||||
*errs = append(*errs, fullPath)
|
||||
}
|
||||
continue
|
||||
}
|
||||
fieldVal.Set(reflect.ValueOf(arr))
|
||||
continue
|
||||
}
|
||||
|
||||
// Pointer to struct - allocate and recurse
|
||||
if elemType.Kind() == reflect.Struct {
|
||||
if optional && !hasWeightsWithPrefix(weights, fullPath) {
|
||||
continue
|
||||
}
|
||||
if fieldVal.IsNil() {
|
||||
fieldVal.Set(reflect.New(elemType))
|
||||
}
|
||||
loadStruct(fieldVal.Elem(), weights, fullPath, errs, optional)
|
||||
}
|
||||
|
||||
case reflect.Slice:
|
||||
elemType := fieldVal.Type().Elem()
|
||||
if elemType.Kind() == reflect.Ptr && elemType.Elem().Kind() == reflect.Struct {
|
||||
loadSlice(fieldVal, weights, fullPath, errs)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// hasWeightsWithPrefix checks if any weights exist with the given prefix.
|
||||
func hasWeightsWithPrefix(weights *ModelWeights, prefix string) bool {
|
||||
for _, name := range weights.ListTensors() {
|
||||
if strings.HasPrefix(name, prefix+".") || name == prefix {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// loadSlice loads weights into each element of a slice of struct pointers.
|
||||
func loadSlice(v reflect.Value, weights *ModelWeights, prefix string, errs *[]string) {
|
||||
elemStructType := v.Type().Elem().Elem()
|
||||
|
||||
for i := 0; i < v.Len(); i++ {
|
||||
elem := v.Index(i)
|
||||
if elem.IsNil() {
|
||||
elem.Set(reflect.New(elemStructType))
|
||||
}
|
||||
loadStruct(elem.Elem(), weights, fmt.Sprintf("%s.%d", prefix, i), errs, false)
|
||||
}
|
||||
}
|
||||
|
||||
// joinPath joins path segments with dots, handling empty segments.
|
||||
func joinPath(prefix, suffix string) string {
|
||||
if prefix == "" {
|
||||
return suffix
|
||||
}
|
||||
if suffix == "" {
|
||||
return prefix
|
||||
}
|
||||
return prefix + "." + suffix
|
||||
}
|
||||
@@ -1,280 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package safetensors
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// SafetensorHeader represents the JSON header of a safetensors file
|
||||
type SafetensorHeader map[string]TensorInfo
|
||||
|
||||
// TensorInfo contains metadata about a tensor
|
||||
type TensorInfo struct {
|
||||
Dtype string `json:"dtype"`
|
||||
Shape []int32 `json:"shape"`
|
||||
DataOffsets [2]int `json:"data_offsets"`
|
||||
}
|
||||
|
||||
// parseSafetensorHeader reads only the JSON header from a safetensors file.
|
||||
func parseSafetensorHeader(path string) (SafetensorHeader, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open file: %w", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var headerSize uint64
|
||||
if err := binary.Read(f, binary.LittleEndian, &headerSize); err != nil {
|
||||
return nil, fmt.Errorf("failed to read header size: %w", err)
|
||||
}
|
||||
|
||||
headerBytes := make([]byte, headerSize)
|
||||
if _, err := f.Read(headerBytes); err != nil {
|
||||
return nil, fmt.Errorf("failed to read header: %w", err)
|
||||
}
|
||||
|
||||
var header SafetensorHeader
|
||||
if err := json.Unmarshal(headerBytes, &header); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse header: %w", err)
|
||||
}
|
||||
|
||||
delete(header, "__metadata__")
|
||||
return header, nil
|
||||
}
|
||||
|
||||
// dtypeFromString converts safetensors dtype string to mlx.Dtype
|
||||
func dtypeFromString(s string) mlx.Dtype {
|
||||
switch strings.ToUpper(s) {
|
||||
case "F32", "FLOAT32":
|
||||
return mlx.DtypeFloat32
|
||||
case "F16", "FLOAT16":
|
||||
return mlx.DtypeFloat16
|
||||
case "BF16", "BFLOAT16":
|
||||
return mlx.DtypeBFloat16
|
||||
case "I32", "INT32":
|
||||
return mlx.DtypeInt32
|
||||
case "I64", "INT64":
|
||||
return mlx.DtypeInt64
|
||||
case "U8", "UINT8":
|
||||
return mlx.DtypeUint8
|
||||
default:
|
||||
return mlx.DtypeFloat32
|
||||
}
|
||||
}
|
||||
|
||||
// ModelWeights manages weights from multiple safetensor files.
|
||||
type ModelWeights struct {
|
||||
dir string // Model directory
|
||||
tensorFiles map[string]string // tensor name -> file path
|
||||
tensorInfo map[string]TensorInfo // tensor name -> metadata
|
||||
nativeCache map[string]*mlx.SafetensorsFile // file path -> loaded native handle
|
||||
cache map[string]*mlx.Array // tensor name -> array (after Load)
|
||||
}
|
||||
|
||||
// LoadModelWeights scans safetensor files and builds a tensor index.
|
||||
// This only reads JSON headers, not tensor data.
|
||||
func LoadModelWeights(dir string) (*ModelWeights, error) {
|
||||
mw := &ModelWeights{
|
||||
dir: dir,
|
||||
tensorFiles: make(map[string]string),
|
||||
tensorInfo: make(map[string]TensorInfo),
|
||||
nativeCache: make(map[string]*mlx.SafetensorsFile),
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read directory: %w", err)
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if strings.HasSuffix(entry.Name(), ".safetensors") {
|
||||
path := filepath.Join(dir, entry.Name())
|
||||
|
||||
header, err := parseSafetensorHeader(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse %s: %w", entry.Name(), err)
|
||||
}
|
||||
|
||||
for name, info := range header {
|
||||
mw.tensorFiles[name] = path
|
||||
mw.tensorInfo[name] = info
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(mw.tensorFiles) == 0 {
|
||||
return nil, fmt.Errorf("no safetensor files found in %s", dir)
|
||||
}
|
||||
|
||||
return mw, nil
|
||||
}
|
||||
|
||||
// Load loads all tensors into cache with the specified dtype.
|
||||
// If dtype is 0, tensors are loaded in their original dtype.
|
||||
// Automatically uses streaming (memory-efficient) when dtype conversion is needed,
|
||||
// or native loading when tensors are already in the target dtype.
|
||||
func (mw *ModelWeights) Load(dtype mlx.Dtype) error {
|
||||
if dtype == 0 {
|
||||
return mw.loadNative()
|
||||
}
|
||||
|
||||
// Check if any tensor needs conversion
|
||||
needsConversion := false
|
||||
for name := range mw.tensorFiles {
|
||||
info := mw.tensorInfo[name]
|
||||
if dtypeFromString(info.Dtype) != dtype {
|
||||
needsConversion = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if needsConversion {
|
||||
return mw.loadStreaming(dtype)
|
||||
}
|
||||
return mw.loadNative()
|
||||
}
|
||||
|
||||
// loadNative loads all tensors using the native memory-mapped loader.
|
||||
func (mw *ModelWeights) loadNative() error {
|
||||
mw.cache = make(map[string]*mlx.Array)
|
||||
|
||||
fileToTensors := make(map[string][]string)
|
||||
for name, path := range mw.tensorFiles {
|
||||
fileToTensors[path] = append(fileToTensors[path], name)
|
||||
}
|
||||
|
||||
for path, names := range fileToTensors {
|
||||
native, err := mlx.LoadSafetensorsNative(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load %s: %w", path, err)
|
||||
}
|
||||
|
||||
for _, name := range names {
|
||||
arr := native.Get(name)
|
||||
if arr == nil {
|
||||
native.Free()
|
||||
return fmt.Errorf("tensor %q not found in %s", name, path)
|
||||
}
|
||||
mw.cache[name] = arr
|
||||
}
|
||||
|
||||
mw.nativeCache[path] = native
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadStreaming loads tensors with dtype conversion.
|
||||
// Uses the same pattern as Python: replace each entry in the map after conversion,
|
||||
// so the original tensor loses its reference and can be freed.
|
||||
func (mw *ModelWeights) loadStreaming(dtype mlx.Dtype) error {
|
||||
mw.cache = make(map[string]*mlx.Array)
|
||||
|
||||
fileToTensors := make(map[string][]string)
|
||||
for name, path := range mw.tensorFiles {
|
||||
fileToTensors[path] = append(fileToTensors[path], name)
|
||||
}
|
||||
|
||||
for path, names := range fileToTensors {
|
||||
native, err := mlx.LoadSafetensorsNative(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load %s: %w", path, err)
|
||||
}
|
||||
|
||||
for _, name := range names {
|
||||
src := native.Get(name)
|
||||
if src == nil {
|
||||
native.Free()
|
||||
return fmt.Errorf("tensor %q not found in %s", name, path)
|
||||
}
|
||||
|
||||
dst := mlx.AsType(src, dtype)
|
||||
mlx.Eval(dst)
|
||||
native.Set(name, dst)
|
||||
mw.cache[name] = dst
|
||||
}
|
||||
|
||||
native.Free()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get returns a tensor from cache. Call Load() first.
|
||||
func (mw *ModelWeights) Get(name string) (*mlx.Array, error) {
|
||||
if mw.cache == nil {
|
||||
return nil, fmt.Errorf("cache not initialized: call Load() first")
|
||||
}
|
||||
arr, ok := mw.cache[name]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("tensor %q not found in cache", name)
|
||||
}
|
||||
return arr, nil
|
||||
}
|
||||
|
||||
// GetTensor loads a tensor using the native loader without caching.
|
||||
// For bulk loading, use Load() + Get() instead.
|
||||
func (mw *ModelWeights) GetTensor(name string) (*mlx.Array, error) {
|
||||
if mw.cache != nil {
|
||||
if arr, ok := mw.cache[name]; ok {
|
||||
return arr, nil
|
||||
}
|
||||
}
|
||||
|
||||
path, ok := mw.tensorFiles[name]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("tensor %q not found", name)
|
||||
}
|
||||
|
||||
native, ok := mw.nativeCache[path]
|
||||
if !ok {
|
||||
var err error
|
||||
native, err = mlx.LoadSafetensorsNative(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load %s: %w", path, err)
|
||||
}
|
||||
mw.nativeCache[path] = native
|
||||
}
|
||||
|
||||
return native.Get(name), nil
|
||||
}
|
||||
|
||||
// GetTensorInfo returns metadata about a tensor without loading it.
|
||||
func (mw *ModelWeights) GetTensorInfo(name string) (TensorInfo, bool) {
|
||||
info, ok := mw.tensorInfo[name]
|
||||
return info, ok
|
||||
}
|
||||
|
||||
// ListTensors returns all tensor names.
|
||||
func (mw *ModelWeights) ListTensors() []string {
|
||||
names := make([]string, 0, len(mw.tensorFiles))
|
||||
for name := range mw.tensorFiles {
|
||||
names = append(names, name)
|
||||
}
|
||||
sort.Strings(names)
|
||||
return names
|
||||
}
|
||||
|
||||
// HasTensor checks if a tensor exists.
|
||||
func (mw *ModelWeights) HasTensor(name string) bool {
|
||||
_, ok := mw.tensorFiles[name]
|
||||
return ok
|
||||
}
|
||||
|
||||
// ReleaseAll releases all cached native file handles.
|
||||
func (mw *ModelWeights) ReleaseAll() {
|
||||
for path, native := range mw.nativeCache {
|
||||
native.Free()
|
||||
delete(mw.nativeCache, path)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,167 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package safetensors
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
func TestLoadModelWeights(t *testing.T) {
|
||||
// Skip if no model available
|
||||
modelDir := "../weights/gpt-oss-20b"
|
||||
if _, err := os.Stat(modelDir); os.IsNotExist(err) {
|
||||
t.Skip("model weights not available")
|
||||
}
|
||||
|
||||
mw, err := LoadModelWeights(modelDir)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadModelWeights: %v", err)
|
||||
}
|
||||
defer mw.ReleaseAll()
|
||||
|
||||
// Check we found tensors
|
||||
tensors := mw.ListTensors()
|
||||
if len(tensors) == 0 {
|
||||
t.Fatal("no tensors found")
|
||||
}
|
||||
t.Logf("found %d tensors", len(tensors))
|
||||
|
||||
// Check HasTensor
|
||||
if !mw.HasTensor(tensors[0]) {
|
||||
t.Errorf("HasTensor(%q) = false", tensors[0])
|
||||
}
|
||||
if mw.HasTensor("nonexistent.weight") {
|
||||
t.Error("HasTensor returned true for nonexistent tensor")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetTensor(t *testing.T) {
|
||||
modelDir := "../weights/gpt-oss-20b"
|
||||
if _, err := os.Stat(modelDir); os.IsNotExist(err) {
|
||||
t.Skip("model weights not available")
|
||||
}
|
||||
|
||||
mw, err := LoadModelWeights(modelDir)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadModelWeights: %v", err)
|
||||
}
|
||||
defer mw.ReleaseAll()
|
||||
|
||||
tensors := mw.ListTensors()
|
||||
if len(tensors) == 0 {
|
||||
t.Skip("no tensors")
|
||||
}
|
||||
|
||||
// Load first tensor
|
||||
arr, err := mw.GetTensor(tensors[0])
|
||||
if err != nil {
|
||||
t.Fatalf("GetTensor(%q): %v", tensors[0], err)
|
||||
}
|
||||
|
||||
// Verify it has a shape
|
||||
shape := arr.Shape()
|
||||
if len(shape) == 0 {
|
||||
t.Error("tensor has no shape")
|
||||
}
|
||||
t.Logf("%s: shape=%v dtype=%v", tensors[0], shape, arr.Dtype())
|
||||
}
|
||||
|
||||
func TestLoadWithDtype(t *testing.T) {
|
||||
modelDir := "../weights/gpt-oss-20b"
|
||||
if _, err := os.Stat(modelDir); os.IsNotExist(err) {
|
||||
t.Skip("model weights not available")
|
||||
}
|
||||
|
||||
mw, err := LoadModelWeights(modelDir)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadModelWeights: %v", err)
|
||||
}
|
||||
defer mw.ReleaseAll()
|
||||
|
||||
// Load all tensors as bfloat16
|
||||
if err := mw.Load(mlx.DtypeBFloat16); err != nil {
|
||||
t.Fatalf("Load: %v", err)
|
||||
}
|
||||
|
||||
// Get a tensor from cache
|
||||
tensors := mw.ListTensors()
|
||||
arr, err := mw.Get(tensors[0])
|
||||
if err != nil {
|
||||
t.Fatalf("Get: %v", err)
|
||||
}
|
||||
|
||||
// Verify dtype (unless it was already bf16)
|
||||
t.Logf("%s: dtype=%v", tensors[0], arr.Dtype())
|
||||
}
|
||||
|
||||
func TestLookupTensor(t *testing.T) {
|
||||
modelDir := "../weights/gpt-oss-20b"
|
||||
if _, err := os.Stat(modelDir); os.IsNotExist(err) {
|
||||
t.Skip("model weights not available")
|
||||
}
|
||||
|
||||
mw, err := LoadModelWeights(modelDir)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadModelWeights: %v", err)
|
||||
}
|
||||
defer mw.ReleaseAll()
|
||||
|
||||
// HasTensor returns false for nonexistent
|
||||
if mw.HasTensor("nonexistent") {
|
||||
t.Error("HasTensor should return false for nonexistent")
|
||||
}
|
||||
|
||||
// HasTensor returns true for existing tensor
|
||||
tensors := mw.ListTensors()
|
||||
if !mw.HasTensor(tensors[0]) {
|
||||
t.Error("HasTensor should return true for existing tensor")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSafetensorHeader(t *testing.T) {
|
||||
modelDir := "../weights/gpt-oss-20b"
|
||||
if _, err := os.Stat(modelDir); os.IsNotExist(err) {
|
||||
t.Skip("model weights not available")
|
||||
}
|
||||
|
||||
// Find a safetensors file
|
||||
entries, err := os.ReadDir(modelDir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var stFile string
|
||||
for _, e := range entries {
|
||||
if filepath.Ext(e.Name()) == ".safetensors" {
|
||||
stFile = filepath.Join(modelDir, e.Name())
|
||||
break
|
||||
}
|
||||
}
|
||||
if stFile == "" {
|
||||
t.Skip("no safetensors file found")
|
||||
}
|
||||
|
||||
header, err := parseSafetensorHeader(stFile)
|
||||
if err != nil {
|
||||
t.Fatalf("parseSafetensorHeader: %v", err)
|
||||
}
|
||||
|
||||
if len(header) == 0 {
|
||||
t.Error("header is empty")
|
||||
}
|
||||
|
||||
// Check a tensor has valid info
|
||||
for name, info := range header {
|
||||
if info.Dtype == "" {
|
||||
t.Errorf("%s: empty dtype", name)
|
||||
}
|
||||
if len(info.Shape) == 0 {
|
||||
t.Errorf("%s: empty shape", name)
|
||||
}
|
||||
break // just check one
|
||||
}
|
||||
}
|
||||
@@ -1,85 +0,0 @@
|
||||
# Tokenizer
|
||||
|
||||
Tokenizer for LLM inference supporting BPE, SentencePiece, and WordPiece algorithms. The goal of this package is to see if a pure Go tokenizer can be fast and correct. It primarily supports the `imagegen` models however it (or parts of it) could be considered to replace Ollama's tokenizer in the `model` package.
|
||||
|
||||
## Features
|
||||
|
||||
- **BPE (Byte Pair Encoding)** - GPT-2/Llama style with byte-level encoding
|
||||
- **SentencePiece** - Gemma style with `▁` space handling
|
||||
- **WordPiece** - BERT style with `##` continuation tokens
|
||||
- **Parallel encoding** - Automatic parallelization for inputs >4KB
|
||||
- **HuggingFace compatible** - Loads `tokenizer.json` directly
|
||||
|
||||
## Usage
|
||||
|
||||
```go
|
||||
import "github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
|
||||
// Load from HuggingFace model directory
|
||||
tok, err := tokenizer.Load("./weights/Llama-3.2-1B")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Encode text to token IDs
|
||||
ids := tok.Encode("Hello, world!", false) // false = don't add BOS
|
||||
|
||||
// Decode back to text
|
||||
text := tok.Decode(ids)
|
||||
|
||||
// Check special tokens
|
||||
if tok.IsEOS(ids[len(ids)-1]) {
|
||||
// End of sequence
|
||||
}
|
||||
```
|
||||
|
||||
## Performance
|
||||
|
||||
Benchmarks on Apple M3 Max:
|
||||
|
||||
| Input Size | Encode | Decode | Tokens |
|
||||
|------------|--------|--------|--------|
|
||||
| 1 KB | 14.5 MB/s | 267 MB/s | 231 |
|
||||
| 10 KB | 10.9 MB/s | 321 MB/s | 2,301 |
|
||||
| 100 KB | 8.9 MB/s | 311 MB/s | 23,001 |
|
||||
| 1 MB | 9.6 MB/s | 321 MB/s | 230,001 |
|
||||
|
||||
Comparison with other implementations (10 MB input):
|
||||
|
||||
| Implementation | Encode Speed | Notes |
|
||||
|----------------|--------------|-------|
|
||||
| Engine (this) | ~10 MB/s | stdlib RE2, parallel >4KB |
|
||||
| tiktoken (Rust) | ~17 MB/s | Highly optimized regex |
|
||||
| Ollama (Go) | ~2-3 MB/s | regexp2 backtracking |
|
||||
|
||||
## Performance Opportunities
|
||||
|
||||
Potential optimizations not yet implemented:
|
||||
|
||||
| Optimization | Expected Gain | Complexity |
|
||||
|--------------|---------------|------------|
|
||||
| Aho-Corasick for special tokens | 2-3x for many special tokens | Medium |
|
||||
| Custom regex engine (like tiktoken) | 1.5-2x | High |
|
||||
| SIMD byte scanning | 1.3-1.5x for pretokenizer | Medium |
|
||||
| Assembly BPE merge loop | 1.2-1.5x | High |
|
||||
| Memoization for repeated substrings | Variable | Low |
|
||||
|
||||
Current bottleneck is the pretokenizer regex (~60% of encode time). tiktoken achieves ~17 MB/s with a hand-tuned Rust regex engine.
|
||||
|
||||
## Not Yet Implemented
|
||||
|
||||
| Feature | Used By | Notes |
|
||||
|---------|---------|-------|
|
||||
| Unigram tokenizer | T5, ALBERT, mBART | Different algorithm (not BPE) |
|
||||
| Unicode normalizers | Some multilingual models | NFD, NFKC, lowercase, etc. |
|
||||
| Custom pretokenizers | Model-specific | Beyond standard patterns |
|
||||
|
||||
Most HuggingFace models use BPE or SentencePiece, which are fully supported. WordPiece (BERT-style) is also supported with standard `[UNK]` fallback for out-of-vocabulary characters.
|
||||
|
||||
## Files
|
||||
|
||||
| File | Description |
|
||||
|------|-------------|
|
||||
| `tokenizer.go` | Main implementation (~1000 lines) |
|
||||
| `tokenizer_test.go` | Tests and benchmarks |
|
||||
| `testdata/` | Mini tokenizer for unit tests |
|
||||
@@ -1 +0,0 @@
|
||||
{"model": {"type": "BPE", "vocab": {"!": 0, "\"": 1, "#": 2, "$": 3, "%": 4, "&": 5, "'": 6, "(": 7, ")": 8, "*": 9, "+": 10, ",": 11, "-": 12, ".": 13, "/": 14, "0": 15, "1": 16, "2": 17, "3": 18, "4": 19, "5": 20, "6": 21, "7": 22, "8": 23, "9": 24, ":": 25, ";": 26, "<": 27, "=": 28, ">": 29, "?": 30, "@": 31, "A": 32, "B": 33, "C": 34, "D": 35, "E": 36, "F": 37, "G": 38, "H": 39, "I": 40, "J": 41, "K": 42, "L": 43, "M": 44, "N": 45, "O": 46, "P": 47, "Q": 48, "R": 49, "S": 50, "T": 51, "U": 52, "V": 53, "fé": 59958, "W": 54, "X": 55, "Y": 56, "Z": 57, "[": 58, "\\": 59, "]": 60, "^": 61, "_": 62, "`": 63, "a": 64, "b": 65, "c": 66, "d": 67, "e": 68, "f": 69, "g": 70, "h": 71, "i": 72, "j": 73, "k": 74, "l": 75, "m": 76, "n": 77, "o": 78, "p": 79, "r": 81, "q": 80, "s": 82, "t": 83, "u": 84, "v": 85, "w": 86, "x": 87, "y": 88, "z": 89, "{": 90, "|": 91, "}": 92, "~": 93, "¡": 94, "¢": 95, "£": 96, "¤": 97, "¥": 98, "¦": 99, "§": 100, "¨": 101, "World": 10343, "©": 102, "ª": 103, "«": 104, "¬": 105, "®": 106, "world": 14957, "¯": 107, "°": 108, "±": 109, "²": 110, "³": 111, "´": 112, "µ": 113, "¶": 114, "·": 115, "¸": 116, "¹": 117, "º": 118, "»": 119, "¼": 120, "½": 121, "¾": 122, "¿": 123, "À": 124, "Á": 125, "Â": 126, "Ã": 127, "Ä": 128, "Å": 129, "Æ": 130, "Ç": 131, "È": 132, "É": 133, "Ê": 134, "Ë": 135, "Ì": 136, "Í": 137, "Î": 138, "Ï": 139, "Ð": 140, "Ñ": 141, "Ò": 142, "Ó": 143, "Ô": 144, "Õ": 145, "Ö": 146, "×": 147, "Ø": 148, "Ù": 149, "Ú": 150, "Û": 151, "Ü": 152, "Ý": 153, "Þ": 154, "ß": 155, "à": 156, "á": 157, "â": 158, "ã": 159, "ä": 160, "å": 161, "æ": 162, "ç": 163, "è": 164, "é": 165, "ê": 166, "ë": 167, "ì": 168, "Ġhello": 24748, "í": 169, "î": 170, "ï": 171, "ð": 172, "ñ": 173, "Hello": 9906, "ò": 174, "ó": 175, "ô": 176, "õ": 177, "ö": 178, "Ġ{}": 4792, "÷": 179, "ø": 180, "ù": 181, "ú": 182, "û": 183, "ü": 184, "ý": 185, "þ": 186, "ÿ": 187, "Ā": 188, "ā": 189, "Ă": 190, "ă": 191, "Ċ": 198, "Ą": 192, "ą": 193, "Ć": 194, "ć": 195, "Ĉ": 196, "ĉ": 197, "ċ": 199, "Č": 200, "č": 201, "Ď": 202, "ď": 203, "Đ": 204, "đ": 205, "Ē": 206, "ē": 207, "Ĕ": 208, "ĕ": 209, "Ė": 210, "ė": 211, "Ę": 212, "ę": 213, "Ġ": 220, "Ě": 214, "ě": 215, "Ĝ": 216, "ĝ": 217, "Ğ": 218, "ğ": 219, "ġ": 221, "Ģ": 222, "ģ": 223, "Ĥ": 224, "ĥ": 225, "Ħ": 226, "ħ": 227, "Ĩ": 228, "ĩ": 229, "Ī": 230, "ī": 231, "Ĭ": 232, "ĭ": 233, "Į": 234, "į": 235, "İ": 236, "ı": 237, "IJ": 238, "ij": 239, "Ĵ": 240, "ĵ": 241, "Ķ": 242, "ķ": 243, "ĸ": 244, "Ĺ": 245, "ĺ": 246, "Ļ": 247, "ļ": 248, "Ľ": 249, "ĠĠ": 256, "ľ": 250, "Ŀ": 251, "ŀ": 252, "Ł": 253, "rer": 38149, "ĠĠĠ": 262, "ł": 254, "Ń": 255, "'m": 2846, "'re": 2351, "can": 4919, "func": 2900, "()": 368, "Ġworld": 1917, "Ġmain": 1925, "00": 410, "123": 4513, "000": 931, "ca": 936, "'t": 956, "é": 978, "hello": 15339, "Ġw": 289, "orld": 1410, "Ġwor": 4191, "ld": 509, "main": 3902, "Ġm": 296, "ain": 467, "Ġma": 7643, "in": 258, "Ġmai": 17154, "re": 265, "'r": 97670, "unc": 1371, "fun": 12158, "fu": 33721, "nc": 1031, "ma": 1764, "mai": 77585, "wor": 50810, "or": 269, "Ġwo": 24670, "23": 1419, "12": 717, "{}": 6390, "Ġ{": 314, "an": 276, "ello": 4896, "Hel": 33813, "lo": 385, "Hell": 81394, "un": 359, "hel": 50222, "hell": 57195, "ai": 2192, "wo": 1146, "Ġh": 305, "Ġhel": 11591, "Ġhell": 15123, "el": 301, "He": 1548, "er": 261, "he": 383, "ell": 616, "ll": 657}, "merges": ["Ġ Ġ", "Ġ ĠĠ", "ĠĠ Ġ", "( )", "0 0", "0 00", "00 0", "c a", "' t", "à ©", "Ġ world", "Ġw orld", "Ġwor ld", "Ġ main", "Ġm ain", "Ġma in", "Ġmai n", "' re", "'r e", "' m", "f unc", "fun c", "fu nc", "m ain", "ma in", "mai n", "Ġ wor", "Ġw or", "Ġwo r", "1 23", "12 3", "Ġ {}", "Ġ{ }", "c an", "ca n", "{ }", "Ġ ma", "Ġm a", "H ello", "Hel lo", "Hell o", "W orld", "f un", "fu n", "w orld", "wor ld", "h ello", "hel lo", "hell o", "Ġ mai", "Ġm ai", "Ġma i", "Ġ wo", "Ġw o", "Ġ hello", "Ġh ello", "Ġhel lo", "Ġhell o", "f u", "H el", "He l", "r er", "re r", "h el", "he l", "w or", "wo r", "h ell", "he ll", "hel l", "f é", "m ai", "ma i", "H ell", "He ll", "Hel l", "' r"]}, "pre_tokenizer": {"type": "Sequence", "pretokenizers": [{"type": "Split", "pattern": {"Regex": "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"}, "behavior": "Isolated", "invert": false}, {"type": "ByteLevel", "add_prefix_space": false, "trim_offsets": true, "use_regex": false}]}, "decoder": {"type": "ByteLevel", "add_prefix_space": true, "trim_offsets": true, "use_regex": true}, "added_tokens": [{"id": 128000, "content": "<|begin_of_text|>", "special": true}, {"id": 128001, "content": "<|end_of_text|>", "special": true}]}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user