mirror of
https://github.com/ollama/ollama.git
synced 2026-01-19 04:51:17 -05:00
Compare commits
4 Commits
parth/decr
...
grace/mist
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5584bf1e19 | ||
|
|
e2f8845f1c | ||
|
|
08d1485846 | ||
|
|
f331801252 |
@@ -216,6 +216,8 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
|
|||||||
conv = &deepseekocr{}
|
conv = &deepseekocr{}
|
||||||
case "DeepseekV3ForCausalLM":
|
case "DeepseekV3ForCausalLM":
|
||||||
conv = &deepseek2Model{}
|
conv = &deepseek2Model{}
|
||||||
|
case "MistralForCausalLM":
|
||||||
|
conv = &mistralLarge3Model{}
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("unsupported architecture %q", p.Architectures[0])
|
return fmt.Errorf("unsupported architecture %q", p.Architectures[0])
|
||||||
}
|
}
|
||||||
|
|||||||
286
convert/convert_mistrallarge3.go
Normal file
286
convert/convert_mistrallarge3.go
Normal file
@@ -0,0 +1,286 @@
|
|||||||
|
package convert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"cmp"
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"regexp"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mistralLarge3Model struct {
|
||||||
|
ModelParameters
|
||||||
|
Dim uint32 `json:"dim"`
|
||||||
|
NumLayers uint32 `json:"n_layers"`
|
||||||
|
HeadDim uint32 `json:"head_dim"`
|
||||||
|
HiddenDim uint32 `json:"hidden_dim"`
|
||||||
|
NumHeads uint32 `json:"n_heads"`
|
||||||
|
NumKVHeads uint32 `json:"n_kv_heads"`
|
||||||
|
RopeTheta float32 `json:"rope_theta"`
|
||||||
|
NormEps float32 `json:"norm_eps"`
|
||||||
|
VocabSize uint32 `json:"vocab_size"`
|
||||||
|
TiedEmbeddings bool `json:"tied_embeddings"`
|
||||||
|
MaxPosEmbed uint32 `json:"max_position_embeddings"`
|
||||||
|
MaxSeqLen uint32 `json:"max_seq_len"`
|
||||||
|
|
||||||
|
// LoRA attention parameters (DeepSeek-style)
|
||||||
|
QLoraRank uint32 `json:"q_lora_rank"`
|
||||||
|
QKRopeHeadDim uint32 `json:"qk_rope_head_dim"`
|
||||||
|
QKNopeHeadDim uint32 `json:"qk_nope_head_dim"`
|
||||||
|
KVLoraRank uint32 `json:"kv_lora_rank"`
|
||||||
|
VHeadDim uint32 `json:"v_head_dim"`
|
||||||
|
|
||||||
|
// ROPE scaling configurations
|
||||||
|
Llama4Scaling struct {
|
||||||
|
OrigMaxPosEmbed uint32 `json:"original_max_position_embeddings"`
|
||||||
|
Beta float32 `json:"beta"`
|
||||||
|
} `json:"llama_4_scaling"`
|
||||||
|
|
||||||
|
Yarn struct {
|
||||||
|
OrigMaxPosEmbed uint32 `json:"original_max_position_embeddings"`
|
||||||
|
Factor float32 `json:"factor"`
|
||||||
|
ApplyScale bool `json:"apply_scale"`
|
||||||
|
Beta float32 `json:"beta"`
|
||||||
|
Alpha float32 `json:"alpha"`
|
||||||
|
} `json:"yarn"`
|
||||||
|
|
||||||
|
// MOE configuration
|
||||||
|
MOE struct {
|
||||||
|
ExpertParallel uint32 `json:"expert_parallel"`
|
||||||
|
ExpertModelParallel uint32 `json:"expert_model_parallel"`
|
||||||
|
RouteEveryN uint32 `json:"route_every_n"`
|
||||||
|
FirstKDenseReplace uint32 `json:"first_k_dense_replace"`
|
||||||
|
NumExperts uint32 `json:"num_experts"`
|
||||||
|
NumExpertsPerTok uint32 `json:"num_experts_per_tok"`
|
||||||
|
NumExpertGroups uint32 `json:"num_expert_groups"`
|
||||||
|
NumExpertGroupsPerTok uint32 `json:"num_expert_groups_per_tok"`
|
||||||
|
RoutedScale float32 `json:"routed_scale"`
|
||||||
|
ExpertHiddenDim uint32 `json:"expert_hidden_dim"`
|
||||||
|
NumSharedExperts uint32 `json:"num_shared_experts"`
|
||||||
|
} `json:"moe"`
|
||||||
|
|
||||||
|
// Vision encoder configuration
|
||||||
|
VisionEncoder struct {
|
||||||
|
ImageTokenID uint32 `json:"image_token_id"`
|
||||||
|
ImageBreakTokenID uint32 `json:"image_break_token_id"`
|
||||||
|
ImageEndTokenID uint32 `json:"image_end_token_id"`
|
||||||
|
IntermediateSize uint32 `json:"intermediate_size"`
|
||||||
|
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||||
|
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||||
|
MMProjectorID string `json:"mm_projector_id"`
|
||||||
|
SpatialMergeSize uint32 `json:"spatial_merge_size"`
|
||||||
|
HiddenSize uint32 `json:"hidden_size"`
|
||||||
|
NumChannels uint32 `json:"num_channels"`
|
||||||
|
ImageSize uint32 `json:"image_size"`
|
||||||
|
MaxImageSize uint32 `json:"max_image_size"`
|
||||||
|
PatchSize uint32 `json:"patch_size"`
|
||||||
|
RopeTheta float32 `json:"rope_theta"`
|
||||||
|
AddPreMMProjectorLayerNorm bool `json:"add_pre_mm_projector_layer_norm"`
|
||||||
|
AdapterBias bool `json:"adapter_bias"`
|
||||||
|
} `json:"vision_encoder"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *mistralLarge3Model) KV(t *Tokenizer) ggml.KV {
|
||||||
|
kv := p.ModelParameters.KV(t)
|
||||||
|
kv["general.architecture"] = "deepseek2" // Use deepseek2 architecture for runtime compatibility
|
||||||
|
kv["general.type"] = "model"
|
||||||
|
|
||||||
|
// Basic model parameters (using deepseek2 keys for compatibility)
|
||||||
|
kv["deepseek2.vocab_size"] = p.VocabSize
|
||||||
|
kv["deepseek2.block_count"] = p.NumLayers
|
||||||
|
kv["deepseek2.context_length"] = cmp.Or(p.MaxPosEmbed, p.MaxSeqLen)
|
||||||
|
kv["deepseek2.embedding_length"] = p.Dim
|
||||||
|
kv["deepseek2.feed_forward_length"] = p.HiddenDim
|
||||||
|
|
||||||
|
// Attention configuration
|
||||||
|
kv["deepseek2.attention.head_count"] = p.NumHeads
|
||||||
|
kv["deepseek2.attention.head_count_kv"] = p.NumKVHeads
|
||||||
|
kv["deepseek2.attention.layer_norm_rms_epsilon"] = p.NormEps
|
||||||
|
kv["deepseek2.attention.key_length"] = p.QKNopeHeadDim + p.QKRopeHeadDim
|
||||||
|
kv["deepseek2.attention.value_length"] = p.VHeadDim
|
||||||
|
|
||||||
|
// LoRA attention parameters
|
||||||
|
kv["deepseek2.attention.q_lora_rank"] = p.QLoraRank
|
||||||
|
kv["deepseek2.attention.kv_lora_rank"] = p.KVLoraRank
|
||||||
|
|
||||||
|
// ROPE configuration
|
||||||
|
kv["deepseek2.rope.dimension_count"] = p.QKRopeHeadDim
|
||||||
|
kv["deepseek2.rope.freq_base"] = cmp.Or(p.RopeTheta, 10000.0)
|
||||||
|
|
||||||
|
// ROPE scaling - map to deepseek2 format
|
||||||
|
if p.Yarn.OrigMaxPosEmbed > 0 {
|
||||||
|
kv["deepseek2.rope.scaling.factor"] = p.Yarn.Factor
|
||||||
|
kv["deepseek2.rope.scaling.original_context_length"] = p.Yarn.OrigMaxPosEmbed
|
||||||
|
kv["deepseek2.rope.scaling.type"] = "yarn"
|
||||||
|
kv["deepseek2.rope.scaling.yarn_log_multiplier"] = float32(0.1) // mscale_all_dim * 0.1 as in llama.cpp
|
||||||
|
}
|
||||||
|
|
||||||
|
// MOE configuration
|
||||||
|
if p.MOE.NumExperts > 0 {
|
||||||
|
kv["deepseek2.expert_count"] = p.MOE.NumExperts
|
||||||
|
kv["deepseek2.expert_used_count"] = p.MOE.NumExpertsPerTok
|
||||||
|
kv["deepseek2.expert_shared_count"] = p.MOE.NumSharedExperts
|
||||||
|
kv["deepseek2.expert_feed_forward_length"] = p.MOE.ExpertHiddenDim
|
||||||
|
kv["deepseek2.expert_weights_scale"] = p.MOE.RoutedScale
|
||||||
|
kv["deepseek2.leading_dense_block_count"] = p.MOE.FirstKDenseReplace
|
||||||
|
kv["deepseek2.expert_weights_norm"] = true
|
||||||
|
kv["deepseek2.expert_gating_func"] = uint32(1) // softmax
|
||||||
|
}
|
||||||
|
|
||||||
|
// Vision encoder configuration (if supported by deepseek2 runtime)
|
||||||
|
if p.VisionEncoder.HiddenSize > 0 {
|
||||||
|
kv["deepseek2.vision.block_count"] = p.VisionEncoder.NumHiddenLayers
|
||||||
|
kv["deepseek2.vision.embedding_length"] = p.VisionEncoder.HiddenSize
|
||||||
|
kv["deepseek2.vision.feed_forward_length"] = p.VisionEncoder.IntermediateSize
|
||||||
|
kv["deepseek2.vision.attention.head_count"] = p.VisionEncoder.NumAttentionHeads
|
||||||
|
kv["deepseek2.vision.image_size"] = p.VisionEncoder.ImageSize
|
||||||
|
kv["deepseek2.vision.patch_size"] = p.VisionEncoder.PatchSize
|
||||||
|
kv["deepseek2.vision.num_channels"] = p.VisionEncoder.NumChannels
|
||||||
|
|
||||||
|
// Multimodal configuration
|
||||||
|
kv["deepseek2.image_token_id"] = p.VisionEncoder.ImageTokenID
|
||||||
|
kv["deepseek2.image_break_token_id"] = p.VisionEncoder.ImageBreakTokenID
|
||||||
|
kv["deepseek2.image_end_token_id"] = p.VisionEncoder.ImageEndTokenID
|
||||||
|
kv["deepseek2.spatial_merge_size"] = p.VisionEncoder.SpatialMergeSize
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set tokenizer type - use tekken preprocessing (now supported!)
|
||||||
|
kv["tokenizer.ggml.pre"] = "tekken"
|
||||||
|
|
||||||
|
return kv
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *mistralLarge3Model) specialTokenTypes() []string {
|
||||||
|
return []string{
|
||||||
|
"bos", "eos", "unk", "sep", "pad", "cls", "mask",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *mistralLarge3Model) Replacements() []string {
|
||||||
|
return []string{
|
||||||
|
"lm_head", "output",
|
||||||
|
"tok_embeddings", "token_embd", // Mistral Large uses tok_embeddings instead of model.embed_tokens
|
||||||
|
"norm", "output_norm",
|
||||||
|
"language_model.", "",
|
||||||
|
"layers", "blk", // Mistral 3 Large uses "layers" instead of "model.layers"
|
||||||
|
"attention_norm", "attn_norm",
|
||||||
|
|
||||||
|
// LoRA attention mappings (Mistral 3 Large style)
|
||||||
|
"attention.wkv_a_with_mqa", "attn_kv_a_mqa",
|
||||||
|
"attention.kv_a_norm", "attn_kv_a_norm",
|
||||||
|
"attention.wkv_b", "attn_kv_b",
|
||||||
|
"attention.wq_a", "attn_q_a",
|
||||||
|
"attention.q_a_norm", "attn_q_a_norm",
|
||||||
|
"attention.wq_b", "attn_q_b",
|
||||||
|
"attention.wo", "attn_output",
|
||||||
|
|
||||||
|
"ffn_norm", "ffn_norm", // Keep ffn_norm as is
|
||||||
|
|
||||||
|
// MOE mappings for Mistral 3 Large
|
||||||
|
"shared_experts.w2", "ffn_down_shexp",
|
||||||
|
"shared_experts.w1", "ffn_gate_shexp",
|
||||||
|
"shared_experts.w3", "ffn_up_shexp",
|
||||||
|
"experts.*.w1", "ffn_gate_exps", // Will be merged in Tensors()
|
||||||
|
"experts.*.w2", "ffn_down_exps", // Will be merged in Tensors()
|
||||||
|
"experts.*.w3", "ffn_up_exps", // Will be merged in Tensors()
|
||||||
|
"gate", "ffn_gate_inp",
|
||||||
|
|
||||||
|
// Standard feed forward mappings (for non-MOE layers)
|
||||||
|
"feed_forward.w1", "ffn_gate",
|
||||||
|
"feed_forward.w2", "ffn_down",
|
||||||
|
"feed_forward.w3", "ffn_up",
|
||||||
|
|
||||||
|
// Mistral-specific tensor renaming
|
||||||
|
".qscale_act", ".input_scale",
|
||||||
|
".qscale_weight", ".weight_scale",
|
||||||
|
|
||||||
|
// Vision encoder mappings - do we even need this?
|
||||||
|
"vision_tower", "v",
|
||||||
|
"ln_pre", "encoder_norm",
|
||||||
|
"attention.q_proj", "attn_q",
|
||||||
|
"attention.k_proj", "attn_k",
|
||||||
|
"attention.v_proj", "attn_v",
|
||||||
|
"attention.o_proj", "attn_output",
|
||||||
|
"attention_norm", "attn_norm",
|
||||||
|
"feed_forward.gate_proj", "ffn_gate",
|
||||||
|
"feed_forward.down_proj", "ffn_down",
|
||||||
|
"feed_forward.up_proj", "ffn_up",
|
||||||
|
|
||||||
|
"multi_modal_projector", "mm",
|
||||||
|
"patch_merger.merging_layer", "mm.patch_merger",
|
||||||
|
"pre_mm_projector_norm", "mm.pre_norm",
|
||||||
|
"vision_language_adapter.w_in", "mm.w_in",
|
||||||
|
"vision_language_adapter.w_out", "mm.w_out",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *mistralLarge3Model) Tensors(s []Tensor) (out []*ggml.Tensor) {
|
||||||
|
// Create merges for MOE expert tensors
|
||||||
|
if p.MOE.NumExperts > 0 {
|
||||||
|
merges := make([]merge, p.NumLayers*3)
|
||||||
|
for i := range p.NumLayers {
|
||||||
|
merges[i*3+0] = merge{
|
||||||
|
fmt.Sprintf("blk.%d.experts.*.w1.weight", i),
|
||||||
|
fmt.Sprintf("blk.%d.ffn_gate_exps.weight", i),
|
||||||
|
}
|
||||||
|
merges[i*3+1] = merge{
|
||||||
|
fmt.Sprintf("blk.%d.experts.*.w3.weight", i),
|
||||||
|
fmt.Sprintf("blk.%d.ffn_up_exps.weight", i),
|
||||||
|
}
|
||||||
|
merges[i*3+2] = merge{
|
||||||
|
fmt.Sprintf("blk.%d.experts.*.w2.weight", i),
|
||||||
|
fmt.Sprintf("blk.%d.ffn_down_exps.weight", i),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
out, s = mergeTensors(s, merges...)
|
||||||
|
}
|
||||||
|
|
||||||
|
skipLayer := func(n string, minValue uint32) bool {
|
||||||
|
re := regexp.MustCompile(`^blk\.(\d+)`)
|
||||||
|
matches := re.FindStringSubmatch(n)
|
||||||
|
if matches == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
blkNum, err := strconv.Atoi(matches[1])
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return uint32(blkNum) >= minValue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Function to check if tensor should be skipped (vision components)
|
||||||
|
skipVisionTensor := func(name string) bool {
|
||||||
|
return strings.HasPrefix(name, "vision_") ||
|
||||||
|
strings.HasPrefix(name, "patch_merger.") ||
|
||||||
|
strings.Contains(name, "mm_projector")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, t := range s {
|
||||||
|
name := t.Name()
|
||||||
|
|
||||||
|
// Skip vision tensors (handled separately or not needed)
|
||||||
|
if skipVisionTensor(name) {
|
||||||
|
slog.Debug("skipping vision tensor", "name", name)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip any additional layers beyond expected count
|
||||||
|
if skipLayer(name, p.NumLayers) {
|
||||||
|
slog.Debug("skipping extra layer", "name", name)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
out = append(out, &ggml.Tensor{
|
||||||
|
Name: name,
|
||||||
|
Kind: t.Kind(),
|
||||||
|
Shape: t.Shape(),
|
||||||
|
WriterTo: t,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
@@ -101,6 +101,8 @@ func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error)
|
|||||||
t.Pre = "deepseek-coder"
|
t.Pre = "deepseek-coder"
|
||||||
case "1ff7f41064896984db5d1bb6ff64fa4bc29007d08c1b439e505b7392777a319e":
|
case "1ff7f41064896984db5d1bb6ff64fa4bc29007d08c1b439e505b7392777a319e":
|
||||||
t.Pre = "qwen2"
|
t.Pre = "qwen2"
|
||||||
|
case "1d64a9a8eaf9f1bd80331984d81fdd514e7feafe8df83a525dd31472f275699a":
|
||||||
|
t.Pre = "tekken"
|
||||||
case "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855":
|
case "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855":
|
||||||
// noop, empty pretokenizer
|
// noop, empty pretokenizer
|
||||||
default:
|
default:
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ package deepseek2
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"cmp"
|
"cmp"
|
||||||
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
|
|
||||||
"github.com/ollama/ollama/fs"
|
"github.com/ollama/ollama/fs"
|
||||||
@@ -39,6 +40,10 @@ type Options struct {
|
|||||||
ropeBase,
|
ropeBase,
|
||||||
ropeScale float32
|
ropeScale float32
|
||||||
kqScale float64
|
kqScale float64
|
||||||
|
|
||||||
|
attentionTemperatureScale float32
|
||||||
|
attentionTemperatureLength int
|
||||||
|
attentionTemperatureFloorScale int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, t, p ml.Tensor) ml.Tensor {
|
func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, t, p ml.Tensor) ml.Tensor {
|
||||||
@@ -66,7 +71,7 @@ type Attention struct {
|
|||||||
Output *nn.Linear `gguf:"attn_out,alt:attn_output"`
|
Output *nn.Linear `gguf:"attn_out,alt:attn_output"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions, attentionScales ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||||
seqLength := hiddenStates.Dim(1)
|
seqLength := hiddenStates.Dim(1)
|
||||||
|
|
||||||
var query ml.Tensor
|
var query ml.Tensor
|
||||||
@@ -104,6 +109,11 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor
|
|||||||
kRot = kRot.Repeat(ctx, 1, queryChunks[0].Dim(1))
|
kRot = kRot.Repeat(ctx, 1, queryChunks[0].Dim(1))
|
||||||
query = qRot.Concat(ctx, queryChunks[0], 0)
|
query = qRot.Concat(ctx, queryChunks[0], 0)
|
||||||
key := kRot.Concat(ctx, kvChunks[0], 0)
|
key := kRot.Concat(ctx, kvChunks[0], 0)
|
||||||
|
|
||||||
|
if attentionScales != nil {
|
||||||
|
query = query.Mul(ctx, attentionScales)
|
||||||
|
}
|
||||||
|
|
||||||
attention = nn.Attention(ctx, query, key, kvChunks[1], opts.kqScale, cache)
|
attention = nn.Attention(ctx, query, key, kvChunks[1], opts.kqScale, cache)
|
||||||
} else { // v3.1
|
} else { // v3.1
|
||||||
qPass := queryChunks[0].Permute(ctx, 0, 2, 1, 3)
|
qPass := queryChunks[0].Permute(ctx, 0, 2, 1, 3)
|
||||||
@@ -115,6 +125,10 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor
|
|||||||
key := kRot.Concat(ctx, kPass, 0)
|
key := kRot.Concat(ctx, kPass, 0)
|
||||||
value := kPass
|
value := kPass
|
||||||
|
|
||||||
|
if attentionScales != nil {
|
||||||
|
query = query.Mul(ctx, attentionScales)
|
||||||
|
}
|
||||||
|
|
||||||
attention = nn.AttentionWithVMLA(ctx, query, key, value, nil, attn.VB.Weight, opts.kqScale, cache)
|
attention = nn.AttentionWithVMLA(ctx, query, key, value, nil, attn.VB.Weight, opts.kqScale, cache)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -201,10 +215,10 @@ type Layer struct {
|
|||||||
MLP MLP
|
MLP MLP
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
func (t *Layer) Forward(ctx ml.Context, hiddenStates, positions, attentionScales, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||||
residual := hiddenStates
|
residual := hiddenStates
|
||||||
hiddenStates = t.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
|
hiddenStates = t.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||||
hiddenStates = t.Attention.Forward(ctx, hiddenStates, positions, cache, opts)
|
hiddenStates = t.Attention.Forward(ctx, hiddenStates, positions, attentionScales, cache, opts)
|
||||||
|
|
||||||
if outputs != nil {
|
if outputs != nil {
|
||||||
hiddenStates = hiddenStates.Rows(ctx, outputs)
|
hiddenStates = hiddenStates.Rows(ctx, outputs)
|
||||||
@@ -234,7 +248,11 @@ type Model struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func New(c fs.Config) (model.Model, error) {
|
func New(c fs.Config) (model.Model, error) {
|
||||||
layers := make([]Layer, c.Uint("block_count"))
|
// layers := make([]Layer, c.Uint("block_count"))
|
||||||
|
// fmt.Printf("[MODEL DEBUG] Creating model with %d layers\n", c.Uint("block_count"))
|
||||||
|
|
||||||
|
layers := make([]Layer, 4)
|
||||||
|
fmt.Printf("[MODEL DEBUG] Creating model with %d layers\n", 4)
|
||||||
|
|
||||||
firstDenseLayerIndex := int(c.Uint("leading_dense_block_count"))
|
firstDenseLayerIndex := int(c.Uint("leading_dense_block_count"))
|
||||||
for i := range layers {
|
for i := range layers {
|
||||||
@@ -261,6 +279,10 @@ func New(c fs.Config) (model.Model, error) {
|
|||||||
`[一-龥-ゟ゠-ヿ]+`,
|
`[一-龥-ゟ゠-ヿ]+`,
|
||||||
"[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+",
|
"[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+",
|
||||||
}
|
}
|
||||||
|
case "tekken":
|
||||||
|
pre = []string{
|
||||||
|
"[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||||
|
}
|
||||||
case "deepseek-llm":
|
case "deepseek-llm":
|
||||||
// TODO: these models haven't been vetted so skip for now
|
// TODO: these models haven't been vetted so skip for now
|
||||||
// pre = []string{
|
// pre = []string{
|
||||||
@@ -276,13 +298,20 @@ func New(c fs.Config) (model.Model, error) {
|
|||||||
return nil, model.ErrUnsupportedTokenizer
|
return nil, model.ErrUnsupportedTokenizer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DEBUG: Check tokenizer vocabulary loading
|
||||||
|
tokens := c.Strings("tokenizer.ggml.tokens")
|
||||||
|
tokenTypes := c.Ints("tokenizer.ggml.token_type")
|
||||||
|
merges := c.Strings("tokenizer.ggml.merges")
|
||||||
|
|
||||||
|
// Debug output removed for performance
|
||||||
|
|
||||||
m := Model{
|
m := Model{
|
||||||
BytePairEncoding: model.NewBytePairEncoding(
|
BytePairEncoding: model.NewBytePairEncoding(
|
||||||
&model.Vocabulary{
|
&model.Vocabulary{
|
||||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
Values: tokens,
|
||||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
Types: tokenTypes,
|
||||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
Merges: merges,
|
||||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
AddBOS: false, // c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||||
EOS: append(
|
EOS: append(
|
||||||
@@ -316,6 +345,11 @@ func New(c fs.Config) (model.Model, error) {
|
|||||||
routedScalingFactor: c.Float("expert_weights_scale"),
|
routedScalingFactor: c.Float("expert_weights_scale"),
|
||||||
originalContextLength: int(c.Uint("rope.scaling.original_context_length")),
|
originalContextLength: int(c.Uint("rope.scaling.original_context_length")),
|
||||||
|
|
||||||
|
// TODO: double check these values
|
||||||
|
attentionTemperatureScale: c.Float("attention.temperature_scale", 1.0),
|
||||||
|
attentionTemperatureLength: int(c.Uint("attention.temperature_length")),
|
||||||
|
attentionTemperatureFloorScale: int(c.Uint("attention.temperature_floor_scale", 8192)),
|
||||||
|
|
||||||
kqScale: kqScale,
|
kqScale: kqScale,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -331,8 +365,28 @@ func (m Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor
|
|||||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
|
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
|
||||||
|
|
||||||
|
// DEBUG: Check TokenEmbedding initialization
|
||||||
|
if m.TokenEmbedding == nil {
|
||||||
|
panic("DEBUG: m.TokenEmbedding is nil - 'token_embd' tensor not found in GGUF")
|
||||||
|
}
|
||||||
|
|
||||||
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||||
|
|
||||||
|
// Temperature tuning - used by mistral-large
|
||||||
|
var attentionScales ml.Tensor
|
||||||
|
if m.attentionTemperatureScale != 0.0 {
|
||||||
|
nTokens := len(batch.Positions)
|
||||||
|
scales := make([]float32, nTokens)
|
||||||
|
|
||||||
|
for i, pos := range batch.Positions {
|
||||||
|
posFloat := float64(pos)
|
||||||
|
scaleValue := math.Log(math.Floor((posFloat+1.0)/float64(m.attentionTemperatureFloorScale))+1.0)*float64(m.attentionTemperatureScale) + 1.0
|
||||||
|
scales[i] = float32(scaleValue)
|
||||||
|
}
|
||||||
|
|
||||||
|
attentionScales = ctx.Input().FromFloats(scales, 1, 1, nTokens)
|
||||||
|
}
|
||||||
|
|
||||||
for i, layer := range m.Layers {
|
for i, layer := range m.Layers {
|
||||||
m.Cache.SetLayer(i)
|
m.Cache.SetLayer(i)
|
||||||
|
|
||||||
@@ -341,7 +395,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
|||||||
outputs = batch.Outputs
|
outputs = batch.Outputs
|
||||||
}
|
}
|
||||||
|
|
||||||
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options)
|
hiddenStates = layer.Forward(ctx, hiddenStates, positions, attentionScales, outputs, m.Cache, m.Options)
|
||||||
}
|
}
|
||||||
|
|
||||||
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
|
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
|
||||||
|
|||||||
Reference in New Issue
Block a user