Compare commits
2 Commits
mxyng/simp
...
imagegen-r
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bb1a5617b6 | ||
|
|
0d3648c1be |
18
Dockerfile
@@ -32,7 +32,7 @@ ENV PATH=/${VULKANVERSION}/x86_64/bin:$PATH
|
||||
FROM --platform=linux/arm64 almalinux:8 AS base-arm64
|
||||
# install epel-release for ccache
|
||||
RUN yum install -y yum-utils epel-release \
|
||||
&& dnf install -y clang ccache git \
|
||||
&& dnf install -y clang ccache \
|
||||
&& yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/sbsa/cuda-rhel8.repo
|
||||
ENV CC=clang CXX=clang++
|
||||
|
||||
@@ -149,7 +149,6 @@ COPY CMakeLists.txt CMakePresets.json .
|
||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||
COPY x/ml/backend/mlx x/ml/backend/mlx
|
||||
COPY go.mod go.sum .
|
||||
COPY MLX_VERSION .
|
||||
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
|
||||
ENV PATH=/usr/local/go/bin:$PATH
|
||||
RUN go mod download
|
||||
@@ -157,6 +156,14 @@ RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'MLX CUDA 13' -DBLAS_INCLUDE_DIRS=/usr/include/openblas -DLAPACK_INCLUDE_DIRS=/usr/include/openblas \
|
||||
&& cmake --build --parallel ${PARALLEL} --preset 'MLX CUDA 13' \
|
||||
&& cmake --install build --component MLX --strip --parallel ${PARALLEL}
|
||||
COPY . .
|
||||
ARG GOFLAGS="'-ldflags=-w -s'"
|
||||
ENV CGO_ENABLED=1
|
||||
ARG CGO_CFLAGS
|
||||
ARG CGO_CXXFLAGS
|
||||
RUN mkdir -p dist/bin
|
||||
RUN --mount=type=cache,target=/root/.cache/go-build \
|
||||
go build -tags mlx -trimpath -buildmode=pie -o dist/bin/ollama-mlx .
|
||||
|
||||
FROM base AS build
|
||||
WORKDIR /go/src/github.com/ollama/ollama
|
||||
@@ -165,14 +172,12 @@ RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-
|
||||
ENV PATH=/usr/local/go/bin:$PATH
|
||||
RUN go mod download
|
||||
COPY . .
|
||||
# Clone mlx-c headers for CGO (version from MLX_VERSION file)
|
||||
RUN git clone --depth 1 --branch "$(cat MLX_VERSION)" https://github.com/ml-explore/mlx-c.git build/_deps/mlx-c-src
|
||||
ARG GOFLAGS="'-ldflags=-w -s'"
|
||||
ENV CGO_ENABLED=1
|
||||
ENV CGO_CFLAGS="-I/go/src/github.com/ollama/ollama/build/_deps/mlx-c-src"
|
||||
ARG CGO_CFLAGS
|
||||
ARG CGO_CXXFLAGS
|
||||
RUN --mount=type=cache,target=/root/.cache/go-build \
|
||||
go build -tags mlx -trimpath -buildmode=pie -o /bin/ollama .
|
||||
go build -trimpath -buildmode=pie -o /bin/ollama .
|
||||
|
||||
FROM --platform=linux/amd64 scratch AS amd64
|
||||
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
|
||||
@@ -180,6 +185,7 @@ COPY --from=cuda-12 dist/lib/ollama /lib/ollama/
|
||||
COPY --from=cuda-13 dist/lib/ollama /lib/ollama/
|
||||
COPY --from=vulkan dist/lib/ollama /lib/ollama/
|
||||
COPY --from=mlx /go/src/github.com/ollama/ollama/dist/lib/ollama /lib/ollama/
|
||||
COPY --from=mlx /go/src/github.com/ollama/ollama/dist/bin/ /bin/
|
||||
|
||||
FROM --platform=linux/arm64 scratch AS arm64
|
||||
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
v0.4.1
|
||||
@@ -270,10 +270,10 @@ cmake --build --preset MLX --parallel
|
||||
cmake --install build --component MLX
|
||||
```
|
||||
|
||||
When building with the `-tags mlx` flag, the main `ollama` binary includes MLX support for experimental features like image generation:
|
||||
Next, build the `ollama-mlx` binary, which is a separate build of the Ollama runtime with MLX support enabled (needs to be in the same directory as `ollama`):
|
||||
|
||||
```shell
|
||||
go build -tags mlx .
|
||||
go build -tags mlx -o ollama-mlx .
|
||||
```
|
||||
|
||||
Finally, start the server:
|
||||
@@ -322,7 +322,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
|
||||
### Web & Desktop
|
||||
|
||||
- [Onyx](https://github.com/onyx-dot-app/onyx)
|
||||
- [Open WebUI](https://github.com/open-webui/open-webui)
|
||||
- [SwiftChat (macOS with ReactNative)](https://github.com/aws-samples/swift-chat)
|
||||
- [Enchanted (macOS native)](https://github.com/AugustDev/enchanted)
|
||||
|
||||
28
api/types.go
@@ -127,20 +127,6 @@ type GenerateRequest struct {
|
||||
// each with an associated log probability. Only applies when Logprobs is true.
|
||||
// Valid values are 0-20. Default is 0 (only return the selected token's logprob).
|
||||
TopLogprobs int `json:"top_logprobs,omitempty"`
|
||||
|
||||
// Experimental: Image generation fields (may change or be removed)
|
||||
|
||||
// Width is the width of the generated image in pixels.
|
||||
// Only used for image generation models.
|
||||
Width int32 `json:"width,omitempty"`
|
||||
|
||||
// Height is the height of the generated image in pixels.
|
||||
// Only used for image generation models.
|
||||
Height int32 `json:"height,omitempty"`
|
||||
|
||||
// Steps is the number of diffusion steps for image generation.
|
||||
// Only used for image generation models.
|
||||
Steps int32 `json:"steps,omitempty"`
|
||||
}
|
||||
|
||||
// ChatRequest describes a request sent by [Client.Chat].
|
||||
@@ -874,20 +860,6 @@ type GenerateResponse struct {
|
||||
// Logprobs contains log probability information for the generated tokens,
|
||||
// if requested via the Logprobs parameter.
|
||||
Logprobs []Logprob `json:"logprobs,omitempty"`
|
||||
|
||||
// Experimental: Image generation fields (may change or be removed)
|
||||
|
||||
// Image contains a base64-encoded generated image.
|
||||
// Only present for image generation models.
|
||||
Image string `json:"image,omitempty"`
|
||||
|
||||
// Completed is the number of completed steps in image generation.
|
||||
// Only present for image generation models during streaming.
|
||||
Completed int64 `json:"completed,omitempty"`
|
||||
|
||||
// Total is the total number of steps for image generation.
|
||||
// Only present for image generation models during streaming.
|
||||
Total int64 `json:"total,omitempty"`
|
||||
}
|
||||
|
||||
// ModelDetails provides details about a model.
|
||||
|
||||
@@ -14,7 +14,6 @@ extern NSString *SystemWidePath;
|
||||
@interface AppDelegate () <NSWindowDelegate, WKNavigationDelegate, WKUIDelegate>
|
||||
@property(strong, nonatomic) NSStatusItem *statusItem;
|
||||
@property(assign, nonatomic) BOOL updateAvailable;
|
||||
@property(assign, nonatomic) BOOL systemShutdownInProgress;
|
||||
@end
|
||||
|
||||
@implementation AppDelegate
|
||||
@@ -41,13 +40,6 @@ bool firstTimeRun,startHidden; // Set in run before initialization
|
||||
}
|
||||
|
||||
- (void)applicationDidFinishLaunching:(NSNotification *)aNotification {
|
||||
// Register for system shutdown/restart notification so we can allow termination
|
||||
[[[NSWorkspace sharedWorkspace] notificationCenter]
|
||||
addObserver:self
|
||||
selector:@selector(systemWillPowerOff:)
|
||||
name:NSWorkspaceWillPowerOffNotification
|
||||
object:nil];
|
||||
|
||||
// if we're in development mode, set the app icon
|
||||
NSString *bundlePath = [[NSBundle mainBundle] bundlePath];
|
||||
if (![bundlePath hasSuffix:@".app"]) {
|
||||
@@ -286,18 +278,7 @@ bool firstTimeRun,startHidden; // Set in run before initialization
|
||||
[NSApp activateIgnoringOtherApps:YES];
|
||||
}
|
||||
|
||||
- (void)systemWillPowerOff:(NSNotification *)notification {
|
||||
// Set flag so applicationShouldTerminate: knows to allow termination.
|
||||
// The system will call applicationShouldTerminate: after posting this notification.
|
||||
self.systemShutdownInProgress = YES;
|
||||
}
|
||||
|
||||
- (NSApplicationTerminateReply)applicationShouldTerminate:(NSApplication *)sender {
|
||||
// Allow termination if the system is shutting down or restarting
|
||||
if (self.systemShutdownInProgress) {
|
||||
return NSTerminateNow;
|
||||
}
|
||||
// Otherwise just hide the app (for Cmd+Q, close button, etc.)
|
||||
[NSApp hide:nil];
|
||||
[NSApp setActivationPolicy:NSApplicationActivationPolicyAccessory];
|
||||
return NSTerminateCancel;
|
||||
|
||||
103
cmd/cmd.go
@@ -46,9 +46,8 @@ import (
|
||||
"github.com/ollama/ollama/types/syncmap"
|
||||
"github.com/ollama/ollama/version"
|
||||
xcmd "github.com/ollama/ollama/x/cmd"
|
||||
"github.com/ollama/ollama/x/create"
|
||||
xcreateclient "github.com/ollama/ollama/x/create/client"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
imagegenclient "github.com/ollama/ollama/x/imagegen/client"
|
||||
)
|
||||
|
||||
const ConnectInstructions = "To sign in, navigate to:\n %s\n\n"
|
||||
@@ -94,87 +93,15 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
p := progress.NewProgress(os.Stderr)
|
||||
defer p.Stop()
|
||||
|
||||
// Validate model name early to fail fast
|
||||
modelName := args[0]
|
||||
name := model.ParseName(modelName)
|
||||
if !name.IsValid() {
|
||||
return fmt.Errorf("invalid model name: %s", modelName)
|
||||
}
|
||||
|
||||
// Check for --experimental flag for safetensors model creation
|
||||
experimental, _ := cmd.Flags().GetBool("experimental")
|
||||
if experimental {
|
||||
// Get Modelfile content - either from -f flag or default to "FROM ."
|
||||
var reader io.Reader
|
||||
filename, err := getModelfileName(cmd)
|
||||
if os.IsNotExist(err) || filename == "" {
|
||||
// No Modelfile specified or found - use default
|
||||
reader = strings.NewReader("FROM .\n")
|
||||
} else if err != nil {
|
||||
return err
|
||||
} else {
|
||||
f, err := os.Open(filename)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
reader = f
|
||||
}
|
||||
|
||||
// Parse the Modelfile
|
||||
modelfile, err := parser.ParseFile(reader)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse Modelfile: %w", err)
|
||||
}
|
||||
|
||||
// Extract FROM path and configuration
|
||||
var modelDir string
|
||||
mfConfig := &xcreateclient.ModelfileConfig{}
|
||||
|
||||
for _, cmd := range modelfile.Commands {
|
||||
switch cmd.Name {
|
||||
case "model":
|
||||
modelDir = cmd.Args
|
||||
case "template":
|
||||
mfConfig.Template = cmd.Args
|
||||
case "system":
|
||||
mfConfig.System = cmd.Args
|
||||
case "license":
|
||||
mfConfig.License = cmd.Args
|
||||
}
|
||||
}
|
||||
|
||||
if modelDir == "" {
|
||||
modelDir = "."
|
||||
}
|
||||
|
||||
// Resolve relative paths based on Modelfile location
|
||||
if !filepath.IsAbs(modelDir) && filename != "" {
|
||||
modelDir = filepath.Join(filepath.Dir(filename), modelDir)
|
||||
}
|
||||
|
||||
quantize, _ := cmd.Flags().GetString("quantize")
|
||||
return xcreateclient.CreateModel(xcreateclient.CreateOptions{
|
||||
ModelName: modelName,
|
||||
ModelDir: modelDir,
|
||||
Quantize: quantize,
|
||||
Modelfile: mfConfig,
|
||||
}, p)
|
||||
}
|
||||
|
||||
var reader io.Reader
|
||||
|
||||
filename, err := getModelfileName(cmd)
|
||||
if os.IsNotExist(err) {
|
||||
if filename == "" {
|
||||
// No Modelfile found - check if current directory is an image gen model
|
||||
if create.IsTensorModelDir(".") {
|
||||
if imagegen.IsTensorModelDir(".") {
|
||||
quantize, _ := cmd.Flags().GetString("quantize")
|
||||
return xcreateclient.CreateModel(xcreateclient.CreateOptions{
|
||||
ModelName: modelName,
|
||||
ModelDir: ".",
|
||||
Quantize: quantize,
|
||||
}, p)
|
||||
return imagegenclient.CreateModel(args[0], ".", quantize, p)
|
||||
}
|
||||
reader = strings.NewReader("FROM .\n")
|
||||
} else {
|
||||
@@ -207,7 +134,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
spinner.Stop()
|
||||
|
||||
req.Model = modelName
|
||||
req.Model = args[0]
|
||||
quantize, _ := cmd.Flags().GetString("quantize")
|
||||
if quantize != "" {
|
||||
req.Quantize = quantize
|
||||
@@ -600,7 +527,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
|
||||
// Check if this is an image generation model
|
||||
if slices.Contains(info.Capabilities, model.CapabilityImage) {
|
||||
if slices.Contains(info.Capabilities, model.CapabilityImageGeneration) {
|
||||
if opts.Prompt == "" && !interactive {
|
||||
return errors.New("image generation models require a prompt. Usage: ollama run " + name + " \"your prompt here\"")
|
||||
}
|
||||
@@ -1815,22 +1742,15 @@ func NewCLI() *cobra.Command {
|
||||
rootCmd.Flags().BoolP("version", "v", false, "Show version information")
|
||||
|
||||
createCmd := &cobra.Command{
|
||||
Use: "create MODEL",
|
||||
Short: "Create a model",
|
||||
Args: cobra.ExactArgs(1),
|
||||
PreRunE: func(cmd *cobra.Command, args []string) error {
|
||||
// Skip server check for experimental mode (writes directly to disk)
|
||||
if experimental, _ := cmd.Flags().GetBool("experimental"); experimental {
|
||||
return nil
|
||||
}
|
||||
return checkServerHeartbeat(cmd, args)
|
||||
},
|
||||
RunE: CreateHandler,
|
||||
Use: "create MODEL",
|
||||
Short: "Create a model",
|
||||
Args: cobra.ExactArgs(1),
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: CreateHandler,
|
||||
}
|
||||
|
||||
createCmd.Flags().StringP("file", "f", "", "Name of the Modelfile (default \"Modelfile\")")
|
||||
createCmd.Flags().StringP("quantize", "q", "", "Quantize model to this level (e.g. q4_K_M)")
|
||||
createCmd.Flags().Bool("experimental", false, "Enable experimental safetensors model creation")
|
||||
|
||||
showCmd := &cobra.Command{
|
||||
Use: "show MODEL",
|
||||
@@ -1958,7 +1878,7 @@ func NewCLI() *cobra.Command {
|
||||
Use: "runner",
|
||||
Hidden: true,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
return runner.Execute(os.Args[2:])
|
||||
return runner.Execute(os.Args[1:])
|
||||
},
|
||||
FParseErrWhitelist: cobra.FParseErrWhitelist{UnknownFlags: true},
|
||||
}
|
||||
@@ -1985,7 +1905,6 @@ func NewCLI() *cobra.Command {
|
||||
} {
|
||||
switch cmd {
|
||||
case runCmd:
|
||||
imagegen.AppendFlagsDocs(cmd)
|
||||
appendEnvDocs(cmd, []envconfig.EnvVar{envVars["OLLAMA_HOST"], envVars["OLLAMA_NOHISTORY"]})
|
||||
case serveCmd:
|
||||
appendEnvDocs(cmd, []envconfig.EnvVar{
|
||||
|
||||
@@ -1555,7 +1555,7 @@ func TestShowInfoImageGen(t *testing.T) {
|
||||
ParameterSize: "10.3B",
|
||||
QuantizationLevel: "FP8",
|
||||
},
|
||||
Capabilities: []model.Capability{model.CapabilityImage},
|
||||
Capabilities: []model.Capability{model.CapabilityImageGeneration},
|
||||
Requires: "0.14.0",
|
||||
}, false, &b)
|
||||
if err != nil {
|
||||
|
||||
@@ -116,7 +116,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
Prompt: ">>> ",
|
||||
AltPrompt: "... ",
|
||||
Placeholder: "Send a message (/? for help)",
|
||||
AltPlaceholder: "Press Enter to send",
|
||||
AltPlaceholder: `Use """ to end multi-line input`,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
15
cmd/runner/main.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/ollama/ollama/runner"
|
||||
)
|
||||
|
||||
func main() {
|
||||
if err := runner.Execute(os.Args[1:]); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "error: %s\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
62
docs/api.md
@@ -16,7 +16,6 @@
|
||||
- [Generate Embeddings](#generate-embeddings)
|
||||
- [List Running Models](#list-running-models)
|
||||
- [Version](#version)
|
||||
- [Experimental: Image Generation](#image-generation-experimental)
|
||||
|
||||
## Conventions
|
||||
|
||||
@@ -59,15 +58,6 @@ Advanced parameters (optional):
|
||||
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
|
||||
- `context` (deprecated): the context parameter returned from a previous request to `/generate`, this can be used to keep a short conversational memory
|
||||
|
||||
Experimental image generation parameters (for image generation models only):
|
||||
|
||||
> [!WARNING]
|
||||
> These parameters are experimental and may change in future versions.
|
||||
|
||||
- `width`: width of the generated image in pixels
|
||||
- `height`: height of the generated image in pixels
|
||||
- `steps`: number of diffusion steps
|
||||
|
||||
#### Structured outputs
|
||||
|
||||
Structured outputs are supported by providing a JSON schema in the `format` parameter. The model will generate a response that matches the schema. See the [structured outputs](#request-structured-outputs) example below.
|
||||
@@ -1877,55 +1867,3 @@ curl http://localhost:11434/api/version
|
||||
"version": "0.5.1"
|
||||
}
|
||||
```
|
||||
|
||||
## Experimental Features
|
||||
|
||||
### Image Generation (Experimental)
|
||||
|
||||
> [!WARNING]
|
||||
> Image generation is experimental and may change in future versions.
|
||||
|
||||
Image generation is now supported through the standard `/api/generate` endpoint when using image generation models. The API automatically detects when an image generation model is being used.
|
||||
|
||||
See the [Generate a completion](#generate-a-completion) section for the full API documentation. The experimental image generation parameters (`width`, `height`, `steps`) are documented there.
|
||||
|
||||
#### Example
|
||||
|
||||
##### Request
|
||||
|
||||
```shell
|
||||
curl http://localhost:11434/api/generate -d '{
|
||||
"model": "x/z-image-turbo",
|
||||
"prompt": "a sunset over mountains",
|
||||
"width": 1024,
|
||||
"height": 768
|
||||
}'
|
||||
```
|
||||
|
||||
##### Response (streaming)
|
||||
|
||||
Progress updates during generation:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "x/z-image-turbo",
|
||||
"created_at": "2024-01-15T10:30:00.000000Z",
|
||||
"completed": 5,
|
||||
"total": 20,
|
||||
"done": false
|
||||
}
|
||||
```
|
||||
|
||||
##### Final Response
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "x/z-image-turbo",
|
||||
"created_at": "2024-01-15T10:30:15.000000Z",
|
||||
"image": "iVBORw0KGgoAAAANSUhEUg...",
|
||||
"done": true,
|
||||
"done_reason": "stop",
|
||||
"total_duration": 15000000000,
|
||||
"load_duration": 2000000000
|
||||
}
|
||||
```
|
||||
|
||||
@@ -21,7 +21,6 @@ ollama pull glm-4.7:cloud
|
||||
To use Ollama with tools that expect the Anthropic API (like Claude Code), set these environment variables:
|
||||
|
||||
```shell
|
||||
export ANTHROPIC_AUTH_TOKEN=ollama # required but ignored
|
||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||
export ANTHROPIC_API_KEY=ollama # required but ignored
|
||||
```
|
||||
@@ -248,13 +247,12 @@ curl -X POST http://localhost:11434/v1/messages \
|
||||
[Claude Code](https://code.claude.com/docs/en/overview) can be configured to use Ollama as its backend:
|
||||
|
||||
```shell
|
||||
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
|
||||
ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
|
||||
```
|
||||
|
||||
Or set the environment variables in your shell profile:
|
||||
|
||||
```shell
|
||||
export ANTHROPIC_AUTH_TOKEN=ollama
|
||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||
export ANTHROPIC_API_KEY=ollama
|
||||
```
|
||||
|
||||
@@ -275,73 +275,6 @@ curl -X POST http://localhost:11434/v1/chat/completions \
|
||||
- [x] `dimensions`
|
||||
- [ ] `user`
|
||||
|
||||
### `/v1/images/generations` (experimental)
|
||||
|
||||
> Note: This endpoint is experimental and may change or be removed in future versions.
|
||||
|
||||
Generate images using image generation models.
|
||||
|
||||
<CodeGroup dropdown>
|
||||
|
||||
```python images.py
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(
|
||||
base_url='http://localhost:11434/v1/',
|
||||
api_key='ollama', # required but ignored
|
||||
)
|
||||
|
||||
response = client.images.generate(
|
||||
model='x/z-image-turbo',
|
||||
prompt='A cute robot learning to paint',
|
||||
size='1024x1024',
|
||||
response_format='b64_json',
|
||||
)
|
||||
print(response.data[0].b64_json[:50] + '...')
|
||||
```
|
||||
|
||||
```javascript images.js
|
||||
import OpenAI from "openai";
|
||||
|
||||
const openai = new OpenAI({
|
||||
baseURL: "http://localhost:11434/v1/",
|
||||
apiKey: "ollama", // required but ignored
|
||||
});
|
||||
|
||||
const response = await openai.images.generate({
|
||||
model: "x/z-image-turbo",
|
||||
prompt: "A cute robot learning to paint",
|
||||
size: "1024x1024",
|
||||
response_format: "b64_json",
|
||||
});
|
||||
|
||||
console.log(response.data[0].b64_json.slice(0, 50) + "...");
|
||||
```
|
||||
|
||||
```shell images.sh
|
||||
curl -X POST http://localhost:11434/v1/images/generations \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "x/z-image-turbo",
|
||||
"prompt": "A cute robot learning to paint",
|
||||
"size": "1024x1024",
|
||||
"response_format": "b64_json"
|
||||
}'
|
||||
```
|
||||
|
||||
</CodeGroup>
|
||||
|
||||
#### Supported request fields
|
||||
|
||||
- [x] `model`
|
||||
- [x] `prompt`
|
||||
- [x] `size` (e.g. "1024x1024")
|
||||
- [x] `response_format` (only `b64_json` supported)
|
||||
- [ ] `n`
|
||||
- [ ] `quality`
|
||||
- [ ] `style`
|
||||
- [ ] `user`
|
||||
|
||||
### `/v1/responses`
|
||||
|
||||
> Note: Added in Ollama v0.13.3
|
||||
|
||||
@@ -110,7 +110,7 @@ More Ollama [Python example](https://github.com/ollama/ollama-python/blob/main/e
|
||||
import { Ollama } from "ollama";
|
||||
|
||||
const client = new Ollama();
|
||||
const results = await client.webSearch("what is ollama?");
|
||||
const results = await client.webSearch({ query: "what is ollama?" });
|
||||
console.log(JSON.stringify(results, null, 2));
|
||||
```
|
||||
|
||||
@@ -213,7 +213,7 @@ models](https://ollama.com/models)\n\nAvailable for macOS, Windows, and Linux',
|
||||
import { Ollama } from "ollama";
|
||||
|
||||
const client = new Ollama();
|
||||
const fetchResult = await client.webFetch("https://ollama.com");
|
||||
const fetchResult = await client.webFetch({ url: "https://ollama.com" });
|
||||
console.log(JSON.stringify(fetchResult, null, 2));
|
||||
```
|
||||
|
||||
|
||||
@@ -111,9 +111,7 @@
|
||||
"/integrations/zed",
|
||||
"/integrations/roo-code",
|
||||
"/integrations/n8n",
|
||||
"/integrations/xcode",
|
||||
"/integrations/onyx",
|
||||
"/integrations/marimo"
|
||||
"/integrations/xcode"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -22,7 +22,7 @@ Please refer to the [GPU docs](./gpu).
|
||||
|
||||
## How can I specify the context window size?
|
||||
|
||||
By default, Ollama uses a context window size of 4096 tokens.
|
||||
By default, Ollama uses a context window size of 2048 tokens.
|
||||
|
||||
This can be overridden with the `OLLAMA_CONTEXT_LENGTH` environment variable. For example, to set the default context window to 8K, use:
|
||||
|
||||
|
||||
|
Before Width: | Height: | Size: 174 KiB |
|
Before Width: | Height: | Size: 80 KiB |
|
Before Width: | Height: | Size: 230 KiB |
|
Before Width: | Height: | Size: 178 KiB |
|
Before Width: | Height: | Size: 186 KiB |
|
Before Width: | Height: | Size: 100 KiB |
|
Before Width: | Height: | Size: 306 KiB |
|
Before Width: | Height: | Size: 300 KiB |
|
Before Width: | Height: | Size: 211 KiB |
@@ -2,12 +2,6 @@
|
||||
title: Claude Code
|
||||
---
|
||||
|
||||
Claude Code is Anthropic's agentic coding tool that can read, modify, and execute code in your working directory.
|
||||
|
||||
Open models can be used with Claude Code through Ollama's Anthropic-compatible API, enabling you to use models such as `qwen3-coder`, `gpt-oss:20b`, or other models.
|
||||
|
||||

|
||||
|
||||
## Install
|
||||
|
||||
Install [Claude Code](https://code.claude.com/docs/en/overview):
|
||||
@@ -31,24 +25,22 @@ Claude Code connects to Ollama using the Anthropic-compatible API.
|
||||
1. Set the environment variables:
|
||||
|
||||
```shell
|
||||
export ANTHROPIC_AUTH_TOKEN=ollama
|
||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||
export ANTHROPIC_API_KEY=ollama
|
||||
```
|
||||
|
||||
2. Run Claude Code with an Ollama model:
|
||||
|
||||
```shell
|
||||
claude --model gpt-oss:20b
|
||||
claude --model qwen3-coder
|
||||
```
|
||||
|
||||
Or run with environment variables inline:
|
||||
|
||||
```shell
|
||||
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 claude --model gpt-oss:20b
|
||||
ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
|
||||
```
|
||||
|
||||
**Note:** Claude Code requires a large context window. We recommend at least 32K tokens. See the [context length documentation](/context-length) for how to adjust context length in Ollama.
|
||||
|
||||
## Connecting to ollama.com
|
||||
|
||||
1. Create an [API key](https://ollama.com/settings/keys) on ollama.com
|
||||
@@ -75,4 +67,3 @@ claude --model glm-4.7:cloud
|
||||
### Local models
|
||||
- `qwen3-coder` - Excellent for coding tasks
|
||||
- `gpt-oss:20b` - Strong general-purpose model
|
||||
- `gpt-oss:120b` - Larger general-purpose model for more complex tasks
|
||||
@@ -1,73 +0,0 @@
|
||||
---
|
||||
title: marimo
|
||||
---
|
||||
|
||||
## Install
|
||||
|
||||
Install [marimo](https://marimo.io). You can use `pip` or `uv` for this. You
|
||||
can also use `uv` to create a sandboxed environment for marimo by running:
|
||||
|
||||
```
|
||||
uvx marimo edit --sandbox notebook.py
|
||||
```
|
||||
|
||||
## Usage with Ollama
|
||||
|
||||
1. In marimo, go to the user settings and go to the AI tab. From here
|
||||
you can find and configure Ollama as an AI provider. For local use you
|
||||
would typically point the base url to `http://localhost:11434/v1`.
|
||||
|
||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<img
|
||||
src="/images/marimo-settings.png"
|
||||
alt="Ollama settings in marimo"
|
||||
width="50%"
|
||||
/>
|
||||
</div>
|
||||
|
||||
2. Once the AI provider is set up, you can turn on/off specific AI models you'd like to access.
|
||||
|
||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<img
|
||||
src="/images/marimo-models.png"
|
||||
alt="Selecting an Ollama model"
|
||||
width="50%"
|
||||
/>
|
||||
</div>
|
||||
|
||||
3. You can also add a model to the list of available models by scrolling to the bottom and using the UI there.
|
||||
|
||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<img
|
||||
src="/images/marimo-add-model.png"
|
||||
alt="Adding a new Ollama model"
|
||||
width="50%"
|
||||
/>
|
||||
</div>
|
||||
|
||||
4. Once configured, you can now use Ollama for AI chats in marimo.
|
||||
|
||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<img
|
||||
src="/images/marimo-chat.png"
|
||||
alt="Configure code completion"
|
||||
width="50%"
|
||||
/>
|
||||
</div>
|
||||
|
||||
4. Alternatively, you can now use Ollama for **inline code completion** in marimo. This can be configured in the "AI Features" tab.
|
||||
|
||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<img
|
||||
src="/images/marimo-code-completion.png"
|
||||
alt="Configure code completion"
|
||||
width="50%"
|
||||
/>
|
||||
</div>
|
||||
|
||||
|
||||
## Connecting to ollama.com
|
||||
|
||||
1. Sign in to ollama cloud via `ollama signin`
|
||||
2. In the ollama model settings add a model that ollama hosts, like `gpt-oss:120b`.
|
||||
3. You can now refer to this model in marimo!
|
||||
@@ -1,63 +0,0 @@
|
||||
---
|
||||
title: Onyx
|
||||
---
|
||||
|
||||
## Overview
|
||||
[Onyx](http://onyx.app/) is a self-hostable Chat UI that integrates with all Ollama models. Features include:
|
||||
- Creating custom Agents
|
||||
- Web search
|
||||
- Deep Research
|
||||
- RAG over uploaded documents and connected apps
|
||||
- Connectors to applications like Google Drive, Email, Slack, etc.
|
||||
- MCP and OpenAPI Actions support
|
||||
- Image generation
|
||||
- User/Groups management, RBAC, SSO, etc.
|
||||
|
||||
Onyx can be deployed for single users or large organizations.
|
||||
|
||||
## Install Onyx
|
||||
|
||||
Deploy Onyx with the [quickstart guide](https://docs.onyx.app/deployment/getting_started/quickstart).
|
||||
|
||||
<Info>
|
||||
Resourcing/scaling docs [here](https://docs.onyx.app/deployment/getting_started/resourcing).
|
||||
</Info>
|
||||
|
||||
## Usage with Ollama
|
||||
|
||||
1. Login to your Onyx deployment (create an account first).
|
||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<img
|
||||
src="/images/onyx-login.png"
|
||||
alt="Onyx Login Page"
|
||||
width="75%"
|
||||
/>
|
||||
</div>
|
||||
2. In the set-up process select `Ollama` as the LLM provider.
|
||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<img
|
||||
src="/images/onyx-ollama-llm.png"
|
||||
alt="Onyx Set Up Form"
|
||||
width="75%"
|
||||
/>
|
||||
</div>
|
||||
3. Provide your **Ollama API URL** and select your models.
|
||||
<Note>If you're running Onyx in Docker, to access your computer's local network use `http://host.docker.internal` instead of `http://127.0.0.1`.</Note>
|
||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<img
|
||||
src="/images/onyx-ollama-form.png"
|
||||
alt="Selecting Ollama Models"
|
||||
width="75%"
|
||||
/>
|
||||
</div>
|
||||
|
||||
You can also easily connect up Onyx Cloud with the `Ollama Cloud` tab of the setup.
|
||||
|
||||
## Send your first query
|
||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||
<img
|
||||
src="/images/onyx-query.png"
|
||||
alt="Onyx Query Example"
|
||||
width="75%"
|
||||
/>
|
||||
</div>
|
||||
@@ -1,5 +1,5 @@
|
||||
---
|
||||
title: Linux
|
||||
title: "Linux"
|
||||
---
|
||||
|
||||
## Install
|
||||
@@ -13,15 +13,14 @@ curl -fsSL https://ollama.com/install.sh | sh
|
||||
## Manual install
|
||||
|
||||
<Note>
|
||||
If you are upgrading from a prior version, you should remove the old libraries
|
||||
with `sudo rm -rf /usr/lib/ollama` first.
|
||||
If you are upgrading from a prior version, you should remove the old libraries with `sudo rm -rf /usr/lib/ollama` first.
|
||||
</Note>
|
||||
|
||||
Download and extract the package:
|
||||
|
||||
```shell
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tar.zst \
|
||||
| sudo tar x -C /usr
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tgz \
|
||||
| sudo tar zx -C /usr
|
||||
```
|
||||
|
||||
Start Ollama:
|
||||
@@ -41,8 +40,8 @@ ollama -v
|
||||
If you have an AMD GPU, also download and extract the additional ROCm package:
|
||||
|
||||
```shell
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tar.zst \
|
||||
| sudo tar x -C /usr
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tgz \
|
||||
| sudo tar zx -C /usr
|
||||
```
|
||||
|
||||
### ARM64 install
|
||||
@@ -50,8 +49,8 @@ curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tar.zst \
|
||||
Download and extract the ARM64-specific package:
|
||||
|
||||
```shell
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-arm64.tar.zst \
|
||||
| sudo tar x -C /usr
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-arm64.tgz \
|
||||
| sudo tar zx -C /usr
|
||||
```
|
||||
|
||||
### Adding Ollama as a startup service (recommended)
|
||||
@@ -113,11 +112,7 @@ sudo systemctl status ollama
|
||||
```
|
||||
|
||||
<Note>
|
||||
While AMD has contributed the `amdgpu` driver upstream to the official linux
|
||||
kernel source, the version is older and may not support all ROCm features. We
|
||||
recommend you install the latest driver from
|
||||
https://www.amd.com/en/support/linux-drivers for best support of your Radeon
|
||||
GPU.
|
||||
While AMD has contributed the `amdgpu` driver upstream to the official linux kernel source, the version is older and may not support all ROCm features. We recommend you install the latest driver from https://www.amd.com/en/support/linux-drivers for best support of your Radeon GPU.
|
||||
</Note>
|
||||
|
||||
## Customizing
|
||||
@@ -146,8 +141,8 @@ curl -fsSL https://ollama.com/install.sh | sh
|
||||
Or by re-downloading Ollama:
|
||||
|
||||
```shell
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tar.zst \
|
||||
| sudo tar x -C /usr
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tgz \
|
||||
| sudo tar zx -C /usr
|
||||
```
|
||||
|
||||
## Installing specific versions
|
||||
@@ -196,4 +191,4 @@ Remove the downloaded models and Ollama service user and group:
|
||||
sudo userdel ollama
|
||||
sudo groupdel ollama
|
||||
sudo rm -r /usr/share/ollama
|
||||
```
|
||||
```
|
||||
@@ -131,7 +131,7 @@ func TestAPIToolCalling(t *testing.T) {
|
||||
t.Errorf("unexpected tool called: got %q want %q", lastToolCall.Function.Name, "get_weather")
|
||||
}
|
||||
|
||||
if _, ok := lastToolCall.Function.Arguments.Get("location"); !ok {
|
||||
if _, ok := lastToolCall.Function.Arguments["location"]; !ok {
|
||||
t.Errorf("expected tool arguments to include 'location', got: %s", lastToolCall.Function.Arguments.String())
|
||||
}
|
||||
case <-ctx.Done():
|
||||
|
||||
@@ -1464,12 +1464,6 @@ type CompletionRequest struct {
|
||||
|
||||
// TopLogprobs specifies the number of most likely alternative tokens to return (0-20)
|
||||
TopLogprobs int
|
||||
|
||||
// Image generation fields
|
||||
Width int32 `json:"width,omitempty"`
|
||||
Height int32 `json:"height,omitempty"`
|
||||
Steps int32 `json:"steps,omitempty"`
|
||||
Seed int64 `json:"seed,omitempty"`
|
||||
}
|
||||
|
||||
// DoneReason represents the reason why a completion response is done
|
||||
@@ -1518,15 +1512,6 @@ type CompletionResponse struct {
|
||||
|
||||
// Logprobs contains log probability information if requested
|
||||
Logprobs []Logprob `json:"logprobs,omitempty"`
|
||||
|
||||
// Image contains base64-encoded image data for image generation
|
||||
Image string `json:"image,omitempty"`
|
||||
|
||||
// Step is the current step in image generation
|
||||
Step int `json:"step,omitempty"`
|
||||
|
||||
// TotalSteps is the total number of steps for image generation
|
||||
TotalSteps int `json:"total_steps,omitempty"`
|
||||
}
|
||||
|
||||
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
@@ -442,7 +441,6 @@ type ResponsesWriter struct {
|
||||
stream bool
|
||||
responseID string
|
||||
itemID string
|
||||
request openai.ResponsesRequest
|
||||
}
|
||||
|
||||
func (w *ResponsesWriter) writeEvent(eventType string, data any) error {
|
||||
@@ -480,9 +478,7 @@ func (w *ResponsesWriter) writeResponse(data []byte) (int, error) {
|
||||
|
||||
// Non-streaming response
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
response := openai.ToResponse(w.model, w.responseID, w.itemID, chatResponse, w.request)
|
||||
completedAt := time.Now().Unix()
|
||||
response.CompletedAt = &completedAt
|
||||
response := openai.ToResponse(w.model, w.responseID, w.itemID, chatResponse)
|
||||
return len(data), json.NewEncoder(w.ResponseWriter).Encode(response)
|
||||
}
|
||||
|
||||
@@ -527,12 +523,11 @@ func ResponsesMiddleware() gin.HandlerFunc {
|
||||
|
||||
w := &ResponsesWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
converter: openai.NewResponsesStreamConverter(responseID, itemID, req.Model, req),
|
||||
converter: openai.NewResponsesStreamConverter(responseID, itemID, req.Model),
|
||||
model: req.Model,
|
||||
stream: streamRequested,
|
||||
responseID: responseID,
|
||||
itemID: itemID,
|
||||
request: req,
|
||||
}
|
||||
|
||||
// Set headers based on streaming mode
|
||||
@@ -546,66 +541,3 @@ func ResponsesMiddleware() gin.HandlerFunc {
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
type ImageWriter struct {
|
||||
BaseWriter
|
||||
}
|
||||
|
||||
func (w *ImageWriter) writeResponse(data []byte) (int, error) {
|
||||
var generateResponse api.GenerateResponse
|
||||
if err := json.Unmarshal(data, &generateResponse); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Only write response when done with image
|
||||
if generateResponse.Done && generateResponse.Image != "" {
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
return len(data), json.NewEncoder(w.ResponseWriter).Encode(openai.ToImageGenerationResponse(generateResponse))
|
||||
}
|
||||
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (w *ImageWriter) Write(data []byte) (int, error) {
|
||||
code := w.ResponseWriter.Status()
|
||||
if code != http.StatusOK {
|
||||
return w.writeError(data)
|
||||
}
|
||||
|
||||
return w.writeResponse(data)
|
||||
}
|
||||
|
||||
func ImageGenerationsMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var req openai.ImageGenerationRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
if req.Prompt == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "prompt is required"))
|
||||
return
|
||||
}
|
||||
|
||||
if req.Model == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "model is required"))
|
||||
return
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(openai.FromImageGenerationRequest(req)); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(&b)
|
||||
|
||||
w := &ImageWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
}
|
||||
|
||||
c.Writer = w
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -961,154 +961,3 @@ func TestRetrieveMiddleware(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestImageGenerationsMiddleware(t *testing.T) {
|
||||
type testCase struct {
|
||||
name string
|
||||
body string
|
||||
req api.GenerateRequest
|
||||
err openai.ErrorResponse
|
||||
}
|
||||
|
||||
var capturedRequest *api.GenerateRequest
|
||||
|
||||
testCases := []testCase{
|
||||
{
|
||||
name: "image generation basic",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"prompt": "a beautiful sunset"
|
||||
}`,
|
||||
req: api.GenerateRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "a beautiful sunset",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "image generation with size",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"prompt": "a beautiful sunset",
|
||||
"size": "512x768"
|
||||
}`,
|
||||
req: api.GenerateRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "a beautiful sunset",
|
||||
Width: 512,
|
||||
Height: 768,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "image generation missing prompt",
|
||||
body: `{
|
||||
"model": "test-model"
|
||||
}`,
|
||||
err: openai.ErrorResponse{
|
||||
Error: openai.Error{
|
||||
Message: "prompt is required",
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "image generation missing model",
|
||||
body: `{
|
||||
"prompt": "a beautiful sunset"
|
||||
}`,
|
||||
err: openai.ErrorResponse{
|
||||
Error: openai.Error{
|
||||
Message: "model is required",
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
endpoint := func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
}
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(ImageGenerationsMiddleware(), captureRequestMiddleware(&capturedRequest))
|
||||
router.Handle(http.MethodPost, "/api/generate", endpoint)
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tc.body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
defer func() { capturedRequest = nil }()
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if tc.err.Error.Message != "" {
|
||||
var errResp openai.ErrorResponse
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.err, errResp); diff != "" {
|
||||
t.Fatalf("errors did not match:\n%s", diff)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d: %s", resp.Code, resp.Body.String())
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(&tc.req, capturedRequest); diff != "" {
|
||||
t.Fatalf("requests did not match:\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestImageWriterResponse(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
// Test that ImageWriter transforms GenerateResponse to OpenAI format
|
||||
endpoint := func(c *gin.Context) {
|
||||
resp := api.GenerateResponse{
|
||||
Model: "test-model",
|
||||
CreatedAt: time.Unix(1234567890, 0).UTC(),
|
||||
Done: true,
|
||||
Image: "dGVzdC1pbWFnZS1kYXRh", // base64 of "test-image-data"
|
||||
}
|
||||
data, _ := json.Marshal(resp)
|
||||
c.Writer.Write(append(data, '\n'))
|
||||
}
|
||||
|
||||
router := gin.New()
|
||||
router.Use(ImageGenerationsMiddleware())
|
||||
router.Handle(http.MethodPost, "/api/generate", endpoint)
|
||||
|
||||
body := `{"model": "test-model", "prompt": "test"}`
|
||||
req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d: %s", resp.Code, resp.Body.String())
|
||||
}
|
||||
|
||||
var imageResp openai.ImageGenerationResponse
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &imageResp); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if imageResp.Created != 1234567890 {
|
||||
t.Errorf("expected created 1234567890, got %d", imageResp.Created)
|
||||
}
|
||||
|
||||
if len(imageResp.Data) != 1 {
|
||||
t.Fatalf("expected 1 image, got %d", len(imageResp.Data))
|
||||
}
|
||||
|
||||
if imageResp.Data[0].B64JSON != "dGVzdC1pbWFnZS1kYXRh" {
|
||||
t.Errorf("expected image data 'dGVzdC1pbWFnZS1kYXRh', got %s", imageResp.Data[0].B64JSON)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
@@ -13,114 +14,243 @@ const (
|
||||
Nemotron3NanoCollectingThinking Nemotron3NanoParserState = iota
|
||||
Nemotron3NanoSkipWhitespaceAfterThinking
|
||||
Nemotron3NanoCollectingContent
|
||||
Nemotron3NanoCollectingToolCalls
|
||||
)
|
||||
|
||||
const (
|
||||
nemotronThinkClose = "</think>"
|
||||
nemotronToolCallOpen = "<tool_call>"
|
||||
nemotronThinkClose = "</think>"
|
||||
nemotronToolCallOpen = "<tool_call>"
|
||||
nemotronToolCallClose = "</tool_call>"
|
||||
)
|
||||
|
||||
type Nemotron3NanoParser struct {
|
||||
state Nemotron3NanoParserState
|
||||
buffer strings.Builder
|
||||
toolParser *Qwen3CoderParser
|
||||
state Nemotron3NanoParserState
|
||||
buffer strings.Builder
|
||||
tools []api.Tool
|
||||
}
|
||||
|
||||
func (p *Nemotron3NanoParser) HasToolSupport() bool { return true }
|
||||
func (p *Nemotron3NanoParser) HasThinkingSupport() bool { return true }
|
||||
|
||||
func (p *Nemotron3NanoParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||
p.toolParser = &Qwen3CoderParser{}
|
||||
p.toolParser.Init(tools, nil, nil)
|
||||
p.tools = tools
|
||||
|
||||
// thinking is enabled if user requests it
|
||||
thinkingEnabled := thinkValue != nil && thinkValue.Bool()
|
||||
|
||||
prefill := lastMessage != nil && lastMessage.Role == "assistant"
|
||||
|
||||
if !thinkingEnabled || (prefill && lastMessage.Content != "") {
|
||||
if !thinkingEnabled {
|
||||
p.state = Nemotron3NanoCollectingContent
|
||||
} else {
|
||||
p.state = Nemotron3NanoCollectingThinking
|
||||
return tools
|
||||
}
|
||||
|
||||
if prefill && lastMessage.Content != "" {
|
||||
p.state = Nemotron3NanoCollectingContent
|
||||
return tools
|
||||
}
|
||||
|
||||
p.state = Nemotron3NanoCollectingThinking
|
||||
return tools
|
||||
}
|
||||
|
||||
type nemotronEvent interface {
|
||||
isNemotronEvent()
|
||||
}
|
||||
|
||||
type nemotronEventThinkingContent struct {
|
||||
content string
|
||||
}
|
||||
|
||||
type nemotronEventContent struct {
|
||||
content string
|
||||
}
|
||||
|
||||
type nemotronEventToolCall struct {
|
||||
toolCall api.ToolCall
|
||||
}
|
||||
|
||||
func (nemotronEventThinkingContent) isNemotronEvent() {}
|
||||
func (nemotronEventContent) isNemotronEvent() {}
|
||||
func (nemotronEventToolCall) isNemotronEvent() {}
|
||||
|
||||
func (p *Nemotron3NanoParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
|
||||
if p.state == Nemotron3NanoCollectingContent {
|
||||
return p.toolParser.Add(s, done)
|
||||
p.buffer.WriteString(s)
|
||||
events := p.parseEvents()
|
||||
|
||||
var toolCalls []api.ToolCall
|
||||
var contentSb strings.Builder
|
||||
var thinkingSb strings.Builder
|
||||
for _, event := range events {
|
||||
switch event := event.(type) {
|
||||
case nemotronEventToolCall:
|
||||
toolCalls = append(toolCalls, event.toolCall)
|
||||
case nemotronEventThinkingContent:
|
||||
thinkingSb.WriteString(event.content)
|
||||
case nemotronEventContent:
|
||||
contentSb.WriteString(event.content)
|
||||
}
|
||||
}
|
||||
|
||||
if p.state == Nemotron3NanoSkipWhitespaceAfterThinking {
|
||||
s = strings.TrimLeftFunc(s, unicode.IsSpace)
|
||||
if s == "" {
|
||||
return "", "", nil, nil
|
||||
return contentSb.String(), thinkingSb.String(), toolCalls, nil
|
||||
}
|
||||
|
||||
func (p *Nemotron3NanoParser) parseEvents() []nemotronEvent {
|
||||
var all []nemotronEvent
|
||||
|
||||
keepLooping := true
|
||||
for keepLooping {
|
||||
var events []nemotronEvent
|
||||
events, keepLooping = p.eat()
|
||||
if len(events) > 0 {
|
||||
all = append(all, events...)
|
||||
}
|
||||
}
|
||||
|
||||
return all
|
||||
}
|
||||
|
||||
// emitWithPartialCheck extracts unambiguous content before a potential partial tag
|
||||
func (p *Nemotron3NanoParser) emitWithPartialCheck(bufStr, tag string) (unambiguous, ambiguous string) {
|
||||
if overlapLen := overlap(bufStr, tag); overlapLen > 0 {
|
||||
beforePartialTag := bufStr[:len(bufStr)-overlapLen]
|
||||
trailingLen := trailingWhitespaceLen(beforePartialTag)
|
||||
return bufStr[:len(beforePartialTag)-trailingLen], bufStr[len(beforePartialTag)-trailingLen:]
|
||||
}
|
||||
wsLen := trailingWhitespaceLen(bufStr)
|
||||
return bufStr[:len(bufStr)-wsLen], bufStr[len(bufStr)-wsLen:]
|
||||
}
|
||||
|
||||
func (p *Nemotron3NanoParser) eat() ([]nemotronEvent, bool) {
|
||||
bufStr := p.buffer.String()
|
||||
if bufStr == "" {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
switch p.state {
|
||||
case Nemotron3NanoCollectingThinking:
|
||||
if strings.Contains(bufStr, nemotronThinkClose) {
|
||||
split := strings.SplitN(bufStr, nemotronThinkClose, 2)
|
||||
thinking := strings.TrimRightFunc(split[0], unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
remainder := strings.TrimLeftFunc(split[1], unicode.IsSpace)
|
||||
p.buffer.WriteString(remainder)
|
||||
// Transition to whitespace-skipping state if buffer is empty,
|
||||
// otherwise go directly to content collection
|
||||
if remainder == "" {
|
||||
p.state = Nemotron3NanoSkipWhitespaceAfterThinking
|
||||
} else {
|
||||
p.state = Nemotron3NanoCollectingContent
|
||||
}
|
||||
if thinking != "" {
|
||||
return []nemotronEvent{nemotronEventThinkingContent{content: thinking}}, true
|
||||
}
|
||||
return nil, true
|
||||
}
|
||||
unambig, ambig := p.emitWithPartialCheck(bufStr, nemotronThinkClose)
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambig)
|
||||
if unambig != "" {
|
||||
return []nemotronEvent{nemotronEventThinkingContent{content: unambig}}, false
|
||||
}
|
||||
return nil, false
|
||||
|
||||
// We only want to skip whitespace between thinking and content
|
||||
case Nemotron3NanoSkipWhitespaceAfterThinking:
|
||||
bufStr = strings.TrimLeftFunc(bufStr, unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(bufStr)
|
||||
if bufStr == "" {
|
||||
return nil, false
|
||||
}
|
||||
p.state = Nemotron3NanoCollectingContent
|
||||
return p.toolParser.Add(s, done)
|
||||
}
|
||||
return nil, true
|
||||
|
||||
// Nemotron3NanoCollectingThinking - buffer and look for end markers
|
||||
p.buffer.WriteString(s)
|
||||
bufStr := p.buffer.String()
|
||||
|
||||
// Look for end of thinking: </think> or <tool_call> (model may skip </think>)
|
||||
thinkIdx := strings.Index(bufStr, nemotronThinkClose)
|
||||
toolIdx := strings.Index(bufStr, nemotronToolCallOpen)
|
||||
|
||||
var endIdx int = -1
|
||||
var remainder string
|
||||
|
||||
if thinkIdx != -1 && (toolIdx == -1 || thinkIdx < toolIdx) {
|
||||
endIdx = thinkIdx
|
||||
remainder = strings.TrimLeftFunc(bufStr[thinkIdx+len(nemotronThinkClose):], unicode.IsSpace)
|
||||
} else if toolIdx != -1 {
|
||||
endIdx = toolIdx
|
||||
remainder = bufStr[toolIdx:] // Include <tool_call> tag
|
||||
}
|
||||
|
||||
if endIdx != -1 {
|
||||
thinking = strings.TrimRightFunc(bufStr[:endIdx], unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
|
||||
if remainder == "" {
|
||||
p.state = Nemotron3NanoSkipWhitespaceAfterThinking
|
||||
} else {
|
||||
p.state = Nemotron3NanoCollectingContent
|
||||
content, _, calls, err = p.toolParser.Add(remainder, done)
|
||||
case Nemotron3NanoCollectingContent:
|
||||
if strings.Contains(bufStr, nemotronToolCallOpen) {
|
||||
split := strings.SplitN(bufStr, nemotronToolCallOpen, 2)
|
||||
content := strings.TrimRightFunc(split[0], unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(split[1])
|
||||
p.state = Nemotron3NanoCollectingToolCalls
|
||||
if content != "" {
|
||||
return []nemotronEvent{nemotronEventContent{content: content}}, true
|
||||
}
|
||||
return nil, true
|
||||
}
|
||||
return content, thinking, calls, err
|
||||
unambig, ambig := p.emitWithPartialCheck(bufStr, nemotronToolCallOpen)
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambig)
|
||||
if unambig != "" {
|
||||
return []nemotronEvent{nemotronEventContent{content: unambig}}, false
|
||||
}
|
||||
return nil, false
|
||||
|
||||
case Nemotron3NanoCollectingToolCalls:
|
||||
if strings.Contains(bufStr, nemotronToolCallClose) {
|
||||
split := strings.SplitN(bufStr, nemotronToolCallClose, 2)
|
||||
remaining := strings.TrimLeftFunc(split[1], unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(remaining)
|
||||
|
||||
var events []nemotronEvent
|
||||
if tc, err := p.parseToolCall(split[0]); err == nil {
|
||||
events = append(events, nemotronEventToolCall{toolCall: tc})
|
||||
}
|
||||
|
||||
if !strings.Contains(remaining, nemotronToolCallOpen) {
|
||||
p.state = Nemotron3NanoCollectingContent
|
||||
}
|
||||
return events, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// No end marker - emit unambiguous thinking
|
||||
thinking = p.emitThinking(bufStr)
|
||||
return "", thinking, nil, nil
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// emitThinking returns unambiguous thinking content, keeping potential partial tags in buffer
|
||||
func (p *Nemotron3NanoParser) emitThinking(bufStr string) string {
|
||||
// Check for partial </think> or <tool_call> at end
|
||||
thinkOverlap := overlap(bufStr, nemotronThinkClose)
|
||||
toolOverlap := overlap(bufStr, nemotronToolCallOpen)
|
||||
maxOverlap := max(thinkOverlap, toolOverlap)
|
||||
var (
|
||||
nemotronFunctionRegex = regexp.MustCompile(`<function=([^>]+)>`)
|
||||
nemotronParameterRegex = regexp.MustCompile(`<parameter=([^>]+)>\n?([\s\S]*?)\n?</parameter>`)
|
||||
)
|
||||
|
||||
if maxOverlap > 0 {
|
||||
unambiguous := bufStr[:len(bufStr)-maxOverlap]
|
||||
unambiguous = strings.TrimRightFunc(unambiguous, unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(bufStr[len(bufStr)-maxOverlap:])
|
||||
return unambiguous
|
||||
func (p *Nemotron3NanoParser) parseToolCall(content string) (api.ToolCall, error) {
|
||||
toolCall := api.ToolCall{}
|
||||
|
||||
// Extract function name
|
||||
fnMatch := nemotronFunctionRegex.FindStringSubmatch(content)
|
||||
if len(fnMatch) < 2 {
|
||||
return toolCall, nil
|
||||
}
|
||||
toolCall.Function.Name = fnMatch[1]
|
||||
|
||||
// Extract parameters
|
||||
toolCall.Function.Arguments = api.NewToolCallFunctionArguments()
|
||||
paramMatches := nemotronParameterRegex.FindAllStringSubmatch(content, -1)
|
||||
for _, match := range paramMatches {
|
||||
if len(match) >= 3 {
|
||||
paramName := match[1]
|
||||
paramValue := strings.TrimSpace(match[2])
|
||||
|
||||
// Try to parse as typed value based on tool definition
|
||||
toolCall.Function.Arguments.Set(paramName, p.parseParamValue(paramName, paramValue))
|
||||
}
|
||||
}
|
||||
|
||||
// No partial tags - emit all but trailing whitespace
|
||||
wsLen := trailingWhitespaceLen(bufStr)
|
||||
if wsLen > 0 {
|
||||
unambiguous := bufStr[:len(bufStr)-wsLen]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(bufStr[len(bufStr)-wsLen:])
|
||||
return unambiguous
|
||||
}
|
||||
|
||||
// Nothing to hold back
|
||||
p.buffer.Reset()
|
||||
return bufStr
|
||||
return toolCall, nil
|
||||
}
|
||||
|
||||
func (p *Nemotron3NanoParser) parseParamValue(paramName string, raw string) any {
|
||||
// Find the matching tool to get parameter type
|
||||
var paramType api.PropertyType
|
||||
for _, tool := range p.tools {
|
||||
if tool.Function.Parameters.Properties != nil {
|
||||
if prop, ok := tool.Function.Parameters.Properties.Get(paramName); ok {
|
||||
paramType = prop.Type
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return parseValue(raw, paramType)
|
||||
}
|
||||
|
||||
@@ -8,8 +8,6 @@ import (
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// TestNemotron3NanoParser tests Nemotron-specific behavior (thinking support).
|
||||
// Tool call parsing is tested in qwen3coder_test.go since Nemotron delegates to Qwen3CoderParser.
|
||||
func TestNemotron3NanoParser(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -19,6 +17,18 @@ func TestNemotron3NanoParser(t *testing.T) {
|
||||
expectedThinking string
|
||||
expectedCalls []api.ToolCall
|
||||
}{
|
||||
{
|
||||
name: "simple content - no thinking",
|
||||
input: "Hello, how can I help you?",
|
||||
thinkValue: nil,
|
||||
expectedContent: "Hello, how can I help you?",
|
||||
},
|
||||
{
|
||||
name: "simple content - thinking disabled",
|
||||
input: "Hello, how can I help you?",
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expectedContent: "Hello, how can I help you?",
|
||||
},
|
||||
{
|
||||
name: "thinking then content",
|
||||
input: "Let me think about this...</think>\nHere is my answer.",
|
||||
@@ -33,6 +43,69 @@ func TestNemotron3NanoParser(t *testing.T) {
|
||||
expectedThinking: "Step 1: Analyze\nStep 2: Process\nStep 3: Conclude",
|
||||
expectedContent: "The answer is 42.",
|
||||
},
|
||||
{
|
||||
name: "simple tool call",
|
||||
input: "<tool_call>\n<function=get_weather>\n<parameter=city>\nParis\n</parameter>\n</function>\n</tool_call>",
|
||||
thinkValue: nil,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "content then tool call",
|
||||
input: "Let me check the weather.\n<tool_call>\n<function=get_weather>\n<parameter=city>\nNYC\n</parameter>\n</function>\n</tool_call>",
|
||||
thinkValue: nil,
|
||||
expectedContent: "Let me check the weather.",
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"city": "NYC"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call with multiple parameters",
|
||||
input: "<tool_call>\n<function=book_flight>\n<parameter=from>\nSFO\n</parameter>\n<parameter=to>\nNYC\n</parameter>\n</function>\n</tool_call>",
|
||||
thinkValue: nil,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "book_flight",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"from": "SFO",
|
||||
"to": "NYC",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple tool calls",
|
||||
input: "<tool_call>\n<function=get_weather>\n<parameter=city>\nSan Francisco\n</parameter>\n</function>\n</tool_call>\n" +
|
||||
"<tool_call>\n<function=get_weather>\n<parameter=city>\nNew York\n</parameter>\n</function>\n</tool_call>",
|
||||
thinkValue: nil,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"city": "San Francisco"}),
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"city": "New York"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "thinking then tool call",
|
||||
input: "I should check the weather...</think>\n<tool_call>\n<function=get_weather>\n<parameter=city>\nParis\n</parameter>\n</function>\n</tool_call>",
|
||||
@@ -62,6 +135,19 @@ func TestNemotron3NanoParser(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call with multiline parameter value",
|
||||
input: "<tool_call>\n<function=create_note>\n<parameter=content>\nLine 1\nLine 2\nLine 3\n</parameter>\n</function>\n</tool_call>",
|
||||
thinkValue: nil,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "create_note",
|
||||
Arguments: testArgs(map[string]any{"content": "Line 1\nLine 2\nLine 3"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty thinking block - immediate close",
|
||||
input: "</think>\nHere is my answer.",
|
||||
@@ -75,6 +161,18 @@ func TestNemotron3NanoParser(t *testing.T) {
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expectedContent: "</think>\nSome content after spurious tag.",
|
||||
},
|
||||
{
|
||||
name: "tool call with no function name - returns empty tool call",
|
||||
input: "<tool_call>\n<function=>\n</function>\n</tool_call>",
|
||||
thinkValue: nil,
|
||||
expectedCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "", Arguments: api.NewToolCallFunctionArguments()}}},
|
||||
},
|
||||
{
|
||||
name: "content with newlines preserved",
|
||||
input: "Line 1\n\nLine 2\n\n\nLine 3",
|
||||
thinkValue: nil,
|
||||
expectedContent: "Line 1\n\nLine 2\n\n\nLine 3",
|
||||
},
|
||||
{
|
||||
name: "thinking with only whitespace after close tag",
|
||||
input: "My thoughts...</think> \n\t\n Content here.",
|
||||
@@ -82,6 +180,25 @@ func TestNemotron3NanoParser(t *testing.T) {
|
||||
expectedThinking: "My thoughts...",
|
||||
expectedContent: "Content here.",
|
||||
},
|
||||
{
|
||||
name: "unicode content",
|
||||
input: "Hello 世界! 🌍 Ñoño",
|
||||
thinkValue: nil,
|
||||
expectedContent: "Hello 世界! 🌍 Ñoño",
|
||||
},
|
||||
{
|
||||
name: "tool call with numeric parameter",
|
||||
input: "<tool_call>\n<function=set_temp>\n<parameter=value>\n42\n</parameter>\n</function>\n</tool_call>",
|
||||
thinkValue: nil,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "set_temp",
|
||||
Arguments: testArgs(map[string]any{"value": "42"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -116,8 +233,6 @@ func TestNemotron3NanoParser(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestNemotron3NanoParser_Streaming tests streaming behavior for thinking support.
|
||||
// Tool call streaming is tested in qwen3coder_test.go.
|
||||
func TestNemotron3NanoParser_Streaming(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -127,6 +242,18 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
|
||||
expectedThinking string
|
||||
expectedCalls []api.ToolCall
|
||||
}{
|
||||
{
|
||||
name: "streaming content character by character",
|
||||
chunks: []string{"H", "e", "l", "l", "o", ",", " ", "w", "o", "r", "l", "d", "!"},
|
||||
thinkValue: nil,
|
||||
expectedContent: "Hello, world!",
|
||||
},
|
||||
{
|
||||
name: "streaming content small tokens",
|
||||
chunks: []string{"Hel", "lo", ", ", "how ", "can", " I", " help", " you", " today", "?"},
|
||||
thinkValue: nil,
|
||||
expectedContent: "Hello, how can I help you today?",
|
||||
},
|
||||
{
|
||||
name: "streaming thinking then content - granular",
|
||||
chunks: []string{"Let", " me", " th", "ink", " about", " this", "...", "<", "/", "think", ">", "\n", "Here", " is", " my", " answer", "."},
|
||||
@@ -141,6 +268,45 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
|
||||
expectedThinking: "Step 1: Analyze\nStep 2: Process",
|
||||
expectedContent: "The answer.",
|
||||
},
|
||||
{
|
||||
name: "streaming tool call - highly granular",
|
||||
chunks: []string{"<", "tool", "_", "call", ">", "\n", "<", "func", "tion", "=", "get", "_", "weather", ">", "\n", "<", "param", "eter", "=", "city", ">", "\n", "Par", "is", "\n", "</", "param", "eter", ">", "\n", "</", "func", "tion", ">", "\n", "</", "tool", "_", "call", ">"},
|
||||
thinkValue: nil,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "streaming content then tool call - granular",
|
||||
chunks: []string{"Let", " me", " check", " the", " weather", ".", "\n<", "tool_call", ">", "\n", "<function=", "get_weather", ">", "\n", "<parameter=", "city", ">", "\n", "NYC", "\n", "</parameter>", "\n", "</function>", "\n", "</tool_call>"},
|
||||
thinkValue: nil,
|
||||
expectedContent: "Let me check the weather.",
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"city": "NYC"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call tag split character by character",
|
||||
chunks: []string{"<", "t", "o", "o", "l", "_", "c", "a", "l", "l", ">", "\n", "<", "f", "u", "n", "c", "t", "i", "o", "n", "=", "t", "e", "s", "t", ">", "\n", "<", "/", "f", "u", "n", "c", "t", "i", "o", "n", ">", "\n", "<", "/", "t", "o", "o", "l", "_", "c", "a", "l", "l", ">"},
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test",
|
||||
Arguments: api.NewToolCallFunctionArguments(),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "thinking close tag split character by character",
|
||||
chunks: []string{"I", "'", "m", " ", "t", "h", "i", "n", "k", "i", "n", "g", ".", ".", ".", "<", "/", "t", "h", "i", "n", "k", ">", "\n", "D", "o", "n", "e", "!"},
|
||||
@@ -155,6 +321,22 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
|
||||
expectedThinking: "Thinking...",
|
||||
expectedContent: "Content here.",
|
||||
},
|
||||
{
|
||||
name: "tool call with multiple parameters - streaming",
|
||||
chunks: []string{"<tool_", "call>\n", "<function", "=book_", "flight>", "\n<para", "meter=", "from>\n", "SFO\n", "</param", "eter>", "\n<param", "eter=to", ">\nNYC", "\n</para", "meter>", "\n</func", "tion>\n", "</tool_", "call>"},
|
||||
thinkValue: nil,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "book_flight",
|
||||
Arguments: testArgs(map[string]any{
|
||||
"from": "SFO",
|
||||
"to": "NYC",
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "thinking then content then tool call - streaming",
|
||||
chunks: []string{"Ana", "lyzing", " your", " request", "...", "</", "think", ">\n", "I'll", " check", " that", " for", " you", ".", "\n", "<tool", "_call", ">\n", "<function", "=search", ">\n", "<parameter", "=query", ">\n", "test", " query", "\n</", "parameter", ">\n", "</function", ">\n", "</tool", "_call", ">"},
|
||||
@@ -170,6 +352,45 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple tool calls - streaming",
|
||||
chunks: []string{
|
||||
"<tool_call>", "\n", "<function=", "get_weather>", "\n",
|
||||
"<parameter=", "city>\n", "San Fran", "cisco\n", "</parameter>", "\n",
|
||||
"</function>", "\n", "</tool_call>", "\n",
|
||||
"<tool_", "call>\n", "<function", "=get_weather", ">\n",
|
||||
"<param", "eter=city", ">\nNew", " York\n", "</parameter>\n",
|
||||
"</function>\n", "</tool_call>",
|
||||
},
|
||||
thinkValue: nil,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"city": "San Francisco"}),
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"city": "New York"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call with multiline parameter - streaming",
|
||||
chunks: []string{"<tool_call>\n", "<function=", "create_note>\n", "<parameter=", "content>\n", "Line 1", "\nLine", " 2\n", "Line 3", "\n</parameter>\n", "</function>\n", "</tool_call>"},
|
||||
thinkValue: nil,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "create_note",
|
||||
Arguments: testArgs(map[string]any{"content": "Line 1\nLine 2\nLine 3"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty thinking block",
|
||||
chunks: []string{"</think>", "\n", "Just content."},
|
||||
@@ -177,6 +398,12 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
|
||||
expectedThinking: "",
|
||||
expectedContent: "Just content.",
|
||||
},
|
||||
{
|
||||
name: "empty input chunks interspersed",
|
||||
chunks: []string{"Hello", "", " ", "", "world", "", "!"},
|
||||
thinkValue: nil,
|
||||
expectedContent: "Hello world!",
|
||||
},
|
||||
{
|
||||
name: "tool call immediately after think close - no content",
|
||||
chunks: []string{"Analyzing...", "</think>", "\n", "<tool_call>", "\n<function=test>\n</function>\n", "</tool_call>"},
|
||||
@@ -191,6 +418,25 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call with empty parameter value",
|
||||
chunks: []string{"<tool_call>\n<function=test>\n<parameter=name>\n", "\n</parameter>\n</function>\n</tool_call>"},
|
||||
thinkValue: nil,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test",
|
||||
Arguments: testArgs(map[string]any{"name": ""}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "partial tool call tag at end - buffered",
|
||||
chunks: []string{"Here's some content", "<tool"},
|
||||
thinkValue: nil,
|
||||
expectedContent: "Here's some content",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -326,65 +572,3 @@ func TestNemotron3NanoParser_WithTools(t *testing.T) {
|
||||
t.Errorf("calls mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
// TestNemotron3NanoParser_ToolCallWithoutThinkClose tests the case where thinking is enabled
|
||||
// but the model outputs content + tool call WITHOUT the </think> tag.
|
||||
// The parser should still parse the tool call (content before is treated as thinking).
|
||||
func TestNemotron3NanoParser_ToolCallWithoutThinkClose(t *testing.T) {
|
||||
chunks := []string{
|
||||
"Let", " me", " analyze", " this", ".", "\n",
|
||||
"<tool_call>", "\n",
|
||||
"<function=get_weather>", "\n",
|
||||
"<parameter=city>", "Paris", "</parameter>", "\n",
|
||||
"</function>", "\n",
|
||||
"</tool_call>",
|
||||
}
|
||||
|
||||
p := &Nemotron3NanoParser{}
|
||||
p.Init(nil, nil, &api.ThinkValue{Value: true}) // thinking ENABLED but model doesn't output </think>
|
||||
|
||||
var allContent string
|
||||
var allThinking string
|
||||
var allCalls []api.ToolCall
|
||||
|
||||
for _, chunk := range chunks {
|
||||
content, thinking, calls, err := p.Add(chunk, false)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
allContent += content
|
||||
allThinking += thinking
|
||||
allCalls = append(allCalls, calls...)
|
||||
}
|
||||
|
||||
// Drain
|
||||
content, thinking, calls, err := p.Add("", true)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error on done: %v", err)
|
||||
}
|
||||
allContent += content
|
||||
allThinking += thinking
|
||||
allCalls = append(allCalls, calls...)
|
||||
|
||||
// The parser was in thinking mode, so text before <tool_call> is emitted as thinking.
|
||||
expectedThinking := "Let me analyze this."
|
||||
|
||||
expectedCalls := []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if allContent != "" {
|
||||
t.Errorf("expected no content (text was streamed as thinking), got: %q", allContent)
|
||||
}
|
||||
if diff := cmp.Diff(allThinking, expectedThinking); diff != "" {
|
||||
t.Errorf("thinking mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
if diff := cmp.Diff(allCalls, expectedCalls, argsComparer); diff != "" {
|
||||
t.Errorf("calls mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -91,37 +91,6 @@ func TestQwenParserStreaming(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "tool call tags split character by character",
|
||||
steps: []step{
|
||||
{input: "<", wantEvents: []qwenEvent{}},
|
||||
{input: "t", wantEvents: []qwenEvent{}},
|
||||
{input: "o", wantEvents: []qwenEvent{}},
|
||||
{input: "o", wantEvents: []qwenEvent{}},
|
||||
{input: "l", wantEvents: []qwenEvent{}},
|
||||
{input: "_", wantEvents: []qwenEvent{}},
|
||||
{input: "c", wantEvents: []qwenEvent{}},
|
||||
{input: "a", wantEvents: []qwenEvent{}},
|
||||
{input: "l", wantEvents: []qwenEvent{}},
|
||||
{input: "l", wantEvents: []qwenEvent{}},
|
||||
{input: ">", wantEvents: []qwenEvent{}},
|
||||
{input: "a", wantEvents: []qwenEvent{}},
|
||||
{input: "b", wantEvents: []qwenEvent{}},
|
||||
{input: "c", wantEvents: []qwenEvent{}},
|
||||
{input: "<", wantEvents: []qwenEvent{}},
|
||||
{input: "/", wantEvents: []qwenEvent{}},
|
||||
{input: "t", wantEvents: []qwenEvent{}},
|
||||
{input: "o", wantEvents: []qwenEvent{}},
|
||||
{input: "o", wantEvents: []qwenEvent{}},
|
||||
{input: "l", wantEvents: []qwenEvent{}},
|
||||
{input: "_", wantEvents: []qwenEvent{}},
|
||||
{input: "c", wantEvents: []qwenEvent{}},
|
||||
{input: "a", wantEvents: []qwenEvent{}},
|
||||
{input: "l", wantEvents: []qwenEvent{}},
|
||||
{input: "l", wantEvents: []qwenEvent{}},
|
||||
{input: ">", wantEvents: []qwenEvent{qwenEventRawToolCall{raw: "abc"}}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "trailing whitespace between content and tool call",
|
||||
steps: []step{
|
||||
|
||||
@@ -630,10 +630,6 @@ func nameFromToolCallID(messages []Message, toolCallID string) string {
|
||||
|
||||
// decodeImageURL decodes a base64 data URI into raw image bytes.
|
||||
func decodeImageURL(url string) (api.ImageData, error) {
|
||||
if strings.HasPrefix(url, "http://") || strings.HasPrefix(url, "https://") {
|
||||
return nil, errors.New("image URLs are not currently supported, please use base64 encoded data instead")
|
||||
}
|
||||
|
||||
types := []string{"jpeg", "jpg", "png", "webp"}
|
||||
|
||||
// Support blank mime type to match /api/chat's behavior of taking just unadorned base64
|
||||
@@ -737,60 +733,3 @@ func FromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
|
||||
DebugRenderOnly: r.DebugRenderOnly,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ImageGenerationRequest is an OpenAI-compatible image generation request.
|
||||
type ImageGenerationRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
N int `json:"n,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
Seed *int64 `json:"seed,omitempty"`
|
||||
}
|
||||
|
||||
// ImageGenerationResponse is an OpenAI-compatible image generation response.
|
||||
type ImageGenerationResponse struct {
|
||||
Created int64 `json:"created"`
|
||||
Data []ImageURLOrData `json:"data"`
|
||||
}
|
||||
|
||||
// ImageURLOrData contains either a URL or base64-encoded image data.
|
||||
type ImageURLOrData struct {
|
||||
URL string `json:"url,omitempty"`
|
||||
B64JSON string `json:"b64_json,omitempty"`
|
||||
}
|
||||
|
||||
// FromImageGenerationRequest converts an OpenAI image generation request to an Ollama GenerateRequest.
|
||||
func FromImageGenerationRequest(r ImageGenerationRequest) api.GenerateRequest {
|
||||
req := api.GenerateRequest{
|
||||
Model: r.Model,
|
||||
Prompt: r.Prompt,
|
||||
}
|
||||
// Parse size if provided (e.g., "1024x768")
|
||||
if r.Size != "" {
|
||||
var w, h int32
|
||||
if _, err := fmt.Sscanf(r.Size, "%dx%d", &w, &h); err == nil {
|
||||
req.Width = w
|
||||
req.Height = h
|
||||
}
|
||||
}
|
||||
if r.Seed != nil {
|
||||
if req.Options == nil {
|
||||
req.Options = map[string]any{}
|
||||
}
|
||||
req.Options["seed"] = *r.Seed
|
||||
}
|
||||
return req
|
||||
}
|
||||
|
||||
// ToImageGenerationResponse converts an Ollama GenerateResponse to an OpenAI ImageGenerationResponse.
|
||||
func ToImageGenerationResponse(resp api.GenerateResponse) ImageGenerationResponse {
|
||||
var data []ImageURLOrData
|
||||
if resp.Image != "" {
|
||||
data = []ImageURLOrData{{B64JSON: resp.Image}}
|
||||
}
|
||||
return ImageGenerationResponse{
|
||||
Created: resp.CreatedAt.Unix(),
|
||||
Data: data,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
@@ -266,9 +265,9 @@ type ResponsesText struct {
|
||||
type ResponsesTool struct {
|
||||
Type string `json:"type"` // "function"
|
||||
Name string `json:"name"`
|
||||
Description *string `json:"description"` // nullable but required
|
||||
Strict *bool `json:"strict"` // nullable but required
|
||||
Parameters map[string]any `json:"parameters"` // nullable but required
|
||||
Description string `json:"description,omitempty"`
|
||||
Strict bool `json:"strict,omitempty"`
|
||||
Parameters map[string]any `json:"parameters,omitempty"`
|
||||
}
|
||||
|
||||
type ResponsesRequest struct {
|
||||
@@ -476,16 +475,11 @@ func convertTool(t ResponsesTool) (api.Tool, error) {
|
||||
}
|
||||
}
|
||||
|
||||
var description string
|
||||
if t.Description != nil {
|
||||
description = *t.Description
|
||||
}
|
||||
|
||||
return api.Tool{
|
||||
Type: t.Type,
|
||||
Function: api.ToolFunction{
|
||||
Name: t.Name,
|
||||
Description: description,
|
||||
Description: t.Description,
|
||||
Parameters: params,
|
||||
},
|
||||
}, nil
|
||||
@@ -522,60 +516,17 @@ func convertInputMessage(m ResponsesInputMessage) (api.Message, error) {
|
||||
|
||||
// Response types for the Responses API
|
||||
|
||||
// ResponsesTextField represents the text output configuration in the response.
|
||||
type ResponsesTextField struct {
|
||||
Format ResponsesTextFormat `json:"format"`
|
||||
}
|
||||
|
||||
// ResponsesReasoningOutput represents reasoning configuration in the response.
|
||||
type ResponsesReasoningOutput struct {
|
||||
Effort *string `json:"effort,omitempty"`
|
||||
Summary *string `json:"summary,omitempty"`
|
||||
}
|
||||
|
||||
// ResponsesError represents an error in the response.
|
||||
type ResponsesError struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// ResponsesIncompleteDetails represents details about why a response was incomplete.
|
||||
type ResponsesIncompleteDetails struct {
|
||||
Reason string `json:"reason"`
|
||||
}
|
||||
|
||||
type ResponsesResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
CompletedAt *int64 `json:"completed_at"`
|
||||
Status string `json:"status"`
|
||||
IncompleteDetails *ResponsesIncompleteDetails `json:"incomplete_details"`
|
||||
Model string `json:"model"`
|
||||
PreviousResponseID *string `json:"previous_response_id"`
|
||||
Instructions *string `json:"instructions"`
|
||||
Output []ResponsesOutputItem `json:"output"`
|
||||
Error *ResponsesError `json:"error"`
|
||||
Tools []ResponsesTool `json:"tools"`
|
||||
ToolChoice any `json:"tool_choice"`
|
||||
Truncation string `json:"truncation"`
|
||||
ParallelToolCalls bool `json:"parallel_tool_calls"`
|
||||
Text ResponsesTextField `json:"text"`
|
||||
TopP float64 `json:"top_p"`
|
||||
PresencePenalty float64 `json:"presence_penalty"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty"`
|
||||
TopLogprobs int `json:"top_logprobs"`
|
||||
Temperature float64 `json:"temperature"`
|
||||
Reasoning *ResponsesReasoningOutput `json:"reasoning"`
|
||||
Usage *ResponsesUsage `json:"usage"`
|
||||
MaxOutputTokens *int `json:"max_output_tokens"`
|
||||
MaxToolCalls *int `json:"max_tool_calls"`
|
||||
Store bool `json:"store"`
|
||||
Background bool `json:"background"`
|
||||
ServiceTier string `json:"service_tier"`
|
||||
Metadata map[string]any `json:"metadata"`
|
||||
SafetyIdentifier *string `json:"safety_identifier"`
|
||||
PromptCacheKey *string `json:"prompt_cache_key"`
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
Status string `json:"status"`
|
||||
Model string `json:"model"`
|
||||
Output []ResponsesOutputItem `json:"output"`
|
||||
Usage *ResponsesUsage `json:"usage,omitempty"`
|
||||
// TODO(drifkin): add `temperature` and `top_p` to the response, but this
|
||||
// requires additional plumbing to find the effective values since the
|
||||
// defaults can come from the model or the request
|
||||
}
|
||||
|
||||
type ResponsesOutputItem struct {
|
||||
@@ -599,39 +550,18 @@ type ResponsesReasoningSummary struct {
|
||||
}
|
||||
|
||||
type ResponsesOutputContent struct {
|
||||
Type string `json:"type"` // "output_text"
|
||||
Text string `json:"text"`
|
||||
Annotations []any `json:"annotations"`
|
||||
Logprobs []any `json:"logprobs"`
|
||||
}
|
||||
|
||||
type ResponsesInputTokensDetails struct {
|
||||
CachedTokens int `json:"cached_tokens"`
|
||||
}
|
||||
|
||||
type ResponsesOutputTokensDetails struct {
|
||||
ReasoningTokens int `json:"reasoning_tokens"`
|
||||
Type string `json:"type"` // "output_text"
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
type ResponsesUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
InputTokensDetails ResponsesInputTokensDetails `json:"input_tokens_details"`
|
||||
OutputTokensDetails ResponsesOutputTokensDetails `json:"output_tokens_details"`
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
// derefFloat64 returns the value of a float64 pointer, or a default if nil.
|
||||
func derefFloat64(p *float64, def float64) float64 {
|
||||
if p != nil {
|
||||
return *p
|
||||
}
|
||||
return def
|
||||
}
|
||||
|
||||
// ToResponse converts an api.ChatResponse to a Responses API response.
|
||||
// The request is used to echo back request parameters in the response.
|
||||
func ToResponse(model, responseID, itemID string, chatResponse api.ChatResponse, request ResponsesRequest) ResponsesResponse {
|
||||
// ToResponse converts an api.ChatResponse to a Responses API response
|
||||
func ToResponse(model, responseID, itemID string, chatResponse api.ChatResponse) ResponsesResponse {
|
||||
var output []ResponsesOutputItem
|
||||
|
||||
// Add reasoning item if thinking is present
|
||||
@@ -655,7 +585,6 @@ func ToResponse(model, responseID, itemID string, chatResponse api.ChatResponse,
|
||||
output = append(output, ResponsesOutputItem{
|
||||
ID: fmt.Sprintf("fc_%s_%d", responseID, i),
|
||||
Type: "function_call",
|
||||
Status: "completed",
|
||||
CallID: tc.ID,
|
||||
Name: tc.Function.Name,
|
||||
Arguments: tc.Function.Arguments,
|
||||
@@ -669,90 +598,25 @@ func ToResponse(model, responseID, itemID string, chatResponse api.ChatResponse,
|
||||
Role: "assistant",
|
||||
Content: []ResponsesOutputContent{
|
||||
{
|
||||
Type: "output_text",
|
||||
Text: chatResponse.Message.Content,
|
||||
Annotations: []any{},
|
||||
Logprobs: []any{},
|
||||
Type: "output_text",
|
||||
Text: chatResponse.Message.Content,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
var instructions *string
|
||||
if request.Instructions != "" {
|
||||
instructions = &request.Instructions
|
||||
}
|
||||
|
||||
// Build truncation with default
|
||||
truncation := "disabled"
|
||||
if request.Truncation != nil {
|
||||
truncation = *request.Truncation
|
||||
}
|
||||
|
||||
tools := request.Tools
|
||||
if tools == nil {
|
||||
tools = []ResponsesTool{}
|
||||
}
|
||||
|
||||
text := ResponsesTextField{
|
||||
Format: ResponsesTextFormat{Type: "text"},
|
||||
}
|
||||
if request.Text != nil && request.Text.Format != nil {
|
||||
text.Format = *request.Text.Format
|
||||
}
|
||||
|
||||
// Build reasoning output from request
|
||||
var reasoning *ResponsesReasoningOutput
|
||||
if request.Reasoning.Effort != "" || request.Reasoning.Summary != "" {
|
||||
reasoning = &ResponsesReasoningOutput{}
|
||||
if request.Reasoning.Effort != "" {
|
||||
reasoning.Effort = &request.Reasoning.Effort
|
||||
}
|
||||
if request.Reasoning.Summary != "" {
|
||||
reasoning.Summary = &request.Reasoning.Summary
|
||||
}
|
||||
}
|
||||
|
||||
return ResponsesResponse{
|
||||
ID: responseID,
|
||||
Object: "response",
|
||||
CreatedAt: chatResponse.CreatedAt.Unix(),
|
||||
CompletedAt: nil, // Set by middleware when writing final response
|
||||
Status: "completed",
|
||||
IncompleteDetails: nil, // Only populated if response incomplete
|
||||
Model: model,
|
||||
PreviousResponseID: nil, // Not supported
|
||||
Instructions: instructions,
|
||||
Output: output,
|
||||
Error: nil, // Only populated on failure
|
||||
Tools: tools,
|
||||
ToolChoice: "auto", // Default value
|
||||
Truncation: truncation,
|
||||
ParallelToolCalls: true, // Default value
|
||||
Text: text,
|
||||
TopP: derefFloat64(request.TopP, 1.0),
|
||||
PresencePenalty: 0, // Default value
|
||||
FrequencyPenalty: 0, // Default value
|
||||
TopLogprobs: 0, // Default value
|
||||
Temperature: derefFloat64(request.Temperature, 1.0),
|
||||
Reasoning: reasoning,
|
||||
ID: responseID,
|
||||
Object: "response",
|
||||
CreatedAt: chatResponse.CreatedAt.Unix(),
|
||||
Status: "completed",
|
||||
Model: model,
|
||||
Output: output,
|
||||
Usage: &ResponsesUsage{
|
||||
InputTokens: chatResponse.PromptEvalCount,
|
||||
OutputTokens: chatResponse.EvalCount,
|
||||
TotalTokens: chatResponse.PromptEvalCount + chatResponse.EvalCount,
|
||||
// TODO(drifkin): wire through the actual values
|
||||
InputTokensDetails: ResponsesInputTokensDetails{CachedTokens: 0},
|
||||
// TODO(drifkin): wire through the actual values
|
||||
OutputTokensDetails: ResponsesOutputTokensDetails{ReasoningTokens: 0},
|
||||
},
|
||||
MaxOutputTokens: request.MaxOutputTokens,
|
||||
MaxToolCalls: nil, // Not supported
|
||||
Store: false, // We don't store responses
|
||||
Background: request.Background,
|
||||
ServiceTier: "default", // Default value
|
||||
Metadata: map[string]any{},
|
||||
SafetyIdentifier: nil, // Not supported
|
||||
PromptCacheKey: nil, // Not supported
|
||||
}
|
||||
}
|
||||
|
||||
@@ -772,7 +636,6 @@ type ResponsesStreamConverter struct {
|
||||
responseID string
|
||||
itemID string
|
||||
model string
|
||||
request ResponsesRequest
|
||||
|
||||
// State tracking (mutated across Process calls)
|
||||
firstWrite bool
|
||||
@@ -805,12 +668,11 @@ func (c *ResponsesStreamConverter) newEvent(eventType string, data map[string]an
|
||||
}
|
||||
|
||||
// NewResponsesStreamConverter creates a new converter with the given configuration.
|
||||
func NewResponsesStreamConverter(responseID, itemID, model string, request ResponsesRequest) *ResponsesStreamConverter {
|
||||
func NewResponsesStreamConverter(responseID, itemID, model string) *ResponsesStreamConverter {
|
||||
return &ResponsesStreamConverter{
|
||||
responseID: responseID,
|
||||
itemID: itemID,
|
||||
model: model,
|
||||
request: request,
|
||||
firstWrite: true,
|
||||
}
|
||||
}
|
||||
@@ -855,120 +717,25 @@ func (c *ResponsesStreamConverter) Process(r api.ChatResponse) []ResponsesStream
|
||||
return events
|
||||
}
|
||||
|
||||
// buildResponseObject creates a full response object with all required fields for streaming events.
|
||||
func (c *ResponsesStreamConverter) buildResponseObject(status string, output []any, usage map[string]any) map[string]any {
|
||||
var instructions any = nil
|
||||
if c.request.Instructions != "" {
|
||||
instructions = c.request.Instructions
|
||||
}
|
||||
|
||||
truncation := "disabled"
|
||||
if c.request.Truncation != nil {
|
||||
truncation = *c.request.Truncation
|
||||
}
|
||||
|
||||
var tools []any
|
||||
if c.request.Tools != nil {
|
||||
for _, t := range c.request.Tools {
|
||||
tools = append(tools, map[string]any{
|
||||
"type": t.Type,
|
||||
"name": t.Name,
|
||||
"description": t.Description,
|
||||
"strict": t.Strict,
|
||||
"parameters": t.Parameters,
|
||||
})
|
||||
}
|
||||
}
|
||||
if tools == nil {
|
||||
tools = []any{}
|
||||
}
|
||||
|
||||
textFormat := map[string]any{"type": "text"}
|
||||
if c.request.Text != nil && c.request.Text.Format != nil {
|
||||
textFormat = map[string]any{
|
||||
"type": c.request.Text.Format.Type,
|
||||
}
|
||||
if c.request.Text.Format.Name != "" {
|
||||
textFormat["name"] = c.request.Text.Format.Name
|
||||
}
|
||||
if c.request.Text.Format.Schema != nil {
|
||||
textFormat["schema"] = c.request.Text.Format.Schema
|
||||
}
|
||||
if c.request.Text.Format.Strict != nil {
|
||||
textFormat["strict"] = *c.request.Text.Format.Strict
|
||||
}
|
||||
}
|
||||
|
||||
var reasoning any = nil
|
||||
if c.request.Reasoning.Effort != "" || c.request.Reasoning.Summary != "" {
|
||||
r := map[string]any{}
|
||||
if c.request.Reasoning.Effort != "" {
|
||||
r["effort"] = c.request.Reasoning.Effort
|
||||
} else {
|
||||
r["effort"] = nil
|
||||
}
|
||||
if c.request.Reasoning.Summary != "" {
|
||||
r["summary"] = c.request.Reasoning.Summary
|
||||
} else {
|
||||
r["summary"] = nil
|
||||
}
|
||||
reasoning = r
|
||||
}
|
||||
|
||||
// Build top_p and temperature with defaults
|
||||
topP := 1.0
|
||||
if c.request.TopP != nil {
|
||||
topP = *c.request.TopP
|
||||
}
|
||||
temperature := 1.0
|
||||
if c.request.Temperature != nil {
|
||||
temperature = *c.request.Temperature
|
||||
}
|
||||
|
||||
return map[string]any{
|
||||
"id": c.responseID,
|
||||
"object": "response",
|
||||
"created_at": time.Now().Unix(),
|
||||
"completed_at": nil,
|
||||
"status": status,
|
||||
"incomplete_details": nil,
|
||||
"model": c.model,
|
||||
"previous_response_id": nil,
|
||||
"instructions": instructions,
|
||||
"output": output,
|
||||
"error": nil,
|
||||
"tools": tools,
|
||||
"tool_choice": "auto",
|
||||
"truncation": truncation,
|
||||
"parallel_tool_calls": true,
|
||||
"text": map[string]any{"format": textFormat},
|
||||
"top_p": topP,
|
||||
"presence_penalty": 0,
|
||||
"frequency_penalty": 0,
|
||||
"top_logprobs": 0,
|
||||
"temperature": temperature,
|
||||
"reasoning": reasoning,
|
||||
"usage": usage,
|
||||
"max_output_tokens": c.request.MaxOutputTokens,
|
||||
"max_tool_calls": nil,
|
||||
"store": false,
|
||||
"background": c.request.Background,
|
||||
"service_tier": "default",
|
||||
"metadata": map[string]any{},
|
||||
"safety_identifier": nil,
|
||||
"prompt_cache_key": nil,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ResponsesStreamConverter) createResponseCreatedEvent() ResponsesStreamEvent {
|
||||
return c.newEvent("response.created", map[string]any{
|
||||
"response": c.buildResponseObject("in_progress", []any{}, nil),
|
||||
"response": map[string]any{
|
||||
"id": c.responseID,
|
||||
"object": "response",
|
||||
"status": "in_progress",
|
||||
"output": []any{},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (c *ResponsesStreamConverter) createResponseInProgressEvent() ResponsesStreamEvent {
|
||||
return c.newEvent("response.in_progress", map[string]any{
|
||||
"response": c.buildResponseObject("in_progress", []any{}, nil),
|
||||
"response": map[string]any{
|
||||
"id": c.responseID,
|
||||
"object": "response",
|
||||
"status": "in_progress",
|
||||
"output": []any{},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
@@ -995,10 +762,9 @@ func (c *ResponsesStreamConverter) processThinking(thinking string) []ResponsesS
|
||||
|
||||
// Emit delta
|
||||
events = append(events, c.newEvent("response.reasoning_summary_text.delta", map[string]any{
|
||||
"item_id": c.reasoningItemID,
|
||||
"output_index": c.outputIndex,
|
||||
"summary_index": 0,
|
||||
"delta": thinking,
|
||||
"item_id": c.reasoningItemID,
|
||||
"output_index": c.outputIndex,
|
||||
"delta": thinking,
|
||||
}))
|
||||
|
||||
// TODO(drifkin): consider adding
|
||||
@@ -1017,10 +783,9 @@ func (c *ResponsesStreamConverter) finishReasoning() []ResponsesStreamEvent {
|
||||
|
||||
events := []ResponsesStreamEvent{
|
||||
c.newEvent("response.reasoning_summary_text.done", map[string]any{
|
||||
"item_id": c.reasoningItemID,
|
||||
"output_index": c.outputIndex,
|
||||
"summary_index": 0,
|
||||
"text": c.accumulatedThinking,
|
||||
"item_id": c.reasoningItemID,
|
||||
"output_index": c.outputIndex,
|
||||
"text": c.accumulatedThinking,
|
||||
}),
|
||||
c.newEvent("response.output_item.done", map[string]any{
|
||||
"output_index": c.outputIndex,
|
||||
@@ -1133,10 +898,8 @@ func (c *ResponsesStreamConverter) processTextContent(content string) []Response
|
||||
"output_index": c.outputIndex,
|
||||
"content_index": c.contentIndex,
|
||||
"part": map[string]any{
|
||||
"type": "output_text",
|
||||
"text": "",
|
||||
"annotations": []any{},
|
||||
"logprobs": []any{},
|
||||
"type": "output_text",
|
||||
"text": "",
|
||||
},
|
||||
}))
|
||||
}
|
||||
@@ -1150,7 +913,6 @@ func (c *ResponsesStreamConverter) processTextContent(content string) []Response
|
||||
"output_index": c.outputIndex,
|
||||
"content_index": 0,
|
||||
"delta": content,
|
||||
"logprobs": []any{},
|
||||
}))
|
||||
|
||||
return events
|
||||
@@ -1182,10 +944,8 @@ func (c *ResponsesStreamConverter) buildFinalOutput() []any {
|
||||
"status": "completed",
|
||||
"role": "assistant",
|
||||
"content": []map[string]any{{
|
||||
"type": "output_text",
|
||||
"text": c.accumulatedText,
|
||||
"annotations": []any{},
|
||||
"logprobs": []any{},
|
||||
"type": "output_text",
|
||||
"text": c.accumulatedText,
|
||||
}},
|
||||
})
|
||||
}
|
||||
@@ -1207,7 +967,6 @@ func (c *ResponsesStreamConverter) processCompletion(r api.ChatResponse) []Respo
|
||||
"output_index": c.outputIndex,
|
||||
"content_index": 0,
|
||||
"text": c.accumulatedText,
|
||||
"logprobs": []any{},
|
||||
}))
|
||||
|
||||
// response.content_part.done
|
||||
@@ -1216,10 +975,8 @@ func (c *ResponsesStreamConverter) processCompletion(r api.ChatResponse) []Respo
|
||||
"output_index": c.outputIndex,
|
||||
"content_index": 0,
|
||||
"part": map[string]any{
|
||||
"type": "output_text",
|
||||
"text": c.accumulatedText,
|
||||
"annotations": []any{},
|
||||
"logprobs": []any{},
|
||||
"type": "output_text",
|
||||
"text": c.accumulatedText,
|
||||
},
|
||||
}))
|
||||
|
||||
@@ -1232,31 +989,26 @@ func (c *ResponsesStreamConverter) processCompletion(r api.ChatResponse) []Respo
|
||||
"status": "completed",
|
||||
"role": "assistant",
|
||||
"content": []map[string]any{{
|
||||
"type": "output_text",
|
||||
"text": c.accumulatedText,
|
||||
"annotations": []any{},
|
||||
"logprobs": []any{},
|
||||
"type": "output_text",
|
||||
"text": c.accumulatedText,
|
||||
}},
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
// response.completed
|
||||
usage := map[string]any{
|
||||
"input_tokens": r.PromptEvalCount,
|
||||
"output_tokens": r.EvalCount,
|
||||
"total_tokens": r.PromptEvalCount + r.EvalCount,
|
||||
"input_tokens_details": map[string]any{
|
||||
"cached_tokens": 0,
|
||||
},
|
||||
"output_tokens_details": map[string]any{
|
||||
"reasoning_tokens": 0,
|
||||
},
|
||||
}
|
||||
response := c.buildResponseObject("completed", c.buildFinalOutput(), usage)
|
||||
response["completed_at"] = time.Now().Unix()
|
||||
events = append(events, c.newEvent("response.completed", map[string]any{
|
||||
"response": response,
|
||||
"response": map[string]any{
|
||||
"id": c.responseID,
|
||||
"object": "response",
|
||||
"status": "completed",
|
||||
"output": c.buildFinalOutput(),
|
||||
"usage": map[string]any{
|
||||
"input_tokens": r.PromptEvalCount,
|
||||
"output_tokens": r.EvalCount,
|
||||
"total_tokens": r.PromptEvalCount + r.EvalCount,
|
||||
},
|
||||
},
|
||||
}))
|
||||
|
||||
return events
|
||||
|
||||
@@ -850,7 +850,7 @@ func TestFromResponsesRequest_Images(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestResponsesStreamConverter_TextOnly(t *testing.T) {
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||
|
||||
// First chunk with content
|
||||
events := converter.Process(api.ChatResponse{
|
||||
@@ -916,7 +916,7 @@ func TestResponsesStreamConverter_TextOnly(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestResponsesStreamConverter_ToolCalls(t *testing.T) {
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||
|
||||
events := converter.Process(api.ChatResponse{
|
||||
Message: api.Message{
|
||||
@@ -952,7 +952,7 @@ func TestResponsesStreamConverter_ToolCalls(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestResponsesStreamConverter_Reasoning(t *testing.T) {
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||
|
||||
// First chunk with thinking
|
||||
events := converter.Process(api.ChatResponse{
|
||||
@@ -1267,7 +1267,7 @@ func TestToResponse_WithReasoning(t *testing.T) {
|
||||
Content: "The answer is 42",
|
||||
},
|
||||
Done: true,
|
||||
}, ResponsesRequest{})
|
||||
})
|
||||
|
||||
// Should have 2 output items: reasoning + message
|
||||
if len(response.Output) != 2 {
|
||||
@@ -1638,7 +1638,7 @@ func TestFromResponsesRequest_ShorthandFormats(t *testing.T) {
|
||||
|
||||
func TestResponsesStreamConverter_OutputIncludesContent(t *testing.T) {
|
||||
// Verify that response.output_item.done includes content field for messages
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||
|
||||
// First chunk
|
||||
converter.Process(api.ChatResponse{
|
||||
@@ -1686,7 +1686,7 @@ func TestResponsesStreamConverter_OutputIncludesContent(t *testing.T) {
|
||||
|
||||
func TestResponsesStreamConverter_ResponseCompletedIncludesOutput(t *testing.T) {
|
||||
// Verify that response.completed includes the output array
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||
|
||||
// Process some content
|
||||
converter.Process(api.ChatResponse{
|
||||
@@ -1730,7 +1730,7 @@ func TestResponsesStreamConverter_ResponseCompletedIncludesOutput(t *testing.T)
|
||||
|
||||
func TestResponsesStreamConverter_ResponseCreatedIncludesOutput(t *testing.T) {
|
||||
// Verify that response.created includes an empty output array
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||
|
||||
events := converter.Process(api.ChatResponse{
|
||||
Message: api.Message{Content: "Hi"},
|
||||
@@ -1757,7 +1757,7 @@ func TestResponsesStreamConverter_ResponseCreatedIncludesOutput(t *testing.T) {
|
||||
|
||||
func TestResponsesStreamConverter_SequenceNumbers(t *testing.T) {
|
||||
// Verify that events include incrementing sequence numbers
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||
|
||||
events := converter.Process(api.ChatResponse{
|
||||
Message: api.Message{Content: "Hello"},
|
||||
@@ -1791,7 +1791,7 @@ func TestResponsesStreamConverter_SequenceNumbers(t *testing.T) {
|
||||
|
||||
func TestResponsesStreamConverter_FunctionCallStatus(t *testing.T) {
|
||||
// Verify that function call items include status field
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b", ResponsesRequest{})
|
||||
converter := NewResponsesStreamConverter("resp_123", "msg_456", "gpt-oss:20b")
|
||||
|
||||
events := converter.Process(api.ChatResponse{
|
||||
Message: api.Message{
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Prompt struct {
|
||||
@@ -37,11 +36,10 @@ type Terminal struct {
|
||||
}
|
||||
|
||||
type Instance struct {
|
||||
Prompt *Prompt
|
||||
Terminal *Terminal
|
||||
History *History
|
||||
Pasting bool
|
||||
pastedLines []string
|
||||
Prompt *Prompt
|
||||
Terminal *Terminal
|
||||
History *History
|
||||
Pasting bool
|
||||
}
|
||||
|
||||
func New(prompt Prompt) (*Instance, error) {
|
||||
@@ -176,8 +174,6 @@ func (i *Instance) Readline() (string, error) {
|
||||
case CharEsc:
|
||||
esc = true
|
||||
case CharInterrupt:
|
||||
i.pastedLines = nil
|
||||
i.Prompt.UseAlt = false
|
||||
return "", ErrInterrupt
|
||||
case CharPrev:
|
||||
i.historyPrev(buf, ¤tLineBuf)
|
||||
@@ -192,23 +188,7 @@ func (i *Instance) Readline() (string, error) {
|
||||
case CharForward:
|
||||
buf.MoveRight()
|
||||
case CharBackspace, CharCtrlH:
|
||||
if buf.IsEmpty() && len(i.pastedLines) > 0 {
|
||||
lastIdx := len(i.pastedLines) - 1
|
||||
prevLine := i.pastedLines[lastIdx]
|
||||
i.pastedLines = i.pastedLines[:lastIdx]
|
||||
fmt.Print(CursorBOL + ClearToEOL + CursorUp + CursorBOL + ClearToEOL)
|
||||
if len(i.pastedLines) == 0 {
|
||||
fmt.Print(i.Prompt.Prompt)
|
||||
i.Prompt.UseAlt = false
|
||||
} else {
|
||||
fmt.Print(i.Prompt.AltPrompt)
|
||||
}
|
||||
for _, r := range prevLine {
|
||||
buf.Add(r)
|
||||
}
|
||||
} else {
|
||||
buf.Remove()
|
||||
}
|
||||
buf.Remove()
|
||||
case CharTab:
|
||||
// todo: convert back to real tabs
|
||||
for range 8 {
|
||||
@@ -231,28 +211,13 @@ func (i *Instance) Readline() (string, error) {
|
||||
case CharCtrlZ:
|
||||
fd := os.Stdin.Fd()
|
||||
return handleCharCtrlZ(fd, i.Terminal.termios)
|
||||
case CharCtrlJ:
|
||||
i.pastedLines = append(i.pastedLines, buf.String())
|
||||
buf.Buf.Clear()
|
||||
buf.Pos = 0
|
||||
buf.DisplayPos = 0
|
||||
buf.LineHasSpace.Clear()
|
||||
fmt.Println()
|
||||
fmt.Print(i.Prompt.AltPrompt)
|
||||
i.Prompt.UseAlt = true
|
||||
continue
|
||||
case CharEnter:
|
||||
case CharEnter, CharCtrlJ:
|
||||
output := buf.String()
|
||||
if len(i.pastedLines) > 0 {
|
||||
output = strings.Join(i.pastedLines, "\n") + "\n" + output
|
||||
i.pastedLines = nil
|
||||
}
|
||||
if output != "" {
|
||||
i.History.Add(output)
|
||||
}
|
||||
buf.MoveToEnd()
|
||||
fmt.Println()
|
||||
i.Prompt.UseAlt = false
|
||||
|
||||
return output, nil
|
||||
default:
|
||||
|
||||
@@ -7,13 +7,26 @@ import (
|
||||
)
|
||||
|
||||
func Execute(args []string) error {
|
||||
if len(args) > 0 {
|
||||
switch args[0] {
|
||||
case "--ollama-engine":
|
||||
return ollamarunner.Execute(args[1:])
|
||||
case "--image-engine":
|
||||
return imagerunner.Execute(args[1:])
|
||||
}
|
||||
if args[0] == "runner" {
|
||||
args = args[1:]
|
||||
}
|
||||
|
||||
var newRunner bool
|
||||
var imageRunner bool
|
||||
if len(args) > 0 && args[0] == "--ollama-engine" {
|
||||
args = args[1:]
|
||||
newRunner = true
|
||||
}
|
||||
if len(args) > 0 && args[0] == "--image-engine" {
|
||||
args = args[1:]
|
||||
imageRunner = true
|
||||
}
|
||||
|
||||
if imageRunner {
|
||||
return imagerunner.Execute(args)
|
||||
} else if newRunner {
|
||||
return ollamarunner.Execute(args)
|
||||
} else {
|
||||
return llamarunner.Execute(args)
|
||||
}
|
||||
return llamarunner.Execute(args)
|
||||
}
|
||||
|
||||
@@ -60,7 +60,7 @@ _build_darwin() {
|
||||
cmake --install $BUILD_DIR --component MLX
|
||||
# Override CGO flags to point to the amd64 build directory
|
||||
MLX_CGO_CFLAGS="-O3 -I$(pwd)/$BUILD_DIR/_deps/mlx-c-src -mmacosx-version-min=14.0"
|
||||
MLX_CGO_LDFLAGS="-ldl -lc++ -framework Accelerate -mmacosx-version-min=14.0"
|
||||
MLX_CGO_LDFLAGS="-L$(pwd)/$BUILD_DIR/lib/ollama -lmlxc -lmlx -Wl,-rpath,@executable_path -lc++ -framework Accelerate -mmacosx-version-min=14.0"
|
||||
else
|
||||
BUILD_DIR=build
|
||||
cmake --preset MLX \
|
||||
@@ -71,12 +71,10 @@ _build_darwin() {
|
||||
cmake --install $BUILD_DIR --component MLX
|
||||
# Use default CGO flags from mlx.go for arm64
|
||||
MLX_CGO_CFLAGS="-O3 -I$(pwd)/$BUILD_DIR/_deps/mlx-c-src -mmacosx-version-min=14.0"
|
||||
MLX_CGO_LDFLAGS="-lc++ -framework Metal -framework Foundation -framework Accelerate -mmacosx-version-min=14.0"
|
||||
MLX_CGO_LDFLAGS="-L$(pwd)/$BUILD_DIR/lib/ollama -lmlxc -lmlx -Wl,-rpath,@executable_path -lc++ -framework Metal -framework Foundation -framework Accelerate -mmacosx-version-min=14.0"
|
||||
fi
|
||||
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 CGO_CFLAGS="$MLX_CGO_CFLAGS" CGO_LDFLAGS="$MLX_CGO_LDFLAGS" go build -tags mlx -o $INSTALL_PREFIX .
|
||||
# Copy MLX libraries to same directory as executable for dlopen
|
||||
cp $INSTALL_PREFIX/lib/ollama/libmlxc.dylib $INSTALL_PREFIX/
|
||||
cp $INSTALL_PREFIX/lib/ollama/libmlx.dylib $INSTALL_PREFIX/
|
||||
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 CGO_CFLAGS="$MLX_CGO_CFLAGS" CGO_LDFLAGS="$MLX_CGO_LDFLAGS" go build -tags mlx -o $INSTALL_PREFIX/ollama-mlx .
|
||||
GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 go build -o $INSTALL_PREFIX .
|
||||
done
|
||||
}
|
||||
|
||||
@@ -84,10 +82,12 @@ _sign_darwin() {
|
||||
status "Creating universal binary..."
|
||||
mkdir -p dist/darwin
|
||||
lipo -create -output dist/darwin/ollama dist/darwin-*/ollama
|
||||
lipo -create -output dist/darwin/ollama-mlx dist/darwin-*/ollama-mlx
|
||||
chmod +x dist/darwin/ollama
|
||||
chmod +x dist/darwin/ollama-mlx
|
||||
|
||||
if [ -n "$APPLE_IDENTITY" ]; then
|
||||
for F in dist/darwin/ollama dist/darwin-*/lib/ollama/*; do
|
||||
for F in dist/darwin/ollama dist/darwin-*/lib/ollama/* dist/darwin/ollama-mlx; do
|
||||
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime $F
|
||||
done
|
||||
|
||||
@@ -154,6 +154,7 @@ _build_macapp() {
|
||||
mkdir -p dist/Ollama.app/Contents/Resources
|
||||
if [ -d dist/darwin-amd64 ]; then
|
||||
lipo -create -output dist/Ollama.app/Contents/Resources/ollama dist/darwin-amd64/ollama dist/darwin-arm64/ollama
|
||||
lipo -create -output dist/Ollama.app/Contents/Resources/ollama-mlx dist/darwin-amd64/ollama-mlx dist/darwin-arm64/ollama-mlx
|
||||
for F in dist/darwin-amd64/lib/ollama/*mlx*.dylib ; do
|
||||
lipo -create -output dist/darwin/$(basename $F) $F dist/darwin-arm64/lib/ollama/$(basename $F)
|
||||
done
|
||||
@@ -165,27 +166,28 @@ _build_macapp() {
|
||||
cp -a dist/darwin/ollama dist/Ollama.app/Contents/Resources/ollama
|
||||
cp dist/darwin/*.so dist/darwin/*.dylib dist/Ollama.app/Contents/Resources/
|
||||
fi
|
||||
cp -a dist/darwin/ollama-mlx dist/Ollama.app/Contents/Resources/ollama-mlx
|
||||
chmod a+x dist/Ollama.app/Contents/Resources/ollama
|
||||
|
||||
# Sign
|
||||
if [ -n "$APPLE_IDENTITY" ]; then
|
||||
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime dist/Ollama.app/Contents/Resources/ollama
|
||||
for lib in dist/Ollama.app/Contents/Resources/*.so dist/Ollama.app/Contents/Resources/*.dylib dist/Ollama.app/Contents/Resources/*.metallib ; do
|
||||
for lib in dist/Ollama.app/Contents/Resources/*.so dist/Ollama.app/Contents/Resources/*.dylib dist/Ollama.app/Contents/Resources/*.metallib dist/Ollama.app/Contents/Resources/ollama-mlx ; do
|
||||
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime ${lib}
|
||||
done
|
||||
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier com.electron.ollama --deep --options=runtime dist/Ollama.app
|
||||
fi
|
||||
|
||||
rm -f dist/Ollama-darwin.zip
|
||||
ditto -c -k --norsrc --keepParent dist/Ollama.app dist/Ollama-darwin.zip
|
||||
(cd dist/Ollama.app/Contents/Resources/; tar -cf - ollama *.so *.dylib *.metallib 2>/dev/null) | gzip -9vc > dist/ollama-darwin.tgz
|
||||
ditto -c -k --keepParent dist/Ollama.app dist/Ollama-darwin.zip
|
||||
(cd dist/Ollama.app/Contents/Resources/; tar -cf - ollama ollama-mlx *.so *.dylib *.metallib 2>/dev/null) | gzip -9vc > dist/ollama-darwin.tgz
|
||||
|
||||
# Notarize and Staple
|
||||
if [ -n "$APPLE_IDENTITY" ]; then
|
||||
$(xcrun -f notarytool) submit dist/Ollama-darwin.zip --wait --timeout 20m --apple-id "$APPLE_ID" --password "$APPLE_PASSWORD" --team-id "$APPLE_TEAM_ID"
|
||||
rm -f dist/Ollama-darwin.zip
|
||||
$(xcrun -f stapler) staple dist/Ollama.app
|
||||
ditto -c -k --norsrc --keepParent dist/Ollama.app dist/Ollama-darwin.zip
|
||||
ditto -c -k --keepParent dist/Ollama.app dist/Ollama-darwin.zip
|
||||
|
||||
rm -f dist/Ollama.dmg
|
||||
|
||||
|
||||
@@ -50,17 +50,12 @@ func (r registryChallenge) URL() (*url.URL, error) {
|
||||
return redirectURL, nil
|
||||
}
|
||||
|
||||
func getAuthorizationToken(ctx context.Context, challenge registryChallenge, originalHost string) (string, error) {
|
||||
func getAuthorizationToken(ctx context.Context, challenge registryChallenge) (string, error) {
|
||||
redirectURL, err := challenge.URL()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Validate that the realm host matches the original request host to prevent sending tokens cross-origin.
|
||||
if redirectURL.Host != originalHost {
|
||||
return "", fmt.Errorf("realm host %q does not match original host %q", redirectURL.Host, originalHost)
|
||||
}
|
||||
|
||||
sha256sum := sha256.Sum256(nil)
|
||||
data := []byte(fmt.Sprintf("%s,%s,%s", http.MethodGet, redirectURL.String(), base64.StdEncoding.EncodeToString([]byte(hex.EncodeToString(sha256sum[:])))))
|
||||
|
||||
|
||||
@@ -1,113 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestGetAuthorizationTokenRejectsCrossDomain(t *testing.T) {
|
||||
tests := []struct {
|
||||
realm string
|
||||
originalHost string
|
||||
wantMismatch bool
|
||||
}{
|
||||
{"https://example.com/token", "example.com", false},
|
||||
{"https://example.com/token", "other.com", true},
|
||||
{"https://example.com/token", "localhost:8000", true},
|
||||
{"https://localhost:5000/token", "localhost:5000", false},
|
||||
{"https://localhost:5000/token", "localhost:6000", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.originalHost, func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
challenge := registryChallenge{Realm: tt.realm, Service: "test", Scope: "repo:x:pull"}
|
||||
_, err := getAuthorizationToken(ctx, challenge, tt.originalHost)
|
||||
|
||||
isMismatch := err != nil && strings.Contains(err.Error(), "does not match")
|
||||
if tt.wantMismatch && !isMismatch {
|
||||
t.Errorf("expected domain mismatch error, got: %v", err)
|
||||
}
|
||||
if !tt.wantMismatch && isMismatch {
|
||||
t.Errorf("unexpected domain mismatch error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseRegistryChallenge(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
wantRealm, wantService, wantScope string
|
||||
}{
|
||||
{
|
||||
`Bearer realm="https://auth.example.com/token",service="registry",scope="repo:foo:pull"`,
|
||||
"https://auth.example.com/token", "registry", "repo:foo:pull",
|
||||
},
|
||||
{
|
||||
`Bearer realm="https://r.ollama.ai/v2/token",service="ollama",scope="-"`,
|
||||
"https://r.ollama.ai/v2/token", "ollama", "-",
|
||||
},
|
||||
{"", "", "", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
result := parseRegistryChallenge(tt.input)
|
||||
if result.Realm != tt.wantRealm || result.Service != tt.wantService || result.Scope != tt.wantScope {
|
||||
t.Errorf("parseRegistryChallenge(%q) = {%q, %q, %q}, want {%q, %q, %q}",
|
||||
tt.input, result.Realm, result.Service, result.Scope,
|
||||
tt.wantRealm, tt.wantService, tt.wantScope)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryChallengeURL(t *testing.T) {
|
||||
challenge := registryChallenge{
|
||||
Realm: "https://auth.example.com/token",
|
||||
Service: "registry",
|
||||
Scope: "repo:foo:pull repo:bar:push",
|
||||
}
|
||||
|
||||
u, err := challenge.URL()
|
||||
if err != nil {
|
||||
t.Fatalf("URL() error: %v", err)
|
||||
}
|
||||
|
||||
if u.Host != "auth.example.com" {
|
||||
t.Errorf("host = %q, want %q", u.Host, "auth.example.com")
|
||||
}
|
||||
if u.Path != "/token" {
|
||||
t.Errorf("path = %q, want %q", u.Path, "/token")
|
||||
}
|
||||
|
||||
q := u.Query()
|
||||
if q.Get("service") != "registry" {
|
||||
t.Errorf("service = %q, want %q", q.Get("service"), "registry")
|
||||
}
|
||||
if scopes := q["scope"]; len(scopes) != 2 {
|
||||
t.Errorf("scope count = %d, want 2", len(scopes))
|
||||
}
|
||||
if q.Get("ts") == "" {
|
||||
t.Error("missing ts")
|
||||
}
|
||||
if q.Get("nonce") == "" {
|
||||
t.Error("missing nonce")
|
||||
}
|
||||
|
||||
// Nonces should differ between calls
|
||||
u2, _ := challenge.URL()
|
||||
if q.Get("nonce") == u2.Query().Get("nonce") {
|
||||
t.Error("nonce should be unique per call")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryChallengeURLInvalid(t *testing.T) {
|
||||
challenge := registryChallenge{Realm: "://invalid"}
|
||||
if _, err := challenge.URL(); err == nil {
|
||||
t.Error("expected error for invalid URL")
|
||||
}
|
||||
}
|
||||
@@ -41,7 +41,6 @@ var (
|
||||
errCapabilityVision = errors.New("vision")
|
||||
errCapabilityEmbedding = errors.New("embedding")
|
||||
errCapabilityThinking = errors.New("thinking")
|
||||
errCapabilityImage = errors.New("image generation")
|
||||
errInsecureProtocol = errors.New("insecure protocol http")
|
||||
)
|
||||
|
||||
@@ -77,7 +76,7 @@ func (m *Model) Capabilities() []model.Capability {
|
||||
|
||||
// Check for image generation model via config capabilities
|
||||
if slices.Contains(m.Config.Capabilities, "image") {
|
||||
return []model.Capability{model.CapabilityImage}
|
||||
return []model.Capability{model.CapabilityImageGeneration}
|
||||
}
|
||||
|
||||
// Check for completion capability
|
||||
@@ -160,7 +159,6 @@ func (m *Model) CheckCapabilities(want ...model.Capability) error {
|
||||
model.CapabilityVision: errCapabilityVision,
|
||||
model.CapabilityEmbedding: errCapabilityEmbedding,
|
||||
model.CapabilityThinking: errCapabilityThinking,
|
||||
model.CapabilityImage: errCapabilityImage,
|
||||
}
|
||||
|
||||
for _, cap := range want {
|
||||
@@ -777,7 +775,7 @@ func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
|
||||
Realm: challenge.Realm,
|
||||
Service: challenge.Service,
|
||||
Scope: challenge.Scope,
|
||||
}, base.Host)
|
||||
})
|
||||
}
|
||||
|
||||
if err := transfer.Download(ctx, transfer.DownloadOptions{
|
||||
@@ -852,7 +850,7 @@ func pushWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
|
||||
Realm: challenge.Realm,
|
||||
Service: challenge.Service,
|
||||
Scope: challenge.Scope,
|
||||
}, base.Host)
|
||||
})
|
||||
}
|
||||
|
||||
return transfer.Upload(ctx, transfer.UploadOptions{
|
||||
@@ -918,7 +916,7 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
|
||||
|
||||
// Handle authentication error with one retry
|
||||
challenge := parseRegistryChallenge(resp.Header.Get("www-authenticate"))
|
||||
token, err := getAuthorizationToken(ctx, challenge, requestURL.Host)
|
||||
token, err := getAuthorizationToken(ctx, challenge)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -54,7 +54,7 @@ func TestModelCapabilities(t *testing.T) {
|
||||
Capabilities: []string{"image"},
|
||||
},
|
||||
},
|
||||
expectedCaps: []model.Capability{model.CapabilityImage},
|
||||
expectedCaps: []model.Capability{model.CapabilityImageGeneration},
|
||||
},
|
||||
{
|
||||
name: "model with completion capability",
|
||||
@@ -242,24 +242,6 @@ func TestModelCheckCapabilities(t *testing.T) {
|
||||
checkCaps: []model.Capability{"unknown"},
|
||||
expectedErrMsg: "unknown capability",
|
||||
},
|
||||
{
|
||||
name: "model missing image generation capability",
|
||||
model: Model{
|
||||
ModelPath: completionModelPath,
|
||||
Template: chatTemplate,
|
||||
},
|
||||
checkCaps: []model.Capability{model.CapabilityImage},
|
||||
expectedErrMsg: "does not support image generation",
|
||||
},
|
||||
{
|
||||
name: "model with image generation capability",
|
||||
model: Model{
|
||||
Config: model.ConfigV2{
|
||||
Capabilities: []string{"image"},
|
||||
},
|
||||
},
|
||||
checkCaps: []model.Capability{model.CapabilityImage},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
170
server/routes.go
@@ -51,7 +51,7 @@ import (
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/version"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
xserver "github.com/ollama/ollama/x/server"
|
||||
imagegenapi "github.com/ollama/ollama/x/imagegen/api"
|
||||
)
|
||||
|
||||
const signinURLStr = "https://ollama.com/connect?name=%s&key=%s"
|
||||
@@ -164,6 +164,29 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.C
|
||||
return runner.llama, model, &opts, nil
|
||||
}
|
||||
|
||||
// ScheduleImageGenRunner schedules an image generation model runner.
|
||||
// This implements the imagegenapi.RunnerScheduler interface.
|
||||
func (s *Server) ScheduleImageGenRunner(c *gin.Context, modelName string, opts api.Options, keepAlive *api.Duration) (llm.LlamaServer, error) {
|
||||
m := &Model{
|
||||
Name: modelName,
|
||||
ShortName: modelName,
|
||||
ModelPath: modelName, // For image gen, ModelPath is just the model name
|
||||
Config: model.ConfigV2{
|
||||
Capabilities: []string{"image"},
|
||||
},
|
||||
}
|
||||
|
||||
runnerCh, errCh := s.sched.GetRunner(c.Request.Context(), m, opts, keepAlive)
|
||||
var runner *runnerRef
|
||||
select {
|
||||
case runner = <-runnerCh:
|
||||
case err := <-errCh:
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return runner.llama, nil
|
||||
}
|
||||
|
||||
func signinURL() (string, error) {
|
||||
pubKey, err := auth.GetPublicKey()
|
||||
if err != nil {
|
||||
@@ -191,6 +214,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Check if this is a known image generation model
|
||||
if imagegen.ResolveModelName(req.Model) != "" {
|
||||
imagegenapi.HandleGenerateRequest(c, s, req.Model, req.Prompt, req.KeepAlive, streamResponse)
|
||||
return
|
||||
}
|
||||
|
||||
name := model.ParseName(req.Model)
|
||||
if !name.IsValid() {
|
||||
// Ideally this is "invalid model name" but we're keeping with
|
||||
@@ -220,12 +249,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Handle image generation models
|
||||
if slices.Contains(m.Capabilities(), model.CapabilityImage) {
|
||||
s.handleImageGenerate(c, req, name.String(), checkpointStart)
|
||||
return
|
||||
}
|
||||
|
||||
if req.TopLogprobs < 0 || req.TopLogprobs > 20 {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "top_logprobs must be between 0 and 20"})
|
||||
return
|
||||
@@ -1102,7 +1125,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
}
|
||||
|
||||
// For image generation models, populate details from imagegen package
|
||||
if slices.Contains(m.Capabilities(), model.CapabilityImage) {
|
||||
if slices.Contains(m.Capabilities(), model.CapabilityImageGeneration) {
|
||||
if info, err := imagegen.GetModelInfo(name.String()); err == nil {
|
||||
modelDetails.Family = info.Architecture
|
||||
modelDetails.ParameterSize = format.HumanNumber(uint64(info.ParameterCount))
|
||||
@@ -1110,22 +1133,6 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// For safetensors LLM models (experimental), populate details from config.json
|
||||
if m.Config.ModelFormat == "safetensors" && slices.Contains(m.Config.Capabilities, "completion") {
|
||||
if info, err := xserver.GetSafetensorsLLMInfo(name.String()); err == nil {
|
||||
if arch, ok := info["general.architecture"].(string); ok && arch != "" {
|
||||
modelDetails.Family = arch
|
||||
}
|
||||
if paramCount, ok := info["general.parameter_count"].(int64); ok && paramCount > 0 {
|
||||
modelDetails.ParameterSize = format.HumanNumber(uint64(paramCount))
|
||||
}
|
||||
}
|
||||
// Get torch_dtype directly from config.json for quantization level
|
||||
if dtype, err := xserver.GetSafetensorsDtype(name.String()); err == nil && dtype != "" {
|
||||
modelDetails.QuantizationLevel = dtype
|
||||
}
|
||||
}
|
||||
|
||||
if req.System != "" {
|
||||
m.System = req.System
|
||||
}
|
||||
@@ -1208,27 +1215,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
if slices.Contains(m.Capabilities(), model.CapabilityImage) {
|
||||
// Populate tensor info if verbose
|
||||
if req.Verbose {
|
||||
if tensors, err := xserver.GetSafetensorsTensorInfo(name.String()); err == nil {
|
||||
resp.Tensors = tensors
|
||||
}
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// For safetensors LLM models (experimental), populate ModelInfo from config.json
|
||||
if m.Config.ModelFormat == "safetensors" && slices.Contains(m.Config.Capabilities, "completion") {
|
||||
if info, err := xserver.GetSafetensorsLLMInfo(name.String()); err == nil {
|
||||
resp.ModelInfo = info
|
||||
}
|
||||
// Populate tensor info if verbose
|
||||
if req.Verbose {
|
||||
if tensors, err := xserver.GetSafetensorsTensorInfo(name.String()); err == nil {
|
||||
resp.Tensors = tensors
|
||||
}
|
||||
}
|
||||
if slices.Contains(m.Capabilities(), model.CapabilityImageGeneration) {
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
@@ -1600,12 +1587,13 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||
r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler)
|
||||
r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler)
|
||||
r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler)
|
||||
// OpenAI-compatible image generation endpoint
|
||||
r.POST("/v1/images/generations", middleware.ImageGenerationsMiddleware(), s.GenerateHandler)
|
||||
|
||||
// Inference (Anthropic compatibility)
|
||||
r.POST("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler)
|
||||
|
||||
// Experimental image generation support
|
||||
imagegenapi.RegisterRoutes(r, s)
|
||||
|
||||
if rc != nil {
|
||||
// wrap old with new
|
||||
rs := ®istry.Local{
|
||||
@@ -2472,91 +2460,3 @@ func filterThinkTags(msgs []api.Message, m *Model) []api.Message {
|
||||
}
|
||||
return msgs
|
||||
}
|
||||
|
||||
// handleImageGenerate handles image generation requests within GenerateHandler.
|
||||
// This is called when the model has the Image capability.
|
||||
func (s *Server) handleImageGenerate(c *gin.Context, req api.GenerateRequest, modelName string, checkpointStart time.Time) {
|
||||
// Validate image dimensions
|
||||
const maxDimension int32 = 4096
|
||||
if req.Width > maxDimension || req.Height > maxDimension {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("width and height must be <= %d", maxDimension)})
|
||||
return
|
||||
}
|
||||
|
||||
// Schedule the runner for image generation
|
||||
runner, _, _, err := s.scheduleRunner(c.Request.Context(), modelName, []model.Capability{model.CapabilityImage}, nil, req.KeepAlive)
|
||||
if err != nil {
|
||||
handleScheduleError(c, req.Model, err)
|
||||
return
|
||||
}
|
||||
|
||||
checkpointLoaded := time.Now()
|
||||
|
||||
// Handle load-only request (empty prompt)
|
||||
if req.Prompt == "" {
|
||||
c.JSON(http.StatusOK, api.GenerateResponse{
|
||||
Model: req.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Done: true,
|
||||
DoneReason: "load",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Set headers for streaming response
|
||||
c.Header("Content-Type", "application/x-ndjson")
|
||||
|
||||
// Get seed from options if provided
|
||||
var seed int64
|
||||
if s, ok := req.Options["seed"]; ok {
|
||||
switch v := s.(type) {
|
||||
case int:
|
||||
seed = int64(v)
|
||||
case int64:
|
||||
seed = v
|
||||
case float64:
|
||||
seed = int64(v)
|
||||
}
|
||||
}
|
||||
|
||||
var streamStarted bool
|
||||
if err := runner.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||
Prompt: req.Prompt,
|
||||
Width: req.Width,
|
||||
Height: req.Height,
|
||||
Steps: req.Steps,
|
||||
Seed: seed,
|
||||
}, func(cr llm.CompletionResponse) {
|
||||
streamStarted = true
|
||||
res := api.GenerateResponse{
|
||||
Model: req.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Done: cr.Done,
|
||||
}
|
||||
|
||||
if cr.TotalSteps > 0 {
|
||||
res.Completed = int64(cr.Step)
|
||||
res.Total = int64(cr.TotalSteps)
|
||||
}
|
||||
|
||||
if cr.Image != "" {
|
||||
res.Image = cr.Image
|
||||
}
|
||||
|
||||
if cr.Done {
|
||||
res.DoneReason = cr.DoneReason.String()
|
||||
res.Metrics.TotalDuration = time.Since(checkpointStart)
|
||||
res.Metrics.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||
}
|
||||
|
||||
data, _ := json.Marshal(res)
|
||||
c.Writer.Write(append(data, '\n'))
|
||||
c.Writer.Flush()
|
||||
}); err != nil {
|
||||
// Only send JSON error if streaming hasn't started yet
|
||||
// (once streaming starts, headers are committed and we can't change status code)
|
||||
if !streamStarted {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -574,8 +574,7 @@ func (s *Scheduler) loadImageGen(req *LlmRequest) bool {
|
||||
Options: &req.opts,
|
||||
loading: false,
|
||||
sessionDuration: sessionDuration,
|
||||
totalSize: server.TotalSize(),
|
||||
vramSize: server.VRAMSize(),
|
||||
refCount: 1,
|
||||
}
|
||||
|
||||
s.loadedMu.Lock()
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"errors"
|
||||
"log/slog"
|
||||
"os"
|
||||
"slices"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -16,6 +17,7 @@ import (
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
@@ -805,8 +807,32 @@ func (s *mockLlm) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo { return n
|
||||
func (s *mockLlm) HasExited() bool { return false }
|
||||
func (s *mockLlm) GetActiveDeviceIDs() []ml.DeviceID { return nil }
|
||||
|
||||
// TestImageGenCapabilityDetection verifies that models with "image" capability
|
||||
// are correctly identified and routed differently from language models.
|
||||
func TestImageGenCapabilityDetection(t *testing.T) {
|
||||
// Model with image capability should be detected
|
||||
imageModel := &Model{
|
||||
Config: model.ConfigV2{
|
||||
Capabilities: []string{"image"},
|
||||
},
|
||||
}
|
||||
require.True(t, slices.Contains(imageModel.Config.Capabilities, "image"))
|
||||
|
||||
// Model without image capability should not be detected
|
||||
langModel := &Model{
|
||||
Config: model.ConfigV2{
|
||||
Capabilities: []string{"completion"},
|
||||
},
|
||||
}
|
||||
require.False(t, slices.Contains(langModel.Config.Capabilities, "image"))
|
||||
|
||||
// Empty capabilities should not match
|
||||
emptyModel := &Model{}
|
||||
require.False(t, slices.Contains(emptyModel.Config.Capabilities, "image"))
|
||||
}
|
||||
|
||||
// TestImageGenRunnerCanBeEvicted verifies that an image generation model
|
||||
// loaded in the scheduler can be evicted when idle.
|
||||
// loaded in the scheduler can be evicted by a language model request.
|
||||
func TestImageGenRunnerCanBeEvicted(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(t.Context(), 500*time.Millisecond)
|
||||
defer done()
|
||||
@@ -838,59 +864,3 @@ func TestImageGenRunnerCanBeEvicted(t *testing.T) {
|
||||
require.NotNil(t, runner)
|
||||
require.Equal(t, "/fake/image/model", runner.modelPath)
|
||||
}
|
||||
|
||||
// TestImageGenSchedulerCoexistence verifies that image generation models
|
||||
// can coexist with language models in the scheduler and VRAM is tracked correctly.
|
||||
func TestImageGenSchedulerCoexistence(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(t.Context(), 500*time.Millisecond)
|
||||
defer done()
|
||||
|
||||
s := InitScheduler(ctx)
|
||||
s.getGpuFn = getGpuFn
|
||||
s.getSystemInfoFn = getSystemInfoFn
|
||||
|
||||
// Load both an imagegen runner and a language model runner
|
||||
imageGenRunner := &runnerRef{
|
||||
model: &Model{Name: "flux", ModelPath: "/fake/flux/model"},
|
||||
modelPath: "/fake/flux/model",
|
||||
llama: &mockLlm{vramSize: 8 * format.GigaByte, vramByGPU: map[ml.DeviceID]uint64{{Library: "Metal"}: 8 * format.GigaByte}},
|
||||
sessionDuration: 10 * time.Millisecond,
|
||||
numParallel: 1,
|
||||
refCount: 0,
|
||||
}
|
||||
|
||||
langModelRunner := &runnerRef{
|
||||
model: &Model{Name: "llama3", ModelPath: "/fake/llama3/model"},
|
||||
modelPath: "/fake/llama3/model",
|
||||
llama: &mockLlm{vramSize: 4 * format.GigaByte, vramByGPU: map[ml.DeviceID]uint64{{Library: "Metal"}: 4 * format.GigaByte}},
|
||||
sessionDuration: 10 * time.Millisecond,
|
||||
numParallel: 1,
|
||||
refCount: 0,
|
||||
}
|
||||
|
||||
s.loadedMu.Lock()
|
||||
s.loaded["/fake/flux/model"] = imageGenRunner
|
||||
s.loaded["/fake/llama3/model"] = langModelRunner
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
// Verify both are loaded
|
||||
s.loadedMu.Lock()
|
||||
require.Len(t, s.loaded, 2)
|
||||
require.NotNil(t, s.loaded["/fake/flux/model"])
|
||||
require.NotNil(t, s.loaded["/fake/llama3/model"])
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
// Verify updateFreeSpace accounts for both
|
||||
gpus := []ml.DeviceInfo{
|
||||
{
|
||||
DeviceID: ml.DeviceID{Library: "Metal"},
|
||||
TotalMemory: 24 * format.GigaByte,
|
||||
FreeMemory: 24 * format.GigaByte,
|
||||
},
|
||||
}
|
||||
s.updateFreeSpace(gpus)
|
||||
|
||||
// Free memory should be reduced by both models
|
||||
expectedFree := uint64(24*format.GigaByte) - uint64(8*format.GigaByte) - uint64(4*format.GigaByte)
|
||||
require.Equal(t, expectedFree, gpus[0].FreeMemory)
|
||||
}
|
||||
|
||||
@@ -279,7 +279,7 @@ func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *
|
||||
case resp.StatusCode == http.StatusUnauthorized:
|
||||
w.Rollback()
|
||||
challenge := parseRegistryChallenge(resp.Header.Get("www-authenticate"))
|
||||
token, err := getAuthorizationToken(ctx, challenge, requestURL.Host)
|
||||
token, err := getAuthorizationToken(ctx, challenge)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ const (
|
||||
CapabilityVision = Capability("vision")
|
||||
CapabilityEmbedding = Capability("embedding")
|
||||
CapabilityThinking = Capability("thinking")
|
||||
CapabilityImage = Capability("image")
|
||||
CapabilityImageGeneration = Capability("image")
|
||||
)
|
||||
|
||||
func (c Capability) String() string {
|
||||
|
||||
67
x/cmd/run.go
@@ -25,6 +25,14 @@ import (
|
||||
"github.com/ollama/ollama/x/tools"
|
||||
)
|
||||
|
||||
// MultilineState tracks the state of multiline input
|
||||
type MultilineState int
|
||||
|
||||
const (
|
||||
MultilineNone MultilineState = iota
|
||||
MultilineSystem
|
||||
)
|
||||
|
||||
// Tool output capping constants
|
||||
const (
|
||||
// localModelTokenLimit is the token limit for local models (smaller context).
|
||||
@@ -648,7 +656,7 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
Prompt: ">>> ",
|
||||
AltPrompt: "... ",
|
||||
Placeholder: "Send a message (/? for help)",
|
||||
AltPlaceholder: "Press Enter to send",
|
||||
AltPlaceholder: `Use """ to end multi-line input`,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -699,6 +707,7 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
var sb strings.Builder
|
||||
var format string
|
||||
var system string
|
||||
var multiline MultilineState = MultilineNone
|
||||
|
||||
for {
|
||||
line, err := scanner.Readline()
|
||||
@@ -712,12 +721,37 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
}
|
||||
scanner.Prompt.UseAlt = false
|
||||
sb.Reset()
|
||||
multiline = MultilineNone
|
||||
continue
|
||||
case err != nil:
|
||||
return err
|
||||
}
|
||||
|
||||
switch {
|
||||
case multiline != MultilineNone:
|
||||
// check if there's a multiline terminating string
|
||||
before, ok := strings.CutSuffix(line, `"""`)
|
||||
sb.WriteString(before)
|
||||
if !ok {
|
||||
fmt.Fprintln(&sb)
|
||||
continue
|
||||
}
|
||||
|
||||
switch multiline {
|
||||
case MultilineSystem:
|
||||
system = sb.String()
|
||||
newMessage := api.Message{Role: "system", Content: system}
|
||||
if len(messages) > 0 && messages[len(messages)-1].Role == "system" {
|
||||
messages[len(messages)-1] = newMessage
|
||||
} else {
|
||||
messages = append(messages, newMessage)
|
||||
}
|
||||
fmt.Println("Set system message.")
|
||||
sb.Reset()
|
||||
}
|
||||
|
||||
multiline = MultilineNone
|
||||
scanner.Prompt.UseAlt = false
|
||||
case strings.HasPrefix(line, "/exit"), strings.HasPrefix(line, "/bye"):
|
||||
return nil
|
||||
case strings.HasPrefix(line, "/clear"):
|
||||
@@ -826,18 +860,41 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
options[args[2]] = fp[args[2]]
|
||||
case "system":
|
||||
if len(args) < 3 {
|
||||
fmt.Println("Usage: /set system <message>")
|
||||
fmt.Println("Usage: /set system <message> or /set system \"\"\"<multi-line message>\"\"\"")
|
||||
continue
|
||||
}
|
||||
|
||||
system = strings.Join(args[2:], " ")
|
||||
newMessage := api.Message{Role: "system", Content: system}
|
||||
multiline = MultilineSystem
|
||||
|
||||
line := strings.Join(args[2:], " ")
|
||||
line, ok := strings.CutPrefix(line, `"""`)
|
||||
if !ok {
|
||||
multiline = MultilineNone
|
||||
} else {
|
||||
// only cut suffix if the line is multiline
|
||||
line, ok = strings.CutSuffix(line, `"""`)
|
||||
if ok {
|
||||
multiline = MultilineNone
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString(line)
|
||||
if multiline != MultilineNone {
|
||||
scanner.Prompt.UseAlt = true
|
||||
continue
|
||||
}
|
||||
|
||||
system = sb.String()
|
||||
newMessage := api.Message{Role: "system", Content: sb.String()}
|
||||
// Check if the slice is not empty and the last message is from 'system'
|
||||
if len(messages) > 0 && messages[len(messages)-1].Role == "system" {
|
||||
// Replace the last message
|
||||
messages[len(messages)-1] = newMessage
|
||||
} else {
|
||||
messages = append(messages, newMessage)
|
||||
}
|
||||
fmt.Println("Set system message.")
|
||||
sb.Reset()
|
||||
continue
|
||||
default:
|
||||
fmt.Printf("Unknown command '/set %s'. Type /? for help\n", args[1])
|
||||
@@ -1024,7 +1081,7 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
|
||||
sb.WriteString(line)
|
||||
}
|
||||
|
||||
if sb.Len() > 0 {
|
||||
if sb.Len() > 0 && multiline == MultilineNone {
|
||||
newMessage := api.Message{Role: "user", Content: sb.String()}
|
||||
messages = append(messages, newMessage)
|
||||
|
||||
|
||||
@@ -1,282 +0,0 @@
|
||||
// Package client provides client-side model creation for safetensors-based models.
|
||||
//
|
||||
// This package is in x/ because the safetensors model storage format is under development.
|
||||
// It also exists to break an import cycle: server imports x/create, so x/create
|
||||
// cannot import server. This sub-package can import server because server doesn't
|
||||
// import it.
|
||||
package client
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/ollama/ollama/server"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/x/create"
|
||||
)
|
||||
|
||||
// MinOllamaVersion is the minimum Ollama version required for safetensors models.
|
||||
const MinOllamaVersion = "0.14.0"
|
||||
|
||||
// ModelfileConfig holds configuration extracted from a Modelfile.
|
||||
type ModelfileConfig struct {
|
||||
Template string
|
||||
System string
|
||||
License string
|
||||
}
|
||||
|
||||
// CreateOptions holds all options for model creation.
|
||||
type CreateOptions struct {
|
||||
ModelName string
|
||||
ModelDir string
|
||||
Quantize string // "fp8" for quantization
|
||||
Modelfile *ModelfileConfig // template/system/license from Modelfile
|
||||
}
|
||||
|
||||
// CreateModel imports a model from a local directory.
|
||||
// This creates blobs and manifest directly on disk, bypassing the HTTP API.
|
||||
// Automatically detects model type (safetensors LLM vs image gen) and routes accordingly.
|
||||
func CreateModel(opts CreateOptions, p *progress.Progress) error {
|
||||
// Detect model type
|
||||
isSafetensors := create.IsSafetensorsModelDir(opts.ModelDir)
|
||||
isImageGen := create.IsTensorModelDir(opts.ModelDir)
|
||||
|
||||
if !isSafetensors && !isImageGen {
|
||||
return fmt.Errorf("%s is not a supported model directory (needs config.json + *.safetensors or model_index.json)", opts.ModelDir)
|
||||
}
|
||||
|
||||
// Determine model type settings
|
||||
var modelType, spinnerKey string
|
||||
var capabilities []string
|
||||
if isSafetensors {
|
||||
modelType = "safetensors model"
|
||||
spinnerKey = "create"
|
||||
capabilities = []string{"completion"}
|
||||
} else {
|
||||
modelType = "image generation model"
|
||||
spinnerKey = "imagegen"
|
||||
capabilities = []string{"image"}
|
||||
}
|
||||
|
||||
// Set up progress spinner
|
||||
statusMsg := "importing " + modelType
|
||||
spinner := progress.NewSpinner(statusMsg)
|
||||
p.Add(spinnerKey, spinner)
|
||||
|
||||
progressFn := func(msg string) {
|
||||
spinner.Stop()
|
||||
statusMsg = msg
|
||||
spinner = progress.NewSpinner(statusMsg)
|
||||
p.Add(spinnerKey, spinner)
|
||||
}
|
||||
|
||||
// Create the model using shared callbacks
|
||||
var err error
|
||||
if isSafetensors {
|
||||
err = create.CreateSafetensorsModel(
|
||||
opts.ModelName, opts.ModelDir, opts.Quantize,
|
||||
newLayerCreator(), newTensorLayerCreator(),
|
||||
newManifestWriter(opts, capabilities),
|
||||
progressFn,
|
||||
)
|
||||
} else {
|
||||
err = create.CreateImageGenModel(
|
||||
opts.ModelName, opts.ModelDir, opts.Quantize,
|
||||
newLayerCreator(), newTensorLayerCreator(),
|
||||
newManifestWriter(opts, capabilities),
|
||||
progressFn,
|
||||
)
|
||||
}
|
||||
|
||||
spinner.Stop()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Printf("Created %s '%s'\n", modelType, opts.ModelName)
|
||||
return nil
|
||||
}
|
||||
|
||||
// newLayerCreator returns a LayerCreator callback for creating config/JSON layers.
|
||||
func newLayerCreator() create.LayerCreator {
|
||||
return func(r io.Reader, mediaType, name string) (create.LayerInfo, error) {
|
||||
layer, err := server.NewLayer(r, mediaType)
|
||||
if err != nil {
|
||||
return create.LayerInfo{}, err
|
||||
}
|
||||
|
||||
return create.LayerInfo{
|
||||
Digest: layer.Digest,
|
||||
Size: layer.Size,
|
||||
MediaType: layer.MediaType,
|
||||
Name: name,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// newTensorLayerCreator returns a QuantizingTensorLayerCreator callback for creating tensor layers.
|
||||
// When quantize is non-empty, returns multiple layers (weight + scales + optional qbias).
|
||||
func newTensorLayerCreator() create.QuantizingTensorLayerCreator {
|
||||
return func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]create.LayerInfo, error) {
|
||||
if quantize != "" {
|
||||
return createQuantizedLayers(r, name, dtype, shape, quantize)
|
||||
}
|
||||
return createUnquantizedLayer(r, name)
|
||||
}
|
||||
}
|
||||
|
||||
// createQuantizedLayers quantizes a tensor and returns the resulting layers.
|
||||
func createQuantizedLayers(r io.Reader, name, dtype string, shape []int32, quantize string) ([]create.LayerInfo, error) {
|
||||
if !QuantizeSupported() {
|
||||
return nil, fmt.Errorf("quantization requires MLX support")
|
||||
}
|
||||
|
||||
// Quantize the tensor
|
||||
qweightData, scalesData, qbiasData, _, _, _, err := quantizeTensor(r, name, dtype, shape, quantize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to quantize %s: %w", name, err)
|
||||
}
|
||||
|
||||
// Create layer for quantized weight
|
||||
weightLayer, err := server.NewLayer(bytes.NewReader(qweightData), server.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create layer for scales
|
||||
scalesLayer, err := server.NewLayer(bytes.NewReader(scalesData), server.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
layers := []create.LayerInfo{
|
||||
{
|
||||
Digest: weightLayer.Digest,
|
||||
Size: weightLayer.Size,
|
||||
MediaType: weightLayer.MediaType,
|
||||
Name: name,
|
||||
},
|
||||
{
|
||||
Digest: scalesLayer.Digest,
|
||||
Size: scalesLayer.Size,
|
||||
MediaType: scalesLayer.MediaType,
|
||||
Name: name + "_scale",
|
||||
},
|
||||
}
|
||||
|
||||
// Add qbiases layer if present (affine mode)
|
||||
if qbiasData != nil {
|
||||
qbiasLayer, err := server.NewLayer(bytes.NewReader(qbiasData), server.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
layers = append(layers, create.LayerInfo{
|
||||
Digest: qbiasLayer.Digest,
|
||||
Size: qbiasLayer.Size,
|
||||
MediaType: qbiasLayer.MediaType,
|
||||
Name: name + "_qbias",
|
||||
})
|
||||
}
|
||||
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
// createUnquantizedLayer creates a single tensor layer without quantization.
|
||||
func createUnquantizedLayer(r io.Reader, name string) ([]create.LayerInfo, error) {
|
||||
layer, err := server.NewLayer(r, server.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return []create.LayerInfo{
|
||||
{
|
||||
Digest: layer.Digest,
|
||||
Size: layer.Size,
|
||||
MediaType: layer.MediaType,
|
||||
Name: name,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// newManifestWriter returns a ManifestWriter callback for writing the model manifest.
|
||||
func newManifestWriter(opts CreateOptions, capabilities []string) create.ManifestWriter {
|
||||
return func(modelName string, config create.LayerInfo, layers []create.LayerInfo) error {
|
||||
name := model.ParseName(modelName)
|
||||
if !name.IsValid() {
|
||||
return fmt.Errorf("invalid model name: %s", modelName)
|
||||
}
|
||||
|
||||
// Create config blob with version requirement
|
||||
configData := model.ConfigV2{
|
||||
ModelFormat: "safetensors",
|
||||
Capabilities: capabilities,
|
||||
Requires: MinOllamaVersion,
|
||||
}
|
||||
configJSON, err := json.Marshal(configData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal config: %w", err)
|
||||
}
|
||||
|
||||
// Create config layer blob
|
||||
configLayer, err := server.NewLayer(bytes.NewReader(configJSON), "application/vnd.docker.container.image.v1+json")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create config layer: %w", err)
|
||||
}
|
||||
|
||||
// Convert LayerInfo to server.Layer
|
||||
serverLayers := make([]server.Layer, 0, len(layers))
|
||||
for _, l := range layers {
|
||||
serverLayers = append(serverLayers, server.Layer{
|
||||
MediaType: l.MediaType,
|
||||
Digest: l.Digest,
|
||||
Size: l.Size,
|
||||
Name: l.Name,
|
||||
})
|
||||
}
|
||||
|
||||
// Add Modelfile layers if present
|
||||
if opts.Modelfile != nil {
|
||||
modelfileLayers, err := createModelfileLayers(opts.Modelfile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
serverLayers = append(serverLayers, modelfileLayers...)
|
||||
}
|
||||
|
||||
return server.WriteManifest(name, configLayer, serverLayers)
|
||||
}
|
||||
}
|
||||
|
||||
// createModelfileLayers creates layers for template, system, and license from Modelfile config.
|
||||
func createModelfileLayers(mf *ModelfileConfig) ([]server.Layer, error) {
|
||||
var layers []server.Layer
|
||||
|
||||
if mf.Template != "" {
|
||||
layer, err := server.NewLayer(bytes.NewReader([]byte(mf.Template)), "application/vnd.ollama.image.template")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create template layer: %w", err)
|
||||
}
|
||||
layers = append(layers, layer)
|
||||
}
|
||||
|
||||
if mf.System != "" {
|
||||
layer, err := server.NewLayer(bytes.NewReader([]byte(mf.System)), "application/vnd.ollama.image.system")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create system layer: %w", err)
|
||||
}
|
||||
layers = append(layers, layer)
|
||||
}
|
||||
|
||||
if mf.License != "" {
|
||||
layer, err := server.NewLayer(bytes.NewReader([]byte(mf.License)), "application/vnd.ollama.image.license")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create license layer: %w", err)
|
||||
}
|
||||
layers = append(layers, layer)
|
||||
}
|
||||
|
||||
return layers, nil
|
||||
}
|
||||
@@ -1,146 +0,0 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestModelfileConfig(t *testing.T) {
|
||||
// Test that ModelfileConfig struct works as expected
|
||||
config := &ModelfileConfig{
|
||||
Template: "{{ .Prompt }}",
|
||||
System: "You are a helpful assistant.",
|
||||
License: "MIT",
|
||||
}
|
||||
|
||||
if config.Template != "{{ .Prompt }}" {
|
||||
t.Errorf("Template = %q, want %q", config.Template, "{{ .Prompt }}")
|
||||
}
|
||||
if config.System != "You are a helpful assistant." {
|
||||
t.Errorf("System = %q, want %q", config.System, "You are a helpful assistant.")
|
||||
}
|
||||
if config.License != "MIT" {
|
||||
t.Errorf("License = %q, want %q", config.License, "MIT")
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelfileConfig_Empty(t *testing.T) {
|
||||
config := &ModelfileConfig{}
|
||||
|
||||
if config.Template != "" {
|
||||
t.Errorf("Template should be empty, got %q", config.Template)
|
||||
}
|
||||
if config.System != "" {
|
||||
t.Errorf("System should be empty, got %q", config.System)
|
||||
}
|
||||
if config.License != "" {
|
||||
t.Errorf("License should be empty, got %q", config.License)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelfileConfig_PartialFields(t *testing.T) {
|
||||
// Test config with only some fields set
|
||||
config := &ModelfileConfig{
|
||||
Template: "{{ .Prompt }}",
|
||||
// System and License intentionally empty
|
||||
}
|
||||
|
||||
if config.Template == "" {
|
||||
t.Error("Template should not be empty")
|
||||
}
|
||||
if config.System != "" {
|
||||
t.Error("System should be empty")
|
||||
}
|
||||
if config.License != "" {
|
||||
t.Error("License should be empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMinOllamaVersion(t *testing.T) {
|
||||
// Verify the minimum version constant is set
|
||||
if MinOllamaVersion == "" {
|
||||
t.Error("MinOllamaVersion should not be empty")
|
||||
}
|
||||
if MinOllamaVersion != "0.14.0" {
|
||||
t.Errorf("MinOllamaVersion = %q, want %q", MinOllamaVersion, "0.14.0")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateModel_InvalidDir(t *testing.T) {
|
||||
// Test that CreateModel returns error for invalid directory
|
||||
err := CreateModel(CreateOptions{
|
||||
ModelName: "test-model",
|
||||
ModelDir: "/nonexistent/path",
|
||||
}, nil)
|
||||
if err == nil {
|
||||
t.Error("expected error for nonexistent directory, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateModel_NotSafetensorsDir(t *testing.T) {
|
||||
// Test that CreateModel returns error for directory without safetensors
|
||||
dir := t.TempDir()
|
||||
|
||||
err := CreateModel(CreateOptions{
|
||||
ModelName: "test-model",
|
||||
ModelDir: dir,
|
||||
}, nil)
|
||||
if err == nil {
|
||||
t.Error("expected error for empty directory, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateOptions(t *testing.T) {
|
||||
opts := CreateOptions{
|
||||
ModelName: "my-model",
|
||||
ModelDir: "/path/to/model",
|
||||
Quantize: "fp8",
|
||||
Modelfile: &ModelfileConfig{
|
||||
Template: "test",
|
||||
System: "system",
|
||||
License: "MIT",
|
||||
},
|
||||
}
|
||||
|
||||
if opts.ModelName != "my-model" {
|
||||
t.Errorf("ModelName = %q, want %q", opts.ModelName, "my-model")
|
||||
}
|
||||
if opts.ModelDir != "/path/to/model" {
|
||||
t.Errorf("ModelDir = %q, want %q", opts.ModelDir, "/path/to/model")
|
||||
}
|
||||
if opts.Quantize != "fp8" {
|
||||
t.Errorf("Quantize = %q, want %q", opts.Quantize, "fp8")
|
||||
}
|
||||
if opts.Modelfile == nil {
|
||||
t.Error("Modelfile should not be nil")
|
||||
}
|
||||
if opts.Modelfile.Template != "test" {
|
||||
t.Errorf("Modelfile.Template = %q, want %q", opts.Modelfile.Template, "test")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateOptions_Defaults(t *testing.T) {
|
||||
opts := CreateOptions{
|
||||
ModelName: "test",
|
||||
ModelDir: "/tmp",
|
||||
}
|
||||
|
||||
// Quantize should default to empty
|
||||
if opts.Quantize != "" {
|
||||
t.Errorf("Quantize should be empty by default, got %q", opts.Quantize)
|
||||
}
|
||||
|
||||
// Modelfile should default to nil
|
||||
if opts.Modelfile != nil {
|
||||
t.Error("Modelfile should be nil by default")
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuantizeSupported(t *testing.T) {
|
||||
// This just verifies the function exists and returns a boolean
|
||||
// The actual value depends on build tags (mlx vs non-mlx)
|
||||
supported := QuantizeSupported()
|
||||
|
||||
// In non-mlx builds, this should be false
|
||||
// We can't easily test both cases, so just verify it returns something
|
||||
_ = supported
|
||||
}
|
||||
@@ -1,399 +0,0 @@
|
||||
package create
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
|
||||
// ModelConfig represents the config blob stored with a model.
|
||||
type ModelConfig struct {
|
||||
ModelFormat string `json:"model_format"`
|
||||
Capabilities []string `json:"capabilities"`
|
||||
}
|
||||
|
||||
// Manifest represents the manifest JSON structure.
|
||||
type Manifest struct {
|
||||
SchemaVersion int `json:"schemaVersion"`
|
||||
MediaType string `json:"mediaType"`
|
||||
Config ManifestLayer `json:"config"`
|
||||
Layers []ManifestLayer `json:"layers"`
|
||||
}
|
||||
|
||||
// ManifestLayer represents a layer in the manifest.
|
||||
type ManifestLayer struct {
|
||||
MediaType string `json:"mediaType"`
|
||||
Digest string `json:"digest"`
|
||||
Size int64 `json:"size"`
|
||||
Name string `json:"name,omitempty"`
|
||||
}
|
||||
|
||||
// defaultManifestDir returns the manifest storage directory.
|
||||
func defaultManifestDir() string {
|
||||
return filepath.Join(envconfig.Models(), "manifests")
|
||||
}
|
||||
|
||||
// defaultBlobDir returns the blob storage directory.
|
||||
func defaultBlobDir() string {
|
||||
return filepath.Join(envconfig.Models(), "blobs")
|
||||
}
|
||||
|
||||
// resolveManifestPath converts a model name to a manifest file path.
|
||||
func resolveManifestPath(modelName string) string {
|
||||
host := "registry.ollama.ai"
|
||||
namespace := "library"
|
||||
name := modelName
|
||||
tag := "latest"
|
||||
|
||||
if idx := strings.LastIndex(name, ":"); idx != -1 {
|
||||
tag = name[idx+1:]
|
||||
name = name[:idx]
|
||||
}
|
||||
|
||||
parts := strings.Split(name, "/")
|
||||
switch len(parts) {
|
||||
case 3:
|
||||
host = parts[0]
|
||||
namespace = parts[1]
|
||||
name = parts[2]
|
||||
case 2:
|
||||
namespace = parts[0]
|
||||
name = parts[1]
|
||||
}
|
||||
|
||||
return filepath.Join(defaultManifestDir(), host, namespace, name, tag)
|
||||
}
|
||||
|
||||
// loadManifest loads a manifest for the given model name.
|
||||
func loadManifest(modelName string) (*Manifest, error) {
|
||||
manifestPath := resolveManifestPath(modelName)
|
||||
|
||||
data, err := os.ReadFile(manifestPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var manifest Manifest
|
||||
if err := json.Unmarshal(data, &manifest); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &manifest, nil
|
||||
}
|
||||
|
||||
// loadModelConfig loads the config blob for a model.
|
||||
func loadModelConfig(modelName string) (*ModelConfig, error) {
|
||||
manifest, err := loadManifest(modelName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Read the config blob
|
||||
blobName := strings.Replace(manifest.Config.Digest, ":", "-", 1)
|
||||
blobPath := filepath.Join(defaultBlobDir(), blobName)
|
||||
|
||||
data, err := os.ReadFile(blobPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var config ModelConfig
|
||||
if err := json.Unmarshal(data, &config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
// IsSafetensorsModel checks if a model was created with the experimental
|
||||
// safetensors builder by checking the model format in the config.
|
||||
func IsSafetensorsModel(modelName string) bool {
|
||||
config, err := loadModelConfig(modelName)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return config.ModelFormat == "safetensors"
|
||||
}
|
||||
|
||||
// IsSafetensorsLLMModel checks if a model is a safetensors LLM model
|
||||
// (has completion capability, not image generation).
|
||||
func IsSafetensorsLLMModel(modelName string) bool {
|
||||
config, err := loadModelConfig(modelName)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return config.ModelFormat == "safetensors" && slices.Contains(config.Capabilities, "completion")
|
||||
}
|
||||
|
||||
// IsImageGenModel checks if a model is an image generation model
|
||||
// (has image capability).
|
||||
func IsImageGenModel(modelName string) bool {
|
||||
config, err := loadModelConfig(modelName)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return config.ModelFormat == "safetensors" && slices.Contains(config.Capabilities, "image")
|
||||
}
|
||||
|
||||
// GetModelArchitecture returns the architecture from the model's config.json layer.
|
||||
func GetModelArchitecture(modelName string) (string, error) {
|
||||
manifest, err := loadManifest(modelName)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Find the config.json layer
|
||||
for _, layer := range manifest.Layers {
|
||||
if layer.Name == "config.json" && layer.MediaType == "application/vnd.ollama.image.json" {
|
||||
blobName := strings.Replace(layer.Digest, ":", "-", 1)
|
||||
blobPath := filepath.Join(defaultBlobDir(), blobName)
|
||||
|
||||
data, err := os.ReadFile(blobPath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var cfg struct {
|
||||
Architectures []string `json:"architectures"`
|
||||
ModelType string `json:"model_type"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Prefer model_type, fall back to first architecture
|
||||
if cfg.ModelType != "" {
|
||||
return cfg.ModelType, nil
|
||||
}
|
||||
if len(cfg.Architectures) > 0 {
|
||||
return cfg.Architectures[0], nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("architecture not found in model config")
|
||||
}
|
||||
|
||||
// IsTensorModelDir checks if the directory contains a diffusers-style tensor model
|
||||
// by looking for model_index.json, which is the standard diffusers pipeline config.
|
||||
func IsTensorModelDir(dir string) bool {
|
||||
_, err := os.Stat(filepath.Join(dir, "model_index.json"))
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// IsSafetensorsModelDir checks if the directory contains a standard safetensors model
|
||||
// by looking for config.json and at least one .safetensors file.
|
||||
func IsSafetensorsModelDir(dir string) bool {
|
||||
// Must have config.json
|
||||
if _, err := os.Stat(filepath.Join(dir, "config.json")); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Must have at least one .safetensors file
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if strings.HasSuffix(entry.Name(), ".safetensors") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// LayerInfo holds metadata for a created layer.
|
||||
type LayerInfo struct {
|
||||
Digest string
|
||||
Size int64
|
||||
MediaType string
|
||||
Name string // Path-style name: "component/tensor" or "path/to/config.json"
|
||||
}
|
||||
|
||||
// LayerCreator is called to create a blob layer.
|
||||
// name is the path-style name (e.g., "tokenizer/tokenizer.json")
|
||||
type LayerCreator func(r io.Reader, mediaType, name string) (LayerInfo, error)
|
||||
|
||||
// TensorLayerCreator creates a tensor blob layer with metadata.
|
||||
// name is the path-style name including component (e.g., "text_encoder/model.embed_tokens.weight")
|
||||
type TensorLayerCreator func(r io.Reader, name, dtype string, shape []int32) (LayerInfo, error)
|
||||
|
||||
// QuantizingTensorLayerCreator creates tensor layers with optional quantization.
|
||||
// When quantize is non-empty (e.g., "fp8"), returns multiple layers (weight + scales + biases).
|
||||
type QuantizingTensorLayerCreator func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error)
|
||||
|
||||
// ManifestWriter writes the manifest file.
|
||||
type ManifestWriter func(modelName string, config LayerInfo, layers []LayerInfo) error
|
||||
|
||||
// ShouldQuantize returns true if a tensor should be quantized.
|
||||
// For image gen models (component non-empty): quantizes linear weights, skipping VAE, embeddings, norms.
|
||||
// For LLM models (component empty): quantizes linear weights, skipping embeddings, norms, and small tensors.
|
||||
func ShouldQuantize(name, component string) bool {
|
||||
// Image gen specific: skip VAE entirely
|
||||
if component == "vae" {
|
||||
return false
|
||||
}
|
||||
|
||||
// Skip embeddings
|
||||
if strings.Contains(name, "embed") {
|
||||
return false
|
||||
}
|
||||
|
||||
// Skip layer norms and RMS norms
|
||||
if strings.Contains(name, "norm") || strings.Contains(name, "ln_") || strings.Contains(name, "layernorm") {
|
||||
return false
|
||||
}
|
||||
|
||||
// Skip biases
|
||||
if strings.HasSuffix(name, ".bias") {
|
||||
return false
|
||||
}
|
||||
|
||||
// Only quantize weights
|
||||
return strings.HasSuffix(name, ".weight")
|
||||
}
|
||||
|
||||
// ShouldQuantizeTensor returns true if a tensor should be quantized based on name and shape.
|
||||
// This is a more detailed check that also considers tensor dimensions.
|
||||
func ShouldQuantizeTensor(name string, shape []int32) bool {
|
||||
// Use basic name-based check first
|
||||
if !ShouldQuantize(name, "") {
|
||||
return false
|
||||
}
|
||||
|
||||
// Only quantize 2D tensors (linear layers) - skip 1D (biases, norms) and higher-D (convolutions if any)
|
||||
if len(shape) != 2 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Skip small tensors (less than 1024 elements) - not worth quantizing
|
||||
if len(shape) >= 2 && int64(shape[0])*int64(shape[1]) < 1024 {
|
||||
return false
|
||||
}
|
||||
|
||||
// MLX quantization requires last dimension to be divisible by group size (32)
|
||||
if shape[len(shape)-1]%32 != 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// CreateSafetensorsModel imports a standard safetensors model from a directory.
|
||||
// This handles Hugging Face style models with config.json and *.safetensors files.
|
||||
// Stores each tensor as a separate blob for fine-grained deduplication.
|
||||
// If quantize is non-empty (e.g., "fp8"), eligible tensors will be quantized.
|
||||
func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
|
||||
var layers []LayerInfo
|
||||
var configLayer LayerInfo
|
||||
|
||||
entries, err := os.ReadDir(modelDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read directory: %w", err)
|
||||
}
|
||||
|
||||
// Process all safetensors files
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".safetensors") {
|
||||
continue
|
||||
}
|
||||
|
||||
stPath := filepath.Join(modelDir, entry.Name())
|
||||
|
||||
// Extract individual tensors from safetensors file
|
||||
extractor, err := safetensors.OpenForExtraction(stPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open %s: %w", stPath, err)
|
||||
}
|
||||
|
||||
tensorNames := extractor.ListTensors()
|
||||
quantizeMsg := ""
|
||||
if quantize != "" {
|
||||
quantizeMsg = fmt.Sprintf(", quantizing to %s", quantize)
|
||||
}
|
||||
fn(fmt.Sprintf("importing %s (%d tensors%s)", entry.Name(), len(tensorNames), quantizeMsg))
|
||||
|
||||
for _, tensorName := range tensorNames {
|
||||
td, err := extractor.GetTensor(tensorName)
|
||||
if err != nil {
|
||||
extractor.Close()
|
||||
return fmt.Errorf("failed to get tensor %s: %w", tensorName, err)
|
||||
}
|
||||
|
||||
// Determine quantization type for this tensor (empty string if not quantizing)
|
||||
quantizeType := ""
|
||||
if quantize != "" && ShouldQuantizeTensor(tensorName, td.Shape) {
|
||||
quantizeType = quantize
|
||||
}
|
||||
|
||||
// Store as minimal safetensors format (88 bytes header overhead)
|
||||
// This enables native mmap loading via mlx_load_safetensors
|
||||
// createTensorLayer returns multiple layers if quantizing (weight + scales)
|
||||
newLayers, err := createTensorLayer(td.SafetensorsReader(), tensorName, td.Dtype, td.Shape, quantizeType)
|
||||
if err != nil {
|
||||
extractor.Close()
|
||||
return fmt.Errorf("failed to create layer for %s: %w", tensorName, err)
|
||||
}
|
||||
layers = append(layers, newLayers...)
|
||||
}
|
||||
|
||||
extractor.Close()
|
||||
}
|
||||
|
||||
// Process all JSON config files
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".json") {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip the index file as we don't need it after extraction
|
||||
if entry.Name() == "model.safetensors.index.json" {
|
||||
continue
|
||||
}
|
||||
|
||||
cfgPath := entry.Name()
|
||||
fullPath := filepath.Join(modelDir, cfgPath)
|
||||
|
||||
fn(fmt.Sprintf("importing config %s", cfgPath))
|
||||
|
||||
f, err := os.Open(fullPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open %s: %w", cfgPath, err)
|
||||
}
|
||||
|
||||
layer, err := createLayer(f, "application/vnd.ollama.image.json", cfgPath)
|
||||
f.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create layer for %s: %w", cfgPath, err)
|
||||
}
|
||||
|
||||
// Use config.json as the config layer
|
||||
if cfgPath == "config.json" {
|
||||
configLayer = layer
|
||||
}
|
||||
|
||||
layers = append(layers, layer)
|
||||
}
|
||||
|
||||
if configLayer.Digest == "" {
|
||||
return fmt.Errorf("config.json not found in %s", modelDir)
|
||||
}
|
||||
|
||||
fn(fmt.Sprintf("writing manifest for %s", modelName))
|
||||
|
||||
if err := writeManifest(modelName, configLayer, layers); err != nil {
|
||||
return fmt.Errorf("failed to write manifest: %w", err)
|
||||
}
|
||||
|
||||
fn(fmt.Sprintf("successfully imported %s with %d layers", modelName, len(layers)))
|
||||
return nil
|
||||
}
|
||||
@@ -1,752 +0,0 @@
|
||||
package create
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestIsTensorModelDir(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func(dir string) error
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "valid diffusers model with model_index.json",
|
||||
setup: func(dir string) error {
|
||||
return os.WriteFile(filepath.Join(dir, "model_index.json"), []byte(`{"_class_name": "FluxPipeline"}`), 0o644)
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "empty directory",
|
||||
setup: func(dir string) error {
|
||||
return nil
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "directory with other files but no model_index.json",
|
||||
setup: func(dir string) error {
|
||||
return os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{}`), 0o644)
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
if err := tt.setup(dir); err != nil {
|
||||
t.Fatalf("setup failed: %v", err)
|
||||
}
|
||||
|
||||
got := IsTensorModelDir(dir)
|
||||
if got != tt.expected {
|
||||
t.Errorf("IsTensorModelDir() = %v, want %v", got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsSafetensorsModelDir(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func(dir string) error
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "valid safetensors model with config.json and .safetensors file",
|
||||
setup: func(dir string) error {
|
||||
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{"model_type": "gemma3"}`), 0o644); err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(filepath.Join(dir, "model.safetensors"), []byte("dummy"), 0o644)
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "config.json only, no safetensors files",
|
||||
setup: func(dir string) error {
|
||||
return os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{}`), 0o644)
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "safetensors file only, no config.json",
|
||||
setup: func(dir string) error {
|
||||
return os.WriteFile(filepath.Join(dir, "model.safetensors"), []byte("dummy"), 0o644)
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "empty directory",
|
||||
setup: func(dir string) error {
|
||||
return nil
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "multiple safetensors files with config.json",
|
||||
setup: func(dir string) error {
|
||||
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{}`), 0o644); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(dir, "model-00001-of-00002.safetensors"), []byte("dummy"), 0o644); err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(filepath.Join(dir, "model-00002-of-00002.safetensors"), []byte("dummy"), 0o644)
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
if err := tt.setup(dir); err != nil {
|
||||
t.Fatalf("setup failed: %v", err)
|
||||
}
|
||||
|
||||
got := IsSafetensorsModelDir(dir)
|
||||
if got != tt.expected {
|
||||
t.Errorf("IsSafetensorsModelDir() = %v, want %v", got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsSafetensorsModelDir_NonexistentDir(t *testing.T) {
|
||||
got := IsSafetensorsModelDir("/nonexistent/path/that/does/not/exist")
|
||||
if got != false {
|
||||
t.Errorf("IsSafetensorsModelDir() = %v for nonexistent dir, want false", got)
|
||||
}
|
||||
}
|
||||
|
||||
// createMinimalSafetensors creates a minimal valid safetensors file with one tensor
|
||||
func createMinimalSafetensors(t *testing.T, path string) {
|
||||
t.Helper()
|
||||
|
||||
// Create a minimal safetensors file with a single float32 tensor
|
||||
header := map[string]interface{}{
|
||||
"test_tensor": map[string]interface{}{
|
||||
"dtype": "F32",
|
||||
"shape": []int{2, 2},
|
||||
"data_offsets": []int{0, 16}, // 4 float32 values = 16 bytes
|
||||
},
|
||||
}
|
||||
headerJSON, err := json.Marshal(header)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal header: %v", err)
|
||||
}
|
||||
|
||||
// Pad header to 8-byte alignment
|
||||
padding := (8 - len(headerJSON)%8) % 8
|
||||
headerJSON = append(headerJSON, bytes.Repeat([]byte(" "), padding)...)
|
||||
|
||||
// Write file
|
||||
f, err := os.Create(path)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create file: %v", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
// Write header size (8 bytes, little endian)
|
||||
if err := binary.Write(f, binary.LittleEndian, uint64(len(headerJSON))); err != nil {
|
||||
t.Fatalf("failed to write header size: %v", err)
|
||||
}
|
||||
|
||||
// Write header
|
||||
if _, err := f.Write(headerJSON); err != nil {
|
||||
t.Fatalf("failed to write header: %v", err)
|
||||
}
|
||||
|
||||
// Write tensor data (16 bytes of zeros for 4 float32 values)
|
||||
if _, err := f.Write(make([]byte, 16)); err != nil {
|
||||
t.Fatalf("failed to write tensor data: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateSafetensorsModel(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// Create config.json
|
||||
configJSON := `{"model_type": "test", "architectures": ["TestModel"]}`
|
||||
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil {
|
||||
t.Fatalf("failed to write config.json: %v", err)
|
||||
}
|
||||
|
||||
// Create a minimal safetensors file
|
||||
createMinimalSafetensors(t, filepath.Join(dir, "model.safetensors"))
|
||||
|
||||
// Track what was created
|
||||
var createdLayers []LayerInfo
|
||||
var manifestWritten bool
|
||||
var manifestModelName string
|
||||
var manifestConfigLayer LayerInfo
|
||||
var manifestLayers []LayerInfo
|
||||
var statusMessages []string
|
||||
|
||||
// Mock callbacks
|
||||
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||
data, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
return LayerInfo{}, err
|
||||
}
|
||||
layer := LayerInfo{
|
||||
Digest: "sha256:test",
|
||||
Size: int64(len(data)),
|
||||
MediaType: mediaType,
|
||||
Name: name,
|
||||
}
|
||||
createdLayers = append(createdLayers, layer)
|
||||
return layer, nil
|
||||
}
|
||||
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
||||
data, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
layer := LayerInfo{
|
||||
Digest: "sha256:tensor_" + name,
|
||||
Size: int64(len(data)),
|
||||
MediaType: "application/vnd.ollama.image.tensor",
|
||||
Name: name,
|
||||
}
|
||||
createdLayers = append(createdLayers, layer)
|
||||
return []LayerInfo{layer}, nil
|
||||
}
|
||||
|
||||
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
|
||||
manifestWritten = true
|
||||
manifestModelName = modelName
|
||||
manifestConfigLayer = config
|
||||
manifestLayers = layers
|
||||
return nil
|
||||
}
|
||||
|
||||
progressFn := func(status string) {
|
||||
statusMessages = append(statusMessages, status)
|
||||
}
|
||||
|
||||
// Run CreateSafetensorsModel
|
||||
err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSafetensorsModel failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify manifest was written
|
||||
if !manifestWritten {
|
||||
t.Error("manifest was not written")
|
||||
}
|
||||
|
||||
if manifestModelName != "test-model" {
|
||||
t.Errorf("manifest model name = %q, want %q", manifestModelName, "test-model")
|
||||
}
|
||||
|
||||
// Verify config layer was set
|
||||
if manifestConfigLayer.Name != "config.json" {
|
||||
t.Errorf("config layer name = %q, want %q", manifestConfigLayer.Name, "config.json")
|
||||
}
|
||||
|
||||
// Verify we have at least one tensor and one config layer
|
||||
hasTensor := false
|
||||
hasConfig := false
|
||||
for _, layer := range manifestLayers {
|
||||
if layer.Name == "test_tensor" {
|
||||
hasTensor = true
|
||||
}
|
||||
if layer.Name == "config.json" {
|
||||
hasConfig = true
|
||||
}
|
||||
}
|
||||
|
||||
if !hasTensor {
|
||||
t.Error("no tensor layer found in manifest")
|
||||
}
|
||||
if !hasConfig {
|
||||
t.Error("no config layer found in manifest")
|
||||
}
|
||||
|
||||
// Verify status messages were sent
|
||||
if len(statusMessages) == 0 {
|
||||
t.Error("no status messages received")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateSafetensorsModel_NoConfigJson(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// Create only a safetensors file, no config.json
|
||||
createMinimalSafetensors(t, filepath.Join(dir, "model.safetensors"))
|
||||
|
||||
// Mock callbacks (minimal)
|
||||
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
return LayerInfo{Name: name}, nil
|
||||
}
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
return []LayerInfo{{Name: name}}, nil
|
||||
}
|
||||
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
|
||||
return nil
|
||||
}
|
||||
progressFn := func(status string) {}
|
||||
|
||||
err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
|
||||
if err == nil {
|
||||
t.Error("expected error for missing config.json, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateSafetensorsModel_EmptyDir(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// Mock callbacks
|
||||
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||
return LayerInfo{}, nil
|
||||
}
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
||||
return []LayerInfo{{}}, nil
|
||||
}
|
||||
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
|
||||
return nil
|
||||
}
|
||||
progressFn := func(status string) {}
|
||||
|
||||
err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
|
||||
if err == nil {
|
||||
t.Error("expected error for empty directory, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateSafetensorsModel_SkipsIndexJson(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// Create config.json
|
||||
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{}`), 0o644); err != nil {
|
||||
t.Fatalf("failed to write config.json: %v", err)
|
||||
}
|
||||
|
||||
// Create model.safetensors.index.json (should be skipped)
|
||||
indexJSON := `{"metadata": {"total_size": 100}, "weight_map": {}}`
|
||||
if err := os.WriteFile(filepath.Join(dir, "model.safetensors.index.json"), []byte(indexJSON), 0o644); err != nil {
|
||||
t.Fatalf("failed to write index.json: %v", err)
|
||||
}
|
||||
|
||||
// Create a minimal safetensors file
|
||||
createMinimalSafetensors(t, filepath.Join(dir, "model.safetensors"))
|
||||
|
||||
var configNames []string
|
||||
|
||||
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
configNames = append(configNames, name)
|
||||
return LayerInfo{Name: name, Digest: "sha256:test"}, nil
|
||||
}
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
return []LayerInfo{{Name: name}}, nil
|
||||
}
|
||||
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
|
||||
return nil
|
||||
}
|
||||
progressFn := func(status string) {}
|
||||
|
||||
err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSafetensorsModel failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify model.safetensors.index.json was not included
|
||||
for _, name := range configNames {
|
||||
if name == "model.safetensors.index.json" {
|
||||
t.Error("model.safetensors.index.json should have been skipped")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveManifestPath(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
modelName string
|
||||
wantParts []string // Parts that should appear in the path
|
||||
}{
|
||||
{
|
||||
name: "simple model name",
|
||||
modelName: "llama2",
|
||||
wantParts: []string{"registry.ollama.ai", "library", "llama2", "latest"},
|
||||
},
|
||||
{
|
||||
name: "model name with tag",
|
||||
modelName: "llama2:7b",
|
||||
wantParts: []string{"registry.ollama.ai", "library", "llama2", "7b"},
|
||||
},
|
||||
{
|
||||
name: "model name with namespace",
|
||||
modelName: "myuser/mymodel",
|
||||
wantParts: []string{"registry.ollama.ai", "myuser", "mymodel", "latest"},
|
||||
},
|
||||
{
|
||||
name: "model name with namespace and tag",
|
||||
modelName: "myuser/mymodel:v1",
|
||||
wantParts: []string{"registry.ollama.ai", "myuser", "mymodel", "v1"},
|
||||
},
|
||||
{
|
||||
name: "fully qualified model name",
|
||||
modelName: "registry.example.com/namespace/model:tag",
|
||||
wantParts: []string{"registry.example.com", "namespace", "model", "tag"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := resolveManifestPath(tt.modelName)
|
||||
|
||||
for _, part := range tt.wantParts {
|
||||
if !strings.Contains(got, part) {
|
||||
t.Errorf("resolveManifestPath(%q) = %q, missing part %q", tt.modelName, got, part)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLayerInfo(t *testing.T) {
|
||||
layer := LayerInfo{
|
||||
Digest: "sha256:abc123",
|
||||
Size: 1024,
|
||||
MediaType: "application/vnd.ollama.image.tensor",
|
||||
Name: "model.weight",
|
||||
}
|
||||
|
||||
if layer.Digest != "sha256:abc123" {
|
||||
t.Errorf("Digest = %q, want %q", layer.Digest, "sha256:abc123")
|
||||
}
|
||||
if layer.Size != 1024 {
|
||||
t.Errorf("Size = %d, want %d", layer.Size, 1024)
|
||||
}
|
||||
if layer.MediaType != "application/vnd.ollama.image.tensor" {
|
||||
t.Errorf("MediaType = %q, want %q", layer.MediaType, "application/vnd.ollama.image.tensor")
|
||||
}
|
||||
if layer.Name != "model.weight" {
|
||||
t.Errorf("Name = %q, want %q", layer.Name, "model.weight")
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelConfig(t *testing.T) {
|
||||
config := ModelConfig{
|
||||
ModelFormat: "safetensors",
|
||||
Capabilities: []string{"completion", "chat"},
|
||||
}
|
||||
|
||||
if config.ModelFormat != "safetensors" {
|
||||
t.Errorf("ModelFormat = %q, want %q", config.ModelFormat, "safetensors")
|
||||
}
|
||||
if len(config.Capabilities) != 2 {
|
||||
t.Errorf("Capabilities length = %d, want %d", len(config.Capabilities), 2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManifest(t *testing.T) {
|
||||
manifest := Manifest{
|
||||
SchemaVersion: 2,
|
||||
MediaType: "application/vnd.oci.image.manifest.v1+json",
|
||||
Config: ManifestLayer{
|
||||
MediaType: "application/vnd.docker.container.image.v1+json",
|
||||
Digest: "sha256:config",
|
||||
Size: 100,
|
||||
},
|
||||
Layers: []ManifestLayer{
|
||||
{
|
||||
MediaType: "application/vnd.ollama.image.tensor",
|
||||
Digest: "sha256:layer1",
|
||||
Size: 1000,
|
||||
Name: "weight.bin",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if manifest.SchemaVersion != 2 {
|
||||
t.Errorf("SchemaVersion = %d, want %d", manifest.SchemaVersion, 2)
|
||||
}
|
||||
if manifest.Config.Digest != "sha256:config" {
|
||||
t.Errorf("Config.Digest = %q, want %q", manifest.Config.Digest, "sha256:config")
|
||||
}
|
||||
if len(manifest.Layers) != 1 {
|
||||
t.Errorf("Layers length = %d, want %d", len(manifest.Layers), 1)
|
||||
}
|
||||
if manifest.Layers[0].Name != "weight.bin" {
|
||||
t.Errorf("Layers[0].Name = %q, want %q", manifest.Layers[0].Name, "weight.bin")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldQuantize(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tensor string
|
||||
component string
|
||||
want bool
|
||||
}{
|
||||
// VAE component should never be quantized
|
||||
{"vae weight", "decoder.weight", "vae", false},
|
||||
{"vae bias", "decoder.bias", "vae", false},
|
||||
|
||||
// Embeddings should not be quantized
|
||||
{"embedding weight", "embed_tokens.weight", "", false},
|
||||
{"embedding in name", "token_embedding.weight", "", false},
|
||||
|
||||
// Norms should not be quantized
|
||||
{"layer norm", "layer_norm.weight", "", false},
|
||||
{"rms norm", "rms_norm.weight", "", false},
|
||||
{"ln prefix", "ln_1.weight", "", false},
|
||||
{"layernorm in name", "input_layernorm.weight", "", false},
|
||||
|
||||
// Biases should not be quantized
|
||||
{"bias tensor", "attention.bias", "", false},
|
||||
{"proj bias", "o_proj.bias", "", false},
|
||||
|
||||
// Linear weights should be quantized
|
||||
{"linear weight", "q_proj.weight", "", true},
|
||||
{"attention weight", "self_attn.weight", "", true},
|
||||
{"mlp weight", "mlp.gate_proj.weight", "", true},
|
||||
|
||||
// Transformer component weights should be quantized
|
||||
{"transformer weight", "layers.0.weight", "transformer", true},
|
||||
{"text_encoder weight", "encoder.weight", "text_encoder", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := ShouldQuantize(tt.tensor, tt.component)
|
||||
if got != tt.want {
|
||||
t.Errorf("ShouldQuantize(%q, %q) = %v, want %v", tt.tensor, tt.component, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldQuantizeTensor(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tensor string
|
||||
shape []int32
|
||||
want bool
|
||||
}{
|
||||
// 2D tensors with sufficient size should be quantized
|
||||
{"large 2D weight", "q_proj.weight", []int32{4096, 4096}, true},
|
||||
{"medium 2D weight", "small_proj.weight", []int32{128, 128}, true},
|
||||
|
||||
// Small tensors should not be quantized (< 1024 elements)
|
||||
{"tiny 2D weight", "tiny.weight", []int32{16, 16}, false},
|
||||
{"small 2D weight", "small.weight", []int32{31, 31}, false},
|
||||
|
||||
// 1D tensors should not be quantized
|
||||
{"1D tensor", "layer_norm.weight", []int32{4096}, false},
|
||||
|
||||
// 3D+ tensors should not be quantized
|
||||
{"3D tensor", "conv.weight", []int32{64, 64, 3}, false},
|
||||
{"4D tensor", "conv2d.weight", []int32{64, 64, 3, 3}, false},
|
||||
|
||||
// Embeddings should not be quantized regardless of shape
|
||||
{"embedding 2D", "embed_tokens.weight", []int32{32000, 4096}, false},
|
||||
|
||||
// Norms should not be quantized regardless of shape
|
||||
{"norm 2D", "layer_norm.weight", []int32{4096, 1}, false},
|
||||
|
||||
// Biases should not be quantized
|
||||
{"bias 2D", "proj.bias", []int32{4096, 1}, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := ShouldQuantizeTensor(tt.tensor, tt.shape)
|
||||
if got != tt.want {
|
||||
t.Errorf("ShouldQuantizeTensor(%q, %v) = %v, want %v", tt.tensor, tt.shape, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateSafetensorsModel_WithQuantize(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// Create config.json
|
||||
configJSON := `{"model_type": "test", "architectures": ["TestModel"]}`
|
||||
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0o644); err != nil {
|
||||
t.Fatalf("failed to write config.json: %v", err)
|
||||
}
|
||||
|
||||
// Create a minimal safetensors file
|
||||
createMinimalSafetensors(t, filepath.Join(dir, "model.safetensors"))
|
||||
|
||||
var quantizeRequested []string
|
||||
|
||||
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
return LayerInfo{Name: name, Digest: "sha256:test"}, nil
|
||||
}
|
||||
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
quantizeRequested = append(quantizeRequested, quantize)
|
||||
return []LayerInfo{{Name: name}}, nil
|
||||
}
|
||||
|
||||
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
progressFn := func(status string) {}
|
||||
|
||||
// Run with quantize enabled
|
||||
err := CreateSafetensorsModel("test-model", dir, "fp8", createLayer, createTensorLayer, writeManifest, progressFn)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSafetensorsModel failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify quantize was passed to callback (will be false for small test tensor)
|
||||
if len(quantizeRequested) == 0 {
|
||||
t.Error("no tensors processed")
|
||||
}
|
||||
}
|
||||
|
||||
// createMinimalImageGenModel creates a minimal diffusers-style model directory
|
||||
func createMinimalImageGenModel(t *testing.T, dir string) {
|
||||
t.Helper()
|
||||
|
||||
// Create model_index.json
|
||||
modelIndex := `{"_class_name": "FluxPipeline", "_diffusers_version": "0.30.0"}`
|
||||
if err := os.WriteFile(filepath.Join(dir, "model_index.json"), []byte(modelIndex), 0o644); err != nil {
|
||||
t.Fatalf("failed to write model_index.json: %v", err)
|
||||
}
|
||||
|
||||
// Create transformer directory with a safetensors file
|
||||
transformerDir := filepath.Join(dir, "transformer")
|
||||
if err := os.MkdirAll(transformerDir, 0o755); err != nil {
|
||||
t.Fatalf("failed to create transformer dir: %v", err)
|
||||
}
|
||||
createMinimalSafetensors(t, filepath.Join(transformerDir, "model.safetensors"))
|
||||
|
||||
// Create transformer config
|
||||
transformerConfig := `{"hidden_size": 3072}`
|
||||
if err := os.WriteFile(filepath.Join(transformerDir, "config.json"), []byte(transformerConfig), 0o644); err != nil {
|
||||
t.Fatalf("failed to write transformer config: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateImageGenModel(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
createMinimalImageGenModel(t, dir)
|
||||
|
||||
var manifestWritten bool
|
||||
var manifestModelName string
|
||||
var statusMessages []string
|
||||
|
||||
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
return LayerInfo{Name: name, Digest: "sha256:test"}, nil
|
||||
}
|
||||
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
return []LayerInfo{{Name: name, Digest: "sha256:tensor"}}, nil
|
||||
}
|
||||
|
||||
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
|
||||
manifestWritten = true
|
||||
manifestModelName = modelName
|
||||
return nil
|
||||
}
|
||||
|
||||
progressFn := func(status string) {
|
||||
statusMessages = append(statusMessages, status)
|
||||
}
|
||||
|
||||
err := CreateImageGenModel("test-imagegen", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateImageGenModel failed: %v", err)
|
||||
}
|
||||
|
||||
if !manifestWritten {
|
||||
t.Error("manifest was not written")
|
||||
}
|
||||
|
||||
if manifestModelName != "test-imagegen" {
|
||||
t.Errorf("manifest model name = %q, want %q", manifestModelName, "test-imagegen")
|
||||
}
|
||||
|
||||
if len(statusMessages) == 0 {
|
||||
t.Error("no status messages received")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateImageGenModel_NoModelIndex(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// Create only transformer without model_index.json
|
||||
transformerDir := filepath.Join(dir, "transformer")
|
||||
if err := os.MkdirAll(transformerDir, 0o755); err != nil {
|
||||
t.Fatalf("failed to create transformer dir: %v", err)
|
||||
}
|
||||
createMinimalSafetensors(t, filepath.Join(transformerDir, "model.safetensors"))
|
||||
|
||||
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
return LayerInfo{Name: name}, nil
|
||||
}
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
return []LayerInfo{{Name: name}}, nil
|
||||
}
|
||||
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
|
||||
return nil
|
||||
}
|
||||
progressFn := func(status string) {}
|
||||
|
||||
err := CreateImageGenModel("test-imagegen", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
|
||||
if err == nil {
|
||||
t.Error("expected error for missing model_index.json, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateImageGenModel_WithQuantize(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
createMinimalImageGenModel(t, dir)
|
||||
|
||||
var quantizeRequested []string
|
||||
|
||||
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
return LayerInfo{Name: name, Digest: "sha256:test"}, nil
|
||||
}
|
||||
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
quantizeRequested = append(quantizeRequested, quantize)
|
||||
return []LayerInfo{{Name: name}}, nil
|
||||
}
|
||||
|
||||
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
progressFn := func(status string) {}
|
||||
|
||||
err := CreateImageGenModel("test-imagegen", dir, "fp8", createLayer, createTensorLayer, writeManifest, progressFn)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateImageGenModel failed: %v", err)
|
||||
}
|
||||
|
||||
if len(quantizeRequested) == 0 {
|
||||
t.Error("no tensors processed")
|
||||
}
|
||||
}
|
||||
231
x/imagegen/api/handler.go
Normal file
@@ -0,0 +1,231 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
)
|
||||
|
||||
// RunnerScheduler is the interface for scheduling a model runner.
|
||||
// This is implemented by server.Server to avoid circular imports.
|
||||
type RunnerScheduler interface {
|
||||
ScheduleImageGenRunner(ctx *gin.Context, modelName string, opts api.Options, keepAlive *api.Duration) (llm.LlamaServer, error)
|
||||
}
|
||||
|
||||
// RegisterRoutes registers the image generation API routes.
|
||||
func RegisterRoutes(r gin.IRouter, scheduler RunnerScheduler) {
|
||||
r.POST("/v1/images/generations", func(c *gin.Context) {
|
||||
ImageGenerationHandler(c, scheduler)
|
||||
})
|
||||
}
|
||||
|
||||
// ImageGenerationHandler handles OpenAI-compatible image generation requests.
|
||||
func ImageGenerationHandler(c *gin.Context, scheduler RunnerScheduler) {
|
||||
var req ImageGenerationRequest
|
||||
if err := c.BindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": err.Error()}})
|
||||
return
|
||||
}
|
||||
|
||||
// Validate required fields
|
||||
if req.Model == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": "model is required"}})
|
||||
return
|
||||
}
|
||||
if req.Prompt == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": "prompt is required"}})
|
||||
return
|
||||
}
|
||||
|
||||
// Apply defaults
|
||||
if req.N == 0 {
|
||||
req.N = 1
|
||||
}
|
||||
if req.Size == "" {
|
||||
req.Size = "1024x1024"
|
||||
}
|
||||
if req.ResponseFormat == "" {
|
||||
req.ResponseFormat = "b64_json"
|
||||
}
|
||||
|
||||
// Verify model exists
|
||||
if imagegen.ResolveModelName(req.Model) == "" {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": gin.H{"message": fmt.Sprintf("model %q not found", req.Model)}})
|
||||
return
|
||||
}
|
||||
|
||||
// Parse size
|
||||
width, height := parseSize(req.Size)
|
||||
|
||||
// Build options - we repurpose NumCtx/NumGPU for width/height
|
||||
opts := api.Options{}
|
||||
opts.NumCtx = int(width)
|
||||
opts.NumGPU = int(height)
|
||||
|
||||
// Schedule runner
|
||||
runner, err := scheduler.ScheduleImageGenRunner(c, req.Model, opts, nil)
|
||||
if err != nil {
|
||||
status := http.StatusInternalServerError
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
status = http.StatusNotFound
|
||||
}
|
||||
c.JSON(status, gin.H{"error": gin.H{"message": err.Error()}})
|
||||
return
|
||||
}
|
||||
|
||||
// Build completion request
|
||||
completionReq := llm.CompletionRequest{
|
||||
Prompt: req.Prompt,
|
||||
Options: &opts,
|
||||
}
|
||||
|
||||
if req.Stream {
|
||||
handleStreamingResponse(c, runner, completionReq, req.ResponseFormat)
|
||||
} else {
|
||||
handleNonStreamingResponse(c, runner, completionReq, req.ResponseFormat)
|
||||
}
|
||||
}
|
||||
|
||||
func handleStreamingResponse(c *gin.Context, runner llm.LlamaServer, req llm.CompletionRequest, format string) {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
|
||||
var imageBase64 string
|
||||
err := runner.Completion(c.Request.Context(), req, func(resp llm.CompletionResponse) {
|
||||
if resp.Done {
|
||||
imageBase64 = extractBase64(resp.Content)
|
||||
} else {
|
||||
progress := parseProgress(resp.Content)
|
||||
if progress.Total > 0 {
|
||||
c.SSEvent("progress", progress)
|
||||
c.Writer.Flush()
|
||||
}
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
c.SSEvent("error", gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.SSEvent("done", buildResponse(imageBase64, format))
|
||||
}
|
||||
|
||||
func handleNonStreamingResponse(c *gin.Context, runner llm.LlamaServer, req llm.CompletionRequest, format string) {
|
||||
var imageBase64 string
|
||||
err := runner.Completion(c.Request.Context(), req, func(resp llm.CompletionResponse) {
|
||||
if resp.Done {
|
||||
imageBase64 = extractBase64(resp.Content)
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{"message": err.Error()}})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, buildResponse(imageBase64, format))
|
||||
}
|
||||
|
||||
func parseSize(size string) (int32, int32) {
|
||||
parts := strings.Split(size, "x")
|
||||
if len(parts) != 2 {
|
||||
return 1024, 1024
|
||||
}
|
||||
w, _ := strconv.Atoi(parts[0])
|
||||
h, _ := strconv.Atoi(parts[1])
|
||||
if w == 0 {
|
||||
w = 1024
|
||||
}
|
||||
if h == 0 {
|
||||
h = 1024
|
||||
}
|
||||
return int32(w), int32(h)
|
||||
}
|
||||
|
||||
func extractBase64(content string) string {
|
||||
if strings.HasPrefix(content, "IMAGE_BASE64:") {
|
||||
return content[13:]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func parseProgress(content string) ImageProgressEvent {
|
||||
var step, total int
|
||||
fmt.Sscanf(content, "\rGenerating: step %d/%d", &step, &total)
|
||||
return ImageProgressEvent{Step: step, Total: total}
|
||||
}
|
||||
|
||||
func buildResponse(imageBase64, format string) ImageGenerationResponse {
|
||||
resp := ImageGenerationResponse{
|
||||
Created: time.Now().Unix(),
|
||||
Data: make([]ImageData, 1),
|
||||
}
|
||||
|
||||
if imageBase64 == "" {
|
||||
return resp
|
||||
}
|
||||
|
||||
if format == "url" {
|
||||
// URL format not supported when using base64 transfer
|
||||
resp.Data[0].B64JSON = imageBase64
|
||||
} else {
|
||||
resp.Data[0].B64JSON = imageBase64
|
||||
}
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
// HandleGenerateRequest handles Ollama /api/generate requests for image gen models.
|
||||
// This allows routes.go to delegate image generation with minimal code.
|
||||
func HandleGenerateRequest(c *gin.Context, scheduler RunnerScheduler, modelName, prompt string, keepAlive *api.Duration, streamFn func(c *gin.Context, ch chan any)) {
|
||||
opts := api.Options{}
|
||||
|
||||
// Schedule runner
|
||||
runner, err := scheduler.ScheduleImageGenRunner(c, modelName, opts, keepAlive)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Build completion request
|
||||
completionReq := llm.CompletionRequest{
|
||||
Prompt: prompt,
|
||||
Options: &opts,
|
||||
}
|
||||
|
||||
// Stream responses via channel
|
||||
ch := make(chan any)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
err := runner.Completion(c.Request.Context(), completionReq, func(resp llm.CompletionResponse) {
|
||||
ch <- GenerateResponse{
|
||||
Model: modelName,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Response: resp.Content,
|
||||
Done: resp.Done,
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
// Log error but don't block - channel is already being consumed
|
||||
_ = err
|
||||
}
|
||||
}()
|
||||
|
||||
streamFn(c, ch)
|
||||
}
|
||||
|
||||
// GenerateResponse matches api.GenerateResponse structure for streaming.
|
||||
type GenerateResponse struct {
|
||||
Model string `json:"model"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
Response string `json:"response"`
|
||||
Done bool `json:"done"`
|
||||
}
|
||||
31
x/imagegen/api/types.go
Normal file
@@ -0,0 +1,31 @@
|
||||
// Package api provides OpenAI-compatible image generation API types.
|
||||
package api
|
||||
|
||||
// ImageGenerationRequest is an OpenAI-compatible image generation request.
|
||||
type ImageGenerationRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
N int `json:"n,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
}
|
||||
|
||||
// ImageGenerationResponse is an OpenAI-compatible image generation response.
|
||||
type ImageGenerationResponse struct {
|
||||
Created int64 `json:"created"`
|
||||
Data []ImageData `json:"data"`
|
||||
}
|
||||
|
||||
// ImageData contains the generated image data.
|
||||
type ImageData struct {
|
||||
URL string `json:"url,omitempty"`
|
||||
B64JSON string `json:"b64_json,omitempty"`
|
||||
RevisedPrompt string `json:"revised_prompt,omitempty"`
|
||||
}
|
||||
|
||||
// ImageProgressEvent is sent during streaming to indicate generation progress.
|
||||
type ImageProgressEvent struct {
|
||||
Step int `json:"step"`
|
||||
Total int `json:"total"`
|
||||
}
|
||||
@@ -7,6 +7,7 @@ package imagegen
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -38,20 +39,79 @@ func DefaultOptions() ImageGenOptions {
|
||||
return ImageGenOptions{
|
||||
Width: 1024,
|
||||
Height: 1024,
|
||||
Steps: 0, // 0 means model default
|
||||
Steps: 9,
|
||||
Seed: 0, // 0 means random
|
||||
}
|
||||
}
|
||||
|
||||
// ModelInfo contains metadata about an image generation model.
|
||||
type ModelInfo struct {
|
||||
Architecture string
|
||||
ParameterCount int64
|
||||
Quantization string
|
||||
}
|
||||
|
||||
// GetModelInfo returns metadata about an image generation model.
|
||||
func GetModelInfo(modelName string) (*ModelInfo, error) {
|
||||
manifest, err := LoadManifest(modelName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load manifest: %w", err)
|
||||
}
|
||||
|
||||
info := &ModelInfo{}
|
||||
|
||||
// Read model_index.json for architecture, parameter count, and quantization
|
||||
if data, err := manifest.ReadConfig("model_index.json"); err == nil {
|
||||
var index struct {
|
||||
Architecture string `json:"architecture"`
|
||||
ParameterCount int64 `json:"parameter_count"`
|
||||
Quantization string `json:"quantization"`
|
||||
}
|
||||
if json.Unmarshal(data, &index) == nil {
|
||||
info.Architecture = index.Architecture
|
||||
info.ParameterCount = index.ParameterCount
|
||||
info.Quantization = index.Quantization
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: detect quantization from tensor names if not in config
|
||||
if info.Quantization == "" {
|
||||
for _, layer := range manifest.Manifest.Layers {
|
||||
if strings.HasSuffix(layer.Name, ".weight_scale") {
|
||||
info.Quantization = "FP8"
|
||||
break
|
||||
}
|
||||
}
|
||||
if info.Quantization == "" {
|
||||
info.Quantization = "BF16"
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: estimate parameter count if not in config
|
||||
if info.ParameterCount == 0 {
|
||||
var totalSize int64
|
||||
for _, layer := range manifest.Manifest.Layers {
|
||||
if layer.MediaType == "application/vnd.ollama.image.tensor" {
|
||||
if !strings.HasSuffix(layer.Name, "_scale") && !strings.HasSuffix(layer.Name, "_qbias") {
|
||||
totalSize += layer.Size
|
||||
}
|
||||
}
|
||||
}
|
||||
// Assume BF16 (2 bytes/param) as rough estimate
|
||||
info.ParameterCount = totalSize / 2
|
||||
}
|
||||
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// RegisterFlags adds image generation flags to the given command.
|
||||
// Flags are hidden since they only apply to image generation models.
|
||||
func RegisterFlags(cmd *cobra.Command) {
|
||||
cmd.Flags().Int("width", 1024, "Image width")
|
||||
cmd.Flags().Int("height", 1024, "Image height")
|
||||
cmd.Flags().Int("steps", 0, "Denoising steps (0 = model default)")
|
||||
cmd.Flags().Int("steps", 9, "Denoising steps")
|
||||
cmd.Flags().Int("seed", 0, "Random seed (0 for random)")
|
||||
cmd.Flags().String("negative", "", "Negative prompt")
|
||||
// Hide from main flags section - shown in separate section via AppendFlagsDocs
|
||||
cmd.Flags().MarkHidden("width")
|
||||
cmd.Flags().MarkHidden("height")
|
||||
cmd.Flags().MarkHidden("steps")
|
||||
@@ -59,19 +119,6 @@ func RegisterFlags(cmd *cobra.Command) {
|
||||
cmd.Flags().MarkHidden("negative")
|
||||
}
|
||||
|
||||
// AppendFlagsDocs appends image generation flags documentation to the command's usage template.
|
||||
func AppendFlagsDocs(cmd *cobra.Command) {
|
||||
usage := `
|
||||
Image Generation Flags (experimental):
|
||||
--width int Image width
|
||||
--height int Image height
|
||||
--steps int Denoising steps
|
||||
--seed int Random seed
|
||||
--negative str Negative prompt
|
||||
`
|
||||
cmd.SetUsageTemplate(cmd.UsageTemplate() + usage)
|
||||
}
|
||||
|
||||
// RunCLI handles the CLI for image generation models.
|
||||
// Returns true if it handled the request, false if the caller should continue with normal flow.
|
||||
// Supports flags: --width, --height, --steps, --seed, --negative
|
||||
@@ -111,15 +158,17 @@ func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keep
|
||||
return err
|
||||
}
|
||||
|
||||
// Build request with image gen options encoded in Options fields
|
||||
// NumCtx=width, NumGPU=height, NumPredict=steps, Seed=seed
|
||||
req := &api.GenerateRequest{
|
||||
Model: modelName,
|
||||
Prompt: prompt,
|
||||
Width: int32(opts.Width),
|
||||
Height: int32(opts.Height),
|
||||
Steps: int32(opts.Steps),
|
||||
}
|
||||
if opts.Seed != 0 {
|
||||
req.Options = map[string]any{"seed": opts.Seed}
|
||||
Options: map[string]any{
|
||||
"num_ctx": opts.Width,
|
||||
"num_gpu": opts.Height,
|
||||
"num_predict": opts.Steps,
|
||||
"seed": opts.Seed,
|
||||
},
|
||||
}
|
||||
if keepAlive != nil {
|
||||
req.KeepAlive = keepAlive
|
||||
@@ -133,25 +182,32 @@ func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keep
|
||||
var stepBar *progress.StepBar
|
||||
var imageBase64 string
|
||||
err = client.Generate(cmd.Context(), req, func(resp api.GenerateResponse) error {
|
||||
// Handle progress updates using structured fields
|
||||
if resp.Total > 0 {
|
||||
if stepBar == nil {
|
||||
content := resp.Response
|
||||
|
||||
// Handle progress updates - parse step info and switch to step bar
|
||||
if strings.HasPrefix(content, "\rGenerating:") {
|
||||
var step, total int
|
||||
fmt.Sscanf(content, "\rGenerating: step %d/%d", &step, &total)
|
||||
if stepBar == nil && total > 0 {
|
||||
spinner.Stop()
|
||||
stepBar = progress.NewStepBar("Generating", int(resp.Total))
|
||||
stepBar = progress.NewStepBar("Generating", total)
|
||||
p.Add("", stepBar)
|
||||
}
|
||||
stepBar.Set(int(resp.Completed))
|
||||
if stepBar != nil {
|
||||
stepBar.Set(step)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Handle final response with image data
|
||||
if resp.Done && resp.Image != "" {
|
||||
imageBase64 = resp.Image
|
||||
// Handle final response with base64 image data
|
||||
if resp.Done && strings.HasPrefix(content, "IMAGE_BASE64:") {
|
||||
imageBase64 = content[13:]
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
p.StopAndClear()
|
||||
p.Stop()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -189,23 +245,6 @@ func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duratio
|
||||
return err
|
||||
}
|
||||
|
||||
// Preload the model with the specified keepalive
|
||||
p := progress.NewProgress(os.Stderr)
|
||||
spinner := progress.NewSpinner("")
|
||||
p.Add("", spinner)
|
||||
|
||||
preloadReq := &api.GenerateRequest{
|
||||
Model: modelName,
|
||||
KeepAlive: keepAlive,
|
||||
}
|
||||
if err := client.Generate(cmd.Context(), preloadReq, func(resp api.GenerateResponse) error {
|
||||
return nil
|
||||
}); err != nil {
|
||||
p.StopAndClear()
|
||||
return fmt.Errorf("failed to load model: %w", err)
|
||||
}
|
||||
p.StopAndClear()
|
||||
|
||||
scanner, err := readline.New(readline.Prompt{
|
||||
Prompt: ">>> ",
|
||||
Placeholder: "Describe an image to generate (/help for commands)",
|
||||
@@ -243,7 +282,7 @@ func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duratio
|
||||
case strings.HasPrefix(line, "/bye"):
|
||||
return nil
|
||||
case strings.HasPrefix(line, "/?"), strings.HasPrefix(line, "/help"):
|
||||
printInteractiveHelp()
|
||||
printInteractiveHelp(opts)
|
||||
continue
|
||||
case strings.HasPrefix(line, "/set "):
|
||||
if err := handleSetCommand(line[5:], &opts); err != nil {
|
||||
@@ -262,12 +301,12 @@ func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duratio
|
||||
req := &api.GenerateRequest{
|
||||
Model: modelName,
|
||||
Prompt: line,
|
||||
Width: int32(opts.Width),
|
||||
Height: int32(opts.Height),
|
||||
Steps: int32(opts.Steps),
|
||||
}
|
||||
if opts.Seed != 0 {
|
||||
req.Options = map[string]any{"seed": opts.Seed}
|
||||
Options: map[string]any{
|
||||
"num_ctx": opts.Width,
|
||||
"num_gpu": opts.Height,
|
||||
"num_predict": opts.Steps,
|
||||
"seed": opts.Seed,
|
||||
},
|
||||
}
|
||||
if keepAlive != nil {
|
||||
req.KeepAlive = keepAlive
|
||||
@@ -282,25 +321,32 @@ func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duratio
|
||||
var imageBase64 string
|
||||
|
||||
err = client.Generate(cmd.Context(), req, func(resp api.GenerateResponse) error {
|
||||
// Handle progress updates using structured fields
|
||||
if resp.Total > 0 {
|
||||
if stepBar == nil {
|
||||
content := resp.Response
|
||||
|
||||
// Handle progress updates - parse step info and switch to step bar
|
||||
if strings.HasPrefix(content, "\rGenerating:") {
|
||||
var step, total int
|
||||
fmt.Sscanf(content, "\rGenerating: step %d/%d", &step, &total)
|
||||
if stepBar == nil && total > 0 {
|
||||
spinner.Stop()
|
||||
stepBar = progress.NewStepBar("Generating", int(resp.Total))
|
||||
stepBar = progress.NewStepBar("Generating", total)
|
||||
p.Add("", stepBar)
|
||||
}
|
||||
stepBar.Set(int(resp.Completed))
|
||||
if stepBar != nil {
|
||||
stepBar.Set(step)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Handle final response with image data
|
||||
if resp.Done && resp.Image != "" {
|
||||
imageBase64 = resp.Image
|
||||
// Handle final response with base64 image data
|
||||
if resp.Done && strings.HasPrefix(content, "IMAGE_BASE64:") {
|
||||
imageBase64 = content[13:]
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
p.StopAndClear()
|
||||
p.Stop()
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
continue
|
||||
@@ -351,13 +397,12 @@ func sanitizeFilename(s string) string {
|
||||
}
|
||||
|
||||
// printInteractiveHelp prints help for interactive mode commands.
|
||||
// TODO: reconcile /set commands with /set parameter in text gen REPL (cmd/cmd.go)
|
||||
func printInteractiveHelp() {
|
||||
func printInteractiveHelp(opts ImageGenOptions) {
|
||||
fmt.Fprintln(os.Stderr, "Commands:")
|
||||
fmt.Fprintln(os.Stderr, " /set width <n> Set image width")
|
||||
fmt.Fprintln(os.Stderr, " /set height <n> Set image height")
|
||||
fmt.Fprintln(os.Stderr, " /set steps <n> Set denoising steps")
|
||||
fmt.Fprintln(os.Stderr, " /set seed <n> Set random seed")
|
||||
fmt.Fprintln(os.Stderr, " /set width <n> Set image width (current:", opts.Width, ")")
|
||||
fmt.Fprintln(os.Stderr, " /set height <n> Set image height (current:", opts.Height, ")")
|
||||
fmt.Fprintln(os.Stderr, " /set steps <n> Set denoising steps (current:", opts.Steps, ")")
|
||||
fmt.Fprintln(os.Stderr, " /set seed <n> Set random seed (current:", opts.Seed, ", 0=random)")
|
||||
fmt.Fprintln(os.Stderr, " /set negative <s> Set negative prompt")
|
||||
fmt.Fprintln(os.Stderr, " /show Show current settings")
|
||||
fmt.Fprintln(os.Stderr, " /bye Exit")
|
||||
|
||||
190
x/imagegen/client/create.go
Normal file
@@ -0,0 +1,190 @@
|
||||
// Package client provides client-side model creation for tensor-based models.
|
||||
//
|
||||
// This package is in x/ because the tensor model storage format is under development.
|
||||
// It also exists to break an import cycle: server imports x/imagegen, so x/imagegen
|
||||
// cannot import server. This sub-package can import server because server doesn't
|
||||
// import it.
|
||||
//
|
||||
// TODO (jmorganca): This is temporary. When tensor models are promoted to production:
|
||||
// 1. Add proper API endpoints for tensor model creation
|
||||
// 2. Move tensor extraction to server-side
|
||||
// 3. Remove this package
|
||||
// 4. Follow the same client→server pattern as regular model creation
|
||||
package client
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/ollama/ollama/server"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
)
|
||||
|
||||
// MinOllamaVersion is the minimum Ollama version required for image generation models.
|
||||
const MinOllamaVersion = "0.14.0"
|
||||
|
||||
// CreateModel imports a tensor-based model from a local directory.
|
||||
// This creates blobs and manifest directly on disk, bypassing the HTTP API.
|
||||
// If quantize is "fp8", weights will be quantized to mxfp8 format during import.
|
||||
//
|
||||
// TODO (jmorganca): Replace with API-based creation when promoted to production.
|
||||
func CreateModel(modelName, modelDir, quantize string, p *progress.Progress) error {
|
||||
if !imagegen.IsTensorModelDir(modelDir) {
|
||||
return fmt.Errorf("%s is not an image generation model directory (model_index.json not found)", modelDir)
|
||||
}
|
||||
|
||||
status := "importing image generation model"
|
||||
spinner := progress.NewSpinner(status)
|
||||
p.Add("imagegen", spinner)
|
||||
|
||||
// Create layer callback for config files
|
||||
createLayer := func(r io.Reader, mediaType, name string) (imagegen.LayerInfo, error) {
|
||||
layer, err := server.NewLayer(r, mediaType)
|
||||
if err != nil {
|
||||
return imagegen.LayerInfo{}, err
|
||||
}
|
||||
layer.Name = name
|
||||
|
||||
return imagegen.LayerInfo{
|
||||
Digest: layer.Digest,
|
||||
Size: layer.Size,
|
||||
MediaType: layer.MediaType,
|
||||
Name: name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Create tensor layer callback for individual tensors
|
||||
// name is path-style: "component/tensor_name"
|
||||
// When quantize is true, returns multiple layers (weight + scales)
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, doQuantize bool) ([]imagegen.LayerInfo, error) {
|
||||
if doQuantize {
|
||||
// Check if quantization is supported
|
||||
if !QuantizeSupported() {
|
||||
return nil, fmt.Errorf("quantization requires MLX support")
|
||||
}
|
||||
|
||||
// Quantize the tensor (affine mode returns weight, scales, qbiases)
|
||||
qweightData, scalesData, qbiasData, _, _, _, err := quantizeTensor(r, name, dtype, shape)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to quantize %s: %w", name, err)
|
||||
}
|
||||
|
||||
// Create layer for quantized weight
|
||||
weightLayer, err := server.NewLayer(bytes.NewReader(qweightData), server.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create layer for scales (use _scale suffix convention)
|
||||
scalesLayer, err := server.NewLayer(bytes.NewReader(scalesData), server.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
layers := []imagegen.LayerInfo{
|
||||
{
|
||||
Digest: weightLayer.Digest,
|
||||
Size: weightLayer.Size,
|
||||
MediaType: weightLayer.MediaType,
|
||||
Name: name, // Keep original name for weight
|
||||
},
|
||||
{
|
||||
Digest: scalesLayer.Digest,
|
||||
Size: scalesLayer.Size,
|
||||
MediaType: scalesLayer.MediaType,
|
||||
Name: name + "_scale", // Add _scale suffix
|
||||
},
|
||||
}
|
||||
|
||||
// Add qbiases layer if present (affine mode)
|
||||
if qbiasData != nil {
|
||||
qbiasLayer, err := server.NewLayer(bytes.NewReader(qbiasData), server.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
layers = append(layers, imagegen.LayerInfo{
|
||||
Digest: qbiasLayer.Digest,
|
||||
Size: qbiasLayer.Size,
|
||||
MediaType: qbiasLayer.MediaType,
|
||||
Name: name + "_qbias", // Add _qbias suffix
|
||||
})
|
||||
}
|
||||
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
// Non-quantized path: just create a single layer
|
||||
layer, err := server.NewLayer(r, server.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return []imagegen.LayerInfo{
|
||||
{
|
||||
Digest: layer.Digest,
|
||||
Size: layer.Size,
|
||||
MediaType: layer.MediaType,
|
||||
Name: name,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Create manifest writer callback
|
||||
writeManifest := func(modelName string, config imagegen.LayerInfo, layers []imagegen.LayerInfo) error {
|
||||
name := model.ParseName(modelName)
|
||||
if !name.IsValid() {
|
||||
return fmt.Errorf("invalid model name: %s", modelName)
|
||||
}
|
||||
|
||||
// Create a proper config blob with version requirement
|
||||
configData := model.ConfigV2{
|
||||
ModelFormat: "safetensors",
|
||||
Capabilities: []string{"image"},
|
||||
Requires: MinOllamaVersion,
|
||||
}
|
||||
configJSON, err := json.Marshal(configData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal config: %w", err)
|
||||
}
|
||||
|
||||
// Create config layer blob
|
||||
configLayer, err := server.NewLayer(bytes.NewReader(configJSON), "application/vnd.docker.container.image.v1+json")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create config layer: %w", err)
|
||||
}
|
||||
|
||||
// Convert LayerInfo to server.Layer (include the original model_index.json in layers)
|
||||
serverLayers := make([]server.Layer, len(layers))
|
||||
for i, l := range layers {
|
||||
serverLayers[i] = server.Layer{
|
||||
MediaType: l.MediaType,
|
||||
Digest: l.Digest,
|
||||
Size: l.Size,
|
||||
Name: l.Name,
|
||||
}
|
||||
}
|
||||
|
||||
return server.WriteManifest(name, configLayer, serverLayers)
|
||||
}
|
||||
|
||||
// Progress callback
|
||||
progressFn := func(msg string) {
|
||||
spinner.Stop()
|
||||
status = msg
|
||||
spinner = progress.NewSpinner(status)
|
||||
p.Add("imagegen", spinner)
|
||||
}
|
||||
|
||||
err := imagegen.CreateModel(modelName, modelDir, quantize, createLayer, createTensorLayer, writeManifest, progressFn)
|
||||
spinner.Stop()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Printf("Created image generation model '%s'\n", modelName)
|
||||
return nil
|
||||
}
|
||||
@@ -11,11 +11,10 @@ import (
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// quantizeTensor loads a tensor from safetensors format, quantizes it,
|
||||
// quantizeTensor loads a tensor from safetensors format, quantizes it to affine int8,
|
||||
// and returns safetensors data for the quantized weights, scales, and biases.
|
||||
// Supported quantization types: "fp8" (affine 8-bit)
|
||||
// Uses MLX's native SaveSafetensors to ensure correct dtype handling (especially uint32 for quantized weights).
|
||||
func quantizeTensor(r io.Reader, name, dtype string, shape []int32, quantize string) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) {
|
||||
func quantizeTensor(r io.Reader, name, dtype string, shape []int32) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) {
|
||||
tmpDir := ensureTempDir()
|
||||
|
||||
// Read safetensors data to a temp file (LoadSafetensorsNative needs a path)
|
||||
@@ -51,18 +50,9 @@ func quantizeTensor(r io.Reader, name, dtype string, shape []int32, quantize str
|
||||
mlx.Eval(arr)
|
||||
}
|
||||
|
||||
// Quantize based on quantization type
|
||||
var qweight, scales, qbiases *mlx.Array
|
||||
switch quantize {
|
||||
case "fp4":
|
||||
// affine mode: group_size=32, bits=4
|
||||
qweight, scales, qbiases = mlx.Quantize(arr, 32, 4, "affine")
|
||||
case "fp8":
|
||||
// affine mode: group_size=32, bits=8
|
||||
qweight, scales, qbiases = mlx.Quantize(arr, 32, 8, "affine")
|
||||
default:
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("unsupported quantization type: %s", quantize)
|
||||
}
|
||||
// Quantize with affine mode: group_size=32, bits=8
|
||||
// Note: mxfp8 mode doesn't have matmul kernels in MLX, affine mode does
|
||||
qweight, scales, qbiases := mlx.Quantize(arr, 32, 8, "affine")
|
||||
|
||||
// Eval and make contiguous for data access
|
||||
qweight = mlx.Contiguous(qweight)
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
)
|
||||
|
||||
// quantizeTensor is not available without MLX
|
||||
func quantizeTensor(r io.Reader, name, dtype string, shape []int32, quantize string) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) {
|
||||
func quantizeTensor(r io.Reader, name, dtype string, shape []int32) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("quantization requires MLX support (build with mlx tag)")
|
||||
}
|
||||
|
||||
@@ -65,12 +65,12 @@ func (s *utf8Streamer) Flush() string {
|
||||
return result
|
||||
}
|
||||
|
||||
func init() {
|
||||
generationStream = mlx.NewStream()
|
||||
}
|
||||
|
||||
// withStream runs fn with the generation stream as default
|
||||
func withStream(fn func()) {
|
||||
// Lazy initialization of generationStream
|
||||
if generationStream == nil {
|
||||
generationStream = mlx.NewStream()
|
||||
}
|
||||
orig := mlx.GetDefaultStream()
|
||||
mlx.SetDefaultStream(generationStream)
|
||||
fn()
|
||||
|
||||
@@ -7,18 +7,15 @@ import (
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"image"
|
||||
_ "image/jpeg"
|
||||
_ "image/png"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime/pprof"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/models/flux2"
|
||||
"github.com/ollama/ollama/x/imagegen/models/gemma3"
|
||||
"github.com/ollama/ollama/x/imagegen/models/glm_image"
|
||||
"github.com/ollama/ollama/x/imagegen/models/gpt_oss"
|
||||
"github.com/ollama/ollama/x/imagegen/models/llama"
|
||||
"github.com/ollama/ollama/x/imagegen/models/qwen_image"
|
||||
@@ -51,9 +48,9 @@ func main() {
|
||||
imagePath := flag.String("image", "", "Image path for multimodal models")
|
||||
|
||||
// Image generation params
|
||||
width := flag.Int("width", 0, "Image width (0 = auto from input or 1024)")
|
||||
height := flag.Int("height", 0, "Image height (0 = auto from input or 1024)")
|
||||
steps := flag.Int("steps", 0, "Denoising steps (0 = model default)")
|
||||
width := flag.Int("width", 1024, "Image width")
|
||||
height := flag.Int("height", 1024, "Image height")
|
||||
steps := flag.Int("steps", 9, "Denoising steps")
|
||||
seed := flag.Int64("seed", 42, "Random seed")
|
||||
out := flag.String("output", "output.png", "Output path")
|
||||
|
||||
@@ -66,7 +63,7 @@ func main() {
|
||||
|
||||
// Legacy mode flags
|
||||
zimageFlag := flag.Bool("zimage", false, "Z-Image generation")
|
||||
flux2Flag := flag.Bool("flux2", false, "FLUX.2 Klein generation")
|
||||
glmImageFlag := flag.Bool("glm-image", false, "GLM-Image generation")
|
||||
qwenImage := flag.Bool("qwen-image", false, "Qwen-Image text-to-image generation")
|
||||
qwenImageEdit := flag.Bool("qwen-image-edit", false, "Qwen-Image-Edit image editing")
|
||||
var inputImages stringSlice
|
||||
@@ -84,11 +81,6 @@ func main() {
|
||||
return
|
||||
}
|
||||
|
||||
// Check if MLX initialized successfully
|
||||
if !mlx.IsMLXAvailable() {
|
||||
log.Fatalf("MLX initialization failed: %v", mlx.GetMLXInitError())
|
||||
}
|
||||
|
||||
// CPU profiling
|
||||
if *cpuProfile != "" {
|
||||
f, err := os.Create(*cpuProfile)
|
||||
@@ -128,40 +120,29 @@ func main() {
|
||||
if err == nil {
|
||||
err = saveImageArray(img, *out)
|
||||
}
|
||||
case *flux2Flag:
|
||||
m := &flux2.Model{}
|
||||
if loadErr := m.Load(*modelPath); loadErr != nil {
|
||||
case *glmImageFlag:
|
||||
m := &glm_image.Model{}
|
||||
// Use LoadFromPath if model path looks like a directory, otherwise use Load (ollama manifest)
|
||||
var loadErr error
|
||||
if strings.HasPrefix(*modelPath, ".") || strings.HasPrefix(*modelPath, "/") {
|
||||
loadErr = m.LoadFromPath(*modelPath)
|
||||
} else {
|
||||
loadErr = m.Load(*modelPath)
|
||||
}
|
||||
if loadErr != nil {
|
||||
log.Fatal(loadErr)
|
||||
}
|
||||
// Load input images with EXIF orientation correction
|
||||
var loadedImages []image.Image
|
||||
for _, path := range inputImages {
|
||||
img, loadErr := loadImageWithEXIF(path)
|
||||
if loadErr != nil {
|
||||
log.Fatalf("Failed to load image %s: %v", path, loadErr)
|
||||
}
|
||||
loadedImages = append(loadedImages, img)
|
||||
}
|
||||
// When input images provided and user didn't override dimensions, use 0 to match input
|
||||
fluxWidth := int32(*width)
|
||||
fluxHeight := int32(*height)
|
||||
if len(loadedImages) > 0 && *width == 0 && *height == 0 {
|
||||
// Both unset, will auto-detect from input
|
||||
} else if len(loadedImages) > 0 && *width == 0 {
|
||||
fluxWidth = 0 // Compute from height + aspect ratio
|
||||
} else if len(loadedImages) > 0 && *height == 0 {
|
||||
fluxHeight = 0 // Compute from width + aspect ratio
|
||||
}
|
||||
var img *mlx.Array
|
||||
img, err = m.GenerateFromConfig(context.Background(), &flux2.GenerateConfig{
|
||||
Prompt: *prompt,
|
||||
Width: fluxWidth,
|
||||
Height: fluxHeight,
|
||||
Steps: *steps,
|
||||
GuidanceScale: float32(*cfgScale),
|
||||
Seed: *seed,
|
||||
CapturePath: *gpuCapture,
|
||||
InputImages: loadedImages,
|
||||
img, err = m.GenerateFromConfig(context.Background(), &glm_image.GenerateConfig{
|
||||
Prompt: *prompt,
|
||||
Width: int32(*width),
|
||||
Height: int32(*height),
|
||||
Steps: *steps,
|
||||
Seed: *seed,
|
||||
Temperature: float32(*temperature),
|
||||
TopP: float32(*topP),
|
||||
GuidanceScale: float32(*cfgScale),
|
||||
MaxVisualTokens: int32(*maxTokens),
|
||||
})
|
||||
if err == nil {
|
||||
err = saveImageArray(img, *out)
|
||||
@@ -320,8 +301,6 @@ func detectModelKind(modelPath string) (string, error) {
|
||||
switch index.ClassName {
|
||||
case "FluxPipeline", "ZImagePipeline":
|
||||
return "zimage", nil
|
||||
case "Flux2KleinPipeline":
|
||||
return "flux2", nil
|
||||
}
|
||||
}
|
||||
return "zimage", nil
|
||||
@@ -342,12 +321,3 @@ func detectModelKind(modelPath string) (string, error) {
|
||||
|
||||
return cfg.ModelType, nil
|
||||
}
|
||||
|
||||
// loadImageWithEXIF loads an image from a file path with EXIF orientation correction.
|
||||
func loadImageWithEXIF(path string) (image.Image, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read file: %w", err)
|
||||
}
|
||||
return imagegen.DecodeImage(data)
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package create
|
||||
package imagegen
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
@@ -12,27 +12,43 @@ import (
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
|
||||
// CreateImageGenModel imports an image generation model from a directory.
|
||||
// Stores each tensor as a separate blob for fine-grained deduplication.
|
||||
// If quantize is specified, linear weights in transformer/text_encoder are quantized.
|
||||
// Supported quantization types: fp8 (or empty for no quantization).
|
||||
// Layer creation and manifest writing are done via callbacks to avoid import cycles.
|
||||
func CreateImageGenModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
|
||||
// Validate quantization type
|
||||
switch quantize {
|
||||
case "", "fp4", "fp8":
|
||||
// valid
|
||||
default:
|
||||
return fmt.Errorf("unsupported quantization type %q: supported types are fp4, fp8", quantize)
|
||||
}
|
||||
// IsTensorModelDir checks if the directory contains a tensor model
|
||||
// by looking for model_index.json, which is the standard diffusers pipeline config.
|
||||
func IsTensorModelDir(dir string) bool {
|
||||
_, err := os.Stat(filepath.Join(dir, "model_index.json"))
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// LayerInfo holds metadata for a created layer.
|
||||
type LayerInfo struct {
|
||||
Digest string
|
||||
Size int64
|
||||
MediaType string
|
||||
Name string // Path-style name: "component/tensor" or "path/to/config.json"
|
||||
}
|
||||
|
||||
// LayerCreator is called to create a blob layer.
|
||||
// name is the path-style name (e.g., "tokenizer/tokenizer.json")
|
||||
type LayerCreator func(r io.Reader, mediaType, name string) (LayerInfo, error)
|
||||
|
||||
// TensorLayerCreator creates a tensor blob layer with metadata.
|
||||
// name is the path-style name including component (e.g., "text_encoder/model.embed_tokens.weight")
|
||||
type TensorLayerCreator func(r io.Reader, name, dtype string, shape []int32) (LayerInfo, error)
|
||||
|
||||
// ManifestWriter writes the manifest file.
|
||||
type ManifestWriter func(modelName string, config LayerInfo, layers []LayerInfo) error
|
||||
|
||||
// CreateModel imports an image generation model from a directory.
|
||||
// Stores each tensor as a separate blob for fine-grained deduplication.
|
||||
// If quantize is "fp8", linear weights in transformer/text_encoder are quantized to mxfp8 format.
|
||||
// Layer creation and manifest writing are done via callbacks to avoid import cycles.
|
||||
func CreateModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
|
||||
var layers []LayerInfo
|
||||
var configLayer LayerInfo
|
||||
var totalParams int64 // Count parameters from original tensor shapes
|
||||
var torchDtype string // Read from component config for quantization display
|
||||
|
||||
// Components to process - extract individual tensors from each
|
||||
components := []string{"text_encoder", "transformer", "vae"}
|
||||
components := []string{"text_encoder", "transformer", "vae", "vision_language_encoder"}
|
||||
|
||||
for _, component := range components {
|
||||
componentDir := filepath.Join(modelDir, component)
|
||||
@@ -61,8 +77,8 @@ func CreateImageGenModel(modelName, modelDir, quantize string, createLayer Layer
|
||||
|
||||
tensorNames := extractor.ListTensors()
|
||||
quantizeMsg := ""
|
||||
if quantize != "" && component != "vae" {
|
||||
quantizeMsg = ", quantizing to " + quantize
|
||||
if quantize == "fp8" && component != "vae" {
|
||||
quantizeMsg = ", quantizing to fp8"
|
||||
}
|
||||
fn(fmt.Sprintf("importing %s/%s (%d tensors%s)", component, entry.Name(), len(tensorNames), quantizeMsg))
|
||||
|
||||
@@ -87,14 +103,11 @@ func CreateImageGenModel(modelName, modelDir, quantize string, createLayer Layer
|
||||
// Use path-style name: "component/tensor_name"
|
||||
fullName := component + "/" + tensorName
|
||||
|
||||
// Determine quantization type for this tensor (empty string if not quantizing)
|
||||
quantizeType := ""
|
||||
if quantize != "" && ShouldQuantize(tensorName, component) && canQuantizeShape(td.Shape) {
|
||||
quantizeType = quantize
|
||||
}
|
||||
// Determine if this tensor should be quantized
|
||||
doQuantize := quantize == "fp8" && ShouldQuantize(tensorName, component)
|
||||
|
||||
// createTensorLayer returns multiple layers if quantizing (weight + scales)
|
||||
newLayers, err := createTensorLayer(td.SafetensorsReader(), fullName, td.Dtype, td.Shape, quantizeType)
|
||||
newLayers, err := createTensorLayer(td.SafetensorsReader(), fullName, td.Dtype, td.Shape, doQuantize)
|
||||
if err != nil {
|
||||
extractor.Close()
|
||||
return fmt.Errorf("failed to create layer for %s: %w", fullName, err)
|
||||
@@ -106,19 +119,6 @@ func CreateImageGenModel(modelName, modelDir, quantize string, createLayer Layer
|
||||
}
|
||||
}
|
||||
|
||||
// Read torch_dtype from text_encoder config for quantization display
|
||||
if torchDtype == "" {
|
||||
textEncoderConfig := filepath.Join(modelDir, "text_encoder/config.json")
|
||||
if data, err := os.ReadFile(textEncoderConfig); err == nil {
|
||||
var cfg struct {
|
||||
TorchDtype string `json:"torch_dtype"`
|
||||
}
|
||||
if json.Unmarshal(data, &cfg) == nil && cfg.TorchDtype != "" {
|
||||
torchDtype = cfg.TorchDtype
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Import config files
|
||||
configFiles := []string{
|
||||
"model_index.json",
|
||||
@@ -126,10 +126,13 @@ func CreateImageGenModel(modelName, modelDir, quantize string, createLayer Layer
|
||||
"text_encoder/generation_config.json",
|
||||
"transformer/config.json",
|
||||
"vae/config.json",
|
||||
"vision_language_encoder/config.json",
|
||||
"scheduler/scheduler_config.json",
|
||||
"tokenizer/tokenizer.json",
|
||||
"tokenizer/tokenizer_config.json",
|
||||
"tokenizer/vocab.json",
|
||||
"processor/tokenizer.json", // GLM-Image main tokenizer
|
||||
"processor/tokenizer_config.json", // GLM-Image tokenizer config
|
||||
}
|
||||
|
||||
for _, cfgPath := range configFiles {
|
||||
@@ -164,11 +167,11 @@ func CreateImageGenModel(modelName, modelDir, quantize string, createLayer Layer
|
||||
// Add parameter count (counted from tensor shapes during import)
|
||||
cfg["parameter_count"] = totalParams
|
||||
|
||||
// Add quantization info - use quantize type if set, otherwise torch_dtype
|
||||
if quantize != "" {
|
||||
cfg["quantization"] = strings.ToUpper(quantize)
|
||||
// Add quantization info
|
||||
if quantize == "fp8" {
|
||||
cfg["quantization"] = "FP8"
|
||||
} else {
|
||||
cfg["quantization"] = torchDtype
|
||||
cfg["quantization"] = "BF16"
|
||||
}
|
||||
|
||||
data, err = json.MarshalIndent(cfg, "", " ")
|
||||
@@ -211,12 +214,3 @@ func CreateImageGenModel(modelName, modelDir, quantize string, createLayer Layer
|
||||
fn(fmt.Sprintf("successfully imported %s with %d layers", modelName, len(layers)))
|
||||
return nil
|
||||
}
|
||||
|
||||
// canQuantizeShape returns true if a tensor shape is compatible with MLX quantization.
|
||||
// MLX requires the last dimension to be divisible by the group size (32).
|
||||
func canQuantizeShape(shape []int32) bool {
|
||||
if len(shape) < 2 {
|
||||
return false
|
||||
}
|
||||
return shape[len(shape)-1]%32 == 0
|
||||
}
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"image"
|
||||
_ "image/jpeg"
|
||||
"image/png"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -109,160 +108,3 @@ func clampF(v, min, max float32) float32 {
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// DecodeImage decodes image bytes with EXIF orientation applied.
|
||||
func DecodeImage(data []byte) (image.Image, error) {
|
||||
orientation := readJPEGOrientation(data)
|
||||
|
||||
img, _, err := image.Decode(bytes.NewReader(data))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return applyOrientation(img, orientation), nil
|
||||
}
|
||||
|
||||
// readJPEGOrientation extracts EXIF orientation from JPEG bytes.
|
||||
// Returns 1 (normal) for non-JPEG or if orientation not found.
|
||||
func readJPEGOrientation(data []byte) int {
|
||||
if len(data) < 2 || data[0] != 0xFF || data[1] != 0xD8 {
|
||||
return 1 // Not JPEG
|
||||
}
|
||||
|
||||
r := bytes.NewReader(data[2:])
|
||||
for {
|
||||
var marker [2]byte
|
||||
if _, err := r.Read(marker[:]); err != nil || marker[0] != 0xFF {
|
||||
return 1
|
||||
}
|
||||
|
||||
if marker[1] == 0xE1 { // APP1 (EXIF)
|
||||
var lenBytes [2]byte
|
||||
if _, err := r.Read(lenBytes[:]); err != nil {
|
||||
return 1
|
||||
}
|
||||
segLen := int(uint16(lenBytes[0])<<8|uint16(lenBytes[1])) - 2
|
||||
if segLen < 14 {
|
||||
r.Seek(int64(segLen), 1)
|
||||
continue
|
||||
}
|
||||
seg := make([]byte, segLen)
|
||||
if _, err := r.Read(seg); err != nil {
|
||||
return 1
|
||||
}
|
||||
if string(seg[:4]) == "Exif" && seg[4] == 0 && seg[5] == 0 {
|
||||
return parseTIFFOrientation(seg[6:])
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if marker[1] == 0xD9 || marker[1] == 0xDA {
|
||||
return 1 // EOI or SOS
|
||||
}
|
||||
if marker[1] >= 0xD0 && marker[1] <= 0xD7 {
|
||||
continue // RST markers
|
||||
}
|
||||
|
||||
var lenBytes [2]byte
|
||||
if _, err := r.Read(lenBytes[:]); err != nil {
|
||||
return 1
|
||||
}
|
||||
segLen := int(uint16(lenBytes[0])<<8|uint16(lenBytes[1])) - 2
|
||||
if segLen > 0 {
|
||||
r.Seek(int64(segLen), 1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func parseTIFFOrientation(tiff []byte) int {
|
||||
if len(tiff) < 8 {
|
||||
return 1
|
||||
}
|
||||
|
||||
var big bool
|
||||
switch string(tiff[:2]) {
|
||||
case "MM":
|
||||
big = true
|
||||
case "II":
|
||||
big = false
|
||||
default:
|
||||
return 1
|
||||
}
|
||||
|
||||
u16 := func(b []byte) uint16 {
|
||||
if big {
|
||||
return uint16(b[0])<<8 | uint16(b[1])
|
||||
}
|
||||
return uint16(b[1])<<8 | uint16(b[0])
|
||||
}
|
||||
u32 := func(b []byte) uint32 {
|
||||
if big {
|
||||
return uint32(b[0])<<24 | uint32(b[1])<<16 | uint32(b[2])<<8 | uint32(b[3])
|
||||
}
|
||||
return uint32(b[3])<<24 | uint32(b[2])<<16 | uint32(b[1])<<8 | uint32(b[0])
|
||||
}
|
||||
|
||||
if u16(tiff[2:4]) != 42 {
|
||||
return 1
|
||||
}
|
||||
|
||||
ifdOffset := u32(tiff[4:8])
|
||||
if int(ifdOffset)+2 > len(tiff) {
|
||||
return 1
|
||||
}
|
||||
|
||||
numEntries := u16(tiff[ifdOffset : ifdOffset+2])
|
||||
for i := range int(numEntries) {
|
||||
offset := ifdOffset + 2 + uint32(i)*12
|
||||
if int(offset)+12 > len(tiff) {
|
||||
break
|
||||
}
|
||||
if u16(tiff[offset:offset+2]) == 0x0112 { // Orientation tag
|
||||
o := int(u16(tiff[offset+8 : offset+10]))
|
||||
if o >= 1 && o <= 8 {
|
||||
return o
|
||||
}
|
||||
return 1
|
||||
}
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
func applyOrientation(img image.Image, orientation int) image.Image {
|
||||
if orientation <= 1 || orientation > 8 {
|
||||
return img
|
||||
}
|
||||
|
||||
bounds := img.Bounds()
|
||||
w, h := bounds.Dx(), bounds.Dy()
|
||||
|
||||
outW, outH := w, h
|
||||
if orientation >= 5 {
|
||||
outW, outH = h, w
|
||||
}
|
||||
|
||||
out := image.NewRGBA(image.Rect(0, 0, outW, outH))
|
||||
for y := range h {
|
||||
for x := range w {
|
||||
var dx, dy int
|
||||
switch orientation {
|
||||
case 2:
|
||||
dx, dy = w-1-x, y
|
||||
case 3:
|
||||
dx, dy = w-1-x, h-1-y
|
||||
case 4:
|
||||
dx, dy = x, h-1-y
|
||||
case 5:
|
||||
dx, dy = y, x
|
||||
case 6:
|
||||
dx, dy = h-1-y, x
|
||||
case 7:
|
||||
dx, dy = h-1-y, w-1-x
|
||||
case 8:
|
||||
dx, dy = y, w-1-x
|
||||
}
|
||||
out.Set(dx, dy, img.At(x+bounds.Min.X, y+bounds.Min.Y))
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
19
x/imagegen/imagegen.md
Normal file
@@ -0,0 +1,19 @@
|
||||
# Image generation models (experimental)
|
||||
|
||||
Experimental image generation models are available for **macOS** in Ollama:
|
||||
|
||||
## Available models
|
||||
|
||||
- [Z-Image-Turbo](https://ollama.com/x/z-image-turbo)
|
||||
|
||||
```
|
||||
ollama run x/z-image-turbo
|
||||
```
|
||||
|
||||
> **Note**: [`x`](https://ollama.com/x) is a username on ollama.com where the maintainer team uploads experimental models
|
||||
|
||||
More models coming soon:
|
||||
|
||||
1. Qwen-Image-2512
|
||||
2. Qwen-Image-Edit-2511
|
||||
3. GLM-Image
|
||||
@@ -175,63 +175,3 @@ func (m *ModelManifest) HasTensorLayers() bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ModelInfo contains metadata about an image generation model.
|
||||
type ModelInfo struct {
|
||||
Architecture string
|
||||
ParameterCount int64
|
||||
Quantization string
|
||||
}
|
||||
|
||||
// GetModelInfo returns metadata about an image generation model.
|
||||
func GetModelInfo(modelName string) (*ModelInfo, error) {
|
||||
manifest, err := LoadManifest(modelName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load manifest: %w", err)
|
||||
}
|
||||
|
||||
info := &ModelInfo{}
|
||||
|
||||
// Read model_index.json for architecture, parameter count, and quantization
|
||||
if data, err := manifest.ReadConfig("model_index.json"); err == nil {
|
||||
var index struct {
|
||||
Architecture string `json:"architecture"`
|
||||
ParameterCount int64 `json:"parameter_count"`
|
||||
Quantization string `json:"quantization"`
|
||||
}
|
||||
if json.Unmarshal(data, &index) == nil {
|
||||
info.Architecture = index.Architecture
|
||||
info.ParameterCount = index.ParameterCount
|
||||
info.Quantization = index.Quantization
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: detect quantization from tensor names if not in config
|
||||
if info.Quantization == "" {
|
||||
for _, layer := range manifest.Manifest.Layers {
|
||||
if strings.HasSuffix(layer.Name, ".weight_scale") {
|
||||
info.Quantization = "FP8"
|
||||
break
|
||||
}
|
||||
}
|
||||
if info.Quantization == "" {
|
||||
info.Quantization = "BF16"
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: estimate parameter count if not in config
|
||||
if info.ParameterCount == 0 {
|
||||
var totalSize int64
|
||||
for _, layer := range manifest.Manifest.Layers {
|
||||
if layer.MediaType == "application/vnd.ollama.image.tensor" {
|
||||
if !strings.HasSuffix(layer.Name, "_scale") && !strings.HasSuffix(layer.Name, "_qbias") {
|
||||
totalSize += layer.Size
|
||||
}
|
||||
}
|
||||
}
|
||||
// Assume BF16 (2 bytes/param) as rough estimate
|
||||
info.ParameterCount = totalSize / 2
|
||||
}
|
||||
|
||||
return info, nil
|
||||
}
|
||||
|
||||
@@ -24,8 +24,10 @@ var SupportedBackends = []string{"metal", "cuda", "cpu"}
|
||||
|
||||
// modelVRAMEstimates maps pipeline class names to their estimated VRAM requirements.
|
||||
var modelVRAMEstimates = map[string]uint64{
|
||||
"ZImagePipeline": 21 * GB, // ~21GB for Z-Image (text encoder + transformer + VAE)
|
||||
"FluxPipeline": 20 * GB, // ~20GB for Flux
|
||||
"ZImagePipeline": 21 * GB, // ~21GB for Z-Image (text encoder + transformer + VAE)
|
||||
"FluxPipeline": 21 * GB, // ~21GB for Flux (same architecture)
|
||||
"QwenImagePipeline": 80 * GB, // TODO: verify actual requirements, using conservative estimate for now
|
||||
"GlmImagePipeline": 80 * GB, // ~34GB weights + ~46GB working memory for 9B+7B hybrid model
|
||||
}
|
||||
|
||||
// CheckPlatformSupport validates that image generation is supported on the current platform.
|
||||
@@ -71,38 +73,31 @@ func ResolveModelName(modelName string) string {
|
||||
// EstimateVRAM returns the estimated VRAM needed for an image generation model.
|
||||
// Returns a conservative default of 21GB if the model type cannot be determined.
|
||||
func EstimateVRAM(modelName string) uint64 {
|
||||
className := DetectModelType(modelName)
|
||||
if estimate, ok := modelVRAMEstimates[className]; ok {
|
||||
manifest, err := LoadManifest(modelName)
|
||||
if err != nil {
|
||||
return 21 * GB
|
||||
}
|
||||
|
||||
data, err := manifest.ReadConfig("model_index.json")
|
||||
if err != nil {
|
||||
return 21 * GB
|
||||
}
|
||||
|
||||
// Parse just the class name
|
||||
var index struct {
|
||||
ClassName string `json:"_class_name"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &index); err != nil {
|
||||
return 21 * GB
|
||||
}
|
||||
|
||||
if estimate, ok := modelVRAMEstimates[index.ClassName]; ok {
|
||||
return estimate
|
||||
}
|
||||
return 21 * GB
|
||||
}
|
||||
|
||||
// DetectModelType reads model_index.json and returns the model type.
|
||||
// Checks both "architecture" (Ollama format) and "_class_name" (diffusers format).
|
||||
// Returns empty string if detection fails.
|
||||
func DetectModelType(modelName string) string {
|
||||
manifest, err := LoadManifest(modelName)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
data, err := manifest.ReadConfig("model_index.json")
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
var index struct {
|
||||
Architecture string `json:"architecture"`
|
||||
ClassName string `json:"_class_name"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &index); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Prefer architecture (Ollama format), fall back to _class_name (diffusers)
|
||||
if index.Architecture != "" {
|
||||
return index.Architecture
|
||||
}
|
||||
return index.ClassName
|
||||
// HasTensorLayers checks if the given model has tensor layers.
|
||||
func HasTensorLayers(modelName string) bool {
|
||||
return ResolveModelName(modelName) != ""
|
||||
}
|
||||
|
||||
@@ -72,8 +72,9 @@ func TestCheckMemoryRequirements(t *testing.T) {
|
||||
func TestModelVRAMEstimates(t *testing.T) {
|
||||
// Verify the VRAM estimates map has expected entries
|
||||
expected := map[string]uint64{
|
||||
"ZImagePipeline": 21 * GB,
|
||||
"FluxPipeline": 20 * GB,
|
||||
"ZImagePipeline": 21 * GB,
|
||||
"FluxPipeline": 21 * GB,
|
||||
"QwenImagePipeline": 80 * GB,
|
||||
}
|
||||
|
||||
for name, expectedVRAM := range expected {
|
||||
@@ -93,6 +94,13 @@ func TestEstimateVRAMDefault(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasTensorLayers(t *testing.T) {
|
||||
// Non-existent model should return false
|
||||
if HasTensorLayers("nonexistent-model") {
|
||||
t.Error("HasTensorLayers() should return false for non-existent model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveModelName(t *testing.T) {
|
||||
// Non-existent model should return empty string
|
||||
result := ResolveModelName("nonexistent-model")
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
package mlx
|
||||
|
||||
/*
|
||||
#include "mlx.h"
|
||||
#include "mlx/c/mlx.h"
|
||||
#include <stdlib.h>
|
||||
|
||||
// Forward declaration for Go callback
|
||||
|
||||
@@ -1,6 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package mlx provides Go bindings for the MLX-C library with dynamic loading support.
|
||||
//
|
||||
//go:generate go run generate_wrappers.go ../../../build/_deps/mlx-c-src/mlx/c mlx.h mlx.c
|
||||
package mlx
|
||||
@@ -1,439 +0,0 @@
|
||||
//go:build ignore
|
||||
|
||||
// This tool generates MLX-C dynamic loading wrappers.
|
||||
// Usage: go run generate_wrappers.go <mlx-c-include-dir> <output-header> [output-impl]
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Function struct {
|
||||
Name string
|
||||
ReturnType string
|
||||
Params string
|
||||
ParamNames []string
|
||||
NeedsARM64Guard bool
|
||||
}
|
||||
|
||||
func findHeaders(directory string) ([]string, error) {
|
||||
var headers []string
|
||||
err := filepath.WalkDir(directory, func(path string, d fs.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !d.IsDir() && strings.HasSuffix(path, ".h") {
|
||||
headers = append(headers, path)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return headers, err
|
||||
}
|
||||
|
||||
func cleanContent(content string) string {
|
||||
// Remove single-line comments
|
||||
re := regexp.MustCompile(`//.*?\n`)
|
||||
content = re.ReplaceAllString(content, "\n")
|
||||
|
||||
// Remove multi-line comments
|
||||
re = regexp.MustCompile(`/\*.*?\*/`)
|
||||
content = re.ReplaceAllString(content, "")
|
||||
|
||||
// Remove preprocessor directives (lines starting with #) - use multiline mode
|
||||
re = regexp.MustCompile(`(?m)^\s*#.*?$`)
|
||||
content = re.ReplaceAllString(content, "")
|
||||
|
||||
// Remove extern "C" { and } blocks more conservatively
|
||||
// Only remove the extern "C" { line, not the content inside
|
||||
re = regexp.MustCompile(`extern\s+"C"\s*\{\s*?\n`)
|
||||
content = re.ReplaceAllString(content, "\n")
|
||||
// Remove standalone closing braces that are not part of function declarations
|
||||
re = regexp.MustCompile(`\n\s*\}\s*\n`)
|
||||
content = re.ReplaceAllString(content, "\n")
|
||||
|
||||
// Collapse whitespace and newlines
|
||||
re = regexp.MustCompile(`\s+`)
|
||||
content = re.ReplaceAllString(content, " ")
|
||||
|
||||
return content
|
||||
}
|
||||
|
||||
func extractParamNames(params string) []string {
|
||||
if params == "" || strings.TrimSpace(params) == "void" {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
var names []string
|
||||
|
||||
// Split by comma, but respect parentheses (for function pointers)
|
||||
parts := splitParams(params)
|
||||
|
||||
// Remove array brackets
|
||||
arrayBrackets := regexp.MustCompile(`\[.*?\]`)
|
||||
|
||||
// Function pointer pattern
|
||||
funcPtrPattern := regexp.MustCompile(`\(\s*\*\s*(\w+)\s*\)`)
|
||||
|
||||
// Type keywords to skip
|
||||
typeKeywords := map[string]bool{
|
||||
"const": true,
|
||||
"struct": true,
|
||||
"unsigned": true,
|
||||
"signed": true,
|
||||
"long": true,
|
||||
"short": true,
|
||||
"int": true,
|
||||
"char": true,
|
||||
"float": true,
|
||||
"double": true,
|
||||
"void": true,
|
||||
"size_t": true,
|
||||
"uint8_t": true,
|
||||
"uint16_t": true,
|
||||
"uint32_t": true,
|
||||
"uint64_t": true,
|
||||
"int8_t": true,
|
||||
"int16_t": true,
|
||||
"int32_t": true,
|
||||
"int64_t": true,
|
||||
"intptr_t": true,
|
||||
"uintptr_t": true,
|
||||
}
|
||||
|
||||
for _, part := range parts {
|
||||
if part == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Remove array brackets
|
||||
part = arrayBrackets.ReplaceAllString(part, "")
|
||||
|
||||
// For function pointers like "void (*callback)(int)"
|
||||
if matches := funcPtrPattern.FindStringSubmatch(part); len(matches) > 1 {
|
||||
names = append(names, matches[1])
|
||||
continue
|
||||
}
|
||||
|
||||
// Regular parameter: last identifier
|
||||
tokens := regexp.MustCompile(`\w+`).FindAllString(part, -1)
|
||||
if len(tokens) > 0 {
|
||||
// The last token is usually the parameter name
|
||||
// Skip type keywords
|
||||
for i := len(tokens) - 1; i >= 0; i-- {
|
||||
if !typeKeywords[tokens[i]] {
|
||||
names = append(names, tokens[i])
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return names
|
||||
}
|
||||
|
||||
func splitParams(params string) []string {
|
||||
var parts []string
|
||||
var current bytes.Buffer
|
||||
depth := 0
|
||||
|
||||
for _, char := range params + "," {
|
||||
switch char {
|
||||
case '(':
|
||||
depth++
|
||||
current.WriteRune(char)
|
||||
case ')':
|
||||
depth--
|
||||
current.WriteRune(char)
|
||||
case ',':
|
||||
if depth == 0 {
|
||||
parts = append(parts, strings.TrimSpace(current.String()))
|
||||
current.Reset()
|
||||
} else {
|
||||
current.WriteRune(char)
|
||||
}
|
||||
default:
|
||||
current.WriteRune(char)
|
||||
}
|
||||
}
|
||||
|
||||
return parts
|
||||
}
|
||||
|
||||
func parseFunctions(content string) []Function {
|
||||
var functions []Function
|
||||
|
||||
// Match function declarations: return_type function_name(params);
|
||||
// Matches both mlx_* and _mlx_* functions
|
||||
pattern := regexp.MustCompile(`\b((?:const\s+)?(?:struct\s+)?[\w\s]+?[\*\s]*)\s+(_?mlx_\w+)\s*\(([^)]*(?:\([^)]*\)[^)]*)*)\)\s*;`)
|
||||
|
||||
matches := pattern.FindAllStringSubmatch(content, -1)
|
||||
for _, match := range matches {
|
||||
returnType := strings.TrimSpace(match[1])
|
||||
funcName := strings.TrimSpace(match[2])
|
||||
params := strings.TrimSpace(match[3])
|
||||
|
||||
// Skip if this looks like a variable declaration
|
||||
if params == "" || strings.Contains(params, "{") {
|
||||
continue
|
||||
}
|
||||
|
||||
// Clean up return type
|
||||
returnType = strings.Join(strings.Fields(returnType), " ")
|
||||
|
||||
// Extract parameter names
|
||||
paramNames := extractParamNames(params)
|
||||
|
||||
// Check if ARM64 guard is needed
|
||||
needsGuard := needsARM64Guard(funcName, returnType, params)
|
||||
|
||||
functions = append(functions, Function{
|
||||
Name: funcName,
|
||||
ReturnType: returnType,
|
||||
Params: params,
|
||||
ParamNames: paramNames,
|
||||
NeedsARM64Guard: needsGuard,
|
||||
})
|
||||
}
|
||||
|
||||
return functions
|
||||
}
|
||||
|
||||
func needsARM64Guard(name, retType, params string) bool {
|
||||
return strings.Contains(name, "float16") ||
|
||||
strings.Contains(name, "bfloat16") ||
|
||||
strings.Contains(retType, "float16_t") ||
|
||||
strings.Contains(retType, "bfloat16_t") ||
|
||||
strings.Contains(params, "float16_t") ||
|
||||
strings.Contains(params, "bfloat16_t")
|
||||
}
|
||||
|
||||
func generateWrapperFiles(functions []Function, headerPath, implPath string) error {
|
||||
// Generate header file
|
||||
var headerBuf bytes.Buffer
|
||||
|
||||
headerBuf.WriteString("// AUTO-GENERATED by generate_wrappers.go - DO NOT EDIT\n")
|
||||
headerBuf.WriteString("// This file provides wrapper declarations for MLX-C functions that use dlopen/dlsym\n")
|
||||
headerBuf.WriteString("//\n")
|
||||
headerBuf.WriteString("// Strategy: Include MLX-C headers for type definitions, then provide wrapper\n")
|
||||
headerBuf.WriteString("// functions that shadow the originals, allowing Go code to call them directly (e.g., C.mlx_add).\n")
|
||||
headerBuf.WriteString("// Function pointers are defined in mlx.c (single compilation unit).\n\n")
|
||||
headerBuf.WriteString("#ifndef MLX_WRAPPERS_H\n")
|
||||
headerBuf.WriteString("#define MLX_WRAPPERS_H\n\n")
|
||||
|
||||
headerBuf.WriteString("// Include MLX headers for type definitions and original declarations\n")
|
||||
headerBuf.WriteString("#include \"mlx/c/mlx.h\"\n")
|
||||
headerBuf.WriteString("#include \"mlx_dynamic.h\"\n")
|
||||
headerBuf.WriteString("#include <stdio.h>\n\n")
|
||||
|
||||
// Undef all MLX functions to avoid conflicts
|
||||
headerBuf.WriteString("// Undefine any existing MLX function macros\n")
|
||||
for _, fn := range functions {
|
||||
headerBuf.WriteString(fmt.Sprintf("#undef %s\n", fn.Name))
|
||||
}
|
||||
headerBuf.WriteString("\n")
|
||||
|
||||
// Function pointer extern declarations
|
||||
headerBuf.WriteString("// Function pointer declarations (defined in mlx.c, loaded via dlsym)\n")
|
||||
for _, fn := range functions {
|
||||
if fn.NeedsARM64Guard {
|
||||
headerBuf.WriteString("#if defined(__aarch64__) || defined(_M_ARM64)\n")
|
||||
}
|
||||
headerBuf.WriteString(fmt.Sprintf("extern %s (*%s_ptr)(%s);\n", fn.ReturnType, fn.Name, fn.Params))
|
||||
if fn.NeedsARM64Guard {
|
||||
headerBuf.WriteString("#endif\n")
|
||||
}
|
||||
}
|
||||
headerBuf.WriteString("\n")
|
||||
|
||||
// Initialization function declaration
|
||||
headerBuf.WriteString("// Initialize all function pointers via dlsym (defined in mlx.c)\n")
|
||||
headerBuf.WriteString("int mlx_load_functions(void* handle);\n\n")
|
||||
|
||||
// Wrapper function declarations
|
||||
headerBuf.WriteString("// Wrapper function declarations that call through function pointers\n")
|
||||
headerBuf.WriteString("// Go code calls these directly as C.mlx_* (no #define redirection needed)\n")
|
||||
for _, fn := range functions {
|
||||
if fn.NeedsARM64Guard {
|
||||
headerBuf.WriteString("#if defined(__aarch64__) || defined(_M_ARM64)\n")
|
||||
}
|
||||
headerBuf.WriteString(fmt.Sprintf("%s %s(%s);\n", fn.ReturnType, fn.Name, fn.Params))
|
||||
if fn.NeedsARM64Guard {
|
||||
headerBuf.WriteString("#endif\n")
|
||||
}
|
||||
headerBuf.WriteString("\n")
|
||||
}
|
||||
|
||||
headerBuf.WriteString("#endif // MLX_WRAPPERS_H\n")
|
||||
|
||||
// Write header file
|
||||
if err := os.WriteFile(headerPath, headerBuf.Bytes(), 0644); err != nil {
|
||||
return fmt.Errorf("failed to write header file: %w", err)
|
||||
}
|
||||
|
||||
// Generate implementation file
|
||||
var implBuf bytes.Buffer
|
||||
|
||||
implBuf.WriteString("// AUTO-GENERATED by generate_wrappers.go - DO NOT EDIT\n")
|
||||
implBuf.WriteString("// This file contains the function pointer definitions and initialization\n")
|
||||
implBuf.WriteString("// All function pointers are in a single compilation unit to avoid duplication\n\n")
|
||||
|
||||
implBuf.WriteString("#include \"mlx/c/mlx.h\"\n")
|
||||
implBuf.WriteString("#include \"mlx_dynamic.h\"\n")
|
||||
implBuf.WriteString("#include <stdio.h>\n")
|
||||
implBuf.WriteString("#include <dlfcn.h>\n\n")
|
||||
|
||||
// Function pointer definitions
|
||||
implBuf.WriteString("// Function pointer definitions\n")
|
||||
for _, fn := range functions {
|
||||
if fn.NeedsARM64Guard {
|
||||
implBuf.WriteString("#if defined(__aarch64__) || defined(_M_ARM64)\n")
|
||||
}
|
||||
implBuf.WriteString(fmt.Sprintf("%s (*%s_ptr)(%s) = NULL;\n", fn.ReturnType, fn.Name, fn.Params))
|
||||
if fn.NeedsARM64Guard {
|
||||
implBuf.WriteString("#endif\n")
|
||||
}
|
||||
}
|
||||
implBuf.WriteString("\n")
|
||||
|
||||
// Initialization function
|
||||
implBuf.WriteString("// Initialize all function pointers via dlsym\n")
|
||||
implBuf.WriteString("int mlx_load_functions(void* handle) {\n")
|
||||
implBuf.WriteString(" if (handle == NULL) {\n")
|
||||
implBuf.WriteString(" fprintf(stderr, \"MLX: Invalid library handle\\n\");\n")
|
||||
implBuf.WriteString(" return -1;\n")
|
||||
implBuf.WriteString(" }\n\n")
|
||||
|
||||
for _, fn := range functions {
|
||||
if fn.NeedsARM64Guard {
|
||||
implBuf.WriteString("#if defined(__aarch64__) || defined(_M_ARM64)\n")
|
||||
}
|
||||
implBuf.WriteString(fmt.Sprintf(" %s_ptr = dlsym(handle, \"%s\");\n", fn.Name, fn.Name))
|
||||
implBuf.WriteString(fmt.Sprintf(" if (%s_ptr == NULL) {\n", fn.Name))
|
||||
implBuf.WriteString(fmt.Sprintf(" fprintf(stderr, \"MLX: Failed to load symbol: %s\\n\");\n", fn.Name))
|
||||
implBuf.WriteString(" return -1;\n")
|
||||
implBuf.WriteString(" }\n")
|
||||
if fn.NeedsARM64Guard {
|
||||
implBuf.WriteString("#endif\n")
|
||||
}
|
||||
}
|
||||
|
||||
implBuf.WriteString(" return 0;\n")
|
||||
implBuf.WriteString("}\n\n")
|
||||
|
||||
// Wrapper function implementations
|
||||
implBuf.WriteString("// Wrapper function implementations that call through function pointers\n")
|
||||
for _, fn := range functions {
|
||||
if fn.NeedsARM64Guard {
|
||||
implBuf.WriteString("#if defined(__aarch64__) || defined(_M_ARM64)\n")
|
||||
}
|
||||
implBuf.WriteString(fmt.Sprintf("%s %s(%s) {\n", fn.ReturnType, fn.Name, fn.Params))
|
||||
|
||||
// Call through function pointer
|
||||
if fn.ReturnType != "void" {
|
||||
implBuf.WriteString(fmt.Sprintf(" return %s_ptr(", fn.Name))
|
||||
} else {
|
||||
implBuf.WriteString(fmt.Sprintf(" %s_ptr(", fn.Name))
|
||||
}
|
||||
|
||||
// Pass parameters
|
||||
implBuf.WriteString(strings.Join(fn.ParamNames, ", "))
|
||||
implBuf.WriteString(");\n")
|
||||
implBuf.WriteString("}\n")
|
||||
if fn.NeedsARM64Guard {
|
||||
implBuf.WriteString("#endif\n")
|
||||
}
|
||||
implBuf.WriteString("\n")
|
||||
}
|
||||
|
||||
// Write implementation file
|
||||
if err := os.WriteFile(implPath, implBuf.Bytes(), 0644); err != nil {
|
||||
return fmt.Errorf("failed to write implementation file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func main() {
|
||||
flag.Usage = func() {
|
||||
fmt.Fprintf(flag.CommandLine.Output(), "Usage: go run generate_wrappers.go <mlx-c-include-dir> <output-header> [output-impl]\n")
|
||||
fmt.Fprintf(flag.CommandLine.Output(), "Generate MLX-C dynamic loading wrappers.\n\n")
|
||||
flag.PrintDefaults()
|
||||
}
|
||||
flag.Parse()
|
||||
|
||||
args := flag.Args()
|
||||
if len(args) < 2 {
|
||||
fmt.Fprintf(flag.CommandLine.Output(), "ERROR: Missing required arguments\n\n")
|
||||
flag.Usage()
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
headerDir := args[0]
|
||||
outputHeader := args[1]
|
||||
// Default implementation file is same name with .c extension
|
||||
outputImpl := outputHeader
|
||||
if len(args) > 2 {
|
||||
outputImpl = args[2]
|
||||
} else if strings.HasSuffix(outputHeader, ".h") {
|
||||
outputImpl = outputHeader[:len(outputHeader)-2] + ".c"
|
||||
}
|
||||
|
||||
// Check if header directory exists
|
||||
if _, err := os.Stat(headerDir); os.IsNotExist(err) {
|
||||
fmt.Fprintf(os.Stderr, "ERROR: MLX-C headers directory not found at: %s\n\n", headerDir)
|
||||
fmt.Fprintf(os.Stderr, "Please run CMake first to download MLX-C dependencies:\n")
|
||||
fmt.Fprintf(os.Stderr, " cmake -B build\n\n")
|
||||
fmt.Fprintf(os.Stderr, "The CMake build will download and extract MLX-C headers needed for wrapper generation.\n")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "Parsing MLX-C headers from: %s\n", headerDir)
|
||||
|
||||
// Find all headers
|
||||
headers, err := findHeaders(headerDir)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "ERROR: Failed to find header files: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "Found %d header files\n", len(headers))
|
||||
|
||||
// Parse all headers
|
||||
var allFunctions []Function
|
||||
seen := make(map[string]bool)
|
||||
|
||||
for _, header := range headers {
|
||||
content, err := os.ReadFile(header)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error reading %s: %v\n", header, err)
|
||||
continue
|
||||
}
|
||||
|
||||
cleaned := cleanContent(string(content))
|
||||
functions := parseFunctions(cleaned)
|
||||
|
||||
// Deduplicate
|
||||
for _, fn := range functions {
|
||||
if !seen[fn.Name] {
|
||||
seen[fn.Name] = true
|
||||
allFunctions = append(allFunctions, fn)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "Found %d unique function declarations\n", len(allFunctions))
|
||||
|
||||
// Generate wrapper files
|
||||
if err := generateWrapperFiles(allFunctions, outputHeader, outputImpl); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "ERROR: Failed to generate wrapper files: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "Generated %s and %s successfully\n", outputHeader, outputImpl)
|
||||
}
|
||||
5786
x/imagegen/mlx/mlx.c
@@ -3,13 +3,12 @@
|
||||
package mlx
|
||||
|
||||
/*
|
||||
#cgo CFLAGS: -O3 -I${SRCDIR}/../../../build/_deps/mlx-c-src -I${SRCDIR}
|
||||
#cgo CFLAGS: -O3 -I${SRCDIR}/../../../build/_deps/mlx-c-src
|
||||
#cgo LDFLAGS: -L${SRCDIR}/../../../build/lib/ollama/ -lmlxc -Wl,-rpath,${SRCDIR}/../../../build/lib/ollama/
|
||||
#cgo darwin LDFLAGS: -lc++ -framework Metal -framework Foundation -framework Accelerate
|
||||
#cgo linux LDFLAGS: -lstdc++ -ldl
|
||||
#cgo windows LDFLAGS: -lstdc++
|
||||
#cgo linux LDFLAGS: -lstdc++ -lcuda -lcudart -lnvrtc
|
||||
|
||||
// Use generated wrappers instead of direct MLX headers
|
||||
#include "mlx.h"
|
||||
#include "mlx/c/mlx.h"
|
||||
#include <stdlib.h>
|
||||
#include <stdint.h>
|
||||
#include <string.h>
|
||||
@@ -43,6 +42,192 @@ static inline mlx_stream cpu_stream() {
|
||||
// CGO noescape/nocallback hints to reduce CGO overhead
|
||||
// noescape: pointers won't escape, no heap allocation needed
|
||||
// nocallback: function won't call back into Go
|
||||
#cgo noescape mlx_add
|
||||
#cgo nocallback mlx_add
|
||||
#cgo noescape mlx_subtract
|
||||
#cgo nocallback mlx_subtract
|
||||
#cgo noescape mlx_multiply
|
||||
#cgo nocallback mlx_multiply
|
||||
#cgo noescape mlx_divide
|
||||
#cgo nocallback mlx_divide
|
||||
#cgo noescape mlx_negative
|
||||
#cgo nocallback mlx_negative
|
||||
#cgo noescape mlx_abs
|
||||
#cgo nocallback mlx_abs
|
||||
#cgo noescape mlx_exp
|
||||
#cgo nocallback mlx_exp
|
||||
#cgo noescape mlx_log
|
||||
#cgo nocallback mlx_log
|
||||
#cgo noescape mlx_sqrt
|
||||
#cgo nocallback mlx_sqrt
|
||||
#cgo noescape mlx_rsqrt
|
||||
#cgo nocallback mlx_rsqrt
|
||||
#cgo noescape mlx_square
|
||||
#cgo nocallback mlx_square
|
||||
#cgo noescape mlx_power
|
||||
#cgo nocallback mlx_power
|
||||
#cgo noescape mlx_erf
|
||||
#cgo nocallback mlx_erf
|
||||
#cgo noescape mlx_sigmoid
|
||||
#cgo nocallback mlx_sigmoid
|
||||
#cgo noescape mlx_tanh
|
||||
#cgo nocallback mlx_tanh
|
||||
#cgo noescape mlx_sin
|
||||
#cgo nocallback mlx_sin
|
||||
#cgo noescape mlx_cos
|
||||
#cgo nocallback mlx_cos
|
||||
#cgo noescape mlx_maximum
|
||||
#cgo nocallback mlx_maximum
|
||||
#cgo noescape mlx_minimum
|
||||
#cgo nocallback mlx_minimum
|
||||
#cgo noescape mlx_clip
|
||||
#cgo nocallback mlx_clip
|
||||
#cgo noescape mlx_sum
|
||||
#cgo nocallback mlx_sum
|
||||
#cgo noescape mlx_sum_axis
|
||||
#cgo nocallback mlx_sum_axis
|
||||
#cgo noescape mlx_mean
|
||||
#cgo nocallback mlx_mean
|
||||
#cgo noescape mlx_mean_axis
|
||||
#cgo nocallback mlx_mean_axis
|
||||
#cgo noescape mlx_var_axis
|
||||
#cgo nocallback mlx_var_axis
|
||||
#cgo noescape mlx_argmax
|
||||
#cgo nocallback mlx_argmax
|
||||
#cgo noescape mlx_argmax_axis
|
||||
#cgo nocallback mlx_argmax_axis
|
||||
#cgo noescape mlx_softmax_axis
|
||||
#cgo nocallback mlx_softmax_axis
|
||||
#cgo noescape mlx_cumsum
|
||||
#cgo nocallback mlx_cumsum
|
||||
#cgo noescape mlx_matmul
|
||||
#cgo nocallback mlx_matmul
|
||||
#cgo noescape mlx_addmm
|
||||
#cgo nocallback mlx_addmm
|
||||
#cgo noescape mlx_gather_mm
|
||||
#cgo nocallback mlx_gather_mm
|
||||
#cgo noescape mlx_gather_qmm
|
||||
#cgo nocallback mlx_gather_qmm
|
||||
#cgo noescape mlx_reshape
|
||||
#cgo nocallback mlx_reshape
|
||||
#cgo noescape mlx_transpose_axes
|
||||
#cgo nocallback mlx_transpose_axes
|
||||
#cgo noescape mlx_expand_dims
|
||||
#cgo nocallback mlx_expand_dims
|
||||
#cgo noescape mlx_squeeze_axis
|
||||
#cgo nocallback mlx_squeeze_axis
|
||||
#cgo noescape mlx_flatten
|
||||
#cgo nocallback mlx_flatten
|
||||
#cgo noescape mlx_concatenate_axis
|
||||
#cgo nocallback mlx_concatenate_axis
|
||||
#cgo noescape mlx_slice
|
||||
#cgo nocallback mlx_slice
|
||||
#cgo noescape mlx_slice_update
|
||||
#cgo nocallback mlx_slice_update
|
||||
#cgo noescape mlx_as_strided
|
||||
#cgo nocallback mlx_as_strided
|
||||
#cgo noescape mlx_view
|
||||
#cgo nocallback mlx_view
|
||||
#cgo noescape mlx_contiguous
|
||||
#cgo nocallback mlx_contiguous
|
||||
#cgo noescape mlx_pad
|
||||
#cgo nocallback mlx_pad
|
||||
#cgo noescape mlx_tile
|
||||
#cgo nocallback mlx_tile
|
||||
#cgo noescape mlx_take_axis
|
||||
#cgo nocallback mlx_take_axis
|
||||
#cgo noescape mlx_take_along_axis
|
||||
#cgo nocallback mlx_take_along_axis
|
||||
#cgo noescape mlx_put_along_axis
|
||||
#cgo nocallback mlx_put_along_axis
|
||||
#cgo noescape mlx_where
|
||||
#cgo nocallback mlx_where
|
||||
#cgo noescape mlx_argsort_axis
|
||||
#cgo nocallback mlx_argsort_axis
|
||||
#cgo noescape mlx_argpartition_axis
|
||||
#cgo nocallback mlx_argpartition_axis
|
||||
#cgo noescape mlx_topk_axis
|
||||
#cgo nocallback mlx_topk_axis
|
||||
#cgo noescape mlx_less
|
||||
#cgo nocallback mlx_less
|
||||
#cgo noescape mlx_greater_equal
|
||||
#cgo nocallback mlx_greater_equal
|
||||
#cgo noescape mlx_logical_and
|
||||
#cgo nocallback mlx_logical_and
|
||||
#cgo noescape mlx_zeros
|
||||
#cgo nocallback mlx_zeros
|
||||
#cgo noescape mlx_zeros_like
|
||||
#cgo nocallback mlx_zeros_like
|
||||
#cgo noescape mlx_ones
|
||||
#cgo nocallback mlx_ones
|
||||
#cgo noescape mlx_full
|
||||
#cgo nocallback mlx_full
|
||||
#cgo noescape mlx_arange
|
||||
#cgo nocallback mlx_arange
|
||||
#cgo noescape mlx_linspace
|
||||
#cgo nocallback mlx_linspace
|
||||
#cgo noescape mlx_tri
|
||||
#cgo nocallback mlx_tri
|
||||
#cgo noescape mlx_astype
|
||||
#cgo nocallback mlx_astype
|
||||
#cgo noescape mlx_fast_rms_norm
|
||||
#cgo nocallback mlx_fast_rms_norm
|
||||
#cgo noescape mlx_fast_rope
|
||||
#cgo nocallback mlx_fast_rope
|
||||
#cgo noescape mlx_fast_scaled_dot_product_attention
|
||||
#cgo nocallback mlx_fast_scaled_dot_product_attention
|
||||
#cgo noescape mlx_conv2d
|
||||
#cgo nocallback mlx_conv2d
|
||||
#cgo noescape mlx_conv3d
|
||||
#cgo nocallback mlx_conv3d
|
||||
#cgo noescape mlx_random_key
|
||||
#cgo nocallback mlx_random_key
|
||||
#cgo noescape mlx_random_split
|
||||
#cgo nocallback mlx_random_split
|
||||
#cgo noescape mlx_random_categorical_num_samples
|
||||
#cgo nocallback mlx_random_categorical_num_samples
|
||||
#cgo noescape mlx_random_normal
|
||||
#cgo nocallback mlx_random_normal
|
||||
#cgo noescape mlx_random_uniform
|
||||
#cgo nocallback mlx_random_uniform
|
||||
#cgo noescape mlx_array_eval
|
||||
#cgo nocallback mlx_array_eval
|
||||
#cgo noescape mlx_eval
|
||||
#cgo nocallback mlx_eval
|
||||
#cgo noescape mlx_async_eval
|
||||
#cgo nocallback mlx_async_eval
|
||||
#cgo noescape mlx_synchronize
|
||||
#cgo nocallback mlx_synchronize
|
||||
#cgo noescape mlx_array_new
|
||||
#cgo nocallback mlx_array_new
|
||||
#cgo noescape mlx_array_new_data
|
||||
#cgo nocallback mlx_array_new_data
|
||||
#cgo noescape mlx_array_new_float
|
||||
#cgo nocallback mlx_array_new_float
|
||||
#cgo noescape mlx_array_free
|
||||
#cgo nocallback mlx_array_free
|
||||
#cgo noescape mlx_array_size
|
||||
#cgo nocallback mlx_array_size
|
||||
#cgo noescape mlx_array_ndim
|
||||
#cgo nocallback mlx_array_ndim
|
||||
#cgo noescape mlx_array_dim
|
||||
#cgo nocallback mlx_array_dim
|
||||
#cgo noescape mlx_array_dtype
|
||||
#cgo nocallback mlx_array_dtype
|
||||
#cgo noescape mlx_array_item_int32
|
||||
#cgo nocallback mlx_array_item_int32
|
||||
#cgo noescape mlx_vector_array_new_data
|
||||
#cgo nocallback mlx_vector_array_new_data
|
||||
#cgo noescape mlx_vector_array_free
|
||||
#cgo nocallback mlx_vector_array_free
|
||||
#cgo noescape mlx_array_new_int
|
||||
#cgo nocallback mlx_array_new_int
|
||||
#cgo noescape mlx_stream_new_device
|
||||
#cgo nocallback mlx_stream_new_device
|
||||
#cgo noescape mlx_get_default_stream
|
||||
#cgo nocallback mlx_get_default_stream
|
||||
#cgo noescape mlx_set_default_stream
|
||||
#cgo nocallback mlx_set_default_stream
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
@@ -1137,27 +1322,6 @@ func RMSNormNoWeight(x *Array, eps float32) *Array {
|
||||
return RMSNorm(x, ones, eps)
|
||||
}
|
||||
|
||||
// LayerNorm applies layer normalization without learnable params
|
||||
// (x - mean) / sqrt(var + eps)
|
||||
func LayerNorm(x *Array, eps float32) *Array {
|
||||
return LayerNormWithWeightBias(x, nil, nil, eps)
|
||||
}
|
||||
|
||||
// LayerNormWithWeightBias computes layer normalization using mlx.fast
|
||||
// weight and bias can be nil for elementwise_affine=False
|
||||
func LayerNormWithWeightBias(x, weight, bias *Array, eps float32) *Array {
|
||||
res := C.mlx_array_new()
|
||||
var wc, bc C.mlx_array
|
||||
if weight != nil {
|
||||
wc = weight.c
|
||||
}
|
||||
if bias != nil {
|
||||
bc = bias.c
|
||||
}
|
||||
C.mlx_fast_layer_norm(&res, x.c, wc, bc, C.float(eps), C.default_stream())
|
||||
return newArray(res)
|
||||
}
|
||||
|
||||
// RoPE applies rotary position embeddings using mlx.fast
|
||||
func RoPE(x *Array, dims int, traditional bool, base, scale float32, offset int) *Array {
|
||||
res := C.mlx_array_new()
|
||||
@@ -1632,57 +1796,7 @@ func ArgmaxKeepArray(logits *Array) *Array {
|
||||
var RandomState = []*Array{nil}
|
||||
var randomStateMu sync.Mutex
|
||||
|
||||
var mlxInitialized bool
|
||||
var mlxInitError error
|
||||
|
||||
// InitMLX initializes the MLX library by dynamically loading libmlxc.
|
||||
// This must be called before using any MLX functions.
|
||||
// Returns an error if the library cannot be loaded.
|
||||
func InitMLX() error {
|
||||
if mlxInitialized {
|
||||
return mlxInitError
|
||||
}
|
||||
|
||||
// Try to load the MLX dynamic library
|
||||
ret := C.mlx_dynamic_init()
|
||||
if ret != 0 {
|
||||
errMsg := C.GoString(C.mlx_dynamic_error())
|
||||
mlxInitError = fmt.Errorf("failed to initialize MLX: %s", errMsg)
|
||||
return mlxInitError
|
||||
}
|
||||
|
||||
// Initialize all function pointers via dlsym
|
||||
handle := C.mlx_get_handle()
|
||||
ret = C.mlx_load_functions(handle)
|
||||
if ret != 0 {
|
||||
mlxInitError = fmt.Errorf("failed to load MLX function symbols")
|
||||
return mlxInitError
|
||||
}
|
||||
|
||||
mlxInitialized = true
|
||||
mlxInitError = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsMLXAvailable returns whether MLX was successfully initialized
|
||||
func IsMLXAvailable() bool {
|
||||
return mlxInitialized && mlxInitError == nil
|
||||
}
|
||||
|
||||
// GetMLXInitError returns any error that occurred during MLX initialization
|
||||
func GetMLXInitError() error {
|
||||
return mlxInitError
|
||||
}
|
||||
|
||||
func init() {
|
||||
// Initialize MLX dynamic library first
|
||||
if err := InitMLX(); err != nil {
|
||||
// Don't panic in init - let the caller handle the error
|
||||
// Store the error for later retrieval
|
||||
mlxInitError = err
|
||||
return
|
||||
}
|
||||
|
||||
// Lock main goroutine to OS thread for CUDA context stability.
|
||||
// CUDA contexts are bound to threads; Go can migrate goroutines between threads.
|
||||
runtime.LockOSThread()
|
||||
|
||||
2337
x/imagegen/mlx/mlx.h
@@ -1,144 +0,0 @@
|
||||
// mlx_dynamic.c - Dynamic loading wrapper for MLX-C library
|
||||
// This file provides runtime dynamic loading of libmlxc instead of link-time binding
|
||||
|
||||
#include "mlx_dynamic.h"
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
#ifdef _WIN32
|
||||
#include <windows.h>
|
||||
typedef HMODULE lib_handle_t;
|
||||
#define LOAD_LIB(path) LoadLibraryA(path)
|
||||
#define GET_SYMBOL(handle, name) GetProcAddress(handle, name)
|
||||
#define CLOSE_LIB(handle) FreeLibrary(handle)
|
||||
#define LIB_ERROR() "LoadLibrary failed"
|
||||
#else
|
||||
#include <dlfcn.h>
|
||||
typedef void* lib_handle_t;
|
||||
#define LOAD_LIB(path) dlopen(path, RTLD_LAZY | RTLD_GLOBAL)
|
||||
#define GET_SYMBOL(handle, name) dlsym(handle, name)
|
||||
#define CLOSE_LIB(handle) dlclose(handle)
|
||||
#define LIB_ERROR() dlerror()
|
||||
#ifdef __APPLE__
|
||||
#include <mach-o/dyld.h>
|
||||
#include <libgen.h>
|
||||
#endif
|
||||
#endif
|
||||
|
||||
static lib_handle_t mlx_handle = NULL;
|
||||
static int mlx_initialized = 0;
|
||||
static char mlx_error_buffer[512] = {0};
|
||||
|
||||
#ifdef __APPLE__
|
||||
// Get path to library in same directory as executable
|
||||
static char* get_exe_relative_path(const char* libname) {
|
||||
static char path[1024];
|
||||
uint32_t size = sizeof(path);
|
||||
if (_NSGetExecutablePath(path, &size) != 0) {
|
||||
return NULL;
|
||||
}
|
||||
// Get directory of executable
|
||||
char* dir = dirname(path);
|
||||
static char fullpath[1024];
|
||||
snprintf(fullpath, sizeof(fullpath), "%s/%s", dir, libname);
|
||||
return fullpath;
|
||||
}
|
||||
#endif
|
||||
|
||||
// Try to load library from a specific path
|
||||
static int try_load_lib(const char* path) {
|
||||
if (!path) return 0;
|
||||
mlx_handle = LOAD_LIB(path);
|
||||
return mlx_handle != NULL;
|
||||
}
|
||||
|
||||
// Initialize MLX dynamic library
|
||||
// Returns 0 on success, -1 on failure
|
||||
// On failure, call mlx_dynamic_error() to get error message
|
||||
int mlx_dynamic_init(void) {
|
||||
if (mlx_initialized) {
|
||||
return 0; // Already initialized
|
||||
}
|
||||
|
||||
const char* lib_path = NULL;
|
||||
const char* tried_paths[8] = {0};
|
||||
int num_tried = 0;
|
||||
|
||||
#ifdef _WIN32
|
||||
// Windows: try same directory as executable
|
||||
lib_path = "libmlxc.dll";
|
||||
tried_paths[num_tried++] = lib_path;
|
||||
if (try_load_lib(lib_path)) goto success;
|
||||
#elif defined(__APPLE__)
|
||||
// macOS: try executable directory first
|
||||
lib_path = get_exe_relative_path("libmlxc.dylib");
|
||||
if (lib_path) {
|
||||
tried_paths[num_tried++] = lib_path;
|
||||
if (try_load_lib(lib_path)) goto success;
|
||||
}
|
||||
// Try build directory (for tests run from repo root)
|
||||
lib_path = "./build/lib/ollama/libmlxc.dylib";
|
||||
tried_paths[num_tried++] = lib_path;
|
||||
if (try_load_lib(lib_path)) goto success;
|
||||
// Fallback to system paths
|
||||
lib_path = "libmlxc.dylib";
|
||||
tried_paths[num_tried++] = lib_path;
|
||||
if (try_load_lib(lib_path)) goto success;
|
||||
#else
|
||||
// Linux: try build directory first (for tests)
|
||||
lib_path = "./build/lib/ollama/libmlxc.so";
|
||||
tried_paths[num_tried++] = lib_path;
|
||||
if (try_load_lib(lib_path)) goto success;
|
||||
// Fallback to system paths
|
||||
lib_path = "libmlxc.so";
|
||||
tried_paths[num_tried++] = lib_path;
|
||||
if (try_load_lib(lib_path)) goto success;
|
||||
#endif
|
||||
|
||||
// Failed to load library - build error message with all tried paths
|
||||
{
|
||||
const char* err = LIB_ERROR();
|
||||
int offset = snprintf(mlx_error_buffer, sizeof(mlx_error_buffer),
|
||||
"MLX: Failed to load libmlxc library. Tried: ");
|
||||
for (int i = 0; i < num_tried && offset < (int)sizeof(mlx_error_buffer) - 50; i++) {
|
||||
offset += snprintf(mlx_error_buffer + offset, sizeof(mlx_error_buffer) - offset,
|
||||
"%s%s", i > 0 ? ", " : "", tried_paths[i]);
|
||||
}
|
||||
if (err) {
|
||||
snprintf(mlx_error_buffer + offset, sizeof(mlx_error_buffer) - offset,
|
||||
". Last error: %s", err);
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
|
||||
success:
|
||||
mlx_initialized = 1;
|
||||
snprintf(mlx_error_buffer, sizeof(mlx_error_buffer),
|
||||
"MLX: Successfully loaded %s", lib_path ? lib_path : "library");
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Get the last error message
|
||||
const char* mlx_dynamic_error(void) {
|
||||
return mlx_error_buffer;
|
||||
}
|
||||
|
||||
// Check if MLX is initialized
|
||||
int mlx_dynamic_is_initialized(void) {
|
||||
return mlx_initialized;
|
||||
}
|
||||
|
||||
// Get the library handle (for use by generated wrappers)
|
||||
void* mlx_get_handle(void) {
|
||||
return mlx_handle;
|
||||
}
|
||||
|
||||
// Cleanup (optional, called at program exit)
|
||||
void mlx_dynamic_cleanup(void) {
|
||||
if (mlx_handle != NULL) {
|
||||
CLOSE_LIB(mlx_handle);
|
||||
mlx_handle = NULL;
|
||||
mlx_initialized = 0;
|
||||
}
|
||||
}
|
||||
@@ -1,29 +0,0 @@
|
||||
// mlx_dynamic.h - Dynamic loading interface for MLX-C library
|
||||
#ifndef MLX_DYNAMIC_H
|
||||
#define MLX_DYNAMIC_H
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
// Initialize the MLX dynamic library
|
||||
// Returns 0 on success, -1 on failure
|
||||
int mlx_dynamic_init(void);
|
||||
|
||||
// Get the last error message from dynamic loading
|
||||
const char* mlx_dynamic_error(void);
|
||||
|
||||
// Check if MLX is initialized
|
||||
int mlx_dynamic_is_initialized(void);
|
||||
|
||||
// Get the library handle (for use by generated wrappers)
|
||||
void* mlx_get_handle(void);
|
||||
|
||||
// Cleanup resources (optional, for clean shutdown)
|
||||
void mlx_dynamic_cleanup(void);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // MLX_DYNAMIC_H
|
||||
@@ -4,30 +4,9 @@ package mlx
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestMain initializes MLX before running tests.
|
||||
// If MLX libraries are not available, tests are skipped.
|
||||
func TestMain(m *testing.M) {
|
||||
// Change to repo root so ./build/lib/ollama/ path works
|
||||
_, thisFile, _, _ := runtime.Caller(0)
|
||||
repoRoot := filepath.Join(filepath.Dir(thisFile), "..", "..", "..")
|
||||
if err := os.Chdir(repoRoot); err != nil {
|
||||
fmt.Printf("Failed to change to repo root: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if err := InitMLX(); err != nil {
|
||||
fmt.Printf("Skipping MLX tests: %v\n", err)
|
||||
os.Exit(0)
|
||||
}
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
// TestBasicCleanup verifies non-kept arrays are freed and kept arrays survive.
|
||||
func TestBasicCleanup(t *testing.T) {
|
||||
weight := NewArrayFloat32([]float32{1, 2, 3, 4}, []int32{2, 2})
|
||||
|
||||
@@ -1,539 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package flux2 implements the FLUX.2 Klein diffusion transformer model.
|
||||
// Klein is a 4B parameter distilled model that supports sub-second inference.
|
||||
package flux2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"image"
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/models/qwen3"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
"golang.org/x/image/draw"
|
||||
)
|
||||
|
||||
// GenerateConfig holds all options for image generation.
|
||||
type GenerateConfig struct {
|
||||
Prompt string
|
||||
Width int32 // Image width (default: 1024)
|
||||
Height int32 // Image height (default: 1024)
|
||||
Steps int // Denoising steps (default: 4 for Klein)
|
||||
GuidanceScale float32 // Guidance scale (default: 1.0, Klein doesn't need CFG)
|
||||
Seed int64 // Random seed
|
||||
Progress func(step, totalSteps int) // Optional progress callback
|
||||
CapturePath string // GPU capture path (debug)
|
||||
InputImages []image.Image // Reference images for image conditioning (already loaded)
|
||||
}
|
||||
|
||||
// Model represents a FLUX.2 Klein model.
|
||||
type Model struct {
|
||||
ModelName string
|
||||
Tokenizer *tokenizer.Tokenizer
|
||||
TextEncoder *qwen3.TextEncoder
|
||||
Transformer *Flux2Transformer2DModel
|
||||
VAE *AutoencoderKLFlux2
|
||||
SchedulerConfig *SchedulerConfig
|
||||
}
|
||||
|
||||
// TextEncoderLayerIndices are the layers from which to extract text embeddings.
|
||||
// Diffusers uses hidden_states[9, 18, 27]. In Python, hidden_states[0] is the embedding
|
||||
// output before any layers, so hidden_states[9] = after layer 8 (0-indexed).
|
||||
// Go's ForwardWithLayerOutputs captures after layer i runs, so we use [8, 17, 26].
|
||||
var TextEncoderLayerIndices = []int{8, 17, 26}
|
||||
|
||||
// Load loads the FLUX.2 Klein model from ollama blob storage.
|
||||
func (m *Model) Load(modelName string) error {
|
||||
fmt.Printf("Loading FLUX.2 Klein model from manifest: %s...\n", modelName)
|
||||
start := time.Now()
|
||||
|
||||
if mlx.GPUIsAvailable() {
|
||||
mlx.SetDefaultDeviceGPU()
|
||||
mlx.EnableCompile()
|
||||
}
|
||||
|
||||
m.ModelName = modelName
|
||||
|
||||
// Load manifest
|
||||
manifest, err := imagegen.LoadManifest(modelName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load manifest: %w", err)
|
||||
}
|
||||
|
||||
// Load tokenizer
|
||||
fmt.Print(" Loading tokenizer... ")
|
||||
tokData, err := manifest.ReadConfig("tokenizer/tokenizer.json")
|
||||
if err != nil {
|
||||
return fmt.Errorf("tokenizer: %w", err)
|
||||
}
|
||||
|
||||
tokConfig := &tokenizer.TokenizerConfig{}
|
||||
if data, err := manifest.ReadConfig("tokenizer/tokenizer_config.json"); err == nil {
|
||||
tokConfig.TokenizerConfigJSON = data
|
||||
}
|
||||
if data, err := manifest.ReadConfig("tokenizer/generation_config.json"); err == nil {
|
||||
tokConfig.GenerationConfigJSON = data
|
||||
}
|
||||
if data, err := manifest.ReadConfig("tokenizer/special_tokens_map.json"); err == nil {
|
||||
tokConfig.SpecialTokensMapJSON = data
|
||||
}
|
||||
|
||||
tok, err := tokenizer.LoadFromBytesWithConfig(tokData, tokConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("tokenizer: %w", err)
|
||||
}
|
||||
m.Tokenizer = tok
|
||||
fmt.Println("✓")
|
||||
|
||||
// Load text encoder
|
||||
m.TextEncoder = &qwen3.TextEncoder{}
|
||||
if err := m.TextEncoder.Load(manifest, "text_encoder/config.json"); err != nil {
|
||||
return fmt.Errorf("text encoder: %w", err)
|
||||
}
|
||||
|
||||
// Load transformer
|
||||
m.Transformer = &Flux2Transformer2DModel{}
|
||||
if err := m.Transformer.Load(manifest); err != nil {
|
||||
return fmt.Errorf("transformer: %w", err)
|
||||
}
|
||||
|
||||
// Load VAE
|
||||
m.VAE = &AutoencoderKLFlux2{}
|
||||
if err := m.VAE.Load(manifest); err != nil {
|
||||
return fmt.Errorf("VAE: %w", err)
|
||||
}
|
||||
|
||||
// Evaluate all weights in a single batch (reduces GPU sync overhead)
|
||||
fmt.Print(" Evaluating weights... ")
|
||||
allWeights := mlx.Collect(m.TextEncoder)
|
||||
allWeights = append(allWeights, mlx.Collect(m.Transformer)...)
|
||||
allWeights = append(allWeights, mlx.Collect(m.VAE)...)
|
||||
mlx.Eval(allWeights...)
|
||||
fmt.Println("✓")
|
||||
|
||||
// Load scheduler config
|
||||
m.SchedulerConfig = DefaultSchedulerConfig()
|
||||
if schedData, err := manifest.ReadConfig("scheduler/scheduler_config.json"); err == nil {
|
||||
if err := json.Unmarshal(schedData, m.SchedulerConfig); err != nil {
|
||||
fmt.Printf(" Warning: failed to parse scheduler config: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
mem := mlx.MetalGetActiveMemory()
|
||||
fmt.Printf(" Loaded in %.2fs (%.1f GB VRAM)\n", time.Since(start).Seconds(), float64(mem)/(1024*1024*1024))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Generate creates an image from a prompt.
|
||||
func (m *Model) Generate(prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
|
||||
return m.GenerateFromConfig(context.Background(), &GenerateConfig{
|
||||
Prompt: prompt,
|
||||
Width: width,
|
||||
Height: height,
|
||||
Steps: steps,
|
||||
Seed: seed,
|
||||
})
|
||||
}
|
||||
|
||||
// GenerateWithProgress creates an image with progress callback.
|
||||
func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress func(step, totalSteps int)) (*mlx.Array, error) {
|
||||
return m.GenerateFromConfig(context.Background(), &GenerateConfig{
|
||||
Prompt: prompt,
|
||||
Width: width,
|
||||
Height: height,
|
||||
Steps: steps,
|
||||
Seed: seed,
|
||||
Progress: progress,
|
||||
})
|
||||
}
|
||||
|
||||
// GenerateFromConfig generates an image using the unified config struct.
|
||||
func (m *Model) GenerateFromConfig(ctx context.Context, cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
start := time.Now()
|
||||
result, err := m.generate(ctx, cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fmt.Printf("Generated in %.2fs (%d steps)\n", time.Since(start).Seconds(), cfg.Steps)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GenerateImage implements runner.ImageModel interface.
|
||||
func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64, progress func(step, total int)) (*mlx.Array, error) {
|
||||
return m.GenerateFromConfig(ctx, &GenerateConfig{
|
||||
Prompt: prompt,
|
||||
Width: width,
|
||||
Height: height,
|
||||
Steps: steps,
|
||||
Seed: seed,
|
||||
Progress: progress,
|
||||
})
|
||||
}
|
||||
|
||||
// MaxOutputPixels is the maximum output resolution (4 megapixels, ~2048x2048)
|
||||
const MaxOutputPixels = 2048 * 2048
|
||||
|
||||
// MaxRefPixels is the maximum resolution for reference images (smaller to reduce attention memory)
|
||||
const MaxRefPixels = 728 * 728
|
||||
|
||||
// generate is the internal denoising pipeline.
|
||||
func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
// Enable MLX compilation for fused kernels
|
||||
mlx.EnableCompile()
|
||||
|
||||
// Apply defaults
|
||||
if cfg.Steps <= 0 {
|
||||
cfg.Steps = 4 // Klein default: 4 steps for distilled model
|
||||
}
|
||||
if cfg.GuidanceScale <= 0 {
|
||||
cfg.GuidanceScale = 1.0 // Klein doesn't need guidance
|
||||
}
|
||||
|
||||
// Determine output dimensions
|
||||
if len(cfg.InputImages) > 0 {
|
||||
// With input images, compute missing dimension from aspect ratio
|
||||
// Images are already EXIF-rotated by the caller
|
||||
bounds := cfg.InputImages[0].Bounds()
|
||||
imgW, imgH := bounds.Dx(), bounds.Dy()
|
||||
aspectRatio := float64(imgH) / float64(imgW)
|
||||
if cfg.Width > 0 && cfg.Height <= 0 {
|
||||
// Width specified, compute height
|
||||
cfg.Height = int32(math.Round(float64(cfg.Width)*aspectRatio/16) * 16)
|
||||
} else if cfg.Height > 0 && cfg.Width <= 0 {
|
||||
// Height specified, compute width
|
||||
cfg.Width = int32(math.Round(float64(cfg.Height)/aspectRatio/16) * 16)
|
||||
} else if cfg.Width <= 0 && cfg.Height <= 0 {
|
||||
// Neither specified, use input dimensions
|
||||
cfg.Width = int32(imgW)
|
||||
cfg.Height = int32(imgH)
|
||||
}
|
||||
}
|
||||
if cfg.Width <= 0 {
|
||||
cfg.Width = 1024
|
||||
}
|
||||
if cfg.Height <= 0 {
|
||||
cfg.Height = 1024
|
||||
}
|
||||
|
||||
// Cap to max pixels, preserve aspect ratio, round to multiple of 16
|
||||
pixels := int(cfg.Width) * int(cfg.Height)
|
||||
if pixels > MaxOutputPixels {
|
||||
scale := math.Sqrt(float64(MaxOutputPixels) / float64(pixels))
|
||||
cfg.Width = int32(math.Round(float64(cfg.Width) * scale / 16) * 16)
|
||||
cfg.Height = int32(math.Round(float64(cfg.Height) * scale / 16) * 16)
|
||||
}
|
||||
cfg.Height = int32((cfg.Height + 8) / 16 * 16) // round to nearest 16
|
||||
cfg.Width = int32((cfg.Width + 8) / 16 * 16)
|
||||
fmt.Printf(" Output: %dx%d\n", cfg.Width, cfg.Height)
|
||||
|
||||
tcfg := m.Transformer.TransformerConfig
|
||||
patchSize := m.VAE.Config.PatchSize
|
||||
|
||||
// Latent dimensions: image / 8 (VAE downscale) / patch_size
|
||||
latentH := cfg.Height / 8
|
||||
latentW := cfg.Width / 8
|
||||
patchH := latentH / patchSize[0]
|
||||
patchW := latentW / patchSize[1]
|
||||
imgSeqLen := patchH * patchW
|
||||
|
||||
// Text encoding with multi-layer extraction (no padding, use true sequence length)
|
||||
fmt.Print(" Encoding prompt... ")
|
||||
promptEmbeds, textLen := m.TextEncoder.EncodePromptWithLayers(m.Tokenizer, cfg.Prompt, 512, TextEncoderLayerIndices, false)
|
||||
fmt.Println("✓")
|
||||
|
||||
// Encode reference images if provided
|
||||
var refTokens *ImageCondTokens
|
||||
var refHeights, refWidths []int32
|
||||
if len(cfg.InputImages) > 0 {
|
||||
fmt.Printf(" Encoding %d reference image(s):\n", len(cfg.InputImages))
|
||||
|
||||
var err error
|
||||
refTokens, err = m.EncodeImageRefs(cfg.InputImages)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("encode reference images: %w", err)
|
||||
}
|
||||
|
||||
// Extract heights/widths for RoPE computation (same limits as EncodeImageRefs)
|
||||
limitPixels := MaxRefPixels
|
||||
if len(cfg.InputImages) > 1 {
|
||||
limitPixels = MaxRefPixels / 2
|
||||
}
|
||||
for _, img := range cfg.InputImages {
|
||||
_, w, h := PrepareImage(img, limitPixels)
|
||||
refHeights = append(refHeights, int32(h/16))
|
||||
refWidths = append(refWidths, int32(w/16))
|
||||
}
|
||||
}
|
||||
|
||||
// Scheduler
|
||||
scheduler := NewFlowMatchScheduler(m.SchedulerConfig)
|
||||
scheduler.SetTimestepsWithMu(cfg.Steps, CalculateShift(imgSeqLen, cfg.Steps))
|
||||
|
||||
// Init latents in packed form [B, C*4, H/2, W/2] like diffusers
|
||||
// diffusers creates noise in [B, 128, 64, 64] and packs to [B, 4096, 128]
|
||||
latentChannels := m.VAE.Config.LatentChannels
|
||||
packedChannels := latentChannels * 4 // 32 * 4 = 128
|
||||
latents := scheduler.InitNoise([]int32{1, packedChannels, patchH, patchW}, cfg.Seed)
|
||||
|
||||
// Pack latents (transpose): [B, C, H, W] -> [B, H*W, C]
|
||||
// This matches diffusers' _pack_latents
|
||||
patches := packLatents(latents)
|
||||
noiseSeqLen := patches.Shape()[1]
|
||||
|
||||
// RoPE cache - includes reference images if present
|
||||
rope := PrepareRoPECache(textLen, patchH, patchW, tcfg.AxesDimsRoPE, tcfg.RopeTheta, refHeights, refWidths, ImageRefScale)
|
||||
|
||||
// Cleanup setup arrays when done
|
||||
defer func() {
|
||||
rope.Cos.Free()
|
||||
rope.Sin.Free()
|
||||
promptEmbeds.Free()
|
||||
if refTokens != nil {
|
||||
refTokens.Tokens.Free()
|
||||
}
|
||||
}()
|
||||
|
||||
// Pre-compute all timesteps before the loop to avoid per-step tensor creation
|
||||
timesteps := make([]*mlx.Array, cfg.Steps)
|
||||
for i := 0; i < cfg.Steps; i++ {
|
||||
tCurr := scheduler.Timesteps[i] / float32(m.SchedulerConfig.NumTrainTimesteps)
|
||||
timesteps[i] = mlx.ToBFloat16(mlx.NewArray([]float32{tCurr}, []int32{1}))
|
||||
}
|
||||
|
||||
// Evaluate setup arrays
|
||||
fmt.Print(" Evaluating setup... ")
|
||||
setupStart := time.Now()
|
||||
toEval := []*mlx.Array{promptEmbeds, patches, rope.Cos, rope.Sin}
|
||||
toEval = append(toEval, timesteps...)
|
||||
if refTokens != nil {
|
||||
toEval = append(toEval, refTokens.Tokens)
|
||||
}
|
||||
mlx.Eval(toEval...)
|
||||
mlx.MetalResetPeakMemory() // Reset peak to measure generation separately
|
||||
fmt.Printf("✓ (%.2fs, %.1f GB)\n", time.Since(setupStart).Seconds(),
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024))
|
||||
|
||||
if cfg.Progress != nil {
|
||||
cfg.Progress(0, cfg.Steps)
|
||||
}
|
||||
|
||||
loopStart := time.Now()
|
||||
stepStart := time.Now()
|
||||
|
||||
// Denoising loop
|
||||
for i := 0; i < cfg.Steps; i++ {
|
||||
// Check for cancellation
|
||||
if ctx != nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// GPU capture on step 2 if requested
|
||||
if cfg.CapturePath != "" && i == 1 {
|
||||
mlx.MetalStartCapture(cfg.CapturePath)
|
||||
}
|
||||
|
||||
timestep := timesteps[i]
|
||||
|
||||
// Prepare input - concatenate noise patches with reference tokens if present
|
||||
imgInput := patches
|
||||
if refTokens != nil {
|
||||
imgInput = mlx.Concatenate([]*mlx.Array{patches, refTokens.Tokens}, 1)
|
||||
}
|
||||
|
||||
// Transformer forward pass
|
||||
output := m.Transformer.Forward(imgInput, promptEmbeds, timestep, rope)
|
||||
|
||||
// If we concatenated reference tokens, slice to only get noise portion
|
||||
if refTokens != nil {
|
||||
output = mlx.Slice(output, []int32{0, 0, 0}, []int32{1, noiseSeqLen, output.Shape()[2]})
|
||||
}
|
||||
|
||||
// Scheduler step (keep reference to old patches for the computation graph)
|
||||
newPatches := scheduler.Step(output, patches, i)
|
||||
|
||||
if cfg.CapturePath != "" && i == 1 {
|
||||
mlx.MetalStopCapture()
|
||||
}
|
||||
|
||||
mlx.Eval(newPatches)
|
||||
patches = newPatches
|
||||
|
||||
elapsed := time.Since(stepStart).Seconds()
|
||||
peakGB := float64(mlx.MetalGetPeakMemory()) / (1024 * 1024 * 1024)
|
||||
if i == 0 {
|
||||
fmt.Printf(" step %d: %.2fs (JIT warmup), peak %.1f GB\n", i+1, elapsed, peakGB)
|
||||
} else {
|
||||
fmt.Printf(" step %d: %.2fs, peak %.1f GB\n", i+1, elapsed, peakGB)
|
||||
}
|
||||
stepStart = time.Now()
|
||||
if cfg.Progress != nil {
|
||||
cfg.Progress(i+1, cfg.Steps)
|
||||
}
|
||||
}
|
||||
|
||||
loopTime := time.Since(loopStart).Seconds()
|
||||
peakMem := float64(mlx.MetalGetPeakMemory()) / (1024 * 1024 * 1024)
|
||||
fmt.Printf(" Denoised %d steps in %.2fs (%.2fs/step), peak %.1f GB\n",
|
||||
cfg.Steps, loopTime, loopTime/float64(cfg.Steps), peakMem)
|
||||
|
||||
// Free timesteps now that denoising is done
|
||||
for _, ts := range timesteps {
|
||||
ts.Free()
|
||||
}
|
||||
|
||||
// VAE decode with tiling for larger images
|
||||
fmt.Print(" Decoding VAE... ")
|
||||
vaeStart := time.Now()
|
||||
// Enable tiling for images > 512x512 (latent > 64x64)
|
||||
// VAE attention is O(n²) on latent pixels, tiling reduces memory significantly
|
||||
if patchH*2 > 64 || patchW*2 > 64 {
|
||||
m.VAE.Tiling = DefaultTilingConfig()
|
||||
}
|
||||
decoded := m.VAE.Decode(patches, patchH, patchW)
|
||||
mlx.Eval(decoded)
|
||||
|
||||
// Free patches now that decode is done
|
||||
patches.Free()
|
||||
|
||||
fmt.Printf("✓ (%.2fs, peak %.1f GB)\n", time.Since(vaeStart).Seconds(),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
return decoded, nil
|
||||
}
|
||||
|
||||
// packLatents converts [B, C, H, W] to [B, H*W, C] (matches diffusers _pack_latents)
|
||||
func packLatents(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
C := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
// [B, C, H, W] -> [B, C, H*W] -> [B, H*W, C]
|
||||
x = mlx.Reshape(x, B, C, H*W)
|
||||
return mlx.Transpose(x, 0, 2, 1)
|
||||
}
|
||||
|
||||
// LoadPersistent loads the model and keeps it in memory for repeated use.
|
||||
func LoadPersistent(modelName string) (*Model, error) {
|
||||
m := &Model{}
|
||||
if err := m.Load(modelName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// ImageRefScale is the time coordinate offset between reference images (matches diffusers scale=10)
|
||||
const ImageRefScale = 10
|
||||
|
||||
// PrepareImage resizes and crops an image to be a multiple of 16, with optional pixel limit.
|
||||
// Returns the processed image and its dimensions.
|
||||
func PrepareImage(img image.Image, limitPixels int) (image.Image, int, int) {
|
||||
bounds := img.Bounds()
|
||||
w, h := bounds.Dx(), bounds.Dy()
|
||||
|
||||
// Cap pixels if needed (like diffusers cap_pixels)
|
||||
if limitPixels > 0 && w*h > limitPixels {
|
||||
scale := math.Sqrt(float64(limitPixels) / float64(w*h))
|
||||
w = int(float64(w) * scale)
|
||||
h = int(float64(h) * scale)
|
||||
}
|
||||
|
||||
// Round down to multiple of 16
|
||||
w = (w / 16) * 16
|
||||
h = (h / 16) * 16
|
||||
|
||||
if w < 16 {
|
||||
w = 16
|
||||
}
|
||||
if h < 16 {
|
||||
h = 16
|
||||
}
|
||||
|
||||
// Resize using high-quality bicubic interpolation (matches diffusers' default lanczos)
|
||||
resized := image.NewRGBA(image.Rect(0, 0, w, h))
|
||||
draw.CatmullRom.Scale(resized, resized.Bounds(), img, img.Bounds(), draw.Over, nil)
|
||||
|
||||
return resized, w, h
|
||||
}
|
||||
|
||||
// ImageToTensor converts an image to a tensor in [-1, 1] range with shape [1, C, H, W].
|
||||
func ImageToTensor(img image.Image) *mlx.Array {
|
||||
bounds := img.Bounds()
|
||||
w, h := bounds.Dx(), bounds.Dy()
|
||||
|
||||
// Convert to float32 array in NCHW format [1, 3, H, W] with values in [-1, 1]
|
||||
data := make([]float32, 3*h*w)
|
||||
|
||||
for y := 0; y < h; y++ {
|
||||
for x := 0; x < w; x++ {
|
||||
r, g, b, _ := img.At(x+bounds.Min.X, y+bounds.Min.Y).RGBA()
|
||||
// RGBA returns 16-bit values, convert to [-1, 1]
|
||||
data[0*h*w+y*w+x] = float32(r>>8)/127.5 - 1.0
|
||||
data[1*h*w+y*w+x] = float32(g>>8)/127.5 - 1.0
|
||||
data[2*h*w+y*w+x] = float32(b>>8)/127.5 - 1.0
|
||||
}
|
||||
}
|
||||
|
||||
arr := mlx.NewArrayFloat32(data, []int32{1, 3, int32(h), int32(w)})
|
||||
return arr
|
||||
}
|
||||
|
||||
// ImageCondTokens holds encoded reference image tokens.
|
||||
type ImageCondTokens struct {
|
||||
Tokens *mlx.Array // [1, total_tokens, C] - concatenated reference tokens
|
||||
}
|
||||
|
||||
// EncodeImageRefs encodes reference images using the VAE.
|
||||
func (m *Model) EncodeImageRefs(images []image.Image) (*ImageCondTokens, error) {
|
||||
if len(images) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Limit reference images to reduce attention memory
|
||||
limitPixels := MaxRefPixels
|
||||
if len(images) > 1 {
|
||||
limitPixels = MaxRefPixels / 2
|
||||
}
|
||||
|
||||
var allTokens []*mlx.Array
|
||||
|
||||
for _, img := range images {
|
||||
// Prepare image (resize, crop to multiple of 16)
|
||||
prepared, prepW, prepH := PrepareImage(img, limitPixels)
|
||||
fmt.Printf(" Encoding %dx%d image... ", prepW, prepH)
|
||||
|
||||
// Convert to tensor [-1, 1]
|
||||
tensor := ImageToTensor(prepared)
|
||||
|
||||
// Encode with VAE - returns [1, L, 128]
|
||||
encoded := m.VAE.EncodeImage(tensor)
|
||||
squeezed := mlx.Squeeze(encoded, 0) // [L, C]
|
||||
|
||||
// Defer eval - will be done with other setup arrays
|
||||
allTokens = append(allTokens, squeezed)
|
||||
fmt.Println("✓")
|
||||
}
|
||||
|
||||
// For single image, just add batch dimension directly
|
||||
// For multiple images, concatenate first
|
||||
var tokens *mlx.Array
|
||||
if len(allTokens) == 1 {
|
||||
tokens = mlx.ExpandDims(allTokens[0], 0) // [1, L, C]
|
||||
} else {
|
||||
tokens = mlx.Concatenate(allTokens, 0) // [total_L, C]
|
||||
tokens = mlx.ExpandDims(tokens, 0) // [1, total_L, C]
|
||||
}
|
||||
|
||||
return &ImageCondTokens{Tokens: tokens}, nil
|
||||
}
|
||||
@@ -1,224 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package flux2
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// RoPEConfig holds 4D RoPE configuration for Flux2
|
||||
type RoPEConfig struct {
|
||||
Theta int32 // 2000 for Klein
|
||||
AxesDims []int32 // [32, 32, 32, 32] - dimensions for T, H, W, L axes
|
||||
}
|
||||
|
||||
// RoPECache holds precomputed RoPE cos/sin values
|
||||
type RoPECache struct {
|
||||
Cos *mlx.Array // [1, TotalSeqLen, 1, head_dim/2]
|
||||
Sin *mlx.Array // [1, TotalSeqLen, 1, head_dim/2]
|
||||
TextLen int32 // Length of text sequence
|
||||
ImageLen int32 // Length of image sequence
|
||||
}
|
||||
|
||||
// PrepareTextIDs creates position IDs for text tokens.
|
||||
// Text tokens use: T=0, H=0, W=0, L=0..seqLen-1
|
||||
// Returns: [seqLen, 4]
|
||||
func PrepareTextIDs(seqLen int32) *mlx.Array {
|
||||
ids := make([]float32, seqLen*4)
|
||||
for i := int32(0); i < seqLen; i++ {
|
||||
idx := i * 4
|
||||
ids[idx+0] = 0 // T = 0
|
||||
ids[idx+1] = 0 // H = 0
|
||||
ids[idx+2] = 0 // W = 0
|
||||
ids[idx+3] = float32(i) // L = sequence position
|
||||
}
|
||||
return mlx.NewArray(ids, []int32{seqLen, 4})
|
||||
}
|
||||
|
||||
// PrepareLatentIDs creates position IDs for image latent tokens.
|
||||
// Latent tokens use: T=0, H=0..height-1, W=0..width-1, L=0
|
||||
// The latents are in row-major order (H then W).
|
||||
// Returns: [height*width, 4]
|
||||
func PrepareLatentIDs(height, width int32) *mlx.Array {
|
||||
seqLen := height * width
|
||||
ids := make([]float32, seqLen*4)
|
||||
idx := 0
|
||||
for h := int32(0); h < height; h++ {
|
||||
for w := int32(0); w < width; w++ {
|
||||
ids[idx*4+0] = 0 // T = 0
|
||||
ids[idx*4+1] = float32(h) // H = row
|
||||
ids[idx*4+2] = float32(w) // W = column
|
||||
ids[idx*4+3] = 0 // L = 0
|
||||
idx++
|
||||
}
|
||||
}
|
||||
return mlx.NewArray(ids, []int32{seqLen, 4})
|
||||
}
|
||||
|
||||
// PrepareImageIDs creates position IDs for reference image tokens (used in editing).
|
||||
// Reference images use: T=scale*(i+1), H=0..h-1, W=0..w-1, L=0
|
||||
// where i is the image index (0, 1, 2, ...) and scale separates images in T dimension.
|
||||
// Returns: [total_tokens, 4]
|
||||
func PrepareImageIDs(imageHeights, imageWidths []int32, scale int32) *mlx.Array {
|
||||
// Calculate total tokens
|
||||
totalTokens := int32(0)
|
||||
for i := range imageHeights {
|
||||
totalTokens += imageHeights[i] * imageWidths[i]
|
||||
}
|
||||
|
||||
ids := make([]float32, totalTokens*4)
|
||||
idx := int32(0)
|
||||
for imgIdx, h := range imageHeights {
|
||||
w := imageWidths[imgIdx]
|
||||
tValue := float32(scale * int32(imgIdx+1))
|
||||
for hi := int32(0); hi < h; hi++ {
|
||||
for wi := int32(0); wi < w; wi++ {
|
||||
ids[idx*4+0] = tValue // T = scale * (imgIdx + 1)
|
||||
ids[idx*4+1] = float32(hi) // H = row
|
||||
ids[idx*4+2] = float32(wi) // W = column
|
||||
ids[idx*4+3] = 0 // L = 0
|
||||
idx++
|
||||
}
|
||||
}
|
||||
}
|
||||
return mlx.NewArray(ids, []int32{totalTokens, 4})
|
||||
}
|
||||
|
||||
// ComputeRoPE computes cos and sin for 4D rotary position embeddings.
|
||||
// ids: [L, 4] with (T, H, W, L) coordinates
|
||||
// axesDims: [32, 32, 32, 32] - each axis has this many dimensions (total = head_dim = 128)
|
||||
// theta: base frequency (2000 for Klein)
|
||||
// Returns: cos, sin each [1, L, 1, head_dim] with repeat_interleave applied
|
||||
func ComputeRoPE(ids *mlx.Array, axesDims []int32, theta int32) (*mlx.Array, *mlx.Array) {
|
||||
shape := ids.Shape()
|
||||
seqLen := shape[0]
|
||||
|
||||
// Compute total head dim (sum of all axes dims)
|
||||
headDim := int32(0)
|
||||
for _, d := range axesDims {
|
||||
headDim += d
|
||||
}
|
||||
|
||||
// Extract each coordinate dimension
|
||||
// ids[:, 0] = T, ids[:, 1] = H, ids[:, 2] = W, ids[:, 3] = L
|
||||
posT := mlx.Slice(ids, []int32{0, 0}, []int32{seqLen, 1}) // [L, 1]
|
||||
posH := mlx.Slice(ids, []int32{0, 1}, []int32{seqLen, 2}) // [L, 1]
|
||||
posW := mlx.Slice(ids, []int32{0, 2}, []int32{seqLen, 3}) // [L, 1]
|
||||
posL := mlx.Slice(ids, []int32{0, 3}, []int32{seqLen, 4}) // [L, 1]
|
||||
|
||||
// Compute frequencies for each axis
|
||||
logTheta := float32(math.Log(float64(theta)))
|
||||
cosArrs := make([]*mlx.Array, 4)
|
||||
sinArrs := make([]*mlx.Array, 4)
|
||||
positions := []*mlx.Array{posT, posH, posW, posL}
|
||||
|
||||
for i, axisDim := range axesDims {
|
||||
half := axisDim / 2
|
||||
|
||||
// Create frequency array for this axis: theta^(-2j/dim) for j=0..half-1
|
||||
// This matches diffusers: 1.0 / (theta ** (torch.arange(0, dim, 2) / dim))
|
||||
freqs := make([]float32, half)
|
||||
for j := int32(0); j < half; j++ {
|
||||
freqs[j] = float32(math.Exp(float64(-logTheta * float32(2*j) / float32(axisDim))))
|
||||
}
|
||||
freqArr := mlx.NewArray(freqs, []int32{1, half})
|
||||
|
||||
// Compute pos * freq -> [L, half]
|
||||
posExpanded := positions[i] // [L, 1]
|
||||
args := mlx.Mul(posExpanded, freqArr) // [L, half]
|
||||
|
||||
// Compute cos and sin for this axis
|
||||
cosAxis := mlx.Cos(args) // [L, half]
|
||||
sinAxis := mlx.Sin(args) // [L, half]
|
||||
|
||||
// repeat_interleave(2): [c0, c1, ...] -> [c0, c0, c1, c1, ...]
|
||||
// Reshape [L, half] -> [L, half, 1], tile to [L, half, 2], reshape to [L, axisDim]
|
||||
cosAxis = mlx.ExpandDims(cosAxis, 2) // [L, half, 1]
|
||||
cosAxis = mlx.Tile(cosAxis, []int32{1, 1, 2}) // [L, half, 2]
|
||||
cosAxis = mlx.Reshape(cosAxis, seqLen, axisDim) // [L, axisDim]
|
||||
|
||||
sinAxis = mlx.ExpandDims(sinAxis, 2)
|
||||
sinAxis = mlx.Tile(sinAxis, []int32{1, 1, 2})
|
||||
sinAxis = mlx.Reshape(sinAxis, seqLen, axisDim)
|
||||
|
||||
cosArrs[i] = cosAxis
|
||||
sinArrs[i] = sinAxis
|
||||
}
|
||||
|
||||
// Concatenate all axes: [L, headDim]
|
||||
cos := mlx.Concatenate(cosArrs, 1)
|
||||
sin := mlx.Concatenate(sinArrs, 1)
|
||||
|
||||
// Reshape to [1, L, 1, headDim] for broadcasting with attention
|
||||
cos = mlx.Reshape(cos, 1, seqLen, 1, headDim)
|
||||
sin = mlx.Reshape(sin, 1, seqLen, 1, headDim)
|
||||
|
||||
return cos, sin
|
||||
}
|
||||
|
||||
// ApplyRoPE4D applies 4D rotary position embeddings to queries and keys.
|
||||
// x: [B, L, nheads, head_dim]
|
||||
// cos, sin: [1, L, 1, head_dim] (with repeat_interleave applied)
|
||||
// Returns: x with RoPE applied
|
||||
// Matches diffusers apply_rotary_emb with use_real=True, use_real_unbind_dim=-1
|
||||
func ApplyRoPE4D(x *mlx.Array, cos, sin *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
nheads := shape[2]
|
||||
headDim := shape[3]
|
||||
half := headDim / 2
|
||||
|
||||
// Reshape x to [B, L, nheads, half, 2] and split into real/imag
|
||||
xReshaped := mlx.Reshape(x, B, L, nheads, half, 2)
|
||||
|
||||
// Extract real (index 0) and imag (index 1) parts
|
||||
xReal := mlx.Slice(xReshaped, []int32{0, 0, 0, 0, 0}, []int32{B, L, nheads, half, 1})
|
||||
xImag := mlx.Slice(xReshaped, []int32{0, 0, 0, 0, 1}, []int32{B, L, nheads, half, 2})
|
||||
xReal = mlx.Squeeze(xReal, 4) // [B, L, nheads, half]
|
||||
xImag = mlx.Squeeze(xImag, 4) // [B, L, nheads, half]
|
||||
|
||||
// x_rotated = stack([-x_imag, x_real], dim=-1).flatten(-2)
|
||||
// This creates [-x_imag[0], x_real[0], -x_imag[1], x_real[1], ...]
|
||||
negXImag := mlx.Neg(xImag)
|
||||
negXImag = mlx.ExpandDims(negXImag, 4) // [B, L, nheads, half, 1]
|
||||
xReal = mlx.ExpandDims(xReal, 4) // [B, L, nheads, half, 1]
|
||||
xRotated := mlx.Concatenate([]*mlx.Array{negXImag, xReal}, 4) // [B, L, nheads, half, 2]
|
||||
xRotated = mlx.Reshape(xRotated, B, L, nheads, headDim) // [B, L, nheads, headDim]
|
||||
|
||||
// out = x * cos + x_rotated * sin
|
||||
return mlx.Add(mlx.Mul(x, cos), mlx.Mul(xRotated, sin))
|
||||
}
|
||||
|
||||
// PrepareRoPECache creates RoPE cache for text + noise, optionally with reference images.
|
||||
// textLen: number of text tokens
|
||||
// noiseH, noiseW: dimensions of the noise latent in patch tokens
|
||||
// axesDims: [32, 32, 32, 32]
|
||||
// theta: 2000
|
||||
// refHeights, refWidths: optional reference image dimensions (pass nil/empty for no images)
|
||||
// scale: time coordinate offset between reference images (e.g., 10)
|
||||
func PrepareRoPECache(textLen, noiseH, noiseW int32, axesDims []int32, theta int32, refHeights, refWidths []int32, scale int32) *RoPECache {
|
||||
textIDs := PrepareTextIDs(textLen)
|
||||
noiseIDs := PrepareLatentIDs(noiseH, noiseW)
|
||||
|
||||
var allIDs *mlx.Array
|
||||
imageLen := noiseH * noiseW
|
||||
|
||||
if len(refHeights) > 0 {
|
||||
refIDs := PrepareImageIDs(refHeights, refWidths, scale)
|
||||
allIDs = mlx.Concatenate([]*mlx.Array{textIDs, noiseIDs, refIDs}, 0)
|
||||
for i := range refHeights {
|
||||
imageLen += refHeights[i] * refWidths[i]
|
||||
}
|
||||
} else {
|
||||
allIDs = mlx.Concatenate([]*mlx.Array{textIDs, noiseIDs}, 0)
|
||||
}
|
||||
|
||||
cos, sin := ComputeRoPE(allIDs, axesDims, theta)
|
||||
cos = mlx.ToBFloat16(cos)
|
||||
sin = mlx.ToBFloat16(sin)
|
||||
|
||||
return &RoPECache{Cos: cos, Sin: sin, TextLen: textLen, ImageLen: imageLen}
|
||||
}
|
||||
@@ -1,149 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package flux2
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// SchedulerConfig holds Flow-Match scheduler configuration
|
||||
type SchedulerConfig struct {
|
||||
NumTrainTimesteps int32 `json:"num_train_timesteps"` // 1000
|
||||
Shift float32 `json:"shift"` // 3.0 for Klein
|
||||
UseDynamicShifting bool `json:"use_dynamic_shifting"` // true
|
||||
TimeShiftType string `json:"time_shift_type"` // "exponential" or "linear"
|
||||
}
|
||||
|
||||
// DefaultSchedulerConfig returns default config for Klein
|
||||
func DefaultSchedulerConfig() *SchedulerConfig {
|
||||
return &SchedulerConfig{
|
||||
NumTrainTimesteps: 1000,
|
||||
Shift: 3.0, // Klein uses 3.0
|
||||
UseDynamicShifting: true,
|
||||
TimeShiftType: "exponential",
|
||||
}
|
||||
}
|
||||
|
||||
// FlowMatchScheduler implements the Flow-Match Euler discrete scheduler
|
||||
type FlowMatchScheduler struct {
|
||||
Config *SchedulerConfig
|
||||
Timesteps []float32 // Discretized timesteps (t from 1 to 0)
|
||||
Sigmas []float32 // Noise levels at each timestep
|
||||
NumSteps int // Number of inference steps
|
||||
}
|
||||
|
||||
// NewFlowMatchScheduler creates a new scheduler
|
||||
func NewFlowMatchScheduler(cfg *SchedulerConfig) *FlowMatchScheduler {
|
||||
return &FlowMatchScheduler{
|
||||
Config: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
// SetTimesteps sets up the scheduler for the given number of inference steps
|
||||
func (s *FlowMatchScheduler) SetTimesteps(numSteps int) {
|
||||
s.SetTimestepsWithMu(numSteps, 0)
|
||||
}
|
||||
|
||||
// SetTimestepsWithMu sets up scheduler matching diffusers set_timesteps(sigmas=..., mu=...)
|
||||
func (s *FlowMatchScheduler) SetTimestepsWithMu(numSteps int, mu float32) {
|
||||
s.NumSteps = numSteps
|
||||
|
||||
// diffusers: sigmas = linspace(1, 1/num_steps, num_steps)
|
||||
// Then applies time shift, appends 0.0 at end
|
||||
s.Sigmas = make([]float32, numSteps+1)
|
||||
|
||||
for i := 0; i < numSteps; i++ {
|
||||
// linspace(1, 1/num_steps, num_steps)
|
||||
var sigma float32
|
||||
if numSteps == 1 {
|
||||
sigma = 1.0
|
||||
} else {
|
||||
sigma = 1.0 - float32(i)/float32(numSteps-1)*(1.0-1.0/float32(numSteps))
|
||||
}
|
||||
|
||||
// Apply time shift if using dynamic shifting
|
||||
if s.Config.UseDynamicShifting && mu != 0 {
|
||||
sigma = s.timeShift(mu, sigma)
|
||||
} else {
|
||||
// If not dynamic shifting, apply fixed shift scaling like diffusers
|
||||
shift := s.Config.Shift
|
||||
sigma = shift * sigma / (1 + (shift-1)*sigma)
|
||||
}
|
||||
s.Sigmas[i] = sigma
|
||||
}
|
||||
// Append terminal zero
|
||||
s.Sigmas[numSteps] = 0.0
|
||||
|
||||
// Timesteps scaled to training range (matches diffusers: timesteps = sigmas * num_train_timesteps)
|
||||
s.Timesteps = make([]float32, numSteps+1)
|
||||
for i, v := range s.Sigmas {
|
||||
s.Timesteps[i] = v * float32(s.Config.NumTrainTimesteps)
|
||||
}
|
||||
}
|
||||
|
||||
// timeShift applies the dynamic time shift
|
||||
func (s *FlowMatchScheduler) timeShift(mu float32, t float32) float32 {
|
||||
if t <= 0 {
|
||||
return 0
|
||||
}
|
||||
if s.Config.TimeShiftType == "linear" {
|
||||
return mu / (mu + (1.0/t-1.0))
|
||||
}
|
||||
// Default: exponential
|
||||
expMu := float32(math.Exp(float64(mu)))
|
||||
return expMu / (expMu + (1.0/t - 1.0))
|
||||
}
|
||||
|
||||
// Step performs one denoising step
|
||||
func (s *FlowMatchScheduler) Step(modelOutput, sample *mlx.Array, timestepIdx int) *mlx.Array {
|
||||
sigma := s.Sigmas[timestepIdx]
|
||||
sigmaNext := s.Sigmas[timestepIdx+1]
|
||||
|
||||
// Euler step: x_{t-dt} = x_t + (sigma_next - sigma) * v_t
|
||||
dt := sigmaNext - sigma
|
||||
|
||||
// Upcast to float32 for precision (matches diffusers)
|
||||
sampleF32 := mlx.AsType(sample, mlx.DtypeFloat32)
|
||||
outputF32 := mlx.AsType(modelOutput, mlx.DtypeFloat32)
|
||||
|
||||
scaledOutput := mlx.MulScalar(outputF32, dt)
|
||||
result := mlx.Add(sampleF32, scaledOutput)
|
||||
|
||||
// Cast back to bfloat16
|
||||
return mlx.ToBFloat16(result)
|
||||
}
|
||||
|
||||
// GetTimestep returns the timestep value at the given index
|
||||
func (s *FlowMatchScheduler) GetTimestep(idx int) float32 {
|
||||
if idx < len(s.Timesteps) {
|
||||
return s.Timesteps[idx]
|
||||
}
|
||||
return 0.0
|
||||
}
|
||||
|
||||
// InitNoise creates initial noise for sampling
|
||||
func (s *FlowMatchScheduler) InitNoise(shape []int32, seed int64) *mlx.Array {
|
||||
return mlx.RandomNormalWithDtype(shape, uint64(seed), mlx.DtypeBFloat16)
|
||||
}
|
||||
|
||||
// CalculateShift computes the mu shift value for dynamic scheduling
|
||||
// Matches diffusers compute_empirical_mu function
|
||||
func CalculateShift(imgSeqLen int32, numSteps int) float32 {
|
||||
a1, b1 := float32(8.73809524e-05), float32(1.89833333)
|
||||
a2, b2 := float32(0.00016927), float32(0.45666666)
|
||||
|
||||
seqLen := float32(imgSeqLen)
|
||||
|
||||
if imgSeqLen > 4300 {
|
||||
return a2*seqLen + b2
|
||||
}
|
||||
|
||||
m200 := a2*seqLen + b2
|
||||
m10 := a1*seqLen + b1
|
||||
|
||||
a := (m200 - m10) / 190.0
|
||||
b := m200 - 200.0*a
|
||||
return a*float32(numSteps) + b
|
||||
}
|
||||
@@ -1,562 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package flux2
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
|
||||
// TransformerConfig holds Flux2 transformer configuration
|
||||
type TransformerConfig struct {
|
||||
AttentionHeadDim int32 `json:"attention_head_dim"` // 128
|
||||
AxesDimsRoPE []int32 `json:"axes_dims_rope"` // [32, 32, 32, 32]
|
||||
Eps float32 `json:"eps"` // 1e-6
|
||||
GuidanceEmbeds bool `json:"guidance_embeds"` // false for Klein
|
||||
InChannels int32 `json:"in_channels"` // 128
|
||||
JointAttentionDim int32 `json:"joint_attention_dim"` // 7680
|
||||
MLPRatio float32 `json:"mlp_ratio"` // 3.0
|
||||
NumAttentionHeads int32 `json:"num_attention_heads"` // 24
|
||||
NumLayers int32 `json:"num_layers"` // 5
|
||||
NumSingleLayers int32 `json:"num_single_layers"` // 20
|
||||
PatchSize int32 `json:"patch_size"` // 1
|
||||
RopeTheta int32 `json:"rope_theta"` // 2000
|
||||
TimestepGuidanceChannels int32 `json:"timestep_guidance_channels"` // 256
|
||||
}
|
||||
|
||||
// Computed dimensions
|
||||
func (c *TransformerConfig) InnerDim() int32 {
|
||||
return c.NumAttentionHeads * c.AttentionHeadDim // 24 * 128 = 3072
|
||||
}
|
||||
|
||||
func (c *TransformerConfig) MLPHiddenDim() int32 {
|
||||
return int32(float32(c.InnerDim()) * c.MLPRatio) // 3072 * 3.0 = 9216
|
||||
}
|
||||
|
||||
// TimestepEmbedder creates timestep embeddings
|
||||
// Weight names: time_guidance_embed.timestep_embedder.linear_1.weight, linear_2.weight
|
||||
type TimestepEmbedder struct {
|
||||
Linear1 nn.LinearLayer `weight:"linear_1"`
|
||||
Linear2 nn.LinearLayer `weight:"linear_2"`
|
||||
EmbedDim int32 // 256
|
||||
}
|
||||
|
||||
// Forward creates sinusoidal embeddings and projects them
|
||||
func (t *TimestepEmbedder) Forward(timesteps *mlx.Array) *mlx.Array {
|
||||
half := t.EmbedDim / 2
|
||||
freqs := make([]float32, half)
|
||||
for i := int32(0); i < half; i++ {
|
||||
freqs[i] = float32(math.Exp(-math.Log(10000.0) * float64(i) / float64(half)))
|
||||
}
|
||||
freqsArr := mlx.NewArray(freqs, []int32{1, half})
|
||||
|
||||
// timesteps: [B] -> [B, 1]
|
||||
tExpanded := mlx.ExpandDims(timesteps, 1)
|
||||
// args: [B, half]
|
||||
args := mlx.Mul(tExpanded, freqsArr)
|
||||
|
||||
// [cos(args), sin(args)] -> [B, embed_dim]
|
||||
sinEmbed := mlx.Concatenate([]*mlx.Array{mlx.Cos(args), mlx.Sin(args)}, 1)
|
||||
|
||||
// MLP: linear_1 -> silu -> linear_2
|
||||
h := t.Linear1.Forward(sinEmbed)
|
||||
h = mlx.SiLU(h)
|
||||
return t.Linear2.Forward(h)
|
||||
}
|
||||
|
||||
// TimeGuidanceEmbed wraps the timestep embedder
|
||||
// Weight names: time_guidance_embed.timestep_embedder.*
|
||||
type TimeGuidanceEmbed struct {
|
||||
TimestepEmbedder *TimestepEmbedder `weight:"timestep_embedder"`
|
||||
}
|
||||
|
||||
// Forward computes timestep embeddings
|
||||
func (t *TimeGuidanceEmbed) Forward(timesteps *mlx.Array) *mlx.Array {
|
||||
return t.TimestepEmbedder.Forward(timesteps)
|
||||
}
|
||||
|
||||
// Modulation computes adaptive modulation parameters
|
||||
// Weight names: double_stream_modulation_img.linear.weight, etc.
|
||||
type Modulation struct {
|
||||
Linear nn.LinearLayer `weight:"linear"`
|
||||
}
|
||||
|
||||
// Forward computes modulation parameters
|
||||
func (m *Modulation) Forward(temb *mlx.Array) *mlx.Array {
|
||||
h := mlx.SiLU(temb)
|
||||
return m.Linear.Forward(h)
|
||||
}
|
||||
|
||||
// TransformerBlockAttn implements dual-stream attention
|
||||
// Weight names: transformer_blocks.N.attn.*
|
||||
type TransformerBlockAttn struct {
|
||||
// Image stream (separate Q, K, V projections)
|
||||
ToQ nn.LinearLayer `weight:"to_q"`
|
||||
ToK nn.LinearLayer `weight:"to_k"`
|
||||
ToV nn.LinearLayer `weight:"to_v"`
|
||||
// Note: to_out has .0 suffix in weights, handled specially
|
||||
ToOut0 nn.LinearLayer `weight:"to_out.0"`
|
||||
|
||||
// Text stream (add_ projections)
|
||||
AddQProj nn.LinearLayer `weight:"add_q_proj"`
|
||||
AddKProj nn.LinearLayer `weight:"add_k_proj"`
|
||||
AddVProj nn.LinearLayer `weight:"add_v_proj"`
|
||||
ToAddOut nn.LinearLayer `weight:"to_add_out"`
|
||||
|
||||
// QK norms for image stream
|
||||
NormQ *mlx.Array `weight:"norm_q.weight"`
|
||||
NormK *mlx.Array `weight:"norm_k.weight"`
|
||||
|
||||
// QK norms for text stream (added)
|
||||
NormAddedQ *mlx.Array `weight:"norm_added_q.weight"`
|
||||
NormAddedK *mlx.Array `weight:"norm_added_k.weight"`
|
||||
}
|
||||
|
||||
// FeedForward implements SwiGLU MLP
|
||||
// Weight names: transformer_blocks.N.ff.linear_in.weight, linear_out.weight
|
||||
type FeedForward struct {
|
||||
LinearIn nn.LinearLayer `weight:"linear_in"`
|
||||
LinearOut nn.LinearLayer `weight:"linear_out"`
|
||||
}
|
||||
|
||||
// Forward applies SwiGLU MLP
|
||||
func (ff *FeedForward) Forward(x *mlx.Array) *mlx.Array {
|
||||
// LinearIn outputs 2x hidden dim for SwiGLU
|
||||
h := ff.LinearIn.Forward(x)
|
||||
shape := h.Shape()
|
||||
half := shape[len(shape)-1] / 2
|
||||
|
||||
// Split into gate and up
|
||||
gate := mlx.Slice(h, []int32{0, 0, 0}, []int32{shape[0], shape[1], half})
|
||||
up := mlx.Slice(h, []int32{0, 0, half}, []int32{shape[0], shape[1], shape[2]})
|
||||
|
||||
// SwiGLU: silu(gate) * up
|
||||
h = mlx.Mul(mlx.SiLU(gate), up)
|
||||
return ff.LinearOut.Forward(h)
|
||||
}
|
||||
|
||||
// TransformerBlock implements a dual-stream transformer block
|
||||
// Weight names: transformer_blocks.N.*
|
||||
type TransformerBlock struct {
|
||||
Attn *TransformerBlockAttn `weight:"attn"`
|
||||
FF *FeedForward `weight:"ff"`
|
||||
FFContext *FeedForward `weight:"ff_context"`
|
||||
|
||||
// Config (set after loading)
|
||||
NHeads int32
|
||||
HeadDim int32
|
||||
Scale float32
|
||||
}
|
||||
|
||||
// Forward applies the dual-stream block
|
||||
// imgHidden: [B, imgLen, dim]
|
||||
// txtHidden: [B, txtLen, dim]
|
||||
// imgMod, txtMod: modulation params [B, 6*dim] each
|
||||
// cos, sin: RoPE values
|
||||
func (block *TransformerBlock) Forward(imgHidden, txtHidden *mlx.Array, imgMod, txtMod *mlx.Array, cos, sin *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
imgShape := imgHidden.Shape()
|
||||
B := imgShape[0]
|
||||
imgLen := imgShape[1]
|
||||
dim := imgShape[2]
|
||||
txtLen := txtHidden.Shape()[1]
|
||||
|
||||
// Parse modulation: 6 params each (shift1, scale1, gate1, shift2, scale2, gate2)
|
||||
imgShift1, imgScale1, imgGate1 := parseModulation3(imgMod, dim, 0)
|
||||
imgShift2, imgScale2, imgGate2 := parseModulation3(imgMod, dim, 3)
|
||||
txtShift1, txtScale1, txtGate1 := parseModulation3(txtMod, dim, 0)
|
||||
txtShift2, txtScale2, txtGate2 := parseModulation3(txtMod, dim, 3)
|
||||
|
||||
// === Attention branch ===
|
||||
// Modulate inputs
|
||||
imgNorm := modulateLayerNorm(imgHidden, imgShift1, imgScale1)
|
||||
txtNorm := modulateLayerNorm(txtHidden, txtShift1, txtScale1)
|
||||
|
||||
// Compute Q, K, V for image stream (separate projections)
|
||||
imgQ := block.Attn.ToQ.Forward(imgNorm)
|
||||
imgK := block.Attn.ToK.Forward(imgNorm)
|
||||
imgV := block.Attn.ToV.Forward(imgNorm)
|
||||
|
||||
// Compute Q, K, V for text stream (add_ projections)
|
||||
txtQ := block.Attn.AddQProj.Forward(txtNorm)
|
||||
txtK := block.Attn.AddKProj.Forward(txtNorm)
|
||||
txtV := block.Attn.AddVProj.Forward(txtNorm)
|
||||
|
||||
// Reshape for attention: [B, L, dim] -> [B, L, nheads, headDim]
|
||||
imgQ = mlx.Reshape(imgQ, B, imgLen, block.NHeads, block.HeadDim)
|
||||
imgK = mlx.Reshape(imgK, B, imgLen, block.NHeads, block.HeadDim)
|
||||
imgV = mlx.Reshape(imgV, B, imgLen, block.NHeads, block.HeadDim)
|
||||
txtQ = mlx.Reshape(txtQ, B, txtLen, block.NHeads, block.HeadDim)
|
||||
txtK = mlx.Reshape(txtK, B, txtLen, block.NHeads, block.HeadDim)
|
||||
txtV = mlx.Reshape(txtV, B, txtLen, block.NHeads, block.HeadDim)
|
||||
|
||||
// Apply QK norm (RMSNorm with learned scale)
|
||||
imgQ = applyQKNorm(imgQ, block.Attn.NormQ)
|
||||
imgK = applyQKNorm(imgK, block.Attn.NormK)
|
||||
txtQ = applyQKNorm(txtQ, block.Attn.NormAddedQ)
|
||||
txtK = applyQKNorm(txtK, block.Attn.NormAddedK)
|
||||
|
||||
// Concatenate for joint attention: text first, then image
|
||||
q := mlx.Concatenate([]*mlx.Array{txtQ, imgQ}, 1)
|
||||
k := mlx.Concatenate([]*mlx.Array{txtK, imgK}, 1)
|
||||
v := mlx.Concatenate([]*mlx.Array{txtV, imgV}, 1)
|
||||
|
||||
// Apply RoPE
|
||||
q = ApplyRoPE4D(q, cos, sin)
|
||||
k = ApplyRoPE4D(k, cos, sin)
|
||||
|
||||
// Transpose for SDPA: [B, nheads, L, headDim]
|
||||
q = mlx.Transpose(q, 0, 2, 1, 3)
|
||||
k = mlx.Transpose(k, 0, 2, 1, 3)
|
||||
v = mlx.Transpose(v, 0, 2, 1, 3)
|
||||
|
||||
// Scaled dot-product attention
|
||||
out := mlx.ScaledDotProductAttention(q, k, v, block.Scale, false)
|
||||
|
||||
// Transpose back: [B, L, nheads, headDim]
|
||||
out = mlx.Transpose(out, 0, 2, 1, 3)
|
||||
|
||||
// Split back into txt and img
|
||||
totalLen := txtLen + imgLen
|
||||
txtOut := mlx.Slice(out, []int32{0, 0, 0, 0}, []int32{B, txtLen, block.NHeads, block.HeadDim})
|
||||
imgOut := mlx.Slice(out, []int32{0, txtLen, 0, 0}, []int32{B, totalLen, block.NHeads, block.HeadDim})
|
||||
|
||||
// Reshape and project
|
||||
txtOut = mlx.Reshape(txtOut, B, txtLen, dim)
|
||||
imgOut = mlx.Reshape(imgOut, B, imgLen, dim)
|
||||
txtOut = block.Attn.ToAddOut.Forward(txtOut)
|
||||
imgOut = block.Attn.ToOut0.Forward(imgOut)
|
||||
|
||||
// Apply gates and residual
|
||||
imgHidden = mlx.Add(imgHidden, mlx.Mul(imgGate1, imgOut))
|
||||
txtHidden = mlx.Add(txtHidden, mlx.Mul(txtGate1, txtOut))
|
||||
|
||||
// === MLP branch ===
|
||||
imgNorm = modulateLayerNorm(imgHidden, imgShift2, imgScale2)
|
||||
txtNorm = modulateLayerNorm(txtHidden, txtShift2, txtScale2)
|
||||
|
||||
imgFFOut := block.FF.Forward(imgNorm)
|
||||
txtFFOut := block.FFContext.Forward(txtNorm)
|
||||
|
||||
imgHidden = mlx.Add(imgHidden, mlx.Mul(imgGate2, imgFFOut))
|
||||
txtHidden = mlx.Add(txtHidden, mlx.Mul(txtGate2, txtFFOut))
|
||||
|
||||
return imgHidden, txtHidden
|
||||
}
|
||||
|
||||
// SingleTransformerBlockAttn implements attention for single-stream blocks
|
||||
// Weight names: single_transformer_blocks.N.attn.*
|
||||
type SingleTransformerBlockAttn struct {
|
||||
ToQKVMlpProj nn.LinearLayer `weight:"to_qkv_mlp_proj"` // Fused QKV + MLP input
|
||||
ToOut nn.LinearLayer `weight:"to_out"` // Fused attn_out + MLP out
|
||||
NormQ *mlx.Array `weight:"norm_q.weight"`
|
||||
NormK *mlx.Array `weight:"norm_k.weight"`
|
||||
}
|
||||
|
||||
// SingleTransformerBlock implements a single-stream transformer block
|
||||
// Weight names: single_transformer_blocks.N.*
|
||||
type SingleTransformerBlock struct {
|
||||
Attn *SingleTransformerBlockAttn `weight:"attn"`
|
||||
|
||||
// Config
|
||||
NHeads int32
|
||||
HeadDim int32
|
||||
InnerDim int32
|
||||
MLPHidDim int32
|
||||
Scale float32
|
||||
}
|
||||
|
||||
// Forward applies the single-stream block
|
||||
// x: [B, L, dim] concatenated text+image
|
||||
// mod: modulation [B, 3*dim]
|
||||
func (block *SingleTransformerBlock) Forward(x *mlx.Array, mod *mlx.Array, cos, sin *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
dim := shape[2]
|
||||
|
||||
// Parse modulation: (shift, scale, gate)
|
||||
shift, scale, gate := parseModulation3(mod, dim, 0)
|
||||
|
||||
// Modulate input
|
||||
h := modulateLayerNorm(x, shift, scale)
|
||||
|
||||
// Fused projection: QKV + MLP gate/up
|
||||
// linear1 outputs: [q, k, v, mlp_gate, mlp_up] = [dim, dim, dim, mlpHid, mlpHid]
|
||||
qkvMlp := block.Attn.ToQKVMlpProj.Forward(h)
|
||||
|
||||
// Split: first 3*dim is QKV, rest is MLP
|
||||
qkvDim := 3 * block.InnerDim
|
||||
qkv := mlx.Slice(qkvMlp, []int32{0, 0, 0}, []int32{B, L, qkvDim})
|
||||
mlpIn := mlx.Slice(qkvMlp, []int32{0, 0, qkvDim}, []int32{B, L, qkvMlp.Shape()[2]})
|
||||
|
||||
// Split QKV
|
||||
q, k, v := splitQKV(qkv, B, L, block.InnerDim)
|
||||
|
||||
// Reshape for attention
|
||||
q = mlx.Reshape(q, B, L, block.NHeads, block.HeadDim)
|
||||
k = mlx.Reshape(k, B, L, block.NHeads, block.HeadDim)
|
||||
v = mlx.Reshape(v, B, L, block.NHeads, block.HeadDim)
|
||||
|
||||
// QK norm
|
||||
q = applyQKNorm(q, block.Attn.NormQ)
|
||||
k = applyQKNorm(k, block.Attn.NormK)
|
||||
|
||||
// Apply RoPE
|
||||
q = ApplyRoPE4D(q, cos, sin)
|
||||
k = ApplyRoPE4D(k, cos, sin)
|
||||
|
||||
// Transpose for SDPA
|
||||
q = mlx.Transpose(q, 0, 2, 1, 3)
|
||||
k = mlx.Transpose(k, 0, 2, 1, 3)
|
||||
v = mlx.Transpose(v, 0, 2, 1, 3)
|
||||
|
||||
// SDPA
|
||||
attnOut := mlx.ScaledDotProductAttention(q, k, v, block.Scale, false)
|
||||
|
||||
// Transpose back and reshape
|
||||
attnOut = mlx.Transpose(attnOut, 0, 2, 1, 3)
|
||||
attnOut = mlx.Reshape(attnOut, B, L, block.InnerDim)
|
||||
|
||||
// MLP: SwiGLU
|
||||
mlpShape := mlpIn.Shape()
|
||||
half := mlpShape[2] / 2
|
||||
mlpGate := mlx.Slice(mlpIn, []int32{0, 0, 0}, []int32{B, L, half})
|
||||
mlpUp := mlx.Slice(mlpIn, []int32{0, 0, half}, []int32{B, L, mlpShape[2]})
|
||||
mlpOut := mlx.Mul(mlx.SiLU(mlpGate), mlpUp)
|
||||
|
||||
// Concatenate attention and MLP for fused output
|
||||
combined := mlx.Concatenate([]*mlx.Array{attnOut, mlpOut}, 2)
|
||||
|
||||
// Output projection
|
||||
out := block.Attn.ToOut.Forward(combined)
|
||||
|
||||
// Apply gate and residual
|
||||
return mlx.Add(x, mlx.Mul(gate, out))
|
||||
}
|
||||
|
||||
// NormOut implements the output normalization with modulation
|
||||
// Weight names: norm_out.linear.weight
|
||||
type NormOut struct {
|
||||
Linear nn.LinearLayer `weight:"linear"`
|
||||
}
|
||||
|
||||
// Forward computes final modulated output
|
||||
func (n *NormOut) Forward(x *mlx.Array, temb *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
dim := shape[2]
|
||||
|
||||
// Modulation: temb -> silu -> linear -> [shift, scale]
|
||||
mod := mlx.SiLU(temb)
|
||||
mod = n.Linear.Forward(mod)
|
||||
|
||||
// Split into scale and shift (diffusers order: scale first, shift second)
|
||||
scale := mlx.Slice(mod, []int32{0, 0}, []int32{B, dim})
|
||||
shift := mlx.Slice(mod, []int32{0, dim}, []int32{B, 2 * dim})
|
||||
shift = mlx.ExpandDims(shift, 1)
|
||||
scale = mlx.ExpandDims(scale, 1)
|
||||
|
||||
// Modulate with RMSNorm
|
||||
return modulateLayerNorm(x, shift, scale)
|
||||
}
|
||||
|
||||
// Flux2Transformer2DModel is the main Flux2 transformer
|
||||
// Weight names at top level: time_guidance_embed.*, double_stream_modulation_*.*, etc.
|
||||
type Flux2Transformer2DModel struct {
|
||||
// Timestep embedding
|
||||
TimeGuidanceEmbed *TimeGuidanceEmbed `weight:"time_guidance_embed"`
|
||||
|
||||
// Shared modulation
|
||||
DoubleStreamModulationImg *Modulation `weight:"double_stream_modulation_img"`
|
||||
DoubleStreamModulationTxt *Modulation `weight:"double_stream_modulation_txt"`
|
||||
SingleStreamModulation *Modulation `weight:"single_stream_modulation"`
|
||||
|
||||
// Embedders
|
||||
XEmbedder nn.LinearLayer `weight:"x_embedder"`
|
||||
ContextEmbedder nn.LinearLayer `weight:"context_embedder"`
|
||||
|
||||
// Transformer blocks
|
||||
TransformerBlocks []*TransformerBlock `weight:"transformer_blocks"`
|
||||
SingleTransformerBlocks []*SingleTransformerBlock `weight:"single_transformer_blocks"`
|
||||
|
||||
// Output
|
||||
NormOut *NormOut `weight:"norm_out"`
|
||||
ProjOut nn.LinearLayer `weight:"proj_out"`
|
||||
|
||||
*TransformerConfig
|
||||
}
|
||||
|
||||
// Load loads the Flux2 transformer from ollama blob storage.
|
||||
func (m *Flux2Transformer2DModel) Load(manifest *imagegen.ModelManifest) error {
|
||||
fmt.Print(" Loading transformer... ")
|
||||
|
||||
// Load config from blob
|
||||
var cfg TransformerConfig
|
||||
if err := manifest.ReadConfigJSON("transformer/config.json", &cfg); err != nil {
|
||||
return fmt.Errorf("config: %w", err)
|
||||
}
|
||||
m.TransformerConfig = &cfg
|
||||
|
||||
// Initialize slices
|
||||
m.TransformerBlocks = make([]*TransformerBlock, cfg.NumLayers)
|
||||
m.SingleTransformerBlocks = make([]*SingleTransformerBlock, cfg.NumSingleLayers)
|
||||
|
||||
// Initialize TimeGuidanceEmbed with embed dim
|
||||
m.TimeGuidanceEmbed = &TimeGuidanceEmbed{
|
||||
TimestepEmbedder: &TimestepEmbedder{EmbedDim: cfg.TimestepGuidanceChannels},
|
||||
}
|
||||
|
||||
// Load weights from tensor blobs
|
||||
weights, err := imagegen.LoadWeightsFromManifest(manifest, "transformer")
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
if err := weights.Load(0); err != nil {
|
||||
return fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
defer weights.ReleaseAll()
|
||||
|
||||
return m.loadWeights(weights)
|
||||
}
|
||||
|
||||
// loadWeights loads weights from any WeightSource into the model
|
||||
func (m *Flux2Transformer2DModel) loadWeights(weights safetensors.WeightSource) error {
|
||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||
return fmt.Errorf("load module: %w", err)
|
||||
}
|
||||
m.initComputedFields()
|
||||
fmt.Println("✓")
|
||||
return nil
|
||||
}
|
||||
|
||||
// initComputedFields initializes computed fields after loading weights
|
||||
func (m *Flux2Transformer2DModel) initComputedFields() {
|
||||
cfg := m.TransformerConfig
|
||||
innerDim := cfg.InnerDim()
|
||||
scale := float32(1.0 / math.Sqrt(float64(cfg.AttentionHeadDim)))
|
||||
|
||||
// Initialize transformer blocks
|
||||
for _, block := range m.TransformerBlocks {
|
||||
block.NHeads = cfg.NumAttentionHeads
|
||||
block.HeadDim = cfg.AttentionHeadDim
|
||||
block.Scale = scale
|
||||
}
|
||||
|
||||
// Initialize single transformer blocks
|
||||
for _, block := range m.SingleTransformerBlocks {
|
||||
block.NHeads = cfg.NumAttentionHeads
|
||||
block.HeadDim = cfg.AttentionHeadDim
|
||||
block.InnerDim = innerDim
|
||||
block.MLPHidDim = cfg.MLPHiddenDim()
|
||||
block.Scale = scale
|
||||
}
|
||||
}
|
||||
|
||||
// Forward runs the Flux2 transformer
|
||||
func (m *Flux2Transformer2DModel) Forward(patches, txtEmbeds *mlx.Array, timesteps *mlx.Array, rope *RoPECache) *mlx.Array {
|
||||
patchShape := patches.Shape()
|
||||
B := patchShape[0]
|
||||
imgLen := patchShape[1]
|
||||
txtLen := txtEmbeds.Shape()[1]
|
||||
|
||||
// Scale timestep to 0-1000 range (diffusers multiplies by 1000)
|
||||
scaledTimesteps := mlx.MulScalar(timesteps, 1000.0)
|
||||
|
||||
// Compute timestep embedding
|
||||
temb := m.TimeGuidanceEmbed.Forward(scaledTimesteps)
|
||||
|
||||
// Embed patches and text
|
||||
imgHidden := m.XEmbedder.Forward(patches)
|
||||
txtHidden := m.ContextEmbedder.Forward(txtEmbeds)
|
||||
|
||||
// Compute shared modulation
|
||||
imgMod := m.DoubleStreamModulationImg.Forward(temb)
|
||||
txtMod := m.DoubleStreamModulationTxt.Forward(temb)
|
||||
singleMod := m.SingleStreamModulation.Forward(temb)
|
||||
|
||||
// Double (dual-stream) blocks
|
||||
for _, block := range m.TransformerBlocks {
|
||||
imgHidden, txtHidden = block.Forward(imgHidden, txtHidden, imgMod, txtMod, rope.Cos, rope.Sin)
|
||||
}
|
||||
|
||||
// Concatenate for single-stream: text first, then image
|
||||
hidden := mlx.Concatenate([]*mlx.Array{txtHidden, imgHidden}, 1)
|
||||
|
||||
// Single-stream blocks
|
||||
for _, block := range m.SingleTransformerBlocks {
|
||||
hidden = block.Forward(hidden, singleMod, rope.Cos, rope.Sin)
|
||||
}
|
||||
|
||||
// Extract image portion
|
||||
totalLen := txtLen + imgLen
|
||||
imgOut := mlx.Slice(hidden, []int32{0, txtLen, 0}, []int32{B, totalLen, hidden.Shape()[2]})
|
||||
|
||||
// Final norm and projection
|
||||
imgOut = m.NormOut.Forward(imgOut, temb)
|
||||
return m.ProjOut.Forward(imgOut)
|
||||
}
|
||||
|
||||
// Note: QK normalization uses mlx.RMSNorm (the fast version) directly
|
||||
// See applyQKNorm function below
|
||||
|
||||
// compiledSwiGLU fuses: silu(gate) * up
|
||||
// Called 30x per step (10 in dual-stream + 20 in single-stream blocks)
|
||||
var compiledSwiGLU *mlx.CompiledFunc
|
||||
|
||||
func getCompiledSwiGLU() *mlx.CompiledFunc {
|
||||
if compiledSwiGLU == nil {
|
||||
compiledSwiGLU = mlx.CompileShapeless(func(inputs []*mlx.Array) []*mlx.Array {
|
||||
gate, up := inputs[0], inputs[1]
|
||||
return []*mlx.Array{mlx.Mul(mlx.SiLU(gate), up)}
|
||||
}, true)
|
||||
}
|
||||
return compiledSwiGLU
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
// parseModulation3 extracts 3 modulation params (shift, scale, gate) starting at offset
|
||||
func parseModulation3(mod *mlx.Array, dim int32, offset int32) (*mlx.Array, *mlx.Array, *mlx.Array) {
|
||||
B := mod.Shape()[0]
|
||||
start := offset * dim
|
||||
shift := mlx.Slice(mod, []int32{0, start}, []int32{B, start + dim})
|
||||
scale := mlx.Slice(mod, []int32{0, start + dim}, []int32{B, start + 2*dim})
|
||||
gate := mlx.Slice(mod, []int32{0, start + 2*dim}, []int32{B, start + 3*dim})
|
||||
|
||||
// Expand for broadcasting [B, dim] -> [B, 1, dim]
|
||||
shift = mlx.ExpandDims(shift, 1)
|
||||
scale = mlx.ExpandDims(scale, 1)
|
||||
gate = mlx.ExpandDims(gate, 1)
|
||||
|
||||
return shift, scale, gate
|
||||
}
|
||||
|
||||
// modulateLayerNorm applies LayerNorm then shift/scale modulation
|
||||
// Diffusers uses LayerNorm(elementwise_affine=False) which centers the data
|
||||
func modulateLayerNorm(x *mlx.Array, shift, scale *mlx.Array) *mlx.Array {
|
||||
// Fast LayerNorm without learnable params
|
||||
x = mlx.LayerNorm(x, 1e-6)
|
||||
|
||||
// Modulate: x * (1 + scale) + shift
|
||||
x = mlx.Mul(x, mlx.AddScalar(scale, 1.0))
|
||||
return mlx.Add(x, shift)
|
||||
}
|
||||
|
||||
// splitQKV splits a fused QKV tensor into Q, K, V
|
||||
func splitQKV(qkv *mlx.Array, B, L, dim int32) (*mlx.Array, *mlx.Array, *mlx.Array) {
|
||||
q := mlx.Slice(qkv, []int32{0, 0, 0}, []int32{B, L, dim})
|
||||
k := mlx.Slice(qkv, []int32{0, 0, dim}, []int32{B, L, 2 * dim})
|
||||
v := mlx.Slice(qkv, []int32{0, 0, 2 * dim}, []int32{B, L, 3 * dim})
|
||||
return q, k, v
|
||||
}
|
||||
|
||||
// applyQKNorm applies RMSNorm with learned scale (no bias)
|
||||
// Uses the optimized mlx_fast_rms_norm
|
||||
func applyQKNorm(x *mlx.Array, scale *mlx.Array) *mlx.Array {
|
||||
return mlx.RMSNorm(x, scale, 1e-6)
|
||||
}
|
||||
@@ -1,804 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package flux2
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
"github.com/ollama/ollama/x/imagegen/vae"
|
||||
)
|
||||
|
||||
// VAEConfig holds AutoencoderKLFlux2 configuration
|
||||
type VAEConfig struct {
|
||||
ActFn string `json:"act_fn"` // "silu"
|
||||
BatchNormEps float32 `json:"batch_norm_eps"` // 0.0001
|
||||
BatchNormMomentum float32 `json:"batch_norm_momentum"` // 0.1
|
||||
BlockOutChannels []int32 `json:"block_out_channels"` // [128, 256, 512, 512]
|
||||
ForceUpcast bool `json:"force_upcast"` // true
|
||||
InChannels int32 `json:"in_channels"` // 3
|
||||
LatentChannels int32 `json:"latent_channels"` // 32
|
||||
LayersPerBlock int32 `json:"layers_per_block"` // 2
|
||||
MidBlockAddAttn bool `json:"mid_block_add_attention"` // true
|
||||
NormNumGroups int32 `json:"norm_num_groups"` // 32
|
||||
OutChannels int32 `json:"out_channels"` // 3
|
||||
PatchSize []int32 `json:"patch_size"` // [2, 2]
|
||||
SampleSize int32 `json:"sample_size"` // 1024
|
||||
UsePostQuantConv bool `json:"use_post_quant_conv"` // true
|
||||
UseQuantConv bool `json:"use_quant_conv"` // true
|
||||
}
|
||||
|
||||
// BatchNorm2D implements 2D batch normalization with running statistics
|
||||
type BatchNorm2D struct {
|
||||
RunningMean *mlx.Array // [C]
|
||||
RunningVar *mlx.Array // [C]
|
||||
Weight *mlx.Array // [C] gamma
|
||||
Bias *mlx.Array // [C] beta
|
||||
Eps float32
|
||||
Momentum float32
|
||||
}
|
||||
|
||||
// Forward applies batch normalization (inference mode - uses running stats)
|
||||
// Input and output are in NHWC format [B, H, W, C]
|
||||
func (bn *BatchNorm2D) Forward(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
C := shape[3]
|
||||
|
||||
// Reshape stats for broadcasting [1, 1, 1, C]
|
||||
mean := mlx.Reshape(bn.RunningMean, 1, 1, 1, C)
|
||||
variance := mlx.Reshape(bn.RunningVar, 1, 1, 1, C)
|
||||
|
||||
// Normalize: (x - mean) / sqrt(var + eps)
|
||||
xNorm := mlx.Sub(x, mean)
|
||||
xNorm = mlx.Div(xNorm, mlx.Sqrt(mlx.AddScalar(variance, bn.Eps)))
|
||||
|
||||
// Scale and shift (only if affine=True)
|
||||
if bn.Weight != nil {
|
||||
weight := mlx.Reshape(bn.Weight, 1, 1, 1, C)
|
||||
xNorm = mlx.Mul(xNorm, weight)
|
||||
}
|
||||
if bn.Bias != nil {
|
||||
bias := mlx.Reshape(bn.Bias, 1, 1, 1, C)
|
||||
xNorm = mlx.Add(xNorm, bias)
|
||||
}
|
||||
|
||||
return xNorm
|
||||
}
|
||||
|
||||
// Denormalize inverts the batch normalization
|
||||
// Used when decoding latents
|
||||
func (bn *BatchNorm2D) Denormalize(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
C := shape[3]
|
||||
|
||||
// Reshape stats for broadcasting [1, 1, 1, C]
|
||||
mean := mlx.Reshape(bn.RunningMean, 1, 1, 1, C)
|
||||
variance := mlx.Reshape(bn.RunningVar, 1, 1, 1, C)
|
||||
|
||||
// Inverse: first undo affine, then undo normalization
|
||||
// For affine=False: x_denorm = x * sqrt(var + eps) + mean
|
||||
if bn.Bias != nil {
|
||||
bias := mlx.Reshape(bn.Bias, 1, 1, 1, C)
|
||||
x = mlx.Sub(x, bias)
|
||||
}
|
||||
if bn.Weight != nil {
|
||||
weight := mlx.Reshape(bn.Weight, 1, 1, 1, C)
|
||||
x = mlx.Div(x, weight)
|
||||
}
|
||||
x = mlx.Mul(x, mlx.Sqrt(mlx.AddScalar(variance, bn.Eps)))
|
||||
x = mlx.Add(x, mean)
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// GroupNormLayer implements group normalization
|
||||
// Reused from zimage package pattern
|
||||
type GroupNormLayer struct {
|
||||
Weight *mlx.Array `weight:"weight"`
|
||||
Bias *mlx.Array `weight:"bias"`
|
||||
NumGroups int32
|
||||
Eps float32
|
||||
}
|
||||
|
||||
// Forward applies group normalization
|
||||
// Input and output are in NHWC format [B, H, W, C]
|
||||
func (gn *GroupNormLayer) Forward(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
H := shape[1]
|
||||
W := shape[2]
|
||||
C := shape[3]
|
||||
|
||||
// Reshape to [B, H, W, groups, C/groups]
|
||||
groupSize := C / gn.NumGroups
|
||||
x = mlx.Reshape(x, B, H, W, gn.NumGroups, groupSize)
|
||||
|
||||
// Compute mean and variance per group
|
||||
mean := mlx.Mean(x, 1, true)
|
||||
mean = mlx.Mean(mean, 2, true)
|
||||
mean = mlx.Mean(mean, 4, true)
|
||||
|
||||
xCentered := mlx.Sub(x, mean)
|
||||
|
||||
sq := mlx.Square(xCentered)
|
||||
variance := mlx.Mean(sq, 1, true)
|
||||
variance = mlx.Mean(variance, 2, true)
|
||||
variance = mlx.Mean(variance, 4, true)
|
||||
|
||||
// Normalize
|
||||
xNorm := mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, gn.Eps)))
|
||||
|
||||
// Reshape back to [B, H, W, C]
|
||||
xNorm = mlx.Reshape(xNorm, B, H, W, C)
|
||||
|
||||
// Scale and shift
|
||||
if gn.Weight != nil {
|
||||
weight := mlx.Reshape(gn.Weight, 1, 1, 1, C)
|
||||
xNorm = mlx.Mul(xNorm, weight)
|
||||
}
|
||||
if gn.Bias != nil {
|
||||
bias := mlx.Reshape(gn.Bias, 1, 1, 1, C)
|
||||
xNorm = mlx.Add(xNorm, bias)
|
||||
}
|
||||
|
||||
return xNorm
|
||||
}
|
||||
|
||||
// Conv2D represents a 2D convolution layer (reused pattern)
|
||||
type Conv2D struct {
|
||||
Weight *mlx.Array `weight:"weight"`
|
||||
Bias *mlx.Array `weight:"bias,optional"`
|
||||
Stride int32
|
||||
Padding int32
|
||||
}
|
||||
|
||||
// Transform implements safetensors.Transformer to transpose weights from PyTorch's OIHW to MLX's OHWI.
|
||||
func (conv *Conv2D) Transform(field string, arr *mlx.Array) *mlx.Array {
|
||||
if field == "Weight" {
|
||||
return mlx.Transpose(arr, 0, 2, 3, 1)
|
||||
}
|
||||
return arr
|
||||
}
|
||||
|
||||
// Forward applies convolution (NHWC format)
|
||||
func (conv *Conv2D) Forward(x *mlx.Array) *mlx.Array {
|
||||
out := mlx.Conv2d(x, conv.Weight, conv.Stride, conv.Padding)
|
||||
|
||||
if conv.Bias != nil {
|
||||
bias := mlx.Reshape(conv.Bias, 1, 1, 1, conv.Bias.Dim(0))
|
||||
out = mlx.Add(out, bias)
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// ResnetBlock2D implements a ResNet block for VAE
|
||||
type ResnetBlock2D struct {
|
||||
Norm1 *GroupNormLayer `weight:"norm1"`
|
||||
Conv1 *Conv2D `weight:"conv1"`
|
||||
Norm2 *GroupNormLayer `weight:"norm2"`
|
||||
Conv2 *Conv2D `weight:"conv2"`
|
||||
ConvShortcut *Conv2D `weight:"conv_shortcut,optional"`
|
||||
}
|
||||
|
||||
// Forward applies the ResNet block
|
||||
func (rb *ResnetBlock2D) Forward(x *mlx.Array) *mlx.Array {
|
||||
h := rb.Norm1.Forward(x)
|
||||
h = mlx.SiLU(h)
|
||||
h = rb.Conv1.Forward(h)
|
||||
|
||||
h = rb.Norm2.Forward(h)
|
||||
h = mlx.SiLU(h)
|
||||
h = rb.Conv2.Forward(h)
|
||||
|
||||
if rb.ConvShortcut != nil {
|
||||
x = rb.ConvShortcut.Forward(x)
|
||||
}
|
||||
|
||||
return mlx.Add(h, x)
|
||||
}
|
||||
|
||||
// VAEAttentionBlock implements self-attention for VAE
|
||||
type VAEAttentionBlock struct {
|
||||
GroupNorm *GroupNormLayer `weight:"group_norm"`
|
||||
ToQ nn.LinearLayer `weight:"to_q"`
|
||||
ToK nn.LinearLayer `weight:"to_k"`
|
||||
ToV nn.LinearLayer `weight:"to_v"`
|
||||
ToOut nn.LinearLayer `weight:"to_out.0"`
|
||||
}
|
||||
|
||||
// Forward applies attention (NHWC format)
|
||||
func (ab *VAEAttentionBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
residual := x
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
H := shape[1]
|
||||
W := shape[2]
|
||||
C := shape[3]
|
||||
|
||||
h := ab.GroupNorm.Forward(x)
|
||||
h = mlx.Reshape(h, B, H*W, C)
|
||||
|
||||
q := ab.ToQ.Forward(h)
|
||||
k := ab.ToK.Forward(h)
|
||||
v := ab.ToV.Forward(h)
|
||||
|
||||
q = mlx.ExpandDims(q, 1)
|
||||
k = mlx.ExpandDims(k, 1)
|
||||
v = mlx.ExpandDims(v, 1)
|
||||
|
||||
scale := float32(1.0 / math.Sqrt(float64(C)))
|
||||
out := mlx.ScaledDotProductAttention(q, k, v, scale, false)
|
||||
out = mlx.Squeeze(out, 1)
|
||||
|
||||
out = ab.ToOut.Forward(out)
|
||||
out = mlx.Reshape(out, B, H, W, C)
|
||||
out = mlx.Add(out, residual)
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// UpDecoderBlock2D implements an upsampling decoder block
|
||||
type UpDecoderBlock2D struct {
|
||||
ResnetBlocks []*ResnetBlock2D
|
||||
Upsample *Conv2D
|
||||
}
|
||||
|
||||
// Forward applies the up decoder block
|
||||
func (ub *UpDecoderBlock2D) Forward(x *mlx.Array) *mlx.Array {
|
||||
for _, resnet := range ub.ResnetBlocks {
|
||||
x = resnet.Forward(x)
|
||||
}
|
||||
|
||||
if ub.Upsample != nil {
|
||||
x = upsample2x(x)
|
||||
x = ub.Upsample.Forward(x)
|
||||
}
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// upsample2x performs 2x nearest neighbor upsampling
|
||||
func upsample2x(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
H := shape[1]
|
||||
W := shape[2]
|
||||
|
||||
hIdx := mlx.ArangeInt(0, H, 1, mlx.DtypeInt32)
|
||||
hIdx = mlx.Reshape(hIdx, H, 1)
|
||||
hIdx = mlx.BroadcastTo(hIdx, []int32{H, 2})
|
||||
hIdx = mlx.Reshape(hIdx, H*2)
|
||||
|
||||
wIdx := mlx.ArangeInt(0, W, 1, mlx.DtypeInt32)
|
||||
wIdx = mlx.Reshape(wIdx, W, 1)
|
||||
wIdx = mlx.BroadcastTo(wIdx, []int32{W, 2})
|
||||
wIdx = mlx.Reshape(wIdx, W*2)
|
||||
|
||||
x = mlx.Take(x, hIdx, 1)
|
||||
x = mlx.Take(x, wIdx, 2)
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// VAEMidBlock is the middle block with attention
|
||||
type VAEMidBlock struct {
|
||||
Resnet1 *ResnetBlock2D
|
||||
Attention *VAEAttentionBlock
|
||||
Resnet2 *ResnetBlock2D
|
||||
}
|
||||
|
||||
// Forward applies the mid block
|
||||
func (mb *VAEMidBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
x = mb.Resnet1.Forward(x)
|
||||
x = mb.Attention.Forward(x)
|
||||
x = mb.Resnet2.Forward(x)
|
||||
return x
|
||||
}
|
||||
|
||||
// DefaultTilingConfig returns reasonable defaults for tiled decoding
|
||||
// Matches diffusers: tile_latent_min_size=64, tile_overlap_factor=0.25
|
||||
func DefaultTilingConfig() *vae.TilingConfig {
|
||||
return vae.DefaultTilingConfig()
|
||||
}
|
||||
|
||||
// AutoencoderKLFlux2 is the Flux2 VAE with BatchNorm
|
||||
type AutoencoderKLFlux2 struct {
|
||||
Config *VAEConfig
|
||||
|
||||
// Encoder components (for image editing)
|
||||
EncoderConvIn *Conv2D
|
||||
EncoderMid *VAEMidBlock
|
||||
EncoderDown []*DownEncoderBlock2D
|
||||
EncoderNormOut *GroupNormLayer
|
||||
EncoderConvOut *Conv2D
|
||||
|
||||
// Decoder components
|
||||
DecoderConvIn *Conv2D
|
||||
DecoderMid *VAEMidBlock
|
||||
DecoderUp []*UpDecoderBlock2D
|
||||
DecoderNormOut *GroupNormLayer
|
||||
DecoderConvOut *Conv2D
|
||||
|
||||
// Quant conv layers
|
||||
QuantConv *Conv2D
|
||||
PostQuantConv *Conv2D
|
||||
|
||||
// BatchNorm for latent normalization
|
||||
LatentBN *BatchNorm2D
|
||||
|
||||
// Tiling configuration (nil = no tiling)
|
||||
Tiling *vae.TilingConfig
|
||||
}
|
||||
|
||||
// DownEncoderBlock2D implements a downsampling encoder block
|
||||
type DownEncoderBlock2D struct {
|
||||
ResnetBlocks []*ResnetBlock2D
|
||||
Downsample *Conv2D
|
||||
}
|
||||
|
||||
// Forward applies the down encoder block
|
||||
func (db *DownEncoderBlock2D) Forward(x *mlx.Array) *mlx.Array {
|
||||
for _, resnet := range db.ResnetBlocks {
|
||||
x = resnet.Forward(x)
|
||||
}
|
||||
|
||||
if db.Downsample != nil {
|
||||
// Pad then conv with stride 2
|
||||
x = mlx.Pad(x, []int32{0, 0, 0, 1, 0, 1, 0, 0})
|
||||
x = db.Downsample.Forward(x)
|
||||
}
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// Load loads the Flux2 VAE from ollama blob storage.
|
||||
func (m *AutoencoderKLFlux2) Load(manifest *imagegen.ModelManifest) error {
|
||||
fmt.Print(" Loading VAE... ")
|
||||
|
||||
// Load config from blob
|
||||
var cfg VAEConfig
|
||||
if err := manifest.ReadConfigJSON("vae/config.json", &cfg); err != nil {
|
||||
return fmt.Errorf("config: %w", err)
|
||||
}
|
||||
m.Config = &cfg
|
||||
|
||||
// Load weights from tensor blobs
|
||||
weights, err := imagegen.LoadWeightsFromManifest(manifest, "vae")
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
if err := weights.Load(0); err != nil {
|
||||
return fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
defer weights.ReleaseAll()
|
||||
|
||||
return m.loadWeights(weights, &cfg)
|
||||
}
|
||||
|
||||
// loadWeights loads VAE weights from any WeightSource
|
||||
func (m *AutoencoderKLFlux2) loadWeights(weights safetensors.WeightSource, cfg *VAEConfig) error {
|
||||
var err error
|
||||
|
||||
// Load encoder components (for image conditioning)
|
||||
if err := m.loadEncoderWeights(weights, cfg); err != nil {
|
||||
return fmt.Errorf("encoder: %w", err)
|
||||
}
|
||||
|
||||
// Load decoder conv_in
|
||||
m.DecoderConvIn = &Conv2D{Stride: 1, Padding: 1}
|
||||
if err := safetensors.LoadModule(m.DecoderConvIn, weights, "decoder.conv_in"); err != nil {
|
||||
return fmt.Errorf("decoder.conv_in: %w", err)
|
||||
}
|
||||
|
||||
// Load mid block
|
||||
m.DecoderMid, err = loadVAEMidBlock(weights, "decoder.mid_block", cfg.NormNumGroups)
|
||||
if err != nil {
|
||||
return fmt.Errorf("decoder.mid_block: %w", err)
|
||||
}
|
||||
|
||||
// Load up blocks
|
||||
numBlocks := len(cfg.BlockOutChannels)
|
||||
m.DecoderUp = make([]*UpDecoderBlock2D, numBlocks)
|
||||
for i := 0; i < numBlocks; i++ {
|
||||
prefix := fmt.Sprintf("decoder.up_blocks.%d", i)
|
||||
hasUpsample := i < numBlocks-1
|
||||
m.DecoderUp[i], err = loadUpDecoderBlock2D(weights, prefix, cfg.LayersPerBlock+1, cfg.NormNumGroups, hasUpsample)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", prefix, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Load decoder conv_norm_out and conv_out
|
||||
m.DecoderNormOut = &GroupNormLayer{NumGroups: cfg.NormNumGroups, Eps: 1e-5}
|
||||
if err := safetensors.LoadModule(m.DecoderNormOut, weights, "decoder.conv_norm_out"); err != nil {
|
||||
return fmt.Errorf("decoder.conv_norm_out: %w", err)
|
||||
}
|
||||
|
||||
m.DecoderConvOut = &Conv2D{Stride: 1, Padding: 1}
|
||||
if err := safetensors.LoadModule(m.DecoderConvOut, weights, "decoder.conv_out"); err != nil {
|
||||
return fmt.Errorf("decoder.conv_out: %w", err)
|
||||
}
|
||||
|
||||
// Load post_quant_conv
|
||||
if cfg.UsePostQuantConv {
|
||||
m.PostQuantConv = &Conv2D{Stride: 1, Padding: 0}
|
||||
if err := safetensors.LoadModule(m.PostQuantConv, weights, "post_quant_conv"); err != nil {
|
||||
return fmt.Errorf("post_quant_conv: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Load latent BatchNorm (affine=False, so no weight/bias)
|
||||
bnMean, err := weights.GetTensor("bn.running_mean")
|
||||
if err != nil {
|
||||
return fmt.Errorf("bn.running_mean: %w", err)
|
||||
}
|
||||
bnVar, err := weights.GetTensor("bn.running_var")
|
||||
if err != nil {
|
||||
return fmt.Errorf("bn.running_var: %w", err)
|
||||
}
|
||||
m.LatentBN = &BatchNorm2D{
|
||||
RunningMean: bnMean,
|
||||
RunningVar: bnVar,
|
||||
Weight: nil, // affine=False
|
||||
Bias: nil, // affine=False
|
||||
Eps: cfg.BatchNormEps,
|
||||
Momentum: cfg.BatchNormMomentum,
|
||||
}
|
||||
|
||||
fmt.Println("✓")
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadVAEMidBlock loads the mid block.
|
||||
func loadVAEMidBlock(weights safetensors.WeightSource, prefix string, numGroups int32) (*VAEMidBlock, error) {
|
||||
resnet1, err := loadResnetBlock2D(weights, prefix+".resnets.0", numGroups)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
attention, err := loadVAEAttentionBlock(weights, prefix+".attentions.0", numGroups)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resnet2, err := loadResnetBlock2D(weights, prefix+".resnets.1", numGroups)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &VAEMidBlock{
|
||||
Resnet1: resnet1,
|
||||
Attention: attention,
|
||||
Resnet2: resnet2,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// loadResnetBlock2D loads a ResNet block.
|
||||
func loadResnetBlock2D(weights safetensors.WeightSource, prefix string, numGroups int32) (*ResnetBlock2D, error) {
|
||||
block := &ResnetBlock2D{
|
||||
Norm1: &GroupNormLayer{NumGroups: numGroups, Eps: 1e-5},
|
||||
Conv1: &Conv2D{Stride: 1, Padding: 1},
|
||||
Norm2: &GroupNormLayer{NumGroups: numGroups, Eps: 1e-5},
|
||||
Conv2: &Conv2D{Stride: 1, Padding: 1},
|
||||
ConvShortcut: &Conv2D{Stride: 1, Padding: 0}, // Pre-allocate for optional loading
|
||||
}
|
||||
if err := safetensors.LoadModule(block, weights, prefix); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// If ConvShortcut wasn't loaded (no weights found), nil it out
|
||||
if block.ConvShortcut.Weight == nil {
|
||||
block.ConvShortcut = nil
|
||||
}
|
||||
return block, nil
|
||||
}
|
||||
|
||||
// loadVAEAttentionBlock loads an attention block using LoadModule.
|
||||
func loadVAEAttentionBlock(weights safetensors.WeightSource, prefix string, numGroups int32) (*VAEAttentionBlock, error) {
|
||||
ab := &VAEAttentionBlock{
|
||||
GroupNorm: &GroupNormLayer{NumGroups: numGroups, Eps: 1e-5},
|
||||
}
|
||||
if err := safetensors.LoadModule(ab, weights, prefix); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ab, nil
|
||||
}
|
||||
|
||||
// loadUpDecoderBlock2D loads an up decoder block.
|
||||
func loadUpDecoderBlock2D(weights safetensors.WeightSource, prefix string, numLayers, numGroups int32, hasUpsample bool) (*UpDecoderBlock2D, error) {
|
||||
resnets := make([]*ResnetBlock2D, numLayers)
|
||||
for i := int32(0); i < numLayers; i++ {
|
||||
resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i)
|
||||
resnet, err := loadResnetBlock2D(weights, resPrefix, numGroups)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resnets[i] = resnet
|
||||
}
|
||||
|
||||
var upsample *Conv2D
|
||||
if hasUpsample {
|
||||
upsample = &Conv2D{Stride: 1, Padding: 1}
|
||||
if err := safetensors.LoadModule(upsample, weights, prefix+".upsamplers.0.conv"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &UpDecoderBlock2D{
|
||||
ResnetBlocks: resnets,
|
||||
Upsample: upsample,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Patchify converts latents [B, C, H, W] to patches [B, H*W/4, C*4] using 2x2 patches
|
||||
// This is the inverse of the VAE's patchify for feeding to transformer
|
||||
func (vae *AutoencoderKLFlux2) Patchify(latents *mlx.Array) *mlx.Array {
|
||||
shape := latents.Shape()
|
||||
B := shape[0]
|
||||
C := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
|
||||
patchH := vae.Config.PatchSize[0]
|
||||
patchW := vae.Config.PatchSize[1]
|
||||
|
||||
pH := H / patchH
|
||||
pW := W / patchW
|
||||
|
||||
// [B, C, H, W] -> [B, C, pH, patchH, pW, patchW]
|
||||
x := mlx.Reshape(latents, B, C, pH, patchH, pW, patchW)
|
||||
// [B, C, pH, patchH, pW, patchW] -> [B, pH, pW, C, patchH, patchW]
|
||||
x = mlx.Transpose(x, 0, 2, 4, 1, 3, 5)
|
||||
// [B, pH, pW, C, patchH, patchW] -> [B, pH*pW, C*patchH*patchW]
|
||||
return mlx.Reshape(x, B, pH*pW, C*patchH*patchW)
|
||||
}
|
||||
|
||||
// Unpatchify converts patches [B, L, C*4] back to [B, C, H, W]
|
||||
func (vae *AutoencoderKLFlux2) Unpatchify(patches *mlx.Array, pH, pW, C int32) *mlx.Array {
|
||||
shape := patches.Shape()
|
||||
B := shape[0]
|
||||
|
||||
patchH := vae.Config.PatchSize[0]
|
||||
patchW := vae.Config.PatchSize[1]
|
||||
|
||||
// [B, pH*pW, C*patchH*patchW] -> [B, pH, pW, C, patchH, patchW]
|
||||
x := mlx.Reshape(patches, B, pH, pW, C, patchH, patchW)
|
||||
// [B, pH, pW, C, patchH, patchW] -> [B, C, pH, patchH, pW, patchW]
|
||||
x = mlx.Transpose(x, 0, 3, 1, 4, 2, 5)
|
||||
// [B, C, pH, patchH, pW, patchW] -> [B, C, H, W]
|
||||
H := pH * patchH
|
||||
W := pW * patchW
|
||||
return mlx.Reshape(x, B, C, H, W)
|
||||
}
|
||||
|
||||
// denormalizePatchified applies inverse batch normalization to patchified latents.
|
||||
// Input: [B, L, 128] where 128 = 32 latent channels * 4 (2x2 patch)
|
||||
// Output: [B, L, 128] denormalized
|
||||
func (vae *AutoencoderKLFlux2) denormalizePatchified(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
C := shape[2] // 128
|
||||
|
||||
// Reshape stats for broadcasting [1, 1, C]
|
||||
mean := mlx.Reshape(vae.LatentBN.RunningMean, 1, 1, C)
|
||||
variance := mlx.Reshape(vae.LatentBN.RunningVar, 1, 1, C)
|
||||
|
||||
// Inverse BN (affine=False): x_denorm = x * sqrt(var + eps) + mean
|
||||
if vae.LatentBN.Bias != nil {
|
||||
bias := mlx.Reshape(vae.LatentBN.Bias, 1, 1, C)
|
||||
x = mlx.Sub(x, bias)
|
||||
}
|
||||
if vae.LatentBN.Weight != nil {
|
||||
weight := mlx.Reshape(vae.LatentBN.Weight, 1, 1, C)
|
||||
x = mlx.Div(x, weight)
|
||||
}
|
||||
x = mlx.Mul(x, mlx.Sqrt(mlx.AddScalar(variance, vae.LatentBN.Eps)))
|
||||
x = mlx.Add(x, mean)
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// Decode decodes latent patches to images.
|
||||
// If Tiling is set, uses tiled decoding to reduce memory for large images.
|
||||
// latents: [B, L, C*4] patchified latents from transformer
|
||||
// pH, pW: patch grid dimensions
|
||||
// Returns: [B, 3, H, W] image tensor
|
||||
func (v *AutoencoderKLFlux2) Decode(latents *mlx.Array, pH, pW int32) *mlx.Array {
|
||||
// Denormalize patchified latents
|
||||
z := v.denormalizePatchified(latents)
|
||||
|
||||
// Unpatchify: [B, L, C*4] -> [B, C, H, W]
|
||||
z = v.Unpatchify(z, pH, pW, v.Config.LatentChannels)
|
||||
|
||||
// Convert NCHW -> NHWC for processing
|
||||
z = mlx.Transpose(z, 0, 2, 3, 1)
|
||||
|
||||
// Use tiled decoding if enabled
|
||||
if v.Tiling != nil {
|
||||
mlx.Eval(z)
|
||||
return vae.DecodeTiled(z, v.Tiling, v.decodeTile)
|
||||
}
|
||||
|
||||
// Direct decode (no tiling)
|
||||
h := v.decodeTile(z)
|
||||
h = mlx.ClipScalar(h, 0.0, 1.0, true, true)
|
||||
h = mlx.Transpose(h, 0, 3, 1, 2)
|
||||
return h
|
||||
}
|
||||
|
||||
// decodeTile decodes a single latent tile to pixels (internal helper)
|
||||
// z: [B, H, W, C] latent tile in NHWC format
|
||||
// Returns: [B, H*8, W*8, 3] pixel tile in NHWC format (before clipping)
|
||||
func (vae *AutoencoderKLFlux2) decodeTile(z *mlx.Array) *mlx.Array {
|
||||
// Post-quant conv
|
||||
if vae.PostQuantConv != nil {
|
||||
z = vae.PostQuantConv.Forward(z)
|
||||
}
|
||||
|
||||
// Decoder
|
||||
h := vae.DecoderConvIn.Forward(z)
|
||||
h = vae.DecoderMid.Forward(h)
|
||||
|
||||
for _, upBlock := range vae.DecoderUp {
|
||||
h = upBlock.Forward(h)
|
||||
}
|
||||
|
||||
h = vae.DecoderNormOut.Forward(h)
|
||||
h = mlx.SiLU(h)
|
||||
h = vae.DecoderConvOut.Forward(h)
|
||||
|
||||
// VAE outputs [-1, 1], convert to [0, 1]
|
||||
h = mlx.MulScalar(h, 0.5)
|
||||
h = mlx.AddScalar(h, 0.5)
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// loadEncoderWeights loads the encoder components for image conditioning
|
||||
func (m *AutoencoderKLFlux2) loadEncoderWeights(weights safetensors.WeightSource, cfg *VAEConfig) error {
|
||||
var err error
|
||||
|
||||
// Load encoder conv_in
|
||||
m.EncoderConvIn = &Conv2D{Stride: 1, Padding: 1}
|
||||
if err := safetensors.LoadModule(m.EncoderConvIn, weights, "encoder.conv_in"); err != nil {
|
||||
return fmt.Errorf("encoder.conv_in: %w", err)
|
||||
}
|
||||
|
||||
// Load encoder down blocks
|
||||
numBlocks := len(cfg.BlockOutChannels)
|
||||
m.EncoderDown = make([]*DownEncoderBlock2D, numBlocks)
|
||||
for i := 0; i < numBlocks; i++ {
|
||||
prefix := fmt.Sprintf("encoder.down_blocks.%d", i)
|
||||
hasDownsample := i < numBlocks-1
|
||||
m.EncoderDown[i], err = loadDownEncoderBlock2D(weights, prefix, cfg.LayersPerBlock, cfg.NormNumGroups, hasDownsample)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", prefix, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Load encoder mid block
|
||||
m.EncoderMid, err = loadVAEMidBlock(weights, "encoder.mid_block", cfg.NormNumGroups)
|
||||
if err != nil {
|
||||
return fmt.Errorf("encoder.mid_block: %w", err)
|
||||
}
|
||||
|
||||
// Load encoder conv_norm_out and conv_out
|
||||
m.EncoderNormOut = &GroupNormLayer{NumGroups: cfg.NormNumGroups, Eps: 1e-5}
|
||||
if err := safetensors.LoadModule(m.EncoderNormOut, weights, "encoder.conv_norm_out"); err != nil {
|
||||
return fmt.Errorf("encoder.conv_norm_out: %w", err)
|
||||
}
|
||||
|
||||
m.EncoderConvOut = &Conv2D{Stride: 1, Padding: 1}
|
||||
if err := safetensors.LoadModule(m.EncoderConvOut, weights, "encoder.conv_out"); err != nil {
|
||||
return fmt.Errorf("encoder.conv_out: %w", err)
|
||||
}
|
||||
|
||||
// Load quant_conv (for encoding)
|
||||
if cfg.UseQuantConv {
|
||||
m.QuantConv = &Conv2D{Stride: 1, Padding: 0}
|
||||
if err := safetensors.LoadModule(m.QuantConv, weights, "quant_conv"); err != nil {
|
||||
return fmt.Errorf("quant_conv: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadDownEncoderBlock2D loads a down encoder block.
|
||||
func loadDownEncoderBlock2D(weights safetensors.WeightSource, prefix string, numLayers, numGroups int32, hasDownsample bool) (*DownEncoderBlock2D, error) {
|
||||
resnets := make([]*ResnetBlock2D, numLayers)
|
||||
for i := int32(0); i < numLayers; i++ {
|
||||
resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i)
|
||||
resnet, err := loadResnetBlock2D(weights, resPrefix, numGroups)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resnets[i] = resnet
|
||||
}
|
||||
|
||||
var downsample *Conv2D
|
||||
if hasDownsample {
|
||||
downsample = &Conv2D{Stride: 2, Padding: 0}
|
||||
if err := safetensors.LoadModule(downsample, weights, prefix+".downsamplers.0.conv"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &DownEncoderBlock2D{
|
||||
ResnetBlocks: resnets,
|
||||
Downsample: downsample,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// EncodeImage encodes an image to normalized latents.
|
||||
// image: [B, 3, H, W] image tensor in [-1, 1]
|
||||
// Returns: [B, L, C*4] patchified normalized latents
|
||||
func (vae *AutoencoderKLFlux2) EncodeImage(image *mlx.Array) *mlx.Array {
|
||||
// Convert NCHW -> NHWC
|
||||
x := mlx.Transpose(image, 0, 2, 3, 1)
|
||||
|
||||
// Encoder
|
||||
h := vae.EncoderConvIn.Forward(x)
|
||||
|
||||
for _, downBlock := range vae.EncoderDown {
|
||||
h = downBlock.Forward(h)
|
||||
}
|
||||
|
||||
h = vae.EncoderMid.Forward(h)
|
||||
h = vae.EncoderNormOut.Forward(h)
|
||||
h = mlx.SiLU(h)
|
||||
h = vae.EncoderConvOut.Forward(h)
|
||||
|
||||
// Quant conv outputs [B, H, W, 2*latent_channels] (mean + logvar)
|
||||
if vae.QuantConv != nil {
|
||||
h = vae.QuantConv.Forward(h)
|
||||
}
|
||||
|
||||
// Take only the mean (first latent_channels) - deterministic encoding
|
||||
// h is [B, H, W, 64] -> take first 32 channels for mean
|
||||
shape := h.Shape()
|
||||
latentChannels := vae.Config.LatentChannels // 32
|
||||
h = mlx.Slice(h, []int32{0, 0, 0, 0}, []int32{shape[0], shape[1], shape[2], latentChannels})
|
||||
|
||||
// Convert NHWC -> NCHW for patchifying
|
||||
h = mlx.Transpose(h, 0, 3, 1, 2)
|
||||
|
||||
// Patchify: [B, C, H, W] -> [B, L, C*4]
|
||||
h = vae.Patchify(h)
|
||||
|
||||
// Apply BatchNorm on patchified latents [B, L, 128]
|
||||
// The BatchNorm has 128 channels matching the patchified dimension
|
||||
h = vae.normalizePatchified(h)
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// normalizePatchified applies batch normalization to patchified latents.
|
||||
// Input: [B, L, 128] where 128 = 32 latent channels * 4 (2x2 patch)
|
||||
// Output: [B, L, 128] normalized
|
||||
func (vae *AutoencoderKLFlux2) normalizePatchified(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
C := shape[2] // 128
|
||||
|
||||
// Reshape stats for broadcasting [1, 1, C]
|
||||
mean := mlx.Reshape(vae.LatentBN.RunningMean, 1, 1, C)
|
||||
variance := mlx.Reshape(vae.LatentBN.RunningVar, 1, 1, C)
|
||||
|
||||
// Normalize: (x - mean) / sqrt(var + eps)
|
||||
xNorm := mlx.Sub(x, mean)
|
||||
xNorm = mlx.Div(xNorm, mlx.Sqrt(mlx.AddScalar(variance, vae.LatentBN.Eps)))
|
||||
|
||||
// Scale and shift (only if affine=True)
|
||||
if vae.LatentBN.Weight != nil {
|
||||
weight := mlx.Reshape(vae.LatentBN.Weight, 1, 1, C)
|
||||
xNorm = mlx.Mul(xNorm, weight)
|
||||
}
|
||||
if vae.LatentBN.Bias != nil {
|
||||
bias := mlx.Reshape(vae.LatentBN.Bias, 1, 1, C)
|
||||
xNorm = mlx.Add(xNorm, bias)
|
||||
}
|
||||
|
||||
return xNorm
|
||||
}
|
||||
693
x/imagegen/models/glm_image/glm_image.go
Normal file
@@ -0,0 +1,693 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package glm_image implements the GLM-Image hybrid AR + diffusion model.
|
||||
package glm_image
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// ByT5Tokenizer is a simple byte-level tokenizer for ByT5
|
||||
// ByT5 uses bytes as tokens: each byte (0-255) maps to token ID (3-258)
|
||||
// Special tokens: 0=pad, 1=eos, 2=unk
|
||||
type ByT5Tokenizer struct {
|
||||
PadTokenID int32
|
||||
EOSTokenID int32
|
||||
UNKTokenID int32
|
||||
}
|
||||
|
||||
// NewByT5Tokenizer creates a new ByT5 tokenizer
|
||||
func NewByT5Tokenizer() *ByT5Tokenizer {
|
||||
return &ByT5Tokenizer{
|
||||
PadTokenID: 0,
|
||||
EOSTokenID: 1,
|
||||
UNKTokenID: 2,
|
||||
}
|
||||
}
|
||||
|
||||
// Encode converts a string to token IDs
|
||||
func (t *ByT5Tokenizer) Encode(text string) []int32 {
|
||||
bytes := []byte(text)
|
||||
tokens := make([]int32, len(bytes))
|
||||
for i, b := range bytes {
|
||||
// Standard ByT5 tokenization: bytes 0-255 map to tokens 3-258
|
||||
// (tokens 0, 1, 2 are PAD, EOS, UNK)
|
||||
tokens[i] = int32(b) + 3
|
||||
}
|
||||
return tokens
|
||||
}
|
||||
|
||||
// Decode converts token IDs back to a string
|
||||
func (t *ByT5Tokenizer) Decode(tokens []int32) string {
|
||||
bytes := make([]byte, 0, len(tokens))
|
||||
for _, tok := range tokens {
|
||||
if tok >= 3 && tok < 259 {
|
||||
bytes = append(bytes, byte(tok-3))
|
||||
}
|
||||
}
|
||||
return string(bytes)
|
||||
}
|
||||
|
||||
// GenerateConfig holds all options for image generation.
|
||||
type GenerateConfig struct {
|
||||
Prompt string
|
||||
NegativePrompt string // For CFG (optional, not typically used with GLM-Image)
|
||||
GuidanceScale float32 // Guidance scale (default: 1.5)
|
||||
Width int32 // Image width (default: 1024, must be divisible by 32)
|
||||
Height int32 // Image height (default: 1024, must be divisible by 32)
|
||||
Steps int // Diffusion denoising steps (default: 50)
|
||||
Seed int64 // Random seed
|
||||
Progress ProgressFunc // Optional progress callback
|
||||
|
||||
// AR generation options
|
||||
MaxVisualTokens int32 // Max visual tokens to generate (default: 256)
|
||||
Temperature float32 // AR sampling temperature (default: 0.9)
|
||||
TopP float32 // Nucleus sampling (default: 0.75)
|
||||
}
|
||||
|
||||
// ProgressFunc is called during generation with stage and step progress.
|
||||
type ProgressFunc func(stage string, step, totalSteps int)
|
||||
|
||||
// Model represents a GLM-Image hybrid model.
|
||||
type Model struct {
|
||||
ModelName string
|
||||
Tokenizer *ByT5Tokenizer // For T5 text encoder (glyph embeddings)
|
||||
GLMTokenizer *GLMTokenizer // For AR model (visual token generation)
|
||||
TextEncoder *T5TextEncoder
|
||||
VisionLanguageEncoder *VisionLanguageEncoder
|
||||
Transformer *DiffusionTransformer
|
||||
VAEDecoder *VAEDecoder
|
||||
}
|
||||
|
||||
// Load loads the GLM-Image model from ollama blob storage.
|
||||
func (m *Model) Load(modelName string) error {
|
||||
fmt.Printf("Loading GLM-Image model from manifest: %s...\n", modelName)
|
||||
start := time.Now()
|
||||
|
||||
if mlx.GPUIsAvailable() {
|
||||
mlx.SetDefaultDeviceGPU()
|
||||
mlx.EnableCompile()
|
||||
}
|
||||
|
||||
m.ModelName = modelName
|
||||
|
||||
// Load manifest
|
||||
manifest, err := imagegen.LoadManifest(modelName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load manifest: %w", err)
|
||||
}
|
||||
|
||||
// Create ByT5 tokenizer (byte-level, no vocabulary file needed)
|
||||
// Used for T5 text encoder (glyph embeddings)
|
||||
fmt.Print(" Creating ByT5 tokenizer... ")
|
||||
m.Tokenizer = NewByT5Tokenizer()
|
||||
fmt.Println("✓")
|
||||
|
||||
// Load GLM tokenizer for AR model (visual token generation)
|
||||
fmt.Print(" Loading GLM tokenizer... ")
|
||||
glmTok, err := NewGLMTokenizer(manifest)
|
||||
if err != nil {
|
||||
return fmt.Errorf("glm tokenizer: %w", err)
|
||||
}
|
||||
m.GLMTokenizer = glmTok
|
||||
fmt.Println("✓")
|
||||
|
||||
// Load T5 text encoder (~830MB)
|
||||
m.TextEncoder = &T5TextEncoder{}
|
||||
if err := m.TextEncoder.Load(manifest); err != nil {
|
||||
return fmt.Errorf("text encoder: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.TextEncoder)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
// Load vision-language encoder (~19GB, 9B params)
|
||||
m.VisionLanguageEncoder = &VisionLanguageEncoder{}
|
||||
if err := m.VisionLanguageEncoder.Load(manifest); err != nil {
|
||||
return fmt.Errorf("vision language encoder: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.VisionLanguageEncoder)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
// Load diffusion transformer (~13GB, 7B params)
|
||||
m.Transformer = &DiffusionTransformer{}
|
||||
if err := m.Transformer.Load(manifest); err != nil {
|
||||
return fmt.Errorf("transformer: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.Transformer)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
// Load VAE decoder (~775MB)
|
||||
m.VAEDecoder = &VAEDecoder{}
|
||||
if err := m.VAEDecoder.Load(manifest); err != nil {
|
||||
return fmt.Errorf("VAE decoder: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.VAEDecoder)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
mem := mlx.MetalGetActiveMemory()
|
||||
fmt.Printf(" Loaded in %.2fs (%.1f GB VRAM)\n", time.Since(start).Seconds(), float64(mem)/(1024*1024*1024))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadFromPath loads the model from a directory path (not ollama manifest)
|
||||
func (m *Model) LoadFromPath(modelPath string) error {
|
||||
fmt.Printf("Loading GLM-Image model from path: %s...\n", modelPath)
|
||||
start := time.Now()
|
||||
|
||||
if mlx.GPUIsAvailable() {
|
||||
mlx.SetDefaultDeviceGPU()
|
||||
mlx.EnableCompile()
|
||||
}
|
||||
|
||||
m.ModelName = modelPath
|
||||
|
||||
// Create ByT5 tokenizer (byte-level, no vocabulary file needed)
|
||||
fmt.Print(" Creating ByT5 tokenizer... ")
|
||||
m.Tokenizer = NewByT5Tokenizer()
|
||||
fmt.Println("✓")
|
||||
|
||||
// Load GLM tokenizer for AR model (visual token generation)
|
||||
fmt.Print(" Loading GLM tokenizer... ")
|
||||
glmTok, err := NewGLMTokenizerFromPath(modelPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("glm tokenizer: %w", err)
|
||||
}
|
||||
m.GLMTokenizer = glmTok
|
||||
fmt.Println("✓")
|
||||
|
||||
// Load T5 text encoder
|
||||
m.TextEncoder = &T5TextEncoder{}
|
||||
if err := m.TextEncoder.LoadFromPath(filepath.Join(modelPath, "text_encoder")); err != nil {
|
||||
return fmt.Errorf("text encoder: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.TextEncoder)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
// Load vision-language encoder
|
||||
m.VisionLanguageEncoder = &VisionLanguageEncoder{}
|
||||
if err := m.VisionLanguageEncoder.LoadFromPath(filepath.Join(modelPath, "vision_language_encoder")); err != nil {
|
||||
return fmt.Errorf("vision language encoder: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.VisionLanguageEncoder)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
// Load diffusion transformer
|
||||
m.Transformer = &DiffusionTransformer{}
|
||||
if err := m.Transformer.LoadFromPath(filepath.Join(modelPath, "transformer")); err != nil {
|
||||
return fmt.Errorf("transformer: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.Transformer)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
// Load VAE decoder
|
||||
m.VAEDecoder = &VAEDecoder{}
|
||||
if err := m.VAEDecoder.LoadFromPath(filepath.Join(modelPath, "vae")); err != nil {
|
||||
return fmt.Errorf("VAE decoder: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.VAEDecoder)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
mem := mlx.MetalGetActiveMemory()
|
||||
fmt.Printf(" Loaded in %.2fs (%.1f GB VRAM)\n", time.Since(start).Seconds(), float64(mem)/(1024*1024*1024))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Generate creates an image from a prompt.
|
||||
func (m *Model) Generate(prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
|
||||
return m.GenerateFromConfig(context.Background(), &GenerateConfig{
|
||||
Prompt: prompt,
|
||||
Width: width,
|
||||
Height: height,
|
||||
Steps: steps,
|
||||
Seed: seed,
|
||||
})
|
||||
}
|
||||
|
||||
// GenerateWithProgress creates an image with progress callback.
|
||||
func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress ProgressFunc) (*mlx.Array, error) {
|
||||
return m.GenerateFromConfig(context.Background(), &GenerateConfig{
|
||||
Prompt: prompt,
|
||||
Width: width,
|
||||
Height: height,
|
||||
Steps: steps,
|
||||
Seed: seed,
|
||||
Progress: progress,
|
||||
})
|
||||
}
|
||||
|
||||
// GenerateFromConfig generates an image using the unified config struct.
|
||||
func (m *Model) GenerateFromConfig(ctx context.Context, cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
start := time.Now()
|
||||
result, err := m.generate(ctx, cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fmt.Printf("Generated in %.2fs (%d diffusion steps)\n", time.Since(start).Seconds(), cfg.Steps)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GenerateImage implements model.ImageModel interface.
|
||||
func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
|
||||
return m.Generate(prompt, width, height, steps, seed)
|
||||
}
|
||||
|
||||
// generate is the internal generation pipeline.
|
||||
func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
// Apply defaults
|
||||
if cfg.Width <= 0 {
|
||||
cfg.Width = 1024
|
||||
}
|
||||
if cfg.Height <= 0 {
|
||||
cfg.Height = 1024
|
||||
}
|
||||
if cfg.Steps <= 0 {
|
||||
cfg.Steps = 50
|
||||
}
|
||||
if cfg.GuidanceScale <= 0 {
|
||||
cfg.GuidanceScale = 1.5
|
||||
}
|
||||
// Calculate MaxVisualTokens based on image dimensions
|
||||
// GLM-Image generates TWO grids of visual tokens:
|
||||
// 1. First: prev (small) grid - prevTokenH × prevTokenW tokens
|
||||
// 2. Then: target (large) grid - tokenH × tokenW tokens
|
||||
// After generation, we extract only the TARGET grid tokens for diffusion.
|
||||
factor := int32(32)
|
||||
tokenH := cfg.Height / factor
|
||||
tokenW := cfg.Width / factor
|
||||
targetGridTokens := tokenH * tokenW
|
||||
|
||||
// Compute prev grid dimensions using diffusers formula:
|
||||
// ratio = token_h / token_w
|
||||
// prev_token_h = int(sqrt(ratio) * 16)
|
||||
// prev_token_w = int(sqrt(1/ratio) * 16)
|
||||
ratio := float64(tokenH) / float64(tokenW)
|
||||
prevTokenH := int32(math.Sqrt(ratio) * 16)
|
||||
prevTokenW := int32(math.Sqrt(1/ratio) * 16)
|
||||
prevGridTokens := prevTokenH * prevTokenW
|
||||
|
||||
// Total tokens to generate = prev grid + target grid
|
||||
// (diffusers does max_new_tokens = total + 1 for EOS, but we stop on EOS anyway)
|
||||
cfg.MaxVisualTokens = prevGridTokens + targetGridTokens
|
||||
if cfg.Temperature <= 0 {
|
||||
cfg.Temperature = 0.9
|
||||
}
|
||||
if cfg.TopP <= 0 {
|
||||
cfg.TopP = 0.75
|
||||
}
|
||||
|
||||
// Ensure dimensions are divisible by 32
|
||||
cfg.Width = (cfg.Width / 32) * 32
|
||||
cfg.Height = (cfg.Height / 32) * 32
|
||||
|
||||
tcfg := m.Transformer.Config
|
||||
latentH := cfg.Height / 8
|
||||
latentW := cfg.Width / 8
|
||||
|
||||
// Progress callback helper
|
||||
progress := func(stage string, step, total int) {
|
||||
if cfg.Progress != nil {
|
||||
cfg.Progress(stage, step, total)
|
||||
}
|
||||
}
|
||||
|
||||
// === PHASE 1: T5 Text Encoding ===
|
||||
fmt.Println("[T5] Encoding glyph text...")
|
||||
progress("text_encoding", 0, 1)
|
||||
textEmbed := m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.Prompt)
|
||||
mlx.Keep(textEmbed)
|
||||
mlx.Eval(textEmbed)
|
||||
fmt.Printf("[T5] Done, shape: %v\n", textEmbed.Shape())
|
||||
progress("text_encoding", 1, 1)
|
||||
|
||||
// === PHASE 2: AR Visual Token Generation ===
|
||||
fmt.Printf("[AR] Generating %d visual tokens...\n", cfg.MaxVisualTokens)
|
||||
progress("ar_generation", 0, int(cfg.MaxVisualTokens))
|
||||
visualTokens := m.VisionLanguageEncoder.Generate(
|
||||
cfg.Prompt,
|
||||
m.GLMTokenizer,
|
||||
cfg.MaxVisualTokens,
|
||||
cfg.Temperature,
|
||||
cfg.TopP,
|
||||
cfg.Seed,
|
||||
cfg.Height,
|
||||
cfg.Width,
|
||||
func(step int) {
|
||||
if step%100 == 0 || step < 10 {
|
||||
fmt.Printf("[AR] Step %d/%d\n", step, cfg.MaxVisualTokens)
|
||||
}
|
||||
progress("ar_generation", step, int(cfg.MaxVisualTokens))
|
||||
},
|
||||
)
|
||||
mlx.Keep(visualTokens)
|
||||
mlx.Eval(visualTokens)
|
||||
fmt.Printf("[AR] Done generating visual tokens\n")
|
||||
progress("ar_generation", int(cfg.MaxVisualTokens), int(cfg.MaxVisualTokens))
|
||||
|
||||
vtShape := visualTokens.Shape()
|
||||
totalGenerated := vtShape[1]
|
||||
fmt.Printf("[AR] Generated %d tokens total\n", totalGenerated)
|
||||
|
||||
// Extract only the TARGET grid tokens (skip the prev grid tokens)
|
||||
// diffusers: large_image_tokens = outputs[input_length + large_image_start_offset : ...]
|
||||
// large_image_start_offset = prev_grid_size
|
||||
var targetGridVisualTokens *mlx.Array
|
||||
if totalGenerated >= prevGridTokens+targetGridTokens {
|
||||
// Full generation completed - extract target grid
|
||||
targetGridVisualTokens = mlx.Slice(visualTokens,
|
||||
[]int32{0, prevGridTokens},
|
||||
[]int32{1, prevGridTokens + targetGridTokens})
|
||||
mlx.Keep(targetGridVisualTokens)
|
||||
mlx.Eval(targetGridVisualTokens)
|
||||
} else if totalGenerated > prevGridTokens {
|
||||
// Partial target grid - take what we have
|
||||
actualTargetTokens := totalGenerated - prevGridTokens
|
||||
targetGridVisualTokens = mlx.Slice(visualTokens,
|
||||
[]int32{0, prevGridTokens},
|
||||
[]int32{1, totalGenerated})
|
||||
mlx.Keep(targetGridVisualTokens)
|
||||
mlx.Eval(targetGridVisualTokens)
|
||||
fmt.Printf("WARNING: Partial target grid: got %d/%d target tokens\n",
|
||||
actualTargetTokens, targetGridTokens)
|
||||
} else {
|
||||
// Not enough tokens - EOS came too early
|
||||
return nil, fmt.Errorf("AR generation stopped too early: got %d tokens, need at least %d (prev grid) + 1",
|
||||
totalGenerated, prevGridTokens)
|
||||
}
|
||||
|
||||
// === PHASE 3: Diffusion Decoding ===
|
||||
// Setup scheduler with dynamic shift based on image size
|
||||
scheduler := NewFlowMatchScheduler(DefaultSchedulerConfig())
|
||||
imgSeqLen := (latentH / tcfg.PatchSize) * (latentW / tcfg.PatchSize)
|
||||
scheduler.SetTimestepsWithDynamicShift(cfg.Steps, imgSeqLen)
|
||||
|
||||
// Initialize noise latents [B, C, H, W]
|
||||
latents := scheduler.InitNoise([]int32{1, tcfg.InChannels, latentH, latentW}, cfg.Seed)
|
||||
mlx.Eval(latents)
|
||||
|
||||
// Upsample TARGET grid visual tokens 2x to match patch count (matching diffusers)
|
||||
// target_grid tokens -> 2x upsample -> patch_count
|
||||
// e.g., 32x32=1024 tokens -> 64x64=4096 patches for 1024x1024
|
||||
visualTokensUpsampled := upsampleTokens(targetGridVisualTokens, tokenH, tokenW, 2)
|
||||
|
||||
// Prepare prior embeddings from upsampled visual tokens (VQ codebook lookup + projection)
|
||||
priorEmbed := m.Transformer.EmbedPriorTokens(visualTokensUpsampled)
|
||||
mlx.Keep(priorEmbed)
|
||||
mlx.Eval(priorEmbed)
|
||||
|
||||
// Prepare text conditioning (project T5 embeddings)
|
||||
textCond := m.Transformer.ProjectTextEmbeddings(textEmbed)
|
||||
mlx.Keep(textCond)
|
||||
mlx.Eval(textCond)
|
||||
|
||||
// === CFG Setup ===
|
||||
// For classifier-free guidance, we need unconditional (negative) text embeddings
|
||||
// GLM-Image uses empty string "" for negative prompt
|
||||
doCFG := cfg.GuidanceScale > 1.0
|
||||
var negativeTextCond *mlx.Array
|
||||
if doCFG {
|
||||
// Encode empty string for negative prompt
|
||||
negativeTextEmbed := m.TextEncoder.EncodePrompt(m.Tokenizer, "")
|
||||
mlx.Keep(negativeTextEmbed)
|
||||
mlx.Eval(negativeTextEmbed)
|
||||
negativeTextCond = m.Transformer.ProjectTextEmbeddings(negativeTextEmbed)
|
||||
mlx.Keep(negativeTextCond)
|
||||
mlx.Eval(negativeTextCond)
|
||||
negativeTextEmbed.Free()
|
||||
}
|
||||
|
||||
// Prepare conditioning inputs
|
||||
targetSize := mlx.NewArray([]float32{float32(cfg.Height), float32(cfg.Width)}, []int32{1, 2})
|
||||
cropCoords := mlx.NewArray([]float32{0, 0}, []int32{1, 2}) // Default: no crop offset
|
||||
targetSize = mlx.ToBFloat16(targetSize)
|
||||
cropCoords = mlx.ToBFloat16(cropCoords)
|
||||
mlx.Keep(targetSize)
|
||||
mlx.Keep(cropCoords)
|
||||
mlx.Eval(targetSize, cropCoords)
|
||||
|
||||
pH := latentH / tcfg.PatchSize
|
||||
pW := latentW / tcfg.PatchSize
|
||||
|
||||
// Denoising loop
|
||||
fmt.Printf("[Diffusion] Starting %d denoising steps...\n", cfg.Steps)
|
||||
progress("diffusion", 0, cfg.Steps)
|
||||
for i := 0; i < cfg.Steps; i++ {
|
||||
fmt.Printf("[Diffusion] Step %d/%d (timestep=%.1f)\n", i+1, cfg.Steps, scheduler.Timesteps[i]-1)
|
||||
// Check for cancellation
|
||||
if ctx != nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
textEmbed.Free()
|
||||
visualTokens.Free()
|
||||
// visualTokensUpsampled points to visualTokens, don't double-free
|
||||
priorEmbed.Free()
|
||||
textCond.Free()
|
||||
latents.Free()
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// Get timestep value for the transformer
|
||||
// scheduler.Timesteps contains raw timestep values (1000 down to ~20)
|
||||
// Pass timestep - 1 to match diffusers: timestep = t.expand(latents.shape[0]) - 1
|
||||
timestepVal := scheduler.Timesteps[i] - 1
|
||||
timestep := mlx.ToBFloat16(mlx.NewArray([]float32{timestepVal}, []int32{1}))
|
||||
|
||||
// Patchify latents [B, C, H, W] -> [B, L, C*p*p]
|
||||
patches := PatchifyLatents(latents, tcfg.PatchSize)
|
||||
|
||||
// Transformer forward with MMDiT architecture
|
||||
// Conditional pass (with text + prior embeddings)
|
||||
outputCond := m.Transformer.ForwardWithPriorDrop(
|
||||
patches,
|
||||
priorEmbed,
|
||||
textCond,
|
||||
timestep,
|
||||
targetSize,
|
||||
cropCoords,
|
||||
pH,
|
||||
pW,
|
||||
false, // priorTokenDrop = false for conditional
|
||||
)
|
||||
|
||||
// Unpatchify [B, L, C*p*p] -> [B, C, H, W]
|
||||
noisePredCond := UnpatchifyLatents(outputCond, latentH, latentW, tcfg.PatchSize, tcfg.OutChannels)
|
||||
|
||||
var noisePred *mlx.Array
|
||||
if doCFG {
|
||||
// Unconditional pass (empty text, dropped prior embeddings)
|
||||
outputUncond := m.Transformer.ForwardWithPriorDrop(
|
||||
patches,
|
||||
priorEmbed, // Still passed but will be ignored due to priorTokenDrop=true
|
||||
negativeTextCond,
|
||||
timestep,
|
||||
targetSize,
|
||||
cropCoords,
|
||||
pH,
|
||||
pW,
|
||||
true, // priorTokenDrop = true for unconditional
|
||||
)
|
||||
noisePredUncond := UnpatchifyLatents(outputUncond, latentH, latentW, tcfg.PatchSize, tcfg.OutChannels)
|
||||
|
||||
// CFG formula: noise_pred = uncond + guidance_scale * (cond - uncond)
|
||||
diff := mlx.Sub(noisePredCond, noisePredUncond)
|
||||
scaled := mlx.MulScalar(diff, cfg.GuidanceScale)
|
||||
noisePred = mlx.Add(noisePredUncond, scaled)
|
||||
} else {
|
||||
noisePred = noisePredCond
|
||||
}
|
||||
|
||||
// Scheduler step
|
||||
oldLatents := latents
|
||||
latents = scheduler.Step(noisePred, latents, i)
|
||||
mlx.Eval(latents)
|
||||
oldLatents.Free()
|
||||
|
||||
progress("diffusion", i+1, cfg.Steps)
|
||||
}
|
||||
|
||||
// Cleanup intermediate arrays
|
||||
textEmbed.Free()
|
||||
visualTokens.Free()
|
||||
// visualTokensUpsampled points to visualTokens, don't double-free
|
||||
priorEmbed.Free()
|
||||
textCond.Free()
|
||||
if negativeTextCond != nil {
|
||||
negativeTextCond.Free()
|
||||
}
|
||||
targetSize.Free()
|
||||
cropCoords.Free()
|
||||
|
||||
// === PHASE 4: VAE Decode ===
|
||||
progress("vae_decode", 0, 1)
|
||||
decoded := m.VAEDecoder.Decode(latents)
|
||||
mlx.Eval(decoded)
|
||||
latents.Free()
|
||||
progress("vae_decode", 1, 1)
|
||||
|
||||
return decoded, nil
|
||||
}
|
||||
|
||||
// upsampleTokens performs nearest-neighbor upsampling of visual tokens
|
||||
// Converts from prev_grid (e.g., 16x16) to target_grid (e.g., 32x32 for 2x, 64x64 for 4x)
|
||||
// scale must be 2 or 4
|
||||
//
|
||||
// Handles early EOS gracefully: if tokens has fewer than prevH*prevW elements,
|
||||
// missing tokens are padded with 0 (visual token padding value).
|
||||
func upsampleTokens(tokens *mlx.Array, prevH, prevW int32, scale int32) *mlx.Array {
|
||||
// tokens: [1, N] where N <= prevH*prevW (may be shorter if early EOS)
|
||||
// Each token at (i, j) becomes scale*scale tokens in the output
|
||||
|
||||
mlx.Eval(tokens)
|
||||
tokenData := tokens.DataInt32()
|
||||
numTokens := int32(len(tokenData))
|
||||
expectedTokens := prevH * prevW
|
||||
|
||||
// Warn if we got fewer tokens than expected (early EOS)
|
||||
if numTokens < expectedTokens {
|
||||
fmt.Printf("WARNING: upsampleTokens got %d tokens, expected %d (padding with 0)\n",
|
||||
numTokens, expectedTokens)
|
||||
}
|
||||
|
||||
targetH := prevH * scale
|
||||
targetW := prevW * scale
|
||||
upsampled := make([]int32, targetH*targetW)
|
||||
|
||||
for i := int32(0); i < prevH; i++ {
|
||||
for j := int32(0); j < prevW; j++ {
|
||||
srcIdx := i*prevW + j
|
||||
|
||||
// Handle early EOS: use 0 (padding) for missing tokens
|
||||
var val int32
|
||||
if srcIdx < numTokens {
|
||||
val = tokenData[srcIdx]
|
||||
} else {
|
||||
val = 0 // Padding token
|
||||
}
|
||||
|
||||
// Place in scale*scale positions
|
||||
dstI := i * scale
|
||||
dstJ := j * scale
|
||||
for di := int32(0); di < scale; di++ {
|
||||
for dj := int32(0); dj < scale; dj++ {
|
||||
upsampled[(dstI+di)*targetW+(dstJ+dj)] = val
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return mlx.NewArrayInt32(upsampled, []int32{1, targetH * targetW})
|
||||
}
|
||||
|
||||
// PatchifyLatents converts [B, C, H, W] to [B, L, C*p*p]
|
||||
func PatchifyLatents(latents *mlx.Array, patchSize int32) *mlx.Array {
|
||||
shape := latents.Shape()
|
||||
B := shape[0]
|
||||
C := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
|
||||
pH := H / patchSize
|
||||
pW := W / patchSize
|
||||
|
||||
// Reshape: [B, C, H, W] -> [B, C, pH, p, pW, p]
|
||||
x := mlx.Reshape(latents, B, C, pH, patchSize, pW, patchSize)
|
||||
// Transpose: -> [B, pH, pW, C, p, p]
|
||||
x = mlx.Transpose(x, 0, 2, 4, 1, 3, 5)
|
||||
// Flatten: -> [B, pH*pW, C*p*p]
|
||||
return mlx.Reshape(x, B, pH*pW, C*patchSize*patchSize)
|
||||
}
|
||||
|
||||
// UnpatchifyLatents converts [B, L, C*p*p] back to [B, C, H, W]
|
||||
func UnpatchifyLatents(patches *mlx.Array, H, W, patchSize, channels int32) *mlx.Array {
|
||||
shape := patches.Shape()
|
||||
B := shape[0]
|
||||
|
||||
pH := H / patchSize
|
||||
pW := W / patchSize
|
||||
|
||||
// Reshape: [B, L, C*p*p] -> [B, pH, pW, C, p, p]
|
||||
x := mlx.Reshape(patches, B, pH, pW, channels, patchSize, patchSize)
|
||||
// Transpose: -> [B, C, pH, p, pW, p]
|
||||
x = mlx.Transpose(x, 0, 3, 1, 4, 2, 5)
|
||||
// Reshape: -> [B, C, H, W]
|
||||
return mlx.Reshape(x, B, channels, pH*patchSize, pW*patchSize)
|
||||
}
|
||||
|
||||
// CalculateShift computes the dynamic shift for flow matching based on image sequence length.
|
||||
func CalculateShift(imgSeqLen int32) float32 {
|
||||
cfg := DefaultSchedulerConfig()
|
||||
if !cfg.UseDynamicShifting {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Sqrt-based shift calculation (matches diffusers)
|
||||
m := float32(math.Sqrt(float64(imgSeqLen) / float64(cfg.BaseImageSeqLen)))
|
||||
return m*cfg.MaxShift + cfg.BaseShift
|
||||
}
|
||||
|
||||
// UpsampleTokens2x upsamples token IDs by 2x using nearest neighbor interpolation
|
||||
// tokens: [B, H*W] -> [B, (H*2)*(W*2)]
|
||||
// This matches diffusers' _upsample_token_ids function
|
||||
func UpsampleTokens2x(tokens *mlx.Array, gridH, gridW int32) *mlx.Array {
|
||||
shape := tokens.Shape()
|
||||
B := shape[0]
|
||||
|
||||
// Reshape to [B, 1, H, W] for interpolation
|
||||
tokens = mlx.Reshape(tokens, B, 1, gridH, gridW)
|
||||
|
||||
// Convert to float for interpolation
|
||||
tokensFloat := mlx.AsType(tokens, mlx.DtypeFloat32)
|
||||
|
||||
// 2x nearest neighbor upsample
|
||||
// [B, 1, H, W] -> [B, 1, H*2, W*2]
|
||||
upsampled := nearestUpsample2x(tokensFloat)
|
||||
|
||||
// Convert back to int and reshape to [B, H*2*W*2]
|
||||
upsampled = mlx.AsType(upsampled, mlx.DtypeInt32)
|
||||
return mlx.Reshape(upsampled, B, gridH*2*gridW*2)
|
||||
}
|
||||
|
||||
// nearestUpsample2x performs 2x nearest neighbor upsampling on NCHW tensor
|
||||
func nearestUpsample2x(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
C := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
|
||||
// Repeat each element 2x2
|
||||
// [B, C, H, W] -> [B, C, H, 1, W, 1] -> [B, C, H, 2, W, 2] -> [B, C, H*2, W*2]
|
||||
x = mlx.Reshape(x, B, C, H, 1, W, 1)
|
||||
|
||||
// Tile to repeat each pixel 2x2
|
||||
x = mlx.Tile(x, []int32{1, 1, 1, 2, 1, 2})
|
||||
|
||||
// Reshape to final size
|
||||
return mlx.Reshape(x, B, C, H*2, W*2)
|
||||
}
|
||||
358
x/imagegen/models/glm_image/glm_tokenizer.go
Normal file
@@ -0,0 +1,358 @@
|
||||
//go:build mlx
|
||||
|
||||
package glm_image
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
)
|
||||
|
||||
// GLMTokenizer implements the GLM tokenizer for the AR model
|
||||
// This is a BPE-style tokenizer with ignore_merges=true, meaning it does
|
||||
// greedy longest-match tokenization from the vocab without runtime merging.
|
||||
type GLMTokenizer struct {
|
||||
Vocab map[string]int32 // token string -> token ID
|
||||
VocabReverse map[int32]string // token ID -> token string
|
||||
SpecialTokens map[string]int32 // special token strings -> IDs
|
||||
|
||||
// Special token IDs
|
||||
SopTokenID int32 // <sop> = grid_bos_token (167845)
|
||||
EopTokenID int32 // <eop> = grid_eos_token (167846)
|
||||
BosTokenID int32 // <|dit_token_16384|> = visual BOS (16384)
|
||||
EosTokenID int32 // <|dit_token_16385|> = visual EOS (16385)
|
||||
PadTokenID int32
|
||||
|
||||
// Sorted vocab keys by length (longest first) for greedy matching
|
||||
sortedTokens []string
|
||||
}
|
||||
|
||||
// tokenizerJSON represents the structure of tokenizer.json
|
||||
type tokenizerJSON struct {
|
||||
Model struct {
|
||||
Vocab map[string]int32 `json:"vocab"`
|
||||
} `json:"model"`
|
||||
AddedTokens []struct {
|
||||
ID int32 `json:"id"`
|
||||
Content string `json:"content"`
|
||||
Special bool `json:"special"`
|
||||
} `json:"added_tokens"`
|
||||
}
|
||||
|
||||
// NewGLMTokenizer creates a GLM tokenizer from the model manifest
|
||||
func NewGLMTokenizer(manifest *imagegen.ModelManifest) (*GLMTokenizer, error) {
|
||||
// Read tokenizer.json from processor directory in manifest
|
||||
data, err := manifest.ReadConfig("processor/tokenizer.json")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read tokenizer.json from manifest: %w", err)
|
||||
}
|
||||
|
||||
var tj tokenizerJSON
|
||||
if err := json.Unmarshal(data, &tj); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse tokenizer.json: %w", err)
|
||||
}
|
||||
|
||||
tok := &GLMTokenizer{
|
||||
Vocab: make(map[string]int32),
|
||||
VocabReverse: make(map[int32]string),
|
||||
SpecialTokens: make(map[string]int32),
|
||||
}
|
||||
|
||||
// Load vocab from model section
|
||||
for token, id := range tj.Model.Vocab {
|
||||
tok.Vocab[token] = id
|
||||
tok.VocabReverse[id] = token
|
||||
}
|
||||
|
||||
// Load added tokens (special tokens including dit_tokens)
|
||||
for _, at := range tj.AddedTokens {
|
||||
tok.Vocab[at.Content] = at.ID
|
||||
tok.VocabReverse[at.ID] = at.Content
|
||||
if at.Special {
|
||||
tok.SpecialTokens[at.Content] = at.ID
|
||||
}
|
||||
}
|
||||
|
||||
// Set special token IDs
|
||||
tok.SopTokenID = 167845 // <sop>
|
||||
tok.EopTokenID = 167846 // <eop>
|
||||
tok.BosTokenID = 16384 // <|dit_token_16384|>
|
||||
tok.EosTokenID = 16385 // <|dit_token_16385|>
|
||||
tok.PadTokenID = 16385 // Same as EOS
|
||||
|
||||
// Build sorted token list for greedy matching (longest first)
|
||||
tok.sortedTokens = make([]string, 0, len(tok.Vocab))
|
||||
for token := range tok.Vocab {
|
||||
tok.sortedTokens = append(tok.sortedTokens, token)
|
||||
}
|
||||
sort.Slice(tok.sortedTokens, func(i, j int) bool {
|
||||
return len(tok.sortedTokens[i]) > len(tok.sortedTokens[j])
|
||||
})
|
||||
|
||||
fmt.Printf("Loaded GLM tokenizer with %d tokens\n", len(tok.Vocab))
|
||||
|
||||
return tok, nil
|
||||
}
|
||||
|
||||
// NewGLMTokenizerFromPath creates a GLM tokenizer from a directory path
|
||||
func NewGLMTokenizerFromPath(modelPath string) (*GLMTokenizer, error) {
|
||||
// Read tokenizer.json from processor directory
|
||||
tokenizerPath := filepath.Join(modelPath, "processor", "tokenizer.json")
|
||||
data, err := os.ReadFile(tokenizerPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read tokenizer.json: %w", err)
|
||||
}
|
||||
|
||||
var tj tokenizerJSON
|
||||
if err := json.Unmarshal(data, &tj); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse tokenizer.json: %w", err)
|
||||
}
|
||||
|
||||
tok := &GLMTokenizer{
|
||||
Vocab: make(map[string]int32),
|
||||
VocabReverse: make(map[int32]string),
|
||||
SpecialTokens: make(map[string]int32),
|
||||
}
|
||||
|
||||
// Load vocab from model section
|
||||
for token, id := range tj.Model.Vocab {
|
||||
tok.Vocab[token] = id
|
||||
tok.VocabReverse[id] = token
|
||||
}
|
||||
|
||||
// Load added tokens (special tokens including dit_tokens)
|
||||
for _, at := range tj.AddedTokens {
|
||||
tok.Vocab[at.Content] = at.ID
|
||||
tok.VocabReverse[at.ID] = at.Content
|
||||
if at.Special {
|
||||
tok.SpecialTokens[at.Content] = at.ID
|
||||
}
|
||||
}
|
||||
|
||||
// Set special token IDs
|
||||
tok.SopTokenID = 167845 // <sop>
|
||||
tok.EopTokenID = 167846 // <eop>
|
||||
tok.BosTokenID = 16384 // <|dit_token_16384|>
|
||||
tok.EosTokenID = 16385 // <|dit_token_16385|>
|
||||
tok.PadTokenID = 16385 // Same as EOS
|
||||
|
||||
// Build sorted token list for greedy matching (longest first)
|
||||
tok.sortedTokens = make([]string, 0, len(tok.Vocab))
|
||||
for token := range tok.Vocab {
|
||||
tok.sortedTokens = append(tok.sortedTokens, token)
|
||||
}
|
||||
sort.Slice(tok.sortedTokens, func(i, j int) bool {
|
||||
return len(tok.sortedTokens[i]) > len(tok.sortedTokens[j])
|
||||
})
|
||||
|
||||
fmt.Printf("Loaded GLM tokenizer with %d tokens\n", len(tok.Vocab))
|
||||
|
||||
return tok, nil
|
||||
}
|
||||
|
||||
// Encode tokenizes a string into token IDs
|
||||
// This uses greedy longest-match tokenization with GPT-2 style space handling
|
||||
func (t *GLMTokenizer) Encode(text string) []int32 {
|
||||
if text == "" {
|
||||
return []int32{}
|
||||
}
|
||||
|
||||
var tokens []int32
|
||||
|
||||
// First, check for and handle special tokens
|
||||
// Replace special tokens with placeholders, encode, then restore
|
||||
specialReplacements := make(map[string]int32)
|
||||
for special, id := range t.SpecialTokens {
|
||||
if strings.Contains(text, special) {
|
||||
specialReplacements[special] = id
|
||||
}
|
||||
}
|
||||
|
||||
// Process text character by character with special token handling
|
||||
i := 0
|
||||
isFirstToken := true
|
||||
|
||||
for i < len(text) {
|
||||
// Check for special tokens first
|
||||
foundSpecial := false
|
||||
for special, id := range specialReplacements {
|
||||
if strings.HasPrefix(text[i:], special) {
|
||||
tokens = append(tokens, id)
|
||||
i += len(special)
|
||||
isFirstToken = false
|
||||
foundSpecial = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if foundSpecial {
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle regular text with GPT-2 style space prefix
|
||||
// "Ġ" (U+0120) represents a space before a token
|
||||
remaining := text[i:]
|
||||
|
||||
// Try to find the longest matching token
|
||||
matched := false
|
||||
for _, token := range t.sortedTokens {
|
||||
// Skip special tokens in regular matching
|
||||
if _, isSpecial := t.SpecialTokens[token]; isSpecial {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if this token matches
|
||||
tokenText := token
|
||||
|
||||
// Handle the Ġ prefix (represents space)
|
||||
if strings.HasPrefix(token, "Ġ") {
|
||||
// This token expects a leading space
|
||||
if i > 0 || !isFirstToken {
|
||||
// Check if remaining starts with space + token content
|
||||
tokenContent := token[len("Ġ"):]
|
||||
if strings.HasPrefix(remaining, " "+tokenContent) {
|
||||
tokens = append(tokens, t.Vocab[token])
|
||||
i += 1 + len(tokenContent) // space + content
|
||||
isFirstToken = false
|
||||
matched = true
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Regular token without space prefix
|
||||
if strings.HasPrefix(remaining, tokenText) {
|
||||
tokens = append(tokens, t.Vocab[token])
|
||||
i += len(tokenText)
|
||||
isFirstToken = false
|
||||
matched = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !matched {
|
||||
// No token found - skip this character (or use UNK)
|
||||
// For now, just skip unknown characters
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
return tokens
|
||||
}
|
||||
|
||||
// EncodeForGeneration encodes a prompt with grid tokens for image generation
|
||||
// Format: {prompt}<sop>{token_h} {token_w}<eop><sop>{prev_h} {prev_w}<eop><|dit_token_16384|>
|
||||
//
|
||||
// Uses GPT-2 style tokenization where " 32" becomes "Ġ32" (a single token with
|
||||
// space prefix), matching the HuggingFace tokenizer behavior.
|
||||
func (t *GLMTokenizer) EncodeForGeneration(prompt string, targetHeight, targetWidth int32) []int32 {
|
||||
// Calculate grid dimensions
|
||||
factor := int32(32)
|
||||
height := (targetHeight / factor) * factor
|
||||
width := (targetWidth / factor) * factor
|
||||
tokenH := height / factor
|
||||
tokenW := width / factor
|
||||
|
||||
// Calculate previous grid dimensions
|
||||
ratio := float64(tokenH) / float64(tokenW)
|
||||
prevTokenH := int32(sqrt(ratio) * 16)
|
||||
prevTokenW := int32(sqrt(1.0/ratio) * 16)
|
||||
|
||||
// Encode the prompt text
|
||||
promptTokens := t.Encode(prompt)
|
||||
|
||||
// Build the full sequence:
|
||||
// [prompt tokens] <sop> [tokenH] [Ġ+tokenW] <eop> <sop> [prevH] [Ġ+prevW] <eop> <bos>
|
||||
// Note: HF tokenizer treats " 32" as "Ġ32" (single token), not "Ġ" + "32"
|
||||
var tokens []int32
|
||||
tokens = append(tokens, promptTokens...)
|
||||
|
||||
// First grid: <sop> H W <eop>
|
||||
// First number has no space prefix, second number has space prefix (Ġ)
|
||||
tokens = append(tokens, t.SopTokenID)
|
||||
tokens = append(tokens, t.encodeNumber(tokenH)...)
|
||||
tokens = append(tokens, t.encodeSpaceNumber(tokenW)...) // " W" as Ġ+W
|
||||
tokens = append(tokens, t.EopTokenID)
|
||||
|
||||
// Second grid: <sop> prevH prevW <eop>
|
||||
tokens = append(tokens, t.SopTokenID)
|
||||
tokens = append(tokens, t.encodeNumber(prevTokenH)...)
|
||||
tokens = append(tokens, t.encodeSpaceNumber(prevTokenW)...) // " prevW" as Ġ+prevW
|
||||
tokens = append(tokens, t.EopTokenID)
|
||||
|
||||
// BOS token (start of image generation)
|
||||
tokens = append(tokens, t.BosTokenID)
|
||||
|
||||
return tokens
|
||||
}
|
||||
|
||||
// encodeNumber encodes a number - first tries as a whole token, falls back to digit-by-digit
|
||||
func (t *GLMTokenizer) encodeNumber(n int32) []int32 {
|
||||
s := fmt.Sprintf("%d", n)
|
||||
// First try: look up the whole number as a single token
|
||||
if id, ok := t.Vocab[s]; ok {
|
||||
return []int32{id}
|
||||
}
|
||||
// Fallback: encode digit by digit
|
||||
var tokens []int32
|
||||
for _, c := range s {
|
||||
if id, ok := t.Vocab[string(c)]; ok {
|
||||
tokens = append(tokens, id)
|
||||
}
|
||||
}
|
||||
return tokens
|
||||
}
|
||||
|
||||
// encodeSpaceNumber encodes " N" as "ĠN" (space-prefixed number) matching HF tokenizer
|
||||
// GPT-2 style: " 32" becomes single token "Ġ32", not "Ġ" + "32"
|
||||
func (t *GLMTokenizer) encodeSpaceNumber(n int32) []int32 {
|
||||
s := fmt.Sprintf("%d", n)
|
||||
|
||||
// First try: look up "Ġ{number}" as a single token (e.g., "Ġ32")
|
||||
spaceToken := "Ġ" + s
|
||||
if id, ok := t.Vocab[spaceToken]; ok {
|
||||
return []int32{id}
|
||||
}
|
||||
|
||||
// Fallback: bare space Ġ + number tokens
|
||||
var tokens []int32
|
||||
if spaceID, ok := t.Vocab["Ġ"]; ok {
|
||||
tokens = append(tokens, spaceID)
|
||||
}
|
||||
tokens = append(tokens, t.encodeNumber(n)...)
|
||||
return tokens
|
||||
}
|
||||
|
||||
// sqrt is a helper for float64 sqrt
|
||||
func sqrt(x float64) float64 {
|
||||
if x <= 0 {
|
||||
return 0
|
||||
}
|
||||
// Newton's method
|
||||
z := x
|
||||
for i := 0; i < 10; i++ {
|
||||
z = z - (z*z-x)/(2*z)
|
||||
}
|
||||
return z
|
||||
}
|
||||
|
||||
// Decode converts token IDs back to a string
|
||||
func (t *GLMTokenizer) Decode(tokens []int32) string {
|
||||
var sb strings.Builder
|
||||
for _, id := range tokens {
|
||||
if token, ok := t.VocabReverse[id]; ok {
|
||||
// Handle Ġ prefix (convert back to space)
|
||||
if strings.HasPrefix(token, "Ġ") {
|
||||
sb.WriteString(" ")
|
||||
sb.WriteString(token[len("Ġ"):])
|
||||
} else {
|
||||
sb.WriteString(token)
|
||||
}
|
||||
}
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
159
x/imagegen/models/glm_image/scheduler.go
Normal file
@@ -0,0 +1,159 @@
|
||||
//go:build mlx
|
||||
|
||||
package glm_image
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// FlowMatchSchedulerConfig holds scheduler configuration
|
||||
type FlowMatchSchedulerConfig struct {
|
||||
NumTrainTimesteps int32 `json:"num_train_timesteps"` // 1000
|
||||
BaseShift float32 `json:"base_shift"` // 0.25
|
||||
MaxShift float32 `json:"max_shift"` // 0.75
|
||||
BaseImageSeqLen int32 `json:"base_image_seq_len"` // 256
|
||||
MaxImageSeqLen int32 `json:"max_image_seq_len"` // 4096
|
||||
UseDynamicShifting bool `json:"use_dynamic_shifting"` // true
|
||||
TimeShiftType string `json:"time_shift_type"` // "linear"
|
||||
}
|
||||
|
||||
// DefaultSchedulerConfig returns the default config for GLM-Image
|
||||
func DefaultSchedulerConfig() *FlowMatchSchedulerConfig {
|
||||
return &FlowMatchSchedulerConfig{
|
||||
NumTrainTimesteps: 1000,
|
||||
BaseShift: 0.25,
|
||||
MaxShift: 0.75,
|
||||
BaseImageSeqLen: 256,
|
||||
MaxImageSeqLen: 4096,
|
||||
UseDynamicShifting: true,
|
||||
TimeShiftType: "linear",
|
||||
}
|
||||
}
|
||||
|
||||
// FlowMatchScheduler implements FlowMatchEulerDiscreteScheduler
|
||||
type FlowMatchScheduler struct {
|
||||
Config *FlowMatchSchedulerConfig
|
||||
Timesteps []float32 // Raw timesteps for transformer conditioning (unshifted)
|
||||
Sigmas []float32 // Shifted sigmas for Euler step calculation
|
||||
NumSteps int
|
||||
}
|
||||
|
||||
// NewFlowMatchScheduler creates a new scheduler
|
||||
func NewFlowMatchScheduler(cfg *FlowMatchSchedulerConfig) *FlowMatchScheduler {
|
||||
return &FlowMatchScheduler{Config: cfg}
|
||||
}
|
||||
|
||||
// SetTimestepsWithDynamicShift sets timesteps with dynamic shifting based on image size
|
||||
// Following diffusers: raw timesteps are used for conditioning, shifted sigmas for step calculation
|
||||
func (s *FlowMatchScheduler) SetTimestepsWithDynamicShift(numSteps int, imgSeqLen int32) {
|
||||
s.NumSteps = numSteps
|
||||
|
||||
// Calculate shift (mu) based on image sequence length
|
||||
mu := s.calculateShift(imgSeqLen)
|
||||
|
||||
// Create timesteps: linspace from sigma_max_t to sigma_min_t
|
||||
// sigma_max = 1.0, sigma_min ~= 0.001 (near 0 but not exactly 0)
|
||||
// Then apply time shift and append terminal sigma=0
|
||||
s.Timesteps = make([]float32, numSteps)
|
||||
s.Sigmas = make([]float32, numSteps+1) // +1 for terminal sigma
|
||||
|
||||
numTrainTimesteps := float32(s.Config.NumTrainTimesteps)
|
||||
|
||||
// Create base sigmas: linspace from 1.0 to small value (matching diffusers)
|
||||
for i := 0; i < numSteps; i++ {
|
||||
// linspace from 1000 to ~20 (sigma_min * num_train_timesteps)
|
||||
tRaw := numTrainTimesteps - float32(i)*(numTrainTimesteps-1.0)/float32(numSteps-1)
|
||||
s.Timesteps[i] = tRaw
|
||||
|
||||
// Convert to sigma [0, 1]
|
||||
sigma := tRaw / numTrainTimesteps
|
||||
|
||||
// Apply time shift if enabled
|
||||
if s.Config.UseDynamicShifting && mu > 0 {
|
||||
sigma = s.applyShift(mu, sigma)
|
||||
}
|
||||
|
||||
s.Sigmas[i] = sigma
|
||||
}
|
||||
|
||||
// Append terminal sigma = 0 (the final clean image)
|
||||
s.Sigmas[numSteps] = 0
|
||||
}
|
||||
|
||||
// calculateShift computes dynamic shift based on image sequence length
|
||||
// Uses the sqrt-based formula from diffusers:
|
||||
// m = (image_seq_len / base_seq_len) ** 0.5
|
||||
// mu = m * max_shift + base_shift
|
||||
func (s *FlowMatchScheduler) calculateShift(imgSeqLen int32) float32 {
|
||||
cfg := s.Config
|
||||
|
||||
if !cfg.UseDynamicShifting {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Sqrt-based shift calculation (matches diffusers pipeline_glm_image.py)
|
||||
m := float32(math.Sqrt(float64(imgSeqLen) / float64(cfg.BaseImageSeqLen)))
|
||||
mu := m*cfg.MaxShift + cfg.BaseShift
|
||||
return mu
|
||||
}
|
||||
|
||||
// applyShift applies time shift transformation
|
||||
// mu: the computed shift value
|
||||
// t: sigma value in [0, 1]
|
||||
func (s *FlowMatchScheduler) applyShift(mu float32, t float32) float32 {
|
||||
if t <= 0 {
|
||||
return 0
|
||||
}
|
||||
if t >= 1 {
|
||||
return 1
|
||||
}
|
||||
|
||||
// sigma=1.0 for both shift types
|
||||
sigma := float32(1.0)
|
||||
|
||||
if s.Config.TimeShiftType == "linear" {
|
||||
// Linear: mu / (mu + (1/t - 1)^sigma)
|
||||
return mu / (mu + float32(math.Pow(float64(1.0/t-1.0), float64(sigma))))
|
||||
}
|
||||
|
||||
// Exponential (default): exp(mu) / (exp(mu) + (1/t - 1)^sigma)
|
||||
expMu := float32(math.Exp(float64(mu)))
|
||||
return expMu / (expMu + float32(math.Pow(float64(1.0/t-1.0), float64(sigma))))
|
||||
}
|
||||
|
||||
// Step performs one denoising step
|
||||
func (s *FlowMatchScheduler) Step(modelOutput, sample *mlx.Array, stepIdx int) *mlx.Array {
|
||||
sigma := s.Sigmas[stepIdx]
|
||||
sigmaNext := s.Sigmas[stepIdx+1]
|
||||
|
||||
// Euler step: x_{t-dt} = x_t + dt * v_t
|
||||
dt := sigmaNext - sigma // Negative (going from noise to clean)
|
||||
|
||||
scaledOutput := mlx.MulScalar(modelOutput, dt)
|
||||
return mlx.Add(sample, scaledOutput)
|
||||
}
|
||||
|
||||
// InitNoise creates initial noise
|
||||
func (s *FlowMatchScheduler) InitNoise(shape []int32, seed int64) *mlx.Array {
|
||||
return mlx.RandomNormalWithDtype(shape, uint64(seed), mlx.DtypeBFloat16)
|
||||
}
|
||||
|
||||
// AddNoise adds noise to clean samples for a given timestep (for img2img)
|
||||
func (s *FlowMatchScheduler) AddNoise(cleanSample, noise *mlx.Array, timestepIdx int) *mlx.Array {
|
||||
// In flow matching: x_t = (1-sigma) * x_0 + sigma * noise
|
||||
// Use sigmas (shifted) for the interpolation
|
||||
sigma := s.Sigmas[timestepIdx]
|
||||
oneMinusSigma := 1.0 - sigma
|
||||
|
||||
scaledClean := mlx.MulScalar(cleanSample, oneMinusSigma)
|
||||
scaledNoise := mlx.MulScalar(noise, sigma)
|
||||
|
||||
return mlx.Add(scaledClean, scaledNoise)
|
||||
}
|
||||
|
||||
// GetTimesteps returns all timesteps
|
||||
func (s *FlowMatchScheduler) GetTimesteps() []float32 {
|
||||
return s.Timesteps
|
||||
}
|
||||
497
x/imagegen/models/glm_image/text_encoder.go
Normal file
@@ -0,0 +1,497 @@
|
||||
//go:build mlx
|
||||
|
||||
package glm_image
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
|
||||
// T5Config holds T5 encoder configuration
|
||||
type T5Config struct {
|
||||
DModel int32 `json:"d_model"` // 1472
|
||||
DFF int32 `json:"d_ff"` // 3584
|
||||
DKV int32 `json:"d_kv"` // 64
|
||||
NumHeads int32 `json:"num_heads"` // 6
|
||||
NumLayers int32 `json:"num_layers"` // 12
|
||||
VocabSize int32 `json:"vocab_size"` // 384 (byte-level)
|
||||
LayerNormEps float32 `json:"layer_norm_epsilon"` // 1e-6
|
||||
IsGatedAct bool `json:"is_gated_act"` // true (gated-gelu)
|
||||
|
||||
// Relative position bias
|
||||
RelativeAttentionNumBuckets int32 `json:"relative_attention_num_buckets"` // 32
|
||||
RelativeAttentionMaxDistance int32 `json:"relative_attention_max_distance"` // 128
|
||||
}
|
||||
|
||||
// T5TextEncoder is the T5 encoder for text conditioning
|
||||
type T5TextEncoder struct {
|
||||
Config *T5Config
|
||||
|
||||
// Embedding (shared for ByT5)
|
||||
SharedEmbed *nn.Embedding `weight:"shared"`
|
||||
|
||||
// Encoder layers
|
||||
Layers []*T5Block `weight:"encoder.block"`
|
||||
|
||||
// Final layer norm
|
||||
FinalNorm *T5LayerNorm `weight:"encoder.final_layer_norm"`
|
||||
|
||||
// Relative position bias (from first layer, shared across all)
|
||||
RelativeAttentionBias *mlx.Array `weight:"encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"`
|
||||
}
|
||||
|
||||
// T5Block is a single T5 encoder block
|
||||
type T5Block struct {
|
||||
// Self attention
|
||||
Layer0 *T5LayerSelfAttention `weight:"layer.0"`
|
||||
// FFN
|
||||
Layer1 *T5LayerFF `weight:"layer.1"`
|
||||
}
|
||||
|
||||
// T5LayerSelfAttention is T5's self-attention layer
|
||||
type T5LayerSelfAttention struct {
|
||||
SelfAttention *T5Attention `weight:"SelfAttention"`
|
||||
LayerNorm *T5LayerNorm `weight:"layer_norm"`
|
||||
}
|
||||
|
||||
// T5Attention implements T5's relative attention
|
||||
type T5Attention struct {
|
||||
Q *mlx.Array `weight:"q.weight"` // No bias in T5
|
||||
K *mlx.Array `weight:"k.weight"`
|
||||
V *mlx.Array `weight:"v.weight"`
|
||||
O *mlx.Array `weight:"o.weight"`
|
||||
|
||||
NHeads int32
|
||||
DKV int32
|
||||
Scale float32
|
||||
}
|
||||
|
||||
// T5LayerFF is T5's feedforward layer with gated-gelu
|
||||
type T5LayerFF struct {
|
||||
DenseReluDense *T5DenseGatedGelu `weight:"DenseReluDense"`
|
||||
LayerNorm *T5LayerNorm `weight:"layer_norm"`
|
||||
}
|
||||
|
||||
// T5DenseGatedGelu is T5's gated-gelu FFN
|
||||
type T5DenseGatedGelu struct {
|
||||
Wi0 *mlx.Array `weight:"wi_0.weight"` // gate projection
|
||||
Wi1 *mlx.Array `weight:"wi_1.weight"` // up projection
|
||||
Wo *mlx.Array `weight:"wo.weight"` // down projection
|
||||
}
|
||||
|
||||
// T5LayerNorm is T5's RMSNorm variant (no bias, no mean subtraction)
|
||||
type T5LayerNorm struct {
|
||||
Weight *mlx.Array `weight:"weight"`
|
||||
Eps float32
|
||||
}
|
||||
|
||||
// Load loads the T5 text encoder from manifest
|
||||
func (m *T5TextEncoder) Load(manifest *imagegen.ModelManifest) error {
|
||||
fmt.Print(" Loading T5 text encoder... ")
|
||||
|
||||
// Load config
|
||||
var cfg T5Config
|
||||
if err := manifest.ReadConfigJSON("text_encoder/config.json", &cfg); err != nil {
|
||||
return fmt.Errorf("config: %w", err)
|
||||
}
|
||||
m.Config = &cfg
|
||||
|
||||
// Pre-allocate layers
|
||||
m.Layers = make([]*T5Block, cfg.NumLayers)
|
||||
|
||||
// Load weights
|
||||
weights, err := imagegen.LoadWeightsFromManifest(manifest, "text_encoder")
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
if err := weights.Load(0); err != nil {
|
||||
return fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
defer weights.ReleaseAll()
|
||||
|
||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||
return fmt.Errorf("load module: %w", err)
|
||||
}
|
||||
|
||||
m.initComputedFields()
|
||||
fmt.Println("✓")
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadFromPath loads the T5 text encoder from a directory path
|
||||
func (m *T5TextEncoder) LoadFromPath(path string) error {
|
||||
fmt.Print(" Loading T5 text encoder... ")
|
||||
|
||||
// Load config
|
||||
var cfg T5Config
|
||||
configPath := filepath.Join(path, "config.json")
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read config: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return fmt.Errorf("parse config: %w", err)
|
||||
}
|
||||
m.Config = &cfg
|
||||
|
||||
// Pre-allocate layers
|
||||
m.Layers = make([]*T5Block, cfg.NumLayers)
|
||||
|
||||
// Load weights from safetensors files
|
||||
weights, err := safetensors.LoadModelWeights(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
if err := weights.Load(0); err != nil {
|
||||
return fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
defer weights.ReleaseAll()
|
||||
|
||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||
return fmt.Errorf("load module: %w", err)
|
||||
}
|
||||
|
||||
m.initComputedFields()
|
||||
fmt.Println("✓")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *T5TextEncoder) initComputedFields() {
|
||||
cfg := m.Config
|
||||
m.FinalNorm.Eps = cfg.LayerNormEps
|
||||
for _, block := range m.Layers {
|
||||
attn := block.Layer0.SelfAttention
|
||||
attn.NHeads = cfg.NumHeads
|
||||
attn.DKV = cfg.DKV
|
||||
attn.Scale = float32(1.0 / math.Sqrt(float64(cfg.DKV)))
|
||||
|
||||
block.Layer0.LayerNorm.Eps = cfg.LayerNormEps
|
||||
block.Layer1.LayerNorm.Eps = cfg.LayerNormEps
|
||||
}
|
||||
}
|
||||
|
||||
// Forward encodes text tokens
|
||||
func (m *T5TextEncoder) Forward(tokens *mlx.Array) *mlx.Array {
|
||||
cfg := m.Config
|
||||
|
||||
// Get embeddings
|
||||
h := m.SharedEmbed.Forward(tokens)
|
||||
|
||||
// Compute relative position bias once
|
||||
seqLen := tokens.Shape()[1]
|
||||
posBias := m.computeRelativePositionBias(seqLen)
|
||||
|
||||
// Forward through layers
|
||||
for _, block := range m.Layers {
|
||||
h = block.Forward(h, posBias, cfg.LayerNormEps)
|
||||
}
|
||||
|
||||
// Final norm
|
||||
h = m.FinalNorm.Forward(h)
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// extractGlyphTexts extracts quoted text (glyphs) from the prompt
|
||||
// This matches diffusers' get_glyph_texts from pipeline_glm_image.py
|
||||
// Glyph texts are used for text rendering guidance in the generated image
|
||||
func extractGlyphTexts(prompt string) []string {
|
||||
var glyphTexts []string
|
||||
|
||||
// Extract text in single quotes: 'text'
|
||||
re1 := regexp.MustCompile(`'([^']*)'`)
|
||||
for _, match := range re1.FindAllStringSubmatch(prompt, -1) {
|
||||
if len(match) > 1 {
|
||||
glyphTexts = append(glyphTexts, match[1])
|
||||
}
|
||||
}
|
||||
|
||||
// Extract text in Unicode curly double quotes: "text"
|
||||
re2 := regexp.MustCompile(`"([^""]*)"`)
|
||||
for _, match := range re2.FindAllStringSubmatch(prompt, -1) {
|
||||
if len(match) > 1 {
|
||||
glyphTexts = append(glyphTexts, match[1])
|
||||
}
|
||||
}
|
||||
|
||||
// Extract text in ASCII double quotes: "text"
|
||||
re3 := regexp.MustCompile(`"([^"]*)"`)
|
||||
for _, match := range re3.FindAllStringSubmatch(prompt, -1) {
|
||||
if len(match) > 1 {
|
||||
glyphTexts = append(glyphTexts, match[1])
|
||||
}
|
||||
}
|
||||
|
||||
// Extract text in Japanese quotes: 「text」
|
||||
re4 := regexp.MustCompile(`「([^「」]*)」`)
|
||||
for _, match := range re4.FindAllStringSubmatch(prompt, -1) {
|
||||
if len(match) > 1 {
|
||||
glyphTexts = append(glyphTexts, match[1])
|
||||
}
|
||||
}
|
||||
|
||||
return glyphTexts
|
||||
}
|
||||
|
||||
// EncodePrompt encodes the prompt text using the ByT5 tokenizer and encoder
|
||||
// This provides text conditioning for the diffusion transformer via the glyph projector
|
||||
//
|
||||
// IMPORTANT: This encodes only the GLYPH TEXTS (quoted strings in the prompt), not the
|
||||
// full prompt. Glyph texts are used for text rendering guidance in the generated image.
|
||||
// Multiple glyph texts are encoded and concatenated to form the conditioning signal.
|
||||
// This matches diffusers' _get_glyph_embeds() behavior.
|
||||
func (m *T5TextEncoder) EncodePrompt(tok *ByT5Tokenizer, prompt string) *mlx.Array {
|
||||
// Extract glyph texts from prompt (text in quotes)
|
||||
glyphTexts := extractGlyphTexts(prompt)
|
||||
|
||||
// If no glyph texts found, encode empty string (matches diffusers: [""] fallback)
|
||||
if len(glyphTexts) == 0 {
|
||||
glyphTexts = []string{""}
|
||||
}
|
||||
|
||||
// Encode each glyph text and collect token sequences
|
||||
// Matching diffusers' _get_glyph_embeds() which batches all glyph texts
|
||||
var allTokenSeqs [][]int32
|
||||
|
||||
for _, glyphText := range glyphTexts {
|
||||
// ByT5 uses byte-level encoding: each byte (0-255) -> token (3-258)
|
||||
tokens := tok.Encode(glyphText)
|
||||
|
||||
// Add EOS token (1) at the end to match HuggingFace tokenizer behavior
|
||||
tokens = append(tokens, tok.EOSTokenID)
|
||||
|
||||
allTokenSeqs = append(allTokenSeqs, tokens)
|
||||
}
|
||||
|
||||
// Process each glyph text through the encoder
|
||||
var allEmbeddings []*mlx.Array
|
||||
for _, tokens := range allTokenSeqs {
|
||||
tokenLen := len(tokens)
|
||||
if tokenLen == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Create token array [1, L]
|
||||
tokensArr := mlx.NewArrayInt32(tokens, []int32{1, int32(tokenLen)})
|
||||
|
||||
// Forward through encoder
|
||||
output := m.Forward(tokensArr)
|
||||
mlx.Eval(output)
|
||||
|
||||
allEmbeddings = append(allEmbeddings, output)
|
||||
}
|
||||
|
||||
// Concatenate all glyph embeddings along sequence dimension
|
||||
var output *mlx.Array
|
||||
if len(allEmbeddings) == 0 {
|
||||
// Fallback: return single zero embedding
|
||||
output = mlx.Zeros([]int32{1, 1, m.Config.DModel}, mlx.DtypeBFloat16)
|
||||
} else if len(allEmbeddings) == 1 {
|
||||
output = allEmbeddings[0]
|
||||
} else {
|
||||
output = mlx.Concatenate(allEmbeddings, 1)
|
||||
}
|
||||
mlx.Eval(output)
|
||||
|
||||
return output
|
||||
}
|
||||
|
||||
// computeRelativePositionBias computes T5's relative position encoding
|
||||
func (m *T5TextEncoder) computeRelativePositionBias(seqLen int32) *mlx.Array {
|
||||
cfg := m.Config
|
||||
|
||||
// Create relative position matrix
|
||||
// For each (query_pos, key_pos) pair, compute bucketed relative position
|
||||
numBuckets := cfg.RelativeAttentionNumBuckets
|
||||
maxDistance := cfg.RelativeAttentionMaxDistance
|
||||
|
||||
// Create position indices
|
||||
contextPos := make([]int32, seqLen*seqLen)
|
||||
memoryPos := make([]int32, seqLen*seqLen)
|
||||
for i := int32(0); i < seqLen; i++ {
|
||||
for j := int32(0); j < seqLen; j++ {
|
||||
contextPos[i*seqLen+j] = i
|
||||
memoryPos[i*seqLen+j] = j
|
||||
}
|
||||
}
|
||||
|
||||
// Compute relative positions and bucket them
|
||||
buckets := make([]int32, seqLen*seqLen)
|
||||
for i := int32(0); i < seqLen*seqLen; i++ {
|
||||
relPos := memoryPos[i] - contextPos[i]
|
||||
buckets[i] = relativePosistionBucket(relPos, numBuckets, maxDistance, false)
|
||||
}
|
||||
|
||||
// Create bucket indices array
|
||||
bucketsArr := mlx.NewArrayInt32(buckets, []int32{seqLen, seqLen})
|
||||
|
||||
// Look up bias: RelativeAttentionBias shape is [numBuckets, numHeads] = [32, 6]
|
||||
// Take along axis 0 (buckets dimension) -> [seqLen, seqLen, numHeads]
|
||||
bias := mlx.Take(m.RelativeAttentionBias, bucketsArr, 0) // [seqLen, seqLen, numHeads]
|
||||
|
||||
// Transpose to [numHeads, seqLen, seqLen]
|
||||
bias = mlx.Transpose(bias, 2, 0, 1) // [numHeads, seqLen, seqLen]
|
||||
bias = mlx.ExpandDims(bias, 0) // [1, numHeads, seqLen, seqLen]
|
||||
|
||||
return bias
|
||||
}
|
||||
|
||||
// relativePosistionBucket computes the bucket for a relative position
|
||||
func relativePosistionBucket(relativePosition int32, numBuckets int32, maxDistance int32, bidirectional bool) int32 {
|
||||
var bucket int32 = 0
|
||||
var n int32 = -relativePosition
|
||||
|
||||
if bidirectional {
|
||||
numBuckets /= 2
|
||||
if n < 0 {
|
||||
bucket += numBuckets
|
||||
n = -n
|
||||
}
|
||||
} else {
|
||||
if n < 0 {
|
||||
n = 0
|
||||
}
|
||||
}
|
||||
|
||||
// Half buckets are for exact positions, half are for log-spaced
|
||||
maxExact := numBuckets / 2
|
||||
if n < maxExact {
|
||||
bucket += n
|
||||
} else {
|
||||
// Log-spaced buckets
|
||||
logVal := math.Log(float64(n)/float64(maxExact)) / math.Log(float64(maxDistance)/float64(maxExact))
|
||||
bucket += maxExact + int32(logVal*float64(numBuckets-maxExact))
|
||||
if bucket > numBuckets-1 {
|
||||
bucket = numBuckets - 1
|
||||
}
|
||||
}
|
||||
|
||||
return bucket
|
||||
}
|
||||
|
||||
// Forward for T5Block
|
||||
func (b *T5Block) Forward(x *mlx.Array, posBias *mlx.Array, eps float32) *mlx.Array {
|
||||
// Self attention with residual
|
||||
h := b.Layer0.Forward(x, posBias, eps)
|
||||
|
||||
// FFN with residual
|
||||
h = b.Layer1.Forward(h, eps)
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// Forward for T5LayerSelfAttention
|
||||
func (l *T5LayerSelfAttention) Forward(x *mlx.Array, posBias *mlx.Array, eps float32) *mlx.Array {
|
||||
// Pre-norm
|
||||
normed := l.LayerNorm.Forward(x)
|
||||
|
||||
// Attention
|
||||
attnOut := l.SelfAttention.Forward(normed, posBias)
|
||||
|
||||
// Residual
|
||||
return mlx.Add(x, attnOut)
|
||||
}
|
||||
|
||||
// Forward for T5Attention
|
||||
func (attn *T5Attention) Forward(x *mlx.Array, posBias *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
D := shape[2]
|
||||
|
||||
// Q, K, V projections (no bias)
|
||||
// Weights are [out_features, in_features], so we use matmul with transpose
|
||||
q := mlx.Matmul(x, mlx.Transpose(attn.Q, 1, 0))
|
||||
k := mlx.Matmul(x, mlx.Transpose(attn.K, 1, 0))
|
||||
v := mlx.Matmul(x, mlx.Transpose(attn.V, 1, 0))
|
||||
|
||||
// Reshape to [B, L, nheads, d_kv]
|
||||
q = mlx.Reshape(q, B, L, attn.NHeads, attn.DKV)
|
||||
k = mlx.Reshape(k, B, L, attn.NHeads, attn.DKV)
|
||||
v = mlx.Reshape(v, B, L, attn.NHeads, attn.DKV)
|
||||
|
||||
// Transpose to [B, nheads, L, d_kv]
|
||||
q = mlx.Transpose(q, 0, 2, 1, 3)
|
||||
k = mlx.Transpose(k, 0, 2, 1, 3)
|
||||
v = mlx.Transpose(v, 0, 2, 1, 3)
|
||||
|
||||
// Attention scores with relative position bias
|
||||
// T5 uses UNSCALED dot-product attention: scores = q @ k.T + pos_bias
|
||||
// (no 1/sqrt(d_k) scale factor like in standard transformers)
|
||||
scores := mlx.Matmul(q, mlx.Transpose(k, 0, 1, 3, 2))
|
||||
scores = mlx.Add(scores, posBias)
|
||||
|
||||
// Softmax
|
||||
attnWeights := mlx.Softmax(scores, -1)
|
||||
|
||||
// Attend to values
|
||||
out := mlx.Matmul(attnWeights, v)
|
||||
|
||||
// Transpose back [B, nheads, L, d_kv] -> [B, L, nheads, d_kv]
|
||||
out = mlx.Transpose(out, 0, 2, 1, 3)
|
||||
// Reshape to [B, L, D]
|
||||
out = mlx.Reshape(out, B, L, attn.NHeads*attn.DKV)
|
||||
|
||||
// Output projection
|
||||
out = mlx.Matmul(out, mlx.Transpose(attn.O, 1, 0))
|
||||
|
||||
_ = D // Silence unused warning
|
||||
return out
|
||||
}
|
||||
|
||||
// Forward for T5LayerFF
|
||||
func (l *T5LayerFF) Forward(x *mlx.Array, eps float32) *mlx.Array {
|
||||
// Pre-norm
|
||||
normed := l.LayerNorm.Forward(x)
|
||||
|
||||
// FFN
|
||||
ffOut := l.DenseReluDense.Forward(normed)
|
||||
|
||||
// Residual
|
||||
return mlx.Add(x, ffOut)
|
||||
}
|
||||
|
||||
// geluNew implements the GELU activation with tanh approximation (gelu_new)
|
||||
// This matches HuggingFace transformers' gelu_new/OpenAI GPT implementation
|
||||
// Formula: 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x³)))
|
||||
func geluNew(x *mlx.Array) *mlx.Array {
|
||||
sqrt2OverPi := float32(0.7978845608) // sqrt(2/π)
|
||||
coeff := float32(0.044715)
|
||||
|
||||
x3 := mlx.Mul(mlx.Mul(x, x), x)
|
||||
inner := mlx.MulScalar(mlx.Add(x, mlx.MulScalar(x3, coeff)), sqrt2OverPi)
|
||||
return mlx.Mul(mlx.MulScalar(x, 0.5), mlx.AddScalar(mlx.Tanh(inner), 1.0))
|
||||
}
|
||||
|
||||
// Forward for T5DenseGatedGelu (gated-gelu activation)
|
||||
func (d *T5DenseGatedGelu) Forward(x *mlx.Array) *mlx.Array {
|
||||
// Gate projection with GELU activation (T5 v1.1/ByT5 uses gelu_new)
|
||||
gate := mlx.Matmul(x, mlx.Transpose(d.Wi0, 1, 0))
|
||||
gate = geluNew(gate)
|
||||
|
||||
// Up projection
|
||||
up := mlx.Matmul(x, mlx.Transpose(d.Wi1, 1, 0))
|
||||
|
||||
// Gated output
|
||||
h := mlx.Mul(gate, up)
|
||||
|
||||
// Down projection
|
||||
return mlx.Matmul(h, mlx.Transpose(d.Wo, 1, 0))
|
||||
}
|
||||
|
||||
// Forward for T5LayerNorm (RMSNorm variant)
|
||||
func (ln *T5LayerNorm) Forward(x *mlx.Array) *mlx.Array {
|
||||
// T5 uses RMSNorm: x * rsqrt(mean(x^2) + eps) * weight
|
||||
variance := mlx.Mean(mlx.Square(x), -1, true)
|
||||
x = mlx.Mul(x, mlx.RSqrt(mlx.AddScalar(variance, ln.Eps)))
|
||||
return mlx.Mul(x, ln.Weight)
|
||||
}
|
||||
1255
x/imagegen/models/glm_image/transformer.go
Normal file
477
x/imagegen/models/glm_image/vae.go
Normal file
@@ -0,0 +1,477 @@
|
||||
//go:build mlx
|
||||
|
||||
package glm_image
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
|
||||
// VAEConfig holds VAE decoder configuration
|
||||
type VAEConfig struct {
|
||||
InChannels int32 `json:"in_channels"` // 3
|
||||
OutChannels int32 `json:"out_channels"` // 3
|
||||
LatentChannels int32 `json:"latent_channels"` // 16
|
||||
BlockOutChannels []int32 `json:"block_out_channels"` // [128, 512, 1024, 1024]
|
||||
LayersPerBlock int32 `json:"layers_per_block"` // 3
|
||||
NormNumGroups int32 `json:"norm_num_groups"` // 32
|
||||
ScalingFactor float32 `json:"scaling_factor"` // 0.18215
|
||||
ShiftFactor *float32 `json:"shift_factor"` // null
|
||||
LatentsMean []float32 `json:"latents_mean"` // [16 values]
|
||||
LatentsStd []float32 `json:"latents_std"` // [16 values]
|
||||
}
|
||||
|
||||
// VAEDecoder is the VAE latent decoder
|
||||
type VAEDecoder struct {
|
||||
Config *VAEConfig
|
||||
|
||||
// Decoder components
|
||||
ConvIn *VAEConv2d `weight:"decoder.conv_in"`
|
||||
MidBlock *VAEMidBlock `weight:"decoder.mid_block"`
|
||||
UpBlocks []*VAEUpBlock `weight:"decoder.up_blocks"`
|
||||
ConvNormOut *GroupNorm `weight:"decoder.conv_norm_out"`
|
||||
ConvOut *VAEConv2d `weight:"decoder.conv_out"`
|
||||
}
|
||||
|
||||
// VAEConv2d is a 2D convolution layer
|
||||
type VAEConv2d struct {
|
||||
Weight *mlx.Array `weight:"weight"`
|
||||
Bias *mlx.Array `weight:"bias"`
|
||||
Stride int32
|
||||
Padding int32
|
||||
}
|
||||
|
||||
// GroupNorm is group normalization
|
||||
type GroupNorm struct {
|
||||
Weight *mlx.Array `weight:"weight"`
|
||||
Bias *mlx.Array `weight:"bias"`
|
||||
NumGroups int32
|
||||
Eps float32
|
||||
}
|
||||
|
||||
// VAEMidBlock is the middle block of the VAE
|
||||
type VAEMidBlock struct {
|
||||
Resnets []*VAEResnetBlock `weight:"resnets"`
|
||||
}
|
||||
|
||||
// VAEUpBlock is an upsampling block
|
||||
type VAEUpBlock struct {
|
||||
Resnets []*VAEResnetBlock `weight:"resnets"`
|
||||
Upsamplers []*VAEUpsampler `weight:"upsamplers"`
|
||||
}
|
||||
|
||||
// VAEResnetBlock is a residual block
|
||||
type VAEResnetBlock struct {
|
||||
Norm1 *GroupNorm `weight:"norm1"`
|
||||
Conv1 *VAEConv2d `weight:"conv1"`
|
||||
Norm2 *GroupNorm `weight:"norm2"`
|
||||
Conv2 *VAEConv2d `weight:"conv2"`
|
||||
ConvShortcut *VAEConv2d `weight:"conv_shortcut,optional"` // Optional, for channel mismatch
|
||||
}
|
||||
|
||||
// VAEUpsampler is an upsampling layer
|
||||
type VAEUpsampler struct {
|
||||
Conv *VAEConv2d `weight:"conv"`
|
||||
}
|
||||
|
||||
// Load loads the VAE decoder from manifest
|
||||
func (m *VAEDecoder) Load(manifest *imagegen.ModelManifest) error {
|
||||
fmt.Print(" Loading VAE decoder... ")
|
||||
|
||||
// Load config
|
||||
var cfg VAEConfig
|
||||
if err := manifest.ReadConfigJSON("vae/config.json", &cfg); err != nil {
|
||||
return fmt.Errorf("config: %w", err)
|
||||
}
|
||||
m.Config = &cfg
|
||||
|
||||
// Initialize structure based on config
|
||||
numBlocks := len(cfg.BlockOutChannels)
|
||||
m.UpBlocks = make([]*VAEUpBlock, numBlocks)
|
||||
|
||||
// Pre-allocate MidBlock resnets (VAE mid_block typically has 2 resnets)
|
||||
m.MidBlock = &VAEMidBlock{
|
||||
Resnets: make([]*VAEResnetBlock, 2),
|
||||
}
|
||||
|
||||
// Pre-allocate UpBlocks with their resnets and upsamplers
|
||||
// VAE decoder has layers_per_block+1 resnets per up_block (to match encoder)
|
||||
// And all but the last up_block has an upsampler
|
||||
for i := 0; i < numBlocks; i++ {
|
||||
numResnets := cfg.LayersPerBlock + 1 // typically 4 resnets
|
||||
m.UpBlocks[i] = &VAEUpBlock{
|
||||
Resnets: make([]*VAEResnetBlock, numResnets),
|
||||
}
|
||||
// All but the last block has upsamplers
|
||||
if i < numBlocks-1 {
|
||||
m.UpBlocks[i].Upsamplers = make([]*VAEUpsampler, 1)
|
||||
}
|
||||
}
|
||||
|
||||
// Load weights
|
||||
weights, err := imagegen.LoadWeightsFromManifest(manifest, "vae")
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
|
||||
return fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
defer weights.ReleaseAll()
|
||||
|
||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||
return fmt.Errorf("load module: %w", err)
|
||||
}
|
||||
|
||||
// Initialize GroupNorm parameters
|
||||
m.initGroupNorms()
|
||||
|
||||
fmt.Println("✓")
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadFromPath loads the VAE decoder from a directory path
|
||||
func (m *VAEDecoder) LoadFromPath(path string) error {
|
||||
fmt.Print(" Loading VAE decoder... ")
|
||||
|
||||
// Load config
|
||||
var cfg VAEConfig
|
||||
configPath := filepath.Join(path, "config.json")
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read config: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return fmt.Errorf("parse config: %w", err)
|
||||
}
|
||||
m.Config = &cfg
|
||||
|
||||
// Initialize structure based on config
|
||||
numBlocks := len(cfg.BlockOutChannels)
|
||||
m.UpBlocks = make([]*VAEUpBlock, numBlocks)
|
||||
|
||||
// Pre-allocate MidBlock resnets (VAE mid_block typically has 2 resnets)
|
||||
m.MidBlock = &VAEMidBlock{
|
||||
Resnets: make([]*VAEResnetBlock, 2),
|
||||
}
|
||||
|
||||
// Pre-allocate UpBlocks with their resnets and upsamplers
|
||||
for i := 0; i < numBlocks; i++ {
|
||||
numResnets := cfg.LayersPerBlock + 1
|
||||
m.UpBlocks[i] = &VAEUpBlock{
|
||||
Resnets: make([]*VAEResnetBlock, numResnets),
|
||||
}
|
||||
if i < numBlocks-1 {
|
||||
m.UpBlocks[i].Upsamplers = make([]*VAEUpsampler, 1)
|
||||
}
|
||||
}
|
||||
|
||||
// Load weights from safetensors files
|
||||
weights, err := safetensors.LoadModelWeights(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
|
||||
return fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
defer weights.ReleaseAll()
|
||||
|
||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||
return fmt.Errorf("load module: %w", err)
|
||||
}
|
||||
|
||||
// Initialize GroupNorm parameters
|
||||
m.initGroupNorms()
|
||||
|
||||
fmt.Println("✓")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *VAEDecoder) initGroupNorms() {
|
||||
cfg := m.Config
|
||||
numGroups := cfg.NormNumGroups
|
||||
eps := float32(1e-6) // Must match diffusers VAE (1e-6, not 1e-5)
|
||||
|
||||
if m.ConvNormOut != nil {
|
||||
m.ConvNormOut.NumGroups = numGroups
|
||||
m.ConvNormOut.Eps = eps
|
||||
}
|
||||
|
||||
if m.MidBlock != nil {
|
||||
for _, resnet := range m.MidBlock.Resnets {
|
||||
if resnet.Norm1 != nil {
|
||||
resnet.Norm1.NumGroups = numGroups
|
||||
resnet.Norm1.Eps = eps
|
||||
}
|
||||
if resnet.Norm2 != nil {
|
||||
resnet.Norm2.NumGroups = numGroups
|
||||
resnet.Norm2.Eps = eps
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, upBlock := range m.UpBlocks {
|
||||
if upBlock == nil {
|
||||
continue
|
||||
}
|
||||
for _, resnet := range upBlock.Resnets {
|
||||
if resnet == nil {
|
||||
continue
|
||||
}
|
||||
if resnet.Norm1 != nil {
|
||||
resnet.Norm1.NumGroups = numGroups
|
||||
resnet.Norm1.Eps = eps
|
||||
}
|
||||
if resnet.Norm2 != nil {
|
||||
resnet.Norm2.NumGroups = numGroups
|
||||
resnet.Norm2.Eps = eps
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Decode decodes latents to an image
|
||||
func (m *VAEDecoder) Decode(latents *mlx.Array) *mlx.Array {
|
||||
cfg := m.Config
|
||||
|
||||
// Apply latent denormalization if mean/std are provided
|
||||
// This matches diffusers GLM-Image: latents = latents * std + mean
|
||||
// Note: GLM-Image does NOT divide by scaling_factor (unlike standard SD VAEs)
|
||||
if len(cfg.LatentsMean) > 0 && len(cfg.LatentsStd) > 0 {
|
||||
latents = m.denormalizeLatents(latents)
|
||||
}
|
||||
|
||||
// Convert from NCHW to NHWC for processing
|
||||
// [B, C, H, W] -> [B, H, W, C]
|
||||
x := mlx.Transpose(latents, 0, 2, 3, 1)
|
||||
|
||||
// Initial convolution
|
||||
x = m.ConvIn.Forward(x)
|
||||
|
||||
// Mid block
|
||||
x = m.MidBlock.Forward(x)
|
||||
|
||||
// Up blocks (forward order - index 0 is at lowest resolution/highest channels)
|
||||
for i := 0; i < len(m.UpBlocks); i++ {
|
||||
if m.UpBlocks[i] != nil {
|
||||
x = m.UpBlocks[i].Forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
// Final normalization and convolution
|
||||
x = m.ConvNormOut.Forward(x)
|
||||
x = mlx.SiLU(x)
|
||||
x = m.ConvOut.Forward(x)
|
||||
|
||||
// Convert back to NCHW
|
||||
// [B, H, W, C] -> [B, C, H, W]
|
||||
x = mlx.Transpose(x, 0, 3, 1, 2)
|
||||
|
||||
// Clamp to valid range and convert to [0, 1]
|
||||
x = mlx.ClipScalar(x, -1.0, 1.0, true, true)
|
||||
x = mlx.AddScalar(x, 1.0)
|
||||
x = mlx.DivScalar(x, 2.0)
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// denormalizeLatents applies the latent mean/std denormalization
|
||||
func (m *VAEDecoder) denormalizeLatents(latents *mlx.Array) *mlx.Array {
|
||||
cfg := m.Config
|
||||
|
||||
// Create mean and std arrays [1, C, 1, 1] for broadcasting
|
||||
mean := mlx.NewArray(cfg.LatentsMean, []int32{1, int32(len(cfg.LatentsMean)), 1, 1})
|
||||
std := mlx.NewArray(cfg.LatentsStd, []int32{1, int32(len(cfg.LatentsStd)), 1, 1})
|
||||
|
||||
// Denormalize: latents * std + mean
|
||||
latents = mlx.Mul(latents, std)
|
||||
latents = mlx.Add(latents, mean)
|
||||
|
||||
return latents
|
||||
}
|
||||
|
||||
// Forward for VAEConv2d
|
||||
func (c *VAEConv2d) Forward(x *mlx.Array) *mlx.Array {
|
||||
// x: [B, H, W, C_in] (NHWC)
|
||||
// PyTorch weight: [C_out, C_in, kH, kW] (OIHW)
|
||||
// MLX conv2d expects weight: [C_out, kH, kW, C_in] (OHWI)
|
||||
// So we need to transpose from OIHW to OHWI
|
||||
|
||||
stride := c.Stride
|
||||
if stride == 0 {
|
||||
stride = 1
|
||||
}
|
||||
padding := c.Padding
|
||||
if padding == 0 {
|
||||
// Default to same padding for 3x3 kernels
|
||||
wShape := c.Weight.Shape()
|
||||
if len(wShape) >= 3 && wShape[2] == 3 {
|
||||
padding = 1
|
||||
}
|
||||
}
|
||||
|
||||
// Transpose weight from OIHW [out, in, h, w] to OHWI [out, h, w, in]
|
||||
weight := mlx.Transpose(c.Weight, 0, 2, 3, 1)
|
||||
|
||||
out := mlx.Conv2d(x, weight, stride, padding)
|
||||
if c.Bias != nil {
|
||||
// Bias: [C_out] -> [1, 1, 1, C_out]
|
||||
bias := mlx.Reshape(c.Bias, 1, 1, 1, -1)
|
||||
out = mlx.Add(out, bias)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// Forward for GroupNorm
|
||||
func (gn *GroupNorm) Forward(x *mlx.Array) *mlx.Array {
|
||||
// x: [B, H, W, C] (NHWC)
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
H := shape[1]
|
||||
W := shape[2]
|
||||
C := shape[3]
|
||||
|
||||
numGroups := gn.NumGroups
|
||||
if numGroups == 0 {
|
||||
numGroups = 32
|
||||
}
|
||||
groupSize := C / numGroups
|
||||
|
||||
// Reshape to [B, H, W, groups, groupSize]
|
||||
x = mlx.Reshape(x, B, H, W, numGroups, groupSize)
|
||||
|
||||
// Compute mean and variance per group
|
||||
mean := mlx.Mean(x, 1, true)
|
||||
mean = mlx.Mean(mean, 2, true)
|
||||
mean = mlx.Mean(mean, 4, true)
|
||||
|
||||
xCentered := mlx.Sub(x, mean)
|
||||
variance := mlx.Mean(mlx.Square(xCentered), 1, true)
|
||||
variance = mlx.Mean(variance, 2, true)
|
||||
variance = mlx.Mean(variance, 4, true)
|
||||
|
||||
// Normalize
|
||||
xNorm := mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, gn.Eps)))
|
||||
|
||||
// Reshape back
|
||||
xNorm = mlx.Reshape(xNorm, B, H, W, C)
|
||||
|
||||
// Scale and shift
|
||||
if gn.Weight != nil {
|
||||
weight := mlx.Reshape(gn.Weight, 1, 1, 1, C)
|
||||
xNorm = mlx.Mul(xNorm, weight)
|
||||
}
|
||||
if gn.Bias != nil {
|
||||
bias := mlx.Reshape(gn.Bias, 1, 1, 1, C)
|
||||
xNorm = mlx.Add(xNorm, bias)
|
||||
}
|
||||
|
||||
return xNorm
|
||||
}
|
||||
|
||||
// Forward for VAEMidBlock
|
||||
func (mb *VAEMidBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
for _, resnet := range mb.Resnets {
|
||||
x = resnet.Forward(x)
|
||||
}
|
||||
return x
|
||||
}
|
||||
|
||||
// Forward for VAEUpBlock
|
||||
func (ub *VAEUpBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
// Apply resnets
|
||||
for _, resnet := range ub.Resnets {
|
||||
if resnet != nil {
|
||||
x = resnet.Forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
// Apply upsamplers
|
||||
for _, upsampler := range ub.Upsamplers {
|
||||
if upsampler != nil {
|
||||
x = upsampler.Forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// Forward for VAEResnetBlock
|
||||
func (rb *VAEResnetBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
residual := x
|
||||
|
||||
// First norm + activation + conv
|
||||
h := rb.Norm1.Forward(x)
|
||||
h = mlx.SiLU(h)
|
||||
h = rb.Conv1.Forward(h)
|
||||
|
||||
// Second norm + activation + conv
|
||||
h = rb.Norm2.Forward(h)
|
||||
h = mlx.SiLU(h)
|
||||
h = rb.Conv2.Forward(h)
|
||||
|
||||
// Shortcut for channel mismatch
|
||||
if rb.ConvShortcut != nil {
|
||||
residual = rb.ConvShortcut.Forward(residual)
|
||||
}
|
||||
|
||||
return mlx.Add(h, residual)
|
||||
}
|
||||
|
||||
// Forward for VAEUpsampler (2x nearest neighbor upsample + conv)
|
||||
func (us *VAEUpsampler) Forward(x *mlx.Array) *mlx.Array {
|
||||
// x: [B, H, W, C]
|
||||
// 2x nearest neighbor upsample
|
||||
x = upsample2x(x)
|
||||
|
||||
// Conv
|
||||
if us.Conv != nil {
|
||||
x = us.Conv.Forward(x)
|
||||
}
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// upsample2x performs 2x nearest neighbor upsampling.
|
||||
// Input and output are in NHWC format: [B, H, W, C] -> [B, H*2, W*2, C]
|
||||
func upsample2x(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
H := shape[1]
|
||||
W := shape[2]
|
||||
C := shape[3]
|
||||
|
||||
// Create indices [0, 0, 1, 1, 2, 2, ...] for nearest neighbor
|
||||
hIndices := make([]int32, H*2)
|
||||
for i := int32(0); i < H; i++ {
|
||||
hIndices[i*2] = i
|
||||
hIndices[i*2+1] = i
|
||||
}
|
||||
wIndices := make([]int32, W*2)
|
||||
for i := int32(0); i < W; i++ {
|
||||
wIndices[i*2] = i
|
||||
wIndices[i*2+1] = i
|
||||
}
|
||||
|
||||
hIdx := mlx.NewArrayInt32(hIndices, []int32{H * 2})
|
||||
wIdx := mlx.NewArrayInt32(wIndices, []int32{W * 2})
|
||||
|
||||
// Take along height axis
|
||||
x = mlx.Reshape(x, B*H, W, C)
|
||||
x = mlx.Take(x, wIdx, 1) // [B*H, W*2, C]
|
||||
x = mlx.Reshape(x, B, H, W*2, C)
|
||||
|
||||
// Take along width axis - transpose to [B, W*2, H, C], take, transpose back
|
||||
x = mlx.Transpose(x, 0, 2, 1, 3) // [B, W*2, H, C]
|
||||
x = mlx.Reshape(x, B*(W*2), H, C)
|
||||
x = mlx.Take(x, hIdx, 1) // [B*(W*2), H*2, C]
|
||||
x = mlx.Reshape(x, B, W*2, H*2, C)
|
||||
x = mlx.Transpose(x, 0, 2, 1, 3) // [B, H*2, W*2, C]
|
||||
|
||||
return x
|
||||
}
|
||||
982
x/imagegen/models/glm_image/vision_language_encoder.go
Normal file
@@ -0,0 +1,982 @@
|
||||
//go:build mlx
|
||||
|
||||
package glm_image
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
|
||||
// VisionLanguageConfig holds GLM-Image AR generator configuration
|
||||
type VisionLanguageConfig struct {
|
||||
// Text model config
|
||||
HiddenSize int32 `json:"hidden_size"` // 4096
|
||||
NumHiddenLayers int32 `json:"num_hidden_layers"` // 40
|
||||
IntermediateSize int32 `json:"intermediate_size"` // 13696
|
||||
NumAttentionHeads int32 `json:"num_attention_heads"` // 32
|
||||
NumKeyValueHeads int32 `json:"num_key_value_heads"` // 2
|
||||
VocabSize int32 `json:"vocab_size"` // 168064
|
||||
RMSNormEps float32 `json:"rms_norm_eps"` // 1e-5
|
||||
|
||||
// RoPE config
|
||||
RopeTheta float32 `json:"rope_theta"` // 10000
|
||||
PartialRotaryFactor float32 `json:"partial_rotary_factor"` // 0.5
|
||||
MRoPESection []int32 `json:"mrope_section"` // [8, 12, 12]
|
||||
|
||||
// Visual token config
|
||||
VisionVocabSize int32 `json:"vision_vocab_size"` // 16512
|
||||
ImageStartTokenID int32 `json:"image_start_token_id"` // 16384
|
||||
ImageEndTokenID int32 `json:"image_end_token_id"` // 16385
|
||||
ImageTokenID int32 `json:"image_token_id"` // 167855
|
||||
|
||||
// Computed
|
||||
HeadDim int32
|
||||
}
|
||||
|
||||
// VisionLanguageEncoder is the 9B AR generator
|
||||
type VisionLanguageEncoder struct {
|
||||
Config *VisionLanguageConfig
|
||||
|
||||
// Embedding
|
||||
EmbedTokens *nn.Embedding `weight:"model.language_model.embed_tokens"`
|
||||
|
||||
// Transformer layers
|
||||
Layers []*GLMBlock `weight:"model.language_model.layers"`
|
||||
|
||||
// Final norm
|
||||
FinalNorm *nn.RMSNorm `weight:"model.language_model.norm"`
|
||||
|
||||
// LM Head
|
||||
LMHead *mlx.Array `weight:"lm_head.weight"`
|
||||
}
|
||||
|
||||
// GLMBlock is a single transformer block in GLM-4 style
|
||||
type GLMBlock struct {
|
||||
// Pre-attention norm (GLM uses post-LN variant)
|
||||
InputLayerNorm *nn.RMSNorm `weight:"input_layernorm"`
|
||||
PostSelfAttnNorm *nn.RMSNorm `weight:"post_self_attn_layernorm"`
|
||||
PostAttnLayerNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
|
||||
PostMLPLayerNorm *nn.RMSNorm `weight:"post_mlp_layernorm"`
|
||||
|
||||
// Attention
|
||||
SelfAttn *GLMAttention `weight:"self_attn"`
|
||||
|
||||
// MLP (fused gate_up)
|
||||
MLP *GLMMLP `weight:"mlp"`
|
||||
}
|
||||
|
||||
// GLMAttention implements GQA with partial rotary and MRoPE
|
||||
type GLMAttention struct {
|
||||
QProj *mlx.Array `weight:"q_proj.weight"`
|
||||
KProj *mlx.Array `weight:"k_proj.weight"`
|
||||
VProj *mlx.Array `weight:"v_proj.weight"`
|
||||
OProj *mlx.Array `weight:"o_proj.weight"`
|
||||
|
||||
// QKV have biases in GLM
|
||||
QBias *mlx.Array `weight:"q_proj.bias"`
|
||||
KBias *mlx.Array `weight:"k_proj.bias"`
|
||||
VBias *mlx.Array `weight:"v_proj.bias"`
|
||||
|
||||
// Computed
|
||||
NHeads int32
|
||||
NKVHeads int32
|
||||
HeadDim int32
|
||||
Scale float32
|
||||
PartialRotary float32 // Only rotate this fraction of head_dim
|
||||
RopeTheta float32
|
||||
MRoPESection []int32 // [8, 12, 12] - frequency pairs per dimension (temporal, height, width)
|
||||
}
|
||||
|
||||
// ARCache holds KV caches for all layers using the shared cache implementation
|
||||
type ARCache struct {
|
||||
Layers []cache.Cache
|
||||
}
|
||||
|
||||
// NewARCache creates a new cache for the given number of layers
|
||||
func NewARCache(numLayers int32) *ARCache {
|
||||
layers := make([]cache.Cache, numLayers)
|
||||
for i := range layers {
|
||||
layers[i] = cache.NewKVCache()
|
||||
}
|
||||
return &ARCache{Layers: layers}
|
||||
}
|
||||
|
||||
// Free releases all cached tensors
|
||||
func (c *ARCache) Free() {
|
||||
for _, layer := range c.Layers {
|
||||
for _, arr := range layer.State() {
|
||||
if arr != nil {
|
||||
arr.Free()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GLMMLP implements fused gate_up SwiGLU MLP
|
||||
type GLMMLP struct {
|
||||
// GLM uses fused gate_up_proj: [hidden, 2*intermediate]
|
||||
GateUpProj *mlx.Array `weight:"gate_up_proj.weight"`
|
||||
DownProj *mlx.Array `weight:"down_proj.weight"`
|
||||
}
|
||||
|
||||
// Load loads the vision-language encoder from manifest
|
||||
func (m *VisionLanguageEncoder) Load(manifest *imagegen.ModelManifest) error {
|
||||
fmt.Print(" Loading vision-language encoder... ")
|
||||
|
||||
// Load config
|
||||
var rawCfg struct {
|
||||
TextConfig struct {
|
||||
HiddenSize int32 `json:"hidden_size"`
|
||||
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
||||
IntermediateSize int32 `json:"intermediate_size"`
|
||||
NumAttentionHeads int32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads int32 `json:"num_key_value_heads"`
|
||||
VocabSize int32 `json:"vocab_size"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
VisionVocabSize int32 `json:"vision_vocab_size"`
|
||||
RopeParameters struct {
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
|
||||
MRoPESection []int32 `json:"mrope_section"`
|
||||
} `json:"rope_parameters"`
|
||||
} `json:"text_config"`
|
||||
ImageStartTokenID int32 `json:"image_start_token_id"`
|
||||
ImageEndTokenID int32 `json:"image_end_token_id"`
|
||||
ImageTokenID int32 `json:"image_token_id"`
|
||||
}
|
||||
|
||||
if err := manifest.ReadConfigJSON("vision_language_encoder/config.json", &rawCfg); err != nil {
|
||||
return fmt.Errorf("config: %w", err)
|
||||
}
|
||||
|
||||
cfg := &VisionLanguageConfig{
|
||||
HiddenSize: rawCfg.TextConfig.HiddenSize,
|
||||
NumHiddenLayers: rawCfg.TextConfig.NumHiddenLayers,
|
||||
IntermediateSize: rawCfg.TextConfig.IntermediateSize,
|
||||
NumAttentionHeads: rawCfg.TextConfig.NumAttentionHeads,
|
||||
NumKeyValueHeads: rawCfg.TextConfig.NumKeyValueHeads,
|
||||
VocabSize: rawCfg.TextConfig.VocabSize,
|
||||
RMSNormEps: rawCfg.TextConfig.RMSNormEps,
|
||||
VisionVocabSize: rawCfg.TextConfig.VisionVocabSize,
|
||||
RopeTheta: rawCfg.TextConfig.RopeParameters.RopeTheta,
|
||||
PartialRotaryFactor: rawCfg.TextConfig.RopeParameters.PartialRotaryFactor,
|
||||
MRoPESection: rawCfg.TextConfig.RopeParameters.MRoPESection,
|
||||
ImageStartTokenID: rawCfg.ImageStartTokenID,
|
||||
ImageEndTokenID: rawCfg.ImageEndTokenID,
|
||||
ImageTokenID: rawCfg.ImageTokenID,
|
||||
}
|
||||
|
||||
cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads
|
||||
m.Config = cfg
|
||||
|
||||
// Pre-allocate layers
|
||||
m.Layers = make([]*GLMBlock, cfg.NumHiddenLayers)
|
||||
|
||||
// Load weights
|
||||
weights, err := imagegen.LoadWeightsFromManifest(manifest, "vision_language_encoder")
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
|
||||
return fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
defer weights.ReleaseAll()
|
||||
|
||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||
return fmt.Errorf("load module: %w", err)
|
||||
}
|
||||
|
||||
m.initComputedFields()
|
||||
fmt.Printf("✓ [%d layers]\n", cfg.NumHiddenLayers)
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadFromPath loads the vision-language encoder from a directory path
|
||||
func (m *VisionLanguageEncoder) LoadFromPath(path string) error {
|
||||
fmt.Print(" Loading vision-language encoder... ")
|
||||
|
||||
// Load config
|
||||
var rawCfg struct {
|
||||
TextConfig struct {
|
||||
HiddenSize int32 `json:"hidden_size"`
|
||||
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
||||
IntermediateSize int32 `json:"intermediate_size"`
|
||||
NumAttentionHeads int32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads int32 `json:"num_key_value_heads"`
|
||||
VocabSize int32 `json:"vocab_size"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
VisionVocabSize int32 `json:"vision_vocab_size"`
|
||||
RopeParameters struct {
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
|
||||
MRoPESection []int32 `json:"mrope_section"`
|
||||
} `json:"rope_parameters"`
|
||||
} `json:"text_config"`
|
||||
ImageStartTokenID int32 `json:"image_start_token_id"`
|
||||
ImageEndTokenID int32 `json:"image_end_token_id"`
|
||||
ImageTokenID int32 `json:"image_token_id"`
|
||||
}
|
||||
|
||||
configPath := filepath.Join(path, "config.json")
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read config: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(data, &rawCfg); err != nil {
|
||||
return fmt.Errorf("parse config: %w", err)
|
||||
}
|
||||
|
||||
cfg := &VisionLanguageConfig{
|
||||
HiddenSize: rawCfg.TextConfig.HiddenSize,
|
||||
NumHiddenLayers: rawCfg.TextConfig.NumHiddenLayers,
|
||||
IntermediateSize: rawCfg.TextConfig.IntermediateSize,
|
||||
NumAttentionHeads: rawCfg.TextConfig.NumAttentionHeads,
|
||||
NumKeyValueHeads: rawCfg.TextConfig.NumKeyValueHeads,
|
||||
VocabSize: rawCfg.TextConfig.VocabSize,
|
||||
RMSNormEps: rawCfg.TextConfig.RMSNormEps,
|
||||
VisionVocabSize: rawCfg.TextConfig.VisionVocabSize,
|
||||
RopeTheta: rawCfg.TextConfig.RopeParameters.RopeTheta,
|
||||
PartialRotaryFactor: rawCfg.TextConfig.RopeParameters.PartialRotaryFactor,
|
||||
MRoPESection: rawCfg.TextConfig.RopeParameters.MRoPESection,
|
||||
ImageStartTokenID: rawCfg.ImageStartTokenID,
|
||||
ImageEndTokenID: rawCfg.ImageEndTokenID,
|
||||
ImageTokenID: rawCfg.ImageTokenID,
|
||||
}
|
||||
|
||||
cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads
|
||||
m.Config = cfg
|
||||
|
||||
// Pre-allocate layers
|
||||
m.Layers = make([]*GLMBlock, cfg.NumHiddenLayers)
|
||||
|
||||
// Load weights
|
||||
weights, err := safetensors.LoadModelWeights(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
|
||||
return fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
defer weights.ReleaseAll()
|
||||
|
||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||
return fmt.Errorf("load module: %w", err)
|
||||
}
|
||||
|
||||
m.initComputedFields()
|
||||
fmt.Printf("✓ [%d layers]\n", cfg.NumHiddenLayers)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *VisionLanguageEncoder) initComputedFields() {
|
||||
cfg := m.Config
|
||||
for _, block := range m.Layers {
|
||||
block.SelfAttn.NHeads = cfg.NumAttentionHeads
|
||||
block.SelfAttn.NKVHeads = cfg.NumKeyValueHeads
|
||||
block.SelfAttn.HeadDim = cfg.HeadDim
|
||||
block.SelfAttn.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
|
||||
block.SelfAttn.PartialRotary = cfg.PartialRotaryFactor
|
||||
block.SelfAttn.RopeTheta = cfg.RopeTheta
|
||||
block.SelfAttn.MRoPESection = cfg.MRoPESection
|
||||
|
||||
// Set norm eps
|
||||
block.InputLayerNorm.Eps = cfg.RMSNormEps
|
||||
block.PostSelfAttnNorm.Eps = cfg.RMSNormEps
|
||||
block.PostAttnLayerNorm.Eps = cfg.RMSNormEps
|
||||
block.PostMLPLayerNorm.Eps = cfg.RMSNormEps
|
||||
}
|
||||
m.FinalNorm.Eps = cfg.RMSNormEps
|
||||
}
|
||||
|
||||
// Generate autoregressively generates visual tokens with KV caching
|
||||
func (m *VisionLanguageEncoder) Generate(
|
||||
prompt string,
|
||||
tok *GLMTokenizer,
|
||||
maxTokens int32,
|
||||
temperature float32,
|
||||
topP float32,
|
||||
seed int64,
|
||||
targetHeight, targetWidth int32,
|
||||
progressFn func(int),
|
||||
) *mlx.Array {
|
||||
cfg := m.Config
|
||||
|
||||
// Encode prompt with grid tokens using GLM tokenizer
|
||||
// Format: {prompt}<sop>{h} {w}<eop><sop>{prev_h} {prev_w}<eop><|dit_token_16384|>
|
||||
tokens := tok.EncodeForGeneration(prompt, targetHeight, targetWidth)
|
||||
|
||||
// Calculate grid dimensions for MRoPE position IDs
|
||||
factor := int32(32)
|
||||
tokenH := targetHeight / factor
|
||||
tokenW := targetWidth / factor
|
||||
ratio := float64(tokenH) / float64(tokenW)
|
||||
prevTokenH := int32(math.Sqrt(ratio) * 16)
|
||||
prevTokenW := int32(math.Sqrt(1.0/ratio) * 16)
|
||||
prevGridSize := prevTokenH * prevTokenW
|
||||
|
||||
// Create KV cache for all layers
|
||||
cache := NewARCache(cfg.NumHiddenLayers)
|
||||
defer cache.Free()
|
||||
|
||||
// ===== PREFILL PHASE =====
|
||||
// Process entire prompt at once, populate cache
|
||||
promptLen := int32(len(tokens))
|
||||
tokenArr := mlx.NewArrayInt32(tokens, []int32{1, promptLen})
|
||||
h := m.EmbedTokens.Forward(tokenArr)
|
||||
tokenArr.Free()
|
||||
|
||||
mlx.Eval(h)
|
||||
|
||||
// Compute position IDs for prefill (text tokens use same position for all dims)
|
||||
prefillPositions := make([][]int32, 3)
|
||||
for dim := 0; dim < 3; dim++ {
|
||||
prefillPositions[dim] = make([]int32, promptLen)
|
||||
for i := int32(0); i < promptLen; i++ {
|
||||
prefillPositions[dim][i] = i
|
||||
}
|
||||
}
|
||||
|
||||
// Forward through layers (prefill)
|
||||
for i, layer := range m.Layers {
|
||||
oldH := h
|
||||
h = layer.ForwardWithCache(h, promptLen, 0, cfg.RMSNormEps, cache.Layers[i], prefillPositions)
|
||||
if i > 0 {
|
||||
oldH.Free()
|
||||
}
|
||||
}
|
||||
// Eval h and cache arrays together so cache is materialized
|
||||
evalArgs := []*mlx.Array{h}
|
||||
for _, lc := range cache.Layers {
|
||||
evalArgs = append(evalArgs, lc.State()...)
|
||||
}
|
||||
mlx.Eval(evalArgs...)
|
||||
|
||||
// Final norm and get logits for last position
|
||||
preNormH := h
|
||||
h = m.FinalNorm.Forward(h, cfg.RMSNormEps)
|
||||
preNormH.Free()
|
||||
|
||||
lastH := mlx.Slice(h, []int32{0, promptLen - 1, 0}, []int32{1, promptLen, cfg.HiddenSize})
|
||||
h.Free()
|
||||
lastH = mlx.Reshape(lastH, 1, cfg.HiddenSize)
|
||||
logits := mlx.Matmul(lastH, mlx.Transpose(m.LMHead, 1, 0))
|
||||
lastH.Free()
|
||||
|
||||
// Sample first token
|
||||
var sampleCounter int64 = 0
|
||||
nextToken := sampleVisualToken(logits, temperature, topP, cfg, seed, &sampleCounter)
|
||||
logits.Free()
|
||||
|
||||
// AR generation loop with caching
|
||||
// Visual tokens are stored as VQ codebook indices [0, 16383]
|
||||
// The LM head outputs indices [0, 16511] where:
|
||||
// - [0, 16383] are VQ codes
|
||||
// - 16384 is BOS
|
||||
// - 16385 is EOS
|
||||
visualTokens := make([]int32, 0, maxTokens)
|
||||
posOffset := promptLen
|
||||
visualTokenIdx := int32(0) // Index within visual token sequence for grid position calculation
|
||||
|
||||
// Preallocate slice for old cache state to reuse
|
||||
oldCacheState := make([]*mlx.Array, 0, len(m.Layers)*2)
|
||||
|
||||
for i := int32(0); i < maxTokens; i++ {
|
||||
if progressFn != nil {
|
||||
progressFn(int(i))
|
||||
}
|
||||
|
||||
// Check for end token (EOS = 16385)
|
||||
if nextToken == cfg.ImageEndTokenID {
|
||||
break
|
||||
}
|
||||
|
||||
// Skip BOS token (16384), only store actual VQ codes [0, 16383]
|
||||
if nextToken == cfg.ImageStartTokenID {
|
||||
// BOS token - skip storing but continue generation
|
||||
} else if nextToken < cfg.ImageStartTokenID {
|
||||
// This is an actual VQ code [0, 16383] - store it
|
||||
visualTokens = append(visualTokens, nextToken)
|
||||
}
|
||||
// Tokens >= 16386 are other special tokens, skip them
|
||||
|
||||
// ===== DECODE PHASE =====
|
||||
// Save old cache state before forward (to free after eval)
|
||||
oldCacheState = oldCacheState[:0]
|
||||
for _, lc := range cache.Layers {
|
||||
oldCacheState = append(oldCacheState, lc.State()...)
|
||||
}
|
||||
|
||||
// Only process the new token, use cached K,V
|
||||
tokenArr := mlx.NewArrayInt32([]int32{nextToken}, []int32{1, 1})
|
||||
h := m.EmbedTokens.Forward(tokenArr)
|
||||
tokenArr.Free()
|
||||
|
||||
// Compute MRoPE position IDs for this visual token
|
||||
// Visual tokens are arranged in two grids: prev grid then target grid
|
||||
// Position dimensions: [temporal, height, width]
|
||||
decodePositions := computeVisualTokenPositions(
|
||||
visualTokenIdx, posOffset, promptLen,
|
||||
prevTokenH, prevTokenW, prevGridSize,
|
||||
tokenH, tokenW,
|
||||
)
|
||||
|
||||
// Forward through layers (decode with cache)
|
||||
for j, layer := range m.Layers {
|
||||
oldH := h
|
||||
h = layer.ForwardWithCache(h, 1, posOffset, cfg.RMSNormEps, cache.Layers[j], decodePositions)
|
||||
if j > 0 { // Don't free the embedding on first layer
|
||||
oldH.Free()
|
||||
}
|
||||
}
|
||||
|
||||
// Eval h and new cache state
|
||||
newCacheState := make([]*mlx.Array, 0, len(m.Layers)*2)
|
||||
for _, lc := range cache.Layers {
|
||||
newCacheState = append(newCacheState, lc.State()...)
|
||||
}
|
||||
mlx.Eval(append([]*mlx.Array{h}, newCacheState...)...)
|
||||
|
||||
// Free old cache state (now that new state is evaluated)
|
||||
for _, arr := range oldCacheState {
|
||||
if arr != nil {
|
||||
arr.Free()
|
||||
}
|
||||
}
|
||||
|
||||
// Final norm
|
||||
preNormH := h
|
||||
h = m.FinalNorm.Forward(h, cfg.RMSNormEps)
|
||||
preNormH.Free()
|
||||
|
||||
// Get logits (h is already [1, 1, hidden_size])
|
||||
h = mlx.Reshape(h, 1, cfg.HiddenSize)
|
||||
logits := mlx.Matmul(h, mlx.Transpose(m.LMHead, 1, 0))
|
||||
h.Free()
|
||||
|
||||
// Sample next token
|
||||
nextToken = sampleVisualToken(logits, temperature, topP, cfg, seed, &sampleCounter)
|
||||
logits.Free()
|
||||
|
||||
posOffset++
|
||||
visualTokenIdx++
|
||||
|
||||
// Periodically clear cache to release intermediate memory
|
||||
if i%256 == 0 {
|
||||
mlx.ClearCache()
|
||||
}
|
||||
}
|
||||
|
||||
if len(visualTokens) == 0 {
|
||||
// Return at least one token to avoid empty tensor issues
|
||||
visualTokens = append(visualTokens, 0)
|
||||
}
|
||||
|
||||
return mlx.NewArrayInt32(visualTokens, []int32{1, int32(len(visualTokens))})
|
||||
}
|
||||
|
||||
// computeVisualTokenPositions computes MRoPE position IDs for a visual token
|
||||
// Returns [3][1] position IDs for temporal, height, and width dimensions
|
||||
//
|
||||
// MRoPE position encoding for GLM-Image visual tokens:
|
||||
// - temporal: CONSTANT within each grid (= decode_pos at grid start)
|
||||
// - height: decode_pos + row index within grid
|
||||
// - width: decode_pos + column index within grid
|
||||
//
|
||||
// Between grids, decode_pos advances by max(grid_h, grid_w) to ensure
|
||||
// sufficient positional separation.
|
||||
func computeVisualTokenPositions(
|
||||
visualIdx int32, absPos int32, promptLen int32,
|
||||
prevH, prevW, prevSize int32,
|
||||
targetH, targetW int32,
|
||||
) [][]int32 {
|
||||
positions := make([][]int32, 3)
|
||||
for dim := 0; dim < 3; dim++ {
|
||||
positions[dim] = make([]int32, 1)
|
||||
}
|
||||
|
||||
// First grid (prev grid) starts at decode_pos = promptLen
|
||||
prevGridDecodePos := promptLen
|
||||
|
||||
// Second grid (target grid) starts after first grid
|
||||
// next_pos = prev_decode_pos + max(prevH, prevW)
|
||||
maxPrev := prevH
|
||||
if prevW > maxPrev {
|
||||
maxPrev = prevW
|
||||
}
|
||||
targetGridDecodePos := prevGridDecodePos + maxPrev
|
||||
|
||||
// Compute position IDs based on which grid the token is in
|
||||
if visualIdx < prevSize {
|
||||
// Token is in the prev grid (prev_token_h × prev_token_w)
|
||||
row := visualIdx / prevW
|
||||
col := visualIdx % prevW
|
||||
|
||||
// temporal is CONSTANT for all tokens in this grid
|
||||
positions[0][0] = prevGridDecodePos
|
||||
// height and width are relative to grid's decode_pos
|
||||
positions[1][0] = prevGridDecodePos + row
|
||||
positions[2][0] = prevGridDecodePos + col
|
||||
} else {
|
||||
// Token is in the target grid (token_h × token_w)
|
||||
targetIdx := visualIdx - prevSize
|
||||
row := targetIdx / targetW
|
||||
col := targetIdx % targetW
|
||||
|
||||
// temporal is CONSTANT for all tokens in this grid
|
||||
positions[0][0] = targetGridDecodePos
|
||||
// height and width are relative to grid's decode_pos
|
||||
positions[1][0] = targetGridDecodePos + row
|
||||
positions[2][0] = targetGridDecodePos + col
|
||||
}
|
||||
|
||||
_ = targetH // Used for documentation clarity
|
||||
_ = absPos // No longer used - kept for API compatibility
|
||||
return positions
|
||||
}
|
||||
|
||||
// sampleVisualToken samples from the visual vocabulary using top-p (nucleus) sampling
|
||||
// Note: For GLM-Image, greedy decoding is not allowed as it may cause repetitive outputs
|
||||
// Returns a visual token ID in range [0, 16511] which directly indexes into the embedding table
|
||||
// sampleCounter is incremented for each call to ensure different random values
|
||||
func sampleVisualToken(logits *mlx.Array, temperature float32, topP float32, cfg *VisionLanguageConfig, seed int64, sampleCounter *int64) int32 {
|
||||
// The LMHead outputs logits for visual tokens only (shape [1, 16512])
|
||||
// Output index directly corresponds to vocab ID [0, 16511]
|
||||
// No offset needed - the visual tokens are at vocab IDs [0, 16511]
|
||||
visualLogits := logits
|
||||
|
||||
// Apply temperature
|
||||
if temperature != 1.0 && temperature > 0 {
|
||||
visualLogits = mlx.DivScalar(visualLogits, temperature)
|
||||
}
|
||||
|
||||
// Apply softmax to get probabilities
|
||||
probs := mlx.Softmax(visualLogits, -1)
|
||||
mlx.Eval(probs)
|
||||
|
||||
// Get the sampled index using top-p sampling
|
||||
// This directly gives us the vocab ID in [0, 16511]
|
||||
// Special tokens: 16384 = BOS, 16385 = EOS
|
||||
// Use seed + counter for reproducible but different random values
|
||||
effectiveSeed := seed + *sampleCounter
|
||||
*sampleCounter++
|
||||
return sampleTopP(probs, topP, effectiveSeed)
|
||||
}
|
||||
|
||||
// sampleTopP implements nucleus (top-p) sampling
|
||||
// probs: [1, vocab_size] probability distribution
|
||||
// topP: cumulative probability threshold (e.g., 0.75)
|
||||
// seed: random seed for reproducible sampling
|
||||
func sampleTopP(probs *mlx.Array, topP float32, seed int64) int32 {
|
||||
// Negate probs for descending sort (Argsort only does ascending)
|
||||
negProbs := mlx.MulScalar(probs, -1)
|
||||
sortedIndices := mlx.Argsort(negProbs, -1)
|
||||
sortedProbs := mlx.TakeAlongAxis(probs, sortedIndices, -1)
|
||||
cumProbs := mlx.Cumsum(sortedProbs, -1)
|
||||
mlx.Eval(sortedIndices, sortedProbs, cumProbs)
|
||||
|
||||
// Find cutoff index where cumulative probability exceeds topP
|
||||
probsData := sortedProbs.Data()
|
||||
cumProbsData := cumProbs.Data()
|
||||
indicesData := sortedIndices.DataInt32()
|
||||
|
||||
// Calculate cutoff and renormalize
|
||||
var cutoffIdx int
|
||||
var totalProb float32
|
||||
for i, cp := range cumProbsData {
|
||||
totalProb += probsData[i]
|
||||
if cp >= topP {
|
||||
cutoffIdx = i + 1 // Include this token
|
||||
break
|
||||
}
|
||||
}
|
||||
if cutoffIdx == 0 {
|
||||
cutoffIdx = len(probsData) // Use all tokens if topP is very high
|
||||
}
|
||||
|
||||
// Sample from the truncated distribution
|
||||
// Renormalize the truncated probabilities
|
||||
truncatedProbs := make([]float32, cutoffIdx)
|
||||
for i := 0; i < cutoffIdx; i++ {
|
||||
truncatedProbs[i] = probsData[i] / totalProb
|
||||
}
|
||||
|
||||
// Sample using random number with provided seed for reproducibility
|
||||
r := mlx.RandomUniform([]int32{1}, uint64(seed))
|
||||
mlx.Eval(r)
|
||||
randVal := r.Data()[0]
|
||||
|
||||
// Find the sampled token
|
||||
var cumulative float32
|
||||
for i := 0; i < cutoffIdx; i++ {
|
||||
cumulative += truncatedProbs[i]
|
||||
if randVal < cumulative {
|
||||
return indicesData[i]
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to the last token in truncated set
|
||||
return indicesData[cutoffIdx-1]
|
||||
}
|
||||
|
||||
// Forward for GLMBlock
|
||||
func (b *GLMBlock) Forward(x *mlx.Array, seqLen int32, eps float32) *mlx.Array {
|
||||
return b.ForwardWithCache(x, seqLen, 0, eps, nil, nil)
|
||||
}
|
||||
|
||||
// ForwardWithCache performs block forward with optional KV caching and MRoPE
|
||||
// positionIDs: [3][L] - position indices for MRoPE (nil = use sequential positions)
|
||||
func (b *GLMBlock) ForwardWithCache(x *mlx.Array, seqLen int32, posOffset int32, eps float32, kvcache cache.Cache, positionIDs [][]int32) *mlx.Array {
|
||||
// Pre-attention norm
|
||||
normed := b.InputLayerNorm.Forward(x, eps)
|
||||
|
||||
// Self-attention with RoPE/MRoPE and cache
|
||||
attnOut := b.SelfAttn.ForwardWithCache(normed, seqLen, posOffset, kvcache, positionIDs)
|
||||
|
||||
// Post-attention norm (GLM-4 style)
|
||||
attnOut = b.PostSelfAttnNorm.Forward(attnOut, eps)
|
||||
|
||||
// Residual connection
|
||||
x = mlx.Add(x, attnOut)
|
||||
|
||||
// Post-attention layer norm
|
||||
normed = b.PostAttnLayerNorm.Forward(x, eps)
|
||||
|
||||
// MLP
|
||||
mlpOut := b.MLP.Forward(normed)
|
||||
|
||||
// Post-MLP norm
|
||||
mlpOut = b.PostMLPLayerNorm.Forward(mlpOut, eps)
|
||||
|
||||
// Residual connection
|
||||
x = mlx.Add(x, mlpOut)
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// Forward for GLMAttention (without cache - used for prefill)
|
||||
func (attn *GLMAttention) Forward(x *mlx.Array, seqLen int32) *mlx.Array {
|
||||
return attn.ForwardWithCache(x, seqLen, 0, nil, nil)
|
||||
}
|
||||
|
||||
// ForwardWithCache performs attention with optional KV caching and MRoPE
|
||||
// posOffset is the position offset for RoPE (0 for prefill, cached_len for decode)
|
||||
// positionIDs: [3][L] - if nil, uses sequential positions for all dims (text mode)
|
||||
// kvcache is updated in-place if provided
|
||||
func (attn *GLMAttention) ForwardWithCache(x *mlx.Array, seqLen int32, posOffset int32, kvcache cache.Cache, positionIDs [][]int32) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
|
||||
// Q, K, V projections
|
||||
q := mlx.Matmul(x, mlx.Transpose(attn.QProj, 1, 0))
|
||||
k := mlx.Matmul(x, mlx.Transpose(attn.KProj, 1, 0))
|
||||
v := mlx.Matmul(x, mlx.Transpose(attn.VProj, 1, 0))
|
||||
|
||||
// Add biases
|
||||
if attn.QBias != nil {
|
||||
q = mlx.Add(q, attn.QBias)
|
||||
}
|
||||
if attn.KBias != nil {
|
||||
k = mlx.Add(k, attn.KBias)
|
||||
}
|
||||
if attn.VBias != nil {
|
||||
v = mlx.Add(v, attn.VBias)
|
||||
}
|
||||
|
||||
// Reshape to [B, L, nheads, head_dim]
|
||||
q = mlx.Reshape(q, B, L, attn.NHeads, attn.HeadDim)
|
||||
k = mlx.Reshape(k, B, L, attn.NKVHeads, attn.HeadDim)
|
||||
v = mlx.Reshape(v, B, L, attn.NKVHeads, attn.HeadDim)
|
||||
|
||||
// Apply partial RoPE or MRoPE
|
||||
rotaryDim := int32(float32(attn.HeadDim) * attn.PartialRotary)
|
||||
if len(attn.MRoPESection) == 3 && positionIDs != nil {
|
||||
// Use MRoPE with explicit position IDs
|
||||
q = applyPartialMRoPE(q, positionIDs, rotaryDim, attn.RopeTheta, attn.MRoPESection)
|
||||
k = applyPartialMRoPE(k, positionIDs, rotaryDim, attn.RopeTheta, attn.MRoPESection)
|
||||
} else if len(attn.MRoPESection) == 3 {
|
||||
// Use MRoPE with sequential positions (same for all dims - text mode)
|
||||
seqPositions := make([][]int32, 3)
|
||||
for dim := 0; dim < 3; dim++ {
|
||||
seqPositions[dim] = make([]int32, L)
|
||||
for i := int32(0); i < L; i++ {
|
||||
seqPositions[dim][i] = i + posOffset
|
||||
}
|
||||
}
|
||||
q = applyPartialMRoPE(q, seqPositions, rotaryDim, attn.RopeTheta, attn.MRoPESection)
|
||||
k = applyPartialMRoPE(k, seqPositions, rotaryDim, attn.RopeTheta, attn.MRoPESection)
|
||||
} else {
|
||||
// Fallback to standard RoPE
|
||||
q = applyPartialRoPEWithOffset(q, L, posOffset, rotaryDim, attn.RopeTheta)
|
||||
k = applyPartialRoPEWithOffset(k, L, posOffset, rotaryDim, attn.RopeTheta)
|
||||
}
|
||||
|
||||
// Transpose to [B, nheads, L, head_dim]
|
||||
q = mlx.Transpose(q, 0, 2, 1, 3)
|
||||
k = mlx.Transpose(k, 0, 2, 1, 3)
|
||||
v = mlx.Transpose(v, 0, 2, 1, 3)
|
||||
|
||||
// Update cache and get full K, V for attention
|
||||
if kvcache != nil {
|
||||
k, v = kvcache.Update(k, v, int(L))
|
||||
}
|
||||
|
||||
// Repeat KV for GQA
|
||||
kExpanded := k
|
||||
vExpanded := v
|
||||
if attn.NKVHeads < attn.NHeads {
|
||||
repeats := attn.NHeads / attn.NKVHeads
|
||||
kExpanded = repeatKV(k, repeats)
|
||||
vExpanded = repeatKV(v, repeats)
|
||||
}
|
||||
|
||||
// Scaled dot-product attention with causal mask
|
||||
out := mlx.ScaledDotProductAttention(q, kExpanded, vExpanded, attn.Scale, true)
|
||||
|
||||
// Transpose back [B, nheads, L, head_dim] -> [B, L, nheads, head_dim]
|
||||
out = mlx.Transpose(out, 0, 2, 1, 3)
|
||||
// Reshape to [B, L, hidden_size]
|
||||
out = mlx.Reshape(out, B, L, attn.NHeads*attn.HeadDim)
|
||||
|
||||
// Output projection
|
||||
out = mlx.Matmul(out, mlx.Transpose(attn.OProj, 1, 0))
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// applyPartialRoPE applies RoPE to only the first rotaryDim dimensions
|
||||
func applyPartialRoPE(x *mlx.Array, seqLen int32, rotaryDim int32, theta float32) *mlx.Array {
|
||||
return applyPartialRoPEWithOffset(x, seqLen, 0, rotaryDim, theta)
|
||||
}
|
||||
|
||||
// applyPartialRoPEWithOffset applies RoPE with a position offset
|
||||
func applyPartialRoPEWithOffset(x *mlx.Array, seqLen int32, posOffset int32, rotaryDim int32, theta float32) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
H := shape[2]
|
||||
D := shape[3]
|
||||
|
||||
if rotaryDim <= 0 || rotaryDim > D {
|
||||
rotaryDim = D
|
||||
}
|
||||
|
||||
// Split into rotary and pass-through parts
|
||||
xRot := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, L, H, rotaryDim})
|
||||
xPass := mlx.Slice(x, []int32{0, 0, 0, rotaryDim}, []int32{B, L, H, D})
|
||||
|
||||
// Apply RoPE to rotary part with position offset
|
||||
xRot = applyRoPEWithOffset(xRot, L, posOffset, theta)
|
||||
|
||||
// Concatenate back
|
||||
return mlx.Concatenate([]*mlx.Array{xRot, xPass}, 3)
|
||||
}
|
||||
|
||||
// applyPartialMRoPE applies Multi-dimensional RoPE (MRoPE) to the first rotaryDim dimensions
|
||||
// positionIDs: [3, L] - position indices for each dimension (temporal, height, width)
|
||||
// mrope_section: [8, 12, 12] - frequency pairs per dimension
|
||||
// For text tokens: all 3 dimensions have the same sequential position
|
||||
// For image tokens: temporal=seq_idx, height=row, width=col
|
||||
func applyPartialMRoPE(x *mlx.Array, positionIDs [][]int32, rotaryDim int32, theta float32, mropeSection []int32) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
H := shape[2]
|
||||
D := shape[3]
|
||||
|
||||
if rotaryDim <= 0 || rotaryDim > D {
|
||||
rotaryDim = D
|
||||
}
|
||||
|
||||
// Split into rotary and pass-through parts
|
||||
xRot := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, L, H, rotaryDim})
|
||||
xPass := mlx.Slice(x, []int32{0, 0, 0, rotaryDim}, []int32{B, L, H, D})
|
||||
|
||||
// Apply MRoPE to rotary part
|
||||
xRot = applyMRoPE(xRot, positionIDs, theta, mropeSection)
|
||||
|
||||
// Concatenate back
|
||||
return mlx.Concatenate([]*mlx.Array{xRot, xPass}, 3)
|
||||
}
|
||||
|
||||
// applyMRoPE applies multi-dimensional rotary position embedding
|
||||
// x: [B, L, H, D] where D is the rotary dimension
|
||||
// positionIDs: [3][L] - positions for temporal, height, width dimensions
|
||||
// mropeSection: [8, 12, 12] - frequency pairs per dimension
|
||||
func applyMRoPE(x *mlx.Array, positionIDs [][]int32, theta float32, mropeSection []int32) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
H := shape[2]
|
||||
D := shape[3]
|
||||
half := D / 2
|
||||
|
||||
// Validate mrope_section sums to half (number of frequency pairs)
|
||||
var totalPairs int32
|
||||
for _, s := range mropeSection {
|
||||
totalPairs += s
|
||||
}
|
||||
if totalPairs != half {
|
||||
// Fallback to standard RoPE if section doesn't match
|
||||
return applyRoPEWithOffset(x, L, 0, theta)
|
||||
}
|
||||
|
||||
// Build angles for each position dimension (matching Python's MRoPE approach)
|
||||
// Python: compute freqs for all dims, then apply_mrope selects freq ranges, then duplicate
|
||||
// Order: [temporal_8, height_12, width_12] -> duplicate -> [t8, h12, w12, t8, h12, w12]
|
||||
angleVals := make([]*mlx.Array, 3)
|
||||
|
||||
freqOffset := int32(0)
|
||||
for dim := 0; dim < 3; dim++ {
|
||||
numPairs := mropeSection[dim]
|
||||
if numPairs == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Compute inverse frequencies for this section
|
||||
// Each dimension uses DIFFERENT frequency ranges:
|
||||
// - Temporal: frequencies 0 to section[0]-1
|
||||
// - Height: frequencies section[0] to section[0]+section[1]-1
|
||||
// - Width: frequencies section[0]+section[1] to sum(section)-1
|
||||
freqsArr := make([]float32, numPairs)
|
||||
for i := int32(0); i < numPairs; i++ {
|
||||
globalIdx := freqOffset + i
|
||||
freqsArr[i] = float32(1.0 / math.Pow(float64(theta), float64(2*globalIdx)/float64(D)))
|
||||
}
|
||||
freqs := mlx.NewArray(freqsArr, []int32{numPairs})
|
||||
|
||||
// Position indices for this dimension
|
||||
posArr := make([]float32, L)
|
||||
for i := int32(0); i < L; i++ {
|
||||
posArr[i] = float32(positionIDs[dim][i])
|
||||
}
|
||||
pos := mlx.NewArray(posArr, []int32{L})
|
||||
|
||||
// Compute angles: [L, numPairs] = outer(pos, freqs)
|
||||
posExpanded := mlx.Reshape(pos, L, 1)
|
||||
freqsExpanded := mlx.Reshape(freqs, 1, numPairs)
|
||||
angleVals[dim] = mlx.Mul(posExpanded, freqsExpanded)
|
||||
|
||||
freqOffset += numPairs
|
||||
}
|
||||
|
||||
// Concatenate all sections: [L, half] = [L, 32]
|
||||
allAngles := mlx.Concatenate(angleVals, 1)
|
||||
|
||||
// Duplicate AFTER concatenation: [L, D] = [L, 64]
|
||||
// This gives: [temporal_8, height_12, width_12, temporal_8, height_12, width_12]
|
||||
allAngles = mlx.Concatenate([]*mlx.Array{allAngles, allAngles}, 1)
|
||||
|
||||
// Compute cos/sin
|
||||
allCos := mlx.Cos(allAngles)
|
||||
allSin := mlx.Sin(allAngles)
|
||||
|
||||
// Reshape for broadcasting: [1, L, 1, D] to match x [B, L, H, D]
|
||||
allCos = mlx.Reshape(allCos, 1, L, 1, D)
|
||||
allSin = mlx.Reshape(allSin, 1, L, 1, D)
|
||||
|
||||
// x_rotated = cat([-x_imag, x_real], dim=-1)
|
||||
x1 := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, L, H, half}) // x_real
|
||||
x2 := mlx.Slice(x, []int32{0, 0, 0, half}, []int32{B, L, H, D}) // x_imag
|
||||
x2Neg := mlx.MulScalar(x2, -1) // -x_imag
|
||||
xRotated := mlx.Concatenate([]*mlx.Array{x2Neg, x1}, 3) // [-x_imag, x_real]
|
||||
|
||||
// out = x * cos + x_rotated * sin
|
||||
return mlx.Add(mlx.Mul(x, allCos), mlx.Mul(xRotated, allSin))
|
||||
}
|
||||
|
||||
// applyRoPE applies rotary position embedding
|
||||
func applyRoPE(x *mlx.Array, seqLen int32, theta float32) *mlx.Array {
|
||||
return applyRoPEWithOffset(x, seqLen, 0, theta)
|
||||
}
|
||||
|
||||
// applyRoPEWithOffset applies rotary position embedding with position offset
|
||||
// Uses the split-half approach (matches diffusers GLM-Image with use_real_unbind_dim=-2)
|
||||
func applyRoPEWithOffset(x *mlx.Array, seqLen int32, posOffset int32, theta float32) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
H := shape[2]
|
||||
D := shape[3]
|
||||
half := D / 2
|
||||
|
||||
// Compute inverse frequencies: 1 / (theta^(2i/d))
|
||||
freqsArr := make([]float32, half)
|
||||
for i := int32(0); i < half; i++ {
|
||||
freqsArr[i] = float32(1.0 / math.Pow(float64(theta), float64(2*i)/float64(D)))
|
||||
}
|
||||
freqs := mlx.NewArray(freqsArr, []int32{half})
|
||||
|
||||
// Position indices with offset
|
||||
posArr := make([]float32, L)
|
||||
for i := int32(0); i < L; i++ {
|
||||
posArr[i] = float32(i + posOffset)
|
||||
}
|
||||
pos := mlx.NewArray(posArr, []int32{L})
|
||||
|
||||
// Compute angles: [L, half] = outer(pos, freqs)
|
||||
posExpanded := mlx.Reshape(pos, L, 1)
|
||||
freqsExpanded := mlx.Reshape(freqs, 1, half)
|
||||
angles := mlx.Mul(posExpanded, freqsExpanded)
|
||||
|
||||
// Duplicate angles to match diffusers: cat([angles, angles], dim=-1) -> [L, D]
|
||||
anglesDup := mlx.Concatenate([]*mlx.Array{angles, angles}, 1)
|
||||
|
||||
// Cos and sin: [L, 1, D] for broadcasting to [B, L, H, D]
|
||||
cosVals := mlx.Cos(anglesDup)
|
||||
sinVals := mlx.Sin(anglesDup)
|
||||
cosVals = mlx.Reshape(cosVals, L, 1, D)
|
||||
sinVals = mlx.Reshape(sinVals, L, 1, D)
|
||||
|
||||
// x_rotated = cat([-x_imag, x_real], dim=-1) where x_real=x[..., :half], x_imag=x[..., half:]
|
||||
x1 := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, L, H, half}) // x_real
|
||||
x2 := mlx.Slice(x, []int32{0, 0, 0, half}, []int32{B, L, H, D}) // x_imag
|
||||
x2Neg := mlx.MulScalar(x2, -1) // -x_imag
|
||||
xRotated := mlx.Concatenate([]*mlx.Array{x2Neg, x1}, 3) // [-x_imag, x_real]
|
||||
|
||||
// out = x * cos + x_rotated * sin
|
||||
return mlx.Add(mlx.Mul(x, cosVals), mlx.Mul(xRotated, sinVals))
|
||||
}
|
||||
|
||||
// repeatKV repeats key/value heads for GQA
|
||||
func repeatKV(x *mlx.Array, repeats int32) *mlx.Array {
|
||||
if repeats == 1 {
|
||||
return x
|
||||
}
|
||||
shape := x.Shape()
|
||||
// x: [B, nkvheads, L, head_dim]
|
||||
x = mlx.ExpandDims(x, 2)
|
||||
// x: [B, nkvheads, 1, L, head_dim]
|
||||
x = mlx.Tile(x, []int32{1, 1, repeats, 1, 1})
|
||||
// x: [B, nkvheads, repeats, L, head_dim]
|
||||
return mlx.Reshape(x, shape[0], shape[1]*repeats, shape[2], shape[3])
|
||||
}
|
||||
|
||||
// Forward for GLMMLP (fused gate_up SwiGLU)
|
||||
func (m *GLMMLP) Forward(x *mlx.Array) *mlx.Array {
|
||||
// gate_up_proj outputs [gate, up] concatenated
|
||||
gateUp := mlx.Matmul(x, mlx.Transpose(m.GateUpProj, 1, 0))
|
||||
|
||||
shape := gateUp.Shape()
|
||||
halfDim := shape[len(shape)-1] / 2
|
||||
|
||||
// Split into gate and up
|
||||
gate := mlx.Slice(gateUp, []int32{0, 0, 0}, []int32{shape[0], shape[1], halfDim})
|
||||
up := mlx.Slice(gateUp, []int32{0, 0, halfDim}, []int32{shape[0], shape[1], shape[2]})
|
||||
|
||||
// SwiGLU: silu(gate) * up
|
||||
gate = mlx.SiLU(gate)
|
||||
h := mlx.Mul(gate, up)
|
||||
|
||||
// Down projection
|
||||
return mlx.Matmul(h, mlx.Transpose(m.DownProj, 1, 0))
|
||||
}
|
||||
@@ -1,390 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package qwen3 provides a shared Qwen3 text encoder used by multiple image generation models.
|
||||
package qwen3
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
)
|
||||
|
||||
// Config holds Qwen3 text encoder configuration
|
||||
type Config struct {
|
||||
HiddenSize int32 `json:"hidden_size"`
|
||||
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
||||
IntermediateSize int32 `json:"intermediate_size"`
|
||||
NumAttentionHeads int32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads int32 `json:"num_key_value_heads"`
|
||||
VocabSize int32 `json:"vocab_size"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
HeadDim int32 `json:"head_dim"`
|
||||
}
|
||||
|
||||
// Attention implements Qwen3 attention with QK norms
|
||||
type Attention struct {
|
||||
QProj nn.LinearLayer `weight:"q_proj"`
|
||||
KProj nn.LinearLayer `weight:"k_proj"`
|
||||
VProj nn.LinearLayer `weight:"v_proj"`
|
||||
OProj nn.LinearLayer `weight:"o_proj"`
|
||||
QNorm *nn.RMSNorm `weight:"q_norm"`
|
||||
KNorm *nn.RMSNorm `weight:"k_norm"`
|
||||
// Computed fields
|
||||
NHeads int32
|
||||
NKVHeads int32
|
||||
HeadDim int32
|
||||
Scale float32
|
||||
RopeTheta float32
|
||||
}
|
||||
|
||||
// applyRoPEQwen3 applies the custom RoPE for Qwen3 text encoder
|
||||
func applyRoPEQwen3(x *mlx.Array, seqLen int32, theta float32) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
H := shape[2]
|
||||
D := shape[3]
|
||||
half := D / 2
|
||||
|
||||
freqsArr := make([]float32, half)
|
||||
logTheta := float32(math.Log(float64(theta)))
|
||||
for i := int32(0); i < half; i++ {
|
||||
freqsArr[i] = float32(math.Exp(float64(-logTheta * float32(i) / float32(half))))
|
||||
}
|
||||
freqs := mlx.NewArray(freqsArr, []int32{half})
|
||||
|
||||
posArr := make([]float32, seqLen)
|
||||
for i := int32(0); i < seqLen; i++ {
|
||||
posArr[i] = float32(i)
|
||||
}
|
||||
pos := mlx.NewArray(posArr, []int32{seqLen})
|
||||
|
||||
posExpanded := mlx.Reshape(pos, seqLen, 1)
|
||||
freqsExpanded := mlx.Reshape(freqs, 1, half)
|
||||
args := mlx.Mul(posExpanded, freqsExpanded)
|
||||
|
||||
cosVals := mlx.Cos(args)
|
||||
sinVals := mlx.Sin(args)
|
||||
cosVals = mlx.Reshape(cosVals, seqLen, 1, half)
|
||||
sinVals = mlx.Reshape(sinVals, seqLen, 1, half)
|
||||
|
||||
x1 := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, L, H, half})
|
||||
x2 := mlx.Slice(x, []int32{0, 0, 0, half}, []int32{B, L, H, D})
|
||||
|
||||
part1 := mlx.Sub(mlx.Mul(x1, cosVals), mlx.Mul(x2, sinVals))
|
||||
part2 := mlx.Add(mlx.Mul(x1, sinVals), mlx.Mul(x2, cosVals))
|
||||
|
||||
return mlx.Concatenate([]*mlx.Array{part1, part2}, 3)
|
||||
}
|
||||
|
||||
// Forward computes attention with causal masking and optional padding mask
|
||||
func (attn *Attention) Forward(x *mlx.Array, mask *mlx.Array, maskMode string) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
|
||||
q := attn.QProj.Forward(x)
|
||||
k := attn.KProj.Forward(x)
|
||||
v := attn.VProj.Forward(x)
|
||||
|
||||
q = mlx.Reshape(q, B, L, attn.NHeads, attn.HeadDim)
|
||||
k = mlx.Reshape(k, B, L, attn.NKVHeads, attn.HeadDim)
|
||||
v = mlx.Reshape(v, B, L, attn.NKVHeads, attn.HeadDim)
|
||||
|
||||
// QK norm uses 1e-6 hardcoded (Qwen3 specific)
|
||||
q = attn.QNorm.Forward(q, 1e-6)
|
||||
k = attn.KNorm.Forward(k, 1e-6)
|
||||
|
||||
q = applyRoPEQwen3(q, L, attn.RopeTheta)
|
||||
k = applyRoPEQwen3(k, L, attn.RopeTheta)
|
||||
|
||||
q = mlx.Transpose(q, 0, 2, 1, 3)
|
||||
k = mlx.Transpose(k, 0, 2, 1, 3)
|
||||
v = mlx.Transpose(v, 0, 2, 1, 3)
|
||||
|
||||
if attn.NKVHeads < attn.NHeads {
|
||||
repeats := attn.NHeads / attn.NKVHeads
|
||||
k = repeatKV(k, repeats)
|
||||
v = repeatKV(v, repeats)
|
||||
}
|
||||
|
||||
out := mlx.ScaledDotProductAttentionWithSinks(q, k, v, attn.Scale, maskMode, mask, nil)
|
||||
|
||||
out = mlx.Transpose(out, 0, 2, 1, 3)
|
||||
out = mlx.Reshape(out, B, L, attn.NHeads*attn.HeadDim)
|
||||
|
||||
out = attn.OProj.Forward(out)
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// repeatKV repeats key/value heads for GQA
|
||||
func repeatKV(x *mlx.Array, repeats int32) *mlx.Array {
|
||||
if repeats == 1 {
|
||||
return x
|
||||
}
|
||||
shape := x.Shape()
|
||||
x = mlx.ExpandDims(x, 2)
|
||||
x = mlx.Tile(x, []int32{1, 1, repeats, 1, 1})
|
||||
return mlx.Reshape(x, shape[0], shape[1]*repeats, shape[2], shape[3])
|
||||
}
|
||||
|
||||
// MLP implements Qwen3 SwiGLU MLP
|
||||
type MLP struct {
|
||||
GateProj nn.LinearLayer `weight:"gate_proj"`
|
||||
UpProj nn.LinearLayer `weight:"up_proj"`
|
||||
DownProj nn.LinearLayer `weight:"down_proj"`
|
||||
}
|
||||
|
||||
// Forward applies the MLP
|
||||
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
|
||||
gate := m.GateProj.Forward(x)
|
||||
gate = mlx.SiLU(gate)
|
||||
up := m.UpProj.Forward(x)
|
||||
h := mlx.Mul(gate, up)
|
||||
return m.DownProj.Forward(h)
|
||||
}
|
||||
|
||||
// Block represents a single Qwen3 transformer block
|
||||
type Block struct {
|
||||
Attention *Attention `weight:"self_attn"`
|
||||
MLP *MLP `weight:"mlp"`
|
||||
InputLayerNorm *nn.RMSNorm `weight:"input_layernorm"`
|
||||
PostAttnLayerNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
|
||||
}
|
||||
|
||||
// Forward applies the Qwen3 block
|
||||
func (qb *Block) Forward(x *mlx.Array, eps float32, mask *mlx.Array, maskMode string) *mlx.Array {
|
||||
h := qb.InputLayerNorm.Forward(x, eps)
|
||||
attnOut := qb.Attention.Forward(h, mask, maskMode)
|
||||
x = mlx.Add(x, attnOut)
|
||||
|
||||
h = qb.PostAttnLayerNorm.Forward(x, eps)
|
||||
mlpOut := qb.MLP.Forward(h)
|
||||
x = mlx.Add(x, mlpOut)
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// TextEncoder is the full Qwen3 encoder
|
||||
type TextEncoder struct {
|
||||
EmbedTokens *nn.Embedding `weight:"model.embed_tokens"`
|
||||
Layers []*Block `weight:"model.layers"`
|
||||
FinalNorm *nn.RMSNorm `weight:"model.norm"`
|
||||
*Config
|
||||
}
|
||||
|
||||
// Load loads the Qwen3 text encoder from ollama blob storage.
|
||||
func (m *TextEncoder) Load(manifest *imagegen.ModelManifest, configPath string) error {
|
||||
fmt.Print(" Loading text encoder... ")
|
||||
|
||||
// Load config from blob
|
||||
var cfg Config
|
||||
if err := manifest.ReadConfigJSON(configPath, &cfg); err != nil {
|
||||
return fmt.Errorf("config: %w", err)
|
||||
}
|
||||
m.Config = &cfg
|
||||
m.Layers = make([]*Block, cfg.NumHiddenLayers)
|
||||
|
||||
// Load weights from tensor blobs
|
||||
weights, err := imagegen.LoadWeightsFromManifest(manifest, "text_encoder")
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
if err := weights.Load(0); err != nil {
|
||||
return fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
defer weights.ReleaseAll()
|
||||
|
||||
return m.loadWeights(weights)
|
||||
}
|
||||
|
||||
// loadWeights loads weights from any WeightSource into the model
|
||||
func (m *TextEncoder) loadWeights(weights safetensors.WeightSource) error {
|
||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||
return fmt.Errorf("load module: %w", err)
|
||||
}
|
||||
m.initComputedFields()
|
||||
fmt.Println("✓")
|
||||
return nil
|
||||
}
|
||||
|
||||
// initComputedFields initializes computed fields after loading weights
|
||||
func (m *TextEncoder) initComputedFields() {
|
||||
cfg := m.Config
|
||||
m.FinalNorm.Eps = cfg.RMSNormEps
|
||||
for _, block := range m.Layers {
|
||||
// Attention
|
||||
block.Attention.NHeads = cfg.NumAttentionHeads
|
||||
block.Attention.NKVHeads = cfg.NumKeyValueHeads
|
||||
block.Attention.HeadDim = cfg.HeadDim
|
||||
block.Attention.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
|
||||
block.Attention.RopeTheta = cfg.RopeTheta
|
||||
block.Attention.QNorm.Eps = cfg.RMSNormEps
|
||||
block.Attention.KNorm.Eps = cfg.RMSNormEps
|
||||
// Block norms
|
||||
block.InputLayerNorm.Eps = cfg.RMSNormEps
|
||||
block.PostAttnLayerNorm.Eps = cfg.RMSNormEps
|
||||
}
|
||||
}
|
||||
|
||||
// Forward encodes text tokens with provided attention mask (LxL) and mask mode.
|
||||
func (te *TextEncoder) Forward(tokens *mlx.Array, attnMask *mlx.Array, maskMode string) *mlx.Array {
|
||||
h := te.EmbedTokens.Forward(tokens)
|
||||
eps := te.RMSNormEps
|
||||
|
||||
for _, layer := range te.Layers {
|
||||
h = layer.Forward(h, eps, attnMask, maskMode)
|
||||
}
|
||||
|
||||
// Apply final RMS norm
|
||||
h = te.FinalNorm.Forward(h, eps)
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// ForwardWithLayerOutputs encodes text tokens and returns hidden states from specified layers.
|
||||
// This is used by Flux2 which needs embeddings from specific intermediate layers.
|
||||
func (te *TextEncoder) ForwardWithLayerOutputs(tokens *mlx.Array, layerIndices []int, attnMask *mlx.Array, maskMode string) []*mlx.Array {
|
||||
h := te.EmbedTokens.Forward(tokens)
|
||||
eps := te.RMSNormEps
|
||||
|
||||
outputs := make([]*mlx.Array, len(layerIndices))
|
||||
layerSet := make(map[int]int)
|
||||
for i, idx := range layerIndices {
|
||||
layerSet[idx] = i
|
||||
}
|
||||
|
||||
for i, layer := range te.Layers {
|
||||
h = layer.Forward(h, eps, attnMask, maskMode)
|
||||
if outIdx, ok := layerSet[i]; ok {
|
||||
outputs[outIdx] = h
|
||||
}
|
||||
}
|
||||
|
||||
return outputs
|
||||
}
|
||||
|
||||
// ApplyChatTemplate wraps prompt in Qwen3 chat format.
|
||||
// If think is true, adds the <think></think> block after the assistant tag
|
||||
// (matches tokenizer.apply_chat_template with enable_thinking=False in Python).
|
||||
func ApplyChatTemplate(prompt string, think bool) string {
|
||||
base := "<|im_start|>user\n" + prompt + "<|im_end|>\n<|im_start|>assistant\n"
|
||||
if think {
|
||||
return base + "<think>\n\n</think>\n\n"
|
||||
}
|
||||
return base
|
||||
}
|
||||
|
||||
// EncodePrompt encodes a text prompt using the tokenizer and encoder.
|
||||
// If think is true, includes the <think></think> block in the chat template.
|
||||
func (te *TextEncoder) EncodePrompt(tok *tokenizer.Tokenizer, prompt string, maxLen int, think bool) (*mlx.Array, *mlx.Array) {
|
||||
formattedPrompt := ApplyChatTemplate(prompt, think)
|
||||
|
||||
tokens := tok.Encode(formattedPrompt, false)
|
||||
|
||||
if len(tokens) > maxLen {
|
||||
tokens = tokens[:maxLen]
|
||||
}
|
||||
|
||||
maskData := make([]float32, maxLen)
|
||||
for i := 0; i < len(tokens); i++ {
|
||||
maskData[i] = 1.0
|
||||
}
|
||||
|
||||
// Get PAD token (different from EOS for Qwen3)
|
||||
padToken := tok.PAD()
|
||||
if padToken < 0 {
|
||||
padToken = tok.EOS() // fallback
|
||||
}
|
||||
|
||||
paddedTokens := make([]int32, maxLen)
|
||||
copy(paddedTokens, tokens)
|
||||
for i := len(tokens); i < maxLen; i++ {
|
||||
paddedTokens[i] = padToken
|
||||
}
|
||||
|
||||
tokensArr := mlx.NewArrayInt32(paddedTokens, []int32{1, int32(maxLen)})
|
||||
maskArr := mlx.NewArray(maskData, []int32{1, int32(maxLen)})
|
||||
|
||||
// Build combined causal + PAD mask [L, L]
|
||||
// mask[i,j] = 0 if (j <= i AND valid[j]) else -inf
|
||||
L := int32(maxLen)
|
||||
validLen := int32(len(tokens))
|
||||
combinedMaskData := make([]float32, L*L)
|
||||
negInf := float32(-1e9)
|
||||
for i := int32(0); i < L; i++ {
|
||||
for j := int32(0); j < L; j++ {
|
||||
idx := i*L + j
|
||||
if j <= i && j < validLen {
|
||||
combinedMaskData[idx] = 0
|
||||
} else {
|
||||
combinedMaskData[idx] = negInf
|
||||
}
|
||||
}
|
||||
}
|
||||
maskMat := mlx.NewArray(combinedMaskData, []int32{L, L})
|
||||
|
||||
embeddings := te.Forward(tokensArr, maskMat, "")
|
||||
|
||||
return embeddings, maskArr
|
||||
}
|
||||
|
||||
// EncodePromptWithLayers encodes a text prompt and returns embeddings from specified layers.
|
||||
// Used by Flux2 which concatenates embeddings from multiple intermediate layers.
|
||||
// If think is true, includes the <think></think> block in the chat template.
|
||||
// Returns embeddings and padded sequence length.
|
||||
func (te *TextEncoder) EncodePromptWithLayers(tok *tokenizer.Tokenizer, prompt string, maxLen int, layerIndices []int, think bool) (*mlx.Array, int32) {
|
||||
formattedPrompt := ApplyChatTemplate(prompt, think)
|
||||
tokens := tok.Encode(formattedPrompt, false)
|
||||
|
||||
if len(tokens) > maxLen {
|
||||
tokens = tokens[:maxLen]
|
||||
}
|
||||
|
||||
// Pad to maxLen
|
||||
padToken := tok.PAD()
|
||||
if padToken < 0 {
|
||||
padToken = tok.EOS() // fallback
|
||||
}
|
||||
padded := make([]int32, maxLen)
|
||||
copy(padded, tokens)
|
||||
for i := len(tokens); i < maxLen; i++ {
|
||||
padded[i] = padToken
|
||||
}
|
||||
tokensArr := mlx.NewArrayInt32(padded, []int32{1, int32(maxLen)})
|
||||
|
||||
// Build combined causal + PAD mask [L, L]
|
||||
// mask[i,j] = 0 if (j <= i AND valid[j]) else -inf
|
||||
// This combines causal masking with PAD token masking
|
||||
L := int32(maxLen)
|
||||
validLen := int32(len(tokens))
|
||||
maskData := make([]float32, L*L)
|
||||
negInf := float32(-1e9)
|
||||
for i := int32(0); i < L; i++ {
|
||||
for j := int32(0); j < L; j++ {
|
||||
idx := i*L + j
|
||||
if j <= i && j < validLen {
|
||||
maskData[idx] = 0 // allowed: causal OK and not PAD
|
||||
} else {
|
||||
maskData[idx] = negInf // blocked: future or PAD
|
||||
}
|
||||
}
|
||||
}
|
||||
maskMat := mlx.NewArray(maskData, []int32{L, L})
|
||||
|
||||
layerOutputs := te.ForwardWithLayerOutputs(tokensArr, layerIndices, maskMat, "")
|
||||
|
||||
// Concatenate layer outputs along the hidden dimension
|
||||
// Each output is [B, L, hidden_dim], result is [B, L, num_layers * hidden_dim]
|
||||
embeddings := mlx.Concatenate(layerOutputs, 2)
|
||||
|
||||
// Return embeddings and padded length
|
||||
return embeddings, int32(maxLen)
|
||||
}
|
||||
@@ -3,33 +3,12 @@
|
||||
package qwen_image
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// TestMain initializes MLX before running tests.
|
||||
// If MLX libraries are not available, tests are skipped.
|
||||
func TestMain(m *testing.M) {
|
||||
// Change to repo root so ./build/lib/ollama/ path works
|
||||
_, thisFile, _, _ := runtime.Caller(0)
|
||||
repoRoot := filepath.Join(filepath.Dir(thisFile), "..", "..", "..", "..")
|
||||
if err := os.Chdir(repoRoot); err != nil {
|
||||
fmt.Printf("Failed to change to repo root: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if err := mlx.InitMLX(); err != nil {
|
||||
fmt.Printf("Skipping qwen_image tests: %v\n", err)
|
||||
os.Exit(0)
|
||||
}
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
// TestPipelineOutput runs the full pipeline (integration test).
|
||||
// Skips if model weights not found. Requires ~50GB VRAM.
|
||||
func TestPipelineOutput(t *testing.T) {
|
||||
|
||||
@@ -17,13 +17,13 @@ import (
|
||||
// GenerateConfig holds all options for image generation.
|
||||
type GenerateConfig struct {
|
||||
Prompt string
|
||||
NegativePrompt string // Empty = no CFG
|
||||
CFGScale float32 // Only used if NegativePrompt is set (default: 4.0)
|
||||
Width int32 // Image width (default: 1024)
|
||||
Height int32 // Image height (default: 1024)
|
||||
Steps int // Denoising steps (default: 30)
|
||||
Seed int64 // Random seed
|
||||
Progress func(step, totalSteps int) // Optional progress callback
|
||||
NegativePrompt string // Empty = no CFG
|
||||
CFGScale float32 // Only used if NegativePrompt is set (default: 4.0)
|
||||
Width int32 // Image width (default: 1024)
|
||||
Height int32 // Image height (default: 1024)
|
||||
Steps int // Denoising steps (default: 30)
|
||||
Seed int64 // Random seed
|
||||
Progress ProgressFunc // Optional progress callback
|
||||
|
||||
// Layer caching (DeepCache/Learning-to-Cache speedup)
|
||||
LayerCache bool // Enable layer caching (default: false)
|
||||
@@ -31,6 +31,9 @@ type GenerateConfig struct {
|
||||
CacheLayers int // Number of shallow layers to cache (default: 25)
|
||||
}
|
||||
|
||||
// ProgressFunc is called during generation with step progress.
|
||||
type ProgressFunc func(step, totalSteps int)
|
||||
|
||||
// Model represents a Qwen-Image diffusion model.
|
||||
type Model struct {
|
||||
ModelPath string
|
||||
@@ -114,7 +117,7 @@ func (m *Model) Generate(prompt string, width, height int32, steps int, seed int
|
||||
}
|
||||
|
||||
// GenerateWithProgress creates an image with progress callback.
|
||||
func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress func(step, totalSteps int)) (*mlx.Array, error) {
|
||||
func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress ProgressFunc) (*mlx.Array, error) {
|
||||
return m.GenerateFromConfig(&GenerateConfig{
|
||||
Prompt: prompt,
|
||||
Width: width,
|
||||
@@ -126,7 +129,7 @@ func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps i
|
||||
}
|
||||
|
||||
// GenerateWithCFG creates an image with classifier-free guidance.
|
||||
func (m *Model) GenerateWithCFG(prompt, negativePrompt string, width, height int32, steps int, seed int64, cfgScale float32, progress func(step, totalSteps int)) (*mlx.Array, error) {
|
||||
func (m *Model) GenerateWithCFG(prompt, negativePrompt string, width, height int32, steps int, seed int64, cfgScale float32, progress ProgressFunc) (*mlx.Array, error) {
|
||||
return m.GenerateFromConfig(&GenerateConfig{
|
||||
Prompt: prompt,
|
||||
NegativePrompt: negativePrompt,
|
||||
@@ -169,7 +172,7 @@ func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
cfg.Height = 1024
|
||||
}
|
||||
if cfg.Steps <= 0 {
|
||||
cfg.Steps = 50
|
||||
cfg.Steps = 30
|
||||
}
|
||||
if cfg.CFGScale <= 0 {
|
||||
cfg.CFGScale = 4.0
|
||||
|
||||
@@ -18,15 +18,18 @@ import (
|
||||
// GenerateConfig holds all options for image editing.
|
||||
type GenerateConfig struct {
|
||||
Prompt string
|
||||
NegativePrompt string // Unconditional prompt for CFG (empty string "" is valid)
|
||||
CFGScale float32 // CFG enabled when > 1.0 (default: 4.0)
|
||||
Width int32 // Output width (default: from input image)
|
||||
Height int32 // Output height (default: from input image)
|
||||
Steps int // Denoising steps (default: 50)
|
||||
Seed int64 // Random seed
|
||||
Progress func(step, totalSteps int) // Optional progress callback
|
||||
NegativePrompt string // Unconditional prompt for CFG (empty string "" is valid)
|
||||
CFGScale float32 // CFG enabled when > 1.0 (default: 4.0)
|
||||
Width int32 // Output width (default: from input image)
|
||||
Height int32 // Output height (default: from input image)
|
||||
Steps int // Denoising steps (default: 50)
|
||||
Seed int64 // Random seed
|
||||
Progress ProgressFunc // Optional progress callback
|
||||
}
|
||||
|
||||
// ProgressFunc is called during generation with step progress.
|
||||
type ProgressFunc func(step, totalSteps int)
|
||||
|
||||
// Model represents a Qwen-Image-Edit diffusion model.
|
||||
type Model struct {
|
||||
ModelPath string
|
||||
|
||||
@@ -3,35 +3,13 @@
|
||||
package qwen_image_edit
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/models/qwen_image"
|
||||
)
|
||||
|
||||
// TestMain initializes MLX before running tests.
|
||||
// If MLX libraries are not available, tests are skipped.
|
||||
func TestMain(m *testing.M) {
|
||||
// Change to repo root so ./build/lib/ollama/ path works
|
||||
_, thisFile, _, _ := runtime.Caller(0)
|
||||
repoRoot := filepath.Join(filepath.Dir(thisFile), "..", "..", "..", "..")
|
||||
if err := os.Chdir(repoRoot); err != nil {
|
||||
fmt.Printf("Failed to change to repo root: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if err := mlx.InitMLX(); err != nil {
|
||||
fmt.Printf("Skipping qwen_image_edit tests: %v\n", err)
|
||||
os.Exit(0)
|
||||
}
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
// TestComputeAxisFreqs verifies frequency computation matches Python reference
|
||||
func TestComputeAxisFreqs(t *testing.T) {
|
||||
theta := float64(10000)
|
||||
|
||||
@@ -3,17 +3,287 @@
|
||||
package zimage
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/x/imagegen/models/qwen3"
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
)
|
||||
|
||||
// Re-export types from shared qwen3 package for backwards compatibility
|
||||
type (
|
||||
Qwen3Config = qwen3.Config
|
||||
Qwen3Attention = qwen3.Attention
|
||||
Qwen3MLP = qwen3.MLP
|
||||
Qwen3Block = qwen3.Block
|
||||
Qwen3TextEncoder = qwen3.TextEncoder
|
||||
)
|
||||
// Qwen3Config holds Qwen3 text encoder configuration
|
||||
type Qwen3Config struct {
|
||||
HiddenSize int32 `json:"hidden_size"`
|
||||
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
||||
IntermediateSize int32 `json:"intermediate_size"`
|
||||
NumAttentionHeads int32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads int32 `json:"num_key_value_heads"`
|
||||
VocabSize int32 `json:"vocab_size"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
HeadDim int32 `json:"head_dim"`
|
||||
}
|
||||
|
||||
// Qwen3Attention implements Qwen3 attention with QK norms
|
||||
type Qwen3Attention struct {
|
||||
QProj nn.LinearLayer `weight:"q_proj"`
|
||||
KProj nn.LinearLayer `weight:"k_proj"`
|
||||
VProj nn.LinearLayer `weight:"v_proj"`
|
||||
OProj nn.LinearLayer `weight:"o_proj"`
|
||||
QNorm *nn.RMSNorm `weight:"q_norm"`
|
||||
KNorm *nn.RMSNorm `weight:"k_norm"`
|
||||
// Computed fields
|
||||
NHeads int32
|
||||
NKVHeads int32
|
||||
HeadDim int32
|
||||
Scale float32
|
||||
RopeTheta float32
|
||||
}
|
||||
|
||||
// applyRoPEQwen3 applies the custom RoPE for Qwen3 text encoder
|
||||
func applyRoPEQwen3(x *mlx.Array, seqLen int32, theta float32) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
H := shape[2]
|
||||
D := shape[3]
|
||||
half := D / 2
|
||||
|
||||
freqsArr := make([]float32, half)
|
||||
logTheta := float32(math.Log(float64(theta)))
|
||||
for i := int32(0); i < half; i++ {
|
||||
freqsArr[i] = float32(math.Exp(float64(-logTheta * float32(i) / float32(half))))
|
||||
}
|
||||
freqs := mlx.NewArray(freqsArr, []int32{half})
|
||||
|
||||
posArr := make([]float32, seqLen)
|
||||
for i := int32(0); i < seqLen; i++ {
|
||||
posArr[i] = float32(i)
|
||||
}
|
||||
pos := mlx.NewArray(posArr, []int32{seqLen})
|
||||
|
||||
posExpanded := mlx.Reshape(pos, seqLen, 1)
|
||||
freqsExpanded := mlx.Reshape(freqs, 1, half)
|
||||
args := mlx.Mul(posExpanded, freqsExpanded)
|
||||
|
||||
cosVals := mlx.Cos(args)
|
||||
sinVals := mlx.Sin(args)
|
||||
cosVals = mlx.Reshape(cosVals, seqLen, 1, half)
|
||||
sinVals = mlx.Reshape(sinVals, seqLen, 1, half)
|
||||
|
||||
x1 := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, L, H, half})
|
||||
x2 := mlx.Slice(x, []int32{0, 0, 0, half}, []int32{B, L, H, D})
|
||||
|
||||
part1 := mlx.Sub(mlx.Mul(x1, cosVals), mlx.Mul(x2, sinVals))
|
||||
part2 := mlx.Add(mlx.Mul(x1, sinVals), mlx.Mul(x2, cosVals))
|
||||
|
||||
return mlx.Concatenate([]*mlx.Array{part1, part2}, 3)
|
||||
}
|
||||
|
||||
// Forward computes attention with causal masking
|
||||
func (attn *Qwen3Attention) Forward(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
|
||||
q := attn.QProj.Forward(x)
|
||||
k := attn.KProj.Forward(x)
|
||||
v := attn.VProj.Forward(x)
|
||||
|
||||
q = mlx.Reshape(q, B, L, attn.NHeads, attn.HeadDim)
|
||||
k = mlx.Reshape(k, B, L, attn.NKVHeads, attn.HeadDim)
|
||||
v = mlx.Reshape(v, B, L, attn.NKVHeads, attn.HeadDim)
|
||||
|
||||
// QK norm uses 1e-6 hardcoded (Qwen3 specific)
|
||||
q = attn.QNorm.Forward(q, 1e-6)
|
||||
k = attn.KNorm.Forward(k, 1e-6)
|
||||
|
||||
q = applyRoPEQwen3(q, L, attn.RopeTheta)
|
||||
k = applyRoPEQwen3(k, L, attn.RopeTheta)
|
||||
|
||||
q = mlx.Transpose(q, 0, 2, 1, 3)
|
||||
k = mlx.Transpose(k, 0, 2, 1, 3)
|
||||
v = mlx.Transpose(v, 0, 2, 1, 3)
|
||||
|
||||
if attn.NKVHeads < attn.NHeads {
|
||||
repeats := attn.NHeads / attn.NKVHeads
|
||||
k = repeatKV(k, repeats)
|
||||
v = repeatKV(v, repeats)
|
||||
}
|
||||
|
||||
out := mlx.ScaledDotProductAttention(q, k, v, attn.Scale, true)
|
||||
|
||||
out = mlx.Transpose(out, 0, 2, 1, 3)
|
||||
out = mlx.Reshape(out, B, L, attn.NHeads*attn.HeadDim)
|
||||
|
||||
out = attn.OProj.Forward(out)
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// repeatKV repeats key/value heads for GQA
|
||||
func repeatKV(x *mlx.Array, repeats int32) *mlx.Array {
|
||||
if repeats == 1 {
|
||||
return x
|
||||
}
|
||||
shape := x.Shape()
|
||||
x = mlx.ExpandDims(x, 2)
|
||||
x = mlx.Tile(x, []int32{1, 1, repeats, 1, 1})
|
||||
return mlx.Reshape(x, shape[0], shape[1]*repeats, shape[2], shape[3])
|
||||
}
|
||||
|
||||
// Qwen3MLP implements Qwen3 SwiGLU MLP
|
||||
type Qwen3MLP struct {
|
||||
GateProj nn.LinearLayer `weight:"gate_proj"`
|
||||
UpProj nn.LinearLayer `weight:"up_proj"`
|
||||
DownProj nn.LinearLayer `weight:"down_proj"`
|
||||
}
|
||||
|
||||
// Forward applies the MLP
|
||||
func (m *Qwen3MLP) Forward(x *mlx.Array) *mlx.Array {
|
||||
gate := m.GateProj.Forward(x)
|
||||
gate = mlx.SiLU(gate)
|
||||
up := m.UpProj.Forward(x)
|
||||
h := mlx.Mul(gate, up)
|
||||
return m.DownProj.Forward(h)
|
||||
}
|
||||
|
||||
// Qwen3Block represents a single Qwen3 transformer block
|
||||
type Qwen3Block struct {
|
||||
Attention *Qwen3Attention `weight:"self_attn"`
|
||||
MLP *Qwen3MLP `weight:"mlp"`
|
||||
InputLayerNorm *nn.RMSNorm `weight:"input_layernorm"`
|
||||
PostAttnLayerNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
|
||||
}
|
||||
|
||||
// Forward applies the Qwen3 block
|
||||
func (qb *Qwen3Block) Forward(x *mlx.Array, eps float32) *mlx.Array {
|
||||
h := qb.InputLayerNorm.Forward(x, eps)
|
||||
attnOut := qb.Attention.Forward(h)
|
||||
x = mlx.Add(x, attnOut)
|
||||
|
||||
h = qb.PostAttnLayerNorm.Forward(x, eps)
|
||||
mlpOut := qb.MLP.Forward(h)
|
||||
x = mlx.Add(x, mlpOut)
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// Qwen3TextEncoder is the full Qwen3 encoder for Z-Image
|
||||
type Qwen3TextEncoder struct {
|
||||
EmbedTokens *nn.Embedding `weight:"model.embed_tokens"`
|
||||
Layers []*Qwen3Block `weight:"model.layers"`
|
||||
FinalNorm *nn.RMSNorm `weight:"model.norm"`
|
||||
*Qwen3Config
|
||||
}
|
||||
|
||||
// Load loads the Qwen3 text encoder from ollama blob storage.
|
||||
func (m *Qwen3TextEncoder) Load(manifest *imagegen.ModelManifest) error {
|
||||
fmt.Print(" Loading text encoder... ")
|
||||
|
||||
// Load config from blob
|
||||
var cfg Qwen3Config
|
||||
if err := manifest.ReadConfigJSON("text_encoder/config.json", &cfg); err != nil {
|
||||
return fmt.Errorf("config: %w", err)
|
||||
}
|
||||
m.Qwen3Config = &cfg
|
||||
m.Layers = make([]*Qwen3Block, cfg.NumHiddenLayers)
|
||||
|
||||
// Load weights from tensor blobs
|
||||
weights, err := imagegen.LoadWeightsFromManifest(manifest, "text_encoder")
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
if err := weights.Load(0); err != nil {
|
||||
return fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
defer weights.ReleaseAll()
|
||||
|
||||
return m.loadWeights(weights)
|
||||
}
|
||||
|
||||
// loadWeights loads weights from any WeightSource into the model
|
||||
func (m *Qwen3TextEncoder) loadWeights(weights safetensors.WeightSource) error {
|
||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||
return fmt.Errorf("load module: %w", err)
|
||||
}
|
||||
m.initComputedFields()
|
||||
fmt.Println("✓")
|
||||
return nil
|
||||
}
|
||||
|
||||
// initComputedFields initializes computed fields after loading weights
|
||||
func (m *Qwen3TextEncoder) initComputedFields() {
|
||||
cfg := m.Qwen3Config
|
||||
m.FinalNorm.Eps = cfg.RMSNormEps
|
||||
for _, block := range m.Layers {
|
||||
// Attention
|
||||
block.Attention.NHeads = cfg.NumAttentionHeads
|
||||
block.Attention.NKVHeads = cfg.NumKeyValueHeads
|
||||
block.Attention.HeadDim = cfg.HeadDim
|
||||
block.Attention.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
|
||||
block.Attention.RopeTheta = cfg.RopeTheta
|
||||
block.Attention.QNorm.Eps = cfg.RMSNormEps
|
||||
block.Attention.KNorm.Eps = cfg.RMSNormEps
|
||||
// Block norms
|
||||
block.InputLayerNorm.Eps = cfg.RMSNormEps
|
||||
block.PostAttnLayerNorm.Eps = cfg.RMSNormEps
|
||||
}
|
||||
}
|
||||
|
||||
// Forward encodes text tokens
|
||||
func (te *Qwen3TextEncoder) Forward(tokens *mlx.Array) *mlx.Array {
|
||||
h := te.EmbedTokens.Forward(tokens)
|
||||
eps := te.RMSNormEps
|
||||
|
||||
for _, layer := range te.Layers {
|
||||
h = layer.Forward(h, eps)
|
||||
}
|
||||
|
||||
// Apply final RMS norm
|
||||
h = te.FinalNorm.Forward(h, eps)
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// ApplyChatTemplate wraps prompt in Qwen3 chat format
|
||||
var ApplyChatTemplate = qwen3.ApplyChatTemplate
|
||||
func ApplyChatTemplate(prompt string) string {
|
||||
return "<|im_start|>user\n" + prompt + "<|im_end|>\n<|im_start|>assistant\n"
|
||||
}
|
||||
|
||||
// EncodePrompt encodes a text prompt using the tokenizer and encoder
|
||||
func (te *Qwen3TextEncoder) EncodePrompt(tok *tokenizer.Tokenizer, prompt string, maxLen int) (*mlx.Array, *mlx.Array) {
|
||||
formattedPrompt := ApplyChatTemplate(prompt)
|
||||
|
||||
tokens := tok.Encode(formattedPrompt, false)
|
||||
|
||||
if len(tokens) > maxLen {
|
||||
tokens = tokens[:maxLen]
|
||||
}
|
||||
|
||||
maskData := make([]float32, maxLen)
|
||||
for i := 0; i < len(tokens); i++ {
|
||||
maskData[i] = 1.0
|
||||
}
|
||||
|
||||
// Get PAD token (different from EOS for Qwen3)
|
||||
padToken := tok.PAD()
|
||||
if padToken < 0 {
|
||||
padToken = tok.EOS() // fallback
|
||||
}
|
||||
|
||||
paddedTokens := make([]int32, maxLen)
|
||||
copy(paddedTokens, tokens)
|
||||
for i := len(tokens); i < maxLen; i++ {
|
||||
paddedTokens[i] = padToken
|
||||
}
|
||||
|
||||
tokensArr := mlx.NewArrayInt32(paddedTokens, []int32{1, int32(maxLen)})
|
||||
maskArr := mlx.NewArray(maskData, []int32{1, int32(maxLen)})
|
||||
|
||||
embeddings := te.Forward(tokensArr)
|
||||
|
||||
return embeddings, maskArr
|
||||
}
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
"github.com/ollama/ollama/x/imagegen/vae"
|
||||
)
|
||||
|
||||
// VAEConfig holds VAE decoder configuration
|
||||
@@ -637,9 +636,6 @@ type VAEDecoder struct {
|
||||
UpBlocks []*UpDecoderBlock2D
|
||||
ConvNormOut *GroupNormLayer
|
||||
ConvOut *Conv2D
|
||||
|
||||
// Tiling configuration (nil = no tiling)
|
||||
Tiling *vae.TilingConfig
|
||||
}
|
||||
|
||||
// Load loads the VAE decoder from ollama blob storage.
|
||||
@@ -734,60 +730,45 @@ func (m *VAEDecoder) loadWeights(weights safetensors.WeightSource, cfg *VAEConfi
|
||||
|
||||
// Decode decodes latents to images.
|
||||
// Input latents are in NCHW format, output is in NCHW format.
|
||||
// If Tiling is set, uses tiled decoding to reduce memory for large images.
|
||||
func (v *VAEDecoder) Decode(latents *mlx.Array) *mlx.Array {
|
||||
// Internally uses NHWC format (MLX native) for all operations.
|
||||
func (vae *VAEDecoder) Decode(latents *mlx.Array) *mlx.Array {
|
||||
// Scale latents
|
||||
z := mlx.DivScalar(latents, v.Config.ScalingFactor)
|
||||
z = mlx.AddScalar(z, v.Config.ShiftFactor)
|
||||
z := mlx.DivScalar(latents, vae.Config.ScalingFactor)
|
||||
z = mlx.AddScalar(z, vae.Config.ShiftFactor)
|
||||
// Convert NCHW -> NHWC for internal processing
|
||||
z = mlx.Transpose(z, 0, 2, 3, 1)
|
||||
|
||||
// Use tiled decoding if enabled
|
||||
if v.Tiling != nil {
|
||||
mlx.Eval(z)
|
||||
return vae.DecodeTiled(z, v.Tiling, v.decodeTile)
|
||||
}
|
||||
|
||||
// Direct decode
|
||||
h := v.decodeTile(z)
|
||||
h = mlx.ClipScalar(h, 0.0, 1.0, true, true)
|
||||
// Convert NHWC -> NCHW for output
|
||||
h = mlx.Transpose(h, 0, 3, 1, 2)
|
||||
mlx.Eval(h)
|
||||
return h
|
||||
}
|
||||
|
||||
// decodeTile decodes a single latent tile to pixels.
|
||||
// Input: [B, H, W, C] latent tile in NHWC format (already scaled)
|
||||
// Output: [B, H*8, W*8, 3] pixel tile in NHWC format
|
||||
func (v *VAEDecoder) decodeTile(z *mlx.Array) *mlx.Array {
|
||||
h := v.ConvIn.Forward(z)
|
||||
h := vae.ConvIn.Forward(z)
|
||||
mlx.Eval(h)
|
||||
|
||||
prev := h
|
||||
h = v.MidBlock.Forward(h)
|
||||
h = vae.MidBlock.Forward(h)
|
||||
prev.Free()
|
||||
|
||||
for _, upBlock := range v.UpBlocks {
|
||||
for _, upBlock := range vae.UpBlocks {
|
||||
prev = h
|
||||
h = upBlock.Forward(h)
|
||||
prev.Free()
|
||||
}
|
||||
|
||||
prev = h
|
||||
h = v.ConvNormOut.Forward(h)
|
||||
h = vae.ConvNormOut.Forward(h)
|
||||
mlx.Eval(h) // Eval after GroupNorm to avoid grid dimension issues
|
||||
prev.Free()
|
||||
|
||||
prev = h
|
||||
h = mlx.SiLU(h)
|
||||
h = v.ConvOut.Forward(h)
|
||||
h = vae.ConvOut.Forward(h)
|
||||
mlx.Eval(h)
|
||||
prev.Free()
|
||||
|
||||
// VAE outputs [-1, 1], convert to [0, 1]
|
||||
h = mlx.MulScalar(h, 0.5)
|
||||
h = mlx.AddScalar(h, 0.5)
|
||||
h = mlx.ClipScalar(h, 0.0, 1.0, true, true)
|
||||
|
||||
// Convert NHWC -> NCHW for output
|
||||
h = mlx.Transpose(h, 0, 3, 1, 2)
|
||||
mlx.Eval(h)
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
@@ -12,20 +12,19 @@ import (
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
"github.com/ollama/ollama/x/imagegen/vae"
|
||||
)
|
||||
|
||||
// GenerateConfig holds all options for image generation.
|
||||
type GenerateConfig struct {
|
||||
Prompt string
|
||||
NegativePrompt string // Empty = no CFG
|
||||
CFGScale float32 // Only used if NegativePrompt is set (default: 4.0)
|
||||
Width int32 // Image width (default: 1024)
|
||||
Height int32 // Image height (default: 1024)
|
||||
Steps int // Denoising steps (default: 9 for turbo)
|
||||
Seed int64 // Random seed
|
||||
Progress func(step, totalSteps int) // Optional progress callback
|
||||
CapturePath string // GPU capture path (debug)
|
||||
NegativePrompt string // Empty = no CFG
|
||||
CFGScale float32 // Only used if NegativePrompt is set (default: 4.0)
|
||||
Width int32 // Image width (default: 1024)
|
||||
Height int32 // Image height (default: 1024)
|
||||
Steps int // Denoising steps (default: 9 for turbo)
|
||||
Seed int64 // Random seed
|
||||
Progress ProgressFunc // Optional progress callback
|
||||
CapturePath string // GPU capture path (debug)
|
||||
|
||||
// TeaCache options (timestep embedding aware caching)
|
||||
TeaCache bool // TeaCache is always enabled for faster inference
|
||||
@@ -35,6 +34,9 @@ type GenerateConfig struct {
|
||||
FusedQKV bool // Enable fused QKV projection (default: false)
|
||||
}
|
||||
|
||||
// ProgressFunc is called during generation with step progress.
|
||||
type ProgressFunc func(step, totalSteps int)
|
||||
|
||||
// Model represents a Z-Image diffusion model.
|
||||
type Model struct {
|
||||
ModelName string
|
||||
@@ -91,7 +93,7 @@ func (m *Model) Load(modelName string) error {
|
||||
|
||||
// Load text encoder
|
||||
m.TextEncoder = &Qwen3TextEncoder{}
|
||||
if err := m.TextEncoder.Load(manifest, "text_encoder/config.json"); err != nil {
|
||||
if err := m.TextEncoder.Load(manifest); err != nil {
|
||||
return fmt.Errorf("text encoder: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.TextEncoder)...)
|
||||
@@ -137,7 +139,7 @@ func (m *Model) Generate(prompt string, width, height int32, steps int, seed int
|
||||
}
|
||||
|
||||
// GenerateWithProgress creates an image with progress callback.
|
||||
func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress func(step, totalSteps int)) (*mlx.Array, error) {
|
||||
func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress ProgressFunc) (*mlx.Array, error) {
|
||||
return m.GenerateFromConfig(context.Background(), &GenerateConfig{
|
||||
Prompt: prompt,
|
||||
Width: width,
|
||||
@@ -149,7 +151,7 @@ func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps i
|
||||
}
|
||||
|
||||
// GenerateWithCFG creates an image with classifier-free guidance.
|
||||
func (m *Model) GenerateWithCFG(prompt, negativePrompt string, width, height int32, steps int, seed int64, cfgScale float32, progress func(step, totalSteps int)) (*mlx.Array, error) {
|
||||
func (m *Model) GenerateWithCFG(prompt, negativePrompt string, width, height int32, steps int, seed int64, cfgScale float32, progress ProgressFunc) (*mlx.Array, error) {
|
||||
return m.GenerateFromConfig(context.Background(), &GenerateConfig{
|
||||
Prompt: prompt,
|
||||
NegativePrompt: negativePrompt,
|
||||
@@ -177,16 +179,9 @@ func (m *Model) GenerateFromConfig(ctx context.Context, cfg *GenerateConfig) (*m
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GenerateImage implements runner.ImageModel interface.
|
||||
func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64, progress func(step, total int)) (*mlx.Array, error) {
|
||||
return m.GenerateFromConfig(ctx, &GenerateConfig{
|
||||
Prompt: prompt,
|
||||
Width: width,
|
||||
Height: height,
|
||||
Steps: steps,
|
||||
Seed: seed,
|
||||
Progress: progress,
|
||||
})
|
||||
// GenerateImage implements model.ImageModel interface.
|
||||
func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
|
||||
return m.Generate(prompt, width, height, steps, seed)
|
||||
}
|
||||
|
||||
// generate is the internal denoising pipeline.
|
||||
@@ -199,7 +194,7 @@ func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array,
|
||||
cfg.Height = 1024
|
||||
}
|
||||
if cfg.Steps <= 0 {
|
||||
cfg.Steps = 9 // Z-Image turbo default
|
||||
cfg.Steps = 9 // Turbo default
|
||||
}
|
||||
if cfg.CFGScale <= 0 {
|
||||
cfg.CFGScale = 4.0
|
||||
@@ -227,9 +222,9 @@ func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array,
|
||||
// Text encoding with padding to multiple of 32
|
||||
var posEmb, negEmb *mlx.Array
|
||||
{
|
||||
posEmb, _ = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.Prompt, 512, false)
|
||||
posEmb, _ = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.Prompt, 512)
|
||||
if useCFG {
|
||||
negEmb, _ = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.NegativePrompt, 512, false)
|
||||
negEmb, _ = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.NegativePrompt, 512)
|
||||
}
|
||||
|
||||
// Pad both to same length (multiple of 32)
|
||||
@@ -453,11 +448,7 @@ func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array,
|
||||
teaCache.Free()
|
||||
}
|
||||
|
||||
// VAE decode - enable tiling for larger images to reduce memory
|
||||
// VAE attention is O(n²) on latent pixels, tiling helps significantly
|
||||
if latentH > 64 || latentW > 64 {
|
||||
m.VAEDecoder.Tiling = vae.DefaultTilingConfig()
|
||||
}
|
||||
// VAE decode
|
||||
decoded := m.VAEDecoder.Decode(latents)
|
||||
latents.Free()
|
||||
|
||||
|
||||
@@ -3,34 +3,12 @@
|
||||
package nn
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// TestMain initializes MLX before running tests.
|
||||
// If MLX libraries are not available, tests are skipped.
|
||||
func TestMain(m *testing.M) {
|
||||
// Change to repo root so ./build/lib/ollama/ path works
|
||||
_, thisFile, _, _ := runtime.Caller(0)
|
||||
repoRoot := filepath.Join(filepath.Dir(thisFile), "..", "..", "..")
|
||||
if err := os.Chdir(repoRoot); err != nil {
|
||||
fmt.Printf("Failed to change to repo root: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if err := mlx.InitMLX(); err != nil {
|
||||
fmt.Printf("Skipping nn tests: %v\n", err)
|
||||
os.Exit(0)
|
||||
}
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
// TestLinearNoBias verifies Linear without bias computes x @ w.T correctly.
|
||||
func TestLinearNoBias(t *testing.T) {
|
||||
// Weight: [out=2, in=3] -> transposed at forward time
|
||||
|
||||
22
x/imagegen/quantize.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package imagegen
|
||||
|
||||
import (
|
||||
"io"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// QuantizingTensorLayerCreator creates tensor layers with optional quantization.
|
||||
// When quantize is true, returns multiple layers (weight + scales + biases).
|
||||
type QuantizingTensorLayerCreator func(r io.Reader, name, dtype string, shape []int32, quantize bool) ([]LayerInfo, error)
|
||||
|
||||
// ShouldQuantize returns true if a tensor should be quantized.
|
||||
// Quantizes linear weights only, skipping VAE, embeddings, norms, and biases.
|
||||
func ShouldQuantize(name, component string) bool {
|
||||
if component == "vae" {
|
||||
return false
|
||||
}
|
||||
if strings.Contains(name, "embed") || strings.Contains(name, "norm") {
|
||||
return false
|
||||
}
|
||||
return strings.HasSuffix(name, ".weight")
|
||||
}
|
||||