mirror of
https://github.com/ollama/ollama.git
synced 2026-03-01 13:36:41 -05:00
Compare commits
15 Commits
v0.17.1-rc
...
jmorganca/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
330b19b73f | ||
|
|
8da09b1e7e | ||
|
|
a60b9adcce | ||
|
|
a16f96658b | ||
|
|
18ab09b431 | ||
|
|
638faeac54 | ||
|
|
dd5eb6337d | ||
|
|
79917cf80b | ||
|
|
cc90a035a0 | ||
|
|
d98dda4676 | ||
|
|
d69ddc1edc | ||
|
|
9bf41969f0 | ||
|
|
0f23b7bff5 | ||
|
|
4e57d2094e | ||
|
|
7f9efd53df |
14
api/types.go
14
api/types.go
@@ -15,6 +15,7 @@ import (
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/internal/orderedmap"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
@@ -569,6 +570,7 @@ type DebugInfo struct {
|
||||
|
||||
type Metrics struct {
|
||||
TotalDuration time.Duration `json:"total_duration,omitempty"`
|
||||
PeakMemory uint64 `json:"peak_memory,omitempty"`
|
||||
LoadDuration time.Duration `json:"load_duration,omitempty"`
|
||||
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
|
||||
PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"`
|
||||
@@ -934,6 +936,10 @@ func (m *Metrics) Summary() {
|
||||
fmt.Fprintf(os.Stderr, "total duration: %v\n", m.TotalDuration)
|
||||
}
|
||||
|
||||
if m.PeakMemory > 0 {
|
||||
fmt.Fprintf(os.Stderr, "peak memory: %s\n", formatPeakMemory(m.PeakMemory))
|
||||
}
|
||||
|
||||
if m.LoadDuration > 0 {
|
||||
fmt.Fprintf(os.Stderr, "load duration: %v\n", m.LoadDuration)
|
||||
}
|
||||
@@ -957,6 +963,14 @@ func (m *Metrics) Summary() {
|
||||
}
|
||||
}
|
||||
|
||||
func formatPeakMemory(b uint64) string {
|
||||
if b >= format.GibiByte {
|
||||
return fmt.Sprintf("%.3f GiB", float64(b)/float64(format.GibiByte))
|
||||
}
|
||||
|
||||
return format.HumanBytes2(b)
|
||||
}
|
||||
|
||||
func (opts *Options) FromMap(m map[string]any) error {
|
||||
valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct
|
||||
typeOpts := reflect.TypeOf(opts).Elem() // types of the fields in the options struct
|
||||
|
||||
@@ -35,6 +35,7 @@ import (
|
||||
var (
|
||||
wv = &Webview{}
|
||||
uiServerPort int
|
||||
appStore *store.Store
|
||||
)
|
||||
|
||||
var debug = strings.EqualFold(os.Getenv("OLLAMA_DEBUG"), "true") || os.Getenv("OLLAMA_DEBUG") == "1"
|
||||
@@ -208,6 +209,7 @@ func main() {
|
||||
uiServerPort = port
|
||||
|
||||
st := &store.Store{}
|
||||
appStore = st
|
||||
|
||||
// Enable CORS in development mode
|
||||
if devMode {
|
||||
@@ -294,8 +296,15 @@ func main() {
|
||||
|
||||
// Check for pending updates on startup (show tray notification if update is ready)
|
||||
if updater.IsUpdatePending() {
|
||||
slog.Debug("update pending on startup, showing tray notification")
|
||||
UpdateAvailable("")
|
||||
// On Windows, the tray is initialized in osRun(). Calling UpdateAvailable
|
||||
// before that would dereference a nil tray callback.
|
||||
// TODO: refactor so the update check runs after platform init on all platforms.
|
||||
if runtime.GOOS == "windows" {
|
||||
slog.Debug("update pending on startup, deferring tray notification until tray initialization")
|
||||
} else {
|
||||
slog.Debug("update pending on startup, showing tray notification")
|
||||
UpdateAvailable("")
|
||||
}
|
||||
}
|
||||
|
||||
hasCompletedFirstRun, err := st.HasCompletedFirstRun()
|
||||
@@ -360,8 +369,7 @@ func startHiddenTasks() {
|
||||
slog.Info("deferring pending update for fast startup")
|
||||
} else {
|
||||
// Check if auto-update is enabled before automatically upgrading
|
||||
st := &store.Store{}
|
||||
settings, err := st.Settings()
|
||||
settings, err := appStore.Settings()
|
||||
if err != nil {
|
||||
slog.Warn("failed to load settings for upgrade check", "error", err)
|
||||
} else if !settings.AutoUpdateEnabled {
|
||||
|
||||
@@ -154,6 +154,10 @@ func handleURLSchemeRequest(urlScheme string) {
|
||||
}
|
||||
|
||||
func UpdateAvailable(ver string) error {
|
||||
if app.t == nil {
|
||||
slog.Debug("tray not yet initialized, skipping update notification")
|
||||
return nil
|
||||
}
|
||||
return app.t.UpdateAvailable(ver)
|
||||
}
|
||||
|
||||
@@ -165,6 +169,14 @@ func osRun(shutdown func(), hasCompletedFirstRun, startHidden bool) {
|
||||
log.Fatalf("Failed to start: %s", err)
|
||||
}
|
||||
|
||||
// Check for pending updates now that the tray is initialized.
|
||||
// The platform-independent check in app.go fires before osRun,
|
||||
// when app.t is still nil, so we must re-check here.
|
||||
if updater.IsUpdatePending() {
|
||||
slog.Debug("update pending on startup, showing tray notification")
|
||||
UpdateAvailable("")
|
||||
}
|
||||
|
||||
signals := make(chan os.Signal, 1)
|
||||
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
|
||||
@@ -289,6 +289,7 @@ func (u *Updater) TriggerImmediateCheck() {
|
||||
|
||||
func (u *Updater) StartBackgroundUpdaterChecker(ctx context.Context, cb func(string) error) {
|
||||
u.checkNow = make(chan struct{}, 1)
|
||||
u.checkNow <- struct{}{} // Trigger first check after initial delay
|
||||
go func() {
|
||||
// Don't blast an update message immediately after startup
|
||||
time.Sleep(UpdateCheckInitialDelay)
|
||||
@@ -333,7 +334,7 @@ func (u *Updater) StartBackgroundUpdaterChecker(ctx context.Context, cb func(str
|
||||
continue
|
||||
}
|
||||
|
||||
// Download successful - show tray notification (regardless of toggle state)
|
||||
// Download successful - show tray notification
|
||||
err = cb(resp.UpdateVersion)
|
||||
if err != nil {
|
||||
slog.Warn("failed to register update available with tray", "error", err)
|
||||
|
||||
@@ -351,10 +351,13 @@ func TestTriggerImmediateCheck(t *testing.T) {
|
||||
|
||||
updater.StartBackgroundUpdaterChecker(ctx, cb)
|
||||
|
||||
// Wait for goroutine to start and pass initial delay
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
// Wait for the initial check that fires after the initial delay
|
||||
select {
|
||||
case <-checkDone:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("initial check did not happen")
|
||||
}
|
||||
|
||||
// With 1 hour interval, no check should have happened yet
|
||||
initialCount := checkCount.Load()
|
||||
|
||||
// Trigger immediate check
|
||||
|
||||
@@ -74,8 +74,7 @@ type LlamaServer interface {
|
||||
Tokenize(ctx context.Context, content string) ([]int, error)
|
||||
Detokenize(ctx context.Context, tokens []int) (string, error)
|
||||
Close() error
|
||||
VRAMSize() uint64 // Total VRAM across all GPUs
|
||||
TotalSize() uint64
|
||||
MemorySize() (total, vram uint64)
|
||||
VRAMByGPU(id ml.DeviceID) uint64
|
||||
Pid() int
|
||||
GetPort() int
|
||||
@@ -685,8 +684,9 @@ func (s *llamaServer) Load(ctx context.Context, systemInfo ml.SystemInfo, system
|
||||
// Windows CUDA should not use mmap for best performance
|
||||
// Linux with a model larger than free space, mmap leads to thrashing
|
||||
// For CPU loads we want the memory to be allocated, not FS cache
|
||||
totalSize, _ := s.MemorySize()
|
||||
if (runtime.GOOS == "windows" && len(gpus) > 0 && gpus[0].Library == "CUDA" && s.options.UseMMap == nil) ||
|
||||
(runtime.GOOS == "linux" && systemInfo.FreeMemory < s.TotalSize() && s.options.UseMMap == nil) ||
|
||||
(runtime.GOOS == "linux" && systemInfo.FreeMemory < totalSize && s.options.UseMMap == nil) ||
|
||||
(len(gpus) == 0 && s.options.UseMMap == nil) ||
|
||||
(len(gpus) > 0 && gpus[0].Library == "Vulkan" && s.options.UseMMap == nil) ||
|
||||
(s.options.UseMMap != nil && !*s.options.UseMMap) {
|
||||
@@ -1518,6 +1518,7 @@ type CompletionResponse struct {
|
||||
PromptEvalDuration time.Duration `json:"prompt_eval_duration"`
|
||||
EvalCount int `json:"eval_count"`
|
||||
EvalDuration time.Duration `json:"eval_duration"`
|
||||
PeakMemory uint64 `json:"peak_memory,omitempty"`
|
||||
|
||||
// Logprobs contains log probability information if requested
|
||||
Logprobs []Logprob `json:"logprobs,omitempty"`
|
||||
@@ -1848,17 +1849,17 @@ func (s *llamaServer) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *llmServer) VRAMSize() uint64 {
|
||||
func (s *llmServer) MemorySize() (total, vram uint64) {
|
||||
if s.mem == nil {
|
||||
return 0
|
||||
return 0, 0
|
||||
}
|
||||
|
||||
var mem uint64
|
||||
|
||||
for _, g := range s.mem.GPUs {
|
||||
mem += g.Size()
|
||||
vram += g.Size()
|
||||
}
|
||||
|
||||
total = s.mem.InputWeights + s.mem.CPU.Size() + vram
|
||||
|
||||
// Some elements are always on CPU. However, if we have allocated all layers
|
||||
// on the GPU then include the CPU components as well, to represent complete offloading.
|
||||
noCPULayers := true
|
||||
@@ -1869,25 +1870,11 @@ func (s *llmServer) VRAMSize() uint64 {
|
||||
}
|
||||
}
|
||||
if noCPULayers {
|
||||
mem += s.mem.InputWeights
|
||||
mem += s.mem.CPU.Graph
|
||||
vram += s.mem.InputWeights
|
||||
vram += s.mem.CPU.Graph
|
||||
}
|
||||
|
||||
return mem
|
||||
}
|
||||
|
||||
func (s *llmServer) TotalSize() uint64 {
|
||||
if s.mem == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
mem := s.mem.InputWeights
|
||||
mem += s.mem.CPU.Size()
|
||||
for _, g := range s.mem.GPUs {
|
||||
mem += g.Size()
|
||||
}
|
||||
|
||||
return mem
|
||||
return total, vram
|
||||
}
|
||||
|
||||
func (s *llmServer) VRAMByGPU(id ml.DeviceID) uint64 {
|
||||
|
||||
@@ -195,6 +195,7 @@ type Tensor interface {
|
||||
Concat(ctx Context, t2 Tensor, dim int) Tensor
|
||||
Rows(ctx Context, t2 Tensor) Tensor
|
||||
SetRows(ctx Context, src Tensor, idxs Tensor) Tensor
|
||||
SetInplace(ctx Context, src Tensor, nb1, nb2, nb3, offset int) Tensor
|
||||
Copy(ctx Context, t2 Tensor) Tensor
|
||||
Duplicate(ctx Context) Tensor
|
||||
|
||||
|
||||
@@ -1345,6 +1345,21 @@ func (t *Tensor) SetRows(ctx ml.Context, src ml.Tensor, idxs ml.Tensor) ml.Tenso
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) SetInplace(ctx ml.Context, src ml.Tensor, nb1, nb2, nb3, offset int) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_set_inplace(
|
||||
ctx.(*Context).ctx,
|
||||
t.t,
|
||||
src.(*Tensor).t,
|
||||
C.size_t(nb1),
|
||||
C.size_t(nb2),
|
||||
C.size_t(nb3),
|
||||
C.size_t(offset),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
|
||||
@@ -41,8 +41,8 @@ type GatedDeltaNet struct {
|
||||
SSMBeta *nn.Linear `gguf:"ssm_beta"` // -> beta (qwen35)
|
||||
SSMAlpha *nn.Linear `gguf:"ssm_alpha"` // -> alpha (qwen35)
|
||||
SSMConv1D *convKernel `gguf:"ssm_conv1d"`
|
||||
SSMDT ml.Tensor `gguf:"ssm_dt"` // alpha bias
|
||||
SSMA ml.Tensor `gguf:"ssm_a"` // -A_log.exp()
|
||||
SSMDT ml.Tensor `gguf:"ssm_dt,alt:ssm_dt.bias"` // alpha bias
|
||||
SSMA ml.Tensor `gguf:"ssm_a"` // -A_log.exp()
|
||||
SSMNorm *nn.RMSNorm `gguf:"ssm_norm"`
|
||||
SSMOut *nn.Linear `gguf:"ssm_out"`
|
||||
|
||||
@@ -122,7 +122,7 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac
|
||||
b := mixedBAReshaped.Slice(ctx, 0, 0, betaSize, 1)
|
||||
a := mixedBAReshaped.Slice(ctx, 0, betaSize, betaSize+alphaSize, 1)
|
||||
|
||||
// Keep beta layout consistent with qwen35 and llama.cpp:
|
||||
// Keep beta layout consistent with qwen35.
|
||||
// [1, numVHeads, nSeqTokens, nSeqs]
|
||||
beta = b.Contiguous(ctx, 1, numVHeads, nSeqTokens, nSeqs)
|
||||
alpha = a.Contiguous(ctx, numVHeads, nSeqTokens, nSeqs)
|
||||
@@ -135,6 +135,18 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac
|
||||
default:
|
||||
return nil, errors.New("qwen3next: missing linear attention beta/alpha projections")
|
||||
}
|
||||
if gdn.SSMDT == nil {
|
||||
return nil, errors.New("qwen3next: missing linear attention ssm_dt tensor")
|
||||
}
|
||||
if gdn.SSMA == nil {
|
||||
return nil, errors.New("qwen3next: missing linear attention ssm_a tensor")
|
||||
}
|
||||
if gdn.SSMConv1D == nil || gdn.SSMConv1D.Weight == nil {
|
||||
return nil, errors.New("qwen3next: missing linear attention ssm_conv1d tensor")
|
||||
}
|
||||
if gdn.SSMNorm == nil || gdn.SSMOut == nil {
|
||||
return nil, errors.New("qwen3next: missing linear attention ssm_norm/ssm_out projections")
|
||||
}
|
||||
|
||||
// Compute gate: softplus(alpha + dt_bias) * -A
|
||||
alphaBiased := alpha.Add(ctx, gdn.SSMDT)
|
||||
@@ -333,7 +345,6 @@ func (gdn *GatedDeltaNet) deltaNetChunked(
|
||||
q = q.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headKDim, nTokens, numVHeads, nSeqs)
|
||||
k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headKDim, nTokens, numVHeads, nSeqs)
|
||||
v = v.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headVDim, nTokens, numVHeads, nSeqs)
|
||||
// Match llama.cpp delta-net-base layout:
|
||||
// gate/beta: [1, numVHeads, nTokens, nSeqs] -> [1, nTokens, numVHeads, nSeqs]
|
||||
gate = gate.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, 1, nTokens, numVHeads, nSeqs)
|
||||
beta = beta.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, 1, nTokens, numVHeads, nSeqs)
|
||||
@@ -437,60 +448,64 @@ func (gdn *GatedDeltaNet) deltaNetChunked(
|
||||
keyGDiff := k.Mul(ctx, gDiffExpReshaped)
|
||||
keyGDiffT := keyGDiff.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
|
||||
// Process chunks and update state
|
||||
var coreAttnOut ml.Tensor
|
||||
newState := state
|
||||
// Process chunks and update state.
|
||||
// Keep a transposed view of v and recurrent state across chunks so the
|
||||
// chunk loop does not need extra transpose+contiguous nodes.
|
||||
vT := v.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx, chunkSize, headVDim, nChunks, numVHeads*nSeqs)
|
||||
stateT := state.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx, headVDim, headVDim, 1, numVHeads*nSeqs)
|
||||
|
||||
for chunk := range nChunks {
|
||||
qChunk := q.Slice(ctx, 2, chunk, chunk+1, 1)
|
||||
vChunk := v.Slice(ctx, 2, chunk, chunk+1, 1)
|
||||
vTChunk := vT.Slice(ctx, 2, chunk, chunk+1, 1)
|
||||
gExpChunk := gExp.Slice(ctx, 2, chunk, chunk+1, 1)
|
||||
kCumdecayChunk := kCumdecay.Slice(ctx, 2, chunk, chunk+1, 1)
|
||||
attnChunk := attnKQ.Slice(ctx, 2, chunk, chunk+1, 1) // Pre-computed!
|
||||
|
||||
// state^T - permute is needed but Contiguous creates a copy
|
||||
stateT := newState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx, headVDim, headVDim, 1, numVHeads*nSeqs)
|
||||
// v'_t = k_cumdecay @ state_t
|
||||
vTPrime := kCumdecayChunk.Mulmat(ctx, stateT)
|
||||
|
||||
// v_prime = k_cumdecay @ state
|
||||
vPrime := stateT.Mulmat(ctx, kCumdecayChunk)
|
||||
|
||||
// v_new = v - v_prime
|
||||
vNew := vChunk.Sub(ctx, vPrime)
|
||||
vNewT := vNew.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
// v_t_new = v_t - v'_t
|
||||
vTNewChunk := vTChunk.Sub(ctx, vTPrime)
|
||||
|
||||
// attn_inter = (q * g_exp) @ state
|
||||
qGExp := qChunk.Mul(ctx, gExpChunk)
|
||||
attnInter := stateT.Mulmat(ctx, qGExp)
|
||||
|
||||
// core_attn_out = attn_inter + attn @ v_new
|
||||
vAttn := vNewT.Mulmat(ctx, attnChunk)
|
||||
vAttn := vTNewChunk.Mulmat(ctx, attnChunk)
|
||||
coreAttnOutChunk := attnInter.Add(ctx, vAttn)
|
||||
|
||||
if coreAttnOut == nil {
|
||||
coreAttnOut = coreAttnOutChunk
|
||||
} else {
|
||||
coreAttnOut = coreAttnOut.Concat(ctx, coreAttnOutChunk, 1)
|
||||
}
|
||||
v = v.SetInplace(
|
||||
ctx,
|
||||
coreAttnOutChunk,
|
||||
v.Stride(1),
|
||||
v.Stride(2),
|
||||
v.Stride(3),
|
||||
chunk*v.Stride(2),
|
||||
)
|
||||
|
||||
// Update state for next chunk
|
||||
gExpLastChunk := gLastExp.Slice(ctx, 2, chunk, chunk+1, 1)
|
||||
kGDiffChunkT := keyGDiffT.Slice(ctx, 2, chunk, chunk+1, 1)
|
||||
kgdMulVNew := vNewT.Mulmat(ctx, kGDiffChunkT)
|
||||
// kgdmulvnew = key_gdiff_t @ v_new_t
|
||||
kgdMulVNew := kGDiffChunkT.Mulmat(ctx, vTNewChunk)
|
||||
|
||||
// state = state * g_last + kgdmulvnew
|
||||
gExpLastReshaped := gExpLastChunk.Contiguous(ctx).Reshape(ctx, 1, 1, numVHeads, nSeqs)
|
||||
newState = newState.Mul(ctx, gExpLastReshaped)
|
||||
newState = newState.Add(ctx, kgdMulVNew.Reshape(ctx, headVDim, headVDim, numVHeads, nSeqs))
|
||||
// stateT = stateT * g_last + kgdmulvnew
|
||||
stateT = stateT.Mul(ctx, gExpLastChunk)
|
||||
stateT = stateT.Add(ctx, kgdMulVNew)
|
||||
}
|
||||
|
||||
// Final reshape
|
||||
coreAttnOut = coreAttnOut.Contiguous(ctx, headVDim, chunkSize*nChunks, numVHeads, nSeqs)
|
||||
coreAttnOut := v.Contiguous(ctx, headVDim, chunkSize*nChunks, numVHeads, nSeqs)
|
||||
|
||||
// Slice to remove padding
|
||||
if pad > 0 {
|
||||
coreAttnOut = coreAttnOut.Slice(ctx, 1, 0, nTokens, 1)
|
||||
}
|
||||
|
||||
// Convert stateT back to cache layout [S_v, S_v, H_v, nSeqs]
|
||||
newState := stateT.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx, headVDim, headVDim, numVHeads, nSeqs)
|
||||
|
||||
// Update delta state in cache
|
||||
cache.UpdateDeltaState(ctx, layer, newState.Reshape(ctx, headVDim, headVDim*numVHeads, nSeqs))
|
||||
|
||||
|
||||
@@ -437,6 +437,46 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
return m.Output.Forward(ctx, hiddenStates), nil
|
||||
}
|
||||
|
||||
func (m *Model) Validate() error {
|
||||
if m.Options == nil {
|
||||
return fmt.Errorf("qwen3next: missing model options")
|
||||
}
|
||||
if len(m.Layers) != len(m.Options.isRecurrent) {
|
||||
return fmt.Errorf("qwen3next: layer config mismatch: have %d layers, %d recurrent flags", len(m.Layers), len(m.Options.isRecurrent))
|
||||
}
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
if !m.Options.isRecurrent[i] {
|
||||
continue
|
||||
}
|
||||
|
||||
gdn, ok := layer.Operator.(*GatedDeltaNet)
|
||||
if !ok || gdn == nil {
|
||||
return fmt.Errorf("qwen3next: layer %d expected recurrent operator", i)
|
||||
}
|
||||
if gdn.SSMQKV == nil || gdn.SSMQKVGate == nil {
|
||||
return fmt.Errorf("qwen3next: layer %d missing attn_qkv/attn_gate projections", i)
|
||||
}
|
||||
if gdn.SSMBetaAlpha == nil && (gdn.SSMBeta == nil || gdn.SSMAlpha == nil) {
|
||||
return fmt.Errorf("qwen3next: layer %d missing linear attention beta/alpha projections", i)
|
||||
}
|
||||
if gdn.SSMDT == nil {
|
||||
return fmt.Errorf("qwen3next: layer %d missing ssm_dt tensor", i)
|
||||
}
|
||||
if gdn.SSMA == nil {
|
||||
return fmt.Errorf("qwen3next: layer %d missing ssm_a tensor", i)
|
||||
}
|
||||
if gdn.SSMConv1D == nil || gdn.SSMConv1D.Weight == nil {
|
||||
return fmt.Errorf("qwen3next: layer %d missing ssm_conv1d tensor", i)
|
||||
}
|
||||
if gdn.SSMNorm == nil || gdn.SSMOut == nil {
|
||||
return fmt.Errorf("qwen3next: layer %d missing ssm_norm/ssm_out projections", i)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
m.positionCache = nil
|
||||
if len(m.mropeSections) > 0 {
|
||||
@@ -450,6 +490,64 @@ var (
|
||||
_ model.MultimodalProcessor = (*Model)(nil)
|
||||
)
|
||||
|
||||
func defaultVHeadReordered(arch string) bool {
|
||||
return arch == "qwen35" || arch == "qwen35moe"
|
||||
}
|
||||
|
||||
func inferRecurrentLayers(headCountKV []uint64, numLayers int, fullAttentionInterval uint32) ([]bool, error) {
|
||||
isRecurrent := make([]bool, numLayers)
|
||||
|
||||
hasZero := false
|
||||
hasFull := false
|
||||
for i := range numLayers {
|
||||
if i >= len(headCountKV) {
|
||||
continue
|
||||
}
|
||||
|
||||
if headCountKV[i] == 0 {
|
||||
isRecurrent[i] = true
|
||||
hasZero = true
|
||||
} else {
|
||||
hasFull = true
|
||||
}
|
||||
}
|
||||
if hasZero && hasFull {
|
||||
return isRecurrent, nil
|
||||
}
|
||||
if !hasFull {
|
||||
return nil, fmt.Errorf("qwen3next: attention.head_count_kv must include at least one non-zero value")
|
||||
}
|
||||
|
||||
// Compatibility path: older imports store a scalar KV head count and omit
|
||||
// per-layer recurrent flags. Derive the hybrid layout from the interval.
|
||||
interval := int(fullAttentionInterval)
|
||||
if interval == 0 {
|
||||
interval = min(4, numLayers)
|
||||
}
|
||||
if interval <= 0 {
|
||||
return nil, fmt.Errorf("qwen3next: invalid block_count (%d)", numLayers)
|
||||
}
|
||||
if interval > numLayers {
|
||||
return nil, fmt.Errorf("qwen3next: full_attention_interval (%d) exceeds block_count (%d)", interval, numLayers)
|
||||
}
|
||||
|
||||
hasZero = false
|
||||
hasFull = false
|
||||
for i := range numLayers {
|
||||
isRecurrent[i] = (i+1)%interval != 0
|
||||
if isRecurrent[i] {
|
||||
hasZero = true
|
||||
} else {
|
||||
hasFull = true
|
||||
}
|
||||
}
|
||||
if !hasZero || !hasFull {
|
||||
return nil, fmt.Errorf("qwen3next: full_attention_interval (%d) does not produce a mixed recurrent/full layout", interval)
|
||||
}
|
||||
|
||||
return isRecurrent, nil
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
numLayers := int(c.Uint("block_count"))
|
||||
layers := make([]Layer, numLayers)
|
||||
@@ -460,26 +558,14 @@ func New(c fs.Config) (model.Model, error) {
|
||||
HeadCountKV() []uint64
|
||||
}
|
||||
|
||||
var isRecurrent []bool
|
||||
var headCountKV []uint64
|
||||
if hc, ok := c.(headCounts); ok {
|
||||
headCountKV = hc.HeadCountKV()
|
||||
}
|
||||
|
||||
isRecurrent = make([]bool, numLayers)
|
||||
hasZero := false
|
||||
hasFull := false
|
||||
for i := range numLayers {
|
||||
// If KV head count is 0, it's a recurrent layer
|
||||
if i < len(headCountKV) && headCountKV[i] == 0 {
|
||||
isRecurrent[i] = true
|
||||
hasZero = true
|
||||
} else if i < len(headCountKV) && headCountKV[i] > 0 {
|
||||
hasFull = true
|
||||
}
|
||||
}
|
||||
if !hasZero || !hasFull {
|
||||
return nil, fmt.Errorf("qwen3next: invalid attention.head_count_kv array; expected mix of zero and non-zero values")
|
||||
isRecurrent, err := inferRecurrentLayers(headCountKV, numLayers, c.Uint("full_attention_interval"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Determine if MoE
|
||||
@@ -543,7 +629,7 @@ func New(c fs.Config) (model.Model, error) {
|
||||
ssmNGroup: int(c.Uint("ssm.group_count")),
|
||||
ssmDtRank: int(c.Uint("ssm.time_step_rank")),
|
||||
convKernelSize: int(c.Uint("ssm.conv_kernel")),
|
||||
vHeadReordered: c.Bool("ssm.v_head_reordered", false),
|
||||
vHeadReordered: c.Bool("ssm.v_head_reordered", defaultVHeadReordered(c.Architecture())),
|
||||
isRecurrent: isRecurrent,
|
||||
mropeSections: slices.Collect(func(yield func(int) bool) {
|
||||
for _, section := range mropeSections {
|
||||
@@ -555,7 +641,7 @@ func New(c fs.Config) (model.Model, error) {
|
||||
mropeInterleaved: c.Bool("rope.mrope_interleaved", c.Bool("mrope_interleaved", false)),
|
||||
}
|
||||
if opts.numKVHeads == 0 {
|
||||
return nil, fmt.Errorf("qwen3next: attention.head_count_kv array must include at least one non-zero value")
|
||||
return nil, fmt.Errorf("qwen3next: attention.head_count_kv must include at least one non-zero value")
|
||||
}
|
||||
|
||||
// Calculate cache dimensions
|
||||
|
||||
65
model/models/qwen3next/model_new_test.go
Normal file
65
model/models/qwen3next/model_new_test.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package qwen3next
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestInferRecurrentLayersMixedKVArray(t *testing.T) {
|
||||
got, err := inferRecurrentLayers([]uint64{0, 2, 0, 2}, 4, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("inferRecurrentLayers() error = %v", err)
|
||||
}
|
||||
|
||||
want := []bool{true, false, true, false}
|
||||
if !slices.Equal(got, want) {
|
||||
t.Fatalf("inferRecurrentLayers() = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInferRecurrentLayersScalarKVDefaultInterval(t *testing.T) {
|
||||
got, err := inferRecurrentLayers([]uint64{2, 2, 2, 2, 2, 2, 2, 2}, 8, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("inferRecurrentLayers() error = %v", err)
|
||||
}
|
||||
|
||||
want := []bool{true, true, true, false, true, true, true, false}
|
||||
if !slices.Equal(got, want) {
|
||||
t.Fatalf("inferRecurrentLayers() = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInferRecurrentLayersScalarKVConfiguredInterval(t *testing.T) {
|
||||
got, err := inferRecurrentLayers([]uint64{2, 2, 2, 2, 2, 2}, 6, 3)
|
||||
if err != nil {
|
||||
t.Fatalf("inferRecurrentLayers() error = %v", err)
|
||||
}
|
||||
|
||||
want := []bool{true, true, false, true, true, false}
|
||||
if !slices.Equal(got, want) {
|
||||
t.Fatalf("inferRecurrentLayers() = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInferRecurrentLayersAllZeroRejects(t *testing.T) {
|
||||
_, err := inferRecurrentLayers([]uint64{0, 0, 0, 0}, 4, 0)
|
||||
if err == nil {
|
||||
t.Fatal("inferRecurrentLayers() expected error, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "must include at least one non-zero value") {
|
||||
t.Fatalf("unexpected error = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultVHeadReordered(t *testing.T) {
|
||||
if !defaultVHeadReordered("qwen35") {
|
||||
t.Fatal("defaultVHeadReordered(qwen35) = false, want true")
|
||||
}
|
||||
if !defaultVHeadReordered("qwen35moe") {
|
||||
t.Fatal("defaultVHeadReordered(qwen35moe) = false, want true")
|
||||
}
|
||||
if defaultVHeadReordered("qwen3next") {
|
||||
t.Fatal("defaultVHeadReordered(qwen3next) = true, want false")
|
||||
}
|
||||
}
|
||||
45
model/models/qwen3next/model_validate_test.go
Normal file
45
model/models/qwen3next/model_validate_test.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package qwen3next
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
)
|
||||
|
||||
func TestValidateRecurrentLayerRequiresSSMDT(t *testing.T) {
|
||||
m := &Model{
|
||||
Layers: []Layer{{
|
||||
Operator: &GatedDeltaNet{
|
||||
SSMQKV: &nn.Linear{},
|
||||
SSMQKVGate: &nn.Linear{},
|
||||
SSMBeta: &nn.Linear{},
|
||||
SSMAlpha: &nn.Linear{},
|
||||
},
|
||||
}},
|
||||
Options: &Options{
|
||||
isRecurrent: []bool{true},
|
||||
},
|
||||
}
|
||||
|
||||
err := m.Validate()
|
||||
if err == nil {
|
||||
t.Fatal("Validate() expected error, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "missing ssm_dt") {
|
||||
t.Fatalf("unexpected error = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateNonRecurrentSkipsLinearChecks(t *testing.T) {
|
||||
m := &Model{
|
||||
Layers: []Layer{{Operator: &FullAttention{}}},
|
||||
Options: &Options{
|
||||
isRecurrent: []bool{false},
|
||||
},
|
||||
}
|
||||
|
||||
if err := m.Validate(); err != nil {
|
||||
t.Fatalf("Validate() error = %v", err)
|
||||
}
|
||||
}
|
||||
@@ -32,9 +32,10 @@ const (
|
||||
)
|
||||
|
||||
type GLM46Parser struct {
|
||||
state glm46ParserState
|
||||
buffer strings.Builder
|
||||
tools []api.Tool
|
||||
state glm46ParserState
|
||||
buffer strings.Builder
|
||||
tools []api.Tool
|
||||
callIndex int
|
||||
}
|
||||
|
||||
func (p *GLM46Parser) HasToolSupport() bool {
|
||||
@@ -48,6 +49,7 @@ func (p *GLM46Parser) HasThinkingSupport() bool {
|
||||
// func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
|
||||
func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||
p.tools = tools
|
||||
p.callIndex = 0
|
||||
return tools
|
||||
}
|
||||
|
||||
@@ -89,6 +91,8 @@ func (p *GLM46Parser) Add(s string, done bool) (content string, thinking string,
|
||||
slog.Warn("glm-4.6 tool call parsing failed", "error", err)
|
||||
return "", "", nil, err
|
||||
}
|
||||
toolCall.Function.Index = p.callIndex
|
||||
p.callIndex++
|
||||
toolCalls = append(toolCalls, toolCall)
|
||||
case glm46EventThinkingContent:
|
||||
thinkingSb.WriteString(event.content)
|
||||
|
||||
@@ -11,6 +11,7 @@ type GLM47Parser struct {
|
||||
|
||||
func (p *GLM47Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||
p.tools = tools
|
||||
p.callIndex = 0
|
||||
// When thinking is enabled (nil or true), the prompt ends with <think>,
|
||||
// so model output starts directly with thinking content (no opening tag).
|
||||
if thinkValue == nil || thinkValue.Bool() {
|
||||
|
||||
@@ -97,3 +97,91 @@ func TestGLM47ParserToolCallEscaping(t *testing.T) {
|
||||
t.Fatalf("expected %#v, got %#v", expected, toolCall)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGLM47ParserToolCallIndexing(t *testing.T) {
|
||||
parser := GLM47Parser{}
|
||||
parser.Init(nil, nil, nil)
|
||||
|
||||
input := `plan</think>
|
||||
<tool_call>first<arg_key>a</arg_key><arg_value>1</arg_value></tool_call>
|
||||
<tool_call>second<arg_key>b</arg_key><arg_value>2</arg_value></tool_call>
|
||||
<tool_call>third<arg_key>c</arg_key><arg_value>3</arg_value></tool_call>`
|
||||
|
||||
_, _, calls, err := parser.Add(input, true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
want := []api.ToolCall{
|
||||
{Function: api.ToolCallFunction{Name: "first", Arguments: args(`{"a":"1"}`), Index: 0}},
|
||||
{Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 1}},
|
||||
{Function: api.ToolCallFunction{Name: "third", Arguments: args(`{"c":"3"}`), Index: 2}},
|
||||
}
|
||||
if len(calls) != len(want) {
|
||||
t.Fatalf("expected %d calls, got %d", len(want), len(calls))
|
||||
}
|
||||
for i := range want {
|
||||
if !toolCallEqual(calls[i], want[i]) {
|
||||
t.Fatalf("call %d mismatch: got %#v, want %#v", i, calls[i], want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGLM47ParserToolCallIndexingStreaming(t *testing.T) {
|
||||
parser := GLM47Parser{}
|
||||
parser.Init(nil, nil, nil)
|
||||
|
||||
var all []api.ToolCall
|
||||
|
||||
_, _, calls, err := parser.Add("plan</think><tool_call>first<arg_key>a</arg_key><arg_value>1</arg_value></tool_call><tool_call>second<arg_key>b</arg_key>", false)
|
||||
if err != nil {
|
||||
t.Fatalf("step 1 parse failed: %v", err)
|
||||
}
|
||||
all = append(all, calls...)
|
||||
|
||||
_, _, calls, err = parser.Add("<arg_value>2</arg_value></tool_call><tool_call>third<arg_key>c</arg_key><arg_value>3</arg_value></tool_call>", true)
|
||||
if err != nil {
|
||||
t.Fatalf("step 2 parse failed: %v", err)
|
||||
}
|
||||
all = append(all, calls...)
|
||||
|
||||
want := []api.ToolCall{
|
||||
{Function: api.ToolCallFunction{Name: "first", Arguments: args(`{"a":"1"}`), Index: 0}},
|
||||
{Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 1}},
|
||||
{Function: api.ToolCallFunction{Name: "third", Arguments: args(`{"c":"3"}`), Index: 2}},
|
||||
}
|
||||
if len(all) != len(want) {
|
||||
t.Fatalf("expected %d calls, got %d", len(want), len(all))
|
||||
}
|
||||
for i := range want {
|
||||
if !toolCallEqual(all[i], want[i]) {
|
||||
t.Fatalf("call %d mismatch: got %#v, want %#v", i, all[i], want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGLM47ParserToolCallIndexResetOnInit(t *testing.T) {
|
||||
parser := GLM47Parser{}
|
||||
parser.Init(nil, nil, nil)
|
||||
|
||||
_, _, _, err := parser.Add("plan</think><tool_call>first<arg_key>a</arg_key><arg_value>1</arg_value></tool_call>", true)
|
||||
if err != nil {
|
||||
t.Fatalf("first parse failed: %v", err)
|
||||
}
|
||||
|
||||
parser.Init(nil, nil, nil)
|
||||
_, _, calls, err := parser.Add("plan</think><tool_call>second<arg_key>b</arg_key><arg_value>2</arg_value></tool_call>", true)
|
||||
if err != nil {
|
||||
t.Fatalf("second parse failed: %v", err)
|
||||
}
|
||||
|
||||
want := api.ToolCall{
|
||||
Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 0},
|
||||
}
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 call, got %d", len(calls))
|
||||
}
|
||||
if !toolCallEqual(calls[0], want) {
|
||||
t.Fatalf("got %#v, want %#v", calls[0], want)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -38,6 +38,7 @@ type Qwen3Parser struct {
|
||||
state qwen3ParserState
|
||||
buffer strings.Builder
|
||||
tools []api.Tool
|
||||
callIndex int
|
||||
hasThinkingSupport bool
|
||||
defaultThinking bool
|
||||
maybeThinkingOpenAtBOL bool
|
||||
@@ -54,6 +55,7 @@ func (p *Qwen3Parser) HasThinkingSupport() bool {
|
||||
func (p *Qwen3Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||
p.tools = tools
|
||||
p.buffer.Reset()
|
||||
p.callIndex = 0
|
||||
|
||||
thinkingEnabled := thinkValue != nil && thinkValue.Bool()
|
||||
if thinkValue == nil {
|
||||
@@ -106,6 +108,8 @@ func (p *Qwen3Parser) Add(s string, done bool) (content string, thinking string,
|
||||
slog.Warn("qwen3 tool call parsing failed", "error", err)
|
||||
return "", "", nil, err
|
||||
}
|
||||
toolCall.Function.Index = p.callIndex
|
||||
p.callIndex++
|
||||
calls = append(calls, toolCall)
|
||||
case qwen3EventThinkingContent:
|
||||
thinkingSb.WriteString(event.content)
|
||||
@@ -204,6 +208,24 @@ func (p *Qwen3Parser) eat() ([]qwen3Event, bool) {
|
||||
p.maybeThinkingOpenAtBOL = false
|
||||
}
|
||||
|
||||
thinkingCloseIdx := strings.Index(acc, qwen3ThinkingCloseTag)
|
||||
toolOpenIdx := strings.Index(acc, qwen3ToolOpenTag)
|
||||
|
||||
// If a tool call starts before </think>, treat that as the end of thinking
|
||||
// for parsing purposes and continue in tool-call mode.
|
||||
if toolOpenIdx != -1 && (thinkingCloseIdx == -1 || toolOpenIdx < thinkingCloseIdx) {
|
||||
before, after := p.splitAtTag(qwen3ToolOpenTag, true)
|
||||
if len(before) > 0 {
|
||||
events = append(events, qwen3EventThinkingContent{content: before})
|
||||
}
|
||||
if after == "" {
|
||||
p.state = qwen3ParserStateToolStartedEatingWhitespace
|
||||
} else {
|
||||
p.state = qwen3ParserStateCollectingToolContent
|
||||
}
|
||||
return events, true
|
||||
}
|
||||
|
||||
if strings.Contains(acc, qwen3ThinkingCloseTag) {
|
||||
thinking, remaining := p.splitAtTag(qwen3ThinkingCloseTag, true)
|
||||
if len(thinking) > 0 {
|
||||
@@ -215,7 +237,7 @@ func (p *Qwen3Parser) eat() ([]qwen3Event, bool) {
|
||||
p.state = qwen3ParserStateCollectingContent
|
||||
}
|
||||
return events, true
|
||||
} else if overlapLen := overlap(acc, qwen3ThinkingCloseTag); overlapLen > 0 {
|
||||
} else if overlapLen := max(overlap(acc, qwen3ThinkingCloseTag), overlap(acc, qwen3ToolOpenTag)); overlapLen > 0 {
|
||||
beforePartialTag := acc[:len(acc)-overlapLen]
|
||||
trailingWsLen := trailingWhitespaceLen(beforePartialTag)
|
||||
ambiguousStart := len(beforePartialTag) - trailingWsLen
|
||||
|
||||
@@ -146,6 +146,68 @@ func TestQwen3ParserToolCall(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3ParserThinkingWithToolCallBeforeThinkingClose(t *testing.T) {
|
||||
parser := &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: true})
|
||||
|
||||
input := "Let me think<tool_call>{\"name\":\"get_weather\",\"arguments\":{\"location\":\"San Francisco\",\"unit\":\"celsius\"}}</tool_call>"
|
||||
content, thinking, calls, err := parser.Add(input, true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
if content != "" {
|
||||
t.Fatalf("expected empty content, got %q", content)
|
||||
}
|
||||
if thinking != "Let me think" {
|
||||
t.Fatalf("expected thinking %q, got %q", "Let me think", thinking)
|
||||
}
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 tool call, got %d", len(calls))
|
||||
}
|
||||
if calls[0].Function.Name != "get_weather" {
|
||||
t.Fatalf("expected tool name %q, got %q", "get_weather", calls[0].Function.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3ParserThinkingWithSplitToolOpenTag(t *testing.T) {
|
||||
parser := &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: true})
|
||||
|
||||
content, thinking, calls, err := parser.Add("Let me think<tool_ca", false)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed on first chunk: %v", err)
|
||||
}
|
||||
if content != "" || thinking != "Let me think" || len(calls) != 0 {
|
||||
t.Fatalf(
|
||||
"expected content=%q thinking=%q calls=%d, got content=%q thinking=%q calls=%d",
|
||||
"",
|
||||
"Let me think",
|
||||
0,
|
||||
content,
|
||||
thinking,
|
||||
len(calls),
|
||||
)
|
||||
}
|
||||
|
||||
content, thinking, calls, err = parser.Add("ll>{\"name\":\"get_weather\",\"arguments\":{\"location\":\"SF\"}}</tool_call>", true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed on second chunk: %v", err)
|
||||
}
|
||||
if content != "" {
|
||||
t.Fatalf("expected empty content, got %q", content)
|
||||
}
|
||||
if thinking != "" {
|
||||
t.Fatalf("expected no additional thinking on second chunk, got %q", thinking)
|
||||
}
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 tool call, got %d", len(calls))
|
||||
}
|
||||
if calls[0].Function.Name != "get_weather" {
|
||||
t.Fatalf("expected tool name %q, got %q", "get_weather", calls[0].Function.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen35ParserRespectsNoThink(t *testing.T) {
|
||||
parser := ParserForName("qwen3.5")
|
||||
if parser == nil {
|
||||
@@ -168,3 +230,89 @@ func TestQwen35ParserRespectsNoThink(t *testing.T) {
|
||||
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3ParserToolCallIndexing(t *testing.T) {
|
||||
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: false})
|
||||
|
||||
input := `<tool_call>{"name":"first","arguments":{"a":"1"}}</tool_call>
|
||||
<tool_call>{"name":"second","arguments":{"b":"2"}}</tool_call>
|
||||
<tool_call>{"name":"third","arguments":{"c":"3"}}</tool_call>`
|
||||
_, _, calls, err := parser.Add(input, true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
want := []api.ToolCall{
|
||||
{Function: api.ToolCallFunction{Name: "first", Arguments: args(`{"a":"1"}`), Index: 0}},
|
||||
{Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 1}},
|
||||
{Function: api.ToolCallFunction{Name: "third", Arguments: args(`{"c":"3"}`), Index: 2}},
|
||||
}
|
||||
if len(calls) != len(want) {
|
||||
t.Fatalf("expected %d calls, got %d", len(want), len(calls))
|
||||
}
|
||||
for i := range want {
|
||||
if !toolCallEqual(calls[i], want[i]) {
|
||||
t.Fatalf("call %d mismatch: got %#v, want %#v", i, calls[i], want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3ParserToolCallIndexingStreaming(t *testing.T) {
|
||||
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: false})
|
||||
|
||||
var all []api.ToolCall
|
||||
|
||||
_, _, calls, err := parser.Add(`<tool_call>{"name":"first","arguments":{"a":"1"}}</tool_call><tool_call>{"name":"second","arguments":{"b":"2"}`, false)
|
||||
if err != nil {
|
||||
t.Fatalf("step 1 parse failed: %v", err)
|
||||
}
|
||||
all = append(all, calls...)
|
||||
|
||||
_, _, calls, err = parser.Add(`}</tool_call><tool_call>{"name":"third","arguments":{"c":"3"}}</tool_call>`, true)
|
||||
if err != nil {
|
||||
t.Fatalf("step 2 parse failed: %v", err)
|
||||
}
|
||||
all = append(all, calls...)
|
||||
|
||||
want := []api.ToolCall{
|
||||
{Function: api.ToolCallFunction{Name: "first", Arguments: args(`{"a":"1"}`), Index: 0}},
|
||||
{Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 1}},
|
||||
{Function: api.ToolCallFunction{Name: "third", Arguments: args(`{"c":"3"}`), Index: 2}},
|
||||
}
|
||||
if len(all) != len(want) {
|
||||
t.Fatalf("expected %d calls, got %d", len(want), len(all))
|
||||
}
|
||||
for i := range want {
|
||||
if !toolCallEqual(all[i], want[i]) {
|
||||
t.Fatalf("call %d mismatch: got %#v, want %#v", i, all[i], want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3ParserToolCallIndexResetOnInit(t *testing.T) {
|
||||
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: false})
|
||||
|
||||
_, _, _, err := parser.Add(`<tool_call>{"name":"first","arguments":{"a":"1"}}</tool_call>`, true)
|
||||
if err != nil {
|
||||
t.Fatalf("first parse failed: %v", err)
|
||||
}
|
||||
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: false})
|
||||
_, _, calls, err := parser.Add(`<tool_call>{"name":"second","arguments":{"b":"2"}}</tool_call>`, true)
|
||||
if err != nil {
|
||||
t.Fatalf("second parse failed: %v", err)
|
||||
}
|
||||
|
||||
want := api.ToolCall{
|
||||
Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 0},
|
||||
}
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 call, got %d", len(calls))
|
||||
}
|
||||
if !toolCallEqual(calls[0], want) {
|
||||
t.Fatalf("got %#v, want %#v", calls[0], want)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -29,9 +29,10 @@ const (
|
||||
)
|
||||
|
||||
type Qwen3CoderParser struct {
|
||||
state qwenParserState
|
||||
acc strings.Builder
|
||||
tools []api.Tool
|
||||
state qwenParserState
|
||||
acc strings.Builder
|
||||
tools []api.Tool
|
||||
callIndex int
|
||||
}
|
||||
|
||||
func (p *Qwen3CoderParser) HasToolSupport() bool {
|
||||
@@ -44,6 +45,7 @@ func (p *Qwen3CoderParser) HasThinkingSupport() bool {
|
||||
|
||||
func (p *Qwen3CoderParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||
p.tools = tools
|
||||
p.callIndex = 0
|
||||
return tools // Qwen doesn't modify tools
|
||||
}
|
||||
|
||||
@@ -62,6 +64,8 @@ func (p *Qwen3CoderParser) Add(s string, done bool) (content string, thinking st
|
||||
slog.Warn("qwen tool call parsing failed", "error", err)
|
||||
return "", "", nil, err
|
||||
}
|
||||
toolCall.Function.Index = p.callIndex
|
||||
p.callIndex++
|
||||
toolCalls = append(toolCalls, toolCall)
|
||||
case qwenEventContent:
|
||||
// TODO(drifkin): if the same turn contains multiple interleaved content
|
||||
|
||||
@@ -1035,6 +1035,92 @@ func TestQwenToolCallValueParsing(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3CoderParserToolCallIndexing(t *testing.T) {
|
||||
parser := Qwen3CoderParser{}
|
||||
parser.Init(nil, nil, nil)
|
||||
|
||||
input := `<tool_call><function=first><parameter=a>1</parameter></function></tool_call>
|
||||
<tool_call><function=second><parameter=b>2</parameter></function></tool_call>
|
||||
<tool_call><function=third><parameter=c>3</parameter></function></tool_call>`
|
||||
_, _, calls, err := parser.Add(input, true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
want := []api.ToolCall{
|
||||
{Function: api.ToolCallFunction{Name: "first", Arguments: testArgs(map[string]any{"a": "1"}), Index: 0}},
|
||||
{Function: api.ToolCallFunction{Name: "second", Arguments: testArgs(map[string]any{"b": "2"}), Index: 1}},
|
||||
{Function: api.ToolCallFunction{Name: "third", Arguments: testArgs(map[string]any{"c": "3"}), Index: 2}},
|
||||
}
|
||||
if len(calls) != len(want) {
|
||||
t.Fatalf("expected %d calls, got %d", len(want), len(calls))
|
||||
}
|
||||
for i := range want {
|
||||
if !toolCallEqual(calls[i], want[i]) {
|
||||
t.Fatalf("call %d mismatch: got %#v, want %#v", i, calls[i], want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3CoderParserToolCallIndexingStreaming(t *testing.T) {
|
||||
parser := Qwen3CoderParser{}
|
||||
parser.Init(nil, nil, nil)
|
||||
|
||||
var all []api.ToolCall
|
||||
|
||||
_, _, calls, err := parser.Add("<tool_call><function=first><parameter=a>1</parameter></function></tool_call><tool_call><function=second>", false)
|
||||
if err != nil {
|
||||
t.Fatalf("step 1 parse failed: %v", err)
|
||||
}
|
||||
all = append(all, calls...)
|
||||
|
||||
_, _, calls, err = parser.Add("<parameter=b>2</parameter></function></tool_call><tool_call><function=third><parameter=c>3</parameter></function></tool_call>", true)
|
||||
if err != nil {
|
||||
t.Fatalf("step 2 parse failed: %v", err)
|
||||
}
|
||||
all = append(all, calls...)
|
||||
|
||||
want := []api.ToolCall{
|
||||
{Function: api.ToolCallFunction{Name: "first", Arguments: testArgs(map[string]any{"a": "1"}), Index: 0}},
|
||||
{Function: api.ToolCallFunction{Name: "second", Arguments: testArgs(map[string]any{"b": "2"}), Index: 1}},
|
||||
{Function: api.ToolCallFunction{Name: "third", Arguments: testArgs(map[string]any{"c": "3"}), Index: 2}},
|
||||
}
|
||||
if len(all) != len(want) {
|
||||
t.Fatalf("expected %d calls, got %d", len(want), len(all))
|
||||
}
|
||||
for i := range want {
|
||||
if !toolCallEqual(all[i], want[i]) {
|
||||
t.Fatalf("call %d mismatch: got %#v, want %#v", i, all[i], want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3CoderParserToolCallIndexResetOnInit(t *testing.T) {
|
||||
parser := Qwen3CoderParser{}
|
||||
parser.Init(nil, nil, nil)
|
||||
|
||||
_, _, _, err := parser.Add("<tool_call><function=first><parameter=a>1</parameter></function></tool_call>", true)
|
||||
if err != nil {
|
||||
t.Fatalf("first parse failed: %v", err)
|
||||
}
|
||||
|
||||
parser.Init(nil, nil, nil)
|
||||
_, _, calls, err := parser.Add("<tool_call><function=second><parameter=b>2</parameter></function></tool_call>", true)
|
||||
if err != nil {
|
||||
t.Fatalf("second parse failed: %v", err)
|
||||
}
|
||||
|
||||
want := api.ToolCall{
|
||||
Function: api.ToolCallFunction{Name: "second", Arguments: testArgs(map[string]any{"b": "2"}), Index: 0},
|
||||
}
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 call, got %d", len(calls))
|
||||
}
|
||||
if !toolCallEqual(calls[0], want) {
|
||||
t.Fatalf("got %#v, want %#v", calls[0], want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwenXMLTransform(t *testing.T) {
|
||||
cases := []struct {
|
||||
desc string
|
||||
|
||||
@@ -180,7 +180,22 @@ func (p *Qwen3VLParser) eat() ([]qwenEvent, bool) {
|
||||
return events, false
|
||||
}
|
||||
case CollectingThinkingContent:
|
||||
if strings.Contains(p.buffer.String(), thinkingCloseTag) {
|
||||
acc := p.buffer.String()
|
||||
thinkingCloseIdx := strings.Index(acc, thinkingCloseTag)
|
||||
toolOpenIdx := strings.Index(acc, toolOpenTag)
|
||||
|
||||
// If a tool call starts before </think>, treat that as the end of thinking
|
||||
// for parsing purposes and continue in tool-call mode.
|
||||
if toolOpenIdx != -1 && (thinkingCloseIdx == -1 || toolOpenIdx < thinkingCloseIdx) {
|
||||
before, _ := splitAtTag(&p.buffer, toolOpenTag, false)
|
||||
if len(before) > 0 {
|
||||
events = append(events, qwenEventThinkingContent{content: before})
|
||||
}
|
||||
p.state = CollectingToolContent
|
||||
return events, true
|
||||
}
|
||||
|
||||
if strings.Contains(acc, thinkingCloseTag) {
|
||||
thinking, remaining := splitAtTag(&p.buffer, thinkingCloseTag, true)
|
||||
if len(thinking) > 0 {
|
||||
events = append(events, qwenEventThinkingContent{content: thinking})
|
||||
@@ -191,13 +206,13 @@ func (p *Qwen3VLParser) eat() ([]qwenEvent, bool) {
|
||||
p.state = CollectingContent
|
||||
}
|
||||
return events, true
|
||||
} else if overlapLen := overlap(p.buffer.String(), thinkingCloseTag); overlapLen > 0 {
|
||||
beforePartialTag := p.buffer.String()[:len(p.buffer.String())-overlapLen]
|
||||
} else if overlapLen := max(overlap(acc, thinkingCloseTag), overlap(acc, toolOpenTag)); overlapLen > 0 {
|
||||
beforePartialTag := acc[:len(acc)-overlapLen]
|
||||
trailingWhitespaceLen := trailingWhitespaceLen(beforePartialTag)
|
||||
ambiguousStart := len(beforePartialTag) - trailingWhitespaceLen
|
||||
|
||||
unambiguous := p.buffer.String()[:ambiguousStart]
|
||||
ambiguous := p.buffer.String()[ambiguousStart:]
|
||||
unambiguous := acc[:ambiguousStart]
|
||||
ambiguous := acc[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
@@ -205,11 +220,11 @@ func (p *Qwen3VLParser) eat() ([]qwenEvent, bool) {
|
||||
}
|
||||
return events, false
|
||||
} else {
|
||||
whitespaceLen := trailingWhitespaceLen(p.buffer.String())
|
||||
ambiguousStart := len(p.buffer.String()) - whitespaceLen
|
||||
whitespaceLen := trailingWhitespaceLen(acc)
|
||||
ambiguousStart := len(acc) - whitespaceLen
|
||||
|
||||
unambiguous := p.buffer.String()[:ambiguousStart]
|
||||
ambiguous := p.buffer.String()[ambiguousStart:]
|
||||
unambiguous := acc[:ambiguousStart]
|
||||
ambiguous := acc[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
|
||||
@@ -98,8 +98,12 @@ func TestQwen3VLThinkingParserStreaming(t *testing.T) {
|
||||
desc: "nested thinking and tool call (outside thinking, inside tool call)",
|
||||
steps: []step{
|
||||
{
|
||||
input: "I'm thinking<tool_call>I'm nested tool call</tool_call></think>",
|
||||
wantEvents: []qwenEvent{qwenEventThinkingContent{content: "I'm thinking<tool_call>I'm nested tool call</tool_call>"}},
|
||||
input: "I'm thinking<tool_call>I'm nested tool call</tool_call></think>",
|
||||
wantEvents: []qwenEvent{
|
||||
qwenEventThinkingContent{content: "I'm thinking"},
|
||||
qwenEventRawToolCall{raw: "I'm nested tool call"},
|
||||
qwenEventContent{content: "</think>"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -109,8 +113,7 @@ func TestQwen3VLThinkingParserStreaming(t *testing.T) {
|
||||
{
|
||||
input: "<tool_call>I'm nested tool call<think>I'm thinking</think></tool_call>",
|
||||
wantEvents: []qwenEvent{
|
||||
qwenEventThinkingContent{content: "<tool_call>I'm nested tool call<think>I'm thinking"},
|
||||
qwenEventContent{content: "</tool_call>"},
|
||||
qwenEventRawToolCall{raw: "I'm nested tool call<think>I'm thinking</think>"},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -121,8 +124,8 @@ func TestQwen3VLThinkingParserStreaming(t *testing.T) {
|
||||
{
|
||||
input: "I'm thinking<tool_call>I'm NOT a nested tool call</think></tool_call><tool_call>I'm nested tool call 2<think></tool_call></think>",
|
||||
wantEvents: []qwenEvent{
|
||||
qwenEventThinkingContent{content: "I'm thinking<tool_call>I'm NOT a nested tool call"},
|
||||
qwenEventContent{content: "</tool_call>"},
|
||||
qwenEventThinkingContent{content: "I'm thinking"},
|
||||
qwenEventRawToolCall{raw: "I'm NOT a nested tool call</think>"},
|
||||
qwenEventRawToolCall{raw: "I'm nested tool call 2<think>"},
|
||||
qwenEventContent{content: "</think>"},
|
||||
},
|
||||
|
||||
@@ -71,6 +71,10 @@ type Model struct {
|
||||
Template *template.Template
|
||||
}
|
||||
|
||||
func (m *Model) IsMLX() bool {
|
||||
return m.Config.ModelFormat == "safetensors"
|
||||
}
|
||||
|
||||
// Capabilities returns the capabilities that the model supports
|
||||
func (m *Model) Capabilities() []model.Capability {
|
||||
capabilities := []model.Capability{}
|
||||
|
||||
@@ -30,42 +30,44 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
||||
lastMsgIdx := len(msgs) - 1
|
||||
currMsgIdx := 0
|
||||
|
||||
// Start with all messages and remove from the front until it fits in context
|
||||
for i := 0; i <= lastMsgIdx; i++ {
|
||||
// Collect system messages from the portion we're about to skip
|
||||
system = make([]api.Message, 0)
|
||||
for j := range i {
|
||||
if msgs[j].Role == "system" {
|
||||
system = append(system, msgs[j])
|
||||
if truncate {
|
||||
// Start with all messages and remove from the front until it fits in context
|
||||
for i := 0; i <= lastMsgIdx; i++ {
|
||||
// Collect system messages from the portion we're about to skip
|
||||
system = make([]api.Message, 0)
|
||||
for j := range i {
|
||||
if msgs[j].Role == "system" {
|
||||
system = append(system, msgs[j])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
p, err := renderPrompt(m, append(system, msgs[i:]...), tools, think)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
s, err := tokenize(ctx, p)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
ctxLen := len(s)
|
||||
if m.ProjectorPaths != nil {
|
||||
for _, msg := range msgs[i:] {
|
||||
ctxLen += imageNumTokens * len(msg.Images)
|
||||
p, err := renderPrompt(m, append(system, msgs[i:]...), tools, think)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if !truncate || ctxLen <= opts.NumCtx {
|
||||
currMsgIdx = i
|
||||
break
|
||||
}
|
||||
s, err := tokenize(ctx, p)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
// Must always include at least the last message
|
||||
if i == lastMsgIdx {
|
||||
currMsgIdx = lastMsgIdx
|
||||
break
|
||||
ctxLen := len(s)
|
||||
if m.ProjectorPaths != nil {
|
||||
for _, msg := range msgs[i:] {
|
||||
ctxLen += imageNumTokens * len(msg.Images)
|
||||
}
|
||||
}
|
||||
|
||||
if ctxLen <= opts.NumCtx {
|
||||
currMsgIdx = i
|
||||
break
|
||||
}
|
||||
|
||||
// Must always include at least the last message
|
||||
if i == lastMsgIdx {
|
||||
currMsgIdx = lastMsgIdx
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -21,33 +21,76 @@ type quantizer struct {
|
||||
progressFn func(n uint64)
|
||||
}
|
||||
|
||||
const quantizationChunkElements uint64 = 4 * 1024 * 1024
|
||||
|
||||
func (q quantizer) WriteTo(w io.Writer) (int64, error) {
|
||||
quantize := q.from.Kind != q.to.Kind
|
||||
sr := io.NewSectionReader(q, int64(q.offset), int64(q.from.Size()))
|
||||
if !quantize {
|
||||
n, err := io.Copy(w, sr)
|
||||
q.progressFn(q.from.Size())
|
||||
if q.progressFn != nil {
|
||||
q.progressFn(q.from.Size())
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
data, err := io.ReadAll(sr)
|
||||
if err != nil {
|
||||
slog.Warn("file read error", "tensor", q.from.Name, "file", q.Name(), "error", err)
|
||||
return 0, fmt.Errorf("unable to read tensor %s from %s: %s", q.from.Name, q.Name(), err)
|
||||
|
||||
if len(q.from.Shape) == 0 || q.from.Shape[0] == 0 {
|
||||
return 0, fmt.Errorf("tensor %s has invalid shape %v", q.from.Name, q.from.Shape)
|
||||
}
|
||||
if uint64(len(data)) < q.from.Size() {
|
||||
return 0, fmt.Errorf("tensor %s data size %d is less than expected %d from shape %v", q.from.Name, len(data), q.from.Size(), q.from.Shape)
|
||||
|
||||
fromType := fsggml.TensorType(q.from.Kind)
|
||||
toType := fsggml.TensorType(q.to.Kind)
|
||||
nPerRow := q.from.Shape[0]
|
||||
totalElements := q.from.Elements()
|
||||
if totalElements%nPerRow != 0 {
|
||||
return 0, fmt.Errorf("tensor %s has non-row-aligned shape %v", q.from.Name, q.from.Shape)
|
||||
}
|
||||
var f32s []float32
|
||||
newType := fsggml.TensorType(q.to.Kind)
|
||||
if fsggml.TensorType(q.from.Kind) == fsggml.TensorTypeF32 {
|
||||
f32s = unsafe.Slice((*float32)(unsafe.Pointer(&data[0])), q.from.Elements())
|
||||
} else {
|
||||
f32s = ggml.ConvertToF32(data, q.from.Kind, q.from.Elements())
|
||||
|
||||
inRowSize := fromType.RowSize(nPerRow)
|
||||
if inRowSize == 0 {
|
||||
return 0, fmt.Errorf("tensor %s has unsupported source type %v", q.from.Name, fromType)
|
||||
}
|
||||
data = ggml.Quantize(newType, f32s, q.from.Shape)
|
||||
n, err := w.Write(data)
|
||||
q.progressFn(q.from.Size())
|
||||
return int64(n), err
|
||||
|
||||
totalRows := totalElements / nPerRow
|
||||
rowsPerChunk := max(quantizationChunkElements/nPerRow, uint64(1))
|
||||
chunkBuf := make([]byte, inRowSize*rowsPerChunk)
|
||||
var written int64
|
||||
|
||||
for row := uint64(0); row < totalRows; {
|
||||
chunkRows := min(rowsPerChunk, totalRows-row)
|
||||
chunkBytes := inRowSize * chunkRows
|
||||
data := chunkBuf[:chunkBytes]
|
||||
|
||||
if _, err := io.ReadFull(sr, data); err != nil {
|
||||
slog.Warn("file read error", "tensor", q.from.Name, "file", q.Name(), "error", err)
|
||||
return written, fmt.Errorf("unable to read tensor %s from %s: %w", q.from.Name, q.Name(), err)
|
||||
}
|
||||
|
||||
var f32s []float32
|
||||
chunkElements := chunkRows * nPerRow
|
||||
if fromType == fsggml.TensorTypeF32 {
|
||||
f32s = unsafe.Slice((*float32)(unsafe.Pointer(&data[0])), chunkElements)
|
||||
} else {
|
||||
f32s = ggml.ConvertToF32(data, q.from.Kind, chunkElements)
|
||||
}
|
||||
|
||||
quantized := ggml.Quantize(toType, f32s, []uint64{nPerRow, chunkRows})
|
||||
n, err := w.Write(quantized)
|
||||
written += int64(n)
|
||||
if err != nil {
|
||||
return written, err
|
||||
}
|
||||
if n != len(quantized) {
|
||||
return written, io.ErrShortWrite
|
||||
}
|
||||
|
||||
if q.progressFn != nil {
|
||||
q.progressFn(chunkBytes)
|
||||
}
|
||||
row += chunkRows
|
||||
}
|
||||
|
||||
return written, nil
|
||||
}
|
||||
|
||||
type quantizeState struct {
|
||||
|
||||
@@ -484,7 +484,8 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
// the real chat handler, but doing this as a stopgap to get renderer
|
||||
// support for generate
|
||||
if values.Messages != nil && values.Suffix == "" && req.Template == "" {
|
||||
prompt, images, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, values.Messages, []api.Tool{}, req.Think, req.Truncate == nil || *req.Truncate)
|
||||
genTruncate := (req.Truncate == nil || *req.Truncate) && !m.IsMLX()
|
||||
prompt, images, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, values.Messages, []api.Tool{}, req.Think, genTruncate)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -557,6 +558,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
PromptEvalDuration: cr.PromptEvalDuration,
|
||||
EvalCount: cr.EvalCount,
|
||||
EvalDuration: cr.EvalDuration,
|
||||
PeakMemory: cr.PeakMemory,
|
||||
},
|
||||
Logprobs: toAPILogprobs(cr.Logprobs),
|
||||
}
|
||||
@@ -1951,6 +1953,9 @@ func (s *Server) PsHandler(c *gin.Context) {
|
||||
}
|
||||
if v.llama != nil {
|
||||
mr.ContextLength = v.llama.ContextLength()
|
||||
total, vram := v.llama.MemorySize()
|
||||
mr.Size = int64(total)
|
||||
mr.SizeVRAM = int64(vram)
|
||||
}
|
||||
// The scheduler waits to set expiresAt, so if a model is loading it's
|
||||
// possible that it will be set to the unix epoch. For those cases, just
|
||||
@@ -2213,6 +2218,9 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
truncate := req.Truncate == nil || *req.Truncate
|
||||
if m.IsMLX() {
|
||||
truncate = false
|
||||
}
|
||||
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think, truncate)
|
||||
if err != nil {
|
||||
slog.Error("chat prompt error", "error", err)
|
||||
@@ -2309,6 +2317,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
PromptEvalDuration: r.PromptEvalDuration,
|
||||
EvalCount: r.EvalCount,
|
||||
EvalDuration: r.EvalDuration,
|
||||
PeakMemory: r.PeakMemory,
|
||||
},
|
||||
Logprobs: toAPILogprobs(r.Logprobs),
|
||||
}
|
||||
|
||||
@@ -231,7 +231,7 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
||||
}
|
||||
|
||||
// Check for experimental safetensors LLM models
|
||||
if pending.model.Config.ModelFormat == "safetensors" {
|
||||
if pending.model.IsMLX() {
|
||||
if slices.Contains(pending.model.Config.Capabilities, "completion") {
|
||||
// LLM model with safetensors format - use MLX runner
|
||||
if s.loadMLX(pending) {
|
||||
@@ -536,6 +536,7 @@ iGPUScan:
|
||||
}
|
||||
}
|
||||
|
||||
totalSize, vramSize := llama.MemorySize()
|
||||
runner := &runnerRef{
|
||||
model: req.model,
|
||||
modelPath: req.model.ModelPath,
|
||||
@@ -545,8 +546,8 @@ iGPUScan:
|
||||
sessionDuration: sessionDuration,
|
||||
gpus: gpuIDs,
|
||||
discreteGPUs: discreteGPUs,
|
||||
vramSize: llama.VRAMSize(),
|
||||
totalSize: llama.TotalSize(),
|
||||
totalSize: totalSize,
|
||||
vramSize: vramSize,
|
||||
loading: true,
|
||||
pid: llama.Pid(),
|
||||
}
|
||||
@@ -619,6 +620,7 @@ func (s *Scheduler) loadMLX(req *LlmRequest) bool {
|
||||
sessionDuration = req.sessionDuration.Duration
|
||||
}
|
||||
|
||||
totalSize, vramSize := server.MemorySize()
|
||||
runner := &runnerRef{
|
||||
model: req.model,
|
||||
modelPath: req.model.ModelPath,
|
||||
@@ -628,8 +630,8 @@ func (s *Scheduler) loadMLX(req *LlmRequest) bool {
|
||||
loading: false,
|
||||
isImagegen: isImagegen,
|
||||
sessionDuration: sessionDuration,
|
||||
totalSize: server.TotalSize(),
|
||||
vramSize: server.VRAMSize(),
|
||||
totalSize: totalSize,
|
||||
vramSize: vramSize,
|
||||
}
|
||||
|
||||
s.loadedMu.Lock()
|
||||
@@ -762,7 +764,7 @@ func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool
|
||||
defer cancel()
|
||||
if !reflect.DeepEqual(runner.model.AdapterPaths, req.model.AdapterPaths) || // have the adapters changed?
|
||||
!reflect.DeepEqual(runner.model.ProjectorPaths, req.model.ProjectorPaths) || // have the projectors changed?
|
||||
!reflect.DeepEqual(optsExisting, optsNew) || // have the runner options changed?
|
||||
(!runner.model.IsMLX() && !reflect.DeepEqual(optsExisting, optsNew)) || // have the runner options changed?
|
||||
runner.llama.Ping(ctx) != nil {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -861,8 +861,7 @@ func (s *mockLlm) Close() error {
|
||||
s.closeCalled = true
|
||||
return s.closeResp
|
||||
}
|
||||
func (s *mockLlm) VRAMSize() uint64 { return s.vramSize }
|
||||
func (s *mockLlm) TotalSize() uint64 { return s.totalSize }
|
||||
func (s *mockLlm) MemorySize() (uint64, uint64) { return s.totalSize, s.vramSize }
|
||||
func (s *mockLlm) VRAMByGPU(id ml.DeviceID) uint64 { return s.vramByGPU[id] }
|
||||
func (s *mockLlm) Pid() int { return -1 }
|
||||
func (s *mockLlm) GetPort() int { return -1 }
|
||||
|
||||
@@ -374,14 +374,9 @@ func (s *Server) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// VRAMSize returns the estimated VRAM usage.
|
||||
func (s *Server) VRAMSize() uint64 {
|
||||
return s.vramSize
|
||||
}
|
||||
|
||||
// TotalSize returns the total memory usage.
|
||||
func (s *Server) TotalSize() uint64 {
|
||||
return s.vramSize
|
||||
// MemorySize returns the total and VRAM memory usage.
|
||||
func (s *Server) MemorySize() (total, vram uint64) {
|
||||
return s.vramSize, s.vramSize
|
||||
}
|
||||
|
||||
// VRAMByGPU returns VRAM usage for a specific GPU.
|
||||
|
||||
@@ -9,59 +9,105 @@ import (
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||
)
|
||||
|
||||
// CacheEntry stores a single sequence
|
||||
type CacheEntry struct {
|
||||
Tokens []int32
|
||||
Caches []cache.Cache
|
||||
type kvCache struct {
|
||||
// For now we only support a single entry, so this is just one sequence
|
||||
tokens []int32
|
||||
caches []cache.Cache
|
||||
}
|
||||
|
||||
// FindNearestCache finds the longest common prefix between tokens and the cached sequence
|
||||
func (r *Runner) FindNearestCache(tokens []int32) ([]cache.Cache, []int32) {
|
||||
if r.cache == nil {
|
||||
slog.Info("Cache miss", "left", len(tokens))
|
||||
return nil, tokens
|
||||
// cacheSession manages caches for a single pipeline run.
|
||||
// Callers should append generated tokens to outputs and
|
||||
// defer close to save the cache state.
|
||||
type cacheSession struct {
|
||||
cache *kvCache
|
||||
inputs []int32
|
||||
outputs []int32
|
||||
|
||||
caches []cache.Cache
|
||||
remaining []int32
|
||||
}
|
||||
|
||||
// begin prepares caches for a new request. It finds the nearest
|
||||
// matching cache or creates new caches if none match.
|
||||
func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession {
|
||||
if len(c.caches) == 0 {
|
||||
if cacheFactory, ok := m.(interface{ NewCaches() []cache.Cache }); ok {
|
||||
c.caches = cacheFactory.NewCaches()
|
||||
} else {
|
||||
c.caches = make([]cache.Cache, m.NumLayers())
|
||||
for i := range c.caches {
|
||||
c.caches[i] = cache.NewKVCache()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Find longest common prefix
|
||||
remaining := c.findRemaining(inputs)
|
||||
|
||||
return &cacheSession{
|
||||
cache: c,
|
||||
inputs: inputs,
|
||||
caches: c.caches,
|
||||
remaining: remaining,
|
||||
}
|
||||
}
|
||||
|
||||
// close saves the token state if the forward pass ran.
|
||||
func (s *cacheSession) close() {
|
||||
if offset := s.caches[0].Offset(); offset > 0 {
|
||||
// Ensure that if we have run the forward pass and set the metadata
|
||||
// that we also actually have the data
|
||||
arrays := make([]*mlx.Array, 0, 2*len(s.caches))
|
||||
for _, c := range s.caches {
|
||||
k, v := c.State()
|
||||
arrays = append(arrays, k, v)
|
||||
}
|
||||
mlx.AsyncEval(arrays...)
|
||||
|
||||
s.cache.tokens = append(s.inputs, s.outputs...)[:offset]
|
||||
}
|
||||
}
|
||||
|
||||
// findRemaining finds the longest common prefix between tokens and the cached
|
||||
// sequence, trims stale cache entries, and returns the remaining tokens.
|
||||
func (c *kvCache) findRemaining(tokens []int32) []int32 {
|
||||
prefix := 0
|
||||
for prefix < len(tokens) && prefix < len(r.cache.Tokens) && tokens[prefix] == r.cache.Tokens[prefix] {
|
||||
for prefix < len(tokens) && prefix < len(c.tokens) && tokens[prefix] == c.tokens[prefix] {
|
||||
prefix++
|
||||
}
|
||||
|
||||
switch {
|
||||
case prefix == 0:
|
||||
for _, c := range r.cache.Caches {
|
||||
c.Free()
|
||||
// Always keep at least one token to re-evaluate so the
|
||||
// pipeline can seed token generation from it.
|
||||
if prefix == len(tokens) && prefix > 0 {
|
||||
prefix--
|
||||
}
|
||||
|
||||
if prefix < len(c.tokens) {
|
||||
trim := len(c.tokens) - prefix
|
||||
for _, kv := range c.caches {
|
||||
kv.Trim(trim)
|
||||
}
|
||||
r.cache = nil
|
||||
c.tokens = c.tokens[:prefix]
|
||||
}
|
||||
|
||||
if prefix == 0 {
|
||||
slog.Info("Cache miss", "left", len(tokens))
|
||||
return nil, tokens
|
||||
case prefix < len(r.cache.Tokens):
|
||||
trim := len(r.cache.Tokens) - prefix
|
||||
for _, c := range r.cache.Caches {
|
||||
c.Trim(trim)
|
||||
}
|
||||
r.cache.Tokens = r.cache.Tokens[:prefix]
|
||||
} else {
|
||||
slog.Info("Cache hit", "total", len(tokens), "cached", prefix, "left", len(tokens[prefix:]))
|
||||
}
|
||||
|
||||
slog.Info("Cache hit", "total", len(tokens), "cached", prefix, "left", len(tokens[prefix:]))
|
||||
return r.cache.Caches, tokens[prefix:]
|
||||
return tokens[prefix:]
|
||||
}
|
||||
|
||||
func (r *Runner) InsertCache(tokens []int32, caches []cache.Cache) {
|
||||
r.cache = &CacheEntry{
|
||||
Tokens: tokens,
|
||||
Caches: caches,
|
||||
func (c *kvCache) log() {
|
||||
if len(c.caches) == 0 {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (c *CacheEntry) LogCache() {
|
||||
var totalBytes int
|
||||
for _, kv := range c.Caches {
|
||||
for _, kv := range c.caches {
|
||||
k, v := kv.State()
|
||||
totalBytes += k.NumBytes() + v.NumBytes()
|
||||
}
|
||||
logutil.Trace(fmt.Sprintf("kv cache tokens: %d, size: %s", c.Caches[0].Offset(), mlx.PrettyBytes(totalBytes)))
|
||||
logutil.Trace(fmt.Sprintf("kv cache tokens: %d, size: %s", c.caches[0].Offset(), mlx.PrettyBytes(totalBytes)))
|
||||
}
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"math"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
@@ -19,25 +18,27 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/manifest"
|
||||
)
|
||||
|
||||
// Client wraps an MLX runner subprocess to implement llm.LlamaServer for LLM models.
|
||||
type Client struct {
|
||||
port int
|
||||
modelName string
|
||||
vramSize uint64
|
||||
done chan error
|
||||
client *http.Client
|
||||
lastErr string
|
||||
lastErrLock sync.Mutex
|
||||
mu sync.Mutex
|
||||
cmd *exec.Cmd
|
||||
port int
|
||||
modelName string
|
||||
contextLength atomic.Int64
|
||||
memory atomic.Uint64
|
||||
done chan error
|
||||
client *http.Client
|
||||
lastErr string
|
||||
lastErrLock sync.Mutex
|
||||
mu sync.Mutex
|
||||
cmd *exec.Cmd
|
||||
}
|
||||
|
||||
// NewClient spawns a new MLX runner subprocess for LLM models and waits until it's ready.
|
||||
@@ -98,18 +99,9 @@ func NewClient(modelName string) (*Client, error) {
|
||||
slog.Debug("mlx subprocess library path", "LD_LIBRARY_PATH", pathEnvVal)
|
||||
}
|
||||
|
||||
// Estimate VRAM based on tensor size from manifest
|
||||
var vramSize uint64
|
||||
if modelManifest, err := manifest.LoadManifest(modelName); err == nil {
|
||||
vramSize = uint64(modelManifest.TotalTensorSize())
|
||||
} else {
|
||||
vramSize = 8 * 1024 * 1024 * 1024
|
||||
}
|
||||
|
||||
c := &Client{
|
||||
port: port,
|
||||
modelName: modelName,
|
||||
vramSize: vramSize,
|
||||
done: make(chan error, 1),
|
||||
client: &http.Client{Timeout: 10 * time.Minute},
|
||||
cmd: cmd,
|
||||
@@ -201,6 +193,20 @@ type completionOpts struct {
|
||||
NumPredict int `json:"num_predict,omitempty"`
|
||||
}
|
||||
|
||||
type CompletionResponse struct {
|
||||
Content string
|
||||
Done bool
|
||||
DoneReason int
|
||||
|
||||
PromptEvalCount int
|
||||
PromptEvalDuration time.Duration
|
||||
EvalCount int
|
||||
EvalDuration time.Duration
|
||||
PeakMemory uint64
|
||||
|
||||
Error *api.StatusError
|
||||
}
|
||||
|
||||
// Close terminates the subprocess.
|
||||
func (c *Client) Close() error {
|
||||
c.mu.Lock()
|
||||
@@ -260,28 +266,25 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
for scanner.Scan() {
|
||||
var raw struct {
|
||||
Content string `json:"content,omitempty"`
|
||||
Done bool `json:"done"`
|
||||
DoneReason int `json:"done_reason,omitempty"`
|
||||
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
|
||||
PromptEvalDuration int `json:"prompt_eval_duration,omitempty"`
|
||||
EvalCount int `json:"eval_count,omitempty"`
|
||||
EvalDuration int `json:"eval_duration,omitempty"`
|
||||
}
|
||||
var raw CompletionResponse
|
||||
if err := json.Unmarshal(scanner.Bytes(), &raw); err != nil {
|
||||
slog.Debug("mlx response parse error", "error", err, "line", string(scanner.Bytes()))
|
||||
continue
|
||||
}
|
||||
|
||||
if raw.Error != nil {
|
||||
return *raw.Error
|
||||
}
|
||||
|
||||
cresp := llm.CompletionResponse{
|
||||
Content: raw.Content,
|
||||
Done: raw.Done,
|
||||
DoneReason: llm.DoneReason(raw.DoneReason),
|
||||
PromptEvalCount: raw.PromptEvalCount,
|
||||
PromptEvalDuration: time.Duration(raw.PromptEvalDuration),
|
||||
PromptEvalDuration: raw.PromptEvalDuration,
|
||||
EvalCount: raw.EvalCount,
|
||||
EvalDuration: time.Duration(raw.EvalDuration),
|
||||
EvalDuration: raw.EvalDuration,
|
||||
PeakMemory: raw.PeakMemory,
|
||||
}
|
||||
|
||||
fn(cresp)
|
||||
@@ -294,7 +297,7 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f
|
||||
}
|
||||
|
||||
func (c *Client) ContextLength() int {
|
||||
return math.MaxInt
|
||||
return int(c.contextLength.Load())
|
||||
}
|
||||
|
||||
// Detokenize implements llm.LlamaServer.
|
||||
@@ -347,9 +350,16 @@ func (c *Client) Pid() int {
|
||||
return -1
|
||||
}
|
||||
|
||||
type statusResponse struct {
|
||||
Status int
|
||||
Progress int
|
||||
ContextLength int
|
||||
Memory uint64
|
||||
}
|
||||
|
||||
// Ping implements llm.LlamaServer.
|
||||
func (c *Client) Ping(ctx context.Context) error {
|
||||
reqURL := fmt.Sprintf("http://127.0.0.1:%d/health", c.port)
|
||||
reqURL := fmt.Sprintf("http://127.0.0.1:%d/v1/status", c.port)
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -362,6 +372,15 @@ func (c *Client) Ping(ctx context.Context) error {
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("health check failed: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var status statusResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&status); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.contextLength.Store(int64(status.ContextLength))
|
||||
c.memory.Store(status.Memory)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -388,19 +407,24 @@ func (c *Client) Tokenize(ctx context.Context, content string) ([]int, error) {
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
// TotalSize implements llm.LlamaServer.
|
||||
func (c *Client) TotalSize() uint64 {
|
||||
return c.vramSize
|
||||
func (c *Client) currentMemory() uint64 {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
if err := c.Ping(ctx); err != nil {
|
||||
slog.Warn("failed to get current memory", "error", err)
|
||||
}
|
||||
return c.memory.Load()
|
||||
}
|
||||
|
||||
// MemorySize implements llm.LlamaServer.
|
||||
func (c *Client) MemorySize() (total, vram uint64) {
|
||||
mem := c.currentMemory()
|
||||
return mem, mem
|
||||
}
|
||||
|
||||
// VRAMByGPU implements llm.LlamaServer.
|
||||
func (c *Client) VRAMByGPU(id ml.DeviceID) uint64 {
|
||||
return c.vramSize
|
||||
}
|
||||
|
||||
// VRAMSize implements llm.LlamaServer.
|
||||
func (c *Client) VRAMSize() uint64 {
|
||||
return c.vramSize
|
||||
return c.currentMemory()
|
||||
}
|
||||
|
||||
// WaitUntilRunning implements llm.LlamaServer.
|
||||
|
||||
@@ -64,6 +64,10 @@ func PeakMemory() int {
|
||||
return int(peak)
|
||||
}
|
||||
|
||||
func ResetPeakMemory() {
|
||||
C.mlx_reset_peak_memory()
|
||||
}
|
||||
|
||||
type Memory struct{}
|
||||
|
||||
func (Memory) LogValue() slog.Value {
|
||||
|
||||
@@ -20,6 +20,7 @@ type Model interface {
|
||||
Unembed(x *mlx.Array) *mlx.Array
|
||||
NumLayers() int
|
||||
Tokenizer() *tokenizer.Tokenizer
|
||||
MaxContextLength() int
|
||||
|
||||
// LoadWeights receives all tensors loaded from the manifest and assigns
|
||||
// them to model fields. Model-specific logic (MLA absorption, expert
|
||||
|
||||
@@ -6,11 +6,13 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
@@ -19,6 +21,23 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
return errors.New("model not loaded")
|
||||
}
|
||||
|
||||
var (
|
||||
sample, logprobs *mlx.Array
|
||||
nextSample, nextLogprobs *mlx.Array
|
||||
)
|
||||
|
||||
defer func() {
|
||||
mlx.Unpin(sample, logprobs)
|
||||
mlx.Unpin(nextSample, nextLogprobs)
|
||||
mlx.Sweep()
|
||||
mlx.ClearCache()
|
||||
|
||||
if slog.Default().Enabled(context.TODO(), logutil.LevelTrace) {
|
||||
mlx.LogArrays()
|
||||
r.cache.log()
|
||||
}
|
||||
}()
|
||||
|
||||
enableCompile := true
|
||||
if modelCompile, ok := r.Model.(interface{ EnableCompile() bool }); ok {
|
||||
enableCompile = modelCompile.EnableCompile()
|
||||
@@ -28,24 +47,40 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
} else {
|
||||
mlx.DisableCompile()
|
||||
}
|
||||
mlx.ResetPeakMemory()
|
||||
|
||||
inputs := r.Tokenizer.Encode(request.Prompt, true)
|
||||
if len(inputs) == 0 {
|
||||
return errors.New("empty prompt")
|
||||
}
|
||||
|
||||
caches, tokens := r.FindNearestCache(inputs)
|
||||
if len(caches) == 0 {
|
||||
if cacheFactory, ok := r.Model.(interface{ NewCaches() []cache.Cache }); ok {
|
||||
caches = cacheFactory.NewCaches()
|
||||
} else {
|
||||
caches = make([]cache.Cache, r.Model.NumLayers())
|
||||
for i := range caches {
|
||||
caches[i] = cache.NewKVCache()
|
||||
}
|
||||
if len(inputs) >= r.contextLength {
|
||||
return api.StatusError{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
ErrorMessage: fmt.Sprintf("input length (%d tokens) exceeds the model's maximum context length (%d tokens)", len(inputs), r.contextLength),
|
||||
}
|
||||
}
|
||||
|
||||
// Cap generation to stay within the model's context length
|
||||
maxGenerate := r.contextLength - len(inputs)
|
||||
if request.Options.MaxTokens <= 0 {
|
||||
request.Options.MaxTokens = maxGenerate
|
||||
} else {
|
||||
request.Options.MaxTokens = min(request.Options.MaxTokens, maxGenerate)
|
||||
}
|
||||
|
||||
session := r.cache.begin(r.Model, inputs)
|
||||
defer session.close()
|
||||
caches := session.caches
|
||||
tokens := session.remaining
|
||||
|
||||
now := time.Now()
|
||||
total, processed := len(tokens), 0
|
||||
slog.Info("Prompt processing progress", "processed", processed, "total", total)
|
||||
for total-processed > 1 {
|
||||
if err := request.Ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
n := min(2<<10, total-processed-1)
|
||||
r.Model.Forward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), caches)
|
||||
mlx.Sweep()
|
||||
@@ -76,61 +111,58 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
return sample, logprobs
|
||||
}
|
||||
|
||||
sample, logprobs := step(mlx.FromValues(tokens[processed:], total-processed))
|
||||
sample, logprobs = step(mlx.FromValues(tokens[processed:], total-processed))
|
||||
|
||||
var b bytes.Buffer
|
||||
|
||||
now := time.Now()
|
||||
final := Response{Done: true, PromptTokens: total, CompletionTokens: request.Options.MaxTokens, DoneReason: 1}
|
||||
outputs := make([]int32, 0, request.Options.MaxTokens)
|
||||
final := CompletionResponse{Done: true, PromptEvalCount: len(inputs), EvalCount: request.Options.MaxTokens, DoneReason: 1}
|
||||
for i := range request.Options.MaxTokens {
|
||||
nextSample, nextLogprobs := step(sample)
|
||||
if err := request.Ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
nextSample, nextLogprobs = step(sample)
|
||||
|
||||
if i == 0 {
|
||||
slog.Info("Prompt processing progress", "processed", total, "total", total)
|
||||
mlx.Eval(sample)
|
||||
final.PromptTokensDuration = time.Since(now)
|
||||
final.PromptEvalDuration = time.Since(now)
|
||||
now = time.Now()
|
||||
}
|
||||
|
||||
output := int32(sample.Int())
|
||||
outputs = append(outputs, output)
|
||||
session.outputs = append(session.outputs, output)
|
||||
|
||||
if r.Tokenizer.IsEOS(output) {
|
||||
mlx.Unpin(nextSample, nextLogprobs)
|
||||
final.Token = int(output)
|
||||
final.DoneReason = 0
|
||||
final.CompletionTokens = i
|
||||
final.EvalCount = i
|
||||
break
|
||||
}
|
||||
|
||||
request.Responses <- Response{
|
||||
Text: r.Decode(output, &b),
|
||||
Token: int(output),
|
||||
select {
|
||||
case <-request.Ctx.Done():
|
||||
return request.Ctx.Err()
|
||||
case request.Responses <- CompletionResponse{
|
||||
Content: r.Decode(output, &b),
|
||||
}:
|
||||
}
|
||||
|
||||
mlx.Unpin(sample, logprobs)
|
||||
sample, logprobs = nextSample, nextLogprobs
|
||||
nextSample, nextLogprobs = nil, nil
|
||||
|
||||
if i%256 == 0 {
|
||||
mlx.ClearCache()
|
||||
}
|
||||
|
||||
sample, logprobs = nextSample, nextLogprobs
|
||||
}
|
||||
|
||||
mlx.Unpin(sample, logprobs)
|
||||
final.CompletionTokensDuration = time.Since(now)
|
||||
request.Responses <- final
|
||||
r.InsertCache(append(inputs, outputs...), caches)
|
||||
mlx.Sweep()
|
||||
|
||||
if slog.Default().Enabled(context.TODO(), logutil.LevelTrace) {
|
||||
mlx.LogArrays()
|
||||
if r.cache != nil {
|
||||
r.cache.LogCache()
|
||||
}
|
||||
final.EvalDuration = time.Since(now)
|
||||
final.PeakMemory = uint64(mlx.PeakMemory())
|
||||
select {
|
||||
case <-request.Ctx.Done():
|
||||
return request.Ctx.Err()
|
||||
case request.Responses <- final:
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r Runner) Decode(sample int32, b *bytes.Buffer) string {
|
||||
|
||||
@@ -4,15 +4,15 @@ package mlxrunner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||
@@ -22,11 +22,12 @@ import (
|
||||
|
||||
type Request struct {
|
||||
TextCompletionsRequest
|
||||
Responses chan Response
|
||||
Responses chan CompletionResponse
|
||||
Pipeline func(Request) error
|
||||
|
||||
Ctx context.Context
|
||||
|
||||
sample.Sampler
|
||||
caches []cache.Cache
|
||||
}
|
||||
|
||||
type TextCompletionsRequest struct {
|
||||
@@ -43,25 +44,12 @@ type TextCompletionsRequest struct {
|
||||
} `json:"options"`
|
||||
}
|
||||
|
||||
type Response struct {
|
||||
Text string `json:"content,omitempty"`
|
||||
Token int `json:"token,omitempty"`
|
||||
Logprobs []float32 `json:"logprobs,omitempty"`
|
||||
Done bool `json:"done,omitempty"`
|
||||
DoneReason int `json:"done_reason,omitempty"`
|
||||
|
||||
PromptTokens int `json:"prompt_eval_count,omitempty"`
|
||||
PromptTokensDuration time.Duration `json:"prompt_eval_duration,omitempty"`
|
||||
CompletionTokens int `json:"eval_count,omitempty"`
|
||||
CompletionTokensDuration time.Duration `json:"eval_duration,omitempty"`
|
||||
TotalTokens int `json:"total_tokens,omitempty"`
|
||||
}
|
||||
|
||||
type Runner struct {
|
||||
Model base.Model
|
||||
Tokenizer *tokenizer.Tokenizer
|
||||
Requests chan Request
|
||||
cache *CacheEntry
|
||||
Model base.Model
|
||||
Tokenizer *tokenizer.Tokenizer
|
||||
Requests chan Request
|
||||
cache kvCache
|
||||
contextLength int
|
||||
}
|
||||
|
||||
func (r *Runner) Load(modelName string) error {
|
||||
@@ -90,6 +78,7 @@ func (r *Runner) Load(modelName string) error {
|
||||
|
||||
r.Model = m
|
||||
r.Tokenizer = m.Tokenizer()
|
||||
r.contextLength = m.MaxContextLength()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -157,7 +146,18 @@ func (r *Runner) Run(host, port string, mux http.Handler) error {
|
||||
return nil
|
||||
case request := <-r.Requests:
|
||||
if err := request.Pipeline(request); err != nil {
|
||||
break
|
||||
slog.Info("Request terminated", "error", err)
|
||||
var statusErr api.StatusError
|
||||
if !errors.As(err, &statusErr) {
|
||||
statusErr = api.StatusError{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
ErrorMessage: err.Error(),
|
||||
}
|
||||
}
|
||||
select {
|
||||
case request.Responses <- CompletionResponse{Error: &statusErr}:
|
||||
case <-request.Ctx.Done():
|
||||
}
|
||||
}
|
||||
|
||||
close(request.Responses)
|
||||
|
||||
@@ -5,6 +5,7 @@ package mlxrunner
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
@@ -49,9 +50,11 @@ func Execute(args []string) error {
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("GET /v1/status", func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := json.NewEncoder(w).Encode(map[string]any{
|
||||
"status": 0,
|
||||
"progress": 100,
|
||||
if err := json.NewEncoder(w).Encode(statusResponse{
|
||||
Status: 0,
|
||||
Progress: 100,
|
||||
ContextLength: runner.contextLength,
|
||||
Memory: uint64(mlx.ActiveMemory() + mlx.CacheMemory()),
|
||||
}); err != nil {
|
||||
slog.Error("Failed to encode response", "error", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
@@ -77,7 +80,7 @@ func Execute(args []string) error {
|
||||
})
|
||||
|
||||
mux.HandleFunc("POST /v1/completions", func(w http.ResponseWriter, r *http.Request) {
|
||||
request := Request{Responses: make(chan Response)}
|
||||
request := Request{Responses: make(chan CompletionResponse)}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&request.TextCompletionsRequest); err != nil {
|
||||
slog.Error("Failed to decode request", "error", err)
|
||||
@@ -86,9 +89,6 @@ func Execute(args []string) error {
|
||||
}
|
||||
|
||||
request.Options.MaxTokens = cmp.Or(request.Options.MaxTokens, request.Options.NumPredict)
|
||||
if request.Options.MaxTokens < 1 {
|
||||
request.Options.MaxTokens = 16 << 10
|
||||
}
|
||||
|
||||
request.Pipeline = runner.TextGenerationPipeline
|
||||
request.Sampler = sample.New(
|
||||
@@ -98,19 +98,36 @@ func Execute(args []string) error {
|
||||
request.Options.TopK,
|
||||
)
|
||||
|
||||
runner.Requests <- request
|
||||
var cancel context.CancelFunc
|
||||
request.Ctx, cancel = context.WithCancel(r.Context())
|
||||
defer cancel()
|
||||
|
||||
select {
|
||||
case <-r.Context().Done():
|
||||
return
|
||||
case runner.Requests <- request:
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/jsonl")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
enc := json.NewEncoder(w)
|
||||
for response := range request.Responses {
|
||||
if err := enc.Encode(response); err != nil {
|
||||
slog.Error("Failed to encode response", "error", err)
|
||||
for {
|
||||
select {
|
||||
case <-r.Context().Done():
|
||||
return
|
||||
}
|
||||
case response, ok := <-request.Responses:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
if err := enc.Encode(response); err != nil {
|
||||
slog.Error("Failed to encode response", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
@@ -430,6 +430,10 @@ func (m *Model) NumLayers() int {
|
||||
return len(m.Layers)
|
||||
}
|
||||
|
||||
func (m *Model) MaxContextLength() int {
|
||||
return int(m.MaxPositionEmbeddings)
|
||||
}
|
||||
|
||||
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
|
||||
return m.tok
|
||||
}
|
||||
|
||||
@@ -733,7 +733,7 @@ func (m *Model) Unembed(x *mlx.Array) *mlx.Array {
|
||||
func (m *Model) NumLayers() int { return len(m.Layers) }
|
||||
|
||||
// MaxContextLength returns the maximum context length
|
||||
func (m *Model) MaxContextLength() int32 { return m.MaxPositionEmbeddings }
|
||||
func (m *Model) MaxContextLength() int { return int(m.MaxPositionEmbeddings) }
|
||||
|
||||
// VocabSize returns the vocabulary size
|
||||
func (m *Model) VocabSize() int32 { return m.Config.VocabSize }
|
||||
|
||||
@@ -262,6 +262,10 @@ func (m *Model) NumLayers() int {
|
||||
return len(m.Layers)
|
||||
}
|
||||
|
||||
func (m *Model) MaxContextLength() int {
|
||||
return int(m.MaxPositionEmbeddings)
|
||||
}
|
||||
|
||||
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
|
||||
return m.tok
|
||||
}
|
||||
|
||||
@@ -279,6 +279,10 @@ func (m *Model) NumLayers() int {
|
||||
return len(m.Layers)
|
||||
}
|
||||
|
||||
func (m *Model) MaxContextLength() int {
|
||||
return int(m.MaxPositionEmbeddings)
|
||||
}
|
||||
|
||||
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
|
||||
return m.tok
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user