mirror of
https://github.com/ollama/ollama.git
synced 2026-01-05 14:11:04 -05:00
Compare commits
14 Commits
mxyng/quan
...
pdevine/au
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
53a53702e0 | ||
|
|
be04fcde16 | ||
|
|
5968989a7f | ||
|
|
8afa6e83f2 | ||
|
|
ea85e27bbd | ||
|
|
c116a7523d | ||
|
|
3515cc377c | ||
|
|
bbf66c0b96 | ||
|
|
764be7480f | ||
|
|
b72e5adb14 | ||
|
|
80b538e312 | ||
|
|
4f8a0166cc | ||
|
|
1e6eab5c33 | ||
|
|
6c733bf0a6 |
2
.github/workflows/release.yaml
vendored
2
.github/workflows/release.yaml
vendored
@@ -23,7 +23,7 @@ jobs:
|
||||
echo GOFLAGS="'-ldflags=-w -s \"-X=github.com/ollama/ollama/version.Version=${GITHUB_REF_NAME#v}\" \"-X=github.com/ollama/ollama/server.mode=release\"'" >>$GITHUB_OUTPUT
|
||||
|
||||
darwin-build:
|
||||
runs-on: macos-13-xlarge
|
||||
runs-on: macos-13
|
||||
environment: release
|
||||
needs: setup-environment
|
||||
strategy:
|
||||
|
||||
@@ -65,7 +65,7 @@ continuation of the sentence:
|
||||
Examples:
|
||||
|
||||
llm/backend/mlx: support the llama architecture
|
||||
CONTRIBUTING: provide clairity on good commit messages, and bad
|
||||
CONTRIBUTING: provide clarity on good commit messages, and bad
|
||||
|
||||
Bad Examples:
|
||||
|
||||
|
||||
@@ -410,6 +410,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [GPTranslate](https://github.com/philberndt/GPTranslate) (A fast and lightweight, AI powered desktop translation application written with Rust and Tauri. Features real-time translation with OpenAI/Azure/Ollama.)
|
||||
- [ollama launcher](https://github.com/NGC13009/ollama-launcher) (A launcher for Ollama, aiming to provide users with convenient functions such as ollama server launching, management, or configuration.)
|
||||
- [ai-hub](https://github.com/Aj-Seven/ai-hub) (AI Hub supports multiple models via API keys and Chat support via Ollama API.)
|
||||
- [Mayan EDMS](https://gitlab.com/mayan-edms/mayan-edms) (Open source document management system to organize, tag, search, and automate your files with powerful Ollama driven workflows.)
|
||||
|
||||
### Cloud
|
||||
|
||||
|
||||
@@ -42,6 +42,23 @@ type Client struct {
|
||||
|
||||
func checkError(resp *http.Response, body []byte) error {
|
||||
if resp.StatusCode < http.StatusBadRequest {
|
||||
if len(body) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// streams can contain error message even with StatusOK
|
||||
var errorResponse struct {
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &errorResponse); err != nil {
|
||||
return fmt.Errorf("unmarshal: %w", err)
|
||||
}
|
||||
|
||||
if errorResponse.Error != "" {
|
||||
return errors.New(errorResponse.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -213,25 +230,9 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
|
||||
scanBuf := make([]byte, 0, maxBufferSize)
|
||||
scanner.Buffer(scanBuf, maxBufferSize)
|
||||
for scanner.Scan() {
|
||||
var errorResponse struct {
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
bts := scanner.Bytes()
|
||||
if err := json.Unmarshal(bts, &errorResponse); err != nil {
|
||||
return fmt.Errorf("unmarshal: %w", err)
|
||||
}
|
||||
|
||||
if response.StatusCode >= http.StatusBadRequest {
|
||||
return StatusError{
|
||||
StatusCode: response.StatusCode,
|
||||
Status: response.Status,
|
||||
ErrorMessage: errorResponse.Error,
|
||||
}
|
||||
}
|
||||
|
||||
if errorResponse.Error != "" {
|
||||
return errors.New(errorResponse.Error)
|
||||
if err := checkError(response, bts); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := fn(bts); err != nil {
|
||||
|
||||
@@ -89,16 +89,6 @@ func TestClientStream(t *testing.T) {
|
||||
},
|
||||
wantErr: "mid-stream error",
|
||||
},
|
||||
{
|
||||
name: "http status error takes precedence over general error",
|
||||
responses: []any{
|
||||
testError{
|
||||
message: "custom error message",
|
||||
statusCode: http.StatusInternalServerError,
|
||||
},
|
||||
},
|
||||
wantErr: "500",
|
||||
},
|
||||
{
|
||||
name: "successful stream completion",
|
||||
responses: []any{
|
||||
|
||||
35
auth/auth.go
35
auth/auth.go
@@ -18,6 +18,8 @@ import (
|
||||
|
||||
const defaultPrivateKey = "id_ed25519"
|
||||
|
||||
var ErrInvalidToken = errors.New("invalid token")
|
||||
|
||||
func keyPath() (string, error) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
@@ -27,6 +29,39 @@ func keyPath() (string, error) {
|
||||
return filepath.Join(home, ".ollama", defaultPrivateKey), nil
|
||||
}
|
||||
|
||||
func parseToken(token string) (key, sig []byte, _ error) {
|
||||
keyData, sigData, ok := strings.Cut(token, ":")
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("identity: parseToken: %w", ErrInvalidToken)
|
||||
}
|
||||
sig, err := base64.StdEncoding.DecodeString(sigData)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("identity: parseToken: base64 decoding signature: %w", err)
|
||||
}
|
||||
return []byte(keyData), sig, nil
|
||||
}
|
||||
|
||||
func Authenticate(token, checkData string) (ssh.PublicKey, error) {
|
||||
keyShort, sigBytes, err := parseToken(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
keyLong := append([]byte("ssh-ed25519 "), keyShort...)
|
||||
pub, _, _, _, err := ssh.ParseAuthorizedKey(keyLong)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := pub.Verify([]byte(checkData), &ssh.Signature{
|
||||
Format: pub.Type(),
|
||||
Blob: sigBytes,
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return pub, nil
|
||||
}
|
||||
|
||||
func GetPublicKey() (string, error) {
|
||||
keyPath, err := keyPath()
|
||||
if err != nil {
|
||||
|
||||
254
auth/authorized_keys.go
Normal file
254
auth/authorized_keys.go
Normal file
@@ -0,0 +1,254 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type KeyEntry struct {
|
||||
Name string
|
||||
PublicKey string
|
||||
Endpoints []string
|
||||
}
|
||||
|
||||
type KeyPermission struct {
|
||||
Name string
|
||||
Endpoints []string
|
||||
}
|
||||
|
||||
type APIPermissions struct {
|
||||
permissions map[string]*KeyPermission
|
||||
lastModified time.Time
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
var ws = regexp.MustCompile(`\s+`)
|
||||
|
||||
func authkeyPath() (string, error) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return filepath.Join(home, ".ollama", "authorized_keys"), nil
|
||||
}
|
||||
|
||||
func NewAPIPermissions() *APIPermissions {
|
||||
return &APIPermissions{
|
||||
permissions: make(map[string]*KeyPermission),
|
||||
mutex: sync.RWMutex{},
|
||||
}
|
||||
}
|
||||
|
||||
func (ap *APIPermissions) ReloadIfNeeded() error {
|
||||
ap.mutex.Lock()
|
||||
defer ap.mutex.Unlock()
|
||||
|
||||
filename, err := authkeyPath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fileInfo, err := os.Stat(filename)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to stat file: %v", err)
|
||||
}
|
||||
|
||||
if !fileInfo.ModTime().After(ap.lastModified) {
|
||||
return nil
|
||||
}
|
||||
|
||||
file, err := os.Open(filename)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open file: %v", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
ap.lastModified = fileInfo.ModTime()
|
||||
return ap.parse(file)
|
||||
}
|
||||
|
||||
func (ap *APIPermissions) parse(r io.Reader) error {
|
||||
ap.permissions = make(map[string]*KeyPermission)
|
||||
|
||||
scanner := bufio.NewScanner(r)
|
||||
var cnt int
|
||||
for scanner.Scan() {
|
||||
cnt += 1
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
|
||||
line = ws.ReplaceAllString(line, " ")
|
||||
|
||||
entry, err := ap.parseLine(line)
|
||||
if err != nil {
|
||||
slog.Warn(fmt.Sprintf("authorized_keys line %d: skipping invalid line: %v\n", cnt, err))
|
||||
continue
|
||||
}
|
||||
|
||||
var pubKeyStr string
|
||||
|
||||
if entry.PublicKey == "*" {
|
||||
pubKeyStr = "*"
|
||||
} else {
|
||||
pubKey, err := ap.validateAndDecodeKey(entry)
|
||||
if err != nil {
|
||||
slog.Warn(fmt.Sprintf("authorized_keys line %d: invalid key for %s: %v\n", cnt, entry.Name, err))
|
||||
continue
|
||||
}
|
||||
pubKeyStr = pubKey
|
||||
}
|
||||
|
||||
if perm, exists := ap.permissions[pubKeyStr]; exists {
|
||||
if perm.Name == "default" {
|
||||
perm.Name = entry.Name
|
||||
}
|
||||
if len(perm.Endpoints) == 1 && perm.Endpoints[0] == "*" {
|
||||
// skip redundant entries
|
||||
continue
|
||||
} else if len(entry.Endpoints) == 1 && entry.Endpoints[0] == "*" {
|
||||
// overwrite redundant entries
|
||||
perm.Endpoints = entry.Endpoints
|
||||
} else {
|
||||
perm.Endpoints = append(perm.Endpoints, entry.Endpoints...)
|
||||
}
|
||||
} else {
|
||||
ap.permissions[pubKeyStr] = &KeyPermission{
|
||||
Name: entry.Name,
|
||||
Endpoints: entry.Endpoints,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
func (ap *APIPermissions) parseLine(line string) (*KeyEntry, error) {
|
||||
parts := strings.SplitN(line, " ", 4)
|
||||
if len(parts) < 2 {
|
||||
return nil, fmt.Errorf("key type and public key not found")
|
||||
}
|
||||
|
||||
kind, b64Key := parts[0], parts[1]
|
||||
name := "default"
|
||||
eps := "*"
|
||||
|
||||
if len(parts) >= 3 && parts[2] != "" {
|
||||
if parts[2] != "*" {
|
||||
name = parts[2]
|
||||
}
|
||||
}
|
||||
|
||||
if len(parts) == 4 && parts[3] != "" {
|
||||
eps = parts[3]
|
||||
}
|
||||
|
||||
if kind != "ssh-ed25519" && kind != "*" {
|
||||
return nil, fmt.Errorf("unsupported key type %s", kind)
|
||||
}
|
||||
|
||||
if kind == "*" && b64Key != "*" {
|
||||
return nil, fmt.Errorf("unsupported key type")
|
||||
}
|
||||
|
||||
var endpoints []string
|
||||
if eps == "*" {
|
||||
endpoints = []string{"*"}
|
||||
} else {
|
||||
for _, e := range strings.Split(eps, ",") {
|
||||
e = strings.TrimSpace(e)
|
||||
if e == "" {
|
||||
return nil, fmt.Errorf("empty endpoint in list")
|
||||
} else if e == "*" {
|
||||
endpoints = []string{"*"}
|
||||
break
|
||||
}
|
||||
endpoints = append(endpoints, e)
|
||||
}
|
||||
}
|
||||
|
||||
return &KeyEntry{
|
||||
PublicKey: b64Key,
|
||||
Name: name,
|
||||
Endpoints: endpoints,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (ap *APIPermissions) validateAndDecodeKey(entry *KeyEntry) (string, error) {
|
||||
keyBlob, err := base64.StdEncoding.DecodeString(entry.PublicKey)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("base64 decode: %w", err)
|
||||
}
|
||||
pub, err := ssh.ParsePublicKey(keyBlob)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("parse key: %w", err)
|
||||
}
|
||||
if pub.Type() != ssh.KeyAlgoED25519 {
|
||||
return "", fmt.Errorf("key is not Ed25519")
|
||||
}
|
||||
|
||||
return entry.PublicKey, nil
|
||||
}
|
||||
|
||||
func (ap *APIPermissions) Authorize(pubKey ssh.PublicKey, endpoint string) (bool, string, error) {
|
||||
if err := ap.ReloadIfNeeded(); err != nil {
|
||||
return false, "unknown", err
|
||||
}
|
||||
|
||||
ap.mutex.RLock()
|
||||
defer ap.mutex.RUnlock()
|
||||
|
||||
if wildcardPerm, exists := ap.permissions["*"]; exists {
|
||||
if len(wildcardPerm.Endpoints) == 1 && wildcardPerm.Endpoints[0] == "*" {
|
||||
return true, wildcardPerm.Name, nil
|
||||
}
|
||||
|
||||
for _, allowedEndpoint := range wildcardPerm.Endpoints {
|
||||
if allowedEndpoint == endpoint {
|
||||
return true, wildcardPerm.Name, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
keyString := string(ssh.MarshalAuthorizedKey(pubKey))
|
||||
parts := strings.SplitN(keyString, " ", 2)
|
||||
var base64Key string
|
||||
if len(parts) > 1 {
|
||||
base64Key = parts[1]
|
||||
} else {
|
||||
base64Key = parts[0]
|
||||
}
|
||||
|
||||
base64Key = strings.TrimSpace(base64Key)
|
||||
|
||||
perm, exists := ap.permissions[base64Key]
|
||||
if !exists {
|
||||
return false, "unknown", nil
|
||||
}
|
||||
|
||||
if len(perm.Endpoints) == 1 && perm.Endpoints[0] == "*" {
|
||||
return true, perm.Name, nil
|
||||
}
|
||||
|
||||
for _, allowedEndpoint := range perm.Endpoints {
|
||||
if allowedEndpoint == endpoint {
|
||||
return true, perm.Name, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, "unknown", nil
|
||||
}
|
||||
133
auth/authorized_keys_test.go
Normal file
133
auth/authorized_keys_test.go
Normal file
@@ -0,0 +1,133 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
const validB64 = "AAAAC3NzaC1lZDI1NTE5AAAAICy1v/Sn0kGhu1LXzCsnx3wlk5ESdncS66JWo13yeJod"
|
||||
|
||||
func TestParse(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
file string
|
||||
want map[string]*KeyPermission
|
||||
}{
|
||||
{
|
||||
name: "two fields only defaults",
|
||||
file: "ssh-ed25519 " + validB64 + "\n",
|
||||
want: map[string]*KeyPermission{
|
||||
validB64: {
|
||||
Name: "default",
|
||||
Endpoints: []string{"*"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "extra whitespace collapsed and default endpoints",
|
||||
file: "ssh-ed25519 " + validB64 + " alice\n",
|
||||
want: map[string]*KeyPermission{
|
||||
validB64: {
|
||||
Name: "alice",
|
||||
Endpoints: []string{"*"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "four fields full",
|
||||
file: "ssh-ed25519 " + validB64 + " bob /api/foo,/api/bar\n",
|
||||
want: map[string]*KeyPermission{
|
||||
validB64: {
|
||||
Name: "bob",
|
||||
Endpoints: []string{"/api/foo", "/api/bar"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "comment lines ignored and multiple entries",
|
||||
file: "# header\n\nssh-ed25519 " + validB64 + " user1\nssh-ed25519 " + validB64 + " user2 /api/x\n",
|
||||
want: map[string]*KeyPermission{
|
||||
validB64: {
|
||||
Name: "user1",
|
||||
Endpoints: []string{"*"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "three entries variety",
|
||||
file: "ssh-ed25519 " + validB64 + "\nssh-ed25519 " + validB64 + " alice /api/a,/api/b\nssh-ed25519 " + validB64 + " bob /api/c\n",
|
||||
want: map[string]*KeyPermission{
|
||||
validB64: {
|
||||
Name: "alice",
|
||||
Endpoints: []string{"*"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "two entries w/ wildcard",
|
||||
file: "ssh-ed25519 " + validB64 + " alice /api/a\n* * * /api/b\n",
|
||||
want: map[string]*KeyPermission{
|
||||
validB64: {
|
||||
Name: "alice",
|
||||
Endpoints: []string{"/api/a"},
|
||||
},
|
||||
"*": {
|
||||
Name: "default",
|
||||
Endpoints: []string{"/api/b"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tags for everyone",
|
||||
file: "* * * /api/tags",
|
||||
want: map[string]*KeyPermission{
|
||||
"*": {
|
||||
Name: "default",
|
||||
Endpoints: []string{"/api/tags"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "default name",
|
||||
file: "* * somename",
|
||||
want: map[string]*KeyPermission{
|
||||
"*": {
|
||||
Name: "somename",
|
||||
Endpoints: []string{"*"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "unsupported key type",
|
||||
file: "ssh-rsa AAAAB3Nza...\n",
|
||||
want: map[string]*KeyPermission{},
|
||||
},
|
||||
{
|
||||
name: "bad base64",
|
||||
file: "ssh-ed25519 invalid@@@\n",
|
||||
want: map[string]*KeyPermission{},
|
||||
},
|
||||
{
|
||||
name: "just an asterix",
|
||||
file: "*\n",
|
||||
want: map[string]*KeyPermission{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
perms := NewAPIPermissions()
|
||||
err := perms.parse(bytes.NewBufferString(tc.file))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(perms.permissions) != len(tc.want) {
|
||||
t.Fatalf("got %d entries, want %d", len(perms.permissions), len(tc.want))
|
||||
}
|
||||
if !reflect.DeepEqual(perms.permissions, tc.want) {
|
||||
t.Errorf("got %+v, want %+v", perms.permissions, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1137,6 +1137,14 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// this error should ideally be wrapped properly by the client
|
||||
if strings.Contains(err.Error(), "upstream error") {
|
||||
p.StopAndClear()
|
||||
fmt.Println("An error occurred while processing your message. Please try again.")
|
||||
fmt.Println()
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
@@ -11,14 +11,13 @@ import (
|
||||
"io"
|
||||
"io/fs"
|
||||
"log/slog"
|
||||
"maps"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
@@ -137,9 +136,7 @@ func TestConvertModel(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
keys := maps.Keys(expect)
|
||||
slices.Sort(keys)
|
||||
for _, k := range keys {
|
||||
for _, k := range slices.Sorted(maps.Keys(expect)) {
|
||||
if v, ok := actual[k]; !ok {
|
||||
t.Errorf("missing %s", k)
|
||||
} else if v != expect[k] {
|
||||
@@ -343,9 +340,7 @@ func TestConvertAdapter(t *testing.T) {
|
||||
|
||||
actual := generateResultsJSON(t, r, m.KV(), m.Tensors())
|
||||
|
||||
keys := maps.Keys(c.Expected)
|
||||
slices.Sort(keys)
|
||||
for _, k := range keys {
|
||||
for _, k := range slices.Sorted(maps.Keys(c.Expected)) {
|
||||
if v, ok := actual[k]; !ok {
|
||||
t.Errorf("missing %s", k)
|
||||
} else if v != c.Expected[k] {
|
||||
|
||||
@@ -8,12 +8,12 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"maps"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/d4l3k/go-bfloat16"
|
||||
"github.com/x448/float16"
|
||||
"golang.org/x/exp/maps"
|
||||
)
|
||||
|
||||
type safetensorMetadata struct {
|
||||
@@ -46,8 +46,7 @@ func parseSafetensors(fsys fs.FS, replacer *strings.Replacer, ps ...string) ([]T
|
||||
return nil, err
|
||||
}
|
||||
|
||||
keys := maps.Keys(headers)
|
||||
slices.Sort(keys)
|
||||
keys := slices.Sorted(maps.Keys(headers))
|
||||
|
||||
names := make(map[string]struct{}, len(keys))
|
||||
|
||||
|
||||
@@ -8,11 +8,10 @@ import (
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"log/slog"
|
||||
"maps"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/exp/maps"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -260,11 +259,8 @@ func parseVocabularyFromTokenizer(fsys fs.FS) (*Vocabulary, error) {
|
||||
tokens[token.ID] = token
|
||||
}
|
||||
|
||||
keys := maps.Keys(tokens)
|
||||
slices.Sort(keys)
|
||||
|
||||
v := Vocabulary{Model: "gpt2"}
|
||||
for _, k := range keys {
|
||||
for _, k := range slices.Sorted(maps.Keys(tokens)) {
|
||||
token := tokens[k]
|
||||
v.Tokens = append(v.Tokens, token.Content)
|
||||
v.Scores = append(v.Scores, float32(token.ID))
|
||||
|
||||
@@ -500,11 +500,11 @@ The `message` object has the following fields:
|
||||
- `thinking`: (for thinking models) the model's thinking process
|
||||
- `images` (optional): a list of images to include in the message (for multimodal models such as `llava`)
|
||||
- `tool_calls` (optional): a list of tools in JSON that the model wants to use
|
||||
- `tool_name` (optional): add the name of the tool that was executed to inform the model of the result
|
||||
- `tool_name` (optional): add the name of the tool that was executed to inform the model of the result
|
||||
|
||||
Advanced parameters (optional):
|
||||
|
||||
- `format`: the format to return a response in. Format can be `json` or a JSON schema.
|
||||
- `format`: the format to return a response in. Format can be `json` or a JSON schema.
|
||||
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
|
||||
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
|
||||
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
|
||||
|
||||
@@ -118,7 +118,7 @@ To run tests, use `go test`:
|
||||
go test ./...
|
||||
```
|
||||
|
||||
> NOTE: In rare cirumstances, you may need to change a package using the new
|
||||
> NOTE: In rare circumstances, you may need to change a package using the new
|
||||
> "synctest" package in go1.24.
|
||||
>
|
||||
> If you do not have the "synctest" package enabled, you will not see build or
|
||||
|
||||
@@ -72,7 +72,7 @@ client = OpenAI(base_url="http://localhost:11434/v1", api_key="ollama")
|
||||
# Define the schema for the response
|
||||
class FriendInfo(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
age: int
|
||||
is_available: bool
|
||||
|
||||
class FriendList(BaseModel):
|
||||
|
||||
@@ -9,7 +9,7 @@ cat ~/.ollama/logs/server.log
|
||||
On **Linux** systems with systemd, the logs can be found with this command:
|
||||
|
||||
```shell
|
||||
journalctl -u ollama --no-pager --follow --pager-end
|
||||
journalctl -u ollama --no-pager --follow --pager-end
|
||||
```
|
||||
|
||||
When you run Ollama in a **container**, the logs go to stdout/stderr in the container:
|
||||
@@ -23,7 +23,7 @@ docker logs <container-name>
|
||||
If manually running `ollama serve` in a terminal, the logs will be on that terminal.
|
||||
|
||||
When you run Ollama on **Windows**, there are a few different locations. You can view them in the explorer window by hitting `<cmd>+R` and type in:
|
||||
- `explorer %LOCALAPPDATA%\Ollama` to view logs. The most recent server logs will be in `server.log` and older logs will be in `server-#.log`
|
||||
- `explorer %LOCALAPPDATA%\Ollama` to view logs. The most recent server logs will be in `server.log` and older logs will be in `server-#.log`
|
||||
- `explorer %LOCALAPPDATA%\Programs\Ollama` to browse the binaries (The installer adds this to your user PATH)
|
||||
- `explorer %HOMEPATH%\.ollama` to browse where models and configuration is stored
|
||||
|
||||
@@ -38,7 +38,7 @@ Join the [Discord](https://discord.gg/ollama) for help interpreting the logs.
|
||||
|
||||
## LLM libraries
|
||||
|
||||
Ollama includes multiple LLM libraries compiled for different GPUs and CPU vector features. Ollama tries to pick the best one based on the capabilities of your system. If this autodetection has problems, or you run into other problems (e.g. crashes in your GPU) you can workaround this by forcing a specific LLM library. `cpu_avx2` will perform the best, followed by `cpu_avx` an the slowest but most compatible is `cpu`. Rosetta emulation under MacOS will work with the `cpu` library.
|
||||
Ollama includes multiple LLM libraries compiled for different GPUs and CPU vector features. Ollama tries to pick the best one based on the capabilities of your system. If this autodetection has problems, or you run into other problems (e.g. crashes in your GPU) you can workaround this by forcing a specific LLM library. `cpu_avx2` will perform the best, followed by `cpu_avx` and the slowest but most compatible is `cpu`. Rosetta emulation under MacOS will work with the `cpu` library.
|
||||
|
||||
In the server log, you will see a message that looks something like this (varies from release to release):
|
||||
|
||||
@@ -97,7 +97,7 @@ If none of those resolve the problem, gather additional information and file an
|
||||
|
||||
On linux, AMD GPU access typically requires `video` and/or `render` group membership to access the `/dev/kfd` device. If permissions are not set up correctly, Ollama will detect this and report an error in the server log.
|
||||
|
||||
When running in a container, in some Linux distributions and container runtimes, the ollama process may be unable to access the GPU. Use `ls -lnd /dev/kfd /dev/dri /dev/dri/*` on the host system to determine the **numeric** group IDs on your system, and pass additional `--group-add ...` arguments to the container so it can access the required devices. For example, in the following output `crw-rw---- 1 0 44 226, 0 Sep 16 16:55 /dev/dri/card0` the group ID column is `44`
|
||||
When running in a container, in some Linux distributions and container runtimes, the ollama process may be unable to access the GPU. Use `ls -lnd /dev/kfd /dev/dri /dev/dri/*` on the host system to determine the **numeric** group IDs on your system, and pass additional `--group-add ...` arguments to the container so it can access the required devices. For example, in the following output `crw-rw---- 1 0 44 226, 0 Sep 16 16:55 /dev/dri/card0` the group ID column is `44`
|
||||
|
||||
If you are experiencing problems getting Ollama to correctly discover or use your GPU for inference, the following may help isolate the failure.
|
||||
- `AMD_LOG_LEVEL=3` Enable info log levels in the AMD HIP/ROCm libraries. This can help show more detailed error codes that can help troubleshoot problems
|
||||
|
||||
2
go.mod
2
go.mod
@@ -71,7 +71,7 @@ require (
|
||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||
golang.org/x/arch v0.8.0 // indirect
|
||||
golang.org/x/crypto v0.36.0
|
||||
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa
|
||||
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa // indirect
|
||||
golang.org/x/net v0.38.0 // indirect
|
||||
golang.org/x/sys v0.31.0
|
||||
golang.org/x/term v0.30.0
|
||||
|
||||
@@ -25,6 +25,9 @@ type Causal struct {
|
||||
|
||||
opts CausalOptions
|
||||
|
||||
// maxBatch is the largest batch that we might receive
|
||||
maxBatch int
|
||||
|
||||
// config controls mostly backend-specific optimizations
|
||||
config *ml.CacheConfig
|
||||
|
||||
@@ -147,6 +150,7 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity
|
||||
c.DType = dtype
|
||||
c.cellRanges = make(map[int]cellRange)
|
||||
c.backend = backend
|
||||
c.maxBatch = maxBatch
|
||||
}
|
||||
|
||||
func (c *Causal) SetConfig(config ml.CacheConfig) {
|
||||
@@ -639,48 +643,64 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error {
|
||||
return ErrNotSupported
|
||||
}
|
||||
|
||||
ctx := c.backend.NewContext()
|
||||
defer ctx.Close()
|
||||
|
||||
seqRange := c.cellRanges[seq]
|
||||
size := seqRange.max - seqRange.min + 1
|
||||
|
||||
offsets := make([]int32, size)
|
||||
for i := range offsets {
|
||||
cell := c.cells[seqRange.min+i]
|
||||
for start := seqRange.min; start <= seqRange.max; start += c.maxBatch {
|
||||
size := min(seqRange.max-start+1, c.maxBatch)
|
||||
offsets := make([]int32, size)
|
||||
|
||||
if slices.Contains(cell.sequences, seq) && cell.pos >= beginIndex {
|
||||
offsets[i] = offset
|
||||
var batchFirst, batchLast int
|
||||
|
||||
batchFirst = -1
|
||||
for i := range offsets {
|
||||
cell := c.cells[start+i]
|
||||
|
||||
if slices.Contains(cell.sequences, seq) && cell.pos >= beginIndex {
|
||||
offsets[i] = offset
|
||||
if batchFirst < 0 {
|
||||
batchFirst = i
|
||||
}
|
||||
batchLast = i
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
kShift := ctx.Input().FromIntSlice(offsets, len(offsets))
|
||||
|
||||
for i, key := range c.keys {
|
||||
if key == nil {
|
||||
if batchFirst < 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
kHeadDim := key.Dim(0)
|
||||
numKVHeads := key.Dim(1)
|
||||
rowSize := key.Stride(2)
|
||||
offsets = offsets[batchFirst : batchLast+1]
|
||||
|
||||
key = key.View(ctx, rowSize*seqRange.min,
|
||||
kHeadDim, key.Stride(1),
|
||||
numKVHeads, key.Stride(2),
|
||||
size,
|
||||
)
|
||||
ctx := c.backend.NewContext()
|
||||
kShift := ctx.Input().FromIntSlice(offsets, len(offsets))
|
||||
|
||||
roped, err := c.shiftFn(ctx, i, key, kShift)
|
||||
if err != nil {
|
||||
return err
|
||||
for i, key := range c.keys {
|
||||
if key == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
kHeadDim := key.Dim(0)
|
||||
numKVHeads := key.Dim(1)
|
||||
rowSize := key.Stride(2)
|
||||
|
||||
key = key.View(ctx, rowSize*(start+batchFirst),
|
||||
kHeadDim, key.Stride(1),
|
||||
numKVHeads, key.Stride(2),
|
||||
len(offsets),
|
||||
)
|
||||
|
||||
roped, err := c.shiftFn(ctx, i, key, kShift)
|
||||
if err != nil {
|
||||
ctx.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
ctx.Forward(roped.Copy(ctx, key))
|
||||
}
|
||||
|
||||
ctx.Forward(roped.Copy(ctx, key))
|
||||
ctx.Compute()
|
||||
ctx.Close()
|
||||
}
|
||||
|
||||
ctx.Compute()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ ggml-ci
|
||||
2 files changed, 67 insertions(+), 14 deletions(-)
|
||||
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m
|
||||
index ee4f2dcb..f20f5615 100644
|
||||
index a9eeebc6..110c9ece 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal.m
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal.m
|
||||
@@ -489,6 +489,7 @@ enum ggml_metal_kernel_type {
|
||||
|
||||
@@ -52,7 +52,7 @@ index 64fb4ff4..5b9a0fe3 100644
|
||||
static __device__ __forceinline__ float warp_reduce_max(float x) {
|
||||
#pragma unroll
|
||||
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
index 4c829153..9e64e5ae 100644
|
||||
index d6960174..2b9fabf4 100644
|
||||
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
@@ -35,6 +35,7 @@
|
||||
|
||||
50
llama/patches/0021-Enable-CUDA-Graphs-for-gemma3n.patch
Normal file
50
llama/patches/0021-Enable-CUDA-Graphs-for-gemma3n.patch
Normal file
@@ -0,0 +1,50 @@
|
||||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: Oliver Simons <osimons@nvidia.com>
|
||||
Date: Tue, 22 Jul 2025 11:02:28 +0200
|
||||
Subject: [PATCH] Enable CUDA Graphs for gemma3n.
|
||||
|
||||
Similar to
|
||||
https://github.com/ggml-org/llama.cpp/pull/14741,
|
||||
though ollama has a slightly different model graph
|
||||
than llama.cpp which requires different workaround
|
||||
checks.
|
||||
---
|
||||
ggml/src/ggml-cuda/ggml-cuda.cu | 16 ++++++++++++----
|
||||
1 file changed, 12 insertions(+), 4 deletions(-)
|
||||
|
||||
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
index 2b9fabf4..28ccf4be 100644
|
||||
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
@@ -2474,6 +2474,9 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
|
||||
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
|
||||
cuda_ctx->cuda_graph->cpy_dest_ptrs.clear();
|
||||
|
||||
+ const std::string gemma3n_per_layer_proj_src1_name = " (reshaped)";
|
||||
+ const std::string gemma3n_node_name = "node_";
|
||||
+
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
ggml_tensor * node = cgraph->nodes[i];
|
||||
|
||||
@@ -2495,12 +2498,17 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
|
||||
#endif
|
||||
}
|
||||
|
||||
- if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1) {
|
||||
- // disable CUDA graphs for batch size > 1 for now.
|
||||
- // Changes in batch size or context size can cause changes to the grid size of some kernels.
|
||||
+ // workarounds to exclude Gemma3n's `project_per_layer_input` operation from the batch-size heuristic, specific to ollama's implementation of gemma3n
|
||||
+ // number of layers is different for per_layer_proj between gemma3n:2b and gemma3n:4b, which is why we don't check that value here
|
||||
+ if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1 && !(node->ne[0] == 256
|
||||
+ && node->ne[2] == 1
|
||||
+ && node->ne[3] == 1
|
||||
+ && node->src[0] ? std::string(node->src[0]->name).find(gemma3n_node_name) != std::string::npos : false
|
||||
+ && node->src[1] ? node->src[1]->name == gemma3n_per_layer_proj_src1_name : false)) {
|
||||
+ // Generally, changes in batch size or context size can cause changes to the grid size of some kernels.
|
||||
use_cuda_graph = false;
|
||||
#ifndef NDEBUG
|
||||
- GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
|
||||
+ GGML_LOG_INFO("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
|
||||
#endif
|
||||
}
|
||||
|
||||
16
ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu
vendored
16
ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu
vendored
@@ -2474,6 +2474,9 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
|
||||
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
|
||||
cuda_ctx->cuda_graph->cpy_dest_ptrs.clear();
|
||||
|
||||
const std::string gemma3n_per_layer_proj_src1_name = " (reshaped)";
|
||||
const std::string gemma3n_node_name = "node_";
|
||||
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
ggml_tensor * node = cgraph->nodes[i];
|
||||
|
||||
@@ -2495,12 +2498,17 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
|
||||
#endif
|
||||
}
|
||||
|
||||
if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1) {
|
||||
// disable CUDA graphs for batch size > 1 for now.
|
||||
// Changes in batch size or context size can cause changes to the grid size of some kernels.
|
||||
// workarounds to exclude Gemma3n's `project_per_layer_input` operation from the batch-size heuristic, specific to ollama's implementation of gemma3n
|
||||
// number of layers is different for per_layer_proj between gemma3n:2b and gemma3n:4b, which is why we don't check that value here
|
||||
if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1 && !(node->ne[0] == 256
|
||||
&& node->ne[2] == 1
|
||||
&& node->ne[3] == 1
|
||||
&& node->src[0] ? std::string(node->src[0]->name).find(gemma3n_node_name) != std::string::npos : false
|
||||
&& node->src[1] ? node->src[1]->name == gemma3n_per_layer_proj_src1_name : false)) {
|
||||
// Generally, changes in batch size or context size can cause changes to the grid size of some kernels.
|
||||
use_cuda_graph = false;
|
||||
#ifndef NDEBUG
|
||||
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
|
||||
GGML_LOG_INFO("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ package ggml
|
||||
// #cgo CPPFLAGS: -I${SRCDIR}/../include -I${SRCDIR}/ggml-cpu
|
||||
// #cgo windows CFLAGS: -Wno-dll-attribute-on-redeclaration
|
||||
// #cgo windows LDFLAGS: -lmsvcrt -static -static-libgcc -static-libstdc++
|
||||
// #cgo windows linux CPPFLAGS: -DGGML_FP16_TO_FP32=ggml_compute_fp16_to_fp32
|
||||
// #include <stdlib.h>
|
||||
// #include "ggml-backend.h"
|
||||
// extern void sink(int level, char *text, void *user_data);
|
||||
|
||||
@@ -10,8 +10,6 @@ package ggml
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"iter"
|
||||
"slices"
|
||||
"unsafe"
|
||||
|
||||
fsggml "github.com/ollama/ollama/fs/ggml"
|
||||
@@ -52,30 +50,34 @@ func ConvertToF32(data []byte, dtype uint32, nelements uint64) []float32 {
|
||||
return f32s
|
||||
}
|
||||
|
||||
func Quantize(newType fsggml.TensorType, f32s []float32, shape []uint64) iter.Seq[[]byte] {
|
||||
return func(yield func([]byte) bool) {
|
||||
C.ggml_quantize_init(uint32(newType))
|
||||
defer C.ggml_quantize_free()
|
||||
|
||||
dims := slices.Repeat([]C.int64_t{1}, 4)
|
||||
for i, s := range shape {
|
||||
dims[i] = C.int64_t(s)
|
||||
}
|
||||
|
||||
bts := make([]byte, C.ggml_row_size(uint32(newType), dims[0])*C.size_t(dims[1]))
|
||||
for chunk := range dims[2] {
|
||||
offset := chunk * dims[0] * dims[1]
|
||||
|
||||
n := C.ggml_quantize_chunk(
|
||||
uint32(newType),
|
||||
(*C.float)(&f32s[0]),
|
||||
unsafe.Pointer(&bts[0]),
|
||||
offset, dims[1], dims[0], nil,
|
||||
)
|
||||
|
||||
if !yield(bts[:n]) {
|
||||
return
|
||||
}
|
||||
}
|
||||
func Quantize(newType fsggml.TensorType, f32s []float32, shape []uint64) []byte {
|
||||
buf := make([]byte, len(f32s)*4) // upper bound on size
|
||||
nPerRow := C.int64_t(shape[0])
|
||||
nrows := C.int64_t(1)
|
||||
if len(shape) > 1 {
|
||||
nrows = C.int64_t(shape[1])
|
||||
}
|
||||
shape2 := C.int64_t(1)
|
||||
if len(shape) > 2 {
|
||||
shape2 = C.int64_t(shape[2])
|
||||
}
|
||||
nelements_matrix := nPerRow * nrows
|
||||
newSize := C.size_t(0)
|
||||
for i03 := C.int64_t(0); i03 < shape2; i03++ {
|
||||
f32s_03 := i03 * nelements_matrix
|
||||
buf_03 := C.int64_t(C.ggml_row_size(uint32(newType), nPerRow)) * i03 * nrows
|
||||
newSize += C.ggml_quantize_chunk(
|
||||
uint32(newType),
|
||||
(*C.float)(&f32s[f32s_03]),
|
||||
unsafe.Pointer((uintptr)(unsafe.Pointer(&buf[0]))+uintptr(buf_03)),
|
||||
0,
|
||||
nrows,
|
||||
nPerRow,
|
||||
nil)
|
||||
}
|
||||
return buf[:newSize]
|
||||
}
|
||||
|
||||
func QuantizationVersion() uint32 {
|
||||
return uint32(C.GGML_QNT_VERSION)
|
||||
}
|
||||
|
||||
@@ -203,10 +203,9 @@ func (a AltUp) Predict(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions
|
||||
coefficients := a.PredictionCoefficient.Forward(ctx, modalities)
|
||||
coefficients = coefficients.Reshape(ctx, opts.altupInputs, opts.altupInputs, coefficients.Dim(1), coefficients.Dim(2))
|
||||
|
||||
hiddenStates = hiddenStates.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
predictions := coefficients.Mulmat(ctx, hiddenStates)
|
||||
predictions = predictions.Add(ctx, hiddenStates)
|
||||
return predictions.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
|
||||
predictions := coefficients.Mulmat(ctx, hiddenStates.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx))
|
||||
predictions = predictions.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
|
||||
return predictions.Add(ctx, hiddenStates)
|
||||
}
|
||||
|
||||
func (a AltUp) Correct(ctx ml.Context, predictions, activated, one ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||
|
||||
@@ -40,19 +40,10 @@ func (q quantizer) WriteTo(w io.Writer) (int64, error) {
|
||||
} else {
|
||||
f32s = ggml.ConvertToF32(data, q.from.Kind, q.from.Elements())
|
||||
}
|
||||
|
||||
var n int64
|
||||
for bts := range ggml.Quantize(newType, f32s, q.from.Shape) {
|
||||
nn, err := w.Write(bts)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
q.progressFn(uint64(nn))
|
||||
n += int64(nn)
|
||||
}
|
||||
|
||||
return n, err
|
||||
data = ggml.Quantize(newType, f32s, q.from.Shape)
|
||||
n, err := w.Write(data)
|
||||
q.progressFn(q.from.Size())
|
||||
return int64(n), err
|
||||
}
|
||||
|
||||
type quantizeState struct {
|
||||
|
||||
@@ -2,14 +2,12 @@ package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
fsggml "github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/ml/backend/ggml"
|
||||
)
|
||||
@@ -651,55 +649,3 @@ var (
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
func TestQuantizer(t *testing.T) {
|
||||
from := fsggml.Tensor{
|
||||
Name: "fp32",
|
||||
Shape: []uint64{256},
|
||||
Kind: uint32(fsggml.TensorTypeF32),
|
||||
}
|
||||
|
||||
temp, err := os.CreateTemp(t.TempDir(), "*.bin")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create temp file: %v", err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, 256)
|
||||
for i := range f32s {
|
||||
f32s[i] = float32(i)
|
||||
}
|
||||
|
||||
if err := binary.Write(temp, binary.LittleEndian, f32s); err != nil {
|
||||
t.Fatalf("failed to write to temp file: %v", err)
|
||||
}
|
||||
|
||||
for type_, want := range quantBytes {
|
||||
t.Run(type_.String(), func(t *testing.T) {
|
||||
f, err := os.Open(temp.Name())
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open temp file: %v", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
q := quantizer{
|
||||
File: f,
|
||||
from: &from,
|
||||
to: &fsggml.Tensor{
|
||||
Name: type_.String(),
|
||||
Shape: from.Shape,
|
||||
Kind: uint32(type_),
|
||||
},
|
||||
progressFn: func(uint64) {},
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := q.WriteTo(&b); err != nil {
|
||||
t.Fatalf("WriteTo failed: %v", err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(b.Bytes(), want); diff != "" {
|
||||
t.Errorf("quantized data mismatch for %s (-got +want):\n%s", type_, diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,6 +28,7 @@ import (
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/auth"
|
||||
"github.com/ollama/ollama/discover"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
@@ -55,6 +56,8 @@ var mode string = gin.DebugMode
|
||||
type Server struct {
|
||||
addr net.Addr
|
||||
sched *Scheduler
|
||||
|
||||
perms *auth.APIPermissions
|
||||
}
|
||||
|
||||
func init() {
|
||||
@@ -69,6 +72,38 @@ func init() {
|
||||
gin.SetMode(mode)
|
||||
}
|
||||
|
||||
func loggedFormatter(param gin.LogFormatterParams) string {
|
||||
var statusColor, methodColor, resetColor string
|
||||
if param.IsOutputColor() {
|
||||
statusColor = param.StatusCodeColor()
|
||||
methodColor = param.MethodColor()
|
||||
resetColor = param.ResetColor()
|
||||
}
|
||||
|
||||
if param.Latency > time.Minute {
|
||||
param.Latency = param.Latency.Truncate(time.Second)
|
||||
}
|
||||
|
||||
username := "default"
|
||||
if userVal, exists := param.Keys["username"]; exists {
|
||||
if name, ok := userVal.(string); ok {
|
||||
username = name
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"[Ollama] %s |%s %3d %s| %13v | %15s | %-20s |%s %-7s %s %#v\n%s",
|
||||
param.TimeStamp.Format("2006/01/02 - 15:04:05"),
|
||||
statusColor, param.StatusCode, resetColor,
|
||||
param.Latency,
|
||||
param.ClientIP,
|
||||
username,
|
||||
methodColor, param.Method, resetColor,
|
||||
param.Path,
|
||||
param.ErrorMessage,
|
||||
)
|
||||
}
|
||||
|
||||
var (
|
||||
errRequired = errors.New("is required")
|
||||
errBadTemplate = errors.New("template error")
|
||||
@@ -1111,6 +1146,43 @@ func allowedHost(host string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func allowedEndpointsMiddleware(perms *auth.APIPermissions) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if !envconfig.UseAuth() || (c.Request.Method == "HEAD" && c.Request.URL.Path == "/") {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
token := strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer ")
|
||||
if token == "" {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
pubKey, err := auth.Authenticate(token, fmt.Sprintf("%s,%s", c.Request.Method, c.Request.RequestURI))
|
||||
if err != nil {
|
||||
slog.Error("authentication error", "error", err)
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
authorized, name, err := perms.Authorize(pubKey, c.Request.URL.Path)
|
||||
c.Set("username", name)
|
||||
if err != nil {
|
||||
slog.Error("authorization error", "error", err)
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
if !authorized {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if addr == nil {
|
||||
@@ -1177,10 +1249,13 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||
}
|
||||
corsConfig.AllowOrigins = envconfig.AllowedOrigins()
|
||||
|
||||
r := gin.Default()
|
||||
r := gin.New()
|
||||
r.HandleMethodNotAllowed = true
|
||||
r.Use(
|
||||
gin.LoggerWithFormatter(loggedFormatter),
|
||||
gin.Recovery(),
|
||||
cors.New(corsConfig),
|
||||
allowedEndpointsMiddleware(s.perms),
|
||||
allowedHostsMiddleware(s.addr),
|
||||
)
|
||||
|
||||
@@ -1190,7 +1265,7 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||
r.HEAD("/api/version", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"version": version.Version}) })
|
||||
r.GET("/api/version", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"version": version.Version}) })
|
||||
|
||||
// Local model cache management (new implementation is at end of function)
|
||||
// Local model cache management
|
||||
r.POST("/api/pull", s.PullHandler)
|
||||
r.POST("/api/push", s.PushHandler)
|
||||
r.HEAD("/api/tags", s.ListHandler)
|
||||
@@ -1222,7 +1297,7 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||
// wrap old with new
|
||||
rs := ®istry.Local{
|
||||
Client: rc,
|
||||
Logger: slog.Default(), // TODO(bmizerany): Take a logger, do not use slog.Default()
|
||||
Logger: slog.Default(),
|
||||
Fallback: r,
|
||||
|
||||
Prune: PruneLayers,
|
||||
@@ -1267,6 +1342,12 @@ func Serve(ln net.Listener) error {
|
||||
|
||||
s := &Server{addr: ln.Addr()}
|
||||
|
||||
if envconfig.UseAuth() {
|
||||
perms := auth.NewAPIPermissions()
|
||||
perms.ReloadIfNeeded()
|
||||
s.perms = perms
|
||||
}
|
||||
|
||||
var rc *ollama.Registry
|
||||
if useClient2 {
|
||||
var err error
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -82,19 +83,6 @@ func createTestFile(t *testing.T, name string) (string, string) {
|
||||
return f.Name(), digest
|
||||
}
|
||||
|
||||
// equalStringSlices checks if two slices of strings are equal.
|
||||
func equalStringSlices(a, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if a[i] != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
type panicTransport struct{}
|
||||
|
||||
func (t *panicTransport) RoundTrip(r *http.Request) (*http.Response, error) {
|
||||
@@ -447,7 +435,7 @@ func TestRoutes(t *testing.T) {
|
||||
"stop \"foo\"",
|
||||
"top_p 0.9",
|
||||
}
|
||||
if !equalStringSlices(params, expectedParams) {
|
||||
if !slices.Equal(params, expectedParams) {
|
||||
t.Errorf("expected parameters %v, got %v", expectedParams, params)
|
||||
}
|
||||
paramCount, ok := showResp.ModelInfo["general.parameter_count"].(float64)
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"maps"
|
||||
"math"
|
||||
"slices"
|
||||
"strings"
|
||||
@@ -14,7 +15,6 @@ import (
|
||||
"text/template/parse"
|
||||
|
||||
"github.com/agnivade/levenshtein"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
@@ -157,9 +157,7 @@ func (t *Template) Vars() []string {
|
||||
set[strings.ToLower(n)] = struct{}{}
|
||||
}
|
||||
|
||||
vars = maps.Keys(set)
|
||||
slices.Sort(vars)
|
||||
return vars
|
||||
return slices.Sorted(maps.Keys(set))
|
||||
}
|
||||
|
||||
type Values struct {
|
||||
|
||||
125
tools/tools.go
125
tools/tools.go
@@ -120,16 +120,14 @@ func (p *Parser) parseToolCall() *api.ToolCall {
|
||||
return nil
|
||||
}
|
||||
|
||||
// only look for arguments after the tool name if the tool has parameters
|
||||
// TODO (jmorganca): while probably uncommon, this doesn't support
|
||||
// parsing arguments before the tool name, which may be needed in the future
|
||||
args := map[string]any{}
|
||||
if len(tool.Function.Parameters.Properties) > 0 {
|
||||
var i int
|
||||
if args, i = findArguments(*tool, p.buffer[end:]); args == nil {
|
||||
return nil
|
||||
var args map[string]any
|
||||
if found, i := findArguments(p.buffer); found == nil {
|
||||
return nil
|
||||
} else {
|
||||
args = found
|
||||
if i > end {
|
||||
end = i
|
||||
}
|
||||
end += i
|
||||
}
|
||||
|
||||
tc := &api.ToolCall{
|
||||
@@ -217,93 +215,70 @@ func findTool(tools []api.Tool, buf []byte) (*api.Tool, int) {
|
||||
// objects for functions that have all-optional parameters
|
||||
// e.g. `{"name": "get_conditions", "arguments": {}}` will work but
|
||||
// `{"name": "get_conditions"}` will not currently work
|
||||
func findArguments(tool api.Tool, buffer []byte) (map[string]any, int) {
|
||||
func findArguments(buffer []byte) (map[string]any, int) {
|
||||
if len(buffer) == 0 {
|
||||
return nil, 0
|
||||
}
|
||||
|
||||
var braces int
|
||||
var start int = -1
|
||||
var end int
|
||||
var object []byte
|
||||
|
||||
// find any outer json object
|
||||
for i, c := range buffer {
|
||||
if c == '{' {
|
||||
braces++
|
||||
if start == -1 {
|
||||
if braces == 0 {
|
||||
start = i
|
||||
}
|
||||
}
|
||||
braces++
|
||||
} else if c == '}' && braces > 0 {
|
||||
braces--
|
||||
if braces == 0 && start != -1 {
|
||||
object := buffer[start : i+1]
|
||||
|
||||
if c == '}' {
|
||||
if start != -1 {
|
||||
braces--
|
||||
if braces == 0 {
|
||||
end = i + 1
|
||||
object = buffer[start:end]
|
||||
break
|
||||
var data map[string]any
|
||||
if err := json.Unmarshal(object, &data); err != nil {
|
||||
start = -1
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if braces > 0 {
|
||||
return nil, 0
|
||||
}
|
||||
|
||||
var data map[string]any
|
||||
if err := json.Unmarshal(object, &data); err != nil {
|
||||
return nil, 0
|
||||
}
|
||||
|
||||
var find func(obj any) map[string]any
|
||||
find = func(obj any) map[string]any {
|
||||
switch obj := obj.(type) {
|
||||
case map[string]any:
|
||||
valid := true
|
||||
// check if all keys in the object exist in the tool's parameters
|
||||
for key := range obj {
|
||||
if _, exists := tool.Function.Parameters.Properties[key]; !exists {
|
||||
valid = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// check for required parameters
|
||||
// TODO (jmorganca): this should error instead of silently failing
|
||||
if valid {
|
||||
for _, required := range tool.Function.Parameters.Required {
|
||||
if _, exists := obj[required]; !exists {
|
||||
valid = false
|
||||
break
|
||||
var findObject func(obj map[string]any) (map[string]any, bool)
|
||||
findObject = func(obj map[string]any) (map[string]any, bool) {
|
||||
if _, hasName := obj["name"]; hasName {
|
||||
if args, ok := obj["arguments"].(map[string]any); ok {
|
||||
return args, true
|
||||
}
|
||||
if args, ok := obj["parameters"].(map[string]any); ok {
|
||||
return args, true
|
||||
}
|
||||
return nil, true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if valid {
|
||||
return obj
|
||||
}
|
||||
for _, v := range obj {
|
||||
switch child := v.(type) {
|
||||
case map[string]any:
|
||||
if result, found := findObject(child); found {
|
||||
return result, true
|
||||
}
|
||||
case []any:
|
||||
for _, item := range child {
|
||||
if childObj, ok := item.(map[string]any); ok {
|
||||
if result, found := findObject(childObj); found {
|
||||
return result, true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, value := range obj {
|
||||
if result := find(value); result != nil {
|
||||
return result
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
case []any:
|
||||
for _, item := range obj {
|
||||
if result := find(item); result != nil {
|
||||
return result
|
||||
|
||||
if args, found := findObject(data); found {
|
||||
return args, i
|
||||
}
|
||||
|
||||
return data, i
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
result := find(data)
|
||||
if result != nil {
|
||||
return result, end
|
||||
}
|
||||
|
||||
return nil, 0
|
||||
|
||||
@@ -227,13 +227,6 @@ func TestParser(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid arguments",
|
||||
inputs: []string{`<tool_call>{"name": "get_conditions", "arguments": {"city": "San Francisco"}}</tool_call>`},
|
||||
content: "",
|
||||
tmpl: qwen,
|
||||
calls: nil,
|
||||
},
|
||||
{
|
||||
name: "empty args",
|
||||
inputs: []string{`<tool_call>{"name": "get_conditions", "arguments": {}}</tool_call>`},
|
||||
@@ -249,13 +242,6 @@ func TestParser(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing required args",
|
||||
inputs: []string{`<tool_call>{"name": "get_temperature", "arguments": {}}</tool_call>`},
|
||||
content: "",
|
||||
tmpl: qwen,
|
||||
calls: nil,
|
||||
},
|
||||
{
|
||||
name: "text before tool call",
|
||||
inputs: []string{`Let me check the weather. <tool_call>{"name": "get_temperature", "arguments": {"city": "New York"}}</tool_call>`},
|
||||
@@ -273,21 +259,6 @@ func TestParser(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "qwen no args tool call",
|
||||
inputs: []string{`Let me say hello to the user. I'll use the say_hello tool <tool_call>{"name": "say_hello"}</tool_call>`},
|
||||
content: "Let me say hello to the user. I'll use the say_hello tool ",
|
||||
tmpl: qwen,
|
||||
calls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Index: 0,
|
||||
Name: "say_hello",
|
||||
Arguments: api.ToolCallFunctionArguments{},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "qwen no args with text",
|
||||
inputs: []string{"Let me say hello to the user. I'll use the say_hello tool. "},
|
||||
@@ -521,52 +492,6 @@ func TestParser(t *testing.T) {
|
||||
content: "for { fmt.Println(\"hello\") }",
|
||||
tmpl: json,
|
||||
},
|
||||
{
|
||||
name: "json no args tool call",
|
||||
inputs: []string{
|
||||
"{\"name\": \"say_hello\"}",
|
||||
},
|
||||
content: "",
|
||||
tmpl: json,
|
||||
calls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Index: 0,
|
||||
Name: "say_hello",
|
||||
Arguments: api.ToolCallFunctionArguments{},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "json no args no tool call",
|
||||
inputs: []string{
|
||||
"I'll use the say_hello tool to say hello to the user.",
|
||||
},
|
||||
content: "I'll use the say_hello tool to say hello to the user.",
|
||||
tmpl: json,
|
||||
calls: nil,
|
||||
},
|
||||
|
||||
// TODO (jmorganca): this is a false positive, we should
|
||||
// not be parsing this as a tool call
|
||||
{
|
||||
name: "json no args false positive",
|
||||
inputs: []string{
|
||||
`{say_hello!!!}`,
|
||||
},
|
||||
content: "",
|
||||
tmpl: json,
|
||||
calls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Index: 0,
|
||||
Name: "say_hello",
|
||||
Arguments: api.ToolCallFunctionArguments{},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "list multiple",
|
||||
inputs: []string{
|
||||
@@ -684,26 +609,6 @@ func TestParser(t *testing.T) {
|
||||
tmpl: list,
|
||||
calls: nil,
|
||||
},
|
||||
{
|
||||
name: "list with no arguments",
|
||||
inputs: []string{
|
||||
"[",
|
||||
"{",
|
||||
"\"name\": \"say_hello\"",
|
||||
"}",
|
||||
},
|
||||
content: "",
|
||||
tmpl: list,
|
||||
calls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Index: 0,
|
||||
Name: "say_hello",
|
||||
Arguments: api.ToolCallFunctionArguments{},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool name with collision",
|
||||
inputs: []string{
|
||||
@@ -711,7 +616,7 @@ func TestParser(t *testing.T) {
|
||||
"{",
|
||||
"\"name\": \"say_hello",
|
||||
"_world\",",
|
||||
"}",
|
||||
"\"arguments\": {}}",
|
||||
"}",
|
||||
},
|
||||
content: "",
|
||||
@@ -733,13 +638,13 @@ func TestParser(t *testing.T) {
|
||||
"{",
|
||||
"\"name\": \"say_hello",
|
||||
"_world\",",
|
||||
"}",
|
||||
"\"arguments\": {}}",
|
||||
"</tool_call>",
|
||||
"<tool_call>",
|
||||
"{",
|
||||
"\"name\": \"say_hello",
|
||||
"\",",
|
||||
"}",
|
||||
"\"arguments\": {}}",
|
||||
"</tool_call>",
|
||||
},
|
||||
content: "",
|
||||
@@ -773,7 +678,7 @@ func TestParser(t *testing.T) {
|
||||
{
|
||||
name: "tool name with collision non streaming multiple",
|
||||
inputs: []string{
|
||||
`<tool_call>{"name": "say_hello"}</tool_call><tool_call>{"name": "say_hello_world"}`,
|
||||
`<tool_call>{"name": "say_hello", "arguments": {}}</tool_call><tool_call>{"name": "say_hello_world", "arguments": {}}`,
|
||||
},
|
||||
content: "",
|
||||
tmpl: qwen,
|
||||
@@ -797,7 +702,7 @@ func TestParser(t *testing.T) {
|
||||
{
|
||||
name: "tool name with collision non streaming shorter",
|
||||
inputs: []string{
|
||||
`<tool_call>{"name": "say_hello"}</tool_call>`,
|
||||
`<tool_call>{"name": "say_hello", "arguments": {}}</tool_call>`,
|
||||
},
|
||||
content: "",
|
||||
tmpl: qwen,
|
||||
@@ -814,7 +719,7 @@ func TestParser(t *testing.T) {
|
||||
{
|
||||
name: "tool name with collision non streaming longer",
|
||||
inputs: []string{
|
||||
`<tool_call>{"name": "say_hello_world"}</tool_call>`,
|
||||
`<tool_call>{"name": "say_hello_world", "arguments": {}}</tool_call>`,
|
||||
},
|
||||
content: "",
|
||||
tmpl: qwen,
|
||||
@@ -871,6 +776,26 @@ func TestParser(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "args before name",
|
||||
inputs: []string{
|
||||
`<tool_call>{"arguments": {"a": "5", "b": "10"}, "name": "add"}</tool_call>`,
|
||||
},
|
||||
content: "",
|
||||
tmpl: qwen,
|
||||
calls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Index: 0,
|
||||
Name: "add",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"a": "5",
|
||||
"b": "10",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -1167,75 +1092,25 @@ func TestFindTag(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestFindArguments(t *testing.T) {
|
||||
tool := api.Tool{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_temperature",
|
||||
Description: "Retrieve the temperature for a given location",
|
||||
Parameters: struct {
|
||||
Type string `json:"type"`
|
||||
Defs any `json:"$defs,omitempty"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Required []string `json:"required"`
|
||||
Properties map[string]struct {
|
||||
Type api.PropertyType `json:"type"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Description string `json:"description"`
|
||||
Enum []any `json:"enum,omitempty"`
|
||||
} `json:"properties"`
|
||||
}{
|
||||
Type: "object",
|
||||
Properties: map[string]struct {
|
||||
Type api.PropertyType `json:"type"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Description string `json:"description"`
|
||||
Enum []any `json:"enum,omitempty"`
|
||||
}{
|
||||
"format": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "The format to return the temperature in",
|
||||
Enum: []any{"fahrenheit", "celsius"},
|
||||
},
|
||||
"location": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "The location to get the temperature for",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tool2 := api.Tool{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "say_hello",
|
||||
Description: "Say hello to the user",
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
buffer []byte
|
||||
want map[string]any
|
||||
tool api.Tool
|
||||
}{
|
||||
{
|
||||
name: "empty string",
|
||||
buffer: []byte{},
|
||||
want: nil,
|
||||
tool: tool,
|
||||
},
|
||||
{
|
||||
name: "whitespace only",
|
||||
buffer: []byte(" \n\t "),
|
||||
want: nil,
|
||||
tool: tool,
|
||||
},
|
||||
{
|
||||
name: "unbalanced braces - missing closing",
|
||||
buffer: []byte(`{"format": "fahrenheit", "location": "San Francisco"`),
|
||||
want: nil,
|
||||
tool: tool,
|
||||
},
|
||||
{
|
||||
name: "unbalanced braces - extra closing",
|
||||
@@ -1243,13 +1118,11 @@ func TestFindArguments(t *testing.T) {
|
||||
want: map[string]any{
|
||||
"format": "fahrenheit",
|
||||
},
|
||||
tool: tool,
|
||||
},
|
||||
{
|
||||
name: "invalid JSON",
|
||||
buffer: []byte(`{format: fahrenheit, location: "San Francisco"}`),
|
||||
want: nil,
|
||||
tool: tool,
|
||||
},
|
||||
{
|
||||
name: "valid json",
|
||||
@@ -1258,7 +1131,6 @@ func TestFindArguments(t *testing.T) {
|
||||
"format": "fahrenheit",
|
||||
"location": "San Francisco, CA",
|
||||
},
|
||||
tool: tool,
|
||||
},
|
||||
{
|
||||
name: "valid arguments with special tokens",
|
||||
@@ -1267,16 +1139,14 @@ func TestFindArguments(t *testing.T) {
|
||||
"format": "fahrenheit",
|
||||
"location": "San Francisco, CA",
|
||||
},
|
||||
tool: tool,
|
||||
},
|
||||
{
|
||||
name: "valid arguments in array",
|
||||
buffer: []byte(`[{"arguments": {"format": "fahrenheit", "location": "San Francisco, CA"}}`),
|
||||
buffer: []byte(`[{"name": "get_temperature", "arguments": {"format": "fahrenheit", "location": "San Francisco, CA"}}`),
|
||||
want: map[string]any{
|
||||
"format": "fahrenheit",
|
||||
"location": "San Francisco, CA",
|
||||
},
|
||||
tool: tool,
|
||||
},
|
||||
{
|
||||
name: "nested deep",
|
||||
@@ -1285,7 +1155,6 @@ func TestFindArguments(t *testing.T) {
|
||||
"format": "fahrenheit",
|
||||
"location": "San Francisco, CA",
|
||||
},
|
||||
tool: tool,
|
||||
},
|
||||
{
|
||||
name: "one arg",
|
||||
@@ -1293,7 +1162,6 @@ func TestFindArguments(t *testing.T) {
|
||||
want: map[string]any{
|
||||
"location": "San Francisco, CA",
|
||||
},
|
||||
tool: tool,
|
||||
},
|
||||
{
|
||||
name: "two args",
|
||||
@@ -1302,13 +1170,6 @@ func TestFindArguments(t *testing.T) {
|
||||
"location": "San Francisco, CA",
|
||||
"format": "fahrenheit",
|
||||
},
|
||||
tool: tool,
|
||||
},
|
||||
{
|
||||
name: "no args",
|
||||
buffer: []byte(`{"name": "say_hello"}`),
|
||||
want: nil,
|
||||
tool: tool2,
|
||||
},
|
||||
{
|
||||
name: "deepseek",
|
||||
@@ -1316,7 +1177,6 @@ func TestFindArguments(t *testing.T) {
|
||||
want: map[string]any{
|
||||
"location": "Tokyo",
|
||||
},
|
||||
tool: tool,
|
||||
},
|
||||
{
|
||||
name: "deepseek",
|
||||
@@ -1324,13 +1184,12 @@ func TestFindArguments(t *testing.T) {
|
||||
want: map[string]any{
|
||||
"location": "Tokyo",
|
||||
},
|
||||
tool: tool,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, _ := findArguments(tt.tool, tt.buffer)
|
||||
got, _ := findArguments(tt.buffer)
|
||||
|
||||
if diff := cmp.Diff(got, tt.want); diff != "" {
|
||||
t.Errorf("scanArguments() args mismatch (-got +want):\n%s", diff)
|
||||
|
||||
Reference in New Issue
Block a user