Compare commits

..

5 Commits

Author SHA1 Message Date
Patrick Devine
f66d4bc75e x models: add glm 4.7 flash to mlx engine 2026-01-21 19:07:02 -08:00
Patrick Devine
369dfe07ba change function names 2026-01-21 11:17:43 -08:00
Patrick Devine
5818394002 gofumpt the gofmt 2026-01-20 17:35:23 -08:00
Patrick Devine
51631bcba0 replace modelpath.go w/ types/model/name.go 2026-01-20 17:14:54 -08:00
Patrick Devine
6e43944c94 housekeeping: move manifest implementation 2026-01-20 16:19:21 -08:00
12 changed files with 557 additions and 538 deletions

View File

@@ -609,49 +609,3 @@ func ImageGenerationsMiddleware() gin.HandlerFunc {
c.Next()
}
}
func ImageEditsMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var req openai.ImageEditRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
return
}
if req.Prompt == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "prompt is required"))
return
}
if req.Model == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "model is required"))
return
}
if req.Image == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "image is required"))
return
}
genReq, err := openai.FromImageEditRequest(req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
return
}
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(genReq); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error()))
return
}
c.Request.Body = io.NopCloser(&b)
w := &ImageWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
}
c.Writer = w
c.Next()
}
}

View File

@@ -1112,129 +1112,3 @@ func TestImageWriterResponse(t *testing.T) {
t.Errorf("expected image data 'dGVzdC1pbWFnZS1kYXRh', got %s", imageResp.Data[0].B64JSON)
}
}
func TestImageEditsMiddleware(t *testing.T) {
type testCase struct {
name string
body string
req api.GenerateRequest
err openai.ErrorResponse
}
var capturedRequest *api.GenerateRequest
// Base64-encoded test image (1x1 pixel PNG)
testImage := "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII="
decodedImage, _ := base64.StdEncoding.DecodeString("iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=")
testCases := []testCase{
{
name: "image edit basic",
body: `{
"model": "test-model",
"prompt": "make it blue",
"image": "` + testImage + `"
}`,
req: api.GenerateRequest{
Model: "test-model",
Prompt: "make it blue",
Images: []api.ImageData{decodedImage},
},
},
{
name: "image edit with size",
body: `{
"model": "test-model",
"prompt": "make it blue",
"image": "` + testImage + `",
"size": "512x768"
}`,
req: api.GenerateRequest{
Model: "test-model",
Prompt: "make it blue",
Images: []api.ImageData{decodedImage},
Width: 512,
Height: 768,
},
},
{
name: "image edit missing prompt",
body: `{
"model": "test-model",
"image": "` + testImage + `"
}`,
err: openai.ErrorResponse{
Error: openai.Error{
Message: "prompt is required",
Type: "invalid_request_error",
},
},
},
{
name: "image edit missing model",
body: `{
"prompt": "make it blue",
"image": "` + testImage + `"
}`,
err: openai.ErrorResponse{
Error: openai.Error{
Message: "model is required",
Type: "invalid_request_error",
},
},
},
{
name: "image edit missing image",
body: `{
"model": "test-model",
"prompt": "make it blue"
}`,
err: openai.ErrorResponse{
Error: openai.Error{
Message: "image is required",
Type: "invalid_request_error",
},
},
},
}
endpoint := func(c *gin.Context) {
c.Status(http.StatusOK)
}
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(ImageEditsMiddleware(), captureRequestMiddleware(&capturedRequest))
router.Handle(http.MethodPost, "/api/generate", endpoint)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tc.body))
req.Header.Set("Content-Type", "application/json")
defer func() { capturedRequest = nil }()
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if tc.err.Error.Message != "" {
var errResp openai.ErrorResponse
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(tc.err, errResp); diff != "" {
t.Fatalf("errors did not match:\n%s", diff)
}
return
}
if resp.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d: %s", resp.Code, resp.Body.String())
}
if diff := cmp.Diff(&tc.req, capturedRequest); diff != "" {
t.Fatalf("requests did not match:\n%s", diff)
}
})
}
}

View File

@@ -794,47 +794,3 @@ func ToImageGenerationResponse(resp api.GenerateResponse) ImageGenerationRespons
Data: data,
}
}
// ImageEditRequest is an OpenAI-compatible image edit request.
type ImageEditRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
Image string `json:"image"` // Base64-encoded image data
Size string `json:"size,omitempty"` // e.g., "1024x1024"
Seed *int64 `json:"seed,omitempty"`
}
// FromImageEditRequest converts an OpenAI image edit request to an Ollama GenerateRequest.
func FromImageEditRequest(r ImageEditRequest) (api.GenerateRequest, error) {
req := api.GenerateRequest{
Model: r.Model,
Prompt: r.Prompt,
}
// Decode the input image
if r.Image != "" {
imgData, err := decodeImageURL(r.Image)
if err != nil {
return api.GenerateRequest{}, fmt.Errorf("invalid image: %w", err)
}
req.Images = append(req.Images, imgData)
}
// Parse size if provided (e.g., "1024x768")
if r.Size != "" {
var w, h int32
if _, err := fmt.Sscanf(r.Size, "%dx%d", &w, &h); err == nil {
req.Width = w
req.Height = h
}
}
if r.Seed != nil {
if req.Options == nil {
req.Options = map[string]any{}
}
req.Options["seed"] = *r.Seed
}
return req, nil
}

