Compare commits
2 Commits
main
...
imagegen-r
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bb1a5617b6 | ||
|
|
0d3648c1be |
@@ -322,7 +322,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
|
||||
### Web & Desktop
|
||||
|
||||
- [Onyx](https://github.com/onyx-dot-app/onyx)
|
||||
- [Open WebUI](https://github.com/open-webui/open-webui)
|
||||
- [SwiftChat (macOS with ReactNative)](https://github.com/aws-samples/swift-chat)
|
||||
- [Enchanted (macOS native)](https://github.com/AugustDev/enchanted)
|
||||
|
||||
@@ -116,7 +116,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
Prompt: ">>> ",
|
||||
AltPrompt: "... ",
|
||||
Placeholder: "Send a message (/? for help)",
|
||||
AltPlaceholder: "Press Enter to send",
|
||||
AltPlaceholder: `Use """ to end multi-line input`,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -21,7 +21,6 @@ ollama pull glm-4.7:cloud
|
||||
To use Ollama with tools that expect the Anthropic API (like Claude Code), set these environment variables:
|
||||
|
||||
```shell
|
||||
export ANTHROPIC_AUTH_TOKEN=ollama # required but ignored
|
||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||
export ANTHROPIC_API_KEY=ollama # required but ignored
|
||||
```
|
||||
@@ -248,13 +247,12 @@ curl -X POST http://localhost:11434/v1/messages \
|
||||
[Claude Code](https://code.claude.com/docs/en/overview) can be configured to use Ollama as its backend:
|
||||
|
||||
```shell
|
||||
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
|
||||
ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
|
||||
```
|
||||
|
||||
Or set the environment variables in your shell profile:
|
||||
|
||||
```shell
|
||||
export ANTHROPIC_AUTH_TOKEN=ollama
|
||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||
export ANTHROPIC_API_KEY=ollama
|
||||
```
|
||||
|
||||
@@ -110,7 +110,7 @@ More Ollama [Python example](https://github.com/ollama/ollama-python/blob/main/e
|
||||
import { Ollama } from "ollama";
|
||||
|
||||
const client = new Ollama();
|
||||
const results = await client.webSearch("what is ollama?");
|
||||
const results = await client.webSearch({ query: "what is ollama?" });
|
||||
console.log(JSON.stringify(results, null, 2));
|
||||
```
|
||||
|
||||
@@ -213,7 +213,7 @@ models](https://ollama.com/models)\n\nAvailable for macOS, Windows, and Linux',
|
||||
import { Ollama } from "ollama";
|
||||
|
||||
const client = new Ollama();
|
||||
const fetchResult = await client.webFetch("https://ollama.com");
|
||||
const fetchResult = await client.webFetch({ url: "https://ollama.com" });
|
||||
console.log(JSON.stringify(fetchResult, null, 2));
|
||||
```
|
||||
|
||||
|
||||
@@ -111,9 +111,7 @@
|
||||
"/integrations/zed",
|
||||
"/integrations/roo-code",
|
||||
"/integrations/n8n",
|
||||
"/integrations/xcode",
|
||||
"/integrations/onyx",
|
||||
"/integrations/marimo"
|
||||
"/integrations/xcode"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -22,7 +22,7 @@ Please refer to the [GPU docs](./gpu).
|
||||
|
||||
## How can I specify the context window size?
|
||||
|
||||
By default, Ollama uses a context window size of 4096 tokens.
|
||||
By default, Ollama uses a context window size of 2048 tokens.
|
||||
|
||||
This can be overridden with the `OLLAMA_CONTEXT_LENGTH` environment variable. For example, to set the default context window to 8K, use:
|
||||
|
||||
|
||||
|
Before Width: | Height: | Size: 174 KiB |
|
Before Width: | Height: | Size: 80 KiB |
|
Before Width: | Height: | Size: 230 KiB |
|
Before Width: | Height: | Size: 178 KiB |
|
Before Width: | Height: | Size: 186 KiB |
|
Before Width: | Height: | Size: 100 KiB |
|
Before Width: | Height: | Size: 306 KiB |
|
Before Width: | Height: | Size: 300 KiB |
|
Before Width: | Height: | Size: 211 KiB |
@@ -25,7 +25,6 @@ Claude Code connects to Ollama using the Anthropic-compatible API.
|
||||
1. Set the environment variables:
|
||||
|
||||
```shell
|
||||
export ANTHROPIC_AUTH_TOKEN=ollama
|
||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||
export ANTHROPIC_API_KEY=ollama
|
||||
```
|
||||
@@ -39,7 +38,7 @@ claude --model qwen3-coder
|
||||
Or run with environment variables inline:
|
||||
|
||||
```shell
|
||||
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
|
||||
ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
|
||||
```
|
||||
|
||||
## Connecting to ollama.com
|
||||
|
||||
@@ -1,73 +0,0 @@
|
||||
---
|
||||
title: marimo
|
||||
---
|
||||
|
||||
## Install
|
||||
|
||||
Install [marimo](https://marimo.io). You can use `pip` or `uv` for this. You
|
||||
can also use `uv` to create a sandboxed environment for marimo by running:
|
||||
|
||||
```
|
||||
uvx marimo edit --sandbox notebook.py
|
||||
```
|
||||
|
||||
## Usage with Ollama
|
||||
|
||||
1. In marimo, go to the user settings and go to the AI tab. From here
|
||||
you can find and configure Ollama as an AI provider. For local use you
|
||||
would typically point the base url to `http://localhost:11434/v1`.
|
||||
|
||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<img
|
||||
src="/images/marimo-settings.png"
|
||||
alt="Ollama settings in marimo"
|
||||
width="50%"
|
||||
/>
|
||||
</div>
|
||||
|
||||
2. Once the AI provider is set up, you can turn on/off specific AI models you'd like to access.
|
||||
|
||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<img
|
||||
src="/images/marimo-models.png"
|
||||
alt="Selecting an Ollama model"
|
||||
width="50%"
|
||||
/>
|
||||
</div>
|
||||
|
||||
3. You can also add a model to the list of available models by scrolling to the bottom and using the UI there.
|
||||
|
||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<img
|
||||
src="/images/marimo-add-model.png"
|
||||
alt="Adding a new Ollama model"
|
||||
width="50%"
|
||||
/>
|
||||
</div>
|
||||
|
||||
4. Once configured, you can now use Ollama for AI chats in marimo.
|
||||
|
||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<img
|
||||
src="/images/marimo-chat.png"
|
||||
alt="Configure code completion"
|
||||
width="50%"
|
||||
/>
|
||||
</div>
|
||||
|
||||
4. Alternatively, you can now use Ollama for **inline code completion** in marimo. This can be configured in the "AI Features" tab.
|
||||
|
||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<img
|
||||
src="/images/marimo-code-completion.png"
|
||||
alt="Configure code completion"
|
||||
width="50%"
|
||||
/>
|
||||
</div>
|
||||
|
||||
|
||||
## Connecting to ollama.com
|
||||
|
||||
1. Sign in to ollama cloud via `ollama signin`
|
||||
2. In the ollama model settings add a model that ollama hosts, like `gpt-oss:120b`.
|
||||
3. You can now refer to this model in marimo!
|
||||
@@ -1,63 +0,0 @@
|
||||
---
|
||||
title: Onyx
|
||||
---
|
||||
|
||||
## Overview
|
||||
[Onyx](http://onyx.app/) is a self-hostable Chat UI that integrates with all Ollama models. Features include:
|
||||
- Creating custom Agents
|
||||
- Web search
|
||||
- Deep Research
|
||||
- RAG over uploaded documents and connected apps
|
||||
- Connectors to applications like Google Drive, Email, Slack, etc.
|
||||
- MCP and OpenAPI Actions support
|
||||
- Image generation
|
||||
- User/Groups management, RBAC, SSO, etc.
|
||||
|
||||
Onyx can be deployed for single users or large organizations.
|
||||
|
||||
## Install Onyx
|
||||
|
||||
Deploy Onyx with the [quickstart guide](https://docs.onyx.app/deployment/getting_started/quickstart).
|
||||
|
||||
<Info>
|
||||
Resourcing/scaling docs [here](https://docs.onyx.app/deployment/getting_started/resourcing).
|
||||
</Info>
|
||||
|
||||
## Usage with Ollama
|
||||
|
||||
1. Login to your Onyx deployment (create an account first).
|
||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<img
|
||||
src="/images/onyx-login.png"
|
||||
alt="Onyx Login Page"
|
||||
width="75%"
|
||||
/>
|
||||
</div>
|
||||
2. In the set-up process select `Ollama` as the LLM provider.
|
||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<img
|
||||
src="/images/onyx-ollama-llm.png"
|
||||
alt="Onyx Set Up Form"
|
||||
width="75%"
|
||||
/>
|
||||
</div>
|
||||
3. Provide your **Ollama API URL** and select your models.
|
||||
<Note>If you're running Onyx in Docker, to access your computer's local network use `http://host.docker.internal` instead of `http://127.0.0.1`.</Note>
|
||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<img
|
||||
src="/images/onyx-ollama-form.png"
|
||||
alt="Selecting Ollama Models"
|
||||
width="75%"
|
||||
/>
|
||||
</div>
|
||||
|
||||
You can also easily connect up Onyx Cloud with the `Ollama Cloud` tab of the setup.
|
||||
|
||||
## Send your first query
|
||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<img
|
||||
src="/images/onyx-query.png"
|
||||
alt="Onyx Query Example"
|
||||
width="75%"
|
||||
/>
|
||||
</div>
|
||||
@@ -1,5 +1,5 @@
|
||||
---
|
||||
title: Linux
|
||||
title: "Linux"
|
||||
---
|
||||
|
||||
## Install
|
||||
@@ -13,15 +13,14 @@ curl -fsSL https://ollama.com/install.sh | sh
|
||||
## Manual install
|
||||
|
||||
<Note>
|
||||
If you are upgrading from a prior version, you should remove the old libraries
|
||||
with `sudo rm -rf /usr/lib/ollama` first.
|
||||
If you are upgrading from a prior version, you should remove the old libraries with `sudo rm -rf /usr/lib/ollama` first.
|
||||
</Note>
|
||||
|
||||
Download and extract the package:
|
||||
|
||||
```shell
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tar.zst \
|
||||
| sudo tar x -C /usr
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tgz \
|
||||
| sudo tar zx -C /usr
|
||||
```
|
||||
|
||||
Start Ollama:
|
||||
@@ -41,8 +40,8 @@ ollama -v
|
||||
If you have an AMD GPU, also download and extract the additional ROCm package:
|
||||
|
||||
```shell
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tar.zst \
|
||||
| sudo tar x -C /usr
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tgz \
|
||||
| sudo tar zx -C /usr
|
||||
```
|
||||
|
||||
### ARM64 install
|
||||
@@ -50,8 +49,8 @@ curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tar.zst \
|
||||
Download and extract the ARM64-specific package:
|
||||
|
||||
```shell
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-arm64.tar.zst \
|
||||
| sudo tar x -C /usr
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-arm64.tgz \
|
||||
| sudo tar zx -C /usr
|
||||
```
|
||||
|
||||
### Adding Ollama as a startup service (recommended)
|
||||
@@ -113,11 +112,7 @@ sudo systemctl status ollama
|
||||
```
|
||||
|
||||
<Note>
|
||||
While AMD has contributed the `amdgpu` driver upstream to the official linux
|
||||
kernel source, the version is older and may not support all ROCm features. We
|
||||
recommend you install the latest driver from
|
||||
https://www.amd.com/en/support/linux-drivers for best support of your Radeon
|
||||
GPU.
|
||||
While AMD has contributed the `amdgpu` driver upstream to the official linux kernel source, the version is older and may not support all ROCm features. We recommend you install the latest driver from https://www.amd.com/en/support/linux-drivers for best support of your Radeon GPU.
|
||||
</Note>
|
||||
|
||||
## Customizing
|
||||
@@ -146,8 +141,8 @@ curl -fsSL https://ollama.com/install.sh | sh
|
||||
Or by re-downloading Ollama:
|
||||
|
||||
```shell
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tar.zst \
|
||||
| sudo tar x -C /usr
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tgz \
|
||||
| sudo tar zx -C /usr
|
||||
```
|
||||
|
||||
## Installing specific versions
|
||||
@@ -196,4 +191,4 @@ Remove the downloaded models and Ollama service user and group:
|
||||
sudo userdel ollama
|
||||
sudo groupdel ollama
|
||||
sudo rm -r /usr/share/ollama
|
||||
```
|
||||
```
|
||||
@@ -131,7 +131,7 @@ func TestAPIToolCalling(t *testing.T) {
|
||||
t.Errorf("unexpected tool called: got %q want %q", lastToolCall.Function.Name, "get_weather")
|
||||
}
|
||||
|
||||
if _, ok := lastToolCall.Function.Arguments.Get("location"); !ok {
|
||||
if _, ok := lastToolCall.Function.Arguments["location"]; !ok {
|
||||
t.Errorf("expected tool arguments to include 'location', got: %s", lastToolCall.Function.Arguments.String())
|
||||
}
|
||||
case <-ctx.Done():
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
@@ -442,7 +441,6 @@ type ResponsesWriter struct {
|
||||
stream bool
|
||||
responseID string
|
||||
itemID string
|
||||
request openai.ResponsesRequest
|
||||
}
|
||||
|
||||
func (w *ResponsesWriter) writeEvent(eventType string, data any) error {
|
||||
@@ -480,9 +478,7 @@ func (w *ResponsesWriter) writeResponse(data []byte) (int, error) {
|
||||
|
||||
// Non-streaming response
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
response := openai.ToResponse(w.model, w.responseID, w.itemID, chatResponse, w.request)
|
||||
completedAt := time.Now().Unix()
|
||||
response.CompletedAt = &completedAt
|
||||
response := openai.ToResponse(w.model, w.responseID, w.itemID, chatResponse)
|
||||
return len(data), json.NewEncoder(w.ResponseWriter).Encode(response)
|
||||
}
|
||||
|
||||
@@ -527,12 +523,11 @@ func ResponsesMiddleware() gin.HandlerFunc {
|
||||
|
||||
w := &ResponsesWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
converter: openai.NewResponsesStreamConverter(responseID, itemID, req.Model, req),
|
||||
converter: openai.NewResponsesStreamConverter(responseID, itemID, req.Model),
|
||||
model: req.Model,
|
||||
stream: streamRequested,
|
||||
responseID: responseID,
|
||||
itemID: itemID,
|
||||
request: req,
|
||||
}
|
||||
|
||||
// Set headers based on streaming mode
|
||||
|
||||
@@ -630,10 +630,6 @@ func nameFromToolCallID(messages []Message, toolCallID string) string {
|
||||
|
||||
// decodeImageURL decodes a base64 data URI into raw image bytes.
|
||||
func decodeImageURL(url string) (api.ImageData, error) {
|
||||
if strings.HasPrefix(url, "http://") || strings.HasPrefix(url, "https://") {
|
||||
return nil, errors.New("image URLs are not currently supported, please use base64 encoded data instead")
|
||||
}
|
||||
|
||||
types := []string{"jpeg", "jpg", "png", "webp"}
|
||||
|
||||
// Support blank mime type to match /api/chat's behavior of taking just unadorned base64
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
@@ -266,9 +265,9 @@ type ResponsesText struct {
|
||||
type ResponsesTool struct {
|
||||
Type string `json:"type"` // "function"
|
||||
Name string `json:"name"`
|
||||
Description *string `json:"description"` // nullable but required
|
||||
Strict *bool `json:"strict"` // nullable but required
|
||||
Parameters map[string]any `json:"parameters"` // nullable but required
|
||||
Description string `json:"description,omitempty"`
|
||||
Strict bool `json:"strict,omitempty"`
|
||||
Parameters map[string]any `json:"parameters,omitempty"`
|
||||
}
|
||||
|
||||
type ResponsesRequest struct {
|
||||
@@ -476,16 +475,11 @@ func convertTool(t ResponsesTool) (api.Tool, error) {
|
||||
}
|
||||
}
|
||||
|
||||
var description string
|
||||
if t.Description != nil {
|
||||
description = *t.Description
|
||||
}
|
||||
|
||||
return api.Tool{
|
||||
Type: t.Type,
|
||||
Function: api.ToolFunction{
|
||||
Name: t.Name,
|
||||
Description: description,
|
||||
Description: t.Description,
|
||||
Parameters: params,
|
||||
},
|
||||
}, nil
|
||||
@@ -522,60 +516,17 @@ func convertInputMessage(m ResponsesInputMessage) (api.Message, error) {
|
||||
|
||||
// Response types for the Responses API
|
||||
|
||||
// ResponsesTextField represents the text output configuration in the response.
|
||||
type ResponsesTextField struct {
|
||||
Format ResponsesTextFormat `json:"format"`
|
||||
}
|
||||
|
||||
// ResponsesReasoningOutput represents reasoning configuration in the response.
|
||||
type ResponsesReasoningOutput struct {
|
||||
Effort *string `json:"effort,omitempty"`
|
||||
Summary *string `json:"summary,omitempty"`
|
||||
}
|
||||
|
||||
// ResponsesError represents an error in the response.
|
||||
type ResponsesError struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// ResponsesIncompleteDetails represents details about why a response was incomplete.
|
||||
type ResponsesIncompleteDetails struct {
|
||||
Reason string `json:"reason"`
|
||||
}
|
||||
|
||||
type ResponsesResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
CompletedAt *int64 `json:"completed_at"`
|
||||
Status string `json:"status"`
|
||||
IncompleteDetails *ResponsesIncompleteDetails `json:"incomplete_details"`
|
||||
Model string `json:"model"`
|
||||
PreviousResponseID *string `json:"previous_response_id"`
|
||||
Instructions *string `json:"instructions"`
|
||||
Output []ResponsesOutputItem `json:"output"`
|
||||
Error *ResponsesError `json:"error"`
|
||||
Tools []ResponsesTool `json:"tools"`
|
||||
ToolChoice any `json:"tool_choice"`
|
||||
Truncation string `json:"truncation"`
|
||||
ParallelToolCalls bool `json:"parallel_tool_calls"`
|
||||
Text ResponsesTextField `json:"text"`
|
||||
TopP float64 `json:"top_p"`
|
||||
PresencePenalty float64 `json:"presence_penalty"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty"`
|
||||
TopLogprobs int `json:"top_logprobs"`
|
||||
Temperature float64 `json:"temperature"`
|
||||
Reasoning *ResponsesReasoningOutput `json:"reasoning"`
|
||||
Usage *ResponsesUsage `json:"usage"`
|
||||
MaxOutputTokens *int `json:"max_output_tokens"`
|
||||
MaxToolCalls *int `json:"max_tool_calls"`
|
||||
Store bool `json:"store"`
|
||||
Background bool `json:"background"`
|
||||
ServiceTier string `json:"service_tier"`
|
||||
Metadata map[string]any `json:"metadata"`
|
||||
SafetyIdentifier *string `json:"safety_identifier"`
|
||||
PromptCacheKey *string `json:"prompt_cache_key"`
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
Status string `json:"status"`
|
||||
Model string `json:"model"`
|
||||
Output []ResponsesOutputItem `json:"output"`
|
||||
Usage *ResponsesUsage `json:"usage,omitempty"`
|
||||
// TODO(drifkin): add `temperature` and `top_p` to the response, but this
|
||||
// requires additional plumbing to find the effective values since the
|
||||
// defaults can come from the model or the request
|
||||
}
|
||||
|
||||
type ResponsesOutputItem struct {
|
||||
@@ -599,39 +550,18 @@ type ResponsesReasoningSummary struct {
|
||||
}
|
||||
|
||||
type ResponsesOutputContent struct {
|
||||
Type string `json:"type"` // "output_text"
|
||||
Text string `json:"text"`
|
||||
Annotations []any `json:"annotations"`
|
||||
Logprobs []any `json:"logprobs"`
|
||||
}
|
||||
|
||||
type ResponsesInputTokensDetails struct {
|
||||
CachedTokens int `json:"cached_tokens"`
|
||||
}
|
||||
|
||||
type ResponsesOutputTokensDetails struct {
|
||||
ReasoningTokens int `json:"reasoning_tokens"`
|
||||
Type string `json:"type"` // "output_text"
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
type ResponsesUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
InputTokensDetails ResponsesInputTokensDetails `json:"input_tokens_details"`
|
||||
OutputTokensDetails ResponsesOutputTokensDetails `json:"output_tokens_details"`
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
// derefFloat64 returns the value of a float64 pointer, or a default if nil.
|
||||
func derefFloat64(p *float64, def float64) float64 {
|
||||
if p != nil {
|
||||
return *p
|
||||
}
|
||||
return def
|
||||
}
|
||||
|
||||
// ToResponse converts an api.ChatResponse to a Responses API response.
|
||||
// The request is used to echo back request parameters in the response.
|
||||
func ToResponse(model, responseID, itemID string, chatResponse api.ChatResponse, request ResponsesRequest) ResponsesResponse {
|
||||
// ToResponse converts an api.ChatResponse to a Responses API response
|
||||
func ToResponse(model, responseID, itemID string, chatResponse api.ChatResponse) ResponsesResponse {
|
||||
var output []ResponsesOutputItem
|
||||
|
||||
// Add reasoning item if thinking is present
|
||||
@@ -655,7 +585,6 @@ func ToResponse(model, responseID, itemID string, chatResponse api.ChatResponse,
|
||||
output = append(output, ResponsesOutputItem{
|
||||
ID: fmt.Sprintf("fc_%s_%d", responseID, i),
|
||||
Type: "function_call",
|
||||
Status: "completed",
|
||||
CallID: tc.ID,
|
||||
Name: tc.Function.Name,
|
||||
Arguments: tc.Function.Arguments,
|
||||
@@ -669,90 +598,25 @@ func ToResponse(model, responseID, itemID string, chatResponse api.ChatResponse,
|
||||
Role: "assistant",
|
||||
Content: []ResponsesOutputContent{
|
||||
{
|
||||
Type: "output_text",
|
||||
Text: chatResponse.Message.Content,
|
||||
Annotations: []any{},
|
||||
Logprobs: []any{},
|
||||
Type: "output_text",
|
||||
Text: chatResponse.Message.Content,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
var instructions *string
|
||||
if request.Instructions != "" {
|
||||
instructions = &request.Instructions
|
||||
}
|
||||
|
||||
// Build truncation with default
|
||||
truncation := "disabled"
|
||||
if request.Truncation != nil {
|
||||
truncation = *request.Truncation
|
||||
}
|
||||
|
||||
tools := request.Tools
|
||||
if tools == nil {
|
||||
tools = []ResponsesTool{}
|
||||
}
|
||||
|
||||
text := ResponsesTextField{
|
||||
Format: ResponsesTextFormat{Type: "text"},
|
||||
}
|
||||
if request.Text != nil && request.Text.Format != nil {
|
||||
text.Format = *request.Text.Format
|
||||
}
|
||||
|
||||
// Build reasoning output from request
|
||||
var reasoning *ResponsesReasoningOutput
|
||||
if request.Reasoning.Effort != "" || request.Reasoning.Summary != "" {
|
||||
reasoning = &ResponsesReasoningOutput{}
|
||||
if request.Reasoning.Effort != "" {
|
||||
reasoning.Effort = &request.Reasoning.Effort
|
||||
}
|
||||
if request.Reasoning.Summary != "" {
|
||||
reasoning.Summary = &request.Reasoning.Summary
|
||||
}
|
||||
}
|
||||
|
||||
return ResponsesResponse{
|
||||
ID: responseID,
|
||||
Object: "response",
|
||||
CreatedAt: chatResponse.CreatedAt.Unix(),
|
||||
CompletedAt: nil, // Set by middleware when writing final response
|
||||
Status: "completed",
|
||||
IncompleteDetails: nil, // Only populated if response incomplete
|
||||
Model: model,
|
||||
PreviousResponseID: nil, // Not supported
|
||||
Instructions: instructions,
|
||||
Output: output,
|
||||
Error: nil, // Only populated on failure
|
||||
Tools: tools,
|
||||
ToolChoice: "auto", // Default value
|
||||
Truncation: truncation,
|
||||
ParallelToolCalls: true, // Default value
|
||||
Text: text,
|
||||
TopP: derefFloat64(request.TopP, 1.0),
|
||||
PresencePenalty: 0, // Default value
|
||||
FrequencyPenalty: 0, // Default value
|
||||
TopLogprobs: 0, // Default value
|
||||
Temperature: derefFloat64(request.Temperature, 1.0),
|
||||
Reasoning: reasoning,
|
||||
ID: responseID,
|
||||
Object: "response",
|
||||
CreatedAt: chatResponse.CreatedAt.Unix(),
|
||||
Status: "completed",
|
||||
Model: model,
|
||||
Output: output,
|
||||
Usage: &ResponsesUsage{
|
||||
InputTokens: chatResponse.PromptEvalCount,
|
||||
OutputTokens: chatResponse.EvalCount,
|
||||
TotalTokens: chatResponse.PromptEvalCount + chatResponse.EvalCount,
|
||||
// TODO(drifkin): wire through the actual values
|
||||
InputTokensDetails: ResponsesInputTokensDetails{CachedTokens: 0},
|
||||
// TODO(drifkin): wire through the actual values
|
||||
OutputTokensDetails: ResponsesOutputTokensDetails{ReasoningTokens: 0},
|
||||
},
|
||||
MaxOutputTokens: request.MaxOutputTokens,
|
||||
MaxToolCalls: nil, // Not supported
|
||||
Store: false, // We don't store responses
|
||||
Background: request.Background,
|
||||
ServiceTier: "default", // Default value
|
||||
Metadata: map[string]any{},
|
||||
SafetyIdentifier: nil, // Not supported
|
||||
PromptCacheKey: nil, // Not supported
|
||||
}
|
||||
}
|
||||
|
||||
@@ -772,7 +636,6 @@ type ResponsesStreamConverter struct {
|
||||
responseID string
|
||||
itemID string
|
||||
model string
|
||||
request ResponsesRequest
|
||||
|
||||
// State tracking (mutated across Process calls)
|
||||
firstWrite bool
|
||||
@@ -805,12 +668,11 @@ func (c *ResponsesStreamConverter) newEvent(eventType string, data map[string]an
|
||||
}
|
||||
|
||||
// NewResponsesStreamConverter creates a new converter with the given configuration.
|
||||
func NewResponsesStreamConverter(responseID, itemID, model string, request ResponsesRequest) *ResponsesStreamConverter {
|
||||
func NewResponsesStreamConverter(responseID, itemID, model string) *ResponsesStreamConverter {
|
||||
return &ResponsesStreamConverter{
|
||||
responseID: responseID,
|
||||
itemID: itemID,
|
||||
model: model,
|
||||
request: request,
|
||||
firstWrite: true,
|
||||
}
|
||||
}
|
||||
@@ -855,120 +717,25 @@ func (c *ResponsesStreamConverter) Process(r api.ChatResponse) []ResponsesStream
|
||||
return events
|
||||
}
|
||||
|
||||
// buildResponseObject creates a full response object with all required fields for streaming events.
|
||||
func (c *ResponsesStreamConverter) buildResponseObject(status string, output []any, usage map[string]any) map[string]any {
|
||||
var instructions any = nil
|
||||
if c.request.Instructions != "" {
|
||||
instructions = c.request.Instructions
|
||||
}
|
||||
|
||||
truncation := "disabled"
|
||||
if c.request.Truncation != nil {
|
||||
truncation = *c.request.Truncation
|
||||
}
|
||||
|
||||
var tools []any
|
||||
if c.request.Tools != nil {
|
||||
for _, t := range c.request.Tools {
|
||||
tools = append(tools, map[string]any{
|
||||
"type": t.Type,
|
||||
"name": t.Name,
|
||||
"description": t.Description,
|
||||
"strict": t.Strict,
|
||||
"parameters": t.Parameters,
|
||||
})
|
||||
}
|
||||
}
|
||||
if tools == nil {
|
||||
tools = []any{}
|
||||
}
|
||||
|
||||
textFormat := map[string]any{"type": "text"}
|
||||
if c.request.Text != nil && c.request.Text.Format != nil {
|
||||
textFormat = map[string]any{
|
||||
"type": c.request.Text.Format.Type,
|
||||
}
|
||||
if c.request.Text.Format.Name != "" {
|
||||
textFormat["name"] = c.request.Text.Format.Name
|
||||
}
|
||||
if c.request.Text.Format.Schema != nil {
|
||||
textFormat["schema"] = c.request.Text.Format.Schema
|
||||
}
|
||||
if c.request.Text.Format.Strict != nil {
|
||||
textFormat["strict"] = *c.request.Text.Format.Strict
|
||||
}
|
||||
}
|
||||
|
||||
var reasoning any = nil
|
||||
if c.request.Reasoning.Effort != "" || c.request.Reasoning.Summary != "" {
|
||||
r := map[string]any{}
|
||||
if c.request.Reasoning.Effort != "" {
|
||||
r["effort"] = c.request.Reasoning.Effort
|
||||
} else {
|
||||
r["effort"] = nil
|
||||
}
|
||||
if c.request.Reasoning.Summary != "" {
|
||||
r["summary"] = c.request.Reasoning.Summary
|
||||
} else {
|
||||
r["summary"] = nil
|
||||
}
|
||||
reasoning = r
|
||||
}
|
||||
|
||||
// Build top_p and temperature with defaults
|
||||
topP := 1.0
|
||||
if c.request.TopP != nil {
|
||||
topP = *c.request.TopP
|
||||
}
|
||||
temperature := 1.0
|
||||
if c.request.Temperature != nil {
|
||||
temperature = *c.request.Temperature
|
||||
}
|
||||
|
||||
return map[string]any{
|
||||
"id": c.responseID,
|
||||
"object": "response",
|
||||
"created_at": time.Now().Unix(),
|
||||
"completed_at": nil,
|
||||
"status": status,
|
||||
"incomplete_details": nil,
|
||||
"model": c.model,
|
||||
"previous_response_id": nil,
|
||||
"instructions": instructions,
|
||||
"output": output,
|
||||
"error": nil,
|
||||
"tools": tools,
|
||||
"tool_choice": "auto",
|
||||
"truncation": truncation,
|
||||
"parallel_tool_calls": true,
|
||||
"text": map[string]any{"format": textFormat},
|
||||
"top_p": topP,
|
||||
"presence_penalty": 0,
|
||||
"frequency_penalty": 0,
|
||||
"top_logprobs": 0,
|
||||
"temperature": temperature,
|
||||
"reasoning": reasoning,
|
||||
"usage": usage,
|
||||
"max_output_tokens": c.request.MaxOutputTokens,
|
||||
"max_tool_calls": nil,
|
||||
"store": false,
|
||||
"background": c.request.Background,
|
||||
"service_tier": "default",
|
||||
"metadata": map[string]any{},
|
||||
"safety_identifier": nil,
|
||||
"prompt_cache_key": nil,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ResponsesStreamConverter) createResponseCreatedEvent() ResponsesStreamEvent {
|
||||
return c.newEvent("response.created", map[string]any{
|
||||
"response": c.buildResponseObject("in_progress", []any{}, nil),
|
||||
"response": map[string]any{
|
||||
"id": c.responseID,
|
||||
"object": "response",
|
||||
"status": "in_progress",
|
||||
"output": []any{},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (c *ResponsesStreamConverter) createResponseInProgressEvent() ResponsesStreamEvent {
|
||||
return c.newEvent("response.in_progress", map[string]any{
|
||||
"response": c.buildResponseObject("in_progress", []any{}, nil),
|
||||
"response": map[string]any{
|
||||
"id": c.responseID,
|
||||
"object": "response",
|
||||
"status": "in_progress",
|
||||
"output": []any{},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
@@ -995,10 +762,9 @@ func (c *ResponsesStreamConverter) processThinking(thinking string) []ResponsesS
|
||||
|
||||
// Emit delta
|
||||
events = append(events, c.newEvent("response.reasoning_summary_text.delta", map[string]any{
|
||||
"item_id": c.reasoningItemID,
|
||||
"output_index": c.outputIndex,
|
||||
"summary_index": 0,
|
||||
"delta": thinking,
|
||||
"item_id": c.reasoningItemID,
|
||||
"output_index": c.outputIndex,
|
||||
"delta": thinking,
|
||||
}))
|
||||
|
||||
// TODO(drifkin): consider adding
|
||||
@@ -1017,10 +783,9 @@ func (c *ResponsesStreamConverter) finishReasoning() []ResponsesStreamEvent {
|
||||
|
||||
events := []ResponsesStreamEvent{
|
||||
c.newEvent("response.reasoning_summary_text.done", map[string]any{
|
||||
"item_id": c.reasoningItemID,
|
||||
"output_index": c.outputIndex,
|
||||
"summary_index": 0,
|
||||
"text": c.accumulatedThinking,
|
||||
"item_id": c.reasoningItemID,
|
||||
"output_index": c.outputIndex,
|
||||
"text": c.accumulatedThinking,
|
||||
}),
|
||||
c.newEvent("response.output_item.done", map[string]any{
|
||||
"output_index": c.outputIndex,
|
||||
@@ -1133,10 +898,8 @@ func (c *ResponsesStreamConverter) processTextContent(content string) []Response
|
||||
"output_index": c.outputIndex,
|
||||
"content_index": c.contentIndex,
|
||||
"part": map[string]any{
|
||||
"type": "output_text",
|
||||
"text": "",
|
||||
"annotations": []any{},
|
||||
"logprobs": []any{},
|
||||
"type": "output_text",
|
||||
"text": "",
|
||||
},
|
||||
}))
|
||||
}
|
||||
@@ -1150,7 +913,6 @@ func (c *ResponsesStreamConverter) processTextContent(content string) []Response
|
||||
"output_index": c.outputIndex,
|
||||
"content_index": 0,
|
||||
"delta": content,
|
||||
"logprobs": []any{},
|
||||
}))
|
||||
|
||||
return events
|
||||
@@ -1182,10 +944,8 @@ func (c *ResponsesStreamConverter) buildFinalOutput() []any {
|
||||
"status": "completed",
|
||||
"role": "assistant",
|
||||
"content": []map[string]any{{
|
||||
"type": "output_text",
|
||||
"text": c.accumulatedText,
|
||||
"annotations": []any{},
|
||||
"logprobs": []any{},
|
||||
"type": "output_text",
|
||||
"text": c.accumulatedText,
|
||||
}},
|
||||
})
|
||||
}
|
||||
@@ -1207,7 +967,6 @@ func (c *ResponsesStreamConverter) processCompletion(r api.ChatResponse) []Respo
|
||||
"output_index": c.outputIndex,
|
||||
"content_index": 0,
|
||||
"text": c.accumulatedText,
|
||||
"logprobs": []any{},
|
||||
}))
|
||||
|
||||
// response.content_part.done
|
||||
@@ -1216,10 +975,8 @@ func (c *ResponsesStreamConverter) processCompletion(r api.ChatResponse) []Respo
|
||||
"output_index": c.outputIndex,
|
||||
"content_index": 0,
|
||||
"part": map[string]any{
|
||||
"type": "output_text",
|
||||
"text": c.accumulatedText,
|
||||
"annotations": []any{},
|
||||
"logprobs": []any{},
|
||||
"type": "output_text",
|
||||
"text": c.accumulatedText,
|
||||
},
|
||||
}))
|
||||
|
||||
@@ -1232,31 +989,26 @@ func (c *ResponsesStreamConverter) processCompletion(r api.ChatResponse) []Respo
|
||||
"status": "completed",
|
||||
"role": "assistant",
|
||||
"content": []map[string]any{{
|
||||
"type": "output_text",
|
||||
"text": c.accumulatedText,
|
||||
"annotations": []any{},
|
||||
"logprobs": []any{},
|
||||
"type": "output_text",
|
||||
"text": c.accumulatedText,
|
||||
}},
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
// response.completed
|
||||
usage := map[string]any{
|
||||
"input_tokens": r.PromptEvalCount,
|
||||
"output_tokens": r.EvalCount,
|
||||
"total_tokens": r.PromptEvalCount + r.EvalCount,
|
||||
"input_tokens_details": map[string]any{
|
||||
"cached_tokens": 0,
|
||||
},
|
||||
"output_tokens_details": map[string]any{
|
||||
"reasoning_tokens": 0,
|
||||
},
|
||||
}
|
||||
response := c.buildResponseObject("completed", c.buildFinalOutput(), usage)
|
||||
response["completed_at"] = time.Now().Unix()
|
||||
events = append(events, c.newEvent("response.completed", map[string]any{
|
||||
"response": response,
|
||||
"response": map[string]any{
|
||||
"id": c.responseID,
|
||||
"object": "response",
|
||||
"status": "completed",
|
||||
"output": c.buildFinalOutput(),
|
||||
"usage": map[string]any{
|
||||
"input_tokens": r.PromptEvalCount,
|
||||
"output_tokens": r.EvalCount,
|
||||
"total_tokens": r.PromptEvalCount + r.EvalCount,
|
||||
},
|
||||
},
|
||||
}))
|
||||
|
||||
return events
|
||||
|
||||
@@ -850,7 +850,7 @@ func TestFromResponsesRequest_Images(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestResponsesStreamConverter_TextOnly(t *testing.T) {
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||
|
||||
// First chunk with content
|
||||
events := converter.Process(api.ChatResponse{
|
||||
@@ -916,7 +916,7 @@ func TestResponsesStreamConverter_TextOnly(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestResponsesStreamConverter_ToolCalls(t *testing.T) {
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||
|
||||
events := converter.Process(api.ChatResponse{
|
||||
Message: api.Message{
|
||||
@@ -952,7 +952,7 @@ func TestResponsesStreamConverter_ToolCalls(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestResponsesStreamConverter_Reasoning(t *testing.T) {
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||
|
||||
// First chunk with thinking
|
||||
events := converter.Process(api.ChatResponse{
|
||||
@@ -1267,7 +1267,7 @@ func TestToResponse_WithReasoning(t *testing.T) {
|
||||
Content: "The answer is 42",
|
||||
},
|
||||
Done: true,
|
||||
}, ResponsesRequest{})
|
||||
})
|
||||
|
||||
// Should have 2 output items: reasoning + message
|
||||
if len(response.Output) != 2 {
|
||||
@@ -1638,7 +1638,7 @@ func TestFromResponsesRequest_ShorthandFormats(t *testing.T) {
|
||||
|
||||
func TestResponsesStreamConverter_OutputIncludesContent(t *testing.T) {
|
||||
// Verify that response.output_item.done includes content field for messages
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||
|
||||
// First chunk
|
||||
converter.Process(api.ChatResponse{
|
||||
@@ -1686,7 +1686,7 @@ func TestResponsesStreamConverter_OutputIncludesContent(t *testing.T) {
|
||||
|
||||
func TestResponsesStreamConverter_ResponseCompletedIncludesOutput(t *testing.T) {
|
||||
// Verify that response.completed includes the output array
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||
|
||||
// Process some content
|
||||
converter.Process(api.ChatResponse{
|
||||
@@ -1730,7 +1730,7 @@ func TestResponsesStreamConverter_ResponseCompletedIncludesOutput(t *testing.T)
|
||||
|
||||
func TestResponsesStreamConverter_ResponseCreatedIncludesOutput(t *testing.T) {
|
||||
// Verify that response.created includes an empty output array
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||
|
||||
events := converter.Process(api.ChatResponse{
|
||||
Message: api.Message{Content: "Hi"},
|
||||
@@ -1757,7 +1757,7 @@ func TestResponsesStreamConverter_ResponseCreatedIncludesOutput(t *testing.T) {
|
||||
|
||||
func TestResponsesStreamConverter_SequenceNumbers(t *testing.T) {
|
||||
// Verify that events include incrementing sequence numbers
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||
|
||||
events := converter.Process(api.ChatResponse{
|
||||
Message: api.Message{Content: "Hello"},
|
||||
@@ -1791,7 +1791,7 @@ func TestResponsesStreamConverter_SequenceNumbers(t *testing.T) {
|
||||
|
||||
func TestResponsesStreamConverter_FunctionCallStatus(t *testing.T) {
|
||||
// Verify that function call items include status field
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||
|
||||
events := converter.Process(api.ChatResponse{
|
||||
Message: api.Message{
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Prompt struct {
|
||||
@@ -37,11 +36,10 @@ type Terminal struct {
|
||||
}
|
||||
|
||||
type Instance struct {
|
||||
Prompt *Prompt
|
||||
Terminal *Terminal
|
||||
History *History
|
||||
Pasting bool
|
||||
pastedLines []string
|
||||
Prompt *Prompt
|
||||
Terminal *Terminal
|
||||
History *History
|
||||
Pasting bool
|
||||
}
|
||||
|
||||
func New(prompt Prompt) (*Instance, error) {
|
||||
@@ -176,8 +174,6 @@ func (i *Instance) Readline() (string, error) {
|
||||
case CharEsc:
|
||||
esc = true
|
||||
case CharInterrupt:
|
||||
i.pastedLines = nil
|
||||
i.Prompt.UseAlt = false
|
||||
return "", ErrInterrupt
|
||||
case CharPrev:
|
||||
i.historyPrev(buf, ¤tLineBuf)
|
||||
@@ -192,23 +188,7 @@ func (i *Instance) Readline() (string, error) {
|
||||
case CharForward:
|
||||
buf.MoveRight()
|
||||
case CharBackspace, CharCtrlH:
|
||||
if buf.IsEmpty() && len(i.pastedLines) > 0 {
|
||||
lastIdx := len(i.pastedLines) - 1
|
||||
prevLine := i.pastedLines[lastIdx]
|
||||
i.pastedLines = i.pastedLines[:lastIdx]
|
||||
fmt.Print(CursorBOL + ClearToEOL + CursorUp + CursorBOL + ClearToEOL)
|
||||
if len(i.pastedLines) == 0 {
|
||||
fmt.Print(i.Prompt.Prompt)
|
||||
i.Prompt.UseAlt = false
|
||||
} else {
|
||||
fmt.Print(i.Prompt.AltPrompt)
|
||||
}
|
||||
for _, r := range prevLine {
|
||||
buf.Add(r)
|
||||
}
|
||||
} else {
|
||||
buf.Remove()
|
||||
}
|
||||
buf.Remove()
|
||||
case CharTab:
|
||||
// todo: convert back to real tabs
|
||||
for range 8 {
|
||||
@@ -231,28 +211,13 @@ func (i *Instance) Readline() (string, error) {
|
||||
case CharCtrlZ:
|
||||
fd := os.Stdin.Fd()
|
||||
return handleCharCtrlZ(fd, i.Terminal.termios)
|
||||
case CharCtrlJ:
|
||||
i.pastedLines = append(i.pastedLines, buf.String())
|
||||
buf.Buf.Clear()
|
||||
buf.Pos = 0
|
||||
buf.DisplayPos = 0
|
||||
buf.LineHasSpace.Clear()
|
||||
fmt.Println()
|
||||
fmt.Print(i.Prompt.AltPrompt)
|
||||
i.Prompt.UseAlt = true
|
||||
continue
|
||||
case CharEnter:
|
||||
case CharEnter, CharCtrlJ:
|
||||
output := buf.String()
|
||||
if len(i.pastedLines) > 0 {
|
||||
output = strings.Join(i.pastedLines, "\n") + "\n" + output
|
||||
i.pastedLines = nil
|
||||
}
|
||||
if output != "" {
|
||||
i.History.Add(output)
|
||||
}
|
||||
buf.MoveToEnd()
|
||||
fmt.Println()
|
||||
i.Prompt.UseAlt = false
|
||||
|
||||
return output, nil
|
||||
default:
|
||||
|
||||
@@ -179,7 +179,7 @@ _build_macapp() {
|
||||
fi
|
||||
|
||||
rm -f dist/Ollama-darwin.zip
|
||||
ditto -c -k --norsrc --keepParent dist/Ollama.app dist/Ollama-darwin.zip
|
||||
ditto -c -k --keepParent dist/Ollama.app dist/Ollama-darwin.zip
|
||||
(cd dist/Ollama.app/Contents/Resources/; tar -cf - ollama ollama-mlx *.so *.dylib *.metallib 2>/dev/null) | gzip -9vc > dist/ollama-darwin.tgz
|
||||
|
||||
# Notarize and Staple
|
||||
@@ -187,7 +187,7 @@ _build_macapp() {
|
||||
$(xcrun -f notarytool) submit dist/Ollama-darwin.zip --wait --timeout 20m --apple-id "$APPLE_ID" --password "$APPLE_PASSWORD" --team-id "$APPLE_TEAM_ID"
|
||||
rm -f dist/Ollama-darwin.zip
|
||||
$(xcrun -f stapler) staple dist/Ollama.app
|
||||
ditto -c -k --norsrc --keepParent dist/Ollama.app dist/Ollama-darwin.zip
|
||||
ditto -c -k --keepParent dist/Ollama.app dist/Ollama-darwin.zip
|
||||
|
||||
rm -f dist/Ollama.dmg
|
||||
|
||||
|
||||
67
x/cmd/run.go
@@ -25,6 +25,14 @@ import (
|
||||
"github.com/ollama/ollama/x/tools"
|
||||
)
|
||||
|
||||
// MultilineState tracks the state of multiline input
|
||||
type MultilineState int
|
||||
|
||||
const (
|
||||
MultilineNone MultilineState = iota
|
||||
MultilineSystem
|
||||
)
|
||||
|
||||
// Tool output capping constants
|
||||
const (
|
||||
// localModelTokenLimit is the token limit for local models (smaller context).
|
||||
@@ -648,7 +656,7 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
Prompt: ">>> ",
|
||||
AltPrompt: "... ",
|
||||
Placeholder: "Send a message (/? for help)",
|
||||
AltPlaceholder: "Press Enter to send",
|
||||
AltPlaceholder: `Use """ to end multi-line input`,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -699,6 +707,7 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
var sb strings.Builder
|
||||
var format string
|
||||
var system string
|
||||
var multiline MultilineState = MultilineNone
|
||||
|
||||
for {
|
||||
line, err := scanner.Readline()
|
||||
@@ -712,12 +721,37 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
}
|
||||
scanner.Prompt.UseAlt = false
|
||||
sb.Reset()
|
||||
multiline = MultilineNone
|
||||
continue
|
||||
case err != nil:
|
||||
return err
|
||||
}
|
||||
|
||||
switch {
|
||||
case multiline != MultilineNone:
|
||||
// check if there's a multiline terminating string
|
||||
before, ok := strings.CutSuffix(line, `"""`)
|
||||
sb.WriteString(before)
|
||||
if !ok {
|
||||
fmt.Fprintln(&sb)
|
||||
continue
|
||||
}
|
||||
|
||||
switch multiline {
|
||||
case MultilineSystem:
|
||||
system = sb.String()
|
||||
newMessage := api.Message{Role: "system", Content: system}
|
||||
if len(messages) > 0 && messages[len(messages)-1].Role == "system" {
|
||||
messages[len(messages)-1] = newMessage
|
||||
} else {
|
||||
messages = append(messages, newMessage)
|
||||
}
|
||||
fmt.Println("Set system message.")
|
||||
sb.Reset()
|
||||
}
|
||||
|
||||
multiline = MultilineNone
|
||||
scanner.Prompt.UseAlt = false
|
||||
case strings.HasPrefix(line, "/exit"), strings.HasPrefix(line, "/bye"):
|
||||
return nil
|
||||
case strings.HasPrefix(line, "/clear"):
|
||||
@@ -826,18 +860,41 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
options[args[2]] = fp[args[2]]
|
||||
case "system":
|
||||
if len(args) < 3 {
|
||||
fmt.Println("Usage: /set system <message>")
|
||||
fmt.Println("Usage: /set system <message> or /set system \"\"\"<multi-line message>\"\"\"")
|
||||
continue
|
||||
}
|
||||
|
||||
system = strings.Join(args[2:], " ")
|
||||
newMessage := api.Message{Role: "system", Content: system}
|
||||
multiline = MultilineSystem
|
||||
|
||||
line := strings.Join(args[2:], " ")
|
||||
line, ok := strings.CutPrefix(line, `"""`)
|
||||
if !ok {
|
||||
multiline = MultilineNone
|
||||
} else {
|
||||
// only cut suffix if the line is multiline
|
||||
line, ok = strings.CutSuffix(line, `"""`)
|
||||
if ok {
|
||||
multiline = MultilineNone
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString(line)
|
||||
if multiline != MultilineNone {
|
||||
scanner.Prompt.UseAlt = true
|
||||
continue
|
||||
}
|
||||
|
||||
system = sb.String()
|
||||
newMessage := api.Message{Role: "system", Content: sb.String()}
|
||||
// Check if the slice is not empty and the last message is from 'system'
|
||||
if len(messages) > 0 && messages[len(messages)-1].Role == "system" {
|
||||
// Replace the last message
|
||||
messages[len(messages)-1] = newMessage
|
||||
} else {
|
||||
messages = append(messages, newMessage)
|
||||
}
|
||||
fmt.Println("Set system message.")
|
||||
sb.Reset()
|
||||
continue
|
||||
default:
|
||||
fmt.Printf("Unknown command '/set %s'. Type /? for help\n", args[1])
|
||||
@@ -1024,7 +1081,7 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
sb.WriteString(line)
|
||||
}
|
||||
|
||||
if sb.Len() > 0 {
|
||||
if sb.Len() > 0 && multiline == MultilineNone {
|
||||
newMessage := api.Message{Role: "user", Content: sb.String()}
|
||||
messages = append(messages, newMessage)
|
||||
|
||||
|
||||
@@ -11,9 +11,11 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime/pprof"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/models/gemma3"
|
||||
"github.com/ollama/ollama/x/imagegen/models/glm_image"
|
||||
"github.com/ollama/ollama/x/imagegen/models/gpt_oss"
|
||||
"github.com/ollama/ollama/x/imagegen/models/llama"
|
||||
"github.com/ollama/ollama/x/imagegen/models/qwen_image"
|
||||
@@ -61,6 +63,7 @@ func main() {
|
||||
|
||||
// Legacy mode flags
|
||||
zimageFlag := flag.Bool("zimage", false, "Z-Image generation")
|
||||
glmImageFlag := flag.Bool("glm-image", false, "GLM-Image generation")
|
||||
qwenImage := flag.Bool("qwen-image", false, "Qwen-Image text-to-image generation")
|
||||
qwenImageEdit := flag.Bool("qwen-image-edit", false, "Qwen-Image-Edit image editing")
|
||||
var inputImages stringSlice
|
||||
@@ -117,6 +120,33 @@ func main() {
|
||||
if err == nil {
|
||||
err = saveImageArray(img, *out)
|
||||
}
|
||||
case *glmImageFlag:
|
||||
m := &glm_image.Model{}
|
||||
// Use LoadFromPath if model path looks like a directory, otherwise use Load (ollama manifest)
|
||||
var loadErr error
|
||||
if strings.HasPrefix(*modelPath, ".") || strings.HasPrefix(*modelPath, "/") {
|
||||
loadErr = m.LoadFromPath(*modelPath)
|
||||
} else {
|
||||
loadErr = m.Load(*modelPath)
|
||||
}
|
||||
if loadErr != nil {
|
||||
log.Fatal(loadErr)
|
||||
}
|
||||
var img *mlx.Array
|
||||
img, err = m.GenerateFromConfig(context.Background(), &glm_image.GenerateConfig{
|
||||
Prompt: *prompt,
|
||||
Width: int32(*width),
|
||||
Height: int32(*height),
|
||||
Steps: *steps,
|
||||
Seed: *seed,
|
||||
Temperature: float32(*temperature),
|
||||
TopP: float32(*topP),
|
||||
GuidanceScale: float32(*cfgScale),
|
||||
MaxVisualTokens: int32(*maxTokens),
|
||||
})
|
||||
if err == nil {
|
||||
err = saveImageArray(img, *out)
|
||||
}
|
||||
case *qwenImage:
|
||||
m, loadErr := qwen_image.LoadPersistent(*modelPath)
|
||||
if loadErr != nil {
|
||||
|
||||
@@ -48,7 +48,7 @@ func CreateModel(modelName, modelDir, quantize string, createLayer LayerCreator,
|
||||
var totalParams int64 // Count parameters from original tensor shapes
|
||||
|
||||
// Components to process - extract individual tensors from each
|
||||
components := []string{"text_encoder", "transformer", "vae"}
|
||||
components := []string{"text_encoder", "transformer", "vae", "vision_language_encoder"}
|
||||
|
||||
for _, component := range components {
|
||||
componentDir := filepath.Join(modelDir, component)
|
||||
@@ -126,10 +126,13 @@ func CreateModel(modelName, modelDir, quantize string, createLayer LayerCreator,
|
||||
"text_encoder/generation_config.json",
|
||||
"transformer/config.json",
|
||||
"vae/config.json",
|
||||
"vision_language_encoder/config.json",
|
||||
"scheduler/scheduler_config.json",
|
||||
"tokenizer/tokenizer.json",
|
||||
"tokenizer/tokenizer_config.json",
|
||||
"tokenizer/vocab.json",
|
||||
"processor/tokenizer.json", // GLM-Image main tokenizer
|
||||
"processor/tokenizer_config.json", // GLM-Image tokenizer config
|
||||
}
|
||||
|
||||
for _, cfgPath := range configFiles {
|
||||
|
||||
19
x/imagegen/imagegen.md
Normal file
@@ -0,0 +1,19 @@
|
||||
# Image generation models (experimental)
|
||||
|
||||
Experimental image generation models are available for **macOS** in Ollama:
|
||||
|
||||
## Available models
|
||||
|
||||
- [Z-Image-Turbo](https://ollama.com/x/z-image-turbo)
|
||||
|
||||
```
|
||||
ollama run x/z-image-turbo
|
||||
```
|
||||
|
||||
> **Note**: [`x`](https://ollama.com/x) is a username on ollama.com where the maintainer team uploads experimental models
|
||||
|
||||
More models coming soon:
|
||||
|
||||
1. Qwen-Image-2512
|
||||
2. Qwen-Image-Edit-2511
|
||||
3. GLM-Image
|
||||
@@ -27,6 +27,7 @@ var modelVRAMEstimates = map[string]uint64{
|
||||
"ZImagePipeline": 21 * GB, // ~21GB for Z-Image (text encoder + transformer + VAE)
|
||||
"FluxPipeline": 21 * GB, // ~21GB for Flux (same architecture)
|
||||
"QwenImagePipeline": 80 * GB, // TODO: verify actual requirements, using conservative estimate for now
|
||||
"GlmImagePipeline": 80 * GB, // ~34GB weights + ~46GB working memory for 9B+7B hybrid model
|
||||
}
|
||||
|
||||
// CheckPlatformSupport validates that image generation is supported on the current platform.
|
||||
|
||||
693
x/imagegen/models/glm_image/glm_image.go
Normal file
@@ -0,0 +1,693 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package glm_image implements the GLM-Image hybrid AR + diffusion model.
|
||||
package glm_image
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// ByT5Tokenizer is a simple byte-level tokenizer for ByT5
|
||||
// ByT5 uses bytes as tokens: each byte (0-255) maps to token ID (3-258)
|
||||
// Special tokens: 0=pad, 1=eos, 2=unk
|
||||
type ByT5Tokenizer struct {
|
||||
PadTokenID int32
|
||||
EOSTokenID int32
|
||||
UNKTokenID int32
|
||||
}
|
||||
|
||||
// NewByT5Tokenizer creates a new ByT5 tokenizer
|
||||
func NewByT5Tokenizer() *ByT5Tokenizer {
|
||||
return &ByT5Tokenizer{
|
||||
PadTokenID: 0,
|
||||
EOSTokenID: 1,
|
||||
UNKTokenID: 2,
|
||||
}
|
||||
}
|
||||
|
||||
// Encode converts a string to token IDs
|
||||
func (t *ByT5Tokenizer) Encode(text string) []int32 {
|
||||
bytes := []byte(text)
|
||||
tokens := make([]int32, len(bytes))
|
||||
for i, b := range bytes {
|
||||
// Standard ByT5 tokenization: bytes 0-255 map to tokens 3-258
|
||||
// (tokens 0, 1, 2 are PAD, EOS, UNK)
|
||||
tokens[i] = int32(b) + 3
|
||||
}
|
||||
return tokens
|
||||
}
|
||||
|
||||
// Decode converts token IDs back to a string
|
||||
func (t *ByT5Tokenizer) Decode(tokens []int32) string {
|
||||
bytes := make([]byte, 0, len(tokens))
|
||||
for _, tok := range tokens {
|
||||
if tok >= 3 && tok < 259 {
|
||||
bytes = append(bytes, byte(tok-3))
|
||||
}
|
||||
}
|
||||
return string(bytes)
|
||||
}
|
||||
|
||||
// GenerateConfig holds all options for image generation.
|
||||
type GenerateConfig struct {
|
||||
Prompt string
|
||||
NegativePrompt string // For CFG (optional, not typically used with GLM-Image)
|
||||
GuidanceScale float32 // Guidance scale (default: 1.5)
|
||||
Width int32 // Image width (default: 1024, must be divisible by 32)
|
||||
Height int32 // Image height (default: 1024, must be divisible by 32)
|
||||
Steps int // Diffusion denoising steps (default: 50)
|
||||
Seed int64 // Random seed
|
||||
Progress ProgressFunc // Optional progress callback
|
||||
|
||||
// AR generation options
|
||||
MaxVisualTokens int32 // Max visual tokens to generate (default: 256)
|
||||
Temperature float32 // AR sampling temperature (default: 0.9)
|
||||
TopP float32 // Nucleus sampling (default: 0.75)
|
||||
}
|
||||
|
||||
// ProgressFunc is called during generation with stage and step progress.
|
||||
type ProgressFunc func(stage string, step, totalSteps int)
|
||||
|
||||
// Model represents a GLM-Image hybrid model.
|
||||
type Model struct {
|
||||
ModelName string
|
||||
Tokenizer *ByT5Tokenizer // For T5 text encoder (glyph embeddings)
|
||||
GLMTokenizer *GLMTokenizer // For AR model (visual token generation)
|
||||
TextEncoder *T5TextEncoder
|
||||
VisionLanguageEncoder *VisionLanguageEncoder
|
||||
Transformer *DiffusionTransformer
|
||||
VAEDecoder *VAEDecoder
|
||||
}
|
||||
|
||||
// Load loads the GLM-Image model from ollama blob storage.
|
||||
func (m *Model) Load(modelName string) error {
|
||||
fmt.Printf("Loading GLM-Image model from manifest: %s...\n", modelName)
|
||||
start := time.Now()
|
||||
|
||||
if mlx.GPUIsAvailable() {
|
||||
mlx.SetDefaultDeviceGPU()
|
||||
mlx.EnableCompile()
|
||||
}
|
||||
|
||||
m.ModelName = modelName
|
||||
|
||||
// Load manifest
|
||||
manifest, err := imagegen.LoadManifest(modelName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load manifest: %w", err)
|
||||
}
|
||||
|
||||
// Create ByT5 tokenizer (byte-level, no vocabulary file needed)
|
||||
// Used for T5 text encoder (glyph embeddings)
|
||||
fmt.Print(" Creating ByT5 tokenizer... ")
|
||||
m.Tokenizer = NewByT5Tokenizer()
|
||||
fmt.Println("✓")
|
||||
|
||||
// Load GLM tokenizer for AR model (visual token generation)
|
||||
fmt.Print(" Loading GLM tokenizer... ")
|
||||
glmTok, err := NewGLMTokenizer(manifest)
|
||||
if err != nil {
|
||||
return fmt.Errorf("glm tokenizer: %w", err)
|
||||
}
|
||||
m.GLMTokenizer = glmTok
|
||||
fmt.Println("✓")
|
||||
|
||||
// Load T5 text encoder (~830MB)
|
||||
m.TextEncoder = &T5TextEncoder{}
|
||||
if err := m.TextEncoder.Load(manifest); err != nil {
|
||||
return fmt.Errorf("text encoder: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.TextEncoder)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
// Load vision-language encoder (~19GB, 9B params)
|
||||
m.VisionLanguageEncoder = &VisionLanguageEncoder{}
|
||||
if err := m.VisionLanguageEncoder.Load(manifest); err != nil {
|
||||
return fmt.Errorf("vision language encoder: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.VisionLanguageEncoder)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
// Load diffusion transformer (~13GB, 7B params)
|
||||
m.Transformer = &DiffusionTransformer{}
|
||||
if err := m.Transformer.Load(manifest); err != nil {
|
||||
return fmt.Errorf("transformer: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.Transformer)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
// Load VAE decoder (~775MB)
|
||||
m.VAEDecoder = &VAEDecoder{}
|
||||
if err := m.VAEDecoder.Load(manifest); err != nil {
|
||||
return fmt.Errorf("VAE decoder: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.VAEDecoder)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
mem := mlx.MetalGetActiveMemory()
|
||||
fmt.Printf(" Loaded in %.2fs (%.1f GB VRAM)\n", time.Since(start).Seconds(), float64(mem)/(1024*1024*1024))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadFromPath loads the model from a directory path (not ollama manifest)
|
||||
func (m *Model) LoadFromPath(modelPath string) error {
|
||||
fmt.Printf("Loading GLM-Image model from path: %s...\n", modelPath)
|
||||
start := time.Now()
|
||||
|
||||
if mlx.GPUIsAvailable() {
|
||||
mlx.SetDefaultDeviceGPU()
|
||||
mlx.EnableCompile()
|
||||
}
|
||||
|
||||
m.ModelName = modelPath
|
||||
|
||||
// Create ByT5 tokenizer (byte-level, no vocabulary file needed)
|
||||
fmt.Print(" Creating ByT5 tokenizer... ")
|
||||
m.Tokenizer = NewByT5Tokenizer()
|
||||
fmt.Println("✓")
|
||||
|
||||
// Load GLM tokenizer for AR model (visual token generation)
|
||||
fmt.Print(" Loading GLM tokenizer... ")
|
||||
glmTok, err := NewGLMTokenizerFromPath(modelPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("glm tokenizer: %w", err)
|
||||
}
|
||||
m.GLMTokenizer = glmTok
|
||||
fmt.Println("✓")
|
||||
|
||||
// Load T5 text encoder
|
||||
m.TextEncoder = &T5TextEncoder{}
|
||||
if err := m.TextEncoder.LoadFromPath(filepath.Join(modelPath, "text_encoder")); err != nil {
|
||||
return fmt.Errorf("text encoder: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.TextEncoder)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
// Load vision-language encoder
|
||||
m.VisionLanguageEncoder = &VisionLanguageEncoder{}
|
||||
if err := m.VisionLanguageEncoder.LoadFromPath(filepath.Join(modelPath, "vision_language_encoder")); err != nil {
|
||||
return fmt.Errorf("vision language encoder: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.VisionLanguageEncoder)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
// Load diffusion transformer
|
||||
m.Transformer = &DiffusionTransformer{}
|
||||
if err := m.Transformer.LoadFromPath(filepath.Join(modelPath, "transformer")); err != nil {
|
||||
return fmt.Errorf("transformer: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.Transformer)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
// Load VAE decoder
|
||||
m.VAEDecoder = &VAEDecoder{}
|
||||
if err := m.VAEDecoder.LoadFromPath(filepath.Join(modelPath, "vae")); err != nil {
|
||||
return fmt.Errorf("VAE decoder: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.VAEDecoder)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
mem := mlx.MetalGetActiveMemory()
|
||||
fmt.Printf(" Loaded in %.2fs (%.1f GB VRAM)\n", time.Since(start).Seconds(), float64(mem)/(1024*1024*1024))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Generate creates an image from a prompt.
|
||||
func (m *Model) Generate(prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
|
||||
return m.GenerateFromConfig(context.Background(), &GenerateConfig{
|
||||
Prompt: prompt,
|
||||
Width: width,
|
||||
Height: height,
|
||||
Steps: steps,
|
||||
Seed: seed,
|
||||
})
|
||||
}
|
||||
|
||||
// GenerateWithProgress creates an image with progress callback.
|
||||
func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress ProgressFunc) (*mlx.Array, error) {
|
||||
return m.GenerateFromConfig(context.Background(), &GenerateConfig{
|
||||
Prompt: prompt,
|
||||
Width: width,
|
||||
Height: height,
|
||||
Steps: steps,
|
||||
Seed: seed,
|
||||
Progress: progress,
|
||||
})
|
||||
}
|
||||
|
||||
// GenerateFromConfig generates an image using the unified config struct.
|
||||
func (m *Model) GenerateFromConfig(ctx context.Context, cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
start := time.Now()
|
||||
result, err := m.generate(ctx, cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fmt.Printf("Generated in %.2fs (%d diffusion steps)\n", time.Since(start).Seconds(), cfg.Steps)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GenerateImage implements model.ImageModel interface.
|
||||
func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
|
||||
return m.Generate(prompt, width, height, steps, seed)
|
||||
}
|
||||
|
||||
// generate is the internal generation pipeline.
|
||||
func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
// Apply defaults
|
||||
if cfg.Width <= 0 {
|
||||
cfg.Width = 1024
|
||||
}
|
||||
if cfg.Height <= 0 {
|
||||
cfg.Height = 1024
|
||||
}
|
||||
if cfg.Steps <= 0 {
|
||||
cfg.Steps = 50
|
||||
}
|
||||
if cfg.GuidanceScale <= 0 {
|
||||
cfg.GuidanceScale = 1.5
|
||||
}
|
||||
// Calculate MaxVisualTokens based on image dimensions
|
||||
// GLM-Image generates TWO grids of visual tokens:
|
||||
// 1. First: prev (small) grid - prevTokenH × prevTokenW tokens
|
||||
// 2. Then: target (large) grid - tokenH × tokenW tokens
|
||||
// After generation, we extract only the TARGET grid tokens for diffusion.
|
||||
factor := int32(32)
|
||||
tokenH := cfg.Height / factor
|
||||
tokenW := cfg.Width / factor
|
||||
targetGridTokens := tokenH * tokenW
|
||||
|
||||
// Compute prev grid dimensions using diffusers formula:
|
||||
// ratio = token_h / token_w
|
||||
// prev_token_h = int(sqrt(ratio) * 16)
|
||||
// prev_token_w = int(sqrt(1/ratio) * 16)
|
||||
ratio := float64(tokenH) / float64(tokenW)
|
||||
prevTokenH := int32(math.Sqrt(ratio) * 16)
|
||||
prevTokenW := int32(math.Sqrt(1/ratio) * 16)
|
||||
prevGridTokens := prevTokenH * prevTokenW
|
||||
|
||||
// Total tokens to generate = prev grid + target grid
|
||||
// (diffusers does max_new_tokens = total + 1 for EOS, but we stop on EOS anyway)
|
||||
cfg.MaxVisualTokens = prevGridTokens + targetGridTokens
|
||||
if cfg.Temperature <= 0 {
|
||||
cfg.Temperature = 0.9
|
||||
}
|
||||
if cfg.TopP <= 0 {
|
||||
cfg.TopP = 0.75
|
||||
}
|
||||
|
||||
// Ensure dimensions are divisible by 32
|
||||
cfg.Width = (cfg.Width / 32) * 32
|
||||
cfg.Height = (cfg.Height / 32) * 32
|
||||
|
||||
tcfg := m.Transformer.Config
|
||||
latentH := cfg.Height / 8
|
||||
latentW := cfg.Width / 8
|
||||
|
||||
// Progress callback helper
|
||||
progress := func(stage string, step, total int) {
|
||||
if cfg.Progress != nil {
|
||||
cfg.Progress(stage, step, total)
|
||||
}
|
||||
}
|
||||
|
||||
// === PHASE 1: T5 Text Encoding ===
|
||||
fmt.Println("[T5] Encoding glyph text...")
|
||||
progress("text_encoding", 0, 1)
|
||||
textEmbed := m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.Prompt)
|
||||
mlx.Keep(textEmbed)
|
||||
mlx.Eval(textEmbed)
|
||||
fmt.Printf("[T5] Done, shape: %v\n", textEmbed.Shape())
|
||||
progress("text_encoding", 1, 1)
|
||||
|
||||
// === PHASE 2: AR Visual Token Generation ===
|
||||
fmt.Printf("[AR] Generating %d visual tokens...\n", cfg.MaxVisualTokens)
|
||||
progress("ar_generation", 0, int(cfg.MaxVisualTokens))
|
||||
visualTokens := m.VisionLanguageEncoder.Generate(
|
||||
cfg.Prompt,
|
||||
m.GLMTokenizer,
|
||||
cfg.MaxVisualTokens,
|
||||
cfg.Temperature,
|
||||
cfg.TopP,
|
||||
cfg.Seed,
|
||||
cfg.Height,
|
||||
cfg.Width,
|
||||
func(step int) {
|
||||
if step%100 == 0 || step < 10 {
|
||||
fmt.Printf("[AR] Step %d/%d\n", step, cfg.MaxVisualTokens)
|
||||
}
|
||||
progress("ar_generation", step, int(cfg.MaxVisualTokens))
|
||||
},
|
||||
)
|
||||
mlx.Keep(visualTokens)
|
||||
mlx.Eval(visualTokens)
|
||||
fmt.Printf("[AR] Done generating visual tokens\n")
|
||||
progress("ar_generation", int(cfg.MaxVisualTokens), int(cfg.MaxVisualTokens))
|
||||
|
||||
vtShape := visualTokens.Shape()
|
||||
totalGenerated := vtShape[1]
|
||||
fmt.Printf("[AR] Generated %d tokens total\n", totalGenerated)
|
||||
|
||||
// Extract only the TARGET grid tokens (skip the prev grid tokens)
|
||||
// diffusers: large_image_tokens = outputs[input_length + large_image_start_offset : ...]
|
||||
// large_image_start_offset = prev_grid_size
|
||||
var targetGridVisualTokens *mlx.Array
|
||||
if totalGenerated >= prevGridTokens+targetGridTokens {
|
||||
// Full generation completed - extract target grid
|
||||
targetGridVisualTokens = mlx.Slice(visualTokens,
|
||||
[]int32{0, prevGridTokens},
|
||||
[]int32{1, prevGridTokens + targetGridTokens})
|
||||
mlx.Keep(targetGridVisualTokens)
|
||||
mlx.Eval(targetGridVisualTokens)
|
||||
} else if totalGenerated > prevGridTokens {
|
||||
// Partial target grid - take what we have
|
||||
actualTargetTokens := totalGenerated - prevGridTokens
|
||||
targetGridVisualTokens = mlx.Slice(visualTokens,
|
||||
[]int32{0, prevGridTokens},
|
||||
[]int32{1, totalGenerated})
|
||||
mlx.Keep(targetGridVisualTokens)
|
||||
mlx.Eval(targetGridVisualTokens)
|
||||
fmt.Printf("WARNING: Partial target grid: got %d/%d target tokens\n",
|
||||
actualTargetTokens, targetGridTokens)
|
||||
} else {
|
||||
// Not enough tokens - EOS came too early
|
||||
return nil, fmt.Errorf("AR generation stopped too early: got %d tokens, need at least %d (prev grid) + 1",
|
||||
totalGenerated, prevGridTokens)
|
||||
}
|
||||
|
||||
// === PHASE 3: Diffusion Decoding ===
|
||||
// Setup scheduler with dynamic shift based on image size
|
||||
scheduler := NewFlowMatchScheduler(DefaultSchedulerConfig())
|
||||
imgSeqLen := (latentH / tcfg.PatchSize) * (latentW / tcfg.PatchSize)
|
||||
scheduler.SetTimestepsWithDynamicShift(cfg.Steps, imgSeqLen)
|
||||
|
||||
// Initialize noise latents [B, C, H, W]
|
||||
latents := scheduler.InitNoise([]int32{1, tcfg.InChannels, latentH, latentW}, cfg.Seed)
|
||||
mlx.Eval(latents)
|
||||
|
||||
// Upsample TARGET grid visual tokens 2x to match patch count (matching diffusers)
|
||||
// target_grid tokens -> 2x upsample -> patch_count
|
||||
// e.g., 32x32=1024 tokens -> 64x64=4096 patches for 1024x1024
|
||||
visualTokensUpsampled := upsampleTokens(targetGridVisualTokens, tokenH, tokenW, 2)
|
||||
|
||||
// Prepare prior embeddings from upsampled visual tokens (VQ codebook lookup + projection)
|
||||
priorEmbed := m.Transformer.EmbedPriorTokens(visualTokensUpsampled)
|
||||
mlx.Keep(priorEmbed)
|
||||
mlx.Eval(priorEmbed)
|
||||
|
||||
// Prepare text conditioning (project T5 embeddings)
|
||||
textCond := m.Transformer.ProjectTextEmbeddings(textEmbed)
|
||||
mlx.Keep(textCond)
|
||||
mlx.Eval(textCond)
|
||||
|
||||
// === CFG Setup ===
|
||||
// For classifier-free guidance, we need unconditional (negative) text embeddings
|
||||
// GLM-Image uses empty string "" for negative prompt
|
||||
doCFG := cfg.GuidanceScale > 1.0
|
||||
var negativeTextCond *mlx.Array
|
||||
if doCFG {
|
||||
// Encode empty string for negative prompt
|
||||
negativeTextEmbed := m.TextEncoder.EncodePrompt(m.Tokenizer, "")
|
||||
mlx.Keep(negativeTextEmbed)
|
||||
mlx.Eval(negativeTextEmbed)
|
||||
negativeTextCond = m.Transformer.ProjectTextEmbeddings(negativeTextEmbed)
|
||||
mlx.Keep(negativeTextCond)
|
||||
mlx.Eval(negativeTextCond)
|
||||
negativeTextEmbed.Free()
|
||||
}
|
||||
|
||||
// Prepare conditioning inputs
|
||||
targetSize := mlx.NewArray([]float32{float32(cfg.Height), float32(cfg.Width)}, []int32{1, 2})
|
||||
cropCoords := mlx.NewArray([]float32{0, 0}, []int32{1, 2}) // Default: no crop offset
|
||||
targetSize = mlx.ToBFloat16(targetSize)
|
||||
cropCoords = mlx.ToBFloat16(cropCoords)
|
||||
mlx.Keep(targetSize)
|
||||
mlx.Keep(cropCoords)
|
||||
mlx.Eval(targetSize, cropCoords)
|
||||
|
||||
pH := latentH / tcfg.PatchSize
|
||||
pW := latentW / tcfg.PatchSize
|
||||
|
||||
// Denoising loop
|
||||
fmt.Printf("[Diffusion] Starting %d denoising steps...\n", cfg.Steps)
|
||||
progress("diffusion", 0, cfg.Steps)
|
||||
for i := 0; i < cfg.Steps; i++ {
|
||||
fmt.Printf("[Diffusion] Step %d/%d (timestep=%.1f)\n", i+1, cfg.Steps, scheduler.Timesteps[i]-1)
|
||||
// Check for cancellation
|
||||
if ctx != nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
textEmbed.Free()
|
||||
visualTokens.Free()
|
||||
// visualTokensUpsampled points to visualTokens, don't double-free
|
||||
priorEmbed.Free()
|
||||
textCond.Free()
|
||||
latents.Free()
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// Get timestep value for the transformer
|
||||
// scheduler.Timesteps contains raw timestep values (1000 down to ~20)
|
||||
// Pass timestep - 1 to match diffusers: timestep = t.expand(latents.shape[0]) - 1
|
||||
timestepVal := scheduler.Timesteps[i] - 1
|
||||
timestep := mlx.ToBFloat16(mlx.NewArray([]float32{timestepVal}, []int32{1}))
|
||||
|
||||
// Patchify latents [B, C, H, W] -> [B, L, C*p*p]
|
||||
patches := PatchifyLatents(latents, tcfg.PatchSize)
|
||||
|
||||
// Transformer forward with MMDiT architecture
|
||||
// Conditional pass (with text + prior embeddings)
|
||||
outputCond := m.Transformer.ForwardWithPriorDrop(
|
||||
patches,
|
||||
priorEmbed,
|
||||
textCond,
|
||||
timestep,
|
||||
targetSize,
|
||||
cropCoords,
|
||||
pH,
|
||||
pW,
|
||||
false, // priorTokenDrop = false for conditional
|
||||
)
|
||||
|
||||
// Unpatchify [B, L, C*p*p] -> [B, C, H, W]
|
||||
noisePredCond := UnpatchifyLatents(outputCond, latentH, latentW, tcfg.PatchSize, tcfg.OutChannels)
|
||||
|
||||
var noisePred *mlx.Array
|
||||
if doCFG {
|
||||
// Unconditional pass (empty text, dropped prior embeddings)
|
||||
outputUncond := m.Transformer.ForwardWithPriorDrop(
|
||||
patches,
|
||||
priorEmbed, // Still passed but will be ignored due to priorTokenDrop=true
|
||||
negativeTextCond,
|
||||
timestep,
|
||||
targetSize,
|
||||
cropCoords,
|
||||
pH,
|
||||
pW,
|
||||
true, // priorTokenDrop = true for unconditional
|
||||
)
|
||||
noisePredUncond := UnpatchifyLatents(outputUncond, latentH, latentW, tcfg.PatchSize, tcfg.OutChannels)
|
||||
|
||||
// CFG formula: noise_pred = uncond + guidance_scale * (cond - uncond)
|
||||
diff := mlx.Sub(noisePredCond, noisePredUncond)
|
||||
scaled := mlx.MulScalar(diff, cfg.GuidanceScale)
|
||||
noisePred = mlx.Add(noisePredUncond, scaled)
|
||||
} else {
|
||||
noisePred = noisePredCond
|
||||
}
|
||||
|
||||
// Scheduler step
|
||||
oldLatents := latents
|
||||
latents = scheduler.Step(noisePred, latents, i)
|
||||
mlx.Eval(latents)
|
||||
oldLatents.Free()
|
||||
|
||||
progress("diffusion", i+1, cfg.Steps)
|
||||
}
|
||||
|
||||
// Cleanup intermediate arrays
|
||||
textEmbed.Free()
|
||||
visualTokens.Free()
|
||||
// visualTokensUpsampled points to visualTokens, don't double-free
|
||||
priorEmbed.Free()
|
||||
textCond.Free()
|
||||
if negativeTextCond != nil {
|
||||
negativeTextCond.Free()
|
||||
}
|
||||
targetSize.Free()
|
||||
cropCoords.Free()
|
||||
|
||||
// === PHASE 4: VAE Decode ===
|
||||
progress("vae_decode", 0, 1)
|
||||
decoded := m.VAEDecoder.Decode(latents)
|
||||
mlx.Eval(decoded)
|
||||
latents.Free()
|
||||
progress("vae_decode", 1, 1)
|
||||
|
||||
return decoded, nil
|
||||
}
|
||||
|
||||
// upsampleTokens performs nearest-neighbor upsampling of visual tokens
|
||||
// Converts from prev_grid (e.g., 16x16) to target_grid (e.g., 32x32 for 2x, 64x64 for 4x)
|
||||
// scale must be 2 or 4
|
||||
//
|
||||
// Handles early EOS gracefully: if tokens has fewer than prevH*prevW elements,
|
||||
// missing tokens are padded with 0 (visual token padding value).
|
||||
func upsampleTokens(tokens *mlx.Array, prevH, prevW int32, scale int32) *mlx.Array {
|
||||
// tokens: [1, N] where N <= prevH*prevW (may be shorter if early EOS)
|
||||
// Each token at (i, j) becomes scale*scale tokens in the output
|
||||
|
||||
mlx.Eval(tokens)
|
||||
tokenData := tokens.DataInt32()
|
||||
numTokens := int32(len(tokenData))
|
||||
expectedTokens := prevH * prevW
|
||||
|
||||
// Warn if we got fewer tokens than expected (early EOS)
|
||||
if numTokens < expectedTokens {
|
||||
fmt.Printf("WARNING: upsampleTokens got %d tokens, expected %d (padding with 0)\n",
|
||||
numTokens, expectedTokens)
|
||||
}
|
||||
|
||||
targetH := prevH * scale
|
||||
targetW := prevW * scale
|
||||
upsampled := make([]int32, targetH*targetW)
|
||||
|
||||
for i := int32(0); i < prevH; i++ {
|
||||
for j := int32(0); j < prevW; j++ {
|
||||
srcIdx := i*prevW + j
|
||||
|
||||
// Handle early EOS: use 0 (padding) for missing tokens
|
||||
var val int32
|
||||
if srcIdx < numTokens {
|
||||
val = tokenData[srcIdx]
|
||||
} else {
|
||||
val = 0 // Padding token
|
||||
}
|
||||
|
||||
// Place in scale*scale positions
|
||||
dstI := i * scale
|
||||
dstJ := j * scale
|
||||
for di := int32(0); di < scale; di++ {
|
||||
for dj := int32(0); dj < scale; dj++ {
|
||||
upsampled[(dstI+di)*targetW+(dstJ+dj)] = val
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return mlx.NewArrayInt32(upsampled, []int32{1, targetH * targetW})
|
||||
}
|
||||
|
||||
// PatchifyLatents converts [B, C, H, W] to [B, L, C*p*p]
|
||||
func PatchifyLatents(latents *mlx.Array, patchSize int32) *mlx.Array {
|
||||
shape := latents.Shape()
|
||||
B := shape[0]
|
||||
C := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
|
||||
pH := H / patchSize
|
||||
pW := W / patchSize
|
||||
|
||||
// Reshape: [B, C, H, W] -> [B, C, pH, p, pW, p]
|
||||
x := mlx.Reshape(latents, B, C, pH, patchSize, pW, patchSize)
|
||||
// Transpose: -> [B, pH, pW, C, p, p]
|
||||
x = mlx.Transpose(x, 0, 2, 4, 1, 3, 5)
|
||||
// Flatten: -> [B, pH*pW, C*p*p]
|
||||
return mlx.Reshape(x, B, pH*pW, C*patchSize*patchSize)
|
||||
}
|
||||
|
||||
// UnpatchifyLatents converts [B, L, C*p*p] back to [B, C, H, W]
|
||||
func UnpatchifyLatents(patches *mlx.Array, H, W, patchSize, channels int32) *mlx.Array {
|
||||
shape := patches.Shape()
|
||||
B := shape[0]
|
||||
|
||||
pH := H / patchSize
|
||||
pW := W / patchSize
|
||||
|
||||
// Reshape: [B, L, C*p*p] -> [B, pH, pW, C, p, p]
|
||||
x := mlx.Reshape(patches, B, pH, pW, channels, patchSize, patchSize)
|
||||
// Transpose: -> [B, C, pH, p, pW, p]
|
||||
x = mlx.Transpose(x, 0, 3, 1, 4, 2, 5)
|
||||
// Reshape: -> [B, C, H, W]
|
||||
return mlx.Reshape(x, B, channels, pH*patchSize, pW*patchSize)
|
||||
}
|
||||
|
||||
// CalculateShift computes the dynamic shift for flow matching based on image sequence length.
|
||||
func CalculateShift(imgSeqLen int32) float32 {
|
||||
cfg := DefaultSchedulerConfig()
|
||||
if !cfg.UseDynamicShifting {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Sqrt-based shift calculation (matches diffusers)
|
||||
m := float32(math.Sqrt(float64(imgSeqLen) / float64(cfg.BaseImageSeqLen)))
|
||||
return m*cfg.MaxShift + cfg.BaseShift
|
||||
}
|
||||
|
||||
// UpsampleTokens2x upsamples token IDs by 2x using nearest neighbor interpolation
|
||||
// tokens: [B, H*W] -> [B, (H*2)*(W*2)]
|
||||
// This matches diffusers' _upsample_token_ids function
|
||||
func UpsampleTokens2x(tokens *mlx.Array, gridH, gridW int32) *mlx.Array {
|
||||
shape := tokens.Shape()
|
||||
B := shape[0]
|
||||
|
||||
// Reshape to [B, 1, H, W] for interpolation
|
||||
tokens = mlx.Reshape(tokens, B, 1, gridH, gridW)
|
||||
|
||||
// Convert to float for interpolation
|
||||
tokensFloat := mlx.AsType(tokens, mlx.DtypeFloat32)
|
||||
|
||||
// 2x nearest neighbor upsample
|
||||
// [B, 1, H, W] -> [B, 1, H*2, W*2]
|
||||
upsampled := nearestUpsample2x(tokensFloat)
|
||||
|
||||
// Convert back to int and reshape to [B, H*2*W*2]
|
||||
upsampled = mlx.AsType(upsampled, mlx.DtypeInt32)
|
||||
return mlx.Reshape(upsampled, B, gridH*2*gridW*2)
|
||||
}
|
||||
|
||||
// nearestUpsample2x performs 2x nearest neighbor upsampling on NCHW tensor
|
||||
func nearestUpsample2x(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
C := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
|
||||
// Repeat each element 2x2
|
||||
// [B, C, H, W] -> [B, C, H, 1, W, 1] -> [B, C, H, 2, W, 2] -> [B, C, H*2, W*2]
|
||||
x = mlx.Reshape(x, B, C, H, 1, W, 1)
|
||||
|
||||
// Tile to repeat each pixel 2x2
|
||||
x = mlx.Tile(x, []int32{1, 1, 1, 2, 1, 2})
|
||||
|
||||
// Reshape to final size
|
||||
return mlx.Reshape(x, B, C, H*2, W*2)
|
||||
}
|
||||
358
x/imagegen/models/glm_image/glm_tokenizer.go
Normal file
@@ -0,0 +1,358 @@
|
||||
//go:build mlx
|
||||
|
||||
package glm_image
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
)
|
||||
|
||||
// GLMTokenizer implements the GLM tokenizer for the AR model
|
||||
// This is a BPE-style tokenizer with ignore_merges=true, meaning it does
|
||||
// greedy longest-match tokenization from the vocab without runtime merging.
|
||||
type GLMTokenizer struct {
|
||||
Vocab map[string]int32 // token string -> token ID
|
||||
VocabReverse map[int32]string // token ID -> token string
|
||||
SpecialTokens map[string]int32 // special token strings -> IDs
|
||||
|
||||
// Special token IDs
|
||||
SopTokenID int32 // <sop> = grid_bos_token (167845)
|
||||
EopTokenID int32 // <eop> = grid_eos_token (167846)
|
||||
BosTokenID int32 // <|dit_token_16384|> = visual BOS (16384)
|
||||
EosTokenID int32 // <|dit_token_16385|> = visual EOS (16385)
|
||||
PadTokenID int32
|
||||
|
||||
// Sorted vocab keys by length (longest first) for greedy matching
|
||||
sortedTokens []string
|
||||
}
|
||||
|
||||
// tokenizerJSON represents the structure of tokenizer.json
|
||||
type tokenizerJSON struct {
|
||||
Model struct {
|
||||
Vocab map[string]int32 `json:"vocab"`
|
||||
} `json:"model"`
|
||||
AddedTokens []struct {
|
||||
ID int32 `json:"id"`
|
||||
Content string `json:"content"`
|
||||
Special bool `json:"special"`
|
||||
} `json:"added_tokens"`
|
||||
}
|
||||
|
||||
// NewGLMTokenizer creates a GLM tokenizer from the model manifest
|
||||
func NewGLMTokenizer(manifest *imagegen.ModelManifest) (*GLMTokenizer, error) {
|
||||
// Read tokenizer.json from processor directory in manifest
|
||||
data, err := manifest.ReadConfig("processor/tokenizer.json")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read tokenizer.json from manifest: %w", err)
|
||||
}
|
||||
|
||||
var tj tokenizerJSON
|
||||
if err := json.Unmarshal(data, &tj); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse tokenizer.json: %w", err)
|
||||
}
|
||||
|
||||
tok := &GLMTokenizer{
|
||||
Vocab: make(map[string]int32),
|
||||
VocabReverse: make(map[int32]string),
|
||||
SpecialTokens: make(map[string]int32),
|
||||
}
|
||||
|
||||
// Load vocab from model section
|
||||
for token, id := range tj.Model.Vocab {
|
||||
tok.Vocab[token] = id
|
||||
tok.VocabReverse[id] = token
|
||||
}
|
||||
|
||||
// Load added tokens (special tokens including dit_tokens)
|
||||
for _, at := range tj.AddedTokens {
|
||||
tok.Vocab[at.Content] = at.ID
|
||||
tok.VocabReverse[at.ID] = at.Content
|
||||
if at.Special {
|
||||
tok.SpecialTokens[at.Content] = at.ID
|
||||
}
|
||||
}
|
||||
|
||||
// Set special token IDs
|
||||
tok.SopTokenID = 167845 // <sop>
|
||||
tok.EopTokenID = 167846 // <eop>
|
||||
tok.BosTokenID = 16384 // <|dit_token_16384|>
|
||||
tok.EosTokenID = 16385 // <|dit_token_16385|>
|
||||
tok.PadTokenID = 16385 // Same as EOS
|
||||
|
||||
// Build sorted token list for greedy matching (longest first)
|
||||
tok.sortedTokens = make([]string, 0, len(tok.Vocab))
|
||||
for token := range tok.Vocab {
|
||||
tok.sortedTokens = append(tok.sortedTokens, token)
|
||||
}
|
||||
sort.Slice(tok.sortedTokens, func(i, j int) bool {
|
||||
return len(tok.sortedTokens[i]) > len(tok.sortedTokens[j])
|
||||
})
|
||||
|
||||
fmt.Printf("Loaded GLM tokenizer with %d tokens\n", len(tok.Vocab))
|
||||
|
||||
return tok, nil
|
||||
}
|
||||
|
||||
// NewGLMTokenizerFromPath creates a GLM tokenizer from a directory path
|
||||
func NewGLMTokenizerFromPath(modelPath string) (*GLMTokenizer, error) {
|
||||
// Read tokenizer.json from processor directory
|
||||
tokenizerPath := filepath.Join(modelPath, "processor", "tokenizer.json")
|
||||
data, err := os.ReadFile(tokenizerPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read tokenizer.json: %w", err)
|
||||
}
|
||||
|
||||
var tj tokenizerJSON
|
||||
if err := json.Unmarshal(data, &tj); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse tokenizer.json: %w", err)
|
||||
}
|
||||
|
||||
tok := &GLMTokenizer{
|
||||
Vocab: make(map[string]int32),
|
||||
VocabReverse: make(map[int32]string),
|
||||
SpecialTokens: make(map[string]int32),
|
||||
}
|
||||
|
||||
// Load vocab from model section
|
||||
for token, id := range tj.Model.Vocab {
|
||||
tok.Vocab[token] = id
|
||||
tok.VocabReverse[id] = token
|
||||
}
|
||||
|
||||
// Load added tokens (special tokens including dit_tokens)
|
||||
for _, at := range tj.AddedTokens {
|
||||
tok.Vocab[at.Content] = at.ID
|
||||
tok.VocabReverse[at.ID] = at.Content
|
||||
if at.Special {
|
||||
tok.SpecialTokens[at.Content] = at.ID
|
||||
}
|
||||
}
|
||||
|
||||
// Set special token IDs
|
||||
tok.SopTokenID = 167845 // <sop>
|
||||
tok.EopTokenID = 167846 // <eop>
|
||||
tok.BosTokenID = 16384 // <|dit_token_16384|>
|
||||
tok.EosTokenID = 16385 // <|dit_token_16385|>
|
||||
tok.PadTokenID = 16385 // Same as EOS
|
||||
|
||||
// Build sorted token list for greedy matching (longest first)
|
||||
tok.sortedTokens = make([]string, 0, len(tok.Vocab))
|
||||
for token := range tok.Vocab {
|
||||
tok.sortedTokens = append(tok.sortedTokens, token)
|
||||
}
|
||||
sort.Slice(tok.sortedTokens, func(i, j int) bool {
|
||||
return len(tok.sortedTokens[i]) > len(tok.sortedTokens[j])
|
||||
})
|
||||
|
||||
fmt.Printf("Loaded GLM tokenizer with %d tokens\n", len(tok.Vocab))
|
||||
|
||||
return tok, nil
|
||||
}
|
||||
|
||||
// Encode tokenizes a string into token IDs
|
||||
// This uses greedy longest-match tokenization with GPT-2 style space handling
|
||||
func (t *GLMTokenizer) Encode(text string) []int32 {
|
||||
if text == "" {
|
||||
return []int32{}
|
||||
}
|
||||
|
||||
var tokens []int32
|
||||
|
||||
// First, check for and handle special tokens
|
||||
// Replace special tokens with placeholders, encode, then restore
|
||||
specialReplacements := make(map[string]int32)
|
||||
for special, id := range t.SpecialTokens {
|
||||
if strings.Contains(text, special) {
|
||||
specialReplacements[special] = id
|
||||
}
|
||||
}
|
||||
|
||||
// Process text character by character with special token handling
|
||||
i := 0
|
||||
isFirstToken := true
|
||||
|
||||
for i < len(text) {
|
||||
// Check for special tokens first
|
||||
foundSpecial := false
|
||||
for special, id := range specialReplacements {
|
||||
if strings.HasPrefix(text[i:], special) {
|
||||
tokens = append(tokens, id)
|
||||
i += len(special)
|
||||
isFirstToken = false
|
||||
foundSpecial = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if foundSpecial {
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle regular text with GPT-2 style space prefix
|
||||
// "Ġ" (U+0120) represents a space before a token
|
||||
remaining := text[i:]
|
||||
|
||||
// Try to find the longest matching token
|
||||
matched := false
|
||||
for _, token := range t.sortedTokens {
|
||||
// Skip special tokens in regular matching
|
||||
if _, isSpecial := t.SpecialTokens[token]; isSpecial {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if this token matches
|
||||
tokenText := token
|
||||
|
||||
// Handle the Ġ prefix (represents space)
|
||||
if strings.HasPrefix(token, "Ġ") {
|
||||
// This token expects a leading space
|
||||
if i > 0 || !isFirstToken {
|
||||
// Check if remaining starts with space + token content
|
||||
tokenContent := token[len("Ġ"):]
|
||||
if strings.HasPrefix(remaining, " "+tokenContent) {
|
||||
tokens = append(tokens, t.Vocab[token])
|
||||
i += 1 + len(tokenContent) // space + content
|
||||
isFirstToken = false
|
||||
matched = true
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Regular token without space prefix
|
||||
if strings.HasPrefix(remaining, tokenText) {
|
||||
tokens = append(tokens, t.Vocab[token])
|
||||
i += len(tokenText)
|
||||
isFirstToken = false
|
||||
matched = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !matched {
|
||||
// No token found - skip this character (or use UNK)
|
||||
// For now, just skip unknown characters
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
return tokens
|
||||
}
|
||||
|
||||
// EncodeForGeneration encodes a prompt with grid tokens for image generation
|
||||
// Format: {prompt}<sop>{token_h} {token_w}<eop><sop>{prev_h} {prev_w}<eop><|dit_token_16384|>
|
||||
//
|
||||
// Uses GPT-2 style tokenization where " 32" becomes "Ġ32" (a single token with
|
||||
// space prefix), matching the HuggingFace tokenizer behavior.
|
||||
func (t *GLMTokenizer) EncodeForGeneration(prompt string, targetHeight, targetWidth int32) []int32 {
|
||||
// Calculate grid dimensions
|
||||
factor := int32(32)
|
||||
height := (targetHeight / factor) * factor
|
||||
width := (targetWidth / factor) * factor
|
||||
tokenH := height / factor
|
||||
tokenW := width / factor
|
||||
|
||||
// Calculate previous grid dimensions
|
||||
ratio := float64(tokenH) / float64(tokenW)
|
||||
prevTokenH := int32(sqrt(ratio) * 16)
|
||||
prevTokenW := int32(sqrt(1.0/ratio) * 16)
|
||||
|
||||
// Encode the prompt text
|
||||
promptTokens := t.Encode(prompt)
|
||||
|
||||
// Build the full sequence:
|
||||
// [prompt tokens] <sop> [tokenH] [Ġ+tokenW] <eop> <sop> [prevH] [Ġ+prevW] <eop> <bos>
|
||||
// Note: HF tokenizer treats " 32" as "Ġ32" (single token), not "Ġ" + "32"
|
||||
var tokens []int32
|
||||
tokens = append(tokens, promptTokens...)
|
||||
|
||||
// First grid: <sop> H W <eop>
|
||||
// First number has no space prefix, second number has space prefix (Ġ)
|
||||
tokens = append(tokens, t.SopTokenID)
|
||||
tokens = append(tokens, t.encodeNumber(tokenH)...)
|
||||
tokens = append(tokens, t.encodeSpaceNumber(tokenW)...) // " W" as Ġ+W
|
||||
tokens = append(tokens, t.EopTokenID)
|
||||
|
||||
// Second grid: <sop> prevH prevW <eop>
|
||||
tokens = append(tokens, t.SopTokenID)
|
||||
tokens = append(tokens, t.encodeNumber(prevTokenH)...)
|
||||
tokens = append(tokens, t.encodeSpaceNumber(prevTokenW)...) // " prevW" as Ġ+prevW
|
||||
tokens = append(tokens, t.EopTokenID)
|
||||
|
||||
// BOS token (start of image generation)
|
||||
tokens = append(tokens, t.BosTokenID)
|
||||
|
||||
return tokens
|
||||
}
|
||||
|
||||
// encodeNumber encodes a number - first tries as a whole token, falls back to digit-by-digit
|
||||
func (t *GLMTokenizer) encodeNumber(n int32) []int32 {
|
||||
s := fmt.Sprintf("%d", n)
|
||||
// First try: look up the whole number as a single token
|
||||
if id, ok := t.Vocab[s]; ok {
|
||||
return []int32{id}
|
||||
}
|
||||
// Fallback: encode digit by digit
|
||||
var tokens []int32
|
||||
for _, c := range s {
|
||||
if id, ok := t.Vocab[string(c)]; ok {
|
||||
tokens = append(tokens, id)
|
||||
}
|
||||
}
|
||||
return tokens
|
||||
}
|
||||
|
||||
// encodeSpaceNumber encodes " N" as "ĠN" (space-prefixed number) matching HF tokenizer
|
||||
// GPT-2 style: " 32" becomes single token "Ġ32", not "Ġ" + "32"
|
||||
func (t *GLMTokenizer) encodeSpaceNumber(n int32) []int32 {
|
||||
s := fmt.Sprintf("%d", n)
|
||||
|
||||
// First try: look up "Ġ{number}" as a single token (e.g., "Ġ32")
|
||||
spaceToken := "Ġ" + s
|
||||
if id, ok := t.Vocab[spaceToken]; ok {
|
||||
return []int32{id}
|
||||
}
|
||||
|
||||
// Fallback: bare space Ġ + number tokens
|
||||
var tokens []int32
|
||||
if spaceID, ok := t.Vocab["Ġ"]; ok {
|
||||
tokens = append(tokens, spaceID)
|
||||
}
|
||||
tokens = append(tokens, t.encodeNumber(n)...)
|
||||
return tokens
|
||||
}
|
||||
|
||||
// sqrt is a helper for float64 sqrt
|
||||
func sqrt(x float64) float64 {
|
||||
if x <= 0 {
|
||||
return 0
|
||||
}
|
||||
// Newton's method
|
||||
z := x
|
||||
for i := 0; i < 10; i++ {
|
||||
z = z - (z*z-x)/(2*z)
|
||||
}
|
||||
return z
|
||||
}
|
||||
|
||||
// Decode converts token IDs back to a string
|
||||
func (t *GLMTokenizer) Decode(tokens []int32) string {
|
||||
var sb strings.Builder
|
||||
for _, id := range tokens {
|
||||
if token, ok := t.VocabReverse[id]; ok {
|
||||
// Handle Ġ prefix (convert back to space)
|
||||
if strings.HasPrefix(token, "Ġ") {
|
||||
sb.WriteString(" ")
|
||||
sb.WriteString(token[len("Ġ"):])
|
||||
} else {
|
||||
sb.WriteString(token)
|
||||
}
|
||||
}
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
159
x/imagegen/models/glm_image/scheduler.go
Normal file
@@ -0,0 +1,159 @@
|
||||
//go:build mlx
|
||||
|
||||
package glm_image
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// FlowMatchSchedulerConfig holds scheduler configuration
|
||||
type FlowMatchSchedulerConfig struct {
|
||||
NumTrainTimesteps int32 `json:"num_train_timesteps"` // 1000
|
||||
BaseShift float32 `json:"base_shift"` // 0.25
|
||||
MaxShift float32 `json:"max_shift"` // 0.75
|
||||
BaseImageSeqLen int32 `json:"base_image_seq_len"` // 256
|
||||
MaxImageSeqLen int32 `json:"max_image_seq_len"` // 4096
|
||||
UseDynamicShifting bool `json:"use_dynamic_shifting"` // true
|
||||
TimeShiftType string `json:"time_shift_type"` // "linear"
|
||||
}
|
||||
|
||||
// DefaultSchedulerConfig returns the default config for GLM-Image
|
||||
func DefaultSchedulerConfig() *FlowMatchSchedulerConfig {
|
||||
return &FlowMatchSchedulerConfig{
|
||||
NumTrainTimesteps: 1000,
|
||||
BaseShift: 0.25,
|
||||
MaxShift: 0.75,
|
||||
BaseImageSeqLen: 256,
|
||||
MaxImageSeqLen: 4096,
|
||||
UseDynamicShifting: true,
|
||||
TimeShiftType: "linear",
|
||||
}
|
||||
}
|
||||
|
||||
// FlowMatchScheduler implements FlowMatchEulerDiscreteScheduler
|
||||
type FlowMatchScheduler struct {
|
||||
Config *FlowMatchSchedulerConfig
|
||||
Timesteps []float32 // Raw timesteps for transformer conditioning (unshifted)
|
||||
Sigmas []float32 // Shifted sigmas for Euler step calculation
|
||||
NumSteps int
|
||||
}
|
||||
|
||||
// NewFlowMatchScheduler creates a new scheduler
|
||||
func NewFlowMatchScheduler(cfg *FlowMatchSchedulerConfig) *FlowMatchScheduler {
|
||||
return &FlowMatchScheduler{Config: cfg}
|
||||
}
|
||||
|
||||
// SetTimestepsWithDynamicShift sets timesteps with dynamic shifting based on image size
|
||||
// Following diffusers: raw timesteps are used for conditioning, shifted sigmas for step calculation
|
||||
func (s *FlowMatchScheduler) SetTimestepsWithDynamicShift(numSteps int, imgSeqLen int32) {
|
||||
s.NumSteps = numSteps
|
||||
|
||||
// Calculate shift (mu) based on image sequence length
|
||||
mu := s.calculateShift(imgSeqLen)
|
||||
|
||||
// Create timesteps: linspace from sigma_max_t to sigma_min_t
|
||||
// sigma_max = 1.0, sigma_min ~= 0.001 (near 0 but not exactly 0)
|
||||
// Then apply time shift and append terminal sigma=0
|
||||
s.Timesteps = make([]float32, numSteps)
|
||||
s.Sigmas = make([]float32, numSteps+1) // +1 for terminal sigma
|
||||
|
||||
numTrainTimesteps := float32(s.Config.NumTrainTimesteps)
|
||||
|
||||
// Create base sigmas: linspace from 1.0 to small value (matching diffusers)
|
||||
for i := 0; i < numSteps; i++ {
|
||||
// linspace from 1000 to ~20 (sigma_min * num_train_timesteps)
|
||||
tRaw := numTrainTimesteps - float32(i)*(numTrainTimesteps-1.0)/float32(numSteps-1)
|
||||
s.Timesteps[i] = tRaw
|
||||
|
||||
// Convert to sigma [0, 1]
|
||||
sigma := tRaw / numTrainTimesteps
|
||||
|
||||
// Apply time shift if enabled
|
||||
if s.Config.UseDynamicShifting && mu > 0 {
|
||||
sigma = s.applyShift(mu, sigma)
|
||||
}
|
||||
|
||||
s.Sigmas[i] = sigma
|
||||
}
|
||||
|
||||
// Append terminal sigma = 0 (the final clean image)
|
||||
s.Sigmas[numSteps] = 0
|
||||
}
|
||||
|
||||
// calculateShift computes dynamic shift based on image sequence length
|
||||
// Uses the sqrt-based formula from diffusers:
|
||||
// m = (image_seq_len / base_seq_len) ** 0.5
|
||||
// mu = m * max_shift + base_shift
|
||||
func (s *FlowMatchScheduler) calculateShift(imgSeqLen int32) float32 {
|
||||
cfg := s.Config
|
||||
|
||||
if !cfg.UseDynamicShifting {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Sqrt-based shift calculation (matches diffusers pipeline_glm_image.py)
|
||||
m := float32(math.Sqrt(float64(imgSeqLen) / float64(cfg.BaseImageSeqLen)))
|
||||
mu := m*cfg.MaxShift + cfg.BaseShift
|
||||
return mu
|
||||
}
|
||||
|
||||
// applyShift applies time shift transformation
|
||||
// mu: the computed shift value
|
||||
// t: sigma value in [0, 1]
|
||||
func (s *FlowMatchScheduler) applyShift(mu float32, t float32) float32 {
|
||||
if t <= 0 {
|
||||
return 0
|
||||
}
|
||||
if t >= 1 {
|
||||
return 1
|
||||
}
|
||||
|
||||
// sigma=1.0 for both shift types
|
||||
sigma := float32(1.0)
|
||||
|
||||
if s.Config.TimeShiftType == "linear" {
|
||||
// Linear: mu / (mu + (1/t - 1)^sigma)
|
||||
return mu / (mu + float32(math.Pow(float64(1.0/t-1.0), float64(sigma))))
|
||||
}
|
||||
|
||||
// Exponential (default): exp(mu) / (exp(mu) + (1/t - 1)^sigma)
|
||||
expMu := float32(math.Exp(float64(mu)))
|
||||
return expMu / (expMu + float32(math.Pow(float64(1.0/t-1.0), float64(sigma))))
|
||||
}
|
||||
|
||||
// Step performs one denoising step
|
||||
func (s *FlowMatchScheduler) Step(modelOutput, sample *mlx.Array, stepIdx int) *mlx.Array {
|
||||
sigma := s.Sigmas[stepIdx]
|
||||
sigmaNext := s.Sigmas[stepIdx+1]
|
||||
|
||||
// Euler step: x_{t-dt} = x_t + dt * v_t
|
||||
dt := sigmaNext - sigma // Negative (going from noise to clean)
|
||||
|
||||
scaledOutput := mlx.MulScalar(modelOutput, dt)
|
||||
return mlx.Add(sample, scaledOutput)
|
||||
}
|
||||
|
||||
// InitNoise creates initial noise
|
||||
func (s *FlowMatchScheduler) InitNoise(shape []int32, seed int64) *mlx.Array {
|
||||
return mlx.RandomNormalWithDtype(shape, uint64(seed), mlx.DtypeBFloat16)
|
||||
}
|
||||
|
||||
// AddNoise adds noise to clean samples for a given timestep (for img2img)
|
||||
func (s *FlowMatchScheduler) AddNoise(cleanSample, noise *mlx.Array, timestepIdx int) *mlx.Array {
|
||||
// In flow matching: x_t = (1-sigma) * x_0 + sigma * noise
|
||||
// Use sigmas (shifted) for the interpolation
|
||||
sigma := s.Sigmas[timestepIdx]
|
||||
oneMinusSigma := 1.0 - sigma
|
||||
|
||||
scaledClean := mlx.MulScalar(cleanSample, oneMinusSigma)
|
||||
scaledNoise := mlx.MulScalar(noise, sigma)
|
||||
|
||||
return mlx.Add(scaledClean, scaledNoise)
|
||||
}
|
||||
|
||||
// GetTimesteps returns all timesteps
|
||||
func (s *FlowMatchScheduler) GetTimesteps() []float32 {
|
||||
return s.Timesteps
|
||||
}
|
||||
497
x/imagegen/models/glm_image/text_encoder.go
Normal file
@@ -0,0 +1,497 @@
|
||||
//go:build mlx
|
||||
|
||||
package glm_image
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
|
||||
// T5Config holds T5 encoder configuration
|
||||
type T5Config struct {
|
||||
DModel int32 `json:"d_model"` // 1472
|
||||
DFF int32 `json:"d_ff"` // 3584
|
||||
DKV int32 `json:"d_kv"` // 64
|
||||
NumHeads int32 `json:"num_heads"` // 6
|
||||
NumLayers int32 `json:"num_layers"` // 12
|
||||
VocabSize int32 `json:"vocab_size"` // 384 (byte-level)
|
||||
LayerNormEps float32 `json:"layer_norm_epsilon"` // 1e-6
|
||||
IsGatedAct bool `json:"is_gated_act"` // true (gated-gelu)
|
||||
|
||||
// Relative position bias
|
||||
RelativeAttentionNumBuckets int32 `json:"relative_attention_num_buckets"` // 32
|
||||
RelativeAttentionMaxDistance int32 `json:"relative_attention_max_distance"` // 128
|
||||
}
|
||||
|
||||
// T5TextEncoder is the T5 encoder for text conditioning
|
||||
type T5TextEncoder struct {
|
||||
Config *T5Config
|
||||
|
||||
// Embedding (shared for ByT5)
|
||||
SharedEmbed *nn.Embedding `weight:"shared"`
|
||||
|
||||
// Encoder layers
|
||||
Layers []*T5Block `weight:"encoder.block"`
|
||||
|
||||
// Final layer norm
|
||||
FinalNorm *T5LayerNorm `weight:"encoder.final_layer_norm"`
|
||||
|
||||
// Relative position bias (from first layer, shared across all)
|
||||
RelativeAttentionBias *mlx.Array `weight:"encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"`
|
||||
}
|
||||
|
||||
// T5Block is a single T5 encoder block
|
||||
type T5Block struct {
|
||||
// Self attention
|
||||
Layer0 *T5LayerSelfAttention `weight:"layer.0"`
|
||||
// FFN
|
||||
Layer1 *T5LayerFF `weight:"layer.1"`
|
||||
}
|
||||
|
||||
// T5LayerSelfAttention is T5's self-attention layer
|
||||
type T5LayerSelfAttention struct {
|
||||
SelfAttention *T5Attention `weight:"SelfAttention"`
|
||||
LayerNorm *T5LayerNorm `weight:"layer_norm"`
|
||||
}
|
||||
|
||||
// T5Attention implements T5's relative attention
|
||||
type T5Attention struct {
|
||||
Q *mlx.Array `weight:"q.weight"` // No bias in T5
|
||||
K *mlx.Array `weight:"k.weight"`
|
||||
V *mlx.Array `weight:"v.weight"`
|
||||
O *mlx.Array `weight:"o.weight"`
|
||||
|
||||
NHeads int32
|
||||
DKV int32
|
||||
Scale float32
|
||||
}
|
||||
|
||||
// T5LayerFF is T5's feedforward layer with gated-gelu
|
||||
type T5LayerFF struct {
|
||||
DenseReluDense *T5DenseGatedGelu `weight:"DenseReluDense"`
|
||||
LayerNorm *T5LayerNorm `weight:"layer_norm"`
|
||||
}
|
||||
|
||||
// T5DenseGatedGelu is T5's gated-gelu FFN
|
||||
type T5DenseGatedGelu struct {
|
||||
Wi0 *mlx.Array `weight:"wi_0.weight"` // gate projection
|
||||
Wi1 *mlx.Array `weight:"wi_1.weight"` // up projection
|
||||
Wo *mlx.Array `weight:"wo.weight"` // down projection
|
||||
}
|
||||
|
||||
// T5LayerNorm is T5's RMSNorm variant (no bias, no mean subtraction)
|
||||
type T5LayerNorm struct {
|
||||
Weight *mlx.Array `weight:"weight"`
|
||||
Eps float32
|
||||
}
|
||||
|
||||
// Load loads the T5 text encoder from manifest
|
||||
func (m *T5TextEncoder) Load(manifest *imagegen.ModelManifest) error {
|
||||
fmt.Print(" Loading T5 text encoder... ")
|
||||
|
||||
// Load config
|
||||
var cfg T5Config
|
||||
if err := manifest.ReadConfigJSON("text_encoder/config.json", &cfg); err != nil {
|
||||
return fmt.Errorf("config: %w", err)
|
||||
}
|
||||
m.Config = &cfg
|
||||
|
||||
// Pre-allocate layers
|
||||
m.Layers = make([]*T5Block, cfg.NumLayers)
|
||||
|
||||
// Load weights
|
||||
weights, err := imagegen.LoadWeightsFromManifest(manifest, "text_encoder")
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
if err := weights.Load(0); err != nil {
|
||||
return fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
defer weights.ReleaseAll()
|
||||
|
||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||
return fmt.Errorf("load module: %w", err)
|
||||
}
|
||||
|
||||
m.initComputedFields()
|
||||
fmt.Println("✓")
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadFromPath loads the T5 text encoder from a directory path
|
||||
func (m *T5TextEncoder) LoadFromPath(path string) error {
|
||||
fmt.Print(" Loading T5 text encoder... ")
|
||||
|
||||
// Load config
|
||||
var cfg T5Config
|
||||
configPath := filepath.Join(path, "config.json")
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read config: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return fmt.Errorf("parse config: %w", err)
|
||||
}
|
||||
m.Config = &cfg
|
||||
|
||||
// Pre-allocate layers
|
||||
m.Layers = make([]*T5Block, cfg.NumLayers)
|
||||
|
||||
// Load weights from safetensors files
|
||||
weights, err := safetensors.LoadModelWeights(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
if err := weights.Load(0); err != nil {
|
||||
return fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
defer weights.ReleaseAll()
|
||||
|
||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||
return fmt.Errorf("load module: %w", err)
|
||||
}
|
||||
|
||||
m.initComputedFields()
|
||||
fmt.Println("✓")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *T5TextEncoder) initComputedFields() {
|
||||
cfg := m.Config
|
||||
m.FinalNorm.Eps = cfg.LayerNormEps
|
||||
for _, block := range m.Layers {
|
||||
attn := block.Layer0.SelfAttention
|
||||
attn.NHeads = cfg.NumHeads
|
||||
attn.DKV = cfg.DKV
|
||||
attn.Scale = float32(1.0 / math.Sqrt(float64(cfg.DKV)))
|
||||
|
||||
block.Layer0.LayerNorm.Eps = cfg.LayerNormEps
|
||||
block.Layer1.LayerNorm.Eps = cfg.LayerNormEps
|
||||
}
|
||||
}
|
||||
|
||||
// Forward encodes text tokens
|
||||
func (m *T5TextEncoder) Forward(tokens *mlx.Array) *mlx.Array {
|
||||
cfg := m.Config
|
||||
|
||||
// Get embeddings
|
||||
h := m.SharedEmbed.Forward(tokens)
|
||||
|
||||
// Compute relative position bias once
|
||||
seqLen := tokens.Shape()[1]
|
||||
posBias := m.computeRelativePositionBias(seqLen)
|
||||
|
||||
// Forward through layers
|
||||
for _, block := range m.Layers {
|
||||
h = block.Forward(h, posBias, cfg.LayerNormEps)
|
||||
}
|
||||
|
||||
// Final norm
|
||||
h = m.FinalNorm.Forward(h)
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// extractGlyphTexts extracts quoted text (glyphs) from the prompt
|
||||
// This matches diffusers' get_glyph_texts from pipeline_glm_image.py
|
||||
// Glyph texts are used for text rendering guidance in the generated image
|
||||
func extractGlyphTexts(prompt string) []string {
|
||||
var glyphTexts []string
|
||||
|
||||
// Extract text in single quotes: 'text'
|
||||
re1 := regexp.MustCompile(`'([^']*)'`)
|
||||
for _, match := range re1.FindAllStringSubmatch(prompt, -1) {
|
||||
if len(match) > 1 {
|
||||
glyphTexts = append(glyphTexts, match[1])
|
||||
}
|
||||
}
|
||||
|
||||
// Extract text in Unicode curly double quotes: "text"
|
||||
re2 := regexp.MustCompile(`"([^""]*)"`)
|
||||
for _, match := range re2.FindAllStringSubmatch(prompt, -1) {
|
||||
if len(match) > 1 {
|
||||
glyphTexts = append(glyphTexts, match[1])
|
||||
}
|
||||
}
|
||||
|
||||
// Extract text in ASCII double quotes: "text"
|
||||
re3 := regexp.MustCompile(`"([^"]*)"`)
|
||||
for _, match := range re3.FindAllStringSubmatch(prompt, -1) {
|
||||
if len(match) > 1 {
|
||||
glyphTexts = append(glyphTexts, match[1])
|
||||
}
|
||||
}
|
||||
|
||||
// Extract text in Japanese quotes: 「text」
|
||||
re4 := regexp.MustCompile(`「([^「」]*)」`)
|
||||
for _, match := range re4.FindAllStringSubmatch(prompt, -1) {
|
||||
if len(match) > 1 {
|
||||
glyphTexts = append(glyphTexts, match[1])
|
||||
}
|
||||
}
|
||||
|
||||
return glyphTexts
|
||||
}
|
||||
|
||||
// EncodePrompt encodes the prompt text using the ByT5 tokenizer and encoder
|
||||
// This provides text conditioning for the diffusion transformer via the glyph projector
|
||||
//
|
||||
// IMPORTANT: This encodes only the GLYPH TEXTS (quoted strings in the prompt), not the
|
||||
// full prompt. Glyph texts are used for text rendering guidance in the generated image.
|
||||
// Multiple glyph texts are encoded and concatenated to form the conditioning signal.
|
||||
// This matches diffusers' _get_glyph_embeds() behavior.
|
||||
func (m *T5TextEncoder) EncodePrompt(tok *ByT5Tokenizer, prompt string) *mlx.Array {
|
||||
// Extract glyph texts from prompt (text in quotes)
|
||||
glyphTexts := extractGlyphTexts(prompt)
|
||||
|
||||
// If no glyph texts found, encode empty string (matches diffusers: [""] fallback)
|
||||
if len(glyphTexts) == 0 {
|
||||
glyphTexts = []string{""}
|
||||
}
|
||||
|
||||
// Encode each glyph text and collect token sequences
|
||||
// Matching diffusers' _get_glyph_embeds() which batches all glyph texts
|
||||
var allTokenSeqs [][]int32
|
||||
|
||||
for _, glyphText := range glyphTexts {
|
||||
// ByT5 uses byte-level encoding: each byte (0-255) -> token (3-258)
|
||||
tokens := tok.Encode(glyphText)
|
||||
|
||||
// Add EOS token (1) at the end to match HuggingFace tokenizer behavior
|
||||
tokens = append(tokens, tok.EOSTokenID)
|
||||
|
||||
allTokenSeqs = append(allTokenSeqs, tokens)
|
||||
}
|
||||
|
||||
// Process each glyph text through the encoder
|
||||
var allEmbeddings []*mlx.Array
|
||||
for _, tokens := range allTokenSeqs {
|
||||
tokenLen := len(tokens)
|
||||
if tokenLen == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Create token array [1, L]
|
||||
tokensArr := mlx.NewArrayInt32(tokens, []int32{1, int32(tokenLen)})
|
||||
|
||||
// Forward through encoder
|
||||
output := m.Forward(tokensArr)
|
||||
mlx.Eval(output)
|
||||
|
||||
allEmbeddings = append(allEmbeddings, output)
|
||||
}
|
||||
|
||||
// Concatenate all glyph embeddings along sequence dimension
|
||||
var output *mlx.Array
|
||||
if len(allEmbeddings) == 0 {
|
||||
// Fallback: return single zero embedding
|
||||
output = mlx.Zeros([]int32{1, 1, m.Config.DModel}, mlx.DtypeBFloat16)
|
||||
} else if len(allEmbeddings) == 1 {
|
||||
output = allEmbeddings[0]
|
||||
} else {
|
||||
output = mlx.Concatenate(allEmbeddings, 1)
|
||||
}
|
||||
mlx.Eval(output)
|
||||
|
||||
return output
|
||||
}
|
||||
|
||||
// computeRelativePositionBias computes T5's relative position encoding
|
||||
func (m *T5TextEncoder) computeRelativePositionBias(seqLen int32) *mlx.Array {
|
||||
cfg := m.Config
|
||||
|
||||
// Create relative position matrix
|
||||
// For each (query_pos, key_pos) pair, compute bucketed relative position
|
||||
numBuckets := cfg.RelativeAttentionNumBuckets
|
||||
maxDistance := cfg.RelativeAttentionMaxDistance
|
||||
|
||||
// Create position indices
|
||||
contextPos := make([]int32, seqLen*seqLen)
|
||||
memoryPos := make([]int32, seqLen*seqLen)
|
||||
for i := int32(0); i < seqLen; i++ {
|
||||
for j := int32(0); j < seqLen; j++ {
|
||||
contextPos[i*seqLen+j] = i
|
||||
memoryPos[i*seqLen+j] = j
|
||||
}
|
||||
}
|
||||
|
||||
// Compute relative positions and bucket them
|
||||
buckets := make([]int32, seqLen*seqLen)
|
||||
for i := int32(0); i < seqLen*seqLen; i++ {
|
||||
relPos := memoryPos[i] - contextPos[i]
|
||||
buckets[i] = relativePosistionBucket(relPos, numBuckets, maxDistance, false)
|
||||
}
|
||||
|
||||
// Create bucket indices array
|
||||
bucketsArr := mlx.NewArrayInt32(buckets, []int32{seqLen, seqLen})
|
||||
|
||||
// Look up bias: RelativeAttentionBias shape is [numBuckets, numHeads] = [32, 6]
|
||||
// Take along axis 0 (buckets dimension) -> [seqLen, seqLen, numHeads]
|
||||
bias := mlx.Take(m.RelativeAttentionBias, bucketsArr, 0) // [seqLen, seqLen, numHeads]
|
||||
|
||||
// Transpose to [numHeads, seqLen, seqLen]
|
||||
bias = mlx.Transpose(bias, 2, 0, 1) // [numHeads, seqLen, seqLen]
|
||||
bias = mlx.ExpandDims(bias, 0) // [1, numHeads, seqLen, seqLen]
|
||||
|
||||
return bias
|
||||
}
|
||||
|
||||
// relativePosistionBucket computes the bucket for a relative position
|
||||
func relativePosistionBucket(relativePosition int32, numBuckets int32, maxDistance int32, bidirectional bool) int32 {
|
||||
var bucket int32 = 0
|
||||
var n int32 = -relativePosition
|
||||
|
||||
if bidirectional {
|
||||
numBuckets /= 2
|
||||
if n < 0 {
|
||||
bucket += numBuckets
|
||||
n = -n
|
||||
}
|
||||
} else {
|
||||
if n < 0 {
|
||||
n = 0
|
||||
}
|
||||
}
|
||||
|
||||
// Half buckets are for exact positions, half are for log-spaced
|
||||
maxExact := numBuckets / 2
|
||||
if n < maxExact {
|
||||
bucket += n
|
||||
} else {
|
||||
// Log-spaced buckets
|
||||
logVal := math.Log(float64(n)/float64(maxExact)) / math.Log(float64(maxDistance)/float64(maxExact))
|
||||
bucket += maxExact + int32(logVal*float64(numBuckets-maxExact))
|
||||
if bucket > numBuckets-1 {
|
||||
bucket = numBuckets - 1
|
||||
}
|
||||
}
|
||||
|
||||
return bucket
|
||||
}
|
||||
|
||||
// Forward for T5Block
|
||||
func (b *T5Block) Forward(x *mlx.Array, posBias *mlx.Array, eps float32) *mlx.Array {
|
||||
// Self attention with residual
|
||||
h := b.Layer0.Forward(x, posBias, eps)
|
||||
|
||||
// FFN with residual
|
||||
h = b.Layer1.Forward(h, eps)
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// Forward for T5LayerSelfAttention
|
||||
func (l *T5LayerSelfAttention) Forward(x *mlx.Array, posBias *mlx.Array, eps float32) *mlx.Array {
|
||||
// Pre-norm
|
||||
normed := l.LayerNorm.Forward(x)
|
||||
|
||||
// Attention
|
||||
attnOut := l.SelfAttention.Forward(normed, posBias)
|
||||
|
||||
// Residual
|
||||
return mlx.Add(x, attnOut)
|
||||
}
|
||||
|
||||
// Forward for T5Attention
|
||||
func (attn *T5Attention) Forward(x *mlx.Array, posBias *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
D := shape[2]
|
||||
|
||||
// Q, K, V projections (no bias)
|
||||
// Weights are [out_features, in_features], so we use matmul with transpose
|
||||
q := mlx.Matmul(x, mlx.Transpose(attn.Q, 1, 0))
|
||||
k := mlx.Matmul(x, mlx.Transpose(attn.K, 1, 0))
|
||||
v := mlx.Matmul(x, mlx.Transpose(attn.V, 1, 0))
|
||||
|
||||
// Reshape to [B, L, nheads, d_kv]
|
||||
q = mlx.Reshape(q, B, L, attn.NHeads, attn.DKV)
|
||||
k = mlx.Reshape(k, B, L, attn.NHeads, attn.DKV)
|
||||
v = mlx.Reshape(v, B, L, attn.NHeads, attn.DKV)
|
||||
|
||||
// Transpose to [B, nheads, L, d_kv]
|
||||
q = mlx.Transpose(q, 0, 2, 1, 3)
|
||||
k = mlx.Transpose(k, 0, 2, 1, 3)
|
||||
v = mlx.Transpose(v, 0, 2, 1, 3)
|
||||
|
||||
// Attention scores with relative position bias
|
||||
// T5 uses UNSCALED dot-product attention: scores = q @ k.T + pos_bias
|
||||
// (no 1/sqrt(d_k) scale factor like in standard transformers)
|
||||
scores := mlx.Matmul(q, mlx.Transpose(k, 0, 1, 3, 2))
|
||||
scores = mlx.Add(scores, posBias)
|
||||
|
||||
// Softmax
|
||||
attnWeights := mlx.Softmax(scores, -1)
|
||||
|
||||
// Attend to values
|
||||
out := mlx.Matmul(attnWeights, v)
|
||||
|
||||
// Transpose back [B, nheads, L, d_kv] -> [B, L, nheads, d_kv]
|
||||
out = mlx.Transpose(out, 0, 2, 1, 3)
|
||||
// Reshape to [B, L, D]
|
||||
out = mlx.Reshape(out, B, L, attn.NHeads*attn.DKV)
|
||||
|
||||
// Output projection
|
||||
out = mlx.Matmul(out, mlx.Transpose(attn.O, 1, 0))
|
||||
|
||||
_ = D // Silence unused warning
|
||||
return out
|
||||
}
|
||||
|
||||
// Forward for T5LayerFF
|
||||
func (l *T5LayerFF) Forward(x *mlx.Array, eps float32) *mlx.Array {
|
||||
// Pre-norm
|
||||
normed := l.LayerNorm.Forward(x)
|
||||
|
||||
// FFN
|
||||
ffOut := l.DenseReluDense.Forward(normed)
|
||||
|
||||
// Residual
|
||||
return mlx.Add(x, ffOut)
|
||||
}
|
||||
|
||||
// geluNew implements the GELU activation with tanh approximation (gelu_new)
|
||||
// This matches HuggingFace transformers' gelu_new/OpenAI GPT implementation
|
||||
// Formula: 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x³)))
|
||||
func geluNew(x *mlx.Array) *mlx.Array {
|
||||
sqrt2OverPi := float32(0.7978845608) // sqrt(2/π)
|
||||
coeff := float32(0.044715)
|
||||
|
||||
x3 := mlx.Mul(mlx.Mul(x, x), x)
|
||||
inner := mlx.MulScalar(mlx.Add(x, mlx.MulScalar(x3, coeff)), sqrt2OverPi)
|
||||
return mlx.Mul(mlx.MulScalar(x, 0.5), mlx.AddScalar(mlx.Tanh(inner), 1.0))
|
||||
}
|
||||
|
||||
// Forward for T5DenseGatedGelu (gated-gelu activation)
|
||||
func (d *T5DenseGatedGelu) Forward(x *mlx.Array) *mlx.Array {
|
||||
// Gate projection with GELU activation (T5 v1.1/ByT5 uses gelu_new)
|
||||
gate := mlx.Matmul(x, mlx.Transpose(d.Wi0, 1, 0))
|
||||
gate = geluNew(gate)
|
||||
|
||||
// Up projection
|
||||
up := mlx.Matmul(x, mlx.Transpose(d.Wi1, 1, 0))
|
||||
|
||||
// Gated output
|
||||
h := mlx.Mul(gate, up)
|
||||
|
||||
// Down projection
|
||||
return mlx.Matmul(h, mlx.Transpose(d.Wo, 1, 0))
|
||||
}
|
||||
|
||||
// Forward for T5LayerNorm (RMSNorm variant)
|
||||
func (ln *T5LayerNorm) Forward(x *mlx.Array) *mlx.Array {
|
||||
// T5 uses RMSNorm: x * rsqrt(mean(x^2) + eps) * weight
|
||||
variance := mlx.Mean(mlx.Square(x), -1, true)
|
||||
x = mlx.Mul(x, mlx.RSqrt(mlx.AddScalar(variance, ln.Eps)))
|
||||
return mlx.Mul(x, ln.Weight)
|
||||
}
|
||||
1255
x/imagegen/models/glm_image/transformer.go
Normal file
477
x/imagegen/models/glm_image/vae.go
Normal file
@@ -0,0 +1,477 @@
|
||||
//go:build mlx
|
||||
|
||||
package glm_image
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
|
||||
// VAEConfig holds VAE decoder configuration
|
||||
type VAEConfig struct {
|
||||
InChannels int32 `json:"in_channels"` // 3
|
||||
OutChannels int32 `json:"out_channels"` // 3
|
||||
LatentChannels int32 `json:"latent_channels"` // 16
|
||||
BlockOutChannels []int32 `json:"block_out_channels"` // [128, 512, 1024, 1024]
|
||||
LayersPerBlock int32 `json:"layers_per_block"` // 3
|
||||
NormNumGroups int32 `json:"norm_num_groups"` // 32
|
||||
ScalingFactor float32 `json:"scaling_factor"` // 0.18215
|
||||
ShiftFactor *float32 `json:"shift_factor"` // null
|
||||
LatentsMean []float32 `json:"latents_mean"` // [16 values]
|
||||
LatentsStd []float32 `json:"latents_std"` // [16 values]
|
||||
}
|
||||
|
||||
// VAEDecoder is the VAE latent decoder
|
||||
type VAEDecoder struct {
|
||||
Config *VAEConfig
|
||||
|
||||
// Decoder components
|
||||
ConvIn *VAEConv2d `weight:"decoder.conv_in"`
|
||||
MidBlock *VAEMidBlock `weight:"decoder.mid_block"`
|
||||
UpBlocks []*VAEUpBlock `weight:"decoder.up_blocks"`
|
||||
ConvNormOut *GroupNorm `weight:"decoder.conv_norm_out"`
|
||||
ConvOut *VAEConv2d `weight:"decoder.conv_out"`
|
||||
}
|
||||
|
||||
// VAEConv2d is a 2D convolution layer
|
||||
type VAEConv2d struct {
|
||||
Weight *mlx.Array `weight:"weight"`
|
||||
Bias *mlx.Array `weight:"bias"`
|
||||
Stride int32
|
||||
Padding int32
|
||||
}
|
||||
|
||||
// GroupNorm is group normalization
|
||||
type GroupNorm struct {
|
||||
Weight *mlx.Array `weight:"weight"`
|
||||
Bias *mlx.Array `weight:"bias"`
|
||||
NumGroups int32
|
||||
Eps float32
|
||||
}
|
||||
|
||||
// VAEMidBlock is the middle block of the VAE
|
||||
type VAEMidBlock struct {
|
||||
Resnets []*VAEResnetBlock `weight:"resnets"`
|
||||
}
|
||||
|
||||
// VAEUpBlock is an upsampling block
|
||||
type VAEUpBlock struct {
|
||||
Resnets []*VAEResnetBlock `weight:"resnets"`
|
||||
Upsamplers []*VAEUpsampler `weight:"upsamplers"`
|
||||
}
|
||||
|
||||
// VAEResnetBlock is a residual block
|
||||
type VAEResnetBlock struct {
|
||||
Norm1 *GroupNorm `weight:"norm1"`
|
||||
Conv1 *VAEConv2d `weight:"conv1"`
|
||||
Norm2 *GroupNorm `weight:"norm2"`
|
||||
Conv2 *VAEConv2d `weight:"conv2"`
|
||||
ConvShortcut *VAEConv2d `weight:"conv_shortcut,optional"` // Optional, for channel mismatch
|
||||
}
|
||||
|
||||
// VAEUpsampler is an upsampling layer
|
||||
type VAEUpsampler struct {
|
||||
Conv *VAEConv2d `weight:"conv"`
|
||||
}
|
||||
|
||||
// Load loads the VAE decoder from manifest
|
||||
func (m *VAEDecoder) Load(manifest *imagegen.ModelManifest) error {
|
||||
fmt.Print(" Loading VAE decoder... ")
|
||||
|
||||
// Load config
|
||||
var cfg VAEConfig
|
||||
if err := manifest.ReadConfigJSON("vae/config.json", &cfg); err != nil {
|
||||
return fmt.Errorf("config: %w", err)
|
||||
}
|
||||
m.Config = &cfg
|
||||
|
||||
// Initialize structure based on config
|
||||
numBlocks := len(cfg.BlockOutChannels)
|
||||
m.UpBlocks = make([]*VAEUpBlock, numBlocks)
|
||||
|
||||
// Pre-allocate MidBlock resnets (VAE mid_block typically has 2 resnets)
|
||||
m.MidBlock = &VAEMidBlock{
|
||||
Resnets: make([]*VAEResnetBlock, 2),
|
||||
}
|
||||
|
||||
// Pre-allocate UpBlocks with their resnets and upsamplers
|
||||
// VAE decoder has layers_per_block+1 resnets per up_block (to match encoder)
|
||||
// And all but the last up_block has an upsampler
|
||||
for i := 0; i < numBlocks; i++ {
|
||||
numResnets := cfg.LayersPerBlock + 1 // typically 4 resnets
|
||||
m.UpBlocks[i] = &VAEUpBlock{
|
||||
Resnets: make([]*VAEResnetBlock, numResnets),
|
||||
}
|
||||
// All but the last block has upsamplers
|
||||
if i < numBlocks-1 {
|
||||
m.UpBlocks[i].Upsamplers = make([]*VAEUpsampler, 1)
|
||||
}
|
||||
}
|
||||
|
||||
// Load weights
|
||||
weights, err := imagegen.LoadWeightsFromManifest(manifest, "vae")
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
|
||||
return fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
defer weights.ReleaseAll()
|
||||
|
||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||
return fmt.Errorf("load module: %w", err)
|
||||
}
|
||||
|
||||
// Initialize GroupNorm parameters
|
||||
m.initGroupNorms()
|
||||
|
||||
fmt.Println("✓")
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadFromPath loads the VAE decoder from a directory path
|
||||
func (m *VAEDecoder) LoadFromPath(path string) error {
|
||||
fmt.Print(" Loading VAE decoder... ")
|
||||
|
||||
// Load config
|
||||
var cfg VAEConfig
|
||||
configPath := filepath.Join(path, "config.json")
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read config: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return fmt.Errorf("parse config: %w", err)
|
||||
}
|
||||
m.Config = &cfg
|
||||
|
||||
// Initialize structure based on config
|
||||
numBlocks := len(cfg.BlockOutChannels)
|
||||
m.UpBlocks = make([]*VAEUpBlock, numBlocks)
|
||||
|
||||
// Pre-allocate MidBlock resnets (VAE mid_block typically has 2 resnets)
|
||||
m.MidBlock = &VAEMidBlock{
|
||||
Resnets: make([]*VAEResnetBlock, 2),
|
||||
}
|
||||
|
||||
// Pre-allocate UpBlocks with their resnets and upsamplers
|
||||
for i := 0; i < numBlocks; i++ {
|
||||
numResnets := cfg.LayersPerBlock + 1
|
||||
m.UpBlocks[i] = &VAEUpBlock{
|
||||
Resnets: make([]*VAEResnetBlock, numResnets),
|
||||
}
|
||||
if i < numBlocks-1 {
|
||||
m.UpBlocks[i].Upsamplers = make([]*VAEUpsampler, 1)
|
||||
}
|
||||
}
|
||||
|
||||
// Load weights from safetensors files
|
||||
weights, err := safetensors.LoadModelWeights(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
|
||||
return fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
defer weights.ReleaseAll()
|
||||
|
||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||
return fmt.Errorf("load module: %w", err)
|
||||
}
|
||||
|
||||
// Initialize GroupNorm parameters
|
||||
m.initGroupNorms()
|
||||
|
||||
fmt.Println("✓")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *VAEDecoder) initGroupNorms() {
|
||||
cfg := m.Config
|
||||
numGroups := cfg.NormNumGroups
|
||||
eps := float32(1e-6) // Must match diffusers VAE (1e-6, not 1e-5)
|
||||
|
||||
if m.ConvNormOut != nil {
|
||||
m.ConvNormOut.NumGroups = numGroups
|
||||
m.ConvNormOut.Eps = eps
|
||||
}
|
||||
|
||||
if m.MidBlock != nil {
|
||||
for _, resnet := range m.MidBlock.Resnets {
|
||||
if resnet.Norm1 != nil {
|
||||
resnet.Norm1.NumGroups = numGroups
|
||||
resnet.Norm1.Eps = eps
|
||||
}
|
||||
if resnet.Norm2 != nil {
|
||||
resnet.Norm2.NumGroups = numGroups
|
||||
resnet.Norm2.Eps = eps
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, upBlock := range m.UpBlocks {
|
||||
if upBlock == nil {
|
||||
continue
|
||||
}
|
||||
for _, resnet := range upBlock.Resnets {
|
||||
if resnet == nil {
|
||||
continue
|
||||
}
|
||||
if resnet.Norm1 != nil {
|
||||
resnet.Norm1.NumGroups = numGroups
|
||||
resnet.Norm1.Eps = eps
|
||||
}
|
||||
if resnet.Norm2 != nil {
|
||||
resnet.Norm2.NumGroups = numGroups
|
||||
resnet.Norm2.Eps = eps
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Decode decodes latents to an image
|
||||
func (m *VAEDecoder) Decode(latents *mlx.Array) *mlx.Array {
|
||||
cfg := m.Config
|
||||
|
||||
// Apply latent denormalization if mean/std are provided
|
||||
// This matches diffusers GLM-Image: latents = latents * std + mean
|
||||
// Note: GLM-Image does NOT divide by scaling_factor (unlike standard SD VAEs)
|
||||
if len(cfg.LatentsMean) > 0 && len(cfg.LatentsStd) > 0 {
|
||||
latents = m.denormalizeLatents(latents)
|
||||
}
|
||||
|
||||
// Convert from NCHW to NHWC for processing
|
||||
// [B, C, H, W] -> [B, H, W, C]
|
||||
x := mlx.Transpose(latents, 0, 2, 3, 1)
|
||||
|
||||
// Initial convolution
|
||||
x = m.ConvIn.Forward(x)
|
||||
|
||||
// Mid block
|
||||
x = m.MidBlock.Forward(x)
|
||||
|
||||
// Up blocks (forward order - index 0 is at lowest resolution/highest channels)
|
||||
for i := 0; i < len(m.UpBlocks); i++ {
|
||||
if m.UpBlocks[i] != nil {
|
||||
x = m.UpBlocks[i].Forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
// Final normalization and convolution
|
||||
x = m.ConvNormOut.Forward(x)
|
||||
x = mlx.SiLU(x)
|
||||
x = m.ConvOut.Forward(x)
|
||||
|
||||
// Convert back to NCHW
|
||||
// [B, H, W, C] -> [B, C, H, W]
|
||||
x = mlx.Transpose(x, 0, 3, 1, 2)
|
||||
|
||||
// Clamp to valid range and convert to [0, 1]
|
||||
x = mlx.ClipScalar(x, -1.0, 1.0, true, true)
|
||||
x = mlx.AddScalar(x, 1.0)
|
||||
x = mlx.DivScalar(x, 2.0)
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// denormalizeLatents applies the latent mean/std denormalization
|
||||
func (m *VAEDecoder) denormalizeLatents(latents *mlx.Array) *mlx.Array {
|
||||
cfg := m.Config
|
||||
|
||||
// Create mean and std arrays [1, C, 1, 1] for broadcasting
|
||||
mean := mlx.NewArray(cfg.LatentsMean, []int32{1, int32(len(cfg.LatentsMean)), 1, 1})
|
||||
std := mlx.NewArray(cfg.LatentsStd, []int32{1, int32(len(cfg.LatentsStd)), 1, 1})
|
||||
|
||||
// Denormalize: latents * std + mean
|
||||
latents = mlx.Mul(latents, std)
|
||||
latents = mlx.Add(latents, mean)
|
||||
|
||||
return latents
|
||||
}
|
||||
|
||||
// Forward for VAEConv2d
|
||||
func (c *VAEConv2d) Forward(x *mlx.Array) *mlx.Array {
|
||||
// x: [B, H, W, C_in] (NHWC)
|
||||
// PyTorch weight: [C_out, C_in, kH, kW] (OIHW)
|
||||
// MLX conv2d expects weight: [C_out, kH, kW, C_in] (OHWI)
|
||||
// So we need to transpose from OIHW to OHWI
|
||||
|
||||
stride := c.Stride
|
||||
if stride == 0 {
|
||||
stride = 1
|
||||
}
|
||||
padding := c.Padding
|
||||
if padding == 0 {
|
||||
// Default to same padding for 3x3 kernels
|
||||
wShape := c.Weight.Shape()
|
||||
if len(wShape) >= 3 && wShape[2] == 3 {
|
||||
padding = 1
|
||||
}
|
||||
}
|
||||
|
||||
// Transpose weight from OIHW [out, in, h, w] to OHWI [out, h, w, in]
|
||||
weight := mlx.Transpose(c.Weight, 0, 2, 3, 1)
|
||||
|
||||
out := mlx.Conv2d(x, weight, stride, padding)
|
||||
if c.Bias != nil {
|
||||
// Bias: [C_out] -> [1, 1, 1, C_out]
|
||||
bias := mlx.Reshape(c.Bias, 1, 1, 1, -1)
|
||||
out = mlx.Add(out, bias)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// Forward for GroupNorm
|
||||
func (gn *GroupNorm) Forward(x *mlx.Array) *mlx.Array {
|
||||
// x: [B, H, W, C] (NHWC)
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
H := shape[1]
|
||||
W := shape[2]
|
||||
C := shape[3]
|
||||
|
||||
numGroups := gn.NumGroups
|
||||
if numGroups == 0 {
|
||||
numGroups = 32
|
||||
}
|
||||
groupSize := C / numGroups
|
||||
|
||||
// Reshape to [B, H, W, groups, groupSize]
|
||||
x = mlx.Reshape(x, B, H, W, numGroups, groupSize)
|
||||
|
||||
// Compute mean and variance per group
|
||||
mean := mlx.Mean(x, 1, true)
|
||||
mean = mlx.Mean(mean, 2, true)
|
||||
mean = mlx.Mean(mean, 4, true)
|
||||
|
||||
xCentered := mlx.Sub(x, mean)
|
||||
variance := mlx.Mean(mlx.Square(xCentered), 1, true)
|
||||
variance = mlx.Mean(variance, 2, true)
|
||||
variance = mlx.Mean(variance, 4, true)
|
||||
|
||||
// Normalize
|
||||
xNorm := mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, gn.Eps)))
|
||||
|
||||
// Reshape back
|
||||
xNorm = mlx.Reshape(xNorm, B, H, W, C)
|
||||
|
||||
// Scale and shift
|
||||
if gn.Weight != nil {
|
||||
weight := mlx.Reshape(gn.Weight, 1, 1, 1, C)
|
||||
xNorm = mlx.Mul(xNorm, weight)
|
||||
}
|
||||
if gn.Bias != nil {
|
||||
bias := mlx.Reshape(gn.Bias, 1, 1, 1, C)
|
||||
xNorm = mlx.Add(xNorm, bias)
|
||||
}
|
||||
|
||||
return xNorm
|
||||
}
|
||||
|
||||
// Forward for VAEMidBlock
|
||||
func (mb *VAEMidBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
for _, resnet := range mb.Resnets {
|
||||
x = resnet.Forward(x)
|
||||
}
|
||||
return x
|
||||
}
|
||||
|
||||
// Forward for VAEUpBlock
|
||||
func (ub *VAEUpBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
// Apply resnets
|
||||
for _, resnet := range ub.Resnets {
|
||||
if resnet != nil {
|
||||
x = resnet.Forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
// Apply upsamplers
|
||||
for _, upsampler := range ub.Upsamplers {
|
||||
if upsampler != nil {
|
||||
x = upsampler.Forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// Forward for VAEResnetBlock
|
||||
func (rb *VAEResnetBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
residual := x
|
||||
|
||||
// First norm + activation + conv
|
||||
h := rb.Norm1.Forward(x)
|
||||
h = mlx.SiLU(h)
|
||||
h = rb.Conv1.Forward(h)
|
||||
|
||||
// Second norm + activation + conv
|
||||
h = rb.Norm2.Forward(h)
|
||||
h = mlx.SiLU(h)
|
||||
h = rb.Conv2.Forward(h)
|
||||
|
||||
// Shortcut for channel mismatch
|
||||
if rb.ConvShortcut != nil {
|
||||
residual = rb.ConvShortcut.Forward(residual)
|
||||
}
|
||||
|
||||
return mlx.Add(h, residual)
|
||||
}
|
||||
|
||||
// Forward for VAEUpsampler (2x nearest neighbor upsample + conv)
|
||||
func (us *VAEUpsampler) Forward(x *mlx.Array) *mlx.Array {
|
||||
// x: [B, H, W, C]
|
||||
// 2x nearest neighbor upsample
|
||||
x = upsample2x(x)
|
||||
|
||||
// Conv
|
||||
if us.Conv != nil {
|
||||
x = us.Conv.Forward(x)
|
||||
}
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// upsample2x performs 2x nearest neighbor upsampling.
|
||||
// Input and output are in NHWC format: [B, H, W, C] -> [B, H*2, W*2, C]
|
||||
func upsample2x(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
H := shape[1]
|
||||
W := shape[2]
|
||||
C := shape[3]
|
||||
|
||||
// Create indices [0, 0, 1, 1, 2, 2, ...] for nearest neighbor
|
||||
hIndices := make([]int32, H*2)
|
||||
for i := int32(0); i < H; i++ {
|
||||
hIndices[i*2] = i
|
||||
hIndices[i*2+1] = i
|
||||
}
|
||||
wIndices := make([]int32, W*2)
|
||||
for i := int32(0); i < W; i++ {
|
||||
wIndices[i*2] = i
|
||||
wIndices[i*2+1] = i
|
||||
}
|
||||
|
||||
hIdx := mlx.NewArrayInt32(hIndices, []int32{H * 2})
|
||||
wIdx := mlx.NewArrayInt32(wIndices, []int32{W * 2})
|
||||
|
||||
// Take along height axis
|
||||
x = mlx.Reshape(x, B*H, W, C)
|
||||
x = mlx.Take(x, wIdx, 1) // [B*H, W*2, C]
|
||||
x = mlx.Reshape(x, B, H, W*2, C)
|
||||
|
||||
// Take along width axis - transpose to [B, W*2, H, C], take, transpose back
|
||||
x = mlx.Transpose(x, 0, 2, 1, 3) // [B, W*2, H, C]
|
||||
x = mlx.Reshape(x, B*(W*2), H, C)
|
||||
x = mlx.Take(x, hIdx, 1) // [B*(W*2), H*2, C]
|
||||
x = mlx.Reshape(x, B, W*2, H*2, C)
|
||||
x = mlx.Transpose(x, 0, 2, 1, 3) // [B, H*2, W*2, C]
|
||||
|
||||
return x
|
||||
}
|
||||
982
x/imagegen/models/glm_image/vision_language_encoder.go
Normal file
@@ -0,0 +1,982 @@
|
||||
//go:build mlx
|
||||
|
||||
package glm_image
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
|
||||
// VisionLanguageConfig holds GLM-Image AR generator configuration
|
||||
type VisionLanguageConfig struct {
|
||||
// Text model config
|
||||
HiddenSize int32 `json:"hidden_size"` // 4096
|
||||
NumHiddenLayers int32 `json:"num_hidden_layers"` // 40
|
||||
IntermediateSize int32 `json:"intermediate_size"` // 13696
|
||||
NumAttentionHeads int32 `json:"num_attention_heads"` // 32
|
||||
NumKeyValueHeads int32 `json:"num_key_value_heads"` // 2
|
||||
VocabSize int32 `json:"vocab_size"` // 168064
|
||||
RMSNormEps float32 `json:"rms_norm_eps"` // 1e-5
|
||||
|
||||
// RoPE config
|
||||
RopeTheta float32 `json:"rope_theta"` // 10000
|
||||
PartialRotaryFactor float32 `json:"partial_rotary_factor"` // 0.5
|
||||
MRoPESection []int32 `json:"mrope_section"` // [8, 12, 12]
|
||||
|
||||
// Visual token config
|
||||
VisionVocabSize int32 `json:"vision_vocab_size"` // 16512
|
||||
ImageStartTokenID int32 `json:"image_start_token_id"` // 16384
|
||||
ImageEndTokenID int32 `json:"image_end_token_id"` // 16385
|
||||
ImageTokenID int32 `json:"image_token_id"` // 167855
|
||||
|
||||
// Computed
|
||||
HeadDim int32
|
||||
}
|
||||
|
||||
// VisionLanguageEncoder is the 9B AR generator
|
||||
type VisionLanguageEncoder struct {
|
||||
Config *VisionLanguageConfig
|
||||
|
||||
// Embedding
|
||||
EmbedTokens *nn.Embedding `weight:"model.language_model.embed_tokens"`
|
||||
|
||||
// Transformer layers
|
||||
Layers []*GLMBlock `weight:"model.language_model.layers"`
|
||||
|
||||
// Final norm
|
||||
FinalNorm *nn.RMSNorm `weight:"model.language_model.norm"`
|
||||
|
||||
// LM Head
|
||||
LMHead *mlx.Array `weight:"lm_head.weight"`
|
||||
}
|
||||
|
||||
// GLMBlock is a single transformer block in GLM-4 style
|
||||
type GLMBlock struct {
|
||||
// Pre-attention norm (GLM uses post-LN variant)
|
||||
InputLayerNorm *nn.RMSNorm `weight:"input_layernorm"`
|
||||
PostSelfAttnNorm *nn.RMSNorm `weight:"post_self_attn_layernorm"`
|
||||
PostAttnLayerNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
|
||||
PostMLPLayerNorm *nn.RMSNorm `weight:"post_mlp_layernorm"`
|
||||
|
||||
// Attention
|
||||
SelfAttn *GLMAttention `weight:"self_attn"`
|
||||
|
||||
// MLP (fused gate_up)
|
||||
MLP *GLMMLP `weight:"mlp"`
|
||||
}
|
||||
|
||||
// GLMAttention implements GQA with partial rotary and MRoPE
|
||||
type GLMAttention struct {
|
||||
QProj *mlx.Array `weight:"q_proj.weight"`
|
||||
KProj *mlx.Array `weight:"k_proj.weight"`
|
||||
VProj *mlx.Array `weight:"v_proj.weight"`
|
||||
OProj *mlx.Array `weight:"o_proj.weight"`
|
||||
|
||||
// QKV have biases in GLM
|
||||
QBias *mlx.Array `weight:"q_proj.bias"`
|
||||
KBias *mlx.Array `weight:"k_proj.bias"`
|
||||
VBias *mlx.Array `weight:"v_proj.bias"`
|
||||
|
||||
// Computed
|
||||
NHeads int32
|
||||
NKVHeads int32
|
||||
HeadDim int32
|
||||
Scale float32
|
||||
PartialRotary float32 // Only rotate this fraction of head_dim
|
||||
RopeTheta float32
|
||||
MRoPESection []int32 // [8, 12, 12] - frequency pairs per dimension (temporal, height, width)
|
||||
}
|
||||
|
||||
// ARCache holds KV caches for all layers using the shared cache implementation
|
||||
type ARCache struct {
|
||||
Layers []cache.Cache
|
||||
}
|
||||
|
||||
// NewARCache creates a new cache for the given number of layers
|
||||
func NewARCache(numLayers int32) *ARCache {
|
||||
layers := make([]cache.Cache, numLayers)
|
||||
for i := range layers {
|
||||
layers[i] = cache.NewKVCache()
|
||||
}
|
||||
return &ARCache{Layers: layers}
|
||||
}
|
||||
|
||||
// Free releases all cached tensors
|
||||
func (c *ARCache) Free() {
|
||||
for _, layer := range c.Layers {
|
||||
for _, arr := range layer.State() {
|
||||
if arr != nil {
|
||||
arr.Free()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GLMMLP implements fused gate_up SwiGLU MLP
|
||||
type GLMMLP struct {
|
||||
// GLM uses fused gate_up_proj: [hidden, 2*intermediate]
|
||||
GateUpProj *mlx.Array `weight:"gate_up_proj.weight"`
|
||||
DownProj *mlx.Array `weight:"down_proj.weight"`
|
||||
}
|
||||
|
||||
// Load loads the vision-language encoder from manifest
|
||||
func (m *VisionLanguageEncoder) Load(manifest *imagegen.ModelManifest) error {
|
||||
fmt.Print(" Loading vision-language encoder... ")
|
||||
|
||||
// Load config
|
||||
var rawCfg struct {
|
||||
TextConfig struct {
|
||||
HiddenSize int32 `json:"hidden_size"`
|
||||
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
||||
IntermediateSize int32 `json:"intermediate_size"`
|
||||
NumAttentionHeads int32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads int32 `json:"num_key_value_heads"`
|
||||
VocabSize int32 `json:"vocab_size"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
VisionVocabSize int32 `json:"vision_vocab_size"`
|
||||
RopeParameters struct {
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
|
||||
MRoPESection []int32 `json:"mrope_section"`
|
||||
} `json:"rope_parameters"`
|
||||
} `json:"text_config"`
|
||||
ImageStartTokenID int32 `json:"image_start_token_id"`
|
||||
ImageEndTokenID int32 `json:"image_end_token_id"`
|
||||
ImageTokenID int32 `json:"image_token_id"`
|
||||
}
|
||||
|
||||
if err := manifest.ReadConfigJSON("vision_language_encoder/config.json", &rawCfg); err != nil {
|
||||
return fmt.Errorf("config: %w", err)
|
||||
}
|
||||
|
||||
cfg := &VisionLanguageConfig{
|
||||
HiddenSize: rawCfg.TextConfig.HiddenSize,
|
||||
NumHiddenLayers: rawCfg.TextConfig.NumHiddenLayers,
|
||||
IntermediateSize: rawCfg.TextConfig.IntermediateSize,
|
||||
NumAttentionHeads: rawCfg.TextConfig.NumAttentionHeads,
|
||||
NumKeyValueHeads: rawCfg.TextConfig.NumKeyValueHeads,
|
||||
VocabSize: rawCfg.TextConfig.VocabSize,
|
||||
RMSNormEps: rawCfg.TextConfig.RMSNormEps,
|
||||
VisionVocabSize: rawCfg.TextConfig.VisionVocabSize,
|
||||
RopeTheta: rawCfg.TextConfig.RopeParameters.RopeTheta,
|
||||
PartialRotaryFactor: rawCfg.TextConfig.RopeParameters.PartialRotaryFactor,
|
||||
MRoPESection: rawCfg.TextConfig.RopeParameters.MRoPESection,
|
||||
ImageStartTokenID: rawCfg.ImageStartTokenID,
|
||||
ImageEndTokenID: rawCfg.ImageEndTokenID,
|
||||
ImageTokenID: rawCfg.ImageTokenID,
|
||||
}
|
||||
|
||||
cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads
|
||||
m.Config = cfg
|
||||
|
||||
// Pre-allocate layers
|
||||
m.Layers = make([]*GLMBlock, cfg.NumHiddenLayers)
|
||||
|
||||
// Load weights
|
||||
weights, err := imagegen.LoadWeightsFromManifest(manifest, "vision_language_encoder")
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
|
||||
return fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
defer weights.ReleaseAll()
|
||||
|
||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||
return fmt.Errorf("load module: %w", err)
|
||||
}
|
||||
|
||||
m.initComputedFields()
|
||||
fmt.Printf("✓ [%d layers]\n", cfg.NumHiddenLayers)
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadFromPath loads the vision-language encoder from a directory path
|
||||
func (m *VisionLanguageEncoder) LoadFromPath(path string) error {
|
||||
fmt.Print(" Loading vision-language encoder... ")
|
||||
|
||||
// Load config
|
||||
var rawCfg struct {
|
||||
TextConfig struct {
|
||||
HiddenSize int32 `json:"hidden_size"`
|
||||
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
||||
IntermediateSize int32 `json:"intermediate_size"`
|
||||
NumAttentionHeads int32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads int32 `json:"num_key_value_heads"`
|
||||
VocabSize int32 `json:"vocab_size"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
VisionVocabSize int32 `json:"vision_vocab_size"`
|
||||
RopeParameters struct {
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
|
||||
MRoPESection []int32 `json:"mrope_section"`
|
||||
} `json:"rope_parameters"`
|
||||
} `json:"text_config"`
|
||||
ImageStartTokenID int32 `json:"image_start_token_id"`
|
||||
ImageEndTokenID int32 `json:"image_end_token_id"`
|
||||
ImageTokenID int32 `json:"image_token_id"`
|
||||
}
|
||||
|
||||
configPath := filepath.Join(path, "config.json")
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read config: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(data, &rawCfg); err != nil {
|
||||
return fmt.Errorf("parse config: %w", err)
|
||||
}
|
||||
|
||||
cfg := &VisionLanguageConfig{
|
||||
HiddenSize: rawCfg.TextConfig.HiddenSize,
|
||||
NumHiddenLayers: rawCfg.TextConfig.NumHiddenLayers,
|
||||
IntermediateSize: rawCfg.TextConfig.IntermediateSize,
|
||||
NumAttentionHeads: rawCfg.TextConfig.NumAttentionHeads,
|
||||
NumKeyValueHeads: rawCfg.TextConfig.NumKeyValueHeads,
|
||||
VocabSize: rawCfg.TextConfig.VocabSize,
|
||||
RMSNormEps: rawCfg.TextConfig.RMSNormEps,
|
||||
VisionVocabSize: rawCfg.TextConfig.VisionVocabSize,
|
||||
RopeTheta: rawCfg.TextConfig.RopeParameters.RopeTheta,
|
||||
PartialRotaryFactor: rawCfg.TextConfig.RopeParameters.PartialRotaryFactor,
|
||||
MRoPESection: rawCfg.TextConfig.RopeParameters.MRoPESection,
|
||||
ImageStartTokenID: rawCfg.ImageStartTokenID,
|
||||
ImageEndTokenID: rawCfg.ImageEndTokenID,
|
||||
ImageTokenID: rawCfg.ImageTokenID,
|
||||
}
|
||||
|
||||
cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads
|
||||
m.Config = cfg
|
||||
|
||||
// Pre-allocate layers
|
||||
m.Layers = make([]*GLMBlock, cfg.NumHiddenLayers)
|
||||
|
||||
// Load weights
|
||||
weights, err := safetensors.LoadModelWeights(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
|
||||
return fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
defer weights.ReleaseAll()
|
||||
|
||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||
return fmt.Errorf("load module: %w", err)
|
||||
}
|
||||
|
||||
m.initComputedFields()
|
||||
fmt.Printf("✓ [%d layers]\n", cfg.NumHiddenLayers)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *VisionLanguageEncoder) initComputedFields() {
|
||||
cfg := m.Config
|
||||
for _, block := range m.Layers {
|
||||
block.SelfAttn.NHeads = cfg.NumAttentionHeads
|
||||
block.SelfAttn.NKVHeads = cfg.NumKeyValueHeads
|
||||
block.SelfAttn.HeadDim = cfg.HeadDim
|
||||
block.SelfAttn.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
|
||||
block.SelfAttn.PartialRotary = cfg.PartialRotaryFactor
|
||||
block.SelfAttn.RopeTheta = cfg.RopeTheta
|
||||
block.SelfAttn.MRoPESection = cfg.MRoPESection
|
||||
|
||||
// Set norm eps
|
||||
block.InputLayerNorm.Eps = cfg.RMSNormEps
|
||||
block.PostSelfAttnNorm.Eps = cfg.RMSNormEps
|
||||
block.PostAttnLayerNorm.Eps = cfg.RMSNormEps
|
||||
block.PostMLPLayerNorm.Eps = cfg.RMSNormEps
|
||||
}
|
||||
m.FinalNorm.Eps = cfg.RMSNormEps
|
||||
}
|
||||
|
||||
// Generate autoregressively generates visual tokens with KV caching
|
||||
func (m *VisionLanguageEncoder) Generate(
|
||||
prompt string,
|
||||
tok *GLMTokenizer,
|
||||
maxTokens int32,
|
||||
temperature float32,
|
||||
topP float32,
|
||||
seed int64,
|
||||
targetHeight, targetWidth int32,
|
||||
progressFn func(int),
|
||||
) *mlx.Array {
|
||||
cfg := m.Config
|
||||
|
||||
// Encode prompt with grid tokens using GLM tokenizer
|
||||
// Format: {prompt}<sop>{h} {w}<eop><sop>{prev_h} {prev_w}<eop><|dit_token_16384|>
|
||||
tokens := tok.EncodeForGeneration(prompt, targetHeight, targetWidth)
|
||||
|
||||
// Calculate grid dimensions for MRoPE position IDs
|
||||
factor := int32(32)
|
||||
tokenH := targetHeight / factor
|
||||
tokenW := targetWidth / factor
|
||||
ratio := float64(tokenH) / float64(tokenW)
|
||||
prevTokenH := int32(math.Sqrt(ratio) * 16)
|
||||
prevTokenW := int32(math.Sqrt(1.0/ratio) * 16)
|
||||
prevGridSize := prevTokenH * prevTokenW
|
||||
|
||||
// Create KV cache for all layers
|
||||
cache := NewARCache(cfg.NumHiddenLayers)
|
||||
defer cache.Free()
|
||||
|
||||
// ===== PREFILL PHASE =====
|
||||
// Process entire prompt at once, populate cache
|
||||
promptLen := int32(len(tokens))
|
||||
tokenArr := mlx.NewArrayInt32(tokens, []int32{1, promptLen})
|
||||
h := m.EmbedTokens.Forward(tokenArr)
|
||||
tokenArr.Free()
|
||||
|
||||
mlx.Eval(h)
|
||||
|
||||
// Compute position IDs for prefill (text tokens use same position for all dims)
|
||||
prefillPositions := make([][]int32, 3)
|
||||
for dim := 0; dim < 3; dim++ {
|
||||
prefillPositions[dim] = make([]int32, promptLen)
|
||||
for i := int32(0); i < promptLen; i++ {
|
||||
prefillPositions[dim][i] = i
|
||||
}
|
||||
}
|
||||
|
||||
// Forward through layers (prefill)
|
||||
for i, layer := range m.Layers {
|
||||
oldH := h
|
||||
h = layer.ForwardWithCache(h, promptLen, 0, cfg.RMSNormEps, cache.Layers[i], prefillPositions)
|
||||
if i > 0 {
|
||||
oldH.Free()
|
||||
}
|
||||
}
|
||||
// Eval h and cache arrays together so cache is materialized
|
||||
evalArgs := []*mlx.Array{h}
|
||||
for _, lc := range cache.Layers {
|
||||
evalArgs = append(evalArgs, lc.State()...)
|
||||
}
|
||||
mlx.Eval(evalArgs...)
|
||||
|
||||
// Final norm and get logits for last position
|
||||
preNormH := h
|
||||
h = m.FinalNorm.Forward(h, cfg.RMSNormEps)
|
||||
preNormH.Free()
|
||||
|
||||
lastH := mlx.Slice(h, []int32{0, promptLen - 1, 0}, []int32{1, promptLen, cfg.HiddenSize})
|
||||
h.Free()
|
||||
lastH = mlx.Reshape(lastH, 1, cfg.HiddenSize)
|
||||
logits := mlx.Matmul(lastH, mlx.Transpose(m.LMHead, 1, 0))
|
||||
lastH.Free()
|
||||
|
||||
// Sample first token
|
||||
var sampleCounter int64 = 0
|
||||
nextToken := sampleVisualToken(logits, temperature, topP, cfg, seed, &sampleCounter)
|
||||
logits.Free()
|
||||
|
||||
// AR generation loop with caching
|
||||
// Visual tokens are stored as VQ codebook indices [0, 16383]
|
||||
// The LM head outputs indices [0, 16511] where:
|
||||
// - [0, 16383] are VQ codes
|
||||
// - 16384 is BOS
|
||||
// - 16385 is EOS
|
||||
visualTokens := make([]int32, 0, maxTokens)
|
||||
posOffset := promptLen
|
||||
visualTokenIdx := int32(0) // Index within visual token sequence for grid position calculation
|
||||
|
||||
// Preallocate slice for old cache state to reuse
|
||||
oldCacheState := make([]*mlx.Array, 0, len(m.Layers)*2)
|
||||
|
||||
for i := int32(0); i < maxTokens; i++ {
|
||||
if progressFn != nil {
|
||||
progressFn(int(i))
|
||||
}
|
||||
|
||||
// Check for end token (EOS = 16385)
|
||||
if nextToken == cfg.ImageEndTokenID {
|
||||
break
|
||||
}
|
||||
|
||||
// Skip BOS token (16384), only store actual VQ codes [0, 16383]
|
||||
if nextToken == cfg.ImageStartTokenID {
|
||||
// BOS token - skip storing but continue generation
|
||||
} else if nextToken < cfg.ImageStartTokenID {
|
||||
// This is an actual VQ code [0, 16383] - store it
|
||||
visualTokens = append(visualTokens, nextToken)
|
||||
}
|
||||
// Tokens >= 16386 are other special tokens, skip them
|
||||
|
||||
// ===== DECODE PHASE =====
|
||||
// Save old cache state before forward (to free after eval)
|
||||
oldCacheState = oldCacheState[:0]
|
||||
for _, lc := range cache.Layers {
|
||||
oldCacheState = append(oldCacheState, lc.State()...)
|
||||
}
|
||||
|
||||
// Only process the new token, use cached K,V
|
||||
tokenArr := mlx.NewArrayInt32([]int32{nextToken}, []int32{1, 1})
|
||||
h := m.EmbedTokens.Forward(tokenArr)
|
||||
tokenArr.Free()
|
||||
|
||||
// Compute MRoPE position IDs for this visual token
|
||||
// Visual tokens are arranged in two grids: prev grid then target grid
|
||||
// Position dimensions: [temporal, height, width]
|
||||
decodePositions := computeVisualTokenPositions(
|
||||
visualTokenIdx, posOffset, promptLen,
|
||||
prevTokenH, prevTokenW, prevGridSize,
|
||||
tokenH, tokenW,
|
||||
)
|
||||
|
||||
// Forward through layers (decode with cache)
|
||||
for j, layer := range m.Layers {
|
||||
oldH := h
|
||||
h = layer.ForwardWithCache(h, 1, posOffset, cfg.RMSNormEps, cache.Layers[j], decodePositions)
|
||||
if j > 0 { // Don't free the embedding on first layer
|
||||
oldH.Free()
|
||||
}
|
||||
}
|
||||
|
||||
// Eval h and new cache state
|
||||
newCacheState := make([]*mlx.Array, 0, len(m.Layers)*2)
|
||||
for _, lc := range cache.Layers {
|
||||
newCacheState = append(newCacheState, lc.State()...)
|
||||
}
|
||||
mlx.Eval(append([]*mlx.Array{h}, newCacheState...)...)
|
||||
|
||||
// Free old cache state (now that new state is evaluated)
|
||||
for _, arr := range oldCacheState {
|
||||
if arr != nil {
|
||||
arr.Free()
|
||||
}
|
||||
}
|
||||
|
||||
// Final norm
|
||||
preNormH := h
|
||||
h = m.FinalNorm.Forward(h, cfg.RMSNormEps)
|
||||
preNormH.Free()
|
||||
|
||||
// Get logits (h is already [1, 1, hidden_size])
|
||||
h = mlx.Reshape(h, 1, cfg.HiddenSize)
|
||||
logits := mlx.Matmul(h, mlx.Transpose(m.LMHead, 1, 0))
|
||||
h.Free()
|
||||
|
||||
// Sample next token
|
||||
nextToken = sampleVisualToken(logits, temperature, topP, cfg, seed, &sampleCounter)
|
||||
logits.Free()
|
||||
|
||||
posOffset++
|
||||
visualTokenIdx++
|
||||
|
||||
// Periodically clear cache to release intermediate memory
|
||||
if i%256 == 0 {
|
||||
mlx.ClearCache()
|
||||
}
|
||||
}
|
||||
|
||||
if len(visualTokens) == 0 {
|
||||
// Return at least one token to avoid empty tensor issues
|
||||
visualTokens = append(visualTokens, 0)
|
||||
}
|
||||
|
||||
return mlx.NewArrayInt32(visualTokens, []int32{1, int32(len(visualTokens))})
|
||||
}
|
||||
|
||||
// computeVisualTokenPositions computes MRoPE position IDs for a visual token
|
||||
// Returns [3][1] position IDs for temporal, height, and width dimensions
|
||||
//
|
||||
// MRoPE position encoding for GLM-Image visual tokens:
|
||||
// - temporal: CONSTANT within each grid (= decode_pos at grid start)
|
||||
// - height: decode_pos + row index within grid
|
||||
// - width: decode_pos + column index within grid
|
||||
//
|
||||
// Between grids, decode_pos advances by max(grid_h, grid_w) to ensure
|
||||
// sufficient positional separation.
|
||||
func computeVisualTokenPositions(
|
||||
visualIdx int32, absPos int32, promptLen int32,
|
||||
prevH, prevW, prevSize int32,
|
||||
targetH, targetW int32,
|
||||
) [][]int32 {
|
||||
positions := make([][]int32, 3)
|
||||
for dim := 0; dim < 3; dim++ {
|
||||
positions[dim] = make([]int32, 1)
|
||||
}
|
||||
|
||||
// First grid (prev grid) starts at decode_pos = promptLen
|
||||
prevGridDecodePos := promptLen
|
||||
|
||||
// Second grid (target grid) starts after first grid
|
||||
// next_pos = prev_decode_pos + max(prevH, prevW)
|
||||
maxPrev := prevH
|
||||
if prevW > maxPrev {
|
||||
maxPrev = prevW
|
||||
}
|
||||
targetGridDecodePos := prevGridDecodePos + maxPrev
|
||||
|
||||
// Compute position IDs based on which grid the token is in
|
||||
if visualIdx < prevSize {
|
||||
// Token is in the prev grid (prev_token_h × prev_token_w)
|
||||
row := visualIdx / prevW
|
||||
col := visualIdx % prevW
|
||||
|
||||
// temporal is CONSTANT for all tokens in this grid
|
||||
positions[0][0] = prevGridDecodePos
|
||||
// height and width are relative to grid's decode_pos
|
||||
positions[1][0] = prevGridDecodePos + row
|
||||
positions[2][0] = prevGridDecodePos + col
|
||||
} else {
|
||||
// Token is in the target grid (token_h × token_w)
|
||||
targetIdx := visualIdx - prevSize
|
||||
row := targetIdx / targetW
|
||||
col := targetIdx % targetW
|
||||
|
||||
// temporal is CONSTANT for all tokens in this grid
|
||||
positions[0][0] = targetGridDecodePos
|
||||
// height and width are relative to grid's decode_pos
|
||||
positions[1][0] = targetGridDecodePos + row
|
||||
positions[2][0] = targetGridDecodePos + col
|
||||
}
|
||||
|
||||
_ = targetH // Used for documentation clarity
|
||||
_ = absPos // No longer used - kept for API compatibility
|
||||
return positions
|
||||
}
|
||||
|
||||
// sampleVisualToken samples from the visual vocabulary using top-p (nucleus) sampling
|
||||
// Note: For GLM-Image, greedy decoding is not allowed as it may cause repetitive outputs
|
||||
// Returns a visual token ID in range [0, 16511] which directly indexes into the embedding table
|
||||
// sampleCounter is incremented for each call to ensure different random values
|
||||
func sampleVisualToken(logits *mlx.Array, temperature float32, topP float32, cfg *VisionLanguageConfig, seed int64, sampleCounter *int64) int32 {
|
||||
// The LMHead outputs logits for visual tokens only (shape [1, 16512])
|
||||
// Output index directly corresponds to vocab ID [0, 16511]
|
||||
// No offset needed - the visual tokens are at vocab IDs [0, 16511]
|
||||
visualLogits := logits
|
||||
|
||||
// Apply temperature
|
||||
if temperature != 1.0 && temperature > 0 {
|
||||
visualLogits = mlx.DivScalar(visualLogits, temperature)
|
||||
}
|
||||
|
||||
// Apply softmax to get probabilities
|
||||
probs := mlx.Softmax(visualLogits, -1)
|
||||
mlx.Eval(probs)
|
||||
|
||||
// Get the sampled index using top-p sampling
|
||||
// This directly gives us the vocab ID in [0, 16511]
|
||||
// Special tokens: 16384 = BOS, 16385 = EOS
|
||||
// Use seed + counter for reproducible but different random values
|
||||
effectiveSeed := seed + *sampleCounter
|
||||
*sampleCounter++
|
||||
return sampleTopP(probs, topP, effectiveSeed)
|
||||
}
|
||||
|
||||
// sampleTopP implements nucleus (top-p) sampling
|
||||
// probs: [1, vocab_size] probability distribution
|
||||
// topP: cumulative probability threshold (e.g., 0.75)
|
||||
// seed: random seed for reproducible sampling
|
||||
func sampleTopP(probs *mlx.Array, topP float32, seed int64) int32 {
|
||||
// Negate probs for descending sort (Argsort only does ascending)
|
||||
negProbs := mlx.MulScalar(probs, -1)
|
||||
sortedIndices := mlx.Argsort(negProbs, -1)
|
||||
sortedProbs := mlx.TakeAlongAxis(probs, sortedIndices, -1)
|
||||
cumProbs := mlx.Cumsum(sortedProbs, -1)
|
||||
mlx.Eval(sortedIndices, sortedProbs, cumProbs)
|
||||
|
||||
// Find cutoff index where cumulative probability exceeds topP
|
||||
probsData := sortedProbs.Data()
|
||||
cumProbsData := cumProbs.Data()
|
||||
indicesData := sortedIndices.DataInt32()
|
||||
|
||||
// Calculate cutoff and renormalize
|
||||
var cutoffIdx int
|
||||
var totalProb float32
|
||||
for i, cp := range cumProbsData {
|
||||
totalProb += probsData[i]
|
||||
if cp >= topP {
|
||||
cutoffIdx = i + 1 // Include this token
|
||||
break
|
||||
}
|
||||
}
|
||||
if cutoffIdx == 0 {
|
||||
cutoffIdx = len(probsData) // Use all tokens if topP is very high
|
||||
}
|
||||
|
||||
// Sample from the truncated distribution
|
||||
// Renormalize the truncated probabilities
|
||||
truncatedProbs := make([]float32, cutoffIdx)
|
||||
for i := 0; i < cutoffIdx; i++ {
|
||||
truncatedProbs[i] = probsData[i] / totalProb
|
||||
}
|
||||
|
||||
// Sample using random number with provided seed for reproducibility
|
||||
r := mlx.RandomUniform([]int32{1}, uint64(seed))
|
||||
mlx.Eval(r)
|
||||
randVal := r.Data()[0]
|
||||
|
||||
// Find the sampled token
|
||||
var cumulative float32
|
||||
for i := 0; i < cutoffIdx; i++ {
|
||||
cumulative += truncatedProbs[i]
|
||||
if randVal < cumulative {
|
||||
return indicesData[i]
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to the last token in truncated set
|
||||
return indicesData[cutoffIdx-1]
|
||||
}
|
||||
|
||||
// Forward for GLMBlock
|
||||
func (b *GLMBlock) Forward(x *mlx.Array, seqLen int32, eps float32) *mlx.Array {
|
||||
return b.ForwardWithCache(x, seqLen, 0, eps, nil, nil)
|
||||
}
|
||||
|
||||
// ForwardWithCache performs block forward with optional KV caching and MRoPE
|
||||
// positionIDs: [3][L] - position indices for MRoPE (nil = use sequential positions)
|
||||
func (b *GLMBlock) ForwardWithCache(x *mlx.Array, seqLen int32, posOffset int32, eps float32, kvcache cache.Cache, positionIDs [][]int32) *mlx.Array {
|
||||
// Pre-attention norm
|
||||
normed := b.InputLayerNorm.Forward(x, eps)
|
||||
|
||||
// Self-attention with RoPE/MRoPE and cache
|
||||
attnOut := b.SelfAttn.ForwardWithCache(normed, seqLen, posOffset, kvcache, positionIDs)
|
||||
|
||||
// Post-attention norm (GLM-4 style)
|
||||
attnOut = b.PostSelfAttnNorm.Forward(attnOut, eps)
|
||||
|
||||
// Residual connection
|
||||
x = mlx.Add(x, attnOut)
|
||||
|
||||
// Post-attention layer norm
|
||||
normed = b.PostAttnLayerNorm.Forward(x, eps)
|
||||
|
||||
// MLP
|
||||
mlpOut := b.MLP.Forward(normed)
|
||||
|
||||
// Post-MLP norm
|
||||
mlpOut = b.PostMLPLayerNorm.Forward(mlpOut, eps)
|
||||
|
||||
// Residual connection
|
||||
x = mlx.Add(x, mlpOut)
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// Forward for GLMAttention (without cache - used for prefill)
|
||||
func (attn *GLMAttention) Forward(x *mlx.Array, seqLen int32) *mlx.Array {
|
||||
return attn.ForwardWithCache(x, seqLen, 0, nil, nil)
|
||||
}
|
||||
|
||||
// ForwardWithCache performs attention with optional KV caching and MRoPE
|
||||
// posOffset is the position offset for RoPE (0 for prefill, cached_len for decode)
|
||||
// positionIDs: [3][L] - if nil, uses sequential positions for all dims (text mode)
|
||||
// kvcache is updated in-place if provided
|
||||
func (attn *GLMAttention) ForwardWithCache(x *mlx.Array, seqLen int32, posOffset int32, kvcache cache.Cache, positionIDs [][]int32) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
|
||||
// Q, K, V projections
|
||||
q := mlx.Matmul(x, mlx.Transpose(attn.QProj, 1, 0))
|
||||
k := mlx.Matmul(x, mlx.Transpose(attn.KProj, 1, 0))
|
||||
v := mlx.Matmul(x, mlx.Transpose(attn.VProj, 1, 0))
|
||||
|
||||
// Add biases
|
||||
if attn.QBias != nil {
|
||||
q = mlx.Add(q, attn.QBias)
|
||||
}
|
||||
if attn.KBias != nil {
|
||||
k = mlx.Add(k, attn.KBias)
|
||||
}
|
||||
if attn.VBias != nil {
|
||||
v = mlx.Add(v, attn.VBias)
|
||||
}
|
||||
|
||||
// Reshape to [B, L, nheads, head_dim]
|
||||
q = mlx.Reshape(q, B, L, attn.NHeads, attn.HeadDim)
|
||||
k = mlx.Reshape(k, B, L, attn.NKVHeads, attn.HeadDim)
|
||||
v = mlx.Reshape(v, B, L, attn.NKVHeads, attn.HeadDim)
|
||||
|
||||
// Apply partial RoPE or MRoPE
|
||||
rotaryDim := int32(float32(attn.HeadDim) * attn.PartialRotary)
|
||||
if len(attn.MRoPESection) == 3 && positionIDs != nil {
|
||||
// Use MRoPE with explicit position IDs
|
||||
q = applyPartialMRoPE(q, positionIDs, rotaryDim, attn.RopeTheta, attn.MRoPESection)
|
||||
k = applyPartialMRoPE(k, positionIDs, rotaryDim, attn.RopeTheta, attn.MRoPESection)
|
||||
} else if len(attn.MRoPESection) == 3 {
|
||||
// Use MRoPE with sequential positions (same for all dims - text mode)
|
||||
seqPositions := make([][]int32, 3)
|
||||
for dim := 0; dim < 3; dim++ {
|
||||
seqPositions[dim] = make([]int32, L)
|
||||
for i := int32(0); i < L; i++ {
|
||||
seqPositions[dim][i] = i + posOffset
|
||||
}
|
||||
}
|
||||
q = applyPartialMRoPE(q, seqPositions, rotaryDim, attn.RopeTheta, attn.MRoPESection)
|
||||
k = applyPartialMRoPE(k, seqPositions, rotaryDim, attn.RopeTheta, attn.MRoPESection)
|
||||
} else {
|
||||
// Fallback to standard RoPE
|
||||
q = applyPartialRoPEWithOffset(q, L, posOffset, rotaryDim, attn.RopeTheta)
|
||||
k = applyPartialRoPEWithOffset(k, L, posOffset, rotaryDim, attn.RopeTheta)
|
||||
}
|
||||
|
||||
// Transpose to [B, nheads, L, head_dim]
|
||||
q = mlx.Transpose(q, 0, 2, 1, 3)
|
||||
k = mlx.Transpose(k, 0, 2, 1, 3)
|
||||
v = mlx.Transpose(v, 0, 2, 1, 3)
|
||||
|
||||
// Update cache and get full K, V for attention
|
||||
if kvcache != nil {
|
||||
k, v = kvcache.Update(k, v, int(L))
|
||||
}
|
||||
|
||||
// Repeat KV for GQA
|
||||
kExpanded := k
|
||||
vExpanded := v
|
||||
if attn.NKVHeads < attn.NHeads {
|
||||
repeats := attn.NHeads / attn.NKVHeads
|
||||
kExpanded = repeatKV(k, repeats)
|
||||
vExpanded = repeatKV(v, repeats)
|
||||
}
|
||||
|
||||
// Scaled dot-product attention with causal mask
|
||||
out := mlx.ScaledDotProductAttention(q, kExpanded, vExpanded, attn.Scale, true)
|
||||
|
||||
// Transpose back [B, nheads, L, head_dim] -> [B, L, nheads, head_dim]
|
||||
out = mlx.Transpose(out, 0, 2, 1, 3)
|
||||
// Reshape to [B, L, hidden_size]
|
||||
out = mlx.Reshape(out, B, L, attn.NHeads*attn.HeadDim)
|
||||
|
||||
// Output projection
|
||||
out = mlx.Matmul(out, mlx.Transpose(attn.OProj, 1, 0))
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// applyPartialRoPE applies RoPE to only the first rotaryDim dimensions
|
||||
func applyPartialRoPE(x *mlx.Array, seqLen int32, rotaryDim int32, theta float32) *mlx.Array {
|
||||
return applyPartialRoPEWithOffset(x, seqLen, 0, rotaryDim, theta)
|
||||
}
|
||||
|
||||
// applyPartialRoPEWithOffset applies RoPE with a position offset
|
||||
func applyPartialRoPEWithOffset(x *mlx.Array, seqLen int32, posOffset int32, rotaryDim int32, theta float32) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
H := shape[2]
|
||||
D := shape[3]
|
||||
|
||||
if rotaryDim <= 0 || rotaryDim > D {
|
||||
rotaryDim = D
|
||||
}
|
||||
|
||||
// Split into rotary and pass-through parts
|
||||
xRot := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, L, H, rotaryDim})
|
||||
xPass := mlx.Slice(x, []int32{0, 0, 0, rotaryDim}, []int32{B, L, H, D})
|
||||
|
||||
// Apply RoPE to rotary part with position offset
|
||||
xRot = applyRoPEWithOffset(xRot, L, posOffset, theta)
|
||||
|
||||
// Concatenate back
|
||||
return mlx.Concatenate([]*mlx.Array{xRot, xPass}, 3)
|
||||
}
|
||||
|
||||
// applyPartialMRoPE applies Multi-dimensional RoPE (MRoPE) to the first rotaryDim dimensions
|
||||
// positionIDs: [3, L] - position indices for each dimension (temporal, height, width)
|
||||
// mrope_section: [8, 12, 12] - frequency pairs per dimension
|
||||
// For text tokens: all 3 dimensions have the same sequential position
|
||||
// For image tokens: temporal=seq_idx, height=row, width=col
|
||||
func applyPartialMRoPE(x *mlx.Array, positionIDs [][]int32, rotaryDim int32, theta float32, mropeSection []int32) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
H := shape[2]
|
||||
D := shape[3]
|
||||
|
||||
if rotaryDim <= 0 || rotaryDim > D {
|
||||
rotaryDim = D
|
||||
}
|
||||
|
||||
// Split into rotary and pass-through parts
|
||||
xRot := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, L, H, rotaryDim})
|
||||
xPass := mlx.Slice(x, []int32{0, 0, 0, rotaryDim}, []int32{B, L, H, D})
|
||||
|
||||
// Apply MRoPE to rotary part
|
||||
xRot = applyMRoPE(xRot, positionIDs, theta, mropeSection)
|
||||
|
||||
// Concatenate back
|
||||
return mlx.Concatenate([]*mlx.Array{xRot, xPass}, 3)
|
||||
}
|
||||
|
||||
// applyMRoPE applies multi-dimensional rotary position embedding
|
||||
// x: [B, L, H, D] where D is the rotary dimension
|
||||
// positionIDs: [3][L] - positions for temporal, height, width dimensions
|
||||
// mropeSection: [8, 12, 12] - frequency pairs per dimension
|
||||
func applyMRoPE(x *mlx.Array, positionIDs [][]int32, theta float32, mropeSection []int32) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
H := shape[2]
|
||||
D := shape[3]
|
||||
half := D / 2
|
||||
|
||||
// Validate mrope_section sums to half (number of frequency pairs)
|
||||
var totalPairs int32
|
||||
for _, s := range mropeSection {
|
||||
totalPairs += s
|
||||
}
|
||||
if totalPairs != half {
|
||||
// Fallback to standard RoPE if section doesn't match
|
||||
return applyRoPEWithOffset(x, L, 0, theta)
|
||||
}
|
||||
|
||||
// Build angles for each position dimension (matching Python's MRoPE approach)
|
||||
// Python: compute freqs for all dims, then apply_mrope selects freq ranges, then duplicate
|
||||
// Order: [temporal_8, height_12, width_12] -> duplicate -> [t8, h12, w12, t8, h12, w12]
|
||||
angleVals := make([]*mlx.Array, 3)
|
||||
|
||||
freqOffset := int32(0)
|
||||
for dim := 0; dim < 3; dim++ {
|
||||
numPairs := mropeSection[dim]
|
||||
if numPairs == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Compute inverse frequencies for this section
|
||||
// Each dimension uses DIFFERENT frequency ranges:
|
||||
// - Temporal: frequencies 0 to section[0]-1
|
||||
// - Height: frequencies section[0] to section[0]+section[1]-1
|
||||
// - Width: frequencies section[0]+section[1] to sum(section)-1
|
||||
freqsArr := make([]float32, numPairs)
|
||||
for i := int32(0); i < numPairs; i++ {
|
||||
globalIdx := freqOffset + i
|
||||
freqsArr[i] = float32(1.0 / math.Pow(float64(theta), float64(2*globalIdx)/float64(D)))
|
||||
}
|
||||
freqs := mlx.NewArray(freqsArr, []int32{numPairs})
|
||||
|
||||
// Position indices for this dimension
|
||||
posArr := make([]float32, L)
|
||||
for i := int32(0); i < L; i++ {
|
||||
posArr[i] = float32(positionIDs[dim][i])
|
||||
}
|
||||
pos := mlx.NewArray(posArr, []int32{L})
|
||||
|
||||
// Compute angles: [L, numPairs] = outer(pos, freqs)
|
||||
posExpanded := mlx.Reshape(pos, L, 1)
|
||||
freqsExpanded := mlx.Reshape(freqs, 1, numPairs)
|
||||
angleVals[dim] = mlx.Mul(posExpanded, freqsExpanded)
|
||||
|
||||
freqOffset += numPairs
|
||||
}
|
||||
|
||||
// Concatenate all sections: [L, half] = [L, 32]
|
||||
allAngles := mlx.Concatenate(angleVals, 1)
|
||||
|
||||
// Duplicate AFTER concatenation: [L, D] = [L, 64]
|
||||
// This gives: [temporal_8, height_12, width_12, temporal_8, height_12, width_12]
|
||||
allAngles = mlx.Concatenate([]*mlx.Array{allAngles, allAngles}, 1)
|
||||
|
||||
// Compute cos/sin
|
||||
allCos := mlx.Cos(allAngles)
|
||||
allSin := mlx.Sin(allAngles)
|
||||
|
||||
// Reshape for broadcasting: [1, L, 1, D] to match x [B, L, H, D]
|
||||
allCos = mlx.Reshape(allCos, 1, L, 1, D)
|
||||
allSin = mlx.Reshape(allSin, 1, L, 1, D)
|
||||
|
||||
// x_rotated = cat([-x_imag, x_real], dim=-1)
|
||||
x1 := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, L, H, half}) // x_real
|
||||
x2 := mlx.Slice(x, []int32{0, 0, 0, half}, []int32{B, L, H, D}) // x_imag
|
||||
x2Neg := mlx.MulScalar(x2, -1) // -x_imag
|
||||
xRotated := mlx.Concatenate([]*mlx.Array{x2Neg, x1}, 3) // [-x_imag, x_real]
|
||||
|
||||
// out = x * cos + x_rotated * sin
|
||||
return mlx.Add(mlx.Mul(x, allCos), mlx.Mul(xRotated, allSin))
|
||||
}
|
||||
|
||||
// applyRoPE applies rotary position embedding
|
||||
func applyRoPE(x *mlx.Array, seqLen int32, theta float32) *mlx.Array {
|
||||
return applyRoPEWithOffset(x, seqLen, 0, theta)
|
||||
}
|
||||
|
||||
// applyRoPEWithOffset applies rotary position embedding with position offset
|
||||
// Uses the split-half approach (matches diffusers GLM-Image with use_real_unbind_dim=-2)
|
||||
func applyRoPEWithOffset(x *mlx.Array, seqLen int32, posOffset int32, theta float32) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
H := shape[2]
|
||||
D := shape[3]
|
||||
half := D / 2
|
||||
|
||||
// Compute inverse frequencies: 1 / (theta^(2i/d))
|
||||
freqsArr := make([]float32, half)
|
||||
for i := int32(0); i < half; i++ {
|
||||
freqsArr[i] = float32(1.0 / math.Pow(float64(theta), float64(2*i)/float64(D)))
|
||||
}
|
||||
freqs := mlx.NewArray(freqsArr, []int32{half})
|
||||
|
||||
// Position indices with offset
|
||||
posArr := make([]float32, L)
|
||||
for i := int32(0); i < L; i++ {
|
||||
posArr[i] = float32(i + posOffset)
|
||||
}
|
||||
pos := mlx.NewArray(posArr, []int32{L})
|
||||
|
||||
// Compute angles: [L, half] = outer(pos, freqs)
|
||||
posExpanded := mlx.Reshape(pos, L, 1)
|
||||
freqsExpanded := mlx.Reshape(freqs, 1, half)
|
||||
angles := mlx.Mul(posExpanded, freqsExpanded)
|
||||
|
||||
// Duplicate angles to match diffusers: cat([angles, angles], dim=-1) -> [L, D]
|
||||
anglesDup := mlx.Concatenate([]*mlx.Array{angles, angles}, 1)
|
||||
|
||||
// Cos and sin: [L, 1, D] for broadcasting to [B, L, H, D]
|
||||
cosVals := mlx.Cos(anglesDup)
|
||||
sinVals := mlx.Sin(anglesDup)
|
||||
cosVals = mlx.Reshape(cosVals, L, 1, D)
|
||||
sinVals = mlx.Reshape(sinVals, L, 1, D)
|
||||
|
||||
// x_rotated = cat([-x_imag, x_real], dim=-1) where x_real=x[..., :half], x_imag=x[..., half:]
|
||||
x1 := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, L, H, half}) // x_real
|
||||
x2 := mlx.Slice(x, []int32{0, 0, 0, half}, []int32{B, L, H, D}) // x_imag
|
||||
x2Neg := mlx.MulScalar(x2, -1) // -x_imag
|
||||
xRotated := mlx.Concatenate([]*mlx.Array{x2Neg, x1}, 3) // [-x_imag, x_real]
|
||||
|
||||
// out = x * cos + x_rotated * sin
|
||||
return mlx.Add(mlx.Mul(x, cosVals), mlx.Mul(xRotated, sinVals))
|
||||
}
|
||||
|
||||
// repeatKV repeats key/value heads for GQA
|
||||
func repeatKV(x *mlx.Array, repeats int32) *mlx.Array {
|
||||
if repeats == 1 {
|
||||
return x
|
||||
}
|
||||
shape := x.Shape()
|
||||
// x: [B, nkvheads, L, head_dim]
|
||||
x = mlx.ExpandDims(x, 2)
|
||||
// x: [B, nkvheads, 1, L, head_dim]
|
||||
x = mlx.Tile(x, []int32{1, 1, repeats, 1, 1})
|
||||
// x: [B, nkvheads, repeats, L, head_dim]
|
||||
return mlx.Reshape(x, shape[0], shape[1]*repeats, shape[2], shape[3])
|
||||
}
|
||||
|
||||
// Forward for GLMMLP (fused gate_up SwiGLU)
|
||||
func (m *GLMMLP) Forward(x *mlx.Array) *mlx.Array {
|
||||
// gate_up_proj outputs [gate, up] concatenated
|
||||
gateUp := mlx.Matmul(x, mlx.Transpose(m.GateUpProj, 1, 0))
|
||||
|
||||
shape := gateUp.Shape()
|
||||
halfDim := shape[len(shape)-1] / 2
|
||||
|
||||
// Split into gate and up
|
||||
gate := mlx.Slice(gateUp, []int32{0, 0, 0}, []int32{shape[0], shape[1], halfDim})
|
||||
up := mlx.Slice(gateUp, []int32{0, 0, halfDim}, []int32{shape[0], shape[1], shape[2]})
|
||||
|
||||
// SwiGLU: silu(gate) * up
|
||||
gate = mlx.SiLU(gate)
|
||||
h := mlx.Mul(gate, up)
|
||||
|
||||
// Down projection
|
||||
return mlx.Matmul(h, mlx.Transpose(m.DownProj, 1, 0))
|
||||
}
|
||||
@@ -19,9 +19,15 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/models/glm_image"
|
||||
"github.com/ollama/ollama/x/imagegen/models/zimage"
|
||||
)
|
||||
|
||||
// ImageModel is the interface for image generation models
|
||||
type ImageModel interface {
|
||||
GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error)
|
||||
}
|
||||
|
||||
// Request is the image generation request format
|
||||
type Request struct {
|
||||
Prompt string `json:"prompt"`
|
||||
@@ -41,8 +47,9 @@ type Response struct {
|
||||
// Server holds the model and handles requests
|
||||
type Server struct {
|
||||
mu sync.Mutex
|
||||
model *zimage.Model
|
||||
model ImageModel
|
||||
modelName string
|
||||
modelType string // "zimage" or "glm_image"
|
||||
}
|
||||
|
||||
// Execute is the entry point for the image runner subprocess
|
||||
@@ -72,15 +79,35 @@ func Execute(args []string) error {
|
||||
requiredMemory/(1024*1024*1024), availableMemory/(1024*1024*1024))
|
||||
}
|
||||
|
||||
// Load model
|
||||
model := &zimage.Model{}
|
||||
if err := model.Load(*modelName); err != nil {
|
||||
return fmt.Errorf("failed to load model: %w", err)
|
||||
// Detect model type and load appropriate model
|
||||
modelType, err := detectModelType(*modelName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to detect model type: %w", err)
|
||||
}
|
||||
|
||||
var model ImageModel
|
||||
switch modelType {
|
||||
case "GlmImagePipeline":
|
||||
slog.Info("loading GLM-Image model")
|
||||
m := &glm_image.Model{}
|
||||
if err := m.Load(*modelName); err != nil {
|
||||
return fmt.Errorf("failed to load GLM-Image model: %w", err)
|
||||
}
|
||||
model = m
|
||||
default:
|
||||
// Default to zimage for ZImagePipeline, FluxPipeline, and unknown types
|
||||
slog.Info("loading Z-Image model")
|
||||
m := &zimage.Model{}
|
||||
if err := m.Load(*modelName); err != nil {
|
||||
return fmt.Errorf("failed to load Z-Image model: %w", err)
|
||||
}
|
||||
model = m
|
||||
}
|
||||
|
||||
server := &Server{
|
||||
model: model,
|
||||
modelName: *modelName,
|
||||
modelType: modelType,
|
||||
}
|
||||
|
||||
// Set up HTTP handlers
|
||||
@@ -144,7 +171,13 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
|
||||
req.Height = 1024
|
||||
}
|
||||
if req.Steps <= 0 {
|
||||
req.Steps = 9
|
||||
// Default steps depend on model type
|
||||
switch s.modelType {
|
||||
case "GlmImagePipeline":
|
||||
req.Steps = 50 // GLM-Image default
|
||||
default:
|
||||
req.Steps = 9 // Z-Image turbo default
|
||||
}
|
||||
}
|
||||
if req.Seed <= 0 {
|
||||
req.Seed = time.Now().UnixNano()
|
||||
@@ -159,25 +192,9 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Generate image
|
||||
// Generate image using interface method
|
||||
ctx := r.Context()
|
||||
img, err := s.model.GenerateFromConfig(ctx, &zimage.GenerateConfig{
|
||||
Prompt: req.Prompt,
|
||||
Width: req.Width,
|
||||
Height: req.Height,
|
||||
Steps: req.Steps,
|
||||
Seed: req.Seed,
|
||||
Progress: func(step, total int) {
|
||||
resp := Response{
|
||||
Content: fmt.Sprintf("\rGenerating: step %d/%d", step, total),
|
||||
Done: false,
|
||||
}
|
||||
data, _ := json.Marshal(resp)
|
||||
w.Write(data)
|
||||
w.Write([]byte("\n"))
|
||||
flusher.Flush()
|
||||
},
|
||||
})
|
||||
img, err := s.model.GenerateImage(ctx, req.Prompt, req.Width, req.Height, req.Steps, req.Seed)
|
||||
|
||||
if err != nil {
|
||||
// Don't send error for cancellation
|
||||
@@ -216,3 +233,35 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("\n"))
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
// detectModelType reads the model manifest and returns the pipeline class name
|
||||
func detectModelType(modelName string) (string, error) {
|
||||
manifest, err := imagegen.LoadManifest(modelName)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
data, err := manifest.ReadConfig("model_index.json")
|
||||
if err != nil {
|
||||
return "ZImagePipeline", nil // Default to Z-Image
|
||||
}
|
||||
|
||||
// Try both _class_name (diffusers format) and architecture (ollama format)
|
||||
var index struct {
|
||||
ClassName string `json:"_class_name"`
|
||||
Architecture string `json:"architecture"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &index); err != nil {
|
||||
return "ZImagePipeline", nil
|
||||
}
|
||||
|
||||
// Prefer _class_name, fall back to architecture
|
||||
className := index.ClassName
|
||||
if className == "" {
|
||||
className = index.Architecture
|
||||
}
|
||||
if className == "" {
|
||||
return "ZImagePipeline", nil
|
||||
}
|
||||
return className, nil
|
||||
}
|
||||
|
||||