Compare commits

..

5 Commits

Author SHA1 Message Date
Patrick Devine
3b95add4e3 x: make ollama create --experimental import from safetensors
This change allows pulling in safetensors models into the new experimental model format, and also
fixes the `ollama show` command to be able to correctly display the model information.
2026-01-15 15:38:44 -08:00
Jeffrey Morgan
4adb9cf4bb scripts: fix macOS auto-update signature verification failure (#13713)
Add --norsrc flag to ditto commands when creating Ollama-darwin.zip
to exclude AppleDouble resource fork files (._* files) from the archive.

The mlx.metallib file has extended attributes, which causes ditto to
include a ._mlx.metallib AppleDouble file in the zip. Since this file
is not part of the code signature seal, macOS rejects the bundle during
auto-update verification with:

  "a sealed resource is missing or invalid"
  "file added: .../._mlx.metallib"

The --norsrc flag prevents ditto from preserving resource forks and
extended attributes, ensuring only signed files are included in the
release archive.
2026-01-14 07:48:10 -08:00
Daniel Hiltgen
74f475e735 Revert "Documentation edits made through Mintlify web editor" (#13688)
This reverts commit c6d4c0c7f2.

Merge after 0.14.0 ships for the updated Linux documentation.
2026-01-14 07:42:34 -08:00
Maternion
875cecba74 docs: update default context window size to 4096 tokens (#13709) 2026-01-14 01:01:28 -08:00
Josh Daniel Bañares
7d411a4686 docs: update web search param in examples (#13711) 2026-01-14 00:38:39 -08:00
30 changed files with 2665 additions and 4852 deletions

View File

@@ -48,7 +48,7 @@ ollama run gemma3
## Model library
Ollama supports a list of models available on [ollama.com/library](https://ollama.com/library "ollama model library")
Ollama supports a list of models available on [ollama.com/library](https://ollama.com/library 'ollama model library')
Here are some example models that can be downloaded:
@@ -79,7 +79,7 @@ Here are some example models that can be downloaded:
| Code Llama | 7B | 3.8GB | `ollama run codellama` |
| Llama 2 Uncensored | 7B | 3.8GB | `ollama run llama2-uncensored` |
| LLaVA | 7B | 4.5GB | `ollama run llava` |
| Granite-3.3 | 8B | 4.9GB | `ollama run granite3.3` |
| Granite-3.3 | 8B | 4.9GB | `ollama run granite3.3` |
> [!NOTE]
> You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models.
@@ -260,38 +260,6 @@ Finally, in a separate shell, run a model:
./ollama run llama3.2
```
## Building with MLX (experimental)
First build the MLX libraries:
```shell
cmake --preset MLX
cmake --build --preset MLX --parallel
cmake --install build --component MLX
```
Next, build the `ollama-mlx` binary, which is a separate build of the Ollama runtime with MLX support enabled (needs to be in the same directory as `ollama`):
```shell
go build -tags mlx -o ollama-mlx .
```
Finally, start the server:
```
./ollama serve
```
### Building MLX with CUDA
When building with CUDA, use the preset "MLX CUDA 13" or "MLX CUDA 12" to enable CUDA with default architectures:
```shell
cmake --preset 'MLX CUDA 13'
cmake --build --preset 'MLX CUDA 13' --parallel
cmake --install build --component MLX
```
## REST API
Ollama has a REST API for running and managing models.
@@ -453,7 +421,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [AppFlowy](https://github.com/AppFlowy-IO/AppFlowy) (AI collaborative workspace with Ollama, cross-platform and self-hostable)
- [Lumina](https://github.com/cushydigit/lumina.git) (A lightweight, minimal React.js frontend for interacting with Ollama servers)
- [Tiny Notepad](https://pypi.org/project/tiny-notepad) (A lightweight, notepad-like interface to chat with ollama available on PyPI)
- [macLlama (macOS native)](https://github.com/hellotunamayo/macLlama) (A native macOS GUI application for interacting with Ollama models, featuring a chat interface.)
- [macLlama (macOS native)](https://github.com/hellotunamayo/macLlama) (A native macOS GUI application for interacting with Ollama models, featuring a chat interface.)
- [GPTranslate](https://github.com/philberndt/GPTranslate) (A fast and lightweight, AI powered desktop translation application written with Rust and Tauri. Features real-time translation with OpenAI/Azure/Ollama.)
- [ollama launcher](https://github.com/NGC13009/ollama-launcher) (A launcher for Ollama, aiming to provide users with convenient functions such as ollama server launching, management, or configuration.)
- [ai-hub](https://github.com/Aj-Seven/ai-hub) (AI Hub supports multiple models via API keys and Chat support via Ollama API.)
@@ -525,7 +493,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
### Database
- [pgai](https://github.com/timescale/pgai) - PostgreSQL as a vector database (Create and search embeddings from Ollama models using pgvector)
- [Get started guide](https://github.com/timescale/pgai/blob/main/docs/vectorizer-quick-start.md)
- [Get started guide](https://github.com/timescale/pgai/blob/main/docs/vectorizer-quick-start.md)
- [MindsDB](https://github.com/mindsdb/mindsdb/blob/staging/mindsdb/integrations/handlers/ollama_handler/README.md) (Connects Ollama models with nearly 200 data platforms and apps)
- [chromem-go](https://github.com/philippgille/chromem-go/blob/v0.5.0/embed_ollama.go) with [example](https://github.com/philippgille/chromem-go/tree/v0.5.0/examples/rag-wikipedia-ollama)
- [Kangaroo](https://github.com/dbkangaroo/kangaroo) (AI-powered SQL client and admin tool for popular databases)
@@ -668,7 +636,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [llama.cpp](https://github.com/ggml-org/llama.cpp) project founded by Georgi Gerganov.
### Observability
- [Opik](https://www.comet.com/docs/opik/cookbook/ollama) is an open-source platform to debug, evaluate, and monitor your LLM applications, RAG systems, and agentic workflows with comprehensive tracing, automated evaluations, and production-ready dashboards. Opik supports native integration to Ollama.
- [Lunary](https://lunary.ai/docs/integrations/ollama) is the leading open-source LLM observability platform. It provides a variety of enterprise-grade features such as real-time analytics, prompt templates management, PII masking, and comprehensive agent tracing.
- [OpenLIT](https://github.com/openlit/openlit) is an OpenTelemetry-native tool for monitoring Ollama Applications & GPUs using traces and metrics.
@@ -677,5 +644,4 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [MLflow Tracing](https://mlflow.org/docs/latest/llms/tracing/index.html#automatic-tracing) is an open source LLM observability tool with a convenient API to log and visualize traces, making it easy to debug and evaluate GenAI applications.
### Security
- [Ollama Fortress](https://github.com/ParisNeo/ollama_proxy_server)

View File

@@ -46,8 +46,9 @@ import (
"github.com/ollama/ollama/types/syncmap"
"github.com/ollama/ollama/version"
xcmd "github.com/ollama/ollama/x/cmd"
"github.com/ollama/ollama/x/create"
xcreateclient "github.com/ollama/ollama/x/create/client"
"github.com/ollama/ollama/x/imagegen"
imagegenclient "github.com/ollama/ollama/x/imagegen/client"
)
const ConnectInstructions = "To sign in, navigate to:\n %s\n\n"
@@ -93,15 +94,82 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
p := progress.NewProgress(os.Stderr)
defer p.Stop()
// Check for --experimental flag for safetensors model creation
experimental, _ := cmd.Flags().GetBool("experimental")
if experimental {
modelName := args[0]
// Get Modelfile content - either from -f flag or default to "FROM ."
var reader io.Reader
filename, err := getModelfileName(cmd)
if os.IsNotExist(err) || filename == "" {
// No Modelfile specified or found - use default
reader = strings.NewReader("FROM .\n")
} else if err != nil {
return err
} else {
f, err := os.Open(filename)
if err != nil {
return err
}
defer f.Close()
reader = f
}
// Parse the Modelfile
modelfile, err := parser.ParseFile(reader)
if err != nil {
return fmt.Errorf("failed to parse Modelfile: %w", err)
}
// Extract FROM path and configuration
var modelDir string
mfConfig := &xcreateclient.ModelfileConfig{}
for _, cmd := range modelfile.Commands {
switch cmd.Name {
case "model":
modelDir = cmd.Args
case "template":
mfConfig.Template = cmd.Args
case "system":
mfConfig.System = cmd.Args
case "license":
mfConfig.License = cmd.Args
}
}
if modelDir == "" {
modelDir = "."
}
// Resolve relative paths based on Modelfile location
if !filepath.IsAbs(modelDir) && filename != "" {
modelDir = filepath.Join(filepath.Dir(filename), modelDir)
}
quantize, _ := cmd.Flags().GetString("quantize")
return xcreateclient.CreateModel(xcreateclient.CreateOptions{
ModelName: modelName,
ModelDir: modelDir,
Quantize: quantize,
Modelfile: mfConfig,
}, p)
}
var reader io.Reader
filename, err := getModelfileName(cmd)
if os.IsNotExist(err) {
if filename == "" {
// No Modelfile found - check if current directory is an image gen model
if imagegen.IsTensorModelDir(".") {
if create.IsTensorModelDir(".") {
quantize, _ := cmd.Flags().GetString("quantize")
return imagegenclient.CreateModel(args[0], ".", quantize, p)
return xcreateclient.CreateModel(xcreateclient.CreateOptions{
ModelName: args[0],
ModelDir: ".",
Quantize: quantize,
}, p)
}
reader = strings.NewReader("FROM .\n")
} else {
@@ -1742,15 +1810,22 @@ func NewCLI() *cobra.Command {
rootCmd.Flags().BoolP("version", "v", false, "Show version information")
createCmd := &cobra.Command{
Use: "create MODEL",
Short: "Create a model",
Args: cobra.ExactArgs(1),
PreRunE: checkServerHeartbeat,
RunE: CreateHandler,
Use: "create MODEL",
Short: "Create a model",
Args: cobra.ExactArgs(1),
PreRunE: func(cmd *cobra.Command, args []string) error {
// Skip server check for experimental mode (writes directly to disk)
if experimental, _ := cmd.Flags().GetBool("experimental"); experimental {
return nil
}
return checkServerHeartbeat(cmd, args)
},
RunE: CreateHandler,
}
createCmd.Flags().StringP("file", "f", "", "Name of the Modelfile (default \"Modelfile\")")
createCmd.Flags().StringP("quantize", "q", "", "Quantize model to this level (e.g. q4_K_M)")
createCmd.Flags().Bool("experimental", false, "Enable experimental safetensors model creation")
showCmd := &cobra.Command{
Use: "show MODEL",

View File

@@ -110,7 +110,7 @@ More Ollama [Python example](https://github.com/ollama/ollama-python/blob/main/e
import { Ollama } from "ollama";
const client = new Ollama();
const results = await client.webSearch({ query: "what is ollama?" });
const results = await client.webSearch("what is ollama?");
console.log(JSON.stringify(results, null, 2));
```
@@ -213,7 +213,7 @@ models](https://ollama.com/models)\n\nAvailable for macOS, Windows, and Linux',
import { Ollama } from "ollama";
const client = new Ollama();
const fetchResult = await client.webFetch({ url: "https://ollama.com" });
const fetchResult = await client.webFetch("https://ollama.com");
console.log(JSON.stringify(fetchResult, null, 2));
```

View File

@@ -22,7 +22,7 @@ Please refer to the [GPU docs](./gpu).
## How can I specify the context window size?
By default, Ollama uses a context window size of 2048 tokens.
By default, Ollama uses a context window size of 4096 tokens.
This can be overridden with the `OLLAMA_CONTEXT_LENGTH` environment variable. For example, to set the default context window to 8K, use:

View File

@@ -1,5 +1,5 @@
---
title: "Linux"
title: Linux
---
## Install
@@ -13,14 +13,15 @@ curl -fsSL https://ollama.com/install.sh | sh
## Manual install
<Note>
If you are upgrading from a prior version, you should remove the old libraries with `sudo rm -rf /usr/lib/ollama` first.
If you are upgrading from a prior version, you should remove the old libraries
with `sudo rm -rf /usr/lib/ollama` first.
</Note>
Download and extract the package:
```shell
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tgz \
| sudo tar zx -C /usr
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tar.zst \
| sudo tar x -C /usr
```
Start Ollama:
@@ -40,8 +41,8 @@ ollama -v
If you have an AMD GPU, also download and extract the additional ROCm package:
```shell
curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tgz \
| sudo tar zx -C /usr
curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tar.zst \
| sudo tar x -C /usr
```
### ARM64 install
@@ -49,8 +50,8 @@ curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tgz \
Download and extract the ARM64-specific package:
```shell
curl -fsSL https://ollama.com/download/ollama-linux-arm64.tgz \
| sudo tar zx -C /usr
curl -fsSL https://ollama.com/download/ollama-linux-arm64.tar.zst \
| sudo tar x -C /usr
```
### Adding Ollama as a startup service (recommended)
@@ -112,7 +113,11 @@ sudo systemctl status ollama
```
<Note>
While AMD has contributed the `amdgpu` driver upstream to the official linux kernel source, the version is older and may not support all ROCm features. We recommend you install the latest driver from https://www.amd.com/en/support/linux-drivers for best support of your Radeon GPU.
While AMD has contributed the `amdgpu` driver upstream to the official linux
kernel source, the version is older and may not support all ROCm features. We
recommend you install the latest driver from
https://www.amd.com/en/support/linux-drivers for best support of your Radeon
GPU.
</Note>
## Customizing
@@ -141,8 +146,8 @@ curl -fsSL https://ollama.com/install.sh | sh
Or by re-downloading Ollama:
```shell
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tgz \
| sudo tar zx -C /usr
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tar.zst \
| sudo tar x -C /usr
```
## Installing specific versions
@@ -191,4 +196,4 @@ Remove the downloaded models and Ollama service user and group:
sudo userdel ollama
sudo groupdel ollama
sudo rm -r /usr/share/ollama
```
```

View File

@@ -179,7 +179,7 @@ _build_macapp() {
fi
rm -f dist/Ollama-darwin.zip
ditto -c -k --keepParent dist/Ollama.app dist/Ollama-darwin.zip
ditto -c -k --norsrc --keepParent dist/Ollama.app dist/Ollama-darwin.zip
(cd dist/Ollama.app/Contents/Resources/; tar -cf - ollama ollama-mlx *.so *.dylib *.metallib 2>/dev/null) | gzip -9vc > dist/ollama-darwin.tgz
# Notarize and Staple
@@ -187,7 +187,7 @@ _build_macapp() {
$(xcrun -f notarytool) submit dist/Ollama-darwin.zip --wait --timeout 20m --apple-id "$APPLE_ID" --password "$APPLE_PASSWORD" --team-id "$APPLE_TEAM_ID"
rm -f dist/Ollama-darwin.zip
$(xcrun -f stapler) staple dist/Ollama.app
ditto -c -k --keepParent dist/Ollama.app dist/Ollama-darwin.zip
ditto -c -k --norsrc --keepParent dist/Ollama.app dist/Ollama-darwin.zip
rm -f dist/Ollama.dmg

View File

@@ -52,6 +52,7 @@ import (
"github.com/ollama/ollama/version"
"github.com/ollama/ollama/x/imagegen"
imagegenapi "github.com/ollama/ollama/x/imagegen/api"
xserver "github.com/ollama/ollama/x/server"
)
const signinURLStr = "https://ollama.com/connect?name=%s&key=%s"
@@ -1133,6 +1134,22 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
}
}
// For safetensors LLM models (experimental), populate details from config.json
if m.Config.ModelFormat == "safetensors" && slices.Contains(m.Config.Capabilities, "completion") {
if info, err := xserver.GetSafetensorsLLMInfo(name.String()); err == nil {
if arch, ok := info["general.architecture"].(string); ok && arch != "" {
modelDetails.Family = arch
}
if paramCount, ok := info["general.parameter_count"].(int64); ok && paramCount > 0 {
modelDetails.ParameterSize = format.HumanNumber(uint64(paramCount))
}
}
// Get torch_dtype directly from config.json for quantization level
if dtype, err := xserver.GetSafetensorsDtype(name.String()); err == nil && dtype != "" {
modelDetails.QuantizationLevel = dtype
}
}
if req.System != "" {
m.System = req.System
}
@@ -1219,6 +1236,20 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
return resp, nil
}
// For safetensors LLM models (experimental), populate ModelInfo from config.json
if m.Config.ModelFormat == "safetensors" && slices.Contains(m.Config.Capabilities, "completion") {
if info, err := xserver.GetSafetensorsLLMInfo(name.String()); err == nil {
resp.ModelInfo = info
}
// Populate tensor info if verbose
if req.Verbose {
if tensors, err := xserver.GetSafetensorsTensorInfo(name.String()); err == nil {
resp.Tensors = tensors
}
}
return resp, nil
}
kvData, tensors, err := getModelData(m.ModelPath, req.Verbose)
if err != nil {
return nil, err

50
x/README.md Normal file
View File

@@ -0,0 +1,50 @@
# Experimental Features
## MLX Backend
We're working on a new experimental backend based on the [MLX project](https://github.com/ml-explore/mlx)
Support is currently limited to MacOS and Linux with CUDA GPUs. We're looking to add support for Windows CUDA soon, and other GPU vendors.
### Building ollama-mlx
The `ollama-mlx` binary is a separate build of Ollama with MLX support enabled. This enables experimental features like image generation.
#### macOS (Apple Silicon and Intel)
```bash
# Build MLX backend libraries
cmake --preset MLX
cmake --build --preset MLX --parallel
cmake --install build --component MLX
# Build ollama-mlx binary
go build -tags mlx -o ollama-mlx .
```
#### Linux (CUDA)
On Linux, use the preset "MLX CUDA 13" or "MLX CUDA 12" to enable CUDA with the default Ollama NVIDIA GPU architectures enabled:
```bash
# Build MLX backend libraries with CUDA support
cmake --preset 'MLX CUDA 13'
cmake --build --preset 'MLX CUDA 13' --parallel
cmake --install build --component MLX
# Build ollama-mlx binary
CGO_CFLAGS="-O3 -I$(pwd)/build/_deps/mlx-c-src" \
CGO_LDFLAGS="-L$(pwd)/build/lib/ollama -lmlxc -lmlx" \
go build -tags mlx -o ollama-mlx .
```
#### Using build scripts
The build scripts automatically create the `ollama-mlx` binary:
- **macOS**: `./scripts/build_darwin.sh` produces `dist/darwin/ollama-mlx`
- **Linux**: `./scripts/build_linux.sh` produces `ollama-mlx` in the output archives
## Image Generation
Image generation is built into the `ollama-mlx` binary. Run `ollama-mlx serve` to start the server with image generation support enabled.

282
x/create/client/create.go Normal file
View File

@@ -0,0 +1,282 @@
// Package client provides client-side model creation for safetensors-based models.
//
// This package is in x/ because the safetensors model storage format is under development.
// It also exists to break an import cycle: server imports x/create, so x/create
// cannot import server. This sub-package can import server because server doesn't
// import it.
package client
import (
"bytes"
"encoding/json"
"fmt"
"io"
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/server"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/x/create"
)
// MinOllamaVersion is the minimum Ollama version required for safetensors models.
const MinOllamaVersion = "0.14.0"
// ModelfileConfig holds configuration extracted from a Modelfile.
type ModelfileConfig struct {
Template string
System string
License string
}
// CreateOptions holds all options for model creation.
type CreateOptions struct {
ModelName string
ModelDir string
Quantize string // "fp8" for quantization
Modelfile *ModelfileConfig // template/system/license from Modelfile
}
// CreateModel imports a model from a local directory.
// This creates blobs and manifest directly on disk, bypassing the HTTP API.
// Automatically detects model type (safetensors LLM vs image gen) and routes accordingly.
func CreateModel(opts CreateOptions, p *progress.Progress) error {
// Detect model type
isSafetensors := create.IsSafetensorsModelDir(opts.ModelDir)
isImageGen := create.IsTensorModelDir(opts.ModelDir)
if !isSafetensors && !isImageGen {
return fmt.Errorf("%s is not a supported model directory (needs config.json + *.safetensors or model_index.json)", opts.ModelDir)
}
// Determine model type settings
var modelType, spinnerKey string
var capabilities []string
if isSafetensors {
modelType = "safetensors model"
spinnerKey = "create"
capabilities = []string{"completion"}
} else {
modelType = "image generation model"
spinnerKey = "imagegen"
capabilities = []string{"image"}
}
// Set up progress spinner
statusMsg := "importing " + modelType
spinner := progress.NewSpinner(statusMsg)
p.Add(spinnerKey, spinner)
progressFn := func(msg string) {
spinner.Stop()
statusMsg = msg
spinner = progress.NewSpinner(statusMsg)
p.Add(spinnerKey, spinner)
}
// Create the model using shared callbacks
var err error
if isSafetensors {
err = create.CreateSafetensorsModel(
opts.ModelName, opts.ModelDir, opts.Quantize,
newLayerCreator(), newTensorLayerCreator(),
newManifestWriter(opts, capabilities),
progressFn,
)
} else {
err = create.CreateImageGenModel(
opts.ModelName, opts.ModelDir, opts.Quantize,
newLayerCreator(), newTensorLayerCreator(),
newManifestWriter(opts, capabilities),
progressFn,
)
}
spinner.Stop()
if err != nil {
return err
}
fmt.Printf("Created %s '%s'\n", modelType, opts.ModelName)
return nil
}
// newLayerCreator returns a LayerCreator callback for creating config/JSON layers.
func newLayerCreator() create.LayerCreator {
return func(r io.Reader, mediaType, name string) (create.LayerInfo, error) {
layer, err := server.NewLayer(r, mediaType)
if err != nil {
return create.LayerInfo{}, err
}
return create.LayerInfo{
Digest: layer.Digest,
Size: layer.Size,
MediaType: layer.MediaType,
Name: name,
}, nil
}
}
// newTensorLayerCreator returns a QuantizingTensorLayerCreator callback for creating tensor layers.
// When doQuantize is true, returns multiple layers (weight + scales + optional qbias).
func newTensorLayerCreator() create.QuantizingTensorLayerCreator {
return func(r io.Reader, name, dtype string, shape []int32, doQuantize bool) ([]create.LayerInfo, error) {
if doQuantize {
return createQuantizedLayers(r, name, dtype, shape)
}
return createUnquantizedLayer(r, name)
}
}
// createQuantizedLayers quantizes a tensor and returns the resulting layers.
func createQuantizedLayers(r io.Reader, name, dtype string, shape []int32) ([]create.LayerInfo, error) {
if !QuantizeSupported() {
return nil, fmt.Errorf("quantization requires MLX support")
}
// Quantize the tensor (affine mode returns weight, scales, qbiases)
qweightData, scalesData, qbiasData, _, _, _, err := quantizeTensor(r, name, dtype, shape)
if err != nil {
return nil, fmt.Errorf("failed to quantize %s: %w", name, err)
}
// Create layer for quantized weight
weightLayer, err := server.NewLayer(bytes.NewReader(qweightData), server.MediaTypeImageTensor)
if err != nil {
return nil, err
}
// Create layer for scales
scalesLayer, err := server.NewLayer(bytes.NewReader(scalesData), server.MediaTypeImageTensor)
if err != nil {
return nil, err
}
layers := []create.LayerInfo{
{
Digest: weightLayer.Digest,
Size: weightLayer.Size,
MediaType: weightLayer.MediaType,
Name: name,
},
{
Digest: scalesLayer.Digest,
Size: scalesLayer.Size,
MediaType: scalesLayer.MediaType,
Name: name + "_scale",
},
}
// Add qbiases layer if present (affine mode)
if qbiasData != nil {
qbiasLayer, err := server.NewLayer(bytes.NewReader(qbiasData), server.MediaTypeImageTensor)
if err != nil {
return nil, err
}
layers = append(layers, create.LayerInfo{
Digest: qbiasLayer.Digest,
Size: qbiasLayer.Size,
MediaType: qbiasLayer.MediaType,
Name: name + "_qbias",
})
}
return layers, nil
}
// createUnquantizedLayer creates a single tensor layer without quantization.
func createUnquantizedLayer(r io.Reader, name string) ([]create.LayerInfo, error) {
layer, err := server.NewLayer(r, server.MediaTypeImageTensor)
if err != nil {
return nil, err
}
return []create.LayerInfo{
{
Digest: layer.Digest,
Size: layer.Size,
MediaType: layer.MediaType,
Name: name,
},
}, nil
}
// newManifestWriter returns a ManifestWriter callback for writing the model manifest.
func newManifestWriter(opts CreateOptions, capabilities []string) create.ManifestWriter {
return func(modelName string, config create.LayerInfo, layers []create.LayerInfo) error {
name := model.ParseName(modelName)
if !name.IsValid() {
return fmt.Errorf("invalid model name: %s", modelName)
}
// Create config blob with version requirement
configData := model.ConfigV2{
ModelFormat: "safetensors",
Capabilities: capabilities,
Requires: MinOllamaVersion,
}
configJSON, err := json.Marshal(configData)
if err != nil {
return fmt.Errorf("failed to marshal config: %w", err)
}
// Create config layer blob
configLayer, err := server.NewLayer(bytes.NewReader(configJSON), "application/vnd.docker.container.image.v1+json")
if err != nil {
return fmt.Errorf("failed to create config layer: %w", err)
}
// Convert LayerInfo to server.Layer
serverLayers := make([]server.Layer, len(layers))
for i, l := range layers {
serverLayers[i] = server.Layer{
MediaType: l.MediaType,
Digest: l.Digest,
Size: l.Size,
Name: l.Name,
}
}
// Add Modelfile layers if present
if opts.Modelfile != nil {
modelfileLayers, err := createModelfileLayers(opts.Modelfile)
if err != nil {
return err
}
serverLayers = append(serverLayers, modelfileLayers...)
}
return server.WriteManifest(name, configLayer, serverLayers)
}
}
// createModelfileLayers creates layers for template, system, and license from Modelfile config.
func createModelfileLayers(mf *ModelfileConfig) ([]server.Layer, error) {
var layers []server.Layer
if mf.Template != "" {
layer, err := server.NewLayer(bytes.NewReader([]byte(mf.Template)), "application/vnd.ollama.image.template")
if err != nil {
return nil, fmt.Errorf("failed to create template layer: %w", err)
}
layers = append(layers, layer)
}
if mf.System != "" {
layer, err := server.NewLayer(bytes.NewReader([]byte(mf.System)), "application/vnd.ollama.image.system")
if err != nil {
return nil, fmt.Errorf("failed to create system layer: %w", err)
}
layers = append(layers, layer)
}
if mf.License != "" {
layer, err := server.NewLayer(bytes.NewReader([]byte(mf.License)), "application/vnd.ollama.image.license")
if err != nil {
return nil, fmt.Errorf("failed to create license layer: %w", err)
}
layers = append(layers, layer)
}
return layers, nil
}

View File

@@ -0,0 +1,146 @@
package client
import (
"testing"
)
func TestModelfileConfig(t *testing.T) {
// Test that ModelfileConfig struct works as expected
config := &ModelfileConfig{
Template: "{{ .Prompt }}",
System: "You are a helpful assistant.",
License: "MIT",
}
if config.Template != "{{ .Prompt }}" {
t.Errorf("Template = %q, want %q", config.Template, "{{ .Prompt }}")
}
if config.System != "You are a helpful assistant." {
t.Errorf("System = %q, want %q", config.System, "You are a helpful assistant.")
}
if config.License != "MIT" {
t.Errorf("License = %q, want %q", config.License, "MIT")
}
}
func TestModelfileConfig_Empty(t *testing.T) {
config := &ModelfileConfig{}
if config.Template != "" {
t.Errorf("Template should be empty, got %q", config.Template)
}
if config.System != "" {
t.Errorf("System should be empty, got %q", config.System)
}
if config.License != "" {
t.Errorf("License should be empty, got %q", config.License)
}
}
func TestModelfileConfig_PartialFields(t *testing.T) {
// Test config with only some fields set
config := &ModelfileConfig{
Template: "{{ .Prompt }}",
// System and License intentionally empty
}
if config.Template == "" {
t.Error("Template should not be empty")
}
if config.System != "" {
t.Error("System should be empty")
}
if config.License != "" {
t.Error("License should be empty")
}
}
func TestMinOllamaVersion(t *testing.T) {
// Verify the minimum version constant is set
if MinOllamaVersion == "" {
t.Error("MinOllamaVersion should not be empty")
}
if MinOllamaVersion != "0.14.0" {
t.Errorf("MinOllamaVersion = %q, want %q", MinOllamaVersion, "0.14.0")
}
}
func TestCreateModel_InvalidDir(t *testing.T) {
// Test that CreateModel returns error for invalid directory
err := CreateModel(CreateOptions{
ModelName: "test-model",
ModelDir: "/nonexistent/path",
}, nil)
if err == nil {
t.Error("expected error for nonexistent directory, got nil")
}
}
func TestCreateModel_NotSafetensorsDir(t *testing.T) {
// Test that CreateModel returns error for directory without safetensors
dir := t.TempDir()
err := CreateModel(CreateOptions{
ModelName: "test-model",
ModelDir: dir,
}, nil)
if err == nil {
t.Error("expected error for empty directory, got nil")
}
}
func TestCreateOptions(t *testing.T) {
opts := CreateOptions{
ModelName: "my-model",
ModelDir: "/path/to/model",
Quantize: "fp8",
Modelfile: &ModelfileConfig{
Template: "test",
System: "system",
License: "MIT",
},
}
if opts.ModelName != "my-model" {
t.Errorf("ModelName = %q, want %q", opts.ModelName, "my-model")
}
if opts.ModelDir != "/path/to/model" {
t.Errorf("ModelDir = %q, want %q", opts.ModelDir, "/path/to/model")
}
if opts.Quantize != "fp8" {
t.Errorf("Quantize = %q, want %q", opts.Quantize, "fp8")
}
if opts.Modelfile == nil {
t.Error("Modelfile should not be nil")
}
if opts.Modelfile.Template != "test" {
t.Errorf("Modelfile.Template = %q, want %q", opts.Modelfile.Template, "test")
}
}
func TestCreateOptions_Defaults(t *testing.T) {
opts := CreateOptions{
ModelName: "test",
ModelDir: "/tmp",
}
// Quantize should default to empty
if opts.Quantize != "" {
t.Errorf("Quantize should be empty by default, got %q", opts.Quantize)
}
// Modelfile should default to nil
if opts.Modelfile != nil {
t.Error("Modelfile should be nil by default")
}
}
func TestQuantizeSupported(t *testing.T) {
// This just verifies the function exists and returns a boolean
// The actual value depends on build tags (mlx vs non-mlx)
supported := QuantizeSupported()
// In non-mlx builds, this should be false
// We can't easily test both cases, so just verify it returns something
_ = supported
}

391
x/create/create.go Normal file
View File

@@ -0,0 +1,391 @@
package create
import (
"encoding/json"
"fmt"
"io"
"os"
"path/filepath"
"slices"
"strings"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// ModelConfig represents the config blob stored with a model.
type ModelConfig struct {
ModelFormat string `json:"model_format"`
Capabilities []string `json:"capabilities"`
}
// Manifest represents the manifest JSON structure.
type Manifest struct {
SchemaVersion int `json:"schemaVersion"`
MediaType string `json:"mediaType"`
Config ManifestLayer `json:"config"`
Layers []ManifestLayer `json:"layers"`
}
// ManifestLayer represents a layer in the manifest.
type ManifestLayer struct {
MediaType string `json:"mediaType"`
Digest string `json:"digest"`
Size int64 `json:"size"`
Name string `json:"name,omitempty"`
}
// defaultManifestDir returns the manifest storage directory.
func defaultManifestDir() string {
return filepath.Join(envconfig.Models(), "manifests")
}
// defaultBlobDir returns the blob storage directory.
func defaultBlobDir() string {
return filepath.Join(envconfig.Models(), "blobs")
}
// resolveManifestPath converts a model name to a manifest file path.
func resolveManifestPath(modelName string) string {
host := "registry.ollama.ai"
namespace := "library"
name := modelName
tag := "latest"
if idx := strings.LastIndex(name, ":"); idx != -1 {
tag = name[idx+1:]
name = name[:idx]
}
parts := strings.Split(name, "/")
switch len(parts) {
case 3:
host = parts[0]
namespace = parts[1]
name = parts[2]
case 2:
namespace = parts[0]
name = parts[1]
}
return filepath.Join(defaultManifestDir(), host, namespace, name, tag)
}
// loadManifest loads a manifest for the given model name.
func loadManifest(modelName string) (*Manifest, error) {
manifestPath := resolveManifestPath(modelName)
data, err := os.ReadFile(manifestPath)
if err != nil {
return nil, err
}
var manifest Manifest
if err := json.Unmarshal(data, &manifest); err != nil {
return nil, err
}
return &manifest, nil
}
// loadModelConfig loads the config blob for a model.
func loadModelConfig(modelName string) (*ModelConfig, error) {
manifest, err := loadManifest(modelName)
if err != nil {
return nil, err
}
// Read the config blob
blobName := strings.Replace(manifest.Config.Digest, ":", "-", 1)
blobPath := filepath.Join(defaultBlobDir(), blobName)
data, err := os.ReadFile(blobPath)
if err != nil {
return nil, err
}
var config ModelConfig
if err := json.Unmarshal(data, &config); err != nil {
return nil, err
}
return &config, nil
}
// IsSafetensorsModel checks if a model was created with the experimental
// safetensors builder by checking the model format in the config.
func IsSafetensorsModel(modelName string) bool {
config, err := loadModelConfig(modelName)
if err != nil {
return false
}
return config.ModelFormat == "safetensors"
}
// IsSafetensorsLLMModel checks if a model is a safetensors LLM model
// (has completion capability, not image generation).
func IsSafetensorsLLMModel(modelName string) bool {
config, err := loadModelConfig(modelName)
if err != nil {
return false
}
return config.ModelFormat == "safetensors" && slices.Contains(config.Capabilities, "completion")
}
// IsImageGenModel checks if a model is an image generation model
// (has image capability).
func IsImageGenModel(modelName string) bool {
config, err := loadModelConfig(modelName)
if err != nil {
return false
}
return config.ModelFormat == "safetensors" && slices.Contains(config.Capabilities, "image")
}
// GetModelArchitecture returns the architecture from the model's config.json layer.
func GetModelArchitecture(modelName string) (string, error) {
manifest, err := loadManifest(modelName)
if err != nil {
return "", err
}
// Find the config.json layer
for _, layer := range manifest.Layers {
if layer.Name == "config.json" && layer.MediaType == "application/vnd.ollama.image.json" {
blobName := strings.Replace(layer.Digest, ":", "-", 1)
blobPath := filepath.Join(defaultBlobDir(), blobName)
data, err := os.ReadFile(blobPath)
if err != nil {
return "", err
}
var cfg struct {
Architectures []string `json:"architectures"`
ModelType string `json:"model_type"`
}
if err := json.Unmarshal(data, &cfg); err != nil {
return "", err
}
// Prefer model_type, fall back to first architecture
if cfg.ModelType != "" {
return cfg.ModelType, nil
}
if len(cfg.Architectures) > 0 {
return cfg.Architectures[0], nil
}
}
}
return "", fmt.Errorf("architecture not found in model config")
}
// IsTensorModelDir checks if the directory contains a diffusers-style tensor model
// by looking for model_index.json, which is the standard diffusers pipeline config.
func IsTensorModelDir(dir string) bool {
_, err := os.Stat(filepath.Join(dir, "model_index.json"))
return err == nil
}
// IsSafetensorsModelDir checks if the directory contains a standard safetensors model
// by looking for config.json and at least one .safetensors file.
func IsSafetensorsModelDir(dir string) bool {
// Must have config.json
if _, err := os.Stat(filepath.Join(dir, "config.json")); err != nil {
return false
}
// Must have at least one .safetensors file
entries, err := os.ReadDir(dir)
if err != nil {
return false
}
for _, entry := range entries {
if strings.HasSuffix(entry.Name(), ".safetensors") {
return true
}
}
return false
}
// LayerInfo holds metadata for a created layer.
type LayerInfo struct {
Digest string
Size int64
MediaType string
Name string // Path-style name: "component/tensor" or "path/to/config.json"
}
// LayerCreator is called to create a blob layer.
// name is the path-style name (e.g., "tokenizer/tokenizer.json")
type LayerCreator func(r io.Reader, mediaType, name string) (LayerInfo, error)
// TensorLayerCreator creates a tensor blob layer with metadata.
// name is the path-style name including component (e.g., "text_encoder/model.embed_tokens.weight")
type TensorLayerCreator func(r io.Reader, name, dtype string, shape []int32) (LayerInfo, error)
// QuantizingTensorLayerCreator creates tensor layers with optional quantization.
// When quantize is true, returns multiple layers (weight + scales + biases).
type QuantizingTensorLayerCreator func(r io.Reader, name, dtype string, shape []int32, quantize bool) ([]LayerInfo, error)
// ManifestWriter writes the manifest file.
type ManifestWriter func(modelName string, config LayerInfo, layers []LayerInfo) error
// ShouldQuantize returns true if a tensor should be quantized.
// For image gen models (component non-empty): quantizes linear weights, skipping VAE, embeddings, norms.
// For LLM models (component empty): quantizes linear weights, skipping embeddings, norms, and small tensors.
func ShouldQuantize(name, component string) bool {
// Image gen specific: skip VAE entirely
if component == "vae" {
return false
}
// Skip embeddings
if strings.Contains(name, "embed") {
return false
}
// Skip layer norms and RMS norms
if strings.Contains(name, "norm") || strings.Contains(name, "ln_") || strings.Contains(name, "layernorm") {
return false
}
// Skip biases
if strings.HasSuffix(name, ".bias") {
return false
}
// Only quantize weights
return strings.HasSuffix(name, ".weight")
}
// ShouldQuantizeTensor returns true if a tensor should be quantized based on name and shape.
// This is a more detailed check that also considers tensor dimensions.
func ShouldQuantizeTensor(name string, shape []int32) bool {
// Use basic name-based check first
if !ShouldQuantize(name, "") {
return false
}
// Only quantize 2D tensors (linear layers) - skip 1D (biases, norms) and higher-D (convolutions if any)
if len(shape) != 2 {
return false
}
// Skip small tensors (less than 1024 elements) - not worth quantizing
if len(shape) >= 2 && int64(shape[0])*int64(shape[1]) < 1024 {
return false
}
return true
}
// CreateSafetensorsModel imports a standard safetensors model from a directory.
// This handles Hugging Face style models with config.json and *.safetensors files.
// Stores each tensor as a separate blob for fine-grained deduplication.
// If quantize is non-empty (e.g., "fp8"), eligible tensors will be quantized.
func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
var layers []LayerInfo
var configLayer LayerInfo
entries, err := os.ReadDir(modelDir)
if err != nil {
return fmt.Errorf("failed to read directory: %w", err)
}
// Process all safetensors files
for _, entry := range entries {
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".safetensors") {
continue
}
stPath := filepath.Join(modelDir, entry.Name())
// Extract individual tensors from safetensors file
extractor, err := safetensors.OpenForExtraction(stPath)
if err != nil {
return fmt.Errorf("failed to open %s: %w", stPath, err)
}
tensorNames := extractor.ListTensors()
quantizeMsg := ""
if quantize != "" {
quantizeMsg = fmt.Sprintf(", quantizing to %s", quantize)
}
fn(fmt.Sprintf("importing %s (%d tensors%s)", entry.Name(), len(tensorNames), quantizeMsg))
for _, tensorName := range tensorNames {
td, err := extractor.GetTensor(tensorName)
if err != nil {
extractor.Close()
return fmt.Errorf("failed to get tensor %s: %w", tensorName, err)
}
// Determine if this tensor should be quantized
doQuantize := quantize != "" && ShouldQuantizeTensor(tensorName, td.Shape)
// Store as minimal safetensors format (88 bytes header overhead)
// This enables native mmap loading via mlx_load_safetensors
// createTensorLayer returns multiple layers if quantizing (weight + scales)
newLayers, err := createTensorLayer(td.SafetensorsReader(), tensorName, td.Dtype, td.Shape, doQuantize)
if err != nil {
extractor.Close()
return fmt.Errorf("failed to create layer for %s: %w", tensorName, err)
}
layers = append(layers, newLayers...)
}
extractor.Close()
}
// Process all JSON config files
for _, entry := range entries {
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".json") {
continue
}
// Skip the index file as we don't need it after extraction
if entry.Name() == "model.safetensors.index.json" {
continue
}
cfgPath := entry.Name()
fullPath := filepath.Join(modelDir, cfgPath)
fn(fmt.Sprintf("importing config %s", cfgPath))
f, err := os.Open(fullPath)
if err != nil {
return fmt.Errorf("failed to open %s: %w", cfgPath, err)
}
layer, err := createLayer(f, "application/vnd.ollama.image.json", cfgPath)
f.Close()
if err != nil {
return fmt.Errorf("failed to create layer for %s: %w", cfgPath, err)
}
// Use config.json as the config layer
if cfgPath == "config.json" {
configLayer = layer
}
layers = append(layers, layer)
}
if configLayer.Digest == "" {
return fmt.Errorf("config.json not found in %s", modelDir)
}
fn(fmt.Sprintf("writing manifest for %s", modelName))
if err := writeManifest(modelName, configLayer, layers); err != nil {
return fmt.Errorf("failed to write manifest: %w", err)
}
fn(fmt.Sprintf("successfully imported %s with %d layers", modelName, len(layers)))
return nil
}

752
x/create/create_test.go Normal file
View File

@@ -0,0 +1,752 @@
package create
import (
"bytes"
"encoding/binary"
"encoding/json"
"io"
"os"
"path/filepath"
"strings"
"testing"
)
func TestIsTensorModelDir(t *testing.T) {
tests := []struct {
name string
setup func(dir string) error
expected bool
}{
{
name: "valid diffusers model with model_index.json",
setup: func(dir string) error {
return os.WriteFile(filepath.Join(dir, "model_index.json"), []byte(`{"_class_name": "FluxPipeline"}`), 0644)
},
expected: true,
},
{
name: "empty directory",
setup: func(dir string) error {
return nil
},
expected: false,
},
{
name: "directory with other files but no model_index.json",
setup: func(dir string) error {
return os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{}`), 0644)
},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
dir := t.TempDir()
if err := tt.setup(dir); err != nil {
t.Fatalf("setup failed: %v", err)
}
got := IsTensorModelDir(dir)
if got != tt.expected {
t.Errorf("IsTensorModelDir() = %v, want %v", got, tt.expected)
}
})
}
}
func TestIsSafetensorsModelDir(t *testing.T) {
tests := []struct {
name string
setup func(dir string) error
expected bool
}{
{
name: "valid safetensors model with config.json and .safetensors file",
setup: func(dir string) error {
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{"model_type": "gemma3"}`), 0644); err != nil {
return err
}
return os.WriteFile(filepath.Join(dir, "model.safetensors"), []byte("dummy"), 0644)
},
expected: true,
},
{
name: "config.json only, no safetensors files",
setup: func(dir string) error {
return os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{}`), 0644)
},
expected: false,
},
{
name: "safetensors file only, no config.json",
setup: func(dir string) error {
return os.WriteFile(filepath.Join(dir, "model.safetensors"), []byte("dummy"), 0644)
},
expected: false,
},
{
name: "empty directory",
setup: func(dir string) error {
return nil
},
expected: false,
},
{
name: "multiple safetensors files with config.json",
setup: func(dir string) error {
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{}`), 0644); err != nil {
return err
}
if err := os.WriteFile(filepath.Join(dir, "model-00001-of-00002.safetensors"), []byte("dummy"), 0644); err != nil {
return err
}
return os.WriteFile(filepath.Join(dir, "model-00002-of-00002.safetensors"), []byte("dummy"), 0644)
},
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
dir := t.TempDir()
if err := tt.setup(dir); err != nil {
t.Fatalf("setup failed: %v", err)
}
got := IsSafetensorsModelDir(dir)
if got != tt.expected {
t.Errorf("IsSafetensorsModelDir() = %v, want %v", got, tt.expected)
}
})
}
}
func TestIsSafetensorsModelDir_NonexistentDir(t *testing.T) {
got := IsSafetensorsModelDir("/nonexistent/path/that/does/not/exist")
if got != false {
t.Errorf("IsSafetensorsModelDir() = %v for nonexistent dir, want false", got)
}
}
// createMinimalSafetensors creates a minimal valid safetensors file with one tensor
func createMinimalSafetensors(t *testing.T, path string) {
t.Helper()
// Create a minimal safetensors file with a single float32 tensor
header := map[string]interface{}{
"test_tensor": map[string]interface{}{
"dtype": "F32",
"shape": []int{2, 2},
"data_offsets": []int{0, 16}, // 4 float32 values = 16 bytes
},
}
headerJSON, err := json.Marshal(header)
if err != nil {
t.Fatalf("failed to marshal header: %v", err)
}
// Pad header to 8-byte alignment
padding := (8 - len(headerJSON)%8) % 8
headerJSON = append(headerJSON, bytes.Repeat([]byte(" "), padding)...)
// Write file
f, err := os.Create(path)
if err != nil {
t.Fatalf("failed to create file: %v", err)
}
defer f.Close()
// Write header size (8 bytes, little endian)
if err := binary.Write(f, binary.LittleEndian, uint64(len(headerJSON))); err != nil {
t.Fatalf("failed to write header size: %v", err)
}
// Write header
if _, err := f.Write(headerJSON); err != nil {
t.Fatalf("failed to write header: %v", err)
}
// Write tensor data (16 bytes of zeros for 4 float32 values)
if _, err := f.Write(make([]byte, 16)); err != nil {
t.Fatalf("failed to write tensor data: %v", err)
}
}
func TestCreateSafetensorsModel(t *testing.T) {
dir := t.TempDir()
// Create config.json
configJSON := `{"model_type": "test", "architectures": ["TestModel"]}`
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0644); err != nil {
t.Fatalf("failed to write config.json: %v", err)
}
// Create a minimal safetensors file
createMinimalSafetensors(t, filepath.Join(dir, "model.safetensors"))
// Track what was created
var createdLayers []LayerInfo
var manifestWritten bool
var manifestModelName string
var manifestConfigLayer LayerInfo
var manifestLayers []LayerInfo
var statusMessages []string
// Mock callbacks
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
data, err := io.ReadAll(r)
if err != nil {
return LayerInfo{}, err
}
layer := LayerInfo{
Digest: "sha256:test",
Size: int64(len(data)),
MediaType: mediaType,
Name: name,
}
createdLayers = append(createdLayers, layer)
return layer, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize bool) ([]LayerInfo, error) {
data, err := io.ReadAll(r)
if err != nil {
return nil, err
}
layer := LayerInfo{
Digest: "sha256:tensor_" + name,
Size: int64(len(data)),
MediaType: "application/vnd.ollama.image.tensor",
Name: name,
}
createdLayers = append(createdLayers, layer)
return []LayerInfo{layer}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
manifestWritten = true
manifestModelName = modelName
manifestConfigLayer = config
manifestLayers = layers
return nil
}
progressFn := func(status string) {
statusMessages = append(statusMessages, status)
}
// Run CreateSafetensorsModel
err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
if err != nil {
t.Fatalf("CreateSafetensorsModel failed: %v", err)
}
// Verify manifest was written
if !manifestWritten {
t.Error("manifest was not written")
}
if manifestModelName != "test-model" {
t.Errorf("manifest model name = %q, want %q", manifestModelName, "test-model")
}
// Verify config layer was set
if manifestConfigLayer.Name != "config.json" {
t.Errorf("config layer name = %q, want %q", manifestConfigLayer.Name, "config.json")
}
// Verify we have at least one tensor and one config layer
hasTensor := false
hasConfig := false
for _, layer := range manifestLayers {
if layer.Name == "test_tensor" {
hasTensor = true
}
if layer.Name == "config.json" {
hasConfig = true
}
}
if !hasTensor {
t.Error("no tensor layer found in manifest")
}
if !hasConfig {
t.Error("no config layer found in manifest")
}
// Verify status messages were sent
if len(statusMessages) == 0 {
t.Error("no status messages received")
}
}
func TestCreateSafetensorsModel_NoConfigJson(t *testing.T) {
dir := t.TempDir()
// Create only a safetensors file, no config.json
createMinimalSafetensors(t, filepath.Join(dir, "model.safetensors"))
// Mock callbacks (minimal)
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
io.ReadAll(r)
return LayerInfo{Name: name}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize bool) ([]LayerInfo, error) {
io.ReadAll(r)
return []LayerInfo{{Name: name}}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
return nil
}
progressFn := func(status string) {}
err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
if err == nil {
t.Error("expected error for missing config.json, got nil")
}
}
func TestCreateSafetensorsModel_EmptyDir(t *testing.T) {
dir := t.TempDir()
// Mock callbacks
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
return LayerInfo{}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize bool) ([]LayerInfo, error) {
return []LayerInfo{{}}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
return nil
}
progressFn := func(status string) {}
err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
if err == nil {
t.Error("expected error for empty directory, got nil")
}
}
func TestCreateSafetensorsModel_SkipsIndexJson(t *testing.T) {
dir := t.TempDir()
// Create config.json
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{}`), 0644); err != nil {
t.Fatalf("failed to write config.json: %v", err)
}
// Create model.safetensors.index.json (should be skipped)
indexJSON := `{"metadata": {"total_size": 100}, "weight_map": {}}`
if err := os.WriteFile(filepath.Join(dir, "model.safetensors.index.json"), []byte(indexJSON), 0644); err != nil {
t.Fatalf("failed to write index.json: %v", err)
}
// Create a minimal safetensors file
createMinimalSafetensors(t, filepath.Join(dir, "model.safetensors"))
var configNames []string
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
io.ReadAll(r)
configNames = append(configNames, name)
return LayerInfo{Name: name, Digest: "sha256:test"}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize bool) ([]LayerInfo, error) {
io.ReadAll(r)
return []LayerInfo{{Name: name}}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
return nil
}
progressFn := func(status string) {}
err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
if err != nil {
t.Fatalf("CreateSafetensorsModel failed: %v", err)
}
// Verify model.safetensors.index.json was not included
for _, name := range configNames {
if name == "model.safetensors.index.json" {
t.Error("model.safetensors.index.json should have been skipped")
}
}
}
func TestResolveManifestPath(t *testing.T) {
tests := []struct {
name string
modelName string
wantParts []string // Parts that should appear in the path
}{
{
name: "simple model name",
modelName: "llama2",
wantParts: []string{"registry.ollama.ai", "library", "llama2", "latest"},
},
{
name: "model name with tag",
modelName: "llama2:7b",
wantParts: []string{"registry.ollama.ai", "library", "llama2", "7b"},
},
{
name: "model name with namespace",
modelName: "myuser/mymodel",
wantParts: []string{"registry.ollama.ai", "myuser", "mymodel", "latest"},
},
{
name: "model name with namespace and tag",
modelName: "myuser/mymodel:v1",
wantParts: []string{"registry.ollama.ai", "myuser", "mymodel", "v1"},
},
{
name: "fully qualified model name",
modelName: "registry.example.com/namespace/model:tag",
wantParts: []string{"registry.example.com", "namespace", "model", "tag"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := resolveManifestPath(tt.modelName)
for _, part := range tt.wantParts {
if !strings.Contains(got, part) {
t.Errorf("resolveManifestPath(%q) = %q, missing part %q", tt.modelName, got, part)
}
}
})
}
}
func TestLayerInfo(t *testing.T) {
layer := LayerInfo{
Digest: "sha256:abc123",
Size: 1024,
MediaType: "application/vnd.ollama.image.tensor",
Name: "model.weight",
}
if layer.Digest != "sha256:abc123" {
t.Errorf("Digest = %q, want %q", layer.Digest, "sha256:abc123")
}
if layer.Size != 1024 {
t.Errorf("Size = %d, want %d", layer.Size, 1024)
}
if layer.MediaType != "application/vnd.ollama.image.tensor" {
t.Errorf("MediaType = %q, want %q", layer.MediaType, "application/vnd.ollama.image.tensor")
}
if layer.Name != "model.weight" {
t.Errorf("Name = %q, want %q", layer.Name, "model.weight")
}
}
func TestModelConfig(t *testing.T) {
config := ModelConfig{
ModelFormat: "safetensors",
Capabilities: []string{"completion", "chat"},
}
if config.ModelFormat != "safetensors" {
t.Errorf("ModelFormat = %q, want %q", config.ModelFormat, "safetensors")
}
if len(config.Capabilities) != 2 {
t.Errorf("Capabilities length = %d, want %d", len(config.Capabilities), 2)
}
}
func TestManifest(t *testing.T) {
manifest := Manifest{
SchemaVersion: 2,
MediaType: "application/vnd.oci.image.manifest.v1+json",
Config: ManifestLayer{
MediaType: "application/vnd.docker.container.image.v1+json",
Digest: "sha256:config",
Size: 100,
},
Layers: []ManifestLayer{
{
MediaType: "application/vnd.ollama.image.tensor",
Digest: "sha256:layer1",
Size: 1000,
Name: "weight.bin",
},
},
}
if manifest.SchemaVersion != 2 {
t.Errorf("SchemaVersion = %d, want %d", manifest.SchemaVersion, 2)
}
if manifest.Config.Digest != "sha256:config" {
t.Errorf("Config.Digest = %q, want %q", manifest.Config.Digest, "sha256:config")
}
if len(manifest.Layers) != 1 {
t.Errorf("Layers length = %d, want %d", len(manifest.Layers), 1)
}
if manifest.Layers[0].Name != "weight.bin" {
t.Errorf("Layers[0].Name = %q, want %q", manifest.Layers[0].Name, "weight.bin")
}
}
func TestShouldQuantize(t *testing.T) {
tests := []struct {
name string
tensor string
component string
want bool
}{
// VAE component should never be quantized
{"vae weight", "decoder.weight", "vae", false},
{"vae bias", "decoder.bias", "vae", false},
// Embeddings should not be quantized
{"embedding weight", "embed_tokens.weight", "", false},
{"embedding in name", "token_embedding.weight", "", false},
// Norms should not be quantized
{"layer norm", "layer_norm.weight", "", false},
{"rms norm", "rms_norm.weight", "", false},
{"ln prefix", "ln_1.weight", "", false},
{"layernorm in name", "input_layernorm.weight", "", false},
// Biases should not be quantized
{"bias tensor", "attention.bias", "", false},
{"proj bias", "o_proj.bias", "", false},
// Linear weights should be quantized
{"linear weight", "q_proj.weight", "", true},
{"attention weight", "self_attn.weight", "", true},
{"mlp weight", "mlp.gate_proj.weight", "", true},
// Transformer component weights should be quantized
{"transformer weight", "layers.0.weight", "transformer", true},
{"text_encoder weight", "encoder.weight", "text_encoder", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := ShouldQuantize(tt.tensor, tt.component)
if got != tt.want {
t.Errorf("ShouldQuantize(%q, %q) = %v, want %v", tt.tensor, tt.component, got, tt.want)
}
})
}
}
func TestShouldQuantizeTensor(t *testing.T) {
tests := []struct {
name string
tensor string
shape []int32
want bool
}{
// 2D tensors with sufficient size should be quantized
{"large 2D weight", "q_proj.weight", []int32{4096, 4096}, true},
{"medium 2D weight", "small_proj.weight", []int32{128, 128}, true},
// Small tensors should not be quantized (< 1024 elements)
{"tiny 2D weight", "tiny.weight", []int32{16, 16}, false},
{"small 2D weight", "small.weight", []int32{31, 31}, false},
// 1D tensors should not be quantized
{"1D tensor", "layer_norm.weight", []int32{4096}, false},
// 3D+ tensors should not be quantized
{"3D tensor", "conv.weight", []int32{64, 64, 3}, false},
{"4D tensor", "conv2d.weight", []int32{64, 64, 3, 3}, false},
// Embeddings should not be quantized regardless of shape
{"embedding 2D", "embed_tokens.weight", []int32{32000, 4096}, false},
// Norms should not be quantized regardless of shape
{"norm 2D", "layer_norm.weight", []int32{4096, 1}, false},
// Biases should not be quantized
{"bias 2D", "proj.bias", []int32{4096, 1}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := ShouldQuantizeTensor(tt.tensor, tt.shape)
if got != tt.want {
t.Errorf("ShouldQuantizeTensor(%q, %v) = %v, want %v", tt.tensor, tt.shape, got, tt.want)
}
})
}
}
func TestCreateSafetensorsModel_WithQuantize(t *testing.T) {
dir := t.TempDir()
// Create config.json
configJSON := `{"model_type": "test", "architectures": ["TestModel"]}`
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0644); err != nil {
t.Fatalf("failed to write config.json: %v", err)
}
// Create a minimal safetensors file
createMinimalSafetensors(t, filepath.Join(dir, "model.safetensors"))
var quantizeRequested []bool
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
io.ReadAll(r)
return LayerInfo{Name: name, Digest: "sha256:test"}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize bool) ([]LayerInfo, error) {
io.ReadAll(r)
quantizeRequested = append(quantizeRequested, quantize)
return []LayerInfo{{Name: name}}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
return nil
}
progressFn := func(status string) {}
// Run with quantize enabled
err := CreateSafetensorsModel("test-model", dir, "fp8", createLayer, createTensorLayer, writeManifest, progressFn)
if err != nil {
t.Fatalf("CreateSafetensorsModel failed: %v", err)
}
// Verify quantize was passed to callback (will be false for small test tensor)
if len(quantizeRequested) == 0 {
t.Error("no tensors processed")
}
}
// createMinimalImageGenModel creates a minimal diffusers-style model directory
func createMinimalImageGenModel(t *testing.T, dir string) {
t.Helper()
// Create model_index.json
modelIndex := `{"_class_name": "FluxPipeline", "_diffusers_version": "0.30.0"}`
if err := os.WriteFile(filepath.Join(dir, "model_index.json"), []byte(modelIndex), 0644); err != nil {
t.Fatalf("failed to write model_index.json: %v", err)
}
// Create transformer directory with a safetensors file
transformerDir := filepath.Join(dir, "transformer")
if err := os.MkdirAll(transformerDir, 0755); err != nil {
t.Fatalf("failed to create transformer dir: %v", err)
}
createMinimalSafetensors(t, filepath.Join(transformerDir, "model.safetensors"))
// Create transformer config
transformerConfig := `{"hidden_size": 3072}`
if err := os.WriteFile(filepath.Join(transformerDir, "config.json"), []byte(transformerConfig), 0644); err != nil {
t.Fatalf("failed to write transformer config: %v", err)
}
}
func TestCreateImageGenModel(t *testing.T) {
dir := t.TempDir()
createMinimalImageGenModel(t, dir)
var manifestWritten bool
var manifestModelName string
var statusMessages []string
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
io.ReadAll(r)
return LayerInfo{Name: name, Digest: "sha256:test"}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize bool) ([]LayerInfo, error) {
io.ReadAll(r)
return []LayerInfo{{Name: name, Digest: "sha256:tensor"}}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
manifestWritten = true
manifestModelName = modelName
return nil
}
progressFn := func(status string) {
statusMessages = append(statusMessages, status)
}
err := CreateImageGenModel("test-imagegen", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
if err != nil {
t.Fatalf("CreateImageGenModel failed: %v", err)
}
if !manifestWritten {
t.Error("manifest was not written")
}
if manifestModelName != "test-imagegen" {
t.Errorf("manifest model name = %q, want %q", manifestModelName, "test-imagegen")
}
if len(statusMessages) == 0 {
t.Error("no status messages received")
}
}
func TestCreateImageGenModel_NoModelIndex(t *testing.T) {
dir := t.TempDir()
// Create only transformer without model_index.json
transformerDir := filepath.Join(dir, "transformer")
if err := os.MkdirAll(transformerDir, 0755); err != nil {
t.Fatalf("failed to create transformer dir: %v", err)
}
createMinimalSafetensors(t, filepath.Join(transformerDir, "model.safetensors"))
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
io.ReadAll(r)
return LayerInfo{Name: name}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize bool) ([]LayerInfo, error) {
io.ReadAll(r)
return []LayerInfo{{Name: name}}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
return nil
}
progressFn := func(status string) {}
err := CreateImageGenModel("test-imagegen", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
if err == nil {
t.Error("expected error for missing model_index.json, got nil")
}
}
func TestCreateImageGenModel_WithQuantize(t *testing.T) {
dir := t.TempDir()
createMinimalImageGenModel(t, dir)
var quantizeRequested []bool
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
io.ReadAll(r)
return LayerInfo{Name: name, Digest: "sha256:test"}, nil
}
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize bool) ([]LayerInfo, error) {
io.ReadAll(r)
quantizeRequested = append(quantizeRequested, quantize)
return []LayerInfo{{Name: name}}, nil
}
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
return nil
}
progressFn := func(status string) {}
err := CreateImageGenModel("test-imagegen", dir, "fp8", createLayer, createTensorLayer, writeManifest, progressFn)
if err != nil {
t.Fatalf("CreateImageGenModel failed: %v", err)
}
if len(quantizeRequested) == 0 {
t.Error("no tensors processed")
}
}

View File

@@ -1,4 +1,4 @@
package imagegen
package create
import (
"bytes"
@@ -12,43 +12,17 @@ import (
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// IsTensorModelDir checks if the directory contains a tensor model
// by looking for model_index.json, which is the standard diffusers pipeline config.
func IsTensorModelDir(dir string) bool {
_, err := os.Stat(filepath.Join(dir, "model_index.json"))
return err == nil
}
// LayerInfo holds metadata for a created layer.
type LayerInfo struct {
Digest string
Size int64
MediaType string
Name string // Path-style name: "component/tensor" or "path/to/config.json"
}
// LayerCreator is called to create a blob layer.
// name is the path-style name (e.g., "tokenizer/tokenizer.json")
type LayerCreator func(r io.Reader, mediaType, name string) (LayerInfo, error)
// TensorLayerCreator creates a tensor blob layer with metadata.
// name is the path-style name including component (e.g., "text_encoder/model.embed_tokens.weight")
type TensorLayerCreator func(r io.Reader, name, dtype string, shape []int32) (LayerInfo, error)
// ManifestWriter writes the manifest file.
type ManifestWriter func(modelName string, config LayerInfo, layers []LayerInfo) error
// CreateModel imports an image generation model from a directory.
// CreateImageGenModel imports an image generation model from a directory.
// Stores each tensor as a separate blob for fine-grained deduplication.
// If quantize is "fp8", linear weights in transformer/text_encoder are quantized to mxfp8 format.
// Layer creation and manifest writing are done via callbacks to avoid import cycles.
func CreateModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
func CreateImageGenModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
var layers []LayerInfo
var configLayer LayerInfo
var totalParams int64 // Count parameters from original tensor shapes
// Components to process - extract individual tensors from each
components := []string{"text_encoder", "transformer", "vae", "vision_language_encoder"}
components := []string{"text_encoder", "transformer", "vae"}
for _, component := range components {
componentDir := filepath.Join(modelDir, component)
@@ -126,13 +100,10 @@ func CreateModel(modelName, modelDir, quantize string, createLayer LayerCreator,
"text_encoder/generation_config.json",
"transformer/config.json",
"vae/config.json",
"vision_language_encoder/config.json",
"scheduler/scheduler_config.json",
"tokenizer/tokenizer.json",
"tokenizer/tokenizer_config.json",
"tokenizer/vocab.json",
"processor/tokenizer.json", // GLM-Image main tokenizer
"processor/tokenizer_config.json", // GLM-Image tokenizer config
}
for _, cfgPath := range configFiles {

View File

@@ -1,190 +0,0 @@
// Package client provides client-side model creation for tensor-based models.
//
// This package is in x/ because the tensor model storage format is under development.
// It also exists to break an import cycle: server imports x/imagegen, so x/imagegen
// cannot import server. This sub-package can import server because server doesn't
// import it.
//
// TODO (jmorganca): This is temporary. When tensor models are promoted to production:
// 1. Add proper API endpoints for tensor model creation
// 2. Move tensor extraction to server-side
// 3. Remove this package
// 4. Follow the same client→server pattern as regular model creation
package client
import (
"bytes"
"encoding/json"
"fmt"
"io"
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/server"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/x/imagegen"
)
// MinOllamaVersion is the minimum Ollama version required for image generation models.
const MinOllamaVersion = "0.14.0"
// CreateModel imports a tensor-based model from a local directory.
// This creates blobs and manifest directly on disk, bypassing the HTTP API.
// If quantize is "fp8", weights will be quantized to mxfp8 format during import.
//
// TODO (jmorganca): Replace with API-based creation when promoted to production.
func CreateModel(modelName, modelDir, quantize string, p *progress.Progress) error {
if !imagegen.IsTensorModelDir(modelDir) {
return fmt.Errorf("%s is not an image generation model directory (model_index.json not found)", modelDir)
}
status := "importing image generation model"
spinner := progress.NewSpinner(status)
p.Add("imagegen", spinner)
// Create layer callback for config files
createLayer := func(r io.Reader, mediaType, name string) (imagegen.LayerInfo, error) {
layer, err := server.NewLayer(r, mediaType)
if err != nil {
return imagegen.LayerInfo{}, err
}
layer.Name = name
return imagegen.LayerInfo{
Digest: layer.Digest,
Size: layer.Size,
MediaType: layer.MediaType,
Name: name,
}, nil
}
// Create tensor layer callback for individual tensors
// name is path-style: "component/tensor_name"
// When quantize is true, returns multiple layers (weight + scales)
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, doQuantize bool) ([]imagegen.LayerInfo, error) {
if doQuantize {
// Check if quantization is supported
if !QuantizeSupported() {
return nil, fmt.Errorf("quantization requires MLX support")
}
// Quantize the tensor (affine mode returns weight, scales, qbiases)
qweightData, scalesData, qbiasData, _, _, _, err := quantizeTensor(r, name, dtype, shape)
if err != nil {
return nil, fmt.Errorf("failed to quantize %s: %w", name, err)
}
// Create layer for quantized weight
weightLayer, err := server.NewLayer(bytes.NewReader(qweightData), server.MediaTypeImageTensor)
if err != nil {
return nil, err
}
// Create layer for scales (use _scale suffix convention)
scalesLayer, err := server.NewLayer(bytes.NewReader(scalesData), server.MediaTypeImageTensor)
if err != nil {
return nil, err
}
layers := []imagegen.LayerInfo{
{
Digest: weightLayer.Digest,
Size: weightLayer.Size,
MediaType: weightLayer.MediaType,
Name: name, // Keep original name for weight
},
{
Digest: scalesLayer.Digest,
Size: scalesLayer.Size,
MediaType: scalesLayer.MediaType,
Name: name + "_scale", // Add _scale suffix
},
}
// Add qbiases layer if present (affine mode)
if qbiasData != nil {
qbiasLayer, err := server.NewLayer(bytes.NewReader(qbiasData), server.MediaTypeImageTensor)
if err != nil {
return nil, err
}
layers = append(layers, imagegen.LayerInfo{
Digest: qbiasLayer.Digest,
Size: qbiasLayer.Size,
MediaType: qbiasLayer.MediaType,
Name: name + "_qbias", // Add _qbias suffix
})
}
return layers, nil
}
// Non-quantized path: just create a single layer
layer, err := server.NewLayer(r, server.MediaTypeImageTensor)
if err != nil {
return nil, err
}
return []imagegen.LayerInfo{
{
Digest: layer.Digest,
Size: layer.Size,
MediaType: layer.MediaType,
Name: name,
},
}, nil
}
// Create manifest writer callback
writeManifest := func(modelName string, config imagegen.LayerInfo, layers []imagegen.LayerInfo) error {
name := model.ParseName(modelName)
if !name.IsValid() {
return fmt.Errorf("invalid model name: %s", modelName)
}
// Create a proper config blob with version requirement
configData := model.ConfigV2{
ModelFormat: "safetensors",
Capabilities: []string{"image"},
Requires: MinOllamaVersion,
}
configJSON, err := json.Marshal(configData)
if err != nil {
return fmt.Errorf("failed to marshal config: %w", err)
}
// Create config layer blob
configLayer, err := server.NewLayer(bytes.NewReader(configJSON), "application/vnd.docker.container.image.v1+json")
if err != nil {
return fmt.Errorf("failed to create config layer: %w", err)
}
// Convert LayerInfo to server.Layer (include the original model_index.json in layers)
serverLayers := make([]server.Layer, len(layers))
for i, l := range layers {
serverLayers[i] = server.Layer{
MediaType: l.MediaType,
Digest: l.Digest,
Size: l.Size,
Name: l.Name,
}
}
return server.WriteManifest(name, configLayer, serverLayers)
}
// Progress callback
progressFn := func(msg string) {
spinner.Stop()
status = msg
spinner = progress.NewSpinner(status)
p.Add("imagegen", spinner)
}
err := imagegen.CreateModel(modelName, modelDir, quantize, createLayer, createTensorLayer, writeManifest, progressFn)
spinner.Stop()
if err != nil {
return err
}
fmt.Printf("Created image generation model '%s'\n", modelName)
return nil
}

View File

@@ -11,11 +11,9 @@ import (
"os"
"path/filepath"
"runtime/pprof"
"strings"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/models/gemma3"
"github.com/ollama/ollama/x/imagegen/models/glm_image"
"github.com/ollama/ollama/x/imagegen/models/gpt_oss"
"github.com/ollama/ollama/x/imagegen/models/llama"
"github.com/ollama/ollama/x/imagegen/models/qwen_image"
@@ -63,7 +61,6 @@ func main() {
// Legacy mode flags
zimageFlag := flag.Bool("zimage", false, "Z-Image generation")
glmImageFlag := flag.Bool("glm-image", false, "GLM-Image generation")
qwenImage := flag.Bool("qwen-image", false, "Qwen-Image text-to-image generation")
qwenImageEdit := flag.Bool("qwen-image-edit", false, "Qwen-Image-Edit image editing")
var inputImages stringSlice
@@ -120,33 +117,6 @@ func main() {
if err == nil {
err = saveImageArray(img, *out)
}
case *glmImageFlag:
m := &glm_image.Model{}
// Use LoadFromPath if model path looks like a directory, otherwise use Load (ollama manifest)
var loadErr error
if strings.HasPrefix(*modelPath, ".") || strings.HasPrefix(*modelPath, "/") {
loadErr = m.LoadFromPath(*modelPath)
} else {
loadErr = m.Load(*modelPath)
}
if loadErr != nil {
log.Fatal(loadErr)
}
var img *mlx.Array
img, err = m.GenerateFromConfig(context.Background(), &glm_image.GenerateConfig{
Prompt: *prompt,
Width: int32(*width),
Height: int32(*height),
Steps: *steps,
Seed: *seed,
Temperature: float32(*temperature),
TopP: float32(*topP),
GuidanceScale: float32(*cfgScale),
MaxVisualTokens: int32(*maxTokens),
})
if err == nil {
err = saveImageArray(img, *out)
}
case *qwenImage:
m, loadErr := qwen_image.LoadPersistent(*modelPath)
if loadErr != nil {

View File

@@ -1,19 +0,0 @@
# Image generation models (experimental)
Experimental image generation models are available for **macOS** in Ollama:
## Available models
- [Z-Image-Turbo](https://ollama.com/x/z-image-turbo)
```
ollama run x/z-image-turbo
```
> **Note**: [`x`](https://ollama.com/x) is a username on ollama.com where the maintainer team uploads experimental models
More models coming soon:
1. Qwen-Image-2512
2. Qwen-Image-Edit-2511
3. GLM-Image

View File

@@ -27,7 +27,6 @@ var modelVRAMEstimates = map[string]uint64{
"ZImagePipeline": 21 * GB, // ~21GB for Z-Image (text encoder + transformer + VAE)
"FluxPipeline": 21 * GB, // ~21GB for Flux (same architecture)
"QwenImagePipeline": 80 * GB, // TODO: verify actual requirements, using conservative estimate for now
"GlmImagePipeline": 80 * GB, // ~34GB weights + ~46GB working memory for 9B+7B hybrid model
}
// CheckPlatformSupport validates that image generation is supported on the current platform.

View File

@@ -1,693 +0,0 @@
//go:build mlx
// Package glm_image implements the GLM-Image hybrid AR + diffusion model.
package glm_image
import (
"context"
"fmt"
"math"
"path/filepath"
"time"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// ByT5Tokenizer is a simple byte-level tokenizer for ByT5
// ByT5 uses bytes as tokens: each byte (0-255) maps to token ID (3-258)
// Special tokens: 0=pad, 1=eos, 2=unk
type ByT5Tokenizer struct {
PadTokenID int32
EOSTokenID int32
UNKTokenID int32
}
// NewByT5Tokenizer creates a new ByT5 tokenizer
func NewByT5Tokenizer() *ByT5Tokenizer {
return &ByT5Tokenizer{
PadTokenID: 0,
EOSTokenID: 1,
UNKTokenID: 2,
}
}
// Encode converts a string to token IDs
func (t *ByT5Tokenizer) Encode(text string) []int32 {
bytes := []byte(text)
tokens := make([]int32, len(bytes))
for i, b := range bytes {
// Standard ByT5 tokenization: bytes 0-255 map to tokens 3-258
// (tokens 0, 1, 2 are PAD, EOS, UNK)
tokens[i] = int32(b) + 3
}
return tokens
}
// Decode converts token IDs back to a string
func (t *ByT5Tokenizer) Decode(tokens []int32) string {
bytes := make([]byte, 0, len(tokens))
for _, tok := range tokens {
if tok >= 3 && tok < 259 {
bytes = append(bytes, byte(tok-3))
}
}
return string(bytes)
}
// GenerateConfig holds all options for image generation.
type GenerateConfig struct {
Prompt string
NegativePrompt string // For CFG (optional, not typically used with GLM-Image)
GuidanceScale float32 // Guidance scale (default: 1.5)
Width int32 // Image width (default: 1024, must be divisible by 32)
Height int32 // Image height (default: 1024, must be divisible by 32)
Steps int // Diffusion denoising steps (default: 50)
Seed int64 // Random seed
Progress ProgressFunc // Optional progress callback
// AR generation options
MaxVisualTokens int32 // Max visual tokens to generate (default: 256)
Temperature float32 // AR sampling temperature (default: 0.9)
TopP float32 // Nucleus sampling (default: 0.75)
}
// ProgressFunc is called during generation with stage and step progress.
type ProgressFunc func(stage string, step, totalSteps int)
// Model represents a GLM-Image hybrid model.
type Model struct {
ModelName string
Tokenizer *ByT5Tokenizer // For T5 text encoder (glyph embeddings)
GLMTokenizer *GLMTokenizer // For AR model (visual token generation)
TextEncoder *T5TextEncoder
VisionLanguageEncoder *VisionLanguageEncoder
Transformer *DiffusionTransformer
VAEDecoder *VAEDecoder
}
// Load loads the GLM-Image model from ollama blob storage.
func (m *Model) Load(modelName string) error {
fmt.Printf("Loading GLM-Image model from manifest: %s...\n", modelName)
start := time.Now()
if mlx.GPUIsAvailable() {
mlx.SetDefaultDeviceGPU()
mlx.EnableCompile()
}
m.ModelName = modelName
// Load manifest
manifest, err := imagegen.LoadManifest(modelName)
if err != nil {
return fmt.Errorf("load manifest: %w", err)
}
// Create ByT5 tokenizer (byte-level, no vocabulary file needed)
// Used for T5 text encoder (glyph embeddings)
fmt.Print(" Creating ByT5 tokenizer... ")
m.Tokenizer = NewByT5Tokenizer()
fmt.Println("✓")
// Load GLM tokenizer for AR model (visual token generation)
fmt.Print(" Loading GLM tokenizer... ")
glmTok, err := NewGLMTokenizer(manifest)
if err != nil {
return fmt.Errorf("glm tokenizer: %w", err)
}
m.GLMTokenizer = glmTok
fmt.Println("✓")
// Load T5 text encoder (~830MB)
m.TextEncoder = &T5TextEncoder{}
if err := m.TextEncoder.Load(manifest); err != nil {
return fmt.Errorf("text encoder: %w", err)
}
mlx.Eval(mlx.Collect(m.TextEncoder)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
// Load vision-language encoder (~19GB, 9B params)
m.VisionLanguageEncoder = &VisionLanguageEncoder{}
if err := m.VisionLanguageEncoder.Load(manifest); err != nil {
return fmt.Errorf("vision language encoder: %w", err)
}
mlx.Eval(mlx.Collect(m.VisionLanguageEncoder)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
// Load diffusion transformer (~13GB, 7B params)
m.Transformer = &DiffusionTransformer{}
if err := m.Transformer.Load(manifest); err != nil {
return fmt.Errorf("transformer: %w", err)
}
mlx.Eval(mlx.Collect(m.Transformer)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
// Load VAE decoder (~775MB)
m.VAEDecoder = &VAEDecoder{}
if err := m.VAEDecoder.Load(manifest); err != nil {
return fmt.Errorf("VAE decoder: %w", err)
}
mlx.Eval(mlx.Collect(m.VAEDecoder)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
mem := mlx.MetalGetActiveMemory()
fmt.Printf(" Loaded in %.2fs (%.1f GB VRAM)\n", time.Since(start).Seconds(), float64(mem)/(1024*1024*1024))
return nil
}
// LoadFromPath loads the model from a directory path (not ollama manifest)
func (m *Model) LoadFromPath(modelPath string) error {
fmt.Printf("Loading GLM-Image model from path: %s...\n", modelPath)
start := time.Now()
if mlx.GPUIsAvailable() {
mlx.SetDefaultDeviceGPU()
mlx.EnableCompile()
}
m.ModelName = modelPath
// Create ByT5 tokenizer (byte-level, no vocabulary file needed)
fmt.Print(" Creating ByT5 tokenizer... ")
m.Tokenizer = NewByT5Tokenizer()
fmt.Println("✓")
// Load GLM tokenizer for AR model (visual token generation)
fmt.Print(" Loading GLM tokenizer... ")
glmTok, err := NewGLMTokenizerFromPath(modelPath)
if err != nil {
return fmt.Errorf("glm tokenizer: %w", err)
}
m.GLMTokenizer = glmTok
fmt.Println("✓")
// Load T5 text encoder
m.TextEncoder = &T5TextEncoder{}
if err := m.TextEncoder.LoadFromPath(filepath.Join(modelPath, "text_encoder")); err != nil {
return fmt.Errorf("text encoder: %w", err)
}
mlx.Eval(mlx.Collect(m.TextEncoder)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
// Load vision-language encoder
m.VisionLanguageEncoder = &VisionLanguageEncoder{}
if err := m.VisionLanguageEncoder.LoadFromPath(filepath.Join(modelPath, "vision_language_encoder")); err != nil {
return fmt.Errorf("vision language encoder: %w", err)
}
mlx.Eval(mlx.Collect(m.VisionLanguageEncoder)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
// Load diffusion transformer
m.Transformer = &DiffusionTransformer{}
if err := m.Transformer.LoadFromPath(filepath.Join(modelPath, "transformer")); err != nil {
return fmt.Errorf("transformer: %w", err)
}
mlx.Eval(mlx.Collect(m.Transformer)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
// Load VAE decoder
m.VAEDecoder = &VAEDecoder{}
if err := m.VAEDecoder.LoadFromPath(filepath.Join(modelPath, "vae")); err != nil {
return fmt.Errorf("VAE decoder: %w", err)
}
mlx.Eval(mlx.Collect(m.VAEDecoder)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
mem := mlx.MetalGetActiveMemory()
fmt.Printf(" Loaded in %.2fs (%.1f GB VRAM)\n", time.Since(start).Seconds(), float64(mem)/(1024*1024*1024))
return nil
}
// Generate creates an image from a prompt.
func (m *Model) Generate(prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
return m.GenerateFromConfig(context.Background(), &GenerateConfig{
Prompt: prompt,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
})
}
// GenerateWithProgress creates an image with progress callback.
func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress ProgressFunc) (*mlx.Array, error) {
return m.GenerateFromConfig(context.Background(), &GenerateConfig{
Prompt: prompt,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
Progress: progress,
})
}
// GenerateFromConfig generates an image using the unified config struct.
func (m *Model) GenerateFromConfig(ctx context.Context, cfg *GenerateConfig) (*mlx.Array, error) {
start := time.Now()
result, err := m.generate(ctx, cfg)
if err != nil {
return nil, err
}
fmt.Printf("Generated in %.2fs (%d diffusion steps)\n", time.Since(start).Seconds(), cfg.Steps)
return result, nil
}
// GenerateImage implements model.ImageModel interface.
func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
return m.Generate(prompt, width, height, steps, seed)
}
// generate is the internal generation pipeline.
func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array, error) {
// Apply defaults
if cfg.Width <= 0 {
cfg.Width = 1024
}
if cfg.Height <= 0 {
cfg.Height = 1024
}
if cfg.Steps <= 0 {
cfg.Steps = 50
}
if cfg.GuidanceScale <= 0 {
cfg.GuidanceScale = 1.5
}
// Calculate MaxVisualTokens based on image dimensions
// GLM-Image generates TWO grids of visual tokens:
// 1. First: prev (small) grid - prevTokenH × prevTokenW tokens
// 2. Then: target (large) grid - tokenH × tokenW tokens
// After generation, we extract only the TARGET grid tokens for diffusion.
factor := int32(32)
tokenH := cfg.Height / factor
tokenW := cfg.Width / factor
targetGridTokens := tokenH * tokenW
// Compute prev grid dimensions using diffusers formula:
// ratio = token_h / token_w
// prev_token_h = int(sqrt(ratio) * 16)
// prev_token_w = int(sqrt(1/ratio) * 16)
ratio := float64(tokenH) / float64(tokenW)
prevTokenH := int32(math.Sqrt(ratio) * 16)
prevTokenW := int32(math.Sqrt(1/ratio) * 16)
prevGridTokens := prevTokenH * prevTokenW
// Total tokens to generate = prev grid + target grid
// (diffusers does max_new_tokens = total + 1 for EOS, but we stop on EOS anyway)
cfg.MaxVisualTokens = prevGridTokens + targetGridTokens
if cfg.Temperature <= 0 {
cfg.Temperature = 0.9
}
if cfg.TopP <= 0 {
cfg.TopP = 0.75
}
// Ensure dimensions are divisible by 32
cfg.Width = (cfg.Width / 32) * 32
cfg.Height = (cfg.Height / 32) * 32
tcfg := m.Transformer.Config
latentH := cfg.Height / 8
latentW := cfg.Width / 8
// Progress callback helper
progress := func(stage string, step, total int) {
if cfg.Progress != nil {
cfg.Progress(stage, step, total)
}
}
// === PHASE 1: T5 Text Encoding ===
fmt.Println("[T5] Encoding glyph text...")
progress("text_encoding", 0, 1)
textEmbed := m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.Prompt)
mlx.Keep(textEmbed)
mlx.Eval(textEmbed)
fmt.Printf("[T5] Done, shape: %v\n", textEmbed.Shape())
progress("text_encoding", 1, 1)
// === PHASE 2: AR Visual Token Generation ===
fmt.Printf("[AR] Generating %d visual tokens...\n", cfg.MaxVisualTokens)
progress("ar_generation", 0, int(cfg.MaxVisualTokens))
visualTokens := m.VisionLanguageEncoder.Generate(
cfg.Prompt,
m.GLMTokenizer,
cfg.MaxVisualTokens,
cfg.Temperature,
cfg.TopP,
cfg.Seed,
cfg.Height,
cfg.Width,
func(step int) {
if step%100 == 0 || step < 10 {
fmt.Printf("[AR] Step %d/%d\n", step, cfg.MaxVisualTokens)
}
progress("ar_generation", step, int(cfg.MaxVisualTokens))
},
)
mlx.Keep(visualTokens)
mlx.Eval(visualTokens)
fmt.Printf("[AR] Done generating visual tokens\n")
progress("ar_generation", int(cfg.MaxVisualTokens), int(cfg.MaxVisualTokens))
vtShape := visualTokens.Shape()
totalGenerated := vtShape[1]
fmt.Printf("[AR] Generated %d tokens total\n", totalGenerated)
// Extract only the TARGET grid tokens (skip the prev grid tokens)
// diffusers: large_image_tokens = outputs[input_length + large_image_start_offset : ...]
// large_image_start_offset = prev_grid_size
var targetGridVisualTokens *mlx.Array
if totalGenerated >= prevGridTokens+targetGridTokens {
// Full generation completed - extract target grid
targetGridVisualTokens = mlx.Slice(visualTokens,
[]int32{0, prevGridTokens},
[]int32{1, prevGridTokens + targetGridTokens})
mlx.Keep(targetGridVisualTokens)
mlx.Eval(targetGridVisualTokens)
} else if totalGenerated > prevGridTokens {
// Partial target grid - take what we have
actualTargetTokens := totalGenerated - prevGridTokens
targetGridVisualTokens = mlx.Slice(visualTokens,
[]int32{0, prevGridTokens},
[]int32{1, totalGenerated})
mlx.Keep(targetGridVisualTokens)
mlx.Eval(targetGridVisualTokens)
fmt.Printf("WARNING: Partial target grid: got %d/%d target tokens\n",
actualTargetTokens, targetGridTokens)
} else {
// Not enough tokens - EOS came too early
return nil, fmt.Errorf("AR generation stopped too early: got %d tokens, need at least %d (prev grid) + 1",
totalGenerated, prevGridTokens)
}
// === PHASE 3: Diffusion Decoding ===
// Setup scheduler with dynamic shift based on image size
scheduler := NewFlowMatchScheduler(DefaultSchedulerConfig())
imgSeqLen := (latentH / tcfg.PatchSize) * (latentW / tcfg.PatchSize)
scheduler.SetTimestepsWithDynamicShift(cfg.Steps, imgSeqLen)
// Initialize noise latents [B, C, H, W]
latents := scheduler.InitNoise([]int32{1, tcfg.InChannels, latentH, latentW}, cfg.Seed)
mlx.Eval(latents)
// Upsample TARGET grid visual tokens 2x to match patch count (matching diffusers)
// target_grid tokens -> 2x upsample -> patch_count
// e.g., 32x32=1024 tokens -> 64x64=4096 patches for 1024x1024
visualTokensUpsampled := upsampleTokens(targetGridVisualTokens, tokenH, tokenW, 2)
// Prepare prior embeddings from upsampled visual tokens (VQ codebook lookup + projection)
priorEmbed := m.Transformer.EmbedPriorTokens(visualTokensUpsampled)
mlx.Keep(priorEmbed)
mlx.Eval(priorEmbed)
// Prepare text conditioning (project T5 embeddings)
textCond := m.Transformer.ProjectTextEmbeddings(textEmbed)
mlx.Keep(textCond)
mlx.Eval(textCond)
// === CFG Setup ===
// For classifier-free guidance, we need unconditional (negative) text embeddings
// GLM-Image uses empty string "" for negative prompt
doCFG := cfg.GuidanceScale > 1.0
var negativeTextCond *mlx.Array
if doCFG {
// Encode empty string for negative prompt
negativeTextEmbed := m.TextEncoder.EncodePrompt(m.Tokenizer, "")
mlx.Keep(negativeTextEmbed)
mlx.Eval(negativeTextEmbed)
negativeTextCond = m.Transformer.ProjectTextEmbeddings(negativeTextEmbed)
mlx.Keep(negativeTextCond)
mlx.Eval(negativeTextCond)
negativeTextEmbed.Free()
}
// Prepare conditioning inputs
targetSize := mlx.NewArray([]float32{float32(cfg.Height), float32(cfg.Width)}, []int32{1, 2})
cropCoords := mlx.NewArray([]float32{0, 0}, []int32{1, 2}) // Default: no crop offset
targetSize = mlx.ToBFloat16(targetSize)
cropCoords = mlx.ToBFloat16(cropCoords)
mlx.Keep(targetSize)
mlx.Keep(cropCoords)
mlx.Eval(targetSize, cropCoords)
pH := latentH / tcfg.PatchSize
pW := latentW / tcfg.PatchSize
// Denoising loop
fmt.Printf("[Diffusion] Starting %d denoising steps...\n", cfg.Steps)
progress("diffusion", 0, cfg.Steps)
for i := 0; i < cfg.Steps; i++ {
fmt.Printf("[Diffusion] Step %d/%d (timestep=%.1f)\n", i+1, cfg.Steps, scheduler.Timesteps[i]-1)
// Check for cancellation
if ctx != nil {
select {
case <-ctx.Done():
textEmbed.Free()
visualTokens.Free()
// visualTokensUpsampled points to visualTokens, don't double-free
priorEmbed.Free()
textCond.Free()
latents.Free()
return nil, ctx.Err()
default:
}
}
// Get timestep value for the transformer
// scheduler.Timesteps contains raw timestep values (1000 down to ~20)
// Pass timestep - 1 to match diffusers: timestep = t.expand(latents.shape[0]) - 1
timestepVal := scheduler.Timesteps[i] - 1
timestep := mlx.ToBFloat16(mlx.NewArray([]float32{timestepVal}, []int32{1}))
// Patchify latents [B, C, H, W] -> [B, L, C*p*p]
patches := PatchifyLatents(latents, tcfg.PatchSize)
// Transformer forward with MMDiT architecture
// Conditional pass (with text + prior embeddings)
outputCond := m.Transformer.ForwardWithPriorDrop(
patches,
priorEmbed,
textCond,
timestep,
targetSize,
cropCoords,
pH,
pW,
false, // priorTokenDrop = false for conditional
)
// Unpatchify [B, L, C*p*p] -> [B, C, H, W]
noisePredCond := UnpatchifyLatents(outputCond, latentH, latentW, tcfg.PatchSize, tcfg.OutChannels)
var noisePred *mlx.Array
if doCFG {
// Unconditional pass (empty text, dropped prior embeddings)
outputUncond := m.Transformer.ForwardWithPriorDrop(
patches,
priorEmbed, // Still passed but will be ignored due to priorTokenDrop=true
negativeTextCond,
timestep,
targetSize,
cropCoords,
pH,
pW,
true, // priorTokenDrop = true for unconditional
)
noisePredUncond := UnpatchifyLatents(outputUncond, latentH, latentW, tcfg.PatchSize, tcfg.OutChannels)
// CFG formula: noise_pred = uncond + guidance_scale * (cond - uncond)
diff := mlx.Sub(noisePredCond, noisePredUncond)
scaled := mlx.MulScalar(diff, cfg.GuidanceScale)
noisePred = mlx.Add(noisePredUncond, scaled)
} else {
noisePred = noisePredCond
}
// Scheduler step
oldLatents := latents
latents = scheduler.Step(noisePred, latents, i)
mlx.Eval(latents)
oldLatents.Free()
progress("diffusion", i+1, cfg.Steps)
}
// Cleanup intermediate arrays
textEmbed.Free()
visualTokens.Free()
// visualTokensUpsampled points to visualTokens, don't double-free
priorEmbed.Free()
textCond.Free()
if negativeTextCond != nil {
negativeTextCond.Free()
}
targetSize.Free()
cropCoords.Free()
// === PHASE 4: VAE Decode ===
progress("vae_decode", 0, 1)
decoded := m.VAEDecoder.Decode(latents)
mlx.Eval(decoded)
latents.Free()
progress("vae_decode", 1, 1)
return decoded, nil
}
// upsampleTokens performs nearest-neighbor upsampling of visual tokens
// Converts from prev_grid (e.g., 16x16) to target_grid (e.g., 32x32 for 2x, 64x64 for 4x)
// scale must be 2 or 4
//
// Handles early EOS gracefully: if tokens has fewer than prevH*prevW elements,
// missing tokens are padded with 0 (visual token padding value).
func upsampleTokens(tokens *mlx.Array, prevH, prevW int32, scale int32) *mlx.Array {
// tokens: [1, N] where N <= prevH*prevW (may be shorter if early EOS)
// Each token at (i, j) becomes scale*scale tokens in the output
mlx.Eval(tokens)
tokenData := tokens.DataInt32()
numTokens := int32(len(tokenData))
expectedTokens := prevH * prevW
// Warn if we got fewer tokens than expected (early EOS)
if numTokens < expectedTokens {
fmt.Printf("WARNING: upsampleTokens got %d tokens, expected %d (padding with 0)\n",
numTokens, expectedTokens)
}
targetH := prevH * scale
targetW := prevW * scale
upsampled := make([]int32, targetH*targetW)
for i := int32(0); i < prevH; i++ {
for j := int32(0); j < prevW; j++ {
srcIdx := i*prevW + j
// Handle early EOS: use 0 (padding) for missing tokens
var val int32
if srcIdx < numTokens {
val = tokenData[srcIdx]
} else {
val = 0 // Padding token
}
// Place in scale*scale positions
dstI := i * scale
dstJ := j * scale
for di := int32(0); di < scale; di++ {
for dj := int32(0); dj < scale; dj++ {
upsampled[(dstI+di)*targetW+(dstJ+dj)] = val
}
}
}
}
return mlx.NewArrayInt32(upsampled, []int32{1, targetH * targetW})
}
// PatchifyLatents converts [B, C, H, W] to [B, L, C*p*p]
func PatchifyLatents(latents *mlx.Array, patchSize int32) *mlx.Array {
shape := latents.Shape()
B := shape[0]
C := shape[1]
H := shape[2]
W := shape[3]
pH := H / patchSize
pW := W / patchSize
// Reshape: [B, C, H, W] -> [B, C, pH, p, pW, p]
x := mlx.Reshape(latents, B, C, pH, patchSize, pW, patchSize)
// Transpose: -> [B, pH, pW, C, p, p]
x = mlx.Transpose(x, 0, 2, 4, 1, 3, 5)
// Flatten: -> [B, pH*pW, C*p*p]
return mlx.Reshape(x, B, pH*pW, C*patchSize*patchSize)
}
// UnpatchifyLatents converts [B, L, C*p*p] back to [B, C, H, W]
func UnpatchifyLatents(patches *mlx.Array, H, W, patchSize, channels int32) *mlx.Array {
shape := patches.Shape()
B := shape[0]
pH := H / patchSize
pW := W / patchSize
// Reshape: [B, L, C*p*p] -> [B, pH, pW, C, p, p]
x := mlx.Reshape(patches, B, pH, pW, channels, patchSize, patchSize)
// Transpose: -> [B, C, pH, p, pW, p]
x = mlx.Transpose(x, 0, 3, 1, 4, 2, 5)
// Reshape: -> [B, C, H, W]
return mlx.Reshape(x, B, channels, pH*patchSize, pW*patchSize)
}
// CalculateShift computes the dynamic shift for flow matching based on image sequence length.
func CalculateShift(imgSeqLen int32) float32 {
cfg := DefaultSchedulerConfig()
if !cfg.UseDynamicShifting {
return 0
}
// Sqrt-based shift calculation (matches diffusers)
m := float32(math.Sqrt(float64(imgSeqLen) / float64(cfg.BaseImageSeqLen)))
return m*cfg.MaxShift + cfg.BaseShift
}
// UpsampleTokens2x upsamples token IDs by 2x using nearest neighbor interpolation
// tokens: [B, H*W] -> [B, (H*2)*(W*2)]
// This matches diffusers' _upsample_token_ids function
func UpsampleTokens2x(tokens *mlx.Array, gridH, gridW int32) *mlx.Array {
shape := tokens.Shape()
B := shape[0]
// Reshape to [B, 1, H, W] for interpolation
tokens = mlx.Reshape(tokens, B, 1, gridH, gridW)
// Convert to float for interpolation
tokensFloat := mlx.AsType(tokens, mlx.DtypeFloat32)
// 2x nearest neighbor upsample
// [B, 1, H, W] -> [B, 1, H*2, W*2]
upsampled := nearestUpsample2x(tokensFloat)
// Convert back to int and reshape to [B, H*2*W*2]
upsampled = mlx.AsType(upsampled, mlx.DtypeInt32)
return mlx.Reshape(upsampled, B, gridH*2*gridW*2)
}
// nearestUpsample2x performs 2x nearest neighbor upsampling on NCHW tensor
func nearestUpsample2x(x *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
C := shape[1]
H := shape[2]
W := shape[3]
// Repeat each element 2x2
// [B, C, H, W] -> [B, C, H, 1, W, 1] -> [B, C, H, 2, W, 2] -> [B, C, H*2, W*2]
x = mlx.Reshape(x, B, C, H, 1, W, 1)
// Tile to repeat each pixel 2x2
x = mlx.Tile(x, []int32{1, 1, 1, 2, 1, 2})
// Reshape to final size
return mlx.Reshape(x, B, C, H*2, W*2)
}

View File

@@ -1,358 +0,0 @@
//go:build mlx
package glm_image
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"sort"
"strings"
"github.com/ollama/ollama/x/imagegen"
)
// GLMTokenizer implements the GLM tokenizer for the AR model
// This is a BPE-style tokenizer with ignore_merges=true, meaning it does
// greedy longest-match tokenization from the vocab without runtime merging.
type GLMTokenizer struct {
Vocab map[string]int32 // token string -> token ID
VocabReverse map[int32]string // token ID -> token string
SpecialTokens map[string]int32 // special token strings -> IDs
// Special token IDs
SopTokenID int32 // <sop> = grid_bos_token (167845)
EopTokenID int32 // <eop> = grid_eos_token (167846)
BosTokenID int32 // <|dit_token_16384|> = visual BOS (16384)
EosTokenID int32 // <|dit_token_16385|> = visual EOS (16385)
PadTokenID int32
// Sorted vocab keys by length (longest first) for greedy matching
sortedTokens []string
}
// tokenizerJSON represents the structure of tokenizer.json
type tokenizerJSON struct {
Model struct {
Vocab map[string]int32 `json:"vocab"`
} `json:"model"`
AddedTokens []struct {
ID int32 `json:"id"`
Content string `json:"content"`
Special bool `json:"special"`
} `json:"added_tokens"`
}
// NewGLMTokenizer creates a GLM tokenizer from the model manifest
func NewGLMTokenizer(manifest *imagegen.ModelManifest) (*GLMTokenizer, error) {
// Read tokenizer.json from processor directory in manifest
data, err := manifest.ReadConfig("processor/tokenizer.json")
if err != nil {
return nil, fmt.Errorf("failed to read tokenizer.json from manifest: %w", err)
}
var tj tokenizerJSON
if err := json.Unmarshal(data, &tj); err != nil {
return nil, fmt.Errorf("failed to parse tokenizer.json: %w", err)
}
tok := &GLMTokenizer{
Vocab: make(map[string]int32),
VocabReverse: make(map[int32]string),
SpecialTokens: make(map[string]int32),
}
// Load vocab from model section
for token, id := range tj.Model.Vocab {
tok.Vocab[token] = id
tok.VocabReverse[id] = token
}
// Load added tokens (special tokens including dit_tokens)
for _, at := range tj.AddedTokens {
tok.Vocab[at.Content] = at.ID
tok.VocabReverse[at.ID] = at.Content
if at.Special {
tok.SpecialTokens[at.Content] = at.ID
}
}
// Set special token IDs
tok.SopTokenID = 167845 // <sop>
tok.EopTokenID = 167846 // <eop>
tok.BosTokenID = 16384 // <|dit_token_16384|>
tok.EosTokenID = 16385 // <|dit_token_16385|>
tok.PadTokenID = 16385 // Same as EOS
// Build sorted token list for greedy matching (longest first)
tok.sortedTokens = make([]string, 0, len(tok.Vocab))
for token := range tok.Vocab {
tok.sortedTokens = append(tok.sortedTokens, token)
}
sort.Slice(tok.sortedTokens, func(i, j int) bool {
return len(tok.sortedTokens[i]) > len(tok.sortedTokens[j])
})
fmt.Printf("Loaded GLM tokenizer with %d tokens\n", len(tok.Vocab))
return tok, nil
}
// NewGLMTokenizerFromPath creates a GLM tokenizer from a directory path
func NewGLMTokenizerFromPath(modelPath string) (*GLMTokenizer, error) {
// Read tokenizer.json from processor directory
tokenizerPath := filepath.Join(modelPath, "processor", "tokenizer.json")
data, err := os.ReadFile(tokenizerPath)
if err != nil {
return nil, fmt.Errorf("failed to read tokenizer.json: %w", err)
}
var tj tokenizerJSON
if err := json.Unmarshal(data, &tj); err != nil {
return nil, fmt.Errorf("failed to parse tokenizer.json: %w", err)
}
tok := &GLMTokenizer{
Vocab: make(map[string]int32),
VocabReverse: make(map[int32]string),
SpecialTokens: make(map[string]int32),
}
// Load vocab from model section
for token, id := range tj.Model.Vocab {
tok.Vocab[token] = id
tok.VocabReverse[id] = token
}
// Load added tokens (special tokens including dit_tokens)
for _, at := range tj.AddedTokens {
tok.Vocab[at.Content] = at.ID
tok.VocabReverse[at.ID] = at.Content
if at.Special {
tok.SpecialTokens[at.Content] = at.ID
}
}
// Set special token IDs
tok.SopTokenID = 167845 // <sop>
tok.EopTokenID = 167846 // <eop>
tok.BosTokenID = 16384 // <|dit_token_16384|>
tok.EosTokenID = 16385 // <|dit_token_16385|>
tok.PadTokenID = 16385 // Same as EOS
// Build sorted token list for greedy matching (longest first)
tok.sortedTokens = make([]string, 0, len(tok.Vocab))
for token := range tok.Vocab {
tok.sortedTokens = append(tok.sortedTokens, token)
}
sort.Slice(tok.sortedTokens, func(i, j int) bool {
return len(tok.sortedTokens[i]) > len(tok.sortedTokens[j])
})
fmt.Printf("Loaded GLM tokenizer with %d tokens\n", len(tok.Vocab))
return tok, nil
}
// Encode tokenizes a string into token IDs
// This uses greedy longest-match tokenization with GPT-2 style space handling
func (t *GLMTokenizer) Encode(text string) []int32 {
if text == "" {
return []int32{}
}
var tokens []int32
// First, check for and handle special tokens
// Replace special tokens with placeholders, encode, then restore
specialReplacements := make(map[string]int32)
for special, id := range t.SpecialTokens {
if strings.Contains(text, special) {
specialReplacements[special] = id
}
}
// Process text character by character with special token handling
i := 0
isFirstToken := true
for i < len(text) {
// Check for special tokens first
foundSpecial := false
for special, id := range specialReplacements {
if strings.HasPrefix(text[i:], special) {
tokens = append(tokens, id)
i += len(special)
isFirstToken = false
foundSpecial = true
break
}
}
if foundSpecial {
continue
}
// Handle regular text with GPT-2 style space prefix
// "Ġ" (U+0120) represents a space before a token
remaining := text[i:]
// Try to find the longest matching token
matched := false
for _, token := range t.sortedTokens {
// Skip special tokens in regular matching
if _, isSpecial := t.SpecialTokens[token]; isSpecial {
continue
}
// Check if this token matches
tokenText := token
// Handle the Ġ prefix (represents space)
if strings.HasPrefix(token, "Ġ") {
// This token expects a leading space
if i > 0 || !isFirstToken {
// Check if remaining starts with space + token content
tokenContent := token[len("Ġ"):]
if strings.HasPrefix(remaining, " "+tokenContent) {
tokens = append(tokens, t.Vocab[token])
i += 1 + len(tokenContent) // space + content
isFirstToken = false
matched = true
break
}
}
} else {
// Regular token without space prefix
if strings.HasPrefix(remaining, tokenText) {
tokens = append(tokens, t.Vocab[token])
i += len(tokenText)
isFirstToken = false
matched = true
break
}
}
}
if !matched {
// No token found - skip this character (or use UNK)
// For now, just skip unknown characters
i++
}
}
return tokens
}
// EncodeForGeneration encodes a prompt with grid tokens for image generation
// Format: {prompt}<sop>{token_h} {token_w}<eop><sop>{prev_h} {prev_w}<eop><|dit_token_16384|>
//
// Uses GPT-2 style tokenization where " 32" becomes "Ġ32" (a single token with
// space prefix), matching the HuggingFace tokenizer behavior.
func (t *GLMTokenizer) EncodeForGeneration(prompt string, targetHeight, targetWidth int32) []int32 {
// Calculate grid dimensions
factor := int32(32)
height := (targetHeight / factor) * factor
width := (targetWidth / factor) * factor
tokenH := height / factor
tokenW := width / factor
// Calculate previous grid dimensions
ratio := float64(tokenH) / float64(tokenW)
prevTokenH := int32(sqrt(ratio) * 16)
prevTokenW := int32(sqrt(1.0/ratio) * 16)
// Encode the prompt text
promptTokens := t.Encode(prompt)
// Build the full sequence:
// [prompt tokens] <sop> [tokenH] [Ġ+tokenW] <eop> <sop> [prevH] [Ġ+prevW] <eop> <bos>
// Note: HF tokenizer treats " 32" as "Ġ32" (single token), not "Ġ" + "32"
var tokens []int32
tokens = append(tokens, promptTokens...)
// First grid: <sop> H W <eop>
// First number has no space prefix, second number has space prefix (Ġ)
tokens = append(tokens, t.SopTokenID)
tokens = append(tokens, t.encodeNumber(tokenH)...)
tokens = append(tokens, t.encodeSpaceNumber(tokenW)...) // " W" as Ġ+W
tokens = append(tokens, t.EopTokenID)
// Second grid: <sop> prevH prevW <eop>
tokens = append(tokens, t.SopTokenID)
tokens = append(tokens, t.encodeNumber(prevTokenH)...)
tokens = append(tokens, t.encodeSpaceNumber(prevTokenW)...) // " prevW" as Ġ+prevW
tokens = append(tokens, t.EopTokenID)
// BOS token (start of image generation)
tokens = append(tokens, t.BosTokenID)
return tokens
}
// encodeNumber encodes a number - first tries as a whole token, falls back to digit-by-digit
func (t *GLMTokenizer) encodeNumber(n int32) []int32 {
s := fmt.Sprintf("%d", n)
// First try: look up the whole number as a single token
if id, ok := t.Vocab[s]; ok {
return []int32{id}
}
// Fallback: encode digit by digit
var tokens []int32
for _, c := range s {
if id, ok := t.Vocab[string(c)]; ok {
tokens = append(tokens, id)
}
}
return tokens
}
// encodeSpaceNumber encodes " N" as "ĠN" (space-prefixed number) matching HF tokenizer
// GPT-2 style: " 32" becomes single token "Ġ32", not "Ġ" + "32"
func (t *GLMTokenizer) encodeSpaceNumber(n int32) []int32 {
s := fmt.Sprintf("%d", n)
// First try: look up "Ġ{number}" as a single token (e.g., "Ġ32")
spaceToken := "Ġ" + s
if id, ok := t.Vocab[spaceToken]; ok {
return []int32{id}
}
// Fallback: bare space Ġ + number tokens
var tokens []int32
if spaceID, ok := t.Vocab["Ġ"]; ok {
tokens = append(tokens, spaceID)
}
tokens = append(tokens, t.encodeNumber(n)...)
return tokens
}
// sqrt is a helper for float64 sqrt
func sqrt(x float64) float64 {
if x <= 0 {
return 0
}
// Newton's method
z := x
for i := 0; i < 10; i++ {
z = z - (z*z-x)/(2*z)
}
return z
}
// Decode converts token IDs back to a string
func (t *GLMTokenizer) Decode(tokens []int32) string {
var sb strings.Builder
for _, id := range tokens {
if token, ok := t.VocabReverse[id]; ok {
// Handle Ġ prefix (convert back to space)
if strings.HasPrefix(token, "Ġ") {
sb.WriteString(" ")
sb.WriteString(token[len("Ġ"):])
} else {
sb.WriteString(token)
}
}
}
return sb.String()
}

View File

@@ -1,159 +0,0 @@
//go:build mlx
package glm_image
import (
"math"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// FlowMatchSchedulerConfig holds scheduler configuration
type FlowMatchSchedulerConfig struct {
NumTrainTimesteps int32 `json:"num_train_timesteps"` // 1000
BaseShift float32 `json:"base_shift"` // 0.25
MaxShift float32 `json:"max_shift"` // 0.75
BaseImageSeqLen int32 `json:"base_image_seq_len"` // 256
MaxImageSeqLen int32 `json:"max_image_seq_len"` // 4096
UseDynamicShifting bool `json:"use_dynamic_shifting"` // true
TimeShiftType string `json:"time_shift_type"` // "linear"
}
// DefaultSchedulerConfig returns the default config for GLM-Image
func DefaultSchedulerConfig() *FlowMatchSchedulerConfig {
return &FlowMatchSchedulerConfig{
NumTrainTimesteps: 1000,
BaseShift: 0.25,
MaxShift: 0.75,
BaseImageSeqLen: 256,
MaxImageSeqLen: 4096,
UseDynamicShifting: true,
TimeShiftType: "linear",
}
}
// FlowMatchScheduler implements FlowMatchEulerDiscreteScheduler
type FlowMatchScheduler struct {
Config *FlowMatchSchedulerConfig
Timesteps []float32 // Raw timesteps for transformer conditioning (unshifted)
Sigmas []float32 // Shifted sigmas for Euler step calculation
NumSteps int
}
// NewFlowMatchScheduler creates a new scheduler
func NewFlowMatchScheduler(cfg *FlowMatchSchedulerConfig) *FlowMatchScheduler {
return &FlowMatchScheduler{Config: cfg}
}
// SetTimestepsWithDynamicShift sets timesteps with dynamic shifting based on image size
// Following diffusers: raw timesteps are used for conditioning, shifted sigmas for step calculation
func (s *FlowMatchScheduler) SetTimestepsWithDynamicShift(numSteps int, imgSeqLen int32) {
s.NumSteps = numSteps
// Calculate shift (mu) based on image sequence length
mu := s.calculateShift(imgSeqLen)
// Create timesteps: linspace from sigma_max_t to sigma_min_t
// sigma_max = 1.0, sigma_min ~= 0.001 (near 0 but not exactly 0)
// Then apply time shift and append terminal sigma=0
s.Timesteps = make([]float32, numSteps)
s.Sigmas = make([]float32, numSteps+1) // +1 for terminal sigma
numTrainTimesteps := float32(s.Config.NumTrainTimesteps)
// Create base sigmas: linspace from 1.0 to small value (matching diffusers)
for i := 0; i < numSteps; i++ {
// linspace from 1000 to ~20 (sigma_min * num_train_timesteps)
tRaw := numTrainTimesteps - float32(i)*(numTrainTimesteps-1.0)/float32(numSteps-1)
s.Timesteps[i] = tRaw
// Convert to sigma [0, 1]
sigma := tRaw / numTrainTimesteps
// Apply time shift if enabled
if s.Config.UseDynamicShifting && mu > 0 {
sigma = s.applyShift(mu, sigma)
}
s.Sigmas[i] = sigma
}
// Append terminal sigma = 0 (the final clean image)
s.Sigmas[numSteps] = 0
}
// calculateShift computes dynamic shift based on image sequence length
// Uses the sqrt-based formula from diffusers:
// m = (image_seq_len / base_seq_len) ** 0.5
// mu = m * max_shift + base_shift
func (s *FlowMatchScheduler) calculateShift(imgSeqLen int32) float32 {
cfg := s.Config
if !cfg.UseDynamicShifting {
return 0
}
// Sqrt-based shift calculation (matches diffusers pipeline_glm_image.py)
m := float32(math.Sqrt(float64(imgSeqLen) / float64(cfg.BaseImageSeqLen)))
mu := m*cfg.MaxShift + cfg.BaseShift
return mu
}
// applyShift applies time shift transformation
// mu: the computed shift value
// t: sigma value in [0, 1]
func (s *FlowMatchScheduler) applyShift(mu float32, t float32) float32 {
if t <= 0 {
return 0
}
if t >= 1 {
return 1
}
// sigma=1.0 for both shift types
sigma := float32(1.0)
if s.Config.TimeShiftType == "linear" {
// Linear: mu / (mu + (1/t - 1)^sigma)
return mu / (mu + float32(math.Pow(float64(1.0/t-1.0), float64(sigma))))
}
// Exponential (default): exp(mu) / (exp(mu) + (1/t - 1)^sigma)
expMu := float32(math.Exp(float64(mu)))
return expMu / (expMu + float32(math.Pow(float64(1.0/t-1.0), float64(sigma))))
}
// Step performs one denoising step
func (s *FlowMatchScheduler) Step(modelOutput, sample *mlx.Array, stepIdx int) *mlx.Array {
sigma := s.Sigmas[stepIdx]
sigmaNext := s.Sigmas[stepIdx+1]
// Euler step: x_{t-dt} = x_t + dt * v_t
dt := sigmaNext - sigma // Negative (going from noise to clean)
scaledOutput := mlx.MulScalar(modelOutput, dt)
return mlx.Add(sample, scaledOutput)
}
// InitNoise creates initial noise
func (s *FlowMatchScheduler) InitNoise(shape []int32, seed int64) *mlx.Array {
return mlx.RandomNormalWithDtype(shape, uint64(seed), mlx.DtypeBFloat16)
}
// AddNoise adds noise to clean samples for a given timestep (for img2img)
func (s *FlowMatchScheduler) AddNoise(cleanSample, noise *mlx.Array, timestepIdx int) *mlx.Array {
// In flow matching: x_t = (1-sigma) * x_0 + sigma * noise
// Use sigmas (shifted) for the interpolation
sigma := s.Sigmas[timestepIdx]
oneMinusSigma := 1.0 - sigma
scaledClean := mlx.MulScalar(cleanSample, oneMinusSigma)
scaledNoise := mlx.MulScalar(noise, sigma)
return mlx.Add(scaledClean, scaledNoise)
}
// GetTimesteps returns all timesteps
func (s *FlowMatchScheduler) GetTimesteps() []float32 {
return s.Timesteps
}

View File

@@ -1,497 +0,0 @@
//go:build mlx
package glm_image
import (
"encoding/json"
"fmt"
"math"
"os"
"path/filepath"
"regexp"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// T5Config holds T5 encoder configuration
type T5Config struct {
DModel int32 `json:"d_model"` // 1472
DFF int32 `json:"d_ff"` // 3584
DKV int32 `json:"d_kv"` // 64
NumHeads int32 `json:"num_heads"` // 6
NumLayers int32 `json:"num_layers"` // 12
VocabSize int32 `json:"vocab_size"` // 384 (byte-level)
LayerNormEps float32 `json:"layer_norm_epsilon"` // 1e-6
IsGatedAct bool `json:"is_gated_act"` // true (gated-gelu)
// Relative position bias
RelativeAttentionNumBuckets int32 `json:"relative_attention_num_buckets"` // 32
RelativeAttentionMaxDistance int32 `json:"relative_attention_max_distance"` // 128
}
// T5TextEncoder is the T5 encoder for text conditioning
type T5TextEncoder struct {
Config *T5Config
// Embedding (shared for ByT5)
SharedEmbed *nn.Embedding `weight:"shared"`
// Encoder layers
Layers []*T5Block `weight:"encoder.block"`
// Final layer norm
FinalNorm *T5LayerNorm `weight:"encoder.final_layer_norm"`
// Relative position bias (from first layer, shared across all)
RelativeAttentionBias *mlx.Array `weight:"encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"`
}
// T5Block is a single T5 encoder block
type T5Block struct {
// Self attention
Layer0 *T5LayerSelfAttention `weight:"layer.0"`
// FFN
Layer1 *T5LayerFF `weight:"layer.1"`
}
// T5LayerSelfAttention is T5's self-attention layer
type T5LayerSelfAttention struct {
SelfAttention *T5Attention `weight:"SelfAttention"`
LayerNorm *T5LayerNorm `weight:"layer_norm"`
}
// T5Attention implements T5's relative attention
type T5Attention struct {
Q *mlx.Array `weight:"q.weight"` // No bias in T5
K *mlx.Array `weight:"k.weight"`
V *mlx.Array `weight:"v.weight"`
O *mlx.Array `weight:"o.weight"`
NHeads int32
DKV int32
Scale float32
}
// T5LayerFF is T5's feedforward layer with gated-gelu
type T5LayerFF struct {
DenseReluDense *T5DenseGatedGelu `weight:"DenseReluDense"`
LayerNorm *T5LayerNorm `weight:"layer_norm"`
}
// T5DenseGatedGelu is T5's gated-gelu FFN
type T5DenseGatedGelu struct {
Wi0 *mlx.Array `weight:"wi_0.weight"` // gate projection
Wi1 *mlx.Array `weight:"wi_1.weight"` // up projection
Wo *mlx.Array `weight:"wo.weight"` // down projection
}
// T5LayerNorm is T5's RMSNorm variant (no bias, no mean subtraction)
type T5LayerNorm struct {
Weight *mlx.Array `weight:"weight"`
Eps float32
}
// Load loads the T5 text encoder from manifest
func (m *T5TextEncoder) Load(manifest *imagegen.ModelManifest) error {
fmt.Print(" Loading T5 text encoder... ")
// Load config
var cfg T5Config
if err := manifest.ReadConfigJSON("text_encoder/config.json", &cfg); err != nil {
return fmt.Errorf("config: %w", err)
}
m.Config = &cfg
// Pre-allocate layers
m.Layers = make([]*T5Block, cfg.NumLayers)
// Load weights
weights, err := imagegen.LoadWeightsFromManifest(manifest, "text_encoder")
if err != nil {
return fmt.Errorf("weights: %w", err)
}
if err := weights.Load(0); err != nil {
return fmt.Errorf("load weights: %w", err)
}
defer weights.ReleaseAll()
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return fmt.Errorf("load module: %w", err)
}
m.initComputedFields()
fmt.Println("✓")
return nil
}
// LoadFromPath loads the T5 text encoder from a directory path
func (m *T5TextEncoder) LoadFromPath(path string) error {
fmt.Print(" Loading T5 text encoder... ")
// Load config
var cfg T5Config
configPath := filepath.Join(path, "config.json")
data, err := os.ReadFile(configPath)
if err != nil {
return fmt.Errorf("read config: %w", err)
}
if err := json.Unmarshal(data, &cfg); err != nil {
return fmt.Errorf("parse config: %w", err)
}
m.Config = &cfg
// Pre-allocate layers
m.Layers = make([]*T5Block, cfg.NumLayers)
// Load weights from safetensors files
weights, err := safetensors.LoadModelWeights(path)
if err != nil {
return fmt.Errorf("weights: %w", err)
}
if err := weights.Load(0); err != nil {
return fmt.Errorf("load weights: %w", err)
}
defer weights.ReleaseAll()
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return fmt.Errorf("load module: %w", err)
}
m.initComputedFields()
fmt.Println("✓")
return nil
}
func (m *T5TextEncoder) initComputedFields() {
cfg := m.Config
m.FinalNorm.Eps = cfg.LayerNormEps
for _, block := range m.Layers {
attn := block.Layer0.SelfAttention
attn.NHeads = cfg.NumHeads
attn.DKV = cfg.DKV
attn.Scale = float32(1.0 / math.Sqrt(float64(cfg.DKV)))
block.Layer0.LayerNorm.Eps = cfg.LayerNormEps
block.Layer1.LayerNorm.Eps = cfg.LayerNormEps
}
}
// Forward encodes text tokens
func (m *T5TextEncoder) Forward(tokens *mlx.Array) *mlx.Array {
cfg := m.Config
// Get embeddings
h := m.SharedEmbed.Forward(tokens)
// Compute relative position bias once
seqLen := tokens.Shape()[1]
posBias := m.computeRelativePositionBias(seqLen)
// Forward through layers
for _, block := range m.Layers {
h = block.Forward(h, posBias, cfg.LayerNormEps)
}
// Final norm
h = m.FinalNorm.Forward(h)
return h
}
// extractGlyphTexts extracts quoted text (glyphs) from the prompt
// This matches diffusers' get_glyph_texts from pipeline_glm_image.py
// Glyph texts are used for text rendering guidance in the generated image
func extractGlyphTexts(prompt string) []string {
var glyphTexts []string
// Extract text in single quotes: 'text'
re1 := regexp.MustCompile(`'([^']*)'`)
for _, match := range re1.FindAllStringSubmatch(prompt, -1) {
if len(match) > 1 {
glyphTexts = append(glyphTexts, match[1])
}
}
// Extract text in Unicode curly double quotes: "text"
re2 := regexp.MustCompile(`"([^""]*)"`)
for _, match := range re2.FindAllStringSubmatch(prompt, -1) {
if len(match) > 1 {
glyphTexts = append(glyphTexts, match[1])
}
}
// Extract text in ASCII double quotes: "text"
re3 := regexp.MustCompile(`"([^"]*)"`)
for _, match := range re3.FindAllStringSubmatch(prompt, -1) {
if len(match) > 1 {
glyphTexts = append(glyphTexts, match[1])
}
}
// Extract text in Japanese quotes: 「text」
re4 := regexp.MustCompile(`「([^「」]*)」`)
for _, match := range re4.FindAllStringSubmatch(prompt, -1) {
if len(match) > 1 {
glyphTexts = append(glyphTexts, match[1])
}
}
return glyphTexts
}
// EncodePrompt encodes the prompt text using the ByT5 tokenizer and encoder
// This provides text conditioning for the diffusion transformer via the glyph projector
//
// IMPORTANT: This encodes only the GLYPH TEXTS (quoted strings in the prompt), not the
// full prompt. Glyph texts are used for text rendering guidance in the generated image.
// Multiple glyph texts are encoded and concatenated to form the conditioning signal.
// This matches diffusers' _get_glyph_embeds() behavior.
func (m *T5TextEncoder) EncodePrompt(tok *ByT5Tokenizer, prompt string) *mlx.Array {
// Extract glyph texts from prompt (text in quotes)
glyphTexts := extractGlyphTexts(prompt)
// If no glyph texts found, encode empty string (matches diffusers: [""] fallback)
if len(glyphTexts) == 0 {
glyphTexts = []string{""}
}
// Encode each glyph text and collect token sequences
// Matching diffusers' _get_glyph_embeds() which batches all glyph texts
var allTokenSeqs [][]int32
for _, glyphText := range glyphTexts {
// ByT5 uses byte-level encoding: each byte (0-255) -> token (3-258)
tokens := tok.Encode(glyphText)
// Add EOS token (1) at the end to match HuggingFace tokenizer behavior
tokens = append(tokens, tok.EOSTokenID)
allTokenSeqs = append(allTokenSeqs, tokens)
}
// Process each glyph text through the encoder
var allEmbeddings []*mlx.Array
for _, tokens := range allTokenSeqs {
tokenLen := len(tokens)
if tokenLen == 0 {
continue
}
// Create token array [1, L]
tokensArr := mlx.NewArrayInt32(tokens, []int32{1, int32(tokenLen)})
// Forward through encoder
output := m.Forward(tokensArr)
mlx.Eval(output)
allEmbeddings = append(allEmbeddings, output)
}
// Concatenate all glyph embeddings along sequence dimension
var output *mlx.Array
if len(allEmbeddings) == 0 {
// Fallback: return single zero embedding
output = mlx.Zeros([]int32{1, 1, m.Config.DModel}, mlx.DtypeBFloat16)
} else if len(allEmbeddings) == 1 {
output = allEmbeddings[0]
} else {
output = mlx.Concatenate(allEmbeddings, 1)
}
mlx.Eval(output)
return output
}
// computeRelativePositionBias computes T5's relative position encoding
func (m *T5TextEncoder) computeRelativePositionBias(seqLen int32) *mlx.Array {
cfg := m.Config
// Create relative position matrix
// For each (query_pos, key_pos) pair, compute bucketed relative position
numBuckets := cfg.RelativeAttentionNumBuckets
maxDistance := cfg.RelativeAttentionMaxDistance
// Create position indices
contextPos := make([]int32, seqLen*seqLen)
memoryPos := make([]int32, seqLen*seqLen)
for i := int32(0); i < seqLen; i++ {
for j := int32(0); j < seqLen; j++ {
contextPos[i*seqLen+j] = i
memoryPos[i*seqLen+j] = j
}
}
// Compute relative positions and bucket them
buckets := make([]int32, seqLen*seqLen)
for i := int32(0); i < seqLen*seqLen; i++ {
relPos := memoryPos[i] - contextPos[i]
buckets[i] = relativePosistionBucket(relPos, numBuckets, maxDistance, false)
}
// Create bucket indices array
bucketsArr := mlx.NewArrayInt32(buckets, []int32{seqLen, seqLen})
// Look up bias: RelativeAttentionBias shape is [numBuckets, numHeads] = [32, 6]
// Take along axis 0 (buckets dimension) -> [seqLen, seqLen, numHeads]
bias := mlx.Take(m.RelativeAttentionBias, bucketsArr, 0) // [seqLen, seqLen, numHeads]
// Transpose to [numHeads, seqLen, seqLen]
bias = mlx.Transpose(bias, 2, 0, 1) // [numHeads, seqLen, seqLen]
bias = mlx.ExpandDims(bias, 0) // [1, numHeads, seqLen, seqLen]
return bias
}
// relativePosistionBucket computes the bucket for a relative position
func relativePosistionBucket(relativePosition int32, numBuckets int32, maxDistance int32, bidirectional bool) int32 {
var bucket int32 = 0
var n int32 = -relativePosition
if bidirectional {
numBuckets /= 2
if n < 0 {
bucket += numBuckets
n = -n
}
} else {
if n < 0 {
n = 0
}
}
// Half buckets are for exact positions, half are for log-spaced
maxExact := numBuckets / 2
if n < maxExact {
bucket += n
} else {
// Log-spaced buckets
logVal := math.Log(float64(n)/float64(maxExact)) / math.Log(float64(maxDistance)/float64(maxExact))
bucket += maxExact + int32(logVal*float64(numBuckets-maxExact))
if bucket > numBuckets-1 {
bucket = numBuckets - 1
}
}
return bucket
}
// Forward for T5Block
func (b *T5Block) Forward(x *mlx.Array, posBias *mlx.Array, eps float32) *mlx.Array {
// Self attention with residual
h := b.Layer0.Forward(x, posBias, eps)
// FFN with residual
h = b.Layer1.Forward(h, eps)
return h
}
// Forward for T5LayerSelfAttention
func (l *T5LayerSelfAttention) Forward(x *mlx.Array, posBias *mlx.Array, eps float32) *mlx.Array {
// Pre-norm
normed := l.LayerNorm.Forward(x)
// Attention
attnOut := l.SelfAttention.Forward(normed, posBias)
// Residual
return mlx.Add(x, attnOut)
}
// Forward for T5Attention
func (attn *T5Attention) Forward(x *mlx.Array, posBias *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
D := shape[2]
// Q, K, V projections (no bias)
// Weights are [out_features, in_features], so we use matmul with transpose
q := mlx.Matmul(x, mlx.Transpose(attn.Q, 1, 0))
k := mlx.Matmul(x, mlx.Transpose(attn.K, 1, 0))
v := mlx.Matmul(x, mlx.Transpose(attn.V, 1, 0))
// Reshape to [B, L, nheads, d_kv]
q = mlx.Reshape(q, B, L, attn.NHeads, attn.DKV)
k = mlx.Reshape(k, B, L, attn.NHeads, attn.DKV)
v = mlx.Reshape(v, B, L, attn.NHeads, attn.DKV)
// Transpose to [B, nheads, L, d_kv]
q = mlx.Transpose(q, 0, 2, 1, 3)
k = mlx.Transpose(k, 0, 2, 1, 3)
v = mlx.Transpose(v, 0, 2, 1, 3)
// Attention scores with relative position bias
// T5 uses UNSCALED dot-product attention: scores = q @ k.T + pos_bias
// (no 1/sqrt(d_k) scale factor like in standard transformers)
scores := mlx.Matmul(q, mlx.Transpose(k, 0, 1, 3, 2))
scores = mlx.Add(scores, posBias)
// Softmax
attnWeights := mlx.Softmax(scores, -1)
// Attend to values
out := mlx.Matmul(attnWeights, v)
// Transpose back [B, nheads, L, d_kv] -> [B, L, nheads, d_kv]
out = mlx.Transpose(out, 0, 2, 1, 3)
// Reshape to [B, L, D]
out = mlx.Reshape(out, B, L, attn.NHeads*attn.DKV)
// Output projection
out = mlx.Matmul(out, mlx.Transpose(attn.O, 1, 0))
_ = D // Silence unused warning
return out
}
// Forward for T5LayerFF
func (l *T5LayerFF) Forward(x *mlx.Array, eps float32) *mlx.Array {
// Pre-norm
normed := l.LayerNorm.Forward(x)
// FFN
ffOut := l.DenseReluDense.Forward(normed)
// Residual
return mlx.Add(x, ffOut)
}
// geluNew implements the GELU activation with tanh approximation (gelu_new)
// This matches HuggingFace transformers' gelu_new/OpenAI GPT implementation
// Formula: 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x³)))
func geluNew(x *mlx.Array) *mlx.Array {
sqrt2OverPi := float32(0.7978845608) // sqrt(2/π)
coeff := float32(0.044715)
x3 := mlx.Mul(mlx.Mul(x, x), x)
inner := mlx.MulScalar(mlx.Add(x, mlx.MulScalar(x3, coeff)), sqrt2OverPi)
return mlx.Mul(mlx.MulScalar(x, 0.5), mlx.AddScalar(mlx.Tanh(inner), 1.0))
}
// Forward for T5DenseGatedGelu (gated-gelu activation)
func (d *T5DenseGatedGelu) Forward(x *mlx.Array) *mlx.Array {
// Gate projection with GELU activation (T5 v1.1/ByT5 uses gelu_new)
gate := mlx.Matmul(x, mlx.Transpose(d.Wi0, 1, 0))
gate = geluNew(gate)
// Up projection
up := mlx.Matmul(x, mlx.Transpose(d.Wi1, 1, 0))
// Gated output
h := mlx.Mul(gate, up)
// Down projection
return mlx.Matmul(h, mlx.Transpose(d.Wo, 1, 0))
}
// Forward for T5LayerNorm (RMSNorm variant)
func (ln *T5LayerNorm) Forward(x *mlx.Array) *mlx.Array {
// T5 uses RMSNorm: x * rsqrt(mean(x^2) + eps) * weight
variance := mlx.Mean(mlx.Square(x), -1, true)
x = mlx.Mul(x, mlx.RSqrt(mlx.AddScalar(variance, ln.Eps)))
return mlx.Mul(x, ln.Weight)
}

