mirror of
https://github.com/ollama/ollama.git
synced 2026-01-29 01:33:06 -05:00
Compare commits
7 Commits
brucemacd/
...
pdevine/bf
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c75b428249 | ||
|
|
021dcf089d | ||
|
|
bf24498b1e | ||
|
|
95e271d98f | ||
|
|
364629b8d6 | ||
|
|
108fe02165 | ||
|
|
4561fff36e |
@@ -182,10 +182,8 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
|
||||
|
||||
var conv ModelConverter
|
||||
switch p.Architectures[0] {
|
||||
case "LlamaForCausalLM":
|
||||
case "LlamaForCausalLM", "MistralForCausalLM":
|
||||
conv = &llamaModel{}
|
||||
case "Mistral3ForConditionalGeneration":
|
||||
conv = &mistral3Model{}
|
||||
case "MixtralForCausalLM":
|
||||
conv = &mixtralModel{}
|
||||
case "GemmaForCausalLM":
|
||||
@@ -248,10 +246,5 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// iterate through all ts and print the name
|
||||
for _, t := range ts {
|
||||
fmt.Print(t.Name(), "\n")
|
||||
}
|
||||
|
||||
return conv.writeFile(ws, conv.KV(t), conv.Tensors(ts))
|
||||
}
|
||||
|
||||
@@ -1,223 +0,0 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/pdevine/tensor"
|
||||
"github.com/pdevine/tensor/native"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
type mistral3Model struct {
|
||||
ModelParameters
|
||||
ImageTokenIndex uint32 `json:"image_token_index"`
|
||||
SpatialMergeSize uint32 `json:"spatial_merge_size"`
|
||||
VisionFeatureLayer int32 `json:"vision_feature_layer"`
|
||||
TextModel struct {
|
||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
RMSNormEPS float32 `json:"rms_norm_eps"`
|
||||
HeadDim uint32 `json:"head_dim"`
|
||||
SlidingWindow *uint32 `json:"sliding_window"`
|
||||
HiddenAct string `json:"hidden_act"`
|
||||
VocabSize uint32 `json:"vocab_size"`
|
||||
} `json:"text_config"`
|
||||
VisionModel struct {
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
ImageSize uint32 `json:"image_size"`
|
||||
NumChannels uint32 `json:"num_channels"`
|
||||
PatchSize uint32 `json:"patch_size"`
|
||||
HeadDim uint32 `json:"head_dim"`
|
||||
HiddenAct string `json:"hidden_act"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
} `json:"vision_config"`
|
||||
MultiModalProjectorBias bool `json:"multimodal_projector_bias"`
|
||||
ProjectorHiddenAct string `json:"projector_hidden_act"`
|
||||
}
|
||||
|
||||
func (p *mistral3Model) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "mistral3"
|
||||
kv["mistral3.vocab_size"] = p.TextModel.VocabSize
|
||||
|
||||
// Text configuration
|
||||
kv["mistral3.block_count"] = p.TextModel.NumHiddenLayers
|
||||
kv["mistral3.context_length"] = p.TextModel.MaxPositionEmbeddings
|
||||
kv["mistral3.embedding_length"] = p.TextModel.HiddenSize
|
||||
kv["mistral3.feed_forward_length"] = p.TextModel.IntermediateSize
|
||||
kv["mistral3.attention.head_count"] = p.TextModel.NumAttentionHeads
|
||||
kv["mistral3.attention.head_count_kv"] = p.TextModel.NumKeyValueHeads
|
||||
kv["mistral3.attention.layer_norm_rms_epsilon"] = p.TextModel.RMSNormEPS
|
||||
kv["mistral3.attention.key_length"] = p.TextModel.HeadDim
|
||||
kv["mistral3.attention.value_length"] = p.TextModel.HeadDim
|
||||
kv["mistral3.rope.dimension_count"] = p.TextModel.HiddenSize / p.TextModel.NumHiddenLayers
|
||||
kv["mistral3.rope.freq_base"] = p.TextModel.RopeTheta
|
||||
|
||||
// Vision configuration
|
||||
kv["mistral3.vision.block_count"] = p.VisionModel.NumHiddenLayers
|
||||
kv["mistral3.vision.embedding_length"] = p.VisionModel.HiddenSize
|
||||
kv["mistral3.vision.feed_forward_length"] = p.VisionModel.IntermediateSize
|
||||
kv["mistral3.vision.attention.head_count"] = p.VisionModel.NumAttentionHeads
|
||||
kv["mistral3.vision.attention.key_length"] = p.VisionModel.HeadDim
|
||||
kv["mistral3.vision.image_size"] = p.VisionModel.ImageSize
|
||||
kv["mistral3.vision.patch_size"] = p.VisionModel.PatchSize
|
||||
kv["mistral3.vision.num_channels"] = p.VisionModel.NumChannels
|
||||
// kv["mistral3.vision.attention.layer_norm_epsilon"] = 1e-05 // Default value
|
||||
kv["mistral3.vision.rope.freq_base"] = p.VisionModel.RopeTheta
|
||||
|
||||
// Multimodal configuration
|
||||
kv["mistral3.image_token_index"] = p.ImageTokenIndex
|
||||
kv["mistral3.spatial_merge_size"] = p.SpatialMergeSize
|
||||
|
||||
kv["mistral3.mm.projector_bias"] = p.MultiModalProjectorBias
|
||||
|
||||
if p.ProjectorHiddenAct != "" {
|
||||
kv["mistral3.mm.projector_hidden_act"] = p.ProjectorHiddenAct
|
||||
}
|
||||
|
||||
return kv
|
||||
}
|
||||
|
||||
func (p *mistral3Model) Tensors(ts []Tensor) []ggml.Tensor {
|
||||
var out []ggml.Tensor
|
||||
|
||||
for _, t := range ts {
|
||||
if strings.HasSuffix(t.Name(), "attn_q.weight") ||
|
||||
strings.HasSuffix(t.Name(), "attn_k.weight") {
|
||||
t.SetRepacker(p.repack)
|
||||
}
|
||||
|
||||
// Skip certain vision model tensors that might need special handling
|
||||
if strings.HasPrefix(t.Name(), "patch_merger.") || strings.HasPrefix(t.Name(), "pre_mm_projector_output_norm.") {
|
||||
continue
|
||||
}
|
||||
|
||||
out = append(out, ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func (p *mistral3Model) Replacements() []string {
|
||||
return []string{
|
||||
// Text model replacements
|
||||
"model.layers", "blk",
|
||||
"input_layernorm", "attn_norm",
|
||||
"post_attention_layernorm", "ffn_norm",
|
||||
"lm_head", "output",
|
||||
"model.embed_tokens.weight", "token_embd.weight",
|
||||
"model.norm.weight", "output_norm.weight",
|
||||
"self_attn.q_proj", "attn_q",
|
||||
"self_attn.k_proj", "attn_k",
|
||||
"self_attn.v_proj", "attn_v",
|
||||
"self_attn.o_proj", "attn_output",
|
||||
"mlp.down_proj", "ffn_down",
|
||||
"mlp.gate_proj", "ffn_gate",
|
||||
"mlp.up_proj", "ffn_up",
|
||||
|
||||
// Language model replacements
|
||||
"language_model.model.embed_tokens", "token_embd",
|
||||
"language_model.model.layers", "blk",
|
||||
"language_model.model.layers.*.input_layernorm", "attn_norm",
|
||||
"language_model.model.layers.*.self_attn.q_proj", "attn_q",
|
||||
"language_model.model.layers.*.self_attn.k_proj", "attn_k",
|
||||
"language_model.model.layers.*.self_attn.v_proj", "attn_v",
|
||||
"language_model.model.layers.*.self_attn.o_proj", "attn_output",
|
||||
"language_model.model.layers.*.mlp.gate_proj", "ffn_gate",
|
||||
"language_model.model.layers.*.mlp.down_proj", "ffn_down",
|
||||
"language_model.model.layers.*.mlp.up_proj", "ffn_up",
|
||||
"language_model.model.layers.*.post_attention_layernorm", "ffn_norm",
|
||||
"language_model.lm_head", "output",
|
||||
"language_model.model.norm", "output_norm",
|
||||
|
||||
// Vision model replacements - map to shorter prefixes
|
||||
"vision_tower", "v",
|
||||
"multi_modal_projector", "mm",
|
||||
|
||||
// Vision transformer blocks - these should be updated accordingly
|
||||
"vision_tower.transformer.layers", "v.blk",
|
||||
"vision_tower.transformer.layers.*.attention_norm", "v.attn_norm",
|
||||
"vision_tower.transformer.layers.*.attention.q_proj", "v.attn_q",
|
||||
"vision_tower.transformer.layers.*.attention.k_proj", "v.attn_k",
|
||||
"vision_tower.transformer.layers.*.attention.v_proj", "v.attn_v",
|
||||
"vision_tower.transformer.layers.*.attention.o_proj", "v.attn_output",
|
||||
"vision_tower.transformer.layers.*.feed_forward.gate_proj", "v.ffn_gate",
|
||||
"vision_tower.transformer.layers.*.feed_forward.down_proj", "v.ffn_down",
|
||||
"vision_tower.transformer.layers.*.feed_forward.up_proj", "v.ffn_up",
|
||||
"vision_tower.transformer.layers.*.ffn_norm", "v.ffn_norm",
|
||||
"vision_tower.ln_pre", "v.encoder_norm",
|
||||
"vision_tower.patch_conv", "v.patch_conv",
|
||||
"vision_tower.embeddings", "v.embeddings",
|
||||
|
||||
// Alternative vision model paths
|
||||
"vision_model.vision_model.embeddings", "v.embeddings",
|
||||
"vision_model.vision_model", "v",
|
||||
"vision_model.layers", "v.blk",
|
||||
|
||||
// Multimodal projector components
|
||||
"multi_modal_projector.patch_merger", "mm.patch_merger",
|
||||
"multi_modal_projector.norm", "mm.norm",
|
||||
"multi_modal_projector.linear", "mm.projection",
|
||||
}
|
||||
}
|
||||
|
||||
func (p *mistral3Model) repack(name string, data []float32, shape []uint64) ([]float32, error) {
|
||||
var dims []int
|
||||
for _, dim := range shape {
|
||||
dims = append(dims, int(dim))
|
||||
}
|
||||
|
||||
var heads uint32
|
||||
if strings.HasSuffix(name, "attn_q.weight") {
|
||||
heads = p.TextModel.NumAttentionHeads
|
||||
} else if strings.HasSuffix(name, "attn_k.weight") {
|
||||
heads = cmp.Or(p.TextModel.NumKeyValueHeads, p.TextModel.NumAttentionHeads)
|
||||
} else {
|
||||
return nil, fmt.Errorf("unknown tensor for repack: %s", name)
|
||||
}
|
||||
|
||||
n := tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
|
||||
if err := n.Reshape(append([]int{int(heads), 2, dims[0] / int(heads) / 2}, dims[1:]...)...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := n.T(0, 2, 1, 3); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := n.Reshape(dims...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := n.Transpose(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ts, err := native.SelectF32(n, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var f32s []float32
|
||||
for _, t := range ts {
|
||||
f32s = append(f32s, t...)
|
||||
}
|
||||
|
||||
return f32s, nil
|
||||
}
|
||||
@@ -62,7 +62,10 @@ func parseTensors(fsys fs.FS, replacer *strings.Replacer) ([]Tensor, error) {
|
||||
Pattern string
|
||||
Func func(fs.FS, *strings.Replacer, ...string) ([]Tensor, error)
|
||||
}{
|
||||
{"*.safetensors", parseSafetensors},
|
||||
{"model-*-of-*.safetensors", parseSafetensors},
|
||||
{"model.safetensors", parseSafetensors},
|
||||
{"adapters.safetensors", parseSafetensors},
|
||||
{"adapter_model.safetensors", parseSafetensors},
|
||||
{"pytorch_model-*-of-*.bin", parseTorch},
|
||||
{"pytorch_model.bin", parseTorch},
|
||||
{"consolidated.*.pth", parseTorch},
|
||||
|
||||
@@ -11,9 +11,10 @@ import (
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/d4l3k/go-bfloat16"
|
||||
"github.com/x448/float16"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/ollama/ollama/types/bfloat16"
|
||||
)
|
||||
|
||||
type safetensorMetadata struct {
|
||||
|
||||
1
go.mod
1
go.mod
@@ -16,7 +16,6 @@ require (
|
||||
|
||||
require (
|
||||
github.com/agnivade/levenshtein v1.1.1
|
||||
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
|
||||
github.com/dlclark/regexp2 v1.11.4
|
||||
github.com/emirpasic/gods/v2 v2.0.0-alpha
|
||||
github.com/google/go-cmp v0.6.0
|
||||
|
||||
2
go.sum
2
go.sum
@@ -35,8 +35,6 @@ github.com/containerd/console v1.0.3 h1:lIr7SlA5PxZyMV30bDW0MGbiOPXwc63yRuCP0ARu
|
||||
github.com/containerd/console v1.0.3/go.mod h1:7LqA/THxQ86k76b8c/EMSiaJ3h1eZkMkXar0TQ1gf3U=
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
|
||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1 h1:cBzrdJPAFBsgCrDPnZxlp1dF2+k4r1kVpD7+1S1PVjY=
|
||||
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1/go.mod h1:uw2gLcxEuYUlAd/EXyjc/v55nd3+47YAgWbSXVxPrNI=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
|
||||
@@ -312,17 +312,19 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
|
||||
return fmt.Errorf("unassigned tensor: %s", t.Name)
|
||||
}
|
||||
|
||||
bts := make([]byte, t.Size())
|
||||
n, err := io.ReadFull(io.NewSectionReader(sr, int64(t.Offset), int64(t.Size())), bts)
|
||||
if err != nil {
|
||||
return err
|
||||
bts := C.malloc(C.size_t(t.Size()))
|
||||
if bts == nil {
|
||||
return errors.New("failed to allocate tensor buffer")
|
||||
}
|
||||
defer C.free(bts)
|
||||
|
||||
buf := unsafe.Slice((*byte)(bts), t.Size())
|
||||
n, err := io.ReadFull(io.NewSectionReader(sr, int64(t.Offset), int64(t.Size())), buf)
|
||||
if err != nil || n != len(buf) {
|
||||
return errors.New("read failed")
|
||||
}
|
||||
|
||||
if n != len(bts) {
|
||||
return errors.New("short read")
|
||||
}
|
||||
|
||||
C.ggml_backend_tensor_set(tt, unsafe.Pointer(&bts[0]), 0, C.size_t(t.Size()))
|
||||
C.ggml_backend_tensor_set(tt, bts, 0, C.size_t(t.Size()))
|
||||
return nil
|
||||
})
|
||||
}
|
||||
@@ -371,7 +373,7 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
|
||||
(*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&schedBufts[0])),
|
||||
C.int(len(schedBackends)),
|
||||
C.size_t(maxGraphNodes),
|
||||
true,
|
||||
C._Bool(len(gpus) > 1 && slices.Contains(gpus, output.d)),
|
||||
),
|
||||
input: deviceBufferTypes[input.d],
|
||||
output: deviceBufferTypes[output.d],
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
type TextConfig struct {
|
||||
type TextOptions struct {
|
||||
hiddenSize, numHeads, numKVHeads int
|
||||
attnKeyLen, attnValLen int
|
||||
eps, ropeScale float32
|
||||
@@ -27,7 +27,7 @@ type TextModel struct {
|
||||
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
||||
|
||||
*TextConfig
|
||||
*TextOptions
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -55,7 +55,7 @@ func newTextModel(c ml.Config) *TextModel {
|
||||
},
|
||||
),
|
||||
Layers: make([]TextLayer, numBlocks),
|
||||
TextConfig: &TextConfig{
|
||||
TextOptions: &TextOptions{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||
@@ -84,7 +84,7 @@ type TextSelfAttention struct {
|
||||
Output *nn.Linear `gguf:"attn_output"`
|
||||
}
|
||||
|
||||
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
|
||||
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
|
||||
batchSize := hiddenState.Dim(1)
|
||||
ropeType := uint32(2)
|
||||
|
||||
@@ -120,12 +120,12 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
|
||||
}
|
||||
|
||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
ropeBase := m.TextConfig.ropeLocalBase
|
||||
ropeBase := m.TextOptions.ropeLocalBase
|
||||
if (layer+1)%gemmaGlobalCacheCount == 0 {
|
||||
ropeBase = m.TextConfig.ropeGlobalBase
|
||||
ropeBase = m.TextOptions.ropeGlobalBase
|
||||
}
|
||||
|
||||
return key.RoPE(ctx, shift, nil, uint32(m.TextConfig.attnKeyLen), uint32(2), ropeBase, m.TextConfig.ropeScale), nil
|
||||
return key.RoPE(ctx, shift, nil, uint32(m.TextOptions.attnKeyLen), uint32(2), ropeBase, m.TextOptions.ropeScale), nil
|
||||
}
|
||||
|
||||
type TextMLP struct {
|
||||
@@ -134,7 +134,7 @@ type TextMLP struct {
|
||||
Gate *nn.Linear `gguf:"ffn_gate"`
|
||||
}
|
||||
|
||||
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextConfig) ml.Tensor {
|
||||
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||
return mlp.Down.Forward(ctx, hiddenState)
|
||||
}
|
||||
@@ -148,7 +148,7 @@ type TextLayer struct {
|
||||
PostMLPNorm *nn.RMSNorm `gguf:"post_ffw_norm"`
|
||||
}
|
||||
|
||||
func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
|
||||
func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
|
||||
residual := hiddenState
|
||||
|
||||
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
@@ -173,7 +173,7 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs,
|
||||
|
||||
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, opts input.Options, cache kvcache.Cache) ml.Tensor {
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
|
||||
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize)))
|
||||
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
|
||||
|
||||
// set image embeddings
|
||||
var except []int
|
||||
@@ -206,7 +206,7 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
|
||||
lastLayerOutputs = outputs
|
||||
}
|
||||
|
||||
hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextConfig)
|
||||
hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextOptions)
|
||||
}
|
||||
|
||||
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
||||
|
||||
@@ -51,7 +51,7 @@ func (p *ImageProcessor) pack(img image.Image, mean, std [3]float32) []float32 {
|
||||
func (p ImageProcessor) ProcessImage(img image.Image) ([]float32, error) {
|
||||
outputSize := image.Point{p.imageSize, p.imageSize}
|
||||
newImage := imageproc.Composite(img)
|
||||
newImage = imageproc.Resize(newImage, outputSize, imageproc.ResizeBicubic)
|
||||
newImage = imageproc.Resize(newImage, outputSize, imageproc.ResizeBilinear)
|
||||
|
||||
data := p.pack(newImage, imageproc.ImageNetStandardMean, imageproc.ImageNetStandardSTD)
|
||||
return data, nil
|
||||
|
||||
@@ -13,9 +13,9 @@ import (
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
hiddenSize, numHeads, numKVHeads, headDim int
|
||||
eps, ropeBase, ropeScale float32
|
||||
ropeDim uint32
|
||||
hiddenSize, numHeads, numKVHeads int
|
||||
eps, ropeBase, ropeScale float32
|
||||
ropeDim uint32
|
||||
}
|
||||
|
||||
type Model struct {
|
||||
@@ -37,8 +37,6 @@ func New(c ml.Config) (model.Model, error) {
|
||||
|
||||
m := Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
// TODO: need to set this in the conversion for mistral:
|
||||
// tokenizer.ggml.pretokenizer = [^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+
|
||||
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
@@ -55,7 +53,6 @@ func New(c ml.Config) (model.Model, error) {
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||
headDim: int(c.Uint("attention.key_length")),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
ropeBase: c.Float("rope.freq_base"),
|
||||
ropeScale: c.Float("rope.freq_scale", 1),
|
||||
@@ -78,36 +75,24 @@ type SelfAttention struct {
|
||||
|
||||
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||
batchSize := hiddenState.Dim(1)
|
||||
headDim := opts.hiddenSize / opts.numHeads
|
||||
ropeType := uint32(0)
|
||||
// Get head dimension - use explicit value if available, otherwise calculate
|
||||
headDim := opts.headDim
|
||||
if headDim == 0 {
|
||||
headDim = opts.hiddenSize / opts.numHeads
|
||||
}
|
||||
|
||||
// Query projection and reshape
|
||||
q := sa.Query.Forward(ctx, hiddenState)
|
||||
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||
q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
||||
|
||||
// Key projection and reshape
|
||||
k := sa.Key.Forward(ctx, hiddenState)
|
||||
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
||||
|
||||
// Value projection and reshape
|
||||
v := sa.Value.Forward(ctx, hiddenState)
|
||||
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
|
||||
// Attention computation
|
||||
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
|
||||
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
|
||||
kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize)
|
||||
|
||||
// Reshape attention output for final projection
|
||||
outputDim := headDim * opts.numHeads
|
||||
kqv = kqv.Reshape(ctx, outputDim, batchSize)
|
||||
|
||||
// Apply output projection
|
||||
return sa.Output.Forward(ctx, kqv)
|
||||
}
|
||||
|
||||
|
||||
@@ -1,139 +0,0 @@
|
||||
package mistral3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"image"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
*TextModel
|
||||
*VisionModel `gguf:"v,vision"`
|
||||
*MultiModalProjector `gguf:"mm"`
|
||||
|
||||
ImageProcessor
|
||||
}
|
||||
|
||||
// Implement MultimodalProcessor interface
|
||||
var _ model.MultimodalProcessor = (*Model)(nil)
|
||||
|
||||
func New(c ml.Config) (model.Model, error) {
|
||||
textModel, err := NewTextModel(c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m := &Model{
|
||||
TextModel: textModel,
|
||||
VisionModel: newVisionModel(c),
|
||||
ImageProcessor: newImageProcessor(c),
|
||||
MultiModalProjector: newMultiModalProjector(c),
|
||||
}
|
||||
|
||||
m.Cache = kvcache.NewCausalCache(m.TextModel.Shift)
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
|
||||
if len(m.VisionModel.Layers) == 0 {
|
||||
return nil, model.ErrNoVisionModel
|
||||
}
|
||||
|
||||
// Decode image
|
||||
image, _, err := image.Decode(bytes.NewReader(multimodalData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Process image
|
||||
f32s, err := m.ImageProcessor.ProcessImage(image)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create tensor from image data
|
||||
pixelValues, err := ctx.Input().FromFloatSlice(f32s,
|
||||
m.ImageProcessor.imageSize,
|
||||
m.ImageProcessor.imageSize,
|
||||
m.ImageProcessor.numChannels,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Forward pass through vision model
|
||||
visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
|
||||
|
||||
// Project to text embedding space
|
||||
visionOutputs = m.MultiModalProjector.Forward(ctx, visionOutputs, m.VisionModel.eps)
|
||||
|
||||
return visionOutputs, nil
|
||||
}
|
||||
|
||||
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||
var result []input.Input
|
||||
|
||||
for _, inp := range inputs {
|
||||
if inp.Multimodal == nil {
|
||||
result = append(result, inp)
|
||||
} else {
|
||||
inputMultimodal := inp.Multimodal.(ml.Tensor)
|
||||
|
||||
// Add special image tokens - using the imageTokenIndex from config
|
||||
result = append(result,
|
||||
input.Input{Token: int32(m.MultiModalProjector.imageTokenIndex)}, // Image token
|
||||
input.Input{Multimodal: inputMultimodal, MultimodalHash: inp.MultimodalHash}, // Image data
|
||||
)
|
||||
|
||||
// Add image token placeholders
|
||||
result = append(result, slices.Repeat([]input.Input{{Token: 0}}, inputMultimodal.Dim(1)-1)...)
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
|
||||
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Handle multimodal inputs
|
||||
// var except []int
|
||||
// hiddenState := m.TextModel.TokenEmbedding.Forward(ctx, inputs)
|
||||
|
||||
// for _, image := range opts.Multimodal {
|
||||
// visionOutputs := image.Multimodal.(ml.Tensor)
|
||||
|
||||
// // Copy vision outputs into the hidden state
|
||||
// ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1))))
|
||||
|
||||
// for i := range visionOutputs.Dim(1) {
|
||||
// except = append(except, image.Index+i)
|
||||
// }
|
||||
// }
|
||||
|
||||
return m.TextModel.Forward(ctx, inputs, positions, outputs, opts, m.Cache), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
model.Register("mistral3", New)
|
||||
}
|
||||
@@ -1,171 +0,0 @@
|
||||
package mistral3
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
type TextOptions struct {
|
||||
hiddenSize, numHeads, numKVHeads, headDim int
|
||||
eps, ropeBase, ropeScale float32
|
||||
ropeDim uint32
|
||||
}
|
||||
|
||||
type TextModel struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
||||
|
||||
*TextOptions
|
||||
}
|
||||
|
||||
type SelfAttention struct {
|
||||
Query *nn.Linear `gguf:"attn_q"`
|
||||
Key *nn.Linear `gguf:"attn_k"`
|
||||
Value *nn.Linear `gguf:"attn_v"`
|
||||
Output *nn.Linear `gguf:"attn_output"`
|
||||
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
|
||||
}
|
||||
|
||||
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
|
||||
batchSize := hiddenState.Dim(1)
|
||||
ropeType := uint32(0)
|
||||
// Get head dimension - use explicit value if available, otherwise calculate
|
||||
headDim := opts.headDim
|
||||
if headDim == 0 {
|
||||
headDim = opts.hiddenSize / opts.numHeads
|
||||
}
|
||||
|
||||
// Query projection and reshape
|
||||
q := sa.Query.Forward(ctx, hiddenState)
|
||||
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||
q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
||||
|
||||
// Key projection and reshape
|
||||
k := sa.Key.Forward(ctx, hiddenState)
|
||||
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
||||
|
||||
// Value projection and reshape
|
||||
v := sa.Value.Forward(ctx, hiddenState)
|
||||
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
|
||||
// Attention computation
|
||||
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
|
||||
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
|
||||
|
||||
// Reshape attention output for final projection
|
||||
outputDim := headDim * opts.numHeads
|
||||
kqv = kqv.Reshape(ctx, outputDim, batchSize)
|
||||
|
||||
// Apply output projection
|
||||
return sa.Output.Forward(ctx, kqv)
|
||||
}
|
||||
|
||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, uint32(0), m.ropeDim, m.ropeBase, m.ropeScale), nil
|
||||
}
|
||||
|
||||
type MLP struct {
|
||||
Up *nn.Linear `gguf:"ffn_up"`
|
||||
Down *nn.Linear `gguf:"ffn_down"`
|
||||
Gate *nn.Linear `gguf:"ffn_gate"`
|
||||
}
|
||||
|
||||
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||
return mlp.Down.Forward(ctx, hiddenState)
|
||||
}
|
||||
|
||||
type Layer struct {
|
||||
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
|
||||
SelfAttention *SelfAttention
|
||||
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
|
||||
MLP *MLP
|
||||
}
|
||||
|
||||
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
|
||||
residual := hiddenState
|
||||
|
||||
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts)
|
||||
|
||||
// In the final layer (outputs != nil), optimize by pruning to just the token positions
|
||||
// we need logits for.
|
||||
if outputs != nil {
|
||||
hiddenState = hiddenState.Rows(ctx, outputs)
|
||||
residual = residual.Rows(ctx, outputs)
|
||||
}
|
||||
|
||||
hiddenState = hiddenState.Add(ctx, residual)
|
||||
residual = hiddenState
|
||||
|
||||
hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
|
||||
return hiddenState.Add(ctx, residual)
|
||||
}
|
||||
|
||||
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, opts input.Options, cache kvcache.Cache) ml.Tensor {
|
||||
// Process text inputs
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
|
||||
|
||||
// Process through text transformer layers
|
||||
for i, layer := range m.Layers {
|
||||
cache.SetLayer(i)
|
||||
|
||||
var lastLayerOutputs ml.Tensor
|
||||
if i == len(m.Layers)-1 {
|
||||
lastLayerOutputs = outputs
|
||||
}
|
||||
|
||||
hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, cache, m.TextOptions)
|
||||
}
|
||||
|
||||
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
||||
return m.Output.Forward(ctx, hiddenState)
|
||||
}
|
||||
|
||||
func NewTextModel(c ml.Config) (*TextModel, error) {
|
||||
if !strings.EqualFold(c.String("tokenizer.ggml.model"), "gpt2") {
|
||||
return nil, fmt.Errorf("tokenizer %s not yet supported", c.String("tokenizer.ggml.model"))
|
||||
}
|
||||
|
||||
textModel := &TextModel{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
c.String("tokenizer.ggml.pretokenizer", `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Uints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id", 1)),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id", 2)),
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
},
|
||||
),
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
TextOptions: &TextOptions{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||
headDim: int(c.Uint("attention.key_length")),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
ropeBase: c.Float("rope.freq_base"),
|
||||
ropeScale: c.Float("rope.freq_scale", 1),
|
||||
ropeDim: c.Uint("rope.dimension_count"),
|
||||
},
|
||||
}
|
||||
|
||||
return textModel, nil
|
||||
}
|
||||
@@ -1,143 +0,0 @@
|
||||
package mistral3
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
)
|
||||
|
||||
var batchSize int = 1
|
||||
|
||||
type VisionSelfAttention struct {
|
||||
Query *nn.Linear `gguf:"attn_q"`
|
||||
Key *nn.Linear `gguf:"attn_k"`
|
||||
Value *nn.Linear `gguf:"attn_v"`
|
||||
Output *nn.Linear `gguf:"attn_output"`
|
||||
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
|
||||
}
|
||||
|
||||
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
headDim := opts.headDim
|
||||
|
||||
query := sa.Query.Forward(ctx, hiddenState)
|
||||
key := sa.Key.Forward(ctx, hiddenState)
|
||||
value := sa.Value.Forward(ctx, hiddenState)
|
||||
|
||||
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||
key = key.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||
value = value.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||
|
||||
ropeType := uint32(0)
|
||||
query = query.RoPE(ctx, positionIDs, sa.RopeFactors, uint32(headDim), ropeType, opts.ropeBase, opts.ropeScale)
|
||||
key = key.RoPE(ctx, positionIDs, sa.RopeFactors, uint32(headDim), ropeType, opts.ropeBase, opts.ropeScale)
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), nil)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
|
||||
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
||||
type VisionMLP struct {
|
||||
Gate *nn.Linear `gguf:"ffn_gate"`
|
||||
Up *nn.Linear `gguf:"ffn_up"`
|
||||
Down *nn.Linear `gguf:"ffn_down"`
|
||||
}
|
||||
|
||||
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||
return mlp.Down.Forward(ctx, hiddenState)
|
||||
}
|
||||
|
||||
type VisionEncoderLayer struct {
|
||||
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
|
||||
SelfAttention *VisionSelfAttention
|
||||
|
||||
FFNNorm *nn.RMSNorm `gguf:"ffn_norm"`
|
||||
MLP *VisionMLP `gguf:"mlp"`
|
||||
}
|
||||
|
||||
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
residual := hiddenState
|
||||
|
||||
// self attention
|
||||
hiddenState = e.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = e.SelfAttention.Forward(ctx, hiddenState, positionIDs, opts)
|
||||
hiddenState = hiddenState.Add(ctx, residual)
|
||||
residual = hiddenState
|
||||
|
||||
// feed forward
|
||||
hiddenState = e.FFNNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = e.MLP.Forward(ctx, hiddenState, opts)
|
||||
return hiddenState.Add(ctx, residual)
|
||||
}
|
||||
|
||||
type VisionModelOptions struct {
|
||||
hiddenSize int
|
||||
numHeads int
|
||||
headDim int
|
||||
intermediateSize int
|
||||
imageSize int
|
||||
patchSize int
|
||||
numChannels int
|
||||
eps float32
|
||||
ropeBase float32
|
||||
ropeScale float32
|
||||
}
|
||||
|
||||
type VisionModel struct {
|
||||
PatchEmbedding *nn.Conv2D `gguf:"patch_conv"`
|
||||
EncoderNorm *nn.LayerNorm `gguf:"encoder_norm"`
|
||||
Layers []VisionEncoderLayer `gguf:"blk"`
|
||||
|
||||
*VisionModelOptions
|
||||
}
|
||||
|
||||
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
|
||||
numPatchesH := m.imageSize / m.patchSize
|
||||
numPatchesW := m.imageSize / m.patchSize
|
||||
numPatches := numPatchesH * numPatchesW
|
||||
|
||||
hiddenState := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1)
|
||||
hiddenState = hiddenState.Reshape(ctx, numPatches, m.hiddenSize)
|
||||
hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
|
||||
// Create position IDs
|
||||
positions := make([]int32, numPatches)
|
||||
for i := range positions {
|
||||
positions[i] = int32(i)
|
||||
}
|
||||
|
||||
positionIDs, err := ctx.Input().FromIntSlice(positions, len(positions))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Apply encoder normalization
|
||||
hiddenState = m.EncoderNorm.Forward(ctx, hiddenState, m.eps)
|
||||
|
||||
// Process through transformer layers
|
||||
for _, layer := range m.Layers {
|
||||
hiddenState = layer.Forward(ctx, hiddenState, positionIDs, m.VisionModelOptions)
|
||||
}
|
||||
|
||||
return hiddenState
|
||||
}
|
||||
|
||||
func newVisionModel(c ml.Config) *VisionModel {
|
||||
return &VisionModel{
|
||||
Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count", 24)),
|
||||
VisionModelOptions: &VisionModelOptions{
|
||||
hiddenSize: int(c.Uint("vision.embedding_length", 1024)),
|
||||
numHeads: int(c.Uint("vision.attention.head_count", 16)),
|
||||
headDim: int(c.Uint("vision.attention.key_length", 64)),
|
||||
intermediateSize: int(c.Uint("vision.feed_forward_length", 4096)),
|
||||
imageSize: int(c.Uint("vision.image_size", 1540)),
|
||||
patchSize: int(c.Uint("vision.patch_size", 14)),
|
||||
numChannels: int(c.Uint("vision.num_channels", 3)),
|
||||
eps: c.Float("vision.attention.layer_norm_epsilon", 1e-05),
|
||||
ropeBase: c.Float("vision.rope.freq_base", 10000.0),
|
||||
ropeScale: c.Float("vision.rope.freq_scale", 1.0),
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -1,38 +0,0 @@
|
||||
package mistral3
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
)
|
||||
|
||||
type MultiModalProjector struct {
|
||||
Norm *nn.RMSNorm `gguf:"norm"`
|
||||
Projection *nn.Linear `gguf:"projection"`
|
||||
|
||||
spatialMergeSize int
|
||||
imageTokenIndex int
|
||||
hasBias bool
|
||||
}
|
||||
|
||||
func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, eps float32) ml.Tensor {
|
||||
// Apply normalization
|
||||
visionOutputs = p.Norm.Forward(ctx, visionOutputs, eps)
|
||||
|
||||
// If the spatial merge size is > 1, average pool the patches
|
||||
if p.spatialMergeSize > 1 {
|
||||
// Implementation depends on how the model handles spatial merging
|
||||
// For simplicity, we'll use a spatial pooling approach
|
||||
visionOutputs = visionOutputs.AvgPool2D(ctx, p.spatialMergeSize, p.spatialMergeSize, 0)
|
||||
}
|
||||
|
||||
// Project to text embedding dimension
|
||||
return p.Projection.Forward(ctx, visionOutputs)
|
||||
}
|
||||
|
||||
func newMultiModalProjector(c ml.Config) *MultiModalProjector {
|
||||
return &MultiModalProjector{
|
||||
spatialMergeSize: int(c.Uint("spatial_merge_size", 2)),
|
||||
imageTokenIndex: int(c.Uint("image_token_index", 10)),
|
||||
hasBias: c.Bool("mm.projector_bias", false),
|
||||
}
|
||||
}
|
||||
@@ -4,6 +4,5 @@ import (
|
||||
_ "github.com/ollama/ollama/model/models/gemma2"
|
||||
_ "github.com/ollama/ollama/model/models/gemma3"
|
||||
_ "github.com/ollama/ollama/model/models/llama"
|
||||
_ "github.com/ollama/ollama/model/models/mistral3"
|
||||
_ "github.com/ollama/ollama/model/models/mllama"
|
||||
)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package mistral3
|
||||
package pixtral
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"io"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model/imageproc"
|
||||
)
|
||||
|
||||
@@ -28,8 +27,8 @@ func getResizeOutputImageSize(img image.Image, longestEdge int, patchSize image.
|
||||
|
||||
if ratio > 1.0 {
|
||||
newSize = image.Point{
|
||||
int(math.Floor(float64(b.Max.X) / ratio)),
|
||||
int(math.Floor(float64(b.Max.Y) / ratio)),
|
||||
int(math.Ceil(float64(b.Max.X) / ratio)),
|
||||
int(math.Ceil(float64(b.Max.Y) / ratio)),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -67,30 +66,3 @@ func Preprocess(imageData io.Reader) ([]float32, map[string]any, error) {
|
||||
opts := map[string]any{}
|
||||
return data, opts, nil
|
||||
}
|
||||
|
||||
type ImageProcessor struct {
|
||||
imageSize int
|
||||
patchSize int
|
||||
numChannels int
|
||||
longestEdge int
|
||||
}
|
||||
|
||||
func newImageProcessor(c ml.Config) ImageProcessor {
|
||||
return ImageProcessor{
|
||||
imageSize: int(c.Uint("vision.image_size", 1540)),
|
||||
patchSize: int(c.Uint("vision.patch_size", 14)),
|
||||
numChannels: int(c.Uint("vision.num_channels", 3)),
|
||||
longestEdge: int(c.Uint("vision.longest_edge", 1024)),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ImageProcessor) ProcessImage(img image.Image) ([]float32, error) {
|
||||
outputSize := getResizeOutputImageSize(img, p.longestEdge, image.Point{p.patchSize, p.patchSize})
|
||||
|
||||
newImage := imageproc.Composite(img)
|
||||
newImage = imageproc.Resize(newImage, outputSize, imageproc.ResizeBilinear)
|
||||
|
||||
data := imageproc.Normalize(newImage, imageproc.ClipDefaultMean, imageproc.ClipDefaultSTD, true, true)
|
||||
|
||||
return data, nil
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package mistral3
|
||||
package pixtral
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
@@ -263,10 +263,6 @@ func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||
continue
|
||||
}
|
||||
|
||||
if id := bpe.vocab.Encode(pair.value); id < 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
merges[pair.a].runes = append(left.runes, right.runes...)
|
||||
merges[pair.b].runes = nil
|
||||
|
||||
|
||||
@@ -209,322 +209,6 @@ func TestLlama(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
// tekken loads the Tekken tokenizer for testing
|
||||
func tekken(t testing.TB) TextProcessor {
|
||||
t.Helper()
|
||||
|
||||
// Load tokenizer config from mistral-small
|
||||
tokenizerConfigPath := filepath.Join("testdata", "mistral-small", "tokenizer_config.json")
|
||||
configFile, err := os.Open(tokenizerConfigPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer configFile.Close()
|
||||
|
||||
var config struct {
|
||||
AddBosToken bool `json:"add_bos_token"`
|
||||
AddEosToken bool `json:"add_eos_token"`
|
||||
BosToken string `json:"bos_token"`
|
||||
EosToken string `json:"eos_token"`
|
||||
}
|
||||
if err := json.NewDecoder(configFile).Decode(&config); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Load tokenizer.json which contains the vocabulary and other settings
|
||||
tokenizerJsonPath := filepath.Join("testdata", "mistral-small", "tokenizer.json")
|
||||
tokenizerFile, err := os.Open(tokenizerJsonPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer tokenizerFile.Close()
|
||||
|
||||
var tokenizerData struct {
|
||||
Model struct {
|
||||
Type string `json:"type"`
|
||||
Vocab map[string]int32 `json:"vocab"`
|
||||
Merges []string `json:"merges"`
|
||||
} `json:"model"`
|
||||
AddedTokens []struct {
|
||||
Id int32 `json:"id"`
|
||||
Content string `json:"content"`
|
||||
Special bool `json:"special"`
|
||||
} `json:"added_tokens"`
|
||||
PreTokenizer struct {
|
||||
Type string `json:"type"`
|
||||
Pretokenizers []struct {
|
||||
Type string `json:"type"`
|
||||
Pattern struct {
|
||||
String string `json:"String"`
|
||||
} `json:"pattern"`
|
||||
Behavior string `json:"behavior"`
|
||||
} `json:"pretokenizers"`
|
||||
} `json:"pre_tokenizer"`
|
||||
}
|
||||
if err := json.NewDecoder(tokenizerFile).Decode(&tokenizerData); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Extract the pattern from pre_tokenizer if available
|
||||
var pattern string
|
||||
if tokenizerData.PreTokenizer.Type == "Sequence" && len(tokenizerData.PreTokenizer.Pretokenizers) > 0 {
|
||||
pattern = tokenizerData.PreTokenizer.Pretokenizers[0].Pattern.String
|
||||
}
|
||||
|
||||
// Combine regular vocab and added tokens
|
||||
vocab := tokenizerData.Model.Vocab
|
||||
|
||||
// Add special tokens from added_tokens
|
||||
for _, token := range tokenizerData.AddedTokens {
|
||||
vocab[token.Content] = token.Id
|
||||
}
|
||||
|
||||
// Create vocabulary arrays
|
||||
maxId := int32(-1)
|
||||
for _, id := range vocab {
|
||||
if id > maxId {
|
||||
maxId = id
|
||||
}
|
||||
}
|
||||
|
||||
vocabSize := int(maxId + 1)
|
||||
types := make([]uint32, vocabSize)
|
||||
tokens := make([]string, vocabSize)
|
||||
scores := make([]float32, vocabSize)
|
||||
|
||||
for token, id := range vocab {
|
||||
tokens[id] = token
|
||||
types[id] = TOKEN_TYPE_NORMAL
|
||||
|
||||
// Assign appropriate token types for special tokens
|
||||
if token == "<s>" {
|
||||
types[id] = TOKEN_TYPE_CONTROL
|
||||
} else if token == "</s>" {
|
||||
types[id] = TOKEN_TYPE_CONTROL
|
||||
} else if token == "[INST]" || token == "[/INST]" {
|
||||
types[id] = TOKEN_TYPE_CONTROL
|
||||
}
|
||||
}
|
||||
|
||||
// In Tekken, we don't need to load merges separately as they're part of the model
|
||||
var merges []string
|
||||
|
||||
// Create vocabulary object
|
||||
vocabObj := &Vocabulary{
|
||||
Values: tokens,
|
||||
Types: types,
|
||||
Scores: scores,
|
||||
Merges: merges,
|
||||
BOS: vocab[config.BosToken],
|
||||
EOS: vocab[config.EosToken],
|
||||
AddBOS: config.AddBosToken,
|
||||
AddEOS: config.AddEosToken,
|
||||
}
|
||||
|
||||
// Use pattern from tokenizer.json if available
|
||||
if pattern != "" {
|
||||
// Ensure pattern has proper escaping for Go regexp
|
||||
pattern = strings.ReplaceAll(pattern, "p{", "\\p{")
|
||||
return NewBytePairEncoding(pattern, vocabObj)
|
||||
}
|
||||
|
||||
// Fallback pattern if not found
|
||||
return NewBytePairEncoding(
|
||||
`\p{L}+|\p{N}+|[^\s\p{L}\p{N}]+|\s+`,
|
||||
vocabObj,
|
||||
)
|
||||
}
|
||||
|
||||
func TestTekken(t *testing.T) {
|
||||
// Skip if the test data isn't available
|
||||
if _, err := os.Stat(filepath.Join("testdata", "mistral-small")); os.IsNotExist(err) {
|
||||
t.Skip("Mistral-small test data not available")
|
||||
}
|
||||
|
||||
tokenizer := tekken(t)
|
||||
|
||||
t.Run("whitespace_handling", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// The key difference from SentencePiece is that Tekken doesn't prepend whitespace
|
||||
cases := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{" hello", " hello"},
|
||||
{"hello ", "hello "},
|
||||
{"hello world", "hello world"},
|
||||
{" hello world ", " hello world "},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
ids, err := tokenizer.Encode(tc.input, false)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to encode %q: %v", tc.input, err)
|
||||
continue
|
||||
}
|
||||
|
||||
decoded, err := tokenizer.Decode(ids)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to decode tokens for %q: %v", tc.input, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if decoded != tc.expected {
|
||||
t.Errorf("Whitespace handling: got %q, want %q", decoded, tc.expected)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("chat_templates", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Test the Tekken chat template format which doesn't have spaces after special tokens
|
||||
templates := []struct {
|
||||
input string
|
||||
expectSpace bool // whether we expect a space after special tokens
|
||||
}{
|
||||
{"<s>[INST]user message[/INST]", false},
|
||||
{"<s>[INST] user message[/INST]", true},
|
||||
{"<s>[INST]user message [/INST]", true},
|
||||
}
|
||||
|
||||
for _, tc := range templates {
|
||||
ids, err := tokenizer.Encode(tc.input, false)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to encode %q: %v", tc.input, err)
|
||||
continue
|
||||
}
|
||||
|
||||
decoded, err := tokenizer.Decode(ids)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to decode tokens for %q: %v", tc.input, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if there's a space after special tokens
|
||||
hasSpaceAfterINST := strings.Contains(decoded, "[INST] ")
|
||||
|
||||
if hasSpaceAfterINST != tc.expectSpace {
|
||||
t.Errorf("Chat template space handling: got space=%v, want space=%v for %q",
|
||||
hasSpaceAfterINST, tc.expectSpace, tc.input)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("special_tokens", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Test how Tekken handles special tokens
|
||||
cases := []struct {
|
||||
input string
|
||||
expected []string // We'll check if these tokens are in the decoded output
|
||||
}{
|
||||
{"<s>[INST]hello[/INST]", []string{"<s>", "[INST]", "hello", "[/INST]"}},
|
||||
{"[INST]hello[/INST]</s>", []string{"[INST]", "hello", "[/INST]", "</s>"}},
|
||||
{"<s>[INST]hello[/INST]</s>[INST]again[/INST]", []string{"<s>", "[INST]", "hello", "[/INST]", "</s>", "[INST]", "again", "[/INST]"}},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
ids, err := tokenizer.Encode(tc.input, false)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to encode %q: %v", tc.input, err)
|
||||
continue
|
||||
}
|
||||
|
||||
decoded, err := tokenizer.Decode(ids)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to decode tokens for %q: %v", tc.input, err)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, expected := range tc.expected {
|
||||
if !strings.Contains(decoded, expected) {
|
||||
t.Errorf("Special token handling: %q missing in decoded output %q", expected, decoded)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("vocabulary_coverage", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Tekken has a larger vocabulary, so test coverage of various token types
|
||||
samples := []string{
|
||||
"Hello world!",
|
||||
"This is a test of the Tekken tokenizer.",
|
||||
"It has a considerably larger vocabulary size.",
|
||||
"Special characters: !@#$%^&*()",
|
||||
"Numbers: 1234567890",
|
||||
"Multiple languages: こんにちは 你好 안녕하세요",
|
||||
"Code snippets: def function(): return True",
|
||||
}
|
||||
|
||||
for _, sample := range samples {
|
||||
ids, err := tokenizer.Encode(sample, false)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to encode %q: %v", sample, err)
|
||||
continue
|
||||
}
|
||||
|
||||
decoded, err := tokenizer.Decode(ids)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to decode tokens for %q: %v", sample, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if decoded != sample {
|
||||
t.Errorf("Vocabulary coverage: got %q, want %q", decoded, sample)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("splitting_behavior", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Test the splitting behavior which might differ from SentencePiece
|
||||
cases := map[string][]string{
|
||||
"Hello World!": {"Hello", " World", "!"},
|
||||
"user message": {"user", " message"},
|
||||
"[INST]hello": {"[INST]", "hello"},
|
||||
"hello[/INST]": {"hello", "[/INST]"},
|
||||
}
|
||||
|
||||
for s, want := range cases {
|
||||
got := slices.Collect(tokenizer.(*BytePairEncoding).split(s))
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Errorf("Splitting behavior no match (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("full_chat_sequence", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Test a complete chat sequence with Tekken's format
|
||||
chatSequence := "<s>[INST]user message[/INST]assistant message</s>[INST]new user message[/INST]"
|
||||
|
||||
ids, err := tokenizer.Encode(chatSequence, false)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to encode chat sequence: %v", err)
|
||||
}
|
||||
|
||||
decoded, err := tokenizer.Decode(ids)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to decode chat sequence tokens: %v", err)
|
||||
}
|
||||
|
||||
// In Tekken, the whitespace shouldn't be added after special tokens
|
||||
if strings.Contains(decoded, "[INST] ") {
|
||||
t.Errorf("Tekken chat sequence has unexpected space after [INST]: %q", decoded)
|
||||
}
|
||||
|
||||
if strings.Contains(decoded, "[/INST] ") {
|
||||
t.Errorf("Tekken chat sequence has unexpected space after [/INST]: %q", decoded)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkBytePairEncoding(b *testing.B) {
|
||||
tokenizer := llama(b)
|
||||
bts, err := os.ReadFile(filepath.Join("testdata", "war-and-peace.txt"))
|
||||
|
||||
1217945
model/testdata/mistral-small/tokenizer.json
vendored
1217945
model/testdata/mistral-small/tokenizer.json
vendored
File diff suppressed because it is too large
Load Diff
9020
model/testdata/mistral-small/tokenizer_config.json
vendored
9020
model/testdata/mistral-small/tokenizer_config.json
vendored
File diff suppressed because it is too large
Load Diff
@@ -211,10 +211,16 @@ func filesForModel(path string) ([]string, error) {
|
||||
}
|
||||
|
||||
var files []string
|
||||
if st, _ := glob(filepath.Join(path, "*.safetensors"), "application/octet-stream"); len(st) > 0 {
|
||||
if st, _ := glob(filepath.Join(path, "model*.safetensors"), "application/octet-stream"); len(st) > 0 {
|
||||
// safetensors files might be unresolved git lfs references; skip if they are
|
||||
// covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors
|
||||
files = append(files, st...)
|
||||
} else if st, _ := glob(filepath.Join(path, "adapters.safetensors"), "application/octet-stream"); len(st) > 0 {
|
||||
// covers adapters.safetensors
|
||||
files = append(files, st...)
|
||||
} else if st, _ := glob(filepath.Join(path, "adapter_model.safetensors"), "application/octet-stream"); len(st) > 0 {
|
||||
// covers adapter_model.safetensors
|
||||
files = append(files, st...)
|
||||
} else if pt, _ := glob(filepath.Join(path, "pytorch_model*.bin"), "application/zip"); len(pt) > 0 {
|
||||
// pytorch files might also be unresolved git lfs references; skip if they are
|
||||
// covers pytorch_model-x-of-y.bin, pytorch_model.fp32-x-of-y.bin, pytorch_model.bin
|
||||
|
||||
@@ -89,7 +89,7 @@ type InputCacheSlot struct {
|
||||
lastUsed time.Time
|
||||
}
|
||||
|
||||
func (c *InputCache) LoadCacheSlot(prompt []input.Input, cachePrompt bool) (*InputCacheSlot, []input.Input, error) {
|
||||
func (c *InputCache) LoadCacheSlot(prompt []input.Input) (*InputCacheSlot, []input.Input, error) {
|
||||
var slot *InputCacheSlot
|
||||
var numPast int32
|
||||
var err error
|
||||
@@ -107,11 +107,6 @@ func (c *InputCache) LoadCacheSlot(prompt []input.Input, cachePrompt bool) (*Inp
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// TODO (brucemacd): cachePrompt is always true for completion, but false for embedding, can this be improved?
|
||||
if !cachePrompt {
|
||||
numPast = 0
|
||||
}
|
||||
|
||||
slot.InUse = true
|
||||
slot.lastUsed = time.Now()
|
||||
|
||||
|
||||
@@ -297,3 +297,131 @@ func TestShiftDiscard(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadCacheSlot(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cache InputCache
|
||||
prompt []input.Input
|
||||
wantErr bool
|
||||
expectedSlotId int
|
||||
expectedPrompt int // expected length of remaining prompt
|
||||
}{
|
||||
{
|
||||
name: "Basic cache hit - single user",
|
||||
cache: InputCache{
|
||||
multiUserCache: false,
|
||||
slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
{
|
||||
Id: 1,
|
||||
Inputs: []input.Input{},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-2 * time.Second),
|
||||
},
|
||||
},
|
||||
},
|
||||
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
wantErr: false,
|
||||
expectedSlotId: 0,
|
||||
expectedPrompt: 1, // Only token 3 remains
|
||||
},
|
||||
{
|
||||
name: "Basic cache hit - multi user",
|
||||
cache: InputCache{
|
||||
multiUserCache: true,
|
||||
slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
{
|
||||
Id: 1,
|
||||
Inputs: []input.Input{},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-2 * time.Second),
|
||||
},
|
||||
},
|
||||
},
|
||||
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
wantErr: false,
|
||||
expectedSlotId: 0,
|
||||
expectedPrompt: 1, // Only token 3 remains
|
||||
},
|
||||
{
|
||||
name: "Exact match - leave one input",
|
||||
cache: InputCache{
|
||||
multiUserCache: false,
|
||||
slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
},
|
||||
},
|
||||
prompt: []input.Input{{Token: 1}, {Token: 2}},
|
||||
wantErr: false,
|
||||
expectedSlotId: 0,
|
||||
expectedPrompt: 1, // Should leave 1 token for sampling
|
||||
},
|
||||
{
|
||||
name: "No available slots",
|
||||
cache: InputCache{
|
||||
multiUserCache: false,
|
||||
slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
||||
InUse: true,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
},
|
||||
},
|
||||
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
wantErr: true,
|
||||
expectedSlotId: -1,
|
||||
expectedPrompt: -1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
slot, remainingPrompt, err := tt.cache.LoadCacheSlot(tt.prompt)
|
||||
|
||||
// Check error state
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("LoadCacheSlot() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if tt.wantErr {
|
||||
return // Skip further checks if we expected an error
|
||||
}
|
||||
|
||||
// Verify slot ID
|
||||
if slot.Id != tt.expectedSlotId {
|
||||
t.Errorf("LoadCacheSlot() slot ID = %v, expected %v", slot.Id, tt.expectedSlotId)
|
||||
}
|
||||
|
||||
// Verify slot is now marked in use
|
||||
if !slot.InUse {
|
||||
t.Errorf("LoadCacheSlot() slot not marked InUse")
|
||||
}
|
||||
|
||||
// Verify remaining prompt length
|
||||
if len(remainingPrompt) != tt.expectedPrompt {
|
||||
t.Errorf("LoadCacheSlot() remaining prompt length = %v, expected %v",
|
||||
len(remainingPrompt), tt.expectedPrompt)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -115,6 +115,9 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
||||
params.numKeep = int32(len(inputs))
|
||||
}
|
||||
|
||||
// TODO(jessegross): We should ensure that we always leave minBatch of context space to shift,
|
||||
// otherwise we might truncate or split the batch against the model's wishes
|
||||
|
||||
// Ensure that at least 1 input can be discarded during shift
|
||||
params.numKeep = min(params.numKeep, s.cache.numCtx-1)
|
||||
|
||||
@@ -179,10 +182,6 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, *
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
for _, t := range tokens {
|
||||
decoded, _ := s.model.(model.TextProcessor).Decode([]int32{t})
|
||||
fmt.Println("token", t, "decoded", decoded)
|
||||
}
|
||||
for _, t := range tokens {
|
||||
inputs = append(inputs, input.Input{Token: t})
|
||||
}
|
||||
@@ -370,17 +369,6 @@ func (s *Server) processBatch() error {
|
||||
batchSize := s.batchSize
|
||||
|
||||
for j, inp := range seq.inputs {
|
||||
if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+1) > s.cache.numCtx {
|
||||
if len(seq.pendingInputs) == 0 {
|
||||
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// If we are required to put following inputs into a single batch then extend the
|
||||
// batch size. Since we are only extending the size the minimum amount possible, this
|
||||
// will cause a break if we have pending inputs.
|
||||
@@ -393,6 +381,20 @@ func (s *Server) processBatch() error {
|
||||
break
|
||||
}
|
||||
|
||||
// If the sum of our working set (already processed tokens, tokens we added to this
|
||||
// batch, required following tokens) exceeds the context size, then trigger a shift
|
||||
// now so we don't have to do one later when we can't break the batch.
|
||||
if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+minBatch) > s.cache.numCtx {
|
||||
if len(seq.pendingInputs) != 0 {
|
||||
break
|
||||
}
|
||||
|
||||
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
options.Inputs = append(options.Inputs, inp.Token)
|
||||
if inp.Multimodal != nil {
|
||||
options.Multimodal = append(options.Multimodal, input.MultimodalIndex{Index: len(options.Inputs) - 1, Multimodal: inp.Multimodal})
|
||||
@@ -594,7 +596,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
found := false
|
||||
for i, sq := range s.seqs {
|
||||
if sq == nil {
|
||||
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, true)
|
||||
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs)
|
||||
if err != nil {
|
||||
s.mu.Unlock()
|
||||
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
|
||||
|
||||
@@ -87,8 +87,9 @@ func (s *Sampler) sample(tokens []token) (token, error) {
|
||||
// topK also sorts the tokens in descending order of logits
|
||||
tokens = topK(tokens, s.topK)
|
||||
|
||||
tokens = temperature(tokens, s.temperature)
|
||||
tokens = softmax(tokens)
|
||||
// scale and normalize the tokens in place
|
||||
temperature(tokens, s.temperature)
|
||||
softmax(tokens)
|
||||
|
||||
tokens = topP(tokens, s.topP)
|
||||
tokens = minP(tokens, s.minP)
|
||||
|
||||
@@ -26,17 +26,16 @@ func (h *tokenHeap) Pop() any {
|
||||
}
|
||||
|
||||
// temperature applies scaling to the logits
|
||||
func temperature(ts []token, temp float32) []token {
|
||||
func temperature(ts []token, temp float32) {
|
||||
// Ensure temperature clipping near 0 to avoid numerical instability
|
||||
temp = max(temp, 1e-7)
|
||||
for i := range ts {
|
||||
ts[i].value = ts[i].value / temp
|
||||
}
|
||||
return ts
|
||||
}
|
||||
|
||||
// softmax applies normalization to the logits
|
||||
func softmax(ts []token) []token {
|
||||
func softmax(ts []token) {
|
||||
// Find max logit for numerical stability
|
||||
maxLogit := float32(math.Inf(-1))
|
||||
for _, t := range ts {
|
||||
@@ -56,8 +55,6 @@ func softmax(ts []token) []token {
|
||||
for i := range ts {
|
||||
ts[i].value /= sum
|
||||
}
|
||||
|
||||
return ts
|
||||
}
|
||||
|
||||
// topK limits the number of tokens considered to the k highest logits
|
||||
@@ -99,6 +96,7 @@ func topK(ts []token, k int) []token {
|
||||
}
|
||||
|
||||
// topP limits tokens to those with cumulative probability p
|
||||
// requires ts to be sorted in descending order of probabilities
|
||||
func topP(ts []token, p float32) []token {
|
||||
if p == 1.0 {
|
||||
return ts
|
||||
@@ -109,37 +107,24 @@ func topP(ts []token, p float32) []token {
|
||||
for i, t := range ts {
|
||||
sum += t.value
|
||||
if sum > float32(p) {
|
||||
ts = ts[:i+1]
|
||||
return ts
|
||||
return ts[:i+1]
|
||||
}
|
||||
}
|
||||
|
||||
return ts
|
||||
}
|
||||
|
||||
// minP limits tokens to those with cumulative probability p
|
||||
// minP filters tokens with probabilities >= p * max_prob
|
||||
// requires ts to be sorted in descending order of probabilities
|
||||
func minP(ts []token, p float32) []token {
|
||||
if p == 1.0 {
|
||||
return ts
|
||||
}
|
||||
maxProb := ts[0].value
|
||||
|
||||
maxProb := float32(math.Inf(-1))
|
||||
for _, token := range ts {
|
||||
if token.value > maxProb {
|
||||
maxProb = token.value
|
||||
threshold := maxProb * p
|
||||
|
||||
for i, t := range ts {
|
||||
if t.value < threshold {
|
||||
return ts[:i]
|
||||
}
|
||||
}
|
||||
|
||||
threshold := maxProb * float32(p)
|
||||
|
||||
// Filter tokens in-place
|
||||
validTokens := ts[:0]
|
||||
for i, token := range ts {
|
||||
if token.value >= threshold {
|
||||
validTokens = append(validTokens, ts[i])
|
||||
}
|
||||
}
|
||||
|
||||
ts = validTokens
|
||||
return ts
|
||||
}
|
||||
|
||||
@@ -34,17 +34,22 @@ func compareLogits(t *testing.T, name string, want []float32, got []token) {
|
||||
|
||||
func TestTemperature(t *testing.T) {
|
||||
input := []float32{1.0, 4.0, -2.0, 0.0}
|
||||
got := temperature(toTokens(input), 0.5)
|
||||
tokens := toTokens(input)
|
||||
temperature(tokens, 0.5)
|
||||
want := []float32{2.0, 8.0, -4.0, 0.0}
|
||||
compareLogits(t, "temperature(0.5)", want, got)
|
||||
compareLogits(t, "temperature(0.5)", want, tokens)
|
||||
|
||||
got = temperature(toTokens(input), 1.0)
|
||||
input = []float32{1.0, 4.0, -2.0, 0.0}
|
||||
tokens = toTokens(input)
|
||||
temperature(tokens, 1.0)
|
||||
want = []float32{1.0, 4.0, -2.0, 0.0}
|
||||
compareLogits(t, "temperature(1)", want, got)
|
||||
compareLogits(t, "temperature(1)", want, tokens)
|
||||
|
||||
got = temperature(toTokens(input), 0.0)
|
||||
input = []float32{1.0, 4.0, -2.0, 0.0}
|
||||
tokens = toTokens(input)
|
||||
temperature(tokens, 0.0)
|
||||
want = []float32{1e7, 4e7, -2e7, 0.0}
|
||||
compareLogits(t, "temperature(0)", want, got)
|
||||
compareLogits(t, "temperature(0)", want, tokens)
|
||||
}
|
||||
|
||||
func TestSoftmax(t *testing.T) {
|
||||
@@ -90,16 +95,17 @@ func TestSoftmax(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := softmax(toTokens(tt.input))
|
||||
tokens := toTokens(tt.input)
|
||||
softmax(tokens)
|
||||
|
||||
if tt.expected != nil {
|
||||
compareLogits(t, tt.name, tt.expected, got)
|
||||
compareLogits(t, tt.name, tt.expected, tokens)
|
||||
return
|
||||
}
|
||||
|
||||
// Check probabilities sum to 1
|
||||
var sum float32
|
||||
for _, token := range got {
|
||||
for _, token := range tokens {
|
||||
sum += token.value
|
||||
if token.value < 0 || token.value > 1 {
|
||||
t.Errorf("probability out of range [0,1]: got %f", token.value)
|
||||
@@ -114,38 +120,44 @@ func TestSoftmax(t *testing.T) {
|
||||
|
||||
func TestTopK(t *testing.T) {
|
||||
input := []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
|
||||
|
||||
// Test k=5
|
||||
got := topK(toTokens(input), 5)
|
||||
if len(got) != 5 {
|
||||
t.Errorf("topK(5): wrong length: want 5, got %d", len(got))
|
||||
tokens := toTokens(input)
|
||||
tokens = topK(tokens, 5)
|
||||
if len(tokens) != 5 {
|
||||
t.Errorf("topK(5): wrong length: want 5, got %d", len(tokens))
|
||||
}
|
||||
// Should keep highest 3 values in descending order
|
||||
want := []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154}
|
||||
compareLogits(t, "topK(3)", want, got)
|
||||
compareLogits(t, "topK(3)", want, tokens)
|
||||
|
||||
got = topK(toTokens(input), 20)
|
||||
if len(got) != len(input) {
|
||||
t.Errorf("topK(20): wrong length: want %d, got %d", len(input), len(got))
|
||||
tokens = toTokens(input)
|
||||
tokens = topK(tokens, 20)
|
||||
if len(tokens) != len(input) {
|
||||
t.Errorf("topK(20): wrong length: want %d, got %d", len(input), len(tokens))
|
||||
}
|
||||
|
||||
// Test k=-1
|
||||
input = []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
|
||||
want = []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
|
||||
got = topK(toTokens(input), -1)
|
||||
if len(got) != len(input) {
|
||||
t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(got))
|
||||
tokens = toTokens(input)
|
||||
tokens = topK(tokens, -1)
|
||||
if len(tokens) != len(input) {
|
||||
t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(tokens))
|
||||
}
|
||||
compareLogits(t, "topK(-1)", want, got)
|
||||
compareLogits(t, "topK(-1)", want, tokens)
|
||||
|
||||
// Test k=0
|
||||
input = []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
|
||||
want = []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
|
||||
got = topK(toTokens(input), 0)
|
||||
if len(got) != len(input) {
|
||||
t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(got))
|
||||
tokens = toTokens(input)
|
||||
tokens = topK(tokens, 0)
|
||||
if len(tokens) != len(input) {
|
||||
t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(tokens))
|
||||
}
|
||||
compareLogits(t, "topK(-1)", want, tokens)
|
||||
|
||||
input = []float32{-1e7, -2e7, -3e7, -4e7}
|
||||
tokens = toTokens(input)
|
||||
tokens = topK(tokens, 1)
|
||||
if len(tokens) < 1 {
|
||||
t.Error("topK should keep at least one token")
|
||||
}
|
||||
compareLogits(t, "topK(-1)", want, got)
|
||||
}
|
||||
|
||||
func TestTopP(t *testing.T) {
|
||||
@@ -153,16 +165,25 @@ func TestTopP(t *testing.T) {
|
||||
tokens := toTokens(input)
|
||||
|
||||
// First apply temperature and softmax to get probabilities
|
||||
tokens = softmax(tokens)
|
||||
softmax(tokens)
|
||||
tokens = topK(tokens, 20)
|
||||
|
||||
// Then apply topP
|
||||
got := topP(tokens, 0.95)
|
||||
tokens = topP(tokens, 0.95)
|
||||
|
||||
// Should keep tokens until cumsum > 0.95
|
||||
if len(got) > 3 {
|
||||
t.Errorf("topP(0.95): kept too many tokens: got %d", len(got))
|
||||
t.Logf("got: %v", got)
|
||||
if len(tokens) > 3 {
|
||||
t.Errorf("topP(0.95): kept too many tokens: got %d", len(tokens))
|
||||
t.Logf("got: %v", tokens)
|
||||
}
|
||||
|
||||
// Test edge case - ensure at least one token remains
|
||||
input = []float32{-1e6, -1e6, -1e6} // One dominant token
|
||||
tokens = toTokens(input)
|
||||
softmax(tokens)
|
||||
tokens = topP(tokens, 0.0) // Very small p
|
||||
if len(tokens) < 1 {
|
||||
t.Error("topP should keep at least one token")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -171,14 +192,45 @@ func TestMinP(t *testing.T) {
|
||||
tokens := toTokens(input)
|
||||
|
||||
// First apply temperature and softmax
|
||||
tokens = softmax(tokens)
|
||||
tokens = topK(tokens, 20)
|
||||
softmax(tokens)
|
||||
|
||||
// Then apply minP
|
||||
got := minP(tokens, 0.2)
|
||||
tokens = minP(tokens, 1.0)
|
||||
|
||||
if len(tokens) != 1 {
|
||||
t.Errorf("minP(1.0): should keep all tokens, got %d, want %d", len(tokens), len(tokens))
|
||||
}
|
||||
|
||||
// Test with normal p value
|
||||
tokens = toTokens(input) // Reset tokens
|
||||
tokens = topK(tokens, 20)
|
||||
softmax(tokens)
|
||||
tokens = minP(tokens, 0.2)
|
||||
|
||||
// Should keep tokens with prob >= 0.2 * max_prob
|
||||
if len(got) > 3 {
|
||||
t.Errorf("minP(0.2): kept too many tokens: got %d", len(got))
|
||||
if len(tokens) > 3 {
|
||||
t.Errorf("minP(0.2): kept too many tokens: got %d", len(tokens))
|
||||
t.Logf("got: %v", tokens)
|
||||
}
|
||||
|
||||
// Test with zero p value
|
||||
tokens = toTokens(input) // Reset tokens
|
||||
tokens = topK(tokens, 20)
|
||||
softmax(tokens)
|
||||
tokens = minP(tokens, 0.0)
|
||||
|
||||
// Should keep only the highest probability token
|
||||
if len(tokens) != len(input) {
|
||||
t.Errorf("minP(0.0): should keep only one token, got %d", len(tokens))
|
||||
t.Logf("got: %v", tokens)
|
||||
}
|
||||
|
||||
input = []float32{1e-10, 1e-10, 1e-10}
|
||||
tokens = toTokens(input)
|
||||
softmax(tokens)
|
||||
tokens = minP(tokens, 1.0)
|
||||
if len(tokens) < 1 {
|
||||
t.Error("minP should keep at least one token even with extreme probabilities")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -231,7 +283,7 @@ func BenchmarkTransforms(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
copy(tokensCopy, tokens)
|
||||
topK(tokensCopy, 10)
|
||||
tokens = topK(tokensCopy, 10)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -239,7 +291,7 @@ func BenchmarkTransforms(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
copy(tokensCopy, tokens)
|
||||
topP(tokensCopy, 0.9)
|
||||
tokens = topP(tokensCopy, 0.9)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -247,7 +299,7 @@ func BenchmarkTransforms(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
copy(tokensCopy, tokens)
|
||||
minP(tokensCopy, 0.2)
|
||||
tokens = minP(tokensCopy, 0.2)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -255,7 +307,7 @@ func BenchmarkTransforms(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
copy(tokensCopy, tokens)
|
||||
topK(tokensCopy, 200000)
|
||||
tokens = topK(tokensCopy, 200000)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
21
types/bfloat16/LICENSE
Normal file
21
types/bfloat16/LICENSE
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2021 Tristan Rice
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
57
types/bfloat16/bfloat16.go
Normal file
57
types/bfloat16/bfloat16.go
Normal file
@@ -0,0 +1,57 @@
|
||||
// Vendored code from https://github.com/d4l3k/go-bfloat16
|
||||
// unsafe pointer replaced by "math"
|
||||
package bfloat16
|
||||
|
||||
import "math"
|
||||
|
||||
type BF16 uint16
|
||||
|
||||
func FromBytes(buf []byte) BF16 {
|
||||
return BF16(uint16(buf[0]) + uint16(buf[1])<<8)
|
||||
}
|
||||
|
||||
func ToBytes(b BF16) []byte {
|
||||
return []byte{byte(b & 0xFF), byte(b >> 8)}
|
||||
}
|
||||
|
||||
func Decode(buf []byte) []BF16 {
|
||||
var out []BF16
|
||||
for i := 0; i < len(buf); i += 2 {
|
||||
out = append(out, FromBytes(buf[i:]))
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func Encode(f []BF16) []byte {
|
||||
var out []byte
|
||||
for _, a := range f {
|
||||
out = append(out, ToBytes(a)...)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func DecodeFloat32(buf []byte) []float32 {
|
||||
var out []float32
|
||||
for i := 0; i < len(buf); i += 2 {
|
||||
out = append(out, ToFloat32(FromBytes(buf[i:])))
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func EncodeFloat32(f []float32) []byte {
|
||||
var out []byte
|
||||
for _, a := range f {
|
||||
out = append(out, ToBytes(FromFloat32(a))...)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func ToFloat32(b BF16) float32 {
|
||||
u32 := uint32(b) << 16
|
||||
return math.Float32frombits(u32)
|
||||
}
|
||||
|
||||
func FromFloat32(f float32) BF16 {
|
||||
u32 := math.Float32bits(f)
|
||||
return BF16(u32 >> 16)
|
||||
}
|
||||
53
types/bfloat16/bfloat16_test.go
Normal file
53
types/bfloat16/bfloat16_test.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package bfloat16
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func randomBytes(n int) []byte {
|
||||
out := make([]byte, n)
|
||||
if _, err := rand.Read(out); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func TestEncodeDecode(t *testing.T) {
|
||||
b := randomBytes(1024)
|
||||
bf16 := Decode(b)
|
||||
out := Encode(bf16)
|
||||
if !reflect.DeepEqual(b, out) {
|
||||
t.Fatalf("%+v != %+v", b, out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeDecodeFloat32(t *testing.T) {
|
||||
b := randomBytes(1024)
|
||||
bf16 := DecodeFloat32(b)
|
||||
out := EncodeFloat32(bf16)
|
||||
if !reflect.DeepEqual(b, out) {
|
||||
t.Fatalf("%+v != %+v", b, out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBasicFloat32(t *testing.T) {
|
||||
var in float32 = 1.0
|
||||
out := ToFloat32(FromFloat32(in))
|
||||
if !reflect.DeepEqual(in, out) {
|
||||
t.Fatalf("%+v != %+v", in, out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComplexFloat32(t *testing.T) {
|
||||
var in float32 = 123456789123456789.123456789
|
||||
var want float32 = 123286039799267328.0
|
||||
out := ToFloat32(FromFloat32(in))
|
||||
if in == out {
|
||||
t.Fatalf("no loss of precision")
|
||||
}
|
||||
if out != want {
|
||||
t.Fatalf("%.16f != %.16f", want, out)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user