mirror of
https://github.com/ollama/ollama.git
synced 2026-01-02 04:29:51 -05:00
Compare commits
8 Commits
brucemacd/
...
brucemacd/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f2a4d058f9 | ||
|
|
63e7634014 | ||
|
|
8d51d92f3b | ||
|
|
2348fef568 | ||
|
|
883f655dd6 | ||
|
|
a6fbfc880c | ||
|
|
502028968d | ||
|
|
5a8eb0e151 |
@@ -407,6 +407,8 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Lumina](https://github.com/cushydigit/lumina.git) (A lightweight, minimal React.js frontend for interacting with Ollama servers)
|
||||
- [Tiny Notepad](https://pypi.org/project/tiny-notepad) (A lightweight, notepad-like interface to chat with ollama available on PyPI)
|
||||
- [macLlama (macOS native)](https://github.com/hellotunamayo/macLlama) (A native macOS GUI application for interacting with Ollama models, featuring a chat interface.)
|
||||
- [GPTranslate](https://github.com/philberndt/GPTranslate) (A fast and lightweight, AI powered desktop translation application written with Rust and Tauri. Features real-time translation with OpenAI/Azure/Ollama.)
|
||||
- [ollama launcher](https://github.com/NGC13009/ollama-launcher) (A launcher for Ollama, aiming to provide users with convenient functions such as ollama server launching, management, or configuration.)
|
||||
|
||||
### Cloud
|
||||
|
||||
|
||||
@@ -527,23 +527,17 @@ func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
|
||||
return err
|
||||
}
|
||||
|
||||
keys := slices.Collect(maps.Keys(kv))
|
||||
slices.Sort(keys)
|
||||
|
||||
for _, key := range keys {
|
||||
for _, key := range slices.Sorted(maps.Keys(kv)) {
|
||||
if err := ggufWriteKV(f, key, kv[key]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
slices.SortStableFunc(ts, func(a, b *Tensor) int {
|
||||
if i, j := a.block(), b.block(); i < 0 && j > 0 {
|
||||
return 1
|
||||
} else if i > 0 && j < 0 {
|
||||
return -1
|
||||
} else {
|
||||
if i, j := a.block(), b.block(); i > 0 && j > 0 {
|
||||
return cmp.Compare(i, j)
|
||||
}
|
||||
return cmp.Compare(a.Name, b.Name)
|
||||
})
|
||||
|
||||
var s uint64
|
||||
|
||||
@@ -2,62 +2,82 @@ package ggml
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"math/rand/v2"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func TestWriteGGUF(t *testing.T) {
|
||||
w, err := os.CreateTemp(t.TempDir(), "*.bin")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer w.Close()
|
||||
r := rand.New(rand.NewPCG(0, 0))
|
||||
for range 8 {
|
||||
t.Run("shuffle", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if err := WriteGGUF(w, KV{
|
||||
"general.alignment": uint32(16),
|
||||
}, []*Tensor{
|
||||
{Name: "test.0", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
|
||||
{Name: "test.1", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
|
||||
{Name: "test.2", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
|
||||
{Name: "test.3", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
|
||||
{Name: "test.4", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
|
||||
{Name: "test.5", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ts := []*Tensor{
|
||||
{Name: "token_embd.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
|
||||
{Name: "blk.0.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
|
||||
{Name: "blk.1.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
|
||||
{Name: "blk.2.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
|
||||
{Name: "blk.3.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
|
||||
{Name: "blk.4.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
|
||||
{Name: "blk.5.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
|
||||
{Name: "output_norm.weight", Shape: []uint64{3, 2}, WriterTo: bytes.NewBuffer(make([]byte, 3*2))},
|
||||
{Name: "output.weight", Shape: []uint64{3, 2}, WriterTo: bytes.NewBuffer(make([]byte, 3*2))},
|
||||
}
|
||||
|
||||
r, err := os.Open(w.Name())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer r.Close()
|
||||
r.Shuffle(len(ts), func(i, j int) {
|
||||
ts[i], ts[j] = ts[j], ts[i]
|
||||
})
|
||||
|
||||
ff, err := Decode(r, 0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
w, err := os.CreateTemp(t.TempDir(), strings.ReplaceAll(t.Name(), "/", "_")+"*.bin")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer w.Close()
|
||||
|
||||
if diff := cmp.Diff(ff.KV(), KV{
|
||||
"general.alignment": uint32(16),
|
||||
"general.parameter_count": uint64(36),
|
||||
}); diff != "" {
|
||||
t.Errorf("Mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
if err := WriteGGUF(w, KV{
|
||||
"general.alignment": uint32(16),
|
||||
}, ts); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(ff.Tensors(), Tensors{
|
||||
Offset: 336,
|
||||
items: []*Tensor{
|
||||
{Name: "test.0", Offset: 0, Shape: []uint64{2, 3}},
|
||||
{Name: "test.1", Offset: 32, Shape: []uint64{2, 3}},
|
||||
{Name: "test.2", Offset: 64, Shape: []uint64{2, 3}},
|
||||
{Name: "test.3", Offset: 96, Shape: []uint64{2, 3}},
|
||||
{Name: "test.4", Offset: 128, Shape: []uint64{2, 3}},
|
||||
{Name: "test.5", Offset: 160, Shape: []uint64{2, 3}},
|
||||
},
|
||||
}, cmp.AllowUnexported(Tensors{})); diff != "" {
|
||||
t.Errorf("Mismatch (-want +got):\n%s", diff)
|
||||
r, err := os.Open(w.Name())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
ff, err := Decode(r, 0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(KV{
|
||||
"general.alignment": uint32(16),
|
||||
"general.parameter_count": uint64(54),
|
||||
}, ff.KV()); diff != "" {
|
||||
t.Errorf("Mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(Tensors{
|
||||
Offset: 608,
|
||||
items: []*Tensor{
|
||||
{Name: "blk.0.attn_norm.weight", Offset: 0, Shape: []uint64{2, 3}},
|
||||
{Name: "blk.1.attn_norm.weight", Offset: 32, Shape: []uint64{2, 3}},
|
||||
{Name: "blk.2.attn_norm.weight", Offset: 64, Shape: []uint64{2, 3}},
|
||||
{Name: "blk.3.attn_norm.weight", Offset: 96, Shape: []uint64{2, 3}},
|
||||
{Name: "blk.4.attn_norm.weight", Offset: 128, Shape: []uint64{2, 3}},
|
||||
{Name: "blk.5.attn_norm.weight", Offset: 160, Shape: []uint64{2, 3}},
|
||||
{Name: "output.weight", Offset: 192, Shape: []uint64{3, 2}},
|
||||
{Name: "output_norm.weight", Offset: 224, Shape: []uint64{3, 2}},
|
||||
{Name: "token_embd.weight", Offset: 256, Shape: []uint64{2, 3}},
|
||||
},
|
||||
}, ff.Tensors(), cmp.AllowUnexported(Tensors{})); diff != "" {
|
||||
t.Errorf("Mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,7 +22,6 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"golang.org/x/sync/semaphore"
|
||||
|
||||
@@ -726,68 +725,10 @@ type CompletionResponse struct {
|
||||
EvalDuration time.Duration `json:"eval_duration"`
|
||||
}
|
||||
|
||||
// unicodeBufferHandler wraps a completion response callback to handle partial UTF-8 sequences.
|
||||
// This function creates a stateful closure that is NOT safe for concurrent use.
|
||||
// Each completion request should create its own handler instance.
|
||||
func unicodeBufferHandler(fn func(CompletionResponse)) func(CompletionResponse) {
|
||||
var pendingUTF8 string
|
||||
|
||||
return func(resp CompletionResponse) {
|
||||
if resp.Content == "" && !resp.Done {
|
||||
// No content to process, just pass through
|
||||
fn(resp)
|
||||
return
|
||||
}
|
||||
|
||||
// Combine any pending UTF-8 with current content
|
||||
combinedContent := pendingUTF8 + resp.Content
|
||||
pendingUTF8 = ""
|
||||
|
||||
// Check if combined content is valid UTF-8
|
||||
if utf8.ValidString(combinedContent) {
|
||||
// Valid UTF-8, send it
|
||||
resp.Content = combinedContent
|
||||
fn(resp)
|
||||
} else {
|
||||
// Invalid UTF-8
|
||||
if resp.Done {
|
||||
// This is the final response, trim incomplete UTF-8
|
||||
trimmedContent := combinedContent
|
||||
for !utf8.ValidString(trimmedContent) && len(trimmedContent) > 0 {
|
||||
trimmedContent = trimmedContent[:len(trimmedContent)-1]
|
||||
}
|
||||
resp.Content = trimmedContent
|
||||
fn(resp)
|
||||
} else {
|
||||
// Not final response, split valid and invalid parts
|
||||
validPrefix := combinedContent
|
||||
for !utf8.ValidString(validPrefix) && len(validPrefix) > 0 {
|
||||
validPrefix = validPrefix[:len(validPrefix)-1]
|
||||
}
|
||||
|
||||
if len(validPrefix) > 0 {
|
||||
// Send valid prefix
|
||||
resp.Content = validPrefix
|
||||
fn(resp)
|
||||
// Buffer the remainder
|
||||
pendingUTF8 = combinedContent[len(validPrefix):]
|
||||
} else {
|
||||
// No valid prefix, buffer everything
|
||||
pendingUTF8 = combinedContent
|
||||
// Don't send this response
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
|
||||
slog.Debug("completion request", "images", len(req.Images), "prompt", len(req.Prompt), "format", string(req.Format))
|
||||
slog.Log(ctx, logutil.LevelTrace, "completion request", "prompt", req.Prompt)
|
||||
|
||||
// Wrap the callback with unicode buffer handling
|
||||
unicodeFn := unicodeBufferHandler(fn)
|
||||
|
||||
if len(req.Format) > 0 {
|
||||
switch string(req.Format) {
|
||||
case `null`, `""`:
|
||||
@@ -913,13 +854,13 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
}
|
||||
|
||||
if c.Content != "" {
|
||||
unicodeFn(CompletionResponse{
|
||||
fn(CompletionResponse{
|
||||
Content: c.Content,
|
||||
})
|
||||
}
|
||||
|
||||
if c.Done {
|
||||
unicodeFn(c)
|
||||
fn(c)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -70,152 +70,3 @@ func TestLLMServerCompletionFormat(t *testing.T) {
|
||||
}, nil)
|
||||
checkValid(err)
|
||||
}
|
||||
|
||||
func TestUnicodeBufferHandler(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
inputResponses []CompletionResponse
|
||||
expectedResponses []CompletionResponse
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "complete_unicode",
|
||||
inputResponses: []CompletionResponse{
|
||||
{Content: "Hello", Done: false},
|
||||
{Content: " world", Done: false},
|
||||
{Content: "!", Done: true},
|
||||
},
|
||||
expectedResponses: []CompletionResponse{
|
||||
{Content: "Hello", Done: false},
|
||||
{Content: " world", Done: false},
|
||||
{Content: "!", Done: true},
|
||||
},
|
||||
description: "All responses with valid unicode should pass through unchanged",
|
||||
},
|
||||
{
|
||||
name: "incomplete_unicode_at_end_with_done",
|
||||
inputResponses: []CompletionResponse{
|
||||
{Content: "Hello", Done: false},
|
||||
{Content: string([]byte{0xF0, 0x9F}), Done: true}, // Incomplete emoji with Done=true
|
||||
},
|
||||
expectedResponses: []CompletionResponse{
|
||||
{Content: "Hello", Done: false},
|
||||
{Content: "", Done: true}, // Content is trimmed but response is still sent with Done=true
|
||||
},
|
||||
description: "When Done=true, incomplete Unicode at the end should be trimmed",
|
||||
},
|
||||
{
|
||||
name: "split_unicode_across_responses",
|
||||
inputResponses: []CompletionResponse{
|
||||
{Content: "Hello " + string([]byte{0xF0, 0x9F}), Done: false}, // First part of 😀
|
||||
{Content: string([]byte{0x98, 0x80}) + " world!", Done: true}, // Second part of 😀 and more text
|
||||
},
|
||||
expectedResponses: []CompletionResponse{
|
||||
{Content: "Hello ", Done: false}, // Incomplete Unicode trimmed
|
||||
{Content: "😀 world!", Done: true}, // Complete emoji in second response
|
||||
},
|
||||
description: "Unicode split across responses should be handled correctly",
|
||||
},
|
||||
{
|
||||
name: "incomplete_unicode_buffered",
|
||||
inputResponses: []CompletionResponse{
|
||||
{Content: "Test " + string([]byte{0xF0, 0x9F}), Done: false}, // Incomplete emoji
|
||||
{Content: string([]byte{0x98, 0x80}), Done: false}, // Complete the emoji
|
||||
{Content: " done", Done: true},
|
||||
},
|
||||
expectedResponses: []CompletionResponse{
|
||||
{Content: "Test ", Done: false}, // First part without incomplete unicode
|
||||
{Content: "😀", Done: false}, // Complete emoji
|
||||
{Content: " done", Done: true},
|
||||
},
|
||||
description: "Incomplete unicode should be buffered and combined with next response",
|
||||
},
|
||||
{
|
||||
name: "empty_response_with_done",
|
||||
inputResponses: []CompletionResponse{
|
||||
{Content: "Complete response", Done: false},
|
||||
{Content: "", Done: true}, // Empty response with Done=true
|
||||
},
|
||||
expectedResponses: []CompletionResponse{
|
||||
{Content: "Complete response", Done: false},
|
||||
{Content: "", Done: true}, // Should still be sent because Done=true
|
||||
},
|
||||
description: "Empty final response with Done=true should still be sent",
|
||||
},
|
||||
{
|
||||
name: "done_reason_preserved",
|
||||
inputResponses: []CompletionResponse{
|
||||
{Content: "Response", Done: false},
|
||||
{Content: " complete", Done: true, DoneReason: DoneReasonStop},
|
||||
},
|
||||
expectedResponses: []CompletionResponse{
|
||||
{Content: "Response", Done: false},
|
||||
{Content: " complete", Done: true, DoneReason: DoneReasonStop},
|
||||
},
|
||||
description: "DoneReason should be preserved in the final response",
|
||||
},
|
||||
{
|
||||
name: "only_incomplete_unicode_not_done",
|
||||
inputResponses: []CompletionResponse{
|
||||
{Content: string([]byte{0xF0, 0x9F}), Done: false}, // Only incomplete unicode
|
||||
},
|
||||
expectedResponses: []CompletionResponse{
|
||||
// No response expected - should be buffered
|
||||
},
|
||||
description: "Response with only incomplete unicode should be buffered if not done",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var actualResponses []CompletionResponse
|
||||
|
||||
// Create a callback that collects responses
|
||||
callback := func(resp CompletionResponse) {
|
||||
actualResponses = append(actualResponses, resp)
|
||||
}
|
||||
|
||||
// Create the unicode buffer handler
|
||||
handler := unicodeBufferHandler(callback)
|
||||
|
||||
// Send all input responses through the handler
|
||||
for _, resp := range tt.inputResponses {
|
||||
handler(resp)
|
||||
}
|
||||
|
||||
// Verify the number of responses
|
||||
if len(actualResponses) != len(tt.expectedResponses) {
|
||||
t.Fatalf("%s: got %d responses, want %d responses",
|
||||
tt.description, len(actualResponses), len(tt.expectedResponses))
|
||||
}
|
||||
|
||||
// Verify each response matches the expected one
|
||||
for i, expected := range tt.expectedResponses {
|
||||
if i >= len(actualResponses) {
|
||||
t.Fatalf("%s: missing response at index %d", tt.description, i)
|
||||
continue
|
||||
}
|
||||
|
||||
actual := actualResponses[i]
|
||||
|
||||
// Verify content
|
||||
if actual.Content != expected.Content {
|
||||
t.Errorf("%s: response[%d].Content = %q, want %q",
|
||||
tt.description, i, actual.Content, expected.Content)
|
||||
}
|
||||
|
||||
// Verify Done flag
|
||||
if actual.Done != expected.Done {
|
||||
t.Errorf("%s: response[%d].Done = %v, want %v",
|
||||
tt.description, i, actual.Done, expected.Done)
|
||||
}
|
||||
|
||||
// Verify DoneReason if specified
|
||||
if actual.DoneReason != expected.DoneReason {
|
||||
t.Errorf("%s: response[%d].DoneReason = %v, want %v",
|
||||
tt.description, i, actual.DoneReason, expected.DoneReason)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,8 +2,6 @@ package common
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/llm"
|
||||
)
|
||||
|
||||
func FindStop(sequence string, stops []string) (bool, string) {
|
||||
@@ -31,41 +29,68 @@ func ContainsStopSuffix(sequence string, stops []string) bool {
|
||||
// truncateStop removes the provided stop string from pieces,
|
||||
// returning the partial pieces with stop removed, including truncating
|
||||
// the last piece if required (and signalling if this was the case)
|
||||
func TruncateStop(resps []llm.CompletionResponse, stop string) ([]llm.CompletionResponse, bool) {
|
||||
var sequence string
|
||||
for _, resp := range resps {
|
||||
sequence += resp.Content
|
||||
func TruncateStop(pieces []string, stop string) ([]string, bool) {
|
||||
joined := strings.Join(pieces, "")
|
||||
|
||||
index := strings.Index(joined, stop)
|
||||
if index == -1 {
|
||||
return pieces, false
|
||||
}
|
||||
|
||||
idx := strings.Index(sequence, stop)
|
||||
if idx < 0 {
|
||||
return resps, false
|
||||
joined = joined[:index]
|
||||
|
||||
// Split truncated string back into pieces of original lengths
|
||||
lengths := make([]int, len(pieces))
|
||||
for i, piece := range pieces {
|
||||
lengths[i] = len(piece)
|
||||
}
|
||||
|
||||
truncated := sequence[:idx]
|
||||
if len(truncated) == 0 {
|
||||
return nil, true
|
||||
}
|
||||
|
||||
result := make([]llm.CompletionResponse, 0, len(resps))
|
||||
|
||||
// Track position in truncated sequence
|
||||
pos := 0
|
||||
truncationHappened := false
|
||||
for _, resp := range resps {
|
||||
if pos >= len(truncated) {
|
||||
var result []string
|
||||
tokenTruncated := false
|
||||
start := 0
|
||||
for _, length := range lengths {
|
||||
if start >= len(joined) {
|
||||
break
|
||||
}
|
||||
|
||||
chunk := truncated[pos:min(pos+len(resp.Content), len(truncated))]
|
||||
if len(chunk) < len(resp.Content) {
|
||||
truncationHappened = true
|
||||
end := start + length
|
||||
if end > len(joined) {
|
||||
end = len(joined)
|
||||
tokenTruncated = true
|
||||
}
|
||||
if len(chunk) > 0 {
|
||||
result = append(result, llm.CompletionResponse{Content: chunk})
|
||||
}
|
||||
pos += len(resp.Content)
|
||||
result = append(result, joined[start:end])
|
||||
start = end
|
||||
}
|
||||
|
||||
return result, truncationHappened
|
||||
return result, tokenTruncated
|
||||
}
|
||||
|
||||
func IncompleteUnicode(token string) bool {
|
||||
incomplete := false
|
||||
|
||||
// check if there is incomplete UTF-8 character at the end
|
||||
for i := 1; i < 5 && i <= len(token); i++ {
|
||||
c := token[len(token)-i]
|
||||
|
||||
if (c & 0xc0) == 0x80 {
|
||||
// continuation byte: 10xxxxxx
|
||||
continue
|
||||
}
|
||||
|
||||
if (c & 0xe0) == 0xc0 {
|
||||
// 2-byte character: 110xxxxx ...
|
||||
incomplete = i < 2
|
||||
} else if (c & 0xf0) == 0xe0 {
|
||||
// 3-byte character: 1110xxxx ...
|
||||
incomplete = i < 3
|
||||
} else if (c & 0xf8) == 0xf0 {
|
||||
// 4-byte character: 11110xxx ...
|
||||
incomplete = i < 4
|
||||
}
|
||||
|
||||
// else 1-byte character or invalid byte
|
||||
break
|
||||
}
|
||||
|
||||
return incomplete
|
||||
}
|
||||
|
||||
@@ -1,84 +1,51 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/llm"
|
||||
)
|
||||
|
||||
func TestTruncateStop(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
pieces []llm.CompletionResponse
|
||||
pieces []string
|
||||
stop string
|
||||
expected []llm.CompletionResponse
|
||||
expected []string
|
||||
expectedTrunc bool
|
||||
}{
|
||||
{
|
||||
name: "Single word",
|
||||
pieces: []llm.CompletionResponse{
|
||||
{Content: "Hello"},
|
||||
{Content: "world"},
|
||||
},
|
||||
stop: "world",
|
||||
expected: []llm.CompletionResponse{
|
||||
{Content: "Hello"},
|
||||
},
|
||||
name: "Single word",
|
||||
pieces: []string{"hello", "world"},
|
||||
stop: "world",
|
||||
expected: []string{"hello"},
|
||||
expectedTrunc: false,
|
||||
},
|
||||
{
|
||||
name: "Partial",
|
||||
pieces: []llm.CompletionResponse{
|
||||
{Content: "Hello"},
|
||||
{Content: " wor"},
|
||||
},
|
||||
stop: "or",
|
||||
expected: []llm.CompletionResponse{
|
||||
{Content: "Hello"},
|
||||
{Content: " w"},
|
||||
},
|
||||
name: "Partial",
|
||||
pieces: []string{"hello", "wor"},
|
||||
stop: "or",
|
||||
expected: []string{"hello", "w"},
|
||||
expectedTrunc: true,
|
||||
},
|
||||
{
|
||||
name: "Suffix",
|
||||
pieces: []llm.CompletionResponse{
|
||||
{Content: "Hello"},
|
||||
{Content: " there"},
|
||||
{Content: "!"},
|
||||
},
|
||||
stop: "!",
|
||||
expected: []llm.CompletionResponse{
|
||||
{Content: "Hello"},
|
||||
{Content: " there"},
|
||||
},
|
||||
name: "Suffix",
|
||||
pieces: []string{"Hello", " there", "!"},
|
||||
stop: "!",
|
||||
expected: []string{"Hello", " there"},
|
||||
expectedTrunc: false,
|
||||
},
|
||||
{
|
||||
name: "Suffix partial",
|
||||
pieces: []llm.CompletionResponse{
|
||||
{Content: "Hello"},
|
||||
{Content: " the"},
|
||||
{Content: "re!"},
|
||||
},
|
||||
stop: "there!",
|
||||
expected: []llm.CompletionResponse{
|
||||
{Content: "Hello"},
|
||||
{Content: " "},
|
||||
},
|
||||
name: "Suffix partial",
|
||||
pieces: []string{"Hello", " the", "re!"},
|
||||
stop: "there!",
|
||||
expected: []string{"Hello", " "},
|
||||
expectedTrunc: true,
|
||||
},
|
||||
{
|
||||
name: "Middle",
|
||||
pieces: []llm.CompletionResponse{
|
||||
{Content: "Hello"},
|
||||
{Content: " wo"},
|
||||
},
|
||||
stop: "llo w",
|
||||
expected: []llm.CompletionResponse{
|
||||
{Content: "He"},
|
||||
},
|
||||
name: "Middle",
|
||||
pieces: []string{"hello", " wor"},
|
||||
stop: "llo w",
|
||||
expected: []string{"he"},
|
||||
expectedTrunc: true,
|
||||
},
|
||||
}
|
||||
@@ -87,23 +54,76 @@ func TestTruncateStop(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, resultTrunc := TruncateStop(tt.pieces, tt.stop)
|
||||
if !reflect.DeepEqual(result, tt.expected) || resultTrunc != tt.expectedTrunc {
|
||||
t.Errorf("truncateStop(%v, %v):\n%shave truncated %v\nwant truncated %v",
|
||||
tt.pieces, tt.stop, formatContentDiff(result, tt.expected), resultTrunc, tt.expectedTrunc)
|
||||
t.Errorf("truncateStop(%v, %s): have %v (%v); want %v (%v)", tt.pieces, tt.stop, result, resultTrunc, tt.expected, tt.expectedTrunc)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func formatContentDiff(result, expected []llm.CompletionResponse) string {
|
||||
var s string
|
||||
for i := 0; i < len(result) || i < len(expected); i++ {
|
||||
if i < len(result) && i < len(expected) && result[i].Content != expected[i].Content {
|
||||
s += fmt.Sprintf("[%d] %q vs %q\n", i, result[i].Content, expected[i].Content)
|
||||
} else if i < len(result) && i >= len(expected) {
|
||||
s += fmt.Sprintf("[%d] extra %q\n", i, result[i].Content)
|
||||
} else if i >= len(result) && i < len(expected) {
|
||||
s += fmt.Sprintf("[%d] missing %q\n", i, expected[i].Content)
|
||||
}
|
||||
func TestIncompleteUnicode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "Basic",
|
||||
input: "hi",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Two byte",
|
||||
input: "hi" + string([]byte{0xc2, 0xa3}),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Two byte - missing last",
|
||||
input: "hi" + string([]byte{0xc2}),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Three byte",
|
||||
input: "hi" + string([]byte{0xe0, 0xA0, 0x80}),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Three byte - missing last",
|
||||
input: "hi" + string([]byte{0xe0, 0xA0}),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Three byte - missing last 2",
|
||||
input: "hi" + string([]byte{0xe0}),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Four byte",
|
||||
input: "hi" + string([]byte{0xf0, 0x92, 0x8a, 0xb7}),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Four byte - missing last",
|
||||
input: "hi" + string([]byte{0xf0, 0x92, 0x8a}),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Four byte - missing last 2",
|
||||
input: "hi" + string([]byte{0xf0, 0x92}),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Four byte - missing last 3",
|
||||
input: "hi" + string([]byte{0xf0}),
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := IncompleteUnicode(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("incompleteUnicode(%s): have %v; want %v", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"golang.org/x/sync/semaphore"
|
||||
|
||||
@@ -51,13 +52,13 @@ type Sequence struct {
|
||||
pendingInputs []input
|
||||
|
||||
// tokens that have been generated but not returned yet (e.g. for stop sequences)
|
||||
pendingResponses []llm.CompletionResponse
|
||||
pendingResponses []string
|
||||
|
||||
// input cache being used by this sequence
|
||||
cache *InputCacheSlot
|
||||
|
||||
// channel to send responses over
|
||||
responses chan llm.CompletionResponse
|
||||
responses chan string
|
||||
|
||||
// channel to stop decoding (such as if the remote connection is closed)
|
||||
quit chan bool
|
||||
@@ -88,19 +89,6 @@ type Sequence struct {
|
||||
numPromptInputs int
|
||||
}
|
||||
|
||||
func (seq *Sequence) send(resp llm.CompletionResponse) bool {
|
||||
if len(resp.Content) > 0 || resp.Done {
|
||||
select {
|
||||
case seq.responses <- resp:
|
||||
// Successfully sent
|
||||
return true
|
||||
case <-seq.quit:
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
type NewSequenceParams struct {
|
||||
numPredict int
|
||||
stop []string
|
||||
@@ -159,8 +147,8 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
||||
numPromptInputs: len(inputs),
|
||||
startProcessingTime: startTime,
|
||||
numPredict: params.numPredict,
|
||||
pendingResponses: make([]llm.CompletionResponse, 0),
|
||||
responses: make(chan llm.CompletionResponse, 100),
|
||||
pendingResponses: make([]string, 0),
|
||||
responses: make(chan string, 100),
|
||||
quit: make(chan bool, 1),
|
||||
embedding: make(chan []float32, 1),
|
||||
samplingCtx: sc,
|
||||
@@ -284,15 +272,36 @@ func (s *Server) allNil() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func flushPending(seq *Sequence) bool {
|
||||
joined := strings.Join(seq.pendingResponses, "")
|
||||
seq.pendingResponses = []string{}
|
||||
|
||||
// Check if there are any partial UTF-8 characters remaining.
|
||||
// We already check and queue as we are generating but some may
|
||||
// still make it here:
|
||||
// - Sequence is ending, e.g. generation limit has been hit
|
||||
// - Invalid characters in the middle of a string
|
||||
// This is a stricter check to ensure we never output invalid Unicode.
|
||||
for !utf8.ValidString(joined) {
|
||||
joined = joined[:len(joined)-1]
|
||||
}
|
||||
|
||||
if len(joined) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
select {
|
||||
case seq.responses <- joined:
|
||||
return true
|
||||
case <-seq.quit:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
|
||||
seq := s.seqs[seqIndex]
|
||||
|
||||
// Send any remaining pending responses
|
||||
for _, resp := range seq.pendingResponses {
|
||||
seq.send(resp)
|
||||
}
|
||||
seq.pendingResponses = []llm.CompletionResponse{}
|
||||
|
||||
flushPending(seq)
|
||||
seq.doneReason = reason
|
||||
close(seq.responses)
|
||||
close(seq.embedding)
|
||||
@@ -481,11 +490,8 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
||||
|
||||
seq.inputs = []input{{token: token}}
|
||||
|
||||
seq.pendingResponses = append(seq.pendingResponses, llm.CompletionResponse{Content: piece})
|
||||
sequence := ""
|
||||
for _, r := range seq.pendingResponses {
|
||||
sequence += r.Content
|
||||
}
|
||||
seq.pendingResponses = append(seq.pendingResponses, piece)
|
||||
sequence := strings.Join(seq.pendingResponses, "")
|
||||
|
||||
if ok, stop := common.FindStop(sequence, seq.stop); ok {
|
||||
slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)
|
||||
@@ -517,13 +523,13 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, resp := range seq.pendingResponses {
|
||||
if !seq.send(resp) {
|
||||
s.removeSequence(i, llm.DoneReasonConnectionClosed)
|
||||
break
|
||||
}
|
||||
if common.IncompleteUnicode(sequence) {
|
||||
continue
|
||||
}
|
||||
|
||||
if !flushPending(seq) {
|
||||
s.removeSequence(i, llm.DoneReasonConnectionClosed)
|
||||
}
|
||||
seq.pendingResponses = []llm.CompletionResponse{}
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -621,7 +627,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
case content, ok := <-seq.responses:
|
||||
if ok {
|
||||
if err := json.NewEncoder(w).Encode(&content); err != nil {
|
||||
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
|
||||
Content: content,
|
||||
}); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||
close(seq.quit)
|
||||
return
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"golang.org/x/image/bmp"
|
||||
"golang.org/x/sync/semaphore"
|
||||
@@ -55,13 +56,13 @@ type Sequence struct {
|
||||
pendingInputs []input.Input
|
||||
|
||||
// tokens that have been generated but not returned yet (e.g. for stop sequences)
|
||||
pendingResponses []llm.CompletionResponse
|
||||
pendingResponses []string
|
||||
|
||||
// input cache being used by this sequence
|
||||
cache *InputCacheSlot
|
||||
|
||||
// channel to send responses over
|
||||
responses chan llm.CompletionResponse
|
||||
responses chan string
|
||||
|
||||
// channel to stop decoding (such as if the remote connection is closed)
|
||||
quit chan bool
|
||||
@@ -93,19 +94,6 @@ type Sequence struct {
|
||||
numPromptInputs int
|
||||
}
|
||||
|
||||
func (seq *Sequence) send(resp llm.CompletionResponse) bool {
|
||||
if len(resp.Content) > 0 || resp.Done {
|
||||
select {
|
||||
case seq.responses <- resp:
|
||||
// Successfully sent
|
||||
return true
|
||||
case <-seq.quit:
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
type NewSequenceParams struct {
|
||||
numPredict int
|
||||
stop []string
|
||||
@@ -179,8 +167,8 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
||||
numPromptInputs: len(inputs),
|
||||
startProcessingTime: startTime,
|
||||
numPredict: params.numPredict,
|
||||
pendingResponses: make([]llm.CompletionResponse, 0),
|
||||
responses: make(chan llm.CompletionResponse, 100),
|
||||
pendingResponses: make([]string, 0),
|
||||
responses: make(chan string, 100),
|
||||
quit: make(chan bool, 1),
|
||||
embedding: make(chan []float32, 1),
|
||||
sampler: params.sampler,
|
||||
@@ -325,15 +313,36 @@ func (s *Server) allNil() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func flushPending(seq *Sequence) bool {
|
||||
joined := strings.Join(seq.pendingResponses, "")
|
||||
seq.pendingResponses = []string{}
|
||||
|
||||
// Check if there are any partial UTF-8 characters remaining.
|
||||
// We already check and queue as we are generating but some may
|
||||
// still make it here:
|
||||
// - Sequence is ending, e.g. generation limit has been hit
|
||||
// - Invalid characters in the middle of a string
|
||||
// This is a stricter check to ensure we never output invalid Unicode.
|
||||
for !utf8.ValidString(joined) {
|
||||
joined = joined[:len(joined)-1]
|
||||
}
|
||||
|
||||
if len(joined) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
select {
|
||||
case seq.responses <- joined:
|
||||
return true
|
||||
case <-seq.quit:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
|
||||
seq := s.seqs[seqIndex]
|
||||
|
||||
// Send any remaining pending responses
|
||||
for _, resp := range seq.pendingResponses {
|
||||
seq.send(resp)
|
||||
}
|
||||
seq.pendingResponses = []llm.CompletionResponse{}
|
||||
|
||||
flushPending(seq)
|
||||
seq.doneReason = reason
|
||||
close(seq.responses)
|
||||
close(seq.embedding)
|
||||
@@ -532,11 +541,8 @@ func (s *Server) processBatch() error {
|
||||
|
||||
seq.inputs = []input.Input{{Token: token}}
|
||||
|
||||
seq.pendingResponses = append(seq.pendingResponses, llm.CompletionResponse{Content: piece})
|
||||
sequence := ""
|
||||
for _, r := range seq.pendingResponses {
|
||||
sequence += r.Content
|
||||
}
|
||||
seq.pendingResponses = append(seq.pendingResponses, piece)
|
||||
sequence := strings.Join(seq.pendingResponses, "")
|
||||
|
||||
if ok, stop := common.FindStop(sequence, seq.stop); ok {
|
||||
slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)
|
||||
@@ -568,14 +574,13 @@ func (s *Server) processBatch() error {
|
||||
continue
|
||||
}
|
||||
|
||||
// Send all pending responses directly without unicode checking
|
||||
for _, resp := range seq.pendingResponses {
|
||||
if !seq.send(resp) {
|
||||
s.removeSequence(i, llm.DoneReasonConnectionClosed)
|
||||
break
|
||||
}
|
||||
if common.IncompleteUnicode(sequence) {
|
||||
continue
|
||||
}
|
||||
|
||||
if !flushPending(seq) {
|
||||
s.removeSequence(i, llm.DoneReasonConnectionClosed)
|
||||
}
|
||||
seq.pendingResponses = []llm.CompletionResponse{}
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -678,7 +683,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
case content, ok := <-seq.responses:
|
||||
if ok {
|
||||
if err := json.NewEncoder(w).Encode(&content); err != nil {
|
||||
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
|
||||
Content: content,
|
||||
}); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||
close(seq.quit)
|
||||
return
|
||||
|
||||
115
server/cache/capabilities.go
vendored
Normal file
115
server/cache/capabilities.go
vendored
Normal file
@@ -0,0 +1,115 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/template"
|
||||
"github.com/ollama/ollama/thinking"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
// cacheEntry stores capabilities and the modification time of the model file
|
||||
type cacheEntry struct {
|
||||
capabilities []model.Capability
|
||||
modTime time.Time
|
||||
}
|
||||
|
||||
// ggufCapabilities is a cache for gguf model capabilities
|
||||
var ggufCapabilities = &sync.Map{}
|
||||
|
||||
// ModelInfo contains the minimal information needed to determine capabilities
|
||||
type ModelInfo struct {
|
||||
ModelPath string
|
||||
ProjectorPaths []string
|
||||
Template *template.Template
|
||||
}
|
||||
|
||||
// Capabilities returns the capabilities that the model supports
|
||||
func Capabilities(info ModelInfo) []model.Capability {
|
||||
capabilities, err := ggufCapabilties(info.ModelPath)
|
||||
if err != nil {
|
||||
slog.Error("could not determine gguf capabilities", "error", err)
|
||||
}
|
||||
|
||||
if info.Template == nil {
|
||||
return capabilities
|
||||
}
|
||||
|
||||
// Check for tools capability
|
||||
if slices.Contains(info.Template.Vars(), "tools") {
|
||||
capabilities = append(capabilities, model.CapabilityTools)
|
||||
}
|
||||
|
||||
// Check for insert capability
|
||||
if slices.Contains(info.Template.Vars(), "suffix") {
|
||||
capabilities = append(capabilities, model.CapabilityInsert)
|
||||
}
|
||||
|
||||
// Check for vision capability in projector-based models
|
||||
if len(info.ProjectorPaths) > 0 {
|
||||
capabilities = append(capabilities, model.CapabilityVision)
|
||||
}
|
||||
|
||||
// Check for thinking capability
|
||||
openingTag, closingTag := thinking.InferTags(info.Template.Template)
|
||||
if openingTag != "" && closingTag != "" {
|
||||
capabilities = append(capabilities, model.CapabilityThinking)
|
||||
}
|
||||
|
||||
return capabilities
|
||||
}
|
||||
|
||||
func ggufCapabilties(modelPath string) ([]model.Capability, error) {
|
||||
// Get file info to check modification time
|
||||
fileInfo, err := os.Stat(modelPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
currentModTime := fileInfo.ModTime()
|
||||
|
||||
// Check if we have a cached entry
|
||||
if cached, ok := ggufCapabilities.Load(modelPath); ok {
|
||||
entry := cached.(cacheEntry)
|
||||
// If the file hasn't been modified since we cached it, return the cached capabilities
|
||||
if entry.modTime.Equal(currentModTime) {
|
||||
return entry.capabilities, nil
|
||||
}
|
||||
}
|
||||
|
||||
// If not cached or file was modified, read the model file to determine capabilities
|
||||
capabilities := []model.Capability{}
|
||||
|
||||
r, err := os.Open(modelPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
f, err := ggml.Decode(r, 1024)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if _, ok := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]; ok {
|
||||
capabilities = append(capabilities, model.CapabilityEmbedding)
|
||||
} else {
|
||||
capabilities = append(capabilities, model.CapabilityCompletion)
|
||||
}
|
||||
if _, ok := f.KV()[fmt.Sprintf("%s.vision.block_count", f.KV().Architecture())]; ok {
|
||||
capabilities = append(capabilities, model.CapabilityVision)
|
||||
}
|
||||
|
||||
// Cache the capabilities with the modification time
|
||||
ggufCapabilities.Store(modelPath, cacheEntry{
|
||||
capabilities: capabilities,
|
||||
modTime: currentModTime,
|
||||
})
|
||||
|
||||
return capabilities, nil
|
||||
}
|
||||
211
server/cache/capabilities_test.go
vendored
Normal file
211
server/cache/capabilities_test.go
vendored
Normal file
@@ -0,0 +1,211 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"maps"
|
||||
"os"
|
||||
"slices"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/template"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
// testGGUF creates a temporary GGUF model file for testing with custom key-value pairs
|
||||
func testGGUF(tb testing.TB, customKV ggml.KV) string {
|
||||
tb.Helper()
|
||||
f, err := os.CreateTemp(tb.TempDir(), "test*.gguf")
|
||||
if err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
kv := ggml.KV{}
|
||||
maps.Copy(kv, customKV)
|
||||
|
||||
tensors := []*ggml.Tensor{
|
||||
{
|
||||
Name: "token_embd.weight",
|
||||
Kind: 0,
|
||||
Shape: []uint64{1, 1},
|
||||
WriterTo: bytes.NewBuffer(make([]byte, 4)),
|
||||
},
|
||||
}
|
||||
|
||||
if err := ggml.WriteGGUF(f, kv, tensors); err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
|
||||
return f.Name()
|
||||
}
|
||||
|
||||
func TestCapabilities(t *testing.T) {
|
||||
ggufCapabilities.Range(func(key, value any) bool {
|
||||
ggufCapabilities.Delete(key)
|
||||
return true
|
||||
})
|
||||
|
||||
// Create test model paths
|
||||
completionModelPath := testGGUF(t, ggml.KV{
|
||||
"general.architecture": "llama",
|
||||
})
|
||||
|
||||
visionModelPath := testGGUF(t, ggml.KV{
|
||||
"general.architecture": "llama",
|
||||
"llama.vision.block_count": uint32(1),
|
||||
})
|
||||
|
||||
embeddingModelPath := testGGUF(t, ggml.KV{
|
||||
"general.architecture": "bert",
|
||||
"bert.pooling_type": uint32(1),
|
||||
})
|
||||
|
||||
// Create templates
|
||||
toolsInsertTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}{{ if .suffix }}{{ .suffix }}{{ end }}")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse template: %v", err)
|
||||
}
|
||||
|
||||
chatTemplate, err := template.Parse("{{ .prompt }}")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse template: %v", err)
|
||||
}
|
||||
|
||||
toolsTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse template: %v", err)
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
model ModelInfo
|
||||
expectedCaps []model.Capability
|
||||
}{
|
||||
{
|
||||
name: "model with completion capability",
|
||||
model: ModelInfo{
|
||||
ModelPath: completionModelPath,
|
||||
Template: chatTemplate,
|
||||
},
|
||||
expectedCaps: []model.Capability{model.CapabilityCompletion},
|
||||
},
|
||||
{
|
||||
name: "model with completion, tools, and insert capability",
|
||||
model: ModelInfo{
|
||||
ModelPath: completionModelPath,
|
||||
Template: toolsInsertTemplate,
|
||||
},
|
||||
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityTools, model.CapabilityInsert},
|
||||
},
|
||||
{
|
||||
name: "model with tools capability",
|
||||
model: ModelInfo{
|
||||
ModelPath: completionModelPath,
|
||||
Template: toolsTemplate,
|
||||
},
|
||||
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityTools},
|
||||
},
|
||||
{
|
||||
name: "model with vision capability from gguf",
|
||||
model: ModelInfo{
|
||||
ModelPath: visionModelPath,
|
||||
Template: chatTemplate,
|
||||
},
|
||||
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityVision},
|
||||
},
|
||||
{
|
||||
name: "model with vision capability from projector",
|
||||
model: ModelInfo{
|
||||
ModelPath: completionModelPath,
|
||||
ProjectorPaths: []string{"/path/to/projector"},
|
||||
Template: chatTemplate,
|
||||
},
|
||||
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityVision},
|
||||
},
|
||||
{
|
||||
name: "model with vision, tools, and insert capability",
|
||||
model: ModelInfo{
|
||||
ModelPath: visionModelPath,
|
||||
Template: toolsInsertTemplate,
|
||||
},
|
||||
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityVision, model.CapabilityTools, model.CapabilityInsert},
|
||||
},
|
||||
{
|
||||
name: "model with embedding capability",
|
||||
model: ModelInfo{
|
||||
ModelPath: embeddingModelPath,
|
||||
Template: chatTemplate,
|
||||
},
|
||||
expectedCaps: []model.Capability{model.CapabilityEmbedding},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// First call - should read from file
|
||||
caps := Capabilities(tc.model)
|
||||
slices.Sort(caps)
|
||||
slices.Sort(tc.expectedCaps)
|
||||
if !slices.Equal(caps, tc.expectedCaps) {
|
||||
t.Errorf("Expected capabilities %v, got %v", tc.expectedCaps, caps)
|
||||
}
|
||||
|
||||
// Verify caching for models that read from GGUF
|
||||
if tc.model.ModelPath != "" {
|
||||
// Check that entry is cached
|
||||
_, ok := ggufCapabilities.Load(tc.model.ModelPath)
|
||||
if !ok {
|
||||
t.Error("Expected capabilities to be cached")
|
||||
}
|
||||
|
||||
// Second call - should use cache
|
||||
caps2 := Capabilities(tc.model)
|
||||
slices.Sort(caps2)
|
||||
if !slices.Equal(caps, caps2) {
|
||||
t.Errorf("Cached capabilities don't match original: expected %v, got %v", caps, caps2)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Test cache invalidation on file modification
|
||||
t.Run("cache invalidation", func(t *testing.T) {
|
||||
// Use completion model for this test
|
||||
info := ModelInfo{
|
||||
ModelPath: completionModelPath,
|
||||
Template: chatTemplate,
|
||||
}
|
||||
|
||||
// Get initial cached entry
|
||||
cached, ok := ggufCapabilities.Load(completionModelPath)
|
||||
if !ok {
|
||||
t.Fatal("Expected model to be cached from previous tests")
|
||||
}
|
||||
entry := cached.(cacheEntry)
|
||||
|
||||
// Modify the file's timestamp to the future
|
||||
future := time.Now().Add(time.Hour)
|
||||
err := os.Chtimes(completionModelPath, future, future)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to update file timestamp: %v", err)
|
||||
}
|
||||
|
||||
// Call should re-read from file due to changed modtime
|
||||
caps := Capabilities(info)
|
||||
if len(caps) != 1 || caps[0] != model.CapabilityCompletion {
|
||||
t.Errorf("Expected [CapabilityCompletion], got %v", caps)
|
||||
}
|
||||
|
||||
// Check that cache was updated with new modtime
|
||||
cached2, ok := ggufCapabilities.Load(completionModelPath)
|
||||
if !ok {
|
||||
t.Error("Expected capabilities to be cached after re-read")
|
||||
}
|
||||
entry2 := cached2.(cacheEntry)
|
||||
if entry2.modTime.Equal(entry.modTime) {
|
||||
t.Error("Expected cache entry to have updated modTime")
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -23,10 +23,9 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/fs/gguf"
|
||||
"github.com/ollama/ollama/parser"
|
||||
"github.com/ollama/ollama/server/cache"
|
||||
"github.com/ollama/ollama/template"
|
||||
"github.com/ollama/ollama/thinking"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
@@ -68,60 +67,14 @@ type Model struct {
|
||||
Template *template.Template
|
||||
}
|
||||
|
||||
// Capabilities returns the capabilities that the model supports
|
||||
func (m *Model) Capabilities() []model.Capability {
|
||||
capabilities := []model.Capability{}
|
||||
|
||||
// Check for completion capability
|
||||
f, err := gguf.Open(m.ModelPath)
|
||||
if err == nil {
|
||||
defer f.Close()
|
||||
|
||||
if f.KeyValue("pooling_type").Valid() {
|
||||
capabilities = append(capabilities, model.CapabilityEmbedding)
|
||||
} else {
|
||||
// If no embedding is specified, we assume the model supports completion
|
||||
capabilities = append(capabilities, model.CapabilityCompletion)
|
||||
}
|
||||
if f.KeyValue("vision.block_count").Valid() {
|
||||
capabilities = append(capabilities, model.CapabilityVision)
|
||||
}
|
||||
} else {
|
||||
slog.Error("couldn't open model file", "error", err)
|
||||
}
|
||||
|
||||
if m.Template == nil {
|
||||
return capabilities
|
||||
}
|
||||
|
||||
// Check for tools capability
|
||||
if slices.Contains(m.Template.Vars(), "tools") {
|
||||
capabilities = append(capabilities, model.CapabilityTools)
|
||||
}
|
||||
|
||||
// Check for insert capability
|
||||
if slices.Contains(m.Template.Vars(), "suffix") {
|
||||
capabilities = append(capabilities, model.CapabilityInsert)
|
||||
}
|
||||
|
||||
// Check for vision capability in projector-based models
|
||||
if len(m.ProjectorPaths) > 0 {
|
||||
capabilities = append(capabilities, model.CapabilityVision)
|
||||
}
|
||||
|
||||
// Check for thinking capability
|
||||
openingTag, closingTag := thinking.InferTags(m.Template.Template)
|
||||
if openingTag != "" && closingTag != "" {
|
||||
capabilities = append(capabilities, model.CapabilityThinking)
|
||||
}
|
||||
|
||||
return capabilities
|
||||
}
|
||||
|
||||
// CheckCapabilities checks if the model has the specified capabilities returning an error describing
|
||||
// any missing or unknown capabilities
|
||||
func (m *Model) CheckCapabilities(want ...model.Capability) error {
|
||||
available := m.Capabilities()
|
||||
available := cache.Capabilities(cache.ModelInfo{
|
||||
ModelPath: m.ModelPath,
|
||||
ProjectorPaths: m.ProjectorPaths,
|
||||
Template: m.Template,
|
||||
})
|
||||
var errs []error
|
||||
|
||||
// Map capabilities to their corresponding error
|
||||
|
||||
@@ -9,131 +9,6 @@ import (
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
func TestModelCapabilities(t *testing.T) {
|
||||
// Create completion model (llama architecture without vision)
|
||||
completionModelPath, _ := createBinFile(t, ggml.KV{
|
||||
"general.architecture": "llama",
|
||||
}, []*ggml.Tensor{})
|
||||
|
||||
// Create vision model (llama architecture with vision block count)
|
||||
visionModelPath, _ := createBinFile(t, ggml.KV{
|
||||
"general.architecture": "llama",
|
||||
"llama.vision.block_count": uint32(1),
|
||||
}, []*ggml.Tensor{})
|
||||
|
||||
// Create embedding model (bert architecture with pooling type)
|
||||
embeddingModelPath, _ := createBinFile(t, ggml.KV{
|
||||
"general.architecture": "bert",
|
||||
"bert.pooling_type": uint32(1),
|
||||
}, []*ggml.Tensor{})
|
||||
|
||||
toolsInsertTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}{{ if .suffix }}{{ .suffix }}{{ end }}")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse template: %v", err)
|
||||
}
|
||||
|
||||
chatTemplate, err := template.Parse("{{ .prompt }}")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse template: %v", err)
|
||||
}
|
||||
|
||||
toolsTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse template: %v", err)
|
||||
}
|
||||
|
||||
testModels := []struct {
|
||||
name string
|
||||
model Model
|
||||
expectedCaps []model.Capability
|
||||
}{
|
||||
{
|
||||
name: "model with completion capability",
|
||||
model: Model{
|
||||
ModelPath: completionModelPath,
|
||||
Template: chatTemplate,
|
||||
},
|
||||
expectedCaps: []model.Capability{model.CapabilityCompletion},
|
||||
},
|
||||
|
||||
{
|
||||
name: "model with completion, tools, and insert capability",
|
||||
model: Model{
|
||||
ModelPath: completionModelPath,
|
||||
Template: toolsInsertTemplate,
|
||||
},
|
||||
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityTools, model.CapabilityInsert},
|
||||
},
|
||||
{
|
||||
name: "model with tools capability",
|
||||
model: Model{
|
||||
ModelPath: completionModelPath,
|
||||
Template: toolsTemplate,
|
||||
},
|
||||
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityTools},
|
||||
},
|
||||
{
|
||||
name: "model with vision capability",
|
||||
model: Model{
|
||||
ModelPath: visionModelPath,
|
||||
Template: chatTemplate,
|
||||
},
|
||||
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityVision},
|
||||
},
|
||||
{
|
||||
name: "model with vision, tools, and insert capability",
|
||||
model: Model{
|
||||
ModelPath: visionModelPath,
|
||||
Template: toolsInsertTemplate,
|
||||
},
|
||||
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityVision, model.CapabilityTools, model.CapabilityInsert},
|
||||
},
|
||||
{
|
||||
name: "model with embedding capability",
|
||||
model: Model{
|
||||
ModelPath: embeddingModelPath,
|
||||
Template: chatTemplate,
|
||||
},
|
||||
expectedCaps: []model.Capability{model.CapabilityEmbedding},
|
||||
},
|
||||
}
|
||||
|
||||
// compare two slices of model.Capability regardless of order
|
||||
compareCapabilities := func(a, b []model.Capability) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
|
||||
aCount := make(map[model.Capability]int)
|
||||
for _, cap := range a {
|
||||
aCount[cap]++
|
||||
}
|
||||
|
||||
bCount := make(map[model.Capability]int)
|
||||
for _, cap := range b {
|
||||
bCount[cap]++
|
||||
}
|
||||
|
||||
for cap, count := range aCount {
|
||||
if bCount[cap] != count {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
for _, tt := range testModels {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Test Capabilities method
|
||||
caps := tt.model.Capabilities()
|
||||
if !compareCapabilities(caps, tt.expectedCaps) {
|
||||
t.Errorf("Expected capabilities %v, got %v", tt.expectedCaps, caps)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelCheckCapabilities(t *testing.T) {
|
||||
// Create simple model file for tests that don't depend on GGUF content
|
||||
completionModelPath, _ := createBinFile(t, ggml.KV{
|
||||
|
||||
@@ -34,6 +34,7 @@ import (
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/openai"
|
||||
"github.com/ollama/ollama/server/cache"
|
||||
"github.com/ollama/ollama/server/internal/client/ollama"
|
||||
"github.com/ollama/ollama/server/internal/registry"
|
||||
"github.com/ollama/ollama/template"
|
||||
@@ -819,13 +820,17 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
}
|
||||
|
||||
resp := &api.ShowResponse{
|
||||
License: strings.Join(m.License, "\n"),
|
||||
System: m.System,
|
||||
Template: m.Template.String(),
|
||||
Details: modelDetails,
|
||||
Messages: msgs,
|
||||
Capabilities: m.Capabilities(),
|
||||
ModifiedAt: manifest.fi.ModTime(),
|
||||
License: strings.Join(m.License, "\n"),
|
||||
System: m.System,
|
||||
Template: m.Template.String(),
|
||||
Details: modelDetails,
|
||||
Messages: msgs,
|
||||
Capabilities: cache.Capabilities(cache.ModelInfo{
|
||||
ModelPath: m.ModelPath,
|
||||
Template: m.Template,
|
||||
ProjectorPaths: m.ProjectorPaths,
|
||||
}),
|
||||
ModifiedAt: manifest.fi.ModTime(),
|
||||
}
|
||||
|
||||
var params []string
|
||||
|
||||
Reference in New Issue
Block a user