Compare commits

...

3 Commits

Author SHA1 Message Date
Michael Yang
c8b1f9e1a1 fix quantization 2025-07-23 13:14:50 -07:00
Michael Yang
362bf0901f cleanup quantization 2025-07-23 13:14:50 -07:00
Michael Yang
2a03498bbb iter quant 2025-07-23 13:14:50 -07:00
4 changed files with 95 additions and 33 deletions

View File

@@ -5,6 +5,7 @@ package ggml
// #cgo CPPFLAGS: -I${SRCDIR}/../include -I${SRCDIR}/ggml-cpu
// #cgo windows CFLAGS: -Wno-dll-attribute-on-redeclaration
// #cgo windows LDFLAGS: -lmsvcrt -static -static-libgcc -static-libstdc++
// #cgo windows linux CPPFLAGS: -DGGML_FP16_TO_FP32=ggml_compute_fp16_to_fp32
// #include <stdlib.h>
// #include "ggml-backend.h"
// extern void sink(int level, char *text, void *user_data);

View File

@@ -10,6 +10,8 @@ package ggml
import "C"
import (
"iter"
"slices"
"unsafe"
fsggml "github.com/ollama/ollama/fs/ggml"
@@ -50,34 +52,30 @@ func ConvertToF32(data []byte, dtype uint32, nelements uint64) []float32 {
return f32s
}
func Quantize(newType fsggml.TensorType, f32s []float32, shape []uint64) []byte {
buf := make([]byte, len(f32s)*4) // upper bound on size
nPerRow := C.int64_t(shape[0])
nrows := C.int64_t(1)
if len(shape) > 1 {
nrows = C.int64_t(shape[1])
}
shape2 := C.int64_t(1)
if len(shape) > 2 {
shape2 = C.int64_t(shape[2])
}
nelements_matrix := nPerRow * nrows
newSize := C.size_t(0)
for i03 := C.int64_t(0); i03 < shape2; i03++ {
f32s_03 := i03 * nelements_matrix
buf_03 := C.int64_t(C.ggml_row_size(uint32(newType), nPerRow)) * i03 * nrows
newSize += C.ggml_quantize_chunk(
uint32(newType),
(*C.float)(&f32s[f32s_03]),
unsafe.Pointer((uintptr)(unsafe.Pointer(&buf[0]))+uintptr(buf_03)),
0,
nrows,
nPerRow,
nil)
}
return buf[:newSize]
}
func Quantize(newType fsggml.TensorType, f32s []float32, shape []uint64) iter.Seq[[]byte] {
return func(yield func([]byte) bool) {
C.ggml_quantize_init(uint32(newType))
defer C.ggml_quantize_free()
func QuantizationVersion() uint32 {
return uint32(C.GGML_QNT_VERSION)
dims := slices.Repeat([]C.int64_t{1}, 4)
for i, s := range shape {
dims[i] = C.int64_t(s)
}
bts := make([]byte, C.ggml_row_size(uint32(newType), dims[0])*C.size_t(dims[1]))
for chunk := range dims[2] {
offset := chunk * dims[0] * dims[1]
n := C.ggml_quantize_chunk(
uint32(newType),
(*C.float)(&f32s[0]),
unsafe.Pointer(&bts[0]),
offset, dims[1], dims[0], nil,
)
if !yield(bts[:n]) {
return
}
}
}
}

View File

@@ -40,10 +40,19 @@ func (q quantizer) WriteTo(w io.Writer) (int64, error) {
} else {
f32s = ggml.ConvertToF32(data, q.from.Kind, q.from.Elements())
}
data = ggml.Quantize(newType, f32s, q.from.Shape)
n, err := w.Write(data)
q.progressFn(q.from.Size())
return int64(n), err
var n int64
for bts := range ggml.Quantize(newType, f32s, q.from.Shape) {
nn, err := w.Write(bts)
if err != nil {
return 0, err
}
q.progressFn(uint64(nn))
n += int64(nn)
}
return n, err
}
type quantizeState struct {

View File

@@ -2,12 +2,14 @@ package server
import (
"bytes"
"encoding/binary"
"fmt"
"math"
"os"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
fsggml "github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/ml/backend/ggml"
)
@@ -649,3 +651,55 @@ var (
},
}
)
func TestQuantizer(t *testing.T) {
from := fsggml.Tensor{
Name: "fp32",
Shape: []uint64{256},
Kind: uint32(fsggml.TensorTypeF32),
}
temp, err := os.CreateTemp(t.TempDir(), "*.bin")
if err != nil {
t.Fatalf("failed to create temp file: %v", err)
}
f32s := make([]float32, 256)
for i := range f32s {
f32s[i] = float32(i)
}
if err := binary.Write(temp, binary.LittleEndian, f32s); err != nil {
t.Fatalf("failed to write to temp file: %v", err)
}
for type_, want := range quantBytes {
t.Run(type_.String(), func(t *testing.T) {
f, err := os.Open(temp.Name())
if err != nil {
t.Fatalf("failed to open temp file: %v", err)
}
defer f.Close()
q := quantizer{
File: f,
from: &from,
to: &fsggml.Tensor{
Name: type_.String(),
Shape: from.Shape,
Kind: uint32(type_),
},
progressFn: func(uint64) {},
}
var b bytes.Buffer
if _, err := q.WriteTo(&b); err != nil {
t.Fatalf("WriteTo failed: %v", err)
}
if diff := cmp.Diff(b.Bytes(), want); diff != "" {
t.Errorf("quantized data mismatch for %s (-got +want):\n%s", type_, diff)
}
})
}
}