Compare commits

...

19 Commits

Author SHA1 Message Date
jmorganca
bb1a5617b6 readme: add instructions to build with MLX 2026-01-15 09:52:56 -08:00
jmorganca
0d3648c1be glm-image wip 2026-01-14 16:46:50 -08:00
Daniel Hiltgen
02a2401596 mlx: bundle openblas dependency (#13706) 2026-01-13 15:29:47 -08:00
Daniel Hiltgen
e4b488a7b5 CI: dedup cuda libraries to reduce payload size (#13704) 2026-01-13 11:25:31 -08:00
Daniel Hiltgen
98079ddd79 ci: add missing mlx components to release build (#13702) 2026-01-13 09:13:09 -08:00
Jeffrey Morgan
d70942f47b x/imagegen/cli: skip local model check (#13699) 2026-01-12 22:38:10 -08:00
Jeffrey Morgan
58e4701557 scripts: increase notarization timeout to 20m (#13697)
The 100MB mlx.metallib file significantly increased the app bundle size,
causing Apple's notarization service to timeout with the previous 10m limit.
2026-01-12 20:38:38 -08:00
Jeffrey Morgan
dbf47ee55a cmake: use CMAKE_SYSTEM_PROCESSOR instead of CMAKE_OSX_ARCHITECTURES for mlx.metallib install (#13696)
The CMake condition for installing mlx.metallib checks
CMAKE_OSX_ARCHITECTURES, but this variable is only set when explicitly
passed - not auto-detected. The arm64 build was missing this flag,
causing the metallib to not be installed, which then caused codesign
to fail on the unexpanded glob pattern.
2026-01-12 20:05:11 -08:00
Jeffrey Morgan
af7ea6e96e x/imagegen: install mlx.metallib and fix macOS rpath handling, add mlx library directories to LD_LIBRARY_PATH (#13695)
- Install mlx.metallib for arm64 builds (required for Metal GPU acceleration)
- Apply rpath settings to all macOS builds, not just x86_64
- Add CMAKE_BUILD_WITH_INSTALL_RPATH to avoid install_name_tool errors
- Update build_darwin.sh to copy, sign, and package the metallib
2026-01-12 19:03:11 -08:00
Jeffrey Morgan
8f1e0140e7 x/imagegen: fix mlx build in Dockerfile and macOS build script (#13693) 2026-01-12 15:52:43 -08:00
Parth Sareen
35c3c9e3c2 anthropic: allow non-thinking models when using Anthropic API (#13692) 2026-01-12 15:13:26 -08:00
Parth Sareen
d06acbcb19 x/cmd: enable web search and web fetch with flag (#13690) 2026-01-12 13:59:40 -08:00
Jeffrey Morgan
9667c2282f x/imagegen: add naive TeaCache and FP8 quantization support (#13683)
TeaCache:
- Timestep embedding similarity caching for diffusion models
- Polynomial rescaling with configurable thresholds
- Reduces transformer forward passes by ~30-50%

FP8 quantization:
- Support for FP8 quantized models (8-bit weights with scales)
- QuantizedMatmul on Metal, Dequantize on CUDA
- Client-side quantization via ollama create --quantize fp8

Other bug fixes:
- Fix `/api/show` API for image generation models
- Server properly returns model info (architecture, parameters, quantization)
- Memory allocation optimizations
- CLI improvements for image generation
2026-01-12 13:45:22 -08:00
Jeffrey Morgan
a937a68317 server: fix slow 'ollama rm' of models with many layers (#13680)
RemoveLayers was calling Manifests() for each layer to check if it was
shared with other models. For models with many blobs (e.g., tensor
models), this caused O(N*M) manifest reads.

Now loads manifests once and builds a set of in-use digests.
2026-01-12 13:17:48 -08:00
Parth Sareen
2185112d84 x/cmd: connect /set flags to behavior in experimental mode (#13684) 2026-01-12 00:40:44 -08:00
Parth Sareen
91926601dc x: add missing /set, /show, /load, /save commands to experimental mode (#13682) 2026-01-11 23:12:31 -08:00
Jeffrey Morgan
361d6c16c2 x/imagegen/transfer: fix timeout and progress reporting (#13679)
Removes 5-minute HTTP client timeout that caused "context deadline
exceeded" errors on large file downloads. Stall detection (10s)
already handles unresponsive connections.

Fixes progress bar total going down on resume by calculating total
from all blobs upfront and reporting already-downloaded bytes
as completed immediately.
2026-01-11 15:33:53 -08:00
Patrick Devine
7e2496e88e Fix cmake install command in README (#13678)
Update installation command for MLX component in README.
2026-01-11 13:16:42 -08:00
WhatToPutHere
5b84e29882 docs: fix troubleshooting page (#13674)
Updated the link in the log output description to point to the correct troubleshooting guide format.
2026-01-11 00:58:07 -08:00
55 changed files with 6662 additions and 414 deletions

View File

@@ -13,7 +13,7 @@ body:
id: logs
attributes:
label: Relevant log output
description: Please copy and paste any relevant log output. See [Troubleshooting Guide](https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md#how-to-troubleshoot-issues) for details.
description: Please copy and paste any relevant log output. See [Troubleshooting Guide](https://github.com/ollama/ollama/blob/main/docs/troubleshooting.mdx#how-to-troubleshoot-issues) for details.
render: shell
validations:
required: false

View File

@@ -372,13 +372,17 @@ jobs:
outputs: type=local,dest=dist/${{ matrix.os }}-${{ matrix.arch }}
cache-from: type=registry,ref=${{ vars.DOCKER_REPO }}:latest
cache-to: type=inline
- name: Deduplicate CUDA libraries
run: |
./scripts/deduplicate_cuda_libs.sh dist/${{ matrix.os }}-${{ matrix.arch }}
- run: |
for COMPONENT in bin/* lib/ollama/*; do
case "$COMPONENT" in
bin/ollama) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
bin/ollama*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/*.so*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/cuda_v*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/vulkan*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/mlx*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;;
lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;;
lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;;

View File

@@ -48,9 +48,10 @@ if((CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_OSX_ARCHITECTURES MATCHES "arm64")
set(GGML_CPU_ALL_VARIANTS ON)
endif()
if (CMAKE_OSX_ARCHITECTURES MATCHES "x86_64")
if(APPLE)
set(CMAKE_BUILD_RPATH "@loader_path")
set(CMAKE_INSTALL_RPATH "@loader_path")
set(CMAKE_BUILD_WITH_INSTALL_RPATH ON)
endif()
set(OLLAMA_BUILD_DIR ${CMAKE_BINARY_DIR}/lib/ollama)
@@ -189,13 +190,21 @@ if(MLX_ENGINE)
install(TARGETS mlx mlxc
RUNTIME_DEPENDENCIES
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
PRE_INCLUDE_REGEXES cublas cublasLt cudart nvrtc cudnn nccl
PRE_INCLUDE_REGEXES cublas cublasLt cudart nvrtc nvrtc-builtins cudnn nccl openblas gfortran
PRE_EXCLUDE_REGEXES ".*"
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
FRAMEWORK DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
)
# Install the Metal library for macOS arm64 (must be colocated with the binary)
# Metal backend is only built for arm64, not x86_64
if(APPLE AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
install(FILES ${CMAKE_BINARY_DIR}/_deps/mlx-build/mlx/backend/metal/kernels/mlx.metallib
DESTINATION ${OLLAMA_INSTALL_DIR}
COMPONENT MLX)
endif()
# Manually install cudart and cublas since they might not be picked up as direct dependencies
if(CUDAToolkit_FOUND)
file(GLOB CUDART_LIBS

View File

@@ -161,6 +161,9 @@ ARG GOFLAGS="'-ldflags=-w -s'"
ENV CGO_ENABLED=1
ARG CGO_CFLAGS
ARG CGO_CXXFLAGS
RUN mkdir -p dist/bin
RUN --mount=type=cache,target=/root/.cache/go-build \
go build -tags mlx -trimpath -buildmode=pie -o dist/bin/ollama-mlx .
FROM base AS build
WORKDIR /go/src/github.com/ollama/ollama
@@ -182,6 +185,7 @@ COPY --from=cuda-12 dist/lib/ollama /lib/ollama/
COPY --from=cuda-13 dist/lib/ollama /lib/ollama/
COPY --from=vulkan dist/lib/ollama /lib/ollama/
COPY --from=mlx /go/src/github.com/ollama/ollama/dist/lib/ollama /lib/ollama/
COPY --from=mlx /go/src/github.com/ollama/ollama/dist/bin/ /bin/
FROM --platform=linux/arm64 scratch AS arm64
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/

View File

@@ -48,7 +48,7 @@ ollama run gemma3
## Model library
Ollama supports a list of models available on [ollama.com/library](https://ollama.com/library 'ollama model library')
Ollama supports a list of models available on [ollama.com/library](https://ollama.com/library "ollama model library")
Here are some example models that can be downloaded:
@@ -79,7 +79,7 @@ Here are some example models that can be downloaded:
| Code Llama | 7B | 3.8GB | `ollama run codellama` |
| Llama 2 Uncensored | 7B | 3.8GB | `ollama run llama2-uncensored` |
| LLaVA | 7B | 4.5GB | `ollama run llava` |
| Granite-3.3 | 8B | 4.9GB | `ollama run granite3.3` |
| Granite-3.3 | 8B | 4.9GB | `ollama run granite3.3` |
> [!NOTE]
> You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models.
@@ -260,6 +260,38 @@ Finally, in a separate shell, run a model:
./ollama run llama3.2
```
## Building with MLX (experimental)
First build the MLX libraries:
```shell
cmake --preset MLX
cmake --build --preset MLX --parallel
cmake --install build --component MLX
```
Next, build the `ollama-mlx` binary, which is a separate build of the Ollama runtime with MLX support enabled (needs to be in the same directory as `ollama`):
```shell
go build -tags mlx -o ollama-mlx .
```
Finally, start the server:
```
./ollama serve
```
### Building MLX with CUDA
When building with CUDA, use the preset "MLX CUDA 13" or "MLX CUDA 12" to enable CUDA with default architectures:
```shell
cmake --preset 'MLX CUDA 13'
cmake --build --preset 'MLX CUDA 13' --parallel
cmake --install build --component MLX
```
## REST API
Ollama has a REST API for running and managing models.
@@ -421,7 +453,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [AppFlowy](https://github.com/AppFlowy-IO/AppFlowy) (AI collaborative workspace with Ollama, cross-platform and self-hostable)
- [Lumina](https://github.com/cushydigit/lumina.git) (A lightweight, minimal React.js frontend for interacting with Ollama servers)
- [Tiny Notepad](https://pypi.org/project/tiny-notepad) (A lightweight, notepad-like interface to chat with ollama available on PyPI)
- [macLlama (macOS native)](https://github.com/hellotunamayo/macLlama) (A native macOS GUI application for interacting with Ollama models, featuring a chat interface.)
- [macLlama (macOS native)](https://github.com/hellotunamayo/macLlama) (A native macOS GUI application for interacting with Ollama models, featuring a chat interface.)
- [GPTranslate](https://github.com/philberndt/GPTranslate) (A fast and lightweight, AI powered desktop translation application written with Rust and Tauri. Features real-time translation with OpenAI/Azure/Ollama.)
- [ollama launcher](https://github.com/NGC13009/ollama-launcher) (A launcher for Ollama, aiming to provide users with convenient functions such as ollama server launching, management, or configuration.)
- [ai-hub](https://github.com/Aj-Seven/ai-hub) (AI Hub supports multiple models via API keys and Chat support via Ollama API.)
@@ -493,7 +525,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
### Database
- [pgai](https://github.com/timescale/pgai) - PostgreSQL as a vector database (Create and search embeddings from Ollama models using pgvector)
- [Get started guide](https://github.com/timescale/pgai/blob/main/docs/vectorizer-quick-start.md)
- [Get started guide](https://github.com/timescale/pgai/blob/main/docs/vectorizer-quick-start.md)
- [MindsDB](https://github.com/mindsdb/mindsdb/blob/staging/mindsdb/integrations/handlers/ollama_handler/README.md) (Connects Ollama models with nearly 200 data platforms and apps)
- [chromem-go](https://github.com/philippgille/chromem-go/blob/v0.5.0/embed_ollama.go) with [example](https://github.com/philippgille/chromem-go/tree/v0.5.0/examples/rag-wikipedia-ollama)
- [Kangaroo](https://github.com/dbkangaroo/kangaroo) (AI-powered SQL client and admin tool for popular databases)
@@ -636,6 +668,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [llama.cpp](https://github.com/ggml-org/llama.cpp) project founded by Georgi Gerganov.
### Observability
- [Opik](https://www.comet.com/docs/opik/cookbook/ollama) is an open-source platform to debug, evaluate, and monitor your LLM applications, RAG systems, and agentic workflows with comprehensive tracing, automated evaluations, and production-ready dashboards. Opik supports native integration to Ollama.
- [Lunary](https://lunary.ai/docs/integrations/ollama) is the leading open-source LLM observability platform. It provides a variety of enterprise-grade features such as real-time analytics, prompt templates management, PII masking, and comprehensive agent tracing.
- [OpenLIT](https://github.com/openlit/openlit) is an OpenTelemetry-native tool for monitoring Ollama Applications & GPUs using traces and metrics.
@@ -644,4 +677,5 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [MLflow Tracing](https://mlflow.org/docs/latest/llms/tracing/index.html#automatic-tracing) is an open source LLM observability tool with a convenient API to log and visualize traces, making it easy to debug and evaluate GenAI applications.
### Security
- [Ollama Fortress](https://github.com/ParisNeo/ollama_proxy_server)

View File

@@ -165,7 +165,7 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
return nil
}
const maxBufferSize = 512 * format.KiloByte
const maxBufferSize = 8 * format.MegaByte
func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
var buf io.Reader

View File

@@ -100,7 +100,8 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
if filename == "" {
// No Modelfile found - check if current directory is an image gen model
if imagegen.IsTensorModelDir(".") {
return imagegenclient.CreateModel(args[0], ".", p)
quantize, _ := cmd.Flags().GetString("quantize")
return imagegenclient.CreateModel(args[0], ".", quantize, p)
}
reader = strings.NewReader("FROM .\n")
} else {
@@ -464,14 +465,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
name := args[0]
// Check if this is a known image generation model (skip Show/Pull)
if imagegen.HasTensorLayers(name) {
if opts.Prompt == "" && !interactive {
return errors.New("image generation models require a prompt. Usage: ollama run " + name + " \"your prompt here\"")
}
return imagegen.RunCLI(cmd, name, opts.Prompt, interactive, opts.KeepAlive)
}
info, err := func() (*api.ShowResponse, error) {
showReq := &api.ShowRequest{Name: name}
info, err := client.Show(cmd.Context(), showReq)
@@ -533,9 +526,18 @@ func RunHandler(cmd *cobra.Command, args []string) error {
return generateEmbedding(cmd, name, opts.Prompt, opts.KeepAlive, truncate, dimensions)
}
// Check if this is an image generation model
if slices.Contains(info.Capabilities, model.CapabilityImageGeneration) {
if opts.Prompt == "" && !interactive {
return errors.New("image generation models require a prompt. Usage: ollama run " + name + " \"your prompt here\"")
}
return imagegen.RunCLI(cmd, name, opts.Prompt, interactive, opts.KeepAlive)
}
// Check for experimental flag
isExperimental, _ := cmd.Flags().GetBool("experimental")
yoloMode, _ := cmd.Flags().GetBool("experimental-yolo")
enableWebsearch, _ := cmd.Flags().GetBool("experimental-websearch")
if interactive {
if err := loadOrUnloadModel(cmd, &opts); err != nil {
@@ -565,7 +567,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
// Use experimental agent loop with tools
if isExperimental {
return xcmd.GenerateInteractive(cmd, opts.Model, opts.WordWrap, opts.Options, opts.Think, opts.HideThinking, opts.KeepAlive, yoloMode)
return xcmd.GenerateInteractive(cmd, opts.Model, opts.WordWrap, opts.Options, opts.Think, opts.HideThinking, opts.KeepAlive, yoloMode, enableWebsearch)
}
return generateInteractive(cmd, opts)
@@ -671,7 +673,11 @@ func PushHandler(cmd *cobra.Command, args []string) error {
bar, ok := bars[resp.Digest]
if !ok {
bar = progress.NewBar(fmt.Sprintf("pushing %s...", resp.Digest[7:19]), resp.Total, resp.Completed)
msg := resp.Status
if msg == "" {
msg = fmt.Sprintf("pushing %s...", resp.Digest[7:19])
}
bar = progress.NewBar(msg, resp.Total, resp.Completed)
bars[resp.Digest] = bar
p.Add(resp.Digest, bar)
}
@@ -837,11 +843,6 @@ func DeleteHandler(cmd *cobra.Command, args []string) error {
}
func ShowHandler(cmd *cobra.Command, args []string) error {
// Check if this is an image generation model
if imagegen.HasTensorLayers(args[0]) {
return imagegen.Show(args[0], os.Stdout)
}
client, err := api.ClientFromEnvironment()
if err != nil {
return err
@@ -1786,6 +1787,7 @@ func NewCLI() *cobra.Command {
runCmd.Flags().Int("dimensions", 0, "Truncate output embeddings to specified dimension (embedding models only)")
runCmd.Flags().Bool("experimental", false, "Enable experimental agent loop with tools")
runCmd.Flags().Bool("experimental-yolo", false, "Skip all tool approval prompts (use with caution)")
runCmd.Flags().Bool("experimental-websearch", false, "Enable web search tool in experimental mode")
// Image generation flags (width, height, steps, seed, etc.)
imagegen.RegisterFlags(runCmd)

View File

@@ -1547,6 +1547,79 @@ func TestRunOptions_Copy_ThinkValueVariants(t *testing.T) {
}
}
func TestShowInfoImageGen(t *testing.T) {
var b bytes.Buffer
err := showInfo(&api.ShowResponse{
Details: api.ModelDetails{
Family: "ZImagePipeline",
ParameterSize: "10.3B",
QuantizationLevel: "FP8",
},
Capabilities: []model.Capability{model.CapabilityImageGeneration},
Requires: "0.14.0",
}, false, &b)
if err != nil {
t.Fatal(err)
}
expect := " Model\n" +
" architecture ZImagePipeline \n" +
" parameters 10.3B \n" +
" quantization FP8 \n" +
" requires 0.14.0 \n" +
"\n" +
" Capabilities\n" +
" image \n" +
"\n"
if diff := cmp.Diff(expect, b.String()); diff != "" {
t.Errorf("unexpected output (-want +got):\n%s", diff)
}
}
func TestPushProgressMessage(t *testing.T) {
tests := []struct {
name string
status string
digest string
wantMsg string
}{
{
name: "uses status when provided",
status: "uploading model",
digest: "sha256:abc123456789def",
wantMsg: "uploading model",
},
{
name: "falls back to digest when status empty",
status: "",
digest: "sha256:abc123456789def",
wantMsg: "pushing abc123456789...",
},
{
name: "handles short digest gracefully",
status: "",
digest: "sha256:abc",
wantMsg: "pushing sha256:abc...",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
msg := tt.status
if msg == "" {
if len(tt.digest) >= 19 {
msg = fmt.Sprintf("pushing %s...", tt.digest[7:19])
} else {
msg = fmt.Sprintf("pushing %s...", tt.digest)
}
}
if msg != tt.wantMsg {
t.Errorf("got %q, want %q", msg, tt.wantMsg)
}
})
}
}
func TestRunOptions_Copy_Independence(t *testing.T) {
// Test that modifications to original don't affect copy
originalThink := &api.ThinkValue{Value: "original"}

View File

@@ -1,3 +0,0 @@
# Troubleshooting
For troubleshooting, see [https://docs.ollama.com/troubleshooting](https://docs.ollama.com/troubleshooting)

View File

@@ -118,6 +118,9 @@ func AnthropicMessagesMiddleware() gin.HandlerFunc {
return
}
// Set think to nil when being used with Anthropic API to connect to tools like claude code
c.Set("relax_thinking", true)
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, anthropic.NewError(http.StatusInternalServerError, err.Error()))

View File

@@ -582,3 +582,26 @@ func TestAnthropicWriter_ErrorFromRoutes(t *testing.T) {
})
}
}
func TestAnthropicMessagesMiddleware_SetsRelaxThinkingFlag(t *testing.T) {
gin.SetMode(gin.TestMode)
var flagSet bool
router := gin.New()
router.Use(AnthropicMessagesMiddleware())
router.POST("/v1/messages", func(c *gin.Context) {
_, flagSet = c.Get("relax_thinking")
c.Status(http.StatusOK)
})
body := `{"model": "test-model", "max_tokens": 100, "messages": [{"role": "user", "content": "Hi"}]}`
req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if !flagSet {
t.Error("expected relax_thinking flag to be set in context")
}
}

View File

@@ -73,7 +73,7 @@ _build_darwin() {
MLX_CGO_CFLAGS="-O3 -I$(pwd)/$BUILD_DIR/_deps/mlx-c-src -mmacosx-version-min=14.0"
MLX_CGO_LDFLAGS="-L$(pwd)/$BUILD_DIR/lib/ollama -lmlxc -lmlx -Wl,-rpath,@executable_path -lc++ -framework Metal -framework Foundation -framework Accelerate -mmacosx-version-min=14.0"
fi
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 CGO_CFLAGS="$MLX_CGO_CFLAGS" CGO_LDFLAGS="$MLX_CGO_LDFLAGS" go build -tags mlx -o $INSTALL_PREFIX/imagegen ./x/imagegen/cmd/engine
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 CGO_CFLAGS="$MLX_CGO_CFLAGS" CGO_LDFLAGS="$MLX_CGO_LDFLAGS" go build -tags mlx -o $INSTALL_PREFIX/ollama-mlx .
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 go build -o $INSTALL_PREFIX .
done
}
@@ -82,19 +82,19 @@ _sign_darwin() {
status "Creating universal binary..."
mkdir -p dist/darwin
lipo -create -output dist/darwin/ollama dist/darwin-*/ollama
lipo -create -output dist/darwin/imagegen dist/darwin-*/imagegen
lipo -create -output dist/darwin/ollama-mlx dist/darwin-*/ollama-mlx
chmod +x dist/darwin/ollama
chmod +x dist/darwin/imagegen
chmod +x dist/darwin/ollama-mlx
if [ -n "$APPLE_IDENTITY" ]; then
for F in dist/darwin/ollama dist/darwin-*/lib/ollama/* dist/darwin/imagegen; do
for F in dist/darwin/ollama dist/darwin-*/lib/ollama/* dist/darwin/ollama-mlx; do
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime $F
done
# create a temporary zip for notarization
TEMP=$(mktemp -u).zip
ditto -c -k --keepParent dist/darwin/ollama "$TEMP"
xcrun notarytool submit "$TEMP" --wait --timeout 10m --apple-id $APPLE_ID --password $APPLE_PASSWORD --team-id $APPLE_TEAM_ID
xcrun notarytool submit "$TEMP" --wait --timeout 20m --apple-id $APPLE_ID --password $APPLE_PASSWORD --team-id $APPLE_TEAM_ID
rm -f "$TEMP"
fi
@@ -154,23 +154,25 @@ _build_macapp() {
mkdir -p dist/Ollama.app/Contents/Resources
if [ -d dist/darwin-amd64 ]; then
lipo -create -output dist/Ollama.app/Contents/Resources/ollama dist/darwin-amd64/ollama dist/darwin-arm64/ollama
lipo -create -output dist/Ollama.app/Contents/Resources/imagegen dist/darwin-amd64/imagegen dist/darwin-arm64/imagegen
lipo -create -output dist/Ollama.app/Contents/Resources/ollama-mlx dist/darwin-amd64/ollama-mlx dist/darwin-arm64/ollama-mlx
for F in dist/darwin-amd64/lib/ollama/*mlx*.dylib ; do
lipo -create -output dist/darwin/$(basename $F) $F dist/darwin-arm64/lib/ollama/$(basename $F)
done
cp dist/darwin-*/lib/ollama/*.so dist/darwin-*/lib/ollama/*.dylib dist/Ollama.app/Contents/Resources/
cp dist/darwin/*.dylib dist/Ollama.app/Contents/Resources/
# Copy MLX metallib (architecture-independent, just use arm64 version)
cp dist/darwin-arm64/lib/ollama/*.metallib dist/Ollama.app/Contents/Resources/ 2>/dev/null || true
else
cp -a dist/darwin/ollama dist/Ollama.app/Contents/Resources/ollama
cp dist/darwin/*.so dist/darwin/*.dylib dist/Ollama.app/Contents/Resources/
fi
cp -a dist/darwin/imagegen dist/Ollama.app/Contents/Resources/imagegen
cp -a dist/darwin/ollama-mlx dist/Ollama.app/Contents/Resources/ollama-mlx
chmod a+x dist/Ollama.app/Contents/Resources/ollama
# Sign
if [ -n "$APPLE_IDENTITY" ]; then
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime dist/Ollama.app/Contents/Resources/ollama
for lib in dist/Ollama.app/Contents/Resources/*.so dist/Ollama.app/Contents/Resources/*.dylib dist/Ollama.app/Contents/Resources/imagegen ; do
for lib in dist/Ollama.app/Contents/Resources/*.so dist/Ollama.app/Contents/Resources/*.dylib dist/Ollama.app/Contents/Resources/*.metallib dist/Ollama.app/Contents/Resources/ollama-mlx ; do
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime ${lib}
done
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier com.electron.ollama --deep --options=runtime dist/Ollama.app
@@ -178,11 +180,11 @@ _build_macapp() {
rm -f dist/Ollama-darwin.zip
ditto -c -k --keepParent dist/Ollama.app dist/Ollama-darwin.zip
(cd dist/Ollama.app/Contents/Resources/; tar -cf - ollama imagegen *.so *.dylib) | gzip -9vc > dist/ollama-darwin.tgz
(cd dist/Ollama.app/Contents/Resources/; tar -cf - ollama ollama-mlx *.so *.dylib *.metallib 2>/dev/null) | gzip -9vc > dist/ollama-darwin.tgz
# Notarize and Staple
if [ -n "$APPLE_IDENTITY" ]; then
$(xcrun -f notarytool) submit dist/Ollama-darwin.zip --wait --timeout 10m --apple-id "$APPLE_ID" --password "$APPLE_PASSWORD" --team-id "$APPLE_TEAM_ID"
$(xcrun -f notarytool) submit dist/Ollama-darwin.zip --wait --timeout 20m --apple-id "$APPLE_ID" --password "$APPLE_PASSWORD" --team-id "$APPLE_TEAM_ID"
rm -f dist/Ollama-darwin.zip
$(xcrun -f stapler) staple dist/Ollama.app
ditto -c -k --keepParent dist/Ollama.app dist/Ollama-darwin.zip
@@ -206,7 +208,7 @@ _build_macapp() {
rm -f dist/rw*.dmg
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime dist/Ollama.dmg
$(xcrun -f notarytool) submit dist/Ollama.dmg --wait --timeout 10m --apple-id "$APPLE_ID" --password "$APPLE_PASSWORD" --team-id "$APPLE_TEAM_ID"
$(xcrun -f notarytool) submit dist/Ollama.dmg --wait --timeout 20m --apple-id "$APPLE_ID" --password "$APPLE_PASSWORD" --team-id "$APPLE_TEAM_ID"
$(xcrun -f stapler) staple dist/Ollama.dmg
else
echo "WARNING: Code signing disabled, this bundle will not work for upgrade testing"

View File

@@ -48,53 +48,12 @@ if echo $PLATFORM | grep "amd64" > /dev/null; then
.
fi
# Deduplicate CUDA libraries across mlx_* and cuda_* directories
deduplicate_cuda_libs() {
local base_dir="$1"
echo "Deduplicating CUDA libraries in ${base_dir}..."
# Find all mlx_cuda_* directories
for mlx_dir in "${base_dir}"/lib/ollama/mlx_cuda_*; do
[ -d "${mlx_dir}" ] || continue
# Extract CUDA version (e.g., v12, v13)
cuda_version=$(basename "${mlx_dir}" | sed 's/mlx_cuda_//')
cuda_dir="${base_dir}/lib/ollama/cuda_${cuda_version}"
# Skip if corresponding cuda_* directory doesn't exist
[ -d "${cuda_dir}" ] || continue
echo " Checking ${mlx_dir} against ${cuda_dir}..."
# Find all .so* files in mlx directory
find "${mlx_dir}" -type f -name "*.so*" | while read mlx_file; do
filename=$(basename "${mlx_file}")
cuda_file="${cuda_dir}/${filename}"
# Skip if file doesn't exist in cuda directory
[ -f "${cuda_file}" ] || continue
# Compare checksums
mlx_sum=$(sha256sum "${mlx_file}" | awk '{print $1}')
cuda_sum=$(sha256sum "${cuda_file}" | awk '{print $1}')
if [ "${mlx_sum}" = "${cuda_sum}" ]; then
echo " Deduplicating ${filename}"
# Calculate relative path from mlx_dir to cuda_dir
rel_path="../cuda_${cuda_version}/${filename}"
rm -f "${mlx_file}"
ln -s "${rel_path}" "${mlx_file}"
fi
done
done
}
# Run deduplication for each platform output directory
if echo $PLATFORM | grep "," > /dev/null ; then
deduplicate_cuda_libs "./dist/linux_amd64"
deduplicate_cuda_libs "./dist/linux_arm64"
$(dirname $0)/deduplicate_cuda_libs.sh "./dist/linux_amd64"
$(dirname $0)/deduplicate_cuda_libs.sh "./dist/linux_arm64"
elif echo $PLATFORM | grep "amd64\|arm64" > /dev/null ; then
deduplicate_cuda_libs "./dist"
$(dirname $0)/deduplicate_cuda_libs.sh "./dist"
fi
# buildx behavior changes for single vs. multiplatform

View File

@@ -0,0 +1,60 @@
#!/bin/sh
#
# Deduplicate CUDA libraries across mlx_* and cuda_* directories
# This script finds identical .so* files in mlx_cuda_* directories that exist
# in corresponding cuda_* directories and replaces them with symlinks.
#
set -eu
if [ $# -eq 0 ]; then
echo "ERROR: No directory specified" >&2
echo "Usage: $0 <base_directory>" >&2
exit 1
fi
base_dir="$1"
if [ ! -d "${base_dir}" ]; then
echo "ERROR: Directory ${base_dir} does not exist" >&2
exit 1
fi
echo "Deduplicating CUDA libraries in ${base_dir}..."
# Find all mlx_cuda_* directories
for mlx_dir in "${base_dir}"/lib/ollama/mlx_cuda_*; do
[ -d "${mlx_dir}" ] || continue
# Extract CUDA version (e.g., v12, v13)
cuda_version=$(basename "${mlx_dir}" | sed 's/mlx_cuda_//')
cuda_dir="${base_dir}/lib/ollama/cuda_${cuda_version}"
# Skip if corresponding cuda_* directory doesn't exist
[ -d "${cuda_dir}" ] || continue
echo " Checking ${mlx_dir} against ${cuda_dir}..."
# Find all .so* files in mlx directory
find "${mlx_dir}" -type f -name "*.so*" | while read mlx_file; do
filename=$(basename "${mlx_file}")
cuda_file="${cuda_dir}/${filename}"
# Skip if file doesn't exist in cuda directory
[ -f "${cuda_file}" ] || continue
# Compare checksums
mlx_sum=$(sha256sum "${mlx_file}" | awk '{print $1}')
cuda_sum=$(sha256sum "${cuda_file}" | awk '{print $1}')
if [ "${mlx_sum}" = "${cuda_sum}" ]; then
echo " Deduplicating ${filename}"
# Calculate relative path from mlx_dir to cuda_dir
rel_path="../cuda_${cuda_version}/${filename}"
rm -f "${mlx_file}"
ln -s "${rel_path}" "${mlx_file}"
fi
done
done
echo "Deduplication complete"

View File

@@ -47,16 +47,40 @@ func (m *Manifest) Remove() error {
}
func (m *Manifest) RemoveLayers() error {
for _, layer := range append(m.Layers, m.Config) {
if layer.Digest != "" {
if err := layer.Remove(); errors.Is(err, os.ErrNotExist) {
slog.Debug("layer does not exist", "digest", layer.Digest)
} else if err != nil {
return err
ms, err := Manifests(true)
if err != nil {
return err
}
// Build set of digests still in use by other manifests
inUse := make(map[string]struct{})
for _, other := range ms {
for _, layer := range append(other.Layers, other.Config) {
if layer.Digest != "" {
inUse[layer.Digest] = struct{}{}
}
}
}
// Remove layers not used by any other manifest
for _, layer := range append(m.Layers, m.Config) {
if layer.Digest == "" {
continue
}
if _, used := inUse[layer.Digest]; used {
continue
}
blob, err := GetBlobsPath(layer.Digest)
if err != nil {
return err
}
if err := os.Remove(blob); errors.Is(err, os.ErrNotExist) {
slog.Debug("layer does not exist", "digest", layer.Digest)
} else if err != nil {
return err
}
}
return nil
}

View File

@@ -1124,6 +1124,15 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
QuantizationLevel: m.Config.FileType,
}
// For image generation models, populate details from imagegen package
if slices.Contains(m.Capabilities(), model.CapabilityImageGeneration) {
if info, err := imagegen.GetModelInfo(name.String()); err == nil {
modelDetails.Family = info.Architecture
modelDetails.ParameterSize = format.HumanNumber(uint64(info.ParameterCount))
modelDetails.QuantizationLevel = info.Quantization
}
}
if req.System != "" {
m.System = req.System
}
@@ -1206,6 +1215,10 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
return resp, nil
}
if slices.Contains(m.Capabilities(), model.CapabilityImageGeneration) {
return resp, nil
}
kvData, tensors, err := getModelData(m.ModelPath, req.Verbose)
if err != nil {
return nil, err
@@ -2059,8 +2072,14 @@ func (s *Server) ChatHandler(c *gin.Context) {
}
} else {
if req.Think != nil && req.Think.Bool() {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support thinking", req.Model)})
return
// Set think to nil when being used with Anthropic API to connect to tools like claude code
if _, ok := c.Get("relax_thinking"); ok {
slog.Warn("model does not support thinking, relaxing thinking to nil", "model", req.Model)
req.Think = nil
} else {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support thinking", req.Model)})
return
}
}
}

View File

@@ -1,24 +0,0 @@
# Experimental Features
## MLX Backend
We're working on a new experimental backend based on the [MLX project](https://github.com/ml-explore/mlx)
Support is currently limited to MacOS and Linux with CUDA GPUs. We're looking to add support for Windows CUDA soon, and other GPU vendors. To build:
```
cmake --preset MLX
cmake --build --preset MLX --parallel
cmake --install --component MLX
go build -tags mlx .
```
On linux, use the preset "MLX CUDA 13" or "MLX CUDA 12" to enable CUDA with the default Ollama NVIDIA GPU architectures enabled.
## Image Generation
Based on the experimental MLX backend, we're working on adding imagegen support. After running the cmake commands above:
```
go build -o imagegen ./x/imagegen/cmd/engine
```

View File

@@ -41,6 +41,7 @@ var optionLabels = []string{
var toolDisplayNames = map[string]string{
"bash": "Bash",
"web_search": "Web Search",
"web_fetch": "Web Fetch",
}
// ToolDisplayName returns the human-readable display name for a tool.
@@ -565,6 +566,16 @@ func formatToolDisplay(toolName string, args map[string]any) string {
}
}
// For web fetch, show URL and internet notice
if toolName == "web_fetch" {
if url, ok := args["url"].(string); ok {
sb.WriteString(fmt.Sprintf("Tool: %s\n", displayName))
sb.WriteString(fmt.Sprintf("URL: %s\n", url))
sb.WriteString("Uses internet via ollama.com")
return sb.String()
}
}
// Generic display
sb.WriteString(fmt.Sprintf("Tool: %s", displayName))
if len(args) > 0 {
@@ -1017,6 +1028,16 @@ func FormatApprovalResult(toolName string, args map[string]any, result ApprovalR
}
}
if toolName == "web_fetch" {
if url, ok := args["url"].(string); ok {
// Truncate long URLs
if len(url) > 50 {
url = url[:47] + "..."
}
return fmt.Sprintf("\033[1m%s:\033[0m %s: %s", label, displayName, url)
}
}
return fmt.Sprintf("\033[1m%s:\033[0m %s", label, displayName)
}

View File

@@ -9,6 +9,7 @@ import (
"net/url"
"os"
"os/signal"
"slices"
"strings"
"syscall"
"time"
@@ -24,6 +25,14 @@ import (
"github.com/ollama/ollama/x/tools"
)
// MultilineState tracks the state of multiline input
type MultilineState int
const (
MultilineNone MultilineState = iota
MultilineSystem
)
// Tool output capping constants
const (
// localModelTokenLimit is the token limit for local models (smaller context).
@@ -130,6 +139,7 @@ type RunOptions struct {
KeepAlive *api.Duration
Think *api.ThinkValue
HideThinking bool
Verbose bool
// Agent fields (managed externally for session persistence)
Tools *tools.Registry
@@ -178,6 +188,7 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
var thinkTagClosed bool = false
var pendingToolCalls []api.ToolCall
var consecutiveErrors int // Track consecutive 500 errors for retry limit
var latest api.ChatResponse
role := "assistant"
messages := opts.Messages
@@ -187,6 +198,7 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
p.StopAndClear()
}
latest = response
role = response.Message.Role
if response.Message.Thinking != "" && !opts.HideThinking {
if !thinkTagOpened {
@@ -483,6 +495,10 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
fmt.Println()
}
if opts.Verbose {
latest.Summary()
}
return &api.Message{Role: role, Thinking: thinkingContent.String(), Content: fullResponse.String()}, nil
}
@@ -634,7 +650,8 @@ func checkModelCapabilities(ctx context.Context, modelName string) (supportsTool
// GenerateInteractive runs an interactive agent session.
// This is called from cmd.go when --experimental flag is set.
// If yoloMode is true, all tool approvals are skipped.
func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, options map[string]any, think *api.ThinkValue, hideThinking bool, keepAlive *api.Duration, yoloMode bool) error {
// If enableWebsearch is true, the web search tool is registered.
func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, options map[string]any, think *api.ThinkValue, hideThinking bool, keepAlive *api.Duration, yoloMode bool, enableWebsearch bool) error {
scanner, err := readline.New(readline.Prompt{
Prompt: ">>> ",
AltPrompt: "... ",
@@ -660,6 +677,12 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
if supportsTools {
toolRegistry = tools.DefaultRegistry()
// Register web search and web fetch tools if enabled via flag
if enableWebsearch {
toolRegistry.RegisterWebSearch()
toolRegistry.RegisterWebFetch()
}
if toolRegistry.Has("bash") {
fmt.Fprintln(os.Stderr)
fmt.Fprintln(os.Stderr, "This experimental version of Ollama has the \033[1mbash\033[0m tool enabled.")
@@ -667,6 +690,11 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
fmt.Fprintln(os.Stderr)
}
if toolRegistry.Has("web_search") || toolRegistry.Has("web_fetch") {
fmt.Fprintln(os.Stderr, "The \033[1mWeb Search\033[0m and \033[1mWeb Fetch\033[0m tools are enabled. Models can search and fetch web content via ollama.com.")
fmt.Fprintln(os.Stderr)
}
if yoloMode {
fmt.Fprintf(os.Stderr, "\033[1mwarning:\033[0m yolo mode - all tool approvals will be skipped\n")
}
@@ -677,6 +705,9 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
var messages []api.Message
var sb strings.Builder
var format string
var system string
var multiline MultilineState = MultilineNone
for {
line, err := scanner.Readline()
@@ -688,13 +719,39 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
if line == "" {
fmt.Println("\nUse Ctrl + d or /bye to exit.")
}
scanner.Prompt.UseAlt = false
sb.Reset()
multiline = MultilineNone
continue
case err != nil:
return err
}
switch {
case multiline != MultilineNone:
// check if there's a multiline terminating string
before, ok := strings.CutSuffix(line, `"""`)
sb.WriteString(before)
if !ok {
fmt.Fprintln(&sb)
continue
}
switch multiline {
case MultilineSystem:
system = sb.String()
newMessage := api.Message{Role: "system", Content: system}
if len(messages) > 0 && messages[len(messages)-1].Role == "system" {
messages[len(messages)-1] = newMessage
} else {
messages = append(messages, newMessage)
}
fmt.Println("Set system message.")
sb.Reset()
}
multiline = MultilineNone
scanner.Prompt.UseAlt = false
case strings.HasPrefix(line, "/exit"), strings.HasPrefix(line, "/bye"):
return nil
case strings.HasPrefix(line, "/clear"):
@@ -707,6 +764,10 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
continue
case strings.HasPrefix(line, "/help"), strings.HasPrefix(line, "/?"):
fmt.Fprintln(os.Stderr, "Available Commands:")
fmt.Fprintln(os.Stderr, " /set Set session variables")
fmt.Fprintln(os.Stderr, " /show Show model information")
fmt.Fprintln(os.Stderr, " /load Load a different model")
fmt.Fprintln(os.Stderr, " /save Save session as a model")
fmt.Fprintln(os.Stderr, " /tools Show available tools and approvals")
fmt.Fprintln(os.Stderr, " /clear Clear session context and approvals")
fmt.Fprintln(os.Stderr, " /bye Exit")
@@ -716,6 +777,303 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
fmt.Fprintln(os.Stderr, " Ctrl+O Expand last tool output")
fmt.Fprintln(os.Stderr, "")
continue
case strings.HasPrefix(line, "/set"):
args := strings.Fields(line)
if len(args) > 1 {
switch args[1] {
case "history":
scanner.HistoryEnable()
case "nohistory":
scanner.HistoryDisable()
case "wordwrap":
wordWrap = true
fmt.Println("Set 'wordwrap' mode.")
case "nowordwrap":
wordWrap = false
fmt.Println("Set 'nowordwrap' mode.")
case "verbose":
if err := cmd.Flags().Set("verbose", "true"); err != nil {
return err
}
fmt.Println("Set 'verbose' mode.")
case "quiet":
if err := cmd.Flags().Set("verbose", "false"); err != nil {
return err
}
fmt.Println("Set 'quiet' mode.")
case "think":
thinkValue := api.ThinkValue{Value: true}
var maybeLevel string
if len(args) > 2 {
maybeLevel = args[2]
}
if maybeLevel != "" {
thinkValue.Value = maybeLevel
}
think = &thinkValue
// Check if model supports thinking
if client, err := api.ClientFromEnvironment(); err == nil {
if resp, err := client.Show(cmd.Context(), &api.ShowRequest{Model: modelName}); err == nil {
if !slices.Contains(resp.Capabilities, model.CapabilityThinking) {
fmt.Fprintf(os.Stderr, "warning: model %q does not support thinking output\n", modelName)
}
}
}
if maybeLevel != "" {
fmt.Printf("Set 'think' mode to '%s'.\n", maybeLevel)
} else {
fmt.Println("Set 'think' mode.")
}
case "nothink":
think = &api.ThinkValue{Value: false}
// Check if model supports thinking
if client, err := api.ClientFromEnvironment(); err == nil {
if resp, err := client.Show(cmd.Context(), &api.ShowRequest{Model: modelName}); err == nil {
if !slices.Contains(resp.Capabilities, model.CapabilityThinking) {
fmt.Fprintf(os.Stderr, "warning: model %q does not support thinking output\n", modelName)
}
}
}
fmt.Println("Set 'nothink' mode.")
case "format":
if len(args) < 3 || args[2] != "json" {
fmt.Println("Invalid or missing format. For 'json' mode use '/set format json'")
} else {
format = args[2]
fmt.Printf("Set format to '%s' mode.\n", args[2])
}
case "noformat":
format = ""
fmt.Println("Disabled format.")
case "parameter":
if len(args) < 4 {
fmt.Println("Usage: /set parameter <name> <value>")
continue
}
params := args[3:]
fp, err := api.FormatParams(map[string][]string{args[2]: params})
if err != nil {
fmt.Printf("Couldn't set parameter: %q\n", err)
continue
}
fmt.Printf("Set parameter '%s' to '%s'\n", args[2], strings.Join(params, ", "))
options[args[2]] = fp[args[2]]
case "system":
if len(args) < 3 {
fmt.Println("Usage: /set system <message> or /set system \"\"\"<multi-line message>\"\"\"")
continue
}
multiline = MultilineSystem
line := strings.Join(args[2:], " ")
line, ok := strings.CutPrefix(line, `"""`)
if !ok {
multiline = MultilineNone
} else {
// only cut suffix if the line is multiline
line, ok = strings.CutSuffix(line, `"""`)
if ok {
multiline = MultilineNone
}
}
sb.WriteString(line)
if multiline != MultilineNone {
scanner.Prompt.UseAlt = true
continue
}
system = sb.String()
newMessage := api.Message{Role: "system", Content: sb.String()}
// Check if the slice is not empty and the last message is from 'system'
if len(messages) > 0 && messages[len(messages)-1].Role == "system" {
// Replace the last message
messages[len(messages)-1] = newMessage
} else {
messages = append(messages, newMessage)
}
fmt.Println("Set system message.")
sb.Reset()
continue
default:
fmt.Printf("Unknown command '/set %s'. Type /? for help\n", args[1])
}
} else {
fmt.Println("Usage: /set <parameter|system|history|format|wordwrap|think|verbose> [value]")
}
continue
case strings.HasPrefix(line, "/show"):
args := strings.Fields(line)
if len(args) > 1 {
client, err := api.ClientFromEnvironment()
if err != nil {
fmt.Println("error: couldn't connect to ollama server")
continue
}
req := &api.ShowRequest{
Name: modelName,
Options: options,
}
resp, err := client.Show(cmd.Context(), req)
if err != nil {
fmt.Println("error: couldn't get model")
continue
}
switch args[1] {
case "info":
fmt.Fprintf(os.Stderr, " Model\n")
fmt.Fprintf(os.Stderr, " %-16s %s\n", "Name", modelName)
if resp.Details.Family != "" {
fmt.Fprintf(os.Stderr, " %-16s %s\n", "Family", resp.Details.Family)
}
if resp.Details.ParameterSize != "" {
fmt.Fprintf(os.Stderr, " %-16s %s\n", "Parameter Size", resp.Details.ParameterSize)
}
if resp.Details.QuantizationLevel != "" {
fmt.Fprintf(os.Stderr, " %-16s %s\n", "Quantization", resp.Details.QuantizationLevel)
}
if len(resp.Capabilities) > 0 {
caps := make([]string, len(resp.Capabilities))
for i, c := range resp.Capabilities {
caps[i] = string(c)
}
fmt.Fprintf(os.Stderr, " %-16s %s\n", "Capabilities", strings.Join(caps, ", "))
}
fmt.Fprintln(os.Stderr)
case "license":
if resp.License == "" {
fmt.Println("No license was specified for this model.")
} else {
fmt.Println(resp.License)
}
case "modelfile":
fmt.Println(resp.Modelfile)
case "parameters":
fmt.Println("Model defined parameters:")
if resp.Parameters == "" {
fmt.Println(" No additional parameters were specified.")
} else {
for _, l := range strings.Split(resp.Parameters, "\n") {
fmt.Printf(" %s\n", l)
}
}
if len(options) > 0 {
fmt.Println("\nUser defined parameters:")
for k, v := range options {
fmt.Printf(" %-30s %v\n", k, v)
}
}
case "system":
switch {
case system != "":
fmt.Println(system + "\n")
case resp.System != "":
fmt.Println(resp.System + "\n")
default:
fmt.Println("No system message was specified for this model.")
}
case "template":
if resp.Template != "" {
fmt.Println(resp.Template)
} else {
fmt.Println("No prompt template was specified for this model.")
}
default:
fmt.Printf("Unknown command '/show %s'. Type /? for help\n", args[1])
}
} else {
fmt.Println("Usage: /show <info|license|modelfile|parameters|system|template>")
}
continue
case strings.HasPrefix(line, "/load"):
args := strings.Fields(line)
if len(args) != 2 {
fmt.Println("Usage: /load <modelname>")
continue
}
newModelName := args[1]
fmt.Printf("Loading model '%s'\n", newModelName)
// Create progress spinner
p := progress.NewProgress(os.Stderr)
spinner := progress.NewSpinner("")
p.Add("", spinner)
// Get client
client, err := api.ClientFromEnvironment()
if err != nil {
p.StopAndClear()
fmt.Println("error: couldn't connect to ollama server")
continue
}
// Check if model exists and get its info
info, err := client.Show(cmd.Context(), &api.ShowRequest{Model: newModelName})
if err != nil {
p.StopAndClear()
if strings.Contains(err.Error(), "not found") {
fmt.Printf("Couldn't find model '%s'\n", newModelName)
} else {
fmt.Printf("error: %v\n", err)
}
continue
}
// For cloud models, no need to preload
if info.RemoteHost == "" {
// Preload the model by sending an empty generate request
req := &api.GenerateRequest{
Model: newModelName,
Think: think,
}
err = client.Generate(cmd.Context(), req, func(r api.GenerateResponse) error {
return nil
})
if err != nil {
p.StopAndClear()
if strings.Contains(err.Error(), "not found") {
fmt.Printf("Couldn't find model '%s'\n", newModelName)
} else if strings.Contains(err.Error(), "does not support thinking") {
fmt.Printf("error: %v\n", err)
} else {
fmt.Printf("error loading model: %v\n", err)
}
continue
}
}
p.StopAndClear()
modelName = newModelName
messages = []api.Message{}
approval.Reset()
continue
case strings.HasPrefix(line, "/save"):
args := strings.Fields(line)
if len(args) != 2 {
fmt.Println("Usage: /save <modelname>")
continue
}
client, err := api.ClientFromEnvironment()
if err != nil {
fmt.Println("error: couldn't connect to ollama server")
continue
}
req := &api.CreateRequest{
Model: args[1],
From: modelName,
Parameters: options,
Messages: messages,
}
fn := func(resp api.ProgressResponse) error { return nil }
err = client.Create(cmd.Context(), req, fn)
if err != nil {
fmt.Printf("error: %v\n", err)
continue
}
fmt.Printf("Created new model '%s'\n", args[1])
continue
case strings.HasPrefix(line, "/"):
fmt.Printf("Unknown command '%s'. Type /? for help\n", strings.Fields(line)[0])
continue
@@ -723,14 +1081,16 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
sb.WriteString(line)
}
if sb.Len() > 0 {
if sb.Len() > 0 && multiline == MultilineNone {
newMessage := api.Message{Role: "user", Content: sb.String()}
messages = append(messages, newMessage)
verbose, _ := cmd.Flags().GetBool("verbose")
opts := RunOptions{
Model: modelName,
Messages: messages,
WordWrap: wordWrap,
Format: format,
Options: options,
Think: think,
HideThinking: hideThinking,
@@ -738,6 +1098,7 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
Tools: toolRegistry,
Approval: approval,
YoloMode: yoloMode,
Verbose: verbose,
}
assistant, err := Chat(cmd.Context(), opts)

View File

@@ -234,3 +234,17 @@ ollama create z-image
3. Copy config files (*.json) as config layers
4. Write manifest
```
## FP8 Quantization
Z-Image supports FP8 quantization to reduce memory usage by ~50% while maintaining image quality.
### Usage
```bash
cd ./weights/Z-Image-Turbo
ollama create z-image-fp8 --quantize fp8
```
This quantizes weights during import. The resulting model will be ~15GB instead of ~31GB.

View File

@@ -1,10 +1,8 @@
package api
import (
"encoding/base64"
"fmt"
"net/http"
"os"
"strconv"
"strings"
"time"
@@ -101,10 +99,10 @@ func handleStreamingResponse(c *gin.Context, runner llm.LlamaServer, req llm.Com
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
var imagePath string
var imageBase64 string
err := runner.Completion(c.Request.Context(), req, func(resp llm.CompletionResponse) {
if resp.Done {
imagePath = extractPath(resp.Content)
imageBase64 = extractBase64(resp.Content)
} else {
progress := parseProgress(resp.Content)
if progress.Total > 0 {
@@ -118,14 +116,14 @@ func handleStreamingResponse(c *gin.Context, runner llm.LlamaServer, req llm.Com
return
}
c.SSEvent("done", buildResponse(imagePath, format))
c.SSEvent("done", buildResponse(imageBase64, format))
}
func handleNonStreamingResponse(c *gin.Context, runner llm.LlamaServer, req llm.CompletionRequest, format string) {
var imagePath string
var imageBase64 string
err := runner.Completion(c.Request.Context(), req, func(resp llm.CompletionResponse) {
if resp.Done {
imagePath = extractPath(resp.Content)
imageBase64 = extractBase64(resp.Content)
}
})
if err != nil {
@@ -133,7 +131,7 @@ func handleNonStreamingResponse(c *gin.Context, runner llm.LlamaServer, req llm.
return
}
c.JSON(http.StatusOK, buildResponse(imagePath, format))
c.JSON(http.StatusOK, buildResponse(imageBase64, format))
}
func parseSize(size string) (int32, int32) {
@@ -152,9 +150,9 @@ func parseSize(size string) (int32, int32) {
return int32(w), int32(h)
}
func extractPath(content string) string {
if idx := strings.Index(content, "Image saved to: "); idx >= 0 {
return strings.TrimSpace(content[idx+16:])
func extractBase64(content string) string {
if strings.HasPrefix(content, "IMAGE_BASE64:") {
return content[13:]
}
return ""
}
@@ -165,23 +163,21 @@ func parseProgress(content string) ImageProgressEvent {
return ImageProgressEvent{Step: step, Total: total}
}
func buildResponse(imagePath, format string) ImageGenerationResponse {
func buildResponse(imageBase64, format string) ImageGenerationResponse {
resp := ImageGenerationResponse{
Created: time.Now().Unix(),
Data: make([]ImageData, 1),
}
if imagePath == "" {
if imageBase64 == "" {
return resp
}
if format == "url" {
resp.Data[0].URL = "file://" + imagePath
// URL format not supported when using base64 transfer
resp.Data[0].B64JSON = imageBase64
} else {
data, err := os.ReadFile(imagePath)
if err == nil {
resp.Data[0].B64JSON = base64.StdEncoding.EncodeToString(data)
}
resp.Data[0].B64JSON = imageBase64
}
return resp

197
x/imagegen/cache/teacache.go vendored Normal file
View File

@@ -0,0 +1,197 @@
//go:build mlx
// Package cache provides caching mechanisms for diffusion model inference.
package cache
import (
"github.com/ollama/ollama/x/imagegen/mlx"
)
// TeaCache implements Timestep Embedding Aware Caching for diffusion models.
// It caches the transformer output and reuses it when timestep values
// are similar between consecutive steps.
//
// For CFG (classifier-free guidance), it caches pos and neg predictions
// separately and always computes CFG fresh to avoid error amplification.
//
// Reference: "Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model"
// https://github.com/ali-vilab/TeaCache
type TeaCache struct {
// Cached transformer output from last computed step (non-CFG mode)
cachedOutput *mlx.Array
// Cached CFG outputs (pos and neg separately)
cachedPosOutput *mlx.Array
cachedNegOutput *mlx.Array
// Previous timestep value for difference calculation
prevTimestep float32
// Accumulated difference for rescaling
accumulatedDiff float32
// Configuration
threshold float32 // Threshold for recomputation decision
rescaleFactor float32 // Model-specific rescaling factor
skipEarlySteps int // Number of early steps to never cache
// Statistics
cacheHits int
cacheMisses int
}
// TeaCacheConfig holds configuration for TeaCache.
type TeaCacheConfig struct {
// Threshold for recomputation. Lower = more cache hits, potential quality loss.
// Recommended: 0.05-0.15 for image models
Threshold float32
// Rescale factor to adjust timestep embedding differences.
// Model-specific, typically 1.0-2.0
RescaleFactor float32
// SkipEarlySteps: number of early steps to always compute (never cache).
// Set to 2-3 for CFG mode to preserve structure. 0 = no skipping.
SkipEarlySteps int
}
// DefaultTeaCacheConfig returns default configuration for TeaCache.
func DefaultTeaCacheConfig() *TeaCacheConfig {
return &TeaCacheConfig{
Threshold: 0.1,
RescaleFactor: 1.0,
}
}
// NewTeaCache creates a new TeaCache instance.
func NewTeaCache(cfg *TeaCacheConfig) *TeaCache {
if cfg == nil {
cfg = DefaultTeaCacheConfig()
}
return &TeaCache{
threshold: cfg.Threshold,
rescaleFactor: cfg.RescaleFactor,
skipEarlySteps: cfg.SkipEarlySteps,
}
}
// ShouldCompute determines if we should compute the full forward pass
// or reuse the cached output based on timestep similarity.
//
// Algorithm:
// 1. First step always computes
// 2. Subsequent steps compare |currTimestep - prevTimestep| * rescaleFactor
// 3. If accumulated difference > threshold, compute new output
// 4. Otherwise, reuse cached output
func (tc *TeaCache) ShouldCompute(step int, timestep float32) bool {
// Always compute early steps (critical for structure)
// Check both regular cache and CFG cache
hasCachedOutput := tc.cachedOutput != nil || tc.HasCFGCache()
if step < tc.skipEarlySteps || step == 0 || !hasCachedOutput {
return true
}
// Compute absolute difference between current and previous timestep
diff := timestep - tc.prevTimestep
if diff < 0 {
diff = -diff
}
// Apply rescaling factor
scaledDiff := diff * tc.rescaleFactor
// Accumulate difference (helps track drift over multiple cached steps)
tc.accumulatedDiff += scaledDiff
// Decision based on accumulated difference
if tc.accumulatedDiff > tc.threshold {
tc.accumulatedDiff = 0 // Reset accumulator
return true
}
return false
}
// UpdateCache stores the computed output for potential reuse (non-CFG mode).
func (tc *TeaCache) UpdateCache(output *mlx.Array, timestep float32) {
// Free previous cached output
if tc.cachedOutput != nil {
tc.cachedOutput.Free()
}
// Store new cached values
tc.cachedOutput = output
tc.prevTimestep = timestep
tc.cacheMisses++
}
// UpdateCFGCache stores pos and neg outputs separately for CFG mode.
// This allows CFG to be computed fresh each step, avoiding error amplification.
func (tc *TeaCache) UpdateCFGCache(posOutput, negOutput *mlx.Array, timestep float32) {
// Free previous cached outputs
if tc.cachedPosOutput != nil {
tc.cachedPosOutput.Free()
}
if tc.cachedNegOutput != nil {
tc.cachedNegOutput.Free()
}
// Store new cached values
tc.cachedPosOutput = posOutput
tc.cachedNegOutput = negOutput
tc.prevTimestep = timestep
tc.cacheMisses++
}
// GetCached returns the cached output (non-CFG mode).
func (tc *TeaCache) GetCached() *mlx.Array {
tc.cacheHits++
return tc.cachedOutput
}
// GetCFGCached returns cached pos and neg outputs for CFG mode.
func (tc *TeaCache) GetCFGCached() (pos, neg *mlx.Array) {
tc.cacheHits++
return tc.cachedPosOutput, tc.cachedNegOutput
}
// HasCFGCache returns true if CFG cache is available.
func (tc *TeaCache) HasCFGCache() bool {
return tc.cachedPosOutput != nil && tc.cachedNegOutput != nil
}
// Arrays returns all arrays that should be kept alive.
func (tc *TeaCache) Arrays() []*mlx.Array {
var arrays []*mlx.Array
if tc.cachedOutput != nil {
arrays = append(arrays, tc.cachedOutput)
}
if tc.cachedPosOutput != nil {
arrays = append(arrays, tc.cachedPosOutput)
}
if tc.cachedNegOutput != nil {
arrays = append(arrays, tc.cachedNegOutput)
}
return arrays
}
// Stats returns cache hit/miss statistics.
func (tc *TeaCache) Stats() (hits, misses int) {
return tc.cacheHits, tc.cacheMisses
}
// Free releases all cached arrays.
func (tc *TeaCache) Free() {
if tc.cachedOutput != nil {
tc.cachedOutput.Free()
tc.cachedOutput = nil
}
if tc.cachedPosOutput != nil {
tc.cachedPosOutput.Free()
tc.cachedPosOutput = nil
}
if tc.cachedNegOutput != nil {
tc.cachedNegOutput.Free()
tc.cachedNegOutput = nil
}
}

View File

@@ -44,62 +44,64 @@ func DefaultOptions() ImageGenOptions {
}
}
// Show displays information about an image generation model.
func Show(modelName string, w io.Writer) error {
manifest, err := LoadManifest(modelName)
if err != nil {
return fmt.Errorf("failed to load manifest: %w", err)
}
// Count total size
var totalSize int64
for _, layer := range manifest.Manifest.Layers {
if layer.MediaType == "application/vnd.ollama.image.tensor" {
totalSize += layer.Size
}
}
// Read model_index.json for architecture
var architecture string
if data, err := manifest.ReadConfig("model_index.json"); err == nil {
var index struct {
Architecture string `json:"architecture"`
}
if json.Unmarshal(data, &index) == nil {
architecture = index.Architecture
}
}
// Estimate parameter count from total size (assuming BF16 = 2 bytes per param)
paramCount := totalSize / 2
paramStr := formatParamCount(paramCount)
// Print Model info
fmt.Fprintln(w, " Model")
if architecture != "" {
fmt.Fprintf(w, " %-20s %s\n", "architecture", architecture)
}
fmt.Fprintf(w, " %-20s %s\n", "parameters", paramStr)
fmt.Fprintf(w, " %-20s %s\n", "quantization", "BF16")
fmt.Fprintln(w)
// Print Capabilities
fmt.Fprintln(w, " Capabilities")
fmt.Fprintf(w, " %s\n", "image")
fmt.Fprintln(w)
return nil
// ModelInfo contains metadata about an image generation model.
type ModelInfo struct {
Architecture string
ParameterCount int64
Quantization string
}
// formatParamCount formats parameter count as human-readable string.
func formatParamCount(count int64) string {
if count >= 1_000_000_000 {
return fmt.Sprintf("%.1fB", float64(count)/1_000_000_000)
// GetModelInfo returns metadata about an image generation model.
func GetModelInfo(modelName string) (*ModelInfo, error) {
manifest, err := LoadManifest(modelName)
if err != nil {
return nil, fmt.Errorf("failed to load manifest: %w", err)
}
if count >= 1_000_000 {
return fmt.Sprintf("%.1fM", float64(count)/1_000_000)
info := &ModelInfo{}
// Read model_index.json for architecture, parameter count, and quantization
if data, err := manifest.ReadConfig("model_index.json"); err == nil {
var index struct {
Architecture string `json:"architecture"`
ParameterCount int64 `json:"parameter_count"`
Quantization string `json:"quantization"`
}
if json.Unmarshal(data, &index) == nil {
info.Architecture = index.Architecture
info.ParameterCount = index.ParameterCount
info.Quantization = index.Quantization
}
}
return fmt.Sprintf("%d", count)
// Fallback: detect quantization from tensor names if not in config
if info.Quantization == "" {
for _, layer := range manifest.Manifest.Layers {
if strings.HasSuffix(layer.Name, ".weight_scale") {
info.Quantization = "FP8"
break
}
}
if info.Quantization == "" {
info.Quantization = "BF16"
}
}
// Fallback: estimate parameter count if not in config
if info.ParameterCount == 0 {
var totalSize int64
for _, layer := range manifest.Manifest.Layers {
if layer.MediaType == "application/vnd.ollama.image.tensor" {
if !strings.HasSuffix(layer.Name, "_scale") && !strings.HasSuffix(layer.Name, "_qbias") {
totalSize += layer.Size
}
}
}
// Assume BF16 (2 bytes/param) as rough estimate
info.ParameterCount = totalSize / 2
}
return info, nil
}
// RegisterFlags adds image generation flags to the given command.
@@ -121,11 +123,6 @@ func RegisterFlags(cmd *cobra.Command) {
// Returns true if it handled the request, false if the caller should continue with normal flow.
// Supports flags: --width, --height, --steps, --seed, --negative
func RunCLI(cmd *cobra.Command, name string, prompt string, interactive bool, keepAlive *api.Duration) error {
// Verify it's a valid image gen model
if ResolveModelName(name) == "" {
return fmt.Errorf("unknown image generation model: %s", name)
}
// Get options from flags (with env var defaults)
opts := DefaultOptions()
if cmd != nil && cmd.Flags() != nil {
@@ -183,8 +180,7 @@ func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keep
p.Add("", spinner)
var stepBar *progress.StepBar
var imagePath string
var imageBase64 string
err = client.Generate(cmd.Context(), req, func(resp api.GenerateResponse) error {
content := resp.Response
@@ -203,11 +199,9 @@ func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keep
return nil
}
// Handle final response with image path
if resp.Done && strings.Contains(content, "Image saved to:") {
if idx := strings.Index(content, "Image saved to: "); idx >= 0 {
imagePath = strings.TrimSpace(content[idx+16:])
}
// Handle final response with base64 image data
if resp.Done && strings.HasPrefix(content, "IMAGE_BASE64:") {
imageBase64 = content[13:]
}
return nil
@@ -218,9 +212,27 @@ func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keep
return err
}
if imagePath != "" {
displayImageInTerminal(imagePath)
fmt.Printf("Image saved to: %s\n", imagePath)
if imageBase64 != "" {
// Decode base64 and save to CWD
imageData, err := base64.StdEncoding.DecodeString(imageBase64)
if err != nil {
return fmt.Errorf("failed to decode image: %w", err)
}
// Create filename from prompt
safeName := sanitizeFilename(prompt)
if len(safeName) > 50 {
safeName = safeName[:50]
}
timestamp := time.Now().Format("20060102-150405")
filename := fmt.Sprintf("%s-%s.png", safeName, timestamp)
if err := os.WriteFile(filename, imageData, 0o644); err != nil {
return fmt.Errorf("failed to save image: %w", err)
}
displayImageInTerminal(filename)
fmt.Printf("Image saved to: %s\n", filename)
}
return nil
@@ -306,7 +318,7 @@ func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duratio
p.Add("", spinner)
var stepBar *progress.StepBar
var imagePath string
var imageBase64 string
err = client.Generate(cmd.Context(), req, func(resp api.GenerateResponse) error {
content := resp.Response
@@ -326,11 +338,9 @@ func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duratio
return nil
}
// Handle final response with image path
if resp.Done && strings.Contains(content, "Image saved to:") {
if idx := strings.Index(content, "Image saved to: "); idx >= 0 {
imagePath = strings.TrimSpace(content[idx+16:])
}
// Handle final response with base64 image data
if resp.Done && strings.HasPrefix(content, "IMAGE_BASE64:") {
imageBase64 = content[13:]
}
return nil
@@ -342,25 +352,30 @@ func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duratio
continue
}
// Copy image to current directory with descriptive name
if imagePath != "" {
// Save image to current directory with descriptive name
if imageBase64 != "" {
// Decode base64 image data
imageData, err := base64.StdEncoding.DecodeString(imageBase64)
if err != nil {
fmt.Fprintf(os.Stderr, "Error decoding image: %v\n", err)
continue
}
// Create filename from prompt (sanitized)
safeName := sanitizeFilename(line)
if len(safeName) > 50 {
safeName = safeName[:50]
}
timestamp := time.Now().Format("20060102-150405")
newName := fmt.Sprintf("%s-%s.png", safeName, timestamp)
filename := fmt.Sprintf("%s-%s.png", safeName, timestamp)
// Copy file to CWD
if err := copyFile(imagePath, newName); err != nil {
fmt.Fprintf(os.Stderr, "Error saving to current directory: %v\n", err)
displayImageInTerminal(imagePath)
fmt.Printf("Image saved to: %s\n", imagePath)
} else {
displayImageInTerminal(newName)
fmt.Printf("Image saved to: %s\n", newName)
if err := os.WriteFile(filename, imageData, 0o644); err != nil {
fmt.Fprintf(os.Stderr, "Error saving image: %v\n", err)
continue
}
displayImageInTerminal(filename)
fmt.Printf("Image saved to: %s\n", filename)
}
fmt.Println()
@@ -381,24 +396,6 @@ func sanitizeFilename(s string) string {
return result.String()
}
// copyFile copies a file from src to dst.
func copyFile(src, dst string) error {
sourceFile, err := os.Open(src)
if err != nil {
return err
}
defer sourceFile.Close()
destFile, err := os.Create(dst)
if err != nil {
return err
}
defer destFile.Close()
_, err = io.Copy(destFile, sourceFile)
return err
}
// printInteractiveHelp prints help for interactive mode commands.
func printInteractiveHelp(opts ImageGenOptions) {
fmt.Fprintln(os.Stderr, "Commands:")
@@ -509,10 +506,7 @@ func displayImageInTerminal(imagePath string) bool {
// Send in chunks for large images
const chunkSize = 4096
for i := 0; i < len(encoded); i += chunkSize {
end := i + chunkSize
if end > len(encoded) {
end = len(encoded)
}
end := min(i+chunkSize, len(encoded))
chunk := encoded[i:end]
if i == 0 {

View File

@@ -29,9 +29,10 @@ const MinOllamaVersion = "0.14.0"
// CreateModel imports a tensor-based model from a local directory.
// This creates blobs and manifest directly on disk, bypassing the HTTP API.
// If quantize is "fp8", weights will be quantized to mxfp8 format during import.
//
// TODO (jmorganca): Replace with API-based creation when promoted to production.
func CreateModel(modelName, modelDir string, p *progress.Progress) error {
func CreateModel(modelName, modelDir, quantize string, p *progress.Progress) error {
if !imagegen.IsTensorModelDir(modelDir) {
return fmt.Errorf("%s is not an image generation model directory (model_index.json not found)", modelDir)
}
@@ -58,18 +59,77 @@ func CreateModel(modelName, modelDir string, p *progress.Progress) error {
// Create tensor layer callback for individual tensors
// name is path-style: "component/tensor_name"
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32) (imagegen.LayerInfo, error) {
// When quantize is true, returns multiple layers (weight + scales)
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, doQuantize bool) ([]imagegen.LayerInfo, error) {
if doQuantize {
// Check if quantization is supported
if !QuantizeSupported() {
return nil, fmt.Errorf("quantization requires MLX support")
}
// Quantize the tensor (affine mode returns weight, scales, qbiases)
qweightData, scalesData, qbiasData, _, _, _, err := quantizeTensor(r, name, dtype, shape)
if err != nil {
return nil, fmt.Errorf("failed to quantize %s: %w", name, err)
}
// Create layer for quantized weight
weightLayer, err := server.NewLayer(bytes.NewReader(qweightData), server.MediaTypeImageTensor)
if err != nil {
return nil, err
}
// Create layer for scales (use _scale suffix convention)
scalesLayer, err := server.NewLayer(bytes.NewReader(scalesData), server.MediaTypeImageTensor)
if err != nil {
return nil, err
}
layers := []imagegen.LayerInfo{
{
Digest: weightLayer.Digest,
Size: weightLayer.Size,
MediaType: weightLayer.MediaType,
Name: name, // Keep original name for weight
},
{
Digest: scalesLayer.Digest,
Size: scalesLayer.Size,
MediaType: scalesLayer.MediaType,
Name: name + "_scale", // Add _scale suffix
},
}
// Add qbiases layer if present (affine mode)
if qbiasData != nil {
qbiasLayer, err := server.NewLayer(bytes.NewReader(qbiasData), server.MediaTypeImageTensor)
if err != nil {
return nil, err
}
layers = append(layers, imagegen.LayerInfo{
Digest: qbiasLayer.Digest,
Size: qbiasLayer.Size,
MediaType: qbiasLayer.MediaType,
Name: name + "_qbias", // Add _qbias suffix
})
}
return layers, nil
}
// Non-quantized path: just create a single layer
layer, err := server.NewLayer(r, server.MediaTypeImageTensor)
if err != nil {
return imagegen.LayerInfo{}, err
return nil, err
}
layer.Name = name
return imagegen.LayerInfo{
Digest: layer.Digest,
Size: layer.Size,
MediaType: layer.MediaType,
Name: name,
return []imagegen.LayerInfo{
{
Digest: layer.Digest,
Size: layer.Size,
MediaType: layer.MediaType,
Name: name,
},
}, nil
}
@@ -119,7 +179,7 @@ func CreateModel(modelName, modelDir string, p *progress.Progress) error {
p.Add("imagegen", spinner)
}
err := imagegen.CreateModel(modelName, modelDir, createLayer, createTensorLayer, writeManifest, progressFn)
err := imagegen.CreateModel(modelName, modelDir, quantize, createLayer, createTensorLayer, writeManifest, progressFn)
spinner.Stop()
if err != nil {
return err

View File

@@ -0,0 +1,120 @@
//go:build mlx
package client
import (
"fmt"
"io"
"os"
"path/filepath"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// quantizeTensor loads a tensor from safetensors format, quantizes it to affine int8,
// and returns safetensors data for the quantized weights, scales, and biases.
// Uses MLX's native SaveSafetensors to ensure correct dtype handling (especially uint32 for quantized weights).
func quantizeTensor(r io.Reader, name, dtype string, shape []int32) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) {
tmpDir := ensureTempDir()
// Read safetensors data to a temp file (LoadSafetensorsNative needs a path)
tmpFile, err := os.CreateTemp(tmpDir, "quant-input-*.safetensors")
if err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to create temp file: %w", err)
}
tmpPath := tmpFile.Name()
defer os.Remove(tmpPath)
if _, err := io.Copy(tmpFile, r); err != nil {
tmpFile.Close()
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to write temp file: %w", err)
}
tmpFile.Close()
// Load the tensor using MLX's native loader
st, err := mlx.LoadSafetensorsNative(tmpPath)
if err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to load safetensors: %w", err)
}
defer st.Free()
// Get the tensor (it's stored as "data" in our minimal safetensors format)
arr := st.Get("data")
if arr == nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("tensor 'data' not found in safetensors")
}
// Convert to BFloat16 if needed (quantize expects float type)
if arr.Dtype() != mlx.DtypeBFloat16 && arr.Dtype() != mlx.DtypeFloat32 && arr.Dtype() != mlx.DtypeFloat16 {
arr = mlx.AsType(arr, mlx.DtypeBFloat16)
mlx.Eval(arr)
}
// Quantize with affine mode: group_size=32, bits=8
// Note: mxfp8 mode doesn't have matmul kernels in MLX, affine mode does
qweight, scales, qbiases := mlx.Quantize(arr, 32, 8, "affine")
// Eval and make contiguous for data access
qweight = mlx.Contiguous(qweight)
scales = mlx.Contiguous(scales)
if qbiases != nil {
qbiases = mlx.Contiguous(qbiases)
mlx.Eval(qweight, scales, qbiases)
} else {
mlx.Eval(qweight, scales)
}
// Get shapes
qweightShape = qweight.Shape()
scalesShape = scales.Shape()
// Save quantized weight using MLX's native safetensors (correctly handles uint32 dtype)
qweightPath := filepath.Join(tmpDir, "qweight.safetensors")
defer os.Remove(qweightPath)
if err := mlx.SaveSafetensors(qweightPath, map[string]*mlx.Array{"data": qweight}); err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to save quantized weight: %w", err)
}
qweightData, err = os.ReadFile(qweightPath)
if err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to read quantized weight: %w", err)
}
// Save scales using MLX's native safetensors
scalesPath := filepath.Join(tmpDir, "scales.safetensors")
defer os.Remove(scalesPath)
if err := mlx.SaveSafetensors(scalesPath, map[string]*mlx.Array{"data": scales}); err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to save scales: %w", err)
}
scalesData, err = os.ReadFile(scalesPath)
if err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to read scales: %w", err)
}
// Affine mode returns qbiases for zero-point offset
if qbiases != nil {
qbiasShape = qbiases.Shape()
qbiasPath := filepath.Join(tmpDir, "qbias.safetensors")
defer os.Remove(qbiasPath)
if err := mlx.SaveSafetensors(qbiasPath, map[string]*mlx.Array{"data": qbiases}); err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to save qbiases: %w", err)
}
qbiasData, err = os.ReadFile(qbiasPath)
if err != nil {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to read qbiases: %w", err)
}
}
return qweightData, scalesData, qbiasData, qweightShape, scalesShape, qbiasShape, nil
}
// QuantizeSupported returns true if quantization is supported (MLX build)
func QuantizeSupported() bool {
return true
}
// ensureTempDir creates the temp directory for quantization if it doesn't exist
func ensureTempDir() string {
tmpDir := filepath.Join(os.TempDir(), "ollama-quantize")
os.MkdirAll(tmpDir, 0755)
return tmpDir
}

View File

@@ -0,0 +1,18 @@
//go:build !mlx
package client
import (
"fmt"
"io"
)
// quantizeTensor is not available without MLX
func quantizeTensor(r io.Reader, name, dtype string, shape []int32) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("quantization requires MLX support (build with mlx tag)")
}
// QuantizeSupported returns false when MLX is not available
func QuantizeSupported() bool {
return false
}

View File

@@ -11,9 +11,11 @@ import (
"os"
"path/filepath"
"runtime/pprof"
"strings"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/models/gemma3"
"github.com/ollama/ollama/x/imagegen/models/glm_image"
"github.com/ollama/ollama/x/imagegen/models/gpt_oss"
"github.com/ollama/ollama/x/imagegen/models/llama"
"github.com/ollama/ollama/x/imagegen/models/qwen_image"
@@ -61,12 +63,16 @@ func main() {
// Legacy mode flags
zimageFlag := flag.Bool("zimage", false, "Z-Image generation")
glmImageFlag := flag.Bool("glm-image", false, "GLM-Image generation")
qwenImage := flag.Bool("qwen-image", false, "Qwen-Image text-to-image generation")
qwenImageEdit := flag.Bool("qwen-image-edit", false, "Qwen-Image-Edit image editing")
var inputImages stringSlice
flag.Var(&inputImages, "input-image", "Input image for image editing (can be specified multiple times)")
negativePrompt := flag.String("negative-prompt", "", "Negative prompt for CFG (empty = no CFG, matching Python)")
cfgScale := flag.Float64("cfg-scale", 4.0, "CFG scale for image editing")
teaCache := flag.Bool("teacache", false, "Enable TeaCache for faster inference")
teaCacheThreshold := flag.Float64("teacache-threshold", 0.1, "TeaCache threshold (lower = more aggressive caching)")
fusedQKV := flag.Bool("fused-qkv", false, "Enable fused QKV projection for faster attention")
flag.Parse()
@@ -99,13 +105,44 @@ func main() {
}
var img *mlx.Array
img, err = m.GenerateFromConfig(context.Background(), &zimage.GenerateConfig{
Prompt: *prompt,
Width: int32(*width),
Height: int32(*height),
Steps: *steps,
Seed: *seed,
CapturePath: *gpuCapture,
LayerCache: *layerCache,
Prompt: *prompt,
NegativePrompt: *negativePrompt,
CFGScale: float32(*cfgScale),
Width: int32(*width),
Height: int32(*height),
Steps: *steps,
Seed: *seed,
CapturePath: *gpuCapture,
TeaCache: *teaCache,
TeaCacheThreshold: float32(*teaCacheThreshold),
FusedQKV: *fusedQKV,
})
if err == nil {
err = saveImageArray(img, *out)
}
case *glmImageFlag:
m := &glm_image.Model{}
// Use LoadFromPath if model path looks like a directory, otherwise use Load (ollama manifest)
var loadErr error
if strings.HasPrefix(*modelPath, ".") || strings.HasPrefix(*modelPath, "/") {
loadErr = m.LoadFromPath(*modelPath)
} else {
loadErr = m.Load(*modelPath)
}
if loadErr != nil {
log.Fatal(loadErr)
}
var img *mlx.Array
img, err = m.GenerateFromConfig(context.Background(), &glm_image.GenerateConfig{
Prompt: *prompt,
Width: int32(*width),
Height: int32(*height),
Steps: *steps,
Seed: *seed,
Temperature: float32(*temperature),
TopP: float32(*topP),
GuidanceScale: float32(*cfgScale),
MaxVisualTokens: int32(*maxTokens),
})
if err == nil {
err = saveImageArray(img, *out)

View File

@@ -40,13 +40,15 @@ type ManifestWriter func(modelName string, config LayerInfo, layers []LayerInfo)
// CreateModel imports an image generation model from a directory.
// Stores each tensor as a separate blob for fine-grained deduplication.
// If quantize is "fp8", linear weights in transformer/text_encoder are quantized to mxfp8 format.
// Layer creation and manifest writing are done via callbacks to avoid import cycles.
func CreateModel(modelName, modelDir string, createLayer LayerCreator, createTensorLayer TensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
func CreateModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
var layers []LayerInfo
var configLayer LayerInfo
var totalParams int64 // Count parameters from original tensor shapes
// Components to process - extract individual tensors from each
components := []string{"text_encoder", "transformer", "vae"}
components := []string{"text_encoder", "transformer", "vae", "vision_language_encoder"}
for _, component := range components {
componentDir := filepath.Join(modelDir, component)
@@ -74,7 +76,11 @@ func CreateModel(modelName, modelDir string, createLayer LayerCreator, createTen
}
tensorNames := extractor.ListTensors()
fn(fmt.Sprintf("importing %s/%s (%d tensors)", component, entry.Name(), len(tensorNames)))
quantizeMsg := ""
if quantize == "fp8" && component != "vae" {
quantizeMsg = ", quantizing to fp8"
}
fn(fmt.Sprintf("importing %s/%s (%d tensors%s)", component, entry.Name(), len(tensorNames), quantizeMsg))
for _, tensorName := range tensorNames {
td, err := extractor.GetTensor(tensorName)
@@ -83,16 +89,30 @@ func CreateModel(modelName, modelDir string, createLayer LayerCreator, createTen
return fmt.Errorf("failed to get tensor %s: %w", tensorName, err)
}
// Count parameters from original tensor shape
if len(td.Shape) > 0 {
numElements := int64(1)
for _, dim := range td.Shape {
numElements *= int64(dim)
}
totalParams += numElements
}
// Store as minimal safetensors format (88 bytes header overhead)
// This enables native mmap loading via mlx_load_safetensors
// Use path-style name: "component/tensor_name"
fullName := component + "/" + tensorName
layer, err := createTensorLayer(td.SafetensorsReader(), fullName, td.Dtype, td.Shape)
// Determine if this tensor should be quantized
doQuantize := quantize == "fp8" && ShouldQuantize(tensorName, component)
// createTensorLayer returns multiple layers if quantizing (weight + scales)
newLayers, err := createTensorLayer(td.SafetensorsReader(), fullName, td.Dtype, td.Shape, doQuantize)
if err != nil {
extractor.Close()
return fmt.Errorf("failed to create layer for %s: %w", fullName, err)
}
layers = append(layers, layer)
layers = append(layers, newLayers...)
}
extractor.Close()
@@ -106,10 +126,13 @@ func CreateModel(modelName, modelDir string, createLayer LayerCreator, createTen
"text_encoder/generation_config.json",
"transformer/config.json",
"vae/config.json",
"vision_language_encoder/config.json",
"scheduler/scheduler_config.json",
"tokenizer/tokenizer.json",
"tokenizer/tokenizer_config.json",
"tokenizer/vocab.json",
"processor/tokenizer.json", // GLM-Image main tokenizer
"processor/tokenizer_config.json", // GLM-Image tokenizer config
}
for _, cfgPath := range configFiles {
@@ -122,7 +145,7 @@ func CreateModel(modelName, modelDir string, createLayer LayerCreator, createTen
var r io.Reader
// For model_index.json, normalize to Ollama format
// For model_index.json, normalize to Ollama format and add metadata
if cfgPath == "model_index.json" {
data, err := os.ReadFile(fullPath)
if err != nil {
@@ -141,6 +164,16 @@ func CreateModel(modelName, modelDir string, createLayer LayerCreator, createTen
}
delete(cfg, "_diffusers_version")
// Add parameter count (counted from tensor shapes during import)
cfg["parameter_count"] = totalParams
// Add quantization info
if quantize == "fp8" {
cfg["quantization"] = "FP8"
} else {
cfg["quantization"] = "BF16"
}
data, err = json.MarshalIndent(cfg, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal %s: %w", cfgPath, err)

View File

@@ -60,9 +60,12 @@ func ArrayToImage(arr *mlx.Array) (*image.RGBA, error) {
}
// Transform to [H, W, C] for image conversion
img := mlx.Squeeze(arr, 0)
img = mlx.Transpose(img, 1, 2, 0)
img = mlx.Contiguous(img)
// Free intermediate arrays to avoid memory leak
squeezed := mlx.Squeeze(arr, 0)
transposed := mlx.Transpose(squeezed, 1, 2, 0)
squeezed.Free()
img := mlx.Contiguous(transposed)
transposed.Free()
mlx.Eval(img)
imgShape := img.Shape()

19
x/imagegen/imagegen.md Normal file
View File

@@ -0,0 +1,19 @@
# Image generation models (experimental)
Experimental image generation models are available for **macOS** in Ollama:
## Available models
- [Z-Image-Turbo](https://ollama.com/x/z-image-turbo)
```
ollama run x/z-image-turbo
```
> **Note**: [`x`](https://ollama.com/x) is a username on ollama.com where the maintainer team uploads experimental models
More models coming soon:
1. Qwen-Image-2512
2. Qwen-Image-Edit-2511
3. GLM-Image

View File

@@ -27,6 +27,7 @@ var modelVRAMEstimates = map[string]uint64{
"ZImagePipeline": 21 * GB, // ~21GB for Z-Image (text encoder + transformer + VAE)
"FluxPipeline": 21 * GB, // ~21GB for Flux (same architecture)
"QwenImagePipeline": 80 * GB, // TODO: verify actual requirements, using conservative estimate for now
"GlmImagePipeline": 80 * GB, // ~34GB weights + ~46GB working memory for 9B+7B hybrid model
}
// CheckPlatformSupport validates that image generation is supported on the current platform.

View File

@@ -607,6 +607,11 @@ func (a *Array) Valid() bool {
return a != nil && a.c.ctx != nil
}
// Kept returns true if the array is marked to survive Eval() cleanup.
func (a *Array) Kept() bool {
return a != nil && a.kept
}
func int32ToCInt(s []int32) *C.int {
if len(s) == 0 {
return nil
@@ -1480,6 +1485,44 @@ func (a *Array) ItemInt32() int32 {
return int32(val)
}
// Bytes copies the raw bytes out of the array without type conversion.
// Works with common dtypes (float32, int32, uint32, uint8).
// For non-contiguous arrays, call Contiguous() first.
// Note: Triggers cleanup of non-kept arrays.
func (a *Array) Bytes() []byte {
cleanup()
nbytes := a.Nbytes()
if nbytes == 0 {
return nil
}
// Get raw pointer based on dtype
var ptr unsafe.Pointer
switch a.Dtype() {
case DtypeFloat32:
ptr = unsafe.Pointer(C.mlx_array_data_float32(a.c))
case DtypeInt32:
ptr = unsafe.Pointer(C.mlx_array_data_int32(a.c))
case DtypeUint32:
ptr = unsafe.Pointer(C.mlx_array_data_uint32(a.c))
case DtypeUint8:
ptr = unsafe.Pointer(C.mlx_array_data_uint8(a.c))
default:
// For other types (bf16, f16, etc), convert to float32
arr := AsType(a, DtypeFloat32)
arr.Eval()
ptr = unsafe.Pointer(C.mlx_array_data_float32(arr.c))
nbytes = arr.Nbytes()
}
if ptr == nil {
return nil
}
data := make([]byte, nbytes)
copy(data, unsafe.Slice((*byte)(ptr), nbytes))
return data
}
// ============ Utility ============
// String returns a string representation
@@ -1658,6 +1701,34 @@ func (s *SafetensorsFile) Free() {
C.mlx_map_string_to_string_free(s.metadata)
}
// SaveSafetensors saves arrays to a safetensors file using MLX's native implementation.
// This correctly handles all dtypes including uint32 for quantized weights.
func SaveSafetensors(path string, arrays map[string]*Array) error {
cPath := C.CString(path)
defer C.free(unsafe.Pointer(cPath))
// Create the map
cArrays := C.mlx_map_string_to_array_new()
defer C.mlx_map_string_to_array_free(cArrays)
// Add each array to the map
for name, arr := range arrays {
cName := C.CString(name)
C.mlx_map_string_to_array_insert(cArrays, cName, arr.c)
C.free(unsafe.Pointer(cName))
}
// Create empty metadata (optional)
cMeta := C.mlx_map_string_to_string_new()
defer C.mlx_map_string_to_string_free(cMeta)
// Save
if C.mlx_save_safetensors(cPath, cArrays, cMeta) != 0 {
return fmt.Errorf("failed to save safetensors: %s", path)
}
return nil
}
// ============ NPY Loading ============
// LoadNpy loads a numpy array from an npy file
@@ -1986,7 +2057,8 @@ func GatherQMM(x, w, scales *Array, biases, lhsIndices, rhsIndices *Array, trans
// Returns (quantized_weights, scales, biases).
// groupSize: number of elements quantized together (default 64)
// bits: bits per element, 2, 4, or 8 (default 4)
// mode: "affine" (default) or "mxfp4"
// mode: "affine" (default), "mxfp4", or "mxfp8"
// Note: mxfp8 mode returns nil biases (only weights and scales)
func Quantize(w *Array, groupSize, bits int, mode string) (weights, scales, biases *Array) {
cMode := C.CString(mode)
defer C.free(unsafe.Pointer(cMode))
@@ -1995,14 +2067,21 @@ func Quantize(w *Array, groupSize, bits int, mode string) (weights, scales, bias
res := C.mlx_vector_array_new()
C.mlx_quantize(&res, w.c, optGroupSize, optBits, cMode, C.default_stream())
// Result is a vector of 3 arrays: [weights, scales, biases]
// Result is a vector of arrays: [weights, scales, biases?]
// mxfp8 mode returns only 2 elements (no biases)
vecSize := int(C.mlx_vector_array_size(res))
var w0, w1, w2 C.mlx_array
C.mlx_vector_array_get(&w0, res, 0)
C.mlx_vector_array_get(&w1, res, 1)
C.mlx_vector_array_get(&w2, res, 2)
if vecSize >= 3 {
C.mlx_vector_array_get(&w2, res, 2)
}
C.mlx_vector_array_free(res)
return newArray(w0), newArray(w1), newArray(w2)
if vecSize >= 3 {
return newArray(w0), newArray(w1), newArray(w2)
}
return newArray(w0), newArray(w1), nil
}
// Dequantize reconstructs weights from quantized form.

View File

@@ -0,0 +1,693 @@
//go:build mlx
// Package glm_image implements the GLM-Image hybrid AR + diffusion model.
package glm_image
import (
"context"
"fmt"
"math"
"path/filepath"
"time"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// ByT5Tokenizer is a simple byte-level tokenizer for ByT5
// ByT5 uses bytes as tokens: each byte (0-255) maps to token ID (3-258)
// Special tokens: 0=pad, 1=eos, 2=unk
type ByT5Tokenizer struct {
PadTokenID int32
EOSTokenID int32
UNKTokenID int32
}
// NewByT5Tokenizer creates a new ByT5 tokenizer
func NewByT5Tokenizer() *ByT5Tokenizer {
return &ByT5Tokenizer{
PadTokenID: 0,
EOSTokenID: 1,
UNKTokenID: 2,
}
}
// Encode converts a string to token IDs
func (t *ByT5Tokenizer) Encode(text string) []int32 {
bytes := []byte(text)
tokens := make([]int32, len(bytes))
for i, b := range bytes {
// Standard ByT5 tokenization: bytes 0-255 map to tokens 3-258
// (tokens 0, 1, 2 are PAD, EOS, UNK)
tokens[i] = int32(b) + 3
}
return tokens
}
// Decode converts token IDs back to a string
func (t *ByT5Tokenizer) Decode(tokens []int32) string {
bytes := make([]byte, 0, len(tokens))
for _, tok := range tokens {
if tok >= 3 && tok < 259 {
bytes = append(bytes, byte(tok-3))
}
}
return string(bytes)
}
// GenerateConfig holds all options for image generation.
type GenerateConfig struct {
Prompt string
NegativePrompt string // For CFG (optional, not typically used with GLM-Image)
GuidanceScale float32 // Guidance scale (default: 1.5)
Width int32 // Image width (default: 1024, must be divisible by 32)
Height int32 // Image height (default: 1024, must be divisible by 32)
Steps int // Diffusion denoising steps (default: 50)
Seed int64 // Random seed
Progress ProgressFunc // Optional progress callback
// AR generation options
MaxVisualTokens int32 // Max visual tokens to generate (default: 256)
Temperature float32 // AR sampling temperature (default: 0.9)
TopP float32 // Nucleus sampling (default: 0.75)
}
// ProgressFunc is called during generation with stage and step progress.
type ProgressFunc func(stage string, step, totalSteps int)
// Model represents a GLM-Image hybrid model.
type Model struct {
ModelName string
Tokenizer *ByT5Tokenizer // For T5 text encoder (glyph embeddings)
GLMTokenizer *GLMTokenizer // For AR model (visual token generation)
TextEncoder *T5TextEncoder
VisionLanguageEncoder *VisionLanguageEncoder
Transformer *DiffusionTransformer
VAEDecoder *VAEDecoder
}
// Load loads the GLM-Image model from ollama blob storage.
func (m *Model) Load(modelName string) error {
fmt.Printf("Loading GLM-Image model from manifest: %s...\n", modelName)
start := time.Now()
if mlx.GPUIsAvailable() {
mlx.SetDefaultDeviceGPU()
mlx.EnableCompile()
}
m.ModelName = modelName
// Load manifest
manifest, err := imagegen.LoadManifest(modelName)
if err != nil {
return fmt.Errorf("load manifest: %w", err)
}
// Create ByT5 tokenizer (byte-level, no vocabulary file needed)
// Used for T5 text encoder (glyph embeddings)
fmt.Print(" Creating ByT5 tokenizer... ")
m.Tokenizer = NewByT5Tokenizer()
fmt.Println("✓")
// Load GLM tokenizer for AR model (visual token generation)
fmt.Print(" Loading GLM tokenizer... ")
glmTok, err := NewGLMTokenizer(manifest)
if err != nil {
return fmt.Errorf("glm tokenizer: %w", err)
}
m.GLMTokenizer = glmTok
fmt.Println("✓")
// Load T5 text encoder (~830MB)
m.TextEncoder = &T5TextEncoder{}
if err := m.TextEncoder.Load(manifest); err != nil {
return fmt.Errorf("text encoder: %w", err)
}
mlx.Eval(mlx.Collect(m.TextEncoder)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
// Load vision-language encoder (~19GB, 9B params)
m.VisionLanguageEncoder = &VisionLanguageEncoder{}
if err := m.VisionLanguageEncoder.Load(manifest); err != nil {
return fmt.Errorf("vision language encoder: %w", err)
}
mlx.Eval(mlx.Collect(m.VisionLanguageEncoder)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
// Load diffusion transformer (~13GB, 7B params)
m.Transformer = &DiffusionTransformer{}
if err := m.Transformer.Load(manifest); err != nil {
return fmt.Errorf("transformer: %w", err)
}
mlx.Eval(mlx.Collect(m.Transformer)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
// Load VAE decoder (~775MB)
m.VAEDecoder = &VAEDecoder{}
if err := m.VAEDecoder.Load(manifest); err != nil {
return fmt.Errorf("VAE decoder: %w", err)
}
mlx.Eval(mlx.Collect(m.VAEDecoder)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
mem := mlx.MetalGetActiveMemory()
fmt.Printf(" Loaded in %.2fs (%.1f GB VRAM)\n", time.Since(start).Seconds(), float64(mem)/(1024*1024*1024))
return nil
}
// LoadFromPath loads the model from a directory path (not ollama manifest)
func (m *Model) LoadFromPath(modelPath string) error {
fmt.Printf("Loading GLM-Image model from path: %s...\n", modelPath)
start := time.Now()
if mlx.GPUIsAvailable() {
mlx.SetDefaultDeviceGPU()
mlx.EnableCompile()
}
m.ModelName = modelPath
// Create ByT5 tokenizer (byte-level, no vocabulary file needed)
fmt.Print(" Creating ByT5 tokenizer... ")
m.Tokenizer = NewByT5Tokenizer()
fmt.Println("✓")
// Load GLM tokenizer for AR model (visual token generation)
fmt.Print(" Loading GLM tokenizer... ")
glmTok, err := NewGLMTokenizerFromPath(modelPath)
if err != nil {
return fmt.Errorf("glm tokenizer: %w", err)
}
m.GLMTokenizer = glmTok
fmt.Println("✓")
// Load T5 text encoder
m.TextEncoder = &T5TextEncoder{}
if err := m.TextEncoder.LoadFromPath(filepath.Join(modelPath, "text_encoder")); err != nil {
return fmt.Errorf("text encoder: %w", err)
}
mlx.Eval(mlx.Collect(m.TextEncoder)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
// Load vision-language encoder
m.VisionLanguageEncoder = &VisionLanguageEncoder{}
if err := m.VisionLanguageEncoder.LoadFromPath(filepath.Join(modelPath, "vision_language_encoder")); err != nil {
return fmt.Errorf("vision language encoder: %w", err)
}
mlx.Eval(mlx.Collect(m.VisionLanguageEncoder)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
// Load diffusion transformer
m.Transformer = &DiffusionTransformer{}
if err := m.Transformer.LoadFromPath(filepath.Join(modelPath, "transformer")); err != nil {
return fmt.Errorf("transformer: %w", err)
}
mlx.Eval(mlx.Collect(m.Transformer)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
// Load VAE decoder
m.VAEDecoder = &VAEDecoder{}
if err := m.VAEDecoder.LoadFromPath(filepath.Join(modelPath, "vae")); err != nil {
return fmt.Errorf("VAE decoder: %w", err)
}
mlx.Eval(mlx.Collect(m.VAEDecoder)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
mem := mlx.MetalGetActiveMemory()
fmt.Printf(" Loaded in %.2fs (%.1f GB VRAM)\n", time.Since(start).Seconds(), float64(mem)/(1024*1024*1024))
return nil
}
// Generate creates an image from a prompt.
func (m *Model) Generate(prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
return m.GenerateFromConfig(context.Background(), &GenerateConfig{
Prompt: prompt,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
})
}
// GenerateWithProgress creates an image with progress callback.
func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress ProgressFunc) (*mlx.Array, error) {
return m.GenerateFromConfig(context.Background(), &GenerateConfig{
Prompt: prompt,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
Progress: progress,
})
}
// GenerateFromConfig generates an image using the unified config struct.
func (m *Model) GenerateFromConfig(ctx context.Context, cfg *GenerateConfig) (*mlx.Array, error) {
start := time.Now()
result, err := m.generate(ctx, cfg)
if err != nil {
return nil, err
}
fmt.Printf("Generated in %.2fs (%d diffusion steps)\n", time.Since(start).Seconds(), cfg.Steps)
return result, nil
}
// GenerateImage implements model.ImageModel interface.
func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
return m.Generate(prompt, width, height, steps, seed)
}
// generate is the internal generation pipeline.
func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array, error) {
// Apply defaults
if cfg.Width <= 0 {
cfg.Width = 1024
}
if cfg.Height <= 0 {
cfg.Height = 1024
}
if cfg.Steps <= 0 {
cfg.Steps = 50
}
if cfg.GuidanceScale <= 0 {
cfg.GuidanceScale = 1.5
}
// Calculate MaxVisualTokens based on image dimensions
// GLM-Image generates TWO grids of visual tokens:
// 1. First: prev (small) grid - prevTokenH × prevTokenW tokens
// 2. Then: target (large) grid - tokenH × tokenW tokens
// After generation, we extract only the TARGET grid tokens for diffusion.
factor := int32(32)
tokenH := cfg.Height / factor
tokenW := cfg.Width / factor
targetGridTokens := tokenH * tokenW
// Compute prev grid dimensions using diffusers formula:
// ratio = token_h / token_w
// prev_token_h = int(sqrt(ratio) * 16)
// prev_token_w = int(sqrt(1/ratio) * 16)
ratio := float64(tokenH) / float64(tokenW)
prevTokenH := int32(math.Sqrt(ratio) * 16)
prevTokenW := int32(math.Sqrt(1/ratio) * 16)
prevGridTokens := prevTokenH * prevTokenW
// Total tokens to generate = prev grid + target grid
// (diffusers does max_new_tokens = total + 1 for EOS, but we stop on EOS anyway)
cfg.MaxVisualTokens = prevGridTokens + targetGridTokens
if cfg.Temperature <= 0 {
cfg.Temperature = 0.9
}
if cfg.TopP <= 0 {
cfg.TopP = 0.75
}
// Ensure dimensions are divisible by 32
cfg.Width = (cfg.Width / 32) * 32
cfg.Height = (cfg.Height / 32) * 32
tcfg := m.Transformer.Config
latentH := cfg.Height / 8
latentW := cfg.Width / 8
// Progress callback helper
progress := func(stage string, step, total int) {
if cfg.Progress != nil {
cfg.Progress(stage, step, total)
}
}
// === PHASE 1: T5 Text Encoding ===
fmt.Println("[T5] Encoding glyph text...")
progress("text_encoding", 0, 1)
textEmbed := m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.Prompt)
mlx.Keep(textEmbed)
mlx.Eval(textEmbed)
fmt.Printf("[T5] Done, shape: %v\n", textEmbed.Shape())
progress("text_encoding", 1, 1)
// === PHASE 2: AR Visual Token Generation ===
fmt.Printf("[AR] Generating %d visual tokens...\n", cfg.MaxVisualTokens)
progress("ar_generation", 0, int(cfg.MaxVisualTokens))
visualTokens := m.VisionLanguageEncoder.Generate(
cfg.Prompt,
m.GLMTokenizer,
cfg.MaxVisualTokens,
cfg.Temperature,
cfg.TopP,
cfg.Seed,
cfg.Height,
cfg.Width,
func(step int) {
if step%100 == 0 || step < 10 {
fmt.Printf("[AR] Step %d/%d\n", step, cfg.MaxVisualTokens)
}
progress("ar_generation", step, int(cfg.MaxVisualTokens))
},
)
mlx.Keep(visualTokens)
mlx.Eval(visualTokens)
fmt.Printf("[AR] Done generating visual tokens\n")
progress("ar_generation", int(cfg.MaxVisualTokens), int(cfg.MaxVisualTokens))
vtShape := visualTokens.Shape()
totalGenerated := vtShape[1]
fmt.Printf("[AR] Generated %d tokens total\n", totalGenerated)
// Extract only the TARGET grid tokens (skip the prev grid tokens)
// diffusers: large_image_tokens = outputs[input_length + large_image_start_offset : ...]
// large_image_start_offset = prev_grid_size
var targetGridVisualTokens *mlx.Array
if totalGenerated >= prevGridTokens+targetGridTokens {
// Full generation completed - extract target grid
targetGridVisualTokens = mlx.Slice(visualTokens,
[]int32{0, prevGridTokens},
[]int32{1, prevGridTokens + targetGridTokens})
mlx.Keep(targetGridVisualTokens)
mlx.Eval(targetGridVisualTokens)
} else if totalGenerated > prevGridTokens {
// Partial target grid - take what we have
actualTargetTokens := totalGenerated - prevGridTokens
targetGridVisualTokens = mlx.Slice(visualTokens,
[]int32{0, prevGridTokens},
[]int32{1, totalGenerated})
mlx.Keep(targetGridVisualTokens)
mlx.Eval(targetGridVisualTokens)
fmt.Printf("WARNING: Partial target grid: got %d/%d target tokens\n",
actualTargetTokens, targetGridTokens)
} else {
// Not enough tokens - EOS came too early
return nil, fmt.Errorf("AR generation stopped too early: got %d tokens, need at least %d (prev grid) + 1",
totalGenerated, prevGridTokens)
}
// === PHASE 3: Diffusion Decoding ===
// Setup scheduler with dynamic shift based on image size
scheduler := NewFlowMatchScheduler(DefaultSchedulerConfig())
imgSeqLen := (latentH / tcfg.PatchSize) * (latentW / tcfg.PatchSize)
scheduler.SetTimestepsWithDynamicShift(cfg.Steps, imgSeqLen)
// Initialize noise latents [B, C, H, W]
latents := scheduler.InitNoise([]int32{1, tcfg.InChannels, latentH, latentW}, cfg.Seed)
mlx.Eval(latents)
// Upsample TARGET grid visual tokens 2x to match patch count (matching diffusers)
// target_grid tokens -> 2x upsample -> patch_count
// e.g., 32x32=1024 tokens -> 64x64=4096 patches for 1024x1024
visualTokensUpsampled := upsampleTokens(targetGridVisualTokens, tokenH, tokenW, 2)
// Prepare prior embeddings from upsampled visual tokens (VQ codebook lookup + projection)
priorEmbed := m.Transformer.EmbedPriorTokens(visualTokensUpsampled)
mlx.Keep(priorEmbed)
mlx.Eval(priorEmbed)
// Prepare text conditioning (project T5 embeddings)
textCond := m.Transformer.ProjectTextEmbeddings(textEmbed)
mlx.Keep(textCond)
mlx.Eval(textCond)
// === CFG Setup ===
// For classifier-free guidance, we need unconditional (negative) text embeddings
// GLM-Image uses empty string "" for negative prompt
doCFG := cfg.GuidanceScale > 1.0
var negativeTextCond *mlx.Array
if doCFG {
// Encode empty string for negative prompt
negativeTextEmbed := m.TextEncoder.EncodePrompt(m.Tokenizer, "")
mlx.Keep(negativeTextEmbed)
mlx.Eval(negativeTextEmbed)
negativeTextCond = m.Transformer.ProjectTextEmbeddings(negativeTextEmbed)
mlx.Keep(negativeTextCond)
mlx.Eval(negativeTextCond)
negativeTextEmbed.Free()
}
// Prepare conditioning inputs
targetSize := mlx.NewArray([]float32{float32(cfg.Height), float32(cfg.Width)}, []int32{1, 2})
cropCoords := mlx.NewArray([]float32{0, 0}, []int32{1, 2}) // Default: no crop offset
targetSize = mlx.ToBFloat16(targetSize)
cropCoords = mlx.ToBFloat16(cropCoords)
mlx.Keep(targetSize)
mlx.Keep(cropCoords)
mlx.Eval(targetSize, cropCoords)
pH := latentH / tcfg.PatchSize
pW := latentW / tcfg.PatchSize
// Denoising loop
fmt.Printf("[Diffusion] Starting %d denoising steps...\n", cfg.Steps)
progress("diffusion", 0, cfg.Steps)
for i := 0; i < cfg.Steps; i++ {
fmt.Printf("[Diffusion] Step %d/%d (timestep=%.1f)\n", i+1, cfg.Steps, scheduler.Timesteps[i]-1)
// Check for cancellation
if ctx != nil {
select {
case <-ctx.Done():
textEmbed.Free()
visualTokens.Free()
// visualTokensUpsampled points to visualTokens, don't double-free
priorEmbed.Free()
textCond.Free()
latents.Free()
return nil, ctx.Err()
default:
}
}
// Get timestep value for the transformer
// scheduler.Timesteps contains raw timestep values (1000 down to ~20)
// Pass timestep - 1 to match diffusers: timestep = t.expand(latents.shape[0]) - 1
timestepVal := scheduler.Timesteps[i] - 1
timestep := mlx.ToBFloat16(mlx.NewArray([]float32{timestepVal}, []int32{1}))
// Patchify latents [B, C, H, W] -> [B, L, C*p*p]
patches := PatchifyLatents(latents, tcfg.PatchSize)
// Transformer forward with MMDiT architecture
// Conditional pass (with text + prior embeddings)
outputCond := m.Transformer.ForwardWithPriorDrop(
patches,
priorEmbed,
textCond,
timestep,
targetSize,
cropCoords,
pH,
pW,
false, // priorTokenDrop = false for conditional
)
// Unpatchify [B, L, C*p*p] -> [B, C, H, W]
noisePredCond := UnpatchifyLatents(outputCond, latentH, latentW, tcfg.PatchSize, tcfg.OutChannels)
var noisePred *mlx.Array
if doCFG {
// Unconditional pass (empty text, dropped prior embeddings)
outputUncond := m.Transformer.ForwardWithPriorDrop(
patches,
priorEmbed, // Still passed but will be ignored due to priorTokenDrop=true
negativeTextCond,
timestep,
targetSize,
cropCoords,
pH,
pW,
true, // priorTokenDrop = true for unconditional
)
noisePredUncond := UnpatchifyLatents(outputUncond, latentH, latentW, tcfg.PatchSize, tcfg.OutChannels)
// CFG formula: noise_pred = uncond + guidance_scale * (cond - uncond)
diff := mlx.Sub(noisePredCond, noisePredUncond)
scaled := mlx.MulScalar(diff, cfg.GuidanceScale)
noisePred = mlx.Add(noisePredUncond, scaled)
} else {
noisePred = noisePredCond
}
// Scheduler step
oldLatents := latents
latents = scheduler.Step(noisePred, latents, i)
mlx.Eval(latents)
oldLatents.Free()
progress("diffusion", i+1, cfg.Steps)
}
// Cleanup intermediate arrays
textEmbed.Free()
visualTokens.Free()
// visualTokensUpsampled points to visualTokens, don't double-free
priorEmbed.Free()
textCond.Free()
if negativeTextCond != nil {
negativeTextCond.Free()
}
targetSize.Free()
cropCoords.Free()
// === PHASE 4: VAE Decode ===
progress("vae_decode", 0, 1)
decoded := m.VAEDecoder.Decode(latents)
mlx.Eval(decoded)
latents.Free()
progress("vae_decode", 1, 1)
return decoded, nil
}
// upsampleTokens performs nearest-neighbor upsampling of visual tokens
// Converts from prev_grid (e.g., 16x16) to target_grid (e.g., 32x32 for 2x, 64x64 for 4x)
// scale must be 2 or 4
//
// Handles early EOS gracefully: if tokens has fewer than prevH*prevW elements,
// missing tokens are padded with 0 (visual token padding value).
func upsampleTokens(tokens *mlx.Array, prevH, prevW int32, scale int32) *mlx.Array {
// tokens: [1, N] where N <= prevH*prevW (may be shorter if early EOS)
// Each token at (i, j) becomes scale*scale tokens in the output
mlx.Eval(tokens)
tokenData := tokens.DataInt32()
numTokens := int32(len(tokenData))
expectedTokens := prevH * prevW
// Warn if we got fewer tokens than expected (early EOS)
if numTokens < expectedTokens {
fmt.Printf("WARNING: upsampleTokens got %d tokens, expected %d (padding with 0)\n",
numTokens, expectedTokens)
}
targetH := prevH * scale
targetW := prevW * scale
upsampled := make([]int32, targetH*targetW)
for i := int32(0); i < prevH; i++ {
for j := int32(0); j < prevW; j++ {
srcIdx := i*prevW + j
// Handle early EOS: use 0 (padding) for missing tokens
var val int32
if srcIdx < numTokens {
val = tokenData[srcIdx]
} else {
val = 0 // Padding token
}
// Place in scale*scale positions
dstI := i * scale
dstJ := j * scale
for di := int32(0); di < scale; di++ {
for dj := int32(0); dj < scale; dj++ {
upsampled[(dstI+di)*targetW+(dstJ+dj)] = val
}
}
}
}
return mlx.NewArrayInt32(upsampled, []int32{1, targetH * targetW})
}
// PatchifyLatents converts [B, C, H, W] to [B, L, C*p*p]
func PatchifyLatents(latents *mlx.Array, patchSize int32) *mlx.Array {
shape := latents.Shape()
B := shape[0]
C := shape[1]
H := shape[2]
W := shape[3]
pH := H / patchSize
pW := W / patchSize
// Reshape: [B, C, H, W] -> [B, C, pH, p, pW, p]
x := mlx.Reshape(latents, B, C, pH, patchSize, pW, patchSize)
// Transpose: -> [B, pH, pW, C, p, p]
x = mlx.Transpose(x, 0, 2, 4, 1, 3, 5)
// Flatten: -> [B, pH*pW, C*p*p]
return mlx.Reshape(x, B, pH*pW, C*patchSize*patchSize)
}
// UnpatchifyLatents converts [B, L, C*p*p] back to [B, C, H, W]
func UnpatchifyLatents(patches *mlx.Array, H, W, patchSize, channels int32) *mlx.Array {
shape := patches.Shape()
B := shape[0]
pH := H / patchSize
pW := W / patchSize
// Reshape: [B, L, C*p*p] -> [B, pH, pW, C, p, p]
x := mlx.Reshape(patches, B, pH, pW, channels, patchSize, patchSize)
// Transpose: -> [B, C, pH, p, pW, p]
x = mlx.Transpose(x, 0, 3, 1, 4, 2, 5)
// Reshape: -> [B, C, H, W]
return mlx.Reshape(x, B, channels, pH*patchSize, pW*patchSize)
}
// CalculateShift computes the dynamic shift for flow matching based on image sequence length.
func CalculateShift(imgSeqLen int32) float32 {
cfg := DefaultSchedulerConfig()
if !cfg.UseDynamicShifting {
return 0
}
// Sqrt-based shift calculation (matches diffusers)
m := float32(math.Sqrt(float64(imgSeqLen) / float64(cfg.BaseImageSeqLen)))
return m*cfg.MaxShift + cfg.BaseShift
}
// UpsampleTokens2x upsamples token IDs by 2x using nearest neighbor interpolation
// tokens: [B, H*W] -> [B, (H*2)*(W*2)]
// This matches diffusers' _upsample_token_ids function
func UpsampleTokens2x(tokens *mlx.Array, gridH, gridW int32) *mlx.Array {
shape := tokens.Shape()
B := shape[0]
// Reshape to [B, 1, H, W] for interpolation
tokens = mlx.Reshape(tokens, B, 1, gridH, gridW)
// Convert to float for interpolation
tokensFloat := mlx.AsType(tokens, mlx.DtypeFloat32)
// 2x nearest neighbor upsample
// [B, 1, H, W] -> [B, 1, H*2, W*2]
upsampled := nearestUpsample2x(tokensFloat)
// Convert back to int and reshape to [B, H*2*W*2]
upsampled = mlx.AsType(upsampled, mlx.DtypeInt32)
return mlx.Reshape(upsampled, B, gridH*2*gridW*2)
}
// nearestUpsample2x performs 2x nearest neighbor upsampling on NCHW tensor
func nearestUpsample2x(x *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
C := shape[1]
H := shape[2]
W := shape[3]
// Repeat each element 2x2
// [B, C, H, W] -> [B, C, H, 1, W, 1] -> [B, C, H, 2, W, 2] -> [B, C, H*2, W*2]
x = mlx.Reshape(x, B, C, H, 1, W, 1)
// Tile to repeat each pixel 2x2
x = mlx.Tile(x, []int32{1, 1, 1, 2, 1, 2})
// Reshape to final size
return mlx.Reshape(x, B, C, H*2, W*2)
}

View File

@@ -0,0 +1,358 @@
//go:build mlx
package glm_image
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"sort"
"strings"
"github.com/ollama/ollama/x/imagegen"
)
// GLMTokenizer implements the GLM tokenizer for the AR model
// This is a BPE-style tokenizer with ignore_merges=true, meaning it does
// greedy longest-match tokenization from the vocab without runtime merging.
type GLMTokenizer struct {
Vocab map[string]int32 // token string -> token ID
VocabReverse map[int32]string // token ID -> token string
SpecialTokens map[string]int32 // special token strings -> IDs
// Special token IDs
SopTokenID int32 // <sop> = grid_bos_token (167845)
EopTokenID int32 // <eop> = grid_eos_token (167846)
BosTokenID int32 // <|dit_token_16384|> = visual BOS (16384)
EosTokenID int32 // <|dit_token_16385|> = visual EOS (16385)
PadTokenID int32
// Sorted vocab keys by length (longest first) for greedy matching
sortedTokens []string
}
// tokenizerJSON represents the structure of tokenizer.json
type tokenizerJSON struct {
Model struct {
Vocab map[string]int32 `json:"vocab"`
} `json:"model"`
AddedTokens []struct {
ID int32 `json:"id"`
Content string `json:"content"`
Special bool `json:"special"`
} `json:"added_tokens"`
}
// NewGLMTokenizer creates a GLM tokenizer from the model manifest
func NewGLMTokenizer(manifest *imagegen.ModelManifest) (*GLMTokenizer, error) {
// Read tokenizer.json from processor directory in manifest
data, err := manifest.ReadConfig("processor/tokenizer.json")
if err != nil {
return nil, fmt.Errorf("failed to read tokenizer.json from manifest: %w", err)
}
var tj tokenizerJSON
if err := json.Unmarshal(data, &tj); err != nil {
return nil, fmt.Errorf("failed to parse tokenizer.json: %w", err)
}
tok := &GLMTokenizer{
Vocab: make(map[string]int32),
VocabReverse: make(map[int32]string),
SpecialTokens: make(map[string]int32),
}
// Load vocab from model section
for token, id := range tj.Model.Vocab {
tok.Vocab[token] = id
tok.VocabReverse[id] = token
}
// Load added tokens (special tokens including dit_tokens)
for _, at := range tj.AddedTokens {
tok.Vocab[at.Content] = at.ID
tok.VocabReverse[at.ID] = at.Content
if at.Special {
tok.SpecialTokens[at.Content] = at.ID
}
}
// Set special token IDs
tok.SopTokenID = 167845 // <sop>
tok.EopTokenID = 167846 // <eop>
tok.BosTokenID = 16384 // <|dit_token_16384|>
tok.EosTokenID = 16385 // <|dit_token_16385|>
tok.PadTokenID = 16385 // Same as EOS
// Build sorted token list for greedy matching (longest first)
tok.sortedTokens = make([]string, 0, len(tok.Vocab))
for token := range tok.Vocab {
tok.sortedTokens = append(tok.sortedTokens, token)
}
sort.Slice(tok.sortedTokens, func(i, j int) bool {
return len(tok.sortedTokens[i]) > len(tok.sortedTokens[j])
})
fmt.Printf("Loaded GLM tokenizer with %d tokens\n", len(tok.Vocab))
return tok, nil
}
// NewGLMTokenizerFromPath creates a GLM tokenizer from a directory path
func NewGLMTokenizerFromPath(modelPath string) (*GLMTokenizer, error) {
// Read tokenizer.json from processor directory
tokenizerPath := filepath.Join(modelPath, "processor", "tokenizer.json")
data, err := os.ReadFile(tokenizerPath)
if err != nil {
return nil, fmt.Errorf("failed to read tokenizer.json: %w", err)
}
var tj tokenizerJSON
if err := json.Unmarshal(data, &tj); err != nil {
return nil, fmt.Errorf("failed to parse tokenizer.json: %w", err)
}
tok := &GLMTokenizer{
Vocab: make(map[string]int32),
VocabReverse: make(map[int32]string),
SpecialTokens: make(map[string]int32),
}
// Load vocab from model section
for token, id := range tj.Model.Vocab {
tok.Vocab[token] = id
tok.VocabReverse[id] = token
}
// Load added tokens (special tokens including dit_tokens)
for _, at := range tj.AddedTokens {
tok.Vocab[at.Content] = at.ID
tok.VocabReverse[at.ID] = at.Content
if at.Special {
tok.SpecialTokens[at.Content] = at.ID
}
}
// Set special token IDs
tok.SopTokenID = 167845 // <sop>
tok.EopTokenID = 167846 // <eop>
tok.BosTokenID = 16384 // <|dit_token_16384|>
tok.EosTokenID = 16385 // <|dit_token_16385|>
tok.PadTokenID = 16385 // Same as EOS
// Build sorted token list for greedy matching (longest first)
tok.sortedTokens = make([]string, 0, len(tok.Vocab))
for token := range tok.Vocab {
tok.sortedTokens = append(tok.sortedTokens, token)
}
sort.Slice(tok.sortedTokens, func(i, j int) bool {
return len(tok.sortedTokens[i]) > len(tok.sortedTokens[j])
})
fmt.Printf("Loaded GLM tokenizer with %d tokens\n", len(tok.Vocab))
return tok, nil
}
// Encode tokenizes a string into token IDs
// This uses greedy longest-match tokenization with GPT-2 style space handling
func (t *GLMTokenizer) Encode(text string) []int32 {
if text == "" {
return []int32{}
}
var tokens []int32
// First, check for and handle special tokens
// Replace special tokens with placeholders, encode, then restore
specialReplacements := make(map[string]int32)
for special, id := range t.SpecialTokens {
if strings.Contains(text, special) {
specialReplacements[special] = id
}
}
// Process text character by character with special token handling
i := 0
isFirstToken := true
for i < len(text) {
// Check for special tokens first
foundSpecial := false
for special, id := range specialReplacements {
if strings.HasPrefix(text[i:], special) {
tokens = append(tokens, id)
i += len(special)
isFirstToken = false
foundSpecial = true
break
}
}
if foundSpecial {
continue
}
// Handle regular text with GPT-2 style space prefix
// "Ġ" (U+0120) represents a space before a token
remaining := text[i:]
// Try to find the longest matching token
matched := false
for _, token := range t.sortedTokens {
// Skip special tokens in regular matching
if _, isSpecial := t.SpecialTokens[token]; isSpecial {
continue
}
// Check if this token matches
tokenText := token
// Handle the Ġ prefix (represents space)
if strings.HasPrefix(token, "Ġ") {
// This token expects a leading space
if i > 0 || !isFirstToken {
// Check if remaining starts with space + token content
tokenContent := token[len("Ġ"):]
if strings.HasPrefix(remaining, " "+tokenContent) {
tokens = append(tokens, t.Vocab[token])
i += 1 + len(tokenContent) // space + content
isFirstToken = false
matched = true
break
}
}
} else {
// Regular token without space prefix
if strings.HasPrefix(remaining, tokenText) {
tokens = append(tokens, t.Vocab[token])
i += len(tokenText)
isFirstToken = false
matched = true
break
}
}
}
if !matched {
// No token found - skip this character (or use UNK)
// For now, just skip unknown characters
i++
}
}
return tokens
}
// EncodeForGeneration encodes a prompt with grid tokens for image generation
// Format: {prompt}<sop>{token_h} {token_w}<eop><sop>{prev_h} {prev_w}<eop><|dit_token_16384|>
//
// Uses GPT-2 style tokenization where " 32" becomes "Ġ32" (a single token with
// space prefix), matching the HuggingFace tokenizer behavior.
func (t *GLMTokenizer) EncodeForGeneration(prompt string, targetHeight, targetWidth int32) []int32 {
// Calculate grid dimensions
factor := int32(32)
height := (targetHeight / factor) * factor
width := (targetWidth / factor) * factor
tokenH := height / factor
tokenW := width / factor
// Calculate previous grid dimensions
ratio := float64(tokenH) / float64(tokenW)
prevTokenH := int32(sqrt(ratio) * 16)
prevTokenW := int32(sqrt(1.0/ratio) * 16)
// Encode the prompt text
promptTokens := t.Encode(prompt)
// Build the full sequence:
// [prompt tokens] <sop> [tokenH] [Ġ+tokenW] <eop> <sop> [prevH] [Ġ+prevW] <eop> <bos>
// Note: HF tokenizer treats " 32" as "Ġ32" (single token), not "Ġ" + "32"
var tokens []int32
tokens = append(tokens, promptTokens...)
// First grid: <sop> H W <eop>
// First number has no space prefix, second number has space prefix (Ġ)
tokens = append(tokens, t.SopTokenID)
tokens = append(tokens, t.encodeNumber(tokenH)...)
tokens = append(tokens, t.encodeSpaceNumber(tokenW)...) // " W" as Ġ+W
tokens = append(tokens, t.EopTokenID)
// Second grid: <sop> prevH prevW <eop>
tokens = append(tokens, t.SopTokenID)
tokens = append(tokens, t.encodeNumber(prevTokenH)...)
tokens = append(tokens, t.encodeSpaceNumber(prevTokenW)...) // " prevW" as Ġ+prevW
tokens = append(tokens, t.EopTokenID)
// BOS token (start of image generation)
tokens = append(tokens, t.BosTokenID)
return tokens
}
// encodeNumber encodes a number - first tries as a whole token, falls back to digit-by-digit
func (t *GLMTokenizer) encodeNumber(n int32) []int32 {
s := fmt.Sprintf("%d", n)
// First try: look up the whole number as a single token
if id, ok := t.Vocab[s]; ok {
return []int32{id}
}
// Fallback: encode digit by digit
var tokens []int32
for _, c := range s {
if id, ok := t.Vocab[string(c)]; ok {
tokens = append(tokens, id)
}
}
return tokens
}
// encodeSpaceNumber encodes " N" as "ĠN" (space-prefixed number) matching HF tokenizer
// GPT-2 style: " 32" becomes single token "Ġ32", not "Ġ" + "32"
func (t *GLMTokenizer) encodeSpaceNumber(n int32) []int32 {
s := fmt.Sprintf("%d", n)
// First try: look up "Ġ{number}" as a single token (e.g., "Ġ32")
spaceToken := "Ġ" + s
if id, ok := t.Vocab[spaceToken]; ok {
return []int32{id}
}
// Fallback: bare space Ġ + number tokens
var tokens []int32
if spaceID, ok := t.Vocab["Ġ"]; ok {
tokens = append(tokens, spaceID)
}
tokens = append(tokens, t.encodeNumber(n)...)
return tokens
}
// sqrt is a helper for float64 sqrt
func sqrt(x float64) float64 {
if x <= 0 {
return 0
}
// Newton's method
z := x
for i := 0; i < 10; i++ {
z = z - (z*z-x)/(2*z)
}
return z
}
// Decode converts token IDs back to a string
func (t *GLMTokenizer) Decode(tokens []int32) string {
var sb strings.Builder
for _, id := range tokens {
if token, ok := t.VocabReverse[id]; ok {
// Handle Ġ prefix (convert back to space)
if strings.HasPrefix(token, "Ġ") {
sb.WriteString(" ")
sb.WriteString(token[len("Ġ"):])
} else {
sb.WriteString(token)
}
}
}
return sb.String()
}

View File

@@ -0,0 +1,159 @@
//go:build mlx
package glm_image
import (
"math"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// FlowMatchSchedulerConfig holds scheduler configuration
type FlowMatchSchedulerConfig struct {
NumTrainTimesteps int32 `json:"num_train_timesteps"` // 1000
BaseShift float32 `json:"base_shift"` // 0.25
MaxShift float32 `json:"max_shift"` // 0.75
BaseImageSeqLen int32 `json:"base_image_seq_len"` // 256
MaxImageSeqLen int32 `json:"max_image_seq_len"` // 4096
UseDynamicShifting bool `json:"use_dynamic_shifting"` // true
TimeShiftType string `json:"time_shift_type"` // "linear"
}
// DefaultSchedulerConfig returns the default config for GLM-Image
func DefaultSchedulerConfig() *FlowMatchSchedulerConfig {
return &FlowMatchSchedulerConfig{
NumTrainTimesteps: 1000,
BaseShift: 0.25,
MaxShift: 0.75,
BaseImageSeqLen: 256,
MaxImageSeqLen: 4096,
UseDynamicShifting: true,
TimeShiftType: "linear",
}
}
// FlowMatchScheduler implements FlowMatchEulerDiscreteScheduler
type FlowMatchScheduler struct {
Config *FlowMatchSchedulerConfig
Timesteps []float32 // Raw timesteps for transformer conditioning (unshifted)
Sigmas []float32 // Shifted sigmas for Euler step calculation
NumSteps int
}
// NewFlowMatchScheduler creates a new scheduler
func NewFlowMatchScheduler(cfg *FlowMatchSchedulerConfig) *FlowMatchScheduler {
return &FlowMatchScheduler{Config: cfg}
}
// SetTimestepsWithDynamicShift sets timesteps with dynamic shifting based on image size
// Following diffusers: raw timesteps are used for conditioning, shifted sigmas for step calculation
func (s *FlowMatchScheduler) SetTimestepsWithDynamicShift(numSteps int, imgSeqLen int32) {
s.NumSteps = numSteps
// Calculate shift (mu) based on image sequence length
mu := s.calculateShift(imgSeqLen)
// Create timesteps: linspace from sigma_max_t to sigma_min_t
// sigma_max = 1.0, sigma_min ~= 0.001 (near 0 but not exactly 0)
// Then apply time shift and append terminal sigma=0
s.Timesteps = make([]float32, numSteps)
s.Sigmas = make([]float32, numSteps+1) // +1 for terminal sigma
numTrainTimesteps := float32(s.Config.NumTrainTimesteps)
// Create base sigmas: linspace from 1.0 to small value (matching diffusers)
for i := 0; i < numSteps; i++ {
// linspace from 1000 to ~20 (sigma_min * num_train_timesteps)
tRaw := numTrainTimesteps - float32(i)*(numTrainTimesteps-1.0)/float32(numSteps-1)
s.Timesteps[i] = tRaw
// Convert to sigma [0, 1]
sigma := tRaw / numTrainTimesteps
// Apply time shift if enabled
if s.Config.UseDynamicShifting && mu > 0 {
sigma = s.applyShift(mu, sigma)
}
s.Sigmas[i] = sigma
}
// Append terminal sigma = 0 (the final clean image)
s.Sigmas[numSteps] = 0
}
// calculateShift computes dynamic shift based on image sequence length
// Uses the sqrt-based formula from diffusers:
// m = (image_seq_len / base_seq_len) ** 0.5
// mu = m * max_shift + base_shift
func (s *FlowMatchScheduler) calculateShift(imgSeqLen int32) float32 {
cfg := s.Config
if !cfg.UseDynamicShifting {
return 0
}
// Sqrt-based shift calculation (matches diffusers pipeline_glm_image.py)
m := float32(math.Sqrt(float64(imgSeqLen) / float64(cfg.BaseImageSeqLen)))
mu := m*cfg.MaxShift + cfg.BaseShift
return mu
}
// applyShift applies time shift transformation
// mu: the computed shift value
// t: sigma value in [0, 1]
func (s *FlowMatchScheduler) applyShift(mu float32, t float32) float32 {
if t <= 0 {
return 0
}
if t >= 1 {
return 1
}
// sigma=1.0 for both shift types
sigma := float32(1.0)
if s.Config.TimeShiftType == "linear" {
// Linear: mu / (mu + (1/t - 1)^sigma)
return mu / (mu + float32(math.Pow(float64(1.0/t-1.0), float64(sigma))))
}
// Exponential (default): exp(mu) / (exp(mu) + (1/t - 1)^sigma)
expMu := float32(math.Exp(float64(mu)))
return expMu / (expMu + float32(math.Pow(float64(1.0/t-1.0), float64(sigma))))
}
// Step performs one denoising step
func (s *FlowMatchScheduler) Step(modelOutput, sample *mlx.Array, stepIdx int) *mlx.Array {
sigma := s.Sigmas[stepIdx]
sigmaNext := s.Sigmas[stepIdx+1]
// Euler step: x_{t-dt} = x_t + dt * v_t
dt := sigmaNext - sigma // Negative (going from noise to clean)
scaledOutput := mlx.MulScalar(modelOutput, dt)
return mlx.Add(sample, scaledOutput)
}
// InitNoise creates initial noise
func (s *FlowMatchScheduler) InitNoise(shape []int32, seed int64) *mlx.Array {
return mlx.RandomNormalWithDtype(shape, uint64(seed), mlx.DtypeBFloat16)
}
// AddNoise adds noise to clean samples for a given timestep (for img2img)
func (s *FlowMatchScheduler) AddNoise(cleanSample, noise *mlx.Array, timestepIdx int) *mlx.Array {
// In flow matching: x_t = (1-sigma) * x_0 + sigma * noise
// Use sigmas (shifted) for the interpolation
sigma := s.Sigmas[timestepIdx]
oneMinusSigma := 1.0 - sigma
scaledClean := mlx.MulScalar(cleanSample, oneMinusSigma)
scaledNoise := mlx.MulScalar(noise, sigma)
return mlx.Add(scaledClean, scaledNoise)
}
// GetTimesteps returns all timesteps
func (s *FlowMatchScheduler) GetTimesteps() []float32 {
return s.Timesteps
}

View File

@@ -0,0 +1,497 @@
//go:build mlx
package glm_image
import (
"encoding/json"
"fmt"
"math"
"os"
"path/filepath"
"regexp"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// T5Config holds T5 encoder configuration
type T5Config struct {
DModel int32 `json:"d_model"` // 1472
DFF int32 `json:"d_ff"` // 3584
DKV int32 `json:"d_kv"` // 64
NumHeads int32 `json:"num_heads"` // 6
NumLayers int32 `json:"num_layers"` // 12
VocabSize int32 `json:"vocab_size"` // 384 (byte-level)
LayerNormEps float32 `json:"layer_norm_epsilon"` // 1e-6
IsGatedAct bool `json:"is_gated_act"` // true (gated-gelu)
// Relative position bias
RelativeAttentionNumBuckets int32 `json:"relative_attention_num_buckets"` // 32
RelativeAttentionMaxDistance int32 `json:"relative_attention_max_distance"` // 128
}
// T5TextEncoder is the T5 encoder for text conditioning
type T5TextEncoder struct {
Config *T5Config
// Embedding (shared for ByT5)
SharedEmbed *nn.Embedding `weight:"shared"`
// Encoder layers
Layers []*T5Block `weight:"encoder.block"`
// Final layer norm
FinalNorm *T5LayerNorm `weight:"encoder.final_layer_norm"`
// Relative position bias (from first layer, shared across all)
RelativeAttentionBias *mlx.Array `weight:"encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"`
}
// T5Block is a single T5 encoder block
type T5Block struct {
// Self attention
Layer0 *T5LayerSelfAttention `weight:"layer.0"`
// FFN
Layer1 *T5LayerFF `weight:"layer.1"`
}
// T5LayerSelfAttention is T5's self-attention layer
type T5LayerSelfAttention struct {
SelfAttention *T5Attention `weight:"SelfAttention"`
LayerNorm *T5LayerNorm `weight:"layer_norm"`
}
// T5Attention implements T5's relative attention
type T5Attention struct {
Q *mlx.Array `weight:"q.weight"` // No bias in T5
K *mlx.Array `weight:"k.weight"`
V *mlx.Array `weight:"v.weight"`
O *mlx.Array `weight:"o.weight"`
NHeads int32
DKV int32
Scale float32
}
// T5LayerFF is T5's feedforward layer with gated-gelu
type T5LayerFF struct {
DenseReluDense *T5DenseGatedGelu `weight:"DenseReluDense"`
LayerNorm *T5LayerNorm `weight:"layer_norm"`
}
// T5DenseGatedGelu is T5's gated-gelu FFN
type T5DenseGatedGelu struct {
Wi0 *mlx.Array `weight:"wi_0.weight"` // gate projection
Wi1 *mlx.Array `weight:"wi_1.weight"` // up projection
Wo *mlx.Array `weight:"wo.weight"` // down projection
}
// T5LayerNorm is T5's RMSNorm variant (no bias, no mean subtraction)
type T5LayerNorm struct {
Weight *mlx.Array `weight:"weight"`
Eps float32
}
// Load loads the T5 text encoder from manifest
func (m *T5TextEncoder) Load(manifest *imagegen.ModelManifest) error {
fmt.Print(" Loading T5 text encoder... ")
// Load config
var cfg T5Config
if err := manifest.ReadConfigJSON("text_encoder/config.json", &cfg); err != nil {
return fmt.Errorf("config: %w", err)
}
m.Config = &cfg
// Pre-allocate layers
m.Layers = make([]*T5Block, cfg.NumLayers)
// Load weights
weights, err := imagegen.LoadWeightsFromManifest(manifest, "text_encoder")
if err != nil {
return fmt.Errorf("weights: %w", err)
}
if err := weights.Load(0); err != nil {
return fmt.Errorf("load weights: %w", err)
}
defer weights.ReleaseAll()
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return fmt.Errorf("load module: %w", err)
}
m.initComputedFields()
fmt.Println("✓")
return nil
}
// LoadFromPath loads the T5 text encoder from a directory path
func (m *T5TextEncoder) LoadFromPath(path string) error {
fmt.Print(" Loading T5 text encoder... ")
// Load config
var cfg T5Config
configPath := filepath.Join(path, "config.json")
data, err := os.ReadFile(configPath)
if err != nil {
return fmt.Errorf("read config: %w", err)
}
if err := json.Unmarshal(data, &cfg); err != nil {
return fmt.Errorf("parse config: %w", err)
}
m.Config = &cfg
// Pre-allocate layers
m.Layers = make([]*T5Block, cfg.NumLayers)
// Load weights from safetensors files
weights, err := safetensors.LoadModelWeights(path)
if err != nil {
return fmt.Errorf("weights: %w", err)
}
if err := weights.Load(0); err != nil {
return fmt.Errorf("load weights: %w", err)
}
defer weights.ReleaseAll()
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return fmt.Errorf("load module: %w", err)
}
m.initComputedFields()
fmt.Println("✓")
return nil
}
func (m *T5TextEncoder) initComputedFields() {
cfg := m.Config
m.FinalNorm.Eps = cfg.LayerNormEps
for _, block := range m.Layers {
attn := block.Layer0.SelfAttention
attn.NHeads = cfg.NumHeads
attn.DKV = cfg.DKV
attn.Scale = float32(1.0 / math.Sqrt(float64(cfg.DKV)))
block.Layer0.LayerNorm.Eps = cfg.LayerNormEps
block.Layer1.LayerNorm.Eps = cfg.LayerNormEps
}
}
// Forward encodes text tokens
func (m *T5TextEncoder) Forward(tokens *mlx.Array) *mlx.Array {
cfg := m.Config
// Get embeddings
h := m.SharedEmbed.Forward(tokens)
// Compute relative position bias once
seqLen := tokens.Shape()[1]
posBias := m.computeRelativePositionBias(seqLen)
// Forward through layers
for _, block := range m.Layers {
h = block.Forward(h, posBias, cfg.LayerNormEps)
}
// Final norm
h = m.FinalNorm.Forward(h)
return h
}
// extractGlyphTexts extracts quoted text (glyphs) from the prompt
// This matches diffusers' get_glyph_texts from pipeline_glm_image.py
// Glyph texts are used for text rendering guidance in the generated image
func extractGlyphTexts(prompt string) []string {
var glyphTexts []string
// Extract text in single quotes: 'text'
re1 := regexp.MustCompile(`'([^']*)'`)
for _, match := range re1.FindAllStringSubmatch(prompt, -1) {
if len(match) > 1 {
glyphTexts = append(glyphTexts, match[1])
}
}
// Extract text in Unicode curly double quotes: "text"
re2 := regexp.MustCompile(`"([^""]*)"`)
for _, match := range re2.FindAllStringSubmatch(prompt, -1) {
if len(match) > 1 {
glyphTexts = append(glyphTexts, match[1])
}
}
// Extract text in ASCII double quotes: "text"
re3 := regexp.MustCompile(`"([^"]*)"`)
for _, match := range re3.FindAllStringSubmatch(prompt, -1) {
if len(match) > 1 {
glyphTexts = append(glyphTexts, match[1])
}
}
// Extract text in Japanese quotes: 「text」
re4 := regexp.MustCompile(`「([^「」]*)」`)
for _, match := range re4.FindAllStringSubmatch(prompt, -1) {
if len(match) > 1 {
glyphTexts = append(glyphTexts, match[1])
}
}
return glyphTexts
}
// EncodePrompt encodes the prompt text using the ByT5 tokenizer and encoder
// This provides text conditioning for the diffusion transformer via the glyph projector
//
// IMPORTANT: This encodes only the GLYPH TEXTS (quoted strings in the prompt), not the
// full prompt. Glyph texts are used for text rendering guidance in the generated image.
// Multiple glyph texts are encoded and concatenated to form the conditioning signal.
// This matches diffusers' _get_glyph_embeds() behavior.
func (m *T5TextEncoder) EncodePrompt(tok *ByT5Tokenizer, prompt string) *mlx.Array {
// Extract glyph texts from prompt (text in quotes)
glyphTexts := extractGlyphTexts(prompt)
// If no glyph texts found, encode empty string (matches diffusers: [""] fallback)
if len(glyphTexts) == 0 {
glyphTexts = []string{""}
}
// Encode each glyph text and collect token sequences
// Matching diffusers' _get_glyph_embeds() which batches all glyph texts
var allTokenSeqs [][]int32
for _, glyphText := range glyphTexts {
// ByT5 uses byte-level encoding: each byte (0-255) -> token (3-258)
tokens := tok.Encode(glyphText)
// Add EOS token (1) at the end to match HuggingFace tokenizer behavior
tokens = append(tokens, tok.EOSTokenID)
allTokenSeqs = append(allTokenSeqs, tokens)
}
// Process each glyph text through the encoder
var allEmbeddings []*mlx.Array
for _, tokens := range allTokenSeqs {
tokenLen := len(tokens)
if tokenLen == 0 {
continue
}
// Create token array [1, L]
tokensArr := mlx.NewArrayInt32(tokens, []int32{1, int32(tokenLen)})
// Forward through encoder
output := m.Forward(tokensArr)
mlx.Eval(output)
allEmbeddings = append(allEmbeddings, output)
}
// Concatenate all glyph embeddings along sequence dimension
var output *mlx.Array
if len(allEmbeddings) == 0 {
// Fallback: return single zero embedding
output = mlx.Zeros([]int32{1, 1, m.Config.DModel}, mlx.DtypeBFloat16)
} else if len(allEmbeddings) == 1 {
output = allEmbeddings[0]
} else {
output = mlx.Concatenate(allEmbeddings, 1)
}
mlx.Eval(output)
return output
}
// computeRelativePositionBias computes T5's relative position encoding
func (m *T5TextEncoder) computeRelativePositionBias(seqLen int32) *mlx.Array {
cfg := m.Config
// Create relative position matrix
// For each (query_pos, key_pos) pair, compute bucketed relative position
numBuckets := cfg.RelativeAttentionNumBuckets
maxDistance := cfg.RelativeAttentionMaxDistance
// Create position indices
contextPos := make([]int32, seqLen*seqLen)
memoryPos := make([]int32, seqLen*seqLen)
for i := int32(0); i < seqLen; i++ {
for j := int32(0); j < seqLen; j++ {
contextPos[i*seqLen+j] = i
memoryPos[i*seqLen+j] = j
}
}
// Compute relative positions and bucket them
buckets := make([]int32, seqLen*seqLen)
for i := int32(0); i < seqLen*seqLen; i++ {
relPos := memoryPos[i] - contextPos[i]
buckets[i] = relativePosistionBucket(relPos, numBuckets, maxDistance, false)
}
// Create bucket indices array
bucketsArr := mlx.NewArrayInt32(buckets, []int32{seqLen, seqLen})
// Look up bias: RelativeAttentionBias shape is [numBuckets, numHeads] = [32, 6]
// Take along axis 0 (buckets dimension) -> [seqLen, seqLen, numHeads]
bias := mlx.Take(m.RelativeAttentionBias, bucketsArr, 0) // [seqLen, seqLen, numHeads]
// Transpose to [numHeads, seqLen, seqLen]
bias = mlx.Transpose(bias, 2, 0, 1) // [numHeads, seqLen, seqLen]
bias = mlx.ExpandDims(bias, 0) // [1, numHeads, seqLen, seqLen]
return bias
}
// relativePosistionBucket computes the bucket for a relative position
func relativePosistionBucket(relativePosition int32, numBuckets int32, maxDistance int32, bidirectional bool) int32 {
var bucket int32 = 0
var n int32 = -relativePosition
if bidirectional {
numBuckets /= 2
if n < 0 {
bucket += numBuckets
n = -n
}
} else {
if n < 0 {
n = 0
}
}
// Half buckets are for exact positions, half are for log-spaced
maxExact := numBuckets / 2
if n < maxExact {
bucket += n
} else {
// Log-spaced buckets
logVal := math.Log(float64(n)/float64(maxExact)) / math.Log(float64(maxDistance)/float64(maxExact))
bucket += maxExact + int32(logVal*float64(numBuckets-maxExact))
if bucket > numBuckets-1 {
bucket = numBuckets - 1
}
}
return bucket
}
// Forward for T5Block
func (b *T5Block) Forward(x *mlx.Array, posBias *mlx.Array, eps float32) *mlx.Array {
// Self attention with residual
h := b.Layer0.Forward(x, posBias, eps)
// FFN with residual
h = b.Layer1.Forward(h, eps)
return h
}
// Forward for T5LayerSelfAttention
func (l *T5LayerSelfAttention) Forward(x *mlx.Array, posBias *mlx.Array, eps float32) *mlx.Array {
// Pre-norm
normed := l.LayerNorm.Forward(x)
// Attention
attnOut := l.SelfAttention.Forward(normed, posBias)
// Residual
return mlx.Add(x, attnOut)
}
// Forward for T5Attention
func (attn *T5Attention) Forward(x *mlx.Array, posBias *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
D := shape[2]
// Q, K, V projections (no bias)
// Weights are [out_features, in_features], so we use matmul with transpose
q := mlx.Matmul(x, mlx.Transpose(attn.Q, 1, 0))
k := mlx.Matmul(x, mlx.Transpose(attn.K, 1, 0))
v := mlx.Matmul(x, mlx.Transpose(attn.V, 1, 0))
// Reshape to [B, L, nheads, d_kv]
q = mlx.Reshape(q, B, L, attn.NHeads, attn.DKV)
k = mlx.Reshape(k, B, L, attn.NHeads, attn.DKV)
v = mlx.Reshape(v, B, L, attn.NHeads, attn.DKV)
// Transpose to [B, nheads, L, d_kv]
q = mlx.Transpose(q, 0, 2, 1, 3)
k = mlx.Transpose(k, 0, 2, 1, 3)
v = mlx.Transpose(v, 0, 2, 1, 3)
// Attention scores with relative position bias
// T5 uses UNSCALED dot-product attention: scores = q @ k.T + pos_bias
// (no 1/sqrt(d_k) scale factor like in standard transformers)
scores := mlx.Matmul(q, mlx.Transpose(k, 0, 1, 3, 2))
scores = mlx.Add(scores, posBias)
// Softmax
attnWeights := mlx.Softmax(scores, -1)
// Attend to values
out := mlx.Matmul(attnWeights, v)
// Transpose back [B, nheads, L, d_kv] -> [B, L, nheads, d_kv]
out = mlx.Transpose(out, 0, 2, 1, 3)
// Reshape to [B, L, D]
out = mlx.Reshape(out, B, L, attn.NHeads*attn.DKV)
// Output projection
out = mlx.Matmul(out, mlx.Transpose(attn.O, 1, 0))
_ = D // Silence unused warning
return out
}
// Forward for T5LayerFF
func (l *T5LayerFF) Forward(x *mlx.Array, eps float32) *mlx.Array {
// Pre-norm
normed := l.LayerNorm.Forward(x)
// FFN
ffOut := l.DenseReluDense.Forward(normed)
// Residual
return mlx.Add(x, ffOut)
}
// geluNew implements the GELU activation with tanh approximation (gelu_new)
// This matches HuggingFace transformers' gelu_new/OpenAI GPT implementation
// Formula: 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x³)))
func geluNew(x *mlx.Array) *mlx.Array {
sqrt2OverPi := float32(0.7978845608) // sqrt(2/π)
coeff := float32(0.044715)
x3 := mlx.Mul(mlx.Mul(x, x), x)
inner := mlx.MulScalar(mlx.Add(x, mlx.MulScalar(x3, coeff)), sqrt2OverPi)
return mlx.Mul(mlx.MulScalar(x, 0.5), mlx.AddScalar(mlx.Tanh(inner), 1.0))
}
// Forward for T5DenseGatedGelu (gated-gelu activation)
func (d *T5DenseGatedGelu) Forward(x *mlx.Array) *mlx.Array {
// Gate projection with GELU activation (T5 v1.1/ByT5 uses gelu_new)
gate := mlx.Matmul(x, mlx.Transpose(d.Wi0, 1, 0))
gate = geluNew(gate)
// Up projection
up := mlx.Matmul(x, mlx.Transpose(d.Wi1, 1, 0))
// Gated output
h := mlx.Mul(gate, up)
// Down projection
return mlx.Matmul(h, mlx.Transpose(d.Wo, 1, 0))
}
// Forward for T5LayerNorm (RMSNorm variant)
func (ln *T5LayerNorm) Forward(x *mlx.Array) *mlx.Array {
// T5 uses RMSNorm: x * rsqrt(mean(x^2) + eps) * weight
variance := mlx.Mean(mlx.Square(x), -1, true)
x = mlx.Mul(x, mlx.RSqrt(mlx.AddScalar(variance, ln.Eps)))
return mlx.Mul(x, ln.Weight)
}

View File

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,477 @@
//go:build mlx
package glm_image
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// VAEConfig holds VAE decoder configuration
type VAEConfig struct {
InChannels int32 `json:"in_channels"` // 3
OutChannels int32 `json:"out_channels"` // 3
LatentChannels int32 `json:"latent_channels"` // 16
BlockOutChannels []int32 `json:"block_out_channels"` // [128, 512, 1024, 1024]
LayersPerBlock int32 `json:"layers_per_block"` // 3
NormNumGroups int32 `json:"norm_num_groups"` // 32
ScalingFactor float32 `json:"scaling_factor"` // 0.18215
ShiftFactor *float32 `json:"shift_factor"` // null
LatentsMean []float32 `json:"latents_mean"` // [16 values]
LatentsStd []float32 `json:"latents_std"` // [16 values]
}
// VAEDecoder is the VAE latent decoder
type VAEDecoder struct {
Config *VAEConfig
// Decoder components
ConvIn *VAEConv2d `weight:"decoder.conv_in"`
MidBlock *VAEMidBlock `weight:"decoder.mid_block"`
UpBlocks []*VAEUpBlock `weight:"decoder.up_blocks"`
ConvNormOut *GroupNorm `weight:"decoder.conv_norm_out"`
ConvOut *VAEConv2d `weight:"decoder.conv_out"`
}
// VAEConv2d is a 2D convolution layer
type VAEConv2d struct {
Weight *mlx.Array `weight:"weight"`
Bias *mlx.Array `weight:"bias"`
Stride int32
Padding int32
}
// GroupNorm is group normalization
type GroupNorm struct {
Weight *mlx.Array `weight:"weight"`
Bias *mlx.Array `weight:"bias"`
NumGroups int32
Eps float32
}
// VAEMidBlock is the middle block of the VAE
type VAEMidBlock struct {
Resnets []*VAEResnetBlock `weight:"resnets"`
}
// VAEUpBlock is an upsampling block
type VAEUpBlock struct {
Resnets []*VAEResnetBlock `weight:"resnets"`
Upsamplers []*VAEUpsampler `weight:"upsamplers"`
}
// VAEResnetBlock is a residual block
type VAEResnetBlock struct {
Norm1 *GroupNorm `weight:"norm1"`
Conv1 *VAEConv2d `weight:"conv1"`
Norm2 *GroupNorm `weight:"norm2"`
Conv2 *VAEConv2d `weight:"conv2"`
ConvShortcut *VAEConv2d `weight:"conv_shortcut,optional"` // Optional, for channel mismatch
}
// VAEUpsampler is an upsampling layer
type VAEUpsampler struct {
Conv *VAEConv2d `weight:"conv"`
}
// Load loads the VAE decoder from manifest
func (m *VAEDecoder) Load(manifest *imagegen.ModelManifest) error {
fmt.Print(" Loading VAE decoder... ")
// Load config
var cfg VAEConfig
if err := manifest.ReadConfigJSON("vae/config.json", &cfg); err != nil {
return fmt.Errorf("config: %w", err)
}
m.Config = &cfg
// Initialize structure based on config
numBlocks := len(cfg.BlockOutChannels)
m.UpBlocks = make([]*VAEUpBlock, numBlocks)
// Pre-allocate MidBlock resnets (VAE mid_block typically has 2 resnets)
m.MidBlock = &VAEMidBlock{
Resnets: make([]*VAEResnetBlock, 2),
}
// Pre-allocate UpBlocks with their resnets and upsamplers
// VAE decoder has layers_per_block+1 resnets per up_block (to match encoder)
// And all but the last up_block has an upsampler
for i := 0; i < numBlocks; i++ {
numResnets := cfg.LayersPerBlock + 1 // typically 4 resnets
m.UpBlocks[i] = &VAEUpBlock{
Resnets: make([]*VAEResnetBlock, numResnets),
}
// All but the last block has upsamplers
if i < numBlocks-1 {
m.UpBlocks[i].Upsamplers = make([]*VAEUpsampler, 1)
}
}
// Load weights
weights, err := imagegen.LoadWeightsFromManifest(manifest, "vae")
if err != nil {
return fmt.Errorf("weights: %w", err)
}
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
return fmt.Errorf("load weights: %w", err)
}
defer weights.ReleaseAll()
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return fmt.Errorf("load module: %w", err)
}
// Initialize GroupNorm parameters
m.initGroupNorms()
fmt.Println("✓")
return nil
}
// LoadFromPath loads the VAE decoder from a directory path
func (m *VAEDecoder) LoadFromPath(path string) error {
fmt.Print(" Loading VAE decoder... ")
// Load config
var cfg VAEConfig
configPath := filepath.Join(path, "config.json")
data, err := os.ReadFile(configPath)
if err != nil {
return fmt.Errorf("read config: %w", err)
}
if err := json.Unmarshal(data, &cfg); err != nil {
return fmt.Errorf("parse config: %w", err)
}
m.Config = &cfg
// Initialize structure based on config
numBlocks := len(cfg.BlockOutChannels)
m.UpBlocks = make([]*VAEUpBlock, numBlocks)
// Pre-allocate MidBlock resnets (VAE mid_block typically has 2 resnets)
m.MidBlock = &VAEMidBlock{
Resnets: make([]*VAEResnetBlock, 2),
}
// Pre-allocate UpBlocks with their resnets and upsamplers
for i := 0; i < numBlocks; i++ {
numResnets := cfg.LayersPerBlock + 1
m.UpBlocks[i] = &VAEUpBlock{
Resnets: make([]*VAEResnetBlock, numResnets),
}
if i < numBlocks-1 {
m.UpBlocks[i].Upsamplers = make([]*VAEUpsampler, 1)
}
}
// Load weights from safetensors files
weights, err := safetensors.LoadModelWeights(path)
if err != nil {
return fmt.Errorf("weights: %w", err)
}
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
return fmt.Errorf("load weights: %w", err)
}
defer weights.ReleaseAll()
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return fmt.Errorf("load module: %w", err)
}
// Initialize GroupNorm parameters
m.initGroupNorms()
fmt.Println("✓")
return nil
}
func (m *VAEDecoder) initGroupNorms() {
cfg := m.Config
numGroups := cfg.NormNumGroups
eps := float32(1e-6) // Must match diffusers VAE (1e-6, not 1e-5)
if m.ConvNormOut != nil {
m.ConvNormOut.NumGroups = numGroups
m.ConvNormOut.Eps = eps
}
if m.MidBlock != nil {
for _, resnet := range m.MidBlock.Resnets {
if resnet.Norm1 != nil {
resnet.Norm1.NumGroups = numGroups
resnet.Norm1.Eps = eps
}
if resnet.Norm2 != nil {
resnet.Norm2.NumGroups = numGroups
resnet.Norm2.Eps = eps
}
}
}
for _, upBlock := range m.UpBlocks {
if upBlock == nil {
continue
}
for _, resnet := range upBlock.Resnets {
if resnet == nil {
continue
}
if resnet.Norm1 != nil {
resnet.Norm1.NumGroups = numGroups
resnet.Norm1.Eps = eps
}
if resnet.Norm2 != nil {
resnet.Norm2.NumGroups = numGroups
resnet.Norm2.Eps = eps
}
}
}
}
// Decode decodes latents to an image
func (m *VAEDecoder) Decode(latents *mlx.Array) *mlx.Array {
cfg := m.Config
// Apply latent denormalization if mean/std are provided
// This matches diffusers GLM-Image: latents = latents * std + mean
// Note: GLM-Image does NOT divide by scaling_factor (unlike standard SD VAEs)
if len(cfg.LatentsMean) > 0 && len(cfg.LatentsStd) > 0 {
latents = m.denormalizeLatents(latents)
}
// Convert from NCHW to NHWC for processing
// [B, C, H, W] -> [B, H, W, C]
x := mlx.Transpose(latents, 0, 2, 3, 1)
// Initial convolution
x = m.ConvIn.Forward(x)
// Mid block
x = m.MidBlock.Forward(x)
// Up blocks (forward order - index 0 is at lowest resolution/highest channels)
for i := 0; i < len(m.UpBlocks); i++ {
if m.UpBlocks[i] != nil {
x = m.UpBlocks[i].Forward(x)
}
}
// Final normalization and convolution
x = m.ConvNormOut.Forward(x)
x = mlx.SiLU(x)
x = m.ConvOut.Forward(x)
// Convert back to NCHW
// [B, H, W, C] -> [B, C, H, W]
x = mlx.Transpose(x, 0, 3, 1, 2)
// Clamp to valid range and convert to [0, 1]
x = mlx.ClipScalar(x, -1.0, 1.0, true, true)
x = mlx.AddScalar(x, 1.0)
x = mlx.DivScalar(x, 2.0)
return x
}
// denormalizeLatents applies the latent mean/std denormalization
func (m *VAEDecoder) denormalizeLatents(latents *mlx.Array) *mlx.Array {
cfg := m.Config
// Create mean and std arrays [1, C, 1, 1] for broadcasting
mean := mlx.NewArray(cfg.LatentsMean, []int32{1, int32(len(cfg.LatentsMean)), 1, 1})
std := mlx.NewArray(cfg.LatentsStd, []int32{1, int32(len(cfg.LatentsStd)), 1, 1})
// Denormalize: latents * std + mean
latents = mlx.Mul(latents, std)
latents = mlx.Add(latents, mean)
return latents
}
// Forward for VAEConv2d
func (c *VAEConv2d) Forward(x *mlx.Array) *mlx.Array {
// x: [B, H, W, C_in] (NHWC)
// PyTorch weight: [C_out, C_in, kH, kW] (OIHW)
// MLX conv2d expects weight: [C_out, kH, kW, C_in] (OHWI)
// So we need to transpose from OIHW to OHWI
stride := c.Stride
if stride == 0 {
stride = 1
}
padding := c.Padding
if padding == 0 {
// Default to same padding for 3x3 kernels
wShape := c.Weight.Shape()
if len(wShape) >= 3 && wShape[2] == 3 {
padding = 1
}
}
// Transpose weight from OIHW [out, in, h, w] to OHWI [out, h, w, in]
weight := mlx.Transpose(c.Weight, 0, 2, 3, 1)
out := mlx.Conv2d(x, weight, stride, padding)
if c.Bias != nil {
// Bias: [C_out] -> [1, 1, 1, C_out]
bias := mlx.Reshape(c.Bias, 1, 1, 1, -1)
out = mlx.Add(out, bias)
}
return out
}
// Forward for GroupNorm
func (gn *GroupNorm) Forward(x *mlx.Array) *mlx.Array {
// x: [B, H, W, C] (NHWC)
shape := x.Shape()
B := shape[0]
H := shape[1]
W := shape[2]
C := shape[3]
numGroups := gn.NumGroups
if numGroups == 0 {
numGroups = 32
}
groupSize := C / numGroups
// Reshape to [B, H, W, groups, groupSize]
x = mlx.Reshape(x, B, H, W, numGroups, groupSize)
// Compute mean and variance per group
mean := mlx.Mean(x, 1, true)
mean = mlx.Mean(mean, 2, true)
mean = mlx.Mean(mean, 4, true)
xCentered := mlx.Sub(x, mean)
variance := mlx.Mean(mlx.Square(xCentered), 1, true)
variance = mlx.Mean(variance, 2, true)
variance = mlx.Mean(variance, 4, true)
// Normalize
xNorm := mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, gn.Eps)))
// Reshape back
xNorm = mlx.Reshape(xNorm, B, H, W, C)
// Scale and shift
if gn.Weight != nil {
weight := mlx.Reshape(gn.Weight, 1, 1, 1, C)
xNorm = mlx.Mul(xNorm, weight)
}
if gn.Bias != nil {
bias := mlx.Reshape(gn.Bias, 1, 1, 1, C)
xNorm = mlx.Add(xNorm, bias)
}
return xNorm
}
// Forward for VAEMidBlock
func (mb *VAEMidBlock) Forward(x *mlx.Array) *mlx.Array {
for _, resnet := range mb.Resnets {
x = resnet.Forward(x)
}
return x
}
// Forward for VAEUpBlock
func (ub *VAEUpBlock) Forward(x *mlx.Array) *mlx.Array {
// Apply resnets
for _, resnet := range ub.Resnets {
if resnet != nil {
x = resnet.Forward(x)
}
}
// Apply upsamplers
for _, upsampler := range ub.Upsamplers {
if upsampler != nil {
x = upsampler.Forward(x)
}
}
return x
}
// Forward for VAEResnetBlock
func (rb *VAEResnetBlock) Forward(x *mlx.Array) *mlx.Array {
residual := x
// First norm + activation + conv
h := rb.Norm1.Forward(x)
h = mlx.SiLU(h)
h = rb.Conv1.Forward(h)
// Second norm + activation + conv
h = rb.Norm2.Forward(h)
h = mlx.SiLU(h)
h = rb.Conv2.Forward(h)
// Shortcut for channel mismatch
if rb.ConvShortcut != nil {
residual = rb.ConvShortcut.Forward(residual)
}
return mlx.Add(h, residual)
}
// Forward for VAEUpsampler (2x nearest neighbor upsample + conv)
func (us *VAEUpsampler) Forward(x *mlx.Array) *mlx.Array {
// x: [B, H, W, C]
// 2x nearest neighbor upsample
x = upsample2x(x)
// Conv
if us.Conv != nil {
x = us.Conv.Forward(x)
}
return x
}
// upsample2x performs 2x nearest neighbor upsampling.
// Input and output are in NHWC format: [B, H, W, C] -> [B, H*2, W*2, C]
func upsample2x(x *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
H := shape[1]
W := shape[2]
C := shape[3]
// Create indices [0, 0, 1, 1, 2, 2, ...] for nearest neighbor
hIndices := make([]int32, H*2)
for i := int32(0); i < H; i++ {
hIndices[i*2] = i
hIndices[i*2+1] = i
}
wIndices := make([]int32, W*2)
for i := int32(0); i < W; i++ {
wIndices[i*2] = i
wIndices[i*2+1] = i
}
hIdx := mlx.NewArrayInt32(hIndices, []int32{H * 2})
wIdx := mlx.NewArrayInt32(wIndices, []int32{W * 2})
// Take along height axis
x = mlx.Reshape(x, B*H, W, C)
x = mlx.Take(x, wIdx, 1) // [B*H, W*2, C]
x = mlx.Reshape(x, B, H, W*2, C)
// Take along width axis - transpose to [B, W*2, H, C], take, transpose back
x = mlx.Transpose(x, 0, 2, 1, 3) // [B, W*2, H, C]
x = mlx.Reshape(x, B*(W*2), H, C)
x = mlx.Take(x, hIdx, 1) // [B*(W*2), H*2, C]
x = mlx.Reshape(x, B, W*2, H*2, C)
x = mlx.Transpose(x, 0, 2, 1, 3) // [B, H*2, W*2, C]
return x
}

View File

@@ -0,0 +1,982 @@
//go:build mlx
package glm_image
import (
"encoding/json"
"fmt"
"math"
"os"
"path/filepath"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// VisionLanguageConfig holds GLM-Image AR generator configuration
type VisionLanguageConfig struct {
// Text model config
HiddenSize int32 `json:"hidden_size"` // 4096
NumHiddenLayers int32 `json:"num_hidden_layers"` // 40
IntermediateSize int32 `json:"intermediate_size"` // 13696
NumAttentionHeads int32 `json:"num_attention_heads"` // 32
NumKeyValueHeads int32 `json:"num_key_value_heads"` // 2
VocabSize int32 `json:"vocab_size"` // 168064
RMSNormEps float32 `json:"rms_norm_eps"` // 1e-5
// RoPE config
RopeTheta float32 `json:"rope_theta"` // 10000
PartialRotaryFactor float32 `json:"partial_rotary_factor"` // 0.5
MRoPESection []int32 `json:"mrope_section"` // [8, 12, 12]
// Visual token config
VisionVocabSize int32 `json:"vision_vocab_size"` // 16512
ImageStartTokenID int32 `json:"image_start_token_id"` // 16384
ImageEndTokenID int32 `json:"image_end_token_id"` // 16385
ImageTokenID int32 `json:"image_token_id"` // 167855
// Computed
HeadDim int32
}
// VisionLanguageEncoder is the 9B AR generator
type VisionLanguageEncoder struct {
Config *VisionLanguageConfig
// Embedding
EmbedTokens *nn.Embedding `weight:"model.language_model.embed_tokens"`
// Transformer layers
Layers []*GLMBlock `weight:"model.language_model.layers"`
// Final norm
FinalNorm *nn.RMSNorm `weight:"model.language_model.norm"`
// LM Head
LMHead *mlx.Array `weight:"lm_head.weight"`
}
// GLMBlock is a single transformer block in GLM-4 style
type GLMBlock struct {
// Pre-attention norm (GLM uses post-LN variant)
InputLayerNorm *nn.RMSNorm `weight:"input_layernorm"`
PostSelfAttnNorm *nn.RMSNorm `weight:"post_self_attn_layernorm"`
PostAttnLayerNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
PostMLPLayerNorm *nn.RMSNorm `weight:"post_mlp_layernorm"`
// Attention
SelfAttn *GLMAttention `weight:"self_attn"`
// MLP (fused gate_up)
MLP *GLMMLP `weight:"mlp"`
}
// GLMAttention implements GQA with partial rotary and MRoPE
type GLMAttention struct {
QProj *mlx.Array `weight:"q_proj.weight"`
KProj *mlx.Array `weight:"k_proj.weight"`
VProj *mlx.Array `weight:"v_proj.weight"`
OProj *mlx.Array `weight:"o_proj.weight"`
// QKV have biases in GLM
QBias *mlx.Array `weight:"q_proj.bias"`
KBias *mlx.Array `weight:"k_proj.bias"`
VBias *mlx.Array `weight:"v_proj.bias"`
// Computed
NHeads int32
NKVHeads int32
HeadDim int32
Scale float32
PartialRotary float32 // Only rotate this fraction of head_dim
RopeTheta float32
MRoPESection []int32 // [8, 12, 12] - frequency pairs per dimension (temporal, height, width)
}
// ARCache holds KV caches for all layers using the shared cache implementation
type ARCache struct {
Layers []cache.Cache
}
// NewARCache creates a new cache for the given number of layers
func NewARCache(numLayers int32) *ARCache {
layers := make([]cache.Cache, numLayers)
for i := range layers {
layers[i] = cache.NewKVCache()
}
return &ARCache{Layers: layers}
}
// Free releases all cached tensors
func (c *ARCache) Free() {
for _, layer := range c.Layers {
for _, arr := range layer.State() {
if arr != nil {
arr.Free()
}
}
}
}
// GLMMLP implements fused gate_up SwiGLU MLP
type GLMMLP struct {
// GLM uses fused gate_up_proj: [hidden, 2*intermediate]
GateUpProj *mlx.Array `weight:"gate_up_proj.weight"`
DownProj *mlx.Array `weight:"down_proj.weight"`
}
// Load loads the vision-language encoder from manifest
func (m *VisionLanguageEncoder) Load(manifest *imagegen.ModelManifest) error {
fmt.Print(" Loading vision-language encoder... ")
// Load config
var rawCfg struct {
TextConfig struct {
HiddenSize int32 `json:"hidden_size"`
NumHiddenLayers int32 `json:"num_hidden_layers"`
IntermediateSize int32 `json:"intermediate_size"`
NumAttentionHeads int32 `json:"num_attention_heads"`
NumKeyValueHeads int32 `json:"num_key_value_heads"`
VocabSize int32 `json:"vocab_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
VisionVocabSize int32 `json:"vision_vocab_size"`
RopeParameters struct {
RopeTheta float32 `json:"rope_theta"`
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
MRoPESection []int32 `json:"mrope_section"`
} `json:"rope_parameters"`
} `json:"text_config"`
ImageStartTokenID int32 `json:"image_start_token_id"`
ImageEndTokenID int32 `json:"image_end_token_id"`
ImageTokenID int32 `json:"image_token_id"`
}
if err := manifest.ReadConfigJSON("vision_language_encoder/config.json", &rawCfg); err != nil {
return fmt.Errorf("config: %w", err)
}
cfg := &VisionLanguageConfig{
HiddenSize: rawCfg.TextConfig.HiddenSize,
NumHiddenLayers: rawCfg.TextConfig.NumHiddenLayers,
IntermediateSize: rawCfg.TextConfig.IntermediateSize,
NumAttentionHeads: rawCfg.TextConfig.NumAttentionHeads,
NumKeyValueHeads: rawCfg.TextConfig.NumKeyValueHeads,
VocabSize: rawCfg.TextConfig.VocabSize,
RMSNormEps: rawCfg.TextConfig.RMSNormEps,
VisionVocabSize: rawCfg.TextConfig.VisionVocabSize,
RopeTheta: rawCfg.TextConfig.RopeParameters.RopeTheta,
PartialRotaryFactor: rawCfg.TextConfig.RopeParameters.PartialRotaryFactor,
MRoPESection: rawCfg.TextConfig.RopeParameters.MRoPESection,
ImageStartTokenID: rawCfg.ImageStartTokenID,
ImageEndTokenID: rawCfg.ImageEndTokenID,
ImageTokenID: rawCfg.ImageTokenID,
}
cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads
m.Config = cfg
// Pre-allocate layers
m.Layers = make([]*GLMBlock, cfg.NumHiddenLayers)
// Load weights
weights, err := imagegen.LoadWeightsFromManifest(manifest, "vision_language_encoder")
if err != nil {
return fmt.Errorf("weights: %w", err)
}
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
return fmt.Errorf("load weights: %w", err)
}
defer weights.ReleaseAll()
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return fmt.Errorf("load module: %w", err)
}
m.initComputedFields()
fmt.Printf("✓ [%d layers]\n", cfg.NumHiddenLayers)
return nil
}
// LoadFromPath loads the vision-language encoder from a directory path
func (m *VisionLanguageEncoder) LoadFromPath(path string) error {
fmt.Print(" Loading vision-language encoder... ")
// Load config
var rawCfg struct {
TextConfig struct {
HiddenSize int32 `json:"hidden_size"`
NumHiddenLayers int32 `json:"num_hidden_layers"`
IntermediateSize int32 `json:"intermediate_size"`
NumAttentionHeads int32 `json:"num_attention_heads"`
NumKeyValueHeads int32 `json:"num_key_value_heads"`
VocabSize int32 `json:"vocab_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
VisionVocabSize int32 `json:"vision_vocab_size"`
RopeParameters struct {
RopeTheta float32 `json:"rope_theta"`
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
MRoPESection []int32 `json:"mrope_section"`
} `json:"rope_parameters"`
} `json:"text_config"`
ImageStartTokenID int32 `json:"image_start_token_id"`
ImageEndTokenID int32 `json:"image_end_token_id"`
ImageTokenID int32 `json:"image_token_id"`
}
configPath := filepath.Join(path, "config.json")
data, err := os.ReadFile(configPath)
if err != nil {
return fmt.Errorf("read config: %w", err)
}
if err := json.Unmarshal(data, &rawCfg); err != nil {
return fmt.Errorf("parse config: %w", err)
}
cfg := &VisionLanguageConfig{
HiddenSize: rawCfg.TextConfig.HiddenSize,
NumHiddenLayers: rawCfg.TextConfig.NumHiddenLayers,
IntermediateSize: rawCfg.TextConfig.IntermediateSize,
NumAttentionHeads: rawCfg.TextConfig.NumAttentionHeads,
NumKeyValueHeads: rawCfg.TextConfig.NumKeyValueHeads,
VocabSize: rawCfg.TextConfig.VocabSize,
RMSNormEps: rawCfg.TextConfig.RMSNormEps,
VisionVocabSize: rawCfg.TextConfig.VisionVocabSize,
RopeTheta: rawCfg.TextConfig.RopeParameters.RopeTheta,
PartialRotaryFactor: rawCfg.TextConfig.RopeParameters.PartialRotaryFactor,
MRoPESection: rawCfg.TextConfig.RopeParameters.MRoPESection,
ImageStartTokenID: rawCfg.ImageStartTokenID,
ImageEndTokenID: rawCfg.ImageEndTokenID,
ImageTokenID: rawCfg.ImageTokenID,
}
cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads
m.Config = cfg
// Pre-allocate layers
m.Layers = make([]*GLMBlock, cfg.NumHiddenLayers)
// Load weights
weights, err := safetensors.LoadModelWeights(path)
if err != nil {
return fmt.Errorf("weights: %w", err)
}
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
return fmt.Errorf("load weights: %w", err)
}
defer weights.ReleaseAll()
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return fmt.Errorf("load module: %w", err)
}
m.initComputedFields()
fmt.Printf("✓ [%d layers]\n", cfg.NumHiddenLayers)
return nil
}
func (m *VisionLanguageEncoder) initComputedFields() {
cfg := m.Config
for _, block := range m.Layers {
block.SelfAttn.NHeads = cfg.NumAttentionHeads
block.SelfAttn.NKVHeads = cfg.NumKeyValueHeads
block.SelfAttn.HeadDim = cfg.HeadDim
block.SelfAttn.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
block.SelfAttn.PartialRotary = cfg.PartialRotaryFactor
block.SelfAttn.RopeTheta = cfg.RopeTheta
block.SelfAttn.MRoPESection = cfg.MRoPESection
// Set norm eps
block.InputLayerNorm.Eps = cfg.RMSNormEps
block.PostSelfAttnNorm.Eps = cfg.RMSNormEps
block.PostAttnLayerNorm.Eps = cfg.RMSNormEps
block.PostMLPLayerNorm.Eps = cfg.RMSNormEps
}
m.FinalNorm.Eps = cfg.RMSNormEps
}
// Generate autoregressively generates visual tokens with KV caching
func (m *VisionLanguageEncoder) Generate(
prompt string,
tok *GLMTokenizer,
maxTokens int32,
temperature float32,
topP float32,
seed int64,
targetHeight, targetWidth int32,
progressFn func(int),
) *mlx.Array {
cfg := m.Config
// Encode prompt with grid tokens using GLM tokenizer
// Format: {prompt}<sop>{h} {w}<eop><sop>{prev_h} {prev_w}<eop><|dit_token_16384|>
tokens := tok.EncodeForGeneration(prompt, targetHeight, targetWidth)
// Calculate grid dimensions for MRoPE position IDs
factor := int32(32)
tokenH := targetHeight / factor
tokenW := targetWidth / factor
ratio := float64(tokenH) / float64(tokenW)
prevTokenH := int32(math.Sqrt(ratio) * 16)
prevTokenW := int32(math.Sqrt(1.0/ratio) * 16)
prevGridSize := prevTokenH * prevTokenW
// Create KV cache for all layers
cache := NewARCache(cfg.NumHiddenLayers)
defer cache.Free()
// ===== PREFILL PHASE =====
// Process entire prompt at once, populate cache
promptLen := int32(len(tokens))
tokenArr := mlx.NewArrayInt32(tokens, []int32{1, promptLen})
h := m.EmbedTokens.Forward(tokenArr)
tokenArr.Free()
mlx.Eval(h)
// Compute position IDs for prefill (text tokens use same position for all dims)
prefillPositions := make([][]int32, 3)
for dim := 0; dim < 3; dim++ {
prefillPositions[dim] = make([]int32, promptLen)
for i := int32(0); i < promptLen; i++ {
prefillPositions[dim][i] = i
}
}
// Forward through layers (prefill)
for i, layer := range m.Layers {
oldH := h
h = layer.ForwardWithCache(h, promptLen, 0, cfg.RMSNormEps, cache.Layers[i], prefillPositions)
if i > 0 {
oldH.Free()
}
}
// Eval h and cache arrays together so cache is materialized
evalArgs := []*mlx.Array{h}
for _, lc := range cache.Layers {
evalArgs = append(evalArgs, lc.State()...)
}
mlx.Eval(evalArgs...)
// Final norm and get logits for last position
preNormH := h
h = m.FinalNorm.Forward(h, cfg.RMSNormEps)
preNormH.Free()
lastH := mlx.Slice(h, []int32{0, promptLen - 1, 0}, []int32{1, promptLen, cfg.HiddenSize})
h.Free()
lastH = mlx.Reshape(lastH, 1, cfg.HiddenSize)
logits := mlx.Matmul(lastH, mlx.Transpose(m.LMHead, 1, 0))
lastH.Free()
// Sample first token
var sampleCounter int64 = 0
nextToken := sampleVisualToken(logits, temperature, topP, cfg, seed, &sampleCounter)
logits.Free()
// AR generation loop with caching
// Visual tokens are stored as VQ codebook indices [0, 16383]
// The LM head outputs indices [0, 16511] where:
// - [0, 16383] are VQ codes
// - 16384 is BOS
// - 16385 is EOS
visualTokens := make([]int32, 0, maxTokens)
posOffset := promptLen
visualTokenIdx := int32(0) // Index within visual token sequence for grid position calculation
// Preallocate slice for old cache state to reuse
oldCacheState := make([]*mlx.Array, 0, len(m.Layers)*2)
for i := int32(0); i < maxTokens; i++ {
if progressFn != nil {
progressFn(int(i))
}
// Check for end token (EOS = 16385)
if nextToken == cfg.ImageEndTokenID {
break
}
// Skip BOS token (16384), only store actual VQ codes [0, 16383]
if nextToken == cfg.ImageStartTokenID {
// BOS token - skip storing but continue generation
} else if nextToken < cfg.ImageStartTokenID {
// This is an actual VQ code [0, 16383] - store it
visualTokens = append(visualTokens, nextToken)
}
// Tokens >= 16386 are other special tokens, skip them
// ===== DECODE PHASE =====
// Save old cache state before forward (to free after eval)
oldCacheState = oldCacheState[:0]
for _, lc := range cache.Layers {
oldCacheState = append(oldCacheState, lc.State()...)
}
// Only process the new token, use cached K,V
tokenArr := mlx.NewArrayInt32([]int32{nextToken}, []int32{1, 1})
h := m.EmbedTokens.Forward(tokenArr)
tokenArr.Free()
// Compute MRoPE position IDs for this visual token
// Visual tokens are arranged in two grids: prev grid then target grid
// Position dimensions: [temporal, height, width]
decodePositions := computeVisualTokenPositions(
visualTokenIdx, posOffset, promptLen,
prevTokenH, prevTokenW, prevGridSize,
tokenH, tokenW,
)
// Forward through layers (decode with cache)
for j, layer := range m.Layers {
oldH := h
h = layer.ForwardWithCache(h, 1, posOffset, cfg.RMSNormEps, cache.Layers[j], decodePositions)
if j > 0 { // Don't free the embedding on first layer
oldH.Free()
}
}
// Eval h and new cache state
newCacheState := make([]*mlx.Array, 0, len(m.Layers)*2)
for _, lc := range cache.Layers {
newCacheState = append(newCacheState, lc.State()...)
}
mlx.Eval(append([]*mlx.Array{h}, newCacheState...)...)
// Free old cache state (now that new state is evaluated)
for _, arr := range oldCacheState {
if arr != nil {
arr.Free()
}
}
// Final norm
preNormH := h
h = m.FinalNorm.Forward(h, cfg.RMSNormEps)
preNormH.Free()
// Get logits (h is already [1, 1, hidden_size])
h = mlx.Reshape(h, 1, cfg.HiddenSize)
logits := mlx.Matmul(h, mlx.Transpose(m.LMHead, 1, 0))
h.Free()
// Sample next token
nextToken = sampleVisualToken(logits, temperature, topP, cfg, seed, &sampleCounter)
logits.Free()
posOffset++
visualTokenIdx++
// Periodically clear cache to release intermediate memory
if i%256 == 0 {
mlx.ClearCache()
}
}
if len(visualTokens) == 0 {
// Return at least one token to avoid empty tensor issues
visualTokens = append(visualTokens, 0)
}
return mlx.NewArrayInt32(visualTokens, []int32{1, int32(len(visualTokens))})
}
// computeVisualTokenPositions computes MRoPE position IDs for a visual token
// Returns [3][1] position IDs for temporal, height, and width dimensions
//
// MRoPE position encoding for GLM-Image visual tokens:
// - temporal: CONSTANT within each grid (= decode_pos at grid start)
// - height: decode_pos + row index within grid
// - width: decode_pos + column index within grid
//
// Between grids, decode_pos advances by max(grid_h, grid_w) to ensure
// sufficient positional separation.
func computeVisualTokenPositions(
visualIdx int32, absPos int32, promptLen int32,
prevH, prevW, prevSize int32,
targetH, targetW int32,
) [][]int32 {
positions := make([][]int32, 3)
for dim := 0; dim < 3; dim++ {
positions[dim] = make([]int32, 1)
}
// First grid (prev grid) starts at decode_pos = promptLen
prevGridDecodePos := promptLen
// Second grid (target grid) starts after first grid
// next_pos = prev_decode_pos + max(prevH, prevW)
maxPrev := prevH
if prevW > maxPrev {
maxPrev = prevW
}
targetGridDecodePos := prevGridDecodePos + maxPrev
// Compute position IDs based on which grid the token is in
if visualIdx < prevSize {
// Token is in the prev grid (prev_token_h × prev_token_w)
row := visualIdx / prevW
col := visualIdx % prevW
// temporal is CONSTANT for all tokens in this grid
positions[0][0] = prevGridDecodePos
// height and width are relative to grid's decode_pos
positions[1][0] = prevGridDecodePos + row
positions[2][0] = prevGridDecodePos + col
} else {
// Token is in the target grid (token_h × token_w)
targetIdx := visualIdx - prevSize
row := targetIdx / targetW
col := targetIdx % targetW
// temporal is CONSTANT for all tokens in this grid
positions[0][0] = targetGridDecodePos
// height and width are relative to grid's decode_pos
positions[1][0] = targetGridDecodePos + row
positions[2][0] = targetGridDecodePos + col
}
_ = targetH // Used for documentation clarity
_ = absPos // No longer used - kept for API compatibility
return positions
}
// sampleVisualToken samples from the visual vocabulary using top-p (nucleus) sampling
// Note: For GLM-Image, greedy decoding is not allowed as it may cause repetitive outputs
// Returns a visual token ID in range [0, 16511] which directly indexes into the embedding table
// sampleCounter is incremented for each call to ensure different random values
func sampleVisualToken(logits *mlx.Array, temperature float32, topP float32, cfg *VisionLanguageConfig, seed int64, sampleCounter *int64) int32 {
// The LMHead outputs logits for visual tokens only (shape [1, 16512])
// Output index directly corresponds to vocab ID [0, 16511]
// No offset needed - the visual tokens are at vocab IDs [0, 16511]
visualLogits := logits
// Apply temperature
if temperature != 1.0 && temperature > 0 {
visualLogits = mlx.DivScalar(visualLogits, temperature)
}
// Apply softmax to get probabilities
probs := mlx.Softmax(visualLogits, -1)
mlx.Eval(probs)
// Get the sampled index using top-p sampling
// This directly gives us the vocab ID in [0, 16511]
// Special tokens: 16384 = BOS, 16385 = EOS
// Use seed + counter for reproducible but different random values
effectiveSeed := seed + *sampleCounter
*sampleCounter++
return sampleTopP(probs, topP, effectiveSeed)
}
// sampleTopP implements nucleus (top-p) sampling
// probs: [1, vocab_size] probability distribution
// topP: cumulative probability threshold (e.g., 0.75)
// seed: random seed for reproducible sampling
func sampleTopP(probs *mlx.Array, topP float32, seed int64) int32 {
// Negate probs for descending sort (Argsort only does ascending)
negProbs := mlx.MulScalar(probs, -1)
sortedIndices := mlx.Argsort(negProbs, -1)
sortedProbs := mlx.TakeAlongAxis(probs, sortedIndices, -1)
cumProbs := mlx.Cumsum(sortedProbs, -1)
mlx.Eval(sortedIndices, sortedProbs, cumProbs)
// Find cutoff index where cumulative probability exceeds topP
probsData := sortedProbs.Data()
cumProbsData := cumProbs.Data()
indicesData := sortedIndices.DataInt32()
// Calculate cutoff and renormalize
var cutoffIdx int
var totalProb float32
for i, cp := range cumProbsData {
totalProb += probsData[i]
if cp >= topP {
cutoffIdx = i + 1 // Include this token
break
}
}
if cutoffIdx == 0 {
cutoffIdx = len(probsData) // Use all tokens if topP is very high
}
// Sample from the truncated distribution
// Renormalize the truncated probabilities
truncatedProbs := make([]float32, cutoffIdx)
for i := 0; i < cutoffIdx; i++ {
truncatedProbs[i] = probsData[i] / totalProb
}
// Sample using random number with provided seed for reproducibility
r := mlx.RandomUniform([]int32{1}, uint64(seed))
mlx.Eval(r)
randVal := r.Data()[0]
// Find the sampled token
var cumulative float32
for i := 0; i < cutoffIdx; i++ {
cumulative += truncatedProbs[i]
if randVal < cumulative {
return indicesData[i]
}
}
// Fallback to the last token in truncated set
return indicesData[cutoffIdx-1]
}
// Forward for GLMBlock
func (b *GLMBlock) Forward(x *mlx.Array, seqLen int32, eps float32) *mlx.Array {
return b.ForwardWithCache(x, seqLen, 0, eps, nil, nil)
}
// ForwardWithCache performs block forward with optional KV caching and MRoPE
// positionIDs: [3][L] - position indices for MRoPE (nil = use sequential positions)
func (b *GLMBlock) ForwardWithCache(x *mlx.Array, seqLen int32, posOffset int32, eps float32, kvcache cache.Cache, positionIDs [][]int32) *mlx.Array {
// Pre-attention norm
normed := b.InputLayerNorm.Forward(x, eps)
// Self-attention with RoPE/MRoPE and cache
attnOut := b.SelfAttn.ForwardWithCache(normed, seqLen, posOffset, kvcache, positionIDs)
// Post-attention norm (GLM-4 style)
attnOut = b.PostSelfAttnNorm.Forward(attnOut, eps)
// Residual connection
x = mlx.Add(x, attnOut)
// Post-attention layer norm
normed = b.PostAttnLayerNorm.Forward(x, eps)
// MLP
mlpOut := b.MLP.Forward(normed)
// Post-MLP norm
mlpOut = b.PostMLPLayerNorm.Forward(mlpOut, eps)
// Residual connection
x = mlx.Add(x, mlpOut)
return x
}
// Forward for GLMAttention (without cache - used for prefill)
func (attn *GLMAttention) Forward(x *mlx.Array, seqLen int32) *mlx.Array {
return attn.ForwardWithCache(x, seqLen, 0, nil, nil)
}
// ForwardWithCache performs attention with optional KV caching and MRoPE
// posOffset is the position offset for RoPE (0 for prefill, cached_len for decode)
// positionIDs: [3][L] - if nil, uses sequential positions for all dims (text mode)
// kvcache is updated in-place if provided
func (attn *GLMAttention) ForwardWithCache(x *mlx.Array, seqLen int32, posOffset int32, kvcache cache.Cache, positionIDs [][]int32) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
// Q, K, V projections
q := mlx.Matmul(x, mlx.Transpose(attn.QProj, 1, 0))
k := mlx.Matmul(x, mlx.Transpose(attn.KProj, 1, 0))
v := mlx.Matmul(x, mlx.Transpose(attn.VProj, 1, 0))
// Add biases
if attn.QBias != nil {
q = mlx.Add(q, attn.QBias)
}
if attn.KBias != nil {
k = mlx.Add(k, attn.KBias)
}
if attn.VBias != nil {
v = mlx.Add(v, attn.VBias)
}
// Reshape to [B, L, nheads, head_dim]
q = mlx.Reshape(q, B, L, attn.NHeads, attn.HeadDim)
k = mlx.Reshape(k, B, L, attn.NKVHeads, attn.HeadDim)
v = mlx.Reshape(v, B, L, attn.NKVHeads, attn.HeadDim)
// Apply partial RoPE or MRoPE
rotaryDim := int32(float32(attn.HeadDim) * attn.PartialRotary)
if len(attn.MRoPESection) == 3 && positionIDs != nil {
// Use MRoPE with explicit position IDs
q = applyPartialMRoPE(q, positionIDs, rotaryDim, attn.RopeTheta, attn.MRoPESection)
k = applyPartialMRoPE(k, positionIDs, rotaryDim, attn.RopeTheta, attn.MRoPESection)
} else if len(attn.MRoPESection) == 3 {
// Use MRoPE with sequential positions (same for all dims - text mode)
seqPositions := make([][]int32, 3)
for dim := 0; dim < 3; dim++ {
seqPositions[dim] = make([]int32, L)
for i := int32(0); i < L; i++ {
seqPositions[dim][i] = i + posOffset
}
}
q = applyPartialMRoPE(q, seqPositions, rotaryDim, attn.RopeTheta, attn.MRoPESection)
k = applyPartialMRoPE(k, seqPositions, rotaryDim, attn.RopeTheta, attn.MRoPESection)
} else {
// Fallback to standard RoPE
q = applyPartialRoPEWithOffset(q, L, posOffset, rotaryDim, attn.RopeTheta)
k = applyPartialRoPEWithOffset(k, L, posOffset, rotaryDim, attn.RopeTheta)
}
// Transpose to [B, nheads, L, head_dim]
q = mlx.Transpose(q, 0, 2, 1, 3)
k = mlx.Transpose(k, 0, 2, 1, 3)
v = mlx.Transpose(v, 0, 2, 1, 3)
// Update cache and get full K, V for attention
if kvcache != nil {
k, v = kvcache.Update(k, v, int(L))
}
// Repeat KV for GQA
kExpanded := k
vExpanded := v
if attn.NKVHeads < attn.NHeads {
repeats := attn.NHeads / attn.NKVHeads
kExpanded = repeatKV(k, repeats)
vExpanded = repeatKV(v, repeats)
}
// Scaled dot-product attention with causal mask
out := mlx.ScaledDotProductAttention(q, kExpanded, vExpanded, attn.Scale, true)
// Transpose back [B, nheads, L, head_dim] -> [B, L, nheads, head_dim]
out = mlx.Transpose(out, 0, 2, 1, 3)
// Reshape to [B, L, hidden_size]
out = mlx.Reshape(out, B, L, attn.NHeads*attn.HeadDim)
// Output projection
out = mlx.Matmul(out, mlx.Transpose(attn.OProj, 1, 0))
return out
}
// applyPartialRoPE applies RoPE to only the first rotaryDim dimensions
func applyPartialRoPE(x *mlx.Array, seqLen int32, rotaryDim int32, theta float32) *mlx.Array {
return applyPartialRoPEWithOffset(x, seqLen, 0, rotaryDim, theta)
}
// applyPartialRoPEWithOffset applies RoPE with a position offset
func applyPartialRoPEWithOffset(x *mlx.Array, seqLen int32, posOffset int32, rotaryDim int32, theta float32) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
H := shape[2]
D := shape[3]
if rotaryDim <= 0 || rotaryDim > D {
rotaryDim = D
}
// Split into rotary and pass-through parts
xRot := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, L, H, rotaryDim})
xPass := mlx.Slice(x, []int32{0, 0, 0, rotaryDim}, []int32{B, L, H, D})
// Apply RoPE to rotary part with position offset
xRot = applyRoPEWithOffset(xRot, L, posOffset, theta)
// Concatenate back
return mlx.Concatenate([]*mlx.Array{xRot, xPass}, 3)
}
// applyPartialMRoPE applies Multi-dimensional RoPE (MRoPE) to the first rotaryDim dimensions
// positionIDs: [3, L] - position indices for each dimension (temporal, height, width)
// mrope_section: [8, 12, 12] - frequency pairs per dimension
// For text tokens: all 3 dimensions have the same sequential position
// For image tokens: temporal=seq_idx, height=row, width=col
func applyPartialMRoPE(x *mlx.Array, positionIDs [][]int32, rotaryDim int32, theta float32, mropeSection []int32) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
H := shape[2]
D := shape[3]
if rotaryDim <= 0 || rotaryDim > D {
rotaryDim = D
}
// Split into rotary and pass-through parts
xRot := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, L, H, rotaryDim})
xPass := mlx.Slice(x, []int32{0, 0, 0, rotaryDim}, []int32{B, L, H, D})
// Apply MRoPE to rotary part
xRot = applyMRoPE(xRot, positionIDs, theta, mropeSection)
// Concatenate back
return mlx.Concatenate([]*mlx.Array{xRot, xPass}, 3)
}
// applyMRoPE applies multi-dimensional rotary position embedding
// x: [B, L, H, D] where D is the rotary dimension
// positionIDs: [3][L] - positions for temporal, height, width dimensions
// mropeSection: [8, 12, 12] - frequency pairs per dimension
func applyMRoPE(x *mlx.Array, positionIDs [][]int32, theta float32, mropeSection []int32) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
H := shape[2]
D := shape[3]
half := D / 2
// Validate mrope_section sums to half (number of frequency pairs)
var totalPairs int32
for _, s := range mropeSection {
totalPairs += s
}
if totalPairs != half {
// Fallback to standard RoPE if section doesn't match
return applyRoPEWithOffset(x, L, 0, theta)
}
// Build angles for each position dimension (matching Python's MRoPE approach)
// Python: compute freqs for all dims, then apply_mrope selects freq ranges, then duplicate
// Order: [temporal_8, height_12, width_12] -> duplicate -> [t8, h12, w12, t8, h12, w12]
angleVals := make([]*mlx.Array, 3)
freqOffset := int32(0)
for dim := 0; dim < 3; dim++ {
numPairs := mropeSection[dim]
if numPairs == 0 {
continue
}
// Compute inverse frequencies for this section
// Each dimension uses DIFFERENT frequency ranges:
// - Temporal: frequencies 0 to section[0]-1
// - Height: frequencies section[0] to section[0]+section[1]-1
// - Width: frequencies section[0]+section[1] to sum(section)-1
freqsArr := make([]float32, numPairs)
for i := int32(0); i < numPairs; i++ {
globalIdx := freqOffset + i
freqsArr[i] = float32(1.0 / math.Pow(float64(theta), float64(2*globalIdx)/float64(D)))
}
freqs := mlx.NewArray(freqsArr, []int32{numPairs})
// Position indices for this dimension
posArr := make([]float32, L)
for i := int32(0); i < L; i++ {
posArr[i] = float32(positionIDs[dim][i])
}
pos := mlx.NewArray(posArr, []int32{L})
// Compute angles: [L, numPairs] = outer(pos, freqs)
posExpanded := mlx.Reshape(pos, L, 1)
freqsExpanded := mlx.Reshape(freqs, 1, numPairs)
angleVals[dim] = mlx.Mul(posExpanded, freqsExpanded)
freqOffset += numPairs
}
// Concatenate all sections: [L, half] = [L, 32]
allAngles := mlx.Concatenate(angleVals, 1)
// Duplicate AFTER concatenation: [L, D] = [L, 64]
// This gives: [temporal_8, height_12, width_12, temporal_8, height_12, width_12]
allAngles = mlx.Concatenate([]*mlx.Array{allAngles, allAngles}, 1)
// Compute cos/sin
allCos := mlx.Cos(allAngles)
allSin := mlx.Sin(allAngles)
// Reshape for broadcasting: [1, L, 1, D] to match x [B, L, H, D]
allCos = mlx.Reshape(allCos, 1, L, 1, D)
allSin = mlx.Reshape(allSin, 1, L, 1, D)
// x_rotated = cat([-x_imag, x_real], dim=-1)
x1 := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, L, H, half}) // x_real
x2 := mlx.Slice(x, []int32{0, 0, 0, half}, []int32{B, L, H, D}) // x_imag
x2Neg := mlx.MulScalar(x2, -1) // -x_imag
xRotated := mlx.Concatenate([]*mlx.Array{x2Neg, x1}, 3) // [-x_imag, x_real]
// out = x * cos + x_rotated * sin
return mlx.Add(mlx.Mul(x, allCos), mlx.Mul(xRotated, allSin))
}
// applyRoPE applies rotary position embedding
func applyRoPE(x *mlx.Array, seqLen int32, theta float32) *mlx.Array {
return applyRoPEWithOffset(x, seqLen, 0, theta)
}
// applyRoPEWithOffset applies rotary position embedding with position offset
// Uses the split-half approach (matches diffusers GLM-Image with use_real_unbind_dim=-2)
func applyRoPEWithOffset(x *mlx.Array, seqLen int32, posOffset int32, theta float32) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
H := shape[2]
D := shape[3]
half := D / 2
// Compute inverse frequencies: 1 / (theta^(2i/d))
freqsArr := make([]float32, half)
for i := int32(0); i < half; i++ {
freqsArr[i] = float32(1.0 / math.Pow(float64(theta), float64(2*i)/float64(D)))
}
freqs := mlx.NewArray(freqsArr, []int32{half})
// Position indices with offset
posArr := make([]float32, L)
for i := int32(0); i < L; i++ {
posArr[i] = float32(i + posOffset)
}
pos := mlx.NewArray(posArr, []int32{L})
// Compute angles: [L, half] = outer(pos, freqs)
posExpanded := mlx.Reshape(pos, L, 1)
freqsExpanded := mlx.Reshape(freqs, 1, half)
angles := mlx.Mul(posExpanded, freqsExpanded)
// Duplicate angles to match diffusers: cat([angles, angles], dim=-1) -> [L, D]
anglesDup := mlx.Concatenate([]*mlx.Array{angles, angles}, 1)
// Cos and sin: [L, 1, D] for broadcasting to [B, L, H, D]
cosVals := mlx.Cos(anglesDup)
sinVals := mlx.Sin(anglesDup)
cosVals = mlx.Reshape(cosVals, L, 1, D)
sinVals = mlx.Reshape(sinVals, L, 1, D)
// x_rotated = cat([-x_imag, x_real], dim=-1) where x_real=x[..., :half], x_imag=x[..., half:]
x1 := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, L, H, half}) // x_real
x2 := mlx.Slice(x, []int32{0, 0, 0, half}, []int32{B, L, H, D}) // x_imag
x2Neg := mlx.MulScalar(x2, -1) // -x_imag
xRotated := mlx.Concatenate([]*mlx.Array{x2Neg, x1}, 3) // [-x_imag, x_real]
// out = x * cos + x_rotated * sin
return mlx.Add(mlx.Mul(x, cosVals), mlx.Mul(xRotated, sinVals))
}
// repeatKV repeats key/value heads for GQA
func repeatKV(x *mlx.Array, repeats int32) *mlx.Array {
if repeats == 1 {
return x
}
shape := x.Shape()
// x: [B, nkvheads, L, head_dim]
x = mlx.ExpandDims(x, 2)
// x: [B, nkvheads, 1, L, head_dim]
x = mlx.Tile(x, []int32{1, 1, repeats, 1, 1})
// x: [B, nkvheads, repeats, L, head_dim]
return mlx.Reshape(x, shape[0], shape[1]*repeats, shape[2], shape[3])
}
// Forward for GLMMLP (fused gate_up SwiGLU)
func (m *GLMMLP) Forward(x *mlx.Array) *mlx.Array {
// gate_up_proj outputs [gate, up] concatenated
gateUp := mlx.Matmul(x, mlx.Transpose(m.GateUpProj, 1, 0))
shape := gateUp.Shape()
halfDim := shape[len(shape)-1] / 2
// Split into gate and up
gate := mlx.Slice(gateUp, []int32{0, 0, 0}, []int32{shape[0], shape[1], halfDim})
up := mlx.Slice(gateUp, []int32{0, 0, halfDim}, []int32{shape[0], shape[1], shape[2]})
// SwiGLU: silu(gate) * up
gate = mlx.SiLU(gate)
h := mlx.Mul(gate, up)
// Down projection
return mlx.Matmul(h, mlx.Transpose(m.DownProj, 1, 0))
}

View File

@@ -222,6 +222,14 @@ func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
mlx.Keep(posEmb, negEmb)
}
// Pre-compute batched embeddings for CFG (single forward pass optimization)
var batchedEmb *mlx.Array
if useCFG {
batchedEmb = mlx.Concatenate([]*mlx.Array{posEmb, negEmb}, 0)
mlx.Keep(batchedEmb)
mlx.Eval(batchedEmb)
}
// Scheduler
scheduler := NewFlowMatchScheduler(DefaultSchedulerConfig())
scheduler.SetTimesteps(cfg.Steps, imgSeqLen)
@@ -264,10 +272,19 @@ func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
var output *mlx.Array
if useCFG {
// True CFG: run twice and combine with norm rescaling
// CFG Batching: single forward pass with batch=2
// Note: layer caching with CFG is not supported yet (would need 2 caches)
posOutput := m.Transformer.Forward(patches, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
negOutput := m.Transformer.Forward(patches, negEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
batchedPatches := mlx.Tile(patches, []int32{2, 1, 1})
batchedTimestep := mlx.Tile(timestep, []int32{2})
// Single batched forward pass
batchedOutput := m.Transformer.Forward(batchedPatches, batchedEmb, batchedTimestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
// Split output: [2, L, D] -> pos [1, L, D], neg [1, L, D]
L := batchedOutput.Shape()[1]
D := batchedOutput.Shape()[2]
posOutput := mlx.Slice(batchedOutput, []int32{0, 0, 0}, []int32{1, L, D})
negOutput := mlx.Slice(batchedOutput, []int32{1, 0, 0}, []int32{2, L, D})
diff := mlx.Sub(posOutput, negOutput)
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
@@ -305,6 +322,9 @@ func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
if negEmb != nil {
negEmb.Free()
}
if batchedEmb != nil {
batchedEmb.Free()
}
ropeCache.ImgFreqs.Free()
ropeCache.TxtFreqs.Free()
if stepCache != nil {

View File

@@ -241,6 +241,14 @@ func (m *Model) edit(inputImagePaths []string, cfg *GenerateConfig) (*mlx.Array,
mlx.Eval(posEmb, negEmb)
}
// Pre-compute batched embeddings for CFG (single forward pass optimization)
var batchedEmb *mlx.Array
if useCFG {
batchedEmb = mlx.Concatenate([]*mlx.Array{posEmb, negEmb}, 0)
mlx.Keep(batchedEmb)
mlx.Eval(batchedEmb)
}
// Encode all input images to latents and concatenate
fmt.Println("Encoding images to latents...")
allImageLatentsPacked := make([]*mlx.Array, len(vaeImages))
@@ -291,11 +299,18 @@ func (m *Model) edit(inputImagePaths []string, cfg *GenerateConfig) (*mlx.Array,
var output *mlx.Array
if useCFG {
posOutput := m.Transformer.Forward(latentInput, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
negOutput := m.Transformer.Forward(latentInput, negEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
// CFG Batching: single forward pass with batch=2
// Tile inputs: [1, L, D] -> [2, L, D]
batchedLatentInput := mlx.Tile(latentInput, []int32{2, 1, 1})
batchedTimestep := mlx.Tile(timestep, []int32{2})
posOutput = mlx.Slice(posOutput, []int32{0, 0, 0}, []int32{1, imgSeqLen, posOutput.Shape()[2]})
negOutput = mlx.Slice(negOutput, []int32{0, 0, 0}, []int32{1, imgSeqLen, negOutput.Shape()[2]})
// Single batched forward pass
batchedOutput := m.Transformer.Forward(batchedLatentInput, batchedEmb, batchedTimestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
// Split output: [2, L, D] -> pos [1, L, D], neg [1, L, D]
D := batchedOutput.Shape()[2]
posOutput := mlx.Slice(batchedOutput, []int32{0, 0, 0}, []int32{1, imgSeqLen, D})
negOutput := mlx.Slice(batchedOutput, []int32{1, 0, 0}, []int32{2, imgSeqLen, D})
output = applyCFGWithNormRescale(posOutput, negOutput, cfg.CFGScale)
} else {
@@ -317,6 +332,9 @@ func (m *Model) edit(inputImagePaths []string, cfg *GenerateConfig) (*mlx.Array,
if negEmb != nil {
negEmb.Free()
}
if batchedEmb != nil {
batchedEmb.Free()
}
ropeCache.ImgFreqs.Free()
ropeCache.TxtFreqs.Free()
imageLatentsPacked.Free()

View File

@@ -28,12 +28,12 @@ type Qwen3Config struct {
// Qwen3Attention implements Qwen3 attention with QK norms
type Qwen3Attention struct {
QProj *nn.Linear `weight:"q_proj"`
KProj *nn.Linear `weight:"k_proj"`
VProj *nn.Linear `weight:"v_proj"`
OProj *nn.Linear `weight:"o_proj"`
QNorm *nn.RMSNorm `weight:"q_norm"`
KNorm *nn.RMSNorm `weight:"k_norm"`
QProj nn.LinearLayer `weight:"q_proj"`
KProj nn.LinearLayer `weight:"k_proj"`
VProj nn.LinearLayer `weight:"v_proj"`
OProj nn.LinearLayer `weight:"o_proj"`
QNorm *nn.RMSNorm `weight:"q_norm"`
KNorm *nn.RMSNorm `weight:"k_norm"`
// Computed fields
NHeads int32
NKVHeads int32
@@ -136,9 +136,9 @@ func repeatKV(x *mlx.Array, repeats int32) *mlx.Array {
// Qwen3MLP implements Qwen3 SwiGLU MLP
type Qwen3MLP struct {
GateProj *nn.Linear `weight:"gate_proj"`
UpProj *nn.Linear `weight:"up_proj"`
DownProj *nn.Linear `weight:"down_proj"`
GateProj nn.LinearLayer `weight:"gate_proj"`
UpProj nn.LinearLayer `weight:"up_proj"`
DownProj nn.LinearLayer `weight:"down_proj"`
}
// Forward applies the MLP

View File

@@ -36,8 +36,8 @@ type TransformerConfig struct {
// TimestepEmbedder creates sinusoidal timestep embeddings
// Output dimension is 256 (fixed), used for AdaLN modulation
type TimestepEmbedder struct {
Linear1 *nn.Linear `weight:"mlp.0"`
Linear2 *nn.Linear `weight:"mlp.2"`
Linear1 nn.LinearLayer `weight:"mlp.0"`
Linear2 nn.LinearLayer `weight:"mlp.2"`
FreqEmbedSize int32 // 256 (computed)
}
@@ -74,7 +74,7 @@ func (te *TimestepEmbedder) Forward(t *mlx.Array) *mlx.Array {
// XEmbedder embeds image patches to model dimension
type XEmbedder struct {
Linear *nn.Linear `weight:"2-1"`
Linear nn.LinearLayer `weight:"2-1"`
}
// Forward embeds patchified image latents
@@ -86,7 +86,7 @@ func (xe *XEmbedder) Forward(x *mlx.Array) *mlx.Array {
// CapEmbedder projects caption features to model dimension
type CapEmbedder struct {
Norm *nn.RMSNorm `weight:"0"`
Linear *nn.Linear `weight:"1"`
Linear nn.LinearLayer `weight:"1"`
PadToken *mlx.Array // loaded separately at root level
}
@@ -100,12 +100,13 @@ func (ce *CapEmbedder) Forward(capFeats *mlx.Array) *mlx.Array {
// FeedForward implements SwiGLU FFN
type FeedForward struct {
W1 *nn.Linear `weight:"w1"` // gate projection
W2 *nn.Linear `weight:"w2"` // down projection
W3 *nn.Linear `weight:"w3"` // up projection
W1 nn.LinearLayer `weight:"w1"` // gate projection
W2 nn.LinearLayer `weight:"w2"` // down projection
W3 nn.LinearLayer `weight:"w3"` // up projection
OutDim int32 // computed from W2
}
// Forward applies SwiGLU: silu(W1(x)) * W3(x), then W2
func (ff *FeedForward) Forward(x *mlx.Array) *mlx.Array {
shape := x.Shape()
@@ -115,6 +116,7 @@ func (ff *FeedForward) Forward(x *mlx.Array) *mlx.Array {
// Reshape for matmul
x = mlx.Reshape(x, B*L, D)
gate := ff.W1.Forward(x)
gate = mlx.SiLU(gate)
up := ff.W3.Forward(x)
@@ -126,17 +128,69 @@ func (ff *FeedForward) Forward(x *mlx.Array) *mlx.Array {
// Attention implements multi-head attention with QK norm
type Attention struct {
ToQ *nn.Linear `weight:"to_q"`
ToK *nn.Linear `weight:"to_k"`
ToV *nn.Linear `weight:"to_v"`
ToOut *nn.Linear `weight:"to_out.0"`
ToQ nn.LinearLayer `weight:"to_q"`
ToK nn.LinearLayer `weight:"to_k"`
ToV nn.LinearLayer `weight:"to_v"`
ToOut nn.LinearLayer `weight:"to_out.0"`
NormQ *mlx.Array `weight:"norm_q.weight"` // [head_dim] for per-head RMSNorm
NormK *mlx.Array `weight:"norm_k.weight"`
// Computed fields
NHeads int32
HeadDim int32
Dim int32
Scale float32
// Fused QKV (computed at init time for efficiency, not loaded from weights)
ToQKV nn.LinearLayer `weight:"-"` // Fused Q+K+V projection (created by FuseQKV)
Fused bool `weight:"-"` // Whether to use fused QKV path
// Computed fields (not loaded from weights)
NHeads int32 `weight:"-"`
HeadDim int32 `weight:"-"`
Dim int32 `weight:"-"`
Scale float32 `weight:"-"`
}
// FuseQKV creates a fused QKV projection by concatenating weights.
// This reduces 3 matmuls to 1 for a ~5-10% speedup.
// Note: Fusion is skipped for quantized weights as it would require complex
// dequant-concat-requant operations. The FP8 memory bandwidth savings outweigh
// the ~5% fusion benefit.
func (attn *Attention) FuseQKV() {
if attn.ToQ == nil || attn.ToK == nil || attn.ToV == nil {
return
}
// Skip fusion for quantized weights - type assert to check
toQ, qOk := attn.ToQ.(*nn.Linear)
toK, kOk := attn.ToK.(*nn.Linear)
toV, vOk := attn.ToV.(*nn.Linear)
if !qOk || !kOk || !vOk {
// One or more are QuantizedLinear, skip fusion
return
}
if toQ.Weight == nil || toK.Weight == nil || toV.Weight == nil {
return
}
// Concatenate weights: [dim, dim] x 3 -> [3*dim, dim]
// Weight shapes: ToQ.Weight [out_dim, in_dim], etc.
qWeight := toQ.Weight
kWeight := toK.Weight
vWeight := toV.Weight
// Concatenate along output dimension (axis 0)
fusedWeight := mlx.Concatenate([]*mlx.Array{qWeight, kWeight, vWeight}, 0)
// Evaluate fused weight to ensure it's materialized
mlx.Eval(fusedWeight)
// Create fused linear layer
fusedLinear := &nn.Linear{Weight: fusedWeight}
// Handle bias if present
if toQ.Bias != nil && toK.Bias != nil && toV.Bias != nil {
fusedBias := mlx.Concatenate([]*mlx.Array{toQ.Bias, toK.Bias, toV.Bias}, 0)
mlx.Eval(fusedBias)
fusedLinear.Bias = fusedBias
}
attn.ToQKV = fusedLinear
attn.Fused = true
}
// Forward computes attention
@@ -146,11 +200,24 @@ func (attn *Attention) Forward(x *mlx.Array, cos, sin *mlx.Array) *mlx.Array {
L := shape[1]
D := shape[2]
// Project Q, K, V
xFlat := mlx.Reshape(x, B*L, D)
q := attn.ToQ.Forward(xFlat)
k := attn.ToK.Forward(xFlat)
v := attn.ToV.Forward(xFlat)
var q, k, v *mlx.Array
if attn.Fused && attn.ToQKV != nil {
// Fused QKV path: single matmul then split
qkv := attn.ToQKV.Forward(xFlat) // [B*L, 3*dim]
// Split into Q, K, V along last dimension
// Each has shape [B*L, dim]
q = mlx.Slice(qkv, []int32{0, 0}, []int32{B * L, attn.Dim})
k = mlx.Slice(qkv, []int32{0, attn.Dim}, []int32{B * L, 2 * attn.Dim})
v = mlx.Slice(qkv, []int32{0, 2 * attn.Dim}, []int32{B * L, 3 * attn.Dim})
} else {
// Separate Q, K, V projections
q = attn.ToQ.Forward(xFlat)
k = attn.ToK.Forward(xFlat)
v = attn.ToV.Forward(xFlat)
}
// Reshape to [B, L, nheads, head_dim]
q = mlx.Reshape(q, B, L, attn.NHeads, attn.HeadDim)
@@ -227,7 +294,7 @@ type TransformerBlock struct {
AttentionNorm2 *nn.RMSNorm `weight:"attention_norm2"`
FFNNorm1 *nn.RMSNorm `weight:"ffn_norm1"`
FFNNorm2 *nn.RMSNorm `weight:"ffn_norm2"`
AdaLN *nn.Linear `weight:"adaLN_modulation.0,optional"` // only if modulation
AdaLN nn.LinearLayer `weight:"adaLN_modulation.0,optional"` // only if modulation
// Computed fields
HasModulation bool
Dim int32
@@ -281,8 +348,8 @@ func (tb *TransformerBlock) Forward(x *mlx.Array, adaln *mlx.Array, cos, sin *ml
// FinalLayer outputs the denoised patches
type FinalLayer struct {
AdaLN *nn.Linear `weight:"adaLN_modulation.1"` // [256] -> [dim]
Output *nn.Linear `weight:"linear"` // [dim] -> [out_channels]
AdaLN nn.LinearLayer `weight:"adaLN_modulation.1"` // [256] -> [dim]
Output nn.LinearLayer `weight:"linear"` // [dim] -> [out_channels]
OutDim int32 // computed from Output
}
@@ -350,12 +417,11 @@ func (m *Transformer) Load(manifest *imagegen.ModelManifest) error {
m.ContextRefiners = make([]*TransformerBlock, cfg.NRefinerLayers)
m.Layers = make([]*TransformerBlock, cfg.NLayers)
// Load weights from tensor blobs with BF16 conversion
weights, err := imagegen.LoadWeightsFromManifest(manifest, "transformer")
if err != nil {
return fmt.Errorf("weights: %w", err)
}
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
if err := weights.Load(0); err != nil {
return fmt.Errorf("load weights: %w", err)
}
defer weights.ReleaseAll()
@@ -377,7 +443,7 @@ func (m *Transformer) loadWeights(weights safetensors.WeightSource) error {
func (m *Transformer) initComputedFields() {
cfg := m.TransformerConfig
m.TEmbed.FreqEmbedSize = 256
m.FinalLayer.OutDim = m.FinalLayer.Output.Weight.Shape()[0]
m.FinalLayer.OutDim = m.FinalLayer.Output.OutputDim()
m.CapEmbed.Norm.Eps = 1e-6
for _, block := range m.NoiseRefiners {
@@ -391,6 +457,20 @@ func (m *Transformer) initComputedFields() {
}
}
// FuseAllQKV fuses QKV projections in all attention layers for efficiency.
// This reduces 3 matmuls to 1 per attention layer, providing ~5-10% speedup.
func (m *Transformer) FuseAllQKV() {
for _, block := range m.NoiseRefiners {
block.Attention.FuseQKV()
}
for _, block := range m.ContextRefiners {
block.Attention.FuseQKV()
}
for _, block := range m.Layers {
block.Attention.FuseQKV()
}
}
// initTransformerBlock sets computed fields on a transformer block
func initTransformerBlock(block *TransformerBlock, cfg *TransformerConfig) {
block.Dim = cfg.Dim
@@ -404,7 +484,7 @@ func initTransformerBlock(block *TransformerBlock, cfg *TransformerConfig) {
attn.Scale = float32(1.0 / math.Sqrt(float64(attn.HeadDim)))
// Init feedforward OutDim
block.FeedForward.OutDim = block.FeedForward.W2.Weight.Shape()[0]
block.FeedForward.OutDim = block.FeedForward.W2.OutputDim()
// Set eps on all RMSNorm layers
block.AttentionNorm1.Eps = cfg.NormEps
@@ -423,6 +503,8 @@ type RoPECache struct {
UnifiedSin *mlx.Array
ImgLen int32
CapLen int32
GridH int32 // Image token grid height
GridW int32 // Image token grid width
}
// PrepareRoPECache precomputes RoPE values for the given image and caption lengths.
@@ -456,6 +538,8 @@ func (m *Transformer) PrepareRoPECache(hTok, wTok, capLen int32) *RoPECache {
UnifiedSin: unifiedSin,
ImgLen: imgLen,
CapLen: capLen,
GridH: hTok,
GridW: wTok,
}
}

View File

@@ -104,6 +104,8 @@ func (gn *GroupNormLayer) forwardTiled(x *mlx.Array, B, H, W, C int32) *mlx.Arra
groupSize := C / gn.NumGroups
// Keep the input - we need it for slicing tiles later
// Track if we were the ones who kept it, so we can restore state after
wasKept := x.Kept()
mlx.Keep(x)
// Compute per-group mean and variance using flattened spatial dimensions
@@ -205,6 +207,10 @@ func (gn *GroupNormLayer) forwardTiled(x *mlx.Array, B, H, W, C int32) *mlx.Arra
}
// Clean up kept arrays
// Restore x's kept state - only free if we were the ones who kept it
if !wasKept {
x.Free()
}
mean.Free()
invStd.Free()
if weightGN != nil {
@@ -734,18 +740,26 @@ func (vae *VAEDecoder) Decode(latents *mlx.Array) *mlx.Array {
h := vae.ConvIn.Forward(z)
mlx.Eval(h)
prev := h
h = vae.MidBlock.Forward(h)
prev.Free()
for _, upBlock := range vae.UpBlocks {
prev = h
h = upBlock.Forward(h)
prev.Free()
}
prev := h
prev = h
h = vae.ConvNormOut.Forward(h)
mlx.Eval(h) // Eval after GroupNorm to avoid grid dimension issues
prev.Free()
prev = h
h = mlx.SiLU(h)
h = vae.ConvOut.Forward(h)
mlx.Eval(h)
prev.Free()
// VAE outputs [-1, 1], convert to [0, 1]
h = mlx.MulScalar(h, 0.5)
@@ -754,7 +768,6 @@ func (vae *VAEDecoder) Decode(latents *mlx.Array) *mlx.Array {
// Convert NHWC -> NCHW for output
h = mlx.Transpose(h, 0, 3, 1, 2)
prev.Free()
mlx.Eval(h)
return h

View File

@@ -26,10 +26,12 @@ type GenerateConfig struct {
Progress ProgressFunc // Optional progress callback
CapturePath string // GPU capture path (debug)
// Layer caching options (speedup via shallow layer reuse)
LayerCache bool // Enable layer caching (default: false)
CacheInterval int // Refresh cache every N steps (default: 3)
CacheLayers int // Number of shallow layers to cache (default: 15)
// TeaCache options (timestep embedding aware caching)
TeaCache bool // TeaCache is always enabled for faster inference
TeaCacheThreshold float32 // Threshold for cache reuse (default: 0.1, lower = more aggressive)
// Fused QKV (fuse Q/K/V projections into single matmul)
FusedQKV bool // Enable fused QKV projection (default: false)
}
// ProgressFunc is called during generation with step progress.
@@ -42,6 +44,7 @@ type Model struct {
TextEncoder *Qwen3TextEncoder
Transformer *Transformer
VAEDecoder *VAEDecoder
qkvFused bool // Track if QKV has been fused (do only once)
}
// Load loads the Z-Image model from ollama blob storage.
@@ -196,13 +199,17 @@ func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array,
if cfg.CFGScale <= 0 {
cfg.CFGScale = 4.0
}
if cfg.LayerCache {
if cfg.CacheInterval <= 0 {
cfg.CacheInterval = 3
}
if cfg.CacheLayers <= 0 {
cfg.CacheLayers = 15 // Half of 30 layers
}
// TeaCache enabled by default
cfg.TeaCache = true
if cfg.TeaCacheThreshold <= 0 {
cfg.TeaCacheThreshold = 0.15
}
// Enable fused QKV if requested (only fuse once)
if cfg.FusedQKV && !m.qkvFused {
m.Transformer.FuseAllQKV()
m.qkvFused = true
fmt.Println(" Fused QKV enabled")
}
useCFG := cfg.NegativePrompt != ""
@@ -260,12 +267,54 @@ func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array,
mlx.Eval(ropeCache.UnifiedCos)
}
// Step cache for shallow layer reuse (DeepCache/Learning-to-Cache style)
var stepCache *cache.StepCache
if cfg.LayerCache {
stepCache = cache.NewStepCache(cfg.CacheLayers)
fmt.Printf(" Layer caching enabled: %d layers, refresh every %d steps\n",
cfg.CacheLayers, cfg.CacheInterval)
// Pre-compute batched embeddings for CFG (outside the loop for efficiency)
var batchedEmb *mlx.Array
if useCFG {
// Concatenate embeddings once: [1, L, D] + [1, L, D] -> [2, L, D]
batchedEmb = mlx.Concatenate([]*mlx.Array{posEmb, negEmb}, 0)
mlx.Keep(batchedEmb)
mlx.Eval(batchedEmb)
}
// TeaCache for timestep-aware caching
// For CFG mode, we cache pos/neg separately, skip early steps, and always compute CFG fresh
var teaCache *cache.TeaCache
if cfg.TeaCache {
skipEarly := 0
if useCFG {
skipEarly = 3 // Skip first 3 steps for CFG to preserve structure
}
teaCache = cache.NewTeaCache(&cache.TeaCacheConfig{
Threshold: cfg.TeaCacheThreshold,
RescaleFactor: 1.0,
SkipEarlySteps: skipEarly,
})
if useCFG {
fmt.Printf(" TeaCache enabled (CFG mode): threshold=%.2f, skip first %d steps\n", cfg.TeaCacheThreshold, skipEarly)
} else {
fmt.Printf(" TeaCache enabled: threshold=%.2f\n", cfg.TeaCacheThreshold)
}
}
// cleanup frees all kept arrays when we need to abort early
cleanup := func() {
posEmb.Free()
if negEmb != nil {
negEmb.Free()
}
ropeCache.ImgCos.Free()
ropeCache.ImgSin.Free()
ropeCache.CapCos.Free()
ropeCache.CapSin.Free()
ropeCache.UnifiedCos.Free()
ropeCache.UnifiedSin.Free()
if batchedEmb != nil {
batchedEmb.Free()
}
if teaCache != nil {
teaCache.Free()
}
latents.Free()
}
// Denoising loop
@@ -277,6 +326,7 @@ func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array,
if ctx != nil {
select {
case <-ctx.Done():
cleanup()
return nil, ctx.Err()
default:
}
@@ -289,50 +339,77 @@ func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array,
}
tCurr := scheduler.Timesteps[i]
timestep := mlx.ToBFloat16(mlx.NewArray([]float32{1.0 - tCurr}, []int32{1}))
var noisePred *mlx.Array
patches := PatchifyLatents(latents, tcfg.PatchSize)
// TeaCache: check if we should compute or reuse cached output
shouldCompute := teaCache == nil || teaCache.ShouldCompute(i, tCurr)
var output *mlx.Array
if stepCache != nil {
// Use layer caching for faster inference
if shouldCompute {
timestep := mlx.ToBFloat16(mlx.NewArray([]float32{1.0 - tCurr}, []int32{1}))
patches := PatchifyLatents(latents, tcfg.PatchSize)
var output *mlx.Array
if useCFG {
posOutput := m.Transformer.ForwardWithCache(patches, timestep, posEmb, ropeCache,
stepCache, i, cfg.CacheInterval)
// Note: CFG with layer cache shares the cache between pos/neg
// This is approximate but fast - neg prompt uses same cached shallow layers
negOutput := m.Transformer.ForwardWithCache(patches, timestep, negEmb, ropeCache,
stepCache, i, cfg.CacheInterval)
diff := mlx.Sub(posOutput, negOutput)
// CFG Batching: single forward pass with batch=2
// Tile patches: [1, L, D] -> [2, L, D]
batchedPatches := mlx.Tile(patches, []int32{2, 1, 1})
// Tile timestep: [1] -> [2]
batchedTimestep := mlx.Tile(timestep, []int32{2})
// Single batched forward pass (RoPE broadcasts from [1,L,H,D] to [2,L,H,D])
batchedOutput := m.Transformer.Forward(batchedPatches, batchedTimestep, batchedEmb, ropeCache)
// Split output: [2, L, D] -> pos [1, L, D], neg [1, L, D]
outputShape := batchedOutput.Shape()
L := outputShape[1]
D := outputShape[2]
posOutput := mlx.Slice(batchedOutput, []int32{0, 0, 0}, []int32{1, L, D})
negOutput := mlx.Slice(batchedOutput, []int32{1, 0, 0}, []int32{2, L, D})
// Convert to noise predictions (unpatchify and negate)
posPred := UnpatchifyLatents(posOutput, tcfg.PatchSize, latentH, latentW, tcfg.InChannels)
posPred = mlx.Neg(posPred)
negPred := UnpatchifyLatents(negOutput, tcfg.PatchSize, latentH, latentW, tcfg.InChannels)
negPred = mlx.Neg(negPred)
// Cache pos/neg separately for TeaCache
if teaCache != nil {
teaCache.UpdateCFGCache(posPred, negPred, tCurr)
mlx.Keep(teaCache.Arrays()...)
}
// Apply CFG: noisePred = neg + scale * (pos - neg)
diff := mlx.Sub(posPred, negPred)
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
output = mlx.Add(negOutput, scaledDiff)
} else {
output = m.Transformer.ForwardWithCache(patches, timestep, posEmb, ropeCache,
stepCache, i, cfg.CacheInterval)
}
} else {
// Standard forward without caching
if useCFG {
posOutput := m.Transformer.Forward(patches, timestep, posEmb, ropeCache)
negOutput := m.Transformer.Forward(patches, timestep, negEmb, ropeCache)
diff := mlx.Sub(posOutput, negOutput)
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
output = mlx.Add(negOutput, scaledDiff)
noisePred = mlx.Add(negPred, scaledDiff)
} else {
// Non-CFG forward pass
output = m.Transformer.Forward(patches, timestep, posEmb, ropeCache)
}
}
noisePred = UnpatchifyLatents(output, tcfg.PatchSize, latentH, latentW, tcfg.InChannels)
noisePred = mlx.Neg(noisePred)
noisePred := UnpatchifyLatents(output, tcfg.PatchSize, latentH, latentW, tcfg.InChannels)
noisePred = mlx.Neg(noisePred)
// Update TeaCache
if teaCache != nil {
teaCache.UpdateCache(noisePred, tCurr)
mlx.Keep(teaCache.Arrays()...)
}
}
} else if useCFG && teaCache != nil && teaCache.HasCFGCache() {
// CFG mode: get cached pos/neg and compute CFG fresh
posPred, negPred := teaCache.GetCFGCached()
diff := mlx.Sub(posPred, negPred)
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
noisePred = mlx.Add(negPred, scaledDiff)
fmt.Printf(" [TeaCache: reusing cached pos/neg outputs]\n")
} else {
// Non-CFG mode: reuse cached noise prediction
noisePred = teaCache.GetCached()
fmt.Printf(" [TeaCache: reusing cached output]\n")
}
oldLatents := latents
latents = scheduler.Step(noisePred, latents, i)
// Keep latents and any cached arrays
if stepCache != nil {
mlx.Keep(stepCache.Arrays()...)
}
mlx.Eval(latents)
oldLatents.Free()
@@ -361,8 +438,14 @@ func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array,
ropeCache.CapSin.Free()
ropeCache.UnifiedCos.Free()
ropeCache.UnifiedSin.Free()
if stepCache != nil {
stepCache.Free()
if batchedEmb != nil {
batchedEmb.Free()
}
if teaCache != nil {
hits, misses := teaCache.Stats()
fmt.Printf(" TeaCache stats: %d hits, %d misses (%.1f%% cache rate)\n",
hits, misses, float64(hits)/float64(hits+misses)*100)
teaCache.Free()
}
// VAE decode

View File

@@ -10,6 +10,13 @@ type Layer interface {
Forward(x *mlx.Array) *mlx.Array
}
// LinearLayer is an interface for linear layers (both regular and quantized).
// This allows swapping between Linear and QuantizedLinear at runtime.
type LinearLayer interface {
Forward(x *mlx.Array) *mlx.Array
OutputDim() int32 // Returns the output dimension of the layer
}
// Linear applies an affine transformation: y = x @ W.T + b
// Weight is stored as [out_features, in_features], matching PyTorch/MLX convention.
type Linear struct {
@@ -49,6 +56,11 @@ func (l *Linear) Forward(x *mlx.Array) *mlx.Array {
return mlx.Linear(x, w)
}
// OutputDim returns the output dimension of the linear layer.
func (l *Linear) OutputDim() int32 {
return l.Weight.Shape()[0]
}
// ToQuantized converts this Linear to a QuantizedLinear.
func (l *Linear) ToQuantized(groupSize, bits int, mode string) *QuantizedLinear {
qw, scales, qbiases := mlx.Quantize(l.Weight, groupSize, bits, mode)
@@ -84,6 +96,13 @@ func (ql *QuantizedLinear) Forward(x *mlx.Array) *mlx.Array {
return out
}
// OutputDim returns the output dimension of the quantized linear layer.
// For mxfp8/mxfp4, quantized weight shape is [out_features, in_features / group_size].
// The output dimension is the first dimension of the weight.
func (ql *QuantizedLinear) OutputDim() int32 {
return ql.Weight.Shape()[0]
}
// RMSNorm represents an RMS normalization layer.
type RMSNorm struct {
Weight *mlx.Array `weight:"weight"`

22
x/imagegen/quantize.go Normal file
View File

@@ -0,0 +1,22 @@
package imagegen
import (
"io"
"strings"
)
// QuantizingTensorLayerCreator creates tensor layers with optional quantization.
// When quantize is true, returns multiple layers (weight + scales + biases).
type QuantizingTensorLayerCreator func(r io.Reader, name, dtype string, shape []int32, quantize bool) ([]LayerInfo, error)
// ShouldQuantize returns true if a tensor should be quantized.
// Quantizes linear weights only, skipping VAE, embeddings, norms, and biases.
func ShouldQuantize(name, component string) bool {
if component == "vae" {
return false
}
if strings.Contains(name, "embed") || strings.Contains(name, "norm") {
return false
}
return strings.HasSuffix(name, ".weight")
}

View File

@@ -13,16 +13,21 @@ import (
"net/http"
"os"
"os/signal"
"path/filepath"
"sync"
"syscall"
"time"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/models/glm_image"
"github.com/ollama/ollama/x/imagegen/models/zimage"
)
// ImageModel is the interface for image generation models
type ImageModel interface {
GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error)
}
// Request is the image generation request format
type Request struct {
Prompt string `json:"prompt"`
@@ -34,15 +39,17 @@ type Request struct {
// Response is streamed back for each progress update
type Response struct {
Content string `json:"content"`
Content string `json:"content,omitempty"`
Image string `json:"image,omitempty"` // Base64-encoded PNG
Done bool `json:"done"`
}
// Server holds the model and handles requests
type Server struct {
mu sync.Mutex
model *zimage.Model
model ImageModel
modelName string
modelType string // "zimage" or "glm_image"
}
// Execute is the entry point for the image runner subprocess
@@ -72,15 +79,35 @@ func Execute(args []string) error {
requiredMemory/(1024*1024*1024), availableMemory/(1024*1024*1024))
}
// Load model
model := &zimage.Model{}
if err := model.Load(*modelName); err != nil {
return fmt.Errorf("failed to load model: %w", err)
// Detect model type and load appropriate model
modelType, err := detectModelType(*modelName)
if err != nil {
return fmt.Errorf("failed to detect model type: %w", err)
}
var model ImageModel
switch modelType {
case "GlmImagePipeline":
slog.Info("loading GLM-Image model")
m := &glm_image.Model{}
if err := m.Load(*modelName); err != nil {
return fmt.Errorf("failed to load GLM-Image model: %w", err)
}
model = m
default:
// Default to zimage for ZImagePipeline, FluxPipeline, and unknown types
slog.Info("loading Z-Image model")
m := &zimage.Model{}
if err := m.Load(*modelName); err != nil {
return fmt.Errorf("failed to load Z-Image model: %w", err)
}
model = m
}
server := &Server{
model: model,
modelName: *modelName,
modelType: modelType,
}
// Set up HTTP handlers
@@ -144,7 +171,13 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
req.Height = 1024
}
if req.Steps <= 0 {
req.Steps = 9
// Default steps depend on model type
switch s.modelType {
case "GlmImagePipeline":
req.Steps = 50 // GLM-Image default
default:
req.Steps = 9 // Z-Image turbo default
}
}
if req.Seed <= 0 {
req.Seed = time.Now().UnixNano()
@@ -159,25 +192,9 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
return
}
// Generate image
// Generate image using interface method
ctx := r.Context()
img, err := s.model.GenerateFromConfig(ctx, &zimage.GenerateConfig{
Prompt: req.Prompt,
Width: req.Width,
Height: req.Height,
Steps: req.Steps,
Seed: req.Seed,
Progress: func(step, total int) {
resp := Response{
Content: fmt.Sprintf("\rGenerating: step %d/%d", step, total),
Done: false,
}
data, _ := json.Marshal(resp)
w.Write(data)
w.Write([]byte("\n"))
flusher.Flush()
},
})
img, err := s.model.GenerateImage(ctx, req.Prompt, req.Width, req.Height, req.Steps, req.Seed)
if err != nil {
// Don't send error for cancellation
@@ -191,10 +208,10 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
return
}
// Save image
outPath := filepath.Join(os.TempDir(), fmt.Sprintf("ollama-image-%d.png", time.Now().UnixNano()))
if err := imagegen.SaveImage(img, outPath); err != nil {
resp := Response{Content: fmt.Sprintf("error saving: %v", err), Done: true}
// Encode image as base64 PNG
imageData, err := imagegen.EncodeImageBase64(img)
if err != nil {
resp := Response{Content: fmt.Sprintf("error encoding: %v", err), Done: true}
data, _ := json.Marshal(resp)
w.Write(data)
w.Write([]byte("\n"))
@@ -204,14 +221,47 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
// Free the generated image array and clean up MLX state
img.Free()
mlx.ClearCache()
mlx.MetalResetPeakMemory()
// Send final response
// Send final response with image data
resp := Response{
Content: fmt.Sprintf("\n\nImage saved to: %s\n", outPath),
Done: true,
Image: imageData,
Done: true,
}
data, _ := json.Marshal(resp)
w.Write(data)
w.Write([]byte("\n"))
flusher.Flush()
}
// detectModelType reads the model manifest and returns the pipeline class name
func detectModelType(modelName string) (string, error) {
manifest, err := imagegen.LoadManifest(modelName)
if err != nil {
return "", err
}
data, err := manifest.ReadConfig("model_index.json")
if err != nil {
return "ZImagePipeline", nil // Default to Z-Image
}
// Try both _class_name (diffusers format) and architecture (ollama format)
var index struct {
ClassName string `json:"_class_name"`
Architecture string `json:"architecture"`
}
if err := json.Unmarshal(data, &index); err != nil {
return "ZImagePipeline", nil
}
// Prefer _class_name, fall back to architecture
className := index.ClassName
if className == "" {
className = index.Architecture
}
if className == "" {
return "ZImagePipeline", nil
}
return className, nil
}

View File

@@ -8,6 +8,7 @@ import (
"strings"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
)
// WeightSource is an interface for loading weights.
@@ -102,6 +103,22 @@ func loadStruct(v reflect.Value, weights WeightSource, prefix string, errs *[]st
}
}
// Handle nn.LinearLayer interface fields specially
if field.Type == reflect.TypeOf((*nn.LinearLayer)(nil)).Elem() {
if !hasTag {
continue // no tag = skip
}
layer, err := LoadLinearLayer(weights, fullPath)
if err != nil {
if !optional {
*errs = append(*errs, fullPath+": "+err.Error())
}
continue
}
fieldVal.Set(reflect.ValueOf(layer))
continue
}
// Handle by kind
switch fieldVal.Kind() {
case reflect.Ptr:
@@ -176,3 +193,64 @@ func joinPath(prefix, suffix string) string {
}
return prefix + "." + suffix
}
// LoadLinearLayer loads a linear layer from weights, automatically detecting if it's quantized.
// If {path}.weight_scale exists, dequantizes the weights.
func LoadLinearLayer(weights WeightSource, path string) (nn.LinearLayer, error) {
// Check if this is a quantized layer by looking for scale tensor
scalePath := path + ".weight_scale"
if weights.HasTensor(scalePath) {
weight, err := weights.GetTensor(path + ".weight")
if err != nil {
return nil, fmt.Errorf("failed to load quantized weight %s: %w", path, err)
}
scales, err := weights.GetTensor(scalePath)
if err != nil {
return nil, fmt.Errorf("failed to load scales %s: %w", scalePath, err)
}
// Bias is optional
var bias *mlx.Array
biasPath := path + ".bias"
if weights.HasTensor(biasPath) {
bias, _ = weights.GetTensor(biasPath)
}
var qbiases *mlx.Array
qbiasPath := path + ".weight_qbias"
if weights.HasTensor(qbiasPath) {
qbiases, _ = weights.GetTensor(qbiasPath)
}
if mlx.MetalIsAvailable() {
return &nn.QuantizedLinear{
Weight: weight,
Scales: scales,
QBiases: qbiases,
Bias: bias,
GroupSize: 32,
Bits: 8,
Mode: "affine",
}, nil
}
dequantized := mlx.Dequantize(weight, scales, qbiases, 32, 8, "affine")
return nn.NewLinear(dequantized, bias), nil
}
// Load as regular Linear
weight, err := weights.GetTensor(path + ".weight")
if err != nil {
return nil, fmt.Errorf("failed to load weight %s: %w", path, err)
}
// Bias is optional
var bias *mlx.Array
biasPath := path + ".bias"
if weights.HasTensor(biasPath) {
bias, _ = weights.GetTensor(biasPath)
}
return nn.NewLinear(weight, bias), nil
}

View File

@@ -14,7 +14,9 @@ import (
"os"
"os/exec"
"path/filepath"
"runtime"
"strconv"
"strings"
"sync"
"time"
@@ -46,7 +48,8 @@ type completionRequest struct {
// completionResponse is received from the subprocess
type completionResponse struct {
Content string `json:"content"`
Content string `json:"content,omitempty"`
Image string `json:"image,omitempty"`
Done bool `json:"done"`
}
@@ -69,7 +72,7 @@ func NewServer(modelName string) (*Server, error) {
port = rand.Intn(65535-49152) + 49152
}
// Get the ollama executable path
// Get the ollama-mlx executable path (in same directory as current executable)
exe, err := os.Executable()
if err != nil {
return nil, fmt.Errorf("unable to lookup executable path: %w", err)
@@ -77,11 +80,42 @@ func NewServer(modelName string) (*Server, error) {
if eval, err := filepath.EvalSymlinks(exe); err == nil {
exe = eval
}
mlxExe := filepath.Join(filepath.Dir(exe), "ollama-mlx")
// Spawn subprocess: ollama runner --image-engine --model <path> --port <port>
cmd := exec.Command(exe, "runner", "--image-engine", "--model", modelName, "--port", strconv.Itoa(port))
// Spawn subprocess: ollama-mlx runner --image-engine --model <path> --port <port>
cmd := exec.Command(mlxExe, "runner", "--image-engine", "--model", modelName, "--port", strconv.Itoa(port))
cmd.Env = os.Environ()
// On Linux, set LD_LIBRARY_PATH to include MLX library directories
if runtime.GOOS == "linux" {
// Build library paths: start with LibOllamaPath, then add any mlx_* subdirectories
libraryPaths := []string{ml.LibOllamaPath}
if mlxDirs, err := filepath.Glob(filepath.Join(ml.LibOllamaPath, "mlx_*")); err == nil {
libraryPaths = append(libraryPaths, mlxDirs...)
}
// Append existing LD_LIBRARY_PATH if set
if existingPath, ok := os.LookupEnv("LD_LIBRARY_PATH"); ok {
libraryPaths = append(libraryPaths, filepath.SplitList(existingPath)...)
}
pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator))
// Update or add LD_LIBRARY_PATH in cmd.Env
found := false
for i := range cmd.Env {
if strings.HasPrefix(cmd.Env[i], "LD_LIBRARY_PATH=") {
cmd.Env[i] = "LD_LIBRARY_PATH=" + pathEnvVal
found = true
break
}
}
if !found {
cmd.Env = append(cmd.Env, "LD_LIBRARY_PATH="+pathEnvVal)
}
slog.Debug("mlx subprocess library path", "LD_LIBRARY_PATH", pathEnvVal)
}
s := &Server{
cmd: cmd,
port: port,
@@ -112,7 +146,7 @@ func NewServer(modelName string) (*Server, error) {
}
}()
slog.Info("starting image runner subprocess", "model", modelName, "port", port)
slog.Info("starting ollama-mlx image runner subprocess", "exe", mlxExe, "model", modelName, "port", port)
if err := cmd.Start(); err != nil {
return nil, fmt.Errorf("failed to start image runner: %w", err)
}
@@ -250,15 +284,23 @@ func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn f
return fmt.Errorf("completion request failed: %d", resp.StatusCode)
}
// Stream responses
// Stream responses - use large buffer for base64 image data
scanner := bufio.NewScanner(resp.Body)
scanner.Buffer(make([]byte, 1024*1024), 16*1024*1024) // 16MB max
for scanner.Scan() {
var cresp completionResponse
if err := json.Unmarshal(scanner.Bytes(), &cresp); err != nil {
continue
}
content := cresp.Content
// If this is the final response with an image, encode it in the content
if cresp.Done && cresp.Image != "" {
content = "IMAGE_BASE64:" + cresp.Image
}
fn(llm.CompletionResponse{
Content: cresp.Content,
Content: content,
Done: cresp.Done,
})
if cresp.Done {

View File

@@ -45,24 +45,33 @@ func download(ctx context.Context, opts DownloadOptions) error {
return nil
}
// Filter existing
var blobs []Blob
// Calculate total from all blobs (for accurate progress reporting on resume)
var total int64
for _, b := range opts.Blobs {
total += b.Size
}
// Filter out already-downloaded blobs and track completed bytes
var blobs []Blob
var alreadyCompleted int64
for _, b := range opts.Blobs {
if fi, _ := os.Stat(filepath.Join(opts.DestDir, digestToPath(b.Digest))); fi != nil && fi.Size() == b.Size {
if opts.Logger != nil {
opts.Logger.Debug("blob already exists", "digest", b.Digest, "size", b.Size)
}
alreadyCompleted += b.Size
continue
}
blobs = append(blobs, b)
total += b.Size
}
if len(blobs) == 0 {
return nil
}
token := opts.Token
progress := newProgressTracker(total, opts.Progress)
progress.add(alreadyCompleted) // Report already-downloaded bytes upfront
d := &downloader{
client: cmp.Or(opts.Client, defaultClient),
baseURL: opts.BaseURL,
@@ -72,7 +81,7 @@ func download(ctx context.Context, opts DownloadOptions) error {
getToken: opts.GetToken,
userAgent: cmp.Or(opts.UserAgent, defaultUserAgent),
stallTimeout: cmp.Or(opts.StallTimeout, defaultStallTimeout),
progress: newProgressTracker(total, opts.Progress),
progress: progress,
speeds: &speedTracker{},
logger: opts.Logger,
}

View File

@@ -110,8 +110,6 @@ var defaultClient = &http.Client{
MaxIdleConnsPerHost: 100,
IdleConnTimeout: 90 * time.Second,
},
Timeout: 5 * time.Minute,
// Don't follow redirects automatically - we handle them manually
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},

View File

@@ -284,6 +284,83 @@ func TestDownloadSkipsExisting(t *testing.T) {
}
}
func TestDownloadResumeProgressTotal(t *testing.T) {
// Test that when resuming a download with some blobs already present:
// 1. Total reflects ALL blob sizes (not just remaining)
// 2. Completed starts at the size of already-downloaded blobs
serverDir := t.TempDir()
blob1, data1 := createTestBlob(t, serverDir, 1000)
blob2, data2 := createTestBlob(t, serverDir, 2000)
blob3, data3 := createTestBlob(t, serverDir, 3000)
// Pre-populate client with blob1 and blob2 (simulating partial download)
clientDir := t.TempDir()
for _, b := range []struct {
blob Blob
data []byte
}{{blob1, data1}, {blob2, data2}} {
path := filepath.Join(clientDir, digestToPath(b.blob.Digest))
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(path, b.data, 0o644); err != nil {
t.Fatal(err)
}
}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
digest := filepath.Base(r.URL.Path)
path := filepath.Join(serverDir, digestToPath(digest))
data, err := os.ReadFile(path)
if err != nil {
http.NotFound(w, r)
return
}
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
w.WriteHeader(http.StatusOK)
w.Write(data)
}))
defer server.Close()
var firstCompleted, firstTotal int64
var gotFirstProgress bool
var mu sync.Mutex
err := Download(context.Background(), DownloadOptions{
Blobs: []Blob{blob1, blob2, blob3},
BaseURL: server.URL,
DestDir: clientDir,
Concurrency: 1,
Progress: func(completed, total int64) {
mu.Lock()
defer mu.Unlock()
if !gotFirstProgress {
firstCompleted = completed
firstTotal = total
gotFirstProgress = true
}
},
})
if err != nil {
t.Fatalf("Download failed: %v", err)
}
// Total should be sum of ALL blobs, not just blob3
expectedTotal := blob1.Size + blob2.Size + blob3.Size
if firstTotal != expectedTotal {
t.Errorf("Total = %d, want %d (should include all blobs)", firstTotal, expectedTotal)
}
// First progress call should show already-completed bytes from blob1+blob2
expectedCompleted := blob1.Size + blob2.Size
if firstCompleted < expectedCompleted {
t.Errorf("First completed = %d, want >= %d (should include already-downloaded blobs)", firstCompleted, expectedCompleted)
}
// Verify blob3 was downloaded
verifyBlob(t, clientDir, blob3, data3)
}
func TestDownloadDigestMismatch(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Return wrong data

View File

@@ -54,6 +54,16 @@ func (r *Registry) RegisterBash() {
r.Register(&BashTool{})
}
// RegisterWebSearch adds the web search tool to the registry.
func (r *Registry) RegisterWebSearch() {
r.Register(&WebSearchTool{})
}
// RegisterWebFetch adds the web fetch tool to the registry.
func (r *Registry) RegisterWebFetch() {
r.Register(&WebFetchTool{})
}
// Get retrieves a tool by name.
func (r *Registry) Get(name string) (Tool, bool) {
tool, ok := r.tools[name]

162
x/tools/webfetch.go Normal file
View File

@@ -0,0 +1,162 @@
package tools
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/auth"
)
const (
webFetchAPI = "https://ollama.com/api/web_fetch"
webFetchTimeout = 30 * time.Second
)
// ErrWebFetchAuthRequired is returned when web fetch requires authentication
var ErrWebFetchAuthRequired = errors.New("web fetch requires authentication")
// WebFetchTool implements web page fetching using Ollama's hosted API.
type WebFetchTool struct{}
// Name returns the tool name.
func (w *WebFetchTool) Name() string {
return "web_fetch"
}
// Description returns a description of the tool.
func (w *WebFetchTool) Description() string {
return "Fetch and extract text content from a web page. Use this to read the full content of a URL found in search results or provided by the user."
}
// Schema returns the tool's parameter schema.
func (w *WebFetchTool) Schema() api.ToolFunction {
props := api.NewToolPropertiesMap()
props.Set("url", api.ToolProperty{
Type: api.PropertyType{"string"},
Description: "The URL to fetch and extract content from",
})
return api.ToolFunction{
Name: w.Name(),
Description: w.Description(),
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: props,
Required: []string{"url"},
},
}
}
// webFetchRequest is the request body for the web fetch API.
type webFetchRequest struct {
URL string `json:"url"`
}
// webFetchResponse is the response from the web fetch API.
type webFetchResponse struct {
Title string `json:"title"`
Content string `json:"content"`
Links []string `json:"links,omitempty"`
}
// Execute fetches content from a web page.
// Uses Ollama key signing for authentication - this makes requests via ollama.com API.
func (w *WebFetchTool) Execute(args map[string]any) (string, error) {
urlStr, ok := args["url"].(string)
if !ok || urlStr == "" {
return "", fmt.Errorf("url parameter is required")
}
// Validate URL
if _, err := url.Parse(urlStr); err != nil {
return "", fmt.Errorf("invalid URL: %w", err)
}
// Prepare request
reqBody := webFetchRequest{
URL: urlStr,
}
jsonBody, err := json.Marshal(reqBody)
if err != nil {
return "", fmt.Errorf("marshaling request: %w", err)
}
// Parse URL and add timestamp for signing
fetchURL, err := url.Parse(webFetchAPI)
if err != nil {
return "", fmt.Errorf("parsing fetch URL: %w", err)
}
q := fetchURL.Query()
q.Add("ts", strconv.FormatInt(time.Now().Unix(), 10))
fetchURL.RawQuery = q.Encode()
// Sign the request using Ollama key (~/.ollama/id_ed25519)
ctx := context.Background()
data := fmt.Appendf(nil, "%s,%s", http.MethodPost, fetchURL.RequestURI())
signature, err := auth.Sign(ctx, data)
if err != nil {
return "", fmt.Errorf("signing request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fetchURL.String(), bytes.NewBuffer(jsonBody))
if err != nil {
return "", fmt.Errorf("creating request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
if signature != "" {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", signature))
}
// Send request
client := &http.Client{Timeout: webFetchTimeout}
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("sending request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("reading response: %w", err)
}
if resp.StatusCode == http.StatusUnauthorized {
return "", ErrWebFetchAuthRequired
}
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("web fetch API returned status %d: %s", resp.StatusCode, string(body))
}
// Parse response
var fetchResp webFetchResponse
if err := json.Unmarshal(body, &fetchResp); err != nil {
return "", fmt.Errorf("parsing response: %w", err)
}
// Format result
var sb strings.Builder
if fetchResp.Title != "" {
sb.WriteString(fmt.Sprintf("Title: %s\n\n", fetchResp.Title))
}
if fetchResp.Content != "" {
sb.WriteString("Content:\n")
sb.WriteString(fetchResp.Content)
} else {
sb.WriteString("No content could be extracted from the page.")
}
return sb.String(), nil
}