mirror of
https://github.com/ollama/ollama.git
synced 2026-01-07 23:20:02 -05:00
Compare commits
18 Commits
jmorganca/
...
v0.11.0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d552068413 | ||
|
|
8306248cfb | ||
|
|
f5fd7cc19c | ||
|
|
6a68a17c2c | ||
|
|
aa43da4b5a | ||
|
|
0ac1c0d3c2 | ||
|
|
e6f39bce8f | ||
|
|
0263ad9b6d | ||
|
|
4fb47ed368 | ||
|
|
9194874df1 | ||
|
|
9d1de41b49 | ||
|
|
9679520e87 | ||
|
|
c8ac4cc546 | ||
|
|
6ca094a3ed | ||
|
|
26ade3a349 | ||
|
|
9950f6ec24 | ||
|
|
f1c73840d5 | ||
|
|
4a8fc3f945 |
119
api/types.go
119
api/types.go
@@ -85,10 +85,11 @@ type GenerateRequest struct {
|
||||
Options map[string]any `json:"options"`
|
||||
|
||||
// Think controls whether thinking/reasoning models will think before
|
||||
// responding. Needs to be a pointer so we can distinguish between false
|
||||
// responding. Can be a boolean (true/false) or a string ("high", "medium", "low")
|
||||
// for supported models. Needs to be a pointer so we can distinguish between false
|
||||
// (request that thinking _not_ be used) and unset (use the old behavior
|
||||
// before this option was introduced)
|
||||
Think *bool `json:"think,omitempty"`
|
||||
Think *ThinkValue `json:"think,omitempty"`
|
||||
}
|
||||
|
||||
// ChatRequest describes a request sent by [Client.Chat].
|
||||
@@ -116,8 +117,9 @@ type ChatRequest struct {
|
||||
Options map[string]any `json:"options"`
|
||||
|
||||
// Think controls whether thinking/reasoning models will think before
|
||||
// responding
|
||||
Think *bool `json:"think,omitempty"`
|
||||
// responding. Can be a boolean (true/false) or a string ("high", "medium", "low")
|
||||
// for supported models.
|
||||
Think *ThinkValue `json:"think,omitempty"`
|
||||
}
|
||||
|
||||
type Tools []Tool
|
||||
@@ -508,6 +510,8 @@ type GenerateResponse struct {
|
||||
Context []int `json:"context,omitempty"`
|
||||
|
||||
Metrics
|
||||
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
}
|
||||
|
||||
// ModelDetails provides details about a model.
|
||||
@@ -677,6 +681,113 @@ func DefaultOptions() Options {
|
||||
}
|
||||
}
|
||||
|
||||
// ThinkValue represents a value that can be a boolean or a string ("high", "medium", "low")
|
||||
type ThinkValue struct {
|
||||
// Value can be a bool or string
|
||||
Value interface{}
|
||||
}
|
||||
|
||||
// IsValid checks if the ThinkValue is valid
|
||||
func (t *ThinkValue) IsValid() bool {
|
||||
if t == nil || t.Value == nil {
|
||||
return true // nil is valid (means not set)
|
||||
}
|
||||
|
||||
switch v := t.Value.(type) {
|
||||
case bool:
|
||||
return true
|
||||
case string:
|
||||
return v == "high" || v == "medium" || v == "low"
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// IsBool returns true if the value is a boolean
|
||||
func (t *ThinkValue) IsBool() bool {
|
||||
if t == nil || t.Value == nil {
|
||||
return false
|
||||
}
|
||||
_, ok := t.Value.(bool)
|
||||
return ok
|
||||
}
|
||||
|
||||
// IsString returns true if the value is a string
|
||||
func (t *ThinkValue) IsString() bool {
|
||||
if t == nil || t.Value == nil {
|
||||
return false
|
||||
}
|
||||
_, ok := t.Value.(string)
|
||||
return ok
|
||||
}
|
||||
|
||||
// AsBool returns the value as a bool (true if enabled in any way)
|
||||
func (t *ThinkValue) AsBool() bool {
|
||||
if t == nil || t.Value == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
switch v := t.Value.(type) {
|
||||
case bool:
|
||||
return v
|
||||
case string:
|
||||
// Any string value ("high", "medium", "low") means thinking is enabled
|
||||
return v == "high" || v == "medium" || v == "low"
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// AsString returns the value as a string
|
||||
func (t *ThinkValue) AsString() string {
|
||||
if t == nil || t.Value == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
switch v := t.Value.(type) {
|
||||
case string:
|
||||
return v
|
||||
case bool:
|
||||
if v {
|
||||
return "medium" // Default level when just true
|
||||
}
|
||||
return ""
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements json.Unmarshaler
|
||||
func (t *ThinkValue) UnmarshalJSON(data []byte) error {
|
||||
// Try to unmarshal as bool first
|
||||
var b bool
|
||||
if err := json.Unmarshal(data, &b); err == nil {
|
||||
t.Value = b
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try to unmarshal as string
|
||||
var s string
|
||||
if err := json.Unmarshal(data, &s); err == nil {
|
||||
// Validate string values
|
||||
if s != "high" && s != "medium" && s != "low" {
|
||||
return fmt.Errorf("invalid think value: %q (must be \"high\", \"medium\", \"low\", true, or false)", s)
|
||||
}
|
||||
t.Value = s
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("think must be a boolean or string (\"high\", \"medium\", \"low\")")
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler
|
||||
func (t *ThinkValue) MarshalJSON() ([]byte, error) {
|
||||
if t == nil || t.Value == nil {
|
||||
return []byte("null"), nil
|
||||
}
|
||||
return json.Marshal(t.Value)
|
||||
}
|
||||
|
||||
type Duration struct {
|
||||
time.Duration
|
||||
}
|
||||
|
||||
@@ -374,24 +374,21 @@ func TestPropertyType_MarshalJSON(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestThinking_UnmarshalJSON(t *testing.T) {
|
||||
trueVal := true
|
||||
falseVal := false
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expectedThinking *bool
|
||||
expectedThinking *ThinkValue
|
||||
expectedError bool
|
||||
}{
|
||||
{
|
||||
name: "true",
|
||||
input: `{ "think": true }`,
|
||||
expectedThinking: &trueVal,
|
||||
expectedThinking: &ThinkValue{Value: true},
|
||||
},
|
||||
{
|
||||
name: "false",
|
||||
input: `{ "think": false }`,
|
||||
expectedThinking: &falseVal,
|
||||
expectedThinking: &ThinkValue{Value: false},
|
||||
},
|
||||
{
|
||||
name: "unset",
|
||||
@@ -399,8 +396,23 @@ func TestThinking_UnmarshalJSON(t *testing.T) {
|
||||
expectedThinking: nil,
|
||||
},
|
||||
{
|
||||
name: "invalid",
|
||||
input: `{ "think": "true" }`,
|
||||
name: "string_high",
|
||||
input: `{ "think": "high" }`,
|
||||
expectedThinking: &ThinkValue{Value: "high"},
|
||||
},
|
||||
{
|
||||
name: "string_medium",
|
||||
input: `{ "think": "medium" }`,
|
||||
expectedThinking: &ThinkValue{Value: "medium"},
|
||||
},
|
||||
{
|
||||
name: "string_low",
|
||||
input: `{ "think": "low" }`,
|
||||
expectedThinking: &ThinkValue{Value: "low"},
|
||||
},
|
||||
{
|
||||
name: "invalid_string",
|
||||
input: `{ "think": "invalid" }`,
|
||||
expectedThinking: nil,
|
||||
expectedError: true,
|
||||
},
|
||||
@@ -414,7 +426,12 @@ func TestThinking_UnmarshalJSON(t *testing.T) {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, test.expectedThinking, req.Think)
|
||||
if test.expectedThinking == nil {
|
||||
assert.Nil(t, req.Think)
|
||||
} else {
|
||||
require.NotNil(t, req.Think)
|
||||
assert.Equal(t, test.expectedThinking.Value, req.Think.Value)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
91
cmd/cmd.go
91
cmd/cmd.go
@@ -322,11 +322,23 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
|
||||
thinkFlag := cmd.Flags().Lookup("think")
|
||||
if thinkFlag.Changed {
|
||||
think, err := cmd.Flags().GetBool("think")
|
||||
thinkStr, err := cmd.Flags().GetString("think")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
opts.Think = &think
|
||||
|
||||
// Handle different values for --think
|
||||
switch thinkStr {
|
||||
case "", "true":
|
||||
// --think or --think=true
|
||||
opts.Think = &api.ThinkValue{Value: true}
|
||||
case "false":
|
||||
opts.Think = &api.ThinkValue{Value: false}
|
||||
case "high", "medium", "low":
|
||||
opts.Think = &api.ThinkValue{Value: thinkStr}
|
||||
default:
|
||||
return fmt.Errorf("invalid value for --think: %q (must be true, false, high, medium, or low)", thinkStr)
|
||||
}
|
||||
} else {
|
||||
opts.Think = nil
|
||||
}
|
||||
@@ -977,7 +989,7 @@ type runOptions struct {
|
||||
Options map[string]any
|
||||
MultiModal bool
|
||||
KeepAlive *api.Duration
|
||||
Think *bool
|
||||
Think *api.ThinkValue
|
||||
HideThinking bool
|
||||
}
|
||||
|
||||
@@ -1017,10 +1029,11 @@ func displayResponse(content string, wordWrap bool, state *displayResponseState)
|
||||
}
|
||||
|
||||
switch ch {
|
||||
case ' ':
|
||||
case ' ', '\t':
|
||||
state.wordBuffer = ""
|
||||
case '\n':
|
||||
case '\n', '\r':
|
||||
state.lineLength = 0
|
||||
state.wordBuffer = ""
|
||||
default:
|
||||
state.wordBuffer += string(ch)
|
||||
}
|
||||
@@ -1078,6 +1091,7 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
||||
}()
|
||||
|
||||
var state *displayResponseState = &displayResponseState{}
|
||||
var thinkingContent strings.Builder
|
||||
var latest api.ChatResponse
|
||||
var fullResponse strings.Builder
|
||||
var thinkTagOpened bool = false
|
||||
@@ -1097,14 +1111,21 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
||||
if !thinkTagOpened {
|
||||
fmt.Print(thinkingOutputOpeningText(false))
|
||||
thinkTagOpened = true
|
||||
thinkTagClosed = false
|
||||
}
|
||||
thinkingContent.WriteString(response.Message.Thinking)
|
||||
displayResponse(response.Message.Thinking, opts.WordWrap, state)
|
||||
}
|
||||
|
||||
content := response.Message.Content
|
||||
if thinkTagOpened && !thinkTagClosed && content != "" {
|
||||
if thinkTagOpened && !thinkTagClosed && (content != "" || len(response.Message.ToolCalls) > 0) {
|
||||
if !strings.HasSuffix(thinkingContent.String(), "\n") {
|
||||
fmt.Println()
|
||||
}
|
||||
fmt.Print(thinkingOutputClosingText(false))
|
||||
thinkTagOpened = false
|
||||
thinkTagClosed = true
|
||||
state = &displayResponseState{}
|
||||
}
|
||||
// purposefully not putting thinking blocks in the response, which would
|
||||
// only be needed if we later added tool calling to the cli (they get
|
||||
@@ -1112,6 +1133,13 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
||||
// about to finish some tool calls)
|
||||
fullResponse.WriteString(content)
|
||||
|
||||
if response.Message.ToolCalls != nil {
|
||||
toolCalls := response.Message.ToolCalls
|
||||
if len(toolCalls) > 0 {
|
||||
fmt.Print(renderToolCalls(toolCalls, false))
|
||||
}
|
||||
}
|
||||
|
||||
displayResponse(content, opts.WordWrap, state)
|
||||
|
||||
return nil
|
||||
@@ -1196,6 +1224,7 @@ func generate(cmd *cobra.Command, opts runOptions) error {
|
||||
}()
|
||||
|
||||
var state *displayResponseState = &displayResponseState{}
|
||||
var thinkingContent strings.Builder
|
||||
var thinkTagOpened bool = false
|
||||
var thinkTagClosed bool = false
|
||||
|
||||
@@ -1213,17 +1242,31 @@ func generate(cmd *cobra.Command, opts runOptions) error {
|
||||
if !thinkTagOpened {
|
||||
fmt.Print(thinkingOutputOpeningText(plainText))
|
||||
thinkTagOpened = true
|
||||
thinkTagClosed = false
|
||||
}
|
||||
thinkingContent.WriteString(response.Thinking)
|
||||
displayResponse(response.Thinking, opts.WordWrap, state)
|
||||
}
|
||||
|
||||
if thinkTagOpened && !thinkTagClosed && content != "" {
|
||||
if thinkTagOpened && !thinkTagClosed && (content != "" || len(response.ToolCalls) > 0) {
|
||||
if !strings.HasSuffix(thinkingContent.String(), "\n") {
|
||||
fmt.Println()
|
||||
}
|
||||
fmt.Print(thinkingOutputClosingText(plainText))
|
||||
thinkTagOpened = false
|
||||
thinkTagClosed = true
|
||||
state = &displayResponseState{}
|
||||
}
|
||||
|
||||
displayResponse(content, opts.WordWrap, state)
|
||||
|
||||
if response.ToolCalls != nil {
|
||||
toolCalls := response.ToolCalls
|
||||
if len(toolCalls) > 0 {
|
||||
fmt.Print(renderToolCalls(toolCalls, plainText))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1463,7 +1506,8 @@ func NewCLI() *cobra.Command {
|
||||
runCmd.Flags().Bool("insecure", false, "Use an insecure registry")
|
||||
runCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically")
|
||||
runCmd.Flags().String("format", "", "Response format (e.g. json)")
|
||||
runCmd.Flags().Bool("think", false, "Whether to use thinking mode for supported models")
|
||||
runCmd.Flags().String("think", "", "Enable thinking mode: true/false or high/medium/low for supported models")
|
||||
runCmd.Flags().Lookup("think").NoOptDefVal = "true"
|
||||
runCmd.Flags().Bool("hidethinking", false, "Hide thinking output (if provided)")
|
||||
|
||||
stopCmd := &cobra.Command{
|
||||
@@ -1613,7 +1657,7 @@ func NewCLI() *cobra.Command {
|
||||
// to false).
|
||||
//
|
||||
// If capabilities are not provided, we fetch them from the server.
|
||||
func inferThinkingOption(caps *[]model.Capability, runOpts *runOptions, explicitlySetByUser bool) (*bool, error) {
|
||||
func inferThinkingOption(caps *[]model.Capability, runOpts *runOptions, explicitlySetByUser bool) (*api.ThinkValue, error) {
|
||||
if explicitlySetByUser {
|
||||
return runOpts.Think, nil
|
||||
}
|
||||
@@ -1640,9 +1684,34 @@ func inferThinkingOption(caps *[]model.Capability, runOpts *runOptions, explicit
|
||||
}
|
||||
|
||||
if thinkingSupported {
|
||||
thinking := true
|
||||
return &thinking, nil
|
||||
return &api.ThinkValue{Value: true}, nil
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func renderToolCalls(toolCalls []api.ToolCall, plainText bool) string {
|
||||
out := ""
|
||||
formatExplanation := ""
|
||||
formatValues := ""
|
||||
if !plainText {
|
||||
formatExplanation = readline.ColorGrey + readline.ColorBold
|
||||
formatValues = readline.ColorDefault
|
||||
out += formatExplanation
|
||||
}
|
||||
for i, toolCall := range toolCalls {
|
||||
argsAsJSON, err := json.Marshal(toolCall.Function.Arguments)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
if i > 0 {
|
||||
out += "\n"
|
||||
}
|
||||
// all tool calls are unexpected since we don't currently support registering any in the CLI
|
||||
out += fmt.Sprintf(" Model called a non-existent function '%s()' with arguments: %s", formatValues+toolCall.Function.Name+formatExplanation, formatValues+string(argsAsJSON)+formatExplanation)
|
||||
}
|
||||
if !plainText {
|
||||
out += readline.ColorDefault
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -272,16 +272,29 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
}
|
||||
fmt.Println("Set 'quiet' mode.")
|
||||
case "think":
|
||||
think := true
|
||||
opts.Think = &think
|
||||
thinkValue := api.ThinkValue{Value: true}
|
||||
var maybeLevel string
|
||||
if len(args) > 2 {
|
||||
maybeLevel = args[2]
|
||||
}
|
||||
if maybeLevel != "" {
|
||||
// TODO(drifkin): validate the level, could be model dependent
|
||||
// though... It will also be validated on the server once a call is
|
||||
// made.
|
||||
thinkValue.Value = maybeLevel
|
||||
}
|
||||
opts.Think = &thinkValue
|
||||
thinkExplicitlySet = true
|
||||
if client, err := api.ClientFromEnvironment(); err == nil {
|
||||
ensureThinkingSupport(cmd.Context(), client, opts.Model)
|
||||
}
|
||||
fmt.Println("Set 'think' mode.")
|
||||
if maybeLevel != "" {
|
||||
fmt.Printf("Set 'think' mode to '%s'.\n", maybeLevel)
|
||||
} else {
|
||||
fmt.Println("Set 'think' mode.")
|
||||
}
|
||||
case "nothink":
|
||||
think := false
|
||||
opts.Think = &think
|
||||
opts.Think = &api.ThinkValue{Value: false}
|
||||
thinkExplicitlySet = true
|
||||
if client, err := api.ClientFromEnvironment(); err == nil {
|
||||
ensureThinkingSupport(cmd.Context(), client, opts.Model)
|
||||
@@ -478,7 +491,8 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
|
||||
assistant, err := chat(cmd, opts)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "does not support thinking") {
|
||||
if strings.Contains(err.Error(), "does not support thinking") ||
|
||||
strings.Contains(err.Error(), "invalid think value") {
|
||||
fmt.Printf("error: %v\n", err)
|
||||
sb.Reset()
|
||||
continue
|
||||
|
||||
@@ -202,6 +202,8 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
|
||||
conv = &bertModel{}
|
||||
case "CohereForCausalLM":
|
||||
conv = &commandrModel{}
|
||||
case "GptOssForCausalLM":
|
||||
conv = &gptossModel{}
|
||||
default:
|
||||
return fmt.Errorf("unsupported architecture %q", p.Architectures[0])
|
||||
}
|
||||
|
||||
179
convert/convert_gptoss.go
Normal file
179
convert/convert_gptoss.go
Normal file
@@ -0,0 +1,179 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/pdevine/tensor"
|
||||
"github.com/pdevine/tensor/native"
|
||||
)
|
||||
|
||||
type gptossModel struct {
|
||||
ModelParameters
|
||||
HiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
AttentionHeads uint32 `json:"num_attention_heads"`
|
||||
KeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
HeadDim uint32 `json:"head_dim"`
|
||||
Experts uint32 `json:"num_experts"`
|
||||
ExpertsPerToken uint32 `json:"experts_per_token"`
|
||||
RMSNormEpsilon float32 `json:"rms_norm_eps"`
|
||||
InitialContextLength uint32 `json:"initial_context_length"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
RopeScalingFactor float32 `json:"rope_scaling_factor"`
|
||||
SlidingWindow uint32 `json:"sliding_window"`
|
||||
}
|
||||
|
||||
var _ ModelConverter = (*gptossModel)(nil)
|
||||
|
||||
func (m *gptossModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := m.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "gptoss"
|
||||
kv["general.file_type"] = uint32(4)
|
||||
kv["gptoss.context_length"] = uint32(m.RopeScalingFactor * float32(m.InitialContextLength))
|
||||
kv["gptoss.block_count"] = m.HiddenLayers
|
||||
kv["gptoss.embedding_length"] = m.HiddenSize
|
||||
kv["gptoss.feed_forward_length"] = m.IntermediateSize
|
||||
kv["gptoss.expert_count"] = m.Experts
|
||||
kv["gptoss.expert_used_count"] = m.ExpertsPerToken
|
||||
kv["gptoss.attention.head_count"] = m.AttentionHeads
|
||||
kv["gptoss.attention.head_count_kv"] = m.KeyValueHeads
|
||||
kv["gptoss.attention.key_length"] = m.HeadDim
|
||||
kv["gptoss.attention.value_length"] = m.HeadDim
|
||||
kv["gptoss.attention.layer_norm_rms_epsilon"] = cmp.Or(m.RMSNormEpsilon, 1e-5)
|
||||
kv["gptoss.attention.sliding_window"] = m.SlidingWindow
|
||||
kv["gptoss.rope.freq_base"] = m.RopeTheta
|
||||
kv["gptoss.rope.scaling.factor"] = m.RopeScalingFactor
|
||||
kv["gptoss.rope.scaling.original_context_length"] = m.InitialContextLength
|
||||
kv["tokenizer.ggml.bos_token_id"] = uint32(199998) // <|startoftext|>
|
||||
kv["tokenizer.ggml.add_bos_token"] = false
|
||||
kv["tokenizer.ggml.eos_token_id"] = uint32(199999) // <|endoftext|>
|
||||
kv["tokenizer.ggml.eos_token_ids"] = []int32{
|
||||
199999, /* <|endoftext|> */
|
||||
200002, /* <|return|> */
|
||||
200012, /* <|call|> */
|
||||
}
|
||||
kv["tokenizer.ggml.add_eos_token"] = false
|
||||
return kv
|
||||
}
|
||||
|
||||
func (m *gptossModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
var out []*ggml.Tensor
|
||||
mxfp4s := make(map[string]*mxfp4)
|
||||
for _, t := range ts {
|
||||
if strings.HasSuffix(t.Name(), ".blocks") || strings.HasSuffix(t.Name(), ".scales") {
|
||||
dot := strings.LastIndex(t.Name(), ".")
|
||||
name, suffix := t.Name()[:dot], t.Name()[dot+1:]
|
||||
if _, ok := mxfp4s[name]; !ok {
|
||||
mxfp4s[name] = &mxfp4{}
|
||||
}
|
||||
|
||||
switch suffix {
|
||||
case "blocks":
|
||||
mxfp4s[name].blocks = t
|
||||
case "scales":
|
||||
mxfp4s[name].scales = t
|
||||
}
|
||||
|
||||
} else {
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
for name, mxfp4 := range mxfp4s {
|
||||
dims := mxfp4.blocks.Shape()
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: name,
|
||||
Kind: uint32(ggml.TensorTypeMXFP4),
|
||||
Shape: []uint64{dims[0], dims[1], dims[2] * dims[3] * 2},
|
||||
WriterTo: mxfp4,
|
||||
})
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func (m *gptossModel) Replacements() []string {
|
||||
return []string{
|
||||
// noop replacements so other replacements will not be applied
|
||||
".blocks", ".blocks",
|
||||
".scales", ".scales",
|
||||
// real replacements
|
||||
"block", "blk",
|
||||
"attn.norm", "attn_norm",
|
||||
"attn.qkv", "attn_qkv",
|
||||
"attn.sinks", "attn_sinks",
|
||||
"attn.out", "attn_out",
|
||||
"mlp.norm", "ffn_norm",
|
||||
"mlp.gate", "ffn_gate_inp",
|
||||
"mlp.mlp1_", "ffn_gate_up_exps.",
|
||||
"mlp.mlp2_", "ffn_down_exps.",
|
||||
"embedding", "token_embd",
|
||||
"norm", "output_norm",
|
||||
"unembedding", "output",
|
||||
"scale", "weight",
|
||||
}
|
||||
}
|
||||
|
||||
type mxfp4 struct {
|
||||
blocks, scales Tensor
|
||||
}
|
||||
|
||||
func (m *mxfp4) WriteTo(w io.Writer) (int64, error) {
|
||||
var b bytes.Buffer
|
||||
if _, err := m.blocks.WriteTo(&b); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
blocksDims := make([]int, len(m.blocks.Shape()))
|
||||
for i, d := range m.blocks.Shape() {
|
||||
blocksDims[i] = int(d)
|
||||
}
|
||||
|
||||
var blocks tensor.Tensor = tensor.New(tensor.WithShape(blocksDims...), tensor.WithBacking(b.Bytes()))
|
||||
|
||||
var s bytes.Buffer
|
||||
if _, err := m.scales.WriteTo(&s); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
scalesDims := slices.Repeat([]int{1}, len(m.blocks.Shape()))
|
||||
for i, d := range m.scales.Shape() {
|
||||
scalesDims[i] = int(d)
|
||||
}
|
||||
|
||||
var scales tensor.Tensor = tensor.New(tensor.WithShape(scalesDims...), tensor.WithBacking(s.Bytes()))
|
||||
|
||||
out, err := tensor.Concat(3, scales, blocks)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
out = tensor.Materialize(out)
|
||||
|
||||
if err := out.Reshape(out.Shape().TotalSize()); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
u8s, err := native.VectorU8(out.(*tensor.Dense))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if err := binary.Write(w, binary.LittleEndian, u8s); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return 0, nil
|
||||
}
|
||||
@@ -31,8 +31,10 @@ func (t tensorBase) Shape() []uint64 {
|
||||
}
|
||||
|
||||
const (
|
||||
tensorKindF32 uint32 = iota
|
||||
tensorKindF16
|
||||
tensorKindFP32 uint32 = iota
|
||||
tensorKindFP16
|
||||
tensorKindMXFP4 = 4
|
||||
tensorKindBF16 = 30
|
||||
)
|
||||
|
||||
func (t tensorBase) Kind() uint32 {
|
||||
@@ -43,16 +45,16 @@ func (t tensorBase) Kind() uint32 {
|
||||
t.name == "v.pre_tile_position_embd.weight" ||
|
||||
t.name == "v.post_tile_position_embd.weight" {
|
||||
// these tensors are always F32
|
||||
return 0
|
||||
return tensorKindFP32
|
||||
}
|
||||
|
||||
switch len(t.shape) {
|
||||
case 0:
|
||||
panic("invalid tensor shape")
|
||||
case 1:
|
||||
return tensorKindF32
|
||||
return tensorKindFP32
|
||||
default:
|
||||
return tensorKindF16
|
||||
return tensorKindBF16
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -150,6 +150,9 @@ func (st safetensor) WriteTo(w io.Writer) (int64, error) {
|
||||
}
|
||||
|
||||
f32s = bfloat16.DecodeFloat32(u8s)
|
||||
case "U8":
|
||||
// U8 tensors do not support repacking or type conversion.
|
||||
return io.CopyN(w, f, st.size)
|
||||
default:
|
||||
return 0, fmt.Errorf("unknown data type: %s", st.dtype)
|
||||
}
|
||||
@@ -162,15 +165,18 @@ func (st safetensor) WriteTo(w io.Writer) (int64, error) {
|
||||
}
|
||||
|
||||
switch st.Kind() {
|
||||
case tensorKindF32:
|
||||
case tensorKindFP32:
|
||||
return 0, binary.Write(w, binary.LittleEndian, f32s)
|
||||
case tensorKindF16:
|
||||
case tensorKindFP16:
|
||||
f16s := make([]uint16, len(f32s))
|
||||
for i := range f32s {
|
||||
f16s[i] = float16.Fromfloat32(f32s[i]).Bits()
|
||||
}
|
||||
|
||||
return 0, binary.Write(w, binary.LittleEndian, f16s)
|
||||
case tensorKindBF16:
|
||||
u8s := bfloat16.EncodeFloat32(f32s)
|
||||
return 0, binary.Write(w, binary.LittleEndian, u8s)
|
||||
default:
|
||||
return 0, fmt.Errorf("unknown storage type: %d", st.Kind())
|
||||
}
|
||||
|
||||
@@ -72,236 +72,787 @@ func mul(shape []uint64) int {
|
||||
}
|
||||
|
||||
func TestSplitDim(t *testing.T) {
|
||||
r := fakeTensor{
|
||||
name: "a.b",
|
||||
shape: []uint64{3, 4},
|
||||
data: []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11},
|
||||
}
|
||||
|
||||
t.Run("no split", func(t *testing.T) {
|
||||
for tt := range splitDim(&r, 0, split{Replacer: strings.NewReplacer("a", "x")}) {
|
||||
if tt.Name != "x.b" {
|
||||
t.Fatalf("expected name 'x', got '%s'", tt.Name)
|
||||
}
|
||||
|
||||
if !slices.Equal(tt.Shape, []uint64{3, 4}) {
|
||||
t.Fatalf("expected shape [3, 4], got %v", tt.Shape)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tt.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, mul(tt.Shape))
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !slices.Equal(f32s, []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}) {
|
||||
t.Fatalf("expected data [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], got %v", f32s)
|
||||
}
|
||||
t.Run("2d", func(t *testing.T) {
|
||||
r := fakeTensor{
|
||||
name: "a.b",
|
||||
shape: []uint64{3, 4},
|
||||
data: []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11},
|
||||
}
|
||||
|
||||
t.Run("no split", func(t *testing.T) {
|
||||
for tt := range splitDim(&r, 0, split{Replacer: strings.NewReplacer("a", "x")}) {
|
||||
if tt.Name != "x.b" {
|
||||
t.Fatalf("expected name 'x', got '%s'", tt.Name)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.Shape, []uint64{3, 4}); diff != "" {
|
||||
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tt.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, mul(tt.Shape))
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(f32s, []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}); diff != "" {
|
||||
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("even split", func(t *testing.T) {
|
||||
next, stop := iter.Pull(splitDim(&r, 1,
|
||||
split{Replacer: strings.NewReplacer("a", "x")},
|
||||
split{Replacer: strings.NewReplacer("b", "y")},
|
||||
))
|
||||
defer stop()
|
||||
|
||||
{
|
||||
tt, ok := next()
|
||||
if !ok {
|
||||
t.Fatal("expected at least one split")
|
||||
}
|
||||
|
||||
if tt.Name != "x.b" {
|
||||
t.Fatal("expected name 'x.b', got", tt.Name)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.Shape, []uint64{3, 2}); diff != "" {
|
||||
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tt.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, mul(tt.Shape))
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(f32s, []float32{0, 1, 4, 5, 8, 9}); diff != "" {
|
||||
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
tt, ok := next()
|
||||
if !ok {
|
||||
t.Fatal("expected at least one split")
|
||||
}
|
||||
|
||||
if tt.Name != "a.y" {
|
||||
t.Fatal("expected name 'a.y', got", tt.Name)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.Shape, []uint64{3, 2}); diff != "" {
|
||||
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tt.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, mul(tt.Shape))
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(f32s, []float32{2, 3, 6, 7, 10, 11}); diff != "" {
|
||||
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uneven split", func(t *testing.T) {
|
||||
next, stop := iter.Pull(splitDim(&r, 0,
|
||||
split{Replacer: strings.NewReplacer("a", "x"), dim: 2},
|
||||
split{Replacer: strings.NewReplacer("b", "y"), dim: 1},
|
||||
))
|
||||
defer stop()
|
||||
|
||||
{
|
||||
tt, ok := next()
|
||||
if !ok {
|
||||
t.Fatal("expected at least one split")
|
||||
}
|
||||
|
||||
if tt.Name != "x.b" {
|
||||
t.Fatal("expected name 'x.b', got", tt.Name)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.Shape, []uint64{2, 4}); diff != "" {
|
||||
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tt.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, mul(tt.Shape))
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(f32s, []float32{0, 1, 2, 3, 4, 5, 6, 7}); diff != "" {
|
||||
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
tt, ok := next()
|
||||
if !ok {
|
||||
t.Fatal("expected at least one split")
|
||||
}
|
||||
|
||||
if tt.Name != "a.y" {
|
||||
t.Fatal("expected name 'a.y', got", tt.Name)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.Shape, []uint64{1, 4}); diff != "" {
|
||||
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tt.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, mul(tt.Shape))
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(f32s, []float32{8, 9, 10, 11}); diff != "" {
|
||||
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("three way split", func(t *testing.T) {
|
||||
next, stop := iter.Pull(splitDim(&r, 0,
|
||||
split{Replacer: strings.NewReplacer("a", "x"), dim: 1},
|
||||
split{Replacer: strings.NewReplacer("b", "y"), dim: 1},
|
||||
split{Replacer: strings.NewReplacer("b", "z"), dim: 1},
|
||||
))
|
||||
defer stop()
|
||||
|
||||
{
|
||||
tt, ok := next()
|
||||
if !ok {
|
||||
t.Fatal("expected at least one split")
|
||||
}
|
||||
|
||||
if tt.Name != "x.b" {
|
||||
t.Fatal("expected name 'x.b', got", tt.Name)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.Shape, []uint64{1, 4}); diff != "" {
|
||||
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tt.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, mul(tt.Shape))
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(f32s, []float32{0, 1, 2, 3}); diff != "" {
|
||||
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
tt, ok := next()
|
||||
if !ok {
|
||||
t.Fatal("expected at least one split")
|
||||
}
|
||||
|
||||
if tt.Name != "a.y" {
|
||||
t.Fatal("expected name 'x.b', got", tt.Name)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.Shape, []uint64{1, 4}); diff != "" {
|
||||
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tt.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, mul(tt.Shape))
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(f32s, []float32{4, 5, 6, 7}); diff != "" {
|
||||
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
tt, ok := next()
|
||||
if !ok {
|
||||
t.Fatal("expected at least one split")
|
||||
}
|
||||
|
||||
if tt.Name != "a.z" {
|
||||
t.Fatal("expected name 'x.b', got", tt.Name)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.Shape, []uint64{1, 4}); diff != "" {
|
||||
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tt.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, mul(tt.Shape))
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(f32s, []float32{8, 9, 10, 11}); diff != "" {
|
||||
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uneven three way split", func(t *testing.T) {
|
||||
next, stop := iter.Pull(splitDim(&r, 1,
|
||||
split{Replacer: strings.NewReplacer("a", "x"), dim: 2},
|
||||
split{Replacer: strings.NewReplacer("b", "y"), dim: 1},
|
||||
split{Replacer: strings.NewReplacer("b", "z"), dim: 1},
|
||||
))
|
||||
defer stop()
|
||||
|
||||
{
|
||||
tt, ok := next()
|
||||
if !ok {
|
||||
t.Fatal("expected at least one split")
|
||||
}
|
||||
|
||||
if tt.Name != "x.b" {
|
||||
t.Fatal("expected name 'x.b', got", tt.Name)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.Shape, []uint64{3, 2}); diff != "" {
|
||||
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tt.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, mul(tt.Shape))
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(f32s, []float32{0, 1, 4, 5, 8, 9}); diff != "" {
|
||||
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
tt, ok := next()
|
||||
if !ok {
|
||||
t.Fatal("expected at least one split")
|
||||
}
|
||||
|
||||
if tt.Name != "a.y" {
|
||||
t.Fatal("expected name 'x.b', got", tt.Name)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.Shape, []uint64{3, 1}); diff != "" {
|
||||
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tt.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, mul(tt.Shape))
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(f32s, []float32{2, 6, 10}); diff != "" {
|
||||
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
tt, ok := next()
|
||||
if !ok {
|
||||
t.Fatal("expected at least one split")
|
||||
}
|
||||
|
||||
if tt.Name != "a.z" {
|
||||
t.Fatal("expected name 'x.b', got", tt.Name)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.Shape, []uint64{3, 1}); diff != "" {
|
||||
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tt.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, mul(tt.Shape))
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(f32s, []float32{3, 7, 11}); diff != "" {
|
||||
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("split with transpose", func(t *testing.T) {
|
||||
next, stop := iter.Pull(splitDim(&r, 1,
|
||||
split{Replacer: strings.NewReplacer("a", "x")},
|
||||
split{Replacer: strings.NewReplacer("b", "y"), fn: func(tt tensor.Tensor) (tensor.Tensor, error) {
|
||||
return tensor.Transpose(tt, 1, 0)
|
||||
}},
|
||||
))
|
||||
defer stop()
|
||||
|
||||
{
|
||||
tt, ok := next()
|
||||
if !ok {
|
||||
t.Fatal("expected at least one split")
|
||||
}
|
||||
|
||||
if tt.Name != "x.b" {
|
||||
t.Fatal("expected name 'x.b', got", tt.Name)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.Shape, []uint64{3, 2}); diff != "" {
|
||||
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tt.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, mul(tt.Shape))
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(f32s, []float32{0, 1, 4, 5, 8, 9}); diff != "" {
|
||||
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
tt, ok := next()
|
||||
if !ok {
|
||||
t.Fatal("expected at least one split")
|
||||
}
|
||||
|
||||
if tt.Name != "a.y" {
|
||||
t.Fatal("expected name 'a.y', got", tt.Name)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.Shape, []uint64{3, 2}); diff != "" {
|
||||
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tt.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, mul(tt.Shape))
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(f32s, []float32{2, 6, 10, 3, 7, 11}); diff != "" {
|
||||
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("even split", func(t *testing.T) {
|
||||
next, stop := iter.Pull(splitDim(&r, 1,
|
||||
split{Replacer: strings.NewReplacer("a", "x")},
|
||||
split{Replacer: strings.NewReplacer("b", "y")},
|
||||
))
|
||||
defer stop()
|
||||
|
||||
{
|
||||
tt, ok := next()
|
||||
if !ok {
|
||||
t.Fatal("expected at least one split")
|
||||
}
|
||||
|
||||
if tt.Name != "x.b" {
|
||||
t.Fatal("expected name 'x.b', got", tt.Name)
|
||||
}
|
||||
|
||||
if !slices.Equal(tt.Shape, []uint64{3, 2}) {
|
||||
t.Fatal("expected shape [3, 2], got", tt.Shape)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tt.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, mul(tt.Shape))
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !slices.Equal(f32s, []float32{0, 1, 4, 5, 8, 9}) {
|
||||
t.Fatal("expected data [0, 1, 4, 5, 8, 9], got", f32s)
|
||||
}
|
||||
t.Run("3d", func(t *testing.T) {
|
||||
r := fakeTensor{
|
||||
name: "a.b",
|
||||
shape: []uint64{3, 4, 2},
|
||||
data: []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
|
||||
}
|
||||
|
||||
{
|
||||
tt, ok := next()
|
||||
if !ok {
|
||||
t.Fatal("expected at least one split")
|
||||
t.Run("no split", func(t *testing.T) {
|
||||
for tt := range splitDim(&r, 0, split{Replacer: strings.NewReplacer("a", "x")}) {
|
||||
if tt.Name != "x.b" {
|
||||
t.Fatalf("expected name 'x', got '%s'", tt.Name)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.Shape, []uint64{3, 4, 2}); diff != "" {
|
||||
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tt.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, mul(tt.Shape))
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(f32s, []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}); diff != "" {
|
||||
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("even split", func(t *testing.T) {
|
||||
next, stop := iter.Pull(splitDim(&r, 1,
|
||||
split{Replacer: strings.NewReplacer("a", "x")},
|
||||
split{Replacer: strings.NewReplacer("b", "y")},
|
||||
))
|
||||
defer stop()
|
||||
|
||||
{
|
||||
tt, ok := next()
|
||||
if !ok {
|
||||
t.Fatal("expected at least one split")
|
||||
}
|
||||
|
||||
if tt.Name != "x.b" {
|
||||
t.Fatal("expected name 'x.b', got", tt.Name)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.Shape, []uint64{3, 2, 2}); diff != "" {
|
||||
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tt.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, mul(tt.Shape))
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(f32s, []float32{0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19}); diff != "" {
|
||||
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
if tt.Name != "a.y" {
|
||||
t.Fatal("expected name 'a.y', got", tt.Name)
|
||||
{
|
||||
tt, ok := next()
|
||||
if !ok {
|
||||
t.Fatal("expected at least one split")
|
||||
}
|
||||
|
||||
if tt.Name != "a.y" {
|
||||
t.Fatal("expected name 'a.y', got", tt.Name)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.Shape, []uint64{3, 2, 2}); diff != "" {
|
||||
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tt.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, mul(tt.Shape))
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(f32s, []float32{4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23}); diff != "" {
|
||||
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uneven split", func(t *testing.T) {
|
||||
next, stop := iter.Pull(splitDim(&r, 0,
|
||||
split{Replacer: strings.NewReplacer("a", "x"), dim: 2},
|
||||
split{Replacer: strings.NewReplacer("b", "y"), dim: 1},
|
||||
))
|
||||
defer stop()
|
||||
|
||||
{
|
||||
tt, ok := next()
|
||||
if !ok {
|
||||
t.Fatal("expected at least one split")
|
||||
}
|
||||
|
||||
if tt.Name != "x.b" {
|
||||
t.Fatal("expected name 'x.b', got", tt.Name)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.Shape, []uint64{2, 4, 2}); diff != "" {
|
||||
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tt.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, mul(tt.Shape))
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(f32s, []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); diff != "" {
|
||||
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
if !slices.Equal(tt.Shape, []uint64{3, 2}) {
|
||||
t.Fatal("expected shape [3, 2], got", tt.Shape)
|
||||
{
|
||||
tt, ok := next()
|
||||
if !ok {
|
||||
t.Fatal("expected at least one split")
|
||||
}
|
||||
|
||||
if tt.Name != "a.y" {
|
||||
t.Fatal("expected name 'a.y', got", tt.Name)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.Shape, []uint64{1, 4, 2}); diff != "" {
|
||||
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tt.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, mul(tt.Shape))
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(f32s, []float32{16, 17, 18, 19, 20, 21, 22, 23}); diff != "" {
|
||||
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("three way split", func(t *testing.T) {
|
||||
next, stop := iter.Pull(splitDim(&r, 0,
|
||||
split{Replacer: strings.NewReplacer("a", "x"), dim: 1},
|
||||
split{Replacer: strings.NewReplacer("b", "y"), dim: 1},
|
||||
split{Replacer: strings.NewReplacer("b", "z"), dim: 1},
|
||||
))
|
||||
defer stop()
|
||||
|
||||
{
|
||||
tt, ok := next()
|
||||
if !ok {
|
||||
t.Fatal("expected at least one split")
|
||||
}
|
||||
|
||||
if tt.Name != "x.b" {
|
||||
t.Fatal("expected name 'x.b', got", tt.Name)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.Shape, []uint64{1, 4, 2}); diff != "" {
|
||||
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tt.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, mul(tt.Shape))
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(f32s, []float32{0, 1, 2, 3, 4, 5, 6, 7}); diff != "" {
|
||||
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tt.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
{
|
||||
tt, ok := next()
|
||||
if !ok {
|
||||
t.Fatal("expected at least one split")
|
||||
}
|
||||
|
||||
if tt.Name != "a.y" {
|
||||
t.Fatal("expected name 'x.b', got", tt.Name)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.Shape, []uint64{1, 4, 2}); diff != "" {
|
||||
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tt.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, mul(tt.Shape))
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(f32s, []float32{8, 9, 10, 11, 12, 13, 14, 15}); diff != "" {
|
||||
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
f32s := make([]float32, mul(tt.Shape))
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
{
|
||||
tt, ok := next()
|
||||
if !ok {
|
||||
t.Fatal("expected at least one split")
|
||||
}
|
||||
|
||||
if tt.Name != "a.z" {
|
||||
t.Fatal("expected name 'x.b', got", tt.Name)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.Shape, []uint64{1, 4, 2}); diff != "" {
|
||||
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tt.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, mul(tt.Shape))
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(f32s, []float32{16, 17, 18, 19, 20, 21, 22, 23}); diff != "" {
|
||||
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uneven three way split", func(t *testing.T) {
|
||||
next, stop := iter.Pull(splitDim(&r, 1,
|
||||
split{Replacer: strings.NewReplacer("a", "x"), dim: 2},
|
||||
split{Replacer: strings.NewReplacer("b", "y"), dim: 1},
|
||||
split{Replacer: strings.NewReplacer("b", "z"), dim: 1},
|
||||
))
|
||||
defer stop()
|
||||
|
||||
{
|
||||
tt, ok := next()
|
||||
if !ok {
|
||||
t.Fatal("expected at least one split")
|
||||
}
|
||||
|
||||
if tt.Name != "x.b" {
|
||||
t.Fatal("expected name 'x.b', got", tt.Name)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.Shape, []uint64{3, 2, 2}); diff != "" {
|
||||
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tt.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, mul(tt.Shape))
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(f32s, []float32{0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19}); diff != "" {
|
||||
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
if !slices.Equal(f32s, []float32{2, 3, 6, 7, 10, 11}) {
|
||||
t.Fatal("expected data [2, 3, 6, 7, 10, 11], got", f32s)
|
||||
}
|
||||
}
|
||||
})
|
||||
{
|
||||
tt, ok := next()
|
||||
if !ok {
|
||||
t.Fatal("expected at least one split")
|
||||
}
|
||||
|
||||
t.Run("uneven split", func(t *testing.T) {
|
||||
next, stop := iter.Pull(splitDim(&r, 0,
|
||||
split{Replacer: strings.NewReplacer("a", "x"), dim: 2},
|
||||
split{Replacer: strings.NewReplacer("b", "y"), dim: 1},
|
||||
))
|
||||
defer stop()
|
||||
if tt.Name != "a.y" {
|
||||
t.Fatal("expected name 'x.b', got", tt.Name)
|
||||
}
|
||||
|
||||
{
|
||||
tt, ok := next()
|
||||
if !ok {
|
||||
t.Fatal("expected at least one split")
|
||||
if diff := cmp.Diff(tt.Shape, []uint64{3, 1, 2}); diff != "" {
|
||||
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tt.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, mul(tt.Shape))
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(f32s, []float32{4, 5, 12, 13, 20, 21}); diff != "" {
|
||||
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
if tt.Name != "x.b" {
|
||||
t.Fatal("expected name 'x.b', got", tt.Name)
|
||||
}
|
||||
{
|
||||
tt, ok := next()
|
||||
if !ok {
|
||||
t.Fatal("expected at least one split")
|
||||
}
|
||||
|
||||
if !slices.Equal(tt.Shape, []uint64{2, 4}) {
|
||||
t.Fatal("expected shape [2, 4], got", tt.Shape)
|
||||
}
|
||||
if tt.Name != "a.z" {
|
||||
t.Fatal("expected name 'x.b', got", tt.Name)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tt.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if diff := cmp.Diff(tt.Shape, []uint64{3, 1, 2}); diff != "" {
|
||||
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
f32s := make([]float32, mul(tt.Shape))
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var b bytes.Buffer
|
||||
if _, err := tt.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !slices.Equal(f32s, []float32{0, 1, 2, 3, 4, 5, 6, 7}) {
|
||||
t.Fatal("expected data [0, 1, 2, 3, 4, 5, 6, 7], got", f32s)
|
||||
}
|
||||
}
|
||||
f32s := make([]float32, mul(tt.Shape))
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
{
|
||||
tt, ok := next()
|
||||
if !ok {
|
||||
t.Fatal("expected at least one split")
|
||||
if diff := cmp.Diff(f32s, []float32{6, 7, 14, 15, 22, 23}); diff != "" {
|
||||
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
if tt.Name != "a.y" {
|
||||
t.Fatal("expected name 'a.y', got", tt.Name)
|
||||
}
|
||||
|
||||
if !slices.Equal(tt.Shape, []uint64{1, 4}) {
|
||||
t.Fatal("expected shape [1, 4], got", tt.Shape)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tt.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, mul(tt.Shape))
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !slices.Equal(f32s, []float32{8, 9, 10, 11}) {
|
||||
t.Fatal("expected data [8, 9, 10, 11], got", f32s)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("split with transpose", func(t *testing.T) {
|
||||
next, stop := iter.Pull(splitDim(&r, 1,
|
||||
split{Replacer: strings.NewReplacer("a", "x")},
|
||||
split{Replacer: strings.NewReplacer("b", "y"), fn: func(tt tensor.Tensor) (tensor.Tensor, error) {
|
||||
return tensor.Transpose(tt, 1, 0)
|
||||
}},
|
||||
))
|
||||
defer stop()
|
||||
|
||||
{
|
||||
tt, ok := next()
|
||||
if !ok {
|
||||
t.Fatal("expected at least one split")
|
||||
}
|
||||
|
||||
if tt.Name != "x.b" {
|
||||
t.Fatal("expected name 'x.b', got", tt.Name)
|
||||
}
|
||||
|
||||
if !slices.Equal(tt.Shape, []uint64{3, 2}) {
|
||||
t.Fatal("expected shape [3, 2], got", tt.Shape)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tt.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, mul(tt.Shape))
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !slices.Equal(f32s, []float32{0, 1, 4, 5, 8, 9}) {
|
||||
t.Fatal("expected data [0, 1, 4, 5, 8, 9], got", f32s)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
tt, ok := next()
|
||||
if !ok {
|
||||
t.Fatal("expected at least one split")
|
||||
}
|
||||
|
||||
if tt.Name != "a.y" {
|
||||
t.Fatal("expected name 'a.y', got", tt.Name)
|
||||
}
|
||||
|
||||
if !slices.Equal(tt.Shape, []uint64{3, 2}) {
|
||||
t.Fatal("expected shape [3, 2], got", tt.Shape)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tt.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, mul(tt.Shape))
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !slices.Equal(f32s, []float32{2, 6, 10, 3, 7, 11}) {
|
||||
t.Fatal("expected data [2, 6, 10, 3, 7, 11], got", f32s)
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package ggml
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -179,6 +180,7 @@ func (kv KV) OllamaEngineRequired() bool {
|
||||
"llama4",
|
||||
"mllama",
|
||||
"qwen25vl",
|
||||
"gptoss",
|
||||
}, kv.Architecture())
|
||||
}
|
||||
|
||||
@@ -280,7 +282,7 @@ func (t Tensor) block() (n int) {
|
||||
}
|
||||
|
||||
func (t Tensor) blockSize() uint64 {
|
||||
return (TensorType)(t.Kind).BlockSize()
|
||||
return TensorType(t.Kind).BlockSize()
|
||||
}
|
||||
|
||||
func (t TensorType) BlockSize() uint64 {
|
||||
@@ -298,6 +300,7 @@ func (t TensorType) BlockSize() uint64 {
|
||||
case
|
||||
2, // Q4_0
|
||||
3, // Q4_1
|
||||
4, // MXFP4
|
||||
6, // Q5_0
|
||||
7, // Q5_1
|
||||
8, // Q8_0
|
||||
@@ -325,6 +328,8 @@ func (t TensorType) TypeSize() uint64 {
|
||||
return 2 + blockSize/2
|
||||
case TensorTypeQ4_1:
|
||||
return 2 + 2 + blockSize/2
|
||||
case TensorTypeMXFP4:
|
||||
return 1 + blockSize/2
|
||||
case TensorTypeQ5_0:
|
||||
return 2 + 4 + blockSize/2
|
||||
case TensorTypeQ5_1:
|
||||
@@ -487,9 +492,11 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
|
||||
layers := f.Tensors().GroupLayers()
|
||||
|
||||
bytesPerElement := kvCacheBytesPerElement(kvCacheType)
|
||||
var kvTotal uint64
|
||||
kv = make([]uint64, f.KV().BlockCount())
|
||||
for i := range kv {
|
||||
kv[i] = uint64(float64(context*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
|
||||
kvTotal += kv[i]
|
||||
}
|
||||
|
||||
switch f.KV().Architecture() {
|
||||
@@ -658,6 +665,18 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
|
||||
4*qkvBias.Shape[0],
|
||||
)
|
||||
}
|
||||
case "gptoss":
|
||||
kv = make([]uint64, f.KV().BlockCount())
|
||||
for i := range kv {
|
||||
kv[i] = uint64(float64((embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
|
||||
if i%2 == 0 {
|
||||
kv[i] *= (uint64(numParallel)*4096 + batch)
|
||||
} else {
|
||||
kv[i] *= context
|
||||
}
|
||||
}
|
||||
fullOffload = 4 * f.KV().HeadCountMax() / cmp.Or(f.KV().HeadCountKVMin(), 1) * kvTotal / 6
|
||||
partialOffload = 2 * fullOffload
|
||||
}
|
||||
|
||||
return
|
||||
|
||||
@@ -14,9 +14,9 @@ const (
|
||||
FileTypeF16
|
||||
fileTypeQ4_0
|
||||
fileTypeQ4_1
|
||||
fileTypeQ4_1_F16 // unused by GGML
|
||||
fileTypeQ4_2 // unused by GGML
|
||||
fileTypeQ4_3 // unused by GGML
|
||||
fileTypeMXFP4 // originally fileTypeQ4_1_F16 // unused by GGML
|
||||
fileTypeQ4_2 // unused by GGML
|
||||
fileTypeQ4_3 // unused by GGML
|
||||
FileTypeQ8_0
|
||||
fileTypeQ5_0
|
||||
fileTypeQ5_1
|
||||
@@ -97,6 +97,8 @@ func (t FileType) String() string {
|
||||
return "Q4_0"
|
||||
case fileTypeQ4_1:
|
||||
return "Q4_1"
|
||||
case fileTypeMXFP4:
|
||||
return "MXFP4"
|
||||
case FileTypeQ8_0:
|
||||
return "Q8_0"
|
||||
case fileTypeQ5_0:
|
||||
@@ -144,6 +146,8 @@ func (ftype FileType) ToTensorType() TensorType {
|
||||
return TensorTypeQ4_0
|
||||
case fileTypeQ4_1:
|
||||
return TensorTypeQ4_1
|
||||
case fileTypeMXFP4:
|
||||
return TensorTypeMXFP4 // Formerly unused tensorTypeQ4_2
|
||||
case FileTypeQ8_0:
|
||||
return TensorTypeQ8_0
|
||||
case fileTypeQ5_0:
|
||||
@@ -187,8 +191,8 @@ const (
|
||||
TensorTypeF16
|
||||
TensorTypeQ4_0
|
||||
TensorTypeQ4_1
|
||||
tensorTypeQ4_2 // unused by GGML
|
||||
tensorTypeQ4_3 // unused by GGML
|
||||
TensorTypeMXFP4 // Formerly unused tensorTypeQ4_2
|
||||
tensorTypeQ4_3 // unused by GGML
|
||||
TensorTypeQ5_0
|
||||
TensorTypeQ5_1
|
||||
TensorTypeQ8_0
|
||||
@@ -260,6 +264,8 @@ func ParseTensorType(s string) (TensorType, error) {
|
||||
return TensorTypeF64, nil
|
||||
case "BF16":
|
||||
return TensorTypeBF16, nil
|
||||
case "MXFP4":
|
||||
return TensorTypeMXFP4, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("unsupported quantization type %s", s)
|
||||
}
|
||||
@@ -312,6 +318,8 @@ func (t TensorType) String() string {
|
||||
return "F64"
|
||||
case TensorTypeBF16:
|
||||
return "BF16"
|
||||
case TensorTypeMXFP4:
|
||||
return "MXFP4"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
@@ -19,7 +19,7 @@ diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m
|
||||
index a9eeebc6..110c9ece 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal.m
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal.m
|
||||
@@ -489,6 +489,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
|
||||
@@ -489,6 +489,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_COS,
|
||||
GGML_METAL_KERNEL_TYPE_NEG,
|
||||
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
||||
@@ -27,7 +27,7 @@ index a9eeebc6..110c9ece 100644
|
||||
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
|
||||
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
|
||||
GGML_METAL_KERNEL_TYPE_ARGMAX,
|
||||
@@ -1436,6 +1437,7 @@ @implementation GGMLMetalClass
|
||||
@@ -1436,6 +1437,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
||||
|
||||
@@ -12,7 +12,7 @@ diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m
|
||||
index 110c9ece..ab46f6e3 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal.m
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal.m
|
||||
@@ -89,7 +89,11 @@
|
||||
@@ -89,7 +89,11 @@ static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_dev
|
||||
ctx->has_bfloat |= [ctx->mtl_device supportsFamily:MTLGPUFamilyApple6];
|
||||
|
||||
#if defined(GGML_METAL_USE_BF16)
|
||||
|
||||
1293
llama/patches/0023-MXFP4.patch
Normal file
1293
llama/patches/0023-MXFP4.patch
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,34 @@
|
||||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: Michael Yang <git@mxy.ng>
|
||||
Date: Thu, 31 Jul 2025 12:31:58 -0700
|
||||
Subject: [PATCH] cuda: disable graph compat check for OP_ADD
|
||||
|
||||
---
|
||||
ggml/src/ggml-cuda/ggml-cuda.cu | 14 --------------
|
||||
1 file changed, 14 deletions(-)
|
||||
|
||||
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
index bb19b06e..080e7467 100644
|
||||
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
@@ -2509,20 +2509,6 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
|
||||
#endif
|
||||
}
|
||||
|
||||
- // workarounds to exclude Gemma3n's `project_per_layer_input` operation from the batch-size heuristic, specific to ollama's implementation of gemma3n
|
||||
- // number of layers is different for per_layer_proj between gemma3n:2b and gemma3n:4b, which is why we don't check that value here
|
||||
- if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1 && !(node->ne[0] == 256
|
||||
- && node->ne[2] == 1
|
||||
- && node->ne[3] == 1
|
||||
- && node->src[0] ? std::string(node->src[0]->name).find(gemma3n_node_name) != std::string::npos : false
|
||||
- && node->src[1] ? node->src[1]->name == gemma3n_per_layer_proj_src1_name : false)) {
|
||||
- // Generally, changes in batch size or context size can cause changes to the grid size of some kernels.
|
||||
- use_cuda_graph = false;
|
||||
-#ifndef NDEBUG
|
||||
- GGML_LOG_INFO("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
|
||||
-#endif
|
||||
- }
|
||||
-
|
||||
if (node->op == GGML_OP_CPY) {
|
||||
|
||||
// Store the pointers which are updated for each token, such that these can be sent
|
||||
@@ -0,0 +1,25 @@
|
||||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: Daniel Hiltgen <daniel@ollama.com>
|
||||
Date: Sun, 3 Aug 2025 10:00:20 -0700
|
||||
Subject: [PATCH] Disable ggml-blas on macos v13 and older
|
||||
|
||||
---
|
||||
ggml/src/ggml-blas/ggml-blas.cpp | 5 +++++
|
||||
1 file changed, 5 insertions(+)
|
||||
|
||||
diff --git a/ggml/src/ggml-blas/ggml-blas.cpp b/ggml/src/ggml-blas/ggml-blas.cpp
|
||||
index ec158dfa..22926d75 100644
|
||||
--- a/ggml/src/ggml-blas/ggml-blas.cpp
|
||||
+++ b/ggml/src/ggml-blas/ggml-blas.cpp
|
||||
@@ -505,6 +505,11 @@ static const struct ggml_backend_reg_i ggml_backend_blas_reg_i = {
|
||||
};
|
||||
|
||||
ggml_backend_reg_t ggml_backend_blas_reg(void) {
|
||||
+ // MacOS prior to v14 does not include cblas_sgemm - disable this backend if it isn't available
|
||||
+ if (&cblas_sgemm == NULL) {
|
||||
+ GGML_LOG_INFO("Disabling ggml-blas backend on old MacOS version\n");
|
||||
+ return NULL;
|
||||
+ }
|
||||
static struct ggml_backend_reg ggml_backend_blas_reg = {
|
||||
/* .api_version = */ GGML_BACKEND_API_VERSION,
|
||||
/* .iface = */ ggml_backend_blas_reg_i,
|
||||
@@ -276,6 +276,7 @@ type Tensor interface {
|
||||
Cos(ctx Context) Tensor
|
||||
Tanh(ctx Context) Tensor
|
||||
GELU(ctx Context) Tensor
|
||||
QuickGELU(ctx Context) Tensor
|
||||
SILU(ctx Context) Tensor
|
||||
RELU(ctx Context) Tensor
|
||||
Sigmoid(ctx Context) Tensor
|
||||
@@ -283,7 +284,7 @@ type Tensor interface {
|
||||
Reshape(ctx Context, shape ...int) Tensor
|
||||
View(ctx Context, offset int, shape ...int) Tensor
|
||||
Permute(ctx Context, shape ...int) Tensor
|
||||
Contiguous(ctx Context) Tensor
|
||||
Contiguous(ctx Context, shape ...int) Tensor
|
||||
Set(ctx Context, t2 Tensor, offset int, strides ...int) Tensor
|
||||
|
||||
Pad(ctx Context, shape ...int) Tensor
|
||||
@@ -468,4 +469,5 @@ const (
|
||||
DTypeQ80
|
||||
DTypeQ40
|
||||
DTypeI32
|
||||
DTypeMXFP4
|
||||
)
|
||||
|
||||
@@ -239,10 +239,12 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
|
||||
createTensor := func(t tensor, bts []*C.struct_ggml_backend_buffer_type, layer int) *C.struct_ggml_tensor {
|
||||
for _, bt := range bts {
|
||||
if _, ok := ctxs[bt]; !ok {
|
||||
// slog.Info("XXX before ggml_init")
|
||||
ctxs[bt] = C.ggml_init(C.struct_ggml_init_params{
|
||||
mem_size: C.ggml_tensor_overhead() * C.size_t(maxTensors),
|
||||
no_alloc: true,
|
||||
})
|
||||
// slog.Info("XXX after ggml_init")
|
||||
}
|
||||
|
||||
targets[t.source.Name] = append(targets[t.source.Name], t.target)
|
||||
@@ -541,6 +543,8 @@ func (b *Backend) NewContextSize(n int) ml.Context {
|
||||
|
||||
var allocatedBuffers []*C.struct_ggml_backend_buffer
|
||||
|
||||
// slog.Info("XXX before ggml_init")
|
||||
// defer slog.Info("XXX after ggml_init")
|
||||
return &Context{
|
||||
b: b,
|
||||
maxGraphNodes: n,
|
||||
@@ -708,6 +712,8 @@ func (c *Context) newTensor(dtype ml.DType, shape []int) ml.Tensor {
|
||||
cdtype = C.GGML_TYPE_Q4_0
|
||||
case ml.DTypeI32:
|
||||
cdtype = C.GGML_TYPE_I32
|
||||
case ml.DTypeMXFP4:
|
||||
cdtype = C.GGML_TYPE_MXFP4
|
||||
default:
|
||||
panic("unsupported dtype")
|
||||
}
|
||||
@@ -896,6 +902,8 @@ func (t *Tensor) DType() ml.DType {
|
||||
return ml.DTypeQ40
|
||||
case C.GGML_TYPE_I32:
|
||||
return ml.DTypeI32
|
||||
case C.GGML_TYPE_MXFP4:
|
||||
return ml.DTypeMXFP4
|
||||
default:
|
||||
return ml.DTypeOther
|
||||
}
|
||||
@@ -958,10 +966,35 @@ func (t *Tensor) Concat(ctx ml.Context, t2 ml.Tensor, dim int) ml.Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Contiguous(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_cont(ctx.(*Context).ctx, t.t),
|
||||
func (t *Tensor) Contiguous(ctx ml.Context, shape ...int) ml.Tensor {
|
||||
switch len(shape) {
|
||||
case 0:
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_cont(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
case 1:
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_cont_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0])),
|
||||
}
|
||||
case 2:
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_cont_2d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1])),
|
||||
}
|
||||
case 3:
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_cont_3d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2])),
|
||||
}
|
||||
case 4:
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_cont_4d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2]), C.int64_t(shape[3])),
|
||||
}
|
||||
default:
|
||||
panic("unsupported number of dimensions")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1176,11 +1209,18 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
|
||||
|
||||
func (t *Tensor) RoPE(ctx ml.Context, positions ml.Tensor, ropeDim int, ropeBase, ropeScale float32, options ...func(*rope.Options)) ml.Tensor {
|
||||
// Default options
|
||||
opts := &rope.Options{OriginalContextLength: 131072, Factors: &Tensor{}}
|
||||
opts := rope.Options{
|
||||
Factors: &Tensor{},
|
||||
OriginalContextLength: 131072,
|
||||
ExtrapolationFactor: 0.,
|
||||
AttentionFactor: 1.,
|
||||
BetaFast: 32.,
|
||||
BetaSlow: 1.,
|
||||
}
|
||||
|
||||
// Apply any provided options
|
||||
for _, option := range options {
|
||||
option(opts)
|
||||
option(&opts)
|
||||
}
|
||||
|
||||
dequant := t.t
|
||||
@@ -1200,10 +1240,10 @@ func (t *Tensor) RoPE(ctx ml.Context, positions ml.Tensor, ropeDim int, ropeBase
|
||||
C.int(opts.OriginalContextLength),
|
||||
C.float(ropeBase),
|
||||
C.float(ropeScale),
|
||||
C.float(0.0),
|
||||
C.float(1.0),
|
||||
C.float(32.0),
|
||||
C.float(1.0),
|
||||
C.float(opts.ExtrapolationFactor),
|
||||
C.float(opts.AttentionFactor),
|
||||
C.float(opts.BetaFast),
|
||||
C.float(opts.BetaSlow),
|
||||
),
|
||||
}
|
||||
}
|
||||
@@ -1222,6 +1262,13 @@ func (t *Tensor) GELU(ctx ml.Context) ml.Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) QuickGELU(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_gelu_quick_inplace(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) SILU(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
@@ -1350,3 +1397,65 @@ func (t *Tensor) Clamp(ctx ml.Context, min, max float32) ml.Tensor {
|
||||
t: C.ggml_clamp(ctx.(*Context).ctx, t.t, C.float(min), C.float(max)),
|
||||
}
|
||||
}
|
||||
|
||||
func (c Context) FromBytes(dtype ml.DType, s []uint8, shape ...int) ml.Tensor {
|
||||
// Unchecked to handle quantized types
|
||||
t := c.newTensor(dtype, shape)
|
||||
if len(s) > 0 {
|
||||
C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
|
||||
}
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
// TODO - DRY this out with New if possible
|
||||
func newTestBackend(size int) *Backend {
|
||||
var cpus []*C.struct_ggml_backend_device
|
||||
for _, d := range devices() {
|
||||
switch C.ggml_backend_dev_type(d) {
|
||||
case C.GGML_BACKEND_DEVICE_TYPE_CPU:
|
||||
if len(cpus) == 0 {
|
||||
// only the first cpu device should be used
|
||||
cpus = append(cpus, d)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
var schedBackends []*C.struct_ggml_backend
|
||||
var schedBufts []*C.struct_ggml_backend_buffer_type
|
||||
b := C.ggml_backend_dev_init(cpus[0], nil)
|
||||
bt := C.ggml_backend_get_default_buffer_type(b)
|
||||
C.ggml_backend_cpu_set_n_threads(b, C.int(Threads(runtime.NumCPU())))
|
||||
// C.ggml_backend_cpu_set_n_threads(b, 1) // DEBUGGING
|
||||
schedBackends = append(schedBackends, b)
|
||||
schedBufts = append(schedBufts, bt)
|
||||
return &Backend{
|
||||
meta: nil,
|
||||
sched: C.ggml_backend_sched_new(
|
||||
(*C.ggml_backend_t)(unsafe.Pointer(&schedBackends[0])),
|
||||
(*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&schedBufts[0])),
|
||||
C.int(len(schedBackends)),
|
||||
C.size_t(max(8192, size)),
|
||||
false,
|
||||
false,
|
||||
),
|
||||
input: bt,
|
||||
maxGraphNodes: max(8192, size),
|
||||
schedBackends: schedBackends,
|
||||
schedBufts: schedBufts,
|
||||
}
|
||||
}
|
||||
|
||||
func newTestContext(b *Backend, n int) *Context {
|
||||
n = max(8192, n)
|
||||
// slog.Info("XXX before ggml_init")
|
||||
// defer slog.Info("XXX after ggml_init")
|
||||
return &Context{
|
||||
b: b,
|
||||
maxGraphNodes: n,
|
||||
ctx: C.ggml_init(C.struct_ggml_init_params{
|
||||
mem_size: C.size_t(n)*C.ggml_tensor_overhead() + C.ggml_graph_overhead_custom(C.size_t(n), false),
|
||||
no_alloc: true,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
2
ml/backend/ggml/ggml/include/ggml.h
vendored
2
ml/backend/ggml/ggml/include/ggml.h
vendored
@@ -353,7 +353,7 @@ extern "C" {
|
||||
GGML_TYPE_F16 = 1,
|
||||
GGML_TYPE_Q4_0 = 2,
|
||||
GGML_TYPE_Q4_1 = 3,
|
||||
// GGML_TYPE_Q4_2 = 4, support has been removed
|
||||
GGML_TYPE_MXFP4 = 4, // Formerly removed type GGML_TYPE_Q4_2
|
||||
// GGML_TYPE_Q4_3 = 5, support has been removed
|
||||
GGML_TYPE_Q5_0 = 6,
|
||||
GGML_TYPE_Q5_1 = 7,
|
||||
|
||||
@@ -505,6 +505,11 @@ static const struct ggml_backend_reg_i ggml_backend_blas_reg_i = {
|
||||
};
|
||||
|
||||
ggml_backend_reg_t ggml_backend_blas_reg(void) {
|
||||
// MacOS prior to v14 does not include cblas_sgemm - disable this backend if it isn't available
|
||||
if (&cblas_sgemm == NULL) {
|
||||
GGML_LOG_INFO("Disabling ggml-blas backend on old MacOS version\n");
|
||||
return NULL;
|
||||
}
|
||||
static struct ggml_backend_reg ggml_backend_blas_reg = {
|
||||
/* .api_version = */ GGML_BACKEND_API_VERSION,
|
||||
/* .iface = */ ggml_backend_blas_reg_i,
|
||||
|
||||
7
ml/backend/ggml/ggml/src/ggml-common.h
vendored
7
ml/backend/ggml/ggml/src/ggml-common.h
vendored
@@ -417,6 +417,13 @@ typedef struct {
|
||||
} block_iq4_xs;
|
||||
static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding");
|
||||
|
||||
#define MXFP4 32
|
||||
typedef struct {
|
||||
uint8_t d; // scale E8M0 float
|
||||
uint8_t qs[MXFP4 / 2]; // (32) 4 bit elements E2M1 float
|
||||
} block_mxfp4;
|
||||
static_assert(sizeof(block_mxfp4) == sizeof(uint8_t) + MXFP4/2, "wrong mxfp4 block size/padding");
|
||||
|
||||
#endif // GGML_COMMON_DECL
|
||||
#endif // GGML_COMMON_DECL
|
||||
|
||||
|
||||
@@ -58,6 +58,8 @@ void ggml_vec_dot_iq4_nl_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const
|
||||
void ggml_vec_dot_iq4_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
|
||||
void ggml_vec_dot_mxfp4(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const float * GGML_RESTRICT y, size_t by, int nrc);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
5
ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.c
vendored
5
ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.c
vendored
@@ -362,6 +362,11 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
|
||||
.vec_dot_type = GGML_TYPE_Q8_K,
|
||||
.nrows = 1,
|
||||
},
|
||||
[GGML_TYPE_MXFP4] = {
|
||||
.vec_dot = (ggml_vec_dot_t) ggml_vec_dot_mxfp4,
|
||||
.vec_dot_type = GGML_TYPE_F32,
|
||||
.nrows = 1,
|
||||
},
|
||||
};
|
||||
|
||||
const struct ggml_type_traits_cpu * ggml_get_type_traits_cpu(enum ggml_type type) {
|
||||
|
||||
1
ml/backend/ggml/ggml/src/ggml-cpu/ops.cpp
vendored
1
ml/backend/ggml/ggml/src/ggml-cpu/ops.cpp
vendored
@@ -4965,6 +4965,7 @@ void ggml_compute_forward_clamp(
|
||||
case GGML_TYPE_I32:
|
||||
case GGML_TYPE_I64:
|
||||
case GGML_TYPE_F64:
|
||||
case GGML_TYPE_MXFP4:
|
||||
case GGML_TYPE_COUNT:
|
||||
{
|
||||
GGML_ABORT("fatal error");
|
||||
|
||||
90
ml/backend/ggml/ggml/src/ggml-cpu/vec.cpp
vendored
90
ml/backend/ggml/ggml/src/ggml-cpu/vec.cpp
vendored
@@ -250,3 +250,93 @@ ggml_float ggml_vec_log_soft_max_f32(const int n, float * y, const float * x, fl
|
||||
}
|
||||
return sum = (ggml_float)logf(sum);
|
||||
}
|
||||
|
||||
#define MXFP4 32
|
||||
typedef struct {
|
||||
uint8_t d; // scale E8M0 float
|
||||
uint8_t qs[MXFP4 / 2]; // (32) 4 bit elements E2M1 float
|
||||
} block_mxfp4;
|
||||
static_assert(sizeof(block_mxfp4) == sizeof(uint8_t) + MXFP4/2, "wrong mxfp4 block size/padding");
|
||||
#define MXFP4_VALS {0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, 0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0}
|
||||
|
||||
void ggml_vec_dot_mxfp4(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const float * GGML_RESTRICT y, size_t by, int nrc) {
|
||||
assert(nrc == 1);
|
||||
GGML_UNUSED(nrc);
|
||||
GGML_UNUSED(bx);
|
||||
GGML_UNUSED(by);
|
||||
GGML_UNUSED(bs);
|
||||
ggml_float mxfp4_table[] = MXFP4_VALS;
|
||||
|
||||
#if defined(GGML_SIMD)
|
||||
float sumf = 0.0f;
|
||||
const int np = (n & ~(GGML_F32_STEP - 1));
|
||||
const block_mxfp4 * GGML_RESTRICT xx = (const block_mxfp4 *) vx;
|
||||
GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
|
||||
|
||||
GGML_F32_VEC scalev;
|
||||
GGML_F32_VEC ax[GGML_F32_ARR];
|
||||
GGML_F32_VEC ay[GGML_F32_ARR];
|
||||
for (int i = 0; i < np; i += GGML_F32_STEP) { // ARM: +16 AVX512: +64
|
||||
for (int j = 0; j < GGML_F32_ARR; j++) { // ARM: 0 .. 4 AVX512: 0 .. 4
|
||||
// convert GGML_F32_ARR X elements
|
||||
const int ib = (i + j*GGML_F32_EPR) / MXFP4;
|
||||
const block_mxfp4 * GGML_RESTRICT x = &xx[ib];
|
||||
union {
|
||||
uint32_t as_bits;
|
||||
float as_value;
|
||||
} scale;
|
||||
scale.as_bits = (((uint32_t)x->d) << 23);
|
||||
scalev = GGML_F32_VEC_SET1(scale.as_value);
|
||||
float xf[GGML_F32_EPR]= {0.f};
|
||||
assert(((i+j*GGML_F32_EPR) % MXFP4)+GGML_F32_ARR < MXFP4 && "block overrun");
|
||||
for (int qi = 0; qi < GGML_F32_EPR/2 ; ++qi) {
|
||||
xf[qi*2] = mxfp4_table[(x->qs[((i+j*GGML_F32_EPR)%MXFP4)/2+qi] & 0xf)];
|
||||
xf[qi*2+1] = mxfp4_table[(x->qs[((i+j*GGML_F32_EPR)%MXFP4)/2+qi] & 0xf0) >> 4];
|
||||
}
|
||||
|
||||
ax[j] = GGML_F32_VEC_MUL(GGML_F32_VEC_LOAD(xf), scalev);
|
||||
ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
|
||||
sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], ay[j]);
|
||||
}
|
||||
}
|
||||
GGML_F32_VEC_REDUCE(sumf, sum);
|
||||
|
||||
// leftovers
|
||||
for (int i = np; i < n; i+=2) {
|
||||
const int ib = i / MXFP4;
|
||||
const block_mxfp4 * GGML_RESTRICT x = &xx[ib];
|
||||
union {
|
||||
uint32_t as_bits;
|
||||
float as_value;
|
||||
} scale;
|
||||
scale.as_bits = (((uint32_t)x->d) << 23);
|
||||
sumf += y[i] * scale.as_value * mxfp4_table[(x->qs[(i%MXFP4)/2] & 0xf)];
|
||||
sumf += y[i+1] * scale.as_value * mxfp4_table[(x->qs[(i%MXFP4)/2] & 0xf0) >> 4];
|
||||
}
|
||||
|
||||
|
||||
#else // defined(GGML_SIMD)
|
||||
const int nb = n / MXFP4;
|
||||
assert(n % MXFP4 == 0);
|
||||
|
||||
int yi = 0;
|
||||
|
||||
const block_mxfp4 * GGML_RESTRICT xx = (const block_mxfp4 *) vx;
|
||||
|
||||
ggml_float sumf = 0.0;
|
||||
for (int ib = 0; ib < nb; ++ib) {
|
||||
const block_mxfp4 * GGML_RESTRICT x = &xx[ib + 0];
|
||||
union {
|
||||
uint32_t as_bits;
|
||||
float as_value;
|
||||
} scale;
|
||||
scale.as_bits = (((uint32_t)x->d) << 23);
|
||||
for (int i = 0; i < MXFP4/2; ++i) {
|
||||
sumf += mxfp4_table[(x->qs[i] & 0xf)] * (ggml_float)(scale.as_value) * (ggml_float)(y[ib*MXFP4 + i*2]);
|
||||
sumf += mxfp4_table[(x->qs[i] & 0xf0) >> 4] * (ggml_float)(scale.as_value) * (ggml_float)(y[ib*MXFP4 + i*2+1]);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
*s = sumf;
|
||||
}
|
||||
|
||||
2
ml/backend/ggml/ggml/src/ggml-cpu/vec.h
vendored
2
ml/backend/ggml/ggml/src/ggml-cpu/vec.h
vendored
@@ -42,6 +42,8 @@ void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * G
|
||||
void ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf16_t * GGML_RESTRICT x, size_t bx, ggml_bf16_t * GGML_RESTRICT y, size_t by, int nrc);
|
||||
void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * GGML_RESTRICT x, size_t bx, ggml_fp16_t * GGML_RESTRICT y, size_t by, int nrc);
|
||||
|
||||
void ggml_vec_dot_mxfp4(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const float * GGML_RESTRICT y, size_t by, int nrc);
|
||||
|
||||
void ggml_vec_silu_f32(const int n, float * y, const float * x);
|
||||
ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max);
|
||||
ggml_float ggml_vec_log_soft_max_f32(const int n, float * y, const float * x, float max);
|
||||
|
||||
80
ml/backend/ggml/ggml/src/ggml-cuda/convert.cu
vendored
80
ml/backend/ggml/ggml/src/ggml-cuda/convert.cu
vendored
@@ -571,6 +571,82 @@ static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int64_t
|
||||
dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y);
|
||||
}
|
||||
|
||||
// MXFP4 dequantize derived from dequantize_block_q4_0
|
||||
template<typename dst_t>
|
||||
static __global__ void dequantize_block_mxfp4(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
|
||||
const uint16_t dst_bias = 15;
|
||||
const uint16_t dst_0p5 = 0x3800;
|
||||
const uint16_t dst_m_bits = 10;
|
||||
const int64_t i = blockIdx.x;
|
||||
|
||||
// assume 32 threads
|
||||
const int64_t tid = threadIdx.x;
|
||||
const int64_t il = tid/8;
|
||||
const int64_t ir = tid%8;
|
||||
const int64_t ib = 8*i + ir;
|
||||
if (ib >= nb32) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint64_t offset = 256*i + MXFP4*ir + 8*il;
|
||||
dst_t * y = yy + offset;
|
||||
|
||||
const block_mxfp4 * x = (const block_mxfp4 *)vx + ib;
|
||||
union {
|
||||
uint32_t as_bits;
|
||||
float as_value;
|
||||
} scale;
|
||||
scale.as_bits = (((uint32_t)x->d) << 23);
|
||||
|
||||
// offset within the block 1/4 chunks (8 items)
|
||||
const uint8_t * q = x->qs + 4*il;
|
||||
|
||||
for (int l = 0; l < 4; ++l) {
|
||||
uint16_t em0 = q[l] & 0x07;
|
||||
uint16_t em1 = q[l] & 0x70;
|
||||
// float16 values
|
||||
iq1m_scale_t x0;
|
||||
iq1m_scale_t x1;
|
||||
|
||||
x0.u16 = (em0 << (dst_m_bits - 1)) | ((q[l] & 0x08) << 12);
|
||||
x1.u16 = (em1 << (dst_m_bits - 5)) | ((q[l] & 0x80) << 8);
|
||||
|
||||
// Three cases:
|
||||
// x is normal and non-zero: Correct bias
|
||||
if ((em0 & 0x06) != 0) {
|
||||
x0.u16 = x0.u16 + ((dst_bias - 1) << dst_m_bits);
|
||||
}
|
||||
if ((em1 & 0x60) != 0) {
|
||||
x1.u16 = x1.u16 + ((dst_bias - 1) << dst_m_bits);
|
||||
}
|
||||
// x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type
|
||||
if (em0 == 0x01) {
|
||||
x0.u16 = dst_0p5 | (x0.u16 & 0x8000);
|
||||
}
|
||||
if (em1 == 0x10) {
|
||||
x1.u16 = dst_0p5 | (x1.u16 & 0x8000);
|
||||
}
|
||||
// x is zero, do nothing
|
||||
|
||||
// XXX it looks correct here - but mulmat still gives bad results...
|
||||
// printf("i:%lld ir:%lld il:%lld l:%d y_offset:[%3lld +%d] = %f \n",
|
||||
// i, ir, il, l, 256*i + 32*ir + 4*il, l*2+ 0, scale * float(x0.f16));
|
||||
// printf("i:%lld ir:%lld il:%lld l:%d y_offset:[%3lld +%d] = %f \n",
|
||||
// i, ir, il, l, 256*i + 32*ir + 4*il, l*2+ 1, scale * float(x1.f16));
|
||||
|
||||
y[l*2] = scale.as_value * float(x0.f16);
|
||||
y[l*2+1] = scale.as_value * float(x1.f16);
|
||||
}
|
||||
}
|
||||
|
||||
// derived from dequantize_row_q4_0_cuda
|
||||
template<typename dst_t>
|
||||
static void dequantize_row_mxfp4_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
|
||||
const int nb32 = k / 32;
|
||||
const int nb = (k + 255) / 256;
|
||||
dequantize_block_mxfp4<<<nb, 32, 0, stream>>>(vx, y, nb32);
|
||||
}
|
||||
|
||||
template <typename src_t, typename dst_t>
|
||||
static __global__ void convert_unary(
|
||||
const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, const int64_t ne02,
|
||||
@@ -664,6 +740,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
|
||||
return convert_unary_cont_cuda<float>;
|
||||
case GGML_TYPE_BF16:
|
||||
return convert_unary_cont_cuda<nv_bfloat16>;
|
||||
case GGML_TYPE_MXFP4:
|
||||
return dequantize_row_mxfp4_cuda;
|
||||
default:
|
||||
return nullptr;
|
||||
}
|
||||
@@ -713,6 +791,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
|
||||
return convert_unary_cont_cuda<half>;
|
||||
case GGML_TYPE_BF16:
|
||||
return convert_unary_cont_cuda<nv_bfloat16>;
|
||||
case GGML_TYPE_MXFP4:
|
||||
return dequantize_row_mxfp4_cuda;
|
||||
default:
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
30
ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu
vendored
30
ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu
vendored
@@ -21,6 +21,7 @@
|
||||
#include "ggml-cuda/im2col.cuh"
|
||||
#include "ggml-cuda/mmq.cuh"
|
||||
#include "ggml-cuda/mmv.cuh"
|
||||
#include "ggml-cuda/mmvmxfp4.cuh"
|
||||
#include "ggml-cuda/mmvq.cuh"
|
||||
#include "ggml-cuda/norm.cuh"
|
||||
#include "ggml-cuda/opt-step-adamw.cuh"
|
||||
@@ -1202,7 +1203,7 @@ static void ggml_cuda_op_mul_mat_cublas(
|
||||
|
||||
const int cc = ggml_cuda_info().devices[id].cc;
|
||||
|
||||
const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT;
|
||||
const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT && src0->type != GGML_TYPE_MXFP4;
|
||||
|
||||
if (src0->type == GGML_TYPE_BF16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {
|
||||
ggml_cuda_pool_alloc<nv_bfloat16> src1_as_bf16(ctx.pool(id));
|
||||
@@ -1924,7 +1925,11 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
|
||||
&& src0->ne[0] % 2 == 0 && src1->ne[1] == 1;
|
||||
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear
|
||||
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
|
||||
&& src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
|
||||
&& src1->ne[1] <= MMVQ_MAX_BATCH_SIZE
|
||||
&& src0->type != GGML_TYPE_MXFP4;
|
||||
bool use_mul_mat_vec_mxfp4 = src0->type == GGML_TYPE_MXFP4
|
||||
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
|
||||
&& src0->ne[0] % 2 == 0 && src1->ne[1] == 1;
|
||||
bool use_mul_mat_q = ggml_is_quantized(src0->type) && !bad_padding_clear
|
||||
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
|
||||
|
||||
@@ -1978,6 +1983,8 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
|
||||
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda);
|
||||
} else if (use_mul_mat_q) {
|
||||
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_q, quantize_mmq_q8_1_cuda);
|
||||
} else if (use_mul_mat_vec_mxfp4) {
|
||||
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_mxfp4, nullptr);
|
||||
} else {
|
||||
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_cublas, nullptr);
|
||||
}
|
||||
@@ -1997,6 +2004,10 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||
|
||||
if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
if (ne2 == 1 && src0->type == GGML_TYPE_MXFP4) {
|
||||
ggml_cuda_mul_mat_vec_mxfp4(ctx, src0, src1, ids, dst);
|
||||
return;
|
||||
}
|
||||
if (ne2 == 1) {
|
||||
if (ggml_is_quantized(src0->type)) {
|
||||
ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst);
|
||||
@@ -2498,20 +2509,6 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
|
||||
#endif
|
||||
}
|
||||
|
||||
// workarounds to exclude Gemma3n's `project_per_layer_input` operation from the batch-size heuristic, specific to ollama's implementation of gemma3n
|
||||
// number of layers is different for per_layer_proj between gemma3n:2b and gemma3n:4b, which is why we don't check that value here
|
||||
if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1 && !(node->ne[0] == 256
|
||||
&& node->ne[2] == 1
|
||||
&& node->ne[3] == 1
|
||||
&& node->src[0] ? std::string(node->src[0]->name).find(gemma3n_node_name) != std::string::npos : false
|
||||
&& node->src[1] ? node->src[1]->name == gemma3n_per_layer_proj_src1_name : false)) {
|
||||
// Generally, changes in batch size or context size can cause changes to the grid size of some kernels.
|
||||
use_cuda_graph = false;
|
||||
#ifndef NDEBUG
|
||||
GGML_LOG_INFO("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
|
||||
#endif
|
||||
}
|
||||
|
||||
if (node->op == GGML_OP_CPY) {
|
||||
|
||||
// Store the pointers which are updated for each token, such that these can be sent
|
||||
@@ -3056,6 +3053,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_BF16:
|
||||
case GGML_TYPE_MXFP4:
|
||||
#ifdef GGML_USE_MUSA
|
||||
if (a->type == GGML_TYPE_Q3_K) {
|
||||
return false;
|
||||
|
||||
333
ml/backend/ggml/ggml/src/ggml-cuda/mmvmxfp4.cu
vendored
Normal file
333
ml/backend/ggml/ggml/src/ggml-cuda/mmvmxfp4.cu
vendored
Normal file
@@ -0,0 +1,333 @@
|
||||
#include "ggml.h"
|
||||
#include "common.cuh"
|
||||
#include "mmvmxfp4.cuh"
|
||||
|
||||
// MXFP4 implementation derived from mmv.cu float32 code paths
|
||||
typedef union {
|
||||
half f16;
|
||||
uint16_t u16;
|
||||
} f16_t;
|
||||
|
||||
template <typename type_acc, int block_size> // TODO type_acc unused - consider bf16 support
|
||||
static __global__ void mul_mat_vec_mxfp4(
|
||||
const block_mxfp4 * __restrict__ x_base, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
|
||||
const int64_t ncols, const int64_t nchannels_y, const int64_t stride_row,
|
||||
const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
|
||||
const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst) {
|
||||
const int64_t row = blockIdx.x;
|
||||
const int64_t channel_dst = blockIdx.y;
|
||||
const int64_t channel_x = ids ? ids[channel_dst] : channel_dst / channel_ratio;
|
||||
const int64_t channel_y = ids ? channel_dst % nchannels_y : channel_dst;
|
||||
const int64_t sample_dst = blockIdx.z;
|
||||
const int64_t sample_x = sample_dst / sample_ratio;
|
||||
const int64_t sample_y = sample_dst;
|
||||
const int tid = threadIdx.x;
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
const int64_t ncols8 = ncols / 8;
|
||||
|
||||
const uint16_t dst_bias = 15;
|
||||
const uint16_t dst_0p5 = 0x3800;
|
||||
const uint16_t dst_m_bits = 10;
|
||||
|
||||
// x_base is offset by blocks of 32 elements
|
||||
x_base += sample_x *stride_sample_x + channel_x *stride_channel_x + row*stride_row;
|
||||
// y is offset by elements
|
||||
y += sample_y *stride_sample_y + channel_y *stride_channel_y;
|
||||
// dst is offset by elements
|
||||
dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst;
|
||||
|
||||
const float4 * y4 = (const float4 *) y;
|
||||
|
||||
extern __shared__ char data_mmv[]; // allocated in GPU shared memory: warp_size*sizeof(float)
|
||||
float * buf_iw = (float *) data_mmv;
|
||||
|
||||
if (block_size > warp_size) {
|
||||
if (tid < warp_size) {
|
||||
buf_iw[tid] = 0.0f;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
float sumf = 0.0f;
|
||||
|
||||
// each i8 index proceses 8 items at a time
|
||||
for (int64_t i8 = tid; i8 < ncols8; i8 += block_size) {
|
||||
// As i8 indexes past a block, we have to offset further
|
||||
int offset0 = i8 / (MXFP4/8);
|
||||
int xi = (i8 % (MXFP4/8)) * 4; // jump 4 bytes for each 8 elements
|
||||
const block_mxfp4 *x = x_base+offset0;
|
||||
|
||||
union {
|
||||
uint32_t as_bits;
|
||||
float as_value;
|
||||
} scale;
|
||||
scale.as_bits = (((uint32_t)x->d) << 23);
|
||||
if (isnan(scale.as_value)) {
|
||||
sumf = scale.as_value;
|
||||
break;
|
||||
}
|
||||
const uint8_t qs[4] = {
|
||||
(uint8_t)(x->qs[xi]),
|
||||
(uint8_t)(x->qs[xi+1]),
|
||||
(uint8_t)(x->qs[xi+2]),
|
||||
(uint8_t)(x->qs[xi+3])
|
||||
};
|
||||
|
||||
const uint8_t el[8] = {
|
||||
(uint8_t)(qs[0] & 0xf),
|
||||
(uint8_t)((qs[0] & 0xf0) >> 4),
|
||||
(uint8_t)(qs[1] & 0xf),
|
||||
(uint8_t)((qs[1] & 0xf0) >> 4),
|
||||
(uint8_t)(qs[2] & 0xf),
|
||||
(uint8_t)((qs[2] & 0xf0) >> 4),
|
||||
(uint8_t)(qs[3] & 0xf),
|
||||
(uint8_t)((qs[3] & 0xf0) >> 4)
|
||||
};
|
||||
|
||||
uint16_t em[8];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) { em[i] = (uint16_t)(el[i] & 0x07); }
|
||||
|
||||
// float16 values
|
||||
f16_t x4u[8];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) { x4u[i].u16 = (em[i] << (dst_m_bits - 1)) | ((el[i] & 0x08) << 12); }
|
||||
|
||||
// Three cases:
|
||||
// x is normal and non-zero: Correct bias
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) { if ((em[i] & 0x06) != 0) { x4u[i].u16 = x4u[i].u16 + ((dst_bias - 1) << dst_m_bits); } }
|
||||
|
||||
// x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) { if (em[i] == 0x01) { x4u[i].u16 = dst_0p5 | (x4u[i].u16 & 0x8000); } }
|
||||
// x is zero, do nothing
|
||||
|
||||
const float scalef = scale.as_value;
|
||||
const float4 tmpx0 = {x4u[0].f16, x4u[1].f16, x4u[2].f16, x4u[3].f16};
|
||||
const float4 tmpx1 = {x4u[4].f16, x4u[5].f16, x4u[6].f16, x4u[7].f16};
|
||||
const float4 tmpy0 = y4[i8*2];
|
||||
const float4 tmpy1 = y4[i8*2+1];
|
||||
sumf += tmpx0.x * tmpy0.x * scalef;
|
||||
sumf += tmpx0.y * tmpy0.y * scalef;
|
||||
sumf += tmpx0.z * tmpy0.z * scalef;
|
||||
sumf += tmpx0.w * tmpy0.w * scalef;
|
||||
sumf += tmpx1.x * tmpy1.x * scalef;
|
||||
sumf += tmpx1.y * tmpy1.y * scalef;
|
||||
sumf += tmpx1.z * tmpy1.z * scalef;
|
||||
sumf += tmpx1.w * tmpy1.w * scalef;
|
||||
}
|
||||
|
||||
sumf = warp_reduce_sum<warp_size>(sumf);
|
||||
|
||||
if (block_size > warp_size) {
|
||||
buf_iw[tid/warp_size] = sumf;
|
||||
__syncthreads();
|
||||
if (tid >= warp_size) {
|
||||
return;
|
||||
}
|
||||
sumf = buf_iw[tid];
|
||||
sumf = warp_reduce_sum<warp_size>(sumf);
|
||||
}
|
||||
|
||||
if (tid != 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
dst[row] = sumf;
|
||||
}
|
||||
|
||||
template <typename type_acc>
|
||||
static void launch_mul_mat_vec_cuda_mxfp4(
|
||||
const block_mxfp4 * x, const float * y, const int32_t * ids, float * dst,
|
||||
const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
||||
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
|
||||
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
||||
cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % 2 == 0);
|
||||
// GGML_ASSERT(stride_row % 2 == 0); // TODO
|
||||
GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
|
||||
GGML_ASSERT( nsamples_dst % nsamples_x == 0);
|
||||
const int64_t channel_ratio = nchannels_dst / nchannels_x;
|
||||
const int64_t sample_ratio = nsamples_dst / nsamples_x;
|
||||
int device;
|
||||
int warp_size;
|
||||
|
||||
CUDA_CHECK(cudaGetDevice(&device));
|
||||
warp_size = ggml_cuda_info().devices[device].warp_size;
|
||||
|
||||
int64_t block_size_best = warp_size;
|
||||
int64_t niter_best = (ncols + 2*warp_size - 1) / (2*warp_size);
|
||||
int64_t max_block_size = 256;
|
||||
if(ggml_cuda_info().devices[device].cc > GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_info().devices[device].cc < GGML_CUDA_CC_RDNA1) {
|
||||
max_block_size = 128;
|
||||
}
|
||||
for (int64_t block_size = 2*warp_size; block_size <= max_block_size; block_size += warp_size) {
|
||||
const int64_t niter = (ncols + 2*block_size - 1) / (2*block_size);
|
||||
if (niter < niter_best) {
|
||||
niter_best = niter;
|
||||
block_size_best = block_size;
|
||||
}
|
||||
}
|
||||
|
||||
const int smem = warp_size*sizeof(float);
|
||||
const dim3 block_nums(nrows, nchannels_dst, nsamples_dst);
|
||||
const dim3 block_dims(block_size_best, 1, 1);
|
||||
|
||||
switch (block_size_best) {
|
||||
case 32: {
|
||||
mul_mat_vec_mxfp4<type_acc, 32><<<block_nums, block_dims, smem, stream>>>
|
||||
(x, y, ids, dst, ncols, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
case 64: {
|
||||
mul_mat_vec_mxfp4<type_acc, 64><<<block_nums, block_dims, smem, stream>>>
|
||||
(x, y, ids, dst, ncols, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
case 96: {
|
||||
mul_mat_vec_mxfp4<type_acc, 96><<<block_nums, block_dims, smem, stream>>>
|
||||
(x, y, ids, dst, ncols, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
case 128: {
|
||||
mul_mat_vec_mxfp4<type_acc, 128><<<block_nums, block_dims, smem, stream>>>
|
||||
(x, y, ids, dst, ncols, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
case 160: {
|
||||
mul_mat_vec_mxfp4<type_acc, 160><<<block_nums, block_dims, smem, stream>>>
|
||||
(x, y, ids, dst, ncols, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
case 192: {
|
||||
mul_mat_vec_mxfp4<type_acc, 192><<<block_nums, block_dims, smem, stream>>>
|
||||
(x, y, ids, dst, ncols, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
case 224: {
|
||||
mul_mat_vec_mxfp4<type_acc, 224><<<block_nums, block_dims, smem, stream>>>
|
||||
(x, y, ids, dst, ncols, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
case 256: {
|
||||
mul_mat_vec_mxfp4<type_acc, 256><<<block_nums, block_dims, smem, stream>>>
|
||||
(x, y, ids, dst, ncols, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
default: {
|
||||
GGML_ABORT("fatal error");
|
||||
} break;
|
||||
}
|
||||
}
|
||||
|
||||
static void mul_mat_vec_cuda_mxfp4(
|
||||
const block_mxfp4 * x, const float * y, const int32_t * ids, float * dst,
|
||||
const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
||||
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
|
||||
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
||||
enum ggml_prec prec, cudaStream_t stream) {
|
||||
launch_mul_mat_vec_cuda_mxfp4<float>
|
||||
(x, y, ids, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
}
|
||||
|
||||
void ggml_cuda_mul_mat_vec_mxfp4(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
|
||||
GGML_ASSERT( src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS;
|
||||
|
||||
const size_t ts_src0 = ggml_type_size(src0->type);
|
||||
const size_t ts_src1 = ggml_type_size(src1->type);
|
||||
const size_t ts_dst = ggml_type_size(dst->type);
|
||||
|
||||
GGML_ASSERT(!ids || ne12 == 1); // Implementation is only correct for batch size 1.
|
||||
GGML_ASSERT(ne13 == ne3);
|
||||
|
||||
// GGML_ASSERT( nb00 == ts_src0); // TODO adjust for block sizing logic
|
||||
GGML_ASSERT( nb10 == ts_src1);
|
||||
GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type));
|
||||
GGML_ASSERT( nb0 == ts_dst);
|
||||
|
||||
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||
const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
|
||||
|
||||
const float * src1_d = (const float *) src1->data;
|
||||
const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr;
|
||||
float * dst_d = (float *) dst->data;
|
||||
|
||||
const int64_t stride_row = src0->nb[1] / ts_src0;
|
||||
const int64_t s11 = src1->nb[1] / ts_src1;
|
||||
const int64_t s1 = dst->nb[1] / ts_dst;
|
||||
const int64_t stride_channel_x = src0->nb[2] / ts_src0;
|
||||
const int64_t s12 = src1->nb[2] / ts_src1;
|
||||
const int64_t s2 = dst->nb[2] / ts_dst;
|
||||
const int64_t stride_sample_x = src0->nb[3] / ts_src0;
|
||||
const int64_t stride_sample_y = src1->nb[3] / ts_src1;
|
||||
const int64_t stride_sample_dst = dst->nb[3] / ts_dst;
|
||||
const int64_t nsamples_dst = ne3;
|
||||
const int64_t nsamples_x = ne03;
|
||||
const int64_t nchannels_x = ne02;
|
||||
const int64_t nrows = ne01;
|
||||
const int64_t ncols = ne00;
|
||||
|
||||
// For MUL_MAT_ID the memory layout is different than for MUL_MAT:
|
||||
const int64_t ncols_dst = ids ? ne2 : ne1;
|
||||
const int64_t nchannels_y = ids ? ne11 : ne12;
|
||||
const int64_t nchannels_dst = ids ? ne1 : ne2;
|
||||
const int64_t stride_channel_dst = ids ? s1 : s2;
|
||||
const int64_t stride_channel_y = ids ? s11 : s12;
|
||||
|
||||
GGML_ASSERT(ncols_dst == 1);
|
||||
|
||||
const block_mxfp4 * src0_d = (const block_mxfp4 *) src0->data;
|
||||
mul_mat_vec_cuda_mxfp4(src0_d, src1_d, ids_d, dst_d, ncols, nrows, stride_row,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, ctx.stream());
|
||||
}
|
||||
|
||||
void ggml_cuda_op_mul_mat_vec_mxfp4(
|
||||
ggml_backend_cuda_context & ctx,
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
|
||||
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
|
||||
const int64_t src1_padded_row_size, cudaStream_t stream) {
|
||||
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t row_diff = row_high - row_low;
|
||||
|
||||
GGML_ASSERT(src1_ncols == 1);
|
||||
|
||||
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||
const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
|
||||
|
||||
// ggml_cuda_op provides single, contiguous matrices
|
||||
const int64_t stride_row = ne00 / MXFP4;
|
||||
const int64_t nchannels_x = 1;
|
||||
const int64_t nchannels_y = 1;
|
||||
const int64_t nchannels_dst = 1;
|
||||
const int64_t stride_channel_x = 0;
|
||||
const int64_t stride_channel_y = 0;
|
||||
const int64_t stride_channel_dst = 0;
|
||||
const int64_t nsamples_x = 1;
|
||||
const int64_t nsamples_dst = 1;
|
||||
const int64_t stride_sample_x = 0;
|
||||
const int64_t stride_sample_y = 0;
|
||||
const int64_t stride_sample_dst = 0;
|
||||
|
||||
const block_mxfp4 * src0_d = (const block_mxfp4 *) src0_dd_i;
|
||||
mul_mat_vec_cuda_mxfp4(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, stride_row,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
|
||||
|
||||
GGML_UNUSED(ctx);
|
||||
GGML_UNUSED(src1);
|
||||
GGML_UNUSED(dst);
|
||||
GGML_UNUSED(src1_ddq_i);
|
||||
GGML_UNUSED(src1_ncols);
|
||||
GGML_UNUSED(src1_padded_row_size);
|
||||
}
|
||||
9
ml/backend/ggml/ggml/src/ggml-cuda/mmvmxfp4.cuh
vendored
Normal file
9
ml/backend/ggml/ggml/src/ggml-cuda/mmvmxfp4.cuh
vendored
Normal file
@@ -0,0 +1,9 @@
|
||||
#include "common.cuh"
|
||||
|
||||
void ggml_cuda_mul_mat_vec_mxfp4(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_mul_mat_vec_mxfp4(
|
||||
ggml_backend_cuda_context & ctx,
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
|
||||
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
|
||||
const int64_t src1_padded_row_size, cudaStream_t stream);
|
||||
@@ -421,6 +421,13 @@ typedef struct {
|
||||
} block_iq4_xs;
|
||||
static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding");
|
||||
|
||||
#define MXFP4 32
|
||||
typedef struct {
|
||||
uint8_t d; // scale E8M0 float
|
||||
uint8_t qs[MXFP4 / 2]; // (32) 4 bit elements E2M1 float
|
||||
} block_mxfp4;
|
||||
static_assert(sizeof(block_mxfp4) == sizeof(uint8_t) + MXFP4/2, "wrong mxfp4 block size/padding");
|
||||
|
||||
#endif // GGML_COMMON_DECL
|
||||
#endif // GGML_COMMON_DECL
|
||||
|
||||
@@ -1929,6 +1936,9 @@ GGML_TABLE_END()
|
||||
#define N_R0_IQ4_XS 2
|
||||
#define N_SG_IQ4_XS 2
|
||||
|
||||
#define N_R0_MXFP4 4
|
||||
#define N_SG_MXFP4 2
|
||||
|
||||
// kernel argument structs
|
||||
//
|
||||
// - element counters (e.g. ne00) typically use int32_t to reduce register usage
|
||||
@@ -4380,16 +4390,16 @@ void mul_vec_q_n_f32_impl(
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup char * shmem,
|
||||
uint3 tgpig,
|
||||
ushort tiisg,
|
||||
ushort sgitg) {
|
||||
const int nb = args.ne00/QK4_0;
|
||||
uint3 tgpig, // Threadgroup Position in Grid
|
||||
ushort tiisg, // Thread Index in SIMD Group
|
||||
ushort sgitg) { // SIMD Group Index in ThreadGroup
|
||||
const int nb = args.ne00/QK4_0; // src0->ne[0] / 32
|
||||
|
||||
const int r0 = tgpig.x;
|
||||
const int r1 = tgpig.y;
|
||||
const int im = tgpig.z;
|
||||
|
||||
const int first_row = (r0 * nsg + sgitg) * nr0;
|
||||
const int first_row = (r0 * nsg + sgitg) * nr0; // nsg=2 nr0=4
|
||||
|
||||
const uint i12 = im%args.ne12;
|
||||
const uint i13 = im/args.ne12;
|
||||
@@ -9222,6 +9232,49 @@ kernel void kernel_mul_mm_id(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_mxfp4(device const block_mxfp4 * xb, short il, thread type4x4 & reg) {
|
||||
float4x4 reg_f;
|
||||
const ushort dst_bias = 15;
|
||||
const ushort dst_0p5 = 0x3800;
|
||||
const ushort dst_m_bits = 10;
|
||||
const half scale = (half)(as_type<float>(((uint32_t)xb->d) << 23));
|
||||
// il:0 first 16, il:1 last 16
|
||||
for (int i = 0; i < 8; i++) {
|
||||
ushort em0 = xb->qs[il*8 + i] & 0x07;
|
||||
ushort em1 = xb->qs[il*8 + i] & 0x70;
|
||||
// float16 values
|
||||
ushort x0 = (em0 << (dst_m_bits - 1)) | ((xb->qs[il*8 + i] & 0x08) << 12);
|
||||
ushort x1 = (em1 << (dst_m_bits - 5)) | ((xb->qs[il*8 + i] & 0x80) << 8);
|
||||
|
||||
// Three cases:
|
||||
// x is normal and non-zero: Correct bias
|
||||
if ((em0 & 0x06) != 0) {
|
||||
x0 = x0 + ((dst_bias - 1) << dst_m_bits);
|
||||
}
|
||||
if ((em1 & 0x60) != 0) {
|
||||
x1 = x1 + ((dst_bias - 1) << dst_m_bits);
|
||||
}
|
||||
// x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type
|
||||
if (em0 == 0x01) {
|
||||
x0 = dst_0p5 | (x0 & 0x8000);
|
||||
}
|
||||
if (em1 == 0x10) {
|
||||
x1 = dst_0p5 | (x1 & 0x8000);
|
||||
}
|
||||
// x is zero, do nothing
|
||||
|
||||
if (isnan(scale)) {
|
||||
reg_f[i/2][2*(i%2) + 0] = scale;
|
||||
reg_f[i/2][2*(i%2) + 1] = scale;
|
||||
} else {
|
||||
reg_f[i/2][2*(i%2) + 0] = scale * as_type<half>(x0);
|
||||
reg_f[i/2][2*(i%2) + 1] = scale * as_type<half>(x1);
|
||||
}
|
||||
}
|
||||
reg = (type4x4) reg_f;
|
||||
}
|
||||
|
||||
#define QK_NL 16
|
||||
|
||||
//
|
||||
@@ -9289,6 +9342,8 @@ template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mul_mm_t kernel_mul_m
|
||||
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl>;
|
||||
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
||||
|
||||
template [[host_name("kernel_mul_mm_mxfp4_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4>;
|
||||
|
||||
//
|
||||
// indirect matrix-matrix multiplication
|
||||
//
|
||||
@@ -9320,6 +9375,8 @@ template [[host_name("kernel_mul_mm_id_iq1_m_f16")]] kernel mul_mm_id kernel_m
|
||||
template [[host_name("kernel_mul_mm_id_iq4_nl_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl>;
|
||||
template [[host_name("kernel_mul_mm_id_iq4_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
||||
|
||||
template [[host_name("kernel_mul_mm_id_mxfp4_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4>;
|
||||
|
||||
|
||||
//
|
||||
// matrix-vector multiplication
|
||||
@@ -9436,6 +9493,120 @@ kernel void kernel_mul_mv_id(
|
||||
sgitg);
|
||||
}
|
||||
|
||||
// MXFP32 implementation derived from mul_vec_q_n_f32_impl and block_q_n_dot_y
|
||||
void mul_mv_mxfp4_f32_impl(
|
||||
ggml_metal_kargs_mul_mv args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup char * shmem,
|
||||
uint3 tgpig,
|
||||
ushort tiisg,
|
||||
ushort sgitg) {
|
||||
const ushort dst_bias = 15;
|
||||
const ushort dst_0p5 = 0x3800;
|
||||
const ushort dst_m_bits = 10;
|
||||
const int nr0 = N_R0_MXFP4;
|
||||
const int nsg = N_SG_MXFP4;
|
||||
const int nw = N_SIMDWIDTH;
|
||||
const int nb = args.ne00/MXFP4;
|
||||
|
||||
const int r0 = tgpig.x;
|
||||
const int r1 = tgpig.y;
|
||||
const int im = tgpig.z;
|
||||
|
||||
const int first_row = (r0 * nsg + sgitg) * nr0;
|
||||
|
||||
const uint i12 = im%args.ne12;
|
||||
const uint i13 = im/args.ne12;
|
||||
|
||||
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
||||
|
||||
device const float * y = (device const float *) (src1 + offset1);
|
||||
|
||||
// pointers to src0 rows
|
||||
device const block_mxfp4 * ax[nr0];
|
||||
for (int row = 0; row < nr0; ++row) {
|
||||
const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
||||
|
||||
ax[row] = (device const block_mxfp4 *) ((device char *) src0 + offset0);
|
||||
}
|
||||
|
||||
float yl[16]; // src1 vector cache
|
||||
float sumf[nr0] = {0.f};
|
||||
|
||||
const short ix = (tiisg/2);
|
||||
const short il = (tiisg%2)*16;
|
||||
|
||||
device const float * yb = y + ix*MXFP4 + il;
|
||||
|
||||
// each thread in a SIMD group deals with half a block.
|
||||
for (int ib = ix; ib < nb; ib += nw/2) {
|
||||
|
||||
#pragma unroll
|
||||
for (short row = 0; row < nr0; row++) {
|
||||
// Processes 16 items
|
||||
device const block_mxfp4 * qb_curr = ax[row] + ib;
|
||||
float d = as_type<float>(((uint32_t)(ax[row] + ib)->d) << 23);
|
||||
// il = 0 or 16
|
||||
device const uint8_t *qs = ((device const uint8_t *) qb_curr + 1 + il/2);
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
ushort em0 = qs[i] & 0x07;
|
||||
ushort em1 = qs[i] & 0x70;
|
||||
ushort x0 = (em0 << (dst_m_bits - 1)) | ((qs[i] & 0x08) << 12);
|
||||
ushort x1 = (em1 << (dst_m_bits - 5)) | ((qs[i] & 0x80) << 8);
|
||||
// Three cases:
|
||||
// x is normal and non-zero: Correct bias
|
||||
if ((em0 & 0x06) != 0) {
|
||||
x0 = x0 + ((dst_bias - 1) << dst_m_bits);
|
||||
}
|
||||
if ((em1 & 0x60) != 0) {
|
||||
x1 = x1 + ((dst_bias - 1) << dst_m_bits);
|
||||
}
|
||||
// x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type
|
||||
if (em0 == 0x01) {
|
||||
x0 = dst_0p5 | (x0 & 0x8000);
|
||||
}
|
||||
if (em1 == 0x10) {
|
||||
x1 = dst_0p5 | (x1 & 0x8000);
|
||||
}
|
||||
// x is zero, do nothing
|
||||
if (!isnan(d)) {
|
||||
sumf[row] += yb[i*2] * as_type<half>(x0) * d
|
||||
+ yb[i*2+1] * as_type<half>(x1) * d;
|
||||
} else {
|
||||
sumf[row] = d;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
yb += MXFP4 * 16;
|
||||
}
|
||||
|
||||
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
|
||||
|
||||
for (int row = 0; row < nr0; ++row) {
|
||||
const float tot = simd_sum(sumf[row]);
|
||||
|
||||
if (tiisg == 0 && first_row + row < args.ne01) {
|
||||
dst_f32[first_row + row] = tot;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
[[host_name("kernel_mul_mv_mxfp4_f32")]]
|
||||
kernel void kernel_mul_mv_mxfp4_f32(
|
||||
constant ggml_metal_kargs_mul_mv & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup char * shmem [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
mul_mv_mxfp4_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>) kernel_mul_mv_id_t;
|
||||
|
||||
template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>;
|
||||
@@ -9465,6 +9636,8 @@ template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t
|
||||
template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl <N_R0_IQ4_NL, N_SG_IQ4_NL, N_SIMDWIDTH>>>;
|
||||
template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl <N_R0_IQ4_XS, N_SG_IQ4_XS, N_SIMDWIDTH>>>;
|
||||
|
||||
template [[host_name("kernel_mul_mv_id_mxfp4_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_mv_mxfp4_f32_impl>>;
|
||||
|
||||
kernel void kernel_pool_2d_max_f32(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
|
||||
@@ -65,6 +65,9 @@
|
||||
#define N_R0_IQ4_XS 2
|
||||
#define N_SG_IQ4_XS 2
|
||||
|
||||
#define N_R0_MXFP4 4
|
||||
#define N_SG_MXFP4 2
|
||||
|
||||
// kernel argument structs
|
||||
//
|
||||
// - element counters (e.g. ne00) typically use int32_t to reduce register usage
|
||||
|
||||
25
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.m
vendored
25
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.m
vendored
@@ -40,6 +40,7 @@ static const NSInteger MTLGPUFamilyMetal3_GGML = 5001;
|
||||
static struct ggml_backend_reg g_ggml_backend_metal_reg;
|
||||
static struct ggml_backend_device g_ggml_backend_metal_device;
|
||||
|
||||
|
||||
// information about a Metal device
|
||||
// note: assumes single GPU device - the default one
|
||||
// TODO: support multiple GPU devices
|
||||
@@ -209,6 +210,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4,
|
||||
@@ -288,6 +290,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32,
|
||||
@@ -310,6 +313,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16,
|
||||
@@ -334,6 +338,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F16,
|
||||
GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32,
|
||||
GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,
|
||||
GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32,
|
||||
@@ -934,7 +939,7 @@ static id<MTLLibrary> ggml_metal_load_library(id<MTLDevice> device, bool use_bfl
|
||||
|
||||
MTLCompileOptions * options = [MTLCompileOptions new];
|
||||
options.preprocessorMacros = prep;
|
||||
|
||||
|
||||
//[options setFastMathEnabled:false];
|
||||
|
||||
metal_library = [device newLibraryWithSource:src options:options error:&error];
|
||||
@@ -1157,6 +1162,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32, mul_mv_mxfp4_f32, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, mul_mv_ext_f16_f32_r1_2, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, mul_mv_ext_f16_f32_r1_3, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, mul_mv_ext_f16_f32_r1_4, has_simdgroup_reduction);
|
||||
@@ -1236,6 +1242,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32, mul_mv_id_mxfp4_f32, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, mul_mm_bf16_f32, has_simdgroup_mm && use_bfloat);
|
||||
@@ -1258,6 +1265,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32, mul_mm_mxfp4_f32, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16, mul_mm_id_map0_f16, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32, mul_mm_id_map1_f32, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16, mul_mm_id_f32_f16, has_simdgroup_mm);
|
||||
@@ -1282,6 +1290,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16, mul_mm_id_iq1_m_f16, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16, mul_mm_id_iq4_nl_f16, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16, mul_mm_id_iq4_xs_f16, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F16, mul_mm_id_mxfp4_f16, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32, rope_multi_f32, true);
|
||||
@@ -3007,6 +3016,7 @@ static bool ggml_metal_encode_node(
|
||||
case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break;
|
||||
case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32 ].pipeline; break;
|
||||
default: GGML_ABORT("MUL MAT-MAT not implemented");
|
||||
}
|
||||
|
||||
@@ -3212,6 +3222,12 @@ static bool ggml_metal_encode_node(
|
||||
smem = 32*sizeof(float);
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline;
|
||||
} break;
|
||||
case GGML_TYPE_MXFP4:
|
||||
{
|
||||
nsg = N_SG_MXFP4;
|
||||
nr0 = N_R0_MXFP4;
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32].pipeline;
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_LOG_ERROR("Asserting on type %d\n", (int)src0t);
|
||||
@@ -3396,6 +3412,7 @@ static bool ggml_metal_encode_node(
|
||||
case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16 ].pipeline; break;
|
||||
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16 ].pipeline; break;
|
||||
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16 ].pipeline; break;
|
||||
case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F16 ].pipeline; break;
|
||||
default: GGML_ABORT("MUL_MAT_ID not implemented");
|
||||
}
|
||||
|
||||
@@ -3607,6 +3624,12 @@ static bool ggml_metal_encode_node(
|
||||
smem = 32*sizeof(float);
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline;
|
||||
} break;
|
||||
case GGML_TYPE_MXFP4:
|
||||
{
|
||||
nsg = N_SG_MXFP4;
|
||||
nr0 = N_R0_MXFP4;
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32].pipeline;
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_LOG_ERROR("Asserting on type %d\n", (int)src2t);
|
||||
|
||||
173
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal
vendored
173
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal
vendored
@@ -1902,16 +1902,16 @@ void mul_vec_q_n_f32_impl(
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup char * shmem,
|
||||
uint3 tgpig,
|
||||
ushort tiisg,
|
||||
ushort sgitg) {
|
||||
const int nb = args.ne00/QK4_0;
|
||||
uint3 tgpig, // Threadgroup Position in Grid
|
||||
ushort tiisg, // Thread Index in SIMD Group
|
||||
ushort sgitg) { // SIMD Group Index in ThreadGroup
|
||||
const int nb = args.ne00/QK4_0; // src0->ne[0] / 32
|
||||
|
||||
const int r0 = tgpig.x;
|
||||
const int r1 = tgpig.y;
|
||||
const int im = tgpig.z;
|
||||
|
||||
const int first_row = (r0 * nsg + sgitg) * nr0;
|
||||
const int first_row = (r0 * nsg + sgitg) * nr0; // nsg=2 nr0=4
|
||||
|
||||
const uint i12 = im%args.ne12;
|
||||
const uint i13 = im/args.ne12;
|
||||
@@ -6744,6 +6744,49 @@ kernel void kernel_mul_mm_id(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_mxfp4(device const block_mxfp4 * xb, short il, thread type4x4 & reg) {
|
||||
float4x4 reg_f;
|
||||
const ushort dst_bias = 15;
|
||||
const ushort dst_0p5 = 0x3800;
|
||||
const ushort dst_m_bits = 10;
|
||||
const half scale = (half)(as_type<float>(((uint32_t)xb->d) << 23));
|
||||
// il:0 first 16, il:1 last 16
|
||||
for (int i = 0; i < 8; i++) {
|
||||
ushort em0 = xb->qs[il*8 + i] & 0x07;
|
||||
ushort em1 = xb->qs[il*8 + i] & 0x70;
|
||||
// float16 values
|
||||
ushort x0 = (em0 << (dst_m_bits - 1)) | ((xb->qs[il*8 + i] & 0x08) << 12);
|
||||
ushort x1 = (em1 << (dst_m_bits - 5)) | ((xb->qs[il*8 + i] & 0x80) << 8);
|
||||
|
||||
// Three cases:
|
||||
// x is normal and non-zero: Correct bias
|
||||
if ((em0 & 0x06) != 0) {
|
||||
x0 = x0 + ((dst_bias - 1) << dst_m_bits);
|
||||
}
|
||||
if ((em1 & 0x60) != 0) {
|
||||
x1 = x1 + ((dst_bias - 1) << dst_m_bits);
|
||||
}
|
||||
// x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type
|
||||
if (em0 == 0x01) {
|
||||
x0 = dst_0p5 | (x0 & 0x8000);
|
||||
}
|
||||
if (em1 == 0x10) {
|
||||
x1 = dst_0p5 | (x1 & 0x8000);
|
||||
}
|
||||
// x is zero, do nothing
|
||||
|
||||
if (isnan(scale)) {
|
||||
reg_f[i/2][2*(i%2) + 0] = scale;
|
||||
reg_f[i/2][2*(i%2) + 1] = scale;
|
||||
} else {
|
||||
reg_f[i/2][2*(i%2) + 0] = scale * as_type<half>(x0);
|
||||
reg_f[i/2][2*(i%2) + 1] = scale * as_type<half>(x1);
|
||||
}
|
||||
}
|
||||
reg = (type4x4) reg_f;
|
||||
}
|
||||
|
||||
#define QK_NL 16
|
||||
|
||||
//
|
||||
@@ -6811,6 +6854,8 @@ template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mul_mm_t kernel_mul_m
|
||||
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl>;
|
||||
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
||||
|
||||
template [[host_name("kernel_mul_mm_mxfp4_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4>;
|
||||
|
||||
//
|
||||
// indirect matrix-matrix multiplication
|
||||
//
|
||||
@@ -6842,6 +6887,8 @@ template [[host_name("kernel_mul_mm_id_iq1_m_f16")]] kernel mul_mm_id kernel_m
|
||||
template [[host_name("kernel_mul_mm_id_iq4_nl_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl>;
|
||||
template [[host_name("kernel_mul_mm_id_iq4_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
||||
|
||||
template [[host_name("kernel_mul_mm_id_mxfp4_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4>;
|
||||
|
||||
|
||||
//
|
||||
// matrix-vector multiplication
|
||||
@@ -6958,6 +7005,120 @@ kernel void kernel_mul_mv_id(
|
||||
sgitg);
|
||||
}
|
||||
|
||||
// MXFP32 implementation derived from mul_vec_q_n_f32_impl and block_q_n_dot_y
|
||||
void mul_mv_mxfp4_f32_impl(
|
||||
ggml_metal_kargs_mul_mv args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup char * shmem,
|
||||
uint3 tgpig,
|
||||
ushort tiisg,
|
||||
ushort sgitg) {
|
||||
const ushort dst_bias = 15;
|
||||
const ushort dst_0p5 = 0x3800;
|
||||
const ushort dst_m_bits = 10;
|
||||
const int nr0 = N_R0_MXFP4;
|
||||
const int nsg = N_SG_MXFP4;
|
||||
const int nw = N_SIMDWIDTH;
|
||||
const int nb = args.ne00/MXFP4;
|
||||
|
||||
const int r0 = tgpig.x;
|
||||
const int r1 = tgpig.y;
|
||||
const int im = tgpig.z;
|
||||
|
||||
const int first_row = (r0 * nsg + sgitg) * nr0;
|
||||
|
||||
const uint i12 = im%args.ne12;
|
||||
const uint i13 = im/args.ne12;
|
||||
|
||||
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
||||
|
||||
device const float * y = (device const float *) (src1 + offset1);
|
||||
|
||||
// pointers to src0 rows
|
||||
device const block_mxfp4 * ax[nr0];
|
||||
for (int row = 0; row < nr0; ++row) {
|
||||
const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
||||
|
||||
ax[row] = (device const block_mxfp4 *) ((device char *) src0 + offset0);
|
||||
}
|
||||
|
||||
float yl[16]; // src1 vector cache
|
||||
float sumf[nr0] = {0.f};
|
||||
|
||||
const short ix = (tiisg/2);
|
||||
const short il = (tiisg%2)*16;
|
||||
|
||||
device const float * yb = y + ix*MXFP4 + il;
|
||||
|
||||
// each thread in a SIMD group deals with half a block.
|
||||
for (int ib = ix; ib < nb; ib += nw/2) {
|
||||
|
||||
#pragma unroll
|
||||
for (short row = 0; row < nr0; row++) {
|
||||
// Processes 16 items
|
||||
device const block_mxfp4 * qb_curr = ax[row] + ib;
|
||||
float d = as_type<float>(((uint32_t)(ax[row] + ib)->d) << 23);
|
||||
// il = 0 or 16
|
||||
device const uint8_t *qs = ((device const uint8_t *) qb_curr + 1 + il/2);
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
ushort em0 = qs[i] & 0x07;
|
||||
ushort em1 = qs[i] & 0x70;
|
||||
ushort x0 = (em0 << (dst_m_bits - 1)) | ((qs[i] & 0x08) << 12);
|
||||
ushort x1 = (em1 << (dst_m_bits - 5)) | ((qs[i] & 0x80) << 8);
|
||||
// Three cases:
|
||||
// x is normal and non-zero: Correct bias
|
||||
if ((em0 & 0x06) != 0) {
|
||||
x0 = x0 + ((dst_bias - 1) << dst_m_bits);
|
||||
}
|
||||
if ((em1 & 0x60) != 0) {
|
||||
x1 = x1 + ((dst_bias - 1) << dst_m_bits);
|
||||
}
|
||||
// x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type
|
||||
if (em0 == 0x01) {
|
||||
x0 = dst_0p5 | (x0 & 0x8000);
|
||||
}
|
||||
if (em1 == 0x10) {
|
||||
x1 = dst_0p5 | (x1 & 0x8000);
|
||||
}
|
||||
// x is zero, do nothing
|
||||
if (!isnan(d)) {
|
||||
sumf[row] += yb[i*2] * as_type<half>(x0) * d
|
||||
+ yb[i*2+1] * as_type<half>(x1) * d;
|
||||
} else {
|
||||
sumf[row] = d;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
yb += MXFP4 * 16;
|
||||
}
|
||||
|
||||
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
|
||||
|
||||
for (int row = 0; row < nr0; ++row) {
|
||||
const float tot = simd_sum(sumf[row]);
|
||||
|
||||
if (tiisg == 0 && first_row + row < args.ne01) {
|
||||
dst_f32[first_row + row] = tot;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
[[host_name("kernel_mul_mv_mxfp4_f32")]]
|
||||
kernel void kernel_mul_mv_mxfp4_f32(
|
||||
constant ggml_metal_kargs_mul_mv & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup char * shmem [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
mul_mv_mxfp4_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>) kernel_mul_mv_id_t;
|
||||
|
||||
template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>;
|
||||
@@ -6987,6 +7148,8 @@ template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t
|
||||
template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl <N_R0_IQ4_NL, N_SG_IQ4_NL, N_SIMDWIDTH>>>;
|
||||
template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl <N_R0_IQ4_XS, N_SG_IQ4_XS, N_SIMDWIDTH>>>;
|
||||
|
||||
template [[host_name("kernel_mul_mv_id_mxfp4_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_mv_mxfp4_f32_impl>>;
|
||||
|
||||
kernel void kernel_pool_2d_max_f32(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
|
||||
142
ml/backend/ggml/ggml/src/ggml-quants.c
vendored
142
ml/backend/ggml/ggml/src/ggml-quants.c
vendored
@@ -4925,6 +4925,144 @@ void quantize_row_iq2_s_ref(const float * GGML_RESTRICT x, block_iq2_s * GGML_RE
|
||||
quantize_iq2_s(x, y, 1, k, NULL);
|
||||
}
|
||||
|
||||
// =============================== mxfp4 (de)-quantization
|
||||
|
||||
void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k) {
|
||||
static const int qk = MXFP4;
|
||||
static const uint32_t E8_BIAS = 127;
|
||||
static const uint32_t E2_BIAS = 1;
|
||||
|
||||
assert(k % qk == 0);
|
||||
|
||||
const int nb = k / qk;
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
float amax = 0.0f; // absolute max
|
||||
|
||||
for (int j = 0; j < qk; j++) {
|
||||
const float v = x[i*qk + j];
|
||||
if (amax < fabsf(v)) {
|
||||
amax = fabsf(v);
|
||||
}
|
||||
}
|
||||
|
||||
const float dequant_scale = amax / 6.0f;
|
||||
uint32_t dequant_scale_exponent = 0;
|
||||
memcpy(&dequant_scale_exponent, &dequant_scale, sizeof(dequant_scale_exponent));
|
||||
|
||||
// Rounding up
|
||||
dequant_scale_exponent = (dequant_scale_exponent + 0x007FFFFF) & 0x7F800000;
|
||||
// Rounding down
|
||||
// dequant_scale_exponent = dequant_scale_exponent & 0x7F800000;
|
||||
|
||||
float dequant_scale_rounded = 0.0f;
|
||||
memcpy(&dequant_scale_rounded, &dequant_scale_exponent, sizeof(dequant_scale_rounded));
|
||||
float quant_scale = 0.0f;
|
||||
if (dequant_scale_rounded != 0.0f) {
|
||||
quant_scale = 1.0f / dequant_scale_rounded;
|
||||
}
|
||||
|
||||
y[i].d = (uint8_t)(dequant_scale_exponent >> 23);
|
||||
|
||||
for (int j = 0; j < qk/2; ++j) {
|
||||
const float x0 = x[i*qk + j*2]*quant_scale;
|
||||
const float x1 = x[i*qk + j*2+1]*quant_scale;
|
||||
|
||||
uint32_t xi0 = 0;
|
||||
uint32_t xi1 = 0;
|
||||
memcpy(&xi0, &x0, sizeof(xi0));
|
||||
memcpy(&xi1, &x1, sizeof(xi1));
|
||||
|
||||
uint32_t s0 = xi0 & 0x80000000;
|
||||
uint32_t s1 = xi1 & 0x80000000;
|
||||
uint32_t e0 = (xi0 >> 23) & 0xFF;
|
||||
uint32_t e1 = (xi1 >> 23) & 0xFF;
|
||||
uint32_t m0 = (xi0 & 0x7FFFFF);
|
||||
uint32_t m1 = (xi1 & 0x7FFFFF);
|
||||
|
||||
// 0.25 <= x < 0.75 maps to 0.5, a denormal number
|
||||
// Move implicit bit 1 at the beginning to mantissa for denormals
|
||||
// adjusted_exponents
|
||||
uint32_t ae0 = E8_BIAS - (e0 + 1);
|
||||
uint32_t ae1 = E8_BIAS - (e1 + 1);
|
||||
if (e0 < E8_BIAS) {
|
||||
m0 = (0x400000 | (m0 >> 1)) >> ae0;
|
||||
}
|
||||
if (e1 < E8_BIAS) {
|
||||
m1 = (0x400000 | (m1 >> 1)) >> ae1;
|
||||
}
|
||||
|
||||
// For normal numbers, we change the bias from 127 to 1, and for subnormals, we keep exponent as 0.
|
||||
e0 = MAX(e0, E8_BIAS - E2_BIAS) - (E8_BIAS - E2_BIAS);
|
||||
e1 = MAX(e1, E8_BIAS - E2_BIAS) - (E8_BIAS - E2_BIAS);
|
||||
|
||||
// Combine sign, exponent, and mantissa, while saturating
|
||||
// rounding nearest with tie breaking up by adding +1 to one bit right of the LSB, then shift right
|
||||
uint32_t tmp0 = MIN((((e0 << 2) | (m0 >> 21)) + 1) >> 1, 0x7);
|
||||
uint32_t tmp1 = MIN((((e1 << 2) | (m1 >> 21)) + 1) >> 1, 0x7);
|
||||
uint8_t v0 = (uint8_t)((s0 >> 28) | tmp0);
|
||||
uint8_t v1 = (uint8_t)((s1 >> 28) | tmp1);
|
||||
y[i].qs[j] = v0;
|
||||
y[i].qs[j] |= v1 << 4;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
|
||||
assert(k % MXFP4 == 0);
|
||||
|
||||
const int nb = k / MXFP4;
|
||||
const uint16_t dst_bias = 15;
|
||||
const uint16_t dst_0p5 = 0x3800;
|
||||
const uint16_t dst_m_bits = 10;
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
union {
|
||||
uint32_t as_bits;
|
||||
float as_value;
|
||||
} scale;
|
||||
scale.as_bits = (((uint32_t)x[i].d) << 23);
|
||||
for (int j = 0; j < MXFP4/2; ++j) {
|
||||
uint16_t em0 = x[i].qs[j] & 0x07;
|
||||
uint16_t em1 = x[i].qs[j] & 0x70;
|
||||
// float16 values
|
||||
uint16_t x0 = (em0 << (dst_m_bits - 1)) | ((x[i].qs[j] & 0x08) << 12);
|
||||
uint16_t x1 = (em1 << (dst_m_bits - 5)) | ((x[i].qs[j] & 0x80) << 8);
|
||||
|
||||
// Three cases:
|
||||
// x is normal and non-zero: Correct bias
|
||||
if ((em0 & 0x06) != 0) {
|
||||
x0 = x0 + ((dst_bias - 1) << dst_m_bits);
|
||||
}
|
||||
if ((em1 & 0x60) != 0) {
|
||||
x1 = x1 + ((dst_bias - 1) << dst_m_bits);
|
||||
}
|
||||
// x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type
|
||||
if (em0 == 0x01) {
|
||||
x0 = dst_0p5 | (x0 & 0x8000);
|
||||
}
|
||||
if (em1 == 0x10) {
|
||||
x1 = dst_0p5 | (x1 & 0x8000);
|
||||
}
|
||||
// x is zero, do nothing
|
||||
|
||||
if (isnan(scale.as_value)) {
|
||||
y[i*MXFP4 + j*2] = scale.as_value;
|
||||
y[i*MXFP4 + j*2+1] = scale.as_value;
|
||||
} else {
|
||||
y[i*MXFP4 + j*2] = GGML_FP16_TO_FP32(x0)*scale.as_value;
|
||||
y[i*MXFP4 + j*2+1] = GGML_FP16_TO_FP32(x1)*scale.as_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
|
||||
quantize_row_mxfp4_ref(src, dst, (int64_t)nrow*n_per_row);
|
||||
return nrow * ggml_row_size(GGML_TYPE_MXFP4, n_per_row);
|
||||
}
|
||||
|
||||
// =============================== data validation
|
||||
|
||||
static bool validate_float(float f, size_t i) {
|
||||
@@ -5214,7 +5352,9 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
|
||||
{
|
||||
VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_nl, data, nb);
|
||||
} break;
|
||||
|
||||
case GGML_TYPE_MXFP4:
|
||||
// TODO - anything to validate?
|
||||
break;
|
||||
case GGML_TYPE_I8:
|
||||
case GGML_TYPE_I16:
|
||||
case GGML_TYPE_I32:
|
||||
|
||||
6
ml/backend/ggml/ggml/src/ggml-quants.h
vendored
6
ml/backend/ggml/ggml/src/ggml-quants.h
vendored
@@ -37,6 +37,8 @@ GGML_API void quantize_row_iq4_xs_ref (const float * GGML_RESTRICT x, block_iq4_
|
||||
GGML_API void quantize_row_iq3_s_ref (const float * GGML_RESTRICT x, block_iq3_s * GGML_RESTRICT y, int64_t k);
|
||||
GGML_API void quantize_row_iq2_s_ref (const float * GGML_RESTRICT x, block_iq2_s * GGML_RESTRICT y, int64_t k);
|
||||
|
||||
GGML_API void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k);
|
||||
|
||||
// Dequantization
|
||||
GGML_API void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
||||
GGML_API void dequantize_row_q4_1(const block_q4_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
||||
@@ -65,6 +67,8 @@ GGML_API void dequantize_row_iq4_nl (const block_iq4_nl * GGML_RESTRICT x, floa
|
||||
GGML_API void dequantize_row_iq4_xs (const block_iq4_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
||||
GGML_API void dequantize_row_iq3_s (const block_iq3_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
||||
|
||||
GGML_API void dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
||||
|
||||
// Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization")
|
||||
GGML_API size_t quantize_iq2_xxs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
||||
GGML_API size_t quantize_iq2_xs (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
||||
@@ -90,6 +94,8 @@ GGML_API size_t quantize_q5_0(const float * GGML_RESTRICT src, void * GGML_RESTR
|
||||
GGML_API size_t quantize_q5_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
||||
GGML_API size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
||||
|
||||
GGML_API size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
||||
|
||||
GGML_API void iq2xs_init_impl(enum ggml_type type);
|
||||
GGML_API void iq2xs_free_impl(enum ggml_type type);
|
||||
GGML_API void iq3xs_init_impl(int grid_size);
|
||||
|
||||
13
ml/backend/ggml/ggml/src/ggml.c
vendored
13
ml/backend/ggml/ggml/src/ggml.c
vendored
@@ -589,11 +589,13 @@ static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = {
|
||||
.to_float = (ggml_to_float_t) dequantize_row_q4_1,
|
||||
.from_float_ref = (ggml_from_float_t) quantize_row_q4_1_ref,
|
||||
},
|
||||
[4] = { // GGML_TYPE_Q4_2
|
||||
.type_name = "DEPRECATED",
|
||||
.blck_size = 0,
|
||||
.type_size = 0,
|
||||
.is_quantized = false,
|
||||
[GGML_TYPE_MXFP4] = { // formerly deprecated GGML_TYPE_Q4_2
|
||||
.type_name = "mxfp4",
|
||||
.blck_size = MXFP4,
|
||||
.type_size = sizeof(block_mxfp4),
|
||||
.is_quantized = true,
|
||||
.to_float = (ggml_to_float_t) dequantize_row_mxfp4,
|
||||
.from_float_ref = (ggml_from_float_t) quantize_row_mxfp4_ref,
|
||||
},
|
||||
[5] = { // GGML_TYPE_Q4_3
|
||||
.type_name = "DEPRECATED",
|
||||
@@ -6446,6 +6448,7 @@ size_t ggml_quantize_chunk(
|
||||
case GGML_TYPE_IQ1_M: result = quantize_iq1_m (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_MXFP4: result = quantize_mxfp4 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_F16:
|
||||
{
|
||||
size_t elemsize = sizeof(ggml_fp16_t);
|
||||
|
||||
60
ml/backend/ggml/ggml_test.go
Normal file
60
ml/backend/ggml/ggml_test.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package ggml
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"log/slog"
|
||||
"os"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
slog.SetDefault(logutil.NewLogger(os.Stderr, envconfig.LogLevel()))
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
func setup(tb testing.TB) ml.Backend {
|
||||
tb.Helper()
|
||||
|
||||
f, err := os.CreateTemp(tb.TempDir(), "*.bin")
|
||||
if err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if err := ggml.WriteGGUF(f, ggml.KV{
|
||||
"general.architecture": "test",
|
||||
"test.block_count": uint32(1),
|
||||
}, []*ggml.Tensor{
|
||||
{Name: "blk.0.weight", Shape: []uint64{1}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 4))},
|
||||
}); err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
|
||||
b, err := New(f.Name(), ml.BackendParams{NumGPULayers: 1})
|
||||
if err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// initContextOrSkip takes a testing.T and true for GPU
|
||||
// If GPUs are not available, the current test is skipped
|
||||
// gpu=false will always succed
|
||||
func initContextOrSkip(t *testing.T, b ml.Backend, gpu bool) ml.Context {
|
||||
if gpu && len(b.(*Backend).schedBackends) == 1 {
|
||||
t.Skip("No GPU detected, skipping GPU test case")
|
||||
}
|
||||
ctx := b.NewContext()
|
||||
t.Cleanup(func() { ctx.Close() })
|
||||
if gpu {
|
||||
return ctx.Layer(0)
|
||||
}
|
||||
return ctx.Input()
|
||||
}
|
||||
800
ml/backend/ggml/mxfp4_test.go
Normal file
800
ml/backend/ggml/mxfp4_test.go
Normal file
@@ -0,0 +1,800 @@
|
||||
package ggml
|
||||
|
||||
import (
|
||||
"math"
|
||||
"math/rand"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
|
||||
fsggml "github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
/*
|
||||
To get GPUs loading in these tests on windows...
|
||||
|
||||
$env:OLLAMA_LIBRARY_PATH="$(pwd)\build\lib\ollama"
|
||||
$env:PATH="$(pwd)\build\lib\ollama;$env:PATH"
|
||||
|
||||
go test .\ml\backend\ggml\... -run TestMXFP4
|
||||
*/
|
||||
|
||||
// MXFP4 reference: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
|
||||
|
||||
var (
|
||||
// E2M1 values
|
||||
mxfp4_vals = []float32{
|
||||
0.0, // 0 00 0 = 0x0
|
||||
0.5, // 0 00 1 = 0x1
|
||||
1.0, // 0 01 0 = 0x2
|
||||
1.5, // 0 01 1 = 0x3
|
||||
2.0, // 0 10 0 = 0x4
|
||||
3.0, // 0 10 1 = 0x5
|
||||
4.0, // 0 11 0 = 0x6
|
||||
6.0, // 0 11 1 = 0x7
|
||||
0.0, // 1 00 0 = 0x8
|
||||
-0.5, // 1 00 1 = 0x9
|
||||
-1.0, // 1 01 0 = 0xa
|
||||
-1.5, // 1 01 1 = 0xb
|
||||
-2.0, // 1 10 0 = 0xc
|
||||
-3.0, // 1 10 1 = 0xd
|
||||
-4.0, // 1 11 0 = 0xe
|
||||
-6.0, // 1 11 1 = 0xf
|
||||
}
|
||||
)
|
||||
|
||||
func TestMXFP4Ops(t *testing.T) {
|
||||
b := setup(t)
|
||||
for _, useGPU := range []bool{false, true} {
|
||||
useGPU := useGPU
|
||||
var label string
|
||||
if useGPU {
|
||||
label = "gpu"
|
||||
} else {
|
||||
label = "cpu"
|
||||
}
|
||||
t.Run(label, func(t *testing.T) {
|
||||
t.Run("mulmatid", func(t *testing.T) {
|
||||
// Use exact values that are supported without scaling so we can compare against an fp32 tensor
|
||||
t.Run("exact", func(t *testing.T) {
|
||||
r := rand.New(rand.NewSource(0))
|
||||
ctx := initContextOrSkip(t, b, useGPU)
|
||||
const s00 = 64
|
||||
const s01 = 1
|
||||
const s02 = 2
|
||||
const s10 = s00
|
||||
const s11 = 1
|
||||
const s12 = 1
|
||||
// const s00 = 2880
|
||||
// const s01 = 5760
|
||||
// const s02 = 32
|
||||
// const s10 = s00
|
||||
// const s11 = 1
|
||||
// const s12 = 64
|
||||
|
||||
data := [s00 * s01 * s02]float32{}
|
||||
for i := range data {
|
||||
data[i] = mxfp4_vals[r.Int()%len(mxfp4_vals)]
|
||||
}
|
||||
mxData := Quantize(fsggml.TensorTypeMXFP4, data[:], []uint64{uint64(len(data))})
|
||||
dtype := ml.DTypeMXFP4
|
||||
t1 := ctx.(*Context).FromBytes(dtype, mxData, s00, s01, s02)
|
||||
t1f := ctx.(*Context).FromFloatSlice(data[:], s00, s01, s02)
|
||||
// for i := range len(data) / 32 { // MXFP4 block size
|
||||
// vals := [32]string{}
|
||||
// for j := range vals {
|
||||
// vals[j] = fmt.Sprintf("%0.2f", data[i*32+j])
|
||||
// }
|
||||
// t.Logf(" t1[%s]\n", strings.Join(vals[:], ", "))
|
||||
// }
|
||||
|
||||
// random 0-1 float
|
||||
d2 := [s10 * s11 * s12]float32{}
|
||||
for i := range d2 {
|
||||
d2[i] = float32(r.Float32())
|
||||
}
|
||||
// for i := range len(d2) / s10 {
|
||||
// vals := [s10]string{}
|
||||
// for j := range vals {
|
||||
// vals[j] = fmt.Sprintf("%0.2f", d2[i*s10+j])
|
||||
// }
|
||||
// t.Logf(" t2[%s]\n", strings.Join(vals[:], ", "))
|
||||
// }
|
||||
t2 := ctx.(*Context).FromFloatSlice(d2[:], s10, s11, s12)
|
||||
|
||||
d3 := [4 * s12]int32{}
|
||||
for i := range d3 {
|
||||
d3[i] = int32(i) % s02
|
||||
}
|
||||
t3 := ctx.(*Context).FromIntSlice(d3[:], 4, s12)
|
||||
|
||||
// t.Log("calling MulmatID")
|
||||
t4 := t1.MulmatID(ctx, t2, t3)
|
||||
t4f := t1f.MulmatID(ctx, t2, t3)
|
||||
d4 := ml.Dump(ctx, t4, ml.DumpWithPrecision(2)) // lower precision for CPU accuracy
|
||||
d4f := ml.Dump(ctx, t4f, ml.DumpWithPrecision(2))
|
||||
if d4 != d4f {
|
||||
t.Fatalf("expected (f32): \n%s\n\n but got (mxfp4): \n%s", d4f, d4)
|
||||
}
|
||||
// t.Logf("MulmatID results matched:\n%s", d4)
|
||||
})
|
||||
|
||||
t.Run("range", func(t *testing.T) {
|
||||
r := rand.New(rand.NewSource(0))
|
||||
ctx := initContextOrSkip(t, b, useGPU)
|
||||
const s0 = 64
|
||||
const s1 = 2
|
||||
const s2 = 4
|
||||
const idlen = 4
|
||||
data := [s0 * s1 * s2]float32{}
|
||||
inTotal := float32(0)
|
||||
for i := range data {
|
||||
data[i] = float32(i)
|
||||
inTotal += float32(i)
|
||||
}
|
||||
mxData := Quantize(fsggml.TensorTypeMXFP4, data[:], []uint64{uint64(len(data))})
|
||||
// Reconvert back to floats to remove the quantization fidelity loss for comparison
|
||||
dataf := ConvertToF32(mxData, uint32(fsggml.TensorTypeMXFP4), uint64(len(data)))
|
||||
dtype := ml.DTypeMXFP4
|
||||
t1 := ctx.(*Context).FromBytes(dtype, mxData, s0, s1, s2)
|
||||
t1f := ctx.(*Context).FromFloatSlice(dataf, s0, s1, s2)
|
||||
// for i := range len(data) / 32 {
|
||||
// vals := [32]string{}
|
||||
// for j := range vals {
|
||||
// vals[j] = fmt.Sprintf("%0.2f", dataf[i*32+j])
|
||||
// }
|
||||
// t.Logf(" t1[%s]\n", strings.Join(vals[:], ", "))
|
||||
// }
|
||||
|
||||
d2 := [s0]float32{}
|
||||
for i := range d2 {
|
||||
// d2[i] = float32(i)
|
||||
d2[i] = float32(r.Float32())
|
||||
}
|
||||
// for i := range len(d2) / s0 {
|
||||
// vals := [s0]string{}
|
||||
// for j := range vals {
|
||||
// vals[j] = fmt.Sprintf("%0.2f", d2[i*s0+j])
|
||||
// }
|
||||
// t.Logf(" t2[%s]\n", strings.Join(vals[:], ", "))
|
||||
// }
|
||||
t2 := ctx.(*Context).FromFloatSlice(d2[:], s0)
|
||||
|
||||
// TODO - there might be a CUDA bug here...
|
||||
d3 := [idlen]int32{1, 1, 2, 3}
|
||||
// for i := range d3 {
|
||||
// d3[i] = int32(i) % s2
|
||||
// t.Logf("%d] %d", i, d3[i])
|
||||
// }
|
||||
t3 := ctx.(*Context).FromIntSlice(d3[:], idlen)
|
||||
|
||||
// t.Log("calling Mulmat")
|
||||
t4 := t1.MulmatID(ctx, t2, t3)
|
||||
t4f := t1f.MulmatID(ctx, t2, t3)
|
||||
// Metal has some drift so use reduced precision for dump comparisons
|
||||
d4 := ml.Dump(ctx, t4, ml.DumpWithPrecision(2))
|
||||
d4f := ml.Dump(ctx, t4f, ml.DumpWithPrecision(2))
|
||||
r4 := t4.Floats()
|
||||
r4f := t4f.Floats()
|
||||
sim := cosineSimilarity(r4, r4f)
|
||||
if sim < 0.99 {
|
||||
t.Logf("expected (f32): \n%s\n\n but got (mxfp4): \n%s", d4f, d4)
|
||||
t.Fatalf("failed similarity test: %f", sim)
|
||||
}
|
||||
t.Logf("similarity: %f", sim)
|
||||
|
||||
if d4 != d4f {
|
||||
t.Fatalf("expected (f32): \n%s\n\n but got (mxfp4): \n%s", d4f, d4)
|
||||
}
|
||||
// t.Logf("mxfp4 result\n%s", d4)
|
||||
})
|
||||
t.Run("random", func(t *testing.T) {
|
||||
r := rand.New(rand.NewSource(0))
|
||||
ctx := initContextOrSkip(t, b, useGPU)
|
||||
const s00 = 2880
|
||||
const s01 = 5760
|
||||
const s02 = 32
|
||||
const s10 = s00
|
||||
const s11 = 1
|
||||
const s12 = 64
|
||||
const idlen = 4
|
||||
|
||||
data := [s00 * s01 * s02]float32{}
|
||||
for i := range data {
|
||||
data[i] = float32(r.Float32() * 10.0)
|
||||
}
|
||||
mxData := Quantize(fsggml.TensorTypeMXFP4, data[:], []uint64{uint64(len(data))})
|
||||
// Reconvert back to floats to remove the quantization fidelity loss for comparison
|
||||
dataf := ConvertToF32(mxData, uint32(fsggml.TensorTypeMXFP4), uint64(len(data)))
|
||||
dtype := ml.DTypeMXFP4
|
||||
t1 := ctx.(*Context).FromBytes(dtype, mxData, s00, s01, s02)
|
||||
t1f := ctx.(*Context).FromFloatSlice(dataf, s00, s01, s02)
|
||||
// for i := range len(data) / 32 {
|
||||
// vals := [32]string{}
|
||||
// for j := range vals {
|
||||
// vals[j] = fmt.Sprintf("%0.2f", dataf[i*32+j])
|
||||
// }
|
||||
// t.Logf(" t1[%s]\n", strings.Join(vals[:], ", "))
|
||||
// }
|
||||
|
||||
d2 := [s10 * s11 * s12]float32{}
|
||||
for i := range d2 {
|
||||
// d2[i] = float32(i)
|
||||
d2[i] = float32(r.Float32())
|
||||
}
|
||||
// for i := range len(d2) / s0 {
|
||||
// vals := [s0]string{}
|
||||
// for j := range vals {
|
||||
// vals[j] = fmt.Sprintf("%0.2f", d2[i*s0+j])
|
||||
// }
|
||||
// t.Logf(" t2[%s]\n", strings.Join(vals[:], ", "))
|
||||
// }
|
||||
t2 := ctx.(*Context).FromFloatSlice(d2[:], s10, s11, s12)
|
||||
|
||||
// arange equiv
|
||||
d3 := [idlen * s12]int32{}
|
||||
for i := range d3 {
|
||||
d3[i] = int32(i) % s02
|
||||
}
|
||||
t3 := ctx.(*Context).FromIntSlice(d3[:], idlen, s12)
|
||||
|
||||
// t.Log("calling Mulmat")
|
||||
// t3 := t1.Mulmat(ctx, t2)
|
||||
// t3f := t1f.Mulmat(ctx, t2)
|
||||
t4 := t1.MulmatID(ctx, t2, t3)
|
||||
t4f := t1f.MulmatID(ctx, t2, t3)
|
||||
// Metal and CPU have some drift so use reduced precision for dump comparisons
|
||||
d4 := ml.Dump(ctx, t4, ml.DumpWithPrecision(1))
|
||||
d4f := ml.Dump(ctx, t4f, ml.DumpWithPrecision(1))
|
||||
// t.Logf("mxfp4 data: \n%s", d4)
|
||||
r4 := t4.Floats()
|
||||
r4f := t4f.Floats()
|
||||
sim := cosineSimilarity(r4, r4f)
|
||||
if sim < 0.99 {
|
||||
t.Logf("expected (f32): \n%s\n\n but got (mxfp4): \n%s", d4f, d4)
|
||||
t.Fatalf("failed similarity test: %f", sim)
|
||||
}
|
||||
t.Logf("similarity: %f", sim)
|
||||
|
||||
if d4 != d4f {
|
||||
t.Fatalf("expected (f32): \n%s\n\n but got (mxfp4): \n%s", d4f, d4)
|
||||
}
|
||||
})
|
||||
|
||||
// Use data file(s) with real data
|
||||
t.Run("example_7", func(t *testing.T) {
|
||||
ctx := initContextOrSkip(t, b, useGPU)
|
||||
data0, err := os.ReadFile("mlp-gateup.bin")
|
||||
if err != nil {
|
||||
t.Skip("missing mlp-gateup.bin file, skipping test")
|
||||
}
|
||||
data1, err := os.ReadFile("hidden-states-7.bin")
|
||||
if err != nil {
|
||||
t.Skip("missing hidden-states.bin file, skipping test")
|
||||
}
|
||||
data2, err := os.ReadFile("selected-experts-7.bin")
|
||||
if err != nil {
|
||||
t.Skip("missing selected-experts.bin file, skipping test")
|
||||
}
|
||||
|
||||
dtype := ml.DTypeMXFP4
|
||||
data0f := ConvertToF32(data0, uint32(fsggml.TensorTypeMXFP4), 2880*5760*32)
|
||||
t1 := ctx.(*Context).FromBytes(dtype, data0, 2880, 5760, 32)
|
||||
t1f := ctx.(*Context).FromFloatSlice(data0f, 2880, 5760, 32)
|
||||
|
||||
// t.Logf("f32: \n%s", ml.Dump(ctx, t1f))
|
||||
|
||||
t2 := ctx.(*Context).FromBytes(ml.DTypeF32, data1, 2880, 1, 7)
|
||||
// t.Logf("hidden-state: \n%s", ml.Dump(ctx, t2))
|
||||
|
||||
t3 := ctx.(*Context).FromBytes(ml.DTypeI32, data2, 4, 7)
|
||||
// t.Logf("experts: \n%s", ml.Dump(ctx, t3))
|
||||
|
||||
// t.Log("calling MulmatID")
|
||||
t4 := t1.MulmatID(ctx, t2, t3)
|
||||
t4f := t1f.MulmatID(ctx, t2, t3)
|
||||
|
||||
d4 := ml.Dump(ctx, t4)
|
||||
d4f := ml.Dump(ctx, t4f)
|
||||
|
||||
r4 := t4.Floats()
|
||||
r4f := t4f.Floats()
|
||||
sim := cosineSimilarity(r4, r4f)
|
||||
if sim < 0.99 {
|
||||
t.Fatalf("failed similarity test: %f", sim)
|
||||
}
|
||||
t.Logf("similarity: %f", sim)
|
||||
|
||||
if d4 != d4f {
|
||||
t.Fatalf("expected (f32): \n%s\n\n but got (mxfp4): \n%s", d4f, d4)
|
||||
}
|
||||
// t.Logf("MulmatID results matched:\n%s", d4)
|
||||
})
|
||||
|
||||
// Use data file(s) with real data
|
||||
t.Run("example_384", func(t *testing.T) {
|
||||
ctx := initContextOrSkip(t, b, useGPU)
|
||||
data0, err := os.ReadFile("mlp-gateup.bin")
|
||||
if err != nil {
|
||||
t.Skip("missing mlp-gateup.bin file, skipping test")
|
||||
}
|
||||
data1, err := os.ReadFile("hidden-states-384.bin")
|
||||
if err != nil {
|
||||
t.Skip("missing hidden-states.bin file, skipping test")
|
||||
}
|
||||
data2, err := os.ReadFile("selected-experts-384.bin")
|
||||
if err != nil {
|
||||
t.Skip("missing selected-experts.bin file, skipping test")
|
||||
}
|
||||
|
||||
dtype := ml.DTypeMXFP4
|
||||
data0f := ConvertToF32(data0, uint32(fsggml.TensorTypeMXFP4), 2880*5760*32)
|
||||
t1 := ctx.(*Context).FromBytes(dtype, data0, 2880, 5760, 32)
|
||||
t1f := ctx.(*Context).FromFloatSlice(data0f, 2880, 5760, 32)
|
||||
|
||||
// t.Logf("f32: \n%s", ml.Dump(ctx, t1f))
|
||||
|
||||
t2 := ctx.(*Context).FromBytes(ml.DTypeF32, data1, 2880, 1, 384)
|
||||
// t.Logf("hidden-state: \n%s", ml.Dump(ctx, t2))
|
||||
|
||||
t3 := ctx.(*Context).FromBytes(ml.DTypeI32, data2, 4, 384)
|
||||
// t.Logf("experts: \n%s", ml.Dump(ctx, t3))
|
||||
|
||||
// t.Log("calling MulmatID")
|
||||
t4 := t1.MulmatID(ctx, t2, t3)
|
||||
t4f := t1f.MulmatID(ctx, t2, t3)
|
||||
|
||||
d4 := ml.Dump(ctx, t4, ml.DumpWithPrecision(3))
|
||||
d4f := ml.Dump(ctx, t4f, ml.DumpWithPrecision(3))
|
||||
|
||||
r4 := t4.Floats()
|
||||
r4f := t4f.Floats()
|
||||
sim := cosineSimilarity(r4, r4f)
|
||||
if sim < 0.99 {
|
||||
t.Fatalf("failed similarity test: %f", sim)
|
||||
}
|
||||
t.Logf("similarity: %f", sim)
|
||||
|
||||
if d4 != d4f {
|
||||
t.Fatalf("expected (f32): \n%s\n\n but got (mxfp4): \n%s", d4f, d4)
|
||||
}
|
||||
// t.Logf("MulmatID results matched:\n%s", d4)
|
||||
})
|
||||
|
||||
// Use data file(s) with real data
|
||||
t.Run("example_1d", func(t *testing.T) {
|
||||
r := rand.New(rand.NewSource(0))
|
||||
ctx := initContextOrSkip(t, b, useGPU)
|
||||
data0, err := os.ReadFile("mlp-gateup.bin")
|
||||
if err != nil {
|
||||
t.Skip("missing mlp-gateup.bin file, skipping test")
|
||||
}
|
||||
|
||||
dtype := ml.DTypeMXFP4
|
||||
data0f := ConvertToF32(data0, uint32(fsggml.TensorTypeMXFP4), 2880*5760*32)
|
||||
t1 := ctx.(*Context).FromBytes(dtype, data0, 2880, 5760, 32)
|
||||
t1f := ctx.(*Context).FromFloatSlice(data0f, 2880, 5760, 32)
|
||||
|
||||
// t.Logf("f32: \n%s", ml.Dump(ctx, t1f))
|
||||
data1 := [2880]float32{}
|
||||
for i := range data1 {
|
||||
data1[i] = float32(r.Float32())
|
||||
}
|
||||
|
||||
t2 := ctx.(*Context).FromFloatSlice(data1[:], 2880)
|
||||
// t.Logf("hidden-state: \n%s", ml.Dump(ctx, t2))
|
||||
data2 := [4]int32{
|
||||
12, 30, 17, 7,
|
||||
// 7, 17, 12, 30,
|
||||
}
|
||||
|
||||
t3 := ctx.(*Context).FromIntSlice(data2[:], 4)
|
||||
// t.Logf("experts: \n%s", ml.Dump(ctx, t3))
|
||||
|
||||
// t.Log("calling MulmatID")
|
||||
t4 := t1.MulmatID(ctx, t2, t3)
|
||||
t4f := t1f.MulmatID(ctx, t2, t3)
|
||||
|
||||
d4 := ml.Dump(ctx, t4)
|
||||
d4f := ml.Dump(ctx, t4f)
|
||||
|
||||
r4 := t4.Floats()
|
||||
r4f := t4f.Floats()
|
||||
sim := cosineSimilarity(r4, r4f)
|
||||
if sim < 0.99 {
|
||||
t.Fatalf("failed similarity test: %f", sim)
|
||||
}
|
||||
t.Logf("similarity: %f", sim)
|
||||
|
||||
if d4 != d4f {
|
||||
t.Fatalf("expected (f32): \n%s\n\n but got (mxfp4): \n%s", d4f, d4)
|
||||
}
|
||||
// t.Logf("MulmatID results matched:\n%s", d4)
|
||||
})
|
||||
|
||||
})
|
||||
|
||||
t.Run("mm", func(t *testing.T) {
|
||||
|
||||
t.Run("example", func(t *testing.T) {
|
||||
r := rand.New(rand.NewSource(0))
|
||||
ctx := initContextOrSkip(t, b, useGPU)
|
||||
data0, err := os.ReadFile("mlp-gateup.bin")
|
||||
if err != nil {
|
||||
t.Skip("missing mlp-gateup.bin file, skipping test")
|
||||
}
|
||||
data1 := [2880 * 1 * 32]float32{}
|
||||
for i := range data1 {
|
||||
data1[i] = float32(r.Float32())
|
||||
}
|
||||
|
||||
dtype := ml.DTypeMXFP4
|
||||
data0f := ConvertToF32(data0, uint32(fsggml.TensorTypeMXFP4), 2880*5760*32)
|
||||
t1 := ctx.(*Context).FromBytes(dtype, data0, 2880, 5760, 32)
|
||||
t1f := ctx.(*Context).FromFloatSlice(data0f, 2880, 5760, 32)
|
||||
|
||||
// t.Logf("f32: \n%s", ml.Dump(ctx, t1f))
|
||||
|
||||
t2 := ctx.(*Context).FromFloatSlice(data1[:], 2880, 1, 32)
|
||||
|
||||
t4 := t1.Mulmat(ctx, t2)
|
||||
t4f := t1f.Mulmat(ctx, t2)
|
||||
|
||||
d4 := ml.Dump(ctx, t4, ml.DumpWithPrecision(3))
|
||||
d4f := ml.Dump(ctx, t4f, ml.DumpWithPrecision(3))
|
||||
|
||||
r4 := t4.Floats()
|
||||
r4f := t4f.Floats()
|
||||
sim := cosineSimilarity(r4, r4f)
|
||||
if sim < 0.99 {
|
||||
t.Fatalf("failed similarity test: %f", sim)
|
||||
}
|
||||
t.Logf("similarity: %f", sim)
|
||||
|
||||
if d4 != d4f {
|
||||
t.Fatalf("expected (f32): \n%s\n\n but got (mxfp4): \n%s", d4f, d4)
|
||||
}
|
||||
// t.Logf("Mulmat results matched:\n%s", d4)
|
||||
})
|
||||
|
||||
t.Run("exact/3x3", func(t *testing.T) {
|
||||
r := rand.New(rand.NewSource(0))
|
||||
ctx := initContextOrSkip(t, b, useGPU)
|
||||
const s10 = 64
|
||||
const s11 = 1
|
||||
const s12 = 2
|
||||
const s20 = s10
|
||||
const s21 = 1
|
||||
const s22 = 2
|
||||
|
||||
data := [s10 * s11 * s12]float32{}
|
||||
for i := range data {
|
||||
data[i] = mxfp4_vals[r.Int()%len(mxfp4_vals)]
|
||||
}
|
||||
// for i := range len(data) / 32 {
|
||||
// vals := [32]string{}
|
||||
// for j := range vals {
|
||||
// vals[j] = fmt.Sprintf("%0.2f", data[i*32+j])
|
||||
// }
|
||||
// t.Logf(" [%s]\n", strings.Join(vals[:], ", "))
|
||||
// }
|
||||
mxData := Quantize(fsggml.TensorTypeMXFP4, data[:], []uint64{uint64(len(data))})
|
||||
// for i := range len(mxData) / 17 {
|
||||
// vals := [17]string{}
|
||||
// for j := range vals {
|
||||
// vals[j] = fmt.Sprintf("%0.2x", mxData[i*17+j])
|
||||
// }
|
||||
// t.Logf(" %s\n", strings.Join(vals[:], ", "))
|
||||
// }
|
||||
dtype := ml.DTypeMXFP4
|
||||
t1 := ctx.(*Context).FromBytes(dtype, mxData, s10, s11, s12)
|
||||
t1f := ctx.(*Context).FromFloatSlice(data[:], s10, s11, s12)
|
||||
|
||||
d2 := [s20 * s21 * s22]float32{}
|
||||
for i := range d2 {
|
||||
d2[i] = float32(r.Float32())
|
||||
}
|
||||
t2 := ctx.(*Context).FromFloatSlice(d2[:], s20, s21, s22)
|
||||
|
||||
t3f := t1f.Mulmat(ctx, t2)
|
||||
t3 := t1.Mulmat(ctx, t2)
|
||||
d3 := ml.Dump(ctx, t3)
|
||||
d3f := ml.Dump(ctx, t3f)
|
||||
if d3 != d3f {
|
||||
t.Fatalf("expected (f32): \n%s\n\n but got (mxfp4): \n%s", d3f, d3)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("exact/2x2", func(t *testing.T) {
|
||||
r := rand.New(rand.NewSource(0))
|
||||
ctx := initContextOrSkip(t, b, useGPU)
|
||||
const s0 = 32
|
||||
const s1 = 64
|
||||
|
||||
data := [s0 * s1]float32{}
|
||||
for i := range data {
|
||||
data[i] = mxfp4_vals[r.Int()%len(mxfp4_vals)]
|
||||
}
|
||||
// for i := range 4 {
|
||||
// vals := [32]string{}
|
||||
// for j := range vals {
|
||||
// vals[j] = fmt.Sprintf("%0.2f", data[i*32+j])
|
||||
// }
|
||||
// t.Logf(" [%s]\n", strings.Join(vals[:], ", "))
|
||||
// }
|
||||
mxData := Quantize(fsggml.TensorTypeMXFP4, data[:], []uint64{uint64(len(data))})
|
||||
// for i := range len(mxData) / 17 {
|
||||
// vals := [17]string{}
|
||||
// for j := range vals {
|
||||
// vals[j] = fmt.Sprintf("%0.2x", mxData[i*17+j])
|
||||
// }
|
||||
// t.Logf(" %s\n", strings.Join(vals[:], ", "))
|
||||
// }
|
||||
dtype := ml.DTypeMXFP4
|
||||
t1 := ctx.(*Context).FromBytes(dtype, mxData, s0, s1)
|
||||
t1f := ctx.(*Context).FromFloatSlice(data[:], s0, s1)
|
||||
|
||||
d2 := [s0 * s1]float32{}
|
||||
for i := range d2 {
|
||||
d2[i] = float32(r.Float32())
|
||||
}
|
||||
t2 := ctx.(*Context).FromFloatSlice(d2[:], s0, s1)
|
||||
|
||||
t3f := t1f.Mulmat(ctx, t2)
|
||||
t3 := t1.Mulmat(ctx, t2)
|
||||
d3 := ml.Dump(ctx, t3)
|
||||
d3f := ml.Dump(ctx, t3f)
|
||||
if d3 != d3f {
|
||||
t.Fatalf("expected (f32): \n%s\n\n but got (mxfp4): \n%s", d3f, d3)
|
||||
}
|
||||
})
|
||||
t.Run("exact/2x1", func(t *testing.T) {
|
||||
r := rand.New(rand.NewSource(0))
|
||||
ctx := initContextOrSkip(t, b, useGPU)
|
||||
const s0 = 64
|
||||
const s1 = 4
|
||||
|
||||
data := [s0 * s1]float32{}
|
||||
for i := range data {
|
||||
data[i] = mxfp4_vals[r.Int()%len(mxfp4_vals)]
|
||||
}
|
||||
// for i := range len(data) / 32 {
|
||||
// vals := [32]string{}
|
||||
// for j := range vals {
|
||||
// vals[j] = fmt.Sprintf("%0.2f", data[i*32+j])
|
||||
// }
|
||||
// t.Logf(" t1[%s]\n", strings.Join(vals[:], ", "))
|
||||
// }
|
||||
mxData := Quantize(fsggml.TensorTypeMXFP4, data[:], []uint64{uint64(len(data))})
|
||||
// for i := range len(mxData) / 17 {
|
||||
// vals := [17]string{}
|
||||
// for j := range vals {
|
||||
// vals[j] = fmt.Sprintf("%0.2x", mxData[i*17+j])
|
||||
// }
|
||||
// t.Logf(" %s\n", strings.Join(vals[:], ", "))
|
||||
// }
|
||||
dtype := ml.DTypeMXFP4
|
||||
t1 := ctx.(*Context).FromBytes(dtype, mxData, s0, s1)
|
||||
t1f := ctx.(*Context).FromFloatSlice(data[:], s0, s1)
|
||||
|
||||
d2 := [s0]float32{}
|
||||
for i := range d2 {
|
||||
d2[i] = float32(r.Float32())
|
||||
}
|
||||
// for i := range len(d2) / 32 {
|
||||
// vals := [32]string{}
|
||||
// for j := range vals {
|
||||
// vals[j] = fmt.Sprintf("%0.2f", d2[i*32+j])
|
||||
// }
|
||||
// t.Logf(" t2[%s]\n", strings.Join(vals[:], ", "))
|
||||
// }
|
||||
|
||||
t2 := ctx.(*Context).FromFloatSlice(d2[:], s0)
|
||||
|
||||
t3f := t1f.Mulmat(ctx, t2)
|
||||
t3 := t1.Mulmat(ctx, t2)
|
||||
d3 := ml.Dump(ctx, t3, ml.DumpWithPrecision(3))
|
||||
d3f := ml.Dump(ctx, t3f, ml.DumpWithPrecision(3))
|
||||
if d3 != d3f {
|
||||
t.Fatalf("expected (f32): \n%s\n\n but got (mxfp4): \n%s", d3f, d3)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("range/2d", func(t *testing.T) {
|
||||
r := rand.New(rand.NewSource(0))
|
||||
ctx := initContextOrSkip(t, b, useGPU)
|
||||
const s0 = 32
|
||||
const s1 = 4
|
||||
data := [s0 * s1]float32{}
|
||||
inTotal := float32(0)
|
||||
for i := range data {
|
||||
data[i] = float32(i)
|
||||
inTotal += float32(i)
|
||||
}
|
||||
mxData := Quantize(fsggml.TensorTypeMXFP4, data[:], []uint64{uint64(len(data))})
|
||||
// Reconvert back to floats to remove the quantization fidelity loss for comparison
|
||||
dataf := ConvertToF32(mxData, uint32(fsggml.TensorTypeMXFP4), uint64(len(data)))
|
||||
dtype := ml.DTypeMXFP4
|
||||
t1 := ctx.(*Context).FromBytes(dtype, mxData, s0, s1)
|
||||
t1f := ctx.(*Context).FromFloatSlice(dataf, s0, s1)
|
||||
// for i := range len(data) / 32 {
|
||||
// vals := [32]string{}
|
||||
// for j := range vals {
|
||||
// vals[j] = fmt.Sprintf("%0.2f", dataf[i*32+j])
|
||||
// }
|
||||
// t.Logf(" t1[%s]\n", strings.Join(vals[:], ", "))
|
||||
// }
|
||||
|
||||
d2 := [s0 * s1]float32{}
|
||||
for i := range d2 {
|
||||
// d2[i] = float32(i)
|
||||
d2[i] = float32(r.Float32())
|
||||
}
|
||||
// for i := range len(d2) / s0 {
|
||||
// vals := [s0]string{}
|
||||
// for j := range vals {
|
||||
// vals[j] = fmt.Sprintf("%0.2f", d2[i*s0+j])
|
||||
// }
|
||||
// t.Logf(" t2[%s]\n", strings.Join(vals[:], ", "))
|
||||
// }
|
||||
|
||||
t2 := ctx.(*Context).FromFloatSlice(d2[:], s0, s1)
|
||||
|
||||
// t.Log("calling Mulmat")
|
||||
t3 := t1.Mulmat(ctx, t2)
|
||||
t3f := t1f.Mulmat(ctx, t2)
|
||||
d3 := ml.Dump(ctx, t3, ml.DumpWithPrecision(2))
|
||||
d3f := ml.Dump(ctx, t3f, ml.DumpWithPrecision(2))
|
||||
r3 := t3.Floats()
|
||||
r3f := t3f.Floats()
|
||||
sim := cosineSimilarity(r3, r3f)
|
||||
if sim < 0.99 {
|
||||
t.Logf("expected (f32): \n%s\n\n but got (mxfp4): \n%s", d3f, d3)
|
||||
t.Fatalf("failed similarity test: %f", sim)
|
||||
}
|
||||
t.Logf("similarity: %f", sim)
|
||||
if d3 != d3f {
|
||||
t.Fatalf("expected (f32): \n%s\n\n but got (mxfp4): \n%s", d3f, d3)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("range/3d", func(t *testing.T) {
|
||||
ctx := initContextOrSkip(t, b, useGPU)
|
||||
data := [32 * 4 * 2]float32{}
|
||||
inTotal := float32(0)
|
||||
for i := range data {
|
||||
data[i] = float32(i)
|
||||
inTotal += float32(i)
|
||||
}
|
||||
mxData := Quantize(fsggml.TensorTypeMXFP4, data[:], []uint64{uint64(len(data))})
|
||||
dtype := ml.DTypeMXFP4
|
||||
// Reconvert back to floats to remove the quantization fidelity loss for comparison
|
||||
dataf := ConvertToF32(mxData, uint32(fsggml.TensorTypeMXFP4), uint64(len(data)))
|
||||
t1 := ctx.(*Context).FromBytes(dtype, mxData, 32, 4, 2)
|
||||
t1f := ctx.(*Context).FromFloatSlice(dataf, 32, 4, 2)
|
||||
|
||||
d2 := [32 * 4 * 2]float32{}
|
||||
for i := range d2 {
|
||||
d2[i] = 2.0
|
||||
}
|
||||
t2 := ctx.(*Context).FromFloatSlice(d2[:], 32, 4, 2)
|
||||
|
||||
// t.Log("calling Mulmat")
|
||||
t3 := t1.Mulmat(ctx, t2)
|
||||
t3f := t1f.Mulmat(ctx, t2)
|
||||
d3 := ml.Dump(ctx, t3)
|
||||
d3f := ml.Dump(ctx, t3f)
|
||||
r3 := t3.Floats()
|
||||
r3f := t3f.Floats()
|
||||
sim := cosineSimilarity(r3, r3f)
|
||||
if sim < 0.99 {
|
||||
t.Logf("expected (f32): \n%s\n\n but got (mxfp4): \n%s", d3f, d3)
|
||||
t.Fatalf("failed similarity test: %f", sim)
|
||||
}
|
||||
t.Logf("similarity: %f", sim)
|
||||
if d3 != d3f {
|
||||
t.Fatalf("expected (f32): \n%s\n\n but got (mxfp4): \n%s", d3f, d3)
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMXFP4Simple(t *testing.T) {
|
||||
b := setup(t)
|
||||
|
||||
t.Run("fixed", func(t *testing.T) {
|
||||
ctx := initContextOrSkip(t, b, false)
|
||||
data := [32 * 2]float32{
|
||||
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
|
||||
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
|
||||
}
|
||||
mxData := Quantize(fsggml.TensorTypeMXFP4, data[:], []uint64{uint64(len(data))})
|
||||
dtype := ml.DTypeMXFP4
|
||||
// Reconvert back to floats to remove the quantization fidelity loss for comparison
|
||||
dataf := ConvertToF32(mxData, uint32(fsggml.TensorTypeMXFP4), uint64(len(data)))
|
||||
t1 := ctx.(*Context).FromBytes(dtype, mxData, 32, 2)
|
||||
t1f := ctx.(*Context).FromFloatSlice(dataf, 32, 2)
|
||||
|
||||
d2 := [32 * 2]float32{
|
||||
// 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
}
|
||||
t2 := ctx.(*Context).FromFloatSlice(d2[:], 32, 2)
|
||||
|
||||
t.Log("calling Mulmat")
|
||||
t3f := t1f.Mulmat(ctx, t2)
|
||||
t3 := t1.Mulmat(ctx, t2)
|
||||
d3 := ml.Dump(ctx, t3)
|
||||
d3f := ml.Dump(ctx, t3f)
|
||||
if d3 != d3f {
|
||||
t.Fatalf("expected (f32): \n%s\n\n but got (mxfp4): \n%s", d3f, d3)
|
||||
}
|
||||
t.Logf("result (mxfp4): \n%s", d3)
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func TestMXFP4Conversion(t *testing.T) {
|
||||
t.Run("quantize/exact", func(t *testing.T) {
|
||||
r := rand.New(rand.NewSource(0))
|
||||
|
||||
data := [32 * 4]float32{}
|
||||
for i := range data {
|
||||
data[i] = mxfp4_vals[r.Int()%len(mxfp4_vals)] * 0.1
|
||||
}
|
||||
mxData := Quantize(fsggml.TensorTypeMXFP4, data[:], []uint64{uint64(len(data))})
|
||||
newData := ConvertToF32(mxData, uint32(fsggml.TensorTypeMXFP4), uint64(len(data)))
|
||||
|
||||
if len(data) != len(newData) {
|
||||
t.Fatalf("length mismatch. started with %d but got %d", len(data), len(newData))
|
||||
}
|
||||
for i := range data {
|
||||
if data[i] != newData[i] {
|
||||
t.Logf("started with: %v", data)
|
||||
t.Logf("got : %v", newData)
|
||||
t.Fatalf("mismatched data starting at offset %d started with %f but got %f", i, data[i], newData[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
t.Run("quantize/arange", func(t *testing.T) {
|
||||
data := [32 * 8]float32{}
|
||||
for i := range data {
|
||||
data[i] = float32(i) // / float32(6.0)
|
||||
}
|
||||
mxData := Quantize(fsggml.TensorTypeMXFP4, data[:], []uint64{uint64(len(data))})
|
||||
newData := ConvertToF32(mxData, uint32(fsggml.TensorTypeMXFP4), uint64(len(data)))
|
||||
|
||||
if len(data) != len(newData) {
|
||||
t.Fatalf("length mismatch. started with %d but got %d", len(data), len(newData))
|
||||
}
|
||||
sim := cosineSimilarity(data[:], newData)
|
||||
if sim < 0.99 {
|
||||
t.Fatalf("failed similarity test: %f", sim)
|
||||
}
|
||||
t.Logf("similarity: %f", sim)
|
||||
})
|
||||
}
|
||||
|
||||
func dotProduct[V float32 | float64](v1, v2 []V) V {
|
||||
var result V = 0
|
||||
for i := range v1 {
|
||||
result += v1[i] * v2[i]
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func magnitude[V float32 | float64](v []V) V {
|
||||
var result V = 0
|
||||
for _, val := range v {
|
||||
result += val * val
|
||||
}
|
||||
return V(math.Sqrt(float64(result)))
|
||||
}
|
||||
|
||||
func cosineSimilarity[V float32 | float64](v1, v2 []V) V {
|
||||
return dotProduct(v1, v2) / (magnitude(v1) * magnitude(v2))
|
||||
}
|
||||
@@ -44,6 +44,8 @@ func ConvertToF32(data []byte, dtype uint32, nelements uint64) []float32 {
|
||||
C.dequantize_row_q6_K((*C.block_q6_K)(unsafe.Pointer(&data[0])), (*C.float)(&f32s[0]), elems)
|
||||
case C.GGML_TYPE_BF16:
|
||||
C.ggml_bf16_to_fp32_row((*C.ggml_bf16_t)(unsafe.Pointer(&data[0])), (*C.float)(&f32s[0]), elems)
|
||||
case C.GGML_TYPE_MXFP4:
|
||||
C.dequantize_row_mxfp4((*C.block_mxfp4)(unsafe.Pointer(&data[0])), (*C.float)(&f32s[0]), elems)
|
||||
default:
|
||||
panic("unsupported quantization format")
|
||||
}
|
||||
|
||||
@@ -15,3 +15,26 @@ func (m *Linear) Forward(ctx ml.Context, t ml.Tensor) ml.Tensor {
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
type LinearBatch struct {
|
||||
Weight ml.Tensor `gguf:"weight"`
|
||||
Bias ml.Tensor `gguf:"bias"`
|
||||
}
|
||||
|
||||
func (m *LinearBatch) Forward(ctx ml.Context, t, indices ml.Tensor) ml.Tensor {
|
||||
t = m.Weight.MulmatID(ctx, t, indices)
|
||||
if m.Bias != nil {
|
||||
var bias ml.Tensor
|
||||
if len(indices.Shape()) > 1 {
|
||||
// FIXME: Rows does not support 2D indices for a 2D input tensor so reshape indices to 1D.
|
||||
bias = m.Bias.Rows(ctx, indices.Contiguous(ctx, indices.Dim(0)*indices.Dim(1))).
|
||||
Duplicate(ctx).
|
||||
Reshape(ctx, m.Bias.Dim(0), indices.Dim(0), indices.Dim(1))
|
||||
} else {
|
||||
bias = m.Bias.Rows(ctx, indices)
|
||||
}
|
||||
t = t.Add(ctx, bias)
|
||||
}
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
@@ -4,9 +4,15 @@ import "github.com/ollama/ollama/ml"
|
||||
|
||||
// Options contains optional parameters for RoPE function
|
||||
type Options struct {
|
||||
OriginalContextLength int
|
||||
Type int
|
||||
Factors ml.Tensor
|
||||
OriginalContextLength int
|
||||
|
||||
// YaRN options
|
||||
ExtrapolationFactor,
|
||||
AttentionFactor,
|
||||
BetaFast,
|
||||
BetaSlow float32
|
||||
}
|
||||
|
||||
// WithOriginalContextLength sets a custom context length
|
||||
@@ -31,3 +37,15 @@ func WithFactors(factors ml.Tensor) func(*Options) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func WithExtrapolationFactor(extrapolationFactor float32) func(*Options) {
|
||||
return func(opts *Options) {
|
||||
opts.ExtrapolationFactor = extrapolationFactor
|
||||
}
|
||||
}
|
||||
|
||||
func WithAttentionFactor(attentionFactor float32) func(*Options) {
|
||||
return func(opts *Options) {
|
||||
opts.AttentionFactor = attentionFactor
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,7 +22,7 @@ var _ TextProcessor = (*BytePairEncoding)(nil)
|
||||
|
||||
func NewBytePairEncoding(pre string, vocab *Vocabulary) BytePairEncoding {
|
||||
return BytePairEncoding{
|
||||
pre: regexp2.MustCompile(pre, regexp2.Unicode|regexp2.RE2),
|
||||
pre: regexp2.MustCompile(pre, regexp2.None),
|
||||
vocab: vocab,
|
||||
}
|
||||
}
|
||||
|
||||
268
model/models/gptoss/model.go
Normal file
268
model/models/gptoss/model.go
Normal file
@@ -0,0 +1,268 @@
|
||||
package gptoss
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"math"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/fast"
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
type Transformer struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
TransformerBlocks []TransformerBlock `gguf:"blk"`
|
||||
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
||||
|
||||
Options
|
||||
}
|
||||
|
||||
// Forward implements model.Model.
|
||||
func (m *Transformer) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
|
||||
one := ctx.Input().FromFloatSlice([]float32{1}, 1)
|
||||
for i, block := range m.TransformerBlocks {
|
||||
m.Cache.SetLayer(i)
|
||||
if c, ok := m.Cache.(*kvcache.WrapperCache); ok {
|
||||
// Even layers are sliding window attention.
|
||||
c.SetLayerType(i % 2)
|
||||
}
|
||||
|
||||
var outputs ml.Tensor
|
||||
if len(batch.Outputs) > 0 && i == len(m.TransformerBlocks)-1 {
|
||||
outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
}
|
||||
|
||||
hiddenStates = block.Forward(ctx, hiddenStates, positions, outputs, one, m.Cache, &m.Options)
|
||||
}
|
||||
|
||||
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
|
||||
return m.Output.Forward(ctx, hiddenStates), nil
|
||||
}
|
||||
|
||||
func (m *Transformer) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
return fast.RoPE(ctx, key, shift, m.headDim(), m.ropeBase, 1./m.ropeScale, m.RoPEOptions()...), nil
|
||||
}
|
||||
|
||||
type Options struct {
|
||||
hiddenSize,
|
||||
numHeads,
|
||||
numKVHeads,
|
||||
keyLength,
|
||||
valueLength,
|
||||
numExperts,
|
||||
numExpertsUsed,
|
||||
originalContextLength int
|
||||
|
||||
eps,
|
||||
ropeBase,
|
||||
ropeScale float32
|
||||
}
|
||||
|
||||
func (o Options) RoPEOptions() []func(*rope.Options) {
|
||||
return []func(*rope.Options){
|
||||
rope.WithTypeNeoX(),
|
||||
rope.WithOriginalContextLength(o.originalContextLength),
|
||||
rope.WithExtrapolationFactor(1.),
|
||||
// NOTE: ggml sets this implicitly so there's no need to set it here
|
||||
// rope.WithAttentionFactor(0.1*float32(math.Log(float64(o.ropeScale))) + 1.0),
|
||||
}
|
||||
}
|
||||
|
||||
func (o Options) headDim() int {
|
||||
return cmp.Or(o.keyLength, o.valueLength, o.hiddenSize/o.numHeads)
|
||||
}
|
||||
|
||||
type TransformerBlock struct {
|
||||
Attention *AttentionBlock
|
||||
MLP *MLPBlock
|
||||
}
|
||||
|
||||
func (d *TransformerBlock) Forward(ctx ml.Context, hiddenStates, positions, outputs, one ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||
hiddenStates = d.Attention.Forward(ctx, hiddenStates, positions, cache, opts)
|
||||
if outputs != nil {
|
||||
hiddenStates = hiddenStates.Rows(ctx, outputs)
|
||||
}
|
||||
|
||||
hiddenStates = d.MLP.Forward(ctx, hiddenStates, one, opts)
|
||||
return hiddenStates
|
||||
}
|
||||
|
||||
type AttentionBlock struct {
|
||||
Norm *nn.RMSNorm `gguf:"attn_norm"`
|
||||
QKV *nn.Linear `gguf:"attn_qkv"`
|
||||
Output *nn.Linear `gguf:"attn_out"`
|
||||
Sinks ml.Tensor `gguf:"attn_sinks"`
|
||||
}
|
||||
|
||||
func (attn *AttentionBlock) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||
batchSize := hiddenStates.Dim(1)
|
||||
|
||||
residual := hiddenStates
|
||||
hiddenStates = attn.Norm.Forward(ctx, hiddenStates, opts.eps)
|
||||
|
||||
qkv := attn.QKV.Forward(ctx, hiddenStates)
|
||||
|
||||
// query = qkv[..., : num_attention_heads * head_dim].reshape(batch_size, num_attention_heads, head_dim)
|
||||
query := qkv.View(ctx,
|
||||
0,
|
||||
opts.headDim(), qkv.Stride(0)*opts.headDim(),
|
||||
opts.numHeads, qkv.Stride(1),
|
||||
batchSize,
|
||||
)
|
||||
query = fast.RoPE(ctx, query, positions, opts.headDim(), opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...)
|
||||
|
||||
// key = qkv[..., num_attention_heads * head_dim:(num_attention_heads + num_key_value_heads) * head_dim].reshape(batch_size, num_key_value_heads, head_dim)
|
||||
key := qkv.View(ctx,
|
||||
qkv.Stride(0)*opts.headDim()*opts.numHeads,
|
||||
opts.headDim(), qkv.Stride(0)*opts.headDim(),
|
||||
opts.numKVHeads, qkv.Stride(1),
|
||||
batchSize,
|
||||
)
|
||||
key = fast.RoPE(ctx, key, positions, opts.headDim(), opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...)
|
||||
|
||||
// value = qkv[..., (num_attention_heads + num_key_value_heads) * head_dim:].reshape(batch_size, num_key_value_heads, head_dim)
|
||||
value := qkv.View(ctx,
|
||||
qkv.Stride(0)*opts.headDim()*(opts.numHeads+opts.numKVHeads),
|
||||
opts.headDim(), qkv.Stride(0)*opts.headDim(),
|
||||
opts.numKVHeads, qkv.Stride(1),
|
||||
batchSize,
|
||||
)
|
||||
|
||||
cache.Put(ctx, key, value)
|
||||
key, value, mask := cache.Get(ctx)
|
||||
|
||||
query = query.Permute(ctx, 0, 2, 1, 3)
|
||||
key = key.Permute(ctx, 0, 2, 1, 3)
|
||||
|
||||
scores := key.MulmatFullPrec(ctx, query)
|
||||
scores = scores.Scale(ctx, 1./math.Sqrt(float64(opts.headDim())))
|
||||
scores = scores.Add(ctx, mask)
|
||||
|
||||
scores = scores.Concat(ctx, attn.Sinks.Reshape(ctx, 1, 1, opts.numHeads, 1).Repeat(ctx, 1, batchSize), 0)
|
||||
scores = scores.Softmax(ctx)
|
||||
scores = scores.Pad(ctx, -1, 0, 0, 0)
|
||||
|
||||
attention := value.Mulmat(ctx, scores)
|
||||
attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize)
|
||||
|
||||
return attn.Output.Forward(ctx, attention).Add(ctx, residual)
|
||||
}
|
||||
|
||||
type MLPBlock struct {
|
||||
Norm *nn.RMSNorm `gguf:"ffn_norm"`
|
||||
Router *nn.Linear `gguf:"ffn_gate_inp"`
|
||||
GateUp *nn.LinearBatch `gguf:"ffn_gate_up_exps"`
|
||||
Down *nn.LinearBatch `gguf:"ffn_down_exps"`
|
||||
}
|
||||
|
||||
func (mlp *MLPBlock) Forward(ctx ml.Context, hiddenStates, one ml.Tensor, opts *Options) ml.Tensor {
|
||||
hiddenDim, sequenceLength, batchSize := hiddenStates.Dim(0), hiddenStates.Dim(1), hiddenStates.Dim(2)
|
||||
|
||||
residual := hiddenStates
|
||||
hiddenStates = mlp.Norm.Forward(ctx, hiddenStates, opts.eps)
|
||||
|
||||
hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, sequenceLength*batchSize)
|
||||
routingWeights := mlp.Router.Forward(ctx, hiddenStates)
|
||||
|
||||
selectedExperts := routingWeights.TopK(ctx, opts.numExpertsUsed)
|
||||
routingWeights = routingWeights.Reshape(ctx, 1, opts.numExperts, sequenceLength*batchSize).Rows(ctx, selectedExperts)
|
||||
routingWeights = routingWeights.Reshape(ctx, opts.numExpertsUsed, sequenceLength*batchSize).Softmax(ctx)
|
||||
routingWeights = routingWeights.Reshape(ctx, 1, opts.numExpertsUsed, sequenceLength*batchSize)
|
||||
|
||||
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1))
|
||||
|
||||
hiddenStates = mlp.GateUp.Forward(ctx, hiddenStates, selectedExperts)
|
||||
hiddenStates = hiddenStates.Reshape(ctx, 2, hiddenStates.Dim(0)/2, hiddenStates.Dim(1), hiddenStates.Dim(2))
|
||||
|
||||
dimStride := []int{hiddenStates.Dim(0) / 2, hiddenStates.Stride(1), hiddenStates.Dim(1), hiddenStates.Stride(2), hiddenStates.Dim(2), hiddenStates.Stride(3), hiddenStates.Dim(3)}
|
||||
|
||||
glu := hiddenStates.View(ctx, 0, dimStride...)
|
||||
glu = glu.Contiguous(ctx)
|
||||
glu = glu.Clamp(ctx, float32(math.Inf(-1)), 7.0)
|
||||
glu = glu.QuickGELU(ctx)
|
||||
|
||||
linear := hiddenStates.View(ctx, hiddenStates.Stride(0), dimStride...)
|
||||
linear = linear.Clamp(ctx, -7.0, 7.0)
|
||||
|
||||
hiddenStates = glu.Mul(ctx, linear.Add(ctx, one))
|
||||
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0)*hiddenStates.Dim(1), hiddenStates.Dim(2), hiddenStates.Dim(3))
|
||||
|
||||
experts := mlp.Down.Forward(ctx, hiddenStates, selectedExperts)
|
||||
experts = experts.Mul(ctx, routingWeights)
|
||||
|
||||
nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
|
||||
for i := 1; i < opts.numExpertsUsed; i++ {
|
||||
nextStates = nextStates.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2)))
|
||||
}
|
||||
|
||||
return nextStates.Add(ctx, residual)
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Transformer{
|
||||
TransformerBlocks: make([]TransformerBlock, c.Uint("block_count")),
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
c.String("tokenizer.ggml.pretokenizer",
|
||||
strings.Join([]string{
|
||||
`[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?`,
|
||||
`[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?`,
|
||||
`\p{N}{1,3}`,
|
||||
` ?[^\s\p{L}\p{N}]+[\r\n/]*`,
|
||||
`\s*[\r\n]+`,
|
||||
`\s+(?!\S)`,
|
||||
`\s+`,
|
||||
}, "|"),
|
||||
),
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
|
||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOS: append(
|
||||
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
|
||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||
),
|
||||
},
|
||||
),
|
||||
Options: Options{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||
keyLength: int(c.Uint("attention.key_length")),
|
||||
valueLength: int(c.Uint("attention.value_length")),
|
||||
numExperts: int(c.Uint("expert_count")),
|
||||
numExpertsUsed: int(c.Uint("expert_used_count")),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
ropeBase: c.Float("rope.freq_base"),
|
||||
ropeScale: c.Float("rope.scaling.factor", 1.),
|
||||
originalContextLength: int(c.Uint("rope.scaling.original_context_length")),
|
||||
},
|
||||
}
|
||||
|
||||
m.Cache = kvcache.NewWrapperCache(
|
||||
kvcache.NewSWAMemCache(int32(c.Uint("attention.sliding_window")), 4096, m.Shift),
|
||||
kvcache.NewCausalCache(m.Shift),
|
||||
)
|
||||
m.Cache.SetConfig(ml.CacheConfig{CachePadding: 32, PermutedV: true})
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
model.Register("gptoss", New)
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
_ "github.com/ollama/ollama/model/models/gemma2"
|
||||
_ "github.com/ollama/ollama/model/models/gemma3"
|
||||
_ "github.com/ollama/ollama/model/models/gemma3n"
|
||||
_ "github.com/ollama/ollama/model/models/gptoss"
|
||||
_ "github.com/ollama/ollama/model/models/llama"
|
||||
_ "github.com/ollama/ollama/model/models/llama4"
|
||||
_ "github.com/ollama/ollama/model/models/mistral3"
|
||||
|
||||
@@ -36,6 +36,7 @@ type ErrorResponse struct {
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content any `json:"content"`
|
||||
Reasoning string `json:"reasoning,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
}
|
||||
|
||||
@@ -81,6 +82,10 @@ type StreamOptions struct {
|
||||
IncludeUsage bool `json:"include_usage"`
|
||||
}
|
||||
|
||||
type Reasoning struct {
|
||||
Effort *string `json:"effort,omitempty"`
|
||||
}
|
||||
|
||||
type ChatCompletionRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
@@ -95,6 +100,7 @@ type ChatCompletionRequest struct {
|
||||
TopP *float64 `json:"top_p"`
|
||||
ResponseFormat *ResponseFormat `json:"response_format"`
|
||||
Tools []api.Tool `json:"tools"`
|
||||
Reasoning *Reasoning `json:"reasoning,omitempty"`
|
||||
}
|
||||
|
||||
type ChatCompletion struct {
|
||||
@@ -253,7 +259,7 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
||||
SystemFingerprint: "fp_ollama",
|
||||
Choices: []Choice{{
|
||||
Index: 0,
|
||||
Message: Message{Role: r.Message.Role, Content: r.Message.Content, ToolCalls: toolCalls},
|
||||
Message: Message{Role: r.Message.Role, Content: r.Message.Content, ToolCalls: toolCalls, Reasoning: r.Message.Thinking},
|
||||
FinishReason: func(reason string) *string {
|
||||
if len(toolCalls) > 0 {
|
||||
reason = "tool_calls"
|
||||
@@ -278,10 +284,10 @@ func toChunk(id string, r api.ChatResponse, toolCallSent bool) ChatCompletionChu
|
||||
SystemFingerprint: "fp_ollama",
|
||||
Choices: []ChunkChoice{{
|
||||
Index: 0,
|
||||
Delta: Message{Role: "assistant", Content: r.Message.Content, ToolCalls: toolCalls},
|
||||
Delta: Message{Role: "assistant", Content: r.Message.Content, ToolCalls: toolCalls, Reasoning: r.Message.Thinking},
|
||||
FinishReason: func(reason string) *string {
|
||||
if len(reason) > 0 {
|
||||
if toolCallSent {
|
||||
if toolCallSent || len(toolCalls) > 0 {
|
||||
return &finishReasonToolCalls
|
||||
}
|
||||
return &reason
|
||||
@@ -397,7 +403,7 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
||||
for _, msg := range r.Messages {
|
||||
switch content := msg.Content.(type) {
|
||||
case string:
|
||||
messages = append(messages, api.Message{Role: msg.Role, Content: content})
|
||||
messages = append(messages, api.Message{Role: msg.Role, Content: content, Thinking: msg.Reasoning})
|
||||
case []any:
|
||||
for _, c := range content {
|
||||
data, ok := c.(map[string]any)
|
||||
@@ -508,6 +514,10 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
||||
options["top_p"] = 1.0
|
||||
}
|
||||
|
||||
if r.Reasoning != nil {
|
||||
options["reasoning"] = *r.Reasoning.Effort
|
||||
}
|
||||
|
||||
var format json.RawMessage
|
||||
if r.ResponseFormat != nil {
|
||||
switch strings.ToLower(strings.TrimSpace(r.ResponseFormat.Type)) {
|
||||
@@ -521,6 +531,13 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
||||
}
|
||||
}
|
||||
|
||||
var think *api.ThinkValue
|
||||
if r.Reasoning != nil {
|
||||
think = &api.ThinkValue{
|
||||
Value: *r.Reasoning.Effort,
|
||||
}
|
||||
}
|
||||
|
||||
return &api.ChatRequest{
|
||||
Model: r.Model,
|
||||
Messages: messages,
|
||||
@@ -528,6 +545,7 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
||||
Options: options,
|
||||
Stream: &r.Stream,
|
||||
Tools: r.Tools,
|
||||
Think: think,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
379
server/harmonyparser.go
Normal file
379
server/harmonyparser.go
Normal file
@@ -0,0 +1,379 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
type harmonyParserState int
|
||||
|
||||
const (
|
||||
harmonyParserState_LookingForMessageStart harmonyParserState = iota
|
||||
harmonyParserState_ParsingHeader
|
||||
harmonyParserState_ParsingContent
|
||||
)
|
||||
|
||||
func shouldUseHarmony(model Model) bool {
|
||||
if model.Config.ModelFamily == "gptoss" {
|
||||
// heuristic to check whether the template expects to be parsed via harmony:
|
||||
// search for harmony tags that are nearly always used
|
||||
if model.Template.Contains("<|start|>") && model.Template.Contains("<|end|>") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (s harmonyParserState) String() string {
|
||||
switch s {
|
||||
// we're looking for the message start tag
|
||||
case harmonyParserState_LookingForMessageStart:
|
||||
return "LookingForMessageStart"
|
||||
case harmonyParserState_ParsingHeader:
|
||||
return "ParsingHeader"
|
||||
case harmonyParserState_ParsingContent:
|
||||
return "ParsingContent"
|
||||
default:
|
||||
return "Unknown"
|
||||
}
|
||||
}
|
||||
|
||||
type HarmonyParser struct {
|
||||
state harmonyParserState
|
||||
MessageStartTag string
|
||||
MessageEndTag string
|
||||
HeaderEndTag string
|
||||
acc strings.Builder
|
||||
lifetimeAcc strings.Builder
|
||||
}
|
||||
|
||||
type HarmonyEvent interface {
|
||||
isHarmonyEvent()
|
||||
}
|
||||
|
||||
type HarmonyEventMessageStart struct{}
|
||||
|
||||
func (HarmonyEventMessageStart) isHarmonyEvent() {}
|
||||
|
||||
type HarmonyEventHeaderComplete struct {
|
||||
Header HarmonyHeader
|
||||
}
|
||||
|
||||
func (HarmonyEventHeaderComplete) isHarmonyEvent() {}
|
||||
|
||||
type HarmonyEventContentEmitted struct {
|
||||
Content string
|
||||
}
|
||||
|
||||
func (HarmonyEventContentEmitted) isHarmonyEvent() {}
|
||||
|
||||
type HarmonyEventMessageEnd struct{}
|
||||
|
||||
func (HarmonyEventMessageEnd) isHarmonyEvent() {}
|
||||
|
||||
type HarmonyHeader struct {
|
||||
Role string
|
||||
Channel string
|
||||
Recipient string
|
||||
}
|
||||
|
||||
func (s *HarmonyParser) AddImplicitStart() {
|
||||
s.acc.WriteString("<|start|>assistant")
|
||||
}
|
||||
|
||||
func (s *HarmonyParser) AddImplicitStartOrPrefill(lastMessage *api.Message) {
|
||||
if lastMessage != nil && lastMessage.Role == "assistant" {
|
||||
// handle prefilling conditions
|
||||
if lastMessage.Content != "" {
|
||||
s.acc.WriteString("<|start|>assistant<|channel|>final<|message|>")
|
||||
return
|
||||
} else if lastMessage.Thinking != "" {
|
||||
s.acc.WriteString("<|start|>assistant<|channel|>analysis<|message|>")
|
||||
return
|
||||
}
|
||||
}
|
||||
s.AddImplicitStart()
|
||||
}
|
||||
|
||||
func (s *HarmonyParser) AddContent(content string) []HarmonyEvent {
|
||||
s.lifetimeAcc.WriteString(content)
|
||||
s.acc.WriteString(content)
|
||||
|
||||
var events []HarmonyEvent
|
||||
|
||||
keepLooping := true
|
||||
// we loop because we might pass through multiple parsing states in a single
|
||||
// call to addContent, and we want to make sure callers don't have to wait for
|
||||
// data that's already unambiguous
|
||||
for keepLooping {
|
||||
var newEvents []HarmonyEvent
|
||||
newEvents, keepLooping = eat(s)
|
||||
events = append(events, newEvents...)
|
||||
}
|
||||
|
||||
return events
|
||||
}
|
||||
|
||||
// the additional bool return is true iff we should continue eating
|
||||
func eat(s *HarmonyParser) ([]HarmonyEvent, bool) {
|
||||
switch s.state {
|
||||
case harmonyParserState_LookingForMessageStart:
|
||||
// does the acc contain the message start tag?
|
||||
if strings.Contains(s.acc.String(), s.MessageStartTag) {
|
||||
// split the acc into the message start tag and the rest
|
||||
split := strings.SplitN(s.acc.String(), s.MessageStartTag, 2)
|
||||
before := split[0]
|
||||
if before != "" {
|
||||
slog.Warn("harmony parser: found message start tag in the middle of the content", "content", s.acc.String())
|
||||
}
|
||||
after := split[1]
|
||||
s.acc.Reset()
|
||||
s.acc.WriteString(after)
|
||||
s.state = harmonyParserState_ParsingHeader
|
||||
return []HarmonyEvent{HarmonyEventMessageStart{}}, true
|
||||
}
|
||||
|
||||
// no match, so we keep accumulating
|
||||
return nil, false
|
||||
case harmonyParserState_ParsingHeader:
|
||||
if strings.Contains(s.acc.String(), s.HeaderEndTag) {
|
||||
split := strings.SplitN(s.acc.String(), s.HeaderEndTag, 2)
|
||||
header := split[0]
|
||||
after := split[1]
|
||||
s.acc.Reset()
|
||||
s.acc.WriteString(after)
|
||||
s.state = harmonyParserState_ParsingContent
|
||||
return []HarmonyEvent{HarmonyEventHeaderComplete{Header: s.parseHeader(header)}}, true
|
||||
}
|
||||
return nil, false
|
||||
case harmonyParserState_ParsingContent:
|
||||
if strings.Contains(s.acc.String(), s.MessageEndTag) {
|
||||
// if we already have the message end tag, we can emit the content up to it
|
||||
split := strings.SplitN(s.acc.String(), s.MessageEndTag, 2)
|
||||
content := split[0]
|
||||
after := split[1]
|
||||
s.acc.Reset()
|
||||
s.acc.WriteString(after)
|
||||
s.state = harmonyParserState_LookingForMessageStart
|
||||
events := []HarmonyEvent{}
|
||||
if content != "" {
|
||||
events = append(events, HarmonyEventContentEmitted{Content: content})
|
||||
}
|
||||
events = append(events, HarmonyEventMessageEnd{})
|
||||
return events, true
|
||||
} else if overlapLen := overlap(s.acc.String(), s.MessageEndTag); overlapLen > 0 {
|
||||
// if our suffix contains the start of the message end tag, we can emit
|
||||
// the content up to the start of the message end tag
|
||||
content := s.acc.String()[:len(s.acc.String())-overlapLen]
|
||||
remaining := s.acc.String()[len(s.acc.String())-overlapLen:]
|
||||
s.acc.Reset()
|
||||
s.acc.WriteString(remaining)
|
||||
// emit the content we know isn't part of the message end tag, and keep
|
||||
// accumulating to disambiguate the rest
|
||||
if content == "" {
|
||||
return nil, false
|
||||
}
|
||||
return []HarmonyEvent{HarmonyEventContentEmitted{Content: content}}, false
|
||||
} else {
|
||||
// no end tag, so it's still normal content that we can immediately emit
|
||||
content := s.acc.String()
|
||||
if content == "" {
|
||||
return nil, false
|
||||
}
|
||||
s.acc.Reset()
|
||||
return []HarmonyEvent{HarmonyEventContentEmitted{Content: content}}, false
|
||||
}
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (s *HarmonyParser) parseHeader(raw string) HarmonyHeader {
|
||||
harmonyHeader := HarmonyHeader{}
|
||||
|
||||
// if `<|constrain|>` is present, ensure it has a space before it so it gets
|
||||
// parsed as a separate token, even if the model didn't include the space
|
||||
if strings.Contains(raw, "<|constrain|>") {
|
||||
raw = strings.Replace(raw, "<|constrain|>", " <|constrain|>", 1)
|
||||
raw = strings.TrimSpace(raw)
|
||||
}
|
||||
|
||||
// look for the optional channel tag, which is `<|channel|>` followed by the
|
||||
// channel name, all without any whitespace
|
||||
channelIndex := strings.Index(raw, "<|channel|>")
|
||||
if channelIndex != -1 {
|
||||
before := raw[:channelIndex]
|
||||
after := raw[channelIndex+len("<|channel|>"):]
|
||||
// the channel name is `after` all the way up to the first (if any) whitespace character
|
||||
idx := strings.IndexFunc(after, func(r rune) bool {
|
||||
return unicode.IsSpace(r)
|
||||
})
|
||||
if idx == -1 {
|
||||
idx = len(after)
|
||||
}
|
||||
harmonyHeader.Channel = after[:idx]
|
||||
after = after[idx:]
|
||||
// now we remove the channel tag from the raw string to further process
|
||||
raw = before + after
|
||||
raw = strings.TrimSpace(raw)
|
||||
}
|
||||
|
||||
// split the header into whitespace-separated tokens
|
||||
tokens := strings.Fields(raw)
|
||||
|
||||
// the first token is treated as the role
|
||||
if len(tokens) == 0 {
|
||||
slog.Error("harmony parser: missing role in header", "header", raw)
|
||||
return harmonyHeader
|
||||
}
|
||||
role := tokens[0]
|
||||
tokens = tokens[1:]
|
||||
// special case: if role starts with to= then it's a tool call
|
||||
if strings.HasPrefix(role, "to=") {
|
||||
harmonyHeader.Recipient = role[3:]
|
||||
harmonyHeader.Role = "tool"
|
||||
} else {
|
||||
harmonyHeader.Role = role
|
||||
}
|
||||
|
||||
// the recipient (if any) can be specified before or after the channel tag, so
|
||||
// we check it at the end once we've already parsed the channel and role
|
||||
if harmonyHeader.Recipient == "" && len(tokens) > 0 && strings.HasPrefix(tokens[0], "to=") {
|
||||
harmonyHeader.Recipient = tokens[0][3:]
|
||||
}
|
||||
|
||||
return harmonyHeader
|
||||
}
|
||||
|
||||
// longest overlap between suffix of s and prefix of delim
|
||||
func overlap(s, delim string) int {
|
||||
max := min(len(delim), len(s))
|
||||
for i := max; i > 0; i-- {
|
||||
if strings.HasSuffix(s, delim[:i]) {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// harmonyMessageState represents the current state of message processing
|
||||
type harmonyMessageState int
|
||||
|
||||
const (
|
||||
harmonyMessageState_Normal harmonyMessageState = iota
|
||||
harmonyMessageState_Thinking
|
||||
harmonyMessageState_ToolCalling
|
||||
)
|
||||
|
||||
// HarmonyMessageHandler processes harmony events and accumulates content appropriately.
|
||||
// This is a higher level interface that maps harmony concepts into ollama concepts
|
||||
type HarmonyMessageHandler struct {
|
||||
state harmonyMessageState
|
||||
harmonyParser *HarmonyParser
|
||||
}
|
||||
|
||||
// NewHarmonyMessageHandler creates a new message handler
|
||||
func NewHarmonyMessageHandler() *HarmonyMessageHandler {
|
||||
return &HarmonyMessageHandler{
|
||||
state: harmonyMessageState_Normal,
|
||||
harmonyParser: &HarmonyParser{
|
||||
MessageStartTag: "<|start|>",
|
||||
MessageEndTag: "<|end|>",
|
||||
HeaderEndTag: "<|message|>",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// AddContent processes the content and returns the content, thinking, and tool content.
|
||||
// content and thinking are already fully parsed, but tool content still needs to be passed to the tool parser
|
||||
func (h *HarmonyMessageHandler) AddContent(content string, toolParser *HarmonyToolCallAccumulator) (string, string, string) {
|
||||
contentSb := strings.Builder{}
|
||||
thinkingSb := strings.Builder{}
|
||||
toolContentSb := strings.Builder{}
|
||||
|
||||
events := h.harmonyParser.AddContent(content)
|
||||
for _, event := range events {
|
||||
switch event := event.(type) {
|
||||
case HarmonyEventHeaderComplete:
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "harmony event header complete", "header", event.Header)
|
||||
switch event.Header.Channel {
|
||||
case "analysis":
|
||||
if event.Header.Recipient != "" {
|
||||
h.state = harmonyMessageState_ToolCalling
|
||||
// event.Header.Recipient is the tool name, something like
|
||||
// "browser.search" for a built-in, or "functions.calc" for a
|
||||
// custom one
|
||||
toolParser.SetToolName(event.Header.Recipient)
|
||||
} else {
|
||||
h.state = harmonyMessageState_Thinking
|
||||
}
|
||||
case "commentary":
|
||||
if event.Header.Recipient != "" {
|
||||
h.state = harmonyMessageState_ToolCalling
|
||||
toolParser.SetToolName(event.Header.Recipient)
|
||||
} else {
|
||||
h.state = harmonyMessageState_Normal
|
||||
}
|
||||
case "final":
|
||||
h.state = harmonyMessageState_Normal
|
||||
}
|
||||
case HarmonyEventContentEmitted:
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "harmony event content", "content", event.Content, "state", h.state)
|
||||
if h.state == harmonyMessageState_Normal {
|
||||
contentSb.WriteString(event.Content)
|
||||
} else if h.state == harmonyMessageState_Thinking {
|
||||
thinkingSb.WriteString(event.Content)
|
||||
} else if h.state == harmonyMessageState_ToolCalling {
|
||||
toolContentSb.WriteString(event.Content)
|
||||
}
|
||||
case HarmonyEventMessageEnd:
|
||||
h.state = harmonyMessageState_Normal
|
||||
}
|
||||
}
|
||||
return contentSb.String(), thinkingSb.String(), toolContentSb.String()
|
||||
}
|
||||
|
||||
func (h *HarmonyMessageHandler) CreateToolParser() *HarmonyToolCallAccumulator {
|
||||
return &HarmonyToolCallAccumulator{
|
||||
state: harmonyToolCallState_Normal,
|
||||
currentToolName: nil,
|
||||
}
|
||||
}
|
||||
|
||||
type harmonyToolCallState int
|
||||
|
||||
const (
|
||||
harmonyToolCallState_Normal harmonyToolCallState = iota
|
||||
harmonyToolCallState_ToolCalling
|
||||
)
|
||||
|
||||
type HarmonyToolCallAccumulator struct {
|
||||
state harmonyToolCallState
|
||||
acc strings.Builder
|
||||
currentToolName *string
|
||||
}
|
||||
|
||||
func (a *HarmonyToolCallAccumulator) SetToolName(toolName string) {
|
||||
a.currentToolName = &toolName
|
||||
}
|
||||
|
||||
func (a *HarmonyToolCallAccumulator) Add(content string) {
|
||||
a.acc.WriteString(content)
|
||||
}
|
||||
|
||||
func (a *HarmonyToolCallAccumulator) Drain() (*string, string) {
|
||||
str := a.acc.String()
|
||||
a.state = harmonyToolCallState_Normal
|
||||
a.acc.Reset()
|
||||
return a.currentToolName, str
|
||||
}
|
||||
|
||||
func (a *HarmonyToolCallAccumulator) Content() string {
|
||||
return a.acc.String()
|
||||
}
|
||||
469
server/harmonyparser_test.go
Normal file
469
server/harmonyparser_test.go
Normal file
@@ -0,0 +1,469 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHeaderParsing(t *testing.T) {
|
||||
tests := []struct {
|
||||
in, wantRole, wantChannel, wantRecipient string
|
||||
}{
|
||||
{
|
||||
in: "assistant<|channel|>analysis",
|
||||
wantRole: "assistant",
|
||||
wantChannel: "analysis",
|
||||
wantRecipient: "",
|
||||
},
|
||||
{
|
||||
in: "assistant<|channel|>analysis to=functions.get_weather",
|
||||
wantRole: "assistant",
|
||||
wantChannel: "analysis",
|
||||
wantRecipient: "functions.get_weather",
|
||||
},
|
||||
{
|
||||
in: "assistant to=functions.get_weather<|channel|>analysis",
|
||||
wantRole: "assistant",
|
||||
wantChannel: "analysis",
|
||||
wantRecipient: "functions.get_weather",
|
||||
},
|
||||
// special case where the role is replaced by the recipient (matches reference code)
|
||||
{
|
||||
in: "to=functions.get_weather<|channel|>analysis",
|
||||
wantRole: "tool",
|
||||
wantChannel: "analysis",
|
||||
wantRecipient: "functions.get_weather",
|
||||
},
|
||||
// extra token after the recipient is ignored
|
||||
{
|
||||
in: "assistant to=functions.get_weather abc<|channel|>analysis",
|
||||
wantRole: "assistant",
|
||||
wantChannel: "analysis",
|
||||
wantRecipient: "functions.get_weather",
|
||||
},
|
||||
// with constrain tag, recipient after channel tag
|
||||
{
|
||||
in: "assistant<|channel|>commentary to=functions.get_weather <|constrain|>json",
|
||||
wantRole: "assistant",
|
||||
wantChannel: "commentary",
|
||||
wantRecipient: "functions.get_weather",
|
||||
},
|
||||
// with constrain tag, recipient before channel tag
|
||||
{
|
||||
in: "assistant to=functions.get_weather<|channel|>commentary <|constrain|>json",
|
||||
wantRole: "assistant",
|
||||
wantChannel: "commentary",
|
||||
wantRecipient: "functions.get_weather",
|
||||
},
|
||||
// constrain tag without space
|
||||
{
|
||||
in: "assistant<|channel|>commentary to=functions.get_weather<|constrain|>json",
|
||||
wantRole: "assistant",
|
||||
wantChannel: "commentary",
|
||||
wantRecipient: "functions.get_weather",
|
||||
},
|
||||
// constrain tag without space, different order
|
||||
{
|
||||
in: "assistant to=functions.get_weather<|channel|>commentary<|constrain|>json",
|
||||
wantRole: "assistant",
|
||||
wantChannel: "commentary",
|
||||
wantRecipient: "functions.get_weather",
|
||||
},
|
||||
}
|
||||
for i, tt := range tests {
|
||||
parser := HarmonyParser{
|
||||
MessageStartTag: "<|start|>",
|
||||
MessageEndTag: "<|end|>",
|
||||
HeaderEndTag: "<|message|>",
|
||||
}
|
||||
header := parser.parseHeader(tt.in)
|
||||
|
||||
if header.Role != tt.wantRole {
|
||||
t.Errorf("case %d: got role \"%s\", want \"%s\"", i, header.Role, tt.wantRole)
|
||||
}
|
||||
if header.Channel != tt.wantChannel {
|
||||
t.Errorf("case %d: got channel \"%s\", want \"%s\"", i, header.Channel, tt.wantChannel)
|
||||
}
|
||||
if header.Recipient != tt.wantRecipient {
|
||||
t.Errorf("case %d: got recipient \"%s\", want \"%s\"", i, header.Recipient, tt.wantRecipient)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHarmonyParserHeaderEvent(t *testing.T) {
|
||||
tests := []struct {
|
||||
in, wantRole, wantChannel, wantRecipient string
|
||||
implicitStart bool
|
||||
}{
|
||||
{
|
||||
in: "<|start|>user<|message|>What is 2 + 2?<|end|>",
|
||||
wantRole: "user",
|
||||
wantChannel: "",
|
||||
wantRecipient: "",
|
||||
},
|
||||
{
|
||||
in: "<|start|>assistant<|channel|>analysis<|message|>What is 2 + 2?<|end|>",
|
||||
wantRole: "assistant",
|
||||
wantChannel: "analysis",
|
||||
wantRecipient: "",
|
||||
},
|
||||
{
|
||||
in: "<|start|>assistant<|channel|>commentary to=functions.get_weather <|constrain|>json<|message|>{\"location\":\"San Francisco\"}<|call|><|start|>functions.get_weather to=assistant<|message|>{\"sunny\": true, \"temperature\": 20}<|end|>",
|
||||
wantRole: "assistant",
|
||||
wantChannel: "commentary",
|
||||
wantRecipient: "functions.get_weather",
|
||||
},
|
||||
{
|
||||
in: "<|channel|>analysis<|message|>User asks weather in SF. We need location. Use get_current_weather with location \"San Francisco, CA\".<|end|><|start|>assistant<|channel|>commentary to=functions.get_current_weather <|constrain|>json<|message|>{\"location\":\"San Francisco, CA\"}<|call|>",
|
||||
wantRole: "assistant",
|
||||
wantChannel: "analysis",
|
||||
wantRecipient: "",
|
||||
implicitStart: true,
|
||||
},
|
||||
}
|
||||
for i, tt := range tests {
|
||||
parser := HarmonyParser{
|
||||
MessageStartTag: "<|start|>",
|
||||
MessageEndTag: "<|end|>",
|
||||
HeaderEndTag: "<|message|>",
|
||||
}
|
||||
if tt.implicitStart {
|
||||
parser.AddImplicitStart()
|
||||
}
|
||||
gotEvents := parser.AddContent(tt.in)
|
||||
if len(gotEvents) == 0 {
|
||||
t.Errorf("case %d: got no events, want at least one", i)
|
||||
}
|
||||
|
||||
var firstHeaderEvent *HarmonyEventHeaderComplete
|
||||
// print events
|
||||
for _, event := range gotEvents {
|
||||
fmt.Printf("event: %+v\n", event)
|
||||
}
|
||||
for _, event := range gotEvents {
|
||||
if event, ok := event.(HarmonyEventHeaderComplete); ok {
|
||||
firstHeaderEvent = &event
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if firstHeaderEvent == nil {
|
||||
t.Errorf("case %d: got no header complete event, want one", i)
|
||||
continue
|
||||
}
|
||||
gotHeader := firstHeaderEvent.Header
|
||||
if gotHeader.Role != tt.wantRole || gotHeader.Channel != tt.wantChannel || gotHeader.Recipient != tt.wantRecipient {
|
||||
t.Errorf("case %d: got header %+v, want role=%s channel=%s recipient=%s", i, gotHeader, tt.wantRole, tt.wantChannel, tt.wantRecipient)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHarmonyParserNonStreaming(t *testing.T) {
|
||||
tests := []struct {
|
||||
in string
|
||||
implicitStart bool
|
||||
wantEvents []HarmonyEvent
|
||||
}{
|
||||
{
|
||||
in: "<|start|>user<|message|>What is 2 + 2?<|end|>",
|
||||
wantEvents: []HarmonyEvent{
|
||||
HarmonyEventMessageStart{},
|
||||
HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "user", Channel: "", Recipient: ""}},
|
||||
HarmonyEventContentEmitted{Content: "What is 2 + 2?"},
|
||||
HarmonyEventMessageEnd{},
|
||||
},
|
||||
},
|
||||
{
|
||||
in: "<|start|>assistant<|channel|>analysis<|message|>The answer is 4<|end|>",
|
||||
wantEvents: []HarmonyEvent{
|
||||
HarmonyEventMessageStart{},
|
||||
HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "assistant", Channel: "analysis", Recipient: ""}},
|
||||
HarmonyEventContentEmitted{Content: "The answer is 4"},
|
||||
HarmonyEventMessageEnd{},
|
||||
},
|
||||
},
|
||||
{
|
||||
in: "<|start|>assistant<|channel|>commentary to=functions.calc<|message|>Computing...<|end|>",
|
||||
wantEvents: []HarmonyEvent{
|
||||
HarmonyEventMessageStart{},
|
||||
HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "assistant", Channel: "commentary", Recipient: "functions.calc"}},
|
||||
HarmonyEventContentEmitted{Content: "Computing..."},
|
||||
HarmonyEventMessageEnd{},
|
||||
},
|
||||
},
|
||||
{
|
||||
in: "<|start|>user<|message|><|end|>",
|
||||
wantEvents: []HarmonyEvent{
|
||||
HarmonyEventMessageStart{},
|
||||
HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "user", Channel: "", Recipient: ""}},
|
||||
HarmonyEventMessageEnd{},
|
||||
},
|
||||
},
|
||||
{
|
||||
in: "<|start|>user<|message|>Hello<|end|><|start|>assistant<|message|>Hi!<|end|>",
|
||||
wantEvents: []HarmonyEvent{
|
||||
HarmonyEventMessageStart{},
|
||||
HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "user", Channel: "", Recipient: ""}},
|
||||
HarmonyEventContentEmitted{Content: "Hello"},
|
||||
HarmonyEventMessageEnd{},
|
||||
HarmonyEventMessageStart{},
|
||||
HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "assistant", Channel: "", Recipient: ""}},
|
||||
HarmonyEventContentEmitted{Content: "Hi!"},
|
||||
HarmonyEventMessageEnd{},
|
||||
},
|
||||
},
|
||||
{
|
||||
in: "<|channel|>analysis<|message|>Thinking about the request<|end|>",
|
||||
implicitStart: true,
|
||||
wantEvents: []HarmonyEvent{HarmonyEventMessageStart{}, HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "assistant", Channel: "analysis", Recipient: ""}}, HarmonyEventContentEmitted{Content: "Thinking about the request"}, HarmonyEventMessageEnd{}},
|
||||
},
|
||||
}
|
||||
for i, tt := range tests {
|
||||
parser := HarmonyParser{
|
||||
MessageStartTag: "<|start|>",
|
||||
MessageEndTag: "<|end|>",
|
||||
HeaderEndTag: "<|message|>",
|
||||
}
|
||||
if tt.implicitStart {
|
||||
parser.AddImplicitStart()
|
||||
}
|
||||
gotEvents := parser.AddContent(tt.in)
|
||||
if !reflect.DeepEqual(gotEvents, tt.wantEvents) {
|
||||
t.Errorf("case %d: got events %#v, want %#v", i, gotEvents, tt.wantEvents)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHarmonyParserStreaming(t *testing.T) {
|
||||
type step struct {
|
||||
input string
|
||||
wantEvents []HarmonyEvent
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
implicitStart bool
|
||||
steps []step
|
||||
}{
|
||||
{
|
||||
desc: "simple message streamed character by character",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<",
|
||||
wantEvents: nil,
|
||||
},
|
||||
{
|
||||
input: "|",
|
||||
wantEvents: nil,
|
||||
},
|
||||
{
|
||||
input: "start|>u",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventMessageStart{}},
|
||||
},
|
||||
{
|
||||
input: "ser<|mess",
|
||||
wantEvents: nil,
|
||||
},
|
||||
{
|
||||
input: "age|>Hi",
|
||||
wantEvents: []HarmonyEvent{
|
||||
HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "user", Channel: "", Recipient: ""}},
|
||||
HarmonyEventContentEmitted{Content: "Hi"},
|
||||
},
|
||||
},
|
||||
{
|
||||
input: " there",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventContentEmitted{Content: " there"}},
|
||||
},
|
||||
{
|
||||
input: "<|e",
|
||||
wantEvents: nil,
|
||||
},
|
||||
{
|
||||
input: "nd|>",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventMessageEnd{}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "message with channel streamed",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<|start|>assistant",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventMessageStart{}},
|
||||
},
|
||||
{
|
||||
input: "<|chan",
|
||||
wantEvents: nil,
|
||||
},
|
||||
{
|
||||
input: "nel|>analysis",
|
||||
wantEvents: nil,
|
||||
},
|
||||
{
|
||||
input: "<|message|>",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "assistant", Channel: "analysis", Recipient: ""}}},
|
||||
},
|
||||
{
|
||||
input: "Thinking",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventContentEmitted{Content: "Thinking"}},
|
||||
},
|
||||
{
|
||||
input: "...",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventContentEmitted{Content: "..."}},
|
||||
},
|
||||
{
|
||||
input: "<|end|>",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventMessageEnd{}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "message with channel and recipient",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<|start|>assistant<|channel|>commentary to=functions.calc<|message|>",
|
||||
wantEvents: []HarmonyEvent{
|
||||
HarmonyEventMessageStart{},
|
||||
HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "assistant", Channel: "commentary", Recipient: "functions.calc"}},
|
||||
},
|
||||
},
|
||||
{
|
||||
input: "{\"x\": 5}",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventContentEmitted{Content: "{\"x\": 5}"}},
|
||||
},
|
||||
{
|
||||
input: "<|end|>",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventMessageEnd{}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "message with channel and recipient (receipient before channel)",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<|start|>assistant to=functions.calc<|channel|>commentary<|message|>",
|
||||
wantEvents: []HarmonyEvent{
|
||||
HarmonyEventMessageStart{},
|
||||
HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "assistant", Channel: "commentary", Recipient: "functions.calc"}},
|
||||
},
|
||||
},
|
||||
{
|
||||
input: "{\"x\": 5}",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventContentEmitted{Content: "{\"x\": 5}"}},
|
||||
},
|
||||
{
|
||||
input: "<|end|>",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventMessageEnd{}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "implicit start with channel",
|
||||
implicitStart: true,
|
||||
steps: []step{
|
||||
{
|
||||
input: "<|channel|>thinking",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventMessageStart{}},
|
||||
},
|
||||
{
|
||||
input: "<|message|>",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "assistant", Channel: "thinking", Recipient: ""}}},
|
||||
},
|
||||
{
|
||||
input: "Processing request",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventContentEmitted{Content: "Processing request"}},
|
||||
},
|
||||
{
|
||||
input: "<|end|>",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventMessageEnd{}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "multiple messages streamed",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<|start|>user<|message|>Hello<|end|>",
|
||||
wantEvents: []HarmonyEvent{
|
||||
HarmonyEventMessageStart{},
|
||||
HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "user", Channel: "", Recipient: ""}},
|
||||
HarmonyEventContentEmitted{Content: "Hello"},
|
||||
HarmonyEventMessageEnd{},
|
||||
},
|
||||
},
|
||||
{
|
||||
input: "<|start|>",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventMessageStart{}},
|
||||
},
|
||||
{
|
||||
input: "assistant<|message|>",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "assistant", Channel: "", Recipient: ""}}},
|
||||
},
|
||||
{
|
||||
input: "Hi!",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventContentEmitted{Content: "Hi!"}},
|
||||
},
|
||||
{
|
||||
input: "<|end|>",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventMessageEnd{}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "empty message",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<|start|>system<|message|><|end|>",
|
||||
wantEvents: []HarmonyEvent{
|
||||
HarmonyEventMessageStart{},
|
||||
HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "system", Channel: "", Recipient: ""}},
|
||||
HarmonyEventMessageEnd{},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "partial tag that looks like end but isn't",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<|start|>user<|message|>test<|e",
|
||||
wantEvents: []HarmonyEvent{
|
||||
HarmonyEventMessageStart{},
|
||||
HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "user", Channel: "", Recipient: ""}},
|
||||
HarmonyEventContentEmitted{Content: "test"},
|
||||
},
|
||||
},
|
||||
{
|
||||
input: "xample|>more",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventContentEmitted{Content: "<|example|>more"}},
|
||||
},
|
||||
{
|
||||
input: "<|end|>",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventMessageEnd{}},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
parser := HarmonyParser{
|
||||
MessageStartTag: "<|start|>",
|
||||
MessageEndTag: "<|end|>",
|
||||
HeaderEndTag: "<|message|>",
|
||||
}
|
||||
if tc.implicitStart {
|
||||
parser.AddImplicitStart()
|
||||
}
|
||||
|
||||
for i, step := range tc.steps {
|
||||
gotEvents := parser.AddContent(step.input)
|
||||
if !reflect.DeepEqual(gotEvents, step.wantEvents) {
|
||||
t.Errorf("step %d: input %q: got events %#v, want %#v", i, step.input, gotEvents, step.wantEvents)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -111,7 +111,8 @@ func (m *Model) Capabilities() []model.Capability {
|
||||
|
||||
// Check for thinking capability
|
||||
openingTag, closingTag := thinking.InferTags(m.Template.Template)
|
||||
if openingTag != "" && closingTag != "" {
|
||||
hasTags := openingTag != "" && closingTag != ""
|
||||
if hasTags || m.Config.ModelFamily == "gptoss" {
|
||||
capabilities = append(capabilities, model.CapabilityThinking)
|
||||
}
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ type tokenizeFunc func(context.Context, string) ([]int, error)
|
||||
// chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
|
||||
// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
|
||||
// latest message and 2) system messages
|
||||
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool, think *bool) (prompt string, images []llm.ImageData, _ error) {
|
||||
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool, think *api.ThinkValue) (prompt string, images []llm.ImageData, _ error) {
|
||||
var system []api.Message
|
||||
|
||||
// TODO: Ideally we would compute this from the projector metadata but some pieces are implementation dependent
|
||||
@@ -42,11 +42,13 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
||||
}
|
||||
|
||||
thinkVal := false
|
||||
thinkLevel := ""
|
||||
if think != nil {
|
||||
thinkVal = *think
|
||||
thinkVal = think.AsBool()
|
||||
thinkLevel = think.AsString()
|
||||
}
|
||||
var b bytes.Buffer
|
||||
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...), Tools: tools, Think: thinkVal, IsThinkSet: think != nil}); err != nil {
|
||||
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...), Tools: tools, Think: thinkVal, ThinkLevel: thinkLevel, IsThinkSet: think != nil}); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
@@ -101,10 +103,12 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
||||
// truncate any messages that do not fit into the context window
|
||||
var b bytes.Buffer
|
||||
thinkVal := false
|
||||
thinkLevel := ""
|
||||
if think != nil {
|
||||
thinkVal = *think
|
||||
thinkVal = think.AsBool()
|
||||
thinkLevel = think.AsString()
|
||||
}
|
||||
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[currMsgIdx:]...), Tools: tools, Think: thinkVal, IsThinkSet: think != nil}); err != nil {
|
||||
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[currMsgIdx:]...), Tools: tools, Think: thinkVal, ThinkLevel: thinkLevel, IsThinkSet: think != nil}); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
|
||||
@@ -209,7 +209,7 @@ func TestChatPrompt(t *testing.T) {
|
||||
model := tt.model
|
||||
opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
|
||||
think := false
|
||||
prompt, images, err := chatPrompt(t.Context(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil, &think)
|
||||
prompt, images, err := chatPrompt(t.Context(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil, &api.ThinkValue{Value: think})
|
||||
if tt.error == nil && err != nil {
|
||||
t.Fatal(err)
|
||||
} else if tt.error != nil && err != tt.error {
|
||||
|
||||
140
server/routes.go
140
server/routes.go
@@ -112,6 +112,11 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.C
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
// This model requires a minimum context to function effectively
|
||||
if slices.Contains(model.Config.ModelFamilies, "gptoss") {
|
||||
opts.NumCtx = max(opts.NumCtx, 8192)
|
||||
}
|
||||
|
||||
runnerCh, errCh := s.sched.GetRunner(ctx, model, opts, keepAlive)
|
||||
var runner *runnerRef
|
||||
select {
|
||||
@@ -182,11 +187,26 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
useHarmony := shouldUseHarmony(*m) && !req.Raw
|
||||
var harmonyMessageHandler *HarmonyMessageHandler
|
||||
var harmonyToolParser *HarmonyToolCallAccumulator
|
||||
if useHarmony {
|
||||
harmonyMessageHandler = NewHarmonyMessageHandler()
|
||||
harmonyMessageHandler.harmonyParser.AddImplicitStart()
|
||||
harmonyToolParser = harmonyMessageHandler.CreateToolParser()
|
||||
}
|
||||
|
||||
// Validate Think value: string values currently only allowed for gptoss models
|
||||
if req.Think != nil && req.Think.IsString() && !useHarmony {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("think value %q is not supported for this model", req.Think.AsString())})
|
||||
return
|
||||
}
|
||||
|
||||
caps := []model.Capability{model.CapabilityCompletion}
|
||||
if req.Suffix != "" {
|
||||
caps = append(caps, model.CapabilityInsert)
|
||||
}
|
||||
if req.Think != nil && *req.Think {
|
||||
if req.Think != nil && req.Think.AsBool() {
|
||||
caps = append(caps, model.CapabilityThinking)
|
||||
// TODO(drifkin): consider adding a warning if it's false and the model
|
||||
// doesn't support thinking. It's not strictly required, but it can be a
|
||||
@@ -261,7 +281,11 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
values.Messages = append(msgs, api.Message{Role: "user", Content: req.Prompt})
|
||||
}
|
||||
|
||||
values.Think = req.Think != nil && *req.Think
|
||||
values.Think = req.Think != nil && req.Think.AsBool()
|
||||
values.ThinkLevel = ""
|
||||
if req.Think != nil {
|
||||
values.ThinkLevel = req.Think.AsString()
|
||||
}
|
||||
values.IsThinkSet = req.Think != nil
|
||||
|
||||
var b bytes.Buffer
|
||||
@@ -284,11 +308,13 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
var thinkingState *thinking.Parser
|
||||
openingTag, closingTag := thinking.InferTags(m.Template.Template)
|
||||
if req.Think != nil && *req.Think && openingTag != "" && closingTag != "" {
|
||||
thinkingState = &thinking.Parser{
|
||||
OpeningTag: openingTag,
|
||||
ClosingTag: closingTag,
|
||||
if !useHarmony {
|
||||
openingTag, closingTag := thinking.InferTags(m.Template.Template)
|
||||
if req.Think != nil && req.Think.AsBool() && openingTag != "" && closingTag != "" {
|
||||
thinkingState = &thinking.Parser{
|
||||
OpeningTag: openingTag,
|
||||
ClosingTag: closingTag,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -316,7 +342,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
},
|
||||
}
|
||||
|
||||
if thinkingState != nil {
|
||||
if useHarmony {
|
||||
content, thinking, toolContent := harmonyMessageHandler.AddContent(cr.Content, harmonyToolParser)
|
||||
res.Response = content
|
||||
res.Thinking = thinking
|
||||
harmonyToolParser.Add(toolContent)
|
||||
} else if thinkingState != nil {
|
||||
thinking, content := thinkingState.AddContent(cr.Content)
|
||||
res.Thinking = thinking
|
||||
res.Response = content
|
||||
@@ -327,6 +358,25 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
if cr.Done {
|
||||
if useHarmony {
|
||||
toolName, toolContent := harmonyToolParser.Drain()
|
||||
if toolName != nil {
|
||||
*toolName = strings.TrimPrefix(*toolName, "functions.")
|
||||
var args api.ToolCallFunctionArguments
|
||||
if err := json.Unmarshal([]byte(toolContent), &args); err != nil {
|
||||
ch <- gin.H{"error parsing tool call": err.Error()}
|
||||
return
|
||||
}
|
||||
|
||||
res.ToolCalls = append(res.ToolCalls, api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: *toolName,
|
||||
Arguments: args,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
res.DoneReason = cr.DoneReason.String()
|
||||
res.TotalDuration = time.Since(checkpointStart)
|
||||
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||
@@ -341,6 +391,15 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
if useHarmony {
|
||||
// only send messages with meaningful content (empty messages confuse clients)
|
||||
if res.Response != "" || res.Thinking != "" || res.Done || len(res.ToolCalls) > 0 {
|
||||
ch <- res
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
ch <- res
|
||||
}); err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
@@ -1471,7 +1530,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
if len(req.Tools) > 0 {
|
||||
caps = append(caps, model.CapabilityTools)
|
||||
}
|
||||
if req.Think != nil && *req.Think {
|
||||
if req.Think != nil && req.Think.AsBool() {
|
||||
caps = append(caps, model.CapabilityThinking)
|
||||
}
|
||||
|
||||
@@ -1521,9 +1580,30 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
useHarmony := shouldUseHarmony(*m)
|
||||
|
||||
// Validate Think value: string values currently only allowed for gptoss models
|
||||
if req.Think != nil && req.Think.IsString() && !useHarmony {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("think value %q is not supported for this model", req.Think.AsString())})
|
||||
return
|
||||
}
|
||||
|
||||
var harmonyMessageHandler *HarmonyMessageHandler
|
||||
var harmonyToolParser *HarmonyToolCallAccumulator
|
||||
|
||||
if useHarmony {
|
||||
harmonyMessageHandler = NewHarmonyMessageHandler()
|
||||
var lastMessage *api.Message
|
||||
if len(msgs) > 0 {
|
||||
lastMessage = &msgs[len(msgs)-1]
|
||||
}
|
||||
harmonyMessageHandler.harmonyParser.AddImplicitStartOrPrefill(lastMessage)
|
||||
harmonyToolParser = harmonyMessageHandler.CreateToolParser()
|
||||
}
|
||||
|
||||
var thinkingState *thinking.Parser
|
||||
openingTag, closingTag := thinking.InferTags(m.Template.Template)
|
||||
if req.Think != nil && *req.Think && openingTag != "" && closingTag != "" {
|
||||
if req.Think != nil && req.Think.AsBool() && openingTag != "" && closingTag != "" {
|
||||
thinkingState = &thinking.Parser{
|
||||
OpeningTag: openingTag,
|
||||
ClosingTag: closingTag,
|
||||
@@ -1531,7 +1611,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
var toolParser *tools.Parser
|
||||
if len(req.Tools) > 0 {
|
||||
if len(req.Tools) > 0 && !useHarmony {
|
||||
toolParser = tools.NewParser(m.Template.Template, req.Tools)
|
||||
}
|
||||
|
||||
@@ -1557,6 +1637,38 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
EvalDuration: r.EvalDuration,
|
||||
},
|
||||
}
|
||||
if r.Done {
|
||||
res.DoneReason = r.DoneReason.String()
|
||||
res.TotalDuration = time.Since(checkpointStart)
|
||||
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||
}
|
||||
|
||||
if useHarmony {
|
||||
content, thinking, toolContent := harmonyMessageHandler.AddContent(r.Content, harmonyToolParser)
|
||||
res.Message.Content = content
|
||||
res.Message.Thinking = thinking
|
||||
harmonyToolParser.Add(toolContent)
|
||||
|
||||
if r.Done {
|
||||
toolName, toolContent := harmonyToolParser.Drain()
|
||||
if toolName != nil {
|
||||
*toolName = strings.TrimPrefix(*toolName, "functions.")
|
||||
var args api.ToolCallFunctionArguments
|
||||
if err := json.Unmarshal([]byte(toolContent), &args); err != nil {
|
||||
ch <- gin.H{"error parsing tool call": err.Error()}
|
||||
return
|
||||
}
|
||||
res.Message.ToolCalls = []api.ToolCall{{Function: api.ToolCallFunction{Name: *toolName, Arguments: args}}}
|
||||
}
|
||||
}
|
||||
|
||||
// only send messages with meaningful content (empty messages confuse clients)
|
||||
if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || res.Done {
|
||||
ch <- res
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if thinkingState != nil {
|
||||
thinkingContent, remainingContent := thinkingState.AddContent(res.Message.Content)
|
||||
@@ -1568,12 +1680,6 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
res.Message.Thinking = thinkingContent
|
||||
}
|
||||
|
||||
if r.Done {
|
||||
res.DoneReason = r.DoneReason.String()
|
||||
res.TotalDuration = time.Since(checkpointStart)
|
||||
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||
}
|
||||
|
||||
if len(req.Tools) > 0 {
|
||||
toolCalls, content := toolParser.Add(res.Message.Content)
|
||||
if len(content) > 0 {
|
||||
|
||||
@@ -150,7 +150,7 @@ func TestGenerateChat(t *testing.T) {
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "Hello!"},
|
||||
},
|
||||
Think: &think,
|
||||
Think: &api.ThinkValue{Value: think},
|
||||
})
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
|
||||
712
server/routes_harmony_streaming_test.go
Normal file
712
server/routes_harmony_streaming_test.go
Normal file
@@ -0,0 +1,712 @@
|
||||
package server
|
||||
|
||||
// this test file is to test integration of harmony parser into routes.go (as
|
||||
// opposed to harmonyparser_test.go, which tests the parser in isolation)
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/discover"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/llm"
|
||||
)
|
||||
|
||||
func getTestTools() []api.Tool {
|
||||
return []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get the current weather in a given location",
|
||||
Parameters: struct {
|
||||
Type string `json:"type"`
|
||||
Defs any `json:"$defs,omitempty"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Required []string `json:"required"`
|
||||
Properties map[string]struct {
|
||||
Type api.PropertyType `json:"type"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Description string `json:"description"`
|
||||
Enum []any `json:"enum,omitempty"`
|
||||
} `json:"properties"`
|
||||
}{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: map[string]struct {
|
||||
Type api.PropertyType `json:"type"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Description string `json:"description"`
|
||||
Enum []any `json:"enum,omitempty"`
|
||||
}{
|
||||
"location": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "calculate",
|
||||
Description: "Calculate a mathematical expression",
|
||||
Parameters: struct {
|
||||
Type string `json:"type"`
|
||||
Defs any `json:"$defs,omitempty"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Required []string `json:"required"`
|
||||
Properties map[string]struct {
|
||||
Type api.PropertyType `json:"type"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Description string `json:"description"`
|
||||
Enum []any `json:"enum,omitempty"`
|
||||
} `json:"properties"`
|
||||
}{
|
||||
Type: "object",
|
||||
Required: []string{"expression"},
|
||||
Properties: map[string]struct {
|
||||
Type api.PropertyType `json:"type"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Description string `json:"description"`
|
||||
Enum []any `json:"enum,omitempty"`
|
||||
}{
|
||||
"expression": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "The mathematical expression to calculate",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func createHarmonyTestModel(t *testing.T) (string, string) {
|
||||
t.Helper()
|
||||
|
||||
return createBinFile(t, ggml.KV{
|
||||
"general.architecture": "gptoss",
|
||||
"llama.block_count": uint32(1),
|
||||
"llama.context_length": uint32(8192),
|
||||
"llama.embedding_length": uint32(4096),
|
||||
"llama.attention.head_count": uint32(32),
|
||||
"llama.attention.head_count_kv": uint32(8),
|
||||
"tokenizer.ggml.tokens": []string{""},
|
||||
"tokenizer.ggml.scores": []float32{0},
|
||||
"tokenizer.ggml.token_type": []int32{0},
|
||||
}, []*ggml.Tensor{
|
||||
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
})
|
||||
}
|
||||
|
||||
// TestChatHarmonyParserStreamingRealtime verifies that chunks are emitted as soon as they're available
|
||||
func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
type step struct {
|
||||
input llm.CompletionResponse
|
||||
wantContent string
|
||||
wantThinking string
|
||||
wantToolCalls []api.ToolCall
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
steps []step
|
||||
only bool
|
||||
}{
|
||||
{
|
||||
name: "content streams as it arrives",
|
||||
steps: []step{
|
||||
{
|
||||
input: llm.CompletionResponse{Content: "<|message|>Hello", Done: false},
|
||||
wantContent: "Hello",
|
||||
},
|
||||
{
|
||||
input: llm.CompletionResponse{Content: ", world", Done: false},
|
||||
wantContent: ", world",
|
||||
},
|
||||
{
|
||||
input: llm.CompletionResponse{Content: "!<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
|
||||
wantContent: "!",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "thinking streams separately from content",
|
||||
steps: []step{
|
||||
{
|
||||
input: llm.CompletionResponse{Content: "<|channel|>analysis<|message|>Thinking...", Done: false},
|
||||
wantThinking: "Thinking...",
|
||||
},
|
||||
{
|
||||
input: llm.CompletionResponse{Content: "<|end|>", Done: false},
|
||||
// No output expected - just closes the analysis message and resets state to normal
|
||||
},
|
||||
{
|
||||
input: llm.CompletionResponse{Content: "<|start|>assistant<|message|>Answer", Done: false},
|
||||
wantContent: "Answer", // After message end, state is reset to normal
|
||||
},
|
||||
{
|
||||
input: llm.CompletionResponse{Content: "<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
|
||||
// No output expected - just closes the assistant message
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "partial tags buffer until complete",
|
||||
steps: []step{
|
||||
{
|
||||
input: llm.CompletionResponse{Content: "<|chan", Done: false},
|
||||
// No output - partial tag
|
||||
},
|
||||
{
|
||||
input: llm.CompletionResponse{Content: "nel|>analysis<|mess", Done: false},
|
||||
// No output - still building tags
|
||||
},
|
||||
{
|
||||
input: llm.CompletionResponse{Content: "age|>Deep ", Done: false},
|
||||
wantThinking: "Deep ",
|
||||
},
|
||||
{
|
||||
input: llm.CompletionResponse{Content: "thought<|end|>", Done: false},
|
||||
wantThinking: "thought",
|
||||
},
|
||||
{
|
||||
input: llm.CompletionResponse{Content: "<|start|>assistant<|message|>Done<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
|
||||
wantContent: "Done", // After message end, state is reset to normal
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "simple assistant after analysis",
|
||||
steps: []step{
|
||||
{
|
||||
input: llm.CompletionResponse{Content: "<|channel|>analysis<|message|>Think<|end|><|start|>assistant<|message|>Answer<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
|
||||
wantContent: "Answer",
|
||||
wantThinking: "Think",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call parsed and returned correctly",
|
||||
steps: []step{
|
||||
{
|
||||
input: llm.CompletionResponse{Content: "<|channel|>commentary to=functions.get_weather<|message|>{\"location\":\"San Francisco\"}<|end|><|start|>assistant<|message|>The weather is sunny<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
|
||||
wantContent: "The weather is sunny",
|
||||
wantToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"location": "San Francisco",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call with streaming JSON across chunks",
|
||||
steps: []step{
|
||||
{
|
||||
input: llm.CompletionResponse{Content: "<|channel|>commentary to=functions.calculate<|message|>{\"expr", Done: false},
|
||||
// No output yet - incomplete JSON
|
||||
},
|
||||
{
|
||||
input: llm.CompletionResponse{Content: "ession\":\"2+", Done: false},
|
||||
// Still no output - incomplete JSON
|
||||
},
|
||||
{
|
||||
input: llm.CompletionResponse{Content: "2\"}", Done: true},
|
||||
wantToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "calculate",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"expression": "2+2",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
anyOnlies := false
|
||||
for _, tc := range testCases {
|
||||
if tc.only {
|
||||
anyOnlies = true
|
||||
}
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
if anyOnlies && !tc.only {
|
||||
continue
|
||||
}
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var chunks []api.ChatResponse
|
||||
chunkIdx := 0
|
||||
|
||||
mockResponses := make([]llm.CompletionResponse, len(tc.steps))
|
||||
for i, step := range tc.steps {
|
||||
mockResponses[i] = step.input
|
||||
}
|
||||
|
||||
mock := mockRunner{
|
||||
CompletionFn: func(ctx context.Context, r llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
|
||||
for _, resp := range mockResponses {
|
||||
fn(resp)
|
||||
// Give the handler time to process each response
|
||||
time.Sleep(30 * time.Millisecond)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
s := Server{
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: newMockServer(&mock),
|
||||
getGpuFn: discover.GetGPUInfo,
|
||||
getCpuFn: discover.GetCPUInfo,
|
||||
reschedDelay: 100 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ int) {
|
||||
req.successCh <- &runnerRef{
|
||||
llama: &mock,
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
go s.sched.Run(t.Context())
|
||||
|
||||
// Create a simple test model
|
||||
_, digest := createHarmonyTestModel(t)
|
||||
|
||||
streamFalse := false
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "harmony-test-streaming",
|
||||
Files: map[string]string{"test.gguf": digest},
|
||||
Template: `<|start|><|end|>{{ with .Tools }}{{ end }}{{ .Prompt }}`,
|
||||
Stream: &streamFalse,
|
||||
})
|
||||
|
||||
if w.Code != 200 {
|
||||
t.Fatalf("failed to create model: %d", w.Code)
|
||||
}
|
||||
|
||||
// Test chat endpoint with streaming
|
||||
streamTrue := true
|
||||
w = createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||
Model: "harmony-test-streaming",
|
||||
Messages: []api.Message{{Role: "user", Content: "Hello"}},
|
||||
Stream: &streamTrue,
|
||||
Tools: getTestTools(),
|
||||
})
|
||||
|
||||
if w.Code != 200 {
|
||||
t.Fatalf("chat request failed: %d - %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
// Parse all chunks
|
||||
decoder := json.NewDecoder(w.Body)
|
||||
for decoder.More() {
|
||||
var chunk api.ChatResponse
|
||||
if err := decoder.Decode(&chunk); err != nil {
|
||||
t.Fatalf("failed to decode chunk: %v", err)
|
||||
}
|
||||
if chunk.Message.Content != "" || chunk.Message.Thinking != "" || len(chunk.Message.ToolCalls) > 0 {
|
||||
chunks = append(chunks, chunk)
|
||||
}
|
||||
}
|
||||
|
||||
// Log received chunks for debugging
|
||||
if t.Failed() || len(chunks) == 0 {
|
||||
t.Logf("Received %d chunks:", len(chunks))
|
||||
for i, chunk := range chunks {
|
||||
t.Logf(" Chunk %d: content=%q thinking=%q", i, chunk.Message.Content, chunk.Message.Thinking)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify chunks match expected steps
|
||||
for i, step := range tc.steps {
|
||||
// Skip steps that don't expect any output
|
||||
if step.wantContent == "" && step.wantThinking == "" && len(step.wantToolCalls) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if chunkIdx >= len(chunks) {
|
||||
t.Errorf("step %d: expected chunk not received (wanted content=%q thinking=%q)",
|
||||
i, step.wantContent, step.wantThinking)
|
||||
continue
|
||||
}
|
||||
|
||||
chunk := chunks[chunkIdx]
|
||||
if chunk.Message.Content != step.wantContent || chunk.Message.Thinking != step.wantThinking {
|
||||
t.Errorf("step %d: chunk mismatch: got (content=%q, thinking=%q), want (content=%q, thinking=%q)",
|
||||
i, chunk.Message.Content, chunk.Message.Thinking, step.wantContent, step.wantThinking)
|
||||
}
|
||||
|
||||
// Check tool calls if expected
|
||||
if len(step.wantToolCalls) > 0 {
|
||||
if len(chunk.Message.ToolCalls) != len(step.wantToolCalls) {
|
||||
t.Errorf("step %d: tool calls count mismatch: got %d, want %d",
|
||||
i, len(chunk.Message.ToolCalls), len(step.wantToolCalls))
|
||||
} else {
|
||||
for j, wantCall := range step.wantToolCalls {
|
||||
if j >= len(chunk.Message.ToolCalls) {
|
||||
break
|
||||
}
|
||||
gotCall := chunk.Message.ToolCalls[j]
|
||||
if gotCall.Function.Name != wantCall.Function.Name {
|
||||
t.Errorf("step %d, tool call %d: name mismatch: got %q, want %q",
|
||||
i, j, gotCall.Function.Name, wantCall.Function.Name)
|
||||
}
|
||||
// Compare arguments as JSON strings for simplicity
|
||||
gotArgs, _ := json.Marshal(gotCall.Function.Arguments)
|
||||
wantArgs, _ := json.Marshal(wantCall.Function.Arguments)
|
||||
if string(gotArgs) != string(wantArgs) {
|
||||
t.Errorf("step %d, tool call %d: arguments mismatch: got %s, want %s",
|
||||
i, j, string(gotArgs), string(wantArgs))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
chunkIdx++
|
||||
}
|
||||
|
||||
// Check if we have extra chunks
|
||||
if chunkIdx < len(chunks) {
|
||||
t.Errorf("received %d extra chunks", len(chunks)-chunkIdx)
|
||||
for i := chunkIdx; i < len(chunks); i++ {
|
||||
t.Logf(" extra chunk %d: content=%q thinking=%q",
|
||||
i-chunkIdx, chunks[i].Message.Content, chunks[i].Message.Thinking)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestChatHarmonyParserStreamingSimple is a simpler test that just verifies basic streaming
|
||||
func TestChatHarmonyParserStreamingSimple(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
mockResponses := []llm.CompletionResponse{
|
||||
{Content: "<|message|>First ", Done: false},
|
||||
{Content: "chunk ", Done: false},
|
||||
{Content: "here<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
|
||||
}
|
||||
|
||||
mock := mockRunner{
|
||||
CompletionFn: func(ctx context.Context, r llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
|
||||
t.Logf("Mock received prompt: %q", r.Prompt)
|
||||
t.Logf("Mock sending %d responses", len(mockResponses))
|
||||
for i, resp := range mockResponses {
|
||||
t.Logf("Sending response %d: %q", i, resp.Content)
|
||||
fn(resp)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
s := Server{
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: newMockServer(&mock),
|
||||
getGpuFn: discover.GetGPUInfo,
|
||||
getCpuFn: discover.GetCPUInfo,
|
||||
reschedDelay: 100 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ int) {
|
||||
req.successCh <- &runnerRef{
|
||||
llama: &mock,
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
go s.sched.Run(t.Context())
|
||||
|
||||
// Create model
|
||||
_, digest := createHarmonyTestModel(t)
|
||||
streamFalse := false
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "gpt-oss",
|
||||
Files: map[string]string{"test.gguf": digest},
|
||||
Template: `<|start|><|end|>{{ .Tools }}{{ .Prompt }}`,
|
||||
Stream: &streamFalse,
|
||||
})
|
||||
|
||||
if w.Code != 200 {
|
||||
t.Fatalf("failed to create model: %d", w.Code)
|
||||
}
|
||||
|
||||
// Test streaming
|
||||
streamTrue := true
|
||||
w = createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||
Model: "gpt-oss",
|
||||
Messages: []api.Message{{Role: "user", Content: "Hello"}},
|
||||
Stream: &streamTrue,
|
||||
Tools: getTestTools(),
|
||||
})
|
||||
|
||||
if w.Code != 200 {
|
||||
t.Fatalf("chat request failed: %d - %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
// Parse chunks
|
||||
var chunks []api.ChatResponse
|
||||
decoder := json.NewDecoder(w.Body)
|
||||
for decoder.More() {
|
||||
var chunk api.ChatResponse
|
||||
if err := decoder.Decode(&chunk); err != nil {
|
||||
t.Fatalf("failed to decode chunk: %v", err)
|
||||
}
|
||||
chunks = append(chunks, chunk)
|
||||
t.Logf("Received chunk %d: content=%q thinking=%q done=%v",
|
||||
len(chunks), chunk.Message.Content, chunk.Message.Thinking, chunk.Done)
|
||||
}
|
||||
|
||||
// Verify we got chunks
|
||||
if len(chunks) == 0 {
|
||||
t.Fatal("expected streaming chunks, got none")
|
||||
}
|
||||
|
||||
// Verify content
|
||||
var content strings.Builder
|
||||
for _, chunk := range chunks {
|
||||
content.WriteString(chunk.Message.Content)
|
||||
}
|
||||
|
||||
expectedContent := "First chunk here"
|
||||
if content.String() != expectedContent {
|
||||
t.Errorf("content mismatch: got %q, want %q", content.String(), expectedContent)
|
||||
}
|
||||
|
||||
// Verify we got multiple chunks (streaming)
|
||||
contentChunks := 0
|
||||
for _, chunk := range chunks {
|
||||
if chunk.Message.Content != "" {
|
||||
contentChunks++
|
||||
}
|
||||
}
|
||||
|
||||
if contentChunks < 2 {
|
||||
t.Errorf("expected at least 2 content chunks for streaming, got %d", contentChunks)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatHarmonyParserStreaming(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
type expectedChunk struct {
|
||||
afterResponse int // Which mock response this chunk should appear after
|
||||
content string // Expected content in this chunk
|
||||
thinking string // Expected thinking in this chunk
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
mockResponses []llm.CompletionResponse
|
||||
expectedChunks []expectedChunk
|
||||
wantContent string
|
||||
wantThinking string
|
||||
}{
|
||||
{
|
||||
name: "simple message without thinking",
|
||||
mockResponses: []llm.CompletionResponse{
|
||||
{Content: "<|start|>assistant<|message|>Hello, ", Done: false},
|
||||
{Content: "how can I help?", Done: false},
|
||||
{Content: "<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
|
||||
},
|
||||
expectedChunks: []expectedChunk{
|
||||
{afterResponse: 1, content: "Hello, "},
|
||||
{afterResponse: 2, content: "how can I help?"},
|
||||
},
|
||||
wantContent: "Hello, how can I help?",
|
||||
},
|
||||
{
|
||||
name: "message with analysis channel for thinking",
|
||||
mockResponses: []llm.CompletionResponse{
|
||||
{Content: "<|channel|>analysis<|message|>", Done: false},
|
||||
{Content: "Let me think ", Done: false},
|
||||
{Content: "about this problem...", Done: false},
|
||||
{Content: "<|end|>", Done: false},
|
||||
{Content: "<|start|>assistant<|message|>", Done: false},
|
||||
{Content: "The answer ", Done: false},
|
||||
{Content: "is 42", Done: false},
|
||||
{Content: "<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
|
||||
},
|
||||
expectedChunks: []expectedChunk{
|
||||
{afterResponse: 2, thinking: "Let me think "},
|
||||
{afterResponse: 3, thinking: "about this problem..."},
|
||||
{afterResponse: 6, content: "The answer "},
|
||||
{afterResponse: 7, content: "is 42"},
|
||||
},
|
||||
wantContent: "The answer is 42",
|
||||
wantThinking: "Let me think about this problem...",
|
||||
},
|
||||
{
|
||||
name: "streaming with partial tags across boundaries",
|
||||
mockResponses: []llm.CompletionResponse{
|
||||
{Content: "<|chan", Done: false},
|
||||
{Content: "nel|>analy", Done: false},
|
||||
{Content: "sis<|mess", Done: false},
|
||||
{Content: "age|>Think", Done: false},
|
||||
{Content: "ing deeply...<|end|>", Done: false},
|
||||
{Content: "<|start|>assi", Done: false},
|
||||
{Content: "stant<|message|>Result ", Done: false},
|
||||
{Content: "computed<|e", Done: false},
|
||||
{Content: "nd|>", Done: true, DoneReason: llm.DoneReasonStop},
|
||||
},
|
||||
expectedChunks: []expectedChunk{
|
||||
{afterResponse: 4, thinking: "Think"},
|
||||
{afterResponse: 5, thinking: "ing deeply..."},
|
||||
{afterResponse: 7, content: "Result "},
|
||||
{afterResponse: 8, content: "computed"},
|
||||
},
|
||||
wantContent: "Result computed",
|
||||
wantThinking: "Thinking deeply...",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Channel to synchronize mock responses with chunk verification
|
||||
responsesSent := make(chan int, len(tc.mockResponses))
|
||||
|
||||
mock := mockRunner{
|
||||
CompletionFn: func(ctx context.Context, r llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
|
||||
// Send mock responses one at a time, notifying when each is sent
|
||||
for i, resp := range tc.mockResponses {
|
||||
fn(resp)
|
||||
responsesSent <- i + 1
|
||||
}
|
||||
close(responsesSent)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
s := Server{
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: newMockServer(&mock),
|
||||
getGpuFn: discover.GetGPUInfo,
|
||||
getCpuFn: discover.GetCPUInfo,
|
||||
reschedDelay: 250 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ int) {
|
||||
req.successCh <- &runnerRef{
|
||||
llama: &mock,
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
go s.sched.Run(t.Context())
|
||||
|
||||
// Create a minimal model
|
||||
_, digest := createHarmonyTestModel(t)
|
||||
|
||||
// Create model with passthrough template
|
||||
stream := false
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "harmony-test",
|
||||
Files: map[string]string{"file.gguf": digest},
|
||||
Template: `<|start|><|end|>{{ with .Tools }}{{ end }}{{ .Prompt }}`,
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("failed to create model: %d", w.Code)
|
||||
}
|
||||
|
||||
// Test chat endpoint with streaming
|
||||
streamTrue := true
|
||||
w = createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||
Model: "harmony-test",
|
||||
Messages: []api.Message{{Role: "user", Content: "Hello"}},
|
||||
Stream: &streamTrue,
|
||||
Tools: getTestTools(),
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("chat request failed: %d - %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
// Parse streaming response
|
||||
var chunks []api.ChatResponse
|
||||
var content, thinking strings.Builder
|
||||
|
||||
decoder := json.NewDecoder(w.Body)
|
||||
for decoder.More() {
|
||||
var chunk api.ChatResponse
|
||||
if err := decoder.Decode(&chunk); err != nil {
|
||||
t.Fatalf("failed to decode chunk: %v", err)
|
||||
}
|
||||
chunks = append(chunks, chunk)
|
||||
|
||||
// Accumulate content and thinking from each chunk
|
||||
content.WriteString(chunk.Message.Content)
|
||||
thinking.WriteString(chunk.Message.Thinking)
|
||||
|
||||
// Debug output
|
||||
t.Logf("Chunk %d: content=%q thinking=%q done=%v", len(chunks), chunk.Message.Content, chunk.Message.Thinking, chunk.Done)
|
||||
}
|
||||
|
||||
// Verify we got streaming chunks
|
||||
if len(chunks) == 0 {
|
||||
t.Fatal("expected streaming chunks, got none")
|
||||
}
|
||||
|
||||
gotContent := content.String()
|
||||
gotThinking := thinking.String()
|
||||
|
||||
if gotContent != tc.wantContent {
|
||||
t.Errorf("content mismatch: got %q, want %q", gotContent, tc.wantContent)
|
||||
}
|
||||
if gotThinking != tc.wantThinking {
|
||||
t.Errorf("thinking mismatch: got %q, want %q", gotThinking, tc.wantThinking)
|
||||
}
|
||||
|
||||
// Verify last chunk has done=true
|
||||
lastChunk := chunks[len(chunks)-1]
|
||||
if !lastChunk.Done {
|
||||
t.Error("expected last chunk to have done=true")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"sync"
|
||||
"text/template"
|
||||
"text/template/parse"
|
||||
"time"
|
||||
|
||||
"github.com/agnivade/levenshtein"
|
||||
|
||||
@@ -121,6 +122,11 @@ var funcs = template.FuncMap{
|
||||
b, _ := json.Marshal(v)
|
||||
return string(b)
|
||||
},
|
||||
"currentDate": func(args ...string) string {
|
||||
// Currently ignoring the format argument, but accepting it for future use
|
||||
// Default format is YYYY-MM-DD
|
||||
return time.Now().Format("2006-01-02")
|
||||
},
|
||||
}
|
||||
|
||||
func Parse(s string) (*Template, error) {
|
||||
@@ -160,12 +166,18 @@ func (t *Template) Vars() []string {
|
||||
return slices.Sorted(maps.Keys(set))
|
||||
}
|
||||
|
||||
func (t *Template) Contains(s string) bool {
|
||||
return strings.Contains(t.raw, s)
|
||||
}
|
||||
|
||||
type Values struct {
|
||||
Messages []api.Message
|
||||
api.Tools
|
||||
Prompt string
|
||||
Suffix string
|
||||
Think bool
|
||||
// ThinkLevel contains the thinking level if Think is true and a string value was provided
|
||||
ThinkLevel string
|
||||
// whether or not the user explicitly set the thinking flag (vs. it being
|
||||
// implicitly false). Templates can't see whether `Think` is nil
|
||||
IsThinkSet bool
|
||||
@@ -228,6 +240,7 @@ func (t *Template) Execute(w io.Writer, v Values) error {
|
||||
"Suffix": v.Suffix,
|
||||
"Response": "",
|
||||
"Think": v.Think,
|
||||
"ThinkLevel": v.ThinkLevel,
|
||||
"IsThinkSet": v.IsThinkSet,
|
||||
})
|
||||
} else if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
|
||||
@@ -237,6 +250,7 @@ func (t *Template) Execute(w io.Writer, v Values) error {
|
||||
"Tools": v.Tools,
|
||||
"Response": "",
|
||||
"Think": v.Think,
|
||||
"ThinkLevel": v.ThinkLevel,
|
||||
"IsThinkSet": v.IsThinkSet,
|
||||
})
|
||||
}
|
||||
@@ -251,6 +265,7 @@ func (t *Template) Execute(w io.Writer, v Values) error {
|
||||
"Prompt": prompt,
|
||||
"Response": response,
|
||||
"Think": v.Think,
|
||||
"ThinkLevel": v.ThinkLevel,
|
||||
"IsThinkSet": v.IsThinkSet,
|
||||
}); err != nil {
|
||||
return err
|
||||
@@ -298,6 +313,7 @@ func (t *Template) Execute(w io.Writer, v Values) error {
|
||||
"Prompt": prompt,
|
||||
"Response": response,
|
||||
"Think": v.Think,
|
||||
"ThinkLevel": v.ThinkLevel,
|
||||
"IsThinkSet": v.IsThinkSet,
|
||||
}); err != nil {
|
||||
return err
|
||||
|
||||
@@ -26,6 +26,10 @@ type Parser struct {
|
||||
n int
|
||||
}
|
||||
|
||||
func (p *Parser) GetBuffer() []byte {
|
||||
return p.buffer
|
||||
}
|
||||
|
||||
// NewParser creates a new tool call parser from a model's chat
|
||||
// template and a list of provided tools.
|
||||
func NewParser(tmpl *template.Template, tools []api.Tool) *Parser {
|
||||
|
||||
Reference in New Issue
Block a user