View File

File diff suppressed because it is too large Load Diff

View File

@@ -1,477 +0,0 @@
//go:build mlx
package glm_image
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// VAEConfig holds VAE decoder configuration
type VAEConfig struct {
InChannels int32 `json:"in_channels"` // 3
OutChannels int32 `json:"out_channels"` // 3
LatentChannels int32 `json:"latent_channels"` // 16
BlockOutChannels []int32 `json:"block_out_channels"` // [128, 512, 1024, 1024]
LayersPerBlock int32 `json:"layers_per_block"` // 3
NormNumGroups int32 `json:"norm_num_groups"` // 32
ScalingFactor float32 `json:"scaling_factor"` // 0.18215
ShiftFactor *float32 `json:"shift_factor"` // null
LatentsMean []float32 `json:"latents_mean"` // [16 values]
LatentsStd []float32 `json:"latents_std"` // [16 values]
}
// VAEDecoder is the VAE latent decoder
type VAEDecoder struct {
Config *VAEConfig
// Decoder components
ConvIn *VAEConv2d `weight:"decoder.conv_in"`
MidBlock *VAEMidBlock `weight:"decoder.mid_block"`
UpBlocks []*VAEUpBlock `weight:"decoder.up_blocks"`
ConvNormOut *GroupNorm `weight:"decoder.conv_norm_out"`
ConvOut *VAEConv2d `weight:"decoder.conv_out"`
}
// VAEConv2d is a 2D convolution layer
type VAEConv2d struct {
Weight *mlx.Array `weight:"weight"`
Bias *mlx.Array `weight:"bias"`
Stride int32
Padding int32
}
// GroupNorm is group normalization
type GroupNorm struct {
Weight *mlx.Array `weight:"weight"`
Bias *mlx.Array `weight:"bias"`
NumGroups int32
Eps float32
}
// VAEMidBlock is the middle block of the VAE
type VAEMidBlock struct {
Resnets []*VAEResnetBlock `weight:"resnets"`
}
// VAEUpBlock is an upsampling block
type VAEUpBlock struct {
Resnets []*VAEResnetBlock `weight:"resnets"`
Upsamplers []*VAEUpsampler `weight:"upsamplers"`
}
// VAEResnetBlock is a residual block
type VAEResnetBlock struct {
Norm1 *GroupNorm `weight:"norm1"`
Conv1 *VAEConv2d `weight:"conv1"`
Norm2 *GroupNorm `weight:"norm2"`
Conv2 *VAEConv2d `weight:"conv2"`
ConvShortcut *VAEConv2d `weight:"conv_shortcut,optional"` // Optional, for channel mismatch
}
// VAEUpsampler is an upsampling layer
type VAEUpsampler struct {
Conv *VAEConv2d `weight:"conv"`
}
// Load loads the VAE decoder from manifest
func (m *VAEDecoder) Load(manifest *imagegen.ModelManifest) error {
fmt.Print(" Loading VAE decoder... ")
// Load config
var cfg VAEConfig
if err := manifest.ReadConfigJSON("vae/config.json", &cfg); err != nil {
return fmt.Errorf("config: %w", err)
}
m.Config = &cfg
// Initialize structure based on config
numBlocks := len(cfg.BlockOutChannels)
m.UpBlocks = make([]*VAEUpBlock, numBlocks)
// Pre-allocate MidBlock resnets (VAE mid_block typically has 2 resnets)
m.MidBlock = &VAEMidBlock{
Resnets: make([]*VAEResnetBlock, 2),
}
// Pre-allocate UpBlocks with their resnets and upsamplers
// VAE decoder has layers_per_block+1 resnets per up_block (to match encoder)
// And all but the last up_block has an upsampler
for i := 0; i < numBlocks; i++ {
numResnets := cfg.LayersPerBlock + 1 // typically 4 resnets
m.UpBlocks[i] = &VAEUpBlock{
Resnets: make([]*VAEResnetBlock, numResnets),
}
// All but the last block has upsamplers
if i < numBlocks-1 {
m.UpBlocks[i].Upsamplers = make([]*VAEUpsampler, 1)
}
}
// Load weights
weights, err := imagegen.LoadWeightsFromManifest(manifest, "vae")
if err != nil {
return fmt.Errorf("weights: %w", err)
}
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
return fmt.Errorf("load weights: %w", err)
}
defer weights.ReleaseAll()
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return fmt.Errorf("load module: %w", err)
}
// Initialize GroupNorm parameters
m.initGroupNorms()
fmt.Println("✓")
return nil
}
// LoadFromPath loads the VAE decoder from a directory path
func (m *VAEDecoder) LoadFromPath(path string) error {
fmt.Print(" Loading VAE decoder... ")
// Load config
var cfg VAEConfig
configPath := filepath.Join(path, "config.json")
data, err := os.ReadFile(configPath)
if err != nil {
return fmt.Errorf("read config: %w", err)
}
if err := json.Unmarshal(data, &cfg); err != nil {
return fmt.Errorf("parse config: %w", err)
}
m.Config = &cfg
// Initialize structure based on config
numBlocks := len(cfg.BlockOutChannels)
m.UpBlocks = make([]*VAEUpBlock, numBlocks)
// Pre-allocate MidBlock resnets (VAE mid_block typically has 2 resnets)
m.MidBlock = &VAEMidBlock{
Resnets: make([]*VAEResnetBlock, 2),
}
// Pre-allocate UpBlocks with their resnets and upsamplers
for i := 0; i < numBlocks; i++ {
numResnets := cfg.LayersPerBlock + 1
m.UpBlocks[i] = &VAEUpBlock{
Resnets: make([]*VAEResnetBlock, numResnets),
}
if i < numBlocks-1 {
m.UpBlocks[i].Upsamplers = make([]*VAEUpsampler, 1)
}
}
// Load weights from safetensors files
weights, err := safetensors.LoadModelWeights(path)
if err != nil {
return fmt.Errorf("weights: %w", err)
}
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
return fmt.Errorf("load weights: %w", err)
}
defer weights.ReleaseAll()
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return fmt.Errorf("load module: %w", err)
}
// Initialize GroupNorm parameters
m.initGroupNorms()
fmt.Println("✓")
return nil
}
func (m *VAEDecoder) initGroupNorms() {
cfg := m.Config
numGroups := cfg.NormNumGroups
eps := float32(1e-6) // Must match diffusers VAE (1e-6, not 1e-5)
if m.ConvNormOut != nil {
m.ConvNormOut.NumGroups = numGroups
m.ConvNormOut.Eps = eps
}
if m.MidBlock != nil {
for _, resnet := range m.MidBlock.Resnets {
if resnet.Norm1 != nil {
resnet.Norm1.NumGroups = numGroups
resnet.Norm1.Eps = eps
}
if resnet.Norm2 != nil {
resnet.Norm2.NumGroups = numGroups
resnet.Norm2.Eps = eps
}
}
}
for _, upBlock := range m.UpBlocks {
if upBlock == nil {
continue
}
for _, resnet := range upBlock.Resnets {
if resnet == nil {
continue
}
if resnet.Norm1 != nil {
resnet.Norm1.NumGroups = numGroups
resnet.Norm1.Eps = eps
}
if resnet.Norm2 != nil {
resnet.Norm2.NumGroups = numGroups
resnet.Norm2.Eps = eps
}
}
}
}
// Decode decodes latents to an image
func (m *VAEDecoder) Decode(latents *mlx.Array) *mlx.Array {
cfg := m.Config
// Apply latent denormalization if mean/std are provided
// This matches diffusers GLM-Image: latents = latents * std + mean
// Note: GLM-Image does NOT divide by scaling_factor (unlike standard SD VAEs)
if len(cfg.LatentsMean) > 0 && len(cfg.LatentsStd) > 0 {
latents = m.denormalizeLatents(latents)
}
// Convert from NCHW to NHWC for processing
// [B, C, H, W] -> [B, H, W, C]
x := mlx.Transpose(latents, 0, 2, 3, 1)
// Initial convolution
x = m.ConvIn.Forward(x)
// Mid block
x = m.MidBlock.Forward(x)
// Up blocks (forward order - index 0 is at lowest resolution/highest channels)
for i := 0; i < len(m.UpBlocks); i++ {
if m.UpBlocks[i] != nil {
x = m.UpBlocks[i].Forward(x)
}
}
// Final normalization and convolution
x = m.ConvNormOut.Forward(x)
x = mlx.SiLU(x)
x = m.ConvOut.Forward(x)
// Convert back to NCHW
// [B, H, W, C] -> [B, C, H, W]
x = mlx.Transpose(x, 0, 3, 1, 2)
// Clamp to valid range and convert to [0, 1]
x = mlx.ClipScalar(x, -1.0, 1.0, true, true)
x = mlx.AddScalar(x, 1.0)
x = mlx.DivScalar(x, 2.0)
return x
}
// denormalizeLatents applies the latent mean/std denormalization
func (m *VAEDecoder) denormalizeLatents(latents *mlx.Array) *mlx.Array {
cfg := m.Config
// Create mean and std arrays [1, C, 1, 1] for broadcasting
mean := mlx.NewArray(cfg.LatentsMean, []int32{1, int32(len(cfg.LatentsMean)), 1, 1})
std := mlx.NewArray(cfg.LatentsStd, []int32{1, int32(len(cfg.LatentsStd)), 1, 1})
// Denormalize: latents * std + mean
latents = mlx.Mul(latents, std)
latents = mlx.Add(latents, mean)
return latents
}
// Forward for VAEConv2d
func (c *VAEConv2d) Forward(x *mlx.Array) *mlx.Array {
// x: [B, H, W, C_in] (NHWC)
// PyTorch weight: [C_out, C_in, kH, kW] (OIHW)
// MLX conv2d expects weight: [C_out, kH, kW, C_in] (OHWI)
// So we need to transpose from OIHW to OHWI
stride := c.Stride
if stride == 0 {
stride = 1
}
padding := c.Padding
if padding == 0 {
// Default to same padding for 3x3 kernels
wShape := c.Weight.Shape()
if len(wShape) >= 3 && wShape[2] == 3 {
padding = 1
}
}
// Transpose weight from OIHW [out, in, h, w] to OHWI [out, h, w, in]
weight := mlx.Transpose(c.Weight, 0, 2, 3, 1)
out := mlx.Conv2d(x, weight, stride, padding)
if c.Bias != nil {
// Bias: [C_out] -> [1, 1, 1, C_out]
bias := mlx.Reshape(c.Bias, 1, 1, 1, -1)
out = mlx.Add(out, bias)
}
return out
}
// Forward for GroupNorm
func (gn *GroupNorm) Forward(x *mlx.Array) *mlx.Array {
// x: [B, H, W, C] (NHWC)
shape := x.Shape()
B := shape[0]
H := shape[1]
W := shape[2]
C := shape[3]
numGroups := gn.NumGroups
if numGroups == 0 {
numGroups = 32
}
groupSize := C / numGroups
// Reshape to [B, H, W, groups, groupSize]
x = mlx.Reshape(x, B, H, W, numGroups, groupSize)
// Compute mean and variance per group
mean := mlx.Mean(x, 1, true)
mean = mlx.Mean(mean, 2, true)
mean = mlx.Mean(mean, 4, true)
xCentered := mlx.Sub(x, mean)
variance := mlx.Mean(mlx.Square(xCentered), 1, true)
variance = mlx.Mean(variance, 2, true)
variance = mlx.Mean(variance, 4, true)
// Normalize
xNorm := mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, gn.Eps)))
// Reshape back
xNorm = mlx.Reshape(xNorm, B, H, W, C)
// Scale and shift
if gn.Weight != nil {
weight := mlx.Reshape(gn.Weight, 1, 1, 1, C)
xNorm = mlx.Mul(xNorm, weight)
}
if gn.Bias != nil {
bias := mlx.Reshape(gn.Bias, 1, 1, 1, C)
xNorm = mlx.Add(xNorm, bias)
}
return xNorm
}
// Forward for VAEMidBlock
func (mb *VAEMidBlock) Forward(x *mlx.Array) *mlx.Array {
for _, resnet := range mb.Resnets {
x = resnet.Forward(x)
}
return x
}
// Forward for VAEUpBlock
func (ub *VAEUpBlock) Forward(x *mlx.Array) *mlx.Array {
// Apply resnets
for _, resnet := range ub.Resnets {
if resnet != nil {
x = resnet.Forward(x)
}
}
// Apply upsamplers
for _, upsampler := range ub.Upsamplers {
if upsampler != nil {
x = upsampler.Forward(x)
}
}
return x
}
// Forward for VAEResnetBlock
func (rb *VAEResnetBlock) Forward(x *mlx.Array) *mlx.Array {
residual := x
// First norm + activation + conv
h := rb.Norm1.Forward(x)
h = mlx.SiLU(h)
h = rb.Conv1.Forward(h)
// Second norm + activation + conv
h = rb.Norm2.Forward(h)
h = mlx.SiLU(h)
h = rb.Conv2.Forward(h)
// Shortcut for channel mismatch
if rb.ConvShortcut != nil {
residual = rb.ConvShortcut.Forward(residual)
}
return mlx.Add(h, residual)
}
// Forward for VAEUpsampler (2x nearest neighbor upsample + conv)
func (us *VAEUpsampler) Forward(x *mlx.Array) *mlx.Array {
// x: [B, H, W, C]
// 2x nearest neighbor upsample
x = upsample2x(x)
// Conv
if us.Conv != nil {
x = us.Conv.Forward(x)
}
return x
}
// upsample2x performs 2x nearest neighbor upsampling.
// Input and output are in NHWC format: [B, H, W, C] -> [B, H*2, W*2, C]
func upsample2x(x *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
H := shape[1]
W := shape[2]
C := shape[3]
// Create indices [0, 0, 1, 1, 2, 2, ...] for nearest neighbor
hIndices := make([]int32, H*2)
for i := int32(0); i < H; i++ {
hIndices[i*2] = i
hIndices[i*2+1] = i
}
wIndices := make([]int32, W*2)
for i := int32(0); i < W; i++ {
wIndices[i*2] = i
wIndices[i*2+1] = i
}
hIdx := mlx.NewArrayInt32(hIndices, []int32{H * 2})
wIdx := mlx.NewArrayInt32(wIndices, []int32{W * 2})
// Take along height axis
x = mlx.Reshape(x, B*H, W, C)
x = mlx.Take(x, wIdx, 1) // [B*H, W*2, C]
x = mlx.Reshape(x, B, H, W*2, C)
// Take along width axis - transpose to [B, W*2, H, C], take, transpose back
x = mlx.Transpose(x, 0, 2, 1, 3) // [B, W*2, H, C]
x = mlx.Reshape(x, B*(W*2), H, C)
x = mlx.Take(x, hIdx, 1) // [B*(W*2), H*2, C]
x = mlx.Reshape(x, B, W*2, H*2, C)
x = mlx.Transpose(x, 0, 2, 1, 3) // [B, H*2, W*2, C]
return x
}

