Compare commits

...

1 Commits

Author SHA1 Message Date
ParthSareen
478824045d temp 2025-05-29 16:17:11 -07:00
5 changed files with 310 additions and 3 deletions

View File

@@ -1523,8 +1523,20 @@ func (s *Server) ChatHandler(c *gin.Context) {
}
}
var toolParser *tools.Parser
if len(req.Tools) > 0 {
var toolParser tools.ToolParser
fmt.Println("m.Config.ModelFamily", m.Config.ModelFamily)
if m.Config.ModelFamily == "qwen" {
slog.Info("using deepseek tool parser")
fmt.Println("m.Template.Template", m.Template.Template)
toolParser, err = tools.NewDeepSeekToolParser(m.Template.Template)
if err != nil {
slog.Error("failed to create tool parser", "error", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
} else if len(req.Tools) > 0 {
slog.Info("using default tool parser")
toolParser, err = tools.NewParser(m.Template.Template)
if err != nil {
slog.Error("failed to create tool parser", "error", err)

179
tools/deepseek_tools.go Normal file
View File

@@ -0,0 +1,179 @@
package tools
import (
"encoding/json"
"errors"
"fmt"
"strings"
gotmpl "text/template"
"github.com/ollama/ollama/api"
)
type DeepSeekToolParser struct {
parser *Parser // Embed the base parser as a field
}
func (p *DeepSeekToolParser) Add(s string) (tools []api.ToolCall, content string) {
fmt.Println("prefix", p.parser.prefix)
fmt.Println("DeepSeekToolParser.Add: Starting with input:", s)
p.parser.sb.WriteString(s)
s = p.parser.sb.String()
fmt.Println("DeepSeekToolParser.Add: After StringBuilder:", s)
// Check for prefix pattern in input
s, err := p.parser.checkPrefix(s)
fmt.Println("DeepSeekToolParser.Add: After checkPrefix:", s, "error:", err)
if err != nil {
// Need more input to complete prefix
return nil, s
}
// Exit if prefix exists in template, greedy parsing is off, and prefix not found
if !p.parser.prefixFound {
fmt.Println("DeepSeekToolParser.Add: Prefix not found, resetting")
p.parser.sb.Reset()
return nil, s
}
toolCalls, err := parseDeepSeekToolCalls(s)
fmt.Println("DeepSeekToolParser.Add: After parseDeepSeekToolCalls:", toolCalls, "error:", err)
if err != nil {
if errors.Is(err, errAccumulateMore) {
return nil, ""
}
p.parser.sb.Reset()
// Only do greedy JSON parsing if there is no prefix from template
if p.parser.prefix != "" {
fmt.Println("DeepSeekToolParser.Add: Disabling greedy parsing")
p.parser.greedyParseJSON = false
}
if p.parser.index != 0 && p.parser.prefix == "" {
return nil, ""
}
if p.parser.prefixFound {
fmt.Println("DeepSeekToolParser.Add: Prefix found but invalid tool call")
// Drop tokens since prefix was found
return nil, ""
}
return nil, s
}
fmt.Println("DeepSeekToolParser.Add: Processing tool calls")
for _, tc := range toolCalls {
tc.Function.Index = p.parser.index
p.parser.index++
}
p.parser.sb.Reset()
fmt.Println("DeepSeekToolParser.Add: Returning tool calls:", toolCalls)
return toolCalls, ""
}
func (p *DeepSeekToolParser) NewParser(templateToProcess *gotmpl.Template) (ToolParser, error) {
return NewDeepSeekToolParser(templateToProcess)
}
func NewDeepSeekToolParser(templateToProcess *gotmpl.Template) (ToolParser, error) {
// Create base parser first
baseParser, err := NewParser(templateToProcess)
if err != nil {
return nil, fmt.Errorf("failed to create base parser: %w", err)
}
return &DeepSeekToolParser{
parser: baseParser,
}, nil
}
func parseDeepSeekToolCalls(s string) ([]api.ToolCall, error) {
fmt.Println("parseDeepSeekToolCalls: Starting with input:", s)
fields := strings.Fields(s)
fmt.Println("parseDeepSeekToolCalls: Split fields:", fields)
sep := "<tool▁sep>"
var functionNames []string
for _, field := range fields {
fmt.Println("parseDeepSeekToolCalls: Processing field:", field)
// TODO: check if brittle
if strings.Contains(field, "function") {
idx := strings.Index(field, "function")
if idx == -1 {
fmt.Println("parseDeepSeekToolCalls: No 'function' prefix found")
return nil, errAccumulateMore
}
functionName := field[idx+len("function"):]
// functionName, cut := strings.CutPrefix(field, "function")
// if !cut {
// fmt.Println("parseDeepSeekToolCalls: Failed to cut 'function' prefix")
// return nil, errAccumulateMore
// }
// pass through on this is fine as it doesn't always come down
functionName, _ = strings.CutPrefix(functionName, sep)
fmt.Println("parseDeepSeekToolCalls: Found function name:", functionName)
functionNames = append(functionNames, functionName)
}
}
if len(functionNames) == 0 {
fmt.Println("parseDeepSeekToolCalls: No function names found")
return nil, errAccumulateMore
}
fmt.Println("parseDeepSeekToolCalls: Found function names:", functionNames)
braceCount := 0
startIndex := -1
var rawToolArgs []string
for i, c := range s {
switch c {
case '{':
braceCount++
if startIndex == -1 {
startIndex = i
fmt.Printf("parseDeepSeekToolCalls: Found opening brace at index %d\n", i)
}
case '}':
braceCount--
if braceCount == 0 {
rawToolArgs = append(rawToolArgs, s[startIndex:i+1])
fmt.Printf("parseDeepSeekToolCalls: Found closing brace at index %d, captured: %s\n", i, s[startIndex:i+1])
startIndex = -1
}
}
}
fmt.Println("parseDeepSeekToolCalls: Raw tool arguments:", rawToolArgs)
var toolCalls []api.ToolCall
// unmarshal args
var args map[string]any
for i, rawToolArg := range rawToolArgs {
fmt.Printf("parseDeepSeekToolCalls: Unmarshaling tool arg %d: %s\n", i, rawToolArg)
if err := json.Unmarshal([]byte(rawToolArg), &args); err != nil {
fmt.Println("parseDeepSeekToolCalls: Failed to unmarshal JSON:", err)
return nil, err
}
toolCalls = append(toolCalls, api.ToolCall{
Function: api.ToolCallFunction{
Name: functionNames[i],
Arguments: args,
},
})
fmt.Printf("parseDeepSeekToolCalls: Created tool call %d with name %s and args %v\n", i, functionNames[i], args)
}
if len(toolCalls) == 0 {
fmt.Println("parseDeepSeekToolCalls: No tool calls created")
// todo: check err here
return nil, errInvalidToolCall
}
fmt.Println("parseDeepSeekToolCalls: Returning tool calls:", toolCalls)
return toolCalls, nil
}
// ! use as prefix
// {{"<tool▁call▁begin>
// ! send to tc parser
// * function<tool▁sep><function_name>\n```json\n<function_arguments_in_json_format>\n```<tool▁call▁end>"}}

View File

@@ -0,0 +1,86 @@
package tools
import (
"fmt"
"path/filepath"
"testing"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/template"
"github.com/stretchr/testify/assert"
)
func TestDeepSeekToolParser(t *testing.T) {
p := filepath.Join("testdata")
t1 := api.ToolCall{
Function: api.ToolCallFunction{
Name: "get_current_weather",
Arguments: map[string]any{
"format": "fahrenheit",
"location": "San Francisco, CA",
},
Index: 0,
},
}
// t2 := api.ToolCall{
// Function: api.ToolCallFunction{
// Name: "get_current_weather",
// Arguments: map[string]any{
// "format": "celsius",
// "location": "Toronto, Canada",
// },
// Index: 1,
// },
// }
tests := []struct {
name string
template string
output string
expectedToolCall []api.ToolCall
expectedTokens string
}{
{
name: "single tool call",
output: `<tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>get_current_weather
` + "```json\n" + `{"format":"fahrenheit","location":"San Francisco, CA"}` + "\n```" + `<tool▁call▁end>`,
expectedToolCall: []api.ToolCall{t1},
expectedTokens: "",
},
// {
// name: "multiple tool calls",
// template: `"<tool▁call▁begin>function<tool▁sep>get_current_weather\n` + "```json\n" + `{"format":"fahrenheit","location":"San Francisco, CA"}` + "\n```" + `<tool▁call▁end>"`,
// output: `<tool▁call▁begin>function<tool▁sep>get_current_weather
// ` + "```json\n" + `{"format":"fahrenheit","location":"San Francisco, CA"}` + "\n```" + `<tool▁call▁end>
// <tool▁call▁begin>function<tool▁sep>get_current_weather
// ` + "```json\n" + `{"format":"celsius","location":"Toronto, Canada"}` + "\n```" + `<tool▁call▁end>`,
// expectedToolCall: []api.ToolCall{t1, t2},
// expectedTokens: "",
// },
// {
// name: "invalid tool call format",
// template: `{{"<tool▁call▁begin>function<tool▁sep>get_current_weather\n` + "```json\n" + `{"format":"fahrenheit","location":"San Francisco, CA"}` + "\n```" + `<tool▁call▁end>"}}`,
// output: "This is just some text without a tool call",
// expectedToolCall: nil,
// expectedTokens: "This is just some text without a tool call",
// },
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tmpl, err := template.Parse(readFile(t, p, "deepseek-r1.gotmpl").String())
if err != nil {
t.Fatal(err)
}
fmt.Println(tmpl.Template.Root.String())
parser, err := NewDeepSeekToolParser(tmpl.Template)
assert.NoError(t, err)
tools, content := parser.Add(tt.output)
assert.Equal(t, tt.expectedToolCall, tools)
assert.Equal(t, tt.expectedTokens, content)
})
}
}

View File

@@ -3,6 +3,7 @@ package tools
import (
"encoding/json"
"errors"
"fmt"
"log/slog"
"strings"
gotmpl "text/template"
@@ -16,6 +17,11 @@ var (
errAccumulateMore = errors.New("need to accumulate more content")
)
type ToolParser interface {
Add(s string) (tools []api.ToolCall, content string)
NewParser(templateToProcess *gotmpl.Template) (ToolParser, error)
}
type Parser struct {
greedyParseJSON bool
prefix string
@@ -104,6 +110,9 @@ func parseJSONToolCalls(s string, name, arguments string, prefix string) ([]api.
continue
}
fmt.Println("name", name)
fmt.Println("arguments", arguments)
fmt.Println("parseJSONToolCalls: Objects:", objs)
// Extract tool calls from objects
for _, kv := range objs {
n, nok := kv[name].(string)
@@ -123,7 +132,6 @@ func parseJSONToolCalls(s string, name, arguments string, prefix string) ([]api.
// Valid JSON, no tool calls found
if len(toolCalls) == 0 {
slog.Debug("No valid tool calls found in any raw tool calls.", "rawToolCalls", rawToolCalls)
return nil, errInvalidToolCall
}
@@ -177,6 +185,7 @@ func (p *Parser) checkPrefix(s string) (string, error) {
func (p *Parser) Add(s string) (tools []api.ToolCall, content string) {
p.sb.WriteString(s)
s = p.sb.String()
fmt.Println("Add: Starting with input:", s)
// Check for prefix pattern in input
s, err := p.checkPrefix(s)
@@ -225,23 +234,37 @@ func (p *Parser) Add(s string) (tools []api.ToolCall, content string) {
//
// Returns an error if the template does not contain valid tool call formatting.
func NewParser(templateToProcess *gotmpl.Template) (*Parser, error) {
fmt.Println("Checkpoint 1: Starting NewParser")
parsed, err := template.Parse(templateToProcess.Root.String())
if err != nil {
fmt.Println("Checkpoint 2: Error parsing template:", err)
return nil, err
}
fmt.Println("Checkpoint 3: Getting tool template")
tt, err := toolTemplate(parsed)
fmt.Println("Checkpoint 4: Tool template:", tt.Root.String())
if err != nil {
fmt.Println("Checkpoint 5: Error getting tool template:", err)
return nil, err
}
fmt.Println("Checkpoint 6: Getting tool prefix")
tp := toolPrefix(templateToProcess)
fmt.Println("Checkpoint 7: Tool prefix:", tp)
fmt.Println("Checkpoint 8: Extracting tool args")
name, arguments, err := extractToolArgs(tt)
if err != nil {
fmt.Println("Checkpoint 9: Error extracting tool args:", err)
return nil, err
}
// name := "temp1"
// args := "temp2"
fmt.Println("Checkpoint 10: Tool name:", name, "arguments:", arguments)
fmt.Println("Checkpoint 11: Creating parser")
return &Parser{
tmpl: *tt,
sb: strings.Builder{},
@@ -251,3 +274,8 @@ func NewParser(templateToProcess *gotmpl.Template) (*Parser, error) {
arguments: arguments,
}, nil
}
// NewParser implements the ToolParser interface
func (p *Parser) NewParser(templateToProcess *gotmpl.Template) (ToolParser, error) {
return NewParser(templateToProcess)
}

View File

@@ -4,6 +4,7 @@ import (
"bytes"
"encoding/json"
"errors"
"fmt"
"log/slog"
"slices"
"strings"
@@ -98,6 +99,7 @@ func isToolCallsNode(n *parse.IfNode) bool {
func toolPrefix(tmpl *gotmpl.Template) string {
tokenText, ok := extractToolCallsFormat(tmpl)
fmt.Println("tokenText", tokenText)
if !ok {
return ""
}