Compare commits

..

2 Commits

Author SHA1 Message Date
jmorganca
bb1a5617b6 readme: add instructions to build with MLX 2026-01-15 09:52:56 -08:00
jmorganca
0d3648c1be glm-image wip 2026-01-14 16:46:50 -08:00
39 changed files with 4716 additions and 575 deletions

View File

@@ -322,7 +322,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
### Web & Desktop
- [Onyx](https://github.com/onyx-dot-app/onyx)
- [Open WebUI](https://github.com/open-webui/open-webui)
- [SwiftChat (macOS with ReactNative)](https://github.com/aws-samples/swift-chat)
- [Enchanted (macOS native)](https://github.com/AugustDev/enchanted)

View File

@@ -116,7 +116,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
Prompt: ">>> ",
AltPrompt: "... ",
Placeholder: "Send a message (/? for help)",
AltPlaceholder: "Press Enter to send",
AltPlaceholder: `Use """ to end multi-line input`,
})
if err != nil {
return err

View File

@@ -21,7 +21,6 @@ ollama pull glm-4.7:cloud
To use Ollama with tools that expect the Anthropic API (like Claude Code), set these environment variables:
```shell
export ANTHROPIC_AUTH_TOKEN=ollama # required but ignored
export ANTHROPIC_BASE_URL=http://localhost:11434
export ANTHROPIC_API_KEY=ollama # required but ignored
```
@@ -248,13 +247,12 @@ curl -X POST http://localhost:11434/v1/messages \
[Claude Code](https://code.claude.com/docs/en/overview) can be configured to use Ollama as its backend:
```shell
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
```
Or set the environment variables in your shell profile:
```shell
export ANTHROPIC_AUTH_TOKEN=ollama
export ANTHROPIC_BASE_URL=http://localhost:11434
export ANTHROPIC_API_KEY=ollama
```

View File

@@ -110,7 +110,7 @@ More Ollama [Python example](https://github.com/ollama/ollama-python/blob/main/e
import { Ollama } from "ollama";
const client = new Ollama();
const results = await client.webSearch("what is ollama?");
const results = await client.webSearch({ query: "what is ollama?" });
console.log(JSON.stringify(results, null, 2));
```
@@ -213,7 +213,7 @@ models](https://ollama.com/models)\n\nAvailable for macOS, Windows, and Linux',
import { Ollama } from "ollama";
const client = new Ollama();
const fetchResult = await client.webFetch("https://ollama.com");
const fetchResult = await client.webFetch({ url: "https://ollama.com" });
console.log(JSON.stringify(fetchResult, null, 2));
```

View File

@@ -111,9 +111,7 @@
"/integrations/zed",
"/integrations/roo-code",
"/integrations/n8n",
"/integrations/xcode",
"/integrations/onyx",
"/integrations/marimo"
"/integrations/xcode"
]
},
{

View File

@@ -22,7 +22,7 @@ Please refer to the [GPU docs](./gpu).
## How can I specify the context window size?
By default, Ollama uses a context window size of 4096 tokens.
By default, Ollama uses a context window size of 2048 tokens.
This can be overridden with the `OLLAMA_CONTEXT_LENGTH` environment variable. For example, to set the default context window to 8K, use:

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 174 KiB

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 80 KiB

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 230 KiB

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 178 KiB

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 186 KiB

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 100 KiB

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 306 KiB

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 300 KiB

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 211 KiB

View File

@@ -25,7 +25,6 @@ Claude Code connects to Ollama using the Anthropic-compatible API.
1. Set the environment variables:
```shell
export ANTHROPIC_AUTH_TOKEN=ollama
export ANTHROPIC_BASE_URL=http://localhost:11434
export ANTHROPIC_API_KEY=ollama
```
@@ -39,7 +38,7 @@ claude --model qwen3-coder
Or run with environment variables inline:
```shell
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
```
## Connecting to ollama.com

View File

@@ -1,73 +0,0 @@
---
title: marimo
---
## Install
Install [marimo](https://marimo.io). You can use `pip` or `uv` for this. You
can also use `uv` to create a sandboxed environment for marimo by running:
```
uvx marimo edit --sandbox notebook.py
```
## Usage with Ollama
1. In marimo, go to the user settings and go to the AI tab. From here
you can find and configure Ollama as an AI provider. For local use you
would typically point the base url to `http://localhost:11434/v1`.
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/marimo-settings.png"
alt="Ollama settings in marimo"
width="50%"
/>
</div>
2. Once the AI provider is set up, you can turn on/off specific AI models you'd like to access.
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/marimo-models.png"
alt="Selecting an Ollama model"
width="50%"
/>
</div>
3. You can also add a model to the list of available models by scrolling to the bottom and using the UI there.
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/marimo-add-model.png"
alt="Adding a new Ollama model"
width="50%"
/>
</div>
4. Once configured, you can now use Ollama for AI chats in marimo.
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/marimo-chat.png"
alt="Configure code completion"
width="50%"
/>
</div>
4. Alternatively, you can now use Ollama for **inline code completion** in marimo. This can be configured in the "AI Features" tab.
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/marimo-code-completion.png"
alt="Configure code completion"
width="50%"
/>
</div>
## Connecting to ollama.com
1. Sign in to ollama cloud via `ollama signin`
2. In the ollama model settings add a model that ollama hosts, like `gpt-oss:120b`.
3. You can now refer to this model in marimo!

View File

@@ -1,63 +0,0 @@
---
title: Onyx
---
## Overview
[Onyx](http://onyx.app/) is a self-hostable Chat UI that integrates with all Ollama models. Features include:
- Creating custom Agents
- Web search
- Deep Research
- RAG over uploaded documents and connected apps
- Connectors to applications like Google Drive, Email, Slack, etc.
- MCP and OpenAPI Actions support
- Image generation
- User/Groups management, RBAC, SSO, etc.
Onyx can be deployed for single users or large organizations.
## Install Onyx
Deploy Onyx with the [quickstart guide](https://docs.onyx.app/deployment/getting_started/quickstart).
<Info>
Resourcing/scaling docs [here](https://docs.onyx.app/deployment/getting_started/resourcing).
</Info>
## Usage with Ollama
1. Login to your Onyx deployment (create an account first).
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/onyx-login.png"
alt="Onyx Login Page"
width="75%"
/>
</div>
2. In the set-up process select `Ollama` as the LLM provider.
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/onyx-ollama-llm.png"
alt="Onyx Set Up Form"
width="75%"
/>
</div>
3. Provide your **Ollama API URL** and select your models.
<Note>If you're running Onyx in Docker, to access your computer's local network use `http://host.docker.internal` instead of `http://127.0.0.1`.</Note>
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/onyx-ollama-form.png"
alt="Selecting Ollama Models"
width="75%"
/>
</div>
You can also easily connect up Onyx Cloud with the `Ollama Cloud` tab of the setup.
## Send your first query
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/onyx-query.png"
alt="Onyx Query Example"
width="75%"
/>
</div>

View File

@@ -1,5 +1,5 @@
---
title: Linux
title: "Linux"
---
## Install
@@ -13,15 +13,14 @@ curl -fsSL https://ollama.com/install.sh | sh
## Manual install
<Note>
If you are upgrading from a prior version, you should remove the old libraries
with `sudo rm -rf /usr/lib/ollama` first.
If you are upgrading from a prior version, you should remove the old libraries with `sudo rm -rf /usr/lib/ollama` first.
</Note>
Download and extract the package:
```shell
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tar.zst \
| sudo tar x -C /usr
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tgz \
| sudo tar zx -C /usr
```
Start Ollama:
@@ -41,8 +40,8 @@ ollama -v
If you have an AMD GPU, also download and extract the additional ROCm package:
```shell
curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tar.zst \
| sudo tar x -C /usr
curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tgz \
| sudo tar zx -C /usr
```
### ARM64 install
@@ -50,8 +49,8 @@ curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tar.zst \
Download and extract the ARM64-specific package:
```shell
curl -fsSL https://ollama.com/download/ollama-linux-arm64.tar.zst \
| sudo tar x -C /usr
curl -fsSL https://ollama.com/download/ollama-linux-arm64.tgz \
| sudo tar zx -C /usr
```
### Adding Ollama as a startup service (recommended)
@@ -113,11 +112,7 @@ sudo systemctl status ollama
```
<Note>
While AMD has contributed the `amdgpu` driver upstream to the official linux
kernel source, the version is older and may not support all ROCm features. We
recommend you install the latest driver from
https://www.amd.com/en/support/linux-drivers for best support of your Radeon
GPU.
While AMD has contributed the `amdgpu` driver upstream to the official linux kernel source, the version is older and may not support all ROCm features. We recommend you install the latest driver from https://www.amd.com/en/support/linux-drivers for best support of your Radeon GPU.
</Note>
## Customizing
@@ -146,8 +141,8 @@ curl -fsSL https://ollama.com/install.sh | sh
Or by re-downloading Ollama:
```shell
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tar.zst \
| sudo tar x -C /usr
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tgz \
| sudo tar zx -C /usr
```
## Installing specific versions
@@ -196,4 +191,4 @@ Remove the downloaded models and Ollama service user and group:
sudo userdel ollama
sudo groupdel ollama
sudo rm -r /usr/share/ollama
```
```

View File

@@ -131,7 +131,7 @@ func TestAPIToolCalling(t *testing.T) {
t.Errorf("unexpected tool called: got %q want %q", lastToolCall.Function.Name, "get_weather")
}
if _, ok := lastToolCall.Function.Arguments.Get("location"); !ok {
if _, ok := lastToolCall.Function.Arguments["location"]; !ok {
t.Errorf("expected tool arguments to include 'location', got: %s", lastToolCall.Function.Arguments.String())
}
case <-ctx.Done():

View File

@@ -8,7 +8,6 @@ import (
"math/rand"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
@@ -442,7 +441,6 @@ type ResponsesWriter struct {
stream bool
responseID string
itemID string
request openai.ResponsesRequest
}
func (w *ResponsesWriter) writeEvent(eventType string, data any) error {
@@ -480,9 +478,7 @@ func (w *ResponsesWriter) writeResponse(data []byte) (int, error) {
// Non-streaming response
w.ResponseWriter.Header().Set("Content-Type", "application/json")
response := openai.ToResponse(w.model, w.responseID, w.itemID, chatResponse, w.request)
completedAt := time.Now().Unix()
response.CompletedAt = &completedAt
response := openai.ToResponse(w.model, w.responseID, w.itemID, chatResponse)
return len(data), json.NewEncoder(w.ResponseWriter).Encode(response)
}
@@ -527,12 +523,11 @@ func ResponsesMiddleware() gin.HandlerFunc {
w := &ResponsesWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
converter: openai.NewResponsesStreamConverter(responseID, itemID, req.Model, req),
converter: openai.NewResponsesStreamConverter(responseID, itemID, req.Model),
model: req.Model,
stream: streamRequested,
responseID: responseID,
itemID: itemID,
request: req,
}
// Set headers based on streaming mode

View File

@@ -630,10 +630,6 @@ func nameFromToolCallID(messages []Message, toolCallID string) string {
// decodeImageURL decodes a base64 data URI into raw image bytes.
func decodeImageURL(url string) (api.ImageData, error) {
if strings.HasPrefix(url, "http://") || strings.HasPrefix(url, "https://") {
return nil, errors.New("image URLs are not currently supported, please use base64 encoded data instead")
}
types := []string{"jpeg", "jpg", "png", "webp"}
// Support blank mime type to match /api/chat's behavior of taking just unadorned base64

View File

@@ -4,7 +4,6 @@ import (
"encoding/json"
"fmt"
"math/rand"
"time"
"github.com/ollama/ollama/api"
)
@@ -266,9 +265,9 @@ type ResponsesText struct {
type ResponsesTool struct {
Type string `json:"type"` // "function"
Name string `json:"name"`
Description *string `json:"description"` // nullable but required
Strict *bool `json:"strict"` // nullable but required
Parameters map[string]any `json:"parameters"` // nullable but required
Description string `json:"description,omitempty"`
Strict bool `json:"strict,omitempty"`
Parameters map[string]any `json:"parameters,omitempty"`
}
type ResponsesRequest struct {
@@ -476,16 +475,11 @@ func convertTool(t ResponsesTool) (api.Tool, error) {
}
}
var description string
if t.Description != nil {
description = *t.Description
}
return api.Tool{
Type: t.Type,
Function: api.ToolFunction{
Name: t.Name,
Description: description,
Description: t.Description,
Parameters: params,
},
}, nil
@@ -522,60 +516,17 @@ func convertInputMessage(m ResponsesInputMessage) (api.Message, error) {
// Response types for the Responses API
// ResponsesTextField represents the text output configuration in the response.
type ResponsesTextField struct {
Format ResponsesTextFormat `json:"format"`
}
// ResponsesReasoningOutput represents reasoning configuration in the response.
type ResponsesReasoningOutput struct {
Effort *string `json:"effort,omitempty"`
Summary *string `json:"summary,omitempty"`
}
// ResponsesError represents an error in the response.
type ResponsesError struct {
Code string `json:"code"`
Message string `json:"message"`
}
// ResponsesIncompleteDetails represents details about why a response was incomplete.
type ResponsesIncompleteDetails struct {
Reason string `json:"reason"`
}
type ResponsesResponse struct {
ID string `json:"id"`
Object string `json:"object"`
CreatedAt int64 `json:"created_at"`
CompletedAt *int64 `json:"completed_at"`
Status string `json:"status"`
IncompleteDetails *ResponsesIncompleteDetails `json:"incomplete_details"`
Model string `json:"model"`
PreviousResponseID *string `json:"previous_response_id"`
Instructions *string `json:"instructions"`
Output []ResponsesOutputItem `json:"output"`
Error *ResponsesError `json:"error"`
Tools []ResponsesTool `json:"tools"`
ToolChoice any `json:"tool_choice"`
Truncation string `json:"truncation"`
ParallelToolCalls bool `json:"parallel_tool_calls"`
Text ResponsesTextField `json:"text"`
TopP float64 `json:"top_p"`
PresencePenalty float64 `json:"presence_penalty"`
FrequencyPenalty float64 `json:"frequency_penalty"`
TopLogprobs int `json:"top_logprobs"`
Temperature float64 `json:"temperature"`
Reasoning *ResponsesReasoningOutput `json:"reasoning"`
Usage *ResponsesUsage `json:"usage"`
MaxOutputTokens *int `json:"max_output_tokens"`
MaxToolCalls *int `json:"max_tool_calls"`
Store bool `json:"store"`
Background bool `json:"background"`
ServiceTier string `json:"service_tier"`
Metadata map[string]any `json:"metadata"`
SafetyIdentifier *string `json:"safety_identifier"`
PromptCacheKey *string `json:"prompt_cache_key"`
ID string `json:"id"`
Object string `json:"object"`
CreatedAt int64 `json:"created_at"`
Status string `json:"status"`
Model string `json:"model"`
Output []ResponsesOutputItem `json:"output"`
Usage *ResponsesUsage `json:"usage,omitempty"`
// TODO(drifkin): add `temperature` and `top_p` to the response, but this
// requires additional plumbing to find the effective values since the
// defaults can come from the model or the request
}
type ResponsesOutputItem struct {
@@ -599,39 +550,18 @@ type ResponsesReasoningSummary struct {
}
type ResponsesOutputContent struct {
Type string `json:"type"` // "output_text"
Text string `json:"text"`
Annotations []any `json:"annotations"`
Logprobs []any `json:"logprobs"`
}
type ResponsesInputTokensDetails struct {
CachedTokens int `json:"cached_tokens"`
}
type ResponsesOutputTokensDetails struct {
ReasoningTokens int `json:"reasoning_tokens"`
Type string `json:"type"` // "output_text"
Text string `json:"text"`
}
type ResponsesUsage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
TotalTokens int `json:"total_tokens"`
InputTokensDetails ResponsesInputTokensDetails `json:"input_tokens_details"`
OutputTokensDetails ResponsesOutputTokensDetails `json:"output_tokens_details"`
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
TotalTokens int `json:"total_tokens"`
}
// derefFloat64 returns the value of a float64 pointer, or a default if nil.
func derefFloat64(p *float64, def float64) float64 {
if p != nil {
return *p
}
return def
}
// ToResponse converts an api.ChatResponse to a Responses API response.
// The request is used to echo back request parameters in the response.
func ToResponse(model, responseID, itemID string, chatResponse api.ChatResponse, request ResponsesRequest) ResponsesResponse {
// ToResponse converts an api.ChatResponse to a Responses API response
func ToResponse(model, responseID, itemID string, chatResponse api.ChatResponse) ResponsesResponse {
var output []ResponsesOutputItem
// Add reasoning item if thinking is present
@@ -655,7 +585,6 @@ func ToResponse(model, responseID, itemID string, chatResponse api.ChatResponse,
output = append(output, ResponsesOutputItem{
ID: fmt.Sprintf("fc_%s_%d", responseID, i),
Type: "function_call",
Status: "completed",
CallID: tc.ID,
Name: tc.Function.Name,
Arguments: tc.Function.Arguments,
@@ -669,90 +598,25 @@ func ToResponse(model, responseID, itemID string, chatResponse api.ChatResponse,
Role: "assistant",
Content: []ResponsesOutputContent{
{
Type: "output_text",
Text: chatResponse.Message.Content,
Annotations: []any{},
Logprobs: []any{},
Type: "output_text",
Text: chatResponse.Message.Content,
},
},
})
}
var instructions *string
if request.Instructions != "" {
instructions = &request.Instructions
}
// Build truncation with default
truncation := "disabled"
if request.Truncation != nil {
truncation = *request.Truncation
}
tools := request.Tools
if tools == nil {
tools = []ResponsesTool{}
}
text := ResponsesTextField{
Format: ResponsesTextFormat{Type: "text"},
}
if request.Text != nil && request.Text.Format != nil {
text.Format = *request.Text.Format
}
// Build reasoning output from request
var reasoning *ResponsesReasoningOutput
if request.Reasoning.Effort != "" || request.Reasoning.Summary != "" {
reasoning = &ResponsesReasoningOutput{}
if request.Reasoning.Effort != "" {
reasoning.Effort = &request.Reasoning.Effort
}
if request.Reasoning.Summary != "" {
reasoning.Summary = &request.Reasoning.Summary
}
}
return ResponsesResponse{
ID: responseID,
Object: "response",
CreatedAt: chatResponse.CreatedAt.Unix(),
CompletedAt: nil, // Set by middleware when writing final response
Status: "completed",
IncompleteDetails: nil, // Only populated if response incomplete
Model: model,
PreviousResponseID: nil, // Not supported
Instructions: instructions,
Output: output,
Error: nil, // Only populated on failure
Tools: tools,
ToolChoice: "auto", // Default value
Truncation: truncation,
ParallelToolCalls: true, // Default value
Text: text,
TopP: derefFloat64(request.TopP, 1.0),
PresencePenalty: 0, // Default value
FrequencyPenalty: 0, // Default value
TopLogprobs: 0, // Default value
Temperature: derefFloat64(request.Temperature, 1.0),
Reasoning: reasoning,
ID: responseID,
Object: "response",
CreatedAt: chatResponse.CreatedAt.Unix(),
Status: "completed",
Model: model,
Output: output,
Usage: &ResponsesUsage{
InputTokens: chatResponse.PromptEvalCount,
OutputTokens: chatResponse.EvalCount,
TotalTokens: chatResponse.PromptEvalCount + chatResponse.EvalCount,
// TODO(drifkin): wire through the actual values
InputTokensDetails: ResponsesInputTokensDetails{CachedTokens: 0},
// TODO(drifkin): wire through the actual values
OutputTokensDetails: ResponsesOutputTokensDetails{ReasoningTokens: 0},
},
MaxOutputTokens: request.MaxOutputTokens,
MaxToolCalls: nil, // Not supported
Store: false, // We don't store responses
Background: request.Background,
ServiceTier: "default", // Default value
Metadata: map[string]any{},
SafetyIdentifier: nil, // Not supported
PromptCacheKey: nil, // Not supported
}
}
@@ -772,7 +636,6 @@ type ResponsesStreamConverter struct {
responseID string
itemID string
model string
request ResponsesRequest
// State tracking (mutated across Process calls)
firstWrite bool
@@ -805,12 +668,11 @@ func (c *ResponsesStreamConverter) newEvent(eventType string, data map[string]an
}
// NewResponsesStreamConverter creates a new converter with the given configuration.
func NewResponsesStreamConverter(responseID, itemID, model string, request ResponsesRequest) *ResponsesStreamConverter {
func NewResponsesStreamConverter(responseID, itemID, model string) *ResponsesStreamConverter {
return &ResponsesStreamConverter{
responseID: responseID,
itemID: itemID,
model: model,
request: request,
firstWrite: true,
}
}
@@ -855,120 +717,25 @@ func (c *ResponsesStreamConverter) Process(r api.ChatResponse) []ResponsesStream
return events
}
// buildResponseObject creates a full response object with all required fields for streaming events.
func (c *ResponsesStreamConverter) buildResponseObject(status string, output []any, usage map[string]any) map[string]any {
var instructions any = nil
if c.request.Instructions != "" {
instructions = c.request.Instructions
}
truncation := "disabled"
if c.request.Truncation != nil {
truncation = *c.request.Truncation
}
var tools []any
if c.request.Tools != nil {
for _, t := range c.request.Tools {
tools = append(tools, map[string]any{
"type": t.Type,
"name": t.Name,
"description": t.Description,
"strict": t.Strict,
"parameters": t.Parameters,
})
}
}
if tools == nil {
tools = []any{}
}
textFormat := map[string]any{"type": "text"}
if c.request.Text != nil && c.request.Text.Format != nil {
textFormat = map[string]any{
"type": c.request.Text.Format.Type,
}
if c.request.Text.Format.Name != "" {
textFormat["name"] = c.request.Text.Format.Name
}
if c.request.Text.Format.Schema != nil {
textFormat["schema"] = c.request.Text.Format.Schema
}
if c.request.Text.Format.Strict != nil {
textFormat["strict"] = *c.request.Text.Format.Strict
}
}
var reasoning any = nil
if c.request.Reasoning.Effort != "" || c.request.Reasoning.Summary != "" {
r := map[string]any{}
if c.request.Reasoning.Effort != "" {
r["effort"] = c.request.Reasoning.Effort
} else {
r["effort"] = nil
}
if c.request.Reasoning.Summary != "" {
r["summary"] = c.request.Reasoning.Summary
} else {
r["summary"] = nil
}
reasoning = r
}
// Build top_p and temperature with defaults
topP := 1.0
if c.request.TopP != nil {
topP = *c.request.TopP
}
temperature := 1.0
if c.request.Temperature != nil {
temperature = *c.request.Temperature
}
return map[string]any{
"id": c.responseID,
"object": "response",
"created_at": time.Now().Unix(),
"completed_at": nil,
"status": status,
"incomplete_details": nil,
"model": c.model,
"previous_response_id": nil,
"instructions": instructions,
"output": output,
"error": nil,
"tools": tools,
"tool_choice": "auto",
"truncation": truncation,
"parallel_tool_calls": true,
"text": map[string]any{"format": textFormat},
"top_p": topP,
"presence_penalty": 0,
"frequency_penalty": 0,
"top_logprobs": 0,
"temperature": temperature,
"reasoning": reasoning,
"usage": usage,
"max_output_tokens": c.request.MaxOutputTokens,
"max_tool_calls": nil,
"store": false,
"background": c.request.Background,
"service_tier": "default",
"metadata": map[string]any{},
"safety_identifier": nil,
"prompt_cache_key": nil,
}
}
func (c *ResponsesStreamConverter) createResponseCreatedEvent() ResponsesStreamEvent {
return c.newEvent("response.created", map[string]any{
"response": c.buildResponseObject("in_progress", []any{}, nil),
"response": map[string]any{
"id": c.responseID,
"object": "response",
"status": "in_progress",
"output": []any{},
},
})
}
func (c *ResponsesStreamConverter) createResponseInProgressEvent() ResponsesStreamEvent {
return c.newEvent("response.in_progress", map[string]any{
"response": c.buildResponseObject("in_progress", []any{}, nil),
"response": map[string]any{
"id": c.responseID,
"object": "response",
"status": "in_progress",
"output": []any{},
},
})
}
@@ -995,10 +762,9 @@ func (c *ResponsesStreamConverter) processThinking(thinking string) []ResponsesS
// Emit delta
events = append(events, c.newEvent("response.reasoning_summary_text.delta", map[string]any{
"item_id": c.reasoningItemID,
"output_index": c.outputIndex,
"summary_index": 0,
"delta": thinking,
"item_id": c.reasoningItemID,
"output_index": c.outputIndex,
"delta": thinking,
}))
// TODO(drifkin): consider adding
@@ -1017,10 +783,9 @@ func (c *ResponsesStreamConverter) finishReasoning() []ResponsesStreamEvent {
events := []ResponsesStreamEvent{
c.newEvent("response.reasoning_summary_text.done", map[string]any{
"item_id": c.reasoningItemID,
"output_index": c.outputIndex,
"summary_index": 0,
"text": c.accumulatedThinking,
"item_id": c.reasoningItemID,
"output_index": c.outputIndex,
"text": c.accumulatedThinking,
}),
c.newEvent("response.output_item.done", map[string]any{
"output_index": c.outputIndex,
@@ -1133,10 +898,8 @@ func (c *ResponsesStreamConverter) processTextContent(content string) []Response
"output_index": c.outputIndex,
"content_index": c.contentIndex,
"part": map[string]any{
"type": "output_text",
"text": "",
"annotations": []any{},
"logprobs": []any{},
"type": "output_text",
"text": "",
},
}))
}
@@ -1150,7 +913,6 @@ func (c *ResponsesStreamConverter) processTextContent(content string) []Response
"output_index": c.outputIndex,
"content_index": 0,
"delta": content,
"logprobs": []any{},
}))
return events
@@ -1182,10 +944,8 @@ func (c *ResponsesStreamConverter) buildFinalOutput() []any {
"status": "completed",
"role": "assistant",
"content": []map[string]any{{
"type": "output_text",
"text": c.accumulatedText,
"annotations": []any{},
"logprobs": []any{},
"type": "output_text",
"text": c.accumulatedText,
}},
})
}
@@ -1207,7 +967,6 @@ func (c *ResponsesStreamConverter) processCompletion(r api.ChatResponse) []Respo
"output_index": c.outputIndex,
"content_index": 0,
"text": c.accumulatedText,
"logprobs": []any{},
}))
// response.content_part.done
@@ -1216,10 +975,8 @@ func (c *ResponsesStreamConverter) processCompletion(r api.ChatResponse) []Respo
"output_index": c.outputIndex,
"content_index": 0,
"part": map[string]any{
"type": "output_text",
"text": c.accumulatedText,
"annotations": []any{},
"logprobs": []any{},
"type": "output_text",
"text": c.accumulatedText,
},
}))
@@ -1232,31 +989,26 @@ func (c *ResponsesStreamConverter) processCompletion(r api.ChatResponse) []Respo
"status": "completed",
"role": "assistant",
"content": []map[string]any{{
"type": "output_text",
"text": c.accumulatedText,
"annotations": []any{},
"logprobs": []any{},
"type": "output_text",
"text": c.accumulatedText,
}},
},
}))
}
// response.completed
usage := map[string]any{
"input_tokens": r.PromptEvalCount,
"output_tokens": r.EvalCount,
"total_tokens": r.PromptEvalCount + r.EvalCount,
"input_tokens_details": map[string]any{
"cached_tokens": 0,
},
"output_tokens_details": map[string]any{
"reasoning_tokens": 0,
},
}
response := c.buildResponseObject("completed", c.buildFinalOutput(), usage)
response["completed_at"] = time.Now().Unix()
events = append(events, c.newEvent("response.completed", map[string]any{
"response": response,
"response": map[string]any{
"id": c.responseID,
"object": "response",
"status": "completed",
"output": c.buildFinalOutput(),
"usage": map[string]any{
"input_tokens": r.PromptEvalCount,
"output_tokens": r.EvalCount,
"total_tokens": r.PromptEvalCount + r.EvalCount,
},
},
}))
return events

View File

@@ -850,7 +850,7 @@ func TestFromResponsesRequest_Images(t *testing.T) {
}
func TestResponsesStreamConverter_TextOnly(t *testing.T) {
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
// First chunk with content
events := converter.Process(api.ChatResponse{
@@ -916,7 +916,7 @@ func TestResponsesStreamConverter_TextOnly(t *testing.T) {
}
func TestResponsesStreamConverter_ToolCalls(t *testing.T) {
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
events := converter.Process(api.ChatResponse{
Message: api.Message{
@@ -952,7 +952,7 @@ func TestResponsesStreamConverter_ToolCalls(t *testing.T) {
}
func TestResponsesStreamConverter_Reasoning(t *testing.T) {
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
// First chunk with thinking
events := converter.Process(api.ChatResponse{
@@ -1267,7 +1267,7 @@ func TestToResponse_WithReasoning(t *testing.T) {
Content: "The answer is 42",
},
Done: true,
}, ResponsesRequest{})
})
// Should have 2 output items: reasoning + message
if len(response.Output) != 2 {
@@ -1638,7 +1638,7 @@ func TestFromResponsesRequest_ShorthandFormats(t *testing.T) {
func TestResponsesStreamConverter_OutputIncludesContent(t *testing.T) {
// Verify that response.output_item.done includes content field for messages
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
// First chunk
converter.Process(api.ChatResponse{
@@ -1686,7 +1686,7 @@ func TestResponsesStreamConverter_OutputIncludesContent(t *testing.T) {
func TestResponsesStreamConverter_ResponseCompletedIncludesOutput(t *testing.T) {
// Verify that response.completed includes the output array
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
// Process some content
converter.Process(api.ChatResponse{
@@ -1730,7 +1730,7 @@ func TestResponsesStreamConverter_ResponseCompletedIncludesOutput(t *testing.T)
func TestResponsesStreamConverter_ResponseCreatedIncludesOutput(t *testing.T) {
// Verify that response.created includes an empty output array
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
events := converter.Process(api.ChatResponse{
Message: api.Message{Content: "Hi"},
@@ -1757,7 +1757,7 @@ func TestResponsesStreamConverter_ResponseCreatedIncludesOutput(t *testing.T) {
func TestResponsesStreamConverter_SequenceNumbers(t *testing.T) {
// Verify that events include incrementing sequence numbers
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
events := converter.Process(api.ChatResponse{
Message: api.Message{Content: "Hello"},
@@ -1791,7 +1791,7 @@ func TestResponsesStreamConverter_SequenceNumbers(t *testing.T) {
func TestResponsesStreamConverter_FunctionCallStatus(t *testing.T) {
// Verify that function call items include status field
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
events := converter.Process(api.ChatResponse{
Message: api.Message{

View File

@@ -5,7 +5,6 @@ import (
"fmt"
"io"
"os"
"strings"
)
type Prompt struct {
@@ -37,11 +36,10 @@ type Terminal struct {
}
type Instance struct {
Prompt *Prompt
Terminal *Terminal
History *History
Pasting bool
pastedLines []string
Prompt *Prompt
Terminal *Terminal
History *History
Pasting bool
}
func New(prompt Prompt) (*Instance, error) {
@@ -176,8 +174,6 @@ func (i *Instance) Readline() (string, error) {
case CharEsc:
esc = true
case CharInterrupt:
i.pastedLines = nil
i.Prompt.UseAlt = false
return "", ErrInterrupt
case CharPrev:
i.historyPrev(buf, &currentLineBuf)
@@ -192,23 +188,7 @@ func (i *Instance) Readline() (string, error) {
case CharForward:
buf.MoveRight()
case CharBackspace, CharCtrlH:
if buf.IsEmpty() && len(i.pastedLines) > 0 {
lastIdx := len(i.pastedLines) - 1
prevLine := i.pastedLines[lastIdx]
i.pastedLines = i.pastedLines[:lastIdx]
fmt.Print(CursorBOL + ClearToEOL + CursorUp + CursorBOL + ClearToEOL)
if len(i.pastedLines) == 0 {
fmt.Print(i.Prompt.Prompt)
i.Prompt.UseAlt = false
} else {
fmt.Print(i.Prompt.AltPrompt)
}
for _, r := range prevLine {
buf.Add(r)
}
} else {
buf.Remove()
}
buf.Remove()
case CharTab:
// todo: convert back to real tabs
for range 8 {
@@ -231,28 +211,13 @@ func (i *Instance) Readline() (string, error) {
case CharCtrlZ:
fd := os.Stdin.Fd()
return handleCharCtrlZ(fd, i.Terminal.termios)
case CharCtrlJ:
i.pastedLines = append(i.pastedLines, buf.String())
buf.Buf.Clear()
buf.Pos = 0
buf.DisplayPos = 0
buf.LineHasSpace.Clear()
fmt.Println()
fmt.Print(i.Prompt.AltPrompt)
i.Prompt.UseAlt = true
continue
case CharEnter:
case CharEnter, CharCtrlJ:
output := buf.String()
if len(i.pastedLines) > 0 {
output = strings.Join(i.pastedLines, "\n") + "\n" + output
i.pastedLines = nil
}
if output != "" {
i.History.Add(output)
}
buf.MoveToEnd()
fmt.Println()
i.Prompt.UseAlt = false
return output, nil
default:

View File

@@ -179,7 +179,7 @@ _build_macapp() {
fi
rm -f dist/Ollama-darwin.zip
ditto -c -k --norsrc --keepParent dist/Ollama.app dist/Ollama-darwin.zip
ditto -c -k --keepParent dist/Ollama.app dist/Ollama-darwin.zip
(cd dist/Ollama.app/Contents/Resources/; tar -cf - ollama ollama-mlx *.so *.dylib *.metallib 2>/dev/null) | gzip -9vc > dist/ollama-darwin.tgz
# Notarize and Staple
@@ -187,7 +187,7 @@ _build_macapp() {
$(xcrun -f notarytool) submit dist/Ollama-darwin.zip --wait --timeout 20m --apple-id "$APPLE_ID" --password "$APPLE_PASSWORD" --team-id "$APPLE_TEAM_ID"
rm -f dist/Ollama-darwin.zip
$(xcrun -f stapler) staple dist/Ollama.app
ditto -c -k --norsrc --keepParent dist/Ollama.app dist/Ollama-darwin.zip
ditto -c -k --keepParent dist/Ollama.app dist/Ollama-darwin.zip
rm -f dist/Ollama.dmg

View File

@@ -25,6 +25,14 @@ import (
"github.com/ollama/ollama/x/tools"
)
// MultilineState tracks the state of multiline input
type MultilineState int
const (
MultilineNone MultilineState = iota
MultilineSystem
)
// Tool output capping constants
const (
// localModelTokenLimit is the token limit for local models (smaller context).
@@ -648,7 +656,7 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
Prompt: ">>> ",
AltPrompt: "... ",
Placeholder: "Send a message (/? for help)",
AltPlaceholder: "Press Enter to send",
AltPlaceholder: `Use """ to end multi-line input`,
})
if err != nil {
return err
@@ -699,6 +707,7 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
var sb strings.Builder
var format string
var system string
var multiline MultilineState = MultilineNone
for {
line, err := scanner.Readline()
@@ -712,12 +721,37 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
}
scanner.Prompt.UseAlt = false
sb.Reset()
multiline = MultilineNone
continue
case err != nil:
return err
}
switch {
case multiline != MultilineNone:
// check if there's a multiline terminating string
before, ok := strings.CutSuffix(line, `"""`)
sb.WriteString(before)
if !ok {
fmt.Fprintln(&sb)
continue
}
switch multiline {
case MultilineSystem:
system = sb.String()
newMessage := api.Message{Role: "system", Content: system}
if len(messages) > 0 && messages[len(messages)-1].Role == "system" {
messages[len(messages)-1] = newMessage
} else {
messages = append(messages, newMessage)
}
fmt.Println("Set system message.")
sb.Reset()
}
multiline = MultilineNone
scanner.Prompt.UseAlt = false
case strings.HasPrefix(line, "/exit"), strings.HasPrefix(line, "/bye"):
return nil
case strings.HasPrefix(line, "/clear"):
@@ -826,18 +860,41 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
options[args[2]] = fp[args[2]]
case "system":
if len(args) < 3 {
fmt.Println("Usage: /set system <message>")
fmt.Println("Usage: /set system <message> or /set system \"\"\"<multi-line message>\"\"\"")
continue
}
system = strings.Join(args[2:], " ")
newMessage := api.Message{Role: "system", Content: system}
multiline = MultilineSystem
line := strings.Join(args[2:], " ")
line, ok := strings.CutPrefix(line, `"""`)
if !ok {
multiline = MultilineNone
} else {
// only cut suffix if the line is multiline
line, ok = strings.CutSuffix(line, `"""`)
if ok {
multiline = MultilineNone
}
}
sb.WriteString(line)
if multiline != MultilineNone {
scanner.Prompt.UseAlt = true
continue
}
system = sb.String()
newMessage := api.Message{Role: "system", Content: sb.String()}
// Check if the slice is not empty and the last message is from 'system'
if len(messages) > 0 && messages[len(messages)-1].Role == "system" {
// Replace the last message
messages[len(messages)-1] = newMessage
} else {
messages = append(messages, newMessage)
}
fmt.Println("Set system message.")
sb.Reset()
continue
default:
fmt.Printf("Unknown command '/set %s'. Type /? for help\n", args[1])
@@ -1024,7 +1081,7 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
sb.WriteString(line)
}
if sb.Len() > 0 {
if sb.Len() > 0 && multiline == MultilineNone {
newMessage := api.Message{Role: "user", Content: sb.String()}
messages = append(messages, newMessage)

View File

@@ -11,9 +11,11 @@ import (
"os"
"path/filepath"
"runtime/pprof"
"strings"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/models/gemma3"
"github.com/ollama/ollama/x/imagegen/models/glm_image"
"github.com/ollama/ollama/x/imagegen/models/gpt_oss"
"github.com/ollama/ollama/x/imagegen/models/llama"
"github.com/ollama/ollama/x/imagegen/models/qwen_image"
@@ -61,6 +63,7 @@ func main() {
// Legacy mode flags
zimageFlag := flag.Bool("zimage", false, "Z-Image generation")
glmImageFlag := flag.Bool("glm-image", false, "GLM-Image generation")
qwenImage := flag.Bool("qwen-image", false, "Qwen-Image text-to-image generation")
qwenImageEdit := flag.Bool("qwen-image-edit", false, "Qwen-Image-Edit image editing")
var inputImages stringSlice
@@ -117,6 +120,33 @@ func main() {
if err == nil {
err = saveImageArray(img, *out)
}
case *glmImageFlag:
m := &glm_image.Model{}
// Use LoadFromPath if model path looks like a directory, otherwise use Load (ollama manifest)
var loadErr error
if strings.HasPrefix(*modelPath, ".") || strings.HasPrefix(*modelPath, "/") {
loadErr = m.LoadFromPath(*modelPath)
} else {
loadErr = m.Load(*modelPath)
}
if loadErr != nil {
log.Fatal(loadErr)
}
var img *mlx.Array
img, err = m.GenerateFromConfig(context.Background(), &glm_image.GenerateConfig{
Prompt: *prompt,
Width: int32(*width),
Height: int32(*height),
Steps: *steps,
Seed: *seed,
Temperature: float32(*temperature),
TopP: float32(*topP),
GuidanceScale: float32(*cfgScale),
MaxVisualTokens: int32(*maxTokens),
})
if err == nil {
err = saveImageArray(img, *out)
}
case *qwenImage:
m, loadErr := qwen_image.LoadPersistent(*modelPath)
if loadErr != nil {

View File

@@ -48,7 +48,7 @@ func CreateModel(modelName, modelDir, quantize string, createLayer LayerCreator,
var totalParams int64 // Count parameters from original tensor shapes
// Components to process - extract individual tensors from each
components := []string{"text_encoder", "transformer", "vae"}
components := []string{"text_encoder", "transformer", "vae", "vision_language_encoder"}
for _, component := range components {
componentDir := filepath.Join(modelDir, component)
@@ -126,10 +126,13 @@ func CreateModel(modelName, modelDir, quantize string, createLayer LayerCreator,
"text_encoder/generation_config.json",
"transformer/config.json",
"vae/config.json",
"vision_language_encoder/config.json",
"scheduler/scheduler_config.json",
"tokenizer/tokenizer.json",
"tokenizer/tokenizer_config.json",
"tokenizer/vocab.json",
"processor/tokenizer.json", // GLM-Image main tokenizer
"processor/tokenizer_config.json", // GLM-Image tokenizer config
}
for _, cfgPath := range configFiles {

19
x/imagegen/imagegen.md Normal file
View File

@@ -0,0 +1,19 @@
# Image generation models (experimental)
Experimental image generation models are available for **macOS** in Ollama:
## Available models
- [Z-Image-Turbo](https://ollama.com/x/z-image-turbo)
```
ollama run x/z-image-turbo
```
> **Note**: [`x`](https://ollama.com/x) is a username on ollama.com where the maintainer team uploads experimental models
More models coming soon:
1. Qwen-Image-2512
2. Qwen-Image-Edit-2511
3. GLM-Image

View File

@@ -27,6 +27,7 @@ var modelVRAMEstimates = map[string]uint64{
"ZImagePipeline": 21 * GB, // ~21GB for Z-Image (text encoder + transformer + VAE)
"FluxPipeline": 21 * GB, // ~21GB for Flux (same architecture)
"QwenImagePipeline": 80 * GB, // TODO: verify actual requirements, using conservative estimate for now
"GlmImagePipeline": 80 * GB, // ~34GB weights + ~46GB working memory for 9B+7B hybrid model
}
// CheckPlatformSupport validates that image generation is supported on the current platform.

View File

@@ -0,0 +1,693 @@
//go:build mlx
// Package glm_image implements the GLM-Image hybrid AR + diffusion model.
package glm_image
import (
"context"
"fmt"
"math"
"path/filepath"
"time"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// ByT5Tokenizer is a simple byte-level tokenizer for ByT5
// ByT5 uses bytes as tokens: each byte (0-255) maps to token ID (3-258)
// Special tokens: 0=pad, 1=eos, 2=unk
type ByT5Tokenizer struct {
PadTokenID int32
EOSTokenID int32
UNKTokenID int32
}
// NewByT5Tokenizer creates a new ByT5 tokenizer
func NewByT5Tokenizer() *ByT5Tokenizer {
return &ByT5Tokenizer{
PadTokenID: 0,
EOSTokenID: 1,
UNKTokenID: 2,
}
}
// Encode converts a string to token IDs
func (t *ByT5Tokenizer) Encode(text string) []int32 {
bytes := []byte(text)
tokens := make([]int32, len(bytes))
for i, b := range bytes {
// Standard ByT5 tokenization: bytes 0-255 map to tokens 3-258
// (tokens 0, 1, 2 are PAD, EOS, UNK)
tokens[i] = int32(b) + 3
}
return tokens
}
// Decode converts token IDs back to a string
func (t *ByT5Tokenizer) Decode(tokens []int32) string {
bytes := make([]byte, 0, len(tokens))
for _, tok := range tokens {
if tok >= 3 && tok < 259 {
bytes = append(bytes, byte(tok-3))
}
}
return string(bytes)
}
// GenerateConfig holds all options for image generation.
type GenerateConfig struct {
Prompt string
NegativePrompt string // For CFG (optional, not typically used with GLM-Image)
GuidanceScale float32 // Guidance scale (default: 1.5)
Width int32 // Image width (default: 1024, must be divisible by 32)
Height int32 // Image height (default: 1024, must be divisible by 32)
Steps int // Diffusion denoising steps (default: 50)
Seed int64 // Random seed
Progress ProgressFunc // Optional progress callback
// AR generation options
MaxVisualTokens int32 // Max visual tokens to generate (default: 256)
Temperature float32 // AR sampling temperature (default: 0.9)
TopP float32 // Nucleus sampling (default: 0.75)
}
// ProgressFunc is called during generation with stage and step progress.
type ProgressFunc func(stage string, step, totalSteps int)
// Model represents a GLM-Image hybrid model.
type Model struct {
ModelName string
Tokenizer *ByT5Tokenizer // For T5 text encoder (glyph embeddings)
GLMTokenizer *GLMTokenizer // For AR model (visual token generation)
TextEncoder *T5TextEncoder
VisionLanguageEncoder *VisionLanguageEncoder
Transformer *DiffusionTransformer
VAEDecoder *VAEDecoder
}
// Load loads the GLM-Image model from ollama blob storage.
func (m *Model) Load(modelName string) error {
fmt.Printf("Loading GLM-Image model from manifest: %s...\n", modelName)
start := time.Now()
if mlx.GPUIsAvailable() {
mlx.SetDefaultDeviceGPU()
mlx.EnableCompile()
}
m.ModelName = modelName
// Load manifest
manifest, err := imagegen.LoadManifest(modelName)
if err != nil {
return fmt.Errorf("load manifest: %w", err)
}
// Create ByT5 tokenizer (byte-level, no vocabulary file needed)
// Used for T5 text encoder (glyph embeddings)
fmt.Print(" Creating ByT5 tokenizer... ")
m.Tokenizer = NewByT5Tokenizer()
fmt.Println("✓")
// Load GLM tokenizer for AR model (visual token generation)
fmt.Print(" Loading GLM tokenizer... ")
glmTok, err := NewGLMTokenizer(manifest)
if err != nil {
return fmt.Errorf("glm tokenizer: %w", err)
}
m.GLMTokenizer = glmTok
fmt.Println("✓")
// Load T5 text encoder (~830MB)
m.TextEncoder = &T5TextEncoder{}
if err := m.TextEncoder.Load(manifest); err != nil {
return fmt.Errorf("text encoder: %w", err)
}
mlx.Eval(mlx.Collect(m.TextEncoder)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
// Load vision-language encoder (~19GB, 9B params)
m.VisionLanguageEncoder = &VisionLanguageEncoder{}
if err := m.VisionLanguageEncoder.Load(manifest); err != nil {
return fmt.Errorf("vision language encoder: %w", err)
}
mlx.Eval(mlx.Collect(m.VisionLanguageEncoder)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
// Load diffusion transformer (~13GB, 7B params)
m.Transformer = &DiffusionTransformer{}
if err := m.Transformer.Load(manifest); err != nil {
return fmt.Errorf("transformer: %w", err)
}
mlx.Eval(mlx.Collect(m.Transformer)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
// Load VAE decoder (~775MB)
m.VAEDecoder = &VAEDecoder{}
if err := m.VAEDecoder.Load(manifest); err != nil {
return fmt.Errorf("VAE decoder: %w", err)
}
mlx.Eval(mlx.Collect(m.VAEDecoder)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
mem := mlx.MetalGetActiveMemory()
fmt.Printf(" Loaded in %.2fs (%.1f GB VRAM)\n", time.Since(start).Seconds(), float64(mem)/(1024*1024*1024))
return nil
}
// LoadFromPath loads the model from a directory path (not ollama manifest)
func (m *Model) LoadFromPath(modelPath string) error {
fmt.Printf("Loading GLM-Image model from path: %s...\n", modelPath)
start := time.Now()
if mlx.GPUIsAvailable() {
mlx.SetDefaultDeviceGPU()
mlx.EnableCompile()
}
m.ModelName = modelPath
// Create ByT5 tokenizer (byte-level, no vocabulary file needed)
fmt.Print(" Creating ByT5 tokenizer... ")
m.Tokenizer = NewByT5Tokenizer()
fmt.Println("✓")
// Load GLM tokenizer for AR model (visual token generation)
fmt.Print(" Loading GLM tokenizer... ")
glmTok, err := NewGLMTokenizerFromPath(modelPath)
if err != nil {
return fmt.Errorf("glm tokenizer: %w", err)
}
m.GLMTokenizer = glmTok
fmt.Println("✓")
// Load T5 text encoder
m.TextEncoder = &T5TextEncoder{}
if err := m.TextEncoder.LoadFromPath(filepath.Join(modelPath, "text_encoder")); err != nil {
return fmt.Errorf("text encoder: %w", err)
}
mlx.Eval(mlx.Collect(m.TextEncoder)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
// Load vision-language encoder
m.VisionLanguageEncoder = &VisionLanguageEncoder{}
if err := m.VisionLanguageEncoder.LoadFromPath(filepath.Join(modelPath, "vision_language_encoder")); err != nil {
return fmt.Errorf("vision language encoder: %w", err)
}
mlx.Eval(mlx.Collect(m.VisionLanguageEncoder)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
// Load diffusion transformer
m.Transformer = &DiffusionTransformer{}
if err := m.Transformer.LoadFromPath(filepath.Join(modelPath, "transformer")); err != nil {
return fmt.Errorf("transformer: %w", err)
}
mlx.Eval(mlx.Collect(m.Transformer)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
// Load VAE decoder
m.VAEDecoder = &VAEDecoder{}
if err := m.VAEDecoder.LoadFromPath(filepath.Join(modelPath, "vae")); err != nil {
return fmt.Errorf("VAE decoder: %w", err)
}
mlx.Eval(mlx.Collect(m.VAEDecoder)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
mem := mlx.MetalGetActiveMemory()
fmt.Printf(" Loaded in %.2fs (%.1f GB VRAM)\n", time.Since(start).Seconds(), float64(mem)/(1024*1024*1024))
return nil
}
// Generate creates an image from a prompt.
func (m *Model) Generate(prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
return m.GenerateFromConfig(context.Background(), &GenerateConfig{
Prompt: prompt,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
})
}
// GenerateWithProgress creates an image with progress callback.
func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress ProgressFunc) (*mlx.Array, error) {
return m.GenerateFromConfig(context.Background(), &GenerateConfig{
Prompt: prompt,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
Progress: progress,
})
}
// GenerateFromConfig generates an image using the unified config struct.
func (m *Model) GenerateFromConfig(ctx context.Context, cfg *GenerateConfig) (*mlx.Array, error) {
start := time.Now()
result, err := m.generate(ctx, cfg)
if err != nil {
return nil, err
}
fmt.Printf("Generated in %.2fs (%d diffusion steps)\n", time.Since(start).Seconds(), cfg.Steps)
return result, nil
}
// GenerateImage implements model.ImageModel interface.
func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
return m.Generate(prompt, width, height, steps, seed)
}
// generate is the internal generation pipeline.
func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array, error) {
// Apply defaults
if cfg.Width <= 0 {
cfg.Width = 1024
}
if cfg.Height <= 0 {
cfg.Height = 1024
}
if cfg.Steps <= 0 {
cfg.Steps = 50
}
if cfg.GuidanceScale <= 0 {
cfg.GuidanceScale = 1.5
}
// Calculate MaxVisualTokens based on image dimensions
// GLM-Image generates TWO grids of visual tokens:
// 1. First: prev (small) grid - prevTokenH × prevTokenW tokens
// 2. Then: target (large) grid - tokenH × tokenW tokens
// After generation, we extract only the TARGET grid tokens for diffusion.
factor := int32(32)
tokenH := cfg.Height / factor
tokenW := cfg.Width / factor
targetGridTokens := tokenH * tokenW
// Compute prev grid dimensions using diffusers formula:
// ratio = token_h / token_w
// prev_token_h = int(sqrt(ratio) * 16)
// prev_token_w = int(sqrt(1/ratio) * 16)
ratio := float64(tokenH) / float64(tokenW)
prevTokenH := int32(math.Sqrt(ratio) * 16)
prevTokenW := int32(math.Sqrt(1/ratio) * 16)
prevGridTokens := prevTokenH * prevTokenW
// Total tokens to generate = prev grid + target grid
// (diffusers does max_new_tokens = total + 1 for EOS, but we stop on EOS anyway)
cfg.MaxVisualTokens = prevGridTokens + targetGridTokens
if cfg.Temperature <= 0 {
cfg.Temperature = 0.9
}
if cfg.TopP <= 0 {
cfg.TopP = 0.75
}
// Ensure dimensions are divisible by 32
cfg.Width = (cfg.Width / 32) * 32
cfg.Height = (cfg.Height / 32) * 32
tcfg := m.Transformer.Config
latentH := cfg.Height / 8
latentW := cfg.Width / 8
// Progress callback helper
progress := func(stage string, step, total int) {
if cfg.Progress != nil {
cfg.Progress(stage, step, total)
}
}
// === PHASE 1: T5 Text Encoding ===
fmt.Println("[T5] Encoding glyph text...")
progress("text_encoding", 0, 1)
textEmbed := m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.Prompt)
mlx.Keep(textEmbed)
mlx.Eval(textEmbed)
fmt.Printf("[T5] Done, shape: %v\n", textEmbed.Shape())
progress("text_encoding", 1, 1)
// === PHASE 2: AR Visual Token Generation ===
fmt.Printf("[AR] Generating %d visual tokens...\n", cfg.MaxVisualTokens)
progress("ar_generation", 0, int(cfg.MaxVisualTokens))
visualTokens := m.VisionLanguageEncoder.Generate(
cfg.Prompt,
m.GLMTokenizer,
cfg.MaxVisualTokens,
cfg.Temperature,
cfg.TopP,
cfg.Seed,
cfg.Height,
cfg.Width,
func(step int) {
if step%100 == 0 || step < 10 {
fmt.Printf("[AR] Step %d/%d\n", step, cfg.MaxVisualTokens)
}
progress("ar_generation", step, int(cfg.MaxVisualTokens))
},
)
mlx.Keep(visualTokens)
mlx.Eval(visualTokens)
fmt.Printf("[AR] Done generating visual tokens\n")
progress("ar_generation", int(cfg.MaxVisualTokens), int(cfg.MaxVisualTokens))
vtShape := visualTokens.Shape()
totalGenerated := vtShape[1]
fmt.Printf("[AR] Generated %d tokens total\n", totalGenerated)
// Extract only the TARGET grid tokens (skip the prev grid tokens)
// diffusers: large_image_tokens = outputs[input_length + large_image_start_offset : ...]
// large_image_start_offset = prev_grid_size
var targetGridVisualTokens *mlx.Array
if totalGenerated >= prevGridTokens+targetGridTokens {
// Full generation completed - extract target grid
targetGridVisualTokens = mlx.Slice(visualTokens,
[]int32{0, prevGridTokens},
[]int32{1, prevGridTokens + targetGridTokens})
mlx.Keep(targetGridVisualTokens)
mlx.Eval(targetGridVisualTokens)
} else if totalGenerated > prevGridTokens {
// Partial target grid - take what we have
actualTargetTokens := totalGenerated - prevGridTokens
targetGridVisualTokens = mlx.Slice(visualTokens,
[]int32{0, prevGridTokens},
[]int32{1, totalGenerated})
mlx.Keep(targetGridVisualTokens)
mlx.Eval(targetGridVisualTokens)
fmt.Printf("WARNING: Partial target grid: got %d/%d target tokens\n",
actualTargetTokens, targetGridTokens)
} else {
// Not enough tokens - EOS came too early
return nil, fmt.Errorf("AR generation stopped too early: got %d tokens, need at least %d (prev grid) + 1",
totalGenerated, prevGridTokens)
}
// === PHASE 3: Diffusion Decoding ===
// Setup scheduler with dynamic shift based on image size
scheduler := NewFlowMatchScheduler(DefaultSchedulerConfig())
imgSeqLen := (latentH / tcfg.PatchSize) * (latentW / tcfg.PatchSize)
scheduler.SetTimestepsWithDynamicShift(cfg.Steps, imgSeqLen)
// Initialize noise latents [B, C, H, W]
latents := scheduler.InitNoise([]int32{1, tcfg.InChannels, latentH, latentW}, cfg.Seed)
mlx.Eval(latents)
// Upsample TARGET grid visual tokens 2x to match patch count (matching diffusers)
// target_grid tokens -> 2x upsample -> patch_count
// e.g., 32x32=1024 tokens -> 64x64=4096 patches for 1024x1024
visualTokensUpsampled := upsampleTokens(targetGridVisualTokens, tokenH, tokenW, 2)
// Prepare prior embeddings from upsampled visual tokens (VQ codebook lookup + projection)
priorEmbed := m.Transformer.EmbedPriorTokens(visualTokensUpsampled)
mlx.Keep(priorEmbed)
mlx.Eval(priorEmbed)
// Prepare text conditioning (project T5 embeddings)
textCond := m.Transformer.ProjectTextEmbeddings(textEmbed)
mlx.Keep(textCond)
mlx.Eval(textCond)
// === CFG Setup ===
// For classifier-free guidance, we need unconditional (negative) text embeddings
// GLM-Image uses empty string "" for negative prompt
doCFG := cfg.GuidanceScale > 1.0
var negativeTextCond *mlx.Array
if doCFG {
// Encode empty string for negative prompt
negativeTextEmbed := m.TextEncoder.EncodePrompt(m.Tokenizer, "")
mlx.Keep(negativeTextEmbed)
mlx.Eval(negativeTextEmbed)
negativeTextCond = m.Transformer.ProjectTextEmbeddings(negativeTextEmbed)
mlx.Keep(negativeTextCond)
mlx.Eval(negativeTextCond)
negativeTextEmbed.Free()
}
// Prepare conditioning inputs
targetSize := mlx.NewArray([]float32{float32(cfg.Height), float32(cfg.Width)}, []int32{1, 2})
cropCoords := mlx.NewArray([]float32{0, 0}, []int32{1, 2}) // Default: no crop offset
targetSize = mlx.ToBFloat16(targetSize)
cropCoords = mlx.ToBFloat16(cropCoords)
mlx.Keep(targetSize)
mlx.Keep(cropCoords)
mlx.Eval(targetSize, cropCoords)
pH := latentH / tcfg.PatchSize
pW := latentW / tcfg.PatchSize
// Denoising loop
fmt.Printf("[Diffusion] Starting %d denoising steps...\n", cfg.Steps)
progress("diffusion", 0, cfg.Steps)
for i := 0; i < cfg.Steps; i++ {
fmt.Printf("[Diffusion] Step %d/%d (timestep=%.1f)\n", i+1, cfg.Steps, scheduler.Timesteps[i]-1)
// Check for cancellation
if ctx != nil {
select {
case <-ctx.Done():
textEmbed.Free()
visualTokens.Free()
// visualTokensUpsampled points to visualTokens, don't double-free
priorEmbed.Free()
textCond.Free()
latents.Free()
return nil, ctx.Err()
default:
}
}
// Get timestep value for the transformer
// scheduler.Timesteps contains raw timestep values (1000 down to ~20)
// Pass timestep - 1 to match diffusers: timestep = t.expand(latents.shape[0]) - 1
timestepVal := scheduler.Timesteps[i] - 1
timestep := mlx.ToBFloat16(mlx.NewArray([]float32{timestepVal}, []int32{1}))
// Patchify latents [B, C, H, W] -> [B, L, C*p*p]
patches := PatchifyLatents(latents, tcfg.PatchSize)
// Transformer forward with MMDiT architecture
// Conditional pass (with text + prior embeddings)
outputCond := m.Transformer.ForwardWithPriorDrop(
patches,
priorEmbed,
textCond,
timestep,
targetSize,
cropCoords,
pH,
pW,
false, // priorTokenDrop = false for conditional
)
// Unpatchify [B, L, C*p*p] -> [B, C, H, W]
noisePredCond := UnpatchifyLatents(outputCond, latentH, latentW, tcfg.PatchSize, tcfg.OutChannels)
var noisePred *mlx.Array
if doCFG {
// Unconditional pass (empty text, dropped prior embeddings)
outputUncond := m.Transformer.ForwardWithPriorDrop(
patches,
priorEmbed, // Still passed but will be ignored due to priorTokenDrop=true
negativeTextCond,
timestep,
targetSize,
cropCoords,
pH,
pW,
true, // priorTokenDrop = true for unconditional
)
noisePredUncond := UnpatchifyLatents(outputUncond, latentH, latentW, tcfg.PatchSize, tcfg.OutChannels)
// CFG formula: noise_pred = uncond + guidance_scale * (cond - uncond)
diff := mlx.Sub(noisePredCond, noisePredUncond)
scaled := mlx.MulScalar(diff, cfg.GuidanceScale)
noisePred = mlx.Add(noisePredUncond, scaled)
} else {
noisePred = noisePredCond
}
// Scheduler step
oldLatents := latents
latents = scheduler.Step(noisePred, latents, i)
mlx.Eval(latents)
oldLatents.Free()
progress("diffusion", i+1, cfg.Steps)
}
// Cleanup intermediate arrays
textEmbed.Free()
visualTokens.Free()
// visualTokensUpsampled points to visualTokens, don't double-free
priorEmbed.Free()
textCond.Free()
if negativeTextCond != nil {
negativeTextCond.Free()
}
targetSize.Free()
cropCoords.Free()
// === PHASE 4: VAE Decode ===
progress("vae_decode", 0, 1)
decoded := m.VAEDecoder.Decode(latents)
mlx.Eval(decoded)
latents.Free()
progress("vae_decode", 1, 1)
return decoded, nil
}
// upsampleTokens performs nearest-neighbor upsampling of visual tokens
// Converts from prev_grid (e.g., 16x16) to target_grid (e.g., 32x32 for 2x, 64x64 for 4x)
// scale must be 2 or 4
//
// Handles early EOS gracefully: if tokens has fewer than prevH*prevW elements,
// missing tokens are padded with 0 (visual token padding value).
func upsampleTokens(tokens *mlx.Array, prevH, prevW int32, scale int32) *mlx.Array {
// tokens: [1, N] where N <= prevH*prevW (may be shorter if early EOS)
// Each token at (i, j) becomes scale*scale tokens in the output
mlx.Eval(tokens)
tokenData := tokens.DataInt32()
numTokens := int32(len(tokenData))
expectedTokens := prevH * prevW
// Warn if we got fewer tokens than expected (early EOS)
if numTokens < expectedTokens {
fmt.Printf("WARNING: upsampleTokens got %d tokens, expected %d (padding with 0)\n",
numTokens, expectedTokens)
}
targetH := prevH * scale
targetW := prevW * scale
upsampled := make([]int32, targetH*targetW)
for i := int32(0); i < prevH; i++ {
for j := int32(0); j < prevW; j++ {
srcIdx := i*prevW + j
// Handle early EOS: use 0 (padding) for missing tokens
var val int32
if srcIdx < numTokens {
val = tokenData[srcIdx]
} else {
val = 0 // Padding token
}
// Place in scale*scale positions
dstI := i * scale
dstJ := j * scale
for di := int32(0); di < scale; di++ {
for dj := int32(0); dj < scale; dj++ {
upsampled[(dstI+di)*targetW+(dstJ+dj)] = val
}
}
}
}
return mlx.NewArrayInt32(upsampled, []int32{1, targetH * targetW})
}
// PatchifyLatents converts [B, C, H, W] to [B, L, C*p*p]
func PatchifyLatents(latents *mlx.Array, patchSize int32) *mlx.Array {
shape := latents.Shape()
B := shape[0]
C := shape[1]
H := shape[2]
W := shape[3]
pH := H / patchSize
pW := W / patchSize
// Reshape: [B, C, H, W] -> [B, C, pH, p, pW, p]
x := mlx.Reshape(latents, B, C, pH, patchSize, pW, patchSize)
// Transpose: -> [B, pH, pW, C, p, p]
x = mlx.Transpose(x, 0, 2, 4, 1, 3, 5)
// Flatten: -> [B, pH*pW, C*p*p]
return mlx.Reshape(x, B, pH*pW, C*patchSize*patchSize)
}
// UnpatchifyLatents converts [B, L, C*p*p] back to [B, C, H, W]
func UnpatchifyLatents(patches *mlx.Array, H, W, patchSize, channels int32) *mlx.Array {
shape := patches.Shape()
B := shape[0]
pH := H / patchSize
pW := W / patchSize
// Reshape: [B, L, C*p*p] -> [B, pH, pW, C, p, p]
x := mlx.Reshape(patches, B, pH, pW, channels, patchSize, patchSize)
// Transpose: -> [B, C, pH, p, pW, p]
x = mlx.Transpose(x, 0, 3, 1, 4, 2, 5)
// Reshape: -> [B, C, H, W]
return mlx.Reshape(x, B, channels, pH*patchSize, pW*patchSize)
}
// CalculateShift computes the dynamic shift for flow matching based on image sequence length.
func CalculateShift(imgSeqLen int32) float32 {
cfg := DefaultSchedulerConfig()
if !cfg.UseDynamicShifting {
return 0
}
// Sqrt-based shift calculation (matches diffusers)
m := float32(math.Sqrt(float64(imgSeqLen) / float64(cfg.BaseImageSeqLen)))
return m*cfg.MaxShift + cfg.BaseShift
}
// UpsampleTokens2x upsamples token IDs by 2x using nearest neighbor interpolation
// tokens: [B, H*W] -> [B, (H*2)*(W*2)]
// This matches diffusers' _upsample_token_ids function
func UpsampleTokens2x(tokens *mlx.Array, gridH, gridW int32) *mlx.Array {
shape := tokens.Shape()
B := shape[0]
// Reshape to [B, 1, H, W] for interpolation
tokens = mlx.Reshape(tokens, B, 1, gridH, gridW)
// Convert to float for interpolation
tokensFloat := mlx.AsType(tokens, mlx.DtypeFloat32)
// 2x nearest neighbor upsample
// [B, 1, H, W] -> [B, 1, H*2, W*2]
upsampled := nearestUpsample2x(tokensFloat)
// Convert back to int and reshape to [B, H*2*W*2]
upsampled = mlx.AsType(upsampled, mlx.DtypeInt32)
return mlx.Reshape(upsampled, B, gridH*2*gridW*2)
}
// nearestUpsample2x performs 2x nearest neighbor upsampling on NCHW tensor
func nearestUpsample2x(x *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
C := shape[1]
H := shape[2]
W := shape[3]
// Repeat each element 2x2
// [B, C, H, W] -> [B, C, H, 1, W, 1] -> [B, C, H, 2, W, 2] -> [B, C, H*2, W*2]
x = mlx.Reshape(x, B, C, H, 1, W, 1)
// Tile to repeat each pixel 2x2
x = mlx.Tile(x, []int32{1, 1, 1, 2, 1, 2})
// Reshape to final size
return mlx.Reshape(x, B, C, H*2, W*2)
}

View File

@@ -0,0 +1,358 @@
//go:build mlx
package glm_image
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"sort"
"strings"
"github.com/ollama/ollama/x/imagegen"
)
// GLMTokenizer implements the GLM tokenizer for the AR model
// This is a BPE-style tokenizer with ignore_merges=true, meaning it does
// greedy longest-match tokenization from the vocab without runtime merging.
type GLMTokenizer struct {
Vocab map[string]int32 // token string -> token ID
VocabReverse map[int32]string // token ID -> token string
SpecialTokens map[string]int32 // special token strings -> IDs
// Special token IDs
SopTokenID int32 // <sop> = grid_bos_token (167845)
EopTokenID int32 // <eop> = grid_eos_token (167846)
BosTokenID int32 // <|dit_token_16384|> = visual BOS (16384)
EosTokenID int32 // <|dit_token_16385|> = visual EOS (16385)
PadTokenID int32
// Sorted vocab keys by length (longest first) for greedy matching
sortedTokens []string
}
// tokenizerJSON represents the structure of tokenizer.json
type tokenizerJSON struct {
Model struct {
Vocab map[string]int32 `json:"vocab"`
} `json:"model"`
AddedTokens []struct {
ID int32 `json:"id"`
Content string `json:"content"`
Special bool `json:"special"`
} `json:"added_tokens"`
}
// NewGLMTokenizer creates a GLM tokenizer from the model manifest
func NewGLMTokenizer(manifest *imagegen.ModelManifest) (*GLMTokenizer, error) {
// Read tokenizer.json from processor directory in manifest
data, err := manifest.ReadConfig("processor/tokenizer.json")
if err != nil {
return nil, fmt.Errorf("failed to read tokenizer.json from manifest: %w", err)
}
var tj tokenizerJSON
if err := json.Unmarshal(data, &tj); err != nil {
return nil, fmt.Errorf("failed to parse tokenizer.json: %w", err)
}
tok := &GLMTokenizer{
Vocab: make(map[string]int32),
VocabReverse: make(map[int32]string),
SpecialTokens: make(map[string]int32),
}
// Load vocab from model section
for token, id := range tj.Model.Vocab {
tok.Vocab[token] = id
tok.VocabReverse[id] = token
}
// Load added tokens (special tokens including dit_tokens)
for _, at := range tj.AddedTokens {
tok.Vocab[at.Content] = at.ID
tok.VocabReverse[at.ID] = at.Content
if at.Special {
tok.SpecialTokens[at.Content] = at.ID
}
}
// Set special token IDs
tok.SopTokenID = 167845 // <sop>
tok.EopTokenID = 167846 // <eop>
tok.BosTokenID = 16384 // <|dit_token_16384|>
tok.EosTokenID = 16385 // <|dit_token_16385|>
tok.PadTokenID = 16385 // Same as EOS
// Build sorted token list for greedy matching (longest first)
tok.sortedTokens = make([]string, 0, len(tok.Vocab))
for token := range tok.Vocab {
tok.sortedTokens = append(tok.sortedTokens, token)
}
sort.Slice(tok.sortedTokens, func(i, j int) bool {
return len(tok.sortedTokens[i]) > len(tok.sortedTokens[j])
})
fmt.Printf("Loaded GLM tokenizer with %d tokens\n", len(tok.Vocab))
return tok, nil
}
// NewGLMTokenizerFromPath creates a GLM tokenizer from a directory path
func NewGLMTokenizerFromPath(modelPath string) (*GLMTokenizer, error) {
// Read tokenizer.json from processor directory
tokenizerPath := filepath.Join(modelPath, "processor", "tokenizer.json")
data, err := os.ReadFile(tokenizerPath)
if err != nil {
return nil, fmt.Errorf("failed to read tokenizer.json: %w", err)
}
var tj tokenizerJSON
if err := json.Unmarshal(data, &tj); err != nil {
return nil, fmt.Errorf("failed to parse tokenizer.json: %w", err)
}
tok := &GLMTokenizer{
Vocab: make(map[string]int32),
VocabReverse: make(map[int32]string),
SpecialTokens: make(map[string]int32),
}
// Load vocab from model section
for token, id := range tj.Model.Vocab {
tok.Vocab[token] = id
tok.VocabReverse[id] = token
}
// Load added tokens (special tokens including dit_tokens)
for _, at := range tj.AddedTokens {
tok.Vocab[at.Content] = at.ID
tok.VocabReverse[at.ID] = at.Content
if at.Special {
tok.SpecialTokens[at.Content] = at.ID
}
}
// Set special token IDs
tok.SopTokenID = 167845 // <sop>
tok.EopTokenID = 167846 // <eop>
tok.BosTokenID = 16384 // <|dit_token_16384|>
tok.EosTokenID = 16385 // <|dit_token_16385|>
tok.PadTokenID = 16385 // Same as EOS
// Build sorted token list for greedy matching (longest first)
tok.sortedTokens = make([]string, 0, len(tok.Vocab))
for token := range tok.Vocab {
tok.sortedTokens = append(tok.sortedTokens, token)
}
sort.Slice(tok.sortedTokens, func(i, j int) bool {
return len(tok.sortedTokens[i]) > len(tok.sortedTokens[j])
})
fmt.Printf("Loaded GLM tokenizer with %d tokens\n", len(tok.Vocab))
return tok, nil
}
// Encode tokenizes a string into token IDs
// This uses greedy longest-match tokenization with GPT-2 style space handling
func (t *GLMTokenizer) Encode(text string) []int32 {
if text == "" {
return []int32{}
}
var tokens []int32
// First, check for and handle special tokens
// Replace special tokens with placeholders, encode, then restore
specialReplacements := make(map[string]int32)
for special, id := range t.SpecialTokens {
if strings.Contains(text, special) {
specialReplacements[special] = id
}
}
// Process text character by character with special token handling
i := 0
isFirstToken := true
for i < len(text) {
// Check for special tokens first
foundSpecial := false
for special, id := range specialReplacements {
if strings.HasPrefix(text[i:], special) {
tokens = append(tokens, id)
i += len(special)
isFirstToken = false
foundSpecial = true
break
}
}
if foundSpecial {
continue
}
// Handle regular text with GPT-2 style space prefix
// "Ġ" (U+0120) represents a space before a token
remaining := text[i:]
// Try to find the longest matching token
matched := false
for _, token := range t.sortedTokens {
// Skip special tokens in regular matching
if _, isSpecial := t.SpecialTokens[token]; isSpecial {
continue
}
// Check if this token matches
tokenText := token
// Handle the Ġ prefix (represents space)
if strings.HasPrefix(token, "Ġ") {
// This token expects a leading space
if i > 0 || !isFirstToken {
// Check if remaining starts with space + token content
tokenContent := token[len("Ġ"):]
if strings.HasPrefix(remaining, " "+tokenContent) {
tokens = append(tokens, t.Vocab[token])
i += 1 + len(tokenContent) // space + content
isFirstToken = false
matched = true
break
}
}
} else {
// Regular token without space prefix
if strings.HasPrefix(remaining, tokenText) {
tokens = append(tokens, t.Vocab[token])
i += len(tokenText)
isFirstToken = false
matched = true
break
}
}
}
if !matched {
// No token found - skip this character (or use UNK)
// For now, just skip unknown characters
i++
}
}
return tokens
}
// EncodeForGeneration encodes a prompt with grid tokens for image generation
// Format: {prompt}<sop>{token_h} {token_w}<eop><sop>{prev_h} {prev_w}<eop><|dit_token_16384|>
//
// Uses GPT-2 style tokenization where " 32" becomes "Ġ32" (a single token with
// space prefix), matching the HuggingFace tokenizer behavior.
func (t *GLMTokenizer) EncodeForGeneration(prompt string, targetHeight, targetWidth int32) []int32 {
// Calculate grid dimensions
factor := int32(32)
height := (targetHeight / factor) * factor
width := (targetWidth / factor) * factor
tokenH := height / factor
tokenW := width / factor
// Calculate previous grid dimensions
ratio := float64(tokenH) / float64(tokenW)
prevTokenH := int32(sqrt(ratio) * 16)
prevTokenW := int32(sqrt(1.0/ratio) * 16)
// Encode the prompt text
promptTokens := t.Encode(prompt)
// Build the full sequence:
// [prompt tokens] <sop> [tokenH] [Ġ+tokenW] <eop> <sop> [prevH] [Ġ+prevW] <eop> <bos>
// Note: HF tokenizer treats " 32" as "Ġ32" (single token), not "Ġ" + "32"
var tokens []int32
tokens = append(tokens, promptTokens...)
// First grid: <sop> H W <eop>
// First number has no space prefix, second number has space prefix (Ġ)
tokens = append(tokens, t.SopTokenID)
tokens = append(tokens, t.encodeNumber(tokenH)...)
tokens = append(tokens, t.encodeSpaceNumber(tokenW)...) // " W" as Ġ+W
tokens = append(tokens, t.EopTokenID)
// Second grid: <sop> prevH prevW <eop>
tokens = append(tokens, t.SopTokenID)
tokens = append(tokens, t.encodeNumber(prevTokenH)...)
tokens = append(tokens, t.encodeSpaceNumber(prevTokenW)...) // " prevW" as Ġ+prevW
tokens = append(tokens, t.EopTokenID)
// BOS token (start of image generation)
tokens = append(tokens, t.BosTokenID)
return tokens
}
// encodeNumber encodes a number - first tries as a whole token, falls back to digit-by-digit
func (t *GLMTokenizer) encodeNumber(n int32) []int32 {
s := fmt.Sprintf("%d", n)
// First try: look up the whole number as a single token
if id, ok := t.Vocab[s]; ok {
return []int32{id}
}
// Fallback: encode digit by digit
var tokens []int32
for _, c := range s {
if id, ok := t.Vocab[string(c)]; ok {
tokens = append(tokens, id)
}
}
return tokens
}
// encodeSpaceNumber encodes " N" as "ĠN" (space-prefixed number) matching HF tokenizer
// GPT-2 style: " 32" becomes single token "Ġ32", not "Ġ" + "32"
func (t *GLMTokenizer) encodeSpaceNumber(n int32) []int32 {
s := fmt.Sprintf("%d", n)
// First try: look up "Ġ{number}" as a single token (e.g., "Ġ32")
spaceToken := "Ġ" + s
if id, ok := t.Vocab[spaceToken]; ok {
return []int32{id}
}
// Fallback: bare space Ġ + number tokens
var tokens []int32
if spaceID, ok := t.Vocab["Ġ"]; ok {
tokens = append(tokens, spaceID)
}
tokens = append(tokens, t.encodeNumber(n)...)
return tokens
}
// sqrt is a helper for float64 sqrt
func sqrt(x float64) float64 {
if x <= 0 {
return 0
}
// Newton's method
z := x
for i := 0; i < 10; i++ {
z = z - (z*z-x)/(2*z)
}
return z
}
// Decode converts token IDs back to a string
func (t *GLMTokenizer) Decode(tokens []int32) string {
var sb strings.Builder
for _, id := range tokens {
if token, ok := t.VocabReverse[id]; ok {
// Handle Ġ prefix (convert back to space)
if strings.HasPrefix(token, "Ġ") {
sb.WriteString(" ")
sb.WriteString(token[len("Ġ"):])
} else {
sb.WriteString(token)
}
}
}
return sb.String()
}

View File

@@ -0,0 +1,159 @@
//go:build mlx
package glm_image
import (
"math"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// FlowMatchSchedulerConfig holds scheduler configuration
type FlowMatchSchedulerConfig struct {
NumTrainTimesteps int32 `json:"num_train_timesteps"` // 1000
BaseShift float32 `json:"base_shift"` // 0.25
MaxShift float32 `json:"max_shift"` // 0.75
BaseImageSeqLen int32 `json:"base_image_seq_len"` // 256
MaxImageSeqLen int32 `json:"max_image_seq_len"` // 4096
UseDynamicShifting bool `json:"use_dynamic_shifting"` // true
TimeShiftType string `json:"time_shift_type"` // "linear"
}
// DefaultSchedulerConfig returns the default config for GLM-Image
func DefaultSchedulerConfig() *FlowMatchSchedulerConfig {
return &FlowMatchSchedulerConfig{
NumTrainTimesteps: 1000,
BaseShift: 0.25,
MaxShift: 0.75,
BaseImageSeqLen: 256,
MaxImageSeqLen: 4096,
UseDynamicShifting: true,
TimeShiftType: "linear",
}
}
// FlowMatchScheduler implements FlowMatchEulerDiscreteScheduler
type FlowMatchScheduler struct {
Config *FlowMatchSchedulerConfig
Timesteps []float32 // Raw timesteps for transformer conditioning (unshifted)
Sigmas []float32 // Shifted sigmas for Euler step calculation
NumSteps int
}
// NewFlowMatchScheduler creates a new scheduler
func NewFlowMatchScheduler(cfg *FlowMatchSchedulerConfig) *FlowMatchScheduler {
return &FlowMatchScheduler{Config: cfg}
}
// SetTimestepsWithDynamicShift sets timesteps with dynamic shifting based on image size
// Following diffusers: raw timesteps are used for conditioning, shifted sigmas for step calculation
func (s *FlowMatchScheduler) SetTimestepsWithDynamicShift(numSteps int, imgSeqLen int32) {
s.NumSteps = numSteps
// Calculate shift (mu) based on image sequence length
mu := s.calculateShift(imgSeqLen)
// Create timesteps: linspace from sigma_max_t to sigma_min_t
// sigma_max = 1.0, sigma_min ~= 0.001 (near 0 but not exactly 0)
// Then apply time shift and append terminal sigma=0
s.Timesteps = make([]float32, numSteps)
s.Sigmas = make([]float32, numSteps+1) // +1 for terminal sigma
numTrainTimesteps := float32(s.Config.NumTrainTimesteps)
// Create base sigmas: linspace from 1.0 to small value (matching diffusers)
for i := 0; i < numSteps; i++ {
// linspace from 1000 to ~20 (sigma_min * num_train_timesteps)
tRaw := numTrainTimesteps - float32(i)*(numTrainTimesteps-1.0)/float32(numSteps-1)
s.Timesteps[i] = tRaw
// Convert to sigma [0, 1]
sigma := tRaw / numTrainTimesteps
// Apply time shift if enabled
if s.Config.UseDynamicShifting && mu > 0 {
sigma = s.applyShift(mu, sigma)
}
s.Sigmas[i] = sigma
}
// Append terminal sigma = 0 (the final clean image)
s.Sigmas[numSteps] = 0
}
// calculateShift computes dynamic shift based on image sequence length
// Uses the sqrt-based formula from diffusers:
// m = (image_seq_len / base_seq_len) ** 0.5
// mu = m * max_shift + base_shift
func (s *FlowMatchScheduler) calculateShift(imgSeqLen int32) float32 {
cfg := s.Config
if !cfg.UseDynamicShifting {
return 0
}
// Sqrt-based shift calculation (matches diffusers pipeline_glm_image.py)
m := float32(math.Sqrt(float64(imgSeqLen) / float64(cfg.BaseImageSeqLen)))
mu := m*cfg.MaxShift + cfg.BaseShift
return mu
}
// applyShift applies time shift transformation
// mu: the computed shift value
// t: sigma value in [0, 1]
func (s *FlowMatchScheduler) applyShift(mu float32, t float32) float32 {
if t <= 0 {
return 0
}
if t >= 1 {
return 1
}
// sigma=1.0 for both shift types
sigma := float32(1.0)
if s.Config.TimeShiftType == "linear" {
// Linear: mu / (mu + (1/t - 1)^sigma)
return mu / (mu + float32(math.Pow(float64(1.0/t-1.0), float64(sigma))))
}
// Exponential (default): exp(mu) / (exp(mu) + (1/t - 1)^sigma)
expMu := float32(math.Exp(float64(mu)))
return expMu / (expMu + float32(math.Pow(float64(1.0/t-1.0), float64(sigma))))
}
// Step performs one denoising step
func (s *FlowMatchScheduler) Step(modelOutput, sample *mlx.Array, stepIdx int) *mlx.Array {
sigma := s.Sigmas[stepIdx]
sigmaNext := s.Sigmas[stepIdx+1]
// Euler step: x_{t-dt} = x_t + dt * v_t
dt := sigmaNext - sigma // Negative (going from noise to clean)
scaledOutput := mlx.MulScalar(modelOutput, dt)
return mlx.Add(sample, scaledOutput)
}
// InitNoise creates initial noise
func (s *FlowMatchScheduler) InitNoise(shape []int32, seed int64) *mlx.Array {
return mlx.RandomNormalWithDtype(shape, uint64(seed), mlx.DtypeBFloat16)
}
// AddNoise adds noise to clean samples for a given timestep (for img2img)
func (s *FlowMatchScheduler) AddNoise(cleanSample, noise *mlx.Array, timestepIdx int) *mlx.Array {
// In flow matching: x_t = (1-sigma) * x_0 + sigma * noise
// Use sigmas (shifted) for the interpolation
sigma := s.Sigmas[timestepIdx]
oneMinusSigma := 1.0 - sigma
scaledClean := mlx.MulScalar(cleanSample, oneMinusSigma)
scaledNoise := mlx.MulScalar(noise, sigma)
return mlx.Add(scaledClean, scaledNoise)
}
// GetTimesteps returns all timesteps
func (s *FlowMatchScheduler) GetTimesteps() []float32 {
return s.Timesteps
}

View File

@@ -0,0 +1,497 @@
//go:build mlx
package glm_image
import (
"encoding/json"
"fmt"
"math"
"os"
"path/filepath"
"regexp"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// T5Config holds T5 encoder configuration
type T5Config struct {
DModel int32 `json:"d_model"` // 1472
DFF int32 `json:"d_ff"` // 3584
DKV int32 `json:"d_kv"` // 64
NumHeads int32 `json:"num_heads"` // 6
NumLayers int32 `json:"num_layers"` // 12
VocabSize int32 `json:"vocab_size"` // 384 (byte-level)
LayerNormEps float32 `json:"layer_norm_epsilon"` // 1e-6
IsGatedAct bool `json:"is_gated_act"` // true (gated-gelu)
// Relative position bias
RelativeAttentionNumBuckets int32 `json:"relative_attention_num_buckets"` // 32
RelativeAttentionMaxDistance int32 `json:"relative_attention_max_distance"` // 128
}
// T5TextEncoder is the T5 encoder for text conditioning
type T5TextEncoder struct {
Config *T5Config
// Embedding (shared for ByT5)
SharedEmbed *nn.Embedding `weight:"shared"`
// Encoder layers
Layers []*T5Block `weight:"encoder.block"`
// Final layer norm
FinalNorm *T5LayerNorm `weight:"encoder.final_layer_norm"`
// Relative position bias (from first layer, shared across all)
RelativeAttentionBias *mlx.Array `weight:"encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"`
}
// T5Block is a single T5 encoder block
type T5Block struct {
// Self attention
Layer0 *T5LayerSelfAttention `weight:"layer.0"`
// FFN
Layer1 *T5LayerFF `weight:"layer.1"`
}
// T5LayerSelfAttention is T5's self-attention layer
type T5LayerSelfAttention struct {
SelfAttention *T5Attention `weight:"SelfAttention"`
LayerNorm *T5LayerNorm `weight:"layer_norm"`
}
// T5Attention implements T5's relative attention
type T5Attention struct {
Q *mlx.Array `weight:"q.weight"` // No bias in T5
K *mlx.Array `weight:"k.weight"`
V *mlx.Array `weight:"v.weight"`
O *mlx.Array `weight:"o.weight"`
NHeads int32
DKV int32
Scale float32
}
// T5LayerFF is T5's feedforward layer with gated-gelu
type T5LayerFF struct {
DenseReluDense *T5DenseGatedGelu `weight:"DenseReluDense"`
LayerNorm *T5LayerNorm `weight:"layer_norm"`
}
// T5DenseGatedGelu is T5's gated-gelu FFN
type T5DenseGatedGelu struct {
Wi0 *mlx.Array `weight:"wi_0.weight"` // gate projection
Wi1 *mlx.Array `weight:"wi_1.weight"` // up projection
Wo *mlx.Array `weight:"wo.weight"` // down projection
}
// T5LayerNorm is T5's RMSNorm variant (no bias, no mean subtraction)
type T5LayerNorm struct {
Weight *mlx.Array `weight:"weight"`
Eps float32
}
// Load loads the T5 text encoder from manifest
func (m *T5TextEncoder) Load(manifest *imagegen.ModelManifest) error {
fmt.Print(" Loading T5 text encoder... ")
// Load config
var cfg T5Config
if err := manifest.ReadConfigJSON("text_encoder/config.json", &cfg); err != nil {
return fmt.Errorf("config: %w", err)
}
m.Config = &cfg
// Pre-allocate layers
m.Layers = make([]*T5Block, cfg.NumLayers)
// Load weights
weights, err := imagegen.LoadWeightsFromManifest(manifest, "text_encoder")
if err != nil {
return fmt.Errorf("weights: %w", err)
}
if err := weights.Load(0); err != nil {
return fmt.Errorf("load weights: %w", err)
}
defer weights.ReleaseAll()
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return fmt.Errorf("load module: %w", err)
}
m.initComputedFields()
fmt.Println("✓")
return nil
}
// LoadFromPath loads the T5 text encoder from a directory path
func (m *T5TextEncoder) LoadFromPath(path string) error {
fmt.Print(" Loading T5 text encoder... ")
// Load config
var cfg T5Config
configPath := filepath.Join(path, "config.json")
data, err := os.ReadFile(configPath)
if err != nil {
return fmt.Errorf("read config: %w", err)
}
if err := json.Unmarshal(data, &cfg); err != nil {
return fmt.Errorf("parse config: %w", err)
}
m.Config = &cfg
// Pre-allocate layers
m.Layers = make([]*T5Block, cfg.NumLayers)
// Load weights from safetensors files
weights, err := safetensors.LoadModelWeights(path)
if err != nil {
return fmt.Errorf("weights: %w", err)
}
if err := weights.Load(0); err != nil {
return fmt.Errorf("load weights: %w", err)
}
defer weights.ReleaseAll()
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return fmt.Errorf("load module: %w", err)
}
m.initComputedFields()
fmt.Println("✓")
return nil
}
func (m *T5TextEncoder) initComputedFields() {
cfg := m.Config
m.FinalNorm.Eps = cfg.LayerNormEps
for _, block := range m.Layers {
attn := block.Layer0.SelfAttention
attn.NHeads = cfg.NumHeads
attn.DKV = cfg.DKV
attn.Scale = float32(1.0 / math.Sqrt(float64(cfg.DKV)))
block.Layer0.LayerNorm.Eps = cfg.LayerNormEps
block.Layer1.LayerNorm.Eps = cfg.LayerNormEps
}
}
// Forward encodes text tokens
func (m *T5TextEncoder) Forward(tokens *mlx.Array) *mlx.Array {
cfg := m.Config
// Get embeddings
h := m.SharedEmbed.Forward(tokens)
// Compute relative position bias once
seqLen := tokens.Shape()[1]
posBias := m.computeRelativePositionBias(seqLen)
// Forward through layers
for _, block := range m.Layers {
h = block.Forward(h, posBias, cfg.LayerNormEps)
}
// Final norm
h = m.FinalNorm.Forward(h)
return h
}
// extractGlyphTexts extracts quoted text (glyphs) from the prompt
// This matches diffusers' get_glyph_texts from pipeline_glm_image.py
// Glyph texts are used for text rendering guidance in the generated image
func extractGlyphTexts(prompt string) []string {
var glyphTexts []string
// Extract text in single quotes: 'text'
re1 := regexp.MustCompile(`'([^']*)'`)
for _, match := range re1.FindAllStringSubmatch(prompt, -1) {
if len(match) > 1 {
glyphTexts = append(glyphTexts, match[1])
}
}
// Extract text in Unicode curly double quotes: "text"
re2 := regexp.MustCompile(`"([^""]*)"`)
for _, match := range re2.FindAllStringSubmatch(prompt, -1) {
if len(match) > 1 {
glyphTexts = append(glyphTexts, match[1])
}
}
// Extract text in ASCII double quotes: "text"
re3 := regexp.MustCompile(`"([^"]*)"`)
for _, match := range re3.FindAllStringSubmatch(prompt, -1) {
if len(match) > 1 {
glyphTexts = append(glyphTexts, match[1])
}
}
// Extract text in Japanese quotes: 「text」
re4 := regexp.MustCompile(`「([^「」]*)」`)
for _, match := range re4.FindAllStringSubmatch(prompt, -1) {
if len(match) > 1 {
glyphTexts = append(glyphTexts, match[1])
}
}
return glyphTexts
}
// EncodePrompt encodes the prompt text using the ByT5 tokenizer and encoder
// This provides text conditioning for the diffusion transformer via the glyph projector
//
// IMPORTANT: This encodes only the GLYPH TEXTS (quoted strings in the prompt), not the
// full prompt. Glyph texts are used for text rendering guidance in the generated image.
// Multiple glyph texts are encoded and concatenated to form the conditioning signal.
// This matches diffusers' _get_glyph_embeds() behavior.
func (m *T5TextEncoder) EncodePrompt(tok *ByT5Tokenizer, prompt string) *mlx.Array {
// Extract glyph texts from prompt (text in quotes)
glyphTexts := extractGlyphTexts(prompt)
// If no glyph texts found, encode empty string (matches diffusers: [""] fallback)
if len(glyphTexts) == 0 {
glyphTexts = []string{""}
}
// Encode each glyph text and collect token sequences
// Matching diffusers' _get_glyph_embeds() which batches all glyph texts
var allTokenSeqs [][]int32
for _, glyphText := range glyphTexts {
// ByT5 uses byte-level encoding: each byte (0-255) -> token (3-258)
tokens := tok.Encode(glyphText)
// Add EOS token (1) at the end to match HuggingFace tokenizer behavior
tokens = append(tokens, tok.EOSTokenID)
allTokenSeqs = append(allTokenSeqs, tokens)
}
// Process each glyph text through the encoder
var allEmbeddings []*mlx.Array
for _, tokens := range allTokenSeqs {
tokenLen := len(tokens)
if tokenLen == 0 {
continue
}
// Create token array [1, L]
tokensArr := mlx.NewArrayInt32(tokens, []int32{1, int32(tokenLen)})
// Forward through encoder
output := m.Forward(tokensArr)
mlx.Eval(output)
allEmbeddings = append(allEmbeddings, output)
}
// Concatenate all glyph embeddings along sequence dimension
var output *mlx.Array
if len(allEmbeddings) == 0 {
// Fallback: return single zero embedding
output = mlx.Zeros([]int32{1, 1, m.Config.DModel}, mlx.DtypeBFloat16)
} else if len(allEmbeddings) == 1 {
output = allEmbeddings[0]
} else {
output = mlx.Concatenate(allEmbeddings, 1)
}
mlx.Eval(output)
return output
}
// computeRelativePositionBias computes T5's relative position encoding
func (m *T5TextEncoder) computeRelativePositionBias(seqLen int32) *mlx.Array {
cfg := m.Config
// Create relative position matrix
// For each (query_pos, key_pos) pair, compute bucketed relative position
numBuckets := cfg.RelativeAttentionNumBuckets
maxDistance := cfg.RelativeAttentionMaxDistance
// Create position indices
contextPos := make([]int32, seqLen*seqLen)
memoryPos := make([]int32, seqLen*seqLen)
for i := int32(0); i < seqLen; i++ {
for j := int32(0); j < seqLen; j++ {
contextPos[i*seqLen+j] = i
memoryPos[i*seqLen+j] = j
}
}
// Compute relative positions and bucket them
buckets := make([]int32, seqLen*seqLen)
for i := int32(0); i < seqLen*seqLen; i++ {
relPos := memoryPos[i] - contextPos[i]
buckets[i] = relativePosistionBucket(relPos, numBuckets, maxDistance, false)
}
// Create bucket indices array
bucketsArr := mlx.NewArrayInt32(buckets, []int32{seqLen, seqLen})
// Look up bias: RelativeAttentionBias shape is [numBuckets, numHeads] = [32, 6]
// Take along axis 0 (buckets dimension) -> [seqLen, seqLen, numHeads]
bias := mlx.Take(m.RelativeAttentionBias, bucketsArr, 0) // [seqLen, seqLen, numHeads]
// Transpose to [numHeads, seqLen, seqLen]
bias = mlx.Transpose(bias, 2, 0, 1) // [numHeads, seqLen, seqLen]
bias = mlx.ExpandDims(bias, 0) // [1, numHeads, seqLen, seqLen]
return bias
}
// relativePosistionBucket computes the bucket for a relative position
func relativePosistionBucket(relativePosition int32, numBuckets int32, maxDistance int32, bidirectional bool) int32 {
var bucket int32 = 0
var n int32 = -relativePosition
if bidirectional {
numBuckets /= 2
if n < 0 {
bucket += numBuckets
n = -n
}
} else {
if n < 0 {
n = 0
}
}
// Half buckets are for exact positions, half are for log-spaced
maxExact := numBuckets / 2
if n < maxExact {
bucket += n
} else {
// Log-spaced buckets
logVal := math.Log(float64(n)/float64(maxExact)) / math.Log(float64(maxDistance)/float64(maxExact))
bucket += maxExact + int32(logVal*float64(numBuckets-maxExact))
if bucket > numBuckets-1 {
bucket = numBuckets - 1
}
}
return bucket
}
// Forward for T5Block
func (b *T5Block) Forward(x *mlx.Array, posBias *mlx.Array, eps float32) *mlx.Array {
// Self attention with residual
h := b.Layer0.Forward(x, posBias, eps)
// FFN with residual
h = b.Layer1.Forward(h, eps)
return h
}
// Forward for T5LayerSelfAttention
func (l *T5LayerSelfAttention) Forward(x *mlx.Array, posBias *mlx.Array, eps float32) *mlx.Array {
// Pre-norm
normed := l.LayerNorm.Forward(x)
// Attention
attnOut := l.SelfAttention.Forward(normed, posBias)
// Residual
return mlx.Add(x, attnOut)
}
// Forward for T5Attention
func (attn *T5Attention) Forward(x *mlx.Array, posBias *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
D := shape[2]
// Q, K, V projections (no bias)
// Weights are [out_features, in_features], so we use matmul with transpose
q := mlx.Matmul(x, mlx.Transpose(attn.Q, 1, 0))
k := mlx.Matmul(x, mlx.Transpose(attn.K, 1, 0))
v := mlx.Matmul(x, mlx.Transpose(attn.V, 1, 0))
// Reshape to [B, L, nheads, d_kv]
q = mlx.Reshape(q, B, L, attn.NHeads, attn.DKV)
k = mlx.Reshape(k, B, L, attn.NHeads, attn.DKV)
v = mlx.Reshape(v, B, L, attn.NHeads, attn.DKV)
// Transpose to [B, nheads, L, d_kv]
q = mlx.Transpose(q, 0, 2, 1, 3)
k = mlx.Transpose(k, 0, 2, 1, 3)
v = mlx.Transpose(v, 0, 2, 1, 3)
// Attention scores with relative position bias
// T5 uses UNSCALED dot-product attention: scores = q @ k.T + pos_bias
// (no 1/sqrt(d_k) scale factor like in standard transformers)
scores := mlx.Matmul(q, mlx.Transpose(k, 0, 1, 3, 2))
scores = mlx.Add(scores, posBias)
// Softmax
attnWeights := mlx.Softmax(scores, -1)
// Attend to values
out := mlx.Matmul(attnWeights, v)
// Transpose back [B, nheads, L, d_kv] -> [B, L, nheads, d_kv]
out = mlx.Transpose(out, 0, 2, 1, 3)
// Reshape to [B, L, D]
out = mlx.Reshape(out, B, L, attn.NHeads*attn.DKV)
// Output projection
out = mlx.Matmul(out, mlx.Transpose(attn.O, 1, 0))
_ = D // Silence unused warning
return out
}
// Forward for T5LayerFF
func (l *T5LayerFF) Forward(x *mlx.Array, eps float32) *mlx.Array {
// Pre-norm
normed := l.LayerNorm.Forward(x)
// FFN
ffOut := l.DenseReluDense.Forward(normed)
// Residual
return mlx.Add(x, ffOut)
}
// geluNew implements the GELU activation with tanh approximation (gelu_new)
// This matches HuggingFace transformers' gelu_new/OpenAI GPT implementation
// Formula: 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x³)))
func geluNew(x *mlx.Array) *mlx.Array {
sqrt2OverPi := float32(0.7978845608) // sqrt(2/π)
coeff := float32(0.044715)
x3 := mlx.Mul(mlx.Mul(x, x), x)
inner := mlx.MulScalar(mlx.Add(x, mlx.MulScalar(x3, coeff)), sqrt2OverPi)
return mlx.Mul(mlx.MulScalar(x, 0.5), mlx.AddScalar(mlx.Tanh(inner), 1.0))
}
// Forward for T5DenseGatedGelu (gated-gelu activation)
func (d *T5DenseGatedGelu) Forward(x *mlx.Array) *mlx.Array {
// Gate projection with GELU activation (T5 v1.1/ByT5 uses gelu_new)
gate := mlx.Matmul(x, mlx.Transpose(d.Wi0, 1, 0))
gate = geluNew(gate)
// Up projection
up := mlx.Matmul(x, mlx.Transpose(d.Wi1, 1, 0))
// Gated output
h := mlx.Mul(gate, up)
// Down projection
return mlx.Matmul(h, mlx.Transpose(d.Wo, 1, 0))
}
// Forward for T5LayerNorm (RMSNorm variant)
func (ln *T5LayerNorm) Forward(x *mlx.Array) *mlx.Array {
// T5 uses RMSNorm: x * rsqrt(mean(x^2) + eps) * weight
variance := mlx.Mean(mlx.Square(x), -1, true)
x = mlx.Mul(x, mlx.RSqrt(mlx.AddScalar(variance, ln.Eps)))
return mlx.Mul(x, ln.Weight)
}

View File

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,477 @@
//go:build mlx
package glm_image
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// VAEConfig holds VAE decoder configuration
type VAEConfig struct {
InChannels int32 `json:"in_channels"` // 3
OutChannels int32 `json:"out_channels"` // 3
LatentChannels int32 `json:"latent_channels"` // 16
BlockOutChannels []int32 `json:"block_out_channels"` // [128, 512, 1024, 1024]
LayersPerBlock int32 `json:"layers_per_block"` // 3
NormNumGroups int32 `json:"norm_num_groups"` // 32
ScalingFactor float32 `json:"scaling_factor"` // 0.18215
ShiftFactor *float32 `json:"shift_factor"` // null
LatentsMean []float32 `json:"latents_mean"` // [16 values]
LatentsStd []float32 `json:"latents_std"` // [16 values]
}
// VAEDecoder is the VAE latent decoder
type VAEDecoder struct {
Config *VAEConfig
// Decoder components
ConvIn *VAEConv2d `weight:"decoder.conv_in"`
MidBlock *VAEMidBlock `weight:"decoder.mid_block"`
UpBlocks []*VAEUpBlock `weight:"decoder.up_blocks"`
ConvNormOut *GroupNorm `weight:"decoder.conv_norm_out"`
ConvOut *VAEConv2d `weight:"decoder.conv_out"`
}
// VAEConv2d is a 2D convolution layer
type VAEConv2d struct {
Weight *mlx.Array `weight:"weight"`
Bias *mlx.Array `weight:"bias"`
Stride int32
Padding int32
}
// GroupNorm is group normalization
type GroupNorm struct {
Weight *mlx.Array `weight:"weight"`
Bias *mlx.Array `weight:"bias"`
NumGroups int32
Eps float32
}
// VAEMidBlock is the middle block of the VAE
type VAEMidBlock struct {
Resnets []*VAEResnetBlock `weight:"resnets"`
}
// VAEUpBlock is an upsampling block
type VAEUpBlock struct {
Resnets []*VAEResnetBlock `weight:"resnets"`
Upsamplers []*VAEUpsampler `weight:"upsamplers"`
}
// VAEResnetBlock is a residual block
type VAEResnetBlock struct {
Norm1 *GroupNorm `weight:"norm1"`
Conv1 *VAEConv2d `weight:"conv1"`
Norm2 *GroupNorm `weight:"norm2"`
Conv2 *VAEConv2d `weight:"conv2"`
ConvShortcut *VAEConv2d `weight:"conv_shortcut,optional"` // Optional, for channel mismatch
}
// VAEUpsampler is an upsampling layer
type VAEUpsampler struct {
Conv *VAEConv2d `weight:"conv"`
}
// Load loads the VAE decoder from manifest
func (m *VAEDecoder) Load(manifest *imagegen.ModelManifest) error {
fmt.Print(" Loading VAE decoder... ")
// Load config
var cfg VAEConfig
if err := manifest.ReadConfigJSON("vae/config.json", &cfg); err != nil {
return fmt.Errorf("config: %w", err)
}
m.Config = &cfg
// Initialize structure based on config
numBlocks := len(cfg.BlockOutChannels)
m.UpBlocks = make([]*VAEUpBlock, numBlocks)
// Pre-allocate MidBlock resnets (VAE mid_block typically has 2 resnets)
m.MidBlock = &VAEMidBlock{
Resnets: make([]*VAEResnetBlock, 2),
}
// Pre-allocate UpBlocks with their resnets and upsamplers
// VAE decoder has layers_per_block+1 resnets per up_block (to match encoder)
// And all but the last up_block has an upsampler
for i := 0; i < numBlocks; i++ {
numResnets := cfg.LayersPerBlock + 1 // typically 4 resnets
m.UpBlocks[i] = &VAEUpBlock{
Resnets: make([]*VAEResnetBlock, numResnets),
}
// All but the last block has upsamplers
if i < numBlocks-1 {
m.UpBlocks[i].Upsamplers = make([]*VAEUpsampler, 1)
}
}
// Load weights
weights, err := imagegen.LoadWeightsFromManifest(manifest, "vae")
if err != nil {
return fmt.Errorf("weights: %w", err)
}
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
return fmt.Errorf("load weights: %w", err)
}
defer weights.ReleaseAll()
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return fmt.Errorf("load module: %w", err)
}
// Initialize GroupNorm parameters
m.initGroupNorms()
fmt.Println("✓")
return nil
}
// LoadFromPath loads the VAE decoder from a directory path
func (m *VAEDecoder) LoadFromPath(path string) error {
fmt.Print(" Loading VAE decoder... ")
// Load config
var cfg VAEConfig
configPath := filepath.Join(path, "config.json")
data, err := os.ReadFile(configPath)
if err != nil {
return fmt.Errorf("read config: %w", err)
}
if err := json.Unmarshal(data, &cfg); err != nil {
return fmt.Errorf("parse config: %w", err)
}
m.Config = &cfg
// Initialize structure based on config
numBlocks := len(cfg.BlockOutChannels)
m.UpBlocks = make([]*VAEUpBlock, numBlocks)
// Pre-allocate MidBlock resnets (VAE mid_block typically has 2 resnets)
m.MidBlock = &VAEMidBlock{
Resnets: make([]*VAEResnetBlock, 2),
}
// Pre-allocate UpBlocks with their resnets and upsamplers
for i := 0; i < numBlocks; i++ {
numResnets := cfg.LayersPerBlock + 1
m.UpBlocks[i] = &VAEUpBlock{
Resnets: make([]*VAEResnetBlock, numResnets),
}
if i < numBlocks-1 {
m.UpBlocks[i].Upsamplers = make([]*VAEUpsampler, 1)
}
}
// Load weights from safetensors files
weights, err := safetensors.LoadModelWeights(path)
if err != nil {
return fmt.Errorf("weights: %w", err)
}
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
return fmt.Errorf("load weights: %w", err)
}
defer weights.ReleaseAll()
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return fmt.Errorf("load module: %w", err)
}
// Initialize GroupNorm parameters
m.initGroupNorms()
fmt.Println("✓")
return nil
}
func (m *VAEDecoder) initGroupNorms() {
cfg := m.Config
numGroups := cfg.NormNumGroups
eps := float32(1e-6) // Must match diffusers VAE (1e-6, not 1e-5)
if m.ConvNormOut != nil {
m.ConvNormOut.NumGroups = numGroups
m.ConvNormOut.Eps = eps
}
if m.MidBlock != nil {
for _, resnet := range m.MidBlock.Resnets {
if resnet.Norm1 != nil {
resnet.Norm1.NumGroups = numGroups
resnet.Norm1.Eps = eps
}
if resnet.Norm2 != nil {
resnet.Norm2.NumGroups = numGroups
resnet.Norm2.Eps = eps
}
}
}
for _, upBlock := range m.UpBlocks {
if upBlock == nil {
continue
}
for _, resnet := range upBlock.Resnets {
if resnet == nil {
continue
}
if resnet.Norm1 != nil {
resnet.Norm1.NumGroups = numGroups
resnet.Norm1.Eps = eps
}
if resnet.Norm2 != nil {
resnet.Norm2.NumGroups = numGroups
resnet.Norm2.Eps = eps
}
}
}
}
// Decode decodes latents to an image
func (m *VAEDecoder) Decode(latents *mlx.Array) *mlx.Array {
cfg := m.Config
// Apply latent denormalization if mean/std are provided
// This matches diffusers GLM-Image: latents = latents * std + mean
// Note: GLM-Image does NOT divide by scaling_factor (unlike standard SD VAEs)
if len(cfg.LatentsMean) > 0 && len(cfg.LatentsStd) > 0 {
latents = m.denormalizeLatents(latents)
}
// Convert from NCHW to NHWC for processing
// [B, C, H, W] -> [B, H, W, C]
x := mlx.Transpose(latents, 0, 2, 3, 1)
// Initial convolution
x = m.ConvIn.Forward(x)
// Mid block
x = m.MidBlock.Forward(x)
// Up blocks (forward order - index 0 is at lowest resolution/highest channels)
for i := 0; i < len(m.UpBlocks); i++ {
if m.UpBlocks[i] != nil {
x = m.UpBlocks[i].Forward(x)
}
}
// Final normalization and convolution
x = m.ConvNormOut.Forward(x)
x = mlx.SiLU(x)
x = m.ConvOut.Forward(x)
// Convert back to NCHW
// [B, H, W, C] -> [B, C, H, W]
x = mlx.Transpose(x, 0, 3, 1, 2)
// Clamp to valid range and convert to [0, 1]
x = mlx.ClipScalar(x, -1.0, 1.0, true, true)
x = mlx.AddScalar(x, 1.0)
x = mlx.DivScalar(x, 2.0)
return x
}
// denormalizeLatents applies the latent mean/std denormalization
func (m *VAEDecoder) denormalizeLatents(latents *mlx.Array) *mlx.Array {
cfg := m.Config
// Create mean and std arrays [1, C, 1, 1] for broadcasting
mean := mlx.NewArray(cfg.LatentsMean, []int32{1, int32(len(cfg.LatentsMean)), 1, 1})
std := mlx.NewArray(cfg.LatentsStd, []int32{1, int32(len(cfg.LatentsStd)), 1, 1})
// Denormalize: latents * std + mean
latents = mlx.Mul(latents, std)
latents = mlx.Add(latents, mean)
return latents
}
// Forward for VAEConv2d
func (c *VAEConv2d) Forward(x *mlx.Array) *mlx.Array {
// x: [B, H, W, C_in] (NHWC)
// PyTorch weight: [C_out, C_in, kH, kW] (OIHW)
// MLX conv2d expects weight: [C_out, kH, kW, C_in] (OHWI)
// So we need to transpose from OIHW to OHWI
stride := c.Stride
if stride == 0 {
stride = 1
}
padding := c.Padding
if padding == 0 {
// Default to same padding for 3x3 kernels
wShape := c.Weight.Shape()
if len(wShape) >= 3 && wShape[2] == 3 {
padding = 1
}
}
// Transpose weight from OIHW [out, in, h, w] to OHWI [out, h, w, in]
weight := mlx.Transpose(c.Weight, 0, 2, 3, 1)
out := mlx.Conv2d(x, weight, stride, padding)
if c.Bias != nil {
// Bias: [C_out] -> [1, 1, 1, C_out]
bias := mlx.Reshape(c.Bias, 1, 1, 1, -1)
out = mlx.Add(out, bias)
}
return out
}
// Forward for GroupNorm
func (gn *GroupNorm) Forward(x *mlx.Array) *mlx.Array {
// x: [B, H, W, C] (NHWC)
shape := x.Shape()
B := shape[0]
H := shape[1]
W := shape[2]
C := shape[3]
numGroups := gn.NumGroups
if numGroups == 0 {
numGroups = 32
}
groupSize := C / numGroups
// Reshape to [B, H, W, groups, groupSize]
x = mlx.Reshape(x, B, H, W, numGroups, groupSize)
// Compute mean and variance per group
mean := mlx.Mean(x, 1, true)
mean = mlx.Mean(mean, 2, true)
mean = mlx.Mean(mean, 4, true)
xCentered := mlx.Sub(x, mean)
variance := mlx.Mean(mlx.Square(xCentered), 1, true)
variance = mlx.Mean(variance, 2, true)
variance = mlx.Mean(variance, 4, true)
// Normalize
xNorm := mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, gn.Eps)))
// Reshape back
xNorm = mlx.Reshape(xNorm, B, H, W, C)
// Scale and shift
if gn.Weight != nil {
weight := mlx.Reshape(gn.Weight, 1, 1, 1, C)
xNorm = mlx.Mul(xNorm, weight)
}
if gn.Bias != nil {
bias := mlx.Reshape(gn.Bias, 1, 1, 1, C)
xNorm = mlx.Add(xNorm, bias)
}
return xNorm
}
// Forward for VAEMidBlock
func (mb *VAEMidBlock) Forward(x *mlx.Array) *mlx.Array {
for _, resnet := range mb.Resnets {
x = resnet.Forward(x)
}
return x
}
// Forward for VAEUpBlock
func (ub *VAEUpBlock) Forward(x *mlx.Array) *mlx.Array {
// Apply resnets
for _, resnet := range ub.Resnets {
if resnet != nil {
x = resnet.Forward(x)
}
}
// Apply upsamplers
for _, upsampler := range ub.Upsamplers {
if upsampler != nil {
x = upsampler.Forward(x)
}
}
return x
}
// Forward for VAEResnetBlock
func (rb *VAEResnetBlock) Forward(x *mlx.Array) *mlx.Array {
residual := x
// First norm + activation + conv
h := rb.Norm1.Forward(x)
h = mlx.SiLU(h)
h = rb.Conv1.Forward(h)
// Second norm + activation + conv
h = rb.Norm2.Forward(h)
h = mlx.SiLU(h)
h = rb.Conv2.Forward(h)
// Shortcut for channel mismatch
if rb.ConvShortcut != nil {
residual = rb.ConvShortcut.Forward(residual)
}
return mlx.Add(h, residual)
}
// Forward for VAEUpsampler (2x nearest neighbor upsample + conv)
func (us *VAEUpsampler) Forward(x *mlx.Array) *mlx.Array {
// x: [B, H, W, C]
// 2x nearest neighbor upsample
x = upsample2x(x)
// Conv
if us.Conv != nil {
x = us.Conv.Forward(x)
}
return x
}
// upsample2x performs 2x nearest neighbor upsampling.
// Input and output are in NHWC format: [B, H, W, C] -> [B, H*2, W*2, C]
func upsample2x(x *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
H := shape[1]
W := shape[2]
C := shape[3]
// Create indices [0, 0, 1, 1, 2, 2, ...] for nearest neighbor
hIndices := make([]int32, H*2)
for i := int32(0); i < H; i++ {
hIndices[i*2] = i
hIndices[i*2+1] = i
}
wIndices := make([]int32, W*2)
for i := int32(0); i < W; i++ {
wIndices[i*2] = i
wIndices[i*2+1] = i
}
hIdx := mlx.NewArrayInt32(hIndices, []int32{H * 2})
wIdx := mlx.NewArrayInt32(wIndices, []int32{W * 2})
// Take along height axis
x = mlx.Reshape(x, B*H, W, C)
x = mlx.Take(x, wIdx, 1) // [B*H, W*2, C]
x = mlx.Reshape(x, B, H, W*2, C)
// Take along width axis - transpose to [B, W*2, H, C], take, transpose back
x = mlx.Transpose(x, 0, 2, 1, 3) // [B, W*2, H, C]
x = mlx.Reshape(x, B*(W*2), H, C)
x = mlx.Take(x, hIdx, 1) // [B*(W*2), H*2, C]
x = mlx.Reshape(x, B, W*2, H*2, C)
x = mlx.Transpose(x, 0, 2, 1, 3) // [B, H*2, W*2, C]
return x
}

