mirror of
https://github.com/ollama/ollama.git
synced 2026-01-19 04:51:17 -05:00
Compare commits
2 Commits
parth/decr
...
jmorganca/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0d3648c1be | ||
|
|
02a2401596 |
@@ -190,7 +190,7 @@ if(MLX_ENGINE)
|
|||||||
install(TARGETS mlx mlxc
|
install(TARGETS mlx mlxc
|
||||||
RUNTIME_DEPENDENCIES
|
RUNTIME_DEPENDENCIES
|
||||||
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
|
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
|
||||||
PRE_INCLUDE_REGEXES cublas cublasLt cudart nvrtc cudnn nccl
|
PRE_INCLUDE_REGEXES cublas cublasLt cudart nvrtc nvrtc-builtins cudnn nccl openblas gfortran
|
||||||
PRE_EXCLUDE_REGEXES ".*"
|
PRE_EXCLUDE_REGEXES ".*"
|
||||||
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
||||||
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT MLX
|
||||||
|
|||||||
@@ -95,48 +95,11 @@ func (p *blobDownloadPart) UnmarshalJSON(b []byte) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// numDownloadParts is the default number of concurrent download parts for standard downloads
|
numDownloadParts = 16
|
||||||
numDownloadParts = 16
|
|
||||||
// numHFDownloadParts is the reduced number of concurrent download parts for HuggingFace
|
|
||||||
// downloads to avoid triggering rate limits (HTTP 429 errors). See GitHub issue #13297.
|
|
||||||
numHFDownloadParts = 4
|
|
||||||
minDownloadPartSize int64 = 100 * format.MegaByte
|
minDownloadPartSize int64 = 100 * format.MegaByte
|
||||||
maxDownloadPartSize int64 = 1000 * format.MegaByte
|
maxDownloadPartSize int64 = 1000 * format.MegaByte
|
||||||
)
|
)
|
||||||
|
|
||||||
// isHuggingFaceURL returns true if the URL is from a HuggingFace domain.
|
|
||||||
// This includes:
|
|
||||||
// - huggingface.co (main domain)
|
|
||||||
// - *.huggingface.co (subdomains like cdn-lfs.huggingface.co)
|
|
||||||
// - hf.co (shortlink domain)
|
|
||||||
// - *.hf.co (CDN domains like cdn-lfs.hf.co, cdn-lfs3.hf.co)
|
|
||||||
func isHuggingFaceURL(u *url.URL) bool {
|
|
||||||
if u == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
host := strings.ToLower(u.Hostname())
|
|
||||||
return host == "huggingface.co" ||
|
|
||||||
strings.HasSuffix(host, ".huggingface.co") ||
|
|
||||||
host == "hf.co" ||
|
|
||||||
strings.HasSuffix(host, ".hf.co")
|
|
||||||
}
|
|
||||||
|
|
||||||
// getNumDownloadParts returns the number of concurrent download parts to use
|
|
||||||
// for the given URL. HuggingFace URLs use reduced concurrency (default 4) to
|
|
||||||
// avoid triggering rate limits. This can be overridden via the OLLAMA_HF_CONCURRENCY
|
|
||||||
// environment variable. For non-HuggingFace URLs, returns the standard concurrency (16).
|
|
||||||
func getNumDownloadParts(u *url.URL) int {
|
|
||||||
if isHuggingFaceURL(u) {
|
|
||||||
if v := os.Getenv("OLLAMA_HF_CONCURRENCY"); v != "" {
|
|
||||||
if n, err := strconv.Atoi(v); err == nil && n > 0 {
|
|
||||||
return n
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return numHFDownloadParts
|
|
||||||
}
|
|
||||||
return numDownloadParts
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *blobDownloadPart) Name() string {
|
func (p *blobDownloadPart) Name() string {
|
||||||
return strings.Join([]string{
|
return strings.Join([]string{
|
||||||
p.blobDownload.Name, "partial", strconv.Itoa(p.N),
|
p.blobDownload.Name, "partial", strconv.Itoa(p.N),
|
||||||
@@ -308,11 +271,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
|
|||||||
}
|
}
|
||||||
|
|
||||||
g, inner := errgroup.WithContext(ctx)
|
g, inner := errgroup.WithContext(ctx)
|
||||||
concurrency := getNumDownloadParts(directURL)
|
g.SetLimit(numDownloadParts)
|
||||||
if concurrency != numDownloadParts {
|
|
||||||
slog.Info(fmt.Sprintf("using reduced concurrency (%d) for HuggingFace download", concurrency))
|
|
||||||
}
|
|
||||||
g.SetLimit(concurrency)
|
|
||||||
for i := range b.Parts {
|
for i := range b.Parts {
|
||||||
part := b.Parts[i]
|
part := b.Parts[i]
|
||||||
if part.Completed.Load() == part.Size {
|
if part.Completed.Load() == part.Size {
|
||||||
|
|||||||
@@ -1,194 +0,0 @@
|
|||||||
package server
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/url"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestIsHuggingFaceURL(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
url string
|
|
||||||
expected bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "nil url",
|
|
||||||
url: "",
|
|
||||||
expected: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "huggingface.co main domain",
|
|
||||||
url: "https://huggingface.co/some/model",
|
|
||||||
expected: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "cdn-lfs.huggingface.co subdomain",
|
|
||||||
url: "https://cdn-lfs.huggingface.co/repos/abc/123",
|
|
||||||
expected: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "cdn-lfs3.hf.co CDN domain",
|
|
||||||
url: "https://cdn-lfs3.hf.co/repos/abc/123",
|
|
||||||
expected: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "hf.co shortlink domain",
|
|
||||||
url: "https://hf.co/model",
|
|
||||||
expected: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "uppercase HuggingFace domain",
|
|
||||||
url: "https://HUGGINGFACE.CO/model",
|
|
||||||
expected: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "mixed case HF domain",
|
|
||||||
url: "https://Cdn-Lfs.HF.Co/repos",
|
|
||||||
expected: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "ollama registry",
|
|
||||||
url: "https://registry.ollama.ai/v2/library/llama3",
|
|
||||||
expected: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "github.com",
|
|
||||||
url: "https://github.com/ollama/ollama",
|
|
||||||
expected: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "fake huggingface domain",
|
|
||||||
url: "https://nothuggingface.co/model",
|
|
||||||
expected: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "fake hf domain",
|
|
||||||
url: "https://nothf.co/model",
|
|
||||||
expected: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "huggingface in path not host",
|
|
||||||
url: "https://example.com/huggingface.co/model",
|
|
||||||
expected: false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range tests {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
var u *url.URL
|
|
||||||
if tc.url != "" {
|
|
||||||
var err error
|
|
||||||
u, err = url.Parse(tc.url)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to parse URL: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
got := isHuggingFaceURL(u)
|
|
||||||
assert.Equal(t, tc.expected, got)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetNumDownloadParts(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
url string
|
|
||||||
envValue string
|
|
||||||
expected int
|
|
||||||
description string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "nil url returns default",
|
|
||||||
url: "",
|
|
||||||
envValue: "",
|
|
||||||
expected: numDownloadParts,
|
|
||||||
description: "nil URL should return standard concurrency",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "ollama registry returns default",
|
|
||||||
url: "https://registry.ollama.ai/v2/library/llama3",
|
|
||||||
envValue: "",
|
|
||||||
expected: numDownloadParts,
|
|
||||||
description: "Ollama registry should use standard concurrency",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "huggingface returns reduced default",
|
|
||||||
url: "https://huggingface.co/model/repo",
|
|
||||||
envValue: "",
|
|
||||||
expected: numHFDownloadParts,
|
|
||||||
description: "HuggingFace should use reduced concurrency",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "hf.co CDN returns reduced default",
|
|
||||||
url: "https://cdn-lfs3.hf.co/repos/abc/123",
|
|
||||||
envValue: "",
|
|
||||||
expected: numHFDownloadParts,
|
|
||||||
description: "HuggingFace CDN should use reduced concurrency",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "huggingface with env override",
|
|
||||||
url: "https://huggingface.co/model/repo",
|
|
||||||
envValue: "2",
|
|
||||||
expected: 2,
|
|
||||||
description: "OLLAMA_HF_CONCURRENCY should override default",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "huggingface with higher env override",
|
|
||||||
url: "https://huggingface.co/model/repo",
|
|
||||||
envValue: "8",
|
|
||||||
expected: 8,
|
|
||||||
description: "OLLAMA_HF_CONCURRENCY can be set higher than default",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "huggingface with invalid env (non-numeric)",
|
|
||||||
url: "https://huggingface.co/model/repo",
|
|
||||||
envValue: "invalid",
|
|
||||||
expected: numHFDownloadParts,
|
|
||||||
description: "Invalid OLLAMA_HF_CONCURRENCY should fall back to default",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "huggingface with invalid env (zero)",
|
|
||||||
url: "https://huggingface.co/model/repo",
|
|
||||||
envValue: "0",
|
|
||||||
expected: numHFDownloadParts,
|
|
||||||
description: "Zero OLLAMA_HF_CONCURRENCY should fall back to default",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "huggingface with invalid env (negative)",
|
|
||||||
url: "https://huggingface.co/model/repo",
|
|
||||||
envValue: "-1",
|
|
||||||
expected: numHFDownloadParts,
|
|
||||||
description: "Negative OLLAMA_HF_CONCURRENCY should fall back to default",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "non-huggingface ignores env",
|
|
||||||
url: "https://registry.ollama.ai/v2/library/llama3",
|
|
||||||
envValue: "2",
|
|
||||||
expected: numDownloadParts,
|
|
||||||
description: "OLLAMA_HF_CONCURRENCY should not affect non-HF URLs",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range tests {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
// Set or clear the environment variable
|
|
||||||
if tc.envValue != "" {
|
|
||||||
t.Setenv("OLLAMA_HF_CONCURRENCY", tc.envValue)
|
|
||||||
}
|
|
||||||
|
|
||||||
var u *url.URL
|
|
||||||
if tc.url != "" {
|
|
||||||
var err error
|
|
||||||
u, err = url.Parse(tc.url)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to parse URL: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
got := getNumDownloadParts(u)
|
|
||||||
assert.Equal(t, tc.expected, got, tc.description)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -11,9 +11,11 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime/pprof"
|
"runtime/pprof"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||||
"github.com/ollama/ollama/x/imagegen/models/gemma3"
|
"github.com/ollama/ollama/x/imagegen/models/gemma3"
|
||||||
|
"github.com/ollama/ollama/x/imagegen/models/glm_image"
|
||||||
"github.com/ollama/ollama/x/imagegen/models/gpt_oss"
|
"github.com/ollama/ollama/x/imagegen/models/gpt_oss"
|
||||||
"github.com/ollama/ollama/x/imagegen/models/llama"
|
"github.com/ollama/ollama/x/imagegen/models/llama"
|
||||||
"github.com/ollama/ollama/x/imagegen/models/qwen_image"
|
"github.com/ollama/ollama/x/imagegen/models/qwen_image"
|
||||||
@@ -61,6 +63,7 @@ func main() {
|
|||||||
|
|
||||||
// Legacy mode flags
|
// Legacy mode flags
|
||||||
zimageFlag := flag.Bool("zimage", false, "Z-Image generation")
|
zimageFlag := flag.Bool("zimage", false, "Z-Image generation")
|
||||||
|
glmImageFlag := flag.Bool("glm-image", false, "GLM-Image generation")
|
||||||
qwenImage := flag.Bool("qwen-image", false, "Qwen-Image text-to-image generation")
|
qwenImage := flag.Bool("qwen-image", false, "Qwen-Image text-to-image generation")
|
||||||
qwenImageEdit := flag.Bool("qwen-image-edit", false, "Qwen-Image-Edit image editing")
|
qwenImageEdit := flag.Bool("qwen-image-edit", false, "Qwen-Image-Edit image editing")
|
||||||
var inputImages stringSlice
|
var inputImages stringSlice
|
||||||
@@ -117,6 +120,33 @@ func main() {
|
|||||||
if err == nil {
|
if err == nil {
|
||||||
err = saveImageArray(img, *out)
|
err = saveImageArray(img, *out)
|
||||||
}
|
}
|
||||||
|
case *glmImageFlag:
|
||||||
|
m := &glm_image.Model{}
|
||||||
|
// Use LoadFromPath if model path looks like a directory, otherwise use Load (ollama manifest)
|
||||||
|
var loadErr error
|
||||||
|
if strings.HasPrefix(*modelPath, ".") || strings.HasPrefix(*modelPath, "/") {
|
||||||
|
loadErr = m.LoadFromPath(*modelPath)
|
||||||
|
} else {
|
||||||
|
loadErr = m.Load(*modelPath)
|
||||||
|
}
|
||||||
|
if loadErr != nil {
|
||||||
|
log.Fatal(loadErr)
|
||||||
|
}
|
||||||
|
var img *mlx.Array
|
||||||
|
img, err = m.GenerateFromConfig(context.Background(), &glm_image.GenerateConfig{
|
||||||
|
Prompt: *prompt,
|
||||||
|
Width: int32(*width),
|
||||||
|
Height: int32(*height),
|
||||||
|
Steps: *steps,
|
||||||
|
Seed: *seed,
|
||||||
|
Temperature: float32(*temperature),
|
||||||
|
TopP: float32(*topP),
|
||||||
|
GuidanceScale: float32(*cfgScale),
|
||||||
|
MaxVisualTokens: int32(*maxTokens),
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
err = saveImageArray(img, *out)
|
||||||
|
}
|
||||||
case *qwenImage:
|
case *qwenImage:
|
||||||
m, loadErr := qwen_image.LoadPersistent(*modelPath)
|
m, loadErr := qwen_image.LoadPersistent(*modelPath)
|
||||||
if loadErr != nil {
|
if loadErr != nil {
|
||||||
|
|||||||
@@ -48,7 +48,7 @@ func CreateModel(modelName, modelDir, quantize string, createLayer LayerCreator,
|
|||||||
var totalParams int64 // Count parameters from original tensor shapes
|
var totalParams int64 // Count parameters from original tensor shapes
|
||||||
|
|
||||||
// Components to process - extract individual tensors from each
|
// Components to process - extract individual tensors from each
|
||||||
components := []string{"text_encoder", "transformer", "vae"}
|
components := []string{"text_encoder", "transformer", "vae", "vision_language_encoder"}
|
||||||
|
|
||||||
for _, component := range components {
|
for _, component := range components {
|
||||||
componentDir := filepath.Join(modelDir, component)
|
componentDir := filepath.Join(modelDir, component)
|
||||||
@@ -126,10 +126,13 @@ func CreateModel(modelName, modelDir, quantize string, createLayer LayerCreator,
|
|||||||
"text_encoder/generation_config.json",
|
"text_encoder/generation_config.json",
|
||||||
"transformer/config.json",
|
"transformer/config.json",
|
||||||
"vae/config.json",
|
"vae/config.json",
|
||||||
|
"vision_language_encoder/config.json",
|
||||||
"scheduler/scheduler_config.json",
|
"scheduler/scheduler_config.json",
|
||||||
"tokenizer/tokenizer.json",
|
"tokenizer/tokenizer.json",
|
||||||
"tokenizer/tokenizer_config.json",
|
"tokenizer/tokenizer_config.json",
|
||||||
"tokenizer/vocab.json",
|
"tokenizer/vocab.json",
|
||||||
|
"processor/tokenizer.json", // GLM-Image main tokenizer
|
||||||
|
"processor/tokenizer_config.json", // GLM-Image tokenizer config
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, cfgPath := range configFiles {
|
for _, cfgPath := range configFiles {
|
||||||
|
|||||||
19
x/imagegen/imagegen.md
Normal file
19
x/imagegen/imagegen.md
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
# Image generation models (experimental)
|
||||||
|
|
||||||
|
Experimental image generation models are available for **macOS** in Ollama:
|
||||||
|
|
||||||
|
## Available models
|
||||||
|
|
||||||
|
- [Z-Image-Turbo](https://ollama.com/x/z-image-turbo)
|
||||||
|
|
||||||
|
```
|
||||||
|
ollama run x/z-image-turbo
|
||||||
|
```
|
||||||
|
|
||||||
|
> **Note**: [`x`](https://ollama.com/x) is a username on ollama.com where the maintainer team uploads experimental models
|
||||||
|
|
||||||
|
More models coming soon:
|
||||||
|
|
||||||
|
1. Qwen-Image-2512
|
||||||
|
2. Qwen-Image-Edit-2511
|
||||||
|
3. GLM-Image
|
||||||
@@ -27,6 +27,7 @@ var modelVRAMEstimates = map[string]uint64{
|
|||||||
"ZImagePipeline": 21 * GB, // ~21GB for Z-Image (text encoder + transformer + VAE)
|
"ZImagePipeline": 21 * GB, // ~21GB for Z-Image (text encoder + transformer + VAE)
|
||||||
"FluxPipeline": 21 * GB, // ~21GB for Flux (same architecture)
|
"FluxPipeline": 21 * GB, // ~21GB for Flux (same architecture)
|
||||||
"QwenImagePipeline": 80 * GB, // TODO: verify actual requirements, using conservative estimate for now
|
"QwenImagePipeline": 80 * GB, // TODO: verify actual requirements, using conservative estimate for now
|
||||||
|
"GlmImagePipeline": 80 * GB, // ~34GB weights + ~46GB working memory for 9B+7B hybrid model
|
||||||
}
|
}
|
||||||
|
|
||||||
// CheckPlatformSupport validates that image generation is supported on the current platform.
|
// CheckPlatformSupport validates that image generation is supported on the current platform.
|
||||||
|
|||||||
693
x/imagegen/models/glm_image/glm_image.go
Normal file
693
x/imagegen/models/glm_image/glm_image.go
Normal file
@@ -0,0 +1,693 @@
|
|||||||
|
//go:build mlx
|
||||||
|
|
||||||
|
// Package glm_image implements the GLM-Image hybrid AR + diffusion model.
|
||||||
|
package glm_image
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"math"
|
||||||
|
"path/filepath"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/x/imagegen"
|
||||||
|
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ByT5Tokenizer is a simple byte-level tokenizer for ByT5
|
||||||
|
// ByT5 uses bytes as tokens: each byte (0-255) maps to token ID (3-258)
|
||||||
|
// Special tokens: 0=pad, 1=eos, 2=unk
|
||||||
|
type ByT5Tokenizer struct {
|
||||||
|
PadTokenID int32
|
||||||
|
EOSTokenID int32
|
||||||
|
UNKTokenID int32
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewByT5Tokenizer creates a new ByT5 tokenizer
|
||||||
|
func NewByT5Tokenizer() *ByT5Tokenizer {
|
||||||
|
return &ByT5Tokenizer{
|
||||||
|
PadTokenID: 0,
|
||||||
|
EOSTokenID: 1,
|
||||||
|
UNKTokenID: 2,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode converts a string to token IDs
|
||||||
|
func (t *ByT5Tokenizer) Encode(text string) []int32 {
|
||||||
|
bytes := []byte(text)
|
||||||
|
tokens := make([]int32, len(bytes))
|
||||||
|
for i, b := range bytes {
|
||||||
|
// Standard ByT5 tokenization: bytes 0-255 map to tokens 3-258
|
||||||
|
// (tokens 0, 1, 2 are PAD, EOS, UNK)
|
||||||
|
tokens[i] = int32(b) + 3
|
||||||
|
}
|
||||||
|
return tokens
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode converts token IDs back to a string
|
||||||
|
func (t *ByT5Tokenizer) Decode(tokens []int32) string {
|
||||||
|
bytes := make([]byte, 0, len(tokens))
|
||||||
|
for _, tok := range tokens {
|
||||||
|
if tok >= 3 && tok < 259 {
|
||||||
|
bytes = append(bytes, byte(tok-3))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return string(bytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateConfig holds all options for image generation.
|
||||||
|
type GenerateConfig struct {
|
||||||
|
Prompt string
|
||||||
|
NegativePrompt string // For CFG (optional, not typically used with GLM-Image)
|
||||||
|
GuidanceScale float32 // Guidance scale (default: 1.5)
|
||||||
|
Width int32 // Image width (default: 1024, must be divisible by 32)
|
||||||
|
Height int32 // Image height (default: 1024, must be divisible by 32)
|
||||||
|
Steps int // Diffusion denoising steps (default: 50)
|
||||||
|
Seed int64 // Random seed
|
||||||
|
Progress ProgressFunc // Optional progress callback
|
||||||
|
|
||||||
|
// AR generation options
|
||||||
|
MaxVisualTokens int32 // Max visual tokens to generate (default: 256)
|
||||||
|
Temperature float32 // AR sampling temperature (default: 0.9)
|
||||||
|
TopP float32 // Nucleus sampling (default: 0.75)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProgressFunc is called during generation with stage and step progress.
|
||||||
|
type ProgressFunc func(stage string, step, totalSteps int)
|
||||||
|
|
||||||
|
// Model represents a GLM-Image hybrid model.
|
||||||
|
type Model struct {
|
||||||
|
ModelName string
|
||||||
|
Tokenizer *ByT5Tokenizer // For T5 text encoder (glyph embeddings)
|
||||||
|
GLMTokenizer *GLMTokenizer // For AR model (visual token generation)
|
||||||
|
TextEncoder *T5TextEncoder
|
||||||
|
VisionLanguageEncoder *VisionLanguageEncoder
|
||||||
|
Transformer *DiffusionTransformer
|
||||||
|
VAEDecoder *VAEDecoder
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load loads the GLM-Image model from ollama blob storage.
|
||||||
|
func (m *Model) Load(modelName string) error {
|
||||||
|
fmt.Printf("Loading GLM-Image model from manifest: %s...\n", modelName)
|
||||||
|
start := time.Now()
|
||||||
|
|
||||||
|
if mlx.GPUIsAvailable() {
|
||||||
|
mlx.SetDefaultDeviceGPU()
|
||||||
|
mlx.EnableCompile()
|
||||||
|
}
|
||||||
|
|
||||||
|
m.ModelName = modelName
|
||||||
|
|
||||||
|
// Load manifest
|
||||||
|
manifest, err := imagegen.LoadManifest(modelName)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("load manifest: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create ByT5 tokenizer (byte-level, no vocabulary file needed)
|
||||||
|
// Used for T5 text encoder (glyph embeddings)
|
||||||
|
fmt.Print(" Creating ByT5 tokenizer... ")
|
||||||
|
m.Tokenizer = NewByT5Tokenizer()
|
||||||
|
fmt.Println("✓")
|
||||||
|
|
||||||
|
// Load GLM tokenizer for AR model (visual token generation)
|
||||||
|
fmt.Print(" Loading GLM tokenizer... ")
|
||||||
|
glmTok, err := NewGLMTokenizer(manifest)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("glm tokenizer: %w", err)
|
||||||
|
}
|
||||||
|
m.GLMTokenizer = glmTok
|
||||||
|
fmt.Println("✓")
|
||||||
|
|
||||||
|
// Load T5 text encoder (~830MB)
|
||||||
|
m.TextEncoder = &T5TextEncoder{}
|
||||||
|
if err := m.TextEncoder.Load(manifest); err != nil {
|
||||||
|
return fmt.Errorf("text encoder: %w", err)
|
||||||
|
}
|
||||||
|
mlx.Eval(mlx.Collect(m.TextEncoder)...)
|
||||||
|
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||||
|
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||||
|
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||||
|
|
||||||
|
// Load vision-language encoder (~19GB, 9B params)
|
||||||
|
m.VisionLanguageEncoder = &VisionLanguageEncoder{}
|
||||||
|
if err := m.VisionLanguageEncoder.Load(manifest); err != nil {
|
||||||
|
return fmt.Errorf("vision language encoder: %w", err)
|
||||||
|
}
|
||||||
|
mlx.Eval(mlx.Collect(m.VisionLanguageEncoder)...)
|
||||||
|
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||||
|
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||||
|
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||||
|
|
||||||
|
// Load diffusion transformer (~13GB, 7B params)
|
||||||
|
m.Transformer = &DiffusionTransformer{}
|
||||||
|
if err := m.Transformer.Load(manifest); err != nil {
|
||||||
|
return fmt.Errorf("transformer: %w", err)
|
||||||
|
}
|
||||||
|
mlx.Eval(mlx.Collect(m.Transformer)...)
|
||||||
|
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||||
|
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||||
|
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||||
|
|
||||||
|
// Load VAE decoder (~775MB)
|
||||||
|
m.VAEDecoder = &VAEDecoder{}
|
||||||
|
if err := m.VAEDecoder.Load(manifest); err != nil {
|
||||||
|
return fmt.Errorf("VAE decoder: %w", err)
|
||||||
|
}
|
||||||
|
mlx.Eval(mlx.Collect(m.VAEDecoder)...)
|
||||||
|
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||||
|
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||||
|
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||||
|
|
||||||
|
mem := mlx.MetalGetActiveMemory()
|
||||||
|
fmt.Printf(" Loaded in %.2fs (%.1f GB VRAM)\n", time.Since(start).Seconds(), float64(mem)/(1024*1024*1024))
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadFromPath loads the model from a directory path (not ollama manifest)
|
||||||
|
func (m *Model) LoadFromPath(modelPath string) error {
|
||||||
|
fmt.Printf("Loading GLM-Image model from path: %s...\n", modelPath)
|
||||||
|
start := time.Now()
|
||||||
|
|
||||||
|
if mlx.GPUIsAvailable() {
|
||||||
|
mlx.SetDefaultDeviceGPU()
|
||||||
|
mlx.EnableCompile()
|
||||||
|
}
|
||||||
|
|
||||||
|
m.ModelName = modelPath
|
||||||
|
|
||||||
|
// Create ByT5 tokenizer (byte-level, no vocabulary file needed)
|
||||||
|
fmt.Print(" Creating ByT5 tokenizer... ")
|
||||||
|
m.Tokenizer = NewByT5Tokenizer()
|
||||||
|
fmt.Println("✓")
|
||||||
|
|
||||||
|
// Load GLM tokenizer for AR model (visual token generation)
|
||||||
|
fmt.Print(" Loading GLM tokenizer... ")
|
||||||
|
glmTok, err := NewGLMTokenizerFromPath(modelPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("glm tokenizer: %w", err)
|
||||||
|
}
|
||||||
|
m.GLMTokenizer = glmTok
|
||||||
|
fmt.Println("✓")
|
||||||
|
|
||||||
|
// Load T5 text encoder
|
||||||
|
m.TextEncoder = &T5TextEncoder{}
|
||||||
|
if err := m.TextEncoder.LoadFromPath(filepath.Join(modelPath, "text_encoder")); err != nil {
|
||||||
|
return fmt.Errorf("text encoder: %w", err)
|
||||||
|
}
|
||||||
|
mlx.Eval(mlx.Collect(m.TextEncoder)...)
|
||||||
|
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||||
|
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||||
|
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||||
|
|
||||||
|
// Load vision-language encoder
|
||||||
|
m.VisionLanguageEncoder = &VisionLanguageEncoder{}
|
||||||
|
if err := m.VisionLanguageEncoder.LoadFromPath(filepath.Join(modelPath, "vision_language_encoder")); err != nil {
|
||||||
|
return fmt.Errorf("vision language encoder: %w", err)
|
||||||
|
}
|
||||||
|
mlx.Eval(mlx.Collect(m.VisionLanguageEncoder)...)
|
||||||
|
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||||
|
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||||
|
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||||
|
|
||||||
|
// Load diffusion transformer
|
||||||
|
m.Transformer = &DiffusionTransformer{}
|
||||||
|
if err := m.Transformer.LoadFromPath(filepath.Join(modelPath, "transformer")); err != nil {
|
||||||
|
return fmt.Errorf("transformer: %w", err)
|
||||||
|
}
|
||||||
|
mlx.Eval(mlx.Collect(m.Transformer)...)
|
||||||
|
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||||
|
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||||
|
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||||
|
|
||||||
|
// Load VAE decoder
|
||||||
|
m.VAEDecoder = &VAEDecoder{}
|
||||||
|
if err := m.VAEDecoder.LoadFromPath(filepath.Join(modelPath, "vae")); err != nil {
|
||||||
|
return fmt.Errorf("VAE decoder: %w", err)
|
||||||
|
}
|
||||||
|
mlx.Eval(mlx.Collect(m.VAEDecoder)...)
|
||||||
|
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
|
||||||
|
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
|
||||||
|
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
|
||||||
|
|
||||||
|
mem := mlx.MetalGetActiveMemory()
|
||||||
|
fmt.Printf(" Loaded in %.2fs (%.1f GB VRAM)\n", time.Since(start).Seconds(), float64(mem)/(1024*1024*1024))
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate creates an image from a prompt.
|
||||||
|
func (m *Model) Generate(prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
|
||||||
|
return m.GenerateFromConfig(context.Background(), &GenerateConfig{
|
||||||
|
Prompt: prompt,
|
||||||
|
Width: width,
|
||||||
|
Height: height,
|
||||||
|
Steps: steps,
|
||||||
|
Seed: seed,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateWithProgress creates an image with progress callback.
|
||||||
|
func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress ProgressFunc) (*mlx.Array, error) {
|
||||||
|
return m.GenerateFromConfig(context.Background(), &GenerateConfig{
|
||||||
|
Prompt: prompt,
|
||||||
|
Width: width,
|
||||||
|
Height: height,
|
||||||
|
Steps: steps,
|
||||||
|
Seed: seed,
|
||||||
|
Progress: progress,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateFromConfig generates an image using the unified config struct.
|
||||||
|
func (m *Model) GenerateFromConfig(ctx context.Context, cfg *GenerateConfig) (*mlx.Array, error) {
|
||||||
|
start := time.Now()
|
||||||
|
result, err := m.generate(ctx, cfg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
fmt.Printf("Generated in %.2fs (%d diffusion steps)\n", time.Since(start).Seconds(), cfg.Steps)
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateImage implements model.ImageModel interface.
|
||||||
|
func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
|
||||||
|
return m.Generate(prompt, width, height, steps, seed)
|
||||||
|
}
|
||||||
|
|
||||||
|
// generate is the internal generation pipeline.
|
||||||
|
func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array, error) {
|
||||||
|
// Apply defaults
|
||||||
|
if cfg.Width <= 0 {
|
||||||
|
cfg.Width = 1024
|
||||||
|
}
|
||||||
|
if cfg.Height <= 0 {
|
||||||
|
cfg.Height = 1024
|
||||||
|
}
|
||||||
|
if cfg.Steps <= 0 {
|
||||||
|
cfg.Steps = 50
|
||||||
|
}
|
||||||
|
if cfg.GuidanceScale <= 0 {
|
||||||
|
cfg.GuidanceScale = 1.5
|
||||||
|
}
|
||||||
|
// Calculate MaxVisualTokens based on image dimensions
|
||||||
|
// GLM-Image generates TWO grids of visual tokens:
|
||||||
|
// 1. First: prev (small) grid - prevTokenH × prevTokenW tokens
|
||||||
|
// 2. Then: target (large) grid - tokenH × tokenW tokens
|
||||||
|
// After generation, we extract only the TARGET grid tokens for diffusion.
|
||||||
|
factor := int32(32)
|
||||||
|
tokenH := cfg.Height / factor
|
||||||
|
tokenW := cfg.Width / factor
|
||||||
|
targetGridTokens := tokenH * tokenW
|
||||||
|
|
||||||
|
// Compute prev grid dimensions using diffusers formula:
|
||||||
|
// ratio = token_h / token_w
|
||||||
|
// prev_token_h = int(sqrt(ratio) * 16)
|
||||||
|
// prev_token_w = int(sqrt(1/ratio) * 16)
|
||||||
|
ratio := float64(tokenH) / float64(tokenW)
|
||||||
|
prevTokenH := int32(math.Sqrt(ratio) * 16)
|
||||||
|
prevTokenW := int32(math.Sqrt(1/ratio) * 16)
|
||||||
|
prevGridTokens := prevTokenH * prevTokenW
|
||||||
|
|
||||||
|
// Total tokens to generate = prev grid + target grid
|
||||||
|
// (diffusers does max_new_tokens = total + 1 for EOS, but we stop on EOS anyway)
|
||||||
|
cfg.MaxVisualTokens = prevGridTokens + targetGridTokens
|
||||||
|
if cfg.Temperature <= 0 {
|
||||||
|
cfg.Temperature = 0.9
|
||||||
|
}
|
||||||
|
if cfg.TopP <= 0 {
|
||||||
|
cfg.TopP = 0.75
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure dimensions are divisible by 32
|
||||||
|
cfg.Width = (cfg.Width / 32) * 32
|
||||||
|
cfg.Height = (cfg.Height / 32) * 32
|
||||||
|
|
||||||
|
tcfg := m.Transformer.Config
|
||||||
|
latentH := cfg.Height / 8
|
||||||
|
latentW := cfg.Width / 8
|
||||||
|
|
||||||
|
// Progress callback helper
|
||||||
|
progress := func(stage string, step, total int) {
|
||||||
|
if cfg.Progress != nil {
|
||||||
|
cfg.Progress(stage, step, total)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// === PHASE 1: T5 Text Encoding ===
|
||||||
|
fmt.Println("[T5] Encoding glyph text...")
|
||||||
|
progress("text_encoding", 0, 1)
|
||||||
|
textEmbed := m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.Prompt)
|
||||||
|
mlx.Keep(textEmbed)
|
||||||
|
mlx.Eval(textEmbed)
|
||||||
|
fmt.Printf("[T5] Done, shape: %v\n", textEmbed.Shape())
|
||||||
|
progress("text_encoding", 1, 1)
|
||||||
|
|
||||||
|
// === PHASE 2: AR Visual Token Generation ===
|
||||||
|
fmt.Printf("[AR] Generating %d visual tokens...\n", cfg.MaxVisualTokens)
|
||||||
|
progress("ar_generation", 0, int(cfg.MaxVisualTokens))
|
||||||
|
visualTokens := m.VisionLanguageEncoder.Generate(
|
||||||
|
cfg.Prompt,
|
||||||
|
m.GLMTokenizer,
|
||||||
|
cfg.MaxVisualTokens,
|
||||||
|
cfg.Temperature,
|
||||||
|
cfg.TopP,
|
||||||
|
cfg.Seed,
|
||||||
|
cfg.Height,
|
||||||
|
cfg.Width,
|
||||||
|
func(step int) {
|
||||||
|
if step%100 == 0 || step < 10 {
|
||||||
|
fmt.Printf("[AR] Step %d/%d\n", step, cfg.MaxVisualTokens)
|
||||||
|
}
|
||||||
|
progress("ar_generation", step, int(cfg.MaxVisualTokens))
|
||||||
|
},
|
||||||
|
)
|
||||||
|
mlx.Keep(visualTokens)
|
||||||
|
mlx.Eval(visualTokens)
|
||||||
|
fmt.Printf("[AR] Done generating visual tokens\n")
|
||||||
|
progress("ar_generation", int(cfg.MaxVisualTokens), int(cfg.MaxVisualTokens))
|
||||||
|
|
||||||
|
vtShape := visualTokens.Shape()
|
||||||
|
totalGenerated := vtShape[1]
|
||||||
|
fmt.Printf("[AR] Generated %d tokens total\n", totalGenerated)
|
||||||
|
|
||||||
|
// Extract only the TARGET grid tokens (skip the prev grid tokens)
|
||||||
|
// diffusers: large_image_tokens = outputs[input_length + large_image_start_offset : ...]
|
||||||
|
// large_image_start_offset = prev_grid_size
|
||||||
|
var targetGridVisualTokens *mlx.Array
|
||||||
|
if totalGenerated >= prevGridTokens+targetGridTokens {
|
||||||
|
// Full generation completed - extract target grid
|
||||||
|
targetGridVisualTokens = mlx.Slice(visualTokens,
|
||||||
|
[]int32{0, prevGridTokens},
|
||||||
|
[]int32{1, prevGridTokens + targetGridTokens})
|
||||||
|
mlx.Keep(targetGridVisualTokens)
|
||||||
|
mlx.Eval(targetGridVisualTokens)
|
||||||
|
} else if totalGenerated > prevGridTokens {
|
||||||
|
// Partial target grid - take what we have
|
||||||
|
actualTargetTokens := totalGenerated - prevGridTokens
|
||||||
|
targetGridVisualTokens = mlx.Slice(visualTokens,
|
||||||
|
[]int32{0, prevGridTokens},
|
||||||
|
[]int32{1, totalGenerated})
|
||||||
|
mlx.Keep(targetGridVisualTokens)
|
||||||
|
mlx.Eval(targetGridVisualTokens)
|
||||||
|
fmt.Printf("WARNING: Partial target grid: got %d/%d target tokens\n",
|
||||||
|
actualTargetTokens, targetGridTokens)
|
||||||
|
} else {
|
||||||
|
// Not enough tokens - EOS came too early
|
||||||
|
return nil, fmt.Errorf("AR generation stopped too early: got %d tokens, need at least %d (prev grid) + 1",
|
||||||
|
totalGenerated, prevGridTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
// === PHASE 3: Diffusion Decoding ===
|
||||||
|
// Setup scheduler with dynamic shift based on image size
|
||||||
|
scheduler := NewFlowMatchScheduler(DefaultSchedulerConfig())
|
||||||
|
imgSeqLen := (latentH / tcfg.PatchSize) * (latentW / tcfg.PatchSize)
|
||||||
|
scheduler.SetTimestepsWithDynamicShift(cfg.Steps, imgSeqLen)
|
||||||
|
|
||||||
|
// Initialize noise latents [B, C, H, W]
|
||||||
|
latents := scheduler.InitNoise([]int32{1, tcfg.InChannels, latentH, latentW}, cfg.Seed)
|
||||||
|
mlx.Eval(latents)
|
||||||
|
|
||||||
|
// Upsample TARGET grid visual tokens 2x to match patch count (matching diffusers)
|
||||||
|
// target_grid tokens -> 2x upsample -> patch_count
|
||||||
|
// e.g., 32x32=1024 tokens -> 64x64=4096 patches for 1024x1024
|
||||||
|
visualTokensUpsampled := upsampleTokens(targetGridVisualTokens, tokenH, tokenW, 2)
|
||||||
|
|
||||||
|
// Prepare prior embeddings from upsampled visual tokens (VQ codebook lookup + projection)
|
||||||
|
priorEmbed := m.Transformer.EmbedPriorTokens(visualTokensUpsampled)
|
||||||
|
mlx.Keep(priorEmbed)
|
||||||
|
mlx.Eval(priorEmbed)
|
||||||
|
|
||||||
|
// Prepare text conditioning (project T5 embeddings)
|
||||||
|
textCond := m.Transformer.ProjectTextEmbeddings(textEmbed)
|
||||||
|
mlx.Keep(textCond)
|
||||||
|
mlx.Eval(textCond)
|
||||||
|
|
||||||
|
// === CFG Setup ===
|
||||||
|
// For classifier-free guidance, we need unconditional (negative) text embeddings
|
||||||
|
// GLM-Image uses empty string "" for negative prompt
|
||||||
|
doCFG := cfg.GuidanceScale > 1.0
|
||||||
|
var negativeTextCond *mlx.Array
|
||||||
|
if doCFG {
|
||||||
|
// Encode empty string for negative prompt
|
||||||
|
negativeTextEmbed := m.TextEncoder.EncodePrompt(m.Tokenizer, "")
|
||||||
|
mlx.Keep(negativeTextEmbed)
|
||||||
|
mlx.Eval(negativeTextEmbed)
|
||||||
|
negativeTextCond = m.Transformer.ProjectTextEmbeddings(negativeTextEmbed)
|
||||||
|
mlx.Keep(negativeTextCond)
|
||||||
|
mlx.Eval(negativeTextCond)
|
||||||
|
negativeTextEmbed.Free()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare conditioning inputs
|
||||||
|
targetSize := mlx.NewArray([]float32{float32(cfg.Height), float32(cfg.Width)}, []int32{1, 2})
|
||||||
|
cropCoords := mlx.NewArray([]float32{0, 0}, []int32{1, 2}) // Default: no crop offset
|
||||||
|
targetSize = mlx.ToBFloat16(targetSize)
|
||||||
|
cropCoords = mlx.ToBFloat16(cropCoords)
|
||||||
|
mlx.Keep(targetSize)
|
||||||
|
mlx.Keep(cropCoords)
|
||||||
|
mlx.Eval(targetSize, cropCoords)
|
||||||
|
|
||||||
|
pH := latentH / tcfg.PatchSize
|
||||||
|
pW := latentW / tcfg.PatchSize
|
||||||
|
|
||||||
|
// Denoising loop
|
||||||
|
fmt.Printf("[Diffusion] Starting %d denoising steps...\n", cfg.Steps)
|
||||||
|
progress("diffusion", 0, cfg.Steps)
|
||||||
|
for i := 0; i < cfg.Steps; i++ {
|
||||||
|
fmt.Printf("[Diffusion] Step %d/%d (timestep=%.1f)\n", i+1, cfg.Steps, scheduler.Timesteps[i]-1)
|
||||||
|
// Check for cancellation
|
||||||
|
if ctx != nil {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
textEmbed.Free()
|
||||||
|
visualTokens.Free()
|
||||||
|
// visualTokensUpsampled points to visualTokens, don't double-free
|
||||||
|
priorEmbed.Free()
|
||||||
|
textCond.Free()
|
||||||
|
latents.Free()
|
||||||
|
return nil, ctx.Err()
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get timestep value for the transformer
|
||||||
|
// scheduler.Timesteps contains raw timestep values (1000 down to ~20)
|
||||||
|
// Pass timestep - 1 to match diffusers: timestep = t.expand(latents.shape[0]) - 1
|
||||||
|
timestepVal := scheduler.Timesteps[i] - 1
|
||||||
|
timestep := mlx.ToBFloat16(mlx.NewArray([]float32{timestepVal}, []int32{1}))
|
||||||
|
|
||||||
|
// Patchify latents [B, C, H, W] -> [B, L, C*p*p]
|
||||||
|
patches := PatchifyLatents(latents, tcfg.PatchSize)
|
||||||
|
|
||||||
|
// Transformer forward with MMDiT architecture
|
||||||
|
// Conditional pass (with text + prior embeddings)
|
||||||
|
outputCond := m.Transformer.ForwardWithPriorDrop(
|
||||||
|
patches,
|
||||||
|
priorEmbed,
|
||||||
|
textCond,
|
||||||
|
timestep,
|
||||||
|
targetSize,
|
||||||
|
cropCoords,
|
||||||
|
pH,
|
||||||
|
pW,
|
||||||
|
false, // priorTokenDrop = false for conditional
|
||||||
|
)
|
||||||
|
|
||||||
|
// Unpatchify [B, L, C*p*p] -> [B, C, H, W]
|
||||||
|
noisePredCond := UnpatchifyLatents(outputCond, latentH, latentW, tcfg.PatchSize, tcfg.OutChannels)
|
||||||
|
|
||||||
|
var noisePred *mlx.Array
|
||||||
|
if doCFG {
|
||||||
|
// Unconditional pass (empty text, dropped prior embeddings)
|
||||||
|
outputUncond := m.Transformer.ForwardWithPriorDrop(
|
||||||
|
patches,
|
||||||
|
priorEmbed, // Still passed but will be ignored due to priorTokenDrop=true
|
||||||
|
negativeTextCond,
|
||||||
|
timestep,
|
||||||
|
targetSize,
|
||||||
|
cropCoords,
|
||||||
|
pH,
|
||||||
|
pW,
|
||||||
|
true, // priorTokenDrop = true for unconditional
|
||||||
|
)
|
||||||
|
noisePredUncond := UnpatchifyLatents(outputUncond, latentH, latentW, tcfg.PatchSize, tcfg.OutChannels)
|
||||||
|
|
||||||
|
// CFG formula: noise_pred = uncond + guidance_scale * (cond - uncond)
|
||||||
|
diff := mlx.Sub(noisePredCond, noisePredUncond)
|
||||||
|
scaled := mlx.MulScalar(diff, cfg.GuidanceScale)
|
||||||
|
noisePred = mlx.Add(noisePredUncond, scaled)
|
||||||
|
} else {
|
||||||
|
noisePred = noisePredCond
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scheduler step
|
||||||
|
oldLatents := latents
|
||||||
|
latents = scheduler.Step(noisePred, latents, i)
|
||||||
|
mlx.Eval(latents)
|
||||||
|
oldLatents.Free()
|
||||||
|
|
||||||
|
progress("diffusion", i+1, cfg.Steps)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cleanup intermediate arrays
|
||||||
|
textEmbed.Free()
|
||||||
|
visualTokens.Free()
|
||||||
|
// visualTokensUpsampled points to visualTokens, don't double-free
|
||||||
|
priorEmbed.Free()
|
||||||
|
textCond.Free()
|
||||||
|
if negativeTextCond != nil {
|
||||||
|
negativeTextCond.Free()
|
||||||
|
}
|
||||||
|
targetSize.Free()
|
||||||
|
cropCoords.Free()
|
||||||
|
|
||||||
|
// === PHASE 4: VAE Decode ===
|
||||||
|
progress("vae_decode", 0, 1)
|
||||||
|
decoded := m.VAEDecoder.Decode(latents)
|
||||||
|
mlx.Eval(decoded)
|
||||||
|
latents.Free()
|
||||||
|
progress("vae_decode", 1, 1)
|
||||||
|
|
||||||
|
return decoded, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// upsampleTokens performs nearest-neighbor upsampling of visual tokens
|
||||||
|
// Converts from prev_grid (e.g., 16x16) to target_grid (e.g., 32x32 for 2x, 64x64 for 4x)
|
||||||
|
// scale must be 2 or 4
|
||||||
|
//
|
||||||
|
// Handles early EOS gracefully: if tokens has fewer than prevH*prevW elements,
|
||||||
|
// missing tokens are padded with 0 (visual token padding value).
|
||||||
|
func upsampleTokens(tokens *mlx.Array, prevH, prevW int32, scale int32) *mlx.Array {
|
||||||
|
// tokens: [1, N] where N <= prevH*prevW (may be shorter if early EOS)
|
||||||
|
// Each token at (i, j) becomes scale*scale tokens in the output
|
||||||
|
|
||||||
|
mlx.Eval(tokens)
|
||||||
|
tokenData := tokens.DataInt32()
|
||||||
|
numTokens := int32(len(tokenData))
|
||||||
|
expectedTokens := prevH * prevW
|
||||||
|
|
||||||
|
// Warn if we got fewer tokens than expected (early EOS)
|
||||||
|
if numTokens < expectedTokens {
|
||||||
|
fmt.Printf("WARNING: upsampleTokens got %d tokens, expected %d (padding with 0)\n",
|
||||||
|
numTokens, expectedTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
targetH := prevH * scale
|
||||||
|
targetW := prevW * scale
|
||||||
|
upsampled := make([]int32, targetH*targetW)
|
||||||
|
|
||||||
|
for i := int32(0); i < prevH; i++ {
|
||||||
|
for j := int32(0); j < prevW; j++ {
|
||||||
|
srcIdx := i*prevW + j
|
||||||
|
|
||||||
|
// Handle early EOS: use 0 (padding) for missing tokens
|
||||||
|
var val int32
|
||||||
|
if srcIdx < numTokens {
|
||||||
|
val = tokenData[srcIdx]
|
||||||
|
} else {
|
||||||
|
val = 0 // Padding token
|
||||||
|
}
|
||||||
|
|
||||||
|
// Place in scale*scale positions
|
||||||
|
dstI := i * scale
|
||||||
|
dstJ := j * scale
|
||||||
|
for di := int32(0); di < scale; di++ {
|
||||||
|
for dj := int32(0); dj < scale; dj++ {
|
||||||
|
upsampled[(dstI+di)*targetW+(dstJ+dj)] = val
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return mlx.NewArrayInt32(upsampled, []int32{1, targetH * targetW})
|
||||||
|
}
|
||||||
|
|
||||||
|
// PatchifyLatents converts [B, C, H, W] to [B, L, C*p*p]
|
||||||
|
func PatchifyLatents(latents *mlx.Array, patchSize int32) *mlx.Array {
|
||||||
|
shape := latents.Shape()
|
||||||
|
B := shape[0]
|
||||||
|
C := shape[1]
|
||||||
|
H := shape[2]
|
||||||
|
W := shape[3]
|
||||||
|
|
||||||
|
pH := H / patchSize
|
||||||
|
pW := W / patchSize
|
||||||
|
|
||||||
|
// Reshape: [B, C, H, W] -> [B, C, pH, p, pW, p]
|
||||||
|
x := mlx.Reshape(latents, B, C, pH, patchSize, pW, patchSize)
|
||||||
|
// Transpose: -> [B, pH, pW, C, p, p]
|
||||||
|
x = mlx.Transpose(x, 0, 2, 4, 1, 3, 5)
|
||||||
|
// Flatten: -> [B, pH*pW, C*p*p]
|
||||||
|
return mlx.Reshape(x, B, pH*pW, C*patchSize*patchSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnpatchifyLatents converts [B, L, C*p*p] back to [B, C, H, W]
|
||||||
|
func UnpatchifyLatents(patches *mlx.Array, H, W, patchSize, channels int32) *mlx.Array {
|
||||||
|
shape := patches.Shape()
|
||||||
|
B := shape[0]
|
||||||
|
|
||||||
|
pH := H / patchSize
|
||||||
|
pW := W / patchSize
|
||||||
|
|
||||||
|
// Reshape: [B, L, C*p*p] -> [B, pH, pW, C, p, p]
|
||||||
|
x := mlx.Reshape(patches, B, pH, pW, channels, patchSize, patchSize)
|
||||||
|
// Transpose: -> [B, C, pH, p, pW, p]
|
||||||
|
x = mlx.Transpose(x, 0, 3, 1, 4, 2, 5)
|
||||||
|
// Reshape: -> [B, C, H, W]
|
||||||
|
return mlx.Reshape(x, B, channels, pH*patchSize, pW*patchSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CalculateShift computes the dynamic shift for flow matching based on image sequence length.
|
||||||
|
func CalculateShift(imgSeqLen int32) float32 {
|
||||||
|
cfg := DefaultSchedulerConfig()
|
||||||
|
if !cfg.UseDynamicShifting {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sqrt-based shift calculation (matches diffusers)
|
||||||
|
m := float32(math.Sqrt(float64(imgSeqLen) / float64(cfg.BaseImageSeqLen)))
|
||||||
|
return m*cfg.MaxShift + cfg.BaseShift
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpsampleTokens2x upsamples token IDs by 2x using nearest neighbor interpolation
|
||||||
|
// tokens: [B, H*W] -> [B, (H*2)*(W*2)]
|
||||||
|
// This matches diffusers' _upsample_token_ids function
|
||||||
|
func UpsampleTokens2x(tokens *mlx.Array, gridH, gridW int32) *mlx.Array {
|
||||||
|
shape := tokens.Shape()
|
||||||
|
B := shape[0]
|
||||||
|
|
||||||
|
// Reshape to [B, 1, H, W] for interpolation
|
||||||
|
tokens = mlx.Reshape(tokens, B, 1, gridH, gridW)
|
||||||
|
|
||||||
|
// Convert to float for interpolation
|
||||||
|
tokensFloat := mlx.AsType(tokens, mlx.DtypeFloat32)
|
||||||
|
|
||||||
|
// 2x nearest neighbor upsample
|
||||||
|
// [B, 1, H, W] -> [B, 1, H*2, W*2]
|
||||||
|
upsampled := nearestUpsample2x(tokensFloat)
|
||||||
|
|
||||||
|
// Convert back to int and reshape to [B, H*2*W*2]
|
||||||
|
upsampled = mlx.AsType(upsampled, mlx.DtypeInt32)
|
||||||
|
return mlx.Reshape(upsampled, B, gridH*2*gridW*2)
|
||||||
|
}
|
||||||
|
|
||||||
|
// nearestUpsample2x performs 2x nearest neighbor upsampling on NCHW tensor
|
||||||
|
func nearestUpsample2x(x *mlx.Array) *mlx.Array {
|
||||||
|
shape := x.Shape()
|
||||||
|
B := shape[0]
|
||||||
|
C := shape[1]
|
||||||
|
H := shape[2]
|
||||||
|
W := shape[3]
|
||||||
|
|
||||||
|
// Repeat each element 2x2
|
||||||
|
// [B, C, H, W] -> [B, C, H, 1, W, 1] -> [B, C, H, 2, W, 2] -> [B, C, H*2, W*2]
|
||||||
|
x = mlx.Reshape(x, B, C, H, 1, W, 1)
|
||||||
|
|
||||||
|
// Tile to repeat each pixel 2x2
|
||||||
|
x = mlx.Tile(x, []int32{1, 1, 1, 2, 1, 2})
|
||||||
|
|
||||||
|
// Reshape to final size
|
||||||
|
return mlx.Reshape(x, B, C, H*2, W*2)
|
||||||
|
}
|
||||||
358
x/imagegen/models/glm_image/glm_tokenizer.go
Normal file
358
x/imagegen/models/glm_image/glm_tokenizer.go
Normal file
@@ -0,0 +1,358 @@
|
|||||||
|
//go:build mlx
|
||||||
|
|
||||||
|
package glm_image
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/x/imagegen"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GLMTokenizer implements the GLM tokenizer for the AR model
|
||||||
|
// This is a BPE-style tokenizer with ignore_merges=true, meaning it does
|
||||||
|
// greedy longest-match tokenization from the vocab without runtime merging.
|
||||||
|
type GLMTokenizer struct {
|
||||||
|
Vocab map[string]int32 // token string -> token ID
|
||||||
|
VocabReverse map[int32]string // token ID -> token string
|
||||||
|
SpecialTokens map[string]int32 // special token strings -> IDs
|
||||||
|
|
||||||
|
// Special token IDs
|
||||||
|
SopTokenID int32 // <sop> = grid_bos_token (167845)
|
||||||
|
EopTokenID int32 // <eop> = grid_eos_token (167846)
|
||||||
|
BosTokenID int32 // <|dit_token_16384|> = visual BOS (16384)
|
||||||
|
EosTokenID int32 // <|dit_token_16385|> = visual EOS (16385)
|
||||||
|
PadTokenID int32
|
||||||
|
|
||||||
|
// Sorted vocab keys by length (longest first) for greedy matching
|
||||||
|
sortedTokens []string
|
||||||
|
}
|
||||||
|
|
||||||
|
// tokenizerJSON represents the structure of tokenizer.json
|
||||||
|
type tokenizerJSON struct {
|
||||||
|
Model struct {
|
||||||
|
Vocab map[string]int32 `json:"vocab"`
|
||||||
|
} `json:"model"`
|
||||||
|
AddedTokens []struct {
|
||||||
|
ID int32 `json:"id"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
Special bool `json:"special"`
|
||||||
|
} `json:"added_tokens"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewGLMTokenizer creates a GLM tokenizer from the model manifest
|
||||||
|
func NewGLMTokenizer(manifest *imagegen.ModelManifest) (*GLMTokenizer, error) {
|
||||||
|
// Read tokenizer.json from processor directory in manifest
|
||||||
|
data, err := manifest.ReadConfig("processor/tokenizer.json")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read tokenizer.json from manifest: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var tj tokenizerJSON
|
||||||
|
if err := json.Unmarshal(data, &tj); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse tokenizer.json: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tok := &GLMTokenizer{
|
||||||
|
Vocab: make(map[string]int32),
|
||||||
|
VocabReverse: make(map[int32]string),
|
||||||
|
SpecialTokens: make(map[string]int32),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load vocab from model section
|
||||||
|
for token, id := range tj.Model.Vocab {
|
||||||
|
tok.Vocab[token] = id
|
||||||
|
tok.VocabReverse[id] = token
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load added tokens (special tokens including dit_tokens)
|
||||||
|
for _, at := range tj.AddedTokens {
|
||||||
|
tok.Vocab[at.Content] = at.ID
|
||||||
|
tok.VocabReverse[at.ID] = at.Content
|
||||||
|
if at.Special {
|
||||||
|
tok.SpecialTokens[at.Content] = at.ID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set special token IDs
|
||||||
|
tok.SopTokenID = 167845 // <sop>
|
||||||
|
tok.EopTokenID = 167846 // <eop>
|
||||||
|
tok.BosTokenID = 16384 // <|dit_token_16384|>
|
||||||
|
tok.EosTokenID = 16385 // <|dit_token_16385|>
|
||||||
|
tok.PadTokenID = 16385 // Same as EOS
|
||||||
|
|
||||||
|
// Build sorted token list for greedy matching (longest first)
|
||||||
|
tok.sortedTokens = make([]string, 0, len(tok.Vocab))
|
||||||
|
for token := range tok.Vocab {
|
||||||
|
tok.sortedTokens = append(tok.sortedTokens, token)
|
||||||
|
}
|
||||||
|
sort.Slice(tok.sortedTokens, func(i, j int) bool {
|
||||||
|
return len(tok.sortedTokens[i]) > len(tok.sortedTokens[j])
|
||||||
|
})
|
||||||
|
|
||||||
|
fmt.Printf("Loaded GLM tokenizer with %d tokens\n", len(tok.Vocab))
|
||||||
|
|
||||||
|
return tok, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewGLMTokenizerFromPath creates a GLM tokenizer from a directory path
|
||||||
|
func NewGLMTokenizerFromPath(modelPath string) (*GLMTokenizer, error) {
|
||||||
|
// Read tokenizer.json from processor directory
|
||||||
|
tokenizerPath := filepath.Join(modelPath, "processor", "tokenizer.json")
|
||||||
|
data, err := os.ReadFile(tokenizerPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read tokenizer.json: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var tj tokenizerJSON
|
||||||
|
if err := json.Unmarshal(data, &tj); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse tokenizer.json: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tok := &GLMTokenizer{
|
||||||
|
Vocab: make(map[string]int32),
|
||||||
|
VocabReverse: make(map[int32]string),
|
||||||
|
SpecialTokens: make(map[string]int32),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load vocab from model section
|
||||||
|
for token, id := range tj.Model.Vocab {
|
||||||
|
tok.Vocab[token] = id
|
||||||
|
tok.VocabReverse[id] = token
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load added tokens (special tokens including dit_tokens)
|
||||||
|
for _, at := range tj.AddedTokens {
|
||||||
|
tok.Vocab[at.Content] = at.ID
|
||||||
|
tok.VocabReverse[at.ID] = at.Content
|
||||||
|
if at.Special {
|
||||||
|
tok.SpecialTokens[at.Content] = at.ID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set special token IDs
|
||||||
|
tok.SopTokenID = 167845 // <sop>
|
||||||
|
tok.EopTokenID = 167846 // <eop>
|
||||||
|
tok.BosTokenID = 16384 // <|dit_token_16384|>
|
||||||
|
tok.EosTokenID = 16385 // <|dit_token_16385|>
|
||||||
|
tok.PadTokenID = 16385 // Same as EOS
|
||||||
|
|
||||||
|
// Build sorted token list for greedy matching (longest first)
|
||||||
|
tok.sortedTokens = make([]string, 0, len(tok.Vocab))
|
||||||
|
for token := range tok.Vocab {
|
||||||
|
tok.sortedTokens = append(tok.sortedTokens, token)
|
||||||
|
}
|
||||||
|
sort.Slice(tok.sortedTokens, func(i, j int) bool {
|
||||||
|
return len(tok.sortedTokens[i]) > len(tok.sortedTokens[j])
|
||||||
|
})
|
||||||
|
|
||||||
|
fmt.Printf("Loaded GLM tokenizer with %d tokens\n", len(tok.Vocab))
|
||||||
|
|
||||||
|
return tok, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode tokenizes a string into token IDs
|
||||||
|
// This uses greedy longest-match tokenization with GPT-2 style space handling
|
||||||
|
func (t *GLMTokenizer) Encode(text string) []int32 {
|
||||||
|
if text == "" {
|
||||||
|
return []int32{}
|
||||||
|
}
|
||||||
|
|
||||||
|
var tokens []int32
|
||||||
|
|
||||||
|
// First, check for and handle special tokens
|
||||||
|
// Replace special tokens with placeholders, encode, then restore
|
||||||
|
specialReplacements := make(map[string]int32)
|
||||||
|
for special, id := range t.SpecialTokens {
|
||||||
|
if strings.Contains(text, special) {
|
||||||
|
specialReplacements[special] = id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process text character by character with special token handling
|
||||||
|
i := 0
|
||||||
|
isFirstToken := true
|
||||||
|
|
||||||
|
for i < len(text) {
|
||||||
|
// Check for special tokens first
|
||||||
|
foundSpecial := false
|
||||||
|
for special, id := range specialReplacements {
|
||||||
|
if strings.HasPrefix(text[i:], special) {
|
||||||
|
tokens = append(tokens, id)
|
||||||
|
i += len(special)
|
||||||
|
isFirstToken = false
|
||||||
|
foundSpecial = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if foundSpecial {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle regular text with GPT-2 style space prefix
|
||||||
|
// "Ġ" (U+0120) represents a space before a token
|
||||||
|
remaining := text[i:]
|
||||||
|
|
||||||
|
// Try to find the longest matching token
|
||||||
|
matched := false
|
||||||
|
for _, token := range t.sortedTokens {
|
||||||
|
// Skip special tokens in regular matching
|
||||||
|
if _, isSpecial := t.SpecialTokens[token]; isSpecial {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if this token matches
|
||||||
|
tokenText := token
|
||||||
|
|
||||||
|
// Handle the Ġ prefix (represents space)
|
||||||
|
if strings.HasPrefix(token, "Ġ") {
|
||||||
|
// This token expects a leading space
|
||||||
|
if i > 0 || !isFirstToken {
|
||||||
|
// Check if remaining starts with space + token content
|
||||||
|
tokenContent := token[len("Ġ"):]
|
||||||
|
if strings.HasPrefix(remaining, " "+tokenContent) {
|
||||||
|
tokens = append(tokens, t.Vocab[token])
|
||||||
|
i += 1 + len(tokenContent) // space + content
|
||||||
|
isFirstToken = false
|
||||||
|
matched = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Regular token without space prefix
|
||||||
|
if strings.HasPrefix(remaining, tokenText) {
|
||||||
|
tokens = append(tokens, t.Vocab[token])
|
||||||
|
i += len(tokenText)
|
||||||
|
isFirstToken = false
|
||||||
|
matched = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !matched {
|
||||||
|
// No token found - skip this character (or use UNK)
|
||||||
|
// For now, just skip unknown characters
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return tokens
|
||||||
|
}
|
||||||
|
|
||||||
|
// EncodeForGeneration encodes a prompt with grid tokens for image generation
|
||||||
|
// Format: {prompt}<sop>{token_h} {token_w}<eop><sop>{prev_h} {prev_w}<eop><|dit_token_16384|>
|
||||||
|
//
|
||||||
|
// Uses GPT-2 style tokenization where " 32" becomes "Ġ32" (a single token with
|
||||||
|
// space prefix), matching the HuggingFace tokenizer behavior.
|
||||||
|
func (t *GLMTokenizer) EncodeForGeneration(prompt string, targetHeight, targetWidth int32) []int32 {
|
||||||
|
// Calculate grid dimensions
|
||||||
|
factor := int32(32)
|
||||||
|
height := (targetHeight / factor) * factor
|
||||||
|
width := (targetWidth / factor) * factor
|
||||||
|
tokenH := height / factor
|
||||||
|
tokenW := width / factor
|
||||||
|
|
||||||
|
// Calculate previous grid dimensions
|
||||||
|
ratio := float64(tokenH) / float64(tokenW)
|
||||||
|
prevTokenH := int32(sqrt(ratio) * 16)
|
||||||
|
prevTokenW := int32(sqrt(1.0/ratio) * 16)
|
||||||
|
|
||||||
|
// Encode the prompt text
|
||||||
|
promptTokens := t.Encode(prompt)
|
||||||
|
|
||||||
|
// Build the full sequence:
|
||||||
|
// [prompt tokens] <sop> [tokenH] [Ġ+tokenW] <eop> <sop> [prevH] [Ġ+prevW] <eop> <bos>
|
||||||
|
// Note: HF tokenizer treats " 32" as "Ġ32" (single token), not "Ġ" + "32"
|
||||||
|
var tokens []int32
|
||||||
|
tokens = append(tokens, promptTokens...)
|
||||||
|
|
||||||
|
// First grid: <sop> H W <eop>
|
||||||
|
// First number has no space prefix, second number has space prefix (Ġ)
|
||||||
|
tokens = append(tokens, t.SopTokenID)
|
||||||
|
tokens = append(tokens, t.encodeNumber(tokenH)...)
|
||||||
|
tokens = append(tokens, t.encodeSpaceNumber(tokenW)...) // " W" as Ġ+W
|
||||||
|
tokens = append(tokens, t.EopTokenID)
|
||||||
|
|
||||||
|
// Second grid: <sop> prevH prevW <eop>
|
||||||
|
tokens = append(tokens, t.SopTokenID)
|
||||||
|
tokens = append(tokens, t.encodeNumber(prevTokenH)...)
|
||||||
|
tokens = append(tokens, t.encodeSpaceNumber(prevTokenW)...) // " prevW" as Ġ+prevW
|
||||||
|
tokens = append(tokens, t.EopTokenID)
|
||||||
|
|
||||||
|
// BOS token (start of image generation)
|
||||||
|
tokens = append(tokens, t.BosTokenID)
|
||||||
|
|
||||||
|
return tokens
|
||||||
|
}
|
||||||
|
|
||||||
|
// encodeNumber encodes a number - first tries as a whole token, falls back to digit-by-digit
|
||||||
|
func (t *GLMTokenizer) encodeNumber(n int32) []int32 {
|
||||||
|
s := fmt.Sprintf("%d", n)
|
||||||
|
// First try: look up the whole number as a single token
|
||||||
|
if id, ok := t.Vocab[s]; ok {
|
||||||
|
return []int32{id}
|
||||||
|
}
|
||||||
|
// Fallback: encode digit by digit
|
||||||
|
var tokens []int32
|
||||||
|
for _, c := range s {
|
||||||
|
if id, ok := t.Vocab[string(c)]; ok {
|
||||||
|
tokens = append(tokens, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return tokens
|
||||||
|
}
|
||||||
|
|
||||||
|
// encodeSpaceNumber encodes " N" as "ĠN" (space-prefixed number) matching HF tokenizer
|
||||||
|
// GPT-2 style: " 32" becomes single token "Ġ32", not "Ġ" + "32"
|
||||||
|
func (t *GLMTokenizer) encodeSpaceNumber(n int32) []int32 {
|
||||||
|
s := fmt.Sprintf("%d", n)
|
||||||
|
|
||||||
|
// First try: look up "Ġ{number}" as a single token (e.g., "Ġ32")
|
||||||
|
spaceToken := "Ġ" + s
|
||||||
|
if id, ok := t.Vocab[spaceToken]; ok {
|
||||||
|
return []int32{id}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback: bare space Ġ + number tokens
|
||||||
|
var tokens []int32
|
||||||
|
if spaceID, ok := t.Vocab["Ġ"]; ok {
|
||||||
|
tokens = append(tokens, spaceID)
|
||||||
|
}
|
||||||
|
tokens = append(tokens, t.encodeNumber(n)...)
|
||||||
|
return tokens
|
||||||
|
}
|
||||||
|
|
||||||
|
// sqrt is a helper for float64 sqrt
|
||||||
|
func sqrt(x float64) float64 {
|
||||||
|
if x <= 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
// Newton's method
|
||||||
|
z := x
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
z = z - (z*z-x)/(2*z)
|
||||||
|
}
|
||||||
|
return z
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode converts token IDs back to a string
|
||||||
|
func (t *GLMTokenizer) Decode(tokens []int32) string {
|
||||||
|
var sb strings.Builder
|
||||||
|
for _, id := range tokens {
|
||||||
|
if token, ok := t.VocabReverse[id]; ok {
|
||||||
|
// Handle Ġ prefix (convert back to space)
|
||||||
|
if strings.HasPrefix(token, "Ġ") {
|
||||||
|
sb.WriteString(" ")
|
||||||
|
sb.WriteString(token[len("Ġ"):])
|
||||||
|
} else {
|
||||||
|
sb.WriteString(token)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return sb.String()
|
||||||
|
}
|
||||||
159
x/imagegen/models/glm_image/scheduler.go
Normal file
159
x/imagegen/models/glm_image/scheduler.go
Normal file
@@ -0,0 +1,159 @@
|
|||||||
|
//go:build mlx
|
||||||
|
|
||||||
|
package glm_image
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||||
|
)
|
||||||
|
|
||||||
|
// FlowMatchSchedulerConfig holds scheduler configuration
|
||||||
|
type FlowMatchSchedulerConfig struct {
|
||||||
|
NumTrainTimesteps int32 `json:"num_train_timesteps"` // 1000
|
||||||
|
BaseShift float32 `json:"base_shift"` // 0.25
|
||||||
|
MaxShift float32 `json:"max_shift"` // 0.75
|
||||||
|
BaseImageSeqLen int32 `json:"base_image_seq_len"` // 256
|
||||||
|
MaxImageSeqLen int32 `json:"max_image_seq_len"` // 4096
|
||||||
|
UseDynamicShifting bool `json:"use_dynamic_shifting"` // true
|
||||||
|
TimeShiftType string `json:"time_shift_type"` // "linear"
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultSchedulerConfig returns the default config for GLM-Image
|
||||||
|
func DefaultSchedulerConfig() *FlowMatchSchedulerConfig {
|
||||||
|
return &FlowMatchSchedulerConfig{
|
||||||
|
NumTrainTimesteps: 1000,
|
||||||
|
BaseShift: 0.25,
|
||||||
|
MaxShift: 0.75,
|
||||||
|
BaseImageSeqLen: 256,
|
||||||
|
MaxImageSeqLen: 4096,
|
||||||
|
UseDynamicShifting: true,
|
||||||
|
TimeShiftType: "linear",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// FlowMatchScheduler implements FlowMatchEulerDiscreteScheduler
|
||||||
|
type FlowMatchScheduler struct {
|
||||||
|
Config *FlowMatchSchedulerConfig
|
||||||
|
Timesteps []float32 // Raw timesteps for transformer conditioning (unshifted)
|
||||||
|
Sigmas []float32 // Shifted sigmas for Euler step calculation
|
||||||
|
NumSteps int
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewFlowMatchScheduler creates a new scheduler
|
||||||
|
func NewFlowMatchScheduler(cfg *FlowMatchSchedulerConfig) *FlowMatchScheduler {
|
||||||
|
return &FlowMatchScheduler{Config: cfg}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTimestepsWithDynamicShift sets timesteps with dynamic shifting based on image size
|
||||||
|
// Following diffusers: raw timesteps are used for conditioning, shifted sigmas for step calculation
|
||||||
|
func (s *FlowMatchScheduler) SetTimestepsWithDynamicShift(numSteps int, imgSeqLen int32) {
|
||||||
|
s.NumSteps = numSteps
|
||||||
|
|
||||||
|
// Calculate shift (mu) based on image sequence length
|
||||||
|
mu := s.calculateShift(imgSeqLen)
|
||||||
|
|
||||||
|
// Create timesteps: linspace from sigma_max_t to sigma_min_t
|
||||||
|
// sigma_max = 1.0, sigma_min ~= 0.001 (near 0 but not exactly 0)
|
||||||
|
// Then apply time shift and append terminal sigma=0
|
||||||
|
s.Timesteps = make([]float32, numSteps)
|
||||||
|
s.Sigmas = make([]float32, numSteps+1) // +1 for terminal sigma
|
||||||
|
|
||||||
|
numTrainTimesteps := float32(s.Config.NumTrainTimesteps)
|
||||||
|
|
||||||
|
// Create base sigmas: linspace from 1.0 to small value (matching diffusers)
|
||||||
|
for i := 0; i < numSteps; i++ {
|
||||||
|
// linspace from 1000 to ~20 (sigma_min * num_train_timesteps)
|
||||||
|
tRaw := numTrainTimesteps - float32(i)*(numTrainTimesteps-1.0)/float32(numSteps-1)
|
||||||
|
s.Timesteps[i] = tRaw
|
||||||
|
|
||||||
|
// Convert to sigma [0, 1]
|
||||||
|
sigma := tRaw / numTrainTimesteps
|
||||||
|
|
||||||
|
// Apply time shift if enabled
|
||||||
|
if s.Config.UseDynamicShifting && mu > 0 {
|
||||||
|
sigma = s.applyShift(mu, sigma)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.Sigmas[i] = sigma
|
||||||
|
}
|
||||||
|
|
||||||
|
// Append terminal sigma = 0 (the final clean image)
|
||||||
|
s.Sigmas[numSteps] = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// calculateShift computes dynamic shift based on image sequence length
|
||||||
|
// Uses the sqrt-based formula from diffusers:
|
||||||
|
// m = (image_seq_len / base_seq_len) ** 0.5
|
||||||
|
// mu = m * max_shift + base_shift
|
||||||
|
func (s *FlowMatchScheduler) calculateShift(imgSeqLen int32) float32 {
|
||||||
|
cfg := s.Config
|
||||||
|
|
||||||
|
if !cfg.UseDynamicShifting {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sqrt-based shift calculation (matches diffusers pipeline_glm_image.py)
|
||||||
|
m := float32(math.Sqrt(float64(imgSeqLen) / float64(cfg.BaseImageSeqLen)))
|
||||||
|
mu := m*cfg.MaxShift + cfg.BaseShift
|
||||||
|
return mu
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyShift applies time shift transformation
|
||||||
|
// mu: the computed shift value
|
||||||
|
// t: sigma value in [0, 1]
|
||||||
|
func (s *FlowMatchScheduler) applyShift(mu float32, t float32) float32 {
|
||||||
|
if t <= 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
if t >= 1 {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// sigma=1.0 for both shift types
|
||||||
|
sigma := float32(1.0)
|
||||||
|
|
||||||
|
if s.Config.TimeShiftType == "linear" {
|
||||||
|
// Linear: mu / (mu + (1/t - 1)^sigma)
|
||||||
|
return mu / (mu + float32(math.Pow(float64(1.0/t-1.0), float64(sigma))))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exponential (default): exp(mu) / (exp(mu) + (1/t - 1)^sigma)
|
||||||
|
expMu := float32(math.Exp(float64(mu)))
|
||||||
|
return expMu / (expMu + float32(math.Pow(float64(1.0/t-1.0), float64(sigma))))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step performs one denoising step
|
||||||
|
func (s *FlowMatchScheduler) Step(modelOutput, sample *mlx.Array, stepIdx int) *mlx.Array {
|
||||||
|
sigma := s.Sigmas[stepIdx]
|
||||||
|
sigmaNext := s.Sigmas[stepIdx+1]
|
||||||
|
|
||||||
|
// Euler step: x_{t-dt} = x_t + dt * v_t
|
||||||
|
dt := sigmaNext - sigma // Negative (going from noise to clean)
|
||||||
|
|
||||||
|
scaledOutput := mlx.MulScalar(modelOutput, dt)
|
||||||
|
return mlx.Add(sample, scaledOutput)
|
||||||
|
}
|
||||||
|
|
||||||
|
// InitNoise creates initial noise
|
||||||
|
func (s *FlowMatchScheduler) InitNoise(shape []int32, seed int64) *mlx.Array {
|
||||||
|
return mlx.RandomNormalWithDtype(shape, uint64(seed), mlx.DtypeBFloat16)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddNoise adds noise to clean samples for a given timestep (for img2img)
|
||||||
|
func (s *FlowMatchScheduler) AddNoise(cleanSample, noise *mlx.Array, timestepIdx int) *mlx.Array {
|
||||||
|
// In flow matching: x_t = (1-sigma) * x_0 + sigma * noise
|
||||||
|
// Use sigmas (shifted) for the interpolation
|
||||||
|
sigma := s.Sigmas[timestepIdx]
|
||||||
|
oneMinusSigma := 1.0 - sigma
|
||||||
|
|
||||||
|
scaledClean := mlx.MulScalar(cleanSample, oneMinusSigma)
|
||||||
|
scaledNoise := mlx.MulScalar(noise, sigma)
|
||||||
|
|
||||||
|
return mlx.Add(scaledClean, scaledNoise)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTimesteps returns all timesteps
|
||||||
|
func (s *FlowMatchScheduler) GetTimesteps() []float32 {
|
||||||
|
return s.Timesteps
|
||||||
|
}
|
||||||
497
x/imagegen/models/glm_image/text_encoder.go
Normal file
497
x/imagegen/models/glm_image/text_encoder.go
Normal file
@@ -0,0 +1,497 @@
|
|||||||
|
//go:build mlx
|
||||||
|
|
||||||
|
package glm_image
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"math"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"regexp"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/x/imagegen"
|
||||||
|
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||||
|
"github.com/ollama/ollama/x/imagegen/nn"
|
||||||
|
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// T5Config holds T5 encoder configuration
|
||||||
|
type T5Config struct {
|
||||||
|
DModel int32 `json:"d_model"` // 1472
|
||||||
|
DFF int32 `json:"d_ff"` // 3584
|
||||||
|
DKV int32 `json:"d_kv"` // 64
|
||||||
|
NumHeads int32 `json:"num_heads"` // 6
|
||||||
|
NumLayers int32 `json:"num_layers"` // 12
|
||||||
|
VocabSize int32 `json:"vocab_size"` // 384 (byte-level)
|
||||||
|
LayerNormEps float32 `json:"layer_norm_epsilon"` // 1e-6
|
||||||
|
IsGatedAct bool `json:"is_gated_act"` // true (gated-gelu)
|
||||||
|
|
||||||
|
// Relative position bias
|
||||||
|
RelativeAttentionNumBuckets int32 `json:"relative_attention_num_buckets"` // 32
|
||||||
|
RelativeAttentionMaxDistance int32 `json:"relative_attention_max_distance"` // 128
|
||||||
|
}
|
||||||
|
|
||||||
|
// T5TextEncoder is the T5 encoder for text conditioning
|
||||||
|
type T5TextEncoder struct {
|
||||||
|
Config *T5Config
|
||||||
|
|
||||||
|
// Embedding (shared for ByT5)
|
||||||
|
SharedEmbed *nn.Embedding `weight:"shared"`
|
||||||
|
|
||||||
|
// Encoder layers
|
||||||
|
Layers []*T5Block `weight:"encoder.block"`
|
||||||
|
|
||||||
|
// Final layer norm
|
||||||
|
FinalNorm *T5LayerNorm `weight:"encoder.final_layer_norm"`
|
||||||
|
|
||||||
|
// Relative position bias (from first layer, shared across all)
|
||||||
|
RelativeAttentionBias *mlx.Array `weight:"encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// T5Block is a single T5 encoder block
|
||||||
|
type T5Block struct {
|
||||||
|
// Self attention
|
||||||
|
Layer0 *T5LayerSelfAttention `weight:"layer.0"`
|
||||||
|
// FFN
|
||||||
|
Layer1 *T5LayerFF `weight:"layer.1"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// T5LayerSelfAttention is T5's self-attention layer
|
||||||
|
type T5LayerSelfAttention struct {
|
||||||
|
SelfAttention *T5Attention `weight:"SelfAttention"`
|
||||||
|
LayerNorm *T5LayerNorm `weight:"layer_norm"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// T5Attention implements T5's relative attention
|
||||||
|
type T5Attention struct {
|
||||||
|
Q *mlx.Array `weight:"q.weight"` // No bias in T5
|
||||||
|
K *mlx.Array `weight:"k.weight"`
|
||||||
|
V *mlx.Array `weight:"v.weight"`
|
||||||
|
O *mlx.Array `weight:"o.weight"`
|
||||||
|
|
||||||
|
NHeads int32
|
||||||
|
DKV int32
|
||||||
|
Scale float32
|
||||||
|
}
|
||||||
|
|
||||||
|
// T5LayerFF is T5's feedforward layer with gated-gelu
|
||||||
|
type T5LayerFF struct {
|
||||||
|
DenseReluDense *T5DenseGatedGelu `weight:"DenseReluDense"`
|
||||||
|
LayerNorm *T5LayerNorm `weight:"layer_norm"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// T5DenseGatedGelu is T5's gated-gelu FFN
|
||||||
|
type T5DenseGatedGelu struct {
|
||||||
|
Wi0 *mlx.Array `weight:"wi_0.weight"` // gate projection
|
||||||
|
Wi1 *mlx.Array `weight:"wi_1.weight"` // up projection
|
||||||
|
Wo *mlx.Array `weight:"wo.weight"` // down projection
|
||||||
|
}
|
||||||
|
|
||||||
|
// T5LayerNorm is T5's RMSNorm variant (no bias, no mean subtraction)
|
||||||
|
type T5LayerNorm struct {
|
||||||
|
Weight *mlx.Array `weight:"weight"`
|
||||||
|
Eps float32
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load loads the T5 text encoder from manifest
|
||||||
|
func (m *T5TextEncoder) Load(manifest *imagegen.ModelManifest) error {
|
||||||
|
fmt.Print(" Loading T5 text encoder... ")
|
||||||
|
|
||||||
|
// Load config
|
||||||
|
var cfg T5Config
|
||||||
|
if err := manifest.ReadConfigJSON("text_encoder/config.json", &cfg); err != nil {
|
||||||
|
return fmt.Errorf("config: %w", err)
|
||||||
|
}
|
||||||
|
m.Config = &cfg
|
||||||
|
|
||||||
|
// Pre-allocate layers
|
||||||
|
m.Layers = make([]*T5Block, cfg.NumLayers)
|
||||||
|
|
||||||
|
// Load weights
|
||||||
|
weights, err := imagegen.LoadWeightsFromManifest(manifest, "text_encoder")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("weights: %w", err)
|
||||||
|
}
|
||||||
|
if err := weights.Load(0); err != nil {
|
||||||
|
return fmt.Errorf("load weights: %w", err)
|
||||||
|
}
|
||||||
|
defer weights.ReleaseAll()
|
||||||
|
|
||||||
|
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||||
|
return fmt.Errorf("load module: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.initComputedFields()
|
||||||
|
fmt.Println("✓")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadFromPath loads the T5 text encoder from a directory path
|
||||||
|
func (m *T5TextEncoder) LoadFromPath(path string) error {
|
||||||
|
fmt.Print(" Loading T5 text encoder... ")
|
||||||
|
|
||||||
|
// Load config
|
||||||
|
var cfg T5Config
|
||||||
|
configPath := filepath.Join(path, "config.json")
|
||||||
|
data, err := os.ReadFile(configPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("read config: %w", err)
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||||
|
return fmt.Errorf("parse config: %w", err)
|
||||||
|
}
|
||||||
|
m.Config = &cfg
|
||||||
|
|
||||||
|
// Pre-allocate layers
|
||||||
|
m.Layers = make([]*T5Block, cfg.NumLayers)
|
||||||
|
|
||||||
|
// Load weights from safetensors files
|
||||||
|
weights, err := safetensors.LoadModelWeights(path)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("weights: %w", err)
|
||||||
|
}
|
||||||
|
if err := weights.Load(0); err != nil {
|
||||||
|
return fmt.Errorf("load weights: %w", err)
|
||||||
|
}
|
||||||
|
defer weights.ReleaseAll()
|
||||||
|
|
||||||
|
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||||
|
return fmt.Errorf("load module: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.initComputedFields()
|
||||||
|
fmt.Println("✓")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *T5TextEncoder) initComputedFields() {
|
||||||
|
cfg := m.Config
|
||||||
|
m.FinalNorm.Eps = cfg.LayerNormEps
|
||||||
|
for _, block := range m.Layers {
|
||||||
|
attn := block.Layer0.SelfAttention
|
||||||
|
attn.NHeads = cfg.NumHeads
|
||||||
|
attn.DKV = cfg.DKV
|
||||||
|
attn.Scale = float32(1.0 / math.Sqrt(float64(cfg.DKV)))
|
||||||
|
|
||||||
|
block.Layer0.LayerNorm.Eps = cfg.LayerNormEps
|
||||||
|
block.Layer1.LayerNorm.Eps = cfg.LayerNormEps
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Forward encodes text tokens
|
||||||
|
func (m *T5TextEncoder) Forward(tokens *mlx.Array) *mlx.Array {
|
||||||
|
cfg := m.Config
|
||||||
|
|
||||||
|
// Get embeddings
|
||||||
|
h := m.SharedEmbed.Forward(tokens)
|
||||||
|
|
||||||
|
// Compute relative position bias once
|
||||||
|
seqLen := tokens.Shape()[1]
|
||||||
|
posBias := m.computeRelativePositionBias(seqLen)
|
||||||
|
|
||||||
|
// Forward through layers
|
||||||
|
for _, block := range m.Layers {
|
||||||
|
h = block.Forward(h, posBias, cfg.LayerNormEps)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Final norm
|
||||||
|
h = m.FinalNorm.Forward(h)
|
||||||
|
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractGlyphTexts extracts quoted text (glyphs) from the prompt
|
||||||
|
// This matches diffusers' get_glyph_texts from pipeline_glm_image.py
|
||||||
|
// Glyph texts are used for text rendering guidance in the generated image
|
||||||
|
func extractGlyphTexts(prompt string) []string {
|
||||||
|
var glyphTexts []string
|
||||||
|
|
||||||
|
// Extract text in single quotes: 'text'
|
||||||
|
re1 := regexp.MustCompile(`'([^']*)'`)
|
||||||
|
for _, match := range re1.FindAllStringSubmatch(prompt, -1) {
|
||||||
|
if len(match) > 1 {
|
||||||
|
glyphTexts = append(glyphTexts, match[1])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract text in Unicode curly double quotes: "text"
|
||||||
|
re2 := regexp.MustCompile(`"([^""]*)"`)
|
||||||
|
for _, match := range re2.FindAllStringSubmatch(prompt, -1) {
|
||||||
|
if len(match) > 1 {
|
||||||
|
glyphTexts = append(glyphTexts, match[1])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract text in ASCII double quotes: "text"
|
||||||
|
re3 := regexp.MustCompile(`"([^"]*)"`)
|
||||||
|
for _, match := range re3.FindAllStringSubmatch(prompt, -1) {
|
||||||
|
if len(match) > 1 {
|
||||||
|
glyphTexts = append(glyphTexts, match[1])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract text in Japanese quotes: 「text」
|
||||||
|
re4 := regexp.MustCompile(`「([^「」]*)」`)
|
||||||
|
for _, match := range re4.FindAllStringSubmatch(prompt, -1) {
|
||||||
|
if len(match) > 1 {
|
||||||
|
glyphTexts = append(glyphTexts, match[1])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return glyphTexts
|
||||||
|
}
|
||||||
|
|
||||||
|
// EncodePrompt encodes the prompt text using the ByT5 tokenizer and encoder
|
||||||
|
// This provides text conditioning for the diffusion transformer via the glyph projector
|
||||||
|
//
|
||||||
|
// IMPORTANT: This encodes only the GLYPH TEXTS (quoted strings in the prompt), not the
|
||||||
|
// full prompt. Glyph texts are used for text rendering guidance in the generated image.
|
||||||
|
// Multiple glyph texts are encoded and concatenated to form the conditioning signal.
|
||||||
|
// This matches diffusers' _get_glyph_embeds() behavior.
|
||||||
|
func (m *T5TextEncoder) EncodePrompt(tok *ByT5Tokenizer, prompt string) *mlx.Array {
|
||||||
|
// Extract glyph texts from prompt (text in quotes)
|
||||||
|
glyphTexts := extractGlyphTexts(prompt)
|
||||||
|
|
||||||
|
// If no glyph texts found, encode empty string (matches diffusers: [""] fallback)
|
||||||
|
if len(glyphTexts) == 0 {
|
||||||
|
glyphTexts = []string{""}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode each glyph text and collect token sequences
|
||||||
|
// Matching diffusers' _get_glyph_embeds() which batches all glyph texts
|
||||||
|
var allTokenSeqs [][]int32
|
||||||
|
|
||||||
|
for _, glyphText := range glyphTexts {
|
||||||
|
// ByT5 uses byte-level encoding: each byte (0-255) -> token (3-258)
|
||||||
|
tokens := tok.Encode(glyphText)
|
||||||
|
|
||||||
|
// Add EOS token (1) at the end to match HuggingFace tokenizer behavior
|
||||||
|
tokens = append(tokens, tok.EOSTokenID)
|
||||||
|
|
||||||
|
allTokenSeqs = append(allTokenSeqs, tokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process each glyph text through the encoder
|
||||||
|
var allEmbeddings []*mlx.Array
|
||||||
|
for _, tokens := range allTokenSeqs {
|
||||||
|
tokenLen := len(tokens)
|
||||||
|
if tokenLen == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create token array [1, L]
|
||||||
|
tokensArr := mlx.NewArrayInt32(tokens, []int32{1, int32(tokenLen)})
|
||||||
|
|
||||||
|
// Forward through encoder
|
||||||
|
output := m.Forward(tokensArr)
|
||||||
|
mlx.Eval(output)
|
||||||
|
|
||||||
|
allEmbeddings = append(allEmbeddings, output)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Concatenate all glyph embeddings along sequence dimension
|
||||||
|
var output *mlx.Array
|
||||||
|
if len(allEmbeddings) == 0 {
|
||||||
|
// Fallback: return single zero embedding
|
||||||
|
output = mlx.Zeros([]int32{1, 1, m.Config.DModel}, mlx.DtypeBFloat16)
|
||||||
|
} else if len(allEmbeddings) == 1 {
|
||||||
|
output = allEmbeddings[0]
|
||||||
|
} else {
|
||||||
|
output = mlx.Concatenate(allEmbeddings, 1)
|
||||||
|
}
|
||||||
|
mlx.Eval(output)
|
||||||
|
|
||||||
|
return output
|
||||||
|
}
|
||||||
|
|
||||||
|
// computeRelativePositionBias computes T5's relative position encoding
|
||||||
|
func (m *T5TextEncoder) computeRelativePositionBias(seqLen int32) *mlx.Array {
|
||||||
|
cfg := m.Config
|
||||||
|
|
||||||
|
// Create relative position matrix
|
||||||
|
// For each (query_pos, key_pos) pair, compute bucketed relative position
|
||||||
|
numBuckets := cfg.RelativeAttentionNumBuckets
|
||||||
|
maxDistance := cfg.RelativeAttentionMaxDistance
|
||||||
|
|
||||||
|
// Create position indices
|
||||||
|
contextPos := make([]int32, seqLen*seqLen)
|
||||||
|
memoryPos := make([]int32, seqLen*seqLen)
|
||||||
|
for i := int32(0); i < seqLen; i++ {
|
||||||
|
for j := int32(0); j < seqLen; j++ {
|
||||||
|
contextPos[i*seqLen+j] = i
|
||||||
|
memoryPos[i*seqLen+j] = j
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute relative positions and bucket them
|
||||||
|
buckets := make([]int32, seqLen*seqLen)
|
||||||
|
for i := int32(0); i < seqLen*seqLen; i++ {
|
||||||
|
relPos := memoryPos[i] - contextPos[i]
|
||||||
|
buckets[i] = relativePosistionBucket(relPos, numBuckets, maxDistance, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create bucket indices array
|
||||||
|
bucketsArr := mlx.NewArrayInt32(buckets, []int32{seqLen, seqLen})
|
||||||
|
|
||||||
|
// Look up bias: RelativeAttentionBias shape is [numBuckets, numHeads] = [32, 6]
|
||||||
|
// Take along axis 0 (buckets dimension) -> [seqLen, seqLen, numHeads]
|
||||||
|
bias := mlx.Take(m.RelativeAttentionBias, bucketsArr, 0) // [seqLen, seqLen, numHeads]
|
||||||
|
|
||||||
|
// Transpose to [numHeads, seqLen, seqLen]
|
||||||
|
bias = mlx.Transpose(bias, 2, 0, 1) // [numHeads, seqLen, seqLen]
|
||||||
|
bias = mlx.ExpandDims(bias, 0) // [1, numHeads, seqLen, seqLen]
|
||||||
|
|
||||||
|
return bias
|
||||||
|
}
|
||||||
|
|
||||||
|
// relativePosistionBucket computes the bucket for a relative position
|
||||||
|
func relativePosistionBucket(relativePosition int32, numBuckets int32, maxDistance int32, bidirectional bool) int32 {
|
||||||
|
var bucket int32 = 0
|
||||||
|
var n int32 = -relativePosition
|
||||||
|
|
||||||
|
if bidirectional {
|
||||||
|
numBuckets /= 2
|
||||||
|
if n < 0 {
|
||||||
|
bucket += numBuckets
|
||||||
|
n = -n
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if n < 0 {
|
||||||
|
n = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Half buckets are for exact positions, half are for log-spaced
|
||||||
|
maxExact := numBuckets / 2
|
||||||
|
if n < maxExact {
|
||||||
|
bucket += n
|
||||||
|
} else {
|
||||||
|
// Log-spaced buckets
|
||||||
|
logVal := math.Log(float64(n)/float64(maxExact)) / math.Log(float64(maxDistance)/float64(maxExact))
|
||||||
|
bucket += maxExact + int32(logVal*float64(numBuckets-maxExact))
|
||||||
|
if bucket > numBuckets-1 {
|
||||||
|
bucket = numBuckets - 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return bucket
|
||||||
|
}
|
||||||
|
|
||||||
|
// Forward for T5Block
|
||||||
|
func (b *T5Block) Forward(x *mlx.Array, posBias *mlx.Array, eps float32) *mlx.Array {
|
||||||
|
// Self attention with residual
|
||||||
|
h := b.Layer0.Forward(x, posBias, eps)
|
||||||
|
|
||||||
|
// FFN with residual
|
||||||
|
h = b.Layer1.Forward(h, eps)
|
||||||
|
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
|
||||||
|
// Forward for T5LayerSelfAttention
|
||||||
|
func (l *T5LayerSelfAttention) Forward(x *mlx.Array, posBias *mlx.Array, eps float32) *mlx.Array {
|
||||||
|
// Pre-norm
|
||||||
|
normed := l.LayerNorm.Forward(x)
|
||||||
|
|
||||||
|
// Attention
|
||||||
|
attnOut := l.SelfAttention.Forward(normed, posBias)
|
||||||
|
|
||||||
|
// Residual
|
||||||
|
return mlx.Add(x, attnOut)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Forward for T5Attention
|
||||||
|
func (attn *T5Attention) Forward(x *mlx.Array, posBias *mlx.Array) *mlx.Array {
|
||||||
|
shape := x.Shape()
|
||||||
|
B := shape[0]
|
||||||
|
L := shape[1]
|
||||||
|
D := shape[2]
|
||||||
|
|
||||||
|
// Q, K, V projections (no bias)
|
||||||
|
// Weights are [out_features, in_features], so we use matmul with transpose
|
||||||
|
q := mlx.Matmul(x, mlx.Transpose(attn.Q, 1, 0))
|
||||||
|
k := mlx.Matmul(x, mlx.Transpose(attn.K, 1, 0))
|
||||||
|
v := mlx.Matmul(x, mlx.Transpose(attn.V, 1, 0))
|
||||||
|
|
||||||
|
// Reshape to [B, L, nheads, d_kv]
|
||||||
|
q = mlx.Reshape(q, B, L, attn.NHeads, attn.DKV)
|
||||||
|
k = mlx.Reshape(k, B, L, attn.NHeads, attn.DKV)
|
||||||
|
v = mlx.Reshape(v, B, L, attn.NHeads, attn.DKV)
|
||||||
|
|
||||||
|
// Transpose to [B, nheads, L, d_kv]
|
||||||
|
q = mlx.Transpose(q, 0, 2, 1, 3)
|
||||||
|
k = mlx.Transpose(k, 0, 2, 1, 3)
|
||||||
|
v = mlx.Transpose(v, 0, 2, 1, 3)
|
||||||
|
|
||||||
|
// Attention scores with relative position bias
|
||||||
|
// T5 uses UNSCALED dot-product attention: scores = q @ k.T + pos_bias
|
||||||
|
// (no 1/sqrt(d_k) scale factor like in standard transformers)
|
||||||
|
scores := mlx.Matmul(q, mlx.Transpose(k, 0, 1, 3, 2))
|
||||||
|
scores = mlx.Add(scores, posBias)
|
||||||
|
|
||||||
|
// Softmax
|
||||||
|
attnWeights := mlx.Softmax(scores, -1)
|
||||||
|
|
||||||
|
// Attend to values
|
||||||
|
out := mlx.Matmul(attnWeights, v)
|
||||||
|
|
||||||
|
// Transpose back [B, nheads, L, d_kv] -> [B, L, nheads, d_kv]
|
||||||
|
out = mlx.Transpose(out, 0, 2, 1, 3)
|
||||||
|
// Reshape to [B, L, D]
|
||||||
|
out = mlx.Reshape(out, B, L, attn.NHeads*attn.DKV)
|
||||||
|
|
||||||
|
// Output projection
|
||||||
|
out = mlx.Matmul(out, mlx.Transpose(attn.O, 1, 0))
|
||||||
|
|
||||||
|
_ = D // Silence unused warning
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// Forward for T5LayerFF
|
||||||
|
func (l *T5LayerFF) Forward(x *mlx.Array, eps float32) *mlx.Array {
|
||||||
|
// Pre-norm
|
||||||
|
normed := l.LayerNorm.Forward(x)
|
||||||
|
|
||||||
|
// FFN
|
||||||
|
ffOut := l.DenseReluDense.Forward(normed)
|
||||||
|
|
||||||
|
// Residual
|
||||||
|
return mlx.Add(x, ffOut)
|
||||||
|
}
|
||||||
|
|
||||||
|
// geluNew implements the GELU activation with tanh approximation (gelu_new)
|
||||||
|
// This matches HuggingFace transformers' gelu_new/OpenAI GPT implementation
|
||||||
|
// Formula: 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x³)))
|
||||||
|
func geluNew(x *mlx.Array) *mlx.Array {
|
||||||
|
sqrt2OverPi := float32(0.7978845608) // sqrt(2/π)
|
||||||
|
coeff := float32(0.044715)
|
||||||
|
|
||||||
|
x3 := mlx.Mul(mlx.Mul(x, x), x)
|
||||||
|
inner := mlx.MulScalar(mlx.Add(x, mlx.MulScalar(x3, coeff)), sqrt2OverPi)
|
||||||
|
return mlx.Mul(mlx.MulScalar(x, 0.5), mlx.AddScalar(mlx.Tanh(inner), 1.0))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Forward for T5DenseGatedGelu (gated-gelu activation)
|
||||||
|
func (d *T5DenseGatedGelu) Forward(x *mlx.Array) *mlx.Array {
|
||||||
|
// Gate projection with GELU activation (T5 v1.1/ByT5 uses gelu_new)
|
||||||
|
gate := mlx.Matmul(x, mlx.Transpose(d.Wi0, 1, 0))
|
||||||
|
gate = geluNew(gate)
|
||||||
|
|
||||||
|
// Up projection
|
||||||
|
up := mlx.Matmul(x, mlx.Transpose(d.Wi1, 1, 0))
|
||||||
|
|
||||||
|
// Gated output
|
||||||
|
h := mlx.Mul(gate, up)
|
||||||
|
|
||||||
|
// Down projection
|
||||||
|
return mlx.Matmul(h, mlx.Transpose(d.Wo, 1, 0))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Forward for T5LayerNorm (RMSNorm variant)
|
||||||
|
func (ln *T5LayerNorm) Forward(x *mlx.Array) *mlx.Array {
|
||||||
|
// T5 uses RMSNorm: x * rsqrt(mean(x^2) + eps) * weight
|
||||||
|
variance := mlx.Mean(mlx.Square(x), -1, true)
|
||||||
|
x = mlx.Mul(x, mlx.RSqrt(mlx.AddScalar(variance, ln.Eps)))
|
||||||
|
return mlx.Mul(x, ln.Weight)
|
||||||
|
}
|
||||||
1255
x/imagegen/models/glm_image/transformer.go
Normal file
1255
x/imagegen/models/glm_image/transformer.go
Normal file
File diff suppressed because it is too large
Load Diff
477
x/imagegen/models/glm_image/vae.go
Normal file
477
x/imagegen/models/glm_image/vae.go
Normal file
@@ -0,0 +1,477 @@
|
|||||||
|
//go:build mlx
|
||||||
|
|
||||||
|
package glm_image
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/x/imagegen"
|
||||||
|
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||||
|
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// VAEConfig holds VAE decoder configuration
|
||||||
|
type VAEConfig struct {
|
||||||
|
InChannels int32 `json:"in_channels"` // 3
|
||||||
|
OutChannels int32 `json:"out_channels"` // 3
|
||||||
|
LatentChannels int32 `json:"latent_channels"` // 16
|
||||||
|
BlockOutChannels []int32 `json:"block_out_channels"` // [128, 512, 1024, 1024]
|
||||||
|
LayersPerBlock int32 `json:"layers_per_block"` // 3
|
||||||
|
NormNumGroups int32 `json:"norm_num_groups"` // 32
|
||||||
|
ScalingFactor float32 `json:"scaling_factor"` // 0.18215
|
||||||
|
ShiftFactor *float32 `json:"shift_factor"` // null
|
||||||
|
LatentsMean []float32 `json:"latents_mean"` // [16 values]
|
||||||
|
LatentsStd []float32 `json:"latents_std"` // [16 values]
|
||||||
|
}
|
||||||
|
|
||||||
|
// VAEDecoder is the VAE latent decoder
|
||||||
|
type VAEDecoder struct {
|
||||||
|
Config *VAEConfig
|
||||||
|
|
||||||
|
// Decoder components
|
||||||
|
ConvIn *VAEConv2d `weight:"decoder.conv_in"`
|
||||||
|
MidBlock *VAEMidBlock `weight:"decoder.mid_block"`
|
||||||
|
UpBlocks []*VAEUpBlock `weight:"decoder.up_blocks"`
|
||||||
|
ConvNormOut *GroupNorm `weight:"decoder.conv_norm_out"`
|
||||||
|
ConvOut *VAEConv2d `weight:"decoder.conv_out"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// VAEConv2d is a 2D convolution layer
|
||||||
|
type VAEConv2d struct {
|
||||||
|
Weight *mlx.Array `weight:"weight"`
|
||||||
|
Bias *mlx.Array `weight:"bias"`
|
||||||
|
Stride int32
|
||||||
|
Padding int32
|
||||||
|
}
|
||||||
|
|
||||||
|
// GroupNorm is group normalization
|
||||||
|
type GroupNorm struct {
|
||||||
|
Weight *mlx.Array `weight:"weight"`
|
||||||
|
Bias *mlx.Array `weight:"bias"`
|
||||||
|
NumGroups int32
|
||||||
|
Eps float32
|
||||||
|
}
|
||||||
|
|
||||||
|
// VAEMidBlock is the middle block of the VAE
|
||||||
|
type VAEMidBlock struct {
|
||||||
|
Resnets []*VAEResnetBlock `weight:"resnets"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// VAEUpBlock is an upsampling block
|
||||||
|
type VAEUpBlock struct {
|
||||||
|
Resnets []*VAEResnetBlock `weight:"resnets"`
|
||||||
|
Upsamplers []*VAEUpsampler `weight:"upsamplers"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// VAEResnetBlock is a residual block
|
||||||
|
type VAEResnetBlock struct {
|
||||||
|
Norm1 *GroupNorm `weight:"norm1"`
|
||||||
|
Conv1 *VAEConv2d `weight:"conv1"`
|
||||||
|
Norm2 *GroupNorm `weight:"norm2"`
|
||||||
|
Conv2 *VAEConv2d `weight:"conv2"`
|
||||||
|
ConvShortcut *VAEConv2d `weight:"conv_shortcut,optional"` // Optional, for channel mismatch
|
||||||
|
}
|
||||||
|
|
||||||
|
// VAEUpsampler is an upsampling layer
|
||||||
|
type VAEUpsampler struct {
|
||||||
|
Conv *VAEConv2d `weight:"conv"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load loads the VAE decoder from manifest
|
||||||
|
func (m *VAEDecoder) Load(manifest *imagegen.ModelManifest) error {
|
||||||
|
fmt.Print(" Loading VAE decoder... ")
|
||||||
|
|
||||||
|
// Load config
|
||||||
|
var cfg VAEConfig
|
||||||
|
if err := manifest.ReadConfigJSON("vae/config.json", &cfg); err != nil {
|
||||||
|
return fmt.Errorf("config: %w", err)
|
||||||
|
}
|
||||||
|
m.Config = &cfg
|
||||||
|
|
||||||
|
// Initialize structure based on config
|
||||||
|
numBlocks := len(cfg.BlockOutChannels)
|
||||||
|
m.UpBlocks = make([]*VAEUpBlock, numBlocks)
|
||||||
|
|
||||||
|
// Pre-allocate MidBlock resnets (VAE mid_block typically has 2 resnets)
|
||||||
|
m.MidBlock = &VAEMidBlock{
|
||||||
|
Resnets: make([]*VAEResnetBlock, 2),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pre-allocate UpBlocks with their resnets and upsamplers
|
||||||
|
// VAE decoder has layers_per_block+1 resnets per up_block (to match encoder)
|
||||||
|
// And all but the last up_block has an upsampler
|
||||||
|
for i := 0; i < numBlocks; i++ {
|
||||||
|
numResnets := cfg.LayersPerBlock + 1 // typically 4 resnets
|
||||||
|
m.UpBlocks[i] = &VAEUpBlock{
|
||||||
|
Resnets: make([]*VAEResnetBlock, numResnets),
|
||||||
|
}
|
||||||
|
// All but the last block has upsamplers
|
||||||
|
if i < numBlocks-1 {
|
||||||
|
m.UpBlocks[i].Upsamplers = make([]*VAEUpsampler, 1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load weights
|
||||||
|
weights, err := imagegen.LoadWeightsFromManifest(manifest, "vae")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("weights: %w", err)
|
||||||
|
}
|
||||||
|
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
|
||||||
|
return fmt.Errorf("load weights: %w", err)
|
||||||
|
}
|
||||||
|
defer weights.ReleaseAll()
|
||||||
|
|
||||||
|
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||||
|
return fmt.Errorf("load module: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize GroupNorm parameters
|
||||||
|
m.initGroupNorms()
|
||||||
|
|
||||||
|
fmt.Println("✓")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadFromPath loads the VAE decoder from a directory path
|
||||||
|
func (m *VAEDecoder) LoadFromPath(path string) error {
|
||||||
|
fmt.Print(" Loading VAE decoder... ")
|
||||||
|
|
||||||
|
// Load config
|
||||||
|
var cfg VAEConfig
|
||||||
|
configPath := filepath.Join(path, "config.json")
|
||||||
|
data, err := os.ReadFile(configPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("read config: %w", err)
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||||
|
return fmt.Errorf("parse config: %w", err)
|
||||||
|
}
|
||||||
|
m.Config = &cfg
|
||||||
|
|
||||||
|
// Initialize structure based on config
|
||||||
|
numBlocks := len(cfg.BlockOutChannels)
|
||||||
|
m.UpBlocks = make([]*VAEUpBlock, numBlocks)
|
||||||
|
|
||||||
|
// Pre-allocate MidBlock resnets (VAE mid_block typically has 2 resnets)
|
||||||
|
m.MidBlock = &VAEMidBlock{
|
||||||
|
Resnets: make([]*VAEResnetBlock, 2),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pre-allocate UpBlocks with their resnets and upsamplers
|
||||||
|
for i := 0; i < numBlocks; i++ {
|
||||||
|
numResnets := cfg.LayersPerBlock + 1
|
||||||
|
m.UpBlocks[i] = &VAEUpBlock{
|
||||||
|
Resnets: make([]*VAEResnetBlock, numResnets),
|
||||||
|
}
|
||||||
|
if i < numBlocks-1 {
|
||||||
|
m.UpBlocks[i].Upsamplers = make([]*VAEUpsampler, 1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load weights from safetensors files
|
||||||
|
weights, err := safetensors.LoadModelWeights(path)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("weights: %w", err)
|
||||||
|
}
|
||||||
|
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
|
||||||
|
return fmt.Errorf("load weights: %w", err)
|
||||||
|
}
|
||||||
|
defer weights.ReleaseAll()
|
||||||
|
|
||||||
|
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||||
|
return fmt.Errorf("load module: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize GroupNorm parameters
|
||||||
|
m.initGroupNorms()
|
||||||
|
|
||||||
|
fmt.Println("✓")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *VAEDecoder) initGroupNorms() {
|
||||||
|
cfg := m.Config
|
||||||
|
numGroups := cfg.NormNumGroups
|
||||||
|
eps := float32(1e-6) // Must match diffusers VAE (1e-6, not 1e-5)
|
||||||
|
|
||||||
|
if m.ConvNormOut != nil {
|
||||||
|
m.ConvNormOut.NumGroups = numGroups
|
||||||
|
m.ConvNormOut.Eps = eps
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.MidBlock != nil {
|
||||||
|
for _, resnet := range m.MidBlock.Resnets {
|
||||||
|
if resnet.Norm1 != nil {
|
||||||
|
resnet.Norm1.NumGroups = numGroups
|
||||||
|
resnet.Norm1.Eps = eps
|
||||||
|
}
|
||||||
|
if resnet.Norm2 != nil {
|
||||||
|
resnet.Norm2.NumGroups = numGroups
|
||||||
|
resnet.Norm2.Eps = eps
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, upBlock := range m.UpBlocks {
|
||||||
|
if upBlock == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, resnet := range upBlock.Resnets {
|
||||||
|
if resnet == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if resnet.Norm1 != nil {
|
||||||
|
resnet.Norm1.NumGroups = numGroups
|
||||||
|
resnet.Norm1.Eps = eps
|
||||||
|
}
|
||||||
|
if resnet.Norm2 != nil {
|
||||||
|
resnet.Norm2.NumGroups = numGroups
|
||||||
|
resnet.Norm2.Eps = eps
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode decodes latents to an image
|
||||||
|
func (m *VAEDecoder) Decode(latents *mlx.Array) *mlx.Array {
|
||||||
|
cfg := m.Config
|
||||||
|
|
||||||
|
// Apply latent denormalization if mean/std are provided
|
||||||
|
// This matches diffusers GLM-Image: latents = latents * std + mean
|
||||||
|
// Note: GLM-Image does NOT divide by scaling_factor (unlike standard SD VAEs)
|
||||||
|
if len(cfg.LatentsMean) > 0 && len(cfg.LatentsStd) > 0 {
|
||||||
|
latents = m.denormalizeLatents(latents)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert from NCHW to NHWC for processing
|
||||||
|
// [B, C, H, W] -> [B, H, W, C]
|
||||||
|
x := mlx.Transpose(latents, 0, 2, 3, 1)
|
||||||
|
|
||||||
|
// Initial convolution
|
||||||
|
x = m.ConvIn.Forward(x)
|
||||||
|
|
||||||
|
// Mid block
|
||||||
|
x = m.MidBlock.Forward(x)
|
||||||
|
|
||||||
|
// Up blocks (forward order - index 0 is at lowest resolution/highest channels)
|
||||||
|
for i := 0; i < len(m.UpBlocks); i++ {
|
||||||
|
if m.UpBlocks[i] != nil {
|
||||||
|
x = m.UpBlocks[i].Forward(x)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Final normalization and convolution
|
||||||
|
x = m.ConvNormOut.Forward(x)
|
||||||
|
x = mlx.SiLU(x)
|
||||||
|
x = m.ConvOut.Forward(x)
|
||||||
|
|
||||||
|
// Convert back to NCHW
|
||||||
|
// [B, H, W, C] -> [B, C, H, W]
|
||||||
|
x = mlx.Transpose(x, 0, 3, 1, 2)
|
||||||
|
|
||||||
|
// Clamp to valid range and convert to [0, 1]
|
||||||
|
x = mlx.ClipScalar(x, -1.0, 1.0, true, true)
|
||||||
|
x = mlx.AddScalar(x, 1.0)
|
||||||
|
x = mlx.DivScalar(x, 2.0)
|
||||||
|
|
||||||
|
return x
|
||||||
|
}
|
||||||
|
|
||||||
|
// denormalizeLatents applies the latent mean/std denormalization
|
||||||
|
func (m *VAEDecoder) denormalizeLatents(latents *mlx.Array) *mlx.Array {
|
||||||
|
cfg := m.Config
|
||||||
|
|
||||||
|
// Create mean and std arrays [1, C, 1, 1] for broadcasting
|
||||||
|
mean := mlx.NewArray(cfg.LatentsMean, []int32{1, int32(len(cfg.LatentsMean)), 1, 1})
|
||||||
|
std := mlx.NewArray(cfg.LatentsStd, []int32{1, int32(len(cfg.LatentsStd)), 1, 1})
|
||||||
|
|
||||||
|
// Denormalize: latents * std + mean
|
||||||
|
latents = mlx.Mul(latents, std)
|
||||||
|
latents = mlx.Add(latents, mean)
|
||||||
|
|
||||||
|
return latents
|
||||||
|
}
|
||||||
|
|
||||||
|
// Forward for VAEConv2d
|
||||||
|
func (c *VAEConv2d) Forward(x *mlx.Array) *mlx.Array {
|
||||||
|
// x: [B, H, W, C_in] (NHWC)
|
||||||
|
// PyTorch weight: [C_out, C_in, kH, kW] (OIHW)
|
||||||
|
// MLX conv2d expects weight: [C_out, kH, kW, C_in] (OHWI)
|
||||||
|
// So we need to transpose from OIHW to OHWI
|
||||||
|
|
||||||
|
stride := c.Stride
|
||||||
|
if stride == 0 {
|
||||||
|
stride = 1
|
||||||
|
}
|
||||||
|
padding := c.Padding
|
||||||
|
if padding == 0 {
|
||||||
|
// Default to same padding for 3x3 kernels
|
||||||
|
wShape := c.Weight.Shape()
|
||||||
|
if len(wShape) >= 3 && wShape[2] == 3 {
|
||||||
|
padding = 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Transpose weight from OIHW [out, in, h, w] to OHWI [out, h, w, in]
|
||||||
|
weight := mlx.Transpose(c.Weight, 0, 2, 3, 1)
|
||||||
|
|
||||||
|
out := mlx.Conv2d(x, weight, stride, padding)
|
||||||
|
if c.Bias != nil {
|
||||||
|
// Bias: [C_out] -> [1, 1, 1, C_out]
|
||||||
|
bias := mlx.Reshape(c.Bias, 1, 1, 1, -1)
|
||||||
|
out = mlx.Add(out, bias)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// Forward for GroupNorm
|
||||||
|
func (gn *GroupNorm) Forward(x *mlx.Array) *mlx.Array {
|
||||||
|
// x: [B, H, W, C] (NHWC)
|
||||||
|
shape := x.Shape()
|
||||||
|
B := shape[0]
|
||||||
|
H := shape[1]
|
||||||
|
W := shape[2]
|
||||||
|
C := shape[3]
|
||||||
|
|
||||||
|
numGroups := gn.NumGroups
|
||||||
|
if numGroups == 0 {
|
||||||
|
numGroups = 32
|
||||||
|
}
|
||||||
|
groupSize := C / numGroups
|
||||||
|
|
||||||
|
// Reshape to [B, H, W, groups, groupSize]
|
||||||
|
x = mlx.Reshape(x, B, H, W, numGroups, groupSize)
|
||||||
|
|
||||||
|
// Compute mean and variance per group
|
||||||
|
mean := mlx.Mean(x, 1, true)
|
||||||
|
mean = mlx.Mean(mean, 2, true)
|
||||||
|
mean = mlx.Mean(mean, 4, true)
|
||||||
|
|
||||||
|
xCentered := mlx.Sub(x, mean)
|
||||||
|
variance := mlx.Mean(mlx.Square(xCentered), 1, true)
|
||||||
|
variance = mlx.Mean(variance, 2, true)
|
||||||
|
variance = mlx.Mean(variance, 4, true)
|
||||||
|
|
||||||
|
// Normalize
|
||||||
|
xNorm := mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, gn.Eps)))
|
||||||
|
|
||||||
|
// Reshape back
|
||||||
|
xNorm = mlx.Reshape(xNorm, B, H, W, C)
|
||||||
|
|
||||||
|
// Scale and shift
|
||||||
|
if gn.Weight != nil {
|
||||||
|
weight := mlx.Reshape(gn.Weight, 1, 1, 1, C)
|
||||||
|
xNorm = mlx.Mul(xNorm, weight)
|
||||||
|
}
|
||||||
|
if gn.Bias != nil {
|
||||||
|
bias := mlx.Reshape(gn.Bias, 1, 1, 1, C)
|
||||||
|
xNorm = mlx.Add(xNorm, bias)
|
||||||
|
}
|
||||||
|
|
||||||
|
return xNorm
|
||||||
|
}
|
||||||
|
|
||||||
|
// Forward for VAEMidBlock
|
||||||
|
func (mb *VAEMidBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||||
|
for _, resnet := range mb.Resnets {
|
||||||
|
x = resnet.Forward(x)
|
||||||
|
}
|
||||||
|
return x
|
||||||
|
}
|
||||||
|
|
||||||
|
// Forward for VAEUpBlock
|
||||||
|
func (ub *VAEUpBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||||
|
// Apply resnets
|
||||||
|
for _, resnet := range ub.Resnets {
|
||||||
|
if resnet != nil {
|
||||||
|
x = resnet.Forward(x)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply upsamplers
|
||||||
|
for _, upsampler := range ub.Upsamplers {
|
||||||
|
if upsampler != nil {
|
||||||
|
x = upsampler.Forward(x)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return x
|
||||||
|
}
|
||||||
|
|
||||||
|
// Forward for VAEResnetBlock
|
||||||
|
func (rb *VAEResnetBlock) Forward(x *mlx.Array) *mlx.Array {
|
||||||
|
residual := x
|
||||||
|
|
||||||
|
// First norm + activation + conv
|
||||||
|
h := rb.Norm1.Forward(x)
|
||||||
|
h = mlx.SiLU(h)
|
||||||
|
h = rb.Conv1.Forward(h)
|
||||||
|
|
||||||
|
// Second norm + activation + conv
|
||||||
|
h = rb.Norm2.Forward(h)
|
||||||
|
h = mlx.SiLU(h)
|
||||||
|
h = rb.Conv2.Forward(h)
|
||||||
|
|
||||||
|
// Shortcut for channel mismatch
|
||||||
|
if rb.ConvShortcut != nil {
|
||||||
|
residual = rb.ConvShortcut.Forward(residual)
|
||||||
|
}
|
||||||
|
|
||||||
|
return mlx.Add(h, residual)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Forward for VAEUpsampler (2x nearest neighbor upsample + conv)
|
||||||
|
func (us *VAEUpsampler) Forward(x *mlx.Array) *mlx.Array {
|
||||||
|
// x: [B, H, W, C]
|
||||||
|
// 2x nearest neighbor upsample
|
||||||
|
x = upsample2x(x)
|
||||||
|
|
||||||
|
// Conv
|
||||||
|
if us.Conv != nil {
|
||||||
|
x = us.Conv.Forward(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
return x
|
||||||
|
}
|
||||||
|
|
||||||
|
// upsample2x performs 2x nearest neighbor upsampling.
|
||||||
|
// Input and output are in NHWC format: [B, H, W, C] -> [B, H*2, W*2, C]
|
||||||
|
func upsample2x(x *mlx.Array) *mlx.Array {
|
||||||
|
shape := x.Shape()
|
||||||
|
B := shape[0]
|
||||||
|
H := shape[1]
|
||||||
|
W := shape[2]
|
||||||
|
C := shape[3]
|
||||||
|
|
||||||
|
// Create indices [0, 0, 1, 1, 2, 2, ...] for nearest neighbor
|
||||||
|
hIndices := make([]int32, H*2)
|
||||||
|
for i := int32(0); i < H; i++ {
|
||||||
|
hIndices[i*2] = i
|
||||||
|
hIndices[i*2+1] = i
|
||||||
|
}
|
||||||
|
wIndices := make([]int32, W*2)
|
||||||
|
for i := int32(0); i < W; i++ {
|
||||||
|
wIndices[i*2] = i
|
||||||
|
wIndices[i*2+1] = i
|
||||||
|
}
|
||||||
|
|
||||||
|
hIdx := mlx.NewArrayInt32(hIndices, []int32{H * 2})
|
||||||
|
wIdx := mlx.NewArrayInt32(wIndices, []int32{W * 2})
|
||||||
|
|
||||||
|
// Take along height axis
|
||||||
|
x = mlx.Reshape(x, B*H, W, C)
|
||||||
|
x = mlx.Take(x, wIdx, 1) // [B*H, W*2, C]
|
||||||
|
x = mlx.Reshape(x, B, H, W*2, C)
|
||||||
|
|
||||||
|
// Take along width axis - transpose to [B, W*2, H, C], take, transpose back
|
||||||
|
x = mlx.Transpose(x, 0, 2, 1, 3) // [B, W*2, H, C]
|
||||||
|
x = mlx.Reshape(x, B*(W*2), H, C)
|
||||||
|
x = mlx.Take(x, hIdx, 1) // [B*(W*2), H*2, C]
|
||||||
|
x = mlx.Reshape(x, B, W*2, H*2, C)
|
||||||
|
x = mlx.Transpose(x, 0, 2, 1, 3) // [B, H*2, W*2, C]
|
||||||
|
|
||||||
|
return x
|
||||||
|
}
|
||||||
982
x/imagegen/models/glm_image/vision_language_encoder.go
Normal file
982
x/imagegen/models/glm_image/vision_language_encoder.go
Normal file
@@ -0,0 +1,982 @@
|
|||||||
|
//go:build mlx
|
||||||
|
|
||||||
|
package glm_image
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"math"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/x/imagegen"
|
||||||
|
"github.com/ollama/ollama/x/imagegen/cache"
|
||||||
|
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||||
|
"github.com/ollama/ollama/x/imagegen/nn"
|
||||||
|
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// VisionLanguageConfig holds GLM-Image AR generator configuration
|
||||||
|
type VisionLanguageConfig struct {
|
||||||
|
// Text model config
|
||||||
|
HiddenSize int32 `json:"hidden_size"` // 4096
|
||||||
|
NumHiddenLayers int32 `json:"num_hidden_layers"` // 40
|
||||||
|
IntermediateSize int32 `json:"intermediate_size"` // 13696
|
||||||
|
NumAttentionHeads int32 `json:"num_attention_heads"` // 32
|
||||||
|
NumKeyValueHeads int32 `json:"num_key_value_heads"` // 2
|
||||||
|
VocabSize int32 `json:"vocab_size"` // 168064
|
||||||
|
RMSNormEps float32 `json:"rms_norm_eps"` // 1e-5
|
||||||
|
|
||||||
|
// RoPE config
|
||||||
|
RopeTheta float32 `json:"rope_theta"` // 10000
|
||||||
|
PartialRotaryFactor float32 `json:"partial_rotary_factor"` // 0.5
|
||||||
|
MRoPESection []int32 `json:"mrope_section"` // [8, 12, 12]
|
||||||
|
|
||||||
|
// Visual token config
|
||||||
|
VisionVocabSize int32 `json:"vision_vocab_size"` // 16512
|
||||||
|
ImageStartTokenID int32 `json:"image_start_token_id"` // 16384
|
||||||
|
ImageEndTokenID int32 `json:"image_end_token_id"` // 16385
|
||||||
|
ImageTokenID int32 `json:"image_token_id"` // 167855
|
||||||
|
|
||||||
|
// Computed
|
||||||
|
HeadDim int32
|
||||||
|
}
|
||||||
|
|
||||||
|
// VisionLanguageEncoder is the 9B AR generator
|
||||||
|
type VisionLanguageEncoder struct {
|
||||||
|
Config *VisionLanguageConfig
|
||||||
|
|
||||||
|
// Embedding
|
||||||
|
EmbedTokens *nn.Embedding `weight:"model.language_model.embed_tokens"`
|
||||||
|
|
||||||
|
// Transformer layers
|
||||||
|
Layers []*GLMBlock `weight:"model.language_model.layers"`
|
||||||
|
|
||||||
|
// Final norm
|
||||||
|
FinalNorm *nn.RMSNorm `weight:"model.language_model.norm"`
|
||||||
|
|
||||||
|
// LM Head
|
||||||
|
LMHead *mlx.Array `weight:"lm_head.weight"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GLMBlock is a single transformer block in GLM-4 style
|
||||||
|
type GLMBlock struct {
|
||||||
|
// Pre-attention norm (GLM uses post-LN variant)
|
||||||
|
InputLayerNorm *nn.RMSNorm `weight:"input_layernorm"`
|
||||||
|
PostSelfAttnNorm *nn.RMSNorm `weight:"post_self_attn_layernorm"`
|
||||||
|
PostAttnLayerNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
|
||||||
|
PostMLPLayerNorm *nn.RMSNorm `weight:"post_mlp_layernorm"`
|
||||||
|
|
||||||
|
// Attention
|
||||||
|
SelfAttn *GLMAttention `weight:"self_attn"`
|
||||||
|
|
||||||
|
// MLP (fused gate_up)
|
||||||
|
MLP *GLMMLP `weight:"mlp"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GLMAttention implements GQA with partial rotary and MRoPE
|
||||||
|
type GLMAttention struct {
|
||||||
|
QProj *mlx.Array `weight:"q_proj.weight"`
|
||||||
|
KProj *mlx.Array `weight:"k_proj.weight"`
|
||||||
|
VProj *mlx.Array `weight:"v_proj.weight"`
|
||||||
|
OProj *mlx.Array `weight:"o_proj.weight"`
|
||||||
|
|
||||||
|
// QKV have biases in GLM
|
||||||
|
QBias *mlx.Array `weight:"q_proj.bias"`
|
||||||
|
KBias *mlx.Array `weight:"k_proj.bias"`
|
||||||
|
VBias *mlx.Array `weight:"v_proj.bias"`
|
||||||
|
|
||||||
|
// Computed
|
||||||
|
NHeads int32
|
||||||
|
NKVHeads int32
|
||||||
|
HeadDim int32
|
||||||
|
Scale float32
|
||||||
|
PartialRotary float32 // Only rotate this fraction of head_dim
|
||||||
|
RopeTheta float32
|
||||||
|
MRoPESection []int32 // [8, 12, 12] - frequency pairs per dimension (temporal, height, width)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ARCache holds KV caches for all layers using the shared cache implementation
|
||||||
|
type ARCache struct {
|
||||||
|
Layers []cache.Cache
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewARCache creates a new cache for the given number of layers
|
||||||
|
func NewARCache(numLayers int32) *ARCache {
|
||||||
|
layers := make([]cache.Cache, numLayers)
|
||||||
|
for i := range layers {
|
||||||
|
layers[i] = cache.NewKVCache()
|
||||||
|
}
|
||||||
|
return &ARCache{Layers: layers}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Free releases all cached tensors
|
||||||
|
func (c *ARCache) Free() {
|
||||||
|
for _, layer := range c.Layers {
|
||||||
|
for _, arr := range layer.State() {
|
||||||
|
if arr != nil {
|
||||||
|
arr.Free()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GLMMLP implements fused gate_up SwiGLU MLP
|
||||||
|
type GLMMLP struct {
|
||||||
|
// GLM uses fused gate_up_proj: [hidden, 2*intermediate]
|
||||||
|
GateUpProj *mlx.Array `weight:"gate_up_proj.weight"`
|
||||||
|
DownProj *mlx.Array `weight:"down_proj.weight"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load loads the vision-language encoder from manifest
|
||||||
|
func (m *VisionLanguageEncoder) Load(manifest *imagegen.ModelManifest) error {
|
||||||
|
fmt.Print(" Loading vision-language encoder... ")
|
||||||
|
|
||||||
|
// Load config
|
||||||
|
var rawCfg struct {
|
||||||
|
TextConfig struct {
|
||||||
|
HiddenSize int32 `json:"hidden_size"`
|
||||||
|
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
||||||
|
IntermediateSize int32 `json:"intermediate_size"`
|
||||||
|
NumAttentionHeads int32 `json:"num_attention_heads"`
|
||||||
|
NumKeyValueHeads int32 `json:"num_key_value_heads"`
|
||||||
|
VocabSize int32 `json:"vocab_size"`
|
||||||
|
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||||
|
VisionVocabSize int32 `json:"vision_vocab_size"`
|
||||||
|
RopeParameters struct {
|
||||||
|
RopeTheta float32 `json:"rope_theta"`
|
||||||
|
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
|
||||||
|
MRoPESection []int32 `json:"mrope_section"`
|
||||||
|
} `json:"rope_parameters"`
|
||||||
|
} `json:"text_config"`
|
||||||
|
ImageStartTokenID int32 `json:"image_start_token_id"`
|
||||||
|
ImageEndTokenID int32 `json:"image_end_token_id"`
|
||||||
|
ImageTokenID int32 `json:"image_token_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := manifest.ReadConfigJSON("vision_language_encoder/config.json", &rawCfg); err != nil {
|
||||||
|
return fmt.Errorf("config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := &VisionLanguageConfig{
|
||||||
|
HiddenSize: rawCfg.TextConfig.HiddenSize,
|
||||||
|
NumHiddenLayers: rawCfg.TextConfig.NumHiddenLayers,
|
||||||
|
IntermediateSize: rawCfg.TextConfig.IntermediateSize,
|
||||||
|
NumAttentionHeads: rawCfg.TextConfig.NumAttentionHeads,
|
||||||
|
NumKeyValueHeads: rawCfg.TextConfig.NumKeyValueHeads,
|
||||||
|
VocabSize: rawCfg.TextConfig.VocabSize,
|
||||||
|
RMSNormEps: rawCfg.TextConfig.RMSNormEps,
|
||||||
|
VisionVocabSize: rawCfg.TextConfig.VisionVocabSize,
|
||||||
|
RopeTheta: rawCfg.TextConfig.RopeParameters.RopeTheta,
|
||||||
|
PartialRotaryFactor: rawCfg.TextConfig.RopeParameters.PartialRotaryFactor,
|
||||||
|
MRoPESection: rawCfg.TextConfig.RopeParameters.MRoPESection,
|
||||||
|
ImageStartTokenID: rawCfg.ImageStartTokenID,
|
||||||
|
ImageEndTokenID: rawCfg.ImageEndTokenID,
|
||||||
|
ImageTokenID: rawCfg.ImageTokenID,
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads
|
||||||
|
m.Config = cfg
|
||||||
|
|
||||||
|
// Pre-allocate layers
|
||||||
|
m.Layers = make([]*GLMBlock, cfg.NumHiddenLayers)
|
||||||
|
|
||||||
|
// Load weights
|
||||||
|
weights, err := imagegen.LoadWeightsFromManifest(manifest, "vision_language_encoder")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("weights: %w", err)
|
||||||
|
}
|
||||||
|
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
|
||||||
|
return fmt.Errorf("load weights: %w", err)
|
||||||
|
}
|
||||||
|
defer weights.ReleaseAll()
|
||||||
|
|
||||||
|
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||||
|
return fmt.Errorf("load module: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.initComputedFields()
|
||||||
|
fmt.Printf("✓ [%d layers]\n", cfg.NumHiddenLayers)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadFromPath loads the vision-language encoder from a directory path
|
||||||
|
func (m *VisionLanguageEncoder) LoadFromPath(path string) error {
|
||||||
|
fmt.Print(" Loading vision-language encoder... ")
|
||||||
|
|
||||||
|
// Load config
|
||||||
|
var rawCfg struct {
|
||||||
|
TextConfig struct {
|
||||||
|
HiddenSize int32 `json:"hidden_size"`
|
||||||
|
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
||||||
|
IntermediateSize int32 `json:"intermediate_size"`
|
||||||
|
NumAttentionHeads int32 `json:"num_attention_heads"`
|
||||||
|
NumKeyValueHeads int32 `json:"num_key_value_heads"`
|
||||||
|
VocabSize int32 `json:"vocab_size"`
|
||||||
|
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||||
|
VisionVocabSize int32 `json:"vision_vocab_size"`
|
||||||
|
RopeParameters struct {
|
||||||
|
RopeTheta float32 `json:"rope_theta"`
|
||||||
|
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
|
||||||
|
MRoPESection []int32 `json:"mrope_section"`
|
||||||
|
} `json:"rope_parameters"`
|
||||||
|
} `json:"text_config"`
|
||||||
|
ImageStartTokenID int32 `json:"image_start_token_id"`
|
||||||
|
ImageEndTokenID int32 `json:"image_end_token_id"`
|
||||||
|
ImageTokenID int32 `json:"image_token_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
configPath := filepath.Join(path, "config.json")
|
||||||
|
data, err := os.ReadFile(configPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("read config: %w", err)
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &rawCfg); err != nil {
|
||||||
|
return fmt.Errorf("parse config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := &VisionLanguageConfig{
|
||||||
|
HiddenSize: rawCfg.TextConfig.HiddenSize,
|
||||||
|
NumHiddenLayers: rawCfg.TextConfig.NumHiddenLayers,
|
||||||
|
IntermediateSize: rawCfg.TextConfig.IntermediateSize,
|
||||||
|
NumAttentionHeads: rawCfg.TextConfig.NumAttentionHeads,
|
||||||
|
NumKeyValueHeads: rawCfg.TextConfig.NumKeyValueHeads,
|
||||||
|
VocabSize: rawCfg.TextConfig.VocabSize,
|
||||||
|
RMSNormEps: rawCfg.TextConfig.RMSNormEps,
|
||||||
|
VisionVocabSize: rawCfg.TextConfig.VisionVocabSize,
|
||||||
|
RopeTheta: rawCfg.TextConfig.RopeParameters.RopeTheta,
|
||||||
|
PartialRotaryFactor: rawCfg.TextConfig.RopeParameters.PartialRotaryFactor,
|
||||||
|
MRoPESection: rawCfg.TextConfig.RopeParameters.MRoPESection,
|
||||||
|
ImageStartTokenID: rawCfg.ImageStartTokenID,
|
||||||
|
ImageEndTokenID: rawCfg.ImageEndTokenID,
|
||||||
|
ImageTokenID: rawCfg.ImageTokenID,
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads
|
||||||
|
m.Config = cfg
|
||||||
|
|
||||||
|
// Pre-allocate layers
|
||||||
|
m.Layers = make([]*GLMBlock, cfg.NumHiddenLayers)
|
||||||
|
|
||||||
|
// Load weights
|
||||||
|
weights, err := safetensors.LoadModelWeights(path)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("weights: %w", err)
|
||||||
|
}
|
||||||
|
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
|
||||||
|
return fmt.Errorf("load weights: %w", err)
|
||||||
|
}
|
||||||
|
defer weights.ReleaseAll()
|
||||||
|
|
||||||
|
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||||
|
return fmt.Errorf("load module: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.initComputedFields()
|
||||||
|
fmt.Printf("✓ [%d layers]\n", cfg.NumHiddenLayers)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *VisionLanguageEncoder) initComputedFields() {
|
||||||
|
cfg := m.Config
|
||||||
|
for _, block := range m.Layers {
|
||||||
|
block.SelfAttn.NHeads = cfg.NumAttentionHeads
|
||||||
|
block.SelfAttn.NKVHeads = cfg.NumKeyValueHeads
|
||||||
|
block.SelfAttn.HeadDim = cfg.HeadDim
|
||||||
|
block.SelfAttn.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
|
||||||
|
block.SelfAttn.PartialRotary = cfg.PartialRotaryFactor
|
||||||
|
block.SelfAttn.RopeTheta = cfg.RopeTheta
|
||||||
|
block.SelfAttn.MRoPESection = cfg.MRoPESection
|
||||||
|
|
||||||
|
// Set norm eps
|
||||||
|
block.InputLayerNorm.Eps = cfg.RMSNormEps
|
||||||
|
block.PostSelfAttnNorm.Eps = cfg.RMSNormEps
|
||||||
|
block.PostAttnLayerNorm.Eps = cfg.RMSNormEps
|
||||||
|
block.PostMLPLayerNorm.Eps = cfg.RMSNormEps
|
||||||
|
}
|
||||||
|
m.FinalNorm.Eps = cfg.RMSNormEps
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate autoregressively generates visual tokens with KV caching
|
||||||
|
func (m *VisionLanguageEncoder) Generate(
|
||||||
|
prompt string,
|
||||||
|
tok *GLMTokenizer,
|
||||||
|
maxTokens int32,
|
||||||
|
temperature float32,
|
||||||
|
topP float32,
|
||||||
|
seed int64,
|
||||||
|
targetHeight, targetWidth int32,
|
||||||
|
progressFn func(int),
|
||||||
|
) *mlx.Array {
|
||||||
|
cfg := m.Config
|
||||||
|
|
||||||
|
// Encode prompt with grid tokens using GLM tokenizer
|
||||||
|
// Format: {prompt}<sop>{h} {w}<eop><sop>{prev_h} {prev_w}<eop><|dit_token_16384|>
|
||||||
|
tokens := tok.EncodeForGeneration(prompt, targetHeight, targetWidth)
|
||||||
|
|
||||||
|
// Calculate grid dimensions for MRoPE position IDs
|
||||||
|
factor := int32(32)
|
||||||
|
tokenH := targetHeight / factor
|
||||||
|
tokenW := targetWidth / factor
|
||||||
|
ratio := float64(tokenH) / float64(tokenW)
|
||||||
|
prevTokenH := int32(math.Sqrt(ratio) * 16)
|
||||||
|
prevTokenW := int32(math.Sqrt(1.0/ratio) * 16)
|
||||||
|
prevGridSize := prevTokenH * prevTokenW
|
||||||
|
|
||||||
|
// Create KV cache for all layers
|
||||||
|
cache := NewARCache(cfg.NumHiddenLayers)
|
||||||
|
defer cache.Free()
|
||||||
|
|
||||||
|
// ===== PREFILL PHASE =====
|
||||||
|
// Process entire prompt at once, populate cache
|
||||||
|
promptLen := int32(len(tokens))
|
||||||
|
tokenArr := mlx.NewArrayInt32(tokens, []int32{1, promptLen})
|
||||||
|
h := m.EmbedTokens.Forward(tokenArr)
|
||||||
|
tokenArr.Free()
|
||||||
|
|
||||||
|
mlx.Eval(h)
|
||||||
|
|
||||||
|
// Compute position IDs for prefill (text tokens use same position for all dims)
|
||||||
|
prefillPositions := make([][]int32, 3)
|
||||||
|
for dim := 0; dim < 3; dim++ {
|
||||||
|
prefillPositions[dim] = make([]int32, promptLen)
|
||||||
|
for i := int32(0); i < promptLen; i++ {
|
||||||
|
prefillPositions[dim][i] = i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Forward through layers (prefill)
|
||||||
|
for i, layer := range m.Layers {
|
||||||
|
oldH := h
|
||||||
|
h = layer.ForwardWithCache(h, promptLen, 0, cfg.RMSNormEps, cache.Layers[i], prefillPositions)
|
||||||
|
if i > 0 {
|
||||||
|
oldH.Free()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Eval h and cache arrays together so cache is materialized
|
||||||
|
evalArgs := []*mlx.Array{h}
|
||||||
|
for _, lc := range cache.Layers {
|
||||||
|
evalArgs = append(evalArgs, lc.State()...)
|
||||||
|
}
|
||||||
|
mlx.Eval(evalArgs...)
|
||||||
|
|
||||||
|
// Final norm and get logits for last position
|
||||||
|
preNormH := h
|
||||||
|
h = m.FinalNorm.Forward(h, cfg.RMSNormEps)
|
||||||
|
preNormH.Free()
|
||||||
|
|
||||||
|
lastH := mlx.Slice(h, []int32{0, promptLen - 1, 0}, []int32{1, promptLen, cfg.HiddenSize})
|
||||||
|
h.Free()
|
||||||
|
lastH = mlx.Reshape(lastH, 1, cfg.HiddenSize)
|
||||||
|
logits := mlx.Matmul(lastH, mlx.Transpose(m.LMHead, 1, 0))
|
||||||
|
lastH.Free()
|
||||||
|
|
||||||
|
// Sample first token
|
||||||
|
var sampleCounter int64 = 0
|
||||||
|
nextToken := sampleVisualToken(logits, temperature, topP, cfg, seed, &sampleCounter)
|
||||||
|
logits.Free()
|
||||||
|
|
||||||
|
// AR generation loop with caching
|
||||||
|
// Visual tokens are stored as VQ codebook indices [0, 16383]
|
||||||
|
// The LM head outputs indices [0, 16511] where:
|
||||||
|
// - [0, 16383] are VQ codes
|
||||||
|
// - 16384 is BOS
|
||||||
|
// - 16385 is EOS
|
||||||
|
visualTokens := make([]int32, 0, maxTokens)
|
||||||
|
posOffset := promptLen
|
||||||
|
visualTokenIdx := int32(0) // Index within visual token sequence for grid position calculation
|
||||||
|
|
||||||
|
// Preallocate slice for old cache state to reuse
|
||||||
|
oldCacheState := make([]*mlx.Array, 0, len(m.Layers)*2)
|
||||||
|
|
||||||
|
for i := int32(0); i < maxTokens; i++ {
|
||||||
|
if progressFn != nil {
|
||||||
|
progressFn(int(i))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for end token (EOS = 16385)
|
||||||
|
if nextToken == cfg.ImageEndTokenID {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip BOS token (16384), only store actual VQ codes [0, 16383]
|
||||||
|
if nextToken == cfg.ImageStartTokenID {
|
||||||
|
// BOS token - skip storing but continue generation
|
||||||
|
} else if nextToken < cfg.ImageStartTokenID {
|
||||||
|
// This is an actual VQ code [0, 16383] - store it
|
||||||
|
visualTokens = append(visualTokens, nextToken)
|
||||||
|
}
|
||||||
|
// Tokens >= 16386 are other special tokens, skip them
|
||||||
|
|
||||||
|
// ===== DECODE PHASE =====
|
||||||
|
// Save old cache state before forward (to free after eval)
|
||||||
|
oldCacheState = oldCacheState[:0]
|
||||||
|
for _, lc := range cache.Layers {
|
||||||
|
oldCacheState = append(oldCacheState, lc.State()...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only process the new token, use cached K,V
|
||||||
|
tokenArr := mlx.NewArrayInt32([]int32{nextToken}, []int32{1, 1})
|
||||||
|
h := m.EmbedTokens.Forward(tokenArr)
|
||||||
|
tokenArr.Free()
|
||||||
|
|
||||||
|
// Compute MRoPE position IDs for this visual token
|
||||||
|
// Visual tokens are arranged in two grids: prev grid then target grid
|
||||||
|
// Position dimensions: [temporal, height, width]
|
||||||
|
decodePositions := computeVisualTokenPositions(
|
||||||
|
visualTokenIdx, posOffset, promptLen,
|
||||||
|
prevTokenH, prevTokenW, prevGridSize,
|
||||||
|
tokenH, tokenW,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Forward through layers (decode with cache)
|
||||||
|
for j, layer := range m.Layers {
|
||||||
|
oldH := h
|
||||||
|
h = layer.ForwardWithCache(h, 1, posOffset, cfg.RMSNormEps, cache.Layers[j], decodePositions)
|
||||||
|
if j > 0 { // Don't free the embedding on first layer
|
||||||
|
oldH.Free()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Eval h and new cache state
|
||||||
|
newCacheState := make([]*mlx.Array, 0, len(m.Layers)*2)
|
||||||
|
for _, lc := range cache.Layers {
|
||||||
|
newCacheState = append(newCacheState, lc.State()...)
|
||||||
|
}
|
||||||
|
mlx.Eval(append([]*mlx.Array{h}, newCacheState...)...)
|
||||||
|
|
||||||
|
// Free old cache state (now that new state is evaluated)
|
||||||
|
for _, arr := range oldCacheState {
|
||||||
|
if arr != nil {
|
||||||
|
arr.Free()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Final norm
|
||||||
|
preNormH := h
|
||||||
|
h = m.FinalNorm.Forward(h, cfg.RMSNormEps)
|
||||||
|
preNormH.Free()
|
||||||
|
|
||||||
|
// Get logits (h is already [1, 1, hidden_size])
|
||||||
|
h = mlx.Reshape(h, 1, cfg.HiddenSize)
|
||||||
|
logits := mlx.Matmul(h, mlx.Transpose(m.LMHead, 1, 0))
|
||||||
|
h.Free()
|
||||||
|
|
||||||
|
// Sample next token
|
||||||
|
nextToken = sampleVisualToken(logits, temperature, topP, cfg, seed, &sampleCounter)
|
||||||
|
logits.Free()
|
||||||
|
|
||||||
|
posOffset++
|
||||||
|
visualTokenIdx++
|
||||||
|
|
||||||
|
// Periodically clear cache to release intermediate memory
|
||||||
|
if i%256 == 0 {
|
||||||
|
mlx.ClearCache()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(visualTokens) == 0 {
|
||||||
|
// Return at least one token to avoid empty tensor issues
|
||||||
|
visualTokens = append(visualTokens, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
return mlx.NewArrayInt32(visualTokens, []int32{1, int32(len(visualTokens))})
|
||||||
|
}
|
||||||
|
|
||||||
|
// computeVisualTokenPositions computes MRoPE position IDs for a visual token
|
||||||
|
// Returns [3][1] position IDs for temporal, height, and width dimensions
|
||||||
|
//
|
||||||
|
// MRoPE position encoding for GLM-Image visual tokens:
|
||||||
|
// - temporal: CONSTANT within each grid (= decode_pos at grid start)
|
||||||
|
// - height: decode_pos + row index within grid
|
||||||
|
// - width: decode_pos + column index within grid
|
||||||
|
//
|
||||||
|
// Between grids, decode_pos advances by max(grid_h, grid_w) to ensure
|
||||||
|
// sufficient positional separation.
|
||||||
|
func computeVisualTokenPositions(
|
||||||
|
visualIdx int32, absPos int32, promptLen int32,
|
||||||
|
prevH, prevW, prevSize int32,
|
||||||
|
targetH, targetW int32,
|
||||||
|
) [][]int32 {
|
||||||
|
positions := make([][]int32, 3)
|
||||||
|
for dim := 0; dim < 3; dim++ {
|
||||||
|
positions[dim] = make([]int32, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// First grid (prev grid) starts at decode_pos = promptLen
|
||||||
|
prevGridDecodePos := promptLen
|
||||||
|
|
||||||
|
// Second grid (target grid) starts after first grid
|
||||||
|
// next_pos = prev_decode_pos + max(prevH, prevW)
|
||||||
|
maxPrev := prevH
|
||||||
|
if prevW > maxPrev {
|
||||||
|
maxPrev = prevW
|
||||||
|
}
|
||||||
|
targetGridDecodePos := prevGridDecodePos + maxPrev
|
||||||
|
|
||||||
|
// Compute position IDs based on which grid the token is in
|
||||||
|
if visualIdx < prevSize {
|
||||||
|
// Token is in the prev grid (prev_token_h × prev_token_w)
|
||||||
|
row := visualIdx / prevW
|
||||||
|
col := visualIdx % prevW
|
||||||
|
|
||||||
|
// temporal is CONSTANT for all tokens in this grid
|
||||||
|
positions[0][0] = prevGridDecodePos
|
||||||
|
// height and width are relative to grid's decode_pos
|
||||||
|
positions[1][0] = prevGridDecodePos + row
|
||||||
|
positions[2][0] = prevGridDecodePos + col
|
||||||
|
} else {
|
||||||
|
// Token is in the target grid (token_h × token_w)
|
||||||
|
targetIdx := visualIdx - prevSize
|
||||||
|
row := targetIdx / targetW
|
||||||
|
col := targetIdx % targetW
|
||||||
|
|
||||||
|
// temporal is CONSTANT for all tokens in this grid
|
||||||
|
positions[0][0] = targetGridDecodePos
|
||||||
|
// height and width are relative to grid's decode_pos
|
||||||
|
positions[1][0] = targetGridDecodePos + row
|
||||||
|
positions[2][0] = targetGridDecodePos + col
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = targetH // Used for documentation clarity
|
||||||
|
_ = absPos // No longer used - kept for API compatibility
|
||||||
|
return positions
|
||||||
|
}
|
||||||
|
|
||||||
|
// sampleVisualToken samples from the visual vocabulary using top-p (nucleus) sampling
|
||||||
|
// Note: For GLM-Image, greedy decoding is not allowed as it may cause repetitive outputs
|
||||||
|
// Returns a visual token ID in range [0, 16511] which directly indexes into the embedding table
|
||||||
|
// sampleCounter is incremented for each call to ensure different random values
|
||||||
|
func sampleVisualToken(logits *mlx.Array, temperature float32, topP float32, cfg *VisionLanguageConfig, seed int64, sampleCounter *int64) int32 {
|
||||||
|
// The LMHead outputs logits for visual tokens only (shape [1, 16512])
|
||||||
|
// Output index directly corresponds to vocab ID [0, 16511]
|
||||||
|
// No offset needed - the visual tokens are at vocab IDs [0, 16511]
|
||||||
|
visualLogits := logits
|
||||||
|
|
||||||
|
// Apply temperature
|
||||||
|
if temperature != 1.0 && temperature > 0 {
|
||||||
|
visualLogits = mlx.DivScalar(visualLogits, temperature)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply softmax to get probabilities
|
||||||
|
probs := mlx.Softmax(visualLogits, -1)
|
||||||
|
mlx.Eval(probs)
|
||||||
|
|
||||||
|
// Get the sampled index using top-p sampling
|
||||||
|
// This directly gives us the vocab ID in [0, 16511]
|
||||||
|
// Special tokens: 16384 = BOS, 16385 = EOS
|
||||||
|
// Use seed + counter for reproducible but different random values
|
||||||
|
effectiveSeed := seed + *sampleCounter
|
||||||
|
*sampleCounter++
|
||||||
|
return sampleTopP(probs, topP, effectiveSeed)
|
||||||
|
}
|
||||||
|
|
||||||
|
// sampleTopP implements nucleus (top-p) sampling
|
||||||
|
// probs: [1, vocab_size] probability distribution
|
||||||
|
// topP: cumulative probability threshold (e.g., 0.75)
|
||||||
|
// seed: random seed for reproducible sampling
|
||||||
|
func sampleTopP(probs *mlx.Array, topP float32, seed int64) int32 {
|
||||||
|
// Negate probs for descending sort (Argsort only does ascending)
|
||||||
|
negProbs := mlx.MulScalar(probs, -1)
|
||||||
|
sortedIndices := mlx.Argsort(negProbs, -1)
|
||||||
|
sortedProbs := mlx.TakeAlongAxis(probs, sortedIndices, -1)
|
||||||
|
cumProbs := mlx.Cumsum(sortedProbs, -1)
|
||||||
|
mlx.Eval(sortedIndices, sortedProbs, cumProbs)
|
||||||
|
|
||||||
|
// Find cutoff index where cumulative probability exceeds topP
|
||||||
|
probsData := sortedProbs.Data()
|
||||||
|
cumProbsData := cumProbs.Data()
|
||||||
|
indicesData := sortedIndices.DataInt32()
|
||||||
|
|
||||||
|
// Calculate cutoff and renormalize
|
||||||
|
var cutoffIdx int
|
||||||
|
var totalProb float32
|
||||||
|
for i, cp := range cumProbsData {
|
||||||
|
totalProb += probsData[i]
|
||||||
|
if cp >= topP {
|
||||||
|
cutoffIdx = i + 1 // Include this token
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if cutoffIdx == 0 {
|
||||||
|
cutoffIdx = len(probsData) // Use all tokens if topP is very high
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sample from the truncated distribution
|
||||||
|
// Renormalize the truncated probabilities
|
||||||
|
truncatedProbs := make([]float32, cutoffIdx)
|
||||||
|
for i := 0; i < cutoffIdx; i++ {
|
||||||
|
truncatedProbs[i] = probsData[i] / totalProb
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sample using random number with provided seed for reproducibility
|
||||||
|
r := mlx.RandomUniform([]int32{1}, uint64(seed))
|
||||||
|
mlx.Eval(r)
|
||||||
|
randVal := r.Data()[0]
|
||||||
|
|
||||||
|
// Find the sampled token
|
||||||
|
var cumulative float32
|
||||||
|
for i := 0; i < cutoffIdx; i++ {
|
||||||
|
cumulative += truncatedProbs[i]
|
||||||
|
if randVal < cumulative {
|
||||||
|
return indicesData[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback to the last token in truncated set
|
||||||
|
return indicesData[cutoffIdx-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Forward for GLMBlock
|
||||||
|
func (b *GLMBlock) Forward(x *mlx.Array, seqLen int32, eps float32) *mlx.Array {
|
||||||
|
return b.ForwardWithCache(x, seqLen, 0, eps, nil, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ForwardWithCache performs block forward with optional KV caching and MRoPE
|
||||||
|
// positionIDs: [3][L] - position indices for MRoPE (nil = use sequential positions)
|
||||||
|
func (b *GLMBlock) ForwardWithCache(x *mlx.Array, seqLen int32, posOffset int32, eps float32, kvcache cache.Cache, positionIDs [][]int32) *mlx.Array {
|
||||||
|
// Pre-attention norm
|
||||||
|
normed := b.InputLayerNorm.Forward(x, eps)
|
||||||
|
|
||||||
|
// Self-attention with RoPE/MRoPE and cache
|
||||||
|
attnOut := b.SelfAttn.ForwardWithCache(normed, seqLen, posOffset, kvcache, positionIDs)
|
||||||
|
|
||||||
|
// Post-attention norm (GLM-4 style)
|
||||||
|
attnOut = b.PostSelfAttnNorm.Forward(attnOut, eps)
|
||||||
|
|
||||||
|
// Residual connection
|
||||||
|
x = mlx.Add(x, attnOut)
|
||||||
|
|
||||||
|
// Post-attention layer norm
|
||||||
|
normed = b.PostAttnLayerNorm.Forward(x, eps)
|
||||||
|
|
||||||
|
// MLP
|
||||||
|
mlpOut := b.MLP.Forward(normed)
|
||||||
|
|
||||||
|
// Post-MLP norm
|
||||||
|
mlpOut = b.PostMLPLayerNorm.Forward(mlpOut, eps)
|
||||||
|
|
||||||
|
// Residual connection
|
||||||
|
x = mlx.Add(x, mlpOut)
|
||||||
|
|
||||||
|
return x
|
||||||
|
}
|
||||||
|
|
||||||
|
// Forward for GLMAttention (without cache - used for prefill)
|
||||||
|
func (attn *GLMAttention) Forward(x *mlx.Array, seqLen int32) *mlx.Array {
|
||||||
|
return attn.ForwardWithCache(x, seqLen, 0, nil, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ForwardWithCache performs attention with optional KV caching and MRoPE
|
||||||
|
// posOffset is the position offset for RoPE (0 for prefill, cached_len for decode)
|
||||||
|
// positionIDs: [3][L] - if nil, uses sequential positions for all dims (text mode)
|
||||||
|
// kvcache is updated in-place if provided
|
||||||
|
func (attn *GLMAttention) ForwardWithCache(x *mlx.Array, seqLen int32, posOffset int32, kvcache cache.Cache, positionIDs [][]int32) *mlx.Array {
|
||||||
|
shape := x.Shape()
|
||||||
|
B := shape[0]
|
||||||
|
L := shape[1]
|
||||||
|
|
||||||
|
// Q, K, V projections
|
||||||
|
q := mlx.Matmul(x, mlx.Transpose(attn.QProj, 1, 0))
|
||||||
|
k := mlx.Matmul(x, mlx.Transpose(attn.KProj, 1, 0))
|
||||||
|
v := mlx.Matmul(x, mlx.Transpose(attn.VProj, 1, 0))
|
||||||
|
|
||||||
|
// Add biases
|
||||||
|
if attn.QBias != nil {
|
||||||
|
q = mlx.Add(q, attn.QBias)
|
||||||
|
}
|
||||||
|
if attn.KBias != nil {
|
||||||
|
k = mlx.Add(k, attn.KBias)
|
||||||
|
}
|
||||||
|
if attn.VBias != nil {
|
||||||
|
v = mlx.Add(v, attn.VBias)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reshape to [B, L, nheads, head_dim]
|
||||||
|
q = mlx.Reshape(q, B, L, attn.NHeads, attn.HeadDim)
|
||||||
|
k = mlx.Reshape(k, B, L, attn.NKVHeads, attn.HeadDim)
|
||||||
|
v = mlx.Reshape(v, B, L, attn.NKVHeads, attn.HeadDim)
|
||||||
|
|
||||||
|
// Apply partial RoPE or MRoPE
|
||||||
|
rotaryDim := int32(float32(attn.HeadDim) * attn.PartialRotary)
|
||||||
|
if len(attn.MRoPESection) == 3 && positionIDs != nil {
|
||||||
|
// Use MRoPE with explicit position IDs
|
||||||
|
q = applyPartialMRoPE(q, positionIDs, rotaryDim, attn.RopeTheta, attn.MRoPESection)
|
||||||
|
k = applyPartialMRoPE(k, positionIDs, rotaryDim, attn.RopeTheta, attn.MRoPESection)
|
||||||
|
} else if len(attn.MRoPESection) == 3 {
|
||||||
|
// Use MRoPE with sequential positions (same for all dims - text mode)
|
||||||
|
seqPositions := make([][]int32, 3)
|
||||||
|
for dim := 0; dim < 3; dim++ {
|
||||||
|
seqPositions[dim] = make([]int32, L)
|
||||||
|
for i := int32(0); i < L; i++ {
|
||||||
|
seqPositions[dim][i] = i + posOffset
|
||||||
|
}
|
||||||
|
}
|
||||||
|
q = applyPartialMRoPE(q, seqPositions, rotaryDim, attn.RopeTheta, attn.MRoPESection)
|
||||||
|
k = applyPartialMRoPE(k, seqPositions, rotaryDim, attn.RopeTheta, attn.MRoPESection)
|
||||||
|
} else {
|
||||||
|
// Fallback to standard RoPE
|
||||||
|
q = applyPartialRoPEWithOffset(q, L, posOffset, rotaryDim, attn.RopeTheta)
|
||||||
|
k = applyPartialRoPEWithOffset(k, L, posOffset, rotaryDim, attn.RopeTheta)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Transpose to [B, nheads, L, head_dim]
|
||||||
|
q = mlx.Transpose(q, 0, 2, 1, 3)
|
||||||
|
k = mlx.Transpose(k, 0, 2, 1, 3)
|
||||||
|
v = mlx.Transpose(v, 0, 2, 1, 3)
|
||||||
|
|
||||||
|
// Update cache and get full K, V for attention
|
||||||
|
if kvcache != nil {
|
||||||
|
k, v = kvcache.Update(k, v, int(L))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Repeat KV for GQA
|
||||||
|
kExpanded := k
|
||||||
|
vExpanded := v
|
||||||
|
if attn.NKVHeads < attn.NHeads {
|
||||||
|
repeats := attn.NHeads / attn.NKVHeads
|
||||||
|
kExpanded = repeatKV(k, repeats)
|
||||||
|
vExpanded = repeatKV(v, repeats)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scaled dot-product attention with causal mask
|
||||||
|
out := mlx.ScaledDotProductAttention(q, kExpanded, vExpanded, attn.Scale, true)
|
||||||
|
|
||||||
|
// Transpose back [B, nheads, L, head_dim] -> [B, L, nheads, head_dim]
|
||||||
|
out = mlx.Transpose(out, 0, 2, 1, 3)
|
||||||
|
// Reshape to [B, L, hidden_size]
|
||||||
|
out = mlx.Reshape(out, B, L, attn.NHeads*attn.HeadDim)
|
||||||
|
|
||||||
|
// Output projection
|
||||||
|
out = mlx.Matmul(out, mlx.Transpose(attn.OProj, 1, 0))
|
||||||
|
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyPartialRoPE applies RoPE to only the first rotaryDim dimensions
|
||||||
|
func applyPartialRoPE(x *mlx.Array, seqLen int32, rotaryDim int32, theta float32) *mlx.Array {
|
||||||
|
return applyPartialRoPEWithOffset(x, seqLen, 0, rotaryDim, theta)
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyPartialRoPEWithOffset applies RoPE with a position offset
|
||||||
|
func applyPartialRoPEWithOffset(x *mlx.Array, seqLen int32, posOffset int32, rotaryDim int32, theta float32) *mlx.Array {
|
||||||
|
shape := x.Shape()
|
||||||
|
B := shape[0]
|
||||||
|
L := shape[1]
|
||||||
|
H := shape[2]
|
||||||
|
D := shape[3]
|
||||||
|
|
||||||
|
if rotaryDim <= 0 || rotaryDim > D {
|
||||||
|
rotaryDim = D
|
||||||
|
}
|
||||||
|
|
||||||
|
// Split into rotary and pass-through parts
|
||||||
|
xRot := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, L, H, rotaryDim})
|
||||||
|
xPass := mlx.Slice(x, []int32{0, 0, 0, rotaryDim}, []int32{B, L, H, D})
|
||||||
|
|
||||||
|
// Apply RoPE to rotary part with position offset
|
||||||
|
xRot = applyRoPEWithOffset(xRot, L, posOffset, theta)
|
||||||
|
|
||||||
|
// Concatenate back
|
||||||
|
return mlx.Concatenate([]*mlx.Array{xRot, xPass}, 3)
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyPartialMRoPE applies Multi-dimensional RoPE (MRoPE) to the first rotaryDim dimensions
|
||||||
|
// positionIDs: [3, L] - position indices for each dimension (temporal, height, width)
|
||||||
|
// mrope_section: [8, 12, 12] - frequency pairs per dimension
|
||||||
|
// For text tokens: all 3 dimensions have the same sequential position
|
||||||
|
// For image tokens: temporal=seq_idx, height=row, width=col
|
||||||
|
func applyPartialMRoPE(x *mlx.Array, positionIDs [][]int32, rotaryDim int32, theta float32, mropeSection []int32) *mlx.Array {
|
||||||
|
shape := x.Shape()
|
||||||
|
B := shape[0]
|
||||||
|
L := shape[1]
|
||||||
|
H := shape[2]
|
||||||
|
D := shape[3]
|
||||||
|
|
||||||
|
if rotaryDim <= 0 || rotaryDim > D {
|
||||||
|
rotaryDim = D
|
||||||
|
}
|
||||||
|
|
||||||
|
// Split into rotary and pass-through parts
|
||||||
|
xRot := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, L, H, rotaryDim})
|
||||||
|
xPass := mlx.Slice(x, []int32{0, 0, 0, rotaryDim}, []int32{B, L, H, D})
|
||||||
|
|
||||||
|
// Apply MRoPE to rotary part
|
||||||
|
xRot = applyMRoPE(xRot, positionIDs, theta, mropeSection)
|
||||||
|
|
||||||
|
// Concatenate back
|
||||||
|
return mlx.Concatenate([]*mlx.Array{xRot, xPass}, 3)
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyMRoPE applies multi-dimensional rotary position embedding
|
||||||
|
// x: [B, L, H, D] where D is the rotary dimension
|
||||||
|
// positionIDs: [3][L] - positions for temporal, height, width dimensions
|
||||||
|
// mropeSection: [8, 12, 12] - frequency pairs per dimension
|
||||||
|
func applyMRoPE(x *mlx.Array, positionIDs [][]int32, theta float32, mropeSection []int32) *mlx.Array {
|
||||||
|
shape := x.Shape()
|
||||||
|
B := shape[0]
|
||||||
|
L := shape[1]
|
||||||
|
H := shape[2]
|
||||||
|
D := shape[3]
|
||||||
|
half := D / 2
|
||||||
|
|
||||||
|
// Validate mrope_section sums to half (number of frequency pairs)
|
||||||
|
var totalPairs int32
|
||||||
|
for _, s := range mropeSection {
|
||||||
|
totalPairs += s
|
||||||
|
}
|
||||||
|
if totalPairs != half {
|
||||||
|
// Fallback to standard RoPE if section doesn't match
|
||||||
|
return applyRoPEWithOffset(x, L, 0, theta)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build angles for each position dimension (matching Python's MRoPE approach)
|
||||||
|
// Python: compute freqs for all dims, then apply_mrope selects freq ranges, then duplicate
|
||||||
|
// Order: [temporal_8, height_12, width_12] -> duplicate -> [t8, h12, w12, t8, h12, w12]
|
||||||
|
angleVals := make([]*mlx.Array, 3)
|
||||||
|
|
||||||
|
freqOffset := int32(0)
|
||||||
|
for dim := 0; dim < 3; dim++ {
|
||||||
|
numPairs := mropeSection[dim]
|
||||||
|
if numPairs == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute inverse frequencies for this section
|
||||||
|
// Each dimension uses DIFFERENT frequency ranges:
|
||||||
|
// - Temporal: frequencies 0 to section[0]-1
|
||||||
|
// - Height: frequencies section[0] to section[0]+section[1]-1
|
||||||
|
// - Width: frequencies section[0]+section[1] to sum(section)-1
|
||||||
|
freqsArr := make([]float32, numPairs)
|
||||||
|
for i := int32(0); i < numPairs; i++ {
|
||||||
|
globalIdx := freqOffset + i
|
||||||
|
freqsArr[i] = float32(1.0 / math.Pow(float64(theta), float64(2*globalIdx)/float64(D)))
|
||||||
|
}
|
||||||
|
freqs := mlx.NewArray(freqsArr, []int32{numPairs})
|
||||||
|
|
||||||
|
// Position indices for this dimension
|
||||||
|
posArr := make([]float32, L)
|
||||||
|
for i := int32(0); i < L; i++ {
|
||||||
|
posArr[i] = float32(positionIDs[dim][i])
|
||||||
|
}
|
||||||
|
pos := mlx.NewArray(posArr, []int32{L})
|
||||||
|
|
||||||
|
// Compute angles: [L, numPairs] = outer(pos, freqs)
|
||||||
|
posExpanded := mlx.Reshape(pos, L, 1)
|
||||||
|
freqsExpanded := mlx.Reshape(freqs, 1, numPairs)
|
||||||
|
angleVals[dim] = mlx.Mul(posExpanded, freqsExpanded)
|
||||||
|
|
||||||
|
freqOffset += numPairs
|
||||||
|
}
|
||||||
|
|
||||||
|
// Concatenate all sections: [L, half] = [L, 32]
|
||||||
|
allAngles := mlx.Concatenate(angleVals, 1)
|
||||||
|
|
||||||
|
// Duplicate AFTER concatenation: [L, D] = [L, 64]
|
||||||
|
// This gives: [temporal_8, height_12, width_12, temporal_8, height_12, width_12]
|
||||||
|
allAngles = mlx.Concatenate([]*mlx.Array{allAngles, allAngles}, 1)
|
||||||
|
|
||||||
|
// Compute cos/sin
|
||||||
|
allCos := mlx.Cos(allAngles)
|
||||||
|
allSin := mlx.Sin(allAngles)
|
||||||
|
|
||||||
|
// Reshape for broadcasting: [1, L, 1, D] to match x [B, L, H, D]
|
||||||
|
allCos = mlx.Reshape(allCos, 1, L, 1, D)
|
||||||
|
allSin = mlx.Reshape(allSin, 1, L, 1, D)
|
||||||
|
|
||||||
|
// x_rotated = cat([-x_imag, x_real], dim=-1)
|
||||||
|
x1 := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, L, H, half}) // x_real
|
||||||
|
x2 := mlx.Slice(x, []int32{0, 0, 0, half}, []int32{B, L, H, D}) // x_imag
|
||||||
|
x2Neg := mlx.MulScalar(x2, -1) // -x_imag
|
||||||
|
xRotated := mlx.Concatenate([]*mlx.Array{x2Neg, x1}, 3) // [-x_imag, x_real]
|
||||||
|
|
||||||
|
// out = x * cos + x_rotated * sin
|
||||||
|
return mlx.Add(mlx.Mul(x, allCos), mlx.Mul(xRotated, allSin))
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyRoPE applies rotary position embedding
|
||||||
|
func applyRoPE(x *mlx.Array, seqLen int32, theta float32) *mlx.Array {
|
||||||
|
return applyRoPEWithOffset(x, seqLen, 0, theta)
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyRoPEWithOffset applies rotary position embedding with position offset
|
||||||
|
// Uses the split-half approach (matches diffusers GLM-Image with use_real_unbind_dim=-2)
|
||||||
|
func applyRoPEWithOffset(x *mlx.Array, seqLen int32, posOffset int32, theta float32) *mlx.Array {
|
||||||
|
shape := x.Shape()
|
||||||
|
B := shape[0]
|
||||||
|
L := shape[1]
|
||||||
|
H := shape[2]
|
||||||
|
D := shape[3]
|
||||||
|
half := D / 2
|
||||||
|
|
||||||
|
// Compute inverse frequencies: 1 / (theta^(2i/d))
|
||||||
|
freqsArr := make([]float32, half)
|
||||||
|
for i := int32(0); i < half; i++ {
|
||||||
|
freqsArr[i] = float32(1.0 / math.Pow(float64(theta), float64(2*i)/float64(D)))
|
||||||
|
}
|
||||||
|
freqs := mlx.NewArray(freqsArr, []int32{half})
|
||||||
|
|
||||||
|
// Position indices with offset
|
||||||
|
posArr := make([]float32, L)
|
||||||
|
for i := int32(0); i < L; i++ {
|
||||||
|
posArr[i] = float32(i + posOffset)
|
||||||
|
}
|
||||||
|
pos := mlx.NewArray(posArr, []int32{L})
|
||||||
|
|
||||||
|
// Compute angles: [L, half] = outer(pos, freqs)
|
||||||
|
posExpanded := mlx.Reshape(pos, L, 1)
|
||||||
|
freqsExpanded := mlx.Reshape(freqs, 1, half)
|
||||||
|
angles := mlx.Mul(posExpanded, freqsExpanded)
|
||||||
|
|
||||||
|
// Duplicate angles to match diffusers: cat([angles, angles], dim=-1) -> [L, D]
|
||||||
|
anglesDup := mlx.Concatenate([]*mlx.Array{angles, angles}, 1)
|
||||||
|
|
||||||
|
// Cos and sin: [L, 1, D] for broadcasting to [B, L, H, D]
|
||||||
|
cosVals := mlx.Cos(anglesDup)
|
||||||
|
sinVals := mlx.Sin(anglesDup)
|
||||||
|
cosVals = mlx.Reshape(cosVals, L, 1, D)
|
||||||
|
sinVals = mlx.Reshape(sinVals, L, 1, D)
|
||||||
|
|
||||||
|
// x_rotated = cat([-x_imag, x_real], dim=-1) where x_real=x[..., :half], x_imag=x[..., half:]
|
||||||
|
x1 := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, L, H, half}) // x_real
|
||||||
|
x2 := mlx.Slice(x, []int32{0, 0, 0, half}, []int32{B, L, H, D}) // x_imag
|
||||||
|
x2Neg := mlx.MulScalar(x2, -1) // -x_imag
|
||||||
|
xRotated := mlx.Concatenate([]*mlx.Array{x2Neg, x1}, 3) // [-x_imag, x_real]
|
||||||
|
|
||||||
|
// out = x * cos + x_rotated * sin
|
||||||
|
return mlx.Add(mlx.Mul(x, cosVals), mlx.Mul(xRotated, sinVals))
|
||||||
|
}
|
||||||
|
|
||||||
|
// repeatKV repeats key/value heads for GQA
|
||||||
|
func repeatKV(x *mlx.Array, repeats int32) *mlx.Array {
|
||||||
|
if repeats == 1 {
|
||||||
|
return x
|
||||||
|
}
|
||||||
|
shape := x.Shape()
|
||||||
|
// x: [B, nkvheads, L, head_dim]
|
||||||
|
x = mlx.ExpandDims(x, 2)
|
||||||
|
// x: [B, nkvheads, 1, L, head_dim]
|
||||||
|
x = mlx.Tile(x, []int32{1, 1, repeats, 1, 1})
|
||||||
|
// x: [B, nkvheads, repeats, L, head_dim]
|
||||||
|
return mlx.Reshape(x, shape[0], shape[1]*repeats, shape[2], shape[3])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Forward for GLMMLP (fused gate_up SwiGLU)
|
||||||
|
func (m *GLMMLP) Forward(x *mlx.Array) *mlx.Array {
|
||||||
|
// gate_up_proj outputs [gate, up] concatenated
|
||||||
|
gateUp := mlx.Matmul(x, mlx.Transpose(m.GateUpProj, 1, 0))
|
||||||
|
|
||||||
|
shape := gateUp.Shape()
|
||||||
|
halfDim := shape[len(shape)-1] / 2
|
||||||
|
|
||||||
|
// Split into gate and up
|
||||||
|
gate := mlx.Slice(gateUp, []int32{0, 0, 0}, []int32{shape[0], shape[1], halfDim})
|
||||||
|
up := mlx.Slice(gateUp, []int32{0, 0, halfDim}, []int32{shape[0], shape[1], shape[2]})
|
||||||
|
|
||||||
|
// SwiGLU: silu(gate) * up
|
||||||
|
gate = mlx.SiLU(gate)
|
||||||
|
h := mlx.Mul(gate, up)
|
||||||
|
|
||||||
|
// Down projection
|
||||||
|
return mlx.Matmul(h, mlx.Transpose(m.DownProj, 1, 0))
|
||||||
|
}
|
||||||
@@ -19,9 +19,15 @@ import (
|
|||||||
|
|
||||||
"github.com/ollama/ollama/x/imagegen"
|
"github.com/ollama/ollama/x/imagegen"
|
||||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||||
|
"github.com/ollama/ollama/x/imagegen/models/glm_image"
|
||||||
"github.com/ollama/ollama/x/imagegen/models/zimage"
|
"github.com/ollama/ollama/x/imagegen/models/zimage"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ImageModel is the interface for image generation models
|
||||||
|
type ImageModel interface {
|
||||||
|
GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error)
|
||||||
|
}
|
||||||
|
|
||||||
// Request is the image generation request format
|
// Request is the image generation request format
|
||||||
type Request struct {
|
type Request struct {
|
||||||
Prompt string `json:"prompt"`
|
Prompt string `json:"prompt"`
|
||||||
@@ -41,8 +47,9 @@ type Response struct {
|
|||||||
// Server holds the model and handles requests
|
// Server holds the model and handles requests
|
||||||
type Server struct {
|
type Server struct {
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
model *zimage.Model
|
model ImageModel
|
||||||
modelName string
|
modelName string
|
||||||
|
modelType string // "zimage" or "glm_image"
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute is the entry point for the image runner subprocess
|
// Execute is the entry point for the image runner subprocess
|
||||||
@@ -72,15 +79,35 @@ func Execute(args []string) error {
|
|||||||
requiredMemory/(1024*1024*1024), availableMemory/(1024*1024*1024))
|
requiredMemory/(1024*1024*1024), availableMemory/(1024*1024*1024))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load model
|
// Detect model type and load appropriate model
|
||||||
model := &zimage.Model{}
|
modelType, err := detectModelType(*modelName)
|
||||||
if err := model.Load(*modelName); err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to load model: %w", err)
|
return fmt.Errorf("failed to detect model type: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var model ImageModel
|
||||||
|
switch modelType {
|
||||||
|
case "GlmImagePipeline":
|
||||||
|
slog.Info("loading GLM-Image model")
|
||||||
|
m := &glm_image.Model{}
|
||||||
|
if err := m.Load(*modelName); err != nil {
|
||||||
|
return fmt.Errorf("failed to load GLM-Image model: %w", err)
|
||||||
|
}
|
||||||
|
model = m
|
||||||
|
default:
|
||||||
|
// Default to zimage for ZImagePipeline, FluxPipeline, and unknown types
|
||||||
|
slog.Info("loading Z-Image model")
|
||||||
|
m := &zimage.Model{}
|
||||||
|
if err := m.Load(*modelName); err != nil {
|
||||||
|
return fmt.Errorf("failed to load Z-Image model: %w", err)
|
||||||
|
}
|
||||||
|
model = m
|
||||||
}
|
}
|
||||||
|
|
||||||
server := &Server{
|
server := &Server{
|
||||||
model: model,
|
model: model,
|
||||||
modelName: *modelName,
|
modelName: *modelName,
|
||||||
|
modelType: modelType,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set up HTTP handlers
|
// Set up HTTP handlers
|
||||||
@@ -144,7 +171,13 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
req.Height = 1024
|
req.Height = 1024
|
||||||
}
|
}
|
||||||
if req.Steps <= 0 {
|
if req.Steps <= 0 {
|
||||||
req.Steps = 9
|
// Default steps depend on model type
|
||||||
|
switch s.modelType {
|
||||||
|
case "GlmImagePipeline":
|
||||||
|
req.Steps = 50 // GLM-Image default
|
||||||
|
default:
|
||||||
|
req.Steps = 9 // Z-Image turbo default
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if req.Seed <= 0 {
|
if req.Seed <= 0 {
|
||||||
req.Seed = time.Now().UnixNano()
|
req.Seed = time.Now().UnixNano()
|
||||||
@@ -159,25 +192,9 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate image
|
// Generate image using interface method
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
img, err := s.model.GenerateFromConfig(ctx, &zimage.GenerateConfig{
|
img, err := s.model.GenerateImage(ctx, req.Prompt, req.Width, req.Height, req.Steps, req.Seed)
|
||||||
Prompt: req.Prompt,
|
|
||||||
Width: req.Width,
|
|
||||||
Height: req.Height,
|
|
||||||
Steps: req.Steps,
|
|
||||||
Seed: req.Seed,
|
|
||||||
Progress: func(step, total int) {
|
|
||||||
resp := Response{
|
|
||||||
Content: fmt.Sprintf("\rGenerating: step %d/%d", step, total),
|
|
||||||
Done: false,
|
|
||||||
}
|
|
||||||
data, _ := json.Marshal(resp)
|
|
||||||
w.Write(data)
|
|
||||||
w.Write([]byte("\n"))
|
|
||||||
flusher.Flush()
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Don't send error for cancellation
|
// Don't send error for cancellation
|
||||||
@@ -216,3 +233,35 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
w.Write([]byte("\n"))
|
w.Write([]byte("\n"))
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// detectModelType reads the model manifest and returns the pipeline class name
|
||||||
|
func detectModelType(modelName string) (string, error) {
|
||||||
|
manifest, err := imagegen.LoadManifest(modelName)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := manifest.ReadConfig("model_index.json")
|
||||||
|
if err != nil {
|
||||||
|
return "ZImagePipeline", nil // Default to Z-Image
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try both _class_name (diffusers format) and architecture (ollama format)
|
||||||
|
var index struct {
|
||||||
|
ClassName string `json:"_class_name"`
|
||||||
|
Architecture string `json:"architecture"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &index); err != nil {
|
||||||
|
return "ZImagePipeline", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prefer _class_name, fall back to architecture
|
||||||
|
className := index.ClassName
|
||||||
|
if className == "" {
|
||||||
|
className = index.Architecture
|
||||||
|
}
|
||||||
|
if className == "" {
|
||||||
|
return "ZImagePipeline", nil
|
||||||
|
}
|
||||||
|
return className, nil
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user