mirror of
https://github.com/ollama/ollama.git
synced 2026-02-14 01:23:04 -05:00
Compare commits
4 Commits
main
...
pdevine/ml
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
050b0a03a6 | ||
|
|
8faae6e443 | ||
|
|
f354af3190 | ||
|
|
967bedce30 |
14
cmd/cmd.go
14
cmd/cmd.go
@@ -581,6 +581,17 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
opts.WordWrap = !nowrap
|
||||
|
||||
useImagegen := false
|
||||
if cmd.Flags().Lookup("imagegen") != nil {
|
||||
useImagegen, err = cmd.Flags().GetBool("imagegen")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if useImagegen {
|
||||
opts.Options["use_imagegen_runner"] = true
|
||||
}
|
||||
|
||||
// Fill out the rest of the options based on information about the
|
||||
// model.
|
||||
client, err := api.ClientFromEnvironment()
|
||||
@@ -2130,6 +2141,9 @@ func NewCLI() *cobra.Command {
|
||||
// Image generation flags (width, height, steps, seed, etc.)
|
||||
imagegen.RegisterFlags(runCmd)
|
||||
|
||||
runCmd.Flags().Bool("imagegen", false, "Use the imagegen runner for LLM inference")
|
||||
runCmd.Flags().MarkHidden("imagegen")
|
||||
|
||||
stopCmd := &cobra.Command{
|
||||
Use: "stop MODEL",
|
||||
Short: "Stop a running model",
|
||||
|
||||
@@ -144,12 +144,15 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.C
|
||||
return nil, nil, nil, fmt.Errorf("%s %w", name, err)
|
||||
}
|
||||
|
||||
useImagegen, _ := requestOpts["use_imagegen_runner"].(bool)
|
||||
delete(requestOpts, "use_imagegen_runner")
|
||||
|
||||
opts, err := s.modelOptions(model, requestOpts)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
runnerCh, errCh := s.sched.GetRunner(ctx, model, opts, keepAlive)
|
||||
runnerCh, errCh := s.sched.GetRunner(ctx, model, opts, keepAlive, useImagegen)
|
||||
var runner *runnerRef
|
||||
select {
|
||||
case runner = <-runnerCh:
|
||||
|
||||
@@ -2383,6 +2383,7 @@ func TestImageGenerateStreamFalse(t *testing.T) {
|
||||
llama: &mock,
|
||||
Options: &opts,
|
||||
model: &Model{Config: model.ConfigV2{Capabilities: []string{"image"}}},
|
||||
isImagegen: true,
|
||||
numParallel: 1,
|
||||
},
|
||||
},
|
||||
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/mlxrunner"
|
||||
)
|
||||
|
||||
type LlmRequest struct {
|
||||
@@ -32,6 +33,7 @@ type LlmRequest struct {
|
||||
successCh chan *runnerRef
|
||||
errCh chan error
|
||||
schedAttempts uint
|
||||
useImagegen bool
|
||||
}
|
||||
|
||||
type Scheduler struct {
|
||||
@@ -82,7 +84,7 @@ func InitScheduler(ctx context.Context) *Scheduler {
|
||||
}
|
||||
|
||||
// context must be canceled to decrement ref count and release the runner
|
||||
func (s *Scheduler) GetRunner(c context.Context, m *Model, opts api.Options, sessionDuration *api.Duration) (chan *runnerRef, chan error) {
|
||||
func (s *Scheduler) GetRunner(c context.Context, m *Model, opts api.Options, sessionDuration *api.Duration, useImagegen bool) (chan *runnerRef, chan error) {
|
||||
if opts.NumCtx < 4 {
|
||||
opts.NumCtx = 4
|
||||
}
|
||||
@@ -99,6 +101,7 @@ func (s *Scheduler) GetRunner(c context.Context, m *Model, opts api.Options, ses
|
||||
sessionDuration: sessionDuration,
|
||||
successCh: make(chan *runnerRef, 1),
|
||||
errCh: make(chan error, 1),
|
||||
useImagegen: useImagegen,
|
||||
}
|
||||
|
||||
s.loadedMu.Lock()
|
||||
@@ -566,17 +569,20 @@ iGPUScan:
|
||||
// loadMLX loads an experimental safetensors model using the unified MLX runner.
|
||||
// This supports both LLM (completion) and image generation models.
|
||||
func (s *Scheduler) loadMLX(req *LlmRequest) bool {
|
||||
// Determine mode based on capabilities
|
||||
var mode imagegen.ModelMode
|
||||
if slices.Contains(req.model.Config.Capabilities, "image") {
|
||||
mode = imagegen.ModeImageGen
|
||||
} else {
|
||||
mode = imagegen.ModeLLM
|
||||
}
|
||||
|
||||
// Use model name for MLX (it resolves manifests by name, not file path)
|
||||
modelName := req.model.ShortName
|
||||
server, err := imagegen.NewServer(modelName, mode)
|
||||
var server llm.LlamaServer
|
||||
var err error
|
||||
|
||||
isImagegen := false
|
||||
if slices.Contains(req.model.Config.Capabilities, "image") {
|
||||
server, err = imagegen.NewServer(modelName, imagegen.ModeImageGen)
|
||||
isImagegen = true
|
||||
} else if req.useImagegen {
|
||||
server, err = imagegen.NewServer(modelName, imagegen.ModeLLM)
|
||||
isImagegen = true
|
||||
} else {
|
||||
server, err = mlxrunner.NewClient(modelName)
|
||||
}
|
||||
if err != nil {
|
||||
req.errCh <- err
|
||||
return true
|
||||
@@ -593,6 +599,7 @@ func (s *Scheduler) loadMLX(req *LlmRequest) bool {
|
||||
llama: server,
|
||||
Options: &req.opts,
|
||||
loading: false,
|
||||
isImagegen: isImagegen,
|
||||
sessionDuration: sessionDuration,
|
||||
totalSize: server.TotalSize(),
|
||||
vramSize: server.VRAMSize(),
|
||||
@@ -667,6 +674,7 @@ type runnerRef struct {
|
||||
loading bool // True only during initial load, then false forever
|
||||
gpus []ml.DeviceID // Recorded at time of provisioning
|
||||
discreteGPUs bool // True if all devices are discrete GPUs - used to skip VRAM recovery check for iGPUs
|
||||
isImagegen bool // True if loaded via imagegen runner (vs mlxrunner)
|
||||
vramSize uint64
|
||||
totalSize uint64
|
||||
|
||||
@@ -699,6 +707,12 @@ func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool
|
||||
runner.refMu.Lock()
|
||||
defer runner.refMu.Unlock()
|
||||
|
||||
// Check if runner type (imagegen vs mlxrunner) matches what's requested
|
||||
wantImagegen := req.useImagegen || slices.Contains(req.model.Config.Capabilities, "image")
|
||||
if runner.isImagegen != wantImagegen {
|
||||
return true
|
||||
}
|
||||
|
||||
timeout := 10 * time.Second
|
||||
if runner.loading {
|
||||
timeout = 2 * time.Minute // Initial load can take a long time for big models on slow systems...
|
||||
|
||||
@@ -408,10 +408,10 @@ func TestSchedGetRunner(t *testing.T) {
|
||||
s.getSystemInfoFn = getSystemInfoFn
|
||||
s.newServerFn = a.newServer
|
||||
slog.Info("a")
|
||||
successCh1a, errCh1a := s.GetRunner(a.ctx, a.req.model, a.req.opts, a.req.sessionDuration)
|
||||
successCh1a, errCh1a := s.GetRunner(a.ctx, a.req.model, a.req.opts, a.req.sessionDuration, false)
|
||||
require.Len(t, s.pendingReqCh, 1)
|
||||
slog.Info("b")
|
||||
successCh1b, errCh1b := s.GetRunner(b.ctx, b.req.model, b.req.opts, b.req.sessionDuration)
|
||||
successCh1b, errCh1b := s.GetRunner(b.ctx, b.req.model, b.req.opts, b.req.sessionDuration, false)
|
||||
require.Len(t, s.pendingReqCh, 1)
|
||||
require.Empty(t, successCh1b)
|
||||
require.Len(t, errCh1b, 1)
|
||||
@@ -435,7 +435,7 @@ func TestSchedGetRunner(t *testing.T) {
|
||||
|
||||
c.req.model.ModelPath = "bad path"
|
||||
slog.Info("c")
|
||||
successCh1c, errCh1c := s.GetRunner(c.ctx, c.req.model, c.req.opts, c.req.sessionDuration)
|
||||
successCh1c, errCh1c := s.GetRunner(c.ctx, c.req.model, c.req.opts, c.req.sessionDuration, false)
|
||||
// Starts in pending channel, then should be quickly processed to return an error
|
||||
time.Sleep(50 * time.Millisecond) // Long enough for the "a" model to expire and unload
|
||||
require.Empty(t, successCh1c)
|
||||
@@ -509,7 +509,7 @@ func TestSchedPrematureExpired(t *testing.T) {
|
||||
s.getGpuFn = getGpuFn
|
||||
s.getSystemInfoFn = getSystemInfoFn
|
||||
s.newServerFn = scenario1a.newServer
|
||||
successCh1a, errCh1a := s.GetRunner(scenario1a.ctx, scenario1a.req.model, scenario1a.req.opts, scenario1a.req.sessionDuration)
|
||||
successCh1a, errCh1a := s.GetRunner(scenario1a.ctx, scenario1a.req.model, scenario1a.req.opts, scenario1a.req.sessionDuration, false)
|
||||
require.Len(t, s.pendingReqCh, 1)
|
||||
s.Run(ctx)
|
||||
select {
|
||||
|
||||
@@ -102,15 +102,20 @@ func (mw *ManifestWeights) Load(dtype mlx.Dtype) error {
|
||||
for _, entry := range entries {
|
||||
name := entry.name
|
||||
|
||||
// Try to get tensor by stripped name first, then with component prefix.
|
||||
// Blobs may store tensors with the full prefixed name (e.g., "text_encoder/model.layers.0.weight")
|
||||
// while the tensors map uses stripped names (e.g., "model.layers.0.weight").
|
||||
// Try to get tensor by stripped name first, then with component prefix,
|
||||
// then fall back to "data" for legacy blobs created by older versions
|
||||
// that stored all tensors with the generic key "data".
|
||||
lookupName := name
|
||||
arr := sf.Get(lookupName)
|
||||
if arr == nil && mw.component != "" {
|
||||
lookupName = mw.component + "/" + name
|
||||
arr = sf.Get(lookupName)
|
||||
}
|
||||
if arr == nil {
|
||||
// Legacy blob format: tensor stored as "data"
|
||||
lookupName = "data"
|
||||
arr = sf.Get(lookupName)
|
||||
}
|
||||
if arr != nil {
|
||||
// Single-tensor blob or tensor found by name
|
||||
if dtype != 0 && arr.Dtype() != dtype {
|
||||
|
||||
@@ -2,76 +2,298 @@ package mlxrunner
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"math"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/manifest"
|
||||
)
|
||||
|
||||
// Client wraps an MLX runner subprocess to implement llm.LlamaServer for LLM models.
|
||||
type Client struct {
|
||||
Port int
|
||||
*exec.Cmd
|
||||
port int
|
||||
modelName string
|
||||
vramSize uint64
|
||||
done chan error
|
||||
client *http.Client
|
||||
lastErr string
|
||||
lastErrLock sync.Mutex
|
||||
mu sync.Mutex
|
||||
cmd *exec.Cmd
|
||||
}
|
||||
|
||||
func (c *Client) JoinPath(path string) string {
|
||||
return (&url.URL{
|
||||
Scheme: "http",
|
||||
Host: net.JoinHostPort("127.0.0.1", strconv.Itoa(c.Port)),
|
||||
}).JoinPath(path).String()
|
||||
// NewClient spawns a new MLX runner subprocess for LLM models and waits until it's ready.
|
||||
func NewClient(modelName string) (*Client, error) {
|
||||
if err := imagegen.CheckPlatformSupport(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Find a free port
|
||||
port := 0
|
||||
if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
|
||||
if l, err := net.ListenTCP("tcp", a); err == nil {
|
||||
port = l.Addr().(*net.TCPAddr).Port
|
||||
l.Close()
|
||||
}
|
||||
}
|
||||
if port == 0 {
|
||||
port = rand.Intn(65535-49152) + 49152
|
||||
}
|
||||
|
||||
// Get the current executable path
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to lookup executable path: %w", err)
|
||||
}
|
||||
if eval, err := filepath.EvalSymlinks(exe); err == nil {
|
||||
exe = eval
|
||||
}
|
||||
|
||||
// Spawn subprocess: ollama runner --mlx-engine --model <name> --port <port>
|
||||
cmd := exec.Command(exe, "runner", "--mlx-engine", "--model", modelName, "--port", strconv.Itoa(port))
|
||||
cmd.Env = os.Environ()
|
||||
|
||||
// On Linux, set LD_LIBRARY_PATH to include MLX library directories
|
||||
if runtime.GOOS == "linux" {
|
||||
libraryPaths := []string{ml.LibOllamaPath}
|
||||
if mlxDirs, err := filepath.Glob(filepath.Join(ml.LibOllamaPath, "mlx_*")); err == nil {
|
||||
libraryPaths = append(libraryPaths, mlxDirs...)
|
||||
}
|
||||
|
||||
if existingPath, ok := os.LookupEnv("LD_LIBRARY_PATH"); ok {
|
||||
libraryPaths = append(libraryPaths, filepath.SplitList(existingPath)...)
|
||||
}
|
||||
|
||||
pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator))
|
||||
|
||||
found := false
|
||||
for i := range cmd.Env {
|
||||
if strings.HasPrefix(cmd.Env[i], "LD_LIBRARY_PATH=") {
|
||||
cmd.Env[i] = "LD_LIBRARY_PATH=" + pathEnvVal
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
cmd.Env = append(cmd.Env, "LD_LIBRARY_PATH="+pathEnvVal)
|
||||
}
|
||||
slog.Debug("mlx subprocess library path", "LD_LIBRARY_PATH", pathEnvVal)
|
||||
}
|
||||
|
||||
// Estimate VRAM based on tensor size from manifest
|
||||
var vramSize uint64
|
||||
if modelManifest, err := manifest.LoadManifest(modelName); err == nil {
|
||||
vramSize = uint64(modelManifest.TotalTensorSize())
|
||||
} else {
|
||||
vramSize = 8 * 1024 * 1024 * 1024
|
||||
}
|
||||
|
||||
c := &Client{
|
||||
port: port,
|
||||
modelName: modelName,
|
||||
vramSize: vramSize,
|
||||
done: make(chan error, 1),
|
||||
client: &http.Client{Timeout: 10 * time.Minute},
|
||||
cmd: cmd,
|
||||
}
|
||||
|
||||
// Forward subprocess stdout/stderr to server logs
|
||||
stdout, _ := cmd.StdoutPipe()
|
||||
stderr, _ := cmd.StderrPipe()
|
||||
go func() {
|
||||
scanner := bufio.NewScanner(stdout)
|
||||
for scanner.Scan() {
|
||||
slog.Info("mlx-runner", "msg", scanner.Text())
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
scanner := bufio.NewScanner(stderr)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
slog.Warn("mlx-runner", "msg", line)
|
||||
c.lastErrLock.Lock()
|
||||
c.lastErr = line
|
||||
c.lastErrLock.Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
slog.Info("starting mlx runner subprocess", "exe", exe, "model", modelName, "port", port)
|
||||
if err := cmd.Start(); err != nil {
|
||||
return nil, fmt.Errorf("failed to start mlx runner: %w", err)
|
||||
}
|
||||
|
||||
// Reap subprocess when it exits
|
||||
go func() {
|
||||
err := cmd.Wait()
|
||||
c.done <- err
|
||||
}()
|
||||
|
||||
// Wait for subprocess to be ready
|
||||
if err := c.waitUntilRunning(); err != nil {
|
||||
c.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *Client) CheckError(w *http.Response) error {
|
||||
if w.StatusCode >= 400 {
|
||||
return errors.New(w.Status)
|
||||
func (c *Client) getLastErr() string {
|
||||
c.lastErrLock.Lock()
|
||||
defer c.lastErrLock.Unlock()
|
||||
return c.lastErr
|
||||
}
|
||||
|
||||
func (c *Client) waitUntilRunning() error {
|
||||
ctx := context.Background()
|
||||
timeout := time.After(2 * time.Minute)
|
||||
ticker := time.NewTicker(100 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case err := <-c.done:
|
||||
errMsg := c.getLastErr()
|
||||
if errMsg != "" {
|
||||
return fmt.Errorf("mlx runner failed: %s (exit: %v)", errMsg, err)
|
||||
}
|
||||
return fmt.Errorf("mlx runner exited unexpectedly: %w", err)
|
||||
case <-timeout:
|
||||
errMsg := c.getLastErr()
|
||||
if errMsg != "" {
|
||||
return fmt.Errorf("timeout waiting for mlx runner: %s", errMsg)
|
||||
}
|
||||
return errors.New("timeout waiting for mlx runner to start")
|
||||
case <-ticker.C:
|
||||
if err := c.Ping(ctx); err == nil {
|
||||
slog.Info("mlx runner is ready", "port", c.port)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// completionRequest is a properly-tagged version of llm.CompletionRequest for JSON serialization.
|
||||
type completionRequest struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Options *completionOpts `json:"options,omitempty"`
|
||||
}
|
||||
|
||||
type completionOpts struct {
|
||||
Temperature float32 `json:"temperature,omitempty"`
|
||||
TopP float32 `json:"top_p,omitempty"`
|
||||
MinP float32 `json:"min_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
NumPredict int `json:"num_predict,omitempty"`
|
||||
}
|
||||
|
||||
// Close terminates the subprocess.
|
||||
func (c *Client) Close() error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.cmd != nil && c.cmd.Process != nil {
|
||||
slog.Info("stopping mlx runner subprocess", "pid", c.cmd.Process.Pid)
|
||||
c.cmd.Process.Signal(os.Interrupt)
|
||||
|
||||
select {
|
||||
case <-c.done:
|
||||
case <-time.After(5 * time.Second):
|
||||
c.cmd.Process.Kill()
|
||||
}
|
||||
c.cmd = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close implements llm.LlamaServer.
|
||||
func (c *Client) Close() error {
|
||||
return c.Cmd.Process.Kill()
|
||||
}
|
||||
|
||||
// Completion implements llm.LlamaServer.
|
||||
func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(req); err != nil {
|
||||
return err
|
||||
creq := completionRequest{
|
||||
Prompt: req.Prompt,
|
||||
}
|
||||
if req.Options != nil {
|
||||
creq.Options = &completionOpts{
|
||||
Temperature: req.Options.Temperature,
|
||||
TopP: req.Options.TopP,
|
||||
MinP: req.Options.MinP,
|
||||
TopK: req.Options.TopK,
|
||||
NumPredict: req.Options.NumPredict,
|
||||
}
|
||||
}
|
||||
|
||||
w, err := http.Post(c.JoinPath("/v1/completions"), "application/json", &b)
|
||||
body, err := json.Marshal(creq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer w.Body.Close()
|
||||
|
||||
if err := c.CheckError(w); err != nil {
|
||||
httpURL := fmt.Sprintf("http://127.0.0.1:%d/completion", c.port)
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", httpURL, strings.NewReader(string(body)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
scanner := bufio.NewScanner(w.Body)
|
||||
for scanner.Scan() {
|
||||
bts := scanner.Bytes()
|
||||
resp, err := c.client.Do(httpReq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var resp llm.CompletionResponse
|
||||
if err := json.Unmarshal(bts, &resp); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fn(resp)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("%s", strings.TrimSpace(string(respBody)))
|
||||
}
|
||||
|
||||
return nil
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
for scanner.Scan() {
|
||||
var raw struct {
|
||||
Content string `json:"content,omitempty"`
|
||||
Done bool `json:"done"`
|
||||
DoneReason int `json:"done_reason,omitempty"`
|
||||
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
|
||||
PromptEvalDuration int `json:"prompt_eval_duration,omitempty"`
|
||||
EvalCount int `json:"eval_count,omitempty"`
|
||||
EvalDuration int `json:"eval_duration,omitempty"`
|
||||
}
|
||||
if err := json.Unmarshal(scanner.Bytes(), &raw); err != nil {
|
||||
slog.Debug("mlx response parse error", "error", err, "line", string(scanner.Bytes()))
|
||||
continue
|
||||
}
|
||||
|
||||
cresp := llm.CompletionResponse{
|
||||
Content: raw.Content,
|
||||
Done: raw.Done,
|
||||
DoneReason: llm.DoneReason(raw.DoneReason),
|
||||
PromptEvalCount: raw.PromptEvalCount,
|
||||
PromptEvalDuration: time.Duration(raw.PromptEvalDuration),
|
||||
EvalCount: raw.EvalCount,
|
||||
EvalDuration: time.Duration(raw.EvalDuration),
|
||||
}
|
||||
|
||||
fn(cresp)
|
||||
if cresp.Done {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
func (c *Client) ContextLength() int {
|
||||
@@ -80,71 +302,89 @@ func (c *Client) ContextLength() int {
|
||||
|
||||
// Detokenize implements llm.LlamaServer.
|
||||
func (c *Client) Detokenize(ctx context.Context, tokens []int) (string, error) {
|
||||
panic("unimplemented")
|
||||
return "", errors.New("not supported")
|
||||
}
|
||||
|
||||
// Embedding implements llm.LlamaServer.
|
||||
func (c *Client) Embedding(ctx context.Context, input string) ([]float32, int, error) {
|
||||
panic("unimplemented")
|
||||
return nil, 0, errors.New("not supported")
|
||||
}
|
||||
|
||||
// GetDeviceInfos implements llm.LlamaServer.
|
||||
func (c *Client) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
|
||||
panic("unimplemented")
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetPort implements llm.LlamaServer.
|
||||
func (c *Client) GetPort() int {
|
||||
return c.Port
|
||||
return c.port
|
||||
}
|
||||
|
||||
// HasExited implements llm.LlamaServer.
|
||||
func (c *Client) HasExited() bool {
|
||||
panic("unimplemented")
|
||||
select {
|
||||
case <-c.done:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Load implements llm.LlamaServer.
|
||||
func (c *Client) Load(ctx context.Context, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) ([]ml.DeviceID, error) {
|
||||
w, err := http.Post(c.JoinPath("/v1/models"), "application/json", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer w.Body.Close()
|
||||
|
||||
return []ml.DeviceID{}, nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// ModelPath implements llm.LlamaServer.
|
||||
func (c *Client) ModelPath() string {
|
||||
panic("unimplemented")
|
||||
return c.modelName
|
||||
}
|
||||
|
||||
// Pid implements llm.LlamaServer.
|
||||
func (c *Client) Pid() int {
|
||||
panic("unimplemented")
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.cmd != nil && c.cmd.Process != nil {
|
||||
return c.cmd.Process.Pid
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// Ping implements llm.LlamaServer.
|
||||
func (c *Client) Ping(ctx context.Context) error {
|
||||
w, err := http.Get(c.JoinPath("/v1/status"))
|
||||
reqURL := fmt.Sprintf("http://127.0.0.1:%d/health", c.port)
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer w.Body.Close()
|
||||
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("health check failed: %d", resp.StatusCode)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Tokenize implements llm.LlamaServer.
|
||||
func (c *Client) Tokenize(ctx context.Context, content string) ([]int, error) {
|
||||
w, err := http.Post(c.JoinPath("/v1/tokenize"), "text/plain", strings.NewReader(content))
|
||||
reqURL := fmt.Sprintf("http://127.0.0.1:%d/v1/tokenize", c.port)
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", reqURL, strings.NewReader(content))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer w.Body.Close()
|
||||
req.Header.Set("Content-Type", "text/plain")
|
||||
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var tokens []int
|
||||
if err := json.NewDecoder(w.Body).Decode(&tokens); err != nil {
|
||||
if err := json.NewDecoder(resp.Body).Decode(&tokens); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -153,22 +393,22 @@ func (c *Client) Tokenize(ctx context.Context, content string) ([]int, error) {
|
||||
|
||||
// TotalSize implements llm.LlamaServer.
|
||||
func (c *Client) TotalSize() uint64 {
|
||||
panic("unimplemented")
|
||||
return c.vramSize
|
||||
}
|
||||
|
||||
// VRAMByGPU implements llm.LlamaServer.
|
||||
func (c *Client) VRAMByGPU(id ml.DeviceID) uint64 {
|
||||
panic("unimplemented")
|
||||
return c.vramSize
|
||||
}
|
||||
|
||||
// VRAMSize implements llm.LlamaServer.
|
||||
func (c *Client) VRAMSize() uint64 {
|
||||
panic("unimplemented")
|
||||
return c.vramSize
|
||||
}
|
||||
|
||||
// WaitUntilRunning implements llm.LlamaServer.
|
||||
func (c *Client) WaitUntilRunning(ctx context.Context) error {
|
||||
panic("unimplemented")
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ llm.LlamaServer = (*Client)(nil)
|
||||
|
||||
7
x/mlxrunner/imports.go
Normal file
7
x/mlxrunner/imports.go
Normal file
@@ -0,0 +1,7 @@
|
||||
//go:build mlx
|
||||
|
||||
package mlxrunner
|
||||
|
||||
import (
|
||||
_ "github.com/ollama/ollama/x/models/glm4_moe_lite"
|
||||
)
|
||||
@@ -133,6 +133,7 @@ func FromValues[S ~[]E, E arrayTypes](s S, shape ...int) *Array {
|
||||
}
|
||||
|
||||
func (t *Array) Set(other *Array) {
|
||||
Free(t.desc.inputs...)
|
||||
other.desc.numRefs++
|
||||
t.desc.inputs = []*Array{other}
|
||||
C.mlx_array_set(&t.ctx, other.ctx)
|
||||
@@ -248,9 +249,9 @@ func Free(s ...*Array) (n int) {
|
||||
free := make([]*Array, 0, 8192)
|
||||
fn := func(t *Array) {
|
||||
if t.Valid() {
|
||||
free = append(free, t.desc.inputs...)
|
||||
t.desc.numRefs--
|
||||
if t.desc.numRefs <= 0 {
|
||||
free = append(free, t.desc.inputs...)
|
||||
logutil.Trace("Free", "t", t)
|
||||
n += t.NumBytes()
|
||||
C.mlx_array_free(t.ctx)
|
||||
|
||||
@@ -24,6 +24,37 @@ func CheckInit() error {
|
||||
return initError
|
||||
}
|
||||
|
||||
// tryLoadFromDir searches a directory for libmlxc.* and tries to load it.
|
||||
// Returns true if the library was successfully loaded.
|
||||
func tryLoadFromDir(dir string) bool {
|
||||
matches, err := fs.Glob(os.DirFS(dir), "libmlxc.*")
|
||||
if err != nil || len(matches) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, match := range matches {
|
||||
path := filepath.Join(dir, match)
|
||||
|
||||
cPath := C.CString(path)
|
||||
defer C.free(unsafe.Pointer(cPath))
|
||||
|
||||
var handle C.mlx_dynamic_handle
|
||||
if C.mlx_dynamic_load(&handle, cPath) != 0 {
|
||||
slog.Error("Failed to load MLX dynamic library", "path", path)
|
||||
continue
|
||||
}
|
||||
|
||||
if C.mlx_dynamic_load_symbols(handle) != 0 {
|
||||
slog.Error("Failed to load MLX dynamic library symbols", "path", path)
|
||||
C.mlx_dynamic_unload(&handle)
|
||||
continue
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func init() {
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
@@ -33,44 +64,34 @@ func init() {
|
||||
return
|
||||
}
|
||||
|
||||
paths, ok := os.LookupEnv("OLLAMA_LIBRARY_PATH")
|
||||
if !ok {
|
||||
slog.Debug("OLLAMA_LIBRARY_PATH not set, skipping mlx dynamic loading")
|
||||
return
|
||||
// Try OLLAMA_LIBRARY_PATH first
|
||||
if paths, ok := os.LookupEnv("OLLAMA_LIBRARY_PATH"); ok {
|
||||
for _, dir := range filepath.SplitList(paths) {
|
||||
if tryLoadFromDir(dir) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, path := range filepath.SplitList(paths) {
|
||||
matches, err := fs.Glob(os.DirFS(path), "libmlxc.*")
|
||||
if err != nil {
|
||||
initError = fmt.Errorf("failed to glob for MLX libraries in %s: %w", path, err)
|
||||
slog.Warn("MLX dynamic library not available", "error", initError)
|
||||
return
|
||||
// Build search paths: executable directory, then build directories
|
||||
var searchDirs []string
|
||||
if exe, err := os.Executable(); err == nil {
|
||||
if eval, err := filepath.EvalSymlinks(exe); err == nil {
|
||||
exe = eval
|
||||
}
|
||||
searchDirs = append(searchDirs, filepath.Dir(exe))
|
||||
}
|
||||
|
||||
for _, match := range matches {
|
||||
path := filepath.Join(paths, match)
|
||||
slog.Info("Loading MLX dynamic library", "path", path)
|
||||
if cwd, err := os.Getwd(); err == nil {
|
||||
searchDirs = append(searchDirs, filepath.Join(cwd, "build", "lib", "ollama"))
|
||||
}
|
||||
|
||||
cPath := C.CString(path)
|
||||
defer C.free(unsafe.Pointer(cPath))
|
||||
|
||||
var handle C.mlx_dynamic_handle
|
||||
if C.mlx_dynamic_load(&handle, cPath) != 0 {
|
||||
slog.Error("Failed to load MLX dynamic library", "path", path)
|
||||
continue
|
||||
}
|
||||
|
||||
if C.mlx_dynamic_load_symbols(handle) != 0 {
|
||||
slog.Error("Failed to load MLX dynamic library symbols", "path", path)
|
||||
C.mlx_dynamic_unload(&handle)
|
||||
continue
|
||||
}
|
||||
|
||||
slog.Info("Loaded MLX dynamic library", "path", path)
|
||||
for _, dir := range searchDirs {
|
||||
if tryLoadFromDir(dir) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
initError = fmt.Errorf("failed to load any MLX dynamic library from OLLAMA_LIBRARY_PATH=%s", paths)
|
||||
initError = fmt.Errorf("failed to load MLX dynamic library (searched: %v)", searchDirs)
|
||||
slog.Warn("MLX dynamic library not available", "error", initError)
|
||||
}
|
||||
|
||||
@@ -306,19 +306,42 @@ func AddMM(c, a, b *Array, alpha, beta float32) *Array {
|
||||
|
||||
// Scalar helpers
|
||||
|
||||
// scalarWithDtype creates a scalar array matching the dtype of a.
|
||||
// Matching dtype is important for graph fusion and avoiding implicit casts.
|
||||
func scalarWithDtype(s float32, a *Array) C.mlx_array {
|
||||
f32 := C.mlx_array_new_float(C.float(s))
|
||||
dtype := a.DType()
|
||||
if dtype == DTypeFloat32 {
|
||||
return f32
|
||||
}
|
||||
casted := C.mlx_array_new()
|
||||
C.mlx_astype(&casted, f32, C.mlx_dtype(dtype), DefaultStream().ctx)
|
||||
C.mlx_array_free(f32)
|
||||
return casted
|
||||
}
|
||||
|
||||
func AddScalar(a *Array, s float32) *Array {
|
||||
scalar := FromValue(s)
|
||||
return a.Add(scalar)
|
||||
scalar := scalarWithDtype(s, a)
|
||||
out := New("ADD_SCALAR", a)
|
||||
C.mlx_add(&out.ctx, a.ctx, scalar, DefaultStream().ctx)
|
||||
C.mlx_array_free(scalar)
|
||||
return out
|
||||
}
|
||||
|
||||
func MulScalar(a *Array, s float32) *Array {
|
||||
scalar := FromValue(s)
|
||||
return a.Multiply(scalar)
|
||||
scalar := scalarWithDtype(s, a)
|
||||
out := New("MUL_SCALAR", a)
|
||||
C.mlx_multiply(&out.ctx, a.ctx, scalar, DefaultStream().ctx)
|
||||
C.mlx_array_free(scalar)
|
||||
return out
|
||||
}
|
||||
|
||||
func DivScalar(a *Array, s float32) *Array {
|
||||
scalar := FromValue(s)
|
||||
return a.Divide(scalar)
|
||||
scalar := scalarWithDtype(s, a)
|
||||
out := New("DIV_SCALAR", a)
|
||||
C.mlx_divide(&out.ctx, a.ctx, scalar, DefaultStream().ctx)
|
||||
C.mlx_array_free(scalar)
|
||||
return out
|
||||
}
|
||||
|
||||
func FloorDivideScalar(a *Array, s int32) *Array {
|
||||
|
||||
85
x/mlxrunner/model/base/base.go
Normal file
85
x/mlxrunner/model/base/base.go
Normal file
@@ -0,0 +1,85 @@
|
||||
//go:build mlx
|
||||
|
||||
package base
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model"
|
||||
)
|
||||
|
||||
// Model is the interface that model implementations must satisfy.
|
||||
type Model interface {
|
||||
Forward(inputs *mlx.Array, cache []cache.Cache) *mlx.Array
|
||||
Unembed(x *mlx.Array) *mlx.Array
|
||||
NumLayers() int
|
||||
Tokenizer() *tokenizer.Tokenizer
|
||||
|
||||
// LoadWeights receives all tensors loaded from the manifest and assigns
|
||||
// them to model fields. Model-specific logic (MLA absorption, expert
|
||||
// stacking, quantized layer creation) happens here.
|
||||
LoadWeights(tensors map[string]*mlx.Array) error
|
||||
}
|
||||
|
||||
var (
|
||||
mu sync.Mutex
|
||||
registry = make(map[string]func(root *model.Root) (Model, error))
|
||||
)
|
||||
|
||||
// Register registers a model constructor by architecture name.
|
||||
// Called from init() in model packages. Panics on duplicate registration.
|
||||
func Register(arch string, fn func(root *model.Root) (Model, error)) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
if _, exists := registry[arch]; exists {
|
||||
panic(fmt.Sprintf("model architecture %q already registered", arch))
|
||||
}
|
||||
registry[arch] = fn
|
||||
}
|
||||
|
||||
// New reads config.json from the manifest, detects the architecture, looks up
|
||||
// the registered constructor, and calls it to create the model (with config
|
||||
// parsed and struct created, but weights not yet loaded).
|
||||
func New(root *model.Root) (Model, error) {
|
||||
configData, err := root.Manifest.ReadConfig("config.json")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read config.json: %w", err)
|
||||
}
|
||||
|
||||
var archConfig struct {
|
||||
Architectures []string `json:"architectures"`
|
||||
}
|
||||
if err := json.Unmarshal(configData, &archConfig); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse config.json: %w", err)
|
||||
}
|
||||
|
||||
if len(archConfig.Architectures) == 0 {
|
||||
return nil, fmt.Errorf("no architectures found in config.json")
|
||||
}
|
||||
|
||||
arch := archConfig.Architectures[0]
|
||||
slog.Info("Model architecture", "arch", arch)
|
||||
|
||||
mu.Lock()
|
||||
fn, ok := registry[arch]
|
||||
mu.Unlock()
|
||||
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unsupported architecture: %s", arch)
|
||||
}
|
||||
|
||||
return fn(root)
|
||||
}
|
||||
|
||||
// Weights returns the model's LoadWeights method, which encapsulates all
|
||||
// weight assignment and post-processing (MLA absorption, expert stacking).
|
||||
func Weights(m Model) func(map[string]*mlx.Array) error {
|
||||
return m.LoadWeights
|
||||
}
|
||||
3
x/mlxrunner/model/base/base_stub.go
Normal file
3
x/mlxrunner/model/base/base_stub.go
Normal file
@@ -0,0 +1,3 @@
|
||||
//go:build !mlx
|
||||
|
||||
package base
|
||||
97
x/mlxrunner/model/root.go
Normal file
97
x/mlxrunner/model/root.go
Normal file
@@ -0,0 +1,97 @@
|
||||
//go:build mlx
|
||||
|
||||
package model
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/manifest"
|
||||
)
|
||||
|
||||
// Root wraps a ModelManifest with pre-scanned quantization metadata.
|
||||
type Root struct {
|
||||
Manifest *manifest.ModelManifest
|
||||
quantType string
|
||||
groupSize int
|
||||
}
|
||||
|
||||
// Open loads a manifest for the given model name and pre-scans the first
|
||||
// tensor blob for quantization metadata (quant_type, group_size).
|
||||
func Open(modelName string) (*Root, error) {
|
||||
m, err := manifest.LoadManifest(modelName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
root := &Root{Manifest: m}
|
||||
|
||||
// Pre-scan first tensor blob for quantization metadata
|
||||
for _, layer := range m.GetTensorLayers("") {
|
||||
blobPath := m.BlobPath(layer.Digest)
|
||||
meta, err := readBlobMetadata(blobPath)
|
||||
if err != nil || meta == nil {
|
||||
continue
|
||||
}
|
||||
if qt := meta["quant_type"]; qt != "" {
|
||||
root.quantType = strings.ToUpper(qt)
|
||||
}
|
||||
if gs := meta["group_size"]; gs != "" {
|
||||
fmt.Sscanf(gs, "%d", &root.groupSize)
|
||||
}
|
||||
break // only check the first tensor blob
|
||||
}
|
||||
|
||||
return root, nil
|
||||
}
|
||||
|
||||
// Close is a no-op for now (future: release resources).
|
||||
func (r *Root) Close() {}
|
||||
|
||||
// QuantType returns the quantization type detected from tensor metadata.
|
||||
func (r *Root) QuantType() string { return r.quantType }
|
||||
|
||||
// GroupSize returns the quantization group size detected from tensor metadata.
|
||||
func (r *Root) GroupSize() int { return r.groupSize }
|
||||
|
||||
// readBlobMetadata reads the __metadata__ from a safetensors blob header.
|
||||
func readBlobMetadata(path string) (map[string]string, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var headerSize uint64
|
||||
if err := binary.Read(f, binary.LittleEndian, &headerSize); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if headerSize > 1024*1024 {
|
||||
return nil, fmt.Errorf("header too large: %d", headerSize)
|
||||
}
|
||||
|
||||
data := make([]byte, headerSize)
|
||||
if _, err := io.ReadFull(f, data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var header map[string]json.RawMessage
|
||||
if err := json.Unmarshal(data, &header); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
metaRaw, ok := header["__metadata__"]
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var meta map[string]string
|
||||
if err := json.Unmarshal(metaRaw, &meta); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return meta, nil
|
||||
}
|
||||
3
x/mlxrunner/model/root_stub.go
Normal file
3
x/mlxrunner/model/root_stub.go
Normal file
@@ -0,0 +1,3 @@
|
||||
//go:build !mlx
|
||||
|
||||
package model
|
||||
@@ -18,6 +18,8 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
return errors.New("model not loaded")
|
||||
}
|
||||
|
||||
mlx.EnableCompile()
|
||||
|
||||
inputs := r.Tokenizer.Encode(request.Prompt, true)
|
||||
|
||||
caches, tokens := r.FindNearestCache(inputs)
|
||||
@@ -47,7 +49,8 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
}
|
||||
|
||||
step := func(token *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
logits := r.Model.Unembed(r.Model.Forward(token.ExpandDims(0), caches))
|
||||
fwd := r.Model.Forward(token.ExpandDims(0), caches)
|
||||
logits := r.Model.Unembed(fwd)
|
||||
logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1)
|
||||
|
||||
logprobs := logits.Subtract(logits.Logsumexp(true))
|
||||
@@ -60,7 +63,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
var b bytes.Buffer
|
||||
|
||||
now := time.Now()
|
||||
final := Response{PromptTokens: total, CompletionTokens: request.Options.MaxTokens, DoneReason: 1}
|
||||
final := Response{Done: true, PromptTokens: total, CompletionTokens: request.Options.MaxTokens, DoneReason: 1}
|
||||
outputs := make([]int32, 0, request.Options.MaxTokens)
|
||||
for i := range request.Options.MaxTokens {
|
||||
nextSample, nextLogprobs := step(sample)
|
||||
|
||||
@@ -4,30 +4,22 @@ package mlxrunner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/manifest"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||
"github.com/ollama/ollama/x/mlxrunner/sample"
|
||||
"github.com/ollama/ollama/x/models/glm4_moe_lite"
|
||||
)
|
||||
|
||||
// TextModel is the interface that model implementations must satisfy.
|
||||
type TextModel interface {
|
||||
Forward(inputs *mlx.Array, cache []cache.Cache) *mlx.Array
|
||||
Unembed(x *mlx.Array) *mlx.Array
|
||||
NumLayers() int
|
||||
}
|
||||
|
||||
type Request struct {
|
||||
TextCompletionsRequest
|
||||
Responses chan Response
|
||||
@@ -66,52 +58,95 @@ type Response struct {
|
||||
}
|
||||
|
||||
type Runner struct {
|
||||
Model TextModel
|
||||
Model base.Model
|
||||
Tokenizer *tokenizer.Tokenizer
|
||||
Requests chan Request
|
||||
CacheEntries map[int32]*CacheEntry
|
||||
}
|
||||
|
||||
func (r *Runner) Load(modelName string) error {
|
||||
modelManifest, err := manifest.LoadManifest(modelName)
|
||||
root, err := model.Open(modelName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer root.Close()
|
||||
|
||||
m, err := base.New(root)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Read config to detect architecture
|
||||
configData, err := modelManifest.ReadConfig("config.json")
|
||||
// Load all tensor blobs from manifest
|
||||
tensors, err := loadTensorsFromManifest(root)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read config.json: %w", err)
|
||||
return err
|
||||
}
|
||||
|
||||
var archConfig struct {
|
||||
Architectures []string `json:"architectures"`
|
||||
}
|
||||
if err := json.Unmarshal(configData, &archConfig); err != nil {
|
||||
return fmt.Errorf("failed to parse config.json: %w", err)
|
||||
}
|
||||
|
||||
if len(archConfig.Architectures) == 0 {
|
||||
return fmt.Errorf("no architectures found in config.json")
|
||||
}
|
||||
|
||||
slog.Info("Model architecture", "arch", archConfig.Architectures[0])
|
||||
|
||||
switch archConfig.Architectures[0] {
|
||||
case "Glm4MoeLiteForCausalLM", "GLM4MoeLite":
|
||||
model, err := glm4_moe_lite.LoadFromManifest(modelManifest)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load GLM4-MoE-Lite model: %w", err)
|
||||
}
|
||||
r.Model = model
|
||||
r.Tokenizer = model.Tokenizer()
|
||||
default:
|
||||
return fmt.Errorf("unsupported architecture: %s", archConfig.Architectures[0])
|
||||
// Assign weights to model (model-specific logic)
|
||||
loadWeights := base.Weights(m)
|
||||
if err := loadWeights(tensors); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r.Model = m
|
||||
r.Tokenizer = m.Tokenizer()
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadTensorsFromManifest loads all tensor blobs from the manifest into a
|
||||
// flat map, deduplicating by digest and remapping safetensors key suffixes.
|
||||
//
|
||||
// Uses a two-phase approach: first loads all raw tensors, then remaps
|
||||
// .bias → _qbias with complete knowledge of which base names have .scale
|
||||
// entries. This avoids a race condition where Go map iteration order could
|
||||
// cause .bias to be processed before .scale within the same blob.
|
||||
func loadTensorsFromManifest(root *model.Root) (map[string]*mlx.Array, error) {
|
||||
// Phase 1: Load all tensors raw from all blobs
|
||||
rawTensors := make(map[string]*mlx.Array)
|
||||
seen := make(map[string]bool)
|
||||
for _, layer := range root.Manifest.GetTensorLayers("") {
|
||||
if seen[layer.Digest] {
|
||||
continue
|
||||
}
|
||||
seen[layer.Digest] = true
|
||||
blobPath := root.Manifest.BlobPath(layer.Digest)
|
||||
for name, arr := range mlx.Load(blobPath) {
|
||||
rawTensors[name] = arr
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 2: Identify all base names that have .scale tensors and remap them
|
||||
scaleBaseNames := make(map[string]bool)
|
||||
allTensors := make(map[string]*mlx.Array, len(rawTensors))
|
||||
for name, arr := range rawTensors {
|
||||
if strings.HasSuffix(name, ".scale") {
|
||||
baseName := strings.TrimSuffix(name, ".scale")
|
||||
allTensors[baseName+"_scale"] = arr
|
||||
scaleBaseNames[baseName] = true
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 3: Process remaining tensors with complete scale knowledge
|
||||
for name, arr := range rawTensors {
|
||||
if strings.HasSuffix(name, ".scale") {
|
||||
continue // already handled
|
||||
}
|
||||
if strings.HasSuffix(name, ".bias") && !strings.HasSuffix(name, ".weight_qbias") {
|
||||
baseName := strings.TrimSuffix(name, ".bias")
|
||||
if scaleBaseNames[baseName] {
|
||||
allTensors[baseName+"_qbias"] = arr
|
||||
} else {
|
||||
allTensors[name] = arr
|
||||
}
|
||||
} else {
|
||||
allTensors[name] = arr
|
||||
}
|
||||
}
|
||||
|
||||
slog.Info("Loaded tensors from manifest", "count", len(allTensors))
|
||||
return allTensors, nil
|
||||
}
|
||||
|
||||
func (r *Runner) Run(host, port string, mux http.Handler) error {
|
||||
g, ctx := errgroup.WithContext(context.Background())
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ func (c chain) Sample(logits *mlx.Array) *mlx.Array {
|
||||
type Temperature float32
|
||||
|
||||
func (t Temperature) Sample(logits *mlx.Array) *mlx.Array {
|
||||
return logits.Multiply(mlx.FromValue(1 / float32(t))).Categorical(-1)
|
||||
return mlx.DivScalar(logits, float32(t)).Categorical(-1)
|
||||
}
|
||||
|
||||
type TopP float32
|
||||
|
||||
@@ -5,21 +5,24 @@
|
||||
package glm4_moe_lite
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/manifest"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||
"github.com/ollama/ollama/x/models/nn"
|
||||
)
|
||||
|
||||
func init() {
|
||||
base.Register("Glm4MoeLiteForCausalLM", newModel)
|
||||
base.Register("GLM4MoeLite", newModel)
|
||||
}
|
||||
|
||||
// RopeScaling holds RoPE scaling configuration
|
||||
type RopeScaling struct {
|
||||
Factor float32 `json:"factor"`
|
||||
@@ -131,7 +134,6 @@ func (a *MLAAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Con
|
||||
queries := mlx.Concatenate([]*mlx.Array{qLatent, qPE}, 3)
|
||||
|
||||
out := mlx.ScaledDotProductAttentionCausal(queries, keys, values, cfg.Scale, L > 1)
|
||||
|
||||
out = a.UnembedOut.Forward(out)
|
||||
|
||||
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.VHeadDim)
|
||||
@@ -386,44 +388,6 @@ func quantizationParams(quantization string) (groupSize, bits int, mode string)
|
||||
}
|
||||
}
|
||||
|
||||
// readBlobMetadata reads the __metadata__ from a safetensors blob header.
|
||||
func readBlobMetadata(path string) (map[string]string, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var headerSize uint64
|
||||
if err := binary.Read(f, binary.LittleEndian, &headerSize); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if headerSize > 1024*1024 {
|
||||
return nil, fmt.Errorf("header too large: %d", headerSize)
|
||||
}
|
||||
|
||||
data := make([]byte, headerSize)
|
||||
if _, err := io.ReadFull(f, data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var header map[string]json.RawMessage
|
||||
if err := json.Unmarshal(data, &header); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
metaRaw, ok := header["__metadata__"]
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var meta map[string]string
|
||||
if err := json.Unmarshal(metaRaw, &meta); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return meta, nil
|
||||
}
|
||||
|
||||
// ExpertWeight holds a single expert's weight with optional quantization components.
|
||||
type ExpertWeight struct {
|
||||
Weight *mlx.Array
|
||||
@@ -569,9 +533,10 @@ func makeLinear(tensors map[string]*mlx.Array, path string, cfg *Config) nn.Line
|
||||
return nn.NewLinear(w, bias)
|
||||
}
|
||||
|
||||
// LoadFromManifest loads a GLM4-MoE-Lite model from a manifest (Ollama blob storage).
|
||||
func LoadFromManifest(modelManifest *manifest.ModelManifest) (*Model, error) {
|
||||
configData, err := modelManifest.ReadConfig("config.json")
|
||||
// newModel creates a new GLM4-MoE-Lite model from a Root (config + tokenizer,
|
||||
// no weights loaded yet). Called by the registry via base.New().
|
||||
func newModel(root *model.Root) (base.Model, error) {
|
||||
configData, err := root.Manifest.ReadConfig("config.json")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
@@ -584,66 +549,18 @@ func LoadFromManifest(modelManifest *manifest.ModelManifest) (*Model, error) {
|
||||
cfg.QHeadDim = cfg.QKNopeHeadDim + cfg.QKRopeHeadDim
|
||||
cfg.Scale = computeScale(&cfg)
|
||||
|
||||
// Load all tensors from manifest blobs into a flat map
|
||||
allTensors := make(map[string]*mlx.Array)
|
||||
seen := make(map[string]bool) // dedupe by digest
|
||||
var quantType string
|
||||
var quantGroupSize int
|
||||
|
||||
for _, layer := range modelManifest.GetTensorLayers("") {
|
||||
if seen[layer.Digest] {
|
||||
continue
|
||||
}
|
||||
seen[layer.Digest] = true
|
||||
blobPath := modelManifest.BlobPath(layer.Digest)
|
||||
|
||||
// Read quantization metadata from first blob
|
||||
if quantType == "" {
|
||||
if meta, err := readBlobMetadata(blobPath); err == nil && meta != nil {
|
||||
if qt := meta["quant_type"]; qt != "" {
|
||||
quantType = strings.ToUpper(qt)
|
||||
}
|
||||
if gs := meta["group_size"]; gs != "" {
|
||||
fmt.Sscanf(gs, "%d", &quantGroupSize)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for name, arr := range mlx.Load(blobPath) {
|
||||
// Map safetensors key naming to our naming convention
|
||||
// Combined blobs use ".scale" and ".bias" suffixes
|
||||
if strings.HasSuffix(name, ".scale") {
|
||||
baseName := strings.TrimSuffix(name, ".scale")
|
||||
allTensors[baseName+"_scale"] = arr
|
||||
} else if strings.HasSuffix(name, ".bias") && !strings.HasSuffix(name, ".weight_qbias") {
|
||||
// Check if this is a quantization bias or a regular bias
|
||||
// by checking if there's a corresponding weight
|
||||
baseName := strings.TrimSuffix(name, ".bias")
|
||||
if _, hasScale := allTensors[baseName+"_scale"]; hasScale {
|
||||
allTensors[baseName+"_qbias"] = arr
|
||||
} else {
|
||||
allTensors[name] = arr
|
||||
}
|
||||
} else {
|
||||
allTensors[name] = arr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Set up quantization parameters
|
||||
useQuantized := false
|
||||
if quantType != "" {
|
||||
_, cfg.QuantBits, cfg.QuantMode = quantizationParams(quantType)
|
||||
if quantGroupSize > 0 {
|
||||
cfg.QuantGroupSize = quantGroupSize
|
||||
// Set up quantization parameters from pre-scanned metadata
|
||||
if qt := root.QuantType(); qt != "" {
|
||||
_, cfg.QuantBits, cfg.QuantMode = quantizationParams(qt)
|
||||
if gs := root.GroupSize(); gs > 0 {
|
||||
cfg.QuantGroupSize = gs
|
||||
} else {
|
||||
cfg.QuantGroupSize, _, _ = quantizationParams(quantType)
|
||||
cfg.QuantGroupSize, _, _ = quantizationParams(qt)
|
||||
}
|
||||
useQuantized = supportsGatherQMM(cfg.QuantMode, cfg.QuantBits)
|
||||
}
|
||||
|
||||
// Load tokenizer
|
||||
tokData, err := modelManifest.ReadConfig("tokenizer.json")
|
||||
tokData, err := root.Manifest.ReadConfig("tokenizer.json")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load tokenizer config: %w", err)
|
||||
}
|
||||
@@ -652,11 +569,11 @@ func LoadFromManifest(modelManifest *manifest.ModelManifest) (*Model, error) {
|
||||
ConfigJSON: configData,
|
||||
}
|
||||
|
||||
if genConfigData, err := modelManifest.ReadConfig("generation_config.json"); err == nil {
|
||||
if genConfigData, err := root.Manifest.ReadConfig("generation_config.json"); err == nil {
|
||||
tokConfig.GenerationConfigJSON = genConfigData
|
||||
}
|
||||
|
||||
if tokConfigData, err := modelManifest.ReadConfig("tokenizer_config.json"); err == nil {
|
||||
if tokConfigData, err := root.Manifest.ReadConfig("tokenizer_config.json"); err == nil {
|
||||
tokConfig.TokenizerConfigJSON = tokConfigData
|
||||
}
|
||||
|
||||
@@ -671,18 +588,28 @@ func LoadFromManifest(modelManifest *manifest.ModelManifest) (*Model, error) {
|
||||
tok: tok,
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// LoadWeights receives all tensors loaded from the manifest and assigns them
|
||||
// to model fields. Handles MLA absorption, expert stacking, and quantized
|
||||
// layer creation.
|
||||
func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
|
||||
cfg := m.Config
|
||||
useQuantized := supportsGatherQMM(cfg.QuantMode, cfg.QuantBits)
|
||||
|
||||
// Load embedding
|
||||
if w := allTensors["model.embed_tokens.weight"]; w != nil {
|
||||
if w := tensors["model.embed_tokens.weight"]; w != nil {
|
||||
m.EmbedTokens = nn.NewEmbedding(w)
|
||||
}
|
||||
|
||||
// Load final norm
|
||||
if w := allTensors["model.norm.weight"]; w != nil {
|
||||
if w := tensors["model.norm.weight"]; w != nil {
|
||||
m.Norm = nn.NewRMSNorm(w, cfg.RMSNormEps)
|
||||
}
|
||||
|
||||
// Load LM head
|
||||
m.LMHead = makeLinear(allTensors, "lm_head", &cfg)
|
||||
m.LMHead = makeLinear(tensors, "lm_head", cfg)
|
||||
|
||||
// Load layers
|
||||
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
|
||||
@@ -690,24 +617,24 @@ func LoadFromManifest(modelManifest *manifest.ModelManifest) (*Model, error) {
|
||||
|
||||
// Load attention (same for both block types)
|
||||
attn := &MLAAttention{}
|
||||
attn.QAProj = makeLinear(allTensors, prefix+".self_attn.q_a_proj", &cfg)
|
||||
if w := allTensors[prefix+".self_attn.q_a_layernorm.weight"]; w != nil {
|
||||
attn.QAProj = makeLinear(tensors, prefix+".self_attn.q_a_proj", cfg)
|
||||
if w := tensors[prefix+".self_attn.q_a_layernorm.weight"]; w != nil {
|
||||
attn.QALayerNorm = nn.NewRMSNorm(w, cfg.RMSNormEps)
|
||||
}
|
||||
attn.QBProj = makeLinear(allTensors, prefix+".self_attn.q_b_proj", &cfg)
|
||||
attn.KVAProjWithMQA = makeLinear(allTensors, prefix+".self_attn.kv_a_proj_with_mqa", &cfg)
|
||||
if w := allTensors[prefix+".self_attn.kv_a_layernorm.weight"]; w != nil {
|
||||
attn.QBProj = makeLinear(tensors, prefix+".self_attn.q_b_proj", cfg)
|
||||
attn.KVAProjWithMQA = makeLinear(tensors, prefix+".self_attn.kv_a_proj_with_mqa", cfg)
|
||||
if w := tensors[prefix+".self_attn.kv_a_layernorm.weight"]; w != nil {
|
||||
attn.KVALayerNorm = nn.NewRMSNorm(w, cfg.RMSNormEps)
|
||||
}
|
||||
attn.OProj = makeLinear(allTensors, prefix+".self_attn.o_proj", &cfg)
|
||||
attn.OProj = makeLinear(tensors, prefix+".self_attn.o_proj", cfg)
|
||||
|
||||
// Sanitize MLA weights for absorbed attention
|
||||
embedQ, unembedOut := sanitizeMLAWeights(allTensors, prefix, &cfg)
|
||||
embedQ, unembedOut := sanitizeMLAWeights(tensors, prefix, cfg)
|
||||
attn.EmbedQ = nn.NewMultiLinear(embedQ)
|
||||
attn.UnembedOut = nn.NewMultiLinear(unembedOut)
|
||||
|
||||
inputLN := allTensors[prefix+".input_layernorm.weight"]
|
||||
postAttnLN := allTensors[prefix+".post_attention_layernorm.weight"]
|
||||
inputLN := tensors[prefix+".input_layernorm.weight"]
|
||||
postAttnLN := tensors[prefix+".post_attention_layernorm.weight"]
|
||||
|
||||
if i < cfg.FirstKDenseReplace {
|
||||
// Dense block
|
||||
@@ -720,9 +647,9 @@ func LoadFromManifest(modelManifest *manifest.ModelManifest) (*Model, error) {
|
||||
}
|
||||
|
||||
block.MLP = &DenseMLP{
|
||||
GateProj: makeLinear(allTensors, prefix+".mlp.gate_proj", &cfg),
|
||||
UpProj: makeLinear(allTensors, prefix+".mlp.up_proj", &cfg),
|
||||
DownProj: makeLinear(allTensors, prefix+".mlp.down_proj", &cfg),
|
||||
GateProj: makeLinear(tensors, prefix+".mlp.gate_proj", cfg),
|
||||
UpProj: makeLinear(tensors, prefix+".mlp.up_proj", cfg),
|
||||
DownProj: makeLinear(tensors, prefix+".mlp.down_proj", cfg),
|
||||
}
|
||||
|
||||
m.Layers[i] = block
|
||||
@@ -737,7 +664,7 @@ func LoadFromManifest(modelManifest *manifest.ModelManifest) (*Model, error) {
|
||||
}
|
||||
|
||||
// Stack expert weights
|
||||
gate, up, down := sanitizeExpertWeights(allTensors, prefix, cfg.NRoutedExperts, useQuantized, &cfg)
|
||||
gate, up, down := sanitizeExpertWeights(tensors, prefix, cfg.NRoutedExperts, useQuantized, cfg)
|
||||
|
||||
switchMLP := &SwitchMLP{UseQuantized: useQuantized}
|
||||
if useQuantized {
|
||||
@@ -763,8 +690,8 @@ func LoadFromManifest(modelManifest *manifest.ModelManifest) (*Model, error) {
|
||||
}
|
||||
|
||||
moeGate := &MoEGate{}
|
||||
moeGate.Gate = makeLinear(allTensors, prefix+".mlp.gate", &cfg)
|
||||
if bias := allTensors[prefix+".mlp.gate.e_score_correction_bias"]; bias != nil {
|
||||
moeGate.Gate = makeLinear(tensors, prefix+".mlp.gate", cfg)
|
||||
if bias := tensors[prefix+".mlp.gate.e_score_correction_bias"]; bias != nil {
|
||||
moeGate.EScoreCorrectionBias = bias
|
||||
}
|
||||
|
||||
@@ -776,9 +703,9 @@ func LoadFromManifest(modelManifest *manifest.ModelManifest) (*Model, error) {
|
||||
// Load shared experts if present
|
||||
if cfg.NSharedExperts > 0 {
|
||||
block.MoE.SharedExperts = &SharedExperts{
|
||||
GateProj: makeLinear(allTensors, prefix+".mlp.shared_experts.gate_proj", &cfg),
|
||||
UpProj: makeLinear(allTensors, prefix+".mlp.shared_experts.up_proj", &cfg),
|
||||
DownProj: makeLinear(allTensors, prefix+".mlp.shared_experts.down_proj", &cfg),
|
||||
GateProj: makeLinear(tensors, prefix+".mlp.shared_experts.gate_proj", cfg),
|
||||
UpProj: makeLinear(tensors, prefix+".mlp.shared_experts.up_proj", cfg),
|
||||
DownProj: makeLinear(tensors, prefix+".mlp.shared_experts.down_proj", cfg),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -786,9 +713,10 @@ func LoadFromManifest(modelManifest *manifest.ModelManifest) (*Model, error) {
|
||||
}
|
||||
}
|
||||
|
||||
mlx.Eval(mlx.Collect(m)...)
|
||||
collected := mlx.Collect(m)
|
||||
mlx.Eval(collected...)
|
||||
|
||||
return m, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// Forward computes the forward pass of the model
|
||||
|
||||
Reference in New Issue
Block a user