mirror of
https://github.com/ollama/ollama.git
synced 2026-01-13 09:59:08 -05:00
Compare commits
2 Commits
drifkin/ch
...
parth/pyth
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
23e8ac9428 | ||
|
|
611d3a17ed |
@@ -34,15 +34,13 @@ func cosineSimilarity[V float32 | float64](v1, v2 []V) V {
|
||||
func TestAllMiniLMEmbeddings(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
req := api.EmbeddingRequest{
|
||||
Model: "all-minilm",
|
||||
Prompt: "why is the sky blue?",
|
||||
}
|
||||
|
||||
res, err := embeddingTestHelper(ctx, client, t, req)
|
||||
res, err := embeddingTestHelper(ctx, t, req)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
@@ -64,15 +62,13 @@ func TestAllMiniLMEmbeddings(t *testing.T) {
|
||||
func TestAllMiniLMEmbed(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
req := api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: "why is the sky blue?",
|
||||
}
|
||||
|
||||
res, err := embedTestHelper(ctx, client, t, req)
|
||||
res, err := embedTestHelper(ctx, t, req)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
@@ -102,15 +98,13 @@ func TestAllMiniLMEmbed(t *testing.T) {
|
||||
func TestAllMiniLMBatchEmbed(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
req := api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: []string{"why is the sky blue?", "why is the grass green?"},
|
||||
}
|
||||
|
||||
res, err := embedTestHelper(ctx, client, t, req)
|
||||
res, err := embedTestHelper(ctx, t, req)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
@@ -150,8 +144,6 @@ func TestAllMiniLMBatchEmbed(t *testing.T) {
|
||||
func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
truncTrue, truncFalse := true, false
|
||||
|
||||
@@ -190,7 +182,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
||||
res := make(map[string]*api.EmbedResponse)
|
||||
|
||||
for _, req := range reqs {
|
||||
response, err := embedTestHelper(ctx, client, t, req.Request)
|
||||
response, err := embedTestHelper(ctx, t, req.Request)
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
@@ -198,7 +190,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
||||
}
|
||||
|
||||
if res["Target Truncation"].Embeddings[0][0] != res["Default Truncate"].Embeddings[0][0] {
|
||||
t.Fatal("expected default request to truncate correctly. Wanted: ", res["Target Truncation"].Embeddings[0][0], "Got: ", res["Default Truncate"].Embeddings[0][0])
|
||||
t.Fatal("expected default request to truncate correctly")
|
||||
}
|
||||
|
||||
if res["Default Truncate"].Embeddings[0][0] != res["Explicit Truncate"].Embeddings[0][0] {
|
||||
@@ -206,7 +198,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
||||
}
|
||||
|
||||
// check that truncate set to false returns an error if context length is exceeded
|
||||
_, err := embedTestHelper(ctx, client, t, api.EmbedRequest{
|
||||
_, err := embedTestHelper(ctx, t, api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: "why is the sky blue?",
|
||||
Truncate: &truncFalse,
|
||||
@@ -218,7 +210,9 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func embeddingTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbeddingRequest) (*api.EmbeddingResponse, error) {
|
||||
func embeddingTestHelper(ctx context.Context, t *testing.T, req api.EmbeddingRequest) (*api.EmbeddingResponse, error) {
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatalf("failed to pull model %s: %v", req.Model, err)
|
||||
}
|
||||
@@ -232,7 +226,9 @@ func embeddingTestHelper(ctx context.Context, client *api.Client, t *testing.T,
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func embedTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
|
||||
func embedTestHelper(ctx context.Context, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatalf("failed to pull model %s: %v", req.Model, err)
|
||||
}
|
||||
|
||||
226
server/python_tools.go
Normal file
226
server/python_tools.go
Normal file
@@ -0,0 +1,226 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
var (
|
||||
pythonFuncRegex = regexp.MustCompile(`(\w+)\((.*?)\)`)
|
||||
braces = map[rune]rune{
|
||||
'[': ']',
|
||||
'{': '}',
|
||||
'(': ')',
|
||||
'"': '"',
|
||||
'\'': '\'',
|
||||
}
|
||||
)
|
||||
|
||||
// parsePythonValue converts a Python value string to its appropriate Go type
|
||||
func parsePythonValue(value string) (any, error) {
|
||||
value = strings.TrimSpace(value)
|
||||
|
||||
// string
|
||||
if (strings.HasPrefix(value, "\"") && strings.HasSuffix(value, "\"")) ||
|
||||
(strings.HasPrefix(value, "'") && strings.HasSuffix(value, "'")) {
|
||||
// Remove quotes
|
||||
result := value[1 : len(value)-1]
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// bool
|
||||
switch strings.ToLower(value) {
|
||||
case "true":
|
||||
return true, nil
|
||||
case "false":
|
||||
return false, nil
|
||||
case "none":
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// int
|
||||
if i, err := strconv.Atoi(value); err == nil {
|
||||
return i, nil
|
||||
}
|
||||
|
||||
// float
|
||||
if f, err := strconv.ParseFloat(value, 64); err == nil {
|
||||
return f, nil
|
||||
}
|
||||
|
||||
// list
|
||||
if strings.HasPrefix(value, "[") && strings.HasSuffix(value, "]") {
|
||||
listStr := value[1 : len(value)-1]
|
||||
var list []any
|
||||
stack := []rune{}
|
||||
start := 0
|
||||
|
||||
for i, char := range listStr {
|
||||
if len(stack) != 0 && char == braces[stack[len(stack)-1]] {
|
||||
stack = stack[:len(stack)-1]
|
||||
} else if _, ok := braces[char]; ok {
|
||||
stack = append(stack, char)
|
||||
}
|
||||
|
||||
if len(stack) == 0 && (char == ',' || i == len(listStr)-1) {
|
||||
end := i
|
||||
if i == len(listStr)-1 {
|
||||
end = i + 1
|
||||
}
|
||||
item := strings.TrimSpace(listStr[start:end])
|
||||
if val, err := parsePythonValue(item); err == nil {
|
||||
list = append(list, val)
|
||||
} else {
|
||||
return nil, fmt.Errorf("invalid list item: %s", item)
|
||||
}
|
||||
start = i + 1
|
||||
}
|
||||
}
|
||||
return list, nil
|
||||
}
|
||||
|
||||
// dictionary
|
||||
if strings.HasPrefix(value, "{") && strings.HasSuffix(value, "}") && strings.Contains(value, ":") {
|
||||
dictStr := value[1 : len(value)-1]
|
||||
dict := make(map[any]any)
|
||||
stack := []rune{}
|
||||
start := 0
|
||||
for i, char := range dictStr {
|
||||
if len(stack) != 0 && char == braces[stack[len(stack)-1]] {
|
||||
stack = stack[:len(stack)-1]
|
||||
} else if _, ok := braces[char]; ok {
|
||||
stack = append(stack, char)
|
||||
}
|
||||
if len(stack) == 0 && (char == ',' || i == len(dictStr)-1) {
|
||||
end := i
|
||||
if i == len(dictStr)-1 {
|
||||
end = i + 1
|
||||
}
|
||||
item := strings.TrimSpace(dictStr[start:end])
|
||||
kv := strings.SplitN(item, ":", 2)
|
||||
if len(kv) != 2 {
|
||||
return nil, fmt.Errorf("invalid dictionary key-value pair: %s", item)
|
||||
}
|
||||
|
||||
key, err := parsePythonValue(strings.TrimSpace(kv[0]))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid dictionary key: %s", kv[0])
|
||||
}
|
||||
|
||||
val, err := parsePythonValue(strings.TrimSpace(kv[1]))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid dictionary value: %s", kv[1])
|
||||
}
|
||||
|
||||
dict[key] = val
|
||||
start = i + 1
|
||||
}
|
||||
}
|
||||
return dict, nil
|
||||
}
|
||||
|
||||
// sets (stored as lists)
|
||||
if strings.HasPrefix(value, "{") && strings.HasSuffix(value, "}") {
|
||||
setStr := value[1 : len(value)-1]
|
||||
var list []any
|
||||
stack := []rune{}
|
||||
start := 0
|
||||
for i, char := range setStr {
|
||||
if len(stack) != 0 && char == braces[stack[len(stack)-1]] {
|
||||
stack = stack[:len(stack)-1]
|
||||
} else if _, ok := braces[char]; ok {
|
||||
stack = append(stack, char)
|
||||
}
|
||||
if len(stack) == 0 && (char == ',' || i == len(setStr)-1) {
|
||||
end := i
|
||||
if i == len(setStr)-1 {
|
||||
end = i + 1
|
||||
}
|
||||
item := strings.TrimSpace(setStr[start:end])
|
||||
if val, err := parsePythonValue(item); err == nil {
|
||||
list = append(list, val)
|
||||
} else {
|
||||
return nil, fmt.Errorf("invalid set item: %s", item)
|
||||
}
|
||||
start = i + 1
|
||||
}
|
||||
}
|
||||
return list, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("invalid Python value: %s", value)
|
||||
}
|
||||
|
||||
// parsePythonToolCall parses Python function calls from a string
|
||||
// it supports keyword arguments, as well as multiple functions in a single string
|
||||
func parsePythonToolCall(s string) ([]api.ToolCall, error) {
|
||||
matches := pythonFuncRegex.FindAllStringSubmatchIndex(s, -1)
|
||||
if len(matches) == 0 {
|
||||
return nil, fmt.Errorf("no Python function calls found")
|
||||
}
|
||||
|
||||
var toolCalls []api.ToolCall
|
||||
for _, match := range matches {
|
||||
name := s[match[2]:match[3]]
|
||||
args := s[match[4]:match[5]]
|
||||
var arguments api.ToolCallFunctionArguments
|
||||
if len(args) == 0 {
|
||||
toolCalls = append(toolCalls, api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: name,
|
||||
},
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
start := 0
|
||||
stack := []rune{}
|
||||
for i, char := range args {
|
||||
if len(stack) != 0 && char == braces[stack[len(stack)-1]] {
|
||||
stack = stack[:len(stack)-1]
|
||||
} else if _, ok := braces[char]; ok {
|
||||
stack = append(stack, char)
|
||||
}
|
||||
if len(stack) == 0 && (char == ',' || i == len(args)-1) {
|
||||
end := i
|
||||
if i == len(args)-1 {
|
||||
end = i + 1
|
||||
}
|
||||
kv := strings.SplitN(args[start:end], "=", 2)
|
||||
if len(kv) == 2 {
|
||||
key := strings.TrimSpace(kv[0])
|
||||
valueStr := strings.TrimSpace(kv[1])
|
||||
|
||||
// Parse the value into appropriate type
|
||||
value, err := parsePythonValue(valueStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse value for key %q: %v", key, err)
|
||||
}
|
||||
|
||||
arguments[key] = value
|
||||
} else {
|
||||
return nil, fmt.Errorf("invalid argument format: %q", args[start:end])
|
||||
}
|
||||
start = i + 1
|
||||
}
|
||||
}
|
||||
|
||||
if len(arguments) > 0 {
|
||||
toolCalls = append(toolCalls, api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: name,
|
||||
Arguments: arguments,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if len(toolCalls) > 0 {
|
||||
return toolCalls, nil
|
||||
}
|
||||
return nil, fmt.Errorf("failed to parse any valid tool calls")
|
||||
}
|
||||
269
server/python_tools_test.go
Normal file
269
server/python_tools_test.go
Normal file
@@ -0,0 +1,269 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestParsePythonFunctionCall(t *testing.T) {
|
||||
t1 := api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_current_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"location": "San Francisco, CA",
|
||||
"format": "fahrenheit",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
t2 := api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_forecast",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"days": 5,
|
||||
"location": "Seattle",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
t3 := api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_current_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"list": []any{1, 2, 3},
|
||||
"int": -1,
|
||||
"float": 1.23,
|
||||
"string": "hello",
|
||||
},
|
||||
},
|
||||
}
|
||||
t4 := api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_current_weather",
|
||||
},
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
input string
|
||||
want []api.ToolCall
|
||||
err bool
|
||||
}{
|
||||
{
|
||||
name: "malformed function call - missing closing paren",
|
||||
input: "get_current_weather(location=\"San Francisco\"",
|
||||
err: true,
|
||||
},
|
||||
{
|
||||
name: "empty function call",
|
||||
input: "get_current_weather()",
|
||||
want: []api.ToolCall{t4},
|
||||
err: false,
|
||||
},
|
||||
{
|
||||
name: "single valid function call",
|
||||
input: "get_current_weather(location=\"San Francisco, CA\", format=\"fahrenheit\")",
|
||||
want: []api.ToolCall{t1},
|
||||
},
|
||||
{
|
||||
name: "multiple valid function calls",
|
||||
input: "get_current_weather(location=\"San Francisco, CA\", format=\"fahrenheit\") get_forecast(days=5, location=\"Seattle\")",
|
||||
want: []api.ToolCall{t1, t2},
|
||||
},
|
||||
{
|
||||
name: "multiple valid function calls with list",
|
||||
input: "get_current_weather(list=[1,2,3], int=-1, float=1.23, string=\"hello\")",
|
||||
want: []api.ToolCall{t3},
|
||||
},
|
||||
{
|
||||
name: "positional arguments not supported",
|
||||
input: "get_current_weather(1, 2, 3)",
|
||||
err: true,
|
||||
},
|
||||
{
|
||||
name: "invalid argument format without equals",
|
||||
input: "get_current_weather(\"San Francisco\")",
|
||||
err: true,
|
||||
},
|
||||
{
|
||||
name: "nested lists",
|
||||
input: "get_current_weather(data=[[1,2],[3,4]])",
|
||||
want: []api.ToolCall{{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_current_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"data": []any{[]any{1, 2}, []any{3, 4}},
|
||||
},
|
||||
},
|
||||
}},
|
||||
},
|
||||
{
|
||||
name: "boolean and none values",
|
||||
input: "get_current_weather(active=true, enabled=false, value=None)",
|
||||
want: []api.ToolCall{{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_current_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"active": true,
|
||||
"enabled": false,
|
||||
"value": nil,
|
||||
},
|
||||
},
|
||||
}},
|
||||
},
|
||||
{
|
||||
name: "single vs double quotes",
|
||||
input: "get_current_weather(str1='single', str2=\"double\")",
|
||||
want: []api.ToolCall{{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_current_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"str1": "single",
|
||||
"str2": "double",
|
||||
},
|
||||
},
|
||||
}},
|
||||
},
|
||||
{
|
||||
name: "whitespace handling",
|
||||
input: "get_current_weather( location = \"San Francisco\" , temp = 72 )",
|
||||
want: []api.ToolCall{{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_current_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"location": "San Francisco",
|
||||
"temp": 72,
|
||||
},
|
||||
},
|
||||
}},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := parsePythonToolCall(tt.input)
|
||||
if (err != nil) != tt.err {
|
||||
t.Fatalf("expected error: %v, got error: %v", tt.err, err)
|
||||
}
|
||||
if tt.err {
|
||||
return
|
||||
}
|
||||
if diff := cmp.Diff(got, tt.want); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParsePythonValue(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
input string
|
||||
want any
|
||||
err bool
|
||||
}{
|
||||
{
|
||||
name: "string with double quotes",
|
||||
input: "\"hello\"",
|
||||
want: "hello",
|
||||
},
|
||||
{
|
||||
name: "string with single quotes",
|
||||
input: "'world'",
|
||||
want: "world",
|
||||
},
|
||||
{
|
||||
name: "integer",
|
||||
input: "42",
|
||||
want: 42,
|
||||
},
|
||||
{
|
||||
name: "float",
|
||||
input: "3.14",
|
||||
want: 3.14,
|
||||
},
|
||||
{
|
||||
name: "boolean true",
|
||||
input: "True",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "boolean false",
|
||||
input: "False",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "none/null",
|
||||
input: "None",
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "simple list",
|
||||
input: "[1, 2, 3]",
|
||||
want: []any{1, 2, 3},
|
||||
},
|
||||
{
|
||||
name: "nested list",
|
||||
input: "[1, [2, 3], 4]",
|
||||
want: []any{1, []any{2, 3}, 4},
|
||||
},
|
||||
{
|
||||
name: "mixed type list",
|
||||
input: "[1, \"two\", 3.0, true]",
|
||||
want: []any{1, "two", 3.0, true},
|
||||
},
|
||||
{
|
||||
name: "invalid list",
|
||||
input: "[1, 2,",
|
||||
want: nil,
|
||||
err: true,
|
||||
},
|
||||
{
|
||||
name: "dictionaries",
|
||||
input: "{'a': 1, 'b': 2}",
|
||||
want: map[any]any{"a": 1, "b": 2},
|
||||
err: false,
|
||||
},
|
||||
{
|
||||
name: "int dictionary",
|
||||
input: "{1: 2}",
|
||||
want: map[any]any{1: 2},
|
||||
err: false,
|
||||
},
|
||||
{
|
||||
name: "mixed type dictionary",
|
||||
input: "{'a': 1, 'b': 2.0, 'c': True}",
|
||||
want: map[any]any{"a": 1, "b": 2.0, "c": true},
|
||||
err: false,
|
||||
},
|
||||
{
|
||||
name: "invalid dictionary - missing closing brace",
|
||||
input: "{'a': 1, 'b': 2",
|
||||
want: nil,
|
||||
err: true,
|
||||
},
|
||||
{
|
||||
name: "sets",
|
||||
input: "{1, 2, 3}",
|
||||
want: []any{1, 2, 3},
|
||||
err: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := parsePythonValue(tt.input)
|
||||
if (err != nil) != tt.err {
|
||||
t.Fatalf("expected error: %v, got error: %v", tt.err, err)
|
||||
}
|
||||
if tt.err {
|
||||
return
|
||||
}
|
||||
if diff := cmp.Diff(got, tt.want); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -114,16 +114,6 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.C
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
// TODO(drifkin): `GetRunner` above changes opts, but we currently pass it by
|
||||
// value. The following line is a hack to fix this for the now dynaically
|
||||
// calculated NumCtx, but we should fix this properly (which could have other
|
||||
// side effects, since perhaps we were relying on the values not being stomped
|
||||
// on, particularly when NumCtx sometimes represents a numParallel-adjusted
|
||||
// number and sometimes not)
|
||||
if opts.NumCtx == -1 {
|
||||
opts.NumCtx = runner.Options.NumCtx / runner.numParallel
|
||||
}
|
||||
|
||||
return runner.llama, model, &opts, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -80,12 +80,8 @@ func TestGenerateChat(t *testing.T) {
|
||||
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ int) {
|
||||
// add small delay to simulate loading
|
||||
time.Sleep(time.Millisecond)
|
||||
opts := api.DefaultOptions()
|
||||
opts.NumCtx = 4096
|
||||
req.successCh <- &runnerRef{
|
||||
llama: &mock,
|
||||
numParallel: 1,
|
||||
Options: &opts,
|
||||
llama: &mock,
|
||||
}
|
||||
},
|
||||
},
|
||||
@@ -188,8 +184,7 @@ func TestGenerateChat(t *testing.T) {
|
||||
|
||||
t.Run("load model", func(t *testing.T) {
|
||||
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||
Model: "test",
|
||||
Options: map[string]any{"num_ctx": 2048},
|
||||
Model: "test",
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
@@ -639,9 +634,7 @@ func TestGenerate(t *testing.T) {
|
||||
// add small delay to simulate loading
|
||||
time.Sleep(time.Millisecond)
|
||||
req.successCh <- &runnerRef{
|
||||
llama: &mock,
|
||||
Options: &api.Options{},
|
||||
numParallel: 1,
|
||||
llama: &mock,
|
||||
}
|
||||
},
|
||||
},
|
||||
@@ -757,8 +750,7 @@ func TestGenerate(t *testing.T) {
|
||||
|
||||
t.Run("load model", func(t *testing.T) {
|
||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||
Model: "test",
|
||||
Options: map[string]any{"num_ctx": 2048},
|
||||
Model: "test",
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
|
||||
@@ -81,9 +81,6 @@ func InitScheduler(ctx context.Context) *Scheduler {
|
||||
|
||||
// context must be canceled to decrement ref count and release the runner
|
||||
func (s *Scheduler) GetRunner(c context.Context, model *Model, opts api.Options, sessionDuration *api.Duration) (chan *runnerRef, chan error) {
|
||||
if opts.NumCtx != -1 && opts.NumCtx < 4 {
|
||||
opts.NumCtx = 4
|
||||
}
|
||||
req := &LlmRequest{
|
||||
ctx: c,
|
||||
model: model,
|
||||
@@ -588,16 +585,6 @@ func (runner *runnerRef) unload() {
|
||||
runner.gpus = nil
|
||||
}
|
||||
|
||||
func runnerOptionsEqual(a, b api.Runner) bool {
|
||||
// if one of the options is -1, then it means it needs to be dynamically calculated
|
||||
if a.NumCtx == -1 {
|
||||
a.NumCtx = b.NumCtx
|
||||
} else if b.NumCtx == -1 {
|
||||
b.NumCtx = a.NumCtx
|
||||
}
|
||||
return reflect.DeepEqual(a, b)
|
||||
}
|
||||
|
||||
func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool {
|
||||
slog.Debug("evaluating already loaded", "model", req.model.ModelPath)
|
||||
runner.refMu.Lock()
|
||||
@@ -627,7 +614,7 @@ func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool
|
||||
defer cancel()
|
||||
if !reflect.DeepEqual(runner.model.AdapterPaths, req.model.AdapterPaths) || // have the adapters changed?
|
||||
!reflect.DeepEqual(runner.model.ProjectorPaths, req.model.ProjectorPaths) || // have the projectors changed?
|
||||
!runnerOptionsEqual(optsExisting, optsNew) ||
|
||||
!reflect.DeepEqual(optsExisting, optsNew) || // have the runner options changed?
|
||||
runner.llama.Ping(ctx) != nil {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -148,6 +148,7 @@ func newScenarioRequest(t *testing.T, ctx context.Context, modelName string, est
|
||||
successCh: make(chan *runnerRef, 1),
|
||||
errCh: make(chan error, 1),
|
||||
}
|
||||
b.req.opts.NumCtx = 4096
|
||||
b.srv = &mockLlm{estimatedVRAM: estimatedVRAM, estimatedVRAMByGPU: map[string]uint64{"": estimatedVRAM}}
|
||||
return b
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user