View File

@@ -0,0 +1,982 @@
//go:build mlx
package glm_image
import (
"encoding/json"
"fmt"
"math"
"os"
"path/filepath"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// VisionLanguageConfig holds GLM-Image AR generator configuration
type VisionLanguageConfig struct {
// Text model config
HiddenSize int32 `json:"hidden_size"` // 4096
NumHiddenLayers int32 `json:"num_hidden_layers"` // 40
IntermediateSize int32 `json:"intermediate_size"` // 13696
NumAttentionHeads int32 `json:"num_attention_heads"` // 32
NumKeyValueHeads int32 `json:"num_key_value_heads"` // 2
VocabSize int32 `json:"vocab_size"` // 168064
RMSNormEps float32 `json:"rms_norm_eps"` // 1e-5
// RoPE config
RopeTheta float32 `json:"rope_theta"` // 10000
PartialRotaryFactor float32 `json:"partial_rotary_factor"` // 0.5
MRoPESection []int32 `json:"mrope_section"` // [8, 12, 12]
// Visual token config
VisionVocabSize int32 `json:"vision_vocab_size"` // 16512
ImageStartTokenID int32 `json:"image_start_token_id"` // 16384
ImageEndTokenID int32 `json:"image_end_token_id"` // 16385
ImageTokenID int32 `json:"image_token_id"` // 167855
// Computed
HeadDim int32
}
// VisionLanguageEncoder is the 9B AR generator
type VisionLanguageEncoder struct {
Config *VisionLanguageConfig
// Embedding
EmbedTokens *nn.Embedding `weight:"model.language_model.embed_tokens"`
// Transformer layers
Layers []*GLMBlock `weight:"model.language_model.layers"`
// Final norm
FinalNorm *nn.RMSNorm `weight:"model.language_model.norm"`
// LM Head
LMHead *mlx.Array `weight:"lm_head.weight"`
}
// GLMBlock is a single transformer block in GLM-4 style
type GLMBlock struct {
// Pre-attention norm (GLM uses post-LN variant)
InputLayerNorm *nn.RMSNorm `weight:"input_layernorm"`
PostSelfAttnNorm *nn.RMSNorm `weight:"post_self_attn_layernorm"`
PostAttnLayerNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
PostMLPLayerNorm *nn.RMSNorm `weight:"post_mlp_layernorm"`
// Attention
SelfAttn *GLMAttention `weight:"self_attn"`
// MLP (fused gate_up)
MLP *GLMMLP `weight:"mlp"`
}
// GLMAttention implements GQA with partial rotary and MRoPE
type GLMAttention struct {
QProj *mlx.Array `weight:"q_proj.weight"`
KProj *mlx.Array `weight:"k_proj.weight"`
VProj *mlx.Array `weight:"v_proj.weight"`
OProj *mlx.Array `weight:"o_proj.weight"`
// QKV have biases in GLM
QBias *mlx.Array `weight:"q_proj.bias"`
KBias *mlx.Array `weight:"k_proj.bias"`
VBias *mlx.Array `weight:"v_proj.bias"`
// Computed
NHeads int32
NKVHeads int32
HeadDim int32
Scale float32
PartialRotary float32 // Only rotate this fraction of head_dim
RopeTheta float32
MRoPESection []int32 // [8, 12, 12] - frequency pairs per dimension (temporal, height, width)
}
// ARCache holds KV caches for all layers using the shared cache implementation
type ARCache struct {
Layers []cache.Cache
}
// NewARCache creates a new cache for the given number of layers
func NewARCache(numLayers int32) *ARCache {
layers := make([]cache.Cache, numLayers)
for i := range layers {
layers[i] = cache.NewKVCache()
}
return &ARCache{Layers: layers}
}
// Free releases all cached tensors
func (c *ARCache) Free() {
for _, layer := range c.Layers {
for _, arr := range layer.State() {
if arr != nil {
arr.Free()
}
}
}
}
// GLMMLP implements fused gate_up SwiGLU MLP
type GLMMLP struct {
// GLM uses fused gate_up_proj: [hidden, 2*intermediate]
GateUpProj *mlx.Array `weight:"gate_up_proj.weight"`
DownProj *mlx.Array `weight:"down_proj.weight"`
}
// Load loads the vision-language encoder from manifest
func (m *VisionLanguageEncoder) Load(manifest *imagegen.ModelManifest) error {
fmt.Print(" Loading vision-language encoder... ")
// Load config
var rawCfg struct {
TextConfig struct {
HiddenSize int32 `json:"hidden_size"`
NumHiddenLayers int32 `json:"num_hidden_layers"`
IntermediateSize int32 `json:"intermediate_size"`
NumAttentionHeads int32 `json:"num_attention_heads"`
NumKeyValueHeads int32 `json:"num_key_value_heads"`
VocabSize int32 `json:"vocab_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
VisionVocabSize int32 `json:"vision_vocab_size"`
RopeParameters struct {
RopeTheta float32 `json:"rope_theta"`
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
MRoPESection []int32 `json:"mrope_section"`
} `json:"rope_parameters"`
} `json:"text_config"`
ImageStartTokenID int32 `json:"image_start_token_id"`
ImageEndTokenID int32 `json:"image_end_token_id"`
ImageTokenID int32 `json:"image_token_id"`
}
if err := manifest.ReadConfigJSON("vision_language_encoder/config.json", &rawCfg); err != nil {
return fmt.Errorf("config: %w", err)
}
cfg := &VisionLanguageConfig{
HiddenSize: rawCfg.TextConfig.HiddenSize,
NumHiddenLayers: rawCfg.TextConfig.NumHiddenLayers,
IntermediateSize: rawCfg.TextConfig.IntermediateSize,
NumAttentionHeads: rawCfg.TextConfig.NumAttentionHeads,
NumKeyValueHeads: rawCfg.TextConfig.NumKeyValueHeads,
VocabSize: rawCfg.TextConfig.VocabSize,
RMSNormEps: rawCfg.TextConfig.RMSNormEps,
VisionVocabSize: rawCfg.TextConfig.VisionVocabSize,
RopeTheta: rawCfg.TextConfig.RopeParameters.RopeTheta,
PartialRotaryFactor: rawCfg.TextConfig.RopeParameters.PartialRotaryFactor,
MRoPESection: rawCfg.TextConfig.RopeParameters.MRoPESection,
ImageStartTokenID: rawCfg.ImageStartTokenID,
ImageEndTokenID: rawCfg.ImageEndTokenID,
ImageTokenID: rawCfg.ImageTokenID,
}
cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads
m.Config = cfg
// Pre-allocate layers
m.Layers = make([]*GLMBlock, cfg.NumHiddenLayers)
// Load weights
weights, err := imagegen.LoadWeightsFromManifest(manifest, "vision_language_encoder")
if err != nil {
return fmt.Errorf("weights: %w", err)
}
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
return fmt.Errorf("load weights: %w", err)
}
defer weights.ReleaseAll()
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return fmt.Errorf("load module: %w", err)
}
m.initComputedFields()
fmt.Printf("✓ [%d layers]\n", cfg.NumHiddenLayers)
return nil
}
// LoadFromPath loads the vision-language encoder from a directory path
func (m *VisionLanguageEncoder) LoadFromPath(path string) error {
fmt.Print(" Loading vision-language encoder... ")
// Load config
var rawCfg struct {
TextConfig struct {
HiddenSize int32 `json:"hidden_size"`
NumHiddenLayers int32 `json:"num_hidden_layers"`
IntermediateSize int32 `json:"intermediate_size"`
NumAttentionHeads int32 `json:"num_attention_heads"`
NumKeyValueHeads int32 `json:"num_key_value_heads"`
VocabSize int32 `json:"vocab_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
VisionVocabSize int32 `json:"vision_vocab_size"`
RopeParameters struct {
RopeTheta float32 `json:"rope_theta"`
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
MRoPESection []int32 `json:"mrope_section"`
} `json:"rope_parameters"`
} `json:"text_config"`
ImageStartTokenID int32 `json:"image_start_token_id"`
ImageEndTokenID int32 `json:"image_end_token_id"`
ImageTokenID int32 `json:"image_token_id"`
}
configPath := filepath.Join(path, "config.json")
data, err := os.ReadFile(configPath)
if err != nil {
return fmt.Errorf("read config: %w", err)
}
if err := json.Unmarshal(data, &rawCfg); err != nil {
return fmt.Errorf("parse config: %w", err)
}
cfg := &VisionLanguageConfig{
HiddenSize: rawCfg.TextConfig.HiddenSize,
NumHiddenLayers: rawCfg.TextConfig.NumHiddenLayers,
IntermediateSize: rawCfg.TextConfig.IntermediateSize,
NumAttentionHeads: rawCfg.TextConfig.NumAttentionHeads,
NumKeyValueHeads: rawCfg.TextConfig.NumKeyValueHeads,
VocabSize: rawCfg.TextConfig.VocabSize,
RMSNormEps: rawCfg.TextConfig.RMSNormEps,
VisionVocabSize: rawCfg.TextConfig.VisionVocabSize,
RopeTheta: rawCfg.TextConfig.RopeParameters.RopeTheta,
PartialRotaryFactor: rawCfg.TextConfig.RopeParameters.PartialRotaryFactor,
MRoPESection: rawCfg.TextConfig.RopeParameters.MRoPESection,
ImageStartTokenID: rawCfg.ImageStartTokenID,
ImageEndTokenID: rawCfg.ImageEndTokenID,
ImageTokenID: rawCfg.ImageTokenID,
}
cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads
m.Config = cfg
// Pre-allocate layers
m.Layers = make([]*GLMBlock, cfg.NumHiddenLayers)
// Load weights
weights, err := safetensors.LoadModelWeights(path)
if err != nil {
return fmt.Errorf("weights: %w", err)
}
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
return fmt.Errorf("load weights: %w", err)
}
defer weights.ReleaseAll()
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return fmt.Errorf("load module: %w", err)
}
m.initComputedFields()
fmt.Printf("✓ [%d layers]\n", cfg.NumHiddenLayers)
return nil
}
func (m *VisionLanguageEncoder) initComputedFields() {
cfg := m.Config
for _, block := range m.Layers {
block.SelfAttn.NHeads = cfg.NumAttentionHeads
block.SelfAttn.NKVHeads = cfg.NumKeyValueHeads
block.SelfAttn.HeadDim = cfg.HeadDim
block.SelfAttn.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
block.SelfAttn.PartialRotary = cfg.PartialRotaryFactor
block.SelfAttn.RopeTheta = cfg.RopeTheta
block.SelfAttn.MRoPESection = cfg.MRoPESection
// Set norm eps
block.InputLayerNorm.Eps = cfg.RMSNormEps
block.PostSelfAttnNorm.Eps = cfg.RMSNormEps
block.PostAttnLayerNorm.Eps = cfg.RMSNormEps
block.PostMLPLayerNorm.Eps = cfg.RMSNormEps
}
m.FinalNorm.Eps = cfg.RMSNormEps
}
// Generate autoregressively generates visual tokens with KV caching
func (m *VisionLanguageEncoder) Generate(
prompt string,
tok *GLMTokenizer,
maxTokens int32,
temperature float32,
topP float32,
seed int64,
targetHeight, targetWidth int32,
progressFn func(int),
) *mlx.Array {
cfg := m.Config
// Encode prompt with grid tokens using GLM tokenizer
// Format: {prompt}<sop>{h} {w}<eop><sop>{prev_h} {prev_w}<eop><|dit_token_16384|>
tokens := tok.EncodeForGeneration(prompt, targetHeight, targetWidth)
// Calculate grid dimensions for MRoPE position IDs
factor := int32(32)
tokenH := targetHeight / factor
tokenW := targetWidth / factor
ratio := float64(tokenH) / float64(tokenW)
prevTokenH := int32(math.Sqrt(ratio) * 16)
prevTokenW := int32(math.Sqrt(1.0/ratio) * 16)
prevGridSize := prevTokenH * prevTokenW
// Create KV cache for all layers
cache := NewARCache(cfg.NumHiddenLayers)
defer cache.Free()
// ===== PREFILL PHASE =====
// Process entire prompt at once, populate cache
promptLen := int32(len(tokens))
tokenArr := mlx.NewArrayInt32(tokens, []int32{1, promptLen})
h := m.EmbedTokens.Forward(tokenArr)
tokenArr.Free()
mlx.Eval(h)
// Compute position IDs for prefill (text tokens use same position for all dims)
prefillPositions := make([][]int32, 3)
for dim := 0; dim < 3; dim++ {
prefillPositions[dim] = make([]int32, promptLen)
for i := int32(0); i < promptLen; i++ {
prefillPositions[dim][i] = i
}
}
// Forward through layers (prefill)
for i, layer := range m.Layers {
oldH := h
h = layer.ForwardWithCache(h, promptLen, 0, cfg.RMSNormEps, cache.Layers[i], prefillPositions)
if i > 0 {
oldH.Free()
}
}
// Eval h and cache arrays together so cache is materialized
evalArgs := []*mlx.Array{h}
for _, lc := range cache.Layers {
evalArgs = append(evalArgs, lc.State()...)
}
mlx.Eval(evalArgs...)
// Final norm and get logits for last position
preNormH := h
h = m.FinalNorm.Forward(h, cfg.RMSNormEps)
preNormH.Free()
lastH := mlx.Slice(h, []int32{0, promptLen - 1, 0}, []int32{1, promptLen, cfg.HiddenSize})
h.Free()
lastH = mlx.Reshape(lastH, 1, cfg.HiddenSize)
logits := mlx.Matmul(lastH, mlx.Transpose(m.LMHead, 1, 0))
lastH.Free()
// Sample first token
var sampleCounter int64 = 0
nextToken := sampleVisualToken(logits, temperature, topP, cfg, seed, &sampleCounter)
logits.Free()
// AR generation loop with caching
// Visual tokens are stored as VQ codebook indices [0, 16383]
// The LM head outputs indices [0, 16511] where:
// - [0, 16383] are VQ codes
// - 16384 is BOS
// - 16385 is EOS
visualTokens := make([]int32, 0, maxTokens)
posOffset := promptLen
visualTokenIdx := int32(0) // Index within visual token sequence for grid position calculation
// Preallocate slice for old cache state to reuse
oldCacheState := make([]*mlx.Array, 0, len(m.Layers)*2)
for i := int32(0); i < maxTokens; i++ {
if progressFn != nil {
progressFn(int(i))
}
// Check for end token (EOS = 16385)
if nextToken == cfg.ImageEndTokenID {
break
}
// Skip BOS token (16384), only store actual VQ codes [0, 16383]
if nextToken == cfg.ImageStartTokenID {
// BOS token - skip storing but continue generation
} else if nextToken < cfg.ImageStartTokenID {
// This is an actual VQ code [0, 16383] - store it
visualTokens = append(visualTokens, nextToken)
}
// Tokens >= 16386 are other special tokens, skip them
// ===== DECODE PHASE =====
// Save old cache state before forward (to free after eval)
oldCacheState = oldCacheState[:0]
for _, lc := range cache.Layers {
oldCacheState = append(oldCacheState, lc.State()...)
}
// Only process the new token, use cached K,V
tokenArr := mlx.NewArrayInt32([]int32{nextToken}, []int32{1, 1})
h := m.EmbedTokens.Forward(tokenArr)
tokenArr.Free()
// Compute MRoPE position IDs for this visual token
// Visual tokens are arranged in two grids: prev grid then target grid
// Position dimensions: [temporal, height, width]
decodePositions := computeVisualTokenPositions(
visualTokenIdx, posOffset, promptLen,
prevTokenH, prevTokenW, prevGridSize,
tokenH, tokenW,
)
// Forward through layers (decode with cache)
for j, layer := range m.Layers {
oldH := h
h = layer.ForwardWithCache(h, 1, posOffset, cfg.RMSNormEps, cache.Layers[j], decodePositions)
if j > 0 { // Don't free the embedding on first layer
oldH.Free()
}
}
// Eval h and new cache state
newCacheState := make([]*mlx.Array, 0, len(m.Layers)*2)
for _, lc := range cache.Layers {
newCacheState = append(newCacheState, lc.State()...)
}
mlx.Eval(append([]*mlx.Array{h}, newCacheState...)...)
// Free old cache state (now that new state is evaluated)
for _, arr := range oldCacheState {
if arr != nil {
arr.Free()
}
}
// Final norm
preNormH := h
h = m.FinalNorm.Forward(h, cfg.RMSNormEps)
preNormH.Free()
// Get logits (h is already [1, 1, hidden_size])
h = mlx.Reshape(h, 1, cfg.HiddenSize)
logits := mlx.Matmul(h, mlx.Transpose(m.LMHead, 1, 0))
h.Free()
// Sample next token
nextToken = sampleVisualToken(logits, temperature, topP, cfg, seed, &sampleCounter)
logits.Free()
posOffset++
visualTokenIdx++
// Periodically clear cache to release intermediate memory
if i%256 == 0 {
mlx.ClearCache()
}
}
if len(visualTokens) == 0 {
// Return at least one token to avoid empty tensor issues
visualTokens = append(visualTokens, 0)
}
return mlx.NewArrayInt32(visualTokens, []int32{1, int32(len(visualTokens))})
}
// computeVisualTokenPositions computes MRoPE position IDs for a visual token
// Returns [3][1] position IDs for temporal, height, and width dimensions
//
// MRoPE position encoding for GLM-Image visual tokens:
// - temporal: CONSTANT within each grid (= decode_pos at grid start)
// - height: decode_pos + row index within grid
// - width: decode_pos + column index within grid
//
// Between grids, decode_pos advances by max(grid_h, grid_w) to ensure
// sufficient positional separation.
func computeVisualTokenPositions(
visualIdx int32, absPos int32, promptLen int32,
prevH, prevW, prevSize int32,
targetH, targetW int32,
) [][]int32 {
positions := make([][]int32, 3)
for dim := 0; dim < 3; dim++ {
positions[dim] = make([]int32, 1)
}
// First grid (prev grid) starts at decode_pos = promptLen
prevGridDecodePos := promptLen
// Second grid (target grid) starts after first grid
// next_pos = prev_decode_pos + max(prevH, prevW)
maxPrev := prevH
if prevW > maxPrev {
maxPrev = prevW
}
targetGridDecodePos := prevGridDecodePos + maxPrev
// Compute position IDs based on which grid the token is in
if visualIdx < prevSize {
// Token is in the prev grid (prev_token_h × prev_token_w)
row := visualIdx / prevW
col := visualIdx % prevW
// temporal is CONSTANT for all tokens in this grid
positions[0][0] = prevGridDecodePos
// height and width are relative to grid's decode_pos
positions[1][0] = prevGridDecodePos + row
positions[2][0] = prevGridDecodePos + col
} else {
// Token is in the target grid (token_h × token_w)
targetIdx := visualIdx - prevSize
row := targetIdx / targetW
col := targetIdx % targetW
// temporal is CONSTANT for all tokens in this grid
positions[0][0] = targetGridDecodePos
// height and width are relative to grid's decode_pos
positions[1][0] = targetGridDecodePos + row
positions[2][0] = targetGridDecodePos + col
}
_ = targetH // Used for documentation clarity
_ = absPos // No longer used - kept for API compatibility
return positions
}
// sampleVisualToken samples from the visual vocabulary using top-p (nucleus) sampling
// Note: For GLM-Image, greedy decoding is not allowed as it may cause repetitive outputs
// Returns a visual token ID in range [0, 16511] which directly indexes into the embedding table
// sampleCounter is incremented for each call to ensure different random values
func sampleVisualToken(logits *mlx.Array, temperature float32, topP float32, cfg *VisionLanguageConfig, seed int64, sampleCounter *int64) int32 {
// The LMHead outputs logits for visual tokens only (shape [1, 16512])
// Output index directly corresponds to vocab ID [0, 16511]
// No offset needed - the visual tokens are at vocab IDs [0, 16511]
visualLogits := logits
// Apply temperature
if temperature != 1.0 && temperature > 0 {
visualLogits = mlx.DivScalar(visualLogits, temperature)
}
// Apply softmax to get probabilities
probs := mlx.Softmax(visualLogits, -1)
mlx.Eval(probs)
// Get the sampled index using top-p sampling
// This directly gives us the vocab ID in [0, 16511]
// Special tokens: 16384 = BOS, 16385 = EOS
// Use seed + counter for reproducible but different random values
effectiveSeed := seed + *sampleCounter
*sampleCounter++
return sampleTopP(probs, topP, effectiveSeed)
}
// sampleTopP implements nucleus (top-p) sampling
// probs: [1, vocab_size] probability distribution
// topP: cumulative probability threshold (e.g., 0.75)
// seed: random seed for reproducible sampling
func sampleTopP(probs *mlx.Array, topP float32, seed int64) int32 {
// Negate probs for descending sort (Argsort only does ascending)
negProbs := mlx.MulScalar(probs, -1)
sortedIndices := mlx.Argsort(negProbs, -1)
sortedProbs := mlx.TakeAlongAxis(probs, sortedIndices, -1)
cumProbs := mlx.Cumsum(sortedProbs, -1)
mlx.Eval(sortedIndices, sortedProbs, cumProbs)
// Find cutoff index where cumulative probability exceeds topP
probsData := sortedProbs.Data()
cumProbsData := cumProbs.Data()
indicesData := sortedIndices.DataInt32()
// Calculate cutoff and renormalize
var cutoffIdx int
var totalProb float32
for i, cp := range cumProbsData {
totalProb += probsData[i]
if cp >= topP {
cutoffIdx = i + 1 // Include this token
break
}
}
if cutoffIdx == 0 {
cutoffIdx = len(probsData) // Use all tokens if topP is very high
}
// Sample from the truncated distribution
// Renormalize the truncated probabilities
truncatedProbs := make([]float32, cutoffIdx)
for i := 0; i < cutoffIdx; i++ {
truncatedProbs[i] = probsData[i] / totalProb
}
// Sample using random number with provided seed for reproducibility
r := mlx.RandomUniform([]int32{1}, uint64(seed))
mlx.Eval(r)
randVal := r.Data()[0]
// Find the sampled token
var cumulative float32
for i := 0; i < cutoffIdx; i++ {
cumulative += truncatedProbs[i]
if randVal < cumulative {
return indicesData[i]
}
}
// Fallback to the last token in truncated set
return indicesData[cutoffIdx-1]
}
// Forward for GLMBlock
func (b *GLMBlock) Forward(x *mlx.Array, seqLen int32, eps float32) *mlx.Array {
return b.ForwardWithCache(x, seqLen, 0, eps, nil, nil)
}
// ForwardWithCache performs block forward with optional KV caching and MRoPE
// positionIDs: [3][L] - position indices for MRoPE (nil = use sequential positions)
func (b *GLMBlock) ForwardWithCache(x *mlx.Array, seqLen int32, posOffset int32, eps float32, kvcache cache.Cache, positionIDs [][]int32) *mlx.Array {
// Pre-attention norm
normed := b.InputLayerNorm.Forward(x, eps)
// Self-attention with RoPE/MRoPE and cache
attnOut := b.SelfAttn.ForwardWithCache(normed, seqLen, posOffset, kvcache, positionIDs)
// Post-attention norm (GLM-4 style)
attnOut = b.PostSelfAttnNorm.Forward(attnOut, eps)
// Residual connection
x = mlx.Add(x, attnOut)
// Post-attention layer norm
normed = b.PostAttnLayerNorm.Forward(x, eps)
// MLP
mlpOut := b.MLP.Forward(normed)
// Post-MLP norm
mlpOut = b.PostMLPLayerNorm.Forward(mlpOut, eps)
// Residual connection
x = mlx.Add(x, mlpOut)
return x
}
// Forward for GLMAttention (without cache - used for prefill)
func (attn *GLMAttention) Forward(x *mlx.Array, seqLen int32) *mlx.Array {
return attn.ForwardWithCache(x, seqLen, 0, nil, nil)
}
// ForwardWithCache performs attention with optional KV caching and MRoPE
// posOffset is the position offset for RoPE (0 for prefill, cached_len for decode)
// positionIDs: [3][L] - if nil, uses sequential positions for all dims (text mode)
// kvcache is updated in-place if provided
func (attn *GLMAttention) ForwardWithCache(x *mlx.Array, seqLen int32, posOffset int32, kvcache cache.Cache, positionIDs [][]int32) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
// Q, K, V projections
q := mlx.Matmul(x, mlx.Transpose(attn.QProj, 1, 0))
k := mlx.Matmul(x, mlx.Transpose(attn.KProj, 1, 0))
v := mlx.Matmul(x, mlx.Transpose(attn.VProj, 1, 0))
// Add biases
if attn.QBias != nil {
q = mlx.Add(q, attn.QBias)
}
if attn.KBias != nil {
k = mlx.Add(k, attn.KBias)
}
if attn.VBias != nil {
v = mlx.Add(v, attn.VBias)
}
// Reshape to [B, L, nheads, head_dim]
q = mlx.Reshape(q, B, L, attn.NHeads, attn.HeadDim)
k = mlx.Reshape(k, B, L, attn.NKVHeads, attn.HeadDim)
v = mlx.Reshape(v, B, L, attn.NKVHeads, attn.HeadDim)
// Apply partial RoPE or MRoPE
rotaryDim := int32(float32(attn.HeadDim) * attn.PartialRotary)
if len(attn.MRoPESection) == 3 && positionIDs != nil {
// Use MRoPE with explicit position IDs
q = applyPartialMRoPE(q, positionIDs, rotaryDim, attn.RopeTheta, attn.MRoPESection)
k = applyPartialMRoPE(k, positionIDs, rotaryDim, attn.RopeTheta, attn.MRoPESection)
} else if len(attn.MRoPESection) == 3 {
// Use MRoPE with sequential positions (same for all dims - text mode)
seqPositions := make([][]int32, 3)
for dim := 0; dim < 3; dim++ {
seqPositions[dim] = make([]int32, L)
for i := int32(0); i < L; i++ {
seqPositions[dim][i] = i + posOffset
}
}
q = applyPartialMRoPE(q, seqPositions, rotaryDim, attn.RopeTheta, attn.MRoPESection)
k = applyPartialMRoPE(k, seqPositions, rotaryDim, attn.RopeTheta, attn.MRoPESection)
} else {
// Fallback to standard RoPE
q = applyPartialRoPEWithOffset(q, L, posOffset, rotaryDim, attn.RopeTheta)
k = applyPartialRoPEWithOffset(k, L, posOffset, rotaryDim, attn.RopeTheta)
}
// Transpose to [B, nheads, L, head_dim]
q = mlx.Transpose(q, 0, 2, 1, 3)
k = mlx.Transpose(k, 0, 2, 1, 3)
v = mlx.Transpose(v, 0, 2, 1, 3)
// Update cache and get full K, V for attention
if kvcache != nil {
k, v = kvcache.Update(k, v, int(L))
}
// Repeat KV for GQA
kExpanded := k
vExpanded := v
if attn.NKVHeads < attn.NHeads {
repeats := attn.NHeads / attn.NKVHeads
kExpanded = repeatKV(k, repeats)
vExpanded = repeatKV(v, repeats)
}
// Scaled dot-product attention with causal mask
out := mlx.ScaledDotProductAttention(q, kExpanded, vExpanded, attn.Scale, true)
// Transpose back [B, nheads, L, head_dim] -> [B, L, nheads, head_dim]
out = mlx.Transpose(out, 0, 2, 1, 3)
// Reshape to [B, L, hidden_size]
out = mlx.Reshape(out, B, L, attn.NHeads*attn.HeadDim)
// Output projection
out = mlx.Matmul(out, mlx.Transpose(attn.OProj, 1, 0))
return out
}
// applyPartialRoPE applies RoPE to only the first rotaryDim dimensions
func applyPartialRoPE(x *mlx.Array, seqLen int32, rotaryDim int32, theta float32) *mlx.Array {
return applyPartialRoPEWithOffset(x, seqLen, 0, rotaryDim, theta)
}
// applyPartialRoPEWithOffset applies RoPE with a position offset
func applyPartialRoPEWithOffset(x *mlx.Array, seqLen int32, posOffset int32, rotaryDim int32, theta float32) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
H := shape[2]
D := shape[3]
if rotaryDim <= 0 || rotaryDim > D {
rotaryDim = D
}
// Split into rotary and pass-through parts
xRot := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, L, H, rotaryDim})
xPass := mlx.Slice(x, []int32{0, 0, 0, rotaryDim}, []int32{B, L, H, D})
// Apply RoPE to rotary part with position offset
xRot = applyRoPEWithOffset(xRot, L, posOffset, theta)
// Concatenate back
return mlx.Concatenate([]*mlx.Array{xRot, xPass}, 3)
}
// applyPartialMRoPE applies Multi-dimensional RoPE (MRoPE) to the first rotaryDim dimensions
// positionIDs: [3, L] - position indices for each dimension (temporal, height, width)
// mrope_section: [8, 12, 12] - frequency pairs per dimension
// For text tokens: all 3 dimensions have the same sequential position
// For image tokens: temporal=seq_idx, height=row, width=col
func applyPartialMRoPE(x *mlx.Array, positionIDs [][]int32, rotaryDim int32, theta float32, mropeSection []int32) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
H := shape[2]
D := shape[3]
if rotaryDim <= 0 || rotaryDim > D {
rotaryDim = D
}
// Split into rotary and pass-through parts
xRot := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, L, H, rotaryDim})
xPass := mlx.Slice(x, []int32{0, 0, 0, rotaryDim}, []int32{B, L, H, D})
// Apply MRoPE to rotary part
xRot = applyMRoPE(xRot, positionIDs, theta, mropeSection)
// Concatenate back
return mlx.Concatenate([]*mlx.Array{xRot, xPass}, 3)
}
// applyMRoPE applies multi-dimensional rotary position embedding
// x: [B, L, H, D] where D is the rotary dimension
// positionIDs: [3][L] - positions for temporal, height, width dimensions
// mropeSection: [8, 12, 12] - frequency pairs per dimension
func applyMRoPE(x *mlx.Array, positionIDs [][]int32, theta float32, mropeSection []int32) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
H := shape[2]
D := shape[3]
half := D / 2
// Validate mrope_section sums to half (number of frequency pairs)
var totalPairs int32
for _, s := range mropeSection {
totalPairs += s
}
if totalPairs != half {
// Fallback to standard RoPE if section doesn't match
return applyRoPEWithOffset(x, L, 0, theta)
}
// Build angles for each position dimension (matching Python's MRoPE approach)
// Python: compute freqs for all dims, then apply_mrope selects freq ranges, then duplicate
// Order: [temporal_8, height_12, width_12] -> duplicate -> [t8, h12, w12, t8, h12, w12]
angleVals := make([]*mlx.Array, 3)
freqOffset := int32(0)
for dim := 0; dim < 3; dim++ {
numPairs := mropeSection[dim]
if numPairs == 0 {
continue
}
// Compute inverse frequencies for this section
// Each dimension uses DIFFERENT frequency ranges:
// - Temporal: frequencies 0 to section[0]-1
// - Height: frequencies section[0] to section[0]+section[1]-1
// - Width: frequencies section[0]+section[1] to sum(section)-1
freqsArr := make([]float32, numPairs)
for i := int32(0); i < numPairs; i++ {
globalIdx := freqOffset + i
freqsArr[i] = float32(1.0 / math.Pow(float64(theta), float64(2*globalIdx)/float64(D)))
}
freqs := mlx.NewArray(freqsArr, []int32{numPairs})
// Position indices for this dimension
posArr := make([]float32, L)
for i := int32(0); i < L; i++ {
posArr[i] = float32(positionIDs[dim][i])
}
pos := mlx.NewArray(posArr, []int32{L})
// Compute angles: [L, numPairs] = outer(pos, freqs)
posExpanded := mlx.Reshape(pos, L, 1)
freqsExpanded := mlx.Reshape(freqs, 1, numPairs)
angleVals[dim] = mlx.Mul(posExpanded, freqsExpanded)
freqOffset += numPairs
}
// Concatenate all sections: [L, half] = [L, 32]
allAngles := mlx.Concatenate(angleVals, 1)
// Duplicate AFTER concatenation: [L, D] = [L, 64]
// This gives: [temporal_8, height_12, width_12, temporal_8, height_12, width_12]
allAngles = mlx.Concatenate([]*mlx.Array{allAngles, allAngles}, 1)
// Compute cos/sin
allCos := mlx.Cos(allAngles)
allSin := mlx.Sin(allAngles)
// Reshape for broadcasting: [1, L, 1, D] to match x [B, L, H, D]
allCos = mlx.Reshape(allCos, 1, L, 1, D)
allSin = mlx.Reshape(allSin, 1, L, 1, D)
// x_rotated = cat([-x_imag, x_real], dim=-1)
x1 := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, L, H, half}) // x_real
x2 := mlx.Slice(x, []int32{0, 0, 0, half}, []int32{B, L, H, D}) // x_imag
x2Neg := mlx.MulScalar(x2, -1) // -x_imag
xRotated := mlx.Concatenate([]*mlx.Array{x2Neg, x1}, 3) // [-x_imag, x_real]
// out = x * cos + x_rotated * sin
return mlx.Add(mlx.Mul(x, allCos), mlx.Mul(xRotated, allSin))
}
// applyRoPE applies rotary position embedding
func applyRoPE(x *mlx.Array, seqLen int32, theta float32) *mlx.Array {
return applyRoPEWithOffset(x, seqLen, 0, theta)
}
// applyRoPEWithOffset applies rotary position embedding with position offset
// Uses the split-half approach (matches diffusers GLM-Image with use_real_unbind_dim=-2)
func applyRoPEWithOffset(x *mlx.Array, seqLen int32, posOffset int32, theta float32) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
H := shape[2]
D := shape[3]
half := D / 2
// Compute inverse frequencies: 1 / (theta^(2i/d))
freqsArr := make([]float32, half)
for i := int32(0); i < half; i++ {
freqsArr[i] = float32(1.0 / math.Pow(float64(theta), float64(2*i)/float64(D)))
}
freqs := mlx.NewArray(freqsArr, []int32{half})
// Position indices with offset
posArr := make([]float32, L)
for i := int32(0); i < L; i++ {
posArr[i] = float32(i + posOffset)
}
pos := mlx.NewArray(posArr, []int32{L})
// Compute angles: [L, half] = outer(pos, freqs)
posExpanded := mlx.Reshape(pos, L, 1)
freqsExpanded := mlx.Reshape(freqs, 1, half)
angles := mlx.Mul(posExpanded, freqsExpanded)
// Duplicate angles to match diffusers: cat([angles, angles], dim=-1) -> [L, D]
anglesDup := mlx.Concatenate([]*mlx.Array{angles, angles}, 1)
// Cos and sin: [L, 1, D] for broadcasting to [B, L, H, D]
cosVals := mlx.Cos(anglesDup)
sinVals := mlx.Sin(anglesDup)
cosVals = mlx.Reshape(cosVals, L, 1, D)
sinVals = mlx.Reshape(sinVals, L, 1, D)
// x_rotated = cat([-x_imag, x_real], dim=-1) where x_real=x[..., :half], x_imag=x[..., half:]
x1 := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, L, H, half}) // x_real
x2 := mlx.Slice(x, []int32{0, 0, 0, half}, []int32{B, L, H, D}) // x_imag
x2Neg := mlx.MulScalar(x2, -1) // -x_imag
xRotated := mlx.Concatenate([]*mlx.Array{x2Neg, x1}, 3) // [-x_imag, x_real]
// out = x * cos + x_rotated * sin
return mlx.Add(mlx.Mul(x, cosVals), mlx.Mul(xRotated, sinVals))
}
// repeatKV repeats key/value heads for GQA
func repeatKV(x *mlx.Array, repeats int32) *mlx.Array {
if repeats == 1 {
return x
}
shape := x.Shape()
// x: [B, nkvheads, L, head_dim]
x = mlx.ExpandDims(x, 2)
// x: [B, nkvheads, 1, L, head_dim]
x = mlx.Tile(x, []int32{1, 1, repeats, 1, 1})
// x: [B, nkvheads, repeats, L, head_dim]
return mlx.Reshape(x, shape[0], shape[1]*repeats, shape[2], shape[3])
}
// Forward for GLMMLP (fused gate_up SwiGLU)
func (m *GLMMLP) Forward(x *mlx.Array) *mlx.Array {
// gate_up_proj outputs [gate, up] concatenated
gateUp := mlx.Matmul(x, mlx.Transpose(m.GateUpProj, 1, 0))
shape := gateUp.Shape()
halfDim := shape[len(shape)-1] / 2
// Split into gate and up
gate := mlx.Slice(gateUp, []int32{0, 0, 0}, []int32{shape[0], shape[1], halfDim})
up := mlx.Slice(gateUp, []int32{0, 0, halfDim}, []int32{shape[0], shape[1], shape[2]})
// SwiGLU: silu(gate) * up
gate = mlx.SiLU(gate)
h := mlx.Mul(gate, up)
// Down projection
return mlx.Matmul(h, mlx.Transpose(m.DownProj, 1, 0))
}

