mirror of
https://github.com/ollama/ollama.git
synced 2026-01-24 07:20:57 -05:00
Compare commits
3 Commits
v0.15.0
...
pdevine/gl
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
49905784f1 | ||
|
|
a00721f586 | ||
|
|
98ca1c3904 |
@@ -3,7 +3,7 @@ package runner
|
||||
import (
|
||||
"github.com/ollama/ollama/runner/llamarunner"
|
||||
"github.com/ollama/ollama/runner/ollamarunner"
|
||||
imagerunner "github.com/ollama/ollama/x/imagegen/runner"
|
||||
"github.com/ollama/ollama/x/mlxrunner"
|
||||
)
|
||||
|
||||
func Execute(args []string) error {
|
||||
@@ -12,18 +12,18 @@ func Execute(args []string) error {
|
||||
}
|
||||
|
||||
var newRunner bool
|
||||
var imageRunner bool
|
||||
var mlxRunner bool
|
||||
if len(args) > 0 && args[0] == "--ollama-engine" {
|
||||
args = args[1:]
|
||||
newRunner = true
|
||||
}
|
||||
if len(args) > 0 && args[0] == "--image-engine" {
|
||||
if len(args) > 0 && args[0] == "--mlx-engine" {
|
||||
args = args[1:]
|
||||
imageRunner = true
|
||||
mlxRunner = true
|
||||
}
|
||||
|
||||
if imageRunner {
|
||||
return imagerunner.Execute(args)
|
||||
if mlxRunner {
|
||||
return mlxrunner.Execute(args)
|
||||
} else if newRunner {
|
||||
return ollamarunner.Execute(args)
|
||||
} else {
|
||||
|
||||
@@ -31,6 +31,7 @@ import (
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/version"
|
||||
"github.com/ollama/ollama/x/imagegen/transfer"
|
||||
xserver "github.com/ollama/ollama/x/server"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -129,6 +130,14 @@ func (m *Model) Capabilities() []model.Capability {
|
||||
return capabilities
|
||||
}
|
||||
|
||||
// Check for thinking capability in safetensors LLM models based on architecture
|
||||
if m.Config.ModelFormat == "safetensors" && slices.Contains(m.Config.Capabilities, "completion") {
|
||||
if xserver.IsSafetensorsThinkingModel(model.ParseName(m.Name)) {
|
||||
capabilities = append(capabilities, model.CapabilityThinking)
|
||||
return capabilities
|
||||
}
|
||||
}
|
||||
|
||||
// Check for thinking capability
|
||||
openingTag, closingTag := thinking.InferTags(m.Template.Template)
|
||||
hasTags := openingTag != "" && closingTag != ""
|
||||
|
||||
@@ -21,7 +21,7 @@ import (
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/mlxrunner"
|
||||
)
|
||||
|
||||
type LlmRequest struct {
|
||||
@@ -195,14 +195,25 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
||||
slog.Debug("updating default concurrency", "OLLAMA_MAX_LOADED_MODELS", maxRunners, "gpu_count", len(gpus))
|
||||
}
|
||||
|
||||
// Check for image generation model before attempting GGML load
|
||||
// Check for image generation models - all use MLX runner
|
||||
if slices.Contains(pending.model.Config.Capabilities, "image") {
|
||||
if s.loadImageGen(pending) {
|
||||
if s.loadMLX(pending) {
|
||||
break
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Check for experimental safetensors LLM models
|
||||
if pending.model.Config.ModelFormat == "safetensors" {
|
||||
if slices.Contains(pending.model.Config.Capabilities, "completion") {
|
||||
// LLM model with safetensors format - use MLX runner
|
||||
if s.loadMLX(pending) {
|
||||
break
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Load model for fitting
|
||||
logutil.Trace("loading model metadata", "model", pending.model.ModelPath)
|
||||
ggml, err := llm.LoadModel(pending.model.ModelPath, 1024)
|
||||
@@ -552,11 +563,20 @@ iGPUScan:
|
||||
return false
|
||||
}
|
||||
|
||||
// loadImageGen loads an image generation model.
|
||||
func (s *Scheduler) loadImageGen(req *LlmRequest) bool {
|
||||
// Use model name for imagegen (it resolves manifests by name, not file path)
|
||||
// loadMLX loads an experimental safetensors model using the unified MLX runner.
|
||||
// This supports both LLM (completion) and image generation models.
|
||||
func (s *Scheduler) loadMLX(req *LlmRequest) bool {
|
||||
// Determine mode based on capabilities
|
||||
var mode mlxrunner.ModelMode
|
||||
if slices.Contains(req.model.Config.Capabilities, "image") {
|
||||
mode = mlxrunner.ModeImageGen
|
||||
} else {
|
||||
mode = mlxrunner.ModeLLM
|
||||
}
|
||||
|
||||
// Use model name for MLX (it resolves manifests by name, not file path)
|
||||
modelName := req.model.ShortName
|
||||
server, err := imagegen.NewServer(modelName)
|
||||
server, err := mlxrunner.NewServer(modelName, mode)
|
||||
if err != nil {
|
||||
req.errCh <- err
|
||||
return true
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/progress"
|
||||
@@ -53,10 +54,20 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error {
|
||||
// Determine model type settings
|
||||
var modelType, spinnerKey string
|
||||
var capabilities []string
|
||||
var parserName, rendererName string
|
||||
if isSafetensors {
|
||||
modelType = "safetensors model"
|
||||
spinnerKey = "create"
|
||||
capabilities = []string{"completion"}
|
||||
|
||||
// Check if model supports thinking based on architecture
|
||||
if supportsThinking(opts.ModelDir) {
|
||||
capabilities = append(capabilities, "thinking")
|
||||
}
|
||||
|
||||
// Set parser and renderer name based on architecture
|
||||
parserName = getParserName(opts.ModelDir)
|
||||
rendererName = getRendererName(opts.ModelDir)
|
||||
} else {
|
||||
modelType = "image generation model"
|
||||
spinnerKey = "imagegen"
|
||||
@@ -81,14 +92,14 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error {
|
||||
err = create.CreateSafetensorsModel(
|
||||
opts.ModelName, opts.ModelDir, opts.Quantize,
|
||||
newLayerCreator(), newTensorLayerCreator(),
|
||||
newManifestWriter(opts, capabilities),
|
||||
newManifestWriter(opts, capabilities, parserName, rendererName),
|
||||
progressFn,
|
||||
)
|
||||
} else {
|
||||
err = create.CreateImageGenModel(
|
||||
opts.ModelName, opts.ModelDir, opts.Quantize,
|
||||
newLayerCreator(), newTensorLayerCreator(),
|
||||
newManifestWriter(opts, capabilities),
|
||||
newManifestWriter(opts, capabilities, "", ""),
|
||||
progressFn,
|
||||
)
|
||||
}
|
||||
@@ -204,7 +215,7 @@ func createUnquantizedLayer(r io.Reader, name string) ([]create.LayerInfo, error
|
||||
}
|
||||
|
||||
// newManifestWriter returns a ManifestWriter callback for writing the model manifest.
|
||||
func newManifestWriter(opts CreateOptions, capabilities []string) create.ManifestWriter {
|
||||
func newManifestWriter(opts CreateOptions, capabilities []string, parserName, rendererName string) create.ManifestWriter {
|
||||
return func(modelName string, config create.LayerInfo, layers []create.LayerInfo) error {
|
||||
name := model.ParseName(modelName)
|
||||
if !name.IsValid() {
|
||||
@@ -229,6 +240,8 @@ func newManifestWriter(opts CreateOptions, capabilities []string) create.Manifes
|
||||
ModelFormat: "safetensors",
|
||||
Capabilities: caps,
|
||||
Requires: MinOllamaVersion,
|
||||
Parser: parserName,
|
||||
Renderer: rendererName,
|
||||
}
|
||||
configJSON, err := json.Marshal(configData)
|
||||
if err != nil {
|
||||
@@ -295,3 +308,146 @@ func createModelfileLayers(mf *ModelfileConfig) ([]manifest.Layer, error) {
|
||||
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
// supportsThinking checks if the model supports thinking mode based on its architecture.
|
||||
// This reads the config.json from the model directory and checks the architectures field.
|
||||
func supportsThinking(modelDir string) bool {
|
||||
configPath := filepath.Join(modelDir, "config.json")
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
var cfg struct {
|
||||
Architectures []string `json:"architectures"`
|
||||
ModelType string `json:"model_type"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check architectures that support thinking
|
||||
thinkingArchitectures := []string{
|
||||
"glm4moe", // GLM-4 MoE models
|
||||
"deepseek", // DeepSeek models
|
||||
"qwen3", // Qwen3 models
|
||||
}
|
||||
|
||||
// Check the architecture list
|
||||
for _, arch := range cfg.Architectures {
|
||||
archLower := strings.ToLower(arch)
|
||||
for _, thinkArch := range thinkingArchitectures {
|
||||
if strings.Contains(archLower, thinkArch) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Also check model_type
|
||||
if cfg.ModelType != "" {
|
||||
typeLower := strings.ToLower(cfg.ModelType)
|
||||
for _, thinkArch := range thinkingArchitectures {
|
||||
if strings.Contains(typeLower, thinkArch) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// getParserName returns the parser name for a model based on its architecture.
|
||||
// This reads the config.json from the model directory and determines the appropriate parser.
|
||||
func getParserName(modelDir string) string {
|
||||
configPath := filepath.Join(modelDir, "config.json")
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
var cfg struct {
|
||||
Architectures []string `json:"architectures"`
|
||||
ModelType string `json:"model_type"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Check architectures for known parsers
|
||||
for _, arch := range cfg.Architectures {
|
||||
archLower := strings.ToLower(arch)
|
||||
if strings.Contains(archLower, "glm4") || strings.Contains(archLower, "glm-4") {
|
||||
return "glm-4.7"
|
||||
}
|
||||
if strings.Contains(archLower, "deepseek") {
|
||||
return "deepseek3"
|
||||
}
|
||||
if strings.Contains(archLower, "qwen3") {
|
||||
return "qwen3-coder"
|
||||
}
|
||||
}
|
||||
|
||||
// Also check model_type
|
||||
if cfg.ModelType != "" {
|
||||
typeLower := strings.ToLower(cfg.ModelType)
|
||||
if strings.Contains(typeLower, "glm4") || strings.Contains(typeLower, "glm-4") {
|
||||
return "glm-4.7"
|
||||
}
|
||||
if strings.Contains(typeLower, "deepseek") {
|
||||
return "deepseek3"
|
||||
}
|
||||
if strings.Contains(typeLower, "qwen3") {
|
||||
return "qwen3-coder"
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// getRendererName returns the renderer name for a model based on its architecture.
|
||||
// This reads the config.json from the model directory and determines the appropriate renderer.
|
||||
func getRendererName(modelDir string) string {
|
||||
configPath := filepath.Join(modelDir, "config.json")
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
var cfg struct {
|
||||
Architectures []string `json:"architectures"`
|
||||
ModelType string `json:"model_type"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Check architectures for known renderers
|
||||
for _, arch := range cfg.Architectures {
|
||||
archLower := strings.ToLower(arch)
|
||||
if strings.Contains(archLower, "glm4") || strings.Contains(archLower, "glm-4") {
|
||||
return "glm-4.7"
|
||||
}
|
||||
if strings.Contains(archLower, "deepseek") {
|
||||
return "deepseek3"
|
||||
}
|
||||
if strings.Contains(archLower, "qwen3") {
|
||||
return "qwen3-coder"
|
||||
}
|
||||
}
|
||||
|
||||
// Also check model_type
|
||||
if cfg.ModelType != "" {
|
||||
typeLower := strings.ToLower(cfg.ModelType)
|
||||
if strings.Contains(typeLower, "glm4") || strings.Contains(typeLower, "glm-4") {
|
||||
return "glm-4.7"
|
||||
}
|
||||
if strings.Contains(typeLower, "deepseek") {
|
||||
return "deepseek3"
|
||||
}
|
||||
if strings.Contains(typeLower, "qwen3") {
|
||||
return "qwen3-coder"
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -13,7 +13,10 @@ import (
|
||||
|
||||
// quantizeTensor loads a tensor from safetensors format, quantizes it,
|
||||
// and returns safetensors data for the quantized weights, scales, and biases.
|
||||
// Supported quantization types: "fp8" (affine 8-bit)
|
||||
// Supported quantization types:
|
||||
// - "fp4": affine 4-bit, group_size=32 (with qbiases)
|
||||
// - "nvfp4": NVIDIA FP4, group_size=16 (no qbiases, E4M3 scales)
|
||||
// - "fp8": affine 8-bit, group_size=32 (with qbiases)
|
||||
// Uses MLX's native SaveSafetensors to ensure correct dtype handling (especially uint32 for quantized weights).
|
||||
func quantizeTensor(r io.Reader, name, dtype string, shape []int32, quantize string) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) {
|
||||
tmpDir := ensureTempDir()
|
||||
@@ -55,10 +58,13 @@ func quantizeTensor(r io.Reader, name, dtype string, shape []int32, quantize str
|
||||
var qweight, scales, qbiases *mlx.Array
|
||||
switch quantize {
|
||||
case "fp4":
|
||||
// affine mode: group_size=32, bits=4
|
||||
// affine mode: group_size=32, bits=4 (with qbiases for zero-point offset)
|
||||
qweight, scales, qbiases = mlx.Quantize(arr, 32, 4, "affine")
|
||||
case "nvfp4":
|
||||
// NVIDIA FP4: group_size=16, bits=4 (no qbiases, E4M3 scales)
|
||||
qweight, scales, qbiases = mlx.Quantize(arr, 16, 4, "nvfp4")
|
||||
case "fp8":
|
||||
// affine mode: group_size=32, bits=8
|
||||
// affine mode: group_size=32, bits=8 (with qbiases for zero-point offset)
|
||||
qweight, scales, qbiases = mlx.Quantize(arr, 32, 8, "affine")
|
||||
default:
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("unsupported quantization type: %s", quantize)
|
||||
|
||||
@@ -262,9 +262,10 @@ func ShouldQuantize(name, component string) bool {
|
||||
return strings.HasSuffix(name, ".weight")
|
||||
}
|
||||
|
||||
// ShouldQuantizeTensor returns true if a tensor should be quantized based on name and shape.
|
||||
// ShouldQuantizeTensor returns true if a tensor should be quantized based on name, shape, and quantize type.
|
||||
// This is a more detailed check that also considers tensor dimensions.
|
||||
func ShouldQuantizeTensor(name string, shape []int32) bool {
|
||||
// The quantize parameter specifies the quantization type (e.g., "fp4", "nvfp4", "fp8").
|
||||
func ShouldQuantizeTensor(name string, shape []int32, quantize string) bool {
|
||||
// Use basic name-based check first
|
||||
if !ShouldQuantize(name, "") {
|
||||
return false
|
||||
@@ -280,8 +281,13 @@ func ShouldQuantizeTensor(name string, shape []int32) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// MLX quantization requires last dimension to be divisible by group size (32)
|
||||
if shape[len(shape)-1]%32 != 0 {
|
||||
// MLX quantization requires last dimension to be divisible by group size
|
||||
// NVFP4 uses group_size=16, all other modes use group_size=32
|
||||
groupSize := int32(32)
|
||||
if strings.ToUpper(quantize) == "NVFP4" {
|
||||
groupSize = 16
|
||||
}
|
||||
if shape[len(shape)-1]%groupSize != 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -331,7 +337,7 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
|
||||
|
||||
// Determine quantization type for this tensor (empty string if not quantizing)
|
||||
quantizeType := ""
|
||||
if quantize != "" && ShouldQuantizeTensor(tensorName, td.Shape) {
|
||||
if quantize != "" && ShouldQuantizeTensor(tensorName, td.Shape, quantize) {
|
||||
quantizeType = quantize
|
||||
}
|
||||
|
||||
@@ -388,6 +394,22 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
|
||||
return fmt.Errorf("config.json not found in %s", modelDir)
|
||||
}
|
||||
|
||||
// Create model_index.json with quantization info if quantizing
|
||||
if quantize != "" {
|
||||
modelIndex := map[string]any{
|
||||
"quantization": strings.ToUpper(quantize),
|
||||
}
|
||||
indexData, err := json.MarshalIndent(modelIndex, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal model_index.json: %w", err)
|
||||
}
|
||||
indexLayer, err := createLayer(strings.NewReader(string(indexData)), "application/vnd.ollama.image.json", "model_index.json")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create model_index.json layer: %w", err)
|
||||
}
|
||||
layers = append(layers, indexLayer)
|
||||
}
|
||||
|
||||
fn(fmt.Sprintf("writing manifest for %s", modelName))
|
||||
|
||||
if err := writeManifest(modelName, configLayer, layers); err != nil {
|
||||
|
||||
@@ -536,41 +536,51 @@ func TestShouldQuantize(t *testing.T) {
|
||||
|
||||
func TestShouldQuantizeTensor(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tensor string
|
||||
shape []int32
|
||||
want bool
|
||||
name string
|
||||
tensor string
|
||||
shape []int32
|
||||
quantize string
|
||||
want bool
|
||||
}{
|
||||
// 2D tensors with sufficient size should be quantized
|
||||
{"large 2D weight", "q_proj.weight", []int32{4096, 4096}, true},
|
||||
{"medium 2D weight", "small_proj.weight", []int32{128, 128}, true},
|
||||
{"large 2D weight fp8", "q_proj.weight", []int32{4096, 4096}, "fp8", true},
|
||||
{"medium 2D weight fp8", "small_proj.weight", []int32{128, 128}, "fp8", true},
|
||||
{"large 2D weight nvfp4", "q_proj.weight", []int32{4096, 4096}, "nvfp4", true},
|
||||
|
||||
// Small tensors should not be quantized (< 1024 elements)
|
||||
{"tiny 2D weight", "tiny.weight", []int32{16, 16}, false},
|
||||
{"small 2D weight", "small.weight", []int32{31, 31}, false},
|
||||
{"tiny 2D weight", "tiny.weight", []int32{16, 16}, "fp8", false},
|
||||
{"small 2D weight", "small.weight", []int32{31, 31}, "fp8", false},
|
||||
|
||||
// 1D tensors should not be quantized
|
||||
{"1D tensor", "layer_norm.weight", []int32{4096}, false},
|
||||
{"1D tensor", "layer_norm.weight", []int32{4096}, "fp8", false},
|
||||
|
||||
// 3D+ tensors should not be quantized
|
||||
{"3D tensor", "conv.weight", []int32{64, 64, 3}, false},
|
||||
{"4D tensor", "conv2d.weight", []int32{64, 64, 3, 3}, false},
|
||||
{"3D tensor", "conv.weight", []int32{64, 64, 3}, "fp8", false},
|
||||
{"4D tensor", "conv2d.weight", []int32{64, 64, 3, 3}, "fp8", false},
|
||||
|
||||
// Embeddings should not be quantized regardless of shape
|
||||
{"embedding 2D", "embed_tokens.weight", []int32{32000, 4096}, false},
|
||||
{"embedding 2D", "embed_tokens.weight", []int32{32000, 4096}, "fp8", false},
|
||||
|
||||
// Norms should not be quantized regardless of shape
|
||||
{"norm 2D", "layer_norm.weight", []int32{4096, 1}, false},
|
||||
{"norm 2D", "layer_norm.weight", []int32{4096, 1}, "fp8", false},
|
||||
|
||||
// Biases should not be quantized
|
||||
{"bias 2D", "proj.bias", []int32{4096, 1}, false},
|
||||
{"bias 2D", "proj.bias", []int32{4096, 1}, "fp8", false},
|
||||
|
||||
// Group size divisibility tests
|
||||
// FP8/FP4 require divisible by 32
|
||||
{"not divisible by 32 fp8", "proj.weight", []int32{128, 48}, "fp8", false},
|
||||
{"divisible by 32 fp8", "proj.weight", []int32{128, 64}, "fp8", true},
|
||||
// NVFP4 requires divisible by 16
|
||||
{"not divisible by 16 nvfp4", "proj.weight", []int32{128, 24}, "nvfp4", false},
|
||||
{"divisible by 16 nvfp4", "proj.weight", []int32{128, 48}, "nvfp4", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := ShouldQuantizeTensor(tt.tensor, tt.shape)
|
||||
got := ShouldQuantizeTensor(tt.tensor, tt.shape, tt.quantize)
|
||||
if got != tt.want {
|
||||
t.Errorf("ShouldQuantizeTensor(%q, %v) = %v, want %v", tt.tensor, tt.shape, got, tt.want)
|
||||
t.Errorf("ShouldQuantizeTensor(%q, %v, %q) = %v, want %v", tt.tensor, tt.shape, tt.quantize, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -15,15 +15,15 @@ import (
|
||||
// CreateImageGenModel imports an image generation model from a directory.
|
||||
// Stores each tensor as a separate blob for fine-grained deduplication.
|
||||
// If quantize is specified, linear weights in transformer/text_encoder are quantized.
|
||||
// Supported quantization types: fp8 (or empty for no quantization).
|
||||
// Supported quantization types: fp4, fp8 (or empty for no quantization).
|
||||
// Layer creation and manifest writing are done via callbacks to avoid import cycles.
|
||||
func CreateImageGenModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
|
||||
// Validate quantization type
|
||||
switch quantize {
|
||||
case "", "fp4", "fp8":
|
||||
case "", "fp4", "fp8", "nvfp4":
|
||||
// valid
|
||||
default:
|
||||
return fmt.Errorf("unsupported quantization type %q: supported types are fp4, fp8", quantize)
|
||||
return fmt.Errorf("unsupported quantization type %q: supported types are fp4, fp8, nvfp4", quantize)
|
||||
}
|
||||
|
||||
var layers []LayerInfo
|
||||
@@ -89,7 +89,7 @@ func CreateImageGenModel(modelName, modelDir, quantize string, createLayer Layer
|
||||
|
||||
// Determine quantization type for this tensor (empty string if not quantizing)
|
||||
quantizeType := ""
|
||||
if quantize != "" && ShouldQuantize(tensorName, component) && canQuantizeShape(td.Shape) {
|
||||
if quantize != "" && ShouldQuantize(tensorName, component) && canQuantizeShape(td.Shape, quantize) {
|
||||
quantizeType = quantize
|
||||
}
|
||||
|
||||
@@ -213,10 +213,15 @@ func CreateImageGenModel(modelName, modelDir, quantize string, createLayer Layer
|
||||
}
|
||||
|
||||
// canQuantizeShape returns true if a tensor shape is compatible with MLX quantization.
|
||||
// MLX requires the last dimension to be divisible by the group size (32).
|
||||
func canQuantizeShape(shape []int32) bool {
|
||||
// MLX requires the last dimension to be divisible by the group size.
|
||||
// NVFP4 uses group_size=16, all other modes use group_size=32.
|
||||
func canQuantizeShape(shape []int32, quantize string) bool {
|
||||
if len(shape) < 2 {
|
||||
return false
|
||||
}
|
||||
return shape[len(shape)-1]%32 == 0
|
||||
groupSize := int32(32)
|
||||
if strings.ToUpper(quantize) == "NVFP4" {
|
||||
groupSize = 16
|
||||
}
|
||||
return shape[len(shape)-1]%groupSize == 0
|
||||
}
|
||||
|
||||
16
x/imagegen/cache/cache.go
vendored
16
x/imagegen/cache/cache.go
vendored
@@ -9,6 +9,7 @@ type Cache interface {
|
||||
Offset() int
|
||||
Len() int
|
||||
State() []*mlx.Array
|
||||
Reset()
|
||||
}
|
||||
|
||||
type KVCache struct {
|
||||
@@ -63,6 +64,13 @@ func (c *KVCache) State() []*mlx.Array {
|
||||
func (c *KVCache) Offset() int { return c.offset }
|
||||
func (c *KVCache) Len() int { return c.offset }
|
||||
|
||||
// Reset clears the cache state for a new generation session
|
||||
func (c *KVCache) Reset() {
|
||||
c.keys = nil
|
||||
c.values = nil
|
||||
c.offset = 0
|
||||
}
|
||||
|
||||
// RotatingKVCache implements sliding window attention with bounded memory
|
||||
type RotatingKVCache struct {
|
||||
keys, values *mlx.Array
|
||||
@@ -154,3 +162,11 @@ func (c *RotatingKVCache) State() []*mlx.Array {
|
||||
|
||||
func (c *RotatingKVCache) Offset() int { return c.offset }
|
||||
func (c *RotatingKVCache) Len() int { return min(c.offset, c.maxSize) }
|
||||
|
||||
// Reset clears the cache state for a new generation session
|
||||
func (c *RotatingKVCache) Reset() {
|
||||
c.keys = nil
|
||||
c.values = nil
|
||||
c.offset = 0
|
||||
c.idx = 0
|
||||
}
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/models/flux2"
|
||||
"github.com/ollama/ollama/x/imagegen/models/gemma3"
|
||||
"github.com/ollama/ollama/x/imagegen/models/glm4_moe_lite"
|
||||
"github.com/ollama/ollama/x/imagegen/models/gpt_oss"
|
||||
"github.com/ollama/ollama/x/imagegen/models/llama"
|
||||
"github.com/ollama/ollama/x/imagegen/models/zimage"
|
||||
@@ -242,6 +243,8 @@ func load(modelPath string) (Model, error) {
|
||||
return gemma3.Load(modelPath)
|
||||
case "gemma3_text":
|
||||
return gemma3.LoadText(modelPath)
|
||||
case "glm4_moe_lite":
|
||||
return glm4_moe_lite.Load(modelPath)
|
||||
default:
|
||||
return llama.Load(modelPath)
|
||||
}
|
||||
|
||||
@@ -116,6 +116,18 @@ func (m *ModelManifest) GetTensorLayers(component string) []ManifestLayer {
|
||||
return layers
|
||||
}
|
||||
|
||||
// GetAllTensorLayers returns all tensor layers without component filtering.
|
||||
// Used for LLM models where tensors don't have a component prefix.
|
||||
func (m *ModelManifest) GetAllTensorLayers() []ManifestLayer {
|
||||
var layers []ManifestLayer
|
||||
for _, layer := range m.Manifest.Layers {
|
||||
if layer.MediaType == "application/vnd.ollama.image.tensor" {
|
||||
layers = append(layers, layer)
|
||||
}
|
||||
}
|
||||
return layers
|
||||
}
|
||||
|
||||
// GetConfigLayer returns the config layer for a given path.
|
||||
func (m *ModelManifest) GetConfigLayer(configPath string) *ManifestLayer {
|
||||
for _, layer := range m.Manifest.Layers {
|
||||
|
||||
@@ -991,6 +991,19 @@ func Concat(a, b *Array, axis int) *Array {
|
||||
return Concatenate([]*Array{a, b}, axis)
|
||||
}
|
||||
|
||||
// Stack stacks arrays along a new axis (axis 0 by default)
|
||||
func Stack(arrays []*Array, axis int) *Array {
|
||||
handles := make([]C.mlx_array, len(arrays))
|
||||
for i, arr := range arrays {
|
||||
handles[i] = arr.c
|
||||
}
|
||||
vec := C.mlx_vector_array_new_data(&handles[0], C.size_t(len(handles)))
|
||||
res := C.mlx_array_new()
|
||||
C.mlx_stack_axis(&res, vec, C.int(axis), C.default_stream())
|
||||
C.mlx_vector_array_free(vec)
|
||||
return newArray(res)
|
||||
}
|
||||
|
||||
// Slice slices the array
|
||||
func Slice(a *Array, start, stop []int32) *Array {
|
||||
n := len(start)
|
||||
|
||||
709
x/imagegen/models/glm4_moe_lite/glm4_moe_lite.go
Normal file
709
x/imagegen/models/glm4_moe_lite/glm4_moe_lite.go
Normal file
@@ -0,0 +1,709 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package glm4_moe_lite provides the GLM4-MoE-Lite implementation for MLX.
|
||||
// This model uses Multi-head Latent Attention (MLA) and Mixture of Experts (MoE).
|
||||
package glm4_moe_lite
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/nn"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
)
|
||||
|
||||
// Config holds GLM4-MoE-Lite model configuration
|
||||
type Config struct {
|
||||
HiddenSize int32 `json:"hidden_size"`
|
||||
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
||||
IntermediateSize int32 `json:"intermediate_size"`
|
||||
MoEIntermediateSize int32 `json:"moe_intermediate_size"`
|
||||
NumAttentionHeads int32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads int32 `json:"num_key_value_heads"`
|
||||
VocabSize int32 `json:"vocab_size"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
|
||||
AttentionBias bool `json:"attention_bias"`
|
||||
|
||||
// MLA (Multi-head Latent Attention) parameters
|
||||
QLoraRank int32 `json:"q_lora_rank"`
|
||||
KVLoraRank int32 `json:"kv_lora_rank"`
|
||||
QKRopeHeadDim int32 `json:"qk_rope_head_dim"`
|
||||
QKNopeHeadDim int32 `json:"qk_nope_head_dim"`
|
||||
VHeadDim int32 `json:"v_head_dim"`
|
||||
|
||||
// MoE parameters
|
||||
NRoutedExperts int32 `json:"n_routed_experts"`
|
||||
NSharedExperts int32 `json:"n_shared_experts"`
|
||||
NumExpertsPerTok int32 `json:"num_experts_per_tok"`
|
||||
RoutedScalingFactor float32 `json:"routed_scaling_factor"`
|
||||
NormTopKProb bool `json:"norm_topk_prob"`
|
||||
FirstKDenseReplace int32 `json:"first_k_dense_replace"`
|
||||
NGroup int32 `json:"n_group"`
|
||||
TopKGroup int32 `json:"topk_group"`
|
||||
|
||||
// Computed fields
|
||||
QHeadDim int32 `json:"-"` // qk_nope_head_dim + qk_rope_head_dim
|
||||
Scale float32 `json:"-"` // 1/sqrt(QHeadDim)
|
||||
}
|
||||
|
||||
// MLAAttention implements Multi-head Latent Attention
|
||||
type MLAAttention struct {
|
||||
// Low-rank query projections
|
||||
QAProj nn.LinearLayer `weight:"self_attn.q_a_proj"`
|
||||
QALayerNorm *nn.RMSNorm `weight:"self_attn.q_a_layernorm"`
|
||||
QBProj nn.LinearLayer `weight:"self_attn.q_b_proj"`
|
||||
|
||||
// Low-rank KV projections (with shared rope component)
|
||||
KVAProjWithMQA nn.LinearLayer `weight:"self_attn.kv_a_proj_with_mqa"`
|
||||
KVALayerNorm *nn.RMSNorm `weight:"self_attn.kv_a_layernorm"`
|
||||
KVBProj nn.LinearLayer `weight:"self_attn.kv_b_proj"`
|
||||
|
||||
// Output projection
|
||||
OProj nn.LinearLayer `weight:"self_attn.o_proj"`
|
||||
}
|
||||
|
||||
// Forward computes MLA attention output
|
||||
func (a *MLAAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||
// Query path: q_a_proj -> layernorm -> q_b_proj
|
||||
q := a.QAProj.Forward(x)
|
||||
q = a.QALayerNorm.Forward(q, cfg.RMSNormEps)
|
||||
q = a.QBProj.Forward(q)
|
||||
|
||||
// Reshape Q: [B, L, num_heads * q_head_dim] -> [B, num_heads, L, q_head_dim]
|
||||
q = mlx.Reshape(q, B, L, cfg.NumAttentionHeads, cfg.QHeadDim)
|
||||
q = mlx.Transpose(q, 0, 2, 1, 3)
|
||||
|
||||
// Split Q into nope and rope parts
|
||||
qNope := mlx.Slice(q, []int32{0, 0, 0, 0}, []int32{B, cfg.NumAttentionHeads, L, cfg.QKNopeHeadDim})
|
||||
qPE := mlx.Slice(q, []int32{0, 0, 0, cfg.QKNopeHeadDim}, []int32{B, cfg.NumAttentionHeads, L, cfg.QHeadDim})
|
||||
|
||||
// KV path: kv_a_proj_with_mqa -> split -> layernorm -> kv_b_proj
|
||||
compressedKV := a.KVAProjWithMQA.Forward(x)
|
||||
|
||||
// Split into compressed_kv and k_pe (shared rope component)
|
||||
kvCompressed := mlx.Slice(compressedKV, []int32{0, 0, 0}, []int32{B, L, cfg.KVLoraRank})
|
||||
kPE := mlx.Slice(compressedKV, []int32{0, 0, cfg.KVLoraRank}, []int32{B, L, cfg.KVLoraRank + cfg.QKRopeHeadDim})
|
||||
|
||||
// k_pe is shared across heads (MQA-style): [B, L, rope_dim] -> [B, 1, L, rope_dim]
|
||||
kPE = mlx.Reshape(kPE, B, L, 1, cfg.QKRopeHeadDim)
|
||||
kPE = mlx.Transpose(kPE, 0, 2, 1, 3)
|
||||
|
||||
// Apply layernorm and project KV
|
||||
kvCompressed = a.KVALayerNorm.Forward(kvCompressed, cfg.RMSNormEps)
|
||||
kv := a.KVBProj.Forward(kvCompressed)
|
||||
|
||||
// Reshape KV: [B, L, num_heads * (qk_nope_head_dim + v_head_dim)]
|
||||
kv = mlx.Reshape(kv, B, L, cfg.NumAttentionHeads, cfg.QKNopeHeadDim+cfg.VHeadDim)
|
||||
kv = mlx.Transpose(kv, 0, 2, 1, 3)
|
||||
|
||||
// Split into k_nope and values
|
||||
kNope := mlx.Slice(kv, []int32{0, 0, 0, 0}, []int32{B, cfg.NumAttentionHeads, L, cfg.QKNopeHeadDim})
|
||||
values := mlx.Slice(kv, []int32{0, 0, 0, cfg.QKNopeHeadDim}, []int32{B, cfg.NumAttentionHeads, L, cfg.QKNopeHeadDim + cfg.VHeadDim})
|
||||
|
||||
// Apply RoPE to the rope parts only
|
||||
offset := 0
|
||||
if c != nil {
|
||||
offset = c.Offset()
|
||||
}
|
||||
qPE = mlx.RoPE(qPE, int(cfg.QKRopeHeadDim), true, cfg.RopeTheta, 1.0, offset)
|
||||
kPE = mlx.RoPE(kPE, int(cfg.QKRopeHeadDim), true, cfg.RopeTheta, 1.0, offset)
|
||||
|
||||
// Repeat k_pe across all heads
|
||||
kPE = mlx.Tile(kPE, []int32{1, cfg.NumAttentionHeads, 1, 1})
|
||||
|
||||
// Concatenate nope and rope parts
|
||||
queries := mlx.Concatenate([]*mlx.Array{qNope, qPE}, 3)
|
||||
keys := mlx.Concatenate([]*mlx.Array{kNope, kPE}, 3)
|
||||
|
||||
// Update KV cache
|
||||
if c != nil {
|
||||
keys, values = c.Update(keys, values, int(L))
|
||||
}
|
||||
|
||||
// Scaled dot product attention
|
||||
out := mlx.ScaledDotProductAttention(queries, keys, values, cfg.Scale, L > 1)
|
||||
|
||||
// Reshape back: [B, num_heads, L, v_head_dim] -> [B, L, num_heads * v_head_dim]
|
||||
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.VHeadDim)
|
||||
|
||||
return a.OProj.Forward(out)
|
||||
}
|
||||
|
||||
// DenseMLP implements the standard SwiGLU MLP for dense layers
|
||||
type DenseMLP struct {
|
||||
GateProj nn.LinearLayer `weight:"mlp.gate_proj"`
|
||||
UpProj nn.LinearLayer `weight:"mlp.up_proj"`
|
||||
DownProj nn.LinearLayer `weight:"mlp.down_proj"`
|
||||
}
|
||||
|
||||
// Forward applies the SwiGLU MLP
|
||||
func (m *DenseMLP) Forward(x *mlx.Array) *mlx.Array {
|
||||
gate := mlx.SiLU(m.GateProj.Forward(x))
|
||||
up := m.UpProj.Forward(x)
|
||||
return m.DownProj.Forward(mlx.Mul(gate, up))
|
||||
}
|
||||
|
||||
// MoEGate implements the expert gating mechanism
|
||||
type MoEGate struct {
|
||||
Gate nn.LinearLayer `weight:"mlp.gate"`
|
||||
EScoreCorrectionBias *mlx.Array `weight:"mlp.gate.e_score_correction_bias,optional"`
|
||||
}
|
||||
|
||||
// Forward computes expert selection indices and scores
|
||||
func (g *MoEGate) Forward(x *mlx.Array, cfg *Config) (*mlx.Array, *mlx.Array) {
|
||||
// Compute gate logits through linear layer (handles both quantized and non-quantized)
|
||||
gates := g.Gate.Forward(x)
|
||||
|
||||
// Sigmoid scoring
|
||||
scores := mlx.Sigmoid(gates)
|
||||
origScores := scores
|
||||
|
||||
// Add correction bias if present
|
||||
if g.EScoreCorrectionBias != nil {
|
||||
scores = mlx.Add(scores, g.EScoreCorrectionBias)
|
||||
}
|
||||
|
||||
// Group-wise expert selection (simplified for n_group=1)
|
||||
// Select top-k experts
|
||||
topK := cfg.NumExpertsPerTok
|
||||
negScores := mlx.Neg(scores)
|
||||
inds := mlx.Argpartition(negScores, int(topK)-1, -1)
|
||||
|
||||
shape := inds.Shape()
|
||||
inds = mlx.Slice(inds, []int32{0, 0, 0}, []int32{shape[0], shape[1], topK})
|
||||
|
||||
// Get scores for selected experts
|
||||
scores = mlx.TakeAlongAxis(origScores, inds, -1)
|
||||
|
||||
// Normalize if configured
|
||||
if topK > 1 && cfg.NormTopKProb {
|
||||
sumScores := mlx.Sum(scores, -1, true)
|
||||
scores = mlx.Div(scores, sumScores)
|
||||
}
|
||||
|
||||
// Apply routing scaling factor
|
||||
scores = mlx.MulScalar(scores, cfg.RoutedScalingFactor)
|
||||
|
||||
return inds, scores
|
||||
}
|
||||
|
||||
// SwitchMLP implements the MoE expert computation using stacked weights
|
||||
// Note: No weight tags - these are populated manually by stacking expert weights
|
||||
type SwitchMLP struct {
|
||||
GateWeight *mlx.Array
|
||||
UpWeight *mlx.Array
|
||||
DownWeight *mlx.Array
|
||||
}
|
||||
|
||||
// Forward applies the switched expert MLP
|
||||
func (s *SwitchMLP) Forward(x *mlx.Array, indices *mlx.Array, cfg *Config) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B, L := shape[0], shape[1]
|
||||
topK := cfg.NumExpertsPerTok
|
||||
|
||||
// Expand x for expert computation: [B, L, D] -> [B, L, 1, 1, D]
|
||||
xExpanded := mlx.ExpandDims(mlx.ExpandDims(x, -2), -2)
|
||||
|
||||
// Flatten for gather_mm: [B*L, 1, 1, D]
|
||||
xFlat := mlx.Reshape(xExpanded, B*L, 1, 1, cfg.HiddenSize)
|
||||
|
||||
// Flatten indices: [B, L, topK] -> [B*L, topK]
|
||||
idxFlat := mlx.Reshape(indices, B*L, topK)
|
||||
|
||||
// Sort for efficient gather (when we have many tokens)
|
||||
doSort := B*L >= 64
|
||||
var invOrder *mlx.Array
|
||||
n := B * L * topK
|
||||
|
||||
if doSort {
|
||||
idxAll := mlx.Flatten(idxFlat)
|
||||
order := mlx.Argsort(idxAll, 0)
|
||||
invOrder = mlx.Argsort(order, 0)
|
||||
// Reorder x based on sorted indices
|
||||
xFlat = mlx.ExpandDims(mlx.Take(mlx.Squeeze(xFlat, 1), mlx.FloorDivideScalar(order, topK), 0), 1)
|
||||
idxFlat = mlx.Reshape(mlx.Take(idxAll, order, 0), n, 1)
|
||||
}
|
||||
|
||||
// Expert computation using gather_mm
|
||||
// gate: x @ gate_weight.T (indices are on the rhs/weight side)
|
||||
gate := mlx.GatherMM(xFlat, mlx.Transpose(s.GateWeight, 0, 2, 1), nil, idxFlat, doSort)
|
||||
// up: x @ up_weight.T
|
||||
up := mlx.GatherMM(xFlat, mlx.Transpose(s.UpWeight, 0, 2, 1), nil, idxFlat, doSort)
|
||||
|
||||
// SwiGLU activation
|
||||
hidden := mlx.Mul(mlx.SiLU(gate), up)
|
||||
|
||||
// down: hidden @ down_weight.T
|
||||
down := mlx.GatherMM(hidden, mlx.Transpose(s.DownWeight, 0, 2, 1), nil, idxFlat, doSort)
|
||||
|
||||
// Unsort if we sorted
|
||||
if doSort {
|
||||
down = mlx.Reshape(mlx.Take(mlx.Squeeze(mlx.Squeeze(down, 2), 1), invOrder, 0), B*L, topK, cfg.HiddenSize)
|
||||
} else {
|
||||
down = mlx.Squeeze(down, 2)
|
||||
}
|
||||
|
||||
return mlx.Reshape(down, B, L, topK, cfg.HiddenSize)
|
||||
}
|
||||
|
||||
// SharedExperts implements the shared expert MLP
|
||||
type SharedExperts struct {
|
||||
GateProj nn.LinearLayer `weight:"mlp.shared_experts.gate_proj"`
|
||||
UpProj nn.LinearLayer `weight:"mlp.shared_experts.up_proj"`
|
||||
DownProj nn.LinearLayer `weight:"mlp.shared_experts.down_proj"`
|
||||
}
|
||||
|
||||
// Forward applies the shared expert MLP
|
||||
func (s *SharedExperts) Forward(x *mlx.Array) *mlx.Array {
|
||||
gate := mlx.SiLU(s.GateProj.Forward(x))
|
||||
up := s.UpProj.Forward(x)
|
||||
return s.DownProj.Forward(mlx.Mul(gate, up))
|
||||
}
|
||||
|
||||
// MoE implements the full Mixture of Experts layer
|
||||
type MoE struct {
|
||||
Gate *MoEGate
|
||||
SwitchMLP *SwitchMLP
|
||||
SharedExperts *SharedExperts
|
||||
}
|
||||
|
||||
// Forward applies the MoE layer
|
||||
func (m *MoE) Forward(x *mlx.Array, cfg *Config) *mlx.Array {
|
||||
shape := x.Shape()
|
||||
B, L := shape[0], shape[1]
|
||||
|
||||
// Get expert indices and scores
|
||||
inds, scores := m.Gate.Forward(x, cfg)
|
||||
|
||||
// Apply routed experts
|
||||
expertOut := m.SwitchMLP.Forward(x, inds, cfg)
|
||||
|
||||
// Weight by scores: [B, L, topK, D] * [B, L, topK, 1] -> sum over topK
|
||||
scoresExpanded := mlx.ExpandDims(scores, -1)
|
||||
y := mlx.Sum(mlx.Mul(expertOut, scoresExpanded), 2, false)
|
||||
|
||||
// Add shared experts if present
|
||||
if m.SharedExperts != nil {
|
||||
y = mlx.Add(y, m.SharedExperts.Forward(x))
|
||||
}
|
||||
|
||||
return mlx.Reshape(y, B, L, cfg.HiddenSize)
|
||||
}
|
||||
|
||||
// DenseBlock represents a dense transformer block (for first_k_dense_replace layers)
|
||||
type DenseBlock struct {
|
||||
Attention *MLAAttention
|
||||
MLP *DenseMLP
|
||||
InputLayerNorm *nn.RMSNorm `weight:"input_layernorm"`
|
||||
PostAttentionLayerNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
|
||||
}
|
||||
|
||||
// Forward applies the dense block
|
||||
func (b *DenseBlock) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||
// Pre-norm attention with residual
|
||||
r := b.Attention.Forward(b.InputLayerNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg)
|
||||
h := mlx.Add(x, r)
|
||||
|
||||
// Pre-norm MLP with residual
|
||||
r = b.MLP.Forward(b.PostAttentionLayerNorm.Forward(h, cfg.RMSNormEps))
|
||||
return mlx.Add(h, r)
|
||||
}
|
||||
|
||||
// MoEBlock represents a MoE transformer block
|
||||
type MoEBlock struct {
|
||||
Attention *MLAAttention
|
||||
MoE *MoE
|
||||
InputLayerNorm *nn.RMSNorm `weight:"input_layernorm"`
|
||||
PostAttentionLayerNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
|
||||
}
|
||||
|
||||
// Forward applies the MoE block
|
||||
func (b *MoEBlock) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||
// Pre-norm attention with residual
|
||||
r := b.Attention.Forward(b.InputLayerNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg)
|
||||
h := mlx.Add(x, r)
|
||||
|
||||
// Pre-norm MoE with residual
|
||||
r = b.MoE.Forward(b.PostAttentionLayerNorm.Forward(h, cfg.RMSNormEps), cfg)
|
||||
return mlx.Add(h, r)
|
||||
}
|
||||
|
||||
// Block interface for both dense and MoE blocks
|
||||
type Block interface {
|
||||
Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array
|
||||
}
|
||||
|
||||
// Model represents the complete GLM4-MoE-Lite model
|
||||
type Model struct {
|
||||
EmbedTokens *nn.Embedding `weight:"model.embed_tokens"`
|
||||
Layers []Block `weight:"-"` // Loaded manually due to different block types
|
||||
Norm *nn.RMSNorm `weight:"model.norm"`
|
||||
LMHead nn.LinearLayer `weight:"lm_head"`
|
||||
|
||||
tok *tokenizer.Tokenizer
|
||||
*Config
|
||||
}
|
||||
|
||||
// loadExpertWeight loads an expert weight, dequantizing if necessary.
|
||||
// GatherMM doesn't support quantized weights, so we must dequantize for MoE.
|
||||
func loadExpertWeight(weights safetensors.WeightSource, path string) *mlx.Array {
|
||||
w, _ := weights.GetTensor(path + ".weight")
|
||||
if w == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if this is a quantized weight by looking for scales
|
||||
scalePath := path + ".weight_scale"
|
||||
if weights.HasTensor(scalePath) {
|
||||
scales, _ := weights.GetTensor(scalePath)
|
||||
var qbiases *mlx.Array
|
||||
qbiasPath := path + ".weight_qbias"
|
||||
if weights.HasTensor(qbiasPath) {
|
||||
qbiases, _ = weights.GetTensor(qbiasPath)
|
||||
}
|
||||
// Dequantize using the model's quantization parameters
|
||||
groupSize, bits, mode := safetensors.QuantizationParams(weights.Quantization())
|
||||
return mlx.Dequantize(w, scales, qbiases, groupSize, bits, mode)
|
||||
}
|
||||
|
||||
return w
|
||||
}
|
||||
|
||||
// sanitizeExpertWeights stacks individual expert weights into a single tensor.
|
||||
// For quantized models, expert weights are dequantized since GatherMM doesn't support quantized weights.
|
||||
func sanitizeExpertWeights(weights safetensors.WeightSource, prefix string, numExperts int32) (*mlx.Array, *mlx.Array, *mlx.Array) {
|
||||
var gateWeights, upWeights, downWeights []*mlx.Array
|
||||
|
||||
for e := int32(0); e < numExperts; e++ {
|
||||
gw := loadExpertWeight(weights, fmt.Sprintf("%s.mlp.experts.%d.gate_proj", prefix, e))
|
||||
uw := loadExpertWeight(weights, fmt.Sprintf("%s.mlp.experts.%d.up_proj", prefix, e))
|
||||
dw := loadExpertWeight(weights, fmt.Sprintf("%s.mlp.experts.%d.down_proj", prefix, e))
|
||||
|
||||
if gw != nil {
|
||||
gateWeights = append(gateWeights, gw)
|
||||
}
|
||||
if uw != nil {
|
||||
upWeights = append(upWeights, uw)
|
||||
}
|
||||
if dw != nil {
|
||||
downWeights = append(downWeights, dw)
|
||||
}
|
||||
}
|
||||
|
||||
var stackedGate, stackedUp, stackedDown *mlx.Array
|
||||
if len(gateWeights) > 0 {
|
||||
stackedGate = mlx.Stack(gateWeights, 0)
|
||||
}
|
||||
if len(upWeights) > 0 {
|
||||
stackedUp = mlx.Stack(upWeights, 0)
|
||||
}
|
||||
if len(downWeights) > 0 {
|
||||
stackedDown = mlx.Stack(downWeights, 0)
|
||||
}
|
||||
|
||||
return stackedGate, stackedUp, stackedDown
|
||||
}
|
||||
|
||||
// Load loads a GLM4-MoE-Lite model from the given path
|
||||
func Load(modelPath string) (*Model, error) {
|
||||
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
|
||||
var cfg Config
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("parse config: %w", err)
|
||||
}
|
||||
|
||||
// Compute derived fields
|
||||
cfg.QHeadDim = cfg.QKNopeHeadDim + cfg.QKRopeHeadDim
|
||||
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.QHeadDim)))
|
||||
|
||||
weights, err := safetensors.LoadModelWeights(modelPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
|
||||
tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load tokenizer: %w", err)
|
||||
}
|
||||
|
||||
m := &Model{
|
||||
Layers: make([]Block, cfg.NumHiddenLayers),
|
||||
Config: &cfg,
|
||||
tok: tok,
|
||||
}
|
||||
|
||||
// Load embedding, norm, and lm_head
|
||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Load layers manually due to different block types
|
||||
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
|
||||
prefix := fmt.Sprintf("model.layers.%d", i)
|
||||
|
||||
// Load attention (same for both block types)
|
||||
attn := &MLAAttention{}
|
||||
if err := safetensors.LoadModule(attn, weights, prefix); err != nil {
|
||||
return nil, fmt.Errorf("layer %d attention: %w", i, err)
|
||||
}
|
||||
|
||||
if i < cfg.FirstKDenseReplace {
|
||||
// Dense block
|
||||
block := &DenseBlock{Attention: attn}
|
||||
if err := safetensors.LoadModule(block, weights, prefix); err != nil {
|
||||
return nil, fmt.Errorf("layer %d dense: %w", i, err)
|
||||
}
|
||||
m.Layers[i] = block
|
||||
} else {
|
||||
// MoE block
|
||||
block := &MoEBlock{Attention: attn}
|
||||
if err := safetensors.LoadModule(block, weights, prefix); err != nil {
|
||||
return nil, fmt.Errorf("layer %d moe block: %w", i, err)
|
||||
}
|
||||
|
||||
// Stack expert weights
|
||||
gateW, upW, downW := sanitizeExpertWeights(weights, prefix, cfg.NRoutedExperts)
|
||||
|
||||
block.MoE = &MoE{
|
||||
Gate: &MoEGate{},
|
||||
SwitchMLP: &SwitchMLP{
|
||||
GateWeight: gateW,
|
||||
UpWeight: upW,
|
||||
DownWeight: downW,
|
||||
},
|
||||
}
|
||||
|
||||
// Load gate weights
|
||||
if err := safetensors.LoadModule(block.MoE.Gate, weights, prefix); err != nil {
|
||||
return nil, fmt.Errorf("layer %d gate: %w", i, err)
|
||||
}
|
||||
|
||||
// Load shared experts if present
|
||||
if cfg.NSharedExperts > 0 {
|
||||
block.MoE.SharedExperts = &SharedExperts{}
|
||||
if err := safetensors.LoadModule(block.MoE.SharedExperts, weights, prefix); err != nil {
|
||||
return nil, fmt.Errorf("layer %d shared experts: %w", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
m.Layers[i] = block
|
||||
}
|
||||
}
|
||||
|
||||
mlx.Eval(mlx.Collect(m)...)
|
||||
weights.ReleaseAll()
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// LoadFromManifest loads a GLM4-MoE-Lite model from a manifest (Ollama blob storage).
|
||||
func LoadFromManifest(manifest *imagegen.ModelManifest) (*Model, error) {
|
||||
// Read config from manifest
|
||||
configData, err := manifest.ReadConfig("config.json")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
|
||||
var cfg Config
|
||||
if err := json.Unmarshal(configData, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("parse config: %w", err)
|
||||
}
|
||||
|
||||
// Compute derived fields
|
||||
cfg.QHeadDim = cfg.QKNopeHeadDim + cfg.QKRopeHeadDim
|
||||
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.QHeadDim)))
|
||||
|
||||
// Load weights from manifest blobs
|
||||
weights, err := imagegen.LoadAllWeightsFromManifest(manifest)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load weights: %w", err)
|
||||
}
|
||||
|
||||
// Debug: print quantization info and sample tensor names
|
||||
fmt.Printf("GLM4: quantization=%q, num_tensors=%d\n", weights.Quantization(), len(weights.ListTensors()))
|
||||
tensors := weights.ListTensors()
|
||||
for i, name := range tensors {
|
||||
if i < 20 { // Print first 20 tensor names
|
||||
fmt.Printf(" tensor[%d]: %s\n", i, name)
|
||||
}
|
||||
}
|
||||
|
||||
if err := weights.Load(0); err != nil {
|
||||
return nil, fmt.Errorf("load weight data: %w", err)
|
||||
}
|
||||
|
||||
// Load tokenizer from manifest with config files for EOS token detection
|
||||
tokData, err := manifest.ReadConfig("tokenizer.json")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load tokenizer config: %w", err)
|
||||
}
|
||||
|
||||
// Build tokenizer config with companion files for EOS/BOS token loading
|
||||
tokConfig := &tokenizer.TokenizerConfig{
|
||||
ConfigJSON: configData, // Already loaded above, contains eos_token_id
|
||||
}
|
||||
|
||||
// Try to load generation_config.json if available (preferred source for EOS)
|
||||
if genConfigData, err := manifest.ReadConfig("generation_config.json"); err == nil {
|
||||
tokConfig.GenerationConfigJSON = genConfigData
|
||||
}
|
||||
|
||||
// Try to load tokenizer_config.json if available
|
||||
if tokConfigData, err := manifest.ReadConfig("tokenizer_config.json"); err == nil {
|
||||
tokConfig.TokenizerConfigJSON = tokConfigData
|
||||
}
|
||||
|
||||
tok, err := tokenizer.LoadFromBytesWithConfig(tokData, tokConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse tokenizer: %w", err)
|
||||
}
|
||||
|
||||
m := &Model{
|
||||
Layers: make([]Block, cfg.NumHiddenLayers),
|
||||
Config: &cfg,
|
||||
tok: tok,
|
||||
}
|
||||
|
||||
// Load embedding, norm, and lm_head
|
||||
if err := safetensors.LoadModule(m, weights, ""); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Load layers manually due to different block types
|
||||
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
|
||||
prefix := fmt.Sprintf("model.layers.%d", i)
|
||||
|
||||
// Load attention (same for both block types)
|
||||
attn := &MLAAttention{}
|
||||
if err := safetensors.LoadModule(attn, weights, prefix); err != nil {
|
||||
return nil, fmt.Errorf("layer %d attention: %w", i, err)
|
||||
}
|
||||
|
||||
if i < cfg.FirstKDenseReplace {
|
||||
// Dense block
|
||||
block := &DenseBlock{Attention: attn}
|
||||
if err := safetensors.LoadModule(block, weights, prefix); err != nil {
|
||||
return nil, fmt.Errorf("layer %d dense: %w", i, err)
|
||||
}
|
||||
m.Layers[i] = block
|
||||
} else {
|
||||
// MoE block
|
||||
block := &MoEBlock{Attention: attn}
|
||||
if err := safetensors.LoadModule(block, weights, prefix); err != nil {
|
||||
return nil, fmt.Errorf("layer %d moe block: %w", i, err)
|
||||
}
|
||||
|
||||
// Stack expert weights
|
||||
gateW, upW, downW := sanitizeExpertWeights(weights, prefix, cfg.NRoutedExperts)
|
||||
|
||||
block.MoE = &MoE{
|
||||
Gate: &MoEGate{},
|
||||
SwitchMLP: &SwitchMLP{
|
||||
GateWeight: gateW,
|
||||
UpWeight: upW,
|
||||
DownWeight: downW,
|
||||
},
|
||||
}
|
||||
|
||||
// Load gate weights
|
||||
if err := safetensors.LoadModule(block.MoE.Gate, weights, prefix); err != nil {
|
||||
return nil, fmt.Errorf("layer %d gate: %w", i, err)
|
||||
}
|
||||
|
||||
// Load shared experts if present
|
||||
if cfg.NSharedExperts > 0 {
|
||||
block.MoE.SharedExperts = &SharedExperts{}
|
||||
if err := safetensors.LoadModule(block.MoE.SharedExperts, weights, prefix); err != nil {
|
||||
return nil, fmt.Errorf("layer %d shared experts: %w", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
m.Layers[i] = block
|
||||
}
|
||||
}
|
||||
|
||||
mlx.Eval(mlx.Collect(m)...)
|
||||
weights.ReleaseAll()
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Forward computes the forward pass of the model
|
||||
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
|
||||
B, L := tokens.Shape()[0], tokens.Shape()[1]
|
||||
|
||||
h := m.EmbedTokens.Forward(tokens)
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
var c cache.Cache
|
||||
if caches != nil {
|
||||
c = caches[i]
|
||||
}
|
||||
h = layer.Forward(h, c, B, L, m.Config)
|
||||
}
|
||||
|
||||
h = m.Norm.Forward(h, m.RMSNormEps)
|
||||
return m.LMHead.Forward(h)
|
||||
}
|
||||
|
||||
// Interface methods
|
||||
|
||||
// NumLayers returns the number of transformer layers
|
||||
func (m *Model) NumLayers() int { return len(m.Layers) }
|
||||
|
||||
// MaxContextLength returns the maximum context length
|
||||
func (m *Model) MaxContextLength() int32 { return m.MaxPositionEmbeddings }
|
||||
|
||||
// VocabSize returns the vocabulary size
|
||||
func (m *Model) VocabSize() int32 { return m.Config.VocabSize }
|
||||
|
||||
// Tokenizer returns the model's tokenizer
|
||||
func (m *Model) Tokenizer() *tokenizer.Tokenizer { return m.tok }
|
||||
|
||||
// NewCache creates a new KV cache for the model
|
||||
func (m *Model) NewCache(maxSeqLen int32) []cache.Cache {
|
||||
caches := make([]cache.Cache, len(m.Layers))
|
||||
for i := range caches {
|
||||
caches[i] = cache.NewKVCache()
|
||||
}
|
||||
return caches
|
||||
}
|
||||
|
||||
// FormatPrompt applies the GLM-4 chat template with thinking enabled by default.
|
||||
// This follows the GLM-4.7 format with <think> tag for reasoning mode.
|
||||
func (m *Model) FormatPrompt(prompt string) string {
|
||||
return "[gMASK]<sop><|user|>" + prompt + "<|assistant|><think>"
|
||||
}
|
||||
|
||||
// FormatPromptWithThinking applies the GLM-4 chat template with explicit thinking control.
|
||||
// When think is true, the prompt ends with <think> to enable reasoning mode.
|
||||
// When think is false, the prompt ends with </think> to skip reasoning.
|
||||
func (m *Model) FormatPromptWithThinking(prompt string, think bool) string {
|
||||
if think {
|
||||
return "[gMASK]<sop><|user|>" + prompt + "<|assistant|><think>"
|
||||
}
|
||||
return "[gMASK]<sop><|user|>" + prompt + "<|assistant|></think>"
|
||||
}
|
||||
|
||||
// NewRenderer returns a new Renderer for formatting multi-turn conversations.
|
||||
func (m *Model) NewRenderer() *Renderer {
|
||||
return &Renderer{}
|
||||
}
|
||||
|
||||
// NewParser returns a new Parser for extracting thinking and tool calls from output.
|
||||
func (m *Model) NewParser() *Parser {
|
||||
return &Parser{}
|
||||
}
|
||||
479
x/imagegen/models/glm4_moe_lite/parser.go
Normal file
479
x/imagegen/models/glm4_moe_lite/parser.go
Normal file
@@ -0,0 +1,479 @@
|
||||
//go:build mlx
|
||||
|
||||
package glm4_moe_lite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
type parserState int
|
||||
|
||||
const (
|
||||
parserState_LookingForThinkingOpen parserState = iota
|
||||
parserState_ThinkingStartedEatingWhitespace
|
||||
parserState_CollectingThinking
|
||||
parserState_ThinkingDoneEatingWhitespace
|
||||
parserState_CollectingContent
|
||||
parserState_ToolStartedEatingWhitespace
|
||||
parserState_CollectingToolContent
|
||||
)
|
||||
|
||||
const (
|
||||
thinkingOpenTag = "<think>"
|
||||
thinkingCloseTag = "</think>"
|
||||
toolOpenTag = "<tool_call>"
|
||||
toolCloseTag = "</tool_call>"
|
||||
)
|
||||
|
||||
// Parser parses GLM4-MoE-Lite model output to extract thinking and tool calls.
|
||||
// GLM-4's prompt ends with <think> when thinking is enabled, so the parser
|
||||
// must start in CollectingThinking state (the model outputs thinking content directly).
|
||||
type Parser struct {
|
||||
state parserState
|
||||
buffer strings.Builder
|
||||
tools []api.Tool
|
||||
}
|
||||
|
||||
// HasToolSupport returns true as GLM4 supports tool calling.
|
||||
func (p *Parser) HasToolSupport() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// HasThinkingSupport returns true as GLM4 supports thinking mode.
|
||||
func (p *Parser) HasThinkingSupport() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// Init initializes the parser with tools and thinking configuration.
|
||||
func (p *Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||
p.tools = tools
|
||||
// When thinking is enabled (nil or true), the prompt ends with <think>,
|
||||
// so model output starts directly with thinking content (no opening tag).
|
||||
if thinkValue == nil || thinkValue.Bool() {
|
||||
p.state = parserState_CollectingThinking
|
||||
}
|
||||
return tools
|
||||
}
|
||||
|
||||
type parserEvent interface {
|
||||
isParserEvent()
|
||||
}
|
||||
|
||||
type eventContent struct {
|
||||
content string
|
||||
}
|
||||
|
||||
func (eventContent) isParserEvent() {}
|
||||
|
||||
type eventRawToolCall struct {
|
||||
raw string
|
||||
}
|
||||
|
||||
func (eventRawToolCall) isParserEvent() {}
|
||||
|
||||
type eventThinkingContent struct {
|
||||
content string
|
||||
}
|
||||
|
||||
func (eventThinkingContent) isParserEvent() {}
|
||||
|
||||
// Add processes new output text and returns parsed content, thinking, and tool calls.
|
||||
func (p *Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
|
||||
p.buffer.WriteString(s)
|
||||
events := p.parseEvents()
|
||||
|
||||
var toolCalls []api.ToolCall
|
||||
var contentSb strings.Builder
|
||||
var thinkingSb strings.Builder
|
||||
|
||||
for _, event := range events {
|
||||
switch event := event.(type) {
|
||||
case eventRawToolCall:
|
||||
toolCall, err := parseToolCall(event, p.tools)
|
||||
if err != nil {
|
||||
slog.Warn("glm-4 tool call parsing failed", "error", err)
|
||||
return "", "", nil, err
|
||||
}
|
||||
toolCalls = append(toolCalls, toolCall)
|
||||
case eventThinkingContent:
|
||||
thinkingSb.WriteString(event.content)
|
||||
case eventContent:
|
||||
contentSb.WriteString(event.content)
|
||||
}
|
||||
}
|
||||
|
||||
return contentSb.String(), thinkingSb.String(), toolCalls, nil
|
||||
}
|
||||
|
||||
func (p *Parser) parseEvents() []parserEvent {
|
||||
var all []parserEvent
|
||||
|
||||
keepLooping := true
|
||||
for keepLooping {
|
||||
var events []parserEvent
|
||||
events, keepLooping = p.eat()
|
||||
if len(events) > 0 {
|
||||
all = append(all, events...)
|
||||
}
|
||||
}
|
||||
|
||||
if len(all) > 0 {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "glm-4 events parsed", "events", all, "state", p.state, "buffer", p.buffer.String())
|
||||
}
|
||||
|
||||
return all
|
||||
}
|
||||
|
||||
// eatLeadingWhitespaceAndTransitionTo consumes leading whitespace from the buffer
|
||||
// and transitions to the next state. Returns (nil, false) if only whitespace remains
|
||||
// in the buffer (needs more input), or (nil, true) if we successfully transitioned.
|
||||
func (p *Parser) eatLeadingWhitespaceAndTransitionTo(nextState parserState) ([]parserEvent, bool) {
|
||||
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
if trimmed == "" {
|
||||
return nil, false // Still only whitespace, keep waiting for more input
|
||||
}
|
||||
p.state = nextState
|
||||
p.buffer.WriteString(trimmed)
|
||||
return nil, true // Successfully transitioned
|
||||
}
|
||||
|
||||
// splitAtTag splits the buffer at the given tag, returns the content before (trimmed of trailing whitespace),
|
||||
// the content after (optionally trimmed of leading whitespace), and updates the buffer
|
||||
func (p *Parser) splitAtTag(tag string, trimAfter bool) (string, string) {
|
||||
split := strings.SplitN(p.buffer.String(), tag, 2)
|
||||
before := split[0]
|
||||
before = strings.TrimRightFunc(before, unicode.IsSpace)
|
||||
after := split[1]
|
||||
if trimAfter {
|
||||
after = strings.TrimLeftFunc(after, unicode.IsSpace)
|
||||
}
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(after)
|
||||
return before, after
|
||||
}
|
||||
|
||||
func (p *Parser) eat() ([]parserEvent, bool) {
|
||||
var events []parserEvent
|
||||
|
||||
switch p.state {
|
||||
case parserState_LookingForThinkingOpen:
|
||||
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
|
||||
if strings.HasPrefix(trimmed, thinkingOpenTag) {
|
||||
// Found <think> opening tag
|
||||
after := strings.TrimPrefix(trimmed, thinkingOpenTag)
|
||||
after = strings.TrimLeftFunc(after, unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(after)
|
||||
if after == "" {
|
||||
p.state = parserState_ThinkingStartedEatingWhitespace
|
||||
} else {
|
||||
p.state = parserState_CollectingThinking
|
||||
}
|
||||
return events, true
|
||||
} else if strings.HasPrefix(thinkingOpenTag, trimmed) {
|
||||
// Partial opening tag seen, keep accumulating
|
||||
return events, false
|
||||
} else if trimmed == "" {
|
||||
// Only whitespace, keep accumulating
|
||||
return events, false
|
||||
} else {
|
||||
// No thinking tag found, skip to content collection
|
||||
p.state = parserState_CollectingContent
|
||||
// Don't trim - we want to keep the original content
|
||||
return events, true
|
||||
}
|
||||
|
||||
case parserState_ThinkingStartedEatingWhitespace:
|
||||
return p.eatLeadingWhitespaceAndTransitionTo(parserState_CollectingThinking)
|
||||
|
||||
case parserState_CollectingThinking:
|
||||
acc := p.buffer.String()
|
||||
if strings.Contains(acc, thinkingCloseTag) {
|
||||
thinking, remaining := p.splitAtTag(thinkingCloseTag, true)
|
||||
if len(thinking) > 0 {
|
||||
events = append(events, eventThinkingContent{content: thinking})
|
||||
}
|
||||
if remaining == "" {
|
||||
p.state = parserState_ThinkingDoneEatingWhitespace
|
||||
} else {
|
||||
p.state = parserState_CollectingContent
|
||||
}
|
||||
return events, true
|
||||
} else if overlapLen := overlap(acc, thinkingCloseTag); overlapLen > 0 {
|
||||
// Partial closing tag - withhold it along with any trailing whitespace before it
|
||||
beforePartialTag := acc[:len(acc)-overlapLen]
|
||||
trailingWsLen := trailingWhitespaceLen(beforePartialTag)
|
||||
ambiguousStart := len(beforePartialTag) - trailingWsLen
|
||||
|
||||
unambiguous := acc[:ambiguousStart]
|
||||
ambiguous := acc[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, eventThinkingContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
} else {
|
||||
// Pure thinking content - withhold trailing whitespace (might precede closing tag)
|
||||
whitespaceLen := trailingWhitespaceLen(acc)
|
||||
ambiguousStart := len(acc) - whitespaceLen
|
||||
|
||||
unambiguous := acc[:ambiguousStart]
|
||||
ambiguous := acc[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, eventThinkingContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
}
|
||||
|
||||
case parserState_ThinkingDoneEatingWhitespace:
|
||||
return p.eatLeadingWhitespaceAndTransitionTo(parserState_CollectingContent)
|
||||
|
||||
case parserState_CollectingContent:
|
||||
if strings.Contains(p.buffer.String(), toolOpenTag) {
|
||||
before, after := p.splitAtTag(toolOpenTag, true)
|
||||
if len(before) > 0 {
|
||||
events = append(events, eventContent{content: before})
|
||||
}
|
||||
if after == "" {
|
||||
p.state = parserState_ToolStartedEatingWhitespace
|
||||
} else {
|
||||
p.state = parserState_CollectingToolContent
|
||||
}
|
||||
return events, true
|
||||
} else if overlapLen := overlap(p.buffer.String(), toolOpenTag); overlapLen > 0 {
|
||||
beforePartialTag := p.buffer.String()[:len(p.buffer.String())-overlapLen]
|
||||
trailingWsLen := trailingWhitespaceLen(beforePartialTag)
|
||||
ambiguousStart := len(beforePartialTag) - trailingWsLen
|
||||
|
||||
unambiguous := p.buffer.String()[:ambiguousStart]
|
||||
ambiguous := p.buffer.String()[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, eventContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
} else {
|
||||
whitespaceLen := trailingWhitespaceLen(p.buffer.String())
|
||||
ambiguousStart := len(p.buffer.String()) - whitespaceLen
|
||||
|
||||
unambiguous := p.buffer.String()[:ambiguousStart]
|
||||
ambiguous := p.buffer.String()[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, eventContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
}
|
||||
|
||||
case parserState_ToolStartedEatingWhitespace:
|
||||
return p.eatLeadingWhitespaceAndTransitionTo(parserState_CollectingToolContent)
|
||||
|
||||
case parserState_CollectingToolContent:
|
||||
acc := p.buffer.String()
|
||||
if strings.Contains(acc, toolCloseTag) {
|
||||
toolContent, _ := p.splitAtTag(toolCloseTag, true)
|
||||
if len(toolContent) == 0 {
|
||||
slog.Warn("glm4 tool call closing tag found but no content before it")
|
||||
}
|
||||
events = append(events, eventRawToolCall{raw: toolContent})
|
||||
p.state = parserState_CollectingContent
|
||||
return events, true
|
||||
} else {
|
||||
// Keep accumulating - tool calls are not streamed
|
||||
// We just wait for the closing tag
|
||||
return events, false
|
||||
}
|
||||
|
||||
default:
|
||||
panic("unreachable")
|
||||
}
|
||||
}
|
||||
|
||||
// overlap returns the length of the overlap between the end of s and the start of tag.
|
||||
func overlap(s, tag string) int {
|
||||
for i := 1; i <= len(tag) && i <= len(s); i++ {
|
||||
if strings.HasSuffix(s, tag[:i]) {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// trailingWhitespaceLen returns the length of trailing whitespace in s.
|
||||
func trailingWhitespaceLen(s string) int {
|
||||
trimmed := strings.TrimRightFunc(s, unicode.IsSpace)
|
||||
return len(s) - len(trimmed)
|
||||
}
|
||||
|
||||
// ToolCallXML represents the structure of a GLM-4 tool call for XML parsing
|
||||
type ToolCallXML struct {
|
||||
XMLName xml.Name `xml:"tool_call"`
|
||||
Content string `xml:",chardata"` // Function name (text nodes between tags)
|
||||
Keys []string `xml:"arg_key"` // All arg_key elements in document order
|
||||
Values []string `xml:"arg_value"` // All arg_value elements in document order
|
||||
}
|
||||
|
||||
// escapeContent escapes XML entities in text content while preserving arg_key/arg_value tags
|
||||
func escapeContent(s string) string {
|
||||
var result strings.Builder
|
||||
inTag := false
|
||||
|
||||
for i := range len(s) {
|
||||
ch := s[i]
|
||||
|
||||
if ch == '<' {
|
||||
// Check if this is a known tag
|
||||
if strings.HasPrefix(s[i:], "<arg_key>") ||
|
||||
strings.HasPrefix(s[i:], "</arg_key>") ||
|
||||
strings.HasPrefix(s[i:], "<arg_value>") ||
|
||||
strings.HasPrefix(s[i:], "</arg_value>") {
|
||||
inTag = true
|
||||
}
|
||||
}
|
||||
|
||||
if inTag {
|
||||
result.WriteByte(ch)
|
||||
if ch == '>' {
|
||||
inTag = false
|
||||
}
|
||||
} else {
|
||||
// Escape special characters in text content
|
||||
switch ch {
|
||||
case '&':
|
||||
result.WriteString("&")
|
||||
case '<':
|
||||
result.WriteString("<")
|
||||
case '>':
|
||||
result.WriteString(">")
|
||||
default:
|
||||
result.WriteByte(ch)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result.String()
|
||||
}
|
||||
|
||||
func parseToolCall(raw eventRawToolCall, tools []api.Tool) (api.ToolCall, error) {
|
||||
// Escape any unescaped entities in text content
|
||||
escaped := escapeContent(raw.raw)
|
||||
|
||||
// Wrap the content in a root element to make it valid XML
|
||||
xmlString := "<tool_call>" + escaped + "</tool_call>"
|
||||
|
||||
// Parse XML into struct
|
||||
var parsed ToolCallXML
|
||||
if err := xml.Unmarshal([]byte(xmlString), &parsed); err != nil {
|
||||
return api.ToolCall{}, fmt.Errorf("failed to parse XML: %w", err)
|
||||
}
|
||||
|
||||
// Extract and trim function name
|
||||
functionName := strings.TrimSpace(parsed.Content)
|
||||
if functionName == "" {
|
||||
return api.ToolCall{}, fmt.Errorf("empty function name")
|
||||
}
|
||||
|
||||
// Verify keys and values are paired correctly
|
||||
if len(parsed.Keys) != len(parsed.Values) {
|
||||
return api.ToolCall{}, fmt.Errorf("mismatched arg_key and arg_value counts: %d keys, %d values", len(parsed.Keys), len(parsed.Values))
|
||||
}
|
||||
|
||||
// Find the matching tool to get parameter types
|
||||
var matchedTool *api.Tool
|
||||
for i := range tools {
|
||||
if tools[i].Function.Name == functionName {
|
||||
matchedTool = &tools[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Build arguments map by pairing keys and values
|
||||
toolCall := api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: functionName,
|
||||
Arguments: api.NewToolCallFunctionArguments(),
|
||||
},
|
||||
}
|
||||
|
||||
for i := range parsed.Keys {
|
||||
key := strings.TrimSpace(parsed.Keys[i])
|
||||
value := parsed.Values[i] // Don't trim here - parseValue handles it
|
||||
|
||||
// Look up parameter type
|
||||
var paramType api.PropertyType
|
||||
if matchedTool != nil && matchedTool.Function.Parameters.Properties != nil {
|
||||
if prop, ok := matchedTool.Function.Parameters.Properties.Get(key); ok {
|
||||
// Handle anyOf by collecting all types from the union
|
||||
if len(prop.AnyOf) > 0 {
|
||||
for _, anyOfProp := range prop.AnyOf {
|
||||
paramType = append(paramType, anyOfProp.Type...)
|
||||
}
|
||||
} else {
|
||||
paramType = prop.Type
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Parse value with type coercion
|
||||
toolCall.Function.Arguments.Set(key, parseValue(value, paramType))
|
||||
}
|
||||
|
||||
return toolCall, nil
|
||||
}
|
||||
|
||||
// parseValue parses a string value and coerces it to the appropriate type based on paramType.
|
||||
func parseValue(value string, paramType api.PropertyType) any {
|
||||
value = strings.TrimSpace(value)
|
||||
|
||||
// If no type specified, return as string
|
||||
if len(paramType) == 0 {
|
||||
return value
|
||||
}
|
||||
|
||||
// Try to parse based on specified types
|
||||
for _, t := range paramType {
|
||||
switch t {
|
||||
case "boolean":
|
||||
if value == "true" {
|
||||
return true
|
||||
}
|
||||
if value == "false" {
|
||||
return false
|
||||
}
|
||||
case "integer":
|
||||
var i int64
|
||||
if _, err := fmt.Sscanf(value, "%d", &i); err == nil {
|
||||
return i
|
||||
}
|
||||
case "number":
|
||||
var f float64
|
||||
if _, err := fmt.Sscanf(value, "%f", &f); err == nil {
|
||||
return f
|
||||
}
|
||||
case "array", "object":
|
||||
// Try to parse as JSON
|
||||
var result any
|
||||
if err := json.Unmarshal([]byte(value), &result); err == nil {
|
||||
return result
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Default to string
|
||||
return value
|
||||
}
|
||||
192
x/imagegen/models/glm4_moe_lite/parser_test.go
Normal file
192
x/imagegen/models/glm4_moe_lite/parser_test.go
Normal file
@@ -0,0 +1,192 @@
|
||||
//go:build mlx
|
||||
|
||||
package glm4_moe_lite
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestParserThinking(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
thinkEnabled bool
|
||||
wantContent string
|
||||
wantThinking string
|
||||
wantToolCalls int
|
||||
}{
|
||||
{
|
||||
name: "thinking enabled - simple thinking then content",
|
||||
input: "Let me think about this...</think>Here is my answer.",
|
||||
thinkEnabled: true,
|
||||
wantThinking: "Let me think about this...",
|
||||
wantContent: "Here is my answer.",
|
||||
},
|
||||
{
|
||||
name: "thinking enabled - only thinking",
|
||||
input: "I need to consider multiple factors...",
|
||||
thinkEnabled: true,
|
||||
wantThinking: "I need to consider multiple factors...",
|
||||
wantContent: "",
|
||||
},
|
||||
{
|
||||
name: "thinking disabled - direct content",
|
||||
input: "Here is my direct answer.",
|
||||
thinkEnabled: false,
|
||||
wantThinking: "",
|
||||
wantContent: "Here is my direct answer.",
|
||||
},
|
||||
{
|
||||
name: "thinking with tool call",
|
||||
input: "Let me search for that...</think>I'll use a tool.<tool_call>search<arg_key>query</arg_key><arg_value>test</arg_value></tool_call>",
|
||||
thinkEnabled: true,
|
||||
wantThinking: "Let me search for that...",
|
||||
wantContent: "I'll use a tool.",
|
||||
wantToolCalls: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p := &Parser{}
|
||||
|
||||
var thinkValue *api.ThinkValue
|
||||
if tt.thinkEnabled {
|
||||
thinkValue = &api.ThinkValue{Value: true}
|
||||
} else {
|
||||
thinkValue = &api.ThinkValue{Value: false}
|
||||
}
|
||||
|
||||
// Define tools for tool call tests
|
||||
props := api.NewToolPropertiesMap()
|
||||
props.Set("query", api.ToolProperty{Type: api.PropertyType{"string"}})
|
||||
tools := []api.Tool{
|
||||
{
|
||||
Function: api.ToolFunction{
|
||||
Name: "search",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Properties: props,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
p.Init(tools, nil, thinkValue)
|
||||
|
||||
content, thinking, calls, err := p.Add(tt.input, true)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if thinking != tt.wantThinking {
|
||||
t.Errorf("thinking = %q, want %q", thinking, tt.wantThinking)
|
||||
}
|
||||
if content != tt.wantContent {
|
||||
t.Errorf("content = %q, want %q", content, tt.wantContent)
|
||||
}
|
||||
if len(calls) != tt.wantToolCalls {
|
||||
t.Errorf("len(calls) = %d, want %d", len(calls), tt.wantToolCalls)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParserToolCall(t *testing.T) {
|
||||
p := &Parser{}
|
||||
|
||||
props := api.NewToolPropertiesMap()
|
||||
props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}})
|
||||
props.Set("unit", api.ToolProperty{Type: api.PropertyType{"string"}})
|
||||
tools := []api.Tool{
|
||||
{
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Properties: props,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Initialize with thinking disabled
|
||||
tv := &api.ThinkValue{Value: false}
|
||||
p.Init(tools, nil, tv)
|
||||
|
||||
input := "<tool_call>get_weather<arg_key>location</arg_key><arg_value>San Francisco</arg_value><arg_key>unit</arg_key><arg_value>celsius</arg_value></tool_call>"
|
||||
|
||||
_, _, calls, err := p.Add(input, true)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 tool call, got %d", len(calls))
|
||||
}
|
||||
|
||||
call := calls[0]
|
||||
if call.Function.Name != "get_weather" {
|
||||
t.Errorf("function name = %q, want %q", call.Function.Name, "get_weather")
|
||||
}
|
||||
|
||||
location, ok := call.Function.Arguments.Get("location")
|
||||
if !ok || location != "San Francisco" {
|
||||
t.Errorf("location = %v, want %q", location, "San Francisco")
|
||||
}
|
||||
|
||||
unit, ok := call.Function.Arguments.Get("unit")
|
||||
if !ok || unit != "celsius" {
|
||||
t.Errorf("unit = %v, want %q", unit, "celsius")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOverlap(t *testing.T) {
|
||||
tests := []struct {
|
||||
s string
|
||||
tag string
|
||||
want int
|
||||
}{
|
||||
{"hello<", "</think>", 1},
|
||||
{"hello</", "</think>", 2},
|
||||
{"hello</t", "</think>", 3},
|
||||
{"hello</th", "</think>", 4},
|
||||
{"hello</thi", "</think>", 5},
|
||||
{"hello</thin", "</think>", 6},
|
||||
{"hello</think", "</think>", 7},
|
||||
{"hello</think>", "</think>", 8}, // Complete tag at end returns full length
|
||||
{"hello", "</think>", 0},
|
||||
{"", "</think>", 0},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.s+"_"+tt.tag, func(t *testing.T) {
|
||||
got := overlap(tt.s, tt.tag)
|
||||
if got != tt.want {
|
||||
t.Errorf("overlap(%q, %q) = %d, want %d", tt.s, tt.tag, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTrailingWhitespaceLen(t *testing.T) {
|
||||
tests := []struct {
|
||||
s string
|
||||
want int
|
||||
}{
|
||||
{"hello ", 3},
|
||||
{"hello\n\t ", 3},
|
||||
{"hello", 0},
|
||||
{"", 0},
|
||||
{" ", 3},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.s, func(t *testing.T) {
|
||||
got := trailingWhitespaceLen(tt.s)
|
||||
if got != tt.want {
|
||||
t.Errorf("trailingWhitespaceLen(%q) = %d, want %d", tt.s, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
175
x/imagegen/models/glm4_moe_lite/render.go
Normal file
175
x/imagegen/models/glm4_moe_lite/render.go
Normal file
@@ -0,0 +1,175 @@
|
||||
//go:build mlx
|
||||
|
||||
package glm4_moe_lite
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// Renderer renders messages for GLM4-MoE-Lite models.
|
||||
//
|
||||
// GLM-4 Thinking Modes (ref: https://docs.z.ai/guides/capabilities/thinking-mode):
|
||||
//
|
||||
// 1. INTERLEAVED THINKING
|
||||
// The model thinks between tool calls and after receiving tool results.
|
||||
// This enables complex step-by-step reasoning: interpreting each tool output
|
||||
// before deciding what to do next. Thinking blocks are preserved and returned
|
||||
// with tool results to maintain reasoning continuity.
|
||||
//
|
||||
// 2. PRESERVED THINKING
|
||||
// The model retains reasoning content from previous assistant turns in context.
|
||||
// This preserves reasoning continuity across multi-turn conversations. The
|
||||
// upstream API has a "clear_thinking" parameter to control this:
|
||||
// - clear_thinking=true: clears reasoning from previous turns (outputs </think>)
|
||||
// - clear_thinking=false: preserves <think>...</think> blocks from previous turns
|
||||
//
|
||||
// 3. TURN-LEVEL THINKING
|
||||
// Controls whether the model should reason on each turn. The upstream API
|
||||
// uses "enable_thinking" parameter:
|
||||
// - enable_thinking=true: outputs <think> to start reasoning
|
||||
// - enable_thinking=false: outputs </think> to skip reasoning
|
||||
//
|
||||
// OLLAMA DEFAULTS:
|
||||
// - Thinking is ENABLED by default (thinkValue=nil or true outputs <think>)
|
||||
// - Thinking is PRESERVED by default (reasoning content from previous turns is always
|
||||
// included in <think>...</think> blocks, equivalent to clear_thinking=false)
|
||||
// - Users can disable thinking per-turn via thinkValue=false
|
||||
type Renderer struct{}
|
||||
|
||||
// Render renders messages into the GLM4 chat format.
|
||||
func (r *Renderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString("[gMASK]<sop>")
|
||||
|
||||
if len(tools) > 0 {
|
||||
sb.WriteString("<|system|>\n")
|
||||
sb.WriteString("# Tools\n\n")
|
||||
sb.WriteString("You may call one or more functions to assist with the user query.\n\n")
|
||||
sb.WriteString("You are provided with function signatures within <tools></tools> XML tags:\n")
|
||||
sb.WriteString("<tools>\n")
|
||||
for _, tool := range tools {
|
||||
d, _ := json.Marshal(tool)
|
||||
sb.WriteString(formatToolJSON(d))
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
sb.WriteString("</tools>\n\n")
|
||||
sb.WriteString("For each function call, output the function name and arguments within the following XML format:\n")
|
||||
sb.WriteString("<tool_call>{function-name}<arg_key>{arg-key-1}</arg_key><arg_value>{arg-value-1}</arg_value><arg_key>{arg-key-2}</arg_key><arg_value>{arg-value-2}</arg_value>...</tool_call>")
|
||||
}
|
||||
|
||||
think := true
|
||||
if thinkValue != nil && !thinkValue.Bool() {
|
||||
think = false
|
||||
}
|
||||
|
||||
for i, message := range messages {
|
||||
switch message.Role {
|
||||
case "user":
|
||||
sb.WriteString("<|user|>")
|
||||
sb.WriteString(message.Content)
|
||||
case "assistant":
|
||||
sb.WriteString("<|assistant|>")
|
||||
if message.Thinking != "" {
|
||||
sb.WriteString("<think>" + message.Thinking + "</think>")
|
||||
} else {
|
||||
sb.WriteString("</think>")
|
||||
}
|
||||
if message.Content != "" {
|
||||
sb.WriteString(message.Content)
|
||||
}
|
||||
if len(message.ToolCalls) > 0 {
|
||||
for _, toolCall := range message.ToolCalls {
|
||||
sb.WriteString("<tool_call>" + toolCall.Function.Name)
|
||||
sb.WriteString(renderToolArguments(toolCall.Function.Arguments))
|
||||
sb.WriteString("</tool_call>")
|
||||
}
|
||||
}
|
||||
case "tool":
|
||||
if i == 0 || messages[i-1].Role != "tool" {
|
||||
sb.WriteString("<|observation|>")
|
||||
}
|
||||
sb.WriteString("<tool_response>")
|
||||
sb.WriteString(message.Content)
|
||||
sb.WriteString("</tool_response>")
|
||||
case "system":
|
||||
sb.WriteString("<|system|>")
|
||||
sb.WriteString(message.Content)
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString("<|assistant|>")
|
||||
if think {
|
||||
sb.WriteString("<think>")
|
||||
} else {
|
||||
sb.WriteString("</think>")
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
||||
// renderToolArguments converts tool call arguments to GLM4 XML format.
|
||||
func renderToolArguments(args api.ToolCallFunctionArguments) string {
|
||||
var sb strings.Builder
|
||||
for key, value := range args.All() {
|
||||
sb.WriteString("<arg_key>" + key + "</arg_key>")
|
||||
var valueStr string
|
||||
if str, ok := value.(string); ok {
|
||||
valueStr = str
|
||||
} else {
|
||||
jsonBytes, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
valueStr = fmt.Sprintf("%v", value)
|
||||
} else {
|
||||
valueStr = string(jsonBytes)
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString("<arg_value>" + valueStr + "</arg_value>")
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// formatToolJSON formats JSON for GLM4 tool definitions by adding spaces after : and ,
|
||||
func formatToolJSON(raw []byte) string {
|
||||
var sb strings.Builder
|
||||
sb.Grow(len(raw) + len(raw)/10)
|
||||
|
||||
inString := false
|
||||
escaped := false
|
||||
for i := range raw {
|
||||
ch := raw[i]
|
||||
sb.WriteByte(ch)
|
||||
|
||||
if inString {
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if ch == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if ch == '"' {
|
||||
inString = false
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if ch == '"' {
|
||||
inString = true
|
||||
continue
|
||||
}
|
||||
|
||||
if ch == ':' || ch == ',' {
|
||||
sb.WriteByte(' ')
|
||||
}
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
205
x/imagegen/models/glm4_moe_lite/render_test.go
Normal file
205
x/imagegen/models/glm4_moe_lite/render_test.go
Normal file
@@ -0,0 +1,205 @@
|
||||
//go:build mlx
|
||||
|
||||
package glm4_moe_lite
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestRendererSimple(t *testing.T) {
|
||||
r := &Renderer{}
|
||||
|
||||
messages := []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
}
|
||||
|
||||
// Thinking enabled (default)
|
||||
result, err := r.Render(messages, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
expected := "[gMASK]<sop><|user|>Hello<|assistant|><think>"
|
||||
if result != expected {
|
||||
t.Errorf("result = %q, want %q", result, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRendererThinkingDisabled(t *testing.T) {
|
||||
r := &Renderer{}
|
||||
|
||||
messages := []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
}
|
||||
|
||||
tv := &api.ThinkValue{Value: false}
|
||||
|
||||
result, err := r.Render(messages, nil, tv)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
expected := "[gMASK]<sop><|user|>Hello<|assistant|></think>"
|
||||
if result != expected {
|
||||
t.Errorf("result = %q, want %q", result, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRendererMultiTurn(t *testing.T) {
|
||||
r := &Renderer{}
|
||||
|
||||
messages := []api.Message{
|
||||
{Role: "user", Content: "What is 2+2?"},
|
||||
{Role: "assistant", Content: "4", Thinking: "Let me calculate: 2+2=4"},
|
||||
{Role: "user", Content: "And 3+3?"},
|
||||
}
|
||||
|
||||
result, err := r.Render(messages, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Check key parts
|
||||
if !strings.Contains(result, "[gMASK]<sop>") {
|
||||
t.Error("missing [gMASK]<sop> prefix")
|
||||
}
|
||||
if !strings.Contains(result, "<|user|>What is 2+2?") {
|
||||
t.Error("missing first user message")
|
||||
}
|
||||
if !strings.Contains(result, "<|assistant|><think>Let me calculate: 2+2=4</think>4") {
|
||||
t.Error("missing assistant message with thinking")
|
||||
}
|
||||
if !strings.Contains(result, "<|user|>And 3+3?") {
|
||||
t.Error("missing second user message")
|
||||
}
|
||||
if !strings.HasSuffix(result, "<|assistant|><think>") {
|
||||
t.Errorf("should end with <|assistant|><think>, got suffix: %q", result[len(result)-30:])
|
||||
}
|
||||
}
|
||||
|
||||
func TestRendererWithSystem(t *testing.T) {
|
||||
r := &Renderer{}
|
||||
|
||||
messages := []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant."},
|
||||
{Role: "user", Content: "Hello"},
|
||||
}
|
||||
|
||||
result, err := r.Render(messages, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(result, "<|system|>You are a helpful assistant.") {
|
||||
t.Error("missing system message")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRendererWithTools(t *testing.T) {
|
||||
r := &Renderer{}
|
||||
|
||||
messages := []api.Message{
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
}
|
||||
|
||||
props := api.NewToolPropertiesMap()
|
||||
props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}, Description: "The city"})
|
||||
tools := []api.Tool{
|
||||
{
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get the weather for a location",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: props,
|
||||
Required: []string{"location"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := r.Render(messages, tools, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Check for tool system prompt
|
||||
if !strings.Contains(result, "<|system|>") {
|
||||
t.Error("missing system tag for tools")
|
||||
}
|
||||
if !strings.Contains(result, "# Tools") {
|
||||
t.Error("missing tools header")
|
||||
}
|
||||
if !strings.Contains(result, "<tools>") {
|
||||
t.Error("missing tools tag")
|
||||
}
|
||||
if !strings.Contains(result, "get_weather") {
|
||||
t.Error("missing tool name")
|
||||
}
|
||||
if !strings.Contains(result, "</tools>") {
|
||||
t.Error("missing closing tools tag")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRendererWithToolCalls(t *testing.T) {
|
||||
r := &Renderer{}
|
||||
|
||||
args := api.NewToolCallFunctionArguments()
|
||||
args.Set("location", "San Francisco")
|
||||
|
||||
messages := []api.Message{
|
||||
{Role: "user", Content: "What's the weather in SF?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: args,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: "Sunny, 72F"},
|
||||
}
|
||||
|
||||
result, err := r.Render(messages, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(result, "<tool_call>get_weather") {
|
||||
t.Error("missing tool call")
|
||||
}
|
||||
if !strings.Contains(result, "<arg_key>location</arg_key>") {
|
||||
t.Error("missing arg_key")
|
||||
}
|
||||
if !strings.Contains(result, "<arg_value>San Francisco</arg_value>") {
|
||||
t.Error("missing arg_value")
|
||||
}
|
||||
if !strings.Contains(result, "</tool_call>") {
|
||||
t.Error("missing tool call closing tag")
|
||||
}
|
||||
if !strings.Contains(result, "<|observation|>") {
|
||||
t.Error("missing observation tag")
|
||||
}
|
||||
if !strings.Contains(result, "<tool_response>Sunny, 72F</tool_response>") {
|
||||
t.Error("missing tool response")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatToolJSON(t *testing.T) {
|
||||
input := []byte(`{"name":"test","value":123}`)
|
||||
result := formatToolJSON(input)
|
||||
|
||||
// Should add spaces after : and ,
|
||||
if !strings.Contains(result, ": ") {
|
||||
t.Error("should add space after colon")
|
||||
}
|
||||
if !strings.Contains(result, ", ") {
|
||||
t.Error("should add space after comma")
|
||||
}
|
||||
}
|
||||
@@ -32,10 +32,16 @@ func NewLinear(weight *mlx.Array, bias *mlx.Array) *Linear {
|
||||
|
||||
// NewQuantizedLinear creates a quantized linear layer directly from bf16 weights.
|
||||
// Quantizes the weight immediately and evaluates to break lazy dependencies.
|
||||
// Note: For modes like "nvfp4", qbiases will be nil.
|
||||
func NewQuantizedLinear(weight *mlx.Array, bias *mlx.Array, groupSize, bits int, mode string) *QuantizedLinear {
|
||||
qw, scales, qbiases := mlx.Quantize(weight, groupSize, bits, mode)
|
||||
// Eval immediately so bf16 weight can be freed
|
||||
mlx.Eval(qw, scales, qbiases)
|
||||
// Handle modes that don't return qbiases (e.g., nvfp4)
|
||||
if qbiases != nil {
|
||||
mlx.Eval(qw, scales, qbiases)
|
||||
} else {
|
||||
mlx.Eval(qw, scales)
|
||||
}
|
||||
return &QuantizedLinear{
|
||||
Weight: qw,
|
||||
Scales: scales,
|
||||
@@ -77,10 +83,13 @@ func (l *Linear) ToQuantized(groupSize, bits int, mode string) *QuantizedLinear
|
||||
|
||||
// QuantizedLinear applies an affine transformation using quantized weights.
|
||||
// Equivalent to mlx.nn.QuantizedLinear.
|
||||
// Supports multiple quantization modes:
|
||||
// - "affine": scale + zero-point bias (QBiases required)
|
||||
// - "nvfp4": NVIDIA FP4 with E4M3 scales (QBiases nil)
|
||||
type QuantizedLinear struct {
|
||||
Weight *mlx.Array // Quantized weight data
|
||||
Scales *mlx.Array // Scale factors for dequantization
|
||||
QBiases *mlx.Array // Quantization biases (NOT layer bias)
|
||||
QBiases *mlx.Array // Quantization biases (NOT layer bias), nil for nvfp4
|
||||
Bias *mlx.Array // Layer bias [output_dims] or nil
|
||||
GroupSize int
|
||||
Bits int
|
||||
|
||||
@@ -1,284 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package runner provides a subprocess server for image generation.
|
||||
// It listens on a port and handles HTTP requests for image generation.
|
||||
package runner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"image"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/models/flux2"
|
||||
"github.com/ollama/ollama/x/imagegen/models/zimage"
|
||||
)
|
||||
|
||||
// Request is the image generation request format
|
||||
type Request struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Width int32 `json:"width,omitempty"`
|
||||
Height int32 `json:"height,omitempty"`
|
||||
Steps int `json:"steps,omitempty"`
|
||||
Seed int64 `json:"seed,omitempty"`
|
||||
Images [][]byte `json:"images,omitempty"` // Input images for image editing/conditioning
|
||||
}
|
||||
|
||||
// Response is streamed back for each progress update
|
||||
type Response struct {
|
||||
Content string `json:"content,omitempty"`
|
||||
Image string `json:"image,omitempty"` // Base64-encoded PNG
|
||||
Done bool `json:"done"`
|
||||
Step int `json:"step,omitempty"`
|
||||
Total int `json:"total,omitempty"`
|
||||
}
|
||||
|
||||
// ImageModel is the interface for image generation models
|
||||
type ImageModel interface {
|
||||
GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64, progress func(step, total int)) (*mlx.Array, error)
|
||||
}
|
||||
|
||||
// ImageEditModel extends ImageModel with image editing/conditioning capability.
|
||||
// Models that support input images for editing should implement this interface.
|
||||
type ImageEditModel interface {
|
||||
ImageModel
|
||||
GenerateImageWithInputs(ctx context.Context, prompt string, width, height int32, steps int, seed int64, inputImages []image.Image, progress func(step, total int)) (*mlx.Array, error)
|
||||
}
|
||||
|
||||
// Server holds the model and handles requests
|
||||
type Server struct {
|
||||
mu sync.Mutex
|
||||
model ImageModel
|
||||
modelName string
|
||||
}
|
||||
|
||||
// Execute is the entry point for the image runner subprocess
|
||||
func Execute(args []string) error {
|
||||
fs := flag.NewFlagSet("image-runner", flag.ExitOnError)
|
||||
modelName := fs.String("model", "", "path to image model")
|
||||
port := fs.Int("port", 0, "port to listen on")
|
||||
|
||||
if err := fs.Parse(args); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if *modelName == "" {
|
||||
return fmt.Errorf("--model is required")
|
||||
}
|
||||
if *port == 0 {
|
||||
return fmt.Errorf("--port is required")
|
||||
}
|
||||
|
||||
err := mlx.InitMLX()
|
||||
if err != nil {
|
||||
slog.Error("unable to initialize MLX", "error", err)
|
||||
return err
|
||||
}
|
||||
slog.Info("MLX library initialized")
|
||||
slog.Info("starting image runner", "model", *modelName, "port", *port)
|
||||
|
||||
// Detect model type and load appropriate model
|
||||
modelType := imagegen.DetectModelType(*modelName)
|
||||
slog.Info("detected model type", "type", modelType)
|
||||
|
||||
var model ImageModel
|
||||
switch modelType {
|
||||
case "Flux2KleinPipeline":
|
||||
m := &flux2.Model{}
|
||||
if err := m.Load(*modelName); err != nil {
|
||||
return fmt.Errorf("failed to load model: %w", err)
|
||||
}
|
||||
model = m
|
||||
default:
|
||||
// Default to Z-Image for ZImagePipeline, FluxPipeline, etc.
|
||||
m := &zimage.Model{}
|
||||
if err := m.Load(*modelName); err != nil {
|
||||
return fmt.Errorf("failed to load model: %w", err)
|
||||
}
|
||||
model = m
|
||||
}
|
||||
|
||||
server := &Server{
|
||||
model: model,
|
||||
modelName: *modelName,
|
||||
}
|
||||
|
||||
// Set up HTTP handlers
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/health", server.healthHandler)
|
||||
mux.HandleFunc("/completion", server.completionHandler)
|
||||
|
||||
httpServer := &http.Server{
|
||||
Addr: fmt.Sprintf("127.0.0.1:%d", *port),
|
||||
Handler: mux,
|
||||
}
|
||||
|
||||
// Handle shutdown
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-sigCh
|
||||
slog.Info("shutting down image runner")
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
httpServer.Shutdown(ctx)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
slog.Info("image runner listening", "addr", httpServer.Addr)
|
||||
if err := httpServer.ListenAndServe(); err != http.ErrServerClosed {
|
||||
return err
|
||||
}
|
||||
|
||||
<-done
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) healthHandler(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
|
||||
}
|
||||
|
||||
func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
var req Request
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate and decode input images
|
||||
const maxInputImages = 2
|
||||
if len(req.Images) > maxInputImages {
|
||||
http.Error(w, fmt.Sprintf("too many input images, maximum is %d", maxInputImages), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var inputImages []image.Image
|
||||
if len(req.Images) > 0 {
|
||||
// TODO: add memory check for input images
|
||||
|
||||
inputImages = make([]image.Image, len(req.Images))
|
||||
for i, imgBytes := range req.Images {
|
||||
img, err := imagegen.DecodeImage(imgBytes)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("invalid image %d: %v", i, err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
inputImages[i] = img
|
||||
}
|
||||
slog.Info("decoded input images", "count", len(inputImages))
|
||||
|
||||
// Default width/height to first input image dimensions, scaled to max 1024
|
||||
bounds := inputImages[0].Bounds()
|
||||
w, h := bounds.Dx(), bounds.Dy()
|
||||
if w > 1024 || h > 1024 {
|
||||
if w > h {
|
||||
h = h * 1024 / w
|
||||
w = 1024
|
||||
} else {
|
||||
w = w * 1024 / h
|
||||
h = 1024
|
||||
}
|
||||
}
|
||||
req.Width = int32(w)
|
||||
req.Height = int32(h)
|
||||
}
|
||||
|
||||
// Serialize generation requests - MLX model may not handle concurrent generation
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Model applies its own defaults for width/height/steps
|
||||
// Only seed needs to be set here if not provided
|
||||
if req.Seed <= 0 {
|
||||
req.Seed = time.Now().UnixNano()
|
||||
}
|
||||
|
||||
// Set up streaming response
|
||||
w.Header().Set("Content-Type", "application/x-ndjson")
|
||||
w.Header().Set("Transfer-Encoding", "chunked")
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
http.Error(w, "streaming not supported", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Generate image using the common interface
|
||||
ctx := r.Context()
|
||||
enc := json.NewEncoder(w)
|
||||
|
||||
// Progress callback streams step updates
|
||||
progress := func(step, total int) {
|
||||
resp := Response{Step: step, Total: total}
|
||||
enc.Encode(resp)
|
||||
w.Write([]byte("\n"))
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
// Use ImageEditModel if available and images provided, otherwise use basic ImageModel
|
||||
var img *mlx.Array
|
||||
var err error
|
||||
if len(inputImages) > 0 {
|
||||
editModel, ok := s.model.(ImageEditModel)
|
||||
if !ok {
|
||||
http.Error(w, "model does not support image editing", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
img, err = editModel.GenerateImageWithInputs(ctx, req.Prompt, req.Width, req.Height, req.Steps, req.Seed, inputImages, progress)
|
||||
} else {
|
||||
img, err = s.model.GenerateImage(ctx, req.Prompt, req.Width, req.Height, req.Steps, req.Seed, progress)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
// Don't send error for cancellation
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
resp := Response{Content: fmt.Sprintf("error: %v", err), Done: true}
|
||||
data, _ := json.Marshal(resp)
|
||||
w.Write(data)
|
||||
w.Write([]byte("\n"))
|
||||
return
|
||||
}
|
||||
|
||||
// Encode image as base64 PNG
|
||||
imageData, err := imagegen.EncodeImageBase64(img)
|
||||
if err != nil {
|
||||
resp := Response{Content: fmt.Sprintf("error encoding: %v", err), Done: true}
|
||||
data, _ := json.Marshal(resp)
|
||||
w.Write(data)
|
||||
w.Write([]byte("\n"))
|
||||
return
|
||||
}
|
||||
|
||||
// Free the generated image array and clean up MLX state
|
||||
img.Free()
|
||||
mlx.ClearCache()
|
||||
mlx.MetalResetPeakMemory()
|
||||
|
||||
// Send final response with image data
|
||||
resp := Response{
|
||||
Image: imageData,
|
||||
Done: true,
|
||||
}
|
||||
data, _ := json.Marshal(resp)
|
||||
w.Write(data)
|
||||
w.Write([]byte("\n"))
|
||||
flusher.Flush()
|
||||
}
|
||||
@@ -17,17 +17,26 @@ type WeightSource interface {
|
||||
GetTensor(name string) (*mlx.Array, error)
|
||||
ListTensors() []string
|
||||
HasTensor(name string) bool
|
||||
Quantization() string // Returns "FP4", "FP8", or ""
|
||||
Quantization() string // Returns "NVFP4", "FP4", "FP8", or ""
|
||||
}
|
||||
|
||||
// quantizationParams returns groupSize, bits, mode for a quantization type.
|
||||
// Returns defaults (32, 8, "affine") for unknown types (backward compatibility).
|
||||
func quantizationParams(quantization string) (groupSize, bits int, mode string) {
|
||||
// QuantizationParams returns groupSize, bits, mode for a quantization type.
|
||||
// MLX quantization modes:
|
||||
// - "affine": scale + zero-point bias, group_size=32/64/128
|
||||
// - "nvfp4": NVIDIA FP4 with E4M3 scales, group_size=16 (no bias)
|
||||
func QuantizationParams(quantization string) (groupSize, bits int, mode string) {
|
||||
switch strings.ToUpper(quantization) {
|
||||
case "FP4":
|
||||
case "NVFP4":
|
||||
// NVIDIA FP4: group_size=16, bits=4, E4M3 scales (no qbias)
|
||||
return 16, 4, "nvfp4"
|
||||
case "FP4", "Q4", "INT4":
|
||||
// 4-bit quantization with affine mode (scale + qbias)
|
||||
return 32, 4, "affine"
|
||||
case "FP8", "Q8", "INT8", "":
|
||||
// 8-bit quantization with affine mode (default for quantized models)
|
||||
return 32, 8, "affine"
|
||||
default:
|
||||
return 32, 8, "affine" // FP8 or unknown
|
||||
return 32, 8, "affine" // Default to affine
|
||||
}
|
||||
}
|
||||
|
||||
@@ -122,7 +131,8 @@ func loadStruct(v reflect.Value, weights WeightSource, prefix string, errs *[]st
|
||||
}
|
||||
|
||||
// Handle nn.LinearLayer interface fields specially
|
||||
if field.Type == reflect.TypeOf((*nn.LinearLayer)(nil)).Elem() {
|
||||
linearLayerType := reflect.TypeOf((*nn.LinearLayer)(nil)).Elem()
|
||||
if field.Type == linearLayerType {
|
||||
if !hasTag {
|
||||
continue // no tag = skip
|
||||
}
|
||||
@@ -217,11 +227,12 @@ func joinPath(prefix, suffix string) string {
|
||||
}
|
||||
|
||||
// LoadLinearLayer loads a linear layer from weights, automatically detecting if it's quantized.
|
||||
// If {path}.weight_scale exists, dequantizes the weights.
|
||||
// If {path}.weight_scale exists, creates a QuantizedLinear layer (or dequantizes if no kernel support).
|
||||
func LoadLinearLayer(weights WeightSource, path string) (nn.LinearLayer, error) {
|
||||
// Check if this is a quantized layer by looking for scale tensor
|
||||
scalePath := path + ".weight_scale"
|
||||
if weights.HasTensor(scalePath) {
|
||||
hasScale := weights.HasTensor(scalePath)
|
||||
if hasScale {
|
||||
weight, err := weights.GetTensor(path + ".weight")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load quantized weight %s: %w", path, err)
|
||||
@@ -245,9 +256,11 @@ func LoadLinearLayer(weights WeightSource, path string) (nn.LinearLayer, error)
|
||||
qbiases, _ = weights.GetTensor(qbiasPath)
|
||||
}
|
||||
|
||||
groupSize, bits, mode := quantizationParams(weights.Quantization())
|
||||
groupSize, bits, mode := QuantizationParams(weights.Quantization())
|
||||
|
||||
if mlx.MetalIsAvailable() {
|
||||
// NVFP4 doesn't have native quantized matmul kernels in MLX yet,
|
||||
// so we always dequantize at load time. Affine modes (FP4, FP8) have kernel support.
|
||||
if mlx.MetalIsAvailable() && mode != "nvfp4" {
|
||||
return &nn.QuantizedLinear{
|
||||
Weight: weight,
|
||||
Scales: scales,
|
||||
|
||||
@@ -1,48 +0,0 @@
|
||||
package imagegen
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestPlatformSupport verifies platform validation works correctly.
|
||||
func TestPlatformSupport(t *testing.T) {
|
||||
err := CheckPlatformSupport()
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
if runtime.GOARCH == "arm64" {
|
||||
// Apple Silicon should be supported
|
||||
if err != nil {
|
||||
t.Errorf("Expected nil error on darwin/arm64, got: %v", err)
|
||||
}
|
||||
} else {
|
||||
// Intel Mac should fail
|
||||
if err == nil {
|
||||
t.Error("Expected error on darwin/amd64 (Intel), got nil")
|
||||
}
|
||||
if err != nil && err.Error() == "" {
|
||||
t.Error("Expected meaningful error message for unsupported platform")
|
||||
}
|
||||
}
|
||||
case "linux", "windows":
|
||||
// Linux/Windows are allowed (CUDA support checked at runtime)
|
||||
if err != nil {
|
||||
t.Errorf("Expected nil error on %s, got: %v", runtime.GOOS, err)
|
||||
}
|
||||
default:
|
||||
// Other platforms should fail
|
||||
if err == nil {
|
||||
t.Errorf("Expected error on unsupported platform %s, got nil", runtime.GOOS)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestServerInterfaceCompliance verifies Server implements llm.LlamaServer.
|
||||
// This is a compile-time check but we document it as a test.
|
||||
func TestServerInterfaceCompliance(t *testing.T) {
|
||||
// The var _ llm.LlamaServer = (*Server)(nil) line in server.go
|
||||
// ensures compile-time interface compliance.
|
||||
// This test documents that requirement.
|
||||
t.Log("Server implements llm.LlamaServer interface (compile-time checked)")
|
||||
}
|
||||
@@ -44,23 +44,54 @@ func LoadWeightsFromManifest(manifest *ModelManifest, component string) (*Manife
|
||||
}, nil
|
||||
}
|
||||
|
||||
// LoadAllWeightsFromManifest creates a weight loader for all tensors without component filtering.
|
||||
// Used for LLM models where tensors don't have a component prefix.
|
||||
func LoadAllWeightsFromManifest(manifest *ModelManifest) (*ManifestWeights, error) {
|
||||
layers := manifest.GetAllTensorLayers()
|
||||
if len(layers) == 0 {
|
||||
return nil, fmt.Errorf("no tensor layers found in manifest")
|
||||
}
|
||||
|
||||
tensors := make(map[string]ManifestLayer, len(layers))
|
||||
for _, layer := range layers {
|
||||
tensors[layer.Name] = layer
|
||||
}
|
||||
|
||||
return &ManifestWeights{
|
||||
manifest: manifest,
|
||||
tensors: tensors,
|
||||
cache: make(map[string]*mlx.Array),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Load loads all tensor blobs using native mmap (zero-copy).
|
||||
// Blobs are stored in safetensors format for native mlx_load_safetensors mmap.
|
||||
// If dtype is non-zero, tensors are converted to the specified dtype.
|
||||
func (mw *ManifestWeights) Load(dtype mlx.Dtype) error {
|
||||
// Track native handles to free after batch eval
|
||||
nativeHandles := make([]*mlx.SafetensorsFile, 0, len(mw.tensors))
|
||||
arrays := make([]*mlx.Array, 0, len(mw.tensors))
|
||||
|
||||
for name, layer := range mw.tensors {
|
||||
path := mw.manifest.BlobPath(layer.Digest)
|
||||
|
||||
// Load blob as safetensors (native mmap, zero-copy)
|
||||
sf, err := mlx.LoadSafetensorsNative(path)
|
||||
if err != nil {
|
||||
// Free any handles we've accumulated
|
||||
for _, h := range nativeHandles {
|
||||
h.Free()
|
||||
}
|
||||
return fmt.Errorf("load %s: %w", name, err)
|
||||
}
|
||||
nativeHandles = append(nativeHandles, sf)
|
||||
|
||||
// Blob contains single tensor named "data"
|
||||
arr := sf.Get("data")
|
||||
if arr == nil {
|
||||
sf.Free()
|
||||
for _, h := range nativeHandles {
|
||||
h.Free()
|
||||
}
|
||||
return fmt.Errorf("tensor 'data' not found in blob for %s", name)
|
||||
}
|
||||
|
||||
@@ -68,11 +99,18 @@ func (mw *ManifestWeights) Load(dtype mlx.Dtype) error {
|
||||
if dtype != 0 && arr.Dtype() != dtype {
|
||||
arr = mlx.AsType(arr, dtype)
|
||||
}
|
||||
// ALWAYS make a contiguous copy to ensure independence from mmap
|
||||
// Make contiguous copy to ensure independence from mmap
|
||||
arr = mlx.Contiguous(arr)
|
||||
mlx.Eval(arr)
|
||||
mw.cache[name] = arr
|
||||
sf.Free() // Safe to free - arr is now an independent copy
|
||||
arrays = append(arrays, arr)
|
||||
}
|
||||
|
||||
// Batch evaluate all tensors at once (much faster than one at a time)
|
||||
mlx.Eval(arrays...)
|
||||
|
||||
// Now safe to free all native handles
|
||||
for _, sf := range nativeHandles {
|
||||
sf.Free()
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -107,18 +145,95 @@ func (mw *ManifestWeights) HasTensor(name string) bool {
|
||||
}
|
||||
|
||||
// Quantization returns the model's quantization type from model_index.json.
|
||||
// Returns empty string if not quantized or unknown.
|
||||
// Returns empty string if not quantized.
|
||||
// Falls back to detecting from tensor names and shapes if not in config.
|
||||
func (mw *ManifestWeights) Quantization() string {
|
||||
if mw.manifest == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Try to read from model_index.json first
|
||||
var index struct {
|
||||
Quantization string `json:"quantization"`
|
||||
}
|
||||
if err := mw.manifest.ReadConfigJSON("model_index.json", &index); err != nil {
|
||||
if err := mw.manifest.ReadConfigJSON("model_index.json", &index); err == nil && index.Quantization != "" {
|
||||
return index.Quantization
|
||||
}
|
||||
|
||||
// Fallback: detect from tensor names
|
||||
// Check if any tensors have _scale suffix (indicates quantization)
|
||||
hasScales := false
|
||||
hasQBias := false
|
||||
for name := range mw.tensors {
|
||||
if strings.HasSuffix(name, ".weight_scale") {
|
||||
hasScales = true
|
||||
}
|
||||
if strings.HasSuffix(name, ".weight_qbias") {
|
||||
hasQBias = true
|
||||
}
|
||||
}
|
||||
|
||||
if !hasScales {
|
||||
// No scales = not quantized
|
||||
return ""
|
||||
}
|
||||
return index.Quantization
|
||||
|
||||
// Has scales but no qbias = NVFP4 (or other non-affine mode)
|
||||
if !hasQBias {
|
||||
return "NVFP4"
|
||||
}
|
||||
|
||||
// Has both scales and qbias = affine mode
|
||||
// Need to determine FP4 vs FP8 from tensor shapes
|
||||
// FP4: weight last dim is 1/8 of scales last dim * group_size
|
||||
// FP8: weight last dim is 1/4 of scales last dim * group_size
|
||||
//
|
||||
// For affine mode with group_size=32:
|
||||
// - FP4 (4 bits): 8 elements packed per uint32, so weight_dim = orig_dim / 8
|
||||
// - FP8 (8 bits): 4 elements packed per uint32, so weight_dim = orig_dim / 4
|
||||
// scales_dim = orig_dim / group_size
|
||||
// So: weight_dim / scales_dim = group_size / pack_factor
|
||||
// FP4: ratio = 32/8 = 4
|
||||
// FP8: ratio = 32/4 = 8
|
||||
|
||||
// Find a weight/scale pair to check the ratio
|
||||
for name := range mw.tensors {
|
||||
if !strings.HasSuffix(name, ".weight") || strings.Contains(name, "_scale") || strings.Contains(name, "_qbias") {
|
||||
continue
|
||||
}
|
||||
scaleName := name + "_scale"
|
||||
if _, ok := mw.tensors[scaleName]; !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// Load both tensors to check shapes
|
||||
weightLayer := mw.tensors[name]
|
||||
scaleLayer := mw.tensors[scaleName]
|
||||
|
||||
// Get shapes from manifest layer metadata if available
|
||||
// For now, default to FP4 since it's more common
|
||||
// The actual shape check would require loading the tensor
|
||||
|
||||
// Simple heuristic: check if scale tensor is ~4x smaller than weight
|
||||
// FP4: weight is packed 8 per uint32, scales are 1 per group (32)
|
||||
// So scale size should be ~weight_size * 8 / 32 = weight_size / 4
|
||||
// FP8: weight is packed 4 per uint32, scales are 1 per group (32)
|
||||
// So scale size should be ~weight_size * 4 / 32 = weight_size / 8
|
||||
|
||||
// Rough size heuristic (assuming float16 scales)
|
||||
// FP4: scale_bytes ≈ weight_bytes / 4 * 2 / 4 = weight_bytes / 8
|
||||
// FP8: scale_bytes ≈ weight_bytes / 8 * 2 / 4 = weight_bytes / 16
|
||||
ratio := float64(weightLayer.Size) / float64(scaleLayer.Size)
|
||||
if ratio < 12 {
|
||||
// Closer to 8 = FP4
|
||||
return "FP4"
|
||||
}
|
||||
// Closer to 16 = FP8
|
||||
return "FP8"
|
||||
}
|
||||
|
||||
// Default to FP4 for affine mode (most common)
|
||||
return "FP4"
|
||||
}
|
||||
|
||||
// ReleaseAll frees all native handles and clears the tensor cache.
|
||||
|
||||
@@ -1,797 +1,144 @@
|
||||
//go:build mlx
|
||||
|
||||
package kvcache
|
||||
|
||||
// import (
|
||||
// "errors"
|
||||
// "fmt"
|
||||
// "log/slog"
|
||||
// "math"
|
||||
// "slices"
|
||||
|
||||
// "github.com/ollama/ollama/ml"
|
||||
// "github.com/ollama/ollama/model/input"
|
||||
// )
|
||||
|
||||
// type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error)
|
||||
|
||||
// // Causal cache stores K and V tensors according to their position in the
|
||||
// // sequence. Returns the history and a mask for attending to past tokens
|
||||
// //
|
||||
// // The tensors are of shape embed dim, kv heads, batch size
|
||||
// // The mask is of shape history size, batch size
|
||||
// type Causal struct {
|
||||
// DType ml.DType
|
||||
|
||||
// // swaWindowSize is the number of tokens that will be included in the mask
|
||||
// // during attention operations. swaMemorySize is the number of tokens that
|
||||
// // will be retained in memory for partial prefix caching. Set to math.MaxInt32
|
||||
// // for unlimited or if sliding window attention is not being used.
|
||||
// swaWindowSize int32
|
||||
// swaMemorySize int32
|
||||
|
||||
// chunkSize int32
|
||||
|
||||
// opts CausalOptions
|
||||
|
||||
// // maxBatch is the largest batch that we might receive
|
||||
// maxBatch int
|
||||
|
||||
// // config controls mostly backend-specific optimizations
|
||||
// config *ml.CacheConfig
|
||||
|
||||
// // ** current forward pass **
|
||||
|
||||
// // size of the current batch
|
||||
// curBatchSize int
|
||||
|
||||
// // locations for data storage for this batch
|
||||
// curLoc ml.Tensor
|
||||
|
||||
// // mask of the cache as used by this batch
|
||||
// curMask ml.Tensor
|
||||
|
||||
// // the active layer for Get and Put
|
||||
// curLayer int
|
||||
|
||||
// // locations in the cache that are needed for this batch
|
||||
// curCellRange cellRange
|
||||
|
||||
// // curSequences is the sequences corresponding to this pass's entries in the cache
|
||||
// curSequences []int
|
||||
|
||||
// // curPositions is the positions corresponding to this pass's entries in the cache
|
||||
// curPositions []int32
|
||||
|
||||
// // ** cache metadata **
|
||||
|
||||
// // for each possible location in the cache, stores the position and set of sequences
|
||||
// // that reference the data there
|
||||
// cells []cacheCell
|
||||
|
||||
// // maps from sequence to the range of locations where it is stored in the cache
|
||||
// cellRanges map[int]cellRange
|
||||
|
||||
// // ** cache data storage **
|
||||
|
||||
// shiftFn shiftFn
|
||||
// backend ml.Backend
|
||||
// ctxs map[int]ml.Context
|
||||
// keys, values map[int]ml.Tensor
|
||||
|
||||
// kHeadDims, vHeadDims, numKVHeads map[int]int
|
||||
// }
|
||||
|
||||
// type cacheCell struct {
|
||||
// pos int32
|
||||
// sequences []int
|
||||
// }
|
||||
|
||||
// type cellRange struct {
|
||||
// min int
|
||||
// max int
|
||||
// }
|
||||
|
||||
// func NewCausalCache(shift shiftFn) *Causal {
|
||||
// return &Causal{
|
||||
// shiftFn: shift,
|
||||
// ctxs: make(map[int]ml.Context),
|
||||
// keys: make(map[int]ml.Tensor),
|
||||
// values: make(map[int]ml.Tensor),
|
||||
// kHeadDims: make(map[int]int),
|
||||
// vHeadDims: make(map[int]int),
|
||||
// numKVHeads: make(map[int]int),
|
||||
// }
|
||||
// }
|
||||
|
||||
// func NewSWACache(windowSize int32, shift shiftFn) *Causal {
|
||||
// return &Causal{
|
||||
// swaWindowSize: windowSize,
|
||||
// shiftFn: shift,
|
||||
// ctxs: make(map[int]ml.Context),
|
||||
// keys: make(map[int]ml.Tensor),
|
||||
// values: make(map[int]ml.Tensor),
|
||||
// kHeadDims: make(map[int]int),
|
||||
// vHeadDims: make(map[int]int),
|
||||
// numKVHeads: make(map[int]int),
|
||||
// }
|
||||
// }
|
||||
|
||||
// func NewSWAMemCache(windowSize int32, memorySize int32, shift shiftFn) *Causal {
|
||||
// return &Causal{
|
||||
// swaWindowSize: windowSize,
|
||||
// swaMemorySize: memorySize,
|
||||
// shiftFn: shift,
|
||||
// ctxs: make(map[int]ml.Context),
|
||||
// keys: make(map[int]ml.Tensor),
|
||||
// values: make(map[int]ml.Tensor),
|
||||
// kHeadDims: make(map[int]int),
|
||||
// vHeadDims: make(map[int]int),
|
||||
// numKVHeads: make(map[int]int),
|
||||
// }
|
||||
// }
|
||||
|
||||
// func NewChunkedAttentionCache(chunkSize int32, shift shiftFn) *Causal {
|
||||
// return &Causal{
|
||||
// chunkSize: chunkSize,
|
||||
// shiftFn: shift,
|
||||
// ctxs: make(map[int]ml.Context),
|
||||
// keys: make(map[int]ml.Tensor),
|
||||
// values: make(map[int]ml.Tensor),
|
||||
// kHeadDims: make(map[int]int),
|
||||
// vHeadDims: make(map[int]int),
|
||||
// numKVHeads: make(map[int]int),
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||
// if c.config == nil {
|
||||
// var config ml.CacheConfig
|
||||
// if cc, ok := backend.(ml.BackendCacheConfig); ok {
|
||||
// config = cc.CacheConfig()
|
||||
// }
|
||||
// c.config = &config
|
||||
// }
|
||||
|
||||
// if c.config.CachePadding == 0 {
|
||||
// c.config.CachePadding = 1
|
||||
// }
|
||||
|
||||
// if c.config.MaskBatchPadding == 0 {
|
||||
// c.config.MaskBatchPadding = 1
|
||||
// }
|
||||
|
||||
// // TODO what types do we handle here?
|
||||
// // if c.config.MaskDType == ml.DTypeOther {
|
||||
// // c.config.MaskDType = ml.DTypeFloat32
|
||||
// // }
|
||||
|
||||
// if c.swaWindowSize == 0 {
|
||||
// c.swaWindowSize = math.MaxInt32
|
||||
// }
|
||||
// if c.swaMemorySize == 0 {
|
||||
// c.swaMemorySize = c.swaWindowSize
|
||||
// }
|
||||
// // We will allocate space in the cache for the stop token, which won't be part of a follow on
|
||||
// // sequence, so allocate an extra token of storage to ensure that we can jump back without
|
||||
// // causing a cache break. As an optimization, only do this when we have parallel sequences
|
||||
// // because the extra token will live in the batch buffer and won't get overwritten if we
|
||||
// // only have a single sequence.
|
||||
// if c.swaMemorySize != math.MaxInt32 && maxSequences > 1 {
|
||||
// c.swaMemorySize = max(c.swaMemorySize, c.swaWindowSize+1)
|
||||
// }
|
||||
// if int(c.swaMemorySize) >= capacity {
|
||||
// c.swaMemorySize = math.MaxInt32
|
||||
// }
|
||||
|
||||
// if c.swaMemorySize < c.swaWindowSize {
|
||||
// panic(fmt.Errorf("sliding window memory (%v) must be at least as large as the window (%v)", c.swaMemorySize, c.swaWindowSize))
|
||||
// }
|
||||
|
||||
// var cacheSize int
|
||||
// if c.swaMemorySize == math.MaxInt32 {
|
||||
// cacheSize = maxSequences * capacity
|
||||
// } else {
|
||||
// cacheSize = (maxSequences * int(c.swaMemorySize)) + maxBatch
|
||||
// }
|
||||
// cacheSize = roundUp(cacheSize, c.config.CachePadding)
|
||||
// c.cells = make([]cacheCell, cacheSize)
|
||||
|
||||
// c.DType = dtype
|
||||
// c.cellRanges = make(map[int]cellRange)
|
||||
// c.backend = backend
|
||||
// c.maxBatch = maxBatch
|
||||
// }
|
||||
|
||||
// func (c *Causal) SetConfig(config ml.CacheConfig) {
|
||||
// if c.config != nil {
|
||||
// panic("config cannot be changed after being previously set, either by the model or backend")
|
||||
// }
|
||||
|
||||
// c.config = &config
|
||||
// }
|
||||
|
||||
// func (c *Causal) Close() {
|
||||
// slog.Info("XXX Causal.Close called", "number of contexts", len(c.ctxs))
|
||||
// for _, ctx := range c.ctxs {
|
||||
// ctx.Close()
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
||||
// slog.Info("XXX Causal.StartForward", "cell count", len(c.cells), "prior batch size", c.curBatchSize, "positions", len(batch.Positions), "reserve", reserve, "batch", batch)
|
||||
// // panic("XXX Causal.StartForward")
|
||||
// c.curBatchSize = len(batch.Positions)
|
||||
// c.curSequences = batch.Sequences
|
||||
// c.curPositions = batch.Positions
|
||||
// c.opts.Except = nil
|
||||
|
||||
// var locs []int32
|
||||
// if !reserve {
|
||||
// c.updateSlidingWindow()
|
||||
|
||||
// var err error
|
||||
// locs, err = c.findLocs()
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// slog.Info("XXX Causal.StartForward", "findLocs len", len(locs))
|
||||
|
||||
// for i, pos := range batch.Positions {
|
||||
// seq := batch.Sequences[i]
|
||||
// loc := int(locs[i])
|
||||
|
||||
// c.cells[loc] = cacheCell{pos: pos, sequences: []int{seq}}
|
||||
|
||||
// seqRange, ok := c.cellRanges[seq]
|
||||
// if !ok {
|
||||
// seqRange = newRange()
|
||||
// }
|
||||
|
||||
// seqRange.min = min(seqRange.min, loc)
|
||||
// c.curCellRange.min = min(c.curCellRange.min, loc)
|
||||
|
||||
// seqRange.max = max(seqRange.max, loc)
|
||||
// c.curCellRange.max = max(c.curCellRange.max, loc)
|
||||
|
||||
// c.cellRanges[seq] = seqRange
|
||||
// }
|
||||
// } else {
|
||||
// // If we are reserving memory, don't update any of the cache metadata but set the size
|
||||
// // to the worst case.
|
||||
// locs = make([]int32, c.curBatchSize)
|
||||
// for i := range locs {
|
||||
// locs[i] = int32(i)
|
||||
// }
|
||||
// c.curCellRange.min = 0
|
||||
// c.curCellRange.max = len(c.cells) - 1
|
||||
// }
|
||||
|
||||
// // XXX Building up the locs for what's already processed (if any)
|
||||
// dummyLocs := []int{}
|
||||
// c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding)
|
||||
// c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1
|
||||
|
||||
// for i := range c.curBatchSize {
|
||||
// enabled := !slices.Contains(c.opts.Except, i)
|
||||
// for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
|
||||
// if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) ||
|
||||
// (enabled && c.cells[j].pos > c.curPositions[i]) ||
|
||||
// c.chunkSize > 0 && c.cells[j].pos < c.curPositions[i]-c.curPositions[i]%c.chunkSize ||
|
||||
// c.cells[j].pos < c.curPositions[i]-c.swaWindowSize {
|
||||
// // mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
|
||||
// } else {
|
||||
// if len(dummyLocs) == 0 || dummyLocs[len(dummyLocs)-1] != i {
|
||||
// dummyLocs = append(dummyLocs, i)
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// slog.Info("XXX Causa.StartForward calculated locations", "locs", dummyLocs)
|
||||
|
||||
// slog.Info("XXX Causal.StartForward", "locs", locs)
|
||||
// c.curLoc = ctx.Input().FromInts(locs, len(locs))
|
||||
// c.curMask = c.buildMask(ctx)
|
||||
|
||||
// return nil
|
||||
// }
|
||||
|
||||
// func newRange() cellRange {
|
||||
// return cellRange{
|
||||
// min: math.MaxInt,
|
||||
// max: 0,
|
||||
// }
|
||||
// }
|
||||
|
||||
// // Returns a slice of locations where each token in the batch should be stored
|
||||
// func (c *Causal) findLocs() ([]int32, error) {
|
||||
// loc := make([]int32, 0, c.curBatchSize)
|
||||
|
||||
// for i := range c.cells {
|
||||
// if len(c.cells[i].sequences) == 0 {
|
||||
// loc = append(loc, int32(i))
|
||||
// if len(loc) >= c.curBatchSize {
|
||||
// return loc, nil
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// return nil, fmt.Errorf("%w (cache: %v batch: %v)", ErrKvCacheFull, len(c.cells), c.curBatchSize)
|
||||
// }
|
||||
|
||||
// func (c *Causal) updateSlidingWindow() {
|
||||
// c.curCellRange = newRange()
|
||||
|
||||
// if c.swaMemorySize == math.MaxInt32 {
|
||||
// for _, seq := range c.curSequences {
|
||||
// if seqRange, ok := c.cellRanges[seq]; ok {
|
||||
// c.curCellRange.min = min(c.curCellRange.min, seqRange.min)
|
||||
// c.curCellRange.max = max(c.curCellRange.max, seqRange.max)
|
||||
// }
|
||||
// }
|
||||
|
||||
// return
|
||||
// }
|
||||
|
||||
// type lowestPosition struct {
|
||||
// pos int32
|
||||
// curBatch bool
|
||||
// }
|
||||
|
||||
// // create a map of unique sequences to the lowest position in that sequence
|
||||
// lowestPos := make(map[int]lowestPosition)
|
||||
// for i := range c.curPositions {
|
||||
// seq := c.curSequences[i]
|
||||
|
||||
// lowest, ok := lowestPos[seq]
|
||||
// if !ok {
|
||||
// lowest = lowestPosition{pos: c.curPositions[i], curBatch: true}
|
||||
// } else if c.curPositions[i] < lowest.pos {
|
||||
// lowest.pos = c.curPositions[i]
|
||||
// }
|
||||
|
||||
// lowestPos[seq] = lowest
|
||||
// }
|
||||
|
||||
// // for any sequences are not part of this batch, clean up any tokens
|
||||
// // that are no longer needed after the processing of the previous
|
||||
// // batch
|
||||
// for seq, seqRange := range c.cellRanges {
|
||||
// if _, ok := lowestPos[seq]; !ok {
|
||||
// var last int32
|
||||
// for i := seqRange.min; i <= seqRange.max; i++ {
|
||||
// if slices.Contains(c.cells[i].sequences, seq) {
|
||||
// last = max(last, c.cells[i].pos)
|
||||
// }
|
||||
// }
|
||||
|
||||
// lowestPos[seq] = lowestPosition{pos: last + 1, curBatch: false}
|
||||
// }
|
||||
// }
|
||||
|
||||
// // delete any entries that are beyond the window of the oldest position in the sequence
|
||||
// for seq, lowest := range lowestPos {
|
||||
// oldRange, ok := c.cellRanges[seq]
|
||||
// if !ok {
|
||||
// continue
|
||||
// }
|
||||
|
||||
// newRange := newRange()
|
||||
|
||||
// for i := oldRange.min; i <= oldRange.max; i++ {
|
||||
// if slices.Contains(c.cells[i].sequences, seq) {
|
||||
// if c.cells[i].pos < lowest.pos-c.swaMemorySize {
|
||||
// c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq })
|
||||
// } else {
|
||||
// newRange.min = min(newRange.min, i)
|
||||
// newRange.max = max(newRange.max, i)
|
||||
// }
|
||||
// if lowest.curBatch && c.cells[i].pos >= lowest.pos-c.swaWindowSize {
|
||||
// c.curCellRange.min = min(c.curCellRange.min, i)
|
||||
// c.curCellRange.max = max(c.curCellRange.max, i)
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// c.cellRanges[seq] = newRange
|
||||
// }
|
||||
// }
|
||||
|
||||
// func roundDown(length, pad int) int {
|
||||
// return (length / pad) * pad
|
||||
// }
|
||||
|
||||
// func roundUp(length, pad int) int {
|
||||
// return ((length + pad - 1) / pad) * pad
|
||||
// }
|
||||
|
||||
// // Builds a mask of history x batch indicating whether for each token in the batch the
|
||||
// // token in the history should apply. This is based on both the sequence and causality (the
|
||||
// // position of the history is not ahead of the token in the batch).
|
||||
// func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
|
||||
// // Align and pad the two dimensions as required by the backend
|
||||
// batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding)
|
||||
|
||||
// c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding)
|
||||
// c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1
|
||||
|
||||
// length := c.curCellRange.max - c.curCellRange.min + 1
|
||||
|
||||
// mask := make([]float32, batchSize*length)
|
||||
|
||||
// for i := range c.curBatchSize {
|
||||
// enabled := !slices.Contains(c.opts.Except, i)
|
||||
// for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
|
||||
// if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) ||
|
||||
// (enabled && c.cells[j].pos > c.curPositions[i]) ||
|
||||
// c.chunkSize > 0 && c.cells[j].pos < c.curPositions[i]-c.curPositions[i]%c.chunkSize ||
|
||||
// c.cells[j].pos < c.curPositions[i]-c.swaWindowSize {
|
||||
// mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// // Mask out any padding tokens we added. For padding that we added to the cache history, this
|
||||
// // has already been masked out because the sequence doesn't match.
|
||||
// for i := c.curBatchSize * length; i < len(mask); i++ {
|
||||
// mask[i] = float32(math.Inf(-1))
|
||||
// }
|
||||
|
||||
// maskTensor := ctx.Input().FromFloats(mask, batchSize, length)
|
||||
|
||||
// // if c.config.MaskDType != ml.DTypeFloat32 {
|
||||
// // maskTensor = maskTensor.Cast(ctx, c.config.MaskDType)
|
||||
// // }
|
||||
|
||||
// slog.Info("XXX Causal.buildMask", "c.curBatchSize", c.curBatchSize, "c.config.MaskBatchPadding", c.config.MaskBatchPadding, "c.curCellRange.min", c.curCellRange.min, "c.curCellRange.max", c.curCellRange.max, "size", len(mask), "shape", []int{1, batchSize, length})
|
||||
|
||||
// return maskTensor
|
||||
// }
|
||||
|
||||
// func (c *Causal) SetLayer(layer int) {
|
||||
// c.curLayer = layer
|
||||
// }
|
||||
|
||||
// type CausalOptions struct {
|
||||
// // Enabled controls whether the causal mask is generated for a particular index in a batch
|
||||
// Except []int
|
||||
// }
|
||||
|
||||
// // SetCausal disables causal mask generation for a particular range of indicies in
|
||||
// // the current batch for subsequent calls to Get. The state resets for the next forward pass.
|
||||
// func (c *Causal) SetCausal(ctx ml.Context, opts CausalOptions) {
|
||||
// if !slices.Equal(c.opts.Except, opts.Except) {
|
||||
// c.opts = opts
|
||||
// if ctx != nil {
|
||||
// c.curMask = c.buildMask(ctx)
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||
// key := c.keys[c.curLayer]
|
||||
// value := c.values[c.curLayer]
|
||||
|
||||
// kHeadDim := c.kHeadDims[c.curLayer]
|
||||
// vHeadDim := c.vHeadDims[c.curLayer]
|
||||
// numKVHeads := c.numKVHeads[c.curLayer]
|
||||
// // rowSize := numKVHeads * c.curBatchSize
|
||||
// // cachedSize := c.curMask.Dim(1)
|
||||
// cachedSize := c.curLoc.Dim(0)
|
||||
// // kCellSize := kHeadDim * numKVHeads
|
||||
// // vCellSize := vHeadDim * numKVHeads
|
||||
|
||||
// slog.Info("XXX Causal.Get full cache", "key", key)
|
||||
// slog.Info("XXX Causal.Get full cache", "value", value)
|
||||
// slog.Info("XXX Causal.Get full cache", "curloc", c.curLoc)
|
||||
// slog.Info("XXX Causal.Get", "curMask", c.curMask)
|
||||
// slog.Info("XXX Causal.Get", "kHeadDim", kHeadDim, "numKVHeads", numKVHeads, "cachedSize", cachedSize, "kHeadDim", kHeadDim)
|
||||
// // panic("XXX")
|
||||
|
||||
// // fmt.Fprintln(os.Stderr, key.ToString())
|
||||
// // panic("full cache value")
|
||||
|
||||
// // TODO we should use TakeAxes to gather the cells from curLoc, but for now to be consistent with GGML, just grab a larger chunk and mask
|
||||
// key = key.TakeAxes(ctx, c.curLoc, 0).Reshape(ctx, 1, numKVHeads, cachedSize, kHeadDim)
|
||||
// // key = key.AsStrided(ctx, []int{1, numKVHeads, cachedSize, kHeadDim}, []int{}, rowSize*c.curCellRange.min)
|
||||
|
||||
// // slog.Info("XXX Causal.Get after AsStrided", "key", key)
|
||||
// // panic("XXX")
|
||||
|
||||
// // if c.config.PermutedV {
|
||||
// // panic("permuted")
|
||||
// // // TODO not converted
|
||||
// // vHeadDim := value.Dim(1)
|
||||
// // elemSize := value.Stride(2)
|
||||
|
||||
// // value = value.AsStrided(ctx,
|
||||
// // []int{numKVHeads, vHeadDim, cachedSize},
|
||||
// // []int{value.Stride(0), value.Stride(1)},
|
||||
// // elemSize*c.curCellRange.min,
|
||||
// // )
|
||||
// // } else {
|
||||
// // vHeadDim := c.vHeadDims[c.curLayer]
|
||||
// // rowSize := value.Stride(2)
|
||||
// // slog.Info("XXX Causal.Get before AsStrided", "vHeadDim", vHeadDim, "rowSize", rowSize)
|
||||
// // panic("XXX")
|
||||
|
||||
// // TODO we should use TakeAxes to gather the cells from curLoc, but for now to be consistent with GGML, just grab a larger chunk and mask
|
||||
// value = value.TakeAxes(ctx, c.curLoc, 0).Reshape(ctx, 1, numKVHeads, cachedSize, vHeadDim)
|
||||
// // value = value.AsStrided(ctx, []int{1, numKVHeads, cachedSize, vHeadDim}, []int{}, rowSize*c.curCellRange.min)
|
||||
|
||||
// // slog.Info("XXX Causal.Get after AsStrided", "value", value)
|
||||
// // panic("XXX")
|
||||
|
||||
// // }
|
||||
|
||||
// // // TODO The mask changes from X,X to 1,X, and with the Row-order change
|
||||
// // // the 1 becomes trailing and messes up later operations
|
||||
// // // This isn't the right solution, but works around it...
|
||||
// // if c.curMask.Dim(1) == 1 {
|
||||
// // return key, value, c.curMask.Transpose(ctx, 1, 0, 2, 3)
|
||||
// // }
|
||||
// // fmt.Fprintln(os.Stderr, key.ToString())
|
||||
// // fmt.Fprintln(os.Stderr, value.ToString())
|
||||
// // panic("XXX")
|
||||
// slog.Info("XXX Mask", "curLayer", c.curLayer, "shape", c.curMask.Shape())
|
||||
|
||||
// return key, value, c.curMask
|
||||
// }
|
||||
|
||||
// func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||
// kHeadDim := key.Dim(3)
|
||||
// vHeadDim := value.Dim(3)
|
||||
// numKVHeads := key.Dim(1)
|
||||
// batchSize := key.Dim(2)
|
||||
// kCellSize := kHeadDim * numKVHeads
|
||||
// vCellSize := vHeadDim * numKVHeads
|
||||
|
||||
// // slog.Info("XXX Causal.Put", "key", key, "value", value)
|
||||
// slog.Info("XXX Causal.Put", "kHeadDim", kHeadDim, "vHeadDim", vHeadDim, "numKVHeads", numKVHeads, "batchSize", batchSize)
|
||||
// // panic("XXX")
|
||||
|
||||
// if c.curBatchSize != batchSize {
|
||||
// panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, batchSize))
|
||||
// }
|
||||
|
||||
// // slog.Info("XXX", "c.ctxs", c.ctxs, "c.curLayer", c.curLayer, "backend", c.backend)
|
||||
// if _, ok := c.ctxs[c.curLayer]; !ok {
|
||||
// slog.Info("XXX Causal.Put creating new context", "c.curLayer", c.curLayer)
|
||||
// c.ctxs[c.curLayer] = c.backend.NewContext().Layer(c.curLayer)
|
||||
// }
|
||||
|
||||
// if _, ok := c.keys[c.curLayer]; !ok {
|
||||
// slog.Info("XXX Causal.Put allocating keys", "c.curLayer", c.curLayer, "shape", []int{len(c.cells), kCellSize})
|
||||
|
||||
// c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, len(c.cells), kCellSize)
|
||||
// c.kHeadDims[c.curLayer] = kHeadDim
|
||||
// c.vHeadDims[c.curLayer] = vHeadDim
|
||||
// c.numKVHeads[c.curLayer] = numKVHeads
|
||||
// }
|
||||
|
||||
// if _, ok := c.values[c.curLayer]; !ok {
|
||||
// // if c.config.PermutedV {
|
||||
// // c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, numKVHeads, vHeadDim, len(c.cells))
|
||||
// // } else {
|
||||
// c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, len(c.cells), vCellSize)
|
||||
// // }
|
||||
// }
|
||||
|
||||
// key = key.Reshape(ctx, batchSize, 1, kCellSize) //.Contiguous(ctx, false) // TODO contiguous may not be needed
|
||||
|
||||
// // slog.Info("XXX Causal.Put after reshape", "keyCache", keyCache)
|
||||
// // panic("XXX")
|
||||
// // curLoc := 0 // TODO c.curLoc is now a tensor
|
||||
// // kSize := numKVHeads * kHeadDim
|
||||
// // vSize := numKVHeads * vHeadDim
|
||||
// // start := []int{int(curLoc), 0}
|
||||
// // kStop := []int{int(curLoc + batchSize), int(kSize)}
|
||||
// // vStop := []int{int(curLoc + batchSize), int(vSize)}
|
||||
// // strides := []int{1, 1}
|
||||
|
||||
// // slog.Info("XXX Causal.Put Key SliceUpdate", "keyCache", keyCache)
|
||||
// // slog.Info("XXX Causal.Put Key SliceUpdate", "key", key)
|
||||
|
||||
// // slog.Info("XXX Causal.Put Key SliceUpdate", "start", start, "kStop", kStop, "strides", strides)
|
||||
|
||||
// // ctx.Forward(c.keys[c.curLayer].SliceUpdate(ctx, key, start, kStop, strides))
|
||||
// ctx.Forward(c.keys[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLoc}, key, []int{0}))
|
||||
// // fmt.Fprintln(os.Stderr, keyCache.ToString())
|
||||
// // panic("input value")
|
||||
|
||||
// // fmt.Fprintln(os.Stderr, t.ToString())
|
||||
// // panic("XXX")
|
||||
|
||||
// // if c.config.PermutedV {
|
||||
// // panic("permuted")
|
||||
// // // TODO not adjusted
|
||||
// // value = value.Reshape(ctx, vHeadDim*numKVHeads, 1, batchSize)
|
||||
// // value = value.Transpose(ctx, 2, 0, 1, 3)
|
||||
|
||||
// // valueCache := c.values[c.curLayer]
|
||||
// // valueCache = valueCache.Reshape(ctx, 1, len(c.cells), vHeadDim*numKVHeads)
|
||||
|
||||
// // ctx.Forward(valueCache.SliceUpdate(ctx, value, start, vStop, strides))
|
||||
// // } else {
|
||||
// value = value.Reshape(ctx, batchSize, 1, vCellSize) //.Contiguous(ctx, false) // TODO contiguous may not be needed
|
||||
// // slog.Info("XXX Causal.Put Value SliceUpdate", "valueCache", valueCache)
|
||||
// // slog.Info("XXX Causal.Put Value SliceUpdate", "value", value)
|
||||
// // slog.Info("XXX Causal.Put Value SliceUpdate", "start", start, "vStop", vStop, "strides", strides)
|
||||
|
||||
// ctx.Forward(c.values[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLoc}, value, []int{0}))
|
||||
// // }
|
||||
// // fmt.Fprintln(os.Stderr, c.keys[c.curLayer].ToString())
|
||||
// // fmt.Fprintln(os.Stderr, c.values[c.curLayer].ToString())
|
||||
// // panic("XXX")
|
||||
|
||||
// }
|
||||
|
||||
// func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
|
||||
// seqRange := newRange()
|
||||
|
||||
// for i := range c.cells {
|
||||
// // Remove the contents of dstSeq so that we only have the copied prefix, metadata will be reset at the end
|
||||
// if slices.Contains(c.cells[i].sequences, dstSeq) {
|
||||
// c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == dstSeq })
|
||||
// }
|
||||
|
||||
// if slices.Contains(c.cells[i].sequences, srcSeq) && c.cells[i].pos < len {
|
||||
// c.cells[i].sequences = append(c.cells[i].sequences, dstSeq)
|
||||
// if i < seqRange.min {
|
||||
// seqRange.min = i
|
||||
// }
|
||||
// if i > seqRange.max {
|
||||
// seqRange.max = i
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// c.cellRanges[dstSeq] = seqRange
|
||||
// }
|
||||
|
||||
// func (c *Causal) CanResume(seq int, pos int32) bool {
|
||||
// if c.swaMemorySize == math.MaxInt32 {
|
||||
// return true
|
||||
// }
|
||||
|
||||
// seqRange, ok := c.cellRanges[seq]
|
||||
// if !ok {
|
||||
// return false
|
||||
// }
|
||||
|
||||
// // for sliding window, check that the window of the new sequence is contained in
|
||||
// // the window of what we are storing
|
||||
// var first int32 = math.MaxInt32
|
||||
// var last int32 = -1
|
||||
// for i := seqRange.min; i <= seqRange.max; i++ {
|
||||
// if slices.Contains(c.cells[i].sequences, seq) {
|
||||
// first = min(first, c.cells[i].pos)
|
||||
// last = max(last, c.cells[i].pos)
|
||||
// }
|
||||
// }
|
||||
|
||||
// if last == -1 {
|
||||
// return false
|
||||
// }
|
||||
|
||||
// posWindowStart := max(0, pos-c.swaWindowSize)
|
||||
// return posWindowStart >= first && pos <= last+1
|
||||
// }
|
||||
|
||||
// func (c *Causal) shift(seq int, beginIndex, offset int32) error {
|
||||
// if c.shiftFn == nil {
|
||||
// return ErrNotSupported
|
||||
// }
|
||||
|
||||
// seqRange := c.cellRanges[seq]
|
||||
|
||||
// for start := seqRange.min; start <= seqRange.max; start += c.maxBatch {
|
||||
// size := min(seqRange.max-start+1, c.maxBatch)
|
||||
// offsets := make([]int32, size)
|
||||
|
||||
// var batchFirst, batchLast int
|
||||
|
||||
// batchFirst = -1
|
||||
// for i := range offsets {
|
||||
// cell := c.cells[start+i]
|
||||
|
||||
// if slices.Contains(cell.sequences, seq) && cell.pos >= beginIndex {
|
||||
// offsets[i] = offset
|
||||
// if batchFirst < 0 {
|
||||
// batchFirst = i
|
||||
// }
|
||||
// batchLast = i
|
||||
// }
|
||||
// }
|
||||
|
||||
// if batchFirst < 0 {
|
||||
// continue
|
||||
// }
|
||||
|
||||
// offsets = offsets[batchFirst : batchLast+1]
|
||||
|
||||
// slog.Info("XXX Causal.shift creating new temporary context")
|
||||
// ctx := c.backend.NewContext()
|
||||
// kShift := ctx.Input().FromInts(offsets, len(offsets))
|
||||
|
||||
// for i, key := range c.keys {
|
||||
// if key == nil {
|
||||
// continue
|
||||
// }
|
||||
|
||||
// kHeadDim := key.Dim(2)
|
||||
// numKVHeads := key.Dim(1)
|
||||
// rowSize := key.Stride(0)
|
||||
|
||||
// key = key.AsStrided(ctx,
|
||||
// []int{len(offsets), numKVHeads, kHeadDim},
|
||||
// []int{key.Stride(0), key.Stride(1)},
|
||||
// rowSize*(start+batchFirst),
|
||||
// )
|
||||
|
||||
// roped, err := c.shiftFn(ctx, i, key, kShift)
|
||||
// if err != nil {
|
||||
// ctx.Close()
|
||||
// return err
|
||||
// }
|
||||
|
||||
// ctx.Forward(roped.Copy(ctx, key))
|
||||
// }
|
||||
|
||||
// ctx.Compute()
|
||||
// ctx.Close()
|
||||
// }
|
||||
|
||||
// return nil
|
||||
// }
|
||||
|
||||
// func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error {
|
||||
// // TODO(jessegross): We should check to see if removing the middle of the sequence will
|
||||
// // cause the sliding window to encompass tokens that we no longer have. If so, then we
|
||||
// // should return an error, which will trigger the runner to evaluate the full history and
|
||||
// // rebuild the window. However, if we have multimodal inputs in our history, this reuse
|
||||
// // results in use after free, so we don't do it for now.
|
||||
|
||||
// var offset int32
|
||||
// if endIndex != math.MaxInt32 {
|
||||
// offset = beginIndex - endIndex
|
||||
// }
|
||||
|
||||
// seqRange := newRange()
|
||||
|
||||
// for i := range c.cells {
|
||||
// if slices.Contains(c.cells[i].sequences, seq) {
|
||||
// if c.cells[i].pos >= beginIndex && c.cells[i].pos < endIndex {
|
||||
// c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq })
|
||||
// } else {
|
||||
// if c.cells[i].pos >= endIndex {
|
||||
// if slices.ContainsFunc(c.cells[i].sequences, func(s int) bool { return s != seq }) {
|
||||
// return errors.New("shifting cells shared by multiple sequences not supported")
|
||||
// }
|
||||
|
||||
// c.cells[i].pos += offset
|
||||
// }
|
||||
// if i < seqRange.min {
|
||||
// seqRange.min = i
|
||||
// }
|
||||
// if i > seqRange.max {
|
||||
// seqRange.max = i
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// if seqRange == newRange() {
|
||||
// delete(c.cellRanges, seq)
|
||||
// return nil
|
||||
// }
|
||||
|
||||
// c.cellRanges[seq] = seqRange
|
||||
|
||||
// if endIndex != math.MaxInt32 {
|
||||
// err := c.shift(seq, endIndex+offset, offset)
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// }
|
||||
|
||||
// return nil
|
||||
// }
|
||||
import (
|
||||
"github.com/ollama/ollama/x/ml"
|
||||
"github.com/ollama/ollama/x/model/input"
|
||||
)
|
||||
|
||||
// Causal cache stores K and V tensors according to their position in the
|
||||
// sequence. Returns the history and a mask for attending to past tokens
|
||||
type Causal struct {
|
||||
DType ml.DType
|
||||
|
||||
// locations for data storage for this batch
|
||||
curLocPut ml.Tensor
|
||||
|
||||
// locations for data storage for this batch
|
||||
curLocGet ml.Tensor
|
||||
|
||||
// the active layer for Get and Put
|
||||
curLayer int
|
||||
|
||||
capacity int
|
||||
|
||||
offset int
|
||||
|
||||
backend ml.Backend
|
||||
ctxs map[int]ml.Context
|
||||
keys, values map[int]ml.Tensor
|
||||
|
||||
// TODO is this needed per layer, or will it always be consistent?
|
||||
kHeadDims, vHeadDims, numKVHeads map[int]int
|
||||
}
|
||||
|
||||
func NewCausalCache() *Causal {
|
||||
return &Causal{
|
||||
ctxs: make(map[int]ml.Context),
|
||||
keys: make(map[int]ml.Tensor),
|
||||
values: make(map[int]ml.Tensor),
|
||||
kHeadDims: make(map[int]int),
|
||||
vHeadDims: make(map[int]int),
|
||||
numKVHeads: make(map[int]int),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||
c.DType = dtype
|
||||
c.capacity = capacity
|
||||
c.backend = backend
|
||||
}
|
||||
|
||||
func (c *Causal) SetConfig(config ml.CacheConfig) {}
|
||||
|
||||
func (c *Causal) SetLayer(layer int) {
|
||||
c.curLayer = layer
|
||||
}
|
||||
|
||||
func (c *Causal) Close() {
|
||||
// slog.Info("XXX Causal.Close called", "number of contexts", len(c.ctxs))
|
||||
for _, ctx := range c.ctxs {
|
||||
ctx.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
||||
locsPut := make([]int32, len(batch.Positions))
|
||||
for i := c.offset; i < len(batch.Positions); i++ {
|
||||
locsPut[i-c.offset] = int32(i)
|
||||
}
|
||||
c.offset += len(batch.Positions)
|
||||
locsGet := make([]int32, c.offset)
|
||||
for i := range c.offset {
|
||||
locsGet[i] = int32(i)
|
||||
}
|
||||
c.curLocGet = ctx.Input().FromInts(locsGet, len(locsGet))
|
||||
c.curLocPut = ctx.Input().FromInts(locsPut, len(locsPut))
|
||||
// slog.Info("XXX Causal.StartForward", "offset", c.offset, "put", locsPut, "get", locsGet)
|
||||
|
||||
return nil
|
||||
}
|
||||
func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||
kHeadDim := key.Dim(3)
|
||||
vHeadDim := value.Dim(3)
|
||||
numKVHeads := key.Dim(1)
|
||||
batchSize := key.Dim(2)
|
||||
kCellSize := kHeadDim * numKVHeads
|
||||
vCellSize := vHeadDim * numKVHeads
|
||||
// slog.Info("XXX Causal.Put", "kHeadDim", kHeadDim, "vHeadDim", vHeadDim, "numKVHeads", numKVHeads, "batchSize", batchSize, "kCellSize", kCellSize, "vCellSize", vCellSize)
|
||||
|
||||
if _, ok := c.ctxs[c.curLayer]; !ok {
|
||||
// slog.Info("XXX Causal.Put creating new context", "c.curLayer", c.curLayer)
|
||||
c.ctxs[c.curLayer] = c.backend.NewContext().Layer(c.curLayer)
|
||||
}
|
||||
|
||||
if _, ok := c.keys[c.curLayer]; !ok {
|
||||
// slog.Info("XXX Causal.Put allocating keys and values", "c.curLayer", c.curLayer, "shape", []int{c.capacity, kCellSize})
|
||||
c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, c.capacity, kCellSize)
|
||||
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, c.capacity, vCellSize)
|
||||
c.kHeadDims[c.curLayer] = kHeadDim
|
||||
c.vHeadDims[c.curLayer] = vHeadDim
|
||||
c.numKVHeads[c.curLayer] = numKVHeads
|
||||
}
|
||||
key = key.Reshape(ctx, batchSize, 1, kCellSize)
|
||||
|
||||
// slog.Info("XXX Causal.Put ", "c.keys[c.curLayer]", c.keys[c.curLayer])
|
||||
// slog.Info("XXX Causal.Put ", "c.curLocPut", c.curLocPut)
|
||||
// slog.Info("XXX Causal.Put ", "key", key)
|
||||
ctx.Forward(c.keys[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLocPut}, key, []int{0}))
|
||||
value = value.Reshape(ctx, batchSize, 1, vCellSize)
|
||||
ctx.Forward(c.values[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLocPut}, value, []int{0}))
|
||||
|
||||
}
|
||||
|
||||
func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||
key := c.keys[c.curLayer]
|
||||
value := c.values[c.curLayer]
|
||||
|
||||
kHeadDim := c.kHeadDims[c.curLayer]
|
||||
vHeadDim := c.vHeadDims[c.curLayer]
|
||||
numKVHeads := c.numKVHeads[c.curLayer]
|
||||
// rowSize := numKVHeads * c.curBatchSize
|
||||
// cachedSize := c.curMask.Dim(1)
|
||||
cachedSize := c.curLocGet.Dim(0)
|
||||
// kCellSize := kHeadDim * numKVHeads
|
||||
// vCellSize := vHeadDim * numKVHeads
|
||||
// slog.Info("XXX Causal.Get", "shape", []int{1, numKVHeads, cachedSize, kHeadDim})
|
||||
|
||||
key = key.TakeAxes(ctx, c.curLocGet, 0).Reshape(ctx, 1, numKVHeads, cachedSize, kHeadDim)
|
||||
value = value.TakeAxes(ctx, c.curLocGet, 0).Reshape(ctx, 1, numKVHeads, cachedSize, vHeadDim)
|
||||
return key, value, nil
|
||||
}
|
||||
|
||||
func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (c *Causal) CanResume(seq int, pos int32) bool {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
@@ -1,973 +0,0 @@
|
||||
package kvcache
|
||||
|
||||
// import (
|
||||
// "fmt"
|
||||
// "math"
|
||||
// "slices"
|
||||
// "testing"
|
||||
|
||||
// "github.com/ollama/ollama/ml"
|
||||
// "github.com/ollama/ollama/model/input"
|
||||
// )
|
||||
|
||||
// type testCase struct {
|
||||
// name string
|
||||
// in []float32
|
||||
// inShape []int
|
||||
// seqs []int
|
||||
// pos []int32
|
||||
// expected []float32
|
||||
// expectedShape []int
|
||||
// expectedMask []float32
|
||||
// }
|
||||
|
||||
// func runPermutedVariants(t *testing.T, fn func(t *testing.T, backend *testBackend)) {
|
||||
// t.Helper()
|
||||
// for _, permuted := range []bool{false, true} {
|
||||
// t.Run(fmt.Sprintf("PermutedV=%t", permuted), func(t *testing.T) {
|
||||
// fn(t, &testBackend{permutedV: permuted})
|
||||
// })
|
||||
// }
|
||||
// }
|
||||
|
||||
// func TestStore(t *testing.T) {
|
||||
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
||||
// cache := NewCausalCache(nil)
|
||||
// defer cache.Close()
|
||||
|
||||
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||
|
||||
// tests := []testCase{
|
||||
// {
|
||||
// name: "FirstBatch",
|
||||
// in: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234},
|
||||
// inShape: []int{2, 3, 4},
|
||||
// seqs: []int{0, 0, 0, 0},
|
||||
// pos: []int32{0, 1, 2, 3},
|
||||
// expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234},
|
||||
// expectedShape: []int{2, 3, 4},
|
||||
// expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0},
|
||||
// },
|
||||
// {
|
||||
// name: "SecondBatch",
|
||||
// in: []float32{115, 215, 125, 225, 135, 235},
|
||||
// inShape: []int{2, 3, 1},
|
||||
// seqs: []int{0},
|
||||
// pos: []int32{4},
|
||||
// expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234, 115, 215, 125, 225, 135, 235},
|
||||
// expectedShape: []int{2, 3, 5},
|
||||
// expectedMask: []float32{0, 0, 0, 0, 0},
|
||||
// },
|
||||
// }
|
||||
|
||||
// testCache(t, backend, cache, tests)
|
||||
// })
|
||||
// }
|
||||
|
||||
// func TestSWA(t *testing.T) {
|
||||
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
||||
// cache := NewSWACache(1, nil)
|
||||
// defer cache.Close()
|
||||
|
||||
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||
|
||||
// x := float32(math.Inf(-1))
|
||||
|
||||
// tests := []testCase{
|
||||
// {
|
||||
// name: "FirstBatch",
|
||||
// in: []float32{1, 2, 3, 4},
|
||||
// inShape: []int{1, 1, 4},
|
||||
// seqs: []int{0, 0, 0, 0},
|
||||
// pos: []int32{0, 1, 2, 3},
|
||||
// expected: []float32{1, 2, 3, 4},
|
||||
// expectedShape: []int{1, 1, 4},
|
||||
// expectedMask: []float32{
|
||||
// 0, x, x, x,
|
||||
// 0, 0, x, x,
|
||||
// x, 0, 0, x,
|
||||
// x, x, 0, 0,
|
||||
// },
|
||||
// },
|
||||
// {
|
||||
// name: "SecondBatch",
|
||||
// in: []float32{5, 6},
|
||||
// inShape: []int{1, 1, 2},
|
||||
// seqs: []int{0, 0},
|
||||
// pos: []int32{4, 5},
|
||||
// expected: []float32{5, 6, 3, 4},
|
||||
// expectedShape: []int{1, 1, 4},
|
||||
// expectedMask: []float32{
|
||||
// 0, x, x, 0,
|
||||
// 0, 0, x, x,
|
||||
// },
|
||||
// },
|
||||
// }
|
||||
|
||||
// testCache(t, backend, cache, tests)
|
||||
// })
|
||||
// }
|
||||
|
||||
// func TestSWASeparateBatches(t *testing.T) {
|
||||
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
||||
// cache := NewSWACache(1, nil)
|
||||
// defer cache.Close()
|
||||
|
||||
// cache.Init(backend, ml.DTypeF16, 2, 16, 2)
|
||||
|
||||
// x := float32(math.Inf(-1))
|
||||
|
||||
// tests := []testCase{
|
||||
// {
|
||||
// name: "First seq 0",
|
||||
// in: []float32{1, 2},
|
||||
// inShape: []int{1, 1, 2},
|
||||
// seqs: []int{0, 0},
|
||||
// pos: []int32{0, 1},
|
||||
// expected: []float32{1, 2},
|
||||
// expectedShape: []int{1, 1, 2},
|
||||
// expectedMask: []float32{
|
||||
// 0, x,
|
||||
// 0, 0,
|
||||
// },
|
||||
// },
|
||||
// {
|
||||
// name: "Second seq 0",
|
||||
// in: []float32{3, 4},
|
||||
// inShape: []int{1, 1, 2},
|
||||
// seqs: []int{0, 0},
|
||||
// pos: []int32{2, 3},
|
||||
// expected: []float32{2, 3, 4},
|
||||
// expectedShape: []int{1, 1, 3},
|
||||
// expectedMask: []float32{
|
||||
// 0, 0, x,
|
||||
// x, 0, 0,
|
||||
// },
|
||||
// },
|
||||
// {
|
||||
// name: "First seq 1",
|
||||
// in: []float32{5, 6},
|
||||
// inShape: []int{1, 1, 2},
|
||||
// seqs: []int{1, 1},
|
||||
// pos: []int32{0, 1},
|
||||
// expected: []float32{5, 6},
|
||||
// expectedShape: []int{1, 1, 2},
|
||||
// expectedMask: []float32{
|
||||
// 0, x,
|
||||
// 0, 0,
|
||||
// },
|
||||
// },
|
||||
// {
|
||||
// name: "Second seq 1",
|
||||
// in: []float32{7, 8},
|
||||
// inShape: []int{1, 1, 2},
|
||||
// seqs: []int{1, 1},
|
||||
// pos: []int32{2, 3},
|
||||
// expected: []float32{6, 3, 4, 7, 8},
|
||||
// expectedShape: []int{1, 1, 5},
|
||||
// expectedMask: []float32{
|
||||
// 0, x, x, 0, x,
|
||||
// x, x, x, 0, 0,
|
||||
// },
|
||||
// },
|
||||
// {
|
||||
// name: "Third seq 0",
|
||||
// in: []float32{9, 10},
|
||||
// inShape: []int{1, 1, 2},
|
||||
// seqs: []int{0, 0},
|
||||
// pos: []int32{4, 5},
|
||||
// expected: []float32{9, 10, 3, 4},
|
||||
// expectedShape: []int{1, 1, 4},
|
||||
// expectedMask: []float32{
|
||||
// 0, x, x, 0,
|
||||
// 0, 0, x, x,
|
||||
// },
|
||||
// },
|
||||
// }
|
||||
|
||||
// testCache(t, backend, cache, tests)
|
||||
// })
|
||||
// }
|
||||
|
||||
// func TestSWAMem(t *testing.T) {
|
||||
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
||||
// cache := NewSWAMemCache(1, 3, nil)
|
||||
// defer cache.Close()
|
||||
|
||||
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||
|
||||
// x := float32(math.Inf(-1))
|
||||
|
||||
// tests := []testCase{
|
||||
// {
|
||||
// name: "FirstBatch",
|
||||
// in: []float32{1, 2, 3, 4},
|
||||
// inShape: []int{1, 1, 4},
|
||||
// seqs: []int{0, 0, 0, 0},
|
||||
// pos: []int32{0, 1, 2, 3},
|
||||
// expected: []float32{1, 2, 3, 4},
|
||||
// expectedShape: []int{1, 1, 4},
|
||||
// expectedMask: []float32{
|
||||
// 0, x, x, x,
|
||||
// 0, 0, x, x,
|
||||
// x, 0, 0, x,
|
||||
// x, x, 0, 0,
|
||||
// },
|
||||
// },
|
||||
// {
|
||||
// name: "SecondBatch",
|
||||
// in: []float32{5, 6},
|
||||
// inShape: []int{1, 1, 2},
|
||||
// seqs: []int{0, 0},
|
||||
// pos: []int32{4, 5},
|
||||
// expected: []float32{5, 2, 3, 4, 6},
|
||||
// expectedShape: []int{1, 1, 5},
|
||||
// expectedMask: []float32{
|
||||
// 0, x, x, 0, x,
|
||||
// 0, x, x, x, 0,
|
||||
// },
|
||||
// },
|
||||
// }
|
||||
|
||||
// testCache(t, backend, cache, tests)
|
||||
// })
|
||||
// }
|
||||
|
||||
// func TestChunkedAttention(t *testing.T) {
|
||||
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
||||
// cache := NewChunkedAttentionCache(2, nil)
|
||||
// defer cache.Close()
|
||||
|
||||
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||
|
||||
// x := float32(math.Inf(-1))
|
||||
|
||||
// testCache(
|
||||
// t, backend, cache,
|
||||
// []testCase{
|
||||
// {
|
||||
// name: "FirstBatch",
|
||||
// in: []float32{1, 2, 3, 4},
|
||||
// inShape: []int{1, 1, 4},
|
||||
// seqs: []int{0, 0, 0, 0},
|
||||
// pos: []int32{0, 1, 2, 3},
|
||||
// expected: []float32{1, 2, 3, 4},
|
||||
// expectedShape: []int{1, 1, 4},
|
||||
// expectedMask: []float32{
|
||||
// 0, x, x, x,
|
||||
// 0, 0, x, x,
|
||||
// x, x, 0, x,
|
||||
// x, x, 0, 0,
|
||||
// },
|
||||
// },
|
||||
// {
|
||||
// name: "SecondBatch",
|
||||
// in: []float32{5, 6, 7},
|
||||
// inShape: []int{1, 1, 3},
|
||||
// seqs: []int{0, 0, 0},
|
||||
// pos: []int32{4, 5, 6},
|
||||
// expected: []float32{1, 2, 3, 4, 5, 6, 7},
|
||||
// expectedShape: []int{1, 1, 7},
|
||||
// expectedMask: []float32{
|
||||
// x, x, x, x, 0, x, x,
|
||||
// x, x, x, x, 0, 0, x,
|
||||
// x, x, x, x, x, x, 0,
|
||||
// },
|
||||
// },
|
||||
// {
|
||||
// name: "ThirdBatch",
|
||||
// in: []float32{8, 9},
|
||||
// inShape: []int{1, 1, 2},
|
||||
// seqs: []int{0, 0},
|
||||
// pos: []int32{7, 8},
|
||||
// expected: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9},
|
||||
// expectedShape: []int{1, 1, 9},
|
||||
// expectedMask: []float32{
|
||||
// x, x, x, x, x, x, 0, 0, x,
|
||||
// x, x, x, x, x, x, x, x, 0,
|
||||
// },
|
||||
// },
|
||||
// },
|
||||
// )
|
||||
// })
|
||||
// }
|
||||
|
||||
// func TestSequences(t *testing.T) {
|
||||
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
||||
// cache := NewCausalCache(nil)
|
||||
// defer cache.Close()
|
||||
|
||||
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||
|
||||
// tests := []testCase{
|
||||
// {
|
||||
// name: "FirstBatch",
|
||||
// in: []float32{1, 2, 3, 4},
|
||||
// inShape: []int{1, 1, 4},
|
||||
// seqs: []int{0, 0, 1, 1},
|
||||
// pos: []int32{0, 1, 0, 1},
|
||||
// expected: []float32{1, 2, 3, 4},
|
||||
// expectedShape: []int{1, 1, 4},
|
||||
// expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
|
||||
// },
|
||||
// {
|
||||
// name: "SecondBatch",
|
||||
// in: []float32{5, 6},
|
||||
// inShape: []int{1, 1, 2},
|
||||
// seqs: []int{0, 1},
|
||||
// pos: []int32{2, 2},
|
||||
// expected: []float32{1, 2, 3, 4, 5, 6},
|
||||
// expectedShape: []int{1, 1, 6},
|
||||
// expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), 0},
|
||||
// },
|
||||
// }
|
||||
|
||||
// testCache(t, backend, cache, tests)
|
||||
// })
|
||||
// }
|
||||
|
||||
// func TestRemove(t *testing.T) {
|
||||
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
||||
// cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
// return key.Add(ctx, shift), nil
|
||||
// })
|
||||
// defer cache.Close()
|
||||
|
||||
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||
|
||||
// x := float32(math.Inf(-1))
|
||||
|
||||
// tests := []testCase{
|
||||
// {
|
||||
// name: "FirstBatch",
|
||||
// in: []float32{1, 2, 3, 4},
|
||||
// inShape: []int{1, 1, 4},
|
||||
// seqs: []int{0, 0, 1, 1},
|
||||
// pos: []int32{0, 1, 0, 1},
|
||||
// expected: []float32{1, 2, 3, 4},
|
||||
// expectedShape: []int{1, 1, 4},
|
||||
// expectedMask: []float32{
|
||||
// 0, x, x, x,
|
||||
// 0, 0, x, x,
|
||||
// x, x, 0, x,
|
||||
// x, x, 0, 0,
|
||||
// },
|
||||
// },
|
||||
// }
|
||||
|
||||
// testCache(t, backend, cache, tests)
|
||||
|
||||
// err := cache.Remove(0, 1, math.MaxInt32)
|
||||
// if err != nil {
|
||||
// panic(err)
|
||||
// }
|
||||
|
||||
// tests = []testCase{
|
||||
// {
|
||||
// name: "RemoveEnd",
|
||||
// in: []float32{5, 6},
|
||||
// inShape: []int{1, 1, 2},
|
||||
// seqs: []int{0, 1},
|
||||
// pos: []int32{1, 2},
|
||||
// expected: []float32{1, 5, 3, 4, 6},
|
||||
// expectedShape: []int{1, 1, 5},
|
||||
// expectedMask: []float32{
|
||||
// 0, 0, x, x, x,
|
||||
// x, x, 0, 0, 0,
|
||||
// },
|
||||
// },
|
||||
// }
|
||||
|
||||
// testCache(t, backend, cache, tests)
|
||||
|
||||
// err = cache.Remove(0, 0, 1)
|
||||
// if err != nil {
|
||||
// panic(err)
|
||||
// }
|
||||
|
||||
// tests = []testCase{
|
||||
// {
|
||||
// name: "RemoveMiddle",
|
||||
// in: []float32{7, 8},
|
||||
// inShape: []int{1, 1, 2},
|
||||
// seqs: []int{0, 0},
|
||||
// pos: []int32{1, 2},
|
||||
// expected: []float32{7, 4, 3, 4, 6, 8},
|
||||
// expectedShape: []int{1, 1, 6},
|
||||
// expectedMask: []float32{
|
||||
// 0, 0, x, x, x, x,
|
||||
// 0, 0, x, x, x, 0,
|
||||
// },
|
||||
// },
|
||||
// }
|
||||
|
||||
// testCache(t, backend, cache, tests)
|
||||
// })
|
||||
// }
|
||||
|
||||
// func TestCopy(t *testing.T) {
|
||||
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
||||
// cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { return key, nil })
|
||||
// defer cache.Close()
|
||||
|
||||
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||
|
||||
// tests := []testCase{
|
||||
// {
|
||||
// name: "FirstBatch",
|
||||
// in: []float32{1, 2, 3, 4},
|
||||
// inShape: []int{1, 1, 4},
|
||||
// seqs: []int{0, 0, 0, 0},
|
||||
// pos: []int32{0, 1, 2, 3},
|
||||
// expected: []float32{1, 2, 3, 4},
|
||||
// expectedShape: []int{1, 1, 4},
|
||||
// expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0},
|
||||
// },
|
||||
// }
|
||||
|
||||
// testCache(t, backend, cache, tests)
|
||||
|
||||
// cache.CopyPrefix(0, 1, 2)
|
||||
|
||||
// tests = []testCase{
|
||||
// {
|
||||
// name: "Copy",
|
||||
// in: []float32{5, 6},
|
||||
// inShape: []int{1, 1, 2},
|
||||
// seqs: []int{1, 1},
|
||||
// pos: []int32{3, 4},
|
||||
// expected: []float32{1, 2, 3, 4, 5, 6},
|
||||
// expectedShape: []int{1, 1, 6},
|
||||
// expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
|
||||
// },
|
||||
// }
|
||||
|
||||
// testCache(t, backend, cache, tests)
|
||||
// })
|
||||
// }
|
||||
|
||||
// func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase) {
|
||||
// for _, test := range tests {
|
||||
// t.Run(test.name, func(t *testing.T) {
|
||||
// context := backend.NewContext()
|
||||
// defer context.Close()
|
||||
|
||||
// err := cache.StartForward(context, input.Batch{Positions: test.pos, Sequences: test.seqs}, false)
|
||||
// if err != nil {
|
||||
// panic(err)
|
||||
// }
|
||||
|
||||
// cache.SetLayer(0)
|
||||
// tensor := context.FromFloats(test.in, test.inShape...)
|
||||
// cache.Put(context, tensor, tensor)
|
||||
|
||||
// out, _, mask := cache.Get(context)
|
||||
|
||||
// context.Forward(out, mask).Compute(out, mask)
|
||||
|
||||
// if !slices.Equal(out.Floats(), test.expected) {
|
||||
// t.Errorf("TestCache: have %v; want %v", out.Floats(), test.expected)
|
||||
// }
|
||||
|
||||
// if !slices.Equal(out.Shape(), test.expectedShape) {
|
||||
// t.Errorf("TestCache: has shape %v; want %v", out.Shape(), test.expectedShape)
|
||||
// }
|
||||
|
||||
// if !slices.Equal(mask.Floats(), test.expectedMask) {
|
||||
// t.Errorf("TestCache: have mask: have %v want %v", mask.Floats(), test.expectedMask)
|
||||
// }
|
||||
// })
|
||||
// }
|
||||
// }
|
||||
|
||||
// func TestCanResume(t *testing.T) {
|
||||
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
||||
// windowSize := int32(4)
|
||||
// cache := NewSWACache(windowSize, nil)
|
||||
// defer cache.Close()
|
||||
|
||||
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||
|
||||
// context := backend.NewContext()
|
||||
// defer context.Close()
|
||||
|
||||
// err := cache.StartForward(context, input.Batch{
|
||||
// Positions: []int32{0, 1, 2, 3, 4},
|
||||
// Sequences: []int{0, 0, 0, 0, 0},
|
||||
// }, false)
|
||||
// if err != nil {
|
||||
// t.Fatalf("StartForward failed: %v", err)
|
||||
// }
|
||||
|
||||
// cache.SetLayer(0)
|
||||
// tensor := context.FromFloats([]float32{1, 2, 3, 4, 5}, 1, 1, 5)
|
||||
// cache.Put(context, tensor, tensor)
|
||||
|
||||
// // with window size 4, nothing has slid out of the window yet
|
||||
// if !cache.CanResume(0, 0) {
|
||||
// t.Errorf("CanResume(0, 0) = false, want true (within window)")
|
||||
// }
|
||||
// if !cache.CanResume(0, 1) {
|
||||
// t.Errorf("CanResume(0, 1) = false, want true (within window)")
|
||||
// }
|
||||
// if !cache.CanResume(0, 2) {
|
||||
// t.Errorf("CanResume(0, 2) = false, want true (within window)")
|
||||
// }
|
||||
// if !cache.CanResume(0, 3) {
|
||||
// t.Errorf("CanResume(0, 3) = false, want true (latest position)")
|
||||
// }
|
||||
// if !cache.CanResume(0, 4) {
|
||||
// t.Errorf("CanResume(0, 4) = false, want true (latest position)")
|
||||
// }
|
||||
|
||||
// // shift window by adding position 5
|
||||
// err = cache.StartForward(context, input.Batch{
|
||||
// Positions: []int32{5},
|
||||
// Sequences: []int{0},
|
||||
// }, false)
|
||||
// if err != nil {
|
||||
// t.Fatalf("StartForward failed: %v", err)
|
||||
// }
|
||||
|
||||
// cache.SetLayer(0)
|
||||
// tensor = context.FromFloats([]float32{6}, 1, 1, 1)
|
||||
// cache.Put(context, tensor, tensor)
|
||||
|
||||
// // only the latest position has overlapping windows
|
||||
// if cache.CanResume(0, 0) {
|
||||
// t.Errorf("after shift: CanResume(0, 0) = true, want false (outside window)")
|
||||
// }
|
||||
// if cache.CanResume(0, 1) {
|
||||
// t.Errorf("after shift: CanResume(0, 1) = true, want false (outside window)")
|
||||
// }
|
||||
// if cache.CanResume(0, 2) {
|
||||
// t.Errorf("after shift: CanResume(0, 2) = true, want false (outside window)")
|
||||
// }
|
||||
// if cache.CanResume(0, 3) {
|
||||
// t.Errorf("after shift: CanResume(0, 3) = true, want false (outside window)")
|
||||
// }
|
||||
// if cache.CanResume(0, 4) {
|
||||
// t.Errorf("after shift: CanResume(0, 4) = true, want false (outside window)")
|
||||
// }
|
||||
// if !cache.CanResume(0, 5) {
|
||||
// t.Errorf("after shift: CanResume(0, 5) = false, want true (latest position)")
|
||||
// }
|
||||
// })
|
||||
// }
|
||||
|
||||
// func TestCanResumeSWAMem(t *testing.T) {
|
||||
// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
||||
// windowSize := int32(4)
|
||||
// memSize := int32(5)
|
||||
// cache := NewSWAMemCache(windowSize, memSize, nil)
|
||||
// defer cache.Close()
|
||||
|
||||
// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||
|
||||
// context := backend.NewContext()
|
||||
// defer context.Close()
|
||||
|
||||
// err := cache.StartForward(context, input.Batch{
|
||||
// Positions: []int32{0, 1, 2, 3, 4, 5, 6},
|
||||
// Sequences: []int{0, 0, 0, 0, 0, 0, 0},
|
||||
// }, false)
|
||||
// if err != nil {
|
||||
// t.Fatalf("StartForward failed: %v", err)
|
||||
// }
|
||||
|
||||
// cache.SetLayer(0)
|
||||
// tensor := context.FromFloats([]float32{1, 2, 3, 4, 5, 6, 7}, 1, 1, 7)
|
||||
// cache.Put(context, tensor, tensor)
|
||||
|
||||
// // shift window by adding position 7
|
||||
// err = cache.StartForward(context, input.Batch{
|
||||
// Positions: []int32{7},
|
||||
// Sequences: []int{0},
|
||||
// }, false)
|
||||
// if err != nil {
|
||||
// t.Fatalf("StartForward failed: %v", err)
|
||||
// }
|
||||
|
||||
// cache.SetLayer(0)
|
||||
// tensor = context.FromFloats([]float32{8}, 1, 1, 1)
|
||||
// cache.Put(context, tensor, tensor)
|
||||
|
||||
// // only the latest position has overlapping windows
|
||||
// if cache.CanResume(0, 0) {
|
||||
// t.Errorf("after shift: CanResume(0, 0) = true, want false (outside window)")
|
||||
// }
|
||||
// if cache.CanResume(0, 1) {
|
||||
// t.Errorf("after shift: CanResume(0, 1) = true, want false (outside window)")
|
||||
// }
|
||||
// if cache.CanResume(0, 2) {
|
||||
// t.Errorf("after shift: CanResume(0, 2) = true, want false (outside window)")
|
||||
// }
|
||||
// if cache.CanResume(0, 3) {
|
||||
// t.Errorf("after shift: CanResume(0, 3) = true, want false (outside window)")
|
||||
// }
|
||||
// if cache.CanResume(0, 4) {
|
||||
// t.Errorf("after shift: CanResume(0, 4) = true, want false (outside window)")
|
||||
// }
|
||||
// if cache.CanResume(0, 5) {
|
||||
// t.Errorf("after shift: CanResume(0, 5) = true, want false (outside window)")
|
||||
// }
|
||||
// if !cache.CanResume(0, 6) {
|
||||
// t.Errorf("after shift: CanResume(0, 6) = false, want true (inside window)")
|
||||
// }
|
||||
// if !cache.CanResume(0, 7) {
|
||||
// t.Errorf("after shift: CanResume(0, 7) = false, want true (latest position)")
|
||||
// }
|
||||
// })
|
||||
// }
|
||||
|
||||
// type testBackend struct {
|
||||
// ml.Backend
|
||||
// permutedV bool
|
||||
// }
|
||||
|
||||
// func (b *testBackend) NewContext() ml.Context {
|
||||
// return &testContext{}
|
||||
// }
|
||||
|
||||
// func (b *testBackend) NewContextSize(int) ml.Context {
|
||||
// return &testContext{}
|
||||
// }
|
||||
|
||||
// func (b *testBackend) CacheConfig() ml.CacheConfig {
|
||||
// return ml.CacheConfig{PermutedV: b.permutedV}
|
||||
// }
|
||||
|
||||
// type testContext struct {
|
||||
// ml.Context
|
||||
// }
|
||||
|
||||
// func (c *testContext) Empty(dtype ml.DType, shape ...int) ml.Tensor {
|
||||
// total := 0
|
||||
|
||||
// if len(shape) > 0 {
|
||||
// total = 1
|
||||
// for _, s := range shape {
|
||||
// total *= s
|
||||
// }
|
||||
// }
|
||||
|
||||
// return &testTensor{dtype: dtype, elementSize: 4, data: make([]float32, total), shape: shape}
|
||||
// }
|
||||
|
||||
// func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
|
||||
// return c.Empty(dtype, shape...)
|
||||
// }
|
||||
|
||||
// func (c *testContext) FromFloats(s []float32, shape ...int) ml.Tensor {
|
||||
// t := c.Empty(ml.DTypeF32, shape...).(*testTensor)
|
||||
|
||||
// copy(t.data, s)
|
||||
|
||||
// return t
|
||||
// }
|
||||
|
||||
// func (c *testContext) FromInts(s []int32, shape ...int) ml.Tensor {
|
||||
// f := make([]float32, len(s))
|
||||
// for i := range f {
|
||||
// f[i] = float32(s[i])
|
||||
// }
|
||||
|
||||
// out := c.FromFloats(f, shape...)
|
||||
// out.(*testTensor).dtype = ml.DTypeI32
|
||||
|
||||
// return out
|
||||
// }
|
||||
|
||||
// func (c *testContext) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor {
|
||||
// s := make([]float32, 0, int((stop-start)/step))
|
||||
// for i := start; i < stop; i += step {
|
||||
// s = append(s, i)
|
||||
// }
|
||||
|
||||
// out := c.FromFloats(s, len(s))
|
||||
// out.(*testTensor).dtype = dtype
|
||||
// return out
|
||||
// }
|
||||
|
||||
// func (c *testContext) Input() ml.Context { return c }
|
||||
// func (c *testContext) Layer(int) ml.Context { return c }
|
||||
|
||||
// func (c *testContext) Forward(...ml.Tensor) ml.Context { return c }
|
||||
|
||||
// func (c *testContext) Compute(...ml.Tensor) {}
|
||||
|
||||
// func (c *testContext) Reserve() {}
|
||||
|
||||
// func (c *testContext) MaxGraphNodes() int {
|
||||
// return 10
|
||||
// }
|
||||
|
||||
// func (c *testContext) Close() {}
|
||||
|
||||
// type testTensor struct {
|
||||
// ml.Tensor
|
||||
|
||||
// dtype ml.DType
|
||||
// elementSize int
|
||||
// data []float32
|
||||
// shape []int
|
||||
// }
|
||||
|
||||
// func (t *testTensor) Dim(n int) int {
|
||||
// return t.shape[n]
|
||||
// }
|
||||
|
||||
// func (t *testTensor) Stride(n int) int {
|
||||
// stride := t.elementSize
|
||||
// for i := range n {
|
||||
// stride *= t.shape[i]
|
||||
// }
|
||||
|
||||
// return stride
|
||||
// }
|
||||
|
||||
// func (t *testTensor) Shape() []int {
|
||||
// return t.shape
|
||||
// }
|
||||
|
||||
// func (t *testTensor) DType() ml.DType {
|
||||
// return t.dtype
|
||||
// }
|
||||
|
||||
// func (t *testTensor) Floats() []float32 {
|
||||
// out := make([]float32, len(t.data))
|
||||
// copy(out, t.data)
|
||||
// return out
|
||||
// }
|
||||
|
||||
// func (t *testTensor) Neg(ctx ml.Context) ml.Tensor {
|
||||
// out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor)
|
||||
// for i := range out.data {
|
||||
// out.data[i] = -t.data[i]
|
||||
// }
|
||||
// return out
|
||||
// }
|
||||
|
||||
// func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
// out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor)
|
||||
|
||||
// for i := range out.data {
|
||||
// out.data[i] = t.data[i] + t2.(*testTensor).data[i]
|
||||
// }
|
||||
|
||||
// return out
|
||||
// }
|
||||
|
||||
// func (t *testTensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
|
||||
// return &testTensor{
|
||||
// dtype: t.dtype,
|
||||
// elementSize: t.elementSize,
|
||||
// data: t.data,
|
||||
// shape: shape,
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
|
||||
// offset /= t.elementSize
|
||||
|
||||
// var s []int
|
||||
|
||||
// switch len(shape) {
|
||||
// case 1:
|
||||
// s = []int{shape[0]}
|
||||
// case 3:
|
||||
// s = []int{shape[0], shape[2]}
|
||||
// case 5:
|
||||
// s = []int{shape[0], shape[2], shape[4]}
|
||||
// default:
|
||||
// panic("unsupported number of dimensions")
|
||||
// }
|
||||
|
||||
// context := &testContext{}
|
||||
|
||||
// view := context.Empty(t.dtype, s...).(*testTensor)
|
||||
// view.data = t.data[offset : offset+len(view.data)]
|
||||
|
||||
// return view
|
||||
// }
|
||||
|
||||
// func (t *testTensor) Permute(ctx ml.Context, order ...int) ml.Tensor {
|
||||
// if len(t.shape) > 4 || len(order) > 4 {
|
||||
// panic("permute only supports up to 4 dimensions")
|
||||
// }
|
||||
|
||||
// if len(order) != len(t.shape) && len(order) != 4 {
|
||||
// panic("invalid number of dimensions for permute")
|
||||
// }
|
||||
|
||||
// // ggml_permute expects 4 axes, so fill in any missing dimensions.
|
||||
// orderFull := append(make([]int, 0, 4), order...)
|
||||
// for len(orderFull) < 4 {
|
||||
// orderFull = append(orderFull, len(orderFull))
|
||||
// }
|
||||
|
||||
// seen := [4]bool{}
|
||||
|
||||
// shape4 := [4]int{1, 1, 1, 1}
|
||||
// for i := 0; i < len(t.shape) && i < 4; i++ {
|
||||
// shape4[i] = t.shape[i]
|
||||
// }
|
||||
|
||||
// newShape4 := [4]int{1, 1, 1, 1}
|
||||
// for axis := range 4 {
|
||||
// dst := orderFull[axis]
|
||||
// if dst < 0 || dst >= 4 {
|
||||
// panic("invalid axis for permute")
|
||||
// }
|
||||
// if seen[dst] {
|
||||
// panic("duplicate axis for permute")
|
||||
// }
|
||||
// seen[dst] = true
|
||||
// newShape4[dst] = shape4[axis]
|
||||
// }
|
||||
|
||||
// total := len(t.data)
|
||||
// newData := make([]float32, total)
|
||||
|
||||
// if total > 0 {
|
||||
// oldDims := shape4
|
||||
// newDims := newShape4
|
||||
|
||||
// oldStride := [4]int{1, 1, 1, 1}
|
||||
// newStride := [4]int{1, 1, 1, 1}
|
||||
// for i := 1; i < 4; i++ {
|
||||
// oldStride[i] = oldStride[i-1] * oldDims[i-1]
|
||||
// newStride[i] = newStride[i-1] * newDims[i-1]
|
||||
// }
|
||||
|
||||
// var coords [4]int
|
||||
// var newCoords [4]int
|
||||
|
||||
// for idx := range total {
|
||||
// remainder := idx
|
||||
// for axis := range 4 {
|
||||
// dim := oldDims[axis]
|
||||
// if dim == 0 {
|
||||
// coords[axis] = 0
|
||||
// continue
|
||||
// }
|
||||
// coords[axis] = remainder % dim
|
||||
// remainder /= dim
|
||||
// }
|
||||
|
||||
// for axis := range 4 {
|
||||
// newCoords[orderFull[axis]] = coords[axis]
|
||||
// }
|
||||
|
||||
// newIndex := 0
|
||||
// for axis := range 4 {
|
||||
// if newDims[axis] == 0 {
|
||||
// continue
|
||||
// }
|
||||
// newIndex += newCoords[axis] * newStride[axis]
|
||||
// }
|
||||
|
||||
// newData[newIndex] = t.data[idx]
|
||||
// }
|
||||
// }
|
||||
|
||||
// numDims := 4
|
||||
// for numDims > 1 && newShape4[numDims-1] <= 1 {
|
||||
// numDims--
|
||||
// }
|
||||
|
||||
// newShape := make([]int, numDims)
|
||||
// copy(newShape, newShape4[:numDims])
|
||||
|
||||
// return &testTensor{
|
||||
// dtype: t.dtype,
|
||||
// elementSize: t.elementSize,
|
||||
// data: newData,
|
||||
// shape: newShape,
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (t *testTensor) SetRows(ctx ml.Context, src ml.Tensor, idxs ml.Tensor) ml.Tensor {
|
||||
// dst := t
|
||||
// srcTensor := src.(*testTensor)
|
||||
// idxTensor := idxs.(*testTensor)
|
||||
|
||||
// shapeTo4D := func(shape []int) [4]int {
|
||||
// out := [4]int{1, 1, 1, 1}
|
||||
// for i := 0; i < len(shape) && i < 4; i++ {
|
||||
// out[i] = shape[i]
|
||||
// }
|
||||
// return out
|
||||
// }
|
||||
|
||||
// computeStrides := func(shape [4]int) [4]int {
|
||||
// out := [4]int{1, 1, 1, 1}
|
||||
// for i := 1; i < 4; i++ {
|
||||
// out[i] = out[i-1] * shape[i-1]
|
||||
// }
|
||||
// return out
|
||||
// }
|
||||
|
||||
// dstShape4D := shapeTo4D(dst.shape)
|
||||
// srcShape4D := shapeTo4D(srcTensor.shape)
|
||||
// idxShape4D := shapeTo4D(idxTensor.shape)
|
||||
|
||||
// if dstShape4D[0] != srcShape4D[0] || dstShape4D[2] != srcShape4D[2] || dstShape4D[3] != srcShape4D[3] {
|
||||
// panic("SetRows requires matching tensor shapes")
|
||||
// }
|
||||
|
||||
// if srcShape4D[1] != idxShape4D[0] {
|
||||
// panic("SetRows rows/index mismatch")
|
||||
// }
|
||||
|
||||
// if srcShape4D[2]%idxShape4D[1] != 0 || srcShape4D[3]%idxShape4D[2] != 0 {
|
||||
// panic("SetRows cannot broadcast indices")
|
||||
// }
|
||||
|
||||
// if idxShape4D[3] != 1 {
|
||||
// panic("SetRows expects 1D or 2D index tensors")
|
||||
// }
|
||||
|
||||
// dstStride := computeStrides(dstShape4D)
|
||||
// srcStride := computeStrides(srcShape4D)
|
||||
// idxStride := computeStrides(idxShape4D)
|
||||
|
||||
// numColumns := srcShape4D[0]
|
||||
// numRows := srcShape4D[1]
|
||||
|
||||
// for dim3Index := range dstShape4D[3] {
|
||||
// for dim2Index := range dstShape4D[2] {
|
||||
// idxDim2 := 0
|
||||
// idxDim3 := 0
|
||||
// if idxShape4D[1] > 0 {
|
||||
// idxDim2 = dim2Index % idxShape4D[1]
|
||||
// }
|
||||
// if idxShape4D[2] > 0 {
|
||||
// idxDim3 = dim3Index % idxShape4D[2]
|
||||
// }
|
||||
|
||||
// idxBase := idxDim3*idxStride[2] + idxDim2*idxStride[1]
|
||||
// srcBase := dim3Index*srcStride[3] + dim2Index*srcStride[2]
|
||||
// dstBase := dim3Index*dstStride[3] + dim2Index*dstStride[2]
|
||||
|
||||
// for row := range numRows {
|
||||
// idx := int(idxTensor.data[idxBase+row*idxStride[0]])
|
||||
// if idx < 0 || idx >= dstShape4D[1] {
|
||||
// panic("SetRows index out of range")
|
||||
// }
|
||||
|
||||
// srcOffset := srcBase + row*srcStride[1]
|
||||
// dstOffset := dstBase + idx*dstStride[1]
|
||||
|
||||
// copy(dst.data[dstOffset:dstOffset+numColumns], srcTensor.data[srcOffset:srcOffset+numColumns])
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// return dst
|
||||
// }
|
||||
|
||||
// func (t *testTensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
// copy(t2.(*testTensor).data, t.data)
|
||||
// return nil
|
||||
// }
|
||||
144
x/kvcache/mlx.go
144
x/kvcache/mlx.go
@@ -1,144 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package kvcache
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/x/ml"
|
||||
"github.com/ollama/ollama/x/model/input"
|
||||
)
|
||||
|
||||
// Causal cache stores K and V tensors according to their position in the
|
||||
// sequence. Returns the history and a mask for attending to past tokens
|
||||
type MLXCausal struct {
|
||||
DType ml.DType
|
||||
|
||||
// locations for data storage for this batch
|
||||
curLocPut ml.Tensor
|
||||
|
||||
// locations for data storage for this batch
|
||||
curLocGet ml.Tensor
|
||||
|
||||
// the active layer for Get and Put
|
||||
curLayer int
|
||||
|
||||
capacity int
|
||||
|
||||
offset int
|
||||
|
||||
backend ml.Backend
|
||||
ctxs map[int]ml.Context
|
||||
keys, values map[int]ml.Tensor
|
||||
|
||||
// TODO is this needed per layer, or will it always be consistent?
|
||||
kHeadDims, vHeadDims, numKVHeads map[int]int
|
||||
}
|
||||
|
||||
func NewMLXCausalCache() *MLXCausal {
|
||||
return &MLXCausal{
|
||||
ctxs: make(map[int]ml.Context),
|
||||
keys: make(map[int]ml.Tensor),
|
||||
values: make(map[int]ml.Tensor),
|
||||
kHeadDims: make(map[int]int),
|
||||
vHeadDims: make(map[int]int),
|
||||
numKVHeads: make(map[int]int),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *MLXCausal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||
c.DType = dtype
|
||||
c.capacity = capacity
|
||||
c.backend = backend
|
||||
}
|
||||
|
||||
func (c *MLXCausal) SetConfig(config ml.CacheConfig) {}
|
||||
|
||||
func (c *MLXCausal) SetLayer(layer int) {
|
||||
c.curLayer = layer
|
||||
}
|
||||
|
||||
func (c *MLXCausal) Close() {
|
||||
// slog.Info("XXX MLXCausal.Close called", "number of contexts", len(c.ctxs))
|
||||
for _, ctx := range c.ctxs {
|
||||
ctx.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *MLXCausal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
||||
locsPut := make([]int32, len(batch.Positions))
|
||||
for i := c.offset; i < len(batch.Positions); i++ {
|
||||
locsPut[i-c.offset] = int32(i)
|
||||
}
|
||||
c.offset += len(batch.Positions)
|
||||
locsGet := make([]int32, c.offset)
|
||||
for i := range c.offset {
|
||||
locsGet[i] = int32(i)
|
||||
}
|
||||
c.curLocGet = ctx.Input().FromInts(locsGet, len(locsGet))
|
||||
c.curLocPut = ctx.Input().FromInts(locsPut, len(locsPut))
|
||||
// slog.Info("XXX MLXCausal.StartForward", "offset", c.offset, "put", locsPut, "get", locsGet)
|
||||
|
||||
return nil
|
||||
}
|
||||
func (c *MLXCausal) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||
kHeadDim := key.Dim(3)
|
||||
vHeadDim := value.Dim(3)
|
||||
numKVHeads := key.Dim(1)
|
||||
batchSize := key.Dim(2)
|
||||
kCellSize := kHeadDim * numKVHeads
|
||||
vCellSize := vHeadDim * numKVHeads
|
||||
// slog.Info("XXX Causal.Put", "kHeadDim", kHeadDim, "vHeadDim", vHeadDim, "numKVHeads", numKVHeads, "batchSize", batchSize, "kCellSize", kCellSize, "vCellSize", vCellSize)
|
||||
|
||||
if _, ok := c.ctxs[c.curLayer]; !ok {
|
||||
// slog.Info("XXX Causal.Put creating new context", "c.curLayer", c.curLayer)
|
||||
c.ctxs[c.curLayer] = c.backend.NewContext().Layer(c.curLayer)
|
||||
}
|
||||
|
||||
if _, ok := c.keys[c.curLayer]; !ok {
|
||||
// slog.Info("XXX MLXCausal.Put allocating keys and values", "c.curLayer", c.curLayer, "shape", []int{c.capacity, kCellSize})
|
||||
c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, c.capacity, kCellSize)
|
||||
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, c.capacity, vCellSize)
|
||||
c.kHeadDims[c.curLayer] = kHeadDim
|
||||
c.vHeadDims[c.curLayer] = vHeadDim
|
||||
c.numKVHeads[c.curLayer] = numKVHeads
|
||||
}
|
||||
key = key.Reshape(ctx, batchSize, 1, kCellSize)
|
||||
|
||||
// slog.Info("XXX MLXCausal.Put ", "c.keys[c.curLayer]", c.keys[c.curLayer])
|
||||
// slog.Info("XXX MLXCausal.Put ", "c.curLocPut", c.curLocPut)
|
||||
// slog.Info("XXX MLXCausal.Put ", "key", key)
|
||||
ctx.Forward(c.keys[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLocPut}, key, []int{0}))
|
||||
value = value.Reshape(ctx, batchSize, 1, vCellSize)
|
||||
ctx.Forward(c.values[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLocPut}, value, []int{0}))
|
||||
|
||||
}
|
||||
|
||||
func (c *MLXCausal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||
key := c.keys[c.curLayer]
|
||||
value := c.values[c.curLayer]
|
||||
|
||||
kHeadDim := c.kHeadDims[c.curLayer]
|
||||
vHeadDim := c.vHeadDims[c.curLayer]
|
||||
numKVHeads := c.numKVHeads[c.curLayer]
|
||||
// rowSize := numKVHeads * c.curBatchSize
|
||||
// cachedSize := c.curMask.Dim(1)
|
||||
cachedSize := c.curLocGet.Dim(0)
|
||||
// kCellSize := kHeadDim * numKVHeads
|
||||
// vCellSize := vHeadDim * numKVHeads
|
||||
// slog.Info("XXX MLXCausal.Get", "shape", []int{1, numKVHeads, cachedSize, kHeadDim})
|
||||
|
||||
key = key.TakeAxes(ctx, c.curLocGet, 0).Reshape(ctx, 1, numKVHeads, cachedSize, kHeadDim)
|
||||
value = value.TakeAxes(ctx, c.curLocGet, 0).Reshape(ctx, 1, numKVHeads, cachedSize, vHeadDim)
|
||||
return key, value, nil
|
||||
}
|
||||
|
||||
func (c *MLXCausal) CopyPrefix(srcSeq, dstSeq int, len int32) {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (c *MLXCausal) CanResume(seq int, pos int32) bool {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (c *MLXCausal) Remove(seq int, beginIndex, endIndex int32) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
134
x/mlxrunner/imagegen.go
Normal file
134
x/mlxrunner/imagegen.go
Normal file
@@ -0,0 +1,134 @@
|
||||
//go:build mlx
|
||||
|
||||
package mlxrunner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/models/flux2"
|
||||
"github.com/ollama/ollama/x/imagegen/models/zimage"
|
||||
)
|
||||
|
||||
// ImageModel is the interface for image generation models.
|
||||
type ImageModel interface {
|
||||
GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64, progress func(step, total int)) (*mlx.Array, error)
|
||||
}
|
||||
|
||||
var imageGenMu sync.Mutex
|
||||
|
||||
// loadImageModel loads an image generation model.
|
||||
func (s *server) loadImageModel() error {
|
||||
// Check memory requirements before loading
|
||||
var requiredMemory uint64
|
||||
if manifest, err := imagegen.LoadManifest(s.modelName); err == nil {
|
||||
requiredMemory = uint64(manifest.TotalTensorSize())
|
||||
}
|
||||
availableMemory := mlx.GetMemoryLimit()
|
||||
if availableMemory > 0 && requiredMemory > 0 && availableMemory < requiredMemory {
|
||||
return fmt.Errorf("insufficient memory for image generation: need %d GB, have %d GB",
|
||||
requiredMemory/(1024*1024*1024), availableMemory/(1024*1024*1024))
|
||||
}
|
||||
|
||||
// Detect model type and load appropriate model
|
||||
modelType := imagegen.DetectModelType(s.modelName)
|
||||
slog.Info("detected image model type", "type", modelType)
|
||||
|
||||
var model ImageModel
|
||||
switch modelType {
|
||||
case "Flux2KleinPipeline":
|
||||
m := &flux2.Model{}
|
||||
if err := m.Load(s.modelName); err != nil {
|
||||
return fmt.Errorf("failed to load flux2 model: %w", err)
|
||||
}
|
||||
model = m
|
||||
default:
|
||||
// Default to Z-Image for ZImagePipeline, FluxPipeline, etc.
|
||||
m := &zimage.Model{}
|
||||
if err := m.Load(s.modelName); err != nil {
|
||||
return fmt.Errorf("failed to load zimage model: %w", err)
|
||||
}
|
||||
model = m
|
||||
}
|
||||
|
||||
s.imageModel = model
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleImageCompletion handles image generation requests.
|
||||
func (s *server) handleImageCompletion(w http.ResponseWriter, r *http.Request, req Request) {
|
||||
// Serialize generation requests - MLX model may not handle concurrent generation
|
||||
imageGenMu.Lock()
|
||||
defer imageGenMu.Unlock()
|
||||
|
||||
// Set seed if not provided
|
||||
if req.Seed <= 0 {
|
||||
req.Seed = time.Now().UnixNano()
|
||||
}
|
||||
|
||||
// Set up streaming response
|
||||
w.Header().Set("Content-Type", "application/x-ndjson")
|
||||
w.Header().Set("Transfer-Encoding", "chunked")
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
http.Error(w, "streaming not supported", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
enc := json.NewEncoder(w)
|
||||
|
||||
// Progress callback streams step updates
|
||||
progress := func(step, total int) {
|
||||
resp := Response{Step: step, Total: total}
|
||||
enc.Encode(resp)
|
||||
w.Write([]byte("\n"))
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
// Generate image
|
||||
img, err := s.imageModel.GenerateImage(ctx, req.Prompt, req.Width, req.Height, req.Steps, req.Seed, progress)
|
||||
if err != nil {
|
||||
// Don't send error for cancellation
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
resp := Response{Content: fmt.Sprintf("error: %v", err), Done: true}
|
||||
data, _ := json.Marshal(resp)
|
||||
w.Write(data)
|
||||
w.Write([]byte("\n"))
|
||||
return
|
||||
}
|
||||
|
||||
// Encode image as base64 PNG
|
||||
imageData, err := imagegen.EncodeImageBase64(img)
|
||||
if err != nil {
|
||||
resp := Response{Content: fmt.Sprintf("error encoding: %v", err), Done: true}
|
||||
data, _ := json.Marshal(resp)
|
||||
w.Write(data)
|
||||
w.Write([]byte("\n"))
|
||||
return
|
||||
}
|
||||
|
||||
// Free the generated image array and clean up MLX state
|
||||
img.Free()
|
||||
mlx.ClearCache()
|
||||
mlx.MetalResetPeakMemory()
|
||||
|
||||
// Send final response with image data
|
||||
resp := Response{
|
||||
Image: imageData,
|
||||
Done: true,
|
||||
}
|
||||
data, _ := json.Marshal(resp)
|
||||
w.Write(data)
|
||||
w.Write([]byte("\n"))
|
||||
flusher.Flush()
|
||||
}
|
||||
420
x/mlxrunner/llm.go
Normal file
420
x/mlxrunner/llm.go
Normal file
@@ -0,0 +1,420 @@
|
||||
//go:build mlx
|
||||
|
||||
package mlxrunner
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/models/glm4_moe_lite"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
)
|
||||
|
||||
// TextModel is the interface for LLM text generation models.
|
||||
type TextModel interface {
|
||||
Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array
|
||||
NewCache(maxSeqLen int32) []cache.Cache
|
||||
Tokenizer() *tokenizer.Tokenizer
|
||||
VocabSize() int32
|
||||
MaxContextLength() int32
|
||||
NumLayers() int
|
||||
}
|
||||
|
||||
// llmState holds the state for LLM generation
|
||||
type llmState struct {
|
||||
model TextModel
|
||||
}
|
||||
|
||||
var llmMu sync.Mutex
|
||||
|
||||
// Dedicated stream for generation (like mlx-lm's generation_stream)
|
||||
var generationStream *mlx.Stream
|
||||
|
||||
// withStream runs fn with the generation stream as default
|
||||
func withStream(fn func()) {
|
||||
// Lazy initialization of generationStream
|
||||
if generationStream == nil {
|
||||
generationStream = mlx.NewStream()
|
||||
}
|
||||
orig := mlx.GetDefaultStream()
|
||||
mlx.SetDefaultStream(generationStream)
|
||||
fn()
|
||||
mlx.SetDefaultStream(orig)
|
||||
}
|
||||
|
||||
// Decoder wraps model + cache for autoregressive generation.
|
||||
// This matches the pattern from cmd/engine/generate.go
|
||||
type Decoder struct {
|
||||
model TextModel
|
||||
caches []cache.Cache
|
||||
vocabSize int32
|
||||
temp float32
|
||||
token *mlx.Array // Current token (kept across iterations)
|
||||
oldCacheState []*mlx.Array // Preallocated slice for old cache state
|
||||
}
|
||||
|
||||
func NewDecoder(m TextModel, temp float32) *Decoder {
|
||||
caches := m.NewCache(0)
|
||||
return &Decoder{
|
||||
model: m,
|
||||
caches: caches,
|
||||
vocabSize: m.VocabSize(),
|
||||
temp: temp,
|
||||
oldCacheState: make([]*mlx.Array, 0, len(caches)*2),
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Decoder) prefill(inputIDs []int32) int {
|
||||
processed := 0
|
||||
|
||||
// Track old cache state to free after each chunk
|
||||
var oldCacheState []*mlx.Array
|
||||
|
||||
// Process all-but-1 tokens in chunks, eval cache state for memory management
|
||||
for len(inputIDs) > 1 {
|
||||
chunkSize := min(2048, len(inputIDs)-1)
|
||||
if chunkSize <= 0 {
|
||||
break
|
||||
}
|
||||
chunk := inputIDs[:chunkSize]
|
||||
|
||||
// Save old cache state before forward
|
||||
oldCacheState = oldCacheState[:0]
|
||||
for _, c := range d.caches {
|
||||
oldCacheState = append(oldCacheState, c.State()...)
|
||||
}
|
||||
|
||||
var cacheState []*mlx.Array
|
||||
withStream(func() {
|
||||
x := mlx.NewArrayInt32(chunk, []int32{1, int32(len(chunk))})
|
||||
d.model.Forward(x, d.caches)
|
||||
for _, c := range d.caches {
|
||||
cacheState = append(cacheState, c.State()...)
|
||||
}
|
||||
})
|
||||
mlx.Eval(cacheState...)
|
||||
|
||||
// Free old cache state
|
||||
for _, arr := range oldCacheState {
|
||||
if arr != nil {
|
||||
arr.Free()
|
||||
}
|
||||
}
|
||||
|
||||
inputIDs = inputIDs[chunkSize:]
|
||||
processed += chunkSize
|
||||
}
|
||||
|
||||
// Save old cache state before final step
|
||||
oldCacheState = oldCacheState[:0]
|
||||
for _, c := range d.caches {
|
||||
oldCacheState = append(oldCacheState, c.State()...)
|
||||
}
|
||||
|
||||
// Final token + sampling
|
||||
withStream(func() {
|
||||
x := mlx.NewArrayInt32(inputIDs, []int32{1, int32(len(inputIDs))})
|
||||
mlx.Eval(x) // Materialize before any other evals
|
||||
logits := d.model.Forward(x, d.caches)
|
||||
d.token = sample(logits, d.temp, d.vocabSize)
|
||||
})
|
||||
// Keep cache state (token auto-kept by AsyncEval)
|
||||
for _, c := range d.caches {
|
||||
mlx.Keep(c.State()...)
|
||||
}
|
||||
mlx.AsyncEval(d.token)
|
||||
|
||||
// Free old cache state from before final step
|
||||
for _, arr := range oldCacheState {
|
||||
if arr != nil {
|
||||
arr.Free()
|
||||
}
|
||||
}
|
||||
|
||||
mlx.ClearCache()
|
||||
|
||||
return processed + len(inputIDs)
|
||||
}
|
||||
|
||||
func (d *Decoder) step() int32 {
|
||||
prevToken := d.token
|
||||
|
||||
// Save old cache state (reuse preallocated slice)
|
||||
d.oldCacheState = d.oldCacheState[:0]
|
||||
for _, c := range d.caches {
|
||||
d.oldCacheState = append(d.oldCacheState, c.State()...)
|
||||
}
|
||||
|
||||
withStream(func() {
|
||||
logits := d.model.Forward(mlx.Reshape(prevToken, 1, 1), d.caches)
|
||||
d.token = sample(logits, d.temp, d.vocabSize)
|
||||
})
|
||||
// Keep token and new cache state so they survive cleanup
|
||||
mlx.Keep(d.token)
|
||||
for _, c := range d.caches {
|
||||
mlx.Keep(c.State()...)
|
||||
}
|
||||
mlx.AsyncEval(d.token)
|
||||
|
||||
// Sync on previous token (GPU already working on next step)
|
||||
val := prevToken.ItemInt32()
|
||||
|
||||
// Free old token and old cache state
|
||||
prevToken.Free()
|
||||
for _, arr := range d.oldCacheState {
|
||||
arr.Free()
|
||||
}
|
||||
return val
|
||||
}
|
||||
|
||||
// sample samples from logits using temperature scaling
|
||||
func sample(logits *mlx.Array, temp float32, vocabSize int32) *mlx.Array {
|
||||
// Get last position logits: [1, L, vocab] -> [vocab]
|
||||
shape := logits.Shape()
|
||||
seqLen := shape[1]
|
||||
lastLogits := mlx.Slice(logits, []int32{0, seqLen - 1, 0}, []int32{1, seqLen, vocabSize})
|
||||
lastLogits = mlx.Reshape(lastLogits, vocabSize)
|
||||
|
||||
if temp <= 0 || temp < 0.01 {
|
||||
// Greedy decoding
|
||||
return mlx.Argmax(lastLogits, -1, false)
|
||||
}
|
||||
|
||||
// Apply temperature scaling
|
||||
scaled := mlx.DivScalar(lastLogits, temp)
|
||||
return mlx.RandomCategorical(scaled, -1, 1)
|
||||
}
|
||||
|
||||
// loadLLMModel loads a safetensors LLM model and its tokenizer from manifest storage.
|
||||
func (s *server) loadLLMModel() error {
|
||||
// Load the manifest to get model information
|
||||
manifest, err := imagegen.LoadManifest(s.modelName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load manifest: %w", err)
|
||||
}
|
||||
|
||||
// Detect model architecture from config.json
|
||||
configData, err := manifest.ReadConfig("config.json")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read config.json: %w", err)
|
||||
}
|
||||
|
||||
var modelConfig struct {
|
||||
Architectures []string `json:"architectures"`
|
||||
ModelType string `json:"model_type"`
|
||||
}
|
||||
if err := json.Unmarshal(configData, &modelConfig); err != nil {
|
||||
return fmt.Errorf("failed to parse config.json: %w", err)
|
||||
}
|
||||
|
||||
arch := ""
|
||||
if len(modelConfig.Architectures) > 0 {
|
||||
arch = modelConfig.Architectures[0]
|
||||
}
|
||||
if arch == "" {
|
||||
arch = modelConfig.ModelType
|
||||
}
|
||||
|
||||
slog.Info("detected LLM architecture", "architecture", arch, "model_type", modelConfig.ModelType)
|
||||
|
||||
// Load the appropriate model based on architecture
|
||||
var model TextModel
|
||||
archLower := strings.ToLower(arch)
|
||||
|
||||
switch {
|
||||
case strings.Contains(archLower, "glm4moelite"):
|
||||
m, err := glm4_moe_lite.LoadFromManifest(manifest)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load glm4-moe-lite model: %w", err)
|
||||
}
|
||||
model = m
|
||||
slog.Info("loaded glm4-moe-lite model", "vocab_size", m.VocabSize(), "layers", m.NumLayers())
|
||||
|
||||
default:
|
||||
return fmt.Errorf("LLM architecture %q is not yet supported. "+
|
||||
"Supported architectures: glm4-moe-lite. "+
|
||||
"Please convert your model to GGUF format or use a supported architecture", arch)
|
||||
}
|
||||
|
||||
s.llmModel = &llmState{
|
||||
model: model,
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleLLMCompletion handles LLM text generation requests.
|
||||
func (s *server) handleLLMCompletion(w http.ResponseWriter, r *http.Request, req Request) {
|
||||
if s.llmModel == nil {
|
||||
http.Error(w, "LLM model not loaded", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Serialize generation requests
|
||||
llmMu.Lock()
|
||||
defer llmMu.Unlock()
|
||||
|
||||
if err := s.llmGenerate(w, r, req); err != nil {
|
||||
slog.Error("LLM generation failed", "error", err)
|
||||
// Don't send error if we've already started streaming
|
||||
}
|
||||
}
|
||||
|
||||
// llmGenerate runs the generation loop using the Decoder pattern from cmd/engine
|
||||
func (s *server) llmGenerate(w http.ResponseWriter, r *http.Request, req Request) error {
|
||||
state := s.llmModel
|
||||
|
||||
// Set up streaming response
|
||||
w.Header().Set("Content-Type", "application/x-ndjson")
|
||||
w.Header().Set("Transfer-Encoding", "chunked")
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
return errors.New("streaming not supported")
|
||||
}
|
||||
|
||||
tok := state.model.Tokenizer()
|
||||
|
||||
// The prompt is already formatted by the server using the model's renderer
|
||||
// (see server/prompt.go renderPrompt), so we don't apply FormatPrompt here.
|
||||
prompt := req.Prompt
|
||||
|
||||
// Tokenize the prompt
|
||||
inputIDs := tok.Encode(prompt, true)
|
||||
slog.Debug("tokenized prompt", "num_tokens", len(inputIDs))
|
||||
|
||||
// Generation parameters
|
||||
maxTokens := int(state.model.MaxContextLength())
|
||||
if maxTokens <= 0 {
|
||||
maxTokens = 4096
|
||||
}
|
||||
if req.Options != nil && req.Options.NumPredict > 0 {
|
||||
maxTokens = req.Options.NumPredict
|
||||
}
|
||||
|
||||
temperature := float32(0.7)
|
||||
if req.Options != nil && req.Options.Temperature > 0 {
|
||||
temperature = float32(req.Options.Temperature)
|
||||
}
|
||||
|
||||
// Enable MLX compilation for better performance
|
||||
mlx.EnableCompile()
|
||||
|
||||
// Create decoder with fresh caches
|
||||
dec := NewDecoder(state.model, temperature)
|
||||
|
||||
prefillStart := time.Now()
|
||||
prefillTokens := dec.prefill(inputIDs)
|
||||
// Prefill measurement includes time to first token
|
||||
firstToken := dec.step()
|
||||
prefillDuration := time.Since(prefillStart)
|
||||
promptEvalDuration := prefillDuration
|
||||
|
||||
enc := json.NewEncoder(w)
|
||||
ctx := r.Context()
|
||||
generated := 0
|
||||
stopReason := "max_tokens"
|
||||
|
||||
// Handle first token
|
||||
generated++
|
||||
if tok.IsEOS(firstToken) {
|
||||
resp := Response{
|
||||
Done: true,
|
||||
StopReason: fmt.Sprintf("first_token_eos:%d", firstToken),
|
||||
PromptEvalCount: prefillTokens,
|
||||
PromptEvalDuration: int(promptEvalDuration.Nanoseconds()),
|
||||
}
|
||||
enc.Encode(resp)
|
||||
flusher.Flush()
|
||||
return nil
|
||||
}
|
||||
|
||||
text := tok.Decode([]int32{firstToken})
|
||||
resp := Response{Content: text}
|
||||
enc.Encode(resp)
|
||||
flusher.Flush()
|
||||
|
||||
genStart := time.Now()
|
||||
|
||||
// Generation loop
|
||||
for n := 1; n < maxTokens; n++ {
|
||||
// Check for cancellation
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
stopReason = fmt.Sprintf("context_cancelled:%d", generated)
|
||||
break
|
||||
default:
|
||||
}
|
||||
if stopReason != "max_tokens" {
|
||||
break
|
||||
}
|
||||
|
||||
token := dec.step()
|
||||
generated++
|
||||
|
||||
if tok.IsEOS(token) {
|
||||
stopReason = fmt.Sprintf("eos_token:%d", token)
|
||||
break
|
||||
}
|
||||
|
||||
text := tok.Decode([]int32{token})
|
||||
|
||||
// Check for stop sequences
|
||||
if req.Options != nil && len(req.Options.Stop) > 0 {
|
||||
shouldStop := false
|
||||
var matchedStop string
|
||||
for _, stop := range req.Options.Stop {
|
||||
if strings.Contains(text, stop) {
|
||||
text = strings.Split(text, stop)[0]
|
||||
shouldStop = true
|
||||
matchedStop = stop
|
||||
break
|
||||
}
|
||||
}
|
||||
if shouldStop {
|
||||
if text != "" {
|
||||
resp := Response{Content: text}
|
||||
enc.Encode(resp)
|
||||
flusher.Flush()
|
||||
}
|
||||
stopReason = fmt.Sprintf("stop_sequence:%s", matchedStop)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
resp := Response{Content: text}
|
||||
enc.Encode(resp)
|
||||
flusher.Flush()
|
||||
|
||||
// Periodically clear MLX cache
|
||||
if n%256 == 0 {
|
||||
mlx.ClearCache()
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up
|
||||
mlx.ClearCache()
|
||||
|
||||
// Send final response with stats
|
||||
evalDuration := time.Since(genStart)
|
||||
resp = Response{
|
||||
Done: true,
|
||||
StopReason: fmt.Sprintf("%s:generated=%d", stopReason, generated),
|
||||
PromptEvalCount: prefillTokens,
|
||||
PromptEvalDuration: int(promptEvalDuration.Nanoseconds()),
|
||||
EvalCount: generated,
|
||||
EvalDuration: int(evalDuration.Nanoseconds()),
|
||||
}
|
||||
enc.Encode(resp)
|
||||
flusher.Flush()
|
||||
|
||||
return nil
|
||||
}
|
||||
204
x/mlxrunner/runner.go
Normal file
204
x/mlxrunner/runner.go
Normal file
@@ -0,0 +1,204 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package mlxrunner provides a unified MLX runner for both LLM and image generation models.
|
||||
package mlxrunner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// Execute is the entry point for the unified MLX runner subprocess.
|
||||
func Execute(args []string) error {
|
||||
// Set up logging with appropriate level from environment
|
||||
slog.SetDefault(slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: envconfig.LogLevel()})))
|
||||
|
||||
fs := flag.NewFlagSet("mlx-runner", flag.ExitOnError)
|
||||
modelName := fs.String("model", "", "path to model")
|
||||
port := fs.Int("port", 0, "port to listen on")
|
||||
|
||||
if err := fs.Parse(args); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if *modelName == "" {
|
||||
return fmt.Errorf("--model is required")
|
||||
}
|
||||
if *port == 0 {
|
||||
return fmt.Errorf("--port is required")
|
||||
}
|
||||
|
||||
// Initialize MLX
|
||||
if err := mlx.InitMLX(); err != nil {
|
||||
slog.Error("unable to initialize MLX", "error", err)
|
||||
return err
|
||||
}
|
||||
slog.Info("MLX library initialized")
|
||||
|
||||
// Detect model type from capabilities
|
||||
mode := detectModelMode(*modelName)
|
||||
slog.Info("starting mlx runner", "model", *modelName, "port", *port, "mode", mode)
|
||||
|
||||
// Create and start server
|
||||
server, err := newServer(*modelName, *port, mode)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create server: %w", err)
|
||||
}
|
||||
|
||||
// Set up HTTP handlers
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/health", server.healthHandler)
|
||||
mux.HandleFunc("/completion", server.completionHandler)
|
||||
|
||||
// LLM-specific endpoints
|
||||
if mode == ModeLLM {
|
||||
mux.HandleFunc("/tokenize", server.tokenizeHandler)
|
||||
mux.HandleFunc("/embedding", server.embeddingHandler)
|
||||
}
|
||||
|
||||
httpServer := &http.Server{
|
||||
Addr: fmt.Sprintf("127.0.0.1:%d", *port),
|
||||
Handler: mux,
|
||||
}
|
||||
|
||||
// Handle shutdown
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-sigCh
|
||||
slog.Info("shutting down mlx runner")
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
httpServer.Shutdown(ctx)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
slog.Info("mlx runner listening", "addr", httpServer.Addr)
|
||||
if err := httpServer.ListenAndServe(); err != http.ErrServerClosed {
|
||||
return err
|
||||
}
|
||||
|
||||
<-done
|
||||
return nil
|
||||
}
|
||||
|
||||
// detectModelMode determines whether a model is an LLM or image generation model.
|
||||
func detectModelMode(modelName string) ModelMode {
|
||||
// Check for image generation model by looking at model_index.json
|
||||
modelType := imagegen.DetectModelType(modelName)
|
||||
if modelType != "" {
|
||||
// Known image generation model types
|
||||
switch modelType {
|
||||
case "ZImagePipeline", "FluxPipeline", "Flux2KleinPipeline":
|
||||
return ModeImageGen
|
||||
}
|
||||
}
|
||||
|
||||
// Default to LLM mode for safetensors models without known image gen types
|
||||
return ModeLLM
|
||||
}
|
||||
|
||||
// server holds the model and handles HTTP requests.
|
||||
type server struct {
|
||||
mode ModelMode
|
||||
modelName string
|
||||
port int
|
||||
|
||||
// Image generation model (when mode == ModeImageGen)
|
||||
imageModel ImageModel
|
||||
|
||||
// LLM model (when mode == ModeLLM)
|
||||
llmModel *llmState
|
||||
}
|
||||
|
||||
// newServer creates a new server instance and loads the appropriate model.
|
||||
func newServer(modelName string, port int, mode ModelMode) (*server, error) {
|
||||
s := &server{
|
||||
mode: mode,
|
||||
modelName: modelName,
|
||||
port: port,
|
||||
}
|
||||
|
||||
switch mode {
|
||||
case ModeImageGen:
|
||||
if err := s.loadImageModel(); err != nil {
|
||||
return nil, fmt.Errorf("failed to load image model: %w", err)
|
||||
}
|
||||
case ModeLLM:
|
||||
if err := s.loadLLMModel(); err != nil {
|
||||
return nil, fmt.Errorf("failed to load LLM model: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *server) healthHandler(w http.ResponseWriter, r *http.Request) {
|
||||
resp := HealthResponse{Status: "ok"}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}
|
||||
|
||||
func (s *server) completionHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
var req Request
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
switch s.mode {
|
||||
case ModeImageGen:
|
||||
s.handleImageCompletion(w, r, req)
|
||||
case ModeLLM:
|
||||
s.handleLLMCompletion(w, r, req)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *server) tokenizeHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if s.llmModel == nil {
|
||||
http.Error(w, "LLM model not loaded", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Content string `json:"content"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
tok := s.llmModel.model.Tokenizer()
|
||||
tokens := tok.Encode(req.Content, false)
|
||||
|
||||
// Convert int32 to int for JSON response
|
||||
intTokens := make([]int, len(tokens))
|
||||
for i, t := range tokens {
|
||||
intTokens[i] = int(t)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string][]int{"tokens": intTokens})
|
||||
}
|
||||
|
||||
func (s *server) embeddingHandler(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "embeddings not yet implemented for MLX models", http.StatusNotImplemented)
|
||||
}
|
||||
@@ -1,10 +1,10 @@
|
||||
//go:build !mlx
|
||||
|
||||
package runner
|
||||
package mlxrunner
|
||||
|
||||
import "errors"
|
||||
|
||||
// Execute returns an error when not built with MLX support.
|
||||
func Execute(args []string) error {
|
||||
return errors.New("image generation not available: build with mlx tag")
|
||||
return errors.New("MLX runner not available: build with mlx tag")
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package imagegen
|
||||
package mlxrunner
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
@@ -23,19 +23,19 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
)
|
||||
|
||||
// Server wraps an image generation subprocess to implement llm.LlamaServer.
|
||||
// Server wraps an MLX runner subprocess to implement llm.LlamaServer.
|
||||
//
|
||||
// This implementation is compatible with Ollama's scheduler and can be loaded/unloaded
|
||||
// like any other model. The plan is to eventually bring this into the llm/ package
|
||||
// and evolve llm/ to support MLX and multimodal models. For now, keeping the code
|
||||
// separate allows for independent iteration on image generation support.
|
||||
// like any other model. It supports both LLM (safetensors) and image generation models.
|
||||
type Server struct {
|
||||
mu sync.Mutex
|
||||
cmd *exec.Cmd
|
||||
port int
|
||||
modelName string
|
||||
mode ModelMode
|
||||
vramSize uint64
|
||||
done chan error
|
||||
client *http.Client
|
||||
@@ -43,10 +43,10 @@ type Server struct {
|
||||
lastErrLock sync.Mutex
|
||||
}
|
||||
|
||||
// NewServer spawns a new image generation subprocess and waits until it's ready.
|
||||
func NewServer(modelName string) (*Server, error) {
|
||||
// NewServer spawns a new MLX runner subprocess and waits until it's ready.
|
||||
func NewServer(modelName string, mode ModelMode) (*Server, error) {
|
||||
// Validate platform support before attempting to start
|
||||
if err := CheckPlatformSupport(); err != nil {
|
||||
if err := imagegen.CheckPlatformSupport(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -71,8 +71,8 @@ func NewServer(modelName string) (*Server, error) {
|
||||
exe = eval
|
||||
}
|
||||
|
||||
// Spawn subprocess: ollama runner --image-engine --model <path> --port <port>
|
||||
cmd := exec.Command(exe, "runner", "--image-engine", "--model", modelName, "--port", strconv.Itoa(port))
|
||||
// Spawn subprocess: ollama runner --mlx-engine --model <path> --port <port>
|
||||
cmd := exec.Command(exe, "runner", "--mlx-engine", "--model", modelName, "--port", strconv.Itoa(port))
|
||||
cmd.Env = os.Environ()
|
||||
|
||||
// On Linux, set LD_LIBRARY_PATH to include MLX library directories
|
||||
@@ -105,17 +105,21 @@ func NewServer(modelName string) (*Server, error) {
|
||||
slog.Debug("mlx subprocess library path", "LD_LIBRARY_PATH", pathEnvVal)
|
||||
}
|
||||
|
||||
// Get total weight size from manifest
|
||||
var weightSize uint64
|
||||
if manifest, err := LoadManifest(modelName); err == nil {
|
||||
weightSize = uint64(manifest.TotalTensorSize())
|
||||
// Estimate VRAM based on tensor size from manifest
|
||||
var vramSize uint64
|
||||
if manifest, err := imagegen.LoadManifest(modelName); err == nil {
|
||||
vramSize = uint64(manifest.TotalTensorSize())
|
||||
} else {
|
||||
// Fallback: default to 8GB if manifest can't be loaded
|
||||
vramSize = 8 * 1024 * 1024 * 1024
|
||||
}
|
||||
|
||||
s := &Server{
|
||||
cmd: cmd,
|
||||
port: port,
|
||||
modelName: modelName,
|
||||
vramSize: weightSize,
|
||||
mode: mode,
|
||||
vramSize: vramSize,
|
||||
done: make(chan error, 1),
|
||||
client: &http.Client{Timeout: 10 * time.Minute},
|
||||
}
|
||||
@@ -126,23 +130,23 @@ func NewServer(modelName string) (*Server, error) {
|
||||
go func() {
|
||||
scanner := bufio.NewScanner(stdout)
|
||||
for scanner.Scan() {
|
||||
slog.Info("image-runner", "msg", scanner.Text())
|
||||
slog.Info("mlx-runner", "msg", scanner.Text())
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
scanner := bufio.NewScanner(stderr)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
slog.Warn("image-runner", "msg", line)
|
||||
slog.Warn("mlx-runner", "msg", line)
|
||||
s.lastErrLock.Lock()
|
||||
s.lastErr = line
|
||||
s.lastErrLock.Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
slog.Info("starting image runner subprocess", "exe", exe, "model", modelName, "port", port)
|
||||
slog.Info("starting mlx runner subprocess", "exe", exe, "model", modelName, "port", port, "mode", mode)
|
||||
if err := cmd.Start(); err != nil {
|
||||
return nil, fmt.Errorf("failed to start image runner: %w", err)
|
||||
return nil, fmt.Errorf("failed to start mlx runner: %w", err)
|
||||
}
|
||||
|
||||
// Reap subprocess when it exits
|
||||
@@ -165,6 +169,7 @@ func (s *Server) ModelPath() string {
|
||||
return s.modelName
|
||||
}
|
||||
|
||||
// Load satisfies the LlamaServer interface. MLX models don't need GPU layer assignment.
|
||||
func (s *Server) Load(ctx context.Context, systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, requireFull bool) ([]ml.DeviceID, error) {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -200,18 +205,18 @@ func (s *Server) waitUntilRunning() error {
|
||||
// Include recent stderr lines for better error context
|
||||
errMsg := s.getLastErr()
|
||||
if errMsg != "" {
|
||||
return fmt.Errorf("image runner failed: %s (exit: %v)", errMsg, err)
|
||||
return fmt.Errorf("mlx runner failed: %s (exit: %v)", errMsg, err)
|
||||
}
|
||||
return fmt.Errorf("image runner exited unexpectedly: %w", err)
|
||||
return fmt.Errorf("mlx runner exited unexpectedly: %w", err)
|
||||
case <-timeout:
|
||||
errMsg := s.getLastErr()
|
||||
if errMsg != "" {
|
||||
return fmt.Errorf("timeout waiting for image runner: %s", errMsg)
|
||||
return fmt.Errorf("timeout waiting for mlx runner: %s", errMsg)
|
||||
}
|
||||
return errors.New("timeout waiting for image runner to start")
|
||||
return errors.New("timeout waiting for mlx runner to start")
|
||||
case <-ticker.C:
|
||||
if err := s.Ping(ctx); err == nil {
|
||||
slog.Info("image runner is ready", "port", s.port)
|
||||
slog.Info("mlx runner is ready", "port", s.port)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -225,8 +230,12 @@ func (s *Server) getLastErr() string {
|
||||
return s.lastErr
|
||||
}
|
||||
|
||||
func (s *Server) WaitUntilRunning(ctx context.Context) error { return nil }
|
||||
// WaitUntilRunning satisfies the LlamaServer interface.
|
||||
func (s *Server) WaitUntilRunning(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Completion handles both text and image generation requests.
|
||||
func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
|
||||
seed := req.Seed
|
||||
if seed == 0 {
|
||||
@@ -240,22 +249,26 @@ func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn f
|
||||
}
|
||||
|
||||
// Build request for subprocess
|
||||
creq := struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Width int32 `json:"width,omitempty"`
|
||||
Height int32 `json:"height,omitempty"`
|
||||
Steps int32 `json:"steps,omitempty"`
|
||||
Seed int64 `json:"seed,omitempty"`
|
||||
Images [][]byte `json:"images,omitempty"`
|
||||
}{
|
||||
creq := Request{
|
||||
Prompt: req.Prompt,
|
||||
Width: req.Width,
|
||||
Height: req.Height,
|
||||
Steps: req.Steps,
|
||||
Steps: int(req.Steps),
|
||||
Seed: seed,
|
||||
Images: images,
|
||||
}
|
||||
|
||||
// Pass LLM options if present
|
||||
if req.Options != nil {
|
||||
creq.Options = &RequestOptions{
|
||||
NumPredict: req.Options.NumPredict,
|
||||
Temperature: float64(req.Options.Temperature),
|
||||
TopP: float64(req.Options.TopP),
|
||||
TopK: req.Options.TopK,
|
||||
Stop: req.Options.Stop,
|
||||
}
|
||||
}
|
||||
|
||||
body, err := json.Marshal(creq)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -282,25 +295,40 @@ func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn f
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Buffer(make([]byte, 1024*1024), 16*1024*1024) // 16MB max
|
||||
for scanner.Scan() {
|
||||
// Parse subprocess response (has singular "image" field)
|
||||
// Parse subprocess response
|
||||
var raw struct {
|
||||
Image string `json:"image,omitempty"`
|
||||
Content string `json:"content,omitempty"`
|
||||
Done bool `json:"done"`
|
||||
Step int `json:"step,omitempty"`
|
||||
Total int `json:"total,omitempty"`
|
||||
Image string `json:"image,omitempty"`
|
||||
Content string `json:"content,omitempty"`
|
||||
Done bool `json:"done"`
|
||||
Step int `json:"step,omitempty"`
|
||||
Total int `json:"total,omitempty"`
|
||||
StopReason string `json:"stop_reason,omitempty"`
|
||||
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
|
||||
PromptEvalDuration int `json:"prompt_eval_duration,omitempty"`
|
||||
EvalCount int `json:"eval_count,omitempty"`
|
||||
EvalDuration int `json:"eval_duration,omitempty"`
|
||||
}
|
||||
if err := json.Unmarshal(scanner.Bytes(), &raw); err != nil {
|
||||
slog.Debug("mlx response parse error", "error", err, "line", string(scanner.Bytes()))
|
||||
continue
|
||||
}
|
||||
|
||||
// Log stop reason when generation completes
|
||||
if raw.Done && raw.StopReason != "" {
|
||||
slog.Info("mlx generation completed", "stop_reason", raw.StopReason)
|
||||
}
|
||||
|
||||
// Convert to llm.CompletionResponse
|
||||
cresp := llm.CompletionResponse{
|
||||
Content: raw.Content,
|
||||
Done: raw.Done,
|
||||
Step: raw.Step,
|
||||
TotalSteps: raw.Total,
|
||||
Image: raw.Image,
|
||||
Content: raw.Content,
|
||||
Done: raw.Done,
|
||||
Step: raw.Step,
|
||||
TotalSteps: raw.Total,
|
||||
Image: raw.Image,
|
||||
PromptEvalCount: raw.PromptEvalCount,
|
||||
PromptEvalDuration: time.Duration(raw.PromptEvalDuration),
|
||||
EvalCount: raw.EvalCount,
|
||||
EvalDuration: time.Duration(raw.EvalDuration),
|
||||
}
|
||||
|
||||
fn(cresp)
|
||||
@@ -309,7 +337,20 @@ func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn f
|
||||
}
|
||||
}
|
||||
|
||||
return scanner.Err()
|
||||
// Scanner exited without receiving Done - connection was likely closed
|
||||
scanErr := scanner.Err()
|
||||
if scanErr != nil {
|
||||
slog.Error("mlx scanner error", "error", scanErr)
|
||||
} else {
|
||||
slog.Warn("mlx scanner EOF without Done response - subprocess may have crashed")
|
||||
}
|
||||
|
||||
// Check if subprocess is still alive
|
||||
if s.HasExited() {
|
||||
slog.Error("mlx subprocess has exited unexpectedly")
|
||||
}
|
||||
|
||||
return scanErr
|
||||
}
|
||||
|
||||
// Close terminates the subprocess.
|
||||
@@ -318,7 +359,7 @@ func (s *Server) Close() error {
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.cmd != nil && s.cmd.Process != nil {
|
||||
slog.Info("stopping image runner subprocess", "pid", s.cmd.Process.Pid)
|
||||
slog.Info("stopping mlx runner subprocess", "pid", s.cmd.Process.Pid)
|
||||
s.cmd.Process.Signal(os.Interrupt)
|
||||
|
||||
// Wait briefly for graceful shutdown
|
||||
@@ -347,18 +388,51 @@ func (s *Server) VRAMByGPU(id ml.DeviceID) uint64 {
|
||||
return s.vramSize
|
||||
}
|
||||
|
||||
// Embedding returns embeddings for the input.
|
||||
func (s *Server) Embedding(ctx context.Context, input string) ([]float32, int, error) {
|
||||
return nil, 0, errors.New("not supported")
|
||||
return nil, 0, errors.New("embeddings not supported for MLX models")
|
||||
}
|
||||
|
||||
// Tokenize tokenizes the input content.
|
||||
func (s *Server) Tokenize(ctx context.Context, content string) ([]int, error) {
|
||||
return nil, errors.New("not supported")
|
||||
body, err := json.Marshal(map[string]string{"content": content})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("http://127.0.0.1:%d/tokenize", s.port)
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := s.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("tokenize failed: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Tokens []int `json:"tokens"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result.Tokens, nil
|
||||
}
|
||||
|
||||
// Detokenize converts tokens back to text.
|
||||
func (s *Server) Detokenize(ctx context.Context, tokens []int) (string, error) {
|
||||
return "", errors.New("not supported")
|
||||
return "", errors.New("detokenization not supported for MLX models")
|
||||
}
|
||||
|
||||
// Pid returns the process ID of the subprocess.
|
||||
func (s *Server) Pid() int {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
@@ -368,9 +442,17 @@ func (s *Server) Pid() int {
|
||||
return -1
|
||||
}
|
||||
|
||||
func (s *Server) GetPort() int { return s.port }
|
||||
func (s *Server) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo { return nil }
|
||||
// GetPort returns the port the subprocess is listening on.
|
||||
func (s *Server) GetPort() int {
|
||||
return s.port
|
||||
}
|
||||
|
||||
// GetDeviceInfos returns device information.
|
||||
func (s *Server) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
|
||||
return nil
|
||||
}
|
||||
|
||||
// HasExited returns whether the subprocess has exited.
|
||||
func (s *Server) HasExited() bool {
|
||||
select {
|
||||
case <-s.done:
|
||||
81
x/mlxrunner/types.go
Normal file
81
x/mlxrunner/types.go
Normal file
@@ -0,0 +1,81 @@
|
||||
// Package mlxrunner provides a unified MLX runner for both LLM and image generation models.
|
||||
//
|
||||
// This package handles safetensors models created with `ollama create --experimental`,
|
||||
// supporting both text generation (LLM) and image generation (diffusion) models
|
||||
// through a single unified interface.
|
||||
package mlxrunner
|
||||
|
||||
// Request is the request format for completion requests.
|
||||
type Request struct {
|
||||
Prompt string `json:"prompt"`
|
||||
|
||||
// LLM-specific fields
|
||||
Options *RequestOptions `json:"options,omitempty"`
|
||||
|
||||
// Image generation fields
|
||||
Width int32 `json:"width,omitempty"`
|
||||
Height int32 `json:"height,omitempty"`
|
||||
Steps int `json:"steps,omitempty"`
|
||||
Seed int64 `json:"seed,omitempty"`
|
||||
Images [][]byte `json:"images,omitempty"` // Input images for image editing/conditioning
|
||||
}
|
||||
|
||||
// RequestOptions contains LLM-specific generation options.
|
||||
type RequestOptions struct {
|
||||
NumPredict int `json:"num_predict,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
Stop []string `json:"stop,omitempty"`
|
||||
}
|
||||
|
||||
// Response is streamed back for each progress update.
|
||||
type Response struct {
|
||||
// Text generation response
|
||||
Content string `json:"content,omitempty"`
|
||||
|
||||
// Image generation response
|
||||
Image string `json:"image,omitempty"` // Base64-encoded PNG
|
||||
|
||||
// Common fields
|
||||
Done bool `json:"done"`
|
||||
DoneReason int `json:"done_reason,omitempty"`
|
||||
StopReason string `json:"stop_reason,omitempty"` // Debug: why generation stopped
|
||||
|
||||
// Progress fields
|
||||
Step int `json:"step,omitempty"`
|
||||
Total int `json:"total,omitempty"`
|
||||
|
||||
// Statistics
|
||||
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
|
||||
PromptEvalDuration int `json:"prompt_eval_duration,omitempty"`
|
||||
EvalCount int `json:"eval_count,omitempty"`
|
||||
EvalDuration int `json:"eval_duration,omitempty"`
|
||||
}
|
||||
|
||||
// HealthResponse is returned by the health endpoint.
|
||||
type HealthResponse struct {
|
||||
Status string `json:"status"`
|
||||
Progress float32 `json:"progress,omitempty"`
|
||||
}
|
||||
|
||||
// ModelMode represents the type of model being run.
|
||||
type ModelMode int
|
||||
|
||||
const (
|
||||
// ModeLLM indicates a text generation model.
|
||||
ModeLLM ModelMode = iota
|
||||
// ModeImageGen indicates an image generation model.
|
||||
ModeImageGen
|
||||
)
|
||||
|
||||
func (m ModelMode) String() string {
|
||||
switch m {
|
||||
case ModeLLM:
|
||||
return "llm"
|
||||
case ModeImageGen:
|
||||
return "imagegen"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
@@ -87,7 +87,7 @@ func New(c fs.Config) (model.Model, error) {
|
||||
// m.Cache = kvcache.NewWrapperCache(kvcache.NewSWACache(slidingWindowLen, m.Shift), kvcache.NewCausalCache(m.Shift))
|
||||
|
||||
// TODO need to implement sliding window...
|
||||
m.Cache = kvcache.NewMLXCausalCache()
|
||||
m.Cache = kvcache.NewCausalCache()
|
||||
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
@@ -199,7 +199,7 @@ func getTensorInfoFromManifest(mf *manifest.Manifest) ([]api.Tensor, error) {
|
||||
}
|
||||
|
||||
// GetSafetensorsDtype returns the quantization type for a safetensors model.
|
||||
// If the model is quantized (has _scale tensors), returns the quantization type (e.g., "FP8").
|
||||
// Reads from model_index.json first, falls back to detection from tensor names.
|
||||
// Otherwise returns the torch_dtype from config.json.
|
||||
func GetSafetensorsDtype(name model.Name) (string, error) {
|
||||
mf, err := manifest.ParseNamedManifest(name)
|
||||
@@ -207,16 +207,38 @@ func GetSafetensorsDtype(name model.Name) (string, error) {
|
||||
return "", fmt.Errorf("failed to load manifest: %w", err)
|
||||
}
|
||||
|
||||
// Check if model is quantized by looking for _scale tensors
|
||||
// First try to read quantization from model_index.json
|
||||
var modelIndex struct {
|
||||
Quantization string `json:"quantization"`
|
||||
}
|
||||
if err := mf.ReadConfigJSON("model_index.json", &modelIndex); err == nil && modelIndex.Quantization != "" {
|
||||
return modelIndex.Quantization, nil
|
||||
}
|
||||
|
||||
// Fallback: detect from tensor names
|
||||
hasScales := false
|
||||
hasQBias := false
|
||||
for _, layer := range mf.Layers {
|
||||
if layer.MediaType == manifest.MediaTypeImageTensor {
|
||||
if strings.HasSuffix(layer.Name, "_scale") {
|
||||
// Model is quantized - return FP8 (affine quantization)
|
||||
return "FP8", nil
|
||||
hasScales = true
|
||||
}
|
||||
if strings.HasSuffix(layer.Name, "_qbias") {
|
||||
hasQBias = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if hasScales {
|
||||
if hasQBias {
|
||||
// Affine mode (has scale + qbias) - could be FP4 or FP8
|
||||
// Default to FP4 as it's more common
|
||||
return "FP4", nil
|
||||
}
|
||||
// No qbias = NVFP4
|
||||
return "NVFP4", nil
|
||||
}
|
||||
|
||||
// Not quantized - return torch_dtype from config.json
|
||||
var cfg struct {
|
||||
TorchDtype string `json:"torch_dtype"`
|
||||
|
||||
51
x/server/thinking.go
Normal file
51
x/server/thinking.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
// IsSafetensorsThinkingModel checks if a safetensors model supports thinking
|
||||
// based on its architecture from config.json.
|
||||
func IsSafetensorsThinkingModel(name model.Name) bool {
|
||||
mf, err := manifest.ParseNamedManifest(name)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
var config struct {
|
||||
Architectures []string `json:"architectures"`
|
||||
ModelType string `json:"model_type"`
|
||||
}
|
||||
if err := mf.ReadConfigJSON("config.json", &config); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Determine architecture
|
||||
arch := config.ModelType
|
||||
if arch == "" && len(config.Architectures) > 0 {
|
||||
arch = config.Architectures[0]
|
||||
}
|
||||
if arch == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
archLower := strings.ToLower(arch)
|
||||
|
||||
// List of architectures that support thinking
|
||||
thinkingArchitectures := []string{
|
||||
"glm4moe", // GLM-4 MoE models
|
||||
"deepseek", // DeepSeek models
|
||||
"qwen3", // Qwen3 models
|
||||
}
|
||||
|
||||
for _, thinkArch := range thinkingArchitectures {
|
||||
if strings.Contains(archLower, thinkArch) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
Reference in New Issue
Block a user