mirror of
https://github.com/ollama/ollama.git
synced 2026-06-03 22:13:30 -04:00
* mlx: Support NVIDIA TensorRT Model Optimizer import * x/create: support FP8 safetensors import Decode HF F8_E4M3 safetensors with block scale companions into MLX-importable tensor blobs, including compressed-tensors weight_scale metadata, packed NVFP4 layouts, and mixed-precision tensor headers. Use that source-precision metadata during create quantization: default FP8-sourced imports to mxfp8, allow source FP8 to target MLX low-bit formats, preserve source-quantized NVFP4 layouts, selectively keep or promote tensors based on their source precision, and detect quantized dtype from mixed-precision safetensors manifests. * review comments
25 lines
584 B
Go
25 lines
584 B
Go
package client
|
|
|
|
import (
|
|
"testing"
|
|
|
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
|
)
|
|
|
|
func TestDecodeSourceFP8TensorAcceptsWeightScale(t *testing.T) {
|
|
if err := mlx.CheckInit(); err != nil {
|
|
t.Skipf("MLX unavailable: %v", err)
|
|
}
|
|
|
|
weight := mlx.FromValues([]uint8{0, 1, 2, 3}, 2, 2)
|
|
scale := mlx.FromValues([]float32{1}, 1, 1).AsType(mlx.DTypeBFloat16)
|
|
got, err := decodeSourceFP8Tensor(weight, scale)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
mlx.Eval(got)
|
|
if dims := got.Dims(); len(dims) != 2 || dims[0] != 2 || dims[1] != 2 {
|
|
t.Fatalf("decoded dims = %v, want [2 2]", dims)
|
|
}
|
|
}
|