View File

@@ -448,86 +448,3 @@ func TestFromChatRequest_TopLogprobsRange(t *testing.T) {
})
}
}
func TestFromImageEditRequest_Basic(t *testing.T) {
req := ImageEditRequest{
Model: "test-model",
Prompt: "make it blue",
Image: prefix + image,
}
result, err := FromImageEditRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Model != "test-model" {
t.Errorf("expected model 'test-model', got %q", result.Model)
}
if result.Prompt != "make it blue" {
t.Errorf("expected prompt 'make it blue', got %q", result.Prompt)
}
if len(result.Images) != 1 {
t.Fatalf("expected 1 image, got %d", len(result.Images))
}
}
func TestFromImageEditRequest_WithSize(t *testing.T) {
req := ImageEditRequest{
Model: "test-model",
Prompt: "make it blue",
Image: prefix + image,
Size: "512x768",
}
result, err := FromImageEditRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Width != 512 {
t.Errorf("expected width 512, got %d", result.Width)
}
if result.Height != 768 {
t.Errorf("expected height 768, got %d", result.Height)
}
}
func TestFromImageEditRequest_WithSeed(t *testing.T) {
seed := int64(12345)
req := ImageEditRequest{
Model: "test-model",
Prompt: "make it blue",
Image: prefix + image,
Seed: &seed,
}
result, err := FromImageEditRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Options == nil {
t.Fatal("expected options to be set")
}
if result.Options["seed"] != seed {
t.Errorf("expected seed %d, got %v", seed, result.Options["seed"])
}
}
func TestFromImageEditRequest_InvalidImage(t *testing.T) {
req := ImageEditRequest{
Model: "test-model",
Prompt: "make it blue",
Image: "not-valid-base64",
}
_, err := FromImageEditRequest(req)
if err == nil {
t.Error("expected error for invalid image")
}
}

View File

@@ -1604,9 +1604,8 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler)
r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler)
r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler)
// OpenAI-compatible image generation endpoints
// OpenAI-compatible image generation endpoint
r.POST("/v1/images/generations", middleware.ImageGenerationsMiddleware(), s.GenerateHandler)
r.POST("/v1/images/edits", middleware.ImageEditsMiddleware(), s.GenerateHandler)
// Inference (Anthropic compatibility)
r.POST("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler)
@@ -2524,11 +2523,6 @@ func (s *Server) handleImageGenerate(c *gin.Context, req api.GenerateRequest, mo
}
}
var images []llm.ImageData
for i, imgData := range req.Images {
images = append(images, llm.ImageData{ID: i, Data: imgData})
}
var streamStarted bool
if err := runner.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: req.Prompt,
@@ -2536,7 +2530,6 @@ func (s *Server) handleImageGenerate(c *gin.Context, req api.GenerateRequest, mo
Height: req.Height,
Steps: req.Steps,
Seed: seed,
Images: images,
}, func(cr llm.CompletionResponse) {
streamStarted = true
res := api.GenerateResponse{

View File

@@ -2193,157 +2193,3 @@ func TestGenerateUnload(t *testing.T) {
}
})
}
func TestGenerateWithImages(t *testing.T) {
gin.SetMode(gin.TestMode)
mock := mockRunner{
CompletionResponse: llm.CompletionResponse{
Done: true,
DoneReason: llm.DoneReasonStop,
PromptEvalCount: 1,
PromptEvalDuration: 1,
EvalCount: 1,
EvalDuration: 1,
},
}
s := Server{
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
expiredCh: make(chan *runnerRef, 1),
unloadedCh: make(chan any, 1),
loaded: make(map[string]*runnerRef),
newServerFn: newMockServer(&mock),
getGpuFn: getGpuFn,
getSystemInfoFn: getSystemInfoFn,
waitForRecovery: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
time.Sleep(time.Millisecond)
req.successCh <- &runnerRef{
llama: &mock,
}
return false
},
},
}
go s.sched.Run(t.Context())
_, digest := createBinFile(t, ggml.KV{
"general.architecture": "llama",
"llama.block_count": uint32(1),
"llama.context_length": uint32(8192),
"llama.embedding_length": uint32(4096),
"llama.attention.head_count": uint32(32),
"llama.attention.head_count_kv": uint32(8),
"tokenizer.ggml.tokens": []string{""},
"tokenizer.ggml.scores": []float32{0},
"tokenizer.ggml.token_type": []int32{0},
}, []*ggml.Tensor{
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
})
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "test",
Files: map[string]string{"file.gguf": digest},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
t.Run("images passed to completion request", func(t *testing.T) {
testImage := []byte("test-image-data")
mock.CompletionResponse.Content = "Image processed"
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test",
Prompt: "Describe this image",
Images: []api.ImageData{testImage},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
}
// Verify images were passed to the completion request
if len(mock.CompletionRequest.Images) != 1 {
t.Fatalf("expected 1 image in completion request, got %d", len(mock.CompletionRequest.Images))
}
if !bytes.Equal(mock.CompletionRequest.Images[0].Data, testImage) {
t.Errorf("image data mismatch in completion request")
}
if mock.CompletionRequest.Images[0].ID != 0 {
t.Errorf("expected image ID 0, got %d", mock.CompletionRequest.Images[0].ID)
}
})
t.Run("multiple images passed to completion request", func(t *testing.T) {
testImage1 := []byte("test-image-1")
testImage2 := []byte("test-image-2")
mock.CompletionResponse.Content = "Images processed"
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test",
Prompt: "Compare these images",
Images: []api.ImageData{testImage1, testImage2},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
}
// Verify both images were passed
if len(mock.CompletionRequest.Images) != 2 {
t.Fatalf("expected 2 images in completion request, got %d", len(mock.CompletionRequest.Images))
}
if !bytes.Equal(mock.CompletionRequest.Images[0].Data, testImage1) {
t.Errorf("first image data mismatch")
}
if !bytes.Equal(mock.CompletionRequest.Images[1].Data, testImage2) {
t.Errorf("second image data mismatch")
}
if mock.CompletionRequest.Images[0].ID != 0 || mock.CompletionRequest.Images[1].ID != 1 {
t.Errorf("expected image IDs 0 and 1, got %d and %d",
mock.CompletionRequest.Images[0].ID, mock.CompletionRequest.Images[1].ID)
}
})
t.Run("no images when none provided", func(t *testing.T) {
mock.CompletionResponse.Content = "No images"
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test",
Prompt: "Hello",
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
}
// Verify no images in completion request
if len(mock.CompletionRequest.Images) != 0 {
t.Fatalf("expected 0 images in completion request, got %d", len(mock.CompletionRequest.Images))
}
})
}

