mirror of
https://github.com/ollama/ollama.git
synced 2026-01-20 21:40:54 -05:00
Compare commits
3 Commits
parth/decr
...
mxyng/quan
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c8b1f9e1a1 | ||
|
|
362bf0901f | ||
|
|
2a03498bbb |
@@ -5,6 +5,7 @@ package ggml
|
|||||||
// #cgo CPPFLAGS: -I${SRCDIR}/../include -I${SRCDIR}/ggml-cpu
|
// #cgo CPPFLAGS: -I${SRCDIR}/../include -I${SRCDIR}/ggml-cpu
|
||||||
// #cgo windows CFLAGS: -Wno-dll-attribute-on-redeclaration
|
// #cgo windows CFLAGS: -Wno-dll-attribute-on-redeclaration
|
||||||
// #cgo windows LDFLAGS: -lmsvcrt -static -static-libgcc -static-libstdc++
|
// #cgo windows LDFLAGS: -lmsvcrt -static -static-libgcc -static-libstdc++
|
||||||
|
// #cgo windows linux CPPFLAGS: -DGGML_FP16_TO_FP32=ggml_compute_fp16_to_fp32
|
||||||
// #include <stdlib.h>
|
// #include <stdlib.h>
|
||||||
// #include "ggml-backend.h"
|
// #include "ggml-backend.h"
|
||||||
// extern void sink(int level, char *text, void *user_data);
|
// extern void sink(int level, char *text, void *user_data);
|
||||||
|
|||||||
@@ -10,6 +10,8 @@ package ggml
|
|||||||
import "C"
|
import "C"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"iter"
|
||||||
|
"slices"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
fsggml "github.com/ollama/ollama/fs/ggml"
|
fsggml "github.com/ollama/ollama/fs/ggml"
|
||||||
@@ -50,34 +52,30 @@ func ConvertToF32(data []byte, dtype uint32, nelements uint64) []float32 {
|
|||||||
return f32s
|
return f32s
|
||||||
}
|
}
|
||||||
|
|
||||||
func Quantize(newType fsggml.TensorType, f32s []float32, shape []uint64) []byte {
|
func Quantize(newType fsggml.TensorType, f32s []float32, shape []uint64) iter.Seq[[]byte] {
|
||||||
buf := make([]byte, len(f32s)*4) // upper bound on size
|
return func(yield func([]byte) bool) {
|
||||||
nPerRow := C.int64_t(shape[0])
|
C.ggml_quantize_init(uint32(newType))
|
||||||
nrows := C.int64_t(1)
|
defer C.ggml_quantize_free()
|
||||||
if len(shape) > 1 {
|
|
||||||
nrows = C.int64_t(shape[1])
|
|
||||||
}
|
|
||||||
shape2 := C.int64_t(1)
|
|
||||||
if len(shape) > 2 {
|
|
||||||
shape2 = C.int64_t(shape[2])
|
|
||||||
}
|
|
||||||
nelements_matrix := nPerRow * nrows
|
|
||||||
newSize := C.size_t(0)
|
|
||||||
for i03 := C.int64_t(0); i03 < shape2; i03++ {
|
|
||||||
f32s_03 := i03 * nelements_matrix
|
|
||||||
buf_03 := C.int64_t(C.ggml_row_size(uint32(newType), nPerRow)) * i03 * nrows
|
|
||||||
newSize += C.ggml_quantize_chunk(
|
|
||||||
uint32(newType),
|
|
||||||
(*C.float)(&f32s[f32s_03]),
|
|
||||||
unsafe.Pointer((uintptr)(unsafe.Pointer(&buf[0]))+uintptr(buf_03)),
|
|
||||||
0,
|
|
||||||
nrows,
|
|
||||||
nPerRow,
|
|
||||||
nil)
|
|
||||||
}
|
|
||||||
return buf[:newSize]
|
|
||||||
}
|
|
||||||
|
|
||||||
func QuantizationVersion() uint32 {
|
dims := slices.Repeat([]C.int64_t{1}, 4)
|
||||||
return uint32(C.GGML_QNT_VERSION)
|
for i, s := range shape {
|
||||||
|
dims[i] = C.int64_t(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
bts := make([]byte, C.ggml_row_size(uint32(newType), dims[0])*C.size_t(dims[1]))
|
||||||
|
for chunk := range dims[2] {
|
||||||
|
offset := chunk * dims[0] * dims[1]
|
||||||
|
|
||||||
|
n := C.ggml_quantize_chunk(
|
||||||
|
uint32(newType),
|
||||||
|
(*C.float)(&f32s[0]),
|
||||||
|
unsafe.Pointer(&bts[0]),
|
||||||
|
offset, dims[1], dims[0], nil,
|
||||||
|
)
|
||||||
|
|
||||||
|
if !yield(bts[:n]) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -40,10 +40,19 @@ func (q quantizer) WriteTo(w io.Writer) (int64, error) {
|
|||||||
} else {
|
} else {
|
||||||
f32s = ggml.ConvertToF32(data, q.from.Kind, q.from.Elements())
|
f32s = ggml.ConvertToF32(data, q.from.Kind, q.from.Elements())
|
||||||
}
|
}
|
||||||
data = ggml.Quantize(newType, f32s, q.from.Shape)
|
|
||||||
n, err := w.Write(data)
|
var n int64
|
||||||
q.progressFn(q.from.Size())
|
for bts := range ggml.Quantize(newType, f32s, q.from.Shape) {
|
||||||
return int64(n), err
|
nn, err := w.Write(bts)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
q.progressFn(uint64(nn))
|
||||||
|
n += int64(nn)
|
||||||
|
}
|
||||||
|
|
||||||
|
return n, err
|
||||||
}
|
}
|
||||||
|
|
||||||
type quantizeState struct {
|
type quantizeState struct {
|
||||||
|
|||||||
@@ -2,12 +2,14 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
fsggml "github.com/ollama/ollama/fs/ggml"
|
fsggml "github.com/ollama/ollama/fs/ggml"
|
||||||
"github.com/ollama/ollama/ml/backend/ggml"
|
"github.com/ollama/ollama/ml/backend/ggml"
|
||||||
)
|
)
|
||||||
@@ -649,3 +651,55 @@ var (
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestQuantizer(t *testing.T) {
|
||||||
|
from := fsggml.Tensor{
|
||||||
|
Name: "fp32",
|
||||||
|
Shape: []uint64{256},
|
||||||
|
Kind: uint32(fsggml.TensorTypeF32),
|
||||||
|
}
|
||||||
|
|
||||||
|
temp, err := os.CreateTemp(t.TempDir(), "*.bin")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create temp file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
f32s := make([]float32, 256)
|
||||||
|
for i := range f32s {
|
||||||
|
f32s[i] = float32(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := binary.Write(temp, binary.LittleEndian, f32s); err != nil {
|
||||||
|
t.Fatalf("failed to write to temp file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for type_, want := range quantBytes {
|
||||||
|
t.Run(type_.String(), func(t *testing.T) {
|
||||||
|
f, err := os.Open(temp.Name())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to open temp file: %v", err)
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
q := quantizer{
|
||||||
|
File: f,
|
||||||
|
from: &from,
|
||||||
|
to: &fsggml.Tensor{
|
||||||
|
Name: type_.String(),
|
||||||
|
Shape: from.Shape,
|
||||||
|
Kind: uint32(type_),
|
||||||
|
},
|
||||||
|
progressFn: func(uint64) {},
|
||||||
|
}
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
if _, err := q.WriteTo(&b); err != nil {
|
||||||
|
t.Fatalf("WriteTo failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(b.Bytes(), want); diff != "" {
|
||||||
|
t.Errorf("quantized data mismatch for %s (-got +want):\n%s", type_, diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user