View File

@@ -1,982 +0,0 @@
//go:build mlx
package glm_image
import (
"encoding/json"
"fmt"
"math"
"os"
"path/filepath"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// VisionLanguageConfig holds GLM-Image AR generator configuration
type VisionLanguageConfig struct {
// Text model config
HiddenSize int32 `json:"hidden_size"` // 4096
NumHiddenLayers int32 `json:"num_hidden_layers"` // 40
IntermediateSize int32 `json:"intermediate_size"` // 13696
NumAttentionHeads int32 `json:"num_attention_heads"` // 32
NumKeyValueHeads int32 `json:"num_key_value_heads"` // 2
VocabSize int32 `json:"vocab_size"` // 168064
RMSNormEps float32 `json:"rms_norm_eps"` // 1e-5
// RoPE config
RopeTheta float32 `json:"rope_theta"` // 10000
PartialRotaryFactor float32 `json:"partial_rotary_factor"` // 0.5
MRoPESection []int32 `json:"mrope_section"` // [8, 12, 12]
// Visual token config
VisionVocabSize int32 `json:"vision_vocab_size"` // 16512
ImageStartTokenID int32 `json:"image_start_token_id"` // 16384
ImageEndTokenID int32 `json:"image_end_token_id"` // 16385
ImageTokenID int32 `json:"image_token_id"` // 167855
// Computed
HeadDim int32
}
// VisionLanguageEncoder is the 9B AR generator
type VisionLanguageEncoder struct {
Config *VisionLanguageConfig
// Embedding
EmbedTokens *nn.Embedding `weight:"model.language_model.embed_tokens"`
// Transformer layers
Layers []*GLMBlock `weight:"model.language_model.layers"`
// Final norm
FinalNorm *nn.RMSNorm `weight:"model.language_model.norm"`
// LM Head
LMHead *mlx.Array `weight:"lm_head.weight"`
}
// GLMBlock is a single transformer block in GLM-4 style
type GLMBlock struct {
// Pre-attention norm (GLM uses post-LN variant)
InputLayerNorm *nn.RMSNorm `weight:"input_layernorm"`
PostSelfAttnNorm *nn.RMSNorm `weight:"post_self_attn_layernorm"`
PostAttnLayerNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
PostMLPLayerNorm *nn.RMSNorm `weight:"post_mlp_layernorm"`
// Attention
SelfAttn *GLMAttention `weight:"self_attn"`
// MLP (fused gate_up)
MLP *GLMMLP `weight:"mlp"`
}
// GLMAttention implements GQA with partial rotary and MRoPE
type GLMAttention struct {
QProj *mlx.Array `weight:"q_proj.weight"`
KProj *mlx.Array `weight:"k_proj.weight"`
VProj *mlx.Array `weight:"v_proj.weight"`
OProj *mlx.Array `weight:"o_proj.weight"`
// QKV have biases in GLM
QBias *mlx.Array `weight:"q_proj.bias"`
KBias *mlx.Array `weight:"k_proj.bias"`
VBias *mlx.Array `weight:"v_proj.bias"`
// Computed
NHeads int32
NKVHeads int32
HeadDim int32
Scale float32
PartialRotary float32 // Only rotate this fraction of head_dim
RopeTheta float32
MRoPESection []int32 // [8, 12, 12] - frequency pairs per dimension (temporal, height, width)
}
// ARCache holds KV caches for all layers using the shared cache implementation
type ARCache struct {
Layers []cache.Cache
}
// NewARCache creates a new cache for the given number of layers
func NewARCache(numLayers int32) *ARCache {
layers := make([]cache.Cache, numLayers)
for i := range layers {
layers[i] = cache.NewKVCache()
}
return &ARCache{Layers: layers}
}
// Free releases all cached tensors
func (c *ARCache) Free() {
for _, layer := range c.Layers {
for _, arr := range layer.State() {
if arr != nil {
arr.Free()
}
}
}
}
// GLMMLP implements fused gate_up SwiGLU MLP
type GLMMLP struct {
// GLM uses fused gate_up_proj: [hidden, 2*intermediate]
GateUpProj *mlx.Array `weight:"gate_up_proj.weight"`
DownProj *mlx.Array `weight:"down_proj.weight"`
}
// Load loads the vision-language encoder from manifest
func (m *VisionLanguageEncoder) Load(manifest *imagegen.ModelManifest) error {
fmt.Print(" Loading vision-language encoder... ")
// Load config
var rawCfg struct {
TextConfig struct {
HiddenSize int32 `json:"hidden_size"`
NumHiddenLayers int32 `json:"num_hidden_layers"`
IntermediateSize int32 `json:"intermediate_size"`
NumAttentionHeads int32 `json:"num_attention_heads"`
NumKeyValueHeads int32 `json:"num_key_value_heads"`
VocabSize int32 `json:"vocab_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
VisionVocabSize int32 `json:"vision_vocab_size"`
RopeParameters struct {
RopeTheta float32 `json:"rope_theta"`
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
MRoPESection []int32 `json:"mrope_section"`
} `json:"rope_parameters"`
} `json:"text_config"`
ImageStartTokenID int32 `json:"image_start_token_id"`
ImageEndTokenID int32 `json:"image_end_token_id"`
ImageTokenID int32 `json:"image_token_id"`
}
if err := manifest.ReadConfigJSON("vision_language_encoder/config.json", &rawCfg); err != nil {
return fmt.Errorf("config: %w", err)
}
cfg := &VisionLanguageConfig{
HiddenSize: rawCfg.TextConfig.HiddenSize,
NumHiddenLayers: rawCfg.TextConfig.NumHiddenLayers,
IntermediateSize: rawCfg.TextConfig.IntermediateSize,
NumAttentionHeads: rawCfg.TextConfig.NumAttentionHeads,
NumKeyValueHeads: rawCfg.TextConfig.NumKeyValueHeads,
VocabSize: rawCfg.TextConfig.VocabSize,
RMSNormEps: rawCfg.TextConfig.RMSNormEps,
VisionVocabSize: rawCfg.TextConfig.VisionVocabSize,
RopeTheta: rawCfg.TextConfig.RopeParameters.RopeTheta,
PartialRotaryFactor: rawCfg.TextConfig.RopeParameters.PartialRotaryFactor,
MRoPESection: rawCfg.TextConfig.RopeParameters.MRoPESection,
ImageStartTokenID: rawCfg.ImageStartTokenID,
ImageEndTokenID: rawCfg.ImageEndTokenID,
ImageTokenID: rawCfg.ImageTokenID,
}
cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads
m.Config = cfg
// Pre-allocate layers
m.Layers = make([]*GLMBlock, cfg.NumHiddenLayers)
// Load weights
weights, err := imagegen.LoadWeightsFromManifest(manifest, "vision_language_encoder")
if err != nil {
return fmt.Errorf("weights: %w", err)
}
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
return fmt.Errorf("load weights: %w", err)
}
defer weights.ReleaseAll()
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return fmt.Errorf("load module: %w", err)
}
m.initComputedFields()
fmt.Printf("✓ [%d layers]\n", cfg.NumHiddenLayers)
return nil
}
// LoadFromPath loads the vision-language encoder from a directory path
func (m *VisionLanguageEncoder) LoadFromPath(path string) error {
fmt.Print(" Loading vision-language encoder... ")
// Load config
var rawCfg struct {
TextConfig struct {
HiddenSize int32 `json:"hidden_size"`
NumHiddenLayers int32 `json:"num_hidden_layers"`
IntermediateSize int32 `json:"intermediate_size"`
NumAttentionHeads int32 `json:"num_attention_heads"`
NumKeyValueHeads int32 `json:"num_key_value_heads"`
VocabSize int32 `json:"vocab_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
VisionVocabSize int32 `json:"vision_vocab_size"`
RopeParameters struct {
RopeTheta float32 `json:"rope_theta"`
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
MRoPESection []int32 `json:"mrope_section"`
} `json:"rope_parameters"`
} `json:"text_config"`
ImageStartTokenID int32 `json:"image_start_token_id"`
ImageEndTokenID int32 `json:"image_end_token_id"`
ImageTokenID int32 `json:"image_token_id"`
}
configPath := filepath.Join(path, "config.json")
data, err := os.ReadFile(configPath)
if err != nil {
return fmt.Errorf("read config: %w", err)
}
if err := json.Unmarshal(data, &rawCfg); err != nil {
return fmt.Errorf("parse config: %w", err)
}
cfg := &VisionLanguageConfig{
HiddenSize: rawCfg.TextConfig.HiddenSize,
NumHiddenLayers: rawCfg.TextConfig.NumHiddenLayers,
IntermediateSize: rawCfg.TextConfig.IntermediateSize,
NumAttentionHeads: rawCfg.TextConfig.NumAttentionHeads,
NumKeyValueHeads: rawCfg.TextConfig.NumKeyValueHeads,
VocabSize: rawCfg.TextConfig.VocabSize,
RMSNormEps: rawCfg.TextConfig.RMSNormEps,
VisionVocabSize: rawCfg.TextConfig.VisionVocabSize,
RopeTheta: rawCfg.TextConfig.RopeParameters.RopeTheta,
PartialRotaryFactor: rawCfg.TextConfig.RopeParameters.PartialRotaryFactor,
MRoPESection: rawCfg.TextConfig.RopeParameters.MRoPESection,
ImageStartTokenID: rawCfg.ImageStartTokenID,
ImageEndTokenID: rawCfg.ImageEndTokenID,
ImageTokenID: rawCfg.ImageTokenID,
}
cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads
m.Config = cfg
// Pre-allocate layers
m.Layers = make([]*GLMBlock, cfg.NumHiddenLayers)
// Load weights
weights, err := safetensors.LoadModelWeights(path)
if err != nil {
return fmt.Errorf("weights: %w", err)
}
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
return fmt.Errorf("load weights: %w", err)
}
defer weights.ReleaseAll()
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return fmt.Errorf("load module: %w", err)
}
m.initComputedFields()
fmt.Printf("✓ [%d layers]\n", cfg.NumHiddenLayers)
return nil
}
func (m *VisionLanguageEncoder) initComputedFields() {
cfg := m.Config
for _, block := range m.Layers {
block.SelfAttn.NHeads = cfg.NumAttentionHeads
block.SelfAttn.NKVHeads = cfg.NumKeyValueHeads
block.SelfAttn.HeadDim = cfg.HeadDim
block.SelfAttn.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
block.SelfAttn.PartialRotary = cfg.PartialRotaryFactor
block.SelfAttn.RopeTheta = cfg.RopeTheta
block.SelfAttn.MRoPESection = cfg.MRoPESection
// Set norm eps
block.InputLayerNorm.Eps = cfg.RMSNormEps
block.PostSelfAttnNorm.Eps = cfg.RMSNormEps
block.PostAttnLayerNorm.Eps = cfg.RMSNormEps
block.PostMLPLayerNorm.Eps = cfg.RMSNormEps
}
m.FinalNorm.Eps = cfg.RMSNormEps
}
// Generate autoregressively generates visual tokens with KV caching
func (m *VisionLanguageEncoder) Generate(
prompt string,
tok *GLMTokenizer,
maxTokens int32,
temperature float32,
topP float32,
seed int64,
targetHeight, targetWidth int32,
progressFn func(int),
) *mlx.Array {
cfg := m.Config
// Encode prompt with grid tokens using GLM tokenizer
// Format: {prompt}<sop>{h} {w}<eop><sop>{prev_h} {prev_w}<eop><|dit_token_16384|>
tokens := tok.EncodeForGeneration(prompt, targetHeight, targetWidth)
// Calculate grid dimensions for MRoPE position IDs
factor := int32(32)
tokenH := targetHeight / factor
tokenW := targetWidth / factor
ratio := float64(tokenH) / float64(tokenW)
prevTokenH := int32(math.Sqrt(ratio) * 16)
prevTokenW := int32(math.Sqrt(1.0/ratio) * 16)
prevGridSize := prevTokenH * prevTokenW
// Create KV cache for all layers
cache := NewARCache(cfg.NumHiddenLayers)
defer cache.Free()
// ===== PREFILL PHASE =====
// Process entire prompt at once, populate cache
promptLen := int32(len(tokens))
tokenArr := mlx.NewArrayInt32(tokens, []int32{1, promptLen})
h := m.EmbedTokens.Forward(tokenArr)
tokenArr.Free()
mlx.Eval(h)
// Compute position IDs for prefill (text tokens use same position for all dims)
prefillPositions := make([][]int32, 3)
for dim := 0; dim < 3; dim++ {
prefillPositions[dim] = make([]int32, promptLen)
for i := int32(0); i < promptLen; i++ {
prefillPositions[dim][i] = i
}
}
// Forward through layers (prefill)
for i, layer := range m.Layers {
oldH := h
h = layer.ForwardWithCache(h, promptLen, 0, cfg.RMSNormEps, cache.Layers[i], prefillPositions)
if i > 0 {
oldH.Free()
}
}
// Eval h and cache arrays together so cache is materialized
evalArgs := []*mlx.Array{h}
for _, lc := range cache.Layers {
evalArgs = append(evalArgs, lc.State()...)
}
mlx.Eval(evalArgs...)
// Final norm and get logits for last position
preNormH := h
h = m.FinalNorm.Forward(h, cfg.RMSNormEps)
preNormH.Free()
lastH := mlx.Slice(h, []int32{0, promptLen - 1, 0}, []int32{1, promptLen, cfg.HiddenSize})
h.Free()
lastH = mlx.Reshape(lastH, 1, cfg.HiddenSize)
logits := mlx.Matmul(lastH, mlx.Transpose(m.LMHead, 1, 0))
lastH.Free()
// Sample first token
var sampleCounter int64 = 0
nextToken := sampleVisualToken(logits, temperature, topP, cfg, seed, &sampleCounter)
logits.Free()
// AR generation loop with caching
// Visual tokens are stored as VQ codebook indices [0, 16383]
// The LM head outputs indices [0, 16511] where:
// - [0, 16383] are VQ codes
// - 16384 is BOS
// - 16385 is EOS
visualTokens := make([]int32, 0, maxTokens)
posOffset := promptLen
visualTokenIdx := int32(0) // Index within visual token sequence for grid position calculation
// Preallocate slice for old cache state to reuse
oldCacheState := make([]*mlx.Array, 0, len(m.Layers)*2)
for i := int32(0); i < maxTokens; i++ {
if progressFn != nil {
progressFn(int(i))
}
// Check for end token (EOS = 16385)
if nextToken == cfg.ImageEndTokenID {
break
}
// Skip BOS token (16384), only store actual VQ codes [0, 16383]
if nextToken == cfg.ImageStartTokenID {
// BOS token - skip storing but continue generation
} else if nextToken < cfg.ImageStartTokenID {
// This is an actual VQ code [0, 16383] - store it
visualTokens = append(visualTokens, nextToken)
}
// Tokens >= 16386 are other special tokens, skip them
// ===== DECODE PHASE =====
// Save old cache state before forward (to free after eval)
oldCacheState = oldCacheState[:0]
for _, lc := range cache.Layers {
oldCacheState = append(oldCacheState, lc.State()...)
}
// Only process the new token, use cached K,V
tokenArr := mlx.NewArrayInt32([]int32{nextToken}, []int32{1, 1})
h := m.EmbedTokens.Forward(tokenArr)
tokenArr.Free()
// Compute MRoPE position IDs for this visual token
// Visual tokens are arranged in two grids: prev grid then target grid
// Position dimensions: [temporal, height, width]
decodePositions := computeVisualTokenPositions(
visualTokenIdx, posOffset, promptLen,
prevTokenH, prevTokenW, prevGridSize,
tokenH, tokenW,
)
// Forward through layers (decode with cache)
for j, layer := range m.Layers {
oldH := h
h = layer.ForwardWithCache(h, 1, posOffset, cfg.RMSNormEps, cache.Layers[j], decodePositions)
if j > 0 { // Don't free the embedding on first layer
oldH.Free()
}
}
// Eval h and new cache state
newCacheState := make([]*mlx.Array, 0, len(m.Layers)*2)
for _, lc := range cache.Layers {
newCacheState = append(newCacheState, lc.State()...)
}
mlx.Eval(append([]*mlx.Array{h}, newCacheState...)...)
// Free old cache state (now that new state is evaluated)
for _, arr := range oldCacheState {
if arr != nil {
arr.Free()
}
}
// Final norm
preNormH := h
h = m.FinalNorm.Forward(h, cfg.RMSNormEps)
preNormH.Free()
// Get logits (h is already [1, 1, hidden_size])
h = mlx.Reshape(h, 1, cfg.HiddenSize)
logits := mlx.Matmul(h, mlx.Transpose(m.LMHead, 1, 0))
h.Free()
// Sample next token
nextToken = sampleVisualToken(logits, temperature, topP, cfg, seed, &sampleCounter)
logits.Free()
posOffset++
visualTokenIdx++
// Periodically clear cache to release intermediate memory
if i%256 == 0 {
mlx.ClearCache()
}
}
if len(visualTokens) == 0 {
// Return at least one token to avoid empty tensor issues
visualTokens = append(visualTokens, 0)
}
return mlx.NewArrayInt32(visualTokens, []int32{1, int32(len(visualTokens))})
}
// computeVisualTokenPositions computes MRoPE position IDs for a visual token
// Returns [3][1] position IDs for temporal, height, and width dimensions
//
// MRoPE position encoding for GLM-Image visual tokens:
// - temporal: CONSTANT within each grid (= decode_pos at grid start)
// - height: decode_pos + row index within grid
// - width: decode_pos + column index within grid
//
// Between grids, decode_pos advances by max(grid_h, grid_w) to ensure
// sufficient positional separation.
func computeVisualTokenPositions(
visualIdx int32, absPos int32, promptLen int32,
prevH, prevW, prevSize int32,
targetH, targetW int32,
) [][]int32 {
positions := make([][]int32, 3)
for dim := 0; dim < 3; dim++ {
positions[dim] = make([]int32, 1)
}
// First grid (prev grid) starts at decode_pos = promptLen
prevGridDecodePos := promptLen
// Second grid (target grid) starts after first grid
// next_pos = prev_decode_pos + max(prevH, prevW)
maxPrev := prevH
if prevW > maxPrev {
maxPrev = prevW
}
targetGridDecodePos := prevGridDecodePos + maxPrev
// Compute position IDs based on which grid the token is in
if visualIdx < prevSize {
// Token is in the prev grid (prev_token_h × prev_token_w)
row := visualIdx / prevW
col := visualIdx % prevW
// temporal is CONSTANT for all tokens in this grid
positions[0][0] = prevGridDecodePos
// height and width are relative to grid's decode_pos
positions[1][0] = prevGridDecodePos + row
positions[2][0] = prevGridDecodePos + col
} else {
// Token is in the target grid (token_h × token_w)
targetIdx := visualIdx - prevSize
row := targetIdx / targetW
col := targetIdx % targetW
// temporal is CONSTANT for all tokens in this grid
positions[0][0] = targetGridDecodePos
// height and width are relative to grid's decode_pos
positions[1][0] = targetGridDecodePos + row
positions[2][0] = targetGridDecodePos + col
}
_ = targetH // Used for documentation clarity
_ = absPos // No longer used - kept for API compatibility
return positions
}
// sampleVisualToken samples from the visual vocabulary using top-p (nucleus) sampling
// Note: For GLM-Image, greedy decoding is not allowed as it may cause repetitive outputs
// Returns a visual token ID in range [0, 16511] which directly indexes into the embedding table
// sampleCounter is incremented for each call to ensure different random values
func sampleVisualToken(logits *mlx.Array, temperature float32, topP float32, cfg *VisionLanguageConfig, seed int64, sampleCounter *int64) int32 {
// The LMHead outputs logits for visual tokens only (shape [1, 16512])
// Output index directly corresponds to vocab ID [0, 16511]
// No offset needed - the visual tokens are at vocab IDs [0, 16511]
visualLogits := logits
// Apply temperature
if temperature != 1.0 && temperature > 0 {
visualLogits = mlx.DivScalar(visualLogits, temperature)
}
// Apply softmax to get probabilities
probs := mlx.Softmax(visualLogits, -1)
mlx.Eval(probs)
// Get the sampled index using top-p sampling
// This directly gives us the vocab ID in [0, 16511]
// Special tokens: 16384 = BOS, 16385 = EOS
// Use seed + counter for reproducible but different random values
effectiveSeed := seed + *sampleCounter
*sampleCounter++
return sampleTopP(probs, topP, effectiveSeed)
}
// sampleTopP implements nucleus (top-p) sampling
// probs: [1, vocab_size] probability distribution
// topP: cumulative probability threshold (e.g., 0.75)
// seed: random seed for reproducible sampling
func sampleTopP(probs *mlx.Array, topP float32, seed int64) int32 {
// Negate probs for descending sort (Argsort only does ascending)
negProbs := mlx.MulScalar(probs, -1)
sortedIndices := mlx.Argsort(negProbs, -1)
sortedProbs := mlx.TakeAlongAxis(probs, sortedIndices, -1)
cumProbs := mlx.Cumsum(sortedProbs, -1)
mlx.Eval(sortedIndices, sortedProbs, cumProbs)
// Find cutoff index where cumulative probability exceeds topP
probsData := sortedProbs.Data()
cumProbsData := cumProbs.Data()
indicesData := sortedIndices.DataInt32()
// Calculate cutoff and renormalize
var cutoffIdx int
var totalProb float32
for i, cp := range cumProbsData {
totalProb += probsData[i]
if cp >= topP {
cutoffIdx = i + 1 // Include this token
break
}
}
if cutoffIdx == 0 {
cutoffIdx = len(probsData) // Use all tokens if topP is very high
}
// Sample from the truncated distribution
// Renormalize the truncated probabilities
truncatedProbs := make([]float32, cutoffIdx)
for i := 0; i < cutoffIdx; i++ {
truncatedProbs[i] = probsData[i] / totalProb
}
// Sample using random number with provided seed for reproducibility
r := mlx.RandomUniform([]int32{1}, uint64(seed))
mlx.Eval(r)
randVal := r.Data()[0]
// Find the sampled token
var cumulative float32
for i := 0; i < cutoffIdx; i++ {
cumulative += truncatedProbs[i]
if randVal < cumulative {
return indicesData[i]
}
}
// Fallback to the last token in truncated set
return indicesData[cutoffIdx-1]
}
// Forward for GLMBlock
func (b *GLMBlock) Forward(x *mlx.Array, seqLen int32, eps float32) *mlx.Array {
return b.ForwardWithCache(x, seqLen, 0, eps, nil, nil)
}
// ForwardWithCache performs block forward with optional KV caching and MRoPE
// positionIDs: [3][L] - position indices for MRoPE (nil = use sequential positions)
func (b *GLMBlock) ForwardWithCache(x *mlx.Array, seqLen int32, posOffset int32, eps float32, kvcache cache.Cache, positionIDs [][]int32) *mlx.Array {
// Pre-attention norm
normed := b.InputLayerNorm.Forward(x, eps)
// Self-attention with RoPE/MRoPE and cache
attnOut := b.SelfAttn.ForwardWithCache(normed, seqLen, posOffset, kvcache, positionIDs)
// Post-attention norm (GLM-4 style)
attnOut = b.PostSelfAttnNorm.Forward(attnOut, eps)
// Residual connection
x = mlx.Add(x, attnOut)
// Post-attention layer norm
normed = b.PostAttnLayerNorm.Forward(x, eps)
// MLP
mlpOut := b.MLP.Forward(normed)
// Post-MLP norm
mlpOut = b.PostMLPLayerNorm.Forward(mlpOut, eps)
// Residual connection
x = mlx.Add(x, mlpOut)
return x
}
// Forward for GLMAttention (without cache - used for prefill)
func (attn *GLMAttention) Forward(x *mlx.Array, seqLen int32) *mlx.Array {
return attn.ForwardWithCache(x, seqLen, 0, nil, nil)
}
// ForwardWithCache performs attention with optional KV caching and MRoPE
// posOffset is the position offset for RoPE (0 for prefill, cached_len for decode)
// positionIDs: [3][L] - if nil, uses sequential positions for all dims (text mode)
// kvcache is updated in-place if provided
func (attn *GLMAttention) ForwardWithCache(x *mlx.Array, seqLen int32, posOffset int32, kvcache cache.Cache, positionIDs [][]int32) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
// Q, K, V projections
q := mlx.Matmul(x, mlx.Transpose(attn.QProj, 1, 0))
k := mlx.Matmul(x, mlx.Transpose(attn.KProj, 1, 0))
v := mlx.Matmul(x, mlx.Transpose(attn.VProj, 1, 0))
// Add biases
if attn.QBias != nil {
q = mlx.Add(q, attn.QBias)
}
if attn.KBias != nil {
k = mlx.Add(k, attn.KBias)
}
if attn.VBias != nil {
v = mlx.Add(v, attn.VBias)
}
// Reshape to [B, L, nheads, head_dim]
q = mlx.Reshape(q, B, L, attn.NHeads, attn.HeadDim)
k = mlx.Reshape(k, B, L, attn.NKVHeads, attn.HeadDim)
v = mlx.Reshape(v, B, L, attn.NKVHeads, attn.HeadDim)
// Apply partial RoPE or MRoPE
rotaryDim := int32(float32(attn.HeadDim) * attn.PartialRotary)
if len(attn.MRoPESection) == 3 && positionIDs != nil {
// Use MRoPE with explicit position IDs
q = applyPartialMRoPE(q, positionIDs, rotaryDim, attn.RopeTheta, attn.MRoPESection)
k = applyPartialMRoPE(k, positionIDs, rotaryDim, attn.RopeTheta, attn.MRoPESection)
} else if len(attn.MRoPESection) == 3 {
// Use MRoPE with sequential positions (same for all dims - text mode)
seqPositions := make([][]int32, 3)
for dim := 0; dim < 3; dim++ {
seqPositions[dim] = make([]int32, L)
for i := int32(0); i < L; i++ {
seqPositions[dim][i] = i + posOffset
}
}
q = applyPartialMRoPE(q, seqPositions, rotaryDim, attn.RopeTheta, attn.MRoPESection)
k = applyPartialMRoPE(k, seqPositions, rotaryDim, attn.RopeTheta, attn.MRoPESection)
} else {
// Fallback to standard RoPE
q = applyPartialRoPEWithOffset(q, L, posOffset, rotaryDim, attn.RopeTheta)
k = applyPartialRoPEWithOffset(k, L, posOffset, rotaryDim, attn.RopeTheta)
}
// Transpose to [B, nheads, L, head_dim]
q = mlx.Transpose(q, 0, 2, 1, 3)
k = mlx.Transpose(k, 0, 2, 1, 3)
v = mlx.Transpose(v, 0, 2, 1, 3)
// Update cache and get full K, V for attention
if kvcache != nil {
k, v = kvcache.Update(k, v, int(L))
}
// Repeat KV for GQA
kExpanded := k
vExpanded := v
if attn.NKVHeads < attn.NHeads {
repeats := attn.NHeads / attn.NKVHeads
kExpanded = repeatKV(k, repeats)
vExpanded = repeatKV(v, repeats)
}
// Scaled dot-product attention with causal mask
out := mlx.ScaledDotProductAttention(q, kExpanded, vExpanded, attn.Scale, true)
// Transpose back [B, nheads, L, head_dim] -> [B, L, nheads, head_dim]
out = mlx.Transpose(out, 0, 2, 1, 3)
// Reshape to [B, L, hidden_size]
out = mlx.Reshape(out, B, L, attn.NHeads*attn.HeadDim)
// Output projection
out = mlx.Matmul(out, mlx.Transpose(attn.OProj, 1, 0))
return out
}
// applyPartialRoPE applies RoPE to only the first rotaryDim dimensions
func applyPartialRoPE(x *mlx.Array, seqLen int32, rotaryDim int32, theta float32) *mlx.Array {
return applyPartialRoPEWithOffset(x, seqLen, 0, rotaryDim, theta)
}
// applyPartialRoPEWithOffset applies RoPE with a position offset
func applyPartialRoPEWithOffset(x *mlx.Array, seqLen int32, posOffset int32, rotaryDim int32, theta float32) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
H := shape[2]
D := shape[3]
if rotaryDim <= 0 || rotaryDim > D {
rotaryDim = D
}
// Split into rotary and pass-through parts
xRot := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, L, H, rotaryDim})
xPass := mlx.Slice(x, []int32{0, 0, 0, rotaryDim}, []int32{B, L, H, D})
// Apply RoPE to rotary part with position offset
xRot = applyRoPEWithOffset(xRot, L, posOffset, theta)
// Concatenate back
return mlx.Concatenate([]*mlx.Array{xRot, xPass}, 3)
}
// applyPartialMRoPE applies Multi-dimensional RoPE (MRoPE) to the first rotaryDim dimensions
// positionIDs: [3, L] - position indices for each dimension (temporal, height, width)
// mrope_section: [8, 12, 12] - frequency pairs per dimension
// For text tokens: all 3 dimensions have the same sequential position
// For image tokens: temporal=seq_idx, height=row, width=col
func applyPartialMRoPE(x *mlx.Array, positionIDs [][]int32, rotaryDim int32, theta float32, mropeSection []int32) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
H := shape[2]
D := shape[3]
if rotaryDim <= 0 || rotaryDim > D {
rotaryDim = D
}
// Split into rotary and pass-through parts
xRot := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, L, H, rotaryDim})
xPass := mlx.Slice(x, []int32{0, 0, 0, rotaryDim}, []int32{B, L, H, D})
// Apply MRoPE to rotary part
xRot = applyMRoPE(xRot, positionIDs, theta, mropeSection)
// Concatenate back
return mlx.Concatenate([]*mlx.Array{xRot, xPass}, 3)
}
// applyMRoPE applies multi-dimensional rotary position embedding
// x: [B, L, H, D] where D is the rotary dimension
// positionIDs: [3][L] - positions for temporal, height, width dimensions
// mropeSection: [8, 12, 12] - frequency pairs per dimension
func applyMRoPE(x *mlx.Array, positionIDs [][]int32, theta float32, mropeSection []int32) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
H := shape[2]
D := shape[3]
half := D / 2
// Validate mrope_section sums to half (number of frequency pairs)
var totalPairs int32
for _, s := range mropeSection {
totalPairs += s
}
if totalPairs != half {
// Fallback to standard RoPE if section doesn't match
return applyRoPEWithOffset(x, L, 0, theta)
}
// Build angles for each position dimension (matching Python's MRoPE approach)
// Python: compute freqs for all dims, then apply_mrope selects freq ranges, then duplicate
// Order: [temporal_8, height_12, width_12] -> duplicate -> [t8, h12, w12, t8, h12, w12]
angleVals := make([]*mlx.Array, 3)
freqOffset := int32(0)
for dim := 0; dim < 3; dim++ {
numPairs := mropeSection[dim]
if numPairs == 0 {
continue
}
// Compute inverse frequencies for this section
// Each dimension uses DIFFERENT frequency ranges:
// - Temporal: frequencies 0 to section[0]-1
// - Height: frequencies section[0] to section[0]+section[1]-1
// - Width: frequencies section[0]+section[1] to sum(section)-1
freqsArr := make([]float32, numPairs)
for i := int32(0); i < numPairs; i++ {
globalIdx := freqOffset + i
freqsArr[i] = float32(1.0 / math.Pow(float64(theta), float64(2*globalIdx)/float64(D)))
}
freqs := mlx.NewArray(freqsArr, []int32{numPairs})
// Position indices for this dimension
posArr := make([]float32, L)
for i := int32(0); i < L; i++ {
posArr[i] = float32(positionIDs[dim][i])
}
pos := mlx.NewArray(posArr, []int32{L})
// Compute angles: [L, numPairs] = outer(pos, freqs)
posExpanded := mlx.Reshape(pos, L, 1)
freqsExpanded := mlx.Reshape(freqs, 1, numPairs)
angleVals[dim] = mlx.Mul(posExpanded, freqsExpanded)
freqOffset += numPairs
}
// Concatenate all sections: [L, half] = [L, 32]
allAngles := mlx.Concatenate(angleVals, 1)
// Duplicate AFTER concatenation: [L, D] = [L, 64]
// This gives: [temporal_8, height_12, width_12, temporal_8, height_12, width_12]
allAngles = mlx.Concatenate([]*mlx.Array{allAngles, allAngles}, 1)
// Compute cos/sin
allCos := mlx.Cos(allAngles)
allSin := mlx.Sin(allAngles)
// Reshape for broadcasting: [1, L, 1, D] to match x [B, L, H, D]
allCos = mlx.Reshape(allCos, 1, L, 1, D)
allSin = mlx.Reshape(allSin, 1, L, 1, D)
// x_rotated = cat([-x_imag, x_real], dim=-1)
x1 := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, L, H, half}) // x_real
x2 := mlx.Slice(x, []int32{0, 0, 0, half}, []int32{B, L, H, D}) // x_imag
x2Neg := mlx.MulScalar(x2, -1) // -x_imag
xRotated := mlx.Concatenate([]*mlx.Array{x2Neg, x1}, 3) // [-x_imag, x_real]
// out = x * cos + x_rotated * sin
return mlx.Add(mlx.Mul(x, allCos), mlx.Mul(xRotated, allSin))
}
// applyRoPE applies rotary position embedding
func applyRoPE(x *mlx.Array, seqLen int32, theta float32) *mlx.Array {
return applyRoPEWithOffset(x, seqLen, 0, theta)
}
// applyRoPEWithOffset applies rotary position embedding with position offset
// Uses the split-half approach (matches diffusers GLM-Image with use_real_unbind_dim=-2)
func applyRoPEWithOffset(x *mlx.Array, seqLen int32, posOffset int32, theta float32) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
H := shape[2]
D := shape[3]
half := D / 2
// Compute inverse frequencies: 1 / (theta^(2i/d))
freqsArr := make([]float32, half)
for i := int32(0); i < half; i++ {
freqsArr[i] = float32(1.0 / math.Pow(float64(theta), float64(2*i)/float64(D)))
}
freqs := mlx.NewArray(freqsArr, []int32{half})
// Position indices with offset
posArr := make([]float32, L)
for i := int32(0); i < L; i++ {
posArr[i] = float32(i + posOffset)
}
pos := mlx.NewArray(posArr, []int32{L})
// Compute angles: [L, half] = outer(pos, freqs)
posExpanded := mlx.Reshape(pos, L, 1)
freqsExpanded := mlx.Reshape(freqs, 1, half)
angles := mlx.Mul(posExpanded, freqsExpanded)
// Duplicate angles to match diffusers: cat([angles, angles], dim=-1) -> [L, D]
anglesDup := mlx.Concatenate([]*mlx.Array{angles, angles}, 1)
// Cos and sin: [L, 1, D] for broadcasting to [B, L, H, D]
cosVals := mlx.Cos(anglesDup)
sinVals := mlx.Sin(anglesDup)
cosVals = mlx.Reshape(cosVals, L, 1, D)
sinVals = mlx.Reshape(sinVals, L, 1, D)
// x_rotated = cat([-x_imag, x_real], dim=-1) where x_real=x[..., :half], x_imag=x[..., half:]
x1 := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, L, H, half}) // x_real
x2 := mlx.Slice(x, []int32{0, 0, 0, half}, []int32{B, L, H, D}) // x_imag
x2Neg := mlx.MulScalar(x2, -1) // -x_imag
xRotated := mlx.Concatenate([]*mlx.Array{x2Neg, x1}, 3) // [-x_imag, x_real]
// out = x * cos + x_rotated * sin
return mlx.Add(mlx.Mul(x, cosVals), mlx.Mul(xRotated, sinVals))
}
// repeatKV repeats key/value heads for GQA
func repeatKV(x *mlx.Array, repeats int32) *mlx.Array {
if repeats == 1 {
return x
}
shape := x.Shape()
// x: [B, nkvheads, L, head_dim]
x = mlx.ExpandDims(x, 2)
// x: [B, nkvheads, 1, L, head_dim]
x = mlx.Tile(x, []int32{1, 1, repeats, 1, 1})
// x: [B, nkvheads, repeats, L, head_dim]
return mlx.Reshape(x, shape[0], shape[1]*repeats, shape[2], shape[3])
}
// Forward for GLMMLP (fused gate_up SwiGLU)
func (m *GLMMLP) Forward(x *mlx.Array) *mlx.Array {
// gate_up_proj outputs [gate, up] concatenated
gateUp := mlx.Matmul(x, mlx.Transpose(m.GateUpProj, 1, 0))
shape := gateUp.Shape()
halfDim := shape[len(shape)-1] / 2
// Split into gate and up
gate := mlx.Slice(gateUp, []int32{0, 0, 0}, []int32{shape[0], shape[1], halfDim})
up := mlx.Slice(gateUp, []int32{0, 0, halfDim}, []int32{shape[0], shape[1], shape[2]})
// SwiGLU: silu(gate) * up
gate = mlx.SiLU(gate)
h := mlx.Mul(gate, up)
// Down projection
return mlx.Matmul(h, mlx.Transpose(m.DownProj, 1, 0))
}

