mirror of
https://github.com/ollama/ollama.git
synced 2026-02-20 08:16:07 -05:00
* load glm4_moe_lite from the mlxrunner * fix loading diffusion models * remove log lines * fix --imagegen flag
86 lines
2.4 KiB
Go
86 lines
2.4 KiB
Go
//go:build mlx
|
|
|
|
package base
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"log/slog"
|
|
"sync"
|
|
|
|
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
|
"github.com/ollama/ollama/x/mlxrunner/cache"
|
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
|
"github.com/ollama/ollama/x/mlxrunner/model"
|
|
)
|
|
|
|
// Model is the interface that model implementations must satisfy.
|
|
type Model interface {
|
|
Forward(inputs *mlx.Array, cache []cache.Cache) *mlx.Array
|
|
Unembed(x *mlx.Array) *mlx.Array
|
|
NumLayers() int
|
|
Tokenizer() *tokenizer.Tokenizer
|
|
|
|
// LoadWeights receives all tensors loaded from the manifest and assigns
|
|
// them to model fields. Model-specific logic (MLA absorption, expert
|
|
// stacking, quantized layer creation) happens here.
|
|
LoadWeights(tensors map[string]*mlx.Array) error
|
|
}
|
|
|
|
var (
|
|
mu sync.Mutex
|
|
registry = make(map[string]func(root *model.Root) (Model, error))
|
|
)
|
|
|
|
// Register registers a model constructor by architecture name.
|
|
// Called from init() in model packages. Panics on duplicate registration.
|
|
func Register(arch string, fn func(root *model.Root) (Model, error)) {
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
|
|
if _, exists := registry[arch]; exists {
|
|
panic(fmt.Sprintf("model architecture %q already registered", arch))
|
|
}
|
|
registry[arch] = fn
|
|
}
|
|
|
|
// New reads config.json from the manifest, detects the architecture, looks up
|
|
// the registered constructor, and calls it to create the model (with config
|
|
// parsed and struct created, but weights not yet loaded).
|
|
func New(root *model.Root) (Model, error) {
|
|
configData, err := root.Manifest.ReadConfig("config.json")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to read config.json: %w", err)
|
|
}
|
|
|
|
var archConfig struct {
|
|
Architectures []string `json:"architectures"`
|
|
}
|
|
if err := json.Unmarshal(configData, &archConfig); err != nil {
|
|
return nil, fmt.Errorf("failed to parse config.json: %w", err)
|
|
}
|
|
|
|
if len(archConfig.Architectures) == 0 {
|
|
return nil, fmt.Errorf("no architectures found in config.json")
|
|
}
|
|
|
|
arch := archConfig.Architectures[0]
|
|
slog.Info("Model architecture", "arch", arch)
|
|
|
|
mu.Lock()
|
|
fn, ok := registry[arch]
|
|
mu.Unlock()
|
|
|
|
if !ok {
|
|
return nil, fmt.Errorf("unsupported architecture: %s", arch)
|
|
}
|
|
|
|
return fn(root)
|
|
}
|
|
|
|
// Weights returns the model's LoadWeights method, which encapsulates all
|
|
// weight assignment and post-processing (MLA absorption, expert stacking).
|
|
func Weights(m Model) func(map[string]*mlx.Array) error {
|
|
return m.LoadWeights
|
|
}
|