Compare commits

..

1 Commits

Author SHA1 Message Date
ParthSareen
a5d638dfe7 extras 2025-03-12 16:12:29 -04:00
124 changed files with 2864 additions and 6056 deletions

View File

@@ -86,9 +86,9 @@ if(CMAKE_CUDA_COMPILER)
)
endif()
set(WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX "^gfx(906|908|90a|1200|1201):xnack[+-]$"
set(WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX "^gfx(906|908|90a):xnack[+-]$"
CACHE STRING
"Regular expression describing AMDGPU_TARGETS not supported on Windows. Override to force building these targets. Default \"^gfx(906|908|90a|1200|1201):xnack[+-]$\"."
"Regular expression describing AMDGPU_TARGETS not supported on Windows. Override to force building these targets. Default \"^gfx(906|908|90a):xnack[+-]$\"."
)
check_language(HIP)
@@ -97,7 +97,7 @@ if(CMAKE_HIP_COMPILER)
find_package(hip REQUIRED)
if(NOT AMDGPU_TARGETS)
list(FILTER AMDGPU_TARGETS INCLUDE REGEX "^gfx(900|94[012]|101[02]|1030|110[012]|120[01])$")
list(FILTER AMDGPU_TARGETS INCLUDE REGEX "^gfx(900|94[012]|101[02]|1030|110[012])$")
elseif(WIN32 AND WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX)
list(FILTER AMDGPU_TARGETS EXCLUDE REGEX ${WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX})
endif()

View File

@@ -56,7 +56,7 @@
"name": "ROCm 6",
"inherits": [ "ROCm" ],
"cacheVariables": {
"AMDGPU_TARGETS": "gfx900;gfx940;gfx941;gfx942;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102;gfx1151;gfx1200;gfx1201;gfx906:xnack-;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-"
"AMDGPU_TARGETS": "gfx900;gfx940;gfx941;gfx942;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102;gfx906:xnack-;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-"
}
}
],

View File

@@ -51,7 +51,7 @@ see if the change were accepted.
The title should look like:
<package>: <short description>
<package>: <short description>
The package is the most affected Go package. If the change does not affect Go
code, then use the directory name instead. Changes to a single well-known

View File

@@ -104,8 +104,8 @@ COPY --from=cuda-12 dist/lib/ollama/cuda_v12 /lib/ollama/cuda_v12
FROM --platform=linux/arm64 scratch AS arm64
COPY --from=cuda-11 dist/lib/ollama/cuda_v11 /lib/ollama/cuda_v11
COPY --from=cuda-12 dist/lib/ollama/cuda_v12 /lib/ollama/cuda_v12
COPY --from=jetpack-5 dist/lib/ollama/cuda_v11 /lib/ollama/cuda_jetpack5
COPY --from=jetpack-6 dist/lib/ollama/cuda_v12 /lib/ollama/cuda_jetpack6
COPY --from=jetpack-5 dist/lib/ollama/cuda_v11 lib/ollama/cuda_jetpack5
COPY --from=jetpack-6 dist/lib/ollama/cuda_v12 lib/ollama/cuda_jetpack6
FROM scratch AS rocm
COPY --from=rocm-6 dist/lib/ollama/rocm /lib/ollama/rocm

View File