View File

@@ -1,22 +0,0 @@
package imagegen
import (
"io"
"strings"
)
// QuantizingTensorLayerCreator creates tensor layers with optional quantization.
// When quantize is true, returns multiple layers (weight + scales + biases).
type QuantizingTensorLayerCreator func(r io.Reader, name, dtype string, shape []int32, quantize bool) ([]LayerInfo, error)
// ShouldQuantize returns true if a tensor should be quantized.
// Quantizes linear weights only, skipping VAE, embeddings, norms, and biases.
func ShouldQuantize(name, component string) bool {
if component == "vae" {
return false
}
if strings.Contains(name, "embed") || strings.Contains(name, "norm") {
return false
}
return strings.HasSuffix(name, ".weight")
}

View File

@@ -19,15 +19,9 @@ import (
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/models/glm_image"
"github.com/ollama/ollama/x/imagegen/models/zimage"
)
// ImageModel is the interface for image generation models
type ImageModel interface {
GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error)
}
// Request is the image generation request format
type Request struct {
Prompt string `json:"prompt"`
@@ -47,9 +41,8 @@ type Response struct {
// Server holds the model and handles requests
type Server struct {
mu sync.Mutex
model ImageModel
model *zimage.Model
modelName string
modelType string // "zimage" or "glm_image"
}
// Execute is the entry point for the image runner subprocess
@@ -79,35 +72,15 @@ func Execute(args []string) error {
requiredMemory/(1024*1024*1024), availableMemory/(1024*1024*1024))
}
// Detect model type and load appropriate model
modelType, err := detectModelType(*modelName)
if err != nil {
return fmt.Errorf("failed to detect model type: %w", err)
}
var model ImageModel
switch modelType {
case "GlmImagePipeline":
slog.Info("loading GLM-Image model")
m := &glm_image.Model{}
if err := m.Load(*modelName); err != nil {
return fmt.Errorf("failed to load GLM-Image model: %w", err)
}
model = m
default:
// Default to zimage for ZImagePipeline, FluxPipeline, and unknown types
slog.Info("loading Z-Image model")
m := &zimage.Model{}
if err := m.Load(*modelName); err != nil {
return fmt.Errorf("failed to load Z-Image model: %w", err)
}
model = m
// Load model
model := &zimage.Model{}
if err := model.Load(*modelName); err != nil {
return fmt.Errorf("failed to load model: %w", err)
}
server := &Server{
model: model,
modelName: *modelName,
modelType: modelType,
}
// Set up HTTP handlers
@@ -171,13 +144,7 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
req.Height = 1024
}
if req.Steps <= 0 {
// Default steps depend on model type
switch s.modelType {
case "GlmImagePipeline":
req.Steps = 50 // GLM-Image default
default:
req.Steps = 9 // Z-Image turbo default
}
req.Steps = 9
}
if req.Seed <= 0 {
req.Seed = time.Now().UnixNano()
@@ -192,9 +159,25 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
return
}
// Generate image using interface method
// Generate image
ctx := r.Context()
img, err := s.model.GenerateImage(ctx, req.Prompt, req.Width, req.Height, req.Steps, req.Seed)
img, err := s.model.GenerateFromConfig(ctx, &zimage.GenerateConfig{
Prompt: req.Prompt,
Width: req.Width,
Height: req.Height,
Steps: req.Steps,
Seed: req.Seed,
Progress: func(step, total int) {
resp := Response{
Content: fmt.Sprintf("\rGenerating: step %d/%d", step, total),
Done: false,
}
data, _ := json.Marshal(resp)
w.Write(data)
w.Write([]byte("\n"))
flusher.Flush()
},
})
if err != nil {
// Don't send error for cancellation
@@ -233,35 +216,3 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("\n"))
flusher.Flush()
}
// detectModelType reads the model manifest and returns the pipeline class name
func detectModelType(modelName string) (string, error) {
manifest, err := imagegen.LoadManifest(modelName)
if err != nil {
return "", err
}
data, err := manifest.ReadConfig("model_index.json")
if err != nil {
return "ZImagePipeline", nil // Default to Z-Image
}
// Try both _class_name (diffusers format) and architecture (ollama format)
var index struct {
ClassName string `json:"_class_name"`
Architecture string `json:"architecture"`
}
if err := json.Unmarshal(data, &index); err != nil {
return "ZImagePipeline", nil
}
// Prefer _class_name, fall back to architecture
className := index.ClassName
if className == "" {
className = index.Architecture
}
if className == "" {
return "ZImagePipeline", nil
}
return className, nil
}

