Compare commits

...

40 Commits

Author SHA1 Message Date
Eva Ho
9d49839a2f fix test 2026-01-19 12:54:45 -05:00
Eva Ho
b519c636ff updater: add more test converage to cover auto updater 2026-01-19 12:54:45 -05:00
Eva Ho
5bca8928b3 fix tests 2026-01-19 12:54:45 -05:00
Eva Ho
bdcda0c243 fix tests 2026-01-19 12:54:45 -05:00
Eva Ho
34a51bdf26 fix test 2026-01-19 12:54:45 -05:00
Eva Ho
d0c3afb301 address comments 2026-01-19 12:54:45 -05:00
Eva Ho
83b0d76de7 address comment 2026-01-19 12:54:45 -05:00
Eva Ho
9fc0500518 address comment 2026-01-19 12:54:45 -05:00
Eva Ho
08daf70ceb fix: gofmt formatting in updater_test.go 2026-01-19 12:54:45 -05:00
Eva Ho
0e595cb5ea fix test 2026-01-19 12:54:45 -05:00
Eva Ho
cf0e8e64b5 fix format 2026-01-19 12:54:45 -05:00
Eva Ho
0e35204b32 fix test 2026-01-19 12:54:45 -05:00
Eva Ho
2f352bbf14 fix test 2026-01-19 12:54:45 -05:00
Eva Ho
1ada0cfbb7 clean up 2026-01-19 12:54:45 -05:00
Eva Ho
c5b7e8f343 fix behaviour when switching between enabled and disabled 2026-01-19 12:54:45 -05:00
Eva Ho
9738b25a7b fix test 2026-01-19 12:54:45 -05:00
Eva Ho
b7915dd601 app: add upgrade configuration to settings page 2026-01-19 12:54:45 -05:00
Jeffrey Morgan
03bf241c33 x/imagegen: add FP4 quantization support for image generation models (#13773)
Add --quantize fp4 support to ollama create for image generation models
(flux2, z-image-turbo), using MLX's affine 4-bit quantization.

Changes:
- Add fp4 to validation in CreateImageGenModel
- Add FP4 case to quantizeTensor (group_size=32, bits=4, affine mode)
- Add GetQuantization() to WeightSource interface for dynamic params
- Update LoadLinearLayer to use quantization params from model metadata
2026-01-19 00:54:54 -08:00
Jeffrey Morgan
a887406c24 x/imagegen: add preliminary support for FLUX.2-klein model (#13772) 2026-01-18 22:30:49 -08:00
Jeffrey Morgan
d51e95ba7e server: prevent image generation models from reloading on every request (#13771)
The loadImageGen function was not setting Options on the runnerRef,
causing needsReload() to always return true (since it checks if
runner.Options == nil). This resulted in the image generation
subprocess being killed and restarted for every request.
2026-01-18 20:50:04 -08:00
Jeffrey Morgan
3d01f2aa34 parsers: refactor Nemotron parser to reuse Qwen3Coder for tool calls (#13764)
Simplify Nemotron3NanoParser by delegating tool call parsing to
Qwen3CoderParser instead of duplicating the parsing logic. The
Nemotron parser now only handles the thinking state machine and
transitions to Qwen3CoderParser for content and tool call parsing.

This also fixes an issue where tool calls without </think> would
cause the parser to get stuck in thinking mode.
2026-01-17 18:28:52 -08:00
Jeffrey Morgan
634c416645 Add experimental image generation fields to /api/generate (#13753)
Request fields (experimental):
- width: image width (max 4096)
- height: image height (max 4096)
- steps: denoising steps
- seed: random seed

Response fields (experimental):
- images: base64-encoded generated images
- completed: current step progress
- total: total steps

Other changes:
- Fix lifecycle bug where image models wouldn't unload (refCount issue)
- Fix "headers already written" error on Ctrl+C during streaming
- Add gin middleware for OpenAI /v1/images/generations compatibility
- Update CLI to use /api/generate with progress bar
- Add preload support in interactive mode
2026-01-17 18:27:41 -08:00
Michael
57de86cc61 docs: update claude code docs (#13757)
* docs: update claude code docs
2026-01-16 22:41:34 -08:00
Daniel Hiltgen
12719b6e87 MLX - dynamic loading of mlx-c (#13735)
* MLX - dynamic loading of mlx-c

Create a wrapper layer to indirect the dependency on mlx-c so
the main ollama binary does not have a load-time dependency on mlx-c, mlx, and on linux, cuda.  Lazy load the library via dlopen
so we can adjust the path to ensure the dependencies are found
and fail gracefully if not present.

* review comments

* fix broken tests
2026-01-16 16:34:22 -08:00
Patrick Devine
a077d996e3 Fix create and show commands for experimental models (#13741)
* x: make `ollama create --experimental` import from safetensors

This change allows pulling in safetensors models into the new experimental model format, and also
fixes the `ollama show` command to be able to correctly display the model information.

* gofumpt the linter

* gofumpt the linter again

* validate the model name
2026-01-16 14:31:55 -08:00
Jeffrey Morgan
c23d5095de x/imagegen: clean up image generation code (#13725) 2026-01-16 12:19:25 -08:00
Bruce MacDonald
7601f0e93e server: reject unexpected auth hosts (#13738)
Added validation to ensure auth redirects stay on the same host as the original request. The fix is a single check in getAuthorizationToken comparing the realm URL's host against the request host. Added tests for the auth flow.

Co-Authored-By: Gecko Security <188164982+geckosecurity@users.noreply.github.com>

* gofmt

---------

Co-authored-by: Gecko Security <188164982+geckosecurity@users.noreply.github.com>
2026-01-16 14:10:36 -05:00
Eva H
aad3f03890 app: allow macOS app to terminate during system shutdown (#13737) 2026-01-16 09:05:04 -05:00
Gyungrai Wang
55d0b6e8b9 integration: fix tools_test.go for ToolCallFunctionArguments API change (#13731) 2026-01-15 16:08:09 -08:00
Devon Rifkin
38eac40d56 openai: tweak v1/responses to conform better (#13736)
* openai: tweak v1/responses to conform better

* openai: provide better error for image URLs

* lint
2026-01-15 15:46:36 -08:00
Jeffrey Morgan
80f3f1bc25 readme: add instructions to build with MLX (#13733) 2026-01-15 11:03:52 -08:00
Parth Sareen
b1a0db547b docs: add env var needed for claude code in docs (#13721) 2026-01-15 10:11:00 -08:00
Parth Sareen
75d7b5f926 cmd: enable multi-line input and shift enter (#13694) 2026-01-14 17:52:46 -08:00
vincent d warmerdam
349d814814 docs: add marimo integration (#13326)
* docs added

* fix title

* add marimo to docs.json

---------

Co-authored-by: Devon Rifkin <drifkin@drifkin.net>
2026-01-14 17:37:38 -08:00
Yuhong Sun
c8743031e0 docs: add onyx integration (#13135)
* Ready for team review

* Update docs/integrations/onyx.mdx

Co-authored-by: Jeffrey Morgan <jmorganca@gmail.com>

* update docs.json

---------

Co-authored-by: Jeffrey Morgan <jmorganca@gmail.com>
Co-authored-by: Devon Rifkin <drifkin@drifkin.net>
2026-01-14 17:32:05 -08:00
Jeffrey Morgan
4adb9cf4bb scripts: fix macOS auto-update signature verification failure (#13713)
Add --norsrc flag to ditto commands when creating Ollama-darwin.zip
to exclude AppleDouble resource fork files (._* files) from the archive.

The mlx.metallib file has extended attributes, which causes ditto to
include a ._mlx.metallib AppleDouble file in the zip. Since this file
is not part of the code signature seal, macOS rejects the bundle during
auto-update verification with:

  "a sealed resource is missing or invalid"
  "file added: .../._mlx.metallib"

The --norsrc flag prevents ditto from preserving resource forks and
extended attributes, ensuring only signed files are included in the
release archive.
2026-01-14 07:48:10 -08:00
Daniel Hiltgen
74f475e735 Revert "Documentation edits made through Mintlify web editor" (#13688)
This reverts commit c6d4c0c7f2.

Merge after 0.14.0 ships for the updated Linux documentation.
2026-01-14 07:42:34 -08:00
Maternion
875cecba74 docs: update default context window size to 4096 tokens (#13709) 2026-01-14 01:01:28 -08:00
Josh Daniel Bañares
7d411a4686 docs: update web search param in examples (#13711) 2026-01-14 00:38:39 -08:00
Daniel Hiltgen
02a2401596 mlx: bundle openblas dependency (#13706) 2026-01-13 15:29:47 -08:00
116 changed files with 17374 additions and 2180 deletions

View File

@@ -190,7 +190,7 @@ if(MLX_ENGINE)
install(TARGETS mlx mlxc install(TARGETS mlx mlxc
RUNTIME_DEPENDENCIES RUNTIME_DEPENDENCIES
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR} DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
PRE_INCLUDE_REGEXES cublas cublasLt cudart nvrtc cudnn nccl PRE_INCLUDE_REGEXES cublas cublasLt cudart nvrtc nvrtc-builtins cudnn nccl openblas gfortran
PRE_EXCLUDE_REGEXES ".*" PRE_EXCLUDE_REGEXES ".*"
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX

View File

@@ -32,7 +32,7 @@ ENV PATH=/${VULKANVERSION}/x86_64/bin:$PATH
FROM --platform=linux/arm64 almalinux:8 AS base-arm64 FROM --platform=linux/arm64 almalinux:8 AS base-arm64
# install epel-release for ccache # install epel-release for ccache
RUN yum install -y yum-utils epel-release \ RUN yum install -y yum-utils epel-release \
&& dnf install -y clang ccache \ && dnf install -y clang ccache git \
&& yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/sbsa/cuda-rhel8.repo && yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/sbsa/cuda-rhel8.repo
ENV CC=clang CXX=clang++ ENV CC=clang CXX=clang++
@@ -149,6 +149,7 @@ COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
COPY x/ml/backend/mlx x/ml/backend/mlx COPY x/ml/backend/mlx x/ml/backend/mlx
COPY go.mod go.sum . COPY go.mod go.sum .
COPY MLX_VERSION .
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
ENV PATH=/usr/local/go/bin:$PATH ENV PATH=/usr/local/go/bin:$PATH
RUN go mod download RUN go mod download
@@ -156,14 +157,6 @@ RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'MLX CUDA 13' -DBLAS_INCLUDE_DIRS=/usr/include/openblas -DLAPACK_INCLUDE_DIRS=/usr/include/openblas \ cmake --preset 'MLX CUDA 13' -DBLAS_INCLUDE_DIRS=/usr/include/openblas -DLAPACK_INCLUDE_DIRS=/usr/include/openblas \
&& cmake --build --parallel ${PARALLEL} --preset 'MLX CUDA 13' \ && cmake --build --parallel ${PARALLEL} --preset 'MLX CUDA 13' \
&& cmake --install build --component MLX --strip --parallel ${PARALLEL} && cmake --install build --component MLX --strip --parallel ${PARALLEL}
COPY . .
ARG GOFLAGS="'-ldflags=-w -s'"
ENV CGO_ENABLED=1
ARG CGO_CFLAGS
ARG CGO_CXXFLAGS
RUN mkdir -p dist/bin
RUN --mount=type=cache,target=/root/.cache/go-build \
go build -tags mlx -trimpath -buildmode=pie -o dist/bin/ollama-mlx .
FROM base AS build FROM base AS build
WORKDIR /go/src/github.com/ollama/ollama WORKDIR /go/src/github.com/ollama/ollama
@@ -172,12 +165,14 @@ RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-
ENV PATH=/usr/local/go/bin:$PATH ENV PATH=/usr/local/go/bin:$PATH
RUN go mod download RUN go mod download
COPY . . COPY . .
# Clone mlx-c headers for CGO (version from MLX_VERSION file)
RUN git clone --depth 1 --branch "$(cat MLX_VERSION)" https://github.com/ml-explore/mlx-c.git build/_deps/mlx-c-src
ARG GOFLAGS="'-ldflags=-w -s'" ARG GOFLAGS="'-ldflags=-w -s'"
ENV CGO_ENABLED=1 ENV CGO_ENABLED=1
ARG CGO_CFLAGS ENV CGO_CFLAGS="-I/go/src/github.com/ollama/ollama/build/_deps/mlx-c-src"
ARG CGO_CXXFLAGS ARG CGO_CXXFLAGS
RUN --mount=type=cache,target=/root/.cache/go-build \ RUN --mount=type=cache,target=/root/.cache/go-build \
go build -trimpath -buildmode=pie -o /bin/ollama . go build -tags mlx -trimpath -buildmode=pie -o /bin/ollama .
FROM --platform=linux/amd64 scratch AS amd64 FROM --platform=linux/amd64 scratch AS amd64
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/ # COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
@@ -185,7 +180,6 @@ COPY --from=cuda-12 dist/lib/ollama /lib/ollama/
COPY --from=cuda-13 dist/lib/ollama /lib/ollama/ COPY --from=cuda-13 dist/lib/ollama /lib/ollama/
COPY --from=vulkan dist/lib/ollama /lib/ollama/ COPY --from=vulkan dist/lib/ollama /lib/ollama/
COPY --from=mlx /go/src/github.com/ollama/ollama/dist/lib/ollama /lib/ollama/ COPY --from=mlx /go/src/github.com/ollama/ollama/dist/lib/ollama /lib/ollama/
COPY --from=mlx /go/src/github.com/ollama/ollama/dist/bin/ /bin/
FROM --platform=linux/arm64 scratch AS arm64 FROM --platform=linux/arm64 scratch AS arm64
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/ # COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/

1
MLX_VERSION Normal file
View File

@@ -0,0 +1 @@
v0.4.1

View File

@@ -48,7 +48,7 @@ ollama run gemma3
## Model library ## Model library
Ollama supports a list of models available on [ollama.com/library](https://ollama.com/library 'ollama model library') Ollama supports a list of models available on [ollama.com/library](https://ollama.com/library "ollama model library")
Here are some example models that can be downloaded: Here are some example models that can be downloaded:
@@ -79,7 +79,7 @@ Here are some example models that can be downloaded:
| Code Llama | 7B | 3.8GB | `ollama run codellama` | | Code Llama | 7B | 3.8GB | `ollama run codellama` |
| Llama 2 Uncensored | 7B | 3.8GB | `ollama run llama2-uncensored` | | Llama 2 Uncensored | 7B | 3.8GB | `ollama run llama2-uncensored` |
| LLaVA | 7B | 4.5GB | `ollama run llava` | | LLaVA | 7B | 4.5GB | `ollama run llava` |
| Granite-3.3 | 8B | 4.9GB | `ollama run granite3.3` | | Granite-3.3 | 8B | 4.9GB | `ollama run granite3.3` |
> [!NOTE] > [!NOTE]
> You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models. > You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models.
@@ -260,6 +260,38 @@ Finally, in a separate shell, run a model:
./ollama run llama3.2 ./ollama run llama3.2
``` ```
## Building with MLX (experimental)
First build the MLX libraries:
```shell
cmake --preset MLX
cmake --build --preset MLX --parallel
cmake --install build --component MLX
```
When building with the `-tags mlx` flag, the main `ollama` binary includes MLX support for experimental features like image generation:
```shell
go build -tags mlx .
```
Finally, start the server:
```
./ollama serve
```
### Building MLX with CUDA
When building with CUDA, use the preset "MLX CUDA 13" or "MLX CUDA 12" to enable CUDA with default architectures:
```shell
cmake --preset 'MLX CUDA 13'
cmake --build --preset 'MLX CUDA 13' --parallel
cmake --install build --component MLX
```
## REST API ## REST API
Ollama has a REST API for running and managing models. Ollama has a REST API for running and managing models.
@@ -290,6 +322,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
### Web & Desktop ### Web & Desktop
- [Onyx](https://github.com/onyx-dot-app/onyx)
- [Open WebUI](https://github.com/open-webui/open-webui) - [Open WebUI](https://github.com/open-webui/open-webui)
- [SwiftChat (macOS with ReactNative)](https://github.com/aws-samples/swift-chat) - [SwiftChat (macOS with ReactNative)](https://github.com/aws-samples/swift-chat)
- [Enchanted (macOS native)](https://github.com/AugustDev/enchanted) - [Enchanted (macOS native)](https://github.com/AugustDev/enchanted)
@@ -493,7 +526,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
### Database ### Database
- [pgai](https://github.com/timescale/pgai) - PostgreSQL as a vector database (Create and search embeddings from Ollama models using pgvector) - [pgai](https://github.com/timescale/pgai) - PostgreSQL as a vector database (Create and search embeddings from Ollama models using pgvector)
- [Get started guide](https://github.com/timescale/pgai/blob/main/docs/vectorizer-quick-start.md) - [Get started guide](https://github.com/timescale/pgai/blob/main/docs/vectorizer-quick-start.md)
- [MindsDB](https://github.com/mindsdb/mindsdb/blob/staging/mindsdb/integrations/handlers/ollama_handler/README.md) (Connects Ollama models with nearly 200 data platforms and apps) - [MindsDB](https://github.com/mindsdb/mindsdb/blob/staging/mindsdb/integrations/handlers/ollama_handler/README.md) (Connects Ollama models with nearly 200 data platforms and apps)
- [chromem-go](https://github.com/philippgille/chromem-go/blob/v0.5.0/embed_ollama.go) with [example](https://github.com/philippgille/chromem-go/tree/v0.5.0/examples/rag-wikipedia-ollama) - [chromem-go](https://github.com/philippgille/chromem-go/blob/v0.5.0/embed_ollama.go) with [example](https://github.com/philippgille/chromem-go/tree/v0.5.0/examples/rag-wikipedia-ollama)
- [Kangaroo](https://github.com/dbkangaroo/kangaroo) (AI-powered SQL client and admin tool for popular databases) - [Kangaroo](https://github.com/dbkangaroo/kangaroo) (AI-powered SQL client and admin tool for popular databases)
@@ -636,6 +669,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [llama.cpp](https://github.com/ggml-org/llama.cpp) project founded by Georgi Gerganov. - [llama.cpp](https://github.com/ggml-org/llama.cpp) project founded by Georgi Gerganov.
### Observability ### Observability
- [Opik](https://www.comet.com/docs/opik/cookbook/ollama) is an open-source platform to debug, evaluate, and monitor your LLM applications, RAG systems, and agentic workflows with comprehensive tracing, automated evaluations, and production-ready dashboards. Opik supports native integration to Ollama. - [Opik](https://www.comet.com/docs/opik/cookbook/ollama) is an open-source platform to debug, evaluate, and monitor your LLM applications, RAG systems, and agentic workflows with comprehensive tracing, automated evaluations, and production-ready dashboards. Opik supports native integration to Ollama.
- [Lunary](https://lunary.ai/docs/integrations/ollama) is the leading open-source LLM observability platform. It provides a variety of enterprise-grade features such as real-time analytics, prompt templates management, PII masking, and comprehensive agent tracing. - [Lunary](https://lunary.ai/docs/integrations/ollama) is the leading open-source LLM observability platform. It provides a variety of enterprise-grade features such as real-time analytics, prompt templates management, PII masking, and comprehensive agent tracing.
- [OpenLIT](https://github.com/openlit/openlit) is an OpenTelemetry-native tool for monitoring Ollama Applications & GPUs using traces and metrics. - [OpenLIT](https://github.com/openlit/openlit) is an OpenTelemetry-native tool for monitoring Ollama Applications & GPUs using traces and metrics.
@@ -644,4 +678,5 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [MLflow Tracing](https://mlflow.org/docs/latest/llms/tracing/index.html#automatic-tracing) is an open source LLM observability tool with a convenient API to log and visualize traces, making it easy to debug and evaluate GenAI applications. - [MLflow Tracing](https://mlflow.org/docs/latest/llms/tracing/index.html#automatic-tracing) is an open source LLM observability tool with a convenient API to log and visualize traces, making it easy to debug and evaluate GenAI applications.
### Security ### Security
- [Ollama Fortress](https://github.com/ParisNeo/ollama_proxy_server) - [Ollama Fortress](https://github.com/ParisNeo/ollama_proxy_server)

View File

@@ -127,6 +127,20 @@ type GenerateRequest struct {
// each with an associated log probability. Only applies when Logprobs is true. // each with an associated log probability. Only applies when Logprobs is true.
// Valid values are 0-20. Default is 0 (only return the selected token's logprob). // Valid values are 0-20. Default is 0 (only return the selected token's logprob).
TopLogprobs int `json:"top_logprobs,omitempty"` TopLogprobs int `json:"top_logprobs,omitempty"`
// Experimental: Image generation fields (may change or be removed)
// Width is the width of the generated image in pixels.
// Only used for image generation models.
Width int32 `json:"width,omitempty"`
// Height is the height of the generated image in pixels.
// Only used for image generation models.
Height int32 `json:"height,omitempty"`
// Steps is the number of diffusion steps for image generation.
// Only used for image generation models.
Steps int32 `json:"steps,omitempty"`
} }
// ChatRequest describes a request sent by [Client.Chat]. // ChatRequest describes a request sent by [Client.Chat].
@@ -860,6 +874,20 @@ type GenerateResponse struct {
// Logprobs contains log probability information for the generated tokens, // Logprobs contains log probability information for the generated tokens,
// if requested via the Logprobs parameter. // if requested via the Logprobs parameter.
Logprobs []Logprob `json:"logprobs,omitempty"` Logprobs []Logprob `json:"logprobs,omitempty"`
// Experimental: Image generation fields (may change or be removed)
// Image contains a base64-encoded generated image.
// Only present for image generation models.
Image string `json:"image,omitempty"`
// Completed is the number of completed steps in image generation.
// Only present for image generation models during streaming.
Completed int64 `json:"completed,omitempty"`
// Total is the total number of steps for image generation.
// Only present for image generation models during streaming.
Total int64 `json:"total,omitempty"`
} }
// ModelDetails provides details about a model. // ModelDetails provides details about a model.

View File

@@ -253,6 +253,8 @@ func main() {
done <- osrv.Run(octx) done <- osrv.Run(octx)
}() }()
upd := &updater.Updater{Store: st}
uiServer := ui.Server{ uiServer := ui.Server{
Token: token, Token: token,
Restart: func() { Restart: func() {
@@ -267,6 +269,10 @@ func main() {
ToolRegistry: toolRegistry, ToolRegistry: toolRegistry,
Dev: devMode, Dev: devMode,
Logger: slog.Default(), Logger: slog.Default(),
Updater: upd,
UpdateAvailableFunc: func() {
UpdateAvailable("")
},
} }
srv := &http.Server{ srv := &http.Server{
@@ -284,8 +290,13 @@ func main() {
slog.Debug("background desktop server done") slog.Debug("background desktop server done")
}() }()
updater := &updater.Updater{Store: st} upd.StartBackgroundUpdaterChecker(ctx, UpdateAvailable)
updater.StartBackgroundUpdaterChecker(ctx, UpdateAvailable)
// Check for pending updates on startup (show tray notification if update is ready)
if updater.IsUpdatePending() {
slog.Debug("update pending on startup, showing tray notification")
UpdateAvailable("")
}
hasCompletedFirstRun, err := st.HasCompletedFirstRun() hasCompletedFirstRun, err := st.HasCompletedFirstRun()
if err != nil { if err != nil {
@@ -348,6 +359,18 @@ func startHiddenTasks() {
// CLI triggered app startup use-case // CLI triggered app startup use-case
slog.Info("deferring pending update for fast startup") slog.Info("deferring pending update for fast startup")
} else { } else {
// Check if auto-update is enabled before automatically upgrading
st := &store.Store{}
settings, err := st.Settings()
if err != nil {
slog.Warn("failed to load settings for upgrade check", "error", err)
} else if !settings.AutoUpdateEnabled {
slog.Info("auto-update disabled, skipping automatic upgrade at startup")
// Still show tray notification so user knows update is ready
UpdateAvailable("")
return
}
if err := updater.DoUpgradeAtStartup(); err != nil { if err := updater.DoUpgradeAtStartup(); err != nil {
slog.Info("unable to perform upgrade at startup", "error", err) slog.Info("unable to perform upgrade at startup", "error", err)
// Make sure the restart to upgrade menu shows so we can attempt an interactive upgrade to get authorization // Make sure the restart to upgrade menu shows so we can attempt an interactive upgrade to get authorization

View File

@@ -14,6 +14,7 @@ extern NSString *SystemWidePath;
@interface AppDelegate () <NSWindowDelegate, WKNavigationDelegate, WKUIDelegate> @interface AppDelegate () <NSWindowDelegate, WKNavigationDelegate, WKUIDelegate>
@property(strong, nonatomic) NSStatusItem *statusItem; @property(strong, nonatomic) NSStatusItem *statusItem;
@property(assign, nonatomic) BOOL updateAvailable; @property(assign, nonatomic) BOOL updateAvailable;
@property(assign, nonatomic) BOOL systemShutdownInProgress;
@end @end
@implementation AppDelegate @implementation AppDelegate
@@ -40,6 +41,13 @@ bool firstTimeRun,startHidden; // Set in run before initialization
} }
- (void)applicationDidFinishLaunching:(NSNotification *)aNotification { - (void)applicationDidFinishLaunching:(NSNotification *)aNotification {
// Register for system shutdown/restart notification so we can allow termination
[[[NSWorkspace sharedWorkspace] notificationCenter]
addObserver:self
selector:@selector(systemWillPowerOff:)
name:NSWorkspaceWillPowerOffNotification
object:nil];
// if we're in development mode, set the app icon // if we're in development mode, set the app icon
NSString *bundlePath = [[NSBundle mainBundle] bundlePath]; NSString *bundlePath = [[NSBundle mainBundle] bundlePath];
if (![bundlePath hasSuffix:@".app"]) { if (![bundlePath hasSuffix:@".app"]) {
@@ -278,7 +286,18 @@ bool firstTimeRun,startHidden; // Set in run before initialization
[NSApp activateIgnoringOtherApps:YES]; [NSApp activateIgnoringOtherApps:YES];
} }
- (void)systemWillPowerOff:(NSNotification *)notification {
// Set flag so applicationShouldTerminate: knows to allow termination.
// The system will call applicationShouldTerminate: after posting this notification.
self.systemShutdownInProgress = YES;
}
- (NSApplicationTerminateReply)applicationShouldTerminate:(NSApplication *)sender { - (NSApplicationTerminateReply)applicationShouldTerminate:(NSApplication *)sender {
// Allow termination if the system is shutting down or restarting
if (self.systemShutdownInProgress) {
return NSTerminateNow;
}
// Otherwise just hide the app (for Cmd+Q, close button, etc.)
[NSApp hide:nil]; [NSApp hide:nil];
[NSApp setActivationPolicy:NSApplicationActivationPolicyAccessory]; [NSApp setActivationPolicy:NSApplicationActivationPolicyAccessory];
return NSTerminateCancel; return NSTerminateCancel;

View File

@@ -9,12 +9,12 @@ import (
"strings" "strings"
"time" "time"
sqlite3 "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
) )
// currentSchemaVersion defines the current database schema version. // currentSchemaVersion defines the current database schema version.
// Increment this when making schema changes that require migrations. // Increment this when making schema changes that require migrations.
const currentSchemaVersion = 12 const currentSchemaVersion = 13
// database wraps the SQLite connection. // database wraps the SQLite connection.
// SQLite handles its own locking for concurrent access: // SQLite handles its own locking for concurrent access:
@@ -85,6 +85,7 @@ func (db *database) init() error {
think_enabled BOOLEAN NOT NULL DEFAULT 0, think_enabled BOOLEAN NOT NULL DEFAULT 0,
think_level TEXT NOT NULL DEFAULT '', think_level TEXT NOT NULL DEFAULT '',
remote TEXT NOT NULL DEFAULT '', -- deprecated remote TEXT NOT NULL DEFAULT '', -- deprecated
auto_update_enabled BOOLEAN NOT NULL DEFAULT 1,
schema_version INTEGER NOT NULL DEFAULT %d schema_version INTEGER NOT NULL DEFAULT %d
); );
@@ -244,6 +245,12 @@ func (db *database) migrate() error {
return fmt.Errorf("migrate v11 to v12: %w", err) return fmt.Errorf("migrate v11 to v12: %w", err)
} }
version = 12 version = 12
case 12:
// add auto_update_enabled column to settings table
if err := db.migrateV12ToV13(); err != nil {
return fmt.Errorf("migrate v12 to v13: %w", err)
}
version = 13
default: default:
// If we have a version we don't recognize, just set it to current // If we have a version we don't recognize, just set it to current
// This might happen during development // This might happen during development
@@ -452,6 +459,21 @@ func (db *database) migrateV11ToV12() error {
return nil return nil
} }
// migrateV12ToV13 adds the auto_update_enabled column to the settings table
func (db *database) migrateV12ToV13() error {
_, err := db.conn.Exec(`ALTER TABLE settings ADD COLUMN auto_update_enabled BOOLEAN NOT NULL DEFAULT 1`)
if err != nil && !duplicateColumnError(err) {
return fmt.Errorf("add auto_update_enabled column: %w", err)
}
_, err = db.conn.Exec(`UPDATE settings SET schema_version = 13`)
if err != nil {
return fmt.Errorf("update schema version: %w", err)
}
return nil
}
// cleanupOrphanedData removes orphaned records that may exist due to the foreign key bug // cleanupOrphanedData removes orphaned records that may exist due to the foreign key bug
func (db *database) cleanupOrphanedData() error { func (db *database) cleanupOrphanedData() error {
_, err := db.conn.Exec(` _, err := db.conn.Exec(`
@@ -482,19 +504,11 @@ func (db *database) cleanupOrphanedData() error {
} }
func duplicateColumnError(err error) bool { func duplicateColumnError(err error) bool {
if sqlite3Err, ok := err.(sqlite3.Error); ok { return err != nil && strings.Contains(err.Error(), "duplicate column name")
return sqlite3Err.Code == sqlite3.ErrError &&
strings.Contains(sqlite3Err.Error(), "duplicate column name")
}
return false
} }
func columnNotExists(err error) bool { func columnNotExists(err error) bool {
if sqlite3Err, ok := err.(sqlite3.Error); ok { return err != nil && strings.Contains(err.Error(), "no such column")
return sqlite3Err.Code == sqlite3.ErrError &&
strings.Contains(sqlite3Err.Error(), "no such column")
}
return false
} }
func (db *database) getAllChats() ([]Chat, error) { func (db *database) getAllChats() ([]Chat, error) {
@@ -1108,9 +1122,9 @@ func (db *database) getSettings() (Settings, error) {
var s Settings var s Settings
err := db.conn.QueryRow(` err := db.conn.QueryRow(`
SELECT expose, survey, browser, models, agent, tools, working_dir, context_length, airplane_mode, turbo_enabled, websearch_enabled, selected_model, sidebar_open, think_enabled, think_level SELECT expose, survey, browser, models, agent, tools, working_dir, context_length, airplane_mode, turbo_enabled, websearch_enabled, selected_model, sidebar_open, think_enabled, think_level, auto_update_enabled
FROM settings FROM settings
`).Scan(&s.Expose, &s.Survey, &s.Browser, &s.Models, &s.Agent, &s.Tools, &s.WorkingDir, &s.ContextLength, &s.AirplaneMode, &s.TurboEnabled, &s.WebSearchEnabled, &s.SelectedModel, &s.SidebarOpen, &s.ThinkEnabled, &s.ThinkLevel) `).Scan(&s.Expose, &s.Survey, &s.Browser, &s.Models, &s.Agent, &s.Tools, &s.WorkingDir, &s.ContextLength, &s.AirplaneMode, &s.TurboEnabled, &s.WebSearchEnabled, &s.SelectedModel, &s.SidebarOpen, &s.ThinkEnabled, &s.ThinkLevel, &s.AutoUpdateEnabled)
if err != nil { if err != nil {
return Settings{}, fmt.Errorf("get settings: %w", err) return Settings{}, fmt.Errorf("get settings: %w", err)
} }
@@ -1121,8 +1135,8 @@ func (db *database) getSettings() (Settings, error) {
func (db *database) setSettings(s Settings) error { func (db *database) setSettings(s Settings) error {
_, err := db.conn.Exec(` _, err := db.conn.Exec(`
UPDATE settings UPDATE settings
SET expose = ?, survey = ?, browser = ?, models = ?, agent = ?, tools = ?, working_dir = ?, context_length = ?, airplane_mode = ?, turbo_enabled = ?, websearch_enabled = ?, selected_model = ?, sidebar_open = ?, think_enabled = ?, think_level = ? SET expose = ?, survey = ?, browser = ?, models = ?, agent = ?, tools = ?, working_dir = ?, context_length = ?, airplane_mode = ?, turbo_enabled = ?, websearch_enabled = ?, selected_model = ?, sidebar_open = ?, think_enabled = ?, think_level = ?, auto_update_enabled = ?
`, s.Expose, s.Survey, s.Browser, s.Models, s.Agent, s.Tools, s.WorkingDir, s.ContextLength, s.AirplaneMode, s.TurboEnabled, s.WebSearchEnabled, s.SelectedModel, s.SidebarOpen, s.ThinkEnabled, s.ThinkLevel) `, s.Expose, s.Survey, s.Browser, s.Models, s.Agent, s.Tools, s.WorkingDir, s.ContextLength, s.AirplaneMode, s.TurboEnabled, s.WebSearchEnabled, s.SelectedModel, s.SidebarOpen, s.ThinkEnabled, s.ThinkLevel, s.AutoUpdateEnabled)
if err != nil { if err != nil {
return fmt.Errorf("set settings: %w", err) return fmt.Errorf("set settings: %w", err)
} }

View File

@@ -169,6 +169,9 @@ type Settings struct {
// SidebarOpen indicates if the chat sidebar is open // SidebarOpen indicates if the chat sidebar is open
SidebarOpen bool SidebarOpen bool
// AutoUpdateEnabled indicates if automatic updates should be downloaded
AutoUpdateEnabled bool
} }
type Store struct { type Store struct {

View File

@@ -413,6 +413,7 @@ export class Settings {
ThinkLevel: string; ThinkLevel: string;
SelectedModel: string; SelectedModel: string;
SidebarOpen: boolean; SidebarOpen: boolean;
AutoUpdateEnabled: boolean;
constructor(source: any = {}) { constructor(source: any = {}) {
if ('string' === typeof source) source = JSON.parse(source); if ('string' === typeof source) source = JSON.parse(source);
@@ -431,6 +432,7 @@ export class Settings {
this.ThinkLevel = source["ThinkLevel"]; this.ThinkLevel = source["ThinkLevel"];
this.SelectedModel = source["SelectedModel"]; this.SelectedModel = source["SelectedModel"];
this.SidebarOpen = source["SidebarOpen"]; this.SidebarOpen = source["SidebarOpen"];
this.AutoUpdateEnabled = source["AutoUpdateEnabled"];
} }
} }
export class SettingsResponse { export class SettingsResponse {
@@ -467,6 +469,46 @@ export class HealthResponse {
this.healthy = source["healthy"]; this.healthy = source["healthy"];
} }
} }
export class UpdateInfo {
currentVersion: string;
availableVersion: string;
updateAvailable: boolean;
updateDownloaded: boolean;
constructor(source: any = {}) {
if ('string' === typeof source) source = JSON.parse(source);
this.currentVersion = source["currentVersion"];
this.availableVersion = source["availableVersion"];
this.updateAvailable = source["updateAvailable"];
this.updateDownloaded = source["updateDownloaded"];
}
}
export class UpdateCheckResponse {
updateInfo: UpdateInfo;
constructor(source: any = {}) {
if ('string' === typeof source) source = JSON.parse(source);
this.updateInfo = this.convertValues(source["updateInfo"], UpdateInfo);
}
convertValues(a: any, classs: any, asMap: boolean = false): any {
if (!a) {
return a;
}
if (Array.isArray(a)) {
return (a as any[]).map(elem => this.convertValues(elem, classs));
} else if ("object" === typeof a) {
if (asMap) {
for (const key of Object.keys(a)) {
a[key] = new classs(a[key]);
}
return a;
}
return new classs(a);
}
return a;
}
}
export class User { export class User {
id: string; id: string;
email: string; email: string;

View File

@@ -414,3 +414,54 @@ export async function fetchHealth(): Promise<boolean> {
return false; return false;
} }
} }
export async function getCurrentVersion(): Promise<string> {
try {
const response = await fetch(`${API_BASE}/api/version`, {
method: "GET",
headers: {
"Content-Type": "application/json",
},
});
if (response.ok) {
const data = await response.json();
return data.version || "Unknown";
}
return "Unknown";
} catch (error) {
console.error("Error fetching version:", error);
return "Unknown";
}
}
export async function checkForUpdate(): Promise<{
currentVersion: string;
availableVersion: string;
updateAvailable: boolean;
updateDownloaded: boolean;
}> {
const response = await fetch(`${API_BASE}/api/v1/update/check`, {
method: "GET",
headers: {
"Content-Type": "application/json",
},
});
if (!response.ok) {
throw new Error("Failed to check for update");
}
const data = await response.json();
return data.updateInfo;
}
export async function installUpdate(): Promise<void> {
const response = await fetch(`${API_BASE}/api/v1/update/install`, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
});
if (!response.ok) {
const error = await response.text();
throw new Error(error || "Failed to install update");
}
}

View File

@@ -14,12 +14,13 @@ import {
XMarkIcon, XMarkIcon,
CogIcon, CogIcon,
ArrowLeftIcon, ArrowLeftIcon,
ArrowDownTrayIcon,
} from "@heroicons/react/20/solid"; } from "@heroicons/react/20/solid";
import { Settings as SettingsType } from "@/gotypes"; import { Settings as SettingsType } from "@/gotypes";
import { useNavigate } from "@tanstack/react-router"; import { useNavigate } from "@tanstack/react-router";
import { useUser } from "@/hooks/useUser"; import { useUser } from "@/hooks/useUser";
import { useQuery, useMutation, useQueryClient } from "@tanstack/react-query"; import { useQuery, useMutation, useQueryClient } from "@tanstack/react-query";
import { getSettings, updateSettings } from "@/api"; import { getSettings, updateSettings, checkForUpdate } from "@/api";
function AnimatedDots() { function AnimatedDots() {
return ( return (
@@ -39,6 +40,12 @@ export default function Settings() {
const queryClient = useQueryClient(); const queryClient = useQueryClient();
const [showSaved, setShowSaved] = useState(false); const [showSaved, setShowSaved] = useState(false);
const [restartMessage, setRestartMessage] = useState(false); const [restartMessage, setRestartMessage] = useState(false);
const [updateInfo, setUpdateInfo] = useState<{
currentVersion: string;
availableVersion: string;
updateAvailable: boolean;
updateDownloaded: boolean;
} | null>(null);
const { const {
user, user,
isAuthenticated, isAuthenticated,
@@ -76,6 +83,10 @@ export default function Settings() {
useEffect(() => { useEffect(() => {
refetchUser(); refetchUser();
// Check for updates on mount
checkForUpdate()
.then(setUpdateInfo)
.catch((err) => console.error("Error checking for update:", err));
}, []); // eslint-disable-line react-hooks/exhaustive-deps }, []); // eslint-disable-line react-hooks/exhaustive-deps
useEffect(() => { useEffect(() => {
@@ -344,6 +355,58 @@ export default function Settings() {
{/* Local Configuration */} {/* Local Configuration */}
<div className="relative overflow-hidden rounded-xl bg-white dark:bg-neutral-800"> <div className="relative overflow-hidden rounded-xl bg-white dark:bg-neutral-800">
<div className="space-y-4 p-4"> <div className="space-y-4 p-4">
{/* Auto Update */}
<Field>
<div className="flex items-start justify-between gap-4">
<div className="flex items-start space-x-3 flex-1">
<ArrowDownTrayIcon className="mt-1 h-5 w-5 flex-shrink-0 text-black dark:text-neutral-100" />
<div className="flex-1">
<Label>Auto-download updates</Label>
<Description>
{settings.AutoUpdateEnabled ? (
<>
Automatically downloads updates when available.
<div className="mt-2 text-xs text-zinc-600 dark:text-zinc-400">
Current version: {updateInfo?.currentVersion || "Loading..."}
</div>
</>
) : (
<>
Manually download updates.
<div className="mt-3 p-3 bg-zinc-50 dark:bg-zinc-900 rounded-lg border border-zinc-200 dark:border-zinc-800">
<div className="space-y-2 text-sm">
<div className="flex justify-between">
<span className="text-zinc-600 dark:text-zinc-400">Current version: {updateInfo?.currentVersion || "Loading..."}</span>
</div>
{updateInfo?.availableVersion && (
<div className="flex justify-between">
<span className="text-zinc-600 dark:text-zinc-400">Available version: {updateInfo?.availableVersion}</span>
</div>
)}
</div>
<a
href="https://ollama.com/download"
target="_blank"
rel="noopener noreferrer"
className="mt-3 inline-block text-sm text-neutral-600 dark:text-neutral-400 underline"
>
Download new version
</a>
</div>
</>
)}
</Description>
</div>
</div>
<div className="flex-shrink-0">
<Switch
checked={settings.AutoUpdateEnabled}
onChange={(checked) => handleChange("AutoUpdateEnabled", checked)}
/>
</div>
</div>
</Field>
{/* Expose Ollama */} {/* Expose Ollama */}
<Field> <Field>
<div className="flex items-start justify-between gap-4"> <div className="flex items-start justify-between gap-4">

View File

@@ -100,6 +100,17 @@ type HealthResponse struct {
Healthy bool `json:"healthy"` Healthy bool `json:"healthy"`
} }
type UpdateInfo struct {
CurrentVersion string `json:"currentVersion"`
AvailableVersion string `json:"availableVersion"`
UpdateAvailable bool `json:"updateAvailable"`
UpdateDownloaded bool `json:"updateDownloaded"`
}
type UpdateCheckResponse struct {
UpdateInfo UpdateInfo `json:"updateInfo"`
}
type User struct { type User struct {
ID string `json:"id"` ID string `json:"id"`
Email string `json:"email"` Email string `json:"email"`

View File

@@ -28,6 +28,7 @@ import (
"github.com/ollama/ollama/app/tools" "github.com/ollama/ollama/app/tools"
"github.com/ollama/ollama/app/types/not" "github.com/ollama/ollama/app/types/not"
"github.com/ollama/ollama/app/ui/responses" "github.com/ollama/ollama/app/ui/responses"
"github.com/ollama/ollama/app/updater"
"github.com/ollama/ollama/app/version" "github.com/ollama/ollama/app/version"
ollamaAuth "github.com/ollama/ollama/auth" ollamaAuth "github.com/ollama/ollama/auth"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
@@ -106,6 +107,18 @@ type Server struct {
// Dev is true if the server is running in development mode // Dev is true if the server is running in development mode
Dev bool Dev bool
// Updater for checking and downloading updates
Updater UpdaterInterface
UpdateAvailableFunc func()
}
// UpdaterInterface defines the methods we need from the updater
type UpdaterInterface interface {
CheckForUpdate(ctx context.Context) (bool, string, error)
InstallAndRestart() error
CancelOngoingDownload()
TriggerImmediateCheck()
} }
func (s *Server) log() *slog.Logger { func (s *Server) log() *slog.Logger {
@@ -284,6 +297,8 @@ func (s *Server) Handler() http.Handler {
mux.Handle("POST /api/v1/model/upstream", handle(s.modelUpstream)) mux.Handle("POST /api/v1/model/upstream", handle(s.modelUpstream))
mux.Handle("GET /api/v1/settings", handle(s.getSettings)) mux.Handle("GET /api/v1/settings", handle(s.getSettings))
mux.Handle("POST /api/v1/settings", handle(s.settings)) mux.Handle("POST /api/v1/settings", handle(s.settings))
mux.Handle("GET /api/v1/update/check", handle(s.checkForUpdate))
mux.Handle("POST /api/v1/update/install", handle(s.installUpdate))
// Ollama proxy endpoints // Ollama proxy endpoints
ollamaProxy := s.ollamaProxy() ollamaProxy := s.ollamaProxy()
@@ -1448,6 +1463,24 @@ func (s *Server) settings(w http.ResponseWriter, r *http.Request) error {
return fmt.Errorf("failed to save settings: %w", err) return fmt.Errorf("failed to save settings: %w", err)
} }
// Handle auto-update toggle changes
if old.AutoUpdateEnabled != settings.AutoUpdateEnabled {
if !settings.AutoUpdateEnabled {
// Auto-update disabled: cancel any ongoing download
if s.Updater != nil {
s.Updater.CancelOngoingDownload()
}
} else {
// Auto-update re-enabled: show notification if update is already staged, or trigger immediate check
if (updater.IsUpdatePending() || updater.UpdateDownloaded) && s.UpdateAvailableFunc != nil {
s.UpdateAvailableFunc()
} else if s.Updater != nil {
// Trigger the background checker to run immediately
s.Updater.TriggerImmediateCheck()
}
}
}
if old.ContextLength != settings.ContextLength || if old.ContextLength != settings.ContextLength ||
old.Models != settings.Models || old.Models != settings.Models ||
old.Expose != settings.Expose { old.Expose != settings.Expose {
@@ -1524,6 +1557,73 @@ func (s *Server) modelUpstream(w http.ResponseWriter, r *http.Request) error {
return json.NewEncoder(w).Encode(response) return json.NewEncoder(w).Encode(response)
} }
func (s *Server) checkForUpdate(w http.ResponseWriter, r *http.Request) error {
currentVersion := version.Version
if s.Updater == nil {
return fmt.Errorf("updater not available")
}
updateAvailable, updateVersion, err := s.Updater.CheckForUpdate(r.Context())
if err != nil {
s.log().Warn("failed to check for update", "error", err)
// Don't return error, just log it and continue with no update available
}
response := responses.UpdateCheckResponse{
UpdateInfo: responses.UpdateInfo{
CurrentVersion: currentVersion,
AvailableVersion: updateVersion,
UpdateAvailable: updateAvailable,
UpdateDownloaded: updater.UpdateDownloaded,
},
}
w.Header().Set("Content-Type", "application/json")
return json.NewEncoder(w).Encode(response)
}
func (s *Server) installUpdate(w http.ResponseWriter, r *http.Request) error {
if r.Method != "POST" {
return fmt.Errorf("method not allowed")
}
if s.Updater == nil {
s.log().Error("install failed: updater not available")
return fmt.Errorf("updater not available")
}
// Check if update is downloaded
if !updater.UpdateDownloaded {
s.log().Error("install failed: no update downloaded")
return fmt.Errorf("no update downloaded")
}
// Send response before restarting
response := map[string]any{
"success": true,
"message": "Installing update and restarting...",
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(response); err != nil {
return err
}
// Give the response time to be sent
time.Sleep(500 * time.Millisecond)
// Trigger the upgrade and restart
go func() {
time.Sleep(500 * time.Millisecond)
if err := s.Updater.InstallAndRestart(); err != nil {
s.log().Error("failed to install update", "error", err)
}
}()
return nil
}
func userAgent() string { func userAgent() string {
buildinfo, _ := debug.ReadBuildInfo() buildinfo, _ := debug.ReadBuildInfo()

View File

@@ -19,6 +19,7 @@ import (
"runtime" "runtime"
"strconv" "strconv"
"strings" "strings"
"sync"
"time" "time"
"github.com/ollama/ollama/app/store" "github.com/ollama/ollama/app/store"
@@ -58,7 +59,8 @@ func (u *Updater) checkForUpdate(ctx context.Context) (bool, UpdateResponse) {
query := requestURL.Query() query := requestURL.Query()
query.Add("os", runtime.GOOS) query.Add("os", runtime.GOOS)
query.Add("arch", runtime.GOARCH) query.Add("arch", runtime.GOARCH)
query.Add("version", version.Version) currentVersion := version.Version
query.Add("version", currentVersion)
query.Add("ts", strconv.FormatInt(time.Now().Unix(), 10)) query.Add("ts", strconv.FormatInt(time.Now().Unix(), 10))
// The original macOS app used to use the device ID // The original macOS app used to use the device ID
@@ -131,15 +133,27 @@ func (u *Updater) checkForUpdate(ctx context.Context) (bool, UpdateResponse) {
} }
func (u *Updater) DownloadNewRelease(ctx context.Context, updateResp UpdateResponse) error { func (u *Updater) DownloadNewRelease(ctx context.Context, updateResp UpdateResponse) error {
// Create a cancellable context for this download
downloadCtx, cancel := context.WithCancel(ctx)
u.cancelDownloadLock.Lock()
u.cancelDownload = cancel
u.cancelDownloadLock.Unlock()
defer func() {
u.cancelDownloadLock.Lock()
u.cancelDownload = nil
u.cancelDownloadLock.Unlock()
cancel()
}()
// Do a head first to check etag info // Do a head first to check etag info
req, err := http.NewRequestWithContext(ctx, http.MethodHead, updateResp.UpdateURL, nil) req, err := http.NewRequestWithContext(downloadCtx, http.MethodHead, updateResp.UpdateURL, nil)
if err != nil { if err != nil {
return err return err
} }
// In case of slow downloads, continue the update check in the background // In case of slow downloads, continue the update check in the background
bgctx, cancel := context.WithCancel(ctx) bgctx, bgcancel := context.WithCancel(downloadCtx)
defer cancel() defer bgcancel()
go func() { go func() {
for { for {
select { select {
@@ -176,6 +190,7 @@ func (u *Updater) DownloadNewRelease(ctx context.Context, updateResp UpdateRespo
_, err = os.Stat(stageFilename) _, err = os.Stat(stageFilename)
if err == nil { if err == nil {
slog.Info("update already downloaded", "bundle", stageFilename) slog.Info("update already downloaded", "bundle", stageFilename)
UpdateDownloaded = true
return nil return nil
} }
@@ -244,34 +259,95 @@ func cleanupOldDownloads(stageDir string) {
} }
type Updater struct { type Updater struct {
Store *store.Store Store *store.Store
cancelDownload context.CancelFunc
cancelDownloadLock sync.Mutex
checkNow chan struct{}
}
// CancelOngoingDownload cancels any currently running download
func (u *Updater) CancelOngoingDownload() {
u.cancelDownloadLock.Lock()
defer u.cancelDownloadLock.Unlock()
if u.cancelDownload != nil {
slog.Info("cancelling ongoing update download")
u.cancelDownload()
u.cancelDownload = nil
}
}
// TriggerImmediateCheck signals the background checker to check for updates immediately
func (u *Updater) TriggerImmediateCheck() {
if u.checkNow != nil {
u.checkNow <- struct{}{}
}
} }
func (u *Updater) StartBackgroundUpdaterChecker(ctx context.Context, cb func(string) error) { func (u *Updater) StartBackgroundUpdaterChecker(ctx context.Context, cb func(string) error) {
u.checkNow = make(chan struct{}, 1)
go func() { go func() {
// Don't blast an update message immediately after startup // Don't blast an update message immediately after startup
time.Sleep(UpdateCheckInitialDelay) time.Sleep(UpdateCheckInitialDelay)
slog.Info("beginning update checker", "interval", UpdateCheckInterval) slog.Info("beginning update checker", "interval", UpdateCheckInterval)
ticker := time.NewTicker(UpdateCheckInterval)
defer ticker.Stop()
for { for {
available, resp := u.checkForUpdate(ctx)
if available {
err := u.DownloadNewRelease(ctx, resp)
if err != nil {
slog.Error(fmt.Sprintf("failed to download new release: %s", err))
} else {
err = cb(resp.UpdateVersion)
if err != nil {
slog.Warn(fmt.Sprintf("failed to register update available with tray: %s", err))
}
}
}
select { select {
case <-ctx.Done(): case <-ctx.Done():
slog.Debug("stopping background update checker") slog.Debug("stopping background update checker")
return return
default: case <-u.checkNow:
time.Sleep(UpdateCheckInterval) // Immediate check triggered
case <-ticker.C:
// Regular interval check
}
// Always check for updates
available, resp := u.checkForUpdate(ctx)
if !available {
continue
}
// Update is available - check if auto-update is enabled for downloading
settings, err := u.Store.Settings()
if err != nil {
slog.Error("failed to load settings", "error", err)
continue
}
if !settings.AutoUpdateEnabled {
// Auto-update disabled - don't download, just log
slog.Debug("update available but auto-update disabled", "version", resp.UpdateVersion)
continue
}
// Auto-update is enabled - download
err = u.DownloadNewRelease(ctx, resp)
if err != nil {
slog.Error("failed to download new release", "error", err)
continue
}
// Download successful - show tray notification (regardless of toggle state)
err = cb(resp.UpdateVersion)
if err != nil {
slog.Warn("failed to register update available with tray", "error", err)
} }
} }
}() }()
} }
func (u *Updater) CheckForUpdate(ctx context.Context) (bool, string, error) {
available, resp := u.checkForUpdate(ctx)
return available, resp.UpdateVersion, nil
}
func (u *Updater) InstallAndRestart() error {
if !UpdateDownloaded {
return fmt.Errorf("no update downloaded")
}
slog.Info("installing update and restarting")
return DoUpgrade(true)
}

View File

@@ -11,6 +11,7 @@ import (
"log/slog" "log/slog"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"sync/atomic"
"testing" "testing"
"time" "time"
@@ -85,7 +86,17 @@ func TestBackgoundChecker(t *testing.T) {
UpdateCheckURLBase = server.URL + "/update.json" UpdateCheckURLBase = server.URL + "/update.json"
updater := &Updater{Store: &store.Store{}} updater := &Updater{Store: &store.Store{}}
defer updater.Store.Close() // Ensure database is closed defer updater.Store.Close()
settings, err := updater.Store.Settings()
if err != nil {
t.Fatal(err)
}
settings.AutoUpdateEnabled = true
if err := updater.Store.SetSettings(settings); err != nil {
t.Fatal(err)
}
updater.StartBackgroundUpdaterChecker(ctx, cb) updater.StartBackgroundUpdaterChecker(ctx, cb)
select { select {
case <-stallTimer.C: case <-stallTimer.C:
@@ -99,3 +110,187 @@ func TestBackgoundChecker(t *testing.T) {
} }
} }
} }
func TestAutoUpdateDisabledSkipsDownload(t *testing.T) {
UpdateStageDir = t.TempDir()
var downloadAttempted atomic.Bool
done := make(chan struct{})
ctx, cancel := context.WithCancel(t.Context())
defer cancel()
UpdateCheckInitialDelay = 5 * time.Millisecond
UpdateCheckInterval = 5 * time.Millisecond
VerifyDownload = func() error {
return nil
}
var server *httptest.Server
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/update.json" {
w.Write([]byte(
fmt.Sprintf(`{"version": "9.9.9", "url": "%s"}`,
server.URL+"/9.9.9/"+Installer)))
} else if r.URL.Path == "/9.9.9/"+Installer {
downloadAttempted.Store(true)
buf := &bytes.Buffer{}
zw := zip.NewWriter(buf)
zw.Close()
io.Copy(w, buf)
}
}))
defer server.Close()
UpdateCheckURLBase = server.URL + "/update.json"
updater := &Updater{Store: &store.Store{}}
defer updater.Store.Close()
// Ensure auto-update is disabled
settings, err := updater.Store.Settings()
if err != nil {
t.Fatal(err)
}
settings.AutoUpdateEnabled = false
if err := updater.Store.SetSettings(settings); err != nil {
t.Fatal(err)
}
cb := func(ver string) error {
t.Fatal("callback should not be called when auto-update is disabled")
return nil
}
updater.StartBackgroundUpdaterChecker(ctx, cb)
// Wait enough time for multiple check cycles
time.Sleep(50 * time.Millisecond)
close(done)
if downloadAttempted.Load() {
t.Fatal("download should not be attempted when auto-update is disabled")
}
}
func TestCancelOngoingDownload(t *testing.T) {
UpdateStageDir = t.TempDir()
downloadStarted := make(chan struct{})
downloadCancelled := make(chan struct{})
ctx := t.Context()
VerifyDownload = func() error {
return nil
}
var server *httptest.Server
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/update.json" {
w.Write([]byte(
fmt.Sprintf(`{"version": "9.9.9", "url": "%s"}`,
server.URL+"/9.9.9/"+Installer)))
} else if r.URL.Path == "/9.9.9/"+Installer {
if r.Method == http.MethodHead {
w.Header().Set("Content-Length", "1000000")
w.WriteHeader(http.StatusOK)
return
}
// Signal that download has started
close(downloadStarted)
// Wait for cancellation or timeout
select {
case <-r.Context().Done():
close(downloadCancelled)
return
case <-time.After(5 * time.Second):
t.Error("download was not cancelled in time")
}
}
}))
defer server.Close()
UpdateCheckURLBase = server.URL + "/update.json"
updater := &Updater{Store: &store.Store{}}
defer updater.Store.Close()
_, resp := updater.checkForUpdate(ctx)
// Start download in goroutine
go func() {
_ = updater.DownloadNewRelease(ctx, resp)
}()
// Wait for download to start
select {
case <-downloadStarted:
case <-time.After(2 * time.Second):
t.Fatal("download did not start in time")
}
// Cancel the download
updater.CancelOngoingDownload()
// Verify cancellation was received
select {
case <-downloadCancelled:
// Success
case <-time.After(2 * time.Second):
t.Fatal("download cancellation was not received by server")
}
}
func TestTriggerImmediateCheck(t *testing.T) {
UpdateStageDir = t.TempDir()
checkCount := atomic.Int32{}
checkDone := make(chan struct{}, 10)
ctx, cancel := context.WithCancel(t.Context())
defer cancel()
// Set a very long interval so only TriggerImmediateCheck causes checks
UpdateCheckInitialDelay = 1 * time.Millisecond
UpdateCheckInterval = 1 * time.Hour
VerifyDownload = func() error {
return nil
}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/update.json" {
checkCount.Add(1)
select {
case checkDone <- struct{}{}:
default:
}
// Return no update available
w.WriteHeader(http.StatusNoContent)
}
}))
defer server.Close()
UpdateCheckURLBase = server.URL + "/update.json"
updater := &Updater{Store: &store.Store{}}
defer updater.Store.Close()
cb := func(ver string) error {
return nil
}
updater.StartBackgroundUpdaterChecker(ctx, cb)
// Wait for goroutine to start and pass initial delay
time.Sleep(10 * time.Millisecond)
// With 1 hour interval, no check should have happened yet
initialCount := checkCount.Load()
// Trigger immediate check
updater.TriggerImmediateCheck()
// Wait for the triggered check
select {
case <-checkDone:
case <-time.After(2 * time.Second):
t.Fatal("triggered check did not happen")
}
finalCount := checkCount.Load()
if finalCount <= initialCount {
t.Fatalf("TriggerImmediateCheck did not cause additional check: initial=%d, final=%d", initialCount, finalCount)
}
}

View File

@@ -369,25 +369,6 @@ func (t *winTray) addSeparatorMenuItem(menuItemId, parentId uint32) error {
return nil return nil
} }
// func (t *winTray) hideMenuItem(menuItemId, parentId uint32) error {
// const ERROR_SUCCESS syscall.Errno = 0
// t.muMenus.RLock()
// menu := uintptr(t.menus[parentId])
// t.muMenus.RUnlock()
// res, _, err := pRemoveMenu.Call(
// menu,
// uintptr(menuItemId),
// MF_BYCOMMAND,
// )
// if res == 0 && err.(syscall.Errno) != ERROR_SUCCESS {
// return err
// }
// t.delFromVisibleItems(parentId, menuItemId)
// return nil
// }
func (t *winTray) showMenu() error { func (t *winTray) showMenu() error {
p := point{} p := point{}
boolRet, _, err := pGetCursorPos.Call(uintptr(unsafe.Pointer(&p))) boolRet, _, err := pGetCursorPos.Call(uintptr(unsafe.Pointer(&p)))

View File

@@ -51,7 +51,6 @@ const (
IMAGE_ICON = 1 // Loads an icon IMAGE_ICON = 1 // Loads an icon
LR_DEFAULTSIZE = 0x00000040 // Loads default-size icon for windows(SM_CXICON x SM_CYICON) if cx, cy are set to zero LR_DEFAULTSIZE = 0x00000040 // Loads default-size icon for windows(SM_CXICON x SM_CYICON) if cx, cy are set to zero
LR_LOADFROMFILE = 0x00000010 // Loads the stand-alone image from the file LR_LOADFROMFILE = 0x00000010 // Loads the stand-alone image from the file
MF_BYCOMMAND = 0x00000000
MFS_DISABLED = 0x00000003 MFS_DISABLED = 0x00000003
MFT_SEPARATOR = 0x00000800 MFT_SEPARATOR = 0x00000800
MFT_STRING = 0x00000000 MFT_STRING = 0x00000000

View File

@@ -46,8 +46,9 @@ import (
"github.com/ollama/ollama/types/syncmap" "github.com/ollama/ollama/types/syncmap"
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
xcmd "github.com/ollama/ollama/x/cmd" xcmd "github.com/ollama/ollama/x/cmd"
"github.com/ollama/ollama/x/create"
xcreateclient "github.com/ollama/ollama/x/create/client"
"github.com/ollama/ollama/x/imagegen" "github.com/ollama/ollama/x/imagegen"
imagegenclient "github.com/ollama/ollama/x/imagegen/client"
) )
const ConnectInstructions = "To sign in, navigate to:\n %s\n\n" const ConnectInstructions = "To sign in, navigate to:\n %s\n\n"
@@ -93,15 +94,87 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
p := progress.NewProgress(os.Stderr) p := progress.NewProgress(os.Stderr)
defer p.Stop() defer p.Stop()
// Validate model name early to fail fast
modelName := args[0]
name := model.ParseName(modelName)
if !name.IsValid() {
return fmt.Errorf("invalid model name: %s", modelName)
}
// Check for --experimental flag for safetensors model creation
experimental, _ := cmd.Flags().GetBool("experimental")
if experimental {
// Get Modelfile content - either from -f flag or default to "FROM ."
var reader io.Reader
filename, err := getModelfileName(cmd)
if os.IsNotExist(err) || filename == "" {
// No Modelfile specified or found - use default
reader = strings.NewReader("FROM .\n")
} else if err != nil {
return err
} else {
f, err := os.Open(filename)
if err != nil {
return err
}
defer f.Close()
reader = f
}
// Parse the Modelfile
modelfile, err := parser.ParseFile(reader)
if err != nil {
return fmt.Errorf("failed to parse Modelfile: %w", err)
}
// Extract FROM path and configuration
var modelDir string
mfConfig := &xcreateclient.ModelfileConfig{}
for _, cmd := range modelfile.Commands {
switch cmd.Name {
case "model":
modelDir = cmd.Args
case "template":
mfConfig.Template = cmd.Args
case "system":
mfConfig.System = cmd.Args
case "license":
mfConfig.License = cmd.Args
}
}
if modelDir == "" {
modelDir = "."
}
// Resolve relative paths based on Modelfile location
if !filepath.IsAbs(modelDir) && filename != "" {
modelDir = filepath.Join(filepath.Dir(filename), modelDir)
}
quantize, _ := cmd.Flags().GetString("quantize")
return xcreateclient.CreateModel(xcreateclient.CreateOptions{
ModelName: modelName,
ModelDir: modelDir,
Quantize: quantize,
Modelfile: mfConfig,
}, p)
}
var reader io.Reader var reader io.Reader
filename, err := getModelfileName(cmd) filename, err := getModelfileName(cmd)
if os.IsNotExist(err) { if os.IsNotExist(err) {
if filename == "" { if filename == "" {
// No Modelfile found - check if current directory is an image gen model // No Modelfile found - check if current directory is an image gen model
if imagegen.IsTensorModelDir(".") { if create.IsTensorModelDir(".") {
quantize, _ := cmd.Flags().GetString("quantize") quantize, _ := cmd.Flags().GetString("quantize")
return imagegenclient.CreateModel(args[0], ".", quantize, p) return xcreateclient.CreateModel(xcreateclient.CreateOptions{
ModelName: modelName,
ModelDir: ".",
Quantize: quantize,
}, p)
} }
reader = strings.NewReader("FROM .\n") reader = strings.NewReader("FROM .\n")
} else { } else {
@@ -134,7 +207,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
} }
spinner.Stop() spinner.Stop()
req.Model = args[0] req.Model = modelName
quantize, _ := cmd.Flags().GetString("quantize") quantize, _ := cmd.Flags().GetString("quantize")
if quantize != "" { if quantize != "" {
req.Quantize = quantize req.Quantize = quantize
@@ -527,7 +600,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
} }
// Check if this is an image generation model // Check if this is an image generation model
if slices.Contains(info.Capabilities, model.CapabilityImageGeneration) { if slices.Contains(info.Capabilities, model.CapabilityImage) {
if opts.Prompt == "" && !interactive { if opts.Prompt == "" && !interactive {
return errors.New("image generation models require a prompt. Usage: ollama run " + name + " \"your prompt here\"") return errors.New("image generation models require a prompt. Usage: ollama run " + name + " \"your prompt here\"")
} }
@@ -1742,15 +1815,22 @@ func NewCLI() *cobra.Command {
rootCmd.Flags().BoolP("version", "v", false, "Show version information") rootCmd.Flags().BoolP("version", "v", false, "Show version information")
createCmd := &cobra.Command{ createCmd := &cobra.Command{
Use: "create MODEL", Use: "create MODEL",
Short: "Create a model", Short: "Create a model",
Args: cobra.ExactArgs(1), Args: cobra.ExactArgs(1),
PreRunE: checkServerHeartbeat, PreRunE: func(cmd *cobra.Command, args []string) error {
RunE: CreateHandler, // Skip server check for experimental mode (writes directly to disk)
if experimental, _ := cmd.Flags().GetBool("experimental"); experimental {
return nil
}
return checkServerHeartbeat(cmd, args)
},
RunE: CreateHandler,
} }
createCmd.Flags().StringP("file", "f", "", "Name of the Modelfile (default \"Modelfile\")") createCmd.Flags().StringP("file", "f", "", "Name of the Modelfile (default \"Modelfile\")")
createCmd.Flags().StringP("quantize", "q", "", "Quantize model to this level (e.g. q4_K_M)") createCmd.Flags().StringP("quantize", "q", "", "Quantize model to this level (e.g. q4_K_M)")
createCmd.Flags().Bool("experimental", false, "Enable experimental safetensors model creation")
showCmd := &cobra.Command{ showCmd := &cobra.Command{
Use: "show MODEL", Use: "show MODEL",
@@ -1905,6 +1985,7 @@ func NewCLI() *cobra.Command {
} { } {
switch cmd { switch cmd {
case runCmd: case runCmd:
imagegen.AppendFlagsDocs(cmd)
appendEnvDocs(cmd, []envconfig.EnvVar{envVars["OLLAMA_HOST"], envVars["OLLAMA_NOHISTORY"]}) appendEnvDocs(cmd, []envconfig.EnvVar{envVars["OLLAMA_HOST"], envVars["OLLAMA_NOHISTORY"]})
case serveCmd: case serveCmd:
appendEnvDocs(cmd, []envconfig.EnvVar{ appendEnvDocs(cmd, []envconfig.EnvVar{

View File

@@ -1555,7 +1555,7 @@ func TestShowInfoImageGen(t *testing.T) {
ParameterSize: "10.3B", ParameterSize: "10.3B",
QuantizationLevel: "FP8", QuantizationLevel: "FP8",
}, },
Capabilities: []model.Capability{model.CapabilityImageGeneration}, Capabilities: []model.Capability{model.CapabilityImage},
Requires: "0.14.0", Requires: "0.14.0",
}, false, &b) }, false, &b)
if err != nil { if err != nil {

View File

@@ -116,7 +116,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
Prompt: ">>> ", Prompt: ">>> ",
AltPrompt: "... ", AltPrompt: "... ",
Placeholder: "Send a message (/? for help)", Placeholder: "Send a message (/? for help)",
AltPlaceholder: `Use """ to end multi-line input`, AltPlaceholder: "Press Enter to send",
}) })
if err != nil { if err != nil {
return err return err

View File

@@ -16,6 +16,7 @@
- [Generate Embeddings](#generate-embeddings) - [Generate Embeddings](#generate-embeddings)
- [List Running Models](#list-running-models) - [List Running Models](#list-running-models)
- [Version](#version) - [Version](#version)
- [Experimental: Image Generation](#image-generation-experimental)
## Conventions ## Conventions
@@ -58,6 +59,15 @@ Advanced parameters (optional):
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`) - `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
- `context` (deprecated): the context parameter returned from a previous request to `/generate`, this can be used to keep a short conversational memory - `context` (deprecated): the context parameter returned from a previous request to `/generate`, this can be used to keep a short conversational memory
Experimental image generation parameters (for image generation models only):
> [!WARNING]
> These parameters are experimental and may change in future versions.
- `width`: width of the generated image in pixels
- `height`: height of the generated image in pixels
- `steps`: number of diffusion steps
#### Structured outputs #### Structured outputs
Structured outputs are supported by providing a JSON schema in the `format` parameter. The model will generate a response that matches the schema. See the [structured outputs](#request-structured-outputs) example below. Structured outputs are supported by providing a JSON schema in the `format` parameter. The model will generate a response that matches the schema. See the [structured outputs](#request-structured-outputs) example below.
@@ -1867,3 +1877,55 @@ curl http://localhost:11434/api/version
"version": "0.5.1" "version": "0.5.1"
} }
``` ```
## Experimental Features
### Image Generation (Experimental)
> [!WARNING]
> Image generation is experimental and may change in future versions.
Image generation is now supported through the standard `/api/generate` endpoint when using image generation models. The API automatically detects when an image generation model is being used.
See the [Generate a completion](#generate-a-completion) section for the full API documentation. The experimental image generation parameters (`width`, `height`, `steps`) are documented there.
#### Example
##### Request
```shell
curl http://localhost:11434/api/generate -d '{
"model": "x/z-image-turbo",
"prompt": "a sunset over mountains",
"width": 1024,
"height": 768
}'
```
##### Response (streaming)
Progress updates during generation:
```json
{
"model": "x/z-image-turbo",
"created_at": "2024-01-15T10:30:00.000000Z",
"completed": 5,
"total": 20,
"done": false
}
```
##### Final Response
```json
{
"model": "x/z-image-turbo",
"created_at": "2024-01-15T10:30:15.000000Z",
"image": "iVBORw0KGgoAAAANSUhEUg...",
"done": true,
"done_reason": "stop",
"total_duration": 15000000000,
"load_duration": 2000000000
}
```

View File

@@ -21,6 +21,7 @@ ollama pull glm-4.7:cloud
To use Ollama with tools that expect the Anthropic API (like Claude Code), set these environment variables: To use Ollama with tools that expect the Anthropic API (like Claude Code), set these environment variables:
```shell ```shell
export ANTHROPIC_AUTH_TOKEN=ollama # required but ignored
export ANTHROPIC_BASE_URL=http://localhost:11434 export ANTHROPIC_BASE_URL=http://localhost:11434
export ANTHROPIC_API_KEY=ollama # required but ignored export ANTHROPIC_API_KEY=ollama # required but ignored
``` ```
@@ -247,12 +248,13 @@ curl -X POST http://localhost:11434/v1/messages \
[Claude Code](https://code.claude.com/docs/en/overview) can be configured to use Ollama as its backend: [Claude Code](https://code.claude.com/docs/en/overview) can be configured to use Ollama as its backend:
```shell ```shell
ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
``` ```
Or set the environment variables in your shell profile: Or set the environment variables in your shell profile:
```shell ```shell
export ANTHROPIC_AUTH_TOKEN=ollama
export ANTHROPIC_BASE_URL=http://localhost:11434 export ANTHROPIC_BASE_URL=http://localhost:11434
export ANTHROPIC_API_KEY=ollama export ANTHROPIC_API_KEY=ollama
``` ```

View File

@@ -275,6 +275,73 @@ curl -X POST http://localhost:11434/v1/chat/completions \
- [x] `dimensions` - [x] `dimensions`
- [ ] `user` - [ ] `user`
### `/v1/images/generations` (experimental)
> Note: This endpoint is experimental and may change or be removed in future versions.
Generate images using image generation models.
<CodeGroup dropdown>
```python images.py
from openai import OpenAI
client = OpenAI(
base_url='http://localhost:11434/v1/',
api_key='ollama', # required but ignored
)
response = client.images.generate(
model='x/z-image-turbo',
prompt='A cute robot learning to paint',
size='1024x1024',
response_format='b64_json',
)
print(response.data[0].b64_json[:50] + '...')
```
```javascript images.js
import OpenAI from "openai";
const openai = new OpenAI({
baseURL: "http://localhost:11434/v1/",
apiKey: "ollama", // required but ignored
});
const response = await openai.images.generate({
model: "x/z-image-turbo",
prompt: "A cute robot learning to paint",
size: "1024x1024",
response_format: "b64_json",
});
console.log(response.data[0].b64_json.slice(0, 50) + "...");
```
```shell images.sh
curl -X POST http://localhost:11434/v1/images/generations \
-H "Content-Type: application/json" \
-d '{
"model": "x/z-image-turbo",
"prompt": "A cute robot learning to paint",
"size": "1024x1024",
"response_format": "b64_json"
}'
```
</CodeGroup>
#### Supported request fields
- [x] `model`
- [x] `prompt`
- [x] `size` (e.g. "1024x1024")
- [x] `response_format` (only `b64_json` supported)
- [ ] `n`
- [ ] `quality`
- [ ] `style`
- [ ] `user`
### `/v1/responses` ### `/v1/responses`
> Note: Added in Ollama v0.13.3 > Note: Added in Ollama v0.13.3

View File

@@ -110,7 +110,7 @@ More Ollama [Python example](https://github.com/ollama/ollama-python/blob/main/e
import { Ollama } from "ollama"; import { Ollama } from "ollama";
const client = new Ollama(); const client = new Ollama();
const results = await client.webSearch({ query: "what is ollama?" }); const results = await client.webSearch("what is ollama?");
console.log(JSON.stringify(results, null, 2)); console.log(JSON.stringify(results, null, 2));
``` ```
@@ -213,7 +213,7 @@ models](https://ollama.com/models)\n\nAvailable for macOS, Windows, and Linux',
import { Ollama } from "ollama"; import { Ollama } from "ollama";
const client = new Ollama(); const client = new Ollama();
const fetchResult = await client.webFetch({ url: "https://ollama.com" }); const fetchResult = await client.webFetch("https://ollama.com");
console.log(JSON.stringify(fetchResult, null, 2)); console.log(JSON.stringify(fetchResult, null, 2));
``` ```

View File

@@ -111,7 +111,9 @@
"/integrations/zed", "/integrations/zed",
"/integrations/roo-code", "/integrations/roo-code",
"/integrations/n8n", "/integrations/n8n",
"/integrations/xcode" "/integrations/xcode",
"/integrations/onyx",
"/integrations/marimo"
] ]
}, },
{ {

View File

@@ -22,7 +22,7 @@ Please refer to the [GPU docs](./gpu).
## How can I specify the context window size? ## How can I specify the context window size?
By default, Ollama uses a context window size of 2048 tokens. By default, Ollama uses a context window size of 4096 tokens.
This can be overridden with the `OLLAMA_CONTEXT_LENGTH` environment variable. For example, to set the default context window to 8K, use: This can be overridden with the `OLLAMA_CONTEXT_LENGTH` environment variable. For example, to set the default context window to 8K, use:

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 174 KiB

BIN
docs/images/marimo-chat.png Normal file
View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 80 KiB

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 230 KiB

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 178 KiB

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 186 KiB

BIN
docs/images/onyx-login.png Normal file
View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 100 KiB

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 306 KiB

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 300 KiB

BIN
docs/images/onyx-query.png Normal file
View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 211 KiB

View File

@@ -2,6 +2,12 @@
title: Claude Code title: Claude Code
--- ---
Claude Code is Anthropic's agentic coding tool that can read, modify, and execute code in your working directory.
Open models can be used with Claude Code through Ollama's Anthropic-compatible API, enabling you to use models such as `qwen3-coder`, `gpt-oss:20b`, or other models.
![Claude Code with Ollama](https://files.ollama.com/claude-code.png)
## Install ## Install
Install [Claude Code](https://code.claude.com/docs/en/overview): Install [Claude Code](https://code.claude.com/docs/en/overview):
@@ -25,22 +31,24 @@ Claude Code connects to Ollama using the Anthropic-compatible API.
1. Set the environment variables: 1. Set the environment variables:
```shell ```shell
export ANTHROPIC_AUTH_TOKEN=ollama
export ANTHROPIC_BASE_URL=http://localhost:11434 export ANTHROPIC_BASE_URL=http://localhost:11434
export ANTHROPIC_API_KEY=ollama
``` ```
2. Run Claude Code with an Ollama model: 2. Run Claude Code with an Ollama model:
```shell ```shell
claude --model qwen3-coder claude --model gpt-oss:20b
``` ```
Or run with environment variables inline: Or run with environment variables inline:
```shell ```shell
ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 claude --model gpt-oss:20b
``` ```
**Note:** Claude Code requires a large context window. We recommend at least 32K tokens. See the [context length documentation](/context-length) for how to adjust context length in Ollama.
## Connecting to ollama.com ## Connecting to ollama.com
1. Create an [API key](https://ollama.com/settings/keys) on ollama.com 1. Create an [API key](https://ollama.com/settings/keys) on ollama.com
@@ -67,3 +75,4 @@ claude --model glm-4.7:cloud
### Local models ### Local models
- `qwen3-coder` - Excellent for coding tasks - `qwen3-coder` - Excellent for coding tasks
- `gpt-oss:20b` - Strong general-purpose model - `gpt-oss:20b` - Strong general-purpose model
- `gpt-oss:120b` - Larger general-purpose model for more complex tasks

View File

@@ -0,0 +1,73 @@
---
title: marimo
---
## Install
Install [marimo](https://marimo.io). You can use `pip` or `uv` for this. You
can also use `uv` to create a sandboxed environment for marimo by running:
```
uvx marimo edit --sandbox notebook.py
```
## Usage with Ollama
1. In marimo, go to the user settings and go to the AI tab. From here
you can find and configure Ollama as an AI provider. For local use you
would typically point the base url to `http://localhost:11434/v1`.
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/marimo-settings.png"
alt="Ollama settings in marimo"
width="50%"
/>
</div>
2. Once the AI provider is set up, you can turn on/off specific AI models you'd like to access.
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/marimo-models.png"
alt="Selecting an Ollama model"
width="50%"
/>
</div>
3. You can also add a model to the list of available models by scrolling to the bottom and using the UI there.
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/marimo-add-model.png"
alt="Adding a new Ollama model"
width="50%"
/>
</div>
4. Once configured, you can now use Ollama for AI chats in marimo.
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/marimo-chat.png"
alt="Configure code completion"
width="50%"
/>
</div>
4. Alternatively, you can now use Ollama for **inline code completion** in marimo. This can be configured in the "AI Features" tab.
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/marimo-code-completion.png"
alt="Configure code completion"
width="50%"
/>
</div>
## Connecting to ollama.com
1. Sign in to ollama cloud via `ollama signin`
2. In the ollama model settings add a model that ollama hosts, like `gpt-oss:120b`.
3. You can now refer to this model in marimo!

View File

@@ -0,0 +1,63 @@
---
title: Onyx
---
## Overview
[Onyx](http://onyx.app/) is a self-hostable Chat UI that integrates with all Ollama models. Features include:
- Creating custom Agents
- Web search
- Deep Research
- RAG over uploaded documents and connected apps
- Connectors to applications like Google Drive, Email, Slack, etc.
- MCP and OpenAPI Actions support
- Image generation
- User/Groups management, RBAC, SSO, etc.
Onyx can be deployed for single users or large organizations.
## Install Onyx
Deploy Onyx with the [quickstart guide](https://docs.onyx.app/deployment/getting_started/quickstart).
<Info>
Resourcing/scaling docs [here](https://docs.onyx.app/deployment/getting_started/resourcing).
</Info>
## Usage with Ollama
1. Login to your Onyx deployment (create an account first).
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/onyx-login.png"
alt="Onyx Login Page"
width="75%"
/>
</div>
2. In the set-up process select `Ollama` as the LLM provider.
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/onyx-ollama-llm.png"
alt="Onyx Set Up Form"
width="75%"
/>
</div>
3. Provide your **Ollama API URL** and select your models.
<Note>If you're running Onyx in Docker, to access your computer's local network use `http://host.docker.internal` instead of `http://127.0.0.1`.</Note>
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/onyx-ollama-form.png"
alt="Selecting Ollama Models"
width="75%"
/>
</div>
You can also easily connect up Onyx Cloud with the `Ollama Cloud` tab of the setup.
## Send your first query
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/onyx-query.png"
alt="Onyx Query Example"
width="75%"
/>
</div>

View File

@@ -1,5 +1,5 @@
--- ---
title: "Linux" title: Linux
--- ---
## Install ## Install
@@ -13,14 +13,15 @@ curl -fsSL https://ollama.com/install.sh | sh
## Manual install ## Manual install
<Note> <Note>
If you are upgrading from a prior version, you should remove the old libraries with `sudo rm -rf /usr/lib/ollama` first. If you are upgrading from a prior version, you should remove the old libraries
with `sudo rm -rf /usr/lib/ollama` first.
</Note> </Note>
Download and extract the package: Download and extract the package:
```shell ```shell
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tgz \ curl -fsSL https://ollama.com/download/ollama-linux-amd64.tar.zst \
| sudo tar zx -C /usr | sudo tar x -C /usr
``` ```
Start Ollama: Start Ollama:
@@ -40,8 +41,8 @@ ollama -v
If you have an AMD GPU, also download and extract the additional ROCm package: If you have an AMD GPU, also download and extract the additional ROCm package:
```shell ```shell
curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tgz \ curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tar.zst \
| sudo tar zx -C /usr | sudo tar x -C /usr
``` ```
### ARM64 install ### ARM64 install
@@ -49,8 +50,8 @@ curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tgz \
Download and extract the ARM64-specific package: Download and extract the ARM64-specific package:
```shell ```shell
curl -fsSL https://ollama.com/download/ollama-linux-arm64.tgz \ curl -fsSL https://ollama.com/download/ollama-linux-arm64.tar.zst \
| sudo tar zx -C /usr | sudo tar x -C /usr
``` ```
### Adding Ollama as a startup service (recommended) ### Adding Ollama as a startup service (recommended)
@@ -112,7 +113,11 @@ sudo systemctl status ollama
``` ```
<Note> <Note>
While AMD has contributed the `amdgpu` driver upstream to the official linux kernel source, the version is older and may not support all ROCm features. We recommend you install the latest driver from https://www.amd.com/en/support/linux-drivers for best support of your Radeon GPU. While AMD has contributed the `amdgpu` driver upstream to the official linux
kernel source, the version is older and may not support all ROCm features. We
recommend you install the latest driver from
https://www.amd.com/en/support/linux-drivers for best support of your Radeon
GPU.
</Note> </Note>
## Customizing ## Customizing
@@ -141,8 +146,8 @@ curl -fsSL https://ollama.com/install.sh | sh
Or by re-downloading Ollama: Or by re-downloading Ollama:
```shell ```shell
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tgz \ curl -fsSL https://ollama.com/download/ollama-linux-amd64.tar.zst \
| sudo tar zx -C /usr | sudo tar x -C /usr
``` ```
## Installing specific versions ## Installing specific versions

View File

@@ -131,7 +131,7 @@ func TestAPIToolCalling(t *testing.T) {
t.Errorf("unexpected tool called: got %q want %q", lastToolCall.Function.Name, "get_weather") t.Errorf("unexpected tool called: got %q want %q", lastToolCall.Function.Name, "get_weather")
} }
if _, ok := lastToolCall.Function.Arguments["location"]; !ok { if _, ok := lastToolCall.Function.Arguments.Get("location"); !ok {
t.Errorf("expected tool arguments to include 'location', got: %s", lastToolCall.Function.Arguments.String()) t.Errorf("expected tool arguments to include 'location', got: %s", lastToolCall.Function.Arguments.String())
} }
case <-ctx.Done(): case <-ctx.Done():

View File

@@ -1464,6 +1464,12 @@ type CompletionRequest struct {
// TopLogprobs specifies the number of most likely alternative tokens to return (0-20) // TopLogprobs specifies the number of most likely alternative tokens to return (0-20)
TopLogprobs int TopLogprobs int
// Image generation fields
Width int32 `json:"width,omitempty"`
Height int32 `json:"height,omitempty"`
Steps int32 `json:"steps,omitempty"`
Seed int64 `json:"seed,omitempty"`
} }
// DoneReason represents the reason why a completion response is done // DoneReason represents the reason why a completion response is done
@@ -1512,6 +1518,15 @@ type CompletionResponse struct {
// Logprobs contains log probability information if requested // Logprobs contains log probability information if requested
Logprobs []Logprob `json:"logprobs,omitempty"` Logprobs []Logprob `json:"logprobs,omitempty"`
// Image contains base64-encoded image data for image generation
Image string `json:"image,omitempty"`
// Step is the current step in image generation
Step int `json:"step,omitempty"`
// TotalSteps is the total number of steps for image generation
TotalSteps int `json:"total_steps,omitempty"`
} }
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error { func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {

View File

@@ -8,6 +8,7 @@ import (
"math/rand" "math/rand"
"net/http" "net/http"
"strings" "strings"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -441,6 +442,7 @@ type ResponsesWriter struct {
stream bool stream bool
responseID string responseID string
itemID string itemID string
request openai.ResponsesRequest
} }
func (w *ResponsesWriter) writeEvent(eventType string, data any) error { func (w *ResponsesWriter) writeEvent(eventType string, data any) error {
@@ -478,7 +480,9 @@ func (w *ResponsesWriter) writeResponse(data []byte) (int, error) {
// Non-streaming response // Non-streaming response
w.ResponseWriter.Header().Set("Content-Type", "application/json") w.ResponseWriter.Header().Set("Content-Type", "application/json")
response := openai.ToResponse(w.model, w.responseID, w.itemID, chatResponse) response := openai.ToResponse(w.model, w.responseID, w.itemID, chatResponse, w.request)
completedAt := time.Now().Unix()
response.CompletedAt = &completedAt
return len(data), json.NewEncoder(w.ResponseWriter).Encode(response) return len(data), json.NewEncoder(w.ResponseWriter).Encode(response)
} }
@@ -523,11 +527,12 @@ func ResponsesMiddleware() gin.HandlerFunc {
w := &ResponsesWriter{ w := &ResponsesWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer}, BaseWriter: BaseWriter{ResponseWriter: c.Writer},
converter: openai.NewResponsesStreamConverter(responseID, itemID, req.Model), converter: openai.NewResponsesStreamConverter(responseID, itemID, req.Model, req),
model: req.Model, model: req.Model,
stream: streamRequested, stream: streamRequested,
responseID: responseID, responseID: responseID,
itemID: itemID, itemID: itemID,
request: req,
} }
// Set headers based on streaming mode // Set headers based on streaming mode
@@ -541,3 +546,66 @@ func ResponsesMiddleware() gin.HandlerFunc {
c.Next() c.Next()
} }
} }
type ImageWriter struct {
BaseWriter
}
func (w *ImageWriter) writeResponse(data []byte) (int, error) {
var generateResponse api.GenerateResponse
if err := json.Unmarshal(data, &generateResponse); err != nil {
return 0, err
}
// Only write response when done with image
if generateResponse.Done && generateResponse.Image != "" {
w.ResponseWriter.Header().Set("Content-Type", "application/json")
return len(data), json.NewEncoder(w.ResponseWriter).Encode(openai.ToImageGenerationResponse(generateResponse))
}
return len(data), nil
}
func (w *ImageWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(data)
}
return w.writeResponse(data)
}
func ImageGenerationsMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var req openai.ImageGenerationRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
return
}
if req.Prompt == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "prompt is required"))
return
}
if req.Model == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "model is required"))
return
}
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(openai.FromImageGenerationRequest(req)); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error()))
return
}
c.Request.Body = io.NopCloser(&b)
w := &ImageWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
}
c.Writer = w
c.Next()
}
}

View File

@@ -961,3 +961,154 @@ func TestRetrieveMiddleware(t *testing.T) {
} }
} }
} }
func TestImageGenerationsMiddleware(t *testing.T) {
type testCase struct {
name string
body string
req api.GenerateRequest
err openai.ErrorResponse
}
var capturedRequest *api.GenerateRequest
testCases := []testCase{
{
name: "image generation basic",
body: `{
"model": "test-model",
"prompt": "a beautiful sunset"
}`,
req: api.GenerateRequest{
Model: "test-model",
Prompt: "a beautiful sunset",
},
},
{
name: "image generation with size",
body: `{
"model": "test-model",
"prompt": "a beautiful sunset",
"size": "512x768"
}`,
req: api.GenerateRequest{
Model: "test-model",
Prompt: "a beautiful sunset",
Width: 512,
Height: 768,
},
},
{
name: "image generation missing prompt",
body: `{
"model": "test-model"
}`,
err: openai.ErrorResponse{
Error: openai.Error{
Message: "prompt is required",
Type: "invalid_request_error",
},
},
},
{
name: "image generation missing model",
body: `{
"prompt": "a beautiful sunset"
}`,
err: openai.ErrorResponse{
Error: openai.Error{
Message: "model is required",
Type: "invalid_request_error",
},
},
},
}
endpoint := func(c *gin.Context) {
c.Status(http.StatusOK)
}
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(ImageGenerationsMiddleware(), captureRequestMiddleware(&capturedRequest))
router.Handle(http.MethodPost, "/api/generate", endpoint)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tc.body))
req.Header.Set("Content-Type", "application/json")
defer func() { capturedRequest = nil }()
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if tc.err.Error.Message != "" {
var errResp openai.ErrorResponse
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(tc.err, errResp); diff != "" {
t.Fatalf("errors did not match:\n%s", diff)
}
return
}
if resp.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d: %s", resp.Code, resp.Body.String())
}
if diff := cmp.Diff(&tc.req, capturedRequest); diff != "" {
t.Fatalf("requests did not match:\n%s", diff)
}
})
}
}
func TestImageWriterResponse(t *testing.T) {
gin.SetMode(gin.TestMode)
// Test that ImageWriter transforms GenerateResponse to OpenAI format
endpoint := func(c *gin.Context) {
resp := api.GenerateResponse{
Model: "test-model",
CreatedAt: time.Unix(1234567890, 0).UTC(),
Done: true,
Image: "dGVzdC1pbWFnZS1kYXRh", // base64 of "test-image-data"
}
data, _ := json.Marshal(resp)
c.Writer.Write(append(data, '\n'))
}
router := gin.New()
router.Use(ImageGenerationsMiddleware())
router.Handle(http.MethodPost, "/api/generate", endpoint)
body := `{"model": "test-model", "prompt": "test"}`
req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d: %s", resp.Code, resp.Body.String())
}
var imageResp openai.ImageGenerationResponse
if err := json.Unmarshal(resp.Body.Bytes(), &imageResp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if imageResp.Created != 1234567890 {
t.Errorf("expected created 1234567890, got %d", imageResp.Created)
}
if len(imageResp.Data) != 1 {
t.Fatalf("expected 1 image, got %d", len(imageResp.Data))
}
if imageResp.Data[0].B64JSON != "dGVzdC1pbWFnZS1kYXRh" {
t.Errorf("expected image data 'dGVzdC1pbWFnZS1kYXRh', got %s", imageResp.Data[0].B64JSON)
}
}

View File

@@ -1,7 +1,6 @@
package parsers package parsers
import ( import (
"regexp"
"strings" "strings"
"unicode" "unicode"
@@ -14,243 +13,114 @@ const (
Nemotron3NanoCollectingThinking Nemotron3NanoParserState = iota Nemotron3NanoCollectingThinking Nemotron3NanoParserState = iota
Nemotron3NanoSkipWhitespaceAfterThinking Nemotron3NanoSkipWhitespaceAfterThinking
Nemotron3NanoCollectingContent Nemotron3NanoCollectingContent
Nemotron3NanoCollectingToolCalls
) )
const ( const (
nemotronThinkClose = "</think>" nemotronThinkClose = "</think>"
nemotronToolCallOpen = "<tool_call>" nemotronToolCallOpen = "<tool_call>"
nemotronToolCallClose = "</tool_call>"
) )
type Nemotron3NanoParser struct { type Nemotron3NanoParser struct {
state Nemotron3NanoParserState state Nemotron3NanoParserState
buffer strings.Builder buffer strings.Builder
tools []api.Tool toolParser *Qwen3CoderParser
} }
func (p *Nemotron3NanoParser) HasToolSupport() bool { return true } func (p *Nemotron3NanoParser) HasToolSupport() bool { return true }
func (p *Nemotron3NanoParser) HasThinkingSupport() bool { return true } func (p *Nemotron3NanoParser) HasThinkingSupport() bool { return true }
func (p *Nemotron3NanoParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool { func (p *Nemotron3NanoParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
p.tools = tools p.toolParser = &Qwen3CoderParser{}
p.toolParser.Init(tools, nil, nil)
// thinking is enabled if user requests it
thinkingEnabled := thinkValue != nil && thinkValue.Bool() thinkingEnabled := thinkValue != nil && thinkValue.Bool()
prefill := lastMessage != nil && lastMessage.Role == "assistant" prefill := lastMessage != nil && lastMessage.Role == "assistant"
if !thinkingEnabled { if !thinkingEnabled || (prefill && lastMessage.Content != "") {
p.state = Nemotron3NanoCollectingContent p.state = Nemotron3NanoCollectingContent
return tools } else {
p.state = Nemotron3NanoCollectingThinking
} }
if prefill && lastMessage.Content != "" {
p.state = Nemotron3NanoCollectingContent
return tools
}
p.state = Nemotron3NanoCollectingThinking
return tools return tools
} }
type nemotronEvent interface {
isNemotronEvent()
}
type nemotronEventThinkingContent struct {
content string
}
type nemotronEventContent struct {
content string
}
type nemotronEventToolCall struct {
toolCall api.ToolCall
}
func (nemotronEventThinkingContent) isNemotronEvent() {}
func (nemotronEventContent) isNemotronEvent() {}
func (nemotronEventToolCall) isNemotronEvent() {}
func (p *Nemotron3NanoParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) { func (p *Nemotron3NanoParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
p.buffer.WriteString(s) if p.state == Nemotron3NanoCollectingContent {
events := p.parseEvents() return p.toolParser.Add(s, done)
var toolCalls []api.ToolCall
var contentSb strings.Builder
var thinkingSb strings.Builder
for _, event := range events {
switch event := event.(type) {
case nemotronEventToolCall:
toolCalls = append(toolCalls, event.toolCall)
case nemotronEventThinkingContent:
thinkingSb.WriteString(event.content)
case nemotronEventContent:
contentSb.WriteString(event.content)
}
} }
return contentSb.String(), thinkingSb.String(), toolCalls, nil if p.state == Nemotron3NanoSkipWhitespaceAfterThinking {
} s = strings.TrimLeftFunc(s, unicode.IsSpace)
if s == "" {
func (p *Nemotron3NanoParser) parseEvents() []nemotronEvent { return "", "", nil, nil
var all []nemotronEvent
keepLooping := true
for keepLooping {
var events []nemotronEvent
events, keepLooping = p.eat()
if len(events) > 0 {
all = append(all, events...)
}
}
return all
}
// emitWithPartialCheck extracts unambiguous content before a potential partial tag
func (p *Nemotron3NanoParser) emitWithPartialCheck(bufStr, tag string) (unambiguous, ambiguous string) {
if overlapLen := overlap(bufStr, tag); overlapLen > 0 {
beforePartialTag := bufStr[:len(bufStr)-overlapLen]
trailingLen := trailingWhitespaceLen(beforePartialTag)
return bufStr[:len(beforePartialTag)-trailingLen], bufStr[len(beforePartialTag)-trailingLen:]
}
wsLen := trailingWhitespaceLen(bufStr)
return bufStr[:len(bufStr)-wsLen], bufStr[len(bufStr)-wsLen:]
}
func (p *Nemotron3NanoParser) eat() ([]nemotronEvent, bool) {
bufStr := p.buffer.String()
if bufStr == "" {
return nil, false
}
switch p.state {
case Nemotron3NanoCollectingThinking:
if strings.Contains(bufStr, nemotronThinkClose) {
split := strings.SplitN(bufStr, nemotronThinkClose, 2)
thinking := strings.TrimRightFunc(split[0], unicode.IsSpace)
p.buffer.Reset()
remainder := strings.TrimLeftFunc(split[1], unicode.IsSpace)
p.buffer.WriteString(remainder)
// Transition to whitespace-skipping state if buffer is empty,
// otherwise go directly to content collection
if remainder == "" {
p.state = Nemotron3NanoSkipWhitespaceAfterThinking
} else {
p.state = Nemotron3NanoCollectingContent
}
if thinking != "" {
return []nemotronEvent{nemotronEventThinkingContent{content: thinking}}, true
}
return nil, true
}
unambig, ambig := p.emitWithPartialCheck(bufStr, nemotronThinkClose)
p.buffer.Reset()
p.buffer.WriteString(ambig)
if unambig != "" {
return []nemotronEvent{nemotronEventThinkingContent{content: unambig}}, false
}
return nil, false
// We only want to skip whitespace between thinking and content
case Nemotron3NanoSkipWhitespaceAfterThinking:
bufStr = strings.TrimLeftFunc(bufStr, unicode.IsSpace)
p.buffer.Reset()
p.buffer.WriteString(bufStr)
if bufStr == "" {
return nil, false
} }
p.state = Nemotron3NanoCollectingContent p.state = Nemotron3NanoCollectingContent
return nil, true return p.toolParser.Add(s, done)
}
case Nemotron3NanoCollectingContent: // Nemotron3NanoCollectingThinking - buffer and look for end markers
if strings.Contains(bufStr, nemotronToolCallOpen) { p.buffer.WriteString(s)
split := strings.SplitN(bufStr, nemotronToolCallOpen, 2) bufStr := p.buffer.String()
content := strings.TrimRightFunc(split[0], unicode.IsSpace)
p.buffer.Reset() // Look for end of thinking: </think> or <tool_call> (model may skip </think>)
p.buffer.WriteString(split[1]) thinkIdx := strings.Index(bufStr, nemotronThinkClose)
p.state = Nemotron3NanoCollectingToolCalls toolIdx := strings.Index(bufStr, nemotronToolCallOpen)
if content != "" {
return []nemotronEvent{nemotronEventContent{content: content}}, true var endIdx int = -1
} var remainder string
return nil, true
} if thinkIdx != -1 && (toolIdx == -1 || thinkIdx < toolIdx) {
unambig, ambig := p.emitWithPartialCheck(bufStr, nemotronToolCallOpen) endIdx = thinkIdx
remainder = strings.TrimLeftFunc(bufStr[thinkIdx+len(nemotronThinkClose):], unicode.IsSpace)
} else if toolIdx != -1 {
endIdx = toolIdx
remainder = bufStr[toolIdx:] // Include <tool_call> tag
}
if endIdx != -1 {
thinking = strings.TrimRightFunc(bufStr[:endIdx], unicode.IsSpace)
p.buffer.Reset() p.buffer.Reset()
p.buffer.WriteString(ambig)
if unambig != "" { if remainder == "" {
return []nemotronEvent{nemotronEventContent{content: unambig}}, false p.state = Nemotron3NanoSkipWhitespaceAfterThinking
} else {
p.state = Nemotron3NanoCollectingContent
content, _, calls, err = p.toolParser.Add(remainder, done)
} }
return nil, false return content, thinking, calls, err
case Nemotron3NanoCollectingToolCalls:
if strings.Contains(bufStr, nemotronToolCallClose) {
split := strings.SplitN(bufStr, nemotronToolCallClose, 2)
remaining := strings.TrimLeftFunc(split[1], unicode.IsSpace)
p.buffer.Reset()
p.buffer.WriteString(remaining)
var events []nemotronEvent
if tc, err := p.parseToolCall(split[0]); err == nil {
events = append(events, nemotronEventToolCall{toolCall: tc})
}
if !strings.Contains(remaining, nemotronToolCallOpen) {
p.state = Nemotron3NanoCollectingContent
}
return events, true
}
return nil, false
} }
return nil, false // No end marker - emit unambiguous thinking
thinking = p.emitThinking(bufStr)
return "", thinking, nil, nil
} }
var ( // emitThinking returns unambiguous thinking content, keeping potential partial tags in buffer
nemotronFunctionRegex = regexp.MustCompile(`<function=([^>]+)>`) func (p *Nemotron3NanoParser) emitThinking(bufStr string) string {
nemotronParameterRegex = regexp.MustCompile(`<parameter=([^>]+)>\n?([\s\S]*?)\n?</parameter>`) // Check for partial </think> or <tool_call> at end
) thinkOverlap := overlap(bufStr, nemotronThinkClose)
toolOverlap := overlap(bufStr, nemotronToolCallOpen)
maxOverlap := max(thinkOverlap, toolOverlap)
func (p *Nemotron3NanoParser) parseToolCall(content string) (api.ToolCall, error) { if maxOverlap > 0 {
toolCall := api.ToolCall{} unambiguous := bufStr[:len(bufStr)-maxOverlap]
unambiguous = strings.TrimRightFunc(unambiguous, unicode.IsSpace)
// Extract function name p.buffer.Reset()
fnMatch := nemotronFunctionRegex.FindStringSubmatch(content) p.buffer.WriteString(bufStr[len(bufStr)-maxOverlap:])
if len(fnMatch) < 2 { return unambiguous
return toolCall, nil
}
toolCall.Function.Name = fnMatch[1]
// Extract parameters
toolCall.Function.Arguments = api.NewToolCallFunctionArguments()
paramMatches := nemotronParameterRegex.FindAllStringSubmatch(content, -1)
for _, match := range paramMatches {
if len(match) >= 3 {
paramName := match[1]
paramValue := strings.TrimSpace(match[2])
// Try to parse as typed value based on tool definition
toolCall.Function.Arguments.Set(paramName, p.parseParamValue(paramName, paramValue))
}
} }
return toolCall, nil // No partial tags - emit all but trailing whitespace
} wsLen := trailingWhitespaceLen(bufStr)
if wsLen > 0 {
func (p *Nemotron3NanoParser) parseParamValue(paramName string, raw string) any { unambiguous := bufStr[:len(bufStr)-wsLen]
// Find the matching tool to get parameter type p.buffer.Reset()
var paramType api.PropertyType p.buffer.WriteString(bufStr[len(bufStr)-wsLen:])
for _, tool := range p.tools { return unambiguous
if tool.Function.Parameters.Properties != nil { }
if prop, ok := tool.Function.Parameters.Properties.Get(paramName); ok {
paramType = prop.Type // Nothing to hold back
break p.buffer.Reset()
} return bufStr
}
}
return parseValue(raw, paramType)
} }

View File

@@ -8,6 +8,8 @@ import (
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
) )
// TestNemotron3NanoParser tests Nemotron-specific behavior (thinking support).
// Tool call parsing is tested in qwen3coder_test.go since Nemotron delegates to Qwen3CoderParser.
func TestNemotron3NanoParser(t *testing.T) { func TestNemotron3NanoParser(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
@@ -17,18 +19,6 @@ func TestNemotron3NanoParser(t *testing.T) {
expectedThinking string expectedThinking string
expectedCalls []api.ToolCall expectedCalls []api.ToolCall
}{ }{
{
name: "simple content - no thinking",
input: "Hello, how can I help you?",
thinkValue: nil,
expectedContent: "Hello, how can I help you?",
},
{
name: "simple content - thinking disabled",
input: "Hello, how can I help you?",
thinkValue: &api.ThinkValue{Value: false},
expectedContent: "Hello, how can I help you?",
},
{ {
name: "thinking then content", name: "thinking then content",
input: "Let me think about this...</think>\nHere is my answer.", input: "Let me think about this...</think>\nHere is my answer.",
@@ -43,69 +33,6 @@ func TestNemotron3NanoParser(t *testing.T) {
expectedThinking: "Step 1: Analyze\nStep 2: Process\nStep 3: Conclude", expectedThinking: "Step 1: Analyze\nStep 2: Process\nStep 3: Conclude",
expectedContent: "The answer is 42.", expectedContent: "The answer is 42.",
}, },
{
name: "simple tool call",
input: "<tool_call>\n<function=get_weather>\n<parameter=city>\nParis\n</parameter>\n</function>\n</tool_call>",
thinkValue: nil,
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "Paris"}),
},
},
},
},
{
name: "content then tool call",
input: "Let me check the weather.\n<tool_call>\n<function=get_weather>\n<parameter=city>\nNYC\n</parameter>\n</function>\n</tool_call>",
thinkValue: nil,
expectedContent: "Let me check the weather.",
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "NYC"}),
},
},
},
},
{
name: "tool call with multiple parameters",
input: "<tool_call>\n<function=book_flight>\n<parameter=from>\nSFO\n</parameter>\n<parameter=to>\nNYC\n</parameter>\n</function>\n</tool_call>",
thinkValue: nil,
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "book_flight",
Arguments: testArgs(map[string]any{
"from": "SFO",
"to": "NYC",
}),
},
},
},
},
{
name: "multiple tool calls",
input: "<tool_call>\n<function=get_weather>\n<parameter=city>\nSan Francisco\n</parameter>\n</function>\n</tool_call>\n" +
"<tool_call>\n<function=get_weather>\n<parameter=city>\nNew York\n</parameter>\n</function>\n</tool_call>",
thinkValue: nil,
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "San Francisco"}),
},
},
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "New York"}),
},
},
},
},
{ {
name: "thinking then tool call", name: "thinking then tool call",
input: "I should check the weather...</think>\n<tool_call>\n<function=get_weather>\n<parameter=city>\nParis\n</parameter>\n</function>\n</tool_call>", input: "I should check the weather...</think>\n<tool_call>\n<function=get_weather>\n<parameter=city>\nParis\n</parameter>\n</function>\n</tool_call>",
@@ -135,19 +62,6 @@ func TestNemotron3NanoParser(t *testing.T) {
}, },
}, },
}, },
{
name: "tool call with multiline parameter value",
input: "<tool_call>\n<function=create_note>\n<parameter=content>\nLine 1\nLine 2\nLine 3\n</parameter>\n</function>\n</tool_call>",
thinkValue: nil,
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "create_note",
Arguments: testArgs(map[string]any{"content": "Line 1\nLine 2\nLine 3"}),
},
},
},
},
{ {
name: "empty thinking block - immediate close", name: "empty thinking block - immediate close",
input: "</think>\nHere is my answer.", input: "</think>\nHere is my answer.",
@@ -161,18 +75,6 @@ func TestNemotron3NanoParser(t *testing.T) {
thinkValue: &api.ThinkValue{Value: false}, thinkValue: &api.ThinkValue{Value: false},
expectedContent: "</think>\nSome content after spurious tag.", expectedContent: "</think>\nSome content after spurious tag.",
}, },
{
name: "tool call with no function name - returns empty tool call",
input: "<tool_call>\n<function=>\n</function>\n</tool_call>",
thinkValue: nil,
expectedCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "", Arguments: api.NewToolCallFunctionArguments()}}},
},
{
name: "content with newlines preserved",
input: "Line 1\n\nLine 2\n\n\nLine 3",
thinkValue: nil,
expectedContent: "Line 1\n\nLine 2\n\n\nLine 3",
},
{ {
name: "thinking with only whitespace after close tag", name: "thinking with only whitespace after close tag",
input: "My thoughts...</think> \n\t\n Content here.", input: "My thoughts...</think> \n\t\n Content here.",
@@ -180,25 +82,6 @@ func TestNemotron3NanoParser(t *testing.T) {
expectedThinking: "My thoughts...", expectedThinking: "My thoughts...",
expectedContent: "Content here.", expectedContent: "Content here.",
}, },
{
name: "unicode content",
input: "Hello 世界! 🌍 Ñoño",
thinkValue: nil,
expectedContent: "Hello 世界! 🌍 Ñoño",
},
{
name: "tool call with numeric parameter",
input: "<tool_call>\n<function=set_temp>\n<parameter=value>\n42\n</parameter>\n</function>\n</tool_call>",
thinkValue: nil,
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "set_temp",
Arguments: testArgs(map[string]any{"value": "42"}),
},
},
},
},
} }
for _, tt := range tests { for _, tt := range tests {
@@ -233,6 +116,8 @@ func TestNemotron3NanoParser(t *testing.T) {
} }
} }
// TestNemotron3NanoParser_Streaming tests streaming behavior for thinking support.
// Tool call streaming is tested in qwen3coder_test.go.
func TestNemotron3NanoParser_Streaming(t *testing.T) { func TestNemotron3NanoParser_Streaming(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
@@ -242,18 +127,6 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
expectedThinking string expectedThinking string
expectedCalls []api.ToolCall expectedCalls []api.ToolCall
}{ }{
{
name: "streaming content character by character",
chunks: []string{"H", "e", "l", "l", "o", ",", " ", "w", "o", "r", "l", "d", "!"},
thinkValue: nil,
expectedContent: "Hello, world!",
},
{
name: "streaming content small tokens",
chunks: []string{"Hel", "lo", ", ", "how ", "can", " I", " help", " you", " today", "?"},
thinkValue: nil,
expectedContent: "Hello, how can I help you today?",
},
{ {
name: "streaming thinking then content - granular", name: "streaming thinking then content - granular",
chunks: []string{"Let", " me", " th", "ink", " about", " this", "...", "<", "/", "think", ">", "\n", "Here", " is", " my", " answer", "."}, chunks: []string{"Let", " me", " th", "ink", " about", " this", "...", "<", "/", "think", ">", "\n", "Here", " is", " my", " answer", "."},
@@ -268,45 +141,6 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
expectedThinking: "Step 1: Analyze\nStep 2: Process", expectedThinking: "Step 1: Analyze\nStep 2: Process",
expectedContent: "The answer.", expectedContent: "The answer.",
}, },
{
name: "streaming tool call - highly granular",
chunks: []string{"<", "tool", "_", "call", ">", "\n", "<", "func", "tion", "=", "get", "_", "weather", ">", "\n", "<", "param", "eter", "=", "city", ">", "\n", "Par", "is", "\n", "</", "param", "eter", ">", "\n", "</", "func", "tion", ">", "\n", "</", "tool", "_", "call", ">"},
thinkValue: nil,
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "Paris"}),
},
},
},
},
{
name: "streaming content then tool call - granular",
chunks: []string{"Let", " me", " check", " the", " weather", ".", "\n<", "tool_call", ">", "\n", "<function=", "get_weather", ">", "\n", "<parameter=", "city", ">", "\n", "NYC", "\n", "</parameter>", "\n", "</function>", "\n", "</tool_call>"},
thinkValue: nil,
expectedContent: "Let me check the weather.",
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "NYC"}),
},
},
},
},
{
name: "tool call tag split character by character",
chunks: []string{"<", "t", "o", "o", "l", "_", "c", "a", "l", "l", ">", "\n", "<", "f", "u", "n", "c", "t", "i", "o", "n", "=", "t", "e", "s", "t", ">", "\n", "<", "/", "f", "u", "n", "c", "t", "i", "o", "n", ">", "\n", "<", "/", "t", "o", "o", "l", "_", "c", "a", "l", "l", ">"},
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "test",
Arguments: api.NewToolCallFunctionArguments(),
},
},
},
},
{ {
name: "thinking close tag split character by character", name: "thinking close tag split character by character",
chunks: []string{"I", "'", "m", " ", "t", "h", "i", "n", "k", "i", "n", "g", ".", ".", ".", "<", "/", "t", "h", "i", "n", "k", ">", "\n", "D", "o", "n", "e", "!"}, chunks: []string{"I", "'", "m", " ", "t", "h", "i", "n", "k", "i", "n", "g", ".", ".", ".", "<", "/", "t", "h", "i", "n", "k", ">", "\n", "D", "o", "n", "e", "!"},
@@ -321,22 +155,6 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
expectedThinking: "Thinking...", expectedThinking: "Thinking...",
expectedContent: "Content here.", expectedContent: "Content here.",
}, },
{
name: "tool call with multiple parameters - streaming",
chunks: []string{"<tool_", "call>\n", "<function", "=book_", "flight>", "\n<para", "meter=", "from>\n", "SFO\n", "</param", "eter>", "\n<param", "eter=to", ">\nNYC", "\n</para", "meter>", "\n</func", "tion>\n", "</tool_", "call>"},
thinkValue: nil,
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "book_flight",
Arguments: testArgs(map[string]any{
"from": "SFO",
"to": "NYC",
}),
},
},
},
},
{ {
name: "thinking then content then tool call - streaming", name: "thinking then content then tool call - streaming",
chunks: []string{"Ana", "lyzing", " your", " request", "...", "</", "think", ">\n", "I'll", " check", " that", " for", " you", ".", "\n", "<tool", "_call", ">\n", "<function", "=search", ">\n", "<parameter", "=query", ">\n", "test", " query", "\n</", "parameter", ">\n", "</function", ">\n", "</tool", "_call", ">"}, chunks: []string{"Ana", "lyzing", " your", " request", "...", "</", "think", ">\n", "I'll", " check", " that", " for", " you", ".", "\n", "<tool", "_call", ">\n", "<function", "=search", ">\n", "<parameter", "=query", ">\n", "test", " query", "\n</", "parameter", ">\n", "</function", ">\n", "</tool", "_call", ">"},
@@ -352,45 +170,6 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
}, },
}, },
}, },
{
name: "multiple tool calls - streaming",
chunks: []string{
"<tool_call>", "\n", "<function=", "get_weather>", "\n",
"<parameter=", "city>\n", "San Fran", "cisco\n", "</parameter>", "\n",
"</function>", "\n", "</tool_call>", "\n",
"<tool_", "call>\n", "<function", "=get_weather", ">\n",
"<param", "eter=city", ">\nNew", " York\n", "</parameter>\n",
"</function>\n", "</tool_call>",
},
thinkValue: nil,
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "San Francisco"}),
},
},
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "New York"}),
},
},
},
},
{
name: "tool call with multiline parameter - streaming",
chunks: []string{"<tool_call>\n", "<function=", "create_note>\n", "<parameter=", "content>\n", "Line 1", "\nLine", " 2\n", "Line 3", "\n</parameter>\n", "</function>\n", "</tool_call>"},
thinkValue: nil,
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "create_note",
Arguments: testArgs(map[string]any{"content": "Line 1\nLine 2\nLine 3"}),
},
},
},
},
{ {
name: "empty thinking block", name: "empty thinking block",
chunks: []string{"</think>", "\n", "Just content."}, chunks: []string{"</think>", "\n", "Just content."},
@@ -398,12 +177,6 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
expectedThinking: "", expectedThinking: "",
expectedContent: "Just content.", expectedContent: "Just content.",
}, },
{
name: "empty input chunks interspersed",
chunks: []string{"Hello", "", " ", "", "world", "", "!"},
thinkValue: nil,
expectedContent: "Hello world!",
},
{ {
name: "tool call immediately after think close - no content", name: "tool call immediately after think close - no content",
chunks: []string{"Analyzing...", "</think>", "\n", "<tool_call>", "\n<function=test>\n</function>\n", "</tool_call>"}, chunks: []string{"Analyzing...", "</think>", "\n", "<tool_call>", "\n<function=test>\n</function>\n", "</tool_call>"},
@@ -418,25 +191,6 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
}, },
}, },
}, },
{
name: "tool call with empty parameter value",
chunks: []string{"<tool_call>\n<function=test>\n<parameter=name>\n", "\n</parameter>\n</function>\n</tool_call>"},
thinkValue: nil,
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "test",
Arguments: testArgs(map[string]any{"name": ""}),
},
},
},
},
{
name: "partial tool call tag at end - buffered",
chunks: []string{"Here's some content", "<tool"},
thinkValue: nil,
expectedContent: "Here's some content",
},
} }
for _, tt := range tests { for _, tt := range tests {
@@ -572,3 +326,65 @@ func TestNemotron3NanoParser_WithTools(t *testing.T) {
t.Errorf("calls mismatch (-got +want):\n%s", diff) t.Errorf("calls mismatch (-got +want):\n%s", diff)
} }
} }
// TestNemotron3NanoParser_ToolCallWithoutThinkClose tests the case where thinking is enabled
// but the model outputs content + tool call WITHOUT the </think> tag.
// The parser should still parse the tool call (content before is treated as thinking).
func TestNemotron3NanoParser_ToolCallWithoutThinkClose(t *testing.T) {
chunks := []string{
"Let", " me", " analyze", " this", ".", "\n",
"<tool_call>", "\n",
"<function=get_weather>", "\n",
"<parameter=city>", "Paris", "</parameter>", "\n",
"</function>", "\n",
"</tool_call>",
}
p := &Nemotron3NanoParser{}
p.Init(nil, nil, &api.ThinkValue{Value: true}) // thinking ENABLED but model doesn't output </think>
var allContent string
var allThinking string
var allCalls []api.ToolCall
for _, chunk := range chunks {
content, thinking, calls, err := p.Add(chunk, false)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
allContent += content
allThinking += thinking
allCalls = append(allCalls, calls...)
}
// Drain
content, thinking, calls, err := p.Add("", true)
if err != nil {
t.Fatalf("unexpected error on done: %v", err)
}
allContent += content
allThinking += thinking
allCalls = append(allCalls, calls...)
// The parser was in thinking mode, so text before <tool_call> is emitted as thinking.
expectedThinking := "Let me analyze this."
expectedCalls := []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "Paris"}),
},
},
}
if allContent != "" {
t.Errorf("expected no content (text was streamed as thinking), got: %q", allContent)
}
if diff := cmp.Diff(allThinking, expectedThinking); diff != "" {
t.Errorf("thinking mismatch (-got +want):\n%s", diff)
}
if diff := cmp.Diff(allCalls, expectedCalls, argsComparer); diff != "" {
t.Errorf("calls mismatch (-got +want):\n%s", diff)
}
}

View File

@@ -91,6 +91,37 @@ func TestQwenParserStreaming(t *testing.T) {
}, },
}, },
}, },
{
desc: "tool call tags split character by character",
steps: []step{
{input: "<", wantEvents: []qwenEvent{}},
{input: "t", wantEvents: []qwenEvent{}},
{input: "o", wantEvents: []qwenEvent{}},
{input: "o", wantEvents: []qwenEvent{}},
{input: "l", wantEvents: []qwenEvent{}},
{input: "_", wantEvents: []qwenEvent{}},
{input: "c", wantEvents: []qwenEvent{}},
{input: "a", wantEvents: []qwenEvent{}},
{input: "l", wantEvents: []qwenEvent{}},
{input: "l", wantEvents: []qwenEvent{}},
{input: ">", wantEvents: []qwenEvent{}},
{input: "a", wantEvents: []qwenEvent{}},
{input: "b", wantEvents: []qwenEvent{}},
{input: "c", wantEvents: []qwenEvent{}},
{input: "<", wantEvents: []qwenEvent{}},
{input: "/", wantEvents: []qwenEvent{}},
{input: "t", wantEvents: []qwenEvent{}},
{input: "o", wantEvents: []qwenEvent{}},
{input: "o", wantEvents: []qwenEvent{}},
{input: "l", wantEvents: []qwenEvent{}},
{input: "_", wantEvents: []qwenEvent{}},
{input: "c", wantEvents: []qwenEvent{}},
{input: "a", wantEvents: []qwenEvent{}},
{input: "l", wantEvents: []qwenEvent{}},
{input: "l", wantEvents: []qwenEvent{}},
{input: ">", wantEvents: []qwenEvent{qwenEventRawToolCall{raw: "abc"}}},
},
},
{ {
desc: "trailing whitespace between content and tool call", desc: "trailing whitespace between content and tool call",
steps: []step{ steps: []step{

View File

@@ -630,6 +630,10 @@ func nameFromToolCallID(messages []Message, toolCallID string) string {
// decodeImageURL decodes a base64 data URI into raw image bytes. // decodeImageURL decodes a base64 data URI into raw image bytes.
func decodeImageURL(url string) (api.ImageData, error) { func decodeImageURL(url string) (api.ImageData, error) {
if strings.HasPrefix(url, "http://") || strings.HasPrefix(url, "https://") {
return nil, errors.New("image URLs are not currently supported, please use base64 encoded data instead")
}
types := []string{"jpeg", "jpg", "png", "webp"} types := []string{"jpeg", "jpg", "png", "webp"}
// Support blank mime type to match /api/chat's behavior of taking just unadorned base64 // Support blank mime type to match /api/chat's behavior of taking just unadorned base64
@@ -733,3 +737,60 @@ func FromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
DebugRenderOnly: r.DebugRenderOnly, DebugRenderOnly: r.DebugRenderOnly,
}, nil }, nil
} }
// ImageGenerationRequest is an OpenAI-compatible image generation request.
type ImageGenerationRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
N int `json:"n,omitempty"`
Size string `json:"size,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
Seed *int64 `json:"seed,omitempty"`
}
// ImageGenerationResponse is an OpenAI-compatible image generation response.
type ImageGenerationResponse struct {
Created int64 `json:"created"`
Data []ImageURLOrData `json:"data"`
}
// ImageURLOrData contains either a URL or base64-encoded image data.
type ImageURLOrData struct {
URL string `json:"url,omitempty"`
B64JSON string `json:"b64_json,omitempty"`
}
// FromImageGenerationRequest converts an OpenAI image generation request to an Ollama GenerateRequest.
func FromImageGenerationRequest(r ImageGenerationRequest) api.GenerateRequest {
req := api.GenerateRequest{
Model: r.Model,
Prompt: r.Prompt,
}
// Parse size if provided (e.g., "1024x768")
if r.Size != "" {
var w, h int32
if _, err := fmt.Sscanf(r.Size, "%dx%d", &w, &h); err == nil {
req.Width = w
req.Height = h
}
}
if r.Seed != nil {
if req.Options == nil {
req.Options = map[string]any{}
}
req.Options["seed"] = *r.Seed
}
return req
}
// ToImageGenerationResponse converts an Ollama GenerateResponse to an OpenAI ImageGenerationResponse.
func ToImageGenerationResponse(resp api.GenerateResponse) ImageGenerationResponse {
var data []ImageURLOrData
if resp.Image != "" {
data = []ImageURLOrData{{B64JSON: resp.Image}}
}
return ImageGenerationResponse{
Created: resp.CreatedAt.Unix(),
Data: data,
}
}

View File

@@ -4,6 +4,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"math/rand" "math/rand"
"time"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
) )
@@ -265,9 +266,9 @@ type ResponsesText struct {
type ResponsesTool struct { type ResponsesTool struct {
Type string `json:"type"` // "function" Type string `json:"type"` // "function"
Name string `json:"name"` Name string `json:"name"`
Description string `json:"description,omitempty"` Description *string `json:"description"` // nullable but required
Strict bool `json:"strict,omitempty"` Strict *bool `json:"strict"` // nullable but required
Parameters map[string]any `json:"parameters,omitempty"` Parameters map[string]any `json:"parameters"` // nullable but required
} }
type ResponsesRequest struct { type ResponsesRequest struct {
@@ -475,11 +476,16 @@ func convertTool(t ResponsesTool) (api.Tool, error) {
} }
} }
var description string
if t.Description != nil {
description = *t.Description
}
return api.Tool{ return api.Tool{
Type: t.Type, Type: t.Type,
Function: api.ToolFunction{ Function: api.ToolFunction{
Name: t.Name, Name: t.Name,
Description: t.Description, Description: description,
Parameters: params, Parameters: params,
}, },
}, nil }, nil
@@ -516,17 +522,60 @@ func convertInputMessage(m ResponsesInputMessage) (api.Message, error) {
// Response types for the Responses API // Response types for the Responses API
// ResponsesTextField represents the text output configuration in the response.
type ResponsesTextField struct {
Format ResponsesTextFormat `json:"format"`
}
// ResponsesReasoningOutput represents reasoning configuration in the response.
type ResponsesReasoningOutput struct {
Effort *string `json:"effort,omitempty"`
Summary *string `json:"summary,omitempty"`
}
// ResponsesError represents an error in the response.
type ResponsesError struct {
Code string `json:"code"`
Message string `json:"message"`
}
// ResponsesIncompleteDetails represents details about why a response was incomplete.
type ResponsesIncompleteDetails struct {
Reason string `json:"reason"`
}
type ResponsesResponse struct { type ResponsesResponse struct {
ID string `json:"id"` ID string `json:"id"`
Object string `json:"object"` Object string `json:"object"`
CreatedAt int64 `json:"created_at"` CreatedAt int64 `json:"created_at"`
Status string `json:"status"` CompletedAt *int64 `json:"completed_at"`
Model string `json:"model"` Status string `json:"status"`
Output []ResponsesOutputItem `json:"output"` IncompleteDetails *ResponsesIncompleteDetails `json:"incomplete_details"`
Usage *ResponsesUsage `json:"usage,omitempty"` Model string `json:"model"`
// TODO(drifkin): add `temperature` and `top_p` to the response, but this PreviousResponseID *string `json:"previous_response_id"`
// requires additional plumbing to find the effective values since the Instructions *string `json:"instructions"`
// defaults can come from the model or the request Output []ResponsesOutputItem `json:"output"`
Error *ResponsesError `json:"error"`
Tools []ResponsesTool `json:"tools"`
ToolChoice any `json:"tool_choice"`
Truncation string `json:"truncation"`
ParallelToolCalls bool `json:"parallel_tool_calls"`
Text ResponsesTextField `json:"text"`
TopP float64 `json:"top_p"`
PresencePenalty float64 `json:"presence_penalty"`
FrequencyPenalty float64 `json:"frequency_penalty"`
TopLogprobs int `json:"top_logprobs"`
Temperature float64 `json:"temperature"`
Reasoning *ResponsesReasoningOutput `json:"reasoning"`
Usage *ResponsesUsage `json:"usage"`
MaxOutputTokens *int `json:"max_output_tokens"`
MaxToolCalls *int `json:"max_tool_calls"`
Store bool `json:"store"`
Background bool `json:"background"`
ServiceTier string `json:"service_tier"`
Metadata map[string]any `json:"metadata"`
SafetyIdentifier *string `json:"safety_identifier"`
PromptCacheKey *string `json:"prompt_cache_key"`
} }
type ResponsesOutputItem struct { type ResponsesOutputItem struct {
@@ -550,18 +599,39 @@ type ResponsesReasoningSummary struct {
} }
type ResponsesOutputContent struct { type ResponsesOutputContent struct {
Type string `json:"type"` // "output_text" Type string `json:"type"` // "output_text"
Text string `json:"text"` Text string `json:"text"`
Annotations []any `json:"annotations"`
Logprobs []any `json:"logprobs"`
}
type ResponsesInputTokensDetails struct {
CachedTokens int `json:"cached_tokens"`
}
type ResponsesOutputTokensDetails struct {
ReasoningTokens int `json:"reasoning_tokens"`
} }
type ResponsesUsage struct { type ResponsesUsage struct {
InputTokens int `json:"input_tokens"` InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"` OutputTokens int `json:"output_tokens"`
TotalTokens int `json:"total_tokens"` TotalTokens int `json:"total_tokens"`
InputTokensDetails ResponsesInputTokensDetails `json:"input_tokens_details"`
OutputTokensDetails ResponsesOutputTokensDetails `json:"output_tokens_details"`
} }
// ToResponse converts an api.ChatResponse to a Responses API response // derefFloat64 returns the value of a float64 pointer, or a default if nil.
func ToResponse(model, responseID, itemID string, chatResponse api.ChatResponse) ResponsesResponse { func derefFloat64(p *float64, def float64) float64 {
if p != nil {
return *p
}
return def
}
// ToResponse converts an api.ChatResponse to a Responses API response.
// The request is used to echo back request parameters in the response.
func ToResponse(model, responseID, itemID string, chatResponse api.ChatResponse, request ResponsesRequest) ResponsesResponse {
var output []ResponsesOutputItem var output []ResponsesOutputItem
// Add reasoning item if thinking is present // Add reasoning item if thinking is present
@@ -585,6 +655,7 @@ func ToResponse(model, responseID, itemID string, chatResponse api.ChatResponse)
output = append(output, ResponsesOutputItem{ output = append(output, ResponsesOutputItem{
ID: fmt.Sprintf("fc_%s_%d", responseID, i), ID: fmt.Sprintf("fc_%s_%d", responseID, i),
Type: "function_call", Type: "function_call",
Status: "completed",
CallID: tc.ID, CallID: tc.ID,
Name: tc.Function.Name, Name: tc.Function.Name,
Arguments: tc.Function.Arguments, Arguments: tc.Function.Arguments,
@@ -598,25 +669,90 @@ func ToResponse(model, responseID, itemID string, chatResponse api.ChatResponse)
Role: "assistant", Role: "assistant",
Content: []ResponsesOutputContent{ Content: []ResponsesOutputContent{
{ {
Type: "output_text", Type: "output_text",
Text: chatResponse.Message.Content, Text: chatResponse.Message.Content,
Annotations: []any{},
Logprobs: []any{},
}, },
}, },
}) })
} }
var instructions *string
if request.Instructions != "" {
instructions = &request.Instructions
}
// Build truncation with default
truncation := "disabled"
if request.Truncation != nil {
truncation = *request.Truncation
}
tools := request.Tools
if tools == nil {
tools = []ResponsesTool{}
}
text := ResponsesTextField{
Format: ResponsesTextFormat{Type: "text"},
}
if request.Text != nil && request.Text.Format != nil {
text.Format = *request.Text.Format
}
// Build reasoning output from request
var reasoning *ResponsesReasoningOutput
if request.Reasoning.Effort != "" || request.Reasoning.Summary != "" {
reasoning = &ResponsesReasoningOutput{}
if request.Reasoning.Effort != "" {
reasoning.Effort = &request.Reasoning.Effort
}
if request.Reasoning.Summary != "" {
reasoning.Summary = &request.Reasoning.Summary
}
}
return ResponsesResponse{ return ResponsesResponse{
ID: responseID, ID: responseID,
Object: "response", Object: "response",
CreatedAt: chatResponse.CreatedAt.Unix(), CreatedAt: chatResponse.CreatedAt.Unix(),
Status: "completed", CompletedAt: nil, // Set by middleware when writing final response
Model: model, Status: "completed",
Output: output, IncompleteDetails: nil, // Only populated if response incomplete
Model: model,
PreviousResponseID: nil, // Not supported
Instructions: instructions,
Output: output,
Error: nil, // Only populated on failure
Tools: tools,
ToolChoice: "auto", // Default value
Truncation: truncation,
ParallelToolCalls: true, // Default value
Text: text,
TopP: derefFloat64(request.TopP, 1.0),
PresencePenalty: 0, // Default value
FrequencyPenalty: 0, // Default value
TopLogprobs: 0, // Default value
Temperature: derefFloat64(request.Temperature, 1.0),
Reasoning: reasoning,
Usage: &ResponsesUsage{ Usage: &ResponsesUsage{
InputTokens: chatResponse.PromptEvalCount, InputTokens: chatResponse.PromptEvalCount,
OutputTokens: chatResponse.EvalCount, OutputTokens: chatResponse.EvalCount,
TotalTokens: chatResponse.PromptEvalCount + chatResponse.EvalCount, TotalTokens: chatResponse.PromptEvalCount + chatResponse.EvalCount,
// TODO(drifkin): wire through the actual values
InputTokensDetails: ResponsesInputTokensDetails{CachedTokens: 0},
// TODO(drifkin): wire through the actual values
OutputTokensDetails: ResponsesOutputTokensDetails{ReasoningTokens: 0},
}, },
MaxOutputTokens: request.MaxOutputTokens,
MaxToolCalls: nil, // Not supported
Store: false, // We don't store responses
Background: request.Background,
ServiceTier: "default", // Default value
Metadata: map[string]any{},
SafetyIdentifier: nil, // Not supported
PromptCacheKey: nil, // Not supported
} }
} }
@@ -636,6 +772,7 @@ type ResponsesStreamConverter struct {
responseID string responseID string
itemID string itemID string
model string model string
request ResponsesRequest
// State tracking (mutated across Process calls) // State tracking (mutated across Process calls)
firstWrite bool firstWrite bool
@@ -668,11 +805,12 @@ func (c *ResponsesStreamConverter) newEvent(eventType string, data map[string]an
} }
// NewResponsesStreamConverter creates a new converter with the given configuration. // NewResponsesStreamConverter creates a new converter with the given configuration.
func NewResponsesStreamConverter(responseID, itemID, model string) *ResponsesStreamConverter { func NewResponsesStreamConverter(responseID, itemID, model string, request ResponsesRequest) *ResponsesStreamConverter {
return &ResponsesStreamConverter{ return &ResponsesStreamConverter{
responseID: responseID, responseID: responseID,
itemID: itemID, itemID: itemID,
model: model, model: model,
request: request,
firstWrite: true, firstWrite: true,
} }
} }
@@ -717,25 +855,120 @@ func (c *ResponsesStreamConverter) Process(r api.ChatResponse) []ResponsesStream
return events return events
} }
// buildResponseObject creates a full response object with all required fields for streaming events.
func (c *ResponsesStreamConverter) buildResponseObject(status string, output []any, usage map[string]any) map[string]any {
var instructions any = nil
if c.request.Instructions != "" {
instructions = c.request.Instructions
}
truncation := "disabled"
if c.request.Truncation != nil {
truncation = *c.request.Truncation
}
var tools []any
if c.request.Tools != nil {
for _, t := range c.request.Tools {
tools = append(tools, map[string]any{
"type": t.Type,
"name": t.Name,
"description": t.Description,
"strict": t.Strict,
"parameters": t.Parameters,
})
}
}
if tools == nil {
tools = []any{}
}
textFormat := map[string]any{"type": "text"}
if c.request.Text != nil && c.request.Text.Format != nil {
textFormat = map[string]any{
"type": c.request.Text.Format.Type,
}
if c.request.Text.Format.Name != "" {
textFormat["name"] = c.request.Text.Format.Name
}
if c.request.Text.Format.Schema != nil {
textFormat["schema"] = c.request.Text.Format.Schema
}
if c.request.Text.Format.Strict != nil {
textFormat["strict"] = *c.request.Text.Format.Strict
}
}
var reasoning any = nil
if c.request.Reasoning.Effort != "" || c.request.Reasoning.Summary != "" {
r := map[string]any{}
if c.request.Reasoning.Effort != "" {
r["effort"] = c.request.Reasoning.Effort
} else {
r["effort"] = nil
}
if c.request.Reasoning.Summary != "" {
r["summary"] = c.request.Reasoning.Summary
} else {
r["summary"] = nil
}
reasoning = r
}
// Build top_p and temperature with defaults
topP := 1.0
if c.request.TopP != nil {
topP = *c.request.TopP
}
temperature := 1.0
if c.request.Temperature != nil {
temperature = *c.request.Temperature
}
return map[string]any{
"id": c.responseID,
"object": "response",
"created_at": time.Now().Unix(),
"completed_at": nil,
"status": status,
"incomplete_details": nil,
"model": c.model,
"previous_response_id": nil,
"instructions": instructions,
"output": output,
"error": nil,
"tools": tools,
"tool_choice": "auto",
"truncation": truncation,
"parallel_tool_calls": true,
"text": map[string]any{"format": textFormat},
"top_p": topP,
"presence_penalty": 0,
"frequency_penalty": 0,
"top_logprobs": 0,
"temperature": temperature,
"reasoning": reasoning,
"usage": usage,
"max_output_tokens": c.request.MaxOutputTokens,
"max_tool_calls": nil,
"store": false,
"background": c.request.Background,
"service_tier": "default",
"metadata": map[string]any{},
"safety_identifier": nil,
"prompt_cache_key": nil,
}
}
func (c *ResponsesStreamConverter) createResponseCreatedEvent() ResponsesStreamEvent { func (c *ResponsesStreamConverter) createResponseCreatedEvent() ResponsesStreamEvent {
return c.newEvent("response.created", map[string]any{ return c.newEvent("response.created", map[string]any{
"response": map[string]any{ "response": c.buildResponseObject("in_progress", []any{}, nil),
"id": c.responseID,
"object": "response",
"status": "in_progress",
"output": []any{},
},
}) })
} }
func (c *ResponsesStreamConverter) createResponseInProgressEvent() ResponsesStreamEvent { func (c *ResponsesStreamConverter) createResponseInProgressEvent() ResponsesStreamEvent {
return c.newEvent("response.in_progress", map[string]any{ return c.newEvent("response.in_progress", map[string]any{
"response": map[string]any{ "response": c.buildResponseObject("in_progress", []any{}, nil),
"id": c.responseID,
"object": "response",
"status": "in_progress",
"output": []any{},
},
}) })
} }
@@ -762,9 +995,10 @@ func (c *ResponsesStreamConverter) processThinking(thinking string) []ResponsesS
// Emit delta // Emit delta
events = append(events, c.newEvent("response.reasoning_summary_text.delta", map[string]any{ events = append(events, c.newEvent("response.reasoning_summary_text.delta", map[string]any{
"item_id": c.reasoningItemID, "item_id": c.reasoningItemID,
"output_index": c.outputIndex, "output_index": c.outputIndex,
"delta": thinking, "summary_index": 0,
"delta": thinking,
})) }))
// TODO(drifkin): consider adding // TODO(drifkin): consider adding
@@ -783,9 +1017,10 @@ func (c *ResponsesStreamConverter) finishReasoning() []ResponsesStreamEvent {
events := []ResponsesStreamEvent{ events := []ResponsesStreamEvent{
c.newEvent("response.reasoning_summary_text.done", map[string]any{ c.newEvent("response.reasoning_summary_text.done", map[string]any{
"item_id": c.reasoningItemID, "item_id": c.reasoningItemID,
"output_index": c.outputIndex, "output_index": c.outputIndex,
"text": c.accumulatedThinking, "summary_index": 0,
"text": c.accumulatedThinking,
}), }),
c.newEvent("response.output_item.done", map[string]any{ c.newEvent("response.output_item.done", map[string]any{
"output_index": c.outputIndex, "output_index": c.outputIndex,
@@ -898,8 +1133,10 @@ func (c *ResponsesStreamConverter) processTextContent(content string) []Response
"output_index": c.outputIndex, "output_index": c.outputIndex,
"content_index": c.contentIndex, "content_index": c.contentIndex,
"part": map[string]any{ "part": map[string]any{
"type": "output_text", "type": "output_text",
"text": "", "text": "",
"annotations": []any{},
"logprobs": []any{},
}, },
})) }))
} }
@@ -913,6 +1150,7 @@ func (c *ResponsesStreamConverter) processTextContent(content string) []Response
"output_index": c.outputIndex, "output_index": c.outputIndex,
"content_index": 0, "content_index": 0,
"delta": content, "delta": content,
"logprobs": []any{},
})) }))
return events return events
@@ -944,8 +1182,10 @@ func (c *ResponsesStreamConverter) buildFinalOutput() []any {
"status": "completed", "status": "completed",
"role": "assistant", "role": "assistant",
"content": []map[string]any{{ "content": []map[string]any{{
"type": "output_text", "type": "output_text",
"text": c.accumulatedText, "text": c.accumulatedText,
"annotations": []any{},
"logprobs": []any{},
}}, }},
}) })
} }
@@ -967,6 +1207,7 @@ func (c *ResponsesStreamConverter) processCompletion(r api.ChatResponse) []Respo
"output_index": c.outputIndex, "output_index": c.outputIndex,
"content_index": 0, "content_index": 0,
"text": c.accumulatedText, "text": c.accumulatedText,
"logprobs": []any{},
})) }))
// response.content_part.done // response.content_part.done
@@ -975,8 +1216,10 @@ func (c *ResponsesStreamConverter) processCompletion(r api.ChatResponse) []Respo
"output_index": c.outputIndex, "output_index": c.outputIndex,
"content_index": 0, "content_index": 0,
"part": map[string]any{ "part": map[string]any{
"type": "output_text", "type": "output_text",
"text": c.accumulatedText, "text": c.accumulatedText,
"annotations": []any{},
"logprobs": []any{},
}, },
})) }))
@@ -989,26 +1232,31 @@ func (c *ResponsesStreamConverter) processCompletion(r api.ChatResponse) []Respo
"status": "completed", "status": "completed",
"role": "assistant", "role": "assistant",
"content": []map[string]any{{ "content": []map[string]any{{
"type": "output_text", "type": "output_text",
"text": c.accumulatedText, "text": c.accumulatedText,
"annotations": []any{},
"logprobs": []any{},
}}, }},
}, },
})) }))
} }
// response.completed // response.completed
events = append(events, c.newEvent("response.completed", map[string]any{ usage := map[string]any{
"response": map[string]any{ "input_tokens": r.PromptEvalCount,
"id": c.responseID, "output_tokens": r.EvalCount,
"object": "response", "total_tokens": r.PromptEvalCount + r.EvalCount,
"status": "completed", "input_tokens_details": map[string]any{
"output": c.buildFinalOutput(), "cached_tokens": 0,
"usage": map[string]any{
"input_tokens": r.PromptEvalCount,
"output_tokens": r.EvalCount,
"total_tokens": r.PromptEvalCount + r.EvalCount,
},
}, },
"output_tokens_details": map[string]any{
"reasoning_tokens": 0,
},
}
response := c.buildResponseObject("completed", c.buildFinalOutput(), usage)
response["completed_at"] = time.Now().Unix()
events = append(events, c.newEvent("response.completed", map[string]any{
"response": response,
})) }))
return events return events

View File

@@ -850,7 +850,7 @@ func TestFromResponsesRequest_Images(t *testing.T) {
} }
func TestResponsesStreamConverter_TextOnly(t *testing.T) { func TestResponsesStreamConverter_TextOnly(t *testing.T) {
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b") converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
// First chunk with content // First chunk with content
events := converter.Process(api.ChatResponse{ events := converter.Process(api.ChatResponse{
@@ -916,7 +916,7 @@ func TestResponsesStreamConverter_TextOnly(t *testing.T) {
} }
func TestResponsesStreamConverter_ToolCalls(t *testing.T) { func TestResponsesStreamConverter_ToolCalls(t *testing.T) {
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b") converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
events := converter.Process(api.ChatResponse{ events := converter.Process(api.ChatResponse{
Message: api.Message{ Message: api.Message{
@@ -952,7 +952,7 @@ func TestResponsesStreamConverter_ToolCalls(t *testing.T) {
} }
func TestResponsesStreamConverter_Reasoning(t *testing.T) { func TestResponsesStreamConverter_Reasoning(t *testing.T) {
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b") converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
// First chunk with thinking // First chunk with thinking
events := converter.Process(api.ChatResponse{ events := converter.Process(api.ChatResponse{
@@ -1267,7 +1267,7 @@ func TestToResponse_WithReasoning(t *testing.T) {
Content: "The answer is 42", Content: "The answer is 42",
}, },
Done: true, Done: true,
}) }, ResponsesRequest{})
// Should have 2 output items: reasoning + message // Should have 2 output items: reasoning + message
if len(response.Output) != 2 { if len(response.Output) != 2 {
@@ -1638,7 +1638,7 @@ func TestFromResponsesRequest_ShorthandFormats(t *testing.T) {
func TestResponsesStreamConverter_OutputIncludesContent(t *testing.T) { func TestResponsesStreamConverter_OutputIncludesContent(t *testing.T) {
// Verify that response.output_item.done includes content field for messages // Verify that response.output_item.done includes content field for messages
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b") converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
// First chunk // First chunk
converter.Process(api.ChatResponse{ converter.Process(api.ChatResponse{
@@ -1686,7 +1686,7 @@ func TestResponsesStreamConverter_OutputIncludesContent(t *testing.T) {
func TestResponsesStreamConverter_ResponseCompletedIncludesOutput(t *testing.T) { func TestResponsesStreamConverter_ResponseCompletedIncludesOutput(t *testing.T) {
// Verify that response.completed includes the output array // Verify that response.completed includes the output array
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b") converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
// Process some content // Process some content
converter.Process(api.ChatResponse{ converter.Process(api.ChatResponse{
@@ -1730,7 +1730,7 @@ func TestResponsesStreamConverter_ResponseCompletedIncludesOutput(t *testing.T)
func TestResponsesStreamConverter_ResponseCreatedIncludesOutput(t *testing.T) { func TestResponsesStreamConverter_ResponseCreatedIncludesOutput(t *testing.T) {
// Verify that response.created includes an empty output array // Verify that response.created includes an empty output array
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b") converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
events := converter.Process(api.ChatResponse{ events := converter.Process(api.ChatResponse{
Message: api.Message{Content: "Hi"}, Message: api.Message{Content: "Hi"},
@@ -1757,7 +1757,7 @@ func TestResponsesStreamConverter_ResponseCreatedIncludesOutput(t *testing.T) {
func TestResponsesStreamConverter_SequenceNumbers(t *testing.T) { func TestResponsesStreamConverter_SequenceNumbers(t *testing.T) {
// Verify that events include incrementing sequence numbers // Verify that events include incrementing sequence numbers
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b") converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
events := converter.Process(api.ChatResponse{ events := converter.Process(api.ChatResponse{
Message: api.Message{Content: "Hello"}, Message: api.Message{Content: "Hello"},
@@ -1791,7 +1791,7 @@ func TestResponsesStreamConverter_SequenceNumbers(t *testing.T) {
func TestResponsesStreamConverter_FunctionCallStatus(t *testing.T) { func TestResponsesStreamConverter_FunctionCallStatus(t *testing.T) {
// Verify that function call items include status field // Verify that function call items include status field
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b") converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
events := converter.Process(api.ChatResponse{ events := converter.Process(api.ChatResponse{
Message: api.Message{ Message: api.Message{

View File

@@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"io" "io"
"os" "os"
"strings"
) )
type Prompt struct { type Prompt struct {
@@ -36,10 +37,11 @@ type Terminal struct {
} }
type Instance struct { type Instance struct {
Prompt *Prompt Prompt *Prompt
Terminal *Terminal Terminal *Terminal
History *History History *History
Pasting bool Pasting bool
pastedLines []string
} }
func New(prompt Prompt) (*Instance, error) { func New(prompt Prompt) (*Instance, error) {
@@ -174,6 +176,8 @@ func (i *Instance) Readline() (string, error) {
case CharEsc: case CharEsc:
esc = true esc = true
case CharInterrupt: case CharInterrupt:
i.pastedLines = nil
i.Prompt.UseAlt = false
return "", ErrInterrupt return "", ErrInterrupt
case CharPrev: case CharPrev:
i.historyPrev(buf, &currentLineBuf) i.historyPrev(buf, &currentLineBuf)
@@ -188,7 +192,23 @@ func (i *Instance) Readline() (string, error) {
case CharForward: case CharForward:
buf.MoveRight() buf.MoveRight()
case CharBackspace, CharCtrlH: case CharBackspace, CharCtrlH:
buf.Remove() if buf.IsEmpty() && len(i.pastedLines) > 0 {
lastIdx := len(i.pastedLines) - 1
prevLine := i.pastedLines[lastIdx]
i.pastedLines = i.pastedLines[:lastIdx]
fmt.Print(CursorBOL + ClearToEOL + CursorUp + CursorBOL + ClearToEOL)
if len(i.pastedLines) == 0 {
fmt.Print(i.Prompt.Prompt)
i.Prompt.UseAlt = false
} else {
fmt.Print(i.Prompt.AltPrompt)
}
for _, r := range prevLine {
buf.Add(r)
}
} else {
buf.Remove()
}
case CharTab: case CharTab:
// todo: convert back to real tabs // todo: convert back to real tabs
for range 8 { for range 8 {
@@ -211,13 +231,28 @@ func (i *Instance) Readline() (string, error) {
case CharCtrlZ: case CharCtrlZ:
fd := os.Stdin.Fd() fd := os.Stdin.Fd()
return handleCharCtrlZ(fd, i.Terminal.termios) return handleCharCtrlZ(fd, i.Terminal.termios)
case CharEnter, CharCtrlJ: case CharCtrlJ:
i.pastedLines = append(i.pastedLines, buf.String())
buf.Buf.Clear()
buf.Pos = 0
buf.DisplayPos = 0
buf.LineHasSpace.Clear()
fmt.Println()
fmt.Print(i.Prompt.AltPrompt)
i.Prompt.UseAlt = true
continue
case CharEnter:
output := buf.String() output := buf.String()
if len(i.pastedLines) > 0 {
output = strings.Join(i.pastedLines, "\n") + "\n" + output
i.pastedLines = nil
}
if output != "" { if output != "" {
i.History.Add(output) i.History.Add(output)
} }
buf.MoveToEnd() buf.MoveToEnd()
fmt.Println() fmt.Println()
i.Prompt.UseAlt = false
return output, nil return output, nil
default: default:

View File

@@ -60,7 +60,7 @@ _build_darwin() {
cmake --install $BUILD_DIR --component MLX cmake --install $BUILD_DIR --component MLX
# Override CGO flags to point to the amd64 build directory # Override CGO flags to point to the amd64 build directory
MLX_CGO_CFLAGS="-O3 -I$(pwd)/$BUILD_DIR/_deps/mlx-c-src -mmacosx-version-min=14.0" MLX_CGO_CFLAGS="-O3 -I$(pwd)/$BUILD_DIR/_deps/mlx-c-src -mmacosx-version-min=14.0"
MLX_CGO_LDFLAGS="-L$(pwd)/$BUILD_DIR/lib/ollama -lmlxc -lmlx -Wl,-rpath,@executable_path -lc++ -framework Accelerate -mmacosx-version-min=14.0" MLX_CGO_LDFLAGS="-ldl -lc++ -framework Accelerate -mmacosx-version-min=14.0"
else else
BUILD_DIR=build BUILD_DIR=build
cmake --preset MLX \ cmake --preset MLX \
@@ -71,10 +71,12 @@ _build_darwin() {
cmake --install $BUILD_DIR --component MLX cmake --install $BUILD_DIR --component MLX
# Use default CGO flags from mlx.go for arm64 # Use default CGO flags from mlx.go for arm64
MLX_CGO_CFLAGS="-O3 -I$(pwd)/$BUILD_DIR/_deps/mlx-c-src -mmacosx-version-min=14.0" MLX_CGO_CFLAGS="-O3 -I$(pwd)/$BUILD_DIR/_deps/mlx-c-src -mmacosx-version-min=14.0"
MLX_CGO_LDFLAGS="-L$(pwd)/$BUILD_DIR/lib/ollama -lmlxc -lmlx -Wl,-rpath,@executable_path -lc++ -framework Metal -framework Foundation -framework Accelerate -mmacosx-version-min=14.0" MLX_CGO_LDFLAGS="-lc++ -framework Metal -framework Foundation -framework Accelerate -mmacosx-version-min=14.0"
fi fi
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 CGO_CFLAGS="$MLX_CGO_CFLAGS" CGO_LDFLAGS="$MLX_CGO_LDFLAGS" go build -tags mlx -o $INSTALL_PREFIX/ollama-mlx . GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 CGO_CFLAGS="$MLX_CGO_CFLAGS" CGO_LDFLAGS="$MLX_CGO_LDFLAGS" go build -tags mlx -o $INSTALL_PREFIX .
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 go build -o $INSTALL_PREFIX . # Copy MLX libraries to same directory as executable for dlopen
cp $INSTALL_PREFIX/lib/ollama/libmlxc.dylib $INSTALL_PREFIX/
cp $INSTALL_PREFIX/lib/ollama/libmlx.dylib $INSTALL_PREFIX/
done done
} }
@@ -82,12 +84,10 @@ _sign_darwin() {
status "Creating universal binary..." status "Creating universal binary..."
mkdir -p dist/darwin mkdir -p dist/darwin
lipo -create -output dist/darwin/ollama dist/darwin-*/ollama lipo -create -output dist/darwin/ollama dist/darwin-*/ollama
lipo -create -output dist/darwin/ollama-mlx dist/darwin-*/ollama-mlx
chmod +x dist/darwin/ollama chmod +x dist/darwin/ollama
chmod +x dist/darwin/ollama-mlx
if [ -n "$APPLE_IDENTITY" ]; then if [ -n "$APPLE_IDENTITY" ]; then
for F in dist/darwin/ollama dist/darwin-*/lib/ollama/* dist/darwin/ollama-mlx; do for F in dist/darwin/ollama dist/darwin-*/lib/ollama/*; do
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime $F codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime $F
done done
@@ -154,7 +154,6 @@ _build_macapp() {
mkdir -p dist/Ollama.app/Contents/Resources mkdir -p dist/Ollama.app/Contents/Resources
if [ -d dist/darwin-amd64 ]; then if [ -d dist/darwin-amd64 ]; then
lipo -create -output dist/Ollama.app/Contents/Resources/ollama dist/darwin-amd64/ollama dist/darwin-arm64/ollama lipo -create -output dist/Ollama.app/Contents/Resources/ollama dist/darwin-amd64/ollama dist/darwin-arm64/ollama
lipo -create -output dist/Ollama.app/Contents/Resources/ollama-mlx dist/darwin-amd64/ollama-mlx dist/darwin-arm64/ollama-mlx
for F in dist/darwin-amd64/lib/ollama/*mlx*.dylib ; do for F in dist/darwin-amd64/lib/ollama/*mlx*.dylib ; do
lipo -create -output dist/darwin/$(basename $F) $F dist/darwin-arm64/lib/ollama/$(basename $F) lipo -create -output dist/darwin/$(basename $F) $F dist/darwin-arm64/lib/ollama/$(basename $F)
done done
@@ -166,28 +165,27 @@ _build_macapp() {
cp -a dist/darwin/ollama dist/Ollama.app/Contents/Resources/ollama cp -a dist/darwin/ollama dist/Ollama.app/Contents/Resources/ollama
cp dist/darwin/*.so dist/darwin/*.dylib dist/Ollama.app/Contents/Resources/ cp dist/darwin/*.so dist/darwin/*.dylib dist/Ollama.app/Contents/Resources/
fi fi
cp -a dist/darwin/ollama-mlx dist/Ollama.app/Contents/Resources/ollama-mlx
chmod a+x dist/Ollama.app/Contents/Resources/ollama chmod a+x dist/Ollama.app/Contents/Resources/ollama
# Sign # Sign
if [ -n "$APPLE_IDENTITY" ]; then if [ -n "$APPLE_IDENTITY" ]; then
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime dist/Ollama.app/Contents/Resources/ollama codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime dist/Ollama.app/Contents/Resources/ollama
for lib in dist/Ollama.app/Contents/Resources/*.so dist/Ollama.app/Contents/Resources/*.dylib dist/Ollama.app/Contents/Resources/*.metallib dist/Ollama.app/Contents/Resources/ollama-mlx ; do for lib in dist/Ollama.app/Contents/Resources/*.so dist/Ollama.app/Contents/Resources/*.dylib dist/Ollama.app/Contents/Resources/*.metallib ; do
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime ${lib} codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime ${lib}
done done
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier com.electron.ollama --deep --options=runtime dist/Ollama.app codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier com.electron.ollama --deep --options=runtime dist/Ollama.app
fi fi
rm -f dist/Ollama-darwin.zip rm -f dist/Ollama-darwin.zip
ditto -c -k --keepParent dist/Ollama.app dist/Ollama-darwin.zip ditto -c -k --norsrc --keepParent dist/Ollama.app dist/Ollama-darwin.zip
(cd dist/Ollama.app/Contents/Resources/; tar -cf - ollama ollama-mlx *.so *.dylib *.metallib 2>/dev/null) | gzip -9vc > dist/ollama-darwin.tgz (cd dist/Ollama.app/Contents/Resources/; tar -cf - ollama *.so *.dylib *.metallib 2>/dev/null) | gzip -9vc > dist/ollama-darwin.tgz
# Notarize and Staple # Notarize and Staple
if [ -n "$APPLE_IDENTITY" ]; then if [ -n "$APPLE_IDENTITY" ]; then
$(xcrun -f notarytool) submit dist/Ollama-darwin.zip --wait --timeout 20m --apple-id "$APPLE_ID" --password "$APPLE_PASSWORD" --team-id "$APPLE_TEAM_ID" $(xcrun -f notarytool) submit dist/Ollama-darwin.zip --wait --timeout 20m --apple-id "$APPLE_ID" --password "$APPLE_PASSWORD" --team-id "$APPLE_TEAM_ID"
rm -f dist/Ollama-darwin.zip rm -f dist/Ollama-darwin.zip
$(xcrun -f stapler) staple dist/Ollama.app $(xcrun -f stapler) staple dist/Ollama.app
ditto -c -k --keepParent dist/Ollama.app dist/Ollama-darwin.zip ditto -c -k --norsrc --keepParent dist/Ollama.app dist/Ollama-darwin.zip
rm -f dist/Ollama.dmg rm -f dist/Ollama.dmg

View File

@@ -50,12 +50,17 @@ func (r registryChallenge) URL() (*url.URL, error) {
return redirectURL, nil return redirectURL, nil
} }
func getAuthorizationToken(ctx context.Context, challenge registryChallenge) (string, error) { func getAuthorizationToken(ctx context.Context, challenge registryChallenge, originalHost string) (string, error) {
redirectURL, err := challenge.URL() redirectURL, err := challenge.URL()
if err != nil { if err != nil {
return "", err return "", err
} }
// Validate that the realm host matches the original request host to prevent sending tokens cross-origin.
if redirectURL.Host != originalHost {
return "", fmt.Errorf("realm host %q does not match original host %q", redirectURL.Host, originalHost)
}
sha256sum := sha256.Sum256(nil) sha256sum := sha256.Sum256(nil)
data := []byte(fmt.Sprintf("%s,%s,%s", http.MethodGet, redirectURL.String(), base64.StdEncoding.EncodeToString([]byte(hex.EncodeToString(sha256sum[:]))))) data := []byte(fmt.Sprintf("%s,%s,%s", http.MethodGet, redirectURL.String(), base64.StdEncoding.EncodeToString([]byte(hex.EncodeToString(sha256sum[:])))))

113
server/auth_test.go Normal file
View File

@@ -0,0 +1,113 @@
package server
import (
"context"
"strings"
"testing"
"time"
)
func TestGetAuthorizationTokenRejectsCrossDomain(t *testing.T) {
tests := []struct {
realm string
originalHost string
wantMismatch bool
}{
{"https://example.com/token", "example.com", false},
{"https://example.com/token", "other.com", true},
{"https://example.com/token", "localhost:8000", true},
{"https://localhost:5000/token", "localhost:5000", false},
{"https://localhost:5000/token", "localhost:6000", true},
}
for _, tt := range tests {
t.Run(tt.originalHost, func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
challenge := registryChallenge{Realm: tt.realm, Service: "test", Scope: "repo:x:pull"}
_, err := getAuthorizationToken(ctx, challenge, tt.originalHost)
isMismatch := err != nil && strings.Contains(err.Error(), "does not match")
if tt.wantMismatch && !isMismatch {
t.Errorf("expected domain mismatch error, got: %v", err)
}
if !tt.wantMismatch && isMismatch {
t.Errorf("unexpected domain mismatch error: %v", err)
}
})
}
}
func TestParseRegistryChallenge(t *testing.T) {
tests := []struct {
input string
wantRealm, wantService, wantScope string
}{
{
`Bearer realm="https://auth.example.com/token",service="registry",scope="repo:foo:pull"`,
"https://auth.example.com/token", "registry", "repo:foo:pull",
},
{
`Bearer realm="https://r.ollama.ai/v2/token",service="ollama",scope="-"`,
"https://r.ollama.ai/v2/token", "ollama", "-",
},
{"", "", "", ""},
}
for _, tt := range tests {
result := parseRegistryChallenge(tt.input)
if result.Realm != tt.wantRealm || result.Service != tt.wantService || result.Scope != tt.wantScope {
t.Errorf("parseRegistryChallenge(%q) = {%q, %q, %q}, want {%q, %q, %q}",
tt.input, result.Realm, result.Service, result.Scope,
tt.wantRealm, tt.wantService, tt.wantScope)
}
}
}
func TestRegistryChallengeURL(t *testing.T) {
challenge := registryChallenge{
Realm: "https://auth.example.com/token",
Service: "registry",
Scope: "repo:foo:pull repo:bar:push",
}
u, err := challenge.URL()
if err != nil {
t.Fatalf("URL() error: %v", err)
}
if u.Host != "auth.example.com" {
t.Errorf("host = %q, want %q", u.Host, "auth.example.com")
}
if u.Path != "/token" {
t.Errorf("path = %q, want %q", u.Path, "/token")
}
q := u.Query()
if q.Get("service") != "registry" {
t.Errorf("service = %q, want %q", q.Get("service"), "registry")
}
if scopes := q["scope"]; len(scopes) != 2 {
t.Errorf("scope count = %d, want 2", len(scopes))
}
if q.Get("ts") == "" {
t.Error("missing ts")
}
if q.Get("nonce") == "" {
t.Error("missing nonce")
}
// Nonces should differ between calls
u2, _ := challenge.URL()
if q.Get("nonce") == u2.Query().Get("nonce") {
t.Error("nonce should be unique per call")
}
}
func TestRegistryChallengeURLInvalid(t *testing.T) {
challenge := registryChallenge{Realm: "://invalid"}
if _, err := challenge.URL(); err == nil {
t.Error("expected error for invalid URL")
}
}

View File

@@ -41,6 +41,7 @@ var (
errCapabilityVision = errors.New("vision") errCapabilityVision = errors.New("vision")
errCapabilityEmbedding = errors.New("embedding") errCapabilityEmbedding = errors.New("embedding")
errCapabilityThinking = errors.New("thinking") errCapabilityThinking = errors.New("thinking")
errCapabilityImage = errors.New("image generation")
errInsecureProtocol = errors.New("insecure protocol http") errInsecureProtocol = errors.New("insecure protocol http")
) )
@@ -76,7 +77,7 @@ func (m *Model) Capabilities() []model.Capability {
// Check for image generation model via config capabilities // Check for image generation model via config capabilities
if slices.Contains(m.Config.Capabilities, "image") { if slices.Contains(m.Config.Capabilities, "image") {
return []model.Capability{model.CapabilityImageGeneration} return []model.Capability{model.CapabilityImage}
} }
// Check for completion capability // Check for completion capability
@@ -159,6 +160,7 @@ func (m *Model) CheckCapabilities(want ...model.Capability) error {
model.CapabilityVision: errCapabilityVision, model.CapabilityVision: errCapabilityVision,
model.CapabilityEmbedding: errCapabilityEmbedding, model.CapabilityEmbedding: errCapabilityEmbedding,
model.CapabilityThinking: errCapabilityThinking, model.CapabilityThinking: errCapabilityThinking,
model.CapabilityImage: errCapabilityImage,
} }
for _, cap := range want { for _, cap := range want {
@@ -775,7 +777,7 @@ func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
Realm: challenge.Realm, Realm: challenge.Realm,
Service: challenge.Service, Service: challenge.Service,
Scope: challenge.Scope, Scope: challenge.Scope,
}) }, base.Host)
} }
if err := transfer.Download(ctx, transfer.DownloadOptions{ if err := transfer.Download(ctx, transfer.DownloadOptions{
@@ -850,7 +852,7 @@ func pushWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
Realm: challenge.Realm, Realm: challenge.Realm,
Service: challenge.Service, Service: challenge.Service,
Scope: challenge.Scope, Scope: challenge.Scope,
}) }, base.Host)
} }
return transfer.Upload(ctx, transfer.UploadOptions{ return transfer.Upload(ctx, transfer.UploadOptions{
@@ -916,7 +918,7 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
// Handle authentication error with one retry // Handle authentication error with one retry
challenge := parseRegistryChallenge(resp.Header.Get("www-authenticate")) challenge := parseRegistryChallenge(resp.Header.Get("www-authenticate"))
token, err := getAuthorizationToken(ctx, challenge) token, err := getAuthorizationToken(ctx, challenge, requestURL.Host)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -54,7 +54,7 @@ func TestModelCapabilities(t *testing.T) {
Capabilities: []string{"image"}, Capabilities: []string{"image"},
}, },
}, },
expectedCaps: []model.Capability{model.CapabilityImageGeneration}, expectedCaps: []model.Capability{model.CapabilityImage},
}, },
{ {
name: "model with completion capability", name: "model with completion capability",
@@ -242,6 +242,24 @@ func TestModelCheckCapabilities(t *testing.T) {
checkCaps: []model.Capability{"unknown"}, checkCaps: []model.Capability{"unknown"},
expectedErrMsg: "unknown capability", expectedErrMsg: "unknown capability",
}, },
{
name: "model missing image generation capability",
model: Model{
ModelPath: completionModelPath,
Template: chatTemplate,
},
checkCaps: []model.Capability{model.CapabilityImage},
expectedErrMsg: "does not support image generation",
},
{
name: "model with image generation capability",
model: Model{
Config: model.ConfigV2{
Capabilities: []string{"image"},
},
},
checkCaps: []model.Capability{model.CapabilityImage},
},
} }
for _, tt := range tests { for _, tt := range tests {

View File

@@ -51,7 +51,7 @@ import (
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
"github.com/ollama/ollama/x/imagegen" "github.com/ollama/ollama/x/imagegen"
imagegenapi "github.com/ollama/ollama/x/imagegen/api" xserver "github.com/ollama/ollama/x/server"
) )
const signinURLStr = "https://ollama.com/connect?name=%s&key=%s" const signinURLStr = "https://ollama.com/connect?name=%s&key=%s"
@@ -164,29 +164,6 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.C
return runner.llama, model, &opts, nil return runner.llama, model, &opts, nil
} }
// ScheduleImageGenRunner schedules an image generation model runner.
// This implements the imagegenapi.RunnerScheduler interface.
func (s *Server) ScheduleImageGenRunner(c *gin.Context, modelName string, opts api.Options, keepAlive *api.Duration) (llm.LlamaServer, error) {
m := &Model{
Name: modelName,
ShortName: modelName,
ModelPath: modelName, // For image gen, ModelPath is just the model name
Config: model.ConfigV2{
Capabilities: []string{"image"},
},
}
runnerCh, errCh := s.sched.GetRunner(c.Request.Context(), m, opts, keepAlive)
var runner *runnerRef
select {
case runner = <-runnerCh:
case err := <-errCh:
return nil, err
}
return runner.llama, nil
}
func signinURL() (string, error) { func signinURL() (string, error) {
pubKey, err := auth.GetPublicKey() pubKey, err := auth.GetPublicKey()
if err != nil { if err != nil {
@@ -214,12 +191,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return return
} }
// Check if this is a known image generation model
if imagegen.ResolveModelName(req.Model) != "" {
imagegenapi.HandleGenerateRequest(c, s, req.Model, req.Prompt, req.KeepAlive, streamResponse)
return
}
name := model.ParseName(req.Model) name := model.ParseName(req.Model)
if !name.IsValid() { if !name.IsValid() {
// Ideally this is "invalid model name" but we're keeping with // Ideally this is "invalid model name" but we're keeping with
@@ -249,6 +220,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return return
} }
// Handle image generation models
if slices.Contains(m.Capabilities(), model.CapabilityImage) {
s.handleImageGenerate(c, req, name.String(), checkpointStart)
return
}
if req.TopLogprobs < 0 || req.TopLogprobs > 20 { if req.TopLogprobs < 0 || req.TopLogprobs > 20 {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "top_logprobs must be between 0 and 20"}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "top_logprobs must be between 0 and 20"})
return return
@@ -1125,7 +1102,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
} }
// For image generation models, populate details from imagegen package // For image generation models, populate details from imagegen package
if slices.Contains(m.Capabilities(), model.CapabilityImageGeneration) { if slices.Contains(m.Capabilities(), model.CapabilityImage) {
if info, err := imagegen.GetModelInfo(name.String()); err == nil { if info, err := imagegen.GetModelInfo(name.String()); err == nil {
modelDetails.Family = info.Architecture modelDetails.Family = info.Architecture
modelDetails.ParameterSize = format.HumanNumber(uint64(info.ParameterCount)) modelDetails.ParameterSize = format.HumanNumber(uint64(info.ParameterCount))
@@ -1133,6 +1110,22 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
} }
} }
// For safetensors LLM models (experimental), populate details from config.json
if m.Config.ModelFormat == "safetensors" && slices.Contains(m.Config.Capabilities, "completion") {
if info, err := xserver.GetSafetensorsLLMInfo(name.String()); err == nil {
if arch, ok := info["general.architecture"].(string); ok && arch != "" {
modelDetails.Family = arch
}
if paramCount, ok := info["general.parameter_count"].(int64); ok && paramCount > 0 {
modelDetails.ParameterSize = format.HumanNumber(uint64(paramCount))
}
}
// Get torch_dtype directly from config.json for quantization level
if dtype, err := xserver.GetSafetensorsDtype(name.String()); err == nil && dtype != "" {
modelDetails.QuantizationLevel = dtype
}
}
if req.System != "" { if req.System != "" {
m.System = req.System m.System = req.System
} }
@@ -1215,7 +1208,27 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
return resp, nil return resp, nil
} }
if slices.Contains(m.Capabilities(), model.CapabilityImageGeneration) { if slices.Contains(m.Capabilities(), model.CapabilityImage) {
// Populate tensor info if verbose
if req.Verbose {
if tensors, err := xserver.GetSafetensorsTensorInfo(name.String()); err == nil {
resp.Tensors = tensors
}
}
return resp, nil
}
// For safetensors LLM models (experimental), populate ModelInfo from config.json
if m.Config.ModelFormat == "safetensors" && slices.Contains(m.Config.Capabilities, "completion") {
if info, err := xserver.GetSafetensorsLLMInfo(name.String()); err == nil {
resp.ModelInfo = info
}
// Populate tensor info if verbose
if req.Verbose {
if tensors, err := xserver.GetSafetensorsTensorInfo(name.String()); err == nil {
resp.Tensors = tensors
}
}
return resp, nil return resp, nil
} }
@@ -1587,13 +1600,12 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler) r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler)
r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler) r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler)
r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler) r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler)
// OpenAI-compatible image generation endpoint
r.POST("/v1/images/generations", middleware.ImageGenerationsMiddleware(), s.GenerateHandler)
// Inference (Anthropic compatibility) // Inference (Anthropic compatibility)
r.POST("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler) r.POST("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler)
// Experimental image generation support
imagegenapi.RegisterRoutes(r, s)
if rc != nil { if rc != nil {
// wrap old with new // wrap old with new
rs := &registry.Local{ rs := &registry.Local{
@@ -2460,3 +2472,91 @@ func filterThinkTags(msgs []api.Message, m *Model) []api.Message {
} }
return msgs return msgs
} }
// handleImageGenerate handles image generation requests within GenerateHandler.
// This is called when the model has the Image capability.
func (s *Server) handleImageGenerate(c *gin.Context, req api.GenerateRequest, modelName string, checkpointStart time.Time) {
// Validate image dimensions
const maxDimension int32 = 4096
if req.Width > maxDimension || req.Height > maxDimension {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("width and height must be <= %d", maxDimension)})
return
}
// Schedule the runner for image generation
runner, _, _, err := s.scheduleRunner(c.Request.Context(), modelName, []model.Capability{model.CapabilityImage}, nil, req.KeepAlive)
if err != nil {
handleScheduleError(c, req.Model, err)
return
}
checkpointLoaded := time.Now()
// Handle load-only request (empty prompt)
if req.Prompt == "" {
c.JSON(http.StatusOK, api.GenerateResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Done: true,
DoneReason: "load",
})
return
}
// Set headers for streaming response
c.Header("Content-Type", "application/x-ndjson")
// Get seed from options if provided
var seed int64
if s, ok := req.Options["seed"]; ok {
switch v := s.(type) {
case int:
seed = int64(v)
case int64:
seed = v
case float64:
seed = int64(v)
}
}
var streamStarted bool
if err := runner.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: req.Prompt,
Width: req.Width,
Height: req.Height,
Steps: req.Steps,
Seed: seed,
}, func(cr llm.CompletionResponse) {
streamStarted = true
res := api.GenerateResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Done: cr.Done,
}
if cr.TotalSteps > 0 {
res.Completed = int64(cr.Step)
res.Total = int64(cr.TotalSteps)
}
if cr.Image != "" {
res.Image = cr.Image
}
if cr.Done {
res.DoneReason = cr.DoneReason.String()
res.Metrics.TotalDuration = time.Since(checkpointStart)
res.Metrics.LoadDuration = checkpointLoaded.Sub(checkpointStart)
}
data, _ := json.Marshal(res)
c.Writer.Write(append(data, '\n'))
c.Writer.Flush()
}); err != nil {
// Only send JSON error if streaming hasn't started yet
// (once streaming starts, headers are committed and we can't change status code)
if !streamStarted {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
}
}

View File

@@ -574,7 +574,8 @@ func (s *Scheduler) loadImageGen(req *LlmRequest) bool {
Options: &req.opts, Options: &req.opts,
loading: false, loading: false,
sessionDuration: sessionDuration, sessionDuration: sessionDuration,
refCount: 1, totalSize: server.TotalSize(),
vramSize: server.VRAMSize(),
} }
s.loadedMu.Lock() s.loadedMu.Lock()

View File

@@ -6,7 +6,6 @@ import (
"errors" "errors"
"log/slog" "log/slog"
"os" "os"
"slices"
"testing" "testing"
"time" "time"
@@ -17,7 +16,6 @@ import (
"github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/types/model"
) )
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
@@ -807,32 +805,8 @@ func (s *mockLlm) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo { return n
func (s *mockLlm) HasExited() bool { return false } func (s *mockLlm) HasExited() bool { return false }
func (s *mockLlm) GetActiveDeviceIDs() []ml.DeviceID { return nil } func (s *mockLlm) GetActiveDeviceIDs() []ml.DeviceID { return nil }
// TestImageGenCapabilityDetection verifies that models with "image" capability
// are correctly identified and routed differently from language models.
func TestImageGenCapabilityDetection(t *testing.T) {
// Model with image capability should be detected
imageModel := &Model{
Config: model.ConfigV2{
Capabilities: []string{"image"},
},
}
require.True(t, slices.Contains(imageModel.Config.Capabilities, "image"))
// Model without image capability should not be detected
langModel := &Model{
Config: model.ConfigV2{
Capabilities: []string{"completion"},
},
}
require.False(t, slices.Contains(langModel.Config.Capabilities, "image"))
// Empty capabilities should not match
emptyModel := &Model{}
require.False(t, slices.Contains(emptyModel.Config.Capabilities, "image"))
}
// TestImageGenRunnerCanBeEvicted verifies that an image generation model // TestImageGenRunnerCanBeEvicted verifies that an image generation model
// loaded in the scheduler can be evicted by a language model request. // loaded in the scheduler can be evicted when idle.
func TestImageGenRunnerCanBeEvicted(t *testing.T) { func TestImageGenRunnerCanBeEvicted(t *testing.T) {
ctx, done := context.WithTimeout(t.Context(), 500*time.Millisecond) ctx, done := context.WithTimeout(t.Context(), 500*time.Millisecond)
defer done() defer done()
@@ -864,3 +838,59 @@ func TestImageGenRunnerCanBeEvicted(t *testing.T) {
require.NotNil(t, runner) require.NotNil(t, runner)
require.Equal(t, "/fake/image/model", runner.modelPath) require.Equal(t, "/fake/image/model", runner.modelPath)
} }
// TestImageGenSchedulerCoexistence verifies that image generation models
// can coexist with language models in the scheduler and VRAM is tracked correctly.
func TestImageGenSchedulerCoexistence(t *testing.T) {
ctx, done := context.WithTimeout(t.Context(), 500*time.Millisecond)
defer done()
s := InitScheduler(ctx)
s.getGpuFn = getGpuFn
s.getSystemInfoFn = getSystemInfoFn
// Load both an imagegen runner and a language model runner
imageGenRunner := &runnerRef{
model: &Model{Name: "flux", ModelPath: "/fake/flux/model"},
modelPath: "/fake/flux/model",
llama: &mockLlm{vramSize: 8 * format.GigaByte, vramByGPU: map[ml.DeviceID]uint64{{Library: "Metal"}: 8 * format.GigaByte}},
sessionDuration: 10 * time.Millisecond,
numParallel: 1,
refCount: 0,
}
langModelRunner := &runnerRef{
model: &Model{Name: "llama3", ModelPath: "/fake/llama3/model"},
modelPath: "/fake/llama3/model",
llama: &mockLlm{vramSize: 4 * format.GigaByte, vramByGPU: map[ml.DeviceID]uint64{{Library: "Metal"}: 4 * format.GigaByte}},
sessionDuration: 10 * time.Millisecond,
numParallel: 1,
refCount: 0,
}
s.loadedMu.Lock()
s.loaded["/fake/flux/model"] = imageGenRunner
s.loaded["/fake/llama3/model"] = langModelRunner
s.loadedMu.Unlock()
// Verify both are loaded
s.loadedMu.Lock()
require.Len(t, s.loaded, 2)
require.NotNil(t, s.loaded["/fake/flux/model"])
require.NotNil(t, s.loaded["/fake/llama3/model"])
s.loadedMu.Unlock()
// Verify updateFreeSpace accounts for both
gpus := []ml.DeviceInfo{
{
DeviceID: ml.DeviceID{Library: "Metal"},
TotalMemory: 24 * format.GigaByte,
FreeMemory: 24 * format.GigaByte,
},
}
s.updateFreeSpace(gpus)
// Free memory should be reduced by both models
expectedFree := uint64(24*format.GigaByte) - uint64(8*format.GigaByte) - uint64(4*format.GigaByte)
require.Equal(t, expectedFree, gpus[0].FreeMemory)
}

View File

@@ -279,7 +279,7 @@ func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *
case resp.StatusCode == http.StatusUnauthorized: case resp.StatusCode == http.StatusUnauthorized:
w.Rollback() w.Rollback()
challenge := parseRegistryChallenge(resp.Header.Get("www-authenticate")) challenge := parseRegistryChallenge(resp.Header.Get("www-authenticate"))
token, err := getAuthorizationToken(ctx, challenge) token, err := getAuthorizationToken(ctx, challenge, requestURL.Host)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -9,7 +9,7 @@ const (
CapabilityVision = Capability("vision") CapabilityVision = Capability("vision")
CapabilityEmbedding = Capability("embedding") CapabilityEmbedding = Capability("embedding")
CapabilityThinking = Capability("thinking") CapabilityThinking = Capability("thinking")
CapabilityImageGeneration = Capability("image") CapabilityImage = Capability("image")
) )
func (c Capability) String() string { func (c Capability) String() string {

View File

@@ -1,50 +0,0 @@
# Experimental Features
## MLX Backend
We're working on a new experimental backend based on the [MLX project](https://github.com/ml-explore/mlx)
Support is currently limited to MacOS and Linux with CUDA GPUs. We're looking to add support for Windows CUDA soon, and other GPU vendors.
### Building ollama-mlx
The `ollama-mlx` binary is a separate build of Ollama with MLX support enabled. This enables experimental features like image generation.
#### macOS (Apple Silicon and Intel)
```bash
# Build MLX backend libraries
cmake --preset MLX
cmake --build --preset MLX --parallel
cmake --install build --component MLX
# Build ollama-mlx binary
go build -tags mlx -o ollama-mlx .
```
#### Linux (CUDA)
On Linux, use the preset "MLX CUDA 13" or "MLX CUDA 12" to enable CUDA with the default Ollama NVIDIA GPU architectures enabled:
```bash
# Build MLX backend libraries with CUDA support
cmake --preset 'MLX CUDA 13'
cmake --build --preset 'MLX CUDA 13' --parallel
cmake --install build --component MLX
# Build ollama-mlx binary
CGO_CFLAGS="-O3 -I$(pwd)/build/_deps/mlx-c-src" \
CGO_LDFLAGS="-L$(pwd)/build/lib/ollama -lmlxc -lmlx" \
go build -tags mlx -o ollama-mlx .
```
#### Using build scripts
The build scripts automatically create the `ollama-mlx` binary:
- **macOS**: `./scripts/build_darwin.sh` produces `dist/darwin/ollama-mlx`
- **Linux**: `./scripts/build_linux.sh` produces `ollama-mlx` in the output archives
## Image Generation
Image generation is built into the `ollama-mlx` binary. Run `ollama-mlx serve` to start the server with image generation support enabled.

View File

@@ -25,14 +25,6 @@ import (
"github.com/ollama/ollama/x/tools" "github.com/ollama/ollama/x/tools"
) )
// MultilineState tracks the state of multiline input
type MultilineState int
const (
MultilineNone MultilineState = iota
MultilineSystem
)
// Tool output capping constants // Tool output capping constants
const ( const (
// localModelTokenLimit is the token limit for local models (smaller context). // localModelTokenLimit is the token limit for local models (smaller context).
@@ -656,7 +648,7 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
Prompt: ">>> ", Prompt: ">>> ",
AltPrompt: "... ", AltPrompt: "... ",
Placeholder: "Send a message (/? for help)", Placeholder: "Send a message (/? for help)",
AltPlaceholder: `Use """ to end multi-line input`, AltPlaceholder: "Press Enter to send",
}) })
if err != nil { if err != nil {
return err return err
@@ -707,7 +699,6 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
var sb strings.Builder var sb strings.Builder
var format string var format string
var system string var system string
var multiline MultilineState = MultilineNone
for { for {
line, err := scanner.Readline() line, err := scanner.Readline()
@@ -721,37 +712,12 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
} }
scanner.Prompt.UseAlt = false scanner.Prompt.UseAlt = false
sb.Reset() sb.Reset()
multiline = MultilineNone
continue continue
case err != nil: case err != nil:
return err return err
} }
switch { switch {
case multiline != MultilineNone:
// check if there's a multiline terminating string
before, ok := strings.CutSuffix(line, `"""`)
sb.WriteString(before)
if !ok {
fmt.Fprintln(&sb)
continue
}
switch multiline {
case MultilineSystem:
system = sb.String()
newMessage := api.Message{Role: "system", Content: system}
if len(messages) > 0 && messages[len(messages)-1].Role == "system" {
messages[len(messages)-1] = newMessage
} else {
messages = append(messages, newMessage)
}
fmt.Println("Set system message.")
sb.Reset()
}
multiline = MultilineNone
scanner.Prompt.UseAlt = false
case strings.HasPrefix(line, "/exit"), strings.HasPrefix(line, "/bye"): case strings.HasPrefix(line, "/exit"), strings.HasPrefix(line, "/bye"):
return nil return nil
case strings.HasPrefix(line, "/clear"): case strings.HasPrefix(line, "/clear"):
@@ -860,41 +826,18 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
options[args[2]] = fp[args[2]] options[args[2]] = fp[args[2]]
case "system": case "system":
if len(args) < 3 { if len(args) < 3 {
fmt.Println("Usage: /set system <message> or /set system \"\"\"<multi-line message>\"\"\"") fmt.Println("Usage: /set system <message>")
continue continue
} }
multiline = MultilineSystem system = strings.Join(args[2:], " ")
newMessage := api.Message{Role: "system", Content: system}
line := strings.Join(args[2:], " ")
line, ok := strings.CutPrefix(line, `"""`)
if !ok {
multiline = MultilineNone
} else {
// only cut suffix if the line is multiline
line, ok = strings.CutSuffix(line, `"""`)
if ok {
multiline = MultilineNone
}
}
sb.WriteString(line)
if multiline != MultilineNone {
scanner.Prompt.UseAlt = true
continue
}
system = sb.String()
newMessage := api.Message{Role: "system", Content: sb.String()}
// Check if the slice is not empty and the last message is from 'system'
if len(messages) > 0 && messages[len(messages)-1].Role == "system" { if len(messages) > 0 && messages[len(messages)-1].Role == "system" {
// Replace the last message
messages[len(messages)-1] = newMessage messages[len(messages)-1] = newMessage
} else { } else {
messages = append(messages, newMessage) messages = append(messages, newMessage)
} }
fmt.Println("Set system message.") fmt.Println("Set system message.")
sb.Reset()
continue continue
default: default:
fmt.Printf("Unknown command '/set %s'. Type /? for help\n", args[1]) fmt.Printf("Unknown command '/set %s'. Type /? for help\n", args[1])
@@ -1081,7 +1024,7 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
sb.WriteString(line) sb.WriteString(line)
} }
if sb.Len() > 0 && multiline == MultilineNone { if sb.Len() > 0 {
newMessage := api.Message{Role: "user", Content: sb.String()} newMessage := api.Message{Role: "user", Content: sb.String()}
messages = append(messages, newMessage) messages = append(messages, newMessage)

282
x/create/client/create.go Normal file
View File

@@ -0,0 +1,282 @@
// Package client provides client-side model creation for safetensors-based models.
//
// This package is in x/ because the safetensors model storage format is under development.
// It also exists to break an import cycle: server imports x/create, so x/create
// cannot import server. This sub-package can import server because server doesn't
// import it.
package client
import (
"bytes"
"encoding/json"
"fmt"
"io"
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/server"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/x/create"
)
// MinOllamaVersion is the minimum Ollama version required for safetensors models.
const MinOllamaVersion = "0.14.0"
// ModelfileConfig holds configuration extracted from a Modelfile.
type ModelfileConfig struct {
Template string
System string
License string
}
// CreateOptions holds all options for model creation.
type CreateOptions struct {
ModelName string
ModelDir string
Quantize string // "fp8" for quantization
Modelfile *ModelfileConfig // template/system/license from Modelfile
}
// CreateModel imports a model from a local directory.
// This creates blobs and manifest directly on disk, bypassing the HTTP API.
// Automatically detects model type (safetensors LLM vs image gen) and routes accordingly.
func CreateModel(opts CreateOptions, p *progress.Progress) error {
// Detect model type
isSafetensors := create.IsSafetensorsModelDir(opts.ModelDir)
isImageGen := create.IsTensorModelDir(opts.ModelDir)
if !isSafetensors && !isImageGen {
return fmt.Errorf("%s is not a supported model directory (needs config.json + *.safetensors or model_index.json)", opts.ModelDir)
}
// Determine model type settings
var modelType, spinnerKey string
var capabilities []string
if isSafetensors {
modelType = "safetensors model"
spinnerKey = "create"
capabilities = []string{"completion"}
} else {
modelType = "image generation model"
spinnerKey = "imagegen"
capabilities = []string{"image"}
}
// Set up progress spinner
statusMsg := "importing " + modelType
spinner := progress.NewSpinner(statusMsg)
p.Add(spinnerKey, spinner)
progressFn := func(msg string) {
spinner.Stop()
statusMsg = msg
spinner = progress.NewSpinner(statusMsg)
p.Add(spinnerKey, spinner)
}
// Create the model using shared callbacks
var err error
if isSafetensors {
err = create.CreateSafetensorsModel(
opts.ModelName, opts.ModelDir, opts.Quantize,
newLayerCreator(), newTensorLayerCreator(),
newManifestWriter(opts, capabilities),
progressFn,
)
} else {
err = create.CreateImageGenModel(
opts.ModelName, opts.ModelDir, opts.Quantize,
newLayerCreator(), newTensorLayerCreator(),
newManifestWriter(opts, capabilities),
progressFn,
)
}
spinner.Stop()
if err != nil {
return err
}
fmt.Printf("Created %s '%s'\n", modelType, opts.ModelName)
return nil
}
// newLayerCreator returns a LayerCreator callback for creating config/JSON layers.
func newLayerCreator() create.LayerCreator {
return func(r io.Reader, mediaType, name string) (create.LayerInfo, error) {
layer, err := server.NewLayer(r, mediaType)
if err != nil {
return create.LayerInfo{}, err
}
return create.LayerInfo{
Digest: layer.Digest,
Size: layer.Size,
MediaType: layer.MediaType,
Name: name,
}, nil
}
}
// newTensorLayerCreator returns a QuantizingTensorLayerCreator callback for creating tensor layers.
// When quantize is non-empty, returns multiple layers (weight + scales + optional qbias).
func newTensorLayerCreator() create.QuantizingTensorLayerCreator {
return func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]create.LayerInfo, error) {
if quantize != "" {
return createQuantizedLayers(r, name, dtype, shape, quantize)
}
return createUnquantizedLayer(r, name)
}
}
// createQuantizedLayers quantizes a tensor and returns the resulting layers.
func createQuantizedLayers(r io.Reader, name, dtype string, shape []int32, quantize string) ([]create.LayerInfo, error) {
if !QuantizeSupported() {
return nil, fmt.Errorf("quantization requires MLX support")
}
// Quantize the tensor
qweightData, scalesData, qbiasData, _, _, _, err := quantizeTensor(r, name, dtype, shape, quantize)
if err != nil {
return nil, fmt.Errorf("failed to quantize %s: %w", name, err)
}
// Create layer for quantized weight
weightLayer, err := server.NewLayer(bytes.NewReader(qweightData), server.MediaTypeImageTensor)
if err != nil {
return nil, err
}
// Create layer for scales
scalesLayer, err := server.NewLayer(bytes.NewReader(scalesData), server.MediaTypeImageTensor)
if err != nil {
return nil, err
}
layers := []create.LayerInfo{
{
Digest: weightLayer.Digest,
Size: weightLayer.Size,
MediaType: weightLayer.MediaType,
Name: name,
},
{
Digest: scalesLayer.Digest,
Size: scalesLayer.Size,
MediaType: scalesLayer.MediaType,
Name: name + "_scale",
},
}
// Add qbiases layer if present (affine mode)
if qbiasData != nil {
qbiasLayer, err := server.NewLayer(bytes.NewReader(qbiasData), server.MediaTypeImageTensor)
if err != nil {
return nil, err
}
layers = append(layers, create.LayerInfo{
Digest: qbiasLayer.Digest,
Size: qbiasLayer.Size,
MediaType: qbiasLayer.MediaType,
Name: name + "_qbias",
})
}
return layers, nil
}
// createUnquantizedLayer creates a single tensor layer without quantization.
func createUnquantizedLayer(r io.Reader, name string) ([]create.LayerInfo, error) {
layer, err := server.NewLayer(r, server.MediaTypeImageTensor)
if err != nil {
return nil, err
}
return []create.LayerInfo{
{
Digest: layer.Digest,
Size: layer.Size,
MediaType: layer.MediaType,
Name: name,
},
}, nil
}
// newManifestWriter returns a ManifestWriter callback for writing the model manifest.
func newManifestWriter(opts CreateOptions, capabilities []string) create.ManifestWriter {
return func(modelName string, config create.LayerInfo, layers []create.LayerInfo) error {
name := model.ParseName(modelName)
if !name.IsValid() {
return fmt.Errorf("invalid model name: %s", modelName)
}
// Create config blob with version requirement
configData := model.ConfigV2{
ModelFormat: "safetensors",
Capabilities: capabilities,
Requires: MinOllamaVersion,
}
configJSON, err := json.Marshal(configData)
if err != nil {
return fmt.Errorf("failed to marshal config: %w", err)
}
// Create config layer blob
configLayer, err := server.NewLayer(bytes.NewReader(configJSON), "application/vnd.docker.container.image.v1+json")
if err != nil {
return fmt.Errorf("failed to create config layer: %w", err)
}
// Convert LayerInfo to server.Layer
serverLayers := make([]server.Layer, 0, len(layers))
for _, l := range layers {
serverLayers = append(serverLayers, server.Layer{
MediaType: l.MediaType,
Digest: l.Digest,
Size: l.Size,
Name: l.Name,
})
}
// Add Modelfile layers if present
if opts.Modelfile != nil {
modelfileLayers, err := createModelfileLayers(opts.Modelfile)
if err != nil {
return err
}
serverLayers = append(serverLayers, modelfileLayers...)
}
return server.WriteManifest(name, configLayer, serverLayers)
}
}
// createModelfileLayers creates layers for template, system, and license from Modelfile config.
func createModelfileLayers(mf *ModelfileConfig) ([]server.Layer, error) {
var layers []server.Layer
if mf.Template != "" {
layer, err := server.NewLayer(bytes.NewReader([]byte(mf.Template)), "application/vnd.ollama.image.template")
if err != nil {
return nil, fmt.Errorf("failed to create template layer: %w", err)
}
layers = append(layers, layer)
}
if mf.System != "" {
layer, err := server.NewLayer(bytes.NewReader([]byte(mf.System)), "application/vnd.ollama.image.system")
if err != nil {
return nil, fmt.Errorf("failed to create system layer: %w", err)
}
layers = append(layers, layer)
}
if mf.License != "" {
layer, err := server.NewLayer(bytes.NewReader([]byte(mf.License)), "application/vnd.ollama.image.license")
if err != nil {
return nil, fmt.Errorf("failed to create license layer: %w", err)
}
layers = append(layers, layer)
}
return layers, nil
}

View File

@@ -0,0 +1,146 @@
package client
import (
"testing"
)
func TestModelfileConfig(t *testing.T) {
// Test that ModelfileConfig struct works as expected
config := &ModelfileConfig{
Template: "{{ .Prompt }}",
System: "You are a helpful assistant.",
License: "MIT",
}
if config.Template != "{{ .Prompt }}" {
t.Errorf("Template = %q, want %q", config.Template, "{{ .Prompt }}")
}
if config.System != "You are a helpful assistant." {
t.Errorf("System = %q, want %q", config.System, "You are a helpful assistant.")
}
if config.License != "MIT" {
t.Errorf("License = %q, want %q", config.License, "MIT")
}
}
func TestModelfileConfig_Empty(t *testing.T) {
config := &ModelfileConfig{}
if config.Template != "" {
t.Errorf("Template should be empty, got %q", config.Template)
}
if config.System != "" {
t.Errorf("System should be empty, got %q", config.System)
}
if config.License != "" {
t.Errorf("License should be empty, got %q", config.License)
}
}
func TestModelfileConfig_PartialFields(t *testing.T) {
// Test config with only some fields set
config := &ModelfileConfig{
Template: "{{ .Prompt }}",
// System and License intentionally empty
}
if config.Template == "" {
t.Error("Template should not be empty")
}
if config.System != "" {
t.Error("System should be empty")
}
if config.License != "" {
t.Error("License should be empty")
}
}
func TestMinOllamaVersion(t *testing.T) {
// Verify the minimum version constant is set
if MinOllamaVersion == "" {
t.Error("MinOllamaVersion should not be empty")
}
if MinOllamaVersion != "0.14.0" {
t.Errorf("MinOllamaVersion = %q, want %q", MinOllamaVersion, "0.14.0")
}
}
func TestCreateModel_InvalidDir(t *testing.T) {
// Test that CreateModel returns error for invalid directory
err := CreateModel(CreateOptions{
ModelName: "test-model",
ModelDir: "/nonexistent/path",
}, nil)
if err == nil {
t.Error("expected error for nonexistent directory, got nil")
}
}
func TestCreateModel_NotSafetensorsDir(t *testing.T) {
// Test that CreateModel returns error for directory without safetensors
dir := t.TempDir()
err := CreateModel(CreateOptions{
ModelName: "test-model",
ModelDir: dir,
}, nil)
if err == nil {
t.Error("expected error for empty directory, got nil")
}
}
func TestCreateOptions(t *testing.T) {
opts := CreateOptions{
ModelName: "my-model",
ModelDir: "/path/to/model",
Quantize: "fp8",
Modelfile: &ModelfileConfig{
Template: "test",
System: "system",
License: "MIT",
},
}
if opts.ModelName != "my-model" {
t.Errorf("ModelName = %q, want %q", opts.ModelName, "my-model")
}
if opts.ModelDir != "/path/to/model" {
t.Errorf("ModelDir = %q, want %q", opts.ModelDir, "/path/to/model")
}
if opts.Quantize != "fp8" {
t.Errorf("Quantize = %q, want %q", opts.Quantize, "fp8")
}
if opts.Modelfile == nil {
t.Error("Modelfile should not be nil")
}
if opts.Modelfile.Template != "test" {
t.Errorf("Modelfile.Template = %q, want %q", opts.Modelfile.Template, "test")
}
}
func TestCreateOptions_Defaults(t *testing.T) {
opts := CreateOptions{
ModelName: "test",
ModelDir: "/tmp",
}
// Quantize should default to empty
if opts.Quantize != "" {
t.Errorf("Quantize should be empty by default, got %q", opts.Quantize)
}
// Modelfile should default to nil
if opts.Modelfile != nil {
t.Error("Modelfile should be nil by default")
}
}
func TestQuantizeSupported(t *testing.T) {
// This just verifies the function exists and returns a boolean
// The actual value depends on build tags (mlx vs non-mlx)
supported := QuantizeSupported()
// In non-mlx builds, this should be false
// We can't easily test both cases, so just verify it returns something
_ = supported
}

View File

@@ -11,10 +11,11 @@ import (
"github.com/ollama/ollama/x/imagegen/mlx" "github.com/ollama/ollama/x/imagegen/mlx"
) )
// quantizeTensor loads a tensor from safetensors format, quantizes it to affine int8, // quantizeTensor loads a tensor from safetensors format, quantizes it,
// and returns safetensors data for the quantized weights, scales, and biases. // and returns safetensors data for the quantized weights, scales, and biases.
// Supported quantization types: "fp8" (affine 8-bit)
// Uses MLX's native SaveSafetensors to ensure correct dtype handling (especially uint32 for quantized weights). // Uses MLX's native SaveSafetensors to ensure correct dtype handling (especially uint32 for quantized weights).
func quantizeTensor(r io.Reader, name, dtype string, shape []int32) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) { func quantizeTensor(r io.Reader, name, dtype string, shape []int32, quantize string) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) {
tmpDir := ensureTempDir() tmpDir := ensureTempDir()
// Read safetensors data to a temp file (LoadSafetensorsNative needs a path) // Read safetensors data to a temp file (LoadSafetensorsNative needs a path)
@@ -50,9 +51,18 @@ func quantizeTensor(r io.Reader, name, dtype string, shape []int32) (qweightData
mlx.Eval(arr) mlx.Eval(arr)
} }
// Quantize with affine mode: group_size=32, bits=8 // Quantize based on quantization type
// Note: mxfp8 mode doesn't have matmul kernels in MLX, affine mode does var qweight, scales, qbiases *mlx.Array
qweight, scales, qbiases := mlx.Quantize(arr, 32, 8, "affine") switch quantize {
case "fp4":
// affine mode: group_size=32, bits=4
qweight, scales, qbiases = mlx.Quantize(arr, 32, 4, "affine")
case "fp8":
// affine mode: group_size=32, bits=8
qweight, scales, qbiases = mlx.Quantize(arr, 32, 8, "affine")
default:
return nil, nil, nil, nil, nil, nil, fmt.Errorf("unsupported quantization type: %s", quantize)
}
// Eval and make contiguous for data access // Eval and make contiguous for data access
qweight = mlx.Contiguous(qweight) qweight = mlx.Contiguous(qweight)

View File

@@ -8,7 +8,7 @@ import (
) )
// quantizeTensor is not available without MLX // quantizeTensor is not available without MLX
func quantizeTensor(r io.Reader, name, dtype string, shape []int32) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) { func quantizeTensor(r io.Reader, name, dtype string, shape []int32, quantize string) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("quantization requires MLX support (build with mlx tag)") return nil, nil, nil, nil, nil, nil, fmt.Errorf("quantization requires MLX support (build with mlx tag)")
} }

399
x/create/create.go Normal file
View File

@@ -0,0 +1,399 @@
package create
import (
"encoding/json"
"fmt"
"io"
"os"
"path/filepath"
"slices"
"strings"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// ModelConfig represents the config blob stored with a model.
type ModelConfig struct {
ModelFormat string `json:"model_format"`
Capabilities []string `json:"capabilities"`
}
// Manifest represents the manifest JSON structure.
type Manifest struct {
SchemaVersion int `json:"schemaVersion"`
MediaType string `json:"mediaType"`
Config ManifestLayer `json:"config"`
Layers []ManifestLayer `json:"layers"`
}
// ManifestLayer represents a layer in the manifest.
type ManifestLayer struct {
MediaType string `json:"mediaType"`
Digest string `json:"digest"`
Size int64 `json:"size"`
Name string `json:"name,omitempty"`
}
// defaultManifestDir returns the manifest storage directory.
func defaultManifestDir() string {
return filepath.Join(envconfig.Models(), "manifests")
}
// defaultBlobDir returns the blob storage directory.
func defaultBlobDir() string {
return filepath.Join(envconfig.Models(), "blobs")
}
// resolveManifestPath converts a model name to a manifest file path.
func resolveManifestPath(modelName string) string {
host := "registry.ollama.ai"
namespace := "library"
name := modelName
tag := "latest"
if idx := strings.LastIndex(name, ":"); idx != -1 {
tag = name[idx+1:]
name = name[:idx]
}
parts := strings.Split(name, "/")
switch len(parts) {
case 3:
host = parts[0]
namespace = parts[1]
name = parts[2]
case 2:
namespace = parts[0]
name = parts[1]
}
return filepath.Join(defaultManifestDir(), host, namespace, name, tag)
}
// loadManifest loads a manifest for the given model name.
func loadManifest(modelName string) (*Manifest, error) {
manifestPath := resolveManifestPath(modelName)
data, err := os.ReadFile(manifestPath)
if err != nil {
return nil, err
}
var manifest Manifest
if err := json.Unmarshal(data, &manifest); err != nil {
return nil, err
}
return &manifest, nil
}
// loadModelConfig loads the config blob for a model.
func loadModelConfig(modelName string) (*ModelConfig, error) {
manifest, err := loadManifest(modelName)
if err != nil {
return nil, err
}
// Read the config blob
blobName := strings.Replace(manifest.Config.Digest, ":", "-", 1)
blobPath := filepath.Join(defaultBlobDir(), blobName)
data, err := os.ReadFile(blobPath)
if err != nil {
return nil, err
}
var config ModelConfig
if err := json.Unmarshal(data, &config); err != nil {
return nil, err
}
return &config, nil
}
// IsSafetensorsModel checks if a model was created with the experimental
// safetensors builder by checking the model format in the config.
func IsSafetensorsModel(modelName string) bool {
config, err := loadModelConfig(modelName)
if err != nil {
return false
}
return config.ModelFormat == "safetensors"
}
// IsSafetensorsLLMModel checks if a model is a safetensors LLM model
// (has completion capability, not image generation).
func IsSafetensorsLLMModel(modelName string) bool {
config, err := loadModelConfig(modelName)
if err != nil {
return false
}
return config.ModelFormat == "safetensors" && slices.Contains(config.Capabilities, "completion")
}
// IsImageGenModel checks if a model is an image generation model
// (has image capability).
func IsImageGenModel(modelName string) bool {
config, err := loadModelConfig(modelName)
if err != nil {
return false
}
return config.ModelFormat == "safetensors" && slices.Contains(config.Capabilities, "image")
}
// GetModelArchitecture returns the architecture from the model's config.json layer.
func GetModelArchitecture(modelName string) (string, error) {
manifest, err := loadManifest(modelName)
if err != nil {
return "", err
}
// Find the config.json layer
for _, layer := range manifest.Layers {
if layer.Name == "config.json" && layer.MediaType == "application/vnd.ollama.image.json" {
blobName := strings.Replace(layer.Digest, ":", "-", 1)
blobPath := filepath.Join(defaultBlobDir(), blobName)
data, err := os.ReadFile(blobPath)
if err != nil {
return "", err
}
var cfg struct {
Architectures []string `json:"architectures"`
ModelType string `json:"model_type"`
}
if err := json.Unmarshal(data, &cfg); err != nil {
return "", err
}
// Prefer model_type, fall back to first architecture
if cfg.ModelType != "" {
return cfg.ModelType, nil
}
if len(cfg.Architectures) > 0 {
return cfg.Architectures[0], nil
}
}
}
return "", fmt.Errorf("architecture not found in model config")
}
// IsTensorModelDir checks if the directory contains a diffusers-style tensor model
// by looking for model_index.json, which is the standard diffusers pipeline config.
func IsTensorModelDir(dir string) bool {
_, err := os.Stat(filepath.Join(dir, "model_index.json"))
return err == nil
}
// IsSafetensorsModelDir checks if the directory contains a standard safetensors model
// by looking for config.json and at least one .safetensors file.
func IsSafetensorsModelDir(dir string) bool {
// Must have config.json
if _, err := os.Stat(filepath.Join(dir, "config.json")); err != nil {
return false
}
// Must have at least one .safetensors file
entries, err := os.ReadDir(dir)
if err != nil {
return false
}
for _, entry := range entries {
if strings.HasSuffix(entry.Name(), ".safetensors") {
return true
}
}
return false
}
// LayerInfo holds metadata for a created layer.
type LayerInfo struct {
Digest string
Size int64
MediaType string
Name string // Path-style name: "component/tensor" or "path/to/config.json"
}
// LayerCreator is called to create a blob layer.
// name is the path-style name (e.g., "tokenizer/tokenizer.json")
type LayerCreator func(r io.Reader, mediaType, name string) (LayerInfo, error)
// TensorLayerCreator creates a tensor blob layer with metadata.
// name is the path-style name including component (e.g., "text_encoder/model.embed_tokens.weight")
type TensorLayerCreator func(r io.Reader, name, dtype string, shape []int32) (LayerInfo, error)
// QuantizingTensorLayerCreator creates tensor layers with optional quantization.
// When quantize is non-empty (e.g., "fp8"), returns multiple layers (weight + scales + biases).
type QuantizingTensorLayerCreator func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error)
// ManifestWriter writes the manifest file.
type ManifestWriter func(modelName string, config LayerInfo, layers []LayerInfo) error
// ShouldQuantize returns true if a tensor should be quantized.
// For image gen models (component non-empty): quantizes linear weights, skipping VAE, embeddings, norms.
// For LLM models (component empty): quantizes linear weights, skipping embeddings, norms, and small tensors.
func ShouldQuantize(name, component string) bool {
// Image gen specific: skip VAE entirely
if component == "vae" {
return false
}
// Skip embeddings
if strings.Contains(name, "embed") {
return false
}
// Skip layer norms and RMS norms
if strings.Contains(name, "norm") || strings.Contains(name, "ln_") || strings.Contains(name, "layernorm") {
return false
}
// Skip biases
if strings.HasSuffix(name, ".bias") {
return false
}
// Only quantize weights
return strings.HasSuffix(name, ".weight")
}
// ShouldQuantizeTensor returns true if a tensor should be quantized based on name and shape.
// This is a more detailed check that also considers tensor dimensions.
func ShouldQuantizeTensor(name string, shape []int32) bool {
// Use basic name-based check first
if !ShouldQuantize(name, "") {
return false
}
// Only quantize 2D tensors (linear layers) - skip 1D (biases, norms) and higher-D (convolutions if any)
if len(shape) != 2 {
return false
}
// Skip small tensors (less than 1024 elements) - not worth quantizing
if len(shape) >= 2 && int64(shape[0])*int64(shape[1]) < 1024 {
return false
}
// MLX quantization requires last dimension to be divisible by group size (32)
if shape[len(shape)-1]%32 != 0 {
return false
}
return true
}
// CreateSafetensorsModel imports a standard safetensors model from a directory.
// This handles Hugging Face style models with config.json and *.safetensors files.
// Stores each tensor as a separate blob for fine-grained deduplication.
// If quantize is non-empty (e.g., "fp8"), eligible tensors will be quantized.
func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
var layers []LayerInfo
var configLayer LayerInfo
entries, err := os.ReadDir(modelDir)
if err != nil {
return fmt.Errorf("failed to read directory: %w", err)
}
// Process all safetensors files
for _, entry := range entries {
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".safetensors") {
continue
}
stPath := filepath.Join(modelDir, entry.Name())
// Extract individual tensors from safetensors file
extractor, err := safetensors.OpenForExtraction(stPath)
if err != nil {
return fmt.Errorf("failed to open %s: %w", stPath, err)
}
tensorNames := extractor.ListTensors()
quantizeMsg := ""
if quantize != "" {
quantizeMsg = fmt.Sprintf(", quantizing to %s", quantize)
}
fn(fmt.Sprintf("importing %s (%d tensors%s)", entry.Name(), len(tensorNames), quantizeMsg))
for _, tensorName := range tensorNames {
td, err := extractor.GetTensor(tensorName)
if err != nil {
extractor.Close()
return fmt.Errorf("failed to get tensor %s: %w", tensorName, err)
}
// Determine quantization type for this tensor (empty string if not quantizing)
quantizeType := ""
if quantize != "" && ShouldQuantizeTensor(tensorName, td.Shape) {
quantizeType = quantize
}
// Store as minimal safetensors format (88 bytes header overhead)
// This enables native mmap loading via mlx_load_safetensors
// createTensorLayer returns multiple layers if quantizing (weight + scales)
newLayers, err := createTensorLayer(td.SafetensorsReader(), tensorName, td.Dtype, td.Shape, quantizeType)
if err != nil {
extractor.Close()
return fmt.Errorf("failed to create layer for %s: %w", tensorName, err)
}
layers = append(layers, newLayers...)
}
extractor.Close()
}
// Process all JSON config files
for _, entry := range entries {
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".json") {
continue
}
// Skip the index file as we don't need it after extraction
if entry.Name() == "model.safetensors.index.json" {
continue
}
cfgPath := entry.Name()
fullPath := filepath.Join(modelDir, cfgPath)
fn(fmt.Sprintf("importing config %s", cfgPath))
f, err := os.Open(fullPath)
if err != nil {
return fmt.Errorf("failed to open %s: %w", cfgPath, err)
}
layer, err := createLayer(f, "application/vnd.ollama.image.json", cfgPath)
f.Close()
if err != nil {
return fmt.Errorf("failed to create layer for %s: %w", cfgPath, err)
}
// Use config.json as the config layer
if cfgPath == "config.json" {
configLayer = layer
}
layers = append(layers, layer)
}
if configLayer.Digest == "" {
return fmt.Errorf("config.json not found in %s", modelDir)
}
fn(fmt.Sprintf("writing manifest for %s", modelName))
if err := writeManifest(modelName, configLayer, layers); err != nil {
return fmt.Errorf("failed to write manifest: %w", err)
}
fn(fmt.Sprintf("successfully imported %s with %d layers", modelName, len(layers)))
return nil
}

752
x/create/create_test.go Normal file
View File

@@ -0,0 +1,752 @@
package create
import (
"bytes"
"encoding/binary"
"encoding/json"
"io"
"os"
"path/filepath"
"strings"
"testing"
)
func TestIsTensorModelDir(t *testing.T) {
tests := []struct {
name string
setup func(dir string) error
expected bool
}{
{
name: "valid diffusers model with model_index.json",
setup: func(dir string) error {
return os.WriteFile(filepath.Join(dir, "model_index.json"), []byte(`{"_class_name": "FluxPipeline"}`), 0o644)
},
expected: true,
},
{
name: "empty directory",
setup: func(dir string) error {
return nil
},
expected: false,
},
{
name: "directory with other files but no model_index.json",
setup: func(dir string) error {
return os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{}`), 0o644)
},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
dir := t.TempDir()
if err := tt.setup(dir); err != nil {
t.Fatalf("setup failed: %v", err)
}
got := IsTensorModelDir(dir)
if got != tt.expected {
t.Errorf("IsTensorModelDir() = %v, want %v", got, tt.expected)
}
})
}
}
func TestIsSafetensorsModelDir(t *testing.T) {
tests := []struct {
name string
setup func(dir string) error
expected bool
}{
{
name: "valid safetensors model with config.json and .safetensors file",
setup: func(dir string) error {
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{"model_type": "gemma3"}`), 0o644); err != nil {
return err
}
return os.WriteFile(filepath.Join(dir, "model.safetensors"), []byte("dummy"), 0o644)
},
expected: true,
},
{
name: "config.json only, no safetensors files",
setup: func(dir string) error {
return os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{}`), 0o644)
},
expected: false,
},
{
name: "safetensors file only, no config.json",
setup: func(dir string) error {
return os.WriteFile(filepath.Join(dir, "model.safetensors"), []byte("dummy"), 0o644)
},
expected: false,
},
{
name: "empty directory",
setup: func(dir string) error {
return nil
},
expected: false,
},
{
name: "multiple safetensors files with config.json",
setup: func(dir string) error {
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{}`), 0o644); err != nil {
return err
}
if err := os.WriteFile(filepath.Join(dir, "model-00001-of-00002.safetensors"), []byte("dummy"), 0o644); err != nil {
return err
}
return os.WriteFile(filepath.Join(dir, "model-00002-of-00002.safetensors"), []byte("dummy"), 0o644)
},
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
dir := t.TempDir()
if err := tt.setup(dir); err != nil {
t.Fatalf("setup failed: %v", err)
}
got := IsSafetensorsModelDir(dir)
if got != tt.expected {
t.Errorf("IsSafetensorsModelDir() = %v, want %v", got, tt.expected)
}
})
}
}
func TestIsSafetensorsModelDir_NonexistentDir(t *testing.T) {
got := IsSafetensorsModelDir("/nonexistent/path/that/does/not/exist")
if got != false {
t.Errorf("IsSafetensorsModelDir() = %v for nonexistent dir, want false", got)
}
}
// createMinimalSafetensors creates a minimal valid safetensors file with one tensor
func createMinimalSafetensors(t *testing.T, path string) {
t.Helper()
// Create a minimal safetensors file with a single float32 tensor
header := map[string]interface{}{
"test_tensor": map[string]interface{}{
"dtype": "F32",
"shape": []int{2, 2},
"data_offsets": []int{0, 16}, // 4 float32 values = 16 bytes
},
}
headerJSON, err := json.Marshal(header)
if err != nil {
t.Fatalf("failed to marshal header: %v", err)
}
// Pad header to 8-byte alignment
padding := (8 - len(headerJSON)%8) % 8
headerJSON = append(headerJSON, bytes.Repeat([]byte(" "), padding)...)
// Write file
f, err := os.Create(path)
if err != nil {
t.Fatalf("failed to create file: %v", err)
}
defer f.Close()
// Write header size (8 bytes, little endian)
if err := binary.Write(f, binary.LittleEndian, uint64(len(headerJSON))); err != nil {
t.Fatalf("failed to write header size: %v", err)
}
// Write header
if _, err := f.Write(headerJSON); err != nil {
t.Fatalf("failed to write header: %v", err)
}
// Write tensor data (16 bytes of zeros for 4 float32 values)
if _, err := f.Write(make([]byte, 16)); err != nil {
t.Fatalf("failed to write tensor data: %v", err)
}
}
func TestCreateSafetensorsModel(t *testing.T) {
dir := t.TempDir()
// Create config.json
configJSON := `{"model_type": "test", "architectures": ["TestModel"]}`
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil {
t.Fatalf("failed to write config.json: %v", err)
}
// Create a minimal safetensors file
createMinimalSafetensors(t, filepath.Join(dir, "model.safetensors"))
// Track what was created
var createdLayers []LayerInfo
var manifestWritten bool
var manifestModelName string
var manifestConfigLayer LayerInfo
var manifestLayers []LayerInfo
var statusMessages []string
// Mock callbacks
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
data, err := io.ReadAll(r)
if err != nil {
return LayerInfo{}, err
}
layer := LayerInfo{
Digest: "sha256:test",
Size: int64(len(data)),
MediaType: mediaType,
Name: name,
}
createdLayers = append(createdLayers, layer)
return layer, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
data, err := io.ReadAll(r)
if err != nil {
return nil, err
}
layer := LayerInfo{
Digest: "sha256:tensor_" + name,
Size: int64(len(data)),
MediaType: "application/vnd.ollama.image.tensor",
Name: name,
}
createdLayers = append(createdLayers, layer)
return []LayerInfo{layer}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
manifestWritten = true
manifestModelName = modelName
manifestConfigLayer = config
manifestLayers = layers
return nil
}
progressFn := func(status string) {
statusMessages = append(statusMessages, status)
}
// Run CreateSafetensorsModel
err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
if err != nil {
t.Fatalf("CreateSafetensorsModel failed: %v", err)
}
// Verify manifest was written
if !manifestWritten {
t.Error("manifest was not written")
}
if manifestModelName != "test-model" {
t.Errorf("manifest model name = %q, want %q", manifestModelName, "test-model")
}
// Verify config layer was set
if manifestConfigLayer.Name != "config.json" {
t.Errorf("config layer name = %q, want %q", manifestConfigLayer.Name, "config.json")
}
// Verify we have at least one tensor and one config layer
hasTensor := false
hasConfig := false
for _, layer := range manifestLayers {
if layer.Name == "test_tensor" {
hasTensor = true
}
if layer.Name == "config.json" {
hasConfig = true
}
}
if !hasTensor {
t.Error("no tensor layer found in manifest")
}
if !hasConfig {
t.Error("no config layer found in manifest")
}
// Verify status messages were sent
if len(statusMessages) == 0 {
t.Error("no status messages received")
}
}
func TestCreateSafetensorsModel_NoConfigJson(t *testing.T) {
dir := t.TempDir()
// Create only a safetensors file, no config.json
createMinimalSafetensors(t, filepath.Join(dir, "model.safetensors"))
// Mock callbacks (minimal)
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
io.ReadAll(r)
return LayerInfo{Name: name}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
io.ReadAll(r)
return []LayerInfo{{Name: name}}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
return nil
}
progressFn := func(status string) {}
err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
if err == nil {
t.Error("expected error for missing config.json, got nil")
}
}
func TestCreateSafetensorsModel_EmptyDir(t *testing.T) {
dir := t.TempDir()
// Mock callbacks
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
return LayerInfo{}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
return []LayerInfo{{}}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
return nil
}
progressFn := func(status string) {}
err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
if err == nil {
t.Error("expected error for empty directory, got nil")
}
}
func TestCreateSafetensorsModel_SkipsIndexJson(t *testing.T) {
dir := t.TempDir()
// Create config.json
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{}`), 0o644); err != nil {
t.Fatalf("failed to write config.json: %v", err)
}
// Create model.safetensors.index.json (should be skipped)
indexJSON := `{"metadata": {"total_size": 100}, "weight_map": {}}`
if err := os.WriteFile(filepath.Join(dir, "model.safetensors.index.json"), []byte(indexJSON), 0o644); err != nil {
t.Fatalf("failed to write index.json: %v", err)
}
// Create a minimal safetensors file
createMinimalSafetensors(t, filepath.Join(dir, "model.safetensors"))
var configNames []string
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
io.ReadAll(r)
configNames = append(configNames, name)
return LayerInfo{Name: name, Digest: "sha256:test"}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
io.ReadAll(r)
return []LayerInfo{{Name: name}}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
return nil
}
progressFn := func(status string) {}
err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
if err != nil {
t.Fatalf("CreateSafetensorsModel failed: %v", err)
}
// Verify model.safetensors.index.json was not included
for _, name := range configNames {
if name == "model.safetensors.index.json" {
t.Error("model.safetensors.index.json should have been skipped")
}
}
}
func TestResolveManifestPath(t *testing.T) {
tests := []struct {
name string
modelName string
wantParts []string // Parts that should appear in the path
}{
{
name: "simple model name",
modelName: "llama2",
wantParts: []string{"registry.ollama.ai", "library", "llama2", "latest"},
},
{
name: "model name with tag",
modelName: "llama2:7b",
wantParts: []string{"registry.ollama.ai", "library", "llama2", "7b"},
},
{
name: "model name with namespace",
modelName: "myuser/mymodel",
wantParts: []string{"registry.ollama.ai", "myuser", "mymodel", "latest"},
},
{
name: "model name with namespace and tag",
modelName: "myuser/mymodel:v1",
wantParts: []string{"registry.ollama.ai", "myuser", "mymodel", "v1"},
},
{
name: "fully qualified model name",
modelName: "registry.example.com/namespace/model:tag",
wantParts: []string{"registry.example.com", "namespace", "model", "tag"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := resolveManifestPath(tt.modelName)
for _, part := range tt.wantParts {
if !strings.Contains(got, part) {
t.Errorf("resolveManifestPath(%q) = %q, missing part %q", tt.modelName, got, part)
}
}
})
}
}
func TestLayerInfo(t *testing.T) {
layer := LayerInfo{
Digest: "sha256:abc123",
Size: 1024,
MediaType: "application/vnd.ollama.image.tensor",
Name: "model.weight",
}
if layer.Digest != "sha256:abc123" {
t.Errorf("Digest = %q, want %q", layer.Digest, "sha256:abc123")
}
if layer.Size != 1024 {
t.Errorf("Size = %d, want %d", layer.Size, 1024)
}
if layer.MediaType != "application/vnd.ollama.image.tensor" {
t.Errorf("MediaType = %q, want %q", layer.MediaType, "application/vnd.ollama.image.tensor")
}
if layer.Name != "model.weight" {
t.Errorf("Name = %q, want %q", layer.Name, "model.weight")
}
}
func TestModelConfig(t *testing.T) {
config := ModelConfig{
ModelFormat: "safetensors",
Capabilities: []string{"completion", "chat"},
}
if config.ModelFormat != "safetensors" {
t.Errorf("ModelFormat = %q, want %q", config.ModelFormat, "safetensors")
}
if len(config.Capabilities) != 2 {
t.Errorf("Capabilities length = %d, want %d", len(config.Capabilities), 2)
}
}
func TestManifest(t *testing.T) {
manifest := Manifest{
SchemaVersion: 2,
MediaType: "application/vnd.oci.image.manifest.v1+json",
Config: ManifestLayer{
MediaType: "application/vnd.docker.container.image.v1+json",
Digest: "sha256:config",
Size: 100,
},
Layers: []ManifestLayer{
{
MediaType: "application/vnd.ollama.image.tensor",
Digest: "sha256:layer1",
Size: 1000,
Name: "weight.bin",
},
},
}
if manifest.SchemaVersion != 2 {
t.Errorf("SchemaVersion = %d, want %d", manifest.SchemaVersion, 2)
}
if manifest.Config.Digest != "sha256:config" {
t.Errorf("Config.Digest = %q, want %q", manifest.Config.Digest, "sha256:config")
}
if len(manifest.Layers) != 1 {
t.Errorf("Layers length = %d, want %d", len(manifest.Layers), 1)
}
if manifest.Layers[0].Name != "weight.bin" {
t.Errorf("Layers[0].Name = %q, want %q", manifest.Layers[0].Name, "weight.bin")
}
}
func TestShouldQuantize(t *testing.T) {
tests := []struct {
name string
tensor string
component string
want bool
}{
// VAE component should never be quantized
{"vae weight", "decoder.weight", "vae", false},
{"vae bias", "decoder.bias", "vae", false},
// Embeddings should not be quantized
{"embedding weight", "embed_tokens.weight", "", false},
{"embedding in name", "token_embedding.weight", "", false},
// Norms should not be quantized
{"layer norm", "layer_norm.weight", "", false},
{"rms norm", "rms_norm.weight", "", false},
{"ln prefix", "ln_1.weight", "", false},
{"layernorm in name", "input_layernorm.weight", "", false},
// Biases should not be quantized
{"bias tensor", "attention.bias", "", false},
{"proj bias", "o_proj.bias", "", false},
// Linear weights should be quantized
{"linear weight", "q_proj.weight", "", true},
{"attention weight", "self_attn.weight", "", true},
{"mlp weight", "mlp.gate_proj.weight", "", true},
// Transformer component weights should be quantized
{"transformer weight", "layers.0.weight", "transformer", true},
{"text_encoder weight", "encoder.weight", "text_encoder", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := ShouldQuantize(tt.tensor, tt.component)
if got != tt.want {
t.Errorf("ShouldQuantize(%q, %q) = %v, want %v", tt.tensor, tt.component, got, tt.want)
}
})
}
}
func TestShouldQuantizeTensor(t *testing.T) {
tests := []struct {
name string
tensor string
shape []int32
want bool
}{
// 2D tensors with sufficient size should be quantized
{"large 2D weight", "q_proj.weight", []int32{4096, 4096}, true},
{"medium 2D weight", "small_proj.weight", []int32{128, 128}, true},
// Small tensors should not be quantized (< 1024 elements)
{"tiny 2D weight", "tiny.weight", []int32{16, 16}, false},
{"small 2D weight", "small.weight", []int32{31, 31}, false},
// 1D tensors should not be quantized
{"1D tensor", "layer_norm.weight", []int32{4096}, false},
// 3D+ tensors should not be quantized
{"3D tensor", "conv.weight", []int32{64, 64, 3}, false},
{"4D tensor", "conv2d.weight", []int32{64, 64, 3, 3}, false},
// Embeddings should not be quantized regardless of shape
{"embedding 2D", "embed_tokens.weight", []int32{32000, 4096}, false},
// Norms should not be quantized regardless of shape
{"norm 2D", "layer_norm.weight", []int32{4096, 1}, false},
// Biases should not be quantized
{"bias 2D", "proj.bias", []int32{4096, 1}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := ShouldQuantizeTensor(tt.tensor, tt.shape)
if got != tt.want {
t.Errorf("ShouldQuantizeTensor(%q, %v) = %v, want %v", tt.tensor, tt.shape, got, tt.want)
}
})
}
}
func TestCreateSafetensorsModel_WithQuantize(t *testing.T) {
dir := t.TempDir()
// Create config.json
configJSON := `{"model_type": "test", "architectures": ["TestModel"]}`
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil {
t.Fatalf("failed to write config.json: %v", err)
}
// Create a minimal safetensors file
createMinimalSafetensors(t, filepath.Join(dir, "model.safetensors"))
var quantizeRequested []string
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
io.ReadAll(r)
return LayerInfo{Name: name, Digest: "sha256:test"}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
io.ReadAll(r)
quantizeRequested = append(quantizeRequested, quantize)
return []LayerInfo{{Name: name}}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
return nil
}
progressFn := func(status string) {}
// Run with quantize enabled
err := CreateSafetensorsModel("test-model", dir, "fp8", createLayer, createTensorLayer, writeManifest, progressFn)
if err != nil {
t.Fatalf("CreateSafetensorsModel failed: %v", err)
}
// Verify quantize was passed to callback (will be false for small test tensor)
if len(quantizeRequested) == 0 {
t.Error("no tensors processed")
}
}
// createMinimalImageGenModel creates a minimal diffusers-style model directory
func createMinimalImageGenModel(t *testing.T, dir string) {
t.Helper()
// Create model_index.json
modelIndex := `{"_class_name": "FluxPipeline", "_diffusers_version": "0.30.0"}`
if err := os.WriteFile(filepath.Join(dir, "model_index.json"), []byte(modelIndex), 0o644); err != nil {
t.Fatalf("failed to write model_index.json: %v", err)
}
// Create transformer directory with a safetensors file
transformerDir := filepath.Join(dir, "transformer")
if err := os.MkdirAll(transformerDir, 0o755); err != nil {
t.Fatalf("failed to create transformer dir: %v", err)
}
createMinimalSafetensors(t, filepath.Join(transformerDir, "model.safetensors"))
// Create transformer config
transformerConfig := `{"hidden_size": 3072}`
if err := os.WriteFile(filepath.Join(transformerDir, "config.json"), []byte(transformerConfig), 0o644); err != nil {
t.Fatalf("failed to write transformer config: %v", err)
}
}
func TestCreateImageGenModel(t *testing.T) {
dir := t.TempDir()
createMinimalImageGenModel(t, dir)
var manifestWritten bool
var manifestModelName string
var statusMessages []string
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
io.ReadAll(r)
return LayerInfo{Name: name, Digest: "sha256:test"}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
io.ReadAll(r)
return []LayerInfo{{Name: name, Digest: "sha256:tensor"}}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
manifestWritten = true
manifestModelName = modelName
return nil
}
progressFn := func(status string) {
statusMessages = append(statusMessages, status)
}
err := CreateImageGenModel("test-imagegen", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
if err != nil {
t.Fatalf("CreateImageGenModel failed: %v", err)
}
if !manifestWritten {
t.Error("manifest was not written")
}
if manifestModelName != "test-imagegen" {
t.Errorf("manifest model name = %q, want %q", manifestModelName, "test-imagegen")
}
if len(statusMessages) == 0 {
t.Error("no status messages received")
}
}
func TestCreateImageGenModel_NoModelIndex(t *testing.T) {
dir := t.TempDir()
// Create only transformer without model_index.json
transformerDir := filepath.Join(dir, "transformer")
if err := os.MkdirAll(transformerDir, 0o755); err != nil {
t.Fatalf("failed to create transformer dir: %v", err)
}
createMinimalSafetensors(t, filepath.Join(transformerDir, "model.safetensors"))
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
io.ReadAll(r)
return LayerInfo{Name: name}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
io.ReadAll(r)
return []LayerInfo{{Name: name}}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
return nil
}
progressFn := func(status string) {}
err := CreateImageGenModel("test-imagegen", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
if err == nil {
t.Error("expected error for missing model_index.json, got nil")
}
}
func TestCreateImageGenModel_WithQuantize(t *testing.T) {
dir := t.TempDir()
createMinimalImageGenModel(t, dir)
var quantizeRequested []string
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
io.ReadAll(r)
return LayerInfo{Name: name, Digest: "sha256:test"}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
io.ReadAll(r)
quantizeRequested = append(quantizeRequested, quantize)
return []LayerInfo{{Name: name}}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
return nil
}
progressFn := func(status string) {}
err := CreateImageGenModel("test-imagegen", dir, "fp8", createLayer, createTensorLayer, writeManifest, progressFn)
if err != nil {
t.Fatalf("CreateImageGenModel failed: %v", err)
}
if len(quantizeRequested) == 0 {
t.Error("no tensors processed")
}
}

View File

@@ -1,4 +1,4 @@
package imagegen package create
import ( import (
"bytes" "bytes"
@@ -12,40 +12,24 @@ import (
"github.com/ollama/ollama/x/imagegen/safetensors" "github.com/ollama/ollama/x/imagegen/safetensors"
) )
// IsTensorModelDir checks if the directory contains a tensor model // CreateImageGenModel imports an image generation model from a directory.
// by looking for model_index.json, which is the standard diffusers pipeline config.
func IsTensorModelDir(dir string) bool {
_, err := os.Stat(filepath.Join(dir, "model_index.json"))
return err == nil
}
// LayerInfo holds metadata for a created layer.
type LayerInfo struct {
Digest string
Size int64
MediaType string
Name string // Path-style name: "component/tensor" or "path/to/config.json"
}
// LayerCreator is called to create a blob layer.
// name is the path-style name (e.g., "tokenizer/tokenizer.json")
type LayerCreator func(r io.Reader, mediaType, name string) (LayerInfo, error)
// TensorLayerCreator creates a tensor blob layer with metadata.
// name is the path-style name including component (e.g., "text_encoder/model.embed_tokens.weight")
type TensorLayerCreator func(r io.Reader, name, dtype string, shape []int32) (LayerInfo, error)
// ManifestWriter writes the manifest file.
type ManifestWriter func(modelName string, config LayerInfo, layers []LayerInfo) error
// CreateModel imports an image generation model from a directory.
// Stores each tensor as a separate blob for fine-grained deduplication. // Stores each tensor as a separate blob for fine-grained deduplication.
// If quantize is "fp8", linear weights in transformer/text_encoder are quantized to mxfp8 format. // If quantize is specified, linear weights in transformer/text_encoder are quantized.
// Supported quantization types: fp8 (or empty for no quantization).
// Layer creation and manifest writing are done via callbacks to avoid import cycles. // Layer creation and manifest writing are done via callbacks to avoid import cycles.
func CreateModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error { func CreateImageGenModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
// Validate quantization type
switch quantize {
case "", "fp4", "fp8":
// valid
default:
return fmt.Errorf("unsupported quantization type %q: supported types are fp4, fp8", quantize)
}
var layers []LayerInfo var layers []LayerInfo
var configLayer LayerInfo var configLayer LayerInfo
var totalParams int64 // Count parameters from original tensor shapes var totalParams int64 // Count parameters from original tensor shapes
var torchDtype string // Read from component config for quantization display
// Components to process - extract individual tensors from each // Components to process - extract individual tensors from each
components := []string{"text_encoder", "transformer", "vae"} components := []string{"text_encoder", "transformer", "vae"}
@@ -77,8 +61,8 @@ func CreateModel(modelName, modelDir, quantize string, createLayer LayerCreator,
tensorNames := extractor.ListTensors() tensorNames := extractor.ListTensors()
quantizeMsg := "" quantizeMsg := ""
if quantize == "fp8" && component != "vae" { if quantize != "" && component != "vae" {
quantizeMsg = ", quantizing to fp8" quantizeMsg = ", quantizing to " + quantize
} }
fn(fmt.Sprintf("importing %s/%s (%d tensors%s)", component, entry.Name(), len(tensorNames), quantizeMsg)) fn(fmt.Sprintf("importing %s/%s (%d tensors%s)", component, entry.Name(), len(tensorNames), quantizeMsg))
@@ -103,11 +87,14 @@ func CreateModel(modelName, modelDir, quantize string, createLayer LayerCreator,
// Use path-style name: "component/tensor_name" // Use path-style name: "component/tensor_name"
fullName := component + "/" + tensorName fullName := component + "/" + tensorName
// Determine if this tensor should be quantized // Determine quantization type for this tensor (empty string if not quantizing)
doQuantize := quantize == "fp8" && ShouldQuantize(tensorName, component) quantizeType := ""
if quantize != "" && ShouldQuantize(tensorName, component) && canQuantizeShape(td.Shape) {
quantizeType = quantize
}
// createTensorLayer returns multiple layers if quantizing (weight + scales) // createTensorLayer returns multiple layers if quantizing (weight + scales)
newLayers, err := createTensorLayer(td.SafetensorsReader(), fullName, td.Dtype, td.Shape, doQuantize) newLayers, err := createTensorLayer(td.SafetensorsReader(), fullName, td.Dtype, td.Shape, quantizeType)
if err != nil { if err != nil {
extractor.Close() extractor.Close()
return fmt.Errorf("failed to create layer for %s: %w", fullName, err) return fmt.Errorf("failed to create layer for %s: %w", fullName, err)
@@ -119,6 +106,19 @@ func CreateModel(modelName, modelDir, quantize string, createLayer LayerCreator,
} }
} }
// Read torch_dtype from text_encoder config for quantization display
if torchDtype == "" {
textEncoderConfig := filepath.Join(modelDir, "text_encoder/config.json")
if data, err := os.ReadFile(textEncoderConfig); err == nil {
var cfg struct {
TorchDtype string `json:"torch_dtype"`
}
if json.Unmarshal(data, &cfg) == nil && cfg.TorchDtype != "" {
torchDtype = cfg.TorchDtype
}
}
}
// Import config files // Import config files
configFiles := []string{ configFiles := []string{
"model_index.json", "model_index.json",
@@ -164,11 +164,11 @@ func CreateModel(modelName, modelDir, quantize string, createLayer LayerCreator,
// Add parameter count (counted from tensor shapes during import) // Add parameter count (counted from tensor shapes during import)
cfg["parameter_count"] = totalParams cfg["parameter_count"] = totalParams
// Add quantization info // Add quantization info - use quantize type if set, otherwise torch_dtype
if quantize == "fp8" { if quantize != "" {
cfg["quantization"] = "FP8" cfg["quantization"] = strings.ToUpper(quantize)
} else { } else {
cfg["quantization"] = "BF16" cfg["quantization"] = torchDtype
} }
data, err = json.MarshalIndent(cfg, "", " ") data, err = json.MarshalIndent(cfg, "", " ")
@@ -211,3 +211,12 @@ func CreateModel(modelName, modelDir, quantize string, createLayer LayerCreator,
fn(fmt.Sprintf("successfully imported %s with %d layers", modelName, len(layers))) fn(fmt.Sprintf("successfully imported %s with %d layers", modelName, len(layers)))
return nil return nil
} }
// canQuantizeShape returns true if a tensor shape is compatible with MLX quantization.
// MLX requires the last dimension to be divisible by the group size (32).
func canQuantizeShape(shape []int32) bool {
if len(shape) < 2 {
return false
}
return shape[len(shape)-1]%32 == 0
}

View File

@@ -1,231 +0,0 @@
package api
import (
"fmt"
"net/http"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/x/imagegen"
)
// RunnerScheduler is the interface for scheduling a model runner.
// This is implemented by server.Server to avoid circular imports.
type RunnerScheduler interface {
ScheduleImageGenRunner(ctx *gin.Context, modelName string, opts api.Options, keepAlive *api.Duration) (llm.LlamaServer, error)
}
// RegisterRoutes registers the image generation API routes.
func RegisterRoutes(r gin.IRouter, scheduler RunnerScheduler) {
r.POST("/v1/images/generations", func(c *gin.Context) {
ImageGenerationHandler(c, scheduler)
})
}
// ImageGenerationHandler handles OpenAI-compatible image generation requests.
func ImageGenerationHandler(c *gin.Context, scheduler RunnerScheduler) {
var req ImageGenerationRequest
if err := c.BindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": err.Error()}})
return
}
// Validate required fields
if req.Model == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": "model is required"}})
return
}
if req.Prompt == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": "prompt is required"}})
return
}
// Apply defaults
if req.N == 0 {
req.N = 1
}
if req.Size == "" {
req.Size = "1024x1024"
}
if req.ResponseFormat == "" {
req.ResponseFormat = "b64_json"
}
// Verify model exists
if imagegen.ResolveModelName(req.Model) == "" {
c.JSON(http.StatusNotFound, gin.H{"error": gin.H{"message": fmt.Sprintf("model %q not found", req.Model)}})
return
}
// Parse size
width, height := parseSize(req.Size)
// Build options - we repurpose NumCtx/NumGPU for width/height
opts := api.Options{}
opts.NumCtx = int(width)
opts.NumGPU = int(height)
// Schedule runner
runner, err := scheduler.ScheduleImageGenRunner(c, req.Model, opts, nil)
if err != nil {
status := http.StatusInternalServerError
if strings.Contains(err.Error(), "not found") {
status = http.StatusNotFound
}
c.JSON(status, gin.H{"error": gin.H{"message": err.Error()}})
return
}
// Build completion request
completionReq := llm.CompletionRequest{
Prompt: req.Prompt,
Options: &opts,
}
if req.Stream {
handleStreamingResponse(c, runner, completionReq, req.ResponseFormat)
} else {
handleNonStreamingResponse(c, runner, completionReq, req.ResponseFormat)
}
}
func handleStreamingResponse(c *gin.Context, runner llm.LlamaServer, req llm.CompletionRequest, format string) {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
var imageBase64 string
err := runner.Completion(c.Request.Context(), req, func(resp llm.CompletionResponse) {
if resp.Done {
imageBase64 = extractBase64(resp.Content)
} else {
progress := parseProgress(resp.Content)
if progress.Total > 0 {
c.SSEvent("progress", progress)
c.Writer.Flush()
}
}
})
if err != nil {
c.SSEvent("error", gin.H{"error": err.Error()})
return
}
c.SSEvent("done", buildResponse(imageBase64, format))
}
func handleNonStreamingResponse(c *gin.Context, runner llm.LlamaServer, req llm.CompletionRequest, format string) {
var imageBase64 string
err := runner.Completion(c.Request.Context(), req, func(resp llm.CompletionResponse) {
if resp.Done {
imageBase64 = extractBase64(resp.Content)
}
})
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{"message": err.Error()}})
return
}
c.JSON(http.StatusOK, buildResponse(imageBase64, format))
}
func parseSize(size string) (int32, int32) {
parts := strings.Split(size, "x")
if len(parts) != 2 {
return 1024, 1024
}
w, _ := strconv.Atoi(parts[0])
h, _ := strconv.Atoi(parts[1])
if w == 0 {
w = 1024
}
if h == 0 {
h = 1024
}
return int32(w), int32(h)
}
func extractBase64(content string) string {
if strings.HasPrefix(content, "IMAGE_BASE64:") {
return content[13:]
}
return ""
}
func parseProgress(content string) ImageProgressEvent {
var step, total int
fmt.Sscanf(content, "\rGenerating: step %d/%d", &step, &total)
return ImageProgressEvent{Step: step, Total: total}
}
func buildResponse(imageBase64, format string) ImageGenerationResponse {
resp := ImageGenerationResponse{
Created: time.Now().Unix(),
Data: make([]ImageData, 1),
}
if imageBase64 == "" {
return resp
}
if format == "url" {
// URL format not supported when using base64 transfer
resp.Data[0].B64JSON = imageBase64
} else {
resp.Data[0].B64JSON = imageBase64
}
return resp
}
// HandleGenerateRequest handles Ollama /api/generate requests for image gen models.
// This allows routes.go to delegate image generation with minimal code.
func HandleGenerateRequest(c *gin.Context, scheduler RunnerScheduler, modelName, prompt string, keepAlive *api.Duration, streamFn func(c *gin.Context, ch chan any)) {
opts := api.Options{}
// Schedule runner
runner, err := scheduler.ScheduleImageGenRunner(c, modelName, opts, keepAlive)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// Build completion request
completionReq := llm.CompletionRequest{
Prompt: prompt,
Options: &opts,
}
// Stream responses via channel
ch := make(chan any)
go func() {
defer close(ch)
err := runner.Completion(c.Request.Context(), completionReq, func(resp llm.CompletionResponse) {
ch <- GenerateResponse{
Model: modelName,
CreatedAt: time.Now().UTC(),
Response: resp.Content,
Done: resp.Done,
}
})
if err != nil {
// Log error but don't block - channel is already being consumed
_ = err
}
}()
streamFn(c, ch)
}
// GenerateResponse matches api.GenerateResponse structure for streaming.
type GenerateResponse struct {
Model string `json:"model"`
CreatedAt time.Time `json:"created_at"`
Response string `json:"response"`
Done bool `json:"done"`
}

View File

@@ -1,31 +0,0 @@
// Package api provides OpenAI-compatible image generation API types.
package api
// ImageGenerationRequest is an OpenAI-compatible image generation request.
type ImageGenerationRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
N int `json:"n,omitempty"`
Size string `json:"size,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
Stream bool `json:"stream,omitempty"`
}
// ImageGenerationResponse is an OpenAI-compatible image generation response.
type ImageGenerationResponse struct {
Created int64 `json:"created"`
Data []ImageData `json:"data"`
}
// ImageData contains the generated image data.
type ImageData struct {
URL string `json:"url,omitempty"`
B64JSON string `json:"b64_json,omitempty"`
RevisedPrompt string `json:"revised_prompt,omitempty"`
}
// ImageProgressEvent is sent during streaming to indicate generation progress.
type ImageProgressEvent struct {
Step int `json:"step"`
Total int `json:"total"`
}

View File

@@ -7,7 +7,6 @@ package imagegen
import ( import (
"encoding/base64" "encoding/base64"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@@ -39,79 +38,20 @@ func DefaultOptions() ImageGenOptions {
return ImageGenOptions{ return ImageGenOptions{
Width: 1024, Width: 1024,
Height: 1024, Height: 1024,
Steps: 9, Steps: 0, // 0 means model default
Seed: 0, // 0 means random Seed: 0, // 0 means random
} }
} }
// ModelInfo contains metadata about an image generation model.
type ModelInfo struct {
Architecture string
ParameterCount int64
Quantization string
}
// GetModelInfo returns metadata about an image generation model.
func GetModelInfo(modelName string) (*ModelInfo, error) {
manifest, err := LoadManifest(modelName)
if err != nil {
return nil, fmt.Errorf("failed to load manifest: %w", err)
}
info := &ModelInfo{}
// Read model_index.json for architecture, parameter count, and quantization
if data, err := manifest.ReadConfig("model_index.json"); err == nil {
var index struct {
Architecture string `json:"architecture"`
ParameterCount int64 `json:"parameter_count"`
Quantization string `json:"quantization"`
}
if json.Unmarshal(data, &index) == nil {
info.Architecture = index.Architecture
info.ParameterCount = index.ParameterCount
info.Quantization = index.Quantization
}
}
// Fallback: detect quantization from tensor names if not in config
if info.Quantization == "" {
for _, layer := range manifest.Manifest.Layers {
if strings.HasSuffix(layer.Name, ".weight_scale") {
info.Quantization = "FP8"
break
}
}
if info.Quantization == "" {
info.Quantization = "BF16"
}
}
// Fallback: estimate parameter count if not in config
if info.ParameterCount == 0 {
var totalSize int64
for _, layer := range manifest.Manifest.Layers {
if layer.MediaType == "application/vnd.ollama.image.tensor" {
if !strings.HasSuffix(layer.Name, "_scale") && !strings.HasSuffix(layer.Name, "_qbias") {
totalSize += layer.Size
}
}
}
// Assume BF16 (2 bytes/param) as rough estimate
info.ParameterCount = totalSize / 2
}
return info, nil
}
// RegisterFlags adds image generation flags to the given command. // RegisterFlags adds image generation flags to the given command.
// Flags are hidden since they only apply to image generation models. // Flags are hidden since they only apply to image generation models.
func RegisterFlags(cmd *cobra.Command) { func RegisterFlags(cmd *cobra.Command) {
cmd.Flags().Int("width", 1024, "Image width") cmd.Flags().Int("width", 1024, "Image width")
cmd.Flags().Int("height", 1024, "Image height") cmd.Flags().Int("height", 1024, "Image height")
cmd.Flags().Int("steps", 9, "Denoising steps") cmd.Flags().Int("steps", 0, "Denoising steps (0 = model default)")
cmd.Flags().Int("seed", 0, "Random seed (0 for random)") cmd.Flags().Int("seed", 0, "Random seed (0 for random)")
cmd.Flags().String("negative", "", "Negative prompt") cmd.Flags().String("negative", "", "Negative prompt")
// Hide from main flags section - shown in separate section via AppendFlagsDocs
cmd.Flags().MarkHidden("width") cmd.Flags().MarkHidden("width")
cmd.Flags().MarkHidden("height") cmd.Flags().MarkHidden("height")
cmd.Flags().MarkHidden("steps") cmd.Flags().MarkHidden("steps")
@@ -119,6 +59,19 @@ func RegisterFlags(cmd *cobra.Command) {
cmd.Flags().MarkHidden("negative") cmd.Flags().MarkHidden("negative")
} }
// AppendFlagsDocs appends image generation flags documentation to the command's usage template.
func AppendFlagsDocs(cmd *cobra.Command) {
usage := `
Image Generation Flags (experimental):
--width int Image width
--height int Image height
--steps int Denoising steps
--seed int Random seed
--negative str Negative prompt
`
cmd.SetUsageTemplate(cmd.UsageTemplate() + usage)
}
// RunCLI handles the CLI for image generation models. // RunCLI handles the CLI for image generation models.
// Returns true if it handled the request, false if the caller should continue with normal flow. // Returns true if it handled the request, false if the caller should continue with normal flow.
// Supports flags: --width, --height, --steps, --seed, --negative // Supports flags: --width, --height, --steps, --seed, --negative
@@ -158,17 +111,15 @@ func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keep
return err return err
} }
// Build request with image gen options encoded in Options fields
// NumCtx=width, NumGPU=height, NumPredict=steps, Seed=seed
req := &api.GenerateRequest{ req := &api.GenerateRequest{
Model: modelName, Model: modelName,
Prompt: prompt, Prompt: prompt,
Options: map[string]any{ Width: int32(opts.Width),
"num_ctx": opts.Width, Height: int32(opts.Height),
"num_gpu": opts.Height, Steps: int32(opts.Steps),
"num_predict": opts.Steps, }
"seed": opts.Seed, if opts.Seed != 0 {
}, req.Options = map[string]any{"seed": opts.Seed}
} }
if keepAlive != nil { if keepAlive != nil {
req.KeepAlive = keepAlive req.KeepAlive = keepAlive
@@ -182,32 +133,25 @@ func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keep
var stepBar *progress.StepBar var stepBar *progress.StepBar
var imageBase64 string var imageBase64 string
err = client.Generate(cmd.Context(), req, func(resp api.GenerateResponse) error { err = client.Generate(cmd.Context(), req, func(resp api.GenerateResponse) error {
content := resp.Response // Handle progress updates using structured fields
if resp.Total > 0 {
// Handle progress updates - parse step info and switch to step bar if stepBar == nil {
if strings.HasPrefix(content, "\rGenerating:") {
var step, total int
fmt.Sscanf(content, "\rGenerating: step %d/%d", &step, &total)
if stepBar == nil && total > 0 {
spinner.Stop() spinner.Stop()
stepBar = progress.NewStepBar("Generating", total) stepBar = progress.NewStepBar("Generating", int(resp.Total))
p.Add("", stepBar) p.Add("", stepBar)
} }
if stepBar != nil { stepBar.Set(int(resp.Completed))
stepBar.Set(step)
}
return nil
} }
// Handle final response with base64 image data // Handle final response with image data
if resp.Done && strings.HasPrefix(content, "IMAGE_BASE64:") { if resp.Done && resp.Image != "" {
imageBase64 = content[13:] imageBase64 = resp.Image
} }
return nil return nil
}) })
p.Stop() p.StopAndClear()
if err != nil { if err != nil {
return err return err
} }
@@ -245,6 +189,23 @@ func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duratio
return err return err
} }
// Preload the model with the specified keepalive
p := progress.NewProgress(os.Stderr)
spinner := progress.NewSpinner("")
p.Add("", spinner)
preloadReq := &api.GenerateRequest{
Model: modelName,
KeepAlive: keepAlive,
}
if err := client.Generate(cmd.Context(), preloadReq, func(resp api.GenerateResponse) error {
return nil
}); err != nil {
p.StopAndClear()
return fmt.Errorf("failed to load model: %w", err)
}
p.StopAndClear()
scanner, err := readline.New(readline.Prompt{ scanner, err := readline.New(readline.Prompt{
Prompt: ">>> ", Prompt: ">>> ",
Placeholder: "Describe an image to generate (/help for commands)", Placeholder: "Describe an image to generate (/help for commands)",
@@ -282,7 +243,7 @@ func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duratio
case strings.HasPrefix(line, "/bye"): case strings.HasPrefix(line, "/bye"):
return nil return nil
case strings.HasPrefix(line, "/?"), strings.HasPrefix(line, "/help"): case strings.HasPrefix(line, "/?"), strings.HasPrefix(line, "/help"):
printInteractiveHelp(opts) printInteractiveHelp()
continue continue
case strings.HasPrefix(line, "/set "): case strings.HasPrefix(line, "/set "):
if err := handleSetCommand(line[5:], &opts); err != nil { if err := handleSetCommand(line[5:], &opts); err != nil {
@@ -301,12 +262,12 @@ func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duratio
req := &api.GenerateRequest{ req := &api.GenerateRequest{
Model: modelName, Model: modelName,
Prompt: line, Prompt: line,
Options: map[string]any{ Width: int32(opts.Width),
"num_ctx": opts.Width, Height: int32(opts.Height),
"num_gpu": opts.Height, Steps: int32(opts.Steps),
"num_predict": opts.Steps, }
"seed": opts.Seed, if opts.Seed != 0 {
}, req.Options = map[string]any{"seed": opts.Seed}
} }
if keepAlive != nil { if keepAlive != nil {
req.KeepAlive = keepAlive req.KeepAlive = keepAlive
@@ -321,32 +282,25 @@ func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duratio
var imageBase64 string var imageBase64 string
err = client.Generate(cmd.Context(), req, func(resp api.GenerateResponse) error { err = client.Generate(cmd.Context(), req, func(resp api.GenerateResponse) error {
content := resp.Response // Handle progress updates using structured fields
if resp.Total > 0 {
// Handle progress updates - parse step info and switch to step bar if stepBar == nil {
if strings.HasPrefix(content, "\rGenerating:") {
var step, total int
fmt.Sscanf(content, "\rGenerating: step %d/%d", &step, &total)
if stepBar == nil && total > 0 {
spinner.Stop() spinner.Stop()
stepBar = progress.NewStepBar("Generating", total) stepBar = progress.NewStepBar("Generating", int(resp.Total))
p.Add("", stepBar) p.Add("", stepBar)
} }
if stepBar != nil { stepBar.Set(int(resp.Completed))
stepBar.Set(step)
}
return nil
} }
// Handle final response with base64 image data // Handle final response with image data
if resp.Done && strings.HasPrefix(content, "IMAGE_BASE64:") { if resp.Done && resp.Image != "" {
imageBase64 = content[13:] imageBase64 = resp.Image
} }
return nil return nil
}) })
p.Stop() p.StopAndClear()
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err) fmt.Fprintf(os.Stderr, "Error: %v\n", err)
continue continue
@@ -397,12 +351,13 @@ func sanitizeFilename(s string) string {
} }
// printInteractiveHelp prints help for interactive mode commands. // printInteractiveHelp prints help for interactive mode commands.
func printInteractiveHelp(opts ImageGenOptions) { // TODO: reconcile /set commands with /set parameter in text gen REPL (cmd/cmd.go)
func printInteractiveHelp() {
fmt.Fprintln(os.Stderr, "Commands:") fmt.Fprintln(os.Stderr, "Commands:")
fmt.Fprintln(os.Stderr, " /set width <n> Set image width (current:", opts.Width, ")") fmt.Fprintln(os.Stderr, " /set width <n> Set image width")
fmt.Fprintln(os.Stderr, " /set height <n> Set image height (current:", opts.Height, ")") fmt.Fprintln(os.Stderr, " /set height <n> Set image height")
fmt.Fprintln(os.Stderr, " /set steps <n> Set denoising steps (current:", opts.Steps, ")") fmt.Fprintln(os.Stderr, " /set steps <n> Set denoising steps")
fmt.Fprintln(os.Stderr, " /set seed <n> Set random seed (current:", opts.Seed, ", 0=random)") fmt.Fprintln(os.Stderr, " /set seed <n> Set random seed")
fmt.Fprintln(os.Stderr, " /set negative <s> Set negative prompt") fmt.Fprintln(os.Stderr, " /set negative <s> Set negative prompt")
fmt.Fprintln(os.Stderr, " /show Show current settings") fmt.Fprintln(os.Stderr, " /show Show current settings")
fmt.Fprintln(os.Stderr, " /bye Exit") fmt.Fprintln(os.Stderr, " /bye Exit")

View File

@@ -1,190 +0,0 @@
// Package client provides client-side model creation for tensor-based models.
//
// This package is in x/ because the tensor model storage format is under development.
// It also exists to break an import cycle: server imports x/imagegen, so x/imagegen
// cannot import server. This sub-package can import server because server doesn't
// import it.
//
// TODO (jmorganca): This is temporary. When tensor models are promoted to production:
// 1. Add proper API endpoints for tensor model creation
// 2. Move tensor extraction to server-side
// 3. Remove this package
// 4. Follow the same client→server pattern as regular model creation
package client
import (
"bytes"
"encoding/json"
"fmt"
"io"
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/server"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/x/imagegen"
)
// MinOllamaVersion is the minimum Ollama version required for image generation models.
const MinOllamaVersion = "0.14.0"
// CreateModel imports a tensor-based model from a local directory.
// This creates blobs and manifest directly on disk, bypassing the HTTP API.
// If quantize is "fp8", weights will be quantized to mxfp8 format during import.
//
// TODO (jmorganca): Replace with API-based creation when promoted to production.
func CreateModel(modelName, modelDir, quantize string, p *progress.Progress) error {
if !imagegen.IsTensorModelDir(modelDir) {
return fmt.Errorf("%s is not an image generation model directory (model_index.json not found)", modelDir)
}
status := "importing image generation model"
spinner := progress.NewSpinner(status)
p.Add("imagegen", spinner)
// Create layer callback for config files
createLayer := func(r io.Reader, mediaType, name string) (imagegen.LayerInfo, error) {
layer, err := server.NewLayer(r, mediaType)
if err != nil {
return imagegen.LayerInfo{}, err
}
layer.Name = name
return imagegen.LayerInfo{
Digest: layer.Digest,
Size: layer.Size,
MediaType: layer.MediaType,
Name: name,
}, nil
}
// Create tensor layer callback for individual tensors
// name is path-style: "component/tensor_name"
// When quantize is true, returns multiple layers (weight + scales)
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, doQuantize bool) ([]imagegen.LayerInfo, error) {
if doQuantize {
// Check if quantization is supported
if !QuantizeSupported() {
return nil, fmt.Errorf("quantization requires MLX support")
}
// Quantize the tensor (affine mode returns weight, scales, qbiases)
qweightData, scalesData, qbiasData, _, _, _, err := quantizeTensor(r, name, dtype, shape)
if err != nil {
return nil, fmt.Errorf("failed to quantize %s: %w", name, err)
}
// Create layer for quantized weight
weightLayer, err := server.NewLayer(bytes.NewReader(qweightData), server.MediaTypeImageTensor)
if err != nil {
return nil, err
}
// Create layer for scales (use _scale suffix convention)
scalesLayer, err := server.NewLayer(bytes.NewReader(scalesData), server.MediaTypeImageTensor)
if err != nil {
return nil, err
}
layers := []imagegen.LayerInfo{
{
Digest: weightLayer.Digest,
Size: weightLayer.Size,
MediaType: weightLayer.MediaType,
Name: name, // Keep original name for weight
},
{
Digest: scalesLayer.Digest,
Size: scalesLayer.Size,
MediaType: scalesLayer.MediaType,
Name: name + "_scale", // Add _scale suffix
},
}
// Add qbiases layer if present (affine mode)
if qbiasData != nil {
qbiasLayer, err := server.NewLayer(bytes.NewReader(qbiasData), server.MediaTypeImageTensor)
if err != nil {
return nil, err
}
layers = append(layers, imagegen.LayerInfo{
Digest: qbiasLayer.Digest,
Size: qbiasLayer.Size,
MediaType: qbiasLayer.MediaType,
Name: name + "_qbias", // Add _qbias suffix
})
}
return layers, nil
}
// Non-quantized path: just create a single layer
layer, err := server.NewLayer(r, server.MediaTypeImageTensor)
if err != nil {
return nil, err
}
return []imagegen.LayerInfo{
{
Digest: layer.Digest,
Size: layer.Size,
MediaType: layer.MediaType,
Name: name,
},
}, nil
}
// Create manifest writer callback
writeManifest := func(modelName string, config imagegen.LayerInfo, layers []imagegen.LayerInfo) error {
name := model.ParseName(modelName)
if !name.IsValid() {
return fmt.Errorf("invalid model name: %s", modelName)
}
// Create a proper config blob with version requirement
configData := model.ConfigV2{
ModelFormat: "safetensors",
Capabilities: []string{"image"},
Requires: MinOllamaVersion,
}
configJSON, err := json.Marshal(configData)
if err != nil {
return fmt.Errorf("failed to marshal config: %w", err)
}
// Create config layer blob
configLayer, err := server.NewLayer(bytes.NewReader(configJSON), "application/vnd.docker.container.image.v1+json")
if err != nil {
return fmt.Errorf("failed to create config layer: %w", err)
}
// Convert LayerInfo to server.Layer (include the original model_index.json in layers)
serverLayers := make([]server.Layer, len(layers))
for i, l := range layers {
serverLayers[i] = server.Layer{
MediaType: l.MediaType,
Digest: l.Digest,
Size: l.Size,
Name: l.Name,
}
}
return server.WriteManifest(name, configLayer, serverLayers)
}
// Progress callback
progressFn := func(msg string) {
spinner.Stop()
status = msg
spinner = progress.NewSpinner(status)
p.Add("imagegen", spinner)
}
err := imagegen.CreateModel(modelName, modelDir, quantize, createLayer, createTensorLayer, writeManifest, progressFn)
spinner.Stop()
if err != nil {
return err
}
fmt.Printf("Created image generation model '%s'\n", modelName)
return nil
}

View File

@@ -65,12 +65,12 @@ func (s *utf8Streamer) Flush() string {
return result return result
} }
func init() {
generationStream = mlx.NewStream()
}
// withStream runs fn with the generation stream as default // withStream runs fn with the generation stream as default
func withStream(fn func()) { func withStream(fn func()) {
// Lazy initialization of generationStream
if generationStream == nil {
generationStream = mlx.NewStream()
}
orig := mlx.GetDefaultStream() orig := mlx.GetDefaultStream()
mlx.SetDefaultStream(generationStream) mlx.SetDefaultStream(generationStream)
fn() fn()

View File

@@ -7,12 +7,17 @@ import (
"encoding/json" "encoding/json"
"flag" "flag"
"fmt" "fmt"
"image"
_ "image/jpeg"
_ "image/png"
"log" "log"
"os" "os"
"path/filepath" "path/filepath"
"runtime/pprof" "runtime/pprof"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/mlx" "github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/models/flux2"
"github.com/ollama/ollama/x/imagegen/models/gemma3" "github.com/ollama/ollama/x/imagegen/models/gemma3"
"github.com/ollama/ollama/x/imagegen/models/gpt_oss" "github.com/ollama/ollama/x/imagegen/models/gpt_oss"
"github.com/ollama/ollama/x/imagegen/models/llama" "github.com/ollama/ollama/x/imagegen/models/llama"
@@ -46,9 +51,9 @@ func main() {
imagePath := flag.String("image", "", "Image path for multimodal models") imagePath := flag.String("image", "", "Image path for multimodal models")
// Image generation params // Image generation params
width := flag.Int("width", 1024, "Image width") width := flag.Int("width", 0, "Image width (0 = auto from input or 1024)")
height := flag.Int("height", 1024, "Image height") height := flag.Int("height", 0, "Image height (0 = auto from input or 1024)")
steps := flag.Int("steps", 9, "Denoising steps") steps := flag.Int("steps", 0, "Denoising steps (0 = model default)")
seed := flag.Int64("seed", 42, "Random seed") seed := flag.Int64("seed", 42, "Random seed")
out := flag.String("output", "output.png", "Output path") out := flag.String("output", "output.png", "Output path")
@@ -61,6 +66,7 @@ func main() {
// Legacy mode flags // Legacy mode flags
zimageFlag := flag.Bool("zimage", false, "Z-Image generation") zimageFlag := flag.Bool("zimage", false, "Z-Image generation")
flux2Flag := flag.Bool("flux2", false, "FLUX.2 Klein generation")
qwenImage := flag.Bool("qwen-image", false, "Qwen-Image text-to-image generation") qwenImage := flag.Bool("qwen-image", false, "Qwen-Image text-to-image generation")
qwenImageEdit := flag.Bool("qwen-image-edit", false, "Qwen-Image-Edit image editing") qwenImageEdit := flag.Bool("qwen-image-edit", false, "Qwen-Image-Edit image editing")
var inputImages stringSlice var inputImages stringSlice
@@ -78,6 +84,11 @@ func main() {
return return
} }
// Check if MLX initialized successfully
if !mlx.IsMLXAvailable() {
log.Fatalf("MLX initialization failed: %v", mlx.GetMLXInitError())
}
// CPU profiling // CPU profiling
if *cpuProfile != "" { if *cpuProfile != "" {
f, err := os.Create(*cpuProfile) f, err := os.Create(*cpuProfile)
@@ -117,6 +128,44 @@ func main() {
if err == nil { if err == nil {
err = saveImageArray(img, *out) err = saveImageArray(img, *out)
} }
case *flux2Flag:
m := &flux2.Model{}
if loadErr := m.Load(*modelPath); loadErr != nil {
log.Fatal(loadErr)
}
// Load input images with EXIF orientation correction
var loadedImages []image.Image
for _, path := range inputImages {
img, loadErr := loadImageWithEXIF(path)
if loadErr != nil {
log.Fatalf("Failed to load image %s: %v", path, loadErr)
}
loadedImages = append(loadedImages, img)
}
// When input images provided and user didn't override dimensions, use 0 to match input
fluxWidth := int32(*width)
fluxHeight := int32(*height)
if len(loadedImages) > 0 && *width == 0 && *height == 0 {
// Both unset, will auto-detect from input
} else if len(loadedImages) > 0 && *width == 0 {
fluxWidth = 0 // Compute from height + aspect ratio
} else if len(loadedImages) > 0 && *height == 0 {
fluxHeight = 0 // Compute from width + aspect ratio
}
var img *mlx.Array
img, err = m.GenerateFromConfig(context.Background(), &flux2.GenerateConfig{
Prompt: *prompt,
Width: fluxWidth,
Height: fluxHeight,
Steps: *steps,
GuidanceScale: float32(*cfgScale),
Seed: *seed,
CapturePath: *gpuCapture,
InputImages: loadedImages,
})
if err == nil {
err = saveImageArray(img, *out)
}
case *qwenImage: case *qwenImage:
m, loadErr := qwen_image.LoadPersistent(*modelPath) m, loadErr := qwen_image.LoadPersistent(*modelPath)
if loadErr != nil { if loadErr != nil {
@@ -271,6 +320,8 @@ func detectModelKind(modelPath string) (string, error) {
switch index.ClassName { switch index.ClassName {
case "FluxPipeline", "ZImagePipeline": case "FluxPipeline", "ZImagePipeline":
return "zimage", nil return "zimage", nil
case "Flux2KleinPipeline":
return "flux2", nil
} }
} }
return "zimage", nil return "zimage", nil
@@ -291,3 +342,12 @@ func detectModelKind(modelPath string) (string, error) {
return cfg.ModelType, nil return cfg.ModelType, nil
} }
// loadImageWithEXIF loads an image from a file path with EXIF orientation correction.
func loadImageWithEXIF(path string) (image.Image, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("read file: %w", err)
}
return imagegen.DecodeImage(data)
}

View File

@@ -7,6 +7,7 @@ import (
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"image" "image"
_ "image/jpeg"
"image/png" "image/png"
"os" "os"
"path/filepath" "path/filepath"
@@ -108,3 +109,160 @@ func clampF(v, min, max float32) float32 {
} }
return v return v
} }
// DecodeImage decodes image bytes with EXIF orientation applied.
func DecodeImage(data []byte) (image.Image, error) {
orientation := readJPEGOrientation(data)
img, _, err := image.Decode(bytes.NewReader(data))
if err != nil {
return nil, err
}
return applyOrientation(img, orientation), nil
}
// readJPEGOrientation extracts EXIF orientation from JPEG bytes.
// Returns 1 (normal) for non-JPEG or if orientation not found.
func readJPEGOrientation(data []byte) int {
if len(data) < 2 || data[0] != 0xFF || data[1] != 0xD8 {
return 1 // Not JPEG
}
r := bytes.NewReader(data[2:])
for {
var marker [2]byte
if _, err := r.Read(marker[:]); err != nil || marker[0] != 0xFF {
return 1
}
if marker[1] == 0xE1 { // APP1 (EXIF)
var lenBytes [2]byte
if _, err := r.Read(lenBytes[:]); err != nil {
return 1
}
segLen := int(uint16(lenBytes[0])<<8|uint16(lenBytes[1])) - 2
if segLen < 14 {
r.Seek(int64(segLen), 1)
continue
}
seg := make([]byte, segLen)
if _, err := r.Read(seg); err != nil {
return 1
}
if string(seg[:4]) == "Exif" && seg[4] == 0 && seg[5] == 0 {
return parseTIFFOrientation(seg[6:])
}
continue
}
if marker[1] == 0xD9 || marker[1] == 0xDA {
return 1 // EOI or SOS
}
if marker[1] >= 0xD0 && marker[1] <= 0xD7 {
continue // RST markers
}
var lenBytes [2]byte
if _, err := r.Read(lenBytes[:]); err != nil {
return 1
}
segLen := int(uint16(lenBytes[0])<<8|uint16(lenBytes[1])) - 2
if segLen > 0 {
r.Seek(int64(segLen), 1)
}
}
}
func parseTIFFOrientation(tiff []byte) int {
if len(tiff) < 8 {
return 1
}
var big bool
switch string(tiff[:2]) {
case "MM":
big = true
case "II":
big = false
default:
return 1
}
u16 := func(b []byte) uint16 {
if big {
return uint16(b[0])<<8 | uint16(b[1])
}
return uint16(b[1])<<8 | uint16(b[0])
}
u32 := func(b []byte) uint32 {
if big {
return uint32(b[0])<<24 | uint32(b[1])<<16 | uint32(b[2])<<8 | uint32(b[3])
}
return uint32(b[3])<<24 | uint32(b[2])<<16 | uint32(b[1])<<8 | uint32(b[0])
}
if u16(tiff[2:4]) != 42 {
return 1
}
ifdOffset := u32(tiff[4:8])
if int(ifdOffset)+2 > len(tiff) {
return 1
}
numEntries := u16(tiff[ifdOffset : ifdOffset+2])
for i := range int(numEntries) {
offset := ifdOffset + 2 + uint32(i)*12
if int(offset)+12 > len(tiff) {
break
}
if u16(tiff[offset:offset+2]) == 0x0112 { // Orientation tag
o := int(u16(tiff[offset+8 : offset+10]))
if o >= 1 && o <= 8 {
return o
}
return 1
}
}
return 1
}
func applyOrientation(img image.Image, orientation int) image.Image {
if orientation <= 1 || orientation > 8 {
return img
}
bounds := img.Bounds()
w, h := bounds.Dx(), bounds.Dy()
outW, outH := w, h
if orientation >= 5 {
outW, outH = h, w
}
out := image.NewRGBA(image.Rect(0, 0, outW, outH))
for y := range h {
for x := range w {
var dx, dy int
switch orientation {
case 2:
dx, dy = w-1-x, y
case 3:
dx, dy = w-1-x, h-1-y
case 4:
dx, dy = x, h-1-y
case 5:
dx, dy = y, x
case 6:
dx, dy = h-1-y, x
case 7:
dx, dy = h-1-y, w-1-x
case 8:
dx, dy = y, w-1-x
}
out.Set(dx, dy, img.At(x+bounds.Min.X, y+bounds.Min.Y))
}
}
return out
}

View File

@@ -175,3 +175,63 @@ func (m *ModelManifest) HasTensorLayers() bool {
} }
return false return false
} }
// ModelInfo contains metadata about an image generation model.
type ModelInfo struct {
Architecture string
ParameterCount int64
Quantization string
}
// GetModelInfo returns metadata about an image generation model.
func GetModelInfo(modelName string) (*ModelInfo, error) {
manifest, err := LoadManifest(modelName)
if err != nil {
return nil, fmt.Errorf("failed to load manifest: %w", err)
}
info := &ModelInfo{}
// Read model_index.json for architecture, parameter count, and quantization
if data, err := manifest.ReadConfig("model_index.json"); err == nil {
var index struct {
Architecture string `json:"architecture"`
ParameterCount int64 `json:"parameter_count"`
Quantization string `json:"quantization"`
}
if json.Unmarshal(data, &index) == nil {
info.Architecture = index.Architecture
info.ParameterCount = index.ParameterCount
info.Quantization = index.Quantization
}
}
// Fallback: detect quantization from tensor names if not in config
if info.Quantization == "" {
for _, layer := range manifest.Manifest.Layers {
if strings.HasSuffix(layer.Name, ".weight_scale") {
info.Quantization = "FP8"
break
}
}
if info.Quantization == "" {
info.Quantization = "BF16"
}
}
// Fallback: estimate parameter count if not in config
if info.ParameterCount == 0 {
var totalSize int64
for _, layer := range manifest.Manifest.Layers {
if layer.MediaType == "application/vnd.ollama.image.tensor" {
if !strings.HasSuffix(layer.Name, "_scale") && !strings.HasSuffix(layer.Name, "_qbias") {
totalSize += layer.Size
}
}
}
// Assume BF16 (2 bytes/param) as rough estimate
info.ParameterCount = totalSize / 2
}
return info, nil
}

View File

@@ -24,9 +24,8 @@ var SupportedBackends = []string{"metal", "cuda", "cpu"}
// modelVRAMEstimates maps pipeline class names to their estimated VRAM requirements. // modelVRAMEstimates maps pipeline class names to their estimated VRAM requirements.
var modelVRAMEstimates = map[string]uint64{ var modelVRAMEstimates = map[string]uint64{
"ZImagePipeline": 21 * GB, // ~21GB for Z-Image (text encoder + transformer + VAE) "ZImagePipeline": 21 * GB, // ~21GB for Z-Image (text encoder + transformer + VAE)
"FluxPipeline": 21 * GB, // ~21GB for Flux (same architecture) "FluxPipeline": 20 * GB, // ~20GB for Flux
"QwenImagePipeline": 80 * GB, // TODO: verify actual requirements, using conservative estimate for now
} }
// CheckPlatformSupport validates that image generation is supported on the current platform. // CheckPlatformSupport validates that image generation is supported on the current platform.
@@ -72,31 +71,38 @@ func ResolveModelName(modelName string) string {
// EstimateVRAM returns the estimated VRAM needed for an image generation model. // EstimateVRAM returns the estimated VRAM needed for an image generation model.
// Returns a conservative default of 21GB if the model type cannot be determined. // Returns a conservative default of 21GB if the model type cannot be determined.
func EstimateVRAM(modelName string) uint64 { func EstimateVRAM(modelName string) uint64 {
manifest, err := LoadManifest(modelName) className := DetectModelType(modelName)
if err != nil { if estimate, ok := modelVRAMEstimates[className]; ok {
return 21 * GB
}
data, err := manifest.ReadConfig("model_index.json")
if err != nil {
return 21 * GB
}
// Parse just the class name
var index struct {
ClassName string `json:"_class_name"`
}
if err := json.Unmarshal(data, &index); err != nil {
return 21 * GB
}
if estimate, ok := modelVRAMEstimates[index.ClassName]; ok {
return estimate return estimate
} }
return 21 * GB return 21 * GB
} }
// HasTensorLayers checks if the given model has tensor layers. // DetectModelType reads model_index.json and returns the model type.
func HasTensorLayers(modelName string) bool { // Checks both "architecture" (Ollama format) and "_class_name" (diffusers format).
return ResolveModelName(modelName) != "" // Returns empty string if detection fails.
func DetectModelType(modelName string) string {
manifest, err := LoadManifest(modelName)
if err != nil {
return ""
}
data, err := manifest.ReadConfig("model_index.json")
if err != nil {
return ""
}
var index struct {
Architecture string `json:"architecture"`
ClassName string `json:"_class_name"`
}
if err := json.Unmarshal(data, &index); err != nil {
return ""
}
// Prefer architecture (Ollama format), fall back to _class_name (diffusers)
if index.Architecture != "" {
return index.Architecture
}
return index.ClassName
} }

View File

@@ -72,9 +72,8 @@ func TestCheckMemoryRequirements(t *testing.T) {
func TestModelVRAMEstimates(t *testing.T) { func TestModelVRAMEstimates(t *testing.T) {
// Verify the VRAM estimates map has expected entries // Verify the VRAM estimates map has expected entries
expected := map[string]uint64{ expected := map[string]uint64{
"ZImagePipeline": 21 * GB, "ZImagePipeline": 21 * GB,
"FluxPipeline": 21 * GB, "FluxPipeline": 20 * GB,
"QwenImagePipeline": 80 * GB,
} }
for name, expectedVRAM := range expected { for name, expectedVRAM := range expected {
@@ -94,13 +93,6 @@ func TestEstimateVRAMDefault(t *testing.T) {
} }
} }
func TestHasTensorLayers(t *testing.T) {
// Non-existent model should return false
if HasTensorLayers("nonexistent-model") {
t.Error("HasTensorLayers() should return false for non-existent model")
}
}
func TestResolveModelName(t *testing.T) { func TestResolveModelName(t *testing.T) {
// Non-existent model should return empty string // Non-existent model should return empty string
result := ResolveModelName("nonexistent-model") result := ResolveModelName("nonexistent-model")

View File

@@ -3,7 +3,7 @@
package mlx package mlx
/* /*
#include "mlx/c/mlx.h" #include "mlx.h"
#include <stdlib.h> #include <stdlib.h>
// Forward declaration for Go callback // Forward declaration for Go callback

6
x/imagegen/mlx/doc.go Normal file
View File

@@ -0,0 +1,6 @@
//go:build mlx
// Package mlx provides Go bindings for the MLX-C library with dynamic loading support.
//
//go:generate go run generate_wrappers.go ../../../build/_deps/mlx-c-src/mlx/c mlx.h mlx.c
package mlx

View File

@@ -0,0 +1,439 @@
//go:build ignore
// This tool generates MLX-C dynamic loading wrappers.
// Usage: go run generate_wrappers.go <mlx-c-include-dir> <output-header> [output-impl]
package main
import (
"bytes"
"flag"
"fmt"
"io/fs"
"os"
"path/filepath"
"regexp"
"strings"
)
type Function struct {
Name string
ReturnType string
Params string
ParamNames []string
NeedsARM64Guard bool
}
func findHeaders(directory string) ([]string, error) {
var headers []string
err := filepath.WalkDir(directory, func(path string, d fs.DirEntry, err error) error {
if err != nil {
return err
}
if !d.IsDir() && strings.HasSuffix(path, ".h") {
headers = append(headers, path)
}
return nil
})
return headers, err
}
func cleanContent(content string) string {
// Remove single-line comments
re := regexp.MustCompile(`//.*?\n`)
content = re.ReplaceAllString(content, "\n")
// Remove multi-line comments
re = regexp.MustCompile(`/\*.*?\*/`)
content = re.ReplaceAllString(content, "")
// Remove preprocessor directives (lines starting with #) - use multiline mode
re = regexp.MustCompile(`(?m)^\s*#.*?$`)
content = re.ReplaceAllString(content, "")
// Remove extern "C" { and } blocks more conservatively
// Only remove the extern "C" { line, not the content inside
re = regexp.MustCompile(`extern\s+"C"\s*\{\s*?\n`)
content = re.ReplaceAllString(content, "\n")
// Remove standalone closing braces that are not part of function declarations
re = regexp.MustCompile(`\n\s*\}\s*\n`)
content = re.ReplaceAllString(content, "\n")
// Collapse whitespace and newlines
re = regexp.MustCompile(`\s+`)
content = re.ReplaceAllString(content, " ")
return content
}
func extractParamNames(params string) []string {
if params == "" || strings.TrimSpace(params) == "void" {
return []string{}
}
var names []string
// Split by comma, but respect parentheses (for function pointers)
parts := splitParams(params)
// Remove array brackets
arrayBrackets := regexp.MustCompile(`\[.*?\]`)
// Function pointer pattern
funcPtrPattern := regexp.MustCompile(`\(\s*\*\s*(\w+)\s*\)`)
// Type keywords to skip
typeKeywords := map[string]bool{
"const": true,
"struct": true,
"unsigned": true,
"signed": true,
"long": true,
"short": true,
"int": true,
"char": true,
"float": true,
"double": true,
"void": true,
"size_t": true,
"uint8_t": true,
"uint16_t": true,
"uint32_t": true,
"uint64_t": true,
"int8_t": true,
"int16_t": true,
"int32_t": true,
"int64_t": true,
"intptr_t": true,
"uintptr_t": true,
}
for _, part := range parts {
if part == "" {
continue
}
// Remove array brackets
part = arrayBrackets.ReplaceAllString(part, "")
// For function pointers like "void (*callback)(int)"
if matches := funcPtrPattern.FindStringSubmatch(part); len(matches) > 1 {
names = append(names, matches[1])
continue
}
// Regular parameter: last identifier
tokens := regexp.MustCompile(`\w+`).FindAllString(part, -1)
if len(tokens) > 0 {
// The last token is usually the parameter name
// Skip type keywords
for i := len(tokens) - 1; i >= 0; i-- {
if !typeKeywords[tokens[i]] {
names = append(names, tokens[i])
break
}
}
}
}
return names
}
func splitParams(params string) []string {
var parts []string
var current bytes.Buffer
depth := 0
for _, char := range params + "," {
switch char {
case '(':
depth++
current.WriteRune(char)
case ')':
depth--
current.WriteRune(char)
case ',':
if depth == 0 {
parts = append(parts, strings.TrimSpace(current.String()))
current.Reset()
} else {
current.WriteRune(char)
}
default:
current.WriteRune(char)
}
}
return parts
}
func parseFunctions(content string) []Function {
var functions []Function
// Match function declarations: return_type function_name(params);
// Matches both mlx_* and _mlx_* functions
pattern := regexp.MustCompile(`\b((?:const\s+)?(?:struct\s+)?[\w\s]+?[\*\s]*)\s+(_?mlx_\w+)\s*\(([^)]*(?:\([^)]*\)[^)]*)*)\)\s*;`)
matches := pattern.FindAllStringSubmatch(content, -1)
for _, match := range matches {
returnType := strings.TrimSpace(match[1])
funcName := strings.TrimSpace(match[2])
params := strings.TrimSpace(match[3])
// Skip if this looks like a variable declaration
if params == "" || strings.Contains(params, "{") {
continue
}
// Clean up return type
returnType = strings.Join(strings.Fields(returnType), " ")
// Extract parameter names
paramNames := extractParamNames(params)
// Check if ARM64 guard is needed
needsGuard := needsARM64Guard(funcName, returnType, params)
functions = append(functions, Function{
Name: funcName,
ReturnType: returnType,
Params: params,
ParamNames: paramNames,
NeedsARM64Guard: needsGuard,
})
}
return functions
}
func needsARM64Guard(name, retType, params string) bool {
return strings.Contains(name, "float16") ||
strings.Contains(name, "bfloat16") ||
strings.Contains(retType, "float16_t") ||
strings.Contains(retType, "bfloat16_t") ||
strings.Contains(params, "float16_t") ||
strings.Contains(params, "bfloat16_t")
}
func generateWrapperFiles(functions []Function, headerPath, implPath string) error {
// Generate header file
var headerBuf bytes.Buffer
headerBuf.WriteString("// AUTO-GENERATED by generate_wrappers.go - DO NOT EDIT\n")
headerBuf.WriteString("// This file provides wrapper declarations for MLX-C functions that use dlopen/dlsym\n")
headerBuf.WriteString("//\n")
headerBuf.WriteString("// Strategy: Include MLX-C headers for type definitions, then provide wrapper\n")
headerBuf.WriteString("// functions that shadow the originals, allowing Go code to call them directly (e.g., C.mlx_add).\n")
headerBuf.WriteString("// Function pointers are defined in mlx.c (single compilation unit).\n\n")
headerBuf.WriteString("#ifndef MLX_WRAPPERS_H\n")
headerBuf.WriteString("#define MLX_WRAPPERS_H\n\n")
headerBuf.WriteString("// Include MLX headers for type definitions and original declarations\n")
headerBuf.WriteString("#include \"mlx/c/mlx.h\"\n")
headerBuf.WriteString("#include \"mlx_dynamic.h\"\n")
headerBuf.WriteString("#include <stdio.h>\n\n")
// Undef all MLX functions to avoid conflicts
headerBuf.WriteString("// Undefine any existing MLX function macros\n")
for _, fn := range functions {
headerBuf.WriteString(fmt.Sprintf("#undef %s\n", fn.Name))
}
headerBuf.WriteString("\n")
// Function pointer extern declarations
headerBuf.WriteString("// Function pointer declarations (defined in mlx.c, loaded via dlsym)\n")
for _, fn := range functions {
if fn.NeedsARM64Guard {
headerBuf.WriteString("#if defined(__aarch64__) || defined(_M_ARM64)\n")
}
headerBuf.WriteString(fmt.Sprintf("extern %s (*%s_ptr)(%s);\n", fn.ReturnType, fn.Name, fn.Params))
if fn.NeedsARM64Guard {
headerBuf.WriteString("#endif\n")
}
}
headerBuf.WriteString("\n")
// Initialization function declaration
headerBuf.WriteString("// Initialize all function pointers via dlsym (defined in mlx.c)\n")
headerBuf.WriteString("int mlx_load_functions(void* handle);\n\n")
// Wrapper function declarations
headerBuf.WriteString("// Wrapper function declarations that call through function pointers\n")
headerBuf.WriteString("// Go code calls these directly as C.mlx_* (no #define redirection needed)\n")
for _, fn := range functions {
if fn.NeedsARM64Guard {
headerBuf.WriteString("#if defined(__aarch64__) || defined(_M_ARM64)\n")
}
headerBuf.WriteString(fmt.Sprintf("%s %s(%s);\n", fn.ReturnType, fn.Name, fn.Params))
if fn.NeedsARM64Guard {
headerBuf.WriteString("#endif\n")
}
headerBuf.WriteString("\n")
}
headerBuf.WriteString("#endif // MLX_WRAPPERS_H\n")
// Write header file
if err := os.WriteFile(headerPath, headerBuf.Bytes(), 0644); err != nil {
return fmt.Errorf("failed to write header file: %w", err)
}
// Generate implementation file
var implBuf bytes.Buffer
implBuf.WriteString("// AUTO-GENERATED by generate_wrappers.go - DO NOT EDIT\n")
implBuf.WriteString("// This file contains the function pointer definitions and initialization\n")
implBuf.WriteString("// All function pointers are in a single compilation unit to avoid duplication\n\n")
implBuf.WriteString("#include \"mlx/c/mlx.h\"\n")
implBuf.WriteString("#include \"mlx_dynamic.h\"\n")
implBuf.WriteString("#include <stdio.h>\n")
implBuf.WriteString("#include <dlfcn.h>\n\n")
// Function pointer definitions
implBuf.WriteString("// Function pointer definitions\n")
for _, fn := range functions {
if fn.NeedsARM64Guard {
implBuf.WriteString("#if defined(__aarch64__) || defined(_M_ARM64)\n")
}
implBuf.WriteString(fmt.Sprintf("%s (*%s_ptr)(%s) = NULL;\n", fn.ReturnType, fn.Name, fn.Params))
if fn.NeedsARM64Guard {
implBuf.WriteString("#endif\n")
}
}
implBuf.WriteString("\n")
// Initialization function
implBuf.WriteString("// Initialize all function pointers via dlsym\n")
implBuf.WriteString("int mlx_load_functions(void* handle) {\n")
implBuf.WriteString(" if (handle == NULL) {\n")
implBuf.WriteString(" fprintf(stderr, \"MLX: Invalid library handle\\n\");\n")
implBuf.WriteString(" return -1;\n")
implBuf.WriteString(" }\n\n")
for _, fn := range functions {
if fn.NeedsARM64Guard {
implBuf.WriteString("#if defined(__aarch64__) || defined(_M_ARM64)\n")
}
implBuf.WriteString(fmt.Sprintf(" %s_ptr = dlsym(handle, \"%s\");\n", fn.Name, fn.Name))
implBuf.WriteString(fmt.Sprintf(" if (%s_ptr == NULL) {\n", fn.Name))
implBuf.WriteString(fmt.Sprintf(" fprintf(stderr, \"MLX: Failed to load symbol: %s\\n\");\n", fn.Name))
implBuf.WriteString(" return -1;\n")
implBuf.WriteString(" }\n")
if fn.NeedsARM64Guard {
implBuf.WriteString("#endif\n")
}
}
implBuf.WriteString(" return 0;\n")
implBuf.WriteString("}\n\n")
// Wrapper function implementations
implBuf.WriteString("// Wrapper function implementations that call through function pointers\n")
for _, fn := range functions {
if fn.NeedsARM64Guard {
implBuf.WriteString("#if defined(__aarch64__) || defined(_M_ARM64)\n")
}
implBuf.WriteString(fmt.Sprintf("%s %s(%s) {\n", fn.ReturnType, fn.Name, fn.Params))
// Call through function pointer
if fn.ReturnType != "void" {
implBuf.WriteString(fmt.Sprintf(" return %s_ptr(", fn.Name))
} else {
implBuf.WriteString(fmt.Sprintf(" %s_ptr(", fn.Name))
}
// Pass parameters
implBuf.WriteString(strings.Join(fn.ParamNames, ", "))
implBuf.WriteString(");\n")
implBuf.WriteString("}\n")
if fn.NeedsARM64Guard {
implBuf.WriteString("#endif\n")
}
implBuf.WriteString("\n")
}
// Write implementation file
if err := os.WriteFile(implPath, implBuf.Bytes(), 0644); err != nil {
return fmt.Errorf("failed to write implementation file: %w", err)
}
return nil
}
func main() {
flag.Usage = func() {
fmt.Fprintf(flag.CommandLine.Output(), "Usage: go run generate_wrappers.go <mlx-c-include-dir> <output-header> [output-impl]\n")
fmt.Fprintf(flag.CommandLine.Output(), "Generate MLX-C dynamic loading wrappers.\n\n")
flag.PrintDefaults()
}
flag.Parse()
args := flag.Args()
if len(args) < 2 {
fmt.Fprintf(flag.CommandLine.Output(), "ERROR: Missing required arguments\n\n")
flag.Usage()
os.Exit(1)
}
headerDir := args[0]
outputHeader := args[1]
// Default implementation file is same name with .c extension
outputImpl := outputHeader
if len(args) > 2 {
outputImpl = args[2]
} else if strings.HasSuffix(outputHeader, ".h") {
outputImpl = outputHeader[:len(outputHeader)-2] + ".c"
}
// Check if header directory exists
if _, err := os.Stat(headerDir); os.IsNotExist(err) {
fmt.Fprintf(os.Stderr, "ERROR: MLX-C headers directory not found at: %s\n\n", headerDir)
fmt.Fprintf(os.Stderr, "Please run CMake first to download MLX-C dependencies:\n")
fmt.Fprintf(os.Stderr, " cmake -B build\n\n")
fmt.Fprintf(os.Stderr, "The CMake build will download and extract MLX-C headers needed for wrapper generation.\n")
os.Exit(1)
}
fmt.Fprintf(os.Stderr, "Parsing MLX-C headers from: %s\n", headerDir)
// Find all headers
headers, err := findHeaders(headerDir)
if err != nil {
fmt.Fprintf(os.Stderr, "ERROR: Failed to find header files: %v\n", err)
os.Exit(1)
}
fmt.Fprintf(os.Stderr, "Found %d header files\n", len(headers))
// Parse all headers
var allFunctions []Function
seen := make(map[string]bool)
for _, header := range headers {
content, err := os.ReadFile(header)
if err != nil {
fmt.Fprintf(os.Stderr, "Error reading %s: %v\n", header, err)
continue
}
cleaned := cleanContent(string(content))
functions := parseFunctions(cleaned)
// Deduplicate
for _, fn := range functions {
if !seen[fn.Name] {
seen[fn.Name] = true
allFunctions = append(allFunctions, fn)
}
}
}
fmt.Fprintf(os.Stderr, "Found %d unique function declarations\n", len(allFunctions))
// Generate wrapper files
if err := generateWrapperFiles(allFunctions, outputHeader, outputImpl); err != nil {
fmt.Fprintf(os.Stderr, "ERROR: Failed to generate wrapper files: %v\n", err)
os.Exit(1)
}
fmt.Fprintf(os.Stderr, "Generated %s and %s successfully\n", outputHeader, outputImpl)
}

5786
x/imagegen/mlx/mlx.c Normal file
View File

File diff suppressed because it is too large Load Diff

View File

@@ -3,12 +3,13 @@
package mlx package mlx
/* /*
#cgo CFLAGS: -O3 -I${SRCDIR}/../../../build/_deps/mlx-c-src #cgo CFLAGS: -O3 -I${SRCDIR}/../../../build/_deps/mlx-c-src -I${SRCDIR}
#cgo LDFLAGS: -L${SRCDIR}/../../../build/lib/ollama/ -lmlxc -Wl,-rpath,${SRCDIR}/../../../build/lib/ollama/
#cgo darwin LDFLAGS: -lc++ -framework Metal -framework Foundation -framework Accelerate #cgo darwin LDFLAGS: -lc++ -framework Metal -framework Foundation -framework Accelerate
#cgo linux LDFLAGS: -lstdc++ -lcuda -lcudart -lnvrtc #cgo linux LDFLAGS: -lstdc++ -ldl
#cgo windows LDFLAGS: -lstdc++
#include "mlx/c/mlx.h" // Use generated wrappers instead of direct MLX headers
#include "mlx.h"
#include <stdlib.h> #include <stdlib.h>
#include <stdint.h> #include <stdint.h>
#include <string.h> #include <string.h>
@@ -42,192 +43,6 @@ static inline mlx_stream cpu_stream() {
// CGO noescape/nocallback hints to reduce CGO overhead // CGO noescape/nocallback hints to reduce CGO overhead
// noescape: pointers won't escape, no heap allocation needed // noescape: pointers won't escape, no heap allocation needed
// nocallback: function won't call back into Go // nocallback: function won't call back into Go
#cgo noescape mlx_add
#cgo nocallback mlx_add
#cgo noescape mlx_subtract
#cgo nocallback mlx_subtract
#cgo noescape mlx_multiply
#cgo nocallback mlx_multiply
#cgo noescape mlx_divide
#cgo nocallback mlx_divide
#cgo noescape mlx_negative
#cgo nocallback mlx_negative
#cgo noescape mlx_abs
#cgo nocallback mlx_abs
#cgo noescape mlx_exp
#cgo nocallback mlx_exp
#cgo noescape mlx_log
#cgo nocallback mlx_log
#cgo noescape mlx_sqrt
#cgo nocallback mlx_sqrt
#cgo noescape mlx_rsqrt
#cgo nocallback mlx_rsqrt
#cgo noescape mlx_square
#cgo nocallback mlx_square
#cgo noescape mlx_power
#cgo nocallback mlx_power
#cgo noescape mlx_erf
#cgo nocallback mlx_erf
#cgo noescape mlx_sigmoid
#cgo nocallback mlx_sigmoid
#cgo noescape mlx_tanh
#cgo nocallback mlx_tanh
#cgo noescape mlx_sin
#cgo nocallback mlx_sin
#cgo noescape mlx_cos
#cgo nocallback mlx_cos
#cgo noescape mlx_maximum
#cgo nocallback mlx_maximum
#cgo noescape mlx_minimum
#cgo nocallback mlx_minimum
#cgo noescape mlx_clip
#cgo nocallback mlx_clip
#cgo noescape mlx_sum
#cgo nocallback mlx_sum
#cgo noescape mlx_sum_axis
#cgo nocallback mlx_sum_axis
#cgo noescape mlx_mean
#cgo nocallback mlx_mean
#cgo noescape mlx_mean_axis
#cgo nocallback mlx_mean_axis
#cgo noescape mlx_var_axis
#cgo nocallback mlx_var_axis
#cgo noescape mlx_argmax
#cgo nocallback mlx_argmax
#cgo noescape mlx_argmax_axis
#cgo nocallback mlx_argmax_axis
#cgo noescape mlx_softmax_axis
#cgo nocallback mlx_softmax_axis
#cgo noescape mlx_cumsum
#cgo nocallback mlx_cumsum
#cgo noescape mlx_matmul
#cgo nocallback mlx_matmul
#cgo noescape mlx_addmm
#cgo nocallback mlx_addmm
#cgo noescape mlx_gather_mm
#cgo nocallback mlx_gather_mm
#cgo noescape mlx_gather_qmm
#cgo nocallback mlx_gather_qmm
#cgo noescape mlx_reshape
#cgo nocallback mlx_reshape
#cgo noescape mlx_transpose_axes
#cgo nocallback mlx_transpose_axes
#cgo noescape mlx_expand_dims
#cgo nocallback mlx_expand_dims
#cgo noescape mlx_squeeze_axis
#cgo nocallback mlx_squeeze_axis
#cgo noescape mlx_flatten
#cgo nocallback mlx_flatten
#cgo noescape mlx_concatenate_axis
#cgo nocallback mlx_concatenate_axis
#cgo noescape mlx_slice
#cgo nocallback mlx_slice
#cgo noescape mlx_slice_update
#cgo nocallback mlx_slice_update
#cgo noescape mlx_as_strided
#cgo nocallback mlx_as_strided
#cgo noescape mlx_view
#cgo nocallback mlx_view
#cgo noescape mlx_contiguous
#cgo nocallback mlx_contiguous
#cgo noescape mlx_pad
#cgo nocallback mlx_pad
#cgo noescape mlx_tile
#cgo nocallback mlx_tile
#cgo noescape mlx_take_axis
#cgo nocallback mlx_take_axis
#cgo noescape mlx_take_along_axis
#cgo nocallback mlx_take_along_axis
#cgo noescape mlx_put_along_axis
#cgo nocallback mlx_put_along_axis
#cgo noescape mlx_where
#cgo nocallback mlx_where
#cgo noescape mlx_argsort_axis
#cgo nocallback mlx_argsort_axis
#cgo noescape mlx_argpartition_axis
#cgo nocallback mlx_argpartition_axis
#cgo noescape mlx_topk_axis
#cgo nocallback mlx_topk_axis
#cgo noescape mlx_less
#cgo nocallback mlx_less
#cgo noescape mlx_greater_equal
#cgo nocallback mlx_greater_equal
#cgo noescape mlx_logical_and
#cgo nocallback mlx_logical_and
#cgo noescape mlx_zeros
#cgo nocallback mlx_zeros
#cgo noescape mlx_zeros_like
#cgo nocallback mlx_zeros_like
#cgo noescape mlx_ones
#cgo nocallback mlx_ones
#cgo noescape mlx_full
#cgo nocallback mlx_full
#cgo noescape mlx_arange
#cgo nocallback mlx_arange
#cgo noescape mlx_linspace
#cgo nocallback mlx_linspace
#cgo noescape mlx_tri
#cgo nocallback mlx_tri
#cgo noescape mlx_astype
#cgo nocallback mlx_astype
#cgo noescape mlx_fast_rms_norm
#cgo nocallback mlx_fast_rms_norm
#cgo noescape mlx_fast_rope
#cgo nocallback mlx_fast_rope
#cgo noescape mlx_fast_scaled_dot_product_attention
#cgo nocallback mlx_fast_scaled_dot_product_attention
#cgo noescape mlx_conv2d
#cgo nocallback mlx_conv2d
#cgo noescape mlx_conv3d
#cgo nocallback mlx_conv3d
#cgo noescape mlx_random_key
#cgo nocallback mlx_random_key
#cgo noescape mlx_random_split
#cgo nocallback mlx_random_split
#cgo noescape mlx_random_categorical_num_samples
#cgo nocallback mlx_random_categorical_num_samples
#cgo noescape mlx_random_normal
#cgo nocallback mlx_random_normal
#cgo noescape mlx_random_uniform
#cgo nocallback mlx_random_uniform
#cgo noescape mlx_array_eval
#cgo nocallback mlx_array_eval
#cgo noescape mlx_eval
#cgo nocallback mlx_eval
#cgo noescape mlx_async_eval
#cgo nocallback mlx_async_eval
#cgo noescape mlx_synchronize
#cgo nocallback mlx_synchronize
#cgo noescape mlx_array_new
#cgo nocallback mlx_array_new
#cgo noescape mlx_array_new_data
#cgo nocallback mlx_array_new_data
#cgo noescape mlx_array_new_float
#cgo nocallback mlx_array_new_float
#cgo noescape mlx_array_free
#cgo nocallback mlx_array_free
#cgo noescape mlx_array_size
#cgo nocallback mlx_array_size
#cgo noescape mlx_array_ndim
#cgo nocallback mlx_array_ndim
#cgo noescape mlx_array_dim
#cgo nocallback mlx_array_dim
#cgo noescape mlx_array_dtype
#cgo nocallback mlx_array_dtype
#cgo noescape mlx_array_item_int32
#cgo nocallback mlx_array_item_int32
#cgo noescape mlx_vector_array_new_data
#cgo nocallback mlx_vector_array_new_data
#cgo noescape mlx_vector_array_free
#cgo nocallback mlx_vector_array_free
#cgo noescape mlx_array_new_int
#cgo nocallback mlx_array_new_int
#cgo noescape mlx_stream_new_device
#cgo nocallback mlx_stream_new_device
#cgo noescape mlx_get_default_stream
#cgo nocallback mlx_get_default_stream
#cgo noescape mlx_set_default_stream
#cgo nocallback mlx_set_default_stream
*/ */
import "C" import "C"
import ( import (
@@ -1322,6 +1137,27 @@ func RMSNormNoWeight(x *Array, eps float32) *Array {
return RMSNorm(x, ones, eps) return RMSNorm(x, ones, eps)
} }
// LayerNorm applies layer normalization without learnable params
// (x - mean) / sqrt(var + eps)
func LayerNorm(x *Array, eps float32) *Array {
return LayerNormWithWeightBias(x, nil, nil, eps)
}
// LayerNormWithWeightBias computes layer normalization using mlx.fast
// weight and bias can be nil for elementwise_affine=False
func LayerNormWithWeightBias(x, weight, bias *Array, eps float32) *Array {
res := C.mlx_array_new()
var wc, bc C.mlx_array
if weight != nil {
wc = weight.c
}
if bias != nil {
bc = bias.c
}
C.mlx_fast_layer_norm(&res, x.c, wc, bc, C.float(eps), C.default_stream())
return newArray(res)
}
// RoPE applies rotary position embeddings using mlx.fast // RoPE applies rotary position embeddings using mlx.fast
func RoPE(x *Array, dims int, traditional bool, base, scale float32, offset int) *Array { func RoPE(x *Array, dims int, traditional bool, base, scale float32, offset int) *Array {
res := C.mlx_array_new() res := C.mlx_array_new()
@@ -1796,7 +1632,57 @@ func ArgmaxKeepArray(logits *Array) *Array {
var RandomState = []*Array{nil} var RandomState = []*Array{nil}
var randomStateMu sync.Mutex var randomStateMu sync.Mutex
var mlxInitialized bool
var mlxInitError error
// InitMLX initializes the MLX library by dynamically loading libmlxc.
// This must be called before using any MLX functions.
// Returns an error if the library cannot be loaded.
func InitMLX() error {
if mlxInitialized {
return mlxInitError
}
// Try to load the MLX dynamic library
ret := C.mlx_dynamic_init()
if ret != 0 {
errMsg := C.GoString(C.mlx_dynamic_error())
mlxInitError = fmt.Errorf("failed to initialize MLX: %s", errMsg)
return mlxInitError
}
// Initialize all function pointers via dlsym
handle := C.mlx_get_handle()
ret = C.mlx_load_functions(handle)
if ret != 0 {
mlxInitError = fmt.Errorf("failed to load MLX function symbols")
return mlxInitError
}
mlxInitialized = true
mlxInitError = nil
return nil
}
// IsMLXAvailable returns whether MLX was successfully initialized
func IsMLXAvailable() bool {
return mlxInitialized && mlxInitError == nil
}
// GetMLXInitError returns any error that occurred during MLX initialization
func GetMLXInitError() error {
return mlxInitError
}
func init() { func init() {
// Initialize MLX dynamic library first
if err := InitMLX(); err != nil {
// Don't panic in init - let the caller handle the error
// Store the error for later retrieval
mlxInitError = err
return
}
// Lock main goroutine to OS thread for CUDA context stability. // Lock main goroutine to OS thread for CUDA context stability.
// CUDA contexts are bound to threads; Go can migrate goroutines between threads. // CUDA contexts are bound to threads; Go can migrate goroutines between threads.
runtime.LockOSThread() runtime.LockOSThread()

2337
x/imagegen/mlx/mlx.h Normal file
View File

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,144 @@
// mlx_dynamic.c - Dynamic loading wrapper for MLX-C library
// This file provides runtime dynamic loading of libmlxc instead of link-time binding
#include "mlx_dynamic.h"
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#ifdef _WIN32
#include <windows.h>
typedef HMODULE lib_handle_t;
#define LOAD_LIB(path) LoadLibraryA(path)
#define GET_SYMBOL(handle, name) GetProcAddress(handle, name)
#define CLOSE_LIB(handle) FreeLibrary(handle)
#define LIB_ERROR() "LoadLibrary failed"
#else
#include <dlfcn.h>
typedef void* lib_handle_t;
#define LOAD_LIB(path) dlopen(path, RTLD_LAZY | RTLD_GLOBAL)
#define GET_SYMBOL(handle, name) dlsym(handle, name)
#define CLOSE_LIB(handle) dlclose(handle)
#define LIB_ERROR() dlerror()
#ifdef __APPLE__
#include <mach-o/dyld.h>
#include <libgen.h>
#endif
#endif
static lib_handle_t mlx_handle = NULL;
static int mlx_initialized = 0;
static char mlx_error_buffer[512] = {0};
#ifdef __APPLE__
// Get path to library in same directory as executable
static char* get_exe_relative_path(const char* libname) {
static char path[1024];
uint32_t size = sizeof(path);
if (_NSGetExecutablePath(path, &size) != 0) {
return NULL;
}
// Get directory of executable
char* dir = dirname(path);
static char fullpath[1024];
snprintf(fullpath, sizeof(fullpath), "%s/%s", dir, libname);
return fullpath;
}
#endif
// Try to load library from a specific path
static int try_load_lib(const char* path) {
if (!path) return 0;
mlx_handle = LOAD_LIB(path);
return mlx_handle != NULL;
}
// Initialize MLX dynamic library
// Returns 0 on success, -1 on failure
// On failure, call mlx_dynamic_error() to get error message
int mlx_dynamic_init(void) {
if (mlx_initialized) {
return 0; // Already initialized
}
const char* lib_path = NULL;
const char* tried_paths[8] = {0};
int num_tried = 0;
#ifdef _WIN32
// Windows: try same directory as executable
lib_path = "libmlxc.dll";
tried_paths[num_tried++] = lib_path;
if (try_load_lib(lib_path)) goto success;
#elif defined(__APPLE__)
// macOS: try executable directory first
lib_path = get_exe_relative_path("libmlxc.dylib");
if (lib_path) {
tried_paths[num_tried++] = lib_path;
if (try_load_lib(lib_path)) goto success;
}
// Try build directory (for tests run from repo root)
lib_path = "./build/lib/ollama/libmlxc.dylib";
tried_paths[num_tried++] = lib_path;
if (try_load_lib(lib_path)) goto success;
// Fallback to system paths
lib_path = "libmlxc.dylib";
tried_paths[num_tried++] = lib_path;
if (try_load_lib(lib_path)) goto success;
#else
// Linux: try build directory first (for tests)
lib_path = "./build/lib/ollama/libmlxc.so";
tried_paths[num_tried++] = lib_path;
if (try_load_lib(lib_path)) goto success;
// Fallback to system paths
lib_path = "libmlxc.so";
tried_paths[num_tried++] = lib_path;
if (try_load_lib(lib_path)) goto success;
#endif
// Failed to load library - build error message with all tried paths
{
const char* err = LIB_ERROR();
int offset = snprintf(mlx_error_buffer, sizeof(mlx_error_buffer),
"MLX: Failed to load libmlxc library. Tried: ");
for (int i = 0; i < num_tried && offset < (int)sizeof(mlx_error_buffer) - 50; i++) {
offset += snprintf(mlx_error_buffer + offset, sizeof(mlx_error_buffer) - offset,
"%s%s", i > 0 ? ", " : "", tried_paths[i]);
}
if (err) {
snprintf(mlx_error_buffer + offset, sizeof(mlx_error_buffer) - offset,
". Last error: %s", err);
}
}
return -1;
success:
mlx_initialized = 1;
snprintf(mlx_error_buffer, sizeof(mlx_error_buffer),
"MLX: Successfully loaded %s", lib_path ? lib_path : "library");
return 0;
}
// Get the last error message
const char* mlx_dynamic_error(void) {
return mlx_error_buffer;
}
// Check if MLX is initialized
int mlx_dynamic_is_initialized(void) {
return mlx_initialized;
}
// Get the library handle (for use by generated wrappers)
void* mlx_get_handle(void) {
return mlx_handle;
}
// Cleanup (optional, called at program exit)
void mlx_dynamic_cleanup(void) {
if (mlx_handle != NULL) {
CLOSE_LIB(mlx_handle);
mlx_handle = NULL;
mlx_initialized = 0;
}
}

View File

@@ -0,0 +1,29 @@
// mlx_dynamic.h - Dynamic loading interface for MLX-C library
#ifndef MLX_DYNAMIC_H
#define MLX_DYNAMIC_H
#ifdef __cplusplus
extern "C" {
#endif
// Initialize the MLX dynamic library
// Returns 0 on success, -1 on failure
int mlx_dynamic_init(void);
// Get the last error message from dynamic loading
const char* mlx_dynamic_error(void);
// Check if MLX is initialized
int mlx_dynamic_is_initialized(void);
// Get the library handle (for use by generated wrappers)
void* mlx_get_handle(void);
// Cleanup resources (optional, for clean shutdown)
void mlx_dynamic_cleanup(void);
#ifdef __cplusplus
}
#endif
#endif // MLX_DYNAMIC_H

View File

@@ -4,9 +4,30 @@ package mlx
import ( import (
"fmt" "fmt"
"os"
"path/filepath"
"runtime"
"testing" "testing"
) )
// TestMain initializes MLX before running tests.
// If MLX libraries are not available, tests are skipped.
func TestMain(m *testing.M) {
// Change to repo root so ./build/lib/ollama/ path works
_, thisFile, _, _ := runtime.Caller(0)
repoRoot := filepath.Join(filepath.Dir(thisFile), "..", "..", "..")
if err := os.Chdir(repoRoot); err != nil {
fmt.Printf("Failed to change to repo root: %v\n", err)
os.Exit(1)
}
if err := InitMLX(); err != nil {
fmt.Printf("Skipping MLX tests: %v\n", err)
os.Exit(0)
}
os.Exit(m.Run())
}
// TestBasicCleanup verifies non-kept arrays are freed and kept arrays survive. // TestBasicCleanup verifies non-kept arrays are freed and kept arrays survive.
func TestBasicCleanup(t *testing.T) { func TestBasicCleanup(t *testing.T) {
weight := NewArrayFloat32([]float32{1, 2, 3, 4}, []int32{2, 2}) weight := NewArrayFloat32([]float32{1, 2, 3, 4}, []int32{2, 2})

View File

@@ -0,0 +1,539 @@
//go:build mlx
// Package flux2 implements the FLUX.2 Klein diffusion transformer model.
// Klein is a 4B parameter distilled model that supports sub-second inference.
package flux2
import (
"context"
"encoding/json"
"fmt"
"image"
"math"
"time"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/models/qwen3"
"github.com/ollama/ollama/x/imagegen/tokenizer"
"golang.org/x/image/draw"
)
// GenerateConfig holds all options for image generation.
type GenerateConfig struct {
Prompt string
Width int32 // Image width (default: 1024)
Height int32 // Image height (default: 1024)
Steps int // Denoising steps (default: 4 for Klein)
GuidanceScale float32 // Guidance scale (default: 1.0, Klein doesn't need CFG)
Seed int64 // Random seed
Progress func(step, totalSteps int) // Optional progress callback
CapturePath string // GPU capture path (debug)
InputImages []image.Image // Reference images for image conditioning (already loaded)
}
// Model represents a FLUX.2 Klein model.
type Model struct {
ModelName string
Tokenizer *tokenizer.Tokenizer
TextEncoder *qwen3.TextEncoder
Transformer *Flux2Transformer2DModel
VAE *AutoencoderKLFlux2
SchedulerConfig *SchedulerConfig
}
// TextEncoderLayerIndices are the layers from which to extract text embeddings.
// Diffusers uses hidden_states[9, 18, 27]. In Python, hidden_states[0] is the embedding
// output before any layers, so hidden_states[9] = after layer 8 (0-indexed).
// Go's ForwardWithLayerOutputs captures after layer i runs, so we use [8, 17, 26].
var TextEncoderLayerIndices = []int{8, 17, 26}
// Load loads the FLUX.2 Klein model from ollama blob storage.
func (m *Model) Load(modelName string) error {
fmt.Printf("Loading FLUX.2 Klein model from manifest: %s...\n", modelName)
start := time.Now()
if mlx.GPUIsAvailable() {
mlx.SetDefaultDeviceGPU()
mlx.EnableCompile()
}
m.ModelName = modelName
// Load manifest
manifest, err := imagegen.LoadManifest(modelName)
if err != nil {
return fmt.Errorf("load manifest: %w", err)
}
// Load tokenizer
fmt.Print(" Loading tokenizer... ")
tokData, err := manifest.ReadConfig("tokenizer/tokenizer.json")
if err != nil {
return fmt.Errorf("tokenizer: %w", err)
}
tokConfig := &tokenizer.TokenizerConfig{}
if data, err := manifest.ReadConfig("tokenizer/tokenizer_config.json"); err == nil {
tokConfig.TokenizerConfigJSON = data
}
if data, err := manifest.ReadConfig("tokenizer/generation_config.json"); err == nil {
tokConfig.GenerationConfigJSON = data
}
if data, err := manifest.ReadConfig("tokenizer/special_tokens_map.json"); err == nil {
tokConfig.SpecialTokensMapJSON = data
}
tok, err := tokenizer.LoadFromBytesWithConfig(tokData, tokConfig)
if err != nil {
return fmt.Errorf("tokenizer: %w", err)
}
m.Tokenizer = tok
fmt.Println("✓")
// Load text encoder
m.TextEncoder = &qwen3.TextEncoder{}
if err := m.TextEncoder.Load(manifest, "text_encoder/config.json"); err != nil {
return fmt.Errorf("text encoder: %w", err)
}
// Load transformer
m.Transformer = &Flux2Transformer2DModel{}
if err := m.Transformer.Load(manifest); err != nil {
return fmt.Errorf("transformer: %w", err)
}
// Load VAE
m.VAE = &AutoencoderKLFlux2{}
if err := m.VAE.Load(manifest); err != nil {
return fmt.Errorf("VAE: %w", err)
}
// Evaluate all weights in a single batch (reduces GPU sync overhead)
fmt.Print(" Evaluating weights... ")
allWeights := mlx.Collect(m.TextEncoder)
allWeights = append(allWeights, mlx.Collect(m.Transformer)...)
allWeights = append(allWeights, mlx.Collect(m.VAE)...)
mlx.Eval(allWeights...)
fmt.Println("✓")
// Load scheduler config
m.SchedulerConfig = DefaultSchedulerConfig()
if schedData, err := manifest.ReadConfig("scheduler/scheduler_config.json"); err == nil {
if err := json.Unmarshal(schedData, m.SchedulerConfig); err != nil {
fmt.Printf(" Warning: failed to parse scheduler config: %v\n", err)
}
}
mem := mlx.MetalGetActiveMemory()
fmt.Printf(" Loaded in %.2fs (%.1f GB VRAM)\n", time.Since(start).Seconds(), float64(mem)/(1024*1024*1024))
return nil
}
// Generate creates an image from a prompt.
func (m *Model) Generate(prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
return m.GenerateFromConfig(context.Background(), &GenerateConfig{
Prompt: prompt,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
})
}
// GenerateWithProgress creates an image with progress callback.
func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress func(step, totalSteps int)) (*mlx.Array, error) {
return m.GenerateFromConfig(context.Background(), &GenerateConfig{
Prompt: prompt,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
Progress: progress,
})
}
// GenerateFromConfig generates an image using the unified config struct.
func (m *Model) GenerateFromConfig(ctx context.Context, cfg *GenerateConfig) (*mlx.Array, error) {
start := time.Now()
result, err := m.generate(ctx, cfg)
if err != nil {
return nil, err
}
fmt.Printf("Generated in %.2fs (%d steps)\n", time.Since(start).Seconds(), cfg.Steps)
return result, nil
}
// GenerateImage implements runner.ImageModel interface.
func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64, progress func(step, total int)) (*mlx.Array, error) {
return m.GenerateFromConfig(ctx, &GenerateConfig{
Prompt: prompt,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
Progress: progress,
})
}
// MaxOutputPixels is the maximum output resolution (4 megapixels, ~2048x2048)
const MaxOutputPixels = 2048 * 2048
// MaxRefPixels is the maximum resolution for reference images (smaller to reduce attention memory)
const MaxRefPixels = 728 * 728
// generate is the internal denoising pipeline.
func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array, error) {
// Enable MLX compilation for fused kernels
mlx.EnableCompile()
// Apply defaults
if cfg.Steps <= 0 {
cfg.Steps = 4 // Klein default: 4 steps for distilled model
}
if cfg.GuidanceScale <= 0 {
cfg.GuidanceScale = 1.0 // Klein doesn't need guidance
}
// Determine output dimensions
if len(cfg.InputImages) > 0 {
// With input images, compute missing dimension from aspect ratio
// Images are already EXIF-rotated by the caller
bounds := cfg.InputImages[0].Bounds()
imgW, imgH := bounds.Dx(), bounds.Dy()
aspectRatio := float64(imgH) / float64(imgW)
if cfg.Width > 0 && cfg.Height <= 0 {
// Width specified, compute height
cfg.Height = int32(math.Round(float64(cfg.Width)*aspectRatio/16) * 16)
} else if cfg.Height > 0 && cfg.Width <= 0 {
// Height specified, compute width
cfg.Width = int32(math.Round(float64(cfg.Height)/aspectRatio/16) * 16)
} else if cfg.Width <= 0 && cfg.Height <= 0 {
// Neither specified, use input dimensions
cfg.Width = int32(imgW)
cfg.Height = int32(imgH)
}
}
if cfg.Width <= 0 {
cfg.Width = 1024
}
if cfg.Height <= 0 {
cfg.Height = 1024
}
// Cap to max pixels, preserve aspect ratio, round to multiple of 16
pixels := int(cfg.Width) * int(cfg.Height)
if pixels > MaxOutputPixels {
scale := math.Sqrt(float64(MaxOutputPixels) / float64(pixels))
cfg.Width = int32(math.Round(float64(cfg.Width) * scale / 16) * 16)
cfg.Height = int32(math.Round(float64(cfg.Height) * scale / 16) * 16)
}
cfg.Height = int32((cfg.Height + 8) / 16 * 16) // round to nearest 16
cfg.Width = int32((cfg.Width + 8) / 16 * 16)
fmt.Printf(" Output: %dx%d\n", cfg.Width, cfg.Height)
tcfg := m.Transformer.TransformerConfig
patchSize := m.VAE.Config.PatchSize
// Latent dimensions: image / 8 (VAE downscale) / patch_size
latentH := cfg.Height / 8
latentW := cfg.Width / 8
patchH := latentH / patchSize[0]
patchW := latentW / patchSize[1]
imgSeqLen := patchH * patchW
// Text encoding with multi-layer extraction (no padding, use true sequence length)
fmt.Print(" Encoding prompt... ")
promptEmbeds, textLen := m.TextEncoder.EncodePromptWithLayers(m.Tokenizer, cfg.Prompt, 512, TextEncoderLayerIndices, false)
fmt.Println("✓")
// Encode reference images if provided
var refTokens *ImageCondTokens
var refHeights, refWidths []int32
if len(cfg.InputImages) > 0 {
fmt.Printf(" Encoding %d reference image(s):\n", len(cfg.InputImages))
var err error
refTokens, err = m.EncodeImageRefs(cfg.InputImages)
if err != nil {
return nil, fmt.Errorf("encode reference images: %w", err)
}
// Extract heights/widths for RoPE computation (same limits as EncodeImageRefs)
limitPixels := MaxRefPixels
if len(cfg.InputImages) > 1 {
limitPixels = MaxRefPixels / 2
}
for _, img := range cfg.InputImages {
_, w, h := PrepareImage(img, limitPixels)
refHeights = append(refHeights, int32(h/16))
refWidths = append(refWidths, int32(w/16))
}
}
// Scheduler
scheduler := NewFlowMatchScheduler(m.SchedulerConfig)
scheduler.SetTimestepsWithMu(cfg.Steps, CalculateShift(imgSeqLen, cfg.Steps))
// Init latents in packed form [B, C*4, H/2, W/2] like diffusers
// diffusers creates noise in [B, 128, 64, 64] and packs to [B, 4096, 128]
latentChannels := m.VAE.Config.LatentChannels
packedChannels := latentChannels * 4 // 32 * 4 = 128
latents := scheduler.InitNoise([]int32{1, packedChannels, patchH, patchW}, cfg.Seed)
// Pack latents (transpose): [B, C, H, W] -> [B, H*W, C]
// This matches diffusers' _pack_latents
patches := packLatents(latents)
noiseSeqLen := patches.Shape()[1]
// RoPE cache - includes reference images if present
rope := PrepareRoPECache(textLen, patchH, patchW, tcfg.AxesDimsRoPE, tcfg.RopeTheta, refHeights, refWidths, ImageRefScale)
// Cleanup setup arrays when done
defer func() {
rope.Cos.Free()
rope.Sin.Free()
promptEmbeds.Free()
if refTokens != nil {
refTokens.Tokens.Free()
}
}()
// Pre-compute all timesteps before the loop to avoid per-step tensor creation
timesteps := make([]*mlx.Array, cfg.Steps)
for i := 0; i < cfg.Steps; i++ {
tCurr := scheduler.Timesteps[i] / float32(m.SchedulerConfig.NumTrainTimesteps)
timesteps[i] = mlx.ToBFloat16(mlx.NewArray([]float32{tCurr}, []int32{1}))
}
// Evaluate setup arrays
fmt.Print(" Evaluating setup... ")
setupStart := time.Now()
toEval := []*mlx.Array{promptEmbeds, patches, rope.Cos, rope.Sin}
toEval = append(toEval, timesteps...)
if refTokens != nil {
toEval = append(toEval, refTokens.Tokens)
}
mlx.Eval(toEval...)
mlx.MetalResetPeakMemory() // Reset peak to measure generation separately
fmt.Printf("✓ (%.2fs, %.1f GB)\n", time.Since(setupStart).Seconds(),
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024))
if cfg.Progress != nil {
cfg.Progress(0, cfg.Steps)
}
loopStart := time.Now()
stepStart := time.Now()
// Denoising loop
for i := 0; i < cfg.Steps; i++ {
// Check for cancellation
if ctx != nil {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
}
// GPU capture on step 2 if requested
if cfg.CapturePath != "" && i == 1 {
mlx.MetalStartCapture(cfg.CapturePath)
}
timestep := timesteps[i]
// Prepare input - concatenate noise patches with reference tokens if present
imgInput := patches
if refTokens != nil {
imgInput = mlx.Concatenate([]*mlx.Array{patches, refTokens.Tokens}, 1)
}
// Transformer forward pass
output := m.Transformer.Forward(imgInput, promptEmbeds, timestep, rope)
// If we concatenated reference tokens, slice to only get noise portion
if refTokens != nil {
output = mlx.Slice(output, []int32{0, 0, 0}, []int32{1, noiseSeqLen, output.Shape()[2]})
}
// Scheduler step (keep reference to old patches for the computation graph)
newPatches := scheduler.Step(output, patches, i)
if cfg.CapturePath != "" && i == 1 {
mlx.MetalStopCapture()
}
mlx.Eval(newPatches)
patches = newPatches
elapsed := time.Since(stepStart).Seconds()
peakGB := float64(mlx.MetalGetPeakMemory()) / (1024 * 1024 * 1024)
if i == 0 {
fmt.Printf(" step %d: %.2fs (JIT warmup), peak %.1f GB\n", i+1, elapsed, peakGB)
} else {
fmt.Printf(" step %d: %.2fs, peak %.1f GB\n", i+1, elapsed, peakGB)
}
stepStart = time.Now()
if cfg.Progress != nil {
cfg.Progress(i+1, cfg.Steps)
}
}
loopTime := time.Since(loopStart).Seconds()
peakMem := float64(mlx.MetalGetPeakMemory()) / (1024 * 1024 * 1024)
fmt.Printf(" Denoised %d steps in %.2fs (%.2fs/step), peak %.1f GB\n",
cfg.Steps, loopTime, loopTime/float64(cfg.Steps), peakMem)
// Free timesteps now that denoising is done
for _, ts := range timesteps {
ts.Free()
}
// VAE decode with tiling for larger images
fmt.Print(" Decoding VAE... ")
vaeStart := time.Now()
// Enable tiling for images > 512x512 (latent > 64x64)
// VAE attention is O(n²) on latent pixels, tiling reduces memory significantly
if patchH*2 > 64 || patchW*2 > 64 {
m.VAE.Tiling = DefaultTilingConfig()
}
decoded := m.VAE.Decode(patches, patchH, patchW)
mlx.Eval(decoded)
// Free patches now that decode is done
patches.Free()
fmt.Printf("✓ (%.2fs, peak %.1f GB)\n", time.Since(vaeStart).Seconds(),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
return decoded, nil
}
// packLatents converts [B, C, H, W] to [B, H*W, C] (matches diffusers _pack_latents)
func packLatents(x *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
C := shape[1]
H := shape[2]
W := shape[3]
// [B, C, H, W] -> [B, C, H*W] -> [B, H*W, C]
x = mlx.Reshape(x, B, C, H*W)
return mlx.Transpose(x, 0, 2, 1)
}
// LoadPersistent loads the model and keeps it in memory for repeated use.
func LoadPersistent(modelName string) (*Model, error) {
m := &Model{}
if err := m.Load(modelName); err != nil {
return nil, err
}
return m, nil
}
// ImageRefScale is the time coordinate offset between reference images (matches diffusers scale=10)
const ImageRefScale = 10
// PrepareImage resizes and crops an image to be a multiple of 16, with optional pixel limit.
// Returns the processed image and its dimensions.
func PrepareImage(img image.Image, limitPixels int) (image.Image, int, int) {
bounds := img.Bounds()
w, h := bounds.Dx(), bounds.Dy()
// Cap pixels if needed (like diffusers cap_pixels)
if limitPixels > 0 && w*h > limitPixels {
scale := math.Sqrt(float64(limitPixels) / float64(w*h))
w = int(float64(w) * scale)
h = int(float64(h) * scale)
}
// Round down to multiple of 16
w = (w / 16) * 16
h = (h / 16) * 16
if w < 16 {
w = 16
}
if h < 16 {
h = 16
}
// Resize using high-quality bicubic interpolation (matches diffusers' default lanczos)
resized := image.NewRGBA(image.Rect(0, 0, w, h))
draw.CatmullRom.Scale(resized, resized.Bounds(), img, img.Bounds(), draw.Over, nil)
return resized, w, h
}
// ImageToTensor converts an image to a tensor in [-1, 1] range with shape [1, C, H, W].
func ImageToTensor(img image.Image) *mlx.Array {
bounds := img.Bounds()
w, h := bounds.Dx(), bounds.Dy()
// Convert to float32 array in NCHW format [1, 3, H, W] with values in [-1, 1]
data := make([]float32, 3*h*w)
for y := 0; y < h; y++ {
for x := 0; x < w; x++ {
r, g, b, _ := img.At(x+bounds.Min.X, y+bounds.Min.Y).RGBA()
// RGBA returns 16-bit values, convert to [-1, 1]
data[0*h*w+y*w+x] = float32(r>>8)/127.5 - 1.0
data[1*h*w+y*w+x] = float32(g>>8)/127.5 - 1.0
data[2*h*w+y*w+x] = float32(b>>8)/127.5 - 1.0
}
}
arr := mlx.NewArrayFloat32(data, []int32{1, 3, int32(h), int32(w)})
return arr
}
// ImageCondTokens holds encoded reference image tokens.
type ImageCondTokens struct {
Tokens *mlx.Array // [1, total_tokens, C] - concatenated reference tokens
}
// EncodeImageRefs encodes reference images using the VAE.
func (m *Model) EncodeImageRefs(images []image.Image) (*ImageCondTokens, error) {
if len(images) == 0 {
return nil, nil
}
// Limit reference images to reduce attention memory
limitPixels := MaxRefPixels
if len(images) > 1 {
limitPixels = MaxRefPixels / 2
}
var allTokens []*mlx.Array
for _, img := range images {
// Prepare image (resize, crop to multiple of 16)
prepared, prepW, prepH := PrepareImage(img, limitPixels)
fmt.Printf(" Encoding %dx%d image... ", prepW, prepH)
// Convert to tensor [-1, 1]
tensor := ImageToTensor(prepared)
// Encode with VAE - returns [1, L, 128]
encoded := m.VAE.EncodeImage(tensor)
squeezed := mlx.Squeeze(encoded, 0) // [L, C]
// Defer eval - will be done with other setup arrays
allTokens = append(allTokens, squeezed)
fmt.Println("✓")
}
// For single image, just add batch dimension directly
// For multiple images, concatenate first
var tokens *mlx.Array
if len(allTokens) == 1 {
tokens = mlx.ExpandDims(allTokens[0], 0) // [1, L, C]
} else {
tokens = mlx.Concatenate(allTokens, 0) // [total_L, C]
tokens = mlx.ExpandDims(tokens, 0) // [1, total_L, C]
}
return &ImageCondTokens{Tokens: tokens}, nil
}

View File

@@ -0,0 +1,224 @@
//go:build mlx
package flux2
import (
"math"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// RoPEConfig holds 4D RoPE configuration for Flux2
type RoPEConfig struct {
Theta int32 // 2000 for Klein
AxesDims []int32 // [32, 32, 32, 32] - dimensions for T, H, W, L axes
}
// RoPECache holds precomputed RoPE cos/sin values
type RoPECache struct {
Cos *mlx.Array // [1, TotalSeqLen, 1, head_dim/2]
Sin *mlx.Array // [1, TotalSeqLen, 1, head_dim/2]
TextLen int32 // Length of text sequence
ImageLen int32 // Length of image sequence
}
// PrepareTextIDs creates position IDs for text tokens.
// Text tokens use: T=0, H=0, W=0, L=0..seqLen-1
// Returns: [seqLen, 4]
func PrepareTextIDs(seqLen int32) *mlx.Array {
ids := make([]float32, seqLen*4)
for i := int32(0); i < seqLen; i++ {
idx := i * 4
ids[idx+0] = 0 // T = 0
ids[idx+1] = 0 // H = 0
ids[idx+2] = 0 // W = 0
ids[idx+3] = float32(i) // L = sequence position
}
return mlx.NewArray(ids, []int32{seqLen, 4})
}
// PrepareLatentIDs creates position IDs for image latent tokens.
// Latent tokens use: T=0, H=0..height-1, W=0..width-1, L=0
// The latents are in row-major order (H then W).
// Returns: [height*width, 4]
func PrepareLatentIDs(height, width int32) *mlx.Array {
seqLen := height * width
ids := make([]float32, seqLen*4)
idx := 0
for h := int32(0); h < height; h++ {
for w := int32(0); w < width; w++ {
ids[idx*4+0] = 0 // T = 0
ids[idx*4+1] = float32(h) // H = row
ids[idx*4+2] = float32(w) // W = column
ids[idx*4+3] = 0 // L = 0
idx++
}
}
return mlx.NewArray(ids, []int32{seqLen, 4})
}
// PrepareImageIDs creates position IDs for reference image tokens (used in editing).
// Reference images use: T=scale*(i+1), H=0..h-1, W=0..w-1, L=0
// where i is the image index (0, 1, 2, ...) and scale separates images in T dimension.
// Returns: [total_tokens, 4]
func PrepareImageIDs(imageHeights, imageWidths []int32, scale int32) *mlx.Array {
// Calculate total tokens
totalTokens := int32(0)
for i := range imageHeights {
totalTokens += imageHeights[i] * imageWidths[i]
}
ids := make([]float32, totalTokens*4)
idx := int32(0)
for imgIdx, h := range imageHeights {
w := imageWidths[imgIdx]
tValue := float32(scale * int32(imgIdx+1))
for hi := int32(0); hi < h; hi++ {
for wi := int32(0); wi < w; wi++ {
ids[idx*4+0] = tValue // T = scale * (imgIdx + 1)
ids[idx*4+1] = float32(hi) // H = row
ids[idx*4+2] = float32(wi) // W = column
ids[idx*4+3] = 0 // L = 0
idx++
}
}
}
return mlx.NewArray(ids, []int32{totalTokens, 4})
}
// ComputeRoPE computes cos and sin for 4D rotary position embeddings.
// ids: [L, 4] with (T, H, W, L) coordinates
// axesDims: [32, 32, 32, 32] - each axis has this many dimensions (total = head_dim = 128)
// theta: base frequency (2000 for Klein)
// Returns: cos, sin each [1, L, 1, head_dim] with repeat_interleave applied
func ComputeRoPE(ids *mlx.Array, axesDims []int32, theta int32) (*mlx.Array, *mlx.Array) {
shape := ids.Shape()
seqLen := shape[0]
// Compute total head dim (sum of all axes dims)
headDim := int32(0)
for _, d := range axesDims {
headDim += d
}
// Extract each coordinate dimension
// ids[:, 0] = T, ids[:, 1] = H, ids[:, 2] = W, ids[:, 3] = L
posT := mlx.Slice(ids, []int32{0, 0}, []int32{seqLen, 1}) // [L, 1]
posH := mlx.Slice(ids, []int32{0, 1}, []int32{seqLen, 2}) // [L, 1]
posW := mlx.Slice(ids, []int32{0, 2}, []int32{seqLen, 3}) // [L, 1]
posL := mlx.Slice(ids, []int32{0, 3}, []int32{seqLen, 4}) // [L, 1]
// Compute frequencies for each axis
logTheta := float32(math.Log(float64(theta)))
cosArrs := make([]*mlx.Array, 4)
sinArrs := make([]*mlx.Array, 4)
positions := []*mlx.Array{posT, posH, posW, posL}
for i, axisDim := range axesDims {
half := axisDim / 2
// Create frequency array for this axis: theta^(-2j/dim) for j=0..half-1
// This matches diffusers: 1.0 / (theta ** (torch.arange(0, dim, 2) / dim))
freqs := make([]float32, half)
for j := int32(0); j < half; j++ {
freqs[j] = float32(math.Exp(float64(-logTheta * float32(2*j) / float32(axisDim))))
}
freqArr := mlx.NewArray(freqs, []int32{1, half})
// Compute pos * freq -> [L, half]
posExpanded := positions[i] // [L, 1]
args := mlx.Mul(posExpanded, freqArr) // [L, half]
// Compute cos and sin for this axis
cosAxis := mlx.Cos(args) // [L, half]
sinAxis := mlx.Sin(args) // [L, half]
// repeat_interleave(2): [c0, c1, ...] -> [c0, c0, c1, c1, ...]
// Reshape [L, half] -> [L, half, 1], tile to [L, half, 2], reshape to [L, axisDim]
cosAxis = mlx.ExpandDims(cosAxis, 2) // [L, half, 1]
cosAxis = mlx.Tile(cosAxis, []int32{1, 1, 2}) // [L, half, 2]
cosAxis = mlx.Reshape(cosAxis, seqLen, axisDim) // [L, axisDim]
sinAxis = mlx.ExpandDims(sinAxis, 2)
sinAxis = mlx.Tile(sinAxis, []int32{1, 1, 2})
sinAxis = mlx.Reshape(sinAxis, seqLen, axisDim)
cosArrs[i] = cosAxis
sinArrs[i] = sinAxis
}
// Concatenate all axes: [L, headDim]
cos := mlx.Concatenate(cosArrs, 1)
sin := mlx.Concatenate(sinArrs, 1)
// Reshape to [1, L, 1, headDim] for broadcasting with attention
cos = mlx.Reshape(cos, 1, seqLen, 1, headDim)
sin = mlx.Reshape(sin, 1, seqLen, 1, headDim)
return cos, sin
}
// ApplyRoPE4D applies 4D rotary position embeddings to queries and keys.
// x: [B, L, nheads, head_dim]
// cos, sin: [1, L, 1, head_dim] (with repeat_interleave applied)
// Returns: x with RoPE applied
// Matches diffusers apply_rotary_emb with use_real=True, use_real_unbind_dim=-1
func ApplyRoPE4D(x *mlx.Array, cos, sin *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
nheads := shape[2]
headDim := shape[3]
half := headDim / 2
// Reshape x to [B, L, nheads, half, 2] and split into real/imag
xReshaped := mlx.Reshape(x, B, L, nheads, half, 2)
// Extract real (index 0) and imag (index 1) parts
xReal := mlx.Slice(xReshaped, []int32{0, 0, 0, 0, 0}, []int32{B, L, nheads, half, 1})
xImag := mlx.Slice(xReshaped, []int32{0, 0, 0, 0, 1}, []int32{B, L, nheads, half, 2})
xReal = mlx.Squeeze(xReal, 4) // [B, L, nheads, half]
xImag = mlx.Squeeze(xImag, 4) // [B, L, nheads, half]
// x_rotated = stack([-x_imag, x_real], dim=-1).flatten(-2)
// This creates [-x_imag[0], x_real[0], -x_imag[1], x_real[1], ...]
negXImag := mlx.Neg(xImag)
negXImag = mlx.ExpandDims(negXImag, 4) // [B, L, nheads, half, 1]
xReal = mlx.ExpandDims(xReal, 4) // [B, L, nheads, half, 1]
xRotated := mlx.Concatenate([]*mlx.Array{negXImag, xReal}, 4) // [B, L, nheads, half, 2]
xRotated = mlx.Reshape(xRotated, B, L, nheads, headDim) // [B, L, nheads, headDim]
// out = x * cos + x_rotated * sin
return mlx.Add(mlx.Mul(x, cos), mlx.Mul(xRotated, sin))
}
// PrepareRoPECache creates RoPE cache for text + noise, optionally with reference images.
// textLen: number of text tokens
// noiseH, noiseW: dimensions of the noise latent in patch tokens
// axesDims: [32, 32, 32, 32]
// theta: 2000
// refHeights, refWidths: optional reference image dimensions (pass nil/empty for no images)
// scale: time coordinate offset between reference images (e.g., 10)
func PrepareRoPECache(textLen, noiseH, noiseW int32, axesDims []int32, theta int32, refHeights, refWidths []int32, scale int32) *RoPECache {
textIDs := PrepareTextIDs(textLen)
noiseIDs := PrepareLatentIDs(noiseH, noiseW)
var allIDs *mlx.Array
imageLen := noiseH * noiseW
if len(refHeights) > 0 {
refIDs := PrepareImageIDs(refHeights, refWidths, scale)
allIDs = mlx.Concatenate([]*mlx.Array{textIDs, noiseIDs, refIDs}, 0)
for i := range refHeights {
imageLen += refHeights[i] * refWidths[i]
}
} else {
allIDs = mlx.Concatenate([]*mlx.Array{textIDs, noiseIDs}, 0)
}
cos, sin := ComputeRoPE(allIDs, axesDims, theta)
cos = mlx.ToBFloat16(cos)
sin = mlx.ToBFloat16(sin)
return &RoPECache{Cos: cos, Sin: sin, TextLen: textLen, ImageLen: imageLen}
}

View File

@@ -0,0 +1,149 @@
//go:build mlx
package flux2
import (
"math"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// SchedulerConfig holds Flow-Match scheduler configuration
type SchedulerConfig struct {
NumTrainTimesteps int32 `json:"num_train_timesteps"` // 1000
Shift float32 `json:"shift"` // 3.0 for Klein
UseDynamicShifting bool `json:"use_dynamic_shifting"` // true
TimeShiftType string `json:"time_shift_type"` // "exponential" or "linear"
}
// DefaultSchedulerConfig returns default config for Klein
func DefaultSchedulerConfig() *SchedulerConfig {
return &SchedulerConfig{
NumTrainTimesteps: 1000,
Shift: 3.0, // Klein uses 3.0
UseDynamicShifting: true,
TimeShiftType: "exponential",
}
}
// FlowMatchScheduler implements the Flow-Match Euler discrete scheduler
type FlowMatchScheduler struct {
Config *SchedulerConfig
Timesteps []float32 // Discretized timesteps (t from 1 to 0)
Sigmas []float32 // Noise levels at each timestep
NumSteps int // Number of inference steps
}
// NewFlowMatchScheduler creates a new scheduler
func NewFlowMatchScheduler(cfg *SchedulerConfig) *FlowMatchScheduler {
return &FlowMatchScheduler{
Config: cfg,
}
}
// SetTimesteps sets up the scheduler for the given number of inference steps
func (s *FlowMatchScheduler) SetTimesteps(numSteps int) {
s.SetTimestepsWithMu(numSteps, 0)
}
// SetTimestepsWithMu sets up scheduler matching diffusers set_timesteps(sigmas=..., mu=...)
func (s *FlowMatchScheduler) SetTimestepsWithMu(numSteps int, mu float32) {
s.NumSteps = numSteps
// diffusers: sigmas = linspace(1, 1/num_steps, num_steps)
// Then applies time shift, appends 0.0 at end
s.Sigmas = make([]float32, numSteps+1)
for i := 0; i < numSteps; i++ {
// linspace(1, 1/num_steps, num_steps)
var sigma float32
if numSteps == 1 {
sigma = 1.0
} else {
sigma = 1.0 - float32(i)/float32(numSteps-1)*(1.0-1.0/float32(numSteps))
}
// Apply time shift if using dynamic shifting
if s.Config.UseDynamicShifting && mu != 0 {
sigma = s.timeShift(mu, sigma)
} else {
// If not dynamic shifting, apply fixed shift scaling like diffusers
shift := s.Config.Shift
sigma = shift * sigma / (1 + (shift-1)*sigma)
}
s.Sigmas[i] = sigma
}
// Append terminal zero
s.Sigmas[numSteps] = 0.0
// Timesteps scaled to training range (matches diffusers: timesteps = sigmas * num_train_timesteps)
s.Timesteps = make([]float32, numSteps+1)
for i, v := range s.Sigmas {
s.Timesteps[i] = v * float32(s.Config.NumTrainTimesteps)
}
}
// timeShift applies the dynamic time shift
func (s *FlowMatchScheduler) timeShift(mu float32, t float32) float32 {
if t <= 0 {
return 0
}
if s.Config.TimeShiftType == "linear" {
return mu / (mu + (1.0/t-1.0))
}
// Default: exponential
expMu := float32(math.Exp(float64(mu)))
return expMu / (expMu + (1.0/t - 1.0))
}
// Step performs one denoising step
func (s *FlowMatchScheduler) Step(modelOutput, sample *mlx.Array, timestepIdx int) *mlx.Array {
sigma := s.Sigmas[timestepIdx]
sigmaNext := s.Sigmas[timestepIdx+1]
// Euler step: x_{t-dt} = x_t + (sigma_next - sigma) * v_t
dt := sigmaNext - sigma
// Upcast to float32 for precision (matches diffusers)
sampleF32 := mlx.AsType(sample, mlx.DtypeFloat32)
outputF32 := mlx.AsType(modelOutput, mlx.DtypeFloat32)
scaledOutput := mlx.MulScalar(outputF32, dt)
result := mlx.Add(sampleF32, scaledOutput)
// Cast back to bfloat16
return mlx.ToBFloat16(result)
}
// GetTimestep returns the timestep value at the given index
func (s *FlowMatchScheduler) GetTimestep(idx int) float32 {
if idx < len(s.Timesteps) {
return s.Timesteps[idx]
}
return 0.0
}
// InitNoise creates initial noise for sampling
func (s *FlowMatchScheduler) InitNoise(shape []int32, seed int64) *mlx.Array {
return mlx.RandomNormalWithDtype(shape, uint64(seed), mlx.DtypeBFloat16)
}
// CalculateShift computes the mu shift value for dynamic scheduling
// Matches diffusers compute_empirical_mu function
func CalculateShift(imgSeqLen int32, numSteps int) float32 {
a1, b1 := float32(8.73809524e-05), float32(1.89833333)
a2, b2 := float32(0.00016927), float32(0.45666666)
seqLen := float32(imgSeqLen)
if imgSeqLen > 4300 {
return a2*seqLen + b2
}
m200 := a2*seqLen + b2
m10 := a1*seqLen + b1
a := (m200 - m10) / 190.0
b := m200 - 200.0*a
return a*float32(numSteps) + b
}

View File

@@ -0,0 +1,562 @@
//go:build mlx
package flux2
import (
"fmt"
"math"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// TransformerConfig holds Flux2 transformer configuration
type TransformerConfig struct {
AttentionHeadDim int32 `json:"attention_head_dim"` // 128
AxesDimsRoPE []int32 `json:"axes_dims_rope"` // [32, 32, 32, 32]
Eps float32 `json:"eps"` // 1e-6
GuidanceEmbeds bool `json:"guidance_embeds"` // false for Klein
InChannels int32 `json:"in_channels"` // 128
JointAttentionDim int32 `json:"joint_attention_dim"` // 7680
MLPRatio float32 `json:"mlp_ratio"` // 3.0
NumAttentionHeads int32 `json:"num_attention_heads"` // 24
NumLayers int32 `json:"num_layers"` // 5
NumSingleLayers int32 `json:"num_single_layers"` // 20
PatchSize int32 `json:"patch_size"` // 1
RopeTheta int32 `json:"rope_theta"` // 2000
TimestepGuidanceChannels int32 `json:"timestep_guidance_channels"` // 256
}
// Computed dimensions
func (c *TransformerConfig) InnerDim() int32 {
return c.NumAttentionHeads * c.AttentionHeadDim // 24 * 128 = 3072
}
func (c *TransformerConfig) MLPHiddenDim() int32 {
return int32(float32(c.InnerDim()) * c.MLPRatio) // 3072 * 3.0 = 9216
}
// TimestepEmbedder creates timestep embeddings
// Weight names: time_guidance_embed.timestep_embedder.linear_1.weight, linear_2.weight
type TimestepEmbedder struct {
Linear1 nn.LinearLayer `weight:"linear_1"`
Linear2 nn.LinearLayer `weight:"linear_2"`
EmbedDim int32 // 256
}
// Forward creates sinusoidal embeddings and projects them
func (t *TimestepEmbedder) Forward(timesteps *mlx.Array) *mlx.Array {
half := t.EmbedDim / 2
freqs := make([]float32, half)
for i := int32(0); i < half; i++ {
freqs[i] = float32(math.Exp(-math.Log(10000.0) * float64(i) / float64(half)))
}
freqsArr := mlx.NewArray(freqs, []int32{1, half})
// timesteps: [B] -> [B, 1]
tExpanded := mlx.ExpandDims(timesteps, 1)
// args: [B, half]
args := mlx.Mul(tExpanded, freqsArr)
// [cos(args), sin(args)] -> [B, embed_dim]
sinEmbed := mlx.Concatenate([]*mlx.Array{mlx.Cos(args), mlx.Sin(args)}, 1)
// MLP: linear_1 -> silu -> linear_2
h := t.Linear1.Forward(sinEmbed)
h = mlx.SiLU(h)
return t.Linear2.Forward(h)
}
// TimeGuidanceEmbed wraps the timestep embedder
// Weight names: time_guidance_embed.timestep_embedder.*
type TimeGuidanceEmbed struct {
TimestepEmbedder *TimestepEmbedder `weight:"timestep_embedder"`
}
// Forward computes timestep embeddings
func (t *TimeGuidanceEmbed) Forward(timesteps *mlx.Array) *mlx.Array {
return t.TimestepEmbedder.Forward(timesteps)
}
// Modulation computes adaptive modulation parameters
// Weight names: double_stream_modulation_img.linear.weight, etc.
type Modulation struct {
Linear nn.LinearLayer `weight:"linear"`
}
// Forward computes modulation parameters
func (m *Modulation) Forward(temb *mlx.Array) *mlx.Array {
h := mlx.SiLU(temb)
return m.Linear.Forward(h)
}
// TransformerBlockAttn implements dual-stream attention
// Weight names: transformer_blocks.N.attn.*
type TransformerBlockAttn struct {
// Image stream (separate Q, K, V projections)
ToQ nn.LinearLayer `weight:"to_q"`
ToK nn.LinearLayer `weight:"to_k"`
ToV nn.LinearLayer `weight:"to_v"`
// Note: to_out has .0 suffix in weights, handled specially
ToOut0 nn.LinearLayer `weight:"to_out.0"`
// Text stream (add_ projections)
AddQProj nn.LinearLayer `weight:"add_q_proj"`
AddKProj nn.LinearLayer `weight:"add_k_proj"`
AddVProj nn.LinearLayer `weight:"add_v_proj"`
ToAddOut nn.LinearLayer `weight:"to_add_out"`
// QK norms for image stream
NormQ *mlx.Array `weight:"norm_q.weight"`
NormK *mlx.Array `weight:"norm_k.weight"`
// QK norms for text stream (added)
NormAddedQ *mlx.Array `weight:"norm_added_q.weight"`
NormAddedK *mlx.Array `weight:"norm_added_k.weight"`
}
// FeedForward implements SwiGLU MLP
// Weight names: transformer_blocks.N.ff.linear_in.weight, linear_out.weight
type FeedForward struct {
LinearIn nn.LinearLayer `weight:"linear_in"`
LinearOut nn.LinearLayer `weight:"linear_out"`
}
// Forward applies SwiGLU MLP
func (ff *FeedForward) Forward(x *mlx.Array) *mlx.Array {
// LinearIn outputs 2x hidden dim for SwiGLU
h := ff.LinearIn.Forward(x)
shape := h.Shape()
half := shape[len(shape)-1] / 2
// Split into gate and up
gate := mlx.Slice(h, []int32{0, 0, 0}, []int32{shape[0], shape[1], half})
up := mlx.Slice(h, []int32{0, 0, half}, []int32{shape[0], shape[1], shape[2]})
// SwiGLU: silu(gate) * up
h = mlx.Mul(mlx.SiLU(gate), up)
return ff.LinearOut.Forward(h)
}
// TransformerBlock implements a dual-stream transformer block
// Weight names: transformer_blocks.N.*
type TransformerBlock struct {
Attn *TransformerBlockAttn `weight:"attn"`
FF *FeedForward `weight:"ff"`
FFContext *FeedForward `weight:"ff_context"`
// Config (set after loading)
NHeads int32
HeadDim int32
Scale float32
}
// Forward applies the dual-stream block
// imgHidden: [B, imgLen, dim]
// txtHidden: [B, txtLen, dim]
// imgMod, txtMod: modulation params [B, 6*dim] each
// cos, sin: RoPE values
func (block *TransformerBlock) Forward(imgHidden, txtHidden *mlx.Array, imgMod, txtMod *mlx.Array, cos, sin *mlx.Array) (*mlx.Array, *mlx.Array) {
imgShape := imgHidden.Shape()
B := imgShape[0]
imgLen := imgShape[1]
dim := imgShape[2]
txtLen := txtHidden.Shape()[1]
// Parse modulation: 6 params each (shift1, scale1, gate1, shift2, scale2, gate2)
imgShift1, imgScale1, imgGate1 := parseModulation3(imgMod, dim, 0)
imgShift2, imgScale2, imgGate2 := parseModulation3(imgMod, dim, 3)
txtShift1, txtScale1, txtGate1 := parseModulation3(txtMod, dim, 0)
txtShift2, txtScale2, txtGate2 := parseModulation3(txtMod, dim, 3)
// === Attention branch ===
// Modulate inputs
imgNorm := modulateLayerNorm(imgHidden, imgShift1, imgScale1)
txtNorm := modulateLayerNorm(txtHidden, txtShift1, txtScale1)
// Compute Q, K, V for image stream (separate projections)
imgQ := block.Attn.ToQ.Forward(imgNorm)
imgK := block.Attn.ToK.Forward(imgNorm)
imgV := block.Attn.ToV.Forward(imgNorm)
// Compute Q, K, V for text stream (add_ projections)
txtQ := block.Attn.AddQProj.Forward(txtNorm)
txtK := block.Attn.AddKProj.Forward(txtNorm)
txtV := block.Attn.AddVProj.Forward(txtNorm)
// Reshape for attention: [B, L, dim] -> [B, L, nheads, headDim]
imgQ = mlx.Reshape(imgQ, B, imgLen, block.NHeads, block.HeadDim)
imgK = mlx.Reshape(imgK, B, imgLen, block.NHeads, block.HeadDim)
imgV = mlx.Reshape(imgV, B, imgLen, block.NHeads, block.HeadDim)
txtQ = mlx.Reshape(txtQ, B, txtLen, block.NHeads, block.HeadDim)
txtK = mlx.Reshape(txtK, B, txtLen, block.NHeads, block.HeadDim)
txtV = mlx.Reshape(txtV, B, txtLen, block.NHeads, block.HeadDim)
// Apply QK norm (RMSNorm with learned scale)
imgQ = applyQKNorm(imgQ, block.Attn.NormQ)
imgK = applyQKNorm(imgK, block.Attn.NormK)
txtQ = applyQKNorm(txtQ, block.Attn.NormAddedQ)
txtK = applyQKNorm(txtK, block.Attn.NormAddedK)
// Concatenate for joint attention: text first, then image
q := mlx.Concatenate([]*mlx.Array{txtQ, imgQ}, 1)
k := mlx.Concatenate([]*mlx.Array{txtK, imgK}, 1)
v := mlx.Concatenate([]*mlx.Array{txtV, imgV}, 1)
// Apply RoPE
q = ApplyRoPE4D(q, cos, sin)
k = ApplyRoPE4D(k, cos, sin)
// Transpose for SDPA: [B, nheads, L, headDim]
q = mlx.Transpose(q, 0, 2, 1, 3)
k = mlx.Transpose(k, 0, 2, 1, 3)
v = mlx.Transpose(v, 0, 2, 1, 3)
// Scaled dot-product attention
out := mlx.ScaledDotProductAttention(q, k, v, block.Scale, false)
// Transpose back: [B, L, nheads, headDim]
out = mlx.Transpose(out, 0, 2, 1, 3)
// Split back into txt and img
totalLen := txtLen + imgLen
txtOut := mlx.Slice(out, []int32{0, 0, 0, 0}, []int32{B, txtLen, block.NHeads, block.HeadDim})
imgOut := mlx.Slice(out, []int32{0, txtLen, 0, 0}, []int32{B, totalLen, block.NHeads, block.HeadDim})
// Reshape and project
txtOut = mlx.Reshape(txtOut, B, txtLen, dim)
imgOut = mlx.Reshape(imgOut, B, imgLen, dim)
txtOut = block.Attn.ToAddOut.Forward(txtOut)
imgOut = block.Attn.ToOut0.Forward(imgOut)
// Apply gates and residual
imgHidden = mlx.Add(imgHidden, mlx.Mul(imgGate1, imgOut))
txtHidden = mlx.Add(txtHidden, mlx.Mul(txtGate1, txtOut))
// === MLP branch ===
imgNorm = modulateLayerNorm(imgHidden, imgShift2, imgScale2)
txtNorm = modulateLayerNorm(txtHidden, txtShift2, txtScale2)
imgFFOut := block.FF.Forward(imgNorm)
txtFFOut := block.FFContext.Forward(txtNorm)
imgHidden = mlx.Add(imgHidden, mlx.Mul(imgGate2, imgFFOut))
txtHidden = mlx.Add(txtHidden, mlx.Mul(txtGate2, txtFFOut))
return imgHidden, txtHidden
}
// SingleTransformerBlockAttn implements attention for single-stream blocks
// Weight names: single_transformer_blocks.N.attn.*
type SingleTransformerBlockAttn struct {
ToQKVMlpProj nn.LinearLayer `weight:"to_qkv_mlp_proj"` // Fused QKV + MLP input
ToOut nn.LinearLayer `weight:"to_out"` // Fused attn_out + MLP out
NormQ *mlx.Array `weight:"norm_q.weight"`
NormK *mlx.Array `weight:"norm_k.weight"`
}
// SingleTransformerBlock implements a single-stream transformer block
// Weight names: single_transformer_blocks.N.*
type SingleTransformerBlock struct {
Attn *SingleTransformerBlockAttn `weight:"attn"`
// Config
NHeads int32
HeadDim int32
InnerDim int32
MLPHidDim int32
Scale float32
}
// Forward applies the single-stream block
// x: [B, L, dim] concatenated text+image
// mod: modulation [B, 3*dim]
func (block *SingleTransformerBlock) Forward(x *mlx.Array, mod *mlx.Array, cos, sin *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
dim := shape[2]
// Parse modulation: (shift, scale, gate)
shift, scale, gate := parseModulation3(mod, dim, 0)
// Modulate input
h := modulateLayerNorm(x, shift, scale)
// Fused projection: QKV + MLP gate/up
// linear1 outputs: [q, k, v, mlp_gate, mlp_up] = [dim, dim, dim, mlpHid, mlpHid]
qkvMlp := block.Attn.ToQKVMlpProj.Forward(h)
// Split: first 3*dim is QKV, rest is MLP
qkvDim := 3 * block.InnerDim
qkv := mlx.Slice(qkvMlp, []int32{0, 0, 0}, []int32{B, L, qkvDim})
mlpIn := mlx.Slice(qkvMlp, []int32{0, 0, qkvDim}, []int32{B, L, qkvMlp.Shape()[2]})
// Split QKV
q, k, v := splitQKV(qkv, B, L, block.InnerDim)
// Reshape for attention
q = mlx.Reshape(q, B, L, block.NHeads, block.HeadDim)
k = mlx.Reshape(k, B, L, block.NHeads, block.HeadDim)
v = mlx.Reshape(v, B, L, block.NHeads, block.HeadDim)
// QK norm
q = applyQKNorm(q, block.Attn.NormQ)
k = applyQKNorm(k, block.Attn.NormK)
// Apply RoPE
q = ApplyRoPE4D(q, cos, sin)
k = ApplyRoPE4D(k, cos, sin)
// Transpose for SDPA
q = mlx.Transpose(q, 0, 2, 1, 3)
k = mlx.Transpose(k, 0, 2, 1, 3)
v = mlx.Transpose(v, 0, 2, 1, 3)
// SDPA
attnOut := mlx.ScaledDotProductAttention(q, k, v, block.Scale, false)
// Transpose back and reshape
attnOut = mlx.Transpose(attnOut, 0, 2, 1, 3)
attnOut = mlx.Reshape(attnOut, B, L, block.InnerDim)
// MLP: SwiGLU
mlpShape := mlpIn.Shape()
half := mlpShape[2] / 2
mlpGate := mlx.Slice(mlpIn, []int32{0, 0, 0}, []int32{B, L, half})
mlpUp := mlx.Slice(mlpIn, []int32{0, 0, half}, []int32{B, L, mlpShape[2]})
mlpOut := mlx.Mul(mlx.SiLU(mlpGate), mlpUp)
// Concatenate attention and MLP for fused output
combined := mlx.Concatenate([]*mlx.Array{attnOut, mlpOut}, 2)
// Output projection
out := block.Attn.ToOut.Forward(combined)
// Apply gate and residual
return mlx.Add(x, mlx.Mul(gate, out))
}
// NormOut implements the output normalization with modulation
// Weight names: norm_out.linear.weight
type NormOut struct {
Linear nn.LinearLayer `weight:"linear"`
}
// Forward computes final modulated output
func (n *NormOut) Forward(x *mlx.Array, temb *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
dim := shape[2]
// Modulation: temb -> silu -> linear -> [shift, scale]
mod := mlx.SiLU(temb)
mod = n.Linear.Forward(mod)
// Split into scale and shift (diffusers order: scale first, shift second)
scale := mlx.Slice(mod, []int32{0, 0}, []int32{B, dim})
shift := mlx.Slice(mod, []int32{0, dim}, []int32{B, 2 * dim})
shift = mlx.ExpandDims(shift, 1)
scale = mlx.ExpandDims(scale, 1)
// Modulate with RMSNorm
return modulateLayerNorm(x, shift, scale)
}
// Flux2Transformer2DModel is the main Flux2 transformer
// Weight names at top level: time_guidance_embed.*, double_stream_modulation_*.*, etc.
type Flux2Transformer2DModel struct {
// Timestep embedding
TimeGuidanceEmbed *TimeGuidanceEmbed `weight:"time_guidance_embed"`
// Shared modulation
DoubleStreamModulationImg *Modulation `weight:"double_stream_modulation_img"`
DoubleStreamModulationTxt *Modulation `weight:"double_stream_modulation_txt"`
SingleStreamModulation *Modulation `weight:"single_stream_modulation"`
// Embedders
XEmbedder nn.LinearLayer `weight:"x_embedder"`
ContextEmbedder nn.LinearLayer `weight:"context_embedder"`
// Transformer blocks
TransformerBlocks []*TransformerBlock `weight:"transformer_blocks"`
SingleTransformerBlocks []*SingleTransformerBlock `weight:"single_transformer_blocks"`
// Output
NormOut *NormOut `weight:"norm_out"`
ProjOut nn.LinearLayer `weight:"proj_out"`
*TransformerConfig
}
// Load loads the Flux2 transformer from ollama blob storage.
func (m *Flux2Transformer2DModel) Load(manifest *imagegen.ModelManifest) error {
fmt.Print(" Loading transformer... ")
// Load config from blob
var cfg TransformerConfig
if err := manifest.ReadConfigJSON("transformer/config.json", &cfg); err != nil {
return fmt.Errorf("config: %w", err)
}
m.TransformerConfig = &cfg
// Initialize slices
m.TransformerBlocks = make([]*TransformerBlock, cfg.NumLayers)
m.SingleTransformerBlocks = make([]*SingleTransformerBlock, cfg.NumSingleLayers)
// Initialize TimeGuidanceEmbed with embed dim
m.TimeGuidanceEmbed = &TimeGuidanceEmbed{
TimestepEmbedder: &TimestepEmbedder{EmbedDim: cfg.TimestepGuidanceChannels},
}
// Load weights from tensor blobs
weights, err := imagegen.LoadWeightsFromManifest(manifest, "transformer")
if err != nil {
return fmt.Errorf("weights: %w", err)
}
if err := weights.Load(0); err != nil {
return fmt.Errorf("load weights: %w", err)
}
defer weights.ReleaseAll()
return m.loadWeights(weights)
}
// loadWeights loads weights from any WeightSource into the model
func (m *Flux2Transformer2DModel) loadWeights(weights safetensors.WeightSource) error {
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return fmt.Errorf("load module: %w", err)
}
m.initComputedFields()
fmt.Println("✓")
return nil
}
// initComputedFields initializes computed fields after loading weights
func (m *Flux2Transformer2DModel) initComputedFields() {
cfg := m.TransformerConfig
innerDim := cfg.InnerDim()
scale := float32(1.0 / math.Sqrt(float64(cfg.AttentionHeadDim)))
// Initialize transformer blocks
for _, block := range m.TransformerBlocks {
block.NHeads = cfg.NumAttentionHeads
block.HeadDim = cfg.AttentionHeadDim
block.Scale = scale
}
// Initialize single transformer blocks
for _, block := range m.SingleTransformerBlocks {
block.NHeads = cfg.NumAttentionHeads
block.HeadDim = cfg.AttentionHeadDim
block.InnerDim = innerDim
block.MLPHidDim = cfg.MLPHiddenDim()
block.Scale = scale
}
}
// Forward runs the Flux2 transformer
func (m *Flux2Transformer2DModel) Forward(patches, txtEmbeds *mlx.Array, timesteps *mlx.Array, rope *RoPECache) *mlx.Array {
patchShape := patches.Shape()
B := patchShape[0]
imgLen := patchShape[1]
txtLen := txtEmbeds.Shape()[1]
// Scale timestep to 0-1000 range (diffusers multiplies by 1000)
scaledTimesteps := mlx.MulScalar(timesteps, 1000.0)
// Compute timestep embedding
temb := m.TimeGuidanceEmbed.Forward(scaledTimesteps)
// Embed patches and text
imgHidden := m.XEmbedder.Forward(patches)
txtHidden := m.ContextEmbedder.Forward(txtEmbeds)
// Compute shared modulation
imgMod := m.DoubleStreamModulationImg.Forward(temb)
txtMod := m.DoubleStreamModulationTxt.Forward(temb)
singleMod := m.SingleStreamModulation.Forward(temb)
// Double (dual-stream) blocks
for _, block := range m.TransformerBlocks {
imgHidden, txtHidden = block.Forward(imgHidden, txtHidden, imgMod, txtMod, rope.Cos, rope.Sin)
}
// Concatenate for single-stream: text first, then image
hidden := mlx.Concatenate([]*mlx.Array{txtHidden, imgHidden}, 1)
// Single-stream blocks
for _, block := range m.SingleTransformerBlocks {
hidden = block.Forward(hidden, singleMod, rope.Cos, rope.Sin)
}
// Extract image portion
totalLen := txtLen + imgLen
imgOut := mlx.Slice(hidden, []int32{0, txtLen, 0}, []int32{B, totalLen, hidden.Shape()[2]})
// Final norm and projection
imgOut = m.NormOut.Forward(imgOut, temb)
return m.ProjOut.Forward(imgOut)
}
// Note: QK normalization uses mlx.RMSNorm (the fast version) directly
// See applyQKNorm function below
// compiledSwiGLU fuses: silu(gate) * up
// Called 30x per step (10 in dual-stream + 20 in single-stream blocks)
var compiledSwiGLU *mlx.CompiledFunc
func getCompiledSwiGLU() *mlx.CompiledFunc {
if compiledSwiGLU == nil {
compiledSwiGLU = mlx.CompileShapeless(func(inputs []*mlx.Array) []*mlx.Array {
gate, up := inputs[0], inputs[1]
return []*mlx.Array{mlx.Mul(mlx.SiLU(gate), up)}
}, true)
}
return compiledSwiGLU
}
// Helper functions
// parseModulation3 extracts 3 modulation params (shift, scale, gate) starting at offset
func parseModulation3(mod *mlx.Array, dim int32, offset int32) (*mlx.Array, *mlx.Array, *mlx.Array) {
B := mod.Shape()[0]
start := offset * dim
shift := mlx.Slice(mod, []int32{0, start}, []int32{B, start + dim})
scale := mlx.Slice(mod, []int32{0, start + dim}, []int32{B, start + 2*dim})
gate := mlx.Slice(mod, []int32{0, start + 2*dim}, []int32{B, start + 3*dim})
// Expand for broadcasting [B, dim] -> [B, 1, dim]
shift = mlx.ExpandDims(shift, 1)
scale = mlx.ExpandDims(scale, 1)
gate = mlx.ExpandDims(gate, 1)
return shift, scale, gate
}
// modulateLayerNorm applies LayerNorm then shift/scale modulation
// Diffusers uses LayerNorm(elementwise_affine=False) which centers the data
func modulateLayerNorm(x *mlx.Array, shift, scale *mlx.Array) *mlx.Array {
// Fast LayerNorm without learnable params
x = mlx.LayerNorm(x, 1e-6)
// Modulate: x * (1 + scale) + shift
x = mlx.Mul(x, mlx.AddScalar(scale, 1.0))
return mlx.Add(x, shift)
}
// splitQKV splits a fused QKV tensor into Q, K, V
func splitQKV(qkv *mlx.Array, B, L, dim int32) (*mlx.Array, *mlx.Array, *mlx.Array) {
q := mlx.Slice(qkv, []int32{0, 0, 0}, []int32{B, L, dim})
k := mlx.Slice(qkv, []int32{0, 0, dim}, []int32{B, L, 2 * dim})
v := mlx.Slice(qkv, []int32{0, 0, 2 * dim}, []int32{B, L, 3 * dim})
return q, k, v
}
// applyQKNorm applies RMSNorm with learned scale (no bias)
// Uses the optimized mlx_fast_rms_norm
func applyQKNorm(x *mlx.Array, scale *mlx.Array) *mlx.Array {
return mlx.RMSNorm(x, scale, 1e-6)
}

View File

@@ -0,0 +1,804 @@
//go:build mlx
package flux2
import (
"fmt"
"math"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
"github.com/ollama/ollama/x/imagegen/vae"
)
// VAEConfig holds AutoencoderKLFlux2 configuration
type VAEConfig struct {
ActFn string `json:"act_fn"` // "silu"
BatchNormEps float32 `json:"batch_norm_eps"` // 0.0001
BatchNormMomentum float32 `json:"batch_norm_momentum"` // 0.1
BlockOutChannels []int32 `json:"block_out_channels"` // [128, 256, 512, 512]
ForceUpcast bool `json:"force_upcast"` // true
InChannels int32 `json:"in_channels"` // 3
LatentChannels int32 `json:"latent_channels"` // 32
LayersPerBlock int32 `json:"layers_per_block"` // 2
MidBlockAddAttn bool `json:"mid_block_add_attention"` // true
NormNumGroups int32 `json:"norm_num_groups"` // 32
OutChannels int32 `json:"out_channels"` // 3
PatchSize []int32 `json:"patch_size"` // [2, 2]
SampleSize int32 `json:"sample_size"` // 1024
UsePostQuantConv bool `json:"use_post_quant_conv"` // true
UseQuantConv bool `json:"use_quant_conv"` // true
}
// BatchNorm2D implements 2D batch normalization with running statistics
type BatchNorm2D struct {
RunningMean *mlx.Array // [C]
RunningVar *mlx.Array // [C]
Weight *mlx.Array // [C] gamma
Bias *mlx.Array // [C] beta
Eps float32
Momentum float32
}
// Forward applies batch normalization (inference mode - uses running stats)
// Input and output are in NHWC format [B, H, W, C]
func (bn *BatchNorm2D) Forward(x *mlx.Array) *mlx.Array {
shape := x.Shape()
C := shape[3]
// Reshape stats for broadcasting [1, 1, 1, C]
mean := mlx.Reshape(bn.RunningMean, 1, 1, 1, C)
variance := mlx.Reshape(bn.RunningVar, 1, 1, 1, C)
// Normalize: (x - mean) / sqrt(var + eps)
xNorm := mlx.Sub(x, mean)
xNorm = mlx.Div(xNorm, mlx.Sqrt(mlx.AddScalar(variance, bn.Eps)))
// Scale and shift (only if affine=True)
if bn.Weight != nil {
weight := mlx.Reshape(bn.Weight, 1, 1, 1, C)
xNorm = mlx.Mul(xNorm, weight)
}
if bn.Bias != nil {
bias := mlx.Reshape(bn.Bias, 1, 1, 1, C)
xNorm = mlx.Add(xNorm, bias)
}
return xNorm
}
// Denormalize inverts the batch normalization
// Used when decoding latents
func (bn *BatchNorm2D) Denormalize(x *mlx.Array) *mlx.Array {
shape := x.Shape()
C := shape[3]
// Reshape stats for broadcasting [1, 1, 1, C]
mean := mlx.Reshape(bn.RunningMean, 1, 1, 1, C)
variance := mlx.Reshape(bn.RunningVar, 1, 1, 1, C)
// Inverse: first undo affine, then undo normalization
// For affine=False: x_denorm = x * sqrt(var + eps) + mean
if bn.Bias != nil {
bias := mlx.Reshape(bn.Bias, 1, 1, 1, C)
x = mlx.Sub(x, bias)
}
if bn.Weight != nil {
weight := mlx.Reshape(bn.Weight, 1, 1, 1, C)
x = mlx.Div(x, weight)
}
x = mlx.Mul(x, mlx.Sqrt(mlx.AddScalar(variance, bn.Eps)))
x = mlx.Add(x, mean)
return x
}
// GroupNormLayer implements group normalization
// Reused from zimage package pattern
type GroupNormLayer struct {
Weight *mlx.Array `weight:"weight"`
Bias *mlx.Array `weight:"bias"`
NumGroups int32
Eps float32
}
// Forward applies group normalization
// Input and output are in NHWC format [B, H, W, C]
func (gn *GroupNormLayer) Forward(x *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
H := shape[1]
W := shape[2]
C := shape[3]
// Reshape to [B, H, W, groups, C/groups]
groupSize := C / gn.NumGroups
x = mlx.Reshape(x, B, H, W, gn.NumGroups, groupSize)
// Compute mean and variance per group
mean := mlx.Mean(x, 1, true)
mean = mlx.Mean(mean, 2, true)
mean = mlx.Mean(mean, 4, true)
xCentered := mlx.Sub(x, mean)
sq := mlx.Square(xCentered)
variance := mlx.Mean(sq, 1, true)
variance = mlx.Mean(variance, 2, true)
variance = mlx.Mean(variance, 4, true)
// Normalize
xNorm := mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, gn.Eps)))
// Reshape back to [B, H, W, C]
xNorm = mlx.Reshape(xNorm, B, H, W, C)
// Scale and shift
if gn.Weight != nil {
weight := mlx.Reshape(gn.Weight, 1, 1, 1, C)
xNorm = mlx.Mul(xNorm, weight)
}
if gn.Bias != nil {
bias := mlx.Reshape(gn.Bias, 1, 1, 1, C)
xNorm = mlx.Add(xNorm, bias)
}
return xNorm
}
// Conv2D represents a 2D convolution layer (reused pattern)
type Conv2D struct {
Weight *mlx.Array `weight:"weight"`
Bias *mlx.Array `weight:"bias,optional"`
Stride int32
Padding int32
}
// Transform implements safetensors.Transformer to transpose weights from PyTorch's OIHW to MLX's OHWI.
func (conv *Conv2D) Transform(field string, arr *mlx.Array) *mlx.Array {
if field == "Weight" {
return mlx.Transpose(arr, 0, 2, 3, 1)
}
return arr
}
// Forward applies convolution (NHWC format)
func (conv *Conv2D) Forward(x *mlx.Array) *mlx.Array {
out := mlx.Conv2d(x, conv.Weight, conv.Stride, conv.Padding)
if conv.Bias != nil {
bias := mlx.Reshape(conv.Bias, 1, 1, 1, conv.Bias.Dim(0))
out = mlx.Add(out, bias)
}
return out
}
// ResnetBlock2D implements a ResNet block for VAE
type ResnetBlock2D struct {
Norm1 *GroupNormLayer `weight:"norm1"`
Conv1 *Conv2D `weight:"conv1"`
Norm2 *GroupNormLayer `weight:"norm2"`
Conv2 *Conv2D `weight:"conv2"`
ConvShortcut *Conv2D `weight:"conv_shortcut,optional"`
}
// Forward applies the ResNet block
func (rb *ResnetBlock2D) Forward(x *mlx.Array) *mlx.Array {
h := rb.Norm1.Forward(x)
h = mlx.SiLU(h)
h = rb.Conv1.Forward(h)
h = rb.Norm2.Forward(h)
h = mlx.SiLU(h)
h = rb.Conv2.Forward(h)
if rb.ConvShortcut != nil {
x = rb.ConvShortcut.Forward(x)
}
return mlx.Add(h, x)
}
// VAEAttentionBlock implements self-attention for VAE
type VAEAttentionBlock struct {
GroupNorm *GroupNormLayer `weight:"group_norm"`
ToQ nn.LinearLayer `weight:"to_q"`
ToK nn.LinearLayer `weight:"to_k"`
ToV nn.LinearLayer `weight:"to_v"`
ToOut nn.LinearLayer `weight:"to_out.0"`
}
// Forward applies attention (NHWC format)
func (ab *VAEAttentionBlock) Forward(x *mlx.Array) *mlx.Array {
residual := x
shape := x.Shape()
B := shape[0]
H := shape[1]
W := shape[2]
C := shape[3]
h := ab.GroupNorm.Forward(x)
h = mlx.Reshape(h, B, H*W, C)
q := ab.ToQ.Forward(h)
k := ab.ToK.Forward(h)
v := ab.ToV.Forward(h)
q = mlx.ExpandDims(q, 1)
k = mlx.ExpandDims(k, 1)
v = mlx.ExpandDims(v, 1)
scale := float32(1.0 / math.Sqrt(float64(C)))
out := mlx.ScaledDotProductAttention(q, k, v, scale, false)
out = mlx.Squeeze(out, 1)
out = ab.ToOut.Forward(out)
out = mlx.Reshape(out, B, H, W, C)
out = mlx.Add(out, residual)
return out
}
// UpDecoderBlock2D implements an upsampling decoder block
type UpDecoderBlock2D struct {
ResnetBlocks []*ResnetBlock2D
Upsample *Conv2D
}
// Forward applies the up decoder block
func (ub *UpDecoderBlock2D) Forward(x *mlx.Array) *mlx.Array {
for _, resnet := range ub.ResnetBlocks {
x = resnet.Forward(x)
}
if ub.Upsample != nil {
x = upsample2x(x)
x = ub.Upsample.Forward(x)
}
return x
}
// upsample2x performs 2x nearest neighbor upsampling
func upsample2x(x *mlx.Array) *mlx.Array {
shape := x.Shape()
H := shape[1]
W := shape[2]
hIdx := mlx.ArangeInt(0, H, 1, mlx.DtypeInt32)
hIdx = mlx.Reshape(hIdx, H, 1)
hIdx = mlx.BroadcastTo(hIdx, []int32{H, 2})
hIdx = mlx.Reshape(hIdx, H*2)
wIdx := mlx.ArangeInt(0, W, 1, mlx.DtypeInt32)
wIdx = mlx.Reshape(wIdx, W, 1)
wIdx = mlx.BroadcastTo(wIdx, []int32{W, 2})
wIdx = mlx.Reshape(wIdx, W*2)
x = mlx.Take(x, hIdx, 1)
x = mlx.Take(x, wIdx, 2)
return x
}
// VAEMidBlock is the middle block with attention
type VAEMidBlock struct {
Resnet1 *ResnetBlock2D
Attention *VAEAttentionBlock
Resnet2 *ResnetBlock2D
}
// Forward applies the mid block
func (mb *VAEMidBlock) Forward(x *mlx.Array) *mlx.Array {
x = mb.Resnet1.Forward(x)
x = mb.Attention.Forward(x)
x = mb.Resnet2.Forward(x)
return x
}
// DefaultTilingConfig returns reasonable defaults for tiled decoding
// Matches diffusers: tile_latent_min_size=64, tile_overlap_factor=0.25
func DefaultTilingConfig() *vae.TilingConfig {
return vae.DefaultTilingConfig()
}
// AutoencoderKLFlux2 is the Flux2 VAE with BatchNorm
type AutoencoderKLFlux2 struct {
Config *VAEConfig
// Encoder components (for image editing)
EncoderConvIn *Conv2D
EncoderMid *VAEMidBlock
EncoderDown []*DownEncoderBlock2D
EncoderNormOut *GroupNormLayer
EncoderConvOut *Conv2D
// Decoder components
DecoderConvIn *Conv2D
DecoderMid *VAEMidBlock
DecoderUp []*UpDecoderBlock2D
DecoderNormOut *GroupNormLayer
DecoderConvOut *Conv2D
// Quant conv layers
QuantConv *Conv2D
PostQuantConv *Conv2D
// BatchNorm for latent normalization
LatentBN *BatchNorm2D
// Tiling configuration (nil = no tiling)
Tiling *vae.TilingConfig
}
// DownEncoderBlock2D implements a downsampling encoder block
type DownEncoderBlock2D struct {
ResnetBlocks []*ResnetBlock2D
Downsample *Conv2D
}
// Forward applies the down encoder block
func (db *DownEncoderBlock2D) Forward(x *mlx.Array) *mlx.Array {
for _, resnet := range db.ResnetBlocks {
x = resnet.Forward(x)
}
if db.Downsample != nil {
// Pad then conv with stride 2
x = mlx.Pad(x, []int32{0, 0, 0, 1, 0, 1, 0, 0})
x = db.Downsample.Forward(x)
}
return x
}
// Load loads the Flux2 VAE from ollama blob storage.
func (m *AutoencoderKLFlux2) Load(manifest *imagegen.ModelManifest) error {
fmt.Print(" Loading VAE... ")
// Load config from blob
var cfg VAEConfig
if err := manifest.ReadConfigJSON("vae/config.json", &cfg); err != nil {
return fmt.Errorf("config: %w", err)
}
m.Config = &cfg
// Load weights from tensor blobs
weights, err := imagegen.LoadWeightsFromManifest(manifest, "vae")
if err != nil {
return fmt.Errorf("weights: %w", err)
}
if err := weights.Load(0); err != nil {
return fmt.Errorf("load weights: %w", err)
}
defer weights.ReleaseAll()
return m.loadWeights(weights, &cfg)
}
// loadWeights loads VAE weights from any WeightSource
func (m *AutoencoderKLFlux2) loadWeights(weights safetensors.WeightSource, cfg *VAEConfig) error {
var err error
// Load encoder components (for image conditioning)
if err := m.loadEncoderWeights(weights, cfg); err != nil {
return fmt.Errorf("encoder: %w", err)
}
// Load decoder conv_in
m.DecoderConvIn = &Conv2D{Stride: 1, Padding: 1}
if err := safetensors.LoadModule(m.DecoderConvIn, weights, "decoder.conv_in"); err != nil {
return fmt.Errorf("decoder.conv_in: %w", err)
}
// Load mid block
m.DecoderMid, err = loadVAEMidBlock(weights, "decoder.mid_block", cfg.NormNumGroups)
if err != nil {
return fmt.Errorf("decoder.mid_block: %w", err)
}
// Load up blocks
numBlocks := len(cfg.BlockOutChannels)
m.DecoderUp = make([]*UpDecoderBlock2D, numBlocks)
for i := 0; i < numBlocks; i++ {
prefix := fmt.Sprintf("decoder.up_blocks.%d", i)
hasUpsample := i < numBlocks-1
m.DecoderUp[i], err = loadUpDecoderBlock2D(weights, prefix, cfg.LayersPerBlock+1, cfg.NormNumGroups, hasUpsample)
if err != nil {
return fmt.Errorf("%s: %w", prefix, err)
}
}
// Load decoder conv_norm_out and conv_out
m.DecoderNormOut = &GroupNormLayer{NumGroups: cfg.NormNumGroups, Eps: 1e-5}
if err := safetensors.LoadModule(m.DecoderNormOut, weights, "decoder.conv_norm_out"); err != nil {
return fmt.Errorf("decoder.conv_norm_out: %w", err)
}
m.DecoderConvOut = &Conv2D{Stride: 1, Padding: 1}
if err := safetensors.LoadModule(m.DecoderConvOut, weights, "decoder.conv_out"); err != nil {
return fmt.Errorf("decoder.conv_out: %w", err)
}
// Load post_quant_conv
if cfg.UsePostQuantConv {
m.PostQuantConv = &Conv2D{Stride: 1, Padding: 0}
if err := safetensors.LoadModule(m.PostQuantConv, weights, "post_quant_conv"); err != nil {
return fmt.Errorf("post_quant_conv: %w", err)
}
}
// Load latent BatchNorm (affine=False, so no weight/bias)
bnMean, err := weights.GetTensor("bn.running_mean")
if err != nil {
return fmt.Errorf("bn.running_mean: %w", err)
}
bnVar, err := weights.GetTensor("bn.running_var")
if err != nil {
return fmt.Errorf("bn.running_var: %w", err)
}
m.LatentBN = &BatchNorm2D{
RunningMean: bnMean,
RunningVar: bnVar,
Weight: nil, // affine=False
Bias: nil, // affine=False
Eps: cfg.BatchNormEps,
Momentum: cfg.BatchNormMomentum,
}
fmt.Println("✓")
return nil
}
// loadVAEMidBlock loads the mid block.
func loadVAEMidBlock(weights safetensors.WeightSource, prefix string, numGroups int32) (*VAEMidBlock, error) {
resnet1, err := loadResnetBlock2D(weights, prefix+".resnets.0", numGroups)
if err != nil {
return nil, err
}
attention, err := loadVAEAttentionBlock(weights, prefix+".attentions.0", numGroups)
if err != nil {
return nil, err
}
resnet2, err := loadResnetBlock2D(weights, prefix+".resnets.1", numGroups)
if err != nil {
return nil, err
}
return &VAEMidBlock{
Resnet1: resnet1,
Attention: attention,
Resnet2: resnet2,
}, nil
}
// loadResnetBlock2D loads a ResNet block.
func loadResnetBlock2D(weights safetensors.WeightSource, prefix string, numGroups int32) (*ResnetBlock2D, error) {
block := &ResnetBlock2D{
Norm1: &GroupNormLayer{NumGroups: numGroups, Eps: 1e-5},
Conv1: &Conv2D{Stride: 1, Padding: 1},
Norm2: &GroupNormLayer{NumGroups: numGroups, Eps: 1e-5},
Conv2: &Conv2D{Stride: 1, Padding: 1},
ConvShortcut: &Conv2D{Stride: 1, Padding: 0}, // Pre-allocate for optional loading
}
if err := safetensors.LoadModule(block, weights, prefix); err != nil {
return nil, err
}
// If ConvShortcut wasn't loaded (no weights found), nil it out
if block.ConvShortcut.Weight == nil {
block.ConvShortcut = nil
}
return block, nil
}
// loadVAEAttentionBlock loads an attention block using LoadModule.
func loadVAEAttentionBlock(weights safetensors.WeightSource, prefix string, numGroups int32) (*VAEAttentionBlock, error) {
ab := &VAEAttentionBlock{
GroupNorm: &GroupNormLayer{NumGroups: numGroups, Eps: 1e-5},
}
if err := safetensors.LoadModule(ab, weights, prefix); err != nil {
return nil, err
}
return ab, nil
}
// loadUpDecoderBlock2D loads an up decoder block.
func loadUpDecoderBlock2D(weights safetensors.WeightSource, prefix string, numLayers, numGroups int32, hasUpsample bool) (*UpDecoderBlock2D, error) {
resnets := make([]*ResnetBlock2D, numLayers)
for i := int32(0); i < numLayers; i++ {
resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i)
resnet, err := loadResnetBlock2D(weights, resPrefix, numGroups)
if err != nil {
return nil, err
}
resnets[i] = resnet
}
var upsample *Conv2D
if hasUpsample {
upsample = &Conv2D{Stride: 1, Padding: 1}
if err := safetensors.LoadModule(upsample, weights, prefix+".upsamplers.0.conv"); err != nil {
return nil, err
}
}
return &UpDecoderBlock2D{
ResnetBlocks: resnets,
Upsample: upsample,
}, nil
}
// Patchify converts latents [B, C, H, W] to patches [B, H*W/4, C*4] using 2x2 patches
// This is the inverse of the VAE's patchify for feeding to transformer
func (vae *AutoencoderKLFlux2) Patchify(latents *mlx.Array) *mlx.Array {
shape := latents.Shape()
B := shape[0]
C := shape[1]
H := shape[2]
W := shape[3]
patchH := vae.Config.PatchSize[0]
patchW := vae.Config.PatchSize[1]
pH := H / patchH
pW := W / patchW
// [B, C, H, W] -> [B, C, pH, patchH, pW, patchW]
x := mlx.Reshape(latents, B, C, pH, patchH, pW, patchW)
// [B, C, pH, patchH, pW, patchW] -> [B, pH, pW, C, patchH, patchW]
x = mlx.Transpose(x, 0, 2, 4, 1, 3, 5)
// [B, pH, pW, C, patchH, patchW] -> [B, pH*pW, C*patchH*patchW]
return mlx.Reshape(x, B, pH*pW, C*patchH*patchW)
}
// Unpatchify converts patches [B, L, C*4] back to [B, C, H, W]
func (vae *AutoencoderKLFlux2) Unpatchify(patches *mlx.Array, pH, pW, C int32) *mlx.Array {
shape := patches.Shape()
B := shape[0]
patchH := vae.Config.PatchSize[0]
patchW := vae.Config.PatchSize[1]
// [B, pH*pW, C*patchH*patchW] -> [B, pH, pW, C, patchH, patchW]
x := mlx.Reshape(patches, B, pH, pW, C, patchH, patchW)
// [B, pH, pW, C, patchH, patchW] -> [B, C, pH, patchH, pW, patchW]
x = mlx.Transpose(x, 0, 3, 1, 4, 2, 5)
// [B, C, pH, patchH, pW, patchW] -> [B, C, H, W]
H := pH * patchH
W := pW * patchW
return mlx.Reshape(x, B, C, H, W)
}
// denormalizePatchified applies inverse batch normalization to patchified latents.
// Input: [B, L, 128] where 128 = 32 latent channels * 4 (2x2 patch)
// Output: [B, L, 128] denormalized
func (vae *AutoencoderKLFlux2) denormalizePatchified(x *mlx.Array) *mlx.Array {
shape := x.Shape()
C := shape[2] // 128
// Reshape stats for broadcasting [1, 1, C]
mean := mlx.Reshape(vae.LatentBN.RunningMean, 1, 1, C)
variance := mlx.Reshape(vae.LatentBN.RunningVar, 1, 1, C)
// Inverse BN (affine=False): x_denorm = x * sqrt(var + eps) + mean
if vae.LatentBN.Bias != nil {
bias := mlx.Reshape(vae.LatentBN.Bias, 1, 1, C)
x = mlx.Sub(x, bias)
}
if vae.LatentBN.Weight != nil {
weight := mlx.Reshape(vae.LatentBN.Weight, 1, 1, C)
x = mlx.Div(x, weight)
}
x = mlx.Mul(x, mlx.Sqrt(mlx.AddScalar(variance, vae.LatentBN.Eps)))
x = mlx.Add(x, mean)
return x
}
// Decode decodes latent patches to images.
// If Tiling is set, uses tiled decoding to reduce memory for large images.
// latents: [B, L, C*4] patchified latents from transformer
// pH, pW: patch grid dimensions
// Returns: [B, 3, H, W] image tensor
func (v *AutoencoderKLFlux2) Decode(latents *mlx.Array, pH, pW int32) *mlx.Array {
// Denormalize patchified latents
z := v.denormalizePatchified(latents)
// Unpatchify: [B, L, C*4] -> [B, C, H, W]
z = v.Unpatchify(z, pH, pW, v.Config.LatentChannels)
// Convert NCHW -> NHWC for processing
z = mlx.Transpose(z, 0, 2, 3, 1)
// Use tiled decoding if enabled
if v.Tiling != nil {
mlx.Eval(z)
return vae.DecodeTiled(z, v.Tiling, v.decodeTile)
}
// Direct decode (no tiling)
h := v.decodeTile(z)
h = mlx.ClipScalar(h, 0.0, 1.0, true, true)
h = mlx.Transpose(h, 0, 3, 1, 2)
return h
}
// decodeTile decodes a single latent tile to pixels (internal helper)
// z: [B, H, W, C] latent tile in NHWC format
// Returns: [B, H*8, W*8, 3] pixel tile in NHWC format (before clipping)
func (vae *AutoencoderKLFlux2) decodeTile(z *mlx.Array) *mlx.Array {
// Post-quant conv
if vae.PostQuantConv != nil {
z = vae.PostQuantConv.Forward(z)
}
// Decoder
h := vae.DecoderConvIn.Forward(z)
h = vae.DecoderMid.Forward(h)
for _, upBlock := range vae.DecoderUp {
h = upBlock.Forward(h)
}
h = vae.DecoderNormOut.Forward(h)
h = mlx.SiLU(h)
h = vae.DecoderConvOut.Forward(h)
// VAE outputs [-1, 1], convert to [0, 1]
h = mlx.MulScalar(h, 0.5)
h = mlx.AddScalar(h, 0.5)
return h
}
// loadEncoderWeights loads the encoder components for image conditioning
func (m *AutoencoderKLFlux2) loadEncoderWeights(weights safetensors.WeightSource, cfg *VAEConfig) error {
var err error
// Load encoder conv_in
m.EncoderConvIn = &Conv2D{Stride: 1, Padding: 1}
if err := safetensors.LoadModule(m.EncoderConvIn, weights, "encoder.conv_in"); err != nil {
return fmt.Errorf("encoder.conv_in: %w", err)
}
// Load encoder down blocks
numBlocks := len(cfg.BlockOutChannels)
m.EncoderDown = make([]*DownEncoderBlock2D, numBlocks)
for i := 0; i < numBlocks; i++ {
prefix := fmt.Sprintf("encoder.down_blocks.%d", i)
hasDownsample := i < numBlocks-1
m.EncoderDown[i], err = loadDownEncoderBlock2D(weights, prefix, cfg.LayersPerBlock, cfg.NormNumGroups, hasDownsample)
if err != nil {
return fmt.Errorf("%s: %w", prefix, err)
}
}
// Load encoder mid block
m.EncoderMid, err = loadVAEMidBlock(weights, "encoder.mid_block", cfg.NormNumGroups)
if err != nil {
return fmt.Errorf("encoder.mid_block: %w", err)
}
// Load encoder conv_norm_out and conv_out
m.EncoderNormOut = &GroupNormLayer{NumGroups: cfg.NormNumGroups, Eps: 1e-5}
if err := safetensors.LoadModule(m.EncoderNormOut, weights, "encoder.conv_norm_out"); err != nil {
return fmt.Errorf("encoder.conv_norm_out: %w", err)
}
m.EncoderConvOut = &Conv2D{Stride: 1, Padding: 1}
if err := safetensors.LoadModule(m.EncoderConvOut, weights, "encoder.conv_out"); err != nil {
return fmt.Errorf("encoder.conv_out: %w", err)
}
// Load quant_conv (for encoding)
if cfg.UseQuantConv {
m.QuantConv = &Conv2D{Stride: 1, Padding: 0}
if err := safetensors.LoadModule(m.QuantConv, weights, "quant_conv"); err != nil {
return fmt.Errorf("quant_conv: %w", err)
}
}
return nil
}
// loadDownEncoderBlock2D loads a down encoder block.
func loadDownEncoderBlock2D(weights safetensors.WeightSource, prefix string, numLayers, numGroups int32, hasDownsample bool) (*DownEncoderBlock2D, error) {
resnets := make([]*ResnetBlock2D, numLayers)
for i := int32(0); i < numLayers; i++ {
resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i)
resnet, err := loadResnetBlock2D(weights, resPrefix, numGroups)
if err != nil {
return nil, err
}
resnets[i] = resnet
}
var downsample *Conv2D
if hasDownsample {
downsample = &Conv2D{Stride: 2, Padding: 0}
if err := safetensors.LoadModule(downsample, weights, prefix+".downsamplers.0.conv"); err != nil {
return nil, err
}
}
return &DownEncoderBlock2D{
ResnetBlocks: resnets,
Downsample: downsample,
}, nil
}
// EncodeImage encodes an image to normalized latents.
// image: [B, 3, H, W] image tensor in [-1, 1]
// Returns: [B, L, C*4] patchified normalized latents
func (vae *AutoencoderKLFlux2) EncodeImage(image *mlx.Array) *mlx.Array {
// Convert NCHW -> NHWC
x := mlx.Transpose(image, 0, 2, 3, 1)
// Encoder
h := vae.EncoderConvIn.Forward(x)
for _, downBlock := range vae.EncoderDown {
h = downBlock.Forward(h)
}
h = vae.EncoderMid.Forward(h)
h = vae.EncoderNormOut.Forward(h)
h = mlx.SiLU(h)
h = vae.EncoderConvOut.Forward(h)
// Quant conv outputs [B, H, W, 2*latent_channels] (mean + logvar)
if vae.QuantConv != nil {
h = vae.QuantConv.Forward(h)
}
// Take only the mean (first latent_channels) - deterministic encoding
// h is [B, H, W, 64] -> take first 32 channels for mean
shape := h.Shape()
latentChannels := vae.Config.LatentChannels // 32
h = mlx.Slice(h, []int32{0, 0, 0, 0}, []int32{shape[0], shape[1], shape[2], latentChannels})
// Convert NHWC -> NCHW for patchifying
h = mlx.Transpose(h, 0, 3, 1, 2)
// Patchify: [B, C, H, W] -> [B, L, C*4]
h = vae.Patchify(h)
// Apply BatchNorm on patchified latents [B, L, 128]
// The BatchNorm has 128 channels matching the patchified dimension
h = vae.normalizePatchified(h)
return h
}
// normalizePatchified applies batch normalization to patchified latents.
// Input: [B, L, 128] where 128 = 32 latent channels * 4 (2x2 patch)
// Output: [B, L, 128] normalized
func (vae *AutoencoderKLFlux2) normalizePatchified(x *mlx.Array) *mlx.Array {
shape := x.Shape()
C := shape[2] // 128
// Reshape stats for broadcasting [1, 1, C]
mean := mlx.Reshape(vae.LatentBN.RunningMean, 1, 1, C)
variance := mlx.Reshape(vae.LatentBN.RunningVar, 1, 1, C)
// Normalize: (x - mean) / sqrt(var + eps)
xNorm := mlx.Sub(x, mean)
xNorm = mlx.Div(xNorm, mlx.Sqrt(mlx.AddScalar(variance, vae.LatentBN.Eps)))
// Scale and shift (only if affine=True)
if vae.LatentBN.Weight != nil {
weight := mlx.Reshape(vae.LatentBN.Weight, 1, 1, C)
xNorm = mlx.Mul(xNorm, weight)
}
if vae.LatentBN.Bias != nil {
bias := mlx.Reshape(vae.LatentBN.Bias, 1, 1, C)
xNorm = mlx.Add(xNorm, bias)
}
return xNorm
}

View File

@@ -0,0 +1,390 @@
//go:build mlx
// Package qwen3 provides a shared Qwen3 text encoder used by multiple image generation models.
package qwen3
import (
"fmt"
"math"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
"github.com/ollama/ollama/x/imagegen/tokenizer"
)
// Config holds Qwen3 text encoder configuration
type Config struct {
HiddenSize int32 `json:"hidden_size"`
NumHiddenLayers int32 `json:"num_hidden_layers"`
IntermediateSize int32 `json:"intermediate_size"`
NumAttentionHeads int32 `json:"num_attention_heads"`
NumKeyValueHeads int32 `json:"num_key_value_heads"`
VocabSize int32 `json:"vocab_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
HeadDim int32 `json:"head_dim"`
}
// Attention implements Qwen3 attention with QK norms
type Attention struct {
QProj nn.LinearLayer `weight:"q_proj"`
KProj nn.LinearLayer `weight:"k_proj"`
VProj nn.LinearLayer `weight:"v_proj"`
OProj nn.LinearLayer `weight:"o_proj"`
QNorm *nn.RMSNorm `weight:"q_norm"`
KNorm *nn.RMSNorm `weight:"k_norm"`
// Computed fields
NHeads int32
NKVHeads int32
HeadDim int32
Scale float32
RopeTheta float32
}
// applyRoPEQwen3 applies the custom RoPE for Qwen3 text encoder
func applyRoPEQwen3(x *mlx.Array, seqLen int32, theta float32) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
H := shape[2]
D := shape[3]
half := D / 2
freqsArr := make([]float32, half)
logTheta := float32(math.Log(float64(theta)))
for i := int32(0); i < half; i++ {
freqsArr[i] = float32(math.Exp(float64(-logTheta * float32(i) / float32(half))))
}
freqs := mlx.NewArray(freqsArr, []int32{half})
posArr := make([]float32, seqLen)
for i := int32(0); i < seqLen; i++ {
posArr[i] = float32(i)
}
pos := mlx.NewArray(posArr, []int32{seqLen})
posExpanded := mlx.Reshape(pos, seqLen, 1)
freqsExpanded := mlx.Reshape(freqs, 1, half)
args := mlx.Mul(posExpanded, freqsExpanded)
cosVals := mlx.Cos(args)
sinVals := mlx.Sin(args)
cosVals = mlx.Reshape(cosVals, seqLen, 1, half)
sinVals = mlx.Reshape(sinVals, seqLen, 1, half)
x1 := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, L, H, half})
x2 := mlx.Slice(x, []int32{0, 0, 0, half}, []int32{B, L, H, D})
part1 := mlx.Sub(mlx.Mul(x1, cosVals), mlx.Mul(x2, sinVals))
part2 := mlx.Add(mlx.Mul(x1, sinVals), mlx.Mul(x2, cosVals))
return mlx.Concatenate([]*mlx.Array{part1, part2}, 3)
}
// Forward computes attention with causal masking and optional padding mask
func (attn *Attention) Forward(x *mlx.Array, mask *mlx.Array, maskMode string) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
q := attn.QProj.Forward(x)
k := attn.KProj.Forward(x)
v := attn.VProj.Forward(x)
q = mlx.Reshape(q, B, L, attn.NHeads, attn.HeadDim)
k = mlx.Reshape(k, B, L, attn.NKVHeads, attn.HeadDim)
v = mlx.Reshape(v, B, L, attn.NKVHeads, attn.HeadDim)
// QK norm uses 1e-6 hardcoded (Qwen3 specific)
q = attn.QNorm.Forward(q, 1e-6)
k = attn.KNorm.Forward(k, 1e-6)
q = applyRoPEQwen3(q, L, attn.RopeTheta)
k = applyRoPEQwen3(k, L, attn.RopeTheta)
q = mlx.Transpose(q, 0, 2, 1, 3)
k = mlx.Transpose(k, 0, 2, 1, 3)
v = mlx.Transpose(v, 0, 2, 1, 3)
if attn.NKVHeads < attn.NHeads {
repeats := attn.NHeads / attn.NKVHeads
k = repeatKV(k, repeats)
v = repeatKV(v, repeats)
}
out := mlx.ScaledDotProductAttentionWithSinks(q, k, v, attn.Scale, maskMode, mask, nil)
out = mlx.Transpose(out, 0, 2, 1, 3)
out = mlx.Reshape(out, B, L, attn.NHeads*attn.HeadDim)
out = attn.OProj.Forward(out)
return out
}
// repeatKV repeats key/value heads for GQA
func repeatKV(x *mlx.Array, repeats int32) *mlx.Array {
if repeats == 1 {
return x
}
shape := x.Shape()
x = mlx.ExpandDims(x, 2)
x = mlx.Tile(x, []int32{1, 1, repeats, 1, 1})
return mlx.Reshape(x, shape[0], shape[1]*repeats, shape[2], shape[3])
}
// MLP implements Qwen3 SwiGLU MLP
type MLP struct {
GateProj nn.LinearLayer `weight:"gate_proj"`
UpProj nn.LinearLayer `weight:"up_proj"`
DownProj nn.LinearLayer `weight:"down_proj"`
}
// Forward applies the MLP
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
gate := m.GateProj.Forward(x)
gate = mlx.SiLU(gate)
up := m.UpProj.Forward(x)
h := mlx.Mul(gate, up)
return m.DownProj.Forward(h)
}
// Block represents a single Qwen3 transformer block
type Block struct {
Attention *Attention `weight:"self_attn"`
MLP *MLP `weight:"mlp"`
InputLayerNorm *nn.RMSNorm `weight:"input_layernorm"`
PostAttnLayerNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
}
// Forward applies the Qwen3 block
func (qb *Block) Forward(x *mlx.Array, eps float32, mask *mlx.Array, maskMode string) *mlx.Array {
h := qb.InputLayerNorm.Forward(x, eps)
attnOut := qb.Attention.Forward(h, mask, maskMode)
x = mlx.Add(x, attnOut)
h = qb.PostAttnLayerNorm.Forward(x, eps)
mlpOut := qb.MLP.Forward(h)
x = mlx.Add(x, mlpOut)
return x
}
// TextEncoder is the full Qwen3 encoder
type TextEncoder struct {
EmbedTokens *nn.Embedding `weight:"model.embed_tokens"`
Layers []*Block `weight:"model.layers"`
FinalNorm *nn.RMSNorm `weight:"model.norm"`
*Config
}
// Load loads the Qwen3 text encoder from ollama blob storage.
func (m *TextEncoder) Load(manifest *imagegen.ModelManifest, configPath string) error {
fmt.Print(" Loading text encoder... ")
// Load config from blob
var cfg Config
if err := manifest.ReadConfigJSON(configPath, &cfg); err != nil {
return fmt.Errorf("config: %w", err)
}
m.Config = &cfg
m.Layers = make([]*Block, cfg.NumHiddenLayers)
// Load weights from tensor blobs
weights, err := imagegen.LoadWeightsFromManifest(manifest, "text_encoder")
if err != nil {
return fmt.Errorf("weights: %w", err)
}
if err := weights.Load(0); err != nil {
return fmt.Errorf("load weights: %w", err)
}
defer weights.ReleaseAll()
return m.loadWeights(weights)
}
// loadWeights loads weights from any WeightSource into the model
func (m *TextEncoder) loadWeights(weights safetensors.WeightSource) error {
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return fmt.Errorf("load module: %w", err)
}
m.initComputedFields()
fmt.Println("✓")
return nil
}
// initComputedFields initializes computed fields after loading weights
func (m *TextEncoder) initComputedFields() {
cfg := m.Config
m.FinalNorm.Eps = cfg.RMSNormEps
for _, block := range m.Layers {
// Attention
block.Attention.NHeads = cfg.NumAttentionHeads
block.Attention.NKVHeads = cfg.NumKeyValueHeads
block.Attention.HeadDim = cfg.HeadDim
block.Attention.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
block.Attention.RopeTheta = cfg.RopeTheta
block.Attention.QNorm.Eps = cfg.RMSNormEps
block.Attention.KNorm.Eps = cfg.RMSNormEps
// Block norms
block.InputLayerNorm.Eps = cfg.RMSNormEps
block.PostAttnLayerNorm.Eps = cfg.RMSNormEps
}
}
// Forward encodes text tokens with provided attention mask (LxL) and mask mode.
func (te *TextEncoder) Forward(tokens *mlx.Array, attnMask *mlx.Array, maskMode string) *mlx.Array {
h := te.EmbedTokens.Forward(tokens)
eps := te.RMSNormEps
for _, layer := range te.Layers {
h = layer.Forward(h, eps, attnMask, maskMode)
}
// Apply final RMS norm
h = te.FinalNorm.Forward(h, eps)
return h
}
// ForwardWithLayerOutputs encodes text tokens and returns hidden states from specified layers.
// This is used by Flux2 which needs embeddings from specific intermediate layers.
func (te *TextEncoder) ForwardWithLayerOutputs(tokens *mlx.Array, layerIndices []int, attnMask *mlx.Array, maskMode string) []*mlx.Array {
h := te.EmbedTokens.Forward(tokens)
eps := te.RMSNormEps
outputs := make([]*mlx.Array, len(layerIndices))
layerSet := make(map[int]int)
for i, idx := range layerIndices {
layerSet[idx] = i
}
for i, layer := range te.Layers {
h = layer.Forward(h, eps, attnMask, maskMode)
if outIdx, ok := layerSet[i]; ok {
outputs[outIdx] = h
}
}
return outputs
}
// ApplyChatTemplate wraps prompt in Qwen3 chat format.
// If think is true, adds the <think></think> block after the assistant tag
// (matches tokenizer.apply_chat_template with enable_thinking=False in Python).
func ApplyChatTemplate(prompt string, think bool) string {
base := "<|im_start|>user\n" + prompt + "<|im_end|>\n<|im_start|>assistant\n"
if think {
return base + "<think>\n\n</think>\n\n"
}
return base
}
// EncodePrompt encodes a text prompt using the tokenizer and encoder.
// If think is true, includes the <think></think> block in the chat template.
func (te *TextEncoder) EncodePrompt(tok *tokenizer.Tokenizer, prompt string, maxLen int, think bool) (*mlx.Array, *mlx.Array) {
formattedPrompt := ApplyChatTemplate(prompt, think)
tokens := tok.Encode(formattedPrompt, false)
if len(tokens) > maxLen {
tokens = tokens[:maxLen]
}
maskData := make([]float32, maxLen)
for i := 0; i < len(tokens); i++ {
maskData[i] = 1.0
}
// Get PAD token (different from EOS for Qwen3)
padToken := tok.PAD()
if padToken < 0 {
padToken = tok.EOS() // fallback
}
paddedTokens := make([]int32, maxLen)
copy(paddedTokens, tokens)
for i := len(tokens); i < maxLen; i++ {
paddedTokens[i] = padToken
}
tokensArr := mlx.NewArrayInt32(paddedTokens, []int32{1, int32(maxLen)})
maskArr := mlx.NewArray(maskData, []int32{1, int32(maxLen)})
// Build combined causal + PAD mask [L, L]
// mask[i,j] = 0 if (j <= i AND valid[j]) else -inf
L := int32(maxLen)
validLen := int32(len(tokens))
combinedMaskData := make([]float32, L*L)
negInf := float32(-1e9)
for i := int32(0); i < L; i++ {
for j := int32(0); j < L; j++ {
idx := i*L + j
if j <= i && j < validLen {
combinedMaskData[idx] = 0
} else {
combinedMaskData[idx] = negInf
}
}
}
maskMat := mlx.NewArray(combinedMaskData, []int32{L, L})
embeddings := te.Forward(tokensArr, maskMat, "")
return embeddings, maskArr
}
// EncodePromptWithLayers encodes a text prompt and returns embeddings from specified layers.
// Used by Flux2 which concatenates embeddings from multiple intermediate layers.
// If think is true, includes the <think></think> block in the chat template.
// Returns embeddings and padded sequence length.
func (te *TextEncoder) EncodePromptWithLayers(tok *tokenizer.Tokenizer, prompt string, maxLen int, layerIndices []int, think bool) (*mlx.Array, int32) {
formattedPrompt := ApplyChatTemplate(prompt, think)
tokens := tok.Encode(formattedPrompt, false)
if len(tokens) > maxLen {
tokens = tokens[:maxLen]
}
// Pad to maxLen
padToken := tok.PAD()
if padToken < 0 {
padToken = tok.EOS() // fallback
}
padded := make([]int32, maxLen)
copy(padded, tokens)
for i := len(tokens); i < maxLen; i++ {
padded[i] = padToken
}
tokensArr := mlx.NewArrayInt32(padded, []int32{1, int32(maxLen)})
// Build combined causal + PAD mask [L, L]
// mask[i,j] = 0 if (j <= i AND valid[j]) else -inf
// This combines causal masking with PAD token masking
L := int32(maxLen)
validLen := int32(len(tokens))
maskData := make([]float32, L*L)
negInf := float32(-1e9)
for i := int32(0); i < L; i++ {
for j := int32(0); j < L; j++ {
idx := i*L + j
if j <= i && j < validLen {
maskData[idx] = 0 // allowed: causal OK and not PAD
} else {
maskData[idx] = negInf // blocked: future or PAD
}
}
}
maskMat := mlx.NewArray(maskData, []int32{L, L})
layerOutputs := te.ForwardWithLayerOutputs(tokensArr, layerIndices, maskMat, "")
// Concatenate layer outputs along the hidden dimension
// Each output is [B, L, hidden_dim], result is [B, L, num_layers * hidden_dim]
embeddings := mlx.Concatenate(layerOutputs, 2)
// Return embeddings and padded length
return embeddings, int32(maxLen)
}

View File

@@ -3,12 +3,33 @@
package qwen_image package qwen_image
import ( import (
"fmt"
"os" "os"
"path/filepath"
"runtime"
"testing" "testing"
"github.com/ollama/ollama/x/imagegen/mlx" "github.com/ollama/ollama/x/imagegen/mlx"
) )
// TestMain initializes MLX before running tests.
// If MLX libraries are not available, tests are skipped.
func TestMain(m *testing.M) {
// Change to repo root so ./build/lib/ollama/ path works
_, thisFile, _, _ := runtime.Caller(0)
repoRoot := filepath.Join(filepath.Dir(thisFile), "..", "..", "..", "..")
if err := os.Chdir(repoRoot); err != nil {
fmt.Printf("Failed to change to repo root: %v\n", err)
os.Exit(1)
}
if err := mlx.InitMLX(); err != nil {
fmt.Printf("Skipping qwen_image tests: %v\n", err)
os.Exit(0)
}
os.Exit(m.Run())
}
// TestPipelineOutput runs the full pipeline (integration test). // TestPipelineOutput runs the full pipeline (integration test).
// Skips if model weights not found. Requires ~50GB VRAM. // Skips if model weights not found. Requires ~50GB VRAM.
func TestPipelineOutput(t *testing.T) { func TestPipelineOutput(t *testing.T) {

View File

@@ -17,13 +17,13 @@ import (
// GenerateConfig holds all options for image generation. // GenerateConfig holds all options for image generation.
type GenerateConfig struct { type GenerateConfig struct {
Prompt string Prompt string
NegativePrompt string // Empty = no CFG NegativePrompt string // Empty = no CFG
CFGScale float32 // Only used if NegativePrompt is set (default: 4.0) CFGScale float32 // Only used if NegativePrompt is set (default: 4.0)
Width int32 // Image width (default: 1024) Width int32 // Image width (default: 1024)
Height int32 // Image height (default: 1024) Height int32 // Image height (default: 1024)
Steps int // Denoising steps (default: 30) Steps int // Denoising steps (default: 30)
Seed int64 // Random seed Seed int64 // Random seed
Progress ProgressFunc // Optional progress callback Progress func(step, totalSteps int) // Optional progress callback
// Layer caching (DeepCache/Learning-to-Cache speedup) // Layer caching (DeepCache/Learning-to-Cache speedup)
LayerCache bool // Enable layer caching (default: false) LayerCache bool // Enable layer caching (default: false)
@@ -31,9 +31,6 @@ type GenerateConfig struct {
CacheLayers int // Number of shallow layers to cache (default: 25) CacheLayers int // Number of shallow layers to cache (default: 25)
} }
// ProgressFunc is called during generation with step progress.
type ProgressFunc func(step, totalSteps int)
// Model represents a Qwen-Image diffusion model. // Model represents a Qwen-Image diffusion model.
type Model struct { type Model struct {
ModelPath string ModelPath string
@@ -117,7 +114,7 @@ func (m *Model) Generate(prompt string, width, height int32, steps int, seed int
} }
// GenerateWithProgress creates an image with progress callback. // GenerateWithProgress creates an image with progress callback.
func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress ProgressFunc) (*mlx.Array, error) { func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress func(step, totalSteps int)) (*mlx.Array, error) {
return m.GenerateFromConfig(&GenerateConfig{ return m.GenerateFromConfig(&GenerateConfig{
Prompt: prompt, Prompt: prompt,
Width: width, Width: width,
@@ -129,7 +126,7 @@ func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps i
} }
// GenerateWithCFG creates an image with classifier-free guidance. // GenerateWithCFG creates an image with classifier-free guidance.
func (m *Model) GenerateWithCFG(prompt, negativePrompt string, width, height int32, steps int, seed int64, cfgScale float32, progress ProgressFunc) (*mlx.Array, error) { func (m *Model) GenerateWithCFG(prompt, negativePrompt string, width, height int32, steps int, seed int64, cfgScale float32, progress func(step, totalSteps int)) (*mlx.Array, error) {
return m.GenerateFromConfig(&GenerateConfig{ return m.GenerateFromConfig(&GenerateConfig{
Prompt: prompt, Prompt: prompt,
NegativePrompt: negativePrompt, NegativePrompt: negativePrompt,
@@ -172,7 +169,7 @@ func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
cfg.Height = 1024 cfg.Height = 1024
} }
if cfg.Steps <= 0 { if cfg.Steps <= 0 {
cfg.Steps = 30 cfg.Steps = 50
} }
if cfg.CFGScale <= 0 { if cfg.CFGScale <= 0 {
cfg.CFGScale = 4.0 cfg.CFGScale = 4.0

View File

@@ -18,18 +18,15 @@ import (
// GenerateConfig holds all options for image editing. // GenerateConfig holds all options for image editing.
type GenerateConfig struct { type GenerateConfig struct {
Prompt string Prompt string
NegativePrompt string // Unconditional prompt for CFG (empty string "" is valid) NegativePrompt string // Unconditional prompt for CFG (empty string "" is valid)
CFGScale float32 // CFG enabled when > 1.0 (default: 4.0) CFGScale float32 // CFG enabled when > 1.0 (default: 4.0)
Width int32 // Output width (default: from input image) Width int32 // Output width (default: from input image)
Height int32 // Output height (default: from input image) Height int32 // Output height (default: from input image)
Steps int // Denoising steps (default: 50) Steps int // Denoising steps (default: 50)
Seed int64 // Random seed Seed int64 // Random seed
Progress ProgressFunc // Optional progress callback Progress func(step, totalSteps int) // Optional progress callback
} }
// ProgressFunc is called during generation with step progress.
type ProgressFunc func(step, totalSteps int)
// Model represents a Qwen-Image-Edit diffusion model. // Model represents a Qwen-Image-Edit diffusion model.
type Model struct { type Model struct {
ModelPath string ModelPath string

View File

@@ -3,13 +3,35 @@
package qwen_image_edit package qwen_image_edit
import ( import (
"fmt"
"math" "math"
"os"
"path/filepath"
"runtime"
"testing" "testing"
"github.com/ollama/ollama/x/imagegen/mlx" "github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/models/qwen_image" "github.com/ollama/ollama/x/imagegen/models/qwen_image"
) )
// TestMain initializes MLX before running tests.
// If MLX libraries are not available, tests are skipped.
func TestMain(m *testing.M) {
// Change to repo root so ./build/lib/ollama/ path works
_, thisFile, _, _ := runtime.Caller(0)
repoRoot := filepath.Join(filepath.Dir(thisFile), "..", "..", "..", "..")
if err := os.Chdir(repoRoot); err != nil {
fmt.Printf("Failed to change to repo root: %v\n", err)
os.Exit(1)
}
if err := mlx.InitMLX(); err != nil {
fmt.Printf("Skipping qwen_image_edit tests: %v\n", err)
os.Exit(0)
}
os.Exit(m.Run())
}
// TestComputeAxisFreqs verifies frequency computation matches Python reference // TestComputeAxisFreqs verifies frequency computation matches Python reference
func TestComputeAxisFreqs(t *testing.T) { func TestComputeAxisFreqs(t *testing.T) {
theta := float64(10000) theta := float64(10000)

View File

@@ -3,287 +3,17 @@
package zimage package zimage
import ( import (
"fmt" "github.com/ollama/ollama/x/imagegen/models/qwen3"
"math"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
"github.com/ollama/ollama/x/imagegen/tokenizer"
) )
// Qwen3Config holds Qwen3 text encoder configuration // Re-export types from shared qwen3 package for backwards compatibility
type Qwen3Config struct { type (
HiddenSize int32 `json:"hidden_size"` Qwen3Config = qwen3.Config
NumHiddenLayers int32 `json:"num_hidden_layers"` Qwen3Attention = qwen3.Attention
IntermediateSize int32 `json:"intermediate_size"` Qwen3MLP = qwen3.MLP
NumAttentionHeads int32 `json:"num_attention_heads"` Qwen3Block = qwen3.Block
NumKeyValueHeads int32 `json:"num_key_value_heads"` Qwen3TextEncoder = qwen3.TextEncoder
VocabSize int32 `json:"vocab_size"` )
RMSNormEps float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
HeadDim int32 `json:"head_dim"`
}
// Qwen3Attention implements Qwen3 attention with QK norms
type Qwen3Attention struct {
QProj nn.LinearLayer `weight:"q_proj"`
KProj nn.LinearLayer `weight:"k_proj"`
VProj nn.LinearLayer `weight:"v_proj"`
OProj nn.LinearLayer `weight:"o_proj"`
QNorm *nn.RMSNorm `weight:"q_norm"`
KNorm *nn.RMSNorm `weight:"k_norm"`
// Computed fields
NHeads int32
NKVHeads int32
HeadDim int32
Scale float32
RopeTheta float32
}
// applyRoPEQwen3 applies the custom RoPE for Qwen3 text encoder
func applyRoPEQwen3(x *mlx.Array, seqLen int32, theta float32) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
H := shape[2]
D := shape[3]
half := D / 2
freqsArr := make([]float32, half)
logTheta := float32(math.Log(float64(theta)))
for i := int32(0); i < half; i++ {
freqsArr[i] = float32(math.Exp(float64(-logTheta * float32(i) / float32(half))))
}
freqs := mlx.NewArray(freqsArr, []int32{half})
posArr := make([]float32, seqLen)
for i := int32(0); i < seqLen; i++ {
posArr[i] = float32(i)
}
pos := mlx.NewArray(posArr, []int32{seqLen})
posExpanded := mlx.Reshape(pos, seqLen, 1)
freqsExpanded := mlx.Reshape(freqs, 1, half)
args := mlx.Mul(posExpanded, freqsExpanded)
cosVals := mlx.Cos(args)
sinVals := mlx.Sin(args)
cosVals = mlx.Reshape(cosVals, seqLen, 1, half)
sinVals = mlx.Reshape(sinVals, seqLen, 1, half)
x1 := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, L, H, half})
x2 := mlx.Slice(x, []int32{0, 0, 0, half}, []int32{B, L, H, D})
part1 := mlx.Sub(mlx.Mul(x1, cosVals), mlx.Mul(x2, sinVals))
part2 := mlx.Add(mlx.Mul(x1, sinVals), mlx.Mul(x2, cosVals))
return mlx.Concatenate([]*mlx.Array{part1, part2}, 3)
}
// Forward computes attention with causal masking
func (attn *Qwen3Attention) Forward(x *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
q := attn.QProj.Forward(x)
k := attn.KProj.Forward(x)
v := attn.VProj.Forward(x)
q = mlx.Reshape(q, B, L, attn.NHeads, attn.HeadDim)
k = mlx.Reshape(k, B, L, attn.NKVHeads, attn.HeadDim)
v = mlx.Reshape(v, B, L, attn.NKVHeads, attn.HeadDim)
// QK norm uses 1e-6 hardcoded (Qwen3 specific)
q = attn.QNorm.Forward(q, 1e-6)
k = attn.KNorm.Forward(k, 1e-6)
q = applyRoPEQwen3(q, L, attn.RopeTheta)
k = applyRoPEQwen3(k, L, attn.RopeTheta)
q = mlx.Transpose(q, 0, 2, 1, 3)
k = mlx.Transpose(k, 0, 2, 1, 3)
v = mlx.Transpose(v, 0, 2, 1, 3)
if attn.NKVHeads < attn.NHeads {
repeats := attn.NHeads / attn.NKVHeads
k = repeatKV(k, repeats)
v = repeatKV(v, repeats)
}
out := mlx.ScaledDotProductAttention(q, k, v, attn.Scale, true)
out = mlx.Transpose(out, 0, 2, 1, 3)
out = mlx.Reshape(out, B, L, attn.NHeads*attn.HeadDim)
out = attn.OProj.Forward(out)
return out
}
// repeatKV repeats key/value heads for GQA
func repeatKV(x *mlx.Array, repeats int32) *mlx.Array {
if repeats == 1 {
return x
}
shape := x.Shape()
x = mlx.ExpandDims(x, 2)
x = mlx.Tile(x, []int32{1, 1, repeats, 1, 1})
return mlx.Reshape(x, shape[0], shape[1]*repeats, shape[2], shape[3])
}
// Qwen3MLP implements Qwen3 SwiGLU MLP
type Qwen3MLP struct {
GateProj nn.LinearLayer `weight:"gate_proj"`
UpProj nn.LinearLayer `weight:"up_proj"`
DownProj nn.LinearLayer `weight:"down_proj"`
}
// Forward applies the MLP
func (m *Qwen3MLP) Forward(x *mlx.Array) *mlx.Array {
gate := m.GateProj.Forward(x)
gate = mlx.SiLU(gate)
up := m.UpProj.Forward(x)
h := mlx.Mul(gate, up)
return m.DownProj.Forward(h)
}
// Qwen3Block represents a single Qwen3 transformer block
type Qwen3Block struct {
Attention *Qwen3Attention `weight:"self_attn"`
MLP *Qwen3MLP `weight:"mlp"`
InputLayerNorm *nn.RMSNorm `weight:"input_layernorm"`
PostAttnLayerNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
}
// Forward applies the Qwen3 block
func (qb *Qwen3Block) Forward(x *mlx.Array, eps float32) *mlx.Array {
h := qb.InputLayerNorm.Forward(x, eps)
attnOut := qb.Attention.Forward(h)
x = mlx.Add(x, attnOut)
h = qb.PostAttnLayerNorm.Forward(x, eps)
mlpOut := qb.MLP.Forward(h)
x = mlx.Add(x, mlpOut)
return x
}
// Qwen3TextEncoder is the full Qwen3 encoder for Z-Image
type Qwen3TextEncoder struct {
EmbedTokens *nn.Embedding `weight:"model.embed_tokens"`
Layers []*Qwen3Block `weight:"model.layers"`
FinalNorm *nn.RMSNorm `weight:"model.norm"`
*Qwen3Config
}
// Load loads the Qwen3 text encoder from ollama blob storage.
func (m *Qwen3TextEncoder) Load(manifest *imagegen.ModelManifest) error {
fmt.Print(" Loading text encoder... ")
// Load config from blob
var cfg Qwen3Config
if err := manifest.ReadConfigJSON("text_encoder/config.json", &cfg); err != nil {
return fmt.Errorf("config: %w", err)
}
m.Qwen3Config = &cfg
m.Layers = make([]*Qwen3Block, cfg.NumHiddenLayers)
// Load weights from tensor blobs
weights, err := imagegen.LoadWeightsFromManifest(manifest, "text_encoder")
if err != nil {
return fmt.Errorf("weights: %w", err)
}
if err := weights.Load(0); err != nil {
return fmt.Errorf("load weights: %w", err)
}
defer weights.ReleaseAll()
return m.loadWeights(weights)
}
// loadWeights loads weights from any WeightSource into the model
func (m *Qwen3TextEncoder) loadWeights(weights safetensors.WeightSource) error {
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return fmt.Errorf("load module: %w", err)
}
m.initComputedFields()
fmt.Println("✓")
return nil
}
// initComputedFields initializes computed fields after loading weights
func (m *Qwen3TextEncoder) initComputedFields() {
cfg := m.Qwen3Config
m.FinalNorm.Eps = cfg.RMSNormEps
for _, block := range m.Layers {
// Attention
block.Attention.NHeads = cfg.NumAttentionHeads
block.Attention.NKVHeads = cfg.NumKeyValueHeads
block.Attention.HeadDim = cfg.HeadDim
block.Attention.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
block.Attention.RopeTheta = cfg.RopeTheta
block.Attention.QNorm.Eps = cfg.RMSNormEps
block.Attention.KNorm.Eps = cfg.RMSNormEps
// Block norms
block.InputLayerNorm.Eps = cfg.RMSNormEps
block.PostAttnLayerNorm.Eps = cfg.RMSNormEps
}
}
// Forward encodes text tokens
func (te *Qwen3TextEncoder) Forward(tokens *mlx.Array) *mlx.Array {
h := te.EmbedTokens.Forward(tokens)
eps := te.RMSNormEps
for _, layer := range te.Layers {
h = layer.Forward(h, eps)
}
// Apply final RMS norm
h = te.FinalNorm.Forward(h, eps)
return h
}
// ApplyChatTemplate wraps prompt in Qwen3 chat format // ApplyChatTemplate wraps prompt in Qwen3 chat format
func ApplyChatTemplate(prompt string) string { var ApplyChatTemplate = qwen3.ApplyChatTemplate
return "<|im_start|>user\n" + prompt + "<|im_end|>\n<|im_start|>assistant\n"
}
// EncodePrompt encodes a text prompt using the tokenizer and encoder
func (te *Qwen3TextEncoder) EncodePrompt(tok *tokenizer.Tokenizer, prompt string, maxLen int) (*mlx.Array, *mlx.Array) {
formattedPrompt := ApplyChatTemplate(prompt)
tokens := tok.Encode(formattedPrompt, false)
if len(tokens) > maxLen {
tokens = tokens[:maxLen]
}
maskData := make([]float32, maxLen)
for i := 0; i < len(tokens); i++ {
maskData[i] = 1.0
}
// Get PAD token (different from EOS for Qwen3)
padToken := tok.PAD()
if padToken < 0 {
padToken = tok.EOS() // fallback
}
paddedTokens := make([]int32, maxLen)
copy(paddedTokens, tokens)
for i := len(tokens); i < maxLen; i++ {
paddedTokens[i] = padToken
}
tokensArr := mlx.NewArrayInt32(paddedTokens, []int32{1, int32(maxLen)})
maskArr := mlx.NewArray(maskData, []int32{1, int32(maxLen)})
embeddings := te.Forward(tokensArr)
return embeddings, maskArr
}

Some files were not shown because too many files have changed in this diff Show More