mirror of
https://github.com/ollama/ollama.git
synced 2026-01-19 12:57:56 -05:00
Compare commits
1 Commits
parth/decr
...
parth/deep
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
478824045d |
@@ -1523,8 +1523,20 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var toolParser *tools.Parser
|
var toolParser tools.ToolParser
|
||||||
if len(req.Tools) > 0 {
|
|
||||||
|
fmt.Println("m.Config.ModelFamily", m.Config.ModelFamily)
|
||||||
|
if m.Config.ModelFamily == "qwen" {
|
||||||
|
slog.Info("using deepseek tool parser")
|
||||||
|
fmt.Println("m.Template.Template", m.Template.Template)
|
||||||
|
toolParser, err = tools.NewDeepSeekToolParser(m.Template.Template)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("failed to create tool parser", "error", err)
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else if len(req.Tools) > 0 {
|
||||||
|
slog.Info("using default tool parser")
|
||||||
toolParser, err = tools.NewParser(m.Template.Template)
|
toolParser, err = tools.NewParser(m.Template.Template)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("failed to create tool parser", "error", err)
|
slog.Error("failed to create tool parser", "error", err)
|
||||||
|
|||||||
179
tools/deepseek_tools.go
Normal file
179
tools/deepseek_tools.go
Normal file
@@ -0,0 +1,179 @@
|
|||||||
|
package tools
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
gotmpl "text/template"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
type DeepSeekToolParser struct {
|
||||||
|
parser *Parser // Embed the base parser as a field
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *DeepSeekToolParser) Add(s string) (tools []api.ToolCall, content string) {
|
||||||
|
fmt.Println("prefix", p.parser.prefix)
|
||||||
|
fmt.Println("DeepSeekToolParser.Add: Starting with input:", s)
|
||||||
|
p.parser.sb.WriteString(s)
|
||||||
|
s = p.parser.sb.String()
|
||||||
|
fmt.Println("DeepSeekToolParser.Add: After StringBuilder:", s)
|
||||||
|
|
||||||
|
// Check for prefix pattern in input
|
||||||
|
s, err := p.parser.checkPrefix(s)
|
||||||
|
fmt.Println("DeepSeekToolParser.Add: After checkPrefix:", s, "error:", err)
|
||||||
|
if err != nil {
|
||||||
|
// Need more input to complete prefix
|
||||||
|
return nil, s
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exit if prefix exists in template, greedy parsing is off, and prefix not found
|
||||||
|
if !p.parser.prefixFound {
|
||||||
|
fmt.Println("DeepSeekToolParser.Add: Prefix not found, resetting")
|
||||||
|
p.parser.sb.Reset()
|
||||||
|
return nil, s
|
||||||
|
}
|
||||||
|
|
||||||
|
toolCalls, err := parseDeepSeekToolCalls(s)
|
||||||
|
fmt.Println("DeepSeekToolParser.Add: After parseDeepSeekToolCalls:", toolCalls, "error:", err)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, errAccumulateMore) {
|
||||||
|
return nil, ""
|
||||||
|
}
|
||||||
|
p.parser.sb.Reset()
|
||||||
|
// Only do greedy JSON parsing if there is no prefix from template
|
||||||
|
if p.parser.prefix != "" {
|
||||||
|
fmt.Println("DeepSeekToolParser.Add: Disabling greedy parsing")
|
||||||
|
p.parser.greedyParseJSON = false
|
||||||
|
}
|
||||||
|
if p.parser.index != 0 && p.parser.prefix == "" {
|
||||||
|
return nil, ""
|
||||||
|
}
|
||||||
|
if p.parser.prefixFound {
|
||||||
|
fmt.Println("DeepSeekToolParser.Add: Prefix found but invalid tool call")
|
||||||
|
// Drop tokens since prefix was found
|
||||||
|
return nil, ""
|
||||||
|
}
|
||||||
|
return nil, s
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("DeepSeekToolParser.Add: Processing tool calls")
|
||||||
|
for _, tc := range toolCalls {
|
||||||
|
tc.Function.Index = p.parser.index
|
||||||
|
p.parser.index++
|
||||||
|
}
|
||||||
|
|
||||||
|
p.parser.sb.Reset()
|
||||||
|
fmt.Println("DeepSeekToolParser.Add: Returning tool calls:", toolCalls)
|
||||||
|
return toolCalls, ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *DeepSeekToolParser) NewParser(templateToProcess *gotmpl.Template) (ToolParser, error) {
|
||||||
|
return NewDeepSeekToolParser(templateToProcess)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewDeepSeekToolParser(templateToProcess *gotmpl.Template) (ToolParser, error) {
|
||||||
|
// Create base parser first
|
||||||
|
baseParser, err := NewParser(templateToProcess)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create base parser: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &DeepSeekToolParser{
|
||||||
|
parser: baseParser,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseDeepSeekToolCalls(s string) ([]api.ToolCall, error) {
|
||||||
|
fmt.Println("parseDeepSeekToolCalls: Starting with input:", s)
|
||||||
|
fields := strings.Fields(s)
|
||||||
|
fmt.Println("parseDeepSeekToolCalls: Split fields:", fields)
|
||||||
|
|
||||||
|
sep := "<|tool▁sep|>"
|
||||||
|
var functionNames []string
|
||||||
|
for _, field := range fields {
|
||||||
|
fmt.Println("parseDeepSeekToolCalls: Processing field:", field)
|
||||||
|
// TODO: check if brittle
|
||||||
|
if strings.Contains(field, "function") {
|
||||||
|
idx := strings.Index(field, "function")
|
||||||
|
if idx == -1 {
|
||||||
|
fmt.Println("parseDeepSeekToolCalls: No 'function' prefix found")
|
||||||
|
return nil, errAccumulateMore
|
||||||
|
}
|
||||||
|
functionName := field[idx+len("function"):]
|
||||||
|
// functionName, cut := strings.CutPrefix(field, "function")
|
||||||
|
// if !cut {
|
||||||
|
// fmt.Println("parseDeepSeekToolCalls: Failed to cut 'function' prefix")
|
||||||
|
// return nil, errAccumulateMore
|
||||||
|
// }
|
||||||
|
// pass through on this is fine as it doesn't always come down
|
||||||
|
functionName, _ = strings.CutPrefix(functionName, sep)
|
||||||
|
fmt.Println("parseDeepSeekToolCalls: Found function name:", functionName)
|
||||||
|
functionNames = append(functionNames, functionName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(functionNames) == 0 {
|
||||||
|
fmt.Println("parseDeepSeekToolCalls: No function names found")
|
||||||
|
return nil, errAccumulateMore
|
||||||
|
}
|
||||||
|
fmt.Println("parseDeepSeekToolCalls: Found function names:", functionNames)
|
||||||
|
|
||||||
|
braceCount := 0
|
||||||
|
startIndex := -1
|
||||||
|
|
||||||
|
var rawToolArgs []string
|
||||||
|
for i, c := range s {
|
||||||
|
switch c {
|
||||||
|
case '{':
|
||||||
|
braceCount++
|
||||||
|
if startIndex == -1 {
|
||||||
|
startIndex = i
|
||||||
|
fmt.Printf("parseDeepSeekToolCalls: Found opening brace at index %d\n", i)
|
||||||
|
}
|
||||||
|
case '}':
|
||||||
|
braceCount--
|
||||||
|
if braceCount == 0 {
|
||||||
|
rawToolArgs = append(rawToolArgs, s[startIndex:i+1])
|
||||||
|
fmt.Printf("parseDeepSeekToolCalls: Found closing brace at index %d, captured: %s\n", i, s[startIndex:i+1])
|
||||||
|
startIndex = -1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fmt.Println("parseDeepSeekToolCalls: Raw tool arguments:", rawToolArgs)
|
||||||
|
|
||||||
|
var toolCalls []api.ToolCall
|
||||||
|
// unmarshal args
|
||||||
|
var args map[string]any
|
||||||
|
for i, rawToolArg := range rawToolArgs {
|
||||||
|
fmt.Printf("parseDeepSeekToolCalls: Unmarshaling tool arg %d: %s\n", i, rawToolArg)
|
||||||
|
if err := json.Unmarshal([]byte(rawToolArg), &args); err != nil {
|
||||||
|
fmt.Println("parseDeepSeekToolCalls: Failed to unmarshal JSON:", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
toolCalls = append(toolCalls, api.ToolCall{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: functionNames[i],
|
||||||
|
Arguments: args,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
fmt.Printf("parseDeepSeekToolCalls: Created tool call %d with name %s and args %v\n", i, functionNames[i], args)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(toolCalls) == 0 {
|
||||||
|
fmt.Println("parseDeepSeekToolCalls: No tool calls created")
|
||||||
|
// todo: check err here
|
||||||
|
return nil, errInvalidToolCall
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("parseDeepSeekToolCalls: Returning tool calls:", toolCalls)
|
||||||
|
return toolCalls, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ! use as prefix
|
||||||
|
// {{"<|tool▁call▁begin|>
|
||||||
|
// ! send to tc parser
|
||||||
|
// * function<|tool▁sep|><function_name>\n```json\n<function_arguments_in_json_format>\n```<|tool▁call▁end|>"}}
|
||||||
86
tools/deepseek_tools_test.go
Normal file
86
tools/deepseek_tools_test.go
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
package tools
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/template"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDeepSeekToolParser(t *testing.T) {
|
||||||
|
p := filepath.Join("testdata")
|
||||||
|
t1 := api.ToolCall{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get_current_weather",
|
||||||
|
Arguments: map[string]any{
|
||||||
|
"format": "fahrenheit",
|
||||||
|
"location": "San Francisco, CA",
|
||||||
|
},
|
||||||
|
Index: 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// t2 := api.ToolCall{
|
||||||
|
// Function: api.ToolCallFunction{
|
||||||
|
// Name: "get_current_weather",
|
||||||
|
// Arguments: map[string]any{
|
||||||
|
// "format": "celsius",
|
||||||
|
// "location": "Toronto, Canada",
|
||||||
|
// },
|
||||||
|
// Index: 1,
|
||||||
|
// },
|
||||||
|
// }
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
template string
|
||||||
|
output string
|
||||||
|
expectedToolCall []api.ToolCall
|
||||||
|
expectedTokens string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "single tool call",
|
||||||
|
output: `<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_current_weather
|
||||||
|
` + "```json\n" + `{"format":"fahrenheit","location":"San Francisco, CA"}` + "\n```" + `<|tool▁call▁end|>`,
|
||||||
|
expectedToolCall: []api.ToolCall{t1},
|
||||||
|
expectedTokens: "",
|
||||||
|
},
|
||||||
|
// {
|
||||||
|
// name: "multiple tool calls",
|
||||||
|
// template: `"<|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n` + "```json\n" + `{"format":"fahrenheit","location":"San Francisco, CA"}` + "\n```" + `<|tool▁call▁end|>"`,
|
||||||
|
// output: `<|tool▁call▁begin|>function<|tool▁sep|>get_current_weather
|
||||||
|
// ` + "```json\n" + `{"format":"fahrenheit","location":"San Francisco, CA"}` + "\n```" + `<|tool▁call▁end|>
|
||||||
|
// <|tool▁call▁begin|>function<|tool▁sep|>get_current_weather
|
||||||
|
// ` + "```json\n" + `{"format":"celsius","location":"Toronto, Canada"}` + "\n```" + `<|tool▁call▁end|>`,
|
||||||
|
// expectedToolCall: []api.ToolCall{t1, t2},
|
||||||
|
// expectedTokens: "",
|
||||||
|
// },
|
||||||
|
// {
|
||||||
|
// name: "invalid tool call format",
|
||||||
|
// template: `{{"<|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n` + "```json\n" + `{"format":"fahrenheit","location":"San Francisco, CA"}` + "\n```" + `<|tool▁call▁end|>"}}`,
|
||||||
|
// output: "This is just some text without a tool call",
|
||||||
|
// expectedToolCall: nil,
|
||||||
|
// expectedTokens: "This is just some text without a tool call",
|
||||||
|
// },
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
tmpl, err := template.Parse(readFile(t, p, "deepseek-r1.gotmpl").String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
fmt.Println(tmpl.Template.Root.String())
|
||||||
|
|
||||||
|
parser, err := NewDeepSeekToolParser(tmpl.Template)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
tools, content := parser.Add(tt.output)
|
||||||
|
assert.Equal(t, tt.expectedToolCall, tools)
|
||||||
|
assert.Equal(t, tt.expectedTokens, content)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -3,6 +3,7 @@ package tools
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"strings"
|
"strings"
|
||||||
gotmpl "text/template"
|
gotmpl "text/template"
|
||||||
@@ -16,6 +17,11 @@ var (
|
|||||||
errAccumulateMore = errors.New("need to accumulate more content")
|
errAccumulateMore = errors.New("need to accumulate more content")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type ToolParser interface {
|
||||||
|
Add(s string) (tools []api.ToolCall, content string)
|
||||||
|
NewParser(templateToProcess *gotmpl.Template) (ToolParser, error)
|
||||||
|
}
|
||||||
|
|
||||||
type Parser struct {
|
type Parser struct {
|
||||||
greedyParseJSON bool
|
greedyParseJSON bool
|
||||||
prefix string
|
prefix string
|
||||||
@@ -104,6 +110,9 @@ func parseJSONToolCalls(s string, name, arguments string, prefix string) ([]api.
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fmt.Println("name", name)
|
||||||
|
fmt.Println("arguments", arguments)
|
||||||
|
fmt.Println("parseJSONToolCalls: Objects:", objs)
|
||||||
// Extract tool calls from objects
|
// Extract tool calls from objects
|
||||||
for _, kv := range objs {
|
for _, kv := range objs {
|
||||||
n, nok := kv[name].(string)
|
n, nok := kv[name].(string)
|
||||||
@@ -123,7 +132,6 @@ func parseJSONToolCalls(s string, name, arguments string, prefix string) ([]api.
|
|||||||
|
|
||||||
// Valid JSON, no tool calls found
|
// Valid JSON, no tool calls found
|
||||||
if len(toolCalls) == 0 {
|
if len(toolCalls) == 0 {
|
||||||
slog.Debug("No valid tool calls found in any raw tool calls.", "rawToolCalls", rawToolCalls)
|
|
||||||
return nil, errInvalidToolCall
|
return nil, errInvalidToolCall
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -177,6 +185,7 @@ func (p *Parser) checkPrefix(s string) (string, error) {
|
|||||||
func (p *Parser) Add(s string) (tools []api.ToolCall, content string) {
|
func (p *Parser) Add(s string) (tools []api.ToolCall, content string) {
|
||||||
p.sb.WriteString(s)
|
p.sb.WriteString(s)
|
||||||
s = p.sb.String()
|
s = p.sb.String()
|
||||||
|
fmt.Println("Add: Starting with input:", s)
|
||||||
|
|
||||||
// Check for prefix pattern in input
|
// Check for prefix pattern in input
|
||||||
s, err := p.checkPrefix(s)
|
s, err := p.checkPrefix(s)
|
||||||
@@ -225,23 +234,37 @@ func (p *Parser) Add(s string) (tools []api.ToolCall, content string) {
|
|||||||
//
|
//
|
||||||
// Returns an error if the template does not contain valid tool call formatting.
|
// Returns an error if the template does not contain valid tool call formatting.
|
||||||
func NewParser(templateToProcess *gotmpl.Template) (*Parser, error) {
|
func NewParser(templateToProcess *gotmpl.Template) (*Parser, error) {
|
||||||
|
fmt.Println("Checkpoint 1: Starting NewParser")
|
||||||
parsed, err := template.Parse(templateToProcess.Root.String())
|
parsed, err := template.Parse(templateToProcess.Root.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
fmt.Println("Checkpoint 2: Error parsing template:", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fmt.Println("Checkpoint 3: Getting tool template")
|
||||||
tt, err := toolTemplate(parsed)
|
tt, err := toolTemplate(parsed)
|
||||||
|
fmt.Println("Checkpoint 4: Tool template:", tt.Root.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
fmt.Println("Checkpoint 5: Error getting tool template:", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fmt.Println("Checkpoint 6: Getting tool prefix")
|
||||||
tp := toolPrefix(templateToProcess)
|
tp := toolPrefix(templateToProcess)
|
||||||
|
fmt.Println("Checkpoint 7: Tool prefix:", tp)
|
||||||
|
|
||||||
|
fmt.Println("Checkpoint 8: Extracting tool args")
|
||||||
name, arguments, err := extractToolArgs(tt)
|
name, arguments, err := extractToolArgs(tt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
fmt.Println("Checkpoint 9: Error extracting tool args:", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
// name := "temp1"
|
||||||
|
// args := "temp2"
|
||||||
|
|
||||||
|
fmt.Println("Checkpoint 10: Tool name:", name, "arguments:", arguments)
|
||||||
|
|
||||||
|
fmt.Println("Checkpoint 11: Creating parser")
|
||||||
return &Parser{
|
return &Parser{
|
||||||
tmpl: *tt,
|
tmpl: *tt,
|
||||||
sb: strings.Builder{},
|
sb: strings.Builder{},
|
||||||
@@ -251,3 +274,8 @@ func NewParser(templateToProcess *gotmpl.Template) (*Parser, error) {
|
|||||||
arguments: arguments,
|
arguments: arguments,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewParser implements the ToolParser interface
|
||||||
|
func (p *Parser) NewParser(templateToProcess *gotmpl.Template) (ToolParser, error) {
|
||||||
|
return NewParser(templateToProcess)
|
||||||
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -98,6 +99,7 @@ func isToolCallsNode(n *parse.IfNode) bool {
|
|||||||
|
|
||||||
func toolPrefix(tmpl *gotmpl.Template) string {
|
func toolPrefix(tmpl *gotmpl.Template) string {
|
||||||
tokenText, ok := extractToolCallsFormat(tmpl)
|
tokenText, ok := extractToolCallsFormat(tmpl)
|
||||||
|
fmt.Println("tokenText", tokenText)
|
||||||
if !ok {
|
if !ok {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user