271
x/server/show.go Normal file
View File

@@ -0,0 +1,271 @@
package server
import (
"encoding/binary"
"encoding/json"
"fmt"
"io"
"os"
"strings"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/x/imagegen"
)
// modelConfig represents the HuggingFace config.json structure
type modelConfig struct {
Architectures []string `json:"architectures"`
ModelType string `json:"model_type"`
HiddenSize int `json:"hidden_size"`
NumHiddenLayers int `json:"num_hidden_layers"`
MaxPositionEmbeddings int `json:"max_position_embeddings"`
IntermediateSize int `json:"intermediate_size"`
NumAttentionHeads int `json:"num_attention_heads"`
NumKeyValueHeads int `json:"num_key_value_heads"`
VocabSize int `json:"vocab_size"`
RMSNormEps float64 `json:"rms_norm_eps"`
RopeTheta float64 `json:"rope_theta"`
TorchDtype string `json:"torch_dtype"`
TextConfig *struct {
HiddenSize int `json:"hidden_size"`
MaxPositionEmbeddings int `json:"max_position_embeddings"`
NumHiddenLayers int `json:"num_hidden_layers"`
} `json:"text_config"`
}
// GetSafetensorsLLMInfo extracts model information from safetensors LLM models.
// It reads the config.json layer and returns a map compatible with GGML's KV format.
func GetSafetensorsLLMInfo(modelName string) (map[string]any, error) {
manifest, err := imagegen.LoadManifest(modelName)
if err != nil {
return nil, fmt.Errorf("failed to load manifest: %w", err)
}
var config modelConfig
if err := manifest.ReadConfigJSON("config.json", &config); err != nil {
return nil, fmt.Errorf("failed to read config.json: %w", err)
}
// Calculate total tensor bytes from manifest layers
var totalBytes int64
var tensorCount int64
for _, layer := range manifest.Manifest.Layers {
if layer.MediaType == "application/vnd.ollama.image.tensor" {
totalBytes += layer.Size
tensorCount++
}
}
return buildModelInfo(config, totalBytes, tensorCount), nil
}
// buildModelInfo constructs the model info map from config and tensor stats.
// This is separated for testability.
func buildModelInfo(config modelConfig, totalTensorBytes, tensorCount int64) map[string]any {
// Determine architecture
arch := config.ModelType
if arch == "" && len(config.Architectures) > 0 {
// Convert HuggingFace architecture name to Ollama format
// e.g., "Gemma3ForCausalLM" -> "gemma3"
hfArch := config.Architectures[0]
arch = strings.ToLower(hfArch)
arch = strings.TrimSuffix(arch, "forcausallm")
arch = strings.TrimSuffix(arch, "forconditionalgeneration")
}
// Use text_config values if they exist (for multimodal models)
hiddenSize := config.HiddenSize
maxPosEmbed := config.MaxPositionEmbeddings
numLayers := config.NumHiddenLayers
if config.TextConfig != nil {
if config.TextConfig.HiddenSize > 0 {
hiddenSize = config.TextConfig.HiddenSize
}
if config.TextConfig.MaxPositionEmbeddings > 0 {
maxPosEmbed = config.TextConfig.MaxPositionEmbeddings
}
if config.TextConfig.NumHiddenLayers > 0 {
numLayers = config.TextConfig.NumHiddenLayers
}
}
// Get dtype to determine bytes per parameter for count calculation
dtype := config.TorchDtype
// Determine bytes per parameter based on dtype
var bytesPerParam int64 = 2 // default to float16/bfloat16
switch strings.ToLower(dtype) {
case "float32":
bytesPerParam = 4
case "float16", "bfloat16":
bytesPerParam = 2
case "int8", "uint8":
bytesPerParam = 1
}
// Subtract safetensors header overhead (88 bytes per tensor file)
// Each tensor is stored as a minimal safetensors file
totalBytes := totalTensorBytes - tensorCount*88
paramCount := totalBytes / bytesPerParam
info := map[string]any{
"general.architecture": arch,
}
if maxPosEmbed > 0 {
info[fmt.Sprintf("%s.context_length", arch)] = maxPosEmbed
}
if hiddenSize > 0 {
info[fmt.Sprintf("%s.embedding_length", arch)] = hiddenSize
}
if numLayers > 0 {
info[fmt.Sprintf("%s.block_count", arch)] = numLayers
}
if config.NumAttentionHeads > 0 {
info[fmt.Sprintf("%s.attention.head_count", arch)] = config.NumAttentionHeads
}
if config.NumKeyValueHeads > 0 {
info[fmt.Sprintf("%s.attention.head_count_kv", arch)] = config.NumKeyValueHeads
}
if config.IntermediateSize > 0 {
info[fmt.Sprintf("%s.feed_forward_length", arch)] = config.IntermediateSize
}
if config.VocabSize > 0 {
info[fmt.Sprintf("%s.vocab_size", arch)] = config.VocabSize
}
if paramCount > 0 {
info["general.parameter_count"] = paramCount
}
return info
}
// GetSafetensorsTensorInfo extracts tensor information from safetensors model layers.
// Each tensor is stored as a minimal safetensors file with an 88-byte header containing metadata.
func GetSafetensorsTensorInfo(modelName string) ([]api.Tensor, error) {
manifest, err := imagegen.LoadManifest(modelName)
if err != nil {
return nil, fmt.Errorf("failed to load manifest: %w", err)
}
return getTensorInfoFromManifest(manifest)
}
// getTensorInfoFromManifest extracts tensor info from a manifest.
// This is separated for testability.
func getTensorInfoFromManifest(manifest *imagegen.ModelManifest) ([]api.Tensor, error) {
var tensors []api.Tensor
for _, layer := range manifest.Manifest.Layers {
if layer.MediaType != "application/vnd.ollama.image.tensor" {
continue
}
// Read the safetensors header from the blob
blobPath := manifest.BlobPath(layer.Digest)
info, err := readSafetensorsHeader(blobPath)
if err != nil {
// Skip tensors we can't read
continue
}
// Convert shape from int to uint64
shape := make([]uint64, len(info.Shape))
for i, s := range info.Shape {
shape[i] = uint64(s)
}
tensors = append(tensors, api.Tensor{
Name: layer.Name,
Type: info.Dtype,
Shape: shape,
})
}
return tensors, nil
}
// GetSafetensorsDtype returns the torch_dtype from config.json for a safetensors model.
func GetSafetensorsDtype(modelName string) (string, error) {
manifest, err := imagegen.LoadManifest(modelName)
if err != nil {
return "", fmt.Errorf("failed to load manifest: %w", err)
}
var cfg struct {
TorchDtype string `json:"torch_dtype"`
}
if err := manifest.ReadConfigJSON("config.json", &cfg); err != nil {
return "", fmt.Errorf("failed to read config.json: %w", err)
}
return cfg.TorchDtype, nil
}
// safetensorsTensorInfo holds metadata about a tensor from a safetensors header
type safetensorsTensorInfo struct {
Dtype string `json:"dtype"`
Shape []int64 `json:"shape"`
}
// readSafetensorsHeader reads the JSON header from a safetensors file to get tensor metadata.
// Safetensors format: 8-byte header size (little endian) + JSON header + tensor data
func readSafetensorsHeader(path string) (*safetensorsTensorInfo, error) {
f, err := os.Open(path)
if err != nil {
return nil, err
}
defer f.Close()
return parseSafetensorsHeader(f)
}
// parseSafetensorsHeader parses a safetensors header from a reader.
// This is separated for testability.
func parseSafetensorsHeader(r io.Reader) (*safetensorsTensorInfo, error) {
// Read header size (8 bytes, little endian)
var headerSize uint64
if err := binary.Read(r, binary.LittleEndian, &headerSize); err != nil {
return nil, fmt.Errorf("failed to read header size: %w", err)
}
// Sanity check - header shouldn't be too large
if headerSize > 1024*1024 {
return nil, fmt.Errorf("header size too large: %d", headerSize)
}
// Read header JSON
headerBytes := make([]byte, headerSize)
if _, err := io.ReadFull(r, headerBytes); err != nil {
return nil, fmt.Errorf("failed to read header: %w", err)
}
// Parse as map of tensor name -> info
var header map[string]json.RawMessage
if err := json.Unmarshal(headerBytes, &header); err != nil {
return nil, fmt.Errorf("failed to parse header: %w", err)
}
// Find the first (and should be only) tensor entry
for name, raw := range header {
if name == "__metadata__" {
continue
}
var info safetensorsTensorInfo
if err := json.Unmarshal(raw, &info); err != nil {
return nil, fmt.Errorf("failed to parse tensor info: %w", err)
}
return &info, nil
}
return nil, fmt.Errorf("no tensor found in header")
}

