mirror of
https://github.com/ollama/ollama.git
synced 2026-01-09 08:01:40 -05:00
Compare commits
7 Commits
improve-cl
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a23b559b4c | ||
|
|
33ee7168ba | ||
|
|
34d0c55ea5 | ||
|
|
53a5a9e9ae | ||
|
|
e30e08a7d6 | ||
|
|
12e2b3514a | ||
|
|
626af2d809 |
7
.github/workflows/release.yaml
vendored
7
.github/workflows/release.yaml
vendored
@@ -68,6 +68,7 @@ jobs:
|
||||
name: bundles-darwin
|
||||
path: |
|
||||
dist/*.tgz
|
||||
dist/*.tar.zst
|
||||
dist/*.zip
|
||||
dist/*.dmg
|
||||
|
||||
@@ -392,13 +393,13 @@ jobs:
|
||||
done
|
||||
- run: |
|
||||
for ARCHIVE in dist/${{ matrix.os }}-${{ matrix.arch }}/*.tar.in; do
|
||||
tar c -C dist/${{ matrix.os }}-${{ matrix.arch }} -T $ARCHIVE --owner 0 --group 0 | pigz -9vc >$(basename ${ARCHIVE//.*/}.tgz);
|
||||
tar c -C dist/${{ matrix.os }}-${{ matrix.arch }} -T $ARCHIVE --owner 0 --group 0 | zstd --ultra -22 -T0 >$(basename ${ARCHIVE//.*/}.tar.zst);
|
||||
done
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: bundles-${{ matrix.os }}-${{ matrix.arch }}-${{ matrix.target }}
|
||||
path: |
|
||||
*.tgz
|
||||
*.tar.zst
|
||||
|
||||
# Build each Docker variant (OS, arch, and flavor) separately. Using QEMU is unreliable and slower.
|
||||
docker-build-push:
|
||||
@@ -531,7 +532,7 @@ jobs:
|
||||
- name: Upload release artifacts
|
||||
run: |
|
||||
pids=()
|
||||
for payload in dist/*.txt dist/*.zip dist/*.tgz dist/*.exe dist/*.dmg ; do
|
||||
for payload in dist/*.txt dist/*.zip dist/*.tgz dist/*.tar.zst dist/*.exe dist/*.dmg ; do
|
||||
echo "Uploading $payload"
|
||||
gh release upload ${GITHUB_REF_NAME} $payload --clobber &
|
||||
pids[$!]=$!
|
||||
|
||||
@@ -2,6 +2,22 @@ cmake_minimum_required(VERSION 3.21)
|
||||
|
||||
project(Ollama C CXX)
|
||||
|
||||
# Handle cross-compilation on macOS: when CMAKE_OSX_ARCHITECTURES is set to a
|
||||
# single architecture different from the host, override CMAKE_SYSTEM_PROCESSOR
|
||||
# to match. This is necessary because CMAKE_SYSTEM_PROCESSOR defaults to the
|
||||
# host architecture, but downstream projects (like MLX) use it to detect the
|
||||
# target architecture.
|
||||
if(CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_OSX_ARCHITECTURES MATCHES ";")
|
||||
# Single architecture specified
|
||||
if(CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" AND NOT CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64")
|
||||
message(STATUS "Cross-compiling for x86_64: overriding CMAKE_SYSTEM_PROCESSOR from ${CMAKE_SYSTEM_PROCESSOR} to x86_64")
|
||||
set(CMAKE_SYSTEM_PROCESSOR "x86_64")
|
||||
elseif(CMAKE_OSX_ARCHITECTURES STREQUAL "arm64" AND NOT CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
|
||||
message(STATUS "Cross-compiling for arm64: overriding CMAKE_SYSTEM_PROCESSOR from ${CMAKE_SYSTEM_PROCESSOR} to arm64")
|
||||
set(CMAKE_SYSTEM_PROCESSOR "arm64")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
include(CheckLanguage)
|
||||
include(GNUInstallDirs)
|
||||
|
||||
@@ -12,7 +28,7 @@ set(BUILD_SHARED_LIBS ON)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
set(CMAKE_CXX_EXTENSIONS OFF)
|
||||
set(CMAKE_CXX_EXTENSIONS ON) # Recent versions of MLX Requires gnu++17 extensions to compile properly
|
||||
|
||||
set(GGML_BUILD ON)
|
||||
set(GGML_SHARED ON)
|
||||
@@ -147,14 +163,48 @@ if(CMAKE_HIP_COMPILER)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
find_package(Vulkan)
|
||||
if(Vulkan_FOUND)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan)
|
||||
install(TARGETS ggml-vulkan
|
||||
RUNTIME_DEPENDENCIES
|
||||
PRE_INCLUDE_REGEXES vulkan
|
||||
PRE_EXCLUDE_REGEXES ".*"
|
||||
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
|
||||
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
|
||||
)
|
||||
if(NOT APPLE)
|
||||
find_package(Vulkan)
|
||||
if(Vulkan_FOUND)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan)
|
||||
install(TARGETS ggml-vulkan
|
||||
RUNTIME_DEPENDENCIES
|
||||
PRE_INCLUDE_REGEXES vulkan
|
||||
PRE_EXCLUDE_REGEXES ".*"
|
||||
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
|
||||
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
|
||||
)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
option(MLX_ENGINE "Enable MLX backend" OFF)
|
||||
|
||||
if(MLX_ENGINE)
|
||||
message(STATUS "Setting up MLX (this takes a while...)")
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/x/ml/backend/mlx)
|
||||
|
||||
# Find CUDA toolkit if MLX is built with CUDA support
|
||||
find_package(CUDAToolkit)
|
||||
|
||||
install(TARGETS mlx mlxc
|
||||
RUNTIME_DEPENDENCIES
|
||||
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
|
||||
PRE_INCLUDE_REGEXES cublas cublasLt cudart nvrtc cudnn nccl
|
||||
PRE_EXCLUDE_REGEXES ".*"
|
||||
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
||||
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
||||
FRAMEWORK DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
||||
)
|
||||
|
||||
# Manually install cudart and cublas since they might not be picked up as direct dependencies
|
||||
if(CUDAToolkit_FOUND)
|
||||
file(GLOB CUDART_LIBS
|
||||
"${CUDAToolkit_LIBRARY_DIR}/libcudart.so*"
|
||||
"${CUDAToolkit_LIBRARY_DIR}/libcublas.so*")
|
||||
if(CUDART_LIBS)
|
||||
install(FILES ${CUDART_LIBS}
|
||||
DESTINATION ${OLLAMA_INSTALL_DIR}
|
||||
COMPONENT MLX)
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
@@ -41,7 +41,7 @@
|
||||
"inherits": [ "CUDA" ],
|
||||
"cacheVariables": {
|
||||
"CMAKE_CUDA_ARCHITECTURES": "75-virtual;80-virtual;86-virtual;87-virtual;89-virtual;90-virtual;90a-virtual;100-virtual;103-virtual;110-virtual;120-virtual;121-virtual",
|
||||
"CMAKE_CUDA_FLAGS": "-t 2",
|
||||
"CMAKE_CUDA_FLAGS": "-t 4",
|
||||
"OLLAMA_RUNNER_DIR": "cuda_v13"
|
||||
}
|
||||
},
|
||||
@@ -83,6 +83,28 @@
|
||||
"cacheVariables": {
|
||||
"OLLAMA_RUNNER_DIR": "vulkan"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "MLX",
|
||||
"inherits": [ "Default" ],
|
||||
"cacheVariables": {
|
||||
"MLX_ENGINE": "ON",
|
||||
"OLLAMA_RUNNER_DIR": "mlx"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "MLX CUDA 12",
|
||||
"inherits": [ "MLX", "CUDA 12" ],
|
||||
"cacheVariables": {
|
||||
"OLLAMA_RUNNER_DIR": "mlx_cuda_v12"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "MLX CUDA 13",
|
||||
"inherits": [ "MLX", "CUDA 13" ],
|
||||
"cacheVariables": {
|
||||
"OLLAMA_RUNNER_DIR": "mlx_cuda_v13"
|
||||
}
|
||||
}
|
||||
],
|
||||
"buildPresets": [
|
||||
@@ -140,6 +162,21 @@
|
||||
"name": "Vulkan",
|
||||
"targets": [ "ggml-vulkan" ],
|
||||
"configurePreset": "Vulkan"
|
||||
},
|
||||
{
|
||||
"name": "MLX",
|
||||
"targets": [ "mlx", "mlxc" ],
|
||||
"configurePreset": "MLX"
|
||||
},
|
||||
{
|
||||
"name": "MLX CUDA 12",
|
||||
"targets": [ "mlx", "mlxc" ],
|
||||
"configurePreset": "MLX CUDA 12"
|
||||
},
|
||||
{
|
||||
"name": "MLX CUDA 13",
|
||||
"targets": [ "mlx", "mlxc" ],
|
||||
"configurePreset": "MLX CUDA 13"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
36
Dockerfile
36
Dockerfile
@@ -131,7 +131,39 @@ COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'Vulkan' \
|
||||
&& cmake --build --parallel --preset 'Vulkan' \
|
||||
&& cmake --install build --component Vulkan --strip --parallel 8
|
||||
&& cmake --install build --component Vulkan --strip --parallel 8
|
||||
|
||||
FROM base AS mlx
|
||||
ARG CUDA13VERSION=13.0
|
||||
RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-} \
|
||||
&& dnf install -y openblas-devel lapack-devel \
|
||||
&& dnf install -y libcudnn9-cuda-13 libcudnn9-devel-cuda-13 \
|
||||
&& dnf install -y libnccl libnccl-devel
|
||||
ENV PATH=/usr/local/cuda-13/bin:$PATH
|
||||
ENV BLAS_INCLUDE_DIRS=/usr/include/openblas
|
||||
ENV LAPACK_INCLUDE_DIRS=/usr/include/openblas
|
||||
ENV CGO_LDFLAGS="-L/usr/local/cuda-13/lib64 -L/usr/local/cuda-13/targets/x86_64-linux/lib/stubs"
|
||||
ARG PARALLEL
|
||||
WORKDIR /go/src/github.com/ollama/ollama
|
||||
COPY CMakeLists.txt CMakePresets.json .
|
||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||
COPY x/ml/backend/mlx x/ml/backend/mlx
|
||||
COPY go.mod go.sum .
|
||||
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
|
||||
ENV PATH=/usr/local/go/bin:$PATH
|
||||
RUN go mod download
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'MLX CUDA 13' -DBLAS_INCLUDE_DIRS=/usr/include/openblas -DLAPACK_INCLUDE_DIRS=/usr/include/openblas \
|
||||
&& cmake --build --parallel ${PARALLEL} --preset 'MLX CUDA 13' \
|
||||
&& cmake --install build --component MLX --strip --parallel ${PARALLEL}
|
||||
COPY . .
|
||||
ARG GOFLAGS="'-ldflags=-w -s'"
|
||||
ENV CGO_ENABLED=1
|
||||
ARG CGO_CFLAGS
|
||||
ARG CGO_CXXFLAGS
|
||||
# TODO wire up the actual MLX engine here instead of building the main binary...
|
||||
RUN mkdir -p dist/bin
|
||||
RUN go build -tags mlx -trimpath -buildmode=pie -o dist/bin/imagegen ./x/imagegen/cmd/engine
|
||||
|
||||
|
||||
FROM base AS build
|
||||
@@ -153,6 +185,8 @@ FROM --platform=linux/amd64 scratch AS amd64
|
||||
COPY --from=cuda-12 dist/lib/ollama /lib/ollama/
|
||||
COPY --from=cuda-13 dist/lib/ollama /lib/ollama/
|
||||
COPY --from=vulkan dist/lib/ollama /lib/ollama/
|
||||
COPY --from=mlx /go/src/github.com/ollama/ollama/dist/lib/ollama /lib/ollama/
|
||||
COPY --from=mlx /go/src/github.com/ollama/ollama/dist/bin/ /bin/
|
||||
|
||||
FROM --platform=linux/arm64 scratch AS arm64
|
||||
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
|
||||
|
||||
@@ -6,11 +6,14 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"iter"
|
||||
"log/slog"
|
||||
"maps"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
ofs "github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
@@ -18,8 +21,13 @@ type ModelParameters struct {
|
||||
Architectures []string `json:"architectures"`
|
||||
VocabSize uint32 `json:"vocab_size"`
|
||||
|
||||
// TODO is this needed?
|
||||
ModelType string `json:"model_type"`
|
||||
|
||||
TextModel struct {
|
||||
VocabSize uint32 `json:"vocab_size"`
|
||||
VocabSize uint32 `json:"vocab_size"`
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
ModelType string `json:"model_type"`
|
||||
} `json:"text_config"`
|
||||
}
|
||||
|
||||
@@ -33,8 +41,94 @@ type AdapterParameters struct {
|
||||
} `json:"lora_parameters"`
|
||||
}
|
||||
|
||||
func (ModelParameters) KV(t *Tokenizer) ggml.KV {
|
||||
kv := ggml.KV{
|
||||
type KV map[string]any
|
||||
|
||||
func (kv KV) Architecture() string {
|
||||
return kv.String("general.architecture", "unknown")
|
||||
}
|
||||
|
||||
type valueTypes interface {
|
||||
uint8 | int8 | uint16 | int16 |
|
||||
uint32 | int32 | uint64 | int64 |
|
||||
string | float32 | float64 | bool
|
||||
}
|
||||
|
||||
type arrayValueTypes interface {
|
||||
[]uint8 | []int8 | []uint16 | []int16 |
|
||||
[]uint32 | []int32 | []uint64 | []int64 |
|
||||
[]string | []float32 | []float64 | []bool
|
||||
}
|
||||
|
||||
func keyValue[T valueTypes | arrayValueTypes](kv KV, key string, defaultValue ...T) (T, bool) {
|
||||
if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") {
|
||||
key = kv.Architecture() + "." + key
|
||||
}
|
||||
|
||||
if val, ok := kv[key].(T); ok {
|
||||
return val, true
|
||||
}
|
||||
return defaultValue[0], false
|
||||
}
|
||||
|
||||
func (kv KV) String(key string, defaultValue ...string) string {
|
||||
val, _ := keyValue(kv, key, append(defaultValue, "")...)
|
||||
return val
|
||||
}
|
||||
|
||||
func (kv KV) Uint(key string, defaultValue ...uint32) uint32 {
|
||||
val, _ := keyValue(kv, key, append(defaultValue, 0)...)
|
||||
return val
|
||||
}
|
||||
|
||||
func (kv KV) Float(key string, defaultValue ...float32) float32 {
|
||||
val, _ := keyValue(kv, key, append(defaultValue, 0)...)
|
||||
return val
|
||||
}
|
||||
|
||||
func (kv KV) Bool(key string, defaultValue ...bool) bool {
|
||||
val, _ := keyValue(kv, key, append(defaultValue, false)...)
|
||||
return val
|
||||
}
|
||||
|
||||
func (kv KV) Strings(key string, defaultValue ...[]string) []string {
|
||||
val, _ := keyValue(kv, key, append(defaultValue, []string{""})...)
|
||||
return val
|
||||
}
|
||||
|
||||
func (kv KV) Ints(key string, defaultValue ...[]int32) []int32 {
|
||||
val, _ := keyValue(kv, key, append(defaultValue, []int32{0})...)
|
||||
return val
|
||||
}
|
||||
|
||||
func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 {
|
||||
val, _ := keyValue(kv, key, append(defaultValue, []uint32{0})...)
|
||||
return val
|
||||
}
|
||||
|
||||
func (kv KV) Floats(key string, defaultValue ...[]float32) []float32 {
|
||||
val, _ := keyValue(kv, key, append(defaultValue, []float32{0})...)
|
||||
return val
|
||||
}
|
||||
|
||||
func (kv KV) Bools(key string, defaultValue ...[]bool) []bool {
|
||||
val, _ := keyValue(kv, key, append(defaultValue, []bool{false})...)
|
||||
return val
|
||||
}
|
||||
|
||||
func (kv KV) Len() int {
|
||||
return len(kv)
|
||||
}
|
||||
|
||||
func (kv KV) Keys() iter.Seq[string] {
|
||||
return maps.Keys(kv)
|
||||
}
|
||||
|
||||
func (kv KV) Value(key string) any {
|
||||
return kv[key]
|
||||
}
|
||||
|
||||
func (ModelParameters) KV(t *Tokenizer) KV {
|
||||
kv := KV{
|
||||
"general.file_type": uint32(1),
|
||||
"general.quantization_version": uint32(2),
|
||||
"tokenizer.ggml.pre": t.Pre,
|
||||
@@ -63,7 +157,7 @@ func (ModelParameters) KV(t *Tokenizer) ggml.KV {
|
||||
return kv
|
||||
}
|
||||
|
||||
func (p AdapterParameters) KV() ggml.KV {
|
||||
func (p AdapterParameters) KV() KV {
|
||||
var alpha float32
|
||||
if p.LoraParameters.Alpha == 0 {
|
||||
alpha = float32(p.Alpha)
|
||||
@@ -71,7 +165,7 @@ func (p AdapterParameters) KV() ggml.KV {
|
||||
alpha = p.LoraParameters.Alpha
|
||||
}
|
||||
|
||||
kv := ggml.KV{
|
||||
kv := KV{
|
||||
"adapter.lora.alpha": alpha,
|
||||
"adapter.type": "lora",
|
||||
"general.file_type": uint32(1),
|
||||
@@ -88,9 +182,14 @@ func (ModelParameters) specialTokenTypes() []string {
|
||||
}
|
||||
}
|
||||
|
||||
type ModelConverter interface {
|
||||
type ModelKV interface {
|
||||
// KV maps parameters to LLM key-values
|
||||
KV(*Tokenizer) ggml.KV
|
||||
KV(*Tokenizer) KV
|
||||
}
|
||||
|
||||
type ModelConverter interface {
|
||||
ModelKV
|
||||
|
||||
// Tensors maps input tensors to LLM tensors. Model specific modifications can be done here.
|
||||
Tensors([]Tensor) []*ggml.Tensor
|
||||
// Replacements returns a list of string pairs to replace in tensor names.
|
||||
@@ -107,7 +206,7 @@ type moreParser interface {
|
||||
|
||||
type AdapterConverter interface {
|
||||
// KV maps parameters to LLM key-values
|
||||
KV(ggml.KV) ggml.KV
|
||||
KV(ofs.Config) KV
|
||||
// Tensors maps input tensors to LLM tensors. Adapter specific modifications can be done here.
|
||||
Tensors([]Tensor) []*ggml.Tensor
|
||||
// Replacements returns a list of string pairs to replace in tensor names.
|
||||
@@ -115,7 +214,7 @@ type AdapterConverter interface {
|
||||
Replacements() []string
|
||||
}
|
||||
|
||||
func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ggml.KV) error {
|
||||
func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ofs.Config) error {
|
||||
bts, err := fs.ReadFile(fsys, "adapter_config.json")
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -126,8 +225,8 @@ func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ggml.KV) error {
|
||||
return err
|
||||
}
|
||||
|
||||
arch, ok := baseKV["general.architecture"]
|
||||
if !ok {
|
||||
arch := baseKV.Architecture()
|
||||
if arch == "" {
|
||||
return errors.New("architecture not set for the base model")
|
||||
}
|
||||
|
||||
@@ -153,23 +252,19 @@ func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ggml.KV) error {
|
||||
return writeFile(f, conv.KV(baseKV), conv.Tensors(ts))
|
||||
}
|
||||
|
||||
// Convert writes an Ollama compatible model to the provided io.WriteSeeker based on configurations
|
||||
// and files it finds in the input path.
|
||||
// Supported input model formats include safetensors.
|
||||
// Supported input tokenizers files include tokenizer.json (preferred) and tokenizer.model.
|
||||
func ConvertModel(fsys fs.FS, f *os.File) error {
|
||||
func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
|
||||
bts, err := fs.ReadFile(fsys, "config.json")
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
var p ModelParameters
|
||||
if err := json.Unmarshal(bts, &p); err != nil {
|
||||
return err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if len(p.Architectures) < 1 {
|
||||
return errors.New("unknown architecture")
|
||||
return nil, nil, errors.New("unknown architecture")
|
||||
}
|
||||
|
||||
var conv ModelConverter
|
||||
@@ -217,22 +312,22 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
|
||||
case "DeepseekV3ForCausalLM":
|
||||
conv = &deepseek2Model{}
|
||||
default:
|
||||
return fmt.Errorf("unsupported architecture %q", p.Architectures[0])
|
||||
return nil, nil, fmt.Errorf("unsupported architecture %q", p.Architectures[0])
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(bts, conv); err != nil {
|
||||
return err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if t, ok := conv.(moreParser); ok {
|
||||
if err := t.parseMore(fsys); err != nil {
|
||||
return err
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
t, err := parseTokenizer(fsys, conv.specialTokenTypes())
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
vocabSize := int(cmp.Or(p.VocabSize, p.TextModel.VocabSize))
|
||||
@@ -254,6 +349,19 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
|
||||
default:
|
||||
slog.Debug("vocabulary", "size", len(t.Vocabulary.Tokens))
|
||||
}
|
||||
return conv, t, nil
|
||||
}
|
||||
|
||||
// Convert writes an Ollama compatible model to the provided io.WriteSeeker based on configurations
|
||||
// and files it finds in the input path.
|
||||
// Supported input model formats include safetensors.
|
||||
// Supported input tokenizers files include tokenizer.json (preferred) and tokenizer.model.
|
||||
func ConvertModel(fsys fs.FS, f *os.File) error {
|
||||
kv, t, err := LoadModelMetadata(fsys)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
conv := kv.(ModelConverter)
|
||||
|
||||
ts, err := parseTensors(fsys, strings.NewReplacer(conv.Replacements()...))
|
||||
if err != nil {
|
||||
@@ -263,7 +371,7 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
|
||||
return writeFile(f, conv.KV(t), conv.Tensors(ts))
|
||||
}
|
||||
|
||||
func writeFile(f *os.File, kv ggml.KV, ts []*ggml.Tensor) error {
|
||||
func writeFile(f *os.File, kv KV, ts []*ggml.Tensor) error {
|
||||
for i := range ts {
|
||||
ts[i].Shape = slices.Clone(ts[i].Shape)
|
||||
slices.Reverse(ts[i].Shape)
|
||||
|
||||
@@ -88,7 +88,7 @@ func (p *bertModel) parseMore(fsys fs.FS) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *bertModel) KV(t *Tokenizer) ggml.KV {
|
||||
func (p *bertModel) KV(t *Tokenizer) KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "bert"
|
||||
kv["bert.attention.causal"] = false
|
||||
|
||||
@@ -24,7 +24,7 @@ type commandrModel struct {
|
||||
|
||||
var _ ModelConverter = (*commandrModel)(nil)
|
||||
|
||||
func (p *commandrModel) KV(t *Tokenizer) ggml.KV {
|
||||
func (p *commandrModel) KV(t *Tokenizer) KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "command-r"
|
||||
kv["general.name"] = "command-r"
|
||||
|
||||
@@ -47,7 +47,7 @@ type deepseek2Model struct {
|
||||
Architecture string
|
||||
}
|
||||
|
||||
func (p *deepseek2Model) KV(t *Tokenizer) ggml.KV {
|
||||
func (p *deepseek2Model) KV(t *Tokenizer) KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "deepseek2"
|
||||
kv["general.type"] = "model"
|
||||
|
||||
@@ -41,7 +41,7 @@ type deepseekocr struct {
|
||||
} `json:"vision_config"`
|
||||
}
|
||||
|
||||
func (m *deepseekocr) KV(t *Tokenizer) ggml.KV {
|
||||
func (m *deepseekocr) KV(t *Tokenizer) KV {
|
||||
kv := m.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "deepseekocr"
|
||||
kv["block_count"] = m.LanguageConfig.HiddenLayers
|
||||
|
||||
@@ -23,7 +23,7 @@ type gemmaModel struct {
|
||||
|
||||
var _ ModelConverter = (*gemmaModel)(nil)
|
||||
|
||||
func (p *gemmaModel) KV(t *Tokenizer) ggml.KV {
|
||||
func (p *gemmaModel) KV(t *Tokenizer) KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "gemma"
|
||||
kv["gemma.context_length"] = p.MaxPositionEmbeddings
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
package convert
|
||||
|
||||
import "github.com/ollama/ollama/fs/ggml"
|
||||
|
||||
type gemma2Model struct {
|
||||
gemmaModel
|
||||
SlidingWindow uint32 `json:"sliding_window"`
|
||||
@@ -9,7 +7,7 @@ type gemma2Model struct {
|
||||
FinalLogitSoftcap float32 `json:"final_logit_softcapping"`
|
||||
}
|
||||
|
||||
func (p *gemma2Model) KV(t *Tokenizer) ggml.KV {
|
||||
func (p *gemma2Model) KV(t *Tokenizer) KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "gemma2"
|
||||
kv["gemma2.context_length"] = p.MaxPositionEmbeddings
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"github.com/pdevine/tensor"
|
||||
"github.com/pdevine/tensor/native"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
@@ -15,7 +16,7 @@ type gemma2Adapter struct {
|
||||
|
||||
var _ AdapterConverter = (*gemma2Adapter)(nil)
|
||||
|
||||
func (p *gemma2Adapter) KV(baseKV ggml.KV) ggml.KV {
|
||||
func (p *gemma2Adapter) KV(baseKV fs.Config) KV {
|
||||
kv := p.AdapterParameters.KV()
|
||||
kv["general.architecture"] = "gemma2"
|
||||
return kv
|
||||
|
||||
@@ -3,8 +3,6 @@ package convert
|
||||
import (
|
||||
"cmp"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
type gemma3Model struct {
|
||||
@@ -55,7 +53,7 @@ const (
|
||||
gemma27BLayerCount = 62
|
||||
)
|
||||
|
||||
func (p *gemma3Model) KV(t *Tokenizer) ggml.KV {
|
||||
func (p *gemma3Model) KV(t *Tokenizer) KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "gemma3"
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ type gemma3nModel struct {
|
||||
VisionModel struct{} `json:"vision_config"`
|
||||
}
|
||||
|
||||
func (m *gemma3nModel) KV(t *Tokenizer) ggml.KV {
|
||||
func (m *gemma3nModel) KV(t *Tokenizer) KV {
|
||||
kv := m.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "gemma3n"
|
||||
kv["gemma3n.activation_sparsity_scale"] = slices.Collect(func(yield func(float32) bool) {
|
||||
|
||||
@@ -37,7 +37,7 @@ type gptossModel struct {
|
||||
|
||||
var _ ModelConverter = (*gptossModel)(nil)
|
||||
|
||||
func (m *gptossModel) KV(t *Tokenizer) ggml.KV {
|
||||
func (m *gptossModel) KV(t *Tokenizer) KV {
|
||||
kv := m.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "gptoss"
|
||||
kv["general.file_type"] = uint32(4)
|
||||
|
||||
@@ -48,7 +48,7 @@ type llamaModel struct {
|
||||
|
||||
var _ ModelConverter = (*llamaModel)(nil)
|
||||
|
||||
func (p *llamaModel) KV(t *Tokenizer) ggml.KV {
|
||||
func (p *llamaModel) KV(t *Tokenizer) KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "llama"
|
||||
kv["llama.vocab_size"] = p.VocabSize
|
||||
|
||||
@@ -35,7 +35,7 @@ type llama4Model struct {
|
||||
}
|
||||
|
||||
// KV implements ModelConverter.
|
||||
func (p *llama4Model) KV(t *Tokenizer) ggml.KV {
|
||||
func (p *llama4Model) KV(t *Tokenizer) KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "llama4"
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"github.com/pdevine/tensor"
|
||||
"github.com/pdevine/tensor/native"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
@@ -18,13 +19,13 @@ type llamaAdapter struct {
|
||||
|
||||
var _ AdapterConverter = (*llamaAdapter)(nil)
|
||||
|
||||
func (p *llamaAdapter) KV(baseKV ggml.KV) ggml.KV {
|
||||
func (p *llamaAdapter) KV(baseKV fs.Config) KV {
|
||||
kv := p.AdapterParameters.KV()
|
||||
kv["general.architecture"] = "llama"
|
||||
kv["llama.attention.head_count"] = baseKV["llama.attention.head_count"]
|
||||
kv["llama.attention.head_count_kv"] = baseKV["llama.attention.head_count_kv"]
|
||||
kv["llama.attention.head_count"] = baseKV.Value("llama.attention.head_count")
|
||||
kv["llama.attention.head_count_kv"] = baseKV.Value("llama.attention.head_count_kv")
|
||||
|
||||
p.NumAttentionHeads = baseKV["llama.attention.head_count"].(uint32)
|
||||
p.NumAttentionHeads = baseKV.Value("llama.attention.head_count").(uint32)
|
||||
|
||||
return kv
|
||||
}
|
||||
|
||||
@@ -60,7 +60,7 @@ type mistral3Model struct {
|
||||
ProjectorHiddenAct string `json:"projector_hidden_act"`
|
||||
}
|
||||
|
||||
func (p *mistral3Model) KV(t *Tokenizer) ggml.KV {
|
||||
func (p *mistral3Model) KV(t *Tokenizer) KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "mistral3"
|
||||
kv["mistral3.vocab_size"] = p.TextModel.VocabSize
|
||||
|
||||
@@ -39,7 +39,7 @@ type mistral3CausalModel struct {
|
||||
} `json:"rope_parameters"`
|
||||
}
|
||||
|
||||
func (p *mistral3CausalModel) KV(t *Tokenizer) ggml.KV {
|
||||
func (p *mistral3CausalModel) KV(t *Tokenizer) KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "mistral3"
|
||||
kv["mistral3.vocab_size"] = p.VocabSize
|
||||
|
||||
@@ -12,7 +12,7 @@ type mixtralModel struct {
|
||||
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
|
||||
}
|
||||
|
||||
func (p *mixtralModel) KV(t *Tokenizer) ggml.KV {
|
||||
func (p *mixtralModel) KV(t *Tokenizer) KV {
|
||||
kv := p.llamaModel.KV(t)
|
||||
|
||||
if p.NumLocalExperts > 0 {
|
||||
|
||||
@@ -34,7 +34,7 @@ type mllamaModel struct {
|
||||
} `json:"vision_config"`
|
||||
}
|
||||
|
||||
func (m *mllamaModel) KV(t *Tokenizer) ggml.KV {
|
||||
func (m *mllamaModel) KV(t *Tokenizer) KV {
|
||||
kv := m.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "mllama"
|
||||
|
||||
|
||||
@@ -87,7 +87,7 @@ func (p *nomicbertModel) parseMore(fsys fs.FS) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *nomicbertModel) KV(t *Tokenizer) ggml.KV {
|
||||
func (p *nomicbertModel) KV(t *Tokenizer) KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
|
||||
// Determine architecture based on MoE parameters (following qwen3 pattern)
|
||||
|
||||
@@ -34,7 +34,7 @@ type olmoModel struct {
|
||||
|
||||
var _ ModelConverter = (*olmoModel)(nil)
|
||||
|
||||
func (p *olmoModel) KV(t *Tokenizer) ggml.KV {
|
||||
func (p *olmoModel) KV(t *Tokenizer) KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "olmo3"
|
||||
kv["olmo3.block_count"] = p.NumHiddenLayers
|
||||
|
||||
@@ -37,7 +37,7 @@ type phi3Model struct {
|
||||
|
||||
var _ ModelConverter = (*phi3Model)(nil)
|
||||
|
||||
func (p *phi3Model) KV(t *Tokenizer) ggml.KV {
|
||||
func (p *phi3Model) KV(t *Tokenizer) KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "phi3"
|
||||
kv["phi3.context_length"] = p.MaxPositionEmbeddings
|
||||
|
||||
@@ -22,7 +22,7 @@ type qwen2Model struct {
|
||||
|
||||
var _ ModelConverter = (*qwen2Model)(nil)
|
||||
|
||||
func (q *qwen2Model) KV(t *Tokenizer) ggml.KV {
|
||||
func (q *qwen2Model) KV(t *Tokenizer) KV {
|
||||
kv := q.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "qwen2"
|
||||
kv["qwen2.block_count"] = q.HiddenLayers
|
||||
|
||||
@@ -29,7 +29,7 @@ type qwen25VLModel struct {
|
||||
|
||||
var _ ModelConverter = (*qwen25VLModel)(nil)
|
||||
|
||||
func (q *qwen25VLModel) KV(t *Tokenizer) ggml.KV {
|
||||
func (q *qwen25VLModel) KV(t *Tokenizer) KV {
|
||||
kv := q.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "qwen25vl"
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ type qwen3Model struct {
|
||||
}
|
||||
|
||||
// KV implements ModelConverter.
|
||||
func (q *qwen3Model) KV(t *Tokenizer) ggml.KV {
|
||||
func (q *qwen3Model) KV(t *Tokenizer) KV {
|
||||
arch := "qwen3"
|
||||
if q.NumExperts > 0 {
|
||||
arch += "moe"
|
||||
|
||||
@@ -45,7 +45,7 @@ func (m *qwen3VLModel) parseMore(fsys fs.FS) error {
|
||||
return json.Unmarshal(bts, &m.VisionModel)
|
||||
}
|
||||
|
||||
func (m *qwen3VLModel) KV(t *Tokenizer) ggml.KV {
|
||||
func (m *qwen3VLModel) KV(t *Tokenizer) KV {
|
||||
kv := m.qwen3Model.KV(t)
|
||||
|
||||
arch := "qwen3vl"
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
fsc "github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
@@ -28,7 +29,7 @@ type tensorData struct {
|
||||
Shape []int `json:"shape"`
|
||||
}
|
||||
|
||||
func convertFull(t *testing.T, fsys fs.FS) (*os.File, ggml.KV, ggml.Tensors) {
|
||||
func convertFull(t *testing.T, fsys fs.FS) (*os.File, fsc.Config, ggml.Tensors) {
|
||||
t.Helper()
|
||||
|
||||
f, err := os.CreateTemp(t.TempDir(), "f16")
|
||||
@@ -59,9 +60,10 @@ func convertFull(t *testing.T, fsys fs.FS) (*os.File, ggml.KV, ggml.Tensors) {
|
||||
return r, m.KV(), m.Tensors()
|
||||
}
|
||||
|
||||
func generateResultsJSON(t *testing.T, f *os.File, kv ggml.KV, tensors ggml.Tensors) map[string]string {
|
||||
func generateResultsJSON(t *testing.T, f *os.File, kv fsc.Config, tensors ggml.Tensors) map[string]string {
|
||||
actual := make(map[string]string)
|
||||
for k, v := range kv {
|
||||
for k := range kv.Keys() {
|
||||
v := kv.Value(k)
|
||||
if s, ok := v.(json.Marshaler); !ok {
|
||||
actual[k] = fmt.Sprintf("%v", v)
|
||||
} else {
|
||||
@@ -277,7 +279,7 @@ func generateSafetensorTestData(t *testing.T, tempDir string, tensorData map[str
|
||||
func TestConvertAdapter(t *testing.T) {
|
||||
type AdapterCase struct {
|
||||
Name string
|
||||
BaseKV map[string]any
|
||||
BaseKV KV
|
||||
Expected map[string]string
|
||||
}
|
||||
|
||||
|
||||
@@ -20,8 +20,8 @@ curl -fsSL https://ollama.com/install.sh | sh
|
||||
Download and extract the package:
|
||||
|
||||
```shell
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tgz \
|
||||
| sudo tar zx -C /usr
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tar.zst \
|
||||
| sudo tar x -C /usr
|
||||
```
|
||||
|
||||
Start Ollama:
|
||||
@@ -41,8 +41,8 @@ ollama -v
|
||||
If you have an AMD GPU, also download and extract the additional ROCm package:
|
||||
|
||||
```shell
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tgz \
|
||||
| sudo tar zx -C /usr
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tar.zst \
|
||||
| sudo tar x -C /usr
|
||||
```
|
||||
|
||||
### ARM64 install
|
||||
@@ -50,8 +50,8 @@ curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tgz \
|
||||
Download and extract the ARM64-specific package:
|
||||
|
||||
```shell
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-arm64.tgz \
|
||||
| sudo tar zx -C /usr
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-arm64.tar.zst \
|
||||
| sudo tar x -C /usr
|
||||
```
|
||||
|
||||
### Adding Ollama as a startup service (recommended)
|
||||
@@ -146,8 +146,8 @@ curl -fsSL https://ollama.com/install.sh | sh
|
||||
Or by re-downloading Ollama:
|
||||
|
||||
```shell
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tgz \
|
||||
| sudo tar zx -C /usr
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tar.zst \
|
||||
| sudo tar x -C /usr
|
||||
```
|
||||
|
||||
## Installing specific versions
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
package fs
|
||||
|
||||
import "iter"
|
||||
|
||||
type Config interface {
|
||||
Architecture() string
|
||||
String(string, ...string) string
|
||||
@@ -11,4 +13,8 @@ type Config interface {
|
||||
Ints(string, ...[]int32) []int32
|
||||
Floats(string, ...[]float32) []float32
|
||||
Bools(string, ...[]bool) []bool
|
||||
|
||||
Len() int
|
||||
Keys() iter.Seq[string]
|
||||
Value(key string) any
|
||||
}
|
||||
|
||||
@@ -6,7 +6,9 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"iter"
|
||||
"log/slog"
|
||||
"maps"
|
||||
"math"
|
||||
"slices"
|
||||
"strings"
|
||||
@@ -239,6 +241,18 @@ func (kv KV) Bools(key string, defaultValue ...[]bool) []bool {
|
||||
return val.values
|
||||
}
|
||||
|
||||
func (kv KV) Len() int {
|
||||
return len(kv)
|
||||
}
|
||||
|
||||
func (kv KV) Keys() iter.Seq[string] {
|
||||
return maps.Keys(kv)
|
||||
}
|
||||
|
||||
func (kv KV) Value(key string) any {
|
||||
return kv[key]
|
||||
}
|
||||
|
||||
func (kv KV) OllamaEngineRequired() bool {
|
||||
return slices.Contains([]string{
|
||||
"bert",
|
||||
|
||||
@@ -8,12 +8,12 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"maps"
|
||||
"os"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
@@ -508,7 +508,7 @@ func writeGGUFArray[S ~[]E, E any](w io.Writer, t uint32, s S) error {
|
||||
return binary.Write(w, binary.LittleEndian, s)
|
||||
}
|
||||
|
||||
func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
|
||||
func WriteGGUF(f *os.File, kv fs.Config, ts []*Tensor) error {
|
||||
arch := kv.String("general.architecture")
|
||||
if arch == "" {
|
||||
return fmt.Errorf("architecture not set")
|
||||
@@ -526,12 +526,12 @@ func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := binary.Write(f, binary.LittleEndian, uint64(len(kv))); err != nil {
|
||||
if err := binary.Write(f, binary.LittleEndian, uint64(kv.Len())); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, key := range slices.Sorted(maps.Keys(kv)) {
|
||||
if err := ggufWriteKV(f, arch, key, kv[key]); err != nil {
|
||||
for _, key := range slices.Sorted(kv.Keys()) {
|
||||
if err := ggufWriteKV(f, arch, key, kv.Value(key)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
"golang.org/x/text/encoding/unicode"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/convert"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
@@ -801,7 +802,7 @@ func createBinFile(t *testing.T, kv map[string]any, ti []*ggml.Tensor) (string,
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
base := map[string]any{"general.architecture": "test"}
|
||||
var base convert.KV = map[string]any{"general.architecture": "test"}
|
||||
maps.Copy(base, kv)
|
||||
|
||||
if err := ggml.WriteGGUF(f, base, ti); err != nil {
|
||||
|
||||
@@ -6,9 +6,6 @@ import (
|
||||
|
||||
var ErrInterrupt = errors.New("Interrupt")
|
||||
|
||||
// ErrExpandOutput is returned when user presses Ctrl+O to expand tool output
|
||||
var ErrExpandOutput = errors.New("ExpandOutput")
|
||||
|
||||
type InterruptError struct {
|
||||
Line []rune
|
||||
}
|
||||
|
||||
@@ -206,9 +206,6 @@ func (i *Instance) Readline() (string, error) {
|
||||
buf.DeleteBefore()
|
||||
case CharCtrlL:
|
||||
buf.ClearScreen()
|
||||
case CharCtrlO:
|
||||
// Ctrl+O - expand tool output
|
||||
return "", ErrExpandOutput
|
||||
case CharCtrlW:
|
||||
buf.DeleteWord()
|
||||
case CharCtrlZ:
|
||||
|
||||
@@ -42,18 +42,39 @@ shift $(( $OPTIND - 1 ))
|
||||
_build_darwin() {
|
||||
for ARCH in $ARCHS; do
|
||||
status "Building darwin $ARCH"
|
||||
INSTALL_PREFIX=dist/darwin-$ARCH/
|
||||
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 go build -o $INSTALL_PREFIX .
|
||||
INSTALL_PREFIX=dist/darwin-$ARCH/
|
||||
|
||||
if [ "$ARCH" = "amd64" ]; then
|
||||
status "Building darwin $ARCH dynamic backends"
|
||||
cmake -B build/darwin-$ARCH \
|
||||
BUILD_DIR=build/darwin-$ARCH
|
||||
cmake -B $BUILD_DIR \
|
||||
-DCMAKE_OSX_ARCHITECTURES=x86_64 \
|
||||
-DCMAKE_OSX_DEPLOYMENT_TARGET=11.3 \
|
||||
-DCMAKE_OSX_DEPLOYMENT_TARGET=14.0 \
|
||||
-DCMAKE_INSTALL_PREFIX=$INSTALL_PREFIX \
|
||||
-DMLX_ENGINE=ON \
|
||||
-DMLX_ENABLE_X64_MAC=ON \
|
||||
-DOLLAMA_RUNNER_DIR=./
|
||||
cmake --build $BUILD_DIR --target ggml-cpu -j
|
||||
cmake --build $BUILD_DIR --target mlx mlxc -j
|
||||
cmake --install $BUILD_DIR --component CPU
|
||||
cmake --install $BUILD_DIR --component MLX
|
||||
# Override CGO flags to point to the amd64 build directory
|
||||
MLX_CGO_CFLAGS="-O3 -I$(pwd)/$BUILD_DIR/_deps/mlx-c-src -mmacosx-version-min=14.0"
|
||||
MLX_CGO_LDFLAGS="-L$(pwd)/$BUILD_DIR/lib/ollama -lmlxc -lmlx -Wl,-rpath,@executable_path -lc++ -framework Accelerate -mmacosx-version-min=14.0"
|
||||
else
|
||||
BUILD_DIR=build
|
||||
cmake --preset MLX \
|
||||
-DOLLAMA_RUNNER_DIR=./ \
|
||||
-DCMAKE_OSX_DEPLOYMENT_TARGET=14.0 \
|
||||
-DCMAKE_INSTALL_PREFIX=$INSTALL_PREFIX
|
||||
cmake --build build/darwin-$ARCH --target ggml-cpu -j
|
||||
cmake --install build/darwin-$ARCH --component CPU
|
||||
cmake --build --preset MLX --parallel
|
||||
cmake --install $BUILD_DIR --component MLX
|
||||
# Use default CGO flags from mlx.go for arm64
|
||||
MLX_CGO_CFLAGS="-O3 -I$(pwd)/$BUILD_DIR/_deps/mlx-c-src -mmacosx-version-min=14.0"
|
||||
MLX_CGO_LDFLAGS="-L$(pwd)/$BUILD_DIR/lib/ollama -lmlxc -lmlx -Wl,-rpath,@executable_path -lc++ -framework Metal -framework Foundation -framework Accelerate -mmacosx-version-min=14.0"
|
||||
fi
|
||||
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 CGO_CFLAGS="$MLX_CGO_CFLAGS" CGO_LDFLAGS="$MLX_CGO_LDFLAGS" go build -tags mlx -o $INSTALL_PREFIX/imagegen ./x/imagegen/cmd/engine
|
||||
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 go build -o $INSTALL_PREFIX .
|
||||
done
|
||||
}
|
||||
|
||||
@@ -61,10 +82,12 @@ _sign_darwin() {
|
||||
status "Creating universal binary..."
|
||||
mkdir -p dist/darwin
|
||||
lipo -create -output dist/darwin/ollama dist/darwin-*/ollama
|
||||
lipo -create -output dist/darwin/imagegen dist/darwin-*/imagegen
|
||||
chmod +x dist/darwin/ollama
|
||||
chmod +x dist/darwin/imagegen
|
||||
|
||||
if [ -n "$APPLE_IDENTITY" ]; then
|
||||
for F in dist/darwin/ollama dist/darwin-amd64/lib/ollama/*; do
|
||||
for F in dist/darwin/ollama dist/darwin-*/lib/ollama/* dist/darwin/imagegen; do
|
||||
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime $F
|
||||
done
|
||||
|
||||
@@ -131,17 +154,23 @@ _build_macapp() {
|
||||
mkdir -p dist/Ollama.app/Contents/Resources
|
||||
if [ -d dist/darwin-amd64 ]; then
|
||||
lipo -create -output dist/Ollama.app/Contents/Resources/ollama dist/darwin-amd64/ollama dist/darwin-arm64/ollama
|
||||
cp dist/darwin-amd64/lib/ollama/*.so dist/darwin-amd64/lib/ollama/*.dylib dist/Ollama.app/Contents/Resources/
|
||||
lipo -create -output dist/Ollama.app/Contents/Resources/imagegen dist/darwin-amd64/imagegen dist/darwin-arm64/imagegen
|
||||
for F in dist/darwin-amd64/lib/ollama/*mlx*.dylib ; do
|
||||
lipo -create -output dist/darwin/$(basename $F) $F dist/darwin-arm64/lib/ollama/$(basename $F)
|
||||
done
|
||||
cp dist/darwin-*/lib/ollama/*.so dist/darwin-*/lib/ollama/*.dylib dist/Ollama.app/Contents/Resources/
|
||||
cp dist/darwin/*.dylib dist/Ollama.app/Contents/Resources/
|
||||
else
|
||||
cp -a dist/darwin/ollama dist/Ollama.app/Contents/Resources/ollama
|
||||
cp dist/darwin/*.so dist/darwin/*.dylib dist/Ollama.app/Contents/Resources/
|
||||
fi
|
||||
cp -a dist/darwin/imagegen dist/Ollama.app/Contents/Resources/imagegen
|
||||
chmod a+x dist/Ollama.app/Contents/Resources/ollama
|
||||
|
||||
# Sign
|
||||
if [ -n "$APPLE_IDENTITY" ]; then
|
||||
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime dist/Ollama.app/Contents/Resources/ollama
|
||||
for lib in dist/Ollama.app/Contents/Resources/*.so dist/Ollama.app/Contents/Resources/*.dylib ; do
|
||||
for lib in dist/Ollama.app/Contents/Resources/*.so dist/Ollama.app/Contents/Resources/*.dylib dist/Ollama.app/Contents/Resources/imagegen ; do
|
||||
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime ${lib}
|
||||
done
|
||||
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier com.electron.ollama --deep --options=runtime dist/Ollama.app
|
||||
@@ -149,7 +178,7 @@ _build_macapp() {
|
||||
|
||||
rm -f dist/Ollama-darwin.zip
|
||||
ditto -c -k --keepParent dist/Ollama.app dist/Ollama-darwin.zip
|
||||
(cd dist/Ollama.app/Contents/Resources/; tar -cf - ollama *.so *.dylib) | gzip -9vc > dist/ollama-darwin.tgz
|
||||
(cd dist/Ollama.app/Contents/Resources/; tar -cf - ollama imagegen *.so *.dylib) | gzip -9vc > dist/ollama-darwin.tgz
|
||||
|
||||
# Notarize and Staple
|
||||
if [ -n "$APPLE_IDENTITY" ]; then
|
||||
|
||||
@@ -12,6 +12,17 @@ set -eu
|
||||
|
||||
. $(dirname $0)/env.sh
|
||||
|
||||
# Check for required tools
|
||||
if ! command -v zstd >/dev/null 2>&1; then
|
||||
echo "ERROR: zstd is required but not installed." >&2
|
||||
echo "Please install zstd:" >&2
|
||||
echo " - macOS: brew install zstd" >&2
|
||||
echo " - Debian/Ubuntu: sudo apt-get install zstd" >&2
|
||||
echo " - RHEL/CentOS/Fedora: sudo dnf install zstd" >&2
|
||||
echo " - Arch: sudo pacman -S zstd" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
mkdir -p dist
|
||||
|
||||
docker buildx build \
|
||||
@@ -37,19 +48,68 @@ if echo $PLATFORM | grep "amd64" > /dev/null; then
|
||||
.
|
||||
fi
|
||||
|
||||
# Deduplicate CUDA libraries across mlx_* and cuda_* directories
|
||||
deduplicate_cuda_libs() {
|
||||
local base_dir="$1"
|
||||
echo "Deduplicating CUDA libraries in ${base_dir}..."
|
||||
|
||||
# Find all mlx_cuda_* directories
|
||||
for mlx_dir in "${base_dir}"/lib/ollama/mlx_cuda_*; do
|
||||
[ -d "${mlx_dir}" ] || continue
|
||||
|
||||
# Extract CUDA version (e.g., v12, v13)
|
||||
cuda_version=$(basename "${mlx_dir}" | sed 's/mlx_cuda_//')
|
||||
cuda_dir="${base_dir}/lib/ollama/cuda_${cuda_version}"
|
||||
|
||||
# Skip if corresponding cuda_* directory doesn't exist
|
||||
[ -d "${cuda_dir}" ] || continue
|
||||
|
||||
echo " Checking ${mlx_dir} against ${cuda_dir}..."
|
||||
|
||||
# Find all .so* files in mlx directory
|
||||
find "${mlx_dir}" -type f -name "*.so*" | while read mlx_file; do
|
||||
filename=$(basename "${mlx_file}")
|
||||
cuda_file="${cuda_dir}/${filename}"
|
||||
|
||||
# Skip if file doesn't exist in cuda directory
|
||||
[ -f "${cuda_file}" ] || continue
|
||||
|
||||
# Compare checksums
|
||||
mlx_sum=$(sha256sum "${mlx_file}" | awk '{print $1}')
|
||||
cuda_sum=$(sha256sum "${cuda_file}" | awk '{print $1}')
|
||||
|
||||
if [ "${mlx_sum}" = "${cuda_sum}" ]; then
|
||||
echo " Deduplicating ${filename}"
|
||||
# Calculate relative path from mlx_dir to cuda_dir
|
||||
rel_path="../cuda_${cuda_version}/${filename}"
|
||||
rm -f "${mlx_file}"
|
||||
ln -s "${rel_path}" "${mlx_file}"
|
||||
fi
|
||||
done
|
||||
done
|
||||
}
|
||||
|
||||
# Run deduplication for each platform output directory
|
||||
if echo $PLATFORM | grep "," > /dev/null ; then
|
||||
deduplicate_cuda_libs "./dist/linux_amd64"
|
||||
deduplicate_cuda_libs "./dist/linux_arm64"
|
||||
elif echo $PLATFORM | grep "amd64\|arm64" > /dev/null ; then
|
||||
deduplicate_cuda_libs "./dist"
|
||||
fi
|
||||
|
||||
# buildx behavior changes for single vs. multiplatform
|
||||
echo "Compressing linux tar bundles..."
|
||||
if echo $PLATFORM | grep "," > /dev/null ; then
|
||||
tar c -C ./dist/linux_arm64 --exclude cuda_jetpack5 --exclude cuda_jetpack6 . | pigz -9vc >./dist/ollama-linux-arm64.tgz
|
||||
tar c -C ./dist/linux_arm64 ./lib/ollama/cuda_jetpack5 | pigz -9vc >./dist/ollama-linux-arm64-jetpack5.tgz
|
||||
tar c -C ./dist/linux_arm64 ./lib/ollama/cuda_jetpack6 | pigz -9vc >./dist/ollama-linux-arm64-jetpack6.tgz
|
||||
tar c -C ./dist/linux_amd64 --exclude rocm . | pigz -9vc >./dist/ollama-linux-amd64.tgz
|
||||
tar c -C ./dist/linux_amd64 ./lib/ollama/rocm | pigz -9vc >./dist/ollama-linux-amd64-rocm.tgz
|
||||
tar c -C ./dist/linux_arm64 --exclude cuda_jetpack5 --exclude cuda_jetpack6 . | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64.tar.zst
|
||||
tar c -C ./dist/linux_arm64 ./lib/ollama/cuda_jetpack5 | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64-jetpack5.tar.zst
|
||||
tar c -C ./dist/linux_arm64 ./lib/ollama/cuda_jetpack6 | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64-jetpack6.tar.zst
|
||||
tar c -C ./dist/linux_amd64 --exclude rocm . | zstd --ultra -22 -T0 >./dist/ollama-linux-amd64.tar.zst
|
||||
tar c -C ./dist/linux_amd64 ./lib/ollama/rocm | zstd --ultra -22 -T0 >./dist/ollama-linux-amd64-rocm.tar.zst
|
||||
elif echo $PLATFORM | grep "arm64" > /dev/null ; then
|
||||
tar c -C ./dist/ --exclude cuda_jetpack5 --exclude cuda_jetpack6 bin lib | pigz -9vc >./dist/ollama-linux-arm64.tgz
|
||||
tar c -C ./dist/ ./lib/ollama/cuda_jetpack5 | pigz -9vc >./dist/ollama-linux-arm64-jetpack5.tgz
|
||||
tar c -C ./dist/ ./lib/ollama/cuda_jetpack6 | pigz -9vc >./dist/ollama-linux-arm64-jetpack6.tgz
|
||||
tar c -C ./dist/ --exclude cuda_jetpack5 --exclude cuda_jetpack6 bin lib | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64.tar.zst
|
||||
tar c -C ./dist/ ./lib/ollama/cuda_jetpack5 | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64-jetpack5.tar.zst
|
||||
tar c -C ./dist/ ./lib/ollama/cuda_jetpack6 | zstd --ultra -22 -T0 >./dist/ollama-linux-arm64-jetpack6.tar.zst
|
||||
elif echo $PLATFORM | grep "amd64" > /dev/null ; then
|
||||
tar c -C ./dist/ --exclude rocm bin lib | pigz -9vc >./dist/ollama-linux-amd64.tgz
|
||||
tar c -C ./dist/ ./lib/ollama/rocm | pigz -9vc >./dist/ollama-linux-amd64-rocm.tgz
|
||||
tar c -C ./dist/ --exclude rocm bin lib | zstd --ultra -22 -T0 >./dist/ollama-linux-amd64.tar.zst
|
||||
tar c -C ./dist/ ./lib/ollama/rocm | zstd --ultra -22 -T0 >./dist/ollama-linux-amd64-rocm.tar.zst
|
||||
fi
|
||||
|
||||
@@ -66,6 +66,36 @@ if [ -n "$NEEDS" ]; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Function to download and extract with fallback from zst to tgz
|
||||
download_and_extract() {
|
||||
local url_base="$1"
|
||||
local dest_dir="$2"
|
||||
local filename="$3"
|
||||
|
||||
# Check if .tar.zst is available
|
||||
if curl --fail --silent --head --location "${url_base}/${filename}.tar.zst${VER_PARAM}" >/dev/null 2>&1; then
|
||||
# zst file exists - check if we have zstd tool
|
||||
if ! available zstd; then
|
||||
error "This version requires zstd for extraction. Please install zstd and try again:
|
||||
- Debian/Ubuntu: sudo apt-get install zstd
|
||||
- RHEL/CentOS/Fedora: sudo dnf install zstd
|
||||
- Arch: sudo pacman -S zstd"
|
||||
fi
|
||||
|
||||
status "Downloading ${filename}.tar.zst"
|
||||
curl --fail --show-error --location --progress-bar \
|
||||
"${url_base}/${filename}.tar.zst${VER_PARAM}" | \
|
||||
zstd -d | $SUDO tar -xf - -C "${dest_dir}"
|
||||
return 0
|
||||
fi
|
||||
|
||||
# Fall back to .tgz for older versions
|
||||
status "Downloading ${filename}.tgz"
|
||||
curl --fail --show-error --location --progress-bar \
|
||||
"${url_base}/${filename}.tgz${VER_PARAM}" | \
|
||||
$SUDO tar -xzf - -C "${dest_dir}"
|
||||
}
|
||||
|
||||
for BINDIR in /usr/local/bin /usr/bin /bin; do
|
||||
echo $PATH | grep -q $BINDIR && break || continue
|
||||
done
|
||||
@@ -78,10 +108,7 @@ fi
|
||||
status "Installing ollama to $OLLAMA_INSTALL_DIR"
|
||||
$SUDO install -o0 -g0 -m755 -d $BINDIR
|
||||
$SUDO install -o0 -g0 -m755 -d "$OLLAMA_INSTALL_DIR/lib/ollama"
|
||||
status "Downloading Linux ${ARCH} bundle"
|
||||
curl --fail --show-error --location --progress-bar \
|
||||
"https://ollama.com/download/ollama-linux-${ARCH}.tgz${VER_PARAM}" | \
|
||||
$SUDO tar -xzf - -C "$OLLAMA_INSTALL_DIR"
|
||||
download_and_extract "https://ollama.com/download" "$OLLAMA_INSTALL_DIR" "ollama-linux-${ARCH}"
|
||||
|
||||
if [ "$OLLAMA_INSTALL_DIR/bin/ollama" != "$BINDIR/ollama" ] ; then
|
||||
status "Making ollama accessible in the PATH in $BINDIR"
|
||||
@@ -91,15 +118,9 @@ fi
|
||||
# Check for NVIDIA JetPack systems with additional downloads
|
||||
if [ -f /etc/nv_tegra_release ] ; then
|
||||
if grep R36 /etc/nv_tegra_release > /dev/null ; then
|
||||
status "Downloading JetPack 6 components"
|
||||
curl --fail --show-error --location --progress-bar \
|
||||
"https://ollama.com/download/ollama-linux-${ARCH}-jetpack6.tgz${VER_PARAM}" | \
|
||||
$SUDO tar -xzf - -C "$OLLAMA_INSTALL_DIR"
|
||||
download_and_extract "https://ollama.com/download" "$OLLAMA_INSTALL_DIR" "ollama-linux-${ARCH}-jetpack6"
|
||||
elif grep R35 /etc/nv_tegra_release > /dev/null ; then
|
||||
status "Downloading JetPack 5 components"
|
||||
curl --fail --show-error --location --progress-bar \
|
||||
"https://ollama.com/download/ollama-linux-${ARCH}-jetpack5.tgz${VER_PARAM}" | \
|
||||
$SUDO tar -xzf - -C "$OLLAMA_INSTALL_DIR"
|
||||
download_and_extract "https://ollama.com/download" "$OLLAMA_INSTALL_DIR" "ollama-linux-${ARCH}-jetpack5"
|
||||
else
|
||||
warning "Unsupported JetPack version detected. GPU may not be supported"
|
||||
fi
|
||||
@@ -222,10 +243,7 @@ if ! check_gpu lspci nvidia && ! check_gpu lshw nvidia && ! check_gpu lspci amdg
|
||||
fi
|
||||
|
||||
if check_gpu lspci amdgpu || check_gpu lshw amdgpu; then
|
||||
status "Downloading Linux ROCm ${ARCH} bundle"
|
||||
curl --fail --show-error --location --progress-bar \
|
||||
"https://ollama.com/download/ollama-linux-${ARCH}-rocm.tgz${VER_PARAM}" | \
|
||||
$SUDO tar -xzf - -C "$OLLAMA_INSTALL_DIR"
|
||||
download_and_extract "https://ollama.com/download" "$OLLAMA_INSTALL_DIR" "ollama-linux-${ARCH}-rocm"
|
||||
|
||||
install_success
|
||||
status "AMD GPU ready."
|
||||
|
||||
@@ -26,6 +26,7 @@ import (
|
||||
"github.com/ollama/ollama/convert"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
ofs "github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/template"
|
||||
"github.com/ollama/ollama/types/errtypes"
|
||||
@@ -454,7 +455,7 @@ func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, is
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
func kvFromLayers(baseLayers []*layerGGML) (ggml.KV, error) {
|
||||
func kvFromLayers(baseLayers []*layerGGML) (ofs.Config, error) {
|
||||
for _, l := range baseLayers {
|
||||
if l.GGML != nil {
|
||||
return l.KV(), nil
|
||||
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
gocmpopts "github.com/google/go-cmp/cmp/cmpopts"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/convert"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
@@ -41,7 +42,7 @@ func createBinFile(t *testing.T, kv map[string]any, ti []*ggml.Tensor) (string,
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
base := map[string]any{"general.architecture": "test"}
|
||||
var base convert.KV = map[string]any{"general.architecture": "test"}
|
||||
maps.Copy(base, kv)
|
||||
|
||||
if err := ggml.WriteGGUF(f, base, ti); err != nil {
|
||||
|
||||
@@ -381,6 +381,28 @@ func (t templateTools) String() string {
|
||||
return string(bts)
|
||||
}
|
||||
|
||||
// templateArgs is a map type with JSON string output for templates.
|
||||
type templateArgs map[string]any
|
||||
|
||||
func (t templateArgs) String() string {
|
||||
if t == nil {
|
||||
return "{}"
|
||||
}
|
||||
bts, _ := json.Marshal(t)
|
||||
return string(bts)
|
||||
}
|
||||
|
||||
// templateProperties is a map type with JSON string output for templates.
|
||||
type templateProperties map[string]api.ToolProperty
|
||||
|
||||
func (t templateProperties) String() string {
|
||||
if t == nil {
|
||||
return "{}"
|
||||
}
|
||||
bts, _ := json.Marshal(t)
|
||||
return string(bts)
|
||||
}
|
||||
|
||||
// templateTool is a template-compatible representation of api.Tool
|
||||
// with Properties as a regular map for template ranging.
|
||||
type templateTool struct {
|
||||
@@ -396,11 +418,11 @@ type templateToolFunction struct {
|
||||
}
|
||||
|
||||
type templateToolFunctionParameters struct {
|
||||
Type string `json:"type"`
|
||||
Defs any `json:"$defs,omitempty"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Required []string `json:"required,omitempty"`
|
||||
Properties map[string]api.ToolProperty `json:"properties"`
|
||||
Type string `json:"type"`
|
||||
Defs any `json:"$defs,omitempty"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Required []string `json:"required,omitempty"`
|
||||
Properties templateProperties `json:"properties"`
|
||||
}
|
||||
|
||||
// templateToolCall is a template-compatible representation of api.ToolCall
|
||||
@@ -413,7 +435,7 @@ type templateToolCall struct {
|
||||
type templateToolCallFunction struct {
|
||||
Index int
|
||||
Name string
|
||||
Arguments map[string]any
|
||||
Arguments templateArgs
|
||||
}
|
||||
|
||||
// templateMessage is a template-compatible representation of api.Message
|
||||
@@ -446,7 +468,7 @@ func convertToolsForTemplate(tools api.Tools) templateTools {
|
||||
Defs: tool.Function.Parameters.Defs,
|
||||
Items: tool.Function.Parameters.Items,
|
||||
Required: tool.Function.Parameters.Required,
|
||||
Properties: tool.Function.Parameters.Properties.ToMap(),
|
||||
Properties: templateProperties(tool.Function.Parameters.Properties.ToMap()),
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -468,7 +490,7 @@ func convertMessagesForTemplate(messages []*api.Message) []*templateMessage {
|
||||
Function: templateToolCallFunction{
|
||||
Index: tc.Function.Index,
|
||||
Name: tc.Function.Name,
|
||||
Arguments: tc.Function.Arguments.ToMap(),
|
||||
Arguments: templateArgs(tc.Function.Arguments.ToMap()),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
@@ -613,3 +613,159 @@ func TestCollate(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTemplateArgumentsJSON(t *testing.T) {
|
||||
// Test that {{ .Function.Arguments }} outputs valid JSON, not map[key:value]
|
||||
tmpl := `{{- range .Messages }}{{- range .ToolCalls }}{{ .Function.Arguments }}{{- end }}{{- end }}`
|
||||
|
||||
template, err := Parse(tmpl)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
args := api.NewToolCallFunctionArguments()
|
||||
args.Set("location", "Tokyo")
|
||||
args.Set("unit", "celsius")
|
||||
|
||||
var buf bytes.Buffer
|
||||
err = template.Execute(&buf, Values{
|
||||
Messages: []api.Message{{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: args,
|
||||
},
|
||||
}},
|
||||
}},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
got := buf.String()
|
||||
// Should be valid JSON, not "map[location:Tokyo unit:celsius]"
|
||||
if strings.HasPrefix(got, "map[") {
|
||||
t.Errorf("Arguments output as Go map format: %s", got)
|
||||
}
|
||||
|
||||
var parsed map[string]any
|
||||
if err := json.Unmarshal([]byte(got), &parsed); err != nil {
|
||||
t.Errorf("Arguments not valid JSON: %s, error: %v", got, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTemplatePropertiesJSON(t *testing.T) {
|
||||
// Test that {{ .Function.Parameters.Properties }} outputs valid JSON
|
||||
// Note: template must reference .Messages to trigger the modern code path that converts Tools
|
||||
tmpl := `{{- range .Messages }}{{- end }}{{- range .Tools }}{{ .Function.Parameters.Properties }}{{- end }}`
|
||||
|
||||
template, err := Parse(tmpl)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
props := api.NewToolPropertiesMap()
|
||||
props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}, Description: "City name"})
|
||||
|
||||
var buf bytes.Buffer
|
||||
err = template.Execute(&buf, Values{
|
||||
Messages: []api.Message{{Role: "user", Content: "test"}},
|
||||
Tools: api.Tools{{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: props,
|
||||
},
|
||||
},
|
||||
}},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
got := buf.String()
|
||||
// Should be valid JSON, not "map[location:{...}]"
|
||||
if strings.HasPrefix(got, "map[") {
|
||||
t.Errorf("Properties output as Go map format: %s", got)
|
||||
}
|
||||
|
||||
var parsed map[string]any
|
||||
if err := json.Unmarshal([]byte(got), &parsed); err != nil {
|
||||
t.Errorf("Properties not valid JSON: %s, error: %v", got, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTemplateArgumentsRange(t *testing.T) {
|
||||
// Test that we can range over Arguments in templates
|
||||
tmpl := `{{- range .Messages }}{{- range .ToolCalls }}{{- range $k, $v := .Function.Arguments }}{{ $k }}={{ $v }};{{- end }}{{- end }}{{- end }}`
|
||||
|
||||
template, err := Parse(tmpl)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
args := api.NewToolCallFunctionArguments()
|
||||
args.Set("city", "Tokyo")
|
||||
|
||||
var buf bytes.Buffer
|
||||
err = template.Execute(&buf, Values{
|
||||
Messages: []api.Message{{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: args,
|
||||
},
|
||||
}},
|
||||
}},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
got := buf.String()
|
||||
if got != "city=Tokyo;" {
|
||||
t.Errorf("Range over Arguments failed, got: %s, want: city=Tokyo;", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTemplatePropertiesRange(t *testing.T) {
|
||||
// Test that we can range over Properties in templates
|
||||
// Note: template must reference .Messages to trigger the modern code path that converts Tools
|
||||
tmpl := `{{- range .Messages }}{{- end }}{{- range .Tools }}{{- range $name, $prop := .Function.Parameters.Properties }}{{ $name }}:{{ $prop.Type }};{{- end }}{{- end }}`
|
||||
|
||||
template, err := Parse(tmpl)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
props := api.NewToolPropertiesMap()
|
||||
props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}})
|
||||
|
||||
var buf bytes.Buffer
|
||||
err = template.Execute(&buf, Values{
|
||||
Messages: []api.Message{{Role: "user", Content: "test"}},
|
||||
Tools: api.Tools{{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: props,
|
||||
},
|
||||
},
|
||||
}},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
got := buf.String()
|
||||
if got != "location:string;" {
|
||||
t.Errorf("Range over Properties failed, got: %s, want: location:string;", got)
|
||||
}
|
||||
}
|
||||
|
||||
24
x/README.md
Normal file
24
x/README.md
Normal file
@@ -0,0 +1,24 @@
|
||||
# Experimental Features
|
||||
|
||||
## MLX Backend
|
||||
|
||||
We're working on a new experimental backend based on the [MLX project](https://github.com/ml-explore/mlx)
|
||||
|
||||
Support is currently limited to MacOS and Linux with CUDA GPUs. We're looking to add support for Windows CUDA soon, and other GPU vendors. To build:
|
||||
|
||||
```
|
||||
cmake --preset MLX
|
||||
cmake --build --preset MLX --parallel
|
||||
cmake --install --component MLX
|
||||
go build -tags mlx .
|
||||
```
|
||||
|
||||
On linux, use the preset "MLX CUDA 13" or "MLX CUDA 12" to enable CUDA with the default Ollama NVIDIA GPU architectures enabled.
|
||||
|
||||
## Image Generation
|
||||
|
||||
Based on the experimental MLX backend, we're working on adding imagegen support. After running the cmake commands above:
|
||||
|
||||
```
|
||||
go build -o imagegen ./x/imagegen/cmd/engine
|
||||
```
|
||||
@@ -37,6 +37,25 @@ var optionLabels = []string{
|
||||
"3. Deny",
|
||||
}
|
||||
|
||||
// toolDisplayNames maps internal tool names to human-readable display names.
|
||||
var toolDisplayNames = map[string]string{
|
||||
"bash": "Bash",
|
||||
"web_search": "Web Search",
|
||||
}
|
||||
|
||||
// ToolDisplayName returns the human-readable display name for a tool.
|
||||
func ToolDisplayName(toolName string) string {
|
||||
if displayName, ok := toolDisplayNames[toolName]; ok {
|
||||
return displayName
|
||||
}
|
||||
// Default: capitalize first letter and replace underscores with spaces
|
||||
name := strings.ReplaceAll(toolName, "_", " ")
|
||||
if len(name) > 0 {
|
||||
return strings.ToUpper(name[:1]) + name[1:]
|
||||
}
|
||||
return toolName
|
||||
}
|
||||
|
||||
// autoAllowCommands are commands that are always allowed without prompting.
|
||||
// These are zero-risk, read-only commands.
|
||||
var autoAllowCommands = map[string]bool{
|
||||
@@ -509,11 +528,12 @@ func (a *ApprovalManager) RequestApproval(toolName string, args map[string]any)
|
||||
// formatToolDisplay creates the display string for a tool call.
|
||||
func formatToolDisplay(toolName string, args map[string]any) string {
|
||||
var sb strings.Builder
|
||||
displayName := ToolDisplayName(toolName)
|
||||
|
||||
// For bash, show command directly
|
||||
if toolName == "bash" {
|
||||
if cmd, ok := args["command"].(string); ok {
|
||||
sb.WriteString(fmt.Sprintf("Tool: %s\n", toolName))
|
||||
sb.WriteString(fmt.Sprintf("Tool: %s\n", displayName))
|
||||
sb.WriteString(fmt.Sprintf("Command: %s", cmd))
|
||||
return sb.String()
|
||||
}
|
||||
@@ -522,7 +542,7 @@ func formatToolDisplay(toolName string, args map[string]any) string {
|
||||
// For web search, show query and internet notice
|
||||
if toolName == "web_search" {
|
||||
if query, ok := args["query"].(string); ok {
|
||||
sb.WriteString(fmt.Sprintf("Tool: %s\n", toolName))
|
||||
sb.WriteString(fmt.Sprintf("Tool: %s\n", displayName))
|
||||
sb.WriteString(fmt.Sprintf("Query: %s\n", query))
|
||||
sb.WriteString("Uses internet via ollama.com")
|
||||
return sb.String()
|
||||
@@ -530,7 +550,7 @@ func formatToolDisplay(toolName string, args map[string]any) string {
|
||||
}
|
||||
|
||||
// Generic display
|
||||
sb.WriteString(fmt.Sprintf("Tool: %s", toolName))
|
||||
sb.WriteString(fmt.Sprintf("Tool: %s", displayName))
|
||||
if len(args) > 0 {
|
||||
sb.WriteString("\nArguments: ")
|
||||
first := true
|
||||
@@ -724,7 +744,7 @@ func wrapText(text string, maxWidth int) []string {
|
||||
|
||||
// getHintLines returns the hint text wrapped to terminal width
|
||||
func getHintLines(state *selectorState) []string {
|
||||
hint := "↑/↓ navigate, Enter confirm, 1-3 quick, Ctrl+C cancel"
|
||||
hint := "up/down select, enter confirm, 1-3 quick select, ctrl+c cancel"
|
||||
if state.termWidth >= len(hint)+1 {
|
||||
return []string{hint}
|
||||
}
|
||||
@@ -734,86 +754,60 @@ func getHintLines(state *selectorState) []string {
|
||||
|
||||
// calculateTotalLines calculates how many lines the selector will use
|
||||
func calculateTotalLines(state *selectorState) int {
|
||||
toolLines := wrapText(state.toolDisplay, state.innerWidth)
|
||||
toolLines := strings.Split(state.toolDisplay, "\n")
|
||||
hintLines := getHintLines(state)
|
||||
// top border + (warning line if applicable) + tool lines + separator + options + bottom border + hint lines
|
||||
// warning line (if applicable) + tool lines + blank line + options + blank line + hint lines
|
||||
warningLines := 0
|
||||
if state.isWarning {
|
||||
warningLines = 1
|
||||
warningLines = 2 // warning line + blank line after
|
||||
}
|
||||
return 1 + warningLines + len(toolLines) + 1 + len(optionLabels) + 1 + len(hintLines)
|
||||
return warningLines + len(toolLines) + 1 + len(optionLabels) + 1 + len(hintLines)
|
||||
}
|
||||
|
||||
// renderSelectorBox renders the complete selector box
|
||||
// renderSelectorBox renders the selector (minimal, no box)
|
||||
func renderSelectorBox(state *selectorState) {
|
||||
toolLines := wrapText(state.toolDisplay, state.innerWidth)
|
||||
toolLines := strings.Split(state.toolDisplay, "\n")
|
||||
hintLines := getHintLines(state)
|
||||
|
||||
// Use red for warning (outside cwd), cyan for normal
|
||||
boxColor := "\033[36m" // cyan
|
||||
// Draw warning line if needed
|
||||
if state.isWarning {
|
||||
boxColor = "\033[91m" // bright red
|
||||
fmt.Fprintf(os.Stderr, "\033[1mwarning:\033[0m command targets paths outside project\033[K\r\n")
|
||||
fmt.Fprintf(os.Stderr, "\033[K\r\n") // blank line after warning
|
||||
}
|
||||
|
||||
// Draw box top
|
||||
fmt.Fprintf(os.Stderr, "%s┌%s┐\033[0m\033[K\r\n", boxColor, strings.Repeat("─", state.boxWidth-2))
|
||||
|
||||
// Draw warning line if needed (inside the box)
|
||||
if state.isWarning {
|
||||
warning := "!! OUTSIDE PROJECT !!"
|
||||
padding := (state.innerWidth - len(warning)) / 2
|
||||
if padding < 0 {
|
||||
padding = 0
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "%s│\033[0m %s%s%s %s│\033[0m\033[K\r\n", boxColor,
|
||||
strings.Repeat(" ", padding), warning, strings.Repeat(" ", state.innerWidth-len(warning)-padding), boxColor)
|
||||
}
|
||||
|
||||
// Draw tool info
|
||||
// Draw tool info (plain white)
|
||||
for _, line := range toolLines {
|
||||
fmt.Fprintf(os.Stderr, "%s│\033[0m %-*s %s│\033[0m\033[K\r\n", boxColor, state.innerWidth, line, boxColor)
|
||||
fmt.Fprintf(os.Stderr, "%s\033[K\r\n", line)
|
||||
}
|
||||
|
||||
// Draw separator
|
||||
fmt.Fprintf(os.Stderr, "%s├%s┤\033[0m\033[K\r\n", boxColor, strings.Repeat("─", state.boxWidth-2))
|
||||
// Blank line separator
|
||||
fmt.Fprintf(os.Stderr, "\033[K\r\n")
|
||||
|
||||
// Draw options with numbers (Deny option includes reason input)
|
||||
// Draw options
|
||||
for i, label := range optionLabels {
|
||||
if i == 2 { // Deny option - show with reason input beside it
|
||||
if i == 2 { // Deny option with input
|
||||
denyLabel := "3. Deny: "
|
||||
availableWidth := state.innerWidth - 2 - len(denyLabel)
|
||||
if availableWidth < 5 {
|
||||
availableWidth = 5
|
||||
}
|
||||
inputDisplay := state.denyReason
|
||||
if len(inputDisplay) > availableWidth {
|
||||
inputDisplay = inputDisplay[len(inputDisplay)-availableWidth:]
|
||||
}
|
||||
if i == state.selected {
|
||||
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[1;32m> %s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
|
||||
fmt.Fprintf(os.Stderr, " \033[1m%s\033[0m%s\033[K\r\n", denyLabel, inputDisplay)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[90m%s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
|
||||
fmt.Fprintf(os.Stderr, " \033[37m%s\033[0m%s\033[K\r\n", denyLabel, inputDisplay)
|
||||
}
|
||||
} else {
|
||||
displayLabel := label
|
||||
if len(displayLabel) > state.innerWidth-2 {
|
||||
displayLabel = displayLabel[:state.innerWidth-5] + "..."
|
||||
}
|
||||
if i == state.selected {
|
||||
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[1;32m> %-*s\033[0m %s│\033[0m\033[K\r\n", boxColor, state.innerWidth-2, displayLabel, boxColor)
|
||||
fmt.Fprintf(os.Stderr, " \033[1m%s\033[0m\033[K\r\n", label)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "%s│\033[0m %-*s %s│\033[0m\033[K\r\n", boxColor, state.innerWidth-2, displayLabel, boxColor)
|
||||
fmt.Fprintf(os.Stderr, " \033[37m%s\033[0m\033[K\r\n", label)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Draw box bottom
|
||||
fmt.Fprintf(os.Stderr, "%s└%s┘\033[0m\033[K\r\n", boxColor, strings.Repeat("─", state.boxWidth-2))
|
||||
// Blank line before hint
|
||||
fmt.Fprintf(os.Stderr, "\033[K\r\n")
|
||||
|
||||
// Draw hint (may be multiple lines)
|
||||
// Draw hint (dark grey)
|
||||
for i, line := range hintLines {
|
||||
if i == len(hintLines)-1 {
|
||||
// Last line - no newline
|
||||
fmt.Fprintf(os.Stderr, "\033[90m%s\033[0m\033[K", line)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "\033[90m%s\033[0m\033[K\r\n", line)
|
||||
@@ -825,50 +819,33 @@ func renderSelectorBox(state *selectorState) {
|
||||
func updateSelectorOptions(state *selectorState) {
|
||||
hintLines := getHintLines(state)
|
||||
|
||||
// Use red for warning (outside cwd), cyan for normal
|
||||
boxColor := "\033[36m" // cyan
|
||||
if state.isWarning {
|
||||
boxColor = "\033[91m" // bright red
|
||||
}
|
||||
|
||||
// Move up to the first option line
|
||||
// Cursor is at end of last hint line, need to go up:
|
||||
// (hint lines - 1) + 1 (bottom border) + numOptions
|
||||
// (hint lines - 1) + 1 (blank line) + numOptions
|
||||
linesToMove := len(hintLines) - 1 + 1 + len(optionLabels)
|
||||
fmt.Fprintf(os.Stderr, "\033[%dA\r", linesToMove)
|
||||
|
||||
// Redraw options (Deny option includes reason input)
|
||||
// Redraw options
|
||||
for i, label := range optionLabels {
|
||||
if i == 2 { // Deny option
|
||||
denyLabel := "3. Deny: "
|
||||
availableWidth := state.innerWidth - 2 - len(denyLabel)
|
||||
if availableWidth < 5 {
|
||||
availableWidth = 5
|
||||
}
|
||||
inputDisplay := state.denyReason
|
||||
if len(inputDisplay) > availableWidth {
|
||||
inputDisplay = inputDisplay[len(inputDisplay)-availableWidth:]
|
||||
}
|
||||
if i == state.selected {
|
||||
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[1;32m> %s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
|
||||
fmt.Fprintf(os.Stderr, " \033[1m%s\033[0m%s\033[K\r\n", denyLabel, inputDisplay)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[90m%s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
|
||||
fmt.Fprintf(os.Stderr, " \033[37m%s\033[0m%s\033[K\r\n", denyLabel, inputDisplay)
|
||||
}
|
||||
} else {
|
||||
displayLabel := label
|
||||
if len(displayLabel) > state.innerWidth-2 {
|
||||
displayLabel = displayLabel[:state.innerWidth-5] + "..."
|
||||
}
|
||||
if i == state.selected {
|
||||
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[1;32m> %-*s\033[0m %s│\033[0m\033[K\r\n", boxColor, state.innerWidth-2, displayLabel, boxColor)
|
||||
fmt.Fprintf(os.Stderr, " \033[1m%s\033[0m\033[K\r\n", label)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "%s│\033[0m %-*s %s│\033[0m\033[K\r\n", boxColor, state.innerWidth-2, displayLabel, boxColor)
|
||||
fmt.Fprintf(os.Stderr, " \033[37m%s\033[0m\033[K\r\n", label)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Redraw bottom and hint
|
||||
fmt.Fprintf(os.Stderr, "%s└%s┘\033[0m\033[K\r\n", boxColor, strings.Repeat("─", state.boxWidth-2))
|
||||
// Blank line + hint
|
||||
fmt.Fprintf(os.Stderr, "\033[K\r\n")
|
||||
for i, line := range hintLines {
|
||||
if i == len(hintLines)-1 {
|
||||
fmt.Fprintf(os.Stderr, "\033[90m%s\033[0m\033[K", line)
|
||||
@@ -882,36 +859,23 @@ func updateSelectorOptions(state *selectorState) {
|
||||
func updateReasonInput(state *selectorState) {
|
||||
hintLines := getHintLines(state)
|
||||
|
||||
// Use red for warning (outside cwd), cyan for normal
|
||||
boxColor := "\033[36m" // cyan
|
||||
if state.isWarning {
|
||||
boxColor = "\033[91m" // bright red
|
||||
}
|
||||
|
||||
// Move up to the Deny line (3rd option, index 2)
|
||||
// Cursor is at end of last hint line, need to go up:
|
||||
// (hint lines - 1) + 1 (bottom border) + 1 (Deny is last option)
|
||||
// (hint lines - 1) + 1 (blank line) + 1 (Deny is last option)
|
||||
linesToMove := len(hintLines) - 1 + 1 + 1
|
||||
fmt.Fprintf(os.Stderr, "\033[%dA\r", linesToMove)
|
||||
|
||||
// Redraw Deny line with reason
|
||||
denyLabel := "3. Deny: "
|
||||
availableWidth := state.innerWidth - 2 - len(denyLabel)
|
||||
if availableWidth < 5 {
|
||||
availableWidth = 5
|
||||
}
|
||||
inputDisplay := state.denyReason
|
||||
if len(inputDisplay) > availableWidth {
|
||||
inputDisplay = inputDisplay[len(inputDisplay)-availableWidth:]
|
||||
}
|
||||
if state.selected == 2 {
|
||||
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[1;32m> %s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
|
||||
fmt.Fprintf(os.Stderr, " \033[1m%s\033[0m%s\033[K\r\n", denyLabel, inputDisplay)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[90m%s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
|
||||
fmt.Fprintf(os.Stderr, " \033[37m%s\033[0m%s\033[K\r\n", denyLabel, inputDisplay)
|
||||
}
|
||||
|
||||
// Redraw bottom and hint
|
||||
fmt.Fprintf(os.Stderr, "%s└%s┘\033[0m\033[K\r\n", boxColor, strings.Repeat("─", state.boxWidth-2))
|
||||
// Blank line + hint
|
||||
fmt.Fprintf(os.Stderr, "\033[K\r\n")
|
||||
for i, line := range hintLines {
|
||||
if i == len(hintLines)-1 {
|
||||
fmt.Fprintf(os.Stderr, "\033[90m%s\033[0m\033[K", line)
|
||||
@@ -935,11 +899,10 @@ func clearSelectorBox(state *selectorState) {
|
||||
// fallbackApproval handles approval when terminal control isn't available.
|
||||
func (a *ApprovalManager) fallbackApproval(toolDisplay string) (ApprovalResult, error) {
|
||||
fmt.Fprintln(os.Stderr)
|
||||
fmt.Fprintln(os.Stderr, "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
|
||||
fmt.Fprintln(os.Stderr, toolDisplay)
|
||||
fmt.Fprintln(os.Stderr, "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
|
||||
fmt.Fprintln(os.Stderr)
|
||||
fmt.Fprintln(os.Stderr, "[1] Execute once [2] Always allow [3] Deny")
|
||||
fmt.Fprint(os.Stderr, "Choice: ")
|
||||
fmt.Fprint(os.Stderr, "choice: ")
|
||||
|
||||
var input string
|
||||
fmt.Scanln(&input)
|
||||
@@ -982,19 +945,16 @@ func (a *ApprovalManager) AllowedTools() []string {
|
||||
|
||||
// FormatApprovalResult returns a formatted string showing the approval result.
|
||||
func FormatApprovalResult(toolName string, args map[string]any, result ApprovalResult) string {
|
||||
var status string
|
||||
var icon string
|
||||
var label string
|
||||
displayName := ToolDisplayName(toolName)
|
||||
|
||||
switch result.Decision {
|
||||
case ApprovalOnce:
|
||||
status = "Approved"
|
||||
icon = "\033[32m✓\033[0m"
|
||||
label = "approved"
|
||||
case ApprovalAlways:
|
||||
status = "Always allowed"
|
||||
icon = "\033[32m✓\033[0m"
|
||||
label = "always allowed"
|
||||
case ApprovalDeny:
|
||||
status = "Denied"
|
||||
icon = "\033[31m✗\033[0m"
|
||||
label = "denied"
|
||||
}
|
||||
|
||||
// Format based on tool type
|
||||
@@ -1004,7 +964,7 @@ func FormatApprovalResult(toolName string, args map[string]any, result ApprovalR
|
||||
if len(cmd) > 40 {
|
||||
cmd = cmd[:37] + "..."
|
||||
}
|
||||
return fmt.Sprintf("▶ bash: %s [%s] %s", cmd, status, icon)
|
||||
return fmt.Sprintf("\033[1m%s:\033[0m %s: %s", label, displayName, cmd)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1014,11 +974,11 @@ func FormatApprovalResult(toolName string, args map[string]any, result ApprovalR
|
||||
if len(query) > 40 {
|
||||
query = query[:37] + "..."
|
||||
}
|
||||
return fmt.Sprintf("▶ web_search: %s [%s] %s", query, status, icon)
|
||||
return fmt.Sprintf("\033[1m%s:\033[0m %s: %s", label, displayName, query)
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Sprintf("▶ %s [%s] %s", toolName, status, icon)
|
||||
return fmt.Sprintf("\033[1m%s:\033[0m %s", label, displayName)
|
||||
}
|
||||
|
||||
// FormatDenyResult returns the tool result message when a tool is denied.
|
||||
@@ -1049,15 +1009,14 @@ func PromptYesNo(question string) (bool, error) {
|
||||
renderYesNo := func() {
|
||||
// Move to start of line and clear
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K")
|
||||
fmt.Fprintf(os.Stderr, "\033[36m%s\033[0m ", question)
|
||||
fmt.Fprintf(os.Stderr, "%s ", question)
|
||||
for i, opt := range options {
|
||||
if i == selected {
|
||||
fmt.Fprintf(os.Stderr, "\033[1;32m[%s]\033[0m ", opt)
|
||||
fmt.Fprintf(os.Stderr, "\033[1m%s\033[0m ", opt)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "\033[90m %s \033[0m ", opt)
|
||||
fmt.Fprintf(os.Stderr, "\033[37m%s\033[0m ", opt)
|
||||
}
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "\033[90m(←/→ or y/n, Enter to confirm)\033[0m")
|
||||
}
|
||||
|
||||
renderYesNo()
|
||||
@@ -1104,108 +1063,3 @@ func PromptYesNo(question string) (bool, error) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CloudModelOption represents a suggested cloud model for the selection prompt.
|
||||
type CloudModelOption struct {
|
||||
Name string
|
||||
Description string
|
||||
}
|
||||
|
||||
// PromptModelChoice displays a model selection prompt with multiple options.
|
||||
// Returns the selected model name, or empty string if user declined or cancelled.
|
||||
func PromptModelChoice(question string, models []CloudModelOption) (string, error) {
|
||||
fd := int(os.Stdin.Fd())
|
||||
oldState, err := term.MakeRaw(fd)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer term.Restore(fd, oldState)
|
||||
|
||||
// Build options: models + "No thanks, continue"
|
||||
optionCount := len(models) + 1
|
||||
selected := 0
|
||||
|
||||
// Total lines: question + models + "no thanks" + hint = optionCount + 2
|
||||
totalLines := optionCount + 2
|
||||
|
||||
// Hide cursor
|
||||
fmt.Fprint(os.Stderr, "\033[?25l")
|
||||
defer fmt.Fprint(os.Stderr, "\033[?25h")
|
||||
|
||||
firstRender := true
|
||||
|
||||
render := func() {
|
||||
if !firstRender {
|
||||
fmt.Fprintf(os.Stderr, "\033[%dA\r", totalLines-1)
|
||||
}
|
||||
firstRender = false
|
||||
|
||||
// \r\n needed in raw mode for proper line breaks
|
||||
fmt.Fprintf(os.Stderr, "\033[K\033[36m%s\033[0m\r\n", question)
|
||||
|
||||
for i, model := range models {
|
||||
fmt.Fprintf(os.Stderr, "\033[K")
|
||||
if i == selected {
|
||||
fmt.Fprintf(os.Stderr, " \033[1;32m> %s\033[0m \033[90m%s\033[0m\r\n", model.Name, model.Description)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, " \033[90m%s %s\033[0m\r\n", model.Name, model.Description)
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\033[K")
|
||||
if selected == len(models) {
|
||||
fmt.Fprintf(os.Stderr, " \033[1;32m> No thanks, continue\033[0m\r\n")
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, " \033[90mNo thanks, continue\033[0m\r\n")
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\033[K\033[90m(↑/↓ to navigate, Enter to confirm)\033[0m")
|
||||
}
|
||||
|
||||
render()
|
||||
|
||||
buf := make([]byte, 3)
|
||||
for {
|
||||
n, err := os.Stdin.Read(buf)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if n == 1 {
|
||||
switch buf[0] {
|
||||
case 'j', 'J':
|
||||
if selected < optionCount-1 {
|
||||
selected++
|
||||
}
|
||||
render()
|
||||
case 'k', 'K':
|
||||
if selected > 0 {
|
||||
selected--
|
||||
}
|
||||
render()
|
||||
case '\r', '\n':
|
||||
fmt.Fprintf(os.Stderr, "\n")
|
||||
if selected < len(models) {
|
||||
return models[selected].Name, nil
|
||||
}
|
||||
return "", nil
|
||||
case 3: // Ctrl+C
|
||||
fmt.Fprintf(os.Stderr, "\n")
|
||||
return "", nil
|
||||
}
|
||||
} else if n == 3 && buf[0] == 27 && buf[1] == 91 {
|
||||
switch buf[2] {
|
||||
case 'A': // Up
|
||||
if selected > 0 {
|
||||
selected--
|
||||
}
|
||||
render()
|
||||
case 'B': // Down
|
||||
if selected < optionCount-1 {
|
||||
selected++
|
||||
}
|
||||
render()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCloudModelOptionStruct(t *testing.T) {
|
||||
// Test that the struct is defined correctly
|
||||
models := []CloudModelOption{
|
||||
{Name: "glm-4.7:cloud", Description: "GLM 4.7 Cloud"},
|
||||
{Name: "qwen3-coder:480b-cloud", Description: "Qwen3 Coder 480B"},
|
||||
}
|
||||
|
||||
if len(models) != 2 {
|
||||
t.Errorf("expected 2 models, got %d", len(models))
|
||||
}
|
||||
|
||||
if models[0].Name != "glm-4.7:cloud" {
|
||||
t.Errorf("expected glm-4.7:cloud, got %s", models[0].Name)
|
||||
}
|
||||
|
||||
if models[1].Description != "Qwen3 Coder 480B" {
|
||||
t.Errorf("expected 'Qwen3 Coder 480B', got %s", models[1].Description)
|
||||
}
|
||||
}
|
||||
@@ -1,41 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCloudModelSwitchRequest(t *testing.T) {
|
||||
// Test the error type
|
||||
req := &CloudModelSwitchRequest{Model: "glm-4.7:cloud"}
|
||||
|
||||
// Test Error() method
|
||||
errMsg := req.Error()
|
||||
expected := "switch to model: glm-4.7:cloud"
|
||||
if errMsg != expected {
|
||||
t.Errorf("expected %q, got %q", expected, errMsg)
|
||||
}
|
||||
|
||||
// Test errors.As
|
||||
var err error = req
|
||||
var switchReq *CloudModelSwitchRequest
|
||||
if !errors.As(err, &switchReq) {
|
||||
t.Error("errors.As should return true for CloudModelSwitchRequest")
|
||||
}
|
||||
|
||||
if switchReq.Model != "glm-4.7:cloud" {
|
||||
t.Errorf("expected model glm-4.7:cloud, got %s", switchReq.Model)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSuggestedCloudModels(t *testing.T) {
|
||||
// Verify the suggested models are defined
|
||||
if len(suggestedCloudModels) == 0 {
|
||||
t.Error("suggestedCloudModels should not be empty")
|
||||
}
|
||||
|
||||
// Check first model
|
||||
if suggestedCloudModels[0].Name != "glm-4.7:cloud" {
|
||||
t.Errorf("expected first model to be glm-4.7:cloud, got %s", suggestedCloudModels[0].Name)
|
||||
}
|
||||
}
|
||||
295
x/cmd/run.go
295
x/cmd/run.go
@@ -37,22 +37,6 @@ const (
|
||||
charsPerToken = 4
|
||||
)
|
||||
|
||||
// suggestedCloudModels are the models suggested to users after signing in.
|
||||
// TODO(parthsareen): Dynamically recommend models based on user context instead of hardcoding
|
||||
var suggestedCloudModels = []agent.CloudModelOption{
|
||||
{Name: "glm-4.7:cloud", Description: "GLM 4.7 Cloud"},
|
||||
{Name: "qwen3-coder:480b-cloud", Description: "Qwen3 Coder 480B"},
|
||||
}
|
||||
|
||||
// CloudModelSwitchRequest signals that the user wants to switch to a different model.
|
||||
type CloudModelSwitchRequest struct {
|
||||
Model string
|
||||
}
|
||||
|
||||
func (c *CloudModelSwitchRequest) Error() string {
|
||||
return fmt.Sprintf("switch to model: %s", c.Model)
|
||||
}
|
||||
|
||||
// isLocalModel checks if the model is running locally (not a cloud model).
|
||||
// TODO: Improve local/cloud model identification - could check model metadata
|
||||
func isLocalModel(modelName string) bool {
|
||||
@@ -107,8 +91,8 @@ func waitForOllamaSignin(ctx context.Context) error {
|
||||
var aErr api.AuthorizationError
|
||||
if errors.As(err, &aErr) && aErr.SigninURL != "" {
|
||||
fmt.Fprintf(os.Stderr, "\n To sign in, navigate to:\n")
|
||||
fmt.Fprintf(os.Stderr, " \033[36m%s\033[0m\n\n", aErr.SigninURL)
|
||||
fmt.Fprintf(os.Stderr, " \033[90mWaiting for sign in to complete...\033[0m")
|
||||
fmt.Fprintf(os.Stderr, " %s\n\n", aErr.SigninURL)
|
||||
fmt.Fprintf(os.Stderr, " \033[90mwaiting for sign in to complete...\033[0m")
|
||||
|
||||
// Poll until auth succeeds
|
||||
ticker := time.NewTicker(2 * time.Second)
|
||||
@@ -122,7 +106,7 @@ func waitForOllamaSignin(ctx context.Context) error {
|
||||
case <-ticker.C:
|
||||
user, whoamiErr := client.Whoami(ctx)
|
||||
if whoamiErr == nil && user != nil && user.Name != "" {
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K \033[32mSigned in as %s\033[0m\n", user.Name)
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K \033[1msigned in:\033[0m %s\n", user.Name)
|
||||
return nil
|
||||
}
|
||||
// Still waiting, show dot
|
||||
@@ -135,21 +119,6 @@ func waitForOllamaSignin(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// promptCloudModelSuggestion shows cloud model suggestions after successful sign-in.
|
||||
// Returns the selected model name, or empty string if user declines.
|
||||
func promptCloudModelSuggestion() string {
|
||||
fmt.Fprintf(os.Stderr, "\n")
|
||||
fmt.Fprintf(os.Stderr, "\033[1;36mTry cloud models for free!\033[0m\n")
|
||||
fmt.Fprintf(os.Stderr, "\033[90mCloud models offer powerful capabilities without local hardware requirements.\033[0m\n")
|
||||
fmt.Fprintf(os.Stderr, "\n")
|
||||
|
||||
selectedModel, err := agent.PromptModelChoice("Try a cloud model now?", suggestedCloudModels)
|
||||
if err != nil || selectedModel == "" {
|
||||
return ""
|
||||
}
|
||||
return selectedModel
|
||||
}
|
||||
|
||||
// RunOptions contains options for running an interactive agent session.
|
||||
type RunOptions struct {
|
||||
Model string
|
||||
@@ -168,47 +137,6 @@ type RunOptions struct {
|
||||
|
||||
// YoloMode skips all tool approval prompts
|
||||
YoloMode bool
|
||||
|
||||
// LastToolOutput stores the full output of the last tool execution
|
||||
// for Ctrl+O expansion. Updated by Chat(), read by caller.
|
||||
LastToolOutput *string
|
||||
|
||||
// LastToolOutputTruncated stores the truncated version shown inline
|
||||
LastToolOutputTruncated *string
|
||||
|
||||
// ActiveModel points to the current model name - can be updated mid-turn
|
||||
// for model switching. If nil, opts.Model is used.
|
||||
ActiveModel *string
|
||||
}
|
||||
|
||||
// getActiveModel returns the current model name, checking ActiveModel pointer first.
|
||||
func getActiveModel(opts *RunOptions) string {
|
||||
if opts.ActiveModel != nil && *opts.ActiveModel != "" {
|
||||
return *opts.ActiveModel
|
||||
}
|
||||
return opts.Model
|
||||
}
|
||||
|
||||
// showModelConnection displays "Connecting to X on ollama.com" for cloud models.
|
||||
func showModelConnection(ctx context.Context, modelName string) error {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
info, err := client.Show(ctx, &api.ShowRequest{Model: modelName})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if info.RemoteHost != "" {
|
||||
if strings.HasPrefix(info.RemoteHost, "https://ollama.com") {
|
||||
fmt.Fprintf(os.Stderr, "Connecting to '%s' on 'ollama.com' ⚡\n", info.RemoteModel)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "Connecting to '%s' on '%s'\n", info.RemoteModel, info.RemoteHost)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Chat runs an agent chat loop with tool support.
|
||||
@@ -308,7 +236,7 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
|
||||
// Agentic loop: continue until no more tool calls
|
||||
for {
|
||||
req := &api.ChatRequest{
|
||||
Model: getActiveModel(&opts),
|
||||
Model: opts.Model,
|
||||
Messages: messages,
|
||||
Format: json.RawMessage(opts.Format),
|
||||
Options: opts.Options,
|
||||
@@ -332,20 +260,17 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Check for 401 Unauthorized - prompt user to sign in
|
||||
var authErr api.AuthorizationError
|
||||
if errors.As(err, &authErr) {
|
||||
p.StopAndClear()
|
||||
fmt.Fprintf(os.Stderr, "\033[33mAuthentication required to use this cloud model.\033[0m\n")
|
||||
fmt.Fprintf(os.Stderr, "\033[1mauth required:\033[0m cloud model requires authentication\n")
|
||||
result, promptErr := agent.PromptYesNo("Sign in to Ollama?")
|
||||
if promptErr == nil && result {
|
||||
if signinErr := waitForOllamaSignin(ctx); signinErr == nil {
|
||||
suggestedModel := promptCloudModelSuggestion()
|
||||
if suggestedModel != "" {
|
||||
return nil, &CloudModelSwitchRequest{Model: suggestedModel}
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\033[90mRetrying...\033[0m\n")
|
||||
continue
|
||||
// Retry the chat request
|
||||
fmt.Fprintf(os.Stderr, "\033[90mretrying...\033[0m\n")
|
||||
continue // Retry the loop
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("authentication required - run 'ollama signin' to authenticate")
|
||||
@@ -358,11 +283,11 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
|
||||
p.StopAndClear()
|
||||
|
||||
if consecutiveErrors >= 3 {
|
||||
fmt.Fprintf(os.Stderr, "\033[31m✗ Too many consecutive errors, giving up\033[0m\n")
|
||||
fmt.Fprintf(os.Stderr, "\033[1merror:\033[0m too many consecutive errors, giving up\n")
|
||||
return nil, fmt.Errorf("too many consecutive server errors: %s", statusErr.ErrorMessage)
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\033[33m⚠ Server error (attempt %d/3): %s\033[0m\n", consecutiveErrors, statusErr.ErrorMessage)
|
||||
fmt.Fprintf(os.Stderr, "\033[1mwarning:\033[0m server error (attempt %d/3): %s\n", consecutiveErrors, statusErr.ErrorMessage)
|
||||
|
||||
// Include both the model's response and the error so it can learn
|
||||
assistantContent := fullResponse.String()
|
||||
@@ -428,8 +353,8 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
|
||||
if cmd, ok := args["command"].(string); ok {
|
||||
// Check if command is denied (dangerous pattern)
|
||||
if denied, pattern := agent.IsDenied(cmd); denied {
|
||||
fmt.Fprintf(os.Stderr, "\033[91m✗ Blocked: %s\033[0m\n", formatToolShort(toolName, args))
|
||||
fmt.Fprintf(os.Stderr, "\033[91m Matches dangerous pattern: %s\033[0m\n", pattern)
|
||||
fmt.Fprintf(os.Stderr, "\033[1mblocked:\033[0m %s\n", formatToolShort(toolName, args))
|
||||
fmt.Fprintf(os.Stderr, " matches dangerous pattern: %s\n", pattern)
|
||||
toolResults = append(toolResults, api.Message{
|
||||
Role: "tool",
|
||||
Content: agent.FormatDeniedResult(cmd, pattern),
|
||||
@@ -440,7 +365,7 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
|
||||
|
||||
// Check if command is auto-allowed (safe command)
|
||||
if agent.IsAutoAllowed(cmd) {
|
||||
fmt.Fprintf(os.Stderr, "\033[90m▶ Auto-allowed: %s\033[0m\n", formatToolShort(toolName, args))
|
||||
fmt.Fprintf(os.Stderr, "\033[1mauto-allowed:\033[0m %s\n", formatToolShort(toolName, args))
|
||||
skipApproval = true
|
||||
}
|
||||
}
|
||||
@@ -450,7 +375,7 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
|
||||
// In yolo mode, skip all approval prompts
|
||||
if opts.YoloMode {
|
||||
if !skipApproval {
|
||||
fmt.Fprintf(os.Stderr, "\033[90m▶ Running: %s\033[0m\n", formatToolShort(toolName, args))
|
||||
fmt.Fprintf(os.Stderr, "\033[1mrunning:\033[0m %s\n", formatToolShort(toolName, args))
|
||||
}
|
||||
} else if !skipApproval && !approval.IsAllowed(toolName, args) {
|
||||
result, err := approval.RequestApproval(toolName, args)
|
||||
@@ -480,23 +405,22 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
|
||||
}
|
||||
} else if !skipApproval {
|
||||
// Already allowed - show running indicator
|
||||
fmt.Fprintf(os.Stderr, "\033[90m▶ Running: %s\033[0m\n", formatToolShort(toolName, args))
|
||||
fmt.Fprintf(os.Stderr, "\033[1mrunning:\033[0m %s\n", formatToolShort(toolName, args))
|
||||
}
|
||||
|
||||
// Execute the tool
|
||||
toolResult, err := toolRegistry.Execute(call)
|
||||
if err != nil {
|
||||
// Check if web search needs authentication
|
||||
if errors.Is(err, tools.ErrWebSearchAuthRequired) {
|
||||
fmt.Fprintf(os.Stderr, "\033[33m Web search requires authentication.\033[0m\n")
|
||||
// Prompt user to sign in
|
||||
fmt.Fprintf(os.Stderr, "\033[1mauth required:\033[0m web search requires authentication\n")
|
||||
result, promptErr := agent.PromptYesNo("Sign in to Ollama?")
|
||||
if promptErr == nil && result {
|
||||
// Get signin URL and wait for auth completion
|
||||
if signinErr := waitForOllamaSignin(ctx); signinErr == nil {
|
||||
suggestedModel := promptCloudModelSuggestion()
|
||||
if suggestedModel != "" && opts.ActiveModel != nil {
|
||||
*opts.ActiveModel = suggestedModel
|
||||
showModelConnection(ctx, suggestedModel)
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\033[90mRetrying web search...\033[0m\n")
|
||||
// Retry the web search
|
||||
fmt.Fprintf(os.Stderr, "\033[90mretrying web search...\033[0m\n")
|
||||
toolResult, err = toolRegistry.Execute(call)
|
||||
if err == nil {
|
||||
goto toolSuccess
|
||||
@@ -504,7 +428,7 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
|
||||
}
|
||||
}
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "\033[31m Error: %v\033[0m\n", err)
|
||||
fmt.Fprintf(os.Stderr, "\033[1merror:\033[0m %v\n", err)
|
||||
toolResults = append(toolResults, api.Message{
|
||||
Role: "tool",
|
||||
Content: fmt.Sprintf("Error: %v", err),
|
||||
@@ -515,27 +439,17 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
|
||||
toolSuccess:
|
||||
|
||||
// Display tool output (truncated for display)
|
||||
truncatedOutput := ""
|
||||
if toolResult != "" {
|
||||
output := toolResult
|
||||
if len(output) > 300 {
|
||||
output = output[:300] + "... (truncated, press Ctrl+O to expand)"
|
||||
output = output[:300] + "... (truncated)"
|
||||
}
|
||||
truncatedOutput = output
|
||||
// Show result in grey, indented
|
||||
fmt.Fprintf(os.Stderr, "\033[90m %s\033[0m\n", strings.ReplaceAll(output, "\n", "\n "))
|
||||
}
|
||||
|
||||
// Store full and truncated output for Ctrl+O toggle
|
||||
if opts.LastToolOutput != nil {
|
||||
*opts.LastToolOutput = toolResult
|
||||
}
|
||||
if opts.LastToolOutputTruncated != nil {
|
||||
*opts.LastToolOutputTruncated = truncatedOutput
|
||||
}
|
||||
|
||||
// Truncate output to prevent context overflow
|
||||
toolResultForLLM := truncateToolOutput(toolResult, getActiveModel(&opts))
|
||||
toolResultForLLM := truncateToolOutput(toolResult, opts.Model)
|
||||
|
||||
toolResults = append(toolResults, api.Message{
|
||||
Role: "tool",
|
||||
@@ -585,17 +499,18 @@ func truncateUTF8(s string, limit int) string {
|
||||
|
||||
// formatToolShort returns a short description of a tool call.
|
||||
func formatToolShort(toolName string, args map[string]any) string {
|
||||
displayName := agent.ToolDisplayName(toolName)
|
||||
if toolName == "bash" {
|
||||
if cmd, ok := args["command"].(string); ok {
|
||||
return fmt.Sprintf("bash: %s", truncateUTF8(cmd, 50))
|
||||
return fmt.Sprintf("%s: %s", displayName, truncateUTF8(cmd, 50))
|
||||
}
|
||||
}
|
||||
if toolName == "web_search" {
|
||||
if query, ok := args["query"].(string); ok {
|
||||
return fmt.Sprintf("web_search: %s", truncateUTF8(query, 50))
|
||||
return fmt.Sprintf("%s: %s", displayName, truncateUTF8(query, 50))
|
||||
}
|
||||
}
|
||||
return toolName
|
||||
return displayName
|
||||
}
|
||||
|
||||
// Helper types and functions for display
|
||||
@@ -694,28 +609,25 @@ func renderToolCalls(toolCalls []api.ToolCall, plainText bool) string {
|
||||
return out
|
||||
}
|
||||
|
||||
// checkModelCapabilities checks if the model supports tools and thinking.
|
||||
func checkModelCapabilities(ctx context.Context, modelName string) (supportsTools bool, supportsThinking bool, err error) {
|
||||
// checkModelCapabilities checks if the model supports tools.
|
||||
func checkModelCapabilities(ctx context.Context, modelName string) (supportsTools bool, err error) {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return false, false, err
|
||||
return false, err
|
||||
}
|
||||
|
||||
resp, err := client.Show(ctx, &api.ShowRequest{Model: modelName})
|
||||
if err != nil {
|
||||
return false, false, err
|
||||
return false, err
|
||||
}
|
||||
|
||||
for _, cap := range resp.Capabilities {
|
||||
if cap == model.CapabilityTools {
|
||||
supportsTools = true
|
||||
}
|
||||
if cap == model.CapabilityThinking {
|
||||
supportsThinking = true
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return supportsTools, supportsThinking, nil
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// GenerateInteractive runs an interactive agent session.
|
||||
@@ -735,29 +647,25 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
fmt.Print(readline.StartBracketedPaste)
|
||||
defer fmt.Printf(readline.EndBracketedPaste)
|
||||
|
||||
// Check if model supports tools and thinking
|
||||
supportsTools, supportsThinking, err := checkModelCapabilities(cmd.Context(), modelName)
|
||||
// Check if model supports tools
|
||||
supportsTools, err := checkModelCapabilities(cmd.Context(), modelName)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "\033[33mWarning: Could not check model capabilities: %v\033[0m\n", err)
|
||||
fmt.Fprintf(os.Stderr, "\033[1mwarning:\033[0m could not check model capabilities: %v\n", err)
|
||||
supportsTools = false
|
||||
supportsThinking = false
|
||||
}
|
||||
|
||||
// Track if session is using thinking mode
|
||||
usingThinking := think != nil && supportsThinking
|
||||
|
||||
// Create tool registry only if model supports tools
|
||||
var toolRegistry *tools.Registry
|
||||
if supportsTools {
|
||||
toolRegistry = tools.DefaultRegistry()
|
||||
if toolRegistry.Count() > 0 {
|
||||
fmt.Fprintf(os.Stderr, "\033[90mTools available: %s\033[0m\n", strings.Join(toolRegistry.Names(), ", "))
|
||||
fmt.Fprintf(os.Stderr, "\033[90mtools available: %s\033[0m\n", strings.Join(toolRegistry.Names(), ", "))
|
||||
}
|
||||
if yoloMode {
|
||||
fmt.Fprintf(os.Stderr, "\033[33m⚠ YOLO mode: All tool approvals will be skipped\033[0m\n")
|
||||
fmt.Fprintf(os.Stderr, "\033[1mwarning:\033[0m yolo mode - all tool approvals will be skipped\n")
|
||||
}
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "\033[33mNote: Model does not support tools - running in chat-only mode\033[0m\n")
|
||||
fmt.Fprintf(os.Stderr, "\033[1mnote:\033[0m model does not support tools - running in chat-only mode\n")
|
||||
}
|
||||
|
||||
// Create approval manager for session
|
||||
@@ -766,11 +674,6 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
var messages []api.Message
|
||||
var sb strings.Builder
|
||||
|
||||
// Track last tool output for Ctrl+O toggle
|
||||
var lastToolOutput string
|
||||
var lastToolOutputTruncated string
|
||||
var toolOutputExpanded bool
|
||||
|
||||
for {
|
||||
line, err := scanner.Readline()
|
||||
switch {
|
||||
@@ -783,20 +686,6 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
}
|
||||
sb.Reset()
|
||||
continue
|
||||
case errors.Is(err, readline.ErrExpandOutput):
|
||||
// Ctrl+O pressed - toggle between expanded and collapsed tool output
|
||||
if lastToolOutput == "" {
|
||||
fmt.Fprintf(os.Stderr, "\033[90mNo tool output to expand\033[0m\n")
|
||||
} else if toolOutputExpanded {
|
||||
// Currently expanded, show truncated
|
||||
fmt.Fprintf(os.Stderr, "\033[90m %s\033[0m\n", strings.ReplaceAll(lastToolOutputTruncated, "\n", "\n "))
|
||||
toolOutputExpanded = false
|
||||
} else {
|
||||
// Currently collapsed, show full
|
||||
fmt.Fprintf(os.Stderr, "\033[90m %s\033[0m\n", strings.ReplaceAll(lastToolOutput, "\n", "\n "))
|
||||
toolOutputExpanded = true
|
||||
}
|
||||
continue
|
||||
case err != nil:
|
||||
return err
|
||||
}
|
||||
@@ -833,44 +722,26 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
if sb.Len() > 0 {
|
||||
newMessage := api.Message{Role: "user", Content: sb.String()}
|
||||
messages = append(messages, newMessage)
|
||||
toolOutputExpanded = false
|
||||
|
||||
retryChat:
|
||||
for {
|
||||
opts := RunOptions{
|
||||
Model: modelName,
|
||||
Messages: messages,
|
||||
WordWrap: wordWrap,
|
||||
Options: options,
|
||||
Think: think,
|
||||
HideThinking: hideThinking,
|
||||
KeepAlive: keepAlive,
|
||||
Tools: toolRegistry,
|
||||
Approval: approval,
|
||||
YoloMode: yoloMode,
|
||||
LastToolOutput: &lastToolOutput,
|
||||
LastToolOutputTruncated: &lastToolOutputTruncated,
|
||||
ActiveModel: &modelName,
|
||||
}
|
||||
opts := RunOptions{
|
||||
Model: modelName,
|
||||
Messages: messages,
|
||||
WordWrap: wordWrap,
|
||||
Options: options,
|
||||
Think: think,
|
||||
HideThinking: hideThinking,
|
||||
KeepAlive: keepAlive,
|
||||
Tools: toolRegistry,
|
||||
Approval: approval,
|
||||
YoloMode: yoloMode,
|
||||
}
|
||||
|
||||
assistant, err := Chat(cmd.Context(), opts)
|
||||
if err != nil {
|
||||
var switchReq *CloudModelSwitchRequest
|
||||
if errors.As(err, &switchReq) {
|
||||
newModel := switchReq.Model
|
||||
if err := switchToModel(cmd.Context(), newModel, &modelName, &supportsTools, &supportsThinking, &toolRegistry, usingThinking); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "\033[33m%v\033[0m\n", err)
|
||||
fmt.Fprintf(os.Stderr, "\033[90mContinuing with %s...\033[0m\n", modelName)
|
||||
}
|
||||
continue retryChat
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
if assistant != nil {
|
||||
messages = append(messages, *assistant)
|
||||
}
|
||||
break retryChat
|
||||
assistant, err := Chat(cmd.Context(), opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if assistant != nil {
|
||||
messages = append(messages, *assistant)
|
||||
}
|
||||
|
||||
sb.Reset()
|
||||
@@ -878,52 +749,6 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
}
|
||||
}
|
||||
|
||||
// switchToModel handles model switching with capability checks and UI updates.
|
||||
func switchToModel(ctx context.Context, newModel string, modelName *string, supportsTools, supportsThinking *bool, toolRegistry **tools.Registry, usingThinking bool) error {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not create client: %w", err)
|
||||
}
|
||||
|
||||
newSupportsTools, newSupportsThinking, capErr := checkModelCapabilities(ctx, newModel)
|
||||
if capErr != nil {
|
||||
return fmt.Errorf("could not check model capabilities: %w", capErr)
|
||||
}
|
||||
|
||||
// TODO(parthsareen): Handle thinking -> non-thinking model switch gracefully
|
||||
if usingThinking && !newSupportsThinking {
|
||||
return fmt.Errorf("%s does not support thinking mode", newModel)
|
||||
}
|
||||
|
||||
// Show "Connecting to X on ollama.com" for cloud models
|
||||
info, err := client.Show(ctx, &api.ShowRequest{Model: newModel})
|
||||
if err == nil && info.RemoteHost != "" {
|
||||
if strings.HasPrefix(info.RemoteHost, "https://ollama.com") {
|
||||
fmt.Fprintf(os.Stderr, "Connecting to '%s' on 'ollama.com' ⚡\n", info.RemoteModel)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "Connecting to '%s' on '%s'\n", info.RemoteModel, info.RemoteHost)
|
||||
}
|
||||
}
|
||||
|
||||
*modelName = newModel
|
||||
*supportsTools = newSupportsTools
|
||||
*supportsThinking = newSupportsThinking
|
||||
|
||||
if *supportsTools {
|
||||
if *toolRegistry == nil {
|
||||
*toolRegistry = tools.DefaultRegistry()
|
||||
}
|
||||
if (*toolRegistry).Count() > 0 {
|
||||
fmt.Fprintf(os.Stderr, "\033[90mTools available: %s\033[0m\n", strings.Join((*toolRegistry).Names(), ", "))
|
||||
}
|
||||
} else {
|
||||
*toolRegistry = nil
|
||||
fmt.Fprintf(os.Stderr, "\033[33mNote: Model does not support tools - running in chat-only mode\033[0m\n")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// showToolsStatus displays the current tools and approval status.
|
||||
func showToolsStatus(registry *tools.Registry, approval *agent.ApprovalManager, supportsTools bool) {
|
||||
if !supportsTools || registry == nil {
|
||||
|
||||
38
x/imagegen/.gitignore
vendored
Normal file
38
x/imagegen/.gitignore
vendored
Normal file
@@ -0,0 +1,38 @@
|
||||
# Build directories
|
||||
build/
|
||||
dist/
|
||||
|
||||
# CMake
|
||||
CMakeCache.txt
|
||||
CMakeFiles/
|
||||
cmake_install.cmake
|
||||
Makefile
|
||||
*.cmake
|
||||
|
||||
# IDE
|
||||
.idea/
|
||||
.vscode/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
|
||||
# macOS
|
||||
.DS_Store
|
||||
*.dSYM/
|
||||
|
||||
# Go
|
||||
*.exe
|
||||
*.exe~
|
||||
*.dll
|
||||
*.so
|
||||
*.dylib
|
||||
|
||||
# Python
|
||||
*.npy
|
||||
|
||||
/engine
|
||||
weights
|
||||
outputs
|
||||
|
||||
prompt.txt
|
||||
negative.txt
|
||||
61
x/imagegen/README.md
Normal file
61
x/imagegen/README.md
Normal file
@@ -0,0 +1,61 @@
|
||||
# imagegen
|
||||
|
||||
This is a package that uses MLX to run image generation models, ahead of being integrated into Ollama's primary runner.
|
||||
in `CMakeLists.txt` and rebuild.
|
||||
|
||||
### 1. Download a Model
|
||||
|
||||
Download Llama 3.1 8B (or any compatible model) in safetensors format:
|
||||
|
||||
```bash
|
||||
mkdir -p ./weights
|
||||
|
||||
# Example using huggingface-cli
|
||||
hf download meta-llama/Llama-3.1-8B --local-dir ./weights/Llama-3.1-8B
|
||||
hf download openai/gpt-oss-20b --local-dir ./weights/gpt-oss-20b
|
||||
```
|
||||
|
||||
### 2. Run Inference
|
||||
|
||||
```bash
|
||||
# Build
|
||||
go build ./cmd/engine
|
||||
|
||||
# Text generation
|
||||
./engine -model ./weights/Llama-3.1-8B -prompt "Hello, world!" -max-tokens 250
|
||||
|
||||
# Qwen-Image 2512 (text-to-image)
|
||||
./engine -qwen-image -model ./weights/Qwen-Image-2512 -prompt "A mountain landscape at sunset" \
|
||||
-width 1024 -height 1024 -steps 20 -seed 42 -output landscape.png
|
||||
|
||||
# Qwen-Image Edit (experimental) - 8 steps for speed, but model recommends 50
|
||||
./engine -qwen-image-edit -model ./weights/Qwen-Image-Edit-2511 \
|
||||
-input-image input.png -prompt "Make it winter" -negative-prompt " " -cfg-scale 4.0 \
|
||||
-steps 8 -seed 42 -output edited.png
|
||||
```
|
||||
|
||||
## Memory Management
|
||||
|
||||
MLX Python/C++ uses scope-based memory management - arrays are freed when they go out of scope. Go's garbage collector is non-deterministic, so we can't rely on finalizers to free GPU memory promptly.
|
||||
|
||||
Instead, arrays are automatically tracked and freed on `Eval()`:
|
||||
|
||||
```go
|
||||
// All arrays are automatically tracked when created
|
||||
x := mlx.Add(a, b)
|
||||
y := mlx.Matmul(x, w)
|
||||
|
||||
// Eval frees non-kept arrays, evaluates outputs (auto-kept)
|
||||
mlx.Eval(y)
|
||||
|
||||
// After copying to CPU, free the array
|
||||
data := y.Data()
|
||||
y.Free()
|
||||
```
|
||||
|
||||
Key points:
|
||||
|
||||
- All created arrays are automatically tracked
|
||||
- `mlx.Eval(outputs...)` frees non-kept arrays, evaluates outputs (outputs auto-kept)
|
||||
- `mlx.Keep(arrays...)` marks arrays to survive multiple Eval cycles (for weights, caches)
|
||||
- Call `.Free()` when done with an array
|
||||
156
x/imagegen/cache/cache.go
vendored
Normal file
156
x/imagegen/cache/cache.go
vendored
Normal file
@@ -0,0 +1,156 @@
|
||||
//go:build mlx
|
||||
|
||||
package cache
|
||||
|
||||
import "github.com/ollama/ollama/x/imagegen/mlx"
|
||||
|
||||
type Cache interface {
|
||||
Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array)
|
||||
Offset() int
|
||||
Len() int
|
||||
State() []*mlx.Array
|
||||
}
|
||||
|
||||
type KVCache struct {
|
||||
keys, values *mlx.Array
|
||||
offset int
|
||||
step int
|
||||
}
|
||||
|
||||
func NewKVCache() *KVCache {
|
||||
return &KVCache{step: 256}
|
||||
}
|
||||
|
||||
func (c *KVCache) Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) {
|
||||
prev := c.offset
|
||||
shape := k.Shape()
|
||||
B, H, Dk := shape[0], shape[1], shape[3]
|
||||
Dv := v.Shape()[3]
|
||||
|
||||
// Grow buffer if needed
|
||||
if c.keys == nil || (prev+seqLen) > int(c.keys.Shape()[2]) {
|
||||
nSteps := (c.step + seqLen - 1) / c.step
|
||||
newK := mlx.Zeros([]int32{B, H, int32(nSteps * c.step), Dk}, k.Dtype())
|
||||
newV := mlx.Zeros([]int32{B, H, int32(nSteps * c.step), Dv}, v.Dtype())
|
||||
|
||||
if c.keys != nil {
|
||||
if prev%c.step != 0 {
|
||||
c.keys = mlx.Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, int32(prev), Dk})
|
||||
c.values = mlx.Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, int32(prev), Dv})
|
||||
}
|
||||
c.keys = mlx.Concatenate([]*mlx.Array{c.keys, newK}, 2)
|
||||
c.values = mlx.Concatenate([]*mlx.Array{c.values, newV}, 2)
|
||||
} else {
|
||||
c.keys, c.values = newK, newV
|
||||
}
|
||||
}
|
||||
|
||||
c.offset += seqLen
|
||||
c.keys = mlx.SliceUpdateInplace(c.keys, k, []int32{0, 0, int32(prev), 0}, []int32{B, H, int32(c.offset), Dk})
|
||||
c.values = mlx.SliceUpdateInplace(c.values, v, []int32{0, 0, int32(prev), 0}, []int32{B, H, int32(c.offset), Dv})
|
||||
|
||||
return mlx.Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, int32(c.offset), Dk}),
|
||||
mlx.Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, int32(c.offset), Dv})
|
||||
}
|
||||
|
||||
func (c *KVCache) State() []*mlx.Array {
|
||||
if c.keys == nil {
|
||||
return nil
|
||||
}
|
||||
return []*mlx.Array{c.keys, c.values}
|
||||
}
|
||||
|
||||
func (c *KVCache) Offset() int { return c.offset }
|
||||
func (c *KVCache) Len() int { return c.offset }
|
||||
|
||||
// RotatingKVCache implements sliding window attention with bounded memory
|
||||
type RotatingKVCache struct {
|
||||
keys, values *mlx.Array
|
||||
offset int
|
||||
maxSize int
|
||||
step int
|
||||
idx int
|
||||
}
|
||||
|
||||
func NewRotatingKVCache(maxSize int) *RotatingKVCache {
|
||||
return &RotatingKVCache{maxSize: maxSize, step: 256}
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) {
|
||||
if seqLen > 1 {
|
||||
return c.updateConcat(k, v, seqLen)
|
||||
}
|
||||
return c.updateInPlace(k, v)
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) updateInPlace(k, v *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
shape := k.Shape()
|
||||
B, H, Dk := shape[0], shape[1], shape[3]
|
||||
Dv := v.Shape()[3]
|
||||
|
||||
// Grow buffer if not yet at max
|
||||
if c.keys == nil || (c.idx >= int(c.keys.Shape()[2]) && int(c.keys.Shape()[2]) < c.maxSize) {
|
||||
var cap int
|
||||
if c.keys != nil {
|
||||
cap = int(c.keys.Shape()[2])
|
||||
}
|
||||
newSize := min(c.step, c.maxSize-cap)
|
||||
newK := mlx.Zeros([]int32{B, H, int32(newSize), Dk}, k.Dtype())
|
||||
newV := mlx.Zeros([]int32{B, H, int32(newSize), Dv}, v.Dtype())
|
||||
if c.keys != nil {
|
||||
c.keys = mlx.Concatenate([]*mlx.Array{c.keys, newK}, 2)
|
||||
c.values = mlx.Concatenate([]*mlx.Array{c.values, newV}, 2)
|
||||
} else {
|
||||
c.keys, c.values = newK, newV
|
||||
}
|
||||
}
|
||||
|
||||
// Rotate when hitting max
|
||||
if c.idx >= c.maxSize {
|
||||
c.idx = 0
|
||||
}
|
||||
|
||||
c.keys = mlx.SliceUpdateInplace(c.keys, k, []int32{0, 0, int32(c.idx), 0}, []int32{B, H, int32(c.idx + 1), Dk})
|
||||
c.values = mlx.SliceUpdateInplace(c.values, v, []int32{0, 0, int32(c.idx), 0}, []int32{B, H, int32(c.idx + 1), Dv})
|
||||
|
||||
c.offset++
|
||||
c.idx++
|
||||
|
||||
validLen := int32(min(c.offset, c.maxSize))
|
||||
return mlx.Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, validLen, Dk}),
|
||||
mlx.Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, validLen, Dv})
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) updateConcat(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) {
|
||||
shape := k.Shape()
|
||||
B, H, Dk := shape[0], shape[1], shape[3]
|
||||
Dv := v.Shape()[3]
|
||||
|
||||
if c.keys == nil {
|
||||
c.keys, c.values = k, v
|
||||
} else {
|
||||
c.keys = mlx.Concatenate([]*mlx.Array{c.keys, k}, 2)
|
||||
c.values = mlx.Concatenate([]*mlx.Array{c.values, v}, 2)
|
||||
}
|
||||
c.offset += seqLen
|
||||
|
||||
// Trim to max_size to maintain sliding window
|
||||
cap := int(c.keys.Shape()[2])
|
||||
if trim := cap - c.maxSize; trim > 0 {
|
||||
c.keys = mlx.Slice(c.keys, []int32{0, 0, int32(trim), 0}, []int32{B, H, int32(cap), Dk})
|
||||
c.values = mlx.Slice(c.values, []int32{0, 0, int32(trim), 0}, []int32{B, H, int32(cap), Dv})
|
||||
}
|
||||
|
||||
c.idx = int(c.keys.Shape()[2])
|
||||
return c.keys, c.values
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) State() []*mlx.Array {
|
||||
if c.keys == nil {
|
||||
return nil
|
||||
}
|
||||
return []*mlx.Array{c.keys, c.values}
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) Offset() int { return c.offset }
|
||||
func (c *RotatingKVCache) Len() int { return min(c.offset, c.maxSize) }
|
||||
164
x/imagegen/cache/step.go
vendored
Normal file
164
x/imagegen/cache/step.go
vendored
Normal file
@@ -0,0 +1,164 @@
|
||||
//go:build mlx
|
||||
|
||||
package cache
|
||||
|
||||
import "github.com/ollama/ollama/x/imagegen/mlx"
|
||||
|
||||
// StepCache caches layer outputs across diffusion denoising steps.
|
||||
// Based on DeepCache (CVPR 2024) and Learning-to-Cache (NeurIPS 2024):
|
||||
// shallow layers change little between consecutive steps, so we can
|
||||
// cache their outputs and skip recomputation on non-refresh steps.
|
||||
//
|
||||
// Supports both single-stream (Z-Image) and dual-stream (Qwen-Image) architectures:
|
||||
// - Single-stream: use Get/Set for the single output per layer
|
||||
// - Dual-stream: use Get/Set for stream 1 (imgH), Get2/Set2 for stream 2 (txtH)
|
||||
//
|
||||
// Usage (single-stream):
|
||||
//
|
||||
// cache := NewStepCache(15) // cache first 15 layers
|
||||
// for step := 0; step < numSteps; step++ {
|
||||
// refresh := cache.ShouldRefresh(step, 3) // refresh every 3 steps
|
||||
// for i, layer := range layers {
|
||||
// if i < 15 && !refresh && cache.Get(i) != nil {
|
||||
// output = cache.Get(i) // reuse cached
|
||||
// } else {
|
||||
// output = layer.Forward(input)
|
||||
// if i < 15 && refresh {
|
||||
// cache.Set(i, output)
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// cache.Free() // cleanup when done
|
||||
//
|
||||
// Usage (dual-stream):
|
||||
//
|
||||
// cache := NewStepCache(15)
|
||||
// for step := 0; step < numSteps; step++ {
|
||||
// refresh := cache.ShouldRefresh(step, 3)
|
||||
// for i, layer := range layers {
|
||||
// if i < 15 && !refresh && cache.Get(i) != nil {
|
||||
// imgH, txtH = cache.Get(i), cache.Get2(i)
|
||||
// } else {
|
||||
// imgH, txtH = layer.Forward(imgH, txtH, ...)
|
||||
// if i < 15 && refresh {
|
||||
// cache.Set(i, imgH)
|
||||
// cache.Set2(i, txtH)
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
type StepCache struct {
|
||||
layers []*mlx.Array // cached layer outputs (stream 1)
|
||||
layers2 []*mlx.Array // cached layer outputs (stream 2, for dual-stream models)
|
||||
constant *mlx.Array // optional constant (e.g., text embeddings)
|
||||
}
|
||||
|
||||
// NewStepCache creates a cache for the given number of layers.
|
||||
func NewStepCache(numLayers int) *StepCache {
|
||||
return &StepCache{
|
||||
layers: make([]*mlx.Array, numLayers),
|
||||
layers2: make([]*mlx.Array, numLayers),
|
||||
}
|
||||
}
|
||||
|
||||
// ShouldRefresh returns true if the cache should be refreshed at this step.
|
||||
// Refresh happens on step 0, interval, 2*interval, etc.
|
||||
func (c *StepCache) ShouldRefresh(step, interval int) bool {
|
||||
return step%interval == 0
|
||||
}
|
||||
|
||||
// Get returns the cached output for a layer, or nil if not cached.
|
||||
func (c *StepCache) Get(layer int) *mlx.Array {
|
||||
if layer < len(c.layers) {
|
||||
return c.layers[layer]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Set stores a layer output (stream 1), freeing any previous value.
|
||||
func (c *StepCache) Set(layer int, arr *mlx.Array) {
|
||||
if layer < len(c.layers) {
|
||||
if c.layers[layer] != nil {
|
||||
c.layers[layer].Free()
|
||||
}
|
||||
c.layers[layer] = arr
|
||||
}
|
||||
}
|
||||
|
||||
// Get2 returns the cached output for a layer (stream 2), or nil if not cached.
|
||||
// Used for dual-stream architectures like Qwen-Image.
|
||||
func (c *StepCache) Get2(layer int) *mlx.Array {
|
||||
if layer < len(c.layers2) {
|
||||
return c.layers2[layer]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Set2 stores a layer output (stream 2), freeing any previous value.
|
||||
// Used for dual-stream architectures like Qwen-Image.
|
||||
func (c *StepCache) Set2(layer int, arr *mlx.Array) {
|
||||
if layer < len(c.layers2) {
|
||||
if c.layers2[layer] != nil {
|
||||
c.layers2[layer].Free()
|
||||
}
|
||||
c.layers2[layer] = arr
|
||||
}
|
||||
}
|
||||
|
||||
// GetConstant returns the cached constant value.
|
||||
func (c *StepCache) GetConstant() *mlx.Array {
|
||||
return c.constant
|
||||
}
|
||||
|
||||
// SetConstant stores a constant value, freeing any previous value.
|
||||
func (c *StepCache) SetConstant(arr *mlx.Array) {
|
||||
if c.constant != nil {
|
||||
c.constant.Free()
|
||||
}
|
||||
c.constant = arr
|
||||
}
|
||||
|
||||
// Arrays returns all non-nil cached arrays (for pool.Keep).
|
||||
func (c *StepCache) Arrays() []*mlx.Array {
|
||||
var result []*mlx.Array
|
||||
if c.constant != nil {
|
||||
result = append(result, c.constant)
|
||||
}
|
||||
for _, arr := range c.layers {
|
||||
if arr != nil {
|
||||
result = append(result, arr)
|
||||
}
|
||||
}
|
||||
for _, arr := range c.layers2 {
|
||||
if arr != nil {
|
||||
result = append(result, arr)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// Free releases all cached arrays. Call when generation completes.
|
||||
func (c *StepCache) Free() {
|
||||
if c.constant != nil {
|
||||
c.constant.Free()
|
||||
c.constant = nil
|
||||
}
|
||||
for i, arr := range c.layers {
|
||||
if arr != nil {
|
||||
arr.Free()
|
||||
c.layers[i] = nil
|
||||
}
|
||||
}
|
||||
for i, arr := range c.layers2 {
|
||||
if arr != nil {
|
||||
arr.Free()
|
||||
c.layers2[i] = nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// NumLayers returns the number of layers this cache can store.
|
||||
func (c *StepCache) NumLayers() int {
|
||||
return len(c.layers)
|
||||
}
|
||||
359
x/imagegen/cmd/engine/generate.go
Normal file
359
x/imagegen/cmd/engine/generate.go
Normal file
@@ -0,0 +1,359 @@
|
||||
//go:build mlx
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
)
|
||||
|
||||
// Dedicated stream for generation (like mlx-lm's generation_stream)
|
||||
var generationStream *mlx.Stream
|
||||
|
||||
// utf8Streamer buffers decoded text and emits only complete UTF-8 characters.
|
||||
// This handles cases where tokenizers output partial multi-byte sequences.
|
||||
type utf8Streamer struct {
|
||||
buffer []byte
|
||||
}
|
||||
|
||||
// Write adds decoded text to the buffer and returns complete UTF-8 characters.
|
||||
func (s *utf8Streamer) Write(text string) string {
|
||||
s.buffer = append(s.buffer, text...)
|
||||
|
||||
// Find the last position that ends with a complete UTF-8 character
|
||||
validLen := 0
|
||||
for i := 0; i < len(s.buffer); {
|
||||
r, size := utf8.DecodeRune(s.buffer[i:])
|
||||
if r == utf8.RuneError && size == 1 {
|
||||
// Invalid or incomplete UTF-8 sequence at this position
|
||||
// Check if it could be a valid start of a multi-byte sequence
|
||||
if len(s.buffer)-i < 4 {
|
||||
// Might be incomplete, keep it in buffer
|
||||
break
|
||||
}
|
||||
// Definitely invalid, skip this byte
|
||||
i++
|
||||
validLen = i
|
||||
} else {
|
||||
i += size
|
||||
validLen = i
|
||||
}
|
||||
}
|
||||
|
||||
if validLen == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
result := string(s.buffer[:validLen])
|
||||
s.buffer = s.buffer[validLen:]
|
||||
return result
|
||||
}
|
||||
|
||||
// Flush returns any remaining buffered bytes (may be incomplete UTF-8).
|
||||
func (s *utf8Streamer) Flush() string {
|
||||
if len(s.buffer) == 0 {
|
||||
return ""
|
||||
}
|
||||
result := string(s.buffer)
|
||||
s.buffer = nil
|
||||
return result
|
||||
}
|
||||
|
||||
func init() {
|
||||
generationStream = mlx.NewStream()
|
||||
}
|
||||
|
||||
// withStream runs fn with the generation stream as default
|
||||
func withStream(fn func()) {
|
||||
orig := mlx.GetDefaultStream()
|
||||
mlx.SetDefaultStream(generationStream)
|
||||
fn()
|
||||
mlx.SetDefaultStream(orig)
|
||||
}
|
||||
|
||||
type Model interface {
|
||||
Tokenizer() *tokenizer.Tokenizer
|
||||
VocabSize() int32
|
||||
NewCache(maxSeqLen int32) []cache.Cache
|
||||
Forward(input *mlx.Array, caches []cache.Cache) *mlx.Array
|
||||
}
|
||||
|
||||
// ChatModel is an optional interface for models that support chat formatting
|
||||
type ChatModel interface {
|
||||
FormatPrompt(prompt string) string
|
||||
}
|
||||
|
||||
// MultimodalModel is for models that support image input
|
||||
type MultimodalModel interface {
|
||||
Model
|
||||
FormatPromptWithImage(prompt string) string
|
||||
ExpandImageTokens(tokens []int32) []int32
|
||||
ForwardWithImage(tokens *mlx.Array, image *mlx.Array, caches []cache.Cache) *mlx.Array
|
||||
ImageSize() int32 // Returns expected image size for preprocessing
|
||||
}
|
||||
|
||||
// ImageLoader loads and preprocesses an image for multimodal models
|
||||
// Returns nil if path is empty
|
||||
type ImageLoader func(path string, imageSize int32) (*mlx.Array, error)
|
||||
|
||||
type input struct {
|
||||
Prompt string
|
||||
Image *mlx.Array // Optional preprocessed image for multimodal models
|
||||
MaxTokens int
|
||||
Temperature float32
|
||||
TopP float32
|
||||
TopK int
|
||||
WiredLimitGB int // Metal wired memory limit in GB (default 32)
|
||||
}
|
||||
|
||||
type output struct {
|
||||
Text string
|
||||
Done bool
|
||||
PrefillTokSec float64
|
||||
GenTokSec float64
|
||||
}
|
||||
|
||||
// Decoder wraps model + cache for autoregressive generation.
|
||||
type Decoder struct {
|
||||
model Model
|
||||
caches []cache.Cache
|
||||
vocabSize int32
|
||||
temp float32
|
||||
topK int
|
||||
topP float32
|
||||
token *mlx.Array // Current token (kept across pools)
|
||||
oldCacheState []*mlx.Array // Preallocated slice for old cache state
|
||||
image *mlx.Array // Optional image for multimodal prefill
|
||||
}
|
||||
|
||||
func NewDecoder(m Model, temp float32, topK int, topP float32) *Decoder {
|
||||
caches := m.NewCache(0)
|
||||
return &Decoder{
|
||||
model: m,
|
||||
caches: caches,
|
||||
vocabSize: m.VocabSize(),
|
||||
temp: temp,
|
||||
topK: topK,
|
||||
topP: topP,
|
||||
oldCacheState: make([]*mlx.Array, 0, len(caches)*2),
|
||||
}
|
||||
}
|
||||
|
||||
// SetImage sets the image for multimodal prefill (call before prefill)
|
||||
func (d *Decoder) SetImage(img *mlx.Array) {
|
||||
d.image = img
|
||||
}
|
||||
|
||||
func (d *Decoder) prefill(inputIDs []int32) int {
|
||||
processed := 0
|
||||
|
||||
// Track old cache state to free after each chunk
|
||||
var oldCacheState []*mlx.Array
|
||||
|
||||
// For multimodal models with an image, we need to process all tokens together
|
||||
// in the first forward pass so the image embeddings can be inserted properly.
|
||||
// Skip chunking for multimodal prefill.
|
||||
isMultimodal := d.image != nil
|
||||
|
||||
// Process all-but-1 tokens in chunks, eval cache state for memory management
|
||||
// Skip chunking for multimodal - process everything in the final step
|
||||
if !isMultimodal {
|
||||
for len(inputIDs) > 1 {
|
||||
chunkSize := min(2048, len(inputIDs)-1)
|
||||
if chunkSize <= 0 {
|
||||
break
|
||||
}
|
||||
chunk := inputIDs[:chunkSize]
|
||||
|
||||
// Save old cache state before forward
|
||||
oldCacheState = oldCacheState[:0]
|
||||
for _, c := range d.caches {
|
||||
oldCacheState = append(oldCacheState, c.State()...)
|
||||
}
|
||||
|
||||
var cacheState []*mlx.Array
|
||||
withStream(func() {
|
||||
x := mlx.NewArrayInt32(chunk, []int32{1, int32(len(chunk))})
|
||||
d.model.Forward(x, d.caches)
|
||||
for _, c := range d.caches {
|
||||
cacheState = append(cacheState, c.State()...)
|
||||
}
|
||||
})
|
||||
mlx.Eval(cacheState...)
|
||||
|
||||
// Free old cache state
|
||||
for _, arr := range oldCacheState {
|
||||
if arr != nil {
|
||||
arr.Free()
|
||||
}
|
||||
}
|
||||
|
||||
inputIDs = inputIDs[chunkSize:]
|
||||
processed += chunkSize
|
||||
}
|
||||
}
|
||||
|
||||
// Save old cache state before final step
|
||||
oldCacheState = oldCacheState[:0]
|
||||
for _, c := range d.caches {
|
||||
oldCacheState = append(oldCacheState, c.State()...)
|
||||
}
|
||||
|
||||
// Final token + sampling (or all tokens for multimodal)
|
||||
withStream(func() {
|
||||
x := mlx.NewArrayInt32(inputIDs, []int32{1, int32(len(inputIDs))})
|
||||
mlx.Eval(x) // Materialize before any other evals
|
||||
|
||||
var logits *mlx.Array
|
||||
// Use ForwardWithImage if we have an image and model supports it
|
||||
if d.image != nil {
|
||||
if mm, ok := d.model.(MultimodalModel); ok {
|
||||
logits = mm.ForwardWithImage(x, d.image, d.caches)
|
||||
d.image = nil // Only use image for first forward
|
||||
} else {
|
||||
logits = d.model.Forward(x, d.caches)
|
||||
}
|
||||
} else {
|
||||
logits = d.model.Forward(x, d.caches)
|
||||
}
|
||||
d.token = sample(logits, d.temp, d.topK, d.topP, d.vocabSize)
|
||||
})
|
||||
// Keep cache state (token auto-kept by AsyncEval)
|
||||
for _, c := range d.caches {
|
||||
mlx.Keep(c.State()...)
|
||||
}
|
||||
mlx.AsyncEval(d.token)
|
||||
|
||||
// Free old cache state from before final step
|
||||
for _, arr := range oldCacheState {
|
||||
if arr != nil {
|
||||
arr.Free()
|
||||
}
|
||||
}
|
||||
|
||||
mlx.ClearCache()
|
||||
|
||||
return processed + len(inputIDs)
|
||||
}
|
||||
|
||||
func (d *Decoder) step() int32 {
|
||||
prevToken := d.token
|
||||
|
||||
// Save old cache state (reuse preallocated slice)
|
||||
d.oldCacheState = d.oldCacheState[:0]
|
||||
for _, c := range d.caches {
|
||||
d.oldCacheState = append(d.oldCacheState, c.State()...)
|
||||
}
|
||||
|
||||
withStream(func() {
|
||||
logits := d.model.Forward(mlx.Reshape(prevToken, 1, 1), d.caches)
|
||||
d.token = sample(logits, d.temp, d.topK, d.topP, d.vocabSize)
|
||||
})
|
||||
// Keep token and new cache state so they survive cleanup
|
||||
mlx.Keep(d.token)
|
||||
for _, c := range d.caches {
|
||||
mlx.Keep(c.State()...)
|
||||
}
|
||||
mlx.AsyncEval(d.token)
|
||||
|
||||
// Sync on previous token (GPU already working on next step)
|
||||
val := prevToken.ItemInt32()
|
||||
|
||||
// Free old token and old cache state
|
||||
prevToken.Free()
|
||||
for _, arr := range d.oldCacheState {
|
||||
arr.Free()
|
||||
}
|
||||
return val
|
||||
}
|
||||
|
||||
func generate(ctx context.Context, m Model, in input, cb func(output)) error {
|
||||
mlx.EnableCompile()
|
||||
wiredLimit := in.WiredLimitGB
|
||||
if wiredLimit <= 0 {
|
||||
wiredLimit = 32 // default 32GB
|
||||
}
|
||||
mlx.MetalSetWiredLimit(uint64(wiredLimit) << 30)
|
||||
|
||||
temp := in.Temperature
|
||||
if temp < 0 {
|
||||
temp = 0.7
|
||||
}
|
||||
|
||||
tok := m.Tokenizer()
|
||||
dec := NewDecoder(m, temp, in.TopK, in.TopP)
|
||||
|
||||
// Apply chat template - use image template if we have an image
|
||||
prompt := in.Prompt
|
||||
var tokens []int32
|
||||
if mm, ok := m.(MultimodalModel); ok && in.Image != nil {
|
||||
prompt = mm.FormatPromptWithImage(prompt)
|
||||
tokens = tok.Encode(prompt, true)
|
||||
tokens = mm.ExpandImageTokens(tokens) // Expand <start_of_image> to 256 image tokens
|
||||
dec.SetImage(in.Image)
|
||||
} else if cm, ok := m.(ChatModel); ok {
|
||||
prompt = cm.FormatPrompt(prompt)
|
||||
tokens = tok.Encode(prompt, true)
|
||||
} else {
|
||||
tokens = tok.Encode(prompt, true)
|
||||
}
|
||||
|
||||
prefillStart := time.Now()
|
||||
prefillTokens := dec.prefill(tokens)
|
||||
// Prefill measurement should include time to first token (like mlx-lm)
|
||||
// Step() waits for prefill to complete and returns first token
|
||||
firstToken := dec.step()
|
||||
prefillTokSec := float64(prefillTokens) / time.Since(prefillStart).Seconds()
|
||||
|
||||
genStart := time.Now()
|
||||
maxTokens := max(in.MaxTokens, 100)
|
||||
var genTokens int
|
||||
|
||||
// UTF-8 streamer to handle partial multi-byte characters
|
||||
streamer := &utf8Streamer{}
|
||||
|
||||
// Handle first token
|
||||
genTokens++
|
||||
if tok.IsEOS(firstToken) {
|
||||
cb(output{Done: true, PrefillTokSec: prefillTokSec, GenTokSec: 0})
|
||||
return nil
|
||||
}
|
||||
if text := streamer.Write(tok.Decode([]int32{firstToken})); text != "" {
|
||||
cb(output{Text: text})
|
||||
}
|
||||
|
||||
for n := 1; n < maxTokens; n++ {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
token := dec.step()
|
||||
genTokens++
|
||||
|
||||
if tok.IsEOS(token) {
|
||||
break
|
||||
}
|
||||
if text := streamer.Write(tok.Decode([]int32{token})); text != "" {
|
||||
cb(output{Text: text})
|
||||
}
|
||||
|
||||
if n%256 == 0 {
|
||||
mlx.ClearCache()
|
||||
}
|
||||
}
|
||||
|
||||
// Flush any remaining buffered bytes
|
||||
if text := streamer.Flush(); text != "" {
|
||||
cb(output{Text: text})
|
||||
}
|
||||
|
||||
fmt.Printf("\nPeak memory: %.2fGB\n", float64(mlx.MetalGetPeakMemory())/(1<<30))
|
||||
cb(output{Done: true, PrefillTokSec: prefillTokSec,
|
||||
GenTokSec: float64(genTokens) / time.Since(genStart).Seconds()})
|
||||
return nil
|
||||
}
|
||||
89
x/imagegen/cmd/engine/image.go
Normal file
89
x/imagegen/cmd/engine/image.go
Normal file
@@ -0,0 +1,89 @@
|
||||
//go:build mlx
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"image"
|
||||
"image/png"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// saveImageArray saves an MLX array as a PNG image.
|
||||
// Expected format: [B, C, H, W] with values in [0, 1] range and C=3 (RGB).
|
||||
func saveImageArray(arr *mlx.Array, path string) error {
|
||||
img, err := arrayToImage(arr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return savePNG(img, path)
|
||||
}
|
||||
|
||||
func savePNG(img *image.RGBA, path string) error {
|
||||
if filepath.Ext(path) != ".png" {
|
||||
path = path + ".png"
|
||||
}
|
||||
f, err := os.Create(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
return png.Encode(f, img)
|
||||
}
|
||||
|
||||
func arrayToImage(arr *mlx.Array) (*image.RGBA, error) {
|
||||
shape := arr.Shape()
|
||||
if len(shape) != 4 {
|
||||
return nil, fmt.Errorf("expected 4D array [B, C, H, W], got %v", shape)
|
||||
}
|
||||
|
||||
// Transform to [H, W, C] for image conversion
|
||||
img := mlx.Squeeze(arr, 0)
|
||||
arr.Free()
|
||||
img = mlx.Transpose(img, 1, 2, 0)
|
||||
img = mlx.Contiguous(img)
|
||||
mlx.Eval(img)
|
||||
|
||||
imgShape := img.Shape()
|
||||
H := int(imgShape[0])
|
||||
W := int(imgShape[1])
|
||||
C := int(imgShape[2])
|
||||
|
||||
if C != 3 {
|
||||
img.Free()
|
||||
return nil, fmt.Errorf("expected 3 channels (RGB), got %d", C)
|
||||
}
|
||||
|
||||
// Copy to CPU and free GPU memory
|
||||
data := img.Data()
|
||||
img.Free()
|
||||
|
||||
// Write directly to Pix slice (faster than SetRGBA)
|
||||
goImg := image.NewRGBA(image.Rect(0, 0, W, H))
|
||||
pix := goImg.Pix
|
||||
for y := 0; y < H; y++ {
|
||||
for x := 0; x < W; x++ {
|
||||
srcIdx := (y*W + x) * C
|
||||
dstIdx := (y*W + x) * 4
|
||||
pix[dstIdx+0] = uint8(clampF(data[srcIdx+0]*255+0.5, 0, 255))
|
||||
pix[dstIdx+1] = uint8(clampF(data[srcIdx+1]*255+0.5, 0, 255))
|
||||
pix[dstIdx+2] = uint8(clampF(data[srcIdx+2]*255+0.5, 0, 255))
|
||||
pix[dstIdx+3] = 255
|
||||
}
|
||||
}
|
||||
|
||||
return goImg, nil
|
||||
}
|
||||
|
||||
func clampF(v, min, max float32) float32 {
|
||||
if v < min {
|
||||
return min
|
||||
}
|
||||
if v > max {
|
||||
return max
|
||||
}
|
||||
return v
|
||||
}
|
||||
286
x/imagegen/cmd/engine/main.go
Normal file
286
x/imagegen/cmd/engine/main.go
Normal file
@@ -0,0 +1,286 @@
|
||||
//go:build mlx
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime/pprof"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/models/gemma3"
|
||||
"github.com/ollama/ollama/x/imagegen/models/gpt_oss"
|
||||
"github.com/ollama/ollama/x/imagegen/models/llama"
|
||||
"github.com/ollama/ollama/x/imagegen/models/qwen_image"
|
||||
"github.com/ollama/ollama/x/imagegen/models/qwen_image_edit"
|
||||
"github.com/ollama/ollama/x/imagegen/models/zimage"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
|
||||
// stringSlice is a flag type that accumulates multiple values
|
||||
type stringSlice []string
|
||||
|
||||
func (s *stringSlice) String() string {
|
||||
return fmt.Sprintf("%v", *s)
|
||||
}
|
||||
|
||||
func (s *stringSlice) Set(value string) error {
|
||||
*s = append(*s, value)
|
||||
return nil
|
||||
}
|
||||
|
||||
func main() {
|
||||
modelPath := flag.String("model", "", "Model directory")
|
||||
prompt := flag.String("prompt", "Hello", "Prompt")
|
||||
|
||||
// Text generation params
|
||||
maxTokens := flag.Int("max-tokens", 100, "Max tokens")
|
||||
temperature := flag.Float64("temperature", 0.7, "Temperature")
|
||||
topP := flag.Float64("top-p", 0.9, "Top-p sampling")
|
||||
topK := flag.Int("top-k", 40, "Top-k sampling")
|
||||
imagePath := flag.String("image", "", "Image path for multimodal models")
|
||||
|
||||
// Image generation params
|
||||
width := flag.Int("width", 1024, "Image width")
|
||||
height := flag.Int("height", 1024, "Image height")
|
||||
steps := flag.Int("steps", 9, "Denoising steps")
|
||||
seed := flag.Int64("seed", 42, "Random seed")
|
||||
out := flag.String("output", "output.png", "Output path")
|
||||
|
||||
// Utility flags
|
||||
listTensors := flag.Bool("list", false, "List tensors only")
|
||||
cpuProfile := flag.String("cpuprofile", "", "Write CPU profile to file")
|
||||
gpuCapture := flag.String("gpu-capture", "", "Capture GPU trace to .gputrace file (run with MTL_CAPTURE_ENABLED=1)")
|
||||
layerCache := flag.Bool("layer-cache", false, "Enable layer caching for faster diffusion (Z-Image, Qwen-Image). Not compatible with CFG/negative prompts.")
|
||||
wiredLimitGB := flag.Int("wired-limit", 32, "Metal wired memory limit in GB")
|
||||
|
||||
// Legacy mode flags
|
||||
zimageFlag := flag.Bool("zimage", false, "Z-Image generation")
|
||||
qwenImage := flag.Bool("qwen-image", false, "Qwen-Image text-to-image generation")
|
||||
qwenImageEdit := flag.Bool("qwen-image-edit", false, "Qwen-Image-Edit image editing")
|
||||
var inputImages stringSlice
|
||||
flag.Var(&inputImages, "input-image", "Input image for image editing (can be specified multiple times)")
|
||||
negativePrompt := flag.String("negative-prompt", "", "Negative prompt for CFG (empty = no CFG, matching Python)")
|
||||
cfgScale := flag.Float64("cfg-scale", 4.0, "CFG scale for image editing")
|
||||
|
||||
flag.Parse()
|
||||
|
||||
if *modelPath == "" {
|
||||
flag.Usage()
|
||||
return
|
||||
}
|
||||
|
||||
// CPU profiling
|
||||
if *cpuProfile != "" {
|
||||
f, err := os.Create(*cpuProfile)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
if err := pprof.StartCPUProfile(f); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer pprof.StopCPUProfile()
|
||||
}
|
||||
|
||||
var err error
|
||||
|
||||
// Handle legacy mode flags that aren't unified yet
|
||||
switch {
|
||||
case *zimageFlag:
|
||||
m := &zimage.Model{}
|
||||
if loadErr := m.Load(*modelPath); loadErr != nil {
|
||||
log.Fatal(loadErr)
|
||||
}
|
||||
var img *mlx.Array
|
||||
img, err = m.GenerateFromConfig(&zimage.GenerateConfig{
|
||||
Prompt: *prompt,
|
||||
Width: int32(*width),
|
||||
Height: int32(*height),
|
||||
Steps: *steps,
|
||||
Seed: *seed,
|
||||
CapturePath: *gpuCapture,
|
||||
LayerCache: *layerCache,
|
||||
})
|
||||
if err == nil {
|
||||
err = saveImageArray(img, *out)
|
||||
}
|
||||
case *qwenImage:
|
||||
m, loadErr := qwen_image.LoadPersistent(*modelPath)
|
||||
if loadErr != nil {
|
||||
log.Fatal(loadErr)
|
||||
}
|
||||
var img *mlx.Array
|
||||
img, err = m.GenerateFromConfig(&qwen_image.GenerateConfig{
|
||||
Prompt: *prompt,
|
||||
NegativePrompt: *negativePrompt,
|
||||
CFGScale: float32(*cfgScale),
|
||||
Width: int32(*width),
|
||||
Height: int32(*height),
|
||||
Steps: *steps,
|
||||
Seed: *seed,
|
||||
LayerCache: *layerCache,
|
||||
})
|
||||
if err == nil {
|
||||
err = saveImageArray(img, *out)
|
||||
}
|
||||
case *qwenImageEdit:
|
||||
if len(inputImages) == 0 {
|
||||
log.Fatal("qwen-image-edit requires at least one -input-image")
|
||||
}
|
||||
|
||||
m, loadErr := qwen_image_edit.LoadPersistent(*modelPath)
|
||||
if loadErr != nil {
|
||||
log.Fatal(loadErr)
|
||||
}
|
||||
// For image editing, use 0 for dimensions to auto-detect from input image
|
||||
// unless explicitly overridden from defaults
|
||||
editWidth := int32(0)
|
||||
editHeight := int32(0)
|
||||
if *width != 1024 {
|
||||
editWidth = int32(*width)
|
||||
}
|
||||
if *height != 1024 {
|
||||
editHeight = int32(*height)
|
||||
}
|
||||
|
||||
cfg := &qwen_image_edit.GenerateConfig{
|
||||
Prompt: *prompt,
|
||||
NegativePrompt: *negativePrompt,
|
||||
CFGScale: float32(*cfgScale),
|
||||
Width: editWidth,
|
||||
Height: editHeight,
|
||||
Steps: *steps,
|
||||
Seed: *seed,
|
||||
}
|
||||
|
||||
var img *mlx.Array
|
||||
img, err = m.EditFromConfig(inputImages, cfg)
|
||||
if err == nil {
|
||||
err = saveImageArray(img, *out)
|
||||
}
|
||||
case *listTensors:
|
||||
err = listModelTensors(*modelPath)
|
||||
default:
|
||||
// llm path
|
||||
m, err := load(*modelPath)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Load image if provided and model supports it
|
||||
var image *mlx.Array
|
||||
if *imagePath != "" {
|
||||
if mm, ok := m.(interface{ ImageSize() int32 }); ok {
|
||||
image, err = gemma3.ProcessImage(*imagePath, mm.ImageSize())
|
||||
if err != nil {
|
||||
log.Fatal("load image:", err)
|
||||
}
|
||||
} else {
|
||||
log.Fatal("model does not support image input")
|
||||
}
|
||||
}
|
||||
|
||||
err = generate(context.Background(), m, input{
|
||||
Prompt: *prompt,
|
||||
Image: image,
|
||||
MaxTokens: *maxTokens,
|
||||
Temperature: float32(*temperature),
|
||||
TopP: float32(*topP),
|
||||
TopK: *topK,
|
||||
WiredLimitGB: *wiredLimitGB,
|
||||
}, func(out output) {
|
||||
if out.Text != "" {
|
||||
fmt.Print(out.Text)
|
||||
}
|
||||
if out.Done {
|
||||
fmt.Printf("\n\n[prefill: %.1f tok/s, gen: %.1f tok/s]\n", out.PrefillTokSec, out.GenTokSec)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func listModelTensors(modelPath string) error {
|
||||
weights, err := safetensors.LoadModelWeights(modelPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, name := range weights.ListTensors() {
|
||||
info, _ := weights.GetTensorInfo(name)
|
||||
fmt.Printf("%s: %v (%s)\n", name, info.Shape, info.Dtype)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadModel builds and evaluates a model using the common load pattern.
|
||||
// Release safetensors BEFORE eval - lazy arrays have captured their data,
|
||||
// and this reduces peak memory by ~6GB (matches mlx-lm behavior).
|
||||
func loadModel[T Model](build func() T, cleanup func()) T {
|
||||
m := build()
|
||||
weights := mlx.Collect(m)
|
||||
cleanup()
|
||||
mlx.Eval(weights...)
|
||||
return m
|
||||
}
|
||||
|
||||
func load(modelPath string) (Model, error) {
|
||||
kind, err := detectModelKind(modelPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("detect model kind: %w", err)
|
||||
}
|
||||
|
||||
switch kind {
|
||||
case "gpt_oss":
|
||||
return gpt_oss.Load(modelPath)
|
||||
case "gemma3":
|
||||
return gemma3.Load(modelPath)
|
||||
case "gemma3_text":
|
||||
return gemma3.LoadText(modelPath)
|
||||
default:
|
||||
return llama.Load(modelPath)
|
||||
}
|
||||
}
|
||||
|
||||
func detectModelKind(modelPath string) (string, error) {
|
||||
indexPath := filepath.Join(modelPath, "model_index.json")
|
||||
if _, err := os.Stat(indexPath); err == nil {
|
||||
data, err := os.ReadFile(indexPath)
|
||||
if err != nil {
|
||||
return "zimage", nil
|
||||
}
|
||||
var index struct {
|
||||
ClassName string `json:"_class_name"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &index); err == nil {
|
||||
switch index.ClassName {
|
||||
case "FluxPipeline", "ZImagePipeline":
|
||||
return "zimage", nil
|
||||
}
|
||||
}
|
||||
return "zimage", nil
|
||||
}
|
||||
|
||||
configPath := filepath.Join(modelPath, "config.json")
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("no config.json or model_index.json found: %w", err)
|
||||
}
|
||||
|
||||
var cfg struct {
|
||||
ModelType string `json:"model_type"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return "", fmt.Errorf("parse config.json: %w", err)
|
||||
}
|
||||
|
||||
return cfg.ModelType, nil
|
||||
}
|
||||
49
x/imagegen/cmd/engine/sample.go
Normal file
49
x/imagegen/cmd/engine/sample.go
Normal file
@@ -0,0 +1,49 @@
|
||||
//go:build mlx
|
||||
|
||||
package main
|
||||
|
||||
import "github.com/ollama/ollama/x/imagegen/mlx"
|
||||
|
||||
// sampleTopK samples from top-k logits using global random state
|
||||
func sampleTopK(scaledLogits *mlx.Array, k int) *mlx.Array {
|
||||
neg := mlx.Neg(scaledLogits)
|
||||
indices := mlx.Argpartition(neg, k-1, -1)
|
||||
topKIdx := mlx.Slice(indices, []int32{0}, []int32{int32(k)})
|
||||
values := mlx.TakeAlongAxis(scaledLogits, topKIdx, -1)
|
||||
sampled := mlx.RandomCategorical(values, -1, 1)
|
||||
return mlx.Take(topKIdx, sampled, -1)
|
||||
}
|
||||
|
||||
// sampleTopP samples using nucleus sampling with global random state
|
||||
func sampleTopP(scaledLogits *mlx.Array, p float32, vocabSize int32) *mlx.Array {
|
||||
sorted := mlx.Argsort(mlx.Neg(scaledLogits), -1)
|
||||
sortedLogits := mlx.TakeAlongAxis(scaledLogits, sorted, -1)
|
||||
probs := mlx.Softmax(sortedLogits, -1)
|
||||
cumProbs := mlx.Cumsum(probs, -1)
|
||||
mask := mlx.LessScalar(cumProbs, p)
|
||||
negInf := mlx.FullDtype(float32(-1e9), scaledLogits.Dtype(), vocabSize)
|
||||
masked := mlx.Where(mask, sortedLogits, negInf)
|
||||
sampled := mlx.RandomCategorical(masked, -1, 1)
|
||||
return mlx.Take(sorted, sampled, -1)
|
||||
}
|
||||
|
||||
// sample samples from logits at the last position
|
||||
func sample(logits *mlx.Array, temp float32, topK int, topP float32, vocab int32) *mlx.Array {
|
||||
// Get last position logits: [1, L, vocab] -> [vocab]
|
||||
shape := logits.Shape()
|
||||
seqLen := shape[1]
|
||||
lastLogits := mlx.Slice(logits, []int32{0, seqLen - 1, 0}, []int32{1, seqLen, vocab})
|
||||
lastLogits = mlx.Reshape(lastLogits, vocab)
|
||||
|
||||
if temp == 0 {
|
||||
return mlx.Argmax(lastLogits, -1, false)
|
||||
}
|
||||
scaled := mlx.DivScalar(lastLogits, temp)
|
||||
if topK > 0 && topK < int(vocab) {
|
||||
return sampleTopK(scaled, topK)
|
||||
}
|
||||
if topP > 0 && topP < 1.0 {
|
||||
return sampleTopP(scaled, topP, vocab)
|
||||
}
|
||||
return mlx.RandomCategorical(scaled, -1, 1)
|
||||
}
|
||||
46
x/imagegen/mlx/README.md
Normal file
46
x/imagegen/mlx/README.md
Normal file
@@ -0,0 +1,46 @@
|
||||
# MLX Memory Management
|
||||
|
||||
| This package will get consolidated with `x/ml/backend/mlx` in the future.
|
||||
|
||||
## Automatic Tracking
|
||||
|
||||
All arrays are automatically tracked when created. On `Eval()`, non-kept arrays are freed.
|
||||
|
||||
### API
|
||||
|
||||
```go
|
||||
result := mlx.Matmul(x, w) // arrays automatically tracked
|
||||
mlx.Eval(result) // free non-kept, eval result (auto-kept)
|
||||
```
|
||||
|
||||
### Key Functions
|
||||
|
||||
- `mlx.Eval(outputs...)` - free non-kept arrays, then evaluate (outputs auto-kept)
|
||||
- `mlx.AsyncEval(outputs...)` - async version of Eval (outputs auto-kept)
|
||||
- `mlx.Keep(arrays...)` - mark arrays to survive cleanup (for weights, caches)
|
||||
- `array.Free()` - mark array for cleanup on next Eval
|
||||
|
||||
### Loop Pattern
|
||||
|
||||
```go
|
||||
for step := 0; step < maxTokens; step++ {
|
||||
logits := model.Forward(token, caches)
|
||||
oldToken := token
|
||||
token = sample(logits)
|
||||
|
||||
// Keep cache state across iterations
|
||||
for _, c := range caches {
|
||||
mlx.Keep(c.State()...)
|
||||
}
|
||||
|
||||
oldToken.Free() // mark for cleanup
|
||||
mlx.AsyncEval(token) // frees old, evals new
|
||||
}
|
||||
```
|
||||
|
||||
### Notes
|
||||
|
||||
- `Eval()` and `AsyncEval()` auto-keep their outputs
|
||||
- `Free()` marks for cleanup - actual free happens during next Eval
|
||||
- Use `Keep()` for weights and cache state that must survive multiple Eval cycles
|
||||
- Arrays created inside compiled closures are managed by MLX, not tracked
|
||||
173
x/imagegen/mlx/compile.go
Normal file
173
x/imagegen/mlx/compile.go
Normal file
@@ -0,0 +1,173 @@
|
||||
//go:build mlx
|
||||
|
||||
package mlx
|
||||
|
||||
/*
|
||||
#include "mlx/c/mlx.h"
|
||||
#include <stdlib.h>
|
||||
|
||||
// Forward declaration for Go callback
|
||||
extern int goClosureCallback(mlx_vector_array* res, mlx_vector_array input, void* payload);
|
||||
|
||||
// Destructor for payload (Go handle)
|
||||
extern void goClosureDestructor(void* payload);
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"runtime/cgo"
|
||||
"sync"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// inClosureCallback is set to true during closure callback execution.
|
||||
var inClosureCallback bool
|
||||
var closureCallbackMu sync.Mutex
|
||||
|
||||
// InClosureCallback returns true if we're currently executing inside a closure callback.
|
||||
func InClosureCallback() bool {
|
||||
closureCallbackMu.Lock()
|
||||
defer closureCallbackMu.Unlock()
|
||||
return inClosureCallback
|
||||
}
|
||||
|
||||
// CompiledFunc is a compiled MLX function that can be called efficiently.
|
||||
// All intermediate arrays during execution stay inside MLX - only inputs
|
||||
// and outputs cross the Go boundary.
|
||||
type CompiledFunc struct {
|
||||
closure C.mlx_closure
|
||||
compiled C.mlx_closure
|
||||
}
|
||||
|
||||
// ClosureFunc is the signature for functions that can be compiled.
|
||||
// It takes a slice of input arrays and returns a slice of output arrays.
|
||||
type ClosureFunc func(inputs []*Array) []*Array
|
||||
|
||||
// Compile compiles a Go function into an optimized MLX closure.
|
||||
// The function is traced once during compilation, then subsequent calls
|
||||
// run the optimized graph without creating Go intermediate arrays.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// compiled := mlx.Compile(func(inputs []*mlx.Array) []*mlx.Array {
|
||||
// a, b := inputs[0], inputs[1]
|
||||
// c := mlx.Add(a, b)
|
||||
// d := mlx.Mul(c, c)
|
||||
// return []*mlx.Array{d}
|
||||
// })
|
||||
// defer compiled.Free()
|
||||
//
|
||||
// result := compiled.Call(x, y)[0]
|
||||
func Compile(fn ClosureFunc) *CompiledFunc {
|
||||
return CompileShapeless(fn, false)
|
||||
}
|
||||
|
||||
// CompileShapeless compiles with optional shapeless mode.
|
||||
// If shapeless=true, the function works for any input shape after tracing.
|
||||
func CompileShapeless(fn ClosureFunc, shapeless bool) *CompiledFunc {
|
||||
// Create a cgo.Handle to prevent the Go function from being GC'd
|
||||
handle := cgo.NewHandle(fn)
|
||||
|
||||
// Create the closure from the Go callback
|
||||
closure := C.mlx_closure_new_func_payload(
|
||||
(*[0]byte)(C.goClosureCallback),
|
||||
unsafe.Pointer(handle),
|
||||
(*[0]byte)(C.goClosureDestructor),
|
||||
)
|
||||
|
||||
// Compile the closure
|
||||
compiled := C.mlx_closure_new()
|
||||
C.mlx_compile(&compiled, closure, C.bool(shapeless))
|
||||
|
||||
return &CompiledFunc{
|
||||
closure: closure,
|
||||
compiled: compiled,
|
||||
}
|
||||
}
|
||||
|
||||
// Call invokes the compiled function with the given inputs.
|
||||
func (cf *CompiledFunc) Call(inputs ...*Array) []*Array {
|
||||
// Pack inputs into vector
|
||||
inputVec := C.mlx_vector_array_new()
|
||||
for _, arr := range inputs {
|
||||
C.mlx_vector_array_append_value(inputVec, arr.c)
|
||||
}
|
||||
|
||||
// Apply compiled closure
|
||||
outputVec := C.mlx_vector_array_new()
|
||||
C.mlx_closure_apply(&outputVec, cf.compiled, inputVec)
|
||||
C.mlx_vector_array_free(inputVec)
|
||||
|
||||
// Unpack outputs
|
||||
numOutputs := int(C.mlx_vector_array_size(outputVec))
|
||||
outputs := make([]*Array, numOutputs)
|
||||
for i := 0; i < numOutputs; i++ {
|
||||
var arr C.mlx_array
|
||||
C.mlx_vector_array_get(&arr, outputVec, C.size_t(i))
|
||||
outputs[i] = newArray(arr)
|
||||
}
|
||||
C.mlx_vector_array_free(outputVec)
|
||||
|
||||
return outputs
|
||||
}
|
||||
|
||||
// CallEval invokes the compiled function and evaluates the results.
|
||||
func (cf *CompiledFunc) CallEval(inputs ...*Array) []*Array {
|
||||
outputs := cf.Call(inputs...)
|
||||
Eval(outputs...)
|
||||
return outputs
|
||||
}
|
||||
|
||||
// Free releases the compiled function resources.
|
||||
func (cf *CompiledFunc) Free() {
|
||||
C.mlx_closure_free(cf.compiled)
|
||||
C.mlx_closure_free(cf.closure)
|
||||
}
|
||||
|
||||
// borrowArray wraps a C array WITHOUT setting up GC cleanup.
|
||||
// Use this for arrays we don't own (e.g., borrowed references in callbacks).
|
||||
func borrowArray(array C.mlx_array) *Array {
|
||||
return &Array{c: array}
|
||||
}
|
||||
|
||||
//export goClosureCallback
|
||||
func goClosureCallback(res *C.mlx_vector_array, input C.mlx_vector_array, payload unsafe.Pointer) C.int {
|
||||
// Set flag to disable AddCleanup during callback
|
||||
closureCallbackMu.Lock()
|
||||
inClosureCallback = true
|
||||
closureCallbackMu.Unlock()
|
||||
defer func() {
|
||||
closureCallbackMu.Lock()
|
||||
inClosureCallback = false
|
||||
closureCallbackMu.Unlock()
|
||||
}()
|
||||
|
||||
// Recover the Go function from the handle
|
||||
handle := cgo.Handle(payload)
|
||||
fn := handle.Value().(ClosureFunc)
|
||||
|
||||
// Convert input vector to Go slice - use borrowArray since MLX owns these
|
||||
numInputs := int(C.mlx_vector_array_size(input))
|
||||
inputs := make([]*Array, numInputs)
|
||||
for i := 0; i < numInputs; i++ {
|
||||
var arr C.mlx_array
|
||||
C.mlx_vector_array_get(&arr, input, C.size_t(i))
|
||||
inputs[i] = borrowArray(arr) // Don't set up cleanup - MLX owns these
|
||||
}
|
||||
|
||||
// Call the Go function
|
||||
outputs := fn(inputs)
|
||||
|
||||
// Build output vector
|
||||
*res = C.mlx_vector_array_new()
|
||||
for _, arr := range outputs {
|
||||
C.mlx_vector_array_append_value(*res, arr.c)
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
//export goClosureDestructor
|
||||
func goClosureDestructor(payload unsafe.Pointer) {
|
||||
handle := cgo.Handle(payload)
|
||||
handle.Delete()
|
||||
}
|
||||
2238
x/imagegen/mlx/mlx.go
Normal file
2238
x/imagegen/mlx/mlx.go
Normal file
File diff suppressed because it is too large
Load Diff
1145
x/imagegen/mlx/mlx_test.go
Normal file
1145
x/imagegen/mlx/mlx_test.go
Normal file
File diff suppressed because it is too large
Load Diff
614
x/imagegen/models/gemma3/gemma3.go
Normal file
614
x/imagegen/models/gemma3/gemma3.go
Normal file
@@ -0,0 +1,614 @@
|
||||
//go:build mlx
|
||||
|
||||
package gemma3
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
)
|
||||
|
||||
// TextConfig holds configuration for the text model
|
||||
type TextConfig struct {
|
||||
HiddenSize int32 `json:"hidden_size"`
|
||||
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
||||
IntermediateSize int32 `json:"intermediate_size"`
|
||||
NumAttentionHeads int32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads int32 `json:"num_key_value_heads"`
|
||||
HeadDim int32 `json:"head_dim"`
|
||||
VocabSize int32 `json:"vocab_size"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
RopeLocalBaseFreq float32 `json:"rope_local_base_freq"`
|
||||
MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
|
||||
SlidingWindow int32 `json:"sliding_window"`
|
||||
SlidingWindowPattern int32 `json:"sliding_window_pattern"`
|
||||
|
||||
// Computed fields
|
||||
Scale float32 `json:"-"`
|
||||
}
|
||||
|
||||
// TextModel is the Gemma 3 text-only model
|
||||
type TextModel struct {
|
||||
EmbedTokens *nn.Embedding `weight:"model.embed_tokens"`
|
||||
Layers []*DecoderLayer `weight:"model.layers"`
|
||||
Norm *nn.RMSNorm `weight:"model.norm"`
|
||||
Output *nn.Linear `weight:"-"` // Tied to EmbedTokens, set manually
|
||||
|
||||
// Precomputed (1 + weight) for Gemma-style RMSNorm to avoid allocation per forward
|
||||
NormScaled *mlx.Array `weight:"-"`
|
||||
|
||||
tok *tokenizer.Tokenizer
|
||||
*TextConfig
|
||||
}
|
||||
|
||||
// DecoderLayer is a single transformer block
|
||||
type DecoderLayer struct {
|
||||
InputNorm *nn.RMSNorm `weight:"input_layernorm"`
|
||||
Attention *Attention
|
||||
PostAttnNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
|
||||
PreFFNorm *nn.RMSNorm `weight:"pre_feedforward_layernorm"`
|
||||
MLP *MLP
|
||||
PostFFNorm *nn.RMSNorm `weight:"post_feedforward_layernorm"`
|
||||
|
||||
// Precomputed (1 + weight) for Gemma-style RMSNorm
|
||||
InputNormScaled *mlx.Array `weight:"-"`
|
||||
PostAttnNormScaled *mlx.Array `weight:"-"`
|
||||
PreFFNormScaled *mlx.Array `weight:"-"`
|
||||
PostFFNormScaled *mlx.Array `weight:"-"`
|
||||
|
||||
// Whether this layer uses sliding window attention
|
||||
IsSliding bool
|
||||
LayerIdx int32
|
||||
}
|
||||
|
||||
// Attention implements Gemma 3 attention with Q/K normalization
|
||||
type Attention struct {
|
||||
QProj *nn.Linear `weight:"self_attn.q_proj"`
|
||||
KProj *nn.Linear `weight:"self_attn.k_proj"`
|
||||
VProj *nn.Linear `weight:"self_attn.v_proj"`
|
||||
OProj *nn.Linear `weight:"self_attn.o_proj"`
|
||||
QNorm *nn.RMSNorm `weight:"self_attn.q_norm"`
|
||||
KNorm *nn.RMSNorm `weight:"self_attn.k_norm"`
|
||||
|
||||
// Precomputed (1 + weight) for Gemma-style RMSNorm
|
||||
QNormScaled *mlx.Array `weight:"-"`
|
||||
KNormScaled *mlx.Array `weight:"-"`
|
||||
}
|
||||
|
||||
// MLP is the feed-forward network with GELU activation
|
||||
type MLP struct {
|
||||
GateProj *nn.Linear `weight:"mlp.gate_proj"`
|
||||
UpProj *nn.Linear `weight:"mlp.up_proj"`
|
||||
DownProj *nn.Linear `weight:"mlp.down_proj"`
|
||||
}
|
||||
|
||||
// LoadText loads the text-only Gemma 3 model
|
||||
func LoadText(modelPath string) (*TextModel, error) {
|
||||
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
var cfg TextConfig
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("parse config: %w", err)
|
||||
}
|
||||
|
||||
// Compute scale
|
||||
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
|
||||
|
||||
// Set defaults if not specified
|
||||
if cfg.RopeTheta == 0 {
|
||||
cfg.RopeTheta = 1000000
|
||||
}
|
||||
if cfg.RopeLocalBaseFreq == 0 {
|
||||
cfg.RopeLocalBaseFreq = 10000
|
||||
}
|
||||
if cfg.RMSNormEps == 0 {
|
||||
cfg.RMSNormEps = 1e-6
|
||||
}
|
||||
|
||||
weights, err := safetensors.LoadModelWeights(modelPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
|
||||
tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load tokenizer: %w", err)
|
||||
}
|
||||
|
||||
m := &TextModel{
|
||||
Layers: make([]*DecoderLayer, cfg.NumHiddenLayers),
|
||||
TextConfig: &cfg,
|
||||
tok: tok,
|
||||
}
|
||||
|
||||
// Initialize layer metadata
|
||||
for i := range m.Layers {
|
||||
m.Layers[i] = &DecoderLayer{
|
||||
LayerIdx: int32(i),
|
||||
IsSliding: isLayerSliding(int32(i), cfg.SlidingWindowPattern),
|
||||
}
|
||||
}
|
||||
|
||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Tied embeddings for output
|
||||
m.Output = nn.NewLinear(m.EmbedTokens.Weight, nil)
|
||||
|
||||
mlx.Eval(mlx.Collect(m)...)
|
||||
weights.ReleaseAll()
|
||||
|
||||
// Precompute (1 + weight) for Gemma-style RMSNorm to avoid per-forward allocation
|
||||
precomputeGemmaScaledWeights(m)
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// precomputeGemmaScaledWeights computes (1 + weight) for all RMSNorm layers
|
||||
// This avoids creating temporary arrays on every forward pass
|
||||
func precomputeGemmaScaledWeights(m *TextModel) {
|
||||
m.NormScaled = mlx.AddScalar(m.Norm.Weight, 1.0)
|
||||
|
||||
for _, layer := range m.Layers {
|
||||
layer.InputNormScaled = mlx.AddScalar(layer.InputNorm.Weight, 1.0)
|
||||
layer.PostAttnNormScaled = mlx.AddScalar(layer.PostAttnNorm.Weight, 1.0)
|
||||
layer.PreFFNormScaled = mlx.AddScalar(layer.PreFFNorm.Weight, 1.0)
|
||||
layer.PostFFNormScaled = mlx.AddScalar(layer.PostFFNorm.Weight, 1.0)
|
||||
|
||||
layer.Attention.QNormScaled = mlx.AddScalar(layer.Attention.QNorm.Weight, 1.0)
|
||||
layer.Attention.KNormScaled = mlx.AddScalar(layer.Attention.KNorm.Weight, 1.0)
|
||||
}
|
||||
|
||||
// Eval all the precomputed weights
|
||||
var scaled []*mlx.Array
|
||||
scaled = append(scaled, m.NormScaled)
|
||||
for _, layer := range m.Layers {
|
||||
scaled = append(scaled, layer.InputNormScaled, layer.PostAttnNormScaled,
|
||||
layer.PreFFNormScaled, layer.PostFFNormScaled,
|
||||
layer.Attention.QNormScaled, layer.Attention.KNormScaled)
|
||||
}
|
||||
mlx.Eval(scaled...)
|
||||
}
|
||||
|
||||
// isLayerSliding determines if a layer uses sliding window attention
|
||||
// Pattern N means: layers 0 to N-1 sliding, N full, N+1 to 2N-1 sliding, 2N full, etc.
|
||||
func isLayerSliding(layerIdx, pattern int32) bool {
|
||||
if pattern <= 0 {
|
||||
return false // No sliding window
|
||||
}
|
||||
// Layer is full attention if (layerIdx + 1) % pattern == 0
|
||||
return (layerIdx+1)%pattern != 0
|
||||
}
|
||||
|
||||
// Forward runs the text model forward pass
|
||||
func (m *TextModel) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
|
||||
B, L := tokens.Shape()[0], tokens.Shape()[1]
|
||||
|
||||
// Get embeddings and scale by sqrt(hidden_size)
|
||||
h := m.EmbedTokens.Forward(tokens)
|
||||
h = mlx.MulScalar(h, float32(math.Sqrt(float64(m.HiddenSize))))
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
h = layer.Forward(h, caches[i], B, L, m.TextConfig)
|
||||
}
|
||||
|
||||
// Final norm and output projection
|
||||
return m.Output.Forward(mlx.RMSNorm(h, m.NormScaled, m.RMSNormEps))
|
||||
}
|
||||
|
||||
// Forward runs a decoder layer
|
||||
func (l *DecoderLayer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *TextConfig) *mlx.Array {
|
||||
// Pre-attention norm (use precomputed scaled weight)
|
||||
normed := mlx.RMSNorm(x, l.InputNormScaled, cfg.RMSNormEps)
|
||||
|
||||
// Attention
|
||||
attnOut := l.Attention.Forward(normed, c, B, L, l.IsSliding, cfg)
|
||||
|
||||
// Post-attention norm and residual
|
||||
attnOut = mlx.RMSNorm(attnOut, l.PostAttnNormScaled, cfg.RMSNormEps)
|
||||
h := mlx.Add(x, attnOut)
|
||||
|
||||
// Pre-FFN norm
|
||||
normed = mlx.RMSNorm(h, l.PreFFNormScaled, cfg.RMSNormEps)
|
||||
|
||||
// MLP
|
||||
mlpOut := l.MLP.Forward(normed)
|
||||
|
||||
// Post-FFN norm and residual
|
||||
mlpOut = mlx.RMSNorm(mlpOut, l.PostFFNormScaled, cfg.RMSNormEps)
|
||||
return mlx.Add(h, mlpOut)
|
||||
}
|
||||
|
||||
// Forward runs attention with Q/K normalization
|
||||
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding bool, cfg *TextConfig) *mlx.Array {
|
||||
q := a.QProj.Forward(x)
|
||||
k := a.KProj.Forward(x)
|
||||
v := a.VProj.Forward(x)
|
||||
|
||||
// Reshape to [B, num_heads, L, head_dim]
|
||||
q = mlx.AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim},
|
||||
[]int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0)
|
||||
k = mlx.AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
|
||||
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
|
||||
v = mlx.AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
|
||||
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
|
||||
|
||||
// Q/K normalization after reshaping (use precomputed scaled weight)
|
||||
q = mlx.RMSNorm(q, a.QNormScaled, cfg.RMSNormEps)
|
||||
k = mlx.RMSNorm(k, a.KNormScaled, cfg.RMSNormEps)
|
||||
|
||||
// Apply RoPE with appropriate theta
|
||||
ropeTheta := cfg.RopeTheta
|
||||
if isSliding {
|
||||
ropeTheta = cfg.RopeLocalBaseFreq
|
||||
}
|
||||
q = mlx.RoPE(q, int(cfg.HeadDim), false, ropeTheta, 1.0, c.Offset())
|
||||
k = mlx.RoPE(k, int(cfg.HeadDim), false, ropeTheta, 1.0, c.Offset())
|
||||
|
||||
// Update cache
|
||||
k, v = c.Update(k, v, int(L))
|
||||
|
||||
// Repeat K/V for GQA if needed
|
||||
repeatFactor := cfg.NumAttentionHeads / cfg.NumKeyValueHeads
|
||||
if repeatFactor > 1 {
|
||||
k = nn.RepeatKV(k, repeatFactor)
|
||||
v = nn.RepeatKV(v, repeatFactor)
|
||||
}
|
||||
|
||||
// Attention
|
||||
out := mlx.ScaledDotProductAttention(q, k, v, cfg.Scale, L > 1)
|
||||
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
|
||||
return a.OProj.Forward(out)
|
||||
}
|
||||
|
||||
// compiledGeluApprox is a singleton compiled GELU function shared across all layers
|
||||
var compiledGeluApprox *mlx.CompiledFunc
|
||||
|
||||
// getCompiledGeluApprox returns the compiled GELU function, creating it once if needed
|
||||
func getCompiledGeluApprox() *mlx.CompiledFunc {
|
||||
if compiledGeluApprox == nil {
|
||||
compiledGeluApprox = mlx.CompileShapeless(func(inputs []*mlx.Array) []*mlx.Array {
|
||||
return []*mlx.Array{geluApproxImpl(inputs[0])}
|
||||
}, true)
|
||||
}
|
||||
return compiledGeluApprox
|
||||
}
|
||||
|
||||
// Forward runs the MLP with GELU approximation (tanh variant)
|
||||
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
|
||||
gate := getCompiledGeluApprox().Call(m.GateProj.Forward(x))[0]
|
||||
return m.DownProj.Forward(mlx.Mul(gate, m.UpProj.Forward(x)))
|
||||
}
|
||||
|
||||
// geluApproxImpl computes GELU using the tanh approximation (gelu_pytorch_tanh):
|
||||
// 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
|
||||
func geluApproxImpl(x *mlx.Array) *mlx.Array {
|
||||
// Constants
|
||||
const sqrt2OverPi = 0.7978845608028654 // sqrt(2/pi)
|
||||
const coeff = 0.044715
|
||||
|
||||
// x^3
|
||||
x3 := mlx.Mul(mlx.Mul(x, x), x)
|
||||
// x + 0.044715 * x^3
|
||||
inner := mlx.Add(x, mlx.MulScalar(x3, coeff))
|
||||
// sqrt(2/pi) * (x + 0.044715 * x^3)
|
||||
scaled := mlx.MulScalar(inner, sqrt2OverPi)
|
||||
// tanh(...)
|
||||
tanh := mlx.Tanh(scaled)
|
||||
// 1 + tanh(...)
|
||||
onePlusTanh := mlx.AddScalar(tanh, 1.0)
|
||||
// 0.5 * x * (1 + tanh(...))
|
||||
return mlx.Mul(mlx.MulScalar(x, 0.5), onePlusTanh)
|
||||
}
|
||||
|
||||
// gemmaRMSNorm applies Gemma-style RMS normalization: x * rsqrt(mean(x^2) + eps) * (1 + weight)
|
||||
// Uses mlx.RMSNorm fast kernel with pre-computed (1 + weight)
|
||||
func gemmaRMSNorm(x, weight *mlx.Array, eps float32) *mlx.Array {
|
||||
// Gemma uses (1 + weight) instead of weight
|
||||
scaledWeight := mlx.AddScalar(weight, 1.0)
|
||||
return mlx.RMSNorm(x, scaledWeight, eps)
|
||||
}
|
||||
|
||||
// Interface methods
|
||||
func (m *TextModel) NumLayers() int { return len(m.Layers) }
|
||||
func (m *TextModel) MaxContextLength() int32 { return m.MaxPositionEmbeddings }
|
||||
func (m *TextModel) VocabSize() int32 { return m.TextConfig.VocabSize }
|
||||
|
||||
// Tokenizer returns the tokenizer wrapped to add BOS and apply chat template
|
||||
func (m *TextModel) Tokenizer() *tokenizer.Tokenizer {
|
||||
return m.tok
|
||||
}
|
||||
|
||||
// FormatPrompt applies the Gemma 3 chat template to a prompt
|
||||
func (m *TextModel) FormatPrompt(prompt string) string {
|
||||
// Gemma 3 chat format: <start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n
|
||||
return fmt.Sprintf("<start_of_turn>user\n%s<end_of_turn>\n<start_of_turn>model\n", prompt)
|
||||
}
|
||||
|
||||
func (m *TextModel) NewCache(maxSeqLen int32) []cache.Cache {
|
||||
caches := make([]cache.Cache, len(m.Layers))
|
||||
for i := range caches {
|
||||
if m.Layers[i].IsSliding {
|
||||
// Use rotating cache for sliding window layers
|
||||
caches[i] = cache.NewRotatingKVCache(int(m.SlidingWindow))
|
||||
} else {
|
||||
// Use regular cache for global attention layers
|
||||
caches[i] = cache.NewKVCache()
|
||||
}
|
||||
}
|
||||
return caches
|
||||
}
|
||||
|
||||
// Config holds config for the full multimodal model
|
||||
type Config struct {
|
||||
TextConfig TextConfig `json:"text_config"`
|
||||
VisionConfig VisionConfig `json:"vision_config"`
|
||||
|
||||
// Image token config (from config.json)
|
||||
BOITokenIndex int32 `json:"boi_token_index"` // <start_of_image> = 255999
|
||||
EOITokenIndex int32 `json:"eoi_token_index"` // <end_of_image> = 256000
|
||||
ImageTokenIndex int32 `json:"image_token_index"` // <image_soft_token> = 262144
|
||||
MMTokensPerImage int32 `json:"mm_tokens_per_image"` // 256
|
||||
}
|
||||
|
||||
// Model is the full Gemma 3 multimodal model
|
||||
type Model struct {
|
||||
VisionTower *VisionTower `weight:"vision_tower"`
|
||||
Projector *MultiModalProjector `weight:"multi_modal_projector"`
|
||||
TextModel *TextModel `weight:"language_model"`
|
||||
Config *Config
|
||||
tok *tokenizer.Tokenizer
|
||||
}
|
||||
|
||||
// Load loads the full multimodal Gemma 3 model
|
||||
func Load(modelPath string) (*Model, error) {
|
||||
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
|
||||
var cfg Config
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("parse config: %w", err)
|
||||
}
|
||||
|
||||
// Set defaults for text config (multimodal config often has incomplete text_config)
|
||||
// These defaults match transformers.Gemma3TextConfig defaults
|
||||
tc := &cfg.TextConfig
|
||||
if tc.HeadDim == 0 {
|
||||
tc.HeadDim = 256 // Gemma 3 uses head_dim=256
|
||||
}
|
||||
if tc.NumAttentionHeads == 0 {
|
||||
// Gemma 3 4B uses 8 attention heads (cannot infer from hidden_size/head_dim)
|
||||
tc.NumAttentionHeads = 8
|
||||
}
|
||||
if tc.NumKeyValueHeads == 0 {
|
||||
// Gemma 3 4B uses 4 KV heads (GQA with 2:1 ratio)
|
||||
tc.NumKeyValueHeads = 4
|
||||
}
|
||||
if tc.VocabSize == 0 {
|
||||
tc.VocabSize = 262208 // Gemma 3 vocab size (not 262144!)
|
||||
}
|
||||
if tc.RopeTheta == 0 {
|
||||
tc.RopeTheta = 1000000
|
||||
}
|
||||
if tc.RopeLocalBaseFreq == 0 {
|
||||
tc.RopeLocalBaseFreq = 10000
|
||||
}
|
||||
if tc.RMSNormEps == 0 {
|
||||
tc.RMSNormEps = 1e-6
|
||||
}
|
||||
if tc.SlidingWindowPattern == 0 {
|
||||
tc.SlidingWindowPattern = 6
|
||||
}
|
||||
if tc.MaxPositionEmbeddings == 0 {
|
||||
tc.MaxPositionEmbeddings = 131072 // Gemma 3 4B default
|
||||
}
|
||||
|
||||
// Compute text model scale
|
||||
tc.Scale = float32(1.0 / math.Sqrt(float64(tc.HeadDim)))
|
||||
|
||||
// Set defaults for image token config
|
||||
if cfg.BOITokenIndex == 0 {
|
||||
cfg.BOITokenIndex = 255999 // <start_of_image>
|
||||
}
|
||||
if cfg.EOITokenIndex == 0 {
|
||||
cfg.EOITokenIndex = 256000 // <end_of_image>
|
||||
}
|
||||
if cfg.ImageTokenIndex == 0 {
|
||||
cfg.ImageTokenIndex = 262144 // <image_soft_token>
|
||||
}
|
||||
if cfg.MMTokensPerImage == 0 {
|
||||
cfg.MMTokensPerImage = 256
|
||||
}
|
||||
|
||||
weights, err := safetensors.LoadModelWeights(modelPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
|
||||
tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load tokenizer: %w", err)
|
||||
}
|
||||
|
||||
m := &Model{
|
||||
VisionTower: &VisionTower{
|
||||
Embeddings: &VisionEmbeddings{},
|
||||
Encoder: make([]*VisionEncoderLayer, cfg.VisionConfig.NumHiddenLayers),
|
||||
Config: &cfg.VisionConfig,
|
||||
},
|
||||
Projector: &MultiModalProjector{},
|
||||
TextModel: &TextModel{
|
||||
Layers: make([]*DecoderLayer, cfg.TextConfig.NumHiddenLayers),
|
||||
TextConfig: &cfg.TextConfig,
|
||||
},
|
||||
Config: &cfg,
|
||||
tok: tok,
|
||||
}
|
||||
|
||||
// Initialize text layer metadata
|
||||
for i := range m.TextModel.Layers {
|
||||
m.TextModel.Layers[i] = &DecoderLayer{
|
||||
LayerIdx: int32(i),
|
||||
IsSliding: isLayerSliding(int32(i), cfg.TextConfig.SlidingWindowPattern),
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize vision encoder layers
|
||||
for i := range m.VisionTower.Encoder {
|
||||
m.VisionTower.Encoder[i] = &VisionEncoderLayer{}
|
||||
}
|
||||
|
||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Tied embeddings for text output
|
||||
m.TextModel.Output = nn.NewLinear(m.TextModel.EmbedTokens.Weight, nil)
|
||||
m.TextModel.tok = tok
|
||||
|
||||
mlx.Eval(mlx.Collect(m)...)
|
||||
weights.ReleaseAll()
|
||||
|
||||
// Precompute (1 + weight) for Gemma-style RMSNorm
|
||||
precomputeGemmaScaledWeights(m.TextModel)
|
||||
|
||||
// Precompute projector's scaled weight
|
||||
m.Projector.SoftEmbNormScaled = mlx.AddScalar(m.Projector.SoftEmbNorm.Weight, 1.0)
|
||||
mlx.Eval(m.Projector.SoftEmbNormScaled)
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Forward runs the text-only forward pass
|
||||
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
|
||||
return m.TextModel.Forward(tokens, caches)
|
||||
}
|
||||
|
||||
// ForwardWithImage runs the multimodal forward pass
|
||||
// tokens: [B, L] input token IDs (with image placeholder tokens)
|
||||
// image: [B, H, W, C] preprocessed image tensor
|
||||
func (m *Model) ForwardWithImage(tokens *mlx.Array, image *mlx.Array, caches []cache.Cache) *mlx.Array {
|
||||
B, L := tokens.Shape()[0], tokens.Shape()[1]
|
||||
cfg := m.Config.TextConfig
|
||||
|
||||
// Find image token position FIRST before any eval that might free tokens
|
||||
imageStartPos := int32(-1)
|
||||
if image != nil && B == 1 {
|
||||
tokenData := tokens.DataInt32() // This evals tokens
|
||||
for i, t := range tokenData {
|
||||
if t == m.Config.ImageTokenIndex {
|
||||
imageStartPos = int32(i)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get text embeddings and scale
|
||||
h := m.TextModel.EmbedTokens.Forward(tokens)
|
||||
h = mlx.MulScalar(h, float32(math.Sqrt(float64(cfg.HiddenSize))))
|
||||
|
||||
// Process image if provided
|
||||
if image != nil && imageStartPos >= 0 {
|
||||
// Vision tower: [B, H, W, C] -> [B, num_patches, vision_hidden]
|
||||
visionFeatures := m.VisionTower.Forward(image)
|
||||
|
||||
// Project to text space: [B, num_patches, vision_hidden] -> [B, 256, text_hidden]
|
||||
imageEmbeds := m.Projector.Forward(visionFeatures, cfg.RMSNormEps)
|
||||
|
||||
// Eval h and imageEmbeds together so neither gets freed
|
||||
mlx.Eval(h, imageEmbeds)
|
||||
|
||||
// Cast imageEmbeds to match text embeddings dtype (bf16)
|
||||
if imageEmbeds.Dtype() != h.Dtype() {
|
||||
imageEmbeds = mlx.AsType(imageEmbeds, h.Dtype())
|
||||
mlx.Eval(imageEmbeds)
|
||||
}
|
||||
|
||||
// Insert image embeddings at the known position
|
||||
h = m.insertImageEmbeddingsAt(h, imageEmbeds, imageStartPos)
|
||||
}
|
||||
|
||||
// Run through text model layers
|
||||
for i, layer := range m.TextModel.Layers {
|
||||
h = layer.Forward(h, caches[i], B, L, m.TextModel.TextConfig)
|
||||
}
|
||||
|
||||
// Final norm and output projection
|
||||
return m.TextModel.Output.Forward(mlx.RMSNorm(h, m.TextModel.NormScaled, cfg.RMSNormEps))
|
||||
}
|
||||
|
||||
// insertImageEmbeddingsAt replaces image placeholder tokens with actual image embeddings
|
||||
// at a known position (to avoid re-scanning tokens after eval)
|
||||
// textEmbeds: [B, L, hidden_size] text embeddings
|
||||
// imageEmbeds: [B, 256, hidden_size] image embeddings from projector
|
||||
// startPos: starting position of image tokens in the sequence
|
||||
func (m *Model) insertImageEmbeddingsAt(textEmbeds, imageEmbeds *mlx.Array, startPos int32) *mlx.Array {
|
||||
numImageTokens := imageEmbeds.Shape()[1]
|
||||
L := textEmbeds.Shape()[1]
|
||||
|
||||
// Split text embeddings: [0:startPos] + imageEmbeds + [startPos+256:L]
|
||||
afterStart := startPos + numImageTokens
|
||||
|
||||
// Slice before image tokens: textEmbeds[:, 0:startPos, :]
|
||||
before := mlx.SliceAxis(textEmbeds, 1, 0, startPos)
|
||||
|
||||
// Slice after image tokens: textEmbeds[:, startPos+256:L, :]
|
||||
after := mlx.SliceAxis(textEmbeds, 1, afterStart, L)
|
||||
|
||||
// Concatenate: before + imageEmbeds + after along axis 1
|
||||
return mlx.Concatenate([]*mlx.Array{before, imageEmbeds, after}, 1)
|
||||
}
|
||||
|
||||
// Interface methods for Model
|
||||
func (m *Model) NumLayers() int { return len(m.TextModel.Layers) }
|
||||
func (m *Model) MaxContextLength() int32 { return m.Config.TextConfig.MaxPositionEmbeddings }
|
||||
func (m *Model) VocabSize() int32 { return m.Config.TextConfig.VocabSize }
|
||||
func (m *Model) Tokenizer() *tokenizer.Tokenizer { return m.tok }
|
||||
func (m *Model) NewCache(maxSeqLen int32) []cache.Cache { return m.TextModel.NewCache(maxSeqLen) }
|
||||
func (m *Model) ImageSize() int32 { return m.Config.VisionConfig.ImageSize }
|
||||
|
||||
// FormatPrompt applies the Gemma 3 multimodal chat template
|
||||
func (m *Model) FormatPrompt(prompt string) string {
|
||||
return fmt.Sprintf("<start_of_turn>user\n%s<end_of_turn>\n<start_of_turn>model\n", prompt)
|
||||
}
|
||||
|
||||
// FormatPromptWithImage applies the Gemma 3 multimodal chat template with image
|
||||
func (m *Model) FormatPromptWithImage(prompt string) string {
|
||||
return fmt.Sprintf("<start_of_turn>user\n<start_of_image>%s<end_of_turn>\n<start_of_turn>model\n", prompt)
|
||||
}
|
||||
|
||||
// ExpandImageTokens expands <start_of_image> into 256 image placeholder tokens
|
||||
// Input tokens containing boi_token (255999) are expanded to:
|
||||
// boi_token + 256 * image_token + eoi_token
|
||||
func (m *Model) ExpandImageTokens(tokens []int32) []int32 {
|
||||
result := make([]int32, 0, len(tokens)+int(m.Config.MMTokensPerImage)+1)
|
||||
|
||||
for _, t := range tokens {
|
||||
if t == m.Config.BOITokenIndex {
|
||||
// Expand: boi + 256 * image_token + eoi
|
||||
result = append(result, m.Config.BOITokenIndex)
|
||||
for i := int32(0); i < m.Config.MMTokensPerImage; i++ {
|
||||
result = append(result, m.Config.ImageTokenIndex)
|
||||
}
|
||||
result = append(result, m.Config.EOITokenIndex)
|
||||
} else {
|
||||
result = append(result, t)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
58
x/imagegen/models/gemma3/image.go
Normal file
58
x/imagegen/models/gemma3/image.go
Normal file
@@ -0,0 +1,58 @@
|
||||
//go:build mlx
|
||||
|
||||
package gemma3
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"image"
|
||||
_ "image/jpeg"
|
||||
_ "image/png"
|
||||
"os"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"golang.org/x/image/draw"
|
||||
)
|
||||
|
||||
// ProcessImage loads and preprocesses an image for the vision tower
|
||||
// Returns [1, H, W, C] tensor in NHWC format normalized for SigLIP
|
||||
func ProcessImage(path string, imageSize int32) (*mlx.Array, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open image: %w", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
img, _, err := image.Decode(f)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode image: %w", err)
|
||||
}
|
||||
|
||||
return ProcessImageData(img, imageSize)
|
||||
}
|
||||
|
||||
// ProcessImageData preprocesses an image.Image for the vision tower
|
||||
func ProcessImageData(img image.Image, imageSize int32) (*mlx.Array, error) {
|
||||
// Resize to target size using bilinear interpolation
|
||||
resized := image.NewRGBA(image.Rect(0, 0, int(imageSize), int(imageSize)))
|
||||
draw.BiLinear.Scale(resized, resized.Bounds(), img, img.Bounds(), draw.Over, nil)
|
||||
|
||||
// Convert to float32 array [H, W, C] and normalize
|
||||
// SigLIP normalization: (pixel / 255.0 - 0.5) / 0.5 = pixel / 127.5 - 1.0
|
||||
data := make([]float32, imageSize*imageSize*3)
|
||||
idx := 0
|
||||
for y := int32(0); y < imageSize; y++ {
|
||||
for x := int32(0); x < imageSize; x++ {
|
||||
r, g, b, _ := resized.At(int(x), int(y)).RGBA()
|
||||
// RGBA returns 16-bit values, convert to 8-bit
|
||||
data[idx] = float32(r>>8)/127.5 - 1.0
|
||||
data[idx+1] = float32(g>>8)/127.5 - 1.0
|
||||
data[idx+2] = float32(b>>8)/127.5 - 1.0
|
||||
idx += 3
|
||||
}
|
||||
}
|
||||
|
||||
// Create MLX array [1, H, W, C] for NHWC layout
|
||||
arr := mlx.NewArrayFloat32(data, []int32{1, imageSize, imageSize, 3})
|
||||
mlx.Eval(arr) // Materialize to prevent use-after-free
|
||||
return arr, nil
|
||||
}
|
||||
50
x/imagegen/models/gemma3/projector.go
Normal file
50
x/imagegen/models/gemma3/projector.go
Normal file
@@ -0,0 +1,50 @@
|
||||
//go:build mlx
|
||||
|
||||
package gemma3
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
)
|
||||
|
||||
// MultiModalProjector projects vision features to text embedding space
|
||||
type MultiModalProjector struct {
|
||||
// mm_input_projection_weight: [vision_hidden, text_hidden]
|
||||
InputProjection *mlx.Array `weight:"mm_input_projection_weight"`
|
||||
SoftEmbNorm *nn.RMSNorm `weight:"mm_soft_emb_norm"`
|
||||
|
||||
// Precomputed (1 + weight) for Gemma-style RMSNorm
|
||||
SoftEmbNormScaled *mlx.Array `weight:"-"`
|
||||
}
|
||||
|
||||
// Forward projects vision features to text space
|
||||
// Input: [B, num_patches, vision_hidden] (e.g., [1, 4096, 1152])
|
||||
// Output: [B, num_image_tokens, text_hidden] (e.g., [1, 256, 2560])
|
||||
func (p *MultiModalProjector) Forward(visionFeatures *mlx.Array, eps float32) *mlx.Array {
|
||||
// Average pool 4x4: [B, 4096, 1152] -> [B, 256, 1152]
|
||||
// 4096 patches = 64x64 grid, pool to 16x16 = 256 tokens
|
||||
B := visionFeatures.Shape()[0]
|
||||
visionHidden := visionFeatures.Shape()[2]
|
||||
|
||||
// Reshape to [B, 64, 64, hidden]
|
||||
gridSize := int32(64) // sqrt(4096)
|
||||
pooledSize := int32(16) // 64/4
|
||||
h := mlx.Reshape(visionFeatures, B, gridSize, gridSize, visionHidden)
|
||||
|
||||
// Reshape to [B, 16, 4, 16, 4, hidden] for 4x4 pooling
|
||||
h = mlx.Reshape(h, B, pooledSize, 4, pooledSize, 4, visionHidden)
|
||||
|
||||
// Average over pooling dimensions (axes 2 and 4)
|
||||
h = mlx.Mean(h, 4, false)
|
||||
h = mlx.Mean(h, 2, false)
|
||||
|
||||
// h is now [B, 16, 16, hidden], reshape to [B, 256, hidden]
|
||||
numTokens := pooledSize * pooledSize
|
||||
h = mlx.Reshape(h, B, numTokens, visionHidden)
|
||||
|
||||
// Apply Gemma-style RMS norm (use precomputed 1 + weight)
|
||||
h = mlx.RMSNorm(h, p.SoftEmbNormScaled, eps)
|
||||
|
||||
// Project to text space: [B, 256, vision_hidden] @ [vision_hidden, text_hidden]
|
||||
return mlx.Linear(h, p.InputProjection)
|
||||
}
|
||||
138
x/imagegen/models/gemma3/vision.go
Normal file
138
x/imagegen/models/gemma3/vision.go
Normal file
@@ -0,0 +1,138 @@
|
||||
//go:build mlx
|
||||
|
||||
package gemma3
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
)
|
||||
|
||||
// VisionConfig holds configuration for the SigLIP vision tower
|
||||
type VisionConfig struct {
|
||||
HiddenSize int32 `json:"hidden_size"`
|
||||
ImageSize int32 `json:"image_size"`
|
||||
IntermediateSize int32 `json:"intermediate_size"`
|
||||
NumAttentionHeads int32 `json:"num_attention_heads"`
|
||||
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
||||
PatchSize int32 `json:"patch_size"`
|
||||
}
|
||||
|
||||
// VisionTower is the SigLIP vision encoder
|
||||
type VisionTower struct {
|
||||
Embeddings *VisionEmbeddings `weight:"vision_model.embeddings"`
|
||||
Encoder []*VisionEncoderLayer `weight:"vision_model.encoder.layers"`
|
||||
PostLayerNorm *nn.LayerNorm `weight:"vision_model.post_layernorm"`
|
||||
Config *VisionConfig
|
||||
}
|
||||
|
||||
// VisionEmbeddings handles patch and position embeddings
|
||||
type VisionEmbeddings struct {
|
||||
// PatchWeight: [O, C, kH, kW] from PyTorch, transposed to [O, kH, kW, C] for MLX
|
||||
PatchWeight *mlx.Array `weight:"patch_embedding.weight"`
|
||||
PatchBias *mlx.Array `weight:"patch_embedding.bias"`
|
||||
PosEmbed *nn.Embedding `weight:"position_embedding"`
|
||||
}
|
||||
|
||||
// VisionEncoderLayer is a single transformer encoder layer
|
||||
type VisionEncoderLayer struct {
|
||||
LayerNorm1 *nn.LayerNorm `weight:"layer_norm1"`
|
||||
Attention *VisionAttention `weight:"self_attn"`
|
||||
LayerNorm2 *nn.LayerNorm `weight:"layer_norm2"`
|
||||
MLP *VisionMLP `weight:"mlp"`
|
||||
}
|
||||
|
||||
// VisionAttention implements multi-head self-attention
|
||||
type VisionAttention struct {
|
||||
QProj *nn.Linear `weight:"q_proj"`
|
||||
KProj *nn.Linear `weight:"k_proj"`
|
||||
VProj *nn.Linear `weight:"v_proj"`
|
||||
OutProj *nn.Linear `weight:"out_proj"`
|
||||
}
|
||||
|
||||
// VisionMLP is the feed-forward network
|
||||
type VisionMLP struct {
|
||||
FC1 *nn.Linear `weight:"fc1"`
|
||||
FC2 *nn.Linear `weight:"fc2"`
|
||||
}
|
||||
|
||||
// Forward runs the vision tower on preprocessed images
|
||||
// Input: [B, H, W, C] normalized image tensor (NHWC layout for MLX)
|
||||
// Output: [B, num_patches, hidden_size]
|
||||
func (v *VisionTower) Forward(x *mlx.Array) *mlx.Array {
|
||||
// Patch embedding conv: input [B, H, W, C], weight [O, kH, kW, C] -> [B, grid, grid, O]
|
||||
// Weight comes as [O, C, kH, kW] from PyTorch, transpose to [O, kH, kW, C]
|
||||
weight := mlx.Transpose(v.Embeddings.PatchWeight, 0, 2, 3, 1)
|
||||
h := mlx.Conv2d(x, weight, v.Config.PatchSize, 0) // stride=patch_size, no padding
|
||||
|
||||
// Add bias: [O] -> [1, 1, 1, O] for broadcasting
|
||||
bias := mlx.Reshape(v.Embeddings.PatchBias, 1, 1, 1, v.Embeddings.PatchBias.Shape()[0])
|
||||
h = mlx.Add(h, bias)
|
||||
|
||||
// h is [B, grid, grid, hidden], flatten to [B, num_patches, hidden]
|
||||
B := h.Shape()[0]
|
||||
gridH, gridW := h.Shape()[1], h.Shape()[2]
|
||||
hidden := h.Shape()[3]
|
||||
numPatches := gridH * gridW
|
||||
h = mlx.Reshape(h, B, numPatches, hidden)
|
||||
|
||||
// Add position embeddings
|
||||
posIds := mlx.ArangeInt(0, numPatches, 1, mlx.DtypeInt32)
|
||||
posEmbed := v.Embeddings.PosEmbed.Forward(posIds)
|
||||
h = mlx.Add(h, posEmbed)
|
||||
|
||||
// Encoder layers
|
||||
headDim := float32(v.Config.HiddenSize / v.Config.NumAttentionHeads)
|
||||
scale := float32(1.0 / math.Sqrt(float64(headDim)))
|
||||
for _, layer := range v.Encoder {
|
||||
h = layer.Forward(h, v.Config, scale)
|
||||
}
|
||||
|
||||
// Final layer norm
|
||||
h = v.PostLayerNorm.Forward(h)
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// Forward runs a vision encoder layer
|
||||
func (l *VisionEncoderLayer) Forward(x *mlx.Array, cfg *VisionConfig, scale float32) *mlx.Array {
|
||||
// Pre-norm attention
|
||||
h := l.LayerNorm1.Forward(x)
|
||||
h = l.Attention.Forward(h, cfg, scale)
|
||||
x = mlx.Add(x, h)
|
||||
|
||||
// Pre-norm MLP
|
||||
h = l.LayerNorm2.Forward(x)
|
||||
h = l.MLP.Forward(h)
|
||||
return mlx.Add(x, h)
|
||||
}
|
||||
|
||||
// Forward runs multi-head self-attention
|
||||
func (a *VisionAttention) Forward(x *mlx.Array, cfg *VisionConfig, scale float32) *mlx.Array {
|
||||
B, L := x.Shape()[0], x.Shape()[1]
|
||||
headDim := cfg.HiddenSize / cfg.NumAttentionHeads
|
||||
|
||||
q := a.QProj.Forward(x)
|
||||
k := a.KProj.Forward(x)
|
||||
v := a.VProj.Forward(x)
|
||||
|
||||
// Reshape to [B, num_heads, L, head_dim]
|
||||
q = mlx.Transpose(mlx.Reshape(q, B, L, cfg.NumAttentionHeads, headDim), 0, 2, 1, 3)
|
||||
k = mlx.Transpose(mlx.Reshape(k, B, L, cfg.NumAttentionHeads, headDim), 0, 2, 1, 3)
|
||||
v = mlx.Transpose(mlx.Reshape(v, B, L, cfg.NumAttentionHeads, headDim), 0, 2, 1, 3)
|
||||
|
||||
// Scaled dot-product attention (no causal mask for vision)
|
||||
out := mlx.ScaledDotProductAttention(q, k, v, scale, false)
|
||||
|
||||
// Reshape back: [B, num_heads, L, head_dim] -> [B, L, hidden]
|
||||
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.HiddenSize)
|
||||
|
||||
return a.OutProj.Forward(out)
|
||||
}
|
||||
|
||||
// Forward runs the MLP with GELU activation
|
||||
func (m *VisionMLP) Forward(x *mlx.Array) *mlx.Array {
|
||||
h := mlx.GELU(m.FC1.Forward(x))
|
||||
return m.FC2.Forward(h)
|
||||
}
|
||||
487
x/imagegen/models/gpt_oss/gpt_oss.go
Normal file
487
x/imagegen/models/gpt_oss/gpt_oss.go
Normal file
@@ -0,0 +1,487 @@
|
||||
//go:build mlx
|
||||
|
||||
package gpt_oss
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
)
|
||||
|
||||
// RopeScaling holds YaRN or other RoPE scaling configuration
|
||||
type RopeScaling struct {
|
||||
RopeType string `json:"rope_type"`
|
||||
Factor float32 `json:"factor"`
|
||||
OriginalMaxPositionEmbeddings int32 `json:"original_max_position_embeddings"`
|
||||
BetaFast float32 `json:"beta_fast"`
|
||||
BetaSlow float32 `json:"beta_slow"`
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
HiddenSize int32 `json:"hidden_size"`
|
||||
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
||||
IntermediateSize int32 `json:"intermediate_size"`
|
||||
NumAttentionHeads int32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads int32 `json:"num_key_value_heads"`
|
||||
VocabSize int32 `json:"vocab_size"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
HeadDim int32 `json:"head_dim"`
|
||||
SlidingWindow int32 `json:"sliding_window"`
|
||||
NumLocalExperts int32 `json:"num_local_experts"`
|
||||
NumExpertsPerTok int32 `json:"num_experts_per_tok"`
|
||||
LayerTypes []string `json:"layer_types"`
|
||||
SwiGLULimit float32 `json:"swiglu_limit"`
|
||||
RopeScaling *RopeScaling `json:"rope_scaling"`
|
||||
Scale float32 `json:"-"` // computed: 1/sqrt(HeadDim)
|
||||
}
|
||||
|
||||
type Attention struct {
|
||||
QProj *nn.Linear `weight:"self_attn.q_proj"`
|
||||
KProj *nn.Linear `weight:"self_attn.k_proj"`
|
||||
VProj *nn.Linear `weight:"self_attn.v_proj"`
|
||||
OProj *nn.Linear `weight:"self_attn.o_proj"`
|
||||
Sinks *mlx.Array `weight:"self_attn.sinks,optional"`
|
||||
YarnFreqs *mlx.Array // computed
|
||||
YarnMscale float32
|
||||
}
|
||||
|
||||
// swiGLU applies the GPT-OSS custom SwiGLU activation.
|
||||
// Formula: (gate * sigmoid(alpha * gate)) * (up + 1)
|
||||
// with clipping: gate to [None, limit], up to [-limit, limit]
|
||||
func swiGLU(gate, up *mlx.Array, alpha, limit float32) *mlx.Array {
|
||||
// Clip gate to [None, limit]
|
||||
gateClipped := mlx.ClipScalar(gate, 0, limit, false, true)
|
||||
|
||||
// Clip up to [-limit, limit]
|
||||
upClipped := mlx.ClipScalar(up, -limit, limit, true, true)
|
||||
|
||||
// glu_scaled = alpha * gate_clipped
|
||||
gluScaled := mlx.MulScalar(gateClipped, alpha)
|
||||
|
||||
// sig = sigmoid(glu_scaled)
|
||||
sig := mlx.Sigmoid(gluScaled)
|
||||
|
||||
// out_glu = gate_clipped * sig
|
||||
outGlu := mlx.Mul(gateClipped, sig)
|
||||
|
||||
// result = out_glu * (up_clipped + 1)
|
||||
return mlx.Mul(outGlu, mlx.AddScalar(upClipped, 1.0))
|
||||
}
|
||||
|
||||
// compiledSwiGLU is a singleton compiled SwiGLU function shared across all layers
|
||||
var compiledSwiGLU *mlx.CompiledFunc
|
||||
|
||||
// getCompiledSwiGLU returns the compiled SwiGLU function, creating it once if needed
|
||||
func getCompiledSwiGLU() *mlx.CompiledFunc {
|
||||
if compiledSwiGLU == nil {
|
||||
const alpha float32 = 1.702
|
||||
const limit float32 = 7.0
|
||||
compiledSwiGLU = mlx.CompileShapeless(func(inputs []*mlx.Array) []*mlx.Array {
|
||||
return []*mlx.Array{swiGLU(inputs[0], inputs[1], alpha, limit)}
|
||||
}, true) // shapeless=true so it works for any input size
|
||||
}
|
||||
return compiledSwiGLU
|
||||
}
|
||||
|
||||
// ComputeYarnFreqs computes YaRN-modified RoPE frequencies
|
||||
// Based on mlx-lm's YarnRoPE implementation
|
||||
func ComputeYarnFreqs(dims int32, base, scalingFactor float32, origMaxPos int32, betaFast, betaSlow float32) (*mlx.Array, float32) {
|
||||
// yarn_find_correction_dim
|
||||
yarnFindCorrectionDim := func(numRotations float64) float64 {
|
||||
return float64(dims) * math.Log(float64(origMaxPos)/(numRotations*2*math.Pi)) / (2 * math.Log(float64(base)))
|
||||
}
|
||||
|
||||
// yarn_find_correction_range
|
||||
low := int(math.Floor(yarnFindCorrectionDim(float64(betaFast))))
|
||||
high := int(math.Ceil(yarnFindCorrectionDim(float64(betaSlow))))
|
||||
if low < 0 {
|
||||
low = 0
|
||||
}
|
||||
if high > int(dims)-1 {
|
||||
high = int(dims) - 1
|
||||
}
|
||||
|
||||
// yarn_get_mscale
|
||||
yarnGetMscale := func(scale, mscale float64) float64 {
|
||||
if scale <= 1 {
|
||||
return 1.0
|
||||
}
|
||||
return 0.1*mscale*math.Log(scale) + 1.0
|
||||
}
|
||||
mscale := float32(yarnGetMscale(float64(scalingFactor), 1.0) / yarnGetMscale(float64(scalingFactor), 0.0))
|
||||
|
||||
// Compute frequencies
|
||||
// freq_extra = base ** (arange(0, dims, 2) / dims)
|
||||
// freq_inter = scaling_factor * freq_extra
|
||||
halfDims := dims / 2
|
||||
freqData := make([]float32, halfDims)
|
||||
for i := int32(0); i < halfDims; i++ {
|
||||
exp := float64(2*i) / float64(dims)
|
||||
freqExtra := math.Pow(float64(base), exp)
|
||||
freqInter := float64(scalingFactor) * freqExtra
|
||||
|
||||
// linear ramp mask
|
||||
var freqMask float64
|
||||
if low == high {
|
||||
freqMask = 0.0
|
||||
} else {
|
||||
t := (float64(i) - float64(low)) / float64(high-low)
|
||||
if t < 0 {
|
||||
t = 0
|
||||
}
|
||||
if t > 1 {
|
||||
t = 1
|
||||
}
|
||||
freqMask = 1.0 - t
|
||||
}
|
||||
|
||||
// Combined frequency: (inter * extra) / (inter * mask + extra * (1 - mask))
|
||||
freqData[i] = float32((freqInter * freqExtra) / (freqInter*freqMask + freqExtra*(1-freqMask)))
|
||||
}
|
||||
|
||||
return mlx.NewArray(freqData, []int32{halfDims}), mscale
|
||||
}
|
||||
|
||||
// initYarn initializes YaRN RoPE if configured
|
||||
func (a *Attention) initYarn(cfg *Config) {
|
||||
a.YarnMscale = 1.0
|
||||
if cfg.RopeScaling != nil && cfg.RopeScaling.RopeType == "yarn" {
|
||||
a.YarnFreqs, a.YarnMscale = ComputeYarnFreqs(
|
||||
cfg.HeadDim,
|
||||
cfg.RopeTheta,
|
||||
cfg.RopeScaling.Factor,
|
||||
cfg.RopeScaling.OriginalMaxPositionEmbeddings,
|
||||
cfg.RopeScaling.BetaFast,
|
||||
cfg.RopeScaling.BetaSlow,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, mask *mlx.Array, maskMode string, cfg *Config) *mlx.Array {
|
||||
q := a.QProj.Forward(x)
|
||||
k := a.KProj.Forward(x)
|
||||
v := a.VProj.Forward(x)
|
||||
|
||||
// Reshape via AsStrided: [B, L, n_heads * head_dim] -> [B, n_heads, L, head_dim]
|
||||
q = mlx.AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim},
|
||||
[]int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0)
|
||||
k = mlx.AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
|
||||
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
|
||||
v = mlx.AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
|
||||
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
|
||||
|
||||
offset := 0
|
||||
if c != nil {
|
||||
offset = c.Offset()
|
||||
}
|
||||
if a.YarnFreqs != nil {
|
||||
if a.YarnMscale != 1.0 {
|
||||
q = mlx.MulScalar(q, a.YarnMscale)
|
||||
}
|
||||
q = mlx.RoPEWithFreqs(q, a.YarnFreqs, int(cfg.HeadDim), false, 1.0, offset)
|
||||
k = mlx.RoPEWithFreqs(k, a.YarnFreqs, int(cfg.HeadDim), false, 1.0, offset)
|
||||
} else {
|
||||
q = mlx.RoPE(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
|
||||
k = mlx.RoPE(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
|
||||
}
|
||||
|
||||
if c != nil {
|
||||
k, v = c.Update(k, v, int(L))
|
||||
}
|
||||
|
||||
out := mlx.ScaledDotProductAttentionWithSinks(q, k, v, cfg.Scale, maskMode, mask, a.Sinks)
|
||||
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
|
||||
return a.OProj.Forward(out)
|
||||
}
|
||||
|
||||
// CreateSlidingWindowMask creates a causal mask with sliding window
|
||||
// Mirrors mlx-lm's create_causal_mask with window_size
|
||||
func CreateSlidingWindowMask(seqLen, queryStart, keyStart, keyLen, windowSize int) *mlx.Array {
|
||||
// Build mask aligned to actual cache length (may be rotated)
|
||||
// rinds covers existing keys: [keyStart, keyStart+keyLen)
|
||||
// linds covers new queries: [queryStart, queryStart+seqLen)
|
||||
rinds := mlx.Arange(float32(keyStart), float32(keyStart+keyLen), 1) // [keyLen]
|
||||
linds := mlx.Arange(float32(queryStart), float32(queryStart+seqLen), 1) // [seqLen]
|
||||
|
||||
linds = mlx.ExpandDims(linds, 1) // [seqLen, 1]
|
||||
rinds = mlx.ExpandDims(rinds, 0) // [1, keyLen]
|
||||
|
||||
causalMask := mlx.GreaterEqual(linds, rinds) // [seqLen, keyLen]
|
||||
windowLimit := mlx.AddScalar(rinds, float32(windowSize))
|
||||
windowMask := mlx.LessArray(linds, windowLimit) // [seqLen, keyLen]
|
||||
|
||||
return mlx.LogicalAnd(causalMask, windowMask)
|
||||
}
|
||||
|
||||
// MoE represents the Mixture of Experts SwiGLU layer with quantized experts.
|
||||
type MoE struct {
|
||||
Router *nn.Linear `weight:"mlp.router"`
|
||||
TopK int32
|
||||
HiddenSize int32
|
||||
GroupSize int
|
||||
Bits int
|
||||
// Expert weights (loaded manually via sanitizeExpertWeights)
|
||||
GateBlocks, GateScales, GateBias *mlx.Array
|
||||
UpBlocks, UpScales, UpBias *mlx.Array
|
||||
DownBlocks, DownScales, DownBias *mlx.Array
|
||||
}
|
||||
|
||||
func (moe *MoE) Forward(x *mlx.Array, B, L int32) *mlx.Array {
|
||||
logits := moe.Router.Forward(x)
|
||||
neg := mlx.Neg(logits)
|
||||
part := mlx.Argpartition(neg, int(moe.TopK)-1, -1)
|
||||
topKIdx := mlx.Slice(part, []int32{0, 0, 0}, []int32{B, L, moe.TopK})
|
||||
topKVal := mlx.TakeAlongAxis(logits, topKIdx, -1)
|
||||
weights := mlx.Softmax(topKVal, -1)
|
||||
|
||||
xFlat := mlx.Reshape(x, B*L, 1, 1, moe.HiddenSize)
|
||||
idxFlat := mlx.Reshape(topKIdx, B*L, moe.TopK)
|
||||
|
||||
doSort := B*L >= 64
|
||||
var invOrder *mlx.Array
|
||||
sorted := false
|
||||
n := B * L * moe.TopK
|
||||
|
||||
if doSort {
|
||||
idxAll := mlx.Flatten(idxFlat)
|
||||
order := mlx.Argsort(idxAll, 0)
|
||||
invOrder = mlx.Argsort(order, 0)
|
||||
xFlat = mlx.ExpandDims(mlx.Take(mlx.Squeeze(xFlat, 1), mlx.FloorDivideScalar(order, moe.TopK), 0), 1)
|
||||
idxFlat = mlx.Reshape(mlx.Take(idxAll, order, 0), n, 1)
|
||||
sorted = true
|
||||
}
|
||||
|
||||
gate := mlx.GatherQMM(xFlat, moe.GateBlocks, moe.GateScales, nil, nil, idxFlat, true, moe.GroupSize, moe.Bits, "mxfp4", sorted)
|
||||
up := mlx.GatherQMM(xFlat, moe.UpBlocks, moe.UpScales, nil, nil, idxFlat, true, moe.GroupSize, moe.Bits, "mxfp4", sorted)
|
||||
|
||||
if moe.GateBias != nil {
|
||||
gate = mlx.Add(gate, mlx.ExpandDims(mlx.Take(moe.GateBias, idxFlat, 0), 2))
|
||||
}
|
||||
if moe.UpBias != nil {
|
||||
up = mlx.Add(up, mlx.ExpandDims(mlx.Take(moe.UpBias, idxFlat, 0), 2))
|
||||
}
|
||||
|
||||
hidden := getCompiledSwiGLU().Call(gate, up)[0]
|
||||
|
||||
down := mlx.GatherQMM(hidden, moe.DownBlocks, moe.DownScales, nil, nil, idxFlat, true, moe.GroupSize, moe.Bits, "mxfp4", sorted)
|
||||
if moe.DownBias != nil {
|
||||
down = mlx.Add(down, mlx.ExpandDims(mlx.Take(moe.DownBias, idxFlat, 0), 2))
|
||||
}
|
||||
|
||||
if doSort {
|
||||
down = mlx.Reshape(mlx.Take(mlx.Squeeze(mlx.Squeeze(down, 2), 1), invOrder, 0), B*L, moe.TopK, moe.HiddenSize)
|
||||
} else {
|
||||
down = mlx.Squeeze(down, 2)
|
||||
}
|
||||
|
||||
ewFlat := mlx.Reshape(weights, B*L, moe.TopK, 1)
|
||||
return mlx.Reshape(mlx.Sum(mlx.Mul(down, ewFlat), 1, false), B, L, moe.HiddenSize)
|
||||
}
|
||||
|
||||
type Block struct {
|
||||
Attention *Attention
|
||||
MLP *MoE
|
||||
InputNorm *nn.RMSNorm `weight:"input_layernorm"`
|
||||
PostAttnNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
|
||||
LayerType string // "sliding_attention" or "full_attention"
|
||||
}
|
||||
|
||||
func (b *Block) Forward(x *mlx.Array, c cache.Cache, B, L int32, mask *mlx.Array, maskMode string, cfg *Config) *mlx.Array {
|
||||
h := mlx.Add(x, b.Attention.Forward(b.InputNorm.Forward(x, cfg.RMSNormEps), c, B, L, mask, maskMode, cfg))
|
||||
return mlx.Add(h, b.MLP.Forward(b.PostAttnNorm.Forward(h, cfg.RMSNormEps), B, L))
|
||||
}
|
||||
|
||||
type Model struct {
|
||||
EmbedTokens *nn.Embedding `weight:"model.embed_tokens"`
|
||||
Layers []*Block `weight:"-"` // loaded manually due to MoE sanitization
|
||||
Norm *nn.RMSNorm `weight:"model.norm"`
|
||||
LMHead *nn.Linear `weight:"lm_head"`
|
||||
|
||||
tok *tokenizer.Tokenizer
|
||||
*Config
|
||||
}
|
||||
|
||||
func (m *Model) Tokenizer() *tokenizer.Tokenizer { return m.tok }
|
||||
func (m *Model) NumLayers() int { return len(m.Layers) }
|
||||
func (m *Model) VocabSize() int32 { return m.Config.VocabSize }
|
||||
|
||||
func (m *Model) NewCache(int32) []cache.Cache {
|
||||
caches := make([]cache.Cache, len(m.Layers))
|
||||
for i, layer := range m.Layers {
|
||||
if layer.LayerType == "sliding_attention" && m.SlidingWindow > 0 {
|
||||
caches[i] = cache.NewRotatingKVCache(int(m.SlidingWindow))
|
||||
} else {
|
||||
caches[i] = cache.NewKVCache()
|
||||
}
|
||||
}
|
||||
return caches
|
||||
}
|
||||
|
||||
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
|
||||
B, L := tokens.Shape()[0], tokens.Shape()[1]
|
||||
x := m.EmbedTokens.Forward(tokens)
|
||||
|
||||
// Find representative cache indices for sliding window attention
|
||||
var swaIdx int = -1
|
||||
for i, layer := range m.Layers {
|
||||
if layer.LayerType == "sliding_attention" {
|
||||
swaIdx = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Create masks once at model level
|
||||
var fullMask, swaMask *mlx.Array
|
||||
var fullMaskMode, swaMaskMode string
|
||||
|
||||
if L > 1 {
|
||||
fullMaskMode = "causal"
|
||||
if swaIdx >= 0 && m.SlidingWindow > 0 && caches != nil {
|
||||
c := caches[swaIdx]
|
||||
offset := c.Offset()
|
||||
windowSize := int(m.SlidingWindow)
|
||||
cacheLen := min(int(L), windowSize)
|
||||
if offset > 0 {
|
||||
cacheLen = min(c.Len()+int(L), windowSize)
|
||||
}
|
||||
if int(L) > windowSize {
|
||||
swaMask = CreateSlidingWindowMask(int(L), offset, offset+int(L)-cacheLen, cacheLen, windowSize)
|
||||
} else {
|
||||
swaMaskMode = "causal"
|
||||
}
|
||||
} else {
|
||||
swaMaskMode = "causal"
|
||||
}
|
||||
}
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
var c cache.Cache
|
||||
if caches != nil {
|
||||
c = caches[i]
|
||||
}
|
||||
mask, maskMode := fullMask, fullMaskMode
|
||||
if layer.LayerType == "sliding_attention" {
|
||||
mask, maskMode = swaMask, swaMaskMode
|
||||
}
|
||||
x = layer.Forward(x, c, B, L, mask, maskMode, m.Config)
|
||||
}
|
||||
|
||||
return m.LMHead.Forward(m.Norm.Forward(x, m.RMSNormEps))
|
||||
}
|
||||
|
||||
// sanitizeExpertWeights splits merged gate_up weights into separate gate/up arrays.
|
||||
// MXFP4 quantized weights require contiguous memory - strided views give wrong results.
|
||||
func sanitizeExpertWeights(weights *safetensors.ModelWeights, prefix string) (moe *MoE) {
|
||||
gateUpBlocks, _ := weights.GetTensor(prefix + ".mlp.experts.gate_up_proj_blocks")
|
||||
gateUpScales, _ := weights.GetTensor(prefix + ".mlp.experts.gate_up_proj_scales")
|
||||
gateUpBias, _ := weights.GetTensor(prefix + ".mlp.experts.gate_up_proj_bias")
|
||||
downBlocks, _ := weights.GetTensor(prefix + ".mlp.experts.down_proj_blocks")
|
||||
downScales, _ := weights.GetTensor(prefix + ".mlp.experts.down_proj_scales")
|
||||
downBias, _ := weights.GetTensor(prefix + ".mlp.experts.down_proj_bias")
|
||||
|
||||
moe = &MoE{GroupSize: 32, Bits: 4, DownScales: downScales, DownBias: downBias}
|
||||
|
||||
if gateUpBlocks != nil {
|
||||
gub := mlx.FlattenRange(mlx.View(gateUpBlocks, int(mlx.DtypeUint32)), -2, -1)
|
||||
s := gub.Shape()
|
||||
moe.GateBlocks = mlx.Contiguous(mlx.SliceStride(gub, []int32{0, 0, 0}, []int32{s[0], s[1], s[2]}, []int32{1, 2, 1}))
|
||||
moe.UpBlocks = mlx.Contiguous(mlx.SliceStride(gub, []int32{0, 1, 0}, []int32{s[0], s[1], s[2]}, []int32{1, 2, 1}))
|
||||
}
|
||||
if gateUpScales != nil {
|
||||
s := gateUpScales.Shape()
|
||||
moe.GateScales = mlx.Contiguous(mlx.SliceStride(gateUpScales, []int32{0, 0, 0}, []int32{s[0], s[1], s[2]}, []int32{1, 2, 1}))
|
||||
moe.UpScales = mlx.Contiguous(mlx.SliceStride(gateUpScales, []int32{0, 1, 0}, []int32{s[0], s[1], s[2]}, []int32{1, 2, 1}))
|
||||
}
|
||||
if gateUpBias != nil {
|
||||
s := gateUpBias.Shape()
|
||||
moe.GateBias = mlx.Contiguous(mlx.SliceStride(gateUpBias, []int32{0, 0}, []int32{s[0], s[1]}, []int32{1, 2}))
|
||||
moe.UpBias = mlx.Contiguous(mlx.SliceStride(gateUpBias, []int32{0, 1}, []int32{s[0], s[1]}, []int32{1, 2}))
|
||||
}
|
||||
if downBlocks != nil {
|
||||
moe.DownBlocks = mlx.FlattenRange(mlx.View(downBlocks, int(mlx.DtypeUint32)), -2, -1)
|
||||
}
|
||||
return moe
|
||||
}
|
||||
|
||||
func Load(modelPath string) (*Model, error) {
|
||||
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
var cfg Config
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("parse config: %w", err)
|
||||
}
|
||||
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
|
||||
|
||||
weights, err := safetensors.LoadModelWeights(modelPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
|
||||
tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load tokenizer: %w", err)
|
||||
}
|
||||
|
||||
m := &Model{
|
||||
Layers: make([]*Block, cfg.NumHiddenLayers),
|
||||
Config: &cfg,
|
||||
tok: tok,
|
||||
}
|
||||
|
||||
// Load simple weights via struct tags
|
||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Load layers with custom MoE handling
|
||||
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
|
||||
prefix := fmt.Sprintf("model.layers.%d", i)
|
||||
layer := &Block{}
|
||||
if err := safetensors.LoadModule(layer, weights, prefix); err != nil {
|
||||
return nil, fmt.Errorf("layer %d: %w", i, err)
|
||||
}
|
||||
|
||||
// Initialize attention YaRN
|
||||
layer.Attention.initYarn(&cfg)
|
||||
|
||||
// Load MoE with weight sanitization
|
||||
moe := sanitizeExpertWeights(weights, prefix)
|
||||
moe.Router = layer.MLP.Router // Router was loaded by LoadModule
|
||||
moe.TopK = cfg.NumExpertsPerTok
|
||||
moe.HiddenSize = cfg.HiddenSize
|
||||
layer.MLP = moe
|
||||
|
||||
// Set layer type
|
||||
layer.LayerType = "full_attention"
|
||||
if int(i) < len(cfg.LayerTypes) {
|
||||
layer.LayerType = cfg.LayerTypes[i]
|
||||
}
|
||||
|
||||
m.Layers[i] = layer
|
||||
}
|
||||
|
||||
// Release safetensors BEFORE eval - lazy arrays have captured data,
|
||||
// this reduces peak memory by freeing mmap during materialization
|
||||
weights.ReleaseAll()
|
||||
mlx.Eval(mlx.Collect(m)...)
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m *Model) MaxContextLength() int32 {
|
||||
if m.RopeScaling != nil && m.RopeScaling.OriginalMaxPositionEmbeddings > 0 {
|
||||
return m.RopeScaling.OriginalMaxPositionEmbeddings
|
||||
}
|
||||
return 131072
|
||||
}
|
||||
152
x/imagegen/models/llama/llama.go
Normal file
152
x/imagegen/models/llama/llama.go
Normal file
@@ -0,0 +1,152 @@
|
||||
//go:build mlx
|
||||
|
||||
package llama
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
HiddenSize int32 `json:"hidden_size"`
|
||||
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
||||
IntermediateSize int32 `json:"intermediate_size"`
|
||||
NumAttentionHeads int32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads int32 `json:"num_key_value_heads"`
|
||||
VocabSize int32 `json:"vocab_size"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
|
||||
HeadDim int32 `json:"-"`
|
||||
Scale float32 `json:"-"`
|
||||
}
|
||||
|
||||
type Model struct {
|
||||
EmbedTokens *nn.Embedding `weight:"model.embed_tokens"`
|
||||
Layers []*Layer `weight:"model.layers"`
|
||||
Norm *nn.RMSNorm `weight:"model.norm"`
|
||||
Output *nn.Linear `weight:"lm_head,optional"`
|
||||
|
||||
tok *tokenizer.Tokenizer
|
||||
*Config
|
||||
}
|
||||
|
||||
type Layer struct {
|
||||
Attention *Attention
|
||||
MLP *MLP
|
||||
AttentionNorm *nn.RMSNorm `weight:"input_layernorm"`
|
||||
MLPNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
|
||||
}
|
||||
|
||||
type Attention struct {
|
||||
QProj *nn.Linear `weight:"self_attn.q_proj"`
|
||||
KProj *nn.Linear `weight:"self_attn.k_proj"`
|
||||
VProj *nn.Linear `weight:"self_attn.v_proj"`
|
||||
OProj *nn.Linear `weight:"self_attn.o_proj"`
|
||||
}
|
||||
|
||||
type MLP struct {
|
||||
GateProj *nn.Linear `weight:"mlp.gate_proj"`
|
||||
UpProj *nn.Linear `weight:"mlp.up_proj"`
|
||||
DownProj *nn.Linear `weight:"mlp.down_proj"`
|
||||
}
|
||||
|
||||
func Load(modelPath string) (*Model, error) {
|
||||
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
var cfg Config
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("parse config: %w", err)
|
||||
}
|
||||
cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads
|
||||
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
|
||||
|
||||
weights, err := safetensors.LoadModelWeights(modelPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
|
||||
tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load tokenizer: %w", err)
|
||||
}
|
||||
|
||||
m := &Model{
|
||||
Layers: make([]*Layer, cfg.NumHiddenLayers),
|
||||
Config: &cfg,
|
||||
tok: tok,
|
||||
}
|
||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m.Output = nn.NewLinear(m.EmbedTokens.Weight, nil)
|
||||
|
||||
mlx.Eval(mlx.Collect(m)...)
|
||||
weights.ReleaseAll()
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
|
||||
B, L := tokens.Shape()[0], tokens.Shape()[1]
|
||||
h := m.EmbedTokens.Forward(tokens)
|
||||
for i, layer := range m.Layers {
|
||||
h = layer.Forward(h, caches[i], B, L, m.Config)
|
||||
}
|
||||
return m.Output.Forward(m.Norm.Forward(h, m.RMSNormEps))
|
||||
}
|
||||
|
||||
func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||
h := mlx.Add(x, l.Attention.Forward(l.AttentionNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg))
|
||||
return mlx.Add(h, l.MLP.Forward(l.MLPNorm.Forward(h, cfg.RMSNormEps)))
|
||||
}
|
||||
|
||||
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||
q := a.QProj.Forward(x)
|
||||
k := a.KProj.Forward(x)
|
||||
v := a.VProj.Forward(x)
|
||||
|
||||
q = mlx.AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim},
|
||||
[]int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0)
|
||||
k = mlx.AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
|
||||
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
|
||||
v = mlx.AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
|
||||
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
|
||||
|
||||
q = mlx.RoPE(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, c.Offset())
|
||||
k = mlx.RoPE(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, c.Offset())
|
||||
|
||||
k, v = c.Update(k, v, int(L))
|
||||
out := mlx.ScaledDotProductAttention(q, k, v, cfg.Scale, L > 1)
|
||||
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
|
||||
return a.OProj.Forward(out)
|
||||
}
|
||||
|
||||
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
|
||||
return m.DownProj.Forward(mlx.Mul(mlx.SiLU(m.GateProj.Forward(x)), m.UpProj.Forward(x)))
|
||||
}
|
||||
|
||||
// Interface methods
|
||||
func (m *Model) NumLayers() int { return len(m.Layers) }
|
||||
func (m *Model) MaxContextLength() int32 { return m.MaxPositionEmbeddings }
|
||||
func (m *Model) VocabSize() int32 { return m.Config.VocabSize }
|
||||
func (m *Model) Tokenizer() *tokenizer.Tokenizer { return m.tok }
|
||||
|
||||
func (m *Model) NewCache(maxSeqLen int32) []cache.Cache {
|
||||
caches := make([]cache.Cache, len(m.Layers))
|
||||
for i := range caches {
|
||||
caches[i] = cache.NewKVCache()
|
||||
}
|
||||
return caches
|
||||
}
|
||||
66
x/imagegen/models/qwen_image/pipeline_test.go
Normal file
66
x/imagegen/models/qwen_image/pipeline_test.go
Normal file
@@ -0,0 +1,66 @@
|
||||
//go:build mlx
|
||||
|
||||
package qwen_image
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// TestPipelineOutput runs the full pipeline (integration test).
|
||||
// Skips if model weights not found. Requires ~50GB VRAM.
|
||||
func TestPipelineOutput(t *testing.T) {
|
||||
modelPath := "../../../weights/Qwen-Image-2512"
|
||||
if _, err := os.Stat(modelPath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping: model weights not found at " + modelPath)
|
||||
}
|
||||
|
||||
// Load model
|
||||
pm, err := LoadPersistent(modelPath)
|
||||
if err != nil {
|
||||
t.Skipf("Skipping: failed to load model: %v", err)
|
||||
}
|
||||
|
||||
// Run 2-step pipeline (minimum for stable scheduler)
|
||||
cfg := &GenerateConfig{
|
||||
Prompt: "a cat",
|
||||
Width: 256,
|
||||
Height: 256,
|
||||
Steps: 2,
|
||||
Seed: 42,
|
||||
}
|
||||
|
||||
output, err := pm.GenerateFromConfig(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("Pipeline failed: %v", err)
|
||||
}
|
||||
mlx.Eval(output)
|
||||
|
||||
// Verify output shape [1, C, H, W]
|
||||
shape := output.Shape()
|
||||
if len(shape) != 4 {
|
||||
t.Errorf("Expected 4D output, got %v", shape)
|
||||
}
|
||||
if shape[0] != 1 || shape[1] != 3 || shape[2] != cfg.Height || shape[3] != cfg.Width {
|
||||
t.Errorf("Shape mismatch: got %v, expected [1, 3, %d, %d]", shape, cfg.Height, cfg.Width)
|
||||
}
|
||||
|
||||
// Verify values in expected range [0, 1]
|
||||
data := output.Data()
|
||||
minVal, maxVal := float32(1.0), float32(0.0)
|
||||
for _, v := range data {
|
||||
if v < minVal {
|
||||
minVal = v
|
||||
}
|
||||
if v > maxVal {
|
||||
maxVal = v
|
||||
}
|
||||
}
|
||||
t.Logf("Output range: [%.4f, %.4f]", minVal, maxVal)
|
||||
|
||||
if minVal < -0.1 || maxVal > 1.1 {
|
||||
t.Errorf("Output values out of range: [%.4f, %.4f]", minVal, maxVal)
|
||||
}
|
||||
}
|
||||
1802
x/imagegen/models/qwen_image/qwen25vl.go
Normal file
1802
x/imagegen/models/qwen_image/qwen25vl.go
Normal file
File diff suppressed because it is too large
Load Diff
350
x/imagegen/models/qwen_image/qwen_image.go
Normal file
350
x/imagegen/models/qwen_image/qwen_image.go
Normal file
@@ -0,0 +1,350 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package qwen_image implements the Qwen-Image diffusion transformer model.
|
||||
package qwen_image
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
)
|
||||
|
||||
// GenerateConfig holds all options for image generation.
|
||||
type GenerateConfig struct {
|
||||
Prompt string
|
||||
NegativePrompt string // Empty = no CFG
|
||||
CFGScale float32 // Only used if NegativePrompt is set (default: 4.0)
|
||||
Width int32 // Image width (default: 1024)
|
||||
Height int32 // Image height (default: 1024)
|
||||
Steps int // Denoising steps (default: 30)
|
||||
Seed int64 // Random seed
|
||||
Progress ProgressFunc // Optional progress callback
|
||||
|
||||
// Layer caching (DeepCache/Learning-to-Cache speedup)
|
||||
LayerCache bool // Enable layer caching (default: false)
|
||||
CacheInterval int // Refresh cache every N steps (default: 3)
|
||||
CacheLayers int // Number of shallow layers to cache (default: 25)
|
||||
}
|
||||
|
||||
// ProgressFunc is called during generation with step progress.
|
||||
type ProgressFunc func(step, totalSteps int)
|
||||
|
||||
// Model represents a Qwen-Image diffusion model.
|
||||
type Model struct {
|
||||
ModelPath string
|
||||
Tokenizer *tokenizer.Tokenizer
|
||||
TextEncoder *Qwen25VL
|
||||
Transformer *Transformer
|
||||
VAEDecoder *VAEDecoder
|
||||
}
|
||||
|
||||
// Load loads the Qwen-Image model from a directory.
|
||||
func (m *Model) Load(modelPath string) error {
|
||||
fmt.Println("Loading Qwen-Image model...")
|
||||
start := time.Now()
|
||||
|
||||
if mlx.GPUIsAvailable() {
|
||||
mlx.SetDefaultDeviceGPU()
|
||||
mlx.EnableCompile()
|
||||
}
|
||||
|
||||
m.ModelPath = modelPath
|
||||
|
||||
// Load tokenizer
|
||||
fmt.Print(" Loading tokenizer... ")
|
||||
tokenizerPath := filepath.Join(modelPath, "tokenizer")
|
||||
tok, err := tokenizer.Load(tokenizerPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("tokenizer: %w", err)
|
||||
}
|
||||
m.Tokenizer = tok
|
||||
fmt.Println("✓")
|
||||
|
||||
// Load text encoder (Qwen2.5-VL in text-only mode - skip vision tower for efficiency)
|
||||
m.TextEncoder = &Qwen25VL{}
|
||||
if err := m.TextEncoder.LoadTextOnly(filepath.Join(modelPath, "text_encoder")); err != nil {
|
||||
return fmt.Errorf("text encoder: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.TextEncoder)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
// Load transformer
|
||||
m.Transformer = &Transformer{}
|
||||
if err := m.Transformer.Load(filepath.Join(modelPath, "transformer")); err != nil {
|
||||
return fmt.Errorf("transformer: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.Transformer)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
// Load VAE decoder
|
||||
m.VAEDecoder = &VAEDecoder{}
|
||||
if err := m.VAEDecoder.Load(filepath.Join(modelPath, "vae")); err != nil {
|
||||
return fmt.Errorf("VAE decoder: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.VAEDecoder)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
mem := mlx.MetalGetActiveMemory()
|
||||
peak := mlx.MetalGetPeakMemory()
|
||||
fmt.Printf(" Loaded in %.2fs (%.1f GB active, %.1f GB peak)\n",
|
||||
time.Since(start).Seconds(),
|
||||
float64(mem)/(1024*1024*1024),
|
||||
float64(peak)/(1024*1024*1024))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Generate creates an image from a prompt.
|
||||
func (m *Model) Generate(prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
|
||||
return m.GenerateFromConfig(&GenerateConfig{
|
||||
Prompt: prompt,
|
||||
Width: width,
|
||||
Height: height,
|
||||
Steps: steps,
|
||||
Seed: seed,
|
||||
})
|
||||
}
|
||||
|
||||
// GenerateWithProgress creates an image with progress callback.
|
||||
func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress ProgressFunc) (*mlx.Array, error) {
|
||||
return m.GenerateFromConfig(&GenerateConfig{
|
||||
Prompt: prompt,
|
||||
Width: width,
|
||||
Height: height,
|
||||
Steps: steps,
|
||||
Seed: seed,
|
||||
Progress: progress,
|
||||
})
|
||||
}
|
||||
|
||||
// GenerateWithCFG creates an image with classifier-free guidance.
|
||||
func (m *Model) GenerateWithCFG(prompt, negativePrompt string, width, height int32, steps int, seed int64, cfgScale float32, progress ProgressFunc) (*mlx.Array, error) {
|
||||
return m.GenerateFromConfig(&GenerateConfig{
|
||||
Prompt: prompt,
|
||||
NegativePrompt: negativePrompt,
|
||||
CFGScale: cfgScale,
|
||||
Width: width,
|
||||
Height: height,
|
||||
Steps: steps,
|
||||
Seed: seed,
|
||||
Progress: progress,
|
||||
})
|
||||
}
|
||||
|
||||
// GenerateFromConfig generates an image using the unified config struct.
|
||||
func (m *Model) GenerateFromConfig(cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
start := time.Now()
|
||||
result, err := m.generate(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if cfg.NegativePrompt != "" {
|
||||
fmt.Printf("Generated with CFG (scale=%.1f) in %.2fs (%d steps)\n", cfg.CFGScale, time.Since(start).Seconds(), cfg.Steps)
|
||||
} else {
|
||||
fmt.Printf("Generated in %.2fs (%d steps)\n", time.Since(start).Seconds(), cfg.Steps)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GenerateImage implements model.ImageModel interface.
|
||||
func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
|
||||
return m.Generate(prompt, width, height, steps, seed)
|
||||
}
|
||||
|
||||
// generate is the internal denoising pipeline.
|
||||
func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
// Apply defaults
|
||||
if cfg.Width <= 0 {
|
||||
cfg.Width = 1024
|
||||
}
|
||||
if cfg.Height <= 0 {
|
||||
cfg.Height = 1024
|
||||
}
|
||||
if cfg.Steps <= 0 {
|
||||
cfg.Steps = 30
|
||||
}
|
||||
if cfg.CFGScale <= 0 {
|
||||
cfg.CFGScale = 4.0
|
||||
}
|
||||
if cfg.CacheInterval <= 0 {
|
||||
cfg.CacheInterval = 3
|
||||
}
|
||||
if cfg.CacheLayers <= 0 {
|
||||
cfg.CacheLayers = 25 // ~42% of 60 layers (similar ratio to Z-Image's 15/38)
|
||||
}
|
||||
|
||||
useCFG := cfg.NegativePrompt != ""
|
||||
tcfg := m.Transformer.Config
|
||||
latentH := cfg.Height / 8
|
||||
latentW := cfg.Width / 8
|
||||
pH := latentH / tcfg.PatchSize
|
||||
pW := latentW / tcfg.PatchSize
|
||||
imgSeqLen := pH * pW
|
||||
|
||||
// Text encoding
|
||||
var posEmb, negEmb *mlx.Array
|
||||
{
|
||||
posEmb = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.Prompt)
|
||||
if useCFG {
|
||||
negEmb = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.NegativePrompt)
|
||||
mlx.Keep(posEmb, negEmb)
|
||||
mlx.Eval(posEmb, negEmb)
|
||||
} else {
|
||||
mlx.Keep(posEmb)
|
||||
mlx.Eval(posEmb)
|
||||
}
|
||||
}
|
||||
|
||||
// Pad sequences to same length for CFG
|
||||
txtLen := posEmb.Shape()[1]
|
||||
if useCFG {
|
||||
negLen := negEmb.Shape()[1]
|
||||
if negLen > txtLen {
|
||||
txtLen = negLen
|
||||
}
|
||||
if posEmb.Shape()[1] < txtLen {
|
||||
posEmb = padSequence(posEmb, txtLen)
|
||||
}
|
||||
if negEmb.Shape()[1] < txtLen {
|
||||
negEmb = padSequence(negEmb, txtLen)
|
||||
}
|
||||
mlx.Keep(posEmb, negEmb)
|
||||
}
|
||||
|
||||
// Scheduler
|
||||
scheduler := NewFlowMatchScheduler(DefaultSchedulerConfig())
|
||||
scheduler.SetTimesteps(cfg.Steps, imgSeqLen)
|
||||
|
||||
// Init latents [B, C, T, H, W]
|
||||
var latents *mlx.Array
|
||||
{
|
||||
latents = scheduler.InitNoise([]int32{1, tcfg.OutChannels, 1, latentH, latentW}, cfg.Seed)
|
||||
mlx.Eval(latents)
|
||||
}
|
||||
|
||||
// RoPE cache
|
||||
var ropeCache *RoPECache
|
||||
{
|
||||
ropeCache = PrepareRoPE(pH, pW, txtLen, tcfg.AxesDimsRope)
|
||||
mlx.Keep(ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
mlx.Eval(ropeCache.ImgFreqs)
|
||||
}
|
||||
|
||||
// Layer cache for DeepCache/Learning-to-Cache speedup
|
||||
var stepCache *cache.StepCache
|
||||
if cfg.LayerCache {
|
||||
stepCache = cache.NewStepCache(cfg.CacheLayers)
|
||||
fmt.Printf(" Layer caching: %d layers, refresh every %d steps\n", cfg.CacheLayers, cfg.CacheInterval)
|
||||
}
|
||||
|
||||
// Denoising loop
|
||||
for i := 0; i < cfg.Steps; i++ {
|
||||
stepStart := time.Now()
|
||||
if cfg.Progress != nil {
|
||||
cfg.Progress(i+1, cfg.Steps)
|
||||
}
|
||||
|
||||
t := scheduler.Timesteps[i]
|
||||
timestep := mlx.ToBFloat16(mlx.NewArray([]float32{t}, []int32{1}))
|
||||
|
||||
// Squeeze temporal dim: [B, C, T, H, W] -> [B, C, H, W]
|
||||
latents2D := mlx.Squeeze(latents, 2)
|
||||
patches := PackLatents(latents2D, tcfg.PatchSize)
|
||||
|
||||
var output *mlx.Array
|
||||
if useCFG {
|
||||
// True CFG: run twice and combine with norm rescaling
|
||||
// Note: layer caching with CFG is not supported yet (would need 2 caches)
|
||||
posOutput := m.Transformer.Forward(patches, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
negOutput := m.Transformer.Forward(patches, negEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
|
||||
diff := mlx.Sub(posOutput, negOutput)
|
||||
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
|
||||
combPred := mlx.Add(negOutput, scaledDiff)
|
||||
|
||||
// Norm rescaling: rescale combined prediction to match conditional prediction's norm
|
||||
condNorm := mlx.Sqrt(mlx.Sum(mlx.Square(posOutput), -1, true))
|
||||
combNorm := mlx.Sqrt(mlx.Sum(mlx.Square(combPred), -1, true))
|
||||
output = mlx.Mul(combPred, mlx.Div(condNorm, combNorm))
|
||||
} else if stepCache != nil {
|
||||
output = m.Transformer.ForwardWithCache(patches, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs,
|
||||
stepCache, i, cfg.CacheInterval, cfg.CacheLayers)
|
||||
} else {
|
||||
output = m.Transformer.Forward(patches, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
}
|
||||
|
||||
noisePred := UnpackLatents(output, latentH, latentW, tcfg.PatchSize)
|
||||
oldLatents := latents
|
||||
latents = scheduler.Step(noisePred, latents, i)
|
||||
|
||||
// Keep cached arrays alive across cleanup
|
||||
if stepCache != nil {
|
||||
mlx.Keep(stepCache.Arrays()...)
|
||||
}
|
||||
mlx.Eval(latents)
|
||||
oldLatents.Free()
|
||||
|
||||
activeMem := float64(mlx.MetalGetActiveMemory()) / (1024 * 1024 * 1024)
|
||||
peakMem := float64(mlx.MetalGetPeakMemory()) / (1024 * 1024 * 1024)
|
||||
fmt.Printf(" Step %d/%d: t=%.4f (%.2fs) [%.1f GB active, %.1f GB peak]\n", i+1, cfg.Steps, t, time.Since(stepStart).Seconds(), activeMem, peakMem)
|
||||
}
|
||||
|
||||
// Free denoising temporaries before VAE decode
|
||||
posEmb.Free()
|
||||
if negEmb != nil {
|
||||
negEmb.Free()
|
||||
}
|
||||
ropeCache.ImgFreqs.Free()
|
||||
ropeCache.TxtFreqs.Free()
|
||||
if stepCache != nil {
|
||||
stepCache.Free()
|
||||
}
|
||||
|
||||
// VAE decode (Decode manages its own pools for staged memory)
|
||||
decoded := m.VAEDecoder.Decode(latents)
|
||||
latents.Free()
|
||||
// Post-process: squeeze temporal dim and rescale to [0, 1]
|
||||
{
|
||||
decoded = mlx.Squeeze(decoded, 2)
|
||||
decoded = mlx.AddScalar(decoded, 1.0)
|
||||
decoded = mlx.DivScalar(decoded, 2.0)
|
||||
mlx.Eval(decoded)
|
||||
}
|
||||
|
||||
fmt.Printf(" Peak memory: %.2f GB\n", float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
return decoded, nil
|
||||
}
|
||||
|
||||
// padSequence pads a sequence tensor to the target length with zeros
|
||||
func padSequence(x *mlx.Array, targetLen int32) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
currentLen := shape[1]
|
||||
if currentLen >= targetLen {
|
||||
return x
|
||||
}
|
||||
padLen := targetLen - currentLen
|
||||
// Pad on sequence dimension (axis 1)
|
||||
return mlx.Pad(x, []int32{0, 0, 0, padLen, 0, 0})
|
||||
}
|
||||
|
||||
// LoadPersistent is an alias for backward compatibility.
|
||||
// Use m := &Model{}; m.Load(path) instead.
|
||||
func LoadPersistent(modelPath string) (*Model, error) {
|
||||
m := &Model{}
|
||||
if err := m.Load(modelPath); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
218
x/imagegen/models/qwen_image/scheduler.go
Normal file
218
x/imagegen/models/qwen_image/scheduler.go
Normal file
@@ -0,0 +1,218 @@
|
||||
//go:build mlx
|
||||
|
||||
package qwen_image
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// SchedulerConfig holds FlowMatchEulerDiscreteScheduler configuration
|
||||
type SchedulerConfig struct {
|
||||
NumTrainTimesteps int32 `json:"num_train_timesteps"` // 1000
|
||||
BaseShift float32 `json:"base_shift"` // 0.5
|
||||
MaxShift float32 `json:"max_shift"` // 0.9
|
||||
BaseImageSeqLen int32 `json:"base_image_seq_len"` // 256
|
||||
MaxImageSeqLen int32 `json:"max_image_seq_len"` // 8192
|
||||
ShiftTerminal float32 `json:"shift_terminal"` // 0.02
|
||||
UseDynamicShift bool `json:"use_dynamic_shifting"` // true
|
||||
}
|
||||
|
||||
// DefaultSchedulerConfig returns config for FlowMatchEulerDiscreteScheduler
|
||||
func DefaultSchedulerConfig() *SchedulerConfig {
|
||||
return &SchedulerConfig{
|
||||
NumTrainTimesteps: 1000,
|
||||
BaseShift: 0.5,
|
||||
MaxShift: 0.9, // Matches scheduler_config.json
|
||||
BaseImageSeqLen: 256,
|
||||
MaxImageSeqLen: 8192,
|
||||
ShiftTerminal: 0.02,
|
||||
UseDynamicShift: true,
|
||||
}
|
||||
}
|
||||
|
||||
// FlowMatchScheduler implements the Flow Match Euler discrete scheduler
|
||||
type FlowMatchScheduler struct {
|
||||
Config *SchedulerConfig
|
||||
Timesteps []float32
|
||||
Sigmas []float32
|
||||
NumSteps int
|
||||
}
|
||||
|
||||
// NewFlowMatchScheduler creates a new scheduler
|
||||
func NewFlowMatchScheduler(cfg *SchedulerConfig) *FlowMatchScheduler {
|
||||
return &FlowMatchScheduler{
|
||||
Config: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
// CalculateShift computes the dynamic shift based on image sequence length
|
||||
// This matches Python's calculate_shift function
|
||||
func CalculateShift(imageSeqLen int32, baseSeqLen int32, maxSeqLen int32, baseShift float32, maxShift float32) float32 {
|
||||
m := (maxShift - baseShift) / float32(maxSeqLen-baseSeqLen)
|
||||
b := baseShift - m*float32(baseSeqLen)
|
||||
mu := float32(imageSeqLen)*m + b
|
||||
return mu
|
||||
}
|
||||
|
||||
// SetTimesteps sets up the scheduler for the given number of inference steps
|
||||
// Matches Python diffusers FlowMatchEulerDiscreteScheduler behavior:
|
||||
// 1. Create sigmas from sigma_max to sigma_min (linspace)
|
||||
// 2. Apply time_shift with mu (if dynamic shifting)
|
||||
// 3. Apply stretch_shift_to_terminal to make final value = shift_terminal
|
||||
func (s *FlowMatchScheduler) SetTimesteps(numSteps int, imageSeqLen int32) {
|
||||
s.NumSteps = numSteps
|
||||
|
||||
// Calculate mu for dynamic shifting
|
||||
var mu float32
|
||||
if s.Config.UseDynamicShift {
|
||||
mu = CalculateShift(
|
||||
imageSeqLen,
|
||||
s.Config.BaseImageSeqLen,
|
||||
s.Config.MaxImageSeqLen,
|
||||
s.Config.BaseShift,
|
||||
s.Config.MaxShift,
|
||||
)
|
||||
}
|
||||
|
||||
// Step 1: Create sigmas from 1.0 to 1/num_steps
|
||||
// Python (pipeline_qwenimage.py:639):
|
||||
// sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
||||
// This gives sigmas from 1.0 to 1/30 = 0.033 for 30 steps
|
||||
sigmas := make([]float32, numSteps)
|
||||
sigmaMax := float32(1.0)
|
||||
sigmaMin := 1.0 / float32(numSteps) // 1/30 = 0.033 for 30 steps
|
||||
if numSteps == 1 {
|
||||
sigmas[0] = sigmaMax
|
||||
} else {
|
||||
for i := 0; i < numSteps; i++ {
|
||||
sigmas[i] = sigmaMax + float32(i)*(sigmaMin-sigmaMax)/float32(numSteps-1)
|
||||
}
|
||||
}
|
||||
|
||||
// Step 2: Apply time shift if using dynamic shifting
|
||||
if s.Config.UseDynamicShift && mu != 0 {
|
||||
for i := range sigmas {
|
||||
sigmas[i] = s.timeShift(mu, sigmas[i])
|
||||
}
|
||||
}
|
||||
|
||||
// Step 3: Apply stretch_shift_to_terminal
|
||||
if s.Config.ShiftTerminal > 0 {
|
||||
sigmas = s.stretchShiftToTerminal(sigmas)
|
||||
}
|
||||
|
||||
// Step 4: Append terminal sigma (0) and store
|
||||
// Note: Python's scheduler.timesteps are sigmas*1000, but the pipeline divides by 1000
|
||||
// before passing to transformer. We skip both steps and just use sigmas directly.
|
||||
s.Sigmas = make([]float32, numSteps+1)
|
||||
s.Timesteps = make([]float32, numSteps+1)
|
||||
for i := 0; i < numSteps; i++ {
|
||||
s.Sigmas[i] = sigmas[i]
|
||||
s.Timesteps[i] = sigmas[i]
|
||||
}
|
||||
s.Sigmas[numSteps] = 0.0
|
||||
s.Timesteps[numSteps] = 0.0
|
||||
}
|
||||
|
||||
// stretchShiftToTerminal stretches and shifts the timestep schedule
|
||||
// so the final value equals shift_terminal (matches Python behavior)
|
||||
func (s *FlowMatchScheduler) stretchShiftToTerminal(sigmas []float32) []float32 {
|
||||
if len(sigmas) == 0 {
|
||||
return sigmas
|
||||
}
|
||||
|
||||
// one_minus_z = 1 - t
|
||||
// scale_factor = one_minus_z[-1] / (1 - shift_terminal)
|
||||
// stretched_t = 1 - (one_minus_z / scale_factor)
|
||||
lastSigma := sigmas[len(sigmas)-1]
|
||||
scaleFactor := (1.0 - lastSigma) / (1.0 - s.Config.ShiftTerminal)
|
||||
|
||||
// Handle edge case: if scaleFactor is 0 or near 0, skip stretch
|
||||
// This happens when lastSigma ≈ 1.0 (e.g., single step with timeshift)
|
||||
if scaleFactor < 1e-6 {
|
||||
return sigmas
|
||||
}
|
||||
|
||||
result := make([]float32, len(sigmas))
|
||||
for i, t := range sigmas {
|
||||
oneMinusZ := 1.0 - t
|
||||
result[i] = 1.0 - (oneMinusZ / scaleFactor)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// timeShift applies the dynamic time shift (exponential)
|
||||
// exp(mu) / (exp(mu) + (1/t - 1))
|
||||
func (s *FlowMatchScheduler) timeShift(mu float32, t float32) float32 {
|
||||
if t <= 0 {
|
||||
return 0
|
||||
}
|
||||
expMu := float32(math.Exp(float64(mu)))
|
||||
return expMu / (expMu + (1.0/t - 1.0))
|
||||
}
|
||||
|
||||
// Step performs one denoising step
|
||||
// modelOutput: predicted velocity from the transformer
|
||||
// sample: current noisy sample
|
||||
// timestepIdx: current timestep index
|
||||
func (s *FlowMatchScheduler) Step(modelOutput, sample *mlx.Array, timestepIdx int) *mlx.Array {
|
||||
// Get current and next sigma
|
||||
sigma := s.Sigmas[timestepIdx]
|
||||
sigmaNext := s.Sigmas[timestepIdx+1]
|
||||
|
||||
// Euler step: x_{t-dt} = x_t + (sigma_next - sigma) * v_t
|
||||
dt := sigmaNext - sigma
|
||||
|
||||
// Upcast to float32 to avoid precision issues (matches Python diffusers)
|
||||
sampleF32 := mlx.AsType(sample, mlx.DtypeFloat32)
|
||||
modelOutputF32 := mlx.AsType(modelOutput, mlx.DtypeFloat32)
|
||||
|
||||
scaledOutput := mlx.MulScalar(modelOutputF32, dt)
|
||||
result := mlx.Add(sampleF32, scaledOutput)
|
||||
|
||||
// Cast back to original dtype
|
||||
return mlx.ToBFloat16(result)
|
||||
}
|
||||
|
||||
// GetTimestep returns the timestep value at the given index
|
||||
func (s *FlowMatchScheduler) GetTimestep(idx int) float32 {
|
||||
if idx < len(s.Timesteps) {
|
||||
return s.Timesteps[idx]
|
||||
}
|
||||
return 0.0
|
||||
}
|
||||
|
||||
// InitNoise creates initial noise for sampling in unpacked format [B, C, T, H, W]
|
||||
func (s *FlowMatchScheduler) InitNoise(shape []int32, seed int64) *mlx.Array {
|
||||
return mlx.RandomNormal(shape, uint64(seed))
|
||||
}
|
||||
|
||||
// InitNoisePacked creates initial noise directly in packed format [B, L, C*4]
|
||||
// This matches how Python diffusers generates noise - directly in packed space.
|
||||
// Generating in unpacked format and then packing produces different spatial
|
||||
// correlation structure, which affects model output quality.
|
||||
func (s *FlowMatchScheduler) InitNoisePacked(batchSize, seqLen, channels int32, seed int64) *mlx.Array {
|
||||
shape := []int32{batchSize, seqLen, channels}
|
||||
return mlx.RandomNormal(shape, uint64(seed))
|
||||
}
|
||||
|
||||
// GetLatentShape returns the latent shape for a given image size
|
||||
// For qwen_image: VAE downscale is 8x (spatial), latent has 16 channels
|
||||
func GetLatentShape(batchSize, height, width int32) []int32 {
|
||||
latentH := height / 8
|
||||
latentW := width / 8
|
||||
return []int32{batchSize, 16, 1, latentH, latentW} // [B, C, T, H, W]
|
||||
}
|
||||
|
||||
// GetPatchedLatentShape returns the patchified latent shape
|
||||
// After patchification: [B, L, C*patch_size^2] where L = H/2 * W/2
|
||||
func GetPatchedLatentShape(batchSize, height, width, patchSize int32) []int32 {
|
||||
latentH := height / 8
|
||||
latentW := width / 8
|
||||
pH := latentH / patchSize
|
||||
pW := latentW / patchSize
|
||||
inChannels := int32(64) // 16 * patch_size^2
|
||||
return []int32{batchSize, pH * pW, inChannels}
|
||||
}
|
||||
135
x/imagegen/models/qwen_image/scheduler_test.go
Normal file
135
x/imagegen/models/qwen_image/scheduler_test.go
Normal file
@@ -0,0 +1,135 @@
|
||||
//go:build mlx
|
||||
|
||||
package qwen_image
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestSchedulerSetTimesteps verifies scheduler sigmas match Python diffusers reference.
|
||||
// Golden values generated via:
|
||||
//
|
||||
// python3 -c "
|
||||
// from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
||||
// import numpy as np
|
||||
// s = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, base_shift=0.5, max_shift=0.9,
|
||||
// base_image_seq_len=256, max_image_seq_len=8192, shift_terminal=0.02, use_dynamic_shifting=True)
|
||||
// mu = 4096 * (0.9-0.5)/(8192-256) + 0.5 - (0.9-0.5)/(8192-256)*256
|
||||
// sigmas = np.linspace(1.0, 1.0/30, 30)
|
||||
// s.set_timesteps(sigmas=sigmas, mu=mu)
|
||||
// print(s.sigmas.numpy())"
|
||||
func TestSchedulerSetTimesteps(t *testing.T) {
|
||||
cfg := DefaultSchedulerConfig()
|
||||
scheduler := NewFlowMatchScheduler(cfg)
|
||||
scheduler.SetTimesteps(30, 4096)
|
||||
|
||||
// Golden values from Python diffusers (first 3, last 3 before terminal)
|
||||
wantFirst := []float32{1.000000, 0.982251, 0.963889}
|
||||
wantLast := []float32{0.142924, 0.083384, 0.020000}
|
||||
|
||||
// Check first 3
|
||||
for i, want := range wantFirst {
|
||||
got := scheduler.Sigmas[i]
|
||||
if abs32(got-want) > 1e-4 {
|
||||
t.Errorf("sigma[%d]: got %v, want %v", i, got, want)
|
||||
}
|
||||
}
|
||||
|
||||
// Check last 3 (indices 27, 28, 29)
|
||||
for i, want := range wantLast {
|
||||
idx := 27 + i
|
||||
got := scheduler.Sigmas[idx]
|
||||
if abs32(got-want) > 1e-4 {
|
||||
t.Errorf("sigma[%d]: got %v, want %v", idx, got, want)
|
||||
}
|
||||
}
|
||||
|
||||
// Check terminal is 0
|
||||
if scheduler.Sigmas[30] != 0.0 {
|
||||
t.Errorf("terminal sigma: got %v, want 0", scheduler.Sigmas[30])
|
||||
}
|
||||
|
||||
// Check length
|
||||
if len(scheduler.Sigmas) != 31 {
|
||||
t.Errorf("sigmas length: got %d, want 31", len(scheduler.Sigmas))
|
||||
}
|
||||
}
|
||||
|
||||
// TestSchedulerProperties tests mathematical invariants of the scheduler.
|
||||
func TestSchedulerProperties(t *testing.T) {
|
||||
cfg := DefaultSchedulerConfig()
|
||||
scheduler := NewFlowMatchScheduler(cfg)
|
||||
scheduler.SetTimesteps(30, 4096)
|
||||
|
||||
// Property: sigmas monotonically decreasing
|
||||
for i := 1; i < len(scheduler.Sigmas); i++ {
|
||||
if scheduler.Sigmas[i] > scheduler.Sigmas[i-1] {
|
||||
t.Errorf("sigmas not monotonically decreasing at %d: %v > %v",
|
||||
i, scheduler.Sigmas[i], scheduler.Sigmas[i-1])
|
||||
}
|
||||
}
|
||||
|
||||
// Property: first sigma should be ~1.0 (with time shift)
|
||||
if scheduler.Sigmas[0] < 0.9 || scheduler.Sigmas[0] > 1.01 {
|
||||
t.Errorf("first sigma out of expected range [0.9, 1.01]: %v", scheduler.Sigmas[0])
|
||||
}
|
||||
|
||||
// Property: terminal sigma should be exactly 0
|
||||
if scheduler.Sigmas[len(scheduler.Sigmas)-1] != 0.0 {
|
||||
t.Errorf("terminal sigma should be 0, got %v", scheduler.Sigmas[len(scheduler.Sigmas)-1])
|
||||
}
|
||||
|
||||
// Property: last non-terminal sigma should be shift_terminal (0.02)
|
||||
lastNonTerminal := scheduler.Sigmas[len(scheduler.Sigmas)-2]
|
||||
if abs32(lastNonTerminal-0.02) > 1e-5 {
|
||||
t.Errorf("last non-terminal sigma should be 0.02, got %v", lastNonTerminal)
|
||||
}
|
||||
|
||||
// Property: length = steps + 1
|
||||
if len(scheduler.Sigmas) != scheduler.NumSteps+1 {
|
||||
t.Errorf("sigmas length should be steps+1: got %d, want %d",
|
||||
len(scheduler.Sigmas), scheduler.NumSteps+1)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCalculateShift verifies the mu calculation against Python reference.
|
||||
// Golden values from: mu = img_seq_len * m + b where m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
||||
func TestCalculateShift(t *testing.T) {
|
||||
cases := []struct {
|
||||
imgSeqLen int32
|
||||
want float32
|
||||
}{
|
||||
{256, 0.5}, // base case
|
||||
{8192, 0.9}, // max case
|
||||
{4096, 0.6935}, // middle case (rounded)
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
got := CalculateShift(c.imgSeqLen, 256, 8192, 0.5, 0.9)
|
||||
if abs32(got-c.want) > 0.001 {
|
||||
t.Errorf("CalculateShift(%d): got %v, want %v", c.imgSeqLen, got, c.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestSchedulerStep verifies the Euler step formula.
|
||||
func TestSchedulerStep(t *testing.T) {
|
||||
cfg := DefaultSchedulerConfig()
|
||||
scheduler := NewFlowMatchScheduler(cfg)
|
||||
scheduler.SetTimesteps(30, 4096)
|
||||
|
||||
// Verify dt calculation for first step
|
||||
sigma0 := scheduler.Sigmas[0]
|
||||
sigma1 := scheduler.Sigmas[1]
|
||||
expectedDt := sigma1 - sigma0
|
||||
|
||||
// dt should be negative (sigmas decrease)
|
||||
if expectedDt >= 0 {
|
||||
t.Errorf("expected negative dt, got %v (sigma0=%v, sigma1=%v)", expectedDt, sigma0, sigma1)
|
||||
}
|
||||
}
|
||||
|
||||
func abs32(x float32) float32 {
|
||||
return float32(math.Abs(float64(x)))
|
||||
}
|
||||
174
x/imagegen/models/qwen_image/text_encoder_test.go
Normal file
174
x/imagegen/models/qwen_image/text_encoder_test.go
Normal file
@@ -0,0 +1,174 @@
|
||||
//go:build mlx
|
||||
|
||||
package qwen_image
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
|
||||
// TinyTextEncoderConfig holds config for the tiny test text encoder
|
||||
type TinyTextEncoderConfig struct {
|
||||
HiddenSize int32 `json:"hidden_size"`
|
||||
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
||||
IntermediateSize int32 `json:"intermediate_size"`
|
||||
NumAttentionHeads int32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads int32 `json:"num_key_value_heads"`
|
||||
VocabSize int32 `json:"vocab_size"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
HeadDim int32 `json:"head_dim"`
|
||||
MRoPESection []int32 `json:"mrope_section"`
|
||||
}
|
||||
|
||||
// loadTinyTextEncoder loads the tiny text encoder from testdata
|
||||
func loadTinyTextEncoder(t *testing.T) (*Qwen25VL, *TinyTextEncoderConfig) {
|
||||
t.Helper()
|
||||
|
||||
testdataDir := filepath.Join("testdata", "tiny_text_encoder")
|
||||
|
||||
// Load config
|
||||
configData, err := os.ReadFile(filepath.Join(testdataDir, "config.json"))
|
||||
if err != nil {
|
||||
t.Skipf("Skipping: tiny weights not found. Regenerate with Python (see models/CLAUDE.md)")
|
||||
}
|
||||
|
||||
var tinyCfg TinyTextEncoderConfig
|
||||
if err := json.Unmarshal(configData, &tinyCfg); err != nil {
|
||||
t.Fatalf("Failed to parse config: %v", err)
|
||||
}
|
||||
|
||||
// Create encoder config (using Qwen25VLConfig)
|
||||
cfg := &Qwen25VLConfig{
|
||||
HiddenSize: tinyCfg.HiddenSize,
|
||||
NumHiddenLayers: tinyCfg.NumHiddenLayers,
|
||||
IntermediateSize: tinyCfg.IntermediateSize,
|
||||
NumAttentionHeads: tinyCfg.NumAttentionHeads,
|
||||
NumKeyValueHeads: tinyCfg.NumKeyValueHeads,
|
||||
VocabSize: tinyCfg.VocabSize,
|
||||
RMSNormEps: tinyCfg.RMSNormEps,
|
||||
RopeTheta: tinyCfg.RopeTheta,
|
||||
HeadDim: tinyCfg.HeadDim,
|
||||
MRoPESection: tinyCfg.MRoPESection,
|
||||
}
|
||||
|
||||
// Load weights
|
||||
weights, err := safetensors.LoadModelWeights(testdataDir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load weights: %v", err)
|
||||
}
|
||||
|
||||
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
|
||||
t.Fatalf("Failed to bulk load weights: %v", err)
|
||||
}
|
||||
|
||||
// Build encoder
|
||||
embedding, err := weights.Get("model.embed_tokens.weight")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get embedding: %v", err)
|
||||
}
|
||||
|
||||
blocks := make([]*VLTextBlock, cfg.NumHiddenLayers)
|
||||
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
|
||||
block, err := newVLTextBlock(weights, int(i), cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load block %d: %v", i, err)
|
||||
}
|
||||
blocks[i] = block
|
||||
}
|
||||
|
||||
finalNorm, err := weights.Get("model.norm.weight")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get final norm: %v", err)
|
||||
}
|
||||
|
||||
encoder := &Qwen25VL{
|
||||
Config: cfg,
|
||||
Embedding: embedding,
|
||||
Blocks: blocks,
|
||||
FinalNorm: finalNorm,
|
||||
HasVision: false, // Text-only mode
|
||||
}
|
||||
|
||||
return encoder, &tinyCfg
|
||||
}
|
||||
|
||||
// TestTextEncoderForward verifies the text encoder forward pass with tiny weights.
|
||||
func TestTextEncoderForward(t *testing.T) {
|
||||
encoder, cfg := loadTinyTextEncoder(t)
|
||||
|
||||
// Create test tokens (within vocab range)
|
||||
tokens := []int32{1, 2, 3, 4, 5}
|
||||
|
||||
// Forward pass using EncodeTextOnly
|
||||
out := encoder.EncodeTextOnly(tokens)
|
||||
mlx.Eval(out)
|
||||
|
||||
// Verify output shape: [batch, seq_len, hidden_size]
|
||||
wantShape := []int32{1, 5, cfg.HiddenSize}
|
||||
if !slices.Equal(out.Shape(), wantShape) {
|
||||
t.Errorf("output shape: got %v, want %v", out.Shape(), wantShape)
|
||||
}
|
||||
|
||||
// Verify output is finite (not NaN or Inf)
|
||||
data := out.Data()
|
||||
for i, v := range data {
|
||||
if math.IsNaN(float64(v)) || math.IsInf(float64(v), 0) {
|
||||
t.Errorf("output[%d] is not finite: %v", i, v)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestTextEncoderBatch tests batch processing.
|
||||
func TestTextEncoderBatch(t *testing.T) {
|
||||
encoder, cfg := loadTinyTextEncoder(t)
|
||||
|
||||
// For batch test, we'll use EncodeTextOnly with a single sequence
|
||||
// (EncodeTextOnly doesn't support batch, but we can verify single sequence works)
|
||||
tokens := []int32{1, 2, 3}
|
||||
|
||||
out := encoder.EncodeTextOnly(tokens)
|
||||
mlx.Eval(out)
|
||||
|
||||
wantShape := []int32{1, 3, cfg.HiddenSize}
|
||||
if !slices.Equal(out.Shape(), wantShape) {
|
||||
t.Errorf("shape: got %v, want %v", out.Shape(), wantShape)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMRoPEComputation verifies M-RoPE frequency computation produces valid values.
|
||||
func TestMRoPEComputation(t *testing.T) {
|
||||
encoder, cfg := loadTinyTextEncoder(t)
|
||||
|
||||
cossin := encoder.computeTextRoPE(10, 1)
|
||||
mlx.Eval(cossin[0], cossin[1])
|
||||
|
||||
// Verify shapes: [3, B, L, head_dim]
|
||||
wantShape := []int32{3, 1, 10, cfg.HeadDim}
|
||||
if !slices.Equal(cossin[0].Shape(), wantShape) {
|
||||
t.Errorf("cos shape: got %v, want %v", cossin[0].Shape(), wantShape)
|
||||
}
|
||||
if !slices.Equal(cossin[1].Shape(), wantShape) {
|
||||
t.Errorf("sin shape: got %v, want %v", cossin[1].Shape(), wantShape)
|
||||
}
|
||||
|
||||
// Verify cos/sin values are in valid range [-1, 1]
|
||||
cosData := cossin[0].Data()
|
||||
sinData := cossin[1].Data()
|
||||
for i := 0; i < min(100, len(cosData)); i++ {
|
||||
if cosData[i] < -1.01 || cosData[i] > 1.01 {
|
||||
t.Errorf("cos[%d] out of range: %v", i, cosData[i])
|
||||
}
|
||||
if sinData[i] < -1.01 || sinData[i] > 1.01 {
|
||||
t.Errorf("sin[%d] out of range: %v", i, sinData[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
868
x/imagegen/models/qwen_image/transformer.go
Normal file
868
x/imagegen/models/qwen_image/transformer.go
Normal file
@@ -0,0 +1,868 @@
|
||||
//go:build mlx
|
||||
|
||||
package qwen_image
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
|
||||
// TransformerConfig holds Qwen-Image transformer configuration
|
||||
type TransformerConfig struct {
|
||||
HiddenDim int32 `json:"hidden_dim"` // 3072 (24 * 128)
|
||||
NHeads int32 `json:"num_attention_heads"` // 24
|
||||
HeadDim int32 `json:"attention_head_dim"` // 128
|
||||
NLayers int32 `json:"num_layers"` // 60
|
||||
InChannels int32 `json:"in_channels"` // 64
|
||||
OutChannels int32 `json:"out_channels"` // 16
|
||||
PatchSize int32 `json:"patch_size"` // 2
|
||||
JointAttentionDim int32 `json:"joint_attention_dim"` // 3584 (text encoder dim)
|
||||
NormEps float32 `json:"norm_eps"` // 1e-6
|
||||
AxesDimsRope []int32 `json:"axes_dims_rope"` // [16, 56, 56]
|
||||
GuidanceEmbeds bool `json:"guidance_embeds"` // false
|
||||
}
|
||||
|
||||
// defaultTransformerConfig returns config for Qwen-Image transformer
|
||||
func defaultTransformerConfig() *TransformerConfig {
|
||||
return &TransformerConfig{
|
||||
HiddenDim: 3072, // 24 * 128
|
||||
NHeads: 24,
|
||||
HeadDim: 128,
|
||||
NLayers: 60,
|
||||
InChannels: 64,
|
||||
OutChannels: 16,
|
||||
PatchSize: 2,
|
||||
JointAttentionDim: 3584,
|
||||
NormEps: 1e-6,
|
||||
AxesDimsRope: []int32{16, 56, 56},
|
||||
GuidanceEmbeds: false,
|
||||
}
|
||||
}
|
||||
|
||||
// TimestepEmbedder creates timestep embeddings
|
||||
type TimestepEmbedder struct {
|
||||
Linear1Weight *mlx.Array // [256, hidden_dim]
|
||||
Linear1Bias *mlx.Array
|
||||
Linear2Weight *mlx.Array // [hidden_dim, hidden_dim]
|
||||
Linear2Bias *mlx.Array
|
||||
}
|
||||
|
||||
// newTimestepEmbedder creates a timestep embedder from weights
|
||||
func newTimestepEmbedder(weights *safetensors.ModelWeights) (*TimestepEmbedder, error) {
|
||||
linear1Weight, err := weights.Get("time_text_embed.timestep_embedder.linear_1.weight")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
linear1Bias, err := weights.Get("time_text_embed.timestep_embedder.linear_1.bias")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
linear2Weight, err := weights.Get("time_text_embed.timestep_embedder.linear_2.weight")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
linear2Bias, err := weights.Get("time_text_embed.timestep_embedder.linear_2.bias")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &TimestepEmbedder{
|
||||
Linear1Weight: mlx.Transpose(linear1Weight, 1, 0),
|
||||
Linear1Bias: linear1Bias,
|
||||
Linear2Weight: mlx.Transpose(linear2Weight, 1, 0),
|
||||
Linear2Bias: linear2Bias,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward computes timestep embeddings
|
||||
// t: [B] timesteps (normalized 0-1, will be scaled by 1000 internally)
|
||||
func (te *TimestepEmbedder) Forward(t *mlx.Array) *mlx.Array {
|
||||
half := int32(128) // embedding_dim / 2
|
||||
|
||||
// Sinusoidal embedding with flip_sin_to_cos=True, scale=1000
|
||||
freqs := make([]float32, half)
|
||||
for i := int32(0); i < half; i++ {
|
||||
freqs[i] = float32(math.Exp(-math.Log(10000.0) * float64(i) / float64(half)))
|
||||
}
|
||||
freqsArr := mlx.NewArray(freqs, []int32{1, half})
|
||||
|
||||
tExpanded := mlx.ExpandDims(t, 1)
|
||||
args := mlx.Mul(tExpanded, freqsArr)
|
||||
args = mlx.MulScalar(args, 1000.0) // scale
|
||||
|
||||
// [cos, sin] (flip_sin_to_cos=True)
|
||||
sinArgs := mlx.Sin(args)
|
||||
cosArgs := mlx.Cos(args)
|
||||
embedding := mlx.Concatenate([]*mlx.Array{cosArgs, sinArgs}, 1) // [B, 256]
|
||||
|
||||
// MLP: linear1 -> silu -> linear2
|
||||
h := mlx.Linear(embedding, te.Linear1Weight)
|
||||
h = mlx.Add(h, te.Linear1Bias)
|
||||
h = mlx.SiLU(h)
|
||||
h = mlx.Linear(h, te.Linear2Weight)
|
||||
h = mlx.Add(h, te.Linear2Bias)
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// JointAttention implements dual-stream joint attention
|
||||
type JointAttention struct {
|
||||
// Image projections
|
||||
ToQ *mlx.Array
|
||||
ToQB *mlx.Array
|
||||
ToK *mlx.Array
|
||||
ToKB *mlx.Array
|
||||
ToV *mlx.Array
|
||||
ToVB *mlx.Array
|
||||
ToOut *mlx.Array
|
||||
ToOutB *mlx.Array
|
||||
NormQ *mlx.Array
|
||||
NormK *mlx.Array
|
||||
|
||||
// Text (added) projections
|
||||
AddQProj *mlx.Array
|
||||
AddQProjB *mlx.Array
|
||||
AddKProj *mlx.Array
|
||||
AddKProjB *mlx.Array
|
||||
AddVProj *mlx.Array
|
||||
AddVProjB *mlx.Array
|
||||
ToAddOut *mlx.Array
|
||||
ToAddOutB *mlx.Array
|
||||
NormAddQ *mlx.Array
|
||||
NormAddK *mlx.Array
|
||||
|
||||
NHeads int32
|
||||
HeadDim int32
|
||||
Scale float32
|
||||
}
|
||||
|
||||
// newJointAttention creates a joint attention layer
|
||||
func newJointAttention(weights *safetensors.ModelWeights, prefix string, cfg *TransformerConfig) (*JointAttention, error) {
|
||||
toQ, _ := weights.Get(prefix + ".attn.to_q.weight")
|
||||
toQB, _ := weights.Get(prefix + ".attn.to_q.bias")
|
||||
toK, _ := weights.Get(prefix + ".attn.to_k.weight")
|
||||
toKB, _ := weights.Get(prefix + ".attn.to_k.bias")
|
||||
toV, _ := weights.Get(prefix + ".attn.to_v.weight")
|
||||
toVB, _ := weights.Get(prefix + ".attn.to_v.bias")
|
||||
toOut, _ := weights.Get(prefix + ".attn.to_out.0.weight")
|
||||
toOutB, _ := weights.Get(prefix + ".attn.to_out.0.bias")
|
||||
normQ, _ := weights.Get(prefix + ".attn.norm_q.weight")
|
||||
normK, _ := weights.Get(prefix + ".attn.norm_k.weight")
|
||||
|
||||
addQProj, _ := weights.Get(prefix + ".attn.add_q_proj.weight")
|
||||
addQProjB, _ := weights.Get(prefix + ".attn.add_q_proj.bias")
|
||||
addKProj, _ := weights.Get(prefix + ".attn.add_k_proj.weight")
|
||||
addKProjB, _ := weights.Get(prefix + ".attn.add_k_proj.bias")
|
||||
addVProj, _ := weights.Get(prefix + ".attn.add_v_proj.weight")
|
||||
addVProjB, _ := weights.Get(prefix + ".attn.add_v_proj.bias")
|
||||
toAddOut, _ := weights.Get(prefix + ".attn.to_add_out.weight")
|
||||
toAddOutB, _ := weights.Get(prefix + ".attn.to_add_out.bias")
|
||||
normAddQ, _ := weights.Get(prefix + ".attn.norm_added_q.weight")
|
||||
normAddK, _ := weights.Get(prefix + ".attn.norm_added_k.weight")
|
||||
|
||||
return &JointAttention{
|
||||
ToQ: mlx.Transpose(toQ, 1, 0),
|
||||
ToQB: toQB,
|
||||
ToK: mlx.Transpose(toK, 1, 0),
|
||||
ToKB: toKB,
|
||||
ToV: mlx.Transpose(toV, 1, 0),
|
||||
ToVB: toVB,
|
||||
ToOut: mlx.Transpose(toOut, 1, 0),
|
||||
ToOutB: toOutB,
|
||||
NormQ: normQ,
|
||||
NormK: normK,
|
||||
AddQProj: mlx.Transpose(addQProj, 1, 0),
|
||||
AddQProjB: addQProjB,
|
||||
AddKProj: mlx.Transpose(addKProj, 1, 0),
|
||||
AddKProjB: addKProjB,
|
||||
AddVProj: mlx.Transpose(addVProj, 1, 0),
|
||||
AddVProjB: addVProjB,
|
||||
ToAddOut: mlx.Transpose(toAddOut, 1, 0),
|
||||
ToAddOutB: toAddOutB,
|
||||
NormAddQ: normAddQ,
|
||||
NormAddK: normAddK,
|
||||
NHeads: cfg.NHeads,
|
||||
HeadDim: cfg.HeadDim,
|
||||
Scale: float32(1.0 / math.Sqrt(float64(cfg.HeadDim))),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward computes joint attention
|
||||
// img: [B, L_img, D], txt: [B, L_txt, D]
|
||||
// imgFreqs, txtFreqs: complex RoPE frequencies [L, head_dim/2] as interleaved real/imag
|
||||
func (attn *JointAttention) Forward(img, txt *mlx.Array, imgFreqs, txtFreqs *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
imgShape := img.Shape()
|
||||
B := imgShape[0]
|
||||
Limg := imgShape[1]
|
||||
D := imgShape[2]
|
||||
|
||||
txtShape := txt.Shape()
|
||||
Ltxt := txtShape[1]
|
||||
|
||||
// === Image Q/K/V ===
|
||||
imgFlat := mlx.Reshape(img, B*Limg, D)
|
||||
qImg := mlx.Add(mlx.Linear(imgFlat, attn.ToQ), attn.ToQB)
|
||||
kImg := mlx.Add(mlx.Linear(imgFlat, attn.ToK), attn.ToKB)
|
||||
vImg := mlx.Add(mlx.Linear(imgFlat, attn.ToV), attn.ToVB)
|
||||
|
||||
qImg = mlx.Reshape(qImg, B, Limg, attn.NHeads, attn.HeadDim)
|
||||
kImg = mlx.Reshape(kImg, B, Limg, attn.NHeads, attn.HeadDim)
|
||||
vImg = mlx.Reshape(vImg, B, Limg, attn.NHeads, attn.HeadDim)
|
||||
|
||||
// QK norm (RMSNorm per head)
|
||||
qImg = mlx.RMSNorm(qImg, attn.NormQ, 1e-6)
|
||||
kImg = mlx.RMSNorm(kImg, attn.NormK, 1e-6)
|
||||
|
||||
// Apply RoPE
|
||||
if imgFreqs != nil {
|
||||
qImg = applyRoPE(qImg, imgFreqs)
|
||||
kImg = applyRoPE(kImg, imgFreqs)
|
||||
}
|
||||
|
||||
// === Text Q/K/V ===
|
||||
txtFlat := mlx.Reshape(txt, B*Ltxt, D)
|
||||
qTxt := mlx.Add(mlx.Linear(txtFlat, attn.AddQProj), attn.AddQProjB)
|
||||
kTxt := mlx.Add(mlx.Linear(txtFlat, attn.AddKProj), attn.AddKProjB)
|
||||
vTxt := mlx.Add(mlx.Linear(txtFlat, attn.AddVProj), attn.AddVProjB)
|
||||
|
||||
qTxt = mlx.Reshape(qTxt, B, Ltxt, attn.NHeads, attn.HeadDim)
|
||||
kTxt = mlx.Reshape(kTxt, B, Ltxt, attn.NHeads, attn.HeadDim)
|
||||
vTxt = mlx.Reshape(vTxt, B, Ltxt, attn.NHeads, attn.HeadDim)
|
||||
|
||||
qTxt = mlx.RMSNorm(qTxt, attn.NormAddQ, 1e-6)
|
||||
kTxt = mlx.RMSNorm(kTxt, attn.NormAddK, 1e-6)
|
||||
|
||||
if txtFreqs != nil {
|
||||
qTxt = applyRoPE(qTxt, txtFreqs)
|
||||
kTxt = applyRoPE(kTxt, txtFreqs)
|
||||
}
|
||||
|
||||
// Concatenate for joint attention: [txt, img] order
|
||||
qJoint := mlx.Concatenate([]*mlx.Array{qTxt, qImg}, 1)
|
||||
kJoint := mlx.Concatenate([]*mlx.Array{kTxt, kImg}, 1)
|
||||
vJoint := mlx.Concatenate([]*mlx.Array{vTxt, vImg}, 1)
|
||||
|
||||
// Transpose to [B, nheads, L, head_dim]
|
||||
qJoint = mlx.Transpose(qJoint, 0, 2, 1, 3)
|
||||
kJoint = mlx.Transpose(kJoint, 0, 2, 1, 3)
|
||||
vJoint = mlx.Transpose(vJoint, 0, 2, 1, 3)
|
||||
|
||||
// SDPA
|
||||
outJoint := mlx.ScaledDotProductAttention(qJoint, kJoint, vJoint, attn.Scale, false)
|
||||
|
||||
// Transpose back and split
|
||||
outJoint = mlx.Transpose(outJoint, 0, 2, 1, 3) // [B, L, nheads, head_dim]
|
||||
outJoint = mlx.Reshape(outJoint, B, Ltxt+Limg, D)
|
||||
|
||||
outTxt := mlx.Slice(outJoint, []int32{0, 0, 0}, []int32{B, Ltxt, D})
|
||||
outImg := mlx.Slice(outJoint, []int32{0, Ltxt, 0}, []int32{B, Ltxt + Limg, D})
|
||||
|
||||
// Output projections
|
||||
outImg = mlx.Reshape(outImg, B*Limg, D)
|
||||
outImg = mlx.Add(mlx.Linear(outImg, attn.ToOut), attn.ToOutB)
|
||||
outImg = mlx.Reshape(outImg, B, Limg, D)
|
||||
|
||||
outTxt = mlx.Reshape(outTxt, B*Ltxt, D)
|
||||
outTxt = mlx.Add(mlx.Linear(outTxt, attn.ToAddOut), attn.ToAddOutB)
|
||||
outTxt = mlx.Reshape(outTxt, B, Ltxt, D)
|
||||
|
||||
return outImg, outTxt
|
||||
}
|
||||
|
||||
// applyRoPE applies rotary embeddings using complex multiplication
|
||||
// x: [B, L, nheads, head_dim]
|
||||
// freqs: [L, head_dim] as complex (interleaved real/imag pairs)
|
||||
func applyRoPE(x *mlx.Array, freqs *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
nheads := shape[2]
|
||||
headDim := shape[3]
|
||||
halfDim := headDim / 2
|
||||
|
||||
// Reshape x to pairs: [B, L, nheads, half, 2]
|
||||
xPairs := mlx.Reshape(x, B, L, nheads, halfDim, 2)
|
||||
|
||||
// freqs: [L, head_dim] -> [1, L, 1, half, 2]
|
||||
freqsExp := mlx.Reshape(freqs, 1, L, 1, halfDim, 2)
|
||||
|
||||
// Extract real/imag parts
|
||||
xReal := mlx.SliceStride(xPairs, []int32{0, 0, 0, 0, 0}, []int32{B, L, nheads, halfDim, 1}, []int32{1, 1, 1, 1, 1})
|
||||
xImag := mlx.SliceStride(xPairs, []int32{0, 0, 0, 0, 1}, []int32{B, L, nheads, halfDim, 2}, []int32{1, 1, 1, 1, 1})
|
||||
xReal = mlx.Squeeze(xReal, 4)
|
||||
xImag = mlx.Squeeze(xImag, 4)
|
||||
|
||||
freqReal := mlx.SliceStride(freqsExp, []int32{0, 0, 0, 0, 0}, []int32{1, L, 1, halfDim, 1}, []int32{1, 1, 1, 1, 1})
|
||||
freqImag := mlx.SliceStride(freqsExp, []int32{0, 0, 0, 0, 1}, []int32{1, L, 1, halfDim, 2}, []int32{1, 1, 1, 1, 1})
|
||||
freqReal = mlx.Squeeze(freqReal, 4)
|
||||
freqImag = mlx.Squeeze(freqImag, 4)
|
||||
|
||||
// Complex multiplication: (a + bi) * (c + di) = (ac - bd) + (ad + bc)i
|
||||
outReal := mlx.Sub(mlx.Mul(xReal, freqReal), mlx.Mul(xImag, freqImag))
|
||||
outImag := mlx.Add(mlx.Mul(xReal, freqImag), mlx.Mul(xImag, freqReal))
|
||||
|
||||
// Interleave back
|
||||
outReal = mlx.ExpandDims(outReal, 4)
|
||||
outImag = mlx.ExpandDims(outImag, 4)
|
||||
out := mlx.Concatenate([]*mlx.Array{outReal, outImag}, 4)
|
||||
|
||||
return mlx.Reshape(out, B, L, nheads, headDim)
|
||||
}
|
||||
|
||||
// MLP implements GELU MLP (not GEGLU)
|
||||
type MLP struct {
|
||||
ProjWeight *mlx.Array
|
||||
ProjBias *mlx.Array
|
||||
OutWeight *mlx.Array
|
||||
OutBias *mlx.Array
|
||||
}
|
||||
|
||||
// newMLP creates a GELU MLP
|
||||
func newMLP(weights *safetensors.ModelWeights, prefix string) (*MLP, error) {
|
||||
projWeight, _ := weights.Get(prefix + ".net.0.proj.weight")
|
||||
projBias, _ := weights.Get(prefix + ".net.0.proj.bias")
|
||||
outWeight, _ := weights.Get(prefix + ".net.2.weight")
|
||||
outBias, _ := weights.Get(prefix + ".net.2.bias")
|
||||
|
||||
return &MLP{
|
||||
ProjWeight: mlx.Transpose(projWeight, 1, 0),
|
||||
ProjBias: projBias,
|
||||
OutWeight: mlx.Transpose(outWeight, 1, 0),
|
||||
OutBias: outBias,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies GELU MLP
|
||||
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
D := shape[2]
|
||||
|
||||
xFlat := mlx.Reshape(x, B*L, D)
|
||||
h := mlx.Add(mlx.Linear(xFlat, m.ProjWeight), m.ProjBias)
|
||||
h = geluApprox(h)
|
||||
h = mlx.Add(mlx.Linear(h, m.OutWeight), m.OutBias)
|
||||
return mlx.Reshape(h, B, L, m.OutBias.Dim(0))
|
||||
}
|
||||
|
||||
// geluApprox implements approximate GELU
|
||||
func geluApprox(x *mlx.Array) *mlx.Array {
|
||||
sqrt2OverPi := float32(math.Sqrt(2.0 / math.Pi))
|
||||
x3 := mlx.Mul(mlx.Mul(x, x), x)
|
||||
inner := mlx.Add(x, mlx.MulScalar(x3, 0.044715))
|
||||
inner = mlx.MulScalar(inner, sqrt2OverPi)
|
||||
return mlx.Mul(mlx.MulScalar(x, 0.5), mlx.AddScalar(mlx.Tanh(inner), 1.0))
|
||||
}
|
||||
|
||||
// TransformerBlock is a single dual-stream transformer block
|
||||
type TransformerBlock struct {
|
||||
Attention *JointAttention
|
||||
ImgMLP *MLP
|
||||
TxtMLP *MLP
|
||||
|
||||
ImgModWeight *mlx.Array
|
||||
ImgModBias *mlx.Array
|
||||
TxtModWeight *mlx.Array
|
||||
TxtModBias *mlx.Array
|
||||
|
||||
HiddenDim int32
|
||||
NormEps float32
|
||||
}
|
||||
|
||||
// newTransformerBlock creates a transformer block
|
||||
func newTransformerBlock(weights *safetensors.ModelWeights, prefix string, cfg *TransformerConfig) (*TransformerBlock, error) {
|
||||
attn, err := newJointAttention(weights, prefix, cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
imgMLP, _ := newMLP(weights, prefix+".img_mlp")
|
||||
txtMLP, _ := newMLP(weights, prefix+".txt_mlp")
|
||||
|
||||
imgModWeight, _ := weights.Get(prefix + ".img_mod.1.weight")
|
||||
imgModBias, _ := weights.Get(prefix + ".img_mod.1.bias")
|
||||
txtModWeight, _ := weights.Get(prefix + ".txt_mod.1.weight")
|
||||
txtModBias, _ := weights.Get(prefix + ".txt_mod.1.bias")
|
||||
|
||||
return &TransformerBlock{
|
||||
Attention: attn,
|
||||
ImgMLP: imgMLP,
|
||||
TxtMLP: txtMLP,
|
||||
ImgModWeight: mlx.Transpose(imgModWeight, 1, 0),
|
||||
ImgModBias: imgModBias,
|
||||
TxtModWeight: mlx.Transpose(txtModWeight, 1, 0),
|
||||
TxtModBias: txtModBias,
|
||||
HiddenDim: cfg.HiddenDim,
|
||||
NormEps: cfg.NormEps,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies the transformer block
|
||||
func (tb *TransformerBlock) Forward(img, txt, temb *mlx.Array, imgFreqs, txtFreqs *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
// Compute modulation: silu(temb) -> linear -> [B, 6*D]
|
||||
siluT := mlx.SiLU(temb)
|
||||
imgMod := mlx.Add(mlx.Linear(siluT, tb.ImgModWeight), tb.ImgModBias)
|
||||
txtMod := mlx.Add(mlx.Linear(siluT, tb.TxtModWeight), tb.TxtModBias)
|
||||
|
||||
// Split into 6 parts: shift1, scale1, gate1, shift2, scale2, gate2
|
||||
imgModParts := splitMod6(imgMod, tb.HiddenDim)
|
||||
txtModParts := splitMod6(txtMod, tb.HiddenDim)
|
||||
|
||||
// Pre-attention: norm + modulate
|
||||
imgNorm := layerNormNoAffine(img, tb.NormEps)
|
||||
imgNorm = mlx.Add(mlx.Mul(imgNorm, mlx.AddScalar(imgModParts[1], 1.0)), imgModParts[0])
|
||||
|
||||
txtNorm := layerNormNoAffine(txt, tb.NormEps)
|
||||
txtNorm = mlx.Add(mlx.Mul(txtNorm, mlx.AddScalar(txtModParts[1], 1.0)), txtModParts[0])
|
||||
|
||||
// Joint attention
|
||||
attnImg, attnTxt := tb.Attention.Forward(imgNorm, txtNorm, imgFreqs, txtFreqs)
|
||||
|
||||
// Residual with gate
|
||||
img = mlx.Add(img, mlx.Mul(imgModParts[2], attnImg))
|
||||
txt = mlx.Add(txt, mlx.Mul(txtModParts[2], attnTxt))
|
||||
|
||||
// Pre-MLP: norm + modulate
|
||||
imgNorm2 := layerNormNoAffine(img, tb.NormEps)
|
||||
imgNorm2 = mlx.Add(mlx.Mul(imgNorm2, mlx.AddScalar(imgModParts[4], 1.0)), imgModParts[3])
|
||||
|
||||
txtNorm2 := layerNormNoAffine(txt, tb.NormEps)
|
||||
txtNorm2 = mlx.Add(mlx.Mul(txtNorm2, mlx.AddScalar(txtModParts[4], 1.0)), txtModParts[3])
|
||||
|
||||
// MLP
|
||||
mlpImg := tb.ImgMLP.Forward(imgNorm2)
|
||||
mlpTxt := tb.TxtMLP.Forward(txtNorm2)
|
||||
|
||||
// Residual with gate
|
||||
img = mlx.Add(img, mlx.Mul(imgModParts[5], mlpImg))
|
||||
txt = mlx.Add(txt, mlx.Mul(txtModParts[5], mlpTxt))
|
||||
|
||||
return img, txt
|
||||
}
|
||||
|
||||
// splitMod6 splits modulation into 6 parts each [B, 1, D]
|
||||
func splitMod6(mod *mlx.Array, hiddenDim int32) []*mlx.Array {
|
||||
shape := mod.Shape()
|
||||
B := shape[0]
|
||||
parts := make([]*mlx.Array, 6)
|
||||
for i := int32(0); i < 6; i++ {
|
||||
part := mlx.Slice(mod, []int32{0, i * hiddenDim}, []int32{B, (i + 1) * hiddenDim})
|
||||
parts[i] = mlx.ExpandDims(part, 1)
|
||||
}
|
||||
return parts
|
||||
}
|
||||
|
||||
// layerNormNoAffine applies layer norm without learnable parameters
|
||||
func layerNormNoAffine(x *mlx.Array, eps float32) *mlx.Array {
|
||||
ndim := x.Ndim()
|
||||
lastAxis := ndim - 1
|
||||
mean := mlx.Mean(x, lastAxis, true)
|
||||
xCentered := mlx.Sub(x, mean)
|
||||
variance := mlx.Mean(mlx.Square(xCentered), lastAxis, true)
|
||||
return mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, eps)))
|
||||
}
|
||||
|
||||
// Transformer is the full Qwen-Image transformer model
|
||||
type Transformer struct {
|
||||
Config *TransformerConfig
|
||||
|
||||
ImgIn *mlx.Array
|
||||
ImgInBias *mlx.Array
|
||||
TxtIn *mlx.Array
|
||||
TxtInBias *mlx.Array
|
||||
TxtNorm *mlx.Array
|
||||
|
||||
TEmbed *TimestepEmbedder
|
||||
Layers []*TransformerBlock
|
||||
|
||||
NormOutWeight *mlx.Array
|
||||
NormOutBias *mlx.Array
|
||||
ProjOut *mlx.Array
|
||||
ProjOutBias *mlx.Array
|
||||
}
|
||||
|
||||
// Load loads the transformer from a directory
|
||||
func (m *Transformer) Load(path string) error {
|
||||
fmt.Println("Loading Qwen-Image transformer...")
|
||||
|
||||
cfg := defaultTransformerConfig()
|
||||
m.Config = cfg
|
||||
|
||||
weights, err := safetensors.LoadModelWeights(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
|
||||
// Bulk load all weights as bf16
|
||||
fmt.Print(" Loading weights as bf16... ")
|
||||
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
|
||||
return fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
fmt.Printf("✓ (%.1f GB)\n", float64(mlx.MetalGetActiveMemory())/(1024*1024*1024))
|
||||
|
||||
fmt.Print(" Loading input projections... ")
|
||||
imgIn, _ := weights.Get("img_in.weight")
|
||||
imgInBias, _ := weights.Get("img_in.bias")
|
||||
txtIn, _ := weights.Get("txt_in.weight")
|
||||
txtInBias, _ := weights.Get("txt_in.bias")
|
||||
txtNorm, _ := weights.Get("txt_norm.weight")
|
||||
m.ImgIn = mlx.Transpose(imgIn, 1, 0)
|
||||
m.ImgInBias = imgInBias
|
||||
m.TxtIn = mlx.Transpose(txtIn, 1, 0)
|
||||
m.TxtInBias = txtInBias
|
||||
m.TxtNorm = txtNorm
|
||||
fmt.Println("✓")
|
||||
|
||||
fmt.Print(" Loading timestep embedder... ")
|
||||
m.TEmbed, err = newTimestepEmbedder(weights)
|
||||
if err != nil {
|
||||
return fmt.Errorf("timestep embedder: %w", err)
|
||||
}
|
||||
fmt.Println("✓")
|
||||
|
||||
m.Layers = make([]*TransformerBlock, cfg.NLayers)
|
||||
for i := int32(0); i < cfg.NLayers; i++ {
|
||||
fmt.Printf("\r Loading transformer layers... %d/%d", i+1, cfg.NLayers)
|
||||
prefix := fmt.Sprintf("transformer_blocks.%d", i)
|
||||
m.Layers[i], err = newTransformerBlock(weights, prefix, cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("layer %d: %w", i, err)
|
||||
}
|
||||
}
|
||||
fmt.Printf("\r Loading transformer layers... ✓ [%d blocks] \n", cfg.NLayers)
|
||||
|
||||
fmt.Print(" Loading output layers... ")
|
||||
normOutWeight, _ := weights.Get("norm_out.linear.weight")
|
||||
normOutBias, _ := weights.Get("norm_out.linear.bias")
|
||||
projOut, _ := weights.Get("proj_out.weight")
|
||||
projOutBias, _ := weights.Get("proj_out.bias")
|
||||
m.NormOutWeight = mlx.Transpose(normOutWeight, 1, 0)
|
||||
m.NormOutBias = normOutBias
|
||||
m.ProjOut = mlx.Transpose(projOut, 1, 0)
|
||||
m.ProjOutBias = projOutBias
|
||||
fmt.Println("✓")
|
||||
|
||||
weights.ReleaseAll()
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadFromPath is a convenience function to load transformer from path
|
||||
func LoadTransformerFromPath(path string) (*Transformer, error) {
|
||||
m := &Transformer{}
|
||||
if err := m.Load(filepath.Join(path, "transformer")); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Forward runs the transformer
|
||||
// img: [B, L_img, in_channels] patchified latents
|
||||
// txt: [B, L_txt, joint_attention_dim] text embeddings
|
||||
// t: [B] timesteps (0-1)
|
||||
// imgFreqs, txtFreqs: RoPE frequencies
|
||||
func (tr *Transformer) Forward(img, txt, t *mlx.Array, imgFreqs, txtFreqs *mlx.Array) *mlx.Array {
|
||||
imgShape := img.Shape()
|
||||
B := imgShape[0]
|
||||
Limg := imgShape[1]
|
||||
|
||||
txtShape := txt.Shape()
|
||||
Ltxt := txtShape[1]
|
||||
|
||||
// Timestep embedding
|
||||
temb := tr.TEmbed.Forward(t)
|
||||
|
||||
// Project image: [B, L, in_channels] -> [B, L, hidden_dim]
|
||||
imgFlat := mlx.Reshape(img, B*Limg, tr.Config.InChannels)
|
||||
imgH := mlx.Add(mlx.Linear(imgFlat, tr.ImgIn), tr.ImgInBias)
|
||||
imgH = mlx.Reshape(imgH, B, Limg, tr.Config.HiddenDim)
|
||||
|
||||
// Project text: RMSNorm then linear
|
||||
txtFlat := mlx.Reshape(txt, B*Ltxt, tr.Config.JointAttentionDim)
|
||||
txtNormed := mlx.RMSNorm(txtFlat, tr.TxtNorm, 1e-6)
|
||||
txtH := mlx.Add(mlx.Linear(txtNormed, tr.TxtIn), tr.TxtInBias)
|
||||
txtH = mlx.Reshape(txtH, B, Ltxt, tr.Config.HiddenDim)
|
||||
|
||||
for _, layer := range tr.Layers {
|
||||
imgH, txtH = layer.Forward(imgH, txtH, temb, imgFreqs, txtFreqs)
|
||||
}
|
||||
|
||||
// Final norm with modulation (AdaLayerNormContinuous)
|
||||
// Python: scale, shift = torch.chunk(emb, 2, dim=1)
|
||||
finalMod := mlx.Add(mlx.Linear(mlx.SiLU(temb), tr.NormOutWeight), tr.NormOutBias)
|
||||
modShape := finalMod.Shape()
|
||||
halfDim := modShape[1] / 2
|
||||
scale := mlx.ExpandDims(mlx.Slice(finalMod, []int32{0, 0}, []int32{B, halfDim}), 1)
|
||||
shift := mlx.ExpandDims(mlx.Slice(finalMod, []int32{0, halfDim}, []int32{B, modShape[1]}), 1)
|
||||
|
||||
imgH = layerNormNoAffine(imgH, tr.Config.NormEps)
|
||||
imgH = mlx.Add(mlx.Mul(imgH, mlx.AddScalar(scale, 1.0)), shift)
|
||||
|
||||
// Final projection: [B, L, hidden_dim] -> [B, L, patch_size^2 * out_channels]
|
||||
imgFlat = mlx.Reshape(imgH, B*Limg, tr.Config.HiddenDim)
|
||||
out := mlx.Add(mlx.Linear(imgFlat, tr.ProjOut), tr.ProjOutBias)
|
||||
|
||||
outChannels := tr.Config.PatchSize * tr.Config.PatchSize * tr.Config.OutChannels
|
||||
return mlx.Reshape(out, B, Limg, outChannels)
|
||||
}
|
||||
|
||||
// ForwardWithCache runs the transformer with layer caching for speedup.
|
||||
// Based on DeepCache (CVPR 2024) / Learning-to-Cache (NeurIPS 2024):
|
||||
// shallow layers change little between denoising steps, so we cache their
|
||||
// outputs and reuse them on non-refresh steps.
|
||||
//
|
||||
// stepCache: cache for layer outputs (use cache.NewStepCache(cacheLayers))
|
||||
// step: current denoising step (0-indexed)
|
||||
// cacheInterval: refresh cache every N steps (e.g., 3)
|
||||
// cacheLayers: number of shallow layers to cache (e.g., 15)
|
||||
func (tr *Transformer) ForwardWithCache(
|
||||
img, txt, t *mlx.Array,
|
||||
imgFreqs, txtFreqs *mlx.Array,
|
||||
stepCache *cache.StepCache,
|
||||
step, cacheInterval, cacheLayers int,
|
||||
) *mlx.Array {
|
||||
imgShape := img.Shape()
|
||||
B := imgShape[0]
|
||||
Limg := imgShape[1]
|
||||
|
||||
txtShape := txt.Shape()
|
||||
Ltxt := txtShape[1]
|
||||
|
||||
// Timestep embedding
|
||||
temb := tr.TEmbed.Forward(t)
|
||||
|
||||
// Project image: [B, L, in_channels] -> [B, L, hidden_dim]
|
||||
imgFlat := mlx.Reshape(img, B*Limg, tr.Config.InChannels)
|
||||
imgH := mlx.Add(mlx.Linear(imgFlat, tr.ImgIn), tr.ImgInBias)
|
||||
imgH = mlx.Reshape(imgH, B, Limg, tr.Config.HiddenDim)
|
||||
|
||||
// Project text: RMSNorm then linear
|
||||
txtFlat := mlx.Reshape(txt, B*Ltxt, tr.Config.JointAttentionDim)
|
||||
txtNormed := mlx.RMSNorm(txtFlat, tr.TxtNorm, 1e-6)
|
||||
txtH := mlx.Add(mlx.Linear(txtNormed, tr.TxtIn), tr.TxtInBias)
|
||||
txtH = mlx.Reshape(txtH, B, Ltxt, tr.Config.HiddenDim)
|
||||
|
||||
// Check if we should refresh the cache
|
||||
refreshCache := stepCache.ShouldRefresh(step, cacheInterval)
|
||||
|
||||
for i, layer := range tr.Layers {
|
||||
if i < cacheLayers && !refreshCache && stepCache.Get(i) != nil {
|
||||
// Use cached outputs for shallow layers
|
||||
imgH = stepCache.Get(i)
|
||||
txtH = stepCache.Get2(i)
|
||||
} else {
|
||||
// Compute layer
|
||||
imgH, txtH = layer.Forward(imgH, txtH, temb, imgFreqs, txtFreqs)
|
||||
// Cache shallow layers on refresh steps
|
||||
if i < cacheLayers && refreshCache {
|
||||
stepCache.Set(i, imgH)
|
||||
stepCache.Set2(i, txtH)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Final norm with modulation (AdaLayerNormContinuous)
|
||||
finalMod := mlx.Add(mlx.Linear(mlx.SiLU(temb), tr.NormOutWeight), tr.NormOutBias)
|
||||
modShape := finalMod.Shape()
|
||||
halfDim := modShape[1] / 2
|
||||
scale := mlx.ExpandDims(mlx.Slice(finalMod, []int32{0, 0}, []int32{B, halfDim}), 1)
|
||||
shift := mlx.ExpandDims(mlx.Slice(finalMod, []int32{0, halfDim}, []int32{B, modShape[1]}), 1)
|
||||
|
||||
imgH = layerNormNoAffine(imgH, tr.Config.NormEps)
|
||||
imgH = mlx.Add(mlx.Mul(imgH, mlx.AddScalar(scale, 1.0)), shift)
|
||||
|
||||
// Final projection: [B, L, hidden_dim] -> [B, L, patch_size^2 * out_channels]
|
||||
imgFlat = mlx.Reshape(imgH, B*Limg, tr.Config.HiddenDim)
|
||||
out := mlx.Add(mlx.Linear(imgFlat, tr.ProjOut), tr.ProjOutBias)
|
||||
|
||||
outChannels := tr.Config.PatchSize * tr.Config.PatchSize * tr.Config.OutChannels
|
||||
return mlx.Reshape(out, B, Limg, outChannels)
|
||||
}
|
||||
|
||||
// RoPECache holds precomputed RoPE frequencies
|
||||
type RoPECache struct {
|
||||
ImgFreqs *mlx.Array // [L_img, head_dim]
|
||||
TxtFreqs *mlx.Array // [L_txt, head_dim]
|
||||
}
|
||||
|
||||
// PrepareRoPE computes RoPE for image and text sequences
|
||||
// This matches Python's QwenEmbedRope with scale_rope=True
|
||||
func PrepareRoPE(imgH, imgW int32, txtLen int32, axesDims []int32) *RoPECache {
|
||||
theta := float64(10000)
|
||||
maxIdx := int32(4096)
|
||||
|
||||
// Compute base frequencies for each axis dimension
|
||||
freqsT := ComputeAxisFreqs(axesDims[0], theta)
|
||||
freqsH := ComputeAxisFreqs(axesDims[1], theta)
|
||||
freqsW := ComputeAxisFreqs(axesDims[2], theta)
|
||||
|
||||
// Build frequency lookup tables
|
||||
posFreqsT := MakeFreqTable(maxIdx, freqsT, false)
|
||||
posFreqsH := MakeFreqTable(maxIdx, freqsH, false)
|
||||
posFreqsW := MakeFreqTable(maxIdx, freqsW, false)
|
||||
negFreqsH := MakeFreqTable(maxIdx, freqsH, true)
|
||||
negFreqsW := MakeFreqTable(maxIdx, freqsW, true)
|
||||
|
||||
// Image frequencies with scale_rope=True
|
||||
imgLen := imgH * imgW
|
||||
headDim := int32(len(freqsT)+len(freqsH)+len(freqsW)) * 2
|
||||
imgFreqsData := make([]float32, imgLen*headDim)
|
||||
|
||||
hHalf := imgH / 2
|
||||
wHalf := imgW / 2
|
||||
|
||||
idx := int32(0)
|
||||
for y := int32(0); y < imgH; y++ {
|
||||
for x := int32(0); x < imgW; x++ {
|
||||
// Frame = 0
|
||||
for i := 0; i < len(freqsT)*2; i++ {
|
||||
imgFreqsData[idx+int32(i)] = posFreqsT[0][i]
|
||||
}
|
||||
idx += int32(len(freqsT) * 2)
|
||||
|
||||
// Height: scale_rope pattern
|
||||
hNegCount := imgH - hHalf
|
||||
if y < hNegCount {
|
||||
negTableIdx := maxIdx - hNegCount + y
|
||||
for i := 0; i < len(freqsH)*2; i++ {
|
||||
imgFreqsData[idx+int32(i)] = negFreqsH[negTableIdx][i]
|
||||
}
|
||||
} else {
|
||||
posIdx := y - hNegCount
|
||||
for i := 0; i < len(freqsH)*2; i++ {
|
||||
imgFreqsData[idx+int32(i)] = posFreqsH[posIdx][i]
|
||||
}
|
||||
}
|
||||
idx += int32(len(freqsH) * 2)
|
||||
|
||||
// Width: scale_rope pattern
|
||||
wNegCount := imgW - wHalf
|
||||
if x < wNegCount {
|
||||
negTableIdx := maxIdx - wNegCount + x
|
||||
for i := 0; i < len(freqsW)*2; i++ {
|
||||
imgFreqsData[idx+int32(i)] = negFreqsW[negTableIdx][i]
|
||||
}
|
||||
} else {
|
||||
posIdx := x - wNegCount
|
||||
for i := 0; i < len(freqsW)*2; i++ {
|
||||
imgFreqsData[idx+int32(i)] = posFreqsW[posIdx][i]
|
||||
}
|
||||
}
|
||||
idx += int32(len(freqsW) * 2)
|
||||
}
|
||||
}
|
||||
|
||||
imgFreqs := mlx.NewArray(imgFreqsData, []int32{imgLen, headDim})
|
||||
imgFreqs = mlx.ToBFloat16(imgFreqs)
|
||||
|
||||
// Text frequencies
|
||||
maxVidIdx := max(hHalf, wHalf)
|
||||
txtFreqsData := make([]float32, txtLen*headDim)
|
||||
|
||||
idx = 0
|
||||
for t := int32(0); t < txtLen; t++ {
|
||||
pos := maxVidIdx + t
|
||||
for i := 0; i < len(freqsT)*2; i++ {
|
||||
txtFreqsData[idx+int32(i)] = posFreqsT[pos][i]
|
||||
}
|
||||
idx += int32(len(freqsT) * 2)
|
||||
for i := 0; i < len(freqsH)*2; i++ {
|
||||
txtFreqsData[idx+int32(i)] = posFreqsH[pos][i]
|
||||
}
|
||||
idx += int32(len(freqsH) * 2)
|
||||
for i := 0; i < len(freqsW)*2; i++ {
|
||||
txtFreqsData[idx+int32(i)] = posFreqsW[pos][i]
|
||||
}
|
||||
idx += int32(len(freqsW) * 2)
|
||||
}
|
||||
|
||||
txtFreqs := mlx.NewArray(txtFreqsData, []int32{txtLen, headDim})
|
||||
txtFreqs = mlx.ToBFloat16(txtFreqs)
|
||||
|
||||
return &RoPECache{
|
||||
ImgFreqs: imgFreqs,
|
||||
TxtFreqs: txtFreqs,
|
||||
}
|
||||
}
|
||||
|
||||
// ComputeAxisFreqs computes RoPE base frequencies for a given dimension.
|
||||
func ComputeAxisFreqs(dim int32, theta float64) []float64 {
|
||||
halfDim := dim / 2
|
||||
freqs := make([]float64, halfDim)
|
||||
for i := int32(0); i < halfDim; i++ {
|
||||
freqs[i] = 1.0 / math.Pow(theta, float64(i)/float64(halfDim))
|
||||
}
|
||||
return freqs
|
||||
}
|
||||
|
||||
// MakeFreqTable builds a table of cos/sin values for RoPE positions.
|
||||
func MakeFreqTable(maxIdx int32, baseFreqs []float64, negative bool) [][]float32 {
|
||||
table := make([][]float32, maxIdx)
|
||||
for idx := int32(0); idx < maxIdx; idx++ {
|
||||
var pos float64
|
||||
if negative {
|
||||
pos = float64(-maxIdx + int32(idx))
|
||||
} else {
|
||||
pos = float64(idx)
|
||||
}
|
||||
|
||||
row := make([]float32, len(baseFreqs)*2)
|
||||
for i, f := range baseFreqs {
|
||||
angle := pos * f
|
||||
row[i*2] = float32(math.Cos(angle))
|
||||
row[i*2+1] = float32(math.Sin(angle))
|
||||
}
|
||||
table[idx] = row
|
||||
}
|
||||
return table
|
||||
}
|
||||
|
||||
func max(a, b int32) int32 {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// PackLatents converts [B, C, H, W] to [B, L, C*4] patches
|
||||
func PackLatents(latents *mlx.Array, patchSize int32) *mlx.Array {
|
||||
shape := latents.Shape()
|
||||
B := shape[0]
|
||||
C := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
|
||||
pH := H / patchSize
|
||||
pW := W / patchSize
|
||||
|
||||
// [B, C, H, W] -> [B, C, pH, 2, pW, 2]
|
||||
x := mlx.Reshape(latents, B, C, pH, patchSize, pW, patchSize)
|
||||
// -> [B, pH, pW, C, 2, 2]
|
||||
x = mlx.Transpose(x, 0, 2, 4, 1, 3, 5)
|
||||
// -> [B, pH*pW, C*4]
|
||||
return mlx.Reshape(x, B, pH*pW, C*patchSize*patchSize)
|
||||
}
|
||||
|
||||
// UnpackLatents converts [B, L, C*4] back to [B, C, 1, H, W] (5D for VAE)
|
||||
func UnpackLatents(patches *mlx.Array, H, W, patchSize int32) *mlx.Array {
|
||||
shape := patches.Shape()
|
||||
B := shape[0]
|
||||
channels := shape[2] / (patchSize * patchSize)
|
||||
|
||||
pH := H / patchSize
|
||||
pW := W / patchSize
|
||||
|
||||
// [B, L, C*4] -> [B, pH, pW, C, 2, 2]
|
||||
x := mlx.Reshape(patches, B, pH, pW, channels, patchSize, patchSize)
|
||||
// -> [B, C, pH, 2, pW, 2]
|
||||
x = mlx.Transpose(x, 0, 3, 1, 4, 2, 5)
|
||||
// -> [B, C, H, W]
|
||||
x = mlx.Reshape(x, B, channels, pH*patchSize, pW*patchSize)
|
||||
// Add temporal dimension for VAE: [B, C, 1, H, W]
|
||||
return mlx.ExpandDims(x, 2)
|
||||
}
|
||||
119
x/imagegen/models/qwen_image/transformer_test.go
Normal file
119
x/imagegen/models/qwen_image/transformer_test.go
Normal file
@@ -0,0 +1,119 @@
|
||||
//go:build mlx
|
||||
|
||||
package qwen_image
|
||||
|
||||
import (
|
||||
"math"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// TestTransformerConfig tests configuration invariants.
|
||||
func TestTransformerConfig(t *testing.T) {
|
||||
cfg := defaultTransformerConfig()
|
||||
|
||||
// Property: hidden_dim = n_heads * head_dim
|
||||
if cfg.HiddenDim != cfg.NHeads*cfg.HeadDim {
|
||||
t.Errorf("hidden_dim != n_heads * head_dim: %d != %d * %d",
|
||||
cfg.HiddenDim, cfg.NHeads, cfg.HeadDim)
|
||||
}
|
||||
|
||||
// Property: axes_dims_rope sums to head_dim
|
||||
var ropeSum int32
|
||||
for _, d := range cfg.AxesDimsRope {
|
||||
ropeSum += d
|
||||
}
|
||||
if ropeSum != cfg.HeadDim {
|
||||
t.Errorf("axes_dims_rope sum != head_dim: %d != %d", ropeSum, cfg.HeadDim)
|
||||
}
|
||||
|
||||
// Property: in_channels = out_channels * patch_size^2
|
||||
expectedIn := cfg.OutChannels * cfg.PatchSize * cfg.PatchSize
|
||||
if cfg.InChannels != expectedIn {
|
||||
t.Errorf("in_channels != out_channels * patch_size^2: %d != %d", cfg.InChannels, expectedIn)
|
||||
}
|
||||
}
|
||||
|
||||
// TestTransformerRoPE tests RoPE frequency computation produces valid values.
|
||||
func TestTransformerRoPE(t *testing.T) {
|
||||
cfg := defaultTransformerConfig()
|
||||
|
||||
// Test with small image dimensions
|
||||
imgH, imgW := int32(4), int32(4) // 4x4 latent = 16 patches
|
||||
txtLen := int32(5)
|
||||
|
||||
ropeCache := PrepareRoPE(imgH, imgW, txtLen, cfg.AxesDimsRope)
|
||||
mlx.Eval(ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
|
||||
// Verify shapes: [seq_len, head_dim]
|
||||
imgSeqLen := imgH * imgW
|
||||
if ropeCache.ImgFreqs.Shape()[0] != imgSeqLen {
|
||||
t.Errorf("ImgFreqs seq_len: got %d, want %d", ropeCache.ImgFreqs.Shape()[0], imgSeqLen)
|
||||
}
|
||||
if ropeCache.ImgFreqs.Shape()[1] != cfg.HeadDim {
|
||||
t.Errorf("ImgFreqs head_dim: got %d, want %d", ropeCache.ImgFreqs.Shape()[1], cfg.HeadDim)
|
||||
}
|
||||
|
||||
if ropeCache.TxtFreqs.Shape()[0] != txtLen {
|
||||
t.Errorf("TxtFreqs seq_len: got %d, want %d", ropeCache.TxtFreqs.Shape()[0], txtLen)
|
||||
}
|
||||
|
||||
// Verify values are finite
|
||||
imgData := ropeCache.ImgFreqs.Data()
|
||||
for i := 0; i < min(100, len(imgData)); i++ {
|
||||
if math.IsNaN(float64(imgData[i])) || math.IsInf(float64(imgData[i]), 0) {
|
||||
t.Errorf("ImgFreqs[%d] not finite: %v", i, imgData[i])
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestTransformerForward tests full forward pass (integration test).
|
||||
// Skips if model weights are not available.
|
||||
func TestTransformerForward(t *testing.T) {
|
||||
weightsPath := "../../../weights/Qwen-Image-2512/transformer"
|
||||
if _, err := os.Stat(weightsPath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping: model weights not found at " + weightsPath)
|
||||
}
|
||||
|
||||
transformer := &Transformer{}
|
||||
if err := transformer.Load(weightsPath); err != nil {
|
||||
t.Fatalf("Failed to load transformer: %v", err)
|
||||
}
|
||||
mlx.Keep(mlx.Collect(transformer)...)
|
||||
cfg := transformer.Config
|
||||
|
||||
// Small test inputs
|
||||
batchSize := int32(1)
|
||||
imgH, imgW := int32(4), int32(4)
|
||||
imgSeqLen := imgH * imgW
|
||||
txtSeqLen := int32(5)
|
||||
|
||||
hiddenStates := mlx.RandomNormal([]int32{batchSize, imgSeqLen, cfg.InChannels}, 0)
|
||||
encoderHiddenStates := mlx.RandomNormal([]int32{batchSize, txtSeqLen, cfg.JointAttentionDim}, 0)
|
||||
timestep := mlx.NewArray([]float32{0.5}, []int32{batchSize})
|
||||
|
||||
ropeCache := PrepareRoPE(imgH, imgW, txtSeqLen, cfg.AxesDimsRope)
|
||||
|
||||
// Forward pass
|
||||
out := transformer.Forward(hiddenStates, encoderHiddenStates, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
mlx.Eval(out)
|
||||
|
||||
// Verify output shape: [batch, img_seq_len, in_channels]
|
||||
wantShape := []int32{batchSize, imgSeqLen, cfg.InChannels}
|
||||
gotShape := out.Shape()
|
||||
if gotShape[0] != wantShape[0] || gotShape[1] != wantShape[1] || gotShape[2] != wantShape[2] {
|
||||
t.Errorf("output shape: got %v, want %v", gotShape, wantShape)
|
||||
}
|
||||
|
||||
// Verify output is finite
|
||||
outData := out.Data()
|
||||
for i := 0; i < min(100, len(outData)); i++ {
|
||||
if math.IsNaN(float64(outData[i])) || math.IsInf(float64(outData[i]), 0) {
|
||||
t.Errorf("output[%d] not finite: %v", i, outData[i])
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
854
x/imagegen/models/qwen_image/vae.go
Normal file
854
x/imagegen/models/qwen_image/vae.go
Normal file
@@ -0,0 +1,854 @@
|
||||
//go:build mlx
|
||||
|
||||
package qwen_image
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
|
||||
// VAEConfig holds Qwen-Image VAE configuration
|
||||
type VAEConfig struct {
|
||||
ZDim int32 `json:"z_dim"` // 16
|
||||
BaseDim int32 `json:"base_dim"` // 96
|
||||
DimMult []int32 `json:"dim_mult"` // [1, 2, 4, 4]
|
||||
NumResBlocks int32 `json:"num_res_blocks"` // 2
|
||||
LatentsMean []float32 `json:"latents_mean"` // 16 values
|
||||
LatentsStd []float32 `json:"latents_std"` // 16 values
|
||||
TemperalDownsample []bool `json:"temperal_downsample"` // [false, true, true]
|
||||
}
|
||||
|
||||
// defaultVAEConfig returns config for Qwen-Image VAE
|
||||
func defaultVAEConfig() *VAEConfig {
|
||||
return &VAEConfig{
|
||||
ZDim: 16,
|
||||
BaseDim: 96,
|
||||
DimMult: []int32{1, 2, 4, 4},
|
||||
NumResBlocks: 2,
|
||||
LatentsMean: []float32{
|
||||
-0.7571, -0.7089, -0.9113, 0.1075,
|
||||
-0.1745, 0.9653, -0.1517, 1.5508,
|
||||
0.4134, -0.0715, 0.5517, -0.3632,
|
||||
-0.1922, -0.9497, 0.2503, -0.2921,
|
||||
},
|
||||
LatentsStd: []float32{
|
||||
2.8184, 1.4541, 2.3275, 2.6558,
|
||||
1.2196, 1.7708, 2.6052, 2.0743,
|
||||
3.2687, 2.1526, 2.8652, 1.5579,
|
||||
1.6382, 1.1253, 2.8251, 1.916,
|
||||
},
|
||||
TemperalDownsample: []bool{false, true, true},
|
||||
}
|
||||
}
|
||||
|
||||
// CausalConv3d is a causal 3D convolution (for temporal causality)
|
||||
type CausalConv3d struct {
|
||||
Weight *mlx.Array
|
||||
Bias *mlx.Array
|
||||
BiasReshaped *mlx.Array // [1, C, 1, 1, 1]
|
||||
KernelT int32
|
||||
}
|
||||
|
||||
// newCausalConv3d creates a 3D causal conv
|
||||
func newCausalConv3d(weights *safetensors.ModelWeights, prefix string) (*CausalConv3d, error) {
|
||||
weight, err := weights.Get(prefix + ".weight")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("weight not found: %s", prefix)
|
||||
}
|
||||
bias, _ := weights.Get(prefix + ".bias")
|
||||
|
||||
kernelT := weight.Shape()[2]
|
||||
outC := weight.Shape()[0]
|
||||
|
||||
var biasReshaped *mlx.Array
|
||||
if bias != nil {
|
||||
biasReshaped = mlx.Reshape(bias, 1, outC, 1, 1, 1)
|
||||
}
|
||||
|
||||
return &CausalConv3d{
|
||||
Weight: weight,
|
||||
Bias: bias,
|
||||
BiasReshaped: biasReshaped,
|
||||
KernelT: kernelT,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies causal 3D convolution
|
||||
// x: [B, T, H, W, C] (channels-last, MLX format)
|
||||
func (c *CausalConv3d) Forward(x *mlx.Array) *mlx.Array {
|
||||
shape := c.Weight.Shape() // PyTorch format: [O, I, kT, kH, kW]
|
||||
kernelT := shape[2]
|
||||
kernelH := shape[3]
|
||||
kernelW := shape[4]
|
||||
|
||||
// Causal temporal padding, same spatial padding
|
||||
// Input is channels-last: [B, T, H, W, C]
|
||||
padT := kernelT - 1
|
||||
padH := kernelH / 2
|
||||
padW := kernelW / 2
|
||||
|
||||
// Stage 1: Pad
|
||||
{
|
||||
x = pad3DChannelsLast(x, padT, 0, padH, padH, padW, padW)
|
||||
mlx.Eval(x)
|
||||
}
|
||||
|
||||
// Stage 2: Conv + bias
|
||||
var out *mlx.Array
|
||||
{
|
||||
prev := x
|
||||
weight := mlx.Transpose(c.Weight, 0, 2, 3, 4, 1)
|
||||
out = mlx.Conv3d(x, weight, 1, 1, 1, 0, 0, 0)
|
||||
if c.Bias != nil {
|
||||
bias := mlx.Reshape(c.Bias, 1, 1, 1, 1, c.Bias.Dim(0))
|
||||
out = mlx.Add(out, bias)
|
||||
}
|
||||
prev.Free()
|
||||
mlx.Eval(out)
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// RMSNorm3D applies RMS normalization over channels
|
||||
// Works with channels-last [B, T, H, W, C] format
|
||||
type RMSNorm3D struct {
|
||||
Gamma *mlx.Array // [1, 1, 1, 1, C] for broadcasting
|
||||
}
|
||||
|
||||
// newRMSNorm3D creates an RMS norm
|
||||
func newRMSNorm3D(weights *safetensors.ModelWeights, prefix string, dim int32) (*RMSNorm3D, error) {
|
||||
gamma, err := weights.Get(prefix + ".gamma")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Reshape for channels-last broadcasting: [1, 1, 1, 1, C]
|
||||
gamma = mlx.Reshape(gamma, 1, 1, 1, 1, gamma.Dim(0))
|
||||
return &RMSNorm3D{Gamma: gamma}, nil
|
||||
}
|
||||
|
||||
// Forward applies RMS norm to channels-last input [B, T, H, W, C]
|
||||
func (n *RMSNorm3D) Forward(x *mlx.Array) *mlx.Array {
|
||||
// RMSNorm: x * rsqrt(mean(x^2) + eps) * gamma
|
||||
normalized := mlx.RMSNormNoWeight(x, 1e-6)
|
||||
return mlx.Mul(normalized, n.Gamma)
|
||||
}
|
||||
|
||||
// ResBlock is a residual block with RMS norm and causal convs
|
||||
type ResBlock struct {
|
||||
Norm1 *RMSNorm3D
|
||||
Conv1 *CausalConv3d
|
||||
Norm2 *RMSNorm3D
|
||||
Conv2 *CausalConv3d
|
||||
Shortcut *CausalConv3d
|
||||
}
|
||||
|
||||
// newResBlock creates a residual block
|
||||
func newResBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32) (*ResBlock, error) {
|
||||
norm1, err := newRMSNorm3D(weights, prefix+".norm1", inDim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conv1, err := newCausalConv3d(weights, prefix+".conv1")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
norm2, err := newRMSNorm3D(weights, prefix+".norm2", outDim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conv2, err := newCausalConv3d(weights, prefix+".conv2")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var shortcut *CausalConv3d
|
||||
if inDim != outDim {
|
||||
shortcut, err = newCausalConv3d(weights, prefix+".conv_shortcut")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &ResBlock{
|
||||
Norm1: norm1,
|
||||
Conv1: conv1,
|
||||
Norm2: norm2,
|
||||
Conv2: conv2,
|
||||
Shortcut: shortcut,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies the residual block
|
||||
func (r *ResBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
// Use h as working variable, keep x intact for residual (caller will free x)
|
||||
// Conv handles its own pools, so we just need pools for non-conv operations
|
||||
var h *mlx.Array
|
||||
|
||||
// Keep x so it survives Eval() cleanup - needed for residual connection
|
||||
mlx.Keep(x)
|
||||
|
||||
// Stage 1: norm1 + silu
|
||||
{
|
||||
h = r.Norm1.Forward(x)
|
||||
h = silu3D(h)
|
||||
mlx.Eval(h)
|
||||
}
|
||||
|
||||
// Stage 2: conv1 (handles its own pools)
|
||||
{
|
||||
prev := h
|
||||
h = r.Conv1.Forward(h)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
// Stage 3: norm2 + silu
|
||||
{
|
||||
prev := h
|
||||
h = r.Norm2.Forward(h)
|
||||
h = silu3D(h)
|
||||
prev.Free()
|
||||
mlx.Eval(h)
|
||||
}
|
||||
|
||||
// Stage 4: conv2 (handles its own pools)
|
||||
{
|
||||
prev := h
|
||||
h = r.Conv2.Forward(h)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
// Residual connection (shortcut handles its own pools if present)
|
||||
if r.Shortcut != nil {
|
||||
shortcut := r.Shortcut.Forward(x)
|
||||
h = mlx.Add(h, shortcut)
|
||||
mlx.Eval(h)
|
||||
} else {
|
||||
h = mlx.Add(h, x)
|
||||
mlx.Eval(h)
|
||||
}
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// AttentionBlock is a 2D attention block
|
||||
type AttentionBlock struct {
|
||||
Norm *RMSNorm3D
|
||||
ToQKV *mlx.Array
|
||||
ToQKVBias *mlx.Array
|
||||
Proj *mlx.Array
|
||||
ProjBias *mlx.Array
|
||||
Dim int32
|
||||
}
|
||||
|
||||
// newAttentionBlock creates an attention block
|
||||
func newAttentionBlock(weights *safetensors.ModelWeights, prefix string, dim int32) (*AttentionBlock, error) {
|
||||
norm, err := newRMSNorm3D(weights, prefix+".norm", dim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
toQKV, _ := weights.Get(prefix + ".to_qkv.weight")
|
||||
toQKVBias, _ := weights.Get(prefix + ".to_qkv.bias")
|
||||
proj, _ := weights.Get(prefix + ".proj.weight")
|
||||
projBias, _ := weights.Get(prefix + ".proj.bias")
|
||||
|
||||
return &AttentionBlock{
|
||||
Norm: norm,
|
||||
ToQKV: toQKV,
|
||||
ToQKVBias: toQKVBias,
|
||||
Proj: proj,
|
||||
ProjBias: projBias,
|
||||
Dim: dim,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies 2D attention
|
||||
// Input: [B, T, H, W, C] (channels-last)
|
||||
func (a *AttentionBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
T := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
C := shape[4]
|
||||
|
||||
identity := x
|
||||
|
||||
// Flatten to [B*T, 1, H, W, C] for norm
|
||||
x = mlx.Reshape(x, B*T, 1, H, W, C)
|
||||
x = a.Norm.Forward(x)
|
||||
x = mlx.Reshape(x, B*T, H, W, C)
|
||||
|
||||
// Flatten spatial to [B*T, H*W, C]
|
||||
x = mlx.Reshape(x, B*T, H*W, C)
|
||||
|
||||
// Linear to get Q, K, V: [B*T, H*W, 3*C]
|
||||
// Weight is [outC, inC] or [outC, inC, 1, 1]
|
||||
wShape := a.ToQKV.Shape()
|
||||
var w *mlx.Array
|
||||
if len(wShape) == 4 {
|
||||
w = mlx.Reshape(a.ToQKV, wShape[0], wShape[1])
|
||||
} else {
|
||||
w = a.ToQKV
|
||||
}
|
||||
w = mlx.Transpose(w, 1, 0) // [inC, outC]
|
||||
|
||||
qkv := mlx.Linear(x, w) // [B*T, H*W, 3*C]
|
||||
if a.ToQKVBias != nil {
|
||||
qkv = mlx.Add(qkv, a.ToQKVBias)
|
||||
}
|
||||
qkv = mlx.Reshape(qkv, B*T, 1, H*W, 3*C)
|
||||
|
||||
q := mlx.Slice(qkv, []int32{0, 0, 0, 0}, []int32{B * T, 1, H * W, C})
|
||||
k := mlx.Slice(qkv, []int32{0, 0, 0, C}, []int32{B * T, 1, H * W, 2 * C})
|
||||
v := mlx.Slice(qkv, []int32{0, 0, 0, 2 * C}, []int32{B * T, 1, H * W, 3 * C})
|
||||
|
||||
scale := float32(1.0 / math.Sqrt(float64(C)))
|
||||
out := mlx.ScaledDotProductAttention(q, k, v, scale, false)
|
||||
|
||||
// out: [B*T, 1, H*W, C]
|
||||
out = mlx.Reshape(out, B*T, H*W, C)
|
||||
|
||||
// Project back
|
||||
pShape := a.Proj.Shape()
|
||||
var p *mlx.Array
|
||||
if len(pShape) == 4 {
|
||||
p = mlx.Reshape(a.Proj, pShape[0], pShape[1])
|
||||
} else {
|
||||
p = a.Proj
|
||||
}
|
||||
p = mlx.Transpose(p, 1, 0) // [inC, outC]
|
||||
out = mlx.Linear(out, p) // [B*T, H*W, C]
|
||||
if a.ProjBias != nil {
|
||||
out = mlx.Add(out, a.ProjBias)
|
||||
}
|
||||
|
||||
out = mlx.Reshape(out, B, T, H, W, C)
|
||||
return mlx.Add(out, identity)
|
||||
}
|
||||
|
||||
// UpBlock handles upsampling in decoder
|
||||
type UpBlock struct {
|
||||
ResBlocks []*ResBlock
|
||||
Upsampler *Upsample
|
||||
}
|
||||
|
||||
// newUpBlock creates an up block
|
||||
func newUpBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32, numBlocks int32, upsampleMode string) (*UpBlock, error) {
|
||||
resBlocks := make([]*ResBlock, numBlocks+1)
|
||||
|
||||
currentDim := inDim
|
||||
for i := int32(0); i <= numBlocks; i++ {
|
||||
resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i)
|
||||
block, err := newResBlock(weights, resPrefix, currentDim, outDim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resBlocks[i] = block
|
||||
currentDim = outDim
|
||||
}
|
||||
|
||||
var upsampler *Upsample
|
||||
if upsampleMode != "" {
|
||||
upsampler = newUpsample(weights, prefix+".upsamplers.0", outDim, upsampleMode)
|
||||
}
|
||||
|
||||
return &UpBlock{
|
||||
ResBlocks: resBlocks,
|
||||
Upsampler: upsampler,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies up block with staged memory management
|
||||
func (u *UpBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
// ResBlocks handle their own pools
|
||||
for _, block := range u.ResBlocks {
|
||||
prev := x
|
||||
x = block.Forward(x)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
// Upsampler handles its own pools
|
||||
if u.Upsampler != nil {
|
||||
prev := x
|
||||
x = u.Upsampler.Forward(x)
|
||||
prev.Free()
|
||||
}
|
||||
return x
|
||||
}
|
||||
|
||||
// Upsample handles spatial upsampling
|
||||
type Upsample struct {
|
||||
Conv *mlx.Array
|
||||
Bias *mlx.Array
|
||||
Mode string
|
||||
}
|
||||
|
||||
// newUpsample creates an upsampler
|
||||
func newUpsample(weights *safetensors.ModelWeights, prefix string, dim int32, mode string) *Upsample {
|
||||
conv, _ := weights.Get(prefix + ".resample.1.weight")
|
||||
bias, _ := weights.Get(prefix + ".resample.1.bias")
|
||||
return &Upsample{
|
||||
Conv: conv,
|
||||
Bias: bias,
|
||||
Mode: mode,
|
||||
}
|
||||
}
|
||||
|
||||
// Forward applies upsampling to channels-last input [B, T, H, W, C]
|
||||
// Uses staged pools to reduce peak memory during 2x upsampling
|
||||
func (u *Upsample) Forward(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
T := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
C := shape[4]
|
||||
outC := u.Conv.Shape()[0]
|
||||
|
||||
// Stage 1: 2x nearest neighbor upsample
|
||||
{
|
||||
x = mlx.Reshape(x, B*T, H, W, C)
|
||||
x = upsample2xChannelsLast(x)
|
||||
mlx.Eval(x)
|
||||
}
|
||||
|
||||
// Stage 2: Conv + bias
|
||||
{
|
||||
prev := x
|
||||
weight := mlx.Transpose(u.Conv, 0, 2, 3, 1)
|
||||
x = conv2D3x3PaddedChannelsLast(x, weight)
|
||||
if u.Bias != nil {
|
||||
bias := mlx.Reshape(u.Bias, 1, 1, 1, outC)
|
||||
x = mlx.Add(x, bias)
|
||||
}
|
||||
x = mlx.Reshape(x, B, T, H*2, W*2, outC)
|
||||
prev.Free()
|
||||
mlx.Eval(x)
|
||||
}
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// MidBlock is the middle block of decoder
|
||||
type MidBlock struct {
|
||||
ResBlock1 *ResBlock
|
||||
Attention *AttentionBlock
|
||||
ResBlock2 *ResBlock
|
||||
}
|
||||
|
||||
// newMidBlock creates a mid block
|
||||
func newMidBlock(weights *safetensors.ModelWeights, prefix string, dim int32) (*MidBlock, error) {
|
||||
res1, err := newResBlock(weights, prefix+".resnets.0", dim, dim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
attn, err := newAttentionBlock(weights, prefix+".attentions.0", dim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
res2, err := newResBlock(weights, prefix+".resnets.1", dim, dim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &MidBlock{
|
||||
ResBlock1: res1,
|
||||
Attention: attn,
|
||||
ResBlock2: res2,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies mid block
|
||||
func (m *MidBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
// Each component handles its own pools; we just free inputs
|
||||
prev := x
|
||||
x = m.ResBlock1.Forward(x)
|
||||
prev.Free()
|
||||
|
||||
prev = x
|
||||
x = m.Attention.Forward(x)
|
||||
prev.Free()
|
||||
|
||||
prev = x
|
||||
x = m.ResBlock2.Forward(x)
|
||||
prev.Free()
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// VAEDecoder is the full VAE decoder
|
||||
type VAEDecoder struct {
|
||||
Config *VAEConfig
|
||||
|
||||
PostQuantConv *CausalConv3d
|
||||
ConvIn *CausalConv3d
|
||||
MidBlock *MidBlock
|
||||
UpBlocks []*UpBlock
|
||||
NormOut *RMSNorm3D
|
||||
ConvOut *CausalConv3d
|
||||
}
|
||||
|
||||
// Load loads the VAE decoder from a directory
|
||||
func (m *VAEDecoder) Load(path string) error {
|
||||
fmt.Println("Loading Qwen-Image VAE decoder...")
|
||||
|
||||
cfg := defaultVAEConfig()
|
||||
m.Config = cfg
|
||||
|
||||
weights, err := safetensors.LoadModelWeights(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
|
||||
// Bulk load all weights as bf16
|
||||
fmt.Print(" Loading weights as bf16... ")
|
||||
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
|
||||
return fmt.Errorf("failed to load weights: %w", err)
|
||||
}
|
||||
fmt.Printf("✓ (%.1f GB)\n", float64(mlx.MetalGetActiveMemory())/(1024*1024*1024))
|
||||
|
||||
fmt.Print(" Loading post_quant_conv... ")
|
||||
postQuantConv, err := newCausalConv3d(weights, "post_quant_conv")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.PostQuantConv = postQuantConv
|
||||
fmt.Println("✓")
|
||||
|
||||
fmt.Print(" Loading conv_in... ")
|
||||
convIn, err := newCausalConv3d(weights, "decoder.conv_in")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.ConvIn = convIn
|
||||
fmt.Println("✓")
|
||||
|
||||
// Mid block (dim = base_dim * dim_mult[-1] = 96 * 4 = 384)
|
||||
fmt.Print(" Loading mid_block... ")
|
||||
midDim := cfg.BaseDim * cfg.DimMult[len(cfg.DimMult)-1]
|
||||
midBlock, err := newMidBlock(weights, "decoder.mid_block", midDim)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.MidBlock = midBlock
|
||||
fmt.Println("✓")
|
||||
|
||||
// Up blocks (reversed dim_mult)
|
||||
fmt.Print(" Loading up_blocks... ")
|
||||
numUpBlocks := len(cfg.DimMult)
|
||||
m.UpBlocks = make([]*UpBlock, numUpBlocks)
|
||||
|
||||
dimsMult := make([]int32, numUpBlocks+1)
|
||||
dimsMult[0] = cfg.DimMult[numUpBlocks-1]
|
||||
for i := 0; i < numUpBlocks; i++ {
|
||||
dimsMult[i+1] = cfg.DimMult[numUpBlocks-1-i]
|
||||
}
|
||||
|
||||
temporalUpsample := make([]bool, len(cfg.TemperalDownsample))
|
||||
for i := range cfg.TemperalDownsample {
|
||||
temporalUpsample[i] = cfg.TemperalDownsample[len(cfg.TemperalDownsample)-1-i]
|
||||
}
|
||||
|
||||
for i := 0; i < numUpBlocks; i++ {
|
||||
inDim := cfg.BaseDim * dimsMult[i]
|
||||
outDim := cfg.BaseDim * dimsMult[i+1]
|
||||
|
||||
if i > 0 {
|
||||
inDim = inDim / 2
|
||||
}
|
||||
|
||||
upsampleMode := ""
|
||||
if i < numUpBlocks-1 {
|
||||
if temporalUpsample[i] {
|
||||
upsampleMode = "upsample3d"
|
||||
} else {
|
||||
upsampleMode = "upsample2d"
|
||||
}
|
||||
}
|
||||
|
||||
prefix := fmt.Sprintf("decoder.up_blocks.%d", i)
|
||||
upBlock, err := newUpBlock(weights, prefix, inDim, outDim, cfg.NumResBlocks, upsampleMode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.UpBlocks[i] = upBlock
|
||||
}
|
||||
fmt.Printf("✓ [%d blocks]\n", numUpBlocks)
|
||||
|
||||
fmt.Print(" Loading output layers... ")
|
||||
normOut, err := newRMSNorm3D(weights, "decoder.norm_out", cfg.BaseDim)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.NormOut = normOut
|
||||
convOut, err := newCausalConv3d(weights, "decoder.conv_out")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.ConvOut = convOut
|
||||
fmt.Println("✓")
|
||||
|
||||
weights.ReleaseAll()
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadVAEDecoderFromPath is a convenience function to load VAE from path
|
||||
func LoadVAEDecoderFromPath(path string) (*VAEDecoder, error) {
|
||||
m := &VAEDecoder{}
|
||||
if err := m.Load(filepath.Join(path, "vae")); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Decode converts latents to image
|
||||
// z: [B, C, T, H, W] normalized latents
|
||||
// Uses staged pools to free intermediate arrays and reduce peak memory.
|
||||
func (vae *VAEDecoder) Decode(z *mlx.Array) *mlx.Array {
|
||||
var x *mlx.Array
|
||||
|
||||
// Stage 1a: Denormalize and transpose
|
||||
{
|
||||
z = vae.Denormalize(z)
|
||||
// Convert from channels-first [N, C, T, H, W] to channels-last [N, T, H, W, C]
|
||||
z = mlx.Contiguous(mlx.Transpose(z, 0, 2, 3, 4, 1))
|
||||
mlx.Eval(z)
|
||||
}
|
||||
|
||||
// Stage 1b: PostQuantConv (handles its own pools)
|
||||
x = vae.PostQuantConv.Forward(z)
|
||||
z.Free()
|
||||
|
||||
// Stage 1c: ConvIn (handles its own pools)
|
||||
{
|
||||
prev := x
|
||||
x = vae.ConvIn.Forward(x)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
// Stage 2: Mid block (handles its own pools)
|
||||
x = vae.MidBlock.Forward(x)
|
||||
|
||||
// Stage 3: Up blocks (each handles its own pools)
|
||||
for _, upBlock := range vae.UpBlocks {
|
||||
x = upBlock.Forward(x)
|
||||
}
|
||||
|
||||
// Stage 4a: NormOut + silu
|
||||
{
|
||||
prev := x
|
||||
x = vae.NormOut.Forward(x)
|
||||
x = silu3D(x)
|
||||
prev.Free()
|
||||
mlx.Eval(x)
|
||||
}
|
||||
|
||||
// Stage 4b: ConvOut (handles its own pools)
|
||||
{
|
||||
prev := x
|
||||
x = vae.ConvOut.Forward(x)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
// Stage 4c: Post-processing
|
||||
{
|
||||
prev := x
|
||||
// Clamp to [-1, 1]
|
||||
x = mlx.ClipScalar(x, -1.0, 1.0, true, true)
|
||||
// Convert back from channels-last to channels-first
|
||||
x = mlx.Contiguous(mlx.Transpose(x, 0, 4, 1, 2, 3))
|
||||
prev.Free()
|
||||
mlx.Eval(x)
|
||||
}
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// Denormalize reverses the normalization applied during encoding
|
||||
func (vae *VAEDecoder) Denormalize(z *mlx.Array) *mlx.Array {
|
||||
shape := z.Shape()
|
||||
C := shape[1]
|
||||
|
||||
mean := mlx.NewArray(vae.Config.LatentsMean[:C], []int32{1, C, 1, 1, 1})
|
||||
std := mlx.NewArray(vae.Config.LatentsStd[:C], []int32{1, C, 1, 1, 1})
|
||||
|
||||
mean = mlx.ToBFloat16(mean)
|
||||
std = mlx.ToBFloat16(std)
|
||||
|
||||
return mlx.Add(mlx.Mul(z, std), mean)
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
func silu3D(x *mlx.Array) *mlx.Array {
|
||||
return mlx.Mul(x, mlx.Sigmoid(x))
|
||||
}
|
||||
|
||||
// pad3DChannelsLast pads a channels-last [B, T, H, W, C] tensor
|
||||
func pad3DChannelsLast(x *mlx.Array, tBefore, tAfter, hBefore, hAfter, wBefore, wAfter int32) *mlx.Array {
|
||||
if tBefore == 0 && tAfter == 0 && hBefore == 0 && hAfter == 0 && wBefore == 0 && wAfter == 0 {
|
||||
return x
|
||||
}
|
||||
// Pad dims: [B before, B after, T before, T after, H before, H after, W before, W after, C before, C after]
|
||||
return mlx.Pad(x, []int32{0, 0, tBefore, tAfter, hBefore, hAfter, wBefore, wAfter, 0, 0})
|
||||
}
|
||||
|
||||
func pad2D(x *mlx.Array, hBefore, hAfter, wBefore, wAfter int32) *mlx.Array {
|
||||
if hBefore == 0 && hAfter == 0 && wBefore == 0 && wAfter == 0 {
|
||||
return x
|
||||
}
|
||||
return mlx.Pad(x, []int32{0, 0, 0, 0, hBefore, hAfter, wBefore, wAfter})
|
||||
}
|
||||
|
||||
func conv2D1x1(x, weight *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
|
||||
x = mlx.Transpose(x, 0, 2, 3, 1)
|
||||
x = mlx.Reshape(x, B*H*W, shape[1])
|
||||
|
||||
wShape := weight.Shape()
|
||||
var w *mlx.Array
|
||||
if len(wShape) == 4 {
|
||||
w = mlx.Reshape(weight, wShape[0], wShape[1])
|
||||
} else {
|
||||
w = weight
|
||||
}
|
||||
w = mlx.Transpose(w, 1, 0)
|
||||
|
||||
out := mlx.Linear(x, w)
|
||||
outC := w.Dim(1)
|
||||
out = mlx.Reshape(out, B, H, W, outC)
|
||||
return mlx.Transpose(out, 0, 3, 1, 2)
|
||||
}
|
||||
|
||||
func conv2D3x3Padded(x, weight *mlx.Array) *mlx.Array {
|
||||
x = pad2D(x, 1, 1, 1, 1)
|
||||
return conv2D(x, weight, 1, 1)
|
||||
}
|
||||
|
||||
func conv2D(x, w *mlx.Array, strideH, strideW int32) *mlx.Array {
|
||||
x = mlx.Transpose(x, 0, 2, 3, 1)
|
||||
w = mlx.Transpose(w, 0, 2, 3, 1)
|
||||
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
H := shape[1]
|
||||
W := shape[2]
|
||||
|
||||
wShape := w.Shape()
|
||||
Cout := wShape[0]
|
||||
kH := wShape[1]
|
||||
kW := wShape[2]
|
||||
|
||||
outH := (H-kH)/strideH + 1
|
||||
outW := (W-kW)/strideW + 1
|
||||
|
||||
patches := extractPatches2D(x, kH, kW, strideH, strideW)
|
||||
wFlat := mlx.Reshape(w, Cout, -1)
|
||||
patches = mlx.Reshape(patches, B*outH*outW, -1)
|
||||
out := mlx.Linear(patches, mlx.Transpose(wFlat, 1, 0))
|
||||
out = mlx.Reshape(out, B, outH, outW, Cout)
|
||||
return mlx.Transpose(out, 0, 3, 1, 2)
|
||||
}
|
||||
|
||||
func extractPatches2D(x *mlx.Array, kH, kW, strideH, strideW int32) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
H := shape[1]
|
||||
W := shape[2]
|
||||
C := shape[3]
|
||||
|
||||
outH := (H-kH)/strideH + 1
|
||||
outW := (W-kW)/strideW + 1
|
||||
|
||||
patches := make([]*mlx.Array, outH*outW)
|
||||
idx := 0
|
||||
for i := int32(0); i < outH; i++ {
|
||||
for j := int32(0); j < outW; j++ {
|
||||
startH := i * strideH
|
||||
startW := j * strideW
|
||||
patch := mlx.Slice(x, []int32{0, startH, startW, 0}, []int32{B, startH + kH, startW + kW, C})
|
||||
patch = mlx.Reshape(patch, B, kH*kW*C)
|
||||
patches[idx] = patch
|
||||
idx++
|
||||
}
|
||||
}
|
||||
|
||||
for i := range patches {
|
||||
patches[i] = mlx.ExpandDims(patches[i], 1)
|
||||
}
|
||||
stacked := mlx.Concatenate(patches, 1)
|
||||
return mlx.Reshape(stacked, B, outH, outW, kH*kW*C)
|
||||
}
|
||||
|
||||
func upsample2x(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
|
||||
rowIdxData := make([]int32, H*2)
|
||||
for i := int32(0); i < H; i++ {
|
||||
rowIdxData[i*2] = i
|
||||
rowIdxData[i*2+1] = i
|
||||
}
|
||||
rowIdx := mlx.NewArrayInt32(rowIdxData, []int32{H * 2})
|
||||
|
||||
colIdxData := make([]int32, W*2)
|
||||
for i := int32(0); i < W; i++ {
|
||||
colIdxData[i*2] = i
|
||||
colIdxData[i*2+1] = i
|
||||
}
|
||||
colIdx := mlx.NewArrayInt32(colIdxData, []int32{W * 2})
|
||||
|
||||
x = mlx.Take(x, rowIdx, 2)
|
||||
x = mlx.Take(x, colIdx, 3)
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// upsample2xChannelsLast upsamples channels-last input [B, H, W, C] by 2x
|
||||
func upsample2xChannelsLast(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
H := shape[1]
|
||||
W := shape[2]
|
||||
|
||||
// Create repeat indices for rows
|
||||
rowIdxData := make([]int32, H*2)
|
||||
for i := int32(0); i < H; i++ {
|
||||
rowIdxData[i*2] = i
|
||||
rowIdxData[i*2+1] = i
|
||||
}
|
||||
rowIdx := mlx.NewArrayInt32(rowIdxData, []int32{H * 2})
|
||||
|
||||
// Create repeat indices for columns
|
||||
colIdxData := make([]int32, W*2)
|
||||
for i := int32(0); i < W; i++ {
|
||||
colIdxData[i*2] = i
|
||||
colIdxData[i*2+1] = i
|
||||
}
|
||||
colIdx := mlx.NewArrayInt32(colIdxData, []int32{W * 2})
|
||||
|
||||
// Take along H (axis 1) then W (axis 2)
|
||||
x = mlx.Take(x, rowIdx, 1)
|
||||
x = mlx.Take(x, colIdx, 2)
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// conv2D3x3PaddedChannelsLast applies 3x3 conv with padding to channels-last input [B, H, W, C]
|
||||
// weight: [outC, kH, kW, inC] (MLX channels-last format)
|
||||
func conv2D3x3PaddedChannelsLast(x, weight *mlx.Array) *mlx.Array {
|
||||
// Pad spatial dims: [B, H, W, C] -> pad H and W by 1 each side
|
||||
x = mlx.Pad(x, []int32{0, 0, 1, 1, 1, 1, 0, 0})
|
||||
// Conv2d expects: input [B, H, W, inC], weight [outC, kH, kW, inC]
|
||||
// stride=1, padding=0 (we already padded manually)
|
||||
return mlx.Conv2d(x, weight, 1, 0)
|
||||
}
|
||||
114
x/imagegen/models/qwen_image/vae_test.go
Normal file
114
x/imagegen/models/qwen_image/vae_test.go
Normal file
@@ -0,0 +1,114 @@
|
||||
//go:build mlx
|
||||
|
||||
package qwen_image
|
||||
|
||||
import (
|
||||
"math"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// TestVAEConfig tests configuration invariants.
|
||||
func TestVAEConfig(t *testing.T) {
|
||||
cfg := defaultVAEConfig()
|
||||
|
||||
// Property: latents_mean and latents_std have z_dim elements
|
||||
if int32(len(cfg.LatentsMean)) != cfg.ZDim {
|
||||
t.Errorf("latents_mean length != z_dim: %d != %d", len(cfg.LatentsMean), cfg.ZDim)
|
||||
}
|
||||
if int32(len(cfg.LatentsStd)) != cfg.ZDim {
|
||||
t.Errorf("latents_std length != z_dim: %d != %d", len(cfg.LatentsStd), cfg.ZDim)
|
||||
}
|
||||
|
||||
// Property: dim_mult defines 4 stages
|
||||
if len(cfg.DimMult) != 4 {
|
||||
t.Errorf("dim_mult should have 4 stages: got %d", len(cfg.DimMult))
|
||||
}
|
||||
|
||||
// Property: temperal_downsample has 3 elements (for 3 transitions)
|
||||
if len(cfg.TemperalDownsample) != 3 {
|
||||
t.Errorf("temperal_downsample should have 3 elements: got %d", len(cfg.TemperalDownsample))
|
||||
}
|
||||
}
|
||||
|
||||
// TestVAELatentsNormalization tests the latent denormalization values.
|
||||
func TestVAELatentsNormalization(t *testing.T) {
|
||||
cfg := defaultVAEConfig()
|
||||
|
||||
// Verify latents_std values are all positive
|
||||
for i, std := range cfg.LatentsStd {
|
||||
if std <= 0 {
|
||||
t.Errorf("latents_std[%d] should be positive: %v", i, std)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify values are in reasonable range (from actual model)
|
||||
for i, mean := range cfg.LatentsMean {
|
||||
if math.Abs(float64(mean)) > 5 {
|
||||
t.Errorf("latents_mean[%d] seems too large: %v", i, mean)
|
||||
}
|
||||
}
|
||||
for i, std := range cfg.LatentsStd {
|
||||
if std > 10 {
|
||||
t.Errorf("latents_std[%d] seems too large: %v", i, std)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestVAEDecoderForward tests full forward pass (integration test).
|
||||
// Skips if model weights are not available.
|
||||
func TestVAEDecoderForward(t *testing.T) {
|
||||
weightsPath := "../../../weights/Qwen-Image-2512/vae"
|
||||
if _, err := os.Stat(weightsPath); os.IsNotExist(err) {
|
||||
t.Skip("Skipping: model weights not found at " + weightsPath)
|
||||
}
|
||||
|
||||
vae := &VAEDecoder{}
|
||||
if err := vae.Load(weightsPath); err != nil {
|
||||
t.Fatalf("Failed to load VAE decoder: %v", err)
|
||||
}
|
||||
mlx.Keep(mlx.Collect(vae)...)
|
||||
|
||||
// Small test input: [B, C, T, H, W]
|
||||
// After 4 upsampling stages (2x each), H/W multiply by 16
|
||||
batchSize := int32(1)
|
||||
channels := int32(16)
|
||||
frames := int32(1)
|
||||
latentH := int32(4)
|
||||
latentW := int32(4)
|
||||
|
||||
latents := mlx.RandomNormal([]int32{batchSize, channels, frames, latentH, latentW}, 0)
|
||||
|
||||
// Decode
|
||||
out := vae.Decode(latents)
|
||||
mlx.Eval(out)
|
||||
|
||||
// Verify output shape: [B, 3, T, H*16, W*16]
|
||||
outShape := out.Shape()
|
||||
if outShape[0] != batchSize {
|
||||
t.Errorf("batch size: got %d, want %d", outShape[0], batchSize)
|
||||
}
|
||||
if outShape[1] != 3 {
|
||||
t.Errorf("channels: got %d, want 3", outShape[1])
|
||||
}
|
||||
if outShape[2] != frames {
|
||||
t.Errorf("frames: got %d, want %d", outShape[2], frames)
|
||||
}
|
||||
expectedH := latentH * 16 // 4 stages of 2x upsampling
|
||||
expectedW := latentW * 16
|
||||
if outShape[3] != expectedH || outShape[4] != expectedW {
|
||||
t.Errorf("spatial dims: got [%d, %d], want [%d, %d]",
|
||||
outShape[3], outShape[4], expectedH, expectedW)
|
||||
}
|
||||
|
||||
// Verify output is in valid range (should be clamped to [0, 1] by decode)
|
||||
outData := out.Data()
|
||||
for i := 0; i < min(100, len(outData)); i++ {
|
||||
if math.IsNaN(float64(outData[i])) || math.IsInf(float64(outData[i]), 0) {
|
||||
t.Errorf("output[%d] not finite: %v", i, outData[i])
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
682
x/imagegen/models/qwen_image_edit/layers.go
Normal file
682
x/imagegen/models/qwen_image_edit/layers.go
Normal file
@@ -0,0 +1,682 @@
|
||||
//go:build mlx
|
||||
|
||||
package qwen_image_edit
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
|
||||
// CausalConv3d is a causal 3D convolution (for temporal causality)
|
||||
type CausalConv3d struct {
|
||||
Weight *mlx.Array
|
||||
Bias *mlx.Array
|
||||
BiasReshaped *mlx.Array // [1, C, 1, 1, 1]
|
||||
KernelT int32
|
||||
}
|
||||
|
||||
// newCausalConv3d creates a 3D causal conv
|
||||
func newCausalConv3d(weights *safetensors.ModelWeights, prefix string) (*CausalConv3d, error) {
|
||||
weight, err := weights.Get(prefix + ".weight")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("weight not found: %s", prefix)
|
||||
}
|
||||
bias, _ := weights.Get(prefix + ".bias")
|
||||
|
||||
kernelT := weight.Shape()[2]
|
||||
outC := weight.Shape()[0]
|
||||
|
||||
var biasReshaped *mlx.Array
|
||||
if bias != nil {
|
||||
biasReshaped = mlx.Reshape(bias, 1, outC, 1, 1, 1)
|
||||
}
|
||||
|
||||
return &CausalConv3d{
|
||||
Weight: weight,
|
||||
Bias: bias,
|
||||
BiasReshaped: biasReshaped,
|
||||
KernelT: kernelT,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies causal 3D convolution (or 2D if weight is 4D)
|
||||
// x: [B, T, H, W, C] (channels-last, MLX format)
|
||||
func (c *CausalConv3d) Forward(x *mlx.Array) *mlx.Array {
|
||||
shape := c.Weight.Shape()
|
||||
|
||||
// Handle both 5D (3D conv) and 4D (2D conv) weights
|
||||
if len(shape) == 4 {
|
||||
// 2D conv: [O, I, kH, kW] - need to apply per-frame
|
||||
return c.forward2D(x)
|
||||
}
|
||||
|
||||
// 3D conv: [O, I, kT, kH, kW]
|
||||
kernelT := shape[2]
|
||||
kernelH := shape[3]
|
||||
kernelW := shape[4]
|
||||
|
||||
// Causal temporal padding, same spatial padding
|
||||
padT := kernelT - 1
|
||||
padH := kernelH / 2
|
||||
padW := kernelW / 2
|
||||
|
||||
// Stage 1: Pad
|
||||
{
|
||||
x = pad3DChannelsLast(x, padT, 0, padH, padH, padW, padW)
|
||||
mlx.Eval(x)
|
||||
}
|
||||
|
||||
// Stage 2: Conv + bias
|
||||
var out *mlx.Array
|
||||
{
|
||||
prev := x
|
||||
weight := mlx.Transpose(c.Weight, 0, 2, 3, 4, 1)
|
||||
out = mlx.Conv3d(x, weight, 1, 1, 1, 0, 0, 0)
|
||||
if c.Bias != nil {
|
||||
bias := mlx.Reshape(c.Bias, 1, 1, 1, 1, c.Bias.Dim(0))
|
||||
out = mlx.Add(out, bias)
|
||||
}
|
||||
prev.Free()
|
||||
mlx.Eval(out)
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// forward2D applies 2D conv per-frame for [B, T, H, W, C] input
|
||||
func (c *CausalConv3d) forward2D(x *mlx.Array) *mlx.Array {
|
||||
xShape := x.Shape()
|
||||
B := xShape[0]
|
||||
T := xShape[1]
|
||||
H := xShape[2]
|
||||
W := xShape[3]
|
||||
C := xShape[4]
|
||||
|
||||
wShape := c.Weight.Shape() // [O, I, kH, kW]
|
||||
kernelH := wShape[2]
|
||||
kernelW := wShape[3]
|
||||
outC := wShape[0]
|
||||
|
||||
padH := kernelH / 2
|
||||
padW := kernelW / 2
|
||||
|
||||
// Reshape to [B*T, H, W, C] for 2D conv
|
||||
x = mlx.Reshape(x, B*T, H, W, C)
|
||||
|
||||
// Pad spatially
|
||||
x = mlx.Pad(x, []int32{0, 0, padH, padH, padW, padW, 0, 0})
|
||||
|
||||
// Apply 2D conv
|
||||
weight := mlx.Transpose(c.Weight, 0, 2, 3, 1) // [O, I, kH, kW] -> [O, kH, kW, I]
|
||||
x = mlx.Conv2d(x, weight, 1, 0)
|
||||
|
||||
if c.Bias != nil {
|
||||
bias := mlx.Reshape(c.Bias, 1, 1, 1, outC)
|
||||
x = mlx.Add(x, bias)
|
||||
}
|
||||
|
||||
// Get output spatial dims
|
||||
outH := H
|
||||
outW := W
|
||||
|
||||
// Reshape back to [B, T, H, W, C]
|
||||
x = mlx.Reshape(x, B, T, outH, outW, outC)
|
||||
mlx.Eval(x)
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// RMSNorm3D applies RMS normalization over channels
|
||||
type RMSNorm3D struct {
|
||||
Gamma *mlx.Array // [1, 1, 1, 1, C] for broadcasting
|
||||
}
|
||||
|
||||
// newRMSNorm3D creates an RMS norm
|
||||
func newRMSNorm3D(weights *safetensors.ModelWeights, prefix string, dim int32) (*RMSNorm3D, error) {
|
||||
gamma, err := weights.Get(prefix + ".gamma")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
gamma = mlx.Reshape(gamma, 1, 1, 1, 1, gamma.Dim(0))
|
||||
return &RMSNorm3D{Gamma: gamma}, nil
|
||||
}
|
||||
|
||||
// Forward applies RMS norm to channels-last input [B, T, H, W, C]
|
||||
func (n *RMSNorm3D) Forward(x *mlx.Array) *mlx.Array {
|
||||
normalized := mlx.RMSNormNoWeight(x, 1e-6)
|
||||
return mlx.Mul(normalized, n.Gamma)
|
||||
}
|
||||
|
||||
// ResBlock is a residual block with RMS norm and causal convs
|
||||
type ResBlock struct {
|
||||
Norm1 *RMSNorm3D
|
||||
Conv1 *CausalConv3d
|
||||
Norm2 *RMSNorm3D
|
||||
Conv2 *CausalConv3d
|
||||
Shortcut *CausalConv3d
|
||||
}
|
||||
|
||||
// newResBlock creates a residual block
|
||||
func newResBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32) (*ResBlock, error) {
|
||||
norm1, err := newRMSNorm3D(weights, prefix+".norm1", inDim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conv1, err := newCausalConv3d(weights, prefix+".conv1")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
norm2, err := newRMSNorm3D(weights, prefix+".norm2", outDim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conv2, err := newCausalConv3d(weights, prefix+".conv2")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var shortcut *CausalConv3d
|
||||
if inDim != outDim {
|
||||
shortcut, err = newCausalConv3d(weights, prefix+".conv_shortcut")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &ResBlock{
|
||||
Norm1: norm1,
|
||||
Conv1: conv1,
|
||||
Norm2: norm2,
|
||||
Conv2: conv2,
|
||||
Shortcut: shortcut,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies the residual block
|
||||
func (r *ResBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
var h *mlx.Array
|
||||
|
||||
mlx.Keep(x)
|
||||
|
||||
// Stage 1: norm1 + silu
|
||||
{
|
||||
h = r.Norm1.Forward(x)
|
||||
h = silu3D(h)
|
||||
mlx.Eval(h)
|
||||
}
|
||||
|
||||
// Stage 2: conv1
|
||||
{
|
||||
prev := h
|
||||
h = r.Conv1.Forward(h)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
// Stage 3: norm2 + silu
|
||||
{
|
||||
prev := h
|
||||
h = r.Norm2.Forward(h)
|
||||
h = silu3D(h)
|
||||
prev.Free()
|
||||
mlx.Eval(h)
|
||||
}
|
||||
|
||||
// Stage 4: conv2
|
||||
{
|
||||
prev := h
|
||||
h = r.Conv2.Forward(h)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
// Residual connection
|
||||
if r.Shortcut != nil {
|
||||
shortcut := r.Shortcut.Forward(x)
|
||||
h = mlx.Add(h, shortcut)
|
||||
mlx.Eval(h)
|
||||
} else {
|
||||
h = mlx.Add(h, x)
|
||||
mlx.Eval(h)
|
||||
}
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// AttentionBlock is a 2D attention block
|
||||
type AttentionBlock struct {
|
||||
Norm *RMSNorm3D
|
||||
ToQKV *mlx.Array
|
||||
ToQKVBias *mlx.Array
|
||||
Proj *mlx.Array
|
||||
ProjBias *mlx.Array
|
||||
Dim int32
|
||||
}
|
||||
|
||||
// newAttentionBlock creates an attention block
|
||||
func newAttentionBlock(weights *safetensors.ModelWeights, prefix string, dim int32) (*AttentionBlock, error) {
|
||||
norm, err := newRMSNorm3D(weights, prefix+".norm", dim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
toQKV, _ := weights.Get(prefix + ".to_qkv.weight")
|
||||
toQKVBias, _ := weights.Get(prefix + ".to_qkv.bias")
|
||||
proj, _ := weights.Get(prefix + ".proj.weight")
|
||||
projBias, _ := weights.Get(prefix + ".proj.bias")
|
||||
|
||||
return &AttentionBlock{
|
||||
Norm: norm,
|
||||
ToQKV: toQKV,
|
||||
ToQKVBias: toQKVBias,
|
||||
Proj: proj,
|
||||
ProjBias: projBias,
|
||||
Dim: dim,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies 2D attention
|
||||
// Input: [B, T, H, W, C] (channels-last)
|
||||
func (a *AttentionBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
T := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
C := shape[4]
|
||||
|
||||
identity := x
|
||||
|
||||
// Flatten to [B*T, 1, H, W, C] for norm
|
||||
x = mlx.Reshape(x, B*T, 1, H, W, C)
|
||||
x = a.Norm.Forward(x)
|
||||
x = mlx.Reshape(x, B*T, H, W, C)
|
||||
|
||||
// Flatten spatial to [B*T, H*W, C]
|
||||
x = mlx.Reshape(x, B*T, H*W, C)
|
||||
|
||||
// Linear to get Q, K, V
|
||||
wShape := a.ToQKV.Shape()
|
||||
var w *mlx.Array
|
||||
if len(wShape) == 4 {
|
||||
w = mlx.Reshape(a.ToQKV, wShape[0], wShape[1])
|
||||
} else {
|
||||
w = a.ToQKV
|
||||
}
|
||||
w = mlx.Transpose(w, 1, 0)
|
||||
|
||||
qkv := mlx.Linear(x, w)
|
||||
if a.ToQKVBias != nil {
|
||||
qkv = mlx.Add(qkv, a.ToQKVBias)
|
||||
}
|
||||
qkv = mlx.Reshape(qkv, B*T, 1, H*W, 3*C)
|
||||
|
||||
q := mlx.Slice(qkv, []int32{0, 0, 0, 0}, []int32{B * T, 1, H * W, C})
|
||||
k := mlx.Slice(qkv, []int32{0, 0, 0, C}, []int32{B * T, 1, H * W, 2 * C})
|
||||
v := mlx.Slice(qkv, []int32{0, 0, 0, 2 * C}, []int32{B * T, 1, H * W, 3 * C})
|
||||
|
||||
scale := float32(1.0 / math.Sqrt(float64(C)))
|
||||
out := mlx.ScaledDotProductAttention(q, k, v, scale, false)
|
||||
|
||||
out = mlx.Reshape(out, B*T, H*W, C)
|
||||
|
||||
// Project back
|
||||
pShape := a.Proj.Shape()
|
||||
var p *mlx.Array
|
||||
if len(pShape) == 4 {
|
||||
p = mlx.Reshape(a.Proj, pShape[0], pShape[1])
|
||||
} else {
|
||||
p = a.Proj
|
||||
}
|
||||
p = mlx.Transpose(p, 1, 0)
|
||||
out = mlx.Linear(out, p)
|
||||
if a.ProjBias != nil {
|
||||
out = mlx.Add(out, a.ProjBias)
|
||||
}
|
||||
|
||||
out = mlx.Reshape(out, B, T, H, W, C)
|
||||
return mlx.Add(out, identity)
|
||||
}
|
||||
|
||||
// UpBlock handles upsampling in decoder
|
||||
type UpBlock struct {
|
||||
ResBlocks []*ResBlock
|
||||
Upsampler *Upsample
|
||||
}
|
||||
|
||||
// newUpBlock creates an up block
|
||||
func newUpBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32, numBlocks int32, upsampleMode string) (*UpBlock, error) {
|
||||
resBlocks := make([]*ResBlock, numBlocks+1)
|
||||
|
||||
currentDim := inDim
|
||||
for i := int32(0); i <= numBlocks; i++ {
|
||||
resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i)
|
||||
block, err := newResBlock(weights, resPrefix, currentDim, outDim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resBlocks[i] = block
|
||||
currentDim = outDim
|
||||
}
|
||||
|
||||
var upsampler *Upsample
|
||||
if upsampleMode != "" {
|
||||
upsampler = newUpsample(weights, prefix+".upsamplers.0", outDim, upsampleMode)
|
||||
}
|
||||
|
||||
return &UpBlock{
|
||||
ResBlocks: resBlocks,
|
||||
Upsampler: upsampler,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies up block
|
||||
func (u *UpBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
for _, block := range u.ResBlocks {
|
||||
prev := x
|
||||
x = block.Forward(x)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
if u.Upsampler != nil {
|
||||
prev := x
|
||||
x = u.Upsampler.Forward(x)
|
||||
prev.Free()
|
||||
}
|
||||
return x
|
||||
}
|
||||
|
||||
// Upsample handles spatial upsampling
|
||||
type Upsample struct {
|
||||
Conv *mlx.Array
|
||||
Bias *mlx.Array
|
||||
Mode string
|
||||
}
|
||||
|
||||
// newUpsample creates an upsampler
|
||||
func newUpsample(weights *safetensors.ModelWeights, prefix string, dim int32, mode string) *Upsample {
|
||||
conv, _ := weights.Get(prefix + ".resample.1.weight")
|
||||
bias, _ := weights.Get(prefix + ".resample.1.bias")
|
||||
return &Upsample{
|
||||
Conv: conv,
|
||||
Bias: bias,
|
||||
Mode: mode,
|
||||
}
|
||||
}
|
||||
|
||||
// Forward applies upsampling to channels-last input [B, T, H, W, C]
|
||||
func (u *Upsample) Forward(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
T := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
C := shape[4]
|
||||
outC := u.Conv.Shape()[0]
|
||||
|
||||
// Stage 1: 2x nearest neighbor upsample
|
||||
{
|
||||
x = mlx.Reshape(x, B*T, H, W, C)
|
||||
x = upsample2xChannelsLast(x)
|
||||
mlx.Eval(x)
|
||||
}
|
||||
|
||||
// Stage 2: Conv + bias
|
||||
{
|
||||
prev := x
|
||||
weight := mlx.Transpose(u.Conv, 0, 2, 3, 1)
|
||||
x = conv2D3x3PaddedChannelsLast(x, weight)
|
||||
if u.Bias != nil {
|
||||
bias := mlx.Reshape(u.Bias, 1, 1, 1, outC)
|
||||
x = mlx.Add(x, bias)
|
||||
}
|
||||
x = mlx.Reshape(x, B, T, H*2, W*2, outC)
|
||||
prev.Free()
|
||||
mlx.Eval(x)
|
||||
}
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// MidBlock is the middle block
|
||||
type MidBlock struct {
|
||||
ResBlock1 *ResBlock
|
||||
Attention *AttentionBlock
|
||||
ResBlock2 *ResBlock
|
||||
}
|
||||
|
||||
// newMidBlock creates a mid block
|
||||
func newMidBlock(weights *safetensors.ModelWeights, prefix string, dim int32) (*MidBlock, error) {
|
||||
res1, err := newResBlock(weights, prefix+".resnets.0", dim, dim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
attn, err := newAttentionBlock(weights, prefix+".attentions.0", dim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
res2, err := newResBlock(weights, prefix+".resnets.1", dim, dim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &MidBlock{
|
||||
ResBlock1: res1,
|
||||
Attention: attn,
|
||||
ResBlock2: res2,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies mid block
|
||||
func (m *MidBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
prev := x
|
||||
x = m.ResBlock1.Forward(x)
|
||||
prev.Free()
|
||||
|
||||
prev = x
|
||||
x = m.Attention.Forward(x)
|
||||
prev.Free()
|
||||
|
||||
prev = x
|
||||
x = m.ResBlock2.Forward(x)
|
||||
prev.Free()
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
func silu3D(x *mlx.Array) *mlx.Array {
|
||||
return mlx.Mul(x, mlx.Sigmoid(x))
|
||||
}
|
||||
|
||||
// pad3DChannelsLast pads a channels-last [B, T, H, W, C] tensor
|
||||
func pad3DChannelsLast(x *mlx.Array, tBefore, tAfter, hBefore, hAfter, wBefore, wAfter int32) *mlx.Array {
|
||||
if tBefore == 0 && tAfter == 0 && hBefore == 0 && hAfter == 0 && wBefore == 0 && wAfter == 0 {
|
||||
return x
|
||||
}
|
||||
return mlx.Pad(x, []int32{0, 0, tBefore, tAfter, hBefore, hAfter, wBefore, wAfter, 0, 0})
|
||||
}
|
||||
|
||||
// upsample2xChannelsLast upsamples channels-last input [B, H, W, C] by 2x
|
||||
func upsample2xChannelsLast(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
H := shape[1]
|
||||
W := shape[2]
|
||||
|
||||
rowIdxData := make([]int32, H*2)
|
||||
for i := int32(0); i < H; i++ {
|
||||
rowIdxData[i*2] = i
|
||||
rowIdxData[i*2+1] = i
|
||||
}
|
||||
rowIdx := mlx.NewArrayInt32(rowIdxData, []int32{H * 2})
|
||||
|
||||
colIdxData := make([]int32, W*2)
|
||||
for i := int32(0); i < W; i++ {
|
||||
colIdxData[i*2] = i
|
||||
colIdxData[i*2+1] = i
|
||||
}
|
||||
colIdx := mlx.NewArrayInt32(colIdxData, []int32{W * 2})
|
||||
|
||||
x = mlx.Take(x, rowIdx, 1)
|
||||
x = mlx.Take(x, colIdx, 2)
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// conv2D3x3PaddedChannelsLast applies 3x3 conv with padding to channels-last input [B, H, W, C]
|
||||
func conv2D3x3PaddedChannelsLast(x, weight *mlx.Array) *mlx.Array {
|
||||
x = mlx.Pad(x, []int32{0, 0, 1, 1, 1, 1, 0, 0})
|
||||
return mlx.Conv2d(x, weight, 1, 0)
|
||||
}
|
||||
|
||||
// conv2DStrided applies conv with stride > 1 using manual patch extraction
|
||||
// x: [B, H, W, C] (channels-last), weight: [O, kH, kW, I]
|
||||
func conv2DStrided(x, weight *mlx.Array, stride int32) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
H := shape[1]
|
||||
W := shape[2]
|
||||
|
||||
wShape := weight.Shape()
|
||||
Cout := wShape[0]
|
||||
kH := wShape[1]
|
||||
kW := wShape[2]
|
||||
|
||||
outH := (H - kH) / stride + 1
|
||||
outW := (W - kW) / stride + 1
|
||||
|
||||
patches := extractPatches2DStrided(x, kH, kW, stride)
|
||||
wFlat := mlx.Reshape(weight, Cout, -1)
|
||||
patches = mlx.Reshape(patches, B*outH*outW, -1)
|
||||
out := mlx.Linear(patches, mlx.Transpose(wFlat, 1, 0))
|
||||
return mlx.Reshape(out, B, outH, outW, Cout)
|
||||
}
|
||||
|
||||
// conv3DStrided applies 3D conv with strides using manual patch extraction
|
||||
// x: [B, T, H, W, C] (channels-last), weight: [O, I, kT, kH, kW] (PyTorch format)
|
||||
// strideT, strideH, strideW are the strides for each dimension
|
||||
// Patches are extracted in [C, T, H, W] order to match Python's preprocessing
|
||||
func conv3DStrided(x, weight *mlx.Array, strideT, strideH, strideW int32) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
T := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
C := shape[4]
|
||||
|
||||
wShape := weight.Shape()
|
||||
Cout := wShape[0]
|
||||
// I := wShape[1]
|
||||
kT := wShape[2]
|
||||
kH := wShape[3]
|
||||
kW := wShape[4]
|
||||
|
||||
// For temporal: if T < kT, we need to repeat frames temporally
|
||||
// For single image with T=1 and kT=2, we duplicate the frame to T=kT
|
||||
// Python Qwen2.5-VL duplicates the frame, not zero-pads
|
||||
if T < kT {
|
||||
// Tile along T dimension: [B, T, H, W, C] -> [B, kT, H, W, C]
|
||||
x = mlx.Tile(x, []int32{1, kT, 1, 1, 1})
|
||||
T = kT
|
||||
}
|
||||
|
||||
outT := (T - kT) / strideT + 1
|
||||
outH := (H - kH) / strideH + 1
|
||||
outW := (W - kW) / strideW + 1
|
||||
|
||||
// Extract 3D patches in [C, T, H, W] order to match Python
|
||||
patches := extractPatches3DStrided(x, kT, kH, kW, strideT, strideH, strideW)
|
||||
// patches shape: [B, outT, outH, outW, C*kT*kH*kW]
|
||||
|
||||
// Weight is [O, I, kT, kH, kW] - flatten to [O, I*kT*kH*kW] to match patch order [C, T, H, W]
|
||||
wFlat := mlx.Reshape(weight, Cout, -1) // [Cout, I*kT*kH*kW]
|
||||
patches = mlx.Reshape(patches, B*outT*outH*outW, C*kT*kH*kW)
|
||||
out := mlx.Linear(patches, mlx.Transpose(wFlat, 1, 0))
|
||||
return mlx.Reshape(out, B, outT, outH, outW, Cout)
|
||||
}
|
||||
|
||||
// extractPatches3DStrided extracts 3D patches with given strides
|
||||
// Returns patches with values in [C, T, H, W] order to match Python's preprocessing
|
||||
func extractPatches3DStrided(x *mlx.Array, kT, kH, kW, strideT, strideH, strideW int32) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
T := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
C := shape[4]
|
||||
|
||||
outT := (T - kT) / strideT + 1
|
||||
outH := (H - kH) / strideH + 1
|
||||
outW := (W - kW) / strideW + 1
|
||||
|
||||
numPatches := outT * outH * outW
|
||||
patches := make([]*mlx.Array, numPatches)
|
||||
idx := 0
|
||||
for t := int32(0); t < outT; t++ {
|
||||
for i := int32(0); i < outH; i++ {
|
||||
for j := int32(0); j < outW; j++ {
|
||||
startT := t * strideT
|
||||
startH := i * strideH
|
||||
startW := j * strideW
|
||||
// Extract patch: [B, kT, kH, kW, C]
|
||||
patch := mlx.Slice(x,
|
||||
[]int32{0, startT, startH, startW, 0},
|
||||
[]int32{B, startT + kT, startH + kH, startW + kW, C})
|
||||
// Transpose from [B, T, H, W, C] to [B, C, T, H, W] to match Python's order
|
||||
patch = mlx.Transpose(patch, 0, 4, 1, 2, 3)
|
||||
// Flatten to [B, C*T*H*W]
|
||||
patch = mlx.Reshape(patch, B, C*kT*kH*kW)
|
||||
patches[idx] = patch
|
||||
idx++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for i := range patches {
|
||||
patches[i] = mlx.ExpandDims(patches[i], 1)
|
||||
}
|
||||
stacked := mlx.Concatenate(patches, 1)
|
||||
return mlx.Reshape(stacked, B, outT, outH, outW, C*kT*kH*kW)
|
||||
}
|
||||
|
||||
// extractPatches2DStrided extracts patches with given stride
|
||||
func extractPatches2DStrided(x *mlx.Array, kH, kW, stride int32) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
H := shape[1]
|
||||
W := shape[2]
|
||||
C := shape[3]
|
||||
|
||||
outH := (H - kH) / stride + 1
|
||||
outW := (W - kW) / stride + 1
|
||||
|
||||
patches := make([]*mlx.Array, outH*outW)
|
||||
idx := 0
|
||||
for i := int32(0); i < outH; i++ {
|
||||
for j := int32(0); j < outW; j++ {
|
||||
startH := i * stride
|
||||
startW := j * stride
|
||||
patch := mlx.Slice(x, []int32{0, startH, startW, 0}, []int32{B, startH + kH, startW + kW, C})
|
||||
patch = mlx.Reshape(patch, B, kH*kW*C)
|
||||
patches[idx] = patch
|
||||
idx++
|
||||
}
|
||||
}
|
||||
|
||||
for i := range patches {
|
||||
patches[i] = mlx.ExpandDims(patches[i], 1)
|
||||
}
|
||||
stacked := mlx.Concatenate(patches, 1)
|
||||
return mlx.Reshape(stacked, B, outH, outW, kH*kW*C)
|
||||
}
|
||||
|
||||
// layerNormNoAffine applies layer norm without learnable parameters
|
||||
func layerNormNoAffine(x *mlx.Array, eps float32) *mlx.Array {
|
||||
ndim := x.Ndim()
|
||||
lastAxis := ndim - 1
|
||||
mean := mlx.Mean(x, lastAxis, true)
|
||||
xCentered := mlx.Sub(x, mean)
|
||||
variance := mlx.Mean(mlx.Square(xCentered), lastAxis, true)
|
||||
return mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, eps)))
|
||||
}
|
||||
475
x/imagegen/models/qwen_image_edit/processor.go
Normal file
475
x/imagegen/models/qwen_image_edit/processor.go
Normal file
@@ -0,0 +1,475 @@
|
||||
//go:build mlx
|
||||
|
||||
package qwen_image_edit
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"image"
|
||||
"image/color"
|
||||
_ "image/jpeg"
|
||||
_ "image/png"
|
||||
"math"
|
||||
"os"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"golang.org/x/image/draw"
|
||||
_ "golang.org/x/image/webp"
|
||||
)
|
||||
|
||||
// loadImageFile loads an image from disk
|
||||
func loadImageFile(path string) (image.Image, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open image: %w", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
img, _, err := image.Decode(f)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode image: %w", err)
|
||||
}
|
||||
return img, nil
|
||||
}
|
||||
|
||||
// imageToFloat32Pixels converts an image to a float32 pixel array [H, W, C] in [0, 1] range
|
||||
func imageToFloat32Pixels(img image.Image, width, height int) []float32 {
|
||||
pixels := make([]float32, width*height*3)
|
||||
idx := 0
|
||||
for y := 0; y < height; y++ {
|
||||
for x := 0; x < width; x++ {
|
||||
r, g, b, _ := img.At(x, y).RGBA()
|
||||
pixels[idx] = float32(r) / 65535.0
|
||||
pixels[idx+1] = float32(g) / 65535.0
|
||||
pixels[idx+2] = float32(b) / 65535.0
|
||||
idx += 3
|
||||
}
|
||||
}
|
||||
return pixels
|
||||
}
|
||||
|
||||
// normalizeImageNet applies ImageNet normalization to an image tensor
|
||||
func (p *Processor) normalizeImageNet(arr *mlx.Array) *mlx.Array {
|
||||
mean := mlx.NewArray(p.Config.ImageMean, []int32{1, 1, 3})
|
||||
std := mlx.NewArray(p.Config.ImageStd, []int32{1, 1, 3})
|
||||
return mlx.Div(mlx.Sub(arr, mean), std)
|
||||
}
|
||||
|
||||
// prepareImageTensor transforms [H, W, C] to [B, C, H, W] and converts to bf16
|
||||
func prepareImageTensor(arr *mlx.Array) *mlx.Array {
|
||||
// Transpose to [C, H, W] and make contiguous
|
||||
arr = mlx.Contiguous(mlx.Transpose(arr, 2, 0, 1))
|
||||
// Add batch dimension [1, C, H, W]
|
||||
arr = mlx.ExpandDims(arr, 0)
|
||||
// Convert to bf16
|
||||
arr = mlx.ToBFloat16(arr)
|
||||
mlx.Eval(arr)
|
||||
return arr
|
||||
}
|
||||
|
||||
// clampFloat clamps a value to [0, 255] and returns uint8
|
||||
func clampFloat(v, weightSum float64) uint8 {
|
||||
v /= weightSum
|
||||
if v < 0 {
|
||||
v = 0
|
||||
}
|
||||
if v > 255 {
|
||||
v = 255
|
||||
}
|
||||
return uint8(math.Round(v))
|
||||
}
|
||||
|
||||
// ImageDims holds dimensions for a preprocessed image
|
||||
type ImageDims struct {
|
||||
// Original image dimensions
|
||||
OrigW, OrigH int32
|
||||
// Condition image dimensions (for vision encoder)
|
||||
CondW, CondH int32
|
||||
// VAE image dimensions
|
||||
VaeW, VaeH int32
|
||||
// Latent dimensions (VAE dims / vae_scale_factor)
|
||||
LatentW, LatentH int32
|
||||
// Patch dimensions (latent dims / patch_size)
|
||||
PatchW, PatchH int32
|
||||
}
|
||||
|
||||
// ProcessorConfig holds image processor configuration
|
||||
type ProcessorConfig struct {
|
||||
// Condition image size (target pixel area for vision encoder input)
|
||||
// Python: CONDITION_IMAGE_SIZE = 384 * 384 = 147456
|
||||
// Pipeline resizes image to this area before passing to encode_prompt
|
||||
ConditionImageSize int32
|
||||
|
||||
// VAE image size (target pixel area)
|
||||
// Python: VAE_IMAGE_SIZE = 1024 * 1024 = 1048576
|
||||
VAEImageSize int32
|
||||
|
||||
// Image normalization (ImageNet stats for vision encoder)
|
||||
ImageMean []float32
|
||||
ImageStd []float32
|
||||
}
|
||||
|
||||
// defaultProcessorConfig returns default processor config
|
||||
func defaultProcessorConfig() *ProcessorConfig {
|
||||
return &ProcessorConfig{
|
||||
ConditionImageSize: 384 * 384, // 147456 - matches Python CONDITION_IMAGE_SIZE
|
||||
VAEImageSize: 1024 * 1024, // 1048576 - matches Python VAE_IMAGE_SIZE
|
||||
ImageMean: []float32{0.48145466, 0.4578275, 0.40821073},
|
||||
ImageStd: []float32{0.26862954, 0.26130258, 0.27577711},
|
||||
}
|
||||
}
|
||||
|
||||
// Processor handles image preprocessing for Qwen-Image-Edit
|
||||
type Processor struct {
|
||||
Config *ProcessorConfig
|
||||
}
|
||||
|
||||
// Load loads the processor config
|
||||
func (p *Processor) Load(path string) error {
|
||||
p.Config = defaultProcessorConfig()
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadAndPreprocess loads an image and preprocesses it for both paths
|
||||
// Returns: condImage (for vision encoder), vaeImage (for VAE encoding)
|
||||
func (p *Processor) LoadAndPreprocess(imagePath string) (*mlx.Array, *mlx.Array, error) {
|
||||
img, err := loadImageFile(imagePath)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
bounds := img.Bounds()
|
||||
origW := bounds.Dx()
|
||||
origH := bounds.Dy()
|
||||
ratio := float64(origW) / float64(origH)
|
||||
|
||||
// Calculate dimensions for condition image (vision encoder)
|
||||
// Python pipeline does TWO resizes:
|
||||
// 1. VaeImageProcessor.resize with Lanczos to CONDITION_IMAGE_SIZE (384x384 area)
|
||||
// 2. Qwen2VLProcessor's smart_resize with Bicubic to multiple of 28
|
||||
intermediateW, intermediateH := calculateDimensions(p.Config.ConditionImageSize, ratio, 32)
|
||||
finalH, finalW := smartResize(intermediateH, intermediateW, 28, 56*56, 28*28*1280)
|
||||
|
||||
// Calculate dimensions for VAE image (1024x1024 area)
|
||||
// Use multiple of 32 (vae_scale_factor * patch_size * 2 = 8 * 2 * 2 = 32)
|
||||
vaeW, vaeH := calculateDimensions(p.Config.VAEImageSize, ratio, 32)
|
||||
|
||||
// Preprocess for condition (vision encoder) - two-step resize
|
||||
condImage := p.preprocessImageTwoStep(img, intermediateW, intermediateH, finalW, finalH)
|
||||
|
||||
// Preprocess for VAE ([-1, 1] range, 5D tensor)
|
||||
vaeImage := p.preprocessImageForVAE(img, vaeW, vaeH)
|
||||
|
||||
return condImage, vaeImage, nil
|
||||
}
|
||||
|
||||
// preprocessImageLanczos does single-step Lanczos resize for vision encoder
|
||||
// Matches Python VaeImageProcessor.resize with resample='lanczos' (the default)
|
||||
// Used by edit_plus pipeline for multi-image input
|
||||
// Returns: [B, C, H, W] normalized tensor
|
||||
func (p *Processor) preprocessImageLanczos(img image.Image, width, height int32) *mlx.Array {
|
||||
resized := resizeImageLanczos(img, int(width), int(height))
|
||||
pixels := imageToFloat32Pixels(resized, int(width), int(height))
|
||||
arr := mlx.NewArray(pixels, []int32{height, width, 3})
|
||||
arr = p.normalizeImageNet(arr)
|
||||
return prepareImageTensor(arr)
|
||||
}
|
||||
|
||||
// preprocessImageTwoStep does two-step resize for vision encoder to match Python pipeline
|
||||
// Step 1: Lanczos resize from original to intermediate size (VaeImageProcessor.resize)
|
||||
// Step 2: Bicubic resize from intermediate to final size (Qwen2VLProcessor smart_resize)
|
||||
// Returns: [B, C, H, W] normalized tensor
|
||||
func (p *Processor) preprocessImageTwoStep(img image.Image, intermediateW, intermediateH, finalW, finalH int32) *mlx.Array {
|
||||
intermediate := resizeImageLanczos(img, int(intermediateW), int(intermediateH))
|
||||
resized := resizeImageBicubic(intermediate, int(finalW), int(finalH))
|
||||
pixels := imageToFloat32Pixels(resized, int(finalW), int(finalH))
|
||||
arr := mlx.NewArray(pixels, []int32{finalH, finalW, 3})
|
||||
arr = p.normalizeImageNet(arr)
|
||||
return prepareImageTensor(arr)
|
||||
}
|
||||
|
||||
// preprocessImage converts image to tensor for vision encoder
|
||||
// Returns: [B, C, H, W] normalized tensor
|
||||
func (p *Processor) preprocessImage(img image.Image, width, height int32, normalize bool) *mlx.Array {
|
||||
resized := resizeImageBicubic(img, int(width), int(height))
|
||||
pixels := imageToFloat32Pixels(resized, int(width), int(height))
|
||||
arr := mlx.NewArray(pixels, []int32{height, width, 3})
|
||||
if normalize {
|
||||
arr = p.normalizeImageNet(arr)
|
||||
}
|
||||
return prepareImageTensor(arr)
|
||||
}
|
||||
|
||||
// preprocessImageForVAE converts image to tensor for VAE encoding
|
||||
// Returns: [B, C, T, H, W] tensor in [-1, 1] range
|
||||
func (p *Processor) preprocessImageForVAE(img image.Image, width, height int32) *mlx.Array {
|
||||
resized := resizeImageLanczos(img, int(width), int(height))
|
||||
pixels := imageToFloat32Pixels(resized, int(width), int(height))
|
||||
arr := mlx.NewArray(pixels, []int32{height, width, 3})
|
||||
|
||||
// Scale to [-1, 1]: arr * 2 - 1
|
||||
arr = mlx.MulScalar(arr, 2.0)
|
||||
arr = mlx.AddScalar(arr, -1.0)
|
||||
|
||||
// Transpose to [C, H, W] and make contiguous
|
||||
arr = mlx.Contiguous(mlx.Transpose(arr, 2, 0, 1))
|
||||
|
||||
// Add batch and temporal dimensions [1, C, 1, H, W]
|
||||
arr = mlx.ExpandDims(arr, 0) // [1, C, H, W]
|
||||
arr = mlx.ExpandDims(arr, 2) // [1, C, 1, H, W]
|
||||
|
||||
arr = mlx.ToBFloat16(arr)
|
||||
mlx.Eval(arr)
|
||||
return arr
|
||||
}
|
||||
|
||||
// smartResize implements Python Qwen2VL processor's smart_resize logic
|
||||
// Returns (resizedHeight, resizedWidth) that fit within min/max pixel constraints
|
||||
func smartResize(height, width, factor, minPixels, maxPixels int32) (int32, int32) {
|
||||
// Round to factor
|
||||
hBar := int32(math.Round(float64(height)/float64(factor))) * factor
|
||||
wBar := int32(math.Round(float64(width)/float64(factor))) * factor
|
||||
|
||||
// Ensure minimum factor size
|
||||
if hBar < factor {
|
||||
hBar = factor
|
||||
}
|
||||
if wBar < factor {
|
||||
wBar = factor
|
||||
}
|
||||
|
||||
// Check pixel constraints
|
||||
total := hBar * wBar
|
||||
if total > maxPixels {
|
||||
// Scale down
|
||||
beta := math.Sqrt(float64(maxPixels) / float64(total))
|
||||
hBar = int32(math.Floor(float64(height)*beta/float64(factor))) * factor
|
||||
wBar = int32(math.Floor(float64(width)*beta/float64(factor))) * factor
|
||||
} else if total < minPixels {
|
||||
// Scale up
|
||||
beta := math.Sqrt(float64(minPixels) / float64(total))
|
||||
hBar = int32(math.Ceil(float64(height)*beta/float64(factor))) * factor
|
||||
wBar = int32(math.Ceil(float64(width)*beta/float64(factor))) * factor
|
||||
}
|
||||
|
||||
return hBar, wBar
|
||||
}
|
||||
|
||||
// calculateDimensions calculates width and height for a target area while maintaining ratio
|
||||
// multiple: the value to round dimensions to (e.g., 28 for vision encoder with patch 14 and 2x2 merge)
|
||||
func calculateDimensions(targetArea int32, ratio float64, multiple int32) (int32, int32) {
|
||||
width := math.Sqrt(float64(targetArea) * ratio)
|
||||
height := width / ratio
|
||||
|
||||
m := float64(multiple)
|
||||
width = math.Round(width/m) * m
|
||||
height = math.Round(height/m) * m
|
||||
|
||||
// Ensure minimum dimensions
|
||||
if width < m {
|
||||
width = m
|
||||
}
|
||||
if height < m {
|
||||
height = m
|
||||
}
|
||||
|
||||
return int32(width), int32(height)
|
||||
}
|
||||
|
||||
// resizeImageLanczos resizes an image using Lanczos3 interpolation (matches PIL.LANCZOS)
|
||||
func resizeImageLanczos(img image.Image, width, height int) image.Image {
|
||||
bounds := img.Bounds()
|
||||
dst := image.NewRGBA(image.Rect(0, 0, width, height))
|
||||
|
||||
// Lanczos3 kernel (a=3) to match PIL.LANCZOS
|
||||
lanczos3 := &draw.Kernel{
|
||||
Support: 3.0,
|
||||
At: func(t float64) float64 {
|
||||
if t == 0 {
|
||||
return 1.0
|
||||
}
|
||||
if t < 0 {
|
||||
t = -t
|
||||
}
|
||||
if t >= 3.0 {
|
||||
return 0.0
|
||||
}
|
||||
// sinc(t) * sinc(t/3)
|
||||
piT := math.Pi * t
|
||||
return (math.Sin(piT) / piT) * (math.Sin(piT/3) / (piT / 3))
|
||||
},
|
||||
}
|
||||
lanczos3.Scale(dst, dst.Bounds(), img, bounds, draw.Over, nil)
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
// resizeImageBicubic resizes an image using bicubic interpolation (matches PIL.BICUBIC)
|
||||
// Uses separable interpolation with PIL's coordinate mapping for exact match
|
||||
func resizeImageBicubic(img image.Image, width, height int) image.Image {
|
||||
bounds := img.Bounds()
|
||||
srcW := bounds.Dx()
|
||||
srcH := bounds.Dy()
|
||||
|
||||
// Convert to RGBA if needed
|
||||
var src *image.RGBA
|
||||
if rgba, ok := img.(*image.RGBA); ok {
|
||||
src = rgba
|
||||
} else {
|
||||
src = image.NewRGBA(bounds)
|
||||
for y := bounds.Min.Y; y < bounds.Max.Y; y++ {
|
||||
for x := bounds.Min.X; x < bounds.Max.X; x++ {
|
||||
src.Set(x, y, img.At(x, y))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Keys cubic with a=-0.5 (PIL BICUBIC)
|
||||
cubic := func(x float64) float64 {
|
||||
if x < 0 {
|
||||
x = -x
|
||||
}
|
||||
if x < 1 {
|
||||
return 1.5*x*x*x - 2.5*x*x + 1
|
||||
}
|
||||
if x < 2 {
|
||||
return -0.5*x*x*x + 2.5*x*x - 4*x + 2
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// Horizontal pass: srcW -> width, keep srcH rows
|
||||
temp := image.NewRGBA(image.Rect(0, 0, width, srcH))
|
||||
for y := 0; y < srcH; y++ {
|
||||
for dstX := 0; dstX < width; dstX++ {
|
||||
// PIL coordinate mapping: center-to-center
|
||||
srcXf := (float64(dstX)+0.5)*(float64(srcW)/float64(width)) - 0.5
|
||||
baseX := int(math.Floor(srcXf))
|
||||
|
||||
var sumR, sumG, sumB, sumA, weightSum float64
|
||||
for i := -1; i <= 2; i++ {
|
||||
sx := baseX + i
|
||||
if sx < 0 {
|
||||
sx = 0
|
||||
}
|
||||
if sx >= srcW {
|
||||
sx = srcW - 1
|
||||
}
|
||||
|
||||
w := cubic(math.Abs(srcXf - float64(baseX+i)))
|
||||
c := src.RGBAAt(sx, y)
|
||||
sumR += float64(c.R) * w
|
||||
sumG += float64(c.G) * w
|
||||
sumB += float64(c.B) * w
|
||||
sumA += float64(c.A) * w
|
||||
weightSum += w
|
||||
}
|
||||
|
||||
temp.SetRGBA(dstX, y, color.RGBA{
|
||||
clampFloat(sumR, weightSum),
|
||||
clampFloat(sumG, weightSum),
|
||||
clampFloat(sumB, weightSum),
|
||||
clampFloat(sumA, weightSum),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Vertical pass: srcH -> height
|
||||
dst := image.NewRGBA(image.Rect(0, 0, width, height))
|
||||
for x := 0; x < width; x++ {
|
||||
for dstY := 0; dstY < height; dstY++ {
|
||||
srcYf := (float64(dstY)+0.5)*(float64(srcH)/float64(height)) - 0.5
|
||||
baseY := int(math.Floor(srcYf))
|
||||
|
||||
var sumR, sumG, sumB, sumA, weightSum float64
|
||||
for j := -1; j <= 2; j++ {
|
||||
sy := baseY + j
|
||||
if sy < 0 {
|
||||
sy = 0
|
||||
}
|
||||
if sy >= srcH {
|
||||
sy = srcH - 1
|
||||
}
|
||||
|
||||
w := cubic(math.Abs(srcYf - float64(baseY+j)))
|
||||
c := temp.RGBAAt(x, sy)
|
||||
sumR += float64(c.R) * w
|
||||
sumG += float64(c.G) * w
|
||||
sumB += float64(c.B) * w
|
||||
sumA += float64(c.A) * w
|
||||
weightSum += w
|
||||
}
|
||||
|
||||
dst.SetRGBA(x, dstY, color.RGBA{
|
||||
clampFloat(sumR, weightSum),
|
||||
clampFloat(sumG, weightSum),
|
||||
clampFloat(sumB, weightSum),
|
||||
clampFloat(sumA, weightSum),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
// LoadAndPreprocessMultiple loads multiple images and preprocesses them
|
||||
// Returns: condImages (for vision encoder), vaeImages (for VAE encoding), dims (per-image dimensions)
|
||||
func (p *Processor) LoadAndPreprocessMultiple(imagePaths []string) ([]*mlx.Array, []*mlx.Array, []ImageDims, error) {
|
||||
const vaeScaleFactor int32 = 8
|
||||
const patchSize int32 = 2
|
||||
|
||||
condImages := make([]*mlx.Array, len(imagePaths))
|
||||
vaeImages := make([]*mlx.Array, len(imagePaths))
|
||||
dims := make([]ImageDims, len(imagePaths))
|
||||
|
||||
for i, imagePath := range imagePaths {
|
||||
img, err := loadImageFile(imagePath)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("image %d: %w", i, err)
|
||||
}
|
||||
|
||||
bounds := img.Bounds()
|
||||
origW := int32(bounds.Dx())
|
||||
origH := int32(bounds.Dy())
|
||||
ratio := float64(origW) / float64(origH)
|
||||
|
||||
// Calculate dimensions for condition image (vision encoder)
|
||||
// Python pipeline does TWO resizes:
|
||||
// 1. VaeImageProcessor.resize with Lanczos to CONDITION_IMAGE_SIZE (384x384 area)
|
||||
// 2. Qwen2VLProcessor's smart_resize with Bicubic to multiple of 28
|
||||
intermediateW, intermediateH := calculateDimensions(p.Config.ConditionImageSize, ratio, 32)
|
||||
condH, condW := smartResize(intermediateH, intermediateW, 28, 56*56, 28*28*1280)
|
||||
|
||||
// Calculate dimensions for VAE image (1024x1024 area)
|
||||
vaeW, vaeH := calculateDimensions(p.Config.VAEImageSize, ratio, 32)
|
||||
|
||||
// Calculate derived dimensions
|
||||
latentW := vaeW / vaeScaleFactor
|
||||
latentH := vaeH / vaeScaleFactor
|
||||
patchW := latentW / patchSize
|
||||
patchH := latentH / patchSize
|
||||
|
||||
dims[i] = ImageDims{
|
||||
OrigW: origW,
|
||||
OrigH: origH,
|
||||
CondW: condW,
|
||||
CondH: condH,
|
||||
VaeW: vaeW,
|
||||
VaeH: vaeH,
|
||||
LatentW: latentW,
|
||||
LatentH: latentH,
|
||||
PatchW: patchW,
|
||||
PatchH: patchH,
|
||||
}
|
||||
|
||||
fmt.Printf(" Image %d: orig=%dx%d, cond=%dx%d, vae=%dx%d, latent=%dx%d, patch=%dx%d\n",
|
||||
i+1, origW, origH, condW, condH, vaeW, vaeH, latentW, latentH, patchW, patchH)
|
||||
|
||||
// Preprocess for condition (vision encoder) - two-step resize to match Python pipeline
|
||||
condImages[i] = p.preprocessImageTwoStep(img, intermediateW, intermediateH, condW, condH)
|
||||
|
||||
// Preprocess for VAE ([-1, 1] range, 5D tensor)
|
||||
vaeImages[i] = p.preprocessImageForVAE(img, vaeW, vaeH)
|
||||
}
|
||||
|
||||
return condImages, vaeImages, dims, nil
|
||||
}
|
||||
610
x/imagegen/models/qwen_image_edit/qwen_image_edit.go
Normal file
610
x/imagegen/models/qwen_image_edit/qwen_image_edit.go
Normal file
@@ -0,0 +1,610 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package qwen_image_edit implements the Qwen-Image-Edit diffusion model for image editing.
|
||||
// It reuses components from qwen_image where possible.
|
||||
package qwen_image_edit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/models/qwen_image"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
)
|
||||
|
||||
// GenerateConfig holds all options for image editing.
|
||||
type GenerateConfig struct {
|
||||
Prompt string
|
||||
NegativePrompt string // Unconditional prompt for CFG (empty string "" is valid)
|
||||
CFGScale float32 // CFG enabled when > 1.0 (default: 4.0)
|
||||
Width int32 // Output width (default: from input image)
|
||||
Height int32 // Output height (default: from input image)
|
||||
Steps int // Denoising steps (default: 50)
|
||||
Seed int64 // Random seed
|
||||
Progress ProgressFunc // Optional progress callback
|
||||
}
|
||||
|
||||
// ProgressFunc is called during generation with step progress.
|
||||
type ProgressFunc func(step, totalSteps int)
|
||||
|
||||
// Model represents a Qwen-Image-Edit diffusion model.
|
||||
type Model struct {
|
||||
ModelPath string
|
||||
Tokenizer *tokenizer.Tokenizer
|
||||
Processor *Processor // Image processor for vision encoder
|
||||
TextEncoder *qwen_image.Qwen25VL // Qwen2.5-VL vision-language encoder (from qwen_image)
|
||||
Transformer *qwen_image.Transformer // Reuse qwen_image transformer
|
||||
VAE *VAE // Combined encoder + decoder
|
||||
}
|
||||
|
||||
// Load loads the Qwen-Image-Edit model from a directory.
|
||||
func (m *Model) Load(modelPath string) error {
|
||||
fmt.Println("Loading Qwen-Image-Edit model...")
|
||||
start := time.Now()
|
||||
|
||||
if mlx.GPUIsAvailable() {
|
||||
mlx.SetDefaultDeviceGPU()
|
||||
mlx.EnableCompile()
|
||||
}
|
||||
|
||||
m.ModelPath = modelPath
|
||||
|
||||
// Load tokenizer from processor directory
|
||||
fmt.Print(" Loading tokenizer... ")
|
||||
processorPath := filepath.Join(modelPath, "processor")
|
||||
tok, err := tokenizer.Load(processorPath)
|
||||
if err != nil {
|
||||
// Fallback to tokenizer directory
|
||||
tokenizerPath := filepath.Join(modelPath, "tokenizer")
|
||||
tok, err = tokenizer.Load(tokenizerPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("tokenizer: %w", err)
|
||||
}
|
||||
}
|
||||
m.Tokenizer = tok
|
||||
fmt.Println("✓")
|
||||
|
||||
// Load processor (image preprocessing config)
|
||||
fmt.Print(" Loading processor... ")
|
||||
m.Processor = &Processor{}
|
||||
if err := m.Processor.Load(processorPath); err != nil {
|
||||
return fmt.Errorf("processor: %w", err)
|
||||
}
|
||||
fmt.Println("✓")
|
||||
|
||||
// Load vision-language text encoder (Qwen2.5-VL from qwen_image package)
|
||||
m.TextEncoder = &qwen_image.Qwen25VL{}
|
||||
if err := m.TextEncoder.Load(filepath.Join(modelPath, "text_encoder")); err != nil {
|
||||
return fmt.Errorf("text encoder: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.TextEncoder)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
// Load transformer (reuse qwen_image)
|
||||
m.Transformer = &qwen_image.Transformer{}
|
||||
if err := m.Transformer.Load(filepath.Join(modelPath, "transformer")); err != nil {
|
||||
return fmt.Errorf("transformer: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.Transformer)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
// Load VAE (encoder + decoder)
|
||||
m.VAE = &VAE{}
|
||||
if err := m.VAE.Load(filepath.Join(modelPath, "vae")); err != nil {
|
||||
return fmt.Errorf("VAE: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.VAE)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
mem := mlx.MetalGetActiveMemory()
|
||||
peak := mlx.MetalGetPeakMemory()
|
||||
fmt.Printf(" Loaded in %.2fs (%.1f GB active, %.1f GB peak)\n",
|
||||
time.Since(start).Seconds(),
|
||||
float64(mem)/(1024*1024*1024),
|
||||
float64(peak)/(1024*1024*1024))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Edit edits an image based on a text prompt.
|
||||
// inputImagePath: path to input image
|
||||
// prompt: text description of desired edit
|
||||
func (m *Model) Edit(inputImagePath string, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
|
||||
return m.EditFromConfig([]string{inputImagePath}, &GenerateConfig{
|
||||
Prompt: prompt,
|
||||
Width: width,
|
||||
Height: height,
|
||||
Steps: steps,
|
||||
Seed: seed,
|
||||
})
|
||||
}
|
||||
|
||||
// EditFromConfig edits images using the unified config struct.
|
||||
// Accepts one or more input images.
|
||||
func (m *Model) EditFromConfig(inputImagePaths []string, cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
if len(inputImagePaths) == 0 {
|
||||
return nil, fmt.Errorf("no input images provided")
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
result, err := m.edit(inputImagePaths, cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if cfg.NegativePrompt != "" {
|
||||
fmt.Printf("Edited %d image(s) with CFG (scale=%.1f) in %.2fs (%d steps)\n",
|
||||
len(inputImagePaths), cfg.CFGScale, time.Since(start).Seconds(), cfg.Steps)
|
||||
} else {
|
||||
fmt.Printf("Edited %d image(s) in %.2fs (%d steps)\n",
|
||||
len(inputImagePaths), time.Since(start).Seconds(), cfg.Steps)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// EditImage implements model.ImageEditModel interface.
|
||||
func (m *Model) EditImage(ctx context.Context, inputImagePath, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
|
||||
return m.Edit(inputImagePath, prompt, width, height, steps, seed)
|
||||
}
|
||||
|
||||
// EditMultiImage edits using multiple source images.
|
||||
// This matches diffusers' QwenImageEditPlusPipeline behavior.
|
||||
func (m *Model) EditMultiImage(inputImagePaths []string, cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
return m.EditFromConfig(inputImagePaths, cfg)
|
||||
}
|
||||
|
||||
// edit is the internal editing pipeline that handles one or more images.
|
||||
func (m *Model) edit(inputImagePaths []string, cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
// Apply defaults
|
||||
if cfg.Steps <= 0 {
|
||||
cfg.Steps = 50
|
||||
}
|
||||
if cfg.CFGScale <= 0 {
|
||||
cfg.CFGScale = 4.0
|
||||
}
|
||||
|
||||
// Load and preprocess all input images
|
||||
fmt.Printf("Loading %d image(s)...\n", len(inputImagePaths))
|
||||
condImages, vaeImages, inputDims, err := m.Processor.LoadAndPreprocessMultiple(inputImagePaths)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("preprocess images: %w", err)
|
||||
}
|
||||
for _, img := range condImages {
|
||||
mlx.Keep(img)
|
||||
}
|
||||
for _, img := range vaeImages {
|
||||
mlx.Keep(img)
|
||||
}
|
||||
mlx.Eval(append(condImages, vaeImages...)...)
|
||||
|
||||
useCFG := cfg.NegativePrompt != ""
|
||||
tcfg := m.Transformer.Config
|
||||
vaeScaleFactor := int32(8)
|
||||
|
||||
// Output dimensions - if not specified, use first input image dimensions
|
||||
if cfg.Width <= 0 {
|
||||
cfg.Width = inputDims[0].VaeW
|
||||
}
|
||||
if cfg.Height <= 0 {
|
||||
cfg.Height = inputDims[0].VaeH
|
||||
}
|
||||
|
||||
// Output (noise) latent dimensions
|
||||
outLatentH := cfg.Height / vaeScaleFactor
|
||||
outLatentW := cfg.Width / vaeScaleFactor
|
||||
outPH := outLatentH / tcfg.PatchSize
|
||||
outPW := outLatentW / tcfg.PatchSize
|
||||
noiseSeqLen := outPH * outPW
|
||||
imgSeqLen := noiseSeqLen
|
||||
|
||||
// Encode prompt with all images for conditioning
|
||||
posEmb, _, _, err := m.TextEncoder.EncodePromptWithImages(m.Tokenizer, cfg.Prompt, condImages)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("encoding prompt: %w", err)
|
||||
}
|
||||
mlx.Keep(posEmb)
|
||||
mlx.Eval(posEmb)
|
||||
|
||||
var negEmb *mlx.Array
|
||||
if useCFG {
|
||||
negEmb, _, _, err = m.TextEncoder.EncodePromptWithImages(m.Tokenizer, cfg.NegativePrompt, condImages)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("encoding negative prompt: %w", err)
|
||||
}
|
||||
mlx.Keep(negEmb)
|
||||
mlx.Eval(negEmb)
|
||||
}
|
||||
|
||||
// Pad sequences to same length for CFG
|
||||
txtLen := posEmb.Shape()[1]
|
||||
if useCFG {
|
||||
negLen := negEmb.Shape()[1]
|
||||
if negLen > txtLen {
|
||||
txtLen = negLen
|
||||
}
|
||||
if posEmb.Shape()[1] < txtLen {
|
||||
posEmb = padSequence(posEmb, txtLen)
|
||||
}
|
||||
if negEmb.Shape()[1] < txtLen {
|
||||
negEmb = padSequence(negEmb, txtLen)
|
||||
}
|
||||
mlx.Keep(posEmb, negEmb)
|
||||
mlx.Eval(posEmb, negEmb)
|
||||
}
|
||||
|
||||
// Encode all input images to latents and concatenate
|
||||
fmt.Println("Encoding images to latents...")
|
||||
allImageLatentsPacked := make([]*mlx.Array, len(vaeImages))
|
||||
for i, vaeImage := range vaeImages {
|
||||
imageLatents := m.VAE.Encode(vaeImage)
|
||||
imageLatents = m.VAE.Normalize(imageLatents)
|
||||
imageLatents2D := mlx.Squeeze(imageLatents, 2)
|
||||
packed := qwen_image.PackLatents(imageLatents2D, tcfg.PatchSize)
|
||||
mlx.Keep(packed)
|
||||
mlx.Eval(packed)
|
||||
allImageLatentsPacked[i] = packed
|
||||
}
|
||||
|
||||
imageLatentsPacked := mlx.Concatenate(allImageLatentsPacked, 1)
|
||||
mlx.Keep(imageLatentsPacked)
|
||||
mlx.Eval(imageLatentsPacked)
|
||||
|
||||
// Scheduler
|
||||
scheduler := qwen_image.NewFlowMatchScheduler(qwen_image.DefaultSchedulerConfig())
|
||||
scheduler.SetTimesteps(cfg.Steps, noiseSeqLen)
|
||||
|
||||
// Init noise latents in packed format
|
||||
packedChannels := tcfg.OutChannels * tcfg.PatchSize * tcfg.PatchSize
|
||||
packedNoise := scheduler.InitNoisePacked(1, noiseSeqLen, packedChannels, cfg.Seed)
|
||||
latents := qwen_image.UnpackLatents(packedNoise, outLatentH, outLatentW, tcfg.PatchSize)
|
||||
mlx.Eval(latents)
|
||||
|
||||
// RoPE cache
|
||||
ropeCache := PrepareRoPEMultiImage(outPH, outPW, inputDims, txtLen, tcfg.AxesDimsRope)
|
||||
mlx.Keep(ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
mlx.Eval(ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
|
||||
// Denoising loop
|
||||
fmt.Printf("Running denoising (%d steps)...\n", cfg.Steps)
|
||||
for i := 0; i < cfg.Steps; i++ {
|
||||
stepStart := time.Now()
|
||||
if cfg.Progress != nil {
|
||||
cfg.Progress(i+1, cfg.Steps)
|
||||
}
|
||||
|
||||
t := scheduler.Timesteps[i]
|
||||
timestep := mlx.ToBFloat16(mlx.NewArray([]float32{t}, []int32{1}))
|
||||
mlx.Eval(timestep)
|
||||
|
||||
latents2D := mlx.Squeeze(latents, 2)
|
||||
patches := qwen_image.PackLatents(latents2D, tcfg.PatchSize)
|
||||
latentInput := mlx.Concatenate([]*mlx.Array{patches, imageLatentsPacked}, 1)
|
||||
|
||||
var output *mlx.Array
|
||||
if useCFG {
|
||||
posOutput := m.Transformer.Forward(latentInput, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
negOutput := m.Transformer.Forward(latentInput, negEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
|
||||
posOutput = mlx.Slice(posOutput, []int32{0, 0, 0}, []int32{1, imgSeqLen, posOutput.Shape()[2]})
|
||||
negOutput = mlx.Slice(negOutput, []int32{0, 0, 0}, []int32{1, imgSeqLen, negOutput.Shape()[2]})
|
||||
|
||||
output = applyCFGWithNormRescale(posOutput, negOutput, cfg.CFGScale)
|
||||
} else {
|
||||
output = m.Transformer.Forward(latentInput, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
|
||||
output = mlx.Slice(output, []int32{0, 0, 0}, []int32{1, imgSeqLen, output.Shape()[2]})
|
||||
}
|
||||
|
||||
noisePred := qwen_image.UnpackLatents(output, outLatentH, outLatentW, tcfg.PatchSize)
|
||||
oldLatents := latents
|
||||
latents = scheduler.Step(noisePred, latents, i)
|
||||
mlx.Eval(latents)
|
||||
oldLatents.Free()
|
||||
|
||||
fmt.Printf(" Step %d/%d: t=%.4f (%.2fs)\n", i+1, cfg.Steps, t, time.Since(stepStart).Seconds())
|
||||
}
|
||||
|
||||
// Free denoising temporaries
|
||||
posEmb.Free()
|
||||
if negEmb != nil {
|
||||
negEmb.Free()
|
||||
}
|
||||
ropeCache.ImgFreqs.Free()
|
||||
ropeCache.TxtFreqs.Free()
|
||||
imageLatentsPacked.Free()
|
||||
|
||||
// Decode latents
|
||||
decoded := m.decodeAndPostprocess(latents)
|
||||
latents.Free()
|
||||
|
||||
fmt.Printf(" Peak memory: %.2f GB\n", float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
return decoded, nil
|
||||
}
|
||||
|
||||
// applyCFGWithNormRescale applies classifier-free guidance with norm rescaling.
|
||||
// This prevents CFG from inflating magnitude too much.
|
||||
func applyCFGWithNormRescale(posOutput, negOutput *mlx.Array, scale float32) *mlx.Array {
|
||||
// Upcast to float32 for precision
|
||||
posF32 := mlx.AsType(posOutput, mlx.DtypeFloat32)
|
||||
negF32 := mlx.AsType(negOutput, mlx.DtypeFloat32)
|
||||
|
||||
// CFG: pred = neg + scale * (pos - neg)
|
||||
diff := mlx.Sub(posF32, negF32)
|
||||
scaledDiff := mlx.MulScalar(diff, scale)
|
||||
combPred := mlx.Add(negF32, scaledDiff)
|
||||
|
||||
// Norm rescaling: rescale combined prediction to match conditional norm
|
||||
condNorm := mlx.Sqrt(mlx.Sum(mlx.Square(posF32), -1, true))
|
||||
combNorm := mlx.Sqrt(mlx.Sum(mlx.Square(combPred), -1, true))
|
||||
output := mlx.Mul(combPred, mlx.Div(condNorm, combNorm))
|
||||
|
||||
mlx.Eval(output)
|
||||
return mlx.ToBFloat16(output)
|
||||
}
|
||||
|
||||
// decodeAndPostprocess denormalizes latents, decodes through VAE, and scales to [0,1].
|
||||
func (m *Model) decodeAndPostprocess(latents *mlx.Array) *mlx.Array {
|
||||
latents = m.VAE.Denormalize(latents)
|
||||
decoded := m.VAE.Decode(latents)
|
||||
|
||||
// Post-process: squeeze temporal dim and rescale to [0, 1]
|
||||
decoded = mlx.Squeeze(decoded, 2)
|
||||
decoded = mlx.AddScalar(decoded, 1.0)
|
||||
decoded = mlx.DivScalar(decoded, 2.0)
|
||||
decoded = mlx.ClipScalar(decoded, 0.0, 1.0, true, true)
|
||||
mlx.Eval(decoded)
|
||||
return decoded
|
||||
}
|
||||
|
||||
// padSequence pads a sequence tensor to the target length with zeros
|
||||
func padSequence(x *mlx.Array, targetLen int32) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
currentLen := shape[1]
|
||||
if currentLen >= targetLen {
|
||||
return x
|
||||
}
|
||||
padLen := targetLen - currentLen
|
||||
// Pad on sequence dimension (axis 1)
|
||||
return mlx.Pad(x, []int32{0, 0, 0, padLen, 0, 0})
|
||||
}
|
||||
|
||||
// LoadPersistent is an alias for backward compatibility.
|
||||
func LoadPersistent(modelPath string) (*Model, error) {
|
||||
m := &Model{}
|
||||
if err := m.Load(modelPath); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// PrepareRoPEMultiImage computes RoPE with interpolation for image editing.
|
||||
// Handles single or multiple input images with different resolutions.
|
||||
//
|
||||
// Parameters:
|
||||
// - outPH, outPW: output patch dimensions (noise latent resolution)
|
||||
// - inputDims: patch dimensions for each input image [(pH1, pW1), (pH2, pW2), ...]
|
||||
// - txtLen: text sequence length
|
||||
// - axesDims: RoPE axis dimensions [16, 56, 56]
|
||||
//
|
||||
// Returns RoPE cache where:
|
||||
// - ImgFreqs has (outPH*outPW + sum(inPH*inPW for each image)) positions
|
||||
// - First outPH*outPW positions are for noise latents (standard RoPE at output res)
|
||||
// - Following positions are for each input image (interpolated from output res)
|
||||
func PrepareRoPEMultiImage(outPH, outPW int32, inputDims []ImageDims, txtLen int32, axesDims []int32) *qwen_image.RoPECache {
|
||||
theta := float64(10000)
|
||||
maxIdx := int32(4096)
|
||||
|
||||
// Compute base frequencies for each axis dimension
|
||||
freqsT := qwen_image.ComputeAxisFreqs(axesDims[0], theta)
|
||||
freqsH := qwen_image.ComputeAxisFreqs(axesDims[1], theta)
|
||||
freqsW := qwen_image.ComputeAxisFreqs(axesDims[2], theta)
|
||||
|
||||
// Build frequency lookup tables
|
||||
posFreqsT := qwen_image.MakeFreqTable(maxIdx, freqsT, false)
|
||||
posFreqsH := qwen_image.MakeFreqTable(maxIdx, freqsH, false)
|
||||
posFreqsW := qwen_image.MakeFreqTable(maxIdx, freqsW, false)
|
||||
negFreqsT := qwen_image.MakeFreqTable(maxIdx, freqsT, true) // For frame -1 on last condition image
|
||||
negFreqsH := qwen_image.MakeFreqTable(maxIdx, freqsH, true)
|
||||
negFreqsW := qwen_image.MakeFreqTable(maxIdx, freqsW, true)
|
||||
|
||||
headDim := int32(len(freqsT)+len(freqsH)+len(freqsW)) * 2
|
||||
|
||||
// Helper to compute RoPE for a single position at output resolution with scale_rope
|
||||
computePosFreqs := func(framePos, y, x int32) []float32 {
|
||||
row := make([]float32, headDim)
|
||||
idx := 0
|
||||
|
||||
// Frame position
|
||||
for i := 0; i < len(freqsT)*2; i++ {
|
||||
row[idx+i] = posFreqsT[framePos][i]
|
||||
}
|
||||
idx += len(freqsT) * 2
|
||||
|
||||
// Height with scale_rope centering (using OUTPUT dimensions)
|
||||
outHHalf := outPH / 2
|
||||
hNegCount := outPH - outHHalf
|
||||
if y < hNegCount {
|
||||
negTableIdx := maxIdx - hNegCount + y
|
||||
for i := 0; i < len(freqsH)*2; i++ {
|
||||
row[idx+i] = negFreqsH[negTableIdx][i]
|
||||
}
|
||||
} else {
|
||||
posIdx := y - hNegCount
|
||||
for i := 0; i < len(freqsH)*2; i++ {
|
||||
row[idx+i] = posFreqsH[posIdx][i]
|
||||
}
|
||||
}
|
||||
idx += len(freqsH) * 2
|
||||
|
||||
// Width with scale_rope centering (using OUTPUT dimensions)
|
||||
outWHalf := outPW / 2
|
||||
wNegCount := outPW - outWHalf
|
||||
if x < wNegCount {
|
||||
negTableIdx := maxIdx - wNegCount + x
|
||||
for i := 0; i < len(freqsW)*2; i++ {
|
||||
row[idx+i] = negFreqsW[negTableIdx][i]
|
||||
}
|
||||
} else {
|
||||
posIdx := x - wNegCount
|
||||
for i := 0; i < len(freqsW)*2; i++ {
|
||||
row[idx+i] = posFreqsW[posIdx][i]
|
||||
}
|
||||
}
|
||||
|
||||
return row
|
||||
}
|
||||
|
||||
// Helper to compute RoPE for frame -1 (used for last condition image)
|
||||
// This matches Python's _compute_condition_freqs which uses freqs_neg[0][-1:]
|
||||
computeNegFrameFreqs := func(y, x int32) []float32 {
|
||||
row := make([]float32, headDim)
|
||||
idx := 0
|
||||
|
||||
// Frame -1: use last row of negative frame frequencies
|
||||
negFrameIdx := maxIdx - 1
|
||||
for i := 0; i < len(freqsT)*2; i++ {
|
||||
row[idx+i] = negFreqsT[negFrameIdx][i]
|
||||
}
|
||||
idx += len(freqsT) * 2
|
||||
|
||||
// Height with scale_rope centering (using OUTPUT dimensions)
|
||||
outHHalf := outPH / 2
|
||||
hNegCount := outPH - outHHalf
|
||||
if y < hNegCount {
|
||||
negTableIdx := maxIdx - hNegCount + y
|
||||
for i := 0; i < len(freqsH)*2; i++ {
|
||||
row[idx+i] = negFreqsH[negTableIdx][i]
|
||||
}
|
||||
} else {
|
||||
posIdx := y - hNegCount
|
||||
for i := 0; i < len(freqsH)*2; i++ {
|
||||
row[idx+i] = posFreqsH[posIdx][i]
|
||||
}
|
||||
}
|
||||
idx += len(freqsH) * 2
|
||||
|
||||
// Width with scale_rope centering (using OUTPUT dimensions)
|
||||
outWHalf := outPW / 2
|
||||
wNegCount := outPW - outWHalf
|
||||
if x < wNegCount {
|
||||
negTableIdx := maxIdx - wNegCount + x
|
||||
for i := 0; i < len(freqsW)*2; i++ {
|
||||
row[idx+i] = negFreqsW[negTableIdx][i]
|
||||
}
|
||||
} else {
|
||||
posIdx := x - wNegCount
|
||||
for i := 0; i < len(freqsW)*2; i++ {
|
||||
row[idx+i] = posFreqsW[posIdx][i]
|
||||
}
|
||||
}
|
||||
|
||||
return row
|
||||
}
|
||||
|
||||
// Total image sequence length: noise + all input images
|
||||
noiseSeqLen := outPH * outPW
|
||||
totalImgLen := noiseSeqLen
|
||||
for _, dims := range inputDims {
|
||||
totalImgLen += dims.PatchH * dims.PatchW
|
||||
}
|
||||
|
||||
imgFreqsData := make([]float32, totalImgLen*headDim)
|
||||
idx := int32(0)
|
||||
|
||||
// Segment 0: Noise latents - standard RoPE at output resolution (frame 0)
|
||||
for y := int32(0); y < outPH; y++ {
|
||||
for x := int32(0); x < outPW; x++ {
|
||||
row := computePosFreqs(0, y, x)
|
||||
copy(imgFreqsData[idx:], row)
|
||||
idx += headDim
|
||||
}
|
||||
}
|
||||
|
||||
// Segments 1..N: Edit image latents - INTERPOLATED RoPE
|
||||
// For single image: use frame 1 (matches original PrepareRoPEInterpolated)
|
||||
// For multiple images: Python uses frame -1 for the LAST condition image
|
||||
// (_compute_condition_freqs), positive indices for others.
|
||||
numImages := len(inputDims)
|
||||
lastImgIdx := numImages - 1
|
||||
for imgIdx, dims := range inputDims {
|
||||
inPH := dims.PatchH
|
||||
inPW := dims.PatchW
|
||||
|
||||
// Determine frame index for this image
|
||||
// Single image case: use frame 1 (like original PrepareRoPEInterpolated)
|
||||
// Multi-image case: last image uses frame -1, others use frame 1, 2, etc.
|
||||
useNegFrame := numImages > 1 && imgIdx == lastImgIdx
|
||||
|
||||
// Map each input position to an output position using linear interpolation
|
||||
for y := int32(0); y < inPH; y++ {
|
||||
for x := int32(0); x < inPW; x++ {
|
||||
// Interpolate: map input (y, x) to output grid position
|
||||
// This is the key fix from DiffSynth's forward_sampling
|
||||
var yOut, xOut int32
|
||||
if inPH == 1 {
|
||||
yOut = 0
|
||||
} else {
|
||||
// Linear interpolation: y_out = y * (outPH - 1) / (inPH - 1)
|
||||
yOut = y * (outPH - 1) / (inPH - 1)
|
||||
}
|
||||
if inPW == 1 {
|
||||
xOut = 0
|
||||
} else {
|
||||
xOut = x * (outPW - 1) / (inPW - 1)
|
||||
}
|
||||
|
||||
var row []float32
|
||||
if useNegFrame {
|
||||
// Last image in multi-image uses frame -1
|
||||
row = computeNegFrameFreqs(yOut, xOut)
|
||||
} else {
|
||||
// Single image uses frame 1, multi-image uses frame 1, 2, etc.
|
||||
frameIdx := int32(imgIdx + 1)
|
||||
row = computePosFreqs(frameIdx, yOut, xOut)
|
||||
}
|
||||
copy(imgFreqsData[idx:], row)
|
||||
idx += headDim
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
imgFreqs := mlx.NewArray(imgFreqsData, []int32{totalImgLen, headDim})
|
||||
imgFreqs = mlx.ToBFloat16(imgFreqs)
|
||||
|
||||
// Text frequencies - start after max video index
|
||||
maxVidIdx := max(outPH/2, outPW/2)
|
||||
|
||||
txtFreqsData := make([]float32, txtLen*headDim)
|
||||
idx = 0
|
||||
for t := int32(0); t < txtLen; t++ {
|
||||
pos := maxVidIdx + t
|
||||
for i := 0; i < len(freqsT)*2; i++ {
|
||||
txtFreqsData[idx+int32(i)] = posFreqsT[pos][i]
|
||||
}
|
||||
idx += int32(len(freqsT) * 2)
|
||||
for i := 0; i < len(freqsH)*2; i++ {
|
||||
txtFreqsData[idx+int32(i)] = posFreqsH[pos][i]
|
||||
}
|
||||
idx += int32(len(freqsH) * 2)
|
||||
for i := 0; i < len(freqsW)*2; i++ {
|
||||
txtFreqsData[idx+int32(i)] = posFreqsW[pos][i]
|
||||
}
|
||||
idx += int32(len(freqsW) * 2)
|
||||
}
|
||||
|
||||
txtFreqs := mlx.NewArray(txtFreqsData, []int32{txtLen, headDim})
|
||||
txtFreqs = mlx.ToBFloat16(txtFreqs)
|
||||
|
||||
return &qwen_image.RoPECache{
|
||||
ImgFreqs: imgFreqs,
|
||||
TxtFreqs: txtFreqs,
|
||||
}
|
||||
}
|
||||
227
x/imagegen/models/qwen_image_edit/rope_test.go
Normal file
227
x/imagegen/models/qwen_image_edit/rope_test.go
Normal file
@@ -0,0 +1,227 @@
|
||||
//go:build mlx
|
||||
|
||||
package qwen_image_edit
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/models/qwen_image"
|
||||
)
|
||||
|
||||
// TestComputeAxisFreqs verifies frequency computation matches Python reference
|
||||
func TestComputeAxisFreqs(t *testing.T) {
|
||||
theta := float64(10000)
|
||||
|
||||
// Expected values from Python:
|
||||
// freqs = 1.0 / (theta ** (np.arange(0, half_dim) / half_dim))
|
||||
expectedFreqsT := []float64{
|
||||
1.000000000000000, 0.316227766016838, 0.100000000000000, 0.031622776601684,
|
||||
0.010000000000000, 0.003162277660168, 0.001000000000000, 0.000316227766017,
|
||||
}
|
||||
|
||||
expectedFreqsH_first4 := []float64{
|
||||
1.000000000000000, 0.719685673001152, 0.517947467923121, 0.372759372031494,
|
||||
}
|
||||
|
||||
expectedFreqsH_last4 := []float64{
|
||||
0.000372759372031, 0.000268269579528, 0.000193069772888, 0.000138949549437,
|
||||
}
|
||||
|
||||
// Test temporal frequencies (dim=16)
|
||||
freqsT := qwen_image.ComputeAxisFreqs(16, theta)
|
||||
if len(freqsT) != 8 {
|
||||
t.Fatalf("expected 8 temporal frequencies, got %d", len(freqsT))
|
||||
}
|
||||
for i, expected := range expectedFreqsT {
|
||||
if diff := math.Abs(freqsT[i] - expected); diff > 1e-10 {
|
||||
t.Errorf("freqsT[%d]: expected %.15f, got %.15f, diff %.2e", i, expected, freqsT[i], diff)
|
||||
}
|
||||
}
|
||||
|
||||
// Test height/width frequencies (dim=56)
|
||||
freqsH := qwen_image.ComputeAxisFreqs(56, theta)
|
||||
if len(freqsH) != 28 {
|
||||
t.Fatalf("expected 28 height frequencies, got %d", len(freqsH))
|
||||
}
|
||||
for i, expected := range expectedFreqsH_first4 {
|
||||
if diff := math.Abs(freqsH[i] - expected); diff > 1e-10 {
|
||||
t.Errorf("freqsH[%d]: expected %.15f, got %.15f, diff %.2e", i, expected, freqsH[i], diff)
|
||||
}
|
||||
}
|
||||
for i, expected := range expectedFreqsH_last4 {
|
||||
idx := 24 + i // last 4 of 28
|
||||
if diff := math.Abs(freqsH[idx] - expected); diff > 1e-10 {
|
||||
t.Errorf("freqsH[%d]: expected %.15f, got %.15f, diff %.2e", idx, expected, freqsH[idx], diff)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestMakeFreqTable verifies the frequency lookup table for both positive and negative positions
|
||||
func TestMakeFreqTable(t *testing.T) {
|
||||
theta := float64(10000)
|
||||
freqsT := qwen_image.ComputeAxisFreqs(16, theta)
|
||||
maxIdx := int32(4096)
|
||||
|
||||
// Test positive table
|
||||
posTable := qwen_image.MakeFreqTable(maxIdx, freqsT, false)
|
||||
|
||||
// Position 0 should give cos=1, sin=0 for all frequencies
|
||||
for i := 0; i < len(freqsT)*2; i += 2 {
|
||||
if posTable[0][i] != 1.0 {
|
||||
t.Errorf("posTable[0][%d] (cos): expected 1.0, got %f", i, posTable[0][i])
|
||||
}
|
||||
if posTable[0][i+1] != 0.0 {
|
||||
t.Errorf("posTable[0][%d] (sin): expected 0.0, got %f", i+1, posTable[0][i+1])
|
||||
}
|
||||
}
|
||||
|
||||
// Position 1, first frequency (1.0): angle = 1*1 = 1
|
||||
// cos(1) = 0.5403, sin(1) = 0.8415
|
||||
if diff := math.Abs(float64(posTable[1][0]) - 0.5403023058681398); diff > 1e-6 {
|
||||
t.Errorf("posTable[1][0] (cos): expected 0.5403, got %f", posTable[1][0])
|
||||
}
|
||||
if diff := math.Abs(float64(posTable[1][1]) - 0.8414709848078965); diff > 1e-6 {
|
||||
t.Errorf("posTable[1][1] (sin): expected 0.8415, got %f", posTable[1][1])
|
||||
}
|
||||
|
||||
// Test negative table
|
||||
negTable := qwen_image.MakeFreqTable(maxIdx, freqsT, true)
|
||||
|
||||
// negTable[4095] corresponds to position -1
|
||||
// cos(-1) = cos(1), sin(-1) = -sin(1)
|
||||
if diff := math.Abs(float64(negTable[4095][0]) - 0.5403023058681398); diff > 1e-6 {
|
||||
t.Errorf("negTable[4095][0] (cos(-1)): expected 0.5403, got %f", negTable[4095][0])
|
||||
}
|
||||
if diff := math.Abs(float64(negTable[4095][1]) - (-0.8414709848078965)); diff > 1e-6 {
|
||||
t.Errorf("negTable[4095][1] (sin(-1)): expected -0.8415, got %f", negTable[4095][1])
|
||||
}
|
||||
|
||||
// negTable[4094] corresponds to position -2
|
||||
// cos(-2) = cos(2), sin(-2) = -sin(2)
|
||||
cos2 := math.Cos(2.0)
|
||||
sin2 := math.Sin(2.0)
|
||||
if diff := math.Abs(float64(negTable[4094][0]) - cos2); diff > 1e-6 {
|
||||
t.Errorf("negTable[4094][0] (cos(-2)): expected %f, got %f", cos2, negTable[4094][0])
|
||||
}
|
||||
if diff := math.Abs(float64(negTable[4094][1]) - (-sin2)); diff > 1e-6 {
|
||||
t.Errorf("negTable[4094][1] (sin(-2)): expected %f, got %f", -sin2, negTable[4094][1])
|
||||
}
|
||||
}
|
||||
|
||||
// TestPrepareRoPE_QwenImage verifies qwen_image.PrepareRoPE for single-segment case
|
||||
func TestPrepareRoPE_QwenImage(t *testing.T) {
|
||||
if !mlx.GPUIsAvailable() {
|
||||
t.Skip("GPU not available")
|
||||
}
|
||||
|
||||
mlx.SetDefaultDeviceCPU()
|
||||
|
||||
// 4x4 patch grid, single image
|
||||
imgH, imgW := int32(4), int32(4)
|
||||
txtLen := int32(5)
|
||||
axesDims := []int32{16, 56, 56}
|
||||
|
||||
cache := qwen_image.PrepareRoPE(imgH, imgW, txtLen, axesDims)
|
||||
mlx.Eval(cache.ImgFreqs, cache.TxtFreqs)
|
||||
|
||||
// Check shapes
|
||||
imgShape := cache.ImgFreqs.Shape()
|
||||
if imgShape[0] != 16 { // 4*4 patches
|
||||
t.Errorf("ImgFreqs seq len: expected 16, got %d", imgShape[0])
|
||||
}
|
||||
|
||||
// For single image (frame=0), all temporal values should be cos=1, sin=0
|
||||
imgFreqsCPU := mlx.AsType(cache.ImgFreqs, mlx.DtypeFloat32)
|
||||
mlx.Eval(imgFreqsCPU)
|
||||
imgData := imgFreqsCPU.Data()
|
||||
|
||||
// Check first 16 values of patch 0 (temporal cos/sin pairs)
|
||||
for i := 0; i < 16; i += 2 {
|
||||
cosVal := imgData[i]
|
||||
sinVal := imgData[i+1]
|
||||
if diff := math.Abs(float64(cosVal - 1.0)); diff > 1e-5 {
|
||||
t.Errorf("ImgFreqs[0][%d] (cos): expected 1.0, got %f", i, cosVal)
|
||||
}
|
||||
if diff := math.Abs(float64(sinVal - 0.0)); diff > 1e-5 {
|
||||
t.Errorf("ImgFreqs[0][%d] (sin): expected 0.0, got %f", i+1, sinVal)
|
||||
}
|
||||
}
|
||||
|
||||
cache.ImgFreqs.Free()
|
||||
cache.TxtFreqs.Free()
|
||||
}
|
||||
|
||||
// TestScaleRopePositions verifies the centered position calculation for scale_rope=True
|
||||
func TestScaleRopePositions(t *testing.T) {
|
||||
// For a 4x4 grid with scale_rope=True:
|
||||
// hHalf = 2, wHalf = 2
|
||||
// hNegCount = 4 - 2 = 2 (positions 0,1 are negative)
|
||||
// wNegCount = 4 - 2 = 2 (positions 0,1 are negative)
|
||||
//
|
||||
// Height positions:
|
||||
// y=0: -(4-2) + 0 = -2
|
||||
// y=1: -(4-2) + 1 = -1
|
||||
// y=2: 2 - 2 = 0
|
||||
// y=3: 3 - 2 = 1
|
||||
//
|
||||
// Same for width
|
||||
|
||||
pH, pW := int32(4), int32(4)
|
||||
hHalf := pH / 2
|
||||
wHalf := pW / 2
|
||||
hNegCount := pH - hHalf
|
||||
wNegCount := pW - wHalf
|
||||
|
||||
expectedH := []int32{-2, -1, 0, 1}
|
||||
expectedW := []int32{-2, -1, 0, 1}
|
||||
|
||||
for y := int32(0); y < pH; y++ {
|
||||
var hPos int32
|
||||
if y < hNegCount {
|
||||
hPos = -(pH - hHalf) + y
|
||||
} else {
|
||||
hPos = y - hNegCount
|
||||
}
|
||||
if hPos != expectedH[y] {
|
||||
t.Errorf("y=%d: expected h_pos=%d, got %d", y, expectedH[y], hPos)
|
||||
}
|
||||
}
|
||||
|
||||
for x := int32(0); x < pW; x++ {
|
||||
var wPos int32
|
||||
if x < wNegCount {
|
||||
wPos = -(pW - wHalf) + x
|
||||
} else {
|
||||
wPos = x - wNegCount
|
||||
}
|
||||
if wPos != expectedW[x] {
|
||||
t.Errorf("x=%d: expected w_pos=%d, got %d", x, expectedW[x], wPos)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestRoPEHeadDimensions verifies the head dimension breakdown
|
||||
func TestRoPEHeadDimensions(t *testing.T) {
|
||||
// axes_dims_rope = [16, 56, 56]
|
||||
// Each dimension uses half the values for frequencies
|
||||
// So we get: 8 + 28 + 28 = 64 frequency values
|
||||
// Each frequency produces cos + sin, so: 64 * 2 = 128 total values per position
|
||||
|
||||
axesDims := []int32{16, 56, 56}
|
||||
expectedFreqs := (axesDims[0]/2 + axesDims[1]/2 + axesDims[2]/2)
|
||||
expectedHeadDim := expectedFreqs * 2
|
||||
|
||||
if expectedFreqs != 64 {
|
||||
t.Errorf("expected 64 frequency values, got %d", expectedFreqs)
|
||||
}
|
||||
if expectedHeadDim != 128 {
|
||||
t.Errorf("expected head_dim=128, got %d", expectedHeadDim)
|
||||
}
|
||||
|
||||
// This should match the transformer's attention head dimension
|
||||
// hidden_size = 3072, num_heads = 24
|
||||
// head_dim = 3072 / 24 = 128
|
||||
}
|
||||
|
||||
642
x/imagegen/models/qwen_image_edit/vae.go
Normal file
642
x/imagegen/models/qwen_image_edit/vae.go
Normal file
@@ -0,0 +1,642 @@
|
||||
//go:build mlx
|
||||
|
||||
package qwen_image_edit
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
|
||||
// VAEConfig holds Qwen-Image VAE configuration
|
||||
type VAEConfig struct {
|
||||
ZDim int32 `json:"z_dim"` // 16
|
||||
BaseDim int32 `json:"base_dim"` // 96
|
||||
DimMult []int32 `json:"dim_mult"` // [1, 2, 4, 4]
|
||||
NumResBlocks int32 `json:"num_res_blocks"` // 2
|
||||
LatentsMean []float32 `json:"latents_mean"` // 16 values
|
||||
LatentsStd []float32 `json:"latents_std"` // 16 values
|
||||
TemperalDownsample []bool `json:"temperal_downsample"` // [false, true, true]
|
||||
}
|
||||
|
||||
// defaultVAEConfig returns config for Qwen-Image VAE
|
||||
func defaultVAEConfig() *VAEConfig {
|
||||
return &VAEConfig{
|
||||
ZDim: 16,
|
||||
BaseDim: 96,
|
||||
DimMult: []int32{1, 2, 4, 4},
|
||||
NumResBlocks: 2,
|
||||
LatentsMean: []float32{
|
||||
-0.7571, -0.7089, -0.9113, 0.1075,
|
||||
-0.1745, 0.9653, -0.1517, 1.5508,
|
||||
0.4134, -0.0715, 0.5517, -0.3632,
|
||||
-0.1922, -0.9497, 0.2503, -0.2921,
|
||||
},
|
||||
LatentsStd: []float32{
|
||||
2.8184, 1.4541, 2.3275, 2.6558,
|
||||
1.2196, 1.7708, 2.6052, 2.0743,
|
||||
3.2687, 2.1526, 2.8652, 1.5579,
|
||||
1.6382, 1.1253, 2.8251, 1.916,
|
||||
},
|
||||
TemperalDownsample: []bool{false, true, true},
|
||||
}
|
||||
}
|
||||
|
||||
// VAE is the full VAE with encoder and decoder
|
||||
type VAE struct {
|
||||
Config *VAEConfig
|
||||
Encoder *VAEEncoder
|
||||
Decoder *VAEDecoder
|
||||
}
|
||||
|
||||
// Load loads the VAE from a directory
|
||||
func (m *VAE) Load(path string) error {
|
||||
fmt.Println("Loading Qwen-Image-Edit VAE (encoder + decoder)...")
|
||||
|
||||
cfg := defaultVAEConfig()
|
||||
m.Config = cfg
|
||||
|
||||
weights, err := safetensors.LoadModelWeights(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
|
||||
// Load weights as f32 for quality (matches Python default behavior)
|
||||
// VAE decoder precision is critical for final image quality
|
||||
fmt.Print(" Loading weights as f32... ")
|
||||
if err := weights.Load(mlx.DtypeFloat32); err != nil {
|
||||
return fmt.Errorf("failed to load weights: %w", err)
|
||||
}
|
||||
fmt.Printf("✓ (%.1f GB)\n", float64(mlx.MetalGetActiveMemory())/(1024*1024*1024))
|
||||
|
||||
// Load encoder
|
||||
fmt.Print(" Loading encoder... ")
|
||||
m.Encoder = &VAEEncoder{}
|
||||
if err := m.Encoder.loadFromWeights(weights, cfg); err != nil {
|
||||
return fmt.Errorf("encoder: %w", err)
|
||||
}
|
||||
fmt.Println("✓")
|
||||
|
||||
// Load decoder
|
||||
fmt.Print(" Loading decoder... ")
|
||||
m.Decoder = &VAEDecoder{}
|
||||
if err := m.Decoder.loadFromWeights(weights, cfg); err != nil {
|
||||
return fmt.Errorf("decoder: %w", err)
|
||||
}
|
||||
fmt.Println("✓")
|
||||
|
||||
weights.ReleaseAll()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encode encodes an image to latents
|
||||
// x: [B, C, T, H, W] image tensor in [-1, 1] range
|
||||
// Returns: [B, C, T, H/8, W/8] latents (unnormalized)
|
||||
func (m *VAE) Encode(x *mlx.Array) *mlx.Array {
|
||||
return m.Encoder.Encode(x)
|
||||
}
|
||||
|
||||
// Decode decodes latents to image
|
||||
// z: [B, C, T, H, W] latents (denormalized)
|
||||
// Returns: [B, C, T, H*8, W*8] image in [-1, 1]
|
||||
func (m *VAE) Decode(z *mlx.Array) *mlx.Array {
|
||||
return m.Decoder.Decode(z)
|
||||
}
|
||||
|
||||
// Normalize applies latent normalization
|
||||
// Input z should be f32 (from VAE encoder), output is f32 for transformer
|
||||
func (m *VAE) Normalize(z *mlx.Array) *mlx.Array {
|
||||
shape := z.Shape()
|
||||
C := shape[1]
|
||||
|
||||
mean := mlx.NewArray(m.Config.LatentsMean[:C], []int32{1, C, 1, 1, 1})
|
||||
std := mlx.NewArray(m.Config.LatentsStd[:C], []int32{1, C, 1, 1, 1})
|
||||
|
||||
// Mean/std are f32, will match z dtype through broadcasting
|
||||
return mlx.Div(mlx.Sub(z, mean), std)
|
||||
}
|
||||
|
||||
// Denormalize reverses latent normalization
|
||||
// Input z is bf16 (from transformer), output converted to f32 for VAE decoder
|
||||
func (m *VAE) Denormalize(z *mlx.Array) *mlx.Array {
|
||||
shape := z.Shape()
|
||||
C := shape[1]
|
||||
|
||||
// Convert latents to f32 for VAE decoder quality
|
||||
z = mlx.AsType(z, mlx.DtypeFloat32)
|
||||
|
||||
mean := mlx.NewArray(m.Config.LatentsMean[:C], []int32{1, C, 1, 1, 1})
|
||||
std := mlx.NewArray(m.Config.LatentsStd[:C], []int32{1, C, 1, 1, 1})
|
||||
|
||||
return mlx.Add(mlx.Mul(z, std), mean)
|
||||
}
|
||||
|
||||
// VAEEncoder is the encoder part of the VAE
|
||||
// The encoder uses a flat structure where down_blocks contains a mix of ResBlocks and Downsamplers:
|
||||
// - Blocks 0,1: ResBlocks (base_dim)
|
||||
// - Block 2: Downsample
|
||||
// - Blocks 3,4: ResBlocks (base_dim*2)
|
||||
// - Block 5: Downsample + temporal
|
||||
// - Blocks 6,7: ResBlocks (base_dim*4)
|
||||
// - Block 8: Downsample + temporal
|
||||
// - Blocks 9,10: ResBlocks (base_dim*4)
|
||||
type VAEEncoder struct {
|
||||
Config *VAEConfig
|
||||
|
||||
ConvIn *CausalConv3d
|
||||
Blocks []EncoderBlock // Flat list of ResBlocks and Downsamplers
|
||||
MidBlock *MidBlock
|
||||
NormOut *RMSNorm3D
|
||||
ConvOut *CausalConv3d
|
||||
QuantConv *CausalConv3d
|
||||
}
|
||||
|
||||
// EncoderBlock is either a ResBlock or a Downsample
|
||||
type EncoderBlock interface {
|
||||
Forward(x *mlx.Array) *mlx.Array
|
||||
IsDownsample() bool
|
||||
}
|
||||
|
||||
// EncoderResBlock wraps ResBlock
|
||||
type EncoderResBlock struct {
|
||||
*ResBlock
|
||||
}
|
||||
|
||||
func (b *EncoderResBlock) IsDownsample() bool { return false }
|
||||
|
||||
// EncoderDownsample is a downsample layer
|
||||
type EncoderDownsample struct {
|
||||
Resample *CausalConv3d
|
||||
TimeConv *CausalConv3d // Optional temporal downsample
|
||||
}
|
||||
|
||||
func (d *EncoderDownsample) IsDownsample() bool { return true }
|
||||
|
||||
func (d *EncoderDownsample) Forward(x *mlx.Array) *mlx.Array {
|
||||
// Spatial downsample with stride 2
|
||||
// WAN VAE uses: ZeroPad2d(0,1,0,1) + Conv2d(3x3, stride=2)
|
||||
x = d.forwardSpatialDownsample(x)
|
||||
|
||||
// NOTE: In WAN VAE, time_conv is ONLY used in streaming/chunked mode
|
||||
// with feat_cache. For single-frame encoding (T=1), time_conv is skipped.
|
||||
// The Python forward checks: if feat_cache is not None ... then use time_conv
|
||||
// Since we don't support streaming, we skip time_conv entirely.
|
||||
return x
|
||||
}
|
||||
|
||||
// forwardSpatialDownsample applies 2D conv with stride 2 for spatial downsampling
|
||||
func (d *EncoderDownsample) forwardSpatialDownsample(x *mlx.Array) *mlx.Array {
|
||||
xShape := x.Shape()
|
||||
B := xShape[0]
|
||||
T := xShape[1]
|
||||
H := xShape[2]
|
||||
W := xShape[3]
|
||||
C := xShape[4]
|
||||
|
||||
wShape := d.Resample.Weight.Shape()
|
||||
outC := wShape[0]
|
||||
|
||||
// Reshape to [B*T, H, W, C] for 2D conv
|
||||
x = mlx.Reshape(x, B*T, H, W, C)
|
||||
|
||||
// Asymmetric padding: pad right and bottom by 1 (WAN VAE style)
|
||||
// ZeroPad2d(0, 1, 0, 1) means (left=0, right=1, top=0, bottom=1)
|
||||
x = mlx.Pad(x, []int32{0, 0, 0, 1, 0, 1, 0, 0}) // [B, H, W, C] -> pad H and W
|
||||
|
||||
// Apply 2D conv with stride 2
|
||||
weight := mlx.Transpose(d.Resample.Weight, 0, 2, 3, 1) // [O, I, kH, kW] -> [O, kH, kW, I]
|
||||
x = conv2DStrided(x, weight, 2)
|
||||
|
||||
if d.Resample.Bias != nil {
|
||||
bias := mlx.Reshape(d.Resample.Bias, 1, 1, 1, outC)
|
||||
x = mlx.Add(x, bias)
|
||||
}
|
||||
|
||||
// Output dims after stride 2: (H+1)/2, (W+1)/2
|
||||
outH := (H + 1) / 2
|
||||
outW := (W + 1) / 2
|
||||
|
||||
// Reshape back to [B, T, H', W', C]
|
||||
x = mlx.Reshape(x, B, T, outH, outW, outC)
|
||||
mlx.Eval(x)
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// loadFromWeights loads the encoder from pre-loaded weights
|
||||
func (e *VAEEncoder) loadFromWeights(weights *safetensors.ModelWeights, cfg *VAEConfig) error {
|
||||
e.Config = cfg
|
||||
|
||||
// Conv in
|
||||
convIn, err := newCausalConv3d(weights, "encoder.conv_in")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
e.ConvIn = convIn
|
||||
|
||||
// Encoder uses flat block structure:
|
||||
// dim_mult = [1, 2, 4, 4], num_res_blocks = 2, temporal_downsample = [false, true, true]
|
||||
// Block layout: res,res,down, res,res,down+t, res,res,down+t, res,res
|
||||
// That's 11 blocks: 0,1=res, 2=down, 3,4=res, 5=down+t, 6,7=res, 8=down+t, 9,10=res
|
||||
e.Blocks = make([]EncoderBlock, 0, 11)
|
||||
|
||||
// Track dimensions
|
||||
dims := []int32{cfg.BaseDim, cfg.BaseDim * 2, cfg.BaseDim * 4, cfg.BaseDim * 4}
|
||||
blockIdx := 0
|
||||
|
||||
for stage := 0; stage < len(cfg.DimMult); stage++ {
|
||||
inDim := cfg.BaseDim
|
||||
if stage > 0 {
|
||||
inDim = dims[stage-1]
|
||||
}
|
||||
outDim := dims[stage]
|
||||
|
||||
// ResBlocks for this stage (num_res_blocks per stage)
|
||||
for r := int32(0); r < cfg.NumResBlocks; r++ {
|
||||
prefix := fmt.Sprintf("encoder.down_blocks.%d", blockIdx)
|
||||
currentInDim := inDim
|
||||
if r > 0 {
|
||||
currentInDim = outDim
|
||||
}
|
||||
block, err := newEncoderResBlock(weights, prefix, currentInDim, outDim)
|
||||
if err != nil {
|
||||
return fmt.Errorf("encoder res block %d: %w", blockIdx, err)
|
||||
}
|
||||
e.Blocks = append(e.Blocks, block)
|
||||
blockIdx++
|
||||
}
|
||||
|
||||
// Downsample after each stage except the last
|
||||
if stage < len(cfg.DimMult)-1 {
|
||||
prefix := fmt.Sprintf("encoder.down_blocks.%d", blockIdx)
|
||||
down, err := newEncoderDownsample(weights, prefix, cfg.TemperalDownsample[stage])
|
||||
if err != nil {
|
||||
return fmt.Errorf("encoder downsample %d: %w", blockIdx, err)
|
||||
}
|
||||
e.Blocks = append(e.Blocks, down)
|
||||
blockIdx++
|
||||
}
|
||||
}
|
||||
|
||||
// Mid block
|
||||
midDim := cfg.BaseDim * cfg.DimMult[len(cfg.DimMult)-1]
|
||||
midBlock, err := newMidBlock(weights, "encoder.mid_block", midDim)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
e.MidBlock = midBlock
|
||||
|
||||
// Norm out
|
||||
normOut, err := newRMSNorm3D(weights, "encoder.norm_out", midDim)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
e.NormOut = normOut
|
||||
|
||||
// Conv out
|
||||
convOut, err := newCausalConv3d(weights, "encoder.conv_out")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
e.ConvOut = convOut
|
||||
|
||||
// Quant conv
|
||||
quantConv, err := newCausalConv3d(weights, "quant_conv")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
e.QuantConv = quantConv
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// newEncoderResBlock creates a ResBlock for the encoder (flat structure)
|
||||
func newEncoderResBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32) (*EncoderResBlock, error) {
|
||||
block, err := newResBlock(weights, prefix, inDim, outDim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &EncoderResBlock{block}, nil
|
||||
}
|
||||
|
||||
// newEncoderDownsample creates a downsample layer for the encoder
|
||||
func newEncoderDownsample(weights *safetensors.ModelWeights, prefix string, temporal bool) (*EncoderDownsample, error) {
|
||||
resample, err := newCausalConv3d(weights, prefix+".resample.1")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var timeConv *CausalConv3d
|
||||
if temporal {
|
||||
timeConv, _ = newCausalConv3d(weights, prefix+".time_conv")
|
||||
}
|
||||
|
||||
return &EncoderDownsample{
|
||||
Resample: resample,
|
||||
TimeConv: timeConv,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Encode encodes an image to latents
|
||||
// x: [B, C, T, H, W] image tensor (channels-first)
|
||||
// Returns: [B, latent_C, T, H/8, W/8] latent distribution mode
|
||||
func (e *VAEEncoder) Encode(x *mlx.Array) *mlx.Array {
|
||||
// Convert from channels-first [N, C, T, H, W] to channels-last [N, T, H, W, C]
|
||||
x = mlx.Contiguous(mlx.Transpose(x, 0, 2, 3, 4, 1))
|
||||
mlx.Eval(x)
|
||||
|
||||
// Conv in
|
||||
x = e.ConvIn.Forward(x)
|
||||
|
||||
// Encoder blocks (mix of ResBlocks and Downsamplers)
|
||||
for _, block := range e.Blocks {
|
||||
prev := x
|
||||
x = block.Forward(x)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
// Mid block
|
||||
x = e.MidBlock.Forward(x)
|
||||
|
||||
// Norm + silu
|
||||
{
|
||||
prev := x
|
||||
x = e.NormOut.Forward(x)
|
||||
x = silu3D(x)
|
||||
prev.Free()
|
||||
mlx.Eval(x)
|
||||
}
|
||||
|
||||
// Conv out
|
||||
{
|
||||
prev := x
|
||||
x = e.ConvOut.Forward(x)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
// Quant conv
|
||||
{
|
||||
prev := x
|
||||
x = e.QuantConv.Forward(x)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
// Get mode from distribution (first half of channels = mean)
|
||||
// Output is [B, T, H, W, 2*latent_C], we take first latent_C channels
|
||||
shape := x.Shape()
|
||||
latentC := shape[4] / 2
|
||||
x = mlx.Slice(x, []int32{0, 0, 0, 0, 0}, []int32{shape[0], shape[1], shape[2], shape[3], latentC})
|
||||
|
||||
// Convert back to channels-first [N, C, T, H, W]
|
||||
x = mlx.Contiguous(mlx.Transpose(x, 0, 4, 1, 2, 3))
|
||||
mlx.Eval(x)
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// VAEDecoder is the decoder part of the VAE
|
||||
type VAEDecoder struct {
|
||||
Config *VAEConfig
|
||||
|
||||
PostQuantConv *CausalConv3d
|
||||
ConvIn *CausalConv3d
|
||||
MidBlock *MidBlock
|
||||
UpBlocks []*UpBlock
|
||||
NormOut *RMSNorm3D
|
||||
ConvOut *CausalConv3d
|
||||
}
|
||||
|
||||
// loadFromWeights loads the decoder from pre-loaded weights
|
||||
func (d *VAEDecoder) loadFromWeights(weights *safetensors.ModelWeights, cfg *VAEConfig) error {
|
||||
d.Config = cfg
|
||||
|
||||
postQuantConv, err := newCausalConv3d(weights, "post_quant_conv")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
d.PostQuantConv = postQuantConv
|
||||
|
||||
convIn, err := newCausalConv3d(weights, "decoder.conv_in")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
d.ConvIn = convIn
|
||||
|
||||
// Mid block
|
||||
midDim := cfg.BaseDim * cfg.DimMult[len(cfg.DimMult)-1]
|
||||
midBlock, err := newMidBlock(weights, "decoder.mid_block", midDim)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
d.MidBlock = midBlock
|
||||
|
||||
// Up blocks (reversed dim_mult)
|
||||
numUpBlocks := len(cfg.DimMult)
|
||||
d.UpBlocks = make([]*UpBlock, numUpBlocks)
|
||||
|
||||
dimsMult := make([]int32, numUpBlocks+1)
|
||||
dimsMult[0] = cfg.DimMult[numUpBlocks-1]
|
||||
for i := 0; i < numUpBlocks; i++ {
|
||||
dimsMult[i+1] = cfg.DimMult[numUpBlocks-1-i]
|
||||
}
|
||||
|
||||
temporalUpsample := make([]bool, len(cfg.TemperalDownsample))
|
||||
for i := range cfg.TemperalDownsample {
|
||||
temporalUpsample[i] = cfg.TemperalDownsample[len(cfg.TemperalDownsample)-1-i]
|
||||
}
|
||||
|
||||
for i := 0; i < numUpBlocks; i++ {
|
||||
inDim := cfg.BaseDim * dimsMult[i]
|
||||
outDim := cfg.BaseDim * dimsMult[i+1]
|
||||
|
||||
if i > 0 {
|
||||
inDim = inDim / 2
|
||||
}
|
||||
|
||||
upsampleMode := ""
|
||||
if i < numUpBlocks-1 {
|
||||
if temporalUpsample[i] {
|
||||
upsampleMode = "upsample3d"
|
||||
} else {
|
||||
upsampleMode = "upsample2d"
|
||||
}
|
||||
}
|
||||
|
||||
prefix := fmt.Sprintf("decoder.up_blocks.%d", i)
|
||||
upBlock, err := newUpBlock(weights, prefix, inDim, outDim, cfg.NumResBlocks, upsampleMode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
d.UpBlocks[i] = upBlock
|
||||
}
|
||||
|
||||
normOut, err := newRMSNorm3D(weights, "decoder.norm_out", cfg.BaseDim)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
d.NormOut = normOut
|
||||
|
||||
convOut, err := newCausalConv3d(weights, "decoder.conv_out")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
d.ConvOut = convOut
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Decode converts latents to image
|
||||
// z: [B, C, T, H, W] denormalized latents
|
||||
func (d *VAEDecoder) Decode(z *mlx.Array) *mlx.Array {
|
||||
var x *mlx.Array
|
||||
|
||||
// Convert from channels-first to channels-last
|
||||
{
|
||||
z = mlx.Contiguous(mlx.Transpose(z, 0, 2, 3, 4, 1))
|
||||
mlx.Eval(z)
|
||||
}
|
||||
|
||||
// PostQuantConv
|
||||
x = d.PostQuantConv.Forward(z)
|
||||
z.Free()
|
||||
|
||||
// ConvIn
|
||||
{
|
||||
prev := x
|
||||
x = d.ConvIn.Forward(x)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
// Mid block
|
||||
x = d.MidBlock.Forward(x)
|
||||
|
||||
// Up blocks
|
||||
for _, upBlock := range d.UpBlocks {
|
||||
x = upBlock.Forward(x)
|
||||
}
|
||||
|
||||
// NormOut + silu
|
||||
{
|
||||
prev := x
|
||||
x = d.NormOut.Forward(x)
|
||||
x = silu3D(x)
|
||||
prev.Free()
|
||||
mlx.Eval(x)
|
||||
}
|
||||
|
||||
// ConvOut
|
||||
{
|
||||
prev := x
|
||||
x = d.ConvOut.Forward(x)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
// Post-processing: clamp and convert back to channels-first
|
||||
{
|
||||
prev := x
|
||||
x = mlx.ClipScalar(x, -1.0, 1.0, true, true)
|
||||
x = mlx.Contiguous(mlx.Transpose(x, 0, 4, 1, 2, 3))
|
||||
prev.Free()
|
||||
mlx.Eval(x)
|
||||
}
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// DownBlock handles downsampling in encoder
|
||||
type DownBlock struct {
|
||||
ResBlocks []*ResBlock
|
||||
Downsampler *Downsample
|
||||
}
|
||||
|
||||
// newDownBlock creates a down block
|
||||
func newDownBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32, numBlocks int32, downsampleMode string) (*DownBlock, error) {
|
||||
resBlocks := make([]*ResBlock, numBlocks+1)
|
||||
|
||||
currentDim := inDim
|
||||
for i := int32(0); i <= numBlocks; i++ {
|
||||
resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i)
|
||||
block, err := newResBlock(weights, resPrefix, currentDim, outDim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resBlocks[i] = block
|
||||
currentDim = outDim
|
||||
}
|
||||
|
||||
var downsampler *Downsample
|
||||
if downsampleMode != "" {
|
||||
downsampler = newDownsample(weights, prefix+".downsamplers.0", outDim, downsampleMode)
|
||||
}
|
||||
|
||||
return &DownBlock{
|
||||
ResBlocks: resBlocks,
|
||||
Downsampler: downsampler,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies down block
|
||||
func (d *DownBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
for _, block := range d.ResBlocks {
|
||||
prev := x
|
||||
x = block.Forward(x)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
if d.Downsampler != nil {
|
||||
prev := x
|
||||
x = d.Downsampler.Forward(x)
|
||||
prev.Free()
|
||||
}
|
||||
return x
|
||||
}
|
||||
|
||||
// Downsample handles spatial downsampling
|
||||
type Downsample struct {
|
||||
Conv *mlx.Array
|
||||
Bias *mlx.Array
|
||||
Mode string
|
||||
}
|
||||
|
||||
// newDownsample creates a downsampler
|
||||
func newDownsample(weights *safetensors.ModelWeights, prefix string, dim int32, mode string) *Downsample {
|
||||
conv, _ := weights.Get(prefix + ".resample.1.weight")
|
||||
bias, _ := weights.Get(prefix + ".resample.1.bias")
|
||||
return &Downsample{
|
||||
Conv: conv,
|
||||
Bias: bias,
|
||||
Mode: mode,
|
||||
}
|
||||
}
|
||||
|
||||
// Forward applies downsampling to channels-last input [B, T, H, W, C]
|
||||
func (d *Downsample) Forward(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
T := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
C := shape[4]
|
||||
outC := d.Conv.Shape()[0]
|
||||
|
||||
// Reshape to [B*T, H, W, C] for 2D conv
|
||||
x = mlx.Reshape(x, B*T, H, W, C)
|
||||
|
||||
// Pad for stride-2 conv: need (3-1)/2 = 1 on each side, but for stride 2 we need specific padding
|
||||
// For 3x3 stride 2: pad 1 on all sides
|
||||
x = mlx.Pad(x, []int32{0, 0, 1, 1, 1, 1, 0, 0})
|
||||
|
||||
// Conv with stride 2 using manual strided patching
|
||||
weight := mlx.Transpose(d.Conv, 0, 2, 3, 1)
|
||||
x = conv2DStrided(x, weight, 2)
|
||||
if d.Bias != nil {
|
||||
bias := mlx.Reshape(d.Bias, 1, 1, 1, outC)
|
||||
x = mlx.Add(x, bias)
|
||||
}
|
||||
|
||||
x = mlx.Reshape(x, B, T, H/2, W/2, outC)
|
||||
mlx.Eval(x)
|
||||
|
||||
return x
|
||||
}
|
||||
148
x/imagegen/models/zimage/scheduler.go
Normal file
148
x/imagegen/models/zimage/scheduler.go
Normal file
@@ -0,0 +1,148 @@
|
||||
//go:build mlx
|
||||
|
||||
package zimage
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// FlowMatchSchedulerConfig holds scheduler configuration
|
||||
type FlowMatchSchedulerConfig struct {
|
||||
NumTrainTimesteps int32 `json:"num_train_timesteps"` // 1000
|
||||
Shift float32 `json:"shift"` // 3.0
|
||||
UseDynamicShifting bool `json:"use_dynamic_shifting"` // false
|
||||
}
|
||||
|
||||
// DefaultFlowMatchSchedulerConfig returns default config
|
||||
func DefaultFlowMatchSchedulerConfig() *FlowMatchSchedulerConfig {
|
||||
return &FlowMatchSchedulerConfig{
|
||||
NumTrainTimesteps: 1000,
|
||||
Shift: 3.0,
|
||||
UseDynamicShifting: true, // Z-Image-Turbo uses dynamic shifting
|
||||
}
|
||||
}
|
||||
|
||||
// FlowMatchEulerScheduler implements the Flow Match Euler discrete scheduler
|
||||
// This is used in Z-Image-Turbo for fast sampling
|
||||
type FlowMatchEulerScheduler struct {
|
||||
Config *FlowMatchSchedulerConfig
|
||||
Timesteps []float32 // Discretized timesteps
|
||||
Sigmas []float32 // Noise levels at each timestep
|
||||
NumSteps int // Number of inference steps
|
||||
}
|
||||
|
||||
// NewFlowMatchEulerScheduler creates a new scheduler
|
||||
func NewFlowMatchEulerScheduler(cfg *FlowMatchSchedulerConfig) *FlowMatchEulerScheduler {
|
||||
return &FlowMatchEulerScheduler{
|
||||
Config: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
// SetTimesteps sets up the scheduler for the given number of inference steps
|
||||
func (s *FlowMatchEulerScheduler) SetTimesteps(numSteps int) {
|
||||
s.SetTimestepsWithMu(numSteps, 0)
|
||||
}
|
||||
|
||||
// SetTimestepsWithMu sets up the scheduler with dynamic mu shift
|
||||
func (s *FlowMatchEulerScheduler) SetTimestepsWithMu(numSteps int, mu float32) {
|
||||
s.NumSteps = numSteps
|
||||
|
||||
// Create evenly spaced timesteps from 1.0 to 0.0 (flow matching goes t=1 to t=0)
|
||||
// Match Python: np.linspace(1.0, 0.0, num_inference_steps + 1)
|
||||
s.Timesteps = make([]float32, numSteps+1)
|
||||
s.Sigmas = make([]float32, numSteps+1)
|
||||
|
||||
for i := 0; i <= numSteps; i++ {
|
||||
t := 1.0 - float32(i)/float32(numSteps)
|
||||
|
||||
// Apply time shift if using dynamic shifting
|
||||
if s.Config.UseDynamicShifting && mu != 0 {
|
||||
t = s.timeShift(mu, t)
|
||||
}
|
||||
|
||||
s.Timesteps[i] = t
|
||||
s.Sigmas[i] = t
|
||||
}
|
||||
}
|
||||
|
||||
// timeShift applies the dynamic time shift (match Python)
|
||||
func (s *FlowMatchEulerScheduler) timeShift(mu float32, t float32) float32 {
|
||||
if t <= 0 {
|
||||
return 0
|
||||
}
|
||||
// exp(mu) / (exp(mu) + (1/t - 1))
|
||||
expMu := float32(math.Exp(float64(mu)))
|
||||
return expMu / (expMu + (1.0/t - 1.0))
|
||||
}
|
||||
|
||||
// Step performs one denoising step
|
||||
// modelOutput: predicted velocity/noise from the model
|
||||
// timestepIdx: current timestep index
|
||||
// sample: current noisy sample
|
||||
// Returns: denoised sample for next step
|
||||
func (s *FlowMatchEulerScheduler) Step(modelOutput, sample *mlx.Array, timestepIdx int) *mlx.Array {
|
||||
// Get current and next sigma
|
||||
sigma := s.Sigmas[timestepIdx]
|
||||
sigmaNext := s.Sigmas[timestepIdx+1]
|
||||
|
||||
// Euler step: x_{t-dt} = x_t + (sigma_next - sigma) * v_t
|
||||
// where v_t is the velocity predicted by the model
|
||||
dt := sigmaNext - sigma // This is negative (going from noise to clean)
|
||||
|
||||
// x_next = x + dt * velocity
|
||||
scaledOutput := mlx.MulScalar(modelOutput, dt)
|
||||
return mlx.Add(sample, scaledOutput)
|
||||
}
|
||||
|
||||
// ScaleSample scales the sample for model input (identity for flow matching)
|
||||
func (s *FlowMatchEulerScheduler) ScaleSample(sample *mlx.Array, timestepIdx int) *mlx.Array {
|
||||
// Flow matching doesn't need scaling
|
||||
return sample
|
||||
}
|
||||
|
||||
// GetTimestep returns the timestep value at the given index
|
||||
func (s *FlowMatchEulerScheduler) GetTimestep(idx int) float32 {
|
||||
if idx < len(s.Timesteps) {
|
||||
return s.Timesteps[idx]
|
||||
}
|
||||
return 0.0
|
||||
}
|
||||
|
||||
// GetTimesteps returns all timesteps (implements Scheduler interface)
|
||||
func (s *FlowMatchEulerScheduler) GetTimesteps() []float32 {
|
||||
return s.Timesteps
|
||||
}
|
||||
|
||||
// AddNoise adds noise to clean samples for a given timestep
|
||||
// Used for img2img or inpainting
|
||||
func (s *FlowMatchEulerScheduler) AddNoise(cleanSample, noise *mlx.Array, timestepIdx int) *mlx.Array {
|
||||
// In flow matching: x_t = (1-t) * x_0 + t * noise
|
||||
t := s.Timesteps[timestepIdx]
|
||||
oneMinusT := 1.0 - t
|
||||
|
||||
scaledClean := mlx.MulScalar(cleanSample, oneMinusT)
|
||||
scaledNoise := mlx.MulScalar(noise, t)
|
||||
|
||||
return mlx.Add(scaledClean, scaledNoise)
|
||||
}
|
||||
|
||||
// InitNoise creates initial noise for sampling
|
||||
func (s *FlowMatchEulerScheduler) InitNoise(shape []int32, seed int64) *mlx.Array {
|
||||
return RandomNormal(shape, seed)
|
||||
}
|
||||
|
||||
// RandomNormal creates a random normal tensor using MLX
|
||||
func RandomNormal(shape []int32, seed int64) *mlx.Array {
|
||||
return mlx.RandomNormal(shape, uint64(seed))
|
||||
}
|
||||
|
||||
// GetLatentShape returns the latent shape for a given image size
|
||||
func GetLatentShape(batchSize, height, width, latentChannels int32, patchSize int32) []int32 {
|
||||
// Latent is 8x smaller than image (VAE downscale)
|
||||
latentH := height / 8
|
||||
latentW := width / 8
|
||||
|
||||
return []int32{batchSize, latentChannels, latentH, latentW}
|
||||
}
|
||||
296
x/imagegen/models/zimage/text_encoder.go
Normal file
296
x/imagegen/models/zimage/text_encoder.go
Normal file
@@ -0,0 +1,296 @@
|
||||
//go:build mlx
|
||||
|
||||
package zimage
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
)
|
||||
|
||||
// Qwen3Config holds Qwen3 text encoder configuration
|
||||
type Qwen3Config struct {
|
||||
HiddenSize int32 `json:"hidden_size"`
|
||||
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
||||
IntermediateSize int32 `json:"intermediate_size"`
|
||||
NumAttentionHeads int32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads int32 `json:"num_key_value_heads"`
|
||||
VocabSize int32 `json:"vocab_size"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
HeadDim int32 `json:"head_dim"`
|
||||
}
|
||||
|
||||
// loadQwen3Config loads text encoder config from a JSON file
|
||||
func loadQwen3Config(path string) (*Qwen3Config, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read config: %w", err)
|
||||
}
|
||||
var cfg Qwen3Config
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("parse config: %w", err)
|
||||
}
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
// Qwen3Attention implements Qwen3 attention with QK norms
|
||||
type Qwen3Attention struct {
|
||||
QProj *nn.Linear `weight:"q_proj"`
|
||||
KProj *nn.Linear `weight:"k_proj"`
|
||||
VProj *nn.Linear `weight:"v_proj"`
|
||||
OProj *nn.Linear `weight:"o_proj"`
|
||||
QNorm *nn.RMSNorm `weight:"q_norm"`
|
||||
KNorm *nn.RMSNorm `weight:"k_norm"`
|
||||
// Computed fields
|
||||
NHeads int32
|
||||
NKVHeads int32
|
||||
HeadDim int32
|
||||
Scale float32
|
||||
RopeTheta float32
|
||||
}
|
||||
|
||||
// applyRoPEQwen3 applies the custom RoPE for Qwen3 text encoder
|
||||
func applyRoPEQwen3(x *mlx.Array, seqLen int32, theta float32) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
H := shape[2]
|
||||
D := shape[3]
|
||||
half := D / 2
|
||||
|
||||
freqsArr := make([]float32, half)
|
||||
logTheta := float32(math.Log(float64(theta)))
|
||||
for i := int32(0); i < half; i++ {
|
||||
freqsArr[i] = float32(math.Exp(float64(-logTheta * float32(i) / float32(half))))
|
||||
}
|
||||
freqs := mlx.NewArray(freqsArr, []int32{half})
|
||||
|
||||
posArr := make([]float32, seqLen)
|
||||
for i := int32(0); i < seqLen; i++ {
|
||||
posArr[i] = float32(i)
|
||||
}
|
||||
pos := mlx.NewArray(posArr, []int32{seqLen})
|
||||
|
||||
posExpanded := mlx.Reshape(pos, seqLen, 1)
|
||||
freqsExpanded := mlx.Reshape(freqs, 1, half)
|
||||
args := mlx.Mul(posExpanded, freqsExpanded)
|
||||
|
||||
cosVals := mlx.Cos(args)
|
||||
sinVals := mlx.Sin(args)
|
||||
cosVals = mlx.Reshape(cosVals, seqLen, 1, half)
|
||||
sinVals = mlx.Reshape(sinVals, seqLen, 1, half)
|
||||
|
||||
x1 := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, L, H, half})
|
||||
x2 := mlx.Slice(x, []int32{0, 0, 0, half}, []int32{B, L, H, D})
|
||||
|
||||
part1 := mlx.Sub(mlx.Mul(x1, cosVals), mlx.Mul(x2, sinVals))
|
||||
part2 := mlx.Add(mlx.Mul(x1, sinVals), mlx.Mul(x2, cosVals))
|
||||
|
||||
return mlx.Concatenate([]*mlx.Array{part1, part2}, 3)
|
||||
}
|
||||
|
||||
// Forward computes attention with causal masking
|
||||
func (attn *Qwen3Attention) Forward(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
|
||||
q := attn.QProj.Forward(x)
|
||||
k := attn.KProj.Forward(x)
|
||||
v := attn.VProj.Forward(x)
|
||||
|
||||
q = mlx.Reshape(q, B, L, attn.NHeads, attn.HeadDim)
|
||||
k = mlx.Reshape(k, B, L, attn.NKVHeads, attn.HeadDim)
|
||||
v = mlx.Reshape(v, B, L, attn.NKVHeads, attn.HeadDim)
|
||||
|
||||
// QK norm uses 1e-6 hardcoded (Qwen3 specific)
|
||||
q = attn.QNorm.Forward(q, 1e-6)
|
||||
k = attn.KNorm.Forward(k, 1e-6)
|
||||
|
||||
q = applyRoPEQwen3(q, L, attn.RopeTheta)
|
||||
k = applyRoPEQwen3(k, L, attn.RopeTheta)
|
||||
|
||||
q = mlx.Transpose(q, 0, 2, 1, 3)
|
||||
k = mlx.Transpose(k, 0, 2, 1, 3)
|
||||
v = mlx.Transpose(v, 0, 2, 1, 3)
|
||||
|
||||
if attn.NKVHeads < attn.NHeads {
|
||||
repeats := attn.NHeads / attn.NKVHeads
|
||||
k = repeatKV(k, repeats)
|
||||
v = repeatKV(v, repeats)
|
||||
}
|
||||
|
||||
out := mlx.ScaledDotProductAttention(q, k, v, attn.Scale, true)
|
||||
|
||||
out = mlx.Transpose(out, 0, 2, 1, 3)
|
||||
out = mlx.Reshape(out, B, L, attn.NHeads*attn.HeadDim)
|
||||
|
||||
out = attn.OProj.Forward(out)
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// repeatKV repeats key/value heads for GQA
|
||||
func repeatKV(x *mlx.Array, repeats int32) *mlx.Array {
|
||||
if repeats == 1 {
|
||||
return x
|
||||
}
|
||||
shape := x.Shape()
|
||||
x = mlx.ExpandDims(x, 2)
|
||||
x = mlx.Tile(x, []int32{1, 1, repeats, 1, 1})
|
||||
return mlx.Reshape(x, shape[0], shape[1]*repeats, shape[2], shape[3])
|
||||
}
|
||||
|
||||
// Qwen3MLP implements Qwen3 SwiGLU MLP
|
||||
type Qwen3MLP struct {
|
||||
GateProj *nn.Linear `weight:"gate_proj"`
|
||||
UpProj *nn.Linear `weight:"up_proj"`
|
||||
DownProj *nn.Linear `weight:"down_proj"`
|
||||
}
|
||||
|
||||
// Forward applies the MLP
|
||||
func (m *Qwen3MLP) Forward(x *mlx.Array) *mlx.Array {
|
||||
gate := m.GateProj.Forward(x)
|
||||
gate = mlx.SiLU(gate)
|
||||
up := m.UpProj.Forward(x)
|
||||
h := mlx.Mul(gate, up)
|
||||
return m.DownProj.Forward(h)
|
||||
}
|
||||
|
||||
// Qwen3Block represents a single Qwen3 transformer block
|
||||
type Qwen3Block struct {
|
||||
Attention *Qwen3Attention `weight:"self_attn"`
|
||||
MLP *Qwen3MLP `weight:"mlp"`
|
||||
InputLayerNorm *nn.RMSNorm `weight:"input_layernorm"`
|
||||
PostAttnLayerNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
|
||||
}
|
||||
|
||||
// Forward applies the Qwen3 block
|
||||
func (qb *Qwen3Block) Forward(x *mlx.Array, eps float32) *mlx.Array {
|
||||
h := qb.InputLayerNorm.Forward(x, eps)
|
||||
attnOut := qb.Attention.Forward(h)
|
||||
x = mlx.Add(x, attnOut)
|
||||
|
||||
h = qb.PostAttnLayerNorm.Forward(x, eps)
|
||||
mlpOut := qb.MLP.Forward(h)
|
||||
x = mlx.Add(x, mlpOut)
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// Qwen3TextEncoder is the full Qwen3 encoder for Z-Image
|
||||
type Qwen3TextEncoder struct {
|
||||
EmbedTokens *nn.Embedding `weight:"model.embed_tokens"`
|
||||
Layers []*Qwen3Block `weight:"model.layers"`
|
||||
FinalNorm *nn.RMSNorm `weight:"model.norm"`
|
||||
*Qwen3Config
|
||||
}
|
||||
|
||||
// Load loads the Qwen3 text encoder from a directory
|
||||
func (m *Qwen3TextEncoder) Load(path string) error {
|
||||
fmt.Println("Loading Qwen3 text encoder...")
|
||||
|
||||
// Load config
|
||||
cfg, err := loadQwen3Config(filepath.Join(path, "config.json"))
|
||||
if err != nil {
|
||||
return fmt.Errorf("config: %w", err)
|
||||
}
|
||||
m.Qwen3Config = cfg
|
||||
|
||||
// Pre-allocate layers slice
|
||||
m.Layers = make([]*Qwen3Block, cfg.NumHiddenLayers)
|
||||
|
||||
// Load weights
|
||||
weights, err := safetensors.LoadModelWeights(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
|
||||
fmt.Print(" Loading weights via struct tags... ")
|
||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||
return fmt.Errorf("load module: %w", err)
|
||||
}
|
||||
fmt.Println("✓")
|
||||
|
||||
// Initialize computed fields
|
||||
m.FinalNorm.Eps = cfg.RMSNormEps
|
||||
for _, block := range m.Layers {
|
||||
// Attention
|
||||
block.Attention.NHeads = cfg.NumAttentionHeads
|
||||
block.Attention.NKVHeads = cfg.NumKeyValueHeads
|
||||
block.Attention.HeadDim = cfg.HeadDim
|
||||
block.Attention.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
|
||||
block.Attention.RopeTheta = cfg.RopeTheta
|
||||
block.Attention.QNorm.Eps = cfg.RMSNormEps
|
||||
block.Attention.KNorm.Eps = cfg.RMSNormEps
|
||||
// Block norms
|
||||
block.InputLayerNorm.Eps = cfg.RMSNormEps
|
||||
block.PostAttnLayerNorm.Eps = cfg.RMSNormEps
|
||||
}
|
||||
|
||||
weights.ReleaseAll()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Forward encodes text tokens
|
||||
func (te *Qwen3TextEncoder) Forward(tokens *mlx.Array) *mlx.Array {
|
||||
h := te.EmbedTokens.Forward(tokens)
|
||||
eps := te.RMSNormEps
|
||||
|
||||
for _, layer := range te.Layers {
|
||||
h = layer.Forward(h, eps)
|
||||
}
|
||||
|
||||
// Apply final RMS norm
|
||||
h = te.FinalNorm.Forward(h, eps)
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// ApplyChatTemplate wraps prompt in Qwen3 chat format
|
||||
func ApplyChatTemplate(prompt string) string {
|
||||
return "<|im_start|>user\n" + prompt + "<|im_end|>\n<|im_start|>assistant\n"
|
||||
}
|
||||
|
||||
// EncodePrompt encodes a text prompt using the tokenizer and encoder
|
||||
func (te *Qwen3TextEncoder) EncodePrompt(tok *tokenizer.Tokenizer, prompt string, maxLen int) (*mlx.Array, *mlx.Array) {
|
||||
formattedPrompt := ApplyChatTemplate(prompt)
|
||||
|
||||
tokens := tok.Encode(formattedPrompt, false)
|
||||
|
||||
if len(tokens) > maxLen {
|
||||
tokens = tokens[:maxLen]
|
||||
}
|
||||
|
||||
maskData := make([]float32, maxLen)
|
||||
for i := 0; i < len(tokens); i++ {
|
||||
maskData[i] = 1.0
|
||||
}
|
||||
|
||||
// Get PAD token (different from EOS for Qwen3)
|
||||
padToken := tok.PAD()
|
||||
if padToken < 0 {
|
||||
padToken = tok.EOS() // fallback
|
||||
}
|
||||
|
||||
paddedTokens := make([]int32, maxLen)
|
||||
copy(paddedTokens, tokens)
|
||||
for i := len(tokens); i < maxLen; i++ {
|
||||
paddedTokens[i] = padToken
|
||||
}
|
||||
|
||||
tokensArr := mlx.NewArrayInt32(paddedTokens, []int32{1, int32(maxLen)})
|
||||
maskArr := mlx.NewArray(maskData, []int32{1, int32(maxLen)})
|
||||
|
||||
embeddings := te.Forward(tokensArr)
|
||||
|
||||
return embeddings, maskArr
|
||||
}
|
||||
692
x/imagegen/models/zimage/transformer.go
Normal file
692
x/imagegen/models/zimage/transformer.go
Normal file
@@ -0,0 +1,692 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package zimage implements the Z-Image diffusion transformer model.
|
||||
package zimage
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
|
||||
// TransformerConfig holds Z-Image transformer configuration
|
||||
type TransformerConfig struct {
|
||||
Dim int32 `json:"dim"`
|
||||
NHeads int32 `json:"n_heads"`
|
||||
NKVHeads int32 `json:"n_kv_heads"`
|
||||
NLayers int32 `json:"n_layers"`
|
||||
NRefinerLayers int32 `json:"n_refiner_layers"`
|
||||
InChannels int32 `json:"in_channels"`
|
||||
PatchSize int32 `json:"-"` // Computed from AllPatchSize
|
||||
CapFeatDim int32 `json:"cap_feat_dim"`
|
||||
NormEps float32 `json:"norm_eps"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
TScale float32 `json:"t_scale"`
|
||||
QKNorm bool `json:"qk_norm"`
|
||||
AxesDims []int32 `json:"axes_dims"`
|
||||
AxesLens []int32 `json:"axes_lens"`
|
||||
AllPatchSize []int32 `json:"all_patch_size"` // JSON array, PatchSize = first element
|
||||
}
|
||||
|
||||
// TimestepEmbedder creates sinusoidal timestep embeddings
|
||||
// Output dimension is 256 (fixed), used for AdaLN modulation
|
||||
type TimestepEmbedder struct {
|
||||
Linear1 *nn.Linear `weight:"mlp.0"`
|
||||
Linear2 *nn.Linear `weight:"mlp.2"`
|
||||
FreqEmbedSize int32 // 256 (computed)
|
||||
}
|
||||
|
||||
// Forward computes timestep embeddings -> [B, 256]
|
||||
func (te *TimestepEmbedder) Forward(t *mlx.Array) *mlx.Array {
|
||||
// t: [B] timesteps
|
||||
|
||||
// Create sinusoidal embedding
|
||||
half := te.FreqEmbedSize / 2
|
||||
|
||||
// freqs = exp(-log(10000) * arange(half) / half)
|
||||
freqs := make([]float32, half)
|
||||
for i := int32(0); i < half; i++ {
|
||||
freqs[i] = float32(math.Exp(-math.Log(10000.0) * float64(i) / float64(half)))
|
||||
}
|
||||
freqsArr := mlx.NewArray(freqs, []int32{1, half})
|
||||
|
||||
// t[:, None] * freqs[None, :] -> [B, half]
|
||||
tExpanded := mlx.ExpandDims(t, 1) // [B, 1]
|
||||
args := mlx.Mul(tExpanded, freqsArr)
|
||||
|
||||
// embedding = [cos(args), sin(args)] -> [B, 256]
|
||||
cosArgs := mlx.Cos(args)
|
||||
sinArgs := mlx.Sin(args)
|
||||
embedding := mlx.Concatenate([]*mlx.Array{cosArgs, sinArgs}, 1)
|
||||
|
||||
// MLP: linear1 -> silu -> linear2
|
||||
h := te.Linear1.Forward(embedding)
|
||||
h = mlx.SiLU(h)
|
||||
h = te.Linear2.Forward(h)
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// XEmbedder embeds image patches to model dimension
|
||||
type XEmbedder struct {
|
||||
Linear *nn.Linear `weight:"2-1"`
|
||||
}
|
||||
|
||||
// Forward embeds patchified image latents
|
||||
func (xe *XEmbedder) Forward(x *mlx.Array) *mlx.Array {
|
||||
// x: [B, L, in_channels * 4] -> [B, L, dim]
|
||||
return xe.Linear.Forward(x)
|
||||
}
|
||||
|
||||
// CapEmbedder projects caption features to model dimension
|
||||
type CapEmbedder struct {
|
||||
Norm *nn.RMSNorm `weight:"0"`
|
||||
Linear *nn.Linear `weight:"1"`
|
||||
PadToken *mlx.Array // loaded separately at root level
|
||||
}
|
||||
|
||||
// Forward projects caption embeddings: [B, L, cap_feat_dim] -> [B, L, dim]
|
||||
func (ce *CapEmbedder) Forward(capFeats *mlx.Array) *mlx.Array {
|
||||
// RMSNorm on last axis (uses 1e-6)
|
||||
h := ce.Norm.Forward(capFeats, 1e-6)
|
||||
// Linear projection
|
||||
return ce.Linear.Forward(h)
|
||||
}
|
||||
|
||||
// FeedForward implements SwiGLU FFN
|
||||
type FeedForward struct {
|
||||
W1 *nn.Linear `weight:"w1"` // gate projection
|
||||
W2 *nn.Linear `weight:"w2"` // down projection
|
||||
W3 *nn.Linear `weight:"w3"` // up projection
|
||||
OutDim int32 // computed from W2
|
||||
}
|
||||
|
||||
// Forward applies SwiGLU: silu(W1(x)) * W3(x), then W2
|
||||
func (ff *FeedForward) Forward(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
D := shape[2]
|
||||
|
||||
// Reshape for matmul
|
||||
x = mlx.Reshape(x, B*L, D)
|
||||
gate := ff.W1.Forward(x)
|
||||
gate = mlx.SiLU(gate)
|
||||
up := ff.W3.Forward(x)
|
||||
h := mlx.Mul(gate, up)
|
||||
out := ff.W2.Forward(h)
|
||||
|
||||
return mlx.Reshape(out, B, L, ff.OutDim)
|
||||
}
|
||||
|
||||
// Attention implements multi-head attention with QK norm
|
||||
type Attention struct {
|
||||
ToQ *nn.Linear `weight:"to_q"`
|
||||
ToK *nn.Linear `weight:"to_k"`
|
||||
ToV *nn.Linear `weight:"to_v"`
|
||||
ToOut *nn.Linear `weight:"to_out.0"`
|
||||
NormQ *mlx.Array `weight:"norm_q.weight"` // [head_dim] for per-head RMSNorm
|
||||
NormK *mlx.Array `weight:"norm_k.weight"`
|
||||
// Computed fields
|
||||
NHeads int32
|
||||
HeadDim int32
|
||||
Dim int32
|
||||
Scale float32
|
||||
}
|
||||
|
||||
// Forward computes attention
|
||||
func (attn *Attention) Forward(x *mlx.Array, cos, sin *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
D := shape[2]
|
||||
|
||||
// Project Q, K, V
|
||||
xFlat := mlx.Reshape(x, B*L, D)
|
||||
q := attn.ToQ.Forward(xFlat)
|
||||
k := attn.ToK.Forward(xFlat)
|
||||
v := attn.ToV.Forward(xFlat)
|
||||
|
||||
// Reshape to [B, L, nheads, head_dim]
|
||||
q = mlx.Reshape(q, B, L, attn.NHeads, attn.HeadDim)
|
||||
k = mlx.Reshape(k, B, L, attn.NHeads, attn.HeadDim)
|
||||
v = mlx.Reshape(v, B, L, attn.NHeads, attn.HeadDim)
|
||||
|
||||
// QK norm
|
||||
q = mlx.RMSNorm(q, attn.NormQ, 1e-5)
|
||||
k = mlx.RMSNorm(k, attn.NormK, 1e-5)
|
||||
|
||||
// Apply RoPE if provided
|
||||
if cos != nil && sin != nil {
|
||||
q = applyRoPE3D(q, cos, sin)
|
||||
k = applyRoPE3D(k, cos, sin)
|
||||
}
|
||||
|
||||
// Transpose to [B, nheads, L, head_dim]
|
||||
q = mlx.Transpose(q, 0, 2, 1, 3)
|
||||
k = mlx.Transpose(k, 0, 2, 1, 3)
|
||||
v = mlx.Transpose(v, 0, 2, 1, 3)
|
||||
|
||||
// SDPA
|
||||
out := mlx.ScaledDotProductAttention(q, k, v, attn.Scale, false)
|
||||
|
||||
// Transpose back and reshape
|
||||
out = mlx.Transpose(out, 0, 2, 1, 3)
|
||||
out = mlx.Reshape(out, B*L, attn.Dim)
|
||||
out = attn.ToOut.Forward(out)
|
||||
|
||||
return mlx.Reshape(out, B, L, attn.Dim)
|
||||
}
|
||||
|
||||
// applyRoPE3D applies 3-axis rotary position embeddings
|
||||
// x: [B, L, nheads, head_dim]
|
||||
// cos, sin: [B, L, 1, head_dim/2]
|
||||
func applyRoPE3D(x *mlx.Array, cos, sin *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
nheads := shape[2]
|
||||
headDim := shape[3]
|
||||
half := headDim / 2
|
||||
|
||||
// Create even/odd index arrays
|
||||
evenIdx := make([]int32, half)
|
||||
oddIdx := make([]int32, half)
|
||||
for i := int32(0); i < half; i++ {
|
||||
evenIdx[i] = i * 2
|
||||
oddIdx[i] = i*2 + 1
|
||||
}
|
||||
evenIndices := mlx.NewArrayInt32(evenIdx, []int32{half})
|
||||
oddIndices := mlx.NewArrayInt32(oddIdx, []int32{half})
|
||||
|
||||
// Extract x1 (even indices) and x2 (odd indices) along last axis
|
||||
x1 := mlx.Take(x, evenIndices, 3) // [B, L, nheads, half]
|
||||
x2 := mlx.Take(x, oddIndices, 3) // [B, L, nheads, half]
|
||||
|
||||
// Apply rotation: [x1*cos - x2*sin, x1*sin + x2*cos]
|
||||
r1 := mlx.Sub(mlx.Mul(x1, cos), mlx.Mul(x2, sin))
|
||||
r2 := mlx.Add(mlx.Mul(x1, sin), mlx.Mul(x2, cos))
|
||||
|
||||
// Stack and reshape to interleave: [r1_0, r2_0, r1_1, r2_1, ...]
|
||||
r1 = mlx.ExpandDims(r1, 4) // [B, L, nheads, half, 1]
|
||||
r2 = mlx.ExpandDims(r2, 4) // [B, L, nheads, half, 1]
|
||||
stacked := mlx.Concatenate([]*mlx.Array{r1, r2}, 4) // [B, L, nheads, half, 2]
|
||||
return mlx.Reshape(stacked, B, L, nheads, headDim)
|
||||
}
|
||||
|
||||
// TransformerBlock is a single transformer block with optional AdaLN modulation
|
||||
type TransformerBlock struct {
|
||||
Attention *Attention `weight:"attention"`
|
||||
FeedForward *FeedForward `weight:"feed_forward"`
|
||||
AttentionNorm1 *nn.RMSNorm `weight:"attention_norm1"`
|
||||
AttentionNorm2 *nn.RMSNorm `weight:"attention_norm2"`
|
||||
FFNNorm1 *nn.RMSNorm `weight:"ffn_norm1"`
|
||||
FFNNorm2 *nn.RMSNorm `weight:"ffn_norm2"`
|
||||
AdaLN *nn.Linear `weight:"adaLN_modulation.0,optional"` // only if modulation
|
||||
// Computed fields
|
||||
HasModulation bool
|
||||
Dim int32
|
||||
}
|
||||
|
||||
// Forward applies the transformer block
|
||||
func (tb *TransformerBlock) Forward(x *mlx.Array, adaln *mlx.Array, cos, sin *mlx.Array, eps float32) *mlx.Array {
|
||||
if tb.AdaLN != nil && adaln != nil {
|
||||
// Compute modulation: [B, 256] -> [B, 4*dim]
|
||||
chunks := tb.AdaLN.Forward(adaln)
|
||||
|
||||
// Split into 4 parts: scale_msa, gate_msa, scale_mlp, gate_mlp
|
||||
chunkShape := chunks.Shape()
|
||||
chunkDim := chunkShape[1] / 4
|
||||
|
||||
scaleMSA := mlx.Slice(chunks, []int32{0, 0}, []int32{chunkShape[0], chunkDim})
|
||||
gateMSA := mlx.Slice(chunks, []int32{0, chunkDim}, []int32{chunkShape[0], chunkDim * 2})
|
||||
scaleMLP := mlx.Slice(chunks, []int32{0, chunkDim * 2}, []int32{chunkShape[0], chunkDim * 3})
|
||||
gateMLP := mlx.Slice(chunks, []int32{0, chunkDim * 3}, []int32{chunkShape[0], chunkDim * 4})
|
||||
|
||||
// Expand for broadcasting: [B, 1, dim]
|
||||
scaleMSA = mlx.ExpandDims(scaleMSA, 1)
|
||||
gateMSA = mlx.ExpandDims(gateMSA, 1)
|
||||
scaleMLP = mlx.ExpandDims(scaleMLP, 1)
|
||||
gateMLP = mlx.ExpandDims(gateMLP, 1)
|
||||
|
||||
// Attention with modulation
|
||||
normX := tb.AttentionNorm1.Forward(x, eps)
|
||||
normX = mlx.Mul(normX, mlx.AddScalar(scaleMSA, 1.0))
|
||||
attnOut := tb.Attention.Forward(normX, cos, sin)
|
||||
attnOut = tb.AttentionNorm2.Forward(attnOut, eps)
|
||||
x = mlx.Add(x, mlx.Mul(mlx.Tanh(gateMSA), attnOut))
|
||||
|
||||
// FFN with modulation
|
||||
normFFN := tb.FFNNorm1.Forward(x, eps)
|
||||
normFFN = mlx.Mul(normFFN, mlx.AddScalar(scaleMLP, 1.0))
|
||||
ffnOut := tb.FeedForward.Forward(normFFN)
|
||||
ffnOut = tb.FFNNorm2.Forward(ffnOut, eps)
|
||||
x = mlx.Add(x, mlx.Mul(mlx.Tanh(gateMLP), ffnOut))
|
||||
} else {
|
||||
// No modulation (context refiner)
|
||||
attnOut := tb.Attention.Forward(tb.AttentionNorm1.Forward(x, eps), cos, sin)
|
||||
x = mlx.Add(x, tb.AttentionNorm2.Forward(attnOut, eps))
|
||||
|
||||
ffnOut := tb.FeedForward.Forward(tb.FFNNorm1.Forward(x, eps))
|
||||
x = mlx.Add(x, tb.FFNNorm2.Forward(ffnOut, eps))
|
||||
}
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// FinalLayer outputs the denoised patches
|
||||
type FinalLayer struct {
|
||||
AdaLN *nn.Linear `weight:"adaLN_modulation.1"` // [256] -> [dim]
|
||||
Output *nn.Linear `weight:"linear"` // [dim] -> [out_channels]
|
||||
OutDim int32 // computed from Output
|
||||
}
|
||||
|
||||
// Forward computes final output
|
||||
func (fl *FinalLayer) Forward(x *mlx.Array, c *mlx.Array) *mlx.Array {
|
||||
// c: [B, 256] -> scale: [B, dim]
|
||||
scale := mlx.SiLU(c)
|
||||
scale = fl.AdaLN.Forward(scale)
|
||||
scale = mlx.ExpandDims(scale, 1) // [B, 1, dim]
|
||||
|
||||
// LayerNorm (affine=False) then scale
|
||||
x = layerNormNoAffine(x, 1e-6)
|
||||
x = mlx.Mul(x, mlx.AddScalar(scale, 1.0))
|
||||
|
||||
// Output projection
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
D := shape[2]
|
||||
x = mlx.Reshape(x, B*L, D)
|
||||
x = fl.Output.Forward(x)
|
||||
|
||||
return mlx.Reshape(x, B, L, fl.OutDim)
|
||||
}
|
||||
|
||||
// layerNormNoAffine applies layer norm without learnable parameters
|
||||
func layerNormNoAffine(x *mlx.Array, eps float32) *mlx.Array {
|
||||
ndim := x.Ndim()
|
||||
lastAxis := ndim - 1
|
||||
|
||||
mean := mlx.Mean(x, lastAxis, true)
|
||||
xCentered := mlx.Sub(x, mean)
|
||||
variance := mlx.Mean(mlx.Square(xCentered), lastAxis, true)
|
||||
return mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, eps)))
|
||||
}
|
||||
|
||||
// Transformer is the full Z-Image DiT model
|
||||
type Transformer struct {
|
||||
TEmbed *TimestepEmbedder `weight:"t_embedder"`
|
||||
XEmbed *XEmbedder `weight:"all_x_embedder"`
|
||||
CapEmbed *CapEmbedder `weight:"cap_embedder"`
|
||||
NoiseRefiners []*TransformerBlock `weight:"noise_refiner"`
|
||||
ContextRefiners []*TransformerBlock `weight:"context_refiner"`
|
||||
Layers []*TransformerBlock `weight:"layers"`
|
||||
FinalLayer *FinalLayer `weight:"all_final_layer.2-1"`
|
||||
XPadToken *mlx.Array `weight:"x_pad_token"`
|
||||
CapPadToken *mlx.Array `weight:"cap_pad_token"`
|
||||
*TransformerConfig
|
||||
}
|
||||
|
||||
// Load loads the Z-Image transformer from a directory
|
||||
func (m *Transformer) Load(path string) error {
|
||||
fmt.Println("Loading Z-Image transformer...")
|
||||
|
||||
// Load config
|
||||
cfg, err := loadTransformerConfig(filepath.Join(path, "config.json"))
|
||||
if err != nil {
|
||||
return fmt.Errorf("config: %w", err)
|
||||
}
|
||||
m.TransformerConfig = cfg
|
||||
|
||||
// Pre-allocate slices for loader
|
||||
m.NoiseRefiners = make([]*TransformerBlock, cfg.NRefinerLayers)
|
||||
m.ContextRefiners = make([]*TransformerBlock, cfg.NRefinerLayers)
|
||||
m.Layers = make([]*TransformerBlock, cfg.NLayers)
|
||||
|
||||
// Load weights
|
||||
weights, err := safetensors.LoadModelWeights(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
|
||||
fmt.Print(" Loading weights as bf16... ")
|
||||
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
|
||||
return fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
fmt.Printf("✓ (%.1f GB)\n", float64(mlx.MetalGetActiveMemory())/(1024*1024*1024))
|
||||
|
||||
fmt.Print(" Loading weights via struct tags... ")
|
||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||
return fmt.Errorf("load module: %w", err)
|
||||
}
|
||||
fmt.Println("✓")
|
||||
|
||||
// Initialize computed fields
|
||||
m.TEmbed.FreqEmbedSize = 256
|
||||
m.FinalLayer.OutDim = m.FinalLayer.Output.Weight.Shape()[0]
|
||||
m.CapEmbed.Norm.Eps = 1e-6
|
||||
|
||||
for _, block := range m.NoiseRefiners {
|
||||
initTransformerBlock(block, cfg)
|
||||
}
|
||||
for _, block := range m.ContextRefiners {
|
||||
initTransformerBlock(block, cfg)
|
||||
}
|
||||
for _, block := range m.Layers {
|
||||
initTransformerBlock(block, cfg)
|
||||
}
|
||||
|
||||
weights.ReleaseAll()
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadTransformerConfig loads transformer config from a JSON file
|
||||
func loadTransformerConfig(path string) (*TransformerConfig, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read config: %w", err)
|
||||
}
|
||||
var cfg TransformerConfig
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("parse config: %w", err)
|
||||
}
|
||||
// Extract PatchSize from array
|
||||
if len(cfg.AllPatchSize) > 0 {
|
||||
cfg.PatchSize = cfg.AllPatchSize[0]
|
||||
}
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
// initTransformerBlock sets computed fields on a transformer block
|
||||
func initTransformerBlock(block *TransformerBlock, cfg *TransformerConfig) {
|
||||
block.Dim = cfg.Dim
|
||||
block.HasModulation = block.AdaLN != nil
|
||||
|
||||
// Init attention computed fields
|
||||
attn := block.Attention
|
||||
attn.NHeads = cfg.NHeads
|
||||
attn.HeadDim = cfg.Dim / cfg.NHeads
|
||||
attn.Dim = cfg.Dim
|
||||
attn.Scale = float32(1.0 / math.Sqrt(float64(attn.HeadDim)))
|
||||
|
||||
// Init feedforward OutDim
|
||||
block.FeedForward.OutDim = block.FeedForward.W2.Weight.Shape()[0]
|
||||
|
||||
// Set eps on all RMSNorm layers
|
||||
block.AttentionNorm1.Eps = cfg.NormEps
|
||||
block.AttentionNorm2.Eps = cfg.NormEps
|
||||
block.FFNNorm1.Eps = cfg.NormEps
|
||||
block.FFNNorm2.Eps = cfg.NormEps
|
||||
}
|
||||
|
||||
// RoPECache holds precomputed RoPE values
|
||||
type RoPECache struct {
|
||||
ImgCos *mlx.Array
|
||||
ImgSin *mlx.Array
|
||||
CapCos *mlx.Array
|
||||
CapSin *mlx.Array
|
||||
UnifiedCos *mlx.Array
|
||||
UnifiedSin *mlx.Array
|
||||
ImgLen int32
|
||||
CapLen int32
|
||||
}
|
||||
|
||||
// PrepareRoPECache precomputes RoPE values for the given image and caption lengths.
|
||||
// hTok and wTok are the number of tokens in each dimension (latentH/patchSize, latentW/patchSize).
|
||||
func (m *Transformer) PrepareRoPECache(hTok, wTok, capLen int32) *RoPECache {
|
||||
imgLen := hTok * wTok
|
||||
|
||||
// Image positions: grid over (1, H, W) starting at (capLen+1, 0, 0)
|
||||
imgPos := createCoordinateGrid(1, hTok, wTok, capLen+1, 0, 0)
|
||||
imgPos = mlx.ToBFloat16(imgPos)
|
||||
// Caption positions: grid over (capLen, 1, 1) starting at (1, 0, 0)
|
||||
capPos := createCoordinateGrid(capLen, 1, 1, 1, 0, 0)
|
||||
capPos = mlx.ToBFloat16(capPos)
|
||||
|
||||
// Compute RoPE from UNIFIED positions
|
||||
unifiedPos := mlx.Concatenate([]*mlx.Array{imgPos, capPos}, 1)
|
||||
unifiedCos, unifiedSin := prepareRoPE3D(unifiedPos, m.TransformerConfig.AxesDims)
|
||||
|
||||
// Slice RoPE for image and caption parts
|
||||
imgCos := mlx.Slice(unifiedCos, []int32{0, 0, 0, 0}, []int32{1, imgLen, 1, 64})
|
||||
imgSin := mlx.Slice(unifiedSin, []int32{0, 0, 0, 0}, []int32{1, imgLen, 1, 64})
|
||||
capCos := mlx.Slice(unifiedCos, []int32{0, imgLen, 0, 0}, []int32{1, imgLen + capLen, 1, 64})
|
||||
capSin := mlx.Slice(unifiedSin, []int32{0, imgLen, 0, 0}, []int32{1, imgLen + capLen, 1, 64})
|
||||
|
||||
return &RoPECache{
|
||||
ImgCos: imgCos,
|
||||
ImgSin: imgSin,
|
||||
CapCos: capCos,
|
||||
CapSin: capSin,
|
||||
UnifiedCos: unifiedCos,
|
||||
UnifiedSin: unifiedSin,
|
||||
ImgLen: imgLen,
|
||||
CapLen: capLen,
|
||||
}
|
||||
}
|
||||
|
||||
// Forward runs the Z-Image transformer with precomputed RoPE
|
||||
func (m *Transformer) Forward(x *mlx.Array, t *mlx.Array, capFeats *mlx.Array, rope *RoPECache) *mlx.Array {
|
||||
imgLen := rope.ImgLen
|
||||
|
||||
// Timestep embedding -> [B, 256]
|
||||
temb := m.TEmbed.Forward(mlx.MulScalar(t, m.TransformerConfig.TScale))
|
||||
|
||||
// Embed image patches -> [B, L_img, dim]
|
||||
x = m.XEmbed.Forward(x)
|
||||
|
||||
// Embed caption features -> [B, L_cap, dim]
|
||||
capEmb := m.CapEmbed.Forward(capFeats)
|
||||
|
||||
eps := m.NormEps
|
||||
|
||||
// Noise refiner: refine image patches with modulation
|
||||
for _, refiner := range m.NoiseRefiners {
|
||||
x = refiner.Forward(x, temb, rope.ImgCos, rope.ImgSin, eps)
|
||||
}
|
||||
|
||||
// Context refiner: refine caption (no modulation)
|
||||
for _, refiner := range m.ContextRefiners {
|
||||
capEmb = refiner.Forward(capEmb, nil, rope.CapCos, rope.CapSin, eps)
|
||||
}
|
||||
|
||||
// Concatenate image and caption for joint attention
|
||||
unified := mlx.Concatenate([]*mlx.Array{x, capEmb}, 1)
|
||||
|
||||
// Main transformer layers use full unified RoPE
|
||||
for _, layer := range m.Layers {
|
||||
unified = layer.Forward(unified, temb, rope.UnifiedCos, rope.UnifiedSin, eps)
|
||||
}
|
||||
|
||||
// Extract image tokens only
|
||||
unifiedShape := unified.Shape()
|
||||
B := unifiedShape[0]
|
||||
imgOut := mlx.Slice(unified, []int32{0, 0, 0}, []int32{B, imgLen, unifiedShape[2]})
|
||||
|
||||
// Final layer
|
||||
return m.FinalLayer.Forward(imgOut, temb)
|
||||
}
|
||||
|
||||
// ForwardWithCache runs the transformer with layer caching for faster inference.
|
||||
// On refresh steps (step % cacheInterval == 0), all layers are computed and cached.
|
||||
// On other steps, shallow layers (0 to cacheLayers-1) reuse cached outputs.
|
||||
func (m *Transformer) ForwardWithCache(
|
||||
x *mlx.Array,
|
||||
t *mlx.Array,
|
||||
capFeats *mlx.Array,
|
||||
rope *RoPECache,
|
||||
stepCache *cache.StepCache,
|
||||
step int,
|
||||
cacheInterval int,
|
||||
) *mlx.Array {
|
||||
imgLen := rope.ImgLen
|
||||
cacheLayers := stepCache.NumLayers()
|
||||
eps := m.NormEps
|
||||
|
||||
// Timestep embedding -> [B, 256]
|
||||
temb := m.TEmbed.Forward(mlx.MulScalar(t, m.TransformerConfig.TScale))
|
||||
|
||||
// Embed image patches -> [B, L_img, dim]
|
||||
x = m.XEmbed.Forward(x)
|
||||
|
||||
// Context refiners: compute once on step 0, reuse forever
|
||||
// (caption embedding doesn't depend on timestep or latents)
|
||||
var capEmb *mlx.Array
|
||||
if stepCache.GetConstant() != nil {
|
||||
capEmb = stepCache.GetConstant()
|
||||
} else {
|
||||
capEmb = m.CapEmbed.Forward(capFeats)
|
||||
for _, refiner := range m.ContextRefiners {
|
||||
capEmb = refiner.Forward(capEmb, nil, rope.CapCos, rope.CapSin, eps)
|
||||
}
|
||||
stepCache.SetConstant(capEmb)
|
||||
}
|
||||
|
||||
// Noise refiners: always compute (depend on x which changes each step)
|
||||
for _, refiner := range m.NoiseRefiners {
|
||||
x = refiner.Forward(x, temb, rope.ImgCos, rope.ImgSin, eps)
|
||||
}
|
||||
|
||||
// Concatenate image and caption for joint attention
|
||||
unified := mlx.Concatenate([]*mlx.Array{x, capEmb}, 1)
|
||||
|
||||
// Determine if this is a cache refresh step
|
||||
refreshCache := stepCache.ShouldRefresh(step, cacheInterval)
|
||||
|
||||
// Main transformer layers with caching
|
||||
for i, layer := range m.Layers {
|
||||
if i < cacheLayers && !refreshCache && stepCache.Get(i) != nil {
|
||||
// Use cached output for shallow layers
|
||||
unified = stepCache.Get(i)
|
||||
} else {
|
||||
// Compute layer
|
||||
unified = layer.Forward(unified, temb, rope.UnifiedCos, rope.UnifiedSin, eps)
|
||||
// Cache shallow layer outputs on refresh steps
|
||||
if i < cacheLayers && refreshCache {
|
||||
stepCache.Set(i, unified)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Extract image tokens only
|
||||
unifiedShape := unified.Shape()
|
||||
B := unifiedShape[0]
|
||||
imgOut := mlx.Slice(unified, []int32{0, 0, 0}, []int32{B, imgLen, unifiedShape[2]})
|
||||
|
||||
// Final layer
|
||||
return m.FinalLayer.Forward(imgOut, temb)
|
||||
}
|
||||
|
||||
// createCoordinateGrid creates 3D position grid [1, d0*d1*d2, 3]
|
||||
func createCoordinateGrid(d0, d1, d2, s0, s1, s2 int32) *mlx.Array {
|
||||
// Create meshgrid and stack
|
||||
total := d0 * d1 * d2
|
||||
coords := make([]float32, total*3)
|
||||
|
||||
idx := 0
|
||||
for i := int32(0); i < d0; i++ {
|
||||
for j := int32(0); j < d1; j++ {
|
||||
for k := int32(0); k < d2; k++ {
|
||||
coords[idx*3+0] = float32(s0 + i)
|
||||
coords[idx*3+1] = float32(s1 + j)
|
||||
coords[idx*3+2] = float32(s2 + k)
|
||||
idx++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return mlx.NewArray(coords, []int32{1, total, 3})
|
||||
}
|
||||
|
||||
// prepareRoPE3D computes cos/sin for 3-axis RoPE
|
||||
// positions: [B, L, 3] with (h, w, t) coordinates
|
||||
// axesDims: [32, 48, 48] - dimensions for each axis
|
||||
// Returns: cos, sin each [B, L, 1, head_dim/2]
|
||||
func prepareRoPE3D(positions *mlx.Array, axesDims []int32) (*mlx.Array, *mlx.Array) {
|
||||
// Compute frequencies for each axis
|
||||
// dims = [32, 48, 48], so halves = [16, 24, 24]
|
||||
ropeTheta := float32(256.0)
|
||||
|
||||
freqs := make([]*mlx.Array, 3)
|
||||
for axis := 0; axis < 3; axis++ {
|
||||
half := axesDims[axis] / 2
|
||||
f := make([]float32, half)
|
||||
for i := int32(0); i < half; i++ {
|
||||
f[i] = float32(math.Exp(-math.Log(float64(ropeTheta)) * float64(i) / float64(half)))
|
||||
}
|
||||
freqs[axis] = mlx.NewArray(f, []int32{1, 1, 1, half})
|
||||
}
|
||||
|
||||
// Extract position coordinates
|
||||
shape := positions.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
|
||||
// positions[:, :, 0] -> h positions
|
||||
posH := mlx.Slice(positions, []int32{0, 0, 0}, []int32{B, L, 1})
|
||||
posW := mlx.Slice(positions, []int32{0, 0, 1}, []int32{B, L, 2})
|
||||
posT := mlx.Slice(positions, []int32{0, 0, 2}, []int32{B, L, 3})
|
||||
|
||||
// Compute args: pos * freqs for each axis
|
||||
posH = mlx.ExpandDims(posH, 3) // [B, L, 1, 1]
|
||||
posW = mlx.ExpandDims(posW, 3)
|
||||
posT = mlx.ExpandDims(posT, 3)
|
||||
|
||||
argsH := mlx.Mul(posH, freqs[0]) // [B, L, 1, 16]
|
||||
argsW := mlx.Mul(posW, freqs[1]) // [B, L, 1, 24]
|
||||
argsT := mlx.Mul(posT, freqs[2]) // [B, L, 1, 24]
|
||||
|
||||
// Concatenate: [B, L, 1, 16+24+24=64]
|
||||
args := mlx.Concatenate([]*mlx.Array{argsH, argsW, argsT}, 3)
|
||||
|
||||
// Compute cos and sin
|
||||
return mlx.Cos(args), mlx.Sin(args)
|
||||
}
|
||||
|
||||
// PatchifyLatents converts latents [B, C, H, W] to patches [B, L, C*patch^2]
|
||||
// Matches Python: x.reshape(C, 1, 1, H_tok, 2, W_tok, 2).transpose(1,2,3,5,4,6,0).reshape(1,-1,C*4)
|
||||
func PatchifyLatents(latents *mlx.Array, patchSize int32) *mlx.Array {
|
||||
shape := latents.Shape()
|
||||
C := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
|
||||
pH := H / patchSize // H_tok
|
||||
pW := W / patchSize // W_tok
|
||||
|
||||
// Match Python exactly: reshape treating B=1 as part of contiguous data
|
||||
// [1, C, H, W] -> [C, 1, 1, pH, 2, pW, 2]
|
||||
x := mlx.Reshape(latents, C, 1, 1, pH, patchSize, pW, patchSize)
|
||||
|
||||
// Python: transpose(1, 2, 3, 5, 4, 6, 0)
|
||||
// [C, 1, 1, pH, 2, pW, 2] -> [1, 1, pH, pW, 2, 2, C]
|
||||
x = mlx.Transpose(x, 1, 2, 3, 5, 4, 6, 0)
|
||||
|
||||
// [1, 1, pH, pW, 2, 2, C] -> [1, pH*pW, C*4]
|
||||
return mlx.Reshape(x, 1, pH*pW, C*patchSize*patchSize)
|
||||
}
|
||||
|
||||
// UnpatchifyLatents converts patches [B, L, C*patch^2] back to [B, C, H, W]
|
||||
// Matches Python: out.reshape(1,1,H_tok,W_tok,2,2,C).transpose(6,0,1,2,4,3,5).reshape(1,C,H,W)
|
||||
func UnpatchifyLatents(patches *mlx.Array, patchSize, H, W, C int32) *mlx.Array {
|
||||
pH := H / patchSize
|
||||
pW := W / patchSize
|
||||
|
||||
// [1, L, C*4] -> [1, 1, pH, pW, 2, 2, C]
|
||||
x := mlx.Reshape(patches, 1, 1, pH, pW, patchSize, patchSize, C)
|
||||
|
||||
// Python: transpose(6, 0, 1, 2, 4, 3, 5)
|
||||
// [1, 1, pH, pW, 2, 2, C] -> [C, 1, 1, pH, 2, pW, 2]
|
||||
x = mlx.Transpose(x, 6, 0, 1, 2, 4, 3, 5)
|
||||
|
||||
// [C, 1, 1, pH, 2, pW, 2] -> [1, C, H, W]
|
||||
return mlx.Reshape(x, 1, C, H, W)
|
||||
}
|
||||
652
x/imagegen/models/zimage/vae.go
Normal file
652
x/imagegen/models/zimage/vae.go
Normal file
@@ -0,0 +1,652 @@
|
||||
//go:build mlx
|
||||
|
||||
package zimage
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
|
||||
// VAEConfig holds VAE decoder configuration
|
||||
type VAEConfig struct {
|
||||
InChannels int32 `json:"in_channels"`
|
||||
OutChannels int32 `json:"out_channels"`
|
||||
LatentChannels int32 `json:"latent_channels"`
|
||||
BlockOutChannels []int32 `json:"block_out_channels"`
|
||||
LayersPerBlock int32 `json:"layers_per_block"`
|
||||
NormNumGroups int32 `json:"norm_num_groups"`
|
||||
ScalingFactor float32 `json:"scaling_factor"`
|
||||
ShiftFactor float32 `json:"shift_factor"`
|
||||
}
|
||||
|
||||
// loadVAEConfig loads VAE config from a JSON file
|
||||
func loadVAEConfig(path string) (*VAEConfig, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read config: %w", err)
|
||||
}
|
||||
var cfg VAEConfig
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("parse config: %w", err)
|
||||
}
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
// GroupNormLayer implements group normalization
|
||||
type GroupNormLayer struct {
|
||||
Weight *mlx.Array
|
||||
Bias *mlx.Array
|
||||
NumGroups int32
|
||||
Eps float32
|
||||
}
|
||||
|
||||
// NewGroupNorm creates a group norm layer
|
||||
func NewGroupNorm(weight, bias *mlx.Array, numGroups int32) *GroupNormLayer {
|
||||
return &GroupNormLayer{
|
||||
Weight: weight,
|
||||
Bias: bias,
|
||||
NumGroups: numGroups,
|
||||
Eps: 1e-5,
|
||||
}
|
||||
}
|
||||
|
||||
// Forward applies group normalization
|
||||
func (gn *GroupNormLayer) Forward(x *mlx.Array) *mlx.Array {
|
||||
// x: [B, C, H, W]
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
C := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
|
||||
// Reshape to [B, groups, C/groups, H, W]
|
||||
groupSize := C / gn.NumGroups
|
||||
x = mlx.Reshape(x, B, gn.NumGroups, groupSize, H, W)
|
||||
|
||||
// Compute mean and variance per group
|
||||
mean := mlx.Mean(x, 2, true)
|
||||
mean = mlx.Mean(mean, 3, true)
|
||||
mean = mlx.Mean(mean, 4, true)
|
||||
|
||||
xCentered := mlx.Sub(x, mean)
|
||||
variance := mlx.Mean(mlx.Square(xCentered), 2, true)
|
||||
variance = mlx.Mean(variance, 3, true)
|
||||
variance = mlx.Mean(variance, 4, true)
|
||||
|
||||
// Normalize
|
||||
xNorm := mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, gn.Eps)))
|
||||
|
||||
// Reshape back to [B, C, H, W]
|
||||
xNorm = mlx.Reshape(xNorm, B, C, H, W)
|
||||
|
||||
// Scale and shift (weight and bias are [C])
|
||||
if gn.Weight != nil {
|
||||
weight := mlx.Reshape(gn.Weight, 1, C, 1, 1)
|
||||
xNorm = mlx.Mul(xNorm, weight)
|
||||
}
|
||||
if gn.Bias != nil {
|
||||
bias := mlx.Reshape(gn.Bias, 1, C, 1, 1)
|
||||
xNorm = mlx.Add(xNorm, bias)
|
||||
}
|
||||
|
||||
return xNorm
|
||||
}
|
||||
|
||||
// Conv2D represents a 2D convolution layer
|
||||
// MLX uses NHWC format, but we store weights in OHWI format for MLX conv
|
||||
type Conv2D struct {
|
||||
Weight *mlx.Array // [out_channels, kH, kW, in_channels] (OHWI for MLX)
|
||||
Bias *mlx.Array // [out_channels]
|
||||
Stride int32
|
||||
Padding int32
|
||||
}
|
||||
|
||||
// NewConv2D creates a Conv2D layer
|
||||
// weight comes in as [out_channels, in_channels, kH, kW] (OIHW from PyTorch)
|
||||
// we transpose to [out_channels, kH, kW, in_channels] (OHWI for MLX)
|
||||
func NewConv2D(weight, bias *mlx.Array, stride, padding int32) *Conv2D {
|
||||
// Transpose weight from OIHW to OHWI
|
||||
// [O, I, H, W] -> [O, H, W, I]
|
||||
weightOHWI := mlx.Transpose(weight, 0, 2, 3, 1)
|
||||
return &Conv2D{
|
||||
Weight: weightOHWI,
|
||||
Bias: bias,
|
||||
Stride: stride,
|
||||
Padding: padding,
|
||||
}
|
||||
}
|
||||
|
||||
// Forward applies convolution
|
||||
// Input x is in NCHW format, we convert to NHWC for MLX, then back to NCHW
|
||||
func (conv *Conv2D) Forward(x *mlx.Array) *mlx.Array {
|
||||
// x: [N, C, H, W] -> [N, H, W, C]
|
||||
xNHWC := mlx.Transpose(x, 0, 2, 3, 1)
|
||||
|
||||
// Conv in NHWC format
|
||||
outNHWC := mlx.Conv2d(xNHWC, conv.Weight, conv.Stride, conv.Padding)
|
||||
|
||||
// Convert back to NCHW: [N, H, W, C] -> [N, C, H, W]
|
||||
out := mlx.Transpose(outNHWC, 0, 3, 1, 2)
|
||||
|
||||
if conv.Bias != nil {
|
||||
bias := mlx.Reshape(conv.Bias, 1, conv.Bias.Dim(0), 1, 1)
|
||||
out = mlx.Add(out, bias)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// ResnetBlock2D implements a ResNet block for VAE
|
||||
type ResnetBlock2D struct {
|
||||
Norm1 *GroupNormLayer
|
||||
Conv1 *Conv2D
|
||||
Norm2 *GroupNormLayer
|
||||
Conv2 *Conv2D
|
||||
ConvShortcut *Conv2D // nil if in_channels == out_channels
|
||||
}
|
||||
|
||||
// NewResnetBlock2D creates a ResNet block
|
||||
func NewResnetBlock2D(weights *safetensors.ModelWeights, prefix string, numGroups int32) (*ResnetBlock2D, error) {
|
||||
norm1Weight, err := weights.GetTensor(prefix + ".norm1.weight")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
norm1Bias, err := weights.GetTensor(prefix + ".norm1.bias")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
conv1Weight, err := weights.GetTensor(prefix + ".conv1.weight")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conv1Bias, err := weights.GetTensor(prefix + ".conv1.bias")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
norm2Weight, err := weights.GetTensor(prefix + ".norm2.weight")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
norm2Bias, err := weights.GetTensor(prefix + ".norm2.bias")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
conv2Weight, err := weights.GetTensor(prefix + ".conv2.weight")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conv2Bias, err := weights.GetTensor(prefix + ".conv2.bias")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
block := &ResnetBlock2D{
|
||||
Norm1: NewGroupNorm(norm1Weight, norm1Bias, numGroups),
|
||||
Conv1: NewConv2D(conv1Weight, conv1Bias, 1, 1),
|
||||
Norm2: NewGroupNorm(norm2Weight, norm2Bias, numGroups),
|
||||
Conv2: NewConv2D(conv2Weight, conv2Bias, 1, 1),
|
||||
}
|
||||
|
||||
if weights.HasTensor(prefix + ".conv_shortcut.weight") {
|
||||
shortcutWeight, err := weights.GetTensor(prefix + ".conv_shortcut.weight")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
shortcutBias, err := weights.GetTensor(prefix + ".conv_shortcut.bias")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
block.ConvShortcut = NewConv2D(shortcutWeight, shortcutBias, 1, 0)
|
||||
}
|
||||
|
||||
return block, nil
|
||||
}
|
||||
|
||||
// Forward applies the ResNet block with staged evaluation
|
||||
func (rb *ResnetBlock2D) Forward(x *mlx.Array) *mlx.Array {
|
||||
var h *mlx.Array
|
||||
|
||||
// Stage 1: norm1
|
||||
{
|
||||
h = rb.Norm1.Forward(x)
|
||||
mlx.Eval(h)
|
||||
}
|
||||
|
||||
// Stage 2: silu + conv1
|
||||
{
|
||||
prev := h
|
||||
h = mlx.SiLU(h)
|
||||
h = rb.Conv1.Forward(h)
|
||||
prev.Free()
|
||||
mlx.Eval(h)
|
||||
}
|
||||
|
||||
// Stage 3: norm2
|
||||
{
|
||||
prev := h
|
||||
h = rb.Norm2.Forward(h)
|
||||
prev.Free()
|
||||
mlx.Eval(h)
|
||||
}
|
||||
|
||||
// Stage 4: silu + conv2
|
||||
{
|
||||
prev := h
|
||||
h = mlx.SiLU(h)
|
||||
h = rb.Conv2.Forward(h)
|
||||
prev.Free()
|
||||
mlx.Eval(h)
|
||||
}
|
||||
|
||||
// Residual connection
|
||||
{
|
||||
prev := h
|
||||
if rb.ConvShortcut != nil {
|
||||
shortcut := rb.ConvShortcut.Forward(x)
|
||||
h = mlx.Add(h, shortcut)
|
||||
} else {
|
||||
h = mlx.Add(h, x)
|
||||
}
|
||||
prev.Free()
|
||||
mlx.Eval(h)
|
||||
}
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// VAEAttentionBlock implements self-attention for VAE
|
||||
type VAEAttentionBlock struct {
|
||||
GroupNorm *GroupNormLayer
|
||||
ToQWeight *mlx.Array
|
||||
ToQBias *mlx.Array
|
||||
ToKWeight *mlx.Array
|
||||
ToKBias *mlx.Array
|
||||
ToVWeight *mlx.Array
|
||||
ToVBias *mlx.Array
|
||||
ToOutWeight *mlx.Array
|
||||
ToOutBias *mlx.Array
|
||||
NumHeads int32
|
||||
}
|
||||
|
||||
// NewVAEAttentionBlock creates an attention block
|
||||
func NewVAEAttentionBlock(weights *safetensors.ModelWeights, prefix string, numGroups int32) (*VAEAttentionBlock, error) {
|
||||
normWeight, err := weights.GetTensor(prefix + ".group_norm.weight")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
normBias, err := weights.GetTensor(prefix + ".group_norm.bias")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
toQWeight, err := weights.GetTensor(prefix + ".to_q.weight")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
toQBias, err := weights.GetTensor(prefix + ".to_q.bias")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
toKWeight, err := weights.GetTensor(prefix + ".to_k.weight")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
toKBias, err := weights.GetTensor(prefix + ".to_k.bias")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
toVWeight, err := weights.GetTensor(prefix + ".to_v.weight")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
toVBias, err := weights.GetTensor(prefix + ".to_v.bias")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
toOutWeight, err := weights.GetTensor(prefix + ".to_out.0.weight")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
toOutBias, err := weights.GetTensor(prefix + ".to_out.0.bias")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &VAEAttentionBlock{
|
||||
GroupNorm: NewGroupNorm(normWeight, normBias, numGroups),
|
||||
ToQWeight: mlx.Transpose(toQWeight, 1, 0),
|
||||
ToQBias: toQBias,
|
||||
ToKWeight: mlx.Transpose(toKWeight, 1, 0),
|
||||
ToKBias: toKBias,
|
||||
ToVWeight: mlx.Transpose(toVWeight, 1, 0),
|
||||
ToVBias: toVBias,
|
||||
ToOutWeight: mlx.Transpose(toOutWeight, 1, 0),
|
||||
ToOutBias: toOutBias,
|
||||
NumHeads: 1,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies attention with staged evaluation
|
||||
func (ab *VAEAttentionBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
residual := x
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
C := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
|
||||
var h *mlx.Array
|
||||
|
||||
// Stage 1: GroupNorm + reshape
|
||||
{
|
||||
h = ab.GroupNorm.Forward(x)
|
||||
h = mlx.Transpose(h, 0, 2, 3, 1)
|
||||
h = mlx.Reshape(h, B, H*W, C)
|
||||
mlx.Eval(h)
|
||||
}
|
||||
|
||||
var out *mlx.Array
|
||||
|
||||
// Stage 2: Q, K, V projections + attention
|
||||
{
|
||||
q := mlx.Linear(h, ab.ToQWeight)
|
||||
q = mlx.Add(q, ab.ToQBias)
|
||||
k := mlx.Linear(h, ab.ToKWeight)
|
||||
k = mlx.Add(k, ab.ToKBias)
|
||||
v := mlx.Linear(h, ab.ToVWeight)
|
||||
v = mlx.Add(v, ab.ToVBias)
|
||||
h.Free()
|
||||
|
||||
q = mlx.ExpandDims(q, 1)
|
||||
k = mlx.ExpandDims(k, 1)
|
||||
v = mlx.ExpandDims(v, 1)
|
||||
|
||||
scale := float32(1.0 / math.Sqrt(float64(C)))
|
||||
out = mlx.ScaledDotProductAttention(q, k, v, scale, false)
|
||||
out = mlx.Squeeze(out, 1)
|
||||
mlx.Eval(out)
|
||||
}
|
||||
|
||||
// Stage 3: Output projection + reshape + residual
|
||||
{
|
||||
prev := out
|
||||
out = mlx.Linear(out, ab.ToOutWeight)
|
||||
out = mlx.Add(out, ab.ToOutBias)
|
||||
out = mlx.Reshape(out, B, H, W, C)
|
||||
out = mlx.Transpose(out, 0, 3, 1, 2)
|
||||
out = mlx.Add(out, residual)
|
||||
prev.Free()
|
||||
mlx.Eval(out)
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// UpDecoderBlock2D implements an upsampling decoder block
|
||||
type UpDecoderBlock2D struct {
|
||||
ResnetBlocks []*ResnetBlock2D
|
||||
Upsample *Conv2D
|
||||
}
|
||||
|
||||
// NewUpDecoderBlock2D creates an up decoder block
|
||||
func NewUpDecoderBlock2D(weights *safetensors.ModelWeights, prefix string, numLayers, numGroups int32, hasUpsample bool) (*UpDecoderBlock2D, error) {
|
||||
resnets := make([]*ResnetBlock2D, numLayers)
|
||||
for i := int32(0); i < numLayers; i++ {
|
||||
resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i)
|
||||
resnet, err := NewResnetBlock2D(weights, resPrefix, numGroups)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resnets[i] = resnet
|
||||
}
|
||||
|
||||
var upsample *Conv2D
|
||||
if hasUpsample {
|
||||
upWeight, err := weights.GetTensor(prefix + ".upsamplers.0.conv.weight")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
upBias, err := weights.GetTensor(prefix + ".upsamplers.0.conv.bias")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
upsample = NewConv2D(upWeight, upBias, 1, 1)
|
||||
}
|
||||
|
||||
return &UpDecoderBlock2D{
|
||||
ResnetBlocks: resnets,
|
||||
Upsample: upsample,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies the up decoder block with staged evaluation to reduce peak memory
|
||||
func (ub *UpDecoderBlock2D) Forward(x *mlx.Array) *mlx.Array {
|
||||
for _, resnet := range ub.ResnetBlocks {
|
||||
prev := x
|
||||
x = resnet.Forward(x) // ResNet handles its own pools
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
if ub.Upsample != nil {
|
||||
// Stage 1: Upsample2x (nearest neighbor)
|
||||
{
|
||||
prev := x
|
||||
x = Upsample2x(x)
|
||||
prev.Free()
|
||||
mlx.Eval(x)
|
||||
}
|
||||
|
||||
// Stage 2: Upsample conv
|
||||
{
|
||||
prev := x
|
||||
x = ub.Upsample.Forward(x)
|
||||
prev.Free()
|
||||
mlx.Eval(x)
|
||||
}
|
||||
}
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// VAEMidBlock is the middle block with attention
|
||||
type VAEMidBlock struct {
|
||||
Resnet1 *ResnetBlock2D
|
||||
Attention *VAEAttentionBlock
|
||||
Resnet2 *ResnetBlock2D
|
||||
}
|
||||
|
||||
// NewVAEMidBlock creates the mid block
|
||||
func NewVAEMidBlock(weights *safetensors.ModelWeights, prefix string, numGroups int32) (*VAEMidBlock, error) {
|
||||
resnet1, err := NewResnetBlock2D(weights, prefix+".resnets.0", numGroups)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
attention, err := NewVAEAttentionBlock(weights, prefix+".attentions.0", numGroups)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resnet2, err := NewResnetBlock2D(weights, prefix+".resnets.1", numGroups)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &VAEMidBlock{
|
||||
Resnet1: resnet1,
|
||||
Attention: attention,
|
||||
Resnet2: resnet2,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Forward applies the mid block with staged evaluation
|
||||
func (mb *VAEMidBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
prev := x
|
||||
x = mb.Resnet1.Forward(x) // ResNet handles its own pools
|
||||
prev.Free()
|
||||
|
||||
// Attention handles its own pools
|
||||
prev = x
|
||||
x = mb.Attention.Forward(x)
|
||||
prev.Free()
|
||||
|
||||
prev = x
|
||||
x = mb.Resnet2.Forward(x) // ResNet handles its own pools
|
||||
prev.Free()
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// VAEDecoder is the full VAE decoder
|
||||
type VAEDecoder struct {
|
||||
Config *VAEConfig
|
||||
ConvIn *Conv2D
|
||||
MidBlock *VAEMidBlock
|
||||
UpBlocks []*UpDecoderBlock2D
|
||||
ConvNormOut *GroupNormLayer
|
||||
ConvOut *Conv2D
|
||||
}
|
||||
|
||||
// Load loads the VAE decoder from a directory
|
||||
func (m *VAEDecoder) Load(path string) error {
|
||||
fmt.Println("Loading VAE decoder...")
|
||||
|
||||
// Load config
|
||||
cfg, err := loadVAEConfig(filepath.Join(path, "config.json"))
|
||||
if err != nil {
|
||||
return fmt.Errorf("config: %w", err)
|
||||
}
|
||||
m.Config = cfg
|
||||
|
||||
// Load weights
|
||||
weights, err := safetensors.LoadModelWeights(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
|
||||
// Load conv_in
|
||||
fmt.Print(" Loading conv_in... ")
|
||||
convInWeight, err := weights.GetTensor("decoder.conv_in.weight")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
convInBias, err := weights.GetTensor("decoder.conv_in.bias")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.ConvIn = NewConv2D(convInWeight, convInBias, 1, 1)
|
||||
fmt.Println("✓")
|
||||
|
||||
// Load mid block
|
||||
fmt.Print(" Loading mid block... ")
|
||||
m.MidBlock, err = NewVAEMidBlock(weights, "decoder.mid_block", cfg.NormNumGroups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Println("✓")
|
||||
|
||||
// Load up blocks
|
||||
fmt.Print(" Loading up blocks... ")
|
||||
numBlocks := len(cfg.BlockOutChannels)
|
||||
m.UpBlocks = make([]*UpDecoderBlock2D, numBlocks)
|
||||
for i := 0; i < numBlocks; i++ {
|
||||
prefix := fmt.Sprintf("decoder.up_blocks.%d", i)
|
||||
hasUpsample := i < numBlocks-1
|
||||
m.UpBlocks[i], err = NewUpDecoderBlock2D(weights, prefix, cfg.LayersPerBlock+1, cfg.NormNumGroups, hasUpsample)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
fmt.Printf("✓ [%d blocks]\n", numBlocks)
|
||||
|
||||
// Load conv_norm_out
|
||||
fmt.Print(" Loading conv_norm_out... ")
|
||||
normWeight, err := weights.GetTensor("decoder.conv_norm_out.weight")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
normBias, err := weights.GetTensor("decoder.conv_norm_out.bias")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.ConvNormOut = NewGroupNorm(normWeight, normBias, cfg.NormNumGroups)
|
||||
fmt.Println("✓")
|
||||
|
||||
// Load conv_out
|
||||
fmt.Print(" Loading conv_out... ")
|
||||
convOutWeight, err := weights.GetTensor("decoder.conv_out.weight")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
convOutBias, err := weights.GetTensor("decoder.conv_out.bias")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.ConvOut = NewConv2D(convOutWeight, convOutBias, 1, 1)
|
||||
fmt.Println("✓")
|
||||
|
||||
weights.ReleaseAll()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Decode decodes latents to images.
|
||||
// Uses staged pools to free intermediate arrays and reduce peak memory.
|
||||
func (vae *VAEDecoder) Decode(latents *mlx.Array) *mlx.Array {
|
||||
var h *mlx.Array
|
||||
{
|
||||
z := mlx.DivScalar(latents, vae.Config.ScalingFactor)
|
||||
z = mlx.AddScalar(z, vae.Config.ShiftFactor)
|
||||
h = vae.ConvIn.Forward(z)
|
||||
mlx.Eval(h)
|
||||
}
|
||||
|
||||
h = vae.MidBlock.Forward(h)
|
||||
|
||||
for _, upBlock := range vae.UpBlocks {
|
||||
h = upBlock.Forward(h)
|
||||
}
|
||||
|
||||
{
|
||||
prev := h
|
||||
h = vae.ConvNormOut.Forward(h)
|
||||
h = mlx.SiLU(h)
|
||||
h = vae.ConvOut.Forward(h)
|
||||
// VAE outputs [-1, 1], convert to [0, 1]
|
||||
h = mlx.AddScalar(mlx.MulScalar(h, 0.5), 0.5)
|
||||
h = mlx.ClipScalar(h, 0.0, 1.0, true, true)
|
||||
prev.Free()
|
||||
mlx.Eval(h)
|
||||
}
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// Upsample2x performs 2x nearest neighbor upsampling using broadcast.
|
||||
// x: [B, C, H, W] -> [B, C, H*2, W*2]
|
||||
func Upsample2x(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
C := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
|
||||
// [B, C, H, W] -> [B, C, H, 1, W, 1]
|
||||
x = mlx.Reshape(x, B, C, H, 1, W, 1)
|
||||
// Broadcast to [B, C, H, 2, W, 2]
|
||||
x = mlx.BroadcastTo(x, []int32{B, C, H, 2, W, 2})
|
||||
// Reshape to [B, C, H*2, W*2]
|
||||
x = mlx.Reshape(x, B, C, H*2, W*2)
|
||||
|
||||
return x
|
||||
}
|
||||
363
x/imagegen/models/zimage/zimage.go
Normal file
363
x/imagegen/models/zimage/zimage.go
Normal file
@@ -0,0 +1,363 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package zimage implements the Z-Image diffusion transformer model.
|
||||
package zimage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
)
|
||||
|
||||
// GenerateConfig holds all options for image generation.
|
||||
type GenerateConfig struct {
|
||||
Prompt string
|
||||
NegativePrompt string // Empty = no CFG
|
||||
CFGScale float32 // Only used if NegativePrompt is set (default: 4.0)
|
||||
Width int32 // Image width (default: 1024)
|
||||
Height int32 // Image height (default: 1024)
|
||||
Steps int // Denoising steps (default: 9 for turbo)
|
||||
Seed int64 // Random seed
|
||||
Progress ProgressFunc // Optional progress callback
|
||||
CapturePath string // GPU capture path (debug)
|
||||
|
||||
// Layer caching options (speedup via shallow layer reuse)
|
||||
LayerCache bool // Enable layer caching (default: false)
|
||||
CacheInterval int // Refresh cache every N steps (default: 3)
|
||||
CacheLayers int // Number of shallow layers to cache (default: 15)
|
||||
}
|
||||
|
||||
// ProgressFunc is called during generation with step progress.
|
||||
type ProgressFunc func(step, totalSteps int)
|
||||
|
||||
// Model represents a Z-Image diffusion model.
|
||||
type Model struct {
|
||||
ModelPath string
|
||||
Tokenizer *tokenizer.Tokenizer
|
||||
TextEncoder *Qwen3TextEncoder
|
||||
Transformer *Transformer
|
||||
VAEDecoder *VAEDecoder
|
||||
}
|
||||
|
||||
// Load loads the Z-Image model from a directory.
|
||||
func (m *Model) Load(modelPath string) error {
|
||||
fmt.Println("Loading Z-Image model...")
|
||||
start := time.Now()
|
||||
|
||||
if mlx.GPUIsAvailable() {
|
||||
mlx.SetDefaultDeviceGPU()
|
||||
mlx.EnableCompile()
|
||||
}
|
||||
|
||||
m.ModelPath = modelPath
|
||||
|
||||
// Load tokenizer
|
||||
fmt.Print(" Loading tokenizer... ")
|
||||
tokenizerPath := filepath.Join(modelPath, "tokenizer", "tokenizer.json")
|
||||
tok, err := tokenizer.Load(tokenizerPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("tokenizer: %w", err)
|
||||
}
|
||||
m.Tokenizer = tok
|
||||
fmt.Println("✓")
|
||||
|
||||
// Load text encoder
|
||||
m.TextEncoder = &Qwen3TextEncoder{}
|
||||
if err := m.TextEncoder.Load(filepath.Join(modelPath, "text_encoder")); err != nil {
|
||||
return fmt.Errorf("text encoder: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.TextEncoder)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
// Load transformer
|
||||
m.Transformer = &Transformer{}
|
||||
if err := m.Transformer.Load(filepath.Join(modelPath, "transformer")); err != nil {
|
||||
return fmt.Errorf("transformer: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.Transformer)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
// Load VAE decoder
|
||||
m.VAEDecoder = &VAEDecoder{}
|
||||
if err := m.VAEDecoder.Load(filepath.Join(modelPath, "vae")); err != nil {
|
||||
return fmt.Errorf("VAE decoder: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.VAEDecoder)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
mem := mlx.MetalGetActiveMemory()
|
||||
fmt.Printf(" Loaded in %.2fs (%.1f GB VRAM)\n", time.Since(start).Seconds(), float64(mem)/(1024*1024*1024))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Generate creates an image from a prompt.
|
||||
func (m *Model) Generate(prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
|
||||
return m.GenerateFromConfig(&GenerateConfig{
|
||||
Prompt: prompt,
|
||||
Width: width,
|
||||
Height: height,
|
||||
Steps: steps,
|
||||
Seed: seed,
|
||||
})
|
||||
}
|
||||
|
||||
// GenerateWithProgress creates an image with progress callback.
|
||||
func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress ProgressFunc) (*mlx.Array, error) {
|
||||
return m.GenerateFromConfig(&GenerateConfig{
|
||||
Prompt: prompt,
|
||||
Width: width,
|
||||
Height: height,
|
||||
Steps: steps,
|
||||
Seed: seed,
|
||||
Progress: progress,
|
||||
})
|
||||
}
|
||||
|
||||
// GenerateWithCFG creates an image with classifier-free guidance.
|
||||
func (m *Model) GenerateWithCFG(prompt, negativePrompt string, width, height int32, steps int, seed int64, cfgScale float32, progress ProgressFunc) (*mlx.Array, error) {
|
||||
return m.GenerateFromConfig(&GenerateConfig{
|
||||
Prompt: prompt,
|
||||
NegativePrompt: negativePrompt,
|
||||
CFGScale: cfgScale,
|
||||
Width: width,
|
||||
Height: height,
|
||||
Steps: steps,
|
||||
Seed: seed,
|
||||
Progress: progress,
|
||||
})
|
||||
}
|
||||
|
||||
// GenerateFromConfig generates an image using the unified config struct.
|
||||
func (m *Model) GenerateFromConfig(cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
start := time.Now()
|
||||
result, err := m.generate(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if cfg.NegativePrompt != "" {
|
||||
fmt.Printf("Generated with CFG (scale=%.1f) in %.2fs (%d steps)\n", cfg.CFGScale, time.Since(start).Seconds(), cfg.Steps)
|
||||
} else {
|
||||
fmt.Printf("Generated in %.2fs (%d steps)\n", time.Since(start).Seconds(), cfg.Steps)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GenerateImage implements model.ImageModel interface.
|
||||
func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
|
||||
return m.Generate(prompt, width, height, steps, seed)
|
||||
}
|
||||
|
||||
// generate is the internal denoising pipeline.
|
||||
func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
// Apply defaults
|
||||
if cfg.Width <= 0 {
|
||||
cfg.Width = 1024
|
||||
}
|
||||
if cfg.Height <= 0 {
|
||||
cfg.Height = 1024
|
||||
}
|
||||
if cfg.Steps <= 0 {
|
||||
cfg.Steps = 9 // Turbo default
|
||||
}
|
||||
if cfg.CFGScale <= 0 {
|
||||
cfg.CFGScale = 4.0
|
||||
}
|
||||
if cfg.LayerCache {
|
||||
if cfg.CacheInterval <= 0 {
|
||||
cfg.CacheInterval = 3
|
||||
}
|
||||
if cfg.CacheLayers <= 0 {
|
||||
cfg.CacheLayers = 15 // Half of 30 layers
|
||||
}
|
||||
}
|
||||
|
||||
useCFG := cfg.NegativePrompt != ""
|
||||
tcfg := m.Transformer.TransformerConfig
|
||||
latentH := cfg.Height / 8
|
||||
latentW := cfg.Width / 8
|
||||
hTok := latentH / tcfg.PatchSize
|
||||
wTok := latentW / tcfg.PatchSize
|
||||
|
||||
// Text encoding with padding to multiple of 32
|
||||
var posEmb, negEmb *mlx.Array
|
||||
{
|
||||
posEmb, _ = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.Prompt, 512)
|
||||
if useCFG {
|
||||
negEmb, _ = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.NegativePrompt, 512)
|
||||
}
|
||||
|
||||
// Pad both to same length (multiple of 32)
|
||||
maxLen := posEmb.Shape()[1]
|
||||
if useCFG && negEmb.Shape()[1] > maxLen {
|
||||
maxLen = negEmb.Shape()[1]
|
||||
}
|
||||
if pad := (32 - (maxLen % 32)) % 32; pad > 0 {
|
||||
maxLen += pad
|
||||
}
|
||||
|
||||
posEmb = padToLength(posEmb, maxLen)
|
||||
if useCFG {
|
||||
negEmb = padToLength(negEmb, maxLen)
|
||||
mlx.Keep(posEmb, negEmb)
|
||||
mlx.Eval(posEmb, negEmb)
|
||||
} else {
|
||||
mlx.Keep(posEmb)
|
||||
mlx.Eval(posEmb)
|
||||
}
|
||||
}
|
||||
|
||||
// Scheduler
|
||||
scheduler := NewFlowMatchEulerScheduler(DefaultFlowMatchSchedulerConfig())
|
||||
scheduler.SetTimestepsWithMu(cfg.Steps, CalculateShift(hTok*wTok))
|
||||
|
||||
// Init latents [B, C, H, W]
|
||||
var latents *mlx.Array
|
||||
{
|
||||
latents = scheduler.InitNoise([]int32{1, tcfg.InChannels, latentH, latentW}, cfg.Seed)
|
||||
mlx.Eval(latents)
|
||||
}
|
||||
|
||||
// RoPE cache
|
||||
var ropeCache *RoPECache
|
||||
{
|
||||
ropeCache = m.Transformer.PrepareRoPECache(hTok, wTok, posEmb.Shape()[1])
|
||||
mlx.Keep(ropeCache.ImgCos, ropeCache.ImgSin, ropeCache.CapCos, ropeCache.CapSin,
|
||||
ropeCache.UnifiedCos, ropeCache.UnifiedSin)
|
||||
mlx.Eval(ropeCache.UnifiedCos)
|
||||
}
|
||||
|
||||
// Step cache for shallow layer reuse (DeepCache/Learning-to-Cache style)
|
||||
var stepCache *cache.StepCache
|
||||
if cfg.LayerCache {
|
||||
stepCache = cache.NewStepCache(cfg.CacheLayers)
|
||||
fmt.Printf(" Layer caching enabled: %d layers, refresh every %d steps\n",
|
||||
cfg.CacheLayers, cfg.CacheInterval)
|
||||
}
|
||||
|
||||
// Denoising loop
|
||||
for i := 0; i < cfg.Steps; i++ {
|
||||
stepStart := time.Now()
|
||||
if cfg.Progress != nil {
|
||||
cfg.Progress(i+1, cfg.Steps)
|
||||
}
|
||||
|
||||
// GPU capture on step 2 if requested
|
||||
if cfg.CapturePath != "" && i == 1 {
|
||||
mlx.MetalStartCapture(cfg.CapturePath)
|
||||
}
|
||||
|
||||
tCurr := scheduler.Timesteps[i]
|
||||
timestep := mlx.ToBFloat16(mlx.NewArray([]float32{1.0 - tCurr}, []int32{1}))
|
||||
|
||||
patches := PatchifyLatents(latents, tcfg.PatchSize)
|
||||
|
||||
var output *mlx.Array
|
||||
if stepCache != nil {
|
||||
// Use layer caching for faster inference
|
||||
if useCFG {
|
||||
posOutput := m.Transformer.ForwardWithCache(patches, timestep, posEmb, ropeCache,
|
||||
stepCache, i, cfg.CacheInterval)
|
||||
// Note: CFG with layer cache shares the cache between pos/neg
|
||||
// This is approximate but fast - neg prompt uses same cached shallow layers
|
||||
negOutput := m.Transformer.ForwardWithCache(patches, timestep, negEmb, ropeCache,
|
||||
stepCache, i, cfg.CacheInterval)
|
||||
diff := mlx.Sub(posOutput, negOutput)
|
||||
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
|
||||
output = mlx.Add(negOutput, scaledDiff)
|
||||
} else {
|
||||
output = m.Transformer.ForwardWithCache(patches, timestep, posEmb, ropeCache,
|
||||
stepCache, i, cfg.CacheInterval)
|
||||
}
|
||||
} else {
|
||||
// Standard forward without caching
|
||||
if useCFG {
|
||||
posOutput := m.Transformer.Forward(patches, timestep, posEmb, ropeCache)
|
||||
negOutput := m.Transformer.Forward(patches, timestep, negEmb, ropeCache)
|
||||
diff := mlx.Sub(posOutput, negOutput)
|
||||
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
|
||||
output = mlx.Add(negOutput, scaledDiff)
|
||||
} else {
|
||||
output = m.Transformer.Forward(patches, timestep, posEmb, ropeCache)
|
||||
}
|
||||
}
|
||||
|
||||
noisePred := UnpatchifyLatents(output, tcfg.PatchSize, latentH, latentW, tcfg.InChannels)
|
||||
noisePred = mlx.Neg(noisePred)
|
||||
oldLatents := latents
|
||||
latents = scheduler.Step(noisePred, latents, i)
|
||||
|
||||
// Keep latents and any cached arrays
|
||||
if stepCache != nil {
|
||||
mlx.Keep(stepCache.Arrays()...)
|
||||
}
|
||||
mlx.Eval(latents)
|
||||
oldLatents.Free()
|
||||
|
||||
if cfg.CapturePath != "" && i == 1 {
|
||||
mlx.MetalStopCapture()
|
||||
}
|
||||
|
||||
activeMem := float64(mlx.MetalGetActiveMemory()) / (1024 * 1024 * 1024)
|
||||
peakMem := float64(mlx.MetalGetPeakMemory()) / (1024 * 1024 * 1024)
|
||||
fmt.Printf(" Step %d/%d: t=%.4f (%.2fs) [%.1f GB active, %.1f GB peak]\n",
|
||||
i+1, cfg.Steps, tCurr, time.Since(stepStart).Seconds(), activeMem, peakMem)
|
||||
}
|
||||
|
||||
// Free denoising temporaries before VAE decode
|
||||
posEmb.Free()
|
||||
if negEmb != nil {
|
||||
negEmb.Free()
|
||||
}
|
||||
ropeCache.ImgCos.Free()
|
||||
ropeCache.ImgSin.Free()
|
||||
ropeCache.CapCos.Free()
|
||||
ropeCache.CapSin.Free()
|
||||
ropeCache.UnifiedCos.Free()
|
||||
ropeCache.UnifiedSin.Free()
|
||||
if stepCache != nil {
|
||||
stepCache.Free()
|
||||
}
|
||||
|
||||
// VAE decode
|
||||
decoded := m.VAEDecoder.Decode(latents)
|
||||
latents.Free()
|
||||
|
||||
return decoded, nil
|
||||
}
|
||||
|
||||
// padToLength pads a sequence tensor to the target length by repeating the last token.
|
||||
func padToLength(x *mlx.Array, targetLen int32) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
currentLen := shape[1]
|
||||
if currentLen >= targetLen {
|
||||
return x
|
||||
}
|
||||
padLen := targetLen - currentLen
|
||||
lastToken := mlx.Slice(x, []int32{0, currentLen - 1, 0}, []int32{shape[0], currentLen, shape[2]})
|
||||
padding := mlx.Tile(lastToken, []int32{1, padLen, 1})
|
||||
return mlx.Concatenate([]*mlx.Array{x, padding}, 1)
|
||||
}
|
||||
|
||||
// CalculateShift computes the mu shift value for dynamic scheduling
|
||||
func CalculateShift(imgSeqLen int32) float32 {
|
||||
baseSeqLen := float32(256)
|
||||
maxSeqLen := float32(4096)
|
||||
baseShift := float32(0.5)
|
||||
maxShift := float32(1.15)
|
||||
|
||||
m := (maxShift - baseShift) / (maxSeqLen - baseSeqLen)
|
||||
b := baseShift - m*baseSeqLen
|
||||
return float32(imgSeqLen)*m + b
|
||||
}
|
||||
203
x/imagegen/nn/nn.go
Normal file
203
x/imagegen/nn/nn.go
Normal file
@@ -0,0 +1,203 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package nn provides neural network layer types.
|
||||
package nn
|
||||
|
||||
import "github.com/ollama/ollama/x/imagegen/mlx"
|
||||
|
||||
// Layer is the interface for neural network layers with a Forward method.
|
||||
type Layer interface {
|
||||
Forward(x *mlx.Array) *mlx.Array
|
||||
}
|
||||
|
||||
// Linear applies an affine transformation: y = x @ W.T + b
|
||||
// Weight is stored as [out_features, in_features], matching PyTorch/MLX convention.
|
||||
type Linear struct {
|
||||
Weight *mlx.Array `weight:"weight"` // [out_features, in_features]
|
||||
Bias *mlx.Array `weight:"bias,optional"` // [out_features] or nil
|
||||
}
|
||||
|
||||
// NewLinear creates a linear layer.
|
||||
// Weight should be [out_features, in_features].
|
||||
func NewLinear(weight *mlx.Array, bias *mlx.Array) *Linear {
|
||||
return &Linear{Weight: weight, Bias: bias}
|
||||
}
|
||||
|
||||
// NewQuantizedLinear creates a quantized linear layer directly from bf16 weights.
|
||||
// Quantizes the weight immediately and evaluates to break lazy dependencies.
|
||||
func NewQuantizedLinear(weight *mlx.Array, bias *mlx.Array, groupSize, bits int, mode string) *QuantizedLinear {
|
||||
qw, scales, qbiases := mlx.Quantize(weight, groupSize, bits, mode)
|
||||
// Eval immediately so bf16 weight can be freed
|
||||
mlx.Eval(qw, scales, qbiases)
|
||||
return &QuantizedLinear{
|
||||
Weight: qw,
|
||||
Scales: scales,
|
||||
QBiases: qbiases,
|
||||
Bias: bias,
|
||||
GroupSize: groupSize,
|
||||
Bits: bits,
|
||||
Mode: mode,
|
||||
}
|
||||
}
|
||||
|
||||
// Forward applies the linear transformation: x @ W.T + bias
|
||||
func (l *Linear) Forward(x *mlx.Array) *mlx.Array {
|
||||
w := mlx.Transpose(l.Weight, 1, 0)
|
||||
if l.Bias != nil {
|
||||
return mlx.AddMM(l.Bias, x, w, 1.0, 1.0)
|
||||
}
|
||||
return mlx.Linear(x, w)
|
||||
}
|
||||
|
||||
// ToQuantized converts this Linear to a QuantizedLinear.
|
||||
func (l *Linear) ToQuantized(groupSize, bits int, mode string) *QuantizedLinear {
|
||||
qw, scales, qbiases := mlx.Quantize(l.Weight, groupSize, bits, mode)
|
||||
return &QuantizedLinear{
|
||||
Weight: qw,
|
||||
Scales: scales,
|
||||
QBiases: qbiases,
|
||||
Bias: l.Bias,
|
||||
GroupSize: groupSize,
|
||||
Bits: bits,
|
||||
Mode: mode,
|
||||
}
|
||||
}
|
||||
|
||||
// QuantizedLinear applies an affine transformation using quantized weights.
|
||||
// Equivalent to mlx.nn.QuantizedLinear.
|
||||
type QuantizedLinear struct {
|
||||
Weight *mlx.Array // Quantized weight data
|
||||
Scales *mlx.Array // Scale factors for dequantization
|
||||
QBiases *mlx.Array // Quantization biases (NOT layer bias)
|
||||
Bias *mlx.Array // Layer bias [output_dims] or nil
|
||||
GroupSize int
|
||||
Bits int
|
||||
Mode string
|
||||
}
|
||||
|
||||
// Forward applies the quantized linear transformation.
|
||||
func (ql *QuantizedLinear) Forward(x *mlx.Array) *mlx.Array {
|
||||
out := mlx.QuantizedMatmul(x, ql.Weight, ql.Scales, ql.QBiases, true, ql.GroupSize, ql.Bits, ql.Mode)
|
||||
if ql.Bias != nil {
|
||||
out = mlx.Add(out, ql.Bias)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// RMSNorm represents an RMS normalization layer.
|
||||
type RMSNorm struct {
|
||||
Weight *mlx.Array `weight:"weight"`
|
||||
Eps float32 // optional: used if Forward called with eps=0
|
||||
}
|
||||
|
||||
// NewRMSNorm creates an RMSNorm layer (for models not using weight loader).
|
||||
func NewRMSNorm(weight *mlx.Array, eps float32) *RMSNorm {
|
||||
return &RMSNorm{Weight: weight, Eps: eps}
|
||||
}
|
||||
|
||||
// Forward applies RMS normalization. If eps=0, uses stored Eps.
|
||||
func (rn *RMSNorm) Forward(x *mlx.Array, eps float32) *mlx.Array {
|
||||
if eps == 0 {
|
||||
eps = rn.Eps
|
||||
}
|
||||
return mlx.RMSNorm(x, rn.Weight, eps)
|
||||
}
|
||||
|
||||
// Embedding represents an embedding layer.
|
||||
type Embedding struct {
|
||||
Weight *mlx.Array `weight:"weight"`
|
||||
}
|
||||
|
||||
// NewEmbedding creates an embedding layer.
|
||||
func NewEmbedding(weight *mlx.Array) *Embedding {
|
||||
return &Embedding{Weight: weight}
|
||||
}
|
||||
|
||||
// Forward looks up embeddings by indices.
|
||||
func (e *Embedding) Forward(indices *mlx.Array) *mlx.Array {
|
||||
return mlx.Take(e.Weight, indices, 0)
|
||||
}
|
||||
|
||||
// RepeatKV repeats K/V tensors for grouped query attention
|
||||
// x: [B, num_kv_heads, S, head_dim] -> [B, num_heads, S, head_dim]
|
||||
func RepeatKV(x *mlx.Array, repeatFactor int32) *mlx.Array {
|
||||
if repeatFactor == 1 {
|
||||
return x
|
||||
}
|
||||
shape := x.Shape()
|
||||
// [B, num_kv_heads, S, head_dim] -> [B, num_kv_heads, 1, S, head_dim]
|
||||
x = mlx.ExpandDims(x, 2)
|
||||
// Repeat along the new axis
|
||||
reps := []int32{1, 1, repeatFactor, 1, 1}
|
||||
x = mlx.Tile(x, reps)
|
||||
// Reshape: [B, num_kv_heads, repeat, S, head_dim] -> [B, num_kv_heads * repeat, S, head_dim]
|
||||
return mlx.Reshape(x, shape[0], shape[1]*repeatFactor, shape[2], shape[3])
|
||||
}
|
||||
|
||||
// ApplyCausalMask applies causal (lower triangular) mask to attention scores
|
||||
func ApplyCausalMask(scores *mlx.Array) *mlx.Array {
|
||||
// scores: [B, num_heads, S, S]
|
||||
shape := scores.Shape()
|
||||
seqLen := shape[2]
|
||||
|
||||
// Create causal mask: 1 for positions to keep, 0 for positions to mask
|
||||
mask := mlx.Tri(seqLen, seqLen, 0)
|
||||
|
||||
// Where mask is 0, set score to -inf
|
||||
negInf := mlx.NewScalarArray(float32(-1e9))
|
||||
|
||||
// Broadcast mask to match scores shape
|
||||
mask = mlx.ExpandDims(mlx.ExpandDims(mask, 0), 0) // [1, 1, S, S]
|
||||
|
||||
// Use where: if mask > 0, keep scores, else -inf
|
||||
return mlx.Where(mask, scores, negInf)
|
||||
}
|
||||
|
||||
// ApplyCausalMaskWithOffset applies causal mask for cached attention
|
||||
// scores: [B, num_heads, queryLen, keyLen] where keyLen = cacheLen + queryLen
|
||||
// offset: the starting position of the new queries (i.e., cache length)
|
||||
func ApplyCausalMaskWithOffset(scores *mlx.Array, offset int32) *mlx.Array {
|
||||
if offset == 0 {
|
||||
return ApplyCausalMask(scores)
|
||||
}
|
||||
|
||||
shape := scores.Shape()
|
||||
queryLen := shape[2]
|
||||
keyLen := shape[3]
|
||||
|
||||
// For cached attention, new queries can attend to all cached keys plus
|
||||
// new keys up to and including their position.
|
||||
mask := mlx.Tri(queryLen, keyLen, int(offset))
|
||||
|
||||
negInf := mlx.NewScalarArray(float32(-1e9))
|
||||
mask = mlx.ExpandDims(mlx.ExpandDims(mask, 0), 0) // [1, 1, queryLen, keyLen]
|
||||
|
||||
return mlx.Where(mask, scores, negInf)
|
||||
}
|
||||
|
||||
// LayerNorm represents a standard layer normalization layer (with bias).
|
||||
type LayerNorm struct {
|
||||
Weight *mlx.Array `weight:"weight"`
|
||||
Bias *mlx.Array `weight:"bias"`
|
||||
Eps float32
|
||||
}
|
||||
|
||||
// Forward applies layer normalization: (x - mean) / sqrt(var + eps) * weight + bias
|
||||
func (ln *LayerNorm) Forward(x *mlx.Array) *mlx.Array {
|
||||
eps := ln.Eps
|
||||
if eps == 0 {
|
||||
eps = 1e-5
|
||||
}
|
||||
// Compute mean and variance along last dimension
|
||||
mean := mlx.Mean(x, -1, true)
|
||||
centered := mlx.Sub(x, mean)
|
||||
variance := mlx.Mean(mlx.Mul(centered, centered), -1, true)
|
||||
normalized := mlx.Mul(centered, mlx.RSqrt(mlx.AddScalar(variance, eps)))
|
||||
|
||||
// Scale and shift
|
||||
out := mlx.Mul(normalized, ln.Weight)
|
||||
if ln.Bias != nil {
|
||||
out = mlx.Add(out, ln.Bias)
|
||||
}
|
||||
return out
|
||||
}
|
||||
356
x/imagegen/nn/nn_test.go
Normal file
356
x/imagegen/nn/nn_test.go
Normal file
@@ -0,0 +1,356 @@
|
||||
//go:build mlx
|
||||
|
||||
package nn
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// TestLinearNoBias verifies Linear without bias computes x @ w.T correctly.
|
||||
func TestLinearNoBias(t *testing.T) {
|
||||
// Weight: [out=2, in=3] -> transposed at forward time
|
||||
weight := mlx.NewArrayFloat32([]float32{
|
||||
1, 2, 3, // row 0
|
||||
4, 5, 6, // row 1
|
||||
}, []int32{2, 3})
|
||||
mlx.Eval(weight)
|
||||
|
||||
linear := NewLinear(weight, nil)
|
||||
|
||||
// Input: [1, 3]
|
||||
x := mlx.NewArrayFloat32([]float32{1, 1, 1}, []int32{1, 3})
|
||||
mlx.Eval(x)
|
||||
|
||||
out := linear.Forward(x)
|
||||
mlx.Eval(out)
|
||||
|
||||
// Expected: [1,1,1] @ [[1,4],[2,5],[3,6]] = [6, 15]
|
||||
data := out.Data()
|
||||
if len(data) != 2 || data[0] != 6 || data[1] != 15 {
|
||||
t.Errorf("expected [6, 15], got %v", data)
|
||||
}
|
||||
}
|
||||
|
||||
// TestLinearWithBias verifies Linear with bias computes x @ w.T + b correctly.
|
||||
func TestLinearWithBias(t *testing.T) {
|
||||
weight := mlx.NewArrayFloat32([]float32{
|
||||
1, 2, 3,
|
||||
4, 5, 6,
|
||||
}, []int32{2, 3})
|
||||
bias := mlx.NewArrayFloat32([]float32{10, 20}, []int32{2})
|
||||
mlx.Eval(weight, bias)
|
||||
|
||||
linear := NewLinear(weight, bias)
|
||||
|
||||
x := mlx.NewArrayFloat32([]float32{1, 1, 1}, []int32{1, 3})
|
||||
mlx.Eval(x)
|
||||
|
||||
out := linear.Forward(x)
|
||||
mlx.Eval(out)
|
||||
|
||||
// Expected: [6, 15] + [10, 20] = [16, 35]
|
||||
data := out.Data()
|
||||
if len(data) != 2 || data[0] != 16 || data[1] != 35 {
|
||||
t.Errorf("expected [16, 35], got %v", data)
|
||||
}
|
||||
}
|
||||
|
||||
// TestLinearBatched verifies Linear works with batched input.
|
||||
func TestLinearBatched(t *testing.T) {
|
||||
weight := mlx.NewArrayFloat32([]float32{
|
||||
1, 0,
|
||||
0, 1,
|
||||
}, []int32{2, 2}) // Identity
|
||||
mlx.Eval(weight)
|
||||
|
||||
linear := NewLinear(weight, nil)
|
||||
|
||||
// Batch of 3 inputs
|
||||
x := mlx.NewArrayFloat32([]float32{
|
||||
1, 2,
|
||||
3, 4,
|
||||
5, 6,
|
||||
}, []int32{3, 2})
|
||||
mlx.Eval(x)
|
||||
|
||||
out := linear.Forward(x)
|
||||
mlx.Eval(out)
|
||||
|
||||
// Identity should return same values
|
||||
data := out.Data()
|
||||
expected := []float32{1, 2, 3, 4, 5, 6}
|
||||
for i, v := range expected {
|
||||
if data[i] != v {
|
||||
t.Errorf("at %d: expected %f, got %f", i, v, data[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestRMSNorm verifies RMSNorm computation.
|
||||
func TestRMSNorm(t *testing.T) {
|
||||
weight := mlx.NewArrayFloat32([]float32{1, 1, 1, 1}, []int32{4})
|
||||
mlx.Eval(weight)
|
||||
|
||||
norm := NewRMSNorm(weight, 1e-5)
|
||||
|
||||
// Input with known RMS
|
||||
x := mlx.NewArrayFloat32([]float32{2, 2, 2, 2}, []int32{1, 4})
|
||||
mlx.Eval(x)
|
||||
|
||||
out := norm.Forward(x, 0) // eps=0 uses stored Eps
|
||||
mlx.Eval(out)
|
||||
|
||||
// RMS of [2,2,2,2] = 2, so normalized = [1,1,1,1]
|
||||
data := out.Data()
|
||||
for i, v := range data {
|
||||
if math.Abs(float64(v-1.0)) > 1e-4 {
|
||||
t.Errorf("at %d: expected ~1.0, got %f", i, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestRMSNormWithScale verifies RMSNorm applies weight scaling.
|
||||
func TestRMSNormWithScale(t *testing.T) {
|
||||
weight := mlx.NewArrayFloat32([]float32{2, 2, 2, 2}, []int32{4})
|
||||
mlx.Eval(weight)
|
||||
|
||||
norm := NewRMSNorm(weight, 1e-5)
|
||||
|
||||
x := mlx.NewArrayFloat32([]float32{2, 2, 2, 2}, []int32{1, 4})
|
||||
mlx.Eval(x)
|
||||
|
||||
out := norm.Forward(x, 0) // eps=0 uses stored Eps
|
||||
mlx.Eval(out)
|
||||
|
||||
// Normalized [1,1,1,1] * weight [2,2,2,2] = [2,2,2,2]
|
||||
data := out.Data()
|
||||
for i, v := range data {
|
||||
if math.Abs(float64(v-2.0)) > 1e-4 {
|
||||
t.Errorf("at %d: expected ~2.0, got %f", i, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestEmbedding verifies embedding lookup.
|
||||
func TestEmbedding(t *testing.T) {
|
||||
// Embedding table: 4 tokens, dim 3
|
||||
weight := mlx.NewArrayFloat32([]float32{
|
||||
0, 0, 0, // token 0
|
||||
1, 1, 1, // token 1
|
||||
2, 2, 2, // token 2
|
||||
3, 3, 3, // token 3
|
||||
}, []int32{4, 3})
|
||||
mlx.Eval(weight)
|
||||
|
||||
emb := NewEmbedding(weight)
|
||||
|
||||
// Look up tokens [1, 3, 0]
|
||||
indices := mlx.NewArrayInt32([]int32{1, 3, 0}, []int32{3})
|
||||
mlx.Eval(indices)
|
||||
|
||||
out := emb.Forward(indices)
|
||||
mlx.Eval(out)
|
||||
|
||||
data := out.Data()
|
||||
expected := []float32{1, 1, 1, 3, 3, 3, 0, 0, 0}
|
||||
for i, v := range expected {
|
||||
if data[i] != v {
|
||||
t.Errorf("at %d: expected %f, got %f", i, v, data[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestRepeatKV verifies K/V repetition for GQA.
|
||||
func TestRepeatKV(t *testing.T) {
|
||||
// [B=1, num_kv_heads=2, S=2, head_dim=2]
|
||||
x := mlx.NewArrayFloat32([]float32{
|
||||
// head 0
|
||||
1, 2, // pos 0
|
||||
3, 4, // pos 1
|
||||
// head 1
|
||||
5, 6, // pos 0
|
||||
7, 8, // pos 1
|
||||
}, []int32{1, 2, 2, 2})
|
||||
mlx.Eval(x)
|
||||
|
||||
// Repeat factor 2: 2 kv heads -> 4 heads
|
||||
out := RepeatKV(x, 2)
|
||||
mlx.Eval(out)
|
||||
|
||||
shape := out.Shape()
|
||||
if shape[0] != 1 || shape[1] != 4 || shape[2] != 2 || shape[3] != 2 {
|
||||
t.Errorf("expected shape [1,4,2,2], got %v", shape)
|
||||
}
|
||||
|
||||
data := out.Data()
|
||||
// After repeat: head0, head0, head1, head1
|
||||
expected := []float32{
|
||||
1, 2, 3, 4, // head 0 (original)
|
||||
1, 2, 3, 4, // head 0 (repeat)
|
||||
5, 6, 7, 8, // head 1 (original)
|
||||
5, 6, 7, 8, // head 1 (repeat)
|
||||
}
|
||||
for i, v := range expected {
|
||||
if data[i] != v {
|
||||
t.Errorf("at %d: expected %f, got %f", i, v, data[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestRepeatKVNoOp verifies RepeatKV with factor 1 returns input unchanged.
|
||||
func TestRepeatKVNoOp(t *testing.T) {
|
||||
x := mlx.NewArrayFloat32([]float32{1, 2, 3, 4}, []int32{1, 1, 2, 2})
|
||||
mlx.Eval(x)
|
||||
|
||||
out := RepeatKV(x, 1)
|
||||
// Should return same pointer
|
||||
if out != x {
|
||||
t.Error("RepeatKV with factor 1 should return input unchanged")
|
||||
}
|
||||
}
|
||||
|
||||
// TestApplyCausalMask verifies causal masking.
|
||||
func TestApplyCausalMask(t *testing.T) {
|
||||
// [B=1, heads=1, S=3, S=3] - all ones
|
||||
scores := mlx.Ones(1, 1, 3, 3)
|
||||
mlx.Eval(scores)
|
||||
|
||||
out := ApplyCausalMask(scores)
|
||||
mlx.Eval(out)
|
||||
|
||||
data := out.Data()
|
||||
// Lower triangular should be 1, upper should be -1e9
|
||||
// Row 0: [1, -inf, -inf]
|
||||
// Row 1: [1, 1, -inf]
|
||||
// Row 2: [1, 1, 1]
|
||||
if data[0] != 1 || data[1] >= 0 || data[2] >= 0 {
|
||||
t.Errorf("row 0 wrong: %v", data[0:3])
|
||||
}
|
||||
if data[3] != 1 || data[4] != 1 || data[5] >= 0 {
|
||||
t.Errorf("row 1 wrong: %v", data[3:6])
|
||||
}
|
||||
if data[6] != 1 || data[7] != 1 || data[8] != 1 {
|
||||
t.Errorf("row 2 wrong: %v", data[6:9])
|
||||
}
|
||||
}
|
||||
|
||||
// TestApplyCausalMaskWithOffset verifies causal masking with cache offset.
|
||||
func TestApplyCausalMaskWithOffset(t *testing.T) {
|
||||
// Simulating: cache has 2 tokens, adding 1 new query
|
||||
// scores: [B=1, heads=1, queryLen=1, keyLen=3]
|
||||
scores := mlx.Ones(1, 1, 1, 3)
|
||||
mlx.Eval(scores)
|
||||
|
||||
out := ApplyCausalMaskWithOffset(scores, 2)
|
||||
mlx.Eval(out)
|
||||
|
||||
data := out.Data()
|
||||
// With offset=2, query at position 2 can attend to all 3 positions
|
||||
if data[0] != 1 || data[1] != 1 || data[2] != 1 {
|
||||
t.Errorf("expected [1, 1, 1], got %v", data)
|
||||
}
|
||||
}
|
||||
|
||||
// TestApplyCausalMaskWithOffsetZero verifies offset=0 falls back to regular causal.
|
||||
func TestApplyCausalMaskWithOffsetZero(t *testing.T) {
|
||||
scores := mlx.Ones(1, 1, 2, 2)
|
||||
mlx.Eval(scores)
|
||||
|
||||
out := ApplyCausalMaskWithOffset(scores, 0)
|
||||
mlx.Eval(out)
|
||||
|
||||
data := out.Data()
|
||||
// Standard causal: [1, -inf], [1, 1]
|
||||
if data[0] != 1 || data[1] >= 0 {
|
||||
t.Errorf("row 0 wrong: %v", data[0:2])
|
||||
}
|
||||
if data[2] != 1 || data[3] != 1 {
|
||||
t.Errorf("row 1 wrong: %v", data[2:4])
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkLinearSmall benchmarks small Linear forward pass.
|
||||
func BenchmarkLinearSmall(b *testing.B) {
|
||||
weight := mlx.RandomNormal([]int32{256, 256}, 42)
|
||||
mlx.Eval(weight)
|
||||
|
||||
linear := NewLinear(weight, nil)
|
||||
|
||||
x := mlx.RandomNormal([]int32{1, 256}, 43)
|
||||
mlx.Eval(x)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
out := linear.Forward(x)
|
||||
mlx.Eval(out)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkLinearLarge benchmarks larger Linear forward pass.
|
||||
func BenchmarkLinearLarge(b *testing.B) {
|
||||
weight := mlx.RandomNormal([]int32{4096, 4096}, 42)
|
||||
mlx.Eval(weight)
|
||||
|
||||
linear := NewLinear(weight, nil)
|
||||
|
||||
x := mlx.RandomNormal([]int32{1, 4096}, 43)
|
||||
mlx.Eval(x)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
out := linear.Forward(x)
|
||||
mlx.Eval(out)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkRMSNorm benchmarks RMSNorm forward pass.
|
||||
func BenchmarkRMSNorm(b *testing.B) {
|
||||
weight := mlx.Ones(4096)
|
||||
mlx.Eval(weight)
|
||||
|
||||
norm := NewRMSNorm(weight, 1e-5)
|
||||
|
||||
x := mlx.RandomNormal([]int32{1, 4096}, 42)
|
||||
mlx.Eval(x)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
out := norm.Forward(x, 0)
|
||||
mlx.Eval(out)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkEmbedding benchmarks embedding lookup.
|
||||
func BenchmarkEmbedding(b *testing.B) {
|
||||
// Typical vocab size
|
||||
weight := mlx.RandomNormal([]int32{32000, 4096}, 42)
|
||||
mlx.Eval(weight)
|
||||
|
||||
emb := NewEmbedding(weight)
|
||||
|
||||
// Single token lookup
|
||||
indices := mlx.NewArrayInt32([]int32{1000}, []int32{1})
|
||||
mlx.Eval(indices)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
out := emb.Forward(indices)
|
||||
mlx.Eval(out)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkRepeatKV benchmarks K/V repetition.
|
||||
func BenchmarkRepeatKV(b *testing.B) {
|
||||
// Typical GQA setup: 8 kv heads -> 32 heads
|
||||
x := mlx.RandomNormal([]int32{1, 8, 512, 128}, 42)
|
||||
mlx.Eval(x)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
out := RepeatKV(x, 4)
|
||||
mlx.Eval(out)
|
||||
}
|
||||
}
|
||||
170
x/imagegen/safetensors/loader.go
Normal file
170
x/imagegen/safetensors/loader.go
Normal file
@@ -0,0 +1,170 @@
|
||||
//go:build mlx
|
||||
|
||||
package safetensors
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// LoadModule loads weights into a struct using reflection and struct tags.
|
||||
//
|
||||
// Struct tags use the format: `weight:"path[,optional]"`
|
||||
// - path: the weight name suffix (appended to prefix)
|
||||
// - optional: if present, missing weights don't cause errors
|
||||
// - "-": skip this field entirely
|
||||
// - no tag on struct pointer: recurse with current prefix
|
||||
// - no tag on *mlx.Array: skip (computed fields don't need loading)
|
||||
//
|
||||
// For slices of struct pointers, the loader iterates with .0, .1, .2... suffixes.
|
||||
// The slice must be pre-allocated to the correct length.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// type Attention struct {
|
||||
// QProj *nn.Linear `weight:"self_attn.q_proj"`
|
||||
// KProj *nn.Linear `weight:"self_attn.k_proj"`
|
||||
// Cache *mlx.Array // no tag = skipped (computed field)
|
||||
// }
|
||||
//
|
||||
// err := LoadModule(&attn, weights, "model.layers.0")
|
||||
func LoadModule(dst any, weights *ModelWeights, prefix string) error {
|
||||
v := reflect.ValueOf(dst)
|
||||
if v.Kind() != reflect.Ptr || v.IsNil() {
|
||||
return fmt.Errorf("LoadModule: dst must be a non-nil pointer")
|
||||
}
|
||||
v = v.Elem()
|
||||
if v.Kind() != reflect.Struct {
|
||||
return fmt.Errorf("LoadModule: dst must be a pointer to struct, got %v", v.Kind())
|
||||
}
|
||||
|
||||
var errs []string
|
||||
loadStruct(v, weights, prefix, &errs, false)
|
||||
|
||||
if len(errs) > 0 {
|
||||
return fmt.Errorf("LoadModule: missing weights:\n %s", strings.Join(errs, "\n "))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadStruct recursively loads weights into a struct value.
|
||||
func loadStruct(v reflect.Value, weights *ModelWeights, prefix string, errs *[]string, parentOptional bool) {
|
||||
t := v.Type()
|
||||
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
field := t.Field(i)
|
||||
fieldVal := v.Field(i)
|
||||
|
||||
// Skip unexported fields
|
||||
if !fieldVal.CanSet() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Parse tag
|
||||
tag, hasTag := field.Tag.Lookup("weight")
|
||||
if tag == "-" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Parse tag options
|
||||
optional := parentOptional
|
||||
weightPath := tag
|
||||
if idx := strings.Index(tag, ","); idx != -1 {
|
||||
weightPath = tag[:idx]
|
||||
if strings.Contains(tag[idx+1:], "optional") {
|
||||
optional = true
|
||||
}
|
||||
}
|
||||
|
||||
// Build full path
|
||||
fullPath := joinPath(prefix, weightPath)
|
||||
|
||||
// For struct pointers without a tag, recurse with current prefix
|
||||
if !hasTag && fieldVal.Kind() == reflect.Ptr {
|
||||
elemType := fieldVal.Type().Elem()
|
||||
if elemType.Kind() == reflect.Struct && elemType != reflect.TypeOf(mlx.Array{}) {
|
||||
if fieldVal.IsNil() {
|
||||
fieldVal.Set(reflect.New(elemType))
|
||||
}
|
||||
loadStruct(fieldVal.Elem(), weights, prefix, errs, optional)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Handle by kind
|
||||
switch fieldVal.Kind() {
|
||||
case reflect.Ptr:
|
||||
elemType := fieldVal.Type().Elem()
|
||||
|
||||
// *mlx.Array - load directly (but skip if no tag - computed fields)
|
||||
if fieldVal.Type() == reflect.TypeOf((*mlx.Array)(nil)) {
|
||||
if !hasTag {
|
||||
continue // no tag on *mlx.Array = computed field, skip
|
||||
}
|
||||
arr, err := weights.GetTensor(fullPath)
|
||||
if err != nil {
|
||||
if !optional {
|
||||
*errs = append(*errs, fullPath)
|
||||
}
|
||||
continue
|
||||
}
|
||||
fieldVal.Set(reflect.ValueOf(arr))
|
||||
continue
|
||||
}
|
||||
|
||||
// Pointer to struct - allocate and recurse
|
||||
if elemType.Kind() == reflect.Struct {
|
||||
if optional && !hasWeightsWithPrefix(weights, fullPath) {
|
||||
continue
|
||||
}
|
||||
if fieldVal.IsNil() {
|
||||
fieldVal.Set(reflect.New(elemType))
|
||||
}
|
||||
loadStruct(fieldVal.Elem(), weights, fullPath, errs, optional)
|
||||
}
|
||||
|
||||
case reflect.Slice:
|
||||
elemType := fieldVal.Type().Elem()
|
||||
if elemType.Kind() == reflect.Ptr && elemType.Elem().Kind() == reflect.Struct {
|
||||
loadSlice(fieldVal, weights, fullPath, errs)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// hasWeightsWithPrefix checks if any weights exist with the given prefix.
|
||||
func hasWeightsWithPrefix(weights *ModelWeights, prefix string) bool {
|
||||
for _, name := range weights.ListTensors() {
|
||||
if strings.HasPrefix(name, prefix+".") || name == prefix {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// loadSlice loads weights into each element of a slice of struct pointers.
|
||||
func loadSlice(v reflect.Value, weights *ModelWeights, prefix string, errs *[]string) {
|
||||
elemStructType := v.Type().Elem().Elem()
|
||||
|
||||
for i := 0; i < v.Len(); i++ {
|
||||
elem := v.Index(i)
|
||||
if elem.IsNil() {
|
||||
elem.Set(reflect.New(elemStructType))
|
||||
}
|
||||
loadStruct(elem.Elem(), weights, fmt.Sprintf("%s.%d", prefix, i), errs, false)
|
||||
}
|
||||
}
|
||||
|
||||
// joinPath joins path segments with dots, handling empty segments.
|
||||
func joinPath(prefix, suffix string) string {
|
||||
if prefix == "" {
|
||||
return suffix
|
||||
}
|
||||
if suffix == "" {
|
||||
return prefix
|
||||
}
|
||||
return prefix + "." + suffix
|
||||
}
|
||||
280
x/imagegen/safetensors/safetensors.go
Normal file
280
x/imagegen/safetensors/safetensors.go
Normal file
@@ -0,0 +1,280 @@
|
||||
//go:build mlx
|
||||
|
||||
package safetensors
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// SafetensorHeader represents the JSON header of a safetensors file
|
||||
type SafetensorHeader map[string]TensorInfo
|
||||
|
||||
// TensorInfo contains metadata about a tensor
|
||||
type TensorInfo struct {
|
||||
Dtype string `json:"dtype"`
|
||||
Shape []int32 `json:"shape"`
|
||||
DataOffsets [2]int `json:"data_offsets"`
|
||||
}
|
||||
|
||||
// parseSafetensorHeader reads only the JSON header from a safetensors file.
|
||||
func parseSafetensorHeader(path string) (SafetensorHeader, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open file: %w", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var headerSize uint64
|
||||
if err := binary.Read(f, binary.LittleEndian, &headerSize); err != nil {
|
||||
return nil, fmt.Errorf("failed to read header size: %w", err)
|
||||
}
|
||||
|
||||
headerBytes := make([]byte, headerSize)
|
||||
if _, err := f.Read(headerBytes); err != nil {
|
||||
return nil, fmt.Errorf("failed to read header: %w", err)
|
||||
}
|
||||
|
||||
var header SafetensorHeader
|
||||
if err := json.Unmarshal(headerBytes, &header); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse header: %w", err)
|
||||
}
|
||||
|
||||
delete(header, "__metadata__")
|
||||
return header, nil
|
||||
}
|
||||
|
||||
// dtypeFromString converts safetensors dtype string to mlx.Dtype
|
||||
func dtypeFromString(s string) mlx.Dtype {
|
||||
switch strings.ToUpper(s) {
|
||||
case "F32", "FLOAT32":
|
||||
return mlx.DtypeFloat32
|
||||
case "F16", "FLOAT16":
|
||||
return mlx.DtypeFloat16
|
||||
case "BF16", "BFLOAT16":
|
||||
return mlx.DtypeBFloat16
|
||||
case "I32", "INT32":
|
||||
return mlx.DtypeInt32
|
||||
case "I64", "INT64":
|
||||
return mlx.DtypeInt64
|
||||
case "U8", "UINT8":
|
||||
return mlx.DtypeUint8
|
||||
default:
|
||||
return mlx.DtypeFloat32
|
||||
}
|
||||
}
|
||||
|
||||
// ModelWeights manages weights from multiple safetensor files.
|
||||
type ModelWeights struct {
|
||||
dir string // Model directory
|
||||
tensorFiles map[string]string // tensor name -> file path
|
||||
tensorInfo map[string]TensorInfo // tensor name -> metadata
|
||||
nativeCache map[string]*mlx.SafetensorsFile // file path -> loaded native handle
|
||||
cache map[string]*mlx.Array // tensor name -> array (after Load)
|
||||
}
|
||||
|
||||
// LoadModelWeights scans safetensor files and builds a tensor index.
|
||||
// This only reads JSON headers, not tensor data.
|
||||
func LoadModelWeights(dir string) (*ModelWeights, error) {
|
||||
mw := &ModelWeights{
|
||||
dir: dir,
|
||||
tensorFiles: make(map[string]string),
|
||||
tensorInfo: make(map[string]TensorInfo),
|
||||
nativeCache: make(map[string]*mlx.SafetensorsFile),
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read directory: %w", err)
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if strings.HasSuffix(entry.Name(), ".safetensors") {
|
||||
path := filepath.Join(dir, entry.Name())
|
||||
|
||||
header, err := parseSafetensorHeader(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse %s: %w", entry.Name(), err)
|
||||
}
|
||||
|
||||
for name, info := range header {
|
||||
mw.tensorFiles[name] = path
|
||||
mw.tensorInfo[name] = info
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(mw.tensorFiles) == 0 {
|
||||
return nil, fmt.Errorf("no safetensor files found in %s", dir)
|
||||
}
|
||||
|
||||
return mw, nil
|
||||
}
|
||||
|
||||
// Load loads all tensors into cache with the specified dtype.
|
||||
// If dtype is 0, tensors are loaded in their original dtype.
|
||||
// Automatically uses streaming (memory-efficient) when dtype conversion is needed,
|
||||
// or native loading when tensors are already in the target dtype.
|
||||
func (mw *ModelWeights) Load(dtype mlx.Dtype) error {
|
||||
if dtype == 0 {
|
||||
return mw.loadNative()
|
||||
}
|
||||
|
||||
// Check if any tensor needs conversion
|
||||
needsConversion := false
|
||||
for name := range mw.tensorFiles {
|
||||
info := mw.tensorInfo[name]
|
||||
if dtypeFromString(info.Dtype) != dtype {
|
||||
needsConversion = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if needsConversion {
|
||||
return mw.loadStreaming(dtype)
|
||||
}
|
||||
return mw.loadNative()
|
||||
}
|
||||
|
||||
// loadNative loads all tensors using the native memory-mapped loader.
|
||||
func (mw *ModelWeights) loadNative() error {
|
||||
mw.cache = make(map[string]*mlx.Array)
|
||||
|
||||
fileToTensors := make(map[string][]string)
|
||||
for name, path := range mw.tensorFiles {
|
||||
fileToTensors[path] = append(fileToTensors[path], name)
|
||||
}
|
||||
|
||||
for path, names := range fileToTensors {
|
||||
native, err := mlx.LoadSafetensorsNative(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load %s: %w", path, err)
|
||||
}
|
||||
|
||||
for _, name := range names {
|
||||
arr := native.Get(name)
|
||||
if arr == nil {
|
||||
native.Free()
|
||||
return fmt.Errorf("tensor %q not found in %s", name, path)
|
||||
}
|
||||
mw.cache[name] = arr
|
||||
}
|
||||
|
||||
mw.nativeCache[path] = native
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadStreaming loads tensors with dtype conversion.
|
||||
// Uses the same pattern as Python: replace each entry in the map after conversion,
|
||||
// so the original tensor loses its reference and can be freed.
|
||||
func (mw *ModelWeights) loadStreaming(dtype mlx.Dtype) error {
|
||||
mw.cache = make(map[string]*mlx.Array)
|
||||
|
||||
fileToTensors := make(map[string][]string)
|
||||
for name, path := range mw.tensorFiles {
|
||||
fileToTensors[path] = append(fileToTensors[path], name)
|
||||
}
|
||||
|
||||
for path, names := range fileToTensors {
|
||||
native, err := mlx.LoadSafetensorsNative(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load %s: %w", path, err)
|
||||
}
|
||||
|
||||
for _, name := range names {
|
||||
src := native.Get(name)
|
||||
if src == nil {
|
||||
native.Free()
|
||||
return fmt.Errorf("tensor %q not found in %s", name, path)
|
||||
}
|
||||
|
||||
dst := mlx.AsType(src, dtype)
|
||||
mlx.Eval(dst)
|
||||
native.Set(name, dst)
|
||||
mw.cache[name] = dst
|
||||
}
|
||||
|
||||
native.Free()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get returns a tensor from cache. Call Load() first.
|
||||
func (mw *ModelWeights) Get(name string) (*mlx.Array, error) {
|
||||
if mw.cache == nil {
|
||||
return nil, fmt.Errorf("cache not initialized: call Load() first")
|
||||
}
|
||||
arr, ok := mw.cache[name]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("tensor %q not found in cache", name)
|
||||
}
|
||||
return arr, nil
|
||||
}
|
||||
|
||||
// GetTensor loads a tensor using the native loader without caching.
|
||||
// For bulk loading, use Load() + Get() instead.
|
||||
func (mw *ModelWeights) GetTensor(name string) (*mlx.Array, error) {
|
||||
if mw.cache != nil {
|
||||
if arr, ok := mw.cache[name]; ok {
|
||||
return arr, nil
|
||||
}
|
||||
}
|
||||
|
||||
path, ok := mw.tensorFiles[name]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("tensor %q not found", name)
|
||||
}
|
||||
|
||||
native, ok := mw.nativeCache[path]
|
||||
if !ok {
|
||||
var err error
|
||||
native, err = mlx.LoadSafetensorsNative(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load %s: %w", path, err)
|
||||
}
|
||||
mw.nativeCache[path] = native
|
||||
}
|
||||
|
||||
return native.Get(name), nil
|
||||
}
|
||||
|
||||
// GetTensorInfo returns metadata about a tensor without loading it.
|
||||
func (mw *ModelWeights) GetTensorInfo(name string) (TensorInfo, bool) {
|
||||
info, ok := mw.tensorInfo[name]
|
||||
return info, ok
|
||||
}
|
||||
|
||||
// ListTensors returns all tensor names.
|
||||
func (mw *ModelWeights) ListTensors() []string {
|
||||
names := make([]string, 0, len(mw.tensorFiles))
|
||||
for name := range mw.tensorFiles {
|
||||
names = append(names, name)
|
||||
}
|
||||
sort.Strings(names)
|
||||
return names
|
||||
}
|
||||
|
||||
// HasTensor checks if a tensor exists.
|
||||
func (mw *ModelWeights) HasTensor(name string) bool {
|
||||
_, ok := mw.tensorFiles[name]
|
||||
return ok
|
||||
}
|
||||
|
||||
// ReleaseAll releases all cached native file handles.
|
||||
func (mw *ModelWeights) ReleaseAll() {
|
||||
for path, native := range mw.nativeCache {
|
||||
native.Free()
|
||||
delete(mw.nativeCache, path)
|
||||
}
|
||||
}
|
||||
|
||||
167
x/imagegen/safetensors/safetensors_test.go
Normal file
167
x/imagegen/safetensors/safetensors_test.go
Normal file
@@ -0,0 +1,167 @@
|
||||
//go:build mlx
|
||||
|
||||
package safetensors
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
func TestLoadModelWeights(t *testing.T) {
|
||||
// Skip if no model available
|
||||
modelDir := "../weights/gpt-oss-20b"
|
||||
if _, err := os.Stat(modelDir); os.IsNotExist(err) {
|
||||
t.Skip("model weights not available")
|
||||
}
|
||||
|
||||
mw, err := LoadModelWeights(modelDir)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadModelWeights: %v", err)
|
||||
}
|
||||
defer mw.ReleaseAll()
|
||||
|
||||
// Check we found tensors
|
||||
tensors := mw.ListTensors()
|
||||
if len(tensors) == 0 {
|
||||
t.Fatal("no tensors found")
|
||||
}
|
||||
t.Logf("found %d tensors", len(tensors))
|
||||
|
||||
// Check HasTensor
|
||||
if !mw.HasTensor(tensors[0]) {
|
||||
t.Errorf("HasTensor(%q) = false", tensors[0])
|
||||
}
|
||||
if mw.HasTensor("nonexistent.weight") {
|
||||
t.Error("HasTensor returned true for nonexistent tensor")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetTensor(t *testing.T) {
|
||||
modelDir := "../weights/gpt-oss-20b"
|
||||
if _, err := os.Stat(modelDir); os.IsNotExist(err) {
|
||||
t.Skip("model weights not available")
|
||||
}
|
||||
|
||||
mw, err := LoadModelWeights(modelDir)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadModelWeights: %v", err)
|
||||
}
|
||||
defer mw.ReleaseAll()
|
||||
|
||||
tensors := mw.ListTensors()
|
||||
if len(tensors) == 0 {
|
||||
t.Skip("no tensors")
|
||||
}
|
||||
|
||||
// Load first tensor
|
||||
arr, err := mw.GetTensor(tensors[0])
|
||||
if err != nil {
|
||||
t.Fatalf("GetTensor(%q): %v", tensors[0], err)
|
||||
}
|
||||
|
||||
// Verify it has a shape
|
||||
shape := arr.Shape()
|
||||
if len(shape) == 0 {
|
||||
t.Error("tensor has no shape")
|
||||
}
|
||||
t.Logf("%s: shape=%v dtype=%v", tensors[0], shape, arr.Dtype())
|
||||
}
|
||||
|
||||
func TestLoadWithDtype(t *testing.T) {
|
||||
modelDir := "../weights/gpt-oss-20b"
|
||||
if _, err := os.Stat(modelDir); os.IsNotExist(err) {
|
||||
t.Skip("model weights not available")
|
||||
}
|
||||
|
||||
mw, err := LoadModelWeights(modelDir)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadModelWeights: %v", err)
|
||||
}
|
||||
defer mw.ReleaseAll()
|
||||
|
||||
// Load all tensors as bfloat16
|
||||
if err := mw.Load(mlx.DtypeBFloat16); err != nil {
|
||||
t.Fatalf("Load: %v", err)
|
||||
}
|
||||
|
||||
// Get a tensor from cache
|
||||
tensors := mw.ListTensors()
|
||||
arr, err := mw.Get(tensors[0])
|
||||
if err != nil {
|
||||
t.Fatalf("Get: %v", err)
|
||||
}
|
||||
|
||||
// Verify dtype (unless it was already bf16)
|
||||
t.Logf("%s: dtype=%v", tensors[0], arr.Dtype())
|
||||
}
|
||||
|
||||
func TestLookupTensor(t *testing.T) {
|
||||
modelDir := "../weights/gpt-oss-20b"
|
||||
if _, err := os.Stat(modelDir); os.IsNotExist(err) {
|
||||
t.Skip("model weights not available")
|
||||
}
|
||||
|
||||
mw, err := LoadModelWeights(modelDir)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadModelWeights: %v", err)
|
||||
}
|
||||
defer mw.ReleaseAll()
|
||||
|
||||
// HasTensor returns false for nonexistent
|
||||
if mw.HasTensor("nonexistent") {
|
||||
t.Error("HasTensor should return false for nonexistent")
|
||||
}
|
||||
|
||||
// HasTensor returns true for existing tensor
|
||||
tensors := mw.ListTensors()
|
||||
if !mw.HasTensor(tensors[0]) {
|
||||
t.Error("HasTensor should return true for existing tensor")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSafetensorHeader(t *testing.T) {
|
||||
modelDir := "../weights/gpt-oss-20b"
|
||||
if _, err := os.Stat(modelDir); os.IsNotExist(err) {
|
||||
t.Skip("model weights not available")
|
||||
}
|
||||
|
||||
// Find a safetensors file
|
||||
entries, err := os.ReadDir(modelDir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var stFile string
|
||||
for _, e := range entries {
|
||||
if filepath.Ext(e.Name()) == ".safetensors" {
|
||||
stFile = filepath.Join(modelDir, e.Name())
|
||||
break
|
||||
}
|
||||
}
|
||||
if stFile == "" {
|
||||
t.Skip("no safetensors file found")
|
||||
}
|
||||
|
||||
header, err := parseSafetensorHeader(stFile)
|
||||
if err != nil {
|
||||
t.Fatalf("parseSafetensorHeader: %v", err)
|
||||
}
|
||||
|
||||
if len(header) == 0 {
|
||||
t.Error("header is empty")
|
||||
}
|
||||
|
||||
// Check a tensor has valid info
|
||||
for name, info := range header {
|
||||
if info.Dtype == "" {
|
||||
t.Errorf("%s: empty dtype", name)
|
||||
}
|
||||
if len(info.Shape) == 0 {
|
||||
t.Errorf("%s: empty shape", name)
|
||||
}
|
||||
break // just check one
|
||||
}
|
||||
}
|
||||
85
x/imagegen/tokenizer/README.md
Normal file
85
x/imagegen/tokenizer/README.md
Normal file
@@ -0,0 +1,85 @@
|
||||
# Tokenizer
|
||||
|
||||
Tokenizer for LLM inference supporting BPE, SentencePiece, and WordPiece algorithms. The goal of this package is to see if a pure Go tokenizer can be fast and correct. It primarily supports the `imagegen` models however it (or parts of it) could be considered to replace Ollama's tokenizer in the `model` package.
|
||||
|
||||
## Features
|
||||
|
||||
- **BPE (Byte Pair Encoding)** - GPT-2/Llama style with byte-level encoding
|
||||
- **SentencePiece** - Gemma style with `▁` space handling
|
||||
- **WordPiece** - BERT style with `##` continuation tokens
|
||||
- **Parallel encoding** - Automatic parallelization for inputs >4KB
|
||||
- **HuggingFace compatible** - Loads `tokenizer.json` directly
|
||||
|
||||
## Usage
|
||||
|
||||
```go
|
||||
import "github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
|
||||
// Load from HuggingFace model directory
|
||||
tok, err := tokenizer.Load("./weights/Llama-3.2-1B")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Encode text to token IDs
|
||||
ids := tok.Encode("Hello, world!", false) // false = don't add BOS
|
||||
|
||||
// Decode back to text
|
||||
text := tok.Decode(ids)
|
||||
|
||||
// Check special tokens
|
||||
if tok.IsEOS(ids[len(ids)-1]) {
|
||||
// End of sequence
|
||||
}
|
||||
```
|
||||
|
||||
## Performance
|
||||
|
||||
Benchmarks on Apple M3 Max:
|
||||
|
||||
| Input Size | Encode | Decode | Tokens |
|
||||
|------------|--------|--------|--------|
|
||||
| 1 KB | 14.5 MB/s | 267 MB/s | 231 |
|
||||
| 10 KB | 10.9 MB/s | 321 MB/s | 2,301 |
|
||||
| 100 KB | 8.9 MB/s | 311 MB/s | 23,001 |
|
||||
| 1 MB | 9.6 MB/s | 321 MB/s | 230,001 |
|
||||
|
||||
Comparison with other implementations (10 MB input):
|
||||
|
||||
| Implementation | Encode Speed | Notes |
|
||||
|----------------|--------------|-------|
|
||||
| Engine (this) | ~10 MB/s | stdlib RE2, parallel >4KB |
|
||||
| tiktoken (Rust) | ~17 MB/s | Highly optimized regex |
|
||||
| Ollama (Go) | ~2-3 MB/s | regexp2 backtracking |
|
||||
|
||||
## Performance Opportunities
|
||||
|
||||
Potential optimizations not yet implemented:
|
||||
|
||||
| Optimization | Expected Gain | Complexity |
|
||||
|--------------|---------------|------------|
|
||||
| Aho-Corasick for special tokens | 2-3x for many special tokens | Medium |
|
||||
| Custom regex engine (like tiktoken) | 1.5-2x | High |
|
||||
| SIMD byte scanning | 1.3-1.5x for pretokenizer | Medium |
|
||||
| Assembly BPE merge loop | 1.2-1.5x | High |
|
||||
| Memoization for repeated substrings | Variable | Low |
|
||||
|
||||
Current bottleneck is the pretokenizer regex (~60% of encode time). tiktoken achieves ~17 MB/s with a hand-tuned Rust regex engine.
|
||||
|
||||
## Not Yet Implemented
|
||||
|
||||
| Feature | Used By | Notes |
|
||||
|---------|---------|-------|
|
||||
| Unigram tokenizer | T5, ALBERT, mBART | Different algorithm (not BPE) |
|
||||
| Unicode normalizers | Some multilingual models | NFD, NFKC, lowercase, etc. |
|
||||
| Custom pretokenizers | Model-specific | Beyond standard patterns |
|
||||
|
||||
Most HuggingFace models use BPE or SentencePiece, which are fully supported. WordPiece (BERT-style) is also supported with standard `[UNK]` fallback for out-of-vocabulary characters.
|
||||
|
||||
## Files
|
||||
|
||||
| File | Description |
|
||||
|------|-------------|
|
||||
| `tokenizer.go` | Main implementation (~1000 lines) |
|
||||
| `tokenizer_test.go` | Tests and benchmarks |
|
||||
| `testdata/` | Mini tokenizer for unit tests |
|
||||
1
x/imagegen/tokenizer/testdata/mini_llama.json
vendored
Normal file
1
x/imagegen/tokenizer/testdata/mini_llama.json
vendored
Normal file
@@ -0,0 +1 @@
|
||||
{"model": {"type": "BPE", "vocab": {"!": 0, "\"": 1, "#": 2, "$": 3, "%": 4, "&": 5, "'": 6, "(": 7, ")": 8, "*": 9, "+": 10, ",": 11, "-": 12, ".": 13, "/": 14, "0": 15, "1": 16, "2": 17, "3": 18, "4": 19, "5": 20, "6": 21, "7": 22, "8": 23, "9": 24, ":": 25, ";": 26, "<": 27, "=": 28, ">": 29, "?": 30, "@": 31, "A": 32, "B": 33, "C": 34, "D": 35, "E": 36, "F": 37, "G": 38, "H": 39, "I": 40, "J": 41, "K": 42, "L": 43, "M": 44, "N": 45, "O": 46, "P": 47, "Q": 48, "R": 49, "S": 50, "T": 51, "U": 52, "V": 53, "fé": 59958, "W": 54, "X": 55, "Y": 56, "Z": 57, "[": 58, "\\": 59, "]": 60, "^": 61, "_": 62, "`": 63, "a": 64, "b": 65, "c": 66, "d": 67, "e": 68, "f": 69, "g": 70, "h": 71, "i": 72, "j": 73, "k": 74, "l": 75, "m": 76, "n": 77, "o": 78, "p": 79, "r": 81, "q": 80, "s": 82, "t": 83, "u": 84, "v": 85, "w": 86, "x": 87, "y": 88, "z": 89, "{": 90, "|": 91, "}": 92, "~": 93, "¡": 94, "¢": 95, "£": 96, "¤": 97, "¥": 98, "¦": 99, "§": 100, "¨": 101, "World": 10343, "©": 102, "ª": 103, "«": 104, "¬": 105, "®": 106, "world": 14957, "¯": 107, "°": 108, "±": 109, "²": 110, "³": 111, "´": 112, "µ": 113, "¶": 114, "·": 115, "¸": 116, "¹": 117, "º": 118, "»": 119, "¼": 120, "½": 121, "¾": 122, "¿": 123, "À": 124, "Á": 125, "Â": 126, "Ã": 127, "Ä": 128, "Å": 129, "Æ": 130, "Ç": 131, "È": 132, "É": 133, "Ê": 134, "Ë": 135, "Ì": 136, "Í": 137, "Î": 138, "Ï": 139, "Ð": 140, "Ñ": 141, "Ò": 142, "Ó": 143, "Ô": 144, "Õ": 145, "Ö": 146, "×": 147, "Ø": 148, "Ù": 149, "Ú": 150, "Û": 151, "Ü": 152, "Ý": 153, "Þ": 154, "ß": 155, "à": 156, "á": 157, "â": 158, "ã": 159, "ä": 160, "å": 161, "æ": 162, "ç": 163, "è": 164, "é": 165, "ê": 166, "ë": 167, "ì": 168, "Ġhello": 24748, "í": 169, "î": 170, "ï": 171, "ð": 172, "ñ": 173, "Hello": 9906, "ò": 174, "ó": 175, "ô": 176, "õ": 177, "ö": 178, "Ġ{}": 4792, "÷": 179, "ø": 180, "ù": 181, "ú": 182, "û": 183, "ü": 184, "ý": 185, "þ": 186, "ÿ": 187, "Ā": 188, "ā": 189, "Ă": 190, "ă": 191, "Ċ": 198, "Ą": 192, "ą": 193, "Ć": 194, "ć": 195, "Ĉ": 196, "ĉ": 197, "ċ": 199, "Č": 200, "č": 201, "Ď": 202, "ď": 203, "Đ": 204, "đ": 205, "Ē": 206, "ē": 207, "Ĕ": 208, "ĕ": 209, "Ė": 210, "ė": 211, "Ę": 212, "ę": 213, "Ġ": 220, "Ě": 214, "ě": 215, "Ĝ": 216, "ĝ": 217, "Ğ": 218, "ğ": 219, "ġ": 221, "Ģ": 222, "ģ": 223, "Ĥ": 224, "ĥ": 225, "Ħ": 226, "ħ": 227, "Ĩ": 228, "ĩ": 229, "Ī": 230, "ī": 231, "Ĭ": 232, "ĭ": 233, "Į": 234, "į": 235, "İ": 236, "ı": 237, "IJ": 238, "ij": 239, "Ĵ": 240, "ĵ": 241, "Ķ": 242, "ķ": 243, "ĸ": 244, "Ĺ": 245, "ĺ": 246, "Ļ": 247, "ļ": 248, "Ľ": 249, "ĠĠ": 256, "ľ": 250, "Ŀ": 251, "ŀ": 252, "Ł": 253, "rer": 38149, "ĠĠĠ": 262, "ł": 254, "Ń": 255, "'m": 2846, "'re": 2351, "can": 4919, "func": 2900, "()": 368, "Ġworld": 1917, "Ġmain": 1925, "00": 410, "123": 4513, "000": 931, "ca": 936, "'t": 956, "é": 978, "hello": 15339, "Ġw": 289, "orld": 1410, "Ġwor": 4191, "ld": 509, "main": 3902, "Ġm": 296, "ain": 467, "Ġma": 7643, "in": 258, "Ġmai": 17154, "re": 265, "'r": 97670, "unc": 1371, "fun": 12158, "fu": 33721, "nc": 1031, "ma": 1764, "mai": 77585, "wor": 50810, "or": 269, "Ġwo": 24670, "23": 1419, "12": 717, "{}": 6390, "Ġ{": 314, "an": 276, "ello": 4896, "Hel": 33813, "lo": 385, "Hell": 81394, "un": 359, "hel": 50222, "hell": 57195, "ai": 2192, "wo": 1146, "Ġh": 305, "Ġhel": 11591, "Ġhell": 15123, "el": 301, "He": 1548, "er": 261, "he": 383, "ell": 616, "ll": 657}, "merges": ["Ġ Ġ", "Ġ ĠĠ", "ĠĠ Ġ", "( )", "0 0", "0 00", "00 0", "c a", "' t", "à ©", "Ġ world", "Ġw orld", "Ġwor ld", "Ġ main", "Ġm ain", "Ġma in", "Ġmai n", "' re", "'r e", "' m", "f unc", "fun c", "fu nc", "m ain", "ma in", "mai n", "Ġ wor", "Ġw or", "Ġwo r", "1 23", "12 3", "Ġ {}", "Ġ{ }", "c an", "ca n", "{ }", "Ġ ma", "Ġm a", "H ello", "Hel lo", "Hell o", "W orld", "f un", "fu n", "w orld", "wor ld", "h ello", "hel lo", "hell o", "Ġ mai", "Ġm ai", "Ġma i", "Ġ wo", "Ġw o", "Ġ hello", "Ġh ello", "Ġhel lo", "Ġhell o", "f u", "H el", "He l", "r er", "re r", "h el", "he l", "w or", "wo r", "h ell", "he ll", "hel l", "f é", "m ai", "ma i", "H ell", "He ll", "Hel l", "' r"]}, "pre_tokenizer": {"type": "Sequence", "pretokenizers": [{"type": "Split", "pattern": {"Regex": "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"}, "behavior": "Isolated", "invert": false}, {"type": "ByteLevel", "add_prefix_space": false, "trim_offsets": true, "use_regex": false}]}, "decoder": {"type": "ByteLevel", "add_prefix_space": true, "trim_offsets": true, "use_regex": true}, "added_tokens": [{"id": 128000, "content": "<|begin_of_text|>", "special": true}, {"id": 128001, "content": "<|end_of_text|>", "special": true}]}
|
||||
1013
x/imagegen/tokenizer/tokenizer.go
Normal file
1013
x/imagegen/tokenizer/tokenizer.go
Normal file
File diff suppressed because it is too large
Load Diff
785
x/imagegen/tokenizer/tokenizer_test.go
Normal file
785
x/imagegen/tokenizer/tokenizer_test.go
Normal file
@@ -0,0 +1,785 @@
|
||||
//go:build mlx
|
||||
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestPatternCompilation validates that HuggingFace pretokenizer patterns
|
||||
// can be rewritten for Go's RE2 regexp engine and compiled successfully.
|
||||
func TestPatternCompilation(t *testing.T) {
|
||||
patterns := []struct {
|
||||
name string
|
||||
pattern string
|
||||
}{
|
||||
{"llama3", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`},
|
||||
{"qwen2", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`},
|
||||
{"gpt4o", `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`},
|
||||
{"gpt2", `'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`},
|
||||
{"deepseek_cjk", `[一-龥\x{3040}-ゟ゠-ヿ]+`},
|
||||
}
|
||||
|
||||
for _, p := range patterns {
|
||||
t.Run(p.name, func(t *testing.T) {
|
||||
rewritten := rewritePatternForRE2(p.pattern)
|
||||
if _, err := regexp.Compile(rewritten); err != nil {
|
||||
t.Errorf("failed to compile pattern: %v\noriginal: %s\nrewritten: %s", err, p.pattern, rewritten)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRoundtrip verifies the fundamental property: encode(text) -> decode -> text
|
||||
// This is the key invariant from tiktoken's test suite.
|
||||
func TestRoundtrip(t *testing.T) {
|
||||
tok, err := Load("testdata/mini_llama.json")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load tokenizer: %v", err)
|
||||
}
|
||||
|
||||
// Test cases covering key edge cases from tiktoken
|
||||
inputs := []string{
|
||||
// Empty and simple
|
||||
"",
|
||||
"a",
|
||||
"hello",
|
||||
"hello world",
|
||||
|
||||
// Whitespace edge cases
|
||||
" ",
|
||||
" ",
|
||||
" ",
|
||||
" hello",
|
||||
"hello ",
|
||||
" hello ",
|
||||
"hello world",
|
||||
"hello world",
|
||||
"\t",
|
||||
"\n",
|
||||
"\r\n",
|
||||
"hello\nworld",
|
||||
"hello\n\nworld",
|
||||
|
||||
// Contractions
|
||||
"don't",
|
||||
"I'm",
|
||||
"we'll",
|
||||
"they're",
|
||||
"it's",
|
||||
"DON'T", // uppercase
|
||||
|
||||
// Numbers
|
||||
"123",
|
||||
"1234567890",
|
||||
"3.14159",
|
||||
"$100",
|
||||
"50%",
|
||||
|
||||
// Unicode
|
||||
"こんにちは", // Japanese
|
||||
"你好", // Chinese
|
||||
"مرحبا", // Arabic (RTL)
|
||||
"🎉", // Emoji
|
||||
"Hello 世界", // Mixed
|
||||
"café", // Accented
|
||||
"naïve", // Diaeresis
|
||||
"Ω≈ç√∫", // Math symbols
|
||||
|
||||
// Code
|
||||
"func main() {}",
|
||||
"if (x == 0) { return; }",
|
||||
"import \"fmt\"",
|
||||
"x := 42",
|
||||
"// comment",
|
||||
"/* block */",
|
||||
|
||||
// Repetition (tiktoken specifically tests this)
|
||||
"aaaa",
|
||||
"aaaaaaaaaaaa",
|
||||
strings.Repeat("a", 100),
|
||||
strings.Repeat("hello ", 50),
|
||||
|
||||
// Punctuation
|
||||
"...",
|
||||
"!!!",
|
||||
"???",
|
||||
"hello, world!",
|
||||
"(parentheses)",
|
||||
"[brackets]",
|
||||
"{braces}",
|
||||
|
||||
// Mixed complexity
|
||||
"The quick brown fox jumps over the lazy dog.",
|
||||
"Lorem ipsum dolor sit amet, consectetur adipiscing elit.",
|
||||
"func TestRoundtrip(t *testing.T) { t.Run(\"test\", func(t *testing.T) {}) }",
|
||||
}
|
||||
|
||||
for _, input := range inputs {
|
||||
name := input
|
||||
if len(name) > 30 {
|
||||
name = name[:30] + "..."
|
||||
}
|
||||
if name == "" {
|
||||
name = "<empty>"
|
||||
}
|
||||
name = strings.ReplaceAll(name, "\n", "\\n")
|
||||
name = strings.ReplaceAll(name, "\t", "\\t")
|
||||
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tokens := tok.Encode(input, false)
|
||||
decoded := tok.Decode(tokens)
|
||||
if decoded != input {
|
||||
t.Errorf("roundtrip failed:\n input: %q\n tokens: %v\n decoded: %q", input, tokens, decoded)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSpecialTokens verifies that special tokens are handled correctly
|
||||
func TestSpecialTokens(t *testing.T) {
|
||||
tok, err := Load("testdata/mini_llama.json")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load tokenizer: %v", err)
|
||||
}
|
||||
|
||||
// Special tokens should be preserved through encode/decode
|
||||
t.Run("bos_preserved", func(t *testing.T) {
|
||||
if tok.BOS() < 0 {
|
||||
t.Skip("no BOS token")
|
||||
}
|
||||
tokens := tok.Encode("hello", true)
|
||||
if len(tokens) == 0 || tokens[0] != tok.BOS() {
|
||||
t.Errorf("BOS not prepended: got %v, want first token to be %d", tokens, tok.BOS())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("special_token_split", func(t *testing.T) {
|
||||
// If we have special tokens, verify they're split correctly
|
||||
for tokenStr, tokenID := range tok.specialTokens {
|
||||
input := "before" + tokenStr + "after"
|
||||
tokens := tok.Encode(input, false)
|
||||
|
||||
found := false
|
||||
for _, id := range tokens {
|
||||
if id == tokenID {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("special token %q (id=%d) not found in encoding of %q: %v",
|
||||
tokenStr, tokenID, input, tokens)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestConcurrency verifies thread-safe encoding
|
||||
func TestConcurrency(t *testing.T) {
|
||||
tok, err := Load("testdata/mini_llama.json")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load tokenizer: %v", err)
|
||||
}
|
||||
|
||||
input := "The quick brown fox jumps over the lazy dog."
|
||||
expected := tok.Encode(input, false)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errors := make(chan error, 100)
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
got := tok.Encode(input, false)
|
||||
if len(got) != len(expected) {
|
||||
errors <- nil // just signal error
|
||||
return
|
||||
}
|
||||
for j := range got {
|
||||
if got[j] != expected[j] {
|
||||
errors <- nil
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errors)
|
||||
|
||||
if len(errors) > 0 {
|
||||
t.Errorf("concurrent encoding produced inconsistent results")
|
||||
}
|
||||
}
|
||||
|
||||
// TestIntegration runs against real model directories, comparing with Python transformers.
|
||||
// Skips if model weights are not available.
|
||||
func TestIntegration(t *testing.T) {
|
||||
models := []string{
|
||||
"../weights/Llama-3.2-1B",
|
||||
"../weights/gemma-3-1b-it",
|
||||
"../weights/gpt-oss-20b",
|
||||
}
|
||||
|
||||
// Test inputs covering various edge cases
|
||||
inputs := []string{
|
||||
"Hello, world!",
|
||||
"The quick brown fox jumps over the lazy dog.",
|
||||
"こんにちは世界",
|
||||
"def fibonacci(n):\n if n <= 1:\n return n\n return fibonacci(n-1) + fibonacci(n-2)",
|
||||
"1234567890",
|
||||
" spaces ",
|
||||
"don't won't can't",
|
||||
}
|
||||
|
||||
for _, modelPath := range models {
|
||||
modelName := filepath.Base(modelPath)
|
||||
|
||||
t.Run(modelName, func(t *testing.T) {
|
||||
tokenizerPath := filepath.Join(modelPath, "tokenizer.json")
|
||||
if _, err := os.Stat(tokenizerPath); err != nil {
|
||||
t.Skipf("skipping: %s not found", tokenizerPath)
|
||||
}
|
||||
|
||||
tok, err := Load(tokenizerPath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load tokenizer: %v", err)
|
||||
}
|
||||
|
||||
for _, input := range inputs {
|
||||
t.Run(truncate(input, 20), func(t *testing.T) {
|
||||
// Test roundtrip
|
||||
tokens := tok.Encode(input, false)
|
||||
decoded := tok.Decode(tokens)
|
||||
if decoded != input {
|
||||
t.Errorf("roundtrip failed:\n input: %q\n decoded: %q", input, decoded)
|
||||
}
|
||||
|
||||
// Compare with Python if available
|
||||
if pythonTokens, err := pythonEncode(modelPath, input); err == nil {
|
||||
if !equalInt32Slice(tokens, pythonTokens) {
|
||||
t.Errorf("mismatch with Python:\n go: %v\n python: %v", tokens, pythonTokens)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// pythonEncode calls Python transformers to encode text, for comparison
|
||||
func pythonEncode(modelPath, text string) ([]int32, error) {
|
||||
script := `
|
||||
import sys, json
|
||||
from transformers import AutoTokenizer
|
||||
tok = AutoTokenizer.from_pretrained(sys.argv[1])
|
||||
tokens = tok.encode(sys.argv[2], add_special_tokens=False)
|
||||
print(json.dumps(tokens))
|
||||
`
|
||||
cmd := exec.Command("python3", "-c", script, modelPath, text)
|
||||
var out bytes.Buffer
|
||||
cmd.Stdout = &out
|
||||
cmd.Stderr = nil
|
||||
|
||||
if err := cmd.Run(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse JSON array
|
||||
var tokens []int32
|
||||
output := strings.TrimSpace(out.String())
|
||||
if output == "" || output == "[]" {
|
||||
return []int32{}, nil
|
||||
}
|
||||
|
||||
// Simple parsing for [1, 2, 3] format
|
||||
output = strings.Trim(output, "[]")
|
||||
if output == "" {
|
||||
return []int32{}, nil
|
||||
}
|
||||
|
||||
for _, s := range strings.Split(output, ",") {
|
||||
s = strings.TrimSpace(s)
|
||||
var v int32
|
||||
if _, err := parseIntSimple(s, &v); err == nil {
|
||||
tokens = append(tokens, v)
|
||||
}
|
||||
}
|
||||
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
func parseIntSimple(s string, v *int32) (bool, error) {
|
||||
var n int64
|
||||
for _, c := range s {
|
||||
if c >= '0' && c <= '9' {
|
||||
n = n*10 + int64(c-'0')
|
||||
}
|
||||
}
|
||||
*v = int32(n)
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func equalInt32Slice(a, b []int32) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if a[i] != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func truncate(s string, n int) string {
|
||||
if len(s) <= n {
|
||||
return s
|
||||
}
|
||||
return s[:n] + "..."
|
||||
}
|
||||
|
||||
// TestBPEPretokenizer verifies BPE pretokenizer splits text correctly
|
||||
// using the GPT-2 style regex pattern (no dependency on tokenizer files)
|
||||
func TestBPEPretokenizer(t *testing.T) {
|
||||
pattern := `'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`
|
||||
re := regexp.MustCompile(rewritePatternForRE2(pattern))
|
||||
|
||||
tests := []struct {
|
||||
input string
|
||||
expected []string
|
||||
}{
|
||||
{"Hello", []string{"Hello"}},
|
||||
{"Hello world", []string{"Hello", " world"}},
|
||||
{"Hello, world!", []string{"Hello", ",", " world", "!"}},
|
||||
{"don't", []string{"don", "'t"}},
|
||||
{"I'm", []string{"I", "'m"}},
|
||||
{"123", []string{"123"}},
|
||||
{"12345", []string{"12345"}}, // GPT-2 pattern matches any digit sequence
|
||||
{"a b", []string{"a", " ", " b"}}, // whitespace boundary: last space prepends to word
|
||||
{" ", []string{" "}}, // pure whitespace stays together
|
||||
{"\n\n", []string{"\n\n"}}, // newlines stay together
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
// Get regex matches
|
||||
matches := re.FindAllStringIndex(tt.input, -1)
|
||||
var chunks []string
|
||||
for _, m := range matches {
|
||||
chunks = append(chunks, tt.input[m[0]:m[1]])
|
||||
}
|
||||
|
||||
// Apply whitespace boundary fix (same logic as Encode)
|
||||
for i := 0; i < len(chunks)-1; i++ {
|
||||
if isNonNewlineWhitespace(chunks[i]) && len(chunks[i+1]) > 0 {
|
||||
r, _ := []rune(chunks[i+1])[0], 0
|
||||
if r >= 'A' && r <= 'z' { // simplified letter check
|
||||
// Move last space to next chunk
|
||||
if len(chunks[i]) > 0 {
|
||||
lastSpace := chunks[i][len(chunks[i])-1:]
|
||||
chunks[i] = chunks[i][:len(chunks[i])-1]
|
||||
chunks[i+1] = lastSpace + chunks[i+1]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Filter empty chunks
|
||||
var result []string
|
||||
for _, c := range chunks {
|
||||
if c != "" {
|
||||
result = append(result, c)
|
||||
}
|
||||
}
|
||||
|
||||
if len(result) != len(tt.expected) {
|
||||
t.Errorf("got %v, want %v", result, tt.expected)
|
||||
return
|
||||
}
|
||||
for i := range result {
|
||||
if result[i] != tt.expected[i] {
|
||||
t.Errorf("chunk %d: got %q, want %q", i, result[i], tt.expected[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSentencePiecePretokenizer verifies SentencePiece doesn't use pretokenizer
|
||||
// and correctly replaces spaces with ▁ (no dependency on tokenizer files)
|
||||
func TestSentencePiecePretokenizer(t *testing.T) {
|
||||
// SentencePiece has no pretokenizer - whole text is one chunk
|
||||
// Spaces are replaced with ▁ during encoding
|
||||
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string // after space replacement
|
||||
}{
|
||||
{"Hello", "Hello"},
|
||||
{"Hello world", "Hello▁world"},
|
||||
{"Hello, world!", "Hello,▁world!"},
|
||||
{" spaces ", "▁▁▁spaces▁▁▁"},
|
||||
{" Hello", "▁Hello"},
|
||||
{"Hello ", "Hello▁"},
|
||||
{"a b c", "a▁b▁c"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
// SentencePiece encoding: replace space with ▁
|
||||
result := strings.ReplaceAll(tt.input, " ", "▁")
|
||||
if result != tt.expected {
|
||||
t.Errorf("got %q, want %q", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestWordPiecePretokenizer verifies WordPiece (BERT) pretokenizer splits correctly
|
||||
// BertPreTokenizer splits on whitespace and punctuation
|
||||
func TestWordPiecePretokenizer(t *testing.T) {
|
||||
// BertPreTokenizer behavior: split on whitespace and punctuation
|
||||
// Whitespace is stripped, punctuation becomes separate tokens
|
||||
|
||||
tests := []struct {
|
||||
input string
|
||||
expected []string
|
||||
}{
|
||||
{"Hello", []string{"Hello"}},
|
||||
{"Hello world", []string{"Hello", "world"}}, // whitespace stripped
|
||||
{"Hello, world!", []string{"Hello", ",", "world", "!"}}, // punct separate
|
||||
{"don't", []string{"don", "'", "t"}}, // apostrophe separate (unlike BPE)
|
||||
{" spaces ", []string{"spaces"}}, // whitespace stripped
|
||||
{"Hello.World", []string{"Hello", ".", "World"}}, // punct splits
|
||||
{"test@email.com", []string{"test", "@", "email", ".", "com"}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
result := splitBertStyle(tt.input)
|
||||
if len(result) != len(tt.expected) {
|
||||
t.Errorf("got %v, want %v", result, tt.expected)
|
||||
return
|
||||
}
|
||||
for i := range result {
|
||||
if result[i] != tt.expected[i] {
|
||||
t.Errorf("token %d: got %q, want %q", i, result[i], tt.expected[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// splitBertStyle mimics BertPreTokenizer: split on whitespace and punctuation
|
||||
func splitBertStyle(s string) []string {
|
||||
var result []string
|
||||
var current strings.Builder
|
||||
|
||||
for _, r := range s {
|
||||
if r == ' ' || r == '\t' || r == '\n' || r == '\r' {
|
||||
// Whitespace: flush current token, don't add whitespace
|
||||
if current.Len() > 0 {
|
||||
result = append(result, current.String())
|
||||
current.Reset()
|
||||
}
|
||||
} else if isPunct(r) {
|
||||
// Punctuation: flush current, add punct as separate token
|
||||
if current.Len() > 0 {
|
||||
result = append(result, current.String())
|
||||
current.Reset()
|
||||
}
|
||||
result = append(result, string(r))
|
||||
} else {
|
||||
current.WriteRune(r)
|
||||
}
|
||||
}
|
||||
if current.Len() > 0 {
|
||||
result = append(result, current.String())
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func isPunct(r rune) bool {
|
||||
// Common ASCII punctuation
|
||||
return (r >= '!' && r <= '/') || (r >= ':' && r <= '@') ||
|
||||
(r >= '[' && r <= '`') || (r >= '{' && r <= '~')
|
||||
}
|
||||
|
||||
// TestRepeatedDigits verifies correct tokenization of repeated digit sequences.
|
||||
// Llama-style tokenizers split digits in groups of 1-3 due to the \p{N}{1,3} pattern.
|
||||
func TestRepeatedDigits(t *testing.T) {
|
||||
tok, err := Load("./testdata/mini_llama.json")
|
||||
if err != nil {
|
||||
t.Skipf("mini_llama.json not available: %v", err)
|
||||
}
|
||||
|
||||
// Pattern: 1 digit, 2 digits, 3 digits, then repeats
|
||||
// "0" -> [single], "00" -> [double], "000" -> [triple]
|
||||
// "0000" -> [triple, single], etc.
|
||||
tests := []struct {
|
||||
input string
|
||||
count int // expected token count
|
||||
}{
|
||||
{"0", 1},
|
||||
{"00", 1},
|
||||
{"000", 1},
|
||||
{"0000", 2}, // 3 + 1
|
||||
{"00000", 2}, // 3 + 2
|
||||
{"000000", 2}, // 3 + 3
|
||||
{"0000000", 3},
|
||||
{"00000000", 3},
|
||||
{"000000000", 3},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
ids := tok.Encode(tt.input, false)
|
||||
if len(ids) != tt.count {
|
||||
t.Errorf("Encode(%q) = %d tokens, want %d", tt.input, len(ids), tt.count)
|
||||
}
|
||||
// Verify roundtrip
|
||||
decoded := tok.Decode(ids)
|
||||
if decoded != tt.input {
|
||||
t.Errorf("Decode(Encode(%q)) = %q", tt.input, decoded)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestNullByte verifies that null bytes roundtrip correctly
|
||||
func TestNullByte(t *testing.T) {
|
||||
tok, err := Load("./testdata/mini_llama.json")
|
||||
if err != nil {
|
||||
t.Skipf("mini_llama.json not available: %v", err)
|
||||
}
|
||||
|
||||
ids := tok.Encode("\x00", false)
|
||||
decoded := tok.Decode(ids)
|
||||
if decoded != "\x00" {
|
||||
t.Errorf("null byte roundtrip failed: got %q, want %q", decoded, "\x00")
|
||||
}
|
||||
}
|
||||
|
||||
// TestTokenizerTypeDetection verifies correct detection of tokenizer types
|
||||
func TestTokenizerTypeDetection(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
decoder string
|
||||
expected TokenizerType
|
||||
}{
|
||||
{
|
||||
name: "ByteLevel decoder (BPE)",
|
||||
decoder: `{"type": "ByteLevel"}`,
|
||||
expected: TokenizerBPE,
|
||||
},
|
||||
{
|
||||
name: "Sequence with Replace ▁ (SentencePiece)",
|
||||
decoder: `{
|
||||
"type": "Sequence",
|
||||
"decoders": [
|
||||
{"type": "Replace", "pattern": {"String": "▁"}, "content": " "}
|
||||
]
|
||||
}`,
|
||||
expected: TokenizerSentencePiece,
|
||||
},
|
||||
{
|
||||
name: "null decoder (BPE default)",
|
||||
decoder: `null`,
|
||||
expected: TokenizerBPE,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isSPM := detectSentencePiece([]byte(tt.decoder))
|
||||
var got TokenizerType
|
||||
if isSPM {
|
||||
got = TokenizerSentencePiece
|
||||
} else {
|
||||
got = TokenizerBPE
|
||||
}
|
||||
if got != tt.expected {
|
||||
t.Errorf("got %v, want %v", got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestPADTokenDefault verifies PAD() returns -1 when not configured
|
||||
func TestPADTokenDefault(t *testing.T) {
|
||||
tok, err := Load("testdata/mini_llama.json")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load tokenizer: %v", err)
|
||||
}
|
||||
|
||||
// mini_llama.json has no PAD token configured, should return -1
|
||||
if got := tok.PAD(); got != -1 {
|
||||
t.Errorf("PAD() = %d, want -1 (not configured)", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestPADTokenFromConfig verifies PAD token is loaded from tokenizer_config.json
|
||||
func TestPADTokenFromConfig(t *testing.T) {
|
||||
// Create temp directory with tokenizer files
|
||||
dir := t.TempDir()
|
||||
|
||||
// Write minimal tokenizer.json
|
||||
tokenizerJSON := `{
|
||||
"model": {
|
||||
"type": "BPE",
|
||||
"vocab": {"<|endoftext|>": 0, "hello": 1, "world": 2},
|
||||
"merges": []
|
||||
},
|
||||
"added_tokens": [
|
||||
{"id": 0, "content": "<|endoftext|>", "special": true}
|
||||
]
|
||||
}`
|
||||
if err := os.WriteFile(filepath.Join(dir, "tokenizer.json"), []byte(tokenizerJSON), 0o644); err != nil {
|
||||
t.Fatalf("failed to write tokenizer.json: %v", err)
|
||||
}
|
||||
|
||||
// Write tokenizer_config.json with pad_token
|
||||
configJSON := `{
|
||||
"pad_token": "<|endoftext|>"
|
||||
}`
|
||||
if err := os.WriteFile(filepath.Join(dir, "tokenizer_config.json"), []byte(configJSON), 0o644); err != nil {
|
||||
t.Fatalf("failed to write tokenizer_config.json: %v", err)
|
||||
}
|
||||
|
||||
tok, err := Load(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load tokenizer: %v", err)
|
||||
}
|
||||
|
||||
if got := tok.PAD(); got != 0 {
|
||||
t.Errorf("PAD() = %d, want 0 (<|endoftext|>)", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestPADTokenFromSpecialTokensMap verifies PAD falls back to special_tokens_map.json
|
||||
func TestPADTokenFromSpecialTokensMap(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// Write minimal tokenizer.json
|
||||
tokenizerJSON := `{
|
||||
"model": {
|
||||
"type": "BPE",
|
||||
"vocab": {"<pad>": 0, "hello": 1, "world": 2},
|
||||
"merges": []
|
||||
},
|
||||
"added_tokens": [
|
||||
{"id": 0, "content": "<pad>", "special": true}
|
||||
]
|
||||
}`
|
||||
if err := os.WriteFile(filepath.Join(dir, "tokenizer.json"), []byte(tokenizerJSON), 0o644); err != nil {
|
||||
t.Fatalf("failed to write tokenizer.json: %v", err)
|
||||
}
|
||||
|
||||
// Write special_tokens_map.json with pad_token
|
||||
mapJSON := `{
|
||||
"pad_token": "<pad>"
|
||||
}`
|
||||
if err := os.WriteFile(filepath.Join(dir, "special_tokens_map.json"), []byte(mapJSON), 0o644); err != nil {
|
||||
t.Fatalf("failed to write special_tokens_map.json: %v", err)
|
||||
}
|
||||
|
||||
tok, err := Load(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load tokenizer: %v", err)
|
||||
}
|
||||
|
||||
if got := tok.PAD(); got != 0 {
|
||||
t.Errorf("PAD() = %d, want 0 (<pad>)", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestPADTokenWithContentObject verifies PAD token works with {"content": "..."} format
|
||||
func TestPADTokenWithContentObject(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// Write minimal tokenizer.json
|
||||
tokenizerJSON := `{
|
||||
"model": {
|
||||
"type": "BPE",
|
||||
"vocab": {"[PAD]": 0, "hello": 1},
|
||||
"merges": []
|
||||
},
|
||||
"added_tokens": [
|
||||
{"id": 0, "content": "[PAD]", "special": true}
|
||||
]
|
||||
}`
|
||||
if err := os.WriteFile(filepath.Join(dir, "tokenizer.json"), []byte(tokenizerJSON), 0o644); err != nil {
|
||||
t.Fatalf("failed to write tokenizer.json: %v", err)
|
||||
}
|
||||
|
||||
// Write tokenizer_config.json with pad_token as object (HuggingFace format)
|
||||
configJSON := `{
|
||||
"pad_token": {"content": "[PAD]", "lstrip": false, "normalized": false}
|
||||
}`
|
||||
if err := os.WriteFile(filepath.Join(dir, "tokenizer_config.json"), []byte(configJSON), 0o644); err != nil {
|
||||
t.Fatalf("failed to write tokenizer_config.json: %v", err)
|
||||
}
|
||||
|
||||
tok, err := Load(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load tokenizer: %v", err)
|
||||
}
|
||||
|
||||
if got := tok.PAD(); got != 0 {
|
||||
t.Errorf("PAD() = %d, want 0 ([PAD])", got)
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmarks
|
||||
|
||||
func BenchmarkEncode(b *testing.B) {
|
||||
tok, err := Load("testdata/mini_llama.json")
|
||||
if err != nil {
|
||||
b.Fatalf("failed to load tokenizer: %v", err)
|
||||
}
|
||||
|
||||
inputs := []struct {
|
||||
name string
|
||||
text string
|
||||
}{
|
||||
{"short", "Hello, world!"},
|
||||
{"medium", "The quick brown fox jumps over the lazy dog. " + strings.Repeat("This is a test. ", 10)},
|
||||
{"long", strings.Repeat("The quick brown fox jumps over the lazy dog. ", 100)},
|
||||
}
|
||||
|
||||
for _, input := range inputs {
|
||||
b.Run(input.name, func(b *testing.B) {
|
||||
b.SetBytes(int64(len(input.text)))
|
||||
for i := 0; i < b.N; i++ {
|
||||
tok.Encode(input.text, false)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkDecode(b *testing.B) {
|
||||
tok, err := Load("testdata/mini_llama.json")
|
||||
if err != nil {
|
||||
b.Fatalf("failed to load tokenizer: %v", err)
|
||||
}
|
||||
|
||||
text := strings.Repeat("The quick brown fox jumps over the lazy dog. ", 100)
|
||||
tokens := tok.Encode(text, false)
|
||||
|
||||
b.SetBytes(int64(len(text)))
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
tok.Decode(tokens)
|
||||
}
|
||||
}
|
||||
77
x/kvcache/cache.go
Normal file
77
x/kvcache/cache.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package kvcache
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/ollama/ollama/x/ml"
|
||||
"github.com/ollama/ollama/x/model/input"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrKvCacheFull = errors.New("could not find a kv cache slot")
|
||||
ErrNotSupported = errors.New("model does not support operation")
|
||||
)
|
||||
|
||||
type Cache interface {
|
||||
// ** used by model implementations **
|
||||
|
||||
// SetLayer sets the active layer of the cache
|
||||
SetLayer(layer int)
|
||||
|
||||
// Get returns the history of key and value tensors plus a mask
|
||||
//
|
||||
// The shape of the tensors is documented in the specific
|
||||
// cache implementation used.
|
||||
Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor)
|
||||
|
||||
// Put stores a batch of key and value in the cache
|
||||
//
|
||||
// The shape of the tensors is documented in the specific
|
||||
// cache implementation used.
|
||||
Put(ctx ml.Context, key, value ml.Tensor)
|
||||
|
||||
// SetConfig controls optimizations (mostly backend-specific) that may transform
|
||||
// the output of the cache to work better with specific kernels. If not called,
|
||||
// the backend settings will be used. This works well when calling Attention.
|
||||
//
|
||||
// The config can be overridden by models, especially if they require vanilla
|
||||
// output when implementing their own version of attention. To do this, pass
|
||||
// an empty ml.CacheConfig.
|
||||
//
|
||||
// Most models will not need to use this.
|
||||
SetConfig(ml.CacheConfig)
|
||||
|
||||
// ** cache management **
|
||||
|
||||
// Init sets up runtime parameters.
|
||||
// backend: Used to allocate cache data storage and execute management operations (such as defrag)
|
||||
// dtype: The data type for storing cache entries
|
||||
// maxSequences: The maximum number of sequences stored in the cache - across all batches
|
||||
// capacity: The number of cache entries to store, per sequence
|
||||
// maxBatch: The maximum number of tokens that can occur in a single batch
|
||||
Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int)
|
||||
|
||||
// Close closes the cache and frees resources associated with it
|
||||
Close()
|
||||
|
||||
// StartForward is called before the start of the model's forward pass.
|
||||
// For each token in the coming batch, there must be a corresponding
|
||||
// entry in positions and seqs. reserve is to preallocate memory
|
||||
// without actually storing data in the cache.
|
||||
StartForward(ctx ml.Context, batch input.Batch, reserve bool) error
|
||||
|
||||
// CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq
|
||||
CopyPrefix(srcSeq, dstSeq int, len int32)
|
||||
|
||||
// CanResume returns true if the cache can continue with the next token at
|
||||
// the given position and sequence. Assumes that the caller has already
|
||||
// verified the contents of the cache.
|
||||
CanResume(seq int, pos int32) bool
|
||||
|
||||
// Remove deletes tokens in the range [beginIndex, endIndex) from seq. Set
|
||||
// endIndex to math.MaxInt32 to remove everything starting at beginIndex.
|
||||
//
|
||||
// If an error occurs, the entire context for the sequence should be
|
||||
// removed by calling Remove(seq, 0, math.MaxInt32)
|
||||
Remove(seq int, beginIndex, endIndex int32) error
|
||||
}
|
||||
797
x/kvcache/causal.go
Normal file
797
x/kvcache/causal.go
Normal file
@@ -0,0 +1,797 @@
|
||||
package kvcache
|
||||
|
||||
// import (
|
||||
// "errors"
|
||||
// "fmt"
|
||||
// "log/slog"
|
||||
// "math"
|
||||
// "slices"
|
||||
|
||||
// "github.com/ollama/ollama/ml"
|
||||
// "github.com/ollama/ollama/model/input"
|
||||
// )
|
||||
|
||||
// type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error)
|
||||
|
||||
// // Causal cache stores K and V tensors according to their position in the
|
||||
// // sequence. Returns the history and a mask for attending to past tokens
|
||||
// //
|
||||
// // The tensors are of shape embed dim, kv heads, batch size
|
||||
// // The mask is of shape history size, batch size
|
||||
// type Causal struct {
|
||||
// DType ml.DType
|
||||
|
||||
// // swaWindowSize is the number of tokens that will be included in the mask
|
||||
// // during attention operations. swaMemorySize is the number of tokens that
|
||||
// // will be retained in memory for partial prefix caching. Set to math.MaxInt32
|
||||
// // for unlimited or if sliding window attention is not being used.
|
||||
// swaWindowSize int32
|
||||
// swaMemorySize int32
|
||||
|
||||
// chunkSize int32
|
||||
|
||||
// opts CausalOptions
|
||||
|
||||
// // maxBatch is the largest batch that we might receive
|
||||
// maxBatch int
|
||||
|
||||
// // config controls mostly backend-specific optimizations
|
||||
// config *ml.CacheConfig
|
||||
|
||||
// // ** current forward pass **
|
||||
|
||||
// // size of the current batch
|
||||
// curBatchSize int
|
||||
|
||||
// // locations for data storage for this batch
|
||||
// curLoc ml.Tensor
|
||||
|
||||
// // mask of the cache as used by this batch
|
||||
// curMask ml.Tensor
|
||||
|
||||
// // the active layer for Get and Put
|
||||
// curLayer int
|
||||
|
||||
// // locations in the cache that are needed for this batch
|
||||
// curCellRange cellRange
|
||||
|
||||
// // curSequences is the sequences corresponding to this pass's entries in the cache
|
||||
// curSequences []int
|
||||
|
||||
// // curPositions is the positions corresponding to this pass's entries in the cache
|
||||
// curPositions []int32
|
||||
|
||||
// // ** cache metadata **
|
||||
|
||||
// // for each possible location in the cache, stores the position and set of sequences
|
||||
// // that reference the data there
|
||||
// cells []cacheCell
|
||||
|
||||
// // maps from sequence to the range of locations where it is stored in the cache
|
||||
// cellRanges map[int]cellRange
|
||||
|
||||
// // ** cache data storage **
|
||||
|
||||
// shiftFn shiftFn
|
||||
// backend ml.Backend
|
||||
// ctxs map[int]ml.Context
|
||||
// keys, values map[int]ml.Tensor
|
||||
|
||||
// kHeadDims, vHeadDims, numKVHeads map[int]int
|
||||
// }
|
||||
|
||||
// type cacheCell struct {
|
||||
// pos int32
|
||||
// sequences []int
|
||||
// }
|
||||
|
||||
// type cellRange struct {
|
||||
// min int
|
||||
// max int
|
||||
// }
|
||||
|
||||
// func NewCausalCache(shift shiftFn) *Causal {
|
||||
// return &Causal{
|
||||
// shiftFn: shift,
|
||||
// ctxs: make(map[int]ml.Context),
|
||||
// keys: make(map[int]ml.Tensor),
|
||||
// values: make(map[int]ml.Tensor),
|
||||
// kHeadDims: make(map[int]int),
|
||||
// vHeadDims: make(map[int]int),
|
||||
// numKVHeads: make(map[int]int),
|
||||
// }
|
||||
// }
|
||||
|
||||
// func NewSWACache(windowSize int32, shift shiftFn) *Causal {
|
||||
// return &Causal{
|
||||
// swaWindowSize: windowSize,
|
||||
// shiftFn: shift,
|
||||
// ctxs: make(map[int]ml.Context),
|
||||
// keys: make(map[int]ml.Tensor),
|
||||
// values: make(map[int]ml.Tensor),
|
||||
// kHeadDims: make(map[int]int),
|
||||
// vHeadDims: make(map[int]int),
|
||||
// numKVHeads: make(map[int]int),
|
||||
// }
|
||||
// }
|
||||
|
||||
// func NewSWAMemCache(windowSize int32, memorySize int32, shift shiftFn) *Causal {
|
||||
// return &Causal{
|
||||
// swaWindowSize: windowSize,
|
||||
// swaMemorySize: memorySize,
|
||||
// shiftFn: shift,
|
||||
// ctxs: make(map[int]ml.Context),
|
||||
// keys: make(map[int]ml.Tensor),
|
||||
// values: make(map[int]ml.Tensor),
|
||||
// kHeadDims: make(map[int]int),
|
||||
// vHeadDims: make(map[int]int),
|
||||
// numKVHeads: make(map[int]int),
|
||||
// }
|
||||
// }
|
||||
|
||||
// func NewChunkedAttentionCache(chunkSize int32, shift shiftFn) *Causal {
|
||||
// return &Causal{
|
||||
// chunkSize: chunkSize,
|
||||
// shiftFn: shift,
|
||||
// ctxs: make(map[int]ml.Context),
|
||||
// keys: make(map[int]ml.Tensor),
|
||||
// values: make(map[int]ml.Tensor),
|
||||
// kHeadDims: make(map[int]int),
|
||||
// vHeadDims: make(map[int]int),
|
||||
// numKVHeads: make(map[int]int),
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||
// if c.config == nil {
|
||||
// var config ml.CacheConfig
|
||||
// if cc, ok := backend.(ml.BackendCacheConfig); ok {
|
||||
// config = cc.CacheConfig()
|
||||
// }
|
||||
// c.config = &config
|
||||
// }
|
||||
|
||||
// if c.config.CachePadding == 0 {
|
||||
// c.config.CachePadding = 1
|
||||
// }
|
||||
|
||||
// if c.config.MaskBatchPadding == 0 {
|
||||
// c.config.MaskBatchPadding = 1
|
||||
// }
|
||||
|
||||
// // TODO what types do we handle here?
|
||||
// // if c.config.MaskDType == ml.DTypeOther {
|
||||
// // c.config.MaskDType = ml.DTypeFloat32
|
||||
// // }
|
||||
|
||||
// if c.swaWindowSize == 0 {
|
||||
// c.swaWindowSize = math.MaxInt32
|
||||
// }
|
||||
// if c.swaMemorySize == 0 {
|
||||
// c.swaMemorySize = c.swaWindowSize
|
||||
// }
|
||||
// // We will allocate space in the cache for the stop token, which won't be part of a follow on
|
||||
// // sequence, so allocate an extra token of storage to ensure that we can jump back without
|
||||
// // causing a cache break. As an optimization, only do this when we have parallel sequences
|
||||
// // because the extra token will live in the batch buffer and won't get overwritten if we
|
||||
// // only have a single sequence.
|
||||
// if c.swaMemorySize != math.MaxInt32 && maxSequences > 1 {
|
||||
// c.swaMemorySize = max(c.swaMemorySize, c.swaWindowSize+1)
|
||||
// }
|
||||
// if int(c.swaMemorySize) >= capacity {
|
||||
// c.swaMemorySize = math.MaxInt32
|
||||
// }
|
||||
|
||||
// if c.swaMemorySize < c.swaWindowSize {
|
||||
// panic(fmt.Errorf("sliding window memory (%v) must be at least as large as the window (%v)", c.swaMemorySize, c.swaWindowSize))
|
||||
// }
|
||||
|
||||
// var cacheSize int
|
||||
// if c.swaMemorySize == math.MaxInt32 {
|
||||
// cacheSize = maxSequences * capacity
|
||||
// } else {
|
||||
// cacheSize = (maxSequences * int(c.swaMemorySize)) + maxBatch
|
||||
// }
|
||||
// cacheSize = roundUp(cacheSize, c.config.CachePadding)
|
||||
// c.cells = make([]cacheCell, cacheSize)
|
||||
|
||||
// c.DType = dtype
|
||||
// c.cellRanges = make(map[int]cellRange)
|
||||
// c.backend = backend
|
||||
// c.maxBatch = maxBatch
|
||||
// }
|
||||
|
||||
// func (c *Causal) SetConfig(config ml.CacheConfig) {
|
||||
// if c.config != nil {
|
||||
// panic("config cannot be changed after being previously set, either by the model or backend")
|
||||
// }
|
||||
|
||||
// c.config = &config
|
||||
// }
|
||||
|
||||
// func (c *Causal) Close() {
|
||||
// slog.Info("XXX Causal.Close called", "number of contexts", len(c.ctxs))
|
||||
// for _, ctx := range c.ctxs {
|
||||
// ctx.Close()
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
||||
// slog.Info("XXX Causal.StartForward", "cell count", len(c.cells), "prior batch size", c.curBatchSize, "positions", len(batch.Positions), "reserve", reserve, "batch", batch)
|
||||
// // panic("XXX Causal.StartForward")
|
||||
// c.curBatchSize = len(batch.Positions)
|
||||
// c.curSequences = batch.Sequences
|
||||
// c.curPositions = batch.Positions
|
||||
// c.opts.Except = nil
|
||||
|
||||
// var locs []int32
|
||||
// if !reserve {
|
||||
// c.updateSlidingWindow()
|
||||
|
||||
// var err error
|
||||
// locs, err = c.findLocs()
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// slog.Info("XXX Causal.StartForward", "findLocs len", len(locs))
|
||||
|
||||
// for i, pos := range batch.Positions {
|
||||
// seq := batch.Sequences[i]
|
||||
// loc := int(locs[i])
|
||||
|
||||
// c.cells[loc] = cacheCell{pos: pos, sequences: []int{seq}}
|
||||
|
||||
// seqRange, ok := c.cellRanges[seq]
|
||||
// if !ok {
|
||||
// seqRange = newRange()
|
||||
// }
|
||||
|
||||
// seqRange.min = min(seqRange.min, loc)
|
||||
// c.curCellRange.min = min(c.curCellRange.min, loc)
|
||||
|
||||
// seqRange.max = max(seqRange.max, loc)
|
||||
// c.curCellRange.max = max(c.curCellRange.max, loc)
|
||||
|
||||
// c.cellRanges[seq] = seqRange
|
||||
// }
|
||||
// } else {
|
||||
// // If we are reserving memory, don't update any of the cache metadata but set the size
|
||||
// // to the worst case.
|
||||
// locs = make([]int32, c.curBatchSize)
|
||||
// for i := range locs {
|
||||
// locs[i] = int32(i)
|
||||
// }
|
||||
// c.curCellRange.min = 0
|
||||
// c.curCellRange.max = len(c.cells) - 1
|
||||
// }
|
||||
|
||||
// // XXX Building up the locs for what's already processed (if any)
|
||||
// dummyLocs := []int{}
|
||||
// c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding)
|
||||
// c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1
|
||||
|
||||
// for i := range c.curBatchSize {
|
||||
// enabled := !slices.Contains(c.opts.Except, i)
|
||||
// for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
|
||||
// if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) ||
|
||||
// (enabled && c.cells[j].pos > c.curPositions[i]) ||
|
||||
// c.chunkSize > 0 && c.cells[j].pos < c.curPositions[i]-c.curPositions[i]%c.chunkSize ||
|
||||
// c.cells[j].pos < c.curPositions[i]-c.swaWindowSize {
|
||||
// // mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
|
||||
// } else {
|
||||
// if len(dummyLocs) == 0 || dummyLocs[len(dummyLocs)-1] != i {
|
||||
// dummyLocs = append(dummyLocs, i)
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// slog.Info("XXX Causa.StartForward calculated locations", "locs", dummyLocs)
|
||||
|
||||
// slog.Info("XXX Causal.StartForward", "locs", locs)
|
||||
// c.curLoc = ctx.Input().FromInts(locs, len(locs))
|
||||
// c.curMask = c.buildMask(ctx)
|
||||
|
||||
// return nil
|
||||
// }
|
||||
|
||||
// func newRange() cellRange {
|
||||
// return cellRange{
|
||||
// min: math.MaxInt,
|
||||
// max: 0,
|
||||
// }
|
||||
// }
|
||||
|
||||
// // Returns a slice of locations where each token in the batch should be stored
|
||||
// func (c *Causal) findLocs() ([]int32, error) {
|
||||
// loc := make([]int32, 0, c.curBatchSize)
|
||||
|
||||
// for i := range c.cells {
|
||||
// if len(c.cells[i].sequences) == 0 {
|
||||
// loc = append(loc, int32(i))
|
||||
// if len(loc) >= c.curBatchSize {
|
||||
// return loc, nil
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// return nil, fmt.Errorf("%w (cache: %v batch: %v)", ErrKvCacheFull, len(c.cells), c.curBatchSize)
|
||||
// }
|
||||
|
||||
// func (c *Causal) updateSlidingWindow() {
|
||||
// c.curCellRange = newRange()
|
||||
|
||||
// if c.swaMemorySize == math.MaxInt32 {
|
||||
// for _, seq := range c.curSequences {
|
||||
// if seqRange, ok := c.cellRanges[seq]; ok {
|
||||
// c.curCellRange.min = min(c.curCellRange.min, seqRange.min)
|
||||
// c.curCellRange.max = max(c.curCellRange.max, seqRange.max)
|
||||
// }
|
||||
// }
|
||||
|
||||
// return
|
||||
// }
|
||||
|
||||
// type lowestPosition struct {
|
||||
// pos int32
|
||||
// curBatch bool
|
||||
// }
|
||||
|
||||
// // create a map of unique sequences to the lowest position in that sequence
|
||||
// lowestPos := make(map[int]lowestPosition)
|
||||
// for i := range c.curPositions {
|
||||
// seq := c.curSequences[i]
|
||||
|
||||
// lowest, ok := lowestPos[seq]
|
||||
// if !ok {
|
||||
// lowest = lowestPosition{pos: c.curPositions[i], curBatch: true}
|
||||
// } else if c.curPositions[i] < lowest.pos {
|
||||
// lowest.pos = c.curPositions[i]
|
||||
// }
|
||||
|
||||
// lowestPos[seq] = lowest
|
||||
// }
|
||||
|
||||
// // for any sequences are not part of this batch, clean up any tokens
|
||||
// // that are no longer needed after the processing of the previous
|
||||
// // batch
|
||||
// for seq, seqRange := range c.cellRanges {
|
||||
// if _, ok := lowestPos[seq]; !ok {
|
||||
// var last int32
|
||||
// for i := seqRange.min; i <= seqRange.max; i++ {
|
||||
// if slices.Contains(c.cells[i].sequences, seq) {
|
||||
// last = max(last, c.cells[i].pos)
|
||||
// }
|
||||
// }
|
||||
|
||||
// lowestPos[seq] = lowestPosition{pos: last + 1, curBatch: false}
|
||||
// }
|
||||
// }
|
||||
|
||||
// // delete any entries that are beyond the window of the oldest position in the sequence
|
||||
// for seq, lowest := range lowestPos {
|
||||
// oldRange, ok := c.cellRanges[seq]
|
||||
// if !ok {
|
||||
// continue
|
||||
// }
|
||||
|
||||
// newRange := newRange()
|
||||
|
||||
// for i := oldRange.min; i <= oldRange.max; i++ {
|
||||
// if slices.Contains(c.cells[i].sequences, seq) {
|
||||
// if c.cells[i].pos < lowest.pos-c.swaMemorySize {
|
||||
// c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq })
|
||||
// } else {
|
||||
// newRange.min = min(newRange.min, i)
|
||||
// newRange.max = max(newRange.max, i)
|
||||
// }
|
||||
// if lowest.curBatch && c.cells[i].pos >= lowest.pos-c.swaWindowSize {
|
||||
// c.curCellRange.min = min(c.curCellRange.min, i)
|
||||
// c.curCellRange.max = max(c.curCellRange.max, i)
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// c.cellRanges[seq] = newRange
|
||||
// }
|
||||
// }
|
||||
|
||||
// func roundDown(length, pad int) int {
|
||||
// return (length / pad) * pad
|
||||
// }
|
||||
|
||||
// func roundUp(length, pad int) int {
|
||||
// return ((length + pad - 1) / pad) * pad
|
||||
// }
|
||||
|
||||
// // Builds a mask of history x batch indicating whether for each token in the batch the
|
||||
// // token in the history should apply. This is based on both the sequence and causality (the
|
||||
// // position of the history is not ahead of the token in the batch).
|
||||
// func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
|
||||
// // Align and pad the two dimensions as required by the backend
|
||||
// batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding)
|
||||
|
||||
// c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding)
|
||||
// c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1
|
||||
|
||||
// length := c.curCellRange.max - c.curCellRange.min + 1
|
||||
|
||||
// mask := make([]float32, batchSize*length)
|
||||
|
||||
// for i := range c.curBatchSize {
|
||||
// enabled := !slices.Contains(c.opts.Except, i)
|
||||
// for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
|
||||
// if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) ||
|
||||
// (enabled && c.cells[j].pos > c.curPositions[i]) ||
|
||||
// c.chunkSize > 0 && c.cells[j].pos < c.curPositions[i]-c.curPositions[i]%c.chunkSize ||
|
||||
// c.cells[j].pos < c.curPositions[i]-c.swaWindowSize {
|
||||
// mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// // Mask out any padding tokens we added. For padding that we added to the cache history, this
|
||||
// // has already been masked out because the sequence doesn't match.
|
||||
// for i := c.curBatchSize * length; i < len(mask); i++ {
|
||||
// mask[i] = float32(math.Inf(-1))
|
||||
// }
|
||||
|
||||
// maskTensor := ctx.Input().FromFloats(mask, batchSize, length)
|
||||
|
||||
// // if c.config.MaskDType != ml.DTypeFloat32 {
|
||||
// // maskTensor = maskTensor.Cast(ctx, c.config.MaskDType)
|
||||
// // }
|
||||
|
||||
// slog.Info("XXX Causal.buildMask", "c.curBatchSize", c.curBatchSize, "c.config.MaskBatchPadding", c.config.MaskBatchPadding, "c.curCellRange.min", c.curCellRange.min, "c.curCellRange.max", c.curCellRange.max, "size", len(mask), "shape", []int{1, batchSize, length})
|
||||
|
||||
// return maskTensor
|
||||
// }
|
||||
|
||||
// func (c *Causal) SetLayer(layer int) {
|
||||
// c.curLayer = layer
|
||||
// }
|
||||
|
||||
// type CausalOptions struct {
|
||||
// // Enabled controls whether the causal mask is generated for a particular index in a batch
|
||||
// Except []int
|
||||
// }
|
||||
|
||||
// // SetCausal disables causal mask generation for a particular range of indicies in
|
||||
// // the current batch for subsequent calls to Get. The state resets for the next forward pass.
|
||||
// func (c *Causal) SetCausal(ctx ml.Context, opts CausalOptions) {
|
||||
// if !slices.Equal(c.opts.Except, opts.Except) {
|
||||
// c.opts = opts
|
||||
// if ctx != nil {
|
||||
// c.curMask = c.buildMask(ctx)
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||
// key := c.keys[c.curLayer]
|
||||
// value := c.values[c.curLayer]
|
||||
|
||||
// kHeadDim := c.kHeadDims[c.curLayer]
|
||||
// vHeadDim := c.vHeadDims[c.curLayer]
|
||||
// numKVHeads := c.numKVHeads[c.curLayer]
|
||||
// // rowSize := numKVHeads * c.curBatchSize
|
||||
// // cachedSize := c.curMask.Dim(1)
|
||||
// cachedSize := c.curLoc.Dim(0)
|
||||
// // kCellSize := kHeadDim * numKVHeads
|
||||
// // vCellSize := vHeadDim * numKVHeads
|
||||
|
||||
// slog.Info("XXX Causal.Get full cache", "key", key)
|
||||
// slog.Info("XXX Causal.Get full cache", "value", value)
|
||||
// slog.Info("XXX Causal.Get full cache", "curloc", c.curLoc)
|
||||
// slog.Info("XXX Causal.Get", "curMask", c.curMask)
|
||||
// slog.Info("XXX Causal.Get", "kHeadDim", kHeadDim, "numKVHeads", numKVHeads, "cachedSize", cachedSize, "kHeadDim", kHeadDim)
|
||||
// // panic("XXX")
|
||||
|
||||
// // fmt.Fprintln(os.Stderr, key.ToString())
|
||||
// // panic("full cache value")
|
||||
|
||||
// // TODO we should use TakeAxes to gather the cells from curLoc, but for now to be consistent with GGML, just grab a larger chunk and mask
|
||||
// key = key.TakeAxes(ctx, c.curLoc, 0).Reshape(ctx, 1, numKVHeads, cachedSize, kHeadDim)
|
||||
// // key = key.AsStrided(ctx, []int{1, numKVHeads, cachedSize, kHeadDim}, []int{}, rowSize*c.curCellRange.min)
|
||||
|
||||
// // slog.Info("XXX Causal.Get after AsStrided", "key", key)
|
||||
// // panic("XXX")
|
||||
|
||||
// // if c.config.PermutedV {
|
||||
// // panic("permuted")
|
||||
// // // TODO not converted
|
||||
// // vHeadDim := value.Dim(1)
|
||||
// // elemSize := value.Stride(2)
|
||||
|
||||
// // value = value.AsStrided(ctx,
|
||||
// // []int{numKVHeads, vHeadDim, cachedSize},
|
||||
// // []int{value.Stride(0), value.Stride(1)},
|
||||
// // elemSize*c.curCellRange.min,
|
||||
// // )
|
||||
// // } else {
|
||||
// // vHeadDim := c.vHeadDims[c.curLayer]
|
||||
// // rowSize := value.Stride(2)
|
||||
// // slog.Info("XXX Causal.Get before AsStrided", "vHeadDim", vHeadDim, "rowSize", rowSize)
|
||||
// // panic("XXX")
|
||||
|
||||
// // TODO we should use TakeAxes to gather the cells from curLoc, but for now to be consistent with GGML, just grab a larger chunk and mask
|
||||
// value = value.TakeAxes(ctx, c.curLoc, 0).Reshape(ctx, 1, numKVHeads, cachedSize, vHeadDim)
|
||||
// // value = value.AsStrided(ctx, []int{1, numKVHeads, cachedSize, vHeadDim}, []int{}, rowSize*c.curCellRange.min)
|
||||
|
||||
// // slog.Info("XXX Causal.Get after AsStrided", "value", value)
|
||||
// // panic("XXX")
|
||||
|
||||
// // }
|
||||
|
||||
// // // TODO The mask changes from X,X to 1,X, and with the Row-order change
|
||||
// // // the 1 becomes trailing and messes up later operations
|
||||
// // // This isn't the right solution, but works around it...
|
||||
// // if c.curMask.Dim(1) == 1 {
|
||||
// // return key, value, c.curMask.Transpose(ctx, 1, 0, 2, 3)
|
||||
// // }
|
||||
// // fmt.Fprintln(os.Stderr, key.ToString())
|
||||
// // fmt.Fprintln(os.Stderr, value.ToString())
|
||||
// // panic("XXX")
|
||||
// slog.Info("XXX Mask", "curLayer", c.curLayer, "shape", c.curMask.Shape())
|
||||
|
||||
// return key, value, c.curMask
|
||||
// }
|
||||
|
||||
// func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||
// kHeadDim := key.Dim(3)
|
||||
// vHeadDim := value.Dim(3)
|
||||
// numKVHeads := key.Dim(1)
|
||||
// batchSize := key.Dim(2)
|
||||
// kCellSize := kHeadDim * numKVHeads
|
||||
// vCellSize := vHeadDim * numKVHeads
|
||||
|
||||
// // slog.Info("XXX Causal.Put", "key", key, "value", value)
|
||||
// slog.Info("XXX Causal.Put", "kHeadDim", kHeadDim, "vHeadDim", vHeadDim, "numKVHeads", numKVHeads, "batchSize", batchSize)
|
||||
// // panic("XXX")
|
||||
|
||||
// if c.curBatchSize != batchSize {
|
||||
// panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, batchSize))
|
||||
// }
|
||||
|
||||
// // slog.Info("XXX", "c.ctxs", c.ctxs, "c.curLayer", c.curLayer, "backend", c.backend)
|
||||
// if _, ok := c.ctxs[c.curLayer]; !ok {
|
||||
// slog.Info("XXX Causal.Put creating new context", "c.curLayer", c.curLayer)
|
||||
// c.ctxs[c.curLayer] = c.backend.NewContext().Layer(c.curLayer)
|
||||
// }
|
||||
|
||||
// if _, ok := c.keys[c.curLayer]; !ok {
|
||||
// slog.Info("XXX Causal.Put allocating keys", "c.curLayer", c.curLayer, "shape", []int{len(c.cells), kCellSize})
|
||||
|
||||
// c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, len(c.cells), kCellSize)
|
||||
// c.kHeadDims[c.curLayer] = kHeadDim
|
||||
// c.vHeadDims[c.curLayer] = vHeadDim
|
||||
// c.numKVHeads[c.curLayer] = numKVHeads
|
||||
// }
|
||||
|
||||
// if _, ok := c.values[c.curLayer]; !ok {
|
||||
// // if c.config.PermutedV {
|
||||
// // c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, numKVHeads, vHeadDim, len(c.cells))
|
||||
// // } else {
|
||||
// c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, len(c.cells), vCellSize)
|
||||
// // }
|
||||
// }
|
||||
|
||||
// key = key.Reshape(ctx, batchSize, 1, kCellSize) //.Contiguous(ctx, false) // TODO contiguous may not be needed
|
||||
|
||||
// // slog.Info("XXX Causal.Put after reshape", "keyCache", keyCache)
|
||||
// // panic("XXX")
|
||||
// // curLoc := 0 // TODO c.curLoc is now a tensor
|
||||
// // kSize := numKVHeads * kHeadDim
|
||||
// // vSize := numKVHeads * vHeadDim
|
||||
// // start := []int{int(curLoc), 0}
|
||||
// // kStop := []int{int(curLoc + batchSize), int(kSize)}
|
||||
// // vStop := []int{int(curLoc + batchSize), int(vSize)}
|
||||
// // strides := []int{1, 1}
|
||||
|
||||
// // slog.Info("XXX Causal.Put Key SliceUpdate", "keyCache", keyCache)
|
||||
// // slog.Info("XXX Causal.Put Key SliceUpdate", "key", key)
|
||||
|
||||
// // slog.Info("XXX Causal.Put Key SliceUpdate", "start", start, "kStop", kStop, "strides", strides)
|
||||
|
||||
// // ctx.Forward(c.keys[c.curLayer].SliceUpdate(ctx, key, start, kStop, strides))
|
||||
// ctx.Forward(c.keys[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLoc}, key, []int{0}))
|
||||
// // fmt.Fprintln(os.Stderr, keyCache.ToString())
|
||||
// // panic("input value")
|
||||
|
||||
// // fmt.Fprintln(os.Stderr, t.ToString())
|
||||
// // panic("XXX")
|
||||
|
||||
// // if c.config.PermutedV {
|
||||
// // panic("permuted")
|
||||
// // // TODO not adjusted
|
||||
// // value = value.Reshape(ctx, vHeadDim*numKVHeads, 1, batchSize)
|
||||
// // value = value.Transpose(ctx, 2, 0, 1, 3)
|
||||
|
||||
// // valueCache := c.values[c.curLayer]
|
||||
// // valueCache = valueCache.Reshape(ctx, 1, len(c.cells), vHeadDim*numKVHeads)
|
||||
|
||||
// // ctx.Forward(valueCache.SliceUpdate(ctx, value, start, vStop, strides))
|
||||
// // } else {
|
||||
// value = value.Reshape(ctx, batchSize, 1, vCellSize) //.Contiguous(ctx, false) // TODO contiguous may not be needed
|
||||
// // slog.Info("XXX Causal.Put Value SliceUpdate", "valueCache", valueCache)
|
||||
// // slog.Info("XXX Causal.Put Value SliceUpdate", "value", value)
|
||||
// // slog.Info("XXX Causal.Put Value SliceUpdate", "start", start, "vStop", vStop, "strides", strides)
|
||||
|
||||
// ctx.Forward(c.values[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLoc}, value, []int{0}))
|
||||
// // }
|
||||
// // fmt.Fprintln(os.Stderr, c.keys[c.curLayer].ToString())
|
||||
// // fmt.Fprintln(os.Stderr, c.values[c.curLayer].ToString())
|
||||
// // panic("XXX")
|
||||
|
||||
// }
|
||||
|
||||
// func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
|
||||
// seqRange := newRange()
|
||||
|
||||
// for i := range c.cells {
|
||||
// // Remove the contents of dstSeq so that we only have the copied prefix, metadata will be reset at the end
|
||||
// if slices.Contains(c.cells[i].sequences, dstSeq) {
|
||||
// c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == dstSeq })
|
||||
// }
|
||||
|
||||
// if slices.Contains(c.cells[i].sequences, srcSeq) && c.cells[i].pos < len {
|
||||
// c.cells[i].sequences = append(c.cells[i].sequences, dstSeq)
|
||||
// if i < seqRange.min {
|
||||
// seqRange.min = i
|
||||
// }
|
||||
// if i > seqRange.max {
|
||||
// seqRange.max = i
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// c.cellRanges[dstSeq] = seqRange
|
||||
// }
|
||||
|
||||
// func (c *Causal) CanResume(seq int, pos int32) bool {
|
||||
// if c.swaMemorySize == math.MaxInt32 {
|
||||
// return true
|
||||
// }
|
||||
|
||||
// seqRange, ok := c.cellRanges[seq]
|
||||
// if !ok {
|
||||
// return false
|
||||
// }
|
||||
|
||||
// // for sliding window, check that the window of the new sequence is contained in
|
||||
// // the window of what we are storing
|
||||
// var first int32 = math.MaxInt32
|
||||
// var last int32 = -1
|
||||
// for i := seqRange.min; i <= seqRange.max; i++ {
|
||||
// if slices.Contains(c.cells[i].sequences, seq) {
|
||||
// first = min(first, c.cells[i].pos)
|
||||
// last = max(last, c.cells[i].pos)
|
||||
// }
|
||||
// }
|
||||
|
||||
// if last == -1 {
|
||||
// return false
|
||||
// }
|
||||
|
||||
// posWindowStart := max(0, pos-c.swaWindowSize)
|
||||
// return posWindowStart >= first && pos <= last+1
|
||||
// }
|
||||
|
||||
// func (c *Causal) shift(seq int, beginIndex, offset int32) error {
|
||||
// if c.shiftFn == nil {
|
||||
// return ErrNotSupported
|
||||
// }
|
||||
|
||||
// seqRange := c.cellRanges[seq]
|
||||
|
||||
// for start := seqRange.min; start <= seqRange.max; start += c.maxBatch {
|
||||
// size := min(seqRange.max-start+1, c.maxBatch)
|
||||
// offsets := make([]int32, size)
|
||||
|
||||
// var batchFirst, batchLast int
|
||||
|
||||
// batchFirst = -1
|
||||
// for i := range offsets {
|
||||
// cell := c.cells[start+i]
|
||||
|
||||
// if slices.Contains(cell.sequences, seq) && cell.pos >= beginIndex {
|
||||
// offsets[i] = offset
|
||||
// if batchFirst < 0 {
|
||||
// batchFirst = i
|
||||
// }
|
||||
// batchLast = i
|
||||
// }
|
||||
// }
|
||||
|
||||
// if batchFirst < 0 {
|
||||
// continue
|
||||
// }
|
||||
|
||||
// offsets = offsets[batchFirst : batchLast+1]
|
||||
|
||||
// slog.Info("XXX Causal.shift creating new temporary context")
|
||||
// ctx := c.backend.NewContext()
|
||||
// kShift := ctx.Input().FromInts(offsets, len(offsets))
|
||||
|
||||
// for i, key := range c.keys {
|
||||
// if key == nil {
|
||||
// continue
|
||||
// }
|
||||
|
||||
// kHeadDim := key.Dim(2)
|
||||
// numKVHeads := key.Dim(1)
|
||||
// rowSize := key.Stride(0)
|
||||
|
||||
// key = key.AsStrided(ctx,
|
||||
// []int{len(offsets), numKVHeads, kHeadDim},
|
||||
// []int{key.Stride(0), key.Stride(1)},
|
||||
// rowSize*(start+batchFirst),
|
||||
// )
|
||||
|
||||
// roped, err := c.shiftFn(ctx, i, key, kShift)
|
||||
// if err != nil {
|
||||
// ctx.Close()
|
||||
// return err
|
||||
// }
|
||||
|
||||
// ctx.Forward(roped.Copy(ctx, key))
|
||||
// }
|
||||
|
||||
// ctx.Compute()
|
||||
// ctx.Close()
|
||||
// }
|
||||
|
||||
// return nil
|
||||
// }
|
||||
|
||||
// func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error {
|
||||
// // TODO(jessegross): We should check to see if removing the middle of the sequence will
|
||||
// // cause the sliding window to encompass tokens that we no longer have. If so, then we
|
||||
// // should return an error, which will trigger the runner to evaluate the full history and
|
||||
// // rebuild the window. However, if we have multimodal inputs in our history, this reuse
|
||||
// // results in use after free, so we don't do it for now.
|
||||
|
||||
// var offset int32
|
||||
// if endIndex != math.MaxInt32 {
|
||||
// offset = beginIndex - endIndex
|
||||
// }
|
||||
|
||||
// seqRange := newRange()
|
||||
|
||||
// for i := range c.cells {
|
||||
// if slices.Contains(c.cells[i].sequences, seq) {
|
||||
// if c.cells[i].pos >= beginIndex && c.cells[i].pos < endIndex {
|
||||
// c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq })
|
||||
// } else {
|
||||
// if c.cells[i].pos >= endIndex {
|
||||
// if slices.ContainsFunc(c.cells[i].sequences, func(s int) bool { return s != seq }) {
|
||||
// return errors.New("shifting cells shared by multiple sequences not supported")
|
||||
// }
|
||||
|
||||
// c.cells[i].pos += offset
|
||||
// }
|
||||
// if i < seqRange.min {
|
||||
// seqRange.min = i
|
||||
// }
|
||||
// if i > seqRange.max {
|
||||
// seqRange.max = i
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// if seqRange == newRange() {
|
||||
// delete(c.cellRanges, seq)
|
||||
// return nil
|
||||
// }
|
||||
|
||||
// c.cellRanges[seq] = seqRange
|
||||
|
||||
// if endIndex != math.MaxInt32 {
|
||||
// err := c.shift(seq, endIndex+offset, offset)
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// }
|
||||
|
||||
// return nil
|
||||
// }
|
||||
973
x/kvcache/causal_test.go
Normal file
973
x/kvcache/causal_test.go
Normal file
@@ -0,0 +1,973 @@
|
||||
package kvcache
|
||||
|
||||
// import (
|
||||
// "fmt"
|
||||
// "math"
|
||||
// "slices"
|
||||
// "testing"
|
||||
|
||||
// "github.com/ollama/ollama/ml"
|
||||
// "github.com/ollama/ollama/model/input"
|
||||
// )
|
||||
|
||||
// type testCase struct {
|
||||
// name string
|
||||
// in []float32
|
||||
// inShape []int
|
||||
// seqs []int
|
||||
// pos []int32
|
||||
// expected []float32
|
||||
// expectedShape []int
|
||||
// expectedMask []float32
|
||||
// }
|
||||
|
||||
// func runPermutedVariants(t *testing.T, fn func(t *testing.T, backend *testBackend)) {
|
||||
// t.Helper()
|
||||
// for _, permuted := range []bool{false, true} {
|
||||
// t.Run(fmt.Sprintf("PermutedV=%t", permuted), func(t *testing.T) {
|
||||
// fn(t, &testBackend{permutedV: permuted})
|
||||
// })
|
||||
// }
|
||||
// }
|
||||
|
||||
// func TestStore(t *testing.T) {
|
||||
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
||||
// cache := NewCausalCache(nil)
|
||||
// defer cache.Close()
|
||||
|
||||
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||
|
||||
// tests := []testCase{
|
||||
// {
|
||||
// name: "FirstBatch",
|
||||
// in: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234},
|
||||
// inShape: []int{2, 3, 4},
|
||||
// seqs: []int{0, 0, 0, 0},
|
||||
// pos: []int32{0, 1, 2, 3},
|
||||
// expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234},
|
||||
// expectedShape: []int{2, 3, 4},
|
||||
// expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0},
|
||||
// },
|
||||
// {
|
||||
// name: "SecondBatch",
|
||||
// in: []float32{115, 215, 125, 225, 135, 235},
|
||||
// inShape: []int{2, 3, 1},
|
||||
// seqs: []int{0},
|
||||
// pos: []int32{4},
|
||||
// expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234, 115, 215, 125, 225, 135, 235},
|
||||
// expectedShape: []int{2, 3, 5},
|
||||
// expectedMask: []float32{0, 0, 0, 0, 0},
|
||||
// },
|
||||
// }
|
||||
|
||||
// testCache(t, backend, cache, tests)
|
||||
// })
|
||||
// }
|
||||
|
||||
// func TestSWA(t *testing.T) {
|
||||
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
||||
// cache := NewSWACache(1, nil)
|
||||
// defer cache.Close()
|
||||
|
||||
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||
|
||||
// x := float32(math.Inf(-1))
|
||||
|
||||
// tests := []testCase{
|
||||
// {
|
||||
// name: "FirstBatch",
|
||||
// in: []float32{1, 2, 3, 4},
|
||||
// inShape: []int{1, 1, 4},
|
||||
// seqs: []int{0, 0, 0, 0},
|
||||
// pos: []int32{0, 1, 2, 3},
|
||||
// expected: []float32{1, 2, 3, 4},
|
||||
// expectedShape: []int{1, 1, 4},
|
||||
// expectedMask: []float32{
|
||||
// 0, x, x, x,
|
||||
// 0, 0, x, x,
|
||||
// x, 0, 0, x,
|
||||
// x, x, 0, 0,
|
||||
// },
|
||||
// },
|
||||
// {
|
||||
// name: "SecondBatch",
|
||||
// in: []float32{5, 6},
|
||||
// inShape: []int{1, 1, 2},
|
||||
// seqs: []int{0, 0},
|
||||
// pos: []int32{4, 5},
|
||||
// expected: []float32{5, 6, 3, 4},
|
||||
// expectedShape: []int{1, 1, 4},
|
||||
// expectedMask: []float32{
|
||||
// 0, x, x, 0,
|
||||
// 0, 0, x, x,
|
||||
// },
|
||||
// },
|
||||
// }
|
||||
|
||||
// testCache(t, backend, cache, tests)
|
||||
// })
|
||||
// }
|
||||
|
||||
// func TestSWASeparateBatches(t *testing.T) {
|
||||
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
||||
// cache := NewSWACache(1, nil)
|
||||
// defer cache.Close()
|
||||
|
||||
// cache.Init(backend, ml.DTypeF16, 2, 16, 2)
|
||||
|
||||
// x := float32(math.Inf(-1))
|
||||
|
||||
// tests := []testCase{
|
||||
// {
|
||||
// name: "First seq 0",
|
||||
// in: []float32{1, 2},
|
||||
// inShape: []int{1, 1, 2},
|
||||
// seqs: []int{0, 0},
|
||||
// pos: []int32{0, 1},
|
||||
// expected: []float32{1, 2},
|
||||
// expectedShape: []int{1, 1, 2},
|
||||
// expectedMask: []float32{
|
||||
// 0, x,
|
||||
// 0, 0,
|
||||
// },
|
||||
// },
|
||||
// {
|
||||
// name: "Second seq 0",
|
||||
// in: []float32{3, 4},
|
||||
// inShape: []int{1, 1, 2},
|
||||
// seqs: []int{0, 0},
|
||||
// pos: []int32{2, 3},
|
||||
// expected: []float32{2, 3, 4},
|
||||
// expectedShape: []int{1, 1, 3},
|
||||
// expectedMask: []float32{
|
||||
// 0, 0, x,
|
||||
// x, 0, 0,
|
||||
// },
|
||||
// },
|
||||
// {
|
||||
// name: "First seq 1",
|
||||
// in: []float32{5, 6},
|
||||
// inShape: []int{1, 1, 2},
|
||||
// seqs: []int{1, 1},
|
||||
// pos: []int32{0, 1},
|
||||
// expected: []float32{5, 6},
|
||||
// expectedShape: []int{1, 1, 2},
|
||||
// expectedMask: []float32{
|
||||
// 0, x,
|
||||
// 0, 0,
|
||||
// },
|
||||
// },
|
||||
// {
|
||||
// name: "Second seq 1",
|
||||
// in: []float32{7, 8},
|
||||
// inShape: []int{1, 1, 2},
|
||||
// seqs: []int{1, 1},
|
||||
// pos: []int32{2, 3},
|
||||
// expected: []float32{6, 3, 4, 7, 8},
|
||||
// expectedShape: []int{1, 1, 5},
|
||||
// expectedMask: []float32{
|
||||
// 0, x, x, 0, x,
|
||||
// x, x, x, 0, 0,
|
||||
// },
|
||||
// },
|
||||
// {
|
||||
// name: "Third seq 0",
|
||||
// in: []float32{9, 10},
|
||||
// inShape: []int{1, 1, 2},
|
||||
// seqs: []int{0, 0},
|
||||
// pos: []int32{4, 5},
|
||||
// expected: []float32{9, 10, 3, 4},
|
||||
// expectedShape: []int{1, 1, 4},
|
||||
// expectedMask: []float32{
|
||||
// 0, x, x, 0,
|
||||
// 0, 0, x, x,
|
||||
// },
|
||||
// },
|
||||
// }
|
||||
|
||||
// testCache(t, backend, cache, tests)
|
||||
// })
|
||||
// }
|
||||
|
||||
// func TestSWAMem(t *testing.T) {
|
||||
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
||||
// cache := NewSWAMemCache(1, 3, nil)
|
||||
// defer cache.Close()
|
||||
|
||||
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||
|
||||
// x := float32(math.Inf(-1))
|
||||
|
||||
// tests := []testCase{
|
||||
// {
|
||||
// name: "FirstBatch",
|
||||
// in: []float32{1, 2, 3, 4},
|
||||
// inShape: []int{1, 1, 4},
|
||||
// seqs: []int{0, 0, 0, 0},
|
||||
// pos: []int32{0, 1, 2, 3},
|
||||
// expected: []float32{1, 2, 3, 4},
|
||||
// expectedShape: []int{1, 1, 4},
|
||||
// expectedMask: []float32{
|
||||
// 0, x, x, x,
|
||||
// 0, 0, x, x,
|
||||
// x, 0, 0, x,
|
||||
// x, x, 0, 0,
|
||||
// },
|
||||
// },
|
||||
// {
|
||||
// name: "SecondBatch",
|
||||
// in: []float32{5, 6},
|
||||
// inShape: []int{1, 1, 2},
|
||||
// seqs: []int{0, 0},
|
||||
// pos: []int32{4, 5},
|
||||
// expected: []float32{5, 2, 3, 4, 6},
|
||||
// expectedShape: []int{1, 1, 5},
|
||||
// expectedMask: []float32{
|
||||
// 0, x, x, 0, x,
|
||||
// 0, x, x, x, 0,
|
||||
// },
|
||||
// },
|
||||
// }
|
||||
|
||||
// testCache(t, backend, cache, tests)
|
||||
// })
|
||||
// }
|
||||
|
||||
// func TestChunkedAttention(t *testing.T) {
|
||||
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
||||
// cache := NewChunkedAttentionCache(2, nil)
|
||||
// defer cache.Close()
|
||||
|
||||
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||
|
||||
// x := float32(math.Inf(-1))
|
||||
|
||||
// testCache(
|
||||
// t, backend, cache,
|
||||
// []testCase{
|
||||
// {
|
||||
// name: "FirstBatch",
|
||||
// in: []float32{1, 2, 3, 4},
|
||||
// inShape: []int{1, 1, 4},
|
||||
// seqs: []int{0, 0, 0, 0},
|
||||
// pos: []int32{0, 1, 2, 3},
|
||||
// expected: []float32{1, 2, 3, 4},
|
||||
// expectedShape: []int{1, 1, 4},
|
||||
// expectedMask: []float32{
|
||||
// 0, x, x, x,
|
||||
// 0, 0, x, x,
|
||||
// x, x, 0, x,
|
||||
// x, x, 0, 0,
|
||||
// },
|
||||
// },
|
||||
// {
|
||||
// name: "SecondBatch",
|
||||
// in: []float32{5, 6, 7},
|
||||
// inShape: []int{1, 1, 3},
|
||||
// seqs: []int{0, 0, 0},
|
||||
// pos: []int32{4, 5, 6},
|
||||
// expected: []float32{1, 2, 3, 4, 5, 6, 7},
|
||||
// expectedShape: []int{1, 1, 7},
|
||||
// expectedMask: []float32{
|
||||
// x, x, x, x, 0, x, x,
|
||||
// x, x, x, x, 0, 0, x,
|
||||
// x, x, x, x, x, x, 0,
|
||||
// },
|
||||
// },
|
||||
// {
|
||||
// name: "ThirdBatch",
|
||||
// in: []float32{8, 9},
|
||||
// inShape: []int{1, 1, 2},
|
||||
// seqs: []int{0, 0},
|
||||
// pos: []int32{7, 8},
|
||||
// expected: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9},
|
||||
// expectedShape: []int{1, 1, 9},
|
||||
// expectedMask: []float32{
|
||||
// x, x, x, x, x, x, 0, 0, x,
|
||||
// x, x, x, x, x, x, x, x, 0,
|
||||
// },
|
||||
// },
|
||||
// },
|
||||
// )
|
||||
// })
|
||||
// }
|
||||
|
||||
// func TestSequences(t *testing.T) {
|
||||
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
||||
// cache := NewCausalCache(nil)
|
||||
// defer cache.Close()
|
||||
|
||||
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||
|
||||
// tests := []testCase{
|
||||
// {
|
||||
// name: "FirstBatch",
|
||||
// in: []float32{1, 2, 3, 4},
|
||||
// inShape: []int{1, 1, 4},
|
||||
// seqs: []int{0, 0, 1, 1},
|
||||
// pos: []int32{0, 1, 0, 1},
|
||||
// expected: []float32{1, 2, 3, 4},
|
||||
// expectedShape: []int{1, 1, 4},
|
||||
// expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
|
||||
// },
|
||||
// {
|
||||
// name: "SecondBatch",
|
||||
// in: []float32{5, 6},
|
||||
// inShape: []int{1, 1, 2},
|
||||
// seqs: []int{0, 1},
|
||||
// pos: []int32{2, 2},
|
||||
// expected: []float32{1, 2, 3, 4, 5, 6},
|
||||
// expectedShape: []int{1, 1, 6},
|
||||
// expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), 0},
|
||||
// },
|
||||
// }
|
||||
|
||||
// testCache(t, backend, cache, tests)
|
||||
// })
|
||||
// }
|
||||
|
||||
// func TestRemove(t *testing.T) {
|
||||
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
||||
// cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
// return key.Add(ctx, shift), nil
|
||||
// })
|
||||
// defer cache.Close()
|
||||
|
||||
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||
|
||||
// x := float32(math.Inf(-1))
|
||||
|
||||
// tests := []testCase{
|
||||
// {
|
||||
// name: "FirstBatch",
|
||||
// in: []float32{1, 2, 3, 4},
|
||||
// inShape: []int{1, 1, 4},
|
||||
// seqs: []int{0, 0, 1, 1},
|
||||
// pos: []int32{0, 1, 0, 1},
|
||||
// expected: []float32{1, 2, 3, 4},
|
||||
// expectedShape: []int{1, 1, 4},
|
||||
// expectedMask: []float32{
|
||||
// 0, x, x, x,
|
||||
// 0, 0, x, x,
|
||||
// x, x, 0, x,
|
||||
// x, x, 0, 0,
|
||||
// },
|
||||
// },
|
||||
// }
|
||||
|
||||
// testCache(t, backend, cache, tests)
|
||||
|
||||
// err := cache.Remove(0, 1, math.MaxInt32)
|
||||
// if err != nil {
|
||||
// panic(err)
|
||||
// }
|
||||
|
||||
// tests = []testCase{
|
||||
// {
|
||||
// name: "RemoveEnd",
|
||||
// in: []float32{5, 6},
|
||||
// inShape: []int{1, 1, 2},
|
||||
// seqs: []int{0, 1},
|
||||
// pos: []int32{1, 2},
|
||||
// expected: []float32{1, 5, 3, 4, 6},
|
||||
// expectedShape: []int{1, 1, 5},
|
||||
// expectedMask: []float32{
|
||||
// 0, 0, x, x, x,
|
||||
// x, x, 0, 0, 0,
|
||||
// },
|
||||
// },
|
||||
// }
|
||||
|
||||
// testCache(t, backend, cache, tests)
|
||||
|
||||
// err = cache.Remove(0, 0, 1)
|
||||
// if err != nil {
|
||||
// panic(err)
|
||||
// }
|
||||
|
||||
// tests = []testCase{
|
||||
// {
|
||||
// name: "RemoveMiddle",
|
||||
// in: []float32{7, 8},
|
||||
// inShape: []int{1, 1, 2},
|
||||
// seqs: []int{0, 0},
|
||||
// pos: []int32{1, 2},
|
||||
// expected: []float32{7, 4, 3, 4, 6, 8},
|
||||
// expectedShape: []int{1, 1, 6},
|
||||
// expectedMask: []float32{
|
||||
// 0, 0, x, x, x, x,
|
||||
// 0, 0, x, x, x, 0,
|
||||
// },
|
||||
// },
|
||||
// }
|
||||
|
||||
// testCache(t, backend, cache, tests)
|
||||
// })
|
||||
// }
|
||||
|
||||
// func TestCopy(t *testing.T) {
|
||||
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
||||
// cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { return key, nil })
|
||||
// defer cache.Close()
|
||||
|
||||
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||
|
||||
// tests := []testCase{
|
||||
// {
|
||||
// name: "FirstBatch",
|
||||
// in: []float32{1, 2, 3, 4},
|
||||
// inShape: []int{1, 1, 4},
|
||||
// seqs: []int{0, 0, 0, 0},
|
||||
// pos: []int32{0, 1, 2, 3},
|
||||
// expected: []float32{1, 2, 3, 4},
|
||||
// expectedShape: []int{1, 1, 4},
|
||||
// expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0},
|
||||
// },
|
||||
// }
|
||||
|
||||
// testCache(t, backend, cache, tests)
|
||||
|
||||
// cache.CopyPrefix(0, 1, 2)
|
||||
|
||||
// tests = []testCase{
|
||||
// {
|
||||
// name: "Copy",
|
||||
// in: []float32{5, 6},
|
||||
// inShape: []int{1, 1, 2},
|
||||
// seqs: []int{1, 1},
|
||||
// pos: []int32{3, 4},
|
||||
// expected: []float32{1, 2, 3, 4, 5, 6},
|
||||
// expectedShape: []int{1, 1, 6},
|
||||
// expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
|
||||
// },
|
||||
// }
|
||||
|
||||
// testCache(t, backend, cache, tests)
|
||||
// })
|
||||
// }
|
||||
|
||||
// func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase) {
|
||||
// for _, test := range tests {
|
||||
// t.Run(test.name, func(t *testing.T) {
|
||||
// context := backend.NewContext()
|
||||
// defer context.Close()
|
||||
|
||||
// err := cache.StartForward(context, input.Batch{Positions: test.pos, Sequences: test.seqs}, false)
|
||||
// if err != nil {
|
||||
// panic(err)
|
||||
// }
|
||||
|
||||
// cache.SetLayer(0)
|
||||
// tensor := context.FromFloats(test.in, test.inShape...)
|
||||
// cache.Put(context, tensor, tensor)
|
||||
|
||||
// out, _, mask := cache.Get(context)
|
||||
|
||||
// context.Forward(out, mask).Compute(out, mask)
|
||||
|
||||
// if !slices.Equal(out.Floats(), test.expected) {
|
||||
// t.Errorf("TestCache: have %v; want %v", out.Floats(), test.expected)
|
||||
// }
|
||||
|
||||
// if !slices.Equal(out.Shape(), test.expectedShape) {
|
||||
// t.Errorf("TestCache: has shape %v; want %v", out.Shape(), test.expectedShape)
|
||||
// }
|
||||
|
||||
// if !slices.Equal(mask.Floats(), test.expectedMask) {
|
||||
// t.Errorf("TestCache: have mask: have %v want %v", mask.Floats(), test.expectedMask)
|
||||
// }
|
||||
// })
|
||||
// }
|
||||
// }
|
||||
|
||||
// func TestCanResume(t *testing.T) {
|
||||
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
||||
// windowSize := int32(4)
|
||||
// cache := NewSWACache(windowSize, nil)
|
||||
// defer cache.Close()
|
||||
|
||||
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||
|
||||
// context := backend.NewContext()
|
||||
// defer context.Close()
|
||||
|
||||
// err := cache.StartForward(context, input.Batch{
|
||||
// Positions: []int32{0, 1, 2, 3, 4},
|
||||
// Sequences: []int{0, 0, 0, 0, 0},
|
||||
// }, false)
|
||||
// if err != nil {
|
||||
// t.Fatalf("StartForward failed: %v", err)
|
||||
// }
|
||||
|
||||
// cache.SetLayer(0)
|
||||
// tensor := context.FromFloats([]float32{1, 2, 3, 4, 5}, 1, 1, 5)
|
||||
// cache.Put(context, tensor, tensor)
|
||||
|
||||
// // with window size 4, nothing has slid out of the window yet
|
||||
// if !cache.CanResume(0, 0) {
|
||||
// t.Errorf("CanResume(0, 0) = false, want true (within window)")
|
||||
// }
|
||||
// if !cache.CanResume(0, 1) {
|
||||
// t.Errorf("CanResume(0, 1) = false, want true (within window)")
|
||||
// }
|
||||
// if !cache.CanResume(0, 2) {
|
||||
// t.Errorf("CanResume(0, 2) = false, want true (within window)")
|
||||
// }
|
||||
// if !cache.CanResume(0, 3) {
|
||||
// t.Errorf("CanResume(0, 3) = false, want true (latest position)")
|
||||
// }
|
||||
// if !cache.CanResume(0, 4) {
|
||||
// t.Errorf("CanResume(0, 4) = false, want true (latest position)")
|
||||
// }
|
||||
|
||||
// // shift window by adding position 5
|
||||
// err = cache.StartForward(context, input.Batch{
|
||||
// Positions: []int32{5},
|
||||
// Sequences: []int{0},
|
||||
// }, false)
|
||||
// if err != nil {
|
||||
// t.Fatalf("StartForward failed: %v", err)
|
||||
// }
|
||||
|
||||
// cache.SetLayer(0)
|
||||
// tensor = context.FromFloats([]float32{6}, 1, 1, 1)
|
||||
// cache.Put(context, tensor, tensor)
|
||||
|
||||
// // only the latest position has overlapping windows
|
||||
// if cache.CanResume(0, 0) {
|
||||
// t.Errorf("after shift: CanResume(0, 0) = true, want false (outside window)")
|
||||
// }
|
||||
// if cache.CanResume(0, 1) {
|
||||
// t.Errorf("after shift: CanResume(0, 1) = true, want false (outside window)")
|
||||
// }
|
||||
// if cache.CanResume(0, 2) {
|
||||
// t.Errorf("after shift: CanResume(0, 2) = true, want false (outside window)")
|
||||
// }
|
||||
// if cache.CanResume(0, 3) {
|
||||
// t.Errorf("after shift: CanResume(0, 3) = true, want false (outside window)")
|
||||
// }
|
||||
// if cache.CanResume(0, 4) {
|
||||
// t.Errorf("after shift: CanResume(0, 4) = true, want false (outside window)")
|
||||
// }
|
||||
// if !cache.CanResume(0, 5) {
|
||||
// t.Errorf("after shift: CanResume(0, 5) = false, want true (latest position)")
|
||||
// }
|
||||
// })
|
||||
// }
|
||||
|
||||
// func TestCanResumeSWAMem(t *testing.T) {
|
||||
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
||||
// windowSize := int32(4)
|
||||
// memSize := int32(5)
|
||||
// cache := NewSWAMemCache(windowSize, memSize, nil)
|
||||
// defer cache.Close()
|
||||
|
||||
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||
|
||||
// context := backend.NewContext()
|
||||
// defer context.Close()
|
||||
|
||||
// err := cache.StartForward(context, input.Batch{
|
||||
// Positions: []int32{0, 1, 2, 3, 4, 5, 6},
|
||||
// Sequences: []int{0, 0, 0, 0, 0, 0, 0},
|
||||
// }, false)
|
||||
// if err != nil {
|
||||
// t.Fatalf("StartForward failed: %v", err)
|
||||
// }
|
||||
|
||||
// cache.SetLayer(0)
|
||||
// tensor := context.FromFloats([]float32{1, 2, 3, 4, 5, 6, 7}, 1, 1, 7)
|
||||
// cache.Put(context, tensor, tensor)
|
||||
|
||||
// // shift window by adding position 7
|
||||
// err = cache.StartForward(context, input.Batch{
|
||||
// Positions: []int32{7},
|
||||
// Sequences: []int{0},
|
||||
// }, false)
|
||||
// if err != nil {
|
||||
// t.Fatalf("StartForward failed: %v", err)
|
||||
// }
|
||||
|
||||
// cache.SetLayer(0)
|
||||
// tensor = context.FromFloats([]float32{8}, 1, 1, 1)
|
||||
// cache.Put(context, tensor, tensor)
|
||||
|
||||
// // only the latest position has overlapping windows
|
||||
// if cache.CanResume(0, 0) {
|
||||
// t.Errorf("after shift: CanResume(0, 0) = true, want false (outside window)")
|
||||
// }
|
||||
// if cache.CanResume(0, 1) {
|
||||
// t.Errorf("after shift: CanResume(0, 1) = true, want false (outside window)")
|
||||
// }
|
||||
// if cache.CanResume(0, 2) {
|
||||
// t.Errorf("after shift: CanResume(0, 2) = true, want false (outside window)")
|
||||
// }
|
||||
// if cache.CanResume(0, 3) {
|
||||
// t.Errorf("after shift: CanResume(0, 3) = true, want false (outside window)")
|
||||
// }
|
||||
// if cache.CanResume(0, 4) {
|
||||
// t.Errorf("after shift: CanResume(0, 4) = true, want false (outside window)")
|
||||
// }
|
||||
// if cache.CanResume(0, 5) {
|
||||
// t.Errorf("after shift: CanResume(0, 5) = true, want false (outside window)")
|
||||
// }
|
||||
// if !cache.CanResume(0, 6) {
|
||||
// t.Errorf("after shift: CanResume(0, 6) = false, want true (inside window)")
|
||||
// }
|
||||
// if !cache.CanResume(0, 7) {
|
||||
// t.Errorf("after shift: CanResume(0, 7) = false, want true (latest position)")
|
||||
// }
|
||||
// })
|
||||
// }
|
||||
|
||||
// type testBackend struct {
|
||||
// ml.Backend
|
||||
// permutedV bool
|
||||
// }
|
||||
|
||||
// func (b *testBackend) NewContext() ml.Context {
|
||||
// return &testContext{}
|
||||
// }
|
||||
|
||||
// func (b *testBackend) NewContextSize(int) ml.Context {
|
||||
// return &testContext{}
|
||||
// }
|
||||
|
||||
// func (b *testBackend) CacheConfig() ml.CacheConfig {
|
||||
// return ml.CacheConfig{PermutedV: b.permutedV}
|
||||
// }
|
||||
|
||||
// type testContext struct {
|
||||
// ml.Context
|
||||
// }
|
||||
|
||||
// func (c *testContext) Empty(dtype ml.DType, shape ...int) ml.Tensor {
|
||||
// total := 0
|
||||
|
||||
// if len(shape) > 0 {
|
||||
// total = 1
|
||||
// for _, s := range shape {
|
||||
// total *= s
|
||||
// }
|
||||
// }
|
||||
|
||||
// return &testTensor{dtype: dtype, elementSize: 4, data: make([]float32, total), shape: shape}
|
||||
// }
|
||||
|
||||
// func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
|
||||
// return c.Empty(dtype, shape...)
|
||||
// }
|
||||
|
||||
// func (c *testContext) FromFloats(s []float32, shape ...int) ml.Tensor {
|
||||
// t := c.Empty(ml.DTypeF32, shape...).(*testTensor)
|
||||
|
||||
// copy(t.data, s)
|
||||
|
||||
// return t
|
||||
// }
|
||||
|
||||
// func (c *testContext) FromInts(s []int32, shape ...int) ml.Tensor {
|
||||
// f := make([]float32, len(s))
|
||||
// for i := range f {
|
||||
// f[i] = float32(s[i])
|
||||
// }
|
||||
|
||||
// out := c.FromFloats(f, shape...)
|
||||
// out.(*testTensor).dtype = ml.DTypeI32
|
||||
|
||||
// return out
|
||||
// }
|
||||
|
||||
// func (c *testContext) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor {
|
||||
// s := make([]float32, 0, int((stop-start)/step))
|
||||
// for i := start; i < stop; i += step {
|
||||
// s = append(s, i)
|
||||
// }
|
||||
|
||||
// out := c.FromFloats(s, len(s))
|
||||
// out.(*testTensor).dtype = dtype
|
||||
// return out
|
||||
// }
|
||||
|
||||
// func (c *testContext) Input() ml.Context { return c }
|
||||
// func (c *testContext) Layer(int) ml.Context { return c }
|
||||
|
||||
// func (c *testContext) Forward(...ml.Tensor) ml.Context { return c }
|
||||
|
||||
// func (c *testContext) Compute(...ml.Tensor) {}
|
||||
|
||||
// func (c *testContext) Reserve() {}
|
||||
|
||||
// func (c *testContext) MaxGraphNodes() int {
|
||||
// return 10
|
||||
// }
|
||||
|
||||
// func (c *testContext) Close() {}
|
||||
|
||||
// type testTensor struct {
|
||||
// ml.Tensor
|
||||
|
||||
// dtype ml.DType
|
||||
// elementSize int
|
||||
// data []float32
|
||||
// shape []int
|
||||
// }
|
||||
|
||||
// func (t *testTensor) Dim(n int) int {
|
||||
// return t.shape[n]
|
||||
// }
|
||||
|
||||
// func (t *testTensor) Stride(n int) int {
|
||||
// stride := t.elementSize
|
||||
// for i := range n {
|
||||
// stride *= t.shape[i]
|
||||
// }
|
||||
|
||||
// return stride
|
||||
// }
|
||||
|
||||
// func (t *testTensor) Shape() []int {
|
||||
// return t.shape
|
||||
// }
|
||||
|
||||
// func (t *testTensor) DType() ml.DType {
|
||||
// return t.dtype
|
||||
// }
|
||||
|
||||
// func (t *testTensor) Floats() []float32 {
|
||||
// out := make([]float32, len(t.data))
|
||||
// copy(out, t.data)
|
||||
// return out
|
||||
// }
|
||||
|
||||
// func (t *testTensor) Neg(ctx ml.Context) ml.Tensor {
|
||||
// out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor)
|
||||
// for i := range out.data {
|
||||
// out.data[i] = -t.data[i]
|
||||
// }
|
||||
// return out
|
||||
// }
|
||||
|
||||
// func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
// out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor)
|
||||
|
||||
// for i := range out.data {
|
||||
// out.data[i] = t.data[i] + t2.(*testTensor).data[i]
|
||||
// }
|
||||
|
||||
// return out
|
||||
// }
|
||||
|
||||
// func (t *testTensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
|
||||
// return &testTensor{
|
||||
// dtype: t.dtype,
|
||||
// elementSize: t.elementSize,
|
||||
// data: t.data,
|
||||
// shape: shape,
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
|
||||
// offset /= t.elementSize
|
||||
|
||||
// var s []int
|
||||
|
||||
// switch len(shape) {
|
||||
// case 1:
|
||||
// s = []int{shape[0]}
|
||||
// case 3:
|
||||
// s = []int{shape[0], shape[2]}
|
||||
// case 5:
|
||||
// s = []int{shape[0], shape[2], shape[4]}
|
||||
// default:
|
||||
// panic("unsupported number of dimensions")
|
||||
// }
|
||||
|
||||
// context := &testContext{}
|
||||
|
||||
// view := context.Empty(t.dtype, s...).(*testTensor)
|
||||
// view.data = t.data[offset : offset+len(view.data)]
|
||||
|
||||
// return view
|
||||
// }
|
||||
|
||||
// func (t *testTensor) Permute(ctx ml.Context, order ...int) ml.Tensor {
|
||||
// if len(t.shape) > 4 || len(order) > 4 {
|
||||
// panic("permute only supports up to 4 dimensions")
|
||||
// }
|
||||
|
||||
// if len(order) != len(t.shape) && len(order) != 4 {
|
||||
// panic("invalid number of dimensions for permute")
|
||||
// }
|
||||
|
||||
// // ggml_permute expects 4 axes, so fill in any missing dimensions.
|
||||
// orderFull := append(make([]int, 0, 4), order...)
|
||||
// for len(orderFull) < 4 {
|
||||
// orderFull = append(orderFull, len(orderFull))
|
||||
// }
|
||||
|
||||
// seen := [4]bool{}
|
||||
|
||||
// shape4 := [4]int{1, 1, 1, 1}
|
||||
// for i := 0; i < len(t.shape) && i < 4; i++ {
|
||||
// shape4[i] = t.shape[i]
|
||||
// }
|
||||
|
||||
// newShape4 := [4]int{1, 1, 1, 1}
|
||||
// for axis := range 4 {
|
||||
// dst := orderFull[axis]
|
||||
// if dst < 0 || dst >= 4 {
|
||||
// panic("invalid axis for permute")
|
||||
// }
|
||||
// if seen[dst] {
|
||||
// panic("duplicate axis for permute")
|
||||
// }
|
||||
// seen[dst] = true
|
||||
// newShape4[dst] = shape4[axis]
|
||||
// }
|
||||
|
||||
// total := len(t.data)
|
||||
// newData := make([]float32, total)
|
||||
|
||||
// if total > 0 {
|
||||
// oldDims := shape4
|
||||
// newDims := newShape4
|
||||
|
||||
// oldStride := [4]int{1, 1, 1, 1}
|
||||
// newStride := [4]int{1, 1, 1, 1}
|
||||
// for i := 1; i < 4; i++ {
|
||||
// oldStride[i] = oldStride[i-1] * oldDims[i-1]
|
||||
// newStride[i] = newStride[i-1] * newDims[i-1]
|
||||
// }
|
||||
|
||||
// var coords [4]int
|
||||
// var newCoords [4]int
|
||||
|
||||
// for idx := range total {
|
||||
// remainder := idx
|
||||
// for axis := range 4 {
|
||||
// dim := oldDims[axis]
|
||||
// if dim == 0 {
|
||||
// coords[axis] = 0
|
||||
// continue
|
||||
// }
|
||||
// coords[axis] = remainder % dim
|
||||
// remainder /= dim
|
||||
// }
|
||||
|
||||
// for axis := range 4 {
|
||||
// newCoords[orderFull[axis]] = coords[axis]
|
||||
// }
|
||||
|
||||
// newIndex := 0
|
||||
// for axis := range 4 {
|
||||
// if newDims[axis] == 0 {
|
||||
// continue
|
||||
// }
|
||||
// newIndex += newCoords[axis] * newStride[axis]
|
||||
// }
|
||||
|
||||
// newData[newIndex] = t.data[idx]
|
||||
// }
|
||||
// }
|
||||
|
||||
// numDims := 4
|
||||
// for numDims > 1 && newShape4[numDims-1] <= 1 {
|
||||
// numDims--
|
||||
// }
|
||||
|
||||
// newShape := make([]int, numDims)
|
||||
// copy(newShape, newShape4[:numDims])
|
||||
|
||||
// return &testTensor{
|
||||
// dtype: t.dtype,
|
||||
// elementSize: t.elementSize,
|
||||
// data: newData,
|
||||
// shape: newShape,
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (t *testTensor) SetRows(ctx ml.Context, src ml.Tensor, idxs ml.Tensor) ml.Tensor {
|
||||
// dst := t
|
||||
// srcTensor := src.(*testTensor)
|
||||
// idxTensor := idxs.(*testTensor)
|
||||
|
||||
// shapeTo4D := func(shape []int) [4]int {
|
||||
// out := [4]int{1, 1, 1, 1}
|
||||
// for i := 0; i < len(shape) && i < 4; i++ {
|
||||
// out[i] = shape[i]
|
||||
// }
|
||||
// return out
|
||||
// }
|
||||
|
||||
// computeStrides := func(shape [4]int) [4]int {
|
||||
// out := [4]int{1, 1, 1, 1}
|
||||
// for i := 1; i < 4; i++ {
|
||||
// out[i] = out[i-1] * shape[i-1]
|
||||
// }
|
||||
// return out
|
||||
// }
|
||||
|
||||
// dstShape4D := shapeTo4D(dst.shape)
|
||||
// srcShape4D := shapeTo4D(srcTensor.shape)
|
||||
// idxShape4D := shapeTo4D(idxTensor.shape)
|
||||
|
||||
// if dstShape4D[0] != srcShape4D[0] || dstShape4D[2] != srcShape4D[2] || dstShape4D[3] != srcShape4D[3] {
|
||||
// panic("SetRows requires matching tensor shapes")
|
||||
// }
|
||||
|
||||
// if srcShape4D[1] != idxShape4D[0] {
|
||||
// panic("SetRows rows/index mismatch")
|
||||
// }
|
||||
|
||||
// if srcShape4D[2]%idxShape4D[1] != 0 || srcShape4D[3]%idxShape4D[2] != 0 {
|
||||
// panic("SetRows cannot broadcast indices")
|
||||
// }
|
||||
|
||||
// if idxShape4D[3] != 1 {
|
||||
// panic("SetRows expects 1D or 2D index tensors")
|
||||
// }
|
||||
|
||||
// dstStride := computeStrides(dstShape4D)
|
||||
// srcStride := computeStrides(srcShape4D)
|
||||
// idxStride := computeStrides(idxShape4D)
|
||||
|
||||
// numColumns := srcShape4D[0]
|
||||
// numRows := srcShape4D[1]
|
||||
|
||||
// for dim3Index := range dstShape4D[3] {
|
||||
// for dim2Index := range dstShape4D[2] {
|
||||
// idxDim2 := 0
|
||||
// idxDim3 := 0
|
||||
// if idxShape4D[1] > 0 {
|
||||
// idxDim2 = dim2Index % idxShape4D[1]
|
||||
// }
|
||||
// if idxShape4D[2] > 0 {
|
||||
// idxDim3 = dim3Index % idxShape4D[2]
|
||||
// }
|
||||
|
||||
// idxBase := idxDim3*idxStride[2] + idxDim2*idxStride[1]
|
||||
// srcBase := dim3Index*srcStride[3] + dim2Index*srcStride[2]
|
||||
// dstBase := dim3Index*dstStride[3] + dim2Index*dstStride[2]
|
||||
|
||||
// for row := range numRows {
|
||||
// idx := int(idxTensor.data[idxBase+row*idxStride[0]])
|
||||
// if idx < 0 || idx >= dstShape4D[1] {
|
||||
// panic("SetRows index out of range")
|
||||
// }
|
||||
|
||||
// srcOffset := srcBase + row*srcStride[1]
|
||||
// dstOffset := dstBase + idx*dstStride[1]
|
||||
|
||||
// copy(dst.data[dstOffset:dstOffset+numColumns], srcTensor.data[srcOffset:srcOffset+numColumns])
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// return dst
|
||||
// }
|
||||
|
||||
// func (t *testTensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
// copy(t2.(*testTensor).data, t.data)
|
||||
// return nil
|
||||
// }
|
||||
156
x/kvcache/encoder.go
Normal file
156
x/kvcache/encoder.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package kvcache
|
||||
|
||||
// import (
|
||||
// "fmt"
|
||||
|
||||
// "github.com/ollama/ollama/ml"
|
||||
// "github.com/ollama/ollama/model/input"
|
||||
// )
|
||||
|
||||
// // Encoder cache stores K and V tensors that are position independent
|
||||
// //
|
||||
// // The tensors can be of any shape and will be returned as they were stored
|
||||
// // The mask is currently always nil
|
||||
// //
|
||||
// // Not currently safe for multiple sequences
|
||||
// type EncoderCache struct {
|
||||
// // config controls mostly backend-specific optimizations
|
||||
// config *ml.CacheConfig
|
||||
|
||||
// // ** current forward pass **
|
||||
|
||||
// // the active layer for Get and Put
|
||||
// curLayer int
|
||||
|
||||
// // if something is stored during this pass, this
|
||||
// // will be the position (but there is no guarantee
|
||||
// // anything will be stored)
|
||||
// curPos int32
|
||||
|
||||
// // curReserve indicates that this forward pass is only for
|
||||
// // memory reservation and we should not update our metadata
|
||||
// // based on it.
|
||||
// curReserve bool
|
||||
|
||||
// // ** cache metadata **
|
||||
|
||||
// // was something stored in the cache?
|
||||
// encoderCached bool
|
||||
|
||||
// // position of the cached data
|
||||
// encoderPos int32
|
||||
|
||||
// // ** cache data storage **
|
||||
// backend ml.Backend
|
||||
// ctxs map[int]ml.Context
|
||||
// keys, values map[int]ml.Tensor
|
||||
// }
|
||||
|
||||
// func NewEncoderCache() *EncoderCache {
|
||||
// return &EncoderCache{
|
||||
// ctxs: make(map[int]ml.Context),
|
||||
// keys: make(map[int]ml.Tensor),
|
||||
// values: make(map[int]ml.Tensor),
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||
// if c.config == nil {
|
||||
// var config ml.CacheConfig
|
||||
// if cc, ok := backend.(ml.BackendCacheConfig); ok {
|
||||
// config = cc.CacheConfig()
|
||||
// }
|
||||
// c.config = &config
|
||||
// }
|
||||
|
||||
// if maxSequences > 1 {
|
||||
// panic(fmt.Errorf("encoder cache does not support multiple sequences; requested: %v", maxSequences))
|
||||
// }
|
||||
|
||||
// if c.config.CachePadding != 0 && c.config.CachePadding != 1 {
|
||||
// panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding))
|
||||
// }
|
||||
|
||||
// c.backend = backend
|
||||
// }
|
||||
|
||||
// func (c *EncoderCache) SetConfig(config ml.CacheConfig) {
|
||||
// if c.config != nil {
|
||||
// panic("config cannot be changed after being previously set, either by the model or backend")
|
||||
// }
|
||||
|
||||
// c.config = &config
|
||||
// }
|
||||
|
||||
// func (c *EncoderCache) Close() {
|
||||
// for _, ctx := range c.ctxs {
|
||||
// ctx.Close()
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (c *EncoderCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
||||
// // We work with the most recent image
|
||||
// if len(batch.Multimodal) > 0 {
|
||||
// c.curPos = batch.Positions[batch.Multimodal[len(batch.Multimodal)-1].Index]
|
||||
// }
|
||||
|
||||
// c.curReserve = reserve
|
||||
|
||||
// return nil
|
||||
// }
|
||||
|
||||
// func (c *EncoderCache) SetLayer(layer int) {
|
||||
// c.curLayer = layer
|
||||
// }
|
||||
|
||||
// func (c *EncoderCache) EncoderCached() bool {
|
||||
// return c.encoderCached
|
||||
// }
|
||||
|
||||
// func (c *EncoderCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||
// return c.keys[c.curLayer], c.values[c.curLayer], nil
|
||||
// }
|
||||
|
||||
// func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||
// if !c.curReserve {
|
||||
// c.encoderPos = c.curPos
|
||||
// c.encoderCached = true
|
||||
// }
|
||||
|
||||
// if c.config.PermutedV {
|
||||
// value = value.Transpose(ctx, 1, 2, 0, 3)
|
||||
// }
|
||||
|
||||
// if _, ok := c.ctxs[c.curLayer]; !ok {
|
||||
// c.ctxs[c.curLayer] = c.backend.NewContext().Layer(c.curLayer)
|
||||
// }
|
||||
|
||||
// if _, ok := c.keys[c.curLayer]; !ok {
|
||||
// c.keys[c.curLayer] = c.ctxs[c.curLayer].Empty(key.DType(), key.Shape()...)
|
||||
// }
|
||||
|
||||
// if _, ok := c.values[c.curLayer]; !ok {
|
||||
// c.values[c.curLayer] = c.ctxs[c.curLayer].Empty(value.DType(), value.Shape()...)
|
||||
// }
|
||||
|
||||
// ctx.Forward(
|
||||
// key.Copy(ctx, c.keys[c.curLayer]),
|
||||
// value.Copy(ctx, c.values[c.curLayer]),
|
||||
// )
|
||||
// }
|
||||
|
||||
// func (c *EncoderCache) CopyPrefix(srcSeq, dstSeq int, len int32) {
|
||||
// panic("encoder cache does not support multiple sequences")
|
||||
// }
|
||||
|
||||
// func (c *EncoderCache) CanResume(seq int, pos int32) bool {
|
||||
// return true
|
||||
// }
|
||||
|
||||
// func (c *EncoderCache) Remove(seq int, beginIndex, endIndex int32) error {
|
||||
// if c.encoderPos >= beginIndex && c.encoderPos < endIndex {
|
||||
// c.encoderCached = false
|
||||
// }
|
||||
|
||||
// return nil
|
||||
// }
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user