View File

@@ -19,6 +19,7 @@ import (
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/models/flux2"
"github.com/ollama/ollama/x/imagegen/models/gemma3"
"github.com/ollama/ollama/x/imagegen/models/glm4_moe_lite"
"github.com/ollama/ollama/x/imagegen/models/gpt_oss"
"github.com/ollama/ollama/x/imagegen/models/llama"
"github.com/ollama/ollama/x/imagegen/models/qwen_image"
@@ -301,6 +302,8 @@ func load(modelPath string) (Model, error) {
return gemma3.Load(modelPath)
case "gemma3_text":
return gemma3.LoadText(modelPath)
case "glm4_moe_lite":
return glm4_moe_lite.Load(modelPath)
default:
return llama.Load(modelPath)
}

View File

@@ -991,6 +991,19 @@ func Concat(a, b *Array, axis int) *Array {
return Concatenate([]*Array{a, b}, axis)
}
// Stack stacks arrays along a new axis (axis 0 by default)
func Stack(arrays []*Array, axis int) *Array {
handles := make([]C.mlx_array, len(arrays))
for i, arr := range arrays {
handles[i] = arr.c
}
vec := C.mlx_vector_array_new_data(&handles[0], C.size_t(len(handles)))
res := C.mlx_array_new()
C.mlx_stack_axis(&res, vec, C.int(axis), C.default_stream())
C.mlx_vector_array_free(vec)
return newArray(res)
}
// Slice slices the array
func Slice(a *Array, start, stop []int32) *Array {
n := len(start)

View File

@@ -177,20 +177,6 @@ func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height
})
}
// GenerateImageWithInputs implements runner.ImageEditModel interface.
// It generates an image conditioned on the provided input images for image editing.
func (m *Model) GenerateImageWithInputs(ctx context.Context, prompt string, width, height int32, steps int, seed int64, inputImages []image.Image, progress func(step, total int)) (*mlx.Array, error) {
return m.GenerateFromConfig(ctx, &GenerateConfig{
Prompt: prompt,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
InputImages: inputImages,
Progress: progress,
})
}
// MaxOutputPixels is the maximum output resolution (4 megapixels, ~2048x2048)
const MaxOutputPixels = 2048 * 2048

View File

