mirror of
https://github.com/ollama/ollama.git
synced 2026-01-21 05:48:35 -05:00
Compare commits
6 Commits
parth/decr
...
brucemacd/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3cf62838ce | ||
|
|
22aed78048 | ||
|
|
d3cbbbfd85 | ||
|
|
e8d1933b99 | ||
|
|
735e80787b | ||
|
|
8e3998b9dd |
350
fs/gguf/gguf.go
Normal file
350
fs/gguf/gguf.go
Normal file
@@ -0,0 +1,350 @@
|
|||||||
|
package gguf
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"cmp"
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"iter"
|
||||||
|
"os"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
typeUint8 uint32 = iota
|
||||||
|
typeInt8
|
||||||
|
typeUint16
|
||||||
|
typeInt16
|
||||||
|
typeUint32
|
||||||
|
typeInt32
|
||||||
|
typeFloat32
|
||||||
|
typeBool
|
||||||
|
typeString
|
||||||
|
typeArray
|
||||||
|
typeUint64
|
||||||
|
typeInt64
|
||||||
|
typeFloat64
|
||||||
|
)
|
||||||
|
|
||||||
|
var ErrUnsupported = errors.New("unsupported")
|
||||||
|
|
||||||
|
type File struct {
|
||||||
|
Magic [4]byte
|
||||||
|
Version uint32
|
||||||
|
|
||||||
|
keyValues *lazy[KeyValue]
|
||||||
|
tensors *lazy[TensorInfo]
|
||||||
|
offset int64
|
||||||
|
|
||||||
|
file *os.File
|
||||||
|
reader *readSeeker
|
||||||
|
bts []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func Open(path string) (f *File, err error) {
|
||||||
|
f = &File{bts: make([]byte, 4096)}
|
||||||
|
f.file, err = os.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
f.reader = newReadSeeker(f.file, 32<<10)
|
||||||
|
|
||||||
|
if err := binary.Read(f.reader, binary.LittleEndian, &f.Magic); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if bytes.Equal(f.Magic[:], []byte("gguf")) {
|
||||||
|
return nil, fmt.Errorf("%w file type %v", ErrUnsupported, f.Magic)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := binary.Read(f.reader, binary.LittleEndian, &f.Version); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if f.Version != 3 {
|
||||||
|
return nil, fmt.Errorf("%w version %v", ErrUnsupported, f.Version)
|
||||||
|
}
|
||||||
|
|
||||||
|
f.tensors, err = newLazy(f, f.readTensor)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
f.tensors.doneFunc = func() error {
|
||||||
|
offset, err := f.reader.Seek(0, io.SeekCurrent)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
alignment := cmp.Or(f.KeyValue("general.alignment").Int(), 32)
|
||||||
|
f.offset = offset + (alignment-offset%alignment)%alignment
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
f.keyValues, err = newLazy(f, f.readKeyValue)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return f, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *File) readTensor() (TensorInfo, error) {
|
||||||
|
name, err := readString(f)
|
||||||
|
if err != nil {
|
||||||
|
return TensorInfo{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
dims, err := read[uint32](f)
|
||||||
|
if err != nil {
|
||||||
|
return TensorInfo{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
shape := make([]uint64, dims)
|
||||||
|
for i := range dims {
|
||||||
|
shape[i], err = read[uint64](f)
|
||||||
|
if err != nil {
|
||||||
|
return TensorInfo{}, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type_, err := read[uint32](f)
|
||||||
|
if err != nil {
|
||||||
|
return TensorInfo{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
offset, err := read[uint64](f)
|
||||||
|
if err != nil {
|
||||||
|
return TensorInfo{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return TensorInfo{
|
||||||
|
Name: name,
|
||||||
|
Offset: offset,
|
||||||
|
Shape: shape,
|
||||||
|
Type: TensorType(type_),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *File) readKeyValue() (KeyValue, error) {
|
||||||
|
key, err := readString(f)
|
||||||
|
if err != nil {
|
||||||
|
return KeyValue{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
t, err := read[uint32](f)
|
||||||
|
if err != nil {
|
||||||
|
return KeyValue{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
value, err := func() (any, error) {
|
||||||
|
switch t {
|
||||||
|
case typeUint8:
|
||||||
|
return read[uint8](f)
|
||||||
|
case typeInt8:
|
||||||
|
return read[int8](f)
|
||||||
|
case typeUint16:
|
||||||
|
return read[uint16](f)
|
||||||
|
case typeInt16:
|
||||||
|
return read[int16](f)
|
||||||
|
case typeUint32:
|
||||||
|
return read[uint32](f)
|
||||||
|
case typeInt32:
|
||||||
|
return read[int32](f)
|
||||||
|
case typeUint64:
|
||||||
|
return read[uint64](f)
|
||||||
|
case typeInt64:
|
||||||
|
return read[int64](f)
|
||||||
|
case typeFloat32:
|
||||||
|
return read[float32](f)
|
||||||
|
case typeFloat64:
|
||||||
|
return read[float64](f)
|
||||||
|
case typeBool:
|
||||||
|
return read[bool](f)
|
||||||
|
case typeString:
|
||||||
|
return readString(f)
|
||||||
|
case typeArray:
|
||||||
|
return readArray(f)
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("%w type %d", ErrUnsupported, t)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
if err != nil {
|
||||||
|
return KeyValue{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return KeyValue{
|
||||||
|
Key: key,
|
||||||
|
Value: Value{value},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func read[T any](f *File) (t T, err error) {
|
||||||
|
err = binary.Read(f.reader, binary.LittleEndian, &t)
|
||||||
|
return t, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func readString(f *File) (string, error) {
|
||||||
|
n, err := read[uint64](f)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
if int(n) > len(f.bts) {
|
||||||
|
f.bts = make([]byte, n)
|
||||||
|
}
|
||||||
|
|
||||||
|
bts := f.bts[:n]
|
||||||
|
if _, err := io.ReadFull(f.reader, bts); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer clear(bts)
|
||||||
|
|
||||||
|
return string(bts), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func readArray(f *File) (any, error) {
|
||||||
|
t, err := read[uint32](f)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err := read[uint64](f)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch t {
|
||||||
|
case typeUint8:
|
||||||
|
return readArrayData[uint8](f, n)
|
||||||
|
case typeInt8:
|
||||||
|
return readArrayData[int8](f, n)
|
||||||
|
case typeUint16:
|
||||||
|
return readArrayData[uint16](f, n)
|
||||||
|
case typeInt16:
|
||||||
|
return readArrayData[int16](f, n)
|
||||||
|
case typeUint32:
|
||||||
|
return readArrayData[uint32](f, n)
|
||||||
|
case typeInt32:
|
||||||
|
return readArrayData[int32](f, n)
|
||||||
|
case typeUint64:
|
||||||
|
return readArrayData[uint64](f, n)
|
||||||
|
case typeInt64:
|
||||||
|
return readArrayData[int64](f, n)
|
||||||
|
case typeFloat32:
|
||||||
|
return readArrayData[float32](f, n)
|
||||||
|
case typeFloat64:
|
||||||
|
return readArrayData[float64](f, n)
|
||||||
|
case typeBool:
|
||||||
|
return readArrayData[bool](f, n)
|
||||||
|
case typeString:
|
||||||
|
return readArrayString(f, n)
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("%w type %d", ErrUnsupported, t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func readArrayData[T any](f *File, n uint64) (s []T, err error) {
|
||||||
|
s = make([]T, n)
|
||||||
|
for i := range n {
|
||||||
|
e, err := read[T](f)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
s[i] = e
|
||||||
|
}
|
||||||
|
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func readArrayString(f *File, n uint64) (s []string, err error) {
|
||||||
|
s = make([]string, n)
|
||||||
|
for i := range n {
|
||||||
|
e, err := readString(f)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
s[i] = e
|
||||||
|
}
|
||||||
|
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *File) Close() error {
|
||||||
|
f.keyValues.stop()
|
||||||
|
f.tensors.stop()
|
||||||
|
return f.file.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *File) KeyValue(key string) KeyValue {
|
||||||
|
if !strings.HasPrefix(key, "general.") && !strings.HasPrefix(key, "tokenizer.") {
|
||||||
|
key = f.KeyValue("general.architecture").String() + "." + key
|
||||||
|
}
|
||||||
|
|
||||||
|
if index := slices.IndexFunc(f.keyValues.values, func(kv KeyValue) bool {
|
||||||
|
return kv.Key == key
|
||||||
|
}); index >= 0 {
|
||||||
|
return f.keyValues.values[index]
|
||||||
|
}
|
||||||
|
|
||||||
|
for keyValue, ok := f.keyValues.next(); ok; keyValue, ok = f.keyValues.next() {
|
||||||
|
if keyValue.Key == key {
|
||||||
|
return keyValue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return KeyValue{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *File) NumKeyValues() int {
|
||||||
|
return int(f.keyValues.count)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *File) KeyValues() iter.Seq2[int, KeyValue] {
|
||||||
|
return f.keyValues.All()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *File) TensorInfo(name string) TensorInfo {
|
||||||
|
if index := slices.IndexFunc(f.tensors.values, func(t TensorInfo) bool {
|
||||||
|
return t.Name == name
|
||||||
|
}); index >= 0 {
|
||||||
|
return f.tensors.values[index]
|
||||||
|
}
|
||||||
|
|
||||||
|
// fast-forward through key values if we haven't already
|
||||||
|
_ = f.keyValues.rest()
|
||||||
|
for tensor, ok := f.tensors.next(); ok; tensor, ok = f.tensors.next() {
|
||||||
|
if tensor.Name == name {
|
||||||
|
return tensor
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return TensorInfo{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *File) NumTensors() int {
|
||||||
|
return int(f.tensors.count)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *File) TensorInfos() iter.Seq2[int, TensorInfo] {
|
||||||
|
// fast forward through key values if we haven't already
|
||||||
|
f.keyValues.rest()
|
||||||
|
return f.tensors.All()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *File) TensorReader(name string) (TensorInfo, io.Reader, error) {
|
||||||
|
t := f.TensorInfo(name)
|
||||||
|
if t.NumBytes() == 0 {
|
||||||
|
return TensorInfo{}, nil, fmt.Errorf("tensor %s not found", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// fast forward through tensor info if we haven't already
|
||||||
|
_ = f.tensors.rest()
|
||||||
|
return t, io.NewSectionReader(f.file, f.offset+int64(t.Offset), t.NumBytes()), nil
|
||||||
|
}
|
||||||
320
fs/gguf/gguf_test.go
Normal file
320
fs/gguf/gguf_test.go
Normal file
@@ -0,0 +1,320 @@
|
|||||||
|
package gguf
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"slices"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRead(t *testing.T) {
|
||||||
|
// Setup
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
tempFile := filepath.Join(tempDir, "test.gguf")
|
||||||
|
|
||||||
|
if err := createTestGGUFFile(tempFile, map[string]any{
|
||||||
|
"general.architecture": "llama",
|
||||||
|
"general.alignment": int64(32),
|
||||||
|
}, []testTensorInfo{
|
||||||
|
{Name: "token_embd.weight", Shape: []uint64{1000, 512}, Type: 1}, // F16
|
||||||
|
{Name: "output.weight", Shape: []uint64{512, 1000}, Type: 1}, // F16
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
f, err := Open(tempFile)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
// Test
|
||||||
|
if got := f.NumKeyValues(); got != 2 {
|
||||||
|
t.Errorf("NumKeyValues() = %d, want %d", got, 2)
|
||||||
|
}
|
||||||
|
if got := f.NumTensors(); got != 2 {
|
||||||
|
t.Errorf("NumTensors() = %d, want %d", got, 2)
|
||||||
|
}
|
||||||
|
archKV := f.KeyValue("general.architecture")
|
||||||
|
if archKV.Key == "" {
|
||||||
|
t.Error("KeyValue(\"general.architecture\") not found")
|
||||||
|
}
|
||||||
|
if got := archKV.String(); got != "llama" {
|
||||||
|
t.Errorf("KeyValue(\"general.architecture\").String() = %q, want %q", got, "llama")
|
||||||
|
}
|
||||||
|
alignKV := f.KeyValue("general.alignment")
|
||||||
|
if alignKV.Key == "" {
|
||||||
|
t.Error("KeyValue(\"general.alignment\") not found")
|
||||||
|
}
|
||||||
|
if got := alignKV.Int(); got != 32 {
|
||||||
|
t.Errorf("KeyValue(\"general.alignment\").Int() = %d, want %d", got, 32)
|
||||||
|
}
|
||||||
|
expectedTensorNames := []string{"token_embd.weight", "output.weight"}
|
||||||
|
var gotTensorNames []string
|
||||||
|
for _, tensor := range f.TensorInfos() {
|
||||||
|
gotTensorNames = append(gotTensorNames, tensor.Name)
|
||||||
|
}
|
||||||
|
if !slices.Equal(gotTensorNames, expectedTensorNames) {
|
||||||
|
t.Errorf("tensor names = %v, want %v", gotTensorNames, expectedTensorNames)
|
||||||
|
}
|
||||||
|
tokenTensor := f.TensorInfo("token_embd.weight")
|
||||||
|
if tokenTensor.Name != "token_embd.weight" {
|
||||||
|
t.Error("TensorInfo(\"token_embd.weight\") not found")
|
||||||
|
}
|
||||||
|
if len(tokenTensor.Shape) == 0 {
|
||||||
|
t.Error("TensorInfo(\"token_embd.weight\") has empty shape")
|
||||||
|
}
|
||||||
|
outputTensor := f.TensorInfo("output.weight")
|
||||||
|
if outputTensor.Name != "output.weight" {
|
||||||
|
t.Error("TensorInfo(\"output.weight\") not found")
|
||||||
|
}
|
||||||
|
if len(outputTensor.Shape) == 0 {
|
||||||
|
t.Error("TensorInfo(\"output.weight\") has empty shape")
|
||||||
|
}
|
||||||
|
var gotKeyCount int
|
||||||
|
for _, kv := range f.KeyValues() {
|
||||||
|
gotKeyCount++
|
||||||
|
if kv.Key == "" {
|
||||||
|
t.Error("found key value with empty key")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if gotKeyCount != 2 {
|
||||||
|
t.Errorf("iterated key count = %d, want %d", gotKeyCount, 2)
|
||||||
|
}
|
||||||
|
tensorInfo, reader, err := f.TensorReader("token_embd.weight")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("TensorReader(\"token_embd.weight\") error: %v", err)
|
||||||
|
}
|
||||||
|
if tensorInfo.Name != "token_embd.weight" {
|
||||||
|
t.Errorf("TensorReader returned wrong tensor: %q", tensorInfo.Name)
|
||||||
|
}
|
||||||
|
if reader == nil {
|
||||||
|
t.Error("TensorReader returned nil reader")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkRead(b *testing.B) {
|
||||||
|
// Create benchmark test file
|
||||||
|
tempDir := b.TempDir()
|
||||||
|
tempFile := filepath.Join(tempDir, "benchmark.gguf")
|
||||||
|
|
||||||
|
if err := createTestGGUFFile(tempFile, map[string]any{
|
||||||
|
"general.architecture": "llama",
|
||||||
|
"general.alignment": int64(32),
|
||||||
|
}, []testTensorInfo{
|
||||||
|
{Name: "token_embd.weight", Shape: []uint64{1000, 512}, Type: 1}, // F16
|
||||||
|
{Name: "output.weight", Shape: []uint64{512, 1000}, Type: 1}, // F16
|
||||||
|
}); err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get file info for reporting
|
||||||
|
info, err := os.Stat(tempFile)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
b.Logf("Benchmark file size: %d bytes", info.Size())
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
for b.Loop() {
|
||||||
|
f, err := Open(tempFile)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Access some data to ensure it's actually being read
|
||||||
|
_ = f.KeyValue("general.architecture").String()
|
||||||
|
_ = f.KeyValue("general.alignment").Int()
|
||||||
|
_ = f.NumTensors()
|
||||||
|
_ = f.NumKeyValues()
|
||||||
|
|
||||||
|
// Iterate through some tensors
|
||||||
|
count := 0
|
||||||
|
for _, tensor := range f.TensorInfos() {
|
||||||
|
_ = tensor.Name
|
||||||
|
count++
|
||||||
|
if count >= 2 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
f.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to create test GGUF files
|
||||||
|
func createTestGGUFFile(path string, keyValues map[string]any, tensors []testTensorInfo) error {
|
||||||
|
file, err := os.Create(path)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
|
||||||
|
// Write GGUF magic
|
||||||
|
if _, err := file.Write([]byte("GGUF")); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write version
|
||||||
|
if err := binary.Write(file, binary.LittleEndian, uint32(3)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write tensor count
|
||||||
|
if err := binary.Write(file, binary.LittleEndian, uint64(len(tensors))); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write metadata count
|
||||||
|
if err := binary.Write(file, binary.LittleEndian, uint64(len(keyValues))); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write metadata
|
||||||
|
for key, value := range keyValues {
|
||||||
|
if err := writeKeyValue(file, key, value); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write tensor info
|
||||||
|
for _, tensor := range tensors {
|
||||||
|
if err := writeTensorInfo(file, tensor); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write some dummy tensor data
|
||||||
|
dummyData := make([]byte, 1024)
|
||||||
|
file.Write(dummyData)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type testTensorInfo struct {
|
||||||
|
Name string
|
||||||
|
Shape []uint64
|
||||||
|
Type uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeKeyValue(file *os.File, key string, value any) error {
|
||||||
|
// Write key length and key
|
||||||
|
if err := binary.Write(file, binary.LittleEndian, uint64(len(key))); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if _, err := file.Write([]byte(key)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write value based on type
|
||||||
|
switch v := value.(type) {
|
||||||
|
case string:
|
||||||
|
if err := binary.Write(file, binary.LittleEndian, typeString); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := binary.Write(file, binary.LittleEndian, uint64(len(v))); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err := file.Write([]byte(v))
|
||||||
|
return err
|
||||||
|
case int64:
|
||||||
|
if err := binary.Write(file, binary.LittleEndian, typeInt64); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return binary.Write(file, binary.LittleEndian, v)
|
||||||
|
case bool:
|
||||||
|
if err := binary.Write(file, binary.LittleEndian, typeBool); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return binary.Write(file, binary.LittleEndian, v)
|
||||||
|
case float64:
|
||||||
|
if err := binary.Write(file, binary.LittleEndian, typeFloat64); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return binary.Write(file, binary.LittleEndian, v)
|
||||||
|
case []string:
|
||||||
|
if err := binary.Write(file, binary.LittleEndian, typeArray); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := binary.Write(file, binary.LittleEndian, typeString); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := binary.Write(file, binary.LittleEndian, uint64(len(v))); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for _, s := range v {
|
||||||
|
if err := binary.Write(file, binary.LittleEndian, uint64(len(s))); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if _, err := file.Write([]byte(s)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
case []int64:
|
||||||
|
if err := binary.Write(file, binary.LittleEndian, typeArray); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := binary.Write(file, binary.LittleEndian, typeInt64); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := binary.Write(file, binary.LittleEndian, uint64(len(v))); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for _, i := range v {
|
||||||
|
if err := binary.Write(file, binary.LittleEndian, i); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
case []float64:
|
||||||
|
if err := binary.Write(file, binary.LittleEndian, typeArray); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := binary.Write(file, binary.LittleEndian, typeFloat64); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := binary.Write(file, binary.LittleEndian, uint64(len(v))); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for _, f := range v {
|
||||||
|
if err := binary.Write(file, binary.LittleEndian, f); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unsupported value type: %T", value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeTensorInfo(file *os.File, tensor testTensorInfo) error {
|
||||||
|
// Write tensor name
|
||||||
|
if err := binary.Write(file, binary.LittleEndian, uint64(len(tensor.Name))); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if _, err := file.Write([]byte(tensor.Name)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write dimensions
|
||||||
|
if err := binary.Write(file, binary.LittleEndian, uint32(len(tensor.Shape))); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for _, dim := range tensor.Shape {
|
||||||
|
if err := binary.Write(file, binary.LittleEndian, dim); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write type
|
||||||
|
if err := binary.Write(file, binary.LittleEndian, tensor.Type); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write offset (dummy value)
|
||||||
|
return binary.Write(file, binary.LittleEndian, uint64(0))
|
||||||
|
}
|
||||||
102
fs/gguf/keyvalue.go
Normal file
102
fs/gguf/keyvalue.go
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
package gguf
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"slices"
|
||||||
|
)
|
||||||
|
|
||||||
|
type KeyValue struct {
|
||||||
|
Key string
|
||||||
|
Value
|
||||||
|
}
|
||||||
|
|
||||||
|
type Value struct {
|
||||||
|
value any
|
||||||
|
}
|
||||||
|
|
||||||
|
func value[T any](v Value, kinds ...reflect.Kind) (t T) {
|
||||||
|
vv := reflect.ValueOf(v.value)
|
||||||
|
if slices.Contains(kinds, vv.Kind()) {
|
||||||
|
t = vv.Convert(reflect.TypeOf(t)).Interface().(T)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func values[T any](v Value, kinds ...reflect.Kind) (ts []T) {
|
||||||
|
switch vv := reflect.ValueOf(v.value); vv.Kind() {
|
||||||
|
case reflect.Slice:
|
||||||
|
if slices.Contains(kinds, vv.Type().Elem().Kind()) {
|
||||||
|
ts = make([]T, vv.Len())
|
||||||
|
for i := range vv.Len() {
|
||||||
|
ts[i] = vv.Index(i).Convert(reflect.TypeOf(ts[i])).Interface().(T)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Int returns Value as a signed integer. If it is not a signed integer, it returns 0.
|
||||||
|
func (v Value) Int() int64 {
|
||||||
|
return value[int64](v, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ints returns Value as a signed integer slice. If it is not a signed integer slice, it returns nil.
|
||||||
|
func (v Value) Ints() (i64s []int64) {
|
||||||
|
return values[int64](v, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Uint converts an unsigned integer value to uint64. If the value is not a unsigned integer, it returns 0.
|
||||||
|
func (v Value) Uint() uint64 {
|
||||||
|
return value[uint64](v, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Uints returns Value as a unsigned integer slice. If it is not a unsigned integer slice, it returns nil.
|
||||||
|
func (v Value) Uints() (u64s []uint64) {
|
||||||
|
return values[uint64](v, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Float returns Value as a float. If it is not a float, it returns 0.
|
||||||
|
func (v Value) Float() float64 {
|
||||||
|
return value[float64](v, reflect.Float32, reflect.Float64)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Floats returns Value as a float slice. If it is not a float slice, it returns nil.
|
||||||
|
func (v Value) Floats() (f64s []float64) {
|
||||||
|
return values[float64](v, reflect.Float32, reflect.Float64)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bool returns Value as a boolean. If it is not a boolean, it returns false.
|
||||||
|
func (v Value) Bool() bool {
|
||||||
|
return value[bool](v, reflect.Bool)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bools returns Value as a boolean slice. If it is not a boolean slice, it returns nil.
|
||||||
|
func (v Value) Bools() (bools []bool) {
|
||||||
|
return values[bool](v, reflect.Bool)
|
||||||
|
}
|
||||||
|
|
||||||
|
// String returns Value as a string. If it is not a string, it returns an empty string.
|
||||||
|
func (v Value) String() string {
|
||||||
|
return value[string](v, reflect.String)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Strings returns Value as a string slice. If it is not a string slice, it returns nil.
|
||||||
|
func (v Value) Strings() (strings []string) {
|
||||||
|
return values[string](v, reflect.String)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsNil checks if the Value is nil. It returns true if the value is nil or if it is a nil pointer, interface, slice, map, channel, or function.
|
||||||
|
func (v Value) IsNil() bool {
|
||||||
|
if v.value == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for nil pointers, interfaces, slices, maps, channels, and functions
|
||||||
|
rv := reflect.ValueOf(v.value)
|
||||||
|
switch rv.Kind() {
|
||||||
|
case reflect.Ptr, reflect.Interface, reflect.Slice, reflect.Map, reflect.Chan, reflect.Func:
|
||||||
|
return rv.IsNil()
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
208
fs/gguf/keyvalue_test.go
Normal file
208
fs/gguf/keyvalue_test.go
Normal file
@@ -0,0 +1,208 @@
|
|||||||
|
package gguf
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
)
|
||||||
|
|
||||||
|
func split(name string, values map[string][]any) (matched []any, unmatched []any) {
|
||||||
|
for key, value := range values {
|
||||||
|
if key == name {
|
||||||
|
matched = value
|
||||||
|
} else {
|
||||||
|
unmatched = append(unmatched, value...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValue(t *testing.T) {
|
||||||
|
values := map[string][]any{
|
||||||
|
"int64": {int(42), int8(42), int16(42), int32(42), int64(42)},
|
||||||
|
"uint64": {uint(42), uint8(42), uint16(42), uint32(42), uint64(42)},
|
||||||
|
"float64": {float32(42), float64(42)},
|
||||||
|
"string": {"42", "hello"},
|
||||||
|
"bool": {true, false},
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("int64", func(t *testing.T) {
|
||||||
|
matched, unmatched := split("int64", values)
|
||||||
|
for _, v := range matched {
|
||||||
|
kv := KeyValue{"key", Value{v}}
|
||||||
|
if i64 := kv.Int(); i64 != 42 {
|
||||||
|
t.Errorf("expected 42, got %d", i64)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, v := range unmatched {
|
||||||
|
kv := KeyValue{"key", Value{v}}
|
||||||
|
if i64 := kv.Int(); i64 != 0 {
|
||||||
|
t.Errorf("expected 42, got %d", i64)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("uint64", func(t *testing.T) {
|
||||||
|
matched, unmatched := split("uint64", values)
|
||||||
|
for _, v := range matched {
|
||||||
|
kv := KeyValue{"key", Value{v}}
|
||||||
|
if u64 := kv.Uint(); u64 != 42 {
|
||||||
|
t.Errorf("expected 42, got %d", u64)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, v := range unmatched {
|
||||||
|
kv := KeyValue{"key", Value{v}}
|
||||||
|
if u64 := kv.Uint(); u64 != 0 {
|
||||||
|
t.Errorf("expected 42, got %d", u64)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("float64", func(t *testing.T) {
|
||||||
|
matched, unmatched := split("float64", values)
|
||||||
|
for _, v := range matched {
|
||||||
|
kv := KeyValue{"key", Value{v}}
|
||||||
|
if f64 := kv.Float(); f64 != 42 {
|
||||||
|
t.Errorf("expected 42, got %f", f64)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, v := range unmatched {
|
||||||
|
kv := KeyValue{"key", Value{v}}
|
||||||
|
if f64 := kv.Float(); f64 != 0 {
|
||||||
|
t.Errorf("expected 42, got %f", f64)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("string", func(t *testing.T) {
|
||||||
|
matched, unmatched := split("string", values)
|
||||||
|
for _, v := range matched {
|
||||||
|
kv := KeyValue{"key", Value{v}}
|
||||||
|
if s := kv.String(); s != v {
|
||||||
|
t.Errorf("expected 42, got %s", s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, v := range unmatched {
|
||||||
|
kv := KeyValue{"key", Value{v}}
|
||||||
|
if s := kv.String(); s != "" {
|
||||||
|
t.Errorf("expected 42, got %s", s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("bool", func(t *testing.T) {
|
||||||
|
matched, unmatched := split("bool", values)
|
||||||
|
for _, v := range matched {
|
||||||
|
kv := KeyValue{"key", Value{v}}
|
||||||
|
if b := kv.Bool(); b != v {
|
||||||
|
t.Errorf("expected true, got %v", b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, v := range unmatched {
|
||||||
|
kv := KeyValue{"key", Value{v}}
|
||||||
|
if b := kv.Bool(); b != false {
|
||||||
|
t.Errorf("expected false, got %v", b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValues(t *testing.T) {
|
||||||
|
values := map[string][]any{
|
||||||
|
"int64s": {[]int{42}, []int8{42}, []int16{42}, []int32{42}, []int64{42}},
|
||||||
|
"uint64s": {[]uint{42}, []uint8{42}, []uint16{42}, []uint32{42}, []uint64{42}},
|
||||||
|
"float64s": {[]float32{42}, []float64{42}},
|
||||||
|
"strings": {[]string{"42"}, []string{"hello"}},
|
||||||
|
"bools": {[]bool{true}, []bool{false}},
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("int64s", func(t *testing.T) {
|
||||||
|
matched, unmatched := split("int64s", values)
|
||||||
|
for _, v := range matched {
|
||||||
|
kv := KeyValue{"key", Value{v}}
|
||||||
|
if diff := cmp.Diff(kv.Ints(), []int64{42}); diff != "" {
|
||||||
|
t.Errorf("diff: %s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, v := range unmatched {
|
||||||
|
kv := KeyValue{"key", Value{v}}
|
||||||
|
if i64s := kv.Ints(); i64s != nil {
|
||||||
|
t.Errorf("expected nil, got %v", i64s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("uint64s", func(t *testing.T) {
|
||||||
|
matched, unmatched := split("uint64s", values)
|
||||||
|
for _, v := range matched {
|
||||||
|
kv := KeyValue{"key", Value{v}}
|
||||||
|
if diff := cmp.Diff(kv.Uints(), []uint64{42}); diff != "" {
|
||||||
|
t.Errorf("diff: %s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, v := range unmatched {
|
||||||
|
kv := KeyValue{"key", Value{v}}
|
||||||
|
if u64s := kv.Uints(); u64s != nil {
|
||||||
|
t.Errorf("expected nil, got %v", u64s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("float64s", func(t *testing.T) {
|
||||||
|
matched, unmatched := split("float64s", values)
|
||||||
|
for _, v := range matched {
|
||||||
|
kv := KeyValue{"key", Value{v}}
|
||||||
|
if diff := cmp.Diff(kv.Floats(), []float64{42}); diff != "" {
|
||||||
|
t.Errorf("diff: %s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, v := range unmatched {
|
||||||
|
kv := KeyValue{"key", Value{v}}
|
||||||
|
if f64s := kv.Floats(); f64s != nil {
|
||||||
|
t.Errorf("expected nil, got %v", f64s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("strings", func(t *testing.T) {
|
||||||
|
matched, unmatched := split("strings", values)
|
||||||
|
for _, v := range matched {
|
||||||
|
kv := KeyValue{"key", Value{v}}
|
||||||
|
if diff := cmp.Diff(kv.Strings(), v); diff != "" {
|
||||||
|
t.Errorf("diff: %s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, v := range unmatched {
|
||||||
|
kv := KeyValue{"key", Value{v}}
|
||||||
|
if s := kv.Strings(); s != nil {
|
||||||
|
t.Errorf("expected nil, got %v", s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("bools", func(t *testing.T) {
|
||||||
|
matched, unmatched := split("bools", values)
|
||||||
|
for _, v := range matched {
|
||||||
|
kv := KeyValue{"key", Value{v}}
|
||||||
|
if diff := cmp.Diff(kv.Bools(), v); diff != "" {
|
||||||
|
t.Errorf("diff: %s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, v := range unmatched {
|
||||||
|
kv := KeyValue{"key", Value{v}}
|
||||||
|
if b := kv.Bools(); b != nil {
|
||||||
|
t.Errorf("expected nil, got %v", b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
88
fs/gguf/lazy.go
Normal file
88
fs/gguf/lazy.go
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
package gguf
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"iter"
|
||||||
|
"log/slog"
|
||||||
|
)
|
||||||
|
|
||||||
|
type lazy[T any] struct {
|
||||||
|
count uint64
|
||||||
|
next func() (T, bool)
|
||||||
|
stop func()
|
||||||
|
values []T
|
||||||
|
|
||||||
|
doneFunc func() error
|
||||||
|
}
|
||||||
|
|
||||||
|
func newLazy[T any](f *File, fn func() (T, error)) (*lazy[T], error) {
|
||||||
|
it := lazy[T]{}
|
||||||
|
if err := binary.Read(f.reader, binary.LittleEndian, &it.count); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
it.values = make([]T, 0)
|
||||||
|
it.next, it.stop = iter.Pull(func(yield func(T) bool) {
|
||||||
|
for i := range it.count {
|
||||||
|
t, err := fn()
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("error reading tensor", "index", i, "error", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
it.values = append(it.values, t)
|
||||||
|
if !yield(t) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if it.doneFunc != nil {
|
||||||
|
it.doneFunc()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
return &it, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *lazy[T]) Values() iter.Seq[T] {
|
||||||
|
return func(yield func(T) bool) {
|
||||||
|
for _, v := range g.All() {
|
||||||
|
if !yield(v) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *lazy[T]) All() iter.Seq2[int, T] {
|
||||||
|
return func(yield func(int, T) bool) {
|
||||||
|
for i := range int(g.count) {
|
||||||
|
if i < len(g.values) {
|
||||||
|
if !yield(i, g.values[i]) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
t, ok := g.next()
|
||||||
|
if !ok {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if !yield(i, t) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *lazy[T]) rest() (collected bool) {
|
||||||
|
for {
|
||||||
|
_, ok := g.next()
|
||||||
|
collected = collected || ok
|
||||||
|
if !ok {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return collected
|
||||||
|
}
|
||||||
34
fs/gguf/reader.go
Normal file
34
fs/gguf/reader.go
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
package gguf
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
type readSeeker struct {
|
||||||
|
rs io.ReadSeeker
|
||||||
|
br *bufio.Reader
|
||||||
|
}
|
||||||
|
|
||||||
|
func newReadSeeker(rs io.ReadSeeker, size int) *readSeeker {
|
||||||
|
return &readSeeker{
|
||||||
|
rs: rs,
|
||||||
|
br: bufio.NewReaderSize(rs, size),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *readSeeker) Read(p []byte) (int, error) {
|
||||||
|
return b.br.Read(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *readSeeker) Seek(offset int64, whence int) (int64, error) {
|
||||||
|
if whence == io.SeekCurrent {
|
||||||
|
offset -= int64(b.br.Buffered())
|
||||||
|
}
|
||||||
|
n, err := b.rs.Seek(offset, whence)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
b.br.Reset(b.rs)
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
284
fs/gguf/tensor.go
Normal file
284
fs/gguf/tensor.go
Normal file
@@ -0,0 +1,284 @@
|
|||||||
|
package gguf
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log/slog"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type TensorInfo struct {
|
||||||
|
Name string
|
||||||
|
Offset uint64
|
||||||
|
Shape []uint64
|
||||||
|
Type TensorType
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t TensorInfo) NumValues() int64 {
|
||||||
|
var numItems int64 = 1
|
||||||
|
for _, dim := range t.Shape {
|
||||||
|
numItems *= int64(dim)
|
||||||
|
}
|
||||||
|
return numItems
|
||||||
|
}
|
||||||
|
|
||||||
|
// NumBytes returns the number of bytes in the tensor.
|
||||||
|
func (t TensorInfo) NumBytes() int64 {
|
||||||
|
return int64(float64(t.NumValues()) * t.Type.NumBytes())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t TensorInfo) LogValue() slog.Value {
|
||||||
|
return slog.GroupValue(
|
||||||
|
slog.String("name", t.Name),
|
||||||
|
slog.Int64("offset", int64(t.Offset)),
|
||||||
|
slog.Any("shape", t.Shape),
|
||||||
|
slog.Int64("num_values", t.NumValues()),
|
||||||
|
slog.Int64("num_bytes", t.NumBytes()),
|
||||||
|
slog.Any("type", t.Type),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
type TensorType uint32
|
||||||
|
|
||||||
|
const (
|
||||||
|
TensorTypeF32 TensorType = iota
|
||||||
|
TensorTypeF16
|
||||||
|
TensorTypeQ4_0
|
||||||
|
TensorTypeQ4_1
|
||||||
|
|
||||||
|
// unexported // unused in gguf
|
||||||
|
tensorTypeQ4_2
|
||||||
|
tensorTypeQ4_3
|
||||||
|
|
||||||
|
TensorTypeQ5_0
|
||||||
|
TensorTypeQ5_1
|
||||||
|
TensorTypeQ8_0
|
||||||
|
TensorTypeQ8_1
|
||||||
|
TensorTypeQ2_K
|
||||||
|
TensorTypeQ3_K
|
||||||
|
TensorTypeQ4_K
|
||||||
|
TensorTypeQ5_K
|
||||||
|
TensorTypeQ6_K
|
||||||
|
TensorTypeQ8_K
|
||||||
|
|
||||||
|
// unexported // unquantizable by ollama
|
||||||
|
tensorTypeIQ2_XXS
|
||||||
|
tensorTypeIQ2_XS
|
||||||
|
tensorTypeIQ3_XXS
|
||||||
|
tensorTypeIQ1_S
|
||||||
|
tensorTypeIQ4_NL
|
||||||
|
tensorTypeIQ3_S
|
||||||
|
tensorTypeIQ2_S
|
||||||
|
tensorTypeIQ4_XS
|
||||||
|
|
||||||
|
TensorTypeI8
|
||||||
|
TensorTypeI16
|
||||||
|
TensorTypeI32
|
||||||
|
TensorTypeI64
|
||||||
|
TensorTypeF64
|
||||||
|
|
||||||
|
// unexported // unquantizable by ollama
|
||||||
|
tensorTypeIQ1_M
|
||||||
|
|
||||||
|
TensorTypeBF16
|
||||||
|
|
||||||
|
// unexported // unused in gguf
|
||||||
|
tensorTypeQ4_0_4_4
|
||||||
|
tensorTypeQ4_0_4_8
|
||||||
|
tensorTypeQ4_0_8_8
|
||||||
|
|
||||||
|
// unexported // unquantizable by ollama
|
||||||
|
tensorTypeTQ1_0
|
||||||
|
tensorTypeTQ2_0
|
||||||
|
|
||||||
|
// unexported // unused in gguf
|
||||||
|
tensorTypeIQ4_NL_4_4
|
||||||
|
tensorTypeIQ4_NL_4_8
|
||||||
|
tensorTypeIQ4_NL_8_8
|
||||||
|
)
|
||||||
|
|
||||||
|
func (t TensorType) NumBytes() float64 {
|
||||||
|
return float64(t.typeSize()) / float64(t.blockSize())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t TensorType) typeSize() int64 {
|
||||||
|
switch t {
|
||||||
|
case TensorTypeF32:
|
||||||
|
return 4
|
||||||
|
case TensorTypeF16:
|
||||||
|
return 2
|
||||||
|
case TensorTypeQ4_0:
|
||||||
|
return 2 + t.blockSize()/2
|
||||||
|
case TensorTypeQ4_1:
|
||||||
|
return 2 + 2 + t.blockSize()/2
|
||||||
|
case TensorTypeQ5_0:
|
||||||
|
return 2 + 4 + t.blockSize()/2
|
||||||
|
case TensorTypeQ5_1:
|
||||||
|
return 2 + 2 + 4 + t.blockSize()/2
|
||||||
|
case TensorTypeQ8_0:
|
||||||
|
return 2 + t.blockSize()
|
||||||
|
case TensorTypeQ8_1:
|
||||||
|
return 2 + 2 + t.blockSize()
|
||||||
|
case TensorTypeQ2_K:
|
||||||
|
return t.blockSize()/16 + t.blockSize()/4 + 2 + 2
|
||||||
|
case TensorTypeQ3_K:
|
||||||
|
return t.blockSize()/8 + t.blockSize()/4 + 12 + 2
|
||||||
|
case TensorTypeQ4_K:
|
||||||
|
return 2 + 2 + 12 + t.blockSize()/2
|
||||||
|
case TensorTypeQ5_K:
|
||||||
|
return 2 + 2 + 12 + t.blockSize()/8 + t.blockSize()/2
|
||||||
|
case TensorTypeQ6_K:
|
||||||
|
return t.blockSize()/2 + t.blockSize()/4 + t.blockSize()/16 + 2
|
||||||
|
case TensorTypeQ8_K:
|
||||||
|
return 4 + t.blockSize() + 2*t.blockSize()/16
|
||||||
|
case tensorTypeIQ2_XXS:
|
||||||
|
return 2 + 2*t.blockSize()/8
|
||||||
|
case tensorTypeIQ2_XS:
|
||||||
|
return 2 + 2*t.blockSize()/8 + t.blockSize()/32
|
||||||
|
case tensorTypeIQ3_XXS:
|
||||||
|
return 2 + t.blockSize()/4 + t.blockSize()/8
|
||||||
|
case tensorTypeIQ1_S:
|
||||||
|
return 2 + t.blockSize()/8 + t.blockSize()/16
|
||||||
|
case tensorTypeIQ4_NL:
|
||||||
|
return 2 + t.blockSize()/2
|
||||||
|
case tensorTypeIQ3_S:
|
||||||
|
return 2 + t.blockSize()/4 + t.blockSize()/8 + t.blockSize()/32 + 4
|
||||||
|
case tensorTypeIQ2_S:
|
||||||
|
return 2 + t.blockSize()/4 + t.blockSize()/16
|
||||||
|
case tensorTypeIQ4_XS:
|
||||||
|
return 2 + 2 + t.blockSize()/2 + t.blockSize()/64
|
||||||
|
case TensorTypeI8:
|
||||||
|
return 1
|
||||||
|
case TensorTypeI16:
|
||||||
|
return 2
|
||||||
|
case TensorTypeI32:
|
||||||
|
return 4
|
||||||
|
case TensorTypeI64:
|
||||||
|
return 8
|
||||||
|
case TensorTypeF64:
|
||||||
|
return 8
|
||||||
|
case tensorTypeIQ1_M:
|
||||||
|
return t.blockSize()/8 + t.blockSize()/16 + t.blockSize()/32
|
||||||
|
case TensorTypeBF16:
|
||||||
|
return 2
|
||||||
|
default:
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t TensorType) blockSize() int64 {
|
||||||
|
switch t {
|
||||||
|
case TensorTypeF32,
|
||||||
|
TensorTypeF16,
|
||||||
|
TensorTypeI8,
|
||||||
|
TensorTypeI16,
|
||||||
|
TensorTypeI32,
|
||||||
|
TensorTypeI64,
|
||||||
|
TensorTypeF64,
|
||||||
|
TensorTypeBF16:
|
||||||
|
return 1
|
||||||
|
case TensorTypeQ4_0,
|
||||||
|
TensorTypeQ4_1,
|
||||||
|
TensorTypeQ5_0,
|
||||||
|
TensorTypeQ5_1,
|
||||||
|
TensorTypeQ8_0,
|
||||||
|
TensorTypeQ8_1,
|
||||||
|
tensorTypeIQ4_NL:
|
||||||
|
return 32
|
||||||
|
default:
|
||||||
|
return 256
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t TensorType) String() string {
|
||||||
|
switch t {
|
||||||
|
case TensorTypeF32:
|
||||||
|
return "f32"
|
||||||
|
case TensorTypeF16:
|
||||||
|
return "f16"
|
||||||
|
case TensorTypeQ4_0:
|
||||||
|
return "q4_0"
|
||||||
|
case TensorTypeQ4_1:
|
||||||
|
return "q4_1"
|
||||||
|
case tensorTypeQ4_2:
|
||||||
|
return "q4_2"
|
||||||
|
case tensorTypeQ4_3:
|
||||||
|
return "q4_3"
|
||||||
|
case TensorTypeQ5_0:
|
||||||
|
return "q5_0"
|
||||||
|
case TensorTypeQ5_1:
|
||||||
|
return "q5_1"
|
||||||
|
case TensorTypeQ8_0:
|
||||||
|
return "q8_0"
|
||||||
|
case TensorTypeQ8_1:
|
||||||
|
return "q8_1"
|
||||||
|
case TensorTypeQ2_K:
|
||||||
|
return "q2_k"
|
||||||
|
case TensorTypeQ3_K:
|
||||||
|
return "q3_k"
|
||||||
|
case TensorTypeQ4_K:
|
||||||
|
return "q4_k"
|
||||||
|
case TensorTypeQ5_K:
|
||||||
|
return "q5_k"
|
||||||
|
case TensorTypeQ6_K:
|
||||||
|
return "q6_k"
|
||||||
|
case TensorTypeQ8_K:
|
||||||
|
return "q8_k"
|
||||||
|
case tensorTypeIQ2_XXS:
|
||||||
|
return "iq2_xxs"
|
||||||
|
case tensorTypeIQ2_XS:
|
||||||
|
return "iq2_xs"
|
||||||
|
case tensorTypeIQ3_XXS:
|
||||||
|
return "iq3_xxs"
|
||||||
|
case tensorTypeIQ1_S:
|
||||||
|
return "iq1_s"
|
||||||
|
case tensorTypeIQ4_NL:
|
||||||
|
return "iq4_nl"
|
||||||
|
case tensorTypeIQ3_S:
|
||||||
|
return "iq3_s"
|
||||||
|
case tensorTypeIQ2_S:
|
||||||
|
return "iq2_s"
|
||||||
|
case tensorTypeIQ4_XS:
|
||||||
|
return "iq4_xs"
|
||||||
|
case TensorTypeI8:
|
||||||
|
return "i8"
|
||||||
|
case TensorTypeI16:
|
||||||
|
return "i16"
|
||||||
|
case TensorTypeI32:
|
||||||
|
return "i32"
|
||||||
|
case TensorTypeI64:
|
||||||
|
return "i64"
|
||||||
|
case TensorTypeF64:
|
||||||
|
return "f64"
|
||||||
|
case tensorTypeIQ1_M:
|
||||||
|
return "iq1_m"
|
||||||
|
case TensorTypeBF16:
|
||||||
|
return "bf16"
|
||||||
|
case tensorTypeQ4_0_4_4:
|
||||||
|
return "q4_0_4_4"
|
||||||
|
case tensorTypeQ4_0_4_8:
|
||||||
|
return "q4_0_4_8"
|
||||||
|
case tensorTypeQ4_0_8_8:
|
||||||
|
return "q4_0_8_8"
|
||||||
|
case tensorTypeTQ1_0:
|
||||||
|
return "tq1_0"
|
||||||
|
case tensorTypeTQ2_0:
|
||||||
|
return "tq2_0"
|
||||||
|
case tensorTypeIQ4_NL_4_4:
|
||||||
|
return "iq4_nl_4_4"
|
||||||
|
case tensorTypeIQ4_NL_4_8:
|
||||||
|
return "iq4_nl_4_8"
|
||||||
|
case tensorTypeIQ4_NL_8_8:
|
||||||
|
return "iq4_nl_8_8"
|
||||||
|
default:
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t TensorType) LogValue() slog.Value {
|
||||||
|
return slog.GroupValue(
|
||||||
|
slog.Uint64("value", uint64(t)),
|
||||||
|
slog.String("name", strings.ToUpper(t.String())),
|
||||||
|
slog.Int64("size", t.typeSize()),
|
||||||
|
slog.Int64("block_size", t.blockSize()),
|
||||||
|
slog.Float64("num_bytes", t.NumBytes()),
|
||||||
|
)
|
||||||
|
}
|
||||||
@@ -23,7 +23,7 @@ import (
|
|||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/fs/ggml"
|
"github.com/ollama/ollama/fs/gguf"
|
||||||
"github.com/ollama/ollama/parser"
|
"github.com/ollama/ollama/parser"
|
||||||
"github.com/ollama/ollama/template"
|
"github.com/ollama/ollama/template"
|
||||||
"github.com/ollama/ollama/thinking"
|
"github.com/ollama/ollama/thinking"
|
||||||
@@ -73,23 +73,21 @@ func (m *Model) Capabilities() []model.Capability {
|
|||||||
capabilities := []model.Capability{}
|
capabilities := []model.Capability{}
|
||||||
|
|
||||||
// Check for completion capability
|
// Check for completion capability
|
||||||
r, err := os.Open(m.ModelPath)
|
f, err := gguf.Open(m.ModelPath)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
defer r.Close()
|
defer f.Close()
|
||||||
|
|
||||||
f, err := ggml.Decode(r, 1024)
|
embedding := f.KeyValue("pooling_type")
|
||||||
if err == nil {
|
if !embedding.Value.IsNil() {
|
||||||
if _, ok := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]; ok {
|
|
||||||
capabilities = append(capabilities, model.CapabilityEmbedding)
|
capabilities = append(capabilities, model.CapabilityEmbedding)
|
||||||
} else {
|
} else {
|
||||||
|
// If no embedding is specified, we assume the model supports completion
|
||||||
capabilities = append(capabilities, model.CapabilityCompletion)
|
capabilities = append(capabilities, model.CapabilityCompletion)
|
||||||
}
|
}
|
||||||
if _, ok := f.KV()[fmt.Sprintf("%s.vision.block_count", f.KV().Architecture())]; ok {
|
vision := f.KeyValue("vision.block_count")
|
||||||
|
if !vision.Value.IsNil() {
|
||||||
capabilities = append(capabilities, model.CapabilityVision)
|
capabilities = append(capabilities, model.CapabilityVision)
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
slog.Error("couldn't decode ggml", "error", err)
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
slog.Error("couldn't open model file", "error", err)
|
slog.Error("couldn't open model file", "error", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,9 +1,8 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -13,81 +12,200 @@ import (
|
|||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Constants for GGUF magic bytes and version
|
// GGUF type constants (matching gguf package)
|
||||||
var (
|
const (
|
||||||
ggufMagic = []byte{0x47, 0x47, 0x55, 0x46} // "GGUF"
|
typeUint8 = uint32(0)
|
||||||
ggufVer = uint32(3) // Version 3
|
typeInt8 = uint32(1)
|
||||||
|
typeUint16 = uint32(2)
|
||||||
|
typeInt16 = uint32(3)
|
||||||
|
typeUint32 = uint32(4)
|
||||||
|
typeInt32 = uint32(5)
|
||||||
|
typeFloat32 = uint32(6)
|
||||||
|
typeBool = uint32(7)
|
||||||
|
typeString = uint32(8)
|
||||||
|
typeArray = uint32(9)
|
||||||
|
typeUint64 = uint32(10)
|
||||||
|
typeInt64 = uint32(11)
|
||||||
|
typeFloat64 = uint32(12)
|
||||||
)
|
)
|
||||||
|
|
||||||
// Helper function to create mock GGUF data
|
type testTensorInfo struct {
|
||||||
func createMockGGUFData(architecture string, vision bool) []byte {
|
Name string
|
||||||
var buf bytes.Buffer
|
Shape []uint64
|
||||||
|
Type uint32
|
||||||
// Write GGUF header
|
|
||||||
buf.Write(ggufMagic)
|
|
||||||
binary.Write(&buf, binary.LittleEndian, ggufVer)
|
|
||||||
|
|
||||||
// Write tensor count (0 for our test)
|
|
||||||
var numTensors uint64 = 0
|
|
||||||
binary.Write(&buf, binary.LittleEndian, numTensors)
|
|
||||||
|
|
||||||
// Calculate number of metadata entries
|
|
||||||
numMetaEntries := uint64(1) // architecture entry
|
|
||||||
if vision {
|
|
||||||
numMetaEntries++
|
|
||||||
}
|
|
||||||
// Add embedding entry if architecture is "bert"
|
|
||||||
if architecture == "bert" {
|
|
||||||
numMetaEntries++
|
|
||||||
}
|
|
||||||
binary.Write(&buf, binary.LittleEndian, numMetaEntries)
|
|
||||||
|
|
||||||
// Write architecture metadata
|
|
||||||
archKey := "general.architecture"
|
|
||||||
keyLen := uint64(len(archKey))
|
|
||||||
binary.Write(&buf, binary.LittleEndian, keyLen)
|
|
||||||
buf.WriteString(archKey)
|
|
||||||
|
|
||||||
// String type (8)
|
|
||||||
var strType uint32 = 8
|
|
||||||
binary.Write(&buf, binary.LittleEndian, strType)
|
|
||||||
|
|
||||||
// String length
|
|
||||||
strLen := uint64(len(architecture))
|
|
||||||
binary.Write(&buf, binary.LittleEndian, strLen)
|
|
||||||
buf.WriteString(architecture)
|
|
||||||
|
|
||||||
if vision {
|
|
||||||
visionKey := architecture + ".vision.block_count"
|
|
||||||
keyLen = uint64(len(visionKey))
|
|
||||||
binary.Write(&buf, binary.LittleEndian, keyLen)
|
|
||||||
buf.WriteString(visionKey)
|
|
||||||
|
|
||||||
// uint32 type (4)
|
|
||||||
var uint32Type uint32 = 4
|
|
||||||
binary.Write(&buf, binary.LittleEndian, uint32Type)
|
|
||||||
|
|
||||||
// uint32 value (1)
|
|
||||||
var countVal uint32 = 1
|
|
||||||
binary.Write(&buf, binary.LittleEndian, countVal)
|
|
||||||
}
|
|
||||||
// Write embedding metadata if architecture is "bert"
|
|
||||||
if architecture == "bert" {
|
|
||||||
poolKey := architecture + ".pooling_type"
|
|
||||||
keyLen = uint64(len(poolKey))
|
|
||||||
binary.Write(&buf, binary.LittleEndian, keyLen)
|
|
||||||
buf.WriteString(poolKey)
|
|
||||||
|
|
||||||
// uint32 type (4)
|
|
||||||
var uint32Type uint32 = 4
|
|
||||||
binary.Write(&buf, binary.LittleEndian, uint32Type)
|
|
||||||
|
|
||||||
// uint32 value (1)
|
|
||||||
var poolingVal uint32 = 1
|
|
||||||
binary.Write(&buf, binary.LittleEndian, poolingVal)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return buf.Bytes()
|
// Helper function to create test GGUF files (matching gguf package approach)
|
||||||
|
func createTestGGUFFile(path string, keyValues map[string]any, tensors []testTensorInfo) error {
|
||||||
|
file, err := os.Create(path)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
|
||||||
|
// Write GGUF magic
|
||||||
|
if _, err := file.Write([]byte("GGUF")); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write version
|
||||||
|
if err := binary.Write(file, binary.LittleEndian, uint32(3)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write tensor count
|
||||||
|
if err := binary.Write(file, binary.LittleEndian, uint64(len(tensors))); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write metadata count
|
||||||
|
if err := binary.Write(file, binary.LittleEndian, uint64(len(keyValues))); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write metadata
|
||||||
|
for key, value := range keyValues {
|
||||||
|
if err := writeKeyValue(file, key, value); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write tensor info
|
||||||
|
for _, tensor := range tensors {
|
||||||
|
if err := writeTensorInfo(file, tensor); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write some dummy tensor data
|
||||||
|
dummyData := make([]byte, 1024)
|
||||||
|
file.Write(dummyData)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeKeyValue(file *os.File, key string, value any) error {
|
||||||
|
// Write key length and key
|
||||||
|
if err := binary.Write(file, binary.LittleEndian, uint64(len(key))); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if _, err := file.Write([]byte(key)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write value based on type
|
||||||
|
switch v := value.(type) {
|
||||||
|
case string:
|
||||||
|
if err := binary.Write(file, binary.LittleEndian, uint32(typeString)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := binary.Write(file, binary.LittleEndian, uint64(len(v))); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err := file.Write([]byte(v))
|
||||||
|
return err
|
||||||
|
case int64:
|
||||||
|
if err := binary.Write(file, binary.LittleEndian, typeInt64); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return binary.Write(file, binary.LittleEndian, v)
|
||||||
|
case uint32:
|
||||||
|
if err := binary.Write(file, binary.LittleEndian, typeUint32); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return binary.Write(file, binary.LittleEndian, v)
|
||||||
|
case bool:
|
||||||
|
if err := binary.Write(file, binary.LittleEndian, typeBool); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return binary.Write(file, binary.LittleEndian, v)
|
||||||
|
case float64:
|
||||||
|
if err := binary.Write(file, binary.LittleEndian, uint32(typeFloat64)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return binary.Write(file, binary.LittleEndian, v)
|
||||||
|
case []string:
|
||||||
|
if err := binary.Write(file, binary.LittleEndian, uint32(typeArray)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := binary.Write(file, binary.LittleEndian, typeString); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := binary.Write(file, binary.LittleEndian, uint64(len(v))); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for _, s := range v {
|
||||||
|
if err := binary.Write(file, binary.LittleEndian, uint64(len(s))); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if _, err := file.Write([]byte(s)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
case []int64:
|
||||||
|
if err := binary.Write(file, binary.LittleEndian, uint32(typeArray)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := binary.Write(file, binary.LittleEndian, typeInt64); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := binary.Write(file, binary.LittleEndian, uint64(len(v))); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for _, i := range v {
|
||||||
|
if err := binary.Write(file, binary.LittleEndian, i); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
case []float64:
|
||||||
|
if err := binary.Write(file, binary.LittleEndian, typeArray); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := binary.Write(file, binary.LittleEndian, typeFloat64); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := binary.Write(file, binary.LittleEndian, uint64(len(v))); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for _, f := range v {
|
||||||
|
if err := binary.Write(file, binary.LittleEndian, f); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unsupported value type: %T", value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeTensorInfo(file *os.File, tensor testTensorInfo) error {
|
||||||
|
// Write tensor name
|
||||||
|
if err := binary.Write(file, binary.LittleEndian, uint64(len(tensor.Name))); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if _, err := file.Write([]byte(tensor.Name)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write dimensions
|
||||||
|
if err := binary.Write(file, binary.LittleEndian, uint32(len(tensor.Shape))); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for _, dim := range tensor.Shape {
|
||||||
|
if err := binary.Write(file, binary.LittleEndian, dim); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write type
|
||||||
|
if err := binary.Write(file, binary.LittleEndian, tensor.Type); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write offset (dummy value)
|
||||||
|
return binary.Write(file, binary.LittleEndian, uint64(0))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestModelCapabilities(t *testing.T) {
|
func TestModelCapabilities(t *testing.T) {
|
||||||
@@ -101,13 +219,38 @@ func TestModelCapabilities(t *testing.T) {
|
|||||||
// Create a simple model file for tests that don't depend on GGUF content
|
// Create a simple model file for tests that don't depend on GGUF content
|
||||||
simpleModelPath := filepath.Join(tempDir, "simple_model.bin")
|
simpleModelPath := filepath.Join(tempDir, "simple_model.bin")
|
||||||
|
|
||||||
if err := errors.Join(
|
// Create completion model (llama architecture without vision)
|
||||||
os.WriteFile(completionModelPath, createMockGGUFData("llama", false), 0o644),
|
if err := createTestGGUFFile(completionModelPath, map[string]any{
|
||||||
os.WriteFile(visionModelPath, createMockGGUFData("llama", true), 0o644),
|
"general.architecture": "llama",
|
||||||
os.WriteFile(embeddingModelPath, createMockGGUFData("bert", false), 0o644),
|
}, []testTensorInfo{
|
||||||
os.WriteFile(simpleModelPath, []byte("dummy model data"), 0o644),
|
{Name: "token_embd.weight", Shape: []uint64{1000, 512}, Type: 1}, // F16
|
||||||
); err != nil {
|
}); err != nil {
|
||||||
t.Fatalf("Failed to create model files: %v", err)
|
t.Fatalf("Failed to create completion model file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create vision model (llama architecture with vision block count)
|
||||||
|
if err := createTestGGUFFile(visionModelPath, map[string]any{
|
||||||
|
"general.architecture": "llama",
|
||||||
|
"llama.vision.block_count": uint32(1),
|
||||||
|
}, []testTensorInfo{
|
||||||
|
{Name: "token_embd.weight", Shape: []uint64{1000, 512}, Type: 1}, // F16
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("Failed to create vision model file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create embedding model (bert architecture with pooling type)
|
||||||
|
if err := createTestGGUFFile(embeddingModelPath, map[string]any{
|
||||||
|
"general.architecture": "bert",
|
||||||
|
"bert.pooling_type": uint32(1),
|
||||||
|
}, []testTensorInfo{
|
||||||
|
{Name: "token_embd.weight", Shape: []uint64{1000, 512}, Type: 1}, // F16
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("Failed to create embedding model file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create simple model file for tests that don't depend on GGUF content
|
||||||
|
if err := os.WriteFile(simpleModelPath, []byte("dummy model data"), 0o644); err != nil {
|
||||||
|
t.Fatalf("Failed to create simple model file: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
toolsInsertTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}{{ if .suffix }}{{ .suffix }}{{ end }}")
|
toolsInsertTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}{{ if .suffix }}{{ .suffix }}{{ end }}")
|
||||||
@@ -231,12 +374,29 @@ func TestModelCheckCapabilities(t *testing.T) {
|
|||||||
simpleModelPath := filepath.Join(tempDir, "model.bin")
|
simpleModelPath := filepath.Join(tempDir, "model.bin")
|
||||||
embeddingModelPath := filepath.Join(tempDir, "embedding_model.bin")
|
embeddingModelPath := filepath.Join(tempDir, "embedding_model.bin")
|
||||||
|
|
||||||
if err := errors.Join(
|
// Create vision model (llama architecture with vision block count)
|
||||||
os.WriteFile(simpleModelPath, []byte("dummy model data"), 0o644),
|
if err := createTestGGUFFile(visionModelPath, map[string]any{
|
||||||
os.WriteFile(visionModelPath, createMockGGUFData("llama", true), 0o644),
|
"general.architecture": "llama",
|
||||||
os.WriteFile(embeddingModelPath, createMockGGUFData("bert", false), 0o644),
|
"llama.vision.block_count": uint32(1),
|
||||||
); err != nil {
|
}, []testTensorInfo{
|
||||||
t.Fatalf("Failed to create model files: %v", err)
|
{Name: "token_embd.weight", Shape: []uint64{1000, 512}, Type: 1}, // F16
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("Failed to create vision model file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create embedding model (bert architecture with pooling type)
|
||||||
|
if err := createTestGGUFFile(embeddingModelPath, map[string]any{
|
||||||
|
"general.architecture": "bert",
|
||||||
|
"bert.pooling_type": uint32(1),
|
||||||
|
}, []testTensorInfo{
|
||||||
|
{Name: "token_embd.weight", Shape: []uint64{1000, 512}, Type: 1}, // F16
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("Failed to create embedding model file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create simple model file for tests that don't depend on GGUF content
|
||||||
|
if err := os.WriteFile(simpleModelPath, []byte("dummy model data"), 0o644); err != nil {
|
||||||
|
t.Fatalf("Failed to create simple model file: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
toolsInsertTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}{{ if .suffix }}{{ .suffix }}{{ end }}")
|
toolsInsertTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}{{ if .suffix }}{{ .suffix }}{{ end }}")
|
||||||
|
|||||||
Reference in New Issue
Block a user