mirror of
https://github.com/ollama/ollama.git
synced 2026-02-28 04:56:37 -05:00
Compare commits
7 Commits
pdevine/sa
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a60b9adcce | ||
|
|
a16f96658b | ||
|
|
18ab09b431 | ||
|
|
638faeac54 | ||
|
|
dd5eb6337d | ||
|
|
79917cf80b | ||
|
|
cc90a035a0 |
14
api/types.go
14
api/types.go
@@ -15,6 +15,7 @@ import (
|
|||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
|
"github.com/ollama/ollama/format"
|
||||||
"github.com/ollama/ollama/internal/orderedmap"
|
"github.com/ollama/ollama/internal/orderedmap"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
)
|
)
|
||||||
@@ -569,6 +570,7 @@ type DebugInfo struct {
|
|||||||
|
|
||||||
type Metrics struct {
|
type Metrics struct {
|
||||||
TotalDuration time.Duration `json:"total_duration,omitempty"`
|
TotalDuration time.Duration `json:"total_duration,omitempty"`
|
||||||
|
PeakMemory uint64 `json:"peak_memory,omitempty"`
|
||||||
LoadDuration time.Duration `json:"load_duration,omitempty"`
|
LoadDuration time.Duration `json:"load_duration,omitempty"`
|
||||||
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
|
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
|
||||||
PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"`
|
PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"`
|
||||||
@@ -934,6 +936,10 @@ func (m *Metrics) Summary() {
|
|||||||
fmt.Fprintf(os.Stderr, "total duration: %v\n", m.TotalDuration)
|
fmt.Fprintf(os.Stderr, "total duration: %v\n", m.TotalDuration)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if m.PeakMemory > 0 {
|
||||||
|
fmt.Fprintf(os.Stderr, "peak memory: %s\n", formatPeakMemory(m.PeakMemory))
|
||||||
|
}
|
||||||
|
|
||||||
if m.LoadDuration > 0 {
|
if m.LoadDuration > 0 {
|
||||||
fmt.Fprintf(os.Stderr, "load duration: %v\n", m.LoadDuration)
|
fmt.Fprintf(os.Stderr, "load duration: %v\n", m.LoadDuration)
|
||||||
}
|
}
|
||||||
@@ -957,6 +963,14 @@ func (m *Metrics) Summary() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func formatPeakMemory(b uint64) string {
|
||||||
|
if b >= format.GibiByte {
|
||||||
|
return fmt.Sprintf("%.3f GiB", float64(b)/float64(format.GibiByte))
|
||||||
|
}
|
||||||
|
|
||||||
|
return format.HumanBytes2(b)
|
||||||
|
}
|
||||||
|
|
||||||
func (opts *Options) FromMap(m map[string]any) error {
|
func (opts *Options) FromMap(m map[string]any) error {
|
||||||
valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct
|
valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct
|
||||||
typeOpts := reflect.TypeOf(opts).Elem() // types of the fields in the options struct
|
typeOpts := reflect.TypeOf(opts).Elem() // types of the fields in the options struct
|
||||||
|
|||||||
@@ -74,8 +74,7 @@ type LlamaServer interface {
|
|||||||
Tokenize(ctx context.Context, content string) ([]int, error)
|
Tokenize(ctx context.Context, content string) ([]int, error)
|
||||||
Detokenize(ctx context.Context, tokens []int) (string, error)
|
Detokenize(ctx context.Context, tokens []int) (string, error)
|
||||||
Close() error
|
Close() error
|
||||||
VRAMSize() uint64 // Total VRAM across all GPUs
|
MemorySize() (total, vram uint64)
|
||||||
TotalSize() uint64
|
|
||||||
VRAMByGPU(id ml.DeviceID) uint64
|
VRAMByGPU(id ml.DeviceID) uint64
|
||||||
Pid() int
|
Pid() int
|
||||||
GetPort() int
|
GetPort() int
|
||||||
@@ -685,8 +684,9 @@ func (s *llamaServer) Load(ctx context.Context, systemInfo ml.SystemInfo, system
|
|||||||
// Windows CUDA should not use mmap for best performance
|
// Windows CUDA should not use mmap for best performance
|
||||||
// Linux with a model larger than free space, mmap leads to thrashing
|
// Linux with a model larger than free space, mmap leads to thrashing
|
||||||
// For CPU loads we want the memory to be allocated, not FS cache
|
// For CPU loads we want the memory to be allocated, not FS cache
|
||||||
|
totalSize, _ := s.MemorySize()
|
||||||
if (runtime.GOOS == "windows" && len(gpus) > 0 && gpus[0].Library == "CUDA" && s.options.UseMMap == nil) ||
|
if (runtime.GOOS == "windows" && len(gpus) > 0 && gpus[0].Library == "CUDA" && s.options.UseMMap == nil) ||
|
||||||
(runtime.GOOS == "linux" && systemInfo.FreeMemory < s.TotalSize() && s.options.UseMMap == nil) ||
|
(runtime.GOOS == "linux" && systemInfo.FreeMemory < totalSize && s.options.UseMMap == nil) ||
|
||||||
(len(gpus) == 0 && s.options.UseMMap == nil) ||
|
(len(gpus) == 0 && s.options.UseMMap == nil) ||
|
||||||
(len(gpus) > 0 && gpus[0].Library == "Vulkan" && s.options.UseMMap == nil) ||
|
(len(gpus) > 0 && gpus[0].Library == "Vulkan" && s.options.UseMMap == nil) ||
|
||||||
(s.options.UseMMap != nil && !*s.options.UseMMap) {
|
(s.options.UseMMap != nil && !*s.options.UseMMap) {
|
||||||
@@ -1518,6 +1518,7 @@ type CompletionResponse struct {
|
|||||||
PromptEvalDuration time.Duration `json:"prompt_eval_duration"`
|
PromptEvalDuration time.Duration `json:"prompt_eval_duration"`
|
||||||
EvalCount int `json:"eval_count"`
|
EvalCount int `json:"eval_count"`
|
||||||
EvalDuration time.Duration `json:"eval_duration"`
|
EvalDuration time.Duration `json:"eval_duration"`
|
||||||
|
PeakMemory uint64 `json:"peak_memory,omitempty"`
|
||||||
|
|
||||||
// Logprobs contains log probability information if requested
|
// Logprobs contains log probability information if requested
|
||||||
Logprobs []Logprob `json:"logprobs,omitempty"`
|
Logprobs []Logprob `json:"logprobs,omitempty"`
|
||||||
@@ -1848,17 +1849,17 @@ func (s *llamaServer) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *llmServer) VRAMSize() uint64 {
|
func (s *llmServer) MemorySize() (total, vram uint64) {
|
||||||
if s.mem == nil {
|
if s.mem == nil {
|
||||||
return 0
|
return 0, 0
|
||||||
}
|
}
|
||||||
|
|
||||||
var mem uint64
|
|
||||||
|
|
||||||
for _, g := range s.mem.GPUs {
|
for _, g := range s.mem.GPUs {
|
||||||
mem += g.Size()
|
vram += g.Size()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
total = s.mem.InputWeights + s.mem.CPU.Size() + vram
|
||||||
|
|
||||||
// Some elements are always on CPU. However, if we have allocated all layers
|
// Some elements are always on CPU. However, if we have allocated all layers
|
||||||
// on the GPU then include the CPU components as well, to represent complete offloading.
|
// on the GPU then include the CPU components as well, to represent complete offloading.
|
||||||
noCPULayers := true
|
noCPULayers := true
|
||||||
@@ -1869,25 +1870,11 @@ func (s *llmServer) VRAMSize() uint64 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if noCPULayers {
|
if noCPULayers {
|
||||||
mem += s.mem.InputWeights
|
vram += s.mem.InputWeights
|
||||||
mem += s.mem.CPU.Graph
|
vram += s.mem.CPU.Graph
|
||||||
}
|
}
|
||||||
|
|
||||||
return mem
|
return total, vram
|
||||||
}
|
|
||||||
|
|
||||||
func (s *llmServer) TotalSize() uint64 {
|
|
||||||
if s.mem == nil {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
mem := s.mem.InputWeights
|
|
||||||
mem += s.mem.CPU.Size()
|
|
||||||
for _, g := range s.mem.GPUs {
|
|
||||||
mem += g.Size()
|
|
||||||
}
|
|
||||||
|
|
||||||
return mem
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *llmServer) VRAMByGPU(id ml.DeviceID) uint64 {
|
func (s *llmServer) VRAMByGPU(id ml.DeviceID) uint64 {
|
||||||
|
|||||||
@@ -32,9 +32,10 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type GLM46Parser struct {
|
type GLM46Parser struct {
|
||||||
state glm46ParserState
|
state glm46ParserState
|
||||||
buffer strings.Builder
|
buffer strings.Builder
|
||||||
tools []api.Tool
|
tools []api.Tool
|
||||||
|
callIndex int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *GLM46Parser) HasToolSupport() bool {
|
func (p *GLM46Parser) HasToolSupport() bool {
|
||||||
@@ -48,6 +49,7 @@ func (p *GLM46Parser) HasThinkingSupport() bool {
|
|||||||
// func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
|
// func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
|
||||||
func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||||
p.tools = tools
|
p.tools = tools
|
||||||
|
p.callIndex = 0
|
||||||
return tools
|
return tools
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -89,6 +91,8 @@ func (p *GLM46Parser) Add(s string, done bool) (content string, thinking string,
|
|||||||
slog.Warn("glm-4.6 tool call parsing failed", "error", err)
|
slog.Warn("glm-4.6 tool call parsing failed", "error", err)
|
||||||
return "", "", nil, err
|
return "", "", nil, err
|
||||||
}
|
}
|
||||||
|
toolCall.Function.Index = p.callIndex
|
||||||
|
p.callIndex++
|
||||||
toolCalls = append(toolCalls, toolCall)
|
toolCalls = append(toolCalls, toolCall)
|
||||||
case glm46EventThinkingContent:
|
case glm46EventThinkingContent:
|
||||||
thinkingSb.WriteString(event.content)
|
thinkingSb.WriteString(event.content)
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ type GLM47Parser struct {
|
|||||||
|
|
||||||
func (p *GLM47Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
func (p *GLM47Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||||
p.tools = tools
|
p.tools = tools
|
||||||
|
p.callIndex = 0
|
||||||
// When thinking is enabled (nil or true), the prompt ends with <think>,
|
// When thinking is enabled (nil or true), the prompt ends with <think>,
|
||||||
// so model output starts directly with thinking content (no opening tag).
|
// so model output starts directly with thinking content (no opening tag).
|
||||||
if thinkValue == nil || thinkValue.Bool() {
|
if thinkValue == nil || thinkValue.Bool() {
|
||||||
|
|||||||
@@ -97,3 +97,91 @@ func TestGLM47ParserToolCallEscaping(t *testing.T) {
|
|||||||
t.Fatalf("expected %#v, got %#v", expected, toolCall)
|
t.Fatalf("expected %#v, got %#v", expected, toolCall)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGLM47ParserToolCallIndexing(t *testing.T) {
|
||||||
|
parser := GLM47Parser{}
|
||||||
|
parser.Init(nil, nil, nil)
|
||||||
|
|
||||||
|
input := `plan</think>
|
||||||
|
<tool_call>first<arg_key>a</arg_key><arg_value>1</arg_value></tool_call>
|
||||||
|
<tool_call>second<arg_key>b</arg_key><arg_value>2</arg_value></tool_call>
|
||||||
|
<tool_call>third<arg_key>c</arg_key><arg_value>3</arg_value></tool_call>`
|
||||||
|
|
||||||
|
_, _, calls, err := parser.Add(input, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
want := []api.ToolCall{
|
||||||
|
{Function: api.ToolCallFunction{Name: "first", Arguments: args(`{"a":"1"}`), Index: 0}},
|
||||||
|
{Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 1}},
|
||||||
|
{Function: api.ToolCallFunction{Name: "third", Arguments: args(`{"c":"3"}`), Index: 2}},
|
||||||
|
}
|
||||||
|
if len(calls) != len(want) {
|
||||||
|
t.Fatalf("expected %d calls, got %d", len(want), len(calls))
|
||||||
|
}
|
||||||
|
for i := range want {
|
||||||
|
if !toolCallEqual(calls[i], want[i]) {
|
||||||
|
t.Fatalf("call %d mismatch: got %#v, want %#v", i, calls[i], want[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGLM47ParserToolCallIndexingStreaming(t *testing.T) {
|
||||||
|
parser := GLM47Parser{}
|
||||||
|
parser.Init(nil, nil, nil)
|
||||||
|
|
||||||
|
var all []api.ToolCall
|
||||||
|
|
||||||
|
_, _, calls, err := parser.Add("plan</think><tool_call>first<arg_key>a</arg_key><arg_value>1</arg_value></tool_call><tool_call>second<arg_key>b</arg_key>", false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("step 1 parse failed: %v", err)
|
||||||
|
}
|
||||||
|
all = append(all, calls...)
|
||||||
|
|
||||||
|
_, _, calls, err = parser.Add("<arg_value>2</arg_value></tool_call><tool_call>third<arg_key>c</arg_key><arg_value>3</arg_value></tool_call>", true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("step 2 parse failed: %v", err)
|
||||||
|
}
|
||||||
|
all = append(all, calls...)
|
||||||
|
|
||||||
|
want := []api.ToolCall{
|
||||||
|
{Function: api.ToolCallFunction{Name: "first", Arguments: args(`{"a":"1"}`), Index: 0}},
|
||||||
|
{Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 1}},
|
||||||
|
{Function: api.ToolCallFunction{Name: "third", Arguments: args(`{"c":"3"}`), Index: 2}},
|
||||||
|
}
|
||||||
|
if len(all) != len(want) {
|
||||||
|
t.Fatalf("expected %d calls, got %d", len(want), len(all))
|
||||||
|
}
|
||||||
|
for i := range want {
|
||||||
|
if !toolCallEqual(all[i], want[i]) {
|
||||||
|
t.Fatalf("call %d mismatch: got %#v, want %#v", i, all[i], want[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGLM47ParserToolCallIndexResetOnInit(t *testing.T) {
|
||||||
|
parser := GLM47Parser{}
|
||||||
|
parser.Init(nil, nil, nil)
|
||||||
|
|
||||||
|
_, _, _, err := parser.Add("plan</think><tool_call>first<arg_key>a</arg_key><arg_value>1</arg_value></tool_call>", true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("first parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
parser.Init(nil, nil, nil)
|
||||||
|
_, _, calls, err := parser.Add("plan</think><tool_call>second<arg_key>b</arg_key><arg_value>2</arg_value></tool_call>", true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("second parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
want := api.ToolCall{
|
||||||
|
Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 0},
|
||||||
|
}
|
||||||
|
if len(calls) != 1 {
|
||||||
|
t.Fatalf("expected 1 call, got %d", len(calls))
|
||||||
|
}
|
||||||
|
if !toolCallEqual(calls[0], want) {
|
||||||
|
t.Fatalf("got %#v, want %#v", calls[0], want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ type Qwen3Parser struct {
|
|||||||
state qwen3ParserState
|
state qwen3ParserState
|
||||||
buffer strings.Builder
|
buffer strings.Builder
|
||||||
tools []api.Tool
|
tools []api.Tool
|
||||||
|
callIndex int
|
||||||
hasThinkingSupport bool
|
hasThinkingSupport bool
|
||||||
defaultThinking bool
|
defaultThinking bool
|
||||||
maybeThinkingOpenAtBOL bool
|
maybeThinkingOpenAtBOL bool
|
||||||
@@ -54,6 +55,7 @@ func (p *Qwen3Parser) HasThinkingSupport() bool {
|
|||||||
func (p *Qwen3Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
func (p *Qwen3Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||||
p.tools = tools
|
p.tools = tools
|
||||||
p.buffer.Reset()
|
p.buffer.Reset()
|
||||||
|
p.callIndex = 0
|
||||||
|
|
||||||
thinkingEnabled := thinkValue != nil && thinkValue.Bool()
|
thinkingEnabled := thinkValue != nil && thinkValue.Bool()
|
||||||
if thinkValue == nil {
|
if thinkValue == nil {
|
||||||
@@ -106,6 +108,8 @@ func (p *Qwen3Parser) Add(s string, done bool) (content string, thinking string,
|
|||||||
slog.Warn("qwen3 tool call parsing failed", "error", err)
|
slog.Warn("qwen3 tool call parsing failed", "error", err)
|
||||||
return "", "", nil, err
|
return "", "", nil, err
|
||||||
}
|
}
|
||||||
|
toolCall.Function.Index = p.callIndex
|
||||||
|
p.callIndex++
|
||||||
calls = append(calls, toolCall)
|
calls = append(calls, toolCall)
|
||||||
case qwen3EventThinkingContent:
|
case qwen3EventThinkingContent:
|
||||||
thinkingSb.WriteString(event.content)
|
thinkingSb.WriteString(event.content)
|
||||||
|
|||||||
@@ -230,3 +230,89 @@ func TestQwen35ParserRespectsNoThink(t *testing.T) {
|
|||||||
t.Fatalf("expected no tool calls, got %d", len(calls))
|
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestQwen3ParserToolCallIndexing(t *testing.T) {
|
||||||
|
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
|
||||||
|
parser.Init(nil, nil, &api.ThinkValue{Value: false})
|
||||||
|
|
||||||
|
input := `<tool_call>{"name":"first","arguments":{"a":"1"}}</tool_call>
|
||||||
|
<tool_call>{"name":"second","arguments":{"b":"2"}}</tool_call>
|
||||||
|
<tool_call>{"name":"third","arguments":{"c":"3"}}</tool_call>`
|
||||||
|
_, _, calls, err := parser.Add(input, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
want := []api.ToolCall{
|
||||||
|
{Function: api.ToolCallFunction{Name: "first", Arguments: args(`{"a":"1"}`), Index: 0}},
|
||||||
|
{Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 1}},
|
||||||
|
{Function: api.ToolCallFunction{Name: "third", Arguments: args(`{"c":"3"}`), Index: 2}},
|
||||||
|
}
|
||||||
|
if len(calls) != len(want) {
|
||||||
|
t.Fatalf("expected %d calls, got %d", len(want), len(calls))
|
||||||
|
}
|
||||||
|
for i := range want {
|
||||||
|
if !toolCallEqual(calls[i], want[i]) {
|
||||||
|
t.Fatalf("call %d mismatch: got %#v, want %#v", i, calls[i], want[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen3ParserToolCallIndexingStreaming(t *testing.T) {
|
||||||
|
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
|
||||||
|
parser.Init(nil, nil, &api.ThinkValue{Value: false})
|
||||||
|
|
||||||
|
var all []api.ToolCall
|
||||||
|
|
||||||
|
_, _, calls, err := parser.Add(`<tool_call>{"name":"first","arguments":{"a":"1"}}</tool_call><tool_call>{"name":"second","arguments":{"b":"2"}`, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("step 1 parse failed: %v", err)
|
||||||
|
}
|
||||||
|
all = append(all, calls...)
|
||||||
|
|
||||||
|
_, _, calls, err = parser.Add(`}</tool_call><tool_call>{"name":"third","arguments":{"c":"3"}}</tool_call>`, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("step 2 parse failed: %v", err)
|
||||||
|
}
|
||||||
|
all = append(all, calls...)
|
||||||
|
|
||||||
|
want := []api.ToolCall{
|
||||||
|
{Function: api.ToolCallFunction{Name: "first", Arguments: args(`{"a":"1"}`), Index: 0}},
|
||||||
|
{Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 1}},
|
||||||
|
{Function: api.ToolCallFunction{Name: "third", Arguments: args(`{"c":"3"}`), Index: 2}},
|
||||||
|
}
|
||||||
|
if len(all) != len(want) {
|
||||||
|
t.Fatalf("expected %d calls, got %d", len(want), len(all))
|
||||||
|
}
|
||||||
|
for i := range want {
|
||||||
|
if !toolCallEqual(all[i], want[i]) {
|
||||||
|
t.Fatalf("call %d mismatch: got %#v, want %#v", i, all[i], want[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen3ParserToolCallIndexResetOnInit(t *testing.T) {
|
||||||
|
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
|
||||||
|
parser.Init(nil, nil, &api.ThinkValue{Value: false})
|
||||||
|
|
||||||
|
_, _, _, err := parser.Add(`<tool_call>{"name":"first","arguments":{"a":"1"}}</tool_call>`, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("first parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
parser.Init(nil, nil, &api.ThinkValue{Value: false})
|
||||||
|
_, _, calls, err := parser.Add(`<tool_call>{"name":"second","arguments":{"b":"2"}}</tool_call>`, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("second parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
want := api.ToolCall{
|
||||||
|
Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 0},
|
||||||
|
}
|
||||||
|
if len(calls) != 1 {
|
||||||
|
t.Fatalf("expected 1 call, got %d", len(calls))
|
||||||
|
}
|
||||||
|
if !toolCallEqual(calls[0], want) {
|
||||||
|
t.Fatalf("got %#v, want %#v", calls[0], want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -29,9 +29,10 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Qwen3CoderParser struct {
|
type Qwen3CoderParser struct {
|
||||||
state qwenParserState
|
state qwenParserState
|
||||||
acc strings.Builder
|
acc strings.Builder
|
||||||
tools []api.Tool
|
tools []api.Tool
|
||||||
|
callIndex int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Qwen3CoderParser) HasToolSupport() bool {
|
func (p *Qwen3CoderParser) HasToolSupport() bool {
|
||||||
@@ -44,6 +45,7 @@ func (p *Qwen3CoderParser) HasThinkingSupport() bool {
|
|||||||
|
|
||||||
func (p *Qwen3CoderParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
func (p *Qwen3CoderParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||||
p.tools = tools
|
p.tools = tools
|
||||||
|
p.callIndex = 0
|
||||||
return tools // Qwen doesn't modify tools
|
return tools // Qwen doesn't modify tools
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -62,6 +64,8 @@ func (p *Qwen3CoderParser) Add(s string, done bool) (content string, thinking st
|
|||||||
slog.Warn("qwen tool call parsing failed", "error", err)
|
slog.Warn("qwen tool call parsing failed", "error", err)
|
||||||
return "", "", nil, err
|
return "", "", nil, err
|
||||||
}
|
}
|
||||||
|
toolCall.Function.Index = p.callIndex
|
||||||
|
p.callIndex++
|
||||||
toolCalls = append(toolCalls, toolCall)
|
toolCalls = append(toolCalls, toolCall)
|
||||||
case qwenEventContent:
|
case qwenEventContent:
|
||||||
// TODO(drifkin): if the same turn contains multiple interleaved content
|
// TODO(drifkin): if the same turn contains multiple interleaved content
|
||||||
|
|||||||
@@ -1035,6 +1035,92 @@ func TestQwenToolCallValueParsing(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestQwen3CoderParserToolCallIndexing(t *testing.T) {
|
||||||
|
parser := Qwen3CoderParser{}
|
||||||
|
parser.Init(nil, nil, nil)
|
||||||
|
|
||||||
|
input := `<tool_call><function=first><parameter=a>1</parameter></function></tool_call>
|
||||||
|
<tool_call><function=second><parameter=b>2</parameter></function></tool_call>
|
||||||
|
<tool_call><function=third><parameter=c>3</parameter></function></tool_call>`
|
||||||
|
_, _, calls, err := parser.Add(input, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
want := []api.ToolCall{
|
||||||
|
{Function: api.ToolCallFunction{Name: "first", Arguments: testArgs(map[string]any{"a": "1"}), Index: 0}},
|
||||||
|
{Function: api.ToolCallFunction{Name: "second", Arguments: testArgs(map[string]any{"b": "2"}), Index: 1}},
|
||||||
|
{Function: api.ToolCallFunction{Name: "third", Arguments: testArgs(map[string]any{"c": "3"}), Index: 2}},
|
||||||
|
}
|
||||||
|
if len(calls) != len(want) {
|
||||||
|
t.Fatalf("expected %d calls, got %d", len(want), len(calls))
|
||||||
|
}
|
||||||
|
for i := range want {
|
||||||
|
if !toolCallEqual(calls[i], want[i]) {
|
||||||
|
t.Fatalf("call %d mismatch: got %#v, want %#v", i, calls[i], want[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen3CoderParserToolCallIndexingStreaming(t *testing.T) {
|
||||||
|
parser := Qwen3CoderParser{}
|
||||||
|
parser.Init(nil, nil, nil)
|
||||||
|
|
||||||
|
var all []api.ToolCall
|
||||||
|
|
||||||
|
_, _, calls, err := parser.Add("<tool_call><function=first><parameter=a>1</parameter></function></tool_call><tool_call><function=second>", false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("step 1 parse failed: %v", err)
|
||||||
|
}
|
||||||
|
all = append(all, calls...)
|
||||||
|
|
||||||
|
_, _, calls, err = parser.Add("<parameter=b>2</parameter></function></tool_call><tool_call><function=third><parameter=c>3</parameter></function></tool_call>", true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("step 2 parse failed: %v", err)
|
||||||
|
}
|
||||||
|
all = append(all, calls...)
|
||||||
|
|
||||||
|
want := []api.ToolCall{
|
||||||
|
{Function: api.ToolCallFunction{Name: "first", Arguments: testArgs(map[string]any{"a": "1"}), Index: 0}},
|
||||||
|
{Function: api.ToolCallFunction{Name: "second", Arguments: testArgs(map[string]any{"b": "2"}), Index: 1}},
|
||||||
|
{Function: api.ToolCallFunction{Name: "third", Arguments: testArgs(map[string]any{"c": "3"}), Index: 2}},
|
||||||
|
}
|
||||||
|
if len(all) != len(want) {
|
||||||
|
t.Fatalf("expected %d calls, got %d", len(want), len(all))
|
||||||
|
}
|
||||||
|
for i := range want {
|
||||||
|
if !toolCallEqual(all[i], want[i]) {
|
||||||
|
t.Fatalf("call %d mismatch: got %#v, want %#v", i, all[i], want[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen3CoderParserToolCallIndexResetOnInit(t *testing.T) {
|
||||||
|
parser := Qwen3CoderParser{}
|
||||||
|
parser.Init(nil, nil, nil)
|
||||||
|
|
||||||
|
_, _, _, err := parser.Add("<tool_call><function=first><parameter=a>1</parameter></function></tool_call>", true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("first parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
parser.Init(nil, nil, nil)
|
||||||
|
_, _, calls, err := parser.Add("<tool_call><function=second><parameter=b>2</parameter></function></tool_call>", true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("second parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
want := api.ToolCall{
|
||||||
|
Function: api.ToolCallFunction{Name: "second", Arguments: testArgs(map[string]any{"b": "2"}), Index: 0},
|
||||||
|
}
|
||||||
|
if len(calls) != 1 {
|
||||||
|
t.Fatalf("expected 1 call, got %d", len(calls))
|
||||||
|
}
|
||||||
|
if !toolCallEqual(calls[0], want) {
|
||||||
|
t.Fatalf("got %#v, want %#v", calls[0], want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestQwenXMLTransform(t *testing.T) {
|
func TestQwenXMLTransform(t *testing.T) {
|
||||||
cases := []struct {
|
cases := []struct {
|
||||||
desc string
|
desc string
|
||||||
|
|||||||
@@ -71,6 +71,10 @@ type Model struct {
|
|||||||
Template *template.Template
|
Template *template.Template
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Model) IsMLX() bool {
|
||||||
|
return m.Config.ModelFormat == "safetensors"
|
||||||
|
}
|
||||||
|
|
||||||
// Capabilities returns the capabilities that the model supports
|
// Capabilities returns the capabilities that the model supports
|
||||||
func (m *Model) Capabilities() []model.Capability {
|
func (m *Model) Capabilities() []model.Capability {
|
||||||
capabilities := []model.Capability{}
|
capabilities := []model.Capability{}
|
||||||
|
|||||||
@@ -30,42 +30,44 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
|||||||
lastMsgIdx := len(msgs) - 1
|
lastMsgIdx := len(msgs) - 1
|
||||||
currMsgIdx := 0
|
currMsgIdx := 0
|
||||||
|
|
||||||
// Start with all messages and remove from the front until it fits in context
|
if truncate {
|
||||||
for i := 0; i <= lastMsgIdx; i++ {
|
// Start with all messages and remove from the front until it fits in context
|
||||||
// Collect system messages from the portion we're about to skip
|
for i := 0; i <= lastMsgIdx; i++ {
|
||||||
system = make([]api.Message, 0)
|
// Collect system messages from the portion we're about to skip
|
||||||
for j := range i {
|
system = make([]api.Message, 0)
|
||||||
if msgs[j].Role == "system" {
|
for j := range i {
|
||||||
system = append(system, msgs[j])
|
if msgs[j].Role == "system" {
|
||||||
|
system = append(system, msgs[j])
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
p, err := renderPrompt(m, append(system, msgs[i:]...), tools, think)
|
p, err := renderPrompt(m, append(system, msgs[i:]...), tools, think)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
|
||||||
|
|
||||||
s, err := tokenize(ctx, p)
|
|
||||||
if err != nil {
|
|
||||||
return "", nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
ctxLen := len(s)
|
|
||||||
if m.ProjectorPaths != nil {
|
|
||||||
for _, msg := range msgs[i:] {
|
|
||||||
ctxLen += imageNumTokens * len(msg.Images)
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if !truncate || ctxLen <= opts.NumCtx {
|
s, err := tokenize(ctx, p)
|
||||||
currMsgIdx = i
|
if err != nil {
|
||||||
break
|
return "", nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Must always include at least the last message
|
ctxLen := len(s)
|
||||||
if i == lastMsgIdx {
|
if m.ProjectorPaths != nil {
|
||||||
currMsgIdx = lastMsgIdx
|
for _, msg := range msgs[i:] {
|
||||||
break
|
ctxLen += imageNumTokens * len(msg.Images)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if ctxLen <= opts.NumCtx {
|
||||||
|
currMsgIdx = i
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// Must always include at least the last message
|
||||||
|
if i == lastMsgIdx {
|
||||||
|
currMsgIdx = lastMsgIdx
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -484,7 +484,8 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
// the real chat handler, but doing this as a stopgap to get renderer
|
// the real chat handler, but doing this as a stopgap to get renderer
|
||||||
// support for generate
|
// support for generate
|
||||||
if values.Messages != nil && values.Suffix == "" && req.Template == "" {
|
if values.Messages != nil && values.Suffix == "" && req.Template == "" {
|
||||||
prompt, images, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, values.Messages, []api.Tool{}, req.Think, req.Truncate == nil || *req.Truncate)
|
genTruncate := (req.Truncate == nil || *req.Truncate) && !m.IsMLX()
|
||||||
|
prompt, images, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, values.Messages, []api.Tool{}, req.Think, genTruncate)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
@@ -557,6 +558,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
PromptEvalDuration: cr.PromptEvalDuration,
|
PromptEvalDuration: cr.PromptEvalDuration,
|
||||||
EvalCount: cr.EvalCount,
|
EvalCount: cr.EvalCount,
|
||||||
EvalDuration: cr.EvalDuration,
|
EvalDuration: cr.EvalDuration,
|
||||||
|
PeakMemory: cr.PeakMemory,
|
||||||
},
|
},
|
||||||
Logprobs: toAPILogprobs(cr.Logprobs),
|
Logprobs: toAPILogprobs(cr.Logprobs),
|
||||||
}
|
}
|
||||||
@@ -1951,6 +1953,9 @@ func (s *Server) PsHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
if v.llama != nil {
|
if v.llama != nil {
|
||||||
mr.ContextLength = v.llama.ContextLength()
|
mr.ContextLength = v.llama.ContextLength()
|
||||||
|
total, vram := v.llama.MemorySize()
|
||||||
|
mr.Size = int64(total)
|
||||||
|
mr.SizeVRAM = int64(vram)
|
||||||
}
|
}
|
||||||
// The scheduler waits to set expiresAt, so if a model is loading it's
|
// The scheduler waits to set expiresAt, so if a model is loading it's
|
||||||
// possible that it will be set to the unix epoch. For those cases, just
|
// possible that it will be set to the unix epoch. For those cases, just
|
||||||
@@ -2213,6 +2218,9 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
truncate := req.Truncate == nil || *req.Truncate
|
truncate := req.Truncate == nil || *req.Truncate
|
||||||
|
if m.IsMLX() {
|
||||||
|
truncate = false
|
||||||
|
}
|
||||||
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think, truncate)
|
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think, truncate)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("chat prompt error", "error", err)
|
slog.Error("chat prompt error", "error", err)
|
||||||
@@ -2309,6 +2317,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
PromptEvalDuration: r.PromptEvalDuration,
|
PromptEvalDuration: r.PromptEvalDuration,
|
||||||
EvalCount: r.EvalCount,
|
EvalCount: r.EvalCount,
|
||||||
EvalDuration: r.EvalDuration,
|
EvalDuration: r.EvalDuration,
|
||||||
|
PeakMemory: r.PeakMemory,
|
||||||
},
|
},
|
||||||
Logprobs: toAPILogprobs(r.Logprobs),
|
Logprobs: toAPILogprobs(r.Logprobs),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -231,7 +231,7 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check for experimental safetensors LLM models
|
// Check for experimental safetensors LLM models
|
||||||
if pending.model.Config.ModelFormat == "safetensors" {
|
if pending.model.IsMLX() {
|
||||||
if slices.Contains(pending.model.Config.Capabilities, "completion") {
|
if slices.Contains(pending.model.Config.Capabilities, "completion") {
|
||||||
// LLM model with safetensors format - use MLX runner
|
// LLM model with safetensors format - use MLX runner
|
||||||
if s.loadMLX(pending) {
|
if s.loadMLX(pending) {
|
||||||
@@ -536,6 +536,7 @@ iGPUScan:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
totalSize, vramSize := llama.MemorySize()
|
||||||
runner := &runnerRef{
|
runner := &runnerRef{
|
||||||
model: req.model,
|
model: req.model,
|
||||||
modelPath: req.model.ModelPath,
|
modelPath: req.model.ModelPath,
|
||||||
@@ -545,8 +546,8 @@ iGPUScan:
|
|||||||
sessionDuration: sessionDuration,
|
sessionDuration: sessionDuration,
|
||||||
gpus: gpuIDs,
|
gpus: gpuIDs,
|
||||||
discreteGPUs: discreteGPUs,
|
discreteGPUs: discreteGPUs,
|
||||||
vramSize: llama.VRAMSize(),
|
totalSize: totalSize,
|
||||||
totalSize: llama.TotalSize(),
|
vramSize: vramSize,
|
||||||
loading: true,
|
loading: true,
|
||||||
pid: llama.Pid(),
|
pid: llama.Pid(),
|
||||||
}
|
}
|
||||||
@@ -619,6 +620,7 @@ func (s *Scheduler) loadMLX(req *LlmRequest) bool {
|
|||||||
sessionDuration = req.sessionDuration.Duration
|
sessionDuration = req.sessionDuration.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
|
totalSize, vramSize := server.MemorySize()
|
||||||
runner := &runnerRef{
|
runner := &runnerRef{
|
||||||
model: req.model,
|
model: req.model,
|
||||||
modelPath: req.model.ModelPath,
|
modelPath: req.model.ModelPath,
|
||||||
@@ -628,8 +630,8 @@ func (s *Scheduler) loadMLX(req *LlmRequest) bool {
|
|||||||
loading: false,
|
loading: false,
|
||||||
isImagegen: isImagegen,
|
isImagegen: isImagegen,
|
||||||
sessionDuration: sessionDuration,
|
sessionDuration: sessionDuration,
|
||||||
totalSize: server.TotalSize(),
|
totalSize: totalSize,
|
||||||
vramSize: server.VRAMSize(),
|
vramSize: vramSize,
|
||||||
}
|
}
|
||||||
|
|
||||||
s.loadedMu.Lock()
|
s.loadedMu.Lock()
|
||||||
@@ -762,7 +764,7 @@ func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
if !reflect.DeepEqual(runner.model.AdapterPaths, req.model.AdapterPaths) || // have the adapters changed?
|
if !reflect.DeepEqual(runner.model.AdapterPaths, req.model.AdapterPaths) || // have the adapters changed?
|
||||||
!reflect.DeepEqual(runner.model.ProjectorPaths, req.model.ProjectorPaths) || // have the projectors changed?
|
!reflect.DeepEqual(runner.model.ProjectorPaths, req.model.ProjectorPaths) || // have the projectors changed?
|
||||||
!reflect.DeepEqual(optsExisting, optsNew) || // have the runner options changed?
|
(!runner.model.IsMLX() && !reflect.DeepEqual(optsExisting, optsNew)) || // have the runner options changed?
|
||||||
runner.llama.Ping(ctx) != nil {
|
runner.llama.Ping(ctx) != nil {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -861,8 +861,7 @@ func (s *mockLlm) Close() error {
|
|||||||
s.closeCalled = true
|
s.closeCalled = true
|
||||||
return s.closeResp
|
return s.closeResp
|
||||||
}
|
}
|
||||||
func (s *mockLlm) VRAMSize() uint64 { return s.vramSize }
|
func (s *mockLlm) MemorySize() (uint64, uint64) { return s.totalSize, s.vramSize }
|
||||||
func (s *mockLlm) TotalSize() uint64 { return s.totalSize }
|
|
||||||
func (s *mockLlm) VRAMByGPU(id ml.DeviceID) uint64 { return s.vramByGPU[id] }
|
func (s *mockLlm) VRAMByGPU(id ml.DeviceID) uint64 { return s.vramByGPU[id] }
|
||||||
func (s *mockLlm) Pid() int { return -1 }
|
func (s *mockLlm) Pid() int { return -1 }
|
||||||
func (s *mockLlm) GetPort() int { return -1 }
|
func (s *mockLlm) GetPort() int { return -1 }
|
||||||
|
|||||||
@@ -374,14 +374,9 @@ func (s *Server) Close() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// VRAMSize returns the estimated VRAM usage.
|
// MemorySize returns the total and VRAM memory usage.
|
||||||
func (s *Server) VRAMSize() uint64 {
|
func (s *Server) MemorySize() (total, vram uint64) {
|
||||||
return s.vramSize
|
return s.vramSize, s.vramSize
|
||||||
}
|
|
||||||
|
|
||||||
// TotalSize returns the total memory usage.
|
|
||||||
func (s *Server) TotalSize() uint64 {
|
|
||||||
return s.vramSize
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// VRAMByGPU returns VRAM usage for a specific GPU.
|
// VRAMByGPU returns VRAM usage for a specific GPU.
|
||||||
|
|||||||
@@ -78,6 +78,12 @@ func (c *kvCache) findRemaining(tokens []int32) []int32 {
|
|||||||
prefix++
|
prefix++
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Always keep at least one token to re-evaluate so the
|
||||||
|
// pipeline can seed token generation from it.
|
||||||
|
if prefix == len(tokens) && prefix > 0 {
|
||||||
|
prefix--
|
||||||
|
}
|
||||||
|
|
||||||
if prefix < len(c.tokens) {
|
if prefix < len(c.tokens) {
|
||||||
trim := len(c.tokens) - prefix
|
trim := len(c.tokens) - prefix
|
||||||
for _, kv := range c.caches {
|
for _, kv := range c.caches {
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"math"
|
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -19,25 +18,27 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
"github.com/ollama/ollama/x/imagegen"
|
"github.com/ollama/ollama/x/imagegen"
|
||||||
"github.com/ollama/ollama/x/imagegen/manifest"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Client wraps an MLX runner subprocess to implement llm.LlamaServer for LLM models.
|
// Client wraps an MLX runner subprocess to implement llm.LlamaServer for LLM models.
|
||||||
type Client struct {
|
type Client struct {
|
||||||
port int
|
port int
|
||||||
modelName string
|
modelName string
|
||||||
vramSize uint64
|
contextLength atomic.Int64
|
||||||
done chan error
|
memory atomic.Uint64
|
||||||
client *http.Client
|
done chan error
|
||||||
lastErr string
|
client *http.Client
|
||||||
lastErrLock sync.Mutex
|
lastErr string
|
||||||
mu sync.Mutex
|
lastErrLock sync.Mutex
|
||||||
cmd *exec.Cmd
|
mu sync.Mutex
|
||||||
|
cmd *exec.Cmd
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewClient spawns a new MLX runner subprocess for LLM models and waits until it's ready.
|
// NewClient spawns a new MLX runner subprocess for LLM models and waits until it's ready.
|
||||||
@@ -98,18 +99,9 @@ func NewClient(modelName string) (*Client, error) {
|
|||||||
slog.Debug("mlx subprocess library path", "LD_LIBRARY_PATH", pathEnvVal)
|
slog.Debug("mlx subprocess library path", "LD_LIBRARY_PATH", pathEnvVal)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Estimate VRAM based on tensor size from manifest
|
|
||||||
var vramSize uint64
|
|
||||||
if modelManifest, err := manifest.LoadManifest(modelName); err == nil {
|
|
||||||
vramSize = uint64(modelManifest.TotalTensorSize())
|
|
||||||
} else {
|
|
||||||
vramSize = 8 * 1024 * 1024 * 1024
|
|
||||||
}
|
|
||||||
|
|
||||||
c := &Client{
|
c := &Client{
|
||||||
port: port,
|
port: port,
|
||||||
modelName: modelName,
|
modelName: modelName,
|
||||||
vramSize: vramSize,
|
|
||||||
done: make(chan error, 1),
|
done: make(chan error, 1),
|
||||||
client: &http.Client{Timeout: 10 * time.Minute},
|
client: &http.Client{Timeout: 10 * time.Minute},
|
||||||
cmd: cmd,
|
cmd: cmd,
|
||||||
@@ -201,6 +193,20 @@ type completionOpts struct {
|
|||||||
NumPredict int `json:"num_predict,omitempty"`
|
NumPredict int `json:"num_predict,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type CompletionResponse struct {
|
||||||
|
Content string
|
||||||
|
Done bool
|
||||||
|
DoneReason int
|
||||||
|
|
||||||
|
PromptEvalCount int
|
||||||
|
PromptEvalDuration time.Duration
|
||||||
|
EvalCount int
|
||||||
|
EvalDuration time.Duration
|
||||||
|
PeakMemory uint64
|
||||||
|
|
||||||
|
Error *api.StatusError
|
||||||
|
}
|
||||||
|
|
||||||
// Close terminates the subprocess.
|
// Close terminates the subprocess.
|
||||||
func (c *Client) Close() error {
|
func (c *Client) Close() error {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
@@ -260,28 +266,25 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f
|
|||||||
|
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
var raw struct {
|
var raw CompletionResponse
|
||||||
Content string `json:"content,omitempty"`
|
|
||||||
Done bool `json:"done"`
|
|
||||||
DoneReason int `json:"done_reason,omitempty"`
|
|
||||||
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
|
|
||||||
PromptEvalDuration int `json:"prompt_eval_duration,omitempty"`
|
|
||||||
EvalCount int `json:"eval_count,omitempty"`
|
|
||||||
EvalDuration int `json:"eval_duration,omitempty"`
|
|
||||||
}
|
|
||||||
if err := json.Unmarshal(scanner.Bytes(), &raw); err != nil {
|
if err := json.Unmarshal(scanner.Bytes(), &raw); err != nil {
|
||||||
slog.Debug("mlx response parse error", "error", err, "line", string(scanner.Bytes()))
|
slog.Debug("mlx response parse error", "error", err, "line", string(scanner.Bytes()))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if raw.Error != nil {
|
||||||
|
return *raw.Error
|
||||||
|
}
|
||||||
|
|
||||||
cresp := llm.CompletionResponse{
|
cresp := llm.CompletionResponse{
|
||||||
Content: raw.Content,
|
Content: raw.Content,
|
||||||
Done: raw.Done,
|
Done: raw.Done,
|
||||||
DoneReason: llm.DoneReason(raw.DoneReason),
|
DoneReason: llm.DoneReason(raw.DoneReason),
|
||||||
PromptEvalCount: raw.PromptEvalCount,
|
PromptEvalCount: raw.PromptEvalCount,
|
||||||
PromptEvalDuration: time.Duration(raw.PromptEvalDuration),
|
PromptEvalDuration: raw.PromptEvalDuration,
|
||||||
EvalCount: raw.EvalCount,
|
EvalCount: raw.EvalCount,
|
||||||
EvalDuration: time.Duration(raw.EvalDuration),
|
EvalDuration: raw.EvalDuration,
|
||||||
|
PeakMemory: raw.PeakMemory,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn(cresp)
|
fn(cresp)
|
||||||
@@ -294,7 +297,7 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) ContextLength() int {
|
func (c *Client) ContextLength() int {
|
||||||
return math.MaxInt
|
return int(c.contextLength.Load())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Detokenize implements llm.LlamaServer.
|
// Detokenize implements llm.LlamaServer.
|
||||||
@@ -347,9 +350,16 @@ func (c *Client) Pid() int {
|
|||||||
return -1
|
return -1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type statusResponse struct {
|
||||||
|
Status int
|
||||||
|
Progress int
|
||||||
|
ContextLength int
|
||||||
|
Memory uint64
|
||||||
|
}
|
||||||
|
|
||||||
// Ping implements llm.LlamaServer.
|
// Ping implements llm.LlamaServer.
|
||||||
func (c *Client) Ping(ctx context.Context) error {
|
func (c *Client) Ping(ctx context.Context) error {
|
||||||
reqURL := fmt.Sprintf("http://127.0.0.1:%d/health", c.port)
|
reqURL := fmt.Sprintf("http://127.0.0.1:%d/v1/status", c.port)
|
||||||
req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil)
|
req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -362,6 +372,15 @@ func (c *Client) Ping(ctx context.Context) error {
|
|||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
return fmt.Errorf("health check failed: %d", resp.StatusCode)
|
return fmt.Errorf("health check failed: %d", resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var status statusResponse
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&status); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
c.contextLength.Store(int64(status.ContextLength))
|
||||||
|
c.memory.Store(status.Memory)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -388,19 +407,24 @@ func (c *Client) Tokenize(ctx context.Context, content string) ([]int, error) {
|
|||||||
return tokens, nil
|
return tokens, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// TotalSize implements llm.LlamaServer.
|
func (c *Client) currentMemory() uint64 {
|
||||||
func (c *Client) TotalSize() uint64 {
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||||
return c.vramSize
|
defer cancel()
|
||||||
|
if err := c.Ping(ctx); err != nil {
|
||||||
|
slog.Warn("failed to get current memory", "error", err)
|
||||||
|
}
|
||||||
|
return c.memory.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
// MemorySize implements llm.LlamaServer.
|
||||||
|
func (c *Client) MemorySize() (total, vram uint64) {
|
||||||
|
mem := c.currentMemory()
|
||||||
|
return mem, mem
|
||||||
}
|
}
|
||||||
|
|
||||||
// VRAMByGPU implements llm.LlamaServer.
|
// VRAMByGPU implements llm.LlamaServer.
|
||||||
func (c *Client) VRAMByGPU(id ml.DeviceID) uint64 {
|
func (c *Client) VRAMByGPU(id ml.DeviceID) uint64 {
|
||||||
return c.vramSize
|
return c.currentMemory()
|
||||||
}
|
|
||||||
|
|
||||||
// VRAMSize implements llm.LlamaServer.
|
|
||||||
func (c *Client) VRAMSize() uint64 {
|
|
||||||
return c.vramSize
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// WaitUntilRunning implements llm.LlamaServer.
|
// WaitUntilRunning implements llm.LlamaServer.
|
||||||
|
|||||||
@@ -64,6 +64,10 @@ func PeakMemory() int {
|
|||||||
return int(peak)
|
return int(peak)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ResetPeakMemory() {
|
||||||
|
C.mlx_reset_peak_memory()
|
||||||
|
}
|
||||||
|
|
||||||
type Memory struct{}
|
type Memory struct{}
|
||||||
|
|
||||||
func (Memory) LogValue() slog.Value {
|
func (Memory) LogValue() slog.Value {
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ type Model interface {
|
|||||||
Unembed(x *mlx.Array) *mlx.Array
|
Unembed(x *mlx.Array) *mlx.Array
|
||||||
NumLayers() int
|
NumLayers() int
|
||||||
Tokenizer() *tokenizer.Tokenizer
|
Tokenizer() *tokenizer.Tokenizer
|
||||||
|
MaxContextLength() int
|
||||||
|
|
||||||
// LoadWeights receives all tensors loaded from the manifest and assigns
|
// LoadWeights receives all tensors loaded from the manifest and assigns
|
||||||
// them to model fields. Model-specific logic (MLA absorption, expert
|
// them to model fields. Model-specific logic (MLA absorption, expert
|
||||||
|
|||||||
@@ -6,9 +6,12 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/logutil"
|
"github.com/ollama/ollama/logutil"
|
||||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||||
)
|
)
|
||||||
@@ -44,16 +47,35 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
|||||||
} else {
|
} else {
|
||||||
mlx.DisableCompile()
|
mlx.DisableCompile()
|
||||||
}
|
}
|
||||||
|
mlx.ResetPeakMemory()
|
||||||
|
|
||||||
inputs := r.Tokenizer.Encode(request.Prompt, true)
|
inputs := r.Tokenizer.Encode(request.Prompt, true)
|
||||||
|
if len(inputs) == 0 {
|
||||||
|
return errors.New("empty prompt")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(inputs) >= r.contextLength {
|
||||||
|
return api.StatusError{
|
||||||
|
StatusCode: http.StatusBadRequest,
|
||||||
|
ErrorMessage: fmt.Sprintf("input length (%d tokens) exceeds the model's maximum context length (%d tokens)", len(inputs), r.contextLength),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cap generation to stay within the model's context length
|
||||||
|
maxGenerate := r.contextLength - len(inputs)
|
||||||
|
if request.Options.MaxTokens <= 0 {
|
||||||
|
request.Options.MaxTokens = maxGenerate
|
||||||
|
} else {
|
||||||
|
request.Options.MaxTokens = min(request.Options.MaxTokens, maxGenerate)
|
||||||
|
}
|
||||||
|
|
||||||
session := r.cache.begin(r.Model, inputs)
|
session := r.cache.begin(r.Model, inputs)
|
||||||
defer session.close()
|
defer session.close()
|
||||||
|
|
||||||
caches := session.caches
|
caches := session.caches
|
||||||
tokens := session.remaining
|
tokens := session.remaining
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
total, processed := len(tokens), 0
|
total, processed := len(tokens), 0
|
||||||
slog.Info("Prompt processing progress", "processed", processed, "total", total)
|
|
||||||
for total-processed > 1 {
|
for total-processed > 1 {
|
||||||
if err := request.Ctx.Err(); err != nil {
|
if err := request.Ctx.Err(); err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -93,8 +115,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
|||||||
|
|
||||||
var b bytes.Buffer
|
var b bytes.Buffer
|
||||||
|
|
||||||
now := time.Now()
|
final := CompletionResponse{Done: true, PromptEvalCount: len(inputs), EvalCount: request.Options.MaxTokens, DoneReason: 1}
|
||||||
final := Response{Done: true, PromptTokens: total, CompletionTokens: request.Options.MaxTokens, DoneReason: 1}
|
|
||||||
for i := range request.Options.MaxTokens {
|
for i := range request.Options.MaxTokens {
|
||||||
if err := request.Ctx.Err(); err != nil {
|
if err := request.Ctx.Err(); err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -103,9 +124,8 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
|||||||
nextSample, nextLogprobs = step(sample)
|
nextSample, nextLogprobs = step(sample)
|
||||||
|
|
||||||
if i == 0 {
|
if i == 0 {
|
||||||
slog.Info("Prompt processing progress", "processed", total, "total", total)
|
|
||||||
mlx.Eval(sample)
|
mlx.Eval(sample)
|
||||||
final.PromptTokensDuration = time.Since(now)
|
final.PromptEvalDuration = time.Since(now)
|
||||||
now = time.Now()
|
now = time.Now()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -113,18 +133,16 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
|||||||
session.outputs = append(session.outputs, output)
|
session.outputs = append(session.outputs, output)
|
||||||
|
|
||||||
if r.Tokenizer.IsEOS(output) {
|
if r.Tokenizer.IsEOS(output) {
|
||||||
final.Token = int(output)
|
|
||||||
final.DoneReason = 0
|
final.DoneReason = 0
|
||||||
final.CompletionTokens = i
|
final.EvalCount = i
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-request.Ctx.Done():
|
case <-request.Ctx.Done():
|
||||||
return request.Ctx.Err()
|
return request.Ctx.Err()
|
||||||
case request.Responses <- Response{
|
case request.Responses <- CompletionResponse{
|
||||||
Text: r.Decode(output, &b),
|
Content: r.Decode(output, &b),
|
||||||
Token: int(output),
|
|
||||||
}:
|
}:
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -137,7 +155,8 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
final.CompletionTokensDuration = time.Since(now)
|
final.EvalDuration = time.Since(now)
|
||||||
|
final.PeakMemory = uint64(mlx.PeakMemory())
|
||||||
select {
|
select {
|
||||||
case <-request.Ctx.Done():
|
case <-request.Ctx.Done():
|
||||||
return request.Ctx.Err()
|
return request.Ctx.Err()
|
||||||
|
|||||||
@@ -4,14 +4,15 @@ package mlxrunner
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
|
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||||
"github.com/ollama/ollama/x/mlxrunner/model"
|
"github.com/ollama/ollama/x/mlxrunner/model"
|
||||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||||
@@ -21,7 +22,7 @@ import (
|
|||||||
|
|
||||||
type Request struct {
|
type Request struct {
|
||||||
TextCompletionsRequest
|
TextCompletionsRequest
|
||||||
Responses chan Response
|
Responses chan CompletionResponse
|
||||||
Pipeline func(Request) error
|
Pipeline func(Request) error
|
||||||
|
|
||||||
Ctx context.Context
|
Ctx context.Context
|
||||||
@@ -43,25 +44,12 @@ type TextCompletionsRequest struct {
|
|||||||
} `json:"options"`
|
} `json:"options"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Response struct {
|
|
||||||
Text string `json:"content,omitempty"`
|
|
||||||
Token int `json:"token,omitempty"`
|
|
||||||
Logprobs []float32 `json:"logprobs,omitempty"`
|
|
||||||
Done bool `json:"done,omitempty"`
|
|
||||||
DoneReason int `json:"done_reason,omitempty"`
|
|
||||||
|
|
||||||
PromptTokens int `json:"prompt_eval_count,omitempty"`
|
|
||||||
PromptTokensDuration time.Duration `json:"prompt_eval_duration,omitempty"`
|
|
||||||
CompletionTokens int `json:"eval_count,omitempty"`
|
|
||||||
CompletionTokensDuration time.Duration `json:"eval_duration,omitempty"`
|
|
||||||
TotalTokens int `json:"total_tokens,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Runner struct {
|
type Runner struct {
|
||||||
Model base.Model
|
Model base.Model
|
||||||
Tokenizer *tokenizer.Tokenizer
|
Tokenizer *tokenizer.Tokenizer
|
||||||
Requests chan Request
|
Requests chan Request
|
||||||
cache kvCache
|
cache kvCache
|
||||||
|
contextLength int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Runner) Load(modelName string) error {
|
func (r *Runner) Load(modelName string) error {
|
||||||
@@ -90,6 +78,7 @@ func (r *Runner) Load(modelName string) error {
|
|||||||
|
|
||||||
r.Model = m
|
r.Model = m
|
||||||
r.Tokenizer = m.Tokenizer()
|
r.Tokenizer = m.Tokenizer()
|
||||||
|
r.contextLength = m.MaxContextLength()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -158,6 +147,17 @@ func (r *Runner) Run(host, port string, mux http.Handler) error {
|
|||||||
case request := <-r.Requests:
|
case request := <-r.Requests:
|
||||||
if err := request.Pipeline(request); err != nil {
|
if err := request.Pipeline(request); err != nil {
|
||||||
slog.Info("Request terminated", "error", err)
|
slog.Info("Request terminated", "error", err)
|
||||||
|
var statusErr api.StatusError
|
||||||
|
if !errors.As(err, &statusErr) {
|
||||||
|
statusErr = api.StatusError{
|
||||||
|
StatusCode: http.StatusInternalServerError,
|
||||||
|
ErrorMessage: err.Error(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case request.Responses <- CompletionResponse{Error: &statusErr}:
|
||||||
|
case <-request.Ctx.Done():
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
close(request.Responses)
|
close(request.Responses)
|
||||||
|
|||||||
@@ -50,9 +50,11 @@ func Execute(args []string) error {
|
|||||||
|
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
mux.HandleFunc("GET /v1/status", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("GET /v1/status", func(w http.ResponseWriter, r *http.Request) {
|
||||||
if err := json.NewEncoder(w).Encode(map[string]any{
|
if err := json.NewEncoder(w).Encode(statusResponse{
|
||||||
"status": 0,
|
Status: 0,
|
||||||
"progress": 100,
|
Progress: 100,
|
||||||
|
ContextLength: runner.contextLength,
|
||||||
|
Memory: uint64(mlx.ActiveMemory() + mlx.CacheMemory()),
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
slog.Error("Failed to encode response", "error", err)
|
slog.Error("Failed to encode response", "error", err)
|
||||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||||
@@ -78,7 +80,7 @@ func Execute(args []string) error {
|
|||||||
})
|
})
|
||||||
|
|
||||||
mux.HandleFunc("POST /v1/completions", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("POST /v1/completions", func(w http.ResponseWriter, r *http.Request) {
|
||||||
request := Request{Responses: make(chan Response)}
|
request := Request{Responses: make(chan CompletionResponse)}
|
||||||
|
|
||||||
if err := json.NewDecoder(r.Body).Decode(&request.TextCompletionsRequest); err != nil {
|
if err := json.NewDecoder(r.Body).Decode(&request.TextCompletionsRequest); err != nil {
|
||||||
slog.Error("Failed to decode request", "error", err)
|
slog.Error("Failed to decode request", "error", err)
|
||||||
@@ -87,9 +89,6 @@ func Execute(args []string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
request.Options.MaxTokens = cmp.Or(request.Options.MaxTokens, request.Options.NumPredict)
|
request.Options.MaxTokens = cmp.Or(request.Options.MaxTokens, request.Options.NumPredict)
|
||||||
if request.Options.MaxTokens < 1 {
|
|
||||||
request.Options.MaxTokens = 16 << 10
|
|
||||||
}
|
|
||||||
|
|
||||||
request.Pipeline = runner.TextGenerationPipeline
|
request.Pipeline = runner.TextGenerationPipeline
|
||||||
request.Sampler = sample.New(
|
request.Sampler = sample.New(
|
||||||
|
|||||||
@@ -430,6 +430,10 @@ func (m *Model) NumLayers() int {
|
|||||||
return len(m.Layers)
|
return len(m.Layers)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Model) MaxContextLength() int {
|
||||||
|
return int(m.MaxPositionEmbeddings)
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
|
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
|
||||||
return m.tok
|
return m.tok
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -733,7 +733,7 @@ func (m *Model) Unembed(x *mlx.Array) *mlx.Array {
|
|||||||
func (m *Model) NumLayers() int { return len(m.Layers) }
|
func (m *Model) NumLayers() int { return len(m.Layers) }
|
||||||
|
|
||||||
// MaxContextLength returns the maximum context length
|
// MaxContextLength returns the maximum context length
|
||||||
func (m *Model) MaxContextLength() int32 { return m.MaxPositionEmbeddings }
|
func (m *Model) MaxContextLength() int { return int(m.MaxPositionEmbeddings) }
|
||||||
|
|
||||||
// VocabSize returns the vocabulary size
|
// VocabSize returns the vocabulary size
|
||||||
func (m *Model) VocabSize() int32 { return m.Config.VocabSize }
|
func (m *Model) VocabSize() int32 { return m.Config.VocabSize }
|
||||||
|
|||||||
@@ -262,6 +262,10 @@ func (m *Model) NumLayers() int {
|
|||||||
return len(m.Layers)
|
return len(m.Layers)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Model) MaxContextLength() int {
|
||||||
|
return int(m.MaxPositionEmbeddings)
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
|
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
|
||||||
return m.tok
|
return m.tok
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -279,6 +279,10 @@ func (m *Model) NumLayers() int {
|
|||||||
return len(m.Layers)
|
return len(m.Layers)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Model) MaxContextLength() int {
|
||||||
|
return int(m.MaxPositionEmbeddings)
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
|
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
|
||||||
return m.tok
|
return m.tok
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user