@@ -0,0 +1,529 @@
//go:build mlx
// Package glm4_moe_lite provides the GLM4-MoE-Lite implementation for MLX.
// This model uses Multi-head Latent Attention (MLA) and Mixture of Experts (MoE).
package glm4_moe_lite
import (
"encoding/json"
"fmt"
"math"
"os"
"path/filepath"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
"github.com/ollama/ollama/x/imagegen/tokenizer"
)
// Config holds GLM4-MoE-Lite model configuration
type Config struct {
HiddenSize int32 `json:"hidden_size"`
NumHiddenLayers int32 `json:"num_hidden_layers"`
IntermediateSize int32 `json:"intermediate_size"`
MoEIntermediateSize int32 `json:"moe_intermediate_size"`
NumAttentionHeads int32 `json:"num_attention_heads"`
NumKeyValueHeads int32 `json:"num_key_value_heads"`
VocabSize int32 `json:"vocab_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
AttentionBias bool `json:"attention_bias"`
// MLA (Multi-head Latent Attention) parameters
QLoraRank int32 `json:"q_lora_rank"`
KVLoraRank int32 `json:"kv_lora_rank"`
QKRopeHeadDim int32 `json:"qk_rope_head_dim"`
QKNopeHeadDim int32 `json:"qk_nope_head_dim"`
VHeadDim int32 `json:"v_head_dim"`
// MoE parameters
NRoutedExperts int32 `json:"n_routed_experts"`
NSharedExperts int32 `json:"n_shared_experts"`
NumExpertsPerTok int32 `json:"num_experts_per_tok"`
RoutedScalingFactor float32 `json:"routed_scaling_factor"`
NormTopKProb bool `json:"norm_topk_prob"`
FirstKDenseReplace int32 `json:"first_k_dense_replace"`
NGroup int32 `json:"n_group"`
TopKGroup int32 `json:"topk_group"`
// Computed fields
QHeadDim int32 `json:"-"` // qk_nope_head_dim + qk_rope_head_dim
Scale float32 `json:"-"` // 1/sqrt(QHeadDim)
}
// MLAAttention implements Multi-head Latent Attention
type MLAAttention struct {
// Low-rank query projections
QAProj *nn.Linear `weight:"self_attn.q_a_proj"`
QALayerNorm *nn.RMSNorm `weight:"self_attn.q_a_layernorm"`
QBProj *nn.Linear `weight:"self_attn.q_b_proj"`
// Low-rank KV projections (with shared rope component)
KVAProjWithMQA *nn.Linear `weight:"self_attn.kv_a_proj_with_mqa"`
KVALayerNorm *nn.RMSNorm `weight:"self_attn.kv_a_layernorm"`
KVBProj *nn.Linear `weight:"self_attn.kv_b_proj"`
// Output projection
OProj *nn.Linear `weight:"self_attn.o_proj"`
}
// Forward computes MLA attention output
func (a *MLAAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
// Query path: q_a_proj -> layernorm -> q_b_proj
q := a.QAProj.Forward(x)
q = a.QALayerNorm.Forward(q, cfg.RMSNormEps)
q = a.QBProj.Forward(q)
// Reshape Q: [B, L, num_heads * q_head_dim] -> [B, num_heads, L, q_head_dim]
q = mlx.Reshape(q, B, L, cfg.NumAttentionHeads, cfg.QHeadDim)
q = mlx.Transpose(q, 0, 2, 1, 3)
// Split Q into nope and rope parts
qNope := mlx.Slice(q, []int32{0, 0, 0, 0}, []int32{B, cfg.NumAttentionHeads, L, cfg.QKNopeHeadDim})
qPE := mlx.Slice(q, []int32{0, 0, 0, cfg.QKNopeHeadDim}, []int32{B, cfg.NumAttentionHeads, L, cfg.QHeadDim})
// KV path: kv_a_proj_with_mqa -> split -> layernorm -> kv_b_proj
compressedKV := a.KVAProjWithMQA.Forward(x)
// Split into compressed_kv and k_pe (shared rope component)
kvCompressed := mlx.Slice(compressedKV, []int32{0, 0, 0}, []int32{B, L, cfg.KVLoraRank})
kPE := mlx.Slice(compressedKV, []int32{0, 0, cfg.KVLoraRank}, []int32{B, L, cfg.KVLoraRank + cfg.QKRopeHeadDim})
// k_pe is shared across heads (MQA-style): [B, L, rope_dim] -> [B, 1, L, rope_dim]
kPE = mlx.Reshape(kPE, B, L, 1, cfg.QKRopeHeadDim)
kPE = mlx.Transpose(kPE, 0, 2, 1, 3)
// Apply layernorm and project KV
kvCompressed = a.KVALayerNorm.Forward(kvCompressed, cfg.RMSNormEps)
kv := a.KVBProj.Forward(kvCompressed)
// Reshape KV: [B, L, num_heads * (qk_nope_head_dim + v_head_dim)]
kv = mlx.Reshape(kv, B, L, cfg.NumAttentionHeads, cfg.QKNopeHeadDim+cfg.VHeadDim)
kv = mlx.Transpose(kv, 0, 2, 1, 3)
// Split into k_nope and values
kNope := mlx.Slice(kv, []int32{0, 0, 0, 0}, []int32{B, cfg.NumAttentionHeads, L, cfg.QKNopeHeadDim})
values := mlx.Slice(kv, []int32{0, 0, 0, cfg.QKNopeHeadDim}, []int32{B, cfg.NumAttentionHeads, L, cfg.QKNopeHeadDim + cfg.VHeadDim})
// Apply RoPE to the rope parts only
offset := 0
if c != nil {
offset = c.Offset()
}
qPE = mlx.RoPE(qPE, int(cfg.QKRopeHeadDim), true, cfg.RopeTheta, 1.0, offset)
kPE = mlx.RoPE(kPE, int(cfg.QKRopeHeadDim), true, cfg.RopeTheta, 1.0, offset)
// Repeat k_pe across all heads
kPE = mlx.Tile(kPE, []int32{1, cfg.NumAttentionHeads, 1, 1})
// Concatenate nope and rope parts
queries := mlx.Concatenate([]*mlx.Array{qNope, qPE}, 3)
keys := mlx.Concatenate([]*mlx.Array{kNope, kPE}, 3)
// Update KV cache
if c != nil {
keys, values = c.Update(keys, values, int(L))
}
// Scaled dot product attention
out := mlx.ScaledDotProductAttention(queries, keys, values, cfg.Scale, L > 1)
// Reshape back: [B, num_heads, L, v_head_dim] -> [B, L, num_heads * v_head_dim]
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.VHeadDim)
return a.OProj.Forward(out)
}
// DenseMLP implements the standard SwiGLU MLP for dense layers
type DenseMLP struct {
GateProj *nn.Linear `weight:"mlp.gate_proj"`
UpProj *nn.Linear `weight:"mlp.up_proj"`
DownProj *nn.Linear `weight:"mlp.down_proj"`
}
// Forward applies the SwiGLU MLP
func (m *DenseMLP) Forward(x *mlx.Array) *mlx.Array {
gate := mlx.SiLU(m.GateProj.Forward(x))
up := m.UpProj.Forward(x)
return m.DownProj.Forward(mlx.Mul(gate, up))
}
// MoEGate implements the expert gating mechanism
type MoEGate struct {
Weight *mlx.Array `weight:"mlp.gate.weight"`
EScoreCorrectionBias *mlx.Array `weight:"mlp.gate.e_score_correction_bias,optional"`
}
// Forward computes expert selection indices and scores
func (g *MoEGate) Forward(x *mlx.Array, cfg *Config) (*mlx.Array, *mlx.Array) {
// Compute gate logits: x @ weight.T
gates := mlx.Linear(x, mlx.Transpose(g.Weight, 1, 0))
// Sigmoid scoring
scores := mlx.Sigmoid(gates)
origScores := scores
// Add correction bias if present
if g.EScoreCorrectionBias != nil {
scores = mlx.Add(scores, g.EScoreCorrectionBias)
}
// Group-wise expert selection (simplified for n_group=1)
// Select top-k experts
topK := cfg.NumExpertsPerTok
negScores := mlx.Neg(scores)
inds := mlx.Argpartition(negScores, int(topK)-1, -1)
shape := inds.Shape()
inds = mlx.Slice(inds, []int32{0, 0, 0}, []int32{shape[0], shape[1], topK})
// Get scores for selected experts
scores = mlx.TakeAlongAxis(origScores, inds, -1)
// Normalize if configured
if topK > 1 && cfg.NormTopKProb {
sumScores := mlx.Sum(scores, -1, true)
scores = mlx.Div(scores, sumScores)
}
// Apply routing scaling factor
scores = mlx.MulScalar(scores, cfg.RoutedScalingFactor)
return inds, scores
}
// SwitchMLP implements the MoE expert computation using stacked weights
// Note: No weight tags - these are populated manually by stacking expert weights
type SwitchMLP struct {
GateWeight *mlx.Array
UpWeight *mlx.Array
DownWeight *mlx.Array
}
// Forward applies the switched expert MLP
func (s *SwitchMLP) Forward(x *mlx.Array, indices *mlx.Array, cfg *Config) *mlx.Array {
shape := x.Shape()
B, L := shape[0], shape[1]
topK := cfg.NumExpertsPerTok
// Expand x for expert computation: [B, L, D] -> [B, L, 1, 1, D]
xExpanded := mlx.ExpandDims(mlx.ExpandDims(x, -2), -2)
// Flatten for gather_mm: [B*L, 1, 1, D]
xFlat := mlx.Reshape(xExpanded, B*L, 1, 1, cfg.HiddenSize)
// Flatten indices: [B, L, topK] -> [B*L, topK]
idxFlat := mlx.Reshape(indices, B*L, topK)
// Sort for efficient gather (when we have many tokens)
doSort := B*L >= 64
var invOrder *mlx.Array
n := B * L * topK
if doSort {
idxAll := mlx.Flatten(idxFlat)
order := mlx.Argsort(idxAll, 0)
invOrder = mlx.Argsort(order, 0)
// Reorder x based on sorted indices
xFlat = mlx.ExpandDims(mlx.Take(mlx.Squeeze(xFlat, 1), mlx.FloorDivideScalar(order, topK), 0), 1)
idxFlat = mlx.Reshape(mlx.Take(idxAll, order, 0), n, 1)
}
// Expert computation using gather_mm
// gate: x @ gate_weight.T (indices are on the rhs/weight side)
gate := mlx.GatherMM(xFlat, mlx.Transpose(s.GateWeight, 0, 2, 1), nil, idxFlat, doSort)
// up: x @ up_weight.T
up := mlx.GatherMM(xFlat, mlx.Transpose(s.UpWeight, 0, 2, 1), nil, idxFlat, doSort)
// SwiGLU activation
hidden := mlx.Mul(mlx.SiLU(gate), up)
// down: hidden @ down_weight.T
down := mlx.GatherMM(hidden, mlx.Transpose(s.DownWeight, 0, 2, 1), nil, idxFlat, doSort)
// Unsort if we sorted
if doSort {
down = mlx.Reshape(mlx.Take(mlx.Squeeze(mlx.Squeeze(down, 2), 1), invOrder, 0), B*L, topK, cfg.HiddenSize)
} else {
down = mlx.Squeeze(down, 2)
}
return mlx.Reshape(down, B, L, topK, cfg.HiddenSize)
}
// SharedExperts implements the shared expert MLP
type SharedExperts struct {
GateProj *nn.Linear `weight:"mlp.shared_experts.gate_proj"`
UpProj *nn.Linear `weight:"mlp.shared_experts.up_proj"`
DownProj *nn.Linear `weight:"mlp.shared_experts.down_proj"`
}
// Forward applies the shared expert MLP
func (s *SharedExperts) Forward(x *mlx.Array) *mlx.Array {
gate := mlx.SiLU(s.GateProj.Forward(x))
up := s.UpProj.Forward(x)
return s.DownProj.Forward(mlx.Mul(gate, up))
}
// MoE implements the full Mixture of Experts layer
type MoE struct {
Gate *MoEGate
SwitchMLP *SwitchMLP
SharedExperts *SharedExperts
}
// Forward applies the MoE layer
func (m *MoE) Forward(x *mlx.Array, cfg *Config) *mlx.Array {
shape := x.Shape()
B, L := shape[0], shape[1]
// Get expert indices and scores
inds, scores := m.Gate.Forward(x, cfg)
// Apply routed experts
expertOut := m.SwitchMLP.Forward(x, inds, cfg)
// Weight by scores: [B, L, topK, D] * [B, L, topK, 1] -> sum over topK
scoresExpanded := mlx.ExpandDims(scores, -1)
y := mlx.Sum(mlx.Mul(expertOut, scoresExpanded), 2, false)
// Add shared experts if present
if m.SharedExperts != nil {
y = mlx.Add(y, m.SharedExperts.Forward(x))
}
return mlx.Reshape(y, B, L, cfg.HiddenSize)
}
// DenseBlock represents a dense transformer block (for first_k_dense_replace layers)
type DenseBlock struct {
Attention *MLAAttention
MLP *DenseMLP
InputLayerNorm *nn.RMSNorm `weight:"input_layernorm"`
PostAttentionLayerNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
}
// Forward applies the dense block
func (b *DenseBlock) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
// Pre-norm attention with residual
r := b.Attention.Forward(b.InputLayerNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg)
h := mlx.Add(x, r)
// Pre-norm MLP with residual
r = b.MLP.Forward(b.PostAttentionLayerNorm.Forward(h, cfg.RMSNormEps))
return mlx.Add(h, r)
}
// MoEBlock represents a MoE transformer block
type MoEBlock struct {
Attention *MLAAttention
MoE *MoE
InputLayerNorm *nn.RMSNorm `weight:"input_layernorm"`
PostAttentionLayerNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
}
// Forward applies the MoE block
func (b *MoEBlock) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
// Pre-norm attention with residual
r := b.Attention.Forward(b.InputLayerNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg)
h := mlx.Add(x, r)
// Pre-norm MoE with residual
r = b.MoE.Forward(b.PostAttentionLayerNorm.Forward(h, cfg.RMSNormEps), cfg)
return mlx.Add(h, r)
}
// Block interface for both dense and MoE blocks
type Block interface {
Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array
}
// Model represents the complete GLM4-MoE-Lite model
type Model struct {
EmbedTokens *nn.Embedding `weight:"model.embed_tokens"`
Layers []Block `weight:"-"` // Loaded manually due to different block types
Norm *nn.RMSNorm `weight:"model.norm"`
LMHead *nn.Linear `weight:"lm_head"`
tok *tokenizer.Tokenizer
*Config
}
// sanitizeExpertWeights stacks individual expert weights into a single tensor
func sanitizeExpertWeights(weights *safetensors.ModelWeights, prefix string, numExperts int32) (*mlx.Array, *mlx.Array, *mlx.Array) {
var gateWeights, upWeights, downWeights []*mlx.Array
for e := int32(0); e < numExperts; e++ {
gw, _ := weights.GetTensor(fmt.Sprintf("%s.mlp.experts.%d.gate_proj.weight", prefix, e))
uw, _ := weights.GetTensor(fmt.Sprintf("%s.mlp.experts.%d.up_proj.weight", prefix, e))
dw, _ := weights.GetTensor(fmt.Sprintf("%s.mlp.experts.%d.down_proj.weight", prefix, e))
if gw != nil {
gateWeights = append(gateWeights, gw)
}
if uw != nil {
upWeights = append(upWeights, uw)
}
if dw != nil {
downWeights = append(downWeights, dw)
}
}
var stackedGate, stackedUp, stackedDown *mlx.Array
if len(gateWeights) > 0 {
stackedGate = mlx.Stack(gateWeights, 0)
}
if len(upWeights) > 0 {
stackedUp = mlx.Stack(upWeights, 0)
}
if len(downWeights) > 0 {
stackedDown = mlx.Stack(downWeights, 0)
}
return stackedGate, stackedUp, stackedDown
}
// Load loads a GLM4-MoE-Lite model from the given path
func Load(modelPath string) (*Model, error) {
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
if err != nil {
return nil, fmt.Errorf("load config: %w", err)
}
var cfg Config
if err := json.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("parse config: %w", err)
}
// Compute derived fields
cfg.QHeadDim = cfg.QKNopeHeadDim + cfg.QKRopeHeadDim
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.QHeadDim)))
weights, err := safetensors.LoadModelWeights(modelPath)
if err != nil {
return nil, fmt.Errorf("load weights: %w", err)
}
tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json"))
if err != nil {
return nil, fmt.Errorf("load tokenizer: %w", err)
}
m := &Model{
Layers: make([]Block, cfg.NumHiddenLayers),
Config: &cfg,
tok: tok,
}
// Load embedding, norm, and lm_head
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return nil, err
}
// Load layers manually due to different block types
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
prefix := fmt.Sprintf("model.layers.%d", i)
// Load attention (same for both block types)
attn := &MLAAttention{}
if err := safetensors.LoadModule(attn, weights, prefix); err != nil {
return nil, fmt.Errorf("layer %d attention: %w", i, err)
}
if i < cfg.FirstKDenseReplace {
// Dense block
block := &DenseBlock{Attention: attn}
if err := safetensors.LoadModule(block, weights, prefix); err != nil {
return nil, fmt.Errorf("layer %d dense: %w", i, err)
}
m.Layers[i] = block
} else {
// MoE block
block := &MoEBlock{Attention: attn}
if err := safetensors.LoadModule(block, weights, prefix); err != nil {
return nil, fmt.Errorf("layer %d moe block: %w", i, err)
}
// Stack expert weights
gateW, upW, downW := sanitizeExpertWeights(weights, prefix, cfg.NRoutedExperts)
block.MoE = &MoE{
Gate: &MoEGate{},
SwitchMLP: &SwitchMLP{
GateWeight: gateW,
UpWeight: upW,
DownWeight: downW,
},
}
// Load gate weights
if err := safetensors.LoadModule(block.MoE.Gate, weights, prefix); err != nil {
return nil, fmt.Errorf("layer %d gate: %w", i, err)
}
// Load shared experts if present
if cfg.NSharedExperts > 0 {
block.MoE.SharedExperts = &SharedExperts{}
if err := safetensors.LoadModule(block.MoE.SharedExperts, weights, prefix); err != nil {
return nil, fmt.Errorf("layer %d shared experts: %w", i, err)
}
}
m.Layers[i] = block
}
}
mlx.Eval(mlx.Collect(m)...)
weights.ReleaseAll()
return m, nil
}
// Forward computes the forward pass of the model
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
B, L := tokens.Shape()[0], tokens.Shape()[1]
h := m.EmbedTokens.Forward(tokens)
for i, layer := range m.Layers {
var c cache.Cache
if caches != nil {
c = caches[i]
}
h = layer.Forward(h, c, B, L, m.Config)
}
h = m.Norm.Forward(h, m.RMSNormEps)
return m.LMHead.Forward(h)
}
// Interface methods
// NumLayers returns the number of transformer layers
func (m *Model) NumLayers() int { return len(m.Layers) }
// MaxContextLength returns the maximum context length
func (m *Model) MaxContextLength() int32 { return m.MaxPositionEmbeddings }
// VocabSize returns the vocabulary size
func (m *Model) VocabSize() int32 { return m.Config.VocabSize }
// Tokenizer returns the model's tokenizer
func (m *Model) Tokenizer() *tokenizer.Tokenizer { return m.tok }
// NewCache creates a new KV cache for the model
func (m *Model) NewCache(maxSeqLen int32) []cache.Cache {
caches := make([]cache.Cache, len(m.Layers))
for i := range caches {
caches[i] = cache.NewKVCache()
}
return caches
}
// FormatPrompt applies the GLM-4 chat template
func (m *Model) FormatPrompt(prompt string) string {
return "[gMASK]<sop><|user|>\n" + prompt + "<|assistant|>\n"
}

View File

@@ -9,7 +9,6 @@ import (
"encoding/json"
"flag"
"fmt"
"image"
"log/slog"
"net/http"
"os"
@@ -26,12 +25,11 @@ import (
// Request is the image generation request format
type Request struct {
Prompt string `json:"prompt"`
Width int32 `json:"width,omitempty"`
Height int32 `json:"height,omitempty"`
Steps int `json:"steps,omitempty"`
Seed int64 `json:"seed,omitempty"`
Images [][]byte `json:"images,omitempty"` // Input images for image editing/conditioning
Prompt string `json:"prompt"`
Width int32 `json:"width,omitempty"`
Height int32 `json:"height,omitempty"`
Steps int `json:"steps,omitempty"`
Seed int64 `json:"seed,omitempty"`
}
// Response is streamed back for each progress update
@@ -48,13 +46,6 @@ type ImageModel interface {
GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64, progress func(step, total int)) (*mlx.Array, error)
}
// ImageEditModel extends ImageModel with image editing/conditioning capability.
// Models that support input images for editing should implement this interface.
type ImageEditModel interface {
ImageModel
GenerateImageWithInputs(ctx context.Context, prompt string, width, height int32, steps int, seed int64, inputImages []image.Image, progress func(step, total int)) (*mlx.Array, error)
}
// Server holds the model and handles requests
type Server struct {
mu sync.Mutex
@@ -170,29 +161,6 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
return
}
// Validate and decode input images
const maxInputImages = 2
if len(req.Images) > maxInputImages {
http.Error(w, fmt.Sprintf("too many input images, maximum is %d", maxInputImages), http.StatusBadRequest)
return
}
var inputImages []image.Image
if len(req.Images) > 0 {
// TODO: add memory check for input images
inputImages = make([]image.Image, len(req.Images))
for i, imgBytes := range req.Images {
img, err := imagegen.DecodeImage(imgBytes)
if err != nil {
http.Error(w, fmt.Sprintf("invalid image %d: %v", i, err), http.StatusBadRequest)
return
}
inputImages[i] = img
}
slog.Info("decoded input images", "count", len(inputImages))
}
// Serialize generation requests - MLX model may not handle concurrent generation
s.mu.Lock()
defer s.mu.Unlock()
@@ -224,19 +192,7 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
flusher.Flush()
}
// Use ImageEditModel if available and images provided, otherwise use basic ImageModel
var img *mlx.Array
var err error
if len(inputImages) > 0 {
editModel, ok := s.model.(ImageEditModel)
if !ok {
http.Error(w, "model does not support image editing", http.StatusBadRequest)
return
}
img, err = editModel.GenerateImageWithInputs(ctx, req.Prompt, req.Width, req.Height, req.Steps, req.Seed, inputImages, progress)
} else {
img, err = s.model.GenerateImage(ctx, req.Prompt, req.Width, req.Height, req.Steps, req.Seed, progress)
}
img, err := s.model.GenerateImage(ctx, req.Prompt, req.Width, req.Height, req.Steps, req.Seed, progress)
if err != nil {
// Don't send error for cancellation

View File

@@ -226,27 +226,19 @@ func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn f
seed = time.Now().UnixNano()
}
// Extract raw image bytes from llm.ImageData slice
var images [][]byte
for _, img := range req.Images {
images = append(images, img.Data)
}
// Build request for subprocess
creq := struct {
Prompt string `json:"prompt"`
Width int32 `json:"width,omitempty"`
Height int32 `json:"height,omitempty"`
Steps int32 `json:"steps,omitempty"`
Seed int64 `json:"seed,omitempty"`
Images [][]byte `json:"images,omitempty"`
Prompt string `json:"prompt"`
Width int32 `json:"width,omitempty"`
Height int32 `json:"height,omitempty"`
Steps int32 `json:"steps,omitempty"`
Seed int64 `json:"seed,omitempty"`
}{
Prompt: req.Prompt,
Width: req.Width,
Height: req.Height,
Steps: req.Steps,
Seed: seed,
Images: images,
}
body, err := json.Marshal(creq)