mirror of
https://github.com/ollama/ollama.git
synced 2026-01-15 19:09:25 -05:00
Compare commits
5 Commits
imagegen-r
...
pdevine/x-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3b95add4e3 | ||
|
|
4adb9cf4bb | ||
|
|
74f475e735 | ||
|
|
875cecba74 | ||
|
|
7d411a4686 |
42
README.md
42
README.md
@@ -48,7 +48,7 @@ ollama run gemma3
|
||||
|
||||
## Model library
|
||||
|
||||
Ollama supports a list of models available on [ollama.com/library](https://ollama.com/library "ollama model library")
|
||||
Ollama supports a list of models available on [ollama.com/library](https://ollama.com/library 'ollama model library')
|
||||
|
||||
Here are some example models that can be downloaded:
|
||||
|
||||
@@ -79,7 +79,7 @@ Here are some example models that can be downloaded:
|
||||
| Code Llama | 7B | 3.8GB | `ollama run codellama` |
|
||||
| Llama 2 Uncensored | 7B | 3.8GB | `ollama run llama2-uncensored` |
|
||||
| LLaVA | 7B | 4.5GB | `ollama run llava` |
|
||||
| Granite-3.3 | 8B | 4.9GB | `ollama run granite3.3` |
|
||||
| Granite-3.3 | 8B | 4.9GB | `ollama run granite3.3` |
|
||||
|
||||
> [!NOTE]
|
||||
> You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models.
|
||||
@@ -260,38 +260,6 @@ Finally, in a separate shell, run a model:
|
||||
./ollama run llama3.2
|
||||
```
|
||||
|
||||
## Building with MLX (experimental)
|
||||
|
||||
First build the MLX libraries:
|
||||
|
||||
```shell
|
||||
cmake --preset MLX
|
||||
cmake --build --preset MLX --parallel
|
||||
cmake --install build --component MLX
|
||||
```
|
||||
|
||||
Next, build the `ollama-mlx` binary, which is a separate build of the Ollama runtime with MLX support enabled (needs to be in the same directory as `ollama`):
|
||||
|
||||
```shell
|
||||
go build -tags mlx -o ollama-mlx .
|
||||
```
|
||||
|
||||
Finally, start the server:
|
||||
|
||||
```
|
||||
./ollama serve
|
||||
```
|
||||
|
||||
### Building MLX with CUDA
|
||||
|
||||
When building with CUDA, use the preset "MLX CUDA 13" or "MLX CUDA 12" to enable CUDA with default architectures:
|
||||
|
||||
```shell
|
||||
cmake --preset 'MLX CUDA 13'
|
||||
cmake --build --preset 'MLX CUDA 13' --parallel
|
||||
cmake --install build --component MLX
|
||||
```
|
||||
|
||||
## REST API
|
||||
|
||||
Ollama has a REST API for running and managing models.
|
||||
@@ -453,7 +421,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [AppFlowy](https://github.com/AppFlowy-IO/AppFlowy) (AI collaborative workspace with Ollama, cross-platform and self-hostable)
|
||||
- [Lumina](https://github.com/cushydigit/lumina.git) (A lightweight, minimal React.js frontend for interacting with Ollama servers)
|
||||
- [Tiny Notepad](https://pypi.org/project/tiny-notepad) (A lightweight, notepad-like interface to chat with ollama available on PyPI)
|
||||
- [macLlama (macOS native)](https://github.com/hellotunamayo/macLlama) (A native macOS GUI application for interacting with Ollama models, featuring a chat interface.)
|
||||
- [macLlama (macOS native)](https://github.com/hellotunamayo/macLlama) (A native macOS GUI application for interacting with Ollama models, featuring a chat interface.)
|
||||
- [GPTranslate](https://github.com/philberndt/GPTranslate) (A fast and lightweight, AI powered desktop translation application written with Rust and Tauri. Features real-time translation with OpenAI/Azure/Ollama.)
|
||||
- [ollama launcher](https://github.com/NGC13009/ollama-launcher) (A launcher for Ollama, aiming to provide users with convenient functions such as ollama server launching, management, or configuration.)
|
||||
- [ai-hub](https://github.com/Aj-Seven/ai-hub) (AI Hub supports multiple models via API keys and Chat support via Ollama API.)
|
||||
@@ -525,7 +493,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
### Database
|
||||
|
||||
- [pgai](https://github.com/timescale/pgai) - PostgreSQL as a vector database (Create and search embeddings from Ollama models using pgvector)
|
||||
- [Get started guide](https://github.com/timescale/pgai/blob/main/docs/vectorizer-quick-start.md)
|
||||
- [Get started guide](https://github.com/timescale/pgai/blob/main/docs/vectorizer-quick-start.md)
|
||||
- [MindsDB](https://github.com/mindsdb/mindsdb/blob/staging/mindsdb/integrations/handlers/ollama_handler/README.md) (Connects Ollama models with nearly 200 data platforms and apps)
|
||||
- [chromem-go](https://github.com/philippgille/chromem-go/blob/v0.5.0/embed_ollama.go) with [example](https://github.com/philippgille/chromem-go/tree/v0.5.0/examples/rag-wikipedia-ollama)
|
||||
- [Kangaroo](https://github.com/dbkangaroo/kangaroo) (AI-powered SQL client and admin tool for popular databases)
|
||||
@@ -668,7 +636,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [llama.cpp](https://github.com/ggml-org/llama.cpp) project founded by Georgi Gerganov.
|
||||
|
||||
### Observability
|
||||
|
||||
- [Opik](https://www.comet.com/docs/opik/cookbook/ollama) is an open-source platform to debug, evaluate, and monitor your LLM applications, RAG systems, and agentic workflows with comprehensive tracing, automated evaluations, and production-ready dashboards. Opik supports native integration to Ollama.
|
||||
- [Lunary](https://lunary.ai/docs/integrations/ollama) is the leading open-source LLM observability platform. It provides a variety of enterprise-grade features such as real-time analytics, prompt templates management, PII masking, and comprehensive agent tracing.
|
||||
- [OpenLIT](https://github.com/openlit/openlit) is an OpenTelemetry-native tool for monitoring Ollama Applications & GPUs using traces and metrics.
|
||||
@@ -677,5 +644,4 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [MLflow Tracing](https://mlflow.org/docs/latest/llms/tracing/index.html#automatic-tracing) is an open source LLM observability tool with a convenient API to log and visualize traces, making it easy to debug and evaluate GenAI applications.
|
||||
|
||||
### Security
|
||||
|
||||
- [Ollama Fortress](https://github.com/ParisNeo/ollama_proxy_server)
|
||||
|
||||
91
cmd/cmd.go
91
cmd/cmd.go
@@ -46,8 +46,9 @@ import (
|
||||
"github.com/ollama/ollama/types/syncmap"
|
||||
"github.com/ollama/ollama/version"
|
||||
xcmd "github.com/ollama/ollama/x/cmd"
|
||||
"github.com/ollama/ollama/x/create"
|
||||
xcreateclient "github.com/ollama/ollama/x/create/client"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
imagegenclient "github.com/ollama/ollama/x/imagegen/client"
|
||||
)
|
||||
|
||||
const ConnectInstructions = "To sign in, navigate to:\n %s\n\n"
|
||||
@@ -93,15 +94,82 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
p := progress.NewProgress(os.Stderr)
|
||||
defer p.Stop()
|
||||
|
||||
// Check for --experimental flag for safetensors model creation
|
||||
experimental, _ := cmd.Flags().GetBool("experimental")
|
||||
if experimental {
|
||||
modelName := args[0]
|
||||
|
||||
// Get Modelfile content - either from -f flag or default to "FROM ."
|
||||
var reader io.Reader
|
||||
filename, err := getModelfileName(cmd)
|
||||
if os.IsNotExist(err) || filename == "" {
|
||||
// No Modelfile specified or found - use default
|
||||
reader = strings.NewReader("FROM .\n")
|
||||
} else if err != nil {
|
||||
return err
|
||||
} else {
|
||||
f, err := os.Open(filename)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
reader = f
|
||||
}
|
||||
|
||||
// Parse the Modelfile
|
||||
modelfile, err := parser.ParseFile(reader)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse Modelfile: %w", err)
|
||||
}
|
||||
|
||||
// Extract FROM path and configuration
|
||||
var modelDir string
|
||||
mfConfig := &xcreateclient.ModelfileConfig{}
|
||||
|
||||
for _, cmd := range modelfile.Commands {
|
||||
switch cmd.Name {
|
||||
case "model":
|
||||
modelDir = cmd.Args
|
||||
case "template":
|
||||
mfConfig.Template = cmd.Args
|
||||
case "system":
|
||||
mfConfig.System = cmd.Args
|
||||
case "license":
|
||||
mfConfig.License = cmd.Args
|
||||
}
|
||||
}
|
||||
|
||||
if modelDir == "" {
|
||||
modelDir = "."
|
||||
}
|
||||
|
||||
// Resolve relative paths based on Modelfile location
|
||||
if !filepath.IsAbs(modelDir) && filename != "" {
|
||||
modelDir = filepath.Join(filepath.Dir(filename), modelDir)
|
||||
}
|
||||
|
||||
quantize, _ := cmd.Flags().GetString("quantize")
|
||||
return xcreateclient.CreateModel(xcreateclient.CreateOptions{
|
||||
ModelName: modelName,
|
||||
ModelDir: modelDir,
|
||||
Quantize: quantize,
|
||||
Modelfile: mfConfig,
|
||||
}, p)
|
||||
}
|
||||
|
||||
var reader io.Reader
|
||||
|
||||
filename, err := getModelfileName(cmd)
|
||||
if os.IsNotExist(err) {
|
||||
if filename == "" {
|
||||
// No Modelfile found - check if current directory is an image gen model
|
||||
if imagegen.IsTensorModelDir(".") {
|
||||
if create.IsTensorModelDir(".") {
|
||||
quantize, _ := cmd.Flags().GetString("quantize")
|
||||
return imagegenclient.CreateModel(args[0], ".", quantize, p)
|
||||
return xcreateclient.CreateModel(xcreateclient.CreateOptions{
|
||||
ModelName: args[0],
|
||||
ModelDir: ".",
|
||||
Quantize: quantize,
|
||||
}, p)
|
||||
}
|
||||
reader = strings.NewReader("FROM .\n")
|
||||
} else {
|
||||
@@ -1742,15 +1810,22 @@ func NewCLI() *cobra.Command {
|
||||
rootCmd.Flags().BoolP("version", "v", false, "Show version information")
|
||||
|
||||
createCmd := &cobra.Command{
|
||||
Use: "create MODEL",
|
||||
Short: "Create a model",
|
||||
Args: cobra.ExactArgs(1),
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: CreateHandler,
|
||||
Use: "create MODEL",
|
||||
Short: "Create a model",
|
||||
Args: cobra.ExactArgs(1),
|
||||
PreRunE: func(cmd *cobra.Command, args []string) error {
|
||||
// Skip server check for experimental mode (writes directly to disk)
|
||||
if experimental, _ := cmd.Flags().GetBool("experimental"); experimental {
|
||||
return nil
|
||||
}
|
||||
return checkServerHeartbeat(cmd, args)
|
||||
},
|
||||
RunE: CreateHandler,
|
||||
}
|
||||
|
||||
createCmd.Flags().StringP("file", "f", "", "Name of the Modelfile (default \"Modelfile\")")
|
||||
createCmd.Flags().StringP("quantize", "q", "", "Quantize model to this level (e.g. q4_K_M)")
|
||||
createCmd.Flags().Bool("experimental", false, "Enable experimental safetensors model creation")
|
||||
|
||||
showCmd := &cobra.Command{
|
||||
Use: "show MODEL",
|
||||
|
||||
@@ -110,7 +110,7 @@ More Ollama [Python example](https://github.com/ollama/ollama-python/blob/main/e
|
||||
import { Ollama } from "ollama";
|
||||
|
||||
const client = new Ollama();
|
||||
const results = await client.webSearch({ query: "what is ollama?" });
|
||||
const results = await client.webSearch("what is ollama?");
|
||||
console.log(JSON.stringify(results, null, 2));
|
||||
```
|
||||
|
||||
@@ -213,7 +213,7 @@ models](https://ollama.com/models)\n\nAvailable for macOS, Windows, and Linux',
|
||||
import { Ollama } from "ollama";
|
||||
|
||||
const client = new Ollama();
|
||||
const fetchResult = await client.webFetch({ url: "https://ollama.com" });
|
||||
const fetchResult = await client.webFetch("https://ollama.com");
|
||||
console.log(JSON.stringify(fetchResult, null, 2));
|
||||
```
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ Please refer to the [GPU docs](./gpu).
|
||||
|
||||
## How can I specify the context window size?
|
||||
|
||||
By default, Ollama uses a context window size of 2048 tokens.
|
||||
By default, Ollama uses a context window size of 4096 tokens.
|
||||
|
||||
This can be overridden with the `OLLAMA_CONTEXT_LENGTH` environment variable. For example, to set the default context window to 8K, use:
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
---
|
||||
title: "Linux"
|
||||
title: Linux
|
||||
---
|
||||
|
||||
## Install
|
||||
@@ -13,14 +13,15 @@ curl -fsSL https://ollama.com/install.sh | sh
|
||||
## Manual install
|
||||
|
||||
<Note>
|
||||
If you are upgrading from a prior version, you should remove the old libraries with `sudo rm -rf /usr/lib/ollama` first.
|
||||
If you are upgrading from a prior version, you should remove the old libraries
|
||||
with `sudo rm -rf /usr/lib/ollama` first.
|
||||
</Note>
|
||||
|
||||
Download and extract the package:
|
||||
|
||||
```shell
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tgz \
|
||||
| sudo tar zx -C /usr
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tar.zst \
|
||||
| sudo tar x -C /usr
|
||||
```
|
||||
|
||||
Start Ollama:
|
||||
@@ -40,8 +41,8 @@ ollama -v
|
||||
If you have an AMD GPU, also download and extract the additional ROCm package:
|
||||
|
||||
```shell
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tgz \
|
||||
| sudo tar zx -C /usr
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tar.zst \
|
||||
| sudo tar x -C /usr
|
||||
```
|
||||
|
||||
### ARM64 install
|
||||
@@ -49,8 +50,8 @@ curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tgz \
|
||||
Download and extract the ARM64-specific package:
|
||||
|
||||
```shell
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-arm64.tgz \
|
||||
| sudo tar zx -C /usr
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-arm64.tar.zst \
|
||||
| sudo tar x -C /usr
|
||||
```
|
||||
|
||||
### Adding Ollama as a startup service (recommended)
|
||||
@@ -112,7 +113,11 @@ sudo systemctl status ollama
|
||||
```
|
||||
|
||||
<Note>
|
||||
While AMD has contributed the `amdgpu` driver upstream to the official linux kernel source, the version is older and may not support all ROCm features. We recommend you install the latest driver from https://www.amd.com/en/support/linux-drivers for best support of your Radeon GPU.
|
||||
While AMD has contributed the `amdgpu` driver upstream to the official linux
|
||||
kernel source, the version is older and may not support all ROCm features. We
|
||||
recommend you install the latest driver from
|
||||
https://www.amd.com/en/support/linux-drivers for best support of your Radeon
|
||||
GPU.
|
||||
</Note>
|
||||
|
||||
## Customizing
|
||||
@@ -141,8 +146,8 @@ curl -fsSL https://ollama.com/install.sh | sh
|
||||
Or by re-downloading Ollama:
|
||||
|
||||
```shell
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tgz \
|
||||
| sudo tar zx -C /usr
|
||||
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tar.zst \
|
||||
| sudo tar x -C /usr
|
||||
```
|
||||
|
||||
## Installing specific versions
|
||||
@@ -191,4 +196,4 @@ Remove the downloaded models and Ollama service user and group:
|
||||
sudo userdel ollama
|
||||
sudo groupdel ollama
|
||||
sudo rm -r /usr/share/ollama
|
||||
```
|
||||
```
|
||||
|
||||
@@ -179,7 +179,7 @@ _build_macapp() {
|
||||
fi
|
||||
|
||||
rm -f dist/Ollama-darwin.zip
|
||||
ditto -c -k --keepParent dist/Ollama.app dist/Ollama-darwin.zip
|
||||
ditto -c -k --norsrc --keepParent dist/Ollama.app dist/Ollama-darwin.zip
|
||||
(cd dist/Ollama.app/Contents/Resources/; tar -cf - ollama ollama-mlx *.so *.dylib *.metallib 2>/dev/null) | gzip -9vc > dist/ollama-darwin.tgz
|
||||
|
||||
# Notarize and Staple
|
||||
@@ -187,7 +187,7 @@ _build_macapp() {
|
||||
$(xcrun -f notarytool) submit dist/Ollama-darwin.zip --wait --timeout 20m --apple-id "$APPLE_ID" --password "$APPLE_PASSWORD" --team-id "$APPLE_TEAM_ID"
|
||||
rm -f dist/Ollama-darwin.zip
|
||||
$(xcrun -f stapler) staple dist/Ollama.app
|
||||
ditto -c -k --keepParent dist/Ollama.app dist/Ollama-darwin.zip
|
||||
ditto -c -k --norsrc --keepParent dist/Ollama.app dist/Ollama-darwin.zip
|
||||
|
||||
rm -f dist/Ollama.dmg
|
||||
|
||||
|
||||
@@ -52,6 +52,7 @@ import (
|
||||
"github.com/ollama/ollama/version"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
imagegenapi "github.com/ollama/ollama/x/imagegen/api"
|
||||
xserver "github.com/ollama/ollama/x/server"
|
||||
)
|
||||
|
||||
const signinURLStr = "https://ollama.com/connect?name=%s&key=%s"
|
||||
@@ -1133,6 +1134,22 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// For safetensors LLM models (experimental), populate details from config.json
|
||||
if m.Config.ModelFormat == "safetensors" && slices.Contains(m.Config.Capabilities, "completion") {
|
||||
if info, err := xserver.GetSafetensorsLLMInfo(name.String()); err == nil {
|
||||
if arch, ok := info["general.architecture"].(string); ok && arch != "" {
|
||||
modelDetails.Family = arch
|
||||
}
|
||||
if paramCount, ok := info["general.parameter_count"].(int64); ok && paramCount > 0 {
|
||||
modelDetails.ParameterSize = format.HumanNumber(uint64(paramCount))
|
||||
}
|
||||
}
|
||||
// Get torch_dtype directly from config.json for quantization level
|
||||
if dtype, err := xserver.GetSafetensorsDtype(name.String()); err == nil && dtype != "" {
|
||||
modelDetails.QuantizationLevel = dtype
|
||||
}
|
||||
}
|
||||
|
||||
if req.System != "" {
|
||||
m.System = req.System
|
||||
}
|
||||
@@ -1219,6 +1236,20 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// For safetensors LLM models (experimental), populate ModelInfo from config.json
|
||||
if m.Config.ModelFormat == "safetensors" && slices.Contains(m.Config.Capabilities, "completion") {
|
||||
if info, err := xserver.GetSafetensorsLLMInfo(name.String()); err == nil {
|
||||
resp.ModelInfo = info
|
||||
}
|
||||
// Populate tensor info if verbose
|
||||
if req.Verbose {
|
||||
if tensors, err := xserver.GetSafetensorsTensorInfo(name.String()); err == nil {
|
||||
resp.Tensors = tensors
|
||||
}
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
kvData, tensors, err := getModelData(m.ModelPath, req.Verbose)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
50
x/README.md
Normal file
50
x/README.md
Normal file
@@ -0,0 +1,50 @@
|
||||
# Experimental Features
|
||||
|
||||
## MLX Backend
|
||||
|
||||
We're working on a new experimental backend based on the [MLX project](https://github.com/ml-explore/mlx)
|
||||
|
||||
Support is currently limited to MacOS and Linux with CUDA GPUs. We're looking to add support for Windows CUDA soon, and other GPU vendors.
|
||||
|
||||
### Building ollama-mlx
|
||||
|
||||
The `ollama-mlx` binary is a separate build of Ollama with MLX support enabled. This enables experimental features like image generation.
|
||||
|
||||
#### macOS (Apple Silicon and Intel)
|
||||
|
||||
```bash
|
||||
# Build MLX backend libraries
|
||||
cmake --preset MLX
|
||||
cmake --build --preset MLX --parallel
|
||||
cmake --install build --component MLX
|
||||
|
||||
# Build ollama-mlx binary
|
||||
go build -tags mlx -o ollama-mlx .
|
||||
```
|
||||
|
||||
#### Linux (CUDA)
|
||||
|
||||
On Linux, use the preset "MLX CUDA 13" or "MLX CUDA 12" to enable CUDA with the default Ollama NVIDIA GPU architectures enabled:
|
||||
|
||||
```bash
|
||||
# Build MLX backend libraries with CUDA support
|
||||
cmake --preset 'MLX CUDA 13'
|
||||
cmake --build --preset 'MLX CUDA 13' --parallel
|
||||
cmake --install build --component MLX
|
||||
|
||||
# Build ollama-mlx binary
|
||||
CGO_CFLAGS="-O3 -I$(pwd)/build/_deps/mlx-c-src" \
|
||||
CGO_LDFLAGS="-L$(pwd)/build/lib/ollama -lmlxc -lmlx" \
|
||||
go build -tags mlx -o ollama-mlx .
|
||||
```
|
||||
|
||||
#### Using build scripts
|
||||
|
||||
The build scripts automatically create the `ollama-mlx` binary:
|
||||
|
||||
- **macOS**: `./scripts/build_darwin.sh` produces `dist/darwin/ollama-mlx`
|
||||
- **Linux**: `./scripts/build_linux.sh` produces `ollama-mlx` in the output archives
|
||||
|
||||
## Image Generation
|
||||
|
||||
Image generation is built into the `ollama-mlx` binary. Run `ollama-mlx serve` to start the server with image generation support enabled.
|
||||
282
x/create/client/create.go
Normal file
282
x/create/client/create.go
Normal file
@@ -0,0 +1,282 @@
|
||||
// Package client provides client-side model creation for safetensors-based models.
|
||||
//
|
||||
// This package is in x/ because the safetensors model storage format is under development.
|
||||
// It also exists to break an import cycle: server imports x/create, so x/create
|
||||
// cannot import server. This sub-package can import server because server doesn't
|
||||
// import it.
|
||||
package client
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/ollama/ollama/server"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/x/create"
|
||||
)
|
||||
|
||||
// MinOllamaVersion is the minimum Ollama version required for safetensors models.
|
||||
const MinOllamaVersion = "0.14.0"
|
||||
|
||||
// ModelfileConfig holds configuration extracted from a Modelfile.
|
||||
type ModelfileConfig struct {
|
||||
Template string
|
||||
System string
|
||||
License string
|
||||
}
|
||||
|
||||
// CreateOptions holds all options for model creation.
|
||||
type CreateOptions struct {
|
||||
ModelName string
|
||||
ModelDir string
|
||||
Quantize string // "fp8" for quantization
|
||||
Modelfile *ModelfileConfig // template/system/license from Modelfile
|
||||
}
|
||||
|
||||
// CreateModel imports a model from a local directory.
|
||||
// This creates blobs and manifest directly on disk, bypassing the HTTP API.
|
||||
// Automatically detects model type (safetensors LLM vs image gen) and routes accordingly.
|
||||
func CreateModel(opts CreateOptions, p *progress.Progress) error {
|
||||
// Detect model type
|
||||
isSafetensors := create.IsSafetensorsModelDir(opts.ModelDir)
|
||||
isImageGen := create.IsTensorModelDir(opts.ModelDir)
|
||||
|
||||
if !isSafetensors && !isImageGen {
|
||||
return fmt.Errorf("%s is not a supported model directory (needs config.json + *.safetensors or model_index.json)", opts.ModelDir)
|
||||
}
|
||||
|
||||
// Determine model type settings
|
||||
var modelType, spinnerKey string
|
||||
var capabilities []string
|
||||
if isSafetensors {
|
||||
modelType = "safetensors model"
|
||||
spinnerKey = "create"
|
||||
capabilities = []string{"completion"}
|
||||
} else {
|
||||
modelType = "image generation model"
|
||||
spinnerKey = "imagegen"
|
||||
capabilities = []string{"image"}
|
||||
}
|
||||
|
||||
// Set up progress spinner
|
||||
statusMsg := "importing " + modelType
|
||||
spinner := progress.NewSpinner(statusMsg)
|
||||
p.Add(spinnerKey, spinner)
|
||||
|
||||
progressFn := func(msg string) {
|
||||
spinner.Stop()
|
||||
statusMsg = msg
|
||||
spinner = progress.NewSpinner(statusMsg)
|
||||
p.Add(spinnerKey, spinner)
|
||||
}
|
||||
|
||||
// Create the model using shared callbacks
|
||||
var err error
|
||||
if isSafetensors {
|
||||
err = create.CreateSafetensorsModel(
|
||||
opts.ModelName, opts.ModelDir, opts.Quantize,
|
||||
newLayerCreator(), newTensorLayerCreator(),
|
||||
newManifestWriter(opts, capabilities),
|
||||
progressFn,
|
||||
)
|
||||
} else {
|
||||
err = create.CreateImageGenModel(
|
||||
opts.ModelName, opts.ModelDir, opts.Quantize,
|
||||
newLayerCreator(), newTensorLayerCreator(),
|
||||
newManifestWriter(opts, capabilities),
|
||||
progressFn,
|
||||
)
|
||||
}
|
||||
|
||||
spinner.Stop()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Printf("Created %s '%s'\n", modelType, opts.ModelName)
|
||||
return nil
|
||||
}
|
||||
|
||||
// newLayerCreator returns a LayerCreator callback for creating config/JSON layers.
|
||||
func newLayerCreator() create.LayerCreator {
|
||||
return func(r io.Reader, mediaType, name string) (create.LayerInfo, error) {
|
||||
layer, err := server.NewLayer(r, mediaType)
|
||||
if err != nil {
|
||||
return create.LayerInfo{}, err
|
||||
}
|
||||
|
||||
return create.LayerInfo{
|
||||
Digest: layer.Digest,
|
||||
Size: layer.Size,
|
||||
MediaType: layer.MediaType,
|
||||
Name: name,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// newTensorLayerCreator returns a QuantizingTensorLayerCreator callback for creating tensor layers.
|
||||
// When doQuantize is true, returns multiple layers (weight + scales + optional qbias).
|
||||
func newTensorLayerCreator() create.QuantizingTensorLayerCreator {
|
||||
return func(r io.Reader, name, dtype string, shape []int32, doQuantize bool) ([]create.LayerInfo, error) {
|
||||
if doQuantize {
|
||||
return createQuantizedLayers(r, name, dtype, shape)
|
||||
}
|
||||
return createUnquantizedLayer(r, name)
|
||||
}
|
||||
}
|
||||
|
||||
// createQuantizedLayers quantizes a tensor and returns the resulting layers.
|
||||
func createQuantizedLayers(r io.Reader, name, dtype string, shape []int32) ([]create.LayerInfo, error) {
|
||||
if !QuantizeSupported() {
|
||||
return nil, fmt.Errorf("quantization requires MLX support")
|
||||
}
|
||||
|
||||
// Quantize the tensor (affine mode returns weight, scales, qbiases)
|
||||
qweightData, scalesData, qbiasData, _, _, _, err := quantizeTensor(r, name, dtype, shape)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to quantize %s: %w", name, err)
|
||||
}
|
||||
|
||||
// Create layer for quantized weight
|
||||
weightLayer, err := server.NewLayer(bytes.NewReader(qweightData), server.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create layer for scales
|
||||
scalesLayer, err := server.NewLayer(bytes.NewReader(scalesData), server.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
layers := []create.LayerInfo{
|
||||
{
|
||||
Digest: weightLayer.Digest,
|
||||
Size: weightLayer.Size,
|
||||
MediaType: weightLayer.MediaType,
|
||||
Name: name,
|
||||
},
|
||||
{
|
||||
Digest: scalesLayer.Digest,
|
||||
Size: scalesLayer.Size,
|
||||
MediaType: scalesLayer.MediaType,
|
||||
Name: name + "_scale",
|
||||
},
|
||||
}
|
||||
|
||||
// Add qbiases layer if present (affine mode)
|
||||
if qbiasData != nil {
|
||||
qbiasLayer, err := server.NewLayer(bytes.NewReader(qbiasData), server.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
layers = append(layers, create.LayerInfo{
|
||||
Digest: qbiasLayer.Digest,
|
||||
Size: qbiasLayer.Size,
|
||||
MediaType: qbiasLayer.MediaType,
|
||||
Name: name + "_qbias",
|
||||
})
|
||||
}
|
||||
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
// createUnquantizedLayer creates a single tensor layer without quantization.
|
||||
func createUnquantizedLayer(r io.Reader, name string) ([]create.LayerInfo, error) {
|
||||
layer, err := server.NewLayer(r, server.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return []create.LayerInfo{
|
||||
{
|
||||
Digest: layer.Digest,
|
||||
Size: layer.Size,
|
||||
MediaType: layer.MediaType,
|
||||
Name: name,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// newManifestWriter returns a ManifestWriter callback for writing the model manifest.
|
||||
func newManifestWriter(opts CreateOptions, capabilities []string) create.ManifestWriter {
|
||||
return func(modelName string, config create.LayerInfo, layers []create.LayerInfo) error {
|
||||
name := model.ParseName(modelName)
|
||||
if !name.IsValid() {
|
||||
return fmt.Errorf("invalid model name: %s", modelName)
|
||||
}
|
||||
|
||||
// Create config blob with version requirement
|
||||
configData := model.ConfigV2{
|
||||
ModelFormat: "safetensors",
|
||||
Capabilities: capabilities,
|
||||
Requires: MinOllamaVersion,
|
||||
}
|
||||
configJSON, err := json.Marshal(configData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal config: %w", err)
|
||||
}
|
||||
|
||||
// Create config layer blob
|
||||
configLayer, err := server.NewLayer(bytes.NewReader(configJSON), "application/vnd.docker.container.image.v1+json")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create config layer: %w", err)
|
||||
}
|
||||
|
||||
// Convert LayerInfo to server.Layer
|
||||
serverLayers := make([]server.Layer, len(layers))
|
||||
for i, l := range layers {
|
||||
serverLayers[i] = server.Layer{
|
||||
MediaType: l.MediaType,
|
||||
Digest: l.Digest,
|
||||
Size: l.Size,
|
||||
Name: l.Name,
|
||||
}
|
||||
}
|
||||
|
||||
// Add Modelfile layers if present
|
||||
if opts.Modelfile != nil {
|
||||
modelfileLayers, err := createModelfileLayers(opts.Modelfile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
serverLayers = append(serverLayers, modelfileLayers...)
|
||||
}
|
||||
|
||||
return server.WriteManifest(name, configLayer, serverLayers)
|
||||
}
|
||||
}
|
||||
|
||||
// createModelfileLayers creates layers for template, system, and license from Modelfile config.
|
||||
func createModelfileLayers(mf *ModelfileConfig) ([]server.Layer, error) {
|
||||
var layers []server.Layer
|
||||
|
||||
if mf.Template != "" {
|
||||
layer, err := server.NewLayer(bytes.NewReader([]byte(mf.Template)), "application/vnd.ollama.image.template")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create template layer: %w", err)
|
||||
}
|
||||
layers = append(layers, layer)
|
||||
}
|
||||
|
||||
if mf.System != "" {
|
||||
layer, err := server.NewLayer(bytes.NewReader([]byte(mf.System)), "application/vnd.ollama.image.system")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create system layer: %w", err)
|
||||
}
|
||||
layers = append(layers, layer)
|
||||
}
|
||||
|
||||
if mf.License != "" {
|
||||
layer, err := server.NewLayer(bytes.NewReader([]byte(mf.License)), "application/vnd.ollama.image.license")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create license layer: %w", err)
|
||||
}
|
||||
layers = append(layers, layer)
|
||||
}
|
||||
|
||||
return layers, nil
|
||||
}
|
||||
146
x/create/client/create_test.go
Normal file
146
x/create/client/create_test.go
Normal file
@@ -0,0 +1,146 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestModelfileConfig(t *testing.T) {
|
||||
// Test that ModelfileConfig struct works as expected
|
||||
config := &ModelfileConfig{
|
||||
Template: "{{ .Prompt }}",
|
||||
System: "You are a helpful assistant.",
|
||||
License: "MIT",
|
||||
}
|
||||
|
||||
if config.Template != "{{ .Prompt }}" {
|
||||
t.Errorf("Template = %q, want %q", config.Template, "{{ .Prompt }}")
|
||||
}
|
||||
if config.System != "You are a helpful assistant." {
|
||||
t.Errorf("System = %q, want %q", config.System, "You are a helpful assistant.")
|
||||
}
|
||||
if config.License != "MIT" {
|
||||
t.Errorf("License = %q, want %q", config.License, "MIT")
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelfileConfig_Empty(t *testing.T) {
|
||||
config := &ModelfileConfig{}
|
||||
|
||||
if config.Template != "" {
|
||||
t.Errorf("Template should be empty, got %q", config.Template)
|
||||
}
|
||||
if config.System != "" {
|
||||
t.Errorf("System should be empty, got %q", config.System)
|
||||
}
|
||||
if config.License != "" {
|
||||
t.Errorf("License should be empty, got %q", config.License)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelfileConfig_PartialFields(t *testing.T) {
|
||||
// Test config with only some fields set
|
||||
config := &ModelfileConfig{
|
||||
Template: "{{ .Prompt }}",
|
||||
// System and License intentionally empty
|
||||
}
|
||||
|
||||
if config.Template == "" {
|
||||
t.Error("Template should not be empty")
|
||||
}
|
||||
if config.System != "" {
|
||||
t.Error("System should be empty")
|
||||
}
|
||||
if config.License != "" {
|
||||
t.Error("License should be empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMinOllamaVersion(t *testing.T) {
|
||||
// Verify the minimum version constant is set
|
||||
if MinOllamaVersion == "" {
|
||||
t.Error("MinOllamaVersion should not be empty")
|
||||
}
|
||||
if MinOllamaVersion != "0.14.0" {
|
||||
t.Errorf("MinOllamaVersion = %q, want %q", MinOllamaVersion, "0.14.0")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateModel_InvalidDir(t *testing.T) {
|
||||
// Test that CreateModel returns error for invalid directory
|
||||
err := CreateModel(CreateOptions{
|
||||
ModelName: "test-model",
|
||||
ModelDir: "/nonexistent/path",
|
||||
}, nil)
|
||||
if err == nil {
|
||||
t.Error("expected error for nonexistent directory, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateModel_NotSafetensorsDir(t *testing.T) {
|
||||
// Test that CreateModel returns error for directory without safetensors
|
||||
dir := t.TempDir()
|
||||
|
||||
err := CreateModel(CreateOptions{
|
||||
ModelName: "test-model",
|
||||
ModelDir: dir,
|
||||
}, nil)
|
||||
if err == nil {
|
||||
t.Error("expected error for empty directory, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateOptions(t *testing.T) {
|
||||
opts := CreateOptions{
|
||||
ModelName: "my-model",
|
||||
ModelDir: "/path/to/model",
|
||||
Quantize: "fp8",
|
||||
Modelfile: &ModelfileConfig{
|
||||
Template: "test",
|
||||
System: "system",
|
||||
License: "MIT",
|
||||
},
|
||||
}
|
||||
|
||||
if opts.ModelName != "my-model" {
|
||||
t.Errorf("ModelName = %q, want %q", opts.ModelName, "my-model")
|
||||
}
|
||||
if opts.ModelDir != "/path/to/model" {
|
||||
t.Errorf("ModelDir = %q, want %q", opts.ModelDir, "/path/to/model")
|
||||
}
|
||||
if opts.Quantize != "fp8" {
|
||||
t.Errorf("Quantize = %q, want %q", opts.Quantize, "fp8")
|
||||
}
|
||||
if opts.Modelfile == nil {
|
||||
t.Error("Modelfile should not be nil")
|
||||
}
|
||||
if opts.Modelfile.Template != "test" {
|
||||
t.Errorf("Modelfile.Template = %q, want %q", opts.Modelfile.Template, "test")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateOptions_Defaults(t *testing.T) {
|
||||
opts := CreateOptions{
|
||||
ModelName: "test",
|
||||
ModelDir: "/tmp",
|
||||
}
|
||||
|
||||
// Quantize should default to empty
|
||||
if opts.Quantize != "" {
|
||||
t.Errorf("Quantize should be empty by default, got %q", opts.Quantize)
|
||||
}
|
||||
|
||||
// Modelfile should default to nil
|
||||
if opts.Modelfile != nil {
|
||||
t.Error("Modelfile should be nil by default")
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuantizeSupported(t *testing.T) {
|
||||
// This just verifies the function exists and returns a boolean
|
||||
// The actual value depends on build tags (mlx vs non-mlx)
|
||||
supported := QuantizeSupported()
|
||||
|
||||
// In non-mlx builds, this should be false
|
||||
// We can't easily test both cases, so just verify it returns something
|
||||
_ = supported
|
||||
}
|
||||
391
x/create/create.go
Normal file
391
x/create/create.go
Normal file
@@ -0,0 +1,391 @@
|
||||
package create
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
|
||||
// ModelConfig represents the config blob stored with a model.
|
||||
type ModelConfig struct {
|
||||
ModelFormat string `json:"model_format"`
|
||||
Capabilities []string `json:"capabilities"`
|
||||
}
|
||||
|
||||
// Manifest represents the manifest JSON structure.
|
||||
type Manifest struct {
|
||||
SchemaVersion int `json:"schemaVersion"`
|
||||
MediaType string `json:"mediaType"`
|
||||
Config ManifestLayer `json:"config"`
|
||||
Layers []ManifestLayer `json:"layers"`
|
||||
}
|
||||
|
||||
// ManifestLayer represents a layer in the manifest.
|
||||
type ManifestLayer struct {
|
||||
MediaType string `json:"mediaType"`
|
||||
Digest string `json:"digest"`
|
||||
Size int64 `json:"size"`
|
||||
Name string `json:"name,omitempty"`
|
||||
}
|
||||
|
||||
// defaultManifestDir returns the manifest storage directory.
|
||||
func defaultManifestDir() string {
|
||||
return filepath.Join(envconfig.Models(), "manifests")
|
||||
}
|
||||
|
||||
// defaultBlobDir returns the blob storage directory.
|
||||
func defaultBlobDir() string {
|
||||
return filepath.Join(envconfig.Models(), "blobs")
|
||||
}
|
||||
|
||||
// resolveManifestPath converts a model name to a manifest file path.
|
||||
func resolveManifestPath(modelName string) string {
|
||||
host := "registry.ollama.ai"
|
||||
namespace := "library"
|
||||
name := modelName
|
||||
tag := "latest"
|
||||
|
||||
if idx := strings.LastIndex(name, ":"); idx != -1 {
|
||||
tag = name[idx+1:]
|
||||
name = name[:idx]
|
||||
}
|
||||
|
||||
parts := strings.Split(name, "/")
|
||||
switch len(parts) {
|
||||
case 3:
|
||||
host = parts[0]
|
||||
namespace = parts[1]
|
||||
name = parts[2]
|
||||
case 2:
|
||||
namespace = parts[0]
|
||||
name = parts[1]
|
||||
}
|
||||
|
||||
return filepath.Join(defaultManifestDir(), host, namespace, name, tag)
|
||||
}
|
||||
|
||||
// loadManifest loads a manifest for the given model name.
|
||||
func loadManifest(modelName string) (*Manifest, error) {
|
||||
manifestPath := resolveManifestPath(modelName)
|
||||
|
||||
data, err := os.ReadFile(manifestPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var manifest Manifest
|
||||
if err := json.Unmarshal(data, &manifest); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &manifest, nil
|
||||
}
|
||||
|
||||
// loadModelConfig loads the config blob for a model.
|
||||
func loadModelConfig(modelName string) (*ModelConfig, error) {
|
||||
manifest, err := loadManifest(modelName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Read the config blob
|
||||
blobName := strings.Replace(manifest.Config.Digest, ":", "-", 1)
|
||||
blobPath := filepath.Join(defaultBlobDir(), blobName)
|
||||
|
||||
data, err := os.ReadFile(blobPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var config ModelConfig
|
||||
if err := json.Unmarshal(data, &config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
// IsSafetensorsModel checks if a model was created with the experimental
|
||||
// safetensors builder by checking the model format in the config.
|
||||
func IsSafetensorsModel(modelName string) bool {
|
||||
config, err := loadModelConfig(modelName)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return config.ModelFormat == "safetensors"
|
||||
}
|
||||
|
||||
// IsSafetensorsLLMModel checks if a model is a safetensors LLM model
|
||||
// (has completion capability, not image generation).
|
||||
func IsSafetensorsLLMModel(modelName string) bool {
|
||||
config, err := loadModelConfig(modelName)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return config.ModelFormat == "safetensors" && slices.Contains(config.Capabilities, "completion")
|
||||
}
|
||||
|
||||
// IsImageGenModel checks if a model is an image generation model
|
||||
// (has image capability).
|
||||
func IsImageGenModel(modelName string) bool {
|
||||
config, err := loadModelConfig(modelName)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return config.ModelFormat == "safetensors" && slices.Contains(config.Capabilities, "image")
|
||||
}
|
||||
|
||||
// GetModelArchitecture returns the architecture from the model's config.json layer.
|
||||
func GetModelArchitecture(modelName string) (string, error) {
|
||||
manifest, err := loadManifest(modelName)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Find the config.json layer
|
||||
for _, layer := range manifest.Layers {
|
||||
if layer.Name == "config.json" && layer.MediaType == "application/vnd.ollama.image.json" {
|
||||
blobName := strings.Replace(layer.Digest, ":", "-", 1)
|
||||
blobPath := filepath.Join(defaultBlobDir(), blobName)
|
||||
|
||||
data, err := os.ReadFile(blobPath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var cfg struct {
|
||||
Architectures []string `json:"architectures"`
|
||||
ModelType string `json:"model_type"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Prefer model_type, fall back to first architecture
|
||||
if cfg.ModelType != "" {
|
||||
return cfg.ModelType, nil
|
||||
}
|
||||
if len(cfg.Architectures) > 0 {
|
||||
return cfg.Architectures[0], nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("architecture not found in model config")
|
||||
}
|
||||
|
||||
// IsTensorModelDir checks if the directory contains a diffusers-style tensor model
|
||||
// by looking for model_index.json, which is the standard diffusers pipeline config.
|
||||
func IsTensorModelDir(dir string) bool {
|
||||
_, err := os.Stat(filepath.Join(dir, "model_index.json"))
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// IsSafetensorsModelDir checks if the directory contains a standard safetensors model
|
||||
// by looking for config.json and at least one .safetensors file.
|
||||
func IsSafetensorsModelDir(dir string) bool {
|
||||
// Must have config.json
|
||||
if _, err := os.Stat(filepath.Join(dir, "config.json")); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Must have at least one .safetensors file
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if strings.HasSuffix(entry.Name(), ".safetensors") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// LayerInfo holds metadata for a created layer.
|
||||
type LayerInfo struct {
|
||||
Digest string
|
||||
Size int64
|
||||
MediaType string
|
||||
Name string // Path-style name: "component/tensor" or "path/to/config.json"
|
||||
}
|
||||
|
||||
// LayerCreator is called to create a blob layer.
|
||||
// name is the path-style name (e.g., "tokenizer/tokenizer.json")
|
||||
type LayerCreator func(r io.Reader, mediaType, name string) (LayerInfo, error)
|
||||
|
||||
// TensorLayerCreator creates a tensor blob layer with metadata.
|
||||
// name is the path-style name including component (e.g., "text_encoder/model.embed_tokens.weight")
|
||||
type TensorLayerCreator func(r io.Reader, name, dtype string, shape []int32) (LayerInfo, error)
|
||||
|
||||
// QuantizingTensorLayerCreator creates tensor layers with optional quantization.
|
||||
// When quantize is true, returns multiple layers (weight + scales + biases).
|
||||
type QuantizingTensorLayerCreator func(r io.Reader, name, dtype string, shape []int32, quantize bool) ([]LayerInfo, error)
|
||||
|
||||
// ManifestWriter writes the manifest file.
|
||||
type ManifestWriter func(modelName string, config LayerInfo, layers []LayerInfo) error
|
||||
|
||||
// ShouldQuantize returns true if a tensor should be quantized.
|
||||
// For image gen models (component non-empty): quantizes linear weights, skipping VAE, embeddings, norms.
|
||||
// For LLM models (component empty): quantizes linear weights, skipping embeddings, norms, and small tensors.
|
||||
func ShouldQuantize(name, component string) bool {
|
||||
// Image gen specific: skip VAE entirely
|
||||
if component == "vae" {
|
||||
return false
|
||||
}
|
||||
|
||||
// Skip embeddings
|
||||
if strings.Contains(name, "embed") {
|
||||
return false
|
||||
}
|
||||
|
||||
// Skip layer norms and RMS norms
|
||||
if strings.Contains(name, "norm") || strings.Contains(name, "ln_") || strings.Contains(name, "layernorm") {
|
||||
return false
|
||||
}
|
||||
|
||||
// Skip biases
|
||||
if strings.HasSuffix(name, ".bias") {
|
||||
return false
|
||||
}
|
||||
|
||||
// Only quantize weights
|
||||
return strings.HasSuffix(name, ".weight")
|
||||
}
|
||||
|
||||
// ShouldQuantizeTensor returns true if a tensor should be quantized based on name and shape.
|
||||
// This is a more detailed check that also considers tensor dimensions.
|
||||
func ShouldQuantizeTensor(name string, shape []int32) bool {
|
||||
// Use basic name-based check first
|
||||
if !ShouldQuantize(name, "") {
|
||||
return false
|
||||
}
|
||||
|
||||
// Only quantize 2D tensors (linear layers) - skip 1D (biases, norms) and higher-D (convolutions if any)
|
||||
if len(shape) != 2 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Skip small tensors (less than 1024 elements) - not worth quantizing
|
||||
if len(shape) >= 2 && int64(shape[0])*int64(shape[1]) < 1024 {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// CreateSafetensorsModel imports a standard safetensors model from a directory.
|
||||
// This handles Hugging Face style models with config.json and *.safetensors files.
|
||||
// Stores each tensor as a separate blob for fine-grained deduplication.
|
||||
// If quantize is non-empty (e.g., "fp8"), eligible tensors will be quantized.
|
||||
func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
|
||||
var layers []LayerInfo
|
||||
var configLayer LayerInfo
|
||||
|
||||
entries, err := os.ReadDir(modelDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read directory: %w", err)
|
||||
}
|
||||
|
||||
// Process all safetensors files
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".safetensors") {
|
||||
continue
|
||||
}
|
||||
|
||||
stPath := filepath.Join(modelDir, entry.Name())
|
||||
|
||||
// Extract individual tensors from safetensors file
|
||||
extractor, err := safetensors.OpenForExtraction(stPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open %s: %w", stPath, err)
|
||||
}
|
||||
|
||||
tensorNames := extractor.ListTensors()
|
||||
quantizeMsg := ""
|
||||
if quantize != "" {
|
||||
quantizeMsg = fmt.Sprintf(", quantizing to %s", quantize)
|
||||
}
|
||||
fn(fmt.Sprintf("importing %s (%d tensors%s)", entry.Name(), len(tensorNames), quantizeMsg))
|
||||
|
||||
for _, tensorName := range tensorNames {
|
||||
td, err := extractor.GetTensor(tensorName)
|
||||
if err != nil {
|
||||
extractor.Close()
|
||||
return fmt.Errorf("failed to get tensor %s: %w", tensorName, err)
|
||||
}
|
||||
|
||||
// Determine if this tensor should be quantized
|
||||
doQuantize := quantize != "" && ShouldQuantizeTensor(tensorName, td.Shape)
|
||||
|
||||
// Store as minimal safetensors format (88 bytes header overhead)
|
||||
// This enables native mmap loading via mlx_load_safetensors
|
||||
// createTensorLayer returns multiple layers if quantizing (weight + scales)
|
||||
newLayers, err := createTensorLayer(td.SafetensorsReader(), tensorName, td.Dtype, td.Shape, doQuantize)
|
||||
if err != nil {
|
||||
extractor.Close()
|
||||
return fmt.Errorf("failed to create layer for %s: %w", tensorName, err)
|
||||
}
|
||||
layers = append(layers, newLayers...)
|
||||
}
|
||||
|
||||
extractor.Close()
|
||||
}
|
||||
|
||||
// Process all JSON config files
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".json") {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip the index file as we don't need it after extraction
|
||||
if entry.Name() == "model.safetensors.index.json" {
|
||||
continue
|
||||
}
|
||||
|
||||
cfgPath := entry.Name()
|
||||
fullPath := filepath.Join(modelDir, cfgPath)
|
||||
|
||||
fn(fmt.Sprintf("importing config %s", cfgPath))
|
||||
|
||||
f, err := os.Open(fullPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open %s: %w", cfgPath, err)
|
||||
}
|
||||
|
||||
layer, err := createLayer(f, "application/vnd.ollama.image.json", cfgPath)
|
||||
f.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create layer for %s: %w", cfgPath, err)
|
||||
}
|
||||
|
||||
// Use config.json as the config layer
|
||||
if cfgPath == "config.json" {
|
||||
configLayer = layer
|
||||
}
|
||||
|
||||
layers = append(layers, layer)
|
||||
}
|
||||
|
||||
if configLayer.Digest == "" {
|
||||
return fmt.Errorf("config.json not found in %s", modelDir)
|
||||
}
|
||||
|
||||
fn(fmt.Sprintf("writing manifest for %s", modelName))
|
||||
|
||||
if err := writeManifest(modelName, configLayer, layers); err != nil {
|
||||
return fmt.Errorf("failed to write manifest: %w", err)
|
||||
}
|
||||
|
||||
fn(fmt.Sprintf("successfully imported %s with %d layers", modelName, len(layers)))
|
||||
return nil
|
||||
}
|
||||
752
x/create/create_test.go
Normal file
752
x/create/create_test.go
Normal file
@@ -0,0 +1,752 @@
|
||||
package create
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestIsTensorModelDir(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func(dir string) error
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "valid diffusers model with model_index.json",
|
||||
setup: func(dir string) error {
|
||||
return os.WriteFile(filepath.Join(dir, "model_index.json"), []byte(`{"_class_name": "FluxPipeline"}`), 0644)
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "empty directory",
|
||||
setup: func(dir string) error {
|
||||
return nil
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "directory with other files but no model_index.json",
|
||||
setup: func(dir string) error {
|
||||
return os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{}`), 0644)
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
if err := tt.setup(dir); err != nil {
|
||||
t.Fatalf("setup failed: %v", err)
|
||||
}
|
||||
|
||||
got := IsTensorModelDir(dir)
|
||||
if got != tt.expected {
|
||||
t.Errorf("IsTensorModelDir() = %v, want %v", got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsSafetensorsModelDir(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func(dir string) error
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "valid safetensors model with config.json and .safetensors file",
|
||||
setup: func(dir string) error {
|
||||
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{"model_type": "gemma3"}`), 0644); err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(filepath.Join(dir, "model.safetensors"), []byte("dummy"), 0644)
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "config.json only, no safetensors files",
|
||||
setup: func(dir string) error {
|
||||
return os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{}`), 0644)
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "safetensors file only, no config.json",
|
||||
setup: func(dir string) error {
|
||||
return os.WriteFile(filepath.Join(dir, "model.safetensors"), []byte("dummy"), 0644)
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "empty directory",
|
||||
setup: func(dir string) error {
|
||||
return nil
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "multiple safetensors files with config.json",
|
||||
setup: func(dir string) error {
|
||||
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{}`), 0644); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(dir, "model-00001-of-00002.safetensors"), []byte("dummy"), 0644); err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(filepath.Join(dir, "model-00002-of-00002.safetensors"), []byte("dummy"), 0644)
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
if err := tt.setup(dir); err != nil {
|
||||
t.Fatalf("setup failed: %v", err)
|
||||
}
|
||||
|
||||
got := IsSafetensorsModelDir(dir)
|
||||
if got != tt.expected {
|
||||
t.Errorf("IsSafetensorsModelDir() = %v, want %v", got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsSafetensorsModelDir_NonexistentDir(t *testing.T) {
|
||||
got := IsSafetensorsModelDir("/nonexistent/path/that/does/not/exist")
|
||||
if got != false {
|
||||
t.Errorf("IsSafetensorsModelDir() = %v for nonexistent dir, want false", got)
|
||||
}
|
||||
}
|
||||
|
||||
// createMinimalSafetensors creates a minimal valid safetensors file with one tensor
|
||||
func createMinimalSafetensors(t *testing.T, path string) {
|
||||
t.Helper()
|
||||
|
||||
// Create a minimal safetensors file with a single float32 tensor
|
||||
header := map[string]interface{}{
|
||||
"test_tensor": map[string]interface{}{
|
||||
"dtype": "F32",
|
||||
"shape": []int{2, 2},
|
||||
"data_offsets": []int{0, 16}, // 4 float32 values = 16 bytes
|
||||
},
|
||||
}
|
||||
headerJSON, err := json.Marshal(header)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal header: %v", err)
|
||||
}
|
||||
|
||||
// Pad header to 8-byte alignment
|
||||
padding := (8 - len(headerJSON)%8) % 8
|
||||
headerJSON = append(headerJSON, bytes.Repeat([]byte(" "), padding)...)
|
||||
|
||||
// Write file
|
||||
f, err := os.Create(path)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create file: %v", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
// Write header size (8 bytes, little endian)
|
||||
if err := binary.Write(f, binary.LittleEndian, uint64(len(headerJSON))); err != nil {
|
||||
t.Fatalf("failed to write header size: %v", err)
|
||||
}
|
||||
|
||||
// Write header
|
||||
if _, err := f.Write(headerJSON); err != nil {
|
||||
t.Fatalf("failed to write header: %v", err)
|
||||
}
|
||||
|
||||
// Write tensor data (16 bytes of zeros for 4 float32 values)
|
||||
if _, err := f.Write(make([]byte, 16)); err != nil {
|
||||
t.Fatalf("failed to write tensor data: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateSafetensorsModel(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// Create config.json
|
||||
configJSON := `{"model_type": "test", "architectures": ["TestModel"]}`
|
||||
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0644); err != nil {
|
||||
t.Fatalf("failed to write config.json: %v", err)
|
||||
}
|
||||
|
||||
// Create a minimal safetensors file
|
||||
createMinimalSafetensors(t, filepath.Join(dir, "model.safetensors"))
|
||||
|
||||
// Track what was created
|
||||
var createdLayers []LayerInfo
|
||||
var manifestWritten bool
|
||||
var manifestModelName string
|
||||
var manifestConfigLayer LayerInfo
|
||||
var manifestLayers []LayerInfo
|
||||
var statusMessages []string
|
||||
|
||||
// Mock callbacks
|
||||
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||
data, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
return LayerInfo{}, err
|
||||
}
|
||||
layer := LayerInfo{
|
||||
Digest: "sha256:test",
|
||||
Size: int64(len(data)),
|
||||
MediaType: mediaType,
|
||||
Name: name,
|
||||
}
|
||||
createdLayers = append(createdLayers, layer)
|
||||
return layer, nil
|
||||
}
|
||||
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize bool) ([]LayerInfo, error) {
|
||||
data, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
layer := LayerInfo{
|
||||
Digest: "sha256:tensor_" + name,
|
||||
Size: int64(len(data)),
|
||||
MediaType: "application/vnd.ollama.image.tensor",
|
||||
Name: name,
|
||||
}
|
||||
createdLayers = append(createdLayers, layer)
|
||||
return []LayerInfo{layer}, nil
|
||||
}
|
||||
|
||||
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
|
||||
manifestWritten = true
|
||||
manifestModelName = modelName
|
||||
manifestConfigLayer = config
|
||||
manifestLayers = layers
|
||||
return nil
|
||||
}
|
||||
|
||||
progressFn := func(status string) {
|
||||
statusMessages = append(statusMessages, status)
|
||||
}
|
||||
|
||||
// Run CreateSafetensorsModel
|
||||
err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSafetensorsModel failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify manifest was written
|
||||
if !manifestWritten {
|
||||
t.Error("manifest was not written")
|
||||
}
|
||||
|
||||
if manifestModelName != "test-model" {
|
||||
t.Errorf("manifest model name = %q, want %q", manifestModelName, "test-model")
|
||||
}
|
||||
|
||||
// Verify config layer was set
|
||||
if manifestConfigLayer.Name != "config.json" {
|
||||
t.Errorf("config layer name = %q, want %q", manifestConfigLayer.Name, "config.json")
|
||||
}
|
||||
|
||||
// Verify we have at least one tensor and one config layer
|
||||
hasTensor := false
|
||||
hasConfig := false
|
||||
for _, layer := range manifestLayers {
|
||||
if layer.Name == "test_tensor" {
|
||||
hasTensor = true
|
||||
}
|
||||
if layer.Name == "config.json" {
|
||||
hasConfig = true
|
||||
}
|
||||
}
|
||||
|
||||
if !hasTensor {
|
||||
t.Error("no tensor layer found in manifest")
|
||||
}
|
||||
if !hasConfig {
|
||||
t.Error("no config layer found in manifest")
|
||||
}
|
||||
|
||||
// Verify status messages were sent
|
||||
if len(statusMessages) == 0 {
|
||||
t.Error("no status messages received")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateSafetensorsModel_NoConfigJson(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// Create only a safetensors file, no config.json
|
||||
createMinimalSafetensors(t, filepath.Join(dir, "model.safetensors"))
|
||||
|
||||
// Mock callbacks (minimal)
|
||||
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
return LayerInfo{Name: name}, nil
|
||||
}
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize bool) ([]LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
return []LayerInfo{{Name: name}}, nil
|
||||
}
|
||||
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
|
||||
return nil
|
||||
}
|
||||
progressFn := func(status string) {}
|
||||
|
||||
err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
|
||||
if err == nil {
|
||||
t.Error("expected error for missing config.json, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateSafetensorsModel_EmptyDir(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// Mock callbacks
|
||||
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||
return LayerInfo{}, nil
|
||||
}
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize bool) ([]LayerInfo, error) {
|
||||
return []LayerInfo{{}}, nil
|
||||
}
|
||||
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
|
||||
return nil
|
||||
}
|
||||
progressFn := func(status string) {}
|
||||
|
||||
err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
|
||||
if err == nil {
|
||||
t.Error("expected error for empty directory, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateSafetensorsModel_SkipsIndexJson(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// Create config.json
|
||||
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{}`), 0644); err != nil {
|
||||
t.Fatalf("failed to write config.json: %v", err)
|
||||
}
|
||||
|
||||
// Create model.safetensors.index.json (should be skipped)
|
||||
indexJSON := `{"metadata": {"total_size": 100}, "weight_map": {}}`
|
||||
if err := os.WriteFile(filepath.Join(dir, "model.safetensors.index.json"), []byte(indexJSON), 0644); err != nil {
|
||||
t.Fatalf("failed to write index.json: %v", err)
|
||||
}
|
||||
|
||||
// Create a minimal safetensors file
|
||||
createMinimalSafetensors(t, filepath.Join(dir, "model.safetensors"))
|
||||
|
||||
var configNames []string
|
||||
|
||||
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
configNames = append(configNames, name)
|
||||
return LayerInfo{Name: name, Digest: "sha256:test"}, nil
|
||||
}
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize bool) ([]LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
return []LayerInfo{{Name: name}}, nil
|
||||
}
|
||||
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
|
||||
return nil
|
||||
}
|
||||
progressFn := func(status string) {}
|
||||
|
||||
err := CreateSafetensorsModel("test-model", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSafetensorsModel failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify model.safetensors.index.json was not included
|
||||
for _, name := range configNames {
|
||||
if name == "model.safetensors.index.json" {
|
||||
t.Error("model.safetensors.index.json should have been skipped")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveManifestPath(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
modelName string
|
||||
wantParts []string // Parts that should appear in the path
|
||||
}{
|
||||
{
|
||||
name: "simple model name",
|
||||
modelName: "llama2",
|
||||
wantParts: []string{"registry.ollama.ai", "library", "llama2", "latest"},
|
||||
},
|
||||
{
|
||||
name: "model name with tag",
|
||||
modelName: "llama2:7b",
|
||||
wantParts: []string{"registry.ollama.ai", "library", "llama2", "7b"},
|
||||
},
|
||||
{
|
||||
name: "model name with namespace",
|
||||
modelName: "myuser/mymodel",
|
||||
wantParts: []string{"registry.ollama.ai", "myuser", "mymodel", "latest"},
|
||||
},
|
||||
{
|
||||
name: "model name with namespace and tag",
|
||||
modelName: "myuser/mymodel:v1",
|
||||
wantParts: []string{"registry.ollama.ai", "myuser", "mymodel", "v1"},
|
||||
},
|
||||
{
|
||||
name: "fully qualified model name",
|
||||
modelName: "registry.example.com/namespace/model:tag",
|
||||
wantParts: []string{"registry.example.com", "namespace", "model", "tag"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := resolveManifestPath(tt.modelName)
|
||||
|
||||
for _, part := range tt.wantParts {
|
||||
if !strings.Contains(got, part) {
|
||||
t.Errorf("resolveManifestPath(%q) = %q, missing part %q", tt.modelName, got, part)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLayerInfo(t *testing.T) {
|
||||
layer := LayerInfo{
|
||||
Digest: "sha256:abc123",
|
||||
Size: 1024,
|
||||
MediaType: "application/vnd.ollama.image.tensor",
|
||||
Name: "model.weight",
|
||||
}
|
||||
|
||||
if layer.Digest != "sha256:abc123" {
|
||||
t.Errorf("Digest = %q, want %q", layer.Digest, "sha256:abc123")
|
||||
}
|
||||
if layer.Size != 1024 {
|
||||
t.Errorf("Size = %d, want %d", layer.Size, 1024)
|
||||
}
|
||||
if layer.MediaType != "application/vnd.ollama.image.tensor" {
|
||||
t.Errorf("MediaType = %q, want %q", layer.MediaType, "application/vnd.ollama.image.tensor")
|
||||
}
|
||||
if layer.Name != "model.weight" {
|
||||
t.Errorf("Name = %q, want %q", layer.Name, "model.weight")
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelConfig(t *testing.T) {
|
||||
config := ModelConfig{
|
||||
ModelFormat: "safetensors",
|
||||
Capabilities: []string{"completion", "chat"},
|
||||
}
|
||||
|
||||
if config.ModelFormat != "safetensors" {
|
||||
t.Errorf("ModelFormat = %q, want %q", config.ModelFormat, "safetensors")
|
||||
}
|
||||
if len(config.Capabilities) != 2 {
|
||||
t.Errorf("Capabilities length = %d, want %d", len(config.Capabilities), 2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManifest(t *testing.T) {
|
||||
manifest := Manifest{
|
||||
SchemaVersion: 2,
|
||||
MediaType: "application/vnd.oci.image.manifest.v1+json",
|
||||
Config: ManifestLayer{
|
||||
MediaType: "application/vnd.docker.container.image.v1+json",
|
||||
Digest: "sha256:config",
|
||||
Size: 100,
|
||||
},
|
||||
Layers: []ManifestLayer{
|
||||
{
|
||||
MediaType: "application/vnd.ollama.image.tensor",
|
||||
Digest: "sha256:layer1",
|
||||
Size: 1000,
|
||||
Name: "weight.bin",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if manifest.SchemaVersion != 2 {
|
||||
t.Errorf("SchemaVersion = %d, want %d", manifest.SchemaVersion, 2)
|
||||
}
|
||||
if manifest.Config.Digest != "sha256:config" {
|
||||
t.Errorf("Config.Digest = %q, want %q", manifest.Config.Digest, "sha256:config")
|
||||
}
|
||||
if len(manifest.Layers) != 1 {
|
||||
t.Errorf("Layers length = %d, want %d", len(manifest.Layers), 1)
|
||||
}
|
||||
if manifest.Layers[0].Name != "weight.bin" {
|
||||
t.Errorf("Layers[0].Name = %q, want %q", manifest.Layers[0].Name, "weight.bin")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldQuantize(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tensor string
|
||||
component string
|
||||
want bool
|
||||
}{
|
||||
// VAE component should never be quantized
|
||||
{"vae weight", "decoder.weight", "vae", false},
|
||||
{"vae bias", "decoder.bias", "vae", false},
|
||||
|
||||
// Embeddings should not be quantized
|
||||
{"embedding weight", "embed_tokens.weight", "", false},
|
||||
{"embedding in name", "token_embedding.weight", "", false},
|
||||
|
||||
// Norms should not be quantized
|
||||
{"layer norm", "layer_norm.weight", "", false},
|
||||
{"rms norm", "rms_norm.weight", "", false},
|
||||
{"ln prefix", "ln_1.weight", "", false},
|
||||
{"layernorm in name", "input_layernorm.weight", "", false},
|
||||
|
||||
// Biases should not be quantized
|
||||
{"bias tensor", "attention.bias", "", false},
|
||||
{"proj bias", "o_proj.bias", "", false},
|
||||
|
||||
// Linear weights should be quantized
|
||||
{"linear weight", "q_proj.weight", "", true},
|
||||
{"attention weight", "self_attn.weight", "", true},
|
||||
{"mlp weight", "mlp.gate_proj.weight", "", true},
|
||||
|
||||
// Transformer component weights should be quantized
|
||||
{"transformer weight", "layers.0.weight", "transformer", true},
|
||||
{"text_encoder weight", "encoder.weight", "text_encoder", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := ShouldQuantize(tt.tensor, tt.component)
|
||||
if got != tt.want {
|
||||
t.Errorf("ShouldQuantize(%q, %q) = %v, want %v", tt.tensor, tt.component, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldQuantizeTensor(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tensor string
|
||||
shape []int32
|
||||
want bool
|
||||
}{
|
||||
// 2D tensors with sufficient size should be quantized
|
||||
{"large 2D weight", "q_proj.weight", []int32{4096, 4096}, true},
|
||||
{"medium 2D weight", "small_proj.weight", []int32{128, 128}, true},
|
||||
|
||||
// Small tensors should not be quantized (< 1024 elements)
|
||||
{"tiny 2D weight", "tiny.weight", []int32{16, 16}, false},
|
||||
{"small 2D weight", "small.weight", []int32{31, 31}, false},
|
||||
|
||||
// 1D tensors should not be quantized
|
||||
{"1D tensor", "layer_norm.weight", []int32{4096}, false},
|
||||
|
||||
// 3D+ tensors should not be quantized
|
||||
{"3D tensor", "conv.weight", []int32{64, 64, 3}, false},
|
||||
{"4D tensor", "conv2d.weight", []int32{64, 64, 3, 3}, false},
|
||||
|
||||
// Embeddings should not be quantized regardless of shape
|
||||
{"embedding 2D", "embed_tokens.weight", []int32{32000, 4096}, false},
|
||||
|
||||
// Norms should not be quantized regardless of shape
|
||||
{"norm 2D", "layer_norm.weight", []int32{4096, 1}, false},
|
||||
|
||||
// Biases should not be quantized
|
||||
{"bias 2D", "proj.bias", []int32{4096, 1}, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := ShouldQuantizeTensor(tt.tensor, tt.shape)
|
||||
if got != tt.want {
|
||||
t.Errorf("ShouldQuantizeTensor(%q, %v) = %v, want %v", tt.tensor, tt.shape, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateSafetensorsModel_WithQuantize(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// Create config.json
|
||||
configJSON := `{"model_type": "test", "architectures": ["TestModel"]}`
|
||||
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(configJSON), 0644); err != nil {
|
||||
t.Fatalf("failed to write config.json: %v", err)
|
||||
}
|
||||
|
||||
// Create a minimal safetensors file
|
||||
createMinimalSafetensors(t, filepath.Join(dir, "model.safetensors"))
|
||||
|
||||
var quantizeRequested []bool
|
||||
|
||||
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
return LayerInfo{Name: name, Digest: "sha256:test"}, nil
|
||||
}
|
||||
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize bool) ([]LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
quantizeRequested = append(quantizeRequested, quantize)
|
||||
return []LayerInfo{{Name: name}}, nil
|
||||
}
|
||||
|
||||
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
progressFn := func(status string) {}
|
||||
|
||||
// Run with quantize enabled
|
||||
err := CreateSafetensorsModel("test-model", dir, "fp8", createLayer, createTensorLayer, writeManifest, progressFn)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSafetensorsModel failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify quantize was passed to callback (will be false for small test tensor)
|
||||
if len(quantizeRequested) == 0 {
|
||||
t.Error("no tensors processed")
|
||||
}
|
||||
}
|
||||
|
||||
// createMinimalImageGenModel creates a minimal diffusers-style model directory
|
||||
func createMinimalImageGenModel(t *testing.T, dir string) {
|
||||
t.Helper()
|
||||
|
||||
// Create model_index.json
|
||||
modelIndex := `{"_class_name": "FluxPipeline", "_diffusers_version": "0.30.0"}`
|
||||
if err := os.WriteFile(filepath.Join(dir, "model_index.json"), []byte(modelIndex), 0644); err != nil {
|
||||
t.Fatalf("failed to write model_index.json: %v", err)
|
||||
}
|
||||
|
||||
// Create transformer directory with a safetensors file
|
||||
transformerDir := filepath.Join(dir, "transformer")
|
||||
if err := os.MkdirAll(transformerDir, 0755); err != nil {
|
||||
t.Fatalf("failed to create transformer dir: %v", err)
|
||||
}
|
||||
createMinimalSafetensors(t, filepath.Join(transformerDir, "model.safetensors"))
|
||||
|
||||
// Create transformer config
|
||||
transformerConfig := `{"hidden_size": 3072}`
|
||||
if err := os.WriteFile(filepath.Join(transformerDir, "config.json"), []byte(transformerConfig), 0644); err != nil {
|
||||
t.Fatalf("failed to write transformer config: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateImageGenModel(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
createMinimalImageGenModel(t, dir)
|
||||
|
||||
var manifestWritten bool
|
||||
var manifestModelName string
|
||||
var statusMessages []string
|
||||
|
||||
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
return LayerInfo{Name: name, Digest: "sha256:test"}, nil
|
||||
}
|
||||
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize bool) ([]LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
return []LayerInfo{{Name: name, Digest: "sha256:tensor"}}, nil
|
||||
}
|
||||
|
||||
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
|
||||
manifestWritten = true
|
||||
manifestModelName = modelName
|
||||
return nil
|
||||
}
|
||||
|
||||
progressFn := func(status string) {
|
||||
statusMessages = append(statusMessages, status)
|
||||
}
|
||||
|
||||
err := CreateImageGenModel("test-imagegen", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateImageGenModel failed: %v", err)
|
||||
}
|
||||
|
||||
if !manifestWritten {
|
||||
t.Error("manifest was not written")
|
||||
}
|
||||
|
||||
if manifestModelName != "test-imagegen" {
|
||||
t.Errorf("manifest model name = %q, want %q", manifestModelName, "test-imagegen")
|
||||
}
|
||||
|
||||
if len(statusMessages) == 0 {
|
||||
t.Error("no status messages received")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateImageGenModel_NoModelIndex(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// Create only transformer without model_index.json
|
||||
transformerDir := filepath.Join(dir, "transformer")
|
||||
if err := os.MkdirAll(transformerDir, 0755); err != nil {
|
||||
t.Fatalf("failed to create transformer dir: %v", err)
|
||||
}
|
||||
createMinimalSafetensors(t, filepath.Join(transformerDir, "model.safetensors"))
|
||||
|
||||
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
return LayerInfo{Name: name}, nil
|
||||
}
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize bool) ([]LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
return []LayerInfo{{Name: name}}, nil
|
||||
}
|
||||
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
|
||||
return nil
|
||||
}
|
||||
progressFn := func(status string) {}
|
||||
|
||||
err := CreateImageGenModel("test-imagegen", dir, "", createLayer, createTensorLayer, writeManifest, progressFn)
|
||||
if err == nil {
|
||||
t.Error("expected error for missing model_index.json, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateImageGenModel_WithQuantize(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
createMinimalImageGenModel(t, dir)
|
||||
|
||||
var quantizeRequested []bool
|
||||
|
||||
createLayer := func(r io.Reader, mediaType, name string) (LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
return LayerInfo{Name: name, Digest: "sha256:test"}, nil
|
||||
}
|
||||
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, quantize bool) ([]LayerInfo, error) {
|
||||
io.ReadAll(r)
|
||||
quantizeRequested = append(quantizeRequested, quantize)
|
||||
return []LayerInfo{{Name: name}}, nil
|
||||
}
|
||||
|
||||
writeManifest := func(modelName string, config LayerInfo, layers []LayerInfo) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
progressFn := func(status string) {}
|
||||
|
||||
err := CreateImageGenModel("test-imagegen", dir, "fp8", createLayer, createTensorLayer, writeManifest, progressFn)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateImageGenModel failed: %v", err)
|
||||
}
|
||||
|
||||
if len(quantizeRequested) == 0 {
|
||||
t.Error("no tensors processed")
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package imagegen
|
||||
package create
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
@@ -12,43 +12,17 @@ import (
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
|
||||
// IsTensorModelDir checks if the directory contains a tensor model
|
||||
// by looking for model_index.json, which is the standard diffusers pipeline config.
|
||||
func IsTensorModelDir(dir string) bool {
|
||||
_, err := os.Stat(filepath.Join(dir, "model_index.json"))
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// LayerInfo holds metadata for a created layer.
|
||||
type LayerInfo struct {
|
||||
Digest string
|
||||
Size int64
|
||||
MediaType string
|
||||
Name string // Path-style name: "component/tensor" or "path/to/config.json"
|
||||
}
|
||||
|
||||
// LayerCreator is called to create a blob layer.
|
||||
// name is the path-style name (e.g., "tokenizer/tokenizer.json")
|
||||
type LayerCreator func(r io.Reader, mediaType, name string) (LayerInfo, error)
|
||||
|
||||
// TensorLayerCreator creates a tensor blob layer with metadata.
|
||||
// name is the path-style name including component (e.g., "text_encoder/model.embed_tokens.weight")
|
||||
type TensorLayerCreator func(r io.Reader, name, dtype string, shape []int32) (LayerInfo, error)
|
||||
|
||||
// ManifestWriter writes the manifest file.
|
||||
type ManifestWriter func(modelName string, config LayerInfo, layers []LayerInfo) error
|
||||
|
||||
// CreateModel imports an image generation model from a directory.
|
||||
// CreateImageGenModel imports an image generation model from a directory.
|
||||
// Stores each tensor as a separate blob for fine-grained deduplication.
|
||||
// If quantize is "fp8", linear weights in transformer/text_encoder are quantized to mxfp8 format.
|
||||
// Layer creation and manifest writing are done via callbacks to avoid import cycles.
|
||||
func CreateModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
|
||||
func CreateImageGenModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
|
||||
var layers []LayerInfo
|
||||
var configLayer LayerInfo
|
||||
var totalParams int64 // Count parameters from original tensor shapes
|
||||
|
||||
// Components to process - extract individual tensors from each
|
||||
components := []string{"text_encoder", "transformer", "vae", "vision_language_encoder"}
|
||||
components := []string{"text_encoder", "transformer", "vae"}
|
||||
|
||||
for _, component := range components {
|
||||
componentDir := filepath.Join(modelDir, component)
|
||||
@@ -126,13 +100,10 @@ func CreateModel(modelName, modelDir, quantize string, createLayer LayerCreator,
|
||||
"text_encoder/generation_config.json",
|
||||
"transformer/config.json",
|
||||
"vae/config.json",
|
||||
"vision_language_encoder/config.json",
|
||||
"scheduler/scheduler_config.json",
|
||||
"tokenizer/tokenizer.json",
|
||||
"tokenizer/tokenizer_config.json",
|
||||
"tokenizer/vocab.json",
|
||||
"processor/tokenizer.json", // GLM-Image main tokenizer
|
||||
"processor/tokenizer_config.json", // GLM-Image tokenizer config
|
||||
}
|
||||
|
||||
for _, cfgPath := range configFiles {
|
||||
@@ -1,190 +0,0 @@
|
||||
// Package client provides client-side model creation for tensor-based models.
|
||||
//
|
||||
// This package is in x/ because the tensor model storage format is under development.
|
||||
// It also exists to break an import cycle: server imports x/imagegen, so x/imagegen
|
||||
// cannot import server. This sub-package can import server because server doesn't
|
||||
// import it.
|
||||
//
|
||||
// TODO (jmorganca): This is temporary. When tensor models are promoted to production:
|
||||
// 1. Add proper API endpoints for tensor model creation
|
||||
// 2. Move tensor extraction to server-side
|
||||
// 3. Remove this package
|
||||
// 4. Follow the same client→server pattern as regular model creation
|
||||
package client
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/ollama/ollama/server"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
)
|
||||
|
||||
// MinOllamaVersion is the minimum Ollama version required for image generation models.
|
||||
const MinOllamaVersion = "0.14.0"
|
||||
|
||||
// CreateModel imports a tensor-based model from a local directory.
|
||||
// This creates blobs and manifest directly on disk, bypassing the HTTP API.
|
||||
// If quantize is "fp8", weights will be quantized to mxfp8 format during import.
|
||||
//
|
||||
// TODO (jmorganca): Replace with API-based creation when promoted to production.
|
||||
func CreateModel(modelName, modelDir, quantize string, p *progress.Progress) error {
|
||||
if !imagegen.IsTensorModelDir(modelDir) {
|
||||
return fmt.Errorf("%s is not an image generation model directory (model_index.json not found)", modelDir)
|
||||
}
|
||||
|
||||
status := "importing image generation model"
|
||||
spinner := progress.NewSpinner(status)
|
||||
p.Add("imagegen", spinner)
|
||||
|
||||
// Create layer callback for config files
|
||||
createLayer := func(r io.Reader, mediaType, name string) (imagegen.LayerInfo, error) {
|
||||
layer, err := server.NewLayer(r, mediaType)
|
||||
if err != nil {
|
||||
return imagegen.LayerInfo{}, err
|
||||
}
|
||||
layer.Name = name
|
||||
|
||||
return imagegen.LayerInfo{
|
||||
Digest: layer.Digest,
|
||||
Size: layer.Size,
|
||||
MediaType: layer.MediaType,
|
||||
Name: name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Create tensor layer callback for individual tensors
|
||||
// name is path-style: "component/tensor_name"
|
||||
// When quantize is true, returns multiple layers (weight + scales)
|
||||
createTensorLayer := func(r io.Reader, name, dtype string, shape []int32, doQuantize bool) ([]imagegen.LayerInfo, error) {
|
||||
if doQuantize {
|
||||
// Check if quantization is supported
|
||||
if !QuantizeSupported() {
|
||||
return nil, fmt.Errorf("quantization requires MLX support")
|
||||
}
|
||||
|
||||
// Quantize the tensor (affine mode returns weight, scales, qbiases)
|
||||
qweightData, scalesData, qbiasData, _, _, _, err := quantizeTensor(r, name, dtype, shape)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to quantize %s: %w", name, err)
|
||||
}
|
||||
|
||||
// Create layer for quantized weight
|
||||
weightLayer, err := server.NewLayer(bytes.NewReader(qweightData), server.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create layer for scales (use _scale suffix convention)
|
||||
scalesLayer, err := server.NewLayer(bytes.NewReader(scalesData), server.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
layers := []imagegen.LayerInfo{
|
||||
{
|
||||
Digest: weightLayer.Digest,
|
||||
Size: weightLayer.Size,
|
||||
MediaType: weightLayer.MediaType,
|
||||
Name: name, // Keep original name for weight
|
||||
},
|
||||
{
|
||||
Digest: scalesLayer.Digest,
|
||||
Size: scalesLayer.Size,
|
||||
MediaType: scalesLayer.MediaType,
|
||||
Name: name + "_scale", // Add _scale suffix
|
||||
},
|
||||
}
|
||||
|
||||
// Add qbiases layer if present (affine mode)
|
||||
if qbiasData != nil {
|
||||
qbiasLayer, err := server.NewLayer(bytes.NewReader(qbiasData), server.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
layers = append(layers, imagegen.LayerInfo{
|
||||
Digest: qbiasLayer.Digest,
|
||||
Size: qbiasLayer.Size,
|
||||
MediaType: qbiasLayer.MediaType,
|
||||
Name: name + "_qbias", // Add _qbias suffix
|
||||
})
|
||||
}
|
||||
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
// Non-quantized path: just create a single layer
|
||||
layer, err := server.NewLayer(r, server.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return []imagegen.LayerInfo{
|
||||
{
|
||||
Digest: layer.Digest,
|
||||
Size: layer.Size,
|
||||
MediaType: layer.MediaType,
|
||||
Name: name,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Create manifest writer callback
|
||||
writeManifest := func(modelName string, config imagegen.LayerInfo, layers []imagegen.LayerInfo) error {
|
||||
name := model.ParseName(modelName)
|
||||
if !name.IsValid() {
|
||||
return fmt.Errorf("invalid model name: %s", modelName)
|
||||
}
|
||||
|
||||
// Create a proper config blob with version requirement
|
||||
configData := model.ConfigV2{
|
||||
ModelFormat: "safetensors",
|
||||
Capabilities: []string{"image"},
|
||||
Requires: MinOllamaVersion,
|
||||
}
|
||||
configJSON, err := json.Marshal(configData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal config: %w", err)
|
||||
}
|
||||
|
||||
// Create config layer blob
|
||||
configLayer, err := server.NewLayer(bytes.NewReader(configJSON), "application/vnd.docker.container.image.v1+json")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create config layer: %w", err)
|
||||
}
|
||||
|
||||
// Convert LayerInfo to server.Layer (include the original model_index.json in layers)
|
||||
serverLayers := make([]server.Layer, len(layers))
|
||||
for i, l := range layers {
|
||||
serverLayers[i] = server.Layer{
|
||||
MediaType: l.MediaType,
|
||||
Digest: l.Digest,
|
||||
Size: l.Size,
|
||||
Name: l.Name,
|
||||
}
|
||||
}
|
||||
|
||||
return server.WriteManifest(name, configLayer, serverLayers)
|
||||
}
|
||||
|
||||
// Progress callback
|
||||
progressFn := func(msg string) {
|
||||
spinner.Stop()
|
||||
status = msg
|
||||
spinner = progress.NewSpinner(status)
|
||||
p.Add("imagegen", spinner)
|
||||
}
|
||||
|
||||
err := imagegen.CreateModel(modelName, modelDir, quantize, createLayer, createTensorLayer, writeManifest, progressFn)
|
||||
spinner.Stop()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Printf("Created image generation model '%s'\n", modelName)
|
||||
return nil
|
||||
}
|
||||
@@ -11,11 +11,9 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime/pprof"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/models/gemma3"
|
||||
"github.com/ollama/ollama/x/imagegen/models/glm_image"
|
||||
"github.com/ollama/ollama/x/imagegen/models/gpt_oss"
|
||||
"github.com/ollama/ollama/x/imagegen/models/llama"
|
||||
"github.com/ollama/ollama/x/imagegen/models/qwen_image"
|
||||
@@ -63,7 +61,6 @@ func main() {
|
||||
|
||||
// Legacy mode flags
|
||||
zimageFlag := flag.Bool("zimage", false, "Z-Image generation")
|
||||
glmImageFlag := flag.Bool("glm-image", false, "GLM-Image generation")
|
||||
qwenImage := flag.Bool("qwen-image", false, "Qwen-Image text-to-image generation")
|
||||
qwenImageEdit := flag.Bool("qwen-image-edit", false, "Qwen-Image-Edit image editing")
|
||||
var inputImages stringSlice
|
||||
@@ -120,33 +117,6 @@ func main() {
|
||||
if err == nil {
|
||||
err = saveImageArray(img, *out)
|
||||
}
|
||||
case *glmImageFlag:
|
||||
m := &glm_image.Model{}
|
||||
// Use LoadFromPath if model path looks like a directory, otherwise use Load (ollama manifest)
|
||||
var loadErr error
|
||||
if strings.HasPrefix(*modelPath, ".") || strings.HasPrefix(*modelPath, "/") {
|
||||
loadErr = m.LoadFromPath(*modelPath)
|
||||
} else {
|
||||
loadErr = m.Load(*modelPath)
|
||||
}
|
||||
if loadErr != nil {
|
||||
log.Fatal(loadErr)
|
||||
}
|
||||
var img *mlx.Array
|
||||
img, err = m.GenerateFromConfig(context.Background(), &glm_image.GenerateConfig{
|
||||
Prompt: *prompt,
|
||||
Width: int32(*width),
|
||||
Height: int32(*height),
|
||||
Steps: *steps,
|
||||
Seed: *seed,
|
||||
Temperature: float32(*temperature),
|
||||
TopP: float32(*topP),
|
||||
GuidanceScale: float32(*cfgScale),
|
||||
MaxVisualTokens: int32(*maxTokens),
|
||||
})
|
||||
if err == nil {
|
||||
err = saveImageArray(img, *out)
|
||||
}
|
||||
case *qwenImage:
|
||||
m, loadErr := qwen_image.LoadPersistent(*modelPath)
|
||||
if loadErr != nil {
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
# Image generation models (experimental)
|
||||
|
||||
Experimental image generation models are available for **macOS** in Ollama:
|
||||
|
||||
## Available models
|
||||
|
||||
- [Z-Image-Turbo](https://ollama.com/x/z-image-turbo)
|
||||
|
||||
```
|
||||
ollama run x/z-image-turbo
|
||||
```
|
||||
|
||||
> **Note**: [`x`](https://ollama.com/x) is a username on ollama.com where the maintainer team uploads experimental models
|
||||
|
||||
More models coming soon:
|
||||
|
||||
1. Qwen-Image-2512
|
||||
2. Qwen-Image-Edit-2511
|
||||
3. GLM-Image
|
||||
@@ -27,7 +27,6 @@ var modelVRAMEstimates = map[string]uint64{
|
||||
"ZImagePipeline": 21 * GB, // ~21GB for Z-Image (text encoder + transformer + VAE)
|
||||
"FluxPipeline": 21 * GB, // ~21GB for Flux (same architecture)
|
||||
"QwenImagePipeline": 80 * GB, // TODO: verify actual requirements, using conservative estimate for now
|
||||
"GlmImagePipeline": 80 * GB, // ~34GB weights + ~46GB working memory for 9B+7B hybrid model
|
||||
}
|
||||
|
||||
// CheckPlatformSupport validates that image generation is supported on the current platform.
|
||||
|
||||
@@ -1,693 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package glm_image implements the GLM-Image hybrid AR + diffusion model.
|
||||
package glm_image
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// ByT5Tokenizer is a simple byte-level tokenizer for ByT5
|
||||
// ByT5 uses bytes as tokens: each byte (0-255) maps to token ID (3-258)
|
||||
// Special tokens: 0=pad, 1=eos, 2=unk
|
||||
type ByT5Tokenizer struct {
|
||||
PadTokenID int32
|
||||
EOSTokenID int32
|
||||
UNKTokenID int32
|
||||
}
|
||||
|
||||
// NewByT5Tokenizer creates a new ByT5 tokenizer
|
||||
func NewByT5Tokenizer() *ByT5Tokenizer {
|
||||
return &ByT5Tokenizer{
|
||||
PadTokenID: 0,
|
||||
EOSTokenID: 1,
|
||||
UNKTokenID: 2,
|
||||
}
|
||||
}
|
||||
|
||||
// Encode converts a string to token IDs
|
||||
func (t *ByT5Tokenizer) Encode(text string) []int32 {
|
||||
bytes := []byte(text)
|
||||
tokens := make([]int32, len(bytes))
|
||||
for i, b := range bytes {
|
||||
// Standard ByT5 tokenization: bytes 0-255 map to tokens 3-258
|
||||
// (tokens 0, 1, 2 are PAD, EOS, UNK)
|
||||
tokens[i] = int32(b) + 3
|
||||
}
|
||||
return tokens
|
||||
}
|
||||
|
||||
// Decode converts token IDs back to a string
|
||||
func (t *ByT5Tokenizer) Decode(tokens []int32) string {
|
||||
bytes := make([]byte, 0, len(tokens))
|
||||
for _, tok := range tokens {
|
||||
if tok >= 3 && tok < 259 {
|
||||
bytes = append(bytes, byte(tok-3))
|
||||
}
|
||||
}
|
||||
return string(bytes)
|
||||
}
|
||||
|
||||
// GenerateConfig holds all options for image generation.
|
||||
type GenerateConfig struct {
|
||||
Prompt string
|
||||
NegativePrompt string // For CFG (optional, not typically used with GLM-Image)
|
||||
GuidanceScale float32 // Guidance scale (default: 1.5)
|
||||
Width int32 // Image width (default: 1024, must be divisible by 32)
|
||||
Height int32 // Image height (default: 1024, must be divisible by 32)
|
||||
Steps int // Diffusion denoising steps (default: 50)
|
||||
Seed int64 // Random seed
|
||||
Progress ProgressFunc // Optional progress callback
|
||||
|
||||
// AR generation options
|
||||
MaxVisualTokens int32 // Max visual tokens to generate (default: 256)
|
||||
Temperature float32 // AR sampling temperature (default: 0.9)
|
||||
TopP float32 // Nucleus sampling (default: 0.75)
|
||||
}
|
||||
|
||||
// ProgressFunc is called during generation with stage and step progress.
|
||||
type ProgressFunc func(stage string, step, totalSteps int)
|
||||
|
||||
// Model represents a GLM-Image hybrid model.
|
||||
type Model struct {
|
||||
ModelName string
|
||||
Tokenizer *ByT5Tokenizer // For T5 text encoder (glyph embeddings)
|
||||
GLMTokenizer *GLMTokenizer // For AR model (visual token generation)
|
||||
TextEncoder *T5TextEncoder
|
||||
VisionLanguageEncoder *VisionLanguageEncoder
|
||||
Transformer *DiffusionTransformer
|
||||
VAEDecoder *VAEDecoder
|
||||
}
|
||||
|
||||
// Load loads the GLM-Image model from ollama blob storage.
|
||||
func (m *Model) Load(modelName string) error {
|
||||
fmt.Printf("Loading GLM-Image model from manifest: %s...\n", modelName)
|
||||
start := time.Now()
|
||||
|
||||
if mlx.GPUIsAvailable() {
|
||||
mlx.SetDefaultDeviceGPU()
|
||||
mlx.EnableCompile()
|
||||
}
|
||||
|
||||
m.ModelName = modelName
|
||||
|
||||
// Load manifest
|
||||
manifest, err := imagegen.LoadManifest(modelName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load manifest: %w", err)
|
||||
}
|
||||
|
||||
// Create ByT5 tokenizer (byte-level, no vocabulary file needed)
|
||||
// Used for T5 text encoder (glyph embeddings)
|
||||
fmt.Print(" Creating ByT5 tokenizer... ")
|
||||
m.Tokenizer = NewByT5Tokenizer()
|
||||
fmt.Println("✓")
|
||||
|
||||
// Load GLM tokenizer for AR model (visual token generation)
|
||||
fmt.Print(" Loading GLM tokenizer... ")
|
||||
glmTok, err := NewGLMTokenizer(manifest)
|
||||
if err != nil {
|
||||
return fmt.Errorf("glm tokenizer: %w", err)
|
||||
}
|
||||
m.GLMTokenizer = glmTok
|
||||
fmt.Println("✓")
|
||||
|
||||
// Load T5 text encoder (~830MB)
|
||||
m.TextEncoder = &T5TextEncoder{}
|
||||
if err := m.TextEncoder.Load(manifest); err != nil {
|
||||
return fmt.Errorf("text encoder: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.TextEncoder)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
// Load vision-language encoder (~19GB, 9B params)
|
||||
m.VisionLanguageEncoder = &VisionLanguageEncoder{}
|
||||
if err := m.VisionLanguageEncoder.Load(manifest); err != nil {
|
||||
return fmt.Errorf("vision language encoder: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.VisionLanguageEncoder)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
// Load diffusion transformer (~13GB, 7B params)
|
||||
m.Transformer = &DiffusionTransformer{}
|
||||
if err := m.Transformer.Load(manifest); err != nil {
|
||||
return fmt.Errorf("transformer: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.Transformer)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
// Load VAE decoder (~775MB)
|
||||
m.VAEDecoder = &VAEDecoder{}
|
||||
if err := m.VAEDecoder.Load(manifest); err != nil {
|
||||
return fmt.Errorf("VAE decoder: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.VAEDecoder)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
mem := mlx.MetalGetActiveMemory()
|
||||
fmt.Printf(" Loaded in %.2fs (%.1f GB VRAM)\n", time.Since(start).Seconds(), float64(mem)/(1024*1024*1024))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadFromPath loads the model from a directory path (not ollama manifest)
|
||||
func (m *Model) LoadFromPath(modelPath string) error {
|
||||
fmt.Printf("Loading GLM-Image model from path: %s...\n", modelPath)
|
||||
start := time.Now()
|
||||
|
||||
if mlx.GPUIsAvailable() {
|
||||
mlx.SetDefaultDeviceGPU()
|
||||
mlx.EnableCompile()
|
||||
}
|
||||
|
||||
m.ModelName = modelPath
|
||||
|
||||
// Create ByT5 tokenizer (byte-level, no vocabulary file needed)
|
||||
fmt.Print(" Creating ByT5 tokenizer... ")
|
||||
m.Tokenizer = NewByT5Tokenizer()
|
||||
fmt.Println("✓")
|
||||
|
||||
// Load GLM tokenizer for AR model (visual token generation)
|
||||
fmt.Print(" Loading GLM tokenizer... ")
|
||||
glmTok, err := NewGLMTokenizerFromPath(modelPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("glm tokenizer: %w", err)
|
||||
}
|
||||
m.GLMTokenizer = glmTok
|
||||
fmt.Println("✓")
|
||||
|
||||
// Load T5 text encoder
|
||||
m.TextEncoder = &T5TextEncoder{}
|
||||
if err := m.TextEncoder.LoadFromPath(filepath.Join(modelPath, "text_encoder")); err != nil {
|
||||
return fmt.Errorf("text encoder: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.TextEncoder)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
// Load vision-language encoder
|
||||
m.VisionLanguageEncoder = &VisionLanguageEncoder{}
|
||||
if err := m.VisionLanguageEncoder.LoadFromPath(filepath.Join(modelPath, "vision_language_encoder")); err != nil {
|
||||
return fmt.Errorf("vision language encoder: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.VisionLanguageEncoder)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
// Load diffusion transformer
|
||||
m.Transformer = &DiffusionTransformer{}
|
||||
if err := m.Transformer.LoadFromPath(filepath.Join(modelPath, "transformer")); err != nil {
|
||||
return fmt.Errorf("transformer: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.Transformer)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
// Load VAE decoder
|
||||
m.VAEDecoder = &VAEDecoder{}
|
||||
if err := m.VAEDecoder.LoadFromPath(filepath.Join(modelPath, "vae")); err != nil {
|
||||
return fmt.Errorf("VAE decoder: %w", err)
|
||||
}
|
||||
mlx.Eval(mlx.Collect(m.VAEDecoder)...)
|
||||
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||
|
||||
mem := mlx.MetalGetActiveMemory()
|
||||
fmt.Printf(" Loaded in %.2fs (%.1f GB VRAM)\n", time.Since(start).Seconds(), float64(mem)/(1024*1024*1024))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Generate creates an image from a prompt.
|
||||
func (m *Model) Generate(prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
|
||||
return m.GenerateFromConfig(context.Background(), &GenerateConfig{
|
||||
Prompt: prompt,
|
||||
Width: width,
|
||||
Height: height,
|
||||
Steps: steps,
|
||||
Seed: seed,
|
||||
})
|
||||
}
|
||||
|
||||
// GenerateWithProgress creates an image with progress callback.
|
||||
func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress ProgressFunc) (*mlx.Array, error) {
|
||||
return m.GenerateFromConfig(context.Background(), &GenerateConfig{
|
||||
Prompt: prompt,
|
||||
Width: width,
|
||||
Height: height,
|
||||
Steps: steps,
|
||||
Seed: seed,
|
||||
Progress: progress,
|
||||
})
|
||||
}
|
||||
|
||||
// GenerateFromConfig generates an image using the unified config struct.
|
||||
func (m *Model) GenerateFromConfig(ctx context.Context, cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
start := time.Now()
|
||||
result, err := m.generate(ctx, cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fmt.Printf("Generated in %.2fs (%d diffusion steps)\n", time.Since(start).Seconds(), cfg.Steps)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GenerateImage implements model.ImageModel interface.
|
||||
func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
|
||||
return m.Generate(prompt, width, height, steps, seed)
|
||||
}
|
||||
|
||||
// generate is the internal generation pipeline.
|
||||
func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array, error) {
|
||||
// Apply defaults
|
||||
if cfg.Width <= 0 {
|
||||
cfg.Width = 1024
|
||||
}
|
||||
if cfg.Height <= 0 {
|
||||
cfg.Height = 1024
|
||||
}
|
||||
if cfg.Steps <= 0 {
|
||||
cfg.Steps = 50
|
||||
}
|
||||
if cfg.GuidanceScale <= 0 {
|
||||
cfg.GuidanceScale = 1.5
|
||||
}
|
||||
// Calculate MaxVisualTokens based on image dimensions
|
||||
// GLM-Image generates TWO grids of visual tokens:
|
||||
// 1. First: prev (small) grid - prevTokenH × prevTokenW tokens
|
||||
// 2. Then: target (large) grid - tokenH × tokenW tokens
|
||||
// After generation, we extract only the TARGET grid tokens for diffusion.
|
||||
factor := int32(32)
|
||||
tokenH := cfg.Height / factor
|
||||
tokenW := cfg.Width / factor
|
||||
targetGridTokens := tokenH * tokenW
|
||||
|
||||
// Compute prev grid dimensions using diffusers formula:
|
||||
// ratio = token_h / token_w
|
||||
// prev_token_h = int(sqrt(ratio) * 16)
|
||||
// prev_token_w = int(sqrt(1/ratio) * 16)
|
||||
ratio := float64(tokenH) / float64(tokenW)
|
||||
prevTokenH := int32(math.Sqrt(ratio) * 16)
|
||||
prevTokenW := int32(math.Sqrt(1/ratio) * 16)
|
||||
prevGridTokens := prevTokenH * prevTokenW
|
||||
|
||||
// Total tokens to generate = prev grid + target grid
|
||||
// (diffusers does max_new_tokens = total + 1 for EOS, but we stop on EOS anyway)
|
||||
cfg.MaxVisualTokens = prevGridTokens + targetGridTokens
|
||||
if cfg.Temperature <= 0 {
|
||||
cfg.Temperature = 0.9
|
||||
}
|
||||
if cfg.TopP <= 0 {
|
||||
cfg.TopP = 0.75
|
||||
}
|
||||
|
||||
// Ensure dimensions are divisible by 32
|
||||
cfg.Width = (cfg.Width / 32) * 32
|
||||
cfg.Height = (cfg.Height / 32) * 32
|
||||
|
||||
tcfg := m.Transformer.Config
|
||||
latentH := cfg.Height / 8
|
||||
latentW := cfg.Width / 8
|
||||
|
||||
// Progress callback helper
|
||||
progress := func(stage string, step, total int) {
|
||||
if cfg.Progress != nil {
|
||||
cfg.Progress(stage, step, total)
|
||||
}
|
||||
}
|
||||
|
||||
// === PHASE 1: T5 Text Encoding ===
|
||||
fmt.Println("[T5] Encoding glyph text...")
|
||||
progress("text_encoding", 0, 1)
|
||||
textEmbed := m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.Prompt)
|
||||
mlx.Keep(textEmbed)
|
||||
mlx.Eval(textEmbed)
|
||||
fmt.Printf("[T5] Done, shape: %v\n", textEmbed.Shape())
|
||||
progress("text_encoding", 1, 1)
|
||||
|
||||
// === PHASE 2: AR Visual Token Generation ===
|
||||
fmt.Printf("[AR] Generating %d visual tokens...\n", cfg.MaxVisualTokens)
|
||||
progress("ar_generation", 0, int(cfg.MaxVisualTokens))
|
||||
visualTokens := m.VisionLanguageEncoder.Generate(
|
||||
cfg.Prompt,
|
||||
m.GLMTokenizer,
|
||||
cfg.MaxVisualTokens,
|
||||
cfg.Temperature,
|
||||
cfg.TopP,
|
||||
cfg.Seed,
|
||||
cfg.Height,
|
||||
cfg.Width,
|
||||
func(step int) {
|
||||
if step%100 == 0 || step < 10 {
|
||||
fmt.Printf("[AR] Step %d/%d\n", step, cfg.MaxVisualTokens)
|
||||
}
|
||||
progress("ar_generation", step, int(cfg.MaxVisualTokens))
|
||||
},
|
||||
)
|
||||
mlx.Keep(visualTokens)
|
||||
mlx.Eval(visualTokens)
|
||||
fmt.Printf("[AR] Done generating visual tokens\n")
|
||||
progress("ar_generation", int(cfg.MaxVisualTokens), int(cfg.MaxVisualTokens))
|
||||
|
||||
vtShape := visualTokens.Shape()
|
||||
totalGenerated := vtShape[1]
|
||||
fmt.Printf("[AR] Generated %d tokens total\n", totalGenerated)
|
||||
|
||||
// Extract only the TARGET grid tokens (skip the prev grid tokens)
|
||||
// diffusers: large_image_tokens = outputs[input_length + large_image_start_offset : ...]
|
||||
// large_image_start_offset = prev_grid_size
|
||||
var targetGridVisualTokens *mlx.Array
|
||||
if totalGenerated >= prevGridTokens+targetGridTokens {
|
||||
// Full generation completed - extract target grid
|
||||
targetGridVisualTokens = mlx.Slice(visualTokens,
|
||||
[]int32{0, prevGridTokens},
|
||||
[]int32{1, prevGridTokens + targetGridTokens})
|
||||
mlx.Keep(targetGridVisualTokens)
|
||||
mlx.Eval(targetGridVisualTokens)
|
||||
} else if totalGenerated > prevGridTokens {
|
||||
// Partial target grid - take what we have
|
||||
actualTargetTokens := totalGenerated - prevGridTokens
|
||||
targetGridVisualTokens = mlx.Slice(visualTokens,
|
||||
[]int32{0, prevGridTokens},
|
||||
[]int32{1, totalGenerated})
|
||||
mlx.Keep(targetGridVisualTokens)
|
||||
mlx.Eval(targetGridVisualTokens)
|
||||
fmt.Printf("WARNING: Partial target grid: got %d/%d target tokens\n",
|
||||
actualTargetTokens, targetGridTokens)
|
||||
} else {
|
||||
// Not enough tokens - EOS came too early
|
||||
return nil, fmt.Errorf("AR generation stopped too early: got %d tokens, need at least %d (prev grid) + 1",
|
||||
totalGenerated, prevGridTokens)
|
||||
}
|
||||
|
||||
// === PHASE 3: Diffusion Decoding ===
|
||||
// Setup scheduler with dynamic shift based on image size
|
||||
scheduler := NewFlowMatchScheduler(DefaultSchedulerConfig())
|
||||
imgSeqLen := (latentH / tcfg.PatchSize) * (latentW / tcfg.PatchSize)
|
||||
scheduler.SetTimestepsWithDynamicShift(cfg.Steps, imgSeqLen)
|
||||
|
||||
// Initialize noise latents [B, C, H, W]
|
||||
latents := scheduler.InitNoise([]int32{1, tcfg.InChannels, latentH, latentW}, cfg.Seed)
|
||||
mlx.Eval(latents)
|
||||
|
||||
// Upsample TARGET grid visual tokens 2x to match patch count (matching diffusers)
|
||||
// target_grid tokens -> 2x upsample -> patch_count
|
||||
// e.g., 32x32=1024 tokens -> 64x64=4096 patches for 1024x1024
|
||||
visualTokensUpsampled := upsampleTokens(targetGridVisualTokens, tokenH, tokenW, 2)
|
||||
|
||||
// Prepare prior embeddings from upsampled visual tokens (VQ codebook lookup + projection)
|
||||
priorEmbed := m.Transformer.EmbedPriorTokens(visualTokensUpsampled)
|
||||
mlx.Keep(priorEmbed)
|
||||
mlx.Eval(priorEmbed)
|
||||
|
||||
// Prepare text conditioning (project T5 embeddings)
|
||||
textCond := m.Transformer.ProjectTextEmbeddings(textEmbed)
|
||||
mlx.Keep(textCond)
|
||||
mlx.Eval(textCond)
|
||||
|
||||
// === CFG Setup ===
|
||||
// For classifier-free guidance, we need unconditional (negative) text embeddings
|
||||
// GLM-Image uses empty string "" for negative prompt
|
||||
doCFG := cfg.GuidanceScale > 1.0
|
||||
var negativeTextCond *mlx.Array
|
||||
if doCFG {
|
||||
// Encode empty string for negative prompt
|
||||
negativeTextEmbed := m.TextEncoder.EncodePrompt(m.Tokenizer, "")
|
||||
mlx.Keep(negativeTextEmbed)
|
||||
mlx.Eval(negativeTextEmbed)
|
||||
negativeTextCond = m.Transformer.ProjectTextEmbeddings(negativeTextEmbed)
|
||||
mlx.Keep(negativeTextCond)
|
||||
mlx.Eval(negativeTextCond)
|
||||
negativeTextEmbed.Free()
|
||||
}
|
||||
|
||||
// Prepare conditioning inputs
|
||||
targetSize := mlx.NewArray([]float32{float32(cfg.Height), float32(cfg.Width)}, []int32{1, 2})
|
||||
cropCoords := mlx.NewArray([]float32{0, 0}, []int32{1, 2}) // Default: no crop offset
|
||||
targetSize = mlx.ToBFloat16(targetSize)
|
||||
cropCoords = mlx.ToBFloat16(cropCoords)
|
||||
mlx.Keep(targetSize)
|
||||
mlx.Keep(cropCoords)
|
||||
mlx.Eval(targetSize, cropCoords)
|
||||
|
||||
pH := latentH / tcfg.PatchSize
|
||||
pW := latentW / tcfg.PatchSize
|
||||
|
||||
// Denoising loop
|
||||
fmt.Printf("[Diffusion] Starting %d denoising steps...\n", cfg.Steps)
|
||||
progress("diffusion", 0, cfg.Steps)
|
||||
for i := 0; i < cfg.Steps; i++ {
|
||||
fmt.Printf("[Diffusion] Step %d/%d (timestep=%.1f)\n", i+1, cfg.Steps, scheduler.Timesteps[i]-1)
|
||||
// Check for cancellation
|
||||
if ctx != nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
textEmbed.Free()
|
||||
visualTokens.Free()
|
||||
// visualTokensUpsampled points to visualTokens, don't double-free
|
||||
priorEmbed.Free()
|
||||
textCond.Free()
|
||||
latents.Free()
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// Get timestep value for the transformer
|
||||
// scheduler.Timesteps contains raw timestep values (1000 down to ~20)
|
||||
// Pass timestep - 1 to match diffusers: timestep = t.expand(latents.shape[0]) - 1
|
||||
timestepVal := scheduler.Timesteps[i] - 1
|
||||
timestep := mlx.ToBFloat16(mlx.NewArray([]float32{timestepVal}, []int32{1}))
|
||||
|
||||
// Patchify latents [B, C, H, W] -> [B, L, C*p*p]
|
||||
patches := PatchifyLatents(latents, tcfg.PatchSize)
|
||||
|
||||
// Transformer forward with MMDiT architecture
|
||||
// Conditional pass (with text + prior embeddings)
|
||||
outputCond := m.Transformer.ForwardWithPriorDrop(
|
||||
patches,
|
||||
priorEmbed,
|
||||
textCond,
|
||||
timestep,
|
||||
targetSize,
|
||||
cropCoords,
|
||||
pH,
|
||||
pW,
|
||||
false, // priorTokenDrop = false for conditional
|
||||
)
|
||||
|
||||
// Unpatchify [B, L, C*p*p] -> [B, C, H, W]
|
||||
noisePredCond := UnpatchifyLatents(outputCond, latentH, latentW, tcfg.PatchSize, tcfg.OutChannels)
|
||||
|
||||
var noisePred *mlx.Array
|
||||
if doCFG {
|
||||
// Unconditional pass (empty text, dropped prior embeddings)
|
||||
outputUncond := m.Transformer.ForwardWithPriorDrop(
|
||||
patches,
|
||||
priorEmbed, // Still passed but will be ignored due to priorTokenDrop=true
|
||||
negativeTextCond,
|
||||
timestep,
|
||||
targetSize,
|
||||
cropCoords,
|
||||
pH,
|
||||
pW,
|
||||
true, // priorTokenDrop = true for unconditional
|
||||
)
|
||||
noisePredUncond := UnpatchifyLatents(outputUncond, latentH, latentW, tcfg.PatchSize, tcfg.OutChannels)
|
||||
|
||||
// CFG formula: noise_pred = uncond + guidance_scale * (cond - uncond)
|
||||
diff := mlx.Sub(noisePredCond, noisePredUncond)
|
||||
scaled := mlx.MulScalar(diff, cfg.GuidanceScale)
|
||||
noisePred = mlx.Add(noisePredUncond, scaled)
|
||||
} else {
|
||||
noisePred = noisePredCond
|
||||
}
|
||||
|
||||
// Scheduler step
|
||||
oldLatents := latents
|
||||
latents = scheduler.Step(noisePred, latents, i)
|
||||
mlx.Eval(latents)
|
||||
oldLatents.Free()
|
||||
|
||||
progress("diffusion", i+1, cfg.Steps)
|
||||
}
|
||||
|
||||
// Cleanup intermediate arrays
|
||||
textEmbed.Free()
|
||||
visualTokens.Free()
|
||||
// visualTokensUpsampled points to visualTokens, don't double-free
|
||||
priorEmbed.Free()
|
||||
textCond.Free()
|
||||
if negativeTextCond != nil {
|
||||
negativeTextCond.Free()
|
||||
}
|
||||
targetSize.Free()
|
||||
cropCoords.Free()
|
||||
|
||||
// === PHASE 4: VAE Decode ===
|
||||
progress("vae_decode", 0, 1)
|
||||
decoded := m.VAEDecoder.Decode(latents)
|
||||
mlx.Eval(decoded)
|
||||
latents.Free()
|
||||
progress("vae_decode", 1, 1)
|
||||
|
||||
return decoded, nil
|
||||
}
|
||||
|
||||
// upsampleTokens performs nearest-neighbor upsampling of visual tokens
|
||||
// Converts from prev_grid (e.g., 16x16) to target_grid (e.g., 32x32 for 2x, 64x64 for 4x)
|
||||
// scale must be 2 or 4
|
||||
//
|
||||
// Handles early EOS gracefully: if tokens has fewer than prevH*prevW elements,
|
||||
// missing tokens are padded with 0 (visual token padding value).
|
||||
func upsampleTokens(tokens *mlx.Array, prevH, prevW int32, scale int32) *mlx.Array {
|
||||
// tokens: [1, N] where N <= prevH*prevW (may be shorter if early EOS)
|
||||
// Each token at (i, j) becomes scale*scale tokens in the output
|
||||
|
||||
mlx.Eval(tokens)
|
||||
tokenData := tokens.DataInt32()
|
||||
numTokens := int32(len(tokenData))
|
||||
expectedTokens := prevH * prevW
|
||||
|
||||
// Warn if we got fewer tokens than expected (early EOS)
|
||||
if numTokens < expectedTokens {
|
||||
fmt.Printf("WARNING: upsampleTokens got %d tokens, expected %d (padding with 0)\n",
|
||||
numTokens, expectedTokens)
|
||||
}
|
||||
|
||||
targetH := prevH * scale
|
||||
targetW := prevW * scale
|
||||
upsampled := make([]int32, targetH*targetW)
|
||||
|
||||
for i := int32(0); i < prevH; i++ {
|
||||
for j := int32(0); j < prevW; j++ {
|
||||
srcIdx := i*prevW + j
|
||||
|
||||
// Handle early EOS: use 0 (padding) for missing tokens
|
||||
var val int32
|
||||
if srcIdx < numTokens {
|
||||
val = tokenData[srcIdx]
|
||||
} else {
|
||||
val = 0 // Padding token
|
||||
}
|
||||
|
||||
// Place in scale*scale positions
|
||||
dstI := i * scale
|
||||
dstJ := j * scale
|
||||
for di := int32(0); di < scale; di++ {
|
||||
for dj := int32(0); dj < scale; dj++ {
|
||||
upsampled[(dstI+di)*targetW+(dstJ+dj)] = val
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return mlx.NewArrayInt32(upsampled, []int32{1, targetH * targetW})
|
||||
}
|
||||
|
||||
// PatchifyLatents converts [B, C, H, W] to [B, L, C*p*p]
|
||||
func PatchifyLatents(latents *mlx.Array, patchSize int32) *mlx.Array {
|
||||
shape := latents.Shape()
|
||||
B := shape[0]
|
||||
C := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
|
||||
pH := H / patchSize
|
||||
pW := W / patchSize
|
||||
|
||||
// Reshape: [B, C, H, W] -> [B, C, pH, p, pW, p]
|
||||
x := mlx.Reshape(latents, B, C, pH, patchSize, pW, patchSize)
|
||||
// Transpose: -> [B, pH, pW, C, p, p]
|
||||
x = mlx.Transpose(x, 0, 2, 4, 1, 3, 5)
|
||||
// Flatten: -> [B, pH*pW, C*p*p]
|
||||
return mlx.Reshape(x, B, pH*pW, C*patchSize*patchSize)
|
||||
}
|
||||
|
||||
// UnpatchifyLatents converts [B, L, C*p*p] back to [B, C, H, W]
|
||||
func UnpatchifyLatents(patches *mlx.Array, H, W, patchSize, channels int32) *mlx.Array {
|
||||
shape := patches.Shape()
|
||||
B := shape[0]
|
||||
|
||||
pH := H / patchSize
|
||||
pW := W / patchSize
|
||||
|
||||
// Reshape: [B, L, C*p*p] -> [B, pH, pW, C, p, p]
|
||||
x := mlx.Reshape(patches, B, pH, pW, channels, patchSize, patchSize)
|
||||
// Transpose: -> [B, C, pH, p, pW, p]
|
||||
x = mlx.Transpose(x, 0, 3, 1, 4, 2, 5)
|
||||
// Reshape: -> [B, C, H, W]
|
||||
return mlx.Reshape(x, B, channels, pH*patchSize, pW*patchSize)
|
||||
}
|
||||
|
||||
// CalculateShift computes the dynamic shift for flow matching based on image sequence length.
|
||||
func CalculateShift(imgSeqLen int32) float32 {
|
||||
cfg := DefaultSchedulerConfig()
|
||||
if !cfg.UseDynamicShifting {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Sqrt-based shift calculation (matches diffusers)
|
||||
m := float32(math.Sqrt(float64(imgSeqLen) / float64(cfg.BaseImageSeqLen)))
|
||||
return m*cfg.MaxShift + cfg.BaseShift
|
||||
}
|
||||
|
||||
// UpsampleTokens2x upsamples token IDs by 2x using nearest neighbor interpolation
|
||||
// tokens: [B, H*W] -> [B, (H*2)*(W*2)]
|
||||
// This matches diffusers' _upsample_token_ids function
|
||||
func UpsampleTokens2x(tokens *mlx.Array, gridH, gridW int32) *mlx.Array {
|
||||
shape := tokens.Shape()
|
||||
B := shape[0]
|
||||
|
||||
// Reshape to [B, 1, H, W] for interpolation
|
||||
tokens = mlx.Reshape(tokens, B, 1, gridH, gridW)
|
||||
|
||||
// Convert to float for interpolation
|
||||
tokensFloat := mlx.AsType(tokens, mlx.DtypeFloat32)
|
||||
|
||||
// 2x nearest neighbor upsample
|
||||
// [B, 1, H, W] -> [B, 1, H*2, W*2]
|
||||
upsampled := nearestUpsample2x(tokensFloat)
|
||||
|
||||
// Convert back to int and reshape to [B, H*2*W*2]
|
||||
upsampled = mlx.AsType(upsampled, mlx.DtypeInt32)
|
||||
return mlx.Reshape(upsampled, B, gridH*2*gridW*2)
|
||||
}
|
||||
|
||||
// nearestUpsample2x performs 2x nearest neighbor upsampling on NCHW tensor
|
||||
func nearestUpsample2x(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
C := shape[1]
|
||||
H := shape[2]
|
||||
W := shape[3]
|
||||
|
||||
// Repeat each element 2x2
|
||||
// [B, C, H, W] -> [B, C, H, 1, W, 1] -> [B, C, H, 2, W, 2] -> [B, C, H*2, W*2]
|
||||
x = mlx.Reshape(x, B, C, H, 1, W, 1)
|
||||
|
||||
// Tile to repeat each pixel 2x2
|
||||
x = mlx.Tile(x, []int32{1, 1, 1, 2, 1, 2})
|
||||
|
||||
// Reshape to final size
|
||||
return mlx.Reshape(x, B, C, H*2, W*2)
|
||||
}
|
||||
@@ -1,358 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package glm_image
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
)
|
||||
|
||||
// GLMTokenizer implements the GLM tokenizer for the AR model
|
||||
// This is a BPE-style tokenizer with ignore_merges=true, meaning it does
|
||||
// greedy longest-match tokenization from the vocab without runtime merging.
|
||||
type GLMTokenizer struct {
|
||||
Vocab map[string]int32 // token string -> token ID
|
||||
VocabReverse map[int32]string // token ID -> token string
|
||||
SpecialTokens map[string]int32 // special token strings -> IDs
|
||||
|
||||
// Special token IDs
|
||||
SopTokenID int32 // <sop> = grid_bos_token (167845)
|
||||
EopTokenID int32 // <eop> = grid_eos_token (167846)
|
||||
BosTokenID int32 // <|dit_token_16384|> = visual BOS (16384)
|
||||
EosTokenID int32 // <|dit_token_16385|> = visual EOS (16385)
|
||||
PadTokenID int32
|
||||
|
||||
// Sorted vocab keys by length (longest first) for greedy matching
|
||||
sortedTokens []string
|
||||
}
|
||||
|
||||
// tokenizerJSON represents the structure of tokenizer.json
|
||||
type tokenizerJSON struct {
|
||||
Model struct {
|
||||
Vocab map[string]int32 `json:"vocab"`
|
||||
} `json:"model"`
|
||||
AddedTokens []struct {
|
||||
ID int32 `json:"id"`
|
||||
Content string `json:"content"`
|
||||
Special bool `json:"special"`
|
||||
} `json:"added_tokens"`
|
||||
}
|
||||
|
||||
// NewGLMTokenizer creates a GLM tokenizer from the model manifest
|
||||
func NewGLMTokenizer(manifest *imagegen.ModelManifest) (*GLMTokenizer, error) {
|
||||
// Read tokenizer.json from processor directory in manifest
|
||||
data, err := manifest.ReadConfig("processor/tokenizer.json")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read tokenizer.json from manifest: %w", err)
|
||||
}
|
||||
|
||||
var tj tokenizerJSON
|
||||
if err := json.Unmarshal(data, &tj); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse tokenizer.json: %w", err)
|
||||
}
|
||||
|
||||
tok := &GLMTokenizer{
|
||||
Vocab: make(map[string]int32),
|
||||
VocabReverse: make(map[int32]string),
|
||||
SpecialTokens: make(map[string]int32),
|
||||
}
|
||||
|
||||
// Load vocab from model section
|
||||
for token, id := range tj.Model.Vocab {
|
||||
tok.Vocab[token] = id
|
||||
tok.VocabReverse[id] = token
|
||||
}
|
||||
|
||||
// Load added tokens (special tokens including dit_tokens)
|
||||
for _, at := range tj.AddedTokens {
|
||||
tok.Vocab[at.Content] = at.ID
|
||||
tok.VocabReverse[at.ID] = at.Content
|
||||
if at.Special {
|
||||
tok.SpecialTokens[at.Content] = at.ID
|
||||
}
|
||||
}
|
||||
|
||||
// Set special token IDs
|
||||
tok.SopTokenID = 167845 // <sop>
|
||||
tok.EopTokenID = 167846 // <eop>
|
||||
tok.BosTokenID = 16384 // <|dit_token_16384|>
|
||||
tok.EosTokenID = 16385 // <|dit_token_16385|>
|
||||
tok.PadTokenID = 16385 // Same as EOS
|
||||
|
||||
// Build sorted token list for greedy matching (longest first)
|
||||
tok.sortedTokens = make([]string, 0, len(tok.Vocab))
|
||||
for token := range tok.Vocab {
|
||||
tok.sortedTokens = append(tok.sortedTokens, token)
|
||||
}
|
||||
sort.Slice(tok.sortedTokens, func(i, j int) bool {
|
||||
return len(tok.sortedTokens[i]) > len(tok.sortedTokens[j])
|
||||
})
|
||||
|
||||
fmt.Printf("Loaded GLM tokenizer with %d tokens\n", len(tok.Vocab))
|
||||
|
||||
return tok, nil
|
||||
}
|
||||
|
||||
// NewGLMTokenizerFromPath creates a GLM tokenizer from a directory path
|
||||
func NewGLMTokenizerFromPath(modelPath string) (*GLMTokenizer, error) {
|
||||
// Read tokenizer.json from processor directory
|
||||
tokenizerPath := filepath.Join(modelPath, "processor", "tokenizer.json")
|
||||
data, err := os.ReadFile(tokenizerPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read tokenizer.json: %w", err)
|
||||
}
|
||||
|
||||
var tj tokenizerJSON
|
||||
if err := json.Unmarshal(data, &tj); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse tokenizer.json: %w", err)
|
||||
}
|
||||
|
||||
tok := &GLMTokenizer{
|
||||
Vocab: make(map[string]int32),
|
||||
VocabReverse: make(map[int32]string),
|
||||
SpecialTokens: make(map[string]int32),
|
||||
}
|
||||
|
||||
// Load vocab from model section
|
||||
for token, id := range tj.Model.Vocab {
|
||||
tok.Vocab[token] = id
|
||||
tok.VocabReverse[id] = token
|
||||
}
|
||||
|
||||
// Load added tokens (special tokens including dit_tokens)
|
||||
for _, at := range tj.AddedTokens {
|
||||
tok.Vocab[at.Content] = at.ID
|
||||
tok.VocabReverse[at.ID] = at.Content
|
||||
if at.Special {
|
||||
tok.SpecialTokens[at.Content] = at.ID
|
||||
}
|
||||
}
|
||||
|
||||
// Set special token IDs
|
||||
tok.SopTokenID = 167845 // <sop>
|
||||
tok.EopTokenID = 167846 // <eop>
|
||||
tok.BosTokenID = 16384 // <|dit_token_16384|>
|
||||
tok.EosTokenID = 16385 // <|dit_token_16385|>
|
||||
tok.PadTokenID = 16385 // Same as EOS
|
||||
|
||||
// Build sorted token list for greedy matching (longest first)
|
||||
tok.sortedTokens = make([]string, 0, len(tok.Vocab))
|
||||
for token := range tok.Vocab {
|
||||
tok.sortedTokens = append(tok.sortedTokens, token)
|
||||
}
|
||||
sort.Slice(tok.sortedTokens, func(i, j int) bool {
|
||||
return len(tok.sortedTokens[i]) > len(tok.sortedTokens[j])
|
||||
})
|
||||
|
||||
fmt.Printf("Loaded GLM tokenizer with %d tokens\n", len(tok.Vocab))
|
||||
|
||||
return tok, nil
|
||||
}
|
||||
|
||||
// Encode tokenizes a string into token IDs
|
||||
// This uses greedy longest-match tokenization with GPT-2 style space handling
|
||||
func (t *GLMTokenizer) Encode(text string) []int32 {
|
||||
if text == "" {
|
||||
return []int32{}
|
||||
}
|
||||
|
||||
var tokens []int32
|
||||
|
||||
// First, check for and handle special tokens
|
||||
// Replace special tokens with placeholders, encode, then restore
|
||||
specialReplacements := make(map[string]int32)
|
||||
for special, id := range t.SpecialTokens {
|
||||
if strings.Contains(text, special) {
|
||||
specialReplacements[special] = id
|
||||
}
|
||||
}
|
||||
|
||||
// Process text character by character with special token handling
|
||||
i := 0
|
||||
isFirstToken := true
|
||||
|
||||
for i < len(text) {
|
||||
// Check for special tokens first
|
||||
foundSpecial := false
|
||||
for special, id := range specialReplacements {
|
||||
if strings.HasPrefix(text[i:], special) {
|
||||
tokens = append(tokens, id)
|
||||
i += len(special)
|
||||
isFirstToken = false
|
||||
foundSpecial = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if foundSpecial {
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle regular text with GPT-2 style space prefix
|
||||
// "Ġ" (U+0120) represents a space before a token
|
||||
remaining := text[i:]
|
||||
|
||||
// Try to find the longest matching token
|
||||
matched := false
|
||||
for _, token := range t.sortedTokens {
|
||||
// Skip special tokens in regular matching
|
||||
if _, isSpecial := t.SpecialTokens[token]; isSpecial {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if this token matches
|
||||
tokenText := token
|
||||
|
||||
// Handle the Ġ prefix (represents space)
|
||||
if strings.HasPrefix(token, "Ġ") {
|
||||
// This token expects a leading space
|
||||
if i > 0 || !isFirstToken {
|
||||
// Check if remaining starts with space + token content
|
||||
tokenContent := token[len("Ġ"):]
|
||||
if strings.HasPrefix(remaining, " "+tokenContent) {
|
||||
tokens = append(tokens, t.Vocab[token])
|
||||
i += 1 + len(tokenContent) // space + content
|
||||
isFirstToken = false
|
||||
matched = true
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Regular token without space prefix
|
||||
if strings.HasPrefix(remaining, tokenText) {
|
||||
tokens = append(tokens, t.Vocab[token])
|
||||
i += len(tokenText)
|
||||
isFirstToken = false
|
||||
matched = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !matched {
|
||||
// No token found - skip this character (or use UNK)
|
||||
// For now, just skip unknown characters
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
return tokens
|
||||
}
|
||||
|
||||
// EncodeForGeneration encodes a prompt with grid tokens for image generation
|
||||
// Format: {prompt}<sop>{token_h} {token_w}<eop><sop>{prev_h} {prev_w}<eop><|dit_token_16384|>
|
||||
//
|
||||
// Uses GPT-2 style tokenization where " 32" becomes "Ġ32" (a single token with
|
||||
// space prefix), matching the HuggingFace tokenizer behavior.
|
||||
func (t *GLMTokenizer) EncodeForGeneration(prompt string, targetHeight, targetWidth int32) []int32 {
|
||||
// Calculate grid dimensions
|
||||
factor := int32(32)
|
||||
height := (targetHeight / factor) * factor
|
||||
width := (targetWidth / factor) * factor
|
||||
tokenH := height / factor
|
||||
tokenW := width / factor
|
||||
|
||||
// Calculate previous grid dimensions
|
||||
ratio := float64(tokenH) / float64(tokenW)
|
||||
prevTokenH := int32(sqrt(ratio) * 16)
|
||||
prevTokenW := int32(sqrt(1.0/ratio) * 16)
|
||||
|
||||
// Encode the prompt text
|
||||
promptTokens := t.Encode(prompt)
|
||||
|
||||
// Build the full sequence:
|
||||
// [prompt tokens] <sop> [tokenH] [Ġ+tokenW] <eop> <sop> [prevH] [Ġ+prevW] <eop> <bos>
|
||||
// Note: HF tokenizer treats " 32" as "Ġ32" (single token), not "Ġ" + "32"
|
||||
var tokens []int32
|
||||
tokens = append(tokens, promptTokens...)
|
||||
|
||||
// First grid: <sop> H W <eop>
|
||||
// First number has no space prefix, second number has space prefix (Ġ)
|
||||
tokens = append(tokens, t.SopTokenID)
|
||||
tokens = append(tokens, t.encodeNumber(tokenH)...)
|
||||
tokens = append(tokens, t.encodeSpaceNumber(tokenW)...) // " W" as Ġ+W
|
||||
tokens = append(tokens, t.EopTokenID)
|
||||
|
||||
// Second grid: <sop> prevH prevW <eop>
|
||||
tokens = append(tokens, t.SopTokenID)
|
||||
tokens = append(tokens, t.encodeNumber(prevTokenH)...)
|
||||
tokens = append(tokens, t.encodeSpaceNumber(prevTokenW)...) // " prevW" as Ġ+prevW
|
||||
tokens = append(tokens, t.EopTokenID)
|
||||
|
||||
// BOS token (start of image generation)
|
||||
tokens = append(tokens, t.BosTokenID)
|
||||
|
||||
return tokens
|
||||
}
|
||||
|
||||
// encodeNumber encodes a number - first tries as a whole token, falls back to digit-by-digit
|
||||
func (t *GLMTokenizer) encodeNumber(n int32) []int32 {
|
||||
s := fmt.Sprintf("%d", n)
|
||||
// First try: look up the whole number as a single token
|
||||
if id, ok := t.Vocab[s]; ok {
|
||||
return []int32{id}
|
||||
}
|
||||
// Fallback: encode digit by digit
|
||||
var tokens []int32
|
||||
for _, c := range s {
|
||||
if id, ok := t.Vocab[string(c)]; ok {
|
||||
tokens = append(tokens, id)
|
||||
}
|
||||
}
|
||||
return tokens
|
||||
}
|
||||
|
||||
// encodeSpaceNumber encodes " N" as "ĠN" (space-prefixed number) matching HF tokenizer
|
||||
// GPT-2 style: " 32" becomes single token "Ġ32", not "Ġ" + "32"
|
||||
func (t *GLMTokenizer) encodeSpaceNumber(n int32) []int32 {
|
||||
s := fmt.Sprintf("%d", n)
|
||||
|
||||
// First try: look up "Ġ{number}" as a single token (e.g., "Ġ32")
|
||||
spaceToken := "Ġ" + s
|
||||
if id, ok := t.Vocab[spaceToken]; ok {
|
||||
return []int32{id}
|
||||
}
|
||||
|
||||
// Fallback: bare space Ġ + number tokens
|
||||
var tokens []int32
|
||||
if spaceID, ok := t.Vocab["Ġ"]; ok {
|
||||
tokens = append(tokens, spaceID)
|
||||
}
|
||||
tokens = append(tokens, t.encodeNumber(n)...)
|
||||
return tokens
|
||||
}
|
||||
|
||||
// sqrt is a helper for float64 sqrt
|
||||
func sqrt(x float64) float64 {
|
||||
if x <= 0 {
|
||||
return 0
|
||||
}
|
||||
// Newton's method
|
||||
z := x
|
||||
for i := 0; i < 10; i++ {
|
||||
z = z - (z*z-x)/(2*z)
|
||||
}
|
||||
return z
|
||||
}
|
||||
|
||||
// Decode converts token IDs back to a string
|
||||
func (t *GLMTokenizer) Decode(tokens []int32) string {
|
||||
var sb strings.Builder
|
||||
for _, id := range tokens {
|
||||
if token, ok := t.VocabReverse[id]; ok {
|
||||
// Handle Ġ prefix (convert back to space)
|
||||
if strings.HasPrefix(token, "Ġ") {
|
||||
sb.WriteString(" ")
|
||||
sb.WriteString(token[len("Ġ"):])
|
||||
} else {
|
||||
sb.WriteString(token)
|
||||
}
|
||||
}
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
@@ -1,159 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package glm_image
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// FlowMatchSchedulerConfig holds scheduler configuration
|
||||
type FlowMatchSchedulerConfig struct {
|
||||
NumTrainTimesteps int32 `json:"num_train_timesteps"` // 1000
|
||||
BaseShift float32 `json:"base_shift"` // 0.25
|
||||
MaxShift float32 `json:"max_shift"` // 0.75
|
||||
BaseImageSeqLen int32 `json:"base_image_seq_len"` // 256
|
||||
MaxImageSeqLen int32 `json:"max_image_seq_len"` // 4096
|
||||
UseDynamicShifting bool `json:"use_dynamic_shifting"` // true
|
||||
TimeShiftType string `json:"time_shift_type"` // "linear"
|
||||
}
|
||||
|
||||
// DefaultSchedulerConfig returns the default config for GLM-Image
|
||||
func DefaultSchedulerConfig() *FlowMatchSchedulerConfig {
|
||||
return &FlowMatchSchedulerConfig{
|
||||
NumTrainTimesteps: 1000,
|
||||
BaseShift: 0.25,
|
||||
MaxShift: 0.75,
|
||||
BaseImageSeqLen: 256,
|
||||
MaxImageSeqLen: 4096,
|
||||
UseDynamicShifting: true,
|
||||
TimeShiftType: "linear",
|
||||
}
|
||||
}
|
||||
|
||||
// FlowMatchScheduler implements FlowMatchEulerDiscreteScheduler
|
||||
type FlowMatchScheduler struct {
|
||||
Config *FlowMatchSchedulerConfig
|
||||
Timesteps []float32 // Raw timesteps for transformer conditioning (unshifted)
|
||||
Sigmas []float32 // Shifted sigmas for Euler step calculation
|
||||
NumSteps int
|
||||
}
|
||||
|
||||
// NewFlowMatchScheduler creates a new scheduler
|
||||
func NewFlowMatchScheduler(cfg *FlowMatchSchedulerConfig) *FlowMatchScheduler {
|
||||
return &FlowMatchScheduler{Config: cfg}
|
||||
}
|
||||
|
||||
// SetTimestepsWithDynamicShift sets timesteps with dynamic shifting based on image size
|
||||
// Following diffusers: raw timesteps are used for conditioning, shifted sigmas for step calculation
|
||||
func (s *FlowMatchScheduler) SetTimestepsWithDynamicShift(numSteps int, imgSeqLen int32) {
|
||||
s.NumSteps = numSteps
|
||||
|
||||
// Calculate shift (mu) based on image sequence length
|
||||
mu := s.calculateShift(imgSeqLen)
|
||||
|
||||
// Create timesteps: linspace from sigma_max_t to sigma_min_t
|
||||
// sigma_max = 1.0, sigma_min ~= 0.001 (near 0 but not exactly 0)
|
||||
// Then apply time shift and append terminal sigma=0
|
||||
s.Timesteps = make([]float32, numSteps)
|
||||
s.Sigmas = make([]float32, numSteps+1) // +1 for terminal sigma
|
||||
|
||||
numTrainTimesteps := float32(s.Config.NumTrainTimesteps)
|
||||
|
||||
// Create base sigmas: linspace from 1.0 to small value (matching diffusers)
|
||||
for i := 0; i < numSteps; i++ {
|
||||
// linspace from 1000 to ~20 (sigma_min * num_train_timesteps)
|
||||
tRaw := numTrainTimesteps - float32(i)*(numTrainTimesteps-1.0)/float32(numSteps-1)
|
||||
s.Timesteps[i] = tRaw
|
||||
|
||||
// Convert to sigma [0, 1]
|
||||
sigma := tRaw / numTrainTimesteps
|
||||
|
||||
// Apply time shift if enabled
|
||||
if s.Config.UseDynamicShifting && mu > 0 {
|
||||
sigma = s.applyShift(mu, sigma)
|
||||
}
|
||||
|
||||
s.Sigmas[i] = sigma
|
||||
}
|
||||
|
||||
// Append terminal sigma = 0 (the final clean image)
|
||||
s.Sigmas[numSteps] = 0
|
||||
}
|
||||
|
||||
// calculateShift computes dynamic shift based on image sequence length
|
||||
// Uses the sqrt-based formula from diffusers:
|
||||
// m = (image_seq_len / base_seq_len) ** 0.5
|
||||
// mu = m * max_shift + base_shift
|
||||
func (s *FlowMatchScheduler) calculateShift(imgSeqLen int32) float32 {
|
||||
cfg := s.Config
|
||||
|
||||
if !cfg.UseDynamicShifting {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Sqrt-based shift calculation (matches diffusers pipeline_glm_image.py)
|
||||
m := float32(math.Sqrt(float64(imgSeqLen) / float64(cfg.BaseImageSeqLen)))
|
||||
mu := m*cfg.MaxShift + cfg.BaseShift
|
||||
return mu
|
||||
}
|
||||
|
||||
// applyShift applies time shift transformation
|
||||
// mu: the computed shift value
|
||||
// t: sigma value in [0, 1]
|
||||
func (s *FlowMatchScheduler) applyShift(mu float32, t float32) float32 {
|
||||
if t <= 0 {
|
||||
return 0
|
||||
}
|
||||
if t >= 1 {
|
||||
return 1
|
||||
}
|
||||
|
||||
// sigma=1.0 for both shift types
|
||||
sigma := float32(1.0)
|
||||
|
||||
if s.Config.TimeShiftType == "linear" {
|
||||
// Linear: mu / (mu + (1/t - 1)^sigma)
|
||||
return mu / (mu + float32(math.Pow(float64(1.0/t-1.0), float64(sigma))))
|
||||
}
|
||||
|
||||
// Exponential (default): exp(mu) / (exp(mu) + (1/t - 1)^sigma)
|
||||
expMu := float32(math.Exp(float64(mu)))
|
||||
return expMu / (expMu + float32(math.Pow(float64(1.0/t-1.0), float64(sigma))))
|
||||
}
|
||||
|
||||
// Step performs one denoising step
|
||||
func (s *FlowMatchScheduler) Step(modelOutput, sample *mlx.Array, stepIdx int) *mlx.Array {
|
||||
sigma := s.Sigmas[stepIdx]
|
||||
sigmaNext := s.Sigmas[stepIdx+1]
|
||||
|
||||
// Euler step: x_{t-dt} = x_t + dt * v_t
|
||||
dt := sigmaNext - sigma // Negative (going from noise to clean)
|
||||
|
||||
scaledOutput := mlx.MulScalar(modelOutput, dt)
|
||||
return mlx.Add(sample, scaledOutput)
|
||||
}
|
||||
|
||||
// InitNoise creates initial noise
|
||||
func (s *FlowMatchScheduler) InitNoise(shape []int32, seed int64) *mlx.Array {
|
||||
return mlx.RandomNormalWithDtype(shape, uint64(seed), mlx.DtypeBFloat16)
|
||||
}
|
||||
|
||||
// AddNoise adds noise to clean samples for a given timestep (for img2img)
|
||||
func (s *FlowMatchScheduler) AddNoise(cleanSample, noise *mlx.Array, timestepIdx int) *mlx.Array {
|
||||
// In flow matching: x_t = (1-sigma) * x_0 + sigma * noise
|
||||
// Use sigmas (shifted) for the interpolation
|
||||
sigma := s.Sigmas[timestepIdx]
|
||||
oneMinusSigma := 1.0 - sigma
|
||||
|
||||
scaledClean := mlx.MulScalar(cleanSample, oneMinusSigma)
|
||||
scaledNoise := mlx.MulScalar(noise, sigma)
|
||||
|
||||
return mlx.Add(scaledClean, scaledNoise)
|
||||
}
|
||||
|
||||
// GetTimesteps returns all timesteps
|
||||
func (s *FlowMatchScheduler) GetTimesteps() []float32 {
|
||||
return s.Timesteps
|
||||
}
|
||||
@@ -1,497 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package glm_image
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
|
||||
// T5Config holds T5 encoder configuration
|
||||
type T5Config struct {
|
||||
DModel int32 `json:"d_model"` // 1472
|
||||
DFF int32 `json:"d_ff"` // 3584
|
||||
DKV int32 `json:"d_kv"` // 64
|
||||
NumHeads int32 `json:"num_heads"` // 6
|
||||
NumLayers int32 `json:"num_layers"` // 12
|
||||
VocabSize int32 `json:"vocab_size"` // 384 (byte-level)
|
||||
LayerNormEps float32 `json:"layer_norm_epsilon"` // 1e-6
|
||||
IsGatedAct bool `json:"is_gated_act"` // true (gated-gelu)
|
||||
|
||||
// Relative position bias
|
||||
RelativeAttentionNumBuckets int32 `json:"relative_attention_num_buckets"` // 32
|
||||
RelativeAttentionMaxDistance int32 `json:"relative_attention_max_distance"` // 128
|
||||
}
|
||||
|
||||
// T5TextEncoder is the T5 encoder for text conditioning
|
||||
type T5TextEncoder struct {
|
||||
Config *T5Config
|
||||
|
||||
// Embedding (shared for ByT5)
|
||||
SharedEmbed *nn.Embedding `weight:"shared"`
|
||||
|
||||
// Encoder layers
|
||||
Layers []*T5Block `weight:"encoder.block"`
|
||||
|
||||
// Final layer norm
|
||||
FinalNorm *T5LayerNorm `weight:"encoder.final_layer_norm"`
|
||||
|
||||
// Relative position bias (from first layer, shared across all)
|
||||
RelativeAttentionBias *mlx.Array `weight:"encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"`
|
||||
}
|
||||
|
||||
// T5Block is a single T5 encoder block
|
||||
type T5Block struct {
|
||||
// Self attention
|
||||
Layer0 *T5LayerSelfAttention `weight:"layer.0"`
|
||||
// FFN
|
||||
Layer1 *T5LayerFF `weight:"layer.1"`
|
||||
}
|
||||
|
||||
// T5LayerSelfAttention is T5's self-attention layer
|
||||
type T5LayerSelfAttention struct {
|
||||
SelfAttention *T5Attention `weight:"SelfAttention"`
|
||||
LayerNorm *T5LayerNorm `weight:"layer_norm"`
|
||||
}
|
||||
|
||||
// T5Attention implements T5's relative attention
|
||||
type T5Attention struct {
|
||||
Q *mlx.Array `weight:"q.weight"` // No bias in T5
|
||||
K *mlx.Array `weight:"k.weight"`
|
||||
V *mlx.Array `weight:"v.weight"`
|
||||
O *mlx.Array `weight:"o.weight"`
|
||||
|
||||
NHeads int32
|
||||
DKV int32
|
||||
Scale float32
|
||||
}
|
||||
|
||||
// T5LayerFF is T5's feedforward layer with gated-gelu
|
||||
type T5LayerFF struct {
|
||||
DenseReluDense *T5DenseGatedGelu `weight:"DenseReluDense"`
|
||||
LayerNorm *T5LayerNorm `weight:"layer_norm"`
|
||||
}
|
||||
|
||||
// T5DenseGatedGelu is T5's gated-gelu FFN
|
||||
type T5DenseGatedGelu struct {
|
||||
Wi0 *mlx.Array `weight:"wi_0.weight"` // gate projection
|
||||
Wi1 *mlx.Array `weight:"wi_1.weight"` // up projection
|
||||
Wo *mlx.Array `weight:"wo.weight"` // down projection
|
||||
}
|
||||
|
||||
// T5LayerNorm is T5's RMSNorm variant (no bias, no mean subtraction)
|
||||
type T5LayerNorm struct {
|
||||
Weight *mlx.Array `weight:"weight"`
|
||||
Eps float32
|
||||
}
|
||||
|
||||
// Load loads the T5 text encoder from manifest
|
||||
func (m *T5TextEncoder) Load(manifest *imagegen.ModelManifest) error {
|
||||
fmt.Print(" Loading T5 text encoder... ")
|
||||
|
||||
// Load config
|
||||
var cfg T5Config
|
||||
if err := manifest.ReadConfigJSON("text_encoder/config.json", &cfg); err != nil {
|
||||
return fmt.Errorf("config: %w", err)
|
||||
}
|
||||
m.Config = &cfg
|
||||
|
||||
// Pre-allocate layers
|
||||
m.Layers = make([]*T5Block, cfg.NumLayers)
|
||||
|
||||
// Load weights
|
||||
weights, err := imagegen.LoadWeightsFromManifest(manifest, "text_encoder")
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
if err := weights.Load(0); err != nil {
|
||||
return fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
defer weights.ReleaseAll()
|
||||
|
||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||
return fmt.Errorf("load module: %w", err)
|
||||
}
|
||||
|
||||
m.initComputedFields()
|
||||
fmt.Println("✓")
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadFromPath loads the T5 text encoder from a directory path
|
||||
func (m *T5TextEncoder) LoadFromPath(path string) error {
|
||||
fmt.Print(" Loading T5 text encoder... ")
|
||||
|
||||
// Load config
|
||||
var cfg T5Config
|
||||
configPath := filepath.Join(path, "config.json")
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read config: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return fmt.Errorf("parse config: %w", err)
|
||||
}
|
||||
m.Config = &cfg
|
||||
|
||||
// Pre-allocate layers
|
||||
m.Layers = make([]*T5Block, cfg.NumLayers)
|
||||
|
||||
// Load weights from safetensors files
|
||||
weights, err := safetensors.LoadModelWeights(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
if err := weights.Load(0); err != nil {
|
||||
return fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
defer weights.ReleaseAll()
|
||||
|
||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||
return fmt.Errorf("load module: %w", err)
|
||||
}
|
||||
|
||||
m.initComputedFields()
|
||||
fmt.Println("✓")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *T5TextEncoder) initComputedFields() {
|
||||
cfg := m.Config
|
||||
m.FinalNorm.Eps = cfg.LayerNormEps
|
||||
for _, block := range m.Layers {
|
||||
attn := block.Layer0.SelfAttention
|
||||
attn.NHeads = cfg.NumHeads
|
||||
attn.DKV = cfg.DKV
|
||||
attn.Scale = float32(1.0 / math.Sqrt(float64(cfg.DKV)))
|
||||
|
||||
block.Layer0.LayerNorm.Eps = cfg.LayerNormEps
|
||||
block.Layer1.LayerNorm.Eps = cfg.LayerNormEps
|
||||
}
|
||||
}
|
||||
|
||||
// Forward encodes text tokens
|
||||
func (m *T5TextEncoder) Forward(tokens *mlx.Array) *mlx.Array {
|
||||
cfg := m.Config
|
||||
|
||||
// Get embeddings
|
||||
h := m.SharedEmbed.Forward(tokens)
|
||||
|
||||
// Compute relative position bias once
|
||||
seqLen := tokens.Shape()[1]
|
||||
posBias := m.computeRelativePositionBias(seqLen)
|
||||
|
||||
// Forward through layers
|
||||
for _, block := range m.Layers {
|
||||
h = block.Forward(h, posBias, cfg.LayerNormEps)
|
||||
}
|
||||
|
||||
// Final norm
|
||||
h = m.FinalNorm.Forward(h)
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// extractGlyphTexts extracts quoted text (glyphs) from the prompt
|
||||
// This matches diffusers' get_glyph_texts from pipeline_glm_image.py
|
||||
// Glyph texts are used for text rendering guidance in the generated image
|
||||
func extractGlyphTexts(prompt string) []string {
|
||||
var glyphTexts []string
|
||||
|
||||
// Extract text in single quotes: 'text'
|
||||
re1 := regexp.MustCompile(`'([^']*)'`)
|
||||
for _, match := range re1.FindAllStringSubmatch(prompt, -1) {
|
||||
if len(match) > 1 {
|
||||
glyphTexts = append(glyphTexts, match[1])
|
||||
}
|
||||
}
|
||||
|
||||
// Extract text in Unicode curly double quotes: "text"
|
||||
re2 := regexp.MustCompile(`"([^""]*)"`)
|
||||
for _, match := range re2.FindAllStringSubmatch(prompt, -1) {
|
||||
if len(match) > 1 {
|
||||
glyphTexts = append(glyphTexts, match[1])
|
||||
}
|
||||
}
|
||||
|
||||
// Extract text in ASCII double quotes: "text"
|
||||
re3 := regexp.MustCompile(`"([^"]*)"`)
|
||||
for _, match := range re3.FindAllStringSubmatch(prompt, -1) {
|
||||
if len(match) > 1 {
|
||||
glyphTexts = append(glyphTexts, match[1])
|
||||
}
|
||||
}
|
||||
|
||||
// Extract text in Japanese quotes: 「text」
|
||||
re4 := regexp.MustCompile(`「([^「」]*)」`)
|
||||
for _, match := range re4.FindAllStringSubmatch(prompt, -1) {
|
||||
if len(match) > 1 {
|
||||
glyphTexts = append(glyphTexts, match[1])
|
||||
}
|
||||
}
|
||||
|
||||
return glyphTexts
|
||||
}
|
||||
|
||||
// EncodePrompt encodes the prompt text using the ByT5 tokenizer and encoder
|
||||
// This provides text conditioning for the diffusion transformer via the glyph projector
|
||||
//
|
||||
// IMPORTANT: This encodes only the GLYPH TEXTS (quoted strings in the prompt), not the
|
||||
// full prompt. Glyph texts are used for text rendering guidance in the generated image.
|
||||
// Multiple glyph texts are encoded and concatenated to form the conditioning signal.
|
||||
// This matches diffusers' _get_glyph_embeds() behavior.
|
||||
func (m *T5TextEncoder) EncodePrompt(tok *ByT5Tokenizer, prompt string) *mlx.Array {
|
||||
// Extract glyph texts from prompt (text in quotes)
|
||||
glyphTexts := extractGlyphTexts(prompt)
|
||||
|
||||
// If no glyph texts found, encode empty string (matches diffusers: [""] fallback)
|
||||
if len(glyphTexts) == 0 {
|
||||
glyphTexts = []string{""}
|
||||
}
|
||||
|
||||
// Encode each glyph text and collect token sequences
|
||||
// Matching diffusers' _get_glyph_embeds() which batches all glyph texts
|
||||
var allTokenSeqs [][]int32
|
||||
|
||||
for _, glyphText := range glyphTexts {
|
||||
// ByT5 uses byte-level encoding: each byte (0-255) -> token (3-258)
|
||||
tokens := tok.Encode(glyphText)
|
||||
|
||||
// Add EOS token (1) at the end to match HuggingFace tokenizer behavior
|
||||
tokens = append(tokens, tok.EOSTokenID)
|
||||
|
||||
allTokenSeqs = append(allTokenSeqs, tokens)
|
||||
}
|
||||
|
||||
// Process each glyph text through the encoder
|
||||
var allEmbeddings []*mlx.Array
|
||||
for _, tokens := range allTokenSeqs {
|
||||
tokenLen := len(tokens)
|
||||
if tokenLen == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Create token array [1, L]
|
||||
tokensArr := mlx.NewArrayInt32(tokens, []int32{1, int32(tokenLen)})
|
||||
|
||||
// Forward through encoder
|
||||
output := m.Forward(tokensArr)
|
||||
mlx.Eval(output)
|
||||
|
||||
allEmbeddings = append(allEmbeddings, output)
|
||||
}
|
||||
|
||||
// Concatenate all glyph embeddings along sequence dimension
|
||||
var output *mlx.Array
|
||||
if len(allEmbeddings) == 0 {
|
||||
// Fallback: return single zero embedding
|
||||
output = mlx.Zeros([]int32{1, 1, m.Config.DModel}, mlx.DtypeBFloat16)
|
||||
} else if len(allEmbeddings) == 1 {
|
||||
output = allEmbeddings[0]
|
||||
} else {
|
||||
output = mlx.Concatenate(allEmbeddings, 1)
|
||||
}
|
||||
mlx.Eval(output)
|
||||
|
||||
return output
|
||||
}
|
||||
|
||||
// computeRelativePositionBias computes T5's relative position encoding
|
||||
func (m *T5TextEncoder) computeRelativePositionBias(seqLen int32) *mlx.Array {
|
||||
cfg := m.Config
|
||||
|
||||
// Create relative position matrix
|
||||
// For each (query_pos, key_pos) pair, compute bucketed relative position
|
||||
numBuckets := cfg.RelativeAttentionNumBuckets
|
||||
maxDistance := cfg.RelativeAttentionMaxDistance
|
||||
|
||||
// Create position indices
|
||||
contextPos := make([]int32, seqLen*seqLen)
|
||||
memoryPos := make([]int32, seqLen*seqLen)
|
||||
for i := int32(0); i < seqLen; i++ {
|
||||
for j := int32(0); j < seqLen; j++ {
|
||||
contextPos[i*seqLen+j] = i
|
||||
memoryPos[i*seqLen+j] = j
|
||||
}
|
||||
}
|
||||
|
||||
// Compute relative positions and bucket them
|
||||
buckets := make([]int32, seqLen*seqLen)
|
||||
for i := int32(0); i < seqLen*seqLen; i++ {
|
||||
relPos := memoryPos[i] - contextPos[i]
|
||||
buckets[i] = relativePosistionBucket(relPos, numBuckets, maxDistance, false)
|
||||
}
|
||||
|
||||
// Create bucket indices array
|
||||
bucketsArr := mlx.NewArrayInt32(buckets, []int32{seqLen, seqLen})
|
||||
|
||||
// Look up bias: RelativeAttentionBias shape is [numBuckets, numHeads] = [32, 6]
|
||||
// Take along axis 0 (buckets dimension) -> [seqLen, seqLen, numHeads]
|
||||
bias := mlx.Take(m.RelativeAttentionBias, bucketsArr, 0) // [seqLen, seqLen, numHeads]
|
||||
|
||||
// Transpose to [numHeads, seqLen, seqLen]
|
||||
bias = mlx.Transpose(bias, 2, 0, 1) // [numHeads, seqLen, seqLen]
|
||||
bias = mlx.ExpandDims(bias, 0) // [1, numHeads, seqLen, seqLen]
|
||||
|
||||
return bias
|
||||
}
|
||||
|
||||
// relativePosistionBucket computes the bucket for a relative position
|
||||
func relativePosistionBucket(relativePosition int32, numBuckets int32, maxDistance int32, bidirectional bool) int32 {
|
||||
var bucket int32 = 0
|
||||
var n int32 = -relativePosition
|
||||
|
||||
if bidirectional {
|
||||
numBuckets /= 2
|
||||
if n < 0 {
|
||||
bucket += numBuckets
|
||||
n = -n
|
||||
}
|
||||
} else {
|
||||
if n < 0 {
|
||||
n = 0
|
||||
}
|
||||
}
|
||||
|
||||
// Half buckets are for exact positions, half are for log-spaced
|
||||
maxExact := numBuckets / 2
|
||||
if n < maxExact {
|
||||
bucket += n
|
||||
} else {
|
||||
// Log-spaced buckets
|
||||
logVal := math.Log(float64(n)/float64(maxExact)) / math.Log(float64(maxDistance)/float64(maxExact))
|
||||
bucket += maxExact + int32(logVal*float64(numBuckets-maxExact))
|
||||
if bucket > numBuckets-1 {
|
||||
bucket = numBuckets - 1
|
||||
}
|
||||
}
|
||||
|
||||
return bucket
|
||||
}
|
||||
|
||||
// Forward for T5Block
|
||||
func (b *T5Block) Forward(x *mlx.Array, posBias *mlx.Array, eps float32) *mlx.Array {
|
||||
// Self attention with residual
|
||||
h := b.Layer0.Forward(x, posBias, eps)
|
||||
|
||||
// FFN with residual
|
||||
h = b.Layer1.Forward(h, eps)
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// Forward for T5LayerSelfAttention
|
||||
func (l *T5LayerSelfAttention) Forward(x *mlx.Array, posBias *mlx.Array, eps float32) *mlx.Array {
|
||||
// Pre-norm
|
||||
normed := l.LayerNorm.Forward(x)
|
||||
|
||||
// Attention
|
||||
attnOut := l.SelfAttention.Forward(normed, posBias)
|
||||
|
||||
// Residual
|
||||
return mlx.Add(x, attnOut)
|
||||
}
|
||||
|
||||
// Forward for T5Attention
|
||||
func (attn *T5Attention) Forward(x *mlx.Array, posBias *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
D := shape[2]
|
||||
|
||||
// Q, K, V projections (no bias)
|
||||
// Weights are [out_features, in_features], so we use matmul with transpose
|
||||
q := mlx.Matmul(x, mlx.Transpose(attn.Q, 1, 0))
|
||||
k := mlx.Matmul(x, mlx.Transpose(attn.K, 1, 0))
|
||||
v := mlx.Matmul(x, mlx.Transpose(attn.V, 1, 0))
|
||||
|
||||
// Reshape to [B, L, nheads, d_kv]
|
||||
q = mlx.Reshape(q, B, L, attn.NHeads, attn.DKV)
|
||||
k = mlx.Reshape(k, B, L, attn.NHeads, attn.DKV)
|
||||
v = mlx.Reshape(v, B, L, attn.NHeads, attn.DKV)
|
||||
|
||||
// Transpose to [B, nheads, L, d_kv]
|
||||
q = mlx.Transpose(q, 0, 2, 1, 3)
|
||||
k = mlx.Transpose(k, 0, 2, 1, 3)
|
||||
v = mlx.Transpose(v, 0, 2, 1, 3)
|
||||
|
||||
// Attention scores with relative position bias
|
||||
// T5 uses UNSCALED dot-product attention: scores = q @ k.T + pos_bias
|
||||
// (no 1/sqrt(d_k) scale factor like in standard transformers)
|
||||
scores := mlx.Matmul(q, mlx.Transpose(k, 0, 1, 3, 2))
|
||||
scores = mlx.Add(scores, posBias)
|
||||
|
||||
// Softmax
|
||||
attnWeights := mlx.Softmax(scores, -1)
|
||||
|
||||
// Attend to values
|
||||
out := mlx.Matmul(attnWeights, v)
|
||||
|
||||
// Transpose back [B, nheads, L, d_kv] -> [B, L, nheads, d_kv]
|
||||
out = mlx.Transpose(out, 0, 2, 1, 3)
|
||||
// Reshape to [B, L, D]
|
||||
out = mlx.Reshape(out, B, L, attn.NHeads*attn.DKV)
|
||||
|
||||
// Output projection
|
||||
out = mlx.Matmul(out, mlx.Transpose(attn.O, 1, 0))
|
||||
|
||||
_ = D // Silence unused warning
|
||||
return out
|
||||
}
|
||||
|
||||
// Forward for T5LayerFF
|
||||
func (l *T5LayerFF) Forward(x *mlx.Array, eps float32) *mlx.Array {
|
||||
// Pre-norm
|
||||
normed := l.LayerNorm.Forward(x)
|
||||
|
||||
// FFN
|
||||
ffOut := l.DenseReluDense.Forward(normed)
|
||||
|
||||
// Residual
|
||||
return mlx.Add(x, ffOut)
|
||||
}
|
||||
|
||||
// geluNew implements the GELU activation with tanh approximation (gelu_new)
|
||||
// This matches HuggingFace transformers' gelu_new/OpenAI GPT implementation
|
||||
// Formula: 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x³)))
|
||||
func geluNew(x *mlx.Array) *mlx.Array {
|
||||
sqrt2OverPi := float32(0.7978845608) // sqrt(2/π)
|
||||
coeff := float32(0.044715)
|
||||
|
||||
x3 := mlx.Mul(mlx.Mul(x, x), x)
|
||||
inner := mlx.MulScalar(mlx.Add(x, mlx.MulScalar(x3, coeff)), sqrt2OverPi)
|
||||
return mlx.Mul(mlx.MulScalar(x, 0.5), mlx.AddScalar(mlx.Tanh(inner), 1.0))
|
||||
}
|
||||
|
||||
// Forward for T5DenseGatedGelu (gated-gelu activation)
|
||||
func (d *T5DenseGatedGelu) Forward(x *mlx.Array) *mlx.Array {
|
||||
// Gate projection with GELU activation (T5 v1.1/ByT5 uses gelu_new)
|
||||
gate := mlx.Matmul(x, mlx.Transpose(d.Wi0, 1, 0))
|
||||
gate = geluNew(gate)
|
||||
|
||||
// Up projection
|
||||
up := mlx.Matmul(x, mlx.Transpose(d.Wi1, 1, 0))
|
||||
|
||||
// Gated output
|
||||
h := mlx.Mul(gate, up)
|
||||
|
||||
// Down projection
|
||||
return mlx.Matmul(h, mlx.Transpose(d.Wo, 1, 0))
|
||||
}
|
||||
|
||||
// Forward for T5LayerNorm (RMSNorm variant)
|
||||
func (ln *T5LayerNorm) Forward(x *mlx.Array) *mlx.Array {
|
||||
// T5 uses RMSNorm: x * rsqrt(mean(x^2) + eps) * weight
|
||||
variance := mlx.Mean(mlx.Square(x), -1, true)
|
||||
x = mlx.Mul(x, mlx.RSqrt(mlx.AddScalar(variance, ln.Eps)))
|
||||
return mlx.Mul(x, ln.Weight)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,477 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package glm_image
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
|
||||
// VAEConfig holds VAE decoder configuration
|
||||
type VAEConfig struct {
|
||||
InChannels int32 `json:"in_channels"` // 3
|
||||
OutChannels int32 `json:"out_channels"` // 3
|
||||
LatentChannels int32 `json:"latent_channels"` // 16
|
||||
BlockOutChannels []int32 `json:"block_out_channels"` // [128, 512, 1024, 1024]
|
||||
LayersPerBlock int32 `json:"layers_per_block"` // 3
|
||||
NormNumGroups int32 `json:"norm_num_groups"` // 32
|
||||
ScalingFactor float32 `json:"scaling_factor"` // 0.18215
|
||||
ShiftFactor *float32 `json:"shift_factor"` // null
|
||||
LatentsMean []float32 `json:"latents_mean"` // [16 values]
|
||||
LatentsStd []float32 `json:"latents_std"` // [16 values]
|
||||
}
|
||||
|
||||
// VAEDecoder is the VAE latent decoder
|
||||
type VAEDecoder struct {
|
||||
Config *VAEConfig
|
||||
|
||||
// Decoder components
|
||||
ConvIn *VAEConv2d `weight:"decoder.conv_in"`
|
||||
MidBlock *VAEMidBlock `weight:"decoder.mid_block"`
|
||||
UpBlocks []*VAEUpBlock `weight:"decoder.up_blocks"`
|
||||
ConvNormOut *GroupNorm `weight:"decoder.conv_norm_out"`
|
||||
ConvOut *VAEConv2d `weight:"decoder.conv_out"`
|
||||
}
|
||||
|
||||
// VAEConv2d is a 2D convolution layer
|
||||
type VAEConv2d struct {
|
||||
Weight *mlx.Array `weight:"weight"`
|
||||
Bias *mlx.Array `weight:"bias"`
|
||||
Stride int32
|
||||
Padding int32
|
||||
}
|
||||
|
||||
// GroupNorm is group normalization
|
||||
type GroupNorm struct {
|
||||
Weight *mlx.Array `weight:"weight"`
|
||||
Bias *mlx.Array `weight:"bias"`
|
||||
NumGroups int32
|
||||
Eps float32
|
||||
}
|
||||
|
||||
// VAEMidBlock is the middle block of the VAE
|
||||
type VAEMidBlock struct {
|
||||
Resnets []*VAEResnetBlock `weight:"resnets"`
|
||||
}
|
||||
|
||||
// VAEUpBlock is an upsampling block
|
||||
type VAEUpBlock struct {
|
||||
Resnets []*VAEResnetBlock `weight:"resnets"`
|
||||
Upsamplers []*VAEUpsampler `weight:"upsamplers"`
|
||||
}
|
||||
|
||||
// VAEResnetBlock is a residual block
|
||||
type VAEResnetBlock struct {
|
||||
Norm1 *GroupNorm `weight:"norm1"`
|
||||
Conv1 *VAEConv2d `weight:"conv1"`
|
||||
Norm2 *GroupNorm `weight:"norm2"`
|
||||
Conv2 *VAEConv2d `weight:"conv2"`
|
||||
ConvShortcut *VAEConv2d `weight:"conv_shortcut,optional"` // Optional, for channel mismatch
|
||||
}
|
||||
|
||||
// VAEUpsampler is an upsampling layer
|
||||
type VAEUpsampler struct {
|
||||
Conv *VAEConv2d `weight:"conv"`
|
||||
}
|
||||
|
||||
// Load loads the VAE decoder from manifest
|
||||
func (m *VAEDecoder) Load(manifest *imagegen.ModelManifest) error {
|
||||
fmt.Print(" Loading VAE decoder... ")
|
||||
|
||||
// Load config
|
||||
var cfg VAEConfig
|
||||
if err := manifest.ReadConfigJSON("vae/config.json", &cfg); err != nil {
|
||||
return fmt.Errorf("config: %w", err)
|
||||
}
|
||||
m.Config = &cfg
|
||||
|
||||
// Initialize structure based on config
|
||||
numBlocks := len(cfg.BlockOutChannels)
|
||||
m.UpBlocks = make([]*VAEUpBlock, numBlocks)
|
||||
|
||||
// Pre-allocate MidBlock resnets (VAE mid_block typically has 2 resnets)
|
||||
m.MidBlock = &VAEMidBlock{
|
||||
Resnets: make([]*VAEResnetBlock, 2),
|
||||
}
|
||||
|
||||
// Pre-allocate UpBlocks with their resnets and upsamplers
|
||||
// VAE decoder has layers_per_block+1 resnets per up_block (to match encoder)
|
||||
// And all but the last up_block has an upsampler
|
||||
for i := 0; i < numBlocks; i++ {
|
||||
numResnets := cfg.LayersPerBlock + 1 // typically 4 resnets
|
||||
m.UpBlocks[i] = &VAEUpBlock{
|
||||
Resnets: make([]*VAEResnetBlock, numResnets),
|
||||
}
|
||||
// All but the last block has upsamplers
|
||||
if i < numBlocks-1 {
|
||||
m.UpBlocks[i].Upsamplers = make([]*VAEUpsampler, 1)
|
||||
}
|
||||
}
|
||||
|
||||
// Load weights
|
||||
weights, err := imagegen.LoadWeightsFromManifest(manifest, "vae")
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
|
||||
return fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
defer weights.ReleaseAll()
|
||||
|
||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||
return fmt.Errorf("load module: %w", err)
|
||||
}
|
||||
|
||||
// Initialize GroupNorm parameters
|
||||
m.initGroupNorms()
|
||||
|
||||
fmt.Println("✓")
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadFromPath loads the VAE decoder from a directory path
|
||||
func (m *VAEDecoder) LoadFromPath(path string) error {
|
||||
fmt.Print(" Loading VAE decoder... ")
|
||||
|
||||
// Load config
|
||||
var cfg VAEConfig
|
||||
configPath := filepath.Join(path, "config.json")
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read config: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return fmt.Errorf("parse config: %w", err)
|
||||
}
|
||||
m.Config = &cfg
|
||||
|
||||
// Initialize structure based on config
|
||||
numBlocks := len(cfg.BlockOutChannels)
|
||||
m.UpBlocks = make([]*VAEUpBlock, numBlocks)
|
||||
|
||||
// Pre-allocate MidBlock resnets (VAE mid_block typically has 2 resnets)
|
||||
m.MidBlock = &VAEMidBlock{
|
||||
Resnets: make([]*VAEResnetBlock, 2),
|
||||
}
|
||||
|
||||
// Pre-allocate UpBlocks with their resnets and upsamplers
|
||||
for i := 0; i < numBlocks; i++ {
|
||||
numResnets := cfg.LayersPerBlock + 1
|
||||
m.UpBlocks[i] = &VAEUpBlock{
|
||||
Resnets: make([]*VAEResnetBlock, numResnets),
|
||||
}
|
||||
if i < numBlocks-1 {
|
||||
m.UpBlocks[i].Upsamplers = make([]*VAEUpsampler, 1)
|
||||
}
|
||||
}
|
||||
|
||||
// Load weights from safetensors files
|
||||
weights, err := safetensors.LoadModelWeights(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
|
||||
return fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
defer weights.ReleaseAll()
|
||||
|
||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||
return fmt.Errorf("load module: %w", err)
|
||||
}
|
||||
|
||||
// Initialize GroupNorm parameters
|
||||
m.initGroupNorms()
|
||||
|
||||
fmt.Println("✓")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *VAEDecoder) initGroupNorms() {
|
||||
cfg := m.Config
|
||||
numGroups := cfg.NormNumGroups
|
||||
eps := float32(1e-6) // Must match diffusers VAE (1e-6, not 1e-5)
|
||||
|
||||
if m.ConvNormOut != nil {
|
||||
m.ConvNormOut.NumGroups = numGroups
|
||||
m.ConvNormOut.Eps = eps
|
||||
}
|
||||
|
||||
if m.MidBlock != nil {
|
||||
for _, resnet := range m.MidBlock.Resnets {
|
||||
if resnet.Norm1 != nil {
|
||||
resnet.Norm1.NumGroups = numGroups
|
||||
resnet.Norm1.Eps = eps
|
||||
}
|
||||
if resnet.Norm2 != nil {
|
||||
resnet.Norm2.NumGroups = numGroups
|
||||
resnet.Norm2.Eps = eps
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, upBlock := range m.UpBlocks {
|
||||
if upBlock == nil {
|
||||
continue
|
||||
}
|
||||
for _, resnet := range upBlock.Resnets {
|
||||
if resnet == nil {
|
||||
continue
|
||||
}
|
||||
if resnet.Norm1 != nil {
|
||||
resnet.Norm1.NumGroups = numGroups
|
||||
resnet.Norm1.Eps = eps
|
||||
}
|
||||
if resnet.Norm2 != nil {
|
||||
resnet.Norm2.NumGroups = numGroups
|
||||
resnet.Norm2.Eps = eps
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Decode decodes latents to an image
|
||||
func (m *VAEDecoder) Decode(latents *mlx.Array) *mlx.Array {
|
||||
cfg := m.Config
|
||||
|
||||
// Apply latent denormalization if mean/std are provided
|
||||
// This matches diffusers GLM-Image: latents = latents * std + mean
|
||||
// Note: GLM-Image does NOT divide by scaling_factor (unlike standard SD VAEs)
|
||||
if len(cfg.LatentsMean) > 0 && len(cfg.LatentsStd) > 0 {
|
||||
latents = m.denormalizeLatents(latents)
|
||||
}
|
||||
|
||||
// Convert from NCHW to NHWC for processing
|
||||
// [B, C, H, W] -> [B, H, W, C]
|
||||
x := mlx.Transpose(latents, 0, 2, 3, 1)
|
||||
|
||||
// Initial convolution
|
||||
x = m.ConvIn.Forward(x)
|
||||
|
||||
// Mid block
|
||||
x = m.MidBlock.Forward(x)
|
||||
|
||||
// Up blocks (forward order - index 0 is at lowest resolution/highest channels)
|
||||
for i := 0; i < len(m.UpBlocks); i++ {
|
||||
if m.UpBlocks[i] != nil {
|
||||
x = m.UpBlocks[i].Forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
// Final normalization and convolution
|
||||
x = m.ConvNormOut.Forward(x)
|
||||
x = mlx.SiLU(x)
|
||||
x = m.ConvOut.Forward(x)
|
||||
|
||||
// Convert back to NCHW
|
||||
// [B, H, W, C] -> [B, C, H, W]
|
||||
x = mlx.Transpose(x, 0, 3, 1, 2)
|
||||
|
||||
// Clamp to valid range and convert to [0, 1]
|
||||
x = mlx.ClipScalar(x, -1.0, 1.0, true, true)
|
||||
x = mlx.AddScalar(x, 1.0)
|
||||
x = mlx.DivScalar(x, 2.0)
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// denormalizeLatents applies the latent mean/std denormalization
|
||||
func (m *VAEDecoder) denormalizeLatents(latents *mlx.Array) *mlx.Array {
|
||||
cfg := m.Config
|
||||
|
||||
// Create mean and std arrays [1, C, 1, 1] for broadcasting
|
||||
mean := mlx.NewArray(cfg.LatentsMean, []int32{1, int32(len(cfg.LatentsMean)), 1, 1})
|
||||
std := mlx.NewArray(cfg.LatentsStd, []int32{1, int32(len(cfg.LatentsStd)), 1, 1})
|
||||
|
||||
// Denormalize: latents * std + mean
|
||||
latents = mlx.Mul(latents, std)
|
||||
latents = mlx.Add(latents, mean)
|
||||
|
||||
return latents
|
||||
}
|
||||
|
||||
// Forward for VAEConv2d
|
||||
func (c *VAEConv2d) Forward(x *mlx.Array) *mlx.Array {
|
||||
// x: [B, H, W, C_in] (NHWC)
|
||||
// PyTorch weight: [C_out, C_in, kH, kW] (OIHW)
|
||||
// MLX conv2d expects weight: [C_out, kH, kW, C_in] (OHWI)
|
||||
// So we need to transpose from OIHW to OHWI
|
||||
|
||||
stride := c.Stride
|
||||
if stride == 0 {
|
||||
stride = 1
|
||||
}
|
||||
padding := c.Padding
|
||||
if padding == 0 {
|
||||
// Default to same padding for 3x3 kernels
|
||||
wShape := c.Weight.Shape()
|
||||
if len(wShape) >= 3 && wShape[2] == 3 {
|
||||
padding = 1
|
||||
}
|
||||
}
|
||||
|
||||
// Transpose weight from OIHW [out, in, h, w] to OHWI [out, h, w, in]
|
||||
weight := mlx.Transpose(c.Weight, 0, 2, 3, 1)
|
||||
|
||||
out := mlx.Conv2d(x, weight, stride, padding)
|
||||
if c.Bias != nil {
|
||||
// Bias: [C_out] -> [1, 1, 1, C_out]
|
||||
bias := mlx.Reshape(c.Bias, 1, 1, 1, -1)
|
||||
out = mlx.Add(out, bias)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// Forward for GroupNorm
|
||||
func (gn *GroupNorm) Forward(x *mlx.Array) *mlx.Array {
|
||||
// x: [B, H, W, C] (NHWC)
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
H := shape[1]
|
||||
W := shape[2]
|
||||
C := shape[3]
|
||||
|
||||
numGroups := gn.NumGroups
|
||||
if numGroups == 0 {
|
||||
numGroups = 32
|
||||
}
|
||||
groupSize := C / numGroups
|
||||
|
||||
// Reshape to [B, H, W, groups, groupSize]
|
||||
x = mlx.Reshape(x, B, H, W, numGroups, groupSize)
|
||||
|
||||
// Compute mean and variance per group
|
||||
mean := mlx.Mean(x, 1, true)
|
||||
mean = mlx.Mean(mean, 2, true)
|
||||
mean = mlx.Mean(mean, 4, true)
|
||||
|
||||
xCentered := mlx.Sub(x, mean)
|
||||
variance := mlx.Mean(mlx.Square(xCentered), 1, true)
|
||||
variance = mlx.Mean(variance, 2, true)
|
||||
variance = mlx.Mean(variance, 4, true)
|
||||
|
||||
// Normalize
|
||||
xNorm := mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, gn.Eps)))
|
||||
|
||||
// Reshape back
|
||||
xNorm = mlx.Reshape(xNorm, B, H, W, C)
|
||||
|
||||
// Scale and shift
|
||||
if gn.Weight != nil {
|
||||
weight := mlx.Reshape(gn.Weight, 1, 1, 1, C)
|
||||
xNorm = mlx.Mul(xNorm, weight)
|
||||
}
|
||||
if gn.Bias != nil {
|
||||
bias := mlx.Reshape(gn.Bias, 1, 1, 1, C)
|
||||
xNorm = mlx.Add(xNorm, bias)
|
||||
}
|
||||
|
||||
return xNorm
|
||||
}
|
||||
|
||||
// Forward for VAEMidBlock
|
||||
func (mb *VAEMidBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
for _, resnet := range mb.Resnets {
|
||||
x = resnet.Forward(x)
|
||||
}
|
||||
return x
|
||||
}
|
||||
|
||||
// Forward for VAEUpBlock
|
||||
func (ub *VAEUpBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
// Apply resnets
|
||||
for _, resnet := range ub.Resnets {
|
||||
if resnet != nil {
|
||||
x = resnet.Forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
// Apply upsamplers
|
||||
for _, upsampler := range ub.Upsamplers {
|
||||
if upsampler != nil {
|
||||
x = upsampler.Forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// Forward for VAEResnetBlock
|
||||
func (rb *VAEResnetBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||
residual := x
|
||||
|
||||
// First norm + activation + conv
|
||||
h := rb.Norm1.Forward(x)
|
||||
h = mlx.SiLU(h)
|
||||
h = rb.Conv1.Forward(h)
|
||||
|
||||
// Second norm + activation + conv
|
||||
h = rb.Norm2.Forward(h)
|
||||
h = mlx.SiLU(h)
|
||||
h = rb.Conv2.Forward(h)
|
||||
|
||||
// Shortcut for channel mismatch
|
||||
if rb.ConvShortcut != nil {
|
||||
residual = rb.ConvShortcut.Forward(residual)
|
||||
}
|
||||
|
||||
return mlx.Add(h, residual)
|
||||
}
|
||||
|
||||
// Forward for VAEUpsampler (2x nearest neighbor upsample + conv)
|
||||
func (us *VAEUpsampler) Forward(x *mlx.Array) *mlx.Array {
|
||||
// x: [B, H, W, C]
|
||||
// 2x nearest neighbor upsample
|
||||
x = upsample2x(x)
|
||||
|
||||
// Conv
|
||||
if us.Conv != nil {
|
||||
x = us.Conv.Forward(x)
|
||||
}
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// upsample2x performs 2x nearest neighbor upsampling.
|
||||
// Input and output are in NHWC format: [B, H, W, C] -> [B, H*2, W*2, C]
|
||||
func upsample2x(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
H := shape[1]
|
||||
W := shape[2]
|
||||
C := shape[3]
|
||||
|
||||
// Create indices [0, 0, 1, 1, 2, 2, ...] for nearest neighbor
|
||||
hIndices := make([]int32, H*2)
|
||||
for i := int32(0); i < H; i++ {
|
||||
hIndices[i*2] = i
|
||||
hIndices[i*2+1] = i
|
||||
}
|
||||
wIndices := make([]int32, W*2)
|
||||
for i := int32(0); i < W; i++ {
|
||||
wIndices[i*2] = i
|
||||
wIndices[i*2+1] = i
|
||||
}
|
||||
|
||||
hIdx := mlx.NewArrayInt32(hIndices, []int32{H * 2})
|
||||
wIdx := mlx.NewArrayInt32(wIndices, []int32{W * 2})
|
||||
|
||||
// Take along height axis
|
||||
x = mlx.Reshape(x, B*H, W, C)
|
||||
x = mlx.Take(x, wIdx, 1) // [B*H, W*2, C]
|
||||
x = mlx.Reshape(x, B, H, W*2, C)
|
||||
|
||||
// Take along width axis - transpose to [B, W*2, H, C], take, transpose back
|
||||
x = mlx.Transpose(x, 0, 2, 1, 3) // [B, W*2, H, C]
|
||||
x = mlx.Reshape(x, B*(W*2), H, C)
|
||||
x = mlx.Take(x, hIdx, 1) // [B*(W*2), H*2, C]
|
||||
x = mlx.Reshape(x, B, W*2, H*2, C)
|
||||
x = mlx.Transpose(x, 0, 2, 1, 3) // [B, H*2, W*2, C]
|
||||
|
||||
return x
|
||||
}
|
||||
@@ -1,982 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package glm_image
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
|
||||
// VisionLanguageConfig holds GLM-Image AR generator configuration
|
||||
type VisionLanguageConfig struct {
|
||||
// Text model config
|
||||
HiddenSize int32 `json:"hidden_size"` // 4096
|
||||
NumHiddenLayers int32 `json:"num_hidden_layers"` // 40
|
||||
IntermediateSize int32 `json:"intermediate_size"` // 13696
|
||||
NumAttentionHeads int32 `json:"num_attention_heads"` // 32
|
||||
NumKeyValueHeads int32 `json:"num_key_value_heads"` // 2
|
||||
VocabSize int32 `json:"vocab_size"` // 168064
|
||||
RMSNormEps float32 `json:"rms_norm_eps"` // 1e-5
|
||||
|
||||
// RoPE config
|
||||
RopeTheta float32 `json:"rope_theta"` // 10000
|
||||
PartialRotaryFactor float32 `json:"partial_rotary_factor"` // 0.5
|
||||
MRoPESection []int32 `json:"mrope_section"` // [8, 12, 12]
|
||||
|
||||
// Visual token config
|
||||
VisionVocabSize int32 `json:"vision_vocab_size"` // 16512
|
||||
ImageStartTokenID int32 `json:"image_start_token_id"` // 16384
|
||||
ImageEndTokenID int32 `json:"image_end_token_id"` // 16385
|
||||
ImageTokenID int32 `json:"image_token_id"` // 167855
|
||||
|
||||
// Computed
|
||||
HeadDim int32
|
||||
}
|
||||
|
||||
// VisionLanguageEncoder is the 9B AR generator
|
||||
type VisionLanguageEncoder struct {
|
||||
Config *VisionLanguageConfig
|
||||
|
||||
// Embedding
|
||||
EmbedTokens *nn.Embedding `weight:"model.language_model.embed_tokens"`
|
||||
|
||||
// Transformer layers
|
||||
Layers []*GLMBlock `weight:"model.language_model.layers"`
|
||||
|
||||
// Final norm
|
||||
FinalNorm *nn.RMSNorm `weight:"model.language_model.norm"`
|
||||
|
||||
// LM Head
|
||||
LMHead *mlx.Array `weight:"lm_head.weight"`
|
||||
}
|
||||
|
||||
// GLMBlock is a single transformer block in GLM-4 style
|
||||
type GLMBlock struct {
|
||||
// Pre-attention norm (GLM uses post-LN variant)
|
||||
InputLayerNorm *nn.RMSNorm `weight:"input_layernorm"`
|
||||
PostSelfAttnNorm *nn.RMSNorm `weight:"post_self_attn_layernorm"`
|
||||
PostAttnLayerNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
|
||||
PostMLPLayerNorm *nn.RMSNorm `weight:"post_mlp_layernorm"`
|
||||
|
||||
// Attention
|
||||
SelfAttn *GLMAttention `weight:"self_attn"`
|
||||
|
||||
// MLP (fused gate_up)
|
||||
MLP *GLMMLP `weight:"mlp"`
|
||||
}
|
||||
|
||||
// GLMAttention implements GQA with partial rotary and MRoPE
|
||||
type GLMAttention struct {
|
||||
QProj *mlx.Array `weight:"q_proj.weight"`
|
||||
KProj *mlx.Array `weight:"k_proj.weight"`
|
||||
VProj *mlx.Array `weight:"v_proj.weight"`
|
||||
OProj *mlx.Array `weight:"o_proj.weight"`
|
||||
|
||||
// QKV have biases in GLM
|
||||
QBias *mlx.Array `weight:"q_proj.bias"`
|
||||
KBias *mlx.Array `weight:"k_proj.bias"`
|
||||
VBias *mlx.Array `weight:"v_proj.bias"`
|
||||
|
||||
// Computed
|
||||
NHeads int32
|
||||
NKVHeads int32
|
||||
HeadDim int32
|
||||
Scale float32
|
||||
PartialRotary float32 // Only rotate this fraction of head_dim
|
||||
RopeTheta float32
|
||||
MRoPESection []int32 // [8, 12, 12] - frequency pairs per dimension (temporal, height, width)
|
||||
}
|
||||
|
||||
// ARCache holds KV caches for all layers using the shared cache implementation
|
||||
type ARCache struct {
|
||||
Layers []cache.Cache
|
||||
}
|
||||
|
||||
// NewARCache creates a new cache for the given number of layers
|
||||
func NewARCache(numLayers int32) *ARCache {
|
||||
layers := make([]cache.Cache, numLayers)
|
||||
for i := range layers {
|
||||
layers[i] = cache.NewKVCache()
|
||||
}
|
||||
return &ARCache{Layers: layers}
|
||||
}
|
||||
|
||||
// Free releases all cached tensors
|
||||
func (c *ARCache) Free() {
|
||||
for _, layer := range c.Layers {
|
||||
for _, arr := range layer.State() {
|
||||
if arr != nil {
|
||||
arr.Free()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GLMMLP implements fused gate_up SwiGLU MLP
|
||||
type GLMMLP struct {
|
||||
// GLM uses fused gate_up_proj: [hidden, 2*intermediate]
|
||||
GateUpProj *mlx.Array `weight:"gate_up_proj.weight"`
|
||||
DownProj *mlx.Array `weight:"down_proj.weight"`
|
||||
}
|
||||
|
||||
// Load loads the vision-language encoder from manifest
|
||||
func (m *VisionLanguageEncoder) Load(manifest *imagegen.ModelManifest) error {
|
||||
fmt.Print(" Loading vision-language encoder... ")
|
||||
|
||||
// Load config
|
||||
var rawCfg struct {
|
||||
TextConfig struct {
|
||||
HiddenSize int32 `json:"hidden_size"`
|
||||
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
||||
IntermediateSize int32 `json:"intermediate_size"`
|
||||
NumAttentionHeads int32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads int32 `json:"num_key_value_heads"`
|
||||
VocabSize int32 `json:"vocab_size"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
VisionVocabSize int32 `json:"vision_vocab_size"`
|
||||
RopeParameters struct {
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
|
||||
MRoPESection []int32 `json:"mrope_section"`
|
||||
} `json:"rope_parameters"`
|
||||
} `json:"text_config"`
|
||||
ImageStartTokenID int32 `json:"image_start_token_id"`
|
||||
ImageEndTokenID int32 `json:"image_end_token_id"`
|
||||
ImageTokenID int32 `json:"image_token_id"`
|
||||
}
|
||||
|
||||
if err := manifest.ReadConfigJSON("vision_language_encoder/config.json", &rawCfg); err != nil {
|
||||
return fmt.Errorf("config: %w", err)
|
||||
}
|
||||
|
||||
cfg := &VisionLanguageConfig{
|
||||
HiddenSize: rawCfg.TextConfig.HiddenSize,
|
||||
NumHiddenLayers: rawCfg.TextConfig.NumHiddenLayers,
|
||||
IntermediateSize: rawCfg.TextConfig.IntermediateSize,
|
||||
NumAttentionHeads: rawCfg.TextConfig.NumAttentionHeads,
|
||||
NumKeyValueHeads: rawCfg.TextConfig.NumKeyValueHeads,
|
||||
VocabSize: rawCfg.TextConfig.VocabSize,
|
||||
RMSNormEps: rawCfg.TextConfig.RMSNormEps,
|
||||
VisionVocabSize: rawCfg.TextConfig.VisionVocabSize,
|
||||
RopeTheta: rawCfg.TextConfig.RopeParameters.RopeTheta,
|
||||
PartialRotaryFactor: rawCfg.TextConfig.RopeParameters.PartialRotaryFactor,
|
||||
MRoPESection: rawCfg.TextConfig.RopeParameters.MRoPESection,
|
||||
ImageStartTokenID: rawCfg.ImageStartTokenID,
|
||||
ImageEndTokenID: rawCfg.ImageEndTokenID,
|
||||
ImageTokenID: rawCfg.ImageTokenID,
|
||||
}
|
||||
|
||||
cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads
|
||||
m.Config = cfg
|
||||
|
||||
// Pre-allocate layers
|
||||
m.Layers = make([]*GLMBlock, cfg.NumHiddenLayers)
|
||||
|
||||
// Load weights
|
||||
weights, err := imagegen.LoadWeightsFromManifest(manifest, "vision_language_encoder")
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
|
||||
return fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
defer weights.ReleaseAll()
|
||||
|
||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||
return fmt.Errorf("load module: %w", err)
|
||||
}
|
||||
|
||||
m.initComputedFields()
|
||||
fmt.Printf("✓ [%d layers]\n", cfg.NumHiddenLayers)
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadFromPath loads the vision-language encoder from a directory path
|
||||
func (m *VisionLanguageEncoder) LoadFromPath(path string) error {
|
||||
fmt.Print(" Loading vision-language encoder... ")
|
||||
|
||||
// Load config
|
||||
var rawCfg struct {
|
||||
TextConfig struct {
|
||||
HiddenSize int32 `json:"hidden_size"`
|
||||
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
||||
IntermediateSize int32 `json:"intermediate_size"`
|
||||
NumAttentionHeads int32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads int32 `json:"num_key_value_heads"`
|
||||
VocabSize int32 `json:"vocab_size"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
VisionVocabSize int32 `json:"vision_vocab_size"`
|
||||
RopeParameters struct {
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
|
||||
MRoPESection []int32 `json:"mrope_section"`
|
||||
} `json:"rope_parameters"`
|
||||
} `json:"text_config"`
|
||||
ImageStartTokenID int32 `json:"image_start_token_id"`
|
||||
ImageEndTokenID int32 `json:"image_end_token_id"`
|
||||
ImageTokenID int32 `json:"image_token_id"`
|
||||
}
|
||||
|
||||
configPath := filepath.Join(path, "config.json")
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read config: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(data, &rawCfg); err != nil {
|
||||
return fmt.Errorf("parse config: %w", err)
|
||||
}
|
||||
|
||||
cfg := &VisionLanguageConfig{
|
||||
HiddenSize: rawCfg.TextConfig.HiddenSize,
|
||||
NumHiddenLayers: rawCfg.TextConfig.NumHiddenLayers,
|
||||
IntermediateSize: rawCfg.TextConfig.IntermediateSize,
|
||||
NumAttentionHeads: rawCfg.TextConfig.NumAttentionHeads,
|
||||
NumKeyValueHeads: rawCfg.TextConfig.NumKeyValueHeads,
|
||||
VocabSize: rawCfg.TextConfig.VocabSize,
|
||||
RMSNormEps: rawCfg.TextConfig.RMSNormEps,
|
||||
VisionVocabSize: rawCfg.TextConfig.VisionVocabSize,
|
||||
RopeTheta: rawCfg.TextConfig.RopeParameters.RopeTheta,
|
||||
PartialRotaryFactor: rawCfg.TextConfig.RopeParameters.PartialRotaryFactor,
|
||||
MRoPESection: rawCfg.TextConfig.RopeParameters.MRoPESection,
|
||||
ImageStartTokenID: rawCfg.ImageStartTokenID,
|
||||
ImageEndTokenID: rawCfg.ImageEndTokenID,
|
||||
ImageTokenID: rawCfg.ImageTokenID,
|
||||
}
|
||||
|
||||
cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads
|
||||
m.Config = cfg
|
||||
|
||||
// Pre-allocate layers
|
||||
m.Layers = make([]*GLMBlock, cfg.NumHiddenLayers)
|
||||
|
||||
// Load weights
|
||||
weights, err := safetensors.LoadModelWeights(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("weights: %w", err)
|
||||
}
|
||||
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
|
||||
return fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
defer weights.ReleaseAll()
|
||||
|
||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||
return fmt.Errorf("load module: %w", err)
|
||||
}
|
||||
|
||||
m.initComputedFields()
|
||||
fmt.Printf("✓ [%d layers]\n", cfg.NumHiddenLayers)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *VisionLanguageEncoder) initComputedFields() {
|
||||
cfg := m.Config
|
||||
for _, block := range m.Layers {
|
||||
block.SelfAttn.NHeads = cfg.NumAttentionHeads
|
||||
block.SelfAttn.NKVHeads = cfg.NumKeyValueHeads
|
||||
block.SelfAttn.HeadDim = cfg.HeadDim
|
||||
block.SelfAttn.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
|
||||
block.SelfAttn.PartialRotary = cfg.PartialRotaryFactor
|
||||
block.SelfAttn.RopeTheta = cfg.RopeTheta
|
||||
block.SelfAttn.MRoPESection = cfg.MRoPESection
|
||||
|
||||
// Set norm eps
|
||||
block.InputLayerNorm.Eps = cfg.RMSNormEps
|
||||
block.PostSelfAttnNorm.Eps = cfg.RMSNormEps
|
||||
block.PostAttnLayerNorm.Eps = cfg.RMSNormEps
|
||||
block.PostMLPLayerNorm.Eps = cfg.RMSNormEps
|
||||
}
|
||||
m.FinalNorm.Eps = cfg.RMSNormEps
|
||||
}
|
||||
|
||||
// Generate autoregressively generates visual tokens with KV caching
|
||||
func (m *VisionLanguageEncoder) Generate(
|
||||
prompt string,
|
||||
tok *GLMTokenizer,
|
||||
maxTokens int32,
|
||||
temperature float32,
|
||||
topP float32,
|
||||
seed int64,
|
||||
targetHeight, targetWidth int32,
|
||||
progressFn func(int),
|
||||
) *mlx.Array {
|
||||
cfg := m.Config
|
||||
|
||||
// Encode prompt with grid tokens using GLM tokenizer
|
||||
// Format: {prompt}<sop>{h} {w}<eop><sop>{prev_h} {prev_w}<eop><|dit_token_16384|>
|
||||
tokens := tok.EncodeForGeneration(prompt, targetHeight, targetWidth)
|
||||
|
||||
// Calculate grid dimensions for MRoPE position IDs
|
||||
factor := int32(32)
|
||||
tokenH := targetHeight / factor
|
||||
tokenW := targetWidth / factor
|
||||
ratio := float64(tokenH) / float64(tokenW)
|
||||
prevTokenH := int32(math.Sqrt(ratio) * 16)
|
||||
prevTokenW := int32(math.Sqrt(1.0/ratio) * 16)
|
||||
prevGridSize := prevTokenH * prevTokenW
|
||||
|
||||
// Create KV cache for all layers
|
||||
cache := NewARCache(cfg.NumHiddenLayers)
|
||||
defer cache.Free()
|
||||
|
||||
// ===== PREFILL PHASE =====
|
||||
// Process entire prompt at once, populate cache
|
||||
promptLen := int32(len(tokens))
|
||||
tokenArr := mlx.NewArrayInt32(tokens, []int32{1, promptLen})
|
||||
h := m.EmbedTokens.Forward(tokenArr)
|
||||
tokenArr.Free()
|
||||
|
||||
mlx.Eval(h)
|
||||
|
||||
// Compute position IDs for prefill (text tokens use same position for all dims)
|
||||
prefillPositions := make([][]int32, 3)
|
||||
for dim := 0; dim < 3; dim++ {
|
||||
prefillPositions[dim] = make([]int32, promptLen)
|
||||
for i := int32(0); i < promptLen; i++ {
|
||||
prefillPositions[dim][i] = i
|
||||
}
|
||||
}
|
||||
|
||||
// Forward through layers (prefill)
|
||||
for i, layer := range m.Layers {
|
||||
oldH := h
|
||||
h = layer.ForwardWithCache(h, promptLen, 0, cfg.RMSNormEps, cache.Layers[i], prefillPositions)
|
||||
if i > 0 {
|
||||
oldH.Free()
|
||||
}
|
||||
}
|
||||
// Eval h and cache arrays together so cache is materialized
|
||||
evalArgs := []*mlx.Array{h}
|
||||
for _, lc := range cache.Layers {
|
||||
evalArgs = append(evalArgs, lc.State()...)
|
||||
}
|
||||
mlx.Eval(evalArgs...)
|
||||
|
||||
// Final norm and get logits for last position
|
||||
preNormH := h
|
||||
h = m.FinalNorm.Forward(h, cfg.RMSNormEps)
|
||||
preNormH.Free()
|
||||
|
||||
lastH := mlx.Slice(h, []int32{0, promptLen - 1, 0}, []int32{1, promptLen, cfg.HiddenSize})
|
||||
h.Free()
|
||||
lastH = mlx.Reshape(lastH, 1, cfg.HiddenSize)
|
||||
logits := mlx.Matmul(lastH, mlx.Transpose(m.LMHead, 1, 0))
|
||||
lastH.Free()
|
||||
|
||||
// Sample first token
|
||||
var sampleCounter int64 = 0
|
||||
nextToken := sampleVisualToken(logits, temperature, topP, cfg, seed, &sampleCounter)
|
||||
logits.Free()
|
||||
|
||||
// AR generation loop with caching
|
||||
// Visual tokens are stored as VQ codebook indices [0, 16383]
|
||||
// The LM head outputs indices [0, 16511] where:
|
||||
// - [0, 16383] are VQ codes
|
||||
// - 16384 is BOS
|
||||
// - 16385 is EOS
|
||||
visualTokens := make([]int32, 0, maxTokens)
|
||||
posOffset := promptLen
|
||||
visualTokenIdx := int32(0) // Index within visual token sequence for grid position calculation
|
||||
|
||||
// Preallocate slice for old cache state to reuse
|
||||
oldCacheState := make([]*mlx.Array, 0, len(m.Layers)*2)
|
||||
|
||||
for i := int32(0); i < maxTokens; i++ {
|
||||
if progressFn != nil {
|
||||
progressFn(int(i))
|
||||
}
|
||||
|
||||
// Check for end token (EOS = 16385)
|
||||
if nextToken == cfg.ImageEndTokenID {
|
||||
break
|
||||
}
|
||||
|
||||
// Skip BOS token (16384), only store actual VQ codes [0, 16383]
|
||||
if nextToken == cfg.ImageStartTokenID {
|
||||
// BOS token - skip storing but continue generation
|
||||
} else if nextToken < cfg.ImageStartTokenID {
|
||||
// This is an actual VQ code [0, 16383] - store it
|
||||
visualTokens = append(visualTokens, nextToken)
|
||||
}
|
||||
// Tokens >= 16386 are other special tokens, skip them
|
||||
|
||||
// ===== DECODE PHASE =====
|
||||
// Save old cache state before forward (to free after eval)
|
||||
oldCacheState = oldCacheState[:0]
|
||||
for _, lc := range cache.Layers {
|
||||
oldCacheState = append(oldCacheState, lc.State()...)
|
||||
}
|
||||
|
||||
// Only process the new token, use cached K,V
|
||||
tokenArr := mlx.NewArrayInt32([]int32{nextToken}, []int32{1, 1})
|
||||
h := m.EmbedTokens.Forward(tokenArr)
|
||||
tokenArr.Free()
|
||||
|
||||
// Compute MRoPE position IDs for this visual token
|
||||
// Visual tokens are arranged in two grids: prev grid then target grid
|
||||
// Position dimensions: [temporal, height, width]
|
||||
decodePositions := computeVisualTokenPositions(
|
||||
visualTokenIdx, posOffset, promptLen,
|
||||
prevTokenH, prevTokenW, prevGridSize,
|
||||
tokenH, tokenW,
|
||||
)
|
||||
|
||||
// Forward through layers (decode with cache)
|
||||
for j, layer := range m.Layers {
|
||||
oldH := h
|
||||
h = layer.ForwardWithCache(h, 1, posOffset, cfg.RMSNormEps, cache.Layers[j], decodePositions)
|
||||
if j > 0 { // Don't free the embedding on first layer
|
||||
oldH.Free()
|
||||
}
|
||||
}
|
||||
|
||||
// Eval h and new cache state
|
||||
newCacheState := make([]*mlx.Array, 0, len(m.Layers)*2)
|
||||
for _, lc := range cache.Layers {
|
||||
newCacheState = append(newCacheState, lc.State()...)
|
||||
}
|
||||
mlx.Eval(append([]*mlx.Array{h}, newCacheState...)...)
|
||||
|
||||
// Free old cache state (now that new state is evaluated)
|
||||
for _, arr := range oldCacheState {
|
||||
if arr != nil {
|
||||
arr.Free()
|
||||
}
|
||||
}
|
||||
|
||||
// Final norm
|
||||
preNormH := h
|
||||
h = m.FinalNorm.Forward(h, cfg.RMSNormEps)
|
||||
preNormH.Free()
|
||||
|
||||
// Get logits (h is already [1, 1, hidden_size])
|
||||
h = mlx.Reshape(h, 1, cfg.HiddenSize)
|
||||
logits := mlx.Matmul(h, mlx.Transpose(m.LMHead, 1, 0))
|
||||
h.Free()
|
||||
|
||||
// Sample next token
|
||||
nextToken = sampleVisualToken(logits, temperature, topP, cfg, seed, &sampleCounter)
|
||||
logits.Free()
|
||||
|
||||
posOffset++
|
||||
visualTokenIdx++
|
||||
|
||||
// Periodically clear cache to release intermediate memory
|
||||
if i%256 == 0 {
|
||||
mlx.ClearCache()
|
||||
}
|
||||
}
|
||||
|
||||
if len(visualTokens) == 0 {
|
||||
// Return at least one token to avoid empty tensor issues
|
||||
visualTokens = append(visualTokens, 0)
|
||||
}
|
||||
|
||||
return mlx.NewArrayInt32(visualTokens, []int32{1, int32(len(visualTokens))})
|
||||
}
|
||||
|
||||
// computeVisualTokenPositions computes MRoPE position IDs for a visual token
|
||||
// Returns [3][1] position IDs for temporal, height, and width dimensions
|
||||
//
|
||||
// MRoPE position encoding for GLM-Image visual tokens:
|
||||
// - temporal: CONSTANT within each grid (= decode_pos at grid start)
|
||||
// - height: decode_pos + row index within grid
|
||||
// - width: decode_pos + column index within grid
|
||||
//
|
||||
// Between grids, decode_pos advances by max(grid_h, grid_w) to ensure
|
||||
// sufficient positional separation.
|
||||
func computeVisualTokenPositions(
|
||||
visualIdx int32, absPos int32, promptLen int32,
|
||||
prevH, prevW, prevSize int32,
|
||||
targetH, targetW int32,
|
||||
) [][]int32 {
|
||||
positions := make([][]int32, 3)
|
||||
for dim := 0; dim < 3; dim++ {
|
||||
positions[dim] = make([]int32, 1)
|
||||
}
|
||||
|
||||
// First grid (prev grid) starts at decode_pos = promptLen
|
||||
prevGridDecodePos := promptLen
|
||||
|
||||
// Second grid (target grid) starts after first grid
|
||||
// next_pos = prev_decode_pos + max(prevH, prevW)
|
||||
maxPrev := prevH
|
||||
if prevW > maxPrev {
|
||||
maxPrev = prevW
|
||||
}
|
||||
targetGridDecodePos := prevGridDecodePos + maxPrev
|
||||
|
||||
// Compute position IDs based on which grid the token is in
|
||||
if visualIdx < prevSize {
|
||||
// Token is in the prev grid (prev_token_h × prev_token_w)
|
||||
row := visualIdx / prevW
|
||||
col := visualIdx % prevW
|
||||
|
||||
// temporal is CONSTANT for all tokens in this grid
|
||||
positions[0][0] = prevGridDecodePos
|
||||
// height and width are relative to grid's decode_pos
|
||||
positions[1][0] = prevGridDecodePos + row
|
||||
positions[2][0] = prevGridDecodePos + col
|
||||
} else {
|
||||
// Token is in the target grid (token_h × token_w)
|
||||
targetIdx := visualIdx - prevSize
|
||||
row := targetIdx / targetW
|
||||
col := targetIdx % targetW
|
||||
|
||||
// temporal is CONSTANT for all tokens in this grid
|
||||
positions[0][0] = targetGridDecodePos
|
||||
// height and width are relative to grid's decode_pos
|
||||
positions[1][0] = targetGridDecodePos + row
|
||||
positions[2][0] = targetGridDecodePos + col
|
||||
}
|
||||
|
||||
_ = targetH // Used for documentation clarity
|
||||
_ = absPos // No longer used - kept for API compatibility
|
||||
return positions
|
||||
}
|
||||
|
||||
// sampleVisualToken samples from the visual vocabulary using top-p (nucleus) sampling
|
||||
// Note: For GLM-Image, greedy decoding is not allowed as it may cause repetitive outputs
|
||||
// Returns a visual token ID in range [0, 16511] which directly indexes into the embedding table
|
||||
// sampleCounter is incremented for each call to ensure different random values
|
||||
func sampleVisualToken(logits *mlx.Array, temperature float32, topP float32, cfg *VisionLanguageConfig, seed int64, sampleCounter *int64) int32 {
|
||||
// The LMHead outputs logits for visual tokens only (shape [1, 16512])
|
||||
// Output index directly corresponds to vocab ID [0, 16511]
|
||||
// No offset needed - the visual tokens are at vocab IDs [0, 16511]
|
||||
visualLogits := logits
|
||||
|
||||
// Apply temperature
|
||||
if temperature != 1.0 && temperature > 0 {
|
||||
visualLogits = mlx.DivScalar(visualLogits, temperature)
|
||||
}
|
||||
|
||||
// Apply softmax to get probabilities
|
||||
probs := mlx.Softmax(visualLogits, -1)
|
||||
mlx.Eval(probs)
|
||||
|
||||
// Get the sampled index using top-p sampling
|
||||
// This directly gives us the vocab ID in [0, 16511]
|
||||
// Special tokens: 16384 = BOS, 16385 = EOS
|
||||
// Use seed + counter for reproducible but different random values
|
||||
effectiveSeed := seed + *sampleCounter
|
||||
*sampleCounter++
|
||||
return sampleTopP(probs, topP, effectiveSeed)
|
||||
}
|
||||
|
||||
// sampleTopP implements nucleus (top-p) sampling
|
||||
// probs: [1, vocab_size] probability distribution
|
||||
// topP: cumulative probability threshold (e.g., 0.75)
|
||||
// seed: random seed for reproducible sampling
|
||||
func sampleTopP(probs *mlx.Array, topP float32, seed int64) int32 {
|
||||
// Negate probs for descending sort (Argsort only does ascending)
|
||||
negProbs := mlx.MulScalar(probs, -1)
|
||||
sortedIndices := mlx.Argsort(negProbs, -1)
|
||||
sortedProbs := mlx.TakeAlongAxis(probs, sortedIndices, -1)
|
||||
cumProbs := mlx.Cumsum(sortedProbs, -1)
|
||||
mlx.Eval(sortedIndices, sortedProbs, cumProbs)
|
||||
|
||||
// Find cutoff index where cumulative probability exceeds topP
|
||||
probsData := sortedProbs.Data()
|
||||
cumProbsData := cumProbs.Data()
|
||||
indicesData := sortedIndices.DataInt32()
|
||||
|
||||
// Calculate cutoff and renormalize
|
||||
var cutoffIdx int
|
||||
var totalProb float32
|
||||
for i, cp := range cumProbsData {
|
||||
totalProb += probsData[i]
|
||||
if cp >= topP {
|
||||
cutoffIdx = i + 1 // Include this token
|
||||
break
|
||||
}
|
||||
}
|
||||
if cutoffIdx == 0 {
|
||||
cutoffIdx = len(probsData) // Use all tokens if topP is very high
|
||||
}
|
||||
|
||||
// Sample from the truncated distribution
|
||||
// Renormalize the truncated probabilities
|
||||
truncatedProbs := make([]float32, cutoffIdx)
|
||||
for i := 0; i < cutoffIdx; i++ {
|
||||
truncatedProbs[i] = probsData[i] / totalProb
|
||||
}
|
||||
|
||||
// Sample using random number with provided seed for reproducibility
|
||||
r := mlx.RandomUniform([]int32{1}, uint64(seed))
|
||||
mlx.Eval(r)
|
||||
randVal := r.Data()[0]
|
||||
|
||||
// Find the sampled token
|
||||
var cumulative float32
|
||||
for i := 0; i < cutoffIdx; i++ {
|
||||
cumulative += truncatedProbs[i]
|
||||
if randVal < cumulative {
|
||||
return indicesData[i]
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to the last token in truncated set
|
||||
return indicesData[cutoffIdx-1]
|
||||
}
|
||||
|
||||
// Forward for GLMBlock
|
||||
func (b *GLMBlock) Forward(x *mlx.Array, seqLen int32, eps float32) *mlx.Array {
|
||||
return b.ForwardWithCache(x, seqLen, 0, eps, nil, nil)
|
||||
}
|
||||
|
||||
// ForwardWithCache performs block forward with optional KV caching and MRoPE
|
||||
// positionIDs: [3][L] - position indices for MRoPE (nil = use sequential positions)
|
||||
func (b *GLMBlock) ForwardWithCache(x *mlx.Array, seqLen int32, posOffset int32, eps float32, kvcache cache.Cache, positionIDs [][]int32) *mlx.Array {
|
||||
// Pre-attention norm
|
||||
normed := b.InputLayerNorm.Forward(x, eps)
|
||||
|
||||
// Self-attention with RoPE/MRoPE and cache
|
||||
attnOut := b.SelfAttn.ForwardWithCache(normed, seqLen, posOffset, kvcache, positionIDs)
|
||||
|
||||
// Post-attention norm (GLM-4 style)
|
||||
attnOut = b.PostSelfAttnNorm.Forward(attnOut, eps)
|
||||
|
||||
// Residual connection
|
||||
x = mlx.Add(x, attnOut)
|
||||
|
||||
// Post-attention layer norm
|
||||
normed = b.PostAttnLayerNorm.Forward(x, eps)
|
||||
|
||||
// MLP
|
||||
mlpOut := b.MLP.Forward(normed)
|
||||
|
||||
// Post-MLP norm
|
||||
mlpOut = b.PostMLPLayerNorm.Forward(mlpOut, eps)
|
||||
|
||||
// Residual connection
|
||||
x = mlx.Add(x, mlpOut)
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// Forward for GLMAttention (without cache - used for prefill)
|
||||
func (attn *GLMAttention) Forward(x *mlx.Array, seqLen int32) *mlx.Array {
|
||||
return attn.ForwardWithCache(x, seqLen, 0, nil, nil)
|
||||
}
|
||||
|
||||
// ForwardWithCache performs attention with optional KV caching and MRoPE
|
||||
// posOffset is the position offset for RoPE (0 for prefill, cached_len for decode)
|
||||
// positionIDs: [3][L] - if nil, uses sequential positions for all dims (text mode)
|
||||
// kvcache is updated in-place if provided
|
||||
func (attn *GLMAttention) ForwardWithCache(x *mlx.Array, seqLen int32, posOffset int32, kvcache cache.Cache, positionIDs [][]int32) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
|
||||
// Q, K, V projections
|
||||
q := mlx.Matmul(x, mlx.Transpose(attn.QProj, 1, 0))
|
||||
k := mlx.Matmul(x, mlx.Transpose(attn.KProj, 1, 0))
|
||||
v := mlx.Matmul(x, mlx.Transpose(attn.VProj, 1, 0))
|
||||
|
||||
// Add biases
|
||||
if attn.QBias != nil {
|
||||
q = mlx.Add(q, attn.QBias)
|
||||
}
|
||||
if attn.KBias != nil {
|
||||
k = mlx.Add(k, attn.KBias)
|
||||
}
|
||||
if attn.VBias != nil {
|
||||
v = mlx.Add(v, attn.VBias)
|
||||
}
|
||||
|
||||
// Reshape to [B, L, nheads, head_dim]
|
||||
q = mlx.Reshape(q, B, L, attn.NHeads, attn.HeadDim)
|
||||
k = mlx.Reshape(k, B, L, attn.NKVHeads, attn.HeadDim)
|
||||
v = mlx.Reshape(v, B, L, attn.NKVHeads, attn.HeadDim)
|
||||
|
||||
// Apply partial RoPE or MRoPE
|
||||
rotaryDim := int32(float32(attn.HeadDim) * attn.PartialRotary)
|
||||
if len(attn.MRoPESection) == 3 && positionIDs != nil {
|
||||
// Use MRoPE with explicit position IDs
|
||||
q = applyPartialMRoPE(q, positionIDs, rotaryDim, attn.RopeTheta, attn.MRoPESection)
|
||||
k = applyPartialMRoPE(k, positionIDs, rotaryDim, attn.RopeTheta, attn.MRoPESection)
|
||||
} else if len(attn.MRoPESection) == 3 {
|
||||
// Use MRoPE with sequential positions (same for all dims - text mode)
|
||||
seqPositions := make([][]int32, 3)
|
||||
for dim := 0; dim < 3; dim++ {
|
||||
seqPositions[dim] = make([]int32, L)
|
||||
for i := int32(0); i < L; i++ {
|
||||
seqPositions[dim][i] = i + posOffset
|
||||
}
|
||||
}
|
||||
q = applyPartialMRoPE(q, seqPositions, rotaryDim, attn.RopeTheta, attn.MRoPESection)
|
||||
k = applyPartialMRoPE(k, seqPositions, rotaryDim, attn.RopeTheta, attn.MRoPESection)
|
||||
} else {
|
||||
// Fallback to standard RoPE
|
||||
q = applyPartialRoPEWithOffset(q, L, posOffset, rotaryDim, attn.RopeTheta)
|
||||
k = applyPartialRoPEWithOffset(k, L, posOffset, rotaryDim, attn.RopeTheta)
|
||||
}
|
||||
|
||||
// Transpose to [B, nheads, L, head_dim]
|
||||
q = mlx.Transpose(q, 0, 2, 1, 3)
|
||||
k = mlx.Transpose(k, 0, 2, 1, 3)
|
||||
v = mlx.Transpose(v, 0, 2, 1, 3)
|
||||
|
||||
// Update cache and get full K, V for attention
|
||||
if kvcache != nil {
|
||||
k, v = kvcache.Update(k, v, int(L))
|
||||
}
|
||||
|
||||
// Repeat KV for GQA
|
||||
kExpanded := k
|
||||
vExpanded := v
|
||||
if attn.NKVHeads < attn.NHeads {
|
||||
repeats := attn.NHeads / attn.NKVHeads
|
||||
kExpanded = repeatKV(k, repeats)
|
||||
vExpanded = repeatKV(v, repeats)
|
||||
}
|
||||
|
||||
// Scaled dot-product attention with causal mask
|
||||
out := mlx.ScaledDotProductAttention(q, kExpanded, vExpanded, attn.Scale, true)
|
||||
|
||||
// Transpose back [B, nheads, L, head_dim] -> [B, L, nheads, head_dim]
|
||||
out = mlx.Transpose(out, 0, 2, 1, 3)
|
||||
// Reshape to [B, L, hidden_size]
|
||||
out = mlx.Reshape(out, B, L, attn.NHeads*attn.HeadDim)
|
||||
|
||||
// Output projection
|
||||
out = mlx.Matmul(out, mlx.Transpose(attn.OProj, 1, 0))
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// applyPartialRoPE applies RoPE to only the first rotaryDim dimensions
|
||||
func applyPartialRoPE(x *mlx.Array, seqLen int32, rotaryDim int32, theta float32) *mlx.Array {
|
||||
return applyPartialRoPEWithOffset(x, seqLen, 0, rotaryDim, theta)
|
||||
}
|
||||
|
||||
// applyPartialRoPEWithOffset applies RoPE with a position offset
|
||||
func applyPartialRoPEWithOffset(x *mlx.Array, seqLen int32, posOffset int32, rotaryDim int32, theta float32) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
H := shape[2]
|
||||
D := shape[3]
|
||||
|
||||
if rotaryDim <= 0 || rotaryDim > D {
|
||||
rotaryDim = D
|
||||
}
|
||||
|
||||
// Split into rotary and pass-through parts
|
||||
xRot := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, L, H, rotaryDim})
|
||||
xPass := mlx.Slice(x, []int32{0, 0, 0, rotaryDim}, []int32{B, L, H, D})
|
||||
|
||||
// Apply RoPE to rotary part with position offset
|
||||
xRot = applyRoPEWithOffset(xRot, L, posOffset, theta)
|
||||
|
||||
// Concatenate back
|
||||
return mlx.Concatenate([]*mlx.Array{xRot, xPass}, 3)
|
||||
}
|
||||
|
||||
// applyPartialMRoPE applies Multi-dimensional RoPE (MRoPE) to the first rotaryDim dimensions
|
||||
// positionIDs: [3, L] - position indices for each dimension (temporal, height, width)
|
||||
// mrope_section: [8, 12, 12] - frequency pairs per dimension
|
||||
// For text tokens: all 3 dimensions have the same sequential position
|
||||
// For image tokens: temporal=seq_idx, height=row, width=col
|
||||
func applyPartialMRoPE(x *mlx.Array, positionIDs [][]int32, rotaryDim int32, theta float32, mropeSection []int32) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
H := shape[2]
|
||||
D := shape[3]
|
||||
|
||||
if rotaryDim <= 0 || rotaryDim > D {
|
||||
rotaryDim = D
|
||||
}
|
||||
|
||||
// Split into rotary and pass-through parts
|
||||
xRot := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, L, H, rotaryDim})
|
||||
xPass := mlx.Slice(x, []int32{0, 0, 0, rotaryDim}, []int32{B, L, H, D})
|
||||
|
||||
// Apply MRoPE to rotary part
|
||||
xRot = applyMRoPE(xRot, positionIDs, theta, mropeSection)
|
||||
|
||||
// Concatenate back
|
||||
return mlx.Concatenate([]*mlx.Array{xRot, xPass}, 3)
|
||||
}
|
||||
|
||||
// applyMRoPE applies multi-dimensional rotary position embedding
|
||||
// x: [B, L, H, D] where D is the rotary dimension
|
||||
// positionIDs: [3][L] - positions for temporal, height, width dimensions
|
||||
// mropeSection: [8, 12, 12] - frequency pairs per dimension
|
||||
func applyMRoPE(x *mlx.Array, positionIDs [][]int32, theta float32, mropeSection []int32) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
H := shape[2]
|
||||
D := shape[3]
|
||||
half := D / 2
|
||||
|
||||
// Validate mrope_section sums to half (number of frequency pairs)
|
||||
var totalPairs int32
|
||||
for _, s := range mropeSection {
|
||||
totalPairs += s
|
||||
}
|
||||
if totalPairs != half {
|
||||
// Fallback to standard RoPE if section doesn't match
|
||||
return applyRoPEWithOffset(x, L, 0, theta)
|
||||
}
|
||||
|
||||
// Build angles for each position dimension (matching Python's MRoPE approach)
|
||||
// Python: compute freqs for all dims, then apply_mrope selects freq ranges, then duplicate
|
||||
// Order: [temporal_8, height_12, width_12] -> duplicate -> [t8, h12, w12, t8, h12, w12]
|
||||
angleVals := make([]*mlx.Array, 3)
|
||||
|
||||
freqOffset := int32(0)
|
||||
for dim := 0; dim < 3; dim++ {
|
||||
numPairs := mropeSection[dim]
|
||||
if numPairs == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Compute inverse frequencies for this section
|
||||
// Each dimension uses DIFFERENT frequency ranges:
|
||||
// - Temporal: frequencies 0 to section[0]-1
|
||||
// - Height: frequencies section[0] to section[0]+section[1]-1
|
||||
// - Width: frequencies section[0]+section[1] to sum(section)-1
|
||||
freqsArr := make([]float32, numPairs)
|
||||
for i := int32(0); i < numPairs; i++ {
|
||||
globalIdx := freqOffset + i
|
||||
freqsArr[i] = float32(1.0 / math.Pow(float64(theta), float64(2*globalIdx)/float64(D)))
|
||||
}
|
||||
freqs := mlx.NewArray(freqsArr, []int32{numPairs})
|
||||
|
||||
// Position indices for this dimension
|
||||
posArr := make([]float32, L)
|
||||
for i := int32(0); i < L; i++ {
|
||||
posArr[i] = float32(positionIDs[dim][i])
|
||||
}
|
||||
pos := mlx.NewArray(posArr, []int32{L})
|
||||
|
||||
// Compute angles: [L, numPairs] = outer(pos, freqs)
|
||||
posExpanded := mlx.Reshape(pos, L, 1)
|
||||
freqsExpanded := mlx.Reshape(freqs, 1, numPairs)
|
||||
angleVals[dim] = mlx.Mul(posExpanded, freqsExpanded)
|
||||
|
||||
freqOffset += numPairs
|
||||
}
|
||||
|
||||
// Concatenate all sections: [L, half] = [L, 32]
|
||||
allAngles := mlx.Concatenate(angleVals, 1)
|
||||
|
||||
// Duplicate AFTER concatenation: [L, D] = [L, 64]
|
||||
// This gives: [temporal_8, height_12, width_12, temporal_8, height_12, width_12]
|
||||
allAngles = mlx.Concatenate([]*mlx.Array{allAngles, allAngles}, 1)
|
||||
|
||||
// Compute cos/sin
|
||||
allCos := mlx.Cos(allAngles)
|
||||
allSin := mlx.Sin(allAngles)
|
||||
|
||||
// Reshape for broadcasting: [1, L, 1, D] to match x [B, L, H, D]
|
||||
allCos = mlx.Reshape(allCos, 1, L, 1, D)
|
||||
allSin = mlx.Reshape(allSin, 1, L, 1, D)
|
||||
|
||||
// x_rotated = cat([-x_imag, x_real], dim=-1)
|
||||
x1 := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, L, H, half}) // x_real
|
||||
x2 := mlx.Slice(x, []int32{0, 0, 0, half}, []int32{B, L, H, D}) // x_imag
|
||||
x2Neg := mlx.MulScalar(x2, -1) // -x_imag
|
||||
xRotated := mlx.Concatenate([]*mlx.Array{x2Neg, x1}, 3) // [-x_imag, x_real]
|
||||
|
||||
// out = x * cos + x_rotated * sin
|
||||
return mlx.Add(mlx.Mul(x, allCos), mlx.Mul(xRotated, allSin))
|
||||
}
|
||||
|
||||
// applyRoPE applies rotary position embedding
|
||||
func applyRoPE(x *mlx.Array, seqLen int32, theta float32) *mlx.Array {
|
||||
return applyRoPEWithOffset(x, seqLen, 0, theta)
|
||||
}
|
||||
|
||||
// applyRoPEWithOffset applies rotary position embedding with position offset
|
||||
// Uses the split-half approach (matches diffusers GLM-Image with use_real_unbind_dim=-2)
|
||||
func applyRoPEWithOffset(x *mlx.Array, seqLen int32, posOffset int32, theta float32) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B := shape[0]
|
||||
L := shape[1]
|
||||
H := shape[2]
|
||||
D := shape[3]
|
||||
half := D / 2
|
||||
|
||||
// Compute inverse frequencies: 1 / (theta^(2i/d))
|
||||
freqsArr := make([]float32, half)
|
||||
for i := int32(0); i < half; i++ {
|
||||
freqsArr[i] = float32(1.0 / math.Pow(float64(theta), float64(2*i)/float64(D)))
|
||||
}
|
||||
freqs := mlx.NewArray(freqsArr, []int32{half})
|
||||
|
||||
// Position indices with offset
|
||||
posArr := make([]float32, L)
|
||||
for i := int32(0); i < L; i++ {
|
||||
posArr[i] = float32(i + posOffset)
|
||||
}
|
||||
pos := mlx.NewArray(posArr, []int32{L})
|
||||
|
||||
// Compute angles: [L, half] = outer(pos, freqs)
|
||||
posExpanded := mlx.Reshape(pos, L, 1)
|
||||
freqsExpanded := mlx.Reshape(freqs, 1, half)
|
||||
angles := mlx.Mul(posExpanded, freqsExpanded)
|
||||
|
||||
// Duplicate angles to match diffusers: cat([angles, angles], dim=-1) -> [L, D]
|
||||
anglesDup := mlx.Concatenate([]*mlx.Array{angles, angles}, 1)
|
||||
|
||||
// Cos and sin: [L, 1, D] for broadcasting to [B, L, H, D]
|
||||
cosVals := mlx.Cos(anglesDup)
|
||||
sinVals := mlx.Sin(anglesDup)
|
||||
cosVals = mlx.Reshape(cosVals, L, 1, D)
|
||||
sinVals = mlx.Reshape(sinVals, L, 1, D)
|
||||
|
||||
// x_rotated = cat([-x_imag, x_real], dim=-1) where x_real=x[..., :half], x_imag=x[..., half:]
|
||||
x1 := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, L, H, half}) // x_real
|
||||
x2 := mlx.Slice(x, []int32{0, 0, 0, half}, []int32{B, L, H, D}) // x_imag
|
||||
x2Neg := mlx.MulScalar(x2, -1) // -x_imag
|
||||
xRotated := mlx.Concatenate([]*mlx.Array{x2Neg, x1}, 3) // [-x_imag, x_real]
|
||||
|
||||
// out = x * cos + x_rotated * sin
|
||||
return mlx.Add(mlx.Mul(x, cosVals), mlx.Mul(xRotated, sinVals))
|
||||
}
|
||||
|
||||
// repeatKV repeats key/value heads for GQA
|
||||
func repeatKV(x *mlx.Array, repeats int32) *mlx.Array {
|
||||
if repeats == 1 {
|
||||
return x
|
||||
}
|
||||
shape := x.Shape()
|
||||
// x: [B, nkvheads, L, head_dim]
|
||||
x = mlx.ExpandDims(x, 2)
|
||||
// x: [B, nkvheads, 1, L, head_dim]
|
||||
x = mlx.Tile(x, []int32{1, 1, repeats, 1, 1})
|
||||
// x: [B, nkvheads, repeats, L, head_dim]
|
||||
return mlx.Reshape(x, shape[0], shape[1]*repeats, shape[2], shape[3])
|
||||
}
|
||||
|
||||
// Forward for GLMMLP (fused gate_up SwiGLU)
|
||||
func (m *GLMMLP) Forward(x *mlx.Array) *mlx.Array {
|
||||
// gate_up_proj outputs [gate, up] concatenated
|
||||
gateUp := mlx.Matmul(x, mlx.Transpose(m.GateUpProj, 1, 0))
|
||||
|
||||
shape := gateUp.Shape()
|
||||
halfDim := shape[len(shape)-1] / 2
|
||||
|
||||
// Split into gate and up
|
||||
gate := mlx.Slice(gateUp, []int32{0, 0, 0}, []int32{shape[0], shape[1], halfDim})
|
||||
up := mlx.Slice(gateUp, []int32{0, 0, halfDim}, []int32{shape[0], shape[1], shape[2]})
|
||||
|
||||
// SwiGLU: silu(gate) * up
|
||||
gate = mlx.SiLU(gate)
|
||||
h := mlx.Mul(gate, up)
|
||||
|
||||
// Down projection
|
||||
return mlx.Matmul(h, mlx.Transpose(m.DownProj, 1, 0))
|
||||
}
|
||||
@@ -1,22 +0,0 @@
|
||||
package imagegen
|
||||
|
||||
import (
|
||||
"io"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// QuantizingTensorLayerCreator creates tensor layers with optional quantization.
|
||||
// When quantize is true, returns multiple layers (weight + scales + biases).
|
||||
type QuantizingTensorLayerCreator func(r io.Reader, name, dtype string, shape []int32, quantize bool) ([]LayerInfo, error)
|
||||
|
||||
// ShouldQuantize returns true if a tensor should be quantized.
|
||||
// Quantizes linear weights only, skipping VAE, embeddings, norms, and biases.
|
||||
func ShouldQuantize(name, component string) bool {
|
||||
if component == "vae" {
|
||||
return false
|
||||
}
|
||||
if strings.Contains(name, "embed") || strings.Contains(name, "norm") {
|
||||
return false
|
||||
}
|
||||
return strings.HasSuffix(name, ".weight")
|
||||
}
|
||||
@@ -19,15 +19,9 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/models/glm_image"
|
||||
"github.com/ollama/ollama/x/imagegen/models/zimage"
|
||||
)
|
||||
|
||||
// ImageModel is the interface for image generation models
|
||||
type ImageModel interface {
|
||||
GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error)
|
||||
}
|
||||
|
||||
// Request is the image generation request format
|
||||
type Request struct {
|
||||
Prompt string `json:"prompt"`
|
||||
@@ -47,9 +41,8 @@ type Response struct {
|
||||
// Server holds the model and handles requests
|
||||
type Server struct {
|
||||
mu sync.Mutex
|
||||
model ImageModel
|
||||
model *zimage.Model
|
||||
modelName string
|
||||
modelType string // "zimage" or "glm_image"
|
||||
}
|
||||
|
||||
// Execute is the entry point for the image runner subprocess
|
||||
@@ -79,35 +72,15 @@ func Execute(args []string) error {
|
||||
requiredMemory/(1024*1024*1024), availableMemory/(1024*1024*1024))
|
||||
}
|
||||
|
||||
// Detect model type and load appropriate model
|
||||
modelType, err := detectModelType(*modelName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to detect model type: %w", err)
|
||||
}
|
||||
|
||||
var model ImageModel
|
||||
switch modelType {
|
||||
case "GlmImagePipeline":
|
||||
slog.Info("loading GLM-Image model")
|
||||
m := &glm_image.Model{}
|
||||
if err := m.Load(*modelName); err != nil {
|
||||
return fmt.Errorf("failed to load GLM-Image model: %w", err)
|
||||
}
|
||||
model = m
|
||||
default:
|
||||
// Default to zimage for ZImagePipeline, FluxPipeline, and unknown types
|
||||
slog.Info("loading Z-Image model")
|
||||
m := &zimage.Model{}
|
||||
if err := m.Load(*modelName); err != nil {
|
||||
return fmt.Errorf("failed to load Z-Image model: %w", err)
|
||||
}
|
||||
model = m
|
||||
// Load model
|
||||
model := &zimage.Model{}
|
||||
if err := model.Load(*modelName); err != nil {
|
||||
return fmt.Errorf("failed to load model: %w", err)
|
||||
}
|
||||
|
||||
server := &Server{
|
||||
model: model,
|
||||
modelName: *modelName,
|
||||
modelType: modelType,
|
||||
}
|
||||
|
||||
// Set up HTTP handlers
|
||||
@@ -171,13 +144,7 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
|
||||
req.Height = 1024
|
||||
}
|
||||
if req.Steps <= 0 {
|
||||
// Default steps depend on model type
|
||||
switch s.modelType {
|
||||
case "GlmImagePipeline":
|
||||
req.Steps = 50 // GLM-Image default
|
||||
default:
|
||||
req.Steps = 9 // Z-Image turbo default
|
||||
}
|
||||
req.Steps = 9
|
||||
}
|
||||
if req.Seed <= 0 {
|
||||
req.Seed = time.Now().UnixNano()
|
||||
@@ -192,9 +159,25 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Generate image using interface method
|
||||
// Generate image
|
||||
ctx := r.Context()
|
||||
img, err := s.model.GenerateImage(ctx, req.Prompt, req.Width, req.Height, req.Steps, req.Seed)
|
||||
img, err := s.model.GenerateFromConfig(ctx, &zimage.GenerateConfig{
|
||||
Prompt: req.Prompt,
|
||||
Width: req.Width,
|
||||
Height: req.Height,
|
||||
Steps: req.Steps,
|
||||
Seed: req.Seed,
|
||||
Progress: func(step, total int) {
|
||||
resp := Response{
|
||||
Content: fmt.Sprintf("\rGenerating: step %d/%d", step, total),
|
||||
Done: false,
|
||||
}
|
||||
data, _ := json.Marshal(resp)
|
||||
w.Write(data)
|
||||
w.Write([]byte("\n"))
|
||||
flusher.Flush()
|
||||
},
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
// Don't send error for cancellation
|
||||
@@ -233,35 +216,3 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("\n"))
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
// detectModelType reads the model manifest and returns the pipeline class name
|
||||
func detectModelType(modelName string) (string, error) {
|
||||
manifest, err := imagegen.LoadManifest(modelName)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
data, err := manifest.ReadConfig("model_index.json")
|
||||
if err != nil {
|
||||
return "ZImagePipeline", nil // Default to Z-Image
|
||||
}
|
||||
|
||||
// Try both _class_name (diffusers format) and architecture (ollama format)
|
||||
var index struct {
|
||||
ClassName string `json:"_class_name"`
|
||||
Architecture string `json:"architecture"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &index); err != nil {
|
||||
return "ZImagePipeline", nil
|
||||
}
|
||||
|
||||
// Prefer _class_name, fall back to architecture
|
||||
className := index.ClassName
|
||||
if className == "" {
|
||||
className = index.Architecture
|
||||
}
|
||||
if className == "" {
|
||||
return "ZImagePipeline", nil
|
||||
}
|
||||
return className, nil
|
||||
}
|
||||
|
||||
271
x/server/show.go
Normal file
271
x/server/show.go
Normal file
@@ -0,0 +1,271 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
)
|
||||
|
||||
// modelConfig represents the HuggingFace config.json structure
|
||||
type modelConfig struct {
|
||||
Architectures []string `json:"architectures"`
|
||||
ModelType string `json:"model_type"`
|
||||
HiddenSize int `json:"hidden_size"`
|
||||
NumHiddenLayers int `json:"num_hidden_layers"`
|
||||
MaxPositionEmbeddings int `json:"max_position_embeddings"`
|
||||
IntermediateSize int `json:"intermediate_size"`
|
||||
NumAttentionHeads int `json:"num_attention_heads"`
|
||||
NumKeyValueHeads int `json:"num_key_value_heads"`
|
||||
VocabSize int `json:"vocab_size"`
|
||||
RMSNormEps float64 `json:"rms_norm_eps"`
|
||||
RopeTheta float64 `json:"rope_theta"`
|
||||
TorchDtype string `json:"torch_dtype"`
|
||||
TextConfig *struct {
|
||||
HiddenSize int `json:"hidden_size"`
|
||||
MaxPositionEmbeddings int `json:"max_position_embeddings"`
|
||||
NumHiddenLayers int `json:"num_hidden_layers"`
|
||||
} `json:"text_config"`
|
||||
}
|
||||
|
||||
// GetSafetensorsLLMInfo extracts model information from safetensors LLM models.
|
||||
// It reads the config.json layer and returns a map compatible with GGML's KV format.
|
||||
func GetSafetensorsLLMInfo(modelName string) (map[string]any, error) {
|
||||
manifest, err := imagegen.LoadManifest(modelName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load manifest: %w", err)
|
||||
}
|
||||
|
||||
var config modelConfig
|
||||
if err := manifest.ReadConfigJSON("config.json", &config); err != nil {
|
||||
return nil, fmt.Errorf("failed to read config.json: %w", err)
|
||||
}
|
||||
|
||||
// Calculate total tensor bytes from manifest layers
|
||||
var totalBytes int64
|
||||
var tensorCount int64
|
||||
for _, layer := range manifest.Manifest.Layers {
|
||||
if layer.MediaType == "application/vnd.ollama.image.tensor" {
|
||||
totalBytes += layer.Size
|
||||
tensorCount++
|
||||
}
|
||||
}
|
||||
|
||||
return buildModelInfo(config, totalBytes, tensorCount), nil
|
||||
}
|
||||
|
||||
// buildModelInfo constructs the model info map from config and tensor stats.
|
||||
// This is separated for testability.
|
||||
func buildModelInfo(config modelConfig, totalTensorBytes, tensorCount int64) map[string]any {
|
||||
// Determine architecture
|
||||
arch := config.ModelType
|
||||
if arch == "" && len(config.Architectures) > 0 {
|
||||
// Convert HuggingFace architecture name to Ollama format
|
||||
// e.g., "Gemma3ForCausalLM" -> "gemma3"
|
||||
hfArch := config.Architectures[0]
|
||||
arch = strings.ToLower(hfArch)
|
||||
arch = strings.TrimSuffix(arch, "forcausallm")
|
||||
arch = strings.TrimSuffix(arch, "forconditionalgeneration")
|
||||
}
|
||||
|
||||
// Use text_config values if they exist (for multimodal models)
|
||||
hiddenSize := config.HiddenSize
|
||||
maxPosEmbed := config.MaxPositionEmbeddings
|
||||
numLayers := config.NumHiddenLayers
|
||||
|
||||
if config.TextConfig != nil {
|
||||
if config.TextConfig.HiddenSize > 0 {
|
||||
hiddenSize = config.TextConfig.HiddenSize
|
||||
}
|
||||
if config.TextConfig.MaxPositionEmbeddings > 0 {
|
||||
maxPosEmbed = config.TextConfig.MaxPositionEmbeddings
|
||||
}
|
||||
if config.TextConfig.NumHiddenLayers > 0 {
|
||||
numLayers = config.TextConfig.NumHiddenLayers
|
||||
}
|
||||
}
|
||||
|
||||
// Get dtype to determine bytes per parameter for count calculation
|
||||
dtype := config.TorchDtype
|
||||
|
||||
// Determine bytes per parameter based on dtype
|
||||
var bytesPerParam int64 = 2 // default to float16/bfloat16
|
||||
switch strings.ToLower(dtype) {
|
||||
case "float32":
|
||||
bytesPerParam = 4
|
||||
case "float16", "bfloat16":
|
||||
bytesPerParam = 2
|
||||
case "int8", "uint8":
|
||||
bytesPerParam = 1
|
||||
}
|
||||
|
||||
// Subtract safetensors header overhead (88 bytes per tensor file)
|
||||
// Each tensor is stored as a minimal safetensors file
|
||||
totalBytes := totalTensorBytes - tensorCount*88
|
||||
|
||||
paramCount := totalBytes / bytesPerParam
|
||||
|
||||
info := map[string]any{
|
||||
"general.architecture": arch,
|
||||
}
|
||||
|
||||
if maxPosEmbed > 0 {
|
||||
info[fmt.Sprintf("%s.context_length", arch)] = maxPosEmbed
|
||||
}
|
||||
|
||||
if hiddenSize > 0 {
|
||||
info[fmt.Sprintf("%s.embedding_length", arch)] = hiddenSize
|
||||
}
|
||||
|
||||
if numLayers > 0 {
|
||||
info[fmt.Sprintf("%s.block_count", arch)] = numLayers
|
||||
}
|
||||
|
||||
if config.NumAttentionHeads > 0 {
|
||||
info[fmt.Sprintf("%s.attention.head_count", arch)] = config.NumAttentionHeads
|
||||
}
|
||||
|
||||
if config.NumKeyValueHeads > 0 {
|
||||
info[fmt.Sprintf("%s.attention.head_count_kv", arch)] = config.NumKeyValueHeads
|
||||
}
|
||||
|
||||
if config.IntermediateSize > 0 {
|
||||
info[fmt.Sprintf("%s.feed_forward_length", arch)] = config.IntermediateSize
|
||||
}
|
||||
|
||||
if config.VocabSize > 0 {
|
||||
info[fmt.Sprintf("%s.vocab_size", arch)] = config.VocabSize
|
||||
}
|
||||
|
||||
if paramCount > 0 {
|
||||
info["general.parameter_count"] = paramCount
|
||||
}
|
||||
|
||||
return info
|
||||
}
|
||||
|
||||
// GetSafetensorsTensorInfo extracts tensor information from safetensors model layers.
|
||||
// Each tensor is stored as a minimal safetensors file with an 88-byte header containing metadata.
|
||||
func GetSafetensorsTensorInfo(modelName string) ([]api.Tensor, error) {
|
||||
manifest, err := imagegen.LoadManifest(modelName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load manifest: %w", err)
|
||||
}
|
||||
|
||||
return getTensorInfoFromManifest(manifest)
|
||||
}
|
||||
|
||||
// getTensorInfoFromManifest extracts tensor info from a manifest.
|
||||
// This is separated for testability.
|
||||
func getTensorInfoFromManifest(manifest *imagegen.ModelManifest) ([]api.Tensor, error) {
|
||||
var tensors []api.Tensor
|
||||
|
||||
for _, layer := range manifest.Manifest.Layers {
|
||||
if layer.MediaType != "application/vnd.ollama.image.tensor" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Read the safetensors header from the blob
|
||||
blobPath := manifest.BlobPath(layer.Digest)
|
||||
info, err := readSafetensorsHeader(blobPath)
|
||||
if err != nil {
|
||||
// Skip tensors we can't read
|
||||
continue
|
||||
}
|
||||
|
||||
// Convert shape from int to uint64
|
||||
shape := make([]uint64, len(info.Shape))
|
||||
for i, s := range info.Shape {
|
||||
shape[i] = uint64(s)
|
||||
}
|
||||
|
||||
tensors = append(tensors, api.Tensor{
|
||||
Name: layer.Name,
|
||||
Type: info.Dtype,
|
||||
Shape: shape,
|
||||
})
|
||||
}
|
||||
|
||||
return tensors, nil
|
||||
}
|
||||
|
||||
// GetSafetensorsDtype returns the torch_dtype from config.json for a safetensors model.
|
||||
func GetSafetensorsDtype(modelName string) (string, error) {
|
||||
manifest, err := imagegen.LoadManifest(modelName)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to load manifest: %w", err)
|
||||
}
|
||||
|
||||
var cfg struct {
|
||||
TorchDtype string `json:"torch_dtype"`
|
||||
}
|
||||
if err := manifest.ReadConfigJSON("config.json", &cfg); err != nil {
|
||||
return "", fmt.Errorf("failed to read config.json: %w", err)
|
||||
}
|
||||
|
||||
return cfg.TorchDtype, nil
|
||||
}
|
||||
|
||||
// safetensorsTensorInfo holds metadata about a tensor from a safetensors header
|
||||
type safetensorsTensorInfo struct {
|
||||
Dtype string `json:"dtype"`
|
||||
Shape []int64 `json:"shape"`
|
||||
}
|
||||
|
||||
// readSafetensorsHeader reads the JSON header from a safetensors file to get tensor metadata.
|
||||
// Safetensors format: 8-byte header size (little endian) + JSON header + tensor data
|
||||
func readSafetensorsHeader(path string) (*safetensorsTensorInfo, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
return parseSafetensorsHeader(f)
|
||||
}
|
||||
|
||||
// parseSafetensorsHeader parses a safetensors header from a reader.
|
||||
// This is separated for testability.
|
||||
func parseSafetensorsHeader(r io.Reader) (*safetensorsTensorInfo, error) {
|
||||
// Read header size (8 bytes, little endian)
|
||||
var headerSize uint64
|
||||
if err := binary.Read(r, binary.LittleEndian, &headerSize); err != nil {
|
||||
return nil, fmt.Errorf("failed to read header size: %w", err)
|
||||
}
|
||||
|
||||
// Sanity check - header shouldn't be too large
|
||||
if headerSize > 1024*1024 {
|
||||
return nil, fmt.Errorf("header size too large: %d", headerSize)
|
||||
}
|
||||
|
||||
// Read header JSON
|
||||
headerBytes := make([]byte, headerSize)
|
||||
if _, err := io.ReadFull(r, headerBytes); err != nil {
|
||||
return nil, fmt.Errorf("failed to read header: %w", err)
|
||||
}
|
||||
|
||||
// Parse as map of tensor name -> info
|
||||
var header map[string]json.RawMessage
|
||||
if err := json.Unmarshal(headerBytes, &header); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse header: %w", err)
|
||||
}
|
||||
|
||||
// Find the first (and should be only) tensor entry
|
||||
for name, raw := range header {
|
||||
if name == "__metadata__" {
|
||||
continue
|
||||
}
|
||||
var info safetensorsTensorInfo
|
||||
if err := json.Unmarshal(raw, &info); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse tensor info: %w", err)
|
||||
}
|
||||
return &info, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("no tensor found in header")
|
||||
}
|
||||
605
x/server/show_test.go
Normal file
605
x/server/show_test.go
Normal file
@@ -0,0 +1,605 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
)
|
||||
|
||||
func TestBuildModelInfo(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config modelConfig
|
||||
totalTensorBytes int64
|
||||
tensorCount int64
|
||||
wantArch string
|
||||
wantContextLen int
|
||||
wantEmbedLen int
|
||||
wantBlockCount int
|
||||
wantParamCount int64
|
||||
}{
|
||||
{
|
||||
name: "gemma3 model with model_type",
|
||||
config: modelConfig{
|
||||
ModelType: "gemma3",
|
||||
HiddenSize: 2560,
|
||||
NumHiddenLayers: 34,
|
||||
MaxPositionEmbeddings: 131072,
|
||||
IntermediateSize: 10240,
|
||||
NumAttentionHeads: 8,
|
||||
NumKeyValueHeads: 4,
|
||||
VocabSize: 262144,
|
||||
TorchDtype: "bfloat16",
|
||||
},
|
||||
totalTensorBytes: 8_600_000_088, // ~4.3B params * 2 bytes + 88 bytes header
|
||||
tensorCount: 1,
|
||||
wantArch: "gemma3",
|
||||
wantContextLen: 131072,
|
||||
wantEmbedLen: 2560,
|
||||
wantBlockCount: 34,
|
||||
wantParamCount: 4_300_000_000,
|
||||
},
|
||||
{
|
||||
name: "llama model with architectures array",
|
||||
config: modelConfig{
|
||||
Architectures: []string{"LlamaForCausalLM"},
|
||||
HiddenSize: 4096,
|
||||
NumHiddenLayers: 32,
|
||||
MaxPositionEmbeddings: 4096,
|
||||
IntermediateSize: 11008,
|
||||
NumAttentionHeads: 32,
|
||||
NumKeyValueHeads: 32,
|
||||
VocabSize: 32000,
|
||||
TorchDtype: "float16",
|
||||
},
|
||||
totalTensorBytes: 14_000_000_088, // ~7B params * 2 bytes + 88 bytes header
|
||||
tensorCount: 1,
|
||||
wantArch: "llama",
|
||||
wantContextLen: 4096,
|
||||
wantEmbedLen: 4096,
|
||||
wantBlockCount: 32,
|
||||
wantParamCount: 7_000_000_000,
|
||||
},
|
||||
{
|
||||
name: "multimodal model with text_config",
|
||||
config: modelConfig{
|
||||
Architectures: []string{"Gemma3ForConditionalGeneration"},
|
||||
HiddenSize: 1152, // vision hidden size
|
||||
TextConfig: &struct {
|
||||
HiddenSize int `json:"hidden_size"`
|
||||
MaxPositionEmbeddings int `json:"max_position_embeddings"`
|
||||
NumHiddenLayers int `json:"num_hidden_layers"`
|
||||
}{
|
||||
HiddenSize: 2560,
|
||||
MaxPositionEmbeddings: 131072,
|
||||
NumHiddenLayers: 34,
|
||||
},
|
||||
NumAttentionHeads: 8,
|
||||
NumKeyValueHeads: 4,
|
||||
VocabSize: 262144,
|
||||
TorchDtype: "bfloat16",
|
||||
},
|
||||
totalTensorBytes: 8_600_000_088,
|
||||
tensorCount: 1,
|
||||
wantArch: "gemma3",
|
||||
wantContextLen: 131072,
|
||||
wantEmbedLen: 2560,
|
||||
wantBlockCount: 34,
|
||||
wantParamCount: 4_300_000_000,
|
||||
},
|
||||
{
|
||||
name: "float32 model",
|
||||
config: modelConfig{
|
||||
ModelType: "test",
|
||||
HiddenSize: 512,
|
||||
NumHiddenLayers: 6,
|
||||
MaxPositionEmbeddings: 2048,
|
||||
TorchDtype: "float32",
|
||||
},
|
||||
totalTensorBytes: 400_000_088, // 100M params * 4 bytes + 88 bytes header
|
||||
tensorCount: 1,
|
||||
wantArch: "test",
|
||||
wantContextLen: 2048,
|
||||
wantEmbedLen: 512,
|
||||
wantBlockCount: 6,
|
||||
wantParamCount: 100_000_000,
|
||||
},
|
||||
{
|
||||
name: "multiple tensors with header overhead",
|
||||
config: modelConfig{
|
||||
ModelType: "test",
|
||||
HiddenSize: 256,
|
||||
NumHiddenLayers: 4,
|
||||
MaxPositionEmbeddings: 1024,
|
||||
TorchDtype: "bfloat16",
|
||||
},
|
||||
totalTensorBytes: 2_000_880, // 1M params * 2 bytes + 10 tensors * 88 bytes
|
||||
tensorCount: 10,
|
||||
wantArch: "test",
|
||||
wantContextLen: 1024,
|
||||
wantEmbedLen: 256,
|
||||
wantBlockCount: 4,
|
||||
wantParamCount: 1_000_000,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
info := buildModelInfo(tt.config, tt.totalTensorBytes, tt.tensorCount)
|
||||
|
||||
// Check architecture
|
||||
if arch, ok := info["general.architecture"].(string); !ok || arch != tt.wantArch {
|
||||
t.Errorf("architecture = %v, want %v", info["general.architecture"], tt.wantArch)
|
||||
}
|
||||
|
||||
// Check context length
|
||||
contextKey := tt.wantArch + ".context_length"
|
||||
if contextLen, ok := info[contextKey].(int); !ok || contextLen != tt.wantContextLen {
|
||||
t.Errorf("context_length = %v, want %v", info[contextKey], tt.wantContextLen)
|
||||
}
|
||||
|
||||
// Check embedding length
|
||||
embedKey := tt.wantArch + ".embedding_length"
|
||||
if embedLen, ok := info[embedKey].(int); !ok || embedLen != tt.wantEmbedLen {
|
||||
t.Errorf("embedding_length = %v, want %v", info[embedKey], tt.wantEmbedLen)
|
||||
}
|
||||
|
||||
// Check block count
|
||||
blockKey := tt.wantArch + ".block_count"
|
||||
if blockCount, ok := info[blockKey].(int); !ok || blockCount != tt.wantBlockCount {
|
||||
t.Errorf("block_count = %v, want %v", info[blockKey], tt.wantBlockCount)
|
||||
}
|
||||
|
||||
// Check parameter count
|
||||
if paramCount, ok := info["general.parameter_count"].(int64); !ok || paramCount != tt.wantParamCount {
|
||||
t.Errorf("parameter_count = %v, want %v", info["general.parameter_count"], tt.wantParamCount)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildModelInfo_ArchitectureConversion(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
architectures []string
|
||||
modelType string
|
||||
wantArch string
|
||||
}{
|
||||
{
|
||||
name: "LlamaForCausalLM",
|
||||
architectures: []string{"LlamaForCausalLM"},
|
||||
wantArch: "llama",
|
||||
},
|
||||
{
|
||||
name: "Gemma3ForCausalLM",
|
||||
architectures: []string{"Gemma3ForCausalLM"},
|
||||
wantArch: "gemma3",
|
||||
},
|
||||
{
|
||||
name: "Gemma3ForConditionalGeneration",
|
||||
architectures: []string{"Gemma3ForConditionalGeneration"},
|
||||
wantArch: "gemma3",
|
||||
},
|
||||
{
|
||||
name: "Qwen2ForCausalLM",
|
||||
architectures: []string{"Qwen2ForCausalLM"},
|
||||
wantArch: "qwen2",
|
||||
},
|
||||
{
|
||||
name: "model_type takes precedence",
|
||||
architectures: []string{"LlamaForCausalLM"},
|
||||
modelType: "custom",
|
||||
wantArch: "custom",
|
||||
},
|
||||
{
|
||||
name: "empty architectures with model_type",
|
||||
architectures: nil,
|
||||
modelType: "mymodel",
|
||||
wantArch: "mymodel",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
config := modelConfig{
|
||||
Architectures: tt.architectures,
|
||||
ModelType: tt.modelType,
|
||||
}
|
||||
info := buildModelInfo(config, 0, 0)
|
||||
|
||||
if arch, ok := info["general.architecture"].(string); !ok || arch != tt.wantArch {
|
||||
t.Errorf("architecture = %v, want %v", info["general.architecture"], tt.wantArch)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildModelInfo_BytesPerParam(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
dtype string
|
||||
totalBytes int64
|
||||
tensorCount int64
|
||||
wantParamCount int64
|
||||
}{
|
||||
{
|
||||
name: "bfloat16",
|
||||
dtype: "bfloat16",
|
||||
totalBytes: 2_000_088, // 1M * 2 + 88
|
||||
tensorCount: 1,
|
||||
wantParamCount: 1_000_000,
|
||||
},
|
||||
{
|
||||
name: "float16",
|
||||
dtype: "float16",
|
||||
totalBytes: 2_000_088,
|
||||
tensorCount: 1,
|
||||
wantParamCount: 1_000_000,
|
||||
},
|
||||
{
|
||||
name: "float32",
|
||||
dtype: "float32",
|
||||
totalBytes: 4_000_088, // 1M * 4 + 88
|
||||
tensorCount: 1,
|
||||
wantParamCount: 1_000_000,
|
||||
},
|
||||
{
|
||||
name: "int8",
|
||||
dtype: "int8",
|
||||
totalBytes: 1_000_088, // 1M * 1 + 88
|
||||
tensorCount: 1,
|
||||
wantParamCount: 1_000_000,
|
||||
},
|
||||
{
|
||||
name: "unknown dtype defaults to 2 bytes",
|
||||
dtype: "unknown",
|
||||
totalBytes: 2_000_088,
|
||||
tensorCount: 1,
|
||||
wantParamCount: 1_000_000,
|
||||
},
|
||||
{
|
||||
name: "empty dtype defaults to 2 bytes",
|
||||
dtype: "",
|
||||
totalBytes: 2_000_088,
|
||||
tensorCount: 1,
|
||||
wantParamCount: 1_000_000,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
config := modelConfig{
|
||||
ModelType: "test",
|
||||
TorchDtype: tt.dtype,
|
||||
}
|
||||
info := buildModelInfo(config, tt.totalBytes, tt.tensorCount)
|
||||
|
||||
if paramCount, ok := info["general.parameter_count"].(int64); !ok || paramCount != tt.wantParamCount {
|
||||
t.Errorf("parameter_count = %v, want %v", info["general.parameter_count"], tt.wantParamCount)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSafetensorsHeader(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
header map[string]any
|
||||
wantDtype string
|
||||
wantShape []int64
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "simple tensor",
|
||||
header: map[string]any{
|
||||
"weight": map[string]any{
|
||||
"dtype": "BF16",
|
||||
"shape": []int64{2560, 262144},
|
||||
"data_offsets": []int64{0, 1342177280},
|
||||
},
|
||||
},
|
||||
wantDtype: "BF16",
|
||||
wantShape: []int64{2560, 262144},
|
||||
},
|
||||
{
|
||||
name: "with metadata",
|
||||
header: map[string]any{
|
||||
"__metadata__": map[string]any{
|
||||
"format": "pt",
|
||||
},
|
||||
"bias": map[string]any{
|
||||
"dtype": "F32",
|
||||
"shape": []int64{1024},
|
||||
"data_offsets": []int64{0, 4096},
|
||||
},
|
||||
},
|
||||
wantDtype: "F32",
|
||||
wantShape: []int64{1024},
|
||||
},
|
||||
{
|
||||
name: "float16 tensor",
|
||||
header: map[string]any{
|
||||
"layer.weight": map[string]any{
|
||||
"dtype": "F16",
|
||||
"shape": []int64{512, 512, 3, 3},
|
||||
"data_offsets": []int64{0, 4718592},
|
||||
},
|
||||
},
|
||||
wantDtype: "F16",
|
||||
wantShape: []int64{512, 512, 3, 3},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create safetensors format: 8-byte size + JSON header
|
||||
headerJSON, err := json.Marshal(tt.header)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal header: %v", err)
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
if err := binary.Write(&buf, binary.LittleEndian, uint64(len(headerJSON))); err != nil {
|
||||
t.Fatalf("failed to write header size: %v", err)
|
||||
}
|
||||
buf.Write(headerJSON)
|
||||
|
||||
info, err := parseSafetensorsHeader(&buf)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("parseSafetensorsHeader() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if tt.wantErr {
|
||||
return
|
||||
}
|
||||
|
||||
if info.Dtype != tt.wantDtype {
|
||||
t.Errorf("Dtype = %v, want %v", info.Dtype, tt.wantDtype)
|
||||
}
|
||||
|
||||
if len(info.Shape) != len(tt.wantShape) {
|
||||
t.Errorf("Shape length = %v, want %v", len(info.Shape), len(tt.wantShape))
|
||||
} else {
|
||||
for i, s := range info.Shape {
|
||||
if s != tt.wantShape[i] {
|
||||
t.Errorf("Shape[%d] = %v, want %v", i, s, tt.wantShape[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSafetensorsHeader_Errors(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data []byte
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "empty data",
|
||||
data: []byte{},
|
||||
wantErr: "failed to read header size",
|
||||
},
|
||||
{
|
||||
name: "truncated header size",
|
||||
data: []byte{0x01, 0x02, 0x03},
|
||||
wantErr: "failed to read header size",
|
||||
},
|
||||
{
|
||||
name: "header size too large",
|
||||
data: func() []byte {
|
||||
var buf bytes.Buffer
|
||||
binary.Write(&buf, binary.LittleEndian, uint64(2*1024*1024)) // 2MB
|
||||
return buf.Bytes()
|
||||
}(),
|
||||
wantErr: "header size too large",
|
||||
},
|
||||
{
|
||||
name: "truncated header",
|
||||
data: func() []byte {
|
||||
var buf bytes.Buffer
|
||||
binary.Write(&buf, binary.LittleEndian, uint64(100))
|
||||
buf.Write([]byte("short"))
|
||||
return buf.Bytes()
|
||||
}(),
|
||||
wantErr: "failed to read header",
|
||||
},
|
||||
{
|
||||
name: "invalid JSON",
|
||||
data: func() []byte {
|
||||
var buf bytes.Buffer
|
||||
binary.Write(&buf, binary.LittleEndian, uint64(10))
|
||||
buf.Write([]byte("not json!!"))
|
||||
return buf.Bytes()
|
||||
}(),
|
||||
wantErr: "failed to parse header",
|
||||
},
|
||||
{
|
||||
name: "no tensors in header",
|
||||
data: func() []byte {
|
||||
header := map[string]any{
|
||||
"__metadata__": map[string]any{"format": "pt"},
|
||||
}
|
||||
headerJSON, _ := json.Marshal(header)
|
||||
var buf bytes.Buffer
|
||||
binary.Write(&buf, binary.LittleEndian, uint64(len(headerJSON)))
|
||||
buf.Write(headerJSON)
|
||||
return buf.Bytes()
|
||||
}(),
|
||||
wantErr: "no tensor found in header",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := parseSafetensorsHeader(bytes.NewReader(tt.data))
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
return
|
||||
}
|
||||
if !bytes.Contains([]byte(err.Error()), []byte(tt.wantErr)) {
|
||||
t.Errorf("error = %v, want error containing %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetTensorInfoFromManifest(t *testing.T) {
|
||||
// Create a temp directory for blobs
|
||||
tempDir, err := os.MkdirTemp("", "ollama-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
// Create test tensor blobs
|
||||
tensors := []struct {
|
||||
name string
|
||||
digest string
|
||||
dtype string
|
||||
shape []int64
|
||||
}{
|
||||
{
|
||||
name: "model.embed_tokens.weight",
|
||||
digest: "sha256:abc123",
|
||||
dtype: "BF16",
|
||||
shape: []int64{262144, 2560},
|
||||
},
|
||||
{
|
||||
name: "model.layers.0.self_attn.q_proj.weight",
|
||||
digest: "sha256:def456",
|
||||
dtype: "BF16",
|
||||
shape: []int64{2560, 2560},
|
||||
},
|
||||
{
|
||||
name: "model.norm.weight",
|
||||
digest: "sha256:ghi789",
|
||||
dtype: "F32",
|
||||
shape: []int64{2560},
|
||||
},
|
||||
}
|
||||
|
||||
// Create blob files
|
||||
var layers []imagegen.ManifestLayer
|
||||
for _, tensor := range tensors {
|
||||
// Create safetensors blob
|
||||
header := map[string]any{
|
||||
tensor.name: map[string]any{
|
||||
"dtype": tensor.dtype,
|
||||
"shape": tensor.shape,
|
||||
"data_offsets": []int64{0, 1000},
|
||||
},
|
||||
}
|
||||
headerJSON, _ := json.Marshal(header)
|
||||
|
||||
var buf bytes.Buffer
|
||||
binary.Write(&buf, binary.LittleEndian, uint64(len(headerJSON)))
|
||||
buf.Write(headerJSON)
|
||||
|
||||
// Write blob file
|
||||
blobName := "sha256-" + tensor.digest[7:]
|
||||
blobPath := filepath.Join(tempDir, blobName)
|
||||
if err := os.WriteFile(blobPath, buf.Bytes(), 0644); err != nil {
|
||||
t.Fatalf("failed to write blob: %v", err)
|
||||
}
|
||||
|
||||
layers = append(layers, imagegen.ManifestLayer{
|
||||
MediaType: "application/vnd.ollama.image.tensor",
|
||||
Digest: tensor.digest,
|
||||
Size: int64(buf.Len() + 1000), // header + fake data
|
||||
Name: tensor.name,
|
||||
})
|
||||
}
|
||||
|
||||
// Add a non-tensor layer (should be skipped)
|
||||
layers = append(layers, imagegen.ManifestLayer{
|
||||
MediaType: "application/vnd.ollama.image.json",
|
||||
Digest: "sha256:config",
|
||||
Size: 100,
|
||||
Name: "config.json",
|
||||
})
|
||||
|
||||
manifest := &imagegen.ModelManifest{
|
||||
Manifest: &imagegen.Manifest{
|
||||
Layers: layers,
|
||||
},
|
||||
BlobDir: tempDir,
|
||||
}
|
||||
|
||||
result, err := getTensorInfoFromManifest(manifest)
|
||||
if err != nil {
|
||||
t.Fatalf("getTensorInfoFromManifest() error = %v", err)
|
||||
}
|
||||
|
||||
if len(result) != 3 {
|
||||
t.Errorf("got %d tensors, want 3", len(result))
|
||||
}
|
||||
|
||||
// Verify each tensor
|
||||
for i, tensor := range tensors {
|
||||
if i >= len(result) {
|
||||
break
|
||||
}
|
||||
if result[i].Name != tensor.name {
|
||||
t.Errorf("tensor[%d].Name = %v, want %v", i, result[i].Name, tensor.name)
|
||||
}
|
||||
if result[i].Type != tensor.dtype {
|
||||
t.Errorf("tensor[%d].Type = %v, want %v", i, result[i].Type, tensor.dtype)
|
||||
}
|
||||
if len(result[i].Shape) != len(tensor.shape) {
|
||||
t.Errorf("tensor[%d].Shape length = %v, want %v", i, len(result[i].Shape), len(tensor.shape))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadSafetensorsHeader(t *testing.T) {
|
||||
// Create a temp file with a valid safetensors header
|
||||
tempDir, err := os.MkdirTemp("", "ollama-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
header := map[string]any{
|
||||
"test_tensor": map[string]any{
|
||||
"dtype": "BF16",
|
||||
"shape": []int64{1024, 768},
|
||||
"data_offsets": []int64{0, 1572864},
|
||||
},
|
||||
}
|
||||
headerJSON, _ := json.Marshal(header)
|
||||
|
||||
var buf bytes.Buffer
|
||||
binary.Write(&buf, binary.LittleEndian, uint64(len(headerJSON)))
|
||||
buf.Write(headerJSON)
|
||||
|
||||
filePath := filepath.Join(tempDir, "test.safetensors")
|
||||
if err := os.WriteFile(filePath, buf.Bytes(), 0644); err != nil {
|
||||
t.Fatalf("failed to write test file: %v", err)
|
||||
}
|
||||
|
||||
info, err := readSafetensorsHeader(filePath)
|
||||
if err != nil {
|
||||
t.Fatalf("readSafetensorsHeader() error = %v", err)
|
||||
}
|
||||
|
||||
if info.Dtype != "BF16" {
|
||||
t.Errorf("Dtype = %v, want BF16", info.Dtype)
|
||||
}
|
||||
if len(info.Shape) != 2 || info.Shape[0] != 1024 || info.Shape[1] != 768 {
|
||||
t.Errorf("Shape = %v, want [1024, 768]", info.Shape)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadSafetensorsHeader_FileNotFound(t *testing.T) {
|
||||
_, err := readSafetensorsHeader("/nonexistent/path/file.safetensors")
|
||||
if err == nil {
|
||||
t.Error("expected error for nonexistent file")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user