605
x/server/show_test.go Normal file
View File

@@ -0,0 +1,605 @@
package server
import (
"bytes"
"encoding/binary"
"encoding/json"
"os"
"path/filepath"
"testing"
"github.com/ollama/ollama/x/imagegen"
)
func TestBuildModelInfo(t *testing.T) {
tests := []struct {
name string
config modelConfig
totalTensorBytes int64
tensorCount int64
wantArch string
wantContextLen int
wantEmbedLen int
wantBlockCount int
wantParamCount int64
}{
{
name: "gemma3 model with model_type",
config: modelConfig{
ModelType: "gemma3",
HiddenSize: 2560,
NumHiddenLayers: 34,
MaxPositionEmbeddings: 131072,
IntermediateSize: 10240,
NumAttentionHeads: 8,
NumKeyValueHeads: 4,
VocabSize: 262144,
TorchDtype: "bfloat16",
},
totalTensorBytes: 8_600_000_088, // ~4.3B params * 2 bytes + 88 bytes header
tensorCount: 1,
wantArch: "gemma3",
wantContextLen: 131072,
wantEmbedLen: 2560,
wantBlockCount: 34,
wantParamCount: 4_300_000_000,
},
{
name: "llama model with architectures array",
config: modelConfig{
Architectures: []string{"LlamaForCausalLM"},
HiddenSize: 4096,
NumHiddenLayers: 32,
MaxPositionEmbeddings: 4096,
IntermediateSize: 11008,
NumAttentionHeads: 32,
NumKeyValueHeads: 32,
VocabSize: 32000,
TorchDtype: "float16",
},
totalTensorBytes: 14_000_000_088, // ~7B params * 2 bytes + 88 bytes header
tensorCount: 1,
wantArch: "llama",
wantContextLen: 4096,
wantEmbedLen: 4096,
wantBlockCount: 32,
wantParamCount: 7_000_000_000,
},
{
name: "multimodal model with text_config",
config: modelConfig{
Architectures: []string{"Gemma3ForConditionalGeneration"},
HiddenSize: 1152, // vision hidden size
TextConfig: &struct {
HiddenSize int `json:"hidden_size"`
MaxPositionEmbeddings int `json:"max_position_embeddings"`
NumHiddenLayers int `json:"num_hidden_layers"`
}{
HiddenSize: 2560,
MaxPositionEmbeddings: 131072,
NumHiddenLayers: 34,
},
NumAttentionHeads: 8,
NumKeyValueHeads: 4,
VocabSize: 262144,
TorchDtype: "bfloat16",
},
totalTensorBytes: 8_600_000_088,
tensorCount: 1,
wantArch: "gemma3",
wantContextLen: 131072,
wantEmbedLen: 2560,
wantBlockCount: 34,
wantParamCount: 4_300_000_000,
},
{
name: "float32 model",
config: modelConfig{
ModelType: "test",
HiddenSize: 512,
NumHiddenLayers: 6,
MaxPositionEmbeddings: 2048,
TorchDtype: "float32",
},
totalTensorBytes: 400_000_088, // 100M params * 4 bytes + 88 bytes header
tensorCount: 1,
wantArch: "test",
wantContextLen: 2048,
wantEmbedLen: 512,
wantBlockCount: 6,
wantParamCount: 100_000_000,
},
{
name: "multiple tensors with header overhead",
config: modelConfig{
ModelType: "test",
HiddenSize: 256,
NumHiddenLayers: 4,
MaxPositionEmbeddings: 1024,
TorchDtype: "bfloat16",
},
totalTensorBytes: 2_000_880, // 1M params * 2 bytes + 10 tensors * 88 bytes
tensorCount: 10,
wantArch: "test",
wantContextLen: 1024,
wantEmbedLen: 256,
wantBlockCount: 4,
wantParamCount: 1_000_000,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
info := buildModelInfo(tt.config, tt.totalTensorBytes, tt.tensorCount)
// Check architecture
if arch, ok := info["general.architecture"].(string); !ok || arch != tt.wantArch {
t.Errorf("architecture = %v, want %v", info["general.architecture"], tt.wantArch)
}
// Check context length
contextKey := tt.wantArch + ".context_length"
if contextLen, ok := info[contextKey].(int); !ok || contextLen != tt.wantContextLen {
t.Errorf("context_length = %v, want %v", info[contextKey], tt.wantContextLen)
}
// Check embedding length
embedKey := tt.wantArch + ".embedding_length"
if embedLen, ok := info[embedKey].(int); !ok || embedLen != tt.wantEmbedLen {
t.Errorf("embedding_length = %v, want %v", info[embedKey], tt.wantEmbedLen)
}
// Check block count
blockKey := tt.wantArch + ".block_count"
if blockCount, ok := info[blockKey].(int); !ok || blockCount != tt.wantBlockCount {
t.Errorf("block_count = %v, want %v", info[blockKey], tt.wantBlockCount)
}
// Check parameter count
if paramCount, ok := info["general.parameter_count"].(int64); !ok || paramCount != tt.wantParamCount {
t.Errorf("parameter_count = %v, want %v", info["general.parameter_count"], tt.wantParamCount)
}
})
}
}
func TestBuildModelInfo_ArchitectureConversion(t *testing.T) {
tests := []struct {
name string
architectures []string
modelType string
wantArch string
}{
{
name: "LlamaForCausalLM",
architectures: []string{"LlamaForCausalLM"},
wantArch: "llama",
},
{
name: "Gemma3ForCausalLM",
architectures: []string{"Gemma3ForCausalLM"},
wantArch: "gemma3",
},
{
name: "Gemma3ForConditionalGeneration",
architectures: []string{"Gemma3ForConditionalGeneration"},
wantArch: "gemma3",
},
{
name: "Qwen2ForCausalLM",
architectures: []string{"Qwen2ForCausalLM"},
wantArch: "qwen2",
},
{
name: "model_type takes precedence",
architectures: []string{"LlamaForCausalLM"},
modelType: "custom",
wantArch: "custom",
},
{
name: "empty architectures with model_type",
architectures: nil,
modelType: "mymodel",
wantArch: "mymodel",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
config := modelConfig{
Architectures: tt.architectures,
ModelType: tt.modelType,
}
info := buildModelInfo(config, 0, 0)
if arch, ok := info["general.architecture"].(string); !ok || arch != tt.wantArch {
t.Errorf("architecture = %v, want %v", info["general.architecture"], tt.wantArch)
}
})
}
}
func TestBuildModelInfo_BytesPerParam(t *testing.T) {
tests := []struct {
name string
dtype string
totalBytes int64
tensorCount int64
wantParamCount int64
}{
{
name: "bfloat16",
dtype: "bfloat16",
totalBytes: 2_000_088, // 1M * 2 + 88
tensorCount: 1,
wantParamCount: 1_000_000,
},
{
name: "float16",
dtype: "float16",
totalBytes: 2_000_088,
tensorCount: 1,
wantParamCount: 1_000_000,
},
{
name: "float32",
dtype: "float32",
totalBytes: 4_000_088, // 1M * 4 + 88
tensorCount: 1,
wantParamCount: 1_000_000,
},
{
name: "int8",
dtype: "int8",
totalBytes: 1_000_088, // 1M * 1 + 88
tensorCount: 1,
wantParamCount: 1_000_000,
},
{
name: "unknown dtype defaults to 2 bytes",
dtype: "unknown",
totalBytes: 2_000_088,
tensorCount: 1,
wantParamCount: 1_000_000,
},
{
name: "empty dtype defaults to 2 bytes",
dtype: "",
totalBytes: 2_000_088,
tensorCount: 1,
wantParamCount: 1_000_000,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
config := modelConfig{
ModelType: "test",
TorchDtype: tt.dtype,
}
info := buildModelInfo(config, tt.totalBytes, tt.tensorCount)
if paramCount, ok := info["general.parameter_count"].(int64); !ok || paramCount != tt.wantParamCount {
t.Errorf("parameter_count = %v, want %v", info["general.parameter_count"], tt.wantParamCount)
}
})
}
}
func TestParseSafetensorsHeader(t *testing.T) {
tests := []struct {
name string
header map[string]any
wantDtype string
wantShape []int64
wantErr bool
}{
{
name: "simple tensor",
header: map[string]any{
"weight": map[string]any{
"dtype": "BF16",
"shape": []int64{2560, 262144},
"data_offsets": []int64{0, 1342177280},
},
},
wantDtype: "BF16",
wantShape: []int64{2560, 262144},
},
{
name: "with metadata",
header: map[string]any{
"__metadata__": map[string]any{
"format": "pt",
},
"bias": map[string]any{
"dtype": "F32",
"shape": []int64{1024},
"data_offsets": []int64{0, 4096},
},
},
wantDtype: "F32",
wantShape: []int64{1024},
},
{
name: "float16 tensor",
header: map[string]any{
"layer.weight": map[string]any{
"dtype": "F16",
"shape": []int64{512, 512, 3, 3},
"data_offsets": []int64{0, 4718592},
},
},
wantDtype: "F16",
wantShape: []int64{512, 512, 3, 3},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create safetensors format: 8-byte size + JSON header
headerJSON, err := json.Marshal(tt.header)
if err != nil {
t.Fatalf("failed to marshal header: %v", err)
}
var buf bytes.Buffer
if err := binary.Write(&buf, binary.LittleEndian, uint64(len(headerJSON))); err != nil {
t.Fatalf("failed to write header size: %v", err)
}
buf.Write(headerJSON)
info, err := parseSafetensorsHeader(&buf)
if (err != nil) != tt.wantErr {
t.Errorf("parseSafetensorsHeader() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.wantErr {
return
}
if info.Dtype != tt.wantDtype {
t.Errorf("Dtype = %v, want %v", info.Dtype, tt.wantDtype)
}
if len(info.Shape) != len(tt.wantShape) {
t.Errorf("Shape length = %v, want %v", len(info.Shape), len(tt.wantShape))
} else {
for i, s := range info.Shape {
if s != tt.wantShape[i] {
t.Errorf("Shape[%d] = %v, want %v", i, s, tt.wantShape[i])
}
}
}
})
}
}
func TestParseSafetensorsHeader_Errors(t *testing.T) {
tests := []struct {
name string
data []byte
wantErr string
}{
{
name: "empty data",
data: []byte{},
wantErr: "failed to read header size",
},
{
name: "truncated header size",
data: []byte{0x01, 0x02, 0x03},
wantErr: "failed to read header size",
},
{
name: "header size too large",
data: func() []byte {
var buf bytes.Buffer
binary.Write(&buf, binary.LittleEndian, uint64(2*1024*1024)) // 2MB
return buf.Bytes()
}(),
wantErr: "header size too large",
},
{
name: "truncated header",
data: func() []byte {
var buf bytes.Buffer
binary.Write(&buf, binary.LittleEndian, uint64(100))
buf.Write([]byte("short"))
return buf.Bytes()
}(),
wantErr: "failed to read header",
},
{
name: "invalid JSON",
data: func() []byte {
var buf bytes.Buffer
binary.Write(&buf, binary.LittleEndian, uint64(10))
buf.Write([]byte("not json!!"))
return buf.Bytes()
}(),
wantErr: "failed to parse header",
},
{
name: "no tensors in header",
data: func() []byte {
header := map[string]any{
"__metadata__": map[string]any{"format": "pt"},
}
headerJSON, _ := json.Marshal(header)
var buf bytes.Buffer
binary.Write(&buf, binary.LittleEndian, uint64(len(headerJSON)))
buf.Write(headerJSON)
return buf.Bytes()
}(),
wantErr: "no tensor found in header",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := parseSafetensorsHeader(bytes.NewReader(tt.data))
if err == nil {
t.Error("expected error, got nil")
return
}
if !bytes.Contains([]byte(err.Error()), []byte(tt.wantErr)) {
t.Errorf("error = %v, want error containing %v", err, tt.wantErr)
}
})
}
}
func TestGetTensorInfoFromManifest(t *testing.T) {
// Create a temp directory for blobs
tempDir, err := os.MkdirTemp("", "ollama-test-*")
if err != nil {
t.Fatalf("failed to create temp dir: %v", err)
}
defer os.RemoveAll(tempDir)
// Create test tensor blobs
tensors := []struct {
name string
digest string
dtype string
shape []int64
}{
{
name: "model.embed_tokens.weight",
digest: "sha256:abc123",
dtype: "BF16",
shape: []int64{262144, 2560},
},
{
name: "model.layers.0.self_attn.q_proj.weight",
digest: "sha256:def456",
dtype: "BF16",
shape: []int64{2560, 2560},
},
{
name: "model.norm.weight",
digest: "sha256:ghi789",
dtype: "F32",
shape: []int64{2560},
},
}
// Create blob files
var layers []imagegen.ManifestLayer
for _, tensor := range tensors {
// Create safetensors blob
header := map[string]any{
tensor.name: map[string]any{
"dtype": tensor.dtype,
"shape": tensor.shape,
"data_offsets": []int64{0, 1000},
},
}
headerJSON, _ := json.Marshal(header)
var buf bytes.Buffer
binary.Write(&buf, binary.LittleEndian, uint64(len(headerJSON)))
buf.Write(headerJSON)
// Write blob file
blobName := "sha256-" + tensor.digest[7:]
blobPath := filepath.Join(tempDir, blobName)
if err := os.WriteFile(blobPath, buf.Bytes(), 0644); err != nil {
t.Fatalf("failed to write blob: %v", err)
}
layers = append(layers, imagegen.ManifestLayer{
MediaType: "application/vnd.ollama.image.tensor",
Digest: tensor.digest,
Size: int64(buf.Len() + 1000), // header + fake data
Name: tensor.name,
})
}
// Add a non-tensor layer (should be skipped)
layers = append(layers, imagegen.ManifestLayer{
MediaType: "application/vnd.ollama.image.json",
Digest: "sha256:config",
Size: 100,
Name: "config.json",
})
manifest := &imagegen.ModelManifest{
Manifest: &imagegen.Manifest{
Layers: layers,
},
BlobDir: tempDir,
}
result, err := getTensorInfoFromManifest(manifest)
if err != nil {
t.Fatalf("getTensorInfoFromManifest() error = %v", err)
}
if len(result) != 3 {
t.Errorf("got %d tensors, want 3", len(result))
}
// Verify each tensor
for i, tensor := range tensors {
if i >= len(result) {
break
}
if result[i].Name != tensor.name {
t.Errorf("tensor[%d].Name = %v, want %v", i, result[i].Name, tensor.name)
}
if result[i].Type != tensor.dtype {
t.Errorf("tensor[%d].Type = %v, want %v", i, result[i].Type, tensor.dtype)
}
if len(result[i].Shape) != len(tensor.shape) {
t.Errorf("tensor[%d].Shape length = %v, want %v", i, len(result[i].Shape), len(tensor.shape))
}
}
}
func TestReadSafetensorsHeader(t *testing.T) {
// Create a temp file with a valid safetensors header
tempDir, err := os.MkdirTemp("", "ollama-test-*")
if err != nil {
t.Fatalf("failed to create temp dir: %v", err)
}
defer os.RemoveAll(tempDir)
header := map[string]any{
"test_tensor": map[string]any{
"dtype": "BF16",
"shape": []int64{1024, 768},
"data_offsets": []int64{0, 1572864},
},
}
headerJSON, _ := json.Marshal(header)
var buf bytes.Buffer
binary.Write(&buf, binary.LittleEndian, uint64(len(headerJSON)))
buf.Write(headerJSON)
filePath := filepath.Join(tempDir, "test.safetensors")
if err := os.WriteFile(filePath, buf.Bytes(), 0644); err != nil {
t.Fatalf("failed to write test file: %v", err)
}
info, err := readSafetensorsHeader(filePath)
if err != nil {
t.Fatalf("readSafetensorsHeader() error = %v", err)
}
if info.Dtype != "BF16" {
t.Errorf("Dtype = %v, want BF16", info.Dtype)
}
if len(info.Shape) != 2 || info.Shape[0] != 1024 || info.Shape[1] != 768 {
t.Errorf("Shape = %v, want [1024, 768]", info.Shape)
}
}
func TestReadSafetensorsHeader_FileNotFound(t *testing.T) {
_, err := readSafetensorsHeader("/nonexistent/path/file.safetensors")
if err == nil {
t.Error("expected error for nonexistent file")
}
}