@@ -285,13 +285,12 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Bionic GPT](https://github.com/bionic-gpt/bionic-gpt)
- [HTML UI](https://github.com/rtcfirefly/ollama-ui)
- [Saddle](https://github.com/jikkuatwork/saddle)
- [TagSpaces](https://www.tagspaces.org) (A platform for file based apps, [utilizing Ollama](https://docs.tagspaces.org/ai/) for the generation of tags and descriptions)
- [Chatbot UI](https://github.com/ivanfioravanti/chatbot-ollama)
- [Chatbot UI v2](https://github.com/mckaywrigley/chatbot-ui)
- [Typescript UI](https://github.com/ollama-interface/Ollama-Gui?tab=readme-ov-file)
- [Minimalistic React UI for Ollama Models](https://github.com/richawo/minimal-llm-ui)
- [Ollamac](https://github.com/kevinhermawan/Ollamac)
- [big-AGI](https://github.com/enricoros/big-AGI)
- [big-AGI](https://github.com/enricoros/big-AGI/blob/main/docs/config-local-ollama.md)
- [Cheshire Cat assistant framework](https://github.com/cheshire-cat-ai/core)
- [Amica](https://github.com/semperai/amica)
- [chatd](https://github.com/BruceMacD/chatd)
@@ -325,7 +324,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [RWKV-Runner](https://github.com/josStorer/RWKV-Runner) (RWKV offline LLM deployment tool, also usable as a client for ChatGPT and Ollama)
- [Ollama Grid Search](https://github.com/dezoito/ollama-grid-search) (app to evaluate and compare models)
- [Olpaka](https://github.com/Otacon/olpaka) (User-friendly Flutter Web App for Ollama)
- [Casibase](https://casibase.org) (An open source AI knowledge base and dialogue system combining the latest RAG, SSO, ollama support and multiple large language models.)
- [OllamaSpring](https://github.com/CrazyNeil/OllamaSpring) (Ollama Client for macOS)
- [LLocal.in](https://github.com/kartikm7/llocal) (Easy to use Electron Desktop Client for Ollama)
- [Shinkai Desktop](https://github.com/dcSpark/shinkai-apps) (Two click install Local AI using Ollama + Files + RAG)
@@ -348,7 +346,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [PartCAD](https://github.com/openvmp/partcad/) (CAD model generation with OpenSCAD and CadQuery)
- [Ollama4j Web UI](https://github.com/ollama4j/ollama4j-web-ui) - Java-based Web UI for Ollama built with Vaadin, Spring Boot and Ollama4j
- [PyOllaMx](https://github.com/kspviswa/pyOllaMx) - macOS application capable of chatting with both Ollama and Apple MLX models.
- [Cline](https://github.com/cline/cline) - Formerly known as Claude Dev is a VSCode extension for multi-file/whole-repo coding
- [Claude Dev](https://github.com/saoudrizwan/claude-dev) - VSCode extension for multi-file/whole-repo coding
- [Cherry Studio](https://github.com/kangfenmao/cherry-studio) (Desktop client with Ollama support)
- [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy focused LLM chat interface with optional encryption)
- [Archyve](https://github.com/nickthecook/archyve) (RAG-enabling document library)
@@ -394,10 +392,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [1Panel](https://github.com/1Panel-dev/1Panel/) (Web-based Linux Server Management Tool)
- [AstrBot](https://github.com/Soulter/AstrBot/) (User-friendly LLM-based multi-platform chatbot with a WebUI, supporting RAG, LLM agents, and plugins integration)
- [Reins](https://github.com/ibrahimcetin/reins) (Easily tweak parameters, customize system prompts per chat, and enhance your AI experiments with reasoning model support.)
- [Ellama](https://github.com/zeozeozeo/ellama) (Friendly native app to chat with an Ollama instance)
- [screenpipe](https://github.com/mediar-ai/screenpipe) Build agents powered by your screen history
- [Ollamb](https://github.com/hengkysteen/ollamb) (Simple yet rich in features, cross-platform built with Flutter and designed for Ollama. Try the [web demo](https://hengkysteen.github.io/demo/ollamb/).)
- [Writeopia](https://github.com/Writeopia/Writeopia) (Text editor with integration with Ollama)
### Cloud
@@ -437,10 +431,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [SwollamaCLI](https://github.com/marcusziade/Swollama) bundled with the Swollama Swift package. [Demo](https://github.com/marcusziade/Swollama?tab=readme-ov-file#cli-usage)
- [aichat](https://github.com/sigoden/aichat) All-in-one LLM CLI tool featuring Shell Assistant, Chat-REPL, RAG, AI tools & agents, with access to OpenAI, Claude, Gemini, Ollama, Groq, and more.
- [PowershAI](https://github.com/rrg92/powershai) PowerShell module that brings AI to terminal on Windows, including support for Ollama
- [DeepShell](https://github.com/Abyss-c0re/deepshell) Your self-hosted AI assistant. Interactive Shell, Files and Folders analysis.
- [orbiton](https://github.com/xyproto/orbiton) Configuration-free text editor and IDE with support for tab completion with Ollama.
- [orca-cli](https://github.com/molbal/orca-cli) Ollama Registry CLI Application - Browse, pull and download models from Ollama Registry in your terminal.
- [GGUF-to-Ollama](https://github.com/jonathanhecl/gguf-to-ollama) - Importing GGUF to Ollama made easy (multiplatform)
### Apple Vision Pro
@@ -519,7 +510,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Ollama for Zig](https://github.com/dravenk/ollama-zig)
- [Abso](https://github.com/lunary-ai/abso) (OpenAI-compatible TypeScript SDK for any LLM provider)
- [Nichey](https://github.com/goodreasonai/nichey) is a Python package for generating custom wikis for your research topic
- [Ollama for D](https://github.com/kassane/ollama-d)
### Mobile

View File

@@ -12,7 +12,6 @@ import (
"time"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/types/model"
)
// StatusError is an error with an HTTP status code and message.
@@ -82,7 +81,7 @@ type GenerateRequest struct {
// Options lists model-specific options. For example, temperature can be
// set through this field, if the model supports it.
Options map[string]any `json:"options"`
Options map[string]interface{} `json:"options"`
}
// ChatRequest describes a request sent by [Client.Chat].
@@ -107,7 +106,7 @@ type ChatRequest struct {
Tools `json:"tools,omitempty"`
// Options lists model-specific options.
Options map[string]any `json:"options"`
Options map[string]interface{} `json:"options"`
}
type Tools []Tool
@@ -163,65 +162,19 @@ func (t *ToolCallFunctionArguments) String() string {
type Tool struct {
Type string `json:"type"`
Items any `json:"items,omitempty"`
Function ToolFunction `json:"function"`
}
// PropertyType can be either a string or an array of strings
type PropertyType []string
// UnmarshalJSON implements the json.Unmarshaler interface
func (pt *PropertyType) UnmarshalJSON(data []byte) error {
// Try to unmarshal as a string first
var s string
if err := json.Unmarshal(data, &s); err == nil {
*pt = []string{s}
return nil
}
// If that fails, try to unmarshal as an array of strings
var a []string
if err := json.Unmarshal(data, &a); err != nil {
return err
}
*pt = a
return nil
}
// MarshalJSON implements the json.Marshaler interface
func (pt PropertyType) MarshalJSON() ([]byte, error) {
if len(pt) == 1 {
// If there's only one type, marshal as a string
return json.Marshal(pt[0])
}
// Otherwise marshal as an array
return json.Marshal([]string(pt))
}
// String returns a string representation of the PropertyType
func (pt PropertyType) String() string {
if len(pt) == 0 {
return ""
}
if len(pt) == 1 {
return pt[0]
}
return fmt.Sprintf("%v", []string(pt))
}
type ToolFunction struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters struct {
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required"`
Properties map[string]struct {
Type PropertyType `json:"type"`
Items any `json:"items,omitempty"`
Description string `json:"description"`
Enum []any `json:"enum,omitempty"`
Type string `json:"type"`
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
} `json:"properties"`
} `json:"parameters"`
}
@@ -307,7 +260,7 @@ type EmbedRequest struct {
Truncate *bool `json:"truncate,omitempty"`
// Options lists model-specific options.
Options map[string]any `json:"options"`
Options map[string]interface{} `json:"options"`
}
// EmbedResponse is the response from [Client.Embed].
@@ -333,7 +286,7 @@ type EmbeddingRequest struct {
KeepAlive *Duration `json:"keep_alive,omitempty"`
// Options lists model-specific options.
Options map[string]any `json:"options"`
Options map[string]interface{} `json:"options"`
}
// EmbeddingResponse is the response from [Client.Embeddings].
@@ -379,7 +332,7 @@ type ShowRequest struct {
Template string `json:"template"`
Verbose bool `json:"verbose"`
Options map[string]any `json:"options"`
Options map[string]interface{} `json:"options"`
// Deprecated: set the model name with Model instead
Name string `json:"name"`
@@ -387,18 +340,16 @@ type ShowRequest struct {
// ShowResponse is the response returned from [Client.Show].
type ShowResponse struct {
License string `json:"license,omitempty"`
Modelfile string `json:"modelfile,omitempty"`
Parameters string `json:"parameters,omitempty"`
Template string `json:"template,omitempty"`
System string `json:"system,omitempty"`
Details ModelDetails `json:"details,omitempty"`
Messages []Message `json:"messages,omitempty"`
ModelInfo map[string]any `json:"model_info,omitempty"`
ProjectorInfo map[string]any `json:"projector_info,omitempty"`
Tensors []Tensor `json:"tensors,omitempty"`
Capabilities []model.Capability `json:"capabilities,omitempty"`
ModifiedAt time.Time `json:"modified_at,omitempty"`
License string `json:"license,omitempty"`
Modelfile string `json:"modelfile,omitempty"`
Parameters string `json:"parameters,omitempty"`
Template string `json:"template,omitempty"`
System string `json:"system,omitempty"`
Details ModelDetails `json:"details,omitempty"`
Messages []Message `json:"messages,omitempty"`
ModelInfo map[string]any `json:"model_info,omitempty"`
ProjectorInfo map[string]any `json:"projector_info,omitempty"`
ModifiedAt time.Time `json:"modified_at,omitempty"`
}
// CopyRequest is the request passed to [Client.Copy].
@@ -516,13 +467,6 @@ type ModelDetails struct {
QuantizationLevel string `json:"quantization_level"`
}
// Tensor describes the metadata for a given tensor.
type Tensor struct {
Name string `json:"name"`
Type string `json:"type"`
Shape []uint64 `json:"shape"`
}
func (m *Metrics) Summary() {
if m.TotalDuration > 0 {
fmt.Fprintf(os.Stderr, "total duration: %v\n", m.TotalDuration)
@@ -551,7 +495,7 @@ func (m *Metrics) Summary() {
}
}
func (opts *Options) FromMap(m map[string]any) error {
func (opts *Options) FromMap(m map[string]interface{}) error {
valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct
typeOpts := reflect.TypeOf(opts).Elem() // types of the fields in the options struct
@@ -608,12 +552,12 @@ func (opts *Options) FromMap(m map[string]any) error {
}
field.SetString(val)
case reflect.Slice:
// JSON unmarshals to []any, not []string
val, ok := val.([]any)
// JSON unmarshals to []interface{}, not []string
val, ok := val.([]interface{})
if !ok {
return fmt.Errorf("option %q must be of type array", key)
}
// convert []any to []string
// convert []interface{} to []string
slice := make([]string, len(val))
for i, item := range val {
str, ok := item.(string)
@@ -720,7 +664,7 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) {
}
// FormatParams converts specified parameter options to their correct types
func FormatParams(params map[string][]string) (map[string]any, error) {
func FormatParams(params map[string][]string) (map[string]interface{}, error) {
opts := Options{}
valueOpts := reflect.ValueOf(&opts).Elem() // names of the fields in the options struct
typeOpts := reflect.TypeOf(opts) // types of the fields in the options struct
@@ -734,7 +678,7 @@ func FormatParams(params map[string][]string) (map[string]any, error) {
}
}
out := make(map[string]any)
out := make(map[string]interface{})
// iterate params and set values based on json struct tags
for key, vals := range params {
if opt, ok := jsonOpts[key]; !ok {

View File

@@ -134,7 +134,7 @@ func TestUseMmapParsingFromJSON(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var oMap map[string]any
var oMap map[string]interface{}
err := json.Unmarshal([]byte(test.req), &oMap)
require.NoError(t, err)
opts := DefaultOptions()
@@ -231,144 +231,3 @@ func TestMessage_UnmarshalJSON(t *testing.T) {
}
}
}
func TestToolFunction_UnmarshalJSON(t *testing.T) {
tests := []struct {
name string
input string
wantErr string
}{
{
name: "valid enum with same types",
input: `{
"name": "test",
"description": "test function",
"parameters": {
"type": "object",
"required": ["test"],
"properties": {
"test": {
"type": "string",
"description": "test prop",
"enum": ["a", "b", "c"]
}
}
}
}`,
wantErr: "",
},
{
name: "empty enum array",
input: `{
"name": "test",
"description": "test function",
"parameters": {
"type": "object",
"required": ["test"],
"properties": {
"test": {
"type": "string",
"description": "test prop",
"enum": []
}
}
}
}`,
wantErr: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var tf ToolFunction
err := json.Unmarshal([]byte(tt.input), &tf)
if tt.wantErr != "" {
require.Error(t, err)
assert.Contains(t, err.Error(), tt.wantErr)
} else {
require.NoError(t, err)
}
})
}
}
func TestPropertyType_UnmarshalJSON(t *testing.T) {
tests := []struct {
name string
input string
expected PropertyType
}{
{
name: "string type",
input: `"string"`,
expected: PropertyType{"string"},
},
{
name: "array of types",
input: `["string", "number"]`,
expected: PropertyType{"string", "number"},
},
{
name: "array with single type",
input: `["string"]`,
expected: PropertyType{"string"},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var pt PropertyType
if err := json.Unmarshal([]byte(test.input), &pt); err != nil {
t.Errorf("Unexpected error: %v", err)
}
if len(pt) != len(test.expected) {
t.Errorf("Length mismatch: got %v, expected %v", len(pt), len(test.expected))
}
for i, v := range pt {
if v != test.expected[i] {
t.Errorf("Value mismatch at index %d: got %v, expected %v", i, v, test.expected[i])
}
}
})
}
}
func TestPropertyType_MarshalJSON(t *testing.T) {
tests := []struct {
name string
input PropertyType
expected string
}{
{
name: "single type",
input: PropertyType{"string"},
expected: `"string"`,
},
{
name: "multiple types",
input: PropertyType{"string", "number"},
expected: `["string","number"]`,
},
{
name: "empty type",
input: PropertyType{},
expected: `[]`,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
data, err := json.Marshal(test.input)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if string(data) != test.expected {
t.Errorf("Marshaled data mismatch: got %v, expected %v", string(data), test.expected)
}
})
}
}

View File

@@ -1,178 +0,0 @@
package benchmark
import (
"context"
"flag"
"fmt"
"testing"
"time"
"github.com/ollama/ollama/api"
)
// Command line flags
var modelFlag string
func init() {
flag.StringVar(&modelFlag, "m", "", "Name of the model to benchmark")
flag.Lookup("m").DefValue = "model"
}
// modelName returns the model name from flags, failing the test if not set
func modelName(b *testing.B) string {
if modelFlag == "" {
b.Fatal("Error: -m flag is required for benchmark tests")
}
return modelFlag
}
type TestCase struct {
name string
prompt string
maxTokens int
}
// runGenerateBenchmark contains the common generate and metrics logic
func runGenerateBenchmark(b *testing.B, ctx context.Context, client *api.Client, req *api.GenerateRequest) {
start := time.Now()
var ttft time.Duration
var metrics api.Metrics
err := client.Generate(ctx, req, func(resp api.GenerateResponse) error {
if ttft == 0 && resp.Response != "" {
ttft = time.Since(start)
}
if resp.Done {
metrics = resp.Metrics
}
return nil
})
// Report custom metrics as part of the benchmark results
b.ReportMetric(float64(ttft.Milliseconds()), "ttft_ms")
b.ReportMetric(float64(metrics.LoadDuration.Milliseconds()), "load_ms")
// Token throughput metrics
promptThroughput := float64(metrics.PromptEvalCount) / metrics.PromptEvalDuration.Seconds()
genThroughput := float64(metrics.EvalCount) / metrics.EvalDuration.Seconds()
b.ReportMetric(promptThroughput, "prompt_tok/s")
b.ReportMetric(genThroughput, "gen_tok/s")
// Token counts
b.ReportMetric(float64(metrics.PromptEvalCount), "prompt_tokens")
b.ReportMetric(float64(metrics.EvalCount), "gen_tokens")
if err != nil {
b.Fatal(err)
}
}
// BenchmarkColdStart runs benchmarks with model loading from cold state
func BenchmarkColdStart(b *testing.B) {
client := setup(b)
tests := []TestCase{
{"short_prompt", "Write a long story", 100},
{"medium_prompt", "Write a detailed economic analysis", 500},
{"long_prompt", "Write a comprehensive AI research paper", 1000},
}
m := modelName(b)
for _, tt := range tests {
b.Run(fmt.Sprintf("%s/cold/%s", m, tt.name), func(b *testing.B) {
ctx := context.Background()
// Set number of tokens as our throughput metric
b.SetBytes(int64(tt.maxTokens))
for b.Loop() {
b.StopTimer()
// Ensure model is unloaded before each iteration
unload(client, m, b)
b.StartTimer()
req := &api.GenerateRequest{
Model: m,
Prompt: tt.prompt,
Options: map[string]any{"num_predict": tt.maxTokens, "temperature": 0.1},
}
runGenerateBenchmark(b, ctx, client, req)
}
})
}
}
// BenchmarkWarmStart runs benchmarks with pre-loaded model
func BenchmarkWarmStart(b *testing.B) {
client := setup(b)
tests := []TestCase{
{"short_prompt", "Write a long story", 100},
{"medium_prompt", "Write a detailed economic analysis", 500},
{"long_prompt", "Write a comprehensive AI research paper", 1000},
}
m := modelName(b)
for _, tt := range tests {
b.Run(fmt.Sprintf("%s/warm/%s", m, tt.name), func(b *testing.B) {
ctx := context.Background()
// Pre-warm the model
warmup(client, m, tt.prompt, b)
// Set number of tokens as our throughput metric
b.SetBytes(int64(tt.maxTokens))
for b.Loop() {
req := &api.GenerateRequest{
Model: m,
Prompt: tt.prompt,
Options: map[string]any{"num_predict": tt.maxTokens, "temperature": 0.1},
}
runGenerateBenchmark(b, ctx, client, req)
}
})
}
}
// setup verifies server and model availability
func setup(b *testing.B) *api.Client {
client, err := api.ClientFromEnvironment()
if err != nil {
b.Fatal(err)
}
if _, err := client.Show(context.Background(), &api.ShowRequest{Model: modelName(b)}); err != nil {
b.Fatalf("Model unavailable: %v", err)
}
return client
}
// warmup ensures the model is loaded and warmed up
func warmup(client *api.Client, model string, prompt string, b *testing.B) {
for range 3 {
err := client.Generate(
context.Background(),
&api.GenerateRequest{
Model: model,
Prompt: prompt,
Options: map[string]any{"num_predict": 50, "temperature": 0.1},
},
func(api.GenerateResponse) error { return nil },
)
if err != nil {
b.Logf("Error during model warm-up: %v", err)
}
}
}
// unload forces model unloading using KeepAlive: 0 parameter
func unload(client *api.Client, model string, b *testing.B) {
req := &api.GenerateRequest{
Model: model,
KeepAlive: &api.Duration{Duration: 0},
}
if err := client.Generate(context.Background(), req, func(api.GenerateResponse) error { return nil }); err != nil {
b.Logf("Unload error: %v", err)
}
time.Sleep(1 * time.Second)
}

View File

@@ -18,8 +18,6 @@ import (
"os/signal"
"path/filepath"
"runtime"
"slices"
"sort"
"strconv"
"strings"
"sync/atomic"
@@ -268,7 +266,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
opts := runOptions{
Model: args[0],
WordWrap: os.Getenv("TERM") == "xterm-256color",
Options: map[string]any{},
Options: map[string]interface{}{},
}
format, err := cmd.Flags().GetString("format")
@@ -340,11 +338,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
return err
}
opts.MultiModal = slices.Contains(info.Capabilities, model.CapabilityVision)
// TODO: remove the projector info and vision info checks below,
// these are left in for backwards compatibility with older servers
// that don't have the capabilities field in the model info
if len(info.ProjectorInfo) != 0 {
opts.MultiModal = true
}
@@ -575,9 +568,8 @@ func ShowHandler(cmd *cobra.Command, args []string) error {
parameters, errParams := cmd.Flags().GetBool("parameters")
system, errSystem := cmd.Flags().GetBool("system")
template, errTemplate := cmd.Flags().GetBool("template")
verbose, errVerbose := cmd.Flags().GetBool("verbose")
for _, boolErr := range []error{errLicense, errModelfile, errParams, errSystem, errTemplate, errVerbose} {
for _, boolErr := range []error{errLicense, errModelfile, errParams, errSystem, errTemplate} {
if boolErr != nil {
return errors.New("error retrieving flags")
}
@@ -615,7 +607,7 @@ func ShowHandler(cmd *cobra.Command, args []string) error {
return errors.New("only one of '--license', '--modelfile', '--parameters', '--system', or '--template' can be specified")
}
req := api.ShowRequest{Name: args[0], Verbose: verbose}
req := api.ShowRequest{Name: args[0]}
resp, err := client.Show(cmd.Context(), &req)
if err != nil {
return err
@@ -638,10 +630,10 @@ func ShowHandler(cmd *cobra.Command, args []string) error {
return nil
}
return showInfo(resp, verbose, os.Stdout)
return showInfo(resp, os.Stdout)
}
func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
func showInfo(resp *api.ShowResponse, w io.Writer) error {
tableRender := func(header string, rows func() [][]string) {
fmt.Fprintln(w, " ", header)
table := tablewriter.NewWriter(w)
@@ -675,15 +667,6 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
return
})
if len(resp.Capabilities) > 0 {
tableRender("Capabilities", func() (rows [][]string) {
for _, capability := range resp.Capabilities {
rows = append(rows, []string{"", capability.String()})
}
return
})
}
if resp.ProjectorInfo != nil {
tableRender("Projector", func() (rows [][]string) {
arch := resp.ProjectorInfo["general.architecture"].(string)
@@ -707,47 +690,6 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
})
}
if resp.ModelInfo != nil && verbose {
tableRender("Metadata", func() (rows [][]string) {
keys := make([]string, 0, len(resp.ModelInfo))
for k := range resp.ModelInfo {
keys = append(keys, k)
}
sort.Strings(keys)
for _, k := range keys {
var v string
switch vData := resp.ModelInfo[k].(type) {
case bool:
v = fmt.Sprintf("%t", vData)
case string:
v = vData
case float64:
v = fmt.Sprintf("%g", vData)
case []any:
n := 3
if len(vData) < n {
n = len(vData)
}
v = fmt.Sprintf("%v", vData[:n])
default:
v = fmt.Sprintf("%T", vData)
}
rows = append(rows, []string{"", k, v})
}
return
})
}
if len(resp.Tensors) > 0 && verbose {
tableRender("Tensors", func() (rows [][]string) {
for _, t := range resp.Tensors {
rows = append(rows, []string{"", t.Name, t.Type, fmt.Sprint(t.Shape)})
}
return
})
}
head := func(s string, n int) (rows [][]string) {
scanner := bufio.NewScanner(strings.NewReader(s))
for scanner.Scan() && (len(rows) < n || n < 0) {
@@ -852,7 +794,7 @@ type runOptions struct {
Format string
System string
Images []api.ImageData
Options map[string]any
Options map[string]interface{}
MultiModal bool
KeepAlive *api.Duration
}
@@ -1254,7 +1196,6 @@ func NewCLI() *cobra.Command {
showCmd.Flags().Bool("parameters", false, "Show parameters of a model")
showCmd.Flags().Bool("template", false, "Show template of a model")
showCmd.Flags().Bool("system", false, "Show system message of a model")
showCmd.Flags().BoolP("verbose", "v", false, "Show detailed model information")
runCmd := &cobra.Command{
Use: "run MODEL [PROMPT]",
@@ -1381,6 +1322,7 @@ func NewCLI() *cobra.Command {
envVars["OLLAMA_NOPRUNE"],
envVars["OLLAMA_ORIGINS"],
envVars["OLLAMA_SCHED_SPREAD"],
envVars["OLLAMA_TMPDIR"],
envVars["OLLAMA_FLASH_ATTENTION"],
envVars["OLLAMA_KV_CACHE_TYPE"],
envVars["OLLAMA_LLM_LIBRARY"],

View File

@@ -16,7 +16,6 @@ import (
"github.com/spf13/cobra"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/types/model"
)
func TestShowInfo(t *testing.T) {
@@ -28,7 +27,7 @@ func TestShowInfo(t *testing.T) {
ParameterSize: "7B",
QuantizationLevel: "FP16",
},
}, false, &b); err != nil {
}, &b); err != nil {
t.Fatal(err)
}
@@ -58,7 +57,7 @@ func TestShowInfo(t *testing.T) {
ParameterSize: "7B",
QuantizationLevel: "FP16",
},
}, false, &b); err != nil {
}, &b); err != nil {
t.Fatal(err)
}
@@ -69,60 +68,6 @@ func TestShowInfo(t *testing.T) {
embedding length 0
quantization FP16
`
if diff := cmp.Diff(expect, b.String()); diff != "" {
t.Errorf("unexpected output (-want +got):\n%s", diff)
}
})
t.Run("verbose model", func(t *testing.T) {
var b bytes.Buffer
if err := showInfo(&api.ShowResponse{
Details: api.ModelDetails{
Family: "test",
ParameterSize: "8B",
QuantizationLevel: "FP16",
},
Parameters: `
stop up`,
ModelInfo: map[string]any{
"general.architecture": "test",
"general.parameter_count": float64(8_000_000_000),
"some.true_bool": true,
"some.false_bool": false,
"test.context_length": float64(1000),
"test.embedding_length": float64(11434),
},
Tensors: []api.Tensor{
{Name: "blk.0.attn_k.weight", Type: "BF16", Shape: []uint64{42, 3117}},
{Name: "blk.0.attn_q.weight", Type: "FP16", Shape: []uint64{3117, 42}},
},
}, true, &b); err != nil {
t.Fatal(err)
}
expect := ` Model
architecture test
parameters 8B
context length 1000
embedding length 11434
quantization FP16
Parameters
stop up
Metadata
general.architecture test
general.parameter_count 8e+09
some.false_bool false
some.true_bool true
test.context_length 1000
test.embedding_length 11434
Tensors
blk.0.attn_k.weight BF16 [42 3117]
blk.0.attn_q.weight FP16 [3117 42]
`
if diff := cmp.Diff(expect, b.String()); diff != "" {
t.Errorf("unexpected output (-want +got):\n%s", diff)
@@ -144,7 +89,7 @@ func TestShowInfo(t *testing.T) {
stop you
stop up
temperature 99`,
}, false, &b); err != nil {
}, &b); err != nil {
t.Fatal(err)
}
@@ -181,7 +126,7 @@ func TestShowInfo(t *testing.T) {
"clip.vision.embedding_length": float64(0),
"clip.vision.projection_dim": float64(0),
},
}, false, &b); err != nil {
}, &b); err != nil {
t.Fatal(err)
}
@@ -214,7 +159,7 @@ func TestShowInfo(t *testing.T) {
Ahoy, matey!
Weigh anchor!
`,
}, false, &b); err != nil {
}, &b); err != nil {
t.Fatal(err)
}
@@ -243,7 +188,7 @@ Weigh anchor!
QuantizationLevel: "FP16",
},
License: license,
}, false, &b); err != nil {
}, &b); err != nil {
t.Fatal(err)
}
@@ -261,34 +206,6 @@ Weigh anchor!
t.Errorf("unexpected output (-want +got):\n%s", diff)
}
})
t.Run("capabilities", func(t *testing.T) {
var b bytes.Buffer
if err := showInfo(&api.ShowResponse{
Details: api.ModelDetails{
Family: "test",
ParameterSize: "7B",
QuantizationLevel: "FP16",
},
Capabilities: []model.Capability{model.CapabilityVision, model.CapabilityTools},
}, false, &b); err != nil {
t.Fatal(err)
}
expect := " Model\n" +
" architecture test \n" +
" parameters 7B \n" +
" quantization FP16 \n" +
"\n" +
" Capabilities\n" +
" vision \n" +
" tools \n" +
"\n"
if diff := cmp.Diff(expect, b.String()); diff != "" {
t.Errorf("unexpected output (-want +got):\n%s", diff)
}
})
}
func TestDeleteHandler(t *testing.T) {
@@ -790,132 +707,3 @@ func TestCreateHandler(t *testing.T) {
})
}
}
func TestNewCreateRequest(t *testing.T) {
tests := []struct {
name string
from string
opts runOptions
expected *api.CreateRequest
}{
{
"basic test",
"newmodel",
runOptions{
Model: "mymodel",
ParentModel: "",
Prompt: "You are a fun AI agent",
Messages: []api.Message{},
WordWrap: true,
},
&api.CreateRequest{
From: "mymodel",
Model: "newmodel",
},
},
{
"parent model test",
"newmodel",
runOptions{
Model: "mymodel",
ParentModel: "parentmodel",
Messages: []api.Message{},
WordWrap: true,
},
&api.CreateRequest{
From: "parentmodel",
Model: "newmodel",
},
},
{
"parent model as filepath test",
"newmodel",
runOptions{
Model: "mymodel",
ParentModel: "/some/file/like/etc/passwd",
Messages: []api.Message{},
WordWrap: true,
},
&api.CreateRequest{
From: "mymodel",
Model: "newmodel",
},
},
{
"parent model as windows filepath test",
"newmodel",
runOptions{
Model: "mymodel",
ParentModel: "D:\\some\\file\\like\\etc\\passwd",
Messages: []api.Message{},
WordWrap: true,
},
&api.CreateRequest{
From: "mymodel",
Model: "newmodel",
},
},
{
"options test",
"newmodel",
runOptions{
Model: "mymodel",
ParentModel: "parentmodel",
Options: map[string]any{
"temperature": 1.0,
},
},
&api.CreateRequest{
From: "parentmodel",
Model: "newmodel",
Parameters: map[string]any{
"temperature": 1.0,
},
},
},
{
"messages test",
"newmodel",
runOptions{
Model: "mymodel",
ParentModel: "parentmodel",
System: "You are a fun AI agent",
Messages: []api.Message{
{
Role: "user",
Content: "hello there!",
},
{
Role: "assistant",
Content: "hello to you!",
},
},
WordWrap: true,
},
&api.CreateRequest{
From: "parentmodel",
Model: "newmodel",
System: "You are a fun AI agent",
Messages: []api.Message{
{
Role: "user",
Content: "hello there!",
},
{
Role: "assistant",
Content: "hello to you!",
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
actual := NewCreateRequest(tt.from, tt.opts)
if !cmp.Equal(actual, tt.expected) {
t.Errorf("expected output %#v, got %#v", tt.expected, actual)
}
})
}
}

View File

@@ -18,7 +18,6 @@ import (
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/readline"
"github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model"
)
type MultilineState int
@@ -348,7 +347,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
switch args[1] {
case "info":
_ = showInfo(resp, false, os.Stderr)
_ = showInfo(resp, os.Stderr)
case "license":
if resp.License == "" {
fmt.Println("No license was specified for this model.")
@@ -460,16 +459,9 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
}
func NewCreateRequest(name string, opts runOptions) *api.CreateRequest {
parentModel := opts.ParentModel
modelName := model.ParseName(parentModel)
if !modelName.IsValid() {
parentModel = ""
}
req := &api.CreateRequest{
Model: name,
From: cmp.Or(parentModel, opts.Model),
Name: name,
From: cmp.Or(opts.ParentModel, opts.Model),
}
if opts.System != "" {

View File

@@ -182,10 +182,8 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
var conv ModelConverter
switch p.Architectures[0] {
case "LlamaForCausalLM":
case "LlamaForCausalLM", "MistralForCausalLM":
conv = &llamaModel{}
case "Mistral3ForConditionalGeneration":
conv = &mistral3Model{}
case "MixtralForCausalLM":
conv = &mixtralModel{}
case "GemmaForCausalLM":
@@ -203,7 +201,7 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
case "CohereForCausalLM":
conv = &commandrModel{}
default:
return fmt.Errorf("unsupported architecture %q", p.Architectures[0])
return errors.New("unsupported architecture")
}
if err := json.Unmarshal(bts, conv); err != nil {

View File

@@ -87,7 +87,7 @@ func (p *gemma3Model) KV(t *Tokenizer) ggml.KV {
kv["gemma3.embedding_length"] = p.HiddenSize
kv["gemma3.feed_forward_length"] = p.IntermediateSize
default:
kv["gemma3.context_length"] = cmp.Or(p.MaxPositionEmbeddings, 131072)
kv["gemma3.context_length"] = cmp.Or(p.MaxPositionEmbeddings, 8192)
kv["gemma3.embedding_length"] = p.TextModel.HiddenSize
kv["gemma3.feed_forward_length"] = p.TextModel.IntermediateSize
kv["gemma3.attention.sliding_window"] = p.TextModel.SlidingWindow

View File

@@ -1,190 +0,0 @@
package convert
import (
"cmp"
"fmt"
"strings"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/ollama/ollama/fs/ggml"
)
type mistral3Model struct {
ModelParameters
ImageTokenIndex uint32 `json:"image_token_index"`
SpatialMergeSize uint32 `json:"spatial_merge_size"`
VisionFeatureLayer int32 `json:"vision_feature_layer"`
TextModel struct {
NumHiddenLayers uint32 `json:"num_hidden_layers"`
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
HiddenSize uint32 `json:"hidden_size"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
RopeTheta float32 `json:"rope_theta"`
RMSNormEPS float32 `json:"rms_norm_eps"`
HeadDim uint32 `json:"head_dim"`
SlidingWindow *uint32 `json:"sliding_window"`
HiddenAct string `json:"hidden_act"`
VocabSize uint32 `json:"vocab_size"`
} `json:"text_config"`
VisionModel struct {
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumHiddenLayers uint32 `json:"num_hidden_layers"`
HiddenSize uint32 `json:"hidden_size"`
IntermediateSize uint32 `json:"intermediate_size"`
ImageSize uint32 `json:"image_size"`
NumChannels uint32 `json:"num_channels"`
PatchSize uint32 `json:"patch_size"`
HeadDim uint32 `json:"head_dim"`
HiddenAct string `json:"hidden_act"`
RopeTheta float32 `json:"rope_theta"`
} `json:"vision_config"`
MultiModalProjectorBias bool `json:"multimodal_projector_bias"`
ProjectorHiddenAct string `json:"projector_hidden_act"`
}
func (p *mistral3Model) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "mistral3"
kv["mistral3.vocab_size"] = p.TextModel.VocabSize
// Text configuration
kv["mistral3.block_count"] = p.TextModel.NumHiddenLayers
kv["mistral3.context_length"] = p.TextModel.MaxPositionEmbeddings
kv["mistral3.embedding_length"] = p.TextModel.HiddenSize
kv["mistral3.feed_forward_length"] = p.TextModel.IntermediateSize
kv["mistral3.attention.head_count"] = p.TextModel.NumAttentionHeads
kv["mistral3.attention.head_count_kv"] = p.TextModel.NumKeyValueHeads
kv["mistral3.attention.layer_norm_rms_epsilon"] = p.TextModel.RMSNormEPS
kv["mistral3.attention.key_length"] = p.TextModel.HeadDim
kv["mistral3.attention.value_length"] = p.TextModel.HeadDim
kv["mistral3.rope.dimension_count"] = p.TextModel.HiddenSize / p.TextModel.NumHiddenLayers
kv["mistral3.rope.freq_base"] = p.TextModel.RopeTheta
// Vision configuration
kv["mistral3.vision.block_count"] = p.VisionModel.NumHiddenLayers
kv["mistral3.vision.embedding_length"] = p.VisionModel.HiddenSize
kv["mistral3.vision.feed_forward_length"] = p.VisionModel.IntermediateSize
kv["mistral3.vision.attention.head_count"] = p.VisionModel.NumAttentionHeads
kv["mistral3.vision.attention.key_length"] = p.VisionModel.HeadDim
kv["mistral3.vision.image_size"] = p.VisionModel.ImageSize
kv["mistral3.vision.patch_size"] = p.VisionModel.PatchSize
kv["mistral3.vision.num_channels"] = p.VisionModel.NumChannels
// kv["mistral3.vision.attention.layer_norm_epsilon"] = 1e-05 // Default value
kv["mistral3.vision.rope.freq_base"] = p.VisionModel.RopeTheta
// Multimodal configuration
kv["mistral3.image_token_index"] = p.ImageTokenIndex
kv["mistral3.spatial_merge_size"] = p.SpatialMergeSize
kv["mistral3.mm.projector_bias"] = p.MultiModalProjectorBias
if p.ProjectorHiddenAct != "" {
kv["mistral3.mm.projector_hidden_act"] = p.ProjectorHiddenAct
}
return kv
}
func (p *mistral3Model) Tensors(ts []Tensor) []ggml.Tensor {
var out []ggml.Tensor
for _, t := range ts {
if !strings.HasPrefix(t.Name(), "v.") {
if strings.HasSuffix(t.Name(), ".attn_q.weight") ||
strings.HasSuffix(t.Name(), ".attn_k.weight") {
t.SetRepacker(p.repack)
}
}
out = append(out, ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
return out
}
func (p *mistral3Model) Replacements() []string {
return []string{
"language_model.model.norm", "output_norm",
"language_model.model.", "",
"language_model.", "",
"layers", "blk",
"transformer.layers", "blk",
"vision_tower", "v",
"ln_pre", "encoder_norm",
"input_layernorm", "attn_norm",
"post_attention_layernorm", "ffn_norm",
"embed_tokens", "token_embd",
"self_attn.q_proj", "attn_q",
"self_attn.k_proj", "attn_k",
"self_attn.v_proj", "attn_v",
"self_attn.o_proj", "attn_output",
"mlp.down_proj", "ffn_down",
"mlp.gate_proj", "ffn_gate",
"mlp.up_proj", "ffn_up",
"attention.q_proj", "attn_q",
"attention.k_proj", "attn_k",
"attention.v_proj", "attn_v",
"attention.o_proj", "attn_output",
"attention_norm", "attn_norm",
"feed_forward.gate_proj", "ffn_gate",
"feed_forward.down_proj", "ffn_down",
"feed_forward.up_proj", "ffn_up",
"multi_modal_projector", "mm",
"ffn_norm", "ffn_norm",
"lm_head", "output",
}
}
func (p *mistral3Model) repack(name string, data []float32, shape []uint64) ([]float32, error) {
var dims []int
for _, dim := range shape {
dims = append(dims, int(dim))
}
var heads uint32
if strings.HasSuffix(name, ".attn_q.weight") {
heads = p.TextModel.NumAttentionHeads
} else if strings.HasSuffix(name, ".attn_k.weight") {
heads = cmp.Or(p.TextModel.NumKeyValueHeads, p.TextModel.NumAttentionHeads)
} else {
return nil, fmt.Errorf("unknown tensor for repack: %s", name)
}
n := tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
if err := n.Reshape(append([]int{int(heads), 2, dims[0] / int(heads) / 2}, dims[1:]...)...); err != nil {
return nil, err
}
if err := n.T(0, 2, 1, 3); err != nil {
return nil, err
}
if err := n.Reshape(dims...); err != nil {
return nil, err
}
if err := n.Transpose(); err != nil {
return nil, err
}
ts, err := native.SelectF32(n, 1)
if err != nil {
return nil, err
}
var f32s []float32
for _, t := range ts {
f32s = append(f32s, t...)
}
return f32s, nil
}

View File

@@ -62,7 +62,10 @@ func parseTensors(fsys fs.FS, replacer *strings.Replacer) ([]Tensor, error) {
Pattern string
Func func(fs.FS, *strings.Replacer, ...string) ([]Tensor, error)
}{
{"*.safetensors", parseSafetensors},
{"model-*-of-*.safetensors", parseSafetensors},
{"model.safetensors", parseSafetensors},
{"adapters.safetensors", parseSafetensors},
{"adapter_model.safetensors", parseSafetensors},
{"pytorch_model-*-of-*.bin", parseTorch},
{"pytorch_model.bin", parseTorch},
{"consolidated.*.pth", parseTorch},

View File

@@ -1360,7 +1360,7 @@ func file_sentencepiece_model_proto_rawDescGZIP() []byte {
var file_sentencepiece_model_proto_enumTypes = make([]protoimpl.EnumInfo, 2)
var file_sentencepiece_model_proto_msgTypes = make([]protoimpl.MessageInfo, 6)
var file_sentencepiece_model_proto_goTypes = []any{
var file_sentencepiece_model_proto_goTypes = []interface{}{
(TrainerSpec_ModelType)(0), // 0: sentencepiece.TrainerSpec.ModelType
(ModelProto_SentencePiece_Type)(0), // 1: sentencepiece.ModelProto.SentencePiece.Type
(*TrainerSpec)(nil), // 2: sentencepiece.TrainerSpec
@@ -1392,7 +1392,7 @@ func file_sentencepiece_model_proto_init() {
return
}
if !protoimpl.UnsafeEnabled {
file_sentencepiece_model_proto_msgTypes[0].Exporter = func(v any, i int) any {
file_sentencepiece_model_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*TrainerSpec); i {
case 0:
return &v.state
@@ -1406,7 +1406,7 @@ func file_sentencepiece_model_proto_init() {
return nil
}
}
file_sentencepiece_model_proto_msgTypes[1].Exporter = func(v any, i int) any {
file_sentencepiece_model_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*NormalizerSpec); i {
case 0:
return &v.state
@@ -1420,7 +1420,7 @@ func file_sentencepiece_model_proto_init() {
return nil
}
}
file_sentencepiece_model_proto_msgTypes[2].Exporter = func(v any, i int) any {
file_sentencepiece_model_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*SelfTestData); i {
case 0:
return &v.state
@@ -1434,7 +1434,7 @@ func file_sentencepiece_model_proto_init() {
return nil
}
}
file_sentencepiece_model_proto_msgTypes[3].Exporter = func(v any, i int) any {
file_sentencepiece_model_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*ModelProto); i {
case 0:
return &v.state
@@ -1448,7 +1448,7 @@ func file_sentencepiece_model_proto_init() {
return nil
}
}
file_sentencepiece_model_proto_msgTypes[4].Exporter = func(v any, i int) any {
file_sentencepiece_model_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*SelfTestData_Sample); i {
case 0:
return &v.state
@@ -1460,7 +1460,7 @@ func file_sentencepiece_model_proto_init() {
return nil
}
}
file_sentencepiece_model_proto_msgTypes[5].Exporter = func(v any, i int) any {
file_sentencepiece_model_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*ModelProto_SentencePiece); i {
case 0:
return &v.state

View File

@@ -12,7 +12,7 @@ func IsNUMA() bool {
// numa support in llama.cpp is linux only
return false
}
ids := map[string]any{}
ids := map[string]interface{}{}
packageIds, _ := filepath.Glob("/sys/devices/system/cpu/cpu*/topology/physical_package_id")
for _, packageId := range packageIds {
id, err := os.ReadFile(packageId)

View File

@@ -111,7 +111,6 @@ func GetCPUDetails() ([]CPU, error) {
if err != nil {
return nil, err
}
defer file.Close()
return linuxCPUDetails(file)
}
@@ -169,11 +168,13 @@ func linuxCPUDetails(file io.Reader) ([]CPU, error) {
for id, s := range socketByID {
s.CoreCount = len(coreBySocket[id])
s.ThreadCount = 0
for _, tc := range threadsByCoreBySocket[id] {
s.ThreadCount += tc
}
// This only works if HT is enabled, consider a more reliable model, maybe cache size comparisons?
efficiencyCoreCount := 0
for _, threads := range threadsByCoreBySocket[id] {
s.ThreadCount += threads
if threads == 1 {
efficiencyCoreCount++
}

View File

@@ -558,10 +558,6 @@ Final response:
{
"model": "llama3.2",
"created_at": "2023-08-04T19:22:45.499127Z",
"message": {
"role": "assistant",
"content": ""
},
"done": true,
"total_duration": 4883583458,
"load_duration": 1334875,
@@ -1217,7 +1213,7 @@ Show information about a model including details, modelfile, template, parameter
```shell
curl http://localhost:11434/api/show -d '{
"model": "llava"
"model": "llama3.2"
}'
```
@@ -1260,11 +1256,7 @@ curl http://localhost:11434/api/show -d '{
"tokenizer.ggml.pre": "llama-bpe",
"tokenizer.ggml.token_type": [], // populates if `verbose=true`
"tokenizer.ggml.tokens": [] // populates if `verbose=true`
},
"capabilities": [
"completion",
"vision"
],
}
}
```

View File

@@ -1,59 +0,0 @@
# Benchmark
Go benchmark tests that measure end-to-end performance of a running Ollama server. Run these tests to evaluate model inference performance on your hardware and measure the impact of code changes.
## When to use
Run these benchmarks when:
- Making changes to the model inference engine
- Modifying model loading/unloading logic
- Changing prompt processing or token generation code
- Implementing a new model architecture
- Testing performance across different hardware setups
## Prerequisites
- Ollama server running locally with `ollama serve` on `127.0.0.1:11434`
## Usage and Examples
>[!NOTE]
>All commands must be run from the root directory of the Ollama project.
Basic syntax:
```bash
go test -bench=. ./benchmark/... -m $MODEL_NAME
```
Required flags:
- `-bench=.`: Run all benchmarks
- `-m`: Model name to benchmark
Optional flags:
- `-count N`: Number of times to run the benchmark (useful for statistical analysis)
- `-timeout T`: Maximum time for the benchmark to run (e.g. "10m" for 10 minutes)
Common usage patterns:
Single benchmark run with a model specified:
```bash
go test -bench=. ./benchmark/... -m llama3.3
```
## Output metrics
The benchmark reports several key metrics:
- `gen_tok/s`: Generated tokens per second
- `prompt_tok/s`: Prompt processing tokens per second
- `ttft_ms`: Time to first token in milliseconds
- `load_ms`: Model load time in milliseconds
- `gen_tokens`: Total tokens generated
- `prompt_tokens`: Total prompt tokens processed
Each benchmark runs two scenarios:
- Cold start: Model is loaded from disk for each test
- Warm start: Model is pre-loaded in memory
Three prompt lengths are tested for each scenario:
- Short prompt (100 tokens)
- Medium prompt (500 tokens)
- Long prompt (1000 tokens)

View File

@@ -20,13 +20,7 @@ Please refer to the [GPU docs](./gpu.md).
## How can I specify the context window size?
By default, Ollama uses a context window size of 2048 tokens.
This can be overridden with the `OLLAMA_CONTEXT_LENGTH` environment variable. For example, to set the default context window to 8K, use:
```shell
OLLAMA_CONTEXT_LENGTH=8192 ollama serve
```
By default, Ollama uses a context window size of 2048 tokens. This can be overridden with the `OLLAMA_CONTEXT_LENGTH` environment variable. For example, to set the default context length to 8K, use: `OLLAMA_CONTEXT_LENGTH=8192 ollama serve`.
To change this when using `ollama run`, use `/set parameter`:
@@ -193,13 +187,6 @@ cloudflared tunnel --url http://localhost:11434 --http-host-header="localhost:11
Ollama allows cross-origin requests from `127.0.0.1` and `0.0.0.0` by default. Additional origins can be configured with `OLLAMA_ORIGINS`.
For browser extensions, you'll need to explicitly allow the extension's origin pattern. Set `OLLAMA_ORIGINS` to include `chrome-extension://*`, `moz-extension://*`, and `safari-web-extension://*` if you wish to allow all browser extensions access, or specific extensions as needed:
```
# Allow all Chrome, Firefox, and Safari extensions
OLLAMA_ORIGINS=chrome-extension://*,moz-extension://*,safari-web-extension://* ollama serve
```
Refer to the section [above](#how-do-i-configure-ollama-server) for how to set environment variables on your platform.
## Where are models stored?

View File

@@ -9,7 +9,7 @@ cat ~/.ollama/logs/server.log
On **Linux** systems with systemd, the logs can be found with this command:
```shell
journalctl -u ollama --no-pager --follow --pager-end
journalctl -u ollama --no-pager
```
When you run Ollama in a **container**, the logs go to stdout/stderr in the container:
@@ -26,6 +26,7 @@ When you run Ollama on **Windows**, there are a few different locations. You can
- `explorer %LOCALAPPDATA%\Ollama` to view logs. The most recent server logs will be in `server.log` and older logs will be in `server-#.log`
- `explorer %LOCALAPPDATA%\Programs\Ollama` to browse the binaries (The installer adds this to your user PATH)
- `explorer %HOMEPATH%\.ollama` to browse where models and configuration is stored
- `explorer %TEMP%` where temporary executable files are stored in one or more `ollama*` directories
To enable additional debug logging to help troubleshoot problems, first **Quit the running app from the tray menu** then in a powershell terminal
@@ -68,6 +69,10 @@ If you run into problems on Linux and want to install an older version, or you'd
curl -fsSL https://ollama.com/install.sh | OLLAMA_VERSION=0.5.7 sh
```
## Linux tmp noexec
If your system is configured with the "noexec" flag where Ollama stores its temporary executable files, you can specify an alternate location by setting OLLAMA_TMPDIR to a location writable by the user ollama runs as. For example OLLAMA_TMPDIR=/usr/share/ollama/
## Linux docker
If Ollama initially works on the GPU in a docker container, but then switches to running on CPU after some period of time with errors in the server log reporting GPU discovery failures, this can be resolved by disabling systemd cgroup management in Docker. Edit `/etc/docker/daemon.json` on the host and add `"exec-opts": ["native.cgroupdriver=cgroupfs"]` to the docker configuration.

View File

@@ -62,6 +62,7 @@ the explorer window by hitting `<Ctrl>+R` and type in:
- *upgrade.log* contains log output for upgrades
- `explorer %LOCALAPPDATA%\Programs\Ollama` contains the binaries (The installer adds this to your user PATH)
- `explorer %HOMEPATH%\.ollama` contains models and configuration
- `explorer %TEMP%` contains temporary executable files in one or more `ollama*` directories
## Uninstall

View File

@@ -5,7 +5,7 @@ import (
"time"
)
func assertEqual(t *testing.T, a any, b any) {
func assertEqual(t *testing.T, a interface{}, b interface{}) {
if a != b {
t.Errorf("Assert failed, expected %v, got %v", b, a)
}

View File

@@ -1,13 +0,0 @@
package fs
type Config interface {
Architecture() string
String(string, ...string) string
Uint(string, ...uint32) uint32
Float(string, ...float32) float32
Bool(string, ...bool) bool
Strings(string, ...[]string) []string
Uints(string, ...[]uint32) []uint32
Floats(string, ...[]float32) []float32
}

View File

@@ -134,10 +134,7 @@ func (kv KV) Floats(key string, defaultValue ...[]float32) []float32 {
}
func (kv KV) OllamaEngineRequired() bool {
return slices.Contains([]string{
"gemma3",
"mistral3",
}, kv.Architecture())
return kv.Architecture() == "gemma3"
}
func keyValue[T string | uint32 | uint64 | float32 | *array | bool](kv KV, key string, defaultValue ...T) T {
@@ -330,10 +327,6 @@ func (t Tensor) Size() uint64 {
return t.parameters() * t.typeSize() / t.blockSize()
}
func (t Tensor) Type() string {
return fileType(t.Kind).String()
}
type container interface {
Name() string
Decode(io.ReadSeeker) (model, error)
@@ -416,7 +409,7 @@ func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) {
}, offset, nil
}
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string) (kv []uint64, partialOffload, fullOffload uint64) {
func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialOffload, fullOffload uint64) {
embedding := f.KV().EmbeddingLength()
heads := f.KV().HeadCount()
headsKV := f.KV().HeadCountKV()
@@ -429,10 +422,7 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
layers := f.Tensors().GroupLayers()
bytesPerElement := kvCacheBytesPerElement(kvCacheType)
kv = make([]uint64, f.KV().BlockCount())
for i := range kv {
kv[i] = uint64(float64(context*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
}
kv = uint64(float64(context*f.KV().BlockCount()*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
switch f.KV().Architecture() {
case "llama":
@@ -466,14 +456,16 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
case "mllama":
var visionTokens, tiles uint64 = 1601, 4
crossAttentionLayers := f.KV().Uints("attention.cross_attention_layers")
for i := range kv {
if slices.Contains(crossAttentionLayers, uint32(i)) {
kv[i] = headsKV * (embeddingHeadsK + embeddingHeadsV) *
4 * // sizeof(float32)
visionTokens *
tiles
}
if crossAttentionLayers, ok := f.KV()["mllama.attention.cross_attention_layers"].(*array); ok {
kv = headsKV *
(embeddingHeadsK + embeddingHeadsV) * // one for K, one for V
(2* // sizeof(float16)
(f.KV().BlockCount()-uint64(crossAttentionLayers.size))* // num non-cross attention layers
context +
4* // sizeof(float32)
uint64(crossAttentionLayers.size)* // num cross attention layers
visionTokens*
tiles)
}
fullOffload = max(
@@ -509,20 +501,6 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
4*embeddingHeadsK*context*8+
embedding*embeddingHeadsK*heads*9/16,
)
// Gemma2 also has sliding window attention but we only have an optimized implementation in the Ollama
// engine. Gemma3 always uses the Ollama engine.
if f.KV().Architecture() == "gemma3" {
const gemma3GlobalCacheCount = 6
slidingWindow := (uint64(numParallel) * uint64(f.KV().Uint("attention.sliding_window"))) + batch
for i := range kv {
// Every 6th layer is a global layer, which is the full context size that has already been set. The other
// layers are the smaller local (sliding) layers.
if (i+1)%gemma3GlobalCacheCount != 0 {
kv[i] = uint64(float64(slidingWindow*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
}
}
}
case "command-r":
fullOffload = max(
4*batch*(embedding+vocab),
@@ -601,52 +579,39 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
}
func (llm GGML) VisionGraphSize() (weights, graphSize uint64) {
if llm.KV().Uint("vision.block_count") == 0 {
return
}
for name, layer := range llm.Tensors().GroupLayers() {
if name == "v" || strings.HasPrefix(name, "v.") {
for _, tensor := range layer {
weights += tensor.Size()
}
}
}
imageSize := uint64(llm.KV().Uint("vision.image_size"))
patchSize := uint64(llm.KV().Uint("vision.patch_size"))
if patchSize == 0 {
slog.Warn("unknown patch size for vision model")
return
}
numChannels := uint64(llm.KV().Uint("vision.num_channels"))
numPatches := (imageSize / patchSize) * (imageSize / patchSize)
if _, ok := llm.Tensors().GroupLayers()["v"]["class_embd"]; ok {
numPatches++
}
headCount := uint64(llm.KV().Uint("vision.attention.head_count"))
embeddingLength := uint64(llm.KV().Uint("vision.embedding_length"))
switch llm.KV().Architecture() {
case "mllama":
for _, layer := range llm.Tensors().GroupLayers()["v"] {
weights += layer.Size()
}
kv := func(n string) uint64 {
if v, ok := llm.KV()["mllama.vision."+n].(uint32); ok {
return uint64(v)
}
return 0
}
imageSize := kv("image_size")
maxNumTiles := kv("max_num_tiles")
embeddingLength := kv("embedding_length")
headCount := kv("attention.head_count")
numPatches := (imageSize / kv("patch_size")) * (imageSize / kv("patch_size"))
if _, ok := llm.Tensors().GroupLayers()["v"]["class_embd"]; ok {
numPatches++
}
numPaddedPatches := numPatches + 8 - (numPatches%8)%8
maxNumTiles := uint64(llm.KV().Uint("vision.max_num_tiles"))
graphSize = 4 * (8 +
imageSize*imageSize*numChannels*maxNumTiles +
imageSize*imageSize*kv("num_channels")*maxNumTiles +
embeddingLength*numPatches*maxNumTiles +
9*embeddingLength*numPaddedPatches*maxNumTiles +
numPaddedPatches*maxNumTiles*numPaddedPatches*maxNumTiles*headCount)
case "gemma3", "mistral3":
graphSize = 4 * (imageSize*imageSize*numChannels +
embeddingLength*patchSize +
numPatches*numPatches*headCount)
}
return weights, graphSize
}

View File

@@ -22,7 +22,7 @@ func TestOrcaMiniBlueSky(t *testing.T) {
Model: "orca-mini",
Prompt: "why is the sky blue?",
Stream: &stream,
Options: map[string]any{
Options: map[string]interface{}{
"temperature": 0,
"seed": 123,
},
@@ -39,7 +39,7 @@ func TestUnicode(t *testing.T) {
Model: "deepseek-coder-v2:16b-lite-instruct-q2_K",
Prompt: "天空为什么是蓝色的?",
Stream: &stream,
Options: map[string]any{
Options: map[string]interface{}{
"temperature": 0,
"seed": 123,
// Workaround deepseek context shifting bug
@@ -61,7 +61,7 @@ func TestExtendedUnicodeOutput(t *testing.T) {
Model: "gemma2:2b",
Prompt: "Output some smily face emoji",
Stream: &stream,
Options: map[string]any{
Options: map[string]interface{}{
"temperature": 0,
"seed": 123,
},
@@ -96,7 +96,7 @@ func TestUnicodeModelDir(t *testing.T) {
Model: "orca-mini",
Prompt: "why is the sky blue?",
Stream: &stream,
Options: map[string]any{
Options: map[string]interface{}{
"temperature": 0,
"seed": 123,
},

View File

@@ -25,7 +25,7 @@ func TestMultiModelConcurrency(t *testing.T) {
Prompt: "why is the ocean blue?",
Stream: &stream,
KeepAlive: &api.Duration{Duration: 10 * time.Second},
Options: map[string]any{
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
@@ -34,7 +34,7 @@ func TestMultiModelConcurrency(t *testing.T) {
Prompt: "what is the origin of the us thanksgiving holiday?",
Stream: &stream,
KeepAlive: &api.Duration{Duration: 10 * time.Second},
Options: map[string]any{
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},

View File

@@ -23,7 +23,7 @@ func TestLongInputContext(t *testing.T) {
Model: "llama2",
Prompt: "Oh, dont speak to me of Austria. Perhaps I dont understand things, but Austria never has wished, and does not wish, for war. She is betraying us! Russia alone must save Europe. Our gracious sovereign recognizes his high vocation and will be true to it. That is the one thing I have faith in! Our good and wonderful sovereign has to perform the noblest role on earth, and he is so virtuous and noble that God will not forsake him. He will fulfill his vocation and crush the hydra of revolution, which has become more terrible than ever in the person of this murderer and villain! We alone must avenge the blood of the just one.... Whom, I ask you, can we rely on?... England with her commercial spirit will not and cannot understand the Emperor Alexanders loftiness of soul. She has refused to evacuate Malta. She wanted to find, and still seeks, some secret motive in our actions. What answer did Novosíltsev get? None. The English have not understood and cannot understand the self-abnegation of our Emperor who wants nothing for himself, but only desires the good of mankind. And what have they promised? Nothing! And what little they have promised they will not perform! Prussia has always declared that Buonaparte is invincible, and that all Europe is powerless before him.... And I dont believe a word that Hardenburg says, or Haugwitz either. This famous Prussian neutrality is just a trap. I have faith only in God and the lofty destiny of our adored monarch. He will save Europe! What country is this referring to?",
Stream: &stream,
Options: map[string]any{
Options: map[string]interface{}{
"temperature": 0,
"seed": 123,
"num_ctx": 128,
@@ -50,7 +50,7 @@ func TestContextExhaustion(t *testing.T) {
Model: "llama2",
Prompt: "Write me a story with a ton of emojis?",
Stream: &stream,
Options: map[string]any{
Options: map[string]interface{}{
"temperature": 0,
"seed": 123,
"num_ctx": 128,

View File

@@ -19,7 +19,7 @@ func TestIntegrationLlava(t *testing.T) {
Model: "llava:7b",
Prompt: "what does the text in this image say?",
Stream: &stream,
Options: map[string]any{
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
@@ -47,7 +47,7 @@ func TestIntegrationMllama(t *testing.T) {
Model: "x/llama3.2-vision",
Prompt: "what does the text in this image say?",
Stream: &stream,
Options: map[string]any{
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
@@ -66,35 +66,6 @@ func TestIntegrationMllama(t *testing.T) {
DoGenerate(ctx, t, client, req, []string{resp}, 240*time.Second, 30*time.Second)
}
func TestIntegrationSplitBatch(t *testing.T) {
image, err := base64.StdEncoding.DecodeString(imageEncoding)
require.NoError(t, err)
req := api.GenerateRequest{
Model: "gemma3:4b",
// Fill up a chunk of the batch so the image will partially spill over into the next one
System: "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed aliquet, justo in malesuada lobortis, odio ligula volutpat quam, quis faucibus ipsum magna quis sapien. Aliquam in venenatis diam, eu viverra magna. Phasellus imperdiet hendrerit volutpat. Vivamus sem ex, facilisis placerat felis non, dictum elementum est. Phasellus aliquam imperdiet lacus, eget placerat ligula sodales vel. Pellentesque nec auctor mi. Curabitur arcu nisi, faucibus eget nunc id, viverra interdum mi. Curabitur ornare ipsum ex, ac euismod ex aliquam in. Vestibulum id magna at purus accumsan fermentum. Proin scelerisque posuere nunc quis interdum. Maecenas sed mollis nisl. Etiam vitae ipsum interdum, placerat est quis, tincidunt velit. Nullam tempor nibh non lorem volutpat efficitur. Cras laoreet diam imperdiet ipsum auctor bibendum. Suspendisse ultrices urna sed metus sagittis suscipit. Quisque ullamcorper aliquam nibh ut mollis. Aenean dapibus mauris pharetra, venenatis elit ac, hendrerit odio. Cras vestibulum erat tempor, lobortis justo eu, lobortis ipsum. Nam laoreet dapibus sem. Proin vel diam ultrices, elementum ante et, ornare lectus. Proin eu accumsan nisl. Praesent ac ex vitae ipsum vulputate tristique facilisis sit amet lacus. Nullam faucibus magna a pellentesque pretium. Nunc lacinia ullamcorper sollicitudin. Donec vitae accumsan turpis, sed porttitor est. Donec porttitor mi vitae augue faucibus, vel mollis diam tincidunt.",
Prompt: "what does the text in this image say?",
Stream: &stream,
Options: map[string]any{
"seed": 42,
"temperature": 0.0,
},
Images: []api.ImageData{
image,
},
}
// Note: sometimes it returns "the ollamas" sometimes "the ollams"
resp := "the ollam"
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
require.NoError(t, PullIfMissing(ctx, client, req.Model))
// llava models on CPU can be quite slow to start,
DoGenerate(ctx, t, client, req, []string{resp}, 120*time.Second, 30*time.Second)
}
const imageEncoding = `iVBORw0KGgoAAAANSUhEUgAAANIAAAB4CAYAAACHHqzKAAAAAXNSR0IArs4c6QAAAIRlWElmTU0AKgAAAAgABQESAAMAAAABAAEAAAEaAAUAAAABAAAASgEb
AAUAAAABAAAAUgEoAAMAAAABAAIAAIdpAAQAAAABAAAAWgAAAAAAAABIAAAAAQAAAEgAAAABAAOgAQADAAAAAQABAACgAgAEAAAAAQAAANKgAwAEAAAAAQAA
AHgAAAAAXdsepgAAAAlwSFlzAAALEwAACxMBAJqcGAAAAVlpVFh0WE1MOmNvbS5hZG9iZS54bXAAAAAAADx4OnhtcG1ldGEgeG1sbnM6eD0iYWRvYmU6bnM6

View File

@@ -20,7 +20,7 @@ var (
Model: "orca-mini",
Prompt: "why is the ocean blue?",
Stream: &stream,
Options: map[string]any{
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
@@ -28,7 +28,7 @@ var (
Model: "orca-mini",
Prompt: "what is the origin of the us thanksgiving holiday?",
Stream: &stream,
Options: map[string]any{
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},

View File

@@ -32,7 +32,7 @@ func TestMaxQueue(t *testing.T) {
req := api.GenerateRequest{
Model: "orca-mini",
Prompt: "write a long historical fiction story about christopher columbus. use at least 10 facts from his actual journey",
Options: map[string]any{
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
@@ -52,8 +52,8 @@ func TestMaxQueue(t *testing.T) {
embedCtx := ctx
var genwg sync.WaitGroup
genwg.Add(1)
go func() {
genwg.Add(1)
defer genwg.Done()
slog.Info("Starting generate request")
DoGenerate(ctx, t, client, req, resp, 45*time.Second, 5*time.Second)
@@ -71,8 +71,8 @@ func TestMaxQueue(t *testing.T) {
counterMu := sync.Mutex{}
var embedwg sync.WaitGroup
for i := 0; i < threadCount; i++ {
embedwg.Add(1)
go func(i int) {
embedwg.Add(1)
defer embedwg.Done()
slog.Info("embed started", "id", i)
embedReq := api.EmbeddingRequest{

View File

@@ -291,7 +291,7 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
Prompt: "why is the ocean blue?",
Stream: &stream,
KeepAlive: &api.Duration{Duration: 10 * time.Second},
Options: map[string]any{
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
@@ -300,7 +300,7 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
Prompt: "why is the color of dirt brown?",
Stream: &stream,
KeepAlive: &api.Duration{Duration: 10 * time.Second},
Options: map[string]any{
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
@@ -309,7 +309,7 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
Prompt: "what is the origin of the us thanksgiving holiday?",
Stream: &stream,
KeepAlive: &api.Duration{Duration: 10 * time.Second},
Options: map[string]any{
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
@@ -318,7 +318,7 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
Prompt: "what is the origin of independence day?",
Stream: &stream,
KeepAlive: &api.Duration{Duration: 10 * time.Second},
Options: map[string]any{
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
@@ -327,7 +327,7 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
Prompt: "what is the composition of air?",
Stream: &stream,
KeepAlive: &api.Duration{Duration: 10 * time.Second},
Options: map[string]any{
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},

View File

@@ -43,31 +43,20 @@ type Cache interface {
// ** cache management **
// Init sets up runtime parameters.
// backend: Used to allocate cache data storage and execute management operations (such as defrag)
// dtype: The data type for storing cache entries
// maxSequences: The maximum number of sequences stored in the cache - across all batches
// capacity: The number of cache entries to store, per sequence
// maxBatch: The maximum number of tokens that can occur in a single batch
Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int)
// Init sets up runtime parameters
Init(backend ml.Backend, dtype ml.DType, capacity int32)
// Close closes the cache and frees resources associated with it
Close()
// StartForward is called before the start of the model's forward pass.
// For each token in the coming batch, there must be a corresponding
// entry in positions and seqs. reserve is to preallocate memory
// without actually storing data in the cache.
StartForward(ctx ml.Context, batch input.Batch, reserve bool) error
// entry in positions and seqs.
StartForward(ctx ml.Context, opts input.Options) error
// CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq
CopyPrefix(srcSeq, dstSeq int, len int32)
// CanResume returns true if the cache can continue with the next token at
// the given position and sequence. Assumes that the caller has already
// verified the contents of the cache.
CanResume(seq int, pos int32) bool
// Remove deletes tokens in the range [beginIndex, endIndex) from seq. Set
// endIndex to math.MaxInt32 to remove everything starting at beginIndex.
//

View File

@@ -20,6 +20,7 @@ type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, e
// The mask is of shape history size, batch size
type Causal struct {
DType ml.DType
Capacity int32
windowSize int32
opts CausalOptions
@@ -97,7 +98,7 @@ func NewSWACache(windowSize int32, shift shiftFn) *Causal {
}
}
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
if c.config == nil {
var config ml.CacheConfig
if cc, ok := backend.(ml.BackendCacheConfig); ok {
@@ -118,16 +119,9 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity
c.config.MaskDType = ml.DTypeF32
}
var cacheSize int
if c.windowSize == math.MaxInt32 || capacity < int(c.windowSize) {
cacheSize = maxSequences * capacity
} else {
cacheSize = (maxSequences * int(c.windowSize)) + maxBatch
}
cacheSize = roundUp(cacheSize, c.config.CachePadding)
c.cells = make([]cacheCell, cacheSize)
c.DType = dtype
c.Capacity = int32(roundUp(int(capacity), c.config.CachePadding))
c.cells = make([]cacheCell, c.Capacity)
c.cellRanges = make(map[int]cellRange)
c.backend = backend
}
@@ -146,60 +140,49 @@ func (c *Causal) Close() {
}
}
func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
c.curBatchSize = len(batch.Positions)
c.curSequences = batch.Sequences
c.curPositions = batch.Positions
func (c *Causal) StartForward(ctx ml.Context, opts input.Options) error {
c.curBatchSize = len(opts.Positions)
c.curSequences = opts.Sequences
c.curPositions = opts.Positions
c.opts.Except = nil
if !reserve {
c.updateSlidingWindow()
var err error
var err error
c.curLoc, err = c.findStartLoc()
if errors.Is(err, ErrKvCacheFull) {
c.defrag()
c.curLoc, err = c.findStartLoc()
if errors.Is(err, ErrKvCacheFull) {
c.defrag()
c.curLoc, err = c.findStartLoc()
}
if err != nil {
return err
}
c.curCellRange = newRange()
for i, pos := range batch.Positions {
seq := batch.Sequences[i]
c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}}
seqRange, ok := c.cellRanges[seq]
if !ok {
seqRange = newRange()
}
if c.curLoc+i > seqRange.max {
seqRange.max = c.curLoc + i
}
if seqRange.max > c.curCellRange.max {
c.curCellRange.max = seqRange.max
}
if c.curLoc+i < seqRange.min {
seqRange.min = c.curLoc + i
}
if seqRange.min < c.curCellRange.min {
c.curCellRange.min = seqRange.min
}
c.cellRanges[seq] = seqRange
}
} else {
// If we are reserving memory, don't update any of the cache metadata but set the size
// to the worst case.
c.curLoc = 0
c.curCellRange.min = 0
c.curCellRange.max = len(c.cells) - 1
}
if err != nil {
return err
}
c.curCellRange = newRange()
for i, pos := range opts.Positions {
seq := opts.Sequences[i]
c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}}
seqRange, ok := c.cellRanges[seq]
if !ok {
seqRange = newRange()
}
if c.curLoc+i > seqRange.max {
seqRange.max = c.curLoc + i
}
if seqRange.max > c.curCellRange.max {
c.curCellRange.max = seqRange.max
}
if c.curLoc+i < seqRange.min {
seqRange.min = c.curLoc + i
}
if seqRange.min < c.curCellRange.min {
c.curCellRange.min = seqRange.min
}
c.cellRanges[seq] = seqRange
}
var err error
c.curMask, err = c.buildMask(ctx)
return err
@@ -227,51 +210,7 @@ func (c *Causal) findStartLoc() (int, error) {
}
}
return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, len(c.cells))
}
func (c *Causal) updateSlidingWindow() {
if c.windowSize == math.MaxInt32 {
return
}
// create a map of unique sequences to the lowest position in that sequence
lowestPos := make(map[int]int32)
for i := range c.curPositions {
seq := c.curSequences[i]
pos, ok := lowestPos[seq]
if !ok {
pos = c.curPositions[i]
} else if c.curPositions[i] < pos {
pos = c.curPositions[i]
}
lowestPos[seq] = pos
}
// delete any entries that are beyond the window of the oldest position in the sequence
for seq, pos := range lowestPos {
oldRange, ok := c.cellRanges[seq]
if !ok {
continue
}
newRange := newRange()
for i := oldRange.min; i <= oldRange.max; i++ {
if slices.Contains(c.cells[i].sequences, seq) {
if c.cells[i].pos < pos-c.windowSize {
c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq })
} else {
newRange.min = min(newRange.min, i)
newRange.max = max(newRange.max, i)
}
}
}
c.cellRanges[seq] = newRange
}
return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, c.Capacity)
}
func roundDown(length, pad int) int {
@@ -326,7 +265,7 @@ func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
return maskTensor, nil
}
func (c *Causal) moveCells(ctx ml.Context, src, dst, length int) {
func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) {
for i, key := range c.keys {
if key == nil {
continue
@@ -336,8 +275,8 @@ func (c *Causal) moveCells(ctx ml.Context, src, dst, length int) {
numKVHeads := key.Dim(1)
rowSize := key.Stride(2)
kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*length)
kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*length)
kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*len)
kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*len)
value := c.values[i]
var vSrcView, vDstView ml.Tensor
@@ -345,14 +284,14 @@ func (c *Causal) moveCells(ctx ml.Context, src, dst, length int) {
vHeadDim := value.Dim(1)
elemSize := value.Stride(0)
vSrcView = value.View(ctx, elemSize*src, length, len(c.cells)*elemSize, vHeadDim*numKVHeads)
vDstView = value.View(ctx, elemSize*dst, length, len(c.cells)*elemSize, vHeadDim*numKVHeads)
vSrcView = value.View(ctx, elemSize*src, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)
vDstView = value.View(ctx, elemSize*dst, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)
} else {
vHeadDim := value.Dim(0)
rowSize := value.Stride(2)
vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*length)
vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*length)
vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*len)
vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*len)
}
ctx.Forward(
@@ -382,8 +321,7 @@ func (c *Causal) defrag() {
ctx := c.backend.NewContext()
// For every move, 6 tensors are required per layer (2 views and a
// copy for each of k and v). We also need to refer to the original
// k and v cache tensors - once per layer, not per move.
// copy for each of k and v).
layers := 0
for _, key := range c.keys {
if key == nil {
@@ -392,7 +330,7 @@ func (c *Causal) defrag() {
layers++
}
maxMoves := (ctx.MaxGraphNodes() - 2*layers) / (6 * layers)
maxMoves := ctx.MaxGraphNodes() / (6 * layers)
moves := 0
var pendingSrc, pendingDst, pendingLen int
@@ -541,14 +479,14 @@ func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
}
if _, ok := c.keys[c.curLayer]; !ok {
c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, kHeadDim, numKVHeads, len(c.cells))
c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, kHeadDim, numKVHeads, int(c.Capacity))
}
if _, ok := c.values[c.curLayer]; !ok {
if c.config.PermutedV {
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, len(c.cells), vHeadDim, numKVHeads)
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, int(c.Capacity), vHeadDim, numKVHeads)
} else {
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, vHeadDim, numKVHeads, len(c.cells))
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, vHeadDim, numKVHeads, int(c.Capacity))
}
}
@@ -559,7 +497,7 @@ func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
elemSize := c.values[c.curLayer].Stride(0)
value = value.Permute(ctx, 1, 2, 0, 3)
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, len(c.cells)*elemSize, vHeadDim*numKVHeads)))
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)))
} else {
rowSize := c.values[c.curLayer].Stride(2)
@@ -590,35 +528,6 @@ func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
c.cellRanges[dstSeq] = seqRange
}
func (c *Causal) CanResume(seq int, pos int32) bool {
if c.windowSize == math.MaxInt32 {
return true
}
seqRange, ok := c.cellRanges[seq]
if !ok {
return false
}
// for sliding window, check that the window of the new sequence is contained in
// the window of what we are storing
var last int32 = -1
for i := seqRange.min; i <= seqRange.max; i++ {
if slices.Contains(c.cells[i].sequences, seq) {
last = max(last, c.cells[i].pos)
}
}
if last == -1 {
return false
}
lastWindowStart := max(0, last-c.windowSize)
posWindowStart := max(0, pos-c.windowSize)
return posWindowStart >= lastWindowStart
}
func (c *Causal) shift(seq int, beginIndex, offset int32) error {
if c.shiftFn == nil {
return ErrNotSupported
@@ -673,12 +582,6 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error {
}
func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error {
// TODO(jessegross): We should check to see if removing the middle of the sequence will
// cause the sliding window to encompass tokens that we no longer have. If so, then we
// should return an error, which will trigger the runner to evaluate the full history and
// rebuild the window. However, if we have multimodal inputs in our history, this reuse
// results in use after free, so we don't do it for now.
var offset int32
if endIndex != math.MaxInt32 {
offset = beginIndex - endIndex
@@ -693,7 +596,8 @@ func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error {
} else {
if c.cells[i].pos >= endIndex {
if slices.ContainsFunc(c.cells[i].sequences, func(s int) bool { return s != seq }) {
return errors.New("shifting cells shared by multiple sequences not supported")
// TODO(jessegross): Need to be careful about data shared between sequences
return errors.New("shifting on cells shared by multiple sequences not yet implemented")
}
c.cells[i].pos += offset

View File

@@ -25,7 +25,7 @@ func TestStore(t *testing.T) {
cache := NewCausalCache(nil)
defer cache.Close()
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
cache.Init(backend, ml.DTypeF16, 16)
tests := []testCase{
{
@@ -58,11 +58,11 @@ func TestSWA(t *testing.T) {
cache := NewSWACache(1, nil)
defer cache.Close()
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
cache.Init(backend, ml.DTypeF32, 16)
tests := []testCase{
{
name: "FirstBatch",
name: "SlidingWindow",
in: []float32{1, 2, 3, 4},
inShape: []int{1, 1, 4},
seqs: []int{0, 0, 0, 0},
@@ -71,16 +71,6 @@ func TestSWA(t *testing.T) {
expectedShape: []int{1, 1, 4},
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
},
{
name: "SecondBatch",
in: []float32{5, 6},
inShape: []int{1, 1, 2},
seqs: []int{0, 0},
pos: []int32{4, 5},
expected: []float32{5, 6, 3, 4},
expectedShape: []int{1, 1, 4},
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1))},
},
}
testCache(t, backend, cache, tests)
@@ -91,7 +81,7 @@ func TestSequences(t *testing.T) {
cache := NewCausalCache(nil)
defer cache.Close()
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
cache.Init(backend, ml.DTypeF16, 16)
tests := []testCase{
{
@@ -126,7 +116,7 @@ func TestRemove(t *testing.T) {
})
defer cache.Close()
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
cache.Init(backend, ml.DTypeF16, 16)
tests := []testCase{
{
@@ -191,7 +181,7 @@ func TestDefrag(t *testing.T) {
})
defer cache.Close()
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
cache.Init(backend, ml.DTypeF16, 16)
tests := []testCase{
{
@@ -239,7 +229,7 @@ func TestCopy(t *testing.T) {
cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { return key, nil })
defer cache.Close()
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
cache.Init(backend, ml.DTypeF16, 16)
tests := []testCase{
{
@@ -280,7 +270,7 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
context := backend.NewContext()
defer context.Close()
err := cache.StartForward(context, input.Batch{Positions: test.pos, Sequences: test.seqs}, false)
err := cache.StartForward(context, input.Options{Positions: test.pos, Sequences: test.seqs})
if err != nil {
panic(err)
}
@@ -300,79 +290,14 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
}
}
func TestCanResume(t *testing.T) {
backend := &testBackend{}
windowSize := int32(4)
cache := NewSWACache(windowSize, nil)
defer cache.Close()
type testBackend struct{}
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
context := backend.NewContext()
defer context.Close()
err := cache.StartForward(context, input.Batch{
Positions: []int32{0, 1, 2, 3},
Sequences: []int{0, 0, 0, 0},
}, false)
if err != nil {
t.Fatalf("StartForward failed: %v", err)
}
cache.SetLayer(0)
tensor, _ := context.FromFloatSlice([]float32{1, 2, 3, 4}, 1, 1, 4)
cache.Put(context, tensor, tensor)
// with window size 4, nothing has slid out of the window yet
if !cache.CanResume(0, 0) {
t.Errorf("CanResume(0, 0) = false, want true (within window)")
}
if !cache.CanResume(0, 1) {
t.Errorf("CanResume(0, 1) = false, want true (within window)")
}
if !cache.CanResume(0, 2) {
t.Errorf("CanResume(0, 2) = false, want true (within window)")
}
if !cache.CanResume(0, 3) {
t.Errorf("CanResume(0, 3) = false, want true (latest position)")
}
// shift window by adding position 4
err = cache.StartForward(context, input.Batch{
Positions: []int32{4, 5},
Sequences: []int{0, 0},
}, false)
if err != nil {
t.Fatalf("StartForward failed: %v", err)
}
cache.SetLayer(0)
tensor, _ = context.FromFloatSlice([]float32{5, 6}, 1, 1, 2)
cache.Put(context, tensor, tensor)
// only the latest position has overlapping windows
if cache.CanResume(0, 0) {
t.Errorf("after shift: CanResume(0, 0) = true, want false (outside window)")
}
if cache.CanResume(0, 1) {
t.Errorf("after shift: CanResume(0, 1) = true, want false (outside window)")
}
if cache.CanResume(0, 2) {
t.Errorf("after shift: CanResume(0, 2) = true, want false (outside window)")
}
if cache.CanResume(0, 3) {
t.Errorf("after shift: CanResume(0, 3) = true, want false (outside window)")
}
if cache.CanResume(0, 4) {
t.Errorf("after shift: CanResume(0, 4) = true, want false (outside window)")
}
if !cache.CanResume(0, 5) {
t.Errorf("after shift: CanResume(0, 5) = false, want true (latest position)")
}
func (b *testBackend) Config() ml.Config {
panic("not implemented")
}
type testBackend struct {
ml.Backend
func (b *testBackend) Get(name string) ml.Tensor {
panic("not implemented")
}
func (b *testBackend) NewContext() ml.Context {
@@ -383,10 +308,12 @@ func (b *testBackend) NewContextSize(int) ml.Context {
return &testContext{}
}
type testContext struct {
ml.Context
func (b *testBackend) SystemInfo() string {
return "not implemented"
}
type testContext struct{}
func (c *testContext) Empty(dtype ml.DType, shape ...int) ml.Tensor {
total := 0
@@ -425,14 +352,13 @@ func (c *testContext) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
}
func (c *testContext) Input() ml.Context { return c }
func (c *testContext) Output() ml.Context { return c }
func (c *testContext) Layer(int) ml.Context { return c }
func (c *testContext) Forward(...ml.Tensor) ml.Context { return c }
func (c *testContext) Compute(...ml.Tensor) {}
func (c *testContext) Reserve() error { return nil }
func (c *testContext) MaxGraphNodes() int {
return 10
}
@@ -440,8 +366,6 @@ func (c *testContext) MaxGraphNodes() int {
func (c *testContext) Close() {}
type testTensor struct {
ml.Tensor
dtype ml.DType
elementSize int
data []float32
@@ -469,20 +393,16 @@ func (t *testTensor) DType() ml.DType {
return t.dtype
}
func (t *testTensor) Bytes() []byte {
panic("not implemented")
}
func (t *testTensor) Floats() []float32 {
out := make([]float32, len(t.data))
copy(out, t.data)
return out
}
func (t *testTensor) Neg(ctx ml.Context) ml.Tensor {
out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor)
for i := range out.data {
out.data[i] = -t.data[i]
}
return out
}
func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor)
@@ -493,6 +413,66 @@ func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
return out
}
func (t *testTensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) MulmatFullPrec(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Softmax(ctx ml.Context) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) LayerNorm(ctx ml.Context, weight, bias ml.Tensor, eps float32) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) RMSNorm(ctx ml.Context, weight ml.Tensor, eps float32) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Scale(ctx ml.Context, s float64) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) AvgPool1D(ctx ml.Context, k, s, p int) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Conv2D(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, dim, ropeType uint32, base, scale float32) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Tanh(ctx ml.Context) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) GELU(ctx ml.Context) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) SILU(ctx ml.Context) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
offset /= t.elementSize
@@ -515,6 +495,38 @@ func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
return view
}
func (t *testTensor) Permute(ctx ml.Context, shape ...int) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Contiguous(ctx ml.Context) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Set(ctx ml.Context, t2 ml.Tensor, offset int, strides ...int) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Unpad(ctx ml.Context, shape ...int) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Stack(ctx ml.Context, dim int, s ...ml.Tensor) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Concat(ctx ml.Context, t2 ml.Tensor, dim int) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Rows(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
copy(t2.(*testTensor).data, t.data)
return nil

View File

@@ -27,11 +27,6 @@ type EncoderCache struct {
// anything will be stored)
curPos int32
// curReserve indicates that this forward pass is only for
// memory reservation and we should not update our metadata
// based on it.
curReserve bool
// ** cache metadata **
// was something stored in the cache?
@@ -54,7 +49,7 @@ func NewEncoderCache() *EncoderCache {
}
}
func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
if c.config == nil {
var config ml.CacheConfig
if cc, ok := backend.(ml.BackendCacheConfig); ok {
@@ -63,10 +58,6 @@ func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, ca
c.config = &config
}
if maxSequences > 1 {
panic(fmt.Errorf("encoder cache does not support multiple sequences; requested: %v", maxSequences))
}
if c.config.CachePadding != 0 && c.config.CachePadding != 1 {
panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding))
}
@@ -88,14 +79,12 @@ func (c *EncoderCache) Close() {
}
}
func (c *EncoderCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
func (c *EncoderCache) StartForward(ctx ml.Context, opts input.Options) error {
// We work with the most recent image
if len(batch.Multimodal) > 0 {
c.curPos = batch.Positions[batch.Multimodal[len(batch.Multimodal)-1].Index]
if len(opts.Multimodal) > 0 {
c.curPos = opts.Positions[opts.Multimodal[len(opts.Multimodal)-1].Index]
}
c.curReserve = reserve
return nil
}
@@ -112,10 +101,8 @@ func (c *EncoderCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
}
func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) {
if !c.curReserve {
c.encoderPos = c.curPos
c.encoderCached = true
}
c.encoderPos = c.curPos
c.encoderCached = true
if c.config.PermutedV {
value = value.Permute(ctx, 1, 2, 0, 3)
@@ -143,10 +130,6 @@ func (c *EncoderCache) CopyPrefix(srcSeq, dstSeq int, len int32) {
panic("encoder cache does not support multiple sequences")
}
func (c *EncoderCache) CanResume(seq int, pos int32) bool {
return true
}
func (c *EncoderCache) Remove(seq int, beginIndex, endIndex int32) error {
if c.encoderPos >= beginIndex && c.encoderPos < endIndex {
c.encoderCached = false

View File

@@ -23,9 +23,9 @@ func NewWrapperCache(caches ...Cache) *WrapperCache {
}
}
func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
for _, cache := range c.caches {
cache.Init(backend, dtype, maxSequences, capacity, maxBatch)
cache.Init(backend, dtype, capacity)
}
}
@@ -41,14 +41,14 @@ func (c *WrapperCache) Close() {
}
}
func (c *WrapperCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
func (c *WrapperCache) StartForward(ctx ml.Context, opts input.Options) error {
for i, cache := range c.caches {
err := cache.StartForward(ctx, batch, reserve)
err := cache.StartForward(ctx, opts)
if err != nil {
// unwind on error - Remove with endIndex set to math.MaxInt32 does not fail
for j := i - 1; j >= 0; j-- {
for k := range batch.Positions {
_ = c.caches[j].Remove(batch.Sequences[k], batch.Positions[k], math.MaxInt32)
for k := range opts.Positions {
_ = c.caches[j].Remove(opts.Sequences[k], opts.Positions[k], math.MaxInt32)
}
}
return err
@@ -87,16 +87,6 @@ func (c *WrapperCache) CopyPrefix(srcSeq, dstSeq int, len int32) {
}
}
func (c *WrapperCache) CanResume(seq int, pos int32) bool {
for _, cache := range c.caches {
if !cache.CanResume(seq, pos) {
return false
}
}
return true
}
func (c *WrapperCache) Remove(seq int, beginIndex, endIndex int32) error {
// If the one of these fails, the caller is supposed to retry with endIndex set to math.MaxInt32, which should not fail
for _, cache := range c.caches {

View File

@@ -37,7 +37,6 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_MINICPM3, "minicpm3" },
{ LLM_ARCH_GEMMA, "gemma" },
{ LLM_ARCH_GEMMA2, "gemma2" },
{ LLM_ARCH_GEMMA3, "gemma3" },
{ LLM_ARCH_STARCODER2, "starcoder2" },
{ LLM_ARCH_MAMBA, "mamba" },
{ LLM_ARCH_XVERSE, "xverse" },
@@ -65,7 +64,6 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_CHAMELEON, "chameleon" },
{ LLM_ARCH_SOLAR, "solar" },
{ LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
{ LLM_ARCH_MISTRAL3, "mistral3" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};
@@ -806,24 +804,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
},
},
{
LLM_ARCH_GEMMA3,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
},
},
{
LLM_ARCH_STARCODER2,
{
@@ -1372,22 +1352,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_POS_NET_ATTN_OUT, "posnet.%d.attn_output" },
},
},
{
LLM_ARCH_MISTRAL3,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
}
},
{
LLM_ARCH_UNKNOWN,
{

View File

@@ -41,7 +41,6 @@ enum llm_arch {
LLM_ARCH_MINICPM3,
LLM_ARCH_GEMMA,
LLM_ARCH_GEMMA2,
LLM_ARCH_GEMMA3,
LLM_ARCH_STARCODER2,
LLM_ARCH_MAMBA,
LLM_ARCH_XVERSE,
@@ -69,7 +68,6 @@ enum llm_arch {
LLM_ARCH_CHAMELEON,
LLM_ARCH_SOLAR,
LLM_ARCH_WAVTOKENIZER_DEC,
LLM_ARCH_MISTRAL3,
LLM_ARCH_UNKNOWN,
};

View File

@@ -878,9 +878,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
default: type = LLM_TYPE_UNKNOWN;
}
} break;
case LLM_ARCH_GEMMA3:
{
} break;
case LLM_ARCH_STARCODER2:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
@@ -1277,7 +1274,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
ml.get_key(LLM_KV_ATTENTION_GROUPNORM_GROUPS, hparams.n_norm_groups);
ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
} break;
case LLM_ARCH_MISTRAL3: break;
default: throw std::runtime_error("unsupported model architecture");
}
@@ -2541,9 +2537,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
}
} break;
case LLM_ARCH_GEMMA3:
{
} break;
case LLM_ARCH_STARCODER2:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@@ -3538,7 +3531,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hparams.convnext.n_embd, n_embd}, 0);
output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {n_embd}, 0);
} break;
case LLM_ARCH_MISTRAL3: break;
default:
throw std::runtime_error("unknown architecture");
}
@@ -4017,7 +4009,6 @@ enum llama_rope_type llama_model_rope_type(const struct llama_model * model) {
case LLM_ARCH_GRANITE_MOE:
case LLM_ARCH_CHAMELEON:
case LLM_ARCH_SOLAR:
case LLM_ARCH_MISTRAL3:
return LLAMA_ROPE_TYPE_NORM;
// the pairs of head values are offset by n_rot/2
@@ -4038,7 +4029,6 @@ enum llama_rope_type llama_model_rope_type(const struct llama_model * model) {
case LLM_ARCH_PHIMOE:
case LLM_ARCH_GEMMA:
case LLM_ARCH_GEMMA2:
case LLM_ARCH_GEMMA3:
case LLM_ARCH_STARCODER2:
case LLM_ARCH_OPENELM:
case LLM_ARCH_GPTNEOX:

View File

@@ -737,10 +737,6 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
// This used to be a regex, but <regex> has an extreme cost to compile times.
bool quantize = name.rfind("weight") == name.size() - 6; // ends with 'weight'?
// don't quantize vision stuff
quantize &= name.find("v.") == std::string::npos;
quantize &= name.find("mm.") == std::string::npos;
// quantize only 2D and 3D tensors (experts)
quantize &= (ggml_n_dims(tensor) >= 2);

View File

@@ -166,10 +166,6 @@ func (c *Context) KvCacheDefrag() {
C.llama_kv_cache_defrag(c.c)
}
func (c *Context) KvCacheCanShift() bool {
return bool(C.llama_kv_cache_can_shift(c.c))
}
// Get the embeddings for a sequence id
func (c *Context) GetEmbeddingsSeq(seqId int) []float32 {
e := unsafe.Pointer(C.llama_get_embeddings_seq(c.c, C.int(seqId)))

View File

@@ -1,173 +0,0 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Patrick Devine <patrick@infrahq.com>
Date: Fri, 14 Mar 2025 16:33:23 -0700
Subject: [PATCH] add model quantizations
- gemma3
- mistral3
---
src/llama-arch.cpp | 36 ++++++++++++++++++++++++++++++++++++
src/llama-arch.h | 2 ++
src/llama-model.cpp | 10 ++++++++++
src/llama-quant.cpp | 4 ++++
4 files changed, 52 insertions(+)
diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp
index b6f20286..13a0a988 100644
--- a/src/llama-arch.cpp
+++ b/src/llama-arch.cpp
@@ -37,6 +37,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_MINICPM3, "minicpm3" },
{ LLM_ARCH_GEMMA, "gemma" },
{ LLM_ARCH_GEMMA2, "gemma2" },
+ { LLM_ARCH_GEMMA3, "gemma3" },
{ LLM_ARCH_STARCODER2, "starcoder2" },
{ LLM_ARCH_MAMBA, "mamba" },
{ LLM_ARCH_XVERSE, "xverse" },
@@ -64,6 +65,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_CHAMELEON, "chameleon" },
{ LLM_ARCH_SOLAR, "solar" },
{ LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
+ { LLM_ARCH_MISTRAL3, "mistral3" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};
@@ -804,6 +806,24 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
},
},
+ {
+ LLM_ARCH_GEMMA3,
+ {
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
+ { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
+ { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
+ },
+ },
{
LLM_ARCH_STARCODER2,
{
@@ -1352,6 +1372,22 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_POS_NET_ATTN_OUT, "posnet.%d.attn_output" },
},
},
+ {
+ LLM_ARCH_MISTRAL3,
+ {
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
+ }
+ },
{
LLM_ARCH_UNKNOWN,
{
diff --git a/src/llama-arch.h b/src/llama-arch.h
index ec742224..8476ae0a 100644
--- a/src/llama-arch.h
+++ b/src/llama-arch.h
@@ -41,6 +41,7 @@ enum llm_arch {
LLM_ARCH_MINICPM3,
LLM_ARCH_GEMMA,
LLM_ARCH_GEMMA2,
+ LLM_ARCH_GEMMA3,
LLM_ARCH_STARCODER2,
LLM_ARCH_MAMBA,
LLM_ARCH_XVERSE,
@@ -68,6 +69,7 @@ enum llm_arch {
LLM_ARCH_CHAMELEON,
LLM_ARCH_SOLAR,
LLM_ARCH_WAVTOKENIZER_DEC,
+ LLM_ARCH_MISTRAL3,
LLM_ARCH_UNKNOWN,
};
diff --git a/src/llama-model.cpp b/src/llama-model.cpp
index ab1a07d1..db4f2685 100644
--- a/src/llama-model.cpp
+++ b/src/llama-model.cpp
@@ -878,6 +878,9 @@ void llama_model::load_hparams(llama_model_loader & ml) {
default: type = LLM_TYPE_UNKNOWN;
}
} break;
+ case LLM_ARCH_GEMMA3:
+ {
+ } break;
case LLM_ARCH_STARCODER2:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
@@ -1274,6 +1277,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
ml.get_key(LLM_KV_ATTENTION_GROUPNORM_GROUPS, hparams.n_norm_groups);
ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
} break;
+ case LLM_ARCH_MISTRAL3: break;
default: throw std::runtime_error("unsupported model architecture");
}
@@ -2537,6 +2541,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
}
} break;
+ case LLM_ARCH_GEMMA3:
+ {
+ } break;
case LLM_ARCH_STARCODER2:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@@ -3531,6 +3538,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hparams.convnext.n_embd, n_embd}, 0);
output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {n_embd}, 0);
} break;
+ case LLM_ARCH_MISTRAL3: break;
default:
throw std::runtime_error("unknown architecture");
}
@@ -4009,6 +4017,7 @@ enum llama_rope_type llama_model_rope_type(const struct llama_model * model) {
case LLM_ARCH_GRANITE_MOE:
case LLM_ARCH_CHAMELEON:
case LLM_ARCH_SOLAR:
+ case LLM_ARCH_MISTRAL3:
return LLAMA_ROPE_TYPE_NORM;
// the pairs of head values are offset by n_rot/2
@@ -4029,6 +4038,7 @@ enum llama_rope_type llama_model_rope_type(const struct llama_model * model) {
case LLM_ARCH_PHIMOE:
case LLM_ARCH_GEMMA:
case LLM_ARCH_GEMMA2:
+ case LLM_ARCH_GEMMA3:
case LLM_ARCH_STARCODER2:
case LLM_ARCH_OPENELM:
case LLM_ARCH_GPTNEOX:
diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp
index 6eb1da08..ebcbafa1 100644
--- a/src/llama-quant.cpp
+++ b/src/llama-quant.cpp
@@ -737,6 +737,10 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
// This used to be a regex, but <regex> has an extreme cost to compile times.
bool quantize = name.rfind("weight") == name.size() - 6; // ends with 'weight'?
+ // don't quantize vision stuff
+ quantize &= name.find("v.") == std::string::npos;
+ quantize &= name.find("mm.") == std::string::npos;
+
// quantize only 2D and 3D tensors (experts)
quantize &= (ggml_n_dims(tensor) >= 2);

View File

@@ -1,103 +0,0 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Saman <saman.khatir@amd.com>
Date: Wed, 19 Mar 2025 14:02:26 -0700
Subject: [PATCH] add rdna4 support
---
ggml/src/ggml-cuda/common.cuh | 6 ++++--
ggml/src/ggml-cuda/mmq.cu | 2 +-
ggml/src/ggml-cuda/mmq.cuh | 4 ++--
ggml/src/ggml-cuda/mmvq.cu | 4 ++--
ggml/src/ggml-cuda/vendors/hip.h | 4 ++++
5 files changed, 13 insertions(+), 7 deletions(-)
diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh
index adf0d3ec..b24593fc 100644
--- a/ggml/src/ggml-cuda/common.cuh
+++ b/ggml/src/ggml-cuda/common.cuh
@@ -61,11 +61,13 @@
#define GGML_CUDA_CC_RDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x1010) // RX 5000
#define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x1030) // RX 6000, minimum for dp4a
#define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x1100) // RX 7000, minimum for WMMA
+#define GGML_CUDA_CC_RDNA4 (GGML_CUDA_CC_OFFSET_AMD + 0x1200) // RX 9000
#define GGML_CUDA_CC_IS_RDNA(cc) (cc >= GGML_CUDA_CC_RDNA1)
#define GGML_CUDA_CC_IS_RDNA1(cc) (cc >= GGML_CUDA_CC_RDNA1 && cc < GGML_CUDA_CC_RDNA2)
#define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3)
-#define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3)
+#define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3 && cc < GGML_CUDA_CC_RDNA4)
+#define GGML_CUDA_CC_IS_RDNA4(cc) (cc >= GGML_CUDA_CC_RDNA4)
#define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA)
#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1)
@@ -386,7 +388,7 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(RDNA2)
c = __builtin_amdgcn_sdot4(a, b, c, false);
-#elif defined(RDNA3)
+#elif defined(RDNA3) || defined(RDNA4)
c = __builtin_amdgcn_sudot4( true, a, true, b, c, false);
#elif defined(__gfx1010__) || defined(__gfx900__)
int tmp1;
diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu
index 10f2ebb1..933d945c 100644
--- a/ggml/src/ggml-cuda/mmq.cu
+++ b/ggml/src/ggml-cuda/mmq.cu
@@ -149,5 +149,5 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
}
- return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
+ return (!GGML_CUDA_CC_IS_RDNA4(cc) && !GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
}
diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh
index 0451c65f..66ce2bc9 100644
--- a/ggml/src/ggml-cuda/mmq.cuh
+++ b/ggml/src/ggml-cuda/mmq.cuh
@@ -2577,9 +2577,9 @@ static __device__ void mul_mat_q_process_tile(
template <ggml_type type, int mmq_x, int nwarps, bool need_check>
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
-#if defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
+#if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
__launch_bounds__(WARP_SIZE*nwarps, 2)
-#endif // defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
+#endif // defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
#else
#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
__launch_bounds__(WARP_SIZE*nwarps, 1)
diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu
index 4fb466ca..23ae7abc 100644
--- a/ggml/src/ggml-cuda/mmvq.cu
+++ b/ggml/src/ggml-cuda/mmvq.cu
@@ -62,13 +62,13 @@ static __global__ void mul_mat_vec_q(
constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
-#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
+#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4))
constexpr int nwarps = 1;
constexpr int rows_per_cuda_block = 1;
#else
constexpr int nwarps = ncols_y <= 4 ? 4 : 2;
constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2;
-#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
+#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3) && !defined(RDNA4)
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
const int row0 = rows_per_cuda_block*blockIdx.x;
diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h
index 81964611..a62544b5 100644
--- a/ggml/src/ggml-cuda/vendors/hip.h
+++ b/ggml/src/ggml-cuda/vendors/hip.h
@@ -150,6 +150,10 @@
#define CDNA
#endif
+#if defined(__gfx1200__) || defined(__gfx1201__)
+#define RDNA4
+#endif
+
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
defined(__gfx1150__) || defined(__gfx1151__)
#define RDNA3

View File

@@ -1,75 +0,0 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Michael Yang <git@mxy.ng>
Date: Wed, 2 Apr 2025 15:26:15 -0700
Subject: [PATCH] metal: add op_neg
---
ggml/src/ggml-metal/ggml-metal.m | 15 +++++++++++++++
ggml/src/ggml-metal/ggml-metal.metal | 7 +++++++
2 files changed, 22 insertions(+)
diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m
index e4c093f9..d8422f1b 100644
--- a/ggml/src/ggml-metal/ggml-metal.m
+++ b/ggml/src/ggml-metal/ggml-metal.m
@@ -423,6 +423,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_SQRT,
GGML_METAL_KERNEL_TYPE_SIN,
GGML_METAL_KERNEL_TYPE_COS,
+ GGML_METAL_KERNEL_TYPE_NEG,
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
@@ -1039,6 +1040,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
@@ -1202,6 +1204,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_SILU:
case GGML_UNARY_OP_ELU:
+ case GGML_UNARY_OP_NEG:
return ggml_is_contiguous(op->src[0]);
default:
return false;
@@ -1873,6 +1876,18 @@ static void ggml_metal_encode_node(
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
+ case GGML_UNARY_OP_NEG:
+ {
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NEG].pipeline;
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+
+ const int64_t n = ggml_nelements(dst);
+
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ } break;
default:
{
GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
index f38909d0..bb0ff668 100644
--- a/ggml/src/ggml-metal/ggml-metal.metal
+++ b/ggml/src/ggml-metal/ggml-metal.metal
@@ -945,6 +945,13 @@ kernel void kernel_cos(
dst[tpig] = cos(src0[tpig]);
}
+kernel void kernel_neg(
+ device const float * src0,
+ device float * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ dst[tpig] = -src0[tpig];
+}
+
kernel void kernel_sum_rows(
device const float * src0,
device float * dst,

View File

@@ -15,12 +15,12 @@ import (
)
// This algorithm looks for a complete fit to determine if we need to unload other models
func PredictServerFit(allGpus discover.GpuInfoList, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (bool, uint64) {
func PredictServerFit(allGpus discover.GpuInfoList, f *ggml.GGML, adapters, projectors []string, opts api.Options) (bool, uint64) {
// Split up the GPUs by type and try them
var estimatedVRAM uint64
for _, gpus := range allGpus.ByLibrary() {
var layerCount int
estimate := EstimateGPULayers(gpus, f, projectors, opts, numParallel)
estimate := EstimateGPULayers(gpus, f, projectors, opts)
layerCount, estimatedVRAM = estimate.Layers, estimate.VRAMSize
if opts.NumGPU < 0 {
if layerCount > 0 && layerCount >= int(f.KV().BlockCount()+1) {
@@ -71,7 +71,7 @@ type MemoryEstimate struct {
// Given a model and one or more GPU targets, predict how many layers and bytes we can load, and the total size
// The GPUs provided must all be the same Library
func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []string, opts api.Options, numParallel int) MemoryEstimate {
func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []string, opts api.Options) MemoryEstimate {
// Graph size for a partial offload, applies to all GPUs
var graphPartialOffload uint64
@@ -137,19 +137,13 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
}
}
kv, graphPartialOffload, graphFullOffload := f.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)), numParallel, kvct)
kv, graphPartialOffload, graphFullOffload := f.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)), kvct)
if len(kv) > 0 {
layerSize += kv[0]
}
var kvTotal uint64
for _, kvLayer := range kv {
kvTotal += kvLayer
}
// KV is proportional to the number of layers
layerSize += kv / f.KV().BlockCount()
if graphPartialOffload == 0 {
graphPartialOffload = f.KV().GQA() * kvTotal / 6
graphPartialOffload = f.KV().GQA() * kv / 6
}
if graphFullOffload == 0 {
graphFullOffload = graphPartialOffload
@@ -223,9 +217,9 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
// Some models have inconsistent layer sizes
if blk, ok := layers[fmt.Sprintf("blk.%d", i)]; ok {
layerSize = blk.Size()
layerSize += kv[i]
memoryWeights += blk.Size()
layerSize += kv / f.KV().BlockCount()
}
memoryWeights += layerSize
if opts.NumGPU >= 0 && layerCount >= opts.NumGPU {
// Stop allocating on GPU(s) once we hit the users target NumGPU
@@ -321,7 +315,7 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
layersRequested: opts.NumGPU,
layersModel: int(f.KV().BlockCount()) + 1,
availableList: availableList,
kv: kvTotal,
kv: kv,
allocationsList: allocationsList,
memoryWeights: memoryWeights,
memoryLayerOutput: memoryLayerOutput,
@@ -380,9 +374,9 @@ func (m MemoryEstimate) LogValue() slog.Value {
slog.Group(
"weights",
// memory of the weights
"total", format.HumanBytes2(m.memoryWeights+m.memoryLayerOutput),
"total", format.HumanBytes2(m.memoryWeights),
// memory of repeating layers
"repeating", format.HumanBytes2(m.memoryWeights),
"repeating", format.HumanBytes2(m.memoryWeights-m.memoryLayerOutput),
// memory of non-repeating layers
"nonrepeating", format.HumanBytes2(m.memoryLayerOutput),
),

View File

@@ -61,7 +61,7 @@ func TestEstimateGPULayers(t *testing.T) {
projectors := []string{}
opts := api.DefaultOptions()
t.Run("cpu", func(t *testing.T) {
estimate := EstimateGPULayers(gpus, ggml, projectors, opts, 1)
estimate := EstimateGPULayers(gpus, ggml, projectors, opts)
assert.Equal(t, 0, estimate.Layers)
assert.Equal(t, uint64(0), estimate.Graph)
})
@@ -112,7 +112,7 @@ func TestEstimateGPULayers(t *testing.T) {
gpus[1].FreeMemory += gpuMinimumMemory + layerSize + s.layer1*layerSize + 1
gpus[0].FreeMemory += max(graphFullOffload, graphPartialOffload)
gpus[1].FreeMemory += max(graphFullOffload, graphPartialOffload)
estimate := EstimateGPULayers(gpus, ggml, projectors, opts, 1)
estimate := EstimateGPULayers(gpus, ggml, projectors, opts)
assert.Equal(t, int(s.expect0+s.expect1), estimate.Layers, "scenario %d: %v", i, s)
assert.Equal(t, fmt.Sprintf("%d,%d", s.expect0, s.expect1), estimate.TensorSplit, "scenario %d: %v", i, s)
var layerSums uint64

View File

@@ -109,7 +109,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
gpus = discover.GetCPUInfo()
}
estimate := EstimateGPULayers(gpus, f, projectors, opts, numParallel)
estimate := EstimateGPULayers(gpus, f, projectors, opts)
if len(gpus) > 1 || gpus[0].Library != "cpu" {
switch {
case gpus[0].Library == "metal" && estimate.VRAMSize > systemTotalMemory:
@@ -402,7 +402,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
s.cmd.Env = append(s.cmd.Env, visibleDevicesEnv+"="+visibleDevicesEnvVal)
}
slog.Info("starting llama server", "cmd", s.cmd)
slog.Info("starting llama server", "cmd", s.cmd.String())
if envconfig.Debug() {
filteredEnv := []string{}
for _, ev := range s.cmd.Env {
@@ -470,7 +470,7 @@ const ( // iota is reset to 0
ServerStatusError
)
func (s ServerStatus) String() string {
func (s ServerStatus) ToString() string {
switch s {
case ServerStatusReady:
return "llm server ready"
@@ -485,9 +485,12 @@ func (s ServerStatus) String() string {
}
}
type ServerStatusResponse struct {
Status ServerStatus `json:"status"`
Progress float32 `json:"progress"`
type ServerStatusResp struct {
Status string `json:"status"`
SlotsIdle int `json:"slots_idle"`
SlotsProcessing int `json:"slots_processing"`
Error string `json:"error"`
Progress float32 `json:"progress"`
}
func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
@@ -499,7 +502,7 @@ func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
}
if s.cmd.ProcessState.ExitCode() == -1 {
// Most likely a signal killed it, log some more details to try to help troubleshoot
slog.Warn("llama runner process no longer running", "sys", s.cmd.ProcessState.Sys(), "string", s.cmd.ProcessState)
slog.Warn("llama runner process no longer running", "sys", s.cmd.ProcessState.Sys(), "string", s.cmd.ProcessState.String())
}
return ServerStatusError, fmt.Errorf("llama runner process no longer running: %d %s", s.cmd.ProcessState.ExitCode(), msg)
}
@@ -524,19 +527,21 @@ func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
return ServerStatusError, fmt.Errorf("read health request: %w", err)
}
var ssr ServerStatusResponse
if err := json.Unmarshal(body, &ssr); err != nil {
var status ServerStatusResp
if err := json.Unmarshal(body, &status); err != nil {
return ServerStatusError, fmt.Errorf("health unmarshal encode response: %w", err)
}
switch ssr.Status {
case ServerStatusLoadingModel:
s.loadProgress = ssr.Progress
return ssr.Status, nil
case ServerStatusReady, ServerStatusNoSlotsAvailable:
return ssr.Status, nil
switch status.Status {
case "ok":
return ServerStatusReady, nil
case "no slot available":
return ServerStatusNoSlotsAvailable, nil
case "loading model":
s.loadProgress = status.Progress
return ServerStatusLoadingModel, nil
default:
return ssr.Status, fmt.Errorf("server error: %+v", ssr)
return ServerStatusError, fmt.Errorf("server error: %+v", status)
}
}
@@ -611,7 +616,7 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
status, _ := s.getServerStatus(ctx)
if lastStatus != status && status != ServerStatusReady {
// Only log on status changes
slog.Info("waiting for server to become available", "status", status)
slog.Info("waiting for server to become available", "status", status.ToString())
}
switch status {
case ServerStatusReady:
@@ -625,7 +630,7 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
slog.Debug(fmt.Sprintf("model load progress %0.2f", s.loadProgress))
stallTimer = time.Now().Add(stallDuration)
} else if !fullyLoaded && int(s.loadProgress*100.0) >= 100 {
slog.Debug("model load completed, waiting for server to become available", "status", status)
slog.Debug("model load completed, waiting for server to become available", "status", status.ToString())
stallTimer = time.Now().Add(stallDuration)
fullyLoaded = true
}
@@ -666,49 +671,63 @@ type ImageData struct {
AspectRatioID int `json:"aspect_ratio_id"`
}
type completion struct {
Content string `json:"content"`
Model string `json:"model"`
Prompt string `json:"prompt"`
Stop bool `json:"stop"`
StoppedLimit bool `json:"stopped_limit"`
Timings struct {
PredictedN int `json:"predicted_n"`
PredictedMS float64 `json:"predicted_ms"`
PromptN int `json:"prompt_n"`
PromptMS float64 `json:"prompt_ms"`
}
}
type CompletionRequest struct {
Prompt string
Format json.RawMessage
Images []ImageData
Options *api.Options
Grammar string // set before sending the request to the subprocess
}
// DoneReason represents the reason why a completion response is done
type DoneReason int
const (
// DoneReasonStop indicates the completion stopped naturally
DoneReasonStop DoneReason = iota
// DoneReasonLength indicates the completion stopped due to length limits
DoneReasonLength
// DoneReasonConnectionClosed indicates the completion stopped due to the connection being closed
DoneReasonConnectionClosed
)
func (d DoneReason) String() string {
switch d {
case DoneReasonLength:
return "length"
case DoneReasonStop:
return "stop"
default:
return "" // closed
}
}
type CompletionResponse struct {
Content string `json:"content"`
DoneReason DoneReason `json:"done_reason"`
Done bool `json:"done"`
PromptEvalCount int `json:"prompt_eval_count"`
PromptEvalDuration time.Duration `json:"prompt_eval_duration"`
EvalCount int `json:"eval_count"`
EvalDuration time.Duration `json:"eval_duration"`
Content string
DoneReason string
Done bool
PromptEvalCount int
PromptEvalDuration time.Duration
EvalCount int
EvalDuration time.Duration
}
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
request := map[string]any{
"prompt": req.Prompt,
"stream": true,
"n_predict": req.Options.NumPredict,
"n_keep": req.Options.NumKeep,
"main_gpu": req.Options.MainGPU,
"temperature": req.Options.Temperature,
"top_k": req.Options.TopK,
"top_p": req.Options.TopP,
"min_p": req.Options.MinP,
"typical_p": req.Options.TypicalP,
"repeat_last_n": req.Options.RepeatLastN,
"repeat_penalty": req.Options.RepeatPenalty,
"presence_penalty": req.Options.PresencePenalty,
"frequency_penalty": req.Options.FrequencyPenalty,
"mirostat": req.Options.Mirostat,
"mirostat_tau": req.Options.MirostatTau,
"mirostat_eta": req.Options.MirostatEta,
"seed": req.Options.Seed,
"stop": req.Options.Stop,
"image_data": req.Images,
"cache_prompt": true,
}
if len(req.Format) > 0 {
switch string(req.Format) {
case `null`, `""`:
@@ -716,7 +735,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
// these as "not set".
break
case `"json"`:
req.Grammar = grammarJSON
request["grammar"] = grammarJSON
default:
if req.Format[0] != '{' {
return fmt.Errorf("invalid format: %q; expected \"json\" or a valid JSON Schema object", req.Format)
@@ -727,15 +746,10 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
if g == nil {
return fmt.Errorf("invalid JSON schema in format")
}
req.Grammar = string(g)
request["grammar"] = string(g)
}
}
if req.Options == nil {
opts := api.DefaultOptions()
req.Options = &opts
}
if err := s.sem.Acquire(ctx, 1); err != nil {
if errors.Is(err, context.Canceled) {
slog.Info("aborting completion request due to client closing the connection")
@@ -756,7 +770,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
if err != nil {
return err
} else if status != ServerStatusReady {
return fmt.Errorf("unexpected server status: %s", status)
return fmt.Errorf("unexpected server status: %s", status.ToString())
}
// Handling JSON marshaling with special characters unescaped.
@@ -764,7 +778,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
enc := json.NewEncoder(buffer)
enc.SetEscapeHTML(false)
if err := enc.Encode(req); err != nil {
if err := enc.Encode(request); err != nil {
return fmt.Errorf("failed to marshal data: %v", err)
}
@@ -809,12 +823,13 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
continue
}
// slog.Debug("got line", "line", string(line))
evt, ok := bytes.CutPrefix(line, []byte("data: "))
if !ok {
evt = line
}
var c CompletionResponse
var c completion
if err := json.Unmarshal(evt, &c); err != nil {
return fmt.Errorf("error unmarshalling llm prediction response: %v", err)
}
@@ -838,8 +853,20 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
})
}
if c.Done {
fn(c)
if c.Stop {
doneReason := "stop"
if c.StoppedLimit {
doneReason = "length"
}
fn(CompletionResponse{
Done: true,
DoneReason: doneReason,
PromptEvalCount: c.Timings.PromptN,
PromptEvalDuration: parseDurationMs(c.Timings.PromptMS),
EvalCount: c.Timings.PredictedN,
EvalDuration: parseDurationMs(c.Timings.PredictedMS),
})
return nil
}
}
@@ -887,7 +914,7 @@ func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, err
if err != nil {
return nil, err
} else if status != ServerStatusReady {
return nil, fmt.Errorf("unexpected server status: %s", status)
return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
}
data, err := json.Marshal(EmbeddingRequest{Content: input})
@@ -1032,3 +1059,12 @@ func (s *llmServer) EstimatedVRAMByGPU(gpuID string) uint64 {
}
return 0
}
func parseDurationMs(ms float64) time.Duration {
dur, err := time.ParseDuration(fmt.Sprintf("%fms", ms))
if err != nil {
panic(err)
}
return dur
}

View File

@@ -2,19 +2,28 @@ package ml
import (
"bytes"
"context"
"encoding/binary"
"fmt"
"os"
"slices"
"strconv"
"strings"
"github.com/ollama/ollama/fs"
)
type Config interface {
Architecture() string
String(string, ...string) string
Uint(string, ...uint32) uint32
Float(string, ...float32) float32
Bool(string, ...bool) bool
Strings(string, ...[]string) []string
Uints(string, ...[]uint32) []uint32
Floats(string, ...[]float32) []float32
}
type Backend interface {
Config() fs.Config
Config() Config
Get(name string) Tensor
NewContext() Context
NewContextSize(size int) Context
@@ -51,10 +60,6 @@ type CacheConfig struct {
// BackendParams controls how the backend loads and executes models
type BackendParams struct {
// Progress is a callback function that allows reporting percentage completion
// of model loading
Progress func(float32)
// NumThreads sets the number of threads to use if running on the CPU
NumThreads int
@@ -71,9 +76,9 @@ type BackendParams struct {
FlashAttention bool
}
var backends = make(map[string]func(context.Context, *os.File, BackendParams) (Backend, error))
var backends = make(map[string]func(*os.File, BackendParams) (Backend, error))
func RegisterBackend(name string, f func(context.Context, *os.File, BackendParams) (Backend, error)) {
func RegisterBackend(name string, f func(*os.File, BackendParams) (Backend, error)) {
if _, ok := backends[name]; ok {
panic("backend: backend already registered")
}
@@ -81,9 +86,9 @@ func RegisterBackend(name string, f func(context.Context, *os.File, BackendParam
backends[name] = f
}
func NewBackend(ctx context.Context, f *os.File, params BackendParams) (Backend, error) {
func NewBackend(f *os.File, params BackendParams) (Backend, error) {
if backend, ok := backends["ggml"]; ok {
return backend(ctx, f, params)
return backend(f, params)
}
return nil, fmt.Errorf("unsupported backend")
@@ -97,20 +102,15 @@ type Context interface {
Forward(...Tensor) Context
Compute(...Tensor)
// Reserve is analogous to Compute but rather than executing a
// graph, simply preallocates memory. Typically called with a
// worst case graph to ensure all resources are available for
// for future inference.
Reserve() error
MaxGraphNodes() int
Close()
// Input returns a context appropriate for creating tensors that are
// inputs to the model (which includes things like output locations)
// Input returns a context appropriate for creating input tensors
Input() Context
// Output returns a context appropriate for creating output tensors
Output() Context
// Layer returns a context appropriate for creating intermediate tensors
Layer(int) Context
}
@@ -125,7 +125,6 @@ type Tensor interface {
Bytes() []byte
Floats() []float32
Neg(ctx Context) Tensor
Add(ctx Context, t2 Tensor) Tensor
Mul(ctx Context, t2 Tensor) Tensor
Mulmat(ctx Context, t2 Tensor) Tensor
@@ -140,10 +139,7 @@ type Tensor interface {
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32) Tensor
IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
Sin(ctx Context) Tensor
Cos(ctx Context) Tensor
Tanh(ctx Context) Tensor
GELU(ctx Context) Tensor
SILU(ctx Context) Tensor
@@ -158,13 +154,9 @@ type Tensor interface {
Unpad(ctx Context, shape ...int) Tensor
Stack(ctx Context, dim int, s ...Tensor) Tensor
// Repeat repeats the tensor n times along dimension dim
Repeat(ctx Context, dim, n int) Tensor
Concat(ctx Context, t2 Tensor, dim int) Tensor
Rows(ctx Context, t2 Tensor) Tensor
Copy(ctx Context, t2 Tensor) Tensor
Duplicate(ctx Context) Tensor
}
// ScaledDotProductAttention implements a fused attention
@@ -229,7 +221,7 @@ func Dump(ctx Context, t Tensor, opts ...DumpOptions) string {
return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
})
case DTypeF16, DTypeQ80, DTypeQ40:
f32 := ctx.Input().Empty(DTypeF32, t.Shape()...)
f32 := ctx.Empty(DTypeF32, t.Shape()...)
f32 = t.Copy(ctx, f32)
return dump[[]float32](ctx, f32, opts[0].Items, func(f float32) string {
return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)

View File

@@ -9,24 +9,20 @@ package ggml
import "C"
import (
"context"
"errors"
"fmt"
"io"
"log/slog"
"maps"
"os"
"runtime"
"slices"
"strconv"
"strings"
"sync/atomic"
"unicode"
"unsafe"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/fs"
fsggml "github.com/ollama/ollama/fs/ggml"
fs "github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/ml"
ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src"
"golang.org/x/sync/errgroup"
@@ -43,17 +39,16 @@ func devices() []*C.struct_ggml_backend_device {
}
type Backend struct {
meta *fsggml.GGML
sched *C.struct_ggml_backend_sched
schedBackends []*C.struct_ggml_backend
schedBufts []*C.struct_ggml_backend_buffer_type
meta *fs.GGML
sched *C.struct_ggml_backend_sched
tensors map[string]*C.struct_ggml_tensor
// input is the backend used for inputs
input *C.struct_ggml_backend_buffer_type
// output is the backend used for outputs
output *C.struct_ggml_backend_buffer_type
// layers is the backend used for repeating layers
layers map[int]*C.struct_ggml_backend_buffer_type
@@ -63,8 +58,8 @@ type Backend struct {
maxGraphNodes int
}
func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend, error) {
meta, n, err := fsggml.Decode(r, -1)
func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
meta, n, err := fs.Decode(r, -1)
if err != nil {
return nil, err
}
@@ -188,7 +183,7 @@ func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend,
maxTensors += blocks * 2
type tensor struct {
source *fsggml.Tensor
source *fs.Tensor
target string
}
@@ -286,10 +281,6 @@ func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend,
}
b := C.ggml_backend_alloc_ctx_tensors_from_buft(c, bt)
if b == nil {
return nil, fmt.Errorf("unable to allocate memory from device %v for model weights", C.GoString(C.ggml_backend_buft_name(bt)))
}
C.ggml_backend_buffer_set_usage(b, C.GGML_BACKEND_BUFFER_USAGE_WEIGHTS)
bbs[c] = b
}
@@ -306,16 +297,12 @@ func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend,
}
}
var doneBytes atomic.Uint64
totalBytes := uint64(n) - meta.Tensors().Offset
g, ctx := errgroup.WithContext(ctx)
g.SetLimit(runtime.GOMAXPROCS(0))
// concurrently read in tensor data. uses a section reader which is safe for concurrent reads
sr := io.NewSectionReader(r, int64(meta.Tensors().Offset), n-int64(meta.Tensors().Offset))
var g errgroup.Group
for _, t := range meta.Tensors().Items() {
g.Go(func() error {
tts := make([]*C.struct_ggml_tensor, max(1, len(targets[t.Name])))
for i := range tts {
target := targets[t.Name][i]
for _, target := range targets[t.Name] {
g.Go(func() error {
if target == "" {
target = t.Name
}
@@ -325,51 +312,23 @@ func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend,
return fmt.Errorf("unassigned tensor: %s", t.Name)
}
tts[i] = tt
}
// Create a new FD for each goroutine so that each FD is read sequentially, rather than
// seeking around within an FD shared between all goroutines.
file, err := os.Open(r.Name())
if err != nil {
return err
}
defer file.Close()
sr := io.NewSectionReader(file, int64(meta.Tensors().Offset+t.Offset), int64(t.Size()))
bts := make([]byte, 128*format.KibiByte)
var s uint64
for s < t.Size() {
n, err := io.ReadFull(sr, bts[:min(len(bts), int(t.Size()-s))])
bts := make([]byte, t.Size())
n, err := io.ReadFull(io.NewSectionReader(sr, int64(t.Offset), int64(t.Size())), bts)
if err != nil {
return err
}
for _, tt := range tts {
C.ggml_backend_tensor_set(tt, unsafe.Pointer(&bts[0]), C.size_t(s), C.size_t(n))
if n != len(bts) {
return errors.New("short read")
}
s += uint64(n)
if params.Progress != nil {
done := doneBytes.Add(uint64(n))
params.Progress(float32(done) / float32(totalBytes))
}
}
return nil
})
C.ggml_backend_tensor_set(tt, unsafe.Pointer(&bts[0]), 0, C.size_t(t.Size()))
return nil
})
}
}
// start a goroutine to cancel the errgroup if the parent context is done
go func() {
<-ctx.Done()
g.Go(func() error {
return ctx.Err()
})
}()
if err := g.Wait(); err != nil {
if g.Wait() != nil {
return nil, err
}
@@ -394,6 +353,8 @@ func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend,
schedBackends = append(schedBackends, b)
schedBufts = append(schedBufts, bt)
slog.Info("compute graph", "backend", C.GoString(C.ggml_backend_name(b)), "buffer_type", C.GoString(C.ggml_backend_buft_name(bt)))
if C.ggml_backend_is_cpu(b) {
// set number of threads for cpu backend
C.ggml_backend_cpu_set_n_threads(b, C.int(Threads(params.NumThreads)))
@@ -410,11 +371,10 @@ func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend,
(*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&schedBufts[0])),
C.int(len(schedBackends)),
C.size_t(maxGraphNodes),
C._Bool(len(gpus) > 1 && slices.Contains(gpus, output.d)),
true,
),
schedBackends: schedBackends,
schedBufts: schedBufts,
input: deviceBufferTypes[input.d],
input: deviceBufferTypes[input.d],
output: deviceBufferTypes[output.d],
layers: func() map[int]*C.struct_ggml_backend_buffer_type {
m := make(map[int]*C.struct_ggml_backend_buffer_type)
for i, layer := range layers {
@@ -430,7 +390,7 @@ func init() {
ml.RegisterBackend("ggml", New)
}
func (b *Backend) Config() fs.Config {
func (b *Backend) Config() ml.Config {
return b.meta.KV()
}
@@ -495,6 +455,19 @@ func (c Context) Input() ml.Context {
return &c
}
func (c Context) Output() ml.Context {
if c.b.output != nil {
return &Context{
b: c.b,
ctx: c.ctx,
buft: c.b.output,
maxGraphNodes: c.maxGraphNodes,
}
}
return &c
}
func (c Context) Layer(i int) ml.Context {
if buft, ok := c.b.layers[i]; ok {
return &Context{
@@ -539,24 +512,6 @@ func (c Context) Compute(tensors ...ml.Tensor) {
}
}
func (c Context) Reserve() error {
if !C.ggml_backend_sched_reserve(c.b.sched, c.graph) {
C.ggml_backend_sched_reset(c.b.sched)
return errors.New("failed to reserve graph")
}
slog.Debug("compute graph", "nodes", C.ggml_graph_n_nodes(c.graph), "splits", C.ggml_backend_sched_get_n_splits(c.b.sched))
for i := range c.b.schedBackends {
size := C.ggml_backend_sched_get_buffer_size(c.b.sched, c.b.schedBackends[i])
slog.Info("compute graph", "backend", C.GoString(C.ggml_backend_name(c.b.schedBackends[i])), "buffer_type", C.GoString(C.ggml_backend_buft_name(c.b.schedBufts[i])),
"size", format.HumanBytes2(uint64(size)))
}
C.ggml_backend_sched_reset(c.b.sched)
return nil
}
func (c Context) MaxGraphNodes() int {
return c.maxGraphNodes
}
@@ -574,9 +529,9 @@ func pad(length, pad C.size_t) C.size_t {
return ((length + pad - 1) / pad) * pad
}
func (c Context) newTensor(dtype ml.DType, shape []int) (ml.Tensor, error) {
func (c Context) newTensor(dtype ml.DType, shape []int) ml.Tensor {
if c.buft == nil {
panic("set Input or Layer before creating tensors")
panic("set Input, Output, or Layer before creating tensors")
}
var cdtype uint32
@@ -597,7 +552,7 @@ func (c Context) newTensor(dtype ml.DType, shape []int) (ml.Tensor, error) {
if len(shape) < 1 || shape[0] == 0 {
var shape C.int64_t = 0
return &Tensor{b: c.b, t: C.ggml_new_tensor(c.ctx, cdtype, 1, &shape)}, nil
return &Tensor{b: c.b, t: C.ggml_new_tensor(c.ctx, cdtype, 1, &shape)}
} else if len(shape) > 4 {
panic("unsupported number of dimensions")
}
@@ -611,29 +566,16 @@ func (c Context) newTensor(dtype ml.DType, shape []int) (ml.Tensor, error) {
t := C.ggml_new_tensor(c.ctx, cdtype, C.int(len(shape)), shapeToGGML(shape))
size := pad(C.ggml_backend_buft_get_alloc_size(c.buft, t), C.ggml_backend_buft_get_alignment(c.buft))
b := C.ggml_backend_buft_alloc_buffer(c.buft, size)
if b == nil {
return nil, fmt.Errorf("unable to allocate %v from device %v for new tensor", format.HumanBytes2(uint64(size)), C.GoString(C.ggml_backend_buft_name(c.buft)))
}
C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
return &Tensor{b: c.b, t: t}, nil
return &Tensor{b: c.b, t: t}
}
func (c Context) Empty(dtype ml.DType, shape ...int) ml.Tensor {
t, err := c.newTensor(dtype, shape)
if err != nil {
panic(err)
}
return t
return c.newTensor(dtype, shape)
}
func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
t, err := c.newTensor(dtype, shape)
if err != nil {
panic(err)
}
t := c.newTensor(dtype, shape)
C.ggml_set_zero(t.(*Tensor).t)
return t
}
@@ -661,11 +603,7 @@ func (c Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
return nil, err
}
t, err := c.newTensor(ml.DTypeF32, shape)
if err != nil {
return nil, err
}
t := c.newTensor(ml.DTypeF32, shape)
if len(s) > 0 {
C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
}
@@ -678,11 +616,7 @@ func (c Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
return nil, err
}
t, err := c.newTensor(ml.DTypeI32, shape)
if err != nil {
return nil, err
}
t := c.newTensor(ml.DTypeI32, shape)
if len(s) > 0 {
C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
}
@@ -766,13 +700,6 @@ func (t *Tensor) DType() ml.DType {
}
}
func (t *Tensor) Neg(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_neg(ctx.(*Context).ctx, t.t),
}
}
func (t *Tensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
return &Tensor{
b: t.b,
@@ -780,27 +707,6 @@ func (t *Tensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
}
}
func (t *Tensor) Repeat(ctx ml.Context, dim, n int) ml.Tensor {
if dim < 0 || dim >= C.GGML_MAX_DIMS {
panic("invalid dimension")
}
shape := make([]C.int64_t, C.GGML_MAX_DIMS)
for i := range C.GGML_MAX_DIMS {
if i == dim {
shape[i] = C.int64_t(t.Dim(i) * n)
} else {
shape[i] = C.int64_t(t.Dim(i))
}
}
tmpl := C.ggml_new_tensor(ctx.(*Context).ctx, t.t._type, C.int(len(shape)), unsafe.SliceData(shape))
return &Tensor{
b: t.b,
t: C.ggml_repeat(ctx.(*Context).ctx, t.t, tmpl),
}
}
func (t *Tensor) Stack(ctx ml.Context, dim int, s ...ml.Tensor) ml.Tensor {
if len(s) > 0 {
return t.Concat(ctx, s[0].Stack(ctx, dim, s[1:]...), dim)
@@ -937,20 +843,6 @@ func (t *Tensor) Softmax(ctx ml.Context) ml.Tensor {
}
}
func (t *Tensor) Sin(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_sin(ctx.(*Context).ctx, t.t),
}
}
func (t *Tensor) Cos(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_cos(ctx.(*Context).ctx, t.t),
}
}
func (t *Tensor) Tanh(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
@@ -1039,13 +931,6 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi
}
}
func (t *Tensor) IM2Col(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_im2col(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(s0), C.int(s1), C.int(p0), C.int(p1), C.int(d0), C.int(d1), true, C.GGML_TYPE_F32),
}
}
func (t *Tensor) GELU(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
@@ -1114,10 +999,3 @@ func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask ml.T
return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
}
}
func (t *Tensor) Duplicate(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_dup(ctx.(*Context).ctx, t.t),
}
}

View File

@@ -61,13 +61,11 @@
#define GGML_CUDA_CC_RDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x1010) // RX 5000
#define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x1030) // RX 6000, minimum for dp4a
#define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x1100) // RX 7000, minimum for WMMA
#define GGML_CUDA_CC_RDNA4 (GGML_CUDA_CC_OFFSET_AMD + 0x1200) // RX 9000
#define GGML_CUDA_CC_IS_RDNA(cc) (cc >= GGML_CUDA_CC_RDNA1)
#define GGML_CUDA_CC_IS_RDNA1(cc) (cc >= GGML_CUDA_CC_RDNA1 && cc < GGML_CUDA_CC_RDNA2)
#define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3)
#define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3 && cc < GGML_CUDA_CC_RDNA4)
#define GGML_CUDA_CC_IS_RDNA4(cc) (cc >= GGML_CUDA_CC_RDNA4)
#define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3)
#define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA)
#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1)
@@ -388,7 +386,7 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(RDNA2)
c = __builtin_amdgcn_sdot4(a, b, c, false);
#elif defined(RDNA3) || defined(RDNA4)
#elif defined(RDNA3)
c = __builtin_amdgcn_sudot4( true, a, true, b, c, false);
#elif defined(__gfx1010__) || defined(__gfx900__)
int tmp1;

View File

@@ -149,5 +149,5 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
}
return (!GGML_CUDA_CC_IS_RDNA4(cc) && !GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
}

View File

@@ -2577,9 +2577,9 @@ static __device__ void mul_mat_q_process_tile(
template <ggml_type type, int mmq_x, int nwarps, bool need_check>
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
#if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
#if defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
__launch_bounds__(WARP_SIZE*nwarps, 2)
#endif // defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
#endif // defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
#else
#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
__launch_bounds__(WARP_SIZE*nwarps, 1)

View File

@@ -62,13 +62,13 @@ static __global__ void mul_mat_vec_q(
constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4))
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
constexpr int nwarps = 1;
constexpr int rows_per_cuda_block = 1;
#else
constexpr int nwarps = ncols_y <= 4 ? 4 : 2;
constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2;
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3) && !defined(RDNA4)
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
const int row0 = rows_per_cuda_block*blockIdx.x;

View File

@@ -150,10 +150,6 @@
#define CDNA
#endif
#if defined(__gfx1200__) || defined(__gfx1201__)
#define RDNA4
#endif
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
defined(__gfx1150__) || defined(__gfx1151__)
#define RDNA3

View File

@@ -3083,13 +3083,6 @@ kernel void kernel_cos(
dst[tpig] = cos(src0[tpig]);
}
kernel void kernel_neg(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = -src0[tpig];
}
kernel void kernel_sum_rows(
device const float * src0,
device float * dst,

View File

@@ -423,7 +423,6 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_SQRT,
GGML_METAL_KERNEL_TYPE_SIN,
GGML_METAL_KERNEL_TYPE_COS,
GGML_METAL_KERNEL_TYPE_NEG,
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
@@ -1040,7 +1039,6 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
@@ -1204,7 +1202,6 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_SILU:
case GGML_UNARY_OP_ELU:
case GGML_UNARY_OP_NEG:
return ggml_is_contiguous(op->src[0]);
default:
return false;
@@ -1876,18 +1873,6 @@ static void ggml_metal_encode_node(
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_UNARY_OP_NEG:
{
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NEG].pipeline;
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
const int64_t n = ggml_nelements(dst);
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
default:
{
GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));

View File

@@ -945,13 +945,6 @@ kernel void kernel_cos(
dst[tpig] = cos(src0[tpig]);
}
kernel void kernel_neg(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = -src0[tpig];
}
kernel void kernel_sum_rows(
device const float * src0,
device float * dst,

View File

@@ -1,5 +1,4 @@
#include <string.h>
#include <inttypes.h>
#include "ollama-debug.h"
@@ -25,7 +24,7 @@ static void print_tensor(const void *tensor, void (*cb)(const void *, int),
fprintf(stderr, "[");
for (int i = 0; i < dims[0]; i++) {
if (i >= nitems && i < dims[0] - nitems) {
fprintf(stderr, "... (%" PRIi64 " more), ", dims[0] - 2 * nitems);
fprintf(stderr, "... (%lld more), ", dims[0] - 2 * nitems);
int skip = dims[0] - 2 * nitems;
if (ndims > 1) {
stride += mul(dims + 1, ndims - 1) * skip;
@@ -68,7 +67,7 @@ static void print_tensor_i32(const void *tensor, int i) {
}
static void ollama_debug_tensor(const struct ggml_tensor *tensor, bool verbose, const char *prefix, int indent) {
fprintf(stderr, "%s%s %s (%s): [%" PRIi64 " %" PRIi64 " %" PRIi64 " %" PRIi64 "]\n", prefix, tensor->name,
fprintf(stderr, "%s%s %s (%s): [%lld %lld %lld %lld]\n", prefix, tensor->name,
ggml_op_name(tensor->op), ggml_type_name(tensor->type), tensor->ne[0],
tensor->ne[1], tensor->ne[2], tensor->ne[3]);

View File

@@ -1,7 +1,5 @@
package input
import "github.com/ollama/ollama/ml"
// Input represents one token in the input stream
type Input struct {
// Token is a single element of text.
@@ -17,12 +15,6 @@ type Input struct {
// stored in Multimodal, used for caching and comparing
// equality.
MultimodalHash uint64
// SameBatch forces the following number of tokens to be processed
// in a single batch, breaking and extending batches as needed.
// Useful for things like images that must be processed in one
// shot.
SameBatch int
}
// MultimodalIndex is a multimodal element (such as an image)
@@ -35,24 +27,11 @@ type MultimodalIndex struct {
Multimodal any
}
// Batch contains the inputs for a model forward pass
type Batch struct {
// Inputs is the input tokens, including placeholders for multimodal inputs.
Inputs ml.Tensor
// Multimodal is a set of multimodal embeddings previously created by
// EncodeMultimodal, along with an index into Inputs. Unused for text-only
// models or for batches without multimodal elements.
// Options contains the inputs for a model forward pass
type Options struct {
Inputs []int32
Multimodal []MultimodalIndex
// Positions is the position for each Input, relative to its sequence. Equal
// in length to Inputs.
Positions []int32
// Sequences is the sequence for each Input. Equal in length to Inputs.
Sequences []int
// Outputs are the set of indicies into Inputs for which output data should
// be returned.
Outputs []int32
Positions []int32
Sequences []int
Outputs []int32
}

View File

@@ -1,7 +1,6 @@
package model
import (
"context"
"errors"
"fmt"
_ "image/jpeg"
@@ -16,19 +15,16 @@ import (
_ "golang.org/x/image/tiff"
_ "golang.org/x/image/webp"
"github.com/ollama/ollama/fs"
fsggml "github.com/ollama/ollama/fs/ggml"
fs "github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
_ "github.com/ollama/ollama/ml/backend"
"github.com/ollama/ollama/model/input"
)
var ErrNoVisionModel = errors.New("this model is missing data required for image input")
// Model implements a specific model architecture, defining the forward pass and any model-specific configuration
type Model interface {
Forward(ml.Context, input.Batch) (ml.Tensor, error)
Forward(ml.Context, input.Options) (ml.Tensor, error)
Backend() ml.Backend
Config() config
@@ -62,7 +58,7 @@ type MultimodalProcessor interface {
// This function is also responsible for updating MultimodalHash for any Multimodal
// that is modified to ensure that there is a unique hash value that accurately
// represents the contents.
PostTokenize([]input.Input) ([]input.Input, error)
PostTokenize(ml.Context, []input.Input) ([]input.Input, error)
}
// Base implements the common fields and methods for all models
@@ -84,10 +80,10 @@ func (m *Base) Config() config {
return m.config
}
var models = make(map[string]func(fs.Config) (Model, error))
var models = make(map[string]func(ml.Config) (Model, error))
// Register registers a model constructor for the given architecture
func Register(name string, f func(fs.Config) (Model, error)) {
func Register(name string, f func(ml.Config) (Model, error)) {
if _, ok := models[name]; ok {
panic("model: model already registered")
}
@@ -96,14 +92,14 @@ func Register(name string, f func(fs.Config) (Model, error)) {
}
// New initializes a new model instance with the provided configuration based on the metadata in the model file
func New(ctx context.Context, modelPath string, params ml.BackendParams) (Model, error) {
func New(modelPath string, params ml.BackendParams) (Model, error) {
r, err := os.Open(modelPath)
if err != nil {
return nil, err
}
defer r.Close()
b, err := ml.NewBackend(ctx, r, params)
b, err := ml.NewBackend(r, params)
if err != nil {
return nil, err
}
@@ -132,14 +128,14 @@ func NewTextProcessor(s string) (TextProcessor, error) {
return nil, err
}
defer r.Close()
meta, _, err := fsggml.Decode(r, -1)
meta, _, err := fs.Decode(r, -1)
if err != nil {
return nil, err
}
return getTextProcessor(meta.KV())
}
func getTextProcessor(kv fsggml.KV) (TextProcessor, error) {
func getTextProcessor(kv fs.KV) (TextProcessor, error) {
arch := kv.Architecture()
f, ok := models[arch]
if !ok {
@@ -282,30 +278,24 @@ func canNil(t reflect.Type) bool {
t.Kind() == reflect.Slice
}
func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Tensor, error) {
if len(batch.Positions) != len(batch.Sequences) {
return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(batch.Positions), len(batch.Sequences))
func Forward(ctx ml.Context, m Model, opts input.Options) (ml.Tensor, error) {
if len(opts.Positions) != len(opts.Sequences) {
return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(opts.Positions), len(opts.Sequences))
}
if len(batch.Positions) < 1 {
if len(opts.Positions) < 1 {
return nil, errors.New("batch size cannot be less than 1")
}
var err error
batch.Inputs, err = ctx.Input().FromIntSlice(inputs, len(inputs))
if err != nil {
return nil, err
}
cache := m.Config().Cache
if cache != nil {
err := cache.StartForward(ctx, batch, false)
err := cache.StartForward(ctx, opts)
if err != nil {
return nil, err
}
}
t, err := m.Forward(ctx, batch)
t, err := m.Forward(ctx, opts)
if err != nil {
return nil, err
}

View File

@@ -7,8 +7,7 @@ import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/fs"
fsggml "github.com/ollama/ollama/fs/ggml"
fs "github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/backend/ggml"
"github.com/ollama/ollama/ml/nn"
@@ -140,7 +139,7 @@ func TestPopulateFieldsAlternateName(t *testing.T) {
}
func TestGetTextProcessor(t *testing.T) {
tp, err := getTextProcessor(fsggml.KV{})
tp, err := getTextProcessor(fs.KV{})
if err == nil {
t.Error("expected error")
} else if !strings.Contains(err.Error(), "unsupported model architecture") {
@@ -149,10 +148,10 @@ func TestGetTextProcessor(t *testing.T) {
t.Error("expected nil tp")
}
models["dummy"] = func(fs.Config) (Model, error) {
models["dummy"] = func(ml.Config) (Model, error) {
return notTextProcessorModel{}, nil
}
tp, err = getTextProcessor(fsggml.KV{"general.architecture": "dummy"})
tp, err = getTextProcessor(fs.KV{"general.architecture": "dummy"})
if err == nil {
t.Error("expected error")
} else if !strings.Contains(err.Error(), "not a TextProcessor") {
@@ -164,7 +163,7 @@ func TestGetTextProcessor(t *testing.T) {
type notTextProcessorModel struct{}
func (notTextProcessorModel) Forward(ml.Context, input.Batch) (ml.Tensor, error) {
func (notTextProcessorModel) Forward(ml.Context, input.Options) (ml.Tensor, error) {
panic("unimplemented")
}

View File

@@ -3,7 +3,6 @@ package gemma2
import (
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
@@ -36,9 +35,10 @@ const (
gemma27BLayerCount = 46
)
func New(c fs.Config) (model.Model, error) {
func New(c ml.Config) (model.Model, error) {
m := Model{
SentencePieceModel: model.NewSentencePieceModel(
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
@@ -168,18 +168,23 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
return hiddenState.Add(ctx, residual)
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
if err != nil {
return nil, err
}
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
if err != nil {
return nil, err
}
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
if err != nil {
return nil, err
}
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize)))
if len(m.Layers) == gemma27BLayerCount {
@@ -206,7 +211,8 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
// final logit softcap
hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.Options.finalLogitSoftcap))
hiddenState = hiddenState.Tanh(ctx)
return hiddenState.Scale(ctx, float64(m.Options.finalLogitSoftcap)), nil
hiddenState = hiddenState.Scale(ctx, float64(m.Options.finalLogitSoftcap))
return hiddenState.Rows(ctx, outputs), nil
}
func init() {

View File

@@ -2,11 +2,11 @@ package gemma3
import (
"bytes"
"encoding/binary"
"hash/fnv"
"image"
"math"
"slices"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
@@ -53,9 +53,10 @@ func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, i
return visionOutputs
}
func New(c fs.Config) (model.Model, error) {
func New(c ml.Config) (model.Model, error) {
m := Model{
SentencePieceModel: model.NewSentencePieceModel(
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
@@ -83,10 +84,6 @@ func New(c fs.Config) (model.Model, error) {
}
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
if len(m.VisionModel.Layers) == 0 {
return nil, model.ErrNoVisionModel
}
image, _, err := image.Decode(bytes.NewReader(multimodalData))
if err != nil {
return nil, err
@@ -111,23 +108,36 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
return visionOutputs, nil
}
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
type imageToken struct {
embedding ml.Tensor
index int
}
func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Input, error) {
var result []input.Input
fnvHash := fnv.New64a()
for _, inp := range inputs {
if inp.Multimodal == nil {
result = append(result, inp)
} else {
imageInputs := []input.Input{
{Token: 108}, // "\n\n"
{Token: 255999}, // "<start_of_image>""
}
result = append(result, imageInputs...)
// add image embeddings
inputMultimodal := inp.Multimodal.(ml.Tensor)
result = append(result,
input.Input{Token: 108, SameBatch: inputMultimodal.Dim(1) + 3}, // "\n\n"
input.Input{Token: 255999}, // "<start_of_image>""
input.Input{Multimodal: inputMultimodal, MultimodalHash: inp.MultimodalHash}, // image data is on the first placeholder
)
for i := range inputMultimodal.Dim(1) {
fnvHash.Reset()
binary.Write(fnvHash, binary.NativeEndian, inp.MultimodalHash)
fnvHash.Write([]byte{byte(i)})
// add image token placeholders
result = append(result, slices.Repeat([]input.Input{{Token: 0}}, inputMultimodal.Dim(1)-1)...)
imageToken := imageToken{embedding: inputMultimodal, index: i}
result = append(result, input.Input{Multimodal: imageToken, MultimodalHash: fnvHash.Sum64()})
}
result = append(result,
input.Input{Token: 256000}, // <end_of_image>
@@ -139,18 +149,23 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
return result, nil
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
if err != nil {
return nil, err
}
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
if err != nil {
return nil, err
}
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
if err != nil {
return nil, err
}
return m.TextModel.Forward(ctx, inputs, positions, outputs, opts, m.Cache), nil
}
func init() {

View File

@@ -3,7 +3,6 @@ package gemma3
import (
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
@@ -11,7 +10,7 @@ import (
"github.com/ollama/ollama/model/input"
)
type TextConfig struct {
type TextOptions struct {
hiddenSize, numHeads, numKVHeads int
attnKeyLen, attnValLen int
eps, ropeScale float32
@@ -28,7 +27,7 @@ type TextModel struct {
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
*TextConfig
*TextOptions
}
const (
@@ -41,11 +40,12 @@ const (
cacheTypeCausal
)
func newTextModel(c fs.Config) *TextModel {
func newTextModel(c ml.Config) *TextModel {
numBlocks := int(c.Uint("block_count"))
m := TextModel{
SentencePieceModel: model.NewSentencePieceModel(
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
@@ -55,7 +55,7 @@ func newTextModel(c fs.Config) *TextModel {
},
),
Layers: make([]TextLayer, numBlocks),
TextConfig: &TextConfig{
TextOptions: &TextOptions{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
@@ -84,7 +84,7 @@ type TextSelfAttention struct {
Output *nn.Linear `gguf:"attn_output"`
}
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
batchSize := hiddenState.Dim(1)
ropeType := uint32(2)
@@ -120,12 +120,12 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
}
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
ropeBase := m.TextConfig.ropeLocalBase
ropeBase := m.TextOptions.ropeLocalBase
if (layer+1)%gemmaGlobalCacheCount == 0 {
ropeBase = m.TextConfig.ropeGlobalBase
ropeBase = m.TextOptions.ropeGlobalBase
}
return key.RoPE(ctx, shift, nil, uint32(m.TextConfig.attnKeyLen), uint32(2), ropeBase, m.TextConfig.ropeScale), nil
return key.RoPE(ctx, shift, nil, uint32(m.TextOptions.attnKeyLen), uint32(2), ropeBase, m.TextOptions.ropeScale), nil
}
type TextMLP struct {
@@ -134,7 +134,7 @@ type TextMLP struct {
Gate *nn.Linear `gguf:"ffn_gate"`
}
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextConfig) ml.Tensor {
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor {
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState)
}
@@ -148,7 +148,7 @@ type TextLayer struct {
PostMLPNorm *nn.RMSNorm `gguf:"post_ffw_norm"`
}
func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
residual := hiddenState
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
@@ -171,21 +171,54 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs,
return hiddenState.Add(ctx, residual)
}
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor {
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize)))
// set image embeddings
func setImageEmbeddings(ctx ml.Context, hiddenState ml.Tensor, multimodal []input.MultimodalIndex) []int {
var embedding ml.Tensor
var src, dst, length int
var except []int
for _, image := range batch.Multimodal {
visionOutputs := image.Multimodal.(ml.Tensor)
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1))))
for i := range visionOutputs.Dim(1) {
except = append(except, image.Index+i)
for _, image := range multimodal {
imageToken := image.Multimodal.(imageToken)
imageSrc := imageToken.index
imageDst := image.Index
if embedding == nil {
embedding = imageToken.embedding
src = imageSrc
dst = imageDst
length = 1
} else if embedding == imageToken.embedding && imageSrc+1 == src && imageDst+1 == dst {
src = imageSrc
dst = imageDst
length++
} else if embedding == imageToken.embedding && src+length == imageSrc && dst+length == imageDst {
length++
} else {
visionOutputs := embedding.View(ctx, src*embedding.Stride(1), length*embedding.Dim(0))
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, dst*hiddenState.Stride(1), length*hiddenState.Dim(0))))
embedding = imageToken.embedding
src = imageSrc
dst = imageDst
length = 1
}
except = append(except, imageDst)
}
if embedding != nil {
visionOutputs := embedding.View(ctx, src*embedding.Stride(1), length*embedding.Dim(0))
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, dst*hiddenState.Stride(1), length*hiddenState.Dim(0))))
}
return except
}
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, opts input.Options, cache kvcache.Cache) ml.Tensor {
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
except := setImageEmbeddings(ctx, hiddenState, opts.Multimodal)
for i, layer := range m.Layers {
// gemma alternates between the sliding window (local) and causal (global)
// kv cache every 6 layers
@@ -206,7 +239,7 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
lastLayerOutputs = outputs
}
hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextConfig)
hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextOptions)
}
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)

View File

@@ -3,7 +3,6 @@ package gemma3
import (
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
)
@@ -112,7 +111,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
return hiddenState
}
func newVisionModel(c fs.Config) *VisionModel {
func newVisionModel(c ml.Config) *VisionModel {
return &VisionModel{
Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count")),
VisionModelOptions: &VisionModelOptions{

View File

@@ -3,7 +3,7 @@ package gemma3
import (
"image"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/imageproc"
)
@@ -11,7 +11,7 @@ type ImageProcessor struct {
imageSize, patchSize, numChannels int
}
func newImageProcessor(c fs.Config) ImageProcessor {
func newImageProcessor(c ml.Config) ImageProcessor {
return ImageProcessor{
imageSize: int(c.Uint("vision.image_size")),
patchSize: int(c.Uint("vision.patch_size")),

View File

@@ -5,7 +5,6 @@ import (
"math"
"strings"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
@@ -31,7 +30,7 @@ type Model struct {
*Options
}
func New(c fs.Config) (model.Model, error) {
func New(c ml.Config) (model.Model, error) {
if !strings.EqualFold(c.String("tokenizer.ggml.model"), "gpt2") {
return nil, fmt.Errorf("tokenizer %s not yet supported", c.String("tokenizer.ggml.model"))
}
@@ -140,18 +139,23 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
return hiddenState.Add(ctx, residual)
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
if err != nil {
return nil, err
}
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
if err != nil {
return nil, err
}
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
if err != nil {
return nil, err
}
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
for i, layer := range m.Layers {
m.Cache.SetLayer(i)

View File

@@ -1,56 +0,0 @@
package mistral3
import (
"image"
_ "image/jpeg"
_ "image/png"
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/model/imageproc"
)
type ImageProcessor struct {
imageSize int
patchSize int
numChannels int
longestEdge int
}
func newImageProcessor(c fs.Config) ImageProcessor {
return ImageProcessor{
imageSize: int(c.Uint("vision.image_size", 1540)),
patchSize: int(c.Uint("vision.patch_size", 14)),
numChannels: int(c.Uint("vision.num_channels", 3)),
longestEdge: int(c.Uint("vision.longest_edge", 1540)),
}
}
// ProcessImage prepares an image for the vision model by:
// 1. Compositing transparent images
// 2. Resizing to fit model constraints while preserving aspect ratio
// 3. Normalizing pixel values
// Returns normalized image data and the final size in pixels
func (p *ImageProcessor) ProcessImage(img image.Image) ([]float32, image.Point, error) {
img = imageproc.Composite(img)
size := img.Bounds().Size()
ratio := max(float64(size.Y)/float64(p.longestEdge), float64(size.X)/float64(p.longestEdge))
if ratio > 1.0 {
size = image.Point{
int(math.Floor(float64(size.X) / ratio)),
int(math.Floor(float64(size.Y) / ratio)),
}
}
patchesX := (size.X-1)/p.patchSize + 1
patchesY := (size.Y-1)/p.patchSize + 1
size = image.Point{
patchesX * p.patchSize,
patchesY * p.patchSize,
}
img = imageproc.Resize(img, size, imageproc.ResizeBilinear)
data := imageproc.Normalize(img, imageproc.ClipDefaultMean, imageproc.ClipDefaultSTD, true, true)
return data, size, nil
}

View File

@@ -1,189 +0,0 @@
package mistral3
import (
"bytes"
"image"
"slices"
"sync"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type Model struct {
model.Base
*TextModel
*VisionModel `gguf:"v,vision"`
*MultiModalProjector `gguf:"mm"`
ImageProcessor
}
// Implement MultimodalProcessor interface
var _ model.MultimodalProcessor = (*Model)(nil)
func New(c fs.Config) (model.Model, error) {
textModel, err := NewTextModel(c)
if err != nil {
return nil, err
}
m := &Model{
TextModel: textModel,
VisionModel: newVisionModel(c),
ImageProcessor: newImageProcessor(c),
MultiModalProjector: newMultiModalProjector(c),
}
m.Cache = kvcache.NewCausalCache(m.TextModel.Shift)
return m, nil
}
type PatchMerger struct {
MergingLayer *nn.Linear `gguf:"merging_layer"`
}
func (pm *PatchMerger) Forward(ctx ml.Context, visionOutputs ml.Tensor, size image.Point, spatialMergeSize int) ml.Tensor {
d := visionOutputs.Dim(0)
imageGrid := visionOutputs.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Reshape(ctx, size.X, size.Y, d)
kernel := ctx.Input().Empty(ml.DTypeF32, spatialMergeSize, spatialMergeSize, d)
patches := kernel.IM2Col(ctx, imageGrid, spatialMergeSize, spatialMergeSize, 0, 0, 1, 1)
reshaped := patches.Reshape(ctx, d*spatialMergeSize*spatialMergeSize, patches.Dim(1)*patches.Dim(2))
return pm.MergingLayer.Forward(ctx, reshaped)
}
type MultiModalProjector struct {
Norm *nn.RMSNorm `gguf:"norm"`
Linear1 *nn.Linear `gguf:"linear_1"`
Linear2 *nn.Linear `gguf:"linear_2"`
PatchMerger *PatchMerger `gguf:"patch_merger"`
spatialMergeSize int
eps float32
patchSize int
}
func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, size image.Point) (ml.Tensor, image.Point) {
visionOutputs = p.Norm.Forward(ctx, visionOutputs, p.eps)
patchSizes := image.Point{size.X / p.patchSize, size.Y / p.patchSize}
visionOutputs = p.PatchMerger.Forward(ctx, visionOutputs, patchSizes, p.spatialMergeSize)
visionOutputs = p.Linear1.Forward(ctx, visionOutputs)
visionOutputs = visionOutputs.GELU(ctx)
return p.Linear2.Forward(ctx, visionOutputs), image.Point{patchSizes.X / p.spatialMergeSize, patchSizes.Y / p.spatialMergeSize}
}
func newMultiModalProjector(c fs.Config) *MultiModalProjector {
return &MultiModalProjector{
spatialMergeSize: int(c.Uint("spatial_merge_size", 2)),
eps: c.Float("text_config.rms_norm_eps", 1e-5),
patchSize: int(c.Uint("vision.patch_size", 14)),
}
}
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
if len(m.VisionModel.Layers) == 0 {
return nil, model.ErrNoVisionModel
}
image, _, err := image.Decode(bytes.NewReader(multimodalData))
if err != nil {
return nil, err
}
f32s, size, err := m.ImageProcessor.ProcessImage(image)
if err != nil {
return nil, err
}
pixelValues, err := ctx.Input().FromFloatSlice(f32s, size.X, size.Y, m.ImageProcessor.numChannels)
if err != nil {
return nil, err
}
visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
features, size := m.MultiModalProjector.Forward(ctx, visionOutputs, size)
// split into patches to be sent to the text transformer
parent := imageFeatures{tensor: features}
rows := make([]*imageRow, size.Y)
for i := range rows {
rows[i] = &imageRow{parent: &parent, s: i, shape: []int{features.Dim(0), size.X}}
}
return rows, nil
}
type imageFeatures struct {
tensor ml.Tensor
dataOnce sync.Once
data []float32
}
type imageRow struct {
parent *imageFeatures
s int
shape []int
}
func (r *imageRow) data() []float32 {
n := 1
for _, s := range r.shape {
n *= s
}
return r.parent.data[r.s*n : (r.s+1)*n]
}
// PostTokenize arranges Mistral 3's inputs for the forward pass
// In Mistral 3 and Pixtral, the input patches are arranged as follows:
// [IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_END]
// Each sequence of [IMG]...[IMG] is a set of patches of vision embeddings
// that can be processed together.
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
var result []input.Input
for _, inp := range inputs {
if inp.Multimodal == nil {
result = append(result, inp)
} else {
inputMultimodal := inp.Multimodal.([]*imageRow)
for i, row := range inputMultimodal {
// [IMG]
result = append(result, input.Input{Token: 10, Multimodal: row, MultimodalHash: inp.MultimodalHash, SameBatch: row.shape[1]})
result = append(result, slices.Repeat([]input.Input{{Token: 10}}, row.shape[1]-1)...)
if i == len(inputMultimodal)-1 {
// [IMG_END]
result = append(result, input.Input{Token: 13})
} else {
// [IMG_BREAK]
result = append(result, input.Input{Token: 12})
}
}
}
}
return result, nil
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
if err != nil {
return nil, err
}
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
if err != nil {
return nil, err
}
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
}
func init() {
model.Register("mistral3", New)
}

View File

@@ -1,177 +0,0 @@
package mistral3
import (
"fmt"
"math"
"strings"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type TextOptions struct {
hiddenSize, numHeads, numKVHeads, headDim int
eps, ropeBase, ropeScale float32
ropeDim uint32
}
type TextModel struct {
model.Base
model.BytePairEncoding
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
*TextOptions
}
type SelfAttention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
}
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
batchSize := hiddenState.Dim(1)
ropeType := uint32(0)
headDim := opts.headDim
if headDim == 0 {
headDim = opts.hiddenSize / opts.numHeads
}
q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
q = q.RoPE(ctx, positionIDs, nil, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
k = k.RoPE(ctx, positionIDs, nil, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
kqv := nn.Attention(ctx, q, k, v, 1.0/math.Sqrt(float64(headDim)), cache)
kqv = kqv.Reshape(ctx, headDim*opts.numHeads, batchSize)
return sa.Output.Forward(ctx, kqv)
}
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return key.RoPE(ctx, shift, nil, uint32(0), m.ropeDim, m.ropeBase, m.ropeScale), nil
}
type MLP struct {
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
Gate *nn.Linear `gguf:"ffn_gate"`
}
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor {
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState)
}
type Layer struct {
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
SelfAttention *SelfAttention
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
MLP *MLP
}
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
residual := hiddenState
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts)
// In the final layer (outputs != nil), optimize by pruning to just the token positions
// we need logits for.
if outputs != nil {
hiddenState = hiddenState.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
}
hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState
hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
return hiddenState.Add(ctx, residual)
}
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor {
hiddenState := m.TokenEmbedding.Forward(ctx, inputs).Duplicate(ctx)
// image embeddings
for _, image := range batch.Multimodal {
row := image.Multimodal.(*imageRow)
row.parent.dataOnce.Do(func() {
// use a new, throwaway context so the image tensor is not added to the graph
temp := m.Backend().NewContext()
temp.Forward(row.parent.tensor).Compute(row.parent.tensor)
row.parent.data = row.parent.tensor.Floats()
temp.Close()
})
imageFeature, err := ctx.Input().FromFloatSlice(row.data(), row.shape...)
if err != nil {
panic(err)
}
ctx.Forward(imageFeature.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), imageFeature.Dim(0)*imageFeature.Dim(1))))
}
for i, layer := range m.Layers {
cache.SetLayer(i)
var lastLayerOutputs ml.Tensor
if i == len(m.Layers)-1 {
lastLayerOutputs = outputs
}
hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, cache, m.TextOptions)
}
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
return m.Output.Forward(ctx, hiddenState)
}
func NewTextModel(c fs.Config) (*TextModel, error) {
if !strings.EqualFold(c.String("tokenizer.ggml.model"), "gpt2") {
return nil, fmt.Errorf("tokenizer %s not yet supported", c.String("tokenizer.ggml.model"))
}
textModel := &TextModel{
BytePairEncoding: model.NewBytePairEncoding(
c.String("tokenizer.ggml.pretokenizer", `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Uints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id", 1)),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id", 2)),
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
},
),
Layers: make([]Layer, c.Uint("block_count")),
TextOptions: &TextOptions{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
headDim: int(c.Uint("attention.key_length")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.freq_scale", 1),
ropeDim: c.Uint("rope.dimension_count"),
},
}
return textModel, nil
}

View File

@@ -1,186 +0,0 @@
package mistral3
import (
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
)
var batchSize int = 1
func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor {
x1 := t.View(ctx, 0, t.Dim(0)/2, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2), t.Stride(3), t.Dim(3))
x2 := t.View(ctx, t.Stride(0)*t.Dim(0)/2, t.Dim(0)/2, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2), t.Stride(3), t.Dim(3)).Contiguous(ctx)
return x2.Neg(ctx).Concat(ctx, x1, 0)
}
func applyRotaryPositionalEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor {
return t.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, t).Mul(ctx, sin))
}
type VisionSelfAttention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
}
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, opts *VisionModelOptions) ml.Tensor {
query := sa.Query.Forward(ctx, hiddenStates)
key := sa.Key.Forward(ctx, hiddenStates)
value := sa.Value.Forward(ctx, hiddenStates)
query = query.Reshape(ctx, opts.headDim, opts.numHeads, query.Dim(1), batchSize)
key = key.Reshape(ctx, opts.headDim, opts.numHeads, key.Dim(1), batchSize)
value = value.Reshape(ctx, opts.headDim, opts.numHeads, value.Dim(1), batchSize)
query = applyRotaryPositionalEmbedding(ctx, query, cos, sin)
key = applyRotaryPositionalEmbedding(ctx, key, cos, sin)
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim)), nil)
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
return sa.Output.Forward(ctx, attention)
}
type VisionMLP struct {
Gate *nn.Linear `gguf:"ffn_gate"`
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
}
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor {
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
return mlp.Down.Forward(ctx, hiddenStates)
}
type VisionEncoderLayer struct {
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
SelfAttention *VisionSelfAttention
FFNNorm *nn.RMSNorm `gguf:"ffn_norm"`
MLP *VisionMLP
}
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, opts *VisionModelOptions) ml.Tensor {
residual := hiddenStates
hiddenStates = e.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = e.SelfAttention.Forward(ctx, hiddenStates, cos, sin, opts)
hiddenStates = hiddenStates.Add(ctx, residual)
residual = hiddenStates
hiddenStates = e.FFNNorm.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = e.MLP.Forward(ctx, hiddenStates, opts)
return hiddenStates.Add(ctx, residual)
}
type VisionModelOptions struct {
hiddenSize int
numHeads int
headDim int
intermediateSize int
imageSize int
patchSize int
numChannels int
eps float32
ropeBase float32
}
type VisionModel struct {
PatchEmbedding *nn.Conv2D `gguf:"patch_conv"`
EncoderNorm *nn.RMSNorm `gguf:"encoder_norm"`
Layers []VisionEncoderLayer `gguf:"blk"`
*VisionModelOptions
}
func (m *VisionModel) positionalEmbedding(ctx ml.Context, positionIDs ml.Tensor) ml.Tensor {
maxPatchesPerSide := m.imageSize / m.patchSize
frequencies := m.headDim / 2
frequenciesHeight := make([]float32, frequencies/2*maxPatchesPerSide)
frequenciesWidth := make([]float32, frequencies/2*maxPatchesPerSide)
for i := range frequencies {
for j := range maxPatchesPerSide {
frequency := float32(j) / float32(math.Pow(float64(m.ropeBase), float64(i)*2/float64(m.headDim)))
if i%2 == 0 {
frequenciesHeight[i/2*maxPatchesPerSide+j] = frequency
} else {
frequenciesWidth[i/2*maxPatchesPerSide+j] = frequency
}
}
}
h, err := ctx.Input().FromFloatSlice(frequenciesHeight, maxPatchesPerSide, frequencies/2)
if err != nil {
panic(err)
}
w, err := ctx.Input().FromFloatSlice(frequenciesWidth, maxPatchesPerSide, frequencies/2)
if err != nil {
panic(err)
}
h = h.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
w = w.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
h = h.Repeat(ctx, 1, maxPatchesPerSide)
h = h.Reshape(ctx, frequencies/2, maxPatchesPerSide, maxPatchesPerSide).Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
w = w.Repeat(ctx, 2, maxPatchesPerSide)
inverseFrequencies := h.Concat(ctx, w, 0).Reshape(ctx, frequencies, maxPatchesPerSide*maxPatchesPerSide)
inverseFrequencies = inverseFrequencies.Concat(ctx, inverseFrequencies, 0)
return inverseFrequencies.Rows(ctx, positionIDs)
}
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
numPatchesW := pixelValues.Dim(0) / m.patchSize
numPatchesH := pixelValues.Dim(1) / m.patchSize
numPatches := numPatchesW * numPatchesH
hiddenStates := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1)
hiddenStates = hiddenStates.Reshape(ctx, numPatches, m.hiddenSize)
hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
hiddenStates = m.EncoderNorm.Forward(ctx, hiddenStates, m.VisionModelOptions.eps)
// Prepare position IDs for 2D rope
positions := make([]int32, numPatches)
for h := range numPatchesH {
for w := range numPatchesW {
idx := h*numPatchesW + w
positions[idx] = int32(h*m.imageSize/m.patchSize + w)
}
}
positionIDs, err := ctx.Input().FromIntSlice(positions, len(positions))
if err != nil {
panic(err)
}
positionEmbedding := m.positionalEmbedding(ctx, positionIDs)
cos, sin := positionEmbedding.Cos(ctx), positionEmbedding.Sin(ctx)
cos = cos.Reshape(ctx, cos.Dim(0), 1, cos.Dim(1))
sin = sin.Reshape(ctx, sin.Dim(0), 1, sin.Dim(1))
for _, layer := range m.Layers {
hiddenStates = layer.Forward(ctx, hiddenStates, cos, sin, m.VisionModelOptions)
}
return hiddenStates
}
func newVisionModel(c fs.Config) *VisionModel {
return &VisionModel{
Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count", 24)),
VisionModelOptions: &VisionModelOptions{
hiddenSize: int(c.Uint("vision.embedding_length", 1024)),
numHeads: int(c.Uint("vision.attention.head_count", 16)),
headDim: int(c.Uint("vision.attention.key_length", 64)),
intermediateSize: int(c.Uint("vision.feed_forward_length", 4096)),
imageSize: int(c.Uint("vision.image_size", 1540)),
patchSize: int(c.Uint("vision.patch_size", 14)),
numChannels: int(c.Uint("vision.num_channels", 3)),
eps: c.Float("vision.attention.layer_norm_epsilon", 1e-5),
ropeBase: c.Float("vision.rope.freq_base", 10000.0),
},
}
}

View File

@@ -8,7 +8,6 @@ import (
"image"
"slices"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
@@ -33,7 +32,7 @@ const (
selfAttentionLayer
)
func New(c fs.Config) (model.Model, error) {
func New(c ml.Config) (model.Model, error) {
// Verify unified config
if c.Uint("vision.block_count") == 0 {
return nil, fmt.Errorf("non-unified vision model not supported")
@@ -64,10 +63,6 @@ func New(c fs.Config) (model.Model, error) {
}
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
if len(m.VisionModel.Transformer.Layers) == 0 || len(m.GlobalTransformer.Layers) == 0 {
return nil, model.ErrNoVisionModel
}
image, _, err := image.Decode(bytes.NewReader(multimodalData))
if err != nil {
return nil, err
@@ -107,17 +102,17 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
return m.Projector.Forward(ctx, crossAttentionStates), nil
}
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Input, error) {
var images []input.Input
fnvHash := fnv.New64a()
for i := range inputs {
if inputs[i].Multimodal == nil {
if len(images) > 0 {
inputs[i].Multimodal = []ml.Tensor{images[0].Multimodal.(ml.Tensor)}
inputs[i].Multimodal = images[0].Multimodal
inputs[i].MultimodalHash = images[0].MultimodalHash
for j := 1; j < len(images); j++ {
inputs[i].Multimodal = append(inputs[i].Multimodal.([]ml.Tensor), images[0].Multimodal.(ml.Tensor))
inputs[i].Multimodal = inputs[i].Multimodal.(ml.Tensor).Concat(ctx, images[j].Multimodal.(ml.Tensor), 3)
fnvHash.Reset()
binary.Write(fnvHash, binary.NativeEndian, inputs[i].MultimodalHash)
binary.Write(fnvHash, binary.NativeEndian, inputs[j].MultimodalHash)
@@ -136,27 +131,29 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
return inputs, nil
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
var crossAttentionStates ml.Tensor
if len(batch.Multimodal) > 0 {
images := batch.Multimodal[len(batch.Multimodal)-1].Multimodal.([]ml.Tensor)
if len(images) > 0 {
crossAttentionStates = images[len(images)-1]
}
if len(opts.Multimodal) > 0 {
crossAttentionStates = opts.Multimodal[len(opts.Multimodal)-1].Multimodal.(ml.Tensor)
}
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
if err != nil {
return nil, err
}
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
if err != nil {
return nil, err
}
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
if err != nil {
return nil, err
}
// TODO: attention mask, cross attention mask
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, nil, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
return m.TextModel.Forward(ctx, inputs, positions, outputs, nil, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
}
func init() {

View File

@@ -4,7 +4,6 @@ import (
"math"
"slices"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
@@ -221,7 +220,7 @@ func (m *TextModel) Forward(ctx ml.Context, inputIDs, positionIDs, outputs, mask
return m.Output.Forward(ctx, hiddenState)
}
func newTextModel(c fs.Config) *TextModel {
func newTextModel(c ml.Config) *TextModel {
var decoderLayers []TextDecoderLayer
for i := range c.Uint("block_count") {
var textDecoderLayer TextDecoderLayer

View File

@@ -4,7 +4,6 @@ import (
"math"
"slices"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
)
@@ -186,7 +185,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues, positionIDs, aspectRa
hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
hiddenState = m.PreTilePositionEmbedding.Forward(ctx, hiddenState, aspectRatioIDs, m.VisionModelOptions)
hiddenState = m.ClassEmbedding.Repeat(ctx, 2, m.numTiles).Concat(ctx, hiddenState, 1)
hiddenState = m.ClassEmbedding.Stack(ctx, 2, slices.Repeat([]ml.Tensor{m.ClassEmbedding}, m.numTiles-1)...).Concat(ctx, hiddenState, 1)
hiddenState = m.PositionEmbedding.Forward(ctx, hiddenState, positionIDs, aspectRatioIDs, numPositions, m.VisionModelOptions)
hiddenState = m.PreLayerNorm.Forward(ctx, hiddenState, m.eps)
@@ -214,7 +213,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues, positionIDs, aspectRa
return hiddenState.Concat(ctx, hiddenStates, 0)
}
func newVisionModel(c fs.Config) *VisionModel {
func newVisionModel(c ml.Config) *VisionModel {
return &VisionModel{
Transformer: &VisionEncoder{Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count"))},
GlobalTransformer: &VisionEncoder{Layers: make([]VisionEncoderLayer, c.Uint("vision.global.block_count"))},

View File

@@ -8,14 +8,14 @@ import (
"golang.org/x/image/draw"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
)
type ImageProcessor struct {
imageSize, numChannels, maxNumTiles int
}
func newImageProcessor(c fs.Config) ImageProcessor {
func newImageProcessor(c ml.Config) ImageProcessor {
return ImageProcessor{
imageSize: int(c.Uint("vision.image_size")),
numChannels: int(c.Uint("vision.num_channels")),

View File

@@ -4,6 +4,5 @@ import (
_ "github.com/ollama/ollama/model/models/gemma2"
_ "github.com/ollama/ollama/model/models/gemma3"
_ "github.com/ollama/ollama/model/models/llama"
_ "github.com/ollama/ollama/model/models/mistral3"
_ "github.com/ollama/ollama/model/models/mllama"
)

View File

@@ -0,0 +1,68 @@
package pixtral
import (
"fmt"
"image"
_ "image/jpeg"
_ "image/png"
"io"
"math"
"github.com/ollama/ollama/model/imageproc"
)
func getNumImageTokens(imageSize, patchSize image.Point) image.Point {
return image.Point{
(imageSize.X-1)/patchSize.X + 1,
(imageSize.Y-1)/patchSize.Y + 1,
}
}
func getResizeOutputImageSize(img image.Image, longestEdge int, patchSize image.Point) image.Point {
b := img.Bounds()
le := float64(longestEdge)
ratio := math.Max(float64(b.Max.Y)/le, float64(b.Max.X)/le)
newSize := img.Bounds().Max
if ratio > 1.0 {
newSize = image.Point{
int(math.Ceil(float64(b.Max.X) / ratio)),
int(math.Ceil(float64(b.Max.Y) / ratio)),
}
}
tokens := getNumImageTokens(newSize, patchSize)
return image.Point{
tokens.X * patchSize.X,
tokens.Y * patchSize.Y,
}
}
func resizeImage(img image.Image, format string, longestEdge int, patchSize image.Point) image.Image {
if format == "png" {
img = imageproc.Composite(img)
}
newSize := getResizeOutputImageSize(img, longestEdge, patchSize)
// todo should be ResizeBicubic, but it doesn't exist
return imageproc.Resize(img, newSize, imageproc.ResizeBilinear)
}
func Preprocess(imageData io.Reader) ([]float32, map[string]any, error) {
img, format, err := image.Decode(imageData)
if err != nil {
return nil, nil, fmt.Errorf("failed to decode image: %w", err)
}
longestEdge := 1024
patchSize := image.Point{16, 16}
img = resizeImage(img, format, longestEdge, patchSize)
data := imageproc.Normalize(img, imageproc.ClipDefaultMean, imageproc.ClipDefaultSTD, true, true)
opts := map[string]any{}
return data, opts, nil
}

View File

@@ -0,0 +1,219 @@
package pixtral
import (
"bytes"
"encoding/binary"
"image"
"image/png"
"math"
"os"
"testing"
"github.com/google/go-cmp/cmp"
)
func TestGetNumImageTokens(t *testing.T) {
type numImageTokensCase struct {
ImageSize image.Point
PatchSize image.Point
Expected image.Point
}
cases := []numImageTokensCase{
{
ImageSize: image.Point{1024, 764},
PatchSize: image.Point{16, 16},
Expected: image.Point{64, 48},
},
{
ImageSize: image.Point{800, 600},
PatchSize: image.Point{16, 16},
Expected: image.Point{50, 38},
},
{
ImageSize: image.Point{640, 480},
PatchSize: image.Point{16, 16},
Expected: image.Point{40, 30},
},
{
ImageSize: image.Point{320, 200},
PatchSize: image.Point{16, 16},
Expected: image.Point{20, 13},
},
{
ImageSize: image.Point{1320, 200},
PatchSize: image.Point{16, 16},
Expected: image.Point{83, 13},
},
{
ImageSize: image.Point{2000, 200},
PatchSize: image.Point{16, 16},
Expected: image.Point{125, 13},
},
{
ImageSize: image.Point{10000, 200},
PatchSize: image.Point{16, 16},
Expected: image.Point{625, 13},
},
{
ImageSize: image.Point{1131, 577},
PatchSize: image.Point{16, 16},
Expected: image.Point{71, 37},
},
{
ImageSize: image.Point{16, 16},
PatchSize: image.Point{16, 16},
Expected: image.Point{1, 1},
},
}
for _, c := range cases {
actual := getNumImageTokens(c.ImageSize, c.PatchSize)
if diff := cmp.Diff(actual, c.Expected); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
}
}
func TestGetResizeOutputImageSize(t *testing.T) {
type resizeCase struct {
Image image.Image
LongestEdge int
PatchSize image.Point
Expected image.Point
}
cases := []resizeCase{
{
Image: image.NewRGBA(image.Rect(0, 0, 1024, 768)),
LongestEdge: 1024,
PatchSize: image.Point{16, 16},
Expected: image.Point{1024, 768},
},
{
Image: image.NewRGBA(image.Rect(0, 0, 1162, 690)),
LongestEdge: 1024,
PatchSize: image.Point{16, 16},
Expected: image.Point{1024, 624},
},
{
Image: image.NewRGBA(image.Rect(0, 0, 300, 200)),
LongestEdge: 1024,
PatchSize: image.Point{16, 16},
Expected: image.Point{304, 208},
},
{
Image: image.NewRGBA(image.Rect(0, 0, 1862, 522)),
LongestEdge: 1024,
PatchSize: image.Point{16, 16},
Expected: image.Point{1024, 288},
},
}
for _, c := range cases {
actual := getResizeOutputImageSize(c.Image, c.LongestEdge, c.PatchSize)
if diff := cmp.Diff(actual, c.Expected); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
}
}
func TestResize(t *testing.T) {
type resizeCase struct {
Image image.Image
LongestEdge int
PatchSize image.Point
Expected image.Image
}
cases := []resizeCase{
{
Image: image.NewRGBA(image.Rect(0, 0, 1862, 522)),
LongestEdge: 1024,
PatchSize: image.Point{16, 16},
Expected: image.NewRGBA(image.Rect(0, 0, 1024, 288)),
},
{
Image: image.NewRGBA(image.Rect(0, 0, 10, 10)),
LongestEdge: 1024,
PatchSize: image.Point{16, 16},
Expected: image.NewRGBA(image.Rect(0, 0, 16, 16)),
},
}
for _, c := range cases {
actual := resizeImage(c.Image, "png", c.LongestEdge, c.PatchSize)
if actual.Bounds() != c.Expected.Bounds() {
t.Errorf("image size incorrect: '%#v': expected: '%#v'", actual.Bounds(), c.Expected.Bounds())
}
}
}
func TestPreprocess(t *testing.T) {
type preprocessCase struct {
TestImage image.Image
ExpectedLen int
}
cases := []preprocessCase{
{
TestImage: image.NewRGBA(image.Rect(0, 0, 10, 10)),
ExpectedLen: 16 * 16 * 3 * 1,
},
{
TestImage: image.NewRGBA(image.Rect(0, 0, 2000, 2000)),
ExpectedLen: 1024 * 1024 * 3 * 1,
},
}
for _, c := range cases {
var buf bytes.Buffer
err := png.Encode(&buf, c.TestImage)
if err != nil {
t.Fatal(err)
}
imgData, _, err := Preprocess(&buf)
if err != nil {
t.Fatalf("error processing: %q", err)
}
switch len(imgData) {
case 0:
t.Errorf("no image data returned")
case c.ExpectedLen:
// ok
default:
t.Errorf("unexpected image data length: %d, expected: %d", len(imgData), c.ExpectedLen)
}
}
}
func TestPreprocessImages(t *testing.T) {
for _, testFile := range []string{"flight.png", "sportsball.png"} {
f, err := os.Open(testFile)
if err != nil {
t.Skipf("skipping test, no test image found at %s", testFile)
}
defer f.Close()
imgData, _, err := Preprocess(f)
if err != nil {
t.Fatalf("error processing: %q", err)
}
byteData := make([]byte, len(imgData)*4) // float32 is 4 bytes
for i, f := range imgData {
binary.LittleEndian.PutUint32(byteData[i*4:], math.Float32bits(f))
}
outputPath := "processed_" + testFile + ".bin"
err = os.WriteFile(outputPath, byteData, 0o644)
if err != nil {
t.Fatalf("error writing processed image: %q", err)
}
}
}

View File

@@ -263,10 +263,6 @@ func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
continue
}
if id := bpe.vocab.Encode(pair.value); id < 0 {
continue
}
merges[pair.a].runes = append(left.runes, right.runes...)
merges[pair.b].runes = nil

View File

@@ -1,23 +1,29 @@
package model
import (
"container/heap"
"fmt"
"iter"
"log/slog"
"strconv"
"strings"
"github.com/dlclark/regexp2"
queue "github.com/emirpasic/gods/v2/queues/priorityqueue"
)
const spmWhitespaceSep = "▁"
func replaceWhitespaceBySeperator(s string) string {
return strings.ReplaceAll(s, " ", spmWhitespaceSep)
}
type SentencePieceModel struct {
maxTokenLen int
pre *regexp2.Regexp
vocab *Vocabulary
}
var _ TextProcessor = (*SentencePieceModel)(nil)
func NewSentencePieceModel(vocab *Vocabulary) SentencePieceModel {
func NewSentencePieceModel(pre string, vocab *Vocabulary) SentencePieceModel {
slog.Debug("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
counter := map[int]int{}
@@ -38,6 +44,7 @@ func NewSentencePieceModel(vocab *Vocabulary) SentencePieceModel {
return SentencePieceModel{
maxTokenLen: maxTokenLen,
pre: regexp2.MustCompile(pre, regexp2.Unicode|regexp2.RE2),
vocab: vocab,
}
}
@@ -46,9 +53,20 @@ func (spm SentencePieceModel) Is(id int32, special Special) bool {
return spm.vocab.Is(id, special)
}
func (spm *SentencePieceModel) split(s string) iter.Seq[string] {
return func(yield func(string) bool) {
for m, _ := spm.pre.FindStringMatch(s); m != nil; m, _ = spm.pre.FindNextMatch(m) {
if !yield(m.String()) {
break
}
}
}
}
func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error) {
fragments := []fragment{{value: s}}
for _, special := range spm.vocab.SpecialVocabulary() {
// TODO: process special tokens concurrently
id := spm.vocab.Encode(special)
for i := 0; i < len(fragments); i++ {
frag := fragments[i]
@@ -73,6 +91,7 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error)
fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...)
}
}
slog.Debug("fragments", "frags", fragments)
var ids []int32
for _, frag := range fragments {
@@ -81,96 +100,105 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error)
continue
}
text := strings.ReplaceAll(frag.value, " ", spmWhitespaceSep)
for split := range spm.split(frag.value) {
split = replaceWhitespaceBySeperator(split)
if id := spm.vocab.Encode(text); id >= 0 {
ids = append(ids, id)
continue
}
q := &queue{}
heap.Init(q)
runes := []rune(text)
merges := make([]merge, len(runes))
for r := range runes {
merges[r] = merge{
p: r - 1,
n: r + 1,
runes: []rune{runes[r]},
}
}
pairwise := func(a, b int) *candidate {
if a < 0 || b >= len(runes) {
return nil
}
left, right := string(merges[a].runes), string(merges[b].runes)
if id := spm.vocab.Encode(left + right); id >= 0 {
return &candidate{
a: a,
b: b,
score: spm.vocab.Scores[id],
size: len(left) + len(right),
}
}
return nil
}
for i := range len(runes) - 1 {
if pair := pairwise(i, i+1); pair != nil {
heap.Push(q, pair)
}
}
for q.Len() > 0 {
pair := heap.Pop(q).(*candidate)
left, right := merges[pair.a], merges[pair.b]
if string(left.runes) == "" || string(right.runes) == "" || len(string(left.runes))+len(string(right.runes)) != pair.size {
var sb strings.Builder
sb.Write([]byte(split))
if id := spm.vocab.Encode(sb.String()); id >= 0 {
ids = append(ids, id)
continue
}
merges[pair.a].runes = append(left.runes, right.runes...)
merges[pair.b].runes = nil
merges[pair.a].n = right.n
if right.n < len(merges) {
merges[right.n].p = pair.a
runes := []rune(sb.String())
pq := queue.NewWith(func(a, b any) int {
priA := a.(*candidate)
priB := b.(*candidate)
if priA.score > priB.score || (priA.score == priB.score && priA.a < priB.a) {
return -1
}
return 1
})
merges := make([]merge, len(runes))
for r := range runes {
merges[r] = merge{
p: r - 1,
n: r + 1,
runes: []rune{runes[r]},
}
}
if pair := pairwise(merges[pair.a].p, pair.a); pair != nil {
heap.Push(q, pair)
slog.Debug("tokenizer", "merges", merges)
pairwise := func(a, b int) *candidate {
if a < 0 || b >= len(runes) {
return nil
}
left, right := string(merges[a].runes), string(merges[b].runes)
if id := spm.vocab.Encode(left + right); id >= 0 {
return &candidate{
a: a,
b: b,
score: spm.vocab.Scores[id],
}
}
return nil
}
if pair := pairwise(pair.a, merges[pair.a].n); pair != nil {
heap.Push(q, pair)
for i := range len(runes) - 1 {
if pair := pairwise(i, i+1); pair != nil {
pq.Enqueue(pair)
}
}
}
for _, merge := range merges {
if token := string(merge.runes); token != "" {
id := spm.vocab.Encode(token)
pqv := pq.Values()
for _, v := range pqv {
e := v.(*candidate)
slog.Debug("candidate", "candidate", e)
}
if id >= 0 {
ids = append(ids, id)
for !pq.Empty() {
v, _ := pq.Dequeue()
pair := v.(*candidate)
left, right := merges[pair.a], merges[pair.b]
slog.Debug("pair", "left", left, "right", right)
if len(left.runes) == 0 || len(right.runes) == 0 {
continue
}
// Fallback to byte tokenization
var result []int32
for _, b := range []byte(token) {
byteToken := fmt.Sprintf("<0x%02X>", b)
unknownID := spm.vocab.Encode(byteToken)
if unknownID >= 0 {
result = append(result, unknownID)
} else {
slog.Debug("unknown byte token", "byte", b, "token", byteToken)
}
if id := spm.vocab.Encode(string(left.runes) + string(right.runes)); id < 0 {
continue
}
ids = append(ids, result...)
merges[pair.a].runes = append(left.runes, right.runes...)
merges[pair.b].runes = nil
merges[pair.a].n = right.n
if right.n < len(merges) {
merges[right.n].p = pair.a
}
if pair := pairwise(merges[pair.a].p, pair.a); pair != nil {
pq.Enqueue(pair)
}
if pair := pairwise(pair.a, merges[pair.a].n); pair != nil {
pq.Enqueue(pair)
}
}
slog.Debug("merges", "merges", merges)
for _, merge := range merges {
if len(merge.runes) > 0 {
if id := spm.vocab.Encode(string(merge.runes)); id >= 0 {
ids = append(ids, id)
} else {
slog.Debug("missing token", "token", string(merge.runes))
}
}
}
}
}
@@ -201,30 +229,6 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error)
type candidate struct {
a, b int
score float32
size int
}
type queue []*candidate
func (q queue) Len() int { return len(q) }
func (q queue) Less(i, j int) bool {
return (q[i].score > q[j].score) || (q[i].score == q[j].score && q[i].a < q[j].a)
}
func (q queue) Swap(i, j int) { q[i], q[j] = q[j], q[i] }
func (q *queue) Push(x interface{}) {
item := x.(*candidate)
*q = append(*q, item)
}
func (q *queue) Pop() interface{} {
old := *q
n := len(old)
item := old[n-1]
*q = old[0 : n-1]
return item
}
func (spm SentencePieceModel) Decode(ids []int32) (string, error) {
@@ -232,26 +236,11 @@ func (spm SentencePieceModel) Decode(ids []int32) (string, error) {
for _, id := range ids {
data := spm.vocab.Decode(id)
data = strings.ReplaceAll(data, spmWhitespaceSep, " ")
// For tokenizers that use byte tokens like "<0xEA>"
// convert them to the partial unicode character
// so they are buffered correctly by the runner instead
// of being sent back to the api as "<0xEA>"
if len(data) == 6 && strings.HasPrefix(data, "<0x") && strings.HasSuffix(data, ">") {
byteVal, err := strconv.ParseUint(data[1:5], 0, 8)
if err != nil {
return "", fmt.Errorf("failed to parse hex byte: %v", err)
}
if err := sb.WriteByte(byte(byteVal)); err != nil {
return "", err
}
} else {
if _, err := sb.WriteString(data); err != nil {
return "", err
}
if _, err := sb.WriteString(data); err != nil {
return "", err
}
}
slog.Debug("decoded", "ids", ids, "text", sb.String())
return sb.String(), nil
}

View File

@@ -25,6 +25,8 @@ func loadSentencePieceVocab(t *testing.T) SentencePieceModel {
t.Fatal(err)
}
preTokenizer := `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`
var v Vocabulary
for _, piece := range spm.GetPieces() {
@@ -45,7 +47,7 @@ func loadSentencePieceVocab(t *testing.T) SentencePieceModel {
}
}
return NewSentencePieceModel(&v)
return NewSentencePieceModel(preTokenizer, &v)
}
func TestSentencePieceEncode(t *testing.T) {
@@ -114,59 +116,3 @@ func TestSentencePieceEncode(t *testing.T) {
}
})
}
func TestSentencePieceModelDecodeByteTokens(t *testing.T) {
vocab := &Vocabulary{
Values: []string{
"normal",
"<0xEA>",
"<0x41>",
"<0xC3>",
"<0xA3>",
},
Types: []uint32{
TOKEN_TYPE_NORMAL,
TOKEN_TYPE_BYTE,
TOKEN_TYPE_BYTE,
TOKEN_TYPE_BYTE,
TOKEN_TYPE_BYTE,
},
Scores: []float32{0, 0, 0, 0, 0},
}
spm := NewSentencePieceModel(vocab)
tests := []struct {
name string
ids []int32
expected string
}{
{
name: "single byte token",
ids: []int32{1},
expected: "\xea",
},
{
name: "ASCII byte token",
ids: []int32{2},
expected: "A",
},
{
name: "multiple byte tokens forming UTF-8 character",
ids: []int32{3, 4},
expected: "ã",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := spm.Decode(tt.ids)
if err != nil {
t.Errorf("failed to decode token IDs %v: %v", tt.ids, err)
}
if result != tt.expected {
t.Errorf("got %q, want %q", result, tt.expected)
}
})
}
}

View File

@@ -23,10 +23,10 @@ import (
var finishReasonToolCalls = "tool_calls"
type Error struct {
Message string `json:"message"`
Type string `json:"type"`
Param any `json:"param"`
Code *string `json:"code"`
Message string `json:"message"`
Type string `json:"type"`
Param interface{} `json:"param"`
Code *string `json:"code"`
}
type ErrorResponse struct {
@@ -465,7 +465,7 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
}
}
options := make(map[string]any)
options := make(map[string]interface{})
switch stop := r.Stop.(type) {
case string:

View File

@@ -219,7 +219,7 @@ func TestChatMiddleware(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "get_current_weather",
Arguments: map[string]any{
Arguments: map[string]interface{}{
"location": "Paris, France",
"format": "celsius",
},
@@ -281,31 +281,27 @@ func TestChatMiddleware(t *testing.T) {
Description: "Get the current weather",
Parameters: struct {
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required"`
Properties map[string]struct {
Type api.PropertyType `json:"type"`
Items any `json:"items,omitempty"`
Description string `json:"description"`
Enum []any `json:"enum,omitempty"`
Type string `json:"type"`
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
} `json:"properties"`
}{
Type: "object",
Required: []string{"location"},
Properties: map[string]struct {
Type api.PropertyType `json:"type"`
Items any `json:"items,omitempty"`
Description string `json:"description"`
Enum []any `json:"enum,omitempty"`
Type string `json:"type"`
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
}{
"location": {
Type: api.PropertyType{"string"},
Type: "string",
Description: "The city and state",
},
"unit": {
Type: api.PropertyType{"string"},
Enum: []any{"celsius", "fahrenheit"},
Type: "string",
Enum: []string{"celsius", "fahrenheit"},
},
},
},

View File

@@ -11,13 +11,10 @@ import (
"os"
"os/user"
"path/filepath"
"runtime"
"slices"
"strconv"
"strings"
"sync"
"golang.org/x/sync/errgroup"
"golang.org/x/text/encoding/unicode"
"golang.org/x/text/transform"
@@ -147,25 +144,12 @@ func fileDigestMap(path string) (map[string]string, error) {
files = []string{path}
}
var mu sync.Mutex
var g errgroup.Group
g.SetLimit(max(runtime.GOMAXPROCS(0)-1, 1))
for _, f := range files {
g.Go(func() error {
digest, err := digestForFile(f)
if err != nil {
return err
}
mu.Lock()
defer mu.Unlock()
fl[f] = digest
return nil
})
}
if err := g.Wait(); err != nil {
return nil, err
digest, err := digestForFile(f)
if err != nil {
return nil, err
}
fl[f] = digest
}
return fl, nil
@@ -227,10 +211,16 @@ func filesForModel(path string) ([]string, error) {
}
var files []string
if st, _ := glob(filepath.Join(path, "*.safetensors"), "application/octet-stream"); len(st) > 0 {
if st, _ := glob(filepath.Join(path, "model*.safetensors"), "application/octet-stream"); len(st) > 0 {
// safetensors files might be unresolved git lfs references; skip if they are
// covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors
files = append(files, st...)
} else if st, _ := glob(filepath.Join(path, "adapters.safetensors"), "application/octet-stream"); len(st) > 0 {
// covers adapters.safetensors
files = append(files, st...)
} else if st, _ := glob(filepath.Join(path, "adapter_model.safetensors"), "application/octet-stream"); len(st) > 0 {
// covers adapter_model.safetensors
files = append(files, st...)
} else if pt, _ := glob(filepath.Join(path, "pytorch_model*.bin"), "application/zip"); len(pt) > 0 {
// pytorch files might also be unresolved git lfs references; skip if they are
// covers pytorch_model-x-of-y.bin, pytorch_model.fp32-x-of-y.bin, pytorch_model.bin

View File

@@ -213,16 +213,8 @@ func (c *InputCache) ShiftDiscard(inputLen int, numKeep int) int {
return discard
}
type ErrReprocessInputs struct {
Inputs []input
}
func (e *ErrReprocessInputs) Error() string {
return fmt.Sprintf("kv cache shift not supported, inputs need reprocessing (input count: %v)", len(e.Inputs))
}
// ShiftCacheSlot frees up space in the KV cache by deleting the oldest half of history
// and shifting the newest half into that space (saving numKeep inputs at the beginning).
// Frees up space in the KV cache by deleting the oldest half of history and shifting
// the newest half into that space (saving numKeep inputs at the beginning).
//
// Assumes that at least 1 entry can be freed up by shifting (i.e. numKeep < numCtx)
func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error {
@@ -230,8 +222,7 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error {
return fmt.Errorf("unable to shift context - keep exceeds context (keep: %v context: %v)", numKeep, c.numCtx)
}
inputLen := len(slot.Inputs)
discard := c.ShiftDiscard(inputLen, numKeep)
discard := c.ShiftDiscard(len(slot.Inputs), numKeep)
if discard <= 0 {
return nil
@@ -240,42 +231,16 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error {
slog.Debug("context limit hit - shifting", "id", slot.Id, "limit", c.numCtx, "input", len(slot.Inputs),
"keep", numKeep, "discard", discard)
var shiftFailed bool
if c.lc.KvCacheCanShift() {
// For models that support shifting, attempt to shift the KV cache
if !c.lc.KvCacheSeqRm(slot.Id, numKeep, numKeep+discard) {
shiftFailed = true
slog.Debug("kv cache removal not supported, clearing cache and returning inputs for reprocessing", "id", slot.Id)
} else {
c.lc.KvCacheSeqAdd(slot.Id, numKeep+discard, inputLen, -discard)
}
} else {
// For models that don't support shifting
shiftFailed = true
slog.Debug("kv cache cannot shift, clearing cache and returning inputs for reprocessing", "id", slot.Id)
// TODO (jessegross): KV cache removal can fail for certain types of models
if !c.lc.KvCacheSeqRm(slot.Id, numKeep, numKeep+discard) {
return fmt.Errorf("unable to remove old kv cache entries (id: %v, keep: %v discard: %v)", slot.Id, numKeep, discard)
}
c.lc.KvCacheSeqAdd(slot.Id, numKeep+discard, len(slot.Inputs), -discard)
if shiftFailed {
// Create new input slice with preserved tokens (numKeep + remaining tokens after discard)
newInputs := make([]input, numKeep+inputLen-(numKeep+discard))
copy(newInputs[:numKeep], slot.Inputs[:numKeep])
copy(newInputs[numKeep:], slot.Inputs[numKeep+discard:])
// Clear the entire KV cache
_ = c.lc.KvCacheSeqRm(slot.Id, 0, -1)
// Reset the slot inputs since we've cleared the cache
slot.Inputs = []input{}
// Return error with inputs that need to be reprocessed
return &ErrReprocessInputs{Inputs: newInputs}
}
// Standard shift succeeded - update input array
for i := numKeep + discard; i < inputLen; i++ {
for i := numKeep + discard; i < len(slot.Inputs); i++ {
slot.Inputs[i-discard] = slot.Inputs[i]
}
slot.Inputs = slot.Inputs[:inputLen-discard]
slot.Inputs = slot.Inputs[:len(slot.Inputs)-discard]
return nil
}

View File

@@ -24,7 +24,6 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llama"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/runner/common"
)
@@ -83,7 +82,7 @@ type Sequence struct {
// true if an embedding are to be returned instead of text generation
embeddingOnly bool
doneReason llm.DoneReason
doneReason string
// Metrics
startProcessingTime time.Time
@@ -100,7 +99,7 @@ type NewSequenceParams struct {
embedding bool
}
func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) {
func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequenceParams) (*Sequence, error) {
s.ready.Wait()
startTime := time.Now()
@@ -164,7 +163,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
// inputs processes the prompt and images into a list of inputs
// by splitting the prompt on [img-<n>] tags, tokenizing text and
// generating image embeddings for each image
func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input, error) {
func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) {
var inputs []input
var parts []string
var matches [][]string
@@ -230,7 +229,7 @@ type Server struct {
image *ImageContext
// status for external health reporting - loading, ready to serve, etc.
status llm.ServerStatus
status ServerStatus
// current progress on loading the model
progress float32
@@ -301,7 +300,7 @@ func flushPending(seq *Sequence) bool {
}
}
func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
func (s *Server) removeSequence(seqIndex int, reason string) {
seq := s.seqs[seqIndex]
flushPending(seq)
@@ -380,7 +379,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
// if past the num predict limit
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
s.removeSequence(seqIdx, llm.DoneReasonLength)
s.removeSequence(seqIdx, "limit")
continue
}
@@ -389,15 +388,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
if len(seq.pendingInputs) == 0 {
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
if err != nil {
var reprocess *ErrReprocessInputs
if errors.As(err, &reprocess) {
// Prepend these inputs to the sequence's inputs queue for reprocessing
seq.inputs = append(reprocess.Inputs, seq.inputs...)
// Continue processing as normal
continue
} else {
return err
}
return err
}
} else {
break
@@ -482,7 +473,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
}
seq.embedding <- embed
s.removeSequence(i, llm.DoneReasonStop)
s.removeSequence(i, "")
continue
}
@@ -499,7 +490,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
// as it's important for the /api/generate context
// seq.responses <- piece
s.removeSequence(i, llm.DoneReasonStop)
s.removeSequence(i, "stop")
continue
}
@@ -530,7 +521,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
}
seq.cache.Inputs = seq.cache.Inputs[:tokenLen]
s.removeSequence(i, llm.DoneReasonStop)
s.removeSequence(i, "stop")
continue
}
@@ -543,25 +534,82 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
}
if !flushPending(seq) {
s.removeSequence(i, llm.DoneReasonConnectionClosed)
s.removeSequence(i, "connection")
}
}
return nil
}
// TODO (jmorganca): use structs from the api package to avoid duplication
// this way the api acts as a proxy instead of using a different api for the
// runner
type Options struct {
api.Runner
NumKeep int `json:"n_keep"`
Seed int `json:"seed"`
NumPredict int `json:"n_predict"`
TopK int `json:"top_k"`
TopP float32 `json:"top_p"`
MinP float32 `json:"min_p"`
TypicalP float32 `json:"typical_p"`
RepeatLastN int `json:"repeat_last_n"`
Temperature float32 `json:"temperature"`
RepeatPenalty float32 `json:"repeat_penalty"`
PresencePenalty float32 `json:"presence_penalty"`
FrequencyPenalty float32 `json:"frequency_penalty"`
Mirostat int `json:"mirostat"`
MirostatTau float32 `json:"mirostat_tau"`
MirostatEta float32 `json:"mirostat_eta"`
Stop []string `json:"stop"`
}
type ImageData struct {
Data []byte `json:"data"`
ID int `json:"id"`
AspectRatioID int `json:"aspect_ratio_id"`
}
type CompletionRequest struct {
Prompt string `json:"prompt"`
Images []ImageData `json:"image_data"`
Grammar string `json:"grammar"`
CachePrompt bool `json:"cache_prompt"`
Options
}
type Timings struct {
PredictedN int `json:"predicted_n"`
PredictedMS float64 `json:"predicted_ms"`
PromptN int `json:"prompt_n"`
PromptMS float64 `json:"prompt_ms"`
}
type CompletionResponse struct {
Content string `json:"content"`
Stop bool `json:"stop"`
Model string `json:"model,omitempty"`
Prompt string `json:"prompt,omitempty"`
StoppedLimit bool `json:"stopped_limit,omitempty"`
PredictedN int `json:"predicted_n,omitempty"`
PredictedMS float64 `json:"predicted_ms,omitempty"`
PromptN int `json:"prompt_n,omitempty"`
PromptMS float64 `json:"prompt_ms,omitempty"`
Timings Timings `json:"timings"`
}
func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
var req llm.CompletionRequest
var req CompletionRequest
req.Options = Options(api.DefaultOptions())
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Bad request", http.StatusBadRequest)
return
}
if req.Options == nil {
opts := api.DefaultOptions()
req.Options = &opts
}
// Set the headers to indicate streaming
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Transfer-Encoding", "chunked")
@@ -572,28 +620,26 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return
}
// Extract options from the CompletionRequest
samplingParams := llama.SamplingParams{
TopK: req.Options.TopK,
TopP: req.Options.TopP,
MinP: req.Options.MinP,
TypicalP: req.Options.TypicalP,
Temp: req.Options.Temperature,
RepeatLastN: req.Options.RepeatLastN,
PenaltyRepeat: req.Options.RepeatPenalty,
PenaltyFreq: req.Options.FrequencyPenalty,
PenaltyPresent: req.Options.PresencePenalty,
Mirostat: req.Options.Mirostat,
MirostatTau: req.Options.MirostatTau,
MirostatEta: req.Options.MirostatEta,
Seed: uint32(req.Options.Seed),
Grammar: req.Grammar,
}
var samplingParams llama.SamplingParams
samplingParams.TopK = req.TopK
samplingParams.TopP = req.TopP
samplingParams.MinP = req.MinP
samplingParams.TypicalP = req.TypicalP
samplingParams.Temp = req.Temperature
samplingParams.RepeatLastN = req.RepeatLastN
samplingParams.PenaltyRepeat = req.RepeatPenalty
samplingParams.PenaltyFreq = req.FrequencyPenalty
samplingParams.PenaltyPresent = req.PresencePenalty
samplingParams.Mirostat = req.Mirostat
samplingParams.MirostatTau = req.MirostatTau
samplingParams.MirostatEta = req.MirostatEta
samplingParams.Seed = uint32(req.Seed)
samplingParams.Grammar = req.Grammar
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
numPredict: req.Options.NumPredict,
stop: req.Options.Stop,
numKeep: req.Options.NumKeep,
numPredict: req.NumPredict,
stop: req.Stop,
numKeep: req.NumKeep,
samplingParams: &samplingParams,
embedding: false,
})
@@ -607,7 +653,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
if errors.Is(err, context.Canceled) {
slog.Info("aborting completion request due to client closing the connection")
} else {
http.Error(w, fmt.Sprintf("Failed to acquire semaphore: %v", err), http.StatusInternalServerError)
slog.Error("Failed to acquire semaphore", "error", err)
}
return
}
@@ -616,10 +662,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
found := false
for i, sq := range s.seqs {
if sq == nil {
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, true)
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
if err != nil {
s.mu.Unlock()
s.seqsSem.Release(1)
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
return
}
@@ -635,7 +680,6 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
s.mu.Unlock()
if !found {
s.seqsSem.Release(1)
http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
return
}
@@ -647,7 +691,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return
case content, ok := <-seq.responses:
if ok {
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
if err := json.NewEncoder(w).Encode(&CompletionResponse{
Content: content,
}); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
@@ -657,13 +701,16 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
flusher.Flush()
} else {
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
Done: true,
DoneReason: seq.doneReason,
PromptEvalCount: seq.numPromptInputs,
PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime),
EvalCount: seq.numDecoded,
EvalDuration: time.Since(seq.startGenerationTime),
// Send the final response
if err := json.NewEncoder(w).Encode(&CompletionResponse{
Stop: true,
StoppedLimit: seq.doneReason == "limit",
Timings: Timings{
PromptN: seq.numPromptInputs,
PromptMS: float64(seq.startGenerationTime.Sub(seq.startProcessingTime).Milliseconds()),
PredictedN: seq.numDecoded,
PredictedMS: float64(time.Since(seq.startGenerationTime).Milliseconds()),
},
}); err != nil {
http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError)
}
@@ -674,8 +721,17 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
}
}
type EmbeddingRequest struct {
Content string `json:"content"`
CachePrompt bool `json:"cache_prompt"`
}
type EmbeddingResponse struct {
Embedding []float32 `json:"embedding"`
}
func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
var req llm.EmbeddingRequest
var req EmbeddingRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, fmt.Sprintf("bad request: %s", err), http.StatusBadRequest)
return
@@ -696,7 +752,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
if errors.Is(err, context.Canceled) {
slog.Info("aborting embeddings request due to client closing the connection")
} else {
http.Error(w, fmt.Sprintf("Failed to acquire semaphore: %v", err), http.StatusInternalServerError)
slog.Error("Failed to acquire semaphore", "error", err)
}
return
}
@@ -705,10 +761,9 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
found := false
for i, sq := range s.seqs {
if sq == nil {
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, false)
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
if err != nil {
s.mu.Unlock()
s.seqsSem.Release(1)
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
return
}
@@ -721,24 +776,47 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
s.mu.Unlock()
if !found {
s.seqsSem.Release(1)
http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
return
}
embedding := <-seq.embedding
if err := json.NewEncoder(w).Encode(&llm.EmbeddingResponse{
if err := json.NewEncoder(w).Encode(&EmbeddingResponse{
Embedding: embedding,
}); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
}
}
type HealthResponse struct {
Status string `json:"status"`
Progress float32 `json:"progress"`
}
type ServerStatus int
const (
ServerStatusReady ServerStatus = iota
ServerStatusLoadingModel
ServerStatusError
)
func (s ServerStatus) ToString() string {
switch s {
case ServerStatusReady:
return "ok"
case ServerStatusLoadingModel:
return "loading model"
default:
return "server error"
}
}
func (s *Server) health(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(&llm.ServerStatusResponse{
Status: s.status,
if err := json.NewEncoder(w).Encode(&HealthResponse{
Status: s.status.ToString(),
Progress: s.progress,
}); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
@@ -801,7 +879,7 @@ func (s *Server) loadModel(
panic(err)
}
s.status = llm.ServerStatusReady
s.status = ServerStatusReady
s.ready.Done()
}
@@ -859,7 +937,7 @@ func Execute(args []string) error {
parallel: *parallel,
seqs: make([]*Sequence, *parallel),
seqsSem: semaphore.NewWeighted(int64(*parallel)),
status: llm.ServerStatusLoadingModel,
status: ServerStatusLoadingModel,
}
var tensorSplitFloats []float32

View File

@@ -31,10 +31,8 @@ type InputCache struct {
cache kvcache.Cache
}
func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots int, batchSize int, multiUserCache bool) (*InputCache, error) {
numCtx := kvSize / int32(numSlots)
if numCtx < 1 {
func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots int, multiUserCache bool) (*InputCache, error) {
if kvSize/int32(numSlots) < 1 {
return nil, fmt.Errorf("must have at least one kv cache entry per parallel sequence (kv: %v parallel: %v)", kvSize, numSlots)
}
@@ -46,11 +44,11 @@ func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots
cache := model.Config().Cache
if cache != nil {
cache.Init(model.Backend(), kvCacheTypeFromStr(kvCacheType), numSlots, int(numCtx), batchSize)
cache.Init(model.Backend(), kvCacheTypeFromStr(kvCacheType), kvSize)
}
return &InputCache{
numCtx: numCtx,
numCtx: kvSize / int32(numSlots),
enabled: cache != nil,
slots: slots,
multiUserCache: multiUserCache,
@@ -91,7 +89,7 @@ type InputCacheSlot struct {
lastUsed time.Time
}
func (c *InputCache) LoadCacheSlot(prompt []input.Input) (*InputCacheSlot, []input.Input, error) {
func (c *InputCache) LoadCacheSlot(prompt []input.Input, cachePrompt bool) (*InputCacheSlot, []input.Input, error) {
var slot *InputCacheSlot
var numPast int32
var err error
@@ -109,6 +107,10 @@ func (c *InputCache) LoadCacheSlot(prompt []input.Input) (*InputCacheSlot, []inp
return nil, nil, err
}
if !cachePrompt {
numPast = 0
}
slot.InUse = true
slot.lastUsed = time.Now()
@@ -118,10 +120,6 @@ func (c *InputCache) LoadCacheSlot(prompt []input.Input) (*InputCacheSlot, []inp
}
if c.cache != nil {
if numPast > 0 && !c.cache.CanResume(slot.Id, numPast) {
numPast = 0
}
err = c.cache.Remove(slot.Id, numPast, math.MaxInt32)
if err != nil {
// Some models don't support partial erasure
@@ -229,8 +227,6 @@ func countCommonPrefix(a []input.Input, b []input.Input) int32 {
return count
}
// TODO(jessegross): If we need to reprocess the inputs we should ensure that
// we don't split up a SameBatch
func (c *InputCache) ShiftDiscard(inputLen int32, numKeep int32) int32 {
targetFree := (c.numCtx - numKeep) / 2
targetFree = max(targetFree, 1)
@@ -245,14 +241,6 @@ func (c *InputCache) ShiftDiscard(inputLen int32, numKeep int32) int32 {
return discard
}
type ErrReprocessInputs struct {
Inputs []input.Input
}
func (e *ErrReprocessInputs) Error() string {
return fmt.Sprintf("kv cache shift not supported, inputs need reprocessing (input count: %v)", len(e.Inputs))
}
// Frees up space in the KV cache by deleting the oldest half of history and shifting
// the newest half into that space (saving numKeep inputs at the beginning).
//
@@ -272,23 +260,11 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int32) error {
slog.Debug("context limit hit - shifting", "id", slot.Id, "limit", c.numCtx, "input", len(slot.Inputs),
"keep", numKeep, "discard", discard)
// TODO (jessegross): KV cache removal can fail for certain types of models
if c.cache != nil {
err := c.cache.Remove(slot.Id, numKeep, numKeep+discard)
if err != nil {
slog.Debug("kv cache removal unsupported, clearing cache and returning inputs for reprocessing",
"id", slot.Id, "error", err)
// Create new input slice with preserved tokens (numKeep + remaining tokens after discard)
newInputs := make([]input.Input, numKeep+inputLen-(numKeep+discard))
copy(newInputs[:numKeep], slot.Inputs[:numKeep])
copy(newInputs[numKeep:], slot.Inputs[numKeep+discard:])
// Reset the cache
_ = c.cache.Remove(slot.Id, 0, -1)
slot.Inputs = []input.Input{}
// Return error with inputs that need to be reprocessed
return &ErrReprocessInputs{Inputs: newInputs}
return fmt.Errorf("unable to remove old kv cache entries (id: %v, keep: %v discard: %v): %w", slot.Id, numKeep, discard, err)
}
}

View File

@@ -1,13 +1,10 @@
package ollamarunner
import (
"errors"
"fmt"
"image"
"testing"
"time"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/input"
)
@@ -300,220 +297,3 @@ func TestShiftDiscard(t *testing.T) {
})
}
}
func TestLoadCacheSlot(t *testing.T) {
tests := []struct {
name string
cache InputCache
prompt []input.Input
wantErr bool
expectedSlotId int
expectedPrompt int // expected length of remaining prompt
}{
{
name: "Basic cache hit - single user",
cache: InputCache{
multiUserCache: false,
slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input.Input{{Token: 1}, {Token: 2}},
InUse: false,
lastUsed: time.Now().Add(-time.Second),
},
{
Id: 1,
Inputs: []input.Input{},
InUse: false,
lastUsed: time.Now().Add(-2 * time.Second),
},
},
},
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
wantErr: false,
expectedSlotId: 0,
expectedPrompt: 1, // Only token 3 remains
},
{
name: "Basic cache hit - multi user",
cache: InputCache{
multiUserCache: true,
slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input.Input{{Token: 1}, {Token: 2}},
InUse: false,
lastUsed: time.Now().Add(-time.Second),
},
{
Id: 1,
Inputs: []input.Input{},
InUse: false,
lastUsed: time.Now().Add(-2 * time.Second),
},
},
},
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
wantErr: false,
expectedSlotId: 0,
expectedPrompt: 1, // Only token 3 remains
},
{
name: "Exact match - leave one input",
cache: InputCache{
multiUserCache: false,
slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input.Input{{Token: 1}, {Token: 2}},
InUse: false,
lastUsed: time.Now().Add(-time.Second),
},
},
},
prompt: []input.Input{{Token: 1}, {Token: 2}},
wantErr: false,
expectedSlotId: 0,
expectedPrompt: 1, // Should leave 1 token for sampling
},
{
name: "No available slots",
cache: InputCache{
multiUserCache: false,
slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input.Input{{Token: 1}, {Token: 2}},
InUse: true,
lastUsed: time.Now().Add(-time.Second),
},
},
},
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
wantErr: true,
expectedSlotId: -1,
expectedPrompt: -1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
slot, remainingPrompt, err := tt.cache.LoadCacheSlot(tt.prompt)
// Check error state
if (err != nil) != tt.wantErr {
t.Errorf("LoadCacheSlot() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.wantErr {
return // Skip further checks if we expected an error
}
// Verify slot ID
if slot.Id != tt.expectedSlotId {
t.Errorf("LoadCacheSlot() slot ID = %v, expected %v", slot.Id, tt.expectedSlotId)
}
// Verify slot is now marked in use
if !slot.InUse {
t.Errorf("LoadCacheSlot() slot not marked InUse")
}
// Verify remaining prompt length
if len(remainingPrompt) != tt.expectedPrompt {
t.Errorf("LoadCacheSlot() remaining prompt length = %v, expected %v",
len(remainingPrompt), tt.expectedPrompt)
}
})
}
}
// Mock implementation of the Cache interface
type mockCache struct {
shouldFail bool
}
// Implement only the methods needed for the test
func (m *mockCache) Remove(seq int, beginIndex, endIndex int32) error {
if m.shouldFail {
return fmt.Errorf("mock cache removal error")
}
return nil
}
// Stub implementations for other interface methods
func (m *mockCache) SetLayer(layer int) {}
func (m *mockCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) { return nil, nil, nil }
func (m *mockCache) Put(ctx ml.Context, key, value ml.Tensor) {}
func (m *mockCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {}
func (m *mockCache) Close() {}
func (m *mockCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error { return nil }
func (m *mockCache) CopyPrefix(srcSeq, dstSeq int, len int32) {}
func (m *mockCache) SetConfig(ml.CacheConfig) {}
func (m *mockCache) CanResume(seq int, pos int32) bool { return true }
func TestShiftCacheSlot(t *testing.T) {
tests := []struct {
name string
numCtx int32
inputs []input.Input
numKeep int32
cacheErr bool
wantErr any
wantInputsLen int
}{
{
name: "Normal shift",
numCtx: 10,
inputs: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
numKeep: 2,
cacheErr: false, // No error
wantErr: nil,
wantInputsLen: 6, // After discarding 4 tokens
},
{
name: "Cache removal fails",
numCtx: 10,
inputs: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
numKeep: 2,
cacheErr: true,
wantErr: &ErrReprocessInputs{},
wantInputsLen: 0, // Original inputs should be cleared
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mock := &mockCache{shouldFail: tt.cacheErr}
c := InputCache{
numCtx: tt.numCtx,
cache: mock,
}
slot := &InputCacheSlot{
Id: 123,
Inputs: make([]input.Input, len(tt.inputs)),
}
copy(slot.Inputs, tt.inputs)
err := c.ShiftCacheSlot(slot, tt.numKeep)
if tt.wantErr != nil {
if err == nil {
t.Errorf("Expected error but got nil")
return
}
if !errors.As(err, &tt.wantErr) {
t.Errorf("Expected error of type %T but got %T: %v", tt.wantErr, err, err)
}
} else if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if len(slot.Inputs) != tt.wantInputsLen {
t.Errorf("Slot inputs length after operation: got %v, want %v", len(slot.Inputs), tt.wantInputsLen)
}
})
}
}

View File

@@ -24,7 +24,6 @@ import (
"golang.org/x/sync/semaphore"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
@@ -34,14 +33,10 @@ import (
_ "github.com/ollama/ollama/model/models"
)
type contextList struct {
list []ml.Context
}
type Sequence struct {
// ctxs are used for allocating tensors that last the lifetime of the sequence, such as
// ctx for allocating tensors that last the lifetime of the sequence, such as
// multimodal embeddings
ctxs *contextList
ctx ml.Context
// batch index
iBatch int
@@ -82,7 +77,7 @@ type Sequence struct {
// true if an embedding are to be returned instead of text generation
embeddingOnly bool
doneReason llm.DoneReason
doneReason string
// Metrics
startProcessingTime time.Time
@@ -99,12 +94,13 @@ type NewSequenceParams struct {
embedding bool
}
func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) {
func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequenceParams) (*Sequence, error) {
s.ready.Wait()
startTime := time.Now()
ctx := s.model.Backend().NewContext()
inputs, ctxs, err := s.inputs(prompt, images)
inputs, err := s.inputs(ctx, prompt, images)
if err != nil {
return nil, fmt.Errorf("failed to process inputs: %w", err)
} else if len(inputs) == 0 {
@@ -120,36 +116,8 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
if int32(len(inputs)) > s.cache.numCtx {
discard := int32(len(inputs)) - s.cache.numCtx
promptStart := params.numKeep + discard
// If we need to truncate in the middle of a unbreakable batch, remove the entire batch
sameBatch := 0
for i, inp := range inputs {
if sameBatch > 0 {
sameBatch--
if promptStart == int32(i) {
promptStart++
}
} else if promptStart == int32(i) {
break
}
if inp.SameBatch != 0 {
if int32(i) < params.numKeep {
return nil, fmt.Errorf("SameBatch may not be specified within numKeep (index: %v numKeep: %v SameBatch: %v)", i, params.numKeep, inp.SameBatch)
}
sameBatch = inp.SameBatch
}
}
if promptStart >= int32(len(inputs)) {
return nil, errors.New("entire prompt removed by truncation")
}
newInputs := inputs[:params.numKeep]
newInputs = append(newInputs, inputs[promptStart:]...)
newInputs = append(newInputs, inputs[params.numKeep+discard:]...)
slog.Warn("truncating input prompt", "limit", s.cache.numCtx, "prompt", len(inputs), "keep", params.numKeep, "new", len(newInputs))
inputs = newInputs
@@ -158,7 +126,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
// TODO(jessegross): Ingest cached history for grammar
return &Sequence{
ctxs: ctxs,
ctx: ctx,
inputs: inputs,
numPromptInputs: len(inputs),
startProcessingTime: startTime,
@@ -177,7 +145,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
// inputs processes the prompt and images into a list of inputs
// by splitting the prompt on [img-<n>] tags, tokenizing text and
// decoding images
func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, *contextList, error) {
func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]input.Input, error) {
var inputs []input.Input
var parts []string
var matches [][]string
@@ -192,19 +160,12 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, *
parts = []string{prompt}
}
var contexts contextList
runtime.AddCleanup(&contexts, func(ctxs []ml.Context) {
for _, ctx := range ctxs {
ctx.Close()
}
}, contexts.list)
postTokenize := false
for i, part := range parts {
// text - tokenize
tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0)
if err != nil {
return nil, nil, err
return nil, err
}
for _, t := range tokens {
@@ -224,14 +185,12 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, *
}
if imageIndex < 0 {
return nil, nil, fmt.Errorf("invalid image index: %d", n)
return nil, fmt.Errorf("invalid image index: %d", n)
}
ctx := s.model.Backend().NewContext()
contexts.list = append(contexts.list, ctx)
imageEmbeddings, err := multimodalProcessor.EncodeMultimodal(ctx, images[imageIndex].Data)
if err != nil {
return nil, nil, err
return nil, err
}
s.multimodalHash.Reset()
@@ -245,13 +204,13 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, *
if visionModel && postTokenize {
var err error
inputs, err = multimodalProcessor.PostTokenize(inputs)
inputs, err = multimodalProcessor.PostTokenize(ctx, inputs)
if err != nil {
return nil, nil, err
return nil, err
}
}
return inputs, &contexts, nil
return inputs, nil
}
type Server struct {
@@ -263,7 +222,7 @@ type Server struct {
model model.Model
// status for external health reporting - loading, ready to serve, etc.
status llm.ServerStatus
status ServerStatus
// current progress on loading the model
progress float32
@@ -292,9 +251,6 @@ type Server struct {
// KV cache
cache *InputCache
// next sequence for prompt processing to avoid starvation
nextSeq int
// multimodalHash generates hashes for comparing equality
// of non-text data
multimodalHash maphash.Hash
@@ -341,7 +297,7 @@ func flushPending(seq *Sequence) bool {
}
}
func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
func (s *Server) removeSequence(seqIndex int, reason string) {
seq := s.seqs[seqIndex]
flushPending(seq)
@@ -349,6 +305,7 @@ func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
close(seq.responses)
close(seq.embedding)
seq.cache.InUse = false
seq.ctx.Close()
s.seqs[seqIndex] = nil
s.seqsSem.Release(1)
}
@@ -376,22 +333,16 @@ func (s *Server) processBatch() error {
}
defer s.mu.Unlock()
var batchInputs []int32
var batch input.Batch
resumeSeq := -1
seqIdx := s.nextSeq - 1
for range s.seqs {
seqIdx = (seqIdx + 1) % len(s.seqs)
seq := s.seqs[seqIdx]
var options input.Options
for i, seq := range s.seqs {
if seq == nil {
continue
}
// if past the num predict limit
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
s.removeSequence(seqIdx, llm.DoneReasonLength)
s.removeSequence(i, "limit")
continue
}
@@ -400,61 +351,33 @@ func (s *Server) processBatch() error {
seq.cache.Inputs = []input.Input{}
}
batchSize := s.batchSize
for i, inp := range seq.inputs {
// If we are required to put following inputs into a single batch then extend the
// batch size. Since we are only extending the size the minimum amount possible, this
// will cause a break if we have existing inputs.
minBatch := 1 + inp.SameBatch
if minBatch > batchSize {
batchSize = minBatch
for j, inp := range seq.inputs {
if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+1) > s.cache.numCtx {
if len(seq.pendingInputs) == 0 {
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
if err != nil {
return err
}
} else {
break
}
}
// Stop if the required batch would put us over the total batch size (including tokens
// added by other sequences). If we haven't been able to add anything yet then pick up
// here again for the next batch to avoid starvation, though we can opportunistically
// check if other sequences can still squeeze something in.
if len(batchInputs)+minBatch > batchSize {
if len(seq.pendingInputs) == 0 && resumeSeq == -1 {
resumeSeq = seqIdx
}
if j >= s.batchSize {
break
}
// If the sum of our working set (already processed tokens, tokens we added to this
// batch, required following tokens) exceeds the context size, then trigger a shift
// now so we don't have to do one later when we can't break the batch.
if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+minBatch) > s.cache.numCtx {
if len(seq.pendingInputs) != 0 {
break
}
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
if err != nil {
var reprocess *ErrReprocessInputs
if errors.As(err, &reprocess) {
// Prepend these inputs to the sequence's inputs queue for reprocessing
seq.inputs = append(reprocess.Inputs, seq.inputs...)
// Skip this sequence but continue processing the rest
continue
} else {
return err
}
}
}
batchInputs = append(batchInputs, inp.Token)
options.Inputs = append(options.Inputs, inp.Token)
if inp.Multimodal != nil {
batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: len(batchInputs) - 1, Multimodal: inp.Multimodal})
options.Multimodal = append(options.Multimodal, input.MultimodalIndex{Index: len(options.Inputs) - 1, Multimodal: inp.Multimodal})
}
batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
batch.Sequences = append(batch.Sequences, seq.cache.Id)
options.Positions = append(options.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
options.Sequences = append(options.Sequences, seq.cache.Id)
seq.iBatch = len(batch.Outputs)
if i+1 == len(seq.inputs) {
batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1))
seq.iBatch = len(options.Outputs)
if j+1 == len(seq.inputs) {
options.Outputs = append(options.Outputs, int32(len(options.Inputs)-1))
}
seq.pendingInputs = append(seq.pendingInputs, inp)
}
@@ -462,20 +385,14 @@ func (s *Server) processBatch() error {
seq.inputs = seq.inputs[len(seq.pendingInputs):]
}
if resumeSeq != -1 {
s.nextSeq = resumeSeq
} else {
s.nextSeq = seqIdx + 1
}
if len(batchInputs) == 0 {
if len(options.Inputs) == 0 {
return nil
}
ctx := s.model.Backend().NewContext()
defer ctx.Close()
modelOutput, err := model.Forward(ctx, s.model, batchInputs, batch)
modelOutput, err := model.Forward(ctx, s.model, options)
if err != nil {
return fmt.Errorf("failed to decode batch: %w", err)
}
@@ -510,12 +427,12 @@ func (s *Server) processBatch() error {
if seq.embeddingOnly {
// TODO(jessegross): Embedding support
slog.Warn("generation of embedding outputs not yet supported")
s.removeSequence(i, llm.DoneReasonStop)
s.removeSequence(i, "")
continue
}
// sample a token
vocabSize := len(logits) / len(batch.Outputs)
vocabSize := len(logits) / len(options.Outputs)
token, err := seq.sampler.Sample(logits[seq.iBatch*vocabSize : (seq.iBatch+1)*vocabSize])
if err != nil {
@@ -528,7 +445,7 @@ func (s *Server) processBatch() error {
// as it's important for the /api/generate context
// seq.responses <- piece
s.removeSequence(i, llm.DoneReasonStop)
s.removeSequence(i, "stop")
continue
}
@@ -564,7 +481,7 @@ func (s *Server) processBatch() error {
}
seq.cache.Inputs = seq.cache.Inputs[:tokenLen]
s.removeSequence(i, llm.DoneReasonStop)
s.removeSequence(i, "stop")
continue
}
@@ -577,25 +494,82 @@ func (s *Server) processBatch() error {
}
if !flushPending(seq) {
s.removeSequence(i, llm.DoneReasonConnectionClosed)
s.removeSequence(i, "connection")
}
}
return nil
}
// TODO (jmorganca): use structs from the api package to avoid duplication
// this way the api acts as a proxy instead of using a different api for the
// runner
type Options struct {
api.Runner
NumKeep int `json:"n_keep"`
Seed int `json:"seed"`
NumPredict int `json:"n_predict"`
TopK int `json:"top_k"`
TopP float32 `json:"top_p"`
MinP float32 `json:"min_p"`
TypicalP float32 `json:"typical_p"`
RepeatLastN int `json:"repeat_last_n"`
Temperature float32 `json:"temperature"`
RepeatPenalty float32 `json:"repeat_penalty"`
PresencePenalty float32 `json:"presence_penalty"`
FrequencyPenalty float32 `json:"frequency_penalty"`
Mirostat int `json:"mirostat"`
MirostatTau float32 `json:"mirostat_tau"`
MirostatEta float32 `json:"mirostat_eta"`
Stop []string `json:"stop"`
}
type ImageData struct {
Data []byte `json:"data"`
ID int `json:"id"`
AspectRatioID int `json:"aspect_ratio_id"`
}
type CompletionRequest struct {
Prompt string `json:"prompt"`
Images []ImageData `json:"image_data"`
Grammar string `json:"grammar"`
CachePrompt bool `json:"cache_prompt"`
Options
}
type Timings struct {
PredictedN int `json:"predicted_n"`
PredictedMS float64 `json:"predicted_ms"`
PromptN int `json:"prompt_n"`
PromptMS float64 `json:"prompt_ms"`
}
type CompletionResponse struct {
Content string `json:"content"`
Stop bool `json:"stop"`
Model string `json:"model,omitempty"`
Prompt string `json:"prompt,omitempty"`
StoppedLimit bool `json:"stopped_limit,omitempty"`
PredictedN int `json:"predicted_n,omitempty"`
PredictedMS float64 `json:"predicted_ms,omitempty"`
PromptN int `json:"prompt_n,omitempty"`
PromptMS float64 `json:"prompt_ms,omitempty"`
Timings Timings `json:"timings"`
}
func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
var req llm.CompletionRequest
var req CompletionRequest
req.Options = Options(api.DefaultOptions())
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Bad request", http.StatusBadRequest)
return
}
if req.Options == nil {
opts := api.DefaultOptions()
req.Options = &opts
}
// Set the headers to indicate streaming
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Transfer-Encoding", "chunked")
@@ -617,18 +591,18 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
}
sampler := sample.NewSampler(
req.Options.Temperature,
req.Options.TopK,
req.Options.TopP,
req.Options.MinP,
req.Options.Seed,
req.Temperature,
req.TopK,
req.TopP,
req.MinP,
req.Seed,
grammar,
)
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
numPredict: req.Options.NumPredict,
stop: req.Options.Stop,
numKeep: int32(req.Options.NumKeep),
numPredict: req.NumPredict,
stop: req.Stop,
numKeep: int32(req.NumKeep),
sampler: sampler,
embedding: false,
})
@@ -642,7 +616,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
if errors.Is(err, context.Canceled) {
slog.Info("aborting completion request due to client closing the connection")
} else {
http.Error(w, fmt.Sprintf("Failed to acquire semaphore: %v", err), http.StatusInternalServerError)
slog.Error("Failed to acquire semaphore", "error", err)
}
return
}
@@ -651,10 +625,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
found := false
for i, sq := range s.seqs {
if sq == nil {
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs)
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
if err != nil {
s.mu.Unlock()
s.seqsSem.Release(1)
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
return
}
@@ -668,7 +641,6 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
s.mu.Unlock()
if !found {
s.seqsSem.Release(1)
http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
return
}
@@ -680,7 +652,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return
case content, ok := <-seq.responses:
if ok {
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
if err := json.NewEncoder(w).Encode(&CompletionResponse{
Content: content,
}); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
@@ -690,13 +662,16 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
flusher.Flush()
} else {
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
Done: true,
DoneReason: seq.doneReason,
PromptEvalCount: seq.numPromptInputs,
PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime),
EvalCount: seq.numPredicted,
EvalDuration: time.Since(seq.startGenerationTime),
// Send the final response
if err := json.NewEncoder(w).Encode(&CompletionResponse{
Stop: true,
StoppedLimit: seq.doneReason == "limit",
Timings: Timings{
PromptN: seq.numPromptInputs,
PromptMS: float64(seq.startGenerationTime.Sub(seq.startProcessingTime).Milliseconds()),
PredictedN: seq.numPredicted,
PredictedMS: float64(time.Since(seq.startGenerationTime).Milliseconds()),
},
}); err != nil {
http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError)
}
@@ -707,10 +682,102 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
}
}
type EmbeddingRequest struct {
Content string `json:"content"`
CachePrompt bool `json:"cache_prompt"`
}
type EmbeddingResponse struct {
Embedding []float32 `json:"embedding"`
}
func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
var req EmbeddingRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, fmt.Sprintf("bad request: %s", err), http.StatusBadRequest)
return
}
w.Header().Set("Content-Type", "application/json")
slog.Debug("embedding request", "content", req.Content)
seq, err := s.NewSequence(req.Content, nil, NewSequenceParams{embedding: true})
if err != nil {
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
return
}
// Ensure there is a place to put the sequence, released when removed from s.seqs
if err := s.seqsSem.Acquire(r.Context(), 1); err != nil {
if errors.Is(err, context.Canceled) {
slog.Info("aborting embeddings request due to client closing the connection")
} else {
slog.Error("Failed to acquire semaphore", "error", err)
}
return
}
s.mu.Lock()
found := false
for i, sq := range s.seqs {
if sq == nil {
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
if err != nil {
s.mu.Unlock()
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
return
}
s.seqs[i] = seq
s.cond.Signal()
found = true
break
}
}
s.mu.Unlock()
if !found {
http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
return
}
embedding := <-seq.embedding
if err := json.NewEncoder(w).Encode(&EmbeddingResponse{
Embedding: embedding,
}); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
}
}
type HealthResponse struct {
Status string `json:"status"`
Progress float32 `json:"progress"`
}
type ServerStatus int
const (
ServerStatusReady ServerStatus = iota
ServerStatusLoadingModel
ServerStatusError
)
func (s ServerStatus) ToString() string {
switch s {
case ServerStatusReady:
return "ok"
case ServerStatusLoadingModel:
return "loading model"
default:
return "server error"
}
}
func (s *Server) health(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(&llm.ServerStatusResponse{
Status: s.status,
if err := json.NewEncoder(w).Encode(&HealthResponse{
Status: s.status.ToString(),
Progress: s.progress,
}); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
@@ -728,53 +795,7 @@ func (m *multiLPath) String() string {
return strings.Join(*m, ", ")
}
func (s *Server) reserveWorstCaseGraph() error {
ctx := s.model.Backend().NewContext()
defer ctx.Close()
var batch input.Batch
inputs := make([]int32, s.batchSize)
batch.Positions = make([]int32, len(inputs))
batch.Sequences = make([]int, len(inputs))
for i := range inputs {
batch.Positions[i] = int32(i)
}
batch.Outputs = make([]int32, s.parallel)
for i := range batch.Outputs {
batch.Outputs[i] = int32(i)
}
var err error
batch.Inputs, err = ctx.Input().FromIntSlice(inputs, len(inputs))
if err != nil {
return err
}
cache := s.model.Config().Cache
if cache != nil {
err := cache.StartForward(ctx, batch, true)
if err != nil {
return err
}
}
t, err := s.model.Forward(ctx, batch)
if err != nil {
return err
}
err = ctx.Forward(t).Reserve()
if err != nil {
return err
}
return nil
}
func (s *Server) loadModel(
ctx context.Context,
mpath string,
params ml.BackendParams,
lpath multiLPath,
@@ -784,7 +805,7 @@ func (s *Server) loadModel(
multiUserCache bool,
) {
var err error
s.model, err = model.New(ctx, mpath, params)
s.model, err = model.New(mpath, params)
if err != nil {
panic(err)
}
@@ -796,7 +817,7 @@ func (s *Server) loadModel(
panic("loras are not yet implemented")
}
s.cache, err = NewInputCache(s.model, kvCacheType, int32(kvSize), parallel, s.batchSize, multiUserCache)
s.cache, err = NewInputCache(s.model, kvCacheType, int32(kvSize), parallel, multiUserCache)
if err != nil {
panic(err)
}
@@ -810,12 +831,7 @@ func (s *Server) loadModel(
s.seqs = make([]*Sequence, s.parallel)
s.seqsSem = semaphore.NewWeighted(int64(s.parallel))
err = s.reserveWorstCaseGraph()
if err != nil {
panic(err)
}
s.status = llm.ServerStatusReady
s.status = ServerStatusReady
s.ready.Done()
}
@@ -867,7 +883,7 @@ func Execute(args []string) error {
server := &Server{
batchSize: *batchSize,
status: llm.ServerStatusLoadingModel,
status: ServerStatusLoadingModel,
}
// TODO(jessegross): Parameters that need to be implemented:
@@ -885,9 +901,6 @@ func Execute(args []string) error {
}
params := ml.BackendParams{
Progress: func(progress float32) {
server.progress = progress
},
NumThreads: *threads,
NumGPULayers: *numGPULayers,
MainGPU: *mainGPU,
@@ -896,13 +909,13 @@ func Execute(args []string) error {
}
server.ready.Add(1)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go server.loadModel(ctx, *mpath, params, lpaths, *parallel, *kvCacheType, *kvSize, *multiUserCache)
go server.loadModel(*mpath, params, lpaths, *parallel, *kvCacheType, *kvSize, *multiUserCache)
server.cond = sync.NewCond(&server.mu)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go server.run(ctx)
addr := "127.0.0.1:" + strconv.Itoa(*port)
@@ -914,13 +927,9 @@ func Execute(args []string) error {
defer listener.Close()
mux := http.NewServeMux()
// TODO: support embeddings
mux.HandleFunc("POST /embedding", func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "this model does not support embeddings", http.StatusNotImplemented)
})
mux.HandleFunc("POST /completion", server.completion)
mux.HandleFunc("GET /health", server.health)
mux.HandleFunc("/embedding", server.embeddings)
mux.HandleFunc("/completion", server.completion)
mux.HandleFunc("/health", server.health)
httpServer := http.Server{
Handler: mux,

View File

@@ -1,11 +1,10 @@
package sample
import (
"errors"
"math"
"math/rand/v2"
"slices"
"math/rand"
"sync"
"time"
"github.com/ollama/ollama/llama"
)
@@ -26,10 +25,6 @@ type Sampler struct {
}
func (s *Sampler) Sample(logits []float32) (int32, error) {
if len(logits) == 0 {
return -1, errors.New("sample: no logits provided to sample")
}
tokens := make([]token, len(logits))
for i := range logits {
tokens[i].id = int32(i)
@@ -91,50 +86,53 @@ func (s *Sampler) sample(tokens []token) (token, error) {
// topK also sorts the tokens in descending order of logits
tokens = topK(tokens, s.topK)
// scale and normalize the tokens in place
temperature(tokens, s.temperature)
softmax(tokens)
tokens = topP(tokens, s.topP)
tokens = minP(tokens, s.minP)
var r float32
if s.rng != nil {
r = s.rng.Float32()
} else {
r = rand.Float32()
}
// token logit values are updated to probabilities
temperature(tokens, s.temperature)
softmax(tokens)
return tokens[dist(tokens, s.rng.Int63())], nil
// Calculate cumulative sum of probabilities
var sum float32
for i := range tokens {
sum += tokens[i].value
tokens[i].value = sum
}
r *= tokens[len(tokens)-1].value
// // TODO: this should fall back to greedy sampling
// // or topP, topK values etc should be such that
// // there are always tokens to sample from
// if len(tokens) == 0 {
// return token{}, errors.New("no tokens to sample from")
// }
idx, _ := slices.BinarySearchFunc(tokens, r, func(token token, target float32) int {
if token.value < target {
return -1
}
return 1
})
// var r float32
// if s.rng != nil {
// r = s.rng.Float32()
// } else {
// r = rand.Float32()
// }
if math.IsNaN(float64(sum)) {
return token{}, errors.New("sample: logits sum to NaN, check model output")
}
return tokens[idx], nil
// // Calculate cumulative sum of probabilities
// var sum float32
// for i := range tokens {
// sum += tokens[i].value
// tokens[i].value = sum
// }
// r *= tokens[len(tokens)-1].value
// idx, _ := slices.BinarySearchFunc(tokens, r, func(token token, target float32) int {
// if token.value < target {
// return -1
// }
// return 1
// })
// return tokens[idx], nil
}
// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *Grammar) Sampler {
var rng *rand.Rand
if seed != -1 {
// PCG requires two parameters: sequence and stream
// Use original seed for sequence
sequence := uint64(seed)
// Use golden ratio hash to generate statistically independent seeds
rng = rand.New(rand.NewPCG(sequence, sequence^0x9E3779B9))
rng = rand.New(rand.NewSource(int64(seed)))
} else {
rng = rand.New(rand.NewSource(time.Now().UnixNano()))
}
if temperature < 0.0 {
temperature = 0.0

View File

@@ -1,7 +1,6 @@
package sample
import (
"math"
"math/rand/v2"
"testing"
)
@@ -30,29 +29,6 @@ func TestWeighted(t *testing.T) {
if want != got {
t.Errorf("index mismatch: want %d, got %d", want, got)
}
// Test very high p
logits = []float32{1.0, 0.9999999999999999, 0.5, 0.1}
// Use extremely small topP to filter out all tokens
sampler = NewSampler(1.0, 0, 1e-10, 0, 0, nil)
got, err = sampler.Sample(logits)
if err != nil {
t.Error(err)
return
}
// Should get the token with the highest logit
want = int32(0)
if want != got {
t.Errorf("index mismatch: want %d, got %d", want, got)
}
logits = []float32{float32(math.NaN()), float32(math.NaN()), float32(math.NaN())}
sampler = NewSampler(1, 0, 0.95, 0.05, 0, nil)
got, err = sampler.Sample(logits)
if err == nil {
t.Errorf("expected error, got %d", got)
return
}
}
func BenchmarkSample(b *testing.B) {

1
sample/testdata/logits.bin vendored Normal file
View File

File diff suppressed because one or more lines are too long

View File

@@ -3,6 +3,7 @@ package sample
import (
"container/heap"
"math"
"math/rand"
"slices"
)
@@ -25,38 +26,6 @@ func (h *tokenHeap) Pop() any {
return x
}
// temperature applies scaling to the logits
func temperature(ts []token, temp float32) {
// Ensure temperature clipping near 0 to avoid numerical instability
temp = max(temp, 1e-7)
for i := range ts {
ts[i].value = ts[i].value / temp
}
}
// softmax applies normalization to the logits
func softmax(ts []token) {
// Find max logit for numerical stability
maxLogit := float32(math.Inf(-1))
for _, t := range ts {
if t.value > maxLogit {
maxLogit = t.value
}
}
// Compute exp(x - max)
var sum float32
for i, v := range ts {
ts[i].value = float32(math.Exp(float64(v.value - maxLogit)))
sum += ts[i].value
}
// exp(x - max) / sum(exp(x - max))
for i := range ts {
ts[i].value /= sum
}
}
// topK limits the number of tokens considered to the k highest logits
func topK(ts []token, k int) []token {
if k >= len(ts) || k <= 0 {
@@ -96,7 +65,6 @@ func topK(ts []token, k int) []token {
}
// topP limits tokens to those with cumulative probability p
// requires ts to be sorted in descending order of probabilities
func topP(ts []token, p float32) []token {
if p == 1.0 {
return ts
@@ -107,24 +75,93 @@ func topP(ts []token, p float32) []token {
for i, t := range ts {
sum += t.value
if sum > float32(p) {
return ts[:i+1]
ts = ts[:i+1]
return ts
}
}
return ts
}
// minP filters tokens with probabilities >= p * max_prob
// requires ts to be sorted in descending order of probabilities
// minP limits tokens to those with cumulative probability p
func minP(ts []token, p float32) []token {
maxProb := ts[0].value
if p == 1.0 {
return ts
}
threshold := maxProb * p
for i, t := range ts {
if t.value < threshold {
return ts[:i]
maxProb := float32(math.Inf(-1))
for _, token := range ts {
if token.value > maxProb {
maxProb = token.value
}
}
threshold := maxProb * float32(p)
// Filter tokens in-place
validTokens := ts[:0]
for i, token := range ts {
if token.value >= threshold {
validTokens = append(validTokens, ts[i])
}
}
ts = validTokens
return ts
}
func temperature(ts []token, temp float32) {
for i := range ts {
ts[i].value /= temp
}
}
func softmax(ts []token) {
if len(ts) == 0 {
return
}
// Find max logit for numerical stability
maxLogit := ts[0].value
for _, t := range ts {
if t.value > maxLogit {
maxLogit = t.value
}
}
// Compute exp(logit - maxLogit) and sum them
var sumExp float32
for i, t := range ts {
expVal := float32(math.Exp(float64(t.value - maxLogit)))
ts[i].value = expVal
sumExp += expVal
}
// Normalize probabilities
for i := range ts {
ts[i].value /= sumExp
}
}
// applyDist selects a token based on probabilities and seed
func dist(ts []token, seed int64) int {
rng := rand.New(rand.NewSource(seed))
cdf := make([]float32, len(ts))
var cumSum float32
for i, t := range ts {
cumSum += t.value
cdf[i] = cumSum
}
r := rng.Float32() * cumSum
// Select token based on CDF
for i, probSum := range cdf {
if r < probSum {
return i
}
}
return len(ts) - 1
}

View File

@@ -1,8 +1,13 @@
package sample
import (
"encoding/binary"
"errors"
"math"
"math/rand/v2"
"os"
"path/filepath"
"runtime"
"testing"
)
@@ -32,132 +37,64 @@ func compareLogits(t *testing.T, name string, want []float32, got []token) {
}
}
func TestTemperature(t *testing.T) {
input := []float32{1.0, 4.0, -2.0, 0.0}
tokens := toTokens(input)
temperature(tokens, 0.5)
want := []float32{2.0, 8.0, -4.0, 0.0}
compareLogits(t, "temperature(0.5)", want, tokens)
func TestTemperatureAndSoftmax(t *testing.T) {
input := []float32{1, 4, -2, 0}
got := temperature(toTokens(input), 0.5)
input = []float32{1.0, 4.0, -2.0, 0.0}
tokens = toTokens(input)
temperature(tokens, 1.0)
want = []float32{1.0, 4.0, -2.0, 0.0}
compareLogits(t, "temperature(1)", want, tokens)
input = []float32{1.0, 4.0, -2.0, 0.0}
tokens = toTokens(input)
temperature(tokens, 0.0)
want = []float32{1e7, 4e7, -2e7, 0.0}
compareLogits(t, "temperature(0)", want, tokens)
}
func TestSoftmax(t *testing.T) {
tests := []struct {
name string
input []float32
expected []float32
}{
{
name: "correctness softmax",
input: []float32{1, -2, 3, 0},
expected: []float32{0.113550, 0.005653, 0.839024, 0.041773},
},
{
name: "normal distribution",
input: []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367},
},
{
name: "single value",
input: []float32{1.0},
},
{
name: "identical values",
input: []float32{0.9, 0.9, 0.9},
},
{
name: "large values",
input: []float32{1000.0, 2000.0, 3000.0},
},
{
name: "small values",
input: []float32{1e-6, 2e-6, 3e-6},
},
{
name: "negative values",
input: []float32{-1.0, -2.0, -3.0},
},
{
name: "mixed values",
input: []float32{-100.0, 0.0, 100.0},
},
// Check probabilities sum to 1
var sum float32
for _, token := range got {
sum += token.value
}
if math.Abs(float64(sum-1.0)) > 1e-6 {
t.Errorf("probabilities don't sum to 1: got %f", sum)
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tokens := toTokens(tt.input)
softmax(tokens)
if tt.expected != nil {
compareLogits(t, tt.name, tt.expected, tokens)
return
}
// Check probabilities sum to 1
var sum float32
for _, token := range tokens {
sum += token.value
if token.value < 0 || token.value > 1 {
t.Errorf("probability out of range [0,1]: got %f", token.value)
}
}
if math.Abs(float64(sum-1.0)) > 1e-6 {
t.Errorf("probabilities don't sum to 1: got %f", sum)
}
})
got = temperature(toTokens(input), 1)
// Check probabilities sum to 1
sum = 0.0
for _, token := range got {
sum += token.value
}
if math.Abs(float64(sum-1.0)) > 1e-6 {
t.Errorf("probabilities don't sum to 1: got %f", sum)
}
}
func TestTopK(t *testing.T) {
input := []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
tokens := toTokens(input)
tokens = topK(tokens, 5)
if len(tokens) != 5 {
t.Errorf("topK(5): wrong length: want 5, got %d", len(tokens))
// Test k=5
got := topK(toTokens(input), 5)
if len(got) != 5 {
t.Errorf("topK(5): wrong length: want 5, got %d", len(got))
}
// Should keep highest 3 values in descending order
want := []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154}
compareLogits(t, "topK(3)", want, tokens)
compareLogits(t, "topK(3)", want, got)
tokens = toTokens(input)
tokens = topK(tokens, 20)
if len(tokens) != len(input) {
t.Errorf("topK(20): wrong length: want %d, got %d", len(input), len(tokens))
got = topK(toTokens(input), 20)
if len(got) != len(input) {
t.Errorf("topK(20): wrong length: want %d, got %d", len(input), len(got))
}
// Test k=-1
input = []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
want = []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
tokens = toTokens(input)
tokens = topK(tokens, -1)
if len(tokens) != len(input) {
t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(tokens))
got = topK(toTokens(input), -1)
if len(got) != len(input) {
t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(got))
}
compareLogits(t, "topK(-1)", want, tokens)
compareLogits(t, "topK(-1)", want, got)
// Test k=0
input = []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
want = []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
tokens = toTokens(input)
tokens = topK(tokens, 0)
if len(tokens) != len(input) {
t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(tokens))
}
compareLogits(t, "topK(-1)", want, tokens)
input = []float32{-1e7, -2e7, -3e7, -4e7}
tokens = toTokens(input)
tokens = topK(tokens, 1)
if len(tokens) < 1 {
t.Error("topK should keep at least one token")
got = topK(toTokens(input), 0)
if len(got) != len(input) {
t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(got))
}
compareLogits(t, "topK(-1)", want, got)
}
func TestTopP(t *testing.T) {
@@ -165,134 +102,142 @@ func TestTopP(t *testing.T) {
tokens := toTokens(input)
// First apply temperature and softmax to get probabilities
softmax(tokens)
tokens = temperature(tokens, 1)
tokens = topK(tokens, 20)
// Test with very high p value
got := topP(tokens, 1.0)
// Should keep all tokens since p is 1
if len(got) != len(input) {
t.Errorf("topP(1.0): should keep all tokens, got %d, want %d", len(got), len(input))
}
// Test with normal p value
got = topP(tokens, 0.95)
// Then apply topP
got := topP(tokens, 0.95)
// Should keep tokens until cumsum > 0.95
if len(got) > 3 {
t.Errorf("topP(0.95): kept too many tokens: got %d", len(tokens))
t.Logf("got: %v", got)
}
// Test edge case - ensure at least one token remains
input = []float32{-1e6, -1e6, -1e7}
tokens = toTokens(input)
tokens = topK(tokens, 20)
softmax(tokens)
got = topP(tokens, 0.0)
if len(got) < 1 {
t.Error("topP should keep at least one token")
}
// Test with zero p value
got = topP(tokens, 0.0)
// Should keep only the highest probability token
if len(got) != 1 {
t.Errorf("topP(0.0): should keep only one token, got %d", len(got))
t.Logf("got: %v", got)
}
tokens = toTokens(input)
tokens = topK(tokens, 20)
softmax(tokens)
got = topP(tokens, 1e-10)
if len(got) == 0 {
t.Errorf("topP(1e-10): should keep at least one token, got %d", len(got))
t.Errorf("topP(0.95): kept too many tokens: got %d", len(got))
t.Logf("got: %v", got)
}
}
func TestMinP(t *testing.T) {
input := []float32{-2, 0, -1, -3, 2, 1, 4, 3}
input := []float32{-3, -2, -1, 0, 1, 2, 4, 3}
tokens := toTokens(input)
// First apply temperature and softmax
tokens = topK(tokens, 20)
softmax(tokens)
tokens = temperature(tokens, 1)
tokens = minP(tokens, 1.0)
if len(tokens) != 1 {
t.Errorf("minP(1.0): should keep all tokens, got %d, want %d", len(tokens), len(tokens))
}
// Test with normal p value
tokens = toTokens(input) // Reset tokens
tokens = topK(tokens, 20)
softmax(tokens)
tokens = minP(tokens, 0.2)
// Then apply minP
got := minP(tokens, 0.2)
// Should keep tokens with prob >= 0.2 * max_prob
if len(tokens) > 3 {
t.Errorf("minP(0.2): kept too many tokens: got %d", len(tokens))
t.Logf("got: %v", tokens)
if len(got) > 3 {
t.Errorf("minP(0.2): kept too many tokens: got %d", len(got))
}
}
func TestSortLogits(t *testing.T) {
input := []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
tokens := toTokens(input)
// Test with zero p value
tokens = toTokens(input) // Reset tokens
tokens = topK(tokens, 20)
softmax(tokens)
tokens = minP(tokens, 0.0)
// Should keep only the highest probability token
if len(tokens) != len(input) {
t.Errorf("minP(0.0): should keep only one token, got %d", len(tokens))
t.Logf("got: %v", tokens)
}
// Test with single token
tokens = toTokens(input[:1])
tokens = topK(tokens, 20)
softmax(tokens)
tokens = minP(tokens, 0.1)
// Should keep only the highest probability token
if len(tokens) != 1 {
t.Errorf("minP(0.1): should return single token, got %d", len(tokens))
t.Logf("got: %v", tokens)
}
input = []float32{1e-10, 1e-10, 1e-10}
tokens = toTokens(input)
softmax(tokens)
tokens = minP(tokens, 1.0)
if len(tokens) < 1 {
t.Error("minP should keep at least one token even with extreme probabilities")
got := minP(tokens, 1.0)
if len(got) != 1 {
t.Errorf("minP(1.0): should keep all tokens, got %d, want %d", len(got), len(tokens))
}
// Test with normal p value
got = minP(tokens, 0.2)
// Should keep tokens with prob >= 0.2 * max_prob
if len(got) > 3 {
t.Errorf("minP(0.2): kept too many tokens: got %d", len(got))
t.Logf("got: %v", got)
}
// Test with zero p value
got = minP(tokens, 0.0)
// Should keep only the highest probability token
if len(got) != len(tokens) {
t.Errorf("minP(0.0): should keep only one token, got %d", len(got))
t.Logf("got: %v", got)
for i := 1; i < len(tokens); i++ {
if tokens[i].value > tokens[i-1].value {
t.Errorf("sortLogits: tokens not sorted in descending order at index %d: %f > %f",
i, tokens[i].value, tokens[i-1].value)
}
}
want := []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
compareLogits(t, "sortLogits", want, tokens)
}
// TestSortLogitsWithRealData tests sorting behavior using real model logit distributions
func TestSortLogitsWithRealData(t *testing.T) {
// This will be populated from testdata/logits.bin
// Format: 32-bit float array in binary format
logits, err := loadTestLogits(t)
if err != nil {
t.Skipf("Skipping real logit test: %v", err)
return
}
tokens := toTokens(logits)
sortLogits(tokens)
// Calculate n for verification
n := int(math.Sqrt(float64(len(tokens)))) + 1
if n > 1000 {
n = 1000
} else if n < 100 {
n = 100
}
t.Logf("Testing with %d tokens, partial sorting top %d", len(tokens), n)
// Only verify the top n elements are sorted (which is what we guarantee)
// This is much faster than checking the entire array
topN := tokens[:n]
for i := 1; i < len(topN); i++ {
if topN[i].value > topN[i-1].value {
t.Fatalf("top %d tokens not properly sorted at index %d: %.15f > %.15f",
n, i, topN[i].value, topN[i-1].value)
}
}
// Verify we didn't lose any high value tokens by checking that
// all tokens after position n are <= the nth token
// Do this in chunks to avoid timeouts on large arrays
nthValue := tokens[n-1].value
const chunkSize = 1000
for start := n; start < len(tokens); start += chunkSize {
end := min(start+chunkSize, len(tokens))
for i := start; i < end; i++ {
if tokens[i].value > nthValue {
t.Fatalf("found higher value token after position %d: tokens[%d].value = %.15f > %.15f",
n, i, tokens[i].value, nthValue)
}
}
}
}
// loadTestLogits loads logit test data from testdata/logits.bin
func loadTestLogits(t *testing.T) ([]float32, error) {
t.Helper()
_, currFile, _, ok := runtime.Caller(0)
if !ok {
return nil, errors.New("could not determine test file path")
}
testDataPath := filepath.Join(filepath.Dir(currFile), "testdata", "logits.bin")
file, err := os.Open(testDataPath)
if err != nil {
return nil, err
}
defer file.Close()
stat, err := file.Stat()
if err != nil {
return nil, err
}
numFloats := stat.Size() / 4 // each float32 is 4 bytes
if numFloats*4 != stat.Size() {
return nil, errors.New("logits.bin has invalid size: not a multiple of 4 bytes")
}
logits := make([]float32, numFloats)
for i := range logits {
var val uint32
if err := binary.Read(file, binary.LittleEndian, &val); err != nil {
return nil, err
}
logits[i] = math.Float32frombits(val)
}
if len(logits) == 0 {
return nil, errors.New("logits.bin is empty")
}
return logits, nil
}
func BenchmarkTransforms(b *testing.B) {
@@ -315,19 +260,11 @@ func BenchmarkTransforms(b *testing.B) {
}
})
b.Run("Softmax", func(b *testing.B) {
b.ResetTimer()
for b.Loop() {
copy(tokensCopy, tokens)
softmax(tokensCopy)
}
})
b.Run("TopK", func(b *testing.B) {
b.ResetTimer()
for b.Loop() {
copy(tokensCopy, tokens)
tokens = topK(tokensCopy, 10)
topK(tokensCopy, 10)
}
})
@@ -335,7 +272,7 @@ func BenchmarkTransforms(b *testing.B) {
b.ResetTimer()
for b.Loop() {
copy(tokensCopy, tokens)
tokens = topP(tokensCopy, 0.9)
topP(tokensCopy, 0.9)
}
})
@@ -343,7 +280,7 @@ func BenchmarkTransforms(b *testing.B) {
b.ResetTimer()
for b.Loop() {
copy(tokensCopy, tokens)
tokens = minP(tokensCopy, 0.2)
minP(tokensCopy, 0.2)
}
})
@@ -351,7 +288,7 @@ func BenchmarkTransforms(b *testing.B) {
b.ResetTimer()
for b.Loop() {
copy(tokensCopy, tokens)
tokens = topK(tokensCopy, 200000)
topK(tokensCopy, 200000)
}
})
}

View File

@@ -8,7 +8,7 @@ usage() {
exit 1
}
export VERSION=${VERSION:-$(git describe --tags --first-parent --abbrev=7 --long --dirty --always | sed -e "s/^v//g")}
export VERSION=${VERSION:-$(git describe --tags --dirty)}
export GOFLAGS="'-ldflags=-w -s \"-X=github.com/ollama/ollama/version.Version=${VERSION#v}\" \"-X=github.com/ollama/ollama/server.mode=release\"'"
export CGO_CPPFLAGS='-mmacosx-version-min=11.3'

View File

@@ -29,9 +29,8 @@ import (
const maxRetries = 6
var (
errMaxRetriesExceeded = errors.New("max retries exceeded")
errPartStalled = errors.New("part stalled")
errMaxRedirectsExceeded = errors.New("maximum redirects exceeded (10) for directURL")
errMaxRetriesExceeded = errors.New("max retries exceeded")
errPartStalled = errors.New("part stalled")
)
var blobDownloadManager sync.Map
@@ -237,7 +236,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
newOpts.CheckRedirect = func(req *http.Request, via []*http.Request) error {
if len(via) > 10 {
return errMaxRedirectsExceeded
return errors.New("maximum redirects exceeded (10) for directURL")
}
// if the hostname is the same, allow the redirect

View File

@@ -35,9 +35,14 @@ var (
errCapabilityCompletion = errors.New("completion")
errCapabilityTools = errors.New("tools")
errCapabilityInsert = errors.New("insert")
errCapabilityVision = errors.New("vision")
errCapabilityEmbedding = errors.New("embedding")
errInsecureProtocol = errors.New("insecure protocol http")
)
type Capability string
const (
CapabilityCompletion = Capability("completion")
CapabilityTools = Capability("tools")
CapabilityInsert = Capability("insert")
)
type registryOptions struct {
@@ -60,83 +65,52 @@ type Model struct {
System string
License []string
Digest string
Options map[string]any
Options map[string]interface{}
Messages []api.Message
Template *template.Template
}
// Capabilities returns the capabilities that the model supports
func (m *Model) Capabilities() []model.Capability {
capabilities := []model.Capability{}
// Check for completion capability
r, err := os.Open(m.ModelPath)
if err == nil {
defer r.Close()
f, _, err := ggml.Decode(r, 0)
if err == nil {
if _, ok := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]; ok {
capabilities = append(capabilities, model.CapabilityEmbedding)
} else {
capabilities = append(capabilities, model.CapabilityCompletion)
}
if _, ok := f.KV()[fmt.Sprintf("%s.vision.block_count", f.KV().Architecture())]; ok {
capabilities = append(capabilities, model.CapabilityVision)
}
} else {
slog.Error("couldn't decode ggml", "error", err)
}
} else {
slog.Error("couldn't open model file", "error", err)
}
if m.Template == nil {
return capabilities
}
// Check for tools capability
if slices.Contains(m.Template.Vars(), "tools") {
capabilities = append(capabilities, model.CapabilityTools)
}
// Check for insert capability
if slices.Contains(m.Template.Vars(), "suffix") {
capabilities = append(capabilities, model.CapabilityInsert)
}
return capabilities
}
// CheckCapabilities checks if the model has the specified capabilities returning an error describing
// any missing or unknown capabilities
func (m *Model) CheckCapabilities(want ...model.Capability) error {
available := m.Capabilities()
func (m *Model) CheckCapabilities(caps ...Capability) error {
var errs []error
for _, cap := range caps {
switch cap {
case CapabilityCompletion:
r, err := os.Open(m.ModelPath)
if err != nil {
slog.Error("couldn't open model file", "error", err)
continue
}
defer r.Close()
// Map capabilities to their corresponding error
capToErr := map[model.Capability]error{
model.CapabilityCompletion: errCapabilityCompletion,
model.CapabilityTools: errCapabilityTools,
model.CapabilityInsert: errCapabilityInsert,
model.CapabilityVision: errCapabilityVision,
model.CapabilityEmbedding: errCapabilityEmbedding,
}
// TODO(mxyng): decode the GGML into model to avoid doing this multiple times
f, _, err := ggml.Decode(r, 0)
if err != nil {
slog.Error("couldn't decode ggml", "error", err)
continue
}
for _, cap := range want {
err, ok := capToErr[cap]
if !ok {
if _, ok := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]; ok {
errs = append(errs, errCapabilityCompletion)
}
case CapabilityTools:
if !slices.Contains(m.Template.Vars(), "tools") {
errs = append(errs, errCapabilityTools)
}
case CapabilityInsert:
vars := m.Template.Vars()
if !slices.Contains(vars, "suffix") {
errs = append(errs, errCapabilityInsert)
}
default:
slog.Error("unknown capability", "capability", cap)
return fmt.Errorf("unknown capability: %s", cap)
}
if !slices.Contains(available, cap) {
errs = append(errs, err)
}
}
if len(errs) > 0 {
if err := errors.Join(errs...); err != nil {
return fmt.Errorf("%w %w", errCapabilities, errors.Join(errs...))
}
@@ -505,7 +479,7 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
fn(api.ProgressResponse{Status: "retrieving manifest"})
if mp.ProtocolScheme == "http" && !regOpts.Insecure {
return errInsecureProtocol
return errors.New("insecure protocol http")
}
manifest, _, err := GetManifest(mp)
@@ -569,7 +543,7 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
}
if mp.ProtocolScheme == "http" && !regOpts.Insecure {
return errInsecureProtocol
return errors.New("insecure protocol http")
}
fn(api.ProgressResponse{Status: "pulling manifest"})

View File

@@ -1,360 +0,0 @@
package server
import (
"bytes"
"encoding/binary"
"os"
"path/filepath"
"strings"
"testing"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/types/model"
)
// Constants for GGUF magic bytes and version
var (
ggufMagic = []byte{0x47, 0x47, 0x55, 0x46} // "GGUF"
ggufVer = uint32(3) // Version 3
)
// Helper function to create mock GGUF data
func createMockGGUFData(architecture string, vision bool) []byte {
var buf bytes.Buffer
// Write GGUF header
buf.Write(ggufMagic)
binary.Write(&buf, binary.LittleEndian, ggufVer)
// Write tensor count (0 for our test)
var numTensors uint64 = 0
binary.Write(&buf, binary.LittleEndian, numTensors)
// Calculate number of metadata entries
numMetaEntries := uint64(1) // architecture entry
if vision {
numMetaEntries++
}
// Add embedding entry if architecture is "bert"
if architecture == "bert" {
numMetaEntries++
}
binary.Write(&buf, binary.LittleEndian, numMetaEntries)
// Write architecture metadata
archKey := "general.architecture"
keyLen := uint64(len(archKey))
binary.Write(&buf, binary.LittleEndian, keyLen)
buf.WriteString(archKey)
// String type (8)
var strType uint32 = 8
binary.Write(&buf, binary.LittleEndian, strType)
// String length
strLen := uint64(len(architecture))
binary.Write(&buf, binary.LittleEndian, strLen)
buf.WriteString(architecture)
if vision {
visionKey := architecture + ".vision.block_count"
keyLen = uint64(len(visionKey))
binary.Write(&buf, binary.LittleEndian, keyLen)
buf.WriteString(visionKey)
// uint32 type (4)
var uint32Type uint32 = 4
binary.Write(&buf, binary.LittleEndian, uint32Type)
// uint32 value (1)
var countVal uint32 = 1
binary.Write(&buf, binary.LittleEndian, countVal)
}
// Write embedding metadata if architecture is "bert"
if architecture == "bert" {
poolKey := architecture + ".pooling_type"
keyLen = uint64(len(poolKey))
binary.Write(&buf, binary.LittleEndian, keyLen)
buf.WriteString(poolKey)
// uint32 type (4)
var uint32Type uint32 = 4
binary.Write(&buf, binary.LittleEndian, uint32Type)
// uint32 value (1)
var poolingVal uint32 = 1
binary.Write(&buf, binary.LittleEndian, poolingVal)
}
return buf.Bytes()
}
func TestModelCapabilities(t *testing.T) {
// Create a temporary directory for test files
tempDir, err := os.MkdirTemp("", "model_capabilities_test")
if err != nil {
t.Fatalf("Failed to create temp directory: %v", err)
}
defer os.RemoveAll(tempDir)
// Create different types of mock model files
completionModelPath := filepath.Join(tempDir, "model.bin")
visionModelPath := filepath.Join(tempDir, "vision_model.bin")
embeddingModelPath := filepath.Join(tempDir, "embedding_model.bin")
// Create a simple model file for tests that don't depend on GGUF content
simpleModelPath := filepath.Join(tempDir, "simple_model.bin")
err = os.WriteFile(completionModelPath, createMockGGUFData("llama", false), 0o644)
if err != nil {
t.Fatalf("Failed to create completion model file: %v", err)
}
err = os.WriteFile(visionModelPath, createMockGGUFData("llama", true), 0o644)
if err != nil {
t.Fatalf("Failed to create completion model file: %v", err)
}
err = os.WriteFile(embeddingModelPath, createMockGGUFData("bert", false), 0o644)
if err != nil {
t.Fatalf("Failed to create embedding model file: %v", err)
}
err = os.WriteFile(simpleModelPath, []byte("dummy model data"), 0o644)
if err != nil {
t.Fatalf("Failed to create simple model file: %v", err)
}
toolsInsertTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}{{ if .suffix }}{{ .suffix }}{{ end }}")
if err != nil {
t.Fatalf("Failed to parse template: %v", err)
}
chatTemplate, err := template.Parse("{{ .prompt }}")
if err != nil {
t.Fatalf("Failed to parse template: %v", err)
}
toolsTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}")
if err != nil {
t.Fatalf("Failed to parse template: %v", err)
}
testModels := []struct {
name string
model Model
expectedCaps []model.Capability
}{
{
name: "model with completion capability",
model: Model{
ModelPath: completionModelPath,
Template: chatTemplate,
},
expectedCaps: []model.Capability{model.CapabilityCompletion},
},
{
name: "model with completion, tools, and insert capability",
model: Model{
ModelPath: completionModelPath,
Template: toolsInsertTemplate,
},
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityTools, model.CapabilityInsert},
},
{
name: "model with tools and insert capability",
model: Model{
ModelPath: simpleModelPath,
Template: toolsInsertTemplate,
},
expectedCaps: []model.Capability{model.CapabilityTools, model.CapabilityInsert},
},
{
name: "model with tools capability",
model: Model{
ModelPath: simpleModelPath,
Template: toolsTemplate,
},
expectedCaps: []model.Capability{model.CapabilityTools},
},
{
name: "model with vision capability",
model: Model{
ModelPath: visionModelPath,
Template: chatTemplate,
},
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityVision},
},
{
name: "model with vision, tools, and insert capability",
model: Model{
ModelPath: visionModelPath,
Template: toolsInsertTemplate,
},
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityVision, model.CapabilityTools, model.CapabilityInsert},
},
{
name: "model with embedding capability",
model: Model{
ModelPath: embeddingModelPath,
Template: chatTemplate,
},
expectedCaps: []model.Capability{model.CapabilityEmbedding},
},
}
// compare two slices of model.Capability regardless of order
compareCapabilities := func(a, b []model.Capability) bool {
if len(a) != len(b) {
return false
}
aCount := make(map[model.Capability]int)
for _, cap := range a {
aCount[cap]++
}
bCount := make(map[model.Capability]int)
for _, cap := range b {
bCount[cap]++
}
for cap, count := range aCount {
if bCount[cap] != count {
return false
}
}
return true
}
for _, tt := range testModels {
t.Run(tt.name, func(t *testing.T) {
// Test Capabilities method
caps := tt.model.Capabilities()
if !compareCapabilities(caps, tt.expectedCaps) {
t.Errorf("Expected capabilities %v, got %v", tt.expectedCaps, caps)
}
})
}
}
func TestModelCheckCapabilities(t *testing.T) {
// Create a temporary directory for test files
tempDir, err := os.MkdirTemp("", "model_check_capabilities_test")
if err != nil {
t.Fatalf("Failed to create temp directory: %v", err)
}
defer os.RemoveAll(tempDir)
visionModelPath := filepath.Join(tempDir, "vision_model.bin")
simpleModelPath := filepath.Join(tempDir, "model.bin")
embeddingModelPath := filepath.Join(tempDir, "embedding_model.bin")
err = os.WriteFile(simpleModelPath, []byte("dummy model data"), 0o644)
if err != nil {
t.Fatalf("Failed to create simple model file: %v", err)
}
err = os.WriteFile(visionModelPath, createMockGGUFData("llama", true), 0o644)
if err != nil {
t.Fatalf("Failed to create vision model file: %v", err)
}
err = os.WriteFile(embeddingModelPath, createMockGGUFData("bert", false), 0o644)
if err != nil {
t.Fatalf("Failed to create embedding model file: %v", err)
}
toolsInsertTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}{{ if .suffix }}{{ .suffix }}{{ end }}")
if err != nil {
t.Fatalf("Failed to parse template: %v", err)
}
chatTemplate, err := template.Parse("{{ .prompt }}")
if err != nil {
t.Fatalf("Failed to parse template: %v", err)
}
toolsTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}")
if err != nil {
t.Fatalf("Failed to parse template: %v", err)
}
tests := []struct {
name string
model Model
checkCaps []model.Capability
expectedErrMsg string
}{
{
name: "completion model without tools capability",
model: Model{
ModelPath: simpleModelPath,
Template: chatTemplate,
},
checkCaps: []model.Capability{model.CapabilityTools},
expectedErrMsg: "does not support tools",
},
{
name: "model with all needed capabilities",
model: Model{
ModelPath: simpleModelPath,
Template: toolsInsertTemplate,
},
checkCaps: []model.Capability{model.CapabilityTools, model.CapabilityInsert},
},
{
name: "model missing insert capability",
model: Model{
ModelPath: simpleModelPath,
Template: toolsTemplate,
},
checkCaps: []model.Capability{model.CapabilityInsert},
expectedErrMsg: "does not support insert",
},
{
name: "model missing vision capability",
model: Model{
ModelPath: simpleModelPath,
Template: toolsTemplate,
},
checkCaps: []model.Capability{model.CapabilityVision},
expectedErrMsg: "does not support vision",
},
{
name: "model with vision capability",
model: Model{
ModelPath: visionModelPath,
Template: chatTemplate,
},
checkCaps: []model.Capability{model.CapabilityVision},
},
{
name: "model with embedding capability",
model: Model{
ModelPath: embeddingModelPath,
Template: chatTemplate,
},
checkCaps: []model.Capability{model.CapabilityEmbedding},
},
{
name: "unknown capability",
model: Model{
ModelPath: simpleModelPath,
Template: chatTemplate,
},
checkCaps: []model.Capability{"unknown"},
expectedErrMsg: "unknown capability",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Test CheckCapabilities method
err := tt.model.CheckCapabilities(tt.checkCaps...)
if tt.expectedErrMsg == "" {
if err != nil {
t.Errorf("Expected no error, got: %v", err)
}
} else {
if err == nil {
t.Errorf("Expected error containing %q, got nil", tt.expectedErrMsg)
} else if !strings.Contains(err.Error(), tt.expectedErrMsg) {
t.Errorf("Expected error containing %q, got: %v", tt.expectedErrMsg, err)
}
}
})
}
}

Some files were not shown because too many files have changed in this diff Show More