Compare commits

...

4 Commits

Author SHA1 Message Date
Patrick Devine
050b0a03a6 fix --imagegen flag 2026-02-13 22:18:27 -08:00
Patrick Devine
8faae6e443 remove log lines 2026-02-13 21:57:06 -08:00
Patrick Devine
f354af3190 fix loading diffusion models 2026-02-13 21:51:54 -08:00
Patrick Devine
967bedce30 load glm4_moe_lite from the mlxrunner 2026-02-13 19:00:35 -08:00
19 changed files with 764 additions and 281 deletions

View File

@@ -581,6 +581,17 @@ func RunHandler(cmd *cobra.Command, args []string) error {
}
opts.WordWrap = !nowrap
useImagegen := false
if cmd.Flags().Lookup("imagegen") != nil {
useImagegen, err = cmd.Flags().GetBool("imagegen")
if err != nil {
return err
}
}
if useImagegen {
opts.Options["use_imagegen_runner"] = true
}
// Fill out the rest of the options based on information about the
// model.
client, err := api.ClientFromEnvironment()
@@ -2130,6 +2141,9 @@ func NewCLI() *cobra.Command {
// Image generation flags (width, height, steps, seed, etc.)
imagegen.RegisterFlags(runCmd)
runCmd.Flags().Bool("imagegen", false, "Use the imagegen runner for LLM inference")
runCmd.Flags().MarkHidden("imagegen")
stopCmd := &cobra.Command{
Use: "stop MODEL",
Short: "Stop a running model",

View File

@@ -144,12 +144,15 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.C
return nil, nil, nil, fmt.Errorf("%s %w", name, err)
}
useImagegen, _ := requestOpts["use_imagegen_runner"].(bool)
delete(requestOpts, "use_imagegen_runner")
opts, err := s.modelOptions(model, requestOpts)
if err != nil {
return nil, nil, nil, err
}
runnerCh, errCh := s.sched.GetRunner(ctx, model, opts, keepAlive)
runnerCh, errCh := s.sched.GetRunner(ctx, model, opts, keepAlive, useImagegen)
var runner *runnerRef
select {
case runner = <-runnerCh:

View File

@@ -2383,6 +2383,7 @@ func TestImageGenerateStreamFalse(t *testing.T) {
llama: &mock,
Options: &opts,
model: &Model{Config: model.ConfigV2{Capabilities: []string{"image"}}},
isImagegen: true,
numParallel: 1,
},
},

View File

@@ -22,6 +22,7 @@ import (
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/mlxrunner"
)
type LlmRequest struct {
@@ -32,6 +33,7 @@ type LlmRequest struct {
successCh chan *runnerRef
errCh chan error
schedAttempts uint
useImagegen bool
}
type Scheduler struct {
@@ -82,7 +84,7 @@ func InitScheduler(ctx context.Context) *Scheduler {
}
// context must be canceled to decrement ref count and release the runner
func (s *Scheduler) GetRunner(c context.Context, m *Model, opts api.Options, sessionDuration *api.Duration) (chan *runnerRef, chan error) {
func (s *Scheduler) GetRunner(c context.Context, m *Model, opts api.Options, sessionDuration *api.Duration, useImagegen bool) (chan *runnerRef, chan error) {
if opts.NumCtx < 4 {
opts.NumCtx = 4
}
@@ -99,6 +101,7 @@ func (s *Scheduler) GetRunner(c context.Context, m *Model, opts api.Options, ses
sessionDuration: sessionDuration,
successCh: make(chan *runnerRef, 1),
errCh: make(chan error, 1),
useImagegen: useImagegen,
}
s.loadedMu.Lock()
@@ -566,17 +569,20 @@ iGPUScan:
// loadMLX loads an experimental safetensors model using the unified MLX runner.
// This supports both LLM (completion) and image generation models.
func (s *Scheduler) loadMLX(req *LlmRequest) bool {
// Determine mode based on capabilities
var mode imagegen.ModelMode
if slices.Contains(req.model.Config.Capabilities, "image") {
mode = imagegen.ModeImageGen
} else {
mode = imagegen.ModeLLM
}
// Use model name for MLX (it resolves manifests by name, not file path)
modelName := req.model.ShortName
server, err := imagegen.NewServer(modelName, mode)
var server llm.LlamaServer
var err error
isImagegen := false
if slices.Contains(req.model.Config.Capabilities, "image") {
server, err = imagegen.NewServer(modelName, imagegen.ModeImageGen)
isImagegen = true
} else if req.useImagegen {
server, err = imagegen.NewServer(modelName, imagegen.ModeLLM)
isImagegen = true
} else {
server, err = mlxrunner.NewClient(modelName)
}
if err != nil {
req.errCh <- err
return true
@@ -593,6 +599,7 @@ func (s *Scheduler) loadMLX(req *LlmRequest) bool {
llama: server,
Options: &req.opts,
loading: false,
isImagegen: isImagegen,
sessionDuration: sessionDuration,
totalSize: server.TotalSize(),
vramSize: server.VRAMSize(),
@@ -667,6 +674,7 @@ type runnerRef struct {
loading bool // True only during initial load, then false forever
gpus []ml.DeviceID // Recorded at time of provisioning
discreteGPUs bool // True if all devices are discrete GPUs - used to skip VRAM recovery check for iGPUs
isImagegen bool // True if loaded via imagegen runner (vs mlxrunner)
vramSize uint64
totalSize uint64
@@ -699,6 +707,12 @@ func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool
runner.refMu.Lock()
defer runner.refMu.Unlock()
// Check if runner type (imagegen vs mlxrunner) matches what's requested
wantImagegen := req.useImagegen || slices.Contains(req.model.Config.Capabilities, "image")
if runner.isImagegen != wantImagegen {
return true
}
timeout := 10 * time.Second
if runner.loading {
timeout = 2 * time.Minute // Initial load can take a long time for big models on slow systems...

View File

@@ -408,10 +408,10 @@ func TestSchedGetRunner(t *testing.T) {
s.getSystemInfoFn = getSystemInfoFn
s.newServerFn = a.newServer
slog.Info("a")
successCh1a, errCh1a := s.GetRunner(a.ctx, a.req.model, a.req.opts, a.req.sessionDuration)
successCh1a, errCh1a := s.GetRunner(a.ctx, a.req.model, a.req.opts, a.req.sessionDuration, false)
require.Len(t, s.pendingReqCh, 1)
slog.Info("b")
successCh1b, errCh1b := s.GetRunner(b.ctx, b.req.model, b.req.opts, b.req.sessionDuration)
successCh1b, errCh1b := s.GetRunner(b.ctx, b.req.model, b.req.opts, b.req.sessionDuration, false)
require.Len(t, s.pendingReqCh, 1)
require.Empty(t, successCh1b)
require.Len(t, errCh1b, 1)
@@ -435,7 +435,7 @@ func TestSchedGetRunner(t *testing.T) {
c.req.model.ModelPath = "bad path"
slog.Info("c")
successCh1c, errCh1c := s.GetRunner(c.ctx, c.req.model, c.req.opts, c.req.sessionDuration)
successCh1c, errCh1c := s.GetRunner(c.ctx, c.req.model, c.req.opts, c.req.sessionDuration, false)
// Starts in pending channel, then should be quickly processed to return an error
time.Sleep(50 * time.Millisecond) // Long enough for the "a" model to expire and unload
require.Empty(t, successCh1c)
@@ -509,7 +509,7 @@ func TestSchedPrematureExpired(t *testing.T) {
s.getGpuFn = getGpuFn
s.getSystemInfoFn = getSystemInfoFn
s.newServerFn = scenario1a.newServer
successCh1a, errCh1a := s.GetRunner(scenario1a.ctx, scenario1a.req.model, scenario1a.req.opts, scenario1a.req.sessionDuration)
successCh1a, errCh1a := s.GetRunner(scenario1a.ctx, scenario1a.req.model, scenario1a.req.opts, scenario1a.req.sessionDuration, false)
require.Len(t, s.pendingReqCh, 1)
s.Run(ctx)
select {

View File

@@ -102,15 +102,20 @@ func (mw *ManifestWeights) Load(dtype mlx.Dtype) error {
for _, entry := range entries {
name := entry.name
// Try to get tensor by stripped name first, then with component prefix.
// Blobs may store tensors with the full prefixed name (e.g., "text_encoder/model.layers.0.weight")
// while the tensors map uses stripped names (e.g., "model.layers.0.weight").
// Try to get tensor by stripped name first, then with component prefix,
// then fall back to "data" for legacy blobs created by older versions
// that stored all tensors with the generic key "data".
lookupName := name
arr := sf.Get(lookupName)
if arr == nil && mw.component != "" {
lookupName = mw.component + "/" + name
arr = sf.Get(lookupName)
}
if arr == nil {
// Legacy blob format: tensor stored as "data"
lookupName = "data"
arr = sf.Get(lookupName)
}
if arr != nil {
// Single-tensor blob or tensor found by name
if dtype != 0 && arr.Dtype() != dtype {

View File

@@ -2,76 +2,298 @@ package mlxrunner
import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"math"
"math/rand"
"net"
"net/http"
"net/url"
"os"
"os/exec"
"path/filepath"
"runtime"
"strconv"
"strings"
"sync"
"time"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/manifest"
)
// Client wraps an MLX runner subprocess to implement llm.LlamaServer for LLM models.
type Client struct {
Port int
*exec.Cmd
port int
modelName string
vramSize uint64
done chan error
client *http.Client
lastErr string
lastErrLock sync.Mutex
mu sync.Mutex
cmd *exec.Cmd
}
func (c *Client) JoinPath(path string) string {
return (&url.URL{
Scheme: "http",
Host: net.JoinHostPort("127.0.0.1", strconv.Itoa(c.Port)),
}).JoinPath(path).String()
// NewClient spawns a new MLX runner subprocess for LLM models and waits until it's ready.
func NewClient(modelName string) (*Client, error) {
if err := imagegen.CheckPlatformSupport(); err != nil {
return nil, err
}
// Find a free port
port := 0
if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
if l, err := net.ListenTCP("tcp", a); err == nil {
port = l.Addr().(*net.TCPAddr).Port
l.Close()
}
}
if port == 0 {
port = rand.Intn(65535-49152) + 49152
}
// Get the current executable path
exe, err := os.Executable()
if err != nil {
return nil, fmt.Errorf("unable to lookup executable path: %w", err)
}
if eval, err := filepath.EvalSymlinks(exe); err == nil {
exe = eval
}
// Spawn subprocess: ollama runner --mlx-engine --model <name> --port <port>
cmd := exec.Command(exe, "runner", "--mlx-engine", "--model", modelName, "--port", strconv.Itoa(port))
cmd.Env = os.Environ()
// On Linux, set LD_LIBRARY_PATH to include MLX library directories
if runtime.GOOS == "linux" {
libraryPaths := []string{ml.LibOllamaPath}
if mlxDirs, err := filepath.Glob(filepath.Join(ml.LibOllamaPath, "mlx_*")); err == nil {
libraryPaths = append(libraryPaths, mlxDirs...)
}
if existingPath, ok := os.LookupEnv("LD_LIBRARY_PATH"); ok {
libraryPaths = append(libraryPaths, filepath.SplitList(existingPath)...)
}
pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator))
found := false
for i := range cmd.Env {
if strings.HasPrefix(cmd.Env[i], "LD_LIBRARY_PATH=") {
cmd.Env[i] = "LD_LIBRARY_PATH=" + pathEnvVal
found = true
break
}
}
if !found {
cmd.Env = append(cmd.Env, "LD_LIBRARY_PATH="+pathEnvVal)
}
slog.Debug("mlx subprocess library path", "LD_LIBRARY_PATH", pathEnvVal)
}
// Estimate VRAM based on tensor size from manifest
var vramSize uint64
if modelManifest, err := manifest.LoadManifest(modelName); err == nil {
vramSize = uint64(modelManifest.TotalTensorSize())
} else {
vramSize = 8 * 1024 * 1024 * 1024
}
c := &Client{
port: port,
modelName: modelName,
vramSize: vramSize,
done: make(chan error, 1),
client: &http.Client{Timeout: 10 * time.Minute},
cmd: cmd,
}
// Forward subprocess stdout/stderr to server logs
stdout, _ := cmd.StdoutPipe()
stderr, _ := cmd.StderrPipe()
go func() {
scanner := bufio.NewScanner(stdout)
for scanner.Scan() {
slog.Info("mlx-runner", "msg", scanner.Text())
}
}()
go func() {
scanner := bufio.NewScanner(stderr)
for scanner.Scan() {
line := scanner.Text()
slog.Warn("mlx-runner", "msg", line)
c.lastErrLock.Lock()
c.lastErr = line
c.lastErrLock.Unlock()
}
}()
slog.Info("starting mlx runner subprocess", "exe", exe, "model", modelName, "port", port)
if err := cmd.Start(); err != nil {
return nil, fmt.Errorf("failed to start mlx runner: %w", err)
}
// Reap subprocess when it exits
go func() {
err := cmd.Wait()
c.done <- err
}()
// Wait for subprocess to be ready
if err := c.waitUntilRunning(); err != nil {
c.Close()
return nil, err
}
return c, nil
}
func (c *Client) CheckError(w *http.Response) error {
if w.StatusCode >= 400 {
return errors.New(w.Status)
func (c *Client) getLastErr() string {
c.lastErrLock.Lock()
defer c.lastErrLock.Unlock()
return c.lastErr
}
func (c *Client) waitUntilRunning() error {
ctx := context.Background()
timeout := time.After(2 * time.Minute)
ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop()
for {
select {
case err := <-c.done:
errMsg := c.getLastErr()
if errMsg != "" {
return fmt.Errorf("mlx runner failed: %s (exit: %v)", errMsg, err)
}
return fmt.Errorf("mlx runner exited unexpectedly: %w", err)
case <-timeout:
errMsg := c.getLastErr()
if errMsg != "" {
return fmt.Errorf("timeout waiting for mlx runner: %s", errMsg)
}
return errors.New("timeout waiting for mlx runner to start")
case <-ticker.C:
if err := c.Ping(ctx); err == nil {
slog.Info("mlx runner is ready", "port", c.port)
return nil
}
}
}
}
// completionRequest is a properly-tagged version of llm.CompletionRequest for JSON serialization.
type completionRequest struct {
Prompt string `json:"prompt"`
Options *completionOpts `json:"options,omitempty"`
}
type completionOpts struct {
Temperature float32 `json:"temperature,omitempty"`
TopP float32 `json:"top_p,omitempty"`
MinP float32 `json:"min_p,omitempty"`
TopK int `json:"top_k,omitempty"`
NumPredict int `json:"num_predict,omitempty"`
}
// Close terminates the subprocess.
func (c *Client) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
if c.cmd != nil && c.cmd.Process != nil {
slog.Info("stopping mlx runner subprocess", "pid", c.cmd.Process.Pid)
c.cmd.Process.Signal(os.Interrupt)
select {
case <-c.done:
case <-time.After(5 * time.Second):
c.cmd.Process.Kill()
}
c.cmd = nil
}
return nil
}
// Close implements llm.LlamaServer.
func (c *Client) Close() error {
return c.Cmd.Process.Kill()
}
// Completion implements llm.LlamaServer.
func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(req); err != nil {
return err
creq := completionRequest{
Prompt: req.Prompt,
}
if req.Options != nil {
creq.Options = &completionOpts{
Temperature: req.Options.Temperature,
TopP: req.Options.TopP,
MinP: req.Options.MinP,
TopK: req.Options.TopK,
NumPredict: req.Options.NumPredict,
}
}
w, err := http.Post(c.JoinPath("/v1/completions"), "application/json", &b)
body, err := json.Marshal(creq)
if err != nil {
return err
}
defer w.Body.Close()
if err := c.CheckError(w); err != nil {
httpURL := fmt.Sprintf("http://127.0.0.1:%d/completion", c.port)
httpReq, err := http.NewRequestWithContext(ctx, "POST", httpURL, strings.NewReader(string(body)))
if err != nil {
return err
}
httpReq.Header.Set("Content-Type", "application/json")
scanner := bufio.NewScanner(w.Body)
for scanner.Scan() {
bts := scanner.Bytes()
resp, err := c.client.Do(httpReq)
if err != nil {
return err
}
defer resp.Body.Close()
var resp llm.CompletionResponse
if err := json.Unmarshal(bts, &resp); err != nil {
return err
}
fn(resp)
if resp.StatusCode != http.StatusOK {
respBody, _ := io.ReadAll(resp.Body)
return fmt.Errorf("%s", strings.TrimSpace(string(respBody)))
}
return nil
scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
var raw struct {
Content string `json:"content,omitempty"`
Done bool `json:"done"`
DoneReason int `json:"done_reason,omitempty"`
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
PromptEvalDuration int `json:"prompt_eval_duration,omitempty"`
EvalCount int `json:"eval_count,omitempty"`
EvalDuration int `json:"eval_duration,omitempty"`
}
if err := json.Unmarshal(scanner.Bytes(), &raw); err != nil {
slog.Debug("mlx response parse error", "error", err, "line", string(scanner.Bytes()))
continue
}
cresp := llm.CompletionResponse{
Content: raw.Content,
Done: raw.Done,
DoneReason: llm.DoneReason(raw.DoneReason),
PromptEvalCount: raw.PromptEvalCount,
PromptEvalDuration: time.Duration(raw.PromptEvalDuration),
EvalCount: raw.EvalCount,
EvalDuration: time.Duration(raw.EvalDuration),
}
fn(cresp)
if cresp.Done {
return nil
}
}
return scanner.Err()
}
func (c *Client) ContextLength() int {
@@ -80,71 +302,89 @@ func (c *Client) ContextLength() int {
// Detokenize implements llm.LlamaServer.
func (c *Client) Detokenize(ctx context.Context, tokens []int) (string, error) {
panic("unimplemented")
return "", errors.New("not supported")
}
// Embedding implements llm.LlamaServer.
func (c *Client) Embedding(ctx context.Context, input string) ([]float32, int, error) {
panic("unimplemented")
return nil, 0, errors.New("not supported")
}
// GetDeviceInfos implements llm.LlamaServer.
func (c *Client) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
panic("unimplemented")
return nil
}
// GetPort implements llm.LlamaServer.
func (c *Client) GetPort() int {
return c.Port
return c.port
}
// HasExited implements llm.LlamaServer.
func (c *Client) HasExited() bool {
panic("unimplemented")
select {
case <-c.done:
return true
default:
return false
}
}
// Load implements llm.LlamaServer.
func (c *Client) Load(ctx context.Context, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) ([]ml.DeviceID, error) {
w, err := http.Post(c.JoinPath("/v1/models"), "application/json", nil)
if err != nil {
return nil, err
}
defer w.Body.Close()
return []ml.DeviceID{}, nil
return nil, nil
}
// ModelPath implements llm.LlamaServer.
func (c *Client) ModelPath() string {
panic("unimplemented")
return c.modelName
}
// Pid implements llm.LlamaServer.
func (c *Client) Pid() int {
panic("unimplemented")
c.mu.Lock()
defer c.mu.Unlock()
if c.cmd != nil && c.cmd.Process != nil {
return c.cmd.Process.Pid
}
return -1
}
// Ping implements llm.LlamaServer.
func (c *Client) Ping(ctx context.Context) error {
w, err := http.Get(c.JoinPath("/v1/status"))
reqURL := fmt.Sprintf("http://127.0.0.1:%d/health", c.port)
req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil)
if err != nil {
return err
}
defer w.Body.Close()
resp, err := c.client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("health check failed: %d", resp.StatusCode)
}
return nil
}
// Tokenize implements llm.LlamaServer.
func (c *Client) Tokenize(ctx context.Context, content string) ([]int, error) {
w, err := http.Post(c.JoinPath("/v1/tokenize"), "text/plain", strings.NewReader(content))
reqURL := fmt.Sprintf("http://127.0.0.1:%d/v1/tokenize", c.port)
req, err := http.NewRequestWithContext(ctx, "POST", reqURL, strings.NewReader(content))
if err != nil {
return nil, err
}
defer w.Body.Close()
req.Header.Set("Content-Type", "text/plain")
resp, err := c.client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
var tokens []int
if err := json.NewDecoder(w.Body).Decode(&tokens); err != nil {
if err := json.NewDecoder(resp.Body).Decode(&tokens); err != nil {
return nil, err
}
@@ -153,22 +393,22 @@ func (c *Client) Tokenize(ctx context.Context, content string) ([]int, error) {
// TotalSize implements llm.LlamaServer.
func (c *Client) TotalSize() uint64 {
panic("unimplemented")
return c.vramSize
}
// VRAMByGPU implements llm.LlamaServer.
func (c *Client) VRAMByGPU(id ml.DeviceID) uint64 {
panic("unimplemented")
return c.vramSize
}
// VRAMSize implements llm.LlamaServer.
func (c *Client) VRAMSize() uint64 {
panic("unimplemented")
return c.vramSize
}
// WaitUntilRunning implements llm.LlamaServer.
func (c *Client) WaitUntilRunning(ctx context.Context) error {
panic("unimplemented")
return nil
}
var _ llm.LlamaServer = (*Client)(nil)

7
x/mlxrunner/imports.go Normal file
View File

@@ -0,0 +1,7 @@
//go:build mlx
package mlxrunner
import (
_ "github.com/ollama/ollama/x/models/glm4_moe_lite"
)

View File

@@ -133,6 +133,7 @@ func FromValues[S ~[]E, E arrayTypes](s S, shape ...int) *Array {
}
func (t *Array) Set(other *Array) {
Free(t.desc.inputs...)
other.desc.numRefs++
t.desc.inputs = []*Array{other}
C.mlx_array_set(&t.ctx, other.ctx)
@@ -248,9 +249,9 @@ func Free(s ...*Array) (n int) {
free := make([]*Array, 0, 8192)
fn := func(t *Array) {
if t.Valid() {
free = append(free, t.desc.inputs...)
t.desc.numRefs--
if t.desc.numRefs <= 0 {
free = append(free, t.desc.inputs...)
logutil.Trace("Free", "t", t)
n += t.NumBytes()
C.mlx_array_free(t.ctx)

View File

@@ -24,6 +24,37 @@ func CheckInit() error {
return initError
}
// tryLoadFromDir searches a directory for libmlxc.* and tries to load it.
// Returns true if the library was successfully loaded.
func tryLoadFromDir(dir string) bool {
matches, err := fs.Glob(os.DirFS(dir), "libmlxc.*")
if err != nil || len(matches) == 0 {
return false
}
for _, match := range matches {
path := filepath.Join(dir, match)
cPath := C.CString(path)
defer C.free(unsafe.Pointer(cPath))
var handle C.mlx_dynamic_handle
if C.mlx_dynamic_load(&handle, cPath) != 0 {
slog.Error("Failed to load MLX dynamic library", "path", path)
continue
}
if C.mlx_dynamic_load_symbols(handle) != 0 {
slog.Error("Failed to load MLX dynamic library symbols", "path", path)
C.mlx_dynamic_unload(&handle)
continue
}
return true
}
return false
}
func init() {
switch runtime.GOOS {
case "darwin":
@@ -33,44 +64,34 @@ func init() {
return
}
paths, ok := os.LookupEnv("OLLAMA_LIBRARY_PATH")
if !ok {
slog.Debug("OLLAMA_LIBRARY_PATH not set, skipping mlx dynamic loading")
return
// Try OLLAMA_LIBRARY_PATH first
if paths, ok := os.LookupEnv("OLLAMA_LIBRARY_PATH"); ok {
for _, dir := range filepath.SplitList(paths) {
if tryLoadFromDir(dir) {
return
}
}
}
for _, path := range filepath.SplitList(paths) {
matches, err := fs.Glob(os.DirFS(path), "libmlxc.*")
if err != nil {
initError = fmt.Errorf("failed to glob for MLX libraries in %s: %w", path, err)
slog.Warn("MLX dynamic library not available", "error", initError)
return
// Build search paths: executable directory, then build directories
var searchDirs []string
if exe, err := os.Executable(); err == nil {
if eval, err := filepath.EvalSymlinks(exe); err == nil {
exe = eval
}
searchDirs = append(searchDirs, filepath.Dir(exe))
}
for _, match := range matches {
path := filepath.Join(paths, match)
slog.Info("Loading MLX dynamic library", "path", path)
if cwd, err := os.Getwd(); err == nil {
searchDirs = append(searchDirs, filepath.Join(cwd, "build", "lib", "ollama"))
}
cPath := C.CString(path)
defer C.free(unsafe.Pointer(cPath))
var handle C.mlx_dynamic_handle
if C.mlx_dynamic_load(&handle, cPath) != 0 {
slog.Error("Failed to load MLX dynamic library", "path", path)
continue
}
if C.mlx_dynamic_load_symbols(handle) != 0 {
slog.Error("Failed to load MLX dynamic library symbols", "path", path)
C.mlx_dynamic_unload(&handle)
continue
}
slog.Info("Loaded MLX dynamic library", "path", path)
for _, dir := range searchDirs {
if tryLoadFromDir(dir) {
return
}
}
initError = fmt.Errorf("failed to load any MLX dynamic library from OLLAMA_LIBRARY_PATH=%s", paths)
initError = fmt.Errorf("failed to load MLX dynamic library (searched: %v)", searchDirs)
slog.Warn("MLX dynamic library not available", "error", initError)
}

View File

@@ -306,19 +306,42 @@ func AddMM(c, a, b *Array, alpha, beta float32) *Array {
// Scalar helpers
// scalarWithDtype creates a scalar array matching the dtype of a.
// Matching dtype is important for graph fusion and avoiding implicit casts.
func scalarWithDtype(s float32, a *Array) C.mlx_array {
f32 := C.mlx_array_new_float(C.float(s))
dtype := a.DType()
if dtype == DTypeFloat32 {
return f32
}
casted := C.mlx_array_new()
C.mlx_astype(&casted, f32, C.mlx_dtype(dtype), DefaultStream().ctx)
C.mlx_array_free(f32)
return casted
}
func AddScalar(a *Array, s float32) *Array {
scalar := FromValue(s)
return a.Add(scalar)
scalar := scalarWithDtype(s, a)
out := New("ADD_SCALAR", a)
C.mlx_add(&out.ctx, a.ctx, scalar, DefaultStream().ctx)
C.mlx_array_free(scalar)
return out
}
func MulScalar(a *Array, s float32) *Array {
scalar := FromValue(s)
return a.Multiply(scalar)
scalar := scalarWithDtype(s, a)
out := New("MUL_SCALAR", a)
C.mlx_multiply(&out.ctx, a.ctx, scalar, DefaultStream().ctx)
C.mlx_array_free(scalar)
return out
}
func DivScalar(a *Array, s float32) *Array {
scalar := FromValue(s)
return a.Divide(scalar)
scalar := scalarWithDtype(s, a)
out := New("DIV_SCALAR", a)
C.mlx_divide(&out.ctx, a.ctx, scalar, DefaultStream().ctx)
C.mlx_array_free(scalar)
return out
}
func FloorDivideScalar(a *Array, s int32) *Array {

View File

@@ -0,0 +1,85 @@
//go:build mlx
package base
import (
"encoding/json"
"fmt"
"log/slog"
"sync"
"github.com/ollama/ollama/x/imagegen/tokenizer"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model"
)
// Model is the interface that model implementations must satisfy.
type Model interface {
Forward(inputs *mlx.Array, cache []cache.Cache) *mlx.Array
Unembed(x *mlx.Array) *mlx.Array
NumLayers() int
Tokenizer() *tokenizer.Tokenizer
// LoadWeights receives all tensors loaded from the manifest and assigns
// them to model fields. Model-specific logic (MLA absorption, expert
// stacking, quantized layer creation) happens here.
LoadWeights(tensors map[string]*mlx.Array) error
}
var (
mu sync.Mutex
registry = make(map[string]func(root *model.Root) (Model, error))
)
// Register registers a model constructor by architecture name.
// Called from init() in model packages. Panics on duplicate registration.
func Register(arch string, fn func(root *model.Root) (Model, error)) {
mu.Lock()
defer mu.Unlock()
if _, exists := registry[arch]; exists {
panic(fmt.Sprintf("model architecture %q already registered", arch))
}
registry[arch] = fn
}
// New reads config.json from the manifest, detects the architecture, looks up
// the registered constructor, and calls it to create the model (with config
// parsed and struct created, but weights not yet loaded).
func New(root *model.Root) (Model, error) {
configData, err := root.Manifest.ReadConfig("config.json")
if err != nil {
return nil, fmt.Errorf("failed to read config.json: %w", err)
}
var archConfig struct {
Architectures []string `json:"architectures"`
}
if err := json.Unmarshal(configData, &archConfig); err != nil {
return nil, fmt.Errorf("failed to parse config.json: %w", err)
}
if len(archConfig.Architectures) == 0 {
return nil, fmt.Errorf("no architectures found in config.json")
}
arch := archConfig.Architectures[0]
slog.Info("Model architecture", "arch", arch)
mu.Lock()
fn, ok := registry[arch]
mu.Unlock()
if !ok {
return nil, fmt.Errorf("unsupported architecture: %s", arch)
}
return fn(root)
}
// Weights returns the model's LoadWeights method, which encapsulates all
// weight assignment and post-processing (MLA absorption, expert stacking).
func Weights(m Model) func(map[string]*mlx.Array) error {
return m.LoadWeights
}

View File

@@ -0,0 +1,3 @@
//go:build !mlx
package base

97
x/mlxrunner/model/root.go Normal file
View File

@@ -0,0 +1,97 @@
//go:build mlx
package model
import (
"encoding/binary"
"encoding/json"
"fmt"
"io"
"os"
"strings"
"github.com/ollama/ollama/x/imagegen/manifest"
)
// Root wraps a ModelManifest with pre-scanned quantization metadata.
type Root struct {
Manifest *manifest.ModelManifest
quantType string
groupSize int
}
// Open loads a manifest for the given model name and pre-scans the first
// tensor blob for quantization metadata (quant_type, group_size).
func Open(modelName string) (*Root, error) {
m, err := manifest.LoadManifest(modelName)
if err != nil {
return nil, err
}
root := &Root{Manifest: m}
// Pre-scan first tensor blob for quantization metadata
for _, layer := range m.GetTensorLayers("") {
blobPath := m.BlobPath(layer.Digest)
meta, err := readBlobMetadata(blobPath)
if err != nil || meta == nil {
continue
}
if qt := meta["quant_type"]; qt != "" {
root.quantType = strings.ToUpper(qt)
}
if gs := meta["group_size"]; gs != "" {
fmt.Sscanf(gs, "%d", &root.groupSize)
}
break // only check the first tensor blob
}
return root, nil
}
// Close is a no-op for now (future: release resources).
func (r *Root) Close() {}
// QuantType returns the quantization type detected from tensor metadata.
func (r *Root) QuantType() string { return r.quantType }
// GroupSize returns the quantization group size detected from tensor metadata.
func (r *Root) GroupSize() int { return r.groupSize }
// readBlobMetadata reads the __metadata__ from a safetensors blob header.
func readBlobMetadata(path string) (map[string]string, error) {
f, err := os.Open(path)
if err != nil {
return nil, err
}
defer f.Close()
var headerSize uint64
if err := binary.Read(f, binary.LittleEndian, &headerSize); err != nil {
return nil, err
}
if headerSize > 1024*1024 {
return nil, fmt.Errorf("header too large: %d", headerSize)
}
data := make([]byte, headerSize)
if _, err := io.ReadFull(f, data); err != nil {
return nil, err
}
var header map[string]json.RawMessage
if err := json.Unmarshal(data, &header); err != nil {
return nil, err
}
metaRaw, ok := header["__metadata__"]
if !ok {
return nil, nil
}
var meta map[string]string
if err := json.Unmarshal(metaRaw, &meta); err != nil {
return nil, err
}
return meta, nil
}

View File

@@ -0,0 +1,3 @@
//go:build !mlx
package model

View File

@@ -18,6 +18,8 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
return errors.New("model not loaded")
}
mlx.EnableCompile()
inputs := r.Tokenizer.Encode(request.Prompt, true)
caches, tokens := r.FindNearestCache(inputs)
@@ -47,7 +49,8 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
}
step := func(token *mlx.Array) (*mlx.Array, *mlx.Array) {
logits := r.Model.Unembed(r.Model.Forward(token.ExpandDims(0), caches))
fwd := r.Model.Forward(token.ExpandDims(0), caches)
logits := r.Model.Unembed(fwd)
logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1)
logprobs := logits.Subtract(logits.Logsumexp(true))
@@ -60,7 +63,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
var b bytes.Buffer
now := time.Now()
final := Response{PromptTokens: total, CompletionTokens: request.Options.MaxTokens, DoneReason: 1}
final := Response{Done: true, PromptTokens: total, CompletionTokens: request.Options.MaxTokens, DoneReason: 1}
outputs := make([]int32, 0, request.Options.MaxTokens)
for i := range request.Options.MaxTokens {
nextSample, nextLogprobs := step(sample)

View File

@@ -4,30 +4,22 @@ package mlxrunner
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"net"
"net/http"
"strings"
"time"
"golang.org/x/sync/errgroup"
"github.com/ollama/ollama/x/imagegen/manifest"
"github.com/ollama/ollama/x/imagegen/tokenizer"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model"
"github.com/ollama/ollama/x/mlxrunner/model/base"
"github.com/ollama/ollama/x/mlxrunner/sample"
"github.com/ollama/ollama/x/models/glm4_moe_lite"
)
// TextModel is the interface that model implementations must satisfy.
type TextModel interface {
Forward(inputs *mlx.Array, cache []cache.Cache) *mlx.Array
Unembed(x *mlx.Array) *mlx.Array
NumLayers() int
}
type Request struct {
TextCompletionsRequest
Responses chan Response
@@ -66,52 +58,95 @@ type Response struct {
}
type Runner struct {
Model TextModel
Model base.Model
Tokenizer *tokenizer.Tokenizer
Requests chan Request
CacheEntries map[int32]*CacheEntry
}
func (r *Runner) Load(modelName string) error {
modelManifest, err := manifest.LoadManifest(modelName)
root, err := model.Open(modelName)
if err != nil {
return err
}
defer root.Close()
m, err := base.New(root)
if err != nil {
return err
}
// Read config to detect architecture
configData, err := modelManifest.ReadConfig("config.json")
// Load all tensor blobs from manifest
tensors, err := loadTensorsFromManifest(root)
if err != nil {
return fmt.Errorf("failed to read config.json: %w", err)
return err
}
var archConfig struct {
Architectures []string `json:"architectures"`
}
if err := json.Unmarshal(configData, &archConfig); err != nil {
return fmt.Errorf("failed to parse config.json: %w", err)
}
if len(archConfig.Architectures) == 0 {
return fmt.Errorf("no architectures found in config.json")
}
slog.Info("Model architecture", "arch", archConfig.Architectures[0])
switch archConfig.Architectures[0] {
case "Glm4MoeLiteForCausalLM", "GLM4MoeLite":
model, err := glm4_moe_lite.LoadFromManifest(modelManifest)
if err != nil {
return fmt.Errorf("failed to load GLM4-MoE-Lite model: %w", err)
}
r.Model = model
r.Tokenizer = model.Tokenizer()
default:
return fmt.Errorf("unsupported architecture: %s", archConfig.Architectures[0])
// Assign weights to model (model-specific logic)
loadWeights := base.Weights(m)
if err := loadWeights(tensors); err != nil {
return err
}
r.Model = m
r.Tokenizer = m.Tokenizer()
return nil
}
// loadTensorsFromManifest loads all tensor blobs from the manifest into a
// flat map, deduplicating by digest and remapping safetensors key suffixes.
//
// Uses a two-phase approach: first loads all raw tensors, then remaps
// .bias → _qbias with complete knowledge of which base names have .scale
// entries. This avoids a race condition where Go map iteration order could
// cause .bias to be processed before .scale within the same blob.
func loadTensorsFromManifest(root *model.Root) (map[string]*mlx.Array, error) {
// Phase 1: Load all tensors raw from all blobs
rawTensors := make(map[string]*mlx.Array)
seen := make(map[string]bool)
for _, layer := range root.Manifest.GetTensorLayers("") {
if seen[layer.Digest] {
continue
}
seen[layer.Digest] = true
blobPath := root.Manifest.BlobPath(layer.Digest)
for name, arr := range mlx.Load(blobPath) {
rawTensors[name] = arr
}
}
// Phase 2: Identify all base names that have .scale tensors and remap them
scaleBaseNames := make(map[string]bool)
allTensors := make(map[string]*mlx.Array, len(rawTensors))
for name, arr := range rawTensors {
if strings.HasSuffix(name, ".scale") {
baseName := strings.TrimSuffix(name, ".scale")
allTensors[baseName+"_scale"] = arr
scaleBaseNames[baseName] = true
}
}
// Phase 3: Process remaining tensors with complete scale knowledge
for name, arr := range rawTensors {
if strings.HasSuffix(name, ".scale") {
continue // already handled
}
if strings.HasSuffix(name, ".bias") && !strings.HasSuffix(name, ".weight_qbias") {
baseName := strings.TrimSuffix(name, ".bias")
if scaleBaseNames[baseName] {
allTensors[baseName+"_qbias"] = arr
} else {
allTensors[name] = arr
}
} else {
allTensors[name] = arr
}
}
slog.Info("Loaded tensors from manifest", "count", len(allTensors))
return allTensors, nil
}
func (r *Runner) Run(host, port string, mux http.Handler) error {
g, ctx := errgroup.WithContext(context.Background())

View File

@@ -52,7 +52,7 @@ func (c chain) Sample(logits *mlx.Array) *mlx.Array {
type Temperature float32
func (t Temperature) Sample(logits *mlx.Array) *mlx.Array {
return logits.Multiply(mlx.FromValue(1 / float32(t))).Categorical(-1)
return mlx.DivScalar(logits, float32(t)).Categorical(-1)
}
type TopP float32

View File

@@ -5,21 +5,24 @@
package glm4_moe_lite
import (
"encoding/binary"
"encoding/json"
"fmt"
"io"
"math"
"os"
"strings"
"github.com/ollama/ollama/x/imagegen/manifest"
"github.com/ollama/ollama/x/imagegen/tokenizer"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model"
"github.com/ollama/ollama/x/mlxrunner/model/base"
"github.com/ollama/ollama/x/models/nn"
)
func init() {
base.Register("Glm4MoeLiteForCausalLM", newModel)
base.Register("GLM4MoeLite", newModel)
}
// RopeScaling holds RoPE scaling configuration
type RopeScaling struct {
Factor float32 `json:"factor"`
@@ -131,7 +134,6 @@ func (a *MLAAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Con
queries := mlx.Concatenate([]*mlx.Array{qLatent, qPE}, 3)
out := mlx.ScaledDotProductAttentionCausal(queries, keys, values, cfg.Scale, L > 1)
out = a.UnembedOut.Forward(out)
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.VHeadDim)
@@ -386,44 +388,6 @@ func quantizationParams(quantization string) (groupSize, bits int, mode string)
}
}
// readBlobMetadata reads the __metadata__ from a safetensors blob header.
func readBlobMetadata(path string) (map[string]string, error) {
f, err := os.Open(path)
if err != nil {
return nil, err
}
defer f.Close()
var headerSize uint64
if err := binary.Read(f, binary.LittleEndian, &headerSize); err != nil {
return nil, err
}
if headerSize > 1024*1024 {
return nil, fmt.Errorf("header too large: %d", headerSize)
}
data := make([]byte, headerSize)
if _, err := io.ReadFull(f, data); err != nil {
return nil, err
}
var header map[string]json.RawMessage
if err := json.Unmarshal(data, &header); err != nil {
return nil, err
}
metaRaw, ok := header["__metadata__"]
if !ok {
return nil, nil
}
var meta map[string]string
if err := json.Unmarshal(metaRaw, &meta); err != nil {
return nil, err
}
return meta, nil
}
// ExpertWeight holds a single expert's weight with optional quantization components.
type ExpertWeight struct {
Weight *mlx.Array
@@ -569,9 +533,10 @@ func makeLinear(tensors map[string]*mlx.Array, path string, cfg *Config) nn.Line
return nn.NewLinear(w, bias)
}
// LoadFromManifest loads a GLM4-MoE-Lite model from a manifest (Ollama blob storage).
func LoadFromManifest(modelManifest *manifest.ModelManifest) (*Model, error) {
configData, err := modelManifest.ReadConfig("config.json")
// newModel creates a new GLM4-MoE-Lite model from a Root (config + tokenizer,
// no weights loaded yet). Called by the registry via base.New().
func newModel(root *model.Root) (base.Model, error) {
configData, err := root.Manifest.ReadConfig("config.json")
if err != nil {
return nil, fmt.Errorf("load config: %w", err)
}
@@ -584,66 +549,18 @@ func LoadFromManifest(modelManifest *manifest.ModelManifest) (*Model, error) {
cfg.QHeadDim = cfg.QKNopeHeadDim + cfg.QKRopeHeadDim
cfg.Scale = computeScale(&cfg)
// Load all tensors from manifest blobs into a flat map
allTensors := make(map[string]*mlx.Array)
seen := make(map[string]bool) // dedupe by digest
var quantType string
var quantGroupSize int
for _, layer := range modelManifest.GetTensorLayers("") {
if seen[layer.Digest] {
continue
}
seen[layer.Digest] = true
blobPath := modelManifest.BlobPath(layer.Digest)
// Read quantization metadata from first blob
if quantType == "" {
if meta, err := readBlobMetadata(blobPath); err == nil && meta != nil {
if qt := meta["quant_type"]; qt != "" {
quantType = strings.ToUpper(qt)
}
if gs := meta["group_size"]; gs != "" {
fmt.Sscanf(gs, "%d", &quantGroupSize)
}
}
}
for name, arr := range mlx.Load(blobPath) {
// Map safetensors key naming to our naming convention
// Combined blobs use ".scale" and ".bias" suffixes
if strings.HasSuffix(name, ".scale") {
baseName := strings.TrimSuffix(name, ".scale")
allTensors[baseName+"_scale"] = arr
} else if strings.HasSuffix(name, ".bias") && !strings.HasSuffix(name, ".weight_qbias") {
// Check if this is a quantization bias or a regular bias
// by checking if there's a corresponding weight
baseName := strings.TrimSuffix(name, ".bias")
if _, hasScale := allTensors[baseName+"_scale"]; hasScale {
allTensors[baseName+"_qbias"] = arr
} else {
allTensors[name] = arr
}
} else {
allTensors[name] = arr
}
}
}
// Set up quantization parameters
useQuantized := false
if quantType != "" {
_, cfg.QuantBits, cfg.QuantMode = quantizationParams(quantType)
if quantGroupSize > 0 {
cfg.QuantGroupSize = quantGroupSize
// Set up quantization parameters from pre-scanned metadata
if qt := root.QuantType(); qt != "" {
_, cfg.QuantBits, cfg.QuantMode = quantizationParams(qt)
if gs := root.GroupSize(); gs > 0 {
cfg.QuantGroupSize = gs
} else {
cfg.QuantGroupSize, _, _ = quantizationParams(quantType)
cfg.QuantGroupSize, _, _ = quantizationParams(qt)
}
useQuantized = supportsGatherQMM(cfg.QuantMode, cfg.QuantBits)
}
// Load tokenizer
tokData, err := modelManifest.ReadConfig("tokenizer.json")
tokData, err := root.Manifest.ReadConfig("tokenizer.json")
if err != nil {
return nil, fmt.Errorf("load tokenizer config: %w", err)
}
@@ -652,11 +569,11 @@ func LoadFromManifest(modelManifest *manifest.ModelManifest) (*Model, error) {
ConfigJSON: configData,
}
if genConfigData, err := modelManifest.ReadConfig("generation_config.json"); err == nil {
if genConfigData, err := root.Manifest.ReadConfig("generation_config.json"); err == nil {
tokConfig.GenerationConfigJSON = genConfigData
}
if tokConfigData, err := modelManifest.ReadConfig("tokenizer_config.json"); err == nil {
if tokConfigData, err := root.Manifest.ReadConfig("tokenizer_config.json"); err == nil {
tokConfig.TokenizerConfigJSON = tokConfigData
}
@@ -671,18 +588,28 @@ func LoadFromManifest(modelManifest *manifest.ModelManifest) (*Model, error) {
tok: tok,
}
return m, nil
}
// LoadWeights receives all tensors loaded from the manifest and assigns them
// to model fields. Handles MLA absorption, expert stacking, and quantized
// layer creation.
func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
cfg := m.Config
useQuantized := supportsGatherQMM(cfg.QuantMode, cfg.QuantBits)
// Load embedding
if w := allTensors["model.embed_tokens.weight"]; w != nil {
if w := tensors["model.embed_tokens.weight"]; w != nil {
m.EmbedTokens = nn.NewEmbedding(w)
}
// Load final norm
if w := allTensors["model.norm.weight"]; w != nil {
if w := tensors["model.norm.weight"]; w != nil {
m.Norm = nn.NewRMSNorm(w, cfg.RMSNormEps)
}
// Load LM head
m.LMHead = makeLinear(allTensors, "lm_head", &cfg)
m.LMHead = makeLinear(tensors, "lm_head", cfg)
// Load layers
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
@@ -690,24 +617,24 @@ func LoadFromManifest(modelManifest *manifest.ModelManifest) (*Model, error) {
// Load attention (same for both block types)
attn := &MLAAttention{}
attn.QAProj = makeLinear(allTensors, prefix+".self_attn.q_a_proj", &cfg)
if w := allTensors[prefix+".self_attn.q_a_layernorm.weight"]; w != nil {
attn.QAProj = makeLinear(tensors, prefix+".self_attn.q_a_proj", cfg)
if w := tensors[prefix+".self_attn.q_a_layernorm.weight"]; w != nil {
attn.QALayerNorm = nn.NewRMSNorm(w, cfg.RMSNormEps)
}
attn.QBProj = makeLinear(allTensors, prefix+".self_attn.q_b_proj", &cfg)
attn.KVAProjWithMQA = makeLinear(allTensors, prefix+".self_attn.kv_a_proj_with_mqa", &cfg)
if w := allTensors[prefix+".self_attn.kv_a_layernorm.weight"]; w != nil {
attn.QBProj = makeLinear(tensors, prefix+".self_attn.q_b_proj", cfg)
attn.KVAProjWithMQA = makeLinear(tensors, prefix+".self_attn.kv_a_proj_with_mqa", cfg)
if w := tensors[prefix+".self_attn.kv_a_layernorm.weight"]; w != nil {
attn.KVALayerNorm = nn.NewRMSNorm(w, cfg.RMSNormEps)
}
attn.OProj = makeLinear(allTensors, prefix+".self_attn.o_proj", &cfg)
attn.OProj = makeLinear(tensors, prefix+".self_attn.o_proj", cfg)
// Sanitize MLA weights for absorbed attention
embedQ, unembedOut := sanitizeMLAWeights(allTensors, prefix, &cfg)
embedQ, unembedOut := sanitizeMLAWeights(tensors, prefix, cfg)
attn.EmbedQ = nn.NewMultiLinear(embedQ)
attn.UnembedOut = nn.NewMultiLinear(unembedOut)
inputLN := allTensors[prefix+".input_layernorm.weight"]
postAttnLN := allTensors[prefix+".post_attention_layernorm.weight"]
inputLN := tensors[prefix+".input_layernorm.weight"]
postAttnLN := tensors[prefix+".post_attention_layernorm.weight"]
if i < cfg.FirstKDenseReplace {
// Dense block
@@ -720,9 +647,9 @@ func LoadFromManifest(modelManifest *manifest.ModelManifest) (*Model, error) {
}
block.MLP = &DenseMLP{
GateProj: makeLinear(allTensors, prefix+".mlp.gate_proj", &cfg),
UpProj: makeLinear(allTensors, prefix+".mlp.up_proj", &cfg),
DownProj: makeLinear(allTensors, prefix+".mlp.down_proj", &cfg),
GateProj: makeLinear(tensors, prefix+".mlp.gate_proj", cfg),
UpProj: makeLinear(tensors, prefix+".mlp.up_proj", cfg),
DownProj: makeLinear(tensors, prefix+".mlp.down_proj", cfg),
}
m.Layers[i] = block
@@ -737,7 +664,7 @@ func LoadFromManifest(modelManifest *manifest.ModelManifest) (*Model, error) {
}
// Stack expert weights
gate, up, down := sanitizeExpertWeights(allTensors, prefix, cfg.NRoutedExperts, useQuantized, &cfg)
gate, up, down := sanitizeExpertWeights(tensors, prefix, cfg.NRoutedExperts, useQuantized, cfg)
switchMLP := &SwitchMLP{UseQuantized: useQuantized}
if useQuantized {
@@ -763,8 +690,8 @@ func LoadFromManifest(modelManifest *manifest.ModelManifest) (*Model, error) {
}
moeGate := &MoEGate{}
moeGate.Gate = makeLinear(allTensors, prefix+".mlp.gate", &cfg)
if bias := allTensors[prefix+".mlp.gate.e_score_correction_bias"]; bias != nil {
moeGate.Gate = makeLinear(tensors, prefix+".mlp.gate", cfg)
if bias := tensors[prefix+".mlp.gate.e_score_correction_bias"]; bias != nil {
moeGate.EScoreCorrectionBias = bias
}
@@ -776,9 +703,9 @@ func LoadFromManifest(modelManifest *manifest.ModelManifest) (*Model, error) {
// Load shared experts if present
if cfg.NSharedExperts > 0 {
block.MoE.SharedExperts = &SharedExperts{
GateProj: makeLinear(allTensors, prefix+".mlp.shared_experts.gate_proj", &cfg),
UpProj: makeLinear(allTensors, prefix+".mlp.shared_experts.up_proj", &cfg),
DownProj: makeLinear(allTensors, prefix+".mlp.shared_experts.down_proj", &cfg),
GateProj: makeLinear(tensors, prefix+".mlp.shared_experts.gate_proj", cfg),
UpProj: makeLinear(tensors, prefix+".mlp.shared_experts.up_proj", cfg),
DownProj: makeLinear(tensors, prefix+".mlp.shared_experts.down_proj", cfg),
}
}
@@ -786,9 +713,10 @@ func LoadFromManifest(modelManifest *manifest.ModelManifest) (*Model, error) {
}
}
mlx.Eval(mlx.Collect(m)...)
collected := mlx.Collect(m)
mlx.Eval(collected...)
return m, nil
return nil
}
// Forward computes the forward pass of the model