Compare commits

..

6 Commits

Author SHA1 Message Date
Jeffrey Morgan
ed567ef43b Revert "ggml: Export GPU UUIDs" (#11115)
This reverts commit aaa7818000.
2025-06-18 05:45:00 -07:00
Jeffrey Morgan
a6e64fbdf2 Revert "feat: incremental gguf parser (#10822)" (#11114)
This reverts commit 6b04cad7e8.
2025-06-18 05:42:44 -07:00
曹家巧
60cfa2a203 cache: fix comment function name in cache.go (#11110) 2025-06-18 05:21:45 -07:00
Jeffrey Morgan
55bbf3b4a1 tools: return empty arguments object instead of null (#11113) 2025-06-18 05:20:43 -07:00
Jeffrey Morgan
6bda1d2479 tools: fix parsing tool calls without any parameters (#11101)
Fixes issue where tool calls that don't expect any parameters were
not being parsed. This also fixes two additional issues: one where
2+ tool calls would not be correctly parsed, and cases where tool calls
with invalid parameters would still get parsed
2025-06-17 10:51:43 -07:00
Jeffrey Morgan
9e125d884c model: treat 'user defined' tokens as special tokens (#11077) 2025-06-16 16:03:16 -07:00
27 changed files with 642 additions and 1873 deletions

View File

@@ -1,347 +0,0 @@
package gguf
import (
"bytes"
"cmp"
"encoding/binary"
"errors"
"fmt"
"io"
"iter"
"os"
"slices"
"strings"
)
const (
typeUint8 uint32 = iota
typeInt8
typeUint16
typeInt16
typeUint32
typeInt32
typeFloat32
typeBool
typeString
typeArray
typeUint64
typeInt64
typeFloat64
)
var ErrUnsupported = errors.New("unsupported")
type File struct {
Magic [4]byte
Version uint32
keyValues *lazy[KeyValue]
tensors *lazy[TensorInfo]
offset int64
file *os.File
reader *bufferedReader
bts []byte
}
func Open(path string) (f *File, err error) {
f = &File{bts: make([]byte, 4096)}
f.file, err = os.Open(path)
if err != nil {
return nil, err
}
f.reader = newBufferedReader(f.file, 32<<10)
if err := binary.Read(f.reader, binary.LittleEndian, &f.Magic); err != nil {
return nil, err
}
if bytes.Equal(f.Magic[:], []byte("gguf")) {
return nil, fmt.Errorf("%w file type %v", ErrUnsupported, f.Magic)
}
if err := binary.Read(f.reader, binary.LittleEndian, &f.Version); err != nil {
return nil, err
}
if f.Version != 3 {
return nil, fmt.Errorf("%w version %v", ErrUnsupported, f.Version)
}
f.tensors, err = newLazy(f, f.readTensor)
if err != nil {
return nil, err
}
f.tensors.successFunc = func() error {
offset := f.reader.offset
alignment := cmp.Or(f.KeyValue("general.alignment").Int(), 32)
f.offset = offset + (alignment-offset%alignment)%alignment
return nil
}
f.keyValues, err = newLazy(f, f.readKeyValue)
if err != nil {
return nil, err
}
return f, nil
}
func (f *File) readTensor() (TensorInfo, error) {
name, err := readString(f)
if err != nil {
return TensorInfo{}, err
}
dims, err := read[uint32](f)
if err != nil {
return TensorInfo{}, err
}
shape := make([]uint64, dims)
for i := range dims {
shape[i], err = read[uint64](f)
if err != nil {
return TensorInfo{}, err
}
}
type_, err := read[uint32](f)
if err != nil {
return TensorInfo{}, err
}
offset, err := read[uint64](f)
if err != nil {
return TensorInfo{}, err
}
return TensorInfo{
Name: name,
Offset: offset,
Shape: shape,
Type: TensorType(type_),
}, nil
}
func (f *File) readKeyValue() (KeyValue, error) {
key, err := readString(f)
if err != nil {
return KeyValue{}, err
}
t, err := read[uint32](f)
if err != nil {
return KeyValue{}, err
}
value, err := func() (any, error) {
switch t {
case typeUint8:
return read[uint8](f)
case typeInt8:
return read[int8](f)
case typeUint16:
return read[uint16](f)
case typeInt16:
return read[int16](f)
case typeUint32:
return read[uint32](f)
case typeInt32:
return read[int32](f)
case typeUint64:
return read[uint64](f)
case typeInt64:
return read[int64](f)
case typeFloat32:
return read[float32](f)
case typeFloat64:
return read[float64](f)
case typeBool:
return read[bool](f)
case typeString:
return readString(f)
case typeArray:
return readArray(f)
default:
return nil, fmt.Errorf("%w type %d", ErrUnsupported, t)
}
}()
if err != nil {
return KeyValue{}, err
}
return KeyValue{
Key: key,
Value: Value{value},
}, nil
}
func read[T any](f *File) (t T, err error) {
err = binary.Read(f.reader, binary.LittleEndian, &t)
return t, err
}
func readString(f *File) (string, error) {
n, err := read[uint64](f)
if err != nil {
return "", err
}
if int(n) > len(f.bts) {
f.bts = make([]byte, n)
}
bts := f.bts[:n]
if _, err := io.ReadFull(f.reader, bts); err != nil {
return "", err
}
defer clear(bts)
return string(bts), nil
}
func readArray(f *File) (any, error) {
t, err := read[uint32](f)
if err != nil {
return nil, err
}
n, err := read[uint64](f)
if err != nil {
return nil, err
}
switch t {
case typeUint8:
return readArrayData[uint8](f, n)
case typeInt8:
return readArrayData[int8](f, n)
case typeUint16:
return readArrayData[uint16](f, n)
case typeInt16:
return readArrayData[int16](f, n)
case typeUint32:
return readArrayData[uint32](f, n)
case typeInt32:
return readArrayData[int32](f, n)
case typeUint64:
return readArrayData[uint64](f, n)
case typeInt64:
return readArrayData[int64](f, n)
case typeFloat32:
return readArrayData[float32](f, n)
case typeFloat64:
return readArrayData[float64](f, n)
case typeBool:
return readArrayData[bool](f, n)
case typeString:
return readArrayString(f, n)
default:
return nil, fmt.Errorf("%w type %d", ErrUnsupported, t)
}
}
func readArrayData[T any](f *File, n uint64) (s []T, err error) {
s = make([]T, n)
for i := range n {
e, err := read[T](f)
if err != nil {
return nil, err
}
s[i] = e
}
return s, nil
}
func readArrayString(f *File, n uint64) (s []string, err error) {
s = make([]string, n)
for i := range n {
e, err := readString(f)
if err != nil {
return nil, err
}
s[i] = e
}
return s, nil
}
func (f *File) Close() error {
f.keyValues.stop()
f.tensors.stop()
return f.file.Close()
}
func (f *File) KeyValue(key string) KeyValue {
if !strings.HasPrefix(key, "general.") && !strings.HasPrefix(key, "tokenizer.") {
key = f.KeyValue("general.architecture").String() + "." + key
}
if index := slices.IndexFunc(f.keyValues.values, func(kv KeyValue) bool {
return kv.Key == key
}); index >= 0 {
return f.keyValues.values[index]
}
for keyValue, ok := f.keyValues.next(); ok; keyValue, ok = f.keyValues.next() {
if keyValue.Key == key {
return keyValue
}
}
return KeyValue{}
}
func (f *File) NumKeyValues() int {
return int(f.keyValues.count)
}
func (f *File) KeyValues() iter.Seq2[int, KeyValue] {
return f.keyValues.All()
}
func (f *File) TensorInfo(name string) TensorInfo {
if index := slices.IndexFunc(f.tensors.values, func(t TensorInfo) bool {
return t.Name == name
}); index >= 0 {
return f.tensors.values[index]
}
// fast-forward through key values if we haven't already
_ = f.keyValues.rest()
for tensor, ok := f.tensors.next(); ok; tensor, ok = f.tensors.next() {
if tensor.Name == name {
return tensor
}
}
return TensorInfo{}
}
func (f *File) NumTensors() int {
return int(f.tensors.count)
}
func (f *File) TensorInfos() iter.Seq2[int, TensorInfo] {
// fast forward through key values if we haven't already
f.keyValues.rest()
return f.tensors.All()
}
func (f *File) TensorReader(name string) (TensorInfo, io.Reader, error) {
t := f.TensorInfo(name)
if t.NumBytes() == 0 {
return TensorInfo{}, nil, fmt.Errorf("tensor %s not found", name)
}
// fast forward through tensor info if we haven't already
_ = f.tensors.rest()
return t, io.NewSectionReader(f.file, f.offset+int64(t.Offset), t.NumBytes()), nil
}

View File

@@ -1,249 +0,0 @@
package gguf_test
import (
"bytes"
"os"
"strconv"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/fs/gguf"
)
func createBinFile(tb testing.TB) string {
tb.Helper()
f, err := os.CreateTemp(tb.TempDir(), "")
if err != nil {
tb.Fatal(err)
}
defer f.Close()
kv := ggml.KV{
"general.architecture": "llama",
"llama.block_count": uint32(8),
"llama.embedding_length": uint32(3),
"llama.attention.head_count": uint32(2),
"llama.attention.head_count_kv": uint32(2),
"llama.attention.key_length": uint32(3),
"llama.rope.dimension_count": uint32(4),
"llama.rope.freq_base": float32(10000.0),
"llama.rope.freq_scale": float32(1.0),
"llama.attention.layer_norm_rms_epsilon": float32(1e-6),
"tokenizer.ggml.eos_token_id": uint32(0),
"tokenizer.ggml.eos_token_ids": []int32{1, 2, 3},
"tokenizer.ggml.tokens": []string{"hello", "world"},
"tokenizer.ggml.scores": []float32{0, 1},
}
tensors := []*ggml.Tensor{
{
Name: "token_embd.weight",
Kind: 0,
Shape: []uint64{2, 3},
WriterTo: bytes.NewBuffer(make([]byte, 4*2*3)),
},
{
Name: "output.weight",
Kind: 0,
Shape: []uint64{3, 2},
WriterTo: bytes.NewBuffer(make([]byte, 4*3*2)),
},
}
for i := range 8 {
tensors = append(tensors, &ggml.Tensor{
Name: "blk." + strconv.Itoa(i) + ".attn_q.weight",
Kind: 0,
Shape: []uint64{3, 3},
WriterTo: bytes.NewBuffer(make([]byte, 4*3*3)),
}, &ggml.Tensor{
Name: "blk." + strconv.Itoa(i) + ".attn_k.weight",
Kind: 0,
Shape: []uint64{3, 3},
WriterTo: bytes.NewBuffer(make([]byte, 4*3*3)),
}, &ggml.Tensor{
Name: "blk." + strconv.Itoa(i) + ".attn_v.weight",
Kind: 0,
Shape: []uint64{3, 3},
WriterTo: bytes.NewBuffer(make([]byte, 4*3*3)),
}, &ggml.Tensor{
Name: "blk." + strconv.Itoa(i) + ".attn_output.weight",
Kind: 0,
Shape: []uint64{3, 3},
WriterTo: bytes.NewBuffer(make([]byte, 4*3*3)),
})
}
if err := ggml.WriteGGUF(f, kv, tensors); err != nil {
tb.Fatal(err)
}
return f.Name()
}
func TestRead(t *testing.T) {
f, err := gguf.Open(createBinFile(t))
if err != nil {
t.Fatal(err)
}
defer f.Close()
if got := f.KeyValue("does.not.exist").Valid(); got {
t.Errorf(`KeyValue("does.not.exist").Exists() = %v, want false`, got)
}
if got := f.KeyValue("general.architecture").String(); got != "llama" {
t.Errorf(`KeyValue("general.architecture").String() = %q, want %q`, got, "llama")
}
if got := f.TensorInfo("token_embd.weight"); got.Name != "token_embd.weight" {
t.Errorf(`TensorInfo("token_embd.weight").Name = %q, want %q`, got.Name, "token_embd.weight")
} else if diff := cmp.Diff(got.Shape, []uint64{2, 3}); diff != "" {
t.Errorf(`TensorInfo("token_embd.weight").Shape mismatch (-got +want):\n%s`, diff)
} else if got.Type != gguf.TensorTypeF32 {
t.Errorf(`TensorInfo("token_embd.weight").Type = %d, want %d`, got.Type, gguf.TensorTypeF32)
}
if got := f.KeyValue("block_count").Uint(); got != 8 {
t.Errorf(`KeyValue("block_count").Uint() = %d, want %d`, got, 8)
}
if diff := cmp.Diff(f.KeyValue("tokenizer.ggml.tokens").Strings(), []string{"hello", "world"}); diff != "" {
t.Errorf("KeyValue(\"tokenizer.ggml.tokens\").Strings() mismatch (-got +want):\n%s", diff)
}
if diff := cmp.Diff(f.KeyValue("tokenizer.ggml.scores").Floats(), []float64{0, 1}); diff != "" {
t.Errorf("KeyValue(\"tokenizer.ggml.scores\").Ints() mismatch (-got +want):\n%s", diff)
}
var kvs []string
for _, kv := range f.KeyValues() {
if !kv.Valid() {
t.Error("found invalid key-value pair:", kv)
}
kvs = append(kvs, kv.Key)
}
if len(kvs) != f.NumKeyValues() {
t.Errorf("iterated key count = %d, want %d", len(kvs), f.NumKeyValues())
}
if diff := cmp.Diff(kvs, []string{
"general.architecture",
"llama.block_count",
"llama.embedding_length",
"llama.attention.head_count",
"llama.attention.head_count_kv",
"llama.attention.key_length",
"llama.rope.dimension_count",
"llama.rope.freq_base",
"llama.rope.freq_scale",
"llama.attention.layer_norm_rms_epsilon",
"tokenizer.ggml.eos_token_id",
"tokenizer.ggml.eos_token_ids",
"tokenizer.ggml.tokens",
"tokenizer.ggml.scores",
}, cmpopts.SortSlices(strings.Compare)); diff != "" {
t.Errorf("KeyValues() mismatch (-got +want):\n%s", diff)
}
var tis []string
for _, ti := range f.TensorInfos() {
if !ti.Valid() {
t.Error("found invalid tensor info:", ti)
}
tis = append(tis, ti.Name)
}
if len(tis) != f.NumTensors() {
t.Errorf("iterated tensor count = %d, want %d", len(tis), f.NumTensors())
}
if diff := cmp.Diff(tis, []string{
"token_embd.weight",
"output.weight",
"blk.0.attn_q.weight",
"blk.0.attn_k.weight",
"blk.0.attn_v.weight",
"blk.0.attn_output.weight",
"blk.1.attn_q.weight",
"blk.1.attn_k.weight",
"blk.1.attn_v.weight",
"blk.1.attn_output.weight",
"blk.2.attn_q.weight",
"blk.2.attn_k.weight",
"blk.2.attn_v.weight",
"blk.2.attn_output.weight",
"blk.3.attn_q.weight",
"blk.3.attn_k.weight",
"blk.3.attn_v.weight",
"blk.3.attn_output.weight",
"blk.4.attn_q.weight",
"blk.4.attn_k.weight",
"blk.4.attn_v.weight",
"blk.4.attn_output.weight",
"blk.5.attn_q.weight",
"blk.5.attn_k.weight",
"blk.5.attn_v.weight",
"blk.5.attn_output.weight",
"blk.6.attn_q.weight",
"blk.6.attn_k.weight",
"blk.6.attn_v.weight",
"blk.6.attn_output.weight",
"blk.7.attn_q.weight",
"blk.7.attn_k.weight",
"blk.7.attn_v.weight",
"blk.7.attn_output.weight",
}, cmpopts.SortSlices(strings.Compare)); diff != "" {
t.Errorf("TensorInfos() mismatch (-got +want):\n%s", diff)
}
ti, r, err := f.TensorReader("output.weight")
if err != nil {
t.Fatalf(`TensorReader("output.weight") error: %v`, err)
}
if ti.Name != "output.weight" {
t.Errorf(`TensorReader("output.weight").Name = %q, want %q`, ti.Name, "output.weight")
} else if diff := cmp.Diff(ti.Shape, []uint64{3, 2}); diff != "" {
t.Errorf(`TensorReader("output.weight").Shape mismatch (-got +want):\n%s`, diff)
} else if ti.Type != gguf.TensorTypeF32 {
t.Errorf(`TensorReader("output.weight").Type = %d, want %d`, ti.Type, gguf.TensorTypeF32)
}
var b bytes.Buffer
if _, err := b.ReadFrom(r); err != nil {
t.Fatalf(`ReadFrom TensorReader("output.weight") error: %v`, err)
}
if b.Len() != int(ti.NumBytes()) {
t.Errorf(`ReadFrom TensorReader("output.weight") length = %d, want %d`, b.Len(), ti.NumBytes())
}
}
func BenchmarkRead(b *testing.B) {
b.ReportAllocs()
p := createBinFile(b)
for b.Loop() {
f, err := gguf.Open(p)
if err != nil {
b.Fatal(err)
}
if got := f.KeyValue("general.architecture").String(); got != "llama" {
b.Errorf("got = %q, want %q", got, "llama")
}
// Iterate through some tensors
for range f.TensorInfos() {
}
f.Close()
}
}

View File

@@ -1,90 +0,0 @@
package gguf
import (
"reflect"
"slices"
)
type KeyValue struct {
Key string
Value
}
func (kv KeyValue) Valid() bool {
return kv.Key != "" && kv.Value.value != nil
}
type Value struct {
value any
}
func value[T any](v Value, kinds ...reflect.Kind) (t T) {
vv := reflect.ValueOf(v.value)
if slices.Contains(kinds, vv.Kind()) {
t = vv.Convert(reflect.TypeOf(t)).Interface().(T)
}
return
}
func values[T any](v Value, kinds ...reflect.Kind) (ts []T) {
switch vv := reflect.ValueOf(v.value); vv.Kind() {
case reflect.Slice:
if slices.Contains(kinds, vv.Type().Elem().Kind()) {
ts = make([]T, vv.Len())
for i := range vv.Len() {
ts[i] = vv.Index(i).Convert(reflect.TypeOf(ts[i])).Interface().(T)
}
}
}
return
}
// Int returns Value as a signed integer. If it is not a signed integer, it returns 0.
func (v Value) Int() int64 {
return value[int64](v, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64)
}
// Ints returns Value as a signed integer slice. If it is not a signed integer slice, it returns nil.
func (v Value) Ints() (i64s []int64) {
return values[int64](v, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64)
}
// Uint converts an unsigned integer value to uint64. If the value is not a unsigned integer, it returns 0.
func (v Value) Uint() uint64 {
return value[uint64](v, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64)
}
// Uints returns Value as a unsigned integer slice. If it is not a unsigned integer slice, it returns nil.
func (v Value) Uints() (u64s []uint64) {
return values[uint64](v, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64)
}
// Float returns Value as a float. If it is not a float, it returns 0.
func (v Value) Float() float64 {
return value[float64](v, reflect.Float32, reflect.Float64)
}
// Floats returns Value as a float slice. If it is not a float slice, it returns nil.
func (v Value) Floats() (f64s []float64) {
return values[float64](v, reflect.Float32, reflect.Float64)
}
// Bool returns Value as a boolean. If it is not a boolean, it returns false.
func (v Value) Bool() bool {
return value[bool](v, reflect.Bool)
}
// Bools returns Value as a boolean slice. If it is not a boolean slice, it returns nil.
func (v Value) Bools() (bools []bool) {
return values[bool](v, reflect.Bool)
}
// String returns Value as a string. If it is not a string, it returns an empty string.
func (v Value) String() string {
return value[string](v, reflect.String)
}
// Strings returns Value as a string slice. If it is not a string slice, it returns nil.
func (v Value) Strings() (strings []string) {
return values[string](v, reflect.String)
}

View File

@@ -1,208 +0,0 @@
package gguf
import (
"testing"
"github.com/google/go-cmp/cmp"
)
func split(name string, values map[string][]any) (matched []any, unmatched []any) {
for key, value := range values {
if key == name {
matched = value
} else {
unmatched = append(unmatched, value...)
}
}
return
}
func TestValue(t *testing.T) {
values := map[string][]any{
"int64": {int(42), int8(42), int16(42), int32(42), int64(42)},
"uint64": {uint(42), uint8(42), uint16(42), uint32(42), uint64(42)},
"float64": {float32(42), float64(42)},
"string": {"42", "hello"},
"bool": {true, false},
}
t.Run("int64", func(t *testing.T) {
matched, unmatched := split("int64", values)
for _, v := range matched {
kv := KeyValue{"key", Value{v}}
if i64 := kv.Int(); i64 != 42 {
t.Errorf("expected 42, got %d", i64)
}
}
for _, v := range unmatched {
kv := KeyValue{"key", Value{v}}
if i64 := kv.Int(); i64 != 0 {
t.Errorf("expected 42, got %d", i64)
}
}
})
t.Run("uint64", func(t *testing.T) {
matched, unmatched := split("uint64", values)
for _, v := range matched {
kv := KeyValue{"key", Value{v}}
if u64 := kv.Uint(); u64 != 42 {
t.Errorf("expected 42, got %d", u64)
}
}
for _, v := range unmatched {
kv := KeyValue{"key", Value{v}}
if u64 := kv.Uint(); u64 != 0 {
t.Errorf("expected 42, got %d", u64)
}
}
})
t.Run("float64", func(t *testing.T) {
matched, unmatched := split("float64", values)
for _, v := range matched {
kv := KeyValue{"key", Value{v}}
if f64 := kv.Float(); f64 != 42 {
t.Errorf("expected 42, got %f", f64)
}
}
for _, v := range unmatched {
kv := KeyValue{"key", Value{v}}
if f64 := kv.Float(); f64 != 0 {
t.Errorf("expected 42, got %f", f64)
}
}
})
t.Run("string", func(t *testing.T) {
matched, unmatched := split("string", values)
for _, v := range matched {
kv := KeyValue{"key", Value{v}}
if s := kv.String(); s != v {
t.Errorf("expected 42, got %s", s)
}
}
for _, v := range unmatched {
kv := KeyValue{"key", Value{v}}
if s := kv.String(); s != "" {
t.Errorf("expected 42, got %s", s)
}
}
})
t.Run("bool", func(t *testing.T) {
matched, unmatched := split("bool", values)
for _, v := range matched {
kv := KeyValue{"key", Value{v}}
if b := kv.Bool(); b != v {
t.Errorf("expected true, got %v", b)
}
}
for _, v := range unmatched {
kv := KeyValue{"key", Value{v}}
if b := kv.Bool(); b != false {
t.Errorf("expected false, got %v", b)
}
}
})
}
func TestValues(t *testing.T) {
values := map[string][]any{
"int64s": {[]int{42}, []int8{42}, []int16{42}, []int32{42}, []int64{42}},
"uint64s": {[]uint{42}, []uint8{42}, []uint16{42}, []uint32{42}, []uint64{42}},
"float64s": {[]float32{42}, []float64{42}},
"strings": {[]string{"42"}, []string{"hello"}},
"bools": {[]bool{true}, []bool{false}},
}
t.Run("int64s", func(t *testing.T) {
matched, unmatched := split("int64s", values)
for _, v := range matched {
kv := KeyValue{"key", Value{v}}
if diff := cmp.Diff(kv.Ints(), []int64{42}); diff != "" {
t.Errorf("diff: %s", diff)
}
}
for _, v := range unmatched {
kv := KeyValue{"key", Value{v}}
if i64s := kv.Ints(); i64s != nil {
t.Errorf("expected nil, got %v", i64s)
}
}
})
t.Run("uint64s", func(t *testing.T) {
matched, unmatched := split("uint64s", values)
for _, v := range matched {
kv := KeyValue{"key", Value{v}}
if diff := cmp.Diff(kv.Uints(), []uint64{42}); diff != "" {
t.Errorf("diff: %s", diff)
}
}
for _, v := range unmatched {
kv := KeyValue{"key", Value{v}}
if u64s := kv.Uints(); u64s != nil {
t.Errorf("expected nil, got %v", u64s)
}
}
})
t.Run("float64s", func(t *testing.T) {
matched, unmatched := split("float64s", values)
for _, v := range matched {
kv := KeyValue{"key", Value{v}}
if diff := cmp.Diff(kv.Floats(), []float64{42}); diff != "" {
t.Errorf("diff: %s", diff)
}
}
for _, v := range unmatched {
kv := KeyValue{"key", Value{v}}
if f64s := kv.Floats(); f64s != nil {
t.Errorf("expected nil, got %v", f64s)
}
}
})
t.Run("strings", func(t *testing.T) {
matched, unmatched := split("strings", values)
for _, v := range matched {
kv := KeyValue{"key", Value{v}}
if diff := cmp.Diff(kv.Strings(), v); diff != "" {
t.Errorf("diff: %s", diff)
}
}
for _, v := range unmatched {
kv := KeyValue{"key", Value{v}}
if s := kv.Strings(); s != nil {
t.Errorf("expected nil, got %v", s)
}
}
})
t.Run("bools", func(t *testing.T) {
matched, unmatched := split("bools", values)
for _, v := range matched {
kv := KeyValue{"key", Value{v}}
if diff := cmp.Diff(kv.Bools(), v); diff != "" {
t.Errorf("diff: %s", diff)
}
}
for _, v := range unmatched {
kv := KeyValue{"key", Value{v}}
if b := kv.Bools(); b != nil {
t.Errorf("expected nil, got %v", b)
}
}
})
}

View File

@@ -1,89 +0,0 @@
package gguf
import (
"encoding/binary"
"iter"
"log/slog"
)
type lazy[T any] struct {
count uint64
next func() (T, bool)
stop func()
values []T
// successFunc is called when all values have been successfully read.
successFunc func() error
}
func newLazy[T any](f *File, fn func() (T, error)) (*lazy[T], error) {
it := lazy[T]{}
if err := binary.Read(f.reader, binary.LittleEndian, &it.count); err != nil {
return nil, err
}
it.values = make([]T, 0)
it.next, it.stop = iter.Pull(func(yield func(T) bool) {
for i := range it.count {
t, err := fn()
if err != nil {
slog.Error("error reading tensor", "index", i, "error", err)
return
}
it.values = append(it.values, t)
if !yield(t) {
break
}
}
if it.successFunc != nil {
it.successFunc()
}
})
return &it, nil
}
func (g *lazy[T]) Values() iter.Seq[T] {
return func(yield func(T) bool) {
for _, v := range g.All() {
if !yield(v) {
break
}
}
}
}
func (g *lazy[T]) All() iter.Seq2[int, T] {
return func(yield func(int, T) bool) {
for i := range int(g.count) {
if i < len(g.values) {
if !yield(i, g.values[i]) {
break
}
} else {
t, ok := g.next()
if !ok {
break
}
if !yield(i, t) {
break
}
}
}
}
}
func (g *lazy[T]) rest() (collected bool) {
for {
_, ok := g.next()
collected = collected || ok
if !ok {
break
}
}
return collected
}

View File

@@ -1,23 +0,0 @@
package gguf
import (
"bufio"
"io"
)
type bufferedReader struct {
offset int64
*bufio.Reader
}
func newBufferedReader(rs io.ReadSeeker, size int) *bufferedReader {
return &bufferedReader{
Reader: bufio.NewReaderSize(rs, size),
}
}
func (rs *bufferedReader) Read(p []byte) (n int, err error) {
n, err = rs.Reader.Read(p)
rs.offset += int64(n)
return n, err
}

View File

@@ -1,288 +0,0 @@
package gguf
import (
"log/slog"
"strings"
)
type TensorInfo struct {
Name string
Offset uint64
Shape []uint64
Type TensorType
}
func (ti TensorInfo) Valid() bool {
return ti.Name != "" && ti.NumBytes() > 0
}
func (ti TensorInfo) NumValues() int64 {
var numItems int64 = 1
for _, dim := range ti.Shape {
numItems *= int64(dim)
}
return numItems
}
// NumBytes returns the number of bytes in the tensor.
func (ti TensorInfo) NumBytes() int64 {
return int64(float64(ti.NumValues()) * ti.Type.NumBytes())
}
func (ti TensorInfo) LogValue() slog.Value {
return slog.GroupValue(
slog.String("name", ti.Name),
slog.Int64("offset", int64(ti.Offset)),
slog.Any("shape", ti.Shape),
slog.Int64("num_values", ti.NumValues()),
slog.Int64("num_bytes", ti.NumBytes()),
slog.Any("type", ti.Type),
)
}
type TensorType uint32
const (
TensorTypeF32 TensorType = iota
TensorTypeF16
TensorTypeQ4_0
TensorTypeQ4_1
// unexported // unused in gguf
tensorTypeQ4_2
tensorTypeQ4_3
TensorTypeQ5_0
TensorTypeQ5_1
TensorTypeQ8_0
TensorTypeQ8_1
TensorTypeQ2_K
TensorTypeQ3_K
TensorTypeQ4_K
TensorTypeQ5_K
TensorTypeQ6_K
TensorTypeQ8_K
// unexported // unquantizable by ollama
tensorTypeIQ2_XXS
tensorTypeIQ2_XS
tensorTypeIQ3_XXS
tensorTypeIQ1_S
tensorTypeIQ4_NL
tensorTypeIQ3_S
tensorTypeIQ2_S
tensorTypeIQ4_XS
TensorTypeI8
TensorTypeI16
TensorTypeI32
TensorTypeI64
TensorTypeF64
// unexported // unquantizable by ollama
tensorTypeIQ1_M
TensorTypeBF16
// unexported // unused in gguf
tensorTypeQ4_0_4_4
tensorTypeQ4_0_4_8
tensorTypeQ4_0_8_8
// unexported // unquantizable by ollama
tensorTypeTQ1_0
tensorTypeTQ2_0
// unexported // unused in gguf
tensorTypeIQ4_NL_4_4
tensorTypeIQ4_NL_4_8
tensorTypeIQ4_NL_8_8
)
func (tt TensorType) NumBytes() float64 {
return float64(tt.typeSize()) / float64(tt.blockSize())
}
func (tt TensorType) typeSize() int64 {
switch tt {
case TensorTypeF32:
return 4
case TensorTypeF16:
return 2
case TensorTypeQ4_0:
return 2 + tt.blockSize()/2
case TensorTypeQ4_1:
return 2 + 2 + tt.blockSize()/2
case TensorTypeQ5_0:
return 2 + 4 + tt.blockSize()/2
case TensorTypeQ5_1:
return 2 + 2 + 4 + tt.blockSize()/2
case TensorTypeQ8_0:
return 2 + tt.blockSize()
case TensorTypeQ8_1:
return 2 + 2 + tt.blockSize()
case TensorTypeQ2_K:
return tt.blockSize()/16 + tt.blockSize()/4 + 2 + 2
case TensorTypeQ3_K:
return tt.blockSize()/8 + tt.blockSize()/4 + 12 + 2
case TensorTypeQ4_K:
return 2 + 2 + 12 + tt.blockSize()/2
case TensorTypeQ5_K:
return 2 + 2 + 12 + tt.blockSize()/8 + tt.blockSize()/2
case TensorTypeQ6_K:
return tt.blockSize()/2 + tt.blockSize()/4 + tt.blockSize()/16 + 2
case TensorTypeQ8_K:
return 4 + tt.blockSize() + 2*tt.blockSize()/16
case tensorTypeIQ2_XXS:
return 2 + 2*tt.blockSize()/8
case tensorTypeIQ2_XS:
return 2 + 2*tt.blockSize()/8 + tt.blockSize()/32
case tensorTypeIQ3_XXS:
return 2 + tt.blockSize()/4 + tt.blockSize()/8
case tensorTypeIQ1_S:
return 2 + tt.blockSize()/8 + tt.blockSize()/16
case tensorTypeIQ4_NL:
return 2 + tt.blockSize()/2
case tensorTypeIQ3_S:
return 2 + tt.blockSize()/4 + tt.blockSize()/8 + tt.blockSize()/32 + 4
case tensorTypeIQ2_S:
return 2 + tt.blockSize()/4 + tt.blockSize()/16
case tensorTypeIQ4_XS:
return 2 + 2 + tt.blockSize()/2 + tt.blockSize()/64
case TensorTypeI8:
return 1
case TensorTypeI16:
return 2
case TensorTypeI32:
return 4
case TensorTypeI64:
return 8
case TensorTypeF64:
return 8
case tensorTypeIQ1_M:
return tt.blockSize()/8 + tt.blockSize()/16 + tt.blockSize()/32
case TensorTypeBF16:
return 2
default:
return 0
}
}
func (tt TensorType) blockSize() int64 {
switch tt {
case TensorTypeF32,
TensorTypeF16,
TensorTypeI8,
TensorTypeI16,
TensorTypeI32,
TensorTypeI64,
TensorTypeF64,
TensorTypeBF16:
return 1
case TensorTypeQ4_0,
TensorTypeQ4_1,
TensorTypeQ5_0,
TensorTypeQ5_1,
TensorTypeQ8_0,
TensorTypeQ8_1,
tensorTypeIQ4_NL:
return 32
default:
return 256
}
}
func (tt TensorType) String() string {
switch tt {
case TensorTypeF32:
return "f32"
case TensorTypeF16:
return "f16"
case TensorTypeQ4_0:
return "q4_0"
case TensorTypeQ4_1:
return "q4_1"
case tensorTypeQ4_2:
return "q4_2"
case tensorTypeQ4_3:
return "q4_3"
case TensorTypeQ5_0:
return "q5_0"
case TensorTypeQ5_1:
return "q5_1"
case TensorTypeQ8_0:
return "q8_0"
case TensorTypeQ8_1:
return "q8_1"
case TensorTypeQ2_K:
return "q2_k"
case TensorTypeQ3_K:
return "q3_k"
case TensorTypeQ4_K:
return "q4_k"
case TensorTypeQ5_K:
return "q5_k"
case TensorTypeQ6_K:
return "q6_k"
case TensorTypeQ8_K:
return "q8_k"
case tensorTypeIQ2_XXS:
return "iq2_xxs"
case tensorTypeIQ2_XS:
return "iq2_xs"
case tensorTypeIQ3_XXS:
return "iq3_xxs"
case tensorTypeIQ1_S:
return "iq1_s"
case tensorTypeIQ4_NL:
return "iq4_nl"
case tensorTypeIQ3_S:
return "iq3_s"
case tensorTypeIQ2_S:
return "iq2_s"
case tensorTypeIQ4_XS:
return "iq4_xs"
case TensorTypeI8:
return "i8"
case TensorTypeI16:
return "i16"
case TensorTypeI32:
return "i32"
case TensorTypeI64:
return "i64"
case TensorTypeF64:
return "f64"
case tensorTypeIQ1_M:
return "iq1_m"
case TensorTypeBF16:
return "bf16"
case tensorTypeQ4_0_4_4:
return "q4_0_4_4"
case tensorTypeQ4_0_4_8:
return "q4_0_4_8"
case tensorTypeQ4_0_8_8:
return "q4_0_8_8"
case tensorTypeTQ1_0:
return "tq1_0"
case tensorTypeTQ2_0:
return "tq2_0"
case tensorTypeIQ4_NL_4_4:
return "iq4_nl_4_4"
case tensorTypeIQ4_NL_4_8:
return "iq4_nl_4_8"
case tensorTypeIQ4_NL_8_8:
return "iq4_nl_8_8"
default:
return "unknown"
}
}
func (tt TensorType) LogValue() slog.Value {
return slog.GroupValue(
slog.Uint64("value", uint64(tt)),
slog.String("name", strings.ToUpper(tt.String())),
slog.Int64("size", tt.typeSize()),
slog.Int64("block_size", tt.blockSize()),
slog.Float64("num_bytes", tt.NumBytes()),
)
}

2
go.mod
View File

@@ -19,7 +19,7 @@ require (
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
github.com/dlclark/regexp2 v1.11.4
github.com/emirpasic/gods/v2 v2.0.0-alpha
github.com/google/go-cmp v0.7.0
github.com/google/go-cmp v0.6.0
github.com/mattn/go-runewidth v0.0.14
github.com/nlpodyssey/gopickle v0.3.0
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c

4
go.sum
View File

@@ -112,8 +112,8 @@ github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=

View File

@@ -1,102 +0,0 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Jesse Gross <jesse@ollama.com>
Date: Thu, 24 Apr 2025 14:48:51 -0700
Subject: [PATCH] ggml: Export GPU UUIDs
This enables matching up devices and information reported by the backend
with tools (e.g. nvidia-smi) and system management libraries (e.g. nvml).
---
ggml/include/ggml-backend.h | 1 +
ggml/src/ggml-cuda/ggml-cuda.cu | 33 ++++++++++++++++++++++++++++++++
ggml/src/ggml-metal/ggml-metal.m | 1 +
3 files changed, 35 insertions(+)
diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h
index 74e46716..a880df33 100644
--- a/ggml/include/ggml-backend.h
+++ b/ggml/include/ggml-backend.h
@@ -152,6 +152,7 @@ extern "C" {
struct ggml_backend_dev_props {
const char * name;
const char * description;
+ const char * uuid;
size_t memory_free;
size_t memory_total;
enum ggml_backend_dev_type type;
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index cb0d8528..4c829153 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -2884,6 +2884,7 @@ struct ggml_backend_cuda_device_context {
int device;
std::string name;
std::string description;
+ std::string uuid;
};
static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) {
@@ -2896,6 +2897,11 @@ static const char * ggml_backend_cuda_device_get_description(ggml_backend_dev_t
return ctx->description.c_str();
}
+static const char * ggml_backend_cuda_device_get_uuid(ggml_backend_dev_t dev) {
+ ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
+ return ctx->uuid.c_str();
+}
+
static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
ggml_cuda_set_device(ctx->device);
@@ -2910,6 +2916,7 @@ static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend
static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
props->name = ggml_backend_cuda_device_get_name(dev);
props->description = ggml_backend_cuda_device_get_description(dev);
+ props->uuid = ggml_backend_cuda_device_get_uuid(dev);
props->type = ggml_backend_cuda_device_get_type(dev);
ggml_backend_cuda_device_get_memory(dev, &props->memory_free, &props->memory_total);
@@ -3458,6 +3465,32 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
CUDA_CHECK(cudaGetDeviceProperties(&prop, i));
dev_ctx->description = prop.name;
+ #if !defined(GGML_USE_HIP)
+ char uuid[64];
+ snprintf(uuid, sizeof(uuid),
+ "GPU-%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x",
+ (unsigned char)prop.uuid.bytes[0],
+ (unsigned char)prop.uuid.bytes[1],
+ (unsigned char)prop.uuid.bytes[2],
+ (unsigned char)prop.uuid.bytes[3],
+ (unsigned char)prop.uuid.bytes[4],
+ (unsigned char)prop.uuid.bytes[5],
+ (unsigned char)prop.uuid.bytes[6],
+ (unsigned char)prop.uuid.bytes[7],
+ (unsigned char)prop.uuid.bytes[8],
+ (unsigned char)prop.uuid.bytes[9],
+ (unsigned char)prop.uuid.bytes[10],
+ (unsigned char)prop.uuid.bytes[11],
+ (unsigned char)prop.uuid.bytes[12],
+ (unsigned char)prop.uuid.bytes[13],
+ (unsigned char)prop.uuid.bytes[14],
+ (unsigned char)prop.uuid.bytes[15]
+ );
+ dev_ctx->uuid = uuid;
+ #else
+ dev_ctx->uuid = "GPU-" + std::string(prop.uuid.bytes, 16);
+ #endif
+
ggml_backend_dev_t dev = new ggml_backend_device {
/* .iface = */ ggml_backend_cuda_device_interface,
/* .reg = */ &reg,
diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m
index 1b56f858..ee4f2dcb 100644
--- a/ggml/src/ggml-metal/ggml-metal.m
+++ b/ggml/src/ggml-metal/ggml-metal.m
@@ -5703,6 +5703,7 @@ static enum ggml_backend_dev_type ggml_backend_metal_device_get_type(ggml_backen
static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
props->name = ggml_backend_metal_device_get_name(dev);
props->description = ggml_backend_metal_device_get_description(dev);
+ props->uuid = "0";
props->type = ggml_backend_metal_device_get_type(dev);
ggml_backend_metal_device_get_memory(dev, &props->memory_free, &props->memory_total);
props->caps = (struct ggml_backend_dev_caps) {

View File

@@ -124,10 +124,6 @@ type DeviceMemory struct {
// may not be persistent across instances of the runner.
Name string
// UUID is a unique persistent identifier for the device for matching
// with system management libraries
UUID string
// Weights is the per-layer memory needed for the model weights.
Weights []Memory
@@ -156,10 +152,6 @@ func (m DeviceMemory) LogValue() slog.Value {
attrs = append(attrs, slog.Any("Graph", m.Graph))
}
if len(attrs) > 0 && m.UUID != "" {
attrs = append([]slog.Attr{slog.String("UUID", m.UUID)}, attrs...)
}
return slog.GroupValue(attrs...)
}

View File

@@ -136,9 +136,6 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
}
requiredMemory.CPU.Name = C.GoString(C.ggml_backend_dev_name(cpuDeviceBufferType.d))
var props C.struct_ggml_backend_dev_props
C.ggml_backend_dev_get_props(cpuDeviceBufferType.d, &props)
requiredMemory.CPU.UUID = C.GoString(props.uuid)
requiredMemory.CPU.Weights = make([]ml.Memory, blocks+1)
requiredMemory.CPU.Cache = make([]ml.Memory, blocks+1)
@@ -153,9 +150,6 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
})
btDeviceMemory[bt] = &requiredMemory.GPUs[i]
requiredMemory.GPUs[i].Name = C.GoString(C.ggml_backend_dev_name(d))
var props C.struct_ggml_backend_dev_props
C.ggml_backend_dev_get_props(d, &props)
requiredMemory.GPUs[i].UUID = C.GoString(props.uuid)
requiredMemory.GPUs[i].Weights = make([]ml.Memory, blocks+1)
requiredMemory.GPUs[i].Cache = make([]ml.Memory, blocks+1)
}

View File

@@ -152,7 +152,6 @@ extern "C" {
struct ggml_backend_dev_props {
const char * name;
const char * description;
const char * uuid;
size_t memory_free;
size_t memory_total;
enum ggml_backend_dev_type type;

View File

@@ -2884,7 +2884,6 @@ struct ggml_backend_cuda_device_context {
int device;
std::string name;
std::string description;
std::string uuid;
};
static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) {
@@ -2897,11 +2896,6 @@ static const char * ggml_backend_cuda_device_get_description(ggml_backend_dev_t
return ctx->description.c_str();
}
static const char * ggml_backend_cuda_device_get_uuid(ggml_backend_dev_t dev) {
ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
return ctx->uuid.c_str();
}
static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
ggml_cuda_set_device(ctx->device);
@@ -2916,7 +2910,6 @@ static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend
static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
props->name = ggml_backend_cuda_device_get_name(dev);
props->description = ggml_backend_cuda_device_get_description(dev);
props->uuid = ggml_backend_cuda_device_get_uuid(dev);
props->type = ggml_backend_cuda_device_get_type(dev);
ggml_backend_cuda_device_get_memory(dev, &props->memory_free, &props->memory_total);
@@ -3465,32 +3458,6 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
CUDA_CHECK(cudaGetDeviceProperties(&prop, i));
dev_ctx->description = prop.name;
#if !defined(GGML_USE_HIP)
char uuid[64];
snprintf(uuid, sizeof(uuid),
"GPU-%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x",
(unsigned char)prop.uuid.bytes[0],
(unsigned char)prop.uuid.bytes[1],
(unsigned char)prop.uuid.bytes[2],
(unsigned char)prop.uuid.bytes[3],
(unsigned char)prop.uuid.bytes[4],
(unsigned char)prop.uuid.bytes[5],
(unsigned char)prop.uuid.bytes[6],
(unsigned char)prop.uuid.bytes[7],
(unsigned char)prop.uuid.bytes[8],
(unsigned char)prop.uuid.bytes[9],
(unsigned char)prop.uuid.bytes[10],
(unsigned char)prop.uuid.bytes[11],
(unsigned char)prop.uuid.bytes[12],
(unsigned char)prop.uuid.bytes[13],
(unsigned char)prop.uuid.bytes[14],
(unsigned char)prop.uuid.bytes[15]
);
dev_ctx->uuid = uuid;
#else
dev_ctx->uuid = "GPU-" + std::string(prop.uuid.bytes, 16);
#endif
ggml_backend_dev_t dev = new ggml_backend_device {
/* .iface = */ ggml_backend_cuda_device_interface,
/* .reg = */ &reg,

View File

@@ -5703,7 +5703,6 @@ static enum ggml_backend_dev_type ggml_backend_metal_device_get_type(ggml_backen
static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
props->name = ggml_backend_metal_device_get_name(dev);
props->description = ggml_backend_metal_device_get_description(dev);
props->uuid = "0";
props->type = ggml_backend_metal_device_get_type(dev);
ggml_backend_metal_device_get_memory(dev, &props->memory_free, &props->memory_total);
props->caps = (struct ggml_backend_dev_caps) {

View File

@@ -87,7 +87,7 @@ func (v *Vocabulary) Decode(id int32) string {
func (v *Vocabulary) SpecialVocabulary() []string {
v.specialOnce.Do(func() {
for i := range v.Values {
if v.Types[i] == TOKEN_TYPE_CONTROL {
if v.Types[i] == TOKEN_TYPE_CONTROL || v.Types[i] == TOKEN_TYPE_USER_DEFINED {
v.special = append(v.special, v.Values[i])
}
}

16
model/vocabulary_test.go Normal file
View File

@@ -0,0 +1,16 @@
package model
import "testing"
func TestVocabulary_SpecialVocabulary(t *testing.T) {
vocab := &Vocabulary{
Values: []string{"<|startoftext|>", "<|endoftext|>", "<|tool_call_start|>", "<|tool_call_end|>", "hi"},
Types: []int32{TOKEN_TYPE_CONTROL, TOKEN_TYPE_CONTROL, TOKEN_TYPE_USER_DEFINED, TOKEN_TYPE_USER_DEFINED, TOKEN_TYPE_NORMAL},
}
specialVocab := vocab.SpecialVocabulary()
if len(specialVocab) != 4 {
t.Errorf("expected 4 special tokens, got %d", len(specialVocab))
}
}

View File

@@ -1,115 +0,0 @@
package cache
import (
"fmt"
"log/slog"
"os"
"slices"
"sync"
"time"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/thinking"
"github.com/ollama/ollama/types/model"
)
// cacheEntry stores capabilities and the modification time of the model file
type cacheEntry struct {
capabilities []model.Capability
modTime time.Time
}
// ggufCapabilities is a cache for gguf model capabilities
var ggufCapabilities = &sync.Map{}
// ModelInfo contains the minimal information needed to determine capabilities
type ModelInfo struct {
ModelPath string
ProjectorPaths []string
Template *template.Template
}
// Capabilities returns the capabilities that the model supports
func Capabilities(info ModelInfo) []model.Capability {
capabilities, err := ggufCapabilties(info.ModelPath)
if err != nil {
slog.Error("could not determine gguf capabilities", "error", err)
}
if info.Template == nil {
return capabilities
}
// Check for tools capability
if slices.Contains(info.Template.Vars(), "tools") {
capabilities = append(capabilities, model.CapabilityTools)
}
// Check for insert capability
if slices.Contains(info.Template.Vars(), "suffix") {
capabilities = append(capabilities, model.CapabilityInsert)
}
// Check for vision capability in projector-based models
if len(info.ProjectorPaths) > 0 {
capabilities = append(capabilities, model.CapabilityVision)
}
// Check for thinking capability
openingTag, closingTag := thinking.InferTags(info.Template.Template)
if openingTag != "" && closingTag != "" {
capabilities = append(capabilities, model.CapabilityThinking)
}
return capabilities
}
func ggufCapabilties(modelPath string) ([]model.Capability, error) {
// Get file info to check modification time
fileInfo, err := os.Stat(modelPath)
if err != nil {
return nil, err
}
currentModTime := fileInfo.ModTime()
// Check if we have a cached entry
if cached, ok := ggufCapabilities.Load(modelPath); ok {
entry := cached.(cacheEntry)
// If the file hasn't been modified since we cached it, return the cached capabilities
if entry.modTime.Equal(currentModTime) {
return entry.capabilities, nil
}
}
// If not cached or file was modified, read the model file to determine capabilities
capabilities := []model.Capability{}
r, err := os.Open(modelPath)
if err != nil {
return nil, err
}
defer r.Close()
f, err := ggml.Decode(r, 1024)
if err != nil {
return nil, err
}
if _, ok := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]; ok {
capabilities = append(capabilities, model.CapabilityEmbedding)
} else {
capabilities = append(capabilities, model.CapabilityCompletion)
}
if _, ok := f.KV()[fmt.Sprintf("%s.vision.block_count", f.KV().Architecture())]; ok {
capabilities = append(capabilities, model.CapabilityVision)
}
// Cache the capabilities with the modification time
ggufCapabilities.Store(modelPath, cacheEntry{
capabilities: capabilities,
modTime: currentModTime,
})
return capabilities, nil
}

View File

@@ -1,211 +0,0 @@
package cache
import (
"bytes"
"maps"
"os"
"slices"
"testing"
"time"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/types/model"
)
// testGGUF creates a temporary GGUF model file for testing with custom key-value pairs
func testGGUF(tb testing.TB, customKV ggml.KV) string {
tb.Helper()
f, err := os.CreateTemp(tb.TempDir(), "test*.gguf")
if err != nil {
tb.Fatal(err)
}
defer f.Close()
kv := ggml.KV{}
maps.Copy(kv, customKV)
tensors := []*ggml.Tensor{
{
Name: "token_embd.weight",
Kind: 0,
Shape: []uint64{1, 1},
WriterTo: bytes.NewBuffer(make([]byte, 4)),
},
}
if err := ggml.WriteGGUF(f, kv, tensors); err != nil {
tb.Fatal(err)
}
return f.Name()
}
func TestCapabilities(t *testing.T) {
ggufCapabilities.Range(func(key, value any) bool {
ggufCapabilities.Delete(key)
return true
})
// Create test model paths
completionModelPath := testGGUF(t, ggml.KV{
"general.architecture": "llama",
})
visionModelPath := testGGUF(t, ggml.KV{
"general.architecture": "llama",
"llama.vision.block_count": uint32(1),
})
embeddingModelPath := testGGUF(t, ggml.KV{
"general.architecture": "bert",
"bert.pooling_type": uint32(1),
})
// Create templates
toolsInsertTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}{{ if .suffix }}{{ .suffix }}{{ end }}")
if err != nil {
t.Fatalf("Failed to parse template: %v", err)
}
chatTemplate, err := template.Parse("{{ .prompt }}")
if err != nil {
t.Fatalf("Failed to parse template: %v", err)
}
toolsTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}")
if err != nil {
t.Fatalf("Failed to parse template: %v", err)
}
testCases := []struct {
name string
model ModelInfo
expectedCaps []model.Capability
}{
{
name: "model with completion capability",
model: ModelInfo{
ModelPath: completionModelPath,
Template: chatTemplate,
},
expectedCaps: []model.Capability{model.CapabilityCompletion},
},
{
name: "model with completion, tools, and insert capability",
model: ModelInfo{
ModelPath: completionModelPath,
Template: toolsInsertTemplate,
},
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityTools, model.CapabilityInsert},
},
{
name: "model with tools capability",
model: ModelInfo{
ModelPath: completionModelPath,
Template: toolsTemplate,
},
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityTools},
},
{
name: "model with vision capability from gguf",
model: ModelInfo{
ModelPath: visionModelPath,
Template: chatTemplate,
},
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityVision},
},
{
name: "model with vision capability from projector",
model: ModelInfo{
ModelPath: completionModelPath,
ProjectorPaths: []string{"/path/to/projector"},
Template: chatTemplate,
},
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityVision},
},
{
name: "model with vision, tools, and insert capability",
model: ModelInfo{
ModelPath: visionModelPath,
Template: toolsInsertTemplate,
},
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityVision, model.CapabilityTools, model.CapabilityInsert},
},
{
name: "model with embedding capability",
model: ModelInfo{
ModelPath: embeddingModelPath,
Template: chatTemplate,
},
expectedCaps: []model.Capability{model.CapabilityEmbedding},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// First call - should read from file
caps := Capabilities(tc.model)
slices.Sort(caps)
slices.Sort(tc.expectedCaps)
if !slices.Equal(caps, tc.expectedCaps) {
t.Errorf("Expected capabilities %v, got %v", tc.expectedCaps, caps)
}
// Verify caching for models that read from GGUF
if tc.model.ModelPath != "" {
// Check that entry is cached
_, ok := ggufCapabilities.Load(tc.model.ModelPath)
if !ok {
t.Error("Expected capabilities to be cached")
}
// Second call - should use cache
caps2 := Capabilities(tc.model)
slices.Sort(caps2)
if !slices.Equal(caps, caps2) {
t.Errorf("Cached capabilities don't match original: expected %v, got %v", caps, caps2)
}
}
})
}
// Test cache invalidation on file modification
t.Run("cache invalidation", func(t *testing.T) {
// Use completion model for this test
info := ModelInfo{
ModelPath: completionModelPath,
Template: chatTemplate,
}
// Get initial cached entry
cached, ok := ggufCapabilities.Load(completionModelPath)
if !ok {
t.Fatal("Expected model to be cached from previous tests")
}
entry := cached.(cacheEntry)
// Modify the file's timestamp to the future
future := time.Now().Add(time.Hour)
err := os.Chtimes(completionModelPath, future, future)
if err != nil {
t.Fatalf("Failed to update file timestamp: %v", err)
}
// Call should re-read from file due to changed modtime
caps := Capabilities(info)
if len(caps) != 1 || caps[0] != model.CapabilityCompletion {
t.Errorf("Expected [CapabilityCompletion], got %v", caps)
}
// Check that cache was updated with new modtime
cached2, ok := ggufCapabilities.Load(completionModelPath)
if !ok {
t.Error("Expected capabilities to be cached after re-read")
}
entry2 := cached2.(cacheEntry)
if entry2.modTime.Equal(entry.modTime) {
t.Error("Expected cache entry to have updated modTime")
}
})
}

View File

@@ -23,9 +23,10 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/server/cache"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/thinking"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
)
@@ -67,14 +68,64 @@ type Model struct {
Template *template.Template
}
// Capabilities returns the capabilities that the model supports
func (m *Model) Capabilities() []model.Capability {
capabilities := []model.Capability{}
// Check for completion capability
r, err := os.Open(m.ModelPath)
if err == nil {
defer r.Close()
f, err := ggml.Decode(r, 1024)
if err == nil {
if _, ok := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]; ok {
capabilities = append(capabilities, model.CapabilityEmbedding)
} else {
capabilities = append(capabilities, model.CapabilityCompletion)
}
if _, ok := f.KV()[fmt.Sprintf("%s.vision.block_count", f.KV().Architecture())]; ok {
capabilities = append(capabilities, model.CapabilityVision)
}
} else {
slog.Error("couldn't decode ggml", "error", err)
}
} else {
slog.Error("couldn't open model file", "error", err)
}
if m.Template == nil {
return capabilities
}
// Check for tools capability
if slices.Contains(m.Template.Vars(), "tools") {
capabilities = append(capabilities, model.CapabilityTools)
}
// Check for insert capability
if slices.Contains(m.Template.Vars(), "suffix") {
capabilities = append(capabilities, model.CapabilityInsert)
}
// Check for vision capability in projector-based models
if len(m.ProjectorPaths) > 0 {
capabilities = append(capabilities, model.CapabilityVision)
}
// Check for thinking capability
openingTag, closingTag := thinking.InferTags(m.Template.Template)
if openingTag != "" && closingTag != "" {
capabilities = append(capabilities, model.CapabilityThinking)
}
return capabilities
}
// CheckCapabilities checks if the model has the specified capabilities returning an error describing
// any missing or unknown capabilities
func (m *Model) CheckCapabilities(want ...model.Capability) error {
available := cache.Capabilities(cache.ModelInfo{
ModelPath: m.ModelPath,
ProjectorPaths: m.ProjectorPaths,
Template: m.Template,
})
available := m.Capabilities()
var errs []error
// Map capabilities to their corresponding error

View File

@@ -1,42 +1,252 @@
package server
import (
"bytes"
"encoding/binary"
"errors"
"os"
"path/filepath"
"strings"
"testing"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/types/model"
)
func TestModelCheckCapabilities(t *testing.T) {
// Create simple model file for tests that don't depend on GGUF content
completionModelPath, _ := createBinFile(t, ggml.KV{
"general.architecture": "llama",
}, []*ggml.Tensor{})
// Constants for GGUF magic bytes and version
var (
ggufMagic = []byte{0x47, 0x47, 0x55, 0x46} // "GGUF"
ggufVer = uint32(3) // Version 3
)
// Create vision model (llama architecture with vision block count)
visionModelPath, _ := createBinFile(t, ggml.KV{
"general.architecture": "llama",
"llama.vision.block_count": uint32(1),
}, []*ggml.Tensor{})
// Helper function to create mock GGUF data
func createMockGGUFData(architecture string, vision bool) []byte {
var buf bytes.Buffer
// Create embedding model (bert architecture with pooling type)
embeddingModelPath, _ := createBinFile(t, ggml.KV{
"general.architecture": "bert",
"bert.pooling_type": uint32(1),
}, []*ggml.Tensor{})
// Write GGUF header
buf.Write(ggufMagic)
binary.Write(&buf, binary.LittleEndian, ggufVer)
// Write tensor count (0 for our test)
var numTensors uint64 = 0
binary.Write(&buf, binary.LittleEndian, numTensors)
// Calculate number of metadata entries
numMetaEntries := uint64(1) // architecture entry
if vision {
numMetaEntries++
}
// Add embedding entry if architecture is "bert"
if architecture == "bert" {
numMetaEntries++
}
binary.Write(&buf, binary.LittleEndian, numMetaEntries)
// Write architecture metadata
archKey := "general.architecture"
keyLen := uint64(len(archKey))
binary.Write(&buf, binary.LittleEndian, keyLen)
buf.WriteString(archKey)
// String type (8)
var strType uint32 = 8
binary.Write(&buf, binary.LittleEndian, strType)
// String length
strLen := uint64(len(architecture))
binary.Write(&buf, binary.LittleEndian, strLen)
buf.WriteString(architecture)
if vision {
visionKey := architecture + ".vision.block_count"
keyLen = uint64(len(visionKey))
binary.Write(&buf, binary.LittleEndian, keyLen)
buf.WriteString(visionKey)
// uint32 type (4)
var uint32Type uint32 = 4
binary.Write(&buf, binary.LittleEndian, uint32Type)
// uint32 value (1)
var countVal uint32 = 1
binary.Write(&buf, binary.LittleEndian, countVal)
}
// Write embedding metadata if architecture is "bert"
if architecture == "bert" {
poolKey := architecture + ".pooling_type"
keyLen = uint64(len(poolKey))
binary.Write(&buf, binary.LittleEndian, keyLen)
buf.WriteString(poolKey)
// uint32 type (4)
var uint32Type uint32 = 4
binary.Write(&buf, binary.LittleEndian, uint32Type)
// uint32 value (1)
var poolingVal uint32 = 1
binary.Write(&buf, binary.LittleEndian, poolingVal)
}
return buf.Bytes()
}
func TestModelCapabilities(t *testing.T) {
// Create a temporary directory for test files
tempDir := t.TempDir()
// Create different types of mock model files
completionModelPath := filepath.Join(tempDir, "model.bin")
visionModelPath := filepath.Join(tempDir, "vision_model.bin")
embeddingModelPath := filepath.Join(tempDir, "embedding_model.bin")
// Create a simple model file for tests that don't depend on GGUF content
simpleModelPath := filepath.Join(tempDir, "simple_model.bin")
if err := errors.Join(
os.WriteFile(completionModelPath, createMockGGUFData("llama", false), 0o644),
os.WriteFile(visionModelPath, createMockGGUFData("llama", true), 0o644),
os.WriteFile(embeddingModelPath, createMockGGUFData("bert", false), 0o644),
os.WriteFile(simpleModelPath, []byte("dummy model data"), 0o644),
); err != nil {
t.Fatalf("Failed to create model files: %v", err)
}
toolsInsertTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}{{ if .suffix }}{{ .suffix }}{{ end }}")
if err != nil {
t.Fatalf("Failed to parse template: %v", err)
}
chatTemplate, err := template.Parse("{{ .prompt }}")
if err != nil {
t.Fatalf("Failed to parse template: %v", err)
}
toolsTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}")
if err != nil {
t.Fatalf("Failed to parse template: %v", err)
}
testModels := []struct {
name string
model Model
expectedCaps []model.Capability
}{
{
name: "model with completion capability",
model: Model{
ModelPath: completionModelPath,
Template: chatTemplate,
},
expectedCaps: []model.Capability{model.CapabilityCompletion},
},
{
name: "model with completion, tools, and insert capability",
model: Model{
ModelPath: completionModelPath,
Template: toolsInsertTemplate,
},
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityTools, model.CapabilityInsert},
},
{
name: "model with tools and insert capability",
model: Model{
ModelPath: simpleModelPath,
Template: toolsInsertTemplate,
},
expectedCaps: []model.Capability{model.CapabilityTools, model.CapabilityInsert},
},
{
name: "model with tools capability",
model: Model{
ModelPath: simpleModelPath,
Template: toolsTemplate,
},
expectedCaps: []model.Capability{model.CapabilityTools},
},
{
name: "model with vision capability",
model: Model{
ModelPath: visionModelPath,
Template: chatTemplate,
},
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityVision},
},
{
name: "model with vision, tools, and insert capability",
model: Model{
ModelPath: visionModelPath,
Template: toolsInsertTemplate,
},
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityVision, model.CapabilityTools, model.CapabilityInsert},
},
{
name: "model with embedding capability",
model: Model{
ModelPath: embeddingModelPath,
Template: chatTemplate,
},
expectedCaps: []model.Capability{model.CapabilityEmbedding},
},
}
// compare two slices of model.Capability regardless of order
compareCapabilities := func(a, b []model.Capability) bool {
if len(a) != len(b) {
return false
}
aCount := make(map[model.Capability]int)
for _, cap := range a {
aCount[cap]++
}
bCount := make(map[model.Capability]int)
for _, cap := range b {
bCount[cap]++
}
for cap, count := range aCount {
if bCount[cap] != count {
return false
}
}
return true
}
for _, tt := range testModels {
t.Run(tt.name, func(t *testing.T) {
// Test Capabilities method
caps := tt.model.Capabilities()
if !compareCapabilities(caps, tt.expectedCaps) {
t.Errorf("Expected capabilities %v, got %v", tt.expectedCaps, caps)
}
})
}
}
func TestModelCheckCapabilities(t *testing.T) {
// Create a temporary directory for test files
tempDir := t.TempDir()
visionModelPath := filepath.Join(tempDir, "vision_model.bin")
simpleModelPath := filepath.Join(tempDir, "model.bin")
embeddingModelPath := filepath.Join(tempDir, "embedding_model.bin")
if err := errors.Join(
os.WriteFile(simpleModelPath, []byte("dummy model data"), 0o644),
os.WriteFile(visionModelPath, createMockGGUFData("llama", true), 0o644),
os.WriteFile(embeddingModelPath, createMockGGUFData("bert", false), 0o644),
); err != nil {
t.Fatalf("Failed to create model files: %v", err)
}
toolsInsertTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}{{ if .suffix }}{{ .suffix }}{{ end }}")
if err != nil {
t.Fatalf("Failed to parse template: %v", err)
}
chatTemplate, err := template.Parse("{{ .prompt }}")
if err != nil {
t.Fatalf("Failed to parse template: %v", err)
}
toolsTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}")
if err != nil {
t.Fatalf("Failed to parse template: %v", err)
@@ -51,7 +261,7 @@ func TestModelCheckCapabilities(t *testing.T) {
{
name: "completion model without tools capability",
model: Model{
ModelPath: completionModelPath,
ModelPath: simpleModelPath,
Template: chatTemplate,
},
checkCaps: []model.Capability{model.CapabilityTools},
@@ -60,7 +270,7 @@ func TestModelCheckCapabilities(t *testing.T) {
{
name: "model with all needed capabilities",
model: Model{
ModelPath: completionModelPath,
ModelPath: simpleModelPath,
Template: toolsInsertTemplate,
},
checkCaps: []model.Capability{model.CapabilityTools, model.CapabilityInsert},
@@ -68,7 +278,7 @@ func TestModelCheckCapabilities(t *testing.T) {
{
name: "model missing insert capability",
model: Model{
ModelPath: completionModelPath,
ModelPath: simpleModelPath,
Template: toolsTemplate,
},
checkCaps: []model.Capability{model.CapabilityInsert},
@@ -77,7 +287,7 @@ func TestModelCheckCapabilities(t *testing.T) {
{
name: "model missing vision capability",
model: Model{
ModelPath: completionModelPath,
ModelPath: simpleModelPath,
Template: toolsTemplate,
},
checkCaps: []model.Capability{model.CapabilityVision},
@@ -102,7 +312,7 @@ func TestModelCheckCapabilities(t *testing.T) {
{
name: "unknown capability",
model: Model{
ModelPath: completionModelPath,
ModelPath: simpleModelPath,
Template: chatTemplate,
},
checkCaps: []model.Capability{"unknown"},

View File

@@ -59,7 +59,7 @@ type DiskCache struct {
testHookBeforeFinalWrite func(f *os.File)
}
// PutString is a convenience function for c.Put(d, strings.NewReader(s), int64(len(s))).
// PutBytes is a convenience function for c.Put(d, strings.NewReader(s), int64(len(s))).
func PutBytes[S string | []byte](c *DiskCache, d Digest, data S) error {
return c.Put(d, bytes.NewReader([]byte(data)), int64(len(data)))
}

View File

@@ -257,8 +257,16 @@ func TestQuantizeModel(t *testing.T) {
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
p, _ := createBinFile(t, tt.kv, tt.tensors)
fp, err := os.Open(p)
f, err := os.CreateTemp(t.TempDir(), tt.name)
if err != nil {
t.Fatal(err.Error())
}
defer f.Close()
err = fsggml.WriteGGUF(f, tt.kv, tt.tensors)
if err != nil {
t.Fatalf("failed to create initial model: %s", err)
}
fp, err := os.Open(f.Name())
if err != nil {
t.Fatal(err.Error())
}

View File

@@ -34,7 +34,6 @@ import (
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/openai"
"github.com/ollama/ollama/server/cache"
"github.com/ollama/ollama/server/internal/client/ollama"
"github.com/ollama/ollama/server/internal/registry"
"github.com/ollama/ollama/template"
@@ -820,17 +819,13 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
}
resp := &api.ShowResponse{
License: strings.Join(m.License, "\n"),
System: m.System,
Template: m.Template.String(),
Details: modelDetails,
Messages: msgs,
Capabilities: cache.Capabilities(cache.ModelInfo{
ModelPath: m.ModelPath,
Template: m.Template,
ProjectorPaths: m.ProjectorPaths,
}),
ModifiedAt: manifest.fi.ModTime(),
License: strings.Join(m.License, "\n"),
System: m.System,
Template: m.Template.String(),
Details: modelDetails,
Messages: msgs,
Capabilities: m.Capabilities(),
ModifiedAt: manifest.fi.ModTime(),
}
var params []string

View File

@@ -112,7 +112,11 @@ func newScenarioRequest(t *testing.T, ctx context.Context, modelName string, est
b.ctx, b.ctxDone = context.WithCancel(ctx)
t.Helper()
p, _ := createBinFile(t, ggml.KV{
f, err := os.CreateTemp(t.TempDir(), modelName)
require.NoError(t, err)
defer f.Close()
require.NoError(t, ggml.WriteGGUF(f, ggml.KV{
"general.architecture": "llama",
"llama.context_length": uint32(32),
"llama.embedding_length": uint32(4096),
@@ -125,14 +129,14 @@ func newScenarioRequest(t *testing.T, ctx context.Context, modelName string, est
}, []*ggml.Tensor{
{Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
{Name: "output.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
})
}))
require.NoError(t, err)
fname := f.Name()
model := &Model{Name: modelName, ModelPath: fname}
b.f, err = llm.LoadModel(model.ModelPath, 0)
require.NoError(t, err)
model := &Model{Name: modelName, ModelPath: p}
f, err := llm.LoadModel(model.ModelPath, 0)
if err != nil {
t.Fatal(err)
}
b.f = f
if duration == nil {
duration = &api.Duration{Duration: 5 * time.Millisecond}
}

View File

@@ -18,9 +18,8 @@ const (
)
type Parser struct {
tag string
names []string
properties []string
tag string
tools []api.Tool
state toolsState
buffer []byte
@@ -34,15 +33,10 @@ func NewParser(tmpl *template.Template, tools []api.Tool) *Parser {
}
func NewParserWithTag(tools []api.Tool, tag string) *Parser {
var p Parser
for _, t := range tools {
p.names = append(p.names, t.Function.Name)
for r := range t.Function.Parameters.Properties {
p.properties = append(p.properties, r)
}
return &Parser{
tag: tag,
tools: tools,
}
p.tag = tag
return &p
}
// Add processes a string input to parse tool calls and content that
@@ -121,36 +115,40 @@ func (p *Parser) findTag() (int, bool) {
// parseToolCall finds the next complete tool call in the buffer
// incrementing n and advancing the buffer.
func (p *Parser) parseToolCall() *api.ToolCall {
var name string
var args map[string]any
var tool *api.Tool
var end int = len(p.buffer)
var i int
// find tool name
var i int
for _, n := range p.names {
for _, t := range p.tools {
n := t.Function.Name
if i = bytes.Index(p.buffer, []byte(n)); i != -1 {
if i+len(n) < end {
name = n
tool = &t
end = i + len(n)
}
}
}
if name == "" {
if tool == nil {
return nil
}
if args, i = p.findArguments(); args == nil {
return nil
}
// only look for arguments if the tool has parameters
args := map[string]any{}
if len(tool.Function.Parameters.Properties) > 0 {
if args, i = p.findArguments(*tool); args == nil {
return nil
}
if i > end {
end = i
if i > end {
end = i
}
}
tc := &api.ToolCall{
Function: api.ToolCallFunction{
Name: name,
Name: tool.Function.Name,
Arguments: args,
Index: p.n,
},
@@ -162,13 +160,17 @@ func (p *Parser) parseToolCall() *api.ToolCall {
}
// findArguments returns the first object that appears to be
// arguments and the position where the arguments end, returning nil and 0 if
// an invalid JSON object or non-arguments object is found first
func (p *Parser) findArguments() (map[string]any, int) {
// arguments for the provided tool, returning nil
func (p *Parser) findArguments(tool api.Tool) (map[string]any, int) {
if len(p.buffer) == 0 {
return nil, 0
}
// no arguments to parse
if len(tool.Function.Parameters.Properties) == 0 {
return nil, 0
}
var braces int
var start int = -1
var end int
@@ -184,11 +186,13 @@ func (p *Parser) findArguments() (map[string]any, int) {
}
if c == '}' {
braces--
if braces == 0 && start != -1 {
end = i + 1
object = p.buffer[start:end]
break
if start != -1 {
braces--
if braces == 0 {
end = i + 1
object = p.buffer[start:end]
break
}
}
}
}
@@ -206,24 +210,27 @@ func (p *Parser) findArguments() (map[string]any, int) {
var find func(obj any) map[string]any
find = func(obj any) map[string]any {
switch v := obj.(type) {
switch obj := obj.(type) {
case map[string]any:
// check if the object keys are valid tool properties
// TODO (jmorganca): check only sets of properties that
// go together instead of the entire set
for _, prop := range p.properties {
if _, exists := v[prop]; exists {
return v
found := true
for key := range obj {
if _, exists := tool.Function.Parameters.Properties[key]; !exists {
found = false
break
}
}
for _, value := range v {
if found {
return obj
}
for _, value := range obj {
if result := find(value); result != nil {
return result
}
}
case []any:
for _, item := range v {
for _, item := range obj {
if result := find(item); result != nil {
return result
}

View File

@@ -104,6 +104,13 @@ func TestParser(t *testing.T) {
},
},
},
{
Type: "function",
Function: api.ToolFunction{
Name: "say_hello",
Description: "Say hello",
},
},
}
tests := []struct {
@@ -144,6 +151,20 @@ func TestParser(t *testing.T) {
},
},
},
{
name: "invalid arguments",
inputs: []string{`<tool_call>{"name": "get_conditions", "arguments": {"city": "San Francisco"}}</tool_call>`},
content: "",
tmpl: qwen,
calls: nil,
},
{
name: "missing args",
inputs: []string{`<tool_call>{"name": "get_conditions"}</tool_call>`},
content: "",
tmpl: qwen,
calls: nil,
},
{
name: "text before tool call",
inputs: []string{`Let me check the weather. <tool_call>{"name": "get_temperature", "arguments": {"city": "New York"}}</tool_call>`},
@@ -161,6 +182,28 @@ func TestParser(t *testing.T) {
},
},
},
{
name: "qwen no args tool call",
inputs: []string{`Let me say hello to the user. I'll use the say_hello tool <tool_call>{"name": "say_hello"}</tool_call>`},
content: "Let me say hello to the user. I'll use the say_hello tool ",
tmpl: qwen,
calls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Index: 0,
Name: "say_hello",
Arguments: api.ToolCallFunctionArguments{},
},
},
},
},
{
name: "qwen no args with text",
inputs: []string{"Let me say hello to the user. I'll use the say_hello tool. "},
content: "Let me say hello to the user. I'll use the say_hello tool. ",
tmpl: qwen,
calls: nil,
},
{
name: "two tool calls in a list",
inputs: []string{`[TOOL_CALLS] [{"name": "get_temperature", "arguments": {"city": "London", "format": "fahrenheit"}}, {"name": "get_conditions", "arguments": {"location": "Tokyo"}}][/TOOL_CALLS]`},
@@ -189,7 +232,7 @@ func TestParser(t *testing.T) {
},
},
{
name: "two tool calls",
name: "qwen two tool calls",
inputs: []string{`Okay, let's call both tools! <tool_call>{"name": "get_temperature", "arguments": {"city": "London", "format": "fahrenheit"}}</tool_call><tool_call>{"name": "get_conditions", "arguments": {"location": "Tokyo"}}</tool_call>`},
content: "Okay, let's call both tools! ",
tmpl: qwen,
@@ -215,6 +258,30 @@ func TestParser(t *testing.T) {
},
},
},
{
name: "qwen two tool calls one with no args",
inputs: []string{`Let me check the weather. <tool_call>{"name": "say_hello"}</tool_call><tool_call>{"name": "get_conditions", "arguments": {"location": "Tokyo"}}`},
content: "Let me check the weather. ",
tmpl: qwen,
calls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Index: 0,
Name: "say_hello",
Arguments: api.ToolCallFunctionArguments{},
},
},
{
Function: api.ToolCallFunction{
Index: 1,
Name: "get_conditions",
Arguments: api.ToolCallFunctionArguments{
"location": "Tokyo",
},
},
},
},
},
{
name: "deepseek",
inputs: []string{"<think>Wait, I need to call a tool</think><|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_temperature\n```json\n{\"city\": \"Tokyo\"}\n```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>"},
@@ -338,6 +405,52 @@ func TestParser(t *testing.T) {
content: "for { fmt.Println(\"hello\") }",
tmpl: json,
},
{
name: "json no args tool call",
inputs: []string{
"{\"name\": \"say_hello\"}",
},
content: "",
tmpl: json,
calls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Index: 0,
Name: "say_hello",
Arguments: api.ToolCallFunctionArguments{},
},
},
},
},
{
name: "json no args no tool call",
inputs: []string{
"I'll use the say_hello tool to say hello to the user.",
},
content: "I'll use the say_hello tool to say hello to the user.",
tmpl: json,
calls: nil,
},
// TODO (jmorganca): this is a false positive, we should
// not be parsing this as a tool call
{
name: "json no args false positive",
inputs: []string{
`{say_hello!!!}`,
},
content: "",
tmpl: json,
calls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Index: 0,
Name: "say_hello",
Arguments: api.ToolCallFunctionArguments{},
},
},
},
},
{
name: "list multiple",
inputs: []string{
@@ -380,6 +493,30 @@ func TestParser(t *testing.T) {
},
{
name: "list partial",
inputs: []string{
"[{",
"\"name\": \"get_conditions\", ",
"\"arguments\": {",
"\"location\": \"Tokyo\"",
"}",
"}",
},
content: "",
tmpl: list,
calls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Index: 0,
Name: "get_conditions",
Arguments: api.ToolCallFunctionArguments{
"location": "Tokyo",
},
},
},
},
},
{
name: "list invalid",
inputs: []string{
"[",
"{",
@@ -393,6 +530,33 @@ func TestParser(t *testing.T) {
tmpl: list,
calls: nil,
},
{
name: "list trailing ]",
inputs: []string{
"[",
"{",
"\"name\": \"get_conditions\", ",
"\"arguments\": {",
"\"location\": \"Tokyo\"",
"}",
"}",
"]",
"]",
},
content: "",
tmpl: list,
calls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Index: 0,
Name: "get_conditions",
Arguments: api.ToolCallFunctionArguments{
"location": "Tokyo",
},
},
},
},
},
{
name: "list not a tool call",
inputs: []string{
@@ -404,6 +568,26 @@ func TestParser(t *testing.T) {
tmpl: list,
calls: nil,
},
{
name: "list with no arguments",
inputs: []string{
"[",
"{",
"\"name\": \"say_hello\"",
"}",
},
content: "",
tmpl: list,
calls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Index: 0,
Name: "say_hello",
Arguments: api.ToolCallFunctionArguments{},
},
},
},
},
}
for _, tt := range tests {
@@ -700,25 +884,75 @@ func TestFindTag(t *testing.T) {
}
func TestFindArguments(t *testing.T) {
tool := api.Tool{
Type: "function",
Function: api.ToolFunction{
Name: "get_temperature",
Description: "Retrieve the temperature for a given location",
Parameters: struct {
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required"`
Properties map[string]struct {
Type api.PropertyType `json:"type"`
Items any `json:"items,omitempty"`
Description string `json:"description"`
Enum []any `json:"enum,omitempty"`
} `json:"properties"`
}{
Type: "object",
Properties: map[string]struct {
Type api.PropertyType `json:"type"`
Items any `json:"items,omitempty"`
Description string `json:"description"`
Enum []any `json:"enum,omitempty"`
}{
"format": {
Type: api.PropertyType{"string"},
Description: "The format to return the temperature in",
Enum: []any{"fahrenheit", "celsius"},
},
"location": {
Type: api.PropertyType{"string"},
Description: "The location to get the temperature for",
},
},
},
},
}
tool2 := api.Tool{
Type: "function",
Function: api.ToolFunction{
Name: "say_hello",
Description: "Say hello to the user",
},
}
tests := []struct {
name string
buffer []byte
want map[string]any
tool api.Tool
}{
{
name: "empty string",
buffer: []byte{},
want: nil,
tool: tool,
},
{
name: "whitespace only",
buffer: []byte(" \n\t "),
want: nil,
tool: tool,
},
{
name: "unbalanced braces - missing closing",
buffer: []byte(`{"format": "fahrenheit", "location": "San Francisco"`),
want: nil,
tool: tool,
},
{
name: "unbalanced braces - extra closing",
@@ -726,11 +960,13 @@ func TestFindArguments(t *testing.T) {
want: map[string]any{
"format": "fahrenheit",
},
tool: tool,
},
{
name: "invalid JSON",
buffer: []byte(`{format: fahrenheit, location: "San Francisco"}`),
want: nil,
tool: tool,
},
{
name: "valid json",
@@ -739,6 +975,7 @@ func TestFindArguments(t *testing.T) {
"format": "fahrenheit",
"location": "San Francisco, CA",
},
tool: tool,
},
{
name: "valid arguments with special tokens",
@@ -747,6 +984,7 @@ func TestFindArguments(t *testing.T) {
"format": "fahrenheit",
"location": "San Francisco, CA",
},
tool: tool,
},
{
name: "valid arguments in array",
@@ -755,6 +993,7 @@ func TestFindArguments(t *testing.T) {
"format": "fahrenheit",
"location": "San Francisco, CA",
},
tool: tool,
},
{
name: "nested deep",
@@ -763,39 +1002,49 @@ func TestFindArguments(t *testing.T) {
"format": "fahrenheit",
"location": "San Francisco, CA",
},
tool: tool,
},
{
name: "one arg",
buffer: []byte(`get_weather({"location": "San Francisco, CA"})`),
buffer: []byte(`get_temperature({"location": "San Francisco, CA"})`),
want: map[string]any{
"location": "San Francisco, CA",
},
tool: tool,
},
{
name: "two args",
buffer: []byte(`[{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "format": "fahrenheit"}}, {"name": "get_weather", "arguments": {"location": "San Francisco, CA", "format": "fahrenheit"}}]`),
buffer: []byte(`[{"name": "get_temperature", "arguments": {"location": "San Francisco, CA", "format": "fahrenheit"}}, {"name": "get_weather", "arguments": {"location": "San Francisco, CA", "format": "fahrenheit"}}]`),
want: map[string]any{
"location": "San Francisco, CA",
"format": "fahrenheit",
},
tool: tool,
},
{
name: "no args",
buffer: []byte(`{"name": "say_hello"}`),
want: nil,
tool: tool2,
},
{
name: "deepseek",
buffer: []byte("<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n```json\n{\"location\": \"Tokyo\"}\n```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>"),
buffer: []byte("<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_temperature\n```json\n{\"location\": \"Tokyo\"}\n```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>"),
want: map[string]any{
"location": "Tokyo",
},
tool: tool,
},
}
for _, tt := range tests {
parser := &Parser{
buffer: tt.buffer,
properties: []string{"format", "location"},
buffer: tt.buffer,
tools: []api.Tool{tool, tool2},
}
t.Run(tt.name, func(t *testing.T) {
got, _ := parser.findArguments()
got, _ := parser.findArguments(tool)
if diff := cmp.Diff(got, tt.want); diff != "" {
t.Errorf("scanArguments() args mismatch (-got +want):\n%s", diff)