Compare commits

..

18 Commits

Author SHA1 Message Date
jmorganca
8e56dab90b Add experimental image generation fields to /api/generate
Request fields (experimental):
- width: image width (max 4096)
- height: image height (max 4096)
- steps: denoising steps
- seed: random seed

Response fields (experimental):
- images: base64-encoded generated images
- completed: current step progress
- total: total steps

Other changes:
- Fix lifecycle bug where image models wouldn't unload (refCount issue)
- Fix "headers already written" error on Ctrl+C during streaming
- Add gin middleware for OpenAI /v1/images/generations compatibility
- Update CLI to use /api/generate with progress bar
- Add preload support in interactive mode
2026-01-17 14:08:06 -08:00
Michael
57de86cc61 docs: update claude code docs (#13757)
* docs: update claude code docs
2026-01-16 22:41:34 -08:00
Daniel Hiltgen
12719b6e87 MLX - dynamic loading of mlx-c (#13735)
* MLX - dynamic loading of mlx-c

Create a wrapper layer to indirect the dependency on mlx-c so
the main ollama binary does not have a load-time dependency on mlx-c, mlx, and on linux, cuda.  Lazy load the library via dlopen
so we can adjust the path to ensure the dependencies are found
and fail gracefully if not present.

* review comments

* fix broken tests
2026-01-16 16:34:22 -08:00
Patrick Devine
a077d996e3 Fix create and show commands for experimental models (#13741)
* x: make `ollama create --experimental` import from safetensors

This change allows pulling in safetensors models into the new experimental model format, and also
fixes the `ollama show` command to be able to correctly display the model information.

* gofumpt the linter

* gofumpt the linter again

* validate the model name
2026-01-16 14:31:55 -08:00
Jeffrey Morgan
c23d5095de x/imagegen: clean up image generation code (#13725) 2026-01-16 12:19:25 -08:00
Bruce MacDonald
7601f0e93e server: reject unexpected auth hosts (#13738)
Added validation to ensure auth redirects stay on the same host as the original request. The fix is a single check in getAuthorizationToken comparing the realm URL's host against the request host. Added tests for the auth flow.

Co-Authored-By: Gecko Security <188164982+geckosecurity@users.noreply.github.com>

* gofmt

---------

Co-authored-by: Gecko Security <188164982+geckosecurity@users.noreply.github.com>
2026-01-16 14:10:36 -05:00
Eva H
aad3f03890 app: allow macOS app to terminate during system shutdown (#13737) 2026-01-16 09:05:04 -05:00
Gyungrai Wang
55d0b6e8b9 integration: fix tools_test.go for ToolCallFunctionArguments API change (#13731) 2026-01-15 16:08:09 -08:00
Devon Rifkin
38eac40d56 openai: tweak v1/responses to conform better (#13736)
* openai: tweak v1/responses to conform better

* openai: provide better error for image URLs

* lint
2026-01-15 15:46:36 -08:00
Jeffrey Morgan
80f3f1bc25 readme: add instructions to build with MLX (#13733) 2026-01-15 11:03:52 -08:00
Parth Sareen
b1a0db547b docs: add env var needed for claude code in docs (#13721) 2026-01-15 10:11:00 -08:00
Parth Sareen
75d7b5f926 cmd: enable multi-line input and shift enter (#13694) 2026-01-14 17:52:46 -08:00
vincent d warmerdam
349d814814 docs: add marimo integration (#13326)
* docs added

* fix title

* add marimo to docs.json

---------

Co-authored-by: Devon Rifkin <drifkin@drifkin.net>
2026-01-14 17:37:38 -08:00
Yuhong Sun
c8743031e0 docs: add onyx integration (#13135)
* Ready for team review

* Update docs/integrations/onyx.mdx

Co-authored-by: Jeffrey Morgan <jmorganca@gmail.com>

* update docs.json

---------

Co-authored-by: Jeffrey Morgan <jmorganca@gmail.com>
Co-authored-by: Devon Rifkin <drifkin@drifkin.net>
2026-01-14 17:32:05 -08:00
Jeffrey Morgan
4adb9cf4bb scripts: fix macOS auto-update signature verification failure (#13713)
Add --norsrc flag to ditto commands when creating Ollama-darwin.zip
to exclude AppleDouble resource fork files (._* files) from the archive.

The mlx.metallib file has extended attributes, which causes ditto to
include a ._mlx.metallib AppleDouble file in the zip. Since this file
is not part of the code signature seal, macOS rejects the bundle during
auto-update verification with:

  "a sealed resource is missing or invalid"
  "file added: .../._mlx.metallib"

The --norsrc flag prevents ditto from preserving resource forks and
extended attributes, ensuring only signed files are included in the
release archive.
2026-01-14 07:48:10 -08:00
Daniel Hiltgen
74f475e735 Revert "Documentation edits made through Mintlify web editor" (#13688)
This reverts commit c6d4c0c7f2.

Merge after 0.14.0 ships for the updated Linux documentation.
2026-01-14 07:42:34 -08:00
Maternion
875cecba74 docs: update default context window size to 4096 tokens (#13709) 2026-01-14 01:01:28 -08:00
Josh Daniel Bañares
7d411a4686 docs: update web search param in examples (#13711) 2026-01-14 00:38:39 -08:00
91 changed files with 13152 additions and 5765 deletions

View File

@@ -32,7 +32,7 @@ ENV PATH=/${VULKANVERSION}/x86_64/bin:$PATH
FROM --platform=linux/arm64 almalinux:8 AS base-arm64
# install epel-release for ccache
RUN yum install -y yum-utils epel-release \
&& dnf install -y clang ccache \
&& dnf install -y clang ccache git \
&& yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/sbsa/cuda-rhel8.repo
ENV CC=clang CXX=clang++
@@ -149,6 +149,7 @@ COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
COPY x/ml/backend/mlx x/ml/backend/mlx
COPY go.mod go.sum .
COPY MLX_VERSION .
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
ENV PATH=/usr/local/go/bin:$PATH
RUN go mod download
@@ -156,14 +157,6 @@ RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'MLX CUDA 13' -DBLAS_INCLUDE_DIRS=/usr/include/openblas -DLAPACK_INCLUDE_DIRS=/usr/include/openblas \
&& cmake --build --parallel ${PARALLEL} --preset 'MLX CUDA 13' \
&& cmake --install build --component MLX --strip --parallel ${PARALLEL}
COPY . .
ARG GOFLAGS="'-ldflags=-w -s'"
ENV CGO_ENABLED=1
ARG CGO_CFLAGS
ARG CGO_CXXFLAGS
RUN mkdir -p dist/bin
RUN --mount=type=cache,target=/root/.cache/go-build \
go build -tags mlx -trimpath -buildmode=pie -o dist/bin/ollama-mlx .
FROM base AS build
WORKDIR /go/src/github.com/ollama/ollama
@@ -172,12 +165,14 @@ RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-
ENV PATH=/usr/local/go/bin:$PATH
RUN go mod download
COPY . .
# Clone mlx-c headers for CGO (version from MLX_VERSION file)
RUN git clone --depth 1 --branch "$(cat MLX_VERSION)" https://github.com/ml-explore/mlx-c.git build/_deps/mlx-c-src
ARG GOFLAGS="'-ldflags=-w -s'"
ENV CGO_ENABLED=1
ARG CGO_CFLAGS
ENV CGO_CFLAGS="-I/go/src/github.com/ollama/ollama/build/_deps/mlx-c-src"
ARG CGO_CXXFLAGS
RUN --mount=type=cache,target=/root/.cache/go-build \
go build -trimpath -buildmode=pie -o /bin/ollama .
go build -tags mlx -trimpath -buildmode=pie -o /bin/ollama .
FROM --platform=linux/amd64 scratch AS amd64
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
@@ -185,7 +180,6 @@ COPY --from=cuda-12 dist/lib/ollama /lib/ollama/
COPY --from=cuda-13 dist/lib/ollama /lib/ollama/
COPY --from=vulkan dist/lib/ollama /lib/ollama/
COPY --from=mlx /go/src/github.com/ollama/ollama/dist/lib/ollama /lib/ollama/
COPY --from=mlx /go/src/github.com/ollama/ollama/dist/bin/ /bin/
FROM --platform=linux/arm64 scratch AS arm64
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/

1
MLX_VERSION Normal file
View File

@@ -0,0 +1 @@
v0.4.1

View File

@@ -270,10 +270,10 @@ cmake --build --preset MLX --parallel
cmake --install build --component MLX
```
Next, build the `ollama-mlx` binary, which is a separate build of the Ollama runtime with MLX support enabled (needs to be in the same directory as `ollama`):
When building with the `-tags mlx` flag, the main `ollama` binary includes MLX support for experimental features like image generation:
```shell
go build -tags mlx -o ollama-mlx .
go build -tags mlx .
```
Finally, start the server:
@@ -322,6 +322,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
### Web & Desktop
- [Onyx](https://github.com/onyx-dot-app/onyx)
- [Open WebUI](https://github.com/open-webui/open-webui)
- [SwiftChat (macOS with ReactNative)](https://github.com/aws-samples/swift-chat)
- [Enchanted (macOS native)](https://github.com/AugustDev/enchanted)

View File

@@ -127,6 +127,25 @@ type GenerateRequest struct {
// each with an associated log probability. Only applies when Logprobs is true.
// Valid values are 0-20. Default is 0 (only return the selected token's logprob).
TopLogprobs int `json:"top_logprobs,omitempty"`
// Experimental: Image generation fields (may change or be removed)
// Width is the width of the generated image in pixels.
// Only used for image generation models.
Width int32 `json:"width,omitempty"`
// Height is the height of the generated image in pixels.
// Only used for image generation models.
Height int32 `json:"height,omitempty"`
// Steps is the number of diffusion steps for image generation.
// Only used for image generation models.
Steps int32 `json:"steps,omitempty"`
// Seed is the random seed for reproducible image generation.
// If 0 or not specified, a random seed will be used.
// Only used for image generation models.
Seed int64 `json:"seed,omitempty"`
}
// ChatRequest describes a request sent by [Client.Chat].
@@ -860,6 +879,20 @@ type GenerateResponse struct {
// Logprobs contains log probability information for the generated tokens,
// if requested via the Logprobs parameter.
Logprobs []Logprob `json:"logprobs,omitempty"`
// Experimental: Image generation fields (may change or be removed)
// Images contains base64-encoded generated images.
// Only present for image generation models.
Images []string `json:"images,omitempty"`
// Completed is the number of completed steps in image generation.
// Only present for image generation models during streaming.
Completed int64 `json:"completed,omitempty"`
// Total is the total number of steps for image generation.
// Only present for image generation models during streaming.
Total int64 `json:"total,omitempty"`
}
// ModelDetails provides details about a model.

View File

@@ -14,6 +14,7 @@ extern NSString *SystemWidePath;
@interface AppDelegate () <NSWindowDelegate, WKNavigationDelegate, WKUIDelegate>
@property(strong, nonatomic) NSStatusItem *statusItem;
@property(assign, nonatomic) BOOL updateAvailable;
@property(assign, nonatomic) BOOL systemShutdownInProgress;
@end
@implementation AppDelegate
@@ -40,6 +41,13 @@ bool firstTimeRun,startHidden; // Set in run before initialization
}
- (void)applicationDidFinishLaunching:(NSNotification *)aNotification {
// Register for system shutdown/restart notification so we can allow termination
[[[NSWorkspace sharedWorkspace] notificationCenter]
addObserver:self
selector:@selector(systemWillPowerOff:)
name:NSWorkspaceWillPowerOffNotification
object:nil];
// if we're in development mode, set the app icon
NSString *bundlePath = [[NSBundle mainBundle] bundlePath];
if (![bundlePath hasSuffix:@".app"]) {
@@ -278,7 +286,18 @@ bool firstTimeRun,startHidden; // Set in run before initialization
[NSApp activateIgnoringOtherApps:YES];
}
- (void)systemWillPowerOff:(NSNotification *)notification {
// Set flag so applicationShouldTerminate: knows to allow termination.
// The system will call applicationShouldTerminate: after posting this notification.
self.systemShutdownInProgress = YES;
}
- (NSApplicationTerminateReply)applicationShouldTerminate:(NSApplication *)sender {
// Allow termination if the system is shutting down or restarting
if (self.systemShutdownInProgress) {
return NSTerminateNow;
}
// Otherwise just hide the app (for Cmd+Q, close button, etc.)
[NSApp hide:nil];
[NSApp setActivationPolicy:NSApplicationActivationPolicyAccessory];
return NSTerminateCancel;

View File

@@ -46,8 +46,9 @@ import (
"github.com/ollama/ollama/types/syncmap"
"github.com/ollama/ollama/version"
xcmd "github.com/ollama/ollama/x/cmd"
"github.com/ollama/ollama/x/create"
xcreateclient "github.com/ollama/ollama/x/create/client"
"github.com/ollama/ollama/x/imagegen"
imagegenclient "github.com/ollama/ollama/x/imagegen/client"
)
const ConnectInstructions = "To sign in, navigate to:\n %s\n\n"
@@ -93,15 +94,87 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
p := progress.NewProgress(os.Stderr)
defer p.Stop()
// Validate model name early to fail fast
modelName := args[0]
name := model.ParseName(modelName)
if !name.IsValid() {
return fmt.Errorf("invalid model name: %s", modelName)
}
// Check for --experimental flag for safetensors model creation
experimental, _ := cmd.Flags().GetBool("experimental")
if experimental {
// Get Modelfile content - either from -f flag or default to "FROM ."
var reader io.Reader
filename, err := getModelfileName(cmd)
if os.IsNotExist(err) || filename == "" {
// No Modelfile specified or found - use default
reader = strings.NewReader("FROM .\n")
} else if err != nil {
return err
} else {
f, err := os.Open(filename)
if err != nil {
return err
}
defer f.Close()
reader = f
}
// Parse the Modelfile
modelfile, err := parser.ParseFile(reader)
if err != nil {
return fmt.Errorf("failed to parse Modelfile: %w", err)
}
// Extract FROM path and configuration
var modelDir string
mfConfig := &xcreateclient.ModelfileConfig{}
for _, cmd := range modelfile.Commands {
switch cmd.Name {
case "model":
modelDir = cmd.Args
case "template":
mfConfig.Template = cmd.Args
case "system":
mfConfig.System = cmd.Args
case "license":
mfConfig.License = cmd.Args
}
}
if modelDir == "" {
modelDir = "."
}
// Resolve relative paths based on Modelfile location
if !filepath.IsAbs(modelDir) && filename != "" {
modelDir = filepath.Join(filepath.Dir(filename), modelDir)
}
quantize, _ := cmd.Flags().GetString("quantize")
return xcreateclient.CreateModel(xcreateclient.CreateOptions{
ModelName: modelName,
ModelDir: modelDir,
Quantize: quantize,
Modelfile: mfConfig,
}, p)
}
var reader io.Reader
filename, err := getModelfileName(cmd)
if os.IsNotExist(err) {
if filename == "" {
// No Modelfile found - check if current directory is an image gen model
if imagegen.IsTensorModelDir(".") {
if create.IsTensorModelDir(".") {
quantize, _ := cmd.Flags().GetString("quantize")
return imagegenclient.CreateModel(args[0], ".", quantize, p)
return xcreateclient.CreateModel(xcreateclient.CreateOptions{
ModelName: modelName,
ModelDir: ".",
Quantize: quantize,
}, p)
}
reader = strings.NewReader("FROM .\n")
} else {
@@ -134,7 +207,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
}
spinner.Stop()
req.Model = args[0]
req.Model = modelName
quantize, _ := cmd.Flags().GetString("quantize")
if quantize != "" {
req.Quantize = quantize
@@ -527,7 +600,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
}
// Check if this is an image generation model
if slices.Contains(info.Capabilities, model.CapabilityImageGeneration) {
if slices.Contains(info.Capabilities, model.CapabilityImage) {
if opts.Prompt == "" && !interactive {
return errors.New("image generation models require a prompt. Usage: ollama run " + name + " \"your prompt here\"")
}
@@ -1742,15 +1815,22 @@ func NewCLI() *cobra.Command {
rootCmd.Flags().BoolP("version", "v", false, "Show version information")
createCmd := &cobra.Command{
Use: "create MODEL",
Short: "Create a model",
Args: cobra.ExactArgs(1),
PreRunE: checkServerHeartbeat,
RunE: CreateHandler,
Use: "create MODEL",
Short: "Create a model",
Args: cobra.ExactArgs(1),
PreRunE: func(cmd *cobra.Command, args []string) error {
// Skip server check for experimental mode (writes directly to disk)
if experimental, _ := cmd.Flags().GetBool("experimental"); experimental {
return nil
}
return checkServerHeartbeat(cmd, args)
},
RunE: CreateHandler,
}
createCmd.Flags().StringP("file", "f", "", "Name of the Modelfile (default \"Modelfile\")")
createCmd.Flags().StringP("quantize", "q", "", "Quantize model to this level (e.g. q4_K_M)")
createCmd.Flags().Bool("experimental", false, "Enable experimental safetensors model creation")
showCmd := &cobra.Command{
Use: "show MODEL",
@@ -1905,6 +1985,7 @@ func NewCLI() *cobra.Command {
} {
switch cmd {
case runCmd:
imagegen.AppendFlagsDocs(cmd)
appendEnvDocs(cmd, []envconfig.EnvVar{envVars["OLLAMA_HOST"], envVars["OLLAMA_NOHISTORY"]})
case serveCmd:
appendEnvDocs(cmd, []envconfig.EnvVar{

View File

@@ -1555,7 +1555,7 @@ func TestShowInfoImageGen(t *testing.T) {
ParameterSize: "10.3B",
QuantizationLevel: "FP8",
},
Capabilities: []model.Capability{model.CapabilityImageGeneration},
Capabilities: []model.Capability{model.CapabilityImage},
Requires: "0.14.0",
}, false, &b)
if err != nil {

View File

@@ -116,7 +116,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
Prompt: ">>> ",
AltPrompt: "... ",
Placeholder: "Send a message (/? for help)",
AltPlaceholder: `Use """ to end multi-line input`,
AltPlaceholder: "Press Enter to send",
})
if err != nil {
return err

View File

@@ -16,6 +16,7 @@
- [Generate Embeddings](#generate-embeddings)
- [List Running Models](#list-running-models)
- [Version](#version)
- [Experimental: Image Generation](#image-generation-experimental)
## Conventions
@@ -58,6 +59,16 @@ Advanced parameters (optional):
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
- `context` (deprecated): the context parameter returned from a previous request to `/generate`, this can be used to keep a short conversational memory
Experimental image generation parameters (for image generation models only):
> [!WARNING]
> These parameters are experimental and may change in future versions.
- `width`: width of the generated image in pixels (default: model-specific, typically 1024)
- `height`: height of the generated image in pixels (default: model-specific, typically 1024)
- `steps`: number of diffusion steps (default: model-specific)
- `seed`: random seed for reproducible image generation (default: random)
#### Structured outputs
Structured outputs are supported by providing a JSON schema in the `format` parameter. The model will generate a response that matches the schema. See the [structured outputs](#request-structured-outputs) example below.
@@ -1867,3 +1878,55 @@ curl http://localhost:11434/api/version
"version": "0.5.1"
}
```
## Experimental Features
### Image Generation (Experimental)
> [!WARNING]
> Image generation is experimental and may change in future versions.
Image generation is now supported through the standard `/api/generate` endpoint when using image generation models (such as Flux). The API automatically detects when an image generation model is being used.
See the [Generate a completion](#generate-a-completion) section for the full API documentation. The experimental image generation parameters (`width`, `height`, `steps`, `seed`) are documented there.
#### Example
##### Request
```shell
curl http://localhost:11434/api/generate -d '{
"model": "flux",
"prompt": "a sunset over mountains",
"width": 1024,
"height": 768
}'
```
##### Response (streaming)
Progress updates during generation:
```json
{
"model": "flux",
"created_at": "2024-01-15T10:30:00.000000Z",
"completed": 5,
"total": 20,
"done": false
}
```
##### Final Response
```json
{
"model": "flux",
"created_at": "2024-01-15T10:30:15.000000Z",
"images": ["iVBORw0KGgoAAAANSUhEUg..."],
"done": true,
"done_reason": "stop",
"total_duration": 15000000000,
"load_duration": 2000000000
}
```

View File

@@ -21,6 +21,7 @@ ollama pull glm-4.7:cloud
To use Ollama with tools that expect the Anthropic API (like Claude Code), set these environment variables:
```shell
export ANTHROPIC_AUTH_TOKEN=ollama # required but ignored
export ANTHROPIC_BASE_URL=http://localhost:11434
export ANTHROPIC_API_KEY=ollama # required but ignored
```
@@ -247,12 +248,13 @@ curl -X POST http://localhost:11434/v1/messages \
[Claude Code](https://code.claude.com/docs/en/overview) can be configured to use Ollama as its backend:
```shell
ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
```
Or set the environment variables in your shell profile:
```shell
export ANTHROPIC_AUTH_TOKEN=ollama
export ANTHROPIC_BASE_URL=http://localhost:11434
export ANTHROPIC_API_KEY=ollama
```

View File

@@ -110,7 +110,7 @@ More Ollama [Python example](https://github.com/ollama/ollama-python/blob/main/e
import { Ollama } from "ollama";
const client = new Ollama();
const results = await client.webSearch({ query: "what is ollama?" });
const results = await client.webSearch("what is ollama?");
console.log(JSON.stringify(results, null, 2));
```
@@ -213,7 +213,7 @@ models](https://ollama.com/models)\n\nAvailable for macOS, Windows, and Linux',
import { Ollama } from "ollama";
const client = new Ollama();
const fetchResult = await client.webFetch({ url: "https://ollama.com" });
const fetchResult = await client.webFetch("https://ollama.com");
console.log(JSON.stringify(fetchResult, null, 2));
```

View File

@@ -111,7 +111,9 @@
"/integrations/zed",
"/integrations/roo-code",
"/integrations/n8n",
"/integrations/xcode"
"/integrations/xcode",
"/integrations/onyx",
"/integrations/marimo"
]
},
{

View File

@@ -22,7 +22,7 @@ Please refer to the [GPU docs](./gpu).
## How can I specify the context window size?
By default, Ollama uses a context window size of 2048 tokens.
By default, Ollama uses a context window size of 4096 tokens.
This can be overridden with the `OLLAMA_CONTEXT_LENGTH` environment variable. For example, to set the default context window to 8K, use:

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 174 KiB

BIN
docs/images/marimo-chat.png Normal file
View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 80 KiB

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 230 KiB

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 178 KiB

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 186 KiB

BIN
docs/images/onyx-login.png Normal file
View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 100 KiB

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 306 KiB

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 300 KiB

BIN
docs/images/onyx-query.png Normal file
View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 211 KiB

View File

@@ -2,6 +2,12 @@
title: Claude Code
---
Claude Code is Anthropic's agentic coding tool that can read, modify, and execute code in your working directory.
Open models can be used with Claude Code through Ollama's Anthropic-compatible API, enabling you to use models such as `qwen3-coder`, `gpt-oss:20b`, or other models.
![Claude Code with Ollama](https://files.ollama.com/claude-code.png)
## Install
Install [Claude Code](https://code.claude.com/docs/en/overview):
@@ -25,22 +31,24 @@ Claude Code connects to Ollama using the Anthropic-compatible API.
1. Set the environment variables:
```shell
export ANTHROPIC_AUTH_TOKEN=ollama
export ANTHROPIC_BASE_URL=http://localhost:11434
export ANTHROPIC_API_KEY=ollama
```
2. Run Claude Code with an Ollama model:
```shell
claude --model qwen3-coder
claude --model gpt-oss:20b
```
Or run with environment variables inline:
```shell
ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 claude --model gpt-oss:20b
```
**Note:** Claude Code requires a large context window. We recommend at least 32K tokens. See the [context length documentation](/context-length) for how to adjust context length in Ollama.
## Connecting to ollama.com
1. Create an [API key](https://ollama.com/settings/keys) on ollama.com
@@ -67,3 +75,4 @@ claude --model glm-4.7:cloud
### Local models
- `qwen3-coder` - Excellent for coding tasks
- `gpt-oss:20b` - Strong general-purpose model
- `gpt-oss:120b` - Larger general-purpose model for more complex tasks

View File

@@ -0,0 +1,73 @@
---
title: marimo
---
## Install
Install [marimo](https://marimo.io). You can use `pip` or `uv` for this. You
can also use `uv` to create a sandboxed environment for marimo by running:
```
uvx marimo edit --sandbox notebook.py
```
## Usage with Ollama
1. In marimo, go to the user settings and go to the AI tab. From here
you can find and configure Ollama as an AI provider. For local use you
would typically point the base url to `http://localhost:11434/v1`.
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/marimo-settings.png"
alt="Ollama settings in marimo"
width="50%"
/>
</div>
2. Once the AI provider is set up, you can turn on/off specific AI models you'd like to access.
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/marimo-models.png"
alt="Selecting an Ollama model"
width="50%"
/>
</div>
3. You can also add a model to the list of available models by scrolling to the bottom and using the UI there.
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/marimo-add-model.png"
alt="Adding a new Ollama model"
width="50%"
/>
</div>
4. Once configured, you can now use Ollama for AI chats in marimo.
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/marimo-chat.png"
alt="Configure code completion"
width="50%"
/>
</div>
4. Alternatively, you can now use Ollama for **inline code completion** in marimo. This can be configured in the "AI Features" tab.
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/marimo-code-completion.png"
alt="Configure code completion"
width="50%"
/>
</div>
## Connecting to ollama.com
1. Sign in to ollama cloud via `ollama signin`
2. In the ollama model settings add a model that ollama hosts, like `gpt-oss:120b`.
3. You can now refer to this model in marimo!

View File

@@ -0,0 +1,63 @@
---
title: Onyx
---
## Overview
[Onyx](http://onyx.app/) is a self-hostable Chat UI that integrates with all Ollama models. Features include:
- Creating custom Agents
- Web search
- Deep Research
- RAG over uploaded documents and connected apps
- Connectors to applications like Google Drive, Email, Slack, etc.
- MCP and OpenAPI Actions support
- Image generation
- User/Groups management, RBAC, SSO, etc.
Onyx can be deployed for single users or large organizations.
## Install Onyx
Deploy Onyx with the [quickstart guide](https://docs.onyx.app/deployment/getting_started/quickstart).
<Info>
Resourcing/scaling docs [here](https://docs.onyx.app/deployment/getting_started/resourcing).
</Info>
## Usage with Ollama
1. Login to your Onyx deployment (create an account first).
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/onyx-login.png"
alt="Onyx Login Page"
width="75%"
/>
</div>
2. In the set-up process select `Ollama` as the LLM provider.
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/onyx-ollama-llm.png"
alt="Onyx Set Up Form"
width="75%"
/>
</div>
3. Provide your **Ollama API URL** and select your models.
<Note>If you're running Onyx in Docker, to access your computer's local network use `http://host.docker.internal` instead of `http://127.0.0.1`.</Note>
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/onyx-ollama-form.png"
alt="Selecting Ollama Models"
width="75%"
/>
</div>
You can also easily connect up Onyx Cloud with the `Ollama Cloud` tab of the setup.
## Send your first query
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/onyx-query.png"
alt="Onyx Query Example"
width="75%"
/>
</div>

View File

@@ -1,5 +1,5 @@
---
title: "Linux"
title: Linux
---
## Install
@@ -13,14 +13,15 @@ curl -fsSL https://ollama.com/install.sh | sh
## Manual install
<Note>
If you are upgrading from a prior version, you should remove the old libraries with `sudo rm -rf /usr/lib/ollama` first.
If you are upgrading from a prior version, you should remove the old libraries
with `sudo rm -rf /usr/lib/ollama` first.
</Note>
Download and extract the package:
```shell
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tgz \
| sudo tar zx -C /usr
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tar.zst \
| sudo tar x -C /usr
```
Start Ollama:
@@ -40,8 +41,8 @@ ollama -v
If you have an AMD GPU, also download and extract the additional ROCm package:
```shell
curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tgz \
| sudo tar zx -C /usr
curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tar.zst \
| sudo tar x -C /usr
```
### ARM64 install
@@ -49,8 +50,8 @@ curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tgz \
Download and extract the ARM64-specific package:
```shell
curl -fsSL https://ollama.com/download/ollama-linux-arm64.tgz \
| sudo tar zx -C /usr
curl -fsSL https://ollama.com/download/ollama-linux-arm64.tar.zst \
| sudo tar x -C /usr
```
### Adding Ollama as a startup service (recommended)
@@ -112,7 +113,11 @@ sudo systemctl status ollama
```
<Note>
While AMD has contributed the `amdgpu` driver upstream to the official linux kernel source, the version is older and may not support all ROCm features. We recommend you install the latest driver from https://www.amd.com/en/support/linux-drivers for best support of your Radeon GPU.
While AMD has contributed the `amdgpu` driver upstream to the official linux
kernel source, the version is older and may not support all ROCm features. We
recommend you install the latest driver from
https://www.amd.com/en/support/linux-drivers for best support of your Radeon
GPU.
</Note>
## Customizing
@@ -141,8 +146,8 @@ curl -fsSL https://ollama.com/install.sh | sh
Or by re-downloading Ollama:
```shell
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tgz \
| sudo tar zx -C /usr
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tar.zst \
| sudo tar x -C /usr
```
## Installing specific versions
@@ -191,4 +196,4 @@ Remove the downloaded models and Ollama service user and group:
sudo userdel ollama
sudo groupdel ollama
sudo rm -r /usr/share/ollama
```
```

View File

@@ -131,7 +131,7 @@ func TestAPIToolCalling(t *testing.T) {
t.Errorf("unexpected tool called: got %q want %q", lastToolCall.Function.Name, "get_weather")
}
if _, ok := lastToolCall.Function.Arguments["location"]; !ok {
if _, ok := lastToolCall.Function.Arguments.Get("location"); !ok {
t.Errorf("expected tool arguments to include 'location', got: %s", lastToolCall.Function.Arguments.String())
}
case <-ctx.Done():

View File

@@ -1464,6 +1464,12 @@ type CompletionRequest struct {
// TopLogprobs specifies the number of most likely alternative tokens to return (0-20)
TopLogprobs int
// Image generation fields
Width int32 `json:"width,omitempty"`
Height int32 `json:"height,omitempty"`
Steps int32 `json:"steps,omitempty"`
Seed int64 `json:"seed,omitempty"`
}
// DoneReason represents the reason why a completion response is done
@@ -1512,6 +1518,15 @@ type CompletionResponse struct {
// Logprobs contains log probability information if requested
Logprobs []Logprob `json:"logprobs,omitempty"`
// Image contains base64-encoded image data for image generation
Image string `json:"image,omitempty"`
// Step is the current step in image generation
Step int `json:"step,omitempty"`
// TotalSteps is the total number of steps for image generation
TotalSteps int `json:"total_steps,omitempty"`
}
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {

View File

@@ -8,6 +8,7 @@ import (
"math/rand"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
@@ -441,6 +442,7 @@ type ResponsesWriter struct {
stream bool
responseID string
itemID string
request openai.ResponsesRequest
}
func (w *ResponsesWriter) writeEvent(eventType string, data any) error {
@@ -478,7 +480,9 @@ func (w *ResponsesWriter) writeResponse(data []byte) (int, error) {
// Non-streaming response
w.ResponseWriter.Header().Set("Content-Type", "application/json")
response := openai.ToResponse(w.model, w.responseID, w.itemID, chatResponse)
response := openai.ToResponse(w.model, w.responseID, w.itemID, chatResponse, w.request)
completedAt := time.Now().Unix()
response.CompletedAt = &completedAt
return len(data), json.NewEncoder(w.ResponseWriter).Encode(response)
}
@@ -523,11 +527,12 @@ func ResponsesMiddleware() gin.HandlerFunc {
w := &ResponsesWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
converter: openai.NewResponsesStreamConverter(responseID, itemID, req.Model),
converter: openai.NewResponsesStreamConverter(responseID, itemID, req.Model, req),
model: req.Model,
stream: streamRequested,
responseID: responseID,
itemID: itemID,
request: req,
}
// Set headers based on streaming mode
@@ -541,3 +546,66 @@ func ResponsesMiddleware() gin.HandlerFunc {
c.Next()
}
}
type ImageWriter struct {
BaseWriter
}
func (w *ImageWriter) writeResponse(data []byte) (int, error) {
var generateResponse api.GenerateResponse
if err := json.Unmarshal(data, &generateResponse); err != nil {
return 0, err
}
// Only write response when done with images
if generateResponse.Done && len(generateResponse.Images) > 0 {
w.ResponseWriter.Header().Set("Content-Type", "application/json")
return len(data), json.NewEncoder(w.ResponseWriter).Encode(openai.ToImageGenerationResponse(generateResponse))
}
return len(data), nil
}
func (w *ImageWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(data)
}
return w.writeResponse(data)
}
func ImageGenerationsMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var req openai.ImageGenerationRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
return
}
if req.Prompt == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "prompt is required"))
return
}
if req.Model == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "model is required"))
return
}
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(openai.FromImageGenerationRequest(req)); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error()))
return
}
c.Request.Body = io.NopCloser(&b)
w := &ImageWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
}
c.Writer = w
c.Next()
}
}

View File

@@ -961,3 +961,154 @@ func TestRetrieveMiddleware(t *testing.T) {
}
}
}
func TestImageGenerationsMiddleware(t *testing.T) {
type testCase struct {
name string
body string
req api.GenerateRequest
err openai.ErrorResponse
}
var capturedRequest *api.GenerateRequest
testCases := []testCase{
{
name: "image generation basic",
body: `{
"model": "test-model",
"prompt": "a beautiful sunset"
}`,
req: api.GenerateRequest{
Model: "test-model",
Prompt: "a beautiful sunset",
},
},
{
name: "image generation with size",
body: `{
"model": "test-model",
"prompt": "a beautiful sunset",
"size": "512x768"
}`,
req: api.GenerateRequest{
Model: "test-model",
Prompt: "a beautiful sunset",
Width: 512,
Height: 768,
},
},
{
name: "image generation missing prompt",
body: `{
"model": "test-model"
}`,
err: openai.ErrorResponse{
Error: openai.Error{
Message: "prompt is required",
Type: "invalid_request_error",
},
},
},
{
name: "image generation missing model",
body: `{
"prompt": "a beautiful sunset"
}`,
err: openai.ErrorResponse{
Error: openai.Error{
Message: "model is required",
Type: "invalid_request_error",
},
},
},
}
endpoint := func(c *gin.Context) {
c.Status(http.StatusOK)
}
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(ImageGenerationsMiddleware(), captureRequestMiddleware(&capturedRequest))
router.Handle(http.MethodPost, "/api/generate", endpoint)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tc.body))
req.Header.Set("Content-Type", "application/json")
defer func() { capturedRequest = nil }()
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if tc.err.Error.Message != "" {
var errResp openai.ErrorResponse
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(tc.err, errResp); diff != "" {
t.Fatalf("errors did not match:\n%s", diff)
}
return
}
if resp.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d: %s", resp.Code, resp.Body.String())
}
if diff := cmp.Diff(&tc.req, capturedRequest); diff != "" {
t.Fatalf("requests did not match:\n%s", diff)
}
})
}
}
func TestImageWriterResponse(t *testing.T) {
gin.SetMode(gin.TestMode)
// Test that ImageWriter transforms GenerateResponse to OpenAI format
endpoint := func(c *gin.Context) {
resp := api.GenerateResponse{
Model: "test-model",
CreatedAt: time.Unix(1234567890, 0).UTC(),
Done: true,
Images: []string{"dGVzdC1pbWFnZS1kYXRh"}, // base64 of "test-image-data"
}
data, _ := json.Marshal(resp)
c.Writer.Write(append(data, '\n'))
}
router := gin.New()
router.Use(ImageGenerationsMiddleware())
router.Handle(http.MethodPost, "/api/generate", endpoint)
body := `{"model": "test-model", "prompt": "test"}`
req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d: %s", resp.Code, resp.Body.String())
}
var imageResp openai.ImageGenerationResponse
if err := json.Unmarshal(resp.Body.Bytes(), &imageResp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if imageResp.Created != 1234567890 {
t.Errorf("expected created 1234567890, got %d", imageResp.Created)
}
if len(imageResp.Data) != 1 {
t.Fatalf("expected 1 image, got %d", len(imageResp.Data))
}
if imageResp.Data[0].B64JSON != "dGVzdC1pbWFnZS1kYXRh" {
t.Errorf("expected image data 'dGVzdC1pbWFnZS1kYXRh', got %s", imageResp.Data[0].B64JSON)
}
}

View File

@@ -630,6 +630,10 @@ func nameFromToolCallID(messages []Message, toolCallID string) string {
// decodeImageURL decodes a base64 data URI into raw image bytes.
func decodeImageURL(url string) (api.ImageData, error) {
if strings.HasPrefix(url, "http://") || strings.HasPrefix(url, "https://") {
return nil, errors.New("image URLs are not currently supported, please use base64 encoded data instead")
}
types := []string{"jpeg", "jpg", "png", "webp"}
// Support blank mime type to match /api/chat's behavior of taking just unadorned base64
@@ -733,3 +737,57 @@ func FromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
DebugRenderOnly: r.DebugRenderOnly,
}, nil
}
// ImageGenerationRequest is an OpenAI-compatible image generation request.
type ImageGenerationRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
N int `json:"n,omitempty"`
Size string `json:"size,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
Seed *int64 `json:"seed,omitempty"`
}
// ImageGenerationResponse is an OpenAI-compatible image generation response.
type ImageGenerationResponse struct {
Created int64 `json:"created"`
Data []ImageURLOrData `json:"data"`
}
// ImageURLOrData contains either a URL or base64-encoded image data.
type ImageURLOrData struct {
URL string `json:"url,omitempty"`
B64JSON string `json:"b64_json,omitempty"`
}
// FromImageGenerationRequest converts an OpenAI image generation request to an Ollama GenerateRequest.
func FromImageGenerationRequest(r ImageGenerationRequest) api.GenerateRequest {
req := api.GenerateRequest{
Model: r.Model,
Prompt: r.Prompt,
}
// Parse size if provided (e.g., "1024x768")
if r.Size != "" {
var w, h int32
if _, err := fmt.Sscanf(r.Size, "%dx%d", &w, &h); err == nil {
req.Width = w
req.Height = h
}
}
if r.Seed != nil {
req.Seed = *r.Seed
}
return req
}
// ToImageGenerationResponse converts an Ollama GenerateResponse to an OpenAI ImageGenerationResponse.
func ToImageGenerationResponse(resp api.GenerateResponse) ImageGenerationResponse {
data := make([]ImageURLOrData, 0)
for _, img := range resp.Images {
data = append(data, ImageURLOrData{B64JSON: img})
}
return ImageGenerationResponse{
Created: resp.CreatedAt.Unix(),
Data: data,
}
}

View File

@@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"math/rand"
"time"
"github.com/ollama/ollama/api"
)
@@ -265,9 +266,9 @@ type ResponsesText struct {
type ResponsesTool struct {
Type string `json:"type"` // "function"
Name string `json:"name"`
Description string `json:"description,omitempty"`
Strict bool `json:"strict,omitempty"`
Parameters map[string]any `json:"parameters,omitempty"`
Description *string `json:"description"` // nullable but required
Strict *bool `json:"strict"` // nullable but required
Parameters map[string]any `json:"parameters"` // nullable but required
}
type ResponsesRequest struct {
@@ -475,11 +476,16 @@ func convertTool(t ResponsesTool) (api.Tool, error) {
}
}
var description string
if t.Description != nil {
description = *t.Description
}
return api.Tool{
Type: t.Type,
Function: api.ToolFunction{
Name: t.Name,
Description: t.Description,
Description: description,
Parameters: params,
},
}, nil
@@ -516,17 +522,60 @@ func convertInputMessage(m ResponsesInputMessage) (api.Message, error) {
// Response types for the Responses API
// ResponsesTextField represents the text output configuration in the response.
type ResponsesTextField struct {
Format ResponsesTextFormat `json:"format"`
}
// ResponsesReasoningOutput represents reasoning configuration in the response.
type ResponsesReasoningOutput struct {
Effort *string `json:"effort,omitempty"`
Summary *string `json:"summary,omitempty"`
}
// ResponsesError represents an error in the response.
type ResponsesError struct {
Code string `json:"code"`
Message string `json:"message"`
}
// ResponsesIncompleteDetails represents details about why a response was incomplete.
type ResponsesIncompleteDetails struct {
Reason string `json:"reason"`
}
type ResponsesResponse struct {
ID string `json:"id"`
Object string `json:"object"`
CreatedAt int64 `json:"created_at"`
Status string `json:"status"`
Model string `json:"model"`
Output []ResponsesOutputItem `json:"output"`
Usage *ResponsesUsage `json:"usage,omitempty"`
// TODO(drifkin): add `temperature` and `top_p` to the response, but this
// requires additional plumbing to find the effective values since the
// defaults can come from the model or the request
ID string `json:"id"`
Object string `json:"object"`
CreatedAt int64 `json:"created_at"`
CompletedAt *int64 `json:"completed_at"`
Status string `json:"status"`
IncompleteDetails *ResponsesIncompleteDetails `json:"incomplete_details"`
Model string `json:"model"`
PreviousResponseID *string `json:"previous_response_id"`
Instructions *string `json:"instructions"`
Output []ResponsesOutputItem `json:"output"`
Error *ResponsesError `json:"error"`
Tools []ResponsesTool `json:"tools"`
ToolChoice any `json:"tool_choice"`
Truncation string `json:"truncation"`
ParallelToolCalls bool `json:"parallel_tool_calls"`
Text ResponsesTextField `json:"text"`
TopP float64 `json:"top_p"`
PresencePenalty float64 `json:"presence_penalty"`
FrequencyPenalty float64 `json:"frequency_penalty"`
TopLogprobs int `json:"top_logprobs"`
Temperature float64 `json:"temperature"`
Reasoning *ResponsesReasoningOutput `json:"reasoning"`
Usage *ResponsesUsage `json:"usage"`
MaxOutputTokens *int `json:"max_output_tokens"`
MaxToolCalls *int `json:"max_tool_calls"`
Store bool `json:"store"`
Background bool `json:"background"`
ServiceTier string `json:"service_tier"`
Metadata map[string]any `json:"metadata"`
SafetyIdentifier *string `json:"safety_identifier"`
PromptCacheKey *string `json:"prompt_cache_key"`
}
type ResponsesOutputItem struct {
@@ -550,18 +599,39 @@ type ResponsesReasoningSummary struct {
}
type ResponsesOutputContent struct {
Type string `json:"type"` // "output_text"
Text string `json:"text"`
Type string `json:"type"` // "output_text"
Text string `json:"text"`
Annotations []any `json:"annotations"`
Logprobs []any `json:"logprobs"`
}
type ResponsesInputTokensDetails struct {
CachedTokens int `json:"cached_tokens"`
}
type ResponsesOutputTokensDetails struct {
ReasoningTokens int `json:"reasoning_tokens"`
}
type ResponsesUsage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
TotalTokens int `json:"total_tokens"`
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
TotalTokens int `json:"total_tokens"`
InputTokensDetails ResponsesInputTokensDetails `json:"input_tokens_details"`
OutputTokensDetails ResponsesOutputTokensDetails `json:"output_tokens_details"`
}
// ToResponse converts an api.ChatResponse to a Responses API response
func ToResponse(model, responseID, itemID string, chatResponse api.ChatResponse) ResponsesResponse {
// derefFloat64 returns the value of a float64 pointer, or a default if nil.
func derefFloat64(p *float64, def float64) float64 {
if p != nil {
return *p
}
return def
}
// ToResponse converts an api.ChatResponse to a Responses API response.
// The request is used to echo back request parameters in the response.
func ToResponse(model, responseID, itemID string, chatResponse api.ChatResponse, request ResponsesRequest) ResponsesResponse {
var output []ResponsesOutputItem
// Add reasoning item if thinking is present
@@ -585,6 +655,7 @@ func ToResponse(model, responseID, itemID string, chatResponse api.ChatResponse)
output = append(output, ResponsesOutputItem{
ID: fmt.Sprintf("fc_%s_%d", responseID, i),
Type: "function_call",
Status: "completed",
CallID: tc.ID,
Name: tc.Function.Name,
Arguments: tc.Function.Arguments,
@@ -598,25 +669,90 @@ func ToResponse(model, responseID, itemID string, chatResponse api.ChatResponse)
Role: "assistant",
Content: []ResponsesOutputContent{
{
Type: "output_text",
Text: chatResponse.Message.Content,
Type: "output_text",
Text: chatResponse.Message.Content,
Annotations: []any{},
Logprobs: []any{},
},
},
})
}
var instructions *string
if request.Instructions != "" {
instructions = &request.Instructions
}
// Build truncation with default
truncation := "disabled"
if request.Truncation != nil {
truncation = *request.Truncation
}
tools := request.Tools
if tools == nil {
tools = []ResponsesTool{}
}
text := ResponsesTextField{
Format: ResponsesTextFormat{Type: "text"},
}
if request.Text != nil && request.Text.Format != nil {
text.Format = *request.Text.Format
}
// Build reasoning output from request
var reasoning *ResponsesReasoningOutput
if request.Reasoning.Effort != "" || request.Reasoning.Summary != "" {
reasoning = &ResponsesReasoningOutput{}
if request.Reasoning.Effort != "" {
reasoning.Effort = &request.Reasoning.Effort
}
if request.Reasoning.Summary != "" {
reasoning.Summary = &request.Reasoning.Summary
}
}
return ResponsesResponse{
ID: responseID,
Object: "response",
CreatedAt: chatResponse.CreatedAt.Unix(),
Status: "completed",
Model: model,
Output: output,
ID: responseID,
Object: "response",
CreatedAt: chatResponse.CreatedAt.Unix(),
CompletedAt: nil, // Set by middleware when writing final response
Status: "completed",
IncompleteDetails: nil, // Only populated if response incomplete
Model: model,
PreviousResponseID: nil, // Not supported
Instructions: instructions,
Output: output,
Error: nil, // Only populated on failure
Tools: tools,
ToolChoice: "auto", // Default value
Truncation: truncation,
ParallelToolCalls: true, // Default value
Text: text,
TopP: derefFloat64(request.TopP, 1.0),
PresencePenalty: 0, // Default value
FrequencyPenalty: 0, // Default value
TopLogprobs: 0, // Default value
Temperature: derefFloat64(request.Temperature, 1.0),
Reasoning: reasoning,
Usage: &ResponsesUsage{
InputTokens: chatResponse.PromptEvalCount,
OutputTokens: chatResponse.EvalCount,
TotalTokens: chatResponse.PromptEvalCount + chatResponse.EvalCount,
// TODO(drifkin): wire through the actual values
InputTokensDetails: ResponsesInputTokensDetails{CachedTokens: 0},
// TODO(drifkin): wire through the actual values
OutputTokensDetails: ResponsesOutputTokensDetails{ReasoningTokens: 0},
},
MaxOutputTokens: request.MaxOutputTokens,
MaxToolCalls: nil, // Not supported
Store: false, // We don't store responses
Background: request.Background,
ServiceTier: "default", // Default value
Metadata: map[string]any{},
SafetyIdentifier: nil, // Not supported
PromptCacheKey: nil, // Not supported
}
}
@@ -636,6 +772,7 @@ type ResponsesStreamConverter struct {
responseID string
itemID string
model string
request ResponsesRequest
// State tracking (mutated across Process calls)
firstWrite bool
@@ -668,11 +805,12 @@ func (c *ResponsesStreamConverter) newEvent(eventType string, data map[string]an
}
// NewResponsesStreamConverter creates a new converter with the given configuration.
func NewResponsesStreamConverter(responseID, itemID, model string) *ResponsesStreamConverter {
func NewResponsesStreamConverter(responseID, itemID, model string, request ResponsesRequest) *ResponsesStreamConverter {
return &ResponsesStreamConverter{
responseID: responseID,
itemID: itemID,
model: model,
request: request,
firstWrite: true,
}
}
@@ -717,25 +855,120 @@ func (c *ResponsesStreamConverter) Process(r api.ChatResponse) []ResponsesStream
return events
}
// buildResponseObject creates a full response object with all required fields for streaming events.
func (c *ResponsesStreamConverter) buildResponseObject(status string, output []any, usage map[string]any) map[string]any {
var instructions any = nil
if c.request.Instructions != "" {
instructions = c.request.Instructions
}
truncation := "disabled"
if c.request.Truncation != nil {
truncation = *c.request.Truncation
}
var tools []any
if c.request.Tools != nil {
for _, t := range c.request.Tools {
tools = append(tools, map[string]any{
"type": t.Type,
"name": t.Name,
"description": t.Description,
"strict": t.Strict,
"parameters": t.Parameters,
})
}
}
if tools == nil {
tools = []any{}
}
textFormat := map[string]any{"type": "text"}
if c.request.Text != nil && c.request.Text.Format != nil {
textFormat = map[string]any{
"type": c.request.Text.Format.Type,
}
if c.request.Text.Format.Name != "" {
textFormat["name"] = c.request.Text.Format.Name
}
if c.request.Text.Format.Schema != nil {
textFormat["schema"] = c.request.Text.Format.Schema
}
if c.request.Text.Format.Strict != nil {
textFormat["strict"] = *c.request.Text.Format.Strict
}
}
var reasoning any = nil
if c.request.Reasoning.Effort != "" || c.request.Reasoning.Summary != "" {
r := map[string]any{}
if c.request.Reasoning.Effort != "" {
r["effort"] = c.request.Reasoning.Effort
} else {
r["effort"] = nil
}
if c.request.Reasoning.Summary != "" {
r["summary"] = c.request.Reasoning.Summary
} else {
r["summary"] = nil
}
reasoning = r
}
// Build top_p and temperature with defaults
topP := 1.0
if c.request.TopP != nil {
topP = *c.request.TopP
}
temperature := 1.0
if c.request.Temperature != nil {
temperature = *c.request.Temperature
}
return map[string]any{
"id": c.responseID,
"object": "response",
"created_at": time.Now().Unix(),
"completed_at": nil,
"status": status,
"incomplete_details": nil,
"model": c.model,
"previous_response_id": nil,
"instructions": instructions,
"output": output,
"error": nil,
"tools": tools,
"tool_choice": "auto",
"truncation": truncation,
"parallel_tool_calls": true,
"text": map[string]any{"format": textFormat},
"top_p": topP,
"presence_penalty": 0,
"frequency_penalty": 0,
"top_logprobs": 0,
"temperature": temperature,
"reasoning": reasoning,
"usage": usage,
"max_output_tokens": c.request.MaxOutputTokens,
"max_tool_calls": nil,
"store": false,
"background": c.request.Background,
"service_tier": "default",
"metadata": map[string]any{},
"safety_identifier": nil,
"prompt_cache_key": nil,
}
}
func (c *ResponsesStreamConverter) createResponseCreatedEvent() ResponsesStreamEvent {
return c.newEvent("response.created", map[string]any{
"response": map[string]any{
"id": c.responseID,
"object": "response",
"status": "in_progress",
"output": []any{},
},
"response": c.buildResponseObject("in_progress", []any{}, nil),
})
}
func (c *ResponsesStreamConverter) createResponseInProgressEvent() ResponsesStreamEvent {
return c.newEvent("response.in_progress", map[string]any{
"response": map[string]any{
"id": c.responseID,
"object": "response",
"status": "in_progress",
"output": []any{},
},
"response": c.buildResponseObject("in_progress", []any{}, nil),
})
}
@@ -762,9 +995,10 @@ func (c *ResponsesStreamConverter) processThinking(thinking string) []ResponsesS
// Emit delta
events = append(events, c.newEvent("response.reasoning_summary_text.delta", map[string]any{
"item_id": c.reasoningItemID,
"output_index": c.outputIndex,
"delta": thinking,
"item_id": c.reasoningItemID,
"output_index": c.outputIndex,
"summary_index": 0,
"delta": thinking,
}))
// TODO(drifkin): consider adding
@@ -783,9 +1017,10 @@ func (c *ResponsesStreamConverter) finishReasoning() []ResponsesStreamEvent {
events := []ResponsesStreamEvent{
c.newEvent("response.reasoning_summary_text.done", map[string]any{
"item_id": c.reasoningItemID,
"output_index": c.outputIndex,
"text": c.accumulatedThinking,
"item_id": c.reasoningItemID,
"output_index": c.outputIndex,
"summary_index": 0,
"text": c.accumulatedThinking,
}),
c.newEvent("response.output_item.done", map[string]any{
"output_index": c.outputIndex,
@@ -898,8 +1133,10 @@ func (c *ResponsesStreamConverter) processTextContent(content string) []Response
"output_index": c.outputIndex,
"content_index": c.contentIndex,
"part": map[string]any{
"type": "output_text",
"text": "",
"type": "output_text",
"text": "",
"annotations": []any{},
"logprobs": []any{},
},
}))
}
@@ -913,6 +1150,7 @@ func (c *ResponsesStreamConverter) processTextContent(content string) []Response
"output_index": c.outputIndex,
"content_index": 0,
"delta": content,
"logprobs": []any{},
}))
return events
@@ -944,8 +1182,10 @@ func (c *ResponsesStreamConverter) buildFinalOutput() []any {
"status": "completed",
"role": "assistant",
"content": []map[string]any{{
"type": "output_text",
"text": c.accumulatedText,
"type": "output_text",
"text": c.accumulatedText,
"annotations": []any{},
"logprobs": []any{},
}},
})
}
@@ -967,6 +1207,7 @@ func (c *ResponsesStreamConverter) processCompletion(r api.ChatResponse) []Respo
"output_index": c.outputIndex,
"content_index": 0,
"text": c.accumulatedText,
"logprobs": []any{},
}))
// response.content_part.done
@@ -975,8 +1216,10 @@ func (c *ResponsesStreamConverter) processCompletion(r api.ChatResponse) []Respo
"output_index": c.outputIndex,
"content_index": 0,
"part": map[string]any{
"type": "output_text",
"text": c.accumulatedText,
"type": "output_text",
"text": c.accumulatedText,
"annotations": []any{},
"logprobs": []any{},
},
}))
@@ -989,26 +1232,31 @@ func (c *ResponsesStreamConverter) processCompletion(r api.ChatResponse) []Respo
"status": "completed",
"role": "assistant",
"content": []map[string]any{{
"type": "output_text",
"text": c.accumulatedText,
"type": "output_text",
"text": c.accumulatedText,
"annotations": []any{},
"logprobs": []any{},
}},
},
}))
}
// response.completed
events = append(events, c.newEvent("response.completed", map[string]any{
"response": map[string]any{
"id": c.responseID,
"object": "response",
"status": "completed",
"output": c.buildFinalOutput(),
"usage": map[string]any{
"input_tokens": r.PromptEvalCount,
"output_tokens": r.EvalCount,
"total_tokens": r.PromptEvalCount + r.EvalCount,
},
usage := map[string]any{
"input_tokens": r.PromptEvalCount,
"output_tokens": r.EvalCount,
"total_tokens": r.PromptEvalCount + r.EvalCount,
"input_tokens_details": map[string]any{
"cached_tokens": 0,
},
"output_tokens_details": map[string]any{
"reasoning_tokens": 0,
},
}
response := c.buildResponseObject("completed", c.buildFinalOutput(), usage)
response["completed_at"] = time.Now().Unix()
events = append(events, c.newEvent("response.completed", map[string]any{
"response": response,
}))
return events

View File

@@ -850,7 +850,7 @@ func TestFromResponsesRequest_Images(t *testing.T) {
}
func TestResponsesStreamConverter_TextOnly(t *testing.T) {
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
// First chunk with content
events := converter.Process(api.ChatResponse{
@@ -916,7 +916,7 @@ func TestResponsesStreamConverter_TextOnly(t *testing.T) {
}
func TestResponsesStreamConverter_ToolCalls(t *testing.T) {
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
events := converter.Process(api.ChatResponse{
Message: api.Message{
@@ -952,7 +952,7 @@ func TestResponsesStreamConverter_ToolCalls(t *testing.T) {
}
func TestResponsesStreamConverter_Reasoning(t *testing.T) {
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
// First chunk with thinking
events := converter.Process(api.ChatResponse{
@@ -1267,7 +1267,7 @@ func TestToResponse_WithReasoning(t *testing.T) {
Content: "The answer is 42",
},
Done: true,
})
}, ResponsesRequest{})
// Should have 2 output items: reasoning + message
if len(response.Output) != 2 {
@@ -1638,7 +1638,7 @@ func TestFromResponsesRequest_ShorthandFormats(t *testing.T) {
func TestResponsesStreamConverter_OutputIncludesContent(t *testing.T) {
// Verify that response.output_item.done includes content field for messages
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
// First chunk
converter.Process(api.ChatResponse{
@@ -1686,7 +1686,7 @@ func TestResponsesStreamConverter_OutputIncludesContent(t *testing.T) {
func TestResponsesStreamConverter_ResponseCompletedIncludesOutput(t *testing.T) {
// Verify that response.completed includes the output array
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
// Process some content
converter.Process(api.ChatResponse{
@@ -1730,7 +1730,7 @@ func TestResponsesStreamConverter_ResponseCompletedIncludesOutput(t *testing.T)
func TestResponsesStreamConverter_ResponseCreatedIncludesOutput(t *testing.T) {
// Verify that response.created includes an empty output array
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
events := converter.Process(api.ChatResponse{
Message: api.Message{Content: "Hi"},
@@ -1757,7 +1757,7 @@ func TestResponsesStreamConverter_ResponseCreatedIncludesOutput(t *testing.T) {
func TestResponsesStreamConverter_SequenceNumbers(t *testing.T) {
// Verify that events include incrementing sequence numbers
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
events := converter.Process(api.ChatResponse{
Message: api.Message{Content: "Hello"},
@@ -1791,7 +1791,7 @@ func TestResponsesStreamConverter_SequenceNumbers(t *testing.T) {
func TestResponsesStreamConverter_FunctionCallStatus(t *testing.T) {
// Verify that function call items include status field
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
events := converter.Process(api.ChatResponse{
Message: api.Message{

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"io"
"os"
"strings"
)
type Prompt struct {
@@ -36,10 +37,11 @@ type Terminal struct {
}
type Instance struct {
Prompt *Prompt
Terminal *Terminal
History *History
Pasting bool
Prompt *Prompt
Terminal *Terminal
History *History
Pasting bool
pastedLines []string
}
func New(prompt Prompt) (*Instance, error) {
@@ -174,6 +176,8 @@ func (i *Instance) Readline() (string, error) {
case CharEsc:
esc = true
case CharInterrupt:
i.pastedLines = nil
i.Prompt.UseAlt = false
return "", ErrInterrupt
case CharPrev:
i.historyPrev(buf, &currentLineBuf)
@@ -188,7 +192,23 @@ func (i *Instance) Readline() (string, error) {
case CharForward:
buf.MoveRight()
case CharBackspace, CharCtrlH:
buf.Remove()
if buf.IsEmpty() && len(i.pastedLines) > 0 {
lastIdx := len(i.pastedLines) - 1
prevLine := i.pastedLines[lastIdx]
i.pastedLines = i.pastedLines[:lastIdx]
fmt.Print(CursorBOL + ClearToEOL + CursorUp + CursorBOL + ClearToEOL)
if len(i.pastedLines) == 0 {
fmt.Print(i.Prompt.Prompt)
i.Prompt.UseAlt = false
} else {
fmt.Print(i.Prompt.AltPrompt)
}
for _, r := range prevLine {
buf.Add(r)
}
} else {
buf.Remove()
}
case CharTab:
// todo: convert back to real tabs
for range 8 {
@@ -211,13 +231,28 @@ func (i *Instance) Readline() (string, error) {
case CharCtrlZ:
fd := os.Stdin.Fd()
return handleCharCtrlZ(fd, i.Terminal.termios)
case CharEnter, CharCtrlJ:
case CharCtrlJ:
i.pastedLines = append(i.pastedLines, buf.String())
buf.Buf.Clear()
buf.Pos = 0
buf.DisplayPos = 0
buf.LineHasSpace.Clear()
fmt.Println()
fmt.Print(i.Prompt.AltPrompt)
i.Prompt.UseAlt = true
continue
case CharEnter:
output := buf.String()
if len(i.pastedLines) > 0 {
output = strings.Join(i.pastedLines, "\n") + "\n" + output
i.pastedLines = nil
}
if output != "" {
i.History.Add(output)
}
buf.MoveToEnd()
fmt.Println()
i.Prompt.UseAlt = false
return output, nil
default:

View File

@@ -60,7 +60,7 @@ _build_darwin() {
cmake --install $BUILD_DIR --component MLX
# Override CGO flags to point to the amd64 build directory
MLX_CGO_CFLAGS="-O3 -I$(pwd)/$BUILD_DIR/_deps/mlx-c-src -mmacosx-version-min=14.0"
MLX_CGO_LDFLAGS="-L$(pwd)/$BUILD_DIR/lib/ollama -lmlxc -lmlx -Wl,-rpath,@executable_path -lc++ -framework Accelerate -mmacosx-version-min=14.0"
MLX_CGO_LDFLAGS="-ldl -lc++ -framework Accelerate -mmacosx-version-min=14.0"
else
BUILD_DIR=build
cmake --preset MLX \
@@ -71,10 +71,12 @@ _build_darwin() {
cmake --install $BUILD_DIR --component MLX
# Use default CGO flags from mlx.go for arm64
MLX_CGO_CFLAGS="-O3 -I$(pwd)/$BUILD_DIR/_deps/mlx-c-src -mmacosx-version-min=14.0"
MLX_CGO_LDFLAGS="-L$(pwd)/$BUILD_DIR/lib/ollama -lmlxc -lmlx -Wl,-rpath,@executable_path -lc++ -framework Metal -framework Foundation -framework Accelerate -mmacosx-version-min=14.0"
MLX_CGO_LDFLAGS="-lc++ -framework Metal -framework Foundation -framework Accelerate -mmacosx-version-min=14.0"
fi
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 CGO_CFLAGS="$MLX_CGO_CFLAGS" CGO_LDFLAGS="$MLX_CGO_LDFLAGS" go build -tags mlx -o $INSTALL_PREFIX/ollama-mlx .
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 go build -o $INSTALL_PREFIX .
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 CGO_CFLAGS="$MLX_CGO_CFLAGS" CGO_LDFLAGS="$MLX_CGO_LDFLAGS" go build -tags mlx -o $INSTALL_PREFIX .
# Copy MLX libraries to same directory as executable for dlopen
cp $INSTALL_PREFIX/lib/ollama/libmlxc.dylib $INSTALL_PREFIX/
cp $INSTALL_PREFIX/lib/ollama/libmlx.dylib $INSTALL_PREFIX/
done
}
@@ -82,12 +84,10 @@ _sign_darwin() {
status "Creating universal binary..."
mkdir -p dist/darwin
lipo -create -output dist/darwin/ollama dist/darwin-*/ollama
lipo -create -output dist/darwin/ollama-mlx dist/darwin-*/ollama-mlx
chmod +x dist/darwin/ollama
chmod +x dist/darwin/ollama-mlx
if [ -n "$APPLE_IDENTITY" ]; then
for F in dist/darwin/ollama dist/darwin-*/lib/ollama/* dist/darwin/ollama-mlx; do
for F in dist/darwin/ollama dist/darwin-*/lib/ollama/*; do
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime $F
done
@@ -154,7 +154,6 @@ _build_macapp() {
mkdir -p dist/Ollama.app/Contents/Resources
if [ -d dist/darwin-amd64 ]; then
lipo -create -output dist/Ollama.app/Contents/Resources/ollama dist/darwin-amd64/ollama dist/darwin-arm64/ollama
lipo -create -output dist/Ollama.app/Contents/Resources/ollama-mlx dist/darwin-amd64/ollama-mlx dist/darwin-arm64/ollama-mlx
for F in dist/darwin-amd64/lib/ollama/*mlx*.dylib ; do
lipo -create -output dist/darwin/$(basename $F) $F dist/darwin-arm64/lib/ollama/$(basename $F)
done
@@ -166,28 +165,27 @@ _build_macapp() {
cp -a dist/darwin/ollama dist/Ollama.app/Contents/Resources/ollama
cp dist/darwin/*.so dist/darwin/*.dylib dist/Ollama.app/Contents/Resources/
fi
cp -a dist/darwin/ollama-mlx dist/Ollama.app/Contents/Resources/ollama-mlx
chmod a+x dist/Ollama.app/Contents/Resources/ollama
# Sign
if [ -n "$APPLE_IDENTITY" ]; then
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime dist/Ollama.app/Contents/Resources/ollama
for lib in dist/Ollama.app/Contents/Resources/*.so dist/Ollama.app/Contents/Resources/*.dylib dist/Ollama.app/Contents/Resources/*.metallib dist/Ollama.app/Contents/Resources/ollama-mlx ; do
for lib in dist/Ollama.app/Contents/Resources/*.so dist/Ollama.app/Contents/Resources/*.dylib dist/Ollama.app/Contents/Resources/*.metallib ; do
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime ${lib}
done
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier com.electron.ollama --deep --options=runtime dist/Ollama.app
fi
rm -f dist/Ollama-darwin.zip
ditto -c -k --keepParent dist/Ollama.app dist/Ollama-darwin.zip
(cd dist/Ollama.app/Contents/Resources/; tar -cf - ollama ollama-mlx *.so *.dylib *.metallib 2>/dev/null) | gzip -9vc > dist/ollama-darwin.tgz
ditto -c -k --norsrc --keepParent dist/Ollama.app dist/Ollama-darwin.zip
(cd dist/Ollama.app/Contents/Resources/; tar -cf - ollama *.so *.dylib *.metallib 2>/dev/null) | gzip -9vc > dist/ollama-darwin.tgz
# Notarize and Staple
if [ -n "$APPLE_IDENTITY" ]; then
$(xcrun -f notarytool) submit dist/Ollama-darwin.zip --wait --timeout 20m --apple-id "$APPLE_ID" --password "$APPLE_PASSWORD" --team-id "$APPLE_TEAM_ID"
rm -f dist/Ollama-darwin.zip
$(xcrun -f stapler) staple dist/Ollama.app
ditto -c -k --keepParent dist/Ollama.app dist/Ollama-darwin.zip
ditto -c -k --norsrc --keepParent dist/Ollama.app dist/Ollama-darwin.zip
rm -f dist/Ollama.dmg

View File

@@ -50,12 +50,17 @@ func (r registryChallenge) URL() (*url.URL, error) {
return redirectURL, nil
}
func getAuthorizationToken(ctx context.Context, challenge registryChallenge) (string, error) {
func getAuthorizationToken(ctx context.Context, challenge registryChallenge, originalHost string) (string, error) {
redirectURL, err := challenge.URL()
if err != nil {
return "", err
}
// Validate that the realm host matches the original request host to prevent sending tokens cross-origin.
if redirectURL.Host != originalHost {
return "", fmt.Errorf("realm host %q does not match original host %q", redirectURL.Host, originalHost)
}
sha256sum := sha256.Sum256(nil)
data := []byte(fmt.Sprintf("%s,%s,%s", http.MethodGet, redirectURL.String(), base64.StdEncoding.EncodeToString([]byte(hex.EncodeToString(sha256sum[:])))))

113
server/auth_test.go Normal file
View File

@@ -0,0 +1,113 @@
package server
import (
"context"
"strings"
"testing"
"time"
)
func TestGetAuthorizationTokenRejectsCrossDomain(t *testing.T) {
tests := []struct {
realm string
originalHost string
wantMismatch bool
}{
{"https://example.com/token", "example.com", false},
{"https://example.com/token", "other.com", true},
{"https://example.com/token", "localhost:8000", true},
{"https://localhost:5000/token", "localhost:5000", false},
{"https://localhost:5000/token", "localhost:6000", true},
}
for _, tt := range tests {
t.Run(tt.originalHost, func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
challenge := registryChallenge{Realm: tt.realm, Service: "test", Scope: "repo:x:pull"}
_, err := getAuthorizationToken(ctx, challenge, tt.originalHost)
isMismatch := err != nil && strings.Contains(err.Error(), "does not match")
if tt.wantMismatch && !isMismatch {
t.Errorf("expected domain mismatch error, got: %v", err)
}
if !tt.wantMismatch && isMismatch {
t.Errorf("unexpected domain mismatch error: %v", err)
}
})
}
}
func TestParseRegistryChallenge(t *testing.T) {
tests := []struct {
input string
wantRealm, wantService, wantScope string
}{
{
`Bearer realm="https://auth.example.com/token",service="registry",scope="repo:foo:pull"`,
"https://auth.example.com/token", "registry", "repo:foo:pull",
},
{
`Bearer realm="https://r.ollama.ai/v2/token",service="ollama",scope="-"`,
"https://r.ollama.ai/v2/token", "ollama", "-",
},
{"", "", "", ""},
}
for _, tt := range tests {
result := parseRegistryChallenge(tt.input)
if result.Realm != tt.wantRealm || result.Service != tt.wantService || result.Scope != tt.wantScope {
t.Errorf("parseRegistryChallenge(%q) = {%q, %q, %q}, want {%q, %q, %q}",
tt.input, result.Realm, result.Service, result.Scope,
tt.wantRealm, tt.wantService, tt.wantScope)
}
}
}
func TestRegistryChallengeURL(t *testing.T) {
challenge := registryChallenge{
Realm: "https://auth.example.com/token",
Service: "registry",
Scope: "repo:foo:pull repo:bar:push",
}
u, err := challenge.URL()
if err != nil {
t.Fatalf("URL() error: %v", err)
}
if u.Host != "auth.example.com" {
t.Errorf("host = %q, want %q", u.Host, "auth.example.com")
}
if u.Path != "/token" {
t.Errorf("path = %q, want %q", u.Path, "/token")
}
q := u.Query()
if q.Get("service") != "registry" {
t.Errorf("service = %q, want %q", q.Get("service"), "registry")
}
if scopes := q["scope"]; len(scopes) != 2 {
t.Errorf("scope count = %d, want 2", len(scopes))
}
if q.Get("ts") == "" {
t.Error("missing ts")
}
if q.Get("nonce") == "" {
t.Error("missing nonce")
}
// Nonces should differ between calls
u2, _ := challenge.URL()
if q.Get("nonce") == u2.Query().Get("nonce") {
t.Error("nonce should be unique per call")
}
}
func TestRegistryChallengeURLInvalid(t *testing.T) {
challenge := registryChallenge{Realm: "://invalid"}
if _, err := challenge.URL(); err == nil {
t.Error("expected error for invalid URL")
}
}

View File

@@ -41,6 +41,7 @@ var (
errCapabilityVision = errors.New("vision")
errCapabilityEmbedding = errors.New("embedding")
errCapabilityThinking = errors.New("thinking")
errCapabilityImage = errors.New("image generation")
errInsecureProtocol = errors.New("insecure protocol http")
)
@@ -76,7 +77,7 @@ func (m *Model) Capabilities() []model.Capability {
// Check for image generation model via config capabilities
if slices.Contains(m.Config.Capabilities, "image") {
return []model.Capability{model.CapabilityImageGeneration}
return []model.Capability{model.CapabilityImage}
}
// Check for completion capability
@@ -159,6 +160,7 @@ func (m *Model) CheckCapabilities(want ...model.Capability) error {
model.CapabilityVision: errCapabilityVision,
model.CapabilityEmbedding: errCapabilityEmbedding,
model.CapabilityThinking: errCapabilityThinking,
model.CapabilityImage: errCapabilityImage,
}
for _, cap := range want {
@@ -775,7 +777,7 @@ func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
Realm: challenge.Realm,
Service: challenge.Service,
Scope: challenge.Scope,
})
}, base.Host)
}
if err := transfer.Download(ctx, transfer.DownloadOptions{
@@ -850,7 +852,7 @@ func pushWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
Realm: challenge.Realm,
Service: challenge.Service,
Scope: challenge.Scope,
})
}, base.Host)
}
return transfer.Upload(ctx, transfer.UploadOptions{
@@ -916,7 +918,7 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
// Handle authentication error with one retry
challenge := parseRegistryChallenge(resp.Header.Get("www-authenticate"))
token, err := getAuthorizationToken(ctx, challenge)
token, err := getAuthorizationToken(ctx, challenge, requestURL.Host)
if err != nil {
return nil, err
}

View File

@@ -54,7 +54,7 @@ func TestModelCapabilities(t *testing.T) {
Capabilities: []string{"image"},
},
},
expectedCaps: []model.Capability{model.CapabilityImageGeneration},
expectedCaps: []model.Capability{model.CapabilityImage},
},
{
name: "model with completion capability",
@@ -242,6 +242,24 @@ func TestModelCheckCapabilities(t *testing.T) {
checkCaps: []model.Capability{"unknown"},
expectedErrMsg: "unknown capability",
},
{
name: "model missing image generation capability",
model: Model{
ModelPath: completionModelPath,
Template: chatTemplate,
},
checkCaps: []model.Capability{model.CapabilityImage},
expectedErrMsg: "does not support image generation",
},
{
name: "model with image generation capability",
model: Model{
Config: model.ConfigV2{
Capabilities: []string{"image"},
},
},
checkCaps: []model.Capability{model.CapabilityImage},
},
}
for _, tt := range tests {

View File

@@ -51,7 +51,7 @@ import (
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
"github.com/ollama/ollama/x/imagegen"
imagegenapi "github.com/ollama/ollama/x/imagegen/api"
xserver "github.com/ollama/ollama/x/server"
)
const signinURLStr = "https://ollama.com/connect?name=%s&key=%s"
@@ -164,29 +164,6 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.C
return runner.llama, model, &opts, nil
}
// ScheduleImageGenRunner schedules an image generation model runner.
// This implements the imagegenapi.RunnerScheduler interface.
func (s *Server) ScheduleImageGenRunner(c *gin.Context, modelName string, opts api.Options, keepAlive *api.Duration) (llm.LlamaServer, error) {
m := &Model{
Name: modelName,
ShortName: modelName,
ModelPath: modelName, // For image gen, ModelPath is just the model name
Config: model.ConfigV2{
Capabilities: []string{"image"},
},
}
runnerCh, errCh := s.sched.GetRunner(c.Request.Context(), m, opts, keepAlive)
var runner *runnerRef
select {
case runner = <-runnerCh:
case err := <-errCh:
return nil, err
}
return runner.llama, nil
}
func signinURL() (string, error) {
pubKey, err := auth.GetPublicKey()
if err != nil {
@@ -214,12 +191,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}
// Check if this is a known image generation model
if imagegen.ResolveModelName(req.Model) != "" {
imagegenapi.HandleGenerateRequest(c, s, req.Model, req.Prompt, req.KeepAlive, streamResponse)
return
}
name := model.ParseName(req.Model)
if !name.IsValid() {
// Ideally this is "invalid model name" but we're keeping with
@@ -249,6 +220,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}
// Handle image generation models
if slices.Contains(m.Capabilities(), model.CapabilityImage) {
s.handleImageGenerate(c, req, name.String(), checkpointStart)
return
}
if req.TopLogprobs < 0 || req.TopLogprobs > 20 {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "top_logprobs must be between 0 and 20"})
return
@@ -1125,7 +1102,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
}
// For image generation models, populate details from imagegen package
if slices.Contains(m.Capabilities(), model.CapabilityImageGeneration) {
if slices.Contains(m.Capabilities(), model.CapabilityImage) {
if info, err := imagegen.GetModelInfo(name.String()); err == nil {
modelDetails.Family = info.Architecture
modelDetails.ParameterSize = format.HumanNumber(uint64(info.ParameterCount))
@@ -1133,6 +1110,22 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
}
}
// For safetensors LLM models (experimental), populate details from config.json
if m.Config.ModelFormat == "safetensors" && slices.Contains(m.Config.Capabilities, "completion") {
if info, err := xserver.GetSafetensorsLLMInfo(name.String()); err == nil {
if arch, ok := info["general.architecture"].(string); ok && arch != "" {
modelDetails.Family = arch
}
if paramCount, ok := info["general.parameter_count"].(int64); ok && paramCount > 0 {
modelDetails.ParameterSize = format.HumanNumber(uint64(paramCount))
}
}
// Get torch_dtype directly from config.json for quantization level
if dtype, err := xserver.GetSafetensorsDtype(name.String()); err == nil && dtype != "" {
modelDetails.QuantizationLevel = dtype
}
}
if req.System != "" {
m.System = req.System
}
@@ -1215,7 +1208,27 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
return resp, nil
}
if slices.Contains(m.Capabilities(), model.CapabilityImageGeneration) {
if slices.Contains(m.Capabilities(), model.CapabilityImage) {
// Populate tensor info if verbose
if req.Verbose {
if tensors, err := xserver.GetSafetensorsTensorInfo(name.String()); err == nil {
resp.Tensors = tensors
}
}
return resp, nil
}
// For safetensors LLM models (experimental), populate ModelInfo from config.json
if m.Config.ModelFormat == "safetensors" && slices.Contains(m.Config.Capabilities, "completion") {
if info, err := xserver.GetSafetensorsLLMInfo(name.String()); err == nil {
resp.ModelInfo = info
}
// Populate tensor info if verbose
if req.Verbose {
if tensors, err := xserver.GetSafetensorsTensorInfo(name.String()); err == nil {
resp.Tensors = tensors
}
}
return resp, nil
}
@@ -1587,13 +1600,12 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler)
r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler)
r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler)
// OpenAI-compatible image generation endpoint
r.POST("/v1/images/generations", middleware.ImageGenerationsMiddleware(), s.GenerateHandler)
// Inference (Anthropic compatibility)
r.POST("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler)
// Experimental image generation support
imagegenapi.RegisterRoutes(r, s)
if rc != nil {
// wrap old with new
rs := &registry.Local{
@@ -2460,3 +2472,78 @@ func filterThinkTags(msgs []api.Message, m *Model) []api.Message {
}
return msgs
}
// handleImageGenerate handles image generation requests within GenerateHandler.
// This is called when the model has the ImageGeneration capability.
func (s *Server) handleImageGenerate(c *gin.Context, req api.GenerateRequest, modelName string, checkpointStart time.Time) {
// Validate image dimensions
const maxDimension int32 = 4096
if req.Width > maxDimension || req.Height > maxDimension {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("width and height must be <= %d", maxDimension)})
return
}
// Schedule the runner for image generation
runner, _, _, err := s.scheduleRunner(c.Request.Context(), modelName, []model.Capability{model.CapabilityImage}, nil, req.KeepAlive)
if err != nil {
handleScheduleError(c, req.Model, err)
return
}
checkpointLoaded := time.Now()
// Handle load-only request (empty prompt)
if req.Prompt == "" {
c.JSON(http.StatusOK, api.GenerateResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Done: true,
DoneReason: "load",
})
return
}
// Set headers for streaming response
c.Header("Content-Type", "application/x-ndjson")
var streamStarted bool
if err := runner.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: req.Prompt,
Width: req.Width,
Height: req.Height,
Steps: req.Steps,
Seed: req.Seed,
}, func(cr llm.CompletionResponse) {
streamStarted = true
res := api.GenerateResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Done: cr.Done,
}
if cr.TotalSteps > 0 {
res.Completed = int64(cr.Step)
res.Total = int64(cr.TotalSteps)
}
if cr.Image != "" {
res.Images = []string{cr.Image}
}
if cr.Done {
res.DoneReason = cr.DoneReason.String()
res.Metrics.TotalDuration = time.Since(checkpointStart)
res.Metrics.LoadDuration = checkpointLoaded.Sub(checkpointStart)
}
data, _ := json.Marshal(res)
c.Writer.Write(append(data, '\n'))
c.Writer.Flush()
}); err != nil {
// Only send JSON error if streaming hasn't started yet
// (once streaming starts, headers are committed and we can't change status code)
if !streamStarted {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
}
}

View File

@@ -571,10 +571,10 @@ func (s *Scheduler) loadImageGen(req *LlmRequest) bool {
model: req.model,
modelPath: req.model.ModelPath,
llama: server,
Options: &req.opts,
loading: false,
sessionDuration: sessionDuration,
refCount: 1,
totalSize: server.TotalSize(),
vramSize: server.VRAMSize(),
}
s.loadedMu.Lock()

View File

@@ -6,7 +6,6 @@ import (
"errors"
"log/slog"
"os"
"slices"
"testing"
"time"
@@ -17,7 +16,6 @@ import (
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/types/model"
)
func TestMain(m *testing.M) {
@@ -807,32 +805,8 @@ func (s *mockLlm) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo { return n
func (s *mockLlm) HasExited() bool { return false }
func (s *mockLlm) GetActiveDeviceIDs() []ml.DeviceID { return nil }
// TestImageGenCapabilityDetection verifies that models with "image" capability
// are correctly identified and routed differently from language models.
func TestImageGenCapabilityDetection(t *testing.T) {
// Model with image capability should be detected
imageModel := &Model{
Config: model.ConfigV2{
Capabilities: []string{"image"},
},
}
require.True(t, slices.Contains(imageModel.Config.Capabilities, "image"))
// Model without image capability should not be detected
langModel := &Model{
Config: model.ConfigV2{
Capabilities: []string{"completion"},
},
}
require.False(t, slices.Contains(langModel.Config.Capabilities, "image"))
// Empty capabilities should not match
emptyModel := &Model{}
require.False(t, slices.Contains(emptyModel.Config.Capabilities, "image"))
}
// TestImageGenRunnerCanBeEvicted verifies that an image generation model
// loaded in the scheduler can be evicted by a language model request.
// loaded in the scheduler can be evicted when idle.
func TestImageGenRunnerCanBeEvicted(t *testing.T) {
ctx, done := context.WithTimeout(t.Context(), 500*time.Millisecond)
defer done()
@@ -864,3 +838,59 @@ func TestImageGenRunnerCanBeEvicted(t *testing.T) {
require.NotNil(t, runner)
require.Equal(t, "/fake/image/model", runner.modelPath)
}
// TestImageGenSchedulerCoexistence verifies that image generation models
// can coexist with language models in the scheduler and VRAM is tracked correctly.
func TestImageGenSchedulerCoexistence(t *testing.T) {
ctx, done := context.WithTimeout(t.Context(), 500*time.Millisecond)
defer done()
s := InitScheduler(ctx)
s.getGpuFn = getGpuFn
s.getSystemInfoFn = getSystemInfoFn
// Load both an imagegen runner and a language model runner
imageGenRunner := &runnerRef{
model: &Model{Name: "flux", ModelPath: "/fake/flux/model"},
modelPath: "/fake/flux/model",
llama: &mockLlm{vramSize: 8 * format.GigaByte, vramByGPU: map[ml.DeviceID]uint64{{Library: "Metal"}: 8 * format.GigaByte}},
sessionDuration: 10 * time.Millisecond,
numParallel: 1,
refCount: 0,
}
langModelRunner := &runnerRef{
model: &Model{Name: "llama3", ModelPath: "/fake/llama3/model"},
modelPath: "/fake/llama3/model",
llama: &mockLlm{vramSize: 4 * format.GigaByte, vramByGPU: map[ml.DeviceID]uint64{{Library: "Metal"}: 4 * format.GigaByte}},
sessionDuration: 10 * time.Millisecond,
numParallel: 1,
refCount: 0,
}
s.loadedMu.Lock()
s.loaded["/fake/flux/model"] = imageGenRunner
s.loaded["/fake/llama3/model"] = langModelRunner
s.loadedMu.Unlock()
// Verify both are loaded
s.loadedMu.Lock()
require.Len(t, s.loaded, 2)
require.NotNil(t, s.loaded["/fake/flux/model"])
require.NotNil(t, s.loaded["/fake/llama3/model"])
s.loadedMu.Unlock()
// Verify updateFreeSpace accounts for both
gpus := []ml.DeviceInfo{
{
DeviceID: ml.DeviceID{Library: "Metal"},
TotalMemory: 24 * format.GigaByte,
FreeMemory: 24 * format.GigaByte,
},
}
s.updateFreeSpace(gpus)
// Free memory should be reduced by both models
expectedFree := uint64(24*format.GigaByte) - uint64(8*format.GigaByte) - uint64(4*format.GigaByte)
require.Equal(t, expectedFree, gpus[0].FreeMemory)
}

View File

@@ -279,7 +279,7 @@ func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *
case resp.StatusCode == http.StatusUnauthorized:
w.Rollback()
challenge := parseRegistryChallenge(resp.Header.Get("www-authenticate"))
token, err := getAuthorizationToken(ctx, challenge)
token, err := getAuthorizationToken(ctx, challenge, requestURL.Host)
if err != nil {
return err
}

View File

@@ -9,7 +9,7 @@ const (
CapabilityVision = Capability("vision")
CapabilityEmbedding = Capability("embedding")
CapabilityThinking = Capability("thinking")
CapabilityImageGeneration = Capability("image")
CapabilityImage = Capability("image")
)
func (c Capability) String() string {

View File

@@ -25,14 +25,6 @@ import (
"github.com/ollama/ollama/x/tools"
)
// MultilineState tracks the state of multiline input
type MultilineState int
const (
MultilineNone MultilineState = iota
MultilineSystem
)
// Tool output capping constants
const (
// localModelTokenLimit is the token limit for local models (smaller context).
@@ -656,7 +648,7 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
Prompt: ">>> ",
AltPrompt: "... ",
Placeholder: "Send a message (/? for help)",
AltPlaceholder: `Use """ to end multi-line input`,
AltPlaceholder: "Press Enter to send",
})
if err != nil {
return err
@@ -707,7 +699,6 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
var sb strings.Builder
var format string
var system string
var multiline MultilineState = MultilineNone
for {
line, err := scanner.Readline()
@@ -721,37 +712,12 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
}
scanner.Prompt.UseAlt = false
sb.Reset()
multiline = MultilineNone
continue
case err != nil:
return err
}
switch {
case multiline != MultilineNone:
// check if there's a multiline terminating string
before, ok := strings.CutSuffix(line, `"""`)
sb.WriteString(before)
if !ok {
fmt.Fprintln(&sb)
continue
}
switch multiline {
case MultilineSystem:
system = sb.String()
newMessage := api.Message{Role: "system", Content: system}
if len(messages) > 0 && messages[len(messages)-1].Role == "system" {
messages[len(messages)-1] = newMessage
} else {
messages = append(messages, newMessage)
}
fmt.Println("Set system message.")
sb.Reset()
}
multiline = MultilineNone
scanner.Prompt.UseAlt = false
case strings.HasPrefix(line, "/exit"), strings.HasPrefix(line, "/bye"):
return nil
case strings.HasPrefix(line, "/clear"):
@@ -860,41 +826,18 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
options[args[2]] = fp[args[2]]
case "system":
if len(args) < 3 {
fmt.Println("Usage: /set system <message> or /set system \"\"\"<multi-line message>\"\"\"")
fmt.Println("Usage: /set system <message>")
continue
}
multiline = MultilineSystem
line := strings.Join(args[2:], " ")
line, ok := strings.CutPrefix(line, `"""`)
if !ok {
multiline = MultilineNone
} else {
// only cut suffix if the line is multiline
line, ok = strings.CutSuffix(line, `"""`)
if ok {
multiline = MultilineNone
}
}
sb.WriteString(line)
if multiline != MultilineNone {
scanner.Prompt.UseAlt = true
continue
}
system = sb.String()
newMessage := api.Message{Role: "system", Content: sb.String()}
// Check if the slice is not empty and the last message is from 'system'
system = strings.Join(args[2:], " ")
newMessage := api.Message{Role: "system", Content: system}
if len(messages) > 0 && messages[len(messages)-1].Role == "system" {
// Replace the last message
messages[len(messages)-1] = newMessage
} else {
messages = append(messages, newMessage)
}
fmt.Println("Set system message.")
sb.Reset()
continue
default:
fmt.Printf("Unknown command '/set %s'. Type /? for help\n", args[1])
@@ -1081,7 +1024,7 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
sb.WriteString(line)
}
if sb.Len() > 0 && multiline == MultilineNone {
if sb.Len() > 0 {
newMessage := api.Message{Role: "user", Content: sb.String()}
messages = append(messages, newMessage)

282
x/create/client/create.go Normal file
View File

@@ -0,0 +1,282 @@
// Package client provides client-side model creation for safetensors-based models.
//
// This package is in x/ because the safetensors model storage format is under development.
// It also exists to break an import cycle: server imports x/create, so x/create
// cannot import server. This sub-package can import server because server doesn't
// import it.
package client
import (
"bytes"
"encoding/json"
"fmt"
"io"
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/server"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/x/create"
)
// MinOllamaVersion is the minimum Ollama version required for safetensors models.
const MinOllamaVersion = "0.14.0"
// ModelfileConfig holds configuration extracted from a Modelfile.
type ModelfileConfig struct {
Template string
System string
License string
}
// CreateOptions holds all options for model creation.
type CreateOptions struct {
ModelName string
ModelDir string
Quantize string // "fp8" for quantization
Modelfile *ModelfileConfig // template/system/license from Modelfile
}
// CreateModel imports a model from a local directory.
// This creates blobs and manifest directly on disk, bypassing the HTTP API.
// Automatically detects model type (safetensors LLM vs image gen) and routes accordingly.
func CreateModel(opts CreateOptions, p *progress.Progress) error {
// Detect model type
isSafetensors := create.IsSafetensorsModelDir(opts.ModelDir)
isImageGen := create.IsTensorModelDir(opts.ModelDir)
if !isSafetensors && !isImageGen {
return fmt.Errorf("%s is not a supported model directory (needs config.json + *.safetensors or model_index.json)", opts.ModelDir)
}
// Determine model type settings
var modelType, spinnerKey string
var capabilities []string
if isSafetensors {
modelType = "safetensors model"
spinnerKey = "create"
capabilities = []string{"completion"}
} else {
modelType = "image generation model"
spinnerKey = "imagegen"
capabilities = []string{"image"}
}
// Set up progress spinner
statusMsg := "importing " + modelType
spinner := progress.NewSpinner(statusMsg)
p.Add(spinnerKey, spinner)
progressFn := func(msg string) {
spinner.Stop()
statusMsg = msg
spinner = progress.NewSpinner(statusMsg)
p.Add(spinnerKey, spinner)
}
// Create the model using shared callbacks
var err error
if isSafetensors {
err = create.CreateSafetensorsModel(
opts.ModelName, opts.ModelDir, opts.Quantize,
newLayerCreator(), newTensorLayerCreator(),
newManifestWriter(opts, capabilities),
progressFn,
)
} else {
err = create.CreateImageGenModel(
opts.ModelName, opts.ModelDir, opts.Quantize,
newLayerCreator(), newTensorLayerCreator(),
newManifestWriter(opts, capabilities),
progressFn,
)
}
spinner.Stop()
if err != nil {
return err
}
fmt.Printf("Created %s '%s'\n", modelType, opts.ModelName)
return nil
}
// newLayerCreator returns a LayerCreator callback for creating config/JSON layers.
func newLayerCreator() create.LayerCreator {
return func(r io.Reader, mediaType, name string) (create.LayerInfo, error) {
layer, err := server.NewLayer(r, mediaType)
if err != nil {
return create.LayerInfo{}, err
}
return create.LayerInfo{
Digest: layer.Digest,
Size: layer.Size,
MediaType: layer.MediaType,
Name: name,
}, nil
}
}
// newTensorLayerCreator returns a QuantizingTensorLayerCreator callback for creating tensor layers.
// When quantize is non-empty, returns multiple layers (weight + scales + optional qbias).
func newTensorLayerCreator() create.QuantizingTensorLayerCreator {
return func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]create.LayerInfo, error) {
if quantize != "" {
return createQuantizedLayers(r, name, dtype, shape, quantize)
}
return createUnquantizedLayer(r, name)
}
}
// createQuantizedLayers quantizes a tensor and returns the resulting layers.
func createQuantizedLayers(r io.Reader, name, dtype string, shape []int32, quantize string) ([]create.LayerInfo, error) {
if !QuantizeSupported() {
return nil, fmt.Errorf("quantization requires MLX support")
}
// Quantize the tensor
qweightData, scalesData, qbiasData, _, _, _, err := quantizeTensor(r, name, dtype, shape, quantize)
if err != nil {
return nil, fmt.Errorf("failed to quantize %s: %w", name, err)
}
// Create layer for quantized weight
weightLayer, err := server.NewLayer(bytes.NewReader(qweightData), server.MediaTypeImageTensor)
if err != nil {
return nil, err
}
// Create layer for scales
scalesLayer, err := server.NewLayer(bytes.NewReader(scalesData), server.MediaTypeImageTensor)
if err != nil {
return nil, err
}
layers := []create.LayerInfo{
{
Digest: weightLayer.Digest,
Size: weightLayer.Size,
MediaType: weightLayer.MediaType,
Name: name,
},
{
Digest: scalesLayer.Digest,
Size: scalesLayer.Size,
MediaType: scalesLayer.MediaType,
Name: name + "_scale",
},
}
// Add qbiases layer if present (affine mode)
if qbiasData != nil {
qbiasLayer, err := server.NewLayer(bytes.NewReader(qbiasData), server.MediaTypeImageTensor)
if err != nil {
return nil, err
}
layers = append(layers, create.LayerInfo{
Digest: qbiasLayer.Digest,
Size: qbiasLayer.Size,
MediaType: qbiasLayer.MediaType,
Name: name + "_qbias",
})
}
return layers, nil
}
// createUnquantizedLayer creates a single tensor layer without quantization.
func createUnquantizedLayer(r io.Reader, name string) ([]create.LayerInfo, error) {
layer, err := server.NewLayer(r, server.MediaTypeImageTensor)
if err != nil {
return nil, err
}
return []create.LayerInfo{
{
Digest: layer.Digest,
Size: layer.Size,
MediaType: layer.MediaType,
Name: name,
},
}, nil
}
// newManifestWriter returns a ManifestWriter callback for writing the model manifest.
func newManifestWriter(opts CreateOptions, capabilities []string) create.ManifestWriter {
return func(modelName string, config create.LayerInfo, layers []create.LayerInfo) error {
name := model.ParseName(modelName)
if !name.IsValid() {
return fmt.Errorf("invalid model name: %s", modelName)
}
// Create config blob with version requirement
configData := model.ConfigV2{
ModelFormat: "safetensors",
Capabilities: capabilities,
Requires: MinOllamaVersion,
}
configJSON, err := json.Marshal(configData)
if err != nil {
return fmt.Errorf("failed to marshal config: %w", err)
}
// Create config layer blob
configLayer, err := server.NewLayer(bytes.NewReader(configJSON), "application/vnd.docker.container.image.v1+json")
if err != nil {
return fmt.Errorf("failed to create config layer: %w", err)
}
// Convert LayerInfo to server.Layer
serverLayers := make([]server.Layer, 0, len(layers))
for _, l := range layers {
serverLayers = append(serverLayers, server.Layer{
MediaType: l.MediaType,
Digest: l.Digest,
Size: l.Size,
Name: l.Name,
})
}
// Add Modelfile layers if present
if opts.Modelfile != nil {
modelfileLayers, err := createModelfileLayers(opts.Modelfile)
if err != nil {
return err
}
serverLayers = append(serverLayers, modelfileLayers...)
}
return server.WriteManifest(name, configLayer, serverLayers)
}
}
// createModelfileLayers creates layers for template, system, and license from Modelfile config.
func createModelfileLayers(mf *ModelfileConfig) ([]server.Layer, error) {
var layers []server.Layer
if mf.Template != "" {
layer, err := server.NewLayer(bytes.NewReader([]byte(mf.Template)), "application/vnd.ollama.image.template")
if err != nil {
return nil, fmt.Errorf("failed to create template layer: %w", err)
}
layers = append(layers, layer)
}
if mf.System != "" {
layer, err := server.NewLayer(bytes.NewReader([]byte(mf.System)), "application/vnd.ollama.image.system")
if err != nil {
return nil, fmt.Errorf("failed to create system layer: %w", err)
}
layers = append(layers, layer)
}
if mf.License != "" {
layer, err := server.NewLayer(bytes.NewReader([]byte(mf.License)), "application/vnd.ollama.image.license")
if err != nil {
return nil, fmt.Errorf("failed to create license layer: %w", err)
}
layers = append(layers, layer)
}
return layers, nil
}

View File

@@ -0,0 +1,146 @@
package client
import (
"testing"
)
func TestModelfileConfig(t *testing.T) {
// Test that ModelfileConfig struct works as expected
config := &ModelfileConfig{
Template: "{{ .Prompt }}",
System: "You are a helpful assistant.",
License: "MIT",
}
if config.Template != "{{ .Prompt }}" {
t.Errorf("Template = %q, want %q", config.Template, "{{ .Prompt }}")
}
if config.System != "You are a helpful assistant." {
t.Errorf("System = %q, want %q", config.System, "You are a helpful assistant.")
}
if config.License != "MIT" {
t.Errorf("License = %q, want %q", config.License, "MIT")
}
}
func TestModelfileConfig_Empty(t *testing.T) {
config := &ModelfileConfig{}
if config.Template != "" {
t.Errorf("Template should be empty, got %q", config.Template)
}
if config.System != "" {
t.Errorf("System should be empty, got %q", config.System)
}
if config.License != "" {
t.Errorf("License should be empty, got %q", config.License)
}
}
func TestModelfileConfig_PartialFields(t *testing.T) {
// Test config with only some fields set
config := &ModelfileConfig{
Template: "{{ .Prompt }}",
// System and License intentionally empty
}
if config.Template == "" {
t.Error("Template should not be empty")
}
if config.System != "" {
t.Error("System should be empty")
}
if config.License != "" {
t.Error("License should be empty")
}
}
func TestMinOllamaVersion(t *testing.T) {
// Verify the minimum version constant is set
if MinOllamaVersion == "" {
t.Error("MinOllamaVersion should not be empty")
}
if MinOllamaVersion != "0.14.0" {
t.Errorf("MinOllamaVersion = %q, want %q", MinOllamaVersion, "0.14.0")
}
}
func TestCreateModel_InvalidDir(t *testing.T) {
// Test that CreateModel returns error for invalid directory
err := CreateModel(CreateOptions{
ModelName: "test-model",
ModelDir: "/nonexistent/path",
}, nil)
if err == nil {
t.Error("expected error for nonexistent directory, got nil")
}
}
func TestCreateModel_NotSafetensorsDir(t *testing.T) {
// Test that CreateModel returns error for directory without safetensors
dir := t.TempDir()
err := CreateModel(CreateOptions{
ModelName: "test-model",
ModelDir: dir,
}, nil)
if err == nil {
t.Error("expected error for empty directory, got nil")
}
}
func TestCreateOptions(t *testing.T) {
opts := CreateOptions{
ModelName: "my-model",
ModelDir: "/path/to/model",
Quantize: "fp8",
Modelfile: &ModelfileConfig{
Template: "test",
System: "system",
License: "MIT",
},
}
if opts.ModelName != "my-model" {
t.Errorf("ModelName = %q, want %q", opts.ModelName, "my-model")
}
if opts.ModelDir != "/path/to/model" {
t.Errorf("ModelDir = %q, want %q", opts.ModelDir, "/path/to/model")
}
if opts.Quantize != "fp8" {
t.Errorf("Quantize = %q, want %q", opts.Quantize, "fp8")
}
if opts.Modelfile == nil {
t.Error("Modelfile should not be nil")
}
if opts.Modelfile.Template != "test" {
t.Errorf("Modelfile.Template = %q, want %q", opts.Modelfile.Template, "test")
}
}
func TestCreateOptions_Defaults(t *testing.T) {
opts := CreateOptions{
ModelName: "test",
ModelDir: "/tmp",
}
// Quantize should default to empty
if opts.Quantize != "" {
t.Errorf("Quantize should be empty by default, got %q", opts.Quantize)
}
// Modelfile should default to nil
if opts.Modelfile != nil {
t.Error("Modelfile should be nil by default")
}
}
func TestQuantizeSupported(t *testing.T) {
// This just verifies the function exists and returns a boolean
// The actual value depends on build tags (mlx vs non-mlx)
supported := QuantizeSupported()
// In non-mlx builds, this should be false
// We can't easily test both cases, so just verify it returns something
_ = supported
}

View File

@@ -11,10 +11,11 @@ import (
"github.com/ollama/ollama/x/imagegen/mlx"
)
// quantizeTensor loads a tensor from safetensors format, quantizes it to affine int8,
// quantizeTensor loads a tensor from safetensors format, quantizes it,
// and returns safetensors data for the quantized weights, scales, and biases.
// Supported quantization types: "fp8" (affine 8-bit)
// Uses MLX's native SaveSafetensors to ensure correct dtype handling (especially uint32 for quantized weights).
func quantizeTensor(r io.Reader, name, dtype string, shape []int32) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) {
func quantizeTensor(r io.Reader, name, dtype string, shape []int32, quantize string) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) {
tmpDir := ensureTempDir()
// Read safetensors data to a temp file (LoadSafetensorsNative needs a path)
@@ -50,9 +51,15 @@ func quantizeTensor(r io.Reader, name, dtype string, shape []int32) (qweightData
mlx.Eval(arr)
}
// Quantize with affine mode: group_size=32, bits=8
// Note: mxfp8 mode doesn't have matmul kernels in MLX, affine mode does
qweight, scales, qbiases := mlx.Quantize(arr, 32, 8, "affine")
// Quantize based on quantization type
var qweight, scales, qbiases *mlx.Array
switch quantize {
case "fp8":
// affine mode: group_size=32, bits=8
qweight, scales, qbiases = mlx.Quantize(arr, 32, 8, "affine")
default:
return nil, nil, nil, nil, nil, nil, fmt.Errorf("unsupported quantization type: %s", quantize)
}
// Eval and make contiguous for data access
qweight = mlx.Contiguous(qweight)

View File

@@ -8,7 +8,7 @@ import (
)
// quantizeTensor is not available without MLX
func quantizeTensor(r io.Reader, name, dtype string, shape []int32) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) {
func quantizeTensor(r io.Reader, name, dtype string, shape []int32, quantize string) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) {
return nil, nil, nil, nil, nil, nil, fmt.Errorf("quantization requires MLX support (build with mlx tag)")
}

399
x/create/create.go Normal file
View File

@@ -0,0 +1,399 @@
package create
import (
"encoding/json"
"fmt"
"io"
"os"
"path/filepath"
"slices"
"strings"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// ModelConfig represents the config blob stored with a model.
type ModelConfig struct {
ModelFormat string `json:"model_format"`
Capabilities []string `json:"capabilities"`
}
// Manifest represents the manifest JSON structure.
type Manifest struct {
SchemaVersion int `json:"schemaVersion"`
MediaType string `json:"mediaType"`
Config ManifestLayer `json:"config"`
Layers []ManifestLayer `json:"layers"`
}
// ManifestLayer represents a layer in the manifest.
type ManifestLayer struct {
MediaType string `json:"mediaType"`
Digest string `json:"digest"`
Size int64 `json:"size"`
Name string `json:"name,omitempty"`
}
// defaultManifestDir returns the manifest storage directory.
func defaultManifestDir() string {
return filepath.Join(envconfig.Models(), "manifests")
}
// defaultBlobDir returns the blob storage directory.
func defaultBlobDir() string {
return filepath.Join(envconfig.Models(), "blobs")
}
// resolveManifestPath converts a model name to a manifest file path.
func resolveManifestPath(modelName string) string {
host := "registry.ollama.ai"
namespace := "library"
name := modelName
tag := "latest"
if idx := strings.LastIndex(name, ":"); idx != -1 {
tag = name[idx+1:]
name = name[:idx]
}
parts := strings.Split(name, "/")
switch len(parts) {
case 3:
host = parts[0]
namespace = parts[1]
name = parts[2]
case 2:
namespace = parts[0]
name = parts[1]
}
return filepath.Join(defaultManifestDir(), host, namespace, name, tag)
}
// loadManifest loads a manifest for the given model name.
func loadManifest(modelName string) (*Manifest, error) {
manifestPath := resolveManifestPath(modelName)
data, err := os.ReadFile(manifestPath)
if err != nil {
return nil, err
}
var manifest Manifest
if err := json.Unmarshal(data, &manifest); err != nil {
return nil, err
}
return &manifest, nil
}
// loadModelConfig loads the config blob for a model.
func loadModelConfig(modelName string) (*ModelConfig, error) {
manifest, err := loadManifest(modelName)
if err != nil {
return nil, err
}
// Read the config blob
blobName := strings.Replace(manifest.Config.Digest, ":", "-", 1)
blobPath := filepath.Join(defaultBlobDir(), blobName)
data, err := os.ReadFile(blobPath)
if err != nil {
return nil, err
}
var config ModelConfig
if err := json.Unmarshal(data, &config); err != nil {
return nil, err
}
return &config, nil
}
// IsSafetensorsModel checks if a model was created with the experimental
// safetensors builder by checking the model format in the config.
func IsSafetensorsModel(modelName string) bool {
config, err := loadModelConfig(modelName)
if err != nil {
return false
}
return config.ModelFormat == "safetensors"
}
// IsSafetensorsLLMModel checks if a model is a safetensors LLM model
// (has completion capability, not image generation).
func IsSafetensorsLLMModel(modelName string) bool {
config, err := loadModelConfig(modelName)
if err != nil {
return false
}
return config.ModelFormat == "safetensors" && slices.Contains(config.Capabilities, "completion")
}
// IsImageGenModel checks if a model is an image generation model
// (has image capability).
func IsImageGenModel(modelName string) bool {
config, err := loadModelConfig(modelName)
if err != nil {
return false
}
return config.ModelFormat == "safetensors" && slices.Contains(config.Capabilities, "image")
}
// GetModelArchitecture returns the architecture from the model's config.json layer.
func GetModelArchitecture(modelName string) (string, error) {
manifest, err := loadManifest(modelName)
if err != nil {
return "", err
}
// Find the config.json layer
for _, layer := range manifest.Layers {
if layer.Name == "config.json" && layer.MediaType == "application/vnd.ollama.image.json" {
blobName := strings.Replace(layer.Digest, ":", "-", 1)
blobPath := filepath.Join(defaultBlobDir(), blobName)
data, err := os.ReadFile(blobPath)
if err != nil {
return "", err
}
var cfg struct {
Architectures []string `json:"architectures"`
ModelType string `json:"model_type"`
}
if err := json.Unmarshal(data, &cfg); err != nil {
return "", err
}
// Prefer model_type, fall back to first architecture
if cfg.ModelType != "" {
return cfg.ModelType, nil
}
if len(cfg.Architectures) > 0 {
return cfg.Architectures[0], nil
}
}
}
return "", fmt.Errorf("architecture not found in model config")
}
// IsTensorModelDir checks if the directory contains a diffusers-style tensor model
// by looking for model_index.json, which is the standard diffusers pipeline config.
func IsTensorModelDir(dir string) bool {
_, err := os.Stat(filepath.Join(dir, "model_index.json"))
return err == nil
}
// IsSafetensorsModelDir checks if the directory contains a standard safetensors model
// by looking for config.json and at least one .safetensors file.
func IsSafetensorsModelDir(dir string) bool {
// Must have config.json
if _, err := os.Stat(filepath.Join(dir, "config.json")); err != nil {
return false
}
// Must have at least one .safetensors file
entries, err := os.ReadDir(dir)
if err != nil {
return false
}
for _, entry := range entries {
if strings.HasSuffix(entry.Name(), ".safetensors") {
return true
}
}
return false
}
// LayerInfo holds metadata for a created layer.
type LayerInfo struct {
Digest string
Size int64
MediaType string
Name string // Path-style name: "component/tensor" or "path/to/config.json"
}
// LayerCreator is called to create a blob layer.
// name is the path-style name (e.g., "tokenizer/tokenizer.json")
type LayerCreator func(r io.Reader, mediaType, name string) (LayerInfo, error)
// TensorLayerCreator creates a tensor blob layer with metadata.
// name is the path-style name including component (e.g., "text_encoder/model.embed_tokens.weight")
type TensorLayerCreator func(r io.Reader, name, dtype string, shape []int32) (LayerInfo, error)
// QuantizingTensorLayerCreator creates tensor layers with optional quantization.
// When quantize is non-empty (e.g., "fp8"), returns multiple layers (weight + scales + biases).
type QuantizingTensorLayerCreator func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error)
// ManifestWriter writes the manifest file.
type ManifestWriter func(modelName string, config LayerInfo, layers []LayerInfo) error
// ShouldQuantize returns true if a tensor should be quantized.
// For image gen models (component non-empty): quantizes linear weights, skipping VAE, embeddings, norms.
// For LLM models (component empty): quantizes linear weights, skipping embeddings, norms, and small tensors.
func ShouldQuantize(name, component string) bool {
// Image gen specific: skip VAE entirely
if component == "vae" {
return false
}
// Skip embeddings
if strings.Contains(name, "embed") {
return false
}
// Skip layer norms and RMS norms
if strings.Contains(name, "norm") || strings.Contains(name, "ln_") || strings.Contains(name, "layernorm") {
return false
}
// Skip biases
if strings.HasSuffix(name, ".bias") {
return false
}
// Only quantize weights
return strings.HasSuffix(name, ".weight")
}
// ShouldQuantizeTensor returns true if a tensor should be quantized based on name and shape.
// This is a more detailed check that also considers tensor dimensions.
func ShouldQuantizeTensor(name string, shape []int32) bool {
// Use basic name-based check first
if !ShouldQuantize(name, "") {
return false
}
// Only quantize 2D tensors (linear layers) - skip 1D (biases, norms) and higher-D (convolutions if any)
if len(shape) != 2 {
return false
}
// Skip small tensors (less than 1024 elements) - not worth quantizing
if len(shape) >= 2 && int64(shape[0])*int64(shape[1]) < 1024 {
return false
}
// MLX quantization requires last dimension to be divisible by group size (32)
if shape[len(shape)-1]%32 != 0 {
return false
}
return true
}
// CreateSafetensorsModel imports a standard safetensors model from a directory.
// This handles Hugging Face style models with config.json and *.safetensors files.
// Stores each tensor as a separate blob for fine-grained deduplication.
// If quantize is non-empty (e.g., "fp8"), eligible tensors will be quantized.
func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
var layers []LayerInfo
var configLayer LayerInfo
entries, err := os.ReadDir(modelDir)
if err != nil {
return fmt.Errorf("failed to read directory: %w", err)
}
// Process all safetensors files
for _, entry := range entries {
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".safetensors") {
continue
}
stPath := filepath.Join(modelDir, entry.Name())
// Extract individual tensors from safetensors file
extractor, err := safetensors.OpenForExtraction(stPath)
if err != nil {
return fmt.Errorf("failed to open %s: %w", stPath, err)
}
tensorNames := extractor.ListTensors()
quantizeMsg := ""
if quantize != "" {
quantizeMsg = fmt.Sprintf(", quantizing to %s", quantize)
}
fn(fmt.Sprintf("importing %s (%d tensors%s)", entry.Name(), len(tensorNames), quantizeMsg))
for _, tensorName := range tensorNames {
td, err := extractor.GetTensor(tensorName)
if err != nil {
extractor.Close()
return fmt.Errorf("failed to get tensor %s: %w", tensorName, err)
}
// Determine quantization type for this tensor (empty string if not quantizing)
quantizeType := ""
if quantize != "" && ShouldQuantizeTensor(tensorName, td.Shape) {
quantizeType = quantize
}
// Store as minimal safetensors format (88 bytes header overhead)
// This enables native mmap loading via mlx_load_safetensors
// createTensorLayer returns multiple layers if quantizing (weight + scales)
newLayers, err := createTensorLayer(td.SafetensorsReader(), tensorName, td.Dtype, td.Shape, quantizeType)
if err != nil {
extractor.Close()
return fmt.Errorf("failed to create layer for %s: %w", tensorName, err)
}
layers = append(layers, newLayers...)
}
extractor.Close()
}
// Process all JSON config files
for _, entry := range entries {
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".json") {
continue
}
// Skip the index file as we don't need it after extraction
if entry.Name() == "model.safetensors.index.json" {
continue
}
cfgPath := entry.Name()
fullPath := filepath.Join(modelDir, cfgPath)
fn(fmt.Sprintf("importing config %s", cfgPath))
f, err := os.Open(fullPath)
if err != nil {
return fmt.Errorf("failed to open %s: %w", cfgPath, err)
}
layer, err := createLayer(f, "application/vnd.ollama.image.json", cfgPath)
f.Close()
if err != nil {
return fmt.Errorf("failed to create layer for %s: %w", cfgPath, err)
}
// Use config.json as the config layer
if cfgPath == "config.json" {
configLayer = layer
}
layers = append(layers, layer)
}
if configLayer.Digest == "" {
return fmt.Errorf("config.json not found in %s", modelDir)
}
fn(fmt.Sprintf("writing manifest for %s", modelName))
if err := writeManifest(modelName, configLayer, layers); err != nil {
return fmt.Errorf("failed to write manifest: %w", err)
}
fn(fmt.Sprintf("successfully imported %s with %d layers", modelName, len(layers)))
return nil
}

752
x/create/create_test.go Normal file
View File

@@ -0,0 +1,752 @@
package create
import (
"bytes"
"encoding/binary"
"encoding/json"
"io"
"os"
"path/filepath"
"strings"
"testing"
)
func TestIsTensorModelDir(t *testing.T) {
tests := []struct {
name string
setup func(dir string) error
expected bool
}{
{
name: "valid diffusers model with model_index.json",
setup: func(dir string) error {
return os.WriteFile(filepath.Join(dir, "model_index.json"), []byte(`{"_class_name": "FluxPipeline"}`), 0o644)
},
expected: true,
},
{
name: "empty directory",
setup: func(dir string) error {
return nil
},
expected: false,
},
{
name: "directory with other files but no model_index.json",
setup: func(dir string) error {
return os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{}`), 0o644)
},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
dir := t.TempDir()
if err := tt.setup(dir); err != nil {
t.Fatalf("setup failed: %v", err)
}
got := IsTensorModelDir(dir)
if got != tt.expected {
t.Errorf("IsTensorModelDir() = %v, want %v", got, tt.expected)
}
})
}
}
func TestIsSafetensorsModelDir(t *testing.T) {
tests := []struct {
name string
setup func(dir string) error
expected bool
}{
{
name: "valid safetensors model with config.json and .safetensors file",
setup: func(dir string) error {
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{"model_type": "gemma3"}`), 0o644); err != nil {
return err
}
return os.WriteFile(filepath.Join(dir, "model.safetensors"), []byte("dummy"), 0o644)
},
expected: true,
},
{
name: "config.json only, no safetensors files",
setup: func(dir string) error {
return os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{}`), 0o644)
},
expected: false,
},
{
name: "safetensors file only, no config.json",
setup: func(dir string) error {
return os.WriteFile(filepath.Join(dir, "model.safetensors"), []byte("dummy"), 0o644)
},
expected: false,
},
{
name: "empty directory",
setup: func(dir string) error {
return nil
},
expected: false,
},
{
name: "multiple safetensors files with config.json",
setup: func(dir string) error {
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{}`), 0o644); err != nil {
return err
}
if err := os.WriteFile(filepath.Join(dir, "model-00001-of-00002.safetensors"), []byte("dummy"), 0o644); err != nil {
return err
}
return os.WriteFile(filepath.Join(dir, "model-00002-of-00002.safetensors"), []byte("dummy"), 0o644)
},
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
dir := t.TempDir()
if err := tt.setup(dir); err != nil {
t.Fatalf("setup failed: %v", err)
}
got := IsSafetensorsModelDir(dir)
if got != tt.expected {
t.Errorf("IsSafetensorsModelDir() = %v, want %v", got, tt.expected)
}
})
}
}
func TestIsSafetensorsModelDir_NonexistentDir(t *testing.T) {
got := IsSafetensorsModelDir("/nonexistent/path/that/does/not/exist")
if got != false {
t.Errorf("IsSafetensorsModelDir() = %v for nonexistent dir, want false", got)
}
}
// createMinimalSafetensors creates a minimal valid safetensors file with one tensor
func createMinimalSafetensors(t *testing.T, path string) {
t.Helper()
// Create a minimal safetensors file with a single float32 tensor
header := map[string]interface{}{
"test_tensor": map[string]interface{}{
"dtype": "F32",
"shape": []int{2, 2},
"data_offsets": []int{0, 16}, // 4 float32 values = 16 bytes
},
}
headerJSON, err := json.Marshal(header)
if err != nil {
t.Fatalf("failed to marshal header: %v", err)
}
// Pad header to 8-byte alignment
padding := (8 - len(headerJSON)%8) % 8
headerJSON = append(headerJSON, bytes.Repeat([]byte(" "), padding)...)
// Write file
f, err := os.Create(path)
if err != nil {
t.Fatalf("failed to create file: %v", err)
}
defer f.Close()
// Write header size (8 bytes, little endian)
if err := binary.Write(f, binary.LittleEndian, uint64(len(headerJSON))); err != nil {
t.Fatalf("failed to write header size: %v", err)
}
// Write header
if _, err := f.Write(headerJSON); err != nil {
t.Fatalf("failed to write header: %v", err)
}
// Write tensor data (16 bytes of zeros for 4 float32 values)
if _, err := f.Write(make([]byte, 16)); err != nil {
t.Fatalf("failed to write tensor data: %v", err)
}
}
func TestCreateSafetensorsModel(t *testing.T) {
dir := t.TempDir()
// Create config.json
configJSON := `{"model_type": "test", "architectures": ["TestModel"]}`
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil {
t.Fatalf("failed to write config.json: %v", err)
}
// Create a minimal safetensors file
createMinimalSafetensors(t, filepath.Join(dir, "model.safetensors"))
// Track what was created
var createdLayers []LayerInfo
var manifestWritten bool
var manifestModelName string
var manifestConfigLayer LayerInfo
var manifestLayers []LayerInfo
var statusMessages []string
// Mock callbacks
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
data, err := io.ReadAll(r)
if err != nil {
return LayerInfo{}, err
}
layer := LayerInfo{
Digest: "sha256:test",
Size: int64(len(data)),
MediaType: mediaType,
Name: name,
}
createdLayers = append(createdLayers, layer)
return layer, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
data, err := io.ReadAll(r)
if err != nil {
return nil, err
}
layer := LayerInfo{
Digest: "sha256:tensor_" + name,
Size: int64(len(data)),
MediaType: "application/vnd.ollama.image.tensor",
Name: name,
}
createdLayers = append(createdLayers, layer)
return []LayerInfo{layer}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
manifestWritten = true
manifestModelName = modelName
manifestConfigLayer = config
manifestLayers = layers
return nil
}
progressFn := func(status string) {
statusMessages = append(statusMessages, status)
}
// Run CreateSafetensorsModel
err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
if err != nil {
t.Fatalf("CreateSafetensorsModel failed: %v", err)
}
// Verify manifest was written
if !manifestWritten {
t.Error("manifest was not written")
}
if manifestModelName != "test-model" {
t.Errorf("manifest model name = %q, want %q", manifestModelName, "test-model")
}
// Verify config layer was set
if manifestConfigLayer.Name != "config.json" {
t.Errorf("config layer name = %q, want %q", manifestConfigLayer.Name, "config.json")
}
// Verify we have at least one tensor and one config layer
hasTensor := false
hasConfig := false
for _, layer := range manifestLayers {
if layer.Name == "test_tensor" {
hasTensor = true
}
if layer.Name == "config.json" {
hasConfig = true
}
}
if !hasTensor {
t.Error("no tensor layer found in manifest")
}
if !hasConfig {
t.Error("no config layer found in manifest")
}
// Verify status messages were sent
if len(statusMessages) == 0 {
t.Error("no status messages received")
}
}
func TestCreateSafetensorsModel_NoConfigJson(t *testing.T) {
dir := t.TempDir()
// Create only a safetensors file, no config.json
createMinimalSafetensors(t, filepath.Join(dir, "model.safetensors"))
// Mock callbacks (minimal)
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
io.ReadAll(r)
return LayerInfo{Name: name}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
io.ReadAll(r)
return []LayerInfo{{Name: name}}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
return nil
}
progressFn := func(status string) {}
err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
if err == nil {
t.Error("expected error for missing config.json, got nil")
}
}
func TestCreateSafetensorsModel_EmptyDir(t *testing.T) {
dir := t.TempDir()
// Mock callbacks
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
return LayerInfo{}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
return []LayerInfo{{}}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
return nil
}
progressFn := func(status string) {}
err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
if err == nil {
t.Error("expected error for empty directory, got nil")
}
}
func TestCreateSafetensorsModel_SkipsIndexJson(t *testing.T) {
dir := t.TempDir()
// Create config.json
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{}`), 0o644); err != nil {
t.Fatalf("failed to write config.json: %v", err)
}
// Create model.safetensors.index.json (should be skipped)
indexJSON := `{"metadata": {"total_size": 100}, "weight_map": {}}`
if err := os.WriteFile(filepath.Join(dir, "model.safetensors.index.json"), []byte(indexJSON), 0o644); err != nil {
t.Fatalf("failed to write index.json: %v", err)
}
// Create a minimal safetensors file
createMinimalSafetensors(t, filepath.Join(dir, "model.safetensors"))
var configNames []string
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
io.ReadAll(r)
configNames = append(configNames, name)
return LayerInfo{Name: name, Digest: "sha256:test"}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
io.ReadAll(r)
return []LayerInfo{{Name: name}}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
return nil
}
progressFn := func(status string) {}
err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
if err != nil {
t.Fatalf("CreateSafetensorsModel failed: %v", err)
}
// Verify model.safetensors.index.json was not included
for _, name := range configNames {
if name == "model.safetensors.index.json" {
t.Error("model.safetensors.index.json should have been skipped")
}
}
}
func TestResolveManifestPath(t *testing.T) {
tests := []struct {
name string
modelName string
wantParts []string // Parts that should appear in the path
}{
{
name: "simple model name",
modelName: "llama2",
wantParts: []string{"registry.ollama.ai", "library", "llama2", "latest"},
},
{
name: "model name with tag",
modelName: "llama2:7b",
wantParts: []string{"registry.ollama.ai", "library", "llama2", "7b"},
},
{
name: "model name with namespace",
modelName: "myuser/mymodel",
wantParts: []string{"registry.ollama.ai", "myuser", "mymodel", "latest"},
},
{
name: "model name with namespace and tag",
modelName: "myuser/mymodel:v1",
wantParts: []string{"registry.ollama.ai", "myuser", "mymodel", "v1"},
},
{
name: "fully qualified model name",
modelName: "registry.example.com/namespace/model:tag",
wantParts: []string{"registry.example.com", "namespace", "model", "tag"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := resolveManifestPath(tt.modelName)
for _, part := range tt.wantParts {
if !strings.Contains(got, part) {
t.Errorf("resolveManifestPath(%q) = %q, missing part %q", tt.modelName, got, part)
}
}
})
}
}
func TestLayerInfo(t *testing.T) {
layer := LayerInfo{
Digest: "sha256:abc123",
Size: 1024,
MediaType: "application/vnd.ollama.image.tensor",
Name: "model.weight",
}
if layer.Digest != "sha256:abc123" {
t.Errorf("Digest = %q, want %q", layer.Digest, "sha256:abc123")
}
if layer.Size != 1024 {
t.Errorf("Size = %d, want %d", layer.Size, 1024)
}
if layer.MediaType != "application/vnd.ollama.image.tensor" {
t.Errorf("MediaType = %q, want %q", layer.MediaType, "application/vnd.ollama.image.tensor")
}
if layer.Name != "model.weight" {
t.Errorf("Name = %q, want %q", layer.Name, "model.weight")
}
}
func TestModelConfig(t *testing.T) {
config := ModelConfig{
ModelFormat: "safetensors",
Capabilities: []string{"completion", "chat"},
}
if config.ModelFormat != "safetensors" {
t.Errorf("ModelFormat = %q, want %q", config.ModelFormat, "safetensors")
}
if len(config.Capabilities) != 2 {
t.Errorf("Capabilities length = %d, want %d", len(config.Capabilities), 2)
}
}
func TestManifest(t *testing.T) {
manifest := Manifest{
SchemaVersion: 2,
MediaType: "application/vnd.oci.image.manifest.v1+json",
Config: ManifestLayer{
MediaType: "application/vnd.docker.container.image.v1+json",
Digest: "sha256:config",
Size: 100,
},
Layers: []ManifestLayer{
{
MediaType: "application/vnd.ollama.image.tensor",
Digest: "sha256:layer1",
Size: 1000,
Name: "weight.bin",
},
},
}
if manifest.SchemaVersion != 2 {
t.Errorf("SchemaVersion = %d, want %d", manifest.SchemaVersion, 2)
}
if manifest.Config.Digest != "sha256:config" {
t.Errorf("Config.Digest = %q, want %q", manifest.Config.Digest, "sha256:config")
}
if len(manifest.Layers) != 1 {
t.Errorf("Layers length = %d, want %d", len(manifest.Layers), 1)
}
if manifest.Layers[0].Name != "weight.bin" {
t.Errorf("Layers[0].Name = %q, want %q", manifest.Layers[0].Name, "weight.bin")
}
}
func TestShouldQuantize(t *testing.T) {
tests := []struct {
name string
tensor string
component string
want bool
}{
// VAE component should never be quantized
{"vae weight", "decoder.weight", "vae", false},
{"vae bias", "decoder.bias", "vae", false},
// Embeddings should not be quantized
{"embedding weight", "embed_tokens.weight", "", false},
{"embedding in name", "token_embedding.weight", "", false},
// Norms should not be quantized
{"layer norm", "layer_norm.weight", "", false},
{"rms norm", "rms_norm.weight", "", false},
{"ln prefix", "ln_1.weight", "", false},
{"layernorm in name", "input_layernorm.weight", "", false},
// Biases should not be quantized
{"bias tensor", "attention.bias", "", false},
{"proj bias", "o_proj.bias", "", false},
// Linear weights should be quantized
{"linear weight", "q_proj.weight", "", true},
{"attention weight", "self_attn.weight", "", true},
{"mlp weight", "mlp.gate_proj.weight", "", true},
// Transformer component weights should be quantized
{"transformer weight", "layers.0.weight", "transformer", true},
{"text_encoder weight", "encoder.weight", "text_encoder", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := ShouldQuantize(tt.tensor, tt.component)
if got != tt.want {
t.Errorf("ShouldQuantize(%q, %q) = %v, want %v", tt.tensor, tt.component, got, tt.want)
}
})
}
}
func TestShouldQuantizeTensor(t *testing.T) {
tests := []struct {
name string
tensor string
shape []int32
want bool
}{
// 2D tensors with sufficient size should be quantized
{"large 2D weight", "q_proj.weight", []int32{4096, 4096}, true},
{"medium 2D weight", "small_proj.weight", []int32{128, 128}, true},
// Small tensors should not be quantized (< 1024 elements)
{"tiny 2D weight", "tiny.weight", []int32{16, 16}, false},
{"small 2D weight", "small.weight", []int32{31, 31}, false},
// 1D tensors should not be quantized
{"1D tensor", "layer_norm.weight", []int32{4096}, false},
// 3D+ tensors should not be quantized
{"3D tensor", "conv.weight", []int32{64, 64, 3}, false},
{"4D tensor", "conv2d.weight", []int32{64, 64, 3, 3}, false},
// Embeddings should not be quantized regardless of shape
{"embedding 2D", "embed_tokens.weight", []int32{32000, 4096}, false},
// Norms should not be quantized regardless of shape
{"norm 2D", "layer_norm.weight", []int32{4096, 1}, false},
// Biases should not be quantized
{"bias 2D", "proj.bias", []int32{4096, 1}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := ShouldQuantizeTensor(tt.tensor, tt.shape)
if got != tt.want {
t.Errorf("ShouldQuantizeTensor(%q, %v) = %v, want %v", tt.tensor, tt.shape, got, tt.want)
}
})
}
}
func TestCreateSafetensorsModel_WithQuantize(t *testing.T) {
dir := t.TempDir()
// Create config.json
configJSON := `{"model_type": "test", "architectures": ["TestModel"]}`
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil {
t.Fatalf("failed to write config.json: %v", err)
}
// Create a minimal safetensors file
createMinimalSafetensors(t, filepath.Join(dir, "model.safetensors"))
var quantizeRequested []string
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
io.ReadAll(r)
return LayerInfo{Name: name, Digest: "sha256:test"}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
io.ReadAll(r)
quantizeRequested = append(quantizeRequested, quantize)
return []LayerInfo{{Name: name}}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
return nil
}
progressFn := func(status string) {}
// Run with quantize enabled
err := CreateSafetensorsModel("test-model", dir, "fp8", createLayer, createTensorLayer, writeManifest, progressFn)
if err != nil {
t.Fatalf("CreateSafetensorsModel failed: %v", err)
}
// Verify quantize was passed to callback (will be false for small test tensor)
if len(quantizeRequested) == 0 {
t.Error("no tensors processed")
}
}
// createMinimalImageGenModel creates a minimal diffusers-style model directory
func createMinimalImageGenModel(t *testing.T, dir string) {
t.Helper()
// Create model_index.json
modelIndex := `{"_class_name": "FluxPipeline", "_diffusers_version": "0.30.0"}`
if err := os.WriteFile(filepath.Join(dir, "model_index.json"), []byte(modelIndex), 0o644); err != nil {
t.Fatalf("failed to write model_index.json: %v", err)
}
// Create transformer directory with a safetensors file
transformerDir := filepath.Join(dir, "transformer")
if err := os.MkdirAll(transformerDir, 0o755); err != nil {
t.Fatalf("failed to create transformer dir: %v", err)
}
createMinimalSafetensors(t, filepath.Join(transformerDir, "model.safetensors"))
// Create transformer config
transformerConfig := `{"hidden_size": 3072}`
if err := os.WriteFile(filepath.Join(transformerDir, "config.json"), []byte(transformerConfig), 0o644); err != nil {
t.Fatalf("failed to write transformer config: %v", err)
}
}
func TestCreateImageGenModel(t *testing.T) {
dir := t.TempDir()
createMinimalImageGenModel(t, dir)
var manifestWritten bool
var manifestModelName string
var statusMessages []string
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
io.ReadAll(r)
return LayerInfo{Name: name, Digest: "sha256:test"}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
io.ReadAll(r)
return []LayerInfo{{Name: name, Digest: "sha256:tensor"}}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
manifestWritten = true
manifestModelName = modelName
return nil
}
progressFn := func(status string) {
statusMessages = append(statusMessages, status)
}
err := CreateImageGenModel("test-imagegen", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
if err != nil {
t.Fatalf("CreateImageGenModel failed: %v", err)
}
if !manifestWritten {
t.Error("manifest was not written")
}
if manifestModelName != "test-imagegen" {
t.Errorf("manifest model name = %q, want %q", manifestModelName, "test-imagegen")
}
if len(statusMessages) == 0 {
t.Error("no status messages received")
}
}
func TestCreateImageGenModel_NoModelIndex(t *testing.T) {
dir := t.TempDir()
// Create only transformer without model_index.json
transformerDir := filepath.Join(dir, "transformer")
if err := os.MkdirAll(transformerDir, 0o755); err != nil {
t.Fatalf("failed to create transformer dir: %v", err)
}
createMinimalSafetensors(t, filepath.Join(transformerDir, "model.safetensors"))
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
io.ReadAll(r)
return LayerInfo{Name: name}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
io.ReadAll(r)
return []LayerInfo{{Name: name}}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
return nil
}
progressFn := func(status string) {}
err := CreateImageGenModel("test-imagegen", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
if err == nil {
t.Error("expected error for missing model_index.json, got nil")
}
}
func TestCreateImageGenModel_WithQuantize(t *testing.T) {
dir := t.TempDir()
createMinimalImageGenModel(t, dir)
var quantizeRequested []string
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
io.ReadAll(r)
return LayerInfo{Name: name, Digest: "sha256:test"}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
io.ReadAll(r)
quantizeRequested = append(quantizeRequested, quantize)
return []LayerInfo{{Name: name}}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
return nil
}
progressFn := func(status string) {}
err := CreateImageGenModel("test-imagegen", dir, "fp8", createLayer, createTensorLayer, writeManifest, progressFn)
if err != nil {
t.Fatalf("CreateImageGenModel failed: %v", err)
}
if len(quantizeRequested) == 0 {
t.Error("no tensors processed")
}
}

View File

@@ -1,4 +1,4 @@
package imagegen
package create
import (
"bytes"
@@ -12,43 +12,27 @@ import (
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// IsTensorModelDir checks if the directory contains a tensor model
// by looking for model_index.json, which is the standard diffusers pipeline config.
func IsTensorModelDir(dir string) bool {
_, err := os.Stat(filepath.Join(dir, "model_index.json"))
return err == nil
}
// LayerInfo holds metadata for a created layer.
type LayerInfo struct {
Digest string
Size int64
MediaType string
Name string // Path-style name: "component/tensor" or "path/to/config.json"
}
// LayerCreator is called to create a blob layer.
// name is the path-style name (e.g., "tokenizer/tokenizer.json")
type LayerCreator func(r io.Reader, mediaType, name string) (LayerInfo, error)
// TensorLayerCreator creates a tensor blob layer with metadata.
// name is the path-style name including component (e.g., "text_encoder/model.embed_tokens.weight")
type TensorLayerCreator func(r io.Reader, name, dtype string, shape []int32) (LayerInfo, error)
// ManifestWriter writes the manifest file.
type ManifestWriter func(modelName string, config LayerInfo, layers []LayerInfo) error
// CreateModel imports an image generation model from a directory.
// CreateImageGenModel imports an image generation model from a directory.
// Stores each tensor as a separate blob for fine-grained deduplication.
// If quantize is "fp8", linear weights in transformer/text_encoder are quantized to mxfp8 format.
// If quantize is specified, linear weights in transformer/text_encoder are quantized.
// Supported quantization types: fp8 (or empty for no quantization).
// Layer creation and manifest writing are done via callbacks to avoid import cycles.
func CreateModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
func CreateImageGenModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
// Validate quantization type
switch quantize {
case "", "fp8":
// valid
default:
return fmt.Errorf("unsupported quantization type %q: supported types are fp8", quantize)
}
var layers []LayerInfo
var configLayer LayerInfo
var totalParams int64 // Count parameters from original tensor shapes
var torchDtype string // Read from component config for quantization display
// Components to process - extract individual tensors from each
components := []string{"text_encoder", "transformer", "vae", "vision_language_encoder"}
components := []string{"text_encoder", "transformer", "vae"}
for _, component := range components {
componentDir := filepath.Join(modelDir, component)
@@ -77,8 +61,8 @@ func CreateModel(modelName, modelDir, quantize string, createLayer LayerCreator,
tensorNames := extractor.ListTensors()
quantizeMsg := ""
if quantize == "fp8" && component != "vae" {
quantizeMsg = ", quantizing to fp8"
if quantize != "" && component != "vae" {
quantizeMsg = ", quantizing to " + quantize
}
fn(fmt.Sprintf("importing %s/%s (%d tensors%s)", component, entry.Name(), len(tensorNames), quantizeMsg))
@@ -103,11 +87,14 @@ func CreateModel(modelName, modelDir, quantize string, createLayer LayerCreator,
// Use path-style name: "component/tensor_name"
fullName := component + "/" + tensorName
// Determine if this tensor should be quantized
doQuantize := quantize == "fp8" && ShouldQuantize(tensorName, component)
// Determine quantization type for this tensor (empty string if not quantizing)
quantizeType := ""
if quantize != "" && ShouldQuantize(tensorName, component) && canQuantizeShape(td.Shape) {
quantizeType = quantize
}
// createTensorLayer returns multiple layers if quantizing (weight + scales)
newLayers, err := createTensorLayer(td.SafetensorsReader(), fullName, td.Dtype, td.Shape, doQuantize)
newLayers, err := createTensorLayer(td.SafetensorsReader(), fullName, td.Dtype, td.Shape, quantizeType)
if err != nil {
extractor.Close()
return fmt.Errorf("failed to create layer for %s: %w", fullName, err)
@@ -119,6 +106,19 @@ func CreateModel(modelName, modelDir, quantize string, createLayer LayerCreator,
}
}
// Read torch_dtype from text_encoder config for quantization display
if torchDtype == "" {
textEncoderConfig := filepath.Join(modelDir, "text_encoder/config.json")
if data, err := os.ReadFile(textEncoderConfig); err == nil {
var cfg struct {
TorchDtype string `json:"torch_dtype"`
}
if json.Unmarshal(data, &cfg) == nil && cfg.TorchDtype != "" {
torchDtype = cfg.TorchDtype
}
}
}
// Import config files
configFiles := []string{
"model_index.json",
@@ -126,13 +126,10 @@ func CreateModel(modelName, modelDir, quantize string, createLayer LayerCreator,
"text_encoder/generation_config.json",
"transformer/config.json",
"vae/config.json",
"vision_language_encoder/config.json",
"scheduler/scheduler_config.json",
"tokenizer/tokenizer.json",
"tokenizer/tokenizer_config.json",
"tokenizer/vocab.json",
"processor/tokenizer.json", // GLM-Image main tokenizer
"processor/tokenizer_config.json", // GLM-Image tokenizer config
}
for _, cfgPath := range configFiles {
@@ -167,11 +164,11 @@ func CreateModel(modelName, modelDir, quantize string, createLayer LayerCreator,
// Add parameter count (counted from tensor shapes during import)
cfg["parameter_count"] = totalParams
// Add quantization info
if quantize == "fp8" {
cfg["quantization"] = "FP8"
// Add quantization info - use quantize type if set, otherwise torch_dtype
if quantize != "" {
cfg["quantization"] = strings.ToUpper(quantize)
} else {
cfg["quantization"] = "BF16"
cfg["quantization"] = torchDtype
}
data, err = json.MarshalIndent(cfg, "", " ")
@@ -214,3 +211,12 @@ func CreateModel(modelName, modelDir, quantize string, createLayer LayerCreator,
fn(fmt.Sprintf("successfully imported %s with %d layers", modelName, len(layers)))
return nil
}
// canQuantizeShape returns true if a tensor shape is compatible with MLX quantization.
// MLX requires the last dimension to be divisible by the group size (32).
func canQuantizeShape(shape []int32) bool {
if len(shape) < 2 {
return false
}
return shape[len(shape)-1]%32 == 0
}

View File

@@ -1,231 +0,0 @@
package api
import (
"fmt"
"net/http"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/x/imagegen"
)
// RunnerScheduler is the interface for scheduling a model runner.
// This is implemented by server.Server to avoid circular imports.
type RunnerScheduler interface {
ScheduleImageGenRunner(ctx *gin.Context, modelName string, opts api.Options, keepAlive *api.Duration) (llm.LlamaServer, error)
}
// RegisterRoutes registers the image generation API routes.
func RegisterRoutes(r gin.IRouter, scheduler RunnerScheduler) {
r.POST("/v1/images/generations", func(c *gin.Context) {
ImageGenerationHandler(c, scheduler)
})
}
// ImageGenerationHandler handles OpenAI-compatible image generation requests.
func ImageGenerationHandler(c *gin.Context, scheduler RunnerScheduler) {
var req ImageGenerationRequest
if err := c.BindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": err.Error()}})
return
}
// Validate required fields
if req.Model == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": "model is required"}})
return
}
if req.Prompt == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": "prompt is required"}})
return
}
// Apply defaults
if req.N == 0 {
req.N = 1
}
if req.Size == "" {
req.Size = "1024x1024"
}
if req.ResponseFormat == "" {
req.ResponseFormat = "b64_json"
}
// Verify model exists
if imagegen.ResolveModelName(req.Model) == "" {
c.JSON(http.StatusNotFound, gin.H{"error": gin.H{"message": fmt.Sprintf("model %q not found", req.Model)}})
return
}
// Parse size
width, height := parseSize(req.Size)
// Build options - we repurpose NumCtx/NumGPU for width/height
opts := api.Options{}
opts.NumCtx = int(width)
opts.NumGPU = int(height)
// Schedule runner
runner, err := scheduler.ScheduleImageGenRunner(c, req.Model, opts, nil)
if err != nil {
status := http.StatusInternalServerError
if strings.Contains(err.Error(), "not found") {
status = http.StatusNotFound
}
c.JSON(status, gin.H{"error": gin.H{"message": err.Error()}})
return
}
// Build completion request
completionReq := llm.CompletionRequest{
Prompt: req.Prompt,
Options: &opts,
}
if req.Stream {
handleStreamingResponse(c, runner, completionReq, req.ResponseFormat)
} else {
handleNonStreamingResponse(c, runner, completionReq, req.ResponseFormat)
}
}
func handleStreamingResponse(c *gin.Context, runner llm.LlamaServer, req llm.CompletionRequest, format string) {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
var imageBase64 string
err := runner.Completion(c.Request.Context(), req, func(resp llm.CompletionResponse) {
if resp.Done {
imageBase64 = extractBase64(resp.Content)
} else {
progress := parseProgress(resp.Content)
if progress.Total > 0 {
c.SSEvent("progress", progress)
c.Writer.Flush()
}
}
})
if err != nil {
c.SSEvent("error", gin.H{"error": err.Error()})
return
}
c.SSEvent("done", buildResponse(imageBase64, format))
}
func handleNonStreamingResponse(c *gin.Context, runner llm.LlamaServer, req llm.CompletionRequest, format string) {
var imageBase64 string
err := runner.Completion(c.Request.Context(), req, func(resp llm.CompletionResponse) {
if resp.Done {
imageBase64 = extractBase64(resp.Content)
}
})
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{"message": err.Error()}})
return
}
c.JSON(http.StatusOK, buildResponse(imageBase64, format))
}
func parseSize(size string) (int32, int32) {
parts := strings.Split(size, "x")
if len(parts) != 2 {
return 1024, 1024
}
w, _ := strconv.Atoi(parts[0])
h, _ := strconv.Atoi(parts[1])
if w == 0 {
w = 1024
}
if h == 0 {
h = 1024
}
return int32(w), int32(h)
}
func extractBase64(content string) string {
if strings.HasPrefix(content, "IMAGE_BASE64:") {
return content[13:]
}
return ""
}
func parseProgress(content string) ImageProgressEvent {
var step, total int
fmt.Sscanf(content, "\rGenerating: step %d/%d", &step, &total)
return ImageProgressEvent{Step: step, Total: total}
}
func buildResponse(imageBase64, format string) ImageGenerationResponse {
resp := ImageGenerationResponse{
Created: time.Now().Unix(),
Data: make([]ImageData, 1),
}
if imageBase64 == "" {
return resp
}
if format == "url" {
// URL format not supported when using base64 transfer
resp.Data[0].B64JSON = imageBase64
} else {
resp.Data[0].B64JSON = imageBase64
}
return resp
}
// HandleGenerateRequest handles Ollama /api/generate requests for image gen models.
// This allows routes.go to delegate image generation with minimal code.
func HandleGenerateRequest(c *gin.Context, scheduler RunnerScheduler, modelName, prompt string, keepAlive *api.Duration, streamFn func(c *gin.Context, ch chan any)) {
opts := api.Options{}
// Schedule runner
runner, err := scheduler.ScheduleImageGenRunner(c, modelName, opts, keepAlive)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// Build completion request
completionReq := llm.CompletionRequest{
Prompt: prompt,
Options: &opts,
}
// Stream responses via channel
ch := make(chan any)
go func() {
defer close(ch)
err := runner.Completion(c.Request.Context(), completionReq, func(resp llm.CompletionResponse) {
ch <- GenerateResponse{
Model: modelName,
CreatedAt: time.Now().UTC(),
Response: resp.Content,
Done: resp.Done,
}
})
if err != nil {
// Log error but don't block - channel is already being consumed
_ = err
}
}()
streamFn(c, ch)
}
// GenerateResponse matches api.GenerateResponse structure for streaming.
type GenerateResponse struct {
Model string `json:"model"`
CreatedAt time.Time `json:"created_at"`
Response string `json:"response"`
Done bool `json:"done"`
}

View File

@@ -1,31 +0,0 @@
// Package api provides OpenAI-compatible image generation API types.
package api
// ImageGenerationRequest is an OpenAI-compatible image generation request.
type ImageGenerationRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
N int `json:"n,omitempty"`
Size string `json:"size,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
Stream bool `json:"stream,omitempty"`
}
// ImageGenerationResponse is an OpenAI-compatible image generation response.
type ImageGenerationResponse struct {
Created int64 `json:"created"`
Data []ImageData `json:"data"`
}
// ImageData contains the generated image data.
type ImageData struct {
URL string `json:"url,omitempty"`
B64JSON string `json:"b64_json,omitempty"`
RevisedPrompt string `json:"revised_prompt,omitempty"`
}
// ImageProgressEvent is sent during streaming to indicate generation progress.
type ImageProgressEvent struct {
Step int `json:"step"`
Total int `json:"total"`
}

View File

@@ -7,7 +7,6 @@ package imagegen
import (
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
@@ -39,79 +38,20 @@ func DefaultOptions() ImageGenOptions {
return ImageGenOptions{
Width: 1024,
Height: 1024,
Steps: 9,
Steps: 0, // 0 means model default
Seed: 0, // 0 means random
}
}
// ModelInfo contains metadata about an image generation model.
type ModelInfo struct {
Architecture string
ParameterCount int64
Quantization string
}
// GetModelInfo returns metadata about an image generation model.
func GetModelInfo(modelName string) (*ModelInfo, error) {
manifest, err := LoadManifest(modelName)
if err != nil {
return nil, fmt.Errorf("failed to load manifest: %w", err)
}
info := &ModelInfo{}
// Read model_index.json for architecture, parameter count, and quantization
if data, err := manifest.ReadConfig("model_index.json"); err == nil {
var index struct {
Architecture string `json:"architecture"`
ParameterCount int64 `json:"parameter_count"`
Quantization string `json:"quantization"`
}
if json.Unmarshal(data, &index) == nil {
info.Architecture = index.Architecture
info.ParameterCount = index.ParameterCount
info.Quantization = index.Quantization
}
}
// Fallback: detect quantization from tensor names if not in config
if info.Quantization == "" {
for _, layer := range manifest.Manifest.Layers {
if strings.HasSuffix(layer.Name, ".weight_scale") {
info.Quantization = "FP8"
break
}
}
if info.Quantization == "" {
info.Quantization = "BF16"
}
}
// Fallback: estimate parameter count if not in config
if info.ParameterCount == 0 {
var totalSize int64
for _, layer := range manifest.Manifest.Layers {
if layer.MediaType == "application/vnd.ollama.image.tensor" {
if !strings.HasSuffix(layer.Name, "_scale") && !strings.HasSuffix(layer.Name, "_qbias") {
totalSize += layer.Size
}
}
}
// Assume BF16 (2 bytes/param) as rough estimate
info.ParameterCount = totalSize / 2
}
return info, nil
}
// RegisterFlags adds image generation flags to the given command.
// Flags are hidden since they only apply to image generation models.
func RegisterFlags(cmd *cobra.Command) {
cmd.Flags().Int("width", 1024, "Image width")
cmd.Flags().Int("height", 1024, "Image height")
cmd.Flags().Int("steps", 9, "Denoising steps")
cmd.Flags().Int("steps", 0, "Denoising steps (0 = model default)")
cmd.Flags().Int("seed", 0, "Random seed (0 for random)")
cmd.Flags().String("negative", "", "Negative prompt")
// Hide from main flags section - shown in separate section via AppendFlagsDocs
cmd.Flags().MarkHidden("width")
cmd.Flags().MarkHidden("height")
cmd.Flags().MarkHidden("steps")
@@ -119,6 +59,19 @@ func RegisterFlags(cmd *cobra.Command) {
cmd.Flags().MarkHidden("negative")
}
// AppendFlagsDocs appends image generation flags documentation to the command's usage template.
func AppendFlagsDocs(cmd *cobra.Command) {
usage := `
Image Generation Flags (experimental):
--width int Image width
--height int Image height
--steps int Denoising steps
--seed int Random seed
--negative str Negative prompt
`
cmd.SetUsageTemplate(cmd.UsageTemplate() + usage)
}
// RunCLI handles the CLI for image generation models.
// Returns true if it handled the request, false if the caller should continue with normal flow.
// Supports flags: --width, --height, --steps, --seed, --negative
@@ -158,17 +111,13 @@ func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keep
return err
}
// Build request with image gen options encoded in Options fields
// NumCtx=width, NumGPU=height, NumPredict=steps, Seed=seed
req := &api.GenerateRequest{
Model: modelName,
Prompt: prompt,
Options: map[string]any{
"num_ctx": opts.Width,
"num_gpu": opts.Height,
"num_predict": opts.Steps,
"seed": opts.Seed,
},
Width: int32(opts.Width),
Height: int32(opts.Height),
Steps: int32(opts.Steps),
Seed: int64(opts.Seed),
}
if keepAlive != nil {
req.KeepAlive = keepAlive
@@ -182,32 +131,25 @@ func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keep
var stepBar *progress.StepBar
var imageBase64 string
err = client.Generate(cmd.Context(), req, func(resp api.GenerateResponse) error {
content := resp.Response
// Handle progress updates - parse step info and switch to step bar
if strings.HasPrefix(content, "\rGenerating:") {
var step, total int
fmt.Sscanf(content, "\rGenerating: step %d/%d", &step, &total)
if stepBar == nil && total > 0 {
// Handle progress updates using structured fields
if resp.Total > 0 {
if stepBar == nil {
spinner.Stop()
stepBar = progress.NewStepBar("Generating", total)
stepBar = progress.NewStepBar("Generating", int(resp.Total))
p.Add("", stepBar)
}
if stepBar != nil {
stepBar.Set(step)
}
return nil
stepBar.Set(int(resp.Completed))
}
// Handle final response with base64 image data
if resp.Done && strings.HasPrefix(content, "IMAGE_BASE64:") {
imageBase64 = content[13:]
// Handle final response with image data
if resp.Done && len(resp.Images) > 0 {
imageBase64 = resp.Images[0]
}
return nil
})
p.Stop()
p.StopAndClear()
if err != nil {
return err
}
@@ -245,6 +187,23 @@ func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duratio
return err
}
// Preload the model with the specified keepalive
p := progress.NewProgress(os.Stderr)
spinner := progress.NewSpinner("")
p.Add("", spinner)
preloadReq := &api.GenerateRequest{
Model: modelName,
KeepAlive: keepAlive,
}
if err := client.Generate(cmd.Context(), preloadReq, func(resp api.GenerateResponse) error {
return nil
}); err != nil {
p.StopAndClear()
return fmt.Errorf("failed to load model: %w", err)
}
p.StopAndClear()
scanner, err := readline.New(readline.Prompt{
Prompt: ">>> ",
Placeholder: "Describe an image to generate (/help for commands)",
@@ -301,12 +260,10 @@ func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duratio
req := &api.GenerateRequest{
Model: modelName,
Prompt: line,
Options: map[string]any{
"num_ctx": opts.Width,
"num_gpu": opts.Height,
"num_predict": opts.Steps,
"seed": opts.Seed,
},
Width: int32(opts.Width),
Height: int32(opts.Height),
Steps: int32(opts.Steps),
Seed: int64(opts.Seed),
}
if keepAlive != nil {
req.KeepAlive = keepAlive
@@ -321,32 +278,25 @@ func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duratio
var imageBase64 string
err = client.Generate(cmd.Context(), req, func(resp api.GenerateResponse) error {
content := resp.Response
// Handle progress updates - parse step info and switch to step bar
if strings.HasPrefix(content, "\rGenerating:") {
var step, total int
fmt.Sscanf(content, "\rGenerating: step %d/%d", &step, &total)
if stepBar == nil && total > 0 {
// Handle progress updates using structured fields
if resp.Total > 0 {
if stepBar == nil {
spinner.Stop()
stepBar = progress.NewStepBar("Generating", total)
stepBar = progress.NewStepBar("Generating", int(resp.Total))
p.Add("", stepBar)
}
if stepBar != nil {
stepBar.Set(step)
}
return nil
stepBar.Set(int(resp.Completed))
}
// Handle final response with base64 image data
if resp.Done && strings.HasPrefix(content, "IMAGE_BASE64:") {
imageBase64 = content[13:]
// Handle final response with image data
if resp.Done && len(resp.Images) > 0 {
imageBase64 = resp.Images[0]
}
return nil
})
p.Stop()
p.StopAndClear()
if err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
continue

View File

@@ -1,190 +0,0 @@
// Package client provides client-side model creation for tensor-based models.
//
// This package is in x/ because the tensor model storage format is under development.
// It also exists to break an import cycle: server imports x/imagegen, so x/imagegen
// cannot import server. This sub-package can import server because server doesn't
// import it.
//
// TODO (jmorganca): This is temporary. When tensor models are promoted to production:
// 1. Add proper API endpoints for tensor model creation
// 2. Move tensor extraction to server-side
// 3. Remove this package
// 4. Follow the same client→server pattern as regular model creation
package client
import (
"bytes"
"encoding/json"
"fmt"
"io"
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/server"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/x/imagegen"
)
// MinOllamaVersion is the minimum Ollama version required for image generation models.
const MinOllamaVersion = "0.14.0"
// CreateModel imports a tensor-based model from a local directory.
// This creates blobs and manifest directly on disk, bypassing the HTTP API.
// If quantize is "fp8", weights will be quantized to mxfp8 format during import.
//
// TODO (jmorganca): Replace with API-based creation when promoted to production.
func CreateModel(modelName, modelDir, quantize string, p *progress.Progress) error {
if !imagegen.IsTensorModelDir(modelDir) {
return fmt.Errorf("%s is not an image generation model directory (model_index.json not found)", modelDir)
}
status := "importing image generation model"
spinner := progress.NewSpinner(status)
p.Add("imagegen", spinner)
// Create layer callback for config files
createLayer := func(r io.Reader, mediaType, name string) (imagegen.LayerInfo, error) {
layer, err := server.NewLayer(r, mediaType)
if err != nil {
return imagegen.LayerInfo{}, err
}
layer.Name = name
return imagegen.LayerInfo{
Digest: layer.Digest,
Size: layer.Size,
MediaType: layer.MediaType,
Name: name,
}, nil
}
// Create tensor layer callback for individual tensors
// name is path-style: "component/tensor_name"
// When quantize is true, returns multiple layers (weight + scales)
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, doQuantize bool) ([]imagegen.LayerInfo, error) {
if doQuantize {
// Check if quantization is supported
if !QuantizeSupported() {
return nil, fmt.Errorf("quantization requires MLX support")
}
// Quantize the tensor (affine mode returns weight, scales, qbiases)
qweightData, scalesData, qbiasData, _, _, _, err := quantizeTensor(r, name, dtype, shape)
if err != nil {
return nil, fmt.Errorf("failed to quantize %s: %w", name, err)
}
// Create layer for quantized weight
weightLayer, err := server.NewLayer(bytes.NewReader(qweightData), server.MediaTypeImageTensor)
if err != nil {
return nil, err
}
// Create layer for scales (use _scale suffix convention)
scalesLayer, err := server.NewLayer(bytes.NewReader(scalesData), server.MediaTypeImageTensor)
if err != nil {
return nil, err
}
layers := []imagegen.LayerInfo{
{
Digest: weightLayer.Digest,
Size: weightLayer.Size,
MediaType: weightLayer.MediaType,
Name: name, // Keep original name for weight
},
{
Digest: scalesLayer.Digest,
Size: scalesLayer.Size,
MediaType: scalesLayer.MediaType,
Name: name + "_scale", // Add _scale suffix
},
}
// Add qbiases layer if present (affine mode)
if qbiasData != nil {
qbiasLayer, err := server.NewLayer(bytes.NewReader(qbiasData), server.MediaTypeImageTensor)
if err != nil {
return nil, err
}
layers = append(layers, imagegen.LayerInfo{
Digest: qbiasLayer.Digest,
Size: qbiasLayer.Size,
MediaType: qbiasLayer.MediaType,
Name: name + "_qbias", // Add _qbias suffix
})
}
return layers, nil
}
// Non-quantized path: just create a single layer
layer, err := server.NewLayer(r, server.MediaTypeImageTensor)
if err != nil {
return nil, err
}
return []imagegen.LayerInfo{
{
Digest: layer.Digest,
Size: layer.Size,
MediaType: layer.MediaType,
Name: name,
},
}, nil
}
// Create manifest writer callback
writeManifest := func(modelName string, config imagegen.LayerInfo, layers []imagegen.LayerInfo) error {
name := model.ParseName(modelName)
if !name.IsValid() {
return fmt.Errorf("invalid model name: %s", modelName)
}
// Create a proper config blob with version requirement
configData := model.ConfigV2{
ModelFormat: "safetensors",
Capabilities: []string{"image"},
Requires: MinOllamaVersion,
}
configJSON, err := json.Marshal(configData)
if err != nil {
return fmt.Errorf("failed to marshal config: %w", err)
}
// Create config layer blob
configLayer, err := server.NewLayer(bytes.NewReader(configJSON), "application/vnd.docker.container.image.v1+json")
if err != nil {
return fmt.Errorf("failed to create config layer: %w", err)
}
// Convert LayerInfo to server.Layer (include the original model_index.json in layers)
serverLayers := make([]server.Layer, len(layers))
for i, l := range layers {
serverLayers[i] = server.Layer{
MediaType: l.MediaType,
Digest: l.Digest,
Size: l.Size,
Name: l.Name,
}
}
return server.WriteManifest(name, configLayer, serverLayers)
}
// Progress callback
progressFn := func(msg string) {
spinner.Stop()
status = msg
spinner = progress.NewSpinner(status)
p.Add("imagegen", spinner)
}
err := imagegen.CreateModel(modelName, modelDir, quantize, createLayer, createTensorLayer, writeManifest, progressFn)
spinner.Stop()
if err != nil {
return err
}
fmt.Printf("Created image generation model '%s'\n", modelName)
return nil
}

View File

@@ -65,12 +65,12 @@ func (s *utf8Streamer) Flush() string {
return result
}
func init() {
generationStream = mlx.NewStream()
}
// withStream runs fn with the generation stream as default
func withStream(fn func()) {
// Lazy initialization of generationStream
if generationStream == nil {
generationStream = mlx.NewStream()
}
orig := mlx.GetDefaultStream()
mlx.SetDefaultStream(generationStream)
fn()

View File

@@ -11,11 +11,9 @@ import (
"os"
"path/filepath"
"runtime/pprof"
"strings"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/models/gemma3"
"github.com/ollama/ollama/x/imagegen/models/glm_image"
"github.com/ollama/ollama/x/imagegen/models/gpt_oss"
"github.com/ollama/ollama/x/imagegen/models/llama"
"github.com/ollama/ollama/x/imagegen/models/qwen_image"
@@ -50,7 +48,7 @@ func main() {
// Image generation params
width := flag.Int("width", 1024, "Image width")
height := flag.Int("height", 1024, "Image height")
steps := flag.Int("steps", 9, "Denoising steps")
steps := flag.Int("steps", 0, "Denoising steps (0 = model default)")
seed := flag.Int64("seed", 42, "Random seed")
out := flag.String("output", "output.png", "Output path")
@@ -63,7 +61,6 @@ func main() {
// Legacy mode flags
zimageFlag := flag.Bool("zimage", false, "Z-Image generation")
glmImageFlag := flag.Bool("glm-image", false, "GLM-Image generation")
qwenImage := flag.Bool("qwen-image", false, "Qwen-Image text-to-image generation")
qwenImageEdit := flag.Bool("qwen-image-edit", false, "Qwen-Image-Edit image editing")
var inputImages stringSlice
@@ -81,6 +78,11 @@ func main() {
return
}
// Check if MLX initialized successfully
if !mlx.IsMLXAvailable() {
log.Fatalf("MLX initialization failed: %v", mlx.GetMLXInitError())
}
// CPU profiling
if *cpuProfile != "" {
f, err := os.Create(*cpuProfile)
@@ -120,33 +122,6 @@ func main() {
if err == nil {
err = saveImageArray(img, *out)
}
case *glmImageFlag:
m := &glm_image.Model{}
// Use LoadFromPath if model path looks like a directory, otherwise use Load (ollama manifest)
var loadErr error
if strings.HasPrefix(*modelPath, ".") || strings.HasPrefix(*modelPath, "/") {
loadErr = m.LoadFromPath(*modelPath)
} else {
loadErr = m.Load(*modelPath)
}
if loadErr != nil {
log.Fatal(loadErr)
}
var img *mlx.Array
img, err = m.GenerateFromConfig(context.Background(), &glm_image.GenerateConfig{
Prompt: *prompt,
Width: int32(*width),
Height: int32(*height),
Steps: *steps,
Seed: *seed,
Temperature: float32(*temperature),
TopP: float32(*topP),
GuidanceScale: float32(*cfgScale),
MaxVisualTokens: int32(*maxTokens),
})
if err == nil {
err = saveImageArray(img, *out)
}
case *qwenImage:
m, loadErr := qwen_image.LoadPersistent(*modelPath)
if loadErr != nil {

View File

@@ -1,19 +0,0 @@
# Image generation models (experimental)
Experimental image generation models are available for **macOS** in Ollama:
## Available models
- [Z-Image-Turbo](https://ollama.com/x/z-image-turbo)
```
ollama run x/z-image-turbo
```
> **Note**: [`x`](https://ollama.com/x) is a username on ollama.com where the maintainer team uploads experimental models
More models coming soon:
1. Qwen-Image-2512
2. Qwen-Image-Edit-2511
3. GLM-Image

View File

@@ -175,3 +175,63 @@ func (m *ModelManifest) HasTensorLayers() bool {
}
return false
}
// ModelInfo contains metadata about an image generation model.
type ModelInfo struct {
Architecture string
ParameterCount int64
Quantization string
}
// GetModelInfo returns metadata about an image generation model.
func GetModelInfo(modelName string) (*ModelInfo, error) {
manifest, err := LoadManifest(modelName)
if err != nil {
return nil, fmt.Errorf("failed to load manifest: %w", err)
}
info := &ModelInfo{}
// Read model_index.json for architecture, parameter count, and quantization
if data, err := manifest.ReadConfig("model_index.json"); err == nil {
var index struct {
Architecture string `json:"architecture"`
ParameterCount int64 `json:"parameter_count"`
Quantization string `json:"quantization"`
}
if json.Unmarshal(data, &index) == nil {
info.Architecture = index.Architecture
info.ParameterCount = index.ParameterCount
info.Quantization = index.Quantization
}
}
// Fallback: detect quantization from tensor names if not in config
if info.Quantization == "" {
for _, layer := range manifest.Manifest.Layers {
if strings.HasSuffix(layer.Name, ".weight_scale") {
info.Quantization = "FP8"
break
}
}
if info.Quantization == "" {
info.Quantization = "BF16"
}
}
// Fallback: estimate parameter count if not in config
if info.ParameterCount == 0 {
var totalSize int64
for _, layer := range manifest.Manifest.Layers {
if layer.MediaType == "application/vnd.ollama.image.tensor" {
if !strings.HasSuffix(layer.Name, "_scale") && !strings.HasSuffix(layer.Name, "_qbias") {
totalSize += layer.Size
}
}
}
// Assume BF16 (2 bytes/param) as rough estimate
info.ParameterCount = totalSize / 2
}
return info, nil
}

View File

@@ -27,7 +27,6 @@ var modelVRAMEstimates = map[string]uint64{
"ZImagePipeline": 21 * GB, // ~21GB for Z-Image (text encoder + transformer + VAE)
"FluxPipeline": 21 * GB, // ~21GB for Flux (same architecture)
"QwenImagePipeline": 80 * GB, // TODO: verify actual requirements, using conservative estimate for now
"GlmImagePipeline": 80 * GB, // ~34GB weights + ~46GB working memory for 9B+7B hybrid model
}
// CheckPlatformSupport validates that image generation is supported on the current platform.
@@ -96,8 +95,3 @@ func EstimateVRAM(modelName string) uint64 {
}
return 21 * GB
}
// HasTensorLayers checks if the given model has tensor layers.
func HasTensorLayers(modelName string) bool {
return ResolveModelName(modelName) != ""
}

View File

@@ -94,13 +94,6 @@ func TestEstimateVRAMDefault(t *testing.T) {
}
}
func TestHasTensorLayers(t *testing.T) {
// Non-existent model should return false
if HasTensorLayers("nonexistent-model") {
t.Error("HasTensorLayers() should return false for non-existent model")
}
}
func TestResolveModelName(t *testing.T) {
// Non-existent model should return empty string
result := ResolveModelName("nonexistent-model")

View File

@@ -3,7 +3,7 @@
package mlx
/*
#include "mlx/c/mlx.h"
#include "mlx.h"
#include <stdlib.h>
// Forward declaration for Go callback

6
x/imagegen/mlx/doc.go Normal file
View File

@@ -0,0 +1,6 @@
//go:build mlx
// Package mlx provides Go bindings for the MLX-C library with dynamic loading support.
//
//go:generate go run generate_wrappers.go ../../../build/_deps/mlx-c-src/mlx/c mlx.h mlx.c
package mlx

View File

@@ -0,0 +1,439 @@
//go:build ignore
// This tool generates MLX-C dynamic loading wrappers.
// Usage: go run generate_wrappers.go <mlx-c-include-dir> <output-header> [output-impl]
package main
import (
"bytes"
"flag"
"fmt"
"io/fs"
"os"
"path/filepath"
"regexp"
"strings"
)
type Function struct {
Name string
ReturnType string
Params string
ParamNames []string
NeedsARM64Guard bool
}
func findHeaders(directory string) ([]string, error) {
var headers []string
err := filepath.WalkDir(directory, func(path string, d fs.DirEntry, err error) error {
if err != nil {
return err
}
if !d.IsDir() && strings.HasSuffix(path, ".h") {
headers = append(headers, path)
}
return nil
})
return headers, err
}
func cleanContent(content string) string {
// Remove single-line comments
re := regexp.MustCompile(`//.*?\n`)
content = re.ReplaceAllString(content, "\n")
// Remove multi-line comments
re = regexp.MustCompile(`/\*.*?\*/`)
content = re.ReplaceAllString(content, "")
// Remove preprocessor directives (lines starting with #) - use multiline mode
re = regexp.MustCompile(`(?m)^\s*#.*?$`)
content = re.ReplaceAllString(content, "")
// Remove extern "C" { and } blocks more conservatively
// Only remove the extern "C" { line, not the content inside
re = regexp.MustCompile(`extern\s+"C"\s*\{\s*?\n`)
content = re.ReplaceAllString(content, "\n")
// Remove standalone closing braces that are not part of function declarations
re = regexp.MustCompile(`\n\s*\}\s*\n`)
content = re.ReplaceAllString(content, "\n")
// Collapse whitespace and newlines
re = regexp.MustCompile(`\s+`)
content = re.ReplaceAllString(content, " ")
return content
}
func extractParamNames(params string) []string {
if params == "" || strings.TrimSpace(params) == "void" {
return []string{}
}
var names []string
// Split by comma, but respect parentheses (for function pointers)
parts := splitParams(params)
// Remove array brackets
arrayBrackets := regexp.MustCompile(`\[.*?\]`)
// Function pointer pattern
funcPtrPattern := regexp.MustCompile(`\(\s*\*\s*(\w+)\s*\)`)
// Type keywords to skip
typeKeywords := map[string]bool{
"const": true,
"struct": true,
"unsigned": true,
"signed": true,
"long": true,
"short": true,
"int": true,
"char": true,
"float": true,
"double": true,
"void": true,
"size_t": true,
"uint8_t": true,
"uint16_t": true,
"uint32_t": true,
"uint64_t": true,
"int8_t": true,
"int16_t": true,
"int32_t": true,
"int64_t": true,
"intptr_t": true,
"uintptr_t": true,
}
for _, part := range parts {
if part == "" {
continue
}
// Remove array brackets
part = arrayBrackets.ReplaceAllString(part, "")
// For function pointers like "void (*callback)(int)"
if matches := funcPtrPattern.FindStringSubmatch(part); len(matches) > 1 {
names = append(names, matches[1])
continue
}
// Regular parameter: last identifier
tokens := regexp.MustCompile(`\w+`).FindAllString(part, -1)
if len(tokens) > 0 {
// The last token is usually the parameter name
// Skip type keywords
for i := len(tokens) - 1; i >= 0; i-- {
if !typeKeywords[tokens[i]] {
names = append(names, tokens[i])
break
}
}
}
}
return names
}
func splitParams(params string) []string {
var parts []string
var current bytes.Buffer
depth := 0
for _, char := range params + "," {
switch char {
case '(':
depth++
current.WriteRune(char)
case ')':
depth--
current.WriteRune(char)
case ',':
if depth == 0 {
parts = append(parts, strings.TrimSpace(current.String()))
current.Reset()
} else {
current.WriteRune(char)
}
default:
current.WriteRune(char)
}
}
return parts
}
func parseFunctions(content string) []Function {
var functions []Function
// Match function declarations: return_type function_name(params);
// Matches both mlx_* and _mlx_* functions
pattern := regexp.MustCompile(`\b((?:const\s+)?(?:struct\s+)?[\w\s]+?[\*\s]*)\s+(_?mlx_\w+)\s*\(([^)]*(?:\([^)]*\)[^)]*)*)\)\s*;`)
matches := pattern.FindAllStringSubmatch(content, -1)
for _, match := range matches {
returnType := strings.TrimSpace(match[1])
funcName := strings.TrimSpace(match[2])
params := strings.TrimSpace(match[3])
// Skip if this looks like a variable declaration
if params == "" || strings.Contains(params, "{") {
continue
}
// Clean up return type
returnType = strings.Join(strings.Fields(returnType), " ")
// Extract parameter names
paramNames := extractParamNames(params)
// Check if ARM64 guard is needed
needsGuard := needsARM64Guard(funcName, returnType, params)
functions = append(functions, Function{
Name: funcName,
ReturnType: returnType,
Params: params,
ParamNames: paramNames,
NeedsARM64Guard: needsGuard,
})
}
return functions
}
func needsARM64Guard(name, retType, params string) bool {
return strings.Contains(name, "float16") ||
strings.Contains(name, "bfloat16") ||
strings.Contains(retType, "float16_t") ||
strings.Contains(retType, "bfloat16_t") ||
strings.Contains(params, "float16_t") ||
strings.Contains(params, "bfloat16_t")
}
func generateWrapperFiles(functions []Function, headerPath, implPath string) error {
// Generate header file
var headerBuf bytes.Buffer
headerBuf.WriteString("// AUTO-GENERATED by generate_wrappers.go - DO NOT EDIT\n")
headerBuf.WriteString("// This file provides wrapper declarations for MLX-C functions that use dlopen/dlsym\n")
headerBuf.WriteString("//\n")
headerBuf.WriteString("// Strategy: Include MLX-C headers for type definitions, then provide wrapper\n")
headerBuf.WriteString("// functions that shadow the originals, allowing Go code to call them directly (e.g., C.mlx_add).\n")
headerBuf.WriteString("// Function pointers are defined in mlx.c (single compilation unit).\n\n")
headerBuf.WriteString("#ifndef MLX_WRAPPERS_H\n")
headerBuf.WriteString("#define MLX_WRAPPERS_H\n\n")
headerBuf.WriteString("// Include MLX headers for type definitions and original declarations\n")
headerBuf.WriteString("#include \"mlx/c/mlx.h\"\n")
headerBuf.WriteString("#include \"mlx_dynamic.h\"\n")
headerBuf.WriteString("#include <stdio.h>\n\n")
// Undef all MLX functions to avoid conflicts
headerBuf.WriteString("// Undefine any existing MLX function macros\n")
for _, fn := range functions {
headerBuf.WriteString(fmt.Sprintf("#undef %s\n", fn.Name))
}
headerBuf.WriteString("\n")
// Function pointer extern declarations
headerBuf.WriteString("// Function pointer declarations (defined in mlx.c, loaded via dlsym)\n")
for _, fn := range functions {
if fn.NeedsARM64Guard {
headerBuf.WriteString("#if defined(__aarch64__) || defined(_M_ARM64)\n")
}
headerBuf.WriteString(fmt.Sprintf("extern %s (*%s_ptr)(%s);\n", fn.ReturnType, fn.Name, fn.Params))
if fn.NeedsARM64Guard {
headerBuf.WriteString("#endif\n")
}
}
headerBuf.WriteString("\n")
// Initialization function declaration
headerBuf.WriteString("// Initialize all function pointers via dlsym (defined in mlx.c)\n")
headerBuf.WriteString("int mlx_load_functions(void* handle);\n\n")
// Wrapper function declarations
headerBuf.WriteString("// Wrapper function declarations that call through function pointers\n")
headerBuf.WriteString("// Go code calls these directly as C.mlx_* (no #define redirection needed)\n")
for _, fn := range functions {
if fn.NeedsARM64Guard {
headerBuf.WriteString("#if defined(__aarch64__) || defined(_M_ARM64)\n")
}
headerBuf.WriteString(fmt.Sprintf("%s %s(%s);\n", fn.ReturnType, fn.Name, fn.Params))
if fn.NeedsARM64Guard {
headerBuf.WriteString("#endif\n")
}
headerBuf.WriteString("\n")
}
headerBuf.WriteString("#endif // MLX_WRAPPERS_H\n")
// Write header file
if err := os.WriteFile(headerPath, headerBuf.Bytes(), 0644); err != nil {
return fmt.Errorf("failed to write header file: %w", err)
}
// Generate implementation file
var implBuf bytes.Buffer
implBuf.WriteString("// AUTO-GENERATED by generate_wrappers.go - DO NOT EDIT\n")
implBuf.WriteString("// This file contains the function pointer definitions and initialization\n")
implBuf.WriteString("// All function pointers are in a single compilation unit to avoid duplication\n\n")
implBuf.WriteString("#include \"mlx/c/mlx.h\"\n")
implBuf.WriteString("#include \"mlx_dynamic.h\"\n")
implBuf.WriteString("#include <stdio.h>\n")
implBuf.WriteString("#include <dlfcn.h>\n\n")
// Function pointer definitions
implBuf.WriteString("// Function pointer definitions\n")
for _, fn := range functions {
if fn.NeedsARM64Guard {
implBuf.WriteString("#if defined(__aarch64__) || defined(_M_ARM64)\n")
}
implBuf.WriteString(fmt.Sprintf("%s (*%s_ptr)(%s) = NULL;\n", fn.ReturnType, fn.Name, fn.Params))
if fn.NeedsARM64Guard {
implBuf.WriteString("#endif\n")
}
}
implBuf.WriteString("\n")
// Initialization function
implBuf.WriteString("// Initialize all function pointers via dlsym\n")
implBuf.WriteString("int mlx_load_functions(void* handle) {\n")
implBuf.WriteString(" if (handle == NULL) {\n")
implBuf.WriteString(" fprintf(stderr, \"MLX: Invalid library handle\\n\");\n")
implBuf.WriteString(" return -1;\n")
implBuf.WriteString(" }\n\n")
for _, fn := range functions {
if fn.NeedsARM64Guard {
implBuf.WriteString("#if defined(__aarch64__) || defined(_M_ARM64)\n")
}
implBuf.WriteString(fmt.Sprintf(" %s_ptr = dlsym(handle, \"%s\");\n", fn.Name, fn.Name))
implBuf.WriteString(fmt.Sprintf(" if (%s_ptr == NULL) {\n", fn.Name))
implBuf.WriteString(fmt.Sprintf(" fprintf(stderr, \"MLX: Failed to load symbol: %s\\n\");\n", fn.Name))
implBuf.WriteString(" return -1;\n")
implBuf.WriteString(" }\n")
if fn.NeedsARM64Guard {
implBuf.WriteString("#endif\n")
}
}
implBuf.WriteString(" return 0;\n")
implBuf.WriteString("}\n\n")
// Wrapper function implementations
implBuf.WriteString("// Wrapper function implementations that call through function pointers\n")
for _, fn := range functions {
if fn.NeedsARM64Guard {
implBuf.WriteString("#if defined(__aarch64__) || defined(_M_ARM64)\n")
}
implBuf.WriteString(fmt.Sprintf("%s %s(%s) {\n", fn.ReturnType, fn.Name, fn.Params))
// Call through function pointer
if fn.ReturnType != "void" {
implBuf.WriteString(fmt.Sprintf(" return %s_ptr(", fn.Name))
} else {
implBuf.WriteString(fmt.Sprintf(" %s_ptr(", fn.Name))
}
// Pass parameters
implBuf.WriteString(strings.Join(fn.ParamNames, ", "))
implBuf.WriteString(");\n")
implBuf.WriteString("}\n")
if fn.NeedsARM64Guard {
implBuf.WriteString("#endif\n")
}
implBuf.WriteString("\n")
}
// Write implementation file
if err := os.WriteFile(implPath, implBuf.Bytes(), 0644); err != nil {
return fmt.Errorf("failed to write implementation file: %w", err)
}
return nil
}
func main() {
flag.Usage = func() {
fmt.Fprintf(flag.CommandLine.Output(), "Usage: go run generate_wrappers.go <mlx-c-include-dir> <output-header> [output-impl]\n")
fmt.Fprintf(flag.CommandLine.Output(), "Generate MLX-C dynamic loading wrappers.\n\n")
flag.PrintDefaults()
}
flag.Parse()
args := flag.Args()
if len(args) < 2 {
fmt.Fprintf(flag.CommandLine.Output(), "ERROR: Missing required arguments\n\n")
flag.Usage()
os.Exit(1)
}
headerDir := args[0]
outputHeader := args[1]
// Default implementation file is same name with .c extension
outputImpl := outputHeader
if len(args) > 2 {
outputImpl = args[2]
} else if strings.HasSuffix(outputHeader, ".h") {
outputImpl = outputHeader[:len(outputHeader)-2] + ".c"
}
// Check if header directory exists
if _, err := os.Stat(headerDir); os.IsNotExist(err) {
fmt.Fprintf(os.Stderr, "ERROR: MLX-C headers directory not found at: %s\n\n", headerDir)
fmt.Fprintf(os.Stderr, "Please run CMake first to download MLX-C dependencies:\n")
fmt.Fprintf(os.Stderr, " cmake -B build\n\n")
fmt.Fprintf(os.Stderr, "The CMake build will download and extract MLX-C headers needed for wrapper generation.\n")
os.Exit(1)
}
fmt.Fprintf(os.Stderr, "Parsing MLX-C headers from: %s\n", headerDir)
// Find all headers
headers, err := findHeaders(headerDir)
if err != nil {
fmt.Fprintf(os.Stderr, "ERROR: Failed to find header files: %v\n", err)
os.Exit(1)
}
fmt.Fprintf(os.Stderr, "Found %d header files\n", len(headers))
// Parse all headers
var allFunctions []Function
seen := make(map[string]bool)
for _, header := range headers {
content, err := os.ReadFile(header)
if err != nil {
fmt.Fprintf(os.Stderr, "Error reading %s: %v\n", header, err)
continue
}
cleaned := cleanContent(string(content))
functions := parseFunctions(cleaned)
// Deduplicate
for _, fn := range functions {
if !seen[fn.Name] {
seen[fn.Name] = true
allFunctions = append(allFunctions, fn)
}
}
}
fmt.Fprintf(os.Stderr, "Found %d unique function declarations\n", len(allFunctions))
// Generate wrapper files
if err := generateWrapperFiles(allFunctions, outputHeader, outputImpl); err != nil {
fmt.Fprintf(os.Stderr, "ERROR: Failed to generate wrapper files: %v\n", err)
os.Exit(1)
}
fmt.Fprintf(os.Stderr, "Generated %s and %s successfully\n", outputHeader, outputImpl)
}

5786
x/imagegen/mlx/mlx.c Normal file
View File

File diff suppressed because it is too large Load Diff

View File

@@ -3,12 +3,13 @@
package mlx
/*
#cgo CFLAGS: -O3 -I${SRCDIR}/../../../build/_deps/mlx-c-src
#cgo LDFLAGS: -L${SRCDIR}/../../../build/lib/ollama/ -lmlxc -Wl,-rpath,${SRCDIR}/../../../build/lib/ollama/
#cgo CFLAGS: -O3 -I${SRCDIR}/../../../build/_deps/mlx-c-src -I${SRCDIR}
#cgo darwin LDFLAGS: -lc++ -framework Metal -framework Foundation -framework Accelerate
#cgo linux LDFLAGS: -lstdc++ -lcuda -lcudart -lnvrtc
#cgo linux LDFLAGS: -lstdc++ -ldl
#cgo windows LDFLAGS: -lstdc++
#include "mlx/c/mlx.h"
// Use generated wrappers instead of direct MLX headers
#include "mlx.h"
#include <stdlib.h>
#include <stdint.h>
#include <string.h>
@@ -42,192 +43,6 @@ static inline mlx_stream cpu_stream() {
// CGO noescape/nocallback hints to reduce CGO overhead
// noescape: pointers won't escape, no heap allocation needed
// nocallback: function won't call back into Go
#cgo noescape mlx_add
#cgo nocallback mlx_add
#cgo noescape mlx_subtract
#cgo nocallback mlx_subtract
#cgo noescape mlx_multiply
#cgo nocallback mlx_multiply
#cgo noescape mlx_divide
#cgo nocallback mlx_divide
#cgo noescape mlx_negative
#cgo nocallback mlx_negative
#cgo noescape mlx_abs
#cgo nocallback mlx_abs
#cgo noescape mlx_exp
#cgo nocallback mlx_exp
#cgo noescape mlx_log
#cgo nocallback mlx_log
#cgo noescape mlx_sqrt
#cgo nocallback mlx_sqrt
#cgo noescape mlx_rsqrt
#cgo nocallback mlx_rsqrt
#cgo noescape mlx_square
#cgo nocallback mlx_square
#cgo noescape mlx_power
#cgo nocallback mlx_power
#cgo noescape mlx_erf
#cgo nocallback mlx_erf
#cgo noescape mlx_sigmoid
#cgo nocallback mlx_sigmoid
#cgo noescape mlx_tanh
#cgo nocallback mlx_tanh
#cgo noescape mlx_sin
#cgo nocallback mlx_sin
#cgo noescape mlx_cos
#cgo nocallback mlx_cos
#cgo noescape mlx_maximum
#cgo nocallback mlx_maximum
#cgo noescape mlx_minimum
#cgo nocallback mlx_minimum
#cgo noescape mlx_clip
#cgo nocallback mlx_clip
#cgo noescape mlx_sum
#cgo nocallback mlx_sum
#cgo noescape mlx_sum_axis
#cgo nocallback mlx_sum_axis
#cgo noescape mlx_mean
#cgo nocallback mlx_mean
#cgo noescape mlx_mean_axis
#cgo nocallback mlx_mean_axis
#cgo noescape mlx_var_axis
#cgo nocallback mlx_var_axis
#cgo noescape mlx_argmax
#cgo nocallback mlx_argmax
#cgo noescape mlx_argmax_axis
#cgo nocallback mlx_argmax_axis
#cgo noescape mlx_softmax_axis
#cgo nocallback mlx_softmax_axis
#cgo noescape mlx_cumsum
#cgo nocallback mlx_cumsum
#cgo noescape mlx_matmul
#cgo nocallback mlx_matmul
#cgo noescape mlx_addmm
#cgo nocallback mlx_addmm
#cgo noescape mlx_gather_mm
#cgo nocallback mlx_gather_mm
#cgo noescape mlx_gather_qmm
#cgo nocallback mlx_gather_qmm
#cgo noescape mlx_reshape
#cgo nocallback mlx_reshape
#cgo noescape mlx_transpose_axes
#cgo nocallback mlx_transpose_axes
#cgo noescape mlx_expand_dims
#cgo nocallback mlx_expand_dims
#cgo noescape mlx_squeeze_axis
#cgo nocallback mlx_squeeze_axis
#cgo noescape mlx_flatten
#cgo nocallback mlx_flatten
#cgo noescape mlx_concatenate_axis
#cgo nocallback mlx_concatenate_axis
#cgo noescape mlx_slice
#cgo nocallback mlx_slice
#cgo noescape mlx_slice_update
#cgo nocallback mlx_slice_update
#cgo noescape mlx_as_strided
#cgo nocallback mlx_as_strided
#cgo noescape mlx_view
#cgo nocallback mlx_view
#cgo noescape mlx_contiguous
#cgo nocallback mlx_contiguous
#cgo noescape mlx_pad
#cgo nocallback mlx_pad
#cgo noescape mlx_tile
#cgo nocallback mlx_tile
#cgo noescape mlx_take_axis
#cgo nocallback mlx_take_axis
#cgo noescape mlx_take_along_axis
#cgo nocallback mlx_take_along_axis
#cgo noescape mlx_put_along_axis
#cgo nocallback mlx_put_along_axis
#cgo noescape mlx_where
#cgo nocallback mlx_where
#cgo noescape mlx_argsort_axis
#cgo nocallback mlx_argsort_axis
#cgo noescape mlx_argpartition_axis
#cgo nocallback mlx_argpartition_axis
#cgo noescape mlx_topk_axis
#cgo nocallback mlx_topk_axis
#cgo noescape mlx_less
#cgo nocallback mlx_less
#cgo noescape mlx_greater_equal
#cgo nocallback mlx_greater_equal
#cgo noescape mlx_logical_and
#cgo nocallback mlx_logical_and
#cgo noescape mlx_zeros
#cgo nocallback mlx_zeros
#cgo noescape mlx_zeros_like
#cgo nocallback mlx_zeros_like
#cgo noescape mlx_ones
#cgo nocallback mlx_ones
#cgo noescape mlx_full
#cgo nocallback mlx_full
#cgo noescape mlx_arange
#cgo nocallback mlx_arange
#cgo noescape mlx_linspace
#cgo nocallback mlx_linspace
#cgo noescape mlx_tri
#cgo nocallback mlx_tri
#cgo noescape mlx_astype
#cgo nocallback mlx_astype
#cgo noescape mlx_fast_rms_norm
#cgo nocallback mlx_fast_rms_norm
#cgo noescape mlx_fast_rope
#cgo nocallback mlx_fast_rope
#cgo noescape mlx_fast_scaled_dot_product_attention
#cgo nocallback mlx_fast_scaled_dot_product_attention
#cgo noescape mlx_conv2d
#cgo nocallback mlx_conv2d
#cgo noescape mlx_conv3d
#cgo nocallback mlx_conv3d
#cgo noescape mlx_random_key
#cgo nocallback mlx_random_key
#cgo noescape mlx_random_split
#cgo nocallback mlx_random_split
#cgo noescape mlx_random_categorical_num_samples
#cgo nocallback mlx_random_categorical_num_samples
#cgo noescape mlx_random_normal
#cgo nocallback mlx_random_normal
#cgo noescape mlx_random_uniform
#cgo nocallback mlx_random_uniform
#cgo noescape mlx_array_eval
#cgo nocallback mlx_array_eval
#cgo noescape mlx_eval
#cgo nocallback mlx_eval
#cgo noescape mlx_async_eval
#cgo nocallback mlx_async_eval
#cgo noescape mlx_synchronize
#cgo nocallback mlx_synchronize
#cgo noescape mlx_array_new
#cgo nocallback mlx_array_new
#cgo noescape mlx_array_new_data
#cgo nocallback mlx_array_new_data
#cgo noescape mlx_array_new_float
#cgo nocallback mlx_array_new_float
#cgo noescape mlx_array_free
#cgo nocallback mlx_array_free
#cgo noescape mlx_array_size
#cgo nocallback mlx_array_size
#cgo noescape mlx_array_ndim
#cgo nocallback mlx_array_ndim
#cgo noescape mlx_array_dim
#cgo nocallback mlx_array_dim
#cgo noescape mlx_array_dtype
#cgo nocallback mlx_array_dtype
#cgo noescape mlx_array_item_int32
#cgo nocallback mlx_array_item_int32
#cgo noescape mlx_vector_array_new_data
#cgo nocallback mlx_vector_array_new_data
#cgo noescape mlx_vector_array_free
#cgo nocallback mlx_vector_array_free
#cgo noescape mlx_array_new_int
#cgo nocallback mlx_array_new_int
#cgo noescape mlx_stream_new_device
#cgo nocallback mlx_stream_new_device
#cgo noescape mlx_get_default_stream
#cgo nocallback mlx_get_default_stream
#cgo noescape mlx_set_default_stream
#cgo nocallback mlx_set_default_stream
*/
import "C"
import (
@@ -1796,7 +1611,57 @@ func ArgmaxKeepArray(logits *Array) *Array {
var RandomState = []*Array{nil}
var randomStateMu sync.Mutex
var mlxInitialized bool
var mlxInitError error
// InitMLX initializes the MLX library by dynamically loading libmlxc.
// This must be called before using any MLX functions.
// Returns an error if the library cannot be loaded.
func InitMLX() error {
if mlxInitialized {
return mlxInitError
}
// Try to load the MLX dynamic library
ret := C.mlx_dynamic_init()
if ret != 0 {
errMsg := C.GoString(C.mlx_dynamic_error())
mlxInitError = fmt.Errorf("failed to initialize MLX: %s", errMsg)
return mlxInitError
}
// Initialize all function pointers via dlsym
handle := C.mlx_get_handle()
ret = C.mlx_load_functions(handle)
if ret != 0 {
mlxInitError = fmt.Errorf("failed to load MLX function symbols")
return mlxInitError
}
mlxInitialized = true
mlxInitError = nil
return nil
}
// IsMLXAvailable returns whether MLX was successfully initialized
func IsMLXAvailable() bool {
return mlxInitialized && mlxInitError == nil
}
// GetMLXInitError returns any error that occurred during MLX initialization
func GetMLXInitError() error {
return mlxInitError
}
func init() {
// Initialize MLX dynamic library first
if err := InitMLX(); err != nil {
// Don't panic in init - let the caller handle the error
// Store the error for later retrieval
mlxInitError = err
return
}
// Lock main goroutine to OS thread for CUDA context stability.
// CUDA contexts are bound to threads; Go can migrate goroutines between threads.
runtime.LockOSThread()

2337
x/imagegen/mlx/mlx.h Normal file
View File

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,144 @@
// mlx_dynamic.c - Dynamic loading wrapper for MLX-C library
// This file provides runtime dynamic loading of libmlxc instead of link-time binding
#include "mlx_dynamic.h"
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#ifdef _WIN32
#include <windows.h>
typedef HMODULE lib_handle_t;
#define LOAD_LIB(path) LoadLibraryA(path)
#define GET_SYMBOL(handle, name) GetProcAddress(handle, name)
#define CLOSE_LIB(handle) FreeLibrary(handle)
#define LIB_ERROR() "LoadLibrary failed"
#else
#include <dlfcn.h>
typedef void* lib_handle_t;
#define LOAD_LIB(path) dlopen(path, RTLD_LAZY | RTLD_GLOBAL)
#define GET_SYMBOL(handle, name) dlsym(handle, name)
#define CLOSE_LIB(handle) dlclose(handle)
#define LIB_ERROR() dlerror()
#ifdef __APPLE__
#include <mach-o/dyld.h>
#include <libgen.h>
#endif
#endif
static lib_handle_t mlx_handle = NULL;
static int mlx_initialized = 0;
static char mlx_error_buffer[512] = {0};
#ifdef __APPLE__
// Get path to library in same directory as executable
static char* get_exe_relative_path(const char* libname) {
static char path[1024];
uint32_t size = sizeof(path);
if (_NSGetExecutablePath(path, &size) != 0) {
return NULL;
}
// Get directory of executable
char* dir = dirname(path);
static char fullpath[1024];
snprintf(fullpath, sizeof(fullpath), "%s/%s", dir, libname);
return fullpath;
}
#endif
// Try to load library from a specific path
static int try_load_lib(const char* path) {
if (!path) return 0;
mlx_handle = LOAD_LIB(path);
return mlx_handle != NULL;
}
// Initialize MLX dynamic library
// Returns 0 on success, -1 on failure
// On failure, call mlx_dynamic_error() to get error message
int mlx_dynamic_init(void) {
if (mlx_initialized) {
return 0; // Already initialized
}
const char* lib_path = NULL;
const char* tried_paths[8] = {0};
int num_tried = 0;
#ifdef _WIN32
// Windows: try same directory as executable
lib_path = "libmlxc.dll";
tried_paths[num_tried++] = lib_path;
if (try_load_lib(lib_path)) goto success;
#elif defined(__APPLE__)
// macOS: try executable directory first
lib_path = get_exe_relative_path("libmlxc.dylib");
if (lib_path) {
tried_paths[num_tried++] = lib_path;
if (try_load_lib(lib_path)) goto success;
}
// Try build directory (for tests run from repo root)
lib_path = "./build/lib/ollama/libmlxc.dylib";
tried_paths[num_tried++] = lib_path;
if (try_load_lib(lib_path)) goto success;
// Fallback to system paths
lib_path = "libmlxc.dylib";
tried_paths[num_tried++] = lib_path;
if (try_load_lib(lib_path)) goto success;
#else
// Linux: try build directory first (for tests)
lib_path = "./build/lib/ollama/libmlxc.so";
tried_paths[num_tried++] = lib_path;
if (try_load_lib(lib_path)) goto success;
// Fallback to system paths
lib_path = "libmlxc.so";
tried_paths[num_tried++] = lib_path;
if (try_load_lib(lib_path)) goto success;
#endif
// Failed to load library - build error message with all tried paths
{
const char* err = LIB_ERROR();
int offset = snprintf(mlx_error_buffer, sizeof(mlx_error_buffer),
"MLX: Failed to load libmlxc library. Tried: ");
for (int i = 0; i < num_tried && offset < (int)sizeof(mlx_error_buffer) - 50; i++) {
offset += snprintf(mlx_error_buffer + offset, sizeof(mlx_error_buffer) - offset,
"%s%s", i > 0 ? ", " : "", tried_paths[i]);
}
if (err) {
snprintf(mlx_error_buffer + offset, sizeof(mlx_error_buffer) - offset,
". Last error: %s", err);
}
}
return -1;
success:
mlx_initialized = 1;
snprintf(mlx_error_buffer, sizeof(mlx_error_buffer),
"MLX: Successfully loaded %s", lib_path ? lib_path : "library");
return 0;
}
// Get the last error message
const char* mlx_dynamic_error(void) {
return mlx_error_buffer;
}
// Check if MLX is initialized
int mlx_dynamic_is_initialized(void) {
return mlx_initialized;
}
// Get the library handle (for use by generated wrappers)
void* mlx_get_handle(void) {
return mlx_handle;
}
// Cleanup (optional, called at program exit)
void mlx_dynamic_cleanup(void) {
if (mlx_handle != NULL) {
CLOSE_LIB(mlx_handle);
mlx_handle = NULL;
mlx_initialized = 0;
}
}

View File

@@ -0,0 +1,29 @@
// mlx_dynamic.h - Dynamic loading interface for MLX-C library
#ifndef MLX_DYNAMIC_H
#define MLX_DYNAMIC_H
#ifdef __cplusplus
extern "C" {
#endif
// Initialize the MLX dynamic library
// Returns 0 on success, -1 on failure
int mlx_dynamic_init(void);
// Get the last error message from dynamic loading
const char* mlx_dynamic_error(void);
// Check if MLX is initialized
int mlx_dynamic_is_initialized(void);
// Get the library handle (for use by generated wrappers)
void* mlx_get_handle(void);
// Cleanup resources (optional, for clean shutdown)
void mlx_dynamic_cleanup(void);
#ifdef __cplusplus
}
#endif
#endif // MLX_DYNAMIC_H

View File

@@ -4,9 +4,30 @@ package mlx
import (
"fmt"
"os"
"path/filepath"
"runtime"
"testing"
)
// TestMain initializes MLX before running tests.
// If MLX libraries are not available, tests are skipped.
func TestMain(m *testing.M) {
// Change to repo root so ./build/lib/ollama/ path works
_, thisFile, _, _ := runtime.Caller(0)
repoRoot := filepath.Join(filepath.Dir(thisFile), "..", "..", "..")
if err := os.Chdir(repoRoot); err != nil {
fmt.Printf("Failed to change to repo root: %v\n", err)
os.Exit(1)
}
if err := InitMLX(); err != nil {
fmt.Printf("Skipping MLX tests: %v\n", err)
os.Exit(0)
}
os.Exit(m.Run())
}
// TestBasicCleanup verifies non-kept arrays are freed and kept arrays survive.
func TestBasicCleanup(t *testing.T) {
weight := NewArrayFloat32([]float32{1, 2, 3, 4}, []int32{2, 2})

View File

@@ -1,693 +0,0 @@
//go:build mlx
// Package glm_image implements the GLM-Image hybrid AR + diffusion model.
package glm_image
import (
"context"
"fmt"
"math"
"path/filepath"
"time"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// ByT5Tokenizer is a simple byte-level tokenizer for ByT5
// ByT5 uses bytes as tokens: each byte (0-255) maps to token ID (3-258)
// Special tokens: 0=pad, 1=eos, 2=unk
type ByT5Tokenizer struct {
PadTokenID int32
EOSTokenID int32
UNKTokenID int32
}
// NewByT5Tokenizer creates a new ByT5 tokenizer
func NewByT5Tokenizer() *ByT5Tokenizer {
return &ByT5Tokenizer{
PadTokenID: 0,
EOSTokenID: 1,
UNKTokenID: 2,
}
}
// Encode converts a string to token IDs
func (t *ByT5Tokenizer) Encode(text string) []int32 {
bytes := []byte(text)
tokens := make([]int32, len(bytes))
for i, b := range bytes {
// Standard ByT5 tokenization: bytes 0-255 map to tokens 3-258
// (tokens 0, 1, 2 are PAD, EOS, UNK)
tokens[i] = int32(b) + 3
}
return tokens
}
// Decode converts token IDs back to a string
func (t *ByT5Tokenizer) Decode(tokens []int32) string {
bytes := make([]byte, 0, len(tokens))
for _, tok := range tokens {
if tok >= 3 && tok < 259 {
bytes = append(bytes, byte(tok-3))
}
}
return string(bytes)
}
// GenerateConfig holds all options for image generation.
type GenerateConfig struct {
Prompt string
NegativePrompt string // For CFG (optional, not typically used with GLM-Image)
GuidanceScale float32 // Guidance scale (default: 1.5)
Width int32 // Image width (default: 1024, must be divisible by 32)
Height int32 // Image height (default: 1024, must be divisible by 32)
Steps int // Diffusion denoising steps (default: 50)
Seed int64 // Random seed
Progress ProgressFunc // Optional progress callback
// AR generation options
MaxVisualTokens int32 // Max visual tokens to generate (default: 256)
Temperature float32 // AR sampling temperature (default: 0.9)
TopP float32 // Nucleus sampling (default: 0.75)
}
// ProgressFunc is called during generation with stage and step progress.
type ProgressFunc func(stage string, step, totalSteps int)
// Model represents a GLM-Image hybrid model.
type Model struct {
ModelName string
Tokenizer *ByT5Tokenizer // For T5 text encoder (glyph embeddings)
GLMTokenizer *GLMTokenizer // For AR model (visual token generation)
TextEncoder *T5TextEncoder
VisionLanguageEncoder *VisionLanguageEncoder
Transformer *DiffusionTransformer
VAEDecoder *VAEDecoder
}
// Load loads the GLM-Image model from ollama blob storage.
func (m *Model) Load(modelName string) error {
fmt.Printf("Loading GLM-Image model from manifest: %s...\n", modelName)
start := time.Now()
if mlx.GPUIsAvailable() {
mlx.SetDefaultDeviceGPU()
mlx.EnableCompile()
}
m.ModelName = modelName
// Load manifest
manifest, err := imagegen.LoadManifest(modelName)
if err != nil {
return fmt.Errorf("load manifest: %w", err)
}
// Create ByT5 tokenizer (byte-level, no vocabulary file needed)
// Used for T5 text encoder (glyph embeddings)
fmt.Print(" Creating ByT5 tokenizer... ")
m.Tokenizer = NewByT5Tokenizer()
fmt.Println("✓")
// Load GLM tokenizer for AR model (visual token generation)
fmt.Print(" Loading GLM tokenizer... ")
glmTok, err := NewGLMTokenizer(manifest)
if err != nil {
return fmt.Errorf("glm tokenizer: %w", err)
}
m.GLMTokenizer = glmTok
fmt.Println("✓")
// Load T5 text encoder (~830MB)
m.TextEncoder = &T5TextEncoder{}
if err := m.TextEncoder.Load(manifest); err != nil {
return fmt.Errorf("text encoder: %w", err)
}
mlx.Eval(mlx.Collect(m.TextEncoder)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
// Load vision-language encoder (~19GB, 9B params)
m.VisionLanguageEncoder = &VisionLanguageEncoder{}
if err := m.VisionLanguageEncoder.Load(manifest); err != nil {
return fmt.Errorf("vision language encoder: %w", err)
}
mlx.Eval(mlx.Collect(m.VisionLanguageEncoder)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
// Load diffusion transformer (~13GB, 7B params)
m.Transformer = &DiffusionTransformer{}
if err := m.Transformer.Load(manifest); err != nil {
return fmt.Errorf("transformer: %w", err)
}
mlx.Eval(mlx.Collect(m.Transformer)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
// Load VAE decoder (~775MB)
m.VAEDecoder = &VAEDecoder{}
if err := m.VAEDecoder.Load(manifest); err != nil {
return fmt.Errorf("VAE decoder: %w", err)
}
mlx.Eval(mlx.Collect(m.VAEDecoder)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
mem := mlx.MetalGetActiveMemory()
fmt.Printf(" Loaded in %.2fs (%.1f GB VRAM)\n", time.Since(start).Seconds(), float64(mem)/(1024*1024*1024))
return nil
}
// LoadFromPath loads the model from a directory path (not ollama manifest)
func (m *Model) LoadFromPath(modelPath string) error {
fmt.Printf("Loading GLM-Image model from path: %s...\n", modelPath)
start := time.Now()
if mlx.GPUIsAvailable() {
mlx.SetDefaultDeviceGPU()
mlx.EnableCompile()
}
m.ModelName = modelPath
// Create ByT5 tokenizer (byte-level, no vocabulary file needed)
fmt.Print(" Creating ByT5 tokenizer... ")
m.Tokenizer = NewByT5Tokenizer()
fmt.Println("✓")
// Load GLM tokenizer for AR model (visual token generation)
fmt.Print(" Loading GLM tokenizer... ")
glmTok, err := NewGLMTokenizerFromPath(modelPath)
if err != nil {
return fmt.Errorf("glm tokenizer: %w", err)
}
m.GLMTokenizer = glmTok
fmt.Println("✓")
// Load T5 text encoder
m.TextEncoder = &T5TextEncoder{}
if err := m.TextEncoder.LoadFromPath(filepath.Join(modelPath, "text_encoder")); err != nil {
return fmt.Errorf("text encoder: %w", err)
}
mlx.Eval(mlx.Collect(m.TextEncoder)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
// Load vision-language encoder
m.VisionLanguageEncoder = &VisionLanguageEncoder{}
if err := m.VisionLanguageEncoder.LoadFromPath(filepath.Join(modelPath, "vision_language_encoder")); err != nil {
return fmt.Errorf("vision language encoder: %w", err)
}
mlx.Eval(mlx.Collect(m.VisionLanguageEncoder)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
// Load diffusion transformer
m.Transformer = &DiffusionTransformer{}
if err := m.Transformer.LoadFromPath(filepath.Join(modelPath, "transformer")); err != nil {
return fmt.Errorf("transformer: %w", err)
}
mlx.Eval(mlx.Collect(m.Transformer)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
// Load VAE decoder
m.VAEDecoder = &VAEDecoder{}
if err := m.VAEDecoder.LoadFromPath(filepath.Join(modelPath, "vae")); err != nil {
return fmt.Errorf("VAE decoder: %w", err)
}
mlx.Eval(mlx.Collect(m.VAEDecoder)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
mem := mlx.MetalGetActiveMemory()
fmt.Printf(" Loaded in %.2fs (%.1f GB VRAM)\n", time.Since(start).Seconds(), float64(mem)/(1024*1024*1024))
return nil
}
// Generate creates an image from a prompt.
func (m *Model) Generate(prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
return m.GenerateFromConfig(context.Background(), &GenerateConfig{
Prompt: prompt,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
})
}
// GenerateWithProgress creates an image with progress callback.
func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress ProgressFunc) (*mlx.Array, error) {
return m.GenerateFromConfig(context.Background(), &GenerateConfig{
Prompt: prompt,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
Progress: progress,
})
}
// GenerateFromConfig generates an image using the unified config struct.
func (m *Model) GenerateFromConfig(ctx context.Context, cfg *GenerateConfig) (*mlx.Array, error) {
start := time.Now()
result, err := m.generate(ctx, cfg)
if err != nil {
return nil, err
}
fmt.Printf("Generated in %.2fs (%d diffusion steps)\n", time.Since(start).Seconds(), cfg.Steps)
return result, nil
}
// GenerateImage implements model.ImageModel interface.
func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
return m.Generate(prompt, width, height, steps, seed)
}
// generate is the internal generation pipeline.
func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array, error) {
// Apply defaults
if cfg.Width <= 0 {
cfg.Width = 1024
}
if cfg.Height <= 0 {
cfg.Height = 1024
}
if cfg.Steps <= 0 {
cfg.Steps = 50
}
if cfg.GuidanceScale <= 0 {
cfg.GuidanceScale = 1.5
}
// Calculate MaxVisualTokens based on image dimensions
// GLM-Image generates TWO grids of visual tokens:
// 1. First: prev (small) grid - prevTokenH × prevTokenW tokens
// 2. Then: target (large) grid - tokenH × tokenW tokens
// After generation, we extract only the TARGET grid tokens for diffusion.
factor := int32(32)
tokenH := cfg.Height / factor
tokenW := cfg.Width / factor
targetGridTokens := tokenH * tokenW
// Compute prev grid dimensions using diffusers formula:
// ratio = token_h / token_w
// prev_token_h = int(sqrt(ratio) * 16)
// prev_token_w = int(sqrt(1/ratio) * 16)
ratio := float64(tokenH) / float64(tokenW)
prevTokenH := int32(math.Sqrt(ratio) * 16)
prevTokenW := int32(math.Sqrt(1/ratio) * 16)
prevGridTokens := prevTokenH * prevTokenW
// Total tokens to generate = prev grid + target grid
// (diffusers does max_new_tokens = total + 1 for EOS, but we stop on EOS anyway)
cfg.MaxVisualTokens = prevGridTokens + targetGridTokens
if cfg.Temperature <= 0 {
cfg.Temperature = 0.9
}
if cfg.TopP <= 0 {
cfg.TopP = 0.75
}
// Ensure dimensions are divisible by 32
cfg.Width = (cfg.Width / 32) * 32
cfg.Height = (cfg.Height / 32) * 32
tcfg := m.Transformer.Config
latentH := cfg.Height / 8
latentW := cfg.Width / 8
// Progress callback helper
progress := func(stage string, step, total int) {
if cfg.Progress != nil {
cfg.Progress(stage, step, total)
}
}
// === PHASE 1: T5 Text Encoding ===
fmt.Println("[T5] Encoding glyph text...")
progress("text_encoding", 0, 1)
textEmbed := m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.Prompt)
mlx.Keep(textEmbed)
mlx.Eval(textEmbed)
fmt.Printf("[T5] Done, shape: %v\n", textEmbed.Shape())
progress("text_encoding", 1, 1)
// === PHASE 2: AR Visual Token Generation ===
fmt.Printf("[AR] Generating %d visual tokens...\n", cfg.MaxVisualTokens)
progress("ar_generation", 0, int(cfg.MaxVisualTokens))
visualTokens := m.VisionLanguageEncoder.Generate(
cfg.Prompt,
m.GLMTokenizer,
cfg.MaxVisualTokens,
cfg.Temperature,
cfg.TopP,
cfg.Seed,
cfg.Height,
cfg.Width,
func(step int) {
if step%100 == 0 || step < 10 {
fmt.Printf("[AR] Step %d/%d\n", step, cfg.MaxVisualTokens)
}
progress("ar_generation", step, int(cfg.MaxVisualTokens))
},
)
mlx.Keep(visualTokens)
mlx.Eval(visualTokens)
fmt.Printf("[AR] Done generating visual tokens\n")
progress("ar_generation", int(cfg.MaxVisualTokens), int(cfg.MaxVisualTokens))
vtShape := visualTokens.Shape()
totalGenerated := vtShape[1]
fmt.Printf("[AR] Generated %d tokens total\n", totalGenerated)
// Extract only the TARGET grid tokens (skip the prev grid tokens)
// diffusers: large_image_tokens = outputs[input_length + large_image_start_offset : ...]
// large_image_start_offset = prev_grid_size
var targetGridVisualTokens *mlx.Array
if totalGenerated >= prevGridTokens+targetGridTokens {
// Full generation completed - extract target grid
targetGridVisualTokens = mlx.Slice(visualTokens,
[]int32{0, prevGridTokens},
[]int32{1, prevGridTokens + targetGridTokens})
mlx.Keep(targetGridVisualTokens)
mlx.Eval(targetGridVisualTokens)
} else if totalGenerated > prevGridTokens {
// Partial target grid - take what we have
actualTargetTokens := totalGenerated - prevGridTokens
targetGridVisualTokens = mlx.Slice(visualTokens,
[]int32{0, prevGridTokens},
[]int32{1, totalGenerated})
mlx.Keep(targetGridVisualTokens)
mlx.Eval(targetGridVisualTokens)
fmt.Printf("WARNING: Partial target grid: got %d/%d target tokens\n",
actualTargetTokens, targetGridTokens)
} else {
// Not enough tokens - EOS came too early
return nil, fmt.Errorf("AR generation stopped too early: got %d tokens, need at least %d (prev grid) + 1",
totalGenerated, prevGridTokens)
}
// === PHASE 3: Diffusion Decoding ===
// Setup scheduler with dynamic shift based on image size
scheduler := NewFlowMatchScheduler(DefaultSchedulerConfig())
imgSeqLen := (latentH / tcfg.PatchSize) * (latentW / tcfg.PatchSize)
scheduler.SetTimestepsWithDynamicShift(cfg.Steps, imgSeqLen)
// Initialize noise latents [B, C, H, W]
latents := scheduler.InitNoise([]int32{1, tcfg.InChannels, latentH, latentW}, cfg.Seed)
mlx.Eval(latents)
// Upsample TARGET grid visual tokens 2x to match patch count (matching diffusers)
// target_grid tokens -> 2x upsample -> patch_count
// e.g., 32x32=1024 tokens -> 64x64=4096 patches for 1024x1024
visualTokensUpsampled := upsampleTokens(targetGridVisualTokens, tokenH, tokenW, 2)
// Prepare prior embeddings from upsampled visual tokens (VQ codebook lookup + projection)
priorEmbed := m.Transformer.EmbedPriorTokens(visualTokensUpsampled)
mlx.Keep(priorEmbed)
mlx.Eval(priorEmbed)
// Prepare text conditioning (project T5 embeddings)
textCond := m.Transformer.ProjectTextEmbeddings(textEmbed)
mlx.Keep(textCond)
mlx.Eval(textCond)
// === CFG Setup ===
// For classifier-free guidance, we need unconditional (negative) text embeddings
// GLM-Image uses empty string "" for negative prompt
doCFG := cfg.GuidanceScale > 1.0
var negativeTextCond *mlx.Array
if doCFG {
// Encode empty string for negative prompt
negativeTextEmbed := m.TextEncoder.EncodePrompt(m.Tokenizer, "")
mlx.Keep(negativeTextEmbed)
mlx.Eval(negativeTextEmbed)
negativeTextCond = m.Transformer.ProjectTextEmbeddings(negativeTextEmbed)
mlx.Keep(negativeTextCond)
mlx.Eval(negativeTextCond)
negativeTextEmbed.Free()
}
// Prepare conditioning inputs
targetSize := mlx.NewArray([]float32{float32(cfg.Height), float32(cfg.Width)}, []int32{1, 2})
cropCoords := mlx.NewArray([]float32{0, 0}, []int32{1, 2}) // Default: no crop offset
targetSize = mlx.ToBFloat16(targetSize)
cropCoords = mlx.ToBFloat16(cropCoords)
mlx.Keep(targetSize)
mlx.Keep(cropCoords)
mlx.Eval(targetSize, cropCoords)
pH := latentH / tcfg.PatchSize
pW := latentW / tcfg.PatchSize
// Denoising loop
fmt.Printf("[Diffusion] Starting %d denoising steps...\n", cfg.Steps)
progress("diffusion", 0, cfg.Steps)
for i := 0; i < cfg.Steps; i++ {
fmt.Printf("[Diffusion] Step %d/%d (timestep=%.1f)\n", i+1, cfg.Steps, scheduler.Timesteps[i]-1)
// Check for cancellation
if ctx != nil {
select {
case <-ctx.Done():
textEmbed.Free()
visualTokens.Free()
// visualTokensUpsampled points to visualTokens, don't double-free
priorEmbed.Free()
textCond.Free()
latents.Free()
return nil, ctx.Err()
default:
}
}
// Get timestep value for the transformer
// scheduler.Timesteps contains raw timestep values (1000 down to ~20)
// Pass timestep - 1 to match diffusers: timestep = t.expand(latents.shape[0]) - 1
timestepVal := scheduler.Timesteps[i] - 1
timestep := mlx.ToBFloat16(mlx.NewArray([]float32{timestepVal}, []int32{1}))
// Patchify latents [B, C, H, W] -> [B, L, C*p*p]
patches := PatchifyLatents(latents, tcfg.PatchSize)
// Transformer forward with MMDiT architecture
// Conditional pass (with text + prior embeddings)
outputCond := m.Transformer.ForwardWithPriorDrop(
patches,
priorEmbed,
textCond,
timestep,
targetSize,
cropCoords,
pH,
pW,
false, // priorTokenDrop = false for conditional
)
// Unpatchify [B, L, C*p*p] -> [B, C, H, W]
noisePredCond := UnpatchifyLatents(outputCond, latentH, latentW, tcfg.PatchSize, tcfg.OutChannels)
var noisePred *mlx.Array
if doCFG {
// Unconditional pass (empty text, dropped prior embeddings)
outputUncond := m.Transformer.ForwardWithPriorDrop(
patches,
priorEmbed, // Still passed but will be ignored due to priorTokenDrop=true
negativeTextCond,
timestep,
targetSize,
cropCoords,
pH,
pW,
true, // priorTokenDrop = true for unconditional
)
noisePredUncond := UnpatchifyLatents(outputUncond, latentH, latentW, tcfg.PatchSize, tcfg.OutChannels)
// CFG formula: noise_pred = uncond + guidance_scale * (cond - uncond)
diff := mlx.Sub(noisePredCond, noisePredUncond)
scaled := mlx.MulScalar(diff, cfg.GuidanceScale)
noisePred = mlx.Add(noisePredUncond, scaled)
} else {
noisePred = noisePredCond
}
// Scheduler step
oldLatents := latents
latents = scheduler.Step(noisePred, latents, i)
mlx.Eval(latents)
oldLatents.Free()
progress("diffusion", i+1, cfg.Steps)
}
// Cleanup intermediate arrays
textEmbed.Free()
visualTokens.Free()
// visualTokensUpsampled points to visualTokens, don't double-free
priorEmbed.Free()
textCond.Free()
if negativeTextCond != nil {
negativeTextCond.Free()
}
targetSize.Free()
cropCoords.Free()
// === PHASE 4: VAE Decode ===
progress("vae_decode", 0, 1)
decoded := m.VAEDecoder.Decode(latents)
mlx.Eval(decoded)
latents.Free()
progress("vae_decode", 1, 1)
return decoded, nil
}
// upsampleTokens performs nearest-neighbor upsampling of visual tokens
// Converts from prev_grid (e.g., 16x16) to target_grid (e.g., 32x32 for 2x, 64x64 for 4x)
// scale must be 2 or 4
//
// Handles early EOS gracefully: if tokens has fewer than prevH*prevW elements,
// missing tokens are padded with 0 (visual token padding value).
func upsampleTokens(tokens *mlx.Array, prevH, prevW int32, scale int32) *mlx.Array {
// tokens: [1, N] where N <= prevH*prevW (may be shorter if early EOS)
// Each token at (i, j) becomes scale*scale tokens in the output
mlx.Eval(tokens)
tokenData := tokens.DataInt32()
numTokens := int32(len(tokenData))
expectedTokens := prevH * prevW
// Warn if we got fewer tokens than expected (early EOS)
if numTokens < expectedTokens {
fmt.Printf("WARNING: upsampleTokens got %d tokens, expected %d (padding with 0)\n",
numTokens, expectedTokens)
}
targetH := prevH * scale
targetW := prevW * scale
upsampled := make([]int32, targetH*targetW)
for i := int32(0); i < prevH; i++ {
for j := int32(0); j < prevW; j++ {
srcIdx := i*prevW + j
// Handle early EOS: use 0 (padding) for missing tokens
var val int32
if srcIdx < numTokens {
val = tokenData[srcIdx]
} else {
val = 0 // Padding token
}
// Place in scale*scale positions
dstI := i * scale
dstJ := j * scale
for di := int32(0); di < scale; di++ {
for dj := int32(0); dj < scale; dj++ {
upsampled[(dstI+di)*targetW+(dstJ+dj)] = val
}
}
}
}
return mlx.NewArrayInt32(upsampled, []int32{1, targetH * targetW})
}
// PatchifyLatents converts [B, C, H, W] to [B, L, C*p*p]
func PatchifyLatents(latents *mlx.Array, patchSize int32) *mlx.Array {
shape := latents.Shape()
B := shape[0]
C := shape[1]
H := shape[2]
W := shape[3]
pH := H / patchSize
pW := W / patchSize
// Reshape: [B, C, H, W] -> [B, C, pH, p, pW, p]
x := mlx.Reshape(latents, B, C, pH, patchSize, pW, patchSize)
// Transpose: -> [B, pH, pW, C, p, p]
x = mlx.Transpose(x, 0, 2, 4, 1, 3, 5)
// Flatten: -> [B, pH*pW, C*p*p]
return mlx.Reshape(x, B, pH*pW, C*patchSize*patchSize)
}
// UnpatchifyLatents converts [B, L, C*p*p] back to [B, C, H, W]
func UnpatchifyLatents(patches *mlx.Array, H, W, patchSize, channels int32) *mlx.Array {
shape := patches.Shape()
B := shape[0]
pH := H / patchSize
pW := W / patchSize
// Reshape: [B, L, C*p*p] -> [B, pH, pW, C, p, p]
x := mlx.Reshape(patches, B, pH, pW, channels, patchSize, patchSize)
// Transpose: -> [B, C, pH, p, pW, p]
x = mlx.Transpose(x, 0, 3, 1, 4, 2, 5)
// Reshape: -> [B, C, H, W]
return mlx.Reshape(x, B, channels, pH*patchSize, pW*patchSize)
}
// CalculateShift computes the dynamic shift for flow matching based on image sequence length.
func CalculateShift(imgSeqLen int32) float32 {
cfg := DefaultSchedulerConfig()
if !cfg.UseDynamicShifting {
return 0
}
// Sqrt-based shift calculation (matches diffusers)
m := float32(math.Sqrt(float64(imgSeqLen) / float64(cfg.BaseImageSeqLen)))
return m*cfg.MaxShift + cfg.BaseShift
}
// UpsampleTokens2x upsamples token IDs by 2x using nearest neighbor interpolation
// tokens: [B, H*W] -> [B, (H*2)*(W*2)]
// This matches diffusers' _upsample_token_ids function
func UpsampleTokens2x(tokens *mlx.Array, gridH, gridW int32) *mlx.Array {
shape := tokens.Shape()
B := shape[0]
// Reshape to [B, 1, H, W] for interpolation
tokens = mlx.Reshape(tokens, B, 1, gridH, gridW)
// Convert to float for interpolation
tokensFloat := mlx.AsType(tokens, mlx.DtypeFloat32)
// 2x nearest neighbor upsample
// [B, 1, H, W] -> [B, 1, H*2, W*2]
upsampled := nearestUpsample2x(tokensFloat)
// Convert back to int and reshape to [B, H*2*W*2]
upsampled = mlx.AsType(upsampled, mlx.DtypeInt32)
return mlx.Reshape(upsampled, B, gridH*2*gridW*2)
}
// nearestUpsample2x performs 2x nearest neighbor upsampling on NCHW tensor
func nearestUpsample2x(x *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
C := shape[1]
H := shape[2]
W := shape[3]
// Repeat each element 2x2
// [B, C, H, W] -> [B, C, H, 1, W, 1] -> [B, C, H, 2, W, 2] -> [B, C, H*2, W*2]
x = mlx.Reshape(x, B, C, H, 1, W, 1)
// Tile to repeat each pixel 2x2
x = mlx.Tile(x, []int32{1, 1, 1, 2, 1, 2})
// Reshape to final size
return mlx.Reshape(x, B, C, H*2, W*2)
}

View File

@@ -1,358 +0,0 @@
//go:build mlx
package glm_image
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"sort"
"strings"
"github.com/ollama/ollama/x/imagegen"
)
// GLMTokenizer implements the GLM tokenizer for the AR model
// This is a BPE-style tokenizer with ignore_merges=true, meaning it does
// greedy longest-match tokenization from the vocab without runtime merging.
type GLMTokenizer struct {
Vocab map[string]int32 // token string -> token ID
VocabReverse map[int32]string // token ID -> token string
SpecialTokens map[string]int32 // special token strings -> IDs
// Special token IDs
SopTokenID int32 // <sop> = grid_bos_token (167845)
EopTokenID int32 // <eop> = grid_eos_token (167846)
BosTokenID int32 // <|dit_token_16384|> = visual BOS (16384)
EosTokenID int32 // <|dit_token_16385|> = visual EOS (16385)
PadTokenID int32
// Sorted vocab keys by length (longest first) for greedy matching
sortedTokens []string
}
// tokenizerJSON represents the structure of tokenizer.json
type tokenizerJSON struct {
Model struct {
Vocab map[string]int32 `json:"vocab"`
} `json:"model"`
AddedTokens []struct {
ID int32 `json:"id"`
Content string `json:"content"`
Special bool `json:"special"`
} `json:"added_tokens"`
}
// NewGLMTokenizer creates a GLM tokenizer from the model manifest
func NewGLMTokenizer(manifest *imagegen.ModelManifest) (*GLMTokenizer, error) {
// Read tokenizer.json from processor directory in manifest
data, err := manifest.ReadConfig("processor/tokenizer.json")
if err != nil {
return nil, fmt.Errorf("failed to read tokenizer.json from manifest: %w", err)
}
var tj tokenizerJSON
if err := json.Unmarshal(data, &tj); err != nil {
return nil, fmt.Errorf("failed to parse tokenizer.json: %w", err)
}
tok := &GLMTokenizer{
Vocab: make(map[string]int32),
VocabReverse: make(map[int32]string),
SpecialTokens: make(map[string]int32),
}
// Load vocab from model section
for token, id := range tj.Model.Vocab {
tok.Vocab[token] = id
tok.VocabReverse[id] = token
}
// Load added tokens (special tokens including dit_tokens)
for _, at := range tj.AddedTokens {
tok.Vocab[at.Content] = at.ID
tok.VocabReverse[at.ID] = at.Content
if at.Special {
tok.SpecialTokens[at.Content] = at.ID
}
}
// Set special token IDs
tok.SopTokenID = 167845 // <sop>
tok.EopTokenID = 167846 // <eop>
tok.BosTokenID = 16384 // <|dit_token_16384|>
tok.EosTokenID = 16385 // <|dit_token_16385|>
tok.PadTokenID = 16385 // Same as EOS
// Build sorted token list for greedy matching (longest first)
tok.sortedTokens = make([]string, 0, len(tok.Vocab))
for token := range tok.Vocab {
tok.sortedTokens = append(tok.sortedTokens, token)
}
sort.Slice(tok.sortedTokens, func(i, j int) bool {
return len(tok.sortedTokens[i]) > len(tok.sortedTokens[j])
})
fmt.Printf("Loaded GLM tokenizer with %d tokens\n", len(tok.Vocab))
return tok, nil
}
// NewGLMTokenizerFromPath creates a GLM tokenizer from a directory path
func NewGLMTokenizerFromPath(modelPath string) (*GLMTokenizer, error) {
// Read tokenizer.json from processor directory
tokenizerPath := filepath.Join(modelPath, "processor", "tokenizer.json")
data, err := os.ReadFile(tokenizerPath)
if err != nil {
return nil, fmt.Errorf("failed to read tokenizer.json: %w", err)
}
var tj tokenizerJSON
if err := json.Unmarshal(data, &tj); err != nil {
return nil, fmt.Errorf("failed to parse tokenizer.json: %w", err)
}
tok := &GLMTokenizer{
Vocab: make(map[string]int32),
VocabReverse: make(map[int32]string),
SpecialTokens: make(map[string]int32),
}
// Load vocab from model section
for token, id := range tj.Model.Vocab {
tok.Vocab[token] = id
tok.VocabReverse[id] = token
}
// Load added tokens (special tokens including dit_tokens)
for _, at := range tj.AddedTokens {
tok.Vocab[at.Content] = at.ID
tok.VocabReverse[at.ID] = at.Content
if at.Special {
tok.SpecialTokens[at.Content] = at.ID
}
}
// Set special token IDs
tok.SopTokenID = 167845 // <sop>
tok.EopTokenID = 167846 // <eop>
tok.BosTokenID = 16384 // <|dit_token_16384|>
tok.EosTokenID = 16385 // <|dit_token_16385|>
tok.PadTokenID = 16385 // Same as EOS
// Build sorted token list for greedy matching (longest first)
tok.sortedTokens = make([]string, 0, len(tok.Vocab))
for token := range tok.Vocab {
tok.sortedTokens = append(tok.sortedTokens, token)
}
sort.Slice(tok.sortedTokens, func(i, j int) bool {
return len(tok.sortedTokens[i]) > len(tok.sortedTokens[j])
})
fmt.Printf("Loaded GLM tokenizer with %d tokens\n", len(tok.Vocab))
return tok, nil
}
// Encode tokenizes a string into token IDs
// This uses greedy longest-match tokenization with GPT-2 style space handling
func (t *GLMTokenizer) Encode(text string) []int32 {
if text == "" {
return []int32{}
}
var tokens []int32
// First, check for and handle special tokens
// Replace special tokens with placeholders, encode, then restore
specialReplacements := make(map[string]int32)
for special, id := range t.SpecialTokens {
if strings.Contains(text, special) {
specialReplacements[special] = id
}
}
// Process text character by character with special token handling
i := 0
isFirstToken := true
for i < len(text) {
// Check for special tokens first
foundSpecial := false
for special, id := range specialReplacements {
if strings.HasPrefix(text[i:], special) {
tokens = append(tokens, id)
i += len(special)
isFirstToken = false
foundSpecial = true
break
}
}
if foundSpecial {
continue
}
// Handle regular text with GPT-2 style space prefix
// "Ġ" (U+0120) represents a space before a token
remaining := text[i:]
// Try to find the longest matching token
matched := false
for _, token := range t.sortedTokens {
// Skip special tokens in regular matching
if _, isSpecial := t.SpecialTokens[token]; isSpecial {
continue
}
// Check if this token matches
tokenText := token
// Handle the Ġ prefix (represents space)
if strings.HasPrefix(token, "Ġ") {
// This token expects a leading space
if i > 0 || !isFirstToken {
// Check if remaining starts with space + token content
tokenContent := token[len("Ġ"):]
if strings.HasPrefix(remaining, " "+tokenContent) {
tokens = append(tokens, t.Vocab[token])
i += 1 + len(tokenContent) // space + content
isFirstToken = false
matched = true
break
}
}
} else {
// Regular token without space prefix
if strings.HasPrefix(remaining, tokenText) {
tokens = append(tokens, t.Vocab[token])
i += len(tokenText)
isFirstToken = false
matched = true
break
}
}
}
if !matched {
// No token found - skip this character (or use UNK)
// For now, just skip unknown characters
i++
}
}
return tokens
}
// EncodeForGeneration encodes a prompt with grid tokens for image generation
// Format: {prompt}<sop>{token_h} {token_w}<eop><sop>{prev_h} {prev_w}<eop><|dit_token_16384|>
//
// Uses GPT-2 style tokenization where " 32" becomes "Ġ32" (a single token with
// space prefix), matching the HuggingFace tokenizer behavior.
func (t *GLMTokenizer) EncodeForGeneration(prompt string, targetHeight, targetWidth int32) []int32 {
// Calculate grid dimensions
factor := int32(32)
height := (targetHeight / factor) * factor
width := (targetWidth / factor) * factor
tokenH := height / factor
tokenW := width / factor
// Calculate previous grid dimensions
ratio := float64(tokenH) / float64(tokenW)
prevTokenH := int32(sqrt(ratio) * 16)
prevTokenW := int32(sqrt(1.0/ratio) * 16)
// Encode the prompt text
promptTokens := t.Encode(prompt)
// Build the full sequence:
// [prompt tokens] <sop> [tokenH] [Ġ+tokenW] <eop> <sop> [prevH] [Ġ+prevW] <eop> <bos>
// Note: HF tokenizer treats " 32" as "Ġ32" (single token), not "Ġ" + "32"
var tokens []int32
tokens = append(tokens, promptTokens...)
// First grid: <sop> H W <eop>
// First number has no space prefix, second number has space prefix (Ġ)
tokens = append(tokens, t.SopTokenID)
tokens = append(tokens, t.encodeNumber(tokenH)...)
tokens = append(tokens, t.encodeSpaceNumber(tokenW)...) // " W" as Ġ+W
tokens = append(tokens, t.EopTokenID)
// Second grid: <sop> prevH prevW <eop>
tokens = append(tokens, t.SopTokenID)
tokens = append(tokens, t.encodeNumber(prevTokenH)...)
tokens = append(tokens, t.encodeSpaceNumber(prevTokenW)...) // " prevW" as Ġ+prevW
tokens = append(tokens, t.EopTokenID)
// BOS token (start of image generation)
tokens = append(tokens, t.BosTokenID)
return tokens
}
// encodeNumber encodes a number - first tries as a whole token, falls back to digit-by-digit
func (t *GLMTokenizer) encodeNumber(n int32) []int32 {
s := fmt.Sprintf("%d", n)
// First try: look up the whole number as a single token
if id, ok := t.Vocab[s]; ok {
return []int32{id}
}
// Fallback: encode digit by digit
var tokens []int32
for _, c := range s {
if id, ok := t.Vocab[string(c)]; ok {
tokens = append(tokens, id)
}
}
return tokens
}
// encodeSpaceNumber encodes " N" as "ĠN" (space-prefixed number) matching HF tokenizer
// GPT-2 style: " 32" becomes single token "Ġ32", not "Ġ" + "32"
func (t *GLMTokenizer) encodeSpaceNumber(n int32) []int32 {
s := fmt.Sprintf("%d", n)
// First try: look up "Ġ{number}" as a single token (e.g., "Ġ32")
spaceToken := "Ġ" + s
if id, ok := t.Vocab[spaceToken]; ok {
return []int32{id}
}
// Fallback: bare space Ġ + number tokens
var tokens []int32
if spaceID, ok := t.Vocab["Ġ"]; ok {
tokens = append(tokens, spaceID)
}
tokens = append(tokens, t.encodeNumber(n)...)
return tokens
}
// sqrt is a helper for float64 sqrt
func sqrt(x float64) float64 {
if x <= 0 {
return 0
}
// Newton's method
z := x
for i := 0; i < 10; i++ {
z = z - (z*z-x)/(2*z)
}
return z
}
// Decode converts token IDs back to a string
func (t *GLMTokenizer) Decode(tokens []int32) string {
var sb strings.Builder
for _, id := range tokens {
if token, ok := t.VocabReverse[id]; ok {
// Handle Ġ prefix (convert back to space)
if strings.HasPrefix(token, "Ġ") {
sb.WriteString(" ")
sb.WriteString(token[len("Ġ"):])
} else {
sb.WriteString(token)
}
}
}
return sb.String()
}

View File

@@ -1,159 +0,0 @@
//go:build mlx
package glm_image
import (
"math"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// FlowMatchSchedulerConfig holds scheduler configuration
type FlowMatchSchedulerConfig struct {
NumTrainTimesteps int32 `json:"num_train_timesteps"` // 1000
BaseShift float32 `json:"base_shift"` // 0.25
MaxShift float32 `json:"max_shift"` // 0.75
BaseImageSeqLen int32 `json:"base_image_seq_len"` // 256
MaxImageSeqLen int32 `json:"max_image_seq_len"` // 4096
UseDynamicShifting bool `json:"use_dynamic_shifting"` // true
TimeShiftType string `json:"time_shift_type"` // "linear"
}
// DefaultSchedulerConfig returns the default config for GLM-Image
func DefaultSchedulerConfig() *FlowMatchSchedulerConfig {
return &FlowMatchSchedulerConfig{
NumTrainTimesteps: 1000,
BaseShift: 0.25,
MaxShift: 0.75,
BaseImageSeqLen: 256,
MaxImageSeqLen: 4096,
UseDynamicShifting: true,
TimeShiftType: "linear",
}
}
// FlowMatchScheduler implements FlowMatchEulerDiscreteScheduler
type FlowMatchScheduler struct {
Config *FlowMatchSchedulerConfig
Timesteps []float32 // Raw timesteps for transformer conditioning (unshifted)
Sigmas []float32 // Shifted sigmas for Euler step calculation
NumSteps int
}
// NewFlowMatchScheduler creates a new scheduler
func NewFlowMatchScheduler(cfg *FlowMatchSchedulerConfig) *FlowMatchScheduler {
return &FlowMatchScheduler{Config: cfg}
}
// SetTimestepsWithDynamicShift sets timesteps with dynamic shifting based on image size
// Following diffusers: raw timesteps are used for conditioning, shifted sigmas for step calculation
func (s *FlowMatchScheduler) SetTimestepsWithDynamicShift(numSteps int, imgSeqLen int32) {
s.NumSteps = numSteps
// Calculate shift (mu) based on image sequence length
mu := s.calculateShift(imgSeqLen)
// Create timesteps: linspace from sigma_max_t to sigma_min_t
// sigma_max = 1.0, sigma_min ~= 0.001 (near 0 but not exactly 0)
// Then apply time shift and append terminal sigma=0
s.Timesteps = make([]float32, numSteps)
s.Sigmas = make([]float32, numSteps+1) // +1 for terminal sigma
numTrainTimesteps := float32(s.Config.NumTrainTimesteps)
// Create base sigmas: linspace from 1.0 to small value (matching diffusers)
for i := 0; i < numSteps; i++ {
// linspace from 1000 to ~20 (sigma_min * num_train_timesteps)
tRaw := numTrainTimesteps - float32(i)*(numTrainTimesteps-1.0)/float32(numSteps-1)
s.Timesteps[i] = tRaw
// Convert to sigma [0, 1]
sigma := tRaw / numTrainTimesteps
// Apply time shift if enabled
if s.Config.UseDynamicShifting && mu > 0 {
sigma = s.applyShift(mu, sigma)
}
s.Sigmas[i] = sigma
}
// Append terminal sigma = 0 (the final clean image)
s.Sigmas[numSteps] = 0
}
// calculateShift computes dynamic shift based on image sequence length
// Uses the sqrt-based formula from diffusers:
// m = (image_seq_len / base_seq_len) ** 0.5
// mu = m * max_shift + base_shift
func (s *FlowMatchScheduler) calculateShift(imgSeqLen int32) float32 {
cfg := s.Config
if !cfg.UseDynamicShifting {
return 0
}
// Sqrt-based shift calculation (matches diffusers pipeline_glm_image.py)
m := float32(math.Sqrt(float64(imgSeqLen) / float64(cfg.BaseImageSeqLen)))
mu := m*cfg.MaxShift + cfg.BaseShift
return mu
}
// applyShift applies time shift transformation
// mu: the computed shift value
// t: sigma value in [0, 1]
func (s *FlowMatchScheduler) applyShift(mu float32, t float32) float32 {
if t <= 0 {
return 0
}
if t >= 1 {
return 1
}
// sigma=1.0 for both shift types
sigma := float32(1.0)
if s.Config.TimeShiftType == "linear" {
// Linear: mu / (mu + (1/t - 1)^sigma)
return mu / (mu + float32(math.Pow(float64(1.0/t-1.0), float64(sigma))))
}
// Exponential (default): exp(mu) / (exp(mu) + (1/t - 1)^sigma)
expMu := float32(math.Exp(float64(mu)))
return expMu / (expMu + float32(math.Pow(float64(1.0/t-1.0), float64(sigma))))
}
// Step performs one denoising step
func (s *FlowMatchScheduler) Step(modelOutput, sample *mlx.Array, stepIdx int) *mlx.Array {
sigma := s.Sigmas[stepIdx]
sigmaNext := s.Sigmas[stepIdx+1]
// Euler step: x_{t-dt} = x_t + dt * v_t
dt := sigmaNext - sigma // Negative (going from noise to clean)
scaledOutput := mlx.MulScalar(modelOutput, dt)
return mlx.Add(sample, scaledOutput)
}
// InitNoise creates initial noise
func (s *FlowMatchScheduler) InitNoise(shape []int32, seed int64) *mlx.Array {
return mlx.RandomNormalWithDtype(shape, uint64(seed), mlx.DtypeBFloat16)
}
// AddNoise adds noise to clean samples for a given timestep (for img2img)
func (s *FlowMatchScheduler) AddNoise(cleanSample, noise *mlx.Array, timestepIdx int) *mlx.Array {
// In flow matching: x_t = (1-sigma) * x_0 + sigma * noise
// Use sigmas (shifted) for the interpolation
sigma := s.Sigmas[timestepIdx]
oneMinusSigma := 1.0 - sigma
scaledClean := mlx.MulScalar(cleanSample, oneMinusSigma)
scaledNoise := mlx.MulScalar(noise, sigma)
return mlx.Add(scaledClean, scaledNoise)
}
// GetTimesteps returns all timesteps
func (s *FlowMatchScheduler) GetTimesteps() []float32 {
return s.Timesteps
}

View File

@@ -1,497 +0,0 @@
//go:build mlx
package glm_image
import (
"encoding/json"
"fmt"
"math"
"os"
"path/filepath"
"regexp"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// T5Config holds T5 encoder configuration
type T5Config struct {
DModel int32 `json:"d_model"` // 1472
DFF int32 `json:"d_ff"` // 3584
DKV int32 `json:"d_kv"` // 64
NumHeads int32 `json:"num_heads"` // 6
NumLayers int32 `json:"num_layers"` // 12
VocabSize int32 `json:"vocab_size"` // 384 (byte-level)
LayerNormEps float32 `json:"layer_norm_epsilon"` // 1e-6
IsGatedAct bool `json:"is_gated_act"` // true (gated-gelu)
// Relative position bias
RelativeAttentionNumBuckets int32 `json:"relative_attention_num_buckets"` // 32
RelativeAttentionMaxDistance int32 `json:"relative_attention_max_distance"` // 128
}
// T5TextEncoder is the T5 encoder for text conditioning
type T5TextEncoder struct {
Config *T5Config
// Embedding (shared for ByT5)
SharedEmbed *nn.Embedding `weight:"shared"`
// Encoder layers
Layers []*T5Block `weight:"encoder.block"`
// Final layer norm
FinalNorm *T5LayerNorm `weight:"encoder.final_layer_norm"`
// Relative position bias (from first layer, shared across all)
RelativeAttentionBias *mlx.Array `weight:"encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"`
}
// T5Block is a single T5 encoder block
type T5Block struct {
// Self attention
Layer0 *T5LayerSelfAttention `weight:"layer.0"`
// FFN
Layer1 *T5LayerFF `weight:"layer.1"`
}
// T5LayerSelfAttention is T5's self-attention layer
type T5LayerSelfAttention struct {
SelfAttention *T5Attention `weight:"SelfAttention"`
LayerNorm *T5LayerNorm `weight:"layer_norm"`
}
// T5Attention implements T5's relative attention
type T5Attention struct {
Q *mlx.Array `weight:"q.weight"` // No bias in T5
K *mlx.Array `weight:"k.weight"`
V *mlx.Array `weight:"v.weight"`
O *mlx.Array `weight:"o.weight"`
NHeads int32
DKV int32
Scale float32
}
// T5LayerFF is T5's feedforward layer with gated-gelu
type T5LayerFF struct {
DenseReluDense *T5DenseGatedGelu `weight:"DenseReluDense"`
LayerNorm *T5LayerNorm `weight:"layer_norm"`
}
// T5DenseGatedGelu is T5's gated-gelu FFN
type T5DenseGatedGelu struct {
Wi0 *mlx.Array `weight:"wi_0.weight"` // gate projection
Wi1 *mlx.Array `weight:"wi_1.weight"` // up projection
Wo *mlx.Array `weight:"wo.weight"` // down projection
}
// T5LayerNorm is T5's RMSNorm variant (no bias, no mean subtraction)
type T5LayerNorm struct {
Weight *mlx.Array `weight:"weight"`
Eps float32
}
// Load loads the T5 text encoder from manifest
func (m *T5TextEncoder) Load(manifest *imagegen.ModelManifest) error {
fmt.Print(" Loading T5 text encoder... ")
// Load config
var cfg T5Config
if err := manifest.ReadConfigJSON("text_encoder/config.json", &cfg); err != nil {
return fmt.Errorf("config: %w", err)
}
m.Config = &cfg
// Pre-allocate layers
m.Layers = make([]*T5Block, cfg.NumLayers)
// Load weights
weights, err := imagegen.LoadWeightsFromManifest(manifest, "text_encoder")
if err != nil {
return fmt.Errorf("weights: %w", err)
}
if err := weights.Load(0); err != nil {
return fmt.Errorf("load weights: %w", err)
}
defer weights.ReleaseAll()
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return fmt.Errorf("load module: %w", err)
}
m.initComputedFields()
fmt.Println("✓")
return nil
}
// LoadFromPath loads the T5 text encoder from a directory path
func (m *T5TextEncoder) LoadFromPath(path string) error {
fmt.Print(" Loading T5 text encoder... ")
// Load config
var cfg T5Config
configPath := filepath.Join(path, "config.json")
data, err := os.ReadFile(configPath)
if err != nil {
return fmt.Errorf("read config: %w", err)
}
if err := json.Unmarshal(data, &cfg); err != nil {
return fmt.Errorf("parse config: %w", err)
}
m.Config = &cfg
// Pre-allocate layers
m.Layers = make([]*T5Block, cfg.NumLayers)
// Load weights from safetensors files
weights, err := safetensors.LoadModelWeights(path)
if err != nil {
return fmt.Errorf("weights: %w", err)
}
if err := weights.Load(0); err != nil {
return fmt.Errorf("load weights: %w", err)
}
defer weights.ReleaseAll()
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return fmt.Errorf("load module: %w", err)
}
m.initComputedFields()
fmt.Println("✓")
return nil
}
func (m *T5TextEncoder) initComputedFields() {
cfg := m.Config
m.FinalNorm.Eps = cfg.LayerNormEps
for _, block := range m.Layers {
attn := block.Layer0.SelfAttention
attn.NHeads = cfg.NumHeads
attn.DKV = cfg.DKV
attn.Scale = float32(1.0 / math.Sqrt(float64(cfg.DKV)))
block.Layer0.LayerNorm.Eps = cfg.LayerNormEps
block.Layer1.LayerNorm.Eps = cfg.LayerNormEps
}
}
// Forward encodes text tokens
func (m *T5TextEncoder) Forward(tokens *mlx.Array) *mlx.Array {
cfg := m.Config
// Get embeddings
h := m.SharedEmbed.Forward(tokens)
// Compute relative position bias once
seqLen := tokens.Shape()[1]
posBias := m.computeRelativePositionBias(seqLen)
// Forward through layers
for _, block := range m.Layers {
h = block.Forward(h, posBias, cfg.LayerNormEps)
}
// Final norm
h = m.FinalNorm.Forward(h)
return h
}
// extractGlyphTexts extracts quoted text (glyphs) from the prompt
// This matches diffusers' get_glyph_texts from pipeline_glm_image.py
// Glyph texts are used for text rendering guidance in the generated image
func extractGlyphTexts(prompt string) []string {
var glyphTexts []string
// Extract text in single quotes: 'text'
re1 := regexp.MustCompile(`'([^']*)'`)
for _, match := range re1.FindAllStringSubmatch(prompt, -1) {
if len(match) > 1 {
glyphTexts = append(glyphTexts, match[1])
}
}
// Extract text in Unicode curly double quotes: "text"
re2 := regexp.MustCompile(`"([^""]*)"`)
for _, match := range re2.FindAllStringSubmatch(prompt, -1) {
if len(match) > 1 {
glyphTexts = append(glyphTexts, match[1])
}
}
// Extract text in ASCII double quotes: "text"
re3 := regexp.MustCompile(`"([^"]*)"`)
for _, match := range re3.FindAllStringSubmatch(prompt, -1) {
if len(match) > 1 {
glyphTexts = append(glyphTexts, match[1])
}
}
// Extract text in Japanese quotes: 「text」
re4 := regexp.MustCompile(`「([^「」]*)」`)
for _, match := range re4.FindAllStringSubmatch(prompt, -1) {
if len(match) > 1 {
glyphTexts = append(glyphTexts, match[1])
}
}
return glyphTexts
}
// EncodePrompt encodes the prompt text using the ByT5 tokenizer and encoder
// This provides text conditioning for the diffusion transformer via the glyph projector
//
// IMPORTANT: This encodes only the GLYPH TEXTS (quoted strings in the prompt), not the
// full prompt. Glyph texts are used for text rendering guidance in the generated image.
// Multiple glyph texts are encoded and concatenated to form the conditioning signal.
// This matches diffusers' _get_glyph_embeds() behavior.
func (m *T5TextEncoder) EncodePrompt(tok *ByT5Tokenizer, prompt string) *mlx.Array {
// Extract glyph texts from prompt (text in quotes)
glyphTexts := extractGlyphTexts(prompt)
// If no glyph texts found, encode empty string (matches diffusers: [""] fallback)
if len(glyphTexts) == 0 {
glyphTexts = []string{""}
}
// Encode each glyph text and collect token sequences
// Matching diffusers' _get_glyph_embeds() which batches all glyph texts
var allTokenSeqs [][]int32
for _, glyphText := range glyphTexts {
// ByT5 uses byte-level encoding: each byte (0-255) -> token (3-258)
tokens := tok.Encode(glyphText)
// Add EOS token (1) at the end to match HuggingFace tokenizer behavior
tokens = append(tokens, tok.EOSTokenID)
allTokenSeqs = append(allTokenSeqs, tokens)
}
// Process each glyph text through the encoder
var allEmbeddings []*mlx.Array
for _, tokens := range allTokenSeqs {
tokenLen := len(tokens)
if tokenLen == 0 {
continue
}
// Create token array [1, L]
tokensArr := mlx.NewArrayInt32(tokens, []int32{1, int32(tokenLen)})
// Forward through encoder
output := m.Forward(tokensArr)
mlx.Eval(output)
allEmbeddings = append(allEmbeddings, output)
}
// Concatenate all glyph embeddings along sequence dimension
var output *mlx.Array
if len(allEmbeddings) == 0 {
// Fallback: return single zero embedding
output = mlx.Zeros([]int32{1, 1, m.Config.DModel}, mlx.DtypeBFloat16)
} else if len(allEmbeddings) == 1 {
output = allEmbeddings[0]
} else {
output = mlx.Concatenate(allEmbeddings, 1)
}
mlx.Eval(output)
return output
}
// computeRelativePositionBias computes T5's relative position encoding
func (m *T5TextEncoder) computeRelativePositionBias(seqLen int32) *mlx.Array {
cfg := m.Config
// Create relative position matrix
// For each (query_pos, key_pos) pair, compute bucketed relative position
numBuckets := cfg.RelativeAttentionNumBuckets
maxDistance := cfg.RelativeAttentionMaxDistance
// Create position indices
contextPos := make([]int32, seqLen*seqLen)
memoryPos := make([]int32, seqLen*seqLen)
for i := int32(0); i < seqLen; i++ {
for j := int32(0); j < seqLen; j++ {
contextPos[i*seqLen+j] = i
memoryPos[i*seqLen+j] = j
}
}
// Compute relative positions and bucket them
buckets := make([]int32, seqLen*seqLen)
for i := int32(0); i < seqLen*seqLen; i++ {
relPos := memoryPos[i] - contextPos[i]
buckets[i] = relativePosistionBucket(relPos, numBuckets, maxDistance, false)
}
// Create bucket indices array
bucketsArr := mlx.NewArrayInt32(buckets, []int32{seqLen, seqLen})
// Look up bias: RelativeAttentionBias shape is [numBuckets, numHeads] = [32, 6]
// Take along axis 0 (buckets dimension) -> [seqLen, seqLen, numHeads]
bias := mlx.Take(m.RelativeAttentionBias, bucketsArr, 0) // [seqLen, seqLen, numHeads]
// Transpose to [numHeads, seqLen, seqLen]
bias = mlx.Transpose(bias, 2, 0, 1) // [numHeads, seqLen, seqLen]
bias = mlx.ExpandDims(bias, 0) // [1, numHeads, seqLen, seqLen]
return bias
}
// relativePosistionBucket computes the bucket for a relative position
func relativePosistionBucket(relativePosition int32, numBuckets int32, maxDistance int32, bidirectional bool) int32 {
var bucket int32 = 0
var n int32 = -relativePosition
if bidirectional {
numBuckets /= 2
if n < 0 {
bucket += numBuckets
n = -n
}
} else {
if n < 0 {
n = 0
}
}
// Half buckets are for exact positions, half are for log-spaced
maxExact := numBuckets / 2
if n < maxExact {
bucket += n
} else {
// Log-spaced buckets
logVal := math.Log(float64(n)/float64(maxExact)) / math.Log(float64(maxDistance)/float64(maxExact))
bucket += maxExact + int32(logVal*float64(numBuckets-maxExact))
if bucket > numBuckets-1 {
bucket = numBuckets - 1
}
}
return bucket
}
// Forward for T5Block
func (b *T5Block) Forward(x *mlx.Array, posBias *mlx.Array, eps float32) *mlx.Array {
// Self attention with residual
h := b.Layer0.Forward(x, posBias, eps)
// FFN with residual
h = b.Layer1.Forward(h, eps)
return h
}
// Forward for T5LayerSelfAttention
func (l *T5LayerSelfAttention) Forward(x *mlx.Array, posBias *mlx.Array, eps float32) *mlx.Array {
// Pre-norm
normed := l.LayerNorm.Forward(x)
// Attention
attnOut := l.SelfAttention.Forward(normed, posBias)
// Residual
return mlx.Add(x, attnOut)
}
// Forward for T5Attention
func (attn *T5Attention) Forward(x *mlx.Array, posBias *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
D := shape[2]
// Q, K, V projections (no bias)
// Weights are [out_features, in_features], so we use matmul with transpose
q := mlx.Matmul(x, mlx.Transpose(attn.Q, 1, 0))
k := mlx.Matmul(x, mlx.Transpose(attn.K, 1, 0))
v := mlx.Matmul(x, mlx.Transpose(attn.V, 1, 0))
// Reshape to [B, L, nheads, d_kv]
q = mlx.Reshape(q, B, L, attn.NHeads, attn.DKV)
k = mlx.Reshape(k, B, L, attn.NHeads, attn.DKV)
v = mlx.Reshape(v, B, L, attn.NHeads, attn.DKV)
// Transpose to [B, nheads, L, d_kv]
q = mlx.Transpose(q, 0, 2, 1, 3)
k = mlx.Transpose(k, 0, 2, 1, 3)
v = mlx.Transpose(v, 0, 2, 1, 3)
// Attention scores with relative position bias
// T5 uses UNSCALED dot-product attention: scores = q @ k.T + pos_bias
// (no 1/sqrt(d_k) scale factor like in standard transformers)
scores := mlx.Matmul(q, mlx.Transpose(k, 0, 1, 3, 2))
scores = mlx.Add(scores, posBias)
// Softmax
attnWeights := mlx.Softmax(scores, -1)
// Attend to values
out := mlx.Matmul(attnWeights, v)
// Transpose back [B, nheads, L, d_kv] -> [B, L, nheads, d_kv]
out = mlx.Transpose(out, 0, 2, 1, 3)
// Reshape to [B, L, D]
out = mlx.Reshape(out, B, L, attn.NHeads*attn.DKV)
// Output projection
out = mlx.Matmul(out, mlx.Transpose(attn.O, 1, 0))
_ = D // Silence unused warning
return out
}
// Forward for T5LayerFF
func (l *T5LayerFF) Forward(x *mlx.Array, eps float32) *mlx.Array {
// Pre-norm
normed := l.LayerNorm.Forward(x)
// FFN
ffOut := l.DenseReluDense.Forward(normed)
// Residual
return mlx.Add(x, ffOut)
}
// geluNew implements the GELU activation with tanh approximation (gelu_new)
// This matches HuggingFace transformers' gelu_new/OpenAI GPT implementation
// Formula: 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x³)))
func geluNew(x *mlx.Array) *mlx.Array {
sqrt2OverPi := float32(0.7978845608) // sqrt(2/π)
coeff := float32(0.044715)
x3 := mlx.Mul(mlx.Mul(x, x), x)
inner := mlx.MulScalar(mlx.Add(x, mlx.MulScalar(x3, coeff)), sqrt2OverPi)
return mlx.Mul(mlx.MulScalar(x, 0.5), mlx.AddScalar(mlx.Tanh(inner), 1.0))
}
// Forward for T5DenseGatedGelu (gated-gelu activation)
func (d *T5DenseGatedGelu) Forward(x *mlx.Array) *mlx.Array {
// Gate projection with GELU activation (T5 v1.1/ByT5 uses gelu_new)
gate := mlx.Matmul(x, mlx.Transpose(d.Wi0, 1, 0))
gate = geluNew(gate)
// Up projection
up := mlx.Matmul(x, mlx.Transpose(d.Wi1, 1, 0))
// Gated output
h := mlx.Mul(gate, up)
// Down projection
return mlx.Matmul(h, mlx.Transpose(d.Wo, 1, 0))
}
// Forward for T5LayerNorm (RMSNorm variant)
func (ln *T5LayerNorm) Forward(x *mlx.Array) *mlx.Array {
// T5 uses RMSNorm: x * rsqrt(mean(x^2) + eps) * weight
variance := mlx.Mean(mlx.Square(x), -1, true)
x = mlx.Mul(x, mlx.RSqrt(mlx.AddScalar(variance, ln.Eps)))
return mlx.Mul(x, ln.Weight)
}

View File

File diff suppressed because it is too large Load Diff

View File

@@ -1,477 +0,0 @@
//go:build mlx
package glm_image
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// VAEConfig holds VAE decoder configuration
type VAEConfig struct {
InChannels int32 `json:"in_channels"` // 3
OutChannels int32 `json:"out_channels"` // 3
LatentChannels int32 `json:"latent_channels"` // 16
BlockOutChannels []int32 `json:"block_out_channels"` // [128, 512, 1024, 1024]
LayersPerBlock int32 `json:"layers_per_block"` // 3
NormNumGroups int32 `json:"norm_num_groups"` // 32
ScalingFactor float32 `json:"scaling_factor"` // 0.18215
ShiftFactor *float32 `json:"shift_factor"` // null
LatentsMean []float32 `json:"latents_mean"` // [16 values]
LatentsStd []float32 `json:"latents_std"` // [16 values]
}
// VAEDecoder is the VAE latent decoder
type VAEDecoder struct {
Config *VAEConfig
// Decoder components
ConvIn *VAEConv2d `weight:"decoder.conv_in"`
MidBlock *VAEMidBlock `weight:"decoder.mid_block"`
UpBlocks []*VAEUpBlock `weight:"decoder.up_blocks"`
ConvNormOut *GroupNorm `weight:"decoder.conv_norm_out"`
ConvOut *VAEConv2d `weight:"decoder.conv_out"`
}
// VAEConv2d is a 2D convolution layer
type VAEConv2d struct {
Weight *mlx.Array `weight:"weight"`
Bias *mlx.Array `weight:"bias"`
Stride int32
Padding int32
}
// GroupNorm is group normalization
type GroupNorm struct {
Weight *mlx.Array `weight:"weight"`
Bias *mlx.Array `weight:"bias"`
NumGroups int32
Eps float32
}
// VAEMidBlock is the middle block of the VAE
type VAEMidBlock struct {
Resnets []*VAEResnetBlock `weight:"resnets"`
}
// VAEUpBlock is an upsampling block
type VAEUpBlock struct {
Resnets []*VAEResnetBlock `weight:"resnets"`
Upsamplers []*VAEUpsampler `weight:"upsamplers"`
}
// VAEResnetBlock is a residual block
type VAEResnetBlock struct {
Norm1 *GroupNorm `weight:"norm1"`
Conv1 *VAEConv2d `weight:"conv1"`
Norm2 *GroupNorm `weight:"norm2"`
Conv2 *VAEConv2d `weight:"conv2"`
ConvShortcut *VAEConv2d `weight:"conv_shortcut,optional"` // Optional, for channel mismatch
}
// VAEUpsampler is an upsampling layer
type VAEUpsampler struct {
Conv *VAEConv2d `weight:"conv"`
}
// Load loads the VAE decoder from manifest
func (m *VAEDecoder) Load(manifest *imagegen.ModelManifest) error {
fmt.Print(" Loading VAE decoder... ")
// Load config
var cfg VAEConfig
if err := manifest.ReadConfigJSON("vae/config.json", &cfg); err != nil {
return fmt.Errorf("config: %w", err)
}
m.Config = &cfg
// Initialize structure based on config
numBlocks := len(cfg.BlockOutChannels)
m.UpBlocks = make([]*VAEUpBlock, numBlocks)
// Pre-allocate MidBlock resnets (VAE mid_block typically has 2 resnets)
m.MidBlock = &VAEMidBlock{
Resnets: make([]*VAEResnetBlock, 2),
}
// Pre-allocate UpBlocks with their resnets and upsamplers
// VAE decoder has layers_per_block+1 resnets per up_block (to match encoder)
// And all but the last up_block has an upsampler
for i := 0; i < numBlocks; i++ {
numResnets := cfg.LayersPerBlock + 1 // typically 4 resnets
m.UpBlocks[i] = &VAEUpBlock{
Resnets: make([]*VAEResnetBlock, numResnets),
}
// All but the last block has upsamplers
if i < numBlocks-1 {
m.UpBlocks[i].Upsamplers = make([]*VAEUpsampler, 1)
}
}
// Load weights
weights, err := imagegen.LoadWeightsFromManifest(manifest, "vae")
if err != nil {
return fmt.Errorf("weights: %w", err)
}
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
return fmt.Errorf("load weights: %w", err)
}
defer weights.ReleaseAll()
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return fmt.Errorf("load module: %w", err)
}
// Initialize GroupNorm parameters
m.initGroupNorms()
fmt.Println("✓")
return nil
}
// LoadFromPath loads the VAE decoder from a directory path
func (m *VAEDecoder) LoadFromPath(path string) error {
fmt.Print(" Loading VAE decoder... ")
// Load config
var cfg VAEConfig
configPath := filepath.Join(path, "config.json")
data, err := os.ReadFile(configPath)
if err != nil {
return fmt.Errorf("read config: %w", err)
}
if err := json.Unmarshal(data, &cfg); err != nil {
return fmt.Errorf("parse config: %w", err)
}
m.Config = &cfg
// Initialize structure based on config
numBlocks := len(cfg.BlockOutChannels)
m.UpBlocks = make([]*VAEUpBlock, numBlocks)
// Pre-allocate MidBlock resnets (VAE mid_block typically has 2 resnets)
m.MidBlock = &VAEMidBlock{
Resnets: make([]*VAEResnetBlock, 2),
}
// Pre-allocate UpBlocks with their resnets and upsamplers
for i := 0; i < numBlocks; i++ {
numResnets := cfg.LayersPerBlock + 1
m.UpBlocks[i] = &VAEUpBlock{
Resnets: make([]*VAEResnetBlock, numResnets),
}
if i < numBlocks-1 {
m.UpBlocks[i].Upsamplers = make([]*VAEUpsampler, 1)
}
}
// Load weights from safetensors files
weights, err := safetensors.LoadModelWeights(path)
if err != nil {
return fmt.Errorf("weights: %w", err)
}
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
return fmt.Errorf("load weights: %w", err)
}
defer weights.ReleaseAll()
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return fmt.Errorf("load module: %w", err)
}
// Initialize GroupNorm parameters
m.initGroupNorms()
fmt.Println("✓")
return nil
}
func (m *VAEDecoder) initGroupNorms() {
cfg := m.Config
numGroups := cfg.NormNumGroups
eps := float32(1e-6) // Must match diffusers VAE (1e-6, not 1e-5)
if m.ConvNormOut != nil {
m.ConvNormOut.NumGroups = numGroups
m.ConvNormOut.Eps = eps
}
if m.MidBlock != nil {
for _, resnet := range m.MidBlock.Resnets {
if resnet.Norm1 != nil {
resnet.Norm1.NumGroups = numGroups
resnet.Norm1.Eps = eps
}
if resnet.Norm2 != nil {
resnet.Norm2.NumGroups = numGroups
resnet.Norm2.Eps = eps
}
}
}
for _, upBlock := range m.UpBlocks {
if upBlock == nil {
continue
}
for _, resnet := range upBlock.Resnets {
if resnet == nil {
continue
}
if resnet.Norm1 != nil {
resnet.Norm1.NumGroups = numGroups
resnet.Norm1.Eps = eps
}
if resnet.Norm2 != nil {
resnet.Norm2.NumGroups = numGroups
resnet.Norm2.Eps = eps
}
}
}
}
// Decode decodes latents to an image
func (m *VAEDecoder) Decode(latents *mlx.Array) *mlx.Array {
cfg := m.Config
// Apply latent denormalization if mean/std are provided
// This matches diffusers GLM-Image: latents = latents * std + mean
// Note: GLM-Image does NOT divide by scaling_factor (unlike standard SD VAEs)
if len(cfg.LatentsMean) > 0 && len(cfg.LatentsStd) > 0 {
latents = m.denormalizeLatents(latents)
}
// Convert from NCHW to NHWC for processing
// [B, C, H, W] -> [B, H, W, C]
x := mlx.Transpose(latents, 0, 2, 3, 1)
// Initial convolution
x = m.ConvIn.Forward(x)
// Mid block
x = m.MidBlock.Forward(x)
// Up blocks (forward order - index 0 is at lowest resolution/highest channels)
for i := 0; i < len(m.UpBlocks); i++ {
if m.UpBlocks[i] != nil {
x = m.UpBlocks[i].Forward(x)
}
}
// Final normalization and convolution
x = m.ConvNormOut.Forward(x)
x = mlx.SiLU(x)
x = m.ConvOut.Forward(x)
// Convert back to NCHW
// [B, H, W, C] -> [B, C, H, W]
x = mlx.Transpose(x, 0, 3, 1, 2)
// Clamp to valid range and convert to [0, 1]
x = mlx.ClipScalar(x, -1.0, 1.0, true, true)
x = mlx.AddScalar(x, 1.0)
x = mlx.DivScalar(x, 2.0)
return x
}
// denormalizeLatents applies the latent mean/std denormalization
func (m *VAEDecoder) denormalizeLatents(latents *mlx.Array) *mlx.Array {
cfg := m.Config
// Create mean and std arrays [1, C, 1, 1] for broadcasting
mean := mlx.NewArray(cfg.LatentsMean, []int32{1, int32(len(cfg.LatentsMean)), 1, 1})
std := mlx.NewArray(cfg.LatentsStd, []int32{1, int32(len(cfg.LatentsStd)), 1, 1})
// Denormalize: latents * std + mean
latents = mlx.Mul(latents, std)
latents = mlx.Add(latents, mean)
return latents
}
// Forward for VAEConv2d
func (c *VAEConv2d) Forward(x *mlx.Array) *mlx.Array {
// x: [B, H, W, C_in] (NHWC)
// PyTorch weight: [C_out, C_in, kH, kW] (OIHW)
// MLX conv2d expects weight: [C_out, kH, kW, C_in] (OHWI)
// So we need to transpose from OIHW to OHWI
stride := c.Stride
if stride == 0 {
stride = 1
}
padding := c.Padding
if padding == 0 {
// Default to same padding for 3x3 kernels
wShape := c.Weight.Shape()
if len(wShape) >= 3 && wShape[2] == 3 {
padding = 1
}
}
// Transpose weight from OIHW [out, in, h, w] to OHWI [out, h, w, in]
weight := mlx.Transpose(c.Weight, 0, 2, 3, 1)
out := mlx.Conv2d(x, weight, stride, padding)
if c.Bias != nil {
// Bias: [C_out] -> [1, 1, 1, C_out]
bias := mlx.Reshape(c.Bias, 1, 1, 1, -1)
out = mlx.Add(out, bias)
}
return out
}
// Forward for GroupNorm
func (gn *GroupNorm) Forward(x *mlx.Array) *mlx.Array {
// x: [B, H, W, C] (NHWC)
shape := x.Shape()
B := shape[0]
H := shape[1]
W := shape[2]
C := shape[3]
numGroups := gn.NumGroups
if numGroups == 0 {
numGroups = 32
}
groupSize := C / numGroups
// Reshape to [B, H, W, groups, groupSize]
x = mlx.Reshape(x, B, H, W, numGroups, groupSize)
// Compute mean and variance per group
mean := mlx.Mean(x, 1, true)
mean = mlx.Mean(mean, 2, true)
mean = mlx.Mean(mean, 4, true)
xCentered := mlx.Sub(x, mean)
variance := mlx.Mean(mlx.Square(xCentered), 1, true)
variance = mlx.Mean(variance, 2, true)
variance = mlx.Mean(variance, 4, true)
// Normalize
xNorm := mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, gn.Eps)))
// Reshape back
xNorm = mlx.Reshape(xNorm, B, H, W, C)
// Scale and shift
if gn.Weight != nil {
weight := mlx.Reshape(gn.Weight, 1, 1, 1, C)
xNorm = mlx.Mul(xNorm, weight)
}
if gn.Bias != nil {
bias := mlx.Reshape(gn.Bias, 1, 1, 1, C)
xNorm = mlx.Add(xNorm, bias)
}
return xNorm
}
// Forward for VAEMidBlock
func (mb *VAEMidBlock) Forward(x *mlx.Array) *mlx.Array {
for _, resnet := range mb.Resnets {
x = resnet.Forward(x)
}
return x
}
// Forward for VAEUpBlock
func (ub *VAEUpBlock) Forward(x *mlx.Array) *mlx.Array {
// Apply resnets
for _, resnet := range ub.Resnets {
if resnet != nil {
x = resnet.Forward(x)
}
}
// Apply upsamplers
for _, upsampler := range ub.Upsamplers {
if upsampler != nil {
x = upsampler.Forward(x)
}
}
return x
}
// Forward for VAEResnetBlock
func (rb *VAEResnetBlock) Forward(x *mlx.Array) *mlx.Array {
residual := x
// First norm + activation + conv
h := rb.Norm1.Forward(x)
h = mlx.SiLU(h)
h = rb.Conv1.Forward(h)
// Second norm + activation + conv
h = rb.Norm2.Forward(h)
h = mlx.SiLU(h)
h = rb.Conv2.Forward(h)
// Shortcut for channel mismatch
if rb.ConvShortcut != nil {
residual = rb.ConvShortcut.Forward(residual)
}
return mlx.Add(h, residual)
}
// Forward for VAEUpsampler (2x nearest neighbor upsample + conv)
func (us *VAEUpsampler) Forward(x *mlx.Array) *mlx.Array {
// x: [B, H, W, C]
// 2x nearest neighbor upsample
x = upsample2x(x)
// Conv
if us.Conv != nil {
x = us.Conv.Forward(x)
}
return x
}
// upsample2x performs 2x nearest neighbor upsampling.
// Input and output are in NHWC format: [B, H, W, C] -> [B, H*2, W*2, C]
func upsample2x(x *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
H := shape[1]
W := shape[2]
C := shape[3]
// Create indices [0, 0, 1, 1, 2, 2, ...] for nearest neighbor
hIndices := make([]int32, H*2)
for i := int32(0); i < H; i++ {
hIndices[i*2] = i
hIndices[i*2+1] = i
}
wIndices := make([]int32, W*2)
for i := int32(0); i < W; i++ {
wIndices[i*2] = i
wIndices[i*2+1] = i
}
hIdx := mlx.NewArrayInt32(hIndices, []int32{H * 2})
wIdx := mlx.NewArrayInt32(wIndices, []int32{W * 2})
// Take along height axis
x = mlx.Reshape(x, B*H, W, C)
x = mlx.Take(x, wIdx, 1) // [B*H, W*2, C]
x = mlx.Reshape(x, B, H, W*2, C)
// Take along width axis - transpose to [B, W*2, H, C], take, transpose back
x = mlx.Transpose(x, 0, 2, 1, 3) // [B, W*2, H, C]
x = mlx.Reshape(x, B*(W*2), H, C)
x = mlx.Take(x, hIdx, 1) // [B*(W*2), H*2, C]
x = mlx.Reshape(x, B, W*2, H*2, C)
x = mlx.Transpose(x, 0, 2, 1, 3) // [B, H*2, W*2, C]
return x
}

View File

@@ -1,982 +0,0 @@
//go:build mlx
package glm_image
import (
"encoding/json"
"fmt"
"math"
"os"
"path/filepath"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// VisionLanguageConfig holds GLM-Image AR generator configuration
type VisionLanguageConfig struct {
// Text model config
HiddenSize int32 `json:"hidden_size"` // 4096
NumHiddenLayers int32 `json:"num_hidden_layers"` // 40
IntermediateSize int32 `json:"intermediate_size"` // 13696
NumAttentionHeads int32 `json:"num_attention_heads"` // 32
NumKeyValueHeads int32 `json:"num_key_value_heads"` // 2
VocabSize int32 `json:"vocab_size"` // 168064
RMSNormEps float32 `json:"rms_norm_eps"` // 1e-5
// RoPE config
RopeTheta float32 `json:"rope_theta"` // 10000
PartialRotaryFactor float32 `json:"partial_rotary_factor"` // 0.5
MRoPESection []int32 `json:"mrope_section"` // [8, 12, 12]
// Visual token config
VisionVocabSize int32 `json:"vision_vocab_size"` // 16512
ImageStartTokenID int32 `json:"image_start_token_id"` // 16384
ImageEndTokenID int32 `json:"image_end_token_id"` // 16385
ImageTokenID int32 `json:"image_token_id"` // 167855
// Computed
HeadDim int32
}
// VisionLanguageEncoder is the 9B AR generator
type VisionLanguageEncoder struct {
Config *VisionLanguageConfig
// Embedding
EmbedTokens *nn.Embedding `weight:"model.language_model.embed_tokens"`
// Transformer layers
Layers []*GLMBlock `weight:"model.language_model.layers"`
// Final norm
FinalNorm *nn.RMSNorm `weight:"model.language_model.norm"`
// LM Head
LMHead *mlx.Array `weight:"lm_head.weight"`
}
// GLMBlock is a single transformer block in GLM-4 style
type GLMBlock struct {
// Pre-attention norm (GLM uses post-LN variant)
InputLayerNorm *nn.RMSNorm `weight:"input_layernorm"`
PostSelfAttnNorm *nn.RMSNorm `weight:"post_self_attn_layernorm"`
PostAttnLayerNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
PostMLPLayerNorm *nn.RMSNorm `weight:"post_mlp_layernorm"`
// Attention
SelfAttn *GLMAttention `weight:"self_attn"`
// MLP (fused gate_up)
MLP *GLMMLP `weight:"mlp"`
}
// GLMAttention implements GQA with partial rotary and MRoPE
type GLMAttention struct {
QProj *mlx.Array `weight:"q_proj.weight"`
KProj *mlx.Array `weight:"k_proj.weight"`
VProj *mlx.Array `weight:"v_proj.weight"`
OProj *mlx.Array `weight:"o_proj.weight"`
// QKV have biases in GLM
QBias *mlx.Array `weight:"q_proj.bias"`
KBias *mlx.Array `weight:"k_proj.bias"`
VBias *mlx.Array `weight:"v_proj.bias"`
// Computed
NHeads int32
NKVHeads int32
HeadDim int32
Scale float32
PartialRotary float32 // Only rotate this fraction of head_dim
RopeTheta float32
MRoPESection []int32 // [8, 12, 12] - frequency pairs per dimension (temporal, height, width)
}
// ARCache holds KV caches for all layers using the shared cache implementation
type ARCache struct {
Layers []cache.Cache
}
// NewARCache creates a new cache for the given number of layers
func NewARCache(numLayers int32) *ARCache {
layers := make([]cache.Cache, numLayers)
for i := range layers {
layers[i] = cache.NewKVCache()
}
return &ARCache{Layers: layers}
}
// Free releases all cached tensors
func (c *ARCache) Free() {
for _, layer := range c.Layers {
for _, arr := range layer.State() {
if arr != nil {
arr.Free()
}
}
}
}
// GLMMLP implements fused gate_up SwiGLU MLP
type GLMMLP struct {
// GLM uses fused gate_up_proj: [hidden, 2*intermediate]
GateUpProj *mlx.Array `weight:"gate_up_proj.weight"`
DownProj *mlx.Array `weight:"down_proj.weight"`
}
// Load loads the vision-language encoder from manifest
func (m *VisionLanguageEncoder) Load(manifest *imagegen.ModelManifest) error {
fmt.Print(" Loading vision-language encoder... ")
// Load config
var rawCfg struct {
TextConfig struct {
HiddenSize int32 `json:"hidden_size"`
NumHiddenLayers int32 `json:"num_hidden_layers"`
IntermediateSize int32 `json:"intermediate_size"`
NumAttentionHeads int32 `json:"num_attention_heads"`
NumKeyValueHeads int32 `json:"num_key_value_heads"`
VocabSize int32 `json:"vocab_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
VisionVocabSize int32 `json:"vision_vocab_size"`
RopeParameters struct {
RopeTheta float32 `json:"rope_theta"`
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
MRoPESection []int32 `json:"mrope_section"`
} `json:"rope_parameters"`
} `json:"text_config"`
ImageStartTokenID int32 `json:"image_start_token_id"`
ImageEndTokenID int32 `json:"image_end_token_id"`
ImageTokenID int32 `json:"image_token_id"`
}
if err := manifest.ReadConfigJSON("vision_language_encoder/config.json", &rawCfg); err != nil {
return fmt.Errorf("config: %w", err)
}
cfg := &VisionLanguageConfig{
HiddenSize: rawCfg.TextConfig.HiddenSize,
NumHiddenLayers: rawCfg.TextConfig.NumHiddenLayers,
IntermediateSize: rawCfg.TextConfig.IntermediateSize,
NumAttentionHeads: rawCfg.TextConfig.NumAttentionHeads,
NumKeyValueHeads: rawCfg.TextConfig.NumKeyValueHeads,
VocabSize: rawCfg.TextConfig.VocabSize,
RMSNormEps: rawCfg.TextConfig.RMSNormEps,
VisionVocabSize: rawCfg.TextConfig.VisionVocabSize,
RopeTheta: rawCfg.TextConfig.RopeParameters.RopeTheta,
PartialRotaryFactor: rawCfg.TextConfig.RopeParameters.PartialRotaryFactor,
MRoPESection: rawCfg.TextConfig.RopeParameters.MRoPESection,
ImageStartTokenID: rawCfg.ImageStartTokenID,
ImageEndTokenID: rawCfg.ImageEndTokenID,
ImageTokenID: rawCfg.ImageTokenID,
}
cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads
m.Config = cfg
// Pre-allocate layers
m.Layers = make([]*GLMBlock, cfg.NumHiddenLayers)
// Load weights
weights, err := imagegen.LoadWeightsFromManifest(manifest, "vision_language_encoder")
if err != nil {
return fmt.Errorf("weights: %w", err)
}
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
return fmt.Errorf("load weights: %w", err)
}
defer weights.ReleaseAll()
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return fmt.Errorf("load module: %w", err)
}
m.initComputedFields()
fmt.Printf("✓ [%d layers]\n", cfg.NumHiddenLayers)
return nil
}
// LoadFromPath loads the vision-language encoder from a directory path
func (m *VisionLanguageEncoder) LoadFromPath(path string) error {
fmt.Print(" Loading vision-language encoder... ")
// Load config
var rawCfg struct {
TextConfig struct {
HiddenSize int32 `json:"hidden_size"`
NumHiddenLayers int32 `json:"num_hidden_layers"`
IntermediateSize int32 `json:"intermediate_size"`
NumAttentionHeads int32 `json:"num_attention_heads"`
NumKeyValueHeads int32 `json:"num_key_value_heads"`
VocabSize int32 `json:"vocab_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
VisionVocabSize int32 `json:"vision_vocab_size"`
RopeParameters struct {
RopeTheta float32 `json:"rope_theta"`
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
MRoPESection []int32 `json:"mrope_section"`
} `json:"rope_parameters"`
} `json:"text_config"`
ImageStartTokenID int32 `json:"image_start_token_id"`
ImageEndTokenID int32 `json:"image_end_token_id"`
ImageTokenID int32 `json:"image_token_id"`
}
configPath := filepath.Join(path, "config.json")
data, err := os.ReadFile(configPath)
if err != nil {
return fmt.Errorf("read config: %w", err)
}
if err := json.Unmarshal(data, &rawCfg); err != nil {
return fmt.Errorf("parse config: %w", err)
}
cfg := &VisionLanguageConfig{
HiddenSize: rawCfg.TextConfig.HiddenSize,
NumHiddenLayers: rawCfg.TextConfig.NumHiddenLayers,
IntermediateSize: rawCfg.TextConfig.IntermediateSize,
NumAttentionHeads: rawCfg.TextConfig.NumAttentionHeads,
NumKeyValueHeads: rawCfg.TextConfig.NumKeyValueHeads,
VocabSize: rawCfg.TextConfig.VocabSize,
RMSNormEps: rawCfg.TextConfig.RMSNormEps,
VisionVocabSize: rawCfg.TextConfig.VisionVocabSize,
RopeTheta: rawCfg.TextConfig.RopeParameters.RopeTheta,
PartialRotaryFactor: rawCfg.TextConfig.RopeParameters.PartialRotaryFactor,
MRoPESection: rawCfg.TextConfig.RopeParameters.MRoPESection,
ImageStartTokenID: rawCfg.ImageStartTokenID,
ImageEndTokenID: rawCfg.ImageEndTokenID,
ImageTokenID: rawCfg.ImageTokenID,
}
cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads
m.Config = cfg
// Pre-allocate layers
m.Layers = make([]*GLMBlock, cfg.NumHiddenLayers)
// Load weights
weights, err := safetensors.LoadModelWeights(path)
if err != nil {
return fmt.Errorf("weights: %w", err)
}
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
return fmt.Errorf("load weights: %w", err)
}
defer weights.ReleaseAll()
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return fmt.Errorf("load module: %w", err)
}
m.initComputedFields()
fmt.Printf("✓ [%d layers]\n", cfg.NumHiddenLayers)
return nil
}
func (m *VisionLanguageEncoder) initComputedFields() {
cfg := m.Config
for _, block := range m.Layers {
block.SelfAttn.NHeads = cfg.NumAttentionHeads
block.SelfAttn.NKVHeads = cfg.NumKeyValueHeads
block.SelfAttn.HeadDim = cfg.HeadDim
block.SelfAttn.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
block.SelfAttn.PartialRotary = cfg.PartialRotaryFactor
block.SelfAttn.RopeTheta = cfg.RopeTheta
block.SelfAttn.MRoPESection = cfg.MRoPESection
// Set norm eps
block.InputLayerNorm.Eps = cfg.RMSNormEps
block.PostSelfAttnNorm.Eps = cfg.RMSNormEps
block.PostAttnLayerNorm.Eps = cfg.RMSNormEps
block.PostMLPLayerNorm.Eps = cfg.RMSNormEps
}
m.FinalNorm.Eps = cfg.RMSNormEps
}
// Generate autoregressively generates visual tokens with KV caching
func (m *VisionLanguageEncoder) Generate(
prompt string,
tok *GLMTokenizer,
maxTokens int32,
temperature float32,
topP float32,
seed int64,
targetHeight, targetWidth int32,
progressFn func(int),
) *mlx.Array {
cfg := m.Config
// Encode prompt with grid tokens using GLM tokenizer
// Format: {prompt}<sop>{h} {w}<eop><sop>{prev_h} {prev_w}<eop><|dit_token_16384|>
tokens := tok.EncodeForGeneration(prompt, targetHeight, targetWidth)
// Calculate grid dimensions for MRoPE position IDs
factor := int32(32)
tokenH := targetHeight / factor
tokenW := targetWidth / factor
ratio := float64(tokenH) / float64(tokenW)
prevTokenH := int32(math.Sqrt(ratio) * 16)
prevTokenW := int32(math.Sqrt(1.0/ratio) * 16)
prevGridSize := prevTokenH * prevTokenW
// Create KV cache for all layers
cache := NewARCache(cfg.NumHiddenLayers)
defer cache.Free()
// ===== PREFILL PHASE =====
// Process entire prompt at once, populate cache
promptLen := int32(len(tokens))
tokenArr := mlx.NewArrayInt32(tokens, []int32{1, promptLen})
h := m.EmbedTokens.Forward(tokenArr)
tokenArr.Free()
mlx.Eval(h)
// Compute position IDs for prefill (text tokens use same position for all dims)
prefillPositions := make([][]int32, 3)
for dim := 0; dim < 3; dim++ {
prefillPositions[dim] = make([]int32, promptLen)
for i := int32(0); i < promptLen; i++ {
prefillPositions[dim][i] = i
}
}
// Forward through layers (prefill)
for i, layer := range m.Layers {
oldH := h
h = layer.ForwardWithCache(h, promptLen, 0, cfg.RMSNormEps, cache.Layers[i], prefillPositions)
if i > 0 {
oldH.Free()
}
}
// Eval h and cache arrays together so cache is materialized
evalArgs := []*mlx.Array{h}
for _, lc := range cache.Layers {
evalArgs = append(evalArgs, lc.State()...)
}
mlx.Eval(evalArgs...)
// Final norm and get logits for last position
preNormH := h
h = m.FinalNorm.Forward(h, cfg.RMSNormEps)
preNormH.Free()
lastH := mlx.Slice(h, []int32{0, promptLen - 1, 0}, []int32{1, promptLen, cfg.HiddenSize})
h.Free()
lastH = mlx.Reshape(lastH, 1, cfg.HiddenSize)
logits := mlx.Matmul(lastH, mlx.Transpose(m.LMHead, 1, 0))
lastH.Free()
// Sample first token
var sampleCounter int64 = 0
nextToken := sampleVisualToken(logits, temperature, topP, cfg, seed, &sampleCounter)
logits.Free()
// AR generation loop with caching
// Visual tokens are stored as VQ codebook indices [0, 16383]
// The LM head outputs indices [0, 16511] where:
// - [0, 16383] are VQ codes
// - 16384 is BOS
// - 16385 is EOS
visualTokens := make([]int32, 0, maxTokens)
posOffset := promptLen
visualTokenIdx := int32(0) // Index within visual token sequence for grid position calculation
// Preallocate slice for old cache state to reuse
oldCacheState := make([]*mlx.Array, 0, len(m.Layers)*2)
for i := int32(0); i < maxTokens; i++ {
if progressFn != nil {
progressFn(int(i))
}
// Check for end token (EOS = 16385)
if nextToken == cfg.ImageEndTokenID {
break
}
// Skip BOS token (16384), only store actual VQ codes [0, 16383]
if nextToken == cfg.ImageStartTokenID {
// BOS token - skip storing but continue generation
} else if nextToken < cfg.ImageStartTokenID {
// This is an actual VQ code [0, 16383] - store it
visualTokens = append(visualTokens, nextToken)
}
// Tokens >= 16386 are other special tokens, skip them
// ===== DECODE PHASE =====
// Save old cache state before forward (to free after eval)
oldCacheState = oldCacheState[:0]
for _, lc := range cache.Layers {
oldCacheState = append(oldCacheState, lc.State()...)
}
// Only process the new token, use cached K,V
tokenArr := mlx.NewArrayInt32([]int32{nextToken}, []int32{1, 1})
h := m.EmbedTokens.Forward(tokenArr)
tokenArr.Free()
// Compute MRoPE position IDs for this visual token
// Visual tokens are arranged in two grids: prev grid then target grid
// Position dimensions: [temporal, height, width]
decodePositions := computeVisualTokenPositions(
visualTokenIdx, posOffset, promptLen,
prevTokenH, prevTokenW, prevGridSize,
tokenH, tokenW,
)
// Forward through layers (decode with cache)
for j, layer := range m.Layers {
oldH := h
h = layer.ForwardWithCache(h, 1, posOffset, cfg.RMSNormEps, cache.Layers[j], decodePositions)
if j > 0 { // Don't free the embedding on first layer
oldH.Free()
}
}
// Eval h and new cache state
newCacheState := make([]*mlx.Array, 0, len(m.Layers)*2)
for _, lc := range cache.Layers {
newCacheState = append(newCacheState, lc.State()...)
}
mlx.Eval(append([]*mlx.Array{h}, newCacheState...)...)
// Free old cache state (now that new state is evaluated)
for _, arr := range oldCacheState {
if arr != nil {
arr.Free()
}
}
// Final norm
preNormH := h
h = m.FinalNorm.Forward(h, cfg.RMSNormEps)
preNormH.Free()
// Get logits (h is already [1, 1, hidden_size])
h = mlx.Reshape(h, 1, cfg.HiddenSize)
logits := mlx.Matmul(h, mlx.Transpose(m.LMHead, 1, 0))
h.Free()
// Sample next token
nextToken = sampleVisualToken(logits, temperature, topP, cfg, seed, &sampleCounter)
logits.Free()
posOffset++
visualTokenIdx++
// Periodically clear cache to release intermediate memory
if i%256 == 0 {
mlx.ClearCache()
}
}
if len(visualTokens) == 0 {
// Return at least one token to avoid empty tensor issues
visualTokens = append(visualTokens, 0)
}
return mlx.NewArrayInt32(visualTokens, []int32{1, int32(len(visualTokens))})
}
// computeVisualTokenPositions computes MRoPE position IDs for a visual token
// Returns [3][1] position IDs for temporal, height, and width dimensions
//
// MRoPE position encoding for GLM-Image visual tokens:
// - temporal: CONSTANT within each grid (= decode_pos at grid start)
// - height: decode_pos + row index within grid
// - width: decode_pos + column index within grid
//
// Between grids, decode_pos advances by max(grid_h, grid_w) to ensure
// sufficient positional separation.
func computeVisualTokenPositions(
visualIdx int32, absPos int32, promptLen int32,
prevH, prevW, prevSize int32,
targetH, targetW int32,
) [][]int32 {
positions := make([][]int32, 3)
for dim := 0; dim < 3; dim++ {
positions[dim] = make([]int32, 1)
}
// First grid (prev grid) starts at decode_pos = promptLen
prevGridDecodePos := promptLen
// Second grid (target grid) starts after first grid
// next_pos = prev_decode_pos + max(prevH, prevW)
maxPrev := prevH
if prevW > maxPrev {
maxPrev = prevW
}
targetGridDecodePos := prevGridDecodePos + maxPrev
// Compute position IDs based on which grid the token is in
if visualIdx < prevSize {
// Token is in the prev grid (prev_token_h × prev_token_w)
row := visualIdx / prevW
col := visualIdx % prevW
// temporal is CONSTANT for all tokens in this grid
positions[0][0] = prevGridDecodePos
// height and width are relative to grid's decode_pos
positions[1][0] = prevGridDecodePos + row
positions[2][0] = prevGridDecodePos + col
} else {
// Token is in the target grid (token_h × token_w)
targetIdx := visualIdx - prevSize
row := targetIdx / targetW
col := targetIdx % targetW
// temporal is CONSTANT for all tokens in this grid
positions[0][0] = targetGridDecodePos
// height and width are relative to grid's decode_pos
positions[1][0] = targetGridDecodePos + row
positions[2][0] = targetGridDecodePos + col
}
_ = targetH // Used for documentation clarity
_ = absPos // No longer used - kept for API compatibility
return positions
}
// sampleVisualToken samples from the visual vocabulary using top-p (nucleus) sampling
// Note: For GLM-Image, greedy decoding is not allowed as it may cause repetitive outputs
// Returns a visual token ID in range [0, 16511] which directly indexes into the embedding table
// sampleCounter is incremented for each call to ensure different random values
func sampleVisualToken(logits *mlx.Array, temperature float32, topP float32, cfg *VisionLanguageConfig, seed int64, sampleCounter *int64) int32 {
// The LMHead outputs logits for visual tokens only (shape [1, 16512])
// Output index directly corresponds to vocab ID [0, 16511]
// No offset needed - the visual tokens are at vocab IDs [0, 16511]
visualLogits := logits
// Apply temperature
if temperature != 1.0 && temperature > 0 {
visualLogits = mlx.DivScalar(visualLogits, temperature)
}
// Apply softmax to get probabilities
probs := mlx.Softmax(visualLogits, -1)
mlx.Eval(probs)
// Get the sampled index using top-p sampling
// This directly gives us the vocab ID in [0, 16511]
// Special tokens: 16384 = BOS, 16385 = EOS
// Use seed + counter for reproducible but different random values
effectiveSeed := seed + *sampleCounter
*sampleCounter++
return sampleTopP(probs, topP, effectiveSeed)
}
// sampleTopP implements nucleus (top-p) sampling
// probs: [1, vocab_size] probability distribution
// topP: cumulative probability threshold (e.g., 0.75)
// seed: random seed for reproducible sampling
func sampleTopP(probs *mlx.Array, topP float32, seed int64) int32 {
// Negate probs for descending sort (Argsort only does ascending)
negProbs := mlx.MulScalar(probs, -1)
sortedIndices := mlx.Argsort(negProbs, -1)
sortedProbs := mlx.TakeAlongAxis(probs, sortedIndices, -1)
cumProbs := mlx.Cumsum(sortedProbs, -1)
mlx.Eval(sortedIndices, sortedProbs, cumProbs)
// Find cutoff index where cumulative probability exceeds topP
probsData := sortedProbs.Data()
cumProbsData := cumProbs.Data()
indicesData := sortedIndices.DataInt32()
// Calculate cutoff and renormalize
var cutoffIdx int
var totalProb float32
for i, cp := range cumProbsData {
totalProb += probsData[i]
if cp >= topP {
cutoffIdx = i + 1 // Include this token
break
}
}
if cutoffIdx == 0 {
cutoffIdx = len(probsData) // Use all tokens if topP is very high
}
// Sample from the truncated distribution
// Renormalize the truncated probabilities
truncatedProbs := make([]float32, cutoffIdx)
for i := 0; i < cutoffIdx; i++ {
truncatedProbs[i] = probsData[i] / totalProb
}
// Sample using random number with provided seed for reproducibility
r := mlx.RandomUniform([]int32{1}, uint64(seed))
mlx.Eval(r)
randVal := r.Data()[0]
// Find the sampled token
var cumulative float32
for i := 0; i < cutoffIdx; i++ {
cumulative += truncatedProbs[i]
if randVal < cumulative {
return indicesData[i]
}
}
// Fallback to the last token in truncated set
return indicesData[cutoffIdx-1]
}
// Forward for GLMBlock
func (b *GLMBlock) Forward(x *mlx.Array, seqLen int32, eps float32) *mlx.Array {
return b.ForwardWithCache(x, seqLen, 0, eps, nil, nil)
}
// ForwardWithCache performs block forward with optional KV caching and MRoPE
// positionIDs: [3][L] - position indices for MRoPE (nil = use sequential positions)
func (b *GLMBlock) ForwardWithCache(x *mlx.Array, seqLen int32, posOffset int32, eps float32, kvcache cache.Cache, positionIDs [][]int32) *mlx.Array {
// Pre-attention norm
normed := b.InputLayerNorm.Forward(x, eps)
// Self-attention with RoPE/MRoPE and cache
attnOut := b.SelfAttn.ForwardWithCache(normed, seqLen, posOffset, kvcache, positionIDs)
// Post-attention norm (GLM-4 style)
attnOut = b.PostSelfAttnNorm.Forward(attnOut, eps)
// Residual connection
x = mlx.Add(x, attnOut)
// Post-attention layer norm
normed = b.PostAttnLayerNorm.Forward(x, eps)
// MLP
mlpOut := b.MLP.Forward(normed)
// Post-MLP norm
mlpOut = b.PostMLPLayerNorm.Forward(mlpOut, eps)
// Residual connection
x = mlx.Add(x, mlpOut)
return x
}
// Forward for GLMAttention (without cache - used for prefill)
func (attn *GLMAttention) Forward(x *mlx.Array, seqLen int32) *mlx.Array {
return attn.ForwardWithCache(x, seqLen, 0, nil, nil)
}
// ForwardWithCache performs attention with optional KV caching and MRoPE
// posOffset is the position offset for RoPE (0 for prefill, cached_len for decode)
// positionIDs: [3][L] - if nil, uses sequential positions for all dims (text mode)
// kvcache is updated in-place if provided
func (attn *GLMAttention) ForwardWithCache(x *mlx.Array, seqLen int32, posOffset int32, kvcache cache.Cache, positionIDs [][]int32) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
// Q, K, V projections
q := mlx.Matmul(x, mlx.Transpose(attn.QProj, 1, 0))
k := mlx.Matmul(x, mlx.Transpose(attn.KProj, 1, 0))
v := mlx.Matmul(x, mlx.Transpose(attn.VProj, 1, 0))
// Add biases
if attn.QBias != nil {
q = mlx.Add(q, attn.QBias)
}
if attn.KBias != nil {
k = mlx.Add(k, attn.KBias)
}
if attn.VBias != nil {
v = mlx.Add(v, attn.VBias)
}
// Reshape to [B, L, nheads, head_dim]
q = mlx.Reshape(q, B, L, attn.NHeads, attn.HeadDim)
k = mlx.Reshape(k, B, L, attn.NKVHeads, attn.HeadDim)
v = mlx.Reshape(v, B, L, attn.NKVHeads, attn.HeadDim)
// Apply partial RoPE or MRoPE
rotaryDim := int32(float32(attn.HeadDim) * attn.PartialRotary)
if len(attn.MRoPESection) == 3 && positionIDs != nil {
// Use MRoPE with explicit position IDs
q = applyPartialMRoPE(q, positionIDs, rotaryDim, attn.RopeTheta, attn.MRoPESection)
k = applyPartialMRoPE(k, positionIDs, rotaryDim, attn.RopeTheta, attn.MRoPESection)
} else if len(attn.MRoPESection) == 3 {
// Use MRoPE with sequential positions (same for all dims - text mode)
seqPositions := make([][]int32, 3)
for dim := 0; dim < 3; dim++ {
seqPositions[dim] = make([]int32, L)
for i := int32(0); i < L; i++ {
seqPositions[dim][i] = i + posOffset
}
}
q = applyPartialMRoPE(q, seqPositions, rotaryDim, attn.RopeTheta, attn.MRoPESection)
k = applyPartialMRoPE(k, seqPositions, rotaryDim, attn.RopeTheta, attn.MRoPESection)
} else {
// Fallback to standard RoPE
q = applyPartialRoPEWithOffset(q, L, posOffset, rotaryDim, attn.RopeTheta)
k = applyPartialRoPEWithOffset(k, L, posOffset, rotaryDim, attn.RopeTheta)
}
// Transpose to [B, nheads, L, head_dim]
q = mlx.Transpose(q, 0, 2, 1, 3)
k = mlx.Transpose(k, 0, 2, 1, 3)
v = mlx.Transpose(v, 0, 2, 1, 3)
// Update cache and get full K, V for attention
if kvcache != nil {
k, v = kvcache.Update(k, v, int(L))
}
// Repeat KV for GQA
kExpanded := k
vExpanded := v
if attn.NKVHeads < attn.NHeads {
repeats := attn.NHeads / attn.NKVHeads
kExpanded = repeatKV(k, repeats)
vExpanded = repeatKV(v, repeats)
}
// Scaled dot-product attention with causal mask
out := mlx.ScaledDotProductAttention(q, kExpanded, vExpanded, attn.Scale, true)
// Transpose back [B, nheads, L, head_dim] -> [B, L, nheads, head_dim]
out = mlx.Transpose(out, 0, 2, 1, 3)
// Reshape to [B, L, hidden_size]
out = mlx.Reshape(out, B, L, attn.NHeads*attn.HeadDim)
// Output projection
out = mlx.Matmul(out, mlx.Transpose(attn.OProj, 1, 0))
return out
}
// applyPartialRoPE applies RoPE to only the first rotaryDim dimensions
func applyPartialRoPE(x *mlx.Array, seqLen int32, rotaryDim int32, theta float32) *mlx.Array {
return applyPartialRoPEWithOffset(x, seqLen, 0, rotaryDim, theta)
}
// applyPartialRoPEWithOffset applies RoPE with a position offset
func applyPartialRoPEWithOffset(x *mlx.Array, seqLen int32, posOffset int32, rotaryDim int32, theta float32) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
H := shape[2]
D := shape[3]
if rotaryDim <= 0 || rotaryDim > D {
rotaryDim = D
}
// Split into rotary and pass-through parts
xRot := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, L, H, rotaryDim})
xPass := mlx.Slice(x, []int32{0, 0, 0, rotaryDim}, []int32{B, L, H, D})
// Apply RoPE to rotary part with position offset
xRot = applyRoPEWithOffset(xRot, L, posOffset, theta)
// Concatenate back
return mlx.Concatenate([]*mlx.Array{xRot, xPass}, 3)
}
// applyPartialMRoPE applies Multi-dimensional RoPE (MRoPE) to the first rotaryDim dimensions
// positionIDs: [3, L] - position indices for each dimension (temporal, height, width)
// mrope_section: [8, 12, 12] - frequency pairs per dimension
// For text tokens: all 3 dimensions have the same sequential position
// For image tokens: temporal=seq_idx, height=row, width=col
func applyPartialMRoPE(x *mlx.Array, positionIDs [][]int32, rotaryDim int32, theta float32, mropeSection []int32) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
H := shape[2]
D := shape[3]
if rotaryDim <= 0 || rotaryDim > D {
rotaryDim = D
}
// Split into rotary and pass-through parts
xRot := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, L, H, rotaryDim})
xPass := mlx.Slice(x, []int32{0, 0, 0, rotaryDim}, []int32{B, L, H, D})
// Apply MRoPE to rotary part
xRot = applyMRoPE(xRot, positionIDs, theta, mropeSection)
// Concatenate back
return mlx.Concatenate([]*mlx.Array{xRot, xPass}, 3)
}
// applyMRoPE applies multi-dimensional rotary position embedding
// x: [B, L, H, D] where D is the rotary dimension
// positionIDs: [3][L] - positions for temporal, height, width dimensions
// mropeSection: [8, 12, 12] - frequency pairs per dimension
func applyMRoPE(x *mlx.Array, positionIDs [][]int32, theta float32, mropeSection []int32) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
H := shape[2]
D := shape[3]
half := D / 2
// Validate mrope_section sums to half (number of frequency pairs)
var totalPairs int32
for _, s := range mropeSection {
totalPairs += s
}
if totalPairs != half {
// Fallback to standard RoPE if section doesn't match
return applyRoPEWithOffset(x, L, 0, theta)
}
// Build angles for each position dimension (matching Python's MRoPE approach)
// Python: compute freqs for all dims, then apply_mrope selects freq ranges, then duplicate
// Order: [temporal_8, height_12, width_12] -> duplicate -> [t8, h12, w12, t8, h12, w12]
angleVals := make([]*mlx.Array, 3)
freqOffset := int32(0)
for dim := 0; dim < 3; dim++ {
numPairs := mropeSection[dim]
if numPairs == 0 {
continue
}
// Compute inverse frequencies for this section
// Each dimension uses DIFFERENT frequency ranges:
// - Temporal: frequencies 0 to section[0]-1
// - Height: frequencies section[0] to section[0]+section[1]-1
// - Width: frequencies section[0]+section[1] to sum(section)-1
freqsArr := make([]float32, numPairs)
for i := int32(0); i < numPairs; i++ {
globalIdx := freqOffset + i
freqsArr[i] = float32(1.0 / math.Pow(float64(theta), float64(2*globalIdx)/float64(D)))
}
freqs := mlx.NewArray(freqsArr, []int32{numPairs})
// Position indices for this dimension
posArr := make([]float32, L)
for i := int32(0); i < L; i++ {
posArr[i] = float32(positionIDs[dim][i])
}
pos := mlx.NewArray(posArr, []int32{L})
// Compute angles: [L, numPairs] = outer(pos, freqs)
posExpanded := mlx.Reshape(pos, L, 1)
freqsExpanded := mlx.Reshape(freqs, 1, numPairs)
angleVals[dim] = mlx.Mul(posExpanded, freqsExpanded)
freqOffset += numPairs
}
// Concatenate all sections: [L, half] = [L, 32]
allAngles := mlx.Concatenate(angleVals, 1)
// Duplicate AFTER concatenation: [L, D] = [L, 64]
// This gives: [temporal_8, height_12, width_12, temporal_8, height_12, width_12]
allAngles = mlx.Concatenate([]*mlx.Array{allAngles, allAngles}, 1)
// Compute cos/sin
allCos := mlx.Cos(allAngles)
allSin := mlx.Sin(allAngles)
// Reshape for broadcasting: [1, L, 1, D] to match x [B, L, H, D]
allCos = mlx.Reshape(allCos, 1, L, 1, D)
allSin = mlx.Reshape(allSin, 1, L, 1, D)
// x_rotated = cat([-x_imag, x_real], dim=-1)
x1 := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, L, H, half}) // x_real
x2 := mlx.Slice(x, []int32{0, 0, 0, half}, []int32{B, L, H, D}) // x_imag
x2Neg := mlx.MulScalar(x2, -1) // -x_imag
xRotated := mlx.Concatenate([]*mlx.Array{x2Neg, x1}, 3) // [-x_imag, x_real]
// out = x * cos + x_rotated * sin
return mlx.Add(mlx.Mul(x, allCos), mlx.Mul(xRotated, allSin))
}
// applyRoPE applies rotary position embedding
func applyRoPE(x *mlx.Array, seqLen int32, theta float32) *mlx.Array {
return applyRoPEWithOffset(x, seqLen, 0, theta)
}
// applyRoPEWithOffset applies rotary position embedding with position offset
// Uses the split-half approach (matches diffusers GLM-Image with use_real_unbind_dim=-2)
func applyRoPEWithOffset(x *mlx.Array, seqLen int32, posOffset int32, theta float32) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
H := shape[2]
D := shape[3]
half := D / 2
// Compute inverse frequencies: 1 / (theta^(2i/d))
freqsArr := make([]float32, half)
for i := int32(0); i < half; i++ {
freqsArr[i] = float32(1.0 / math.Pow(float64(theta), float64(2*i)/float64(D)))
}
freqs := mlx.NewArray(freqsArr, []int32{half})
// Position indices with offset
posArr := make([]float32, L)
for i := int32(0); i < L; i++ {
posArr[i] = float32(i + posOffset)
}
pos := mlx.NewArray(posArr, []int32{L})
// Compute angles: [L, half] = outer(pos, freqs)
posExpanded := mlx.Reshape(pos, L, 1)
freqsExpanded := mlx.Reshape(freqs, 1, half)
angles := mlx.Mul(posExpanded, freqsExpanded)
// Duplicate angles to match diffusers: cat([angles, angles], dim=-1) -> [L, D]
anglesDup := mlx.Concatenate([]*mlx.Array{angles, angles}, 1)
// Cos and sin: [L, 1, D] for broadcasting to [B, L, H, D]
cosVals := mlx.Cos(anglesDup)
sinVals := mlx.Sin(anglesDup)
cosVals = mlx.Reshape(cosVals, L, 1, D)
sinVals = mlx.Reshape(sinVals, L, 1, D)
// x_rotated = cat([-x_imag, x_real], dim=-1) where x_real=x[..., :half], x_imag=x[..., half:]
x1 := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, L, H, half}) // x_real
x2 := mlx.Slice(x, []int32{0, 0, 0, half}, []int32{B, L, H, D}) // x_imag
x2Neg := mlx.MulScalar(x2, -1) // -x_imag
xRotated := mlx.Concatenate([]*mlx.Array{x2Neg, x1}, 3) // [-x_imag, x_real]
// out = x * cos + x_rotated * sin
return mlx.Add(mlx.Mul(x, cosVals), mlx.Mul(xRotated, sinVals))
}
// repeatKV repeats key/value heads for GQA
func repeatKV(x *mlx.Array, repeats int32) *mlx.Array {
if repeats == 1 {
return x
}
shape := x.Shape()
// x: [B, nkvheads, L, head_dim]
x = mlx.ExpandDims(x, 2)
// x: [B, nkvheads, 1, L, head_dim]
x = mlx.Tile(x, []int32{1, 1, repeats, 1, 1})
// x: [B, nkvheads, repeats, L, head_dim]
return mlx.Reshape(x, shape[0], shape[1]*repeats, shape[2], shape[3])
}
// Forward for GLMMLP (fused gate_up SwiGLU)
func (m *GLMMLP) Forward(x *mlx.Array) *mlx.Array {
// gate_up_proj outputs [gate, up] concatenated
gateUp := mlx.Matmul(x, mlx.Transpose(m.GateUpProj, 1, 0))
shape := gateUp.Shape()
halfDim := shape[len(shape)-1] / 2
// Split into gate and up
gate := mlx.Slice(gateUp, []int32{0, 0, 0}, []int32{shape[0], shape[1], halfDim})
up := mlx.Slice(gateUp, []int32{0, 0, halfDim}, []int32{shape[0], shape[1], shape[2]})
// SwiGLU: silu(gate) * up
gate = mlx.SiLU(gate)
h := mlx.Mul(gate, up)
// Down projection
return mlx.Matmul(h, mlx.Transpose(m.DownProj, 1, 0))
}

View File

@@ -3,12 +3,33 @@
package qwen_image
import (
"fmt"
"os"
"path/filepath"
"runtime"
"testing"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// TestMain initializes MLX before running tests.
// If MLX libraries are not available, tests are skipped.
func TestMain(m *testing.M) {
// Change to repo root so ./build/lib/ollama/ path works
_, thisFile, _, _ := runtime.Caller(0)
repoRoot := filepath.Join(filepath.Dir(thisFile), "..", "..", "..", "..")
if err := os.Chdir(repoRoot); err != nil {
fmt.Printf("Failed to change to repo root: %v\n", err)
os.Exit(1)
}
if err := mlx.InitMLX(); err != nil {
fmt.Printf("Skipping qwen_image tests: %v\n", err)
os.Exit(0)
}
os.Exit(m.Run())
}
// TestPipelineOutput runs the full pipeline (integration test).
// Skips if model weights not found. Requires ~50GB VRAM.
func TestPipelineOutput(t *testing.T) {

View File

@@ -172,7 +172,7 @@ func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
cfg.Height = 1024
}
if cfg.Steps <= 0 {
cfg.Steps = 30
cfg.Steps = 50
}
if cfg.CFGScale <= 0 {
cfg.CFGScale = 4.0

View File

@@ -3,13 +3,35 @@
package qwen_image_edit
import (
"fmt"
"math"
"os"
"path/filepath"
"runtime"
"testing"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/models/qwen_image"
)
// TestMain initializes MLX before running tests.
// If MLX libraries are not available, tests are skipped.
func TestMain(m *testing.M) {
// Change to repo root so ./build/lib/ollama/ path works
_, thisFile, _, _ := runtime.Caller(0)
repoRoot := filepath.Join(filepath.Dir(thisFile), "..", "..", "..", "..")
if err := os.Chdir(repoRoot); err != nil {
fmt.Printf("Failed to change to repo root: %v\n", err)
os.Exit(1)
}
if err := mlx.InitMLX(); err != nil {
fmt.Printf("Skipping qwen_image_edit tests: %v\n", err)
os.Exit(0)
}
os.Exit(m.Run())
}
// TestComputeAxisFreqs verifies frequency computation matches Python reference
func TestComputeAxisFreqs(t *testing.T) {
theta := float64(10000)

View File

@@ -194,7 +194,7 @@ func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array,
cfg.Height = 1024
}
if cfg.Steps <= 0 {
cfg.Steps = 9 // Turbo default
cfg.Steps = 9 // Z-Image turbo default
}
if cfg.CFGScale <= 0 {
cfg.CFGScale = 4.0

View File

@@ -3,12 +3,34 @@
package nn
import (
"fmt"
"math"
"os"
"path/filepath"
"runtime"
"testing"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// TestMain initializes MLX before running tests.
// If MLX libraries are not available, tests are skipped.
func TestMain(m *testing.M) {
// Change to repo root so ./build/lib/ollama/ path works
_, thisFile, _, _ := runtime.Caller(0)
repoRoot := filepath.Join(filepath.Dir(thisFile), "..", "..", "..")
if err := os.Chdir(repoRoot); err != nil {
fmt.Printf("Failed to change to repo root: %v\n", err)
os.Exit(1)
}
if err := mlx.InitMLX(); err != nil {
fmt.Printf("Skipping nn tests: %v\n", err)
os.Exit(0)
}
os.Exit(m.Run())
}
// TestLinearNoBias verifies Linear without bias computes x @ w.T correctly.
func TestLinearNoBias(t *testing.T) {
// Weight: [out=2, in=3] -> transposed at forward time

View File

@@ -1,22 +0,0 @@
package imagegen
import (
"io"
"strings"
)
// QuantizingTensorLayerCreator creates tensor layers with optional quantization.
// When quantize is true, returns multiple layers (weight + scales + biases).
type QuantizingTensorLayerCreator func(r io.Reader, name, dtype string, shape []int32, quantize bool) ([]LayerInfo, error)
// ShouldQuantize returns true if a tensor should be quantized.
// Quantizes linear weights only, skipping VAE, embeddings, norms, and biases.
func ShouldQuantize(name, component string) bool {
if component == "vae" {
return false
}
if strings.Contains(name, "embed") || strings.Contains(name, "norm") {
return false
}
return strings.HasSuffix(name, ".weight")
}

View File

@@ -19,15 +19,9 @@ import (
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/models/glm_image"
"github.com/ollama/ollama/x/imagegen/models/zimage"
)
// ImageModel is the interface for image generation models
type ImageModel interface {
GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error)
}
// Request is the image generation request format
type Request struct {
Prompt string `json:"prompt"`
@@ -42,14 +36,15 @@ type Response struct {
Content string `json:"content,omitempty"`
Image string `json:"image,omitempty"` // Base64-encoded PNG
Done bool `json:"done"`
Step int `json:"step,omitempty"`
Total int `json:"total,omitempty"`
}
// Server holds the model and handles requests
type Server struct {
mu sync.Mutex
model ImageModel
model *zimage.Model
modelName string
modelType string // "zimage" or "glm_image"
}
// Execute is the entry point for the image runner subprocess
@@ -69,6 +64,12 @@ func Execute(args []string) error {
return fmt.Errorf("--port is required")
}
err := mlx.InitMLX()
if err != nil {
slog.Error("unable to initialize MLX", "error", err)
return err
}
slog.Info("MLX library initialized")
slog.Info("starting image runner", "model", *modelName, "port", *port)
// Check memory requirements before loading
@@ -79,35 +80,15 @@ func Execute(args []string) error {
requiredMemory/(1024*1024*1024), availableMemory/(1024*1024*1024))
}
// Detect model type and load appropriate model
modelType, err := detectModelType(*modelName)
if err != nil {
return fmt.Errorf("failed to detect model type: %w", err)
}
var model ImageModel
switch modelType {
case "GlmImagePipeline":
slog.Info("loading GLM-Image model")
m := &glm_image.Model{}
if err := m.Load(*modelName); err != nil {
return fmt.Errorf("failed to load GLM-Image model: %w", err)
}
model = m
default:
// Default to zimage for ZImagePipeline, FluxPipeline, and unknown types
slog.Info("loading Z-Image model")
m := &zimage.Model{}
if err := m.Load(*modelName); err != nil {
return fmt.Errorf("failed to load Z-Image model: %w", err)
}
model = m
// Load model
model := &zimage.Model{}
if err := model.Load(*modelName); err != nil {
return fmt.Errorf("failed to load model: %w", err)
}
server := &Server{
model: model,
modelName: *modelName,
modelType: modelType,
}
// Set up HTTP handlers
@@ -163,22 +144,8 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
s.mu.Lock()
defer s.mu.Unlock()
// Apply defaults
if req.Width <= 0 {
req.Width = 1024
}
if req.Height <= 0 {
req.Height = 1024
}
if req.Steps <= 0 {
// Default steps depend on model type
switch s.modelType {
case "GlmImagePipeline":
req.Steps = 50 // GLM-Image default
default:
req.Steps = 9 // Z-Image turbo default
}
}
// Model applies its own defaults for width/height/steps
// Only seed needs to be set here if not provided
if req.Seed <= 0 {
req.Seed = time.Now().UnixNano()
}
@@ -192,9 +159,26 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
return
}
// Generate image using interface method
// Generate image
ctx := r.Context()
img, err := s.model.GenerateImage(ctx, req.Prompt, req.Width, req.Height, req.Steps, req.Seed)
img, err := s.model.GenerateFromConfig(ctx, &zimage.GenerateConfig{
Prompt: req.Prompt,
Width: req.Width,
Height: req.Height,
Steps: req.Steps,
Seed: req.Seed,
Progress: func(step, total int) {
resp := Response{
Step: step,
Total: total,
Done: false,
}
data, _ := json.Marshal(resp)
w.Write(data)
w.Write([]byte("\n"))
flusher.Flush()
},
})
if err != nil {
// Don't send error for cancellation
@@ -233,35 +217,3 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("\n"))
flusher.Flush()
}
// detectModelType reads the model manifest and returns the pipeline class name
func detectModelType(modelName string) (string, error) {
manifest, err := imagegen.LoadManifest(modelName)
if err != nil {
return "", err
}
data, err := manifest.ReadConfig("model_index.json")
if err != nil {
return "ZImagePipeline", nil // Default to Z-Image
}
// Try both _class_name (diffusers format) and architecture (ollama format)
var index struct {
ClassName string `json:"_class_name"`
Architecture string `json:"architecture"`
}
if err := json.Unmarshal(data, &index); err != nil {
return "ZImagePipeline", nil
}
// Prefer _class_name, fall back to architecture
className := index.ClassName
if className == "" {
className = index.Architecture
}
if className == "" {
return "ZImagePipeline", nil
}
return className, nil
}

View File

@@ -25,6 +25,11 @@ import (
)
// Server wraps an image generation subprocess to implement llm.LlamaServer.
//
// This implementation is compatible with Ollama's scheduler and can be loaded/unloaded
// like any other model. The plan is to eventually bring this into the llm/ package
// and evolve llm/ to support MLX and multimodal models. For now, keeping the code
// separate allows for independent iteration on image generation support.
type Server struct {
mu sync.Mutex
cmd *exec.Cmd
@@ -37,22 +42,6 @@ type Server struct {
lastErrLock sync.Mutex
}
// completionRequest is sent to the subprocess
type completionRequest struct {
Prompt string `json:"prompt"`
Width int32 `json:"width,omitempty"`
Height int32 `json:"height,omitempty"`
Steps int `json:"steps,omitempty"`
Seed int64 `json:"seed,omitempty"`
}
// completionResponse is received from the subprocess
type completionResponse struct {
Content string `json:"content,omitempty"`
Image string `json:"image,omitempty"`
Done bool `json:"done"`
}
// NewServer spawns a new image generation subprocess and waits until it's ready.
func NewServer(modelName string) (*Server, error) {
// Validate platform support before attempting to start
@@ -72,7 +61,7 @@ func NewServer(modelName string) (*Server, error) {
port = rand.Intn(65535-49152) + 49152
}
// Get the ollama-mlx executable path (in same directory as current executable)
// Get the current executable path (we use the same binary with runner subcommand)
exe, err := os.Executable()
if err != nil {
return nil, fmt.Errorf("unable to lookup executable path: %w", err)
@@ -80,10 +69,9 @@ func NewServer(modelName string) (*Server, error) {
if eval, err := filepath.EvalSymlinks(exe); err == nil {
exe = eval
}
mlxExe := filepath.Join(filepath.Dir(exe), "ollama-mlx")
// Spawn subprocess: ollama-mlx runner --image-engine --model <path> --port <port>
cmd := exec.Command(mlxExe, "runner", "--image-engine", "--model", modelName, "--port", strconv.Itoa(port))
// Spawn subprocess: ollama runner --image-engine --model <path> --port <port>
cmd := exec.Command(exe, "runner", "--image-engine", "--model", modelName, "--port", strconv.Itoa(port))
cmd.Env = os.Environ()
// On Linux, set LD_LIBRARY_PATH to include MLX library directories
@@ -139,14 +127,13 @@ func NewServer(modelName string) (*Server, error) {
for scanner.Scan() {
line := scanner.Text()
slog.Warn("image-runner", "msg", line)
// Capture last error line for better error reporting
s.lastErrLock.Lock()
s.lastErr = line
s.lastErrLock.Unlock()
}
}()
slog.Info("starting ollama-mlx image runner subprocess", "exe", mlxExe, "model", modelName, "port", port)
slog.Info("starting image runner subprocess", "exe", exe, "model", modelName, "port", port)
if err := cmd.Start(); err != nil {
return nil, fmt.Errorf("failed to start image runner: %w", err)
}
@@ -171,7 +158,6 @@ func (s *Server) ModelPath() string {
return s.modelName
}
// Load is called by the scheduler after the server is created.
func (s *Server) Load(ctx context.Context, systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, requireFull bool) ([]ml.DeviceID, error) {
return nil, nil
}
@@ -204,20 +190,16 @@ func (s *Server) waitUntilRunning() error {
for {
select {
case err := <-s.done:
// Include last stderr line for better error context
s.lastErrLock.Lock()
lastErr := s.lastErr
s.lastErrLock.Unlock()
if lastErr != "" {
return fmt.Errorf("image runner failed: %s (exit: %v)", lastErr, err)
// Include recent stderr lines for better error context
errMsg := s.getLastErr()
if errMsg != "" {
return fmt.Errorf("image runner failed: %s (exit: %v)", errMsg, err)
}
return fmt.Errorf("image runner exited unexpectedly: %w", err)
case <-timeout:
s.lastErrLock.Lock()
lastErr := s.lastErr
s.lastErrLock.Unlock()
if lastErr != "" {
return fmt.Errorf("timeout waiting for image runner: %s", lastErr)
errMsg := s.getLastErr()
if errMsg != "" {
return fmt.Errorf("timeout waiting for image runner: %s", errMsg)
}
return errors.New("timeout waiting for image runner to start")
case <-ticker.C:
@@ -229,44 +211,41 @@ func (s *Server) waitUntilRunning() error {
}
}
// WaitUntilRunning implements the LlamaServer interface (no-op since NewServer waits).
func (s *Server) WaitUntilRunning(ctx context.Context) error {
return nil
// getLastErr returns the last stderr line.
func (s *Server) getLastErr() string {
s.lastErrLock.Lock()
defer s.lastErrLock.Unlock()
return s.lastErr
}
// Completion generates an image from the prompt via the subprocess.
func (s *Server) WaitUntilRunning(ctx context.Context) error { return nil }
func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
// Build request
creq := completionRequest{
seed := req.Seed
if seed == 0 {
seed = time.Now().UnixNano()
}
// Build request for subprocess
creq := struct {
Prompt string `json:"prompt"`
Width int32 `json:"width,omitempty"`
Height int32 `json:"height,omitempty"`
Steps int32 `json:"steps,omitempty"`
Seed int64 `json:"seed,omitempty"`
}{
Prompt: req.Prompt,
Width: 1024,
Height: 1024,
Steps: 9,
Seed: time.Now().UnixNano(),
Width: req.Width,
Height: req.Height,
Steps: req.Steps,
Seed: seed,
}
if req.Options != nil {
if req.Options.NumCtx > 0 && req.Options.NumCtx <= 4096 {
creq.Width = int32(req.Options.NumCtx)
}
if req.Options.NumGPU > 0 && req.Options.NumGPU <= 4096 {
creq.Height = int32(req.Options.NumGPU)
}
if req.Options.NumPredict > 0 && req.Options.NumPredict <= 100 {
creq.Steps = req.Options.NumPredict
}
if req.Options.Seed > 0 {
creq.Seed = int64(req.Options.Seed)
}
}
// Encode request body
body, err := json.Marshal(creq)
if err != nil {
return err
}
// Send request to subprocess
url := fmt.Sprintf("http://127.0.0.1:%d/completion", s.port)
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
if err != nil {
@@ -281,30 +260,36 @@ func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn f
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("completion request failed: %d", resp.StatusCode)
return fmt.Errorf("request failed: %d", resp.StatusCode)
}
// Stream responses - use large buffer for base64 image data
scanner := bufio.NewScanner(resp.Body)
scanner.Buffer(make([]byte, 1024*1024), 16*1024*1024) // 16MB max
for scanner.Scan() {
var cresp completionResponse
if err := json.Unmarshal(scanner.Bytes(), &cresp); err != nil {
// Parse subprocess response (has singular "image" field)
var raw struct {
Image string `json:"image,omitempty"`
Content string `json:"content,omitempty"`
Done bool `json:"done"`
Step int `json:"step,omitempty"`
Total int `json:"total,omitempty"`
}
if err := json.Unmarshal(scanner.Bytes(), &raw); err != nil {
continue
}
content := cresp.Content
// If this is the final response with an image, encode it in the content
if cresp.Done && cresp.Image != "" {
content = "IMAGE_BASE64:" + cresp.Image
// Convert to llm.CompletionResponse
cresp := llm.CompletionResponse{
Content: raw.Content,
Done: raw.Done,
Step: raw.Step,
TotalSteps: raw.Total,
Image: raw.Image,
}
fn(llm.CompletionResponse{
Content: content,
Done: cresp.Done,
})
fn(cresp)
if cresp.Done {
break
return nil
}
}
@@ -346,22 +331,18 @@ func (s *Server) VRAMByGPU(id ml.DeviceID) uint64 {
return s.vramSize
}
// Embedding is not supported for image generation models.
func (s *Server) Embedding(ctx context.Context, input string) ([]float32, int, error) {
return nil, 0, errors.New("embedding not supported for image generation models")
return nil, 0, errors.New("not supported")
}
// Tokenize is not supported for image generation models.
func (s *Server) Tokenize(ctx context.Context, content string) ([]int, error) {
return nil, errors.New("tokenize not supported for image generation models")
return nil, errors.New("not supported")
}
// Detokenize is not supported for image generation models.
func (s *Server) Detokenize(ctx context.Context, tokens []int) (string, error) {
return "", errors.New("detokenize not supported for image generation models")
return "", errors.New("not supported")
}
// Pid returns the subprocess PID.
func (s *Server) Pid() int {
s.mu.Lock()
defer s.mu.Unlock()
@@ -371,17 +352,9 @@ func (s *Server) Pid() int {
return -1
}
// GetPort returns the subprocess port.
func (s *Server) GetPort() int {
return s.port
}
func (s *Server) GetPort() int { return s.port }
func (s *Server) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo { return nil }
// GetDeviceInfos returns nil since we don't track GPU info.
func (s *Server) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
return nil
}
// HasExited returns true if the subprocess has exited.
func (s *Server) HasExited() bool {
select {
case <-s.done:

View File

@@ -1,5 +1,9 @@
include(FetchContent)
# Read MLX version from top-level file (shared with Dockerfile)
file(READ "${CMAKE_SOURCE_DIR}/MLX_VERSION" MLX_C_GIT_TAG)
string(STRIP "${MLX_C_GIT_TAG}" MLX_C_GIT_TAG)
set(MLX_C_BUILD_EXAMPLES OFF)
set(MLX_BUILD_GGUF OFF)
@@ -50,7 +54,7 @@ endif()
FetchContent_Declare(
mlx-c
GIT_REPOSITORY "https://github.com/ml-explore/mlx-c.git"
GIT_TAG v0.4.1)
GIT_TAG ${MLX_C_GIT_TAG})
FetchContent_MakeAvailable(mlx-c)
set_target_output_directory(mlx)

View File

@@ -0,0 +1,92 @@
// mlx_dynamic.c - Dynamic loading wrapper for MLX-C library
// This file provides runtime dynamic loading of libmlxc instead of link-time binding
#include "mlx_dynamic.h"
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#ifdef _WIN32
#include <windows.h>
typedef HMODULE lib_handle_t;
#define LOAD_LIB(path) LoadLibraryA(path)
#define GET_SYMBOL(handle, name) GetProcAddress(handle, name)
#define CLOSE_LIB(handle) FreeLibrary(handle)
#define LIB_ERROR() "LoadLibrary failed"
static const char* LIB_NAMES[] = {"libmlxc.dll", NULL};
#else
#include <dlfcn.h>
typedef void* lib_handle_t;
#define LOAD_LIB(path) dlopen(path, RTLD_LAZY | RTLD_GLOBAL)
#define GET_SYMBOL(handle, name) dlsym(handle, name)
#define CLOSE_LIB(handle) dlclose(handle)
#define LIB_ERROR() dlerror()
#ifdef __APPLE__
static const char* LIB_NAMES[] = {
"libmlxc.dylib",
"@loader_path/../build/lib/ollama/libmlxc.dylib",
"@executable_path/../build/lib/ollama/libmlxc.dylib",
"build/lib/ollama/libmlxc.dylib",
"../build/lib/ollama/libmlxc.dylib",
NULL
};
#else
static const char* LIB_NAMES[] = {
"libmlxc.so",
"$ORIGIN/../build/lib/ollama/libmlxc.so",
"build/lib/ollama/libmlxc.so",
"../build/lib/ollama/libmlxc.so",
NULL
};
#endif
#endif
static lib_handle_t mlx_handle = NULL;
static int mlx_initialized = 0;
static char mlx_error_buffer[512] = {0};
// Initialize MLX dynamic library
// Returns 0 on success, -1 on failure
// On failure, call mlx_dynamic_error() to get error message
int mlx_dynamic_init(void) {
if (mlx_initialized) {
return 0; // Already initialized
}
// Try each possible library path
for (int i = 0; LIB_NAMES[i] != NULL; i++) {
mlx_handle = LOAD_LIB(LIB_NAMES[i]);
if (mlx_handle != NULL) {
mlx_initialized = 1;
snprintf(mlx_error_buffer, sizeof(mlx_error_buffer),
"MLX: Successfully loaded %s", LIB_NAMES[i]);
return 0;
}
}
// Failed to load library
const char* err = LIB_ERROR();
snprintf(mlx_error_buffer, sizeof(mlx_error_buffer),
"MLX: Failed to load libmlxc library. %s",
err ? err : "Unknown error");
return -1;
}
// Get the last error message
const char* mlx_dynamic_error(void) {
return mlx_error_buffer;
}
// Check if MLX is initialized
int mlx_dynamic_is_initialized(void) {
return mlx_initialized;
}
// Cleanup (optional, called at program exit)
void mlx_dynamic_cleanup(void) {
if (mlx_handle != NULL) {
CLOSE_LIB(mlx_handle);
mlx_handle = NULL;
mlx_initialized = 0;
}
}

View File

@@ -0,0 +1,26 @@
// mlx_dynamic.h - Dynamic loading interface for MLX-C library
#ifndef MLX_DYNAMIC_H
#define MLX_DYNAMIC_H
#ifdef __cplusplus
extern "C" {
#endif
// Initialize the MLX dynamic library
// Returns 0 on success, -1 on failure
int mlx_dynamic_init(void);
// Get the last error message from dynamic loading
const char* mlx_dynamic_error(void);
// Check if MLX is initialized
int mlx_dynamic_is_initialized(void);
// Cleanup resources (optional, for clean shutdown)
void mlx_dynamic_cleanup(void);
#ifdef __cplusplus
}
#endif
#endif // MLX_DYNAMIC_H

284
x/server/show.go Normal file
View File

@@ -0,0 +1,284 @@
package server
import (
"encoding/binary"
"encoding/json"
"fmt"
"io"
"os"
"strings"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/x/imagegen"
)
// modelConfig represents the HuggingFace config.json structure
type modelConfig struct {
Architectures []string `json:"architectures"`
ModelType string `json:"model_type"`
HiddenSize int `json:"hidden_size"`
NumHiddenLayers int `json:"num_hidden_layers"`
MaxPositionEmbeddings int `json:"max_position_embeddings"`
IntermediateSize int `json:"intermediate_size"`
NumAttentionHeads int `json:"num_attention_heads"`
NumKeyValueHeads int `json:"num_key_value_heads"`
VocabSize int `json:"vocab_size"`
RMSNormEps float64 `json:"rms_norm_eps"`
RopeTheta float64 `json:"rope_theta"`
TorchDtype string `json:"torch_dtype"`
TextConfig *struct {
HiddenSize int `json:"hidden_size"`
MaxPositionEmbeddings int `json:"max_position_embeddings"`
NumHiddenLayers int `json:"num_hidden_layers"`
} `json:"text_config"`
}
// GetSafetensorsLLMInfo extracts model information from safetensors LLM models.
// It reads the config.json layer and returns a map compatible with GGML's KV format.
func GetSafetensorsLLMInfo(modelName string) (map[string]any, error) {
manifest, err := imagegen.LoadManifest(modelName)
if err != nil {
return nil, fmt.Errorf("failed to load manifest: %w", err)
}
var config modelConfig
if err := manifest.ReadConfigJSON("config.json", &config); err != nil {
return nil, fmt.Errorf("failed to read config.json: %w", err)
}
// Calculate total tensor bytes from manifest layers
var totalBytes int64
var tensorCount int64
for _, layer := range manifest.Manifest.Layers {
if layer.MediaType == "application/vnd.ollama.image.tensor" {
totalBytes += layer.Size
tensorCount++
}
}
return buildModelInfo(config, totalBytes, tensorCount), nil
}
// buildModelInfo constructs the model info map from config and tensor stats.
// This is separated for testability.
func buildModelInfo(config modelConfig, totalTensorBytes, tensorCount int64) map[string]any {
// Determine architecture
arch := config.ModelType
if arch == "" && len(config.Architectures) > 0 {
// Convert HuggingFace architecture name to Ollama format
// e.g., "Gemma3ForCausalLM" -> "gemma3"
hfArch := config.Architectures[0]
arch = strings.ToLower(hfArch)
arch = strings.TrimSuffix(arch, "forcausallm")
arch = strings.TrimSuffix(arch, "forconditionalgeneration")
}
// Use text_config values if they exist (for multimodal models)
hiddenSize := config.HiddenSize
maxPosEmbed := config.MaxPositionEmbeddings
numLayers := config.NumHiddenLayers
if config.TextConfig != nil {
if config.TextConfig.HiddenSize > 0 {
hiddenSize = config.TextConfig.HiddenSize
}
if config.TextConfig.MaxPositionEmbeddings > 0 {
maxPosEmbed = config.TextConfig.MaxPositionEmbeddings
}
if config.TextConfig.NumHiddenLayers > 0 {
numLayers = config.TextConfig.NumHiddenLayers
}
}
// Get dtype to determine bytes per parameter for count calculation
dtype := config.TorchDtype
// Determine bytes per parameter based on dtype
var bytesPerParam int64 = 2 // default to float16/bfloat16
switch strings.ToLower(dtype) {
case "float32":
bytesPerParam = 4
case "float16", "bfloat16":
bytesPerParam = 2
case "int8", "uint8":
bytesPerParam = 1
}
// Subtract safetensors header overhead (88 bytes per tensor file)
// Each tensor is stored as a minimal safetensors file
totalBytes := totalTensorBytes - tensorCount*88
paramCount := totalBytes / bytesPerParam
info := map[string]any{
"general.architecture": arch,
}
if maxPosEmbed > 0 {
info[fmt.Sprintf("%s.context_length", arch)] = maxPosEmbed
}
if hiddenSize > 0 {
info[fmt.Sprintf("%s.embedding_length", arch)] = hiddenSize
}
if numLayers > 0 {
info[fmt.Sprintf("%s.block_count", arch)] = numLayers
}
if config.NumAttentionHeads > 0 {
info[fmt.Sprintf("%s.attention.head_count", arch)] = config.NumAttentionHeads
}
if config.NumKeyValueHeads > 0 {
info[fmt.Sprintf("%s.attention.head_count_kv", arch)] = config.NumKeyValueHeads
}
if config.IntermediateSize > 0 {
info[fmt.Sprintf("%s.feed_forward_length", arch)] = config.IntermediateSize
}
if config.VocabSize > 0 {
info[fmt.Sprintf("%s.vocab_size", arch)] = config.VocabSize
}
if paramCount > 0 {
info["general.parameter_count"] = paramCount
}
return info
}
// GetSafetensorsTensorInfo extracts tensor information from safetensors model layers.
// Each tensor is stored as a minimal safetensors file with an 88-byte header containing metadata.
func GetSafetensorsTensorInfo(modelName string) ([]api.Tensor, error) {
manifest, err := imagegen.LoadManifest(modelName)
if err != nil {
return nil, fmt.Errorf("failed to load manifest: %w", err)
}
return getTensorInfoFromManifest(manifest)
}
// getTensorInfoFromManifest extracts tensor info from a manifest.
// This is separated for testability.
func getTensorInfoFromManifest(manifest *imagegen.ModelManifest) ([]api.Tensor, error) {
var tensors []api.Tensor
for _, layer := range manifest.Manifest.Layers {
if layer.MediaType != "application/vnd.ollama.image.tensor" {
continue
}
// Read the safetensors header from the blob
blobPath := manifest.BlobPath(layer.Digest)
info, err := readSafetensorsHeader(blobPath)
if err != nil {
// Skip tensors we can't read
continue
}
// Convert shape from int to uint64
shape := make([]uint64, len(info.Shape))
for i, s := range info.Shape {
shape[i] = uint64(s)
}
tensors = append(tensors, api.Tensor{
Name: layer.Name,
Type: info.Dtype,
Shape: shape,
})
}
return tensors, nil
}
// GetSafetensorsDtype returns the quantization type for a safetensors model.
// If the model is quantized (has _scale tensors), returns the quantization type (e.g., "FP8").
// Otherwise returns the torch_dtype from config.json.
func GetSafetensorsDtype(modelName string) (string, error) {
manifest, err := imagegen.LoadManifest(modelName)
if err != nil {
return "", fmt.Errorf("failed to load manifest: %w", err)
}
// Check if model is quantized by looking for _scale tensors
for _, layer := range manifest.Manifest.Layers {
if layer.MediaType == "application/vnd.ollama.image.tensor" {
if strings.HasSuffix(layer.Name, "_scale") {
// Model is quantized - return FP8 (affine quantization)
return "FP8", nil
}
}
}
// Not quantized - return torch_dtype from config.json
var cfg struct {
TorchDtype string `json:"torch_dtype"`
}
if err := manifest.ReadConfigJSON("config.json", &cfg); err != nil {
return "", fmt.Errorf("failed to read config.json: %w", err)
}
return cfg.TorchDtype, nil
}
// safetensorsTensorInfo holds metadata about a tensor from a safetensors header
type safetensorsTensorInfo struct {
Dtype string `json:"dtype"`
Shape []int64 `json:"shape"`
}
// readSafetensorsHeader reads the JSON header from a safetensors file to get tensor metadata.
// Safetensors format: 8-byte header size (little endian) + JSON header + tensor data
func readSafetensorsHeader(path string) (*safetensorsTensorInfo, error) {
f, err := os.Open(path)
if err != nil {
return nil, err
}
defer f.Close()
return parseSafetensorsHeader(f)
}
// parseSafetensorsHeader parses a safetensors header from a reader.
// This is separated for testability.
func parseSafetensorsHeader(r io.Reader) (*safetensorsTensorInfo, error) {
// Read header size (8 bytes, little endian)
var headerSize uint64
if err := binary.Read(r, binary.LittleEndian, &headerSize); err != nil {
return nil, fmt.Errorf("failed to read header size: %w", err)
}
// Sanity check - header shouldn't be too large
if headerSize > 1024*1024 {
return nil, fmt.Errorf("header size too large: %d", headerSize)
}
// Read header JSON
headerBytes := make([]byte, headerSize)
if _, err := io.ReadFull(r, headerBytes); err != nil {
return nil, fmt.Errorf("failed to read header: %w", err)
}
// Parse as map of tensor name -> info
var header map[string]json.RawMessage
if err := json.Unmarshal(headerBytes, &header); err != nil {
return nil, fmt.Errorf("failed to parse header: %w", err)
}
// Find the first (and should be only) tensor entry
for name, raw := range header {
if name == "__metadata__" {
continue
}
var info safetensorsTensorInfo
if err := json.Unmarshal(raw, &info); err != nil {
return nil, fmt.Errorf("failed to parse tensor info: %w", err)
}
return &info, nil
}
return nil, fmt.Errorf("no tensor found in header")
}

597
x/server/show_test.go Normal file
View File

@@ -0,0 +1,597 @@
package server
import (
"bytes"
"encoding/binary"
"encoding/json"
"os"
"path/filepath"
"testing"
"github.com/ollama/ollama/x/imagegen"
)
func TestBuildModelInfo(t *testing.T) {
tests := []struct {
name string
config modelConfig
totalTensorBytes int64
tensorCount int64
wantArch string
wantContextLen int
wantEmbedLen int
wantBlockCount int
wantParamCount int64
}{
{
name: "gemma3 model with model_type",
config: modelConfig{
ModelType: "gemma3",
HiddenSize: 2560,
NumHiddenLayers: 34,
MaxPositionEmbeddings: 131072,
IntermediateSize: 10240,
NumAttentionHeads: 8,
NumKeyValueHeads: 4,
VocabSize: 262144,
TorchDtype: "bfloat16",
},
totalTensorBytes: 8_600_000_088, // ~4.3B params * 2 bytes + 88 bytes header
tensorCount: 1,
wantArch: "gemma3",
wantContextLen: 131072,
wantEmbedLen: 2560,
wantBlockCount: 34,
wantParamCount: 4_300_000_000,
},
{
name: "llama model with architectures array",
config: modelConfig{
Architectures: []string{"LlamaForCausalLM"},
HiddenSize: 4096,
NumHiddenLayers: 32,
MaxPositionEmbeddings: 4096,
IntermediateSize: 11008,
NumAttentionHeads: 32,
NumKeyValueHeads: 32,
VocabSize: 32000,
TorchDtype: "float16",
},
totalTensorBytes: 14_000_000_088, // ~7B params * 2 bytes + 88 bytes header
tensorCount: 1,
wantArch: "llama",
wantContextLen: 4096,
wantEmbedLen: 4096,
wantBlockCount: 32,
wantParamCount: 7_000_000_000,
},
{
name: "multimodal model with text_config",
config: modelConfig{
Architectures: []string{"Gemma3ForConditionalGeneration"},
HiddenSize: 1152, // vision hidden size
TextConfig: &struct {
HiddenSize int `json:"hidden_size"`
MaxPositionEmbeddings int `json:"max_position_embeddings"`
NumHiddenLayers int `json:"num_hidden_layers"`
}{
HiddenSize: 2560,
MaxPositionEmbeddings: 131072,
NumHiddenLayers: 34,
},
NumAttentionHeads: 8,
NumKeyValueHeads: 4,
VocabSize: 262144,
TorchDtype: "bfloat16",
},
totalTensorBytes: 8_600_000_088,
tensorCount: 1,
wantArch: "gemma3",
wantContextLen: 131072,
wantEmbedLen: 2560,
wantBlockCount: 34,
wantParamCount: 4_300_000_000,
},
{
name: "float32 model",
config: modelConfig{
ModelType: "test",
HiddenSize: 512,
NumHiddenLayers: 6,
MaxPositionEmbeddings: 2048,
TorchDtype: "float32",
},
totalTensorBytes: 400_000_088, // 100M params * 4 bytes + 88 bytes header
tensorCount: 1,
wantArch: "test",
wantContextLen: 2048,
wantEmbedLen: 512,
wantBlockCount: 6,
wantParamCount: 100_000_000,
},
{
name: "multiple tensors with header overhead",
config: modelConfig{
ModelType: "test",
HiddenSize: 256,
NumHiddenLayers: 4,
MaxPositionEmbeddings: 1024,
TorchDtype: "bfloat16",
},
totalTensorBytes: 2_000_880, // 1M params * 2 bytes + 10 tensors * 88 bytes
tensorCount: 10,
wantArch: "test",
wantContextLen: 1024,
wantEmbedLen: 256,
wantBlockCount: 4,
wantParamCount: 1_000_000,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
info := buildModelInfo(tt.config, tt.totalTensorBytes, tt.tensorCount)
// Check architecture
if arch, ok := info["general.architecture"].(string); !ok || arch != tt.wantArch {
t.Errorf("architecture = %v, want %v", info["general.architecture"], tt.wantArch)
}
// Check context length
contextKey := tt.wantArch + ".context_length"
if contextLen, ok := info[contextKey].(int); !ok || contextLen != tt.wantContextLen {
t.Errorf("context_length = %v, want %v", info[contextKey], tt.wantContextLen)
}
// Check embedding length
embedKey := tt.wantArch + ".embedding_length"
if embedLen, ok := info[embedKey].(int); !ok || embedLen != tt.wantEmbedLen {
t.Errorf("embedding_length = %v, want %v", info[embedKey], tt.wantEmbedLen)
}
// Check block count
blockKey := tt.wantArch + ".block_count"
if blockCount, ok := info[blockKey].(int); !ok || blockCount != tt.wantBlockCount {
t.Errorf("block_count = %v, want %v", info[blockKey], tt.wantBlockCount)
}
// Check parameter count
if paramCount, ok := info["general.parameter_count"].(int64); !ok || paramCount != tt.wantParamCount {
t.Errorf("parameter_count = %v, want %v", info["general.parameter_count"], tt.wantParamCount)
}
})
}
}
func TestBuildModelInfo_ArchitectureConversion(t *testing.T) {
tests := []struct {
name string
architectures []string
modelType string
wantArch string
}{
{
name: "LlamaForCausalLM",
architectures: []string{"LlamaForCausalLM"},
wantArch: "llama",
},
{
name: "Gemma3ForCausalLM",
architectures: []string{"Gemma3ForCausalLM"},
wantArch: "gemma3",
},
{
name: "Gemma3ForConditionalGeneration",
architectures: []string{"Gemma3ForConditionalGeneration"},
wantArch: "gemma3",
},
{
name: "Qwen2ForCausalLM",
architectures: []string{"Qwen2ForCausalLM"},
wantArch: "qwen2",
},
{
name: "model_type takes precedence",
architectures: []string{"LlamaForCausalLM"},
modelType: "custom",
wantArch: "custom",
},
{
name: "empty architectures with model_type",
architectures: nil,
modelType: "mymodel",
wantArch: "mymodel",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
config := modelConfig{
Architectures: tt.architectures,
ModelType: tt.modelType,
}
info := buildModelInfo(config, 0, 0)
if arch, ok := info["general.architecture"].(string); !ok || arch != tt.wantArch {
t.Errorf("architecture = %v, want %v", info["general.architecture"], tt.wantArch)
}
})
}
}
func TestBuildModelInfo_BytesPerParam(t *testing.T) {
tests := []struct {
name string
dtype string
totalBytes int64
tensorCount int64
wantParamCount int64
}{
{
name: "bfloat16",
dtype: "bfloat16",
totalBytes: 2_000_088, // 1M * 2 + 88
tensorCount: 1,
wantParamCount: 1_000_000,
},
{
name: "float16",
dtype: "float16",
totalBytes: 2_000_088,
tensorCount: 1,
wantParamCount: 1_000_000,
},
{
name: "float32",
dtype: "float32",
totalBytes: 4_000_088, // 1M * 4 + 88
tensorCount: 1,
wantParamCount: 1_000_000,
},
{
name: "int8",
dtype: "int8",
totalBytes: 1_000_088, // 1M * 1 + 88
tensorCount: 1,
wantParamCount: 1_000_000,
},
{
name: "unknown dtype defaults to 2 bytes",
dtype: "unknown",
totalBytes: 2_000_088,
tensorCount: 1,
wantParamCount: 1_000_000,
},
{
name: "empty dtype defaults to 2 bytes",
dtype: "",
totalBytes: 2_000_088,
tensorCount: 1,
wantParamCount: 1_000_000,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
config := modelConfig{
ModelType: "test",
TorchDtype: tt.dtype,
}
info := buildModelInfo(config, tt.totalBytes, tt.tensorCount)
if paramCount, ok := info["general.parameter_count"].(int64); !ok || paramCount != tt.wantParamCount {
t.Errorf("parameter_count = %v, want %v", info["general.parameter_count"], tt.wantParamCount)
}
})
}
}
func TestParseSafetensorsHeader(t *testing.T) {
tests := []struct {
name string
header map[string]any
wantDtype string
wantShape []int64
wantErr bool
}{
{
name: "simple tensor",
header: map[string]any{
"weight": map[string]any{
"dtype": "BF16",
"shape": []int64{2560, 262144},
"data_offsets": []int64{0, 1342177280},
},
},
wantDtype: "BF16",
wantShape: []int64{2560, 262144},
},
{
name: "with metadata",
header: map[string]any{
"__metadata__": map[string]any{
"format": "pt",
},
"bias": map[string]any{
"dtype": "F32",
"shape": []int64{1024},
"data_offsets": []int64{0, 4096},
},
},
wantDtype: "F32",
wantShape: []int64{1024},
},
{
name: "float16 tensor",
header: map[string]any{
"layer.weight": map[string]any{
"dtype": "F16",
"shape": []int64{512, 512, 3, 3},
"data_offsets": []int64{0, 4718592},
},
},
wantDtype: "F16",
wantShape: []int64{512, 512, 3, 3},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create safetensors format: 8-byte size + JSON header
headerJSON, err := json.Marshal(tt.header)
if err != nil {
t.Fatalf("failed to marshal header: %v", err)
}
var buf bytes.Buffer
if err := binary.Write(&buf, binary.LittleEndian, uint64(len(headerJSON))); err != nil {
t.Fatalf("failed to write header size: %v", err)
}
buf.Write(headerJSON)
info, err := parseSafetensorsHeader(&buf)
if (err != nil) != tt.wantErr {
t.Errorf("parseSafetensorsHeader() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.wantErr {
return
}
if info.Dtype != tt.wantDtype {
t.Errorf("Dtype = %v, want %v", info.Dtype, tt.wantDtype)
}
if len(info.Shape) != len(tt.wantShape) {
t.Errorf("Shape length = %v, want %v", len(info.Shape), len(tt.wantShape))
} else {
for i, s := range info.Shape {
if s != tt.wantShape[i] {
t.Errorf("Shape[%d] = %v, want %v", i, s, tt.wantShape[i])
}
}
}
})
}
}
func TestParseSafetensorsHeader_Errors(t *testing.T) {
tests := []struct {
name string
data []byte
wantErr string
}{
{
name: "empty data",
data: []byte{},
wantErr: "failed to read header size",
},
{
name: "truncated header size",
data: []byte{0x01, 0x02, 0x03},
wantErr: "failed to read header size",
},
{
name: "header size too large",
data: func() []byte {
var buf bytes.Buffer
binary.Write(&buf, binary.LittleEndian, uint64(2*1024*1024)) // 2MB
return buf.Bytes()
}(),
wantErr: "header size too large",
},
{
name: "truncated header",
data: func() []byte {
var buf bytes.Buffer
binary.Write(&buf, binary.LittleEndian, uint64(100))
buf.Write([]byte("short"))
return buf.Bytes()
}(),
wantErr: "failed to read header",
},
{
name: "invalid JSON",
data: func() []byte {
var buf bytes.Buffer
binary.Write(&buf, binary.LittleEndian, uint64(10))
buf.Write([]byte("not json!!"))
return buf.Bytes()
}(),
wantErr: "failed to parse header",
},
{
name: "no tensors in header",
data: func() []byte {
header := map[string]any{
"__metadata__": map[string]any{"format": "pt"},
}
headerJSON, _ := json.Marshal(header)
var buf bytes.Buffer
binary.Write(&buf, binary.LittleEndian, uint64(len(headerJSON)))
buf.Write(headerJSON)
return buf.Bytes()
}(),
wantErr: "no tensor found in header",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := parseSafetensorsHeader(bytes.NewReader(tt.data))
if err == nil {
t.Error("expected error, got nil")
return
}
if !bytes.Contains([]byte(err.Error()), []byte(tt.wantErr)) {
t.Errorf("error = %v, want error containing %v", err, tt.wantErr)
}
})
}
}
func TestGetTensorInfoFromManifest(t *testing.T) {
// Create a temp directory for blobs
tempDir := t.TempDir()
// Create test tensor blobs
tensors := []struct {
name string
digest string
dtype string
shape []int64
}{
{
name: "model.embed_tokens.weight",
digest: "sha256:abc123",
dtype: "BF16",
shape: []int64{262144, 2560},
},
{
name: "model.layers.0.self_attn.q_proj.weight",
digest: "sha256:def456",
dtype: "BF16",
shape: []int64{2560, 2560},
},
{
name: "model.norm.weight",
digest: "sha256:ghi789",
dtype: "F32",
shape: []int64{2560},
},
}
// Create blob files
var layers []imagegen.ManifestLayer
for _, tensor := range tensors {
// Create safetensors blob
header := map[string]any{
tensor.name: map[string]any{
"dtype": tensor.dtype,
"shape": tensor.shape,
"data_offsets": []int64{0, 1000},
},
}
headerJSON, _ := json.Marshal(header)
var buf bytes.Buffer
binary.Write(&buf, binary.LittleEndian, uint64(len(headerJSON)))
buf.Write(headerJSON)
// Write blob file
blobName := "sha256-" + tensor.digest[7:]
blobPath := filepath.Join(tempDir, blobName)
if err := os.WriteFile(blobPath, buf.Bytes(), 0o644); err != nil {
t.Fatalf("failed to write blob: %v", err)
}
layers = append(layers, imagegen.ManifestLayer{
MediaType: "application/vnd.ollama.image.tensor",
Digest: tensor.digest,
Size: int64(buf.Len() + 1000), // header + fake data
Name: tensor.name,
})
}
// Add a non-tensor layer (should be skipped)
layers = append(layers, imagegen.ManifestLayer{
MediaType: "application/vnd.ollama.image.json",
Digest: "sha256:config",
Size: 100,
Name: "config.json",
})
manifest := &imagegen.ModelManifest{
Manifest: &imagegen.Manifest{
Layers: layers,
},
BlobDir: tempDir,
}
result, err := getTensorInfoFromManifest(manifest)
if err != nil {
t.Fatalf("getTensorInfoFromManifest() error = %v", err)
}
if len(result) != 3 {
t.Errorf("got %d tensors, want 3", len(result))
}
// Verify each tensor
for i, tensor := range tensors {
if i >= len(result) {
break
}
if result[i].Name != tensor.name {
t.Errorf("tensor[%d].Name = %v, want %v", i, result[i].Name, tensor.name)
}
if result[i].Type != tensor.dtype {
t.Errorf("tensor[%d].Type = %v, want %v", i, result[i].Type, tensor.dtype)
}
if len(result[i].Shape) != len(tensor.shape) {
t.Errorf("tensor[%d].Shape length = %v, want %v", i, len(result[i].Shape), len(tensor.shape))
}
}
}
func TestReadSafetensorsHeader(t *testing.T) {
// Create a temp file with a valid safetensors header
tempDir := t.TempDir()
header := map[string]any{
"test_tensor": map[string]any{
"dtype": "BF16",
"shape": []int64{1024, 768},
"data_offsets": []int64{0, 1572864},
},
}
headerJSON, _ := json.Marshal(header)
var buf bytes.Buffer
binary.Write(&buf, binary.LittleEndian, uint64(len(headerJSON)))
buf.Write(headerJSON)
filePath := filepath.Join(tempDir, "test.safetensors")
if err := os.WriteFile(filePath, buf.Bytes(), 0o644); err != nil {
t.Fatalf("failed to write test file: %v", err)
}
info, err := readSafetensorsHeader(filePath)
if err != nil {
t.Fatalf("readSafetensorsHeader() error = %v", err)
}
if info.Dtype != "BF16" {
t.Errorf("Dtype = %v, want BF16", info.Dtype)
}
if len(info.Shape) != 2 || info.Shape[0] != 1024 || info.Shape[1] != 768 {
t.Errorf("Shape = %v, want [1024, 768]", info.Shape)
}
}
func TestReadSafetensorsHeader_FileNotFound(t *testing.T) {
_, err := readSafetensorsHeader("/nonexistent/path/file.safetensors")
if err == nil {
t.Error("expected error for nonexistent file")
}
}