mirror of
https://github.com/ollama/ollama.git
synced 2026-01-02 04:29:51 -05:00
Compare commits
1 Commits
implement-
...
parth/deep
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
478824045d |
@@ -1523,8 +1523,20 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
var toolParser *tools.Parser
|
||||
if len(req.Tools) > 0 {
|
||||
var toolParser tools.ToolParser
|
||||
|
||||
fmt.Println("m.Config.ModelFamily", m.Config.ModelFamily)
|
||||
if m.Config.ModelFamily == "qwen" {
|
||||
slog.Info("using deepseek tool parser")
|
||||
fmt.Println("m.Template.Template", m.Template.Template)
|
||||
toolParser, err = tools.NewDeepSeekToolParser(m.Template.Template)
|
||||
if err != nil {
|
||||
slog.Error("failed to create tool parser", "error", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
} else if len(req.Tools) > 0 {
|
||||
slog.Info("using default tool parser")
|
||||
toolParser, err = tools.NewParser(m.Template.Template)
|
||||
if err != nil {
|
||||
slog.Error("failed to create tool parser", "error", err)
|
||||
|
||||
179
tools/deepseek_tools.go
Normal file
179
tools/deepseek_tools.go
Normal file
@@ -0,0 +1,179 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
gotmpl "text/template"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
type DeepSeekToolParser struct {
|
||||
parser *Parser // Embed the base parser as a field
|
||||
}
|
||||
|
||||
func (p *DeepSeekToolParser) Add(s string) (tools []api.ToolCall, content string) {
|
||||
fmt.Println("prefix", p.parser.prefix)
|
||||
fmt.Println("DeepSeekToolParser.Add: Starting with input:", s)
|
||||
p.parser.sb.WriteString(s)
|
||||
s = p.parser.sb.String()
|
||||
fmt.Println("DeepSeekToolParser.Add: After StringBuilder:", s)
|
||||
|
||||
// Check for prefix pattern in input
|
||||
s, err := p.parser.checkPrefix(s)
|
||||
fmt.Println("DeepSeekToolParser.Add: After checkPrefix:", s, "error:", err)
|
||||
if err != nil {
|
||||
// Need more input to complete prefix
|
||||
return nil, s
|
||||
}
|
||||
|
||||
// Exit if prefix exists in template, greedy parsing is off, and prefix not found
|
||||
if !p.parser.prefixFound {
|
||||
fmt.Println("DeepSeekToolParser.Add: Prefix not found, resetting")
|
||||
p.parser.sb.Reset()
|
||||
return nil, s
|
||||
}
|
||||
|
||||
toolCalls, err := parseDeepSeekToolCalls(s)
|
||||
fmt.Println("DeepSeekToolParser.Add: After parseDeepSeekToolCalls:", toolCalls, "error:", err)
|
||||
if err != nil {
|
||||
if errors.Is(err, errAccumulateMore) {
|
||||
return nil, ""
|
||||
}
|
||||
p.parser.sb.Reset()
|
||||
// Only do greedy JSON parsing if there is no prefix from template
|
||||
if p.parser.prefix != "" {
|
||||
fmt.Println("DeepSeekToolParser.Add: Disabling greedy parsing")
|
||||
p.parser.greedyParseJSON = false
|
||||
}
|
||||
if p.parser.index != 0 && p.parser.prefix == "" {
|
||||
return nil, ""
|
||||
}
|
||||
if p.parser.prefixFound {
|
||||
fmt.Println("DeepSeekToolParser.Add: Prefix found but invalid tool call")
|
||||
// Drop tokens since prefix was found
|
||||
return nil, ""
|
||||
}
|
||||
return nil, s
|
||||
}
|
||||
|
||||
fmt.Println("DeepSeekToolParser.Add: Processing tool calls")
|
||||
for _, tc := range toolCalls {
|
||||
tc.Function.Index = p.parser.index
|
||||
p.parser.index++
|
||||
}
|
||||
|
||||
p.parser.sb.Reset()
|
||||
fmt.Println("DeepSeekToolParser.Add: Returning tool calls:", toolCalls)
|
||||
return toolCalls, ""
|
||||
}
|
||||
|
||||
func (p *DeepSeekToolParser) NewParser(templateToProcess *gotmpl.Template) (ToolParser, error) {
|
||||
return NewDeepSeekToolParser(templateToProcess)
|
||||
}
|
||||
|
||||
func NewDeepSeekToolParser(templateToProcess *gotmpl.Template) (ToolParser, error) {
|
||||
// Create base parser first
|
||||
baseParser, err := NewParser(templateToProcess)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create base parser: %w", err)
|
||||
}
|
||||
|
||||
return &DeepSeekToolParser{
|
||||
parser: baseParser,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func parseDeepSeekToolCalls(s string) ([]api.ToolCall, error) {
|
||||
fmt.Println("parseDeepSeekToolCalls: Starting with input:", s)
|
||||
fields := strings.Fields(s)
|
||||
fmt.Println("parseDeepSeekToolCalls: Split fields:", fields)
|
||||
|
||||
sep := "<|tool▁sep|>"
|
||||
var functionNames []string
|
||||
for _, field := range fields {
|
||||
fmt.Println("parseDeepSeekToolCalls: Processing field:", field)
|
||||
// TODO: check if brittle
|
||||
if strings.Contains(field, "function") {
|
||||
idx := strings.Index(field, "function")
|
||||
if idx == -1 {
|
||||
fmt.Println("parseDeepSeekToolCalls: No 'function' prefix found")
|
||||
return nil, errAccumulateMore
|
||||
}
|
||||
functionName := field[idx+len("function"):]
|
||||
// functionName, cut := strings.CutPrefix(field, "function")
|
||||
// if !cut {
|
||||
// fmt.Println("parseDeepSeekToolCalls: Failed to cut 'function' prefix")
|
||||
// return nil, errAccumulateMore
|
||||
// }
|
||||
// pass through on this is fine as it doesn't always come down
|
||||
functionName, _ = strings.CutPrefix(functionName, sep)
|
||||
fmt.Println("parseDeepSeekToolCalls: Found function name:", functionName)
|
||||
functionNames = append(functionNames, functionName)
|
||||
}
|
||||
}
|
||||
|
||||
if len(functionNames) == 0 {
|
||||
fmt.Println("parseDeepSeekToolCalls: No function names found")
|
||||
return nil, errAccumulateMore
|
||||
}
|
||||
fmt.Println("parseDeepSeekToolCalls: Found function names:", functionNames)
|
||||
|
||||
braceCount := 0
|
||||
startIndex := -1
|
||||
|
||||
var rawToolArgs []string
|
||||
for i, c := range s {
|
||||
switch c {
|
||||
case '{':
|
||||
braceCount++
|
||||
if startIndex == -1 {
|
||||
startIndex = i
|
||||
fmt.Printf("parseDeepSeekToolCalls: Found opening brace at index %d\n", i)
|
||||
}
|
||||
case '}':
|
||||
braceCount--
|
||||
if braceCount == 0 {
|
||||
rawToolArgs = append(rawToolArgs, s[startIndex:i+1])
|
||||
fmt.Printf("parseDeepSeekToolCalls: Found closing brace at index %d, captured: %s\n", i, s[startIndex:i+1])
|
||||
startIndex = -1
|
||||
}
|
||||
}
|
||||
}
|
||||
fmt.Println("parseDeepSeekToolCalls: Raw tool arguments:", rawToolArgs)
|
||||
|
||||
var toolCalls []api.ToolCall
|
||||
// unmarshal args
|
||||
var args map[string]any
|
||||
for i, rawToolArg := range rawToolArgs {
|
||||
fmt.Printf("parseDeepSeekToolCalls: Unmarshaling tool arg %d: %s\n", i, rawToolArg)
|
||||
if err := json.Unmarshal([]byte(rawToolArg), &args); err != nil {
|
||||
fmt.Println("parseDeepSeekToolCalls: Failed to unmarshal JSON:", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
toolCalls = append(toolCalls, api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: functionNames[i],
|
||||
Arguments: args,
|
||||
},
|
||||
})
|
||||
fmt.Printf("parseDeepSeekToolCalls: Created tool call %d with name %s and args %v\n", i, functionNames[i], args)
|
||||
}
|
||||
|
||||
if len(toolCalls) == 0 {
|
||||
fmt.Println("parseDeepSeekToolCalls: No tool calls created")
|
||||
// todo: check err here
|
||||
return nil, errInvalidToolCall
|
||||
}
|
||||
|
||||
fmt.Println("parseDeepSeekToolCalls: Returning tool calls:", toolCalls)
|
||||
return toolCalls, nil
|
||||
}
|
||||
|
||||
// ! use as prefix
|
||||
// {{"<|tool▁call▁begin|>
|
||||
// ! send to tc parser
|
||||
// * function<|tool▁sep|><function_name>\n```json\n<function_arguments_in_json_format>\n```<|tool▁call▁end|>"}}
|
||||
86
tools/deepseek_tools_test.go
Normal file
86
tools/deepseek_tools_test.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/template"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestDeepSeekToolParser(t *testing.T) {
|
||||
p := filepath.Join("testdata")
|
||||
t1 := api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_current_weather",
|
||||
Arguments: map[string]any{
|
||||
"format": "fahrenheit",
|
||||
"location": "San Francisco, CA",
|
||||
},
|
||||
Index: 0,
|
||||
},
|
||||
}
|
||||
|
||||
// t2 := api.ToolCall{
|
||||
// Function: api.ToolCallFunction{
|
||||
// Name: "get_current_weather",
|
||||
// Arguments: map[string]any{
|
||||
// "format": "celsius",
|
||||
// "location": "Toronto, Canada",
|
||||
// },
|
||||
// Index: 1,
|
||||
// },
|
||||
// }
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
template string
|
||||
output string
|
||||
expectedToolCall []api.ToolCall
|
||||
expectedTokens string
|
||||
}{
|
||||
{
|
||||
name: "single tool call",
|
||||
output: `<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_current_weather
|
||||
` + "```json\n" + `{"format":"fahrenheit","location":"San Francisco, CA"}` + "\n```" + `<|tool▁call▁end|>`,
|
||||
expectedToolCall: []api.ToolCall{t1},
|
||||
expectedTokens: "",
|
||||
},
|
||||
// {
|
||||
// name: "multiple tool calls",
|
||||
// template: `"<|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n` + "```json\n" + `{"format":"fahrenheit","location":"San Francisco, CA"}` + "\n```" + `<|tool▁call▁end|>"`,
|
||||
// output: `<|tool▁call▁begin|>function<|tool▁sep|>get_current_weather
|
||||
// ` + "```json\n" + `{"format":"fahrenheit","location":"San Francisco, CA"}` + "\n```" + `<|tool▁call▁end|>
|
||||
// <|tool▁call▁begin|>function<|tool▁sep|>get_current_weather
|
||||
// ` + "```json\n" + `{"format":"celsius","location":"Toronto, Canada"}` + "\n```" + `<|tool▁call▁end|>`,
|
||||
// expectedToolCall: []api.ToolCall{t1, t2},
|
||||
// expectedTokens: "",
|
||||
// },
|
||||
// {
|
||||
// name: "invalid tool call format",
|
||||
// template: `{{"<|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n` + "```json\n" + `{"format":"fahrenheit","location":"San Francisco, CA"}` + "\n```" + `<|tool▁call▁end|>"}}`,
|
||||
// output: "This is just some text without a tool call",
|
||||
// expectedToolCall: nil,
|
||||
// expectedTokens: "This is just some text without a tool call",
|
||||
// },
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tmpl, err := template.Parse(readFile(t, p, "deepseek-r1.gotmpl").String())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
fmt.Println(tmpl.Template.Root.String())
|
||||
|
||||
parser, err := NewDeepSeekToolParser(tmpl.Template)
|
||||
assert.NoError(t, err)
|
||||
|
||||
tools, content := parser.Add(tt.output)
|
||||
assert.Equal(t, tt.expectedToolCall, tools)
|
||||
assert.Equal(t, tt.expectedTokens, content)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package tools
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
gotmpl "text/template"
|
||||
@@ -16,6 +17,11 @@ var (
|
||||
errAccumulateMore = errors.New("need to accumulate more content")
|
||||
)
|
||||
|
||||
type ToolParser interface {
|
||||
Add(s string) (tools []api.ToolCall, content string)
|
||||
NewParser(templateToProcess *gotmpl.Template) (ToolParser, error)
|
||||
}
|
||||
|
||||
type Parser struct {
|
||||
greedyParseJSON bool
|
||||
prefix string
|
||||
@@ -104,6 +110,9 @@ func parseJSONToolCalls(s string, name, arguments string, prefix string) ([]api.
|
||||
continue
|
||||
}
|
||||
|
||||
fmt.Println("name", name)
|
||||
fmt.Println("arguments", arguments)
|
||||
fmt.Println("parseJSONToolCalls: Objects:", objs)
|
||||
// Extract tool calls from objects
|
||||
for _, kv := range objs {
|
||||
n, nok := kv[name].(string)
|
||||
@@ -123,7 +132,6 @@ func parseJSONToolCalls(s string, name, arguments string, prefix string) ([]api.
|
||||
|
||||
// Valid JSON, no tool calls found
|
||||
if len(toolCalls) == 0 {
|
||||
slog.Debug("No valid tool calls found in any raw tool calls.", "rawToolCalls", rawToolCalls)
|
||||
return nil, errInvalidToolCall
|
||||
}
|
||||
|
||||
@@ -177,6 +185,7 @@ func (p *Parser) checkPrefix(s string) (string, error) {
|
||||
func (p *Parser) Add(s string) (tools []api.ToolCall, content string) {
|
||||
p.sb.WriteString(s)
|
||||
s = p.sb.String()
|
||||
fmt.Println("Add: Starting with input:", s)
|
||||
|
||||
// Check for prefix pattern in input
|
||||
s, err := p.checkPrefix(s)
|
||||
@@ -225,23 +234,37 @@ func (p *Parser) Add(s string) (tools []api.ToolCall, content string) {
|
||||
//
|
||||
// Returns an error if the template does not contain valid tool call formatting.
|
||||
func NewParser(templateToProcess *gotmpl.Template) (*Parser, error) {
|
||||
fmt.Println("Checkpoint 1: Starting NewParser")
|
||||
parsed, err := template.Parse(templateToProcess.Root.String())
|
||||
if err != nil {
|
||||
fmt.Println("Checkpoint 2: Error parsing template:", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fmt.Println("Checkpoint 3: Getting tool template")
|
||||
tt, err := toolTemplate(parsed)
|
||||
fmt.Println("Checkpoint 4: Tool template:", tt.Root.String())
|
||||
if err != nil {
|
||||
fmt.Println("Checkpoint 5: Error getting tool template:", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fmt.Println("Checkpoint 6: Getting tool prefix")
|
||||
tp := toolPrefix(templateToProcess)
|
||||
fmt.Println("Checkpoint 7: Tool prefix:", tp)
|
||||
|
||||
fmt.Println("Checkpoint 8: Extracting tool args")
|
||||
name, arguments, err := extractToolArgs(tt)
|
||||
if err != nil {
|
||||
fmt.Println("Checkpoint 9: Error extracting tool args:", err)
|
||||
return nil, err
|
||||
}
|
||||
// name := "temp1"
|
||||
// args := "temp2"
|
||||
|
||||
fmt.Println("Checkpoint 10: Tool name:", name, "arguments:", arguments)
|
||||
|
||||
fmt.Println("Checkpoint 11: Creating parser")
|
||||
return &Parser{
|
||||
tmpl: *tt,
|
||||
sb: strings.Builder{},
|
||||
@@ -251,3 +274,8 @@ func NewParser(templateToProcess *gotmpl.Template) (*Parser, error) {
|
||||
arguments: arguments,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NewParser implements the ToolParser interface
|
||||
func (p *Parser) NewParser(templateToProcess *gotmpl.Template) (ToolParser, error) {
|
||||
return NewParser(templateToProcess)
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"slices"
|
||||
"strings"
|
||||
@@ -98,6 +99,7 @@ func isToolCallsNode(n *parse.IfNode) bool {
|
||||
|
||||
func toolPrefix(tmpl *gotmpl.Template) string {
|
||||
tokenText, ok := extractToolCallsFormat(tmpl)
|
||||
fmt.Println("tokenText", tokenText)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user