mirror of
https://github.com/ollama/ollama.git
synced 2026-01-21 22:10:58 -05:00
Compare commits
5 Commits
focused-mo
...
pdevine/gl
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f66d4bc75e | ||
|
|
369dfe07ba | ||
|
|
5818394002 | ||
|
|
51631bcba0 | ||
|
|
6e43944c94 |
@@ -609,49 +609,3 @@ func ImageGenerationsMiddleware() gin.HandlerFunc {
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func ImageEditsMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var req openai.ImageEditRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
if req.Prompt == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "prompt is required"))
|
||||
return
|
||||
}
|
||||
|
||||
if req.Model == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "model is required"))
|
||||
return
|
||||
}
|
||||
|
||||
if req.Image == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "image is required"))
|
||||
return
|
||||
}
|
||||
|
||||
genReq, err := openai.FromImageEditRequest(req)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(genReq); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(&b)
|
||||
|
||||
w := &ImageWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
}
|
||||
|
||||
c.Writer = w
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1112,129 +1112,3 @@ func TestImageWriterResponse(t *testing.T) {
|
||||
t.Errorf("expected image data 'dGVzdC1pbWFnZS1kYXRh', got %s", imageResp.Data[0].B64JSON)
|
||||
}
|
||||
}
|
||||
|
||||
func TestImageEditsMiddleware(t *testing.T) {
|
||||
type testCase struct {
|
||||
name string
|
||||
body string
|
||||
req api.GenerateRequest
|
||||
err openai.ErrorResponse
|
||||
}
|
||||
|
||||
var capturedRequest *api.GenerateRequest
|
||||
|
||||
// Base64-encoded test image (1x1 pixel PNG)
|
||||
testImage := ""
|
||||
decodedImage, _ := base64.StdEncoding.DecodeString("iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=")
|
||||
|
||||
testCases := []testCase{
|
||||
{
|
||||
name: "image edit basic",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"prompt": "make it blue",
|
||||
"image": "` + testImage + `"
|
||||
}`,
|
||||
req: api.GenerateRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "make it blue",
|
||||
Images: []api.ImageData{decodedImage},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "image edit with size",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"prompt": "make it blue",
|
||||
"image": "` + testImage + `",
|
||||
"size": "512x768"
|
||||
}`,
|
||||
req: api.GenerateRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "make it blue",
|
||||
Images: []api.ImageData{decodedImage},
|
||||
Width: 512,
|
||||
Height: 768,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "image edit missing prompt",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"image": "` + testImage + `"
|
||||
}`,
|
||||
err: openai.ErrorResponse{
|
||||
Error: openai.Error{
|
||||
Message: "prompt is required",
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "image edit missing model",
|
||||
body: `{
|
||||
"prompt": "make it blue",
|
||||
"image": "` + testImage + `"
|
||||
}`,
|
||||
err: openai.ErrorResponse{
|
||||
Error: openai.Error{
|
||||
Message: "model is required",
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "image edit missing image",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"prompt": "make it blue"
|
||||
}`,
|
||||
err: openai.ErrorResponse{
|
||||
Error: openai.Error{
|
||||
Message: "image is required",
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
endpoint := func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
}
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(ImageEditsMiddleware(), captureRequestMiddleware(&capturedRequest))
|
||||
router.Handle(http.MethodPost, "/api/generate", endpoint)
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tc.body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
defer func() { capturedRequest = nil }()
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if tc.err.Error.Message != "" {
|
||||
var errResp openai.ErrorResponse
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.err, errResp); diff != "" {
|
||||
t.Fatalf("errors did not match:\n%s", diff)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d: %s", resp.Code, resp.Body.String())
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(&tc.req, capturedRequest); diff != "" {
|
||||
t.Fatalf("requests did not match:\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -794,47 +794,3 @@ func ToImageGenerationResponse(resp api.GenerateResponse) ImageGenerationRespons
|
||||
Data: data,
|
||||
}
|
||||
}
|
||||
|
||||
// ImageEditRequest is an OpenAI-compatible image edit request.
|
||||
type ImageEditRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
Image string `json:"image"` // Base64-encoded image data
|
||||
Size string `json:"size,omitempty"` // e.g., "1024x1024"
|
||||
Seed *int64 `json:"seed,omitempty"`
|
||||
}
|
||||
|
||||
// FromImageEditRequest converts an OpenAI image edit request to an Ollama GenerateRequest.
|
||||
func FromImageEditRequest(r ImageEditRequest) (api.GenerateRequest, error) {
|
||||
req := api.GenerateRequest{
|
||||
Model: r.Model,
|
||||
Prompt: r.Prompt,
|
||||
}
|
||||
|
||||
// Decode the input image
|
||||
if r.Image != "" {
|
||||
imgData, err := decodeImageURL(r.Image)
|
||||
if err != nil {
|
||||
return api.GenerateRequest{}, fmt.Errorf("invalid image: %w", err)
|
||||
}
|
||||
req.Images = append(req.Images, imgData)
|
||||
}
|
||||
|
||||
// Parse size if provided (e.g., "1024x768")
|
||||
if r.Size != "" {
|
||||
var w, h int32
|
||||
if _, err := fmt.Sscanf(r.Size, "%dx%d", &w, &h); err == nil {
|
||||
req.Width = w
|
||||
req.Height = h
|
||||
}
|
||||
}
|
||||
|
||||
if r.Seed != nil {
|
||||
if req.Options == nil {
|
||||
req.Options = map[string]any{}
|
||||
}
|
||||
req.Options["seed"] = *r.Seed
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
@@ -448,86 +448,3 @@ func TestFromChatRequest_TopLogprobsRange(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromImageEditRequest_Basic(t *testing.T) {
|
||||
req := ImageEditRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "make it blue",
|
||||
Image: prefix + image,
|
||||
}
|
||||
|
||||
result, err := FromImageEditRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Model != "test-model" {
|
||||
t.Errorf("expected model 'test-model', got %q", result.Model)
|
||||
}
|
||||
|
||||
if result.Prompt != "make it blue" {
|
||||
t.Errorf("expected prompt 'make it blue', got %q", result.Prompt)
|
||||
}
|
||||
|
||||
if len(result.Images) != 1 {
|
||||
t.Fatalf("expected 1 image, got %d", len(result.Images))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromImageEditRequest_WithSize(t *testing.T) {
|
||||
req := ImageEditRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "make it blue",
|
||||
Image: prefix + image,
|
||||
Size: "512x768",
|
||||
}
|
||||
|
||||
result, err := FromImageEditRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Width != 512 {
|
||||
t.Errorf("expected width 512, got %d", result.Width)
|
||||
}
|
||||
|
||||
if result.Height != 768 {
|
||||
t.Errorf("expected height 768, got %d", result.Height)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromImageEditRequest_WithSeed(t *testing.T) {
|
||||
seed := int64(12345)
|
||||
req := ImageEditRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "make it blue",
|
||||
Image: prefix + image,
|
||||
Seed: &seed,
|
||||
}
|
||||
|
||||
result, err := FromImageEditRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Options == nil {
|
||||
t.Fatal("expected options to be set")
|
||||
}
|
||||
|
||||
if result.Options["seed"] != seed {
|
||||
t.Errorf("expected seed %d, got %v", seed, result.Options["seed"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromImageEditRequest_InvalidImage(t *testing.T) {
|
||||
req := ImageEditRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "make it blue",
|
||||
Image: "not-valid-base64",
|
||||
}
|
||||
|
||||
_, err := FromImageEditRequest(req)
|
||||
if err == nil {
|
||||
t.Error("expected error for invalid image")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1604,9 +1604,8 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||
r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler)
|
||||
r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler)
|
||||
r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler)
|
||||
// OpenAI-compatible image generation endpoints
|
||||
// OpenAI-compatible image generation endpoint
|
||||
r.POST("/v1/images/generations", middleware.ImageGenerationsMiddleware(), s.GenerateHandler)
|
||||
r.POST("/v1/images/edits", middleware.ImageEditsMiddleware(), s.GenerateHandler)
|
||||
|
||||
// Inference (Anthropic compatibility)
|
||||
r.POST("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler)
|
||||
@@ -2524,11 +2523,6 @@ func (s *Server) handleImageGenerate(c *gin.Context, req api.GenerateRequest, mo
|
||||
}
|
||||
}
|
||||
|
||||
var images []llm.ImageData
|
||||
for i, imgData := range req.Images {
|
||||
images = append(images, llm.ImageData{ID: i, Data: imgData})
|
||||
}
|
||||
|
||||
var streamStarted bool
|
||||
if err := runner.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||
Prompt: req.Prompt,
|
||||
@@ -2536,7 +2530,6 @@ func (s *Server) handleImageGenerate(c *gin.Context, req api.GenerateRequest, mo
|
||||
Height: req.Height,
|
||||
Steps: req.Steps,
|
||||
Seed: seed,
|
||||
Images: images,
|
||||
}, func(cr llm.CompletionResponse) {
|
||||
streamStarted = true
|
||||
res := api.GenerateResponse{
|
||||
|
||||
@@ -2193,157 +2193,3 @@ func TestGenerateUnload(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGenerateWithImages(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
mock := mockRunner{
|
||||
CompletionResponse: llm.CompletionResponse{
|
||||
Done: true,
|
||||
DoneReason: llm.DoneReasonStop,
|
||||
PromptEvalCount: 1,
|
||||
PromptEvalDuration: 1,
|
||||
EvalCount: 1,
|
||||
EvalDuration: 1,
|
||||
},
|
||||
}
|
||||
|
||||
s := Server{
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: newMockServer(&mock),
|
||||
getGpuFn: getGpuFn,
|
||||
getSystemInfoFn: getSystemInfoFn,
|
||||
waitForRecovery: 250 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
|
||||
time.Sleep(time.Millisecond)
|
||||
req.successCh <- &runnerRef{
|
||||
llama: &mock,
|
||||
}
|
||||
return false
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
go s.sched.Run(t.Context())
|
||||
|
||||
_, digest := createBinFile(t, ggml.KV{
|
||||
"general.architecture": "llama",
|
||||
"llama.block_count": uint32(1),
|
||||
"llama.context_length": uint32(8192),
|
||||
"llama.embedding_length": uint32(4096),
|
||||
"llama.attention.head_count": uint32(32),
|
||||
"llama.attention.head_count_kv": uint32(8),
|
||||
"tokenizer.ggml.tokens": []string{""},
|
||||
"tokenizer.ggml.scores": []float32{0},
|
||||
"tokenizer.ggml.token_type": []int32{0},
|
||||
}, []*ggml.Tensor{
|
||||
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
})
|
||||
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "test",
|
||||
Files: map[string]string{"file.gguf": digest},
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
t.Run("images passed to completion request", func(t *testing.T) {
|
||||
testImage := []byte("test-image-data")
|
||||
|
||||
mock.CompletionResponse.Content = "Image processed"
|
||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||
Model: "test",
|
||||
Prompt: "Describe this image",
|
||||
Images: []api.ImageData{testImage},
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
// Verify images were passed to the completion request
|
||||
if len(mock.CompletionRequest.Images) != 1 {
|
||||
t.Fatalf("expected 1 image in completion request, got %d", len(mock.CompletionRequest.Images))
|
||||
}
|
||||
|
||||
if !bytes.Equal(mock.CompletionRequest.Images[0].Data, testImage) {
|
||||
t.Errorf("image data mismatch in completion request")
|
||||
}
|
||||
|
||||
if mock.CompletionRequest.Images[0].ID != 0 {
|
||||
t.Errorf("expected image ID 0, got %d", mock.CompletionRequest.Images[0].ID)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("multiple images passed to completion request", func(t *testing.T) {
|
||||
testImage1 := []byte("test-image-1")
|
||||
testImage2 := []byte("test-image-2")
|
||||
|
||||
mock.CompletionResponse.Content = "Images processed"
|
||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||
Model: "test",
|
||||
Prompt: "Compare these images",
|
||||
Images: []api.ImageData{testImage1, testImage2},
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
// Verify both images were passed
|
||||
if len(mock.CompletionRequest.Images) != 2 {
|
||||
t.Fatalf("expected 2 images in completion request, got %d", len(mock.CompletionRequest.Images))
|
||||
}
|
||||
|
||||
if !bytes.Equal(mock.CompletionRequest.Images[0].Data, testImage1) {
|
||||
t.Errorf("first image data mismatch")
|
||||
}
|
||||
|
||||
if !bytes.Equal(mock.CompletionRequest.Images[1].Data, testImage2) {
|
||||
t.Errorf("second image data mismatch")
|
||||
}
|
||||
|
||||
if mock.CompletionRequest.Images[0].ID != 0 || mock.CompletionRequest.Images[1].ID != 1 {
|
||||
t.Errorf("expected image IDs 0 and 1, got %d and %d",
|
||||
mock.CompletionRequest.Images[0].ID, mock.CompletionRequest.Images[1].ID)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no images when none provided", func(t *testing.T) {
|
||||
mock.CompletionResponse.Content = "No images"
|
||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||
Model: "test",
|
||||
Prompt: "Hello",
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
// Verify no images in completion request
|
||||
if len(mock.CompletionRequest.Images) != 0 {
|
||||
t.Fatalf("expected 0 images in completion request, got %d", len(mock.CompletionRequest.Images))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/models/flux2"
|
||||
"github.com/ollama/ollama/x/imagegen/models/gemma3"
|
||||
"github.com/ollama/ollama/x/imagegen/models/glm4_moe_lite"
|
||||
"github.com/ollama/ollama/x/imagegen/models/gpt_oss"
|
||||
"github.com/ollama/ollama/x/imagegen/models/llama"
|
||||
"github.com/ollama/ollama/x/imagegen/models/qwen_image"
|
||||
@@ -301,6 +302,8 @@ func load(modelPath string) (Model, error) {
|
||||
return gemma3.Load(modelPath)
|
||||
case "gemma3_text":
|
||||
return gemma3.LoadText(modelPath)
|
||||
case "glm4_moe_lite":
|
||||
return glm4_moe_lite.Load(modelPath)
|
||||
default:
|
||||
return llama.Load(modelPath)
|
||||
}
|
||||
|
||||
@@ -991,6 +991,19 @@ func Concat(a, b *Array, axis int) *Array {
|
||||
return Concatenate([]*Array{a, b}, axis)
|
||||
}
|
||||
|
||||
// Stack stacks arrays along a new axis (axis 0 by default)
|
||||
func Stack(arrays []*Array, axis int) *Array {
|
||||
handles := make([]C.mlx_array, len(arrays))
|
||||
for i, arr := range arrays {
|
||||
handles[i] = arr.c
|
||||
}
|
||||
vec := C.mlx_vector_array_new_data(&handles[0], C.size_t(len(handles)))
|
||||
res := C.mlx_array_new()
|
||||
C.mlx_stack_axis(&res, vec, C.int(axis), C.default_stream())
|
||||
C.mlx_vector_array_free(vec)
|
||||
return newArray(res)
|
||||
}
|
||||
|
||||
// Slice slices the array
|
||||
func Slice(a *Array, start, stop []int32) *Array {
|
||||
n := len(start)
|
||||
|
||||
@@ -177,20 +177,6 @@ func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height
|
||||
})
|
||||
}
|
||||
|
||||
// GenerateImageWithInputs implements runner.ImageEditModel interface.
|
||||
// It generates an image conditioned on the provided input images for image editing.
|
||||
func (m *Model) GenerateImageWithInputs(ctx context.Context, prompt string, width, height int32, steps int, seed int64, inputImages []image.Image, progress func(step, total int)) (*mlx.Array, error) {
|
||||
return m.GenerateFromConfig(ctx, &GenerateConfig{
|
||||
Prompt: prompt,
|
||||
Width: width,
|
||||
Height: height,
|
||||
Steps: steps,
|
||||
Seed: seed,
|
||||
InputImages: inputImages,
|
||||
Progress: progress,
|
||||
})
|
||||
}
|
||||
|
||||
// MaxOutputPixels is the maximum output resolution (4 megapixels, ~2048x2048)
|
||||
const MaxOutputPixels = 2048 * 2048
|
||||
|
||||
|
||||
529
x/imagegen/models/glm4_moe_lite/glm4_moe_lite.go
Normal file
529
x/imagegen/models/glm4_moe_lite/glm4_moe_lite.go
Normal file
@@ -0,0 +1,529 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package glm4_moe_lite provides the GLM4-MoE-Lite implementation for MLX.
|
||||
// This model uses Multi-head Latent Attention (MLA) and Mixture of Experts (MoE).
|
||||
package glm4_moe_lite
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
)
|
||||
|
||||
// Config holds GLM4-MoE-Lite model configuration
|
||||
type Config struct {
|
||||
HiddenSize int32 `json:"hidden_size"`
|
||||
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
||||
IntermediateSize int32 `json:"intermediate_size"`
|
||||
MoEIntermediateSize int32 `json:"moe_intermediate_size"`
|
||||
NumAttentionHeads int32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads int32 `json:"num_key_value_heads"`
|
||||
VocabSize int32 `json:"vocab_size"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
|
||||
AttentionBias bool `json:"attention_bias"`
|
||||
|
||||
// MLA (Multi-head Latent Attention) parameters
|
||||
QLoraRank int32 `json:"q_lora_rank"`
|
||||
KVLoraRank int32 `json:"kv_lora_rank"`
|
||||
QKRopeHeadDim int32 `json:"qk_rope_head_dim"`
|
||||
QKNopeHeadDim int32 `json:"qk_nope_head_dim"`
|
||||
VHeadDim int32 `json:"v_head_dim"`
|
||||
|
||||
// MoE parameters
|
||||
NRoutedExperts int32 `json:"n_routed_experts"`
|
||||
NSharedExperts int32 `json:"n_shared_experts"`
|
||||
NumExpertsPerTok int32 `json:"num_experts_per_tok"`
|
||||
RoutedScalingFactor float32 `json:"routed_scaling_factor"`
|
||||
NormTopKProb bool `json:"norm_topk_prob"`
|
||||
FirstKDenseReplace int32 `json:"first_k_dense_replace"`
|
||||
NGroup int32 `json:"n_group"`
|
||||
TopKGroup int32 `json:"topk_group"`
|
||||
|
||||
// Computed fields
|
||||
QHeadDim int32 `json:"-"` // qk_nope_head_dim + qk_rope_head_dim
|
||||
Scale float32 `json:"-"` // 1/sqrt(QHeadDim)
|
||||
}
|
||||
|
||||
// MLAAttention implements Multi-head Latent Attention
|
||||
type MLAAttention struct {
|
||||
// Low-rank query projections
|
||||
QAProj *nn.Linear `weight:"self_attn.q_a_proj"`
|
||||
QALayerNorm *nn.RMSNorm `weight:"self_attn.q_a_layernorm"`
|
||||
QBProj *nn.Linear `weight:"self_attn.q_b_proj"`
|
||||
|
||||
// Low-rank KV projections (with shared rope component)
|
||||
KVAProjWithMQA *nn.Linear `weight:"self_attn.kv_a_proj_with_mqa"`
|
||||
KVALayerNorm *nn.RMSNorm `weight:"self_attn.kv_a_layernorm"`
|
||||
KVBProj *nn.Linear `weight:"self_attn.kv_b_proj"`
|
||||
|
||||
// Output projection
|
||||
OProj *nn.Linear `weight:"self_attn.o_proj"`
|
||||
}
|
||||
|
||||
// Forward computes MLA attention output
|
||||
func (a *MLAAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||
// Query path: q_a_proj -> layernorm -> q_b_proj
|
||||
q := a.QAProj.Forward(x)
|
||||
q = a.QALayerNorm.Forward(q, cfg.RMSNormEps)
|
||||
q = a.QBProj.Forward(q)
|
||||
|
||||
// Reshape Q: [B, L, num_heads * q_head_dim] -> [B, num_heads, L, q_head_dim]
|
||||
q = mlx.Reshape(q, B, L, cfg.NumAttentionHeads, cfg.QHeadDim)
|
||||
q = mlx.Transpose(q, 0, 2, 1, 3)
|
||||
|
||||
// Split Q into nope and rope parts
|
||||
qNope := mlx.Slice(q, []int32{0, 0, 0, 0}, []int32{B, cfg.NumAttentionHeads, L, cfg.QKNopeHeadDim})
|
||||
qPE := mlx.Slice(q, []int32{0, 0, 0, cfg.QKNopeHeadDim}, []int32{B, cfg.NumAttentionHeads, L, cfg.QHeadDim})
|
||||
|
||||
// KV path: kv_a_proj_with_mqa -> split -> layernorm -> kv_b_proj
|
||||
compressedKV := a.KVAProjWithMQA.Forward(x)
|
||||
|
||||
// Split into compressed_kv and k_pe (shared rope component)
|
||||
kvCompressed := mlx.Slice(compressedKV, []int32{0, 0, 0}, []int32{B, L, cfg.KVLoraRank})
|
||||
kPE := mlx.Slice(compressedKV, []int32{0, 0, cfg.KVLoraRank}, []int32{B, L, cfg.KVLoraRank + cfg.QKRopeHeadDim})
|
||||
|
||||
// k_pe is shared across heads (MQA-style): [B, L, rope_dim] -> [B, 1, L, rope_dim]
|
||||
kPE = mlx.Reshape(kPE, B, L, 1, cfg.QKRopeHeadDim)
|
||||
kPE = mlx.Transpose(kPE, 0, 2, 1, 3)
|
||||
|
||||
// Apply layernorm and project KV
|
||||
kvCompressed = a.KVALayerNorm.Forward(kvCompressed, cfg.RMSNormEps)
|
||||
kv := a.KVBProj.Forward(kvCompressed)
|
||||
|
||||
// Reshape KV: [B, L, num_heads * (qk_nope_head_dim + v_head_dim)]
|
||||
kv = mlx.Reshape(kv, B, L, cfg.NumAttentionHeads, cfg.QKNopeHeadDim+cfg.VHeadDim)
|
||||
kv = mlx.Transpose(kv, 0, 2, 1, 3)
|
||||
|
||||
// Split into k_nope and values
|
||||
kNope := mlx.Slice(kv, []int32{0, 0, 0, 0}, []int32{B, cfg.NumAttentionHeads, L, cfg.QKNopeHeadDim})
|
||||
values := mlx.Slice(kv, []int32{0, 0, 0, cfg.QKNopeHeadDim}, []int32{B, cfg.NumAttentionHeads, L, cfg.QKNopeHeadDim + cfg.VHeadDim})
|
||||
|
||||
// Apply RoPE to the rope parts only
|
||||
offset := 0
|
||||
if c != nil {
|
||||
offset = c.Offset()
|
||||
}
|
||||
qPE = mlx.RoPE(qPE, int(cfg.QKRopeHeadDim), true, cfg.RopeTheta, 1.0, offset)
|
||||
kPE = mlx.RoPE(kPE, int(cfg.QKRopeHeadDim), true, cfg.RopeTheta, 1.0, offset)
|
||||
|
||||
// Repeat k_pe across all heads
|
||||
kPE = mlx.Tile(kPE, []int32{1, cfg.NumAttentionHeads, 1, 1})
|
||||
|
||||
// Concatenate nope and rope parts
|
||||
queries := mlx.Concatenate([]*mlx.Array{qNope, qPE}, 3)
|
||||
keys := mlx.Concatenate([]*mlx.Array{kNope, kPE}, 3)
|
||||
|
||||
// Update KV cache
|
||||
if c != nil {
|
||||
keys, values = c.Update(keys, values, int(L))
|
||||
}
|
||||
|
||||
// Scaled dot product attention
|
||||
out := mlx.ScaledDotProductAttention(queries, keys, values, cfg.Scale, L > 1)
|
||||
|
||||
// Reshape back: [B, num_heads, L, v_head_dim] -> [B, L, num_heads * v_head_dim]
|
||||
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.VHeadDim)
|
||||
|
||||
return a.OProj.Forward(out)
|
||||
}
|
||||
|
||||
// DenseMLP implements the standard SwiGLU MLP for dense layers
|
||||
type DenseMLP struct {
|
||||
GateProj *nn.Linear `weight:"mlp.gate_proj"`
|
||||
UpProj *nn.Linear `weight:"mlp.up_proj"`
|
||||
DownProj *nn.Linear `weight:"mlp.down_proj"`
|
||||
}
|
||||
|
||||
// Forward applies the SwiGLU MLP
|
||||
func (m *DenseMLP) Forward(x *mlx.Array) *mlx.Array {
|
||||
gate := mlx.SiLU(m.GateProj.Forward(x))
|
||||
up := m.UpProj.Forward(x)
|
||||
return m.DownProj.Forward(mlx.Mul(gate, up))
|
||||
}
|
||||
|
||||
// MoEGate implements the expert gating mechanism
|
||||
type MoEGate struct {
|
||||
Weight *mlx.Array `weight:"mlp.gate.weight"`
|
||||
EScoreCorrectionBias *mlx.Array `weight:"mlp.gate.e_score_correction_bias,optional"`
|
||||
}
|
||||
|
||||
// Forward computes expert selection indices and scores
|
||||
func (g *MoEGate) Forward(x *mlx.Array, cfg *Config) (*mlx.Array, *mlx.Array) {
|
||||
// Compute gate logits: x @ weight.T
|
||||
gates := mlx.Linear(x, mlx.Transpose(g.Weight, 1, 0))
|
||||
|
||||
// Sigmoid scoring
|
||||
scores := mlx.Sigmoid(gates)
|
||||
origScores := scores
|
||||
|
||||
// Add correction bias if present
|
||||
if g.EScoreCorrectionBias != nil {
|
||||
scores = mlx.Add(scores, g.EScoreCorrectionBias)
|
||||
}
|
||||
|
||||
// Group-wise expert selection (simplified for n_group=1)
|
||||
// Select top-k experts
|
||||
topK := cfg.NumExpertsPerTok
|
||||
negScores := mlx.Neg(scores)
|
||||
inds := mlx.Argpartition(negScores, int(topK)-1, -1)
|
||||
|
||||
shape := inds.Shape()
|
||||
inds = mlx.Slice(inds, []int32{0, 0, 0}, []int32{shape[0], shape[1], topK})
|
||||
|
||||
// Get scores for selected experts
|
||||
scores = mlx.TakeAlongAxis(origScores, inds, -1)
|
||||
|
||||
// Normalize if configured
|
||||
if topK > 1 && cfg.NormTopKProb {
|
||||
sumScores := mlx.Sum(scores, -1, true)
|
||||
scores = mlx.Div(scores, sumScores)
|
||||
}
|
||||
|
||||
// Apply routing scaling factor
|
||||
scores = mlx.MulScalar(scores, cfg.RoutedScalingFactor)
|
||||
|
||||
return inds, scores
|
||||
}
|
||||
|
||||
// SwitchMLP implements the MoE expert computation using stacked weights
|
||||
// Note: No weight tags - these are populated manually by stacking expert weights
|
||||
type SwitchMLP struct {
|
||||
GateWeight *mlx.Array
|
||||
UpWeight *mlx.Array
|
||||
DownWeight *mlx.Array
|
||||
}
|
||||
|
||||
// Forward applies the switched expert MLP
|
||||
func (s *SwitchMLP) Forward(x *mlx.Array, indices *mlx.Array, cfg *Config) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B, L := shape[0], shape[1]
|
||||
topK := cfg.NumExpertsPerTok
|
||||
|
||||
// Expand x for expert computation: [B, L, D] -> [B, L, 1, 1, D]
|
||||
xExpanded := mlx.ExpandDims(mlx.ExpandDims(x, -2), -2)
|
||||
|
||||
// Flatten for gather_mm: [B*L, 1, 1, D]
|
||||
xFlat := mlx.Reshape(xExpanded, B*L, 1, 1, cfg.HiddenSize)
|
||||
|
||||
// Flatten indices: [B, L, topK] -> [B*L, topK]
|
||||
idxFlat := mlx.Reshape(indices, B*L, topK)
|
||||
|
||||
// Sort for efficient gather (when we have many tokens)
|
||||
doSort := B*L >= 64
|
||||
var invOrder *mlx.Array
|
||||
n := B * L * topK
|
||||
|
||||
if doSort {
|
||||
idxAll := mlx.Flatten(idxFlat)
|
||||
order := mlx.Argsort(idxAll, 0)
|
||||
invOrder = mlx.Argsort(order, 0)
|
||||
// Reorder x based on sorted indices
|
||||
xFlat = mlx.ExpandDims(mlx.Take(mlx.Squeeze(xFlat, 1), mlx.FloorDivideScalar(order, topK), 0), 1)
|
||||
idxFlat = mlx.Reshape(mlx.Take(idxAll, order, 0), n, 1)
|
||||
}
|
||||
|
||||
// Expert computation using gather_mm
|
||||
// gate: x @ gate_weight.T (indices are on the rhs/weight side)
|
||||
gate := mlx.GatherMM(xFlat, mlx.Transpose(s.GateWeight, 0, 2, 1), nil, idxFlat, doSort)
|
||||
// up: x @ up_weight.T
|
||||
up := mlx.GatherMM(xFlat, mlx.Transpose(s.UpWeight, 0, 2, 1), nil, idxFlat, doSort)
|
||||
|
||||
// SwiGLU activation
|
||||
hidden := mlx.Mul(mlx.SiLU(gate), up)
|
||||
|
||||
// down: hidden @ down_weight.T
|
||||
down := mlx.GatherMM(hidden, mlx.Transpose(s.DownWeight, 0, 2, 1), nil, idxFlat, doSort)
|
||||
|
||||
// Unsort if we sorted
|
||||
if doSort {
|
||||
down = mlx.Reshape(mlx.Take(mlx.Squeeze(mlx.Squeeze(down, 2), 1), invOrder, 0), B*L, topK, cfg.HiddenSize)
|
||||
} else {
|
||||
down = mlx.Squeeze(down, 2)
|
||||
}
|
||||
|
||||
return mlx.Reshape(down, B, L, topK, cfg.HiddenSize)
|
||||
}
|
||||
|
||||
// SharedExperts implements the shared expert MLP
|
||||
type SharedExperts struct {
|
||||
GateProj *nn.Linear `weight:"mlp.shared_experts.gate_proj"`
|
||||
UpProj *nn.Linear `weight:"mlp.shared_experts.up_proj"`
|
||||
DownProj *nn.Linear `weight:"mlp.shared_experts.down_proj"`
|
||||
}
|
||||
|
||||
// Forward applies the shared expert MLP
|
||||
func (s *SharedExperts) Forward(x *mlx.Array) *mlx.Array {
|
||||
gate := mlx.SiLU(s.GateProj.Forward(x))
|
||||
up := s.UpProj.Forward(x)
|
||||
return s.DownProj.Forward(mlx.Mul(gate, up))
|
||||
}
|
||||
|
||||
// MoE implements the full Mixture of Experts layer
|
||||
type MoE struct {
|
||||
Gate *MoEGate
|
||||
SwitchMLP *SwitchMLP
|
||||
SharedExperts *SharedExperts
|
||||
}
|
||||
|
||||
// Forward applies the MoE layer
|
||||
func (m *MoE) Forward(x *mlx.Array, cfg *Config) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B, L := shape[0], shape[1]
|
||||
|
||||
// Get expert indices and scores
|
||||
inds, scores := m.Gate.Forward(x, cfg)
|
||||
|
||||
// Apply routed experts
|
||||
expertOut := m.SwitchMLP.Forward(x, inds, cfg)
|
||||
|
||||
// Weight by scores: [B, L, topK, D] * [B, L, topK, 1] -> sum over topK
|
||||
scoresExpanded := mlx.ExpandDims(scores, -1)
|
||||
y := mlx.Sum(mlx.Mul(expertOut, scoresExpanded), 2, false)
|
||||
|
||||
// Add shared experts if present
|
||||
if m.SharedExperts != nil {
|
||||
y = mlx.Add(y, m.SharedExperts.Forward(x))
|
||||
}
|
||||
|
||||
return mlx.Reshape(y, B, L, cfg.HiddenSize)
|
||||
}
|
||||
|
||||
// DenseBlock represents a dense transformer block (for first_k_dense_replace layers)
|
||||
type DenseBlock struct {
|
||||
Attention *MLAAttention
|
||||
MLP *DenseMLP
|
||||
InputLayerNorm *nn.RMSNorm `weight:"input_layernorm"`
|
||||
PostAttentionLayerNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
|
||||
}
|
||||
|
||||
// Forward applies the dense block
|
||||
func (b *DenseBlock) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||
// Pre-norm attention with residual
|
||||
r := b.Attention.Forward(b.InputLayerNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg)
|
||||
h := mlx.Add(x, r)
|
||||
|
||||
// Pre-norm MLP with residual
|
||||
r = b.MLP.Forward(b.PostAttentionLayerNorm.Forward(h, cfg.RMSNormEps))
|
||||
return mlx.Add(h, r)
|
||||
}
|
||||
|
||||
// MoEBlock represents a MoE transformer block
|
||||
type MoEBlock struct {
|
||||
Attention *MLAAttention
|
||||
MoE *MoE
|
||||
InputLayerNorm *nn.RMSNorm `weight:"input_layernorm"`
|
||||
PostAttentionLayerNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
|
||||
}
|
||||
|
||||
// Forward applies the MoE block
|
||||
func (b *MoEBlock) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||
// Pre-norm attention with residual
|
||||
r := b.Attention.Forward(b.InputLayerNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg)
|
||||
h := mlx.Add(x, r)
|
||||
|
||||
// Pre-norm MoE with residual
|
||||
r = b.MoE.Forward(b.PostAttentionLayerNorm.Forward(h, cfg.RMSNormEps), cfg)
|
||||
return mlx.Add(h, r)
|
||||
}
|
||||
|
||||
// Block interface for both dense and MoE blocks
|
||||
type Block interface {
|
||||
Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array
|
||||
}
|
||||
|
||||
// Model represents the complete GLM4-MoE-Lite model
|
||||
type Model struct {
|
||||
EmbedTokens *nn.Embedding `weight:"model.embed_tokens"`
|
||||
Layers []Block `weight:"-"` // Loaded manually due to different block types
|
||||
Norm *nn.RMSNorm `weight:"model.norm"`
|
||||
LMHead *nn.Linear `weight:"lm_head"`
|
||||
|
||||
tok *tokenizer.Tokenizer
|
||||
*Config
|
||||
}
|
||||
|
||||
// sanitizeExpertWeights stacks individual expert weights into a single tensor
|
||||
func sanitizeExpertWeights(weights *safetensors.ModelWeights, prefix string, numExperts int32) (*mlx.Array, *mlx.Array, *mlx.Array) {
|
||||
var gateWeights, upWeights, downWeights []*mlx.Array
|
||||
|
||||
for e := int32(0); e < numExperts; e++ {
|
||||
gw, _ := weights.GetTensor(fmt.Sprintf("%s.mlp.experts.%d.gate_proj.weight", prefix, e))
|
||||
uw, _ := weights.GetTensor(fmt.Sprintf("%s.mlp.experts.%d.up_proj.weight", prefix, e))
|
||||
dw, _ := weights.GetTensor(fmt.Sprintf("%s.mlp.experts.%d.down_proj.weight", prefix, e))
|
||||
|
||||
if gw != nil {
|
||||
gateWeights = append(gateWeights, gw)
|
||||
}
|
||||
if uw != nil {
|
||||
upWeights = append(upWeights, uw)
|
||||
}
|
||||
if dw != nil {
|
||||
downWeights = append(downWeights, dw)
|
||||
}
|
||||
}
|
||||
|
||||
var stackedGate, stackedUp, stackedDown *mlx.Array
|
||||
if len(gateWeights) > 0 {
|
||||
stackedGate = mlx.Stack(gateWeights, 0)
|
||||
}
|
||||
if len(upWeights) > 0 {
|
||||
stackedUp = mlx.Stack(upWeights, 0)
|
||||
}
|
||||
if len(downWeights) > 0 {
|
||||
stackedDown = mlx.Stack(downWeights, 0)
|
||||
}
|
||||
|
||||
return stackedGate, stackedUp, stackedDown
|
||||
}
|
||||
|
||||
// Load loads a GLM4-MoE-Lite model from the given path
|
||||
func Load(modelPath string) (*Model, error) {
|
||||
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
|
||||
var cfg Config
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("parse config: %w", err)
|
||||
}
|
||||
|
||||
// Compute derived fields
|
||||
cfg.QHeadDim = cfg.QKNopeHeadDim + cfg.QKRopeHeadDim
|
||||
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.QHeadDim)))
|
||||
|
||||
weights, err := safetensors.LoadModelWeights(modelPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
|
||||
tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load tokenizer: %w", err)
|
||||
}
|
||||
|
||||
m := &Model{
|
||||
Layers: make([]Block, cfg.NumHiddenLayers),
|
||||
Config: &cfg,
|
||||
tok: tok,
|
||||
}
|
||||
|
||||
// Load embedding, norm, and lm_head
|
||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Load layers manually due to different block types
|
||||
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
|
||||
prefix := fmt.Sprintf("model.layers.%d", i)
|
||||
|
||||
// Load attention (same for both block types)
|
||||
attn := &MLAAttention{}
|
||||
if err := safetensors.LoadModule(attn, weights, prefix); err != nil {
|
||||
return nil, fmt.Errorf("layer %d attention: %w", i, err)
|
||||
}
|
||||
|
||||
if i < cfg.FirstKDenseReplace {
|
||||
// Dense block
|
||||
block := &DenseBlock{Attention: attn}
|
||||
if err := safetensors.LoadModule(block, weights, prefix); err != nil {
|
||||
return nil, fmt.Errorf("layer %d dense: %w", i, err)
|
||||
}
|
||||
m.Layers[i] = block
|
||||
} else {
|
||||
// MoE block
|
||||
block := &MoEBlock{Attention: attn}
|
||||
if err := safetensors.LoadModule(block, weights, prefix); err != nil {
|
||||
return nil, fmt.Errorf("layer %d moe block: %w", i, err)
|
||||
}
|
||||
|
||||
// Stack expert weights
|
||||
gateW, upW, downW := sanitizeExpertWeights(weights, prefix, cfg.NRoutedExperts)
|
||||
|
||||
block.MoE = &MoE{
|
||||
Gate: &MoEGate{},
|
||||
SwitchMLP: &SwitchMLP{
|
||||
GateWeight: gateW,
|
||||
UpWeight: upW,
|
||||
DownWeight: downW,
|
||||
},
|
||||
}
|
||||
|
||||
// Load gate weights
|
||||
if err := safetensors.LoadModule(block.MoE.Gate, weights, prefix); err != nil {
|
||||
return nil, fmt.Errorf("layer %d gate: %w", i, err)
|
||||
}
|
||||
|
||||
// Load shared experts if present
|
||||
if cfg.NSharedExperts > 0 {
|
||||
block.MoE.SharedExperts = &SharedExperts{}
|
||||
if err := safetensors.LoadModule(block.MoE.SharedExperts, weights, prefix); err != nil {
|
||||
return nil, fmt.Errorf("layer %d shared experts: %w", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
m.Layers[i] = block
|
||||
}
|
||||
}
|
||||
|
||||
mlx.Eval(mlx.Collect(m)...)
|
||||
weights.ReleaseAll()
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Forward computes the forward pass of the model
|
||||
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
|
||||
B, L := tokens.Shape()[0], tokens.Shape()[1]
|
||||
|
||||
h := m.EmbedTokens.Forward(tokens)
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
var c cache.Cache
|
||||
if caches != nil {
|
||||
c = caches[i]
|
||||
}
|
||||
h = layer.Forward(h, c, B, L, m.Config)
|
||||
}
|
||||
|
||||
h = m.Norm.Forward(h, m.RMSNormEps)
|
||||
return m.LMHead.Forward(h)
|
||||
}
|
||||
|
||||
// Interface methods
|
||||
|
||||
// NumLayers returns the number of transformer layers
|
||||
func (m *Model) NumLayers() int { return len(m.Layers) }
|
||||
|
||||
// MaxContextLength returns the maximum context length
|
||||
func (m *Model) MaxContextLength() int32 { return m.MaxPositionEmbeddings }
|
||||
|
||||
// VocabSize returns the vocabulary size
|
||||
func (m *Model) VocabSize() int32 { return m.Config.VocabSize }
|
||||
|
||||
// Tokenizer returns the model's tokenizer
|
||||
func (m *Model) Tokenizer() *tokenizer.Tokenizer { return m.tok }
|
||||
|
||||
// NewCache creates a new KV cache for the model
|
||||
func (m *Model) NewCache(maxSeqLen int32) []cache.Cache {
|
||||
caches := make([]cache.Cache, len(m.Layers))
|
||||
for i := range caches {
|
||||
caches[i] = cache.NewKVCache()
|
||||
}
|
||||
return caches
|
||||
}
|
||||
|
||||
// FormatPrompt applies the GLM-4 chat template
|
||||
func (m *Model) FormatPrompt(prompt string) string {
|
||||
return "[gMASK]<sop><|user|>\n" + prompt + "<|assistant|>\n"
|
||||
}
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"image"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
@@ -26,12 +25,11 @@ import (
|
||||
|
||||
// Request is the image generation request format
|
||||
type Request struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Width int32 `json:"width,omitempty"`
|
||||
Height int32 `json:"height,omitempty"`
|
||||
Steps int `json:"steps,omitempty"`
|
||||
Seed int64 `json:"seed,omitempty"`
|
||||
Images [][]byte `json:"images,omitempty"` // Input images for image editing/conditioning
|
||||
Prompt string `json:"prompt"`
|
||||
Width int32 `json:"width,omitempty"`
|
||||
Height int32 `json:"height,omitempty"`
|
||||
Steps int `json:"steps,omitempty"`
|
||||
Seed int64 `json:"seed,omitempty"`
|
||||
}
|
||||
|
||||
// Response is streamed back for each progress update
|
||||
@@ -48,13 +46,6 @@ type ImageModel interface {
|
||||
GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64, progress func(step, total int)) (*mlx.Array, error)
|
||||
}
|
||||
|
||||
// ImageEditModel extends ImageModel with image editing/conditioning capability.
|
||||
// Models that support input images for editing should implement this interface.
|
||||
type ImageEditModel interface {
|
||||
ImageModel
|
||||
GenerateImageWithInputs(ctx context.Context, prompt string, width, height int32, steps int, seed int64, inputImages []image.Image, progress func(step, total int)) (*mlx.Array, error)
|
||||
}
|
||||
|
||||
// Server holds the model and handles requests
|
||||
type Server struct {
|
||||
mu sync.Mutex
|
||||
@@ -170,29 +161,6 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Validate and decode input images
|
||||
const maxInputImages = 2
|
||||
if len(req.Images) > maxInputImages {
|
||||
http.Error(w, fmt.Sprintf("too many input images, maximum is %d", maxInputImages), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var inputImages []image.Image
|
||||
if len(req.Images) > 0 {
|
||||
// TODO: add memory check for input images
|
||||
|
||||
inputImages = make([]image.Image, len(req.Images))
|
||||
for i, imgBytes := range req.Images {
|
||||
img, err := imagegen.DecodeImage(imgBytes)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("invalid image %d: %v", i, err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
inputImages[i] = img
|
||||
}
|
||||
slog.Info("decoded input images", "count", len(inputImages))
|
||||
}
|
||||
|
||||
// Serialize generation requests - MLX model may not handle concurrent generation
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
@@ -224,19 +192,7 @@ func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
// Use ImageEditModel if available and images provided, otherwise use basic ImageModel
|
||||
var img *mlx.Array
|
||||
var err error
|
||||
if len(inputImages) > 0 {
|
||||
editModel, ok := s.model.(ImageEditModel)
|
||||
if !ok {
|
||||
http.Error(w, "model does not support image editing", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
img, err = editModel.GenerateImageWithInputs(ctx, req.Prompt, req.Width, req.Height, req.Steps, req.Seed, inputImages, progress)
|
||||
} else {
|
||||
img, err = s.model.GenerateImage(ctx, req.Prompt, req.Width, req.Height, req.Steps, req.Seed, progress)
|
||||
}
|
||||
img, err := s.model.GenerateImage(ctx, req.Prompt, req.Width, req.Height, req.Steps, req.Seed, progress)
|
||||
|
||||
if err != nil {
|
||||
// Don't send error for cancellation
|
||||
|
||||
@@ -226,27 +226,19 @@ func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn f
|
||||
seed = time.Now().UnixNano()
|
||||
}
|
||||
|
||||
// Extract raw image bytes from llm.ImageData slice
|
||||
var images [][]byte
|
||||
for _, img := range req.Images {
|
||||
images = append(images, img.Data)
|
||||
}
|
||||
|
||||
// Build request for subprocess
|
||||
creq := struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Width int32 `json:"width,omitempty"`
|
||||
Height int32 `json:"height,omitempty"`
|
||||
Steps int32 `json:"steps,omitempty"`
|
||||
Seed int64 `json:"seed,omitempty"`
|
||||
Images [][]byte `json:"images,omitempty"`
|
||||
Prompt string `json:"prompt"`
|
||||
Width int32 `json:"width,omitempty"`
|
||||
Height int32 `json:"height,omitempty"`
|
||||
Steps int32 `json:"steps,omitempty"`
|
||||
Seed int64 `json:"seed,omitempty"`
|
||||
}{
|
||||
Prompt: req.Prompt,
|
||||
Width: req.Width,
|
||||
Height: req.Height,
|
||||
Steps: req.Steps,
|
||||
Seed: seed,
|
||||
Images: images,
|
||||
}
|
||||
|
||||
body, err := json.Marshal(creq)
|
||||
|
||||
Reference in New Issue
Block a user