mirror of
https://github.com/ollama/ollama.git
synced 2026-02-06 05:34:21 -05:00
Compare commits
2 Commits
pdevine/la
...
jmorganca/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c330ea33ed | ||
|
|
c61023f554 |
@@ -144,3 +144,47 @@ func TestUnicodeModelDir(t *testing.T) {
|
||||
}
|
||||
ChatTestHelper(ctx, t, req, blueSkyExpected)
|
||||
}
|
||||
|
||||
// TestNumPredict verifies that when num_predict is set, the model generates
|
||||
// exactly that many tokens. It uses logprobs to count the actual tokens output.
|
||||
func TestNumPredict(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
if err := PullIfMissing(ctx, client, "qwen3:0.6b"); err != nil {
|
||||
t.Fatalf("failed to pull model: %v", err)
|
||||
}
|
||||
|
||||
req := api.GenerateRequest{
|
||||
Model: "qwen3:0.6b",
|
||||
Prompt: "Write a long story.",
|
||||
Stream: &stream,
|
||||
Logprobs: true,
|
||||
Options: map[string]any{
|
||||
"num_predict": 10,
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
},
|
||||
}
|
||||
|
||||
logprobCount := 0
|
||||
var finalResponse api.GenerateResponse
|
||||
err := client.Generate(ctx, &req, func(resp api.GenerateResponse) error {
|
||||
logprobCount += len(resp.Logprobs)
|
||||
if resp.Done {
|
||||
finalResponse = resp
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("generate failed: %v", err)
|
||||
}
|
||||
|
||||
if logprobCount != 10 {
|
||||
t.Errorf("expected 10 tokens (logprobs), got %d (EvalCount=%d, DoneReason=%s)",
|
||||
logprobCount, finalResponse.EvalCount, finalResponse.DoneReason)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -39,12 +39,15 @@ func (sa *FullAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tens
|
||||
if nSeqs > 0 {
|
||||
// 3D tensor: [hiddenDim, seqTokens, nSeqs]
|
||||
if batchSize != seqTokens || nSeqs != seqs {
|
||||
return nil, ErrUnsupportedBatchLayout
|
||||
// Fallback: treat as flat batch if layout doesn't match.
|
||||
hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, batchSize*nSeqs)
|
||||
batchSize = batchSize * nSeqs
|
||||
} else {
|
||||
hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, seqTokens*seqs)
|
||||
batchSize = seqTokens * seqs
|
||||
}
|
||||
hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, seqTokens*seqs)
|
||||
batchSize = seqTokens * seqs
|
||||
} else if batchSize != seqTokens*seqs {
|
||||
return nil, ErrUnsupportedBatchLayout
|
||||
// Layout mismatch; proceed with flat batch.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -64,6 +64,8 @@ type HybridCache struct {
|
||||
curSlots []int
|
||||
curSlotsInput ml.Tensor
|
||||
curSeqTokens int
|
||||
// token indices per sequence in batch order
|
||||
curSeqTokenIdxs [][]int32
|
||||
|
||||
// track if EnsureWritable has been called for this forward pass
|
||||
writableEnsured bool
|
||||
@@ -168,19 +170,44 @@ func (c *HybridCache) StartForward(ctx ml.Context, batch input.Batch, reserve bo
|
||||
}
|
||||
|
||||
if len(c.curSeqs) == 0 {
|
||||
c.curSeqTokenIdxs = c.curSeqTokenIdxs[:0]
|
||||
return nil
|
||||
}
|
||||
|
||||
if cap(c.curSeqTokenIdxs) < len(c.curSeqs) {
|
||||
c.curSeqTokenIdxs = make([][]int32, len(c.curSeqs))
|
||||
} else {
|
||||
c.curSeqTokenIdxs = c.curSeqTokenIdxs[:len(c.curSeqs)]
|
||||
}
|
||||
for i := range c.curSeqTokenIdxs {
|
||||
c.curSeqTokenIdxs[i] = c.curSeqTokenIdxs[i][:0]
|
||||
}
|
||||
|
||||
seqIndex := make(map[int]int, len(c.curSeqs))
|
||||
for i, s := range c.curSeqs {
|
||||
seqIndex[s] = i
|
||||
}
|
||||
for i, s := range batch.Sequences {
|
||||
c.curSeqTokenIdxs[seqIndex[s]] = append(c.curSeqTokenIdxs[seqIndex[s]], int32(i))
|
||||
}
|
||||
|
||||
nTokens := len(batch.Sequences)
|
||||
nSeqs := len(c.curSeqs)
|
||||
want := nTokens / nSeqs
|
||||
uniform := true
|
||||
for _, s := range c.curSeqs {
|
||||
if seqCounts[s] != want {
|
||||
return kvcache.ErrNotSupported
|
||||
uniform = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
c.curSeqTokens = want
|
||||
if uniform {
|
||||
c.curSeqTokens = want
|
||||
} else {
|
||||
// Mixed batch: recurrent layers will process sequences independently.
|
||||
c.curSeqTokens = 0
|
||||
}
|
||||
|
||||
// When reserving memory for estimation, use fake slot assignments
|
||||
if reserve {
|
||||
@@ -585,7 +612,101 @@ func (c *HybridCache) UpdateDeltaState(ctx ml.Context, layer int, newState ml.Te
|
||||
c.captureDeltaCheckpoint(ctx, layer, srcF32)
|
||||
}
|
||||
|
||||
// IsSupportedForBatch returns true if the current batch layout supports recurrent layers.
|
||||
// convStateForSlot returns the conv state for a single slot as [convDim, convChannels, 1].
|
||||
func (c *HybridCache) convStateForSlot(ctx ml.Context, layer int, slot int) (ml.Tensor, error) {
|
||||
c.ensureWritableOnce(ctx)
|
||||
if c.writableError != nil {
|
||||
return nil, c.writableError
|
||||
}
|
||||
buf := c.convBuffer(ctx, layer)
|
||||
slotIdx := ctx.Input().FromInts([]int32{int32(slot)}, 1)
|
||||
cur := buf.Rows(ctx, slotIdx)
|
||||
return cur.Reshape(ctx, c.convDim, c.convChannels, 1), nil
|
||||
}
|
||||
|
||||
// updateConvStateForSlot writes a new conv state for a single slot.
|
||||
func (c *HybridCache) updateConvStateForSlot(ctx ml.Context, layer int, slot int, seqIndex int, newState ml.Tensor) {
|
||||
buf := c.convBuffer(ctx, layer)
|
||||
src := newState.Reshape(ctx, c.convDim*c.convChannels, 1)
|
||||
srcF32 := src.Cast(ctx, ml.DTypeF32)
|
||||
slotIdx := ctx.Input().FromInts([]int32{int32(slot)}, 1)
|
||||
ctx.Forward(buf.SetRows(ctx, srcF32, slotIdx))
|
||||
c.captureConvCheckpointForSeq(ctx, layer, seqIndex, srcF32)
|
||||
}
|
||||
|
||||
// deltaStateForSlot returns the delta state for a single slot as [headVDim, headVDim*numVHeads, 1].
|
||||
func (c *HybridCache) deltaStateForSlot(ctx ml.Context, layer int, slot int, headVDim, numVHeads int) (ml.Tensor, error) {
|
||||
c.ensureWritableOnce(ctx)
|
||||
if c.writableError != nil {
|
||||
return nil, c.writableError
|
||||
}
|
||||
buf := c.deltaBuffer(ctx, layer)
|
||||
slotIdx := ctx.Input().FromInts([]int32{int32(slot)}, 1)
|
||||
cur := buf.Rows(ctx, slotIdx)
|
||||
return cur.Reshape(ctx, headVDim, headVDim*numVHeads, 1), nil
|
||||
}
|
||||
|
||||
// updateDeltaStateForSlot writes a new delta state for a single slot.
|
||||
func (c *HybridCache) updateDeltaStateForSlot(ctx ml.Context, layer int, slot int, seqIndex int, newState ml.Tensor) {
|
||||
buf := c.deltaBuffer(ctx, layer)
|
||||
src := newState.Reshape(ctx, c.deltaStateSize, 1)
|
||||
srcF32 := src.Cast(ctx, ml.DTypeF32)
|
||||
slotIdx := ctx.Input().FromInts([]int32{int32(slot)}, 1)
|
||||
ctx.Forward(buf.SetRows(ctx, srcF32, slotIdx))
|
||||
c.captureDeltaCheckpointForSeq(ctx, layer, seqIndex, srcF32)
|
||||
}
|
||||
|
||||
func (c *HybridCache) captureConvCheckpointForSeq(ctx ml.Context, layer int, seqIndex int, src ml.Tensor) {
|
||||
if c.checkpointCount == 0 {
|
||||
return
|
||||
}
|
||||
if c.reserveCheckpoints {
|
||||
c.reserveCheckpointConv(layer)
|
||||
return
|
||||
}
|
||||
if seqIndex < 0 || seqIndex >= len(c.curCheckpointPos) {
|
||||
return
|
||||
}
|
||||
pos := c.curCheckpointPos[seqIndex]
|
||||
if pos < 0 {
|
||||
return
|
||||
}
|
||||
slot := c.curSlots[seqIndex]
|
||||
idx := c.checkpointIndexForSlot(slot, pos)
|
||||
if idx < 0 {
|
||||
return
|
||||
}
|
||||
entry := &c.checkpoints[slot].entries[idx]
|
||||
dst := c.ensureCheckpointConv(layer, entry)
|
||||
ctx.Forward(src.Copy(ctx, dst))
|
||||
}
|
||||
|
||||
func (c *HybridCache) captureDeltaCheckpointForSeq(ctx ml.Context, layer int, seqIndex int, src ml.Tensor) {
|
||||
if c.checkpointCount == 0 {
|
||||
return
|
||||
}
|
||||
if c.reserveCheckpoints {
|
||||
c.reserveCheckpointDelta(layer)
|
||||
return
|
||||
}
|
||||
if seqIndex < 0 || seqIndex >= len(c.curCheckpointPos) {
|
||||
return
|
||||
}
|
||||
pos := c.curCheckpointPos[seqIndex]
|
||||
if pos < 0 {
|
||||
return
|
||||
}
|
||||
slot := c.curSlots[seqIndex]
|
||||
idx := c.checkpointIndexForSlot(slot, pos)
|
||||
if idx < 0 {
|
||||
return
|
||||
}
|
||||
entry := &c.checkpoints[slot].entries[idx]
|
||||
dst := c.ensureCheckpointDelta(layer, entry)
|
||||
ctx.Forward(src.Copy(ctx, dst))
|
||||
}
|
||||
|
||||
// IsSupportedForBatch returns true if the current batch layout supports grid-style recurrent processing.
|
||||
func (c *HybridCache) IsSupportedForBatch() bool {
|
||||
return c.curSeqTokens > 0 && len(c.curSeqs) > 0
|
||||
}
|
||||
|
||||
@@ -48,6 +48,13 @@ type GatedDeltaNet struct {
|
||||
Layer int
|
||||
}
|
||||
|
||||
type stateAccessors struct {
|
||||
convState func() (ml.Tensor, error)
|
||||
updateConv func(ml.Tensor)
|
||||
deltaState func() (ml.Tensor, error)
|
||||
updateDelta func(ml.Tensor)
|
||||
}
|
||||
|
||||
// createMasks builds the constant mask tensors (called once, reused for all chunks)
|
||||
func createMasks(ctx ml.Context) *Masks {
|
||||
ones := ctx.Input().Zeros(ml.DTypeF32, chunkSize, chunkSize)
|
||||
@@ -68,7 +75,6 @@ func createMasks(ctx ml.Context) *Masks {
|
||||
}
|
||||
|
||||
func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error) {
|
||||
layer := gdn.Layer
|
||||
nSeqTokens := hiddenStates.Dim(1)
|
||||
nSeqs := hiddenStates.Dim(2)
|
||||
if cache != nil && cache.IsSupportedForBatch() {
|
||||
@@ -77,34 +83,140 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac
|
||||
if seqTokens > 0 && seqs > 0 {
|
||||
if nSeqs > 1 {
|
||||
if nSeqTokens != seqTokens || nSeqs != seqs {
|
||||
return nil, ErrUnsupportedBatchLayout
|
||||
return gdn.forwardMixed(ctx, hiddenStates, cache, opts)
|
||||
}
|
||||
} else {
|
||||
if nSeqTokens != seqTokens*seqs {
|
||||
return nil, ErrUnsupportedBatchLayout
|
||||
return gdn.forwardMixed(ctx, hiddenStates, cache, opts)
|
||||
}
|
||||
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), seqTokens, seqs)
|
||||
nSeqTokens = seqTokens
|
||||
nSeqs = seqs
|
||||
}
|
||||
}
|
||||
|
||||
numVHeads := opts.ssmDtRank
|
||||
headVDim := opts.ssmDInner / numVHeads
|
||||
layer := gdn.Layer
|
||||
access := stateAccessors{
|
||||
convState: func() (ml.Tensor, error) {
|
||||
return cache.ConvState(ctx, layer)
|
||||
},
|
||||
updateConv: func(newState ml.Tensor) {
|
||||
cache.UpdateConvState(ctx, layer, newState)
|
||||
},
|
||||
deltaState: func() (ml.Tensor, error) {
|
||||
return cache.DeltaState(ctx, layer, headVDim, numVHeads)
|
||||
},
|
||||
updateDelta: func(newState ml.Tensor) {
|
||||
cache.UpdateDeltaState(ctx, layer, newState)
|
||||
},
|
||||
}
|
||||
|
||||
return gdn.forwardWithAccessors(ctx, hiddenStates, opts, access)
|
||||
}
|
||||
|
||||
if cache == nil {
|
||||
return nil, ErrUnsupportedBatchLayout
|
||||
}
|
||||
|
||||
return gdn.forwardMixed(ctx, hiddenStates, cache, opts)
|
||||
}
|
||||
|
||||
func (gdn *GatedDeltaNet) forwardMixed(ctx ml.Context, hiddenStates ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error) {
|
||||
if hiddenStates.Dim(2) > 0 {
|
||||
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), hiddenStates.Dim(1)*hiddenStates.Dim(2))
|
||||
}
|
||||
|
||||
if len(cache.curSeqs) == 0 {
|
||||
return hiddenStates, nil
|
||||
}
|
||||
|
||||
// Ensure any shared slots are detached once for this forward pass.
|
||||
cache.ensureWritableOnce(ctx)
|
||||
|
||||
layer := gdn.Layer
|
||||
numVHeads := opts.ssmDtRank
|
||||
headVDim := opts.ssmDInner / numVHeads
|
||||
|
||||
if gdn.SSMQKV == nil || gdn.SSMQKVGate == nil {
|
||||
return nil, errors.New("qwen3next: missing attn_qkv/attn_gate projections (legacy ssm_in is not supported)")
|
||||
}
|
||||
|
||||
// Precompute projections for the full batch and slice per sequence.
|
||||
mixedBAFull := gdn.SSMBetaAlpha.Forward(ctx, hiddenStates)
|
||||
qkvMixedFull := gdn.SSMQKV.Forward(ctx, hiddenStates)
|
||||
zFull := gdn.SSMQKVGate.Forward(ctx, hiddenStates)
|
||||
|
||||
out := hiddenStates
|
||||
for seqIndex := range cache.curSeqs {
|
||||
idxs := cache.curSeqTokenIdxs[seqIndex]
|
||||
if len(idxs) == 0 {
|
||||
continue
|
||||
}
|
||||
idxTensor := ctx.Input().FromInts(idxs, len(idxs))
|
||||
|
||||
mixedBA := mixedBAFull.Rows(ctx, idxTensor)
|
||||
qkvMixed := qkvMixedFull.Rows(ctx, idxTensor)
|
||||
z := zFull.Rows(ctx, idxTensor)
|
||||
|
||||
slot := cache.curSlots[seqIndex]
|
||||
access := stateAccessors{
|
||||
convState: func() (ml.Tensor, error) {
|
||||
return cache.convStateForSlot(ctx, layer, slot)
|
||||
},
|
||||
updateConv: func(newState ml.Tensor) {
|
||||
cache.updateConvStateForSlot(ctx, layer, slot, seqIndex, newState)
|
||||
},
|
||||
deltaState: func() (ml.Tensor, error) {
|
||||
return cache.deltaStateForSlot(ctx, layer, slot, headVDim, numVHeads)
|
||||
},
|
||||
updateDelta: func(newState ml.Tensor) {
|
||||
cache.updateDeltaStateForSlot(ctx, layer, slot, seqIndex, newState)
|
||||
},
|
||||
}
|
||||
|
||||
seqOut, err := gdn.forwardProjected(ctx, len(idxs), 1, mixedBA, qkvMixed, z, opts, access)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = out.SetRows(ctx, seqOut, idxTensor)
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (gdn *GatedDeltaNet) forwardWithAccessors(ctx ml.Context, hiddenStates ml.Tensor, opts *Options, access stateAccessors) (ml.Tensor, error) {
|
||||
nSeqTokens := hiddenStates.Dim(1)
|
||||
nSeqs := hiddenStates.Dim(2)
|
||||
|
||||
mixedBA := gdn.SSMBetaAlpha.Forward(ctx, hiddenStates)
|
||||
|
||||
if gdn.SSMQKV == nil || gdn.SSMQKVGate == nil {
|
||||
return nil, errors.New("qwen3next: missing attn_qkv/attn_gate projections (legacy ssm_in is not supported)")
|
||||
}
|
||||
// Optimized path: pre-split QKV and gate
|
||||
qkvMixed := gdn.SSMQKV.Forward(ctx, hiddenStates)
|
||||
z := gdn.SSMQKVGate.Forward(ctx, hiddenStates)
|
||||
|
||||
return gdn.forwardProjected(ctx, nSeqTokens, nSeqs, mixedBA, qkvMixed, z, opts, access)
|
||||
}
|
||||
|
||||
func (gdn *GatedDeltaNet) forwardProjected(
|
||||
ctx ml.Context,
|
||||
nSeqTokens, nSeqs int,
|
||||
mixedBA, qkvMixed, z ml.Tensor,
|
||||
opts *Options,
|
||||
access stateAccessors,
|
||||
) (ml.Tensor, error) {
|
||||
layer := gdn.Layer
|
||||
|
||||
headKDim := opts.ssmDState
|
||||
numKHeads := opts.ssmNGroup
|
||||
numVHeads := opts.ssmDtRank
|
||||
headVDim := opts.ssmDInner / numVHeads
|
||||
convKernelSize := opts.convKernelSize
|
||||
|
||||
mixedBA := gdn.SSMBetaAlpha.Forward(ctx, hiddenStates)
|
||||
qkvDim := headKDim*numKHeads*2 + headVDim*numVHeads
|
||||
|
||||
if gdn.SSMQKV == nil || gdn.SSMQKVGate == nil {
|
||||
return nil, errors.New("qwen3next: missing attn_qkv/attn_gate projections (legacy ssm_in is not supported)")
|
||||
}
|
||||
// Optimized path: pre-split QKV and gate
|
||||
qkvMixed := gdn.SSMQKV.Forward(ctx, hiddenStates).Reshape(ctx, qkvDim, nSeqTokens, nSeqs)
|
||||
z := gdn.SSMQKVGate.Forward(ctx, hiddenStates)
|
||||
qkvMixed = qkvMixed.Reshape(ctx, qkvDim, nSeqTokens, nSeqs)
|
||||
|
||||
baNewDim := 2 * numVHeads / numKHeads
|
||||
mixedBAReshaped := mixedBA.Reshape(ctx, baNewDim, numKHeads, nSeqTokens, nSeqs)
|
||||
@@ -127,7 +239,7 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac
|
||||
qkvMixed = qkvMixed.Permute(ctx, 1, 0, 2, 3)
|
||||
|
||||
// Get conv state from cache
|
||||
convStates, err := cache.ConvState(ctx, layer)
|
||||
convStates, err := access.convState()
|
||||
if err != nil {
|
||||
// Log this - if it happens, short-term context will be lost
|
||||
slog.Warn("qwen3next: failed to get conv state, using zeros", "layer", layer, "error", err)
|
||||
@@ -142,7 +254,7 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac
|
||||
|
||||
// Save new conv state (last convKernelSize-1 tokens)
|
||||
lastConvStates := convInput.Slice(ctx, 0, nSeqTokens, nSeqTokens+convKernelSize-1, 1)
|
||||
cache.UpdateConvState(ctx, layer, lastConvStates)
|
||||
access.updateConv(lastConvStates)
|
||||
|
||||
// Apply SSM convolution (kernel must be F32 for Metal)
|
||||
convOutput := convInput.SSMConv(ctx, gdn.SSMConv1D.Weight)
|
||||
@@ -162,7 +274,7 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac
|
||||
vConv = vConv.Contiguous(ctx, headVDim, numVHeads, nSeqTokens, nSeqs)
|
||||
|
||||
// Get delta state from cache
|
||||
state, err := cache.DeltaState(ctx, layer, headVDim, numVHeads)
|
||||
state, err := access.deltaState()
|
||||
if err != nil {
|
||||
// Log this - if it happens frequently, context will degrade
|
||||
slog.Warn("qwen3next: failed to get delta state, using zeros", "layer", layer, "error", err)
|
||||
@@ -185,14 +297,19 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac
|
||||
}
|
||||
|
||||
// Choose computation mode based on sequence length
|
||||
var attnOut ml.Tensor
|
||||
var (
|
||||
attnOut ml.Tensor
|
||||
newState ml.Tensor
|
||||
)
|
||||
if nSeqTokens == 1 {
|
||||
attnOut = gdn.deltaNetAutoregressive(ctx, qConv, kConv, vConv, gate, beta, state, opts, layer, cache)
|
||||
attnOut, newState = gdn.deltaNetAutoregressive(ctx, qConv, kConv, vConv, gate, beta, state, opts)
|
||||
} else {
|
||||
// Use pre-computed masks from opts (created once in Model.Forward)
|
||||
attnOut = gdn.deltaNetChunked(ctx, qConv, kConv, vConv, gate, beta, state, opts.masks, opts, layer, cache)
|
||||
attnOut, newState = gdn.deltaNetChunked(ctx, qConv, kConv, vConv, gate, beta, state, opts.masks, opts)
|
||||
}
|
||||
|
||||
access.updateDelta(newState)
|
||||
|
||||
// Apply gated normalization
|
||||
attnOut2D := attnOut.Contiguous(ctx, headVDim, numVHeads*nSeqTokens*nSeqs)
|
||||
z2D := z.Contiguous(ctx, headVDim, numVHeads*nSeqTokens*nSeqs)
|
||||
@@ -215,9 +332,7 @@ func (gdn *GatedDeltaNet) deltaNetAutoregressive(
|
||||
ctx ml.Context,
|
||||
q, k, v, gate, beta, state ml.Tensor,
|
||||
opts *Options,
|
||||
layer int,
|
||||
cache *HybridCache,
|
||||
) ml.Tensor {
|
||||
) (ml.Tensor, ml.Tensor) {
|
||||
numVHeads := v.Dim(1)
|
||||
headVDim := v.Dim(0)
|
||||
nSeqs := q.Dim(3)
|
||||
@@ -273,10 +388,8 @@ func (gdn *GatedDeltaNet) deltaNetAutoregressive(
|
||||
coreAttnOut := stateQ.SumRows(ctx)
|
||||
coreAttnOut = coreAttnOut.Permute(ctx, 1, 0, 2, 3)
|
||||
|
||||
// Update delta state in cache
|
||||
cache.UpdateDeltaState(ctx, layer, state.Reshape(ctx, headVDim, headVDim*numVHeads, nSeqs))
|
||||
|
||||
return coreAttnOut.Reshape(ctx, headVDim, numVHeads, 1, nSeqs)
|
||||
newState := state.Reshape(ctx, headVDim, headVDim*numVHeads, nSeqs)
|
||||
return coreAttnOut.Reshape(ctx, headVDim, numVHeads, 1, nSeqs), newState
|
||||
}
|
||||
|
||||
// deltaNetChunked implements chunked computation for prefill.
|
||||
@@ -286,9 +399,7 @@ func (gdn *GatedDeltaNet) deltaNetChunked(
|
||||
q, k, v, gate, beta, state ml.Tensor,
|
||||
masks *Masks,
|
||||
opts *Options,
|
||||
layer int,
|
||||
cache *HybridCache,
|
||||
) ml.Tensor {
|
||||
) (ml.Tensor, ml.Tensor) {
|
||||
headKDim := q.Dim(0)
|
||||
numVHeads := v.Dim(1)
|
||||
headVDim := v.Dim(0)
|
||||
@@ -465,8 +576,6 @@ func (gdn *GatedDeltaNet) deltaNetChunked(
|
||||
coreAttnOut = coreAttnOut.Slice(ctx, 1, 0, nTokens, 1)
|
||||
}
|
||||
|
||||
// Update delta state in cache
|
||||
cache.UpdateDeltaState(ctx, layer, newState.Reshape(ctx, headVDim, headVDim*numVHeads, nSeqs))
|
||||
|
||||
return coreAttnOut.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headVDim, numVHeads, nTokens, nSeqs)
|
||||
newStateFlat := newState.Reshape(ctx, headVDim, headVDim*numVHeads, nSeqs)
|
||||
return coreAttnOut.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headVDim, numVHeads, nTokens, nSeqs), newStateFlat
|
||||
}
|
||||
|
||||
@@ -514,13 +514,6 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
|
||||
continue
|
||||
}
|
||||
|
||||
// if past the num predict limit
|
||||
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
|
||||
s.removeSequence(seqIdx, llm.DoneReasonLength)
|
||||
nextBatch.seqs[seqIdx] = nil
|
||||
continue
|
||||
}
|
||||
|
||||
if !s.cache.enabled {
|
||||
seq.inputs = append(seq.cache.Inputs, seq.inputs...)
|
||||
seq.cache.Inputs = []*input.Input{}
|
||||
@@ -709,7 +702,6 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
||||
continue
|
||||
}
|
||||
|
||||
seq.numPredicted++
|
||||
nextToken := &input.Input{Token: 0} // placeholder we'll fill in after Compute/Floats
|
||||
seq.inputs = []*input.Input{nextToken}
|
||||
nextBatchTokens[i] = nextToken
|
||||
@@ -745,7 +737,9 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
||||
logutil.Trace("computeBatch: sequence replaced, discarding its results", "batchID", activeBatch.id, "seqIdx", i)
|
||||
continue
|
||||
}
|
||||
|
||||
seq.lastUpdatedAt = t
|
||||
seq.numPredicted++
|
||||
if seq.numPredicted == 1 {
|
||||
seq.processingDuration = seq.lastUpdatedAt.Sub(seq.startedAt)
|
||||
seq.startedAt = seq.lastUpdatedAt
|
||||
@@ -791,6 +785,13 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
||||
}
|
||||
|
||||
seq.pendingResponses = append(seq.pendingResponses, piece)
|
||||
|
||||
// if past the num predict limit
|
||||
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
|
||||
s.removeSequence(i, llm.DoneReasonLength)
|
||||
continue
|
||||
}
|
||||
|
||||
sequence := strings.Join(seq.pendingResponses, "")
|
||||
|
||||
if ok, stop := common.FindStop(sequence, seq.stop); ok {
|
||||
|
||||
Reference in New Issue
Block a user