View File

@@ -19,9 +19,15 @@ import (
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/models/glm_image"
"github.com/ollama/ollama/x/imagegen/models/zimage"
)
// ImageModel is the interface for image generation models
type ImageModel interface {
GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error)
}
// Request is the image generation request format
type Request struct {
Prompt string `json:"prompt"`
@@ -41,8 +47,9 @@ type Response struct {
// Server holds the model and handles requests
type Server struct {
mu sync.Mutex
model *zimage.Model
model ImageModel
modelName string
modelType string // "zimage" or "glm_image"
}
// Execute is the entry point for the image runner subprocess
@@ -72,15 +79,35 @@ func Execute(args []string) error {
requiredMemory/(1024*1024*1024), availableMemory/(1024*1024*1024))
}
// Load model
model := &zimage.Model{}
if err := model.Load(*modelName); err != nil {
return fmt.Errorf("failed to load model: %w", err)
// Detect model type and load appropriate model
modelType, err := detectModelType(*modelName)
if err != nil {
return fmt.Errorf("failed to detect model type: %w", err)
}
var model ImageModel
switch modelType {
case "GlmImagePipeline":
slog.Info("loading GLM-Image model")
m := &glm_image.Model{}
if err := m.Load(*modelName); err != nil {
return fmt.Errorf("failed to load GLM-Image model: %w", err)
}
model = m
default:
// Default to zimage for ZImagePipeline, FluxPipeline, and unknown types
slog.Info("loading Z-Image model")
m := &zimage.Model{}
if err := m.Load(*modelName); err != nil {
return fmt.Errorf("failed to load Z-Image model: %w", err)
}
model = m
}
server := &Server{
model: model,
modelName: *modelName,
modelType: modelType,
}
// Set up HTTP handlers
@@ -144,7 +171,13 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
req.Height = 1024
}
if req.Steps <= 0 {
req.Steps = 9
// Default steps depend on model type
switch s.modelType {
case "GlmImagePipeline":
req.Steps = 50 // GLM-Image default
default:
req.Steps = 9 // Z-Image turbo default
}
}
if req.Seed <= 0 {
req.Seed = time.Now().UnixNano()
@@ -159,25 +192,9 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
return
}
// Generate image
// Generate image using interface method
ctx := r.Context()
img, err := s.model.GenerateFromConfig(ctx, &zimage.GenerateConfig{
Prompt: req.Prompt,
Width: req.Width,
Height: req.Height,
Steps: req.Steps,
Seed: req.Seed,
Progress: func(step, total int) {
resp := Response{
Content: fmt.Sprintf("\rGenerating: step %d/%d", step, total),
Done: false,
}
data, _ := json.Marshal(resp)
w.Write(data)
w.Write([]byte("\n"))
flusher.Flush()
},
})
img, err := s.model.GenerateImage(ctx, req.Prompt, req.Width, req.Height, req.Steps, req.Seed)
if err != nil {
// Don't send error for cancellation
@@ -216,3 +233,35 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("\n"))
flusher.Flush()
}
// detectModelType reads the model manifest and returns the pipeline class name
func detectModelType(modelName string) (string, error) {
manifest, err := imagegen.LoadManifest(modelName)
if err != nil {
return "", err
}
data, err := manifest.ReadConfig("model_index.json")
if err != nil {
return "ZImagePipeline", nil // Default to Z-Image
}
// Try both _class_name (diffusers format) and architecture (ollama format)
var index struct {
ClassName string `json:"_class_name"`
Architecture string `json:"architecture"`
}
if err := json.Unmarshal(data, &index); err != nil {
return "ZImagePipeline", nil
}
// Prefer _class_name, fall back to architecture
className := index.ClassName
if className == "" {
className = index.Architecture
}
if className == "" {
return "ZImagePipeline", nil
}
return className, nil
}