mirror of
https://github.com/ollama/ollama.git
synced 2025-12-24 08:10:54 -05:00
Compare commits
116 Commits
parth/samp
...
parth/pyth
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b4cd1118ab | ||
|
|
128c90d3ac | ||
|
|
f5872a097c | ||
|
|
3ac5e0f102 | ||
|
|
ef65174df2 | ||
|
|
42ecb9f138 | ||
|
|
5c0331fd83 | ||
|
|
e7019c9455 | ||
|
|
d98bfe7e70 | ||
|
|
6747099d71 | ||
|
|
ccc8c6777b | ||
|
|
dbb149e6f7 | ||
|
|
a807985e59 | ||
|
|
8643c4d5bf | ||
|
|
b0c3aba590 | ||
|
|
19c0c25de8 | ||
|
|
2f723ac2d6 | ||
|
|
249fbbe52f | ||
|
|
c38680b8a1 | ||
|
|
16fca86c4a | ||
|
|
0f3f9e353d | ||
|
|
6bd0a983cd | ||
|
|
1861fbdeb5 | ||
|
|
3b96a93672 | ||
|
|
e53b3cbd0c | ||
|
|
b51e0f397c | ||
|
|
b42970063d | ||
|
|
493385eb3e | ||
|
|
9876c9faa4 | ||
|
|
4e415029b3 | ||
|
|
e172f095ba | ||
|
|
c001b98087 | ||
|
|
23fc8e92eb | ||
|
|
4059a297a6 | ||
|
|
66b2539238 | ||
|
|
ef27d52e79 | ||
|
|
b2a465296d | ||
|
|
5d097277ef | ||
|
|
071a9872cb | ||
|
|
0bd0454ea7 | ||
|
|
01aa788722 | ||
|
|
ead27aa9fe | ||
|
|
b816ff86c9 | ||
|
|
e5d84fb90b | ||
|
|
dd66712e31 | ||
|
|
f66216e399 | ||
|
|
f4f0992b6e | ||
|
|
1feff61977 | ||
|
|
5e0b904e88 | ||
|
|
131f0355a5 | ||
|
|
ce929984a3 | ||
|
|
4b34930a31 | ||
|
|
74bd09652d | ||
|
|
fb6252d786 | ||
|
|
c794fef2f2 | ||
|
|
00ebda8cc4 | ||
|
|
d14ce75b95 | ||
|
|
2d6eac9084 | ||
|
|
3ed7ad3ab3 | ||
|
|
6d1103048e | ||
|
|
0ff28758b3 | ||
|
|
d3e9ca3eda | ||
|
|
0fbfcf3c9c | ||
|
|
0c220935bd | ||
|
|
ffbfe833da | ||
|
|
42a14f7f63 | ||
|
|
f8c3dbe5b5 | ||
|
|
b078dd157c | ||
|
|
2ddacd7516 | ||
|
|
da0e345200 | ||
|
|
df94175a0f | ||
|
|
61a8825216 | ||
|
|
021dcf089d | ||
|
|
bf24498b1e | ||
|
|
95e271d98f | ||
|
|
364629b8d6 | ||
|
|
108fe02165 | ||
|
|
4561fff36e | ||
|
|
50b5962042 | ||
|
|
e27e4a3c1b | ||
|
|
088514bbd4 | ||
|
|
2c8b484643 | ||
|
|
8294676150 | ||
|
|
ef378ad673 | ||
|
|
2d2247e59e | ||
|
|
7bf793a600 | ||
|
|
282bfaaa95 | ||
|
|
9679f40146 | ||
|
|
3892c3a703 | ||
|
|
4e320b8b90 | ||
|
|
eb2b22b042 | ||
|
|
4ea4d2b189 | ||
|
|
8d76fa23ef | ||
|
|
74b44fdf8f | ||
|
|
65b88c544f | ||
|
|
a422ba39c9 | ||
|
|
d2ec22371e | ||
|
|
033cec232a | ||
|
|
543240fb5f | ||
|
|
4bed739259 | ||
|
|
80c7ce381b | ||
|
|
ccfd41c4f0 | ||
|
|
3e102b7dad | ||
|
|
ec46f3286c | ||
|
|
5e2e0b46b1 | ||
|
|
45a13b1dec | ||
|
|
5c0b663969 | ||
|
|
30d7a59ba8 | ||
|
|
4aeb67ef4c | ||
|
|
3ba91634c1 | ||
|
|
1b7433b71e | ||
|
|
a70820daa0 | ||
|
|
6b45b1d6b4 | ||
|
|
85ab552028 | ||
|
|
b3af953a55 | ||
|
|
ad4e0bf3be |
@@ -86,9 +86,9 @@ if(CMAKE_CUDA_COMPILER)
|
||||
)
|
||||
endif()
|
||||
|
||||
set(WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX "^gfx(906|908|90a):xnack[+-]$"
|
||||
set(WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX "^gfx(906|908|90a|1200|1201):xnack[+-]$"
|
||||
CACHE STRING
|
||||
"Regular expression describing AMDGPU_TARGETS not supported on Windows. Override to force building these targets. Default \"^gfx(906|908|90a):xnack[+-]$\"."
|
||||
"Regular expression describing AMDGPU_TARGETS not supported on Windows. Override to force building these targets. Default \"^gfx(906|908|90a|1200|1201):xnack[+-]$\"."
|
||||
)
|
||||
|
||||
check_language(HIP)
|
||||
@@ -97,7 +97,7 @@ if(CMAKE_HIP_COMPILER)
|
||||
|
||||
find_package(hip REQUIRED)
|
||||
if(NOT AMDGPU_TARGETS)
|
||||
list(FILTER AMDGPU_TARGETS INCLUDE REGEX "^gfx(900|94[012]|101[02]|1030|110[012])$")
|
||||
list(FILTER AMDGPU_TARGETS INCLUDE REGEX "^gfx(900|94[012]|101[02]|1030|110[012]|120[01])$")
|
||||
elseif(WIN32 AND WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX)
|
||||
list(FILTER AMDGPU_TARGETS EXCLUDE REGEX ${WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX})
|
||||
endif()
|
||||
|
||||
@@ -56,7 +56,7 @@
|
||||
"name": "ROCm 6",
|
||||
"inherits": [ "ROCm" ],
|
||||
"cacheVariables": {
|
||||
"AMDGPU_TARGETS": "gfx900;gfx940;gfx941;gfx942;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102;gfx906:xnack-;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-"
|
||||
"AMDGPU_TARGETS": "gfx900;gfx940;gfx941;gfx942;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102;gfx1151;gfx1200;gfx1201;gfx906:xnack-;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-"
|
||||
}
|
||||
}
|
||||
],
|
||||
|
||||
@@ -51,7 +51,7 @@ see if the change were accepted.
|
||||
|
||||
The title should look like:
|
||||
|
||||
<package>: <short description>
|
||||
<package>: <short description>
|
||||
|
||||
The package is the most affected Go package. If the change does not affect Go
|
||||
code, then use the directory name instead. Changes to a single well-known
|
||||
|
||||
@@ -104,8 +104,8 @@ COPY --from=cuda-12 dist/lib/ollama/cuda_v12 /lib/ollama/cuda_v12
|
||||
FROM --platform=linux/arm64 scratch AS arm64
|
||||
COPY --from=cuda-11 dist/lib/ollama/cuda_v11 /lib/ollama/cuda_v11
|
||||
COPY --from=cuda-12 dist/lib/ollama/cuda_v12 /lib/ollama/cuda_v12
|
||||
COPY --from=jetpack-5 dist/lib/ollama/cuda_v11 lib/ollama/cuda_jetpack5
|
||||
COPY --from=jetpack-6 dist/lib/ollama/cuda_v12 lib/ollama/cuda_jetpack6
|
||||
COPY --from=jetpack-5 dist/lib/ollama/cuda_v11 /lib/ollama/cuda_jetpack5
|
||||
COPY --from=jetpack-6 dist/lib/ollama/cuda_v12 /lib/ollama/cuda_jetpack6
|
||||
|
||||
FROM scratch AS rocm
|
||||
COPY --from=rocm-6 dist/lib/ollama/rocm /lib/ollama/rocm
|
||||
|
||||
21
README.md
21
README.md
@@ -54,6 +54,10 @@ Here are some example models that can be downloaded:
|
||||
|
||||
| Model | Parameters | Size | Download |
|
||||
| ------------------ | ---------- | ----- | -------------------------------- |
|
||||
| Gemma 3 | 1B | 815MB | `ollama run gemma3:1b` |
|
||||
| Gemma 3 | 4B | 3.3GB | `ollama run gemma3` |
|
||||
| Gemma 3 | 12B | 8.1GB | `ollama run gemma3:12b` |
|
||||
| Gemma 3 | 27B | 17GB | `ollama run gemma3:27b` |
|
||||
| QwQ | 32B | 20GB | `ollama run qwq` |
|
||||
| DeepSeek-R1 | 7B | 4.7GB | `ollama run deepseek-r1` |
|
||||
| DeepSeek-R1 | 671B | 404GB | `ollama run deepseek-r1:671b` |
|
||||
@@ -66,9 +70,6 @@ Here are some example models that can be downloaded:
|
||||
| Llama 3.1 | 405B | 231GB | `ollama run llama3.1:405b` |
|
||||
| Phi 4 | 14B | 9.1GB | `ollama run phi4` |
|
||||
| Phi 4 Mini | 3.8B | 2.5GB | `ollama run phi4-mini` |
|
||||
| Gemma 2 | 2B | 1.6GB | `ollama run gemma2:2b` |
|
||||
| Gemma 2 | 9B | 5.5GB | `ollama run gemma2` |
|
||||
| Gemma 2 | 27B | 16GB | `ollama run gemma2:27b` |
|
||||
| Mistral | 7B | 4.1GB | `ollama run mistral` |
|
||||
| Moondream 2 | 1.4B | 829MB | `ollama run moondream` |
|
||||
| Neural Chat | 7B | 4.1GB | `ollama run neural-chat` |
|
||||
@@ -284,12 +285,13 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Bionic GPT](https://github.com/bionic-gpt/bionic-gpt)
|
||||
- [HTML UI](https://github.com/rtcfirefly/ollama-ui)
|
||||
- [Saddle](https://github.com/jikkuatwork/saddle)
|
||||
- [TagSpaces](https://www.tagspaces.org) (A platform for file based apps, [utilizing Ollama](https://docs.tagspaces.org/ai/) for the generation of tags and descriptions)
|
||||
- [Chatbot UI](https://github.com/ivanfioravanti/chatbot-ollama)
|
||||
- [Chatbot UI v2](https://github.com/mckaywrigley/chatbot-ui)
|
||||
- [Typescript UI](https://github.com/ollama-interface/Ollama-Gui?tab=readme-ov-file)
|
||||
- [Minimalistic React UI for Ollama Models](https://github.com/richawo/minimal-llm-ui)
|
||||
- [Ollamac](https://github.com/kevinhermawan/Ollamac)
|
||||
- [big-AGI](https://github.com/enricoros/big-AGI/blob/main/docs/config-local-ollama.md)
|
||||
- [big-AGI](https://github.com/enricoros/big-AGI)
|
||||
- [Cheshire Cat assistant framework](https://github.com/cheshire-cat-ai/core)
|
||||
- [Amica](https://github.com/semperai/amica)
|
||||
- [chatd](https://github.com/BruceMacD/chatd)
|
||||
@@ -323,6 +325,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [RWKV-Runner](https://github.com/josStorer/RWKV-Runner) (RWKV offline LLM deployment tool, also usable as a client for ChatGPT and Ollama)
|
||||
- [Ollama Grid Search](https://github.com/dezoito/ollama-grid-search) (app to evaluate and compare models)
|
||||
- [Olpaka](https://github.com/Otacon/olpaka) (User-friendly Flutter Web App for Ollama)
|
||||
- [Casibase](https://casibase.org) (An open source AI knowledge base and dialogue system combining the latest RAG, SSO, ollama support and multiple large language models.)
|
||||
- [OllamaSpring](https://github.com/CrazyNeil/OllamaSpring) (Ollama Client for macOS)
|
||||
- [LLocal.in](https://github.com/kartikm7/llocal) (Easy to use Electron Desktop Client for Ollama)
|
||||
- [Shinkai Desktop](https://github.com/dcSpark/shinkai-apps) (Two click install Local AI using Ollama + Files + RAG)
|
||||
@@ -345,7 +348,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [PartCAD](https://github.com/openvmp/partcad/) (CAD model generation with OpenSCAD and CadQuery)
|
||||
- [Ollama4j Web UI](https://github.com/ollama4j/ollama4j-web-ui) - Java-based Web UI for Ollama built with Vaadin, Spring Boot and Ollama4j
|
||||
- [PyOllaMx](https://github.com/kspviswa/pyOllaMx) - macOS application capable of chatting with both Ollama and Apple MLX models.
|
||||
- [Claude Dev](https://github.com/saoudrizwan/claude-dev) - VSCode extension for multi-file/whole-repo coding
|
||||
- [Cline](https://github.com/cline/cline) - Formerly known as Claude Dev is a VSCode extension for multi-file/whole-repo coding
|
||||
- [Cherry Studio](https://github.com/kangfenmao/cherry-studio) (Desktop client with Ollama support)
|
||||
- [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy focused LLM chat interface with optional encryption)
|
||||
- [Archyve](https://github.com/nickthecook/archyve) (RAG-enabling document library)
|
||||
@@ -391,6 +394,10 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [1Panel](https://github.com/1Panel-dev/1Panel/) (Web-based Linux Server Management Tool)
|
||||
- [AstrBot](https://github.com/Soulter/AstrBot/) (User-friendly LLM-based multi-platform chatbot with a WebUI, supporting RAG, LLM agents, and plugins integration)
|
||||
- [Reins](https://github.com/ibrahimcetin/reins) (Easily tweak parameters, customize system prompts per chat, and enhance your AI experiments with reasoning model support.)
|
||||
- [Ellama](https://github.com/zeozeozeo/ellama) (Friendly native app to chat with an Ollama instance)
|
||||
- [screenpipe](https://github.com/mediar-ai/screenpipe) Build agents powered by your screen history
|
||||
- [Ollamb](https://github.com/hengkysteen/ollamb) (Simple yet rich in features, cross-platform built with Flutter and designed for Ollama. Try the [web demo](https://hengkysteen.github.io/demo/ollamb/).)
|
||||
- [Writeopia](https://github.com/Writeopia/Writeopia) (Text editor with integration with Ollama)
|
||||
|
||||
### Cloud
|
||||
|
||||
@@ -430,7 +437,10 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [SwollamaCLI](https://github.com/marcusziade/Swollama) bundled with the Swollama Swift package. [Demo](https://github.com/marcusziade/Swollama?tab=readme-ov-file#cli-usage)
|
||||
- [aichat](https://github.com/sigoden/aichat) All-in-one LLM CLI tool featuring Shell Assistant, Chat-REPL, RAG, AI tools & agents, with access to OpenAI, Claude, Gemini, Ollama, Groq, and more.
|
||||
- [PowershAI](https://github.com/rrg92/powershai) PowerShell module that brings AI to terminal on Windows, including support for Ollama
|
||||
- [DeepShell](https://github.com/Abyss-c0re/deepshell) Your self-hosted AI assistant. Interactive Shell, Files and Folders analysis.
|
||||
- [orbiton](https://github.com/xyproto/orbiton) Configuration-free text editor and IDE with support for tab completion with Ollama.
|
||||
- [orca-cli](https://github.com/molbal/orca-cli) Ollama Registry CLI Application - Browse, pull and download models from Ollama Registry in your terminal.
|
||||
- [GGUF-to-Ollama](https://github.com/jonathanhecl/gguf-to-ollama) - Importing GGUF to Ollama made easy (multiplatform)
|
||||
|
||||
### Apple Vision Pro
|
||||
|
||||
@@ -509,6 +519,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Ollama for Zig](https://github.com/dravenk/ollama-zig)
|
||||
- [Abso](https://github.com/lunary-ai/abso) (OpenAI-compatible TypeScript SDK for any LLM provider)
|
||||
- [Nichey](https://github.com/goodreasonai/nichey) is a Python package for generating custom wikis for your research topic
|
||||
- [Ollama for D](https://github.com/kassane/ollama-d)
|
||||
|
||||
### Mobile
|
||||
|
||||
|
||||
104
api/types.go
104
api/types.go
@@ -12,6 +12,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
// StatusError is an error with an HTTP status code and message.
|
||||
@@ -81,7 +82,7 @@ type GenerateRequest struct {
|
||||
|
||||
// Options lists model-specific options. For example, temperature can be
|
||||
// set through this field, if the model supports it.
|
||||
Options map[string]interface{} `json:"options"`
|
||||
Options map[string]any `json:"options"`
|
||||
}
|
||||
|
||||
// ChatRequest describes a request sent by [Client.Chat].
|
||||
@@ -106,7 +107,7 @@ type ChatRequest struct {
|
||||
Tools `json:"tools,omitempty"`
|
||||
|
||||
// Options lists model-specific options.
|
||||
Options map[string]interface{} `json:"options"`
|
||||
Options map[string]any `json:"options"`
|
||||
}
|
||||
|
||||
type Tools []Tool
|
||||
@@ -162,19 +163,65 @@ func (t *ToolCallFunctionArguments) String() string {
|
||||
|
||||
type Tool struct {
|
||||
Type string `json:"type"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Function ToolFunction `json:"function"`
|
||||
}
|
||||
|
||||
// PropertyType can be either a string or an array of strings
|
||||
type PropertyType []string
|
||||
|
||||
// UnmarshalJSON implements the json.Unmarshaler interface
|
||||
func (pt *PropertyType) UnmarshalJSON(data []byte) error {
|
||||
// Try to unmarshal as a string first
|
||||
var s string
|
||||
if err := json.Unmarshal(data, &s); err == nil {
|
||||
*pt = []string{s}
|
||||
return nil
|
||||
}
|
||||
|
||||
// If that fails, try to unmarshal as an array of strings
|
||||
var a []string
|
||||
if err := json.Unmarshal(data, &a); err != nil {
|
||||
return err
|
||||
}
|
||||
*pt = a
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarshalJSON implements the json.Marshaler interface
|
||||
func (pt PropertyType) MarshalJSON() ([]byte, error) {
|
||||
if len(pt) == 1 {
|
||||
// If there's only one type, marshal as a string
|
||||
return json.Marshal(pt[0])
|
||||
}
|
||||
// Otherwise marshal as an array
|
||||
return json.Marshal([]string(pt))
|
||||
}
|
||||
|
||||
// String returns a string representation of the PropertyType
|
||||
func (pt PropertyType) String() string {
|
||||
if len(pt) == 0 {
|
||||
return ""
|
||||
}
|
||||
if len(pt) == 1 {
|
||||
return pt[0]
|
||||
}
|
||||
return fmt.Sprintf("%v", []string(pt))
|
||||
}
|
||||
|
||||
type ToolFunction struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Parameters struct {
|
||||
Type string `json:"type"`
|
||||
Defs any `json:"$defs,omitempty"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Required []string `json:"required"`
|
||||
Properties map[string]struct {
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
Enum []string `json:"enum,omitempty"`
|
||||
Type PropertyType `json:"type"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Description string `json:"description"`
|
||||
Enum []any `json:"enum,omitempty"`
|
||||
} `json:"properties"`
|
||||
} `json:"parameters"`
|
||||
}
|
||||
@@ -260,7 +307,7 @@ type EmbedRequest struct {
|
||||
Truncate *bool `json:"truncate,omitempty"`
|
||||
|
||||
// Options lists model-specific options.
|
||||
Options map[string]interface{} `json:"options"`
|
||||
Options map[string]any `json:"options"`
|
||||
}
|
||||
|
||||
// EmbedResponse is the response from [Client.Embed].
|
||||
@@ -286,7 +333,7 @@ type EmbeddingRequest struct {
|
||||
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
||||
|
||||
// Options lists model-specific options.
|
||||
Options map[string]interface{} `json:"options"`
|
||||
Options map[string]any `json:"options"`
|
||||
}
|
||||
|
||||
// EmbeddingResponse is the response from [Client.Embeddings].
|
||||
@@ -332,7 +379,7 @@ type ShowRequest struct {
|
||||
Template string `json:"template"`
|
||||
Verbose bool `json:"verbose"`
|
||||
|
||||
Options map[string]interface{} `json:"options"`
|
||||
Options map[string]any `json:"options"`
|
||||
|
||||
// Deprecated: set the model name with Model instead
|
||||
Name string `json:"name"`
|
||||
@@ -340,16 +387,18 @@ type ShowRequest struct {
|
||||
|
||||
// ShowResponse is the response returned from [Client.Show].
|
||||
type ShowResponse struct {
|
||||
License string `json:"license,omitempty"`
|
||||
Modelfile string `json:"modelfile,omitempty"`
|
||||
Parameters string `json:"parameters,omitempty"`
|
||||
Template string `json:"template,omitempty"`
|
||||
System string `json:"system,omitempty"`
|
||||
Details ModelDetails `json:"details,omitempty"`
|
||||
Messages []Message `json:"messages,omitempty"`
|
||||
ModelInfo map[string]any `json:"model_info,omitempty"`
|
||||
ProjectorInfo map[string]any `json:"projector_info,omitempty"`
|
||||
ModifiedAt time.Time `json:"modified_at,omitempty"`
|
||||
License string `json:"license,omitempty"`
|
||||
Modelfile string `json:"modelfile,omitempty"`
|
||||
Parameters string `json:"parameters,omitempty"`
|
||||
Template string `json:"template,omitempty"`
|
||||
System string `json:"system,omitempty"`
|
||||
Details ModelDetails `json:"details,omitempty"`
|
||||
Messages []Message `json:"messages,omitempty"`
|
||||
ModelInfo map[string]any `json:"model_info,omitempty"`
|
||||
ProjectorInfo map[string]any `json:"projector_info,omitempty"`
|
||||
Tensors []Tensor `json:"tensors,omitempty"`
|
||||
Capabilities []model.Capability `json:"capabilities,omitempty"`
|
||||
ModifiedAt time.Time `json:"modified_at,omitempty"`
|
||||
}
|
||||
|
||||
// CopyRequest is the request passed to [Client.Copy].
|
||||
@@ -467,6 +516,13 @@ type ModelDetails struct {
|
||||
QuantizationLevel string `json:"quantization_level"`
|
||||
}
|
||||
|
||||
// Tensor describes the metadata for a given tensor.
|
||||
type Tensor struct {
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
Shape []uint64 `json:"shape"`
|
||||
}
|
||||
|
||||
func (m *Metrics) Summary() {
|
||||
if m.TotalDuration > 0 {
|
||||
fmt.Fprintf(os.Stderr, "total duration: %v\n", m.TotalDuration)
|
||||
@@ -495,7 +551,7 @@ func (m *Metrics) Summary() {
|
||||
}
|
||||
}
|
||||
|
||||
func (opts *Options) FromMap(m map[string]interface{}) error {
|
||||
func (opts *Options) FromMap(m map[string]any) error {
|
||||
valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct
|
||||
typeOpts := reflect.TypeOf(opts).Elem() // types of the fields in the options struct
|
||||
|
||||
@@ -552,12 +608,12 @@ func (opts *Options) FromMap(m map[string]interface{}) error {
|
||||
}
|
||||
field.SetString(val)
|
||||
case reflect.Slice:
|
||||
// JSON unmarshals to []interface{}, not []string
|
||||
val, ok := val.([]interface{})
|
||||
// JSON unmarshals to []any, not []string
|
||||
val, ok := val.([]any)
|
||||
if !ok {
|
||||
return fmt.Errorf("option %q must be of type array", key)
|
||||
}
|
||||
// convert []interface{} to []string
|
||||
// convert []any to []string
|
||||
slice := make([]string, len(val))
|
||||
for i, item := range val {
|
||||
str, ok := item.(string)
|
||||
@@ -664,7 +720,7 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) {
|
||||
}
|
||||
|
||||
// FormatParams converts specified parameter options to their correct types
|
||||
func FormatParams(params map[string][]string) (map[string]interface{}, error) {
|
||||
func FormatParams(params map[string][]string) (map[string]any, error) {
|
||||
opts := Options{}
|
||||
valueOpts := reflect.ValueOf(&opts).Elem() // names of the fields in the options struct
|
||||
typeOpts := reflect.TypeOf(opts) // types of the fields in the options struct
|
||||
@@ -678,7 +734,7 @@ func FormatParams(params map[string][]string) (map[string]interface{}, error) {
|
||||
}
|
||||
}
|
||||
|
||||
out := make(map[string]interface{})
|
||||
out := make(map[string]any)
|
||||
// iterate params and set values based on json struct tags
|
||||
for key, vals := range params {
|
||||
if opt, ok := jsonOpts[key]; !ok {
|
||||
|
||||
@@ -134,7 +134,7 @@ func TestUseMmapParsingFromJSON(t *testing.T) {
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
var oMap map[string]interface{}
|
||||
var oMap map[string]any
|
||||
err := json.Unmarshal([]byte(test.req), &oMap)
|
||||
require.NoError(t, err)
|
||||
opts := DefaultOptions()
|
||||
@@ -231,3 +231,144 @@ func TestMessage_UnmarshalJSON(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolFunction_UnmarshalJSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "valid enum with same types",
|
||||
input: `{
|
||||
"name": "test",
|
||||
"description": "test function",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"required": ["test"],
|
||||
"properties": {
|
||||
"test": {
|
||||
"type": "string",
|
||||
"description": "test prop",
|
||||
"enum": ["a", "b", "c"]
|
||||
}
|
||||
}
|
||||
}
|
||||
}`,
|
||||
wantErr: "",
|
||||
},
|
||||
{
|
||||
name: "empty enum array",
|
||||
input: `{
|
||||
"name": "test",
|
||||
"description": "test function",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"required": ["test"],
|
||||
"properties": {
|
||||
"test": {
|
||||
"type": "string",
|
||||
"description": "test prop",
|
||||
"enum": []
|
||||
}
|
||||
}
|
||||
}
|
||||
}`,
|
||||
wantErr: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var tf ToolFunction
|
||||
err := json.Unmarshal([]byte(tt.input), &tf)
|
||||
|
||||
if tt.wantErr != "" {
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.wantErr)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPropertyType_UnmarshalJSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected PropertyType
|
||||
}{
|
||||
{
|
||||
name: "string type",
|
||||
input: `"string"`,
|
||||
expected: PropertyType{"string"},
|
||||
},
|
||||
{
|
||||
name: "array of types",
|
||||
input: `["string", "number"]`,
|
||||
expected: PropertyType{"string", "number"},
|
||||
},
|
||||
{
|
||||
name: "array with single type",
|
||||
input: `["string"]`,
|
||||
expected: PropertyType{"string"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
var pt PropertyType
|
||||
if err := json.Unmarshal([]byte(test.input), &pt); err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(pt) != len(test.expected) {
|
||||
t.Errorf("Length mismatch: got %v, expected %v", len(pt), len(test.expected))
|
||||
}
|
||||
|
||||
for i, v := range pt {
|
||||
if v != test.expected[i] {
|
||||
t.Errorf("Value mismatch at index %d: got %v, expected %v", i, v, test.expected[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPropertyType_MarshalJSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input PropertyType
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "single type",
|
||||
input: PropertyType{"string"},
|
||||
expected: `"string"`,
|
||||
},
|
||||
{
|
||||
name: "multiple types",
|
||||
input: PropertyType{"string", "number"},
|
||||
expected: `["string","number"]`,
|
||||
},
|
||||
{
|
||||
name: "empty type",
|
||||
input: PropertyType{},
|
||||
expected: `[]`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
data, err := json.Marshal(test.input)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if string(data) != test.expected {
|
||||
t.Errorf("Marshaled data mismatch: got %v, expected %v", string(data), test.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
178
benchmark/server_benchmark_test.go
Normal file
178
benchmark/server_benchmark_test.go
Normal file
@@ -0,0 +1,178 @@
|
||||
package benchmark
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// Command line flags
|
||||
var modelFlag string
|
||||
|
||||
func init() {
|
||||
flag.StringVar(&modelFlag, "m", "", "Name of the model to benchmark")
|
||||
flag.Lookup("m").DefValue = "model"
|
||||
}
|
||||
|
||||
// modelName returns the model name from flags, failing the test if not set
|
||||
func modelName(b *testing.B) string {
|
||||
if modelFlag == "" {
|
||||
b.Fatal("Error: -m flag is required for benchmark tests")
|
||||
}
|
||||
return modelFlag
|
||||
}
|
||||
|
||||
type TestCase struct {
|
||||
name string
|
||||
prompt string
|
||||
maxTokens int
|
||||
}
|
||||
|
||||
// runGenerateBenchmark contains the common generate and metrics logic
|
||||
func runGenerateBenchmark(b *testing.B, ctx context.Context, client *api.Client, req *api.GenerateRequest) {
|
||||
start := time.Now()
|
||||
var ttft time.Duration
|
||||
var metrics api.Metrics
|
||||
|
||||
err := client.Generate(ctx, req, func(resp api.GenerateResponse) error {
|
||||
if ttft == 0 && resp.Response != "" {
|
||||
ttft = time.Since(start)
|
||||
}
|
||||
if resp.Done {
|
||||
metrics = resp.Metrics
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
// Report custom metrics as part of the benchmark results
|
||||
b.ReportMetric(float64(ttft.Milliseconds()), "ttft_ms")
|
||||
b.ReportMetric(float64(metrics.LoadDuration.Milliseconds()), "load_ms")
|
||||
|
||||
// Token throughput metrics
|
||||
promptThroughput := float64(metrics.PromptEvalCount) / metrics.PromptEvalDuration.Seconds()
|
||||
genThroughput := float64(metrics.EvalCount) / metrics.EvalDuration.Seconds()
|
||||
b.ReportMetric(promptThroughput, "prompt_tok/s")
|
||||
b.ReportMetric(genThroughput, "gen_tok/s")
|
||||
|
||||
// Token counts
|
||||
b.ReportMetric(float64(metrics.PromptEvalCount), "prompt_tokens")
|
||||
b.ReportMetric(float64(metrics.EvalCount), "gen_tokens")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkColdStart runs benchmarks with model loading from cold state
|
||||
func BenchmarkColdStart(b *testing.B) {
|
||||
client := setup(b)
|
||||
tests := []TestCase{
|
||||
{"short_prompt", "Write a long story", 100},
|
||||
{"medium_prompt", "Write a detailed economic analysis", 500},
|
||||
{"long_prompt", "Write a comprehensive AI research paper", 1000},
|
||||
}
|
||||
m := modelName(b)
|
||||
|
||||
for _, tt := range tests {
|
||||
b.Run(fmt.Sprintf("%s/cold/%s", m, tt.name), func(b *testing.B) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Set number of tokens as our throughput metric
|
||||
b.SetBytes(int64(tt.maxTokens))
|
||||
|
||||
for b.Loop() {
|
||||
b.StopTimer()
|
||||
// Ensure model is unloaded before each iteration
|
||||
unload(client, m, b)
|
||||
b.StartTimer()
|
||||
|
||||
req := &api.GenerateRequest{
|
||||
Model: m,
|
||||
Prompt: tt.prompt,
|
||||
Options: map[string]any{"num_predict": tt.maxTokens, "temperature": 0.1},
|
||||
}
|
||||
|
||||
runGenerateBenchmark(b, ctx, client, req)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkWarmStart runs benchmarks with pre-loaded model
|
||||
func BenchmarkWarmStart(b *testing.B) {
|
||||
client := setup(b)
|
||||
tests := []TestCase{
|
||||
{"short_prompt", "Write a long story", 100},
|
||||
{"medium_prompt", "Write a detailed economic analysis", 500},
|
||||
{"long_prompt", "Write a comprehensive AI research paper", 1000},
|
||||
}
|
||||
m := modelName(b)
|
||||
|
||||
for _, tt := range tests {
|
||||
b.Run(fmt.Sprintf("%s/warm/%s", m, tt.name), func(b *testing.B) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Pre-warm the model
|
||||
warmup(client, m, tt.prompt, b)
|
||||
|
||||
// Set number of tokens as our throughput metric
|
||||
b.SetBytes(int64(tt.maxTokens))
|
||||
|
||||
for b.Loop() {
|
||||
req := &api.GenerateRequest{
|
||||
Model: m,
|
||||
Prompt: tt.prompt,
|
||||
Options: map[string]any{"num_predict": tt.maxTokens, "temperature": 0.1},
|
||||
}
|
||||
|
||||
runGenerateBenchmark(b, ctx, client, req)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// setup verifies server and model availability
|
||||
func setup(b *testing.B) *api.Client {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
if _, err := client.Show(context.Background(), &api.ShowRequest{Model: modelName(b)}); err != nil {
|
||||
b.Fatalf("Model unavailable: %v", err)
|
||||
}
|
||||
|
||||
return client
|
||||
}
|
||||
|
||||
// warmup ensures the model is loaded and warmed up
|
||||
func warmup(client *api.Client, model string, prompt string, b *testing.B) {
|
||||
for range 3 {
|
||||
err := client.Generate(
|
||||
context.Background(),
|
||||
&api.GenerateRequest{
|
||||
Model: model,
|
||||
Prompt: prompt,
|
||||
Options: map[string]any{"num_predict": 50, "temperature": 0.1},
|
||||
},
|
||||
func(api.GenerateResponse) error { return nil },
|
||||
)
|
||||
if err != nil {
|
||||
b.Logf("Error during model warm-up: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// unload forces model unloading using KeepAlive: 0 parameter
|
||||
func unload(client *api.Client, model string, b *testing.B) {
|
||||
req := &api.GenerateRequest{
|
||||
Model: model,
|
||||
KeepAlive: &api.Duration{Duration: 0},
|
||||
}
|
||||
if err := client.Generate(context.Background(), req, func(api.GenerateResponse) error { return nil }); err != nil {
|
||||
b.Logf("Unload error: %v", err)
|
||||
}
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
72
cmd/cmd.go
72
cmd/cmd.go
@@ -18,6 +18,8 @@ import (
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
@@ -266,7 +268,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
opts := runOptions{
|
||||
Model: args[0],
|
||||
WordWrap: os.Getenv("TERM") == "xterm-256color",
|
||||
Options: map[string]interface{}{},
|
||||
Options: map[string]any{},
|
||||
}
|
||||
|
||||
format, err := cmd.Flags().GetString("format")
|
||||
@@ -338,6 +340,11 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
opts.MultiModal = slices.Contains(info.Capabilities, model.CapabilityVision)
|
||||
|
||||
// TODO: remove the projector info and vision info checks below,
|
||||
// these are left in for backwards compatibility with older servers
|
||||
// that don't have the capabilities field in the model info
|
||||
if len(info.ProjectorInfo) != 0 {
|
||||
opts.MultiModal = true
|
||||
}
|
||||
@@ -568,8 +575,9 @@ func ShowHandler(cmd *cobra.Command, args []string) error {
|
||||
parameters, errParams := cmd.Flags().GetBool("parameters")
|
||||
system, errSystem := cmd.Flags().GetBool("system")
|
||||
template, errTemplate := cmd.Flags().GetBool("template")
|
||||
verbose, errVerbose := cmd.Flags().GetBool("verbose")
|
||||
|
||||
for _, boolErr := range []error{errLicense, errModelfile, errParams, errSystem, errTemplate} {
|
||||
for _, boolErr := range []error{errLicense, errModelfile, errParams, errSystem, errTemplate, errVerbose} {
|
||||
if boolErr != nil {
|
||||
return errors.New("error retrieving flags")
|
||||
}
|
||||
@@ -607,7 +615,7 @@ func ShowHandler(cmd *cobra.Command, args []string) error {
|
||||
return errors.New("only one of '--license', '--modelfile', '--parameters', '--system', or '--template' can be specified")
|
||||
}
|
||||
|
||||
req := api.ShowRequest{Name: args[0]}
|
||||
req := api.ShowRequest{Name: args[0], Verbose: verbose}
|
||||
resp, err := client.Show(cmd.Context(), &req)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -630,10 +638,10 @@ func ShowHandler(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
return showInfo(resp, os.Stdout)
|
||||
return showInfo(resp, verbose, os.Stdout)
|
||||
}
|
||||
|
||||
func showInfo(resp *api.ShowResponse, w io.Writer) error {
|
||||
func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
|
||||
tableRender := func(header string, rows func() [][]string) {
|
||||
fmt.Fprintln(w, " ", header)
|
||||
table := tablewriter.NewWriter(w)
|
||||
@@ -667,6 +675,15 @@ func showInfo(resp *api.ShowResponse, w io.Writer) error {
|
||||
return
|
||||
})
|
||||
|
||||
if len(resp.Capabilities) > 0 {
|
||||
tableRender("Capabilities", func() (rows [][]string) {
|
||||
for _, capability := range resp.Capabilities {
|
||||
rows = append(rows, []string{"", capability.String()})
|
||||
}
|
||||
return
|
||||
})
|
||||
}
|
||||
|
||||
if resp.ProjectorInfo != nil {
|
||||
tableRender("Projector", func() (rows [][]string) {
|
||||
arch := resp.ProjectorInfo["general.architecture"].(string)
|
||||
@@ -690,6 +707,47 @@ func showInfo(resp *api.ShowResponse, w io.Writer) error {
|
||||
})
|
||||
}
|
||||
|
||||
if resp.ModelInfo != nil && verbose {
|
||||
tableRender("Metadata", func() (rows [][]string) {
|
||||
keys := make([]string, 0, len(resp.ModelInfo))
|
||||
for k := range resp.ModelInfo {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
for _, k := range keys {
|
||||
var v string
|
||||
switch vData := resp.ModelInfo[k].(type) {
|
||||
case bool:
|
||||
v = fmt.Sprintf("%t", vData)
|
||||
case string:
|
||||
v = vData
|
||||
case float64:
|
||||
v = fmt.Sprintf("%g", vData)
|
||||
case []any:
|
||||
n := 3
|
||||
if len(vData) < n {
|
||||
n = len(vData)
|
||||
}
|
||||
v = fmt.Sprintf("%v", vData[:n])
|
||||
default:
|
||||
v = fmt.Sprintf("%T", vData)
|
||||
}
|
||||
rows = append(rows, []string{"", k, v})
|
||||
}
|
||||
return
|
||||
})
|
||||
}
|
||||
|
||||
if len(resp.Tensors) > 0 && verbose {
|
||||
tableRender("Tensors", func() (rows [][]string) {
|
||||
for _, t := range resp.Tensors {
|
||||
rows = append(rows, []string{"", t.Name, t.Type, fmt.Sprint(t.Shape)})
|
||||
}
|
||||
return
|
||||
})
|
||||
}
|
||||
|
||||
head := func(s string, n int) (rows [][]string) {
|
||||
scanner := bufio.NewScanner(strings.NewReader(s))
|
||||
for scanner.Scan() && (len(rows) < n || n < 0) {
|
||||
@@ -794,7 +852,7 @@ type runOptions struct {
|
||||
Format string
|
||||
System string
|
||||
Images []api.ImageData
|
||||
Options map[string]interface{}
|
||||
Options map[string]any
|
||||
MultiModal bool
|
||||
KeepAlive *api.Duration
|
||||
}
|
||||
@@ -1196,6 +1254,7 @@ func NewCLI() *cobra.Command {
|
||||
showCmd.Flags().Bool("parameters", false, "Show parameters of a model")
|
||||
showCmd.Flags().Bool("template", false, "Show template of a model")
|
||||
showCmd.Flags().Bool("system", false, "Show system message of a model")
|
||||
showCmd.Flags().BoolP("verbose", "v", false, "Show detailed model information")
|
||||
|
||||
runCmd := &cobra.Command{
|
||||
Use: "run MODEL [PROMPT]",
|
||||
@@ -1322,7 +1381,6 @@ func NewCLI() *cobra.Command {
|
||||
envVars["OLLAMA_NOPRUNE"],
|
||||
envVars["OLLAMA_ORIGINS"],
|
||||
envVars["OLLAMA_SCHED_SPREAD"],
|
||||
envVars["OLLAMA_TMPDIR"],
|
||||
envVars["OLLAMA_FLASH_ATTENTION"],
|
||||
envVars["OLLAMA_KV_CACHE_TYPE"],
|
||||
envVars["OLLAMA_LLM_LIBRARY"],
|
||||
|
||||
224
cmd/cmd_test.go
224
cmd/cmd_test.go
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
func TestShowInfo(t *testing.T) {
|
||||
@@ -27,7 +28,7 @@ func TestShowInfo(t *testing.T) {
|
||||
ParameterSize: "7B",
|
||||
QuantizationLevel: "FP16",
|
||||
},
|
||||
}, &b); err != nil {
|
||||
}, false, &b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -57,7 +58,7 @@ func TestShowInfo(t *testing.T) {
|
||||
ParameterSize: "7B",
|
||||
QuantizationLevel: "FP16",
|
||||
},
|
||||
}, &b); err != nil {
|
||||
}, false, &b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -68,6 +69,60 @@ func TestShowInfo(t *testing.T) {
|
||||
embedding length 0
|
||||
quantization FP16
|
||||
|
||||
`
|
||||
if diff := cmp.Diff(expect, b.String()); diff != "" {
|
||||
t.Errorf("unexpected output (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("verbose model", func(t *testing.T) {
|
||||
var b bytes.Buffer
|
||||
if err := showInfo(&api.ShowResponse{
|
||||
Details: api.ModelDetails{
|
||||
Family: "test",
|
||||
ParameterSize: "8B",
|
||||
QuantizationLevel: "FP16",
|
||||
},
|
||||
Parameters: `
|
||||
stop up`,
|
||||
ModelInfo: map[string]any{
|
||||
"general.architecture": "test",
|
||||
"general.parameter_count": float64(8_000_000_000),
|
||||
"some.true_bool": true,
|
||||
"some.false_bool": false,
|
||||
"test.context_length": float64(1000),
|
||||
"test.embedding_length": float64(11434),
|
||||
},
|
||||
Tensors: []api.Tensor{
|
||||
{Name: "blk.0.attn_k.weight", Type: "BF16", Shape: []uint64{42, 3117}},
|
||||
{Name: "blk.0.attn_q.weight", Type: "FP16", Shape: []uint64{3117, 42}},
|
||||
},
|
||||
}, true, &b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expect := ` Model
|
||||
architecture test
|
||||
parameters 8B
|
||||
context length 1000
|
||||
embedding length 11434
|
||||
quantization FP16
|
||||
|
||||
Parameters
|
||||
stop up
|
||||
|
||||
Metadata
|
||||
general.architecture test
|
||||
general.parameter_count 8e+09
|
||||
some.false_bool false
|
||||
some.true_bool true
|
||||
test.context_length 1000
|
||||
test.embedding_length 11434
|
||||
|
||||
Tensors
|
||||
blk.0.attn_k.weight BF16 [42 3117]
|
||||
blk.0.attn_q.weight FP16 [3117 42]
|
||||
|
||||
`
|
||||
if diff := cmp.Diff(expect, b.String()); diff != "" {
|
||||
t.Errorf("unexpected output (-want +got):\n%s", diff)
|
||||
@@ -89,7 +144,7 @@ func TestShowInfo(t *testing.T) {
|
||||
stop you
|
||||
stop up
|
||||
temperature 99`,
|
||||
}, &b); err != nil {
|
||||
}, false, &b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -126,7 +181,7 @@ func TestShowInfo(t *testing.T) {
|
||||
"clip.vision.embedding_length": float64(0),
|
||||
"clip.vision.projection_dim": float64(0),
|
||||
},
|
||||
}, &b); err != nil {
|
||||
}, false, &b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -159,7 +214,7 @@ func TestShowInfo(t *testing.T) {
|
||||
Ahoy, matey!
|
||||
Weigh anchor!
|
||||
`,
|
||||
}, &b); err != nil {
|
||||
}, false, &b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -188,7 +243,7 @@ Weigh anchor!
|
||||
QuantizationLevel: "FP16",
|
||||
},
|
||||
License: license,
|
||||
}, &b); err != nil {
|
||||
}, false, &b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -206,6 +261,34 @@ Weigh anchor!
|
||||
t.Errorf("unexpected output (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("capabilities", func(t *testing.T) {
|
||||
var b bytes.Buffer
|
||||
if err := showInfo(&api.ShowResponse{
|
||||
Details: api.ModelDetails{
|
||||
Family: "test",
|
||||
ParameterSize: "7B",
|
||||
QuantizationLevel: "FP16",
|
||||
},
|
||||
Capabilities: []model.Capability{model.CapabilityVision, model.CapabilityTools},
|
||||
}, false, &b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expect := " Model\n" +
|
||||
" architecture test \n" +
|
||||
" parameters 7B \n" +
|
||||
" quantization FP16 \n" +
|
||||
"\n" +
|
||||
" Capabilities\n" +
|
||||
" vision \n" +
|
||||
" tools \n" +
|
||||
"\n"
|
||||
|
||||
if diff := cmp.Diff(expect, b.String()); diff != "" {
|
||||
t.Errorf("unexpected output (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestDeleteHandler(t *testing.T) {
|
||||
@@ -707,3 +790,132 @@ func TestCreateHandler(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewCreateRequest(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
from string
|
||||
opts runOptions
|
||||
expected *api.CreateRequest
|
||||
}{
|
||||
{
|
||||
"basic test",
|
||||
"newmodel",
|
||||
runOptions{
|
||||
Model: "mymodel",
|
||||
ParentModel: "",
|
||||
Prompt: "You are a fun AI agent",
|
||||
Messages: []api.Message{},
|
||||
WordWrap: true,
|
||||
},
|
||||
&api.CreateRequest{
|
||||
From: "mymodel",
|
||||
Model: "newmodel",
|
||||
},
|
||||
},
|
||||
{
|
||||
"parent model test",
|
||||
"newmodel",
|
||||
runOptions{
|
||||
Model: "mymodel",
|
||||
ParentModel: "parentmodel",
|
||||
Messages: []api.Message{},
|
||||
WordWrap: true,
|
||||
},
|
||||
&api.CreateRequest{
|
||||
From: "parentmodel",
|
||||
Model: "newmodel",
|
||||
},
|
||||
},
|
||||
{
|
||||
"parent model as filepath test",
|
||||
"newmodel",
|
||||
runOptions{
|
||||
Model: "mymodel",
|
||||
ParentModel: "/some/file/like/etc/passwd",
|
||||
Messages: []api.Message{},
|
||||
WordWrap: true,
|
||||
},
|
||||
&api.CreateRequest{
|
||||
From: "mymodel",
|
||||
Model: "newmodel",
|
||||
},
|
||||
},
|
||||
{
|
||||
"parent model as windows filepath test",
|
||||
"newmodel",
|
||||
runOptions{
|
||||
Model: "mymodel",
|
||||
ParentModel: "D:\\some\\file\\like\\etc\\passwd",
|
||||
Messages: []api.Message{},
|
||||
WordWrap: true,
|
||||
},
|
||||
&api.CreateRequest{
|
||||
From: "mymodel",
|
||||
Model: "newmodel",
|
||||
},
|
||||
},
|
||||
{
|
||||
"options test",
|
||||
"newmodel",
|
||||
runOptions{
|
||||
Model: "mymodel",
|
||||
ParentModel: "parentmodel",
|
||||
Options: map[string]any{
|
||||
"temperature": 1.0,
|
||||
},
|
||||
},
|
||||
&api.CreateRequest{
|
||||
From: "parentmodel",
|
||||
Model: "newmodel",
|
||||
Parameters: map[string]any{
|
||||
"temperature": 1.0,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"messages test",
|
||||
"newmodel",
|
||||
runOptions{
|
||||
Model: "mymodel",
|
||||
ParentModel: "parentmodel",
|
||||
System: "You are a fun AI agent",
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "hello there!",
|
||||
},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "hello to you!",
|
||||
},
|
||||
},
|
||||
WordWrap: true,
|
||||
},
|
||||
&api.CreateRequest{
|
||||
From: "parentmodel",
|
||||
Model: "newmodel",
|
||||
System: "You are a fun AI agent",
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "hello there!",
|
||||
},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "hello to you!",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
actual := NewCreateRequest(tt.from, tt.opts)
|
||||
if !cmp.Equal(actual, tt.expected) {
|
||||
t.Errorf("expected output %#v, got %#v", tt.expected, actual)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/readline"
|
||||
"github.com/ollama/ollama/types/errtypes"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
type MultilineState int
|
||||
@@ -195,6 +196,10 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
opts.Messages = []api.Message{}
|
||||
fmt.Printf("Loading model '%s'\n", opts.Model)
|
||||
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
fmt.Printf("error: %v\n", err)
|
||||
continue
|
||||
}
|
||||
return err
|
||||
}
|
||||
continue
|
||||
@@ -343,7 +348,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
|
||||
switch args[1] {
|
||||
case "info":
|
||||
_ = showInfo(resp, os.Stderr)
|
||||
_ = showInfo(resp, false, os.Stderr)
|
||||
case "license":
|
||||
if resp.License == "" {
|
||||
fmt.Println("No license was specified for this model.")
|
||||
@@ -455,9 +460,16 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
}
|
||||
|
||||
func NewCreateRequest(name string, opts runOptions) *api.CreateRequest {
|
||||
parentModel := opts.ParentModel
|
||||
|
||||
modelName := model.ParseName(parentModel)
|
||||
if !modelName.IsValid() {
|
||||
parentModel = ""
|
||||
}
|
||||
|
||||
req := &api.CreateRequest{
|
||||
Name: name,
|
||||
From: cmp.Or(opts.ParentModel, opts.Model),
|
||||
Model: name,
|
||||
From: cmp.Or(parentModel, opts.Model),
|
||||
}
|
||||
|
||||
if opts.System != "" {
|
||||
|
||||
@@ -182,8 +182,10 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
|
||||
|
||||
var conv ModelConverter
|
||||
switch p.Architectures[0] {
|
||||
case "LlamaForCausalLM", "MistralForCausalLM":
|
||||
case "LlamaForCausalLM":
|
||||
conv = &llamaModel{}
|
||||
case "Mistral3ForConditionalGeneration":
|
||||
conv = &mistral3Model{}
|
||||
case "MixtralForCausalLM":
|
||||
conv = &mixtralModel{}
|
||||
case "GemmaForCausalLM":
|
||||
@@ -201,7 +203,7 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
|
||||
case "CohereForCausalLM":
|
||||
conv = &commandrModel{}
|
||||
default:
|
||||
return errors.New("unsupported architecture")
|
||||
return fmt.Errorf("unsupported architecture %q", p.Architectures[0])
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(bts, conv); err != nil {
|
||||
|
||||
@@ -87,7 +87,7 @@ func (p *gemma3Model) KV(t *Tokenizer) ggml.KV {
|
||||
kv["gemma3.embedding_length"] = p.HiddenSize
|
||||
kv["gemma3.feed_forward_length"] = p.IntermediateSize
|
||||
default:
|
||||
kv["gemma3.context_length"] = cmp.Or(p.MaxPositionEmbeddings, 8192)
|
||||
kv["gemma3.context_length"] = cmp.Or(p.MaxPositionEmbeddings, 131072)
|
||||
kv["gemma3.embedding_length"] = p.TextModel.HiddenSize
|
||||
kv["gemma3.feed_forward_length"] = p.TextModel.IntermediateSize
|
||||
kv["gemma3.attention.sliding_window"] = p.TextModel.SlidingWindow
|
||||
|
||||
190
convert/convert_mistral.go
Normal file
190
convert/convert_mistral.go
Normal file
@@ -0,0 +1,190 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/pdevine/tensor"
|
||||
"github.com/pdevine/tensor/native"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
type mistral3Model struct {
|
||||
ModelParameters
|
||||
ImageTokenIndex uint32 `json:"image_token_index"`
|
||||
SpatialMergeSize uint32 `json:"spatial_merge_size"`
|
||||
VisionFeatureLayer int32 `json:"vision_feature_layer"`
|
||||
TextModel struct {
|
||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
RMSNormEPS float32 `json:"rms_norm_eps"`
|
||||
HeadDim uint32 `json:"head_dim"`
|
||||
SlidingWindow *uint32 `json:"sliding_window"`
|
||||
HiddenAct string `json:"hidden_act"`
|
||||
VocabSize uint32 `json:"vocab_size"`
|
||||
} `json:"text_config"`
|
||||
VisionModel struct {
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
ImageSize uint32 `json:"image_size"`
|
||||
NumChannels uint32 `json:"num_channels"`
|
||||
PatchSize uint32 `json:"patch_size"`
|
||||
HeadDim uint32 `json:"head_dim"`
|
||||
HiddenAct string `json:"hidden_act"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
} `json:"vision_config"`
|
||||
MultiModalProjectorBias bool `json:"multimodal_projector_bias"`
|
||||
ProjectorHiddenAct string `json:"projector_hidden_act"`
|
||||
}
|
||||
|
||||
func (p *mistral3Model) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "mistral3"
|
||||
kv["mistral3.vocab_size"] = p.TextModel.VocabSize
|
||||
|
||||
// Text configuration
|
||||
kv["mistral3.block_count"] = p.TextModel.NumHiddenLayers
|
||||
kv["mistral3.context_length"] = p.TextModel.MaxPositionEmbeddings
|
||||
kv["mistral3.embedding_length"] = p.TextModel.HiddenSize
|
||||
kv["mistral3.feed_forward_length"] = p.TextModel.IntermediateSize
|
||||
kv["mistral3.attention.head_count"] = p.TextModel.NumAttentionHeads
|
||||
kv["mistral3.attention.head_count_kv"] = p.TextModel.NumKeyValueHeads
|
||||
kv["mistral3.attention.layer_norm_rms_epsilon"] = p.TextModel.RMSNormEPS
|
||||
kv["mistral3.attention.key_length"] = p.TextModel.HeadDim
|
||||
kv["mistral3.attention.value_length"] = p.TextModel.HeadDim
|
||||
kv["mistral3.rope.dimension_count"] = p.TextModel.HiddenSize / p.TextModel.NumHiddenLayers
|
||||
kv["mistral3.rope.freq_base"] = p.TextModel.RopeTheta
|
||||
|
||||
// Vision configuration
|
||||
kv["mistral3.vision.block_count"] = p.VisionModel.NumHiddenLayers
|
||||
kv["mistral3.vision.embedding_length"] = p.VisionModel.HiddenSize
|
||||
kv["mistral3.vision.feed_forward_length"] = p.VisionModel.IntermediateSize
|
||||
kv["mistral3.vision.attention.head_count"] = p.VisionModel.NumAttentionHeads
|
||||
kv["mistral3.vision.attention.key_length"] = p.VisionModel.HeadDim
|
||||
kv["mistral3.vision.image_size"] = p.VisionModel.ImageSize
|
||||
kv["mistral3.vision.patch_size"] = p.VisionModel.PatchSize
|
||||
kv["mistral3.vision.num_channels"] = p.VisionModel.NumChannels
|
||||
// kv["mistral3.vision.attention.layer_norm_epsilon"] = 1e-05 // Default value
|
||||
kv["mistral3.vision.rope.freq_base"] = p.VisionModel.RopeTheta
|
||||
|
||||
// Multimodal configuration
|
||||
kv["mistral3.image_token_index"] = p.ImageTokenIndex
|
||||
kv["mistral3.spatial_merge_size"] = p.SpatialMergeSize
|
||||
|
||||
kv["mistral3.mm.projector_bias"] = p.MultiModalProjectorBias
|
||||
|
||||
if p.ProjectorHiddenAct != "" {
|
||||
kv["mistral3.mm.projector_hidden_act"] = p.ProjectorHiddenAct
|
||||
}
|
||||
|
||||
return kv
|
||||
}
|
||||
|
||||
func (p *mistral3Model) Tensors(ts []Tensor) []ggml.Tensor {
|
||||
var out []ggml.Tensor
|
||||
|
||||
for _, t := range ts {
|
||||
if !strings.HasPrefix(t.Name(), "v.") {
|
||||
if strings.HasSuffix(t.Name(), ".attn_q.weight") ||
|
||||
strings.HasSuffix(t.Name(), ".attn_k.weight") {
|
||||
t.SetRepacker(p.repack)
|
||||
}
|
||||
}
|
||||
|
||||
out = append(out, ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func (p *mistral3Model) Replacements() []string {
|
||||
return []string{
|
||||
"language_model.model.norm", "output_norm",
|
||||
"language_model.model.", "",
|
||||
"language_model.", "",
|
||||
"layers", "blk",
|
||||
"transformer.layers", "blk",
|
||||
"vision_tower", "v",
|
||||
"ln_pre", "encoder_norm",
|
||||
"input_layernorm", "attn_norm",
|
||||
"post_attention_layernorm", "ffn_norm",
|
||||
"embed_tokens", "token_embd",
|
||||
"self_attn.q_proj", "attn_q",
|
||||
"self_attn.k_proj", "attn_k",
|
||||
"self_attn.v_proj", "attn_v",
|
||||
"self_attn.o_proj", "attn_output",
|
||||
"mlp.down_proj", "ffn_down",
|
||||
"mlp.gate_proj", "ffn_gate",
|
||||
"mlp.up_proj", "ffn_up",
|
||||
"attention.q_proj", "attn_q",
|
||||
"attention.k_proj", "attn_k",
|
||||
"attention.v_proj", "attn_v",
|
||||
"attention.o_proj", "attn_output",
|
||||
"attention_norm", "attn_norm",
|
||||
"feed_forward.gate_proj", "ffn_gate",
|
||||
"feed_forward.down_proj", "ffn_down",
|
||||
"feed_forward.up_proj", "ffn_up",
|
||||
"multi_modal_projector", "mm",
|
||||
"ffn_norm", "ffn_norm",
|
||||
"lm_head", "output",
|
||||
}
|
||||
}
|
||||
|
||||
func (p *mistral3Model) repack(name string, data []float32, shape []uint64) ([]float32, error) {
|
||||
var dims []int
|
||||
for _, dim := range shape {
|
||||
dims = append(dims, int(dim))
|
||||
}
|
||||
|
||||
var heads uint32
|
||||
if strings.HasSuffix(name, ".attn_q.weight") {
|
||||
heads = p.TextModel.NumAttentionHeads
|
||||
} else if strings.HasSuffix(name, ".attn_k.weight") {
|
||||
heads = cmp.Or(p.TextModel.NumKeyValueHeads, p.TextModel.NumAttentionHeads)
|
||||
} else {
|
||||
return nil, fmt.Errorf("unknown tensor for repack: %s", name)
|
||||
}
|
||||
|
||||
n := tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
|
||||
if err := n.Reshape(append([]int{int(heads), 2, dims[0] / int(heads) / 2}, dims[1:]...)...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := n.T(0, 2, 1, 3); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := n.Reshape(dims...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := n.Transpose(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ts, err := native.SelectF32(n, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var f32s []float32
|
||||
for _, t := range ts {
|
||||
f32s = append(f32s, t...)
|
||||
}
|
||||
|
||||
return f32s, nil
|
||||
}
|
||||
@@ -62,10 +62,7 @@ func parseTensors(fsys fs.FS, replacer *strings.Replacer) ([]Tensor, error) {
|
||||
Pattern string
|
||||
Func func(fs.FS, *strings.Replacer, ...string) ([]Tensor, error)
|
||||
}{
|
||||
{"model-*-of-*.safetensors", parseSafetensors},
|
||||
{"model.safetensors", parseSafetensors},
|
||||
{"adapters.safetensors", parseSafetensors},
|
||||
{"adapter_model.safetensors", parseSafetensors},
|
||||
{"*.safetensors", parseSafetensors},
|
||||
{"pytorch_model-*-of-*.bin", parseTorch},
|
||||
{"pytorch_model.bin", parseTorch},
|
||||
{"consolidated.*.pth", parseTorch},
|
||||
|
||||
@@ -1360,7 +1360,7 @@ func file_sentencepiece_model_proto_rawDescGZIP() []byte {
|
||||
|
||||
var file_sentencepiece_model_proto_enumTypes = make([]protoimpl.EnumInfo, 2)
|
||||
var file_sentencepiece_model_proto_msgTypes = make([]protoimpl.MessageInfo, 6)
|
||||
var file_sentencepiece_model_proto_goTypes = []interface{}{
|
||||
var file_sentencepiece_model_proto_goTypes = []any{
|
||||
(TrainerSpec_ModelType)(0), // 0: sentencepiece.TrainerSpec.ModelType
|
||||
(ModelProto_SentencePiece_Type)(0), // 1: sentencepiece.ModelProto.SentencePiece.Type
|
||||
(*TrainerSpec)(nil), // 2: sentencepiece.TrainerSpec
|
||||
@@ -1392,7 +1392,7 @@ func file_sentencepiece_model_proto_init() {
|
||||
return
|
||||
}
|
||||
if !protoimpl.UnsafeEnabled {
|
||||
file_sentencepiece_model_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
|
||||
file_sentencepiece_model_proto_msgTypes[0].Exporter = func(v any, i int) any {
|
||||
switch v := v.(*TrainerSpec); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
@@ -1406,7 +1406,7 @@ func file_sentencepiece_model_proto_init() {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_sentencepiece_model_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
|
||||
file_sentencepiece_model_proto_msgTypes[1].Exporter = func(v any, i int) any {
|
||||
switch v := v.(*NormalizerSpec); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
@@ -1420,7 +1420,7 @@ func file_sentencepiece_model_proto_init() {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_sentencepiece_model_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} {
|
||||
file_sentencepiece_model_proto_msgTypes[2].Exporter = func(v any, i int) any {
|
||||
switch v := v.(*SelfTestData); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
@@ -1434,7 +1434,7 @@ func file_sentencepiece_model_proto_init() {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_sentencepiece_model_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} {
|
||||
file_sentencepiece_model_proto_msgTypes[3].Exporter = func(v any, i int) any {
|
||||
switch v := v.(*ModelProto); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
@@ -1448,7 +1448,7 @@ func file_sentencepiece_model_proto_init() {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_sentencepiece_model_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} {
|
||||
file_sentencepiece_model_proto_msgTypes[4].Exporter = func(v any, i int) any {
|
||||
switch v := v.(*SelfTestData_Sample); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
@@ -1460,7 +1460,7 @@ func file_sentencepiece_model_proto_init() {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_sentencepiece_model_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} {
|
||||
file_sentencepiece_model_proto_msgTypes[5].Exporter = func(v any, i int) any {
|
||||
switch v := v.(*ModelProto_SentencePiece); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
|
||||
@@ -12,7 +12,7 @@ func IsNUMA() bool {
|
||||
// numa support in llama.cpp is linux only
|
||||
return false
|
||||
}
|
||||
ids := map[string]interface{}{}
|
||||
ids := map[string]any{}
|
||||
packageIds, _ := filepath.Glob("/sys/devices/system/cpu/cpu*/topology/physical_package_id")
|
||||
for _, packageId := range packageIds {
|
||||
id, err := os.ReadFile(packageId)
|
||||
|
||||
@@ -111,6 +111,7 @@ func GetCPUDetails() ([]CPU, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer file.Close()
|
||||
return linuxCPUDetails(file)
|
||||
}
|
||||
|
||||
@@ -168,13 +169,11 @@ func linuxCPUDetails(file io.Reader) ([]CPU, error) {
|
||||
for id, s := range socketByID {
|
||||
s.CoreCount = len(coreBySocket[id])
|
||||
s.ThreadCount = 0
|
||||
for _, tc := range threadsByCoreBySocket[id] {
|
||||
s.ThreadCount += tc
|
||||
}
|
||||
|
||||
// This only works if HT is enabled, consider a more reliable model, maybe cache size comparisons?
|
||||
efficiencyCoreCount := 0
|
||||
for _, threads := range threadsByCoreBySocket[id] {
|
||||
s.ThreadCount += threads
|
||||
if threads == 1 {
|
||||
efficiencyCoreCount++
|
||||
}
|
||||
|
||||
12
docs/api.md
12
docs/api.md
@@ -558,6 +558,10 @@ Final response:
|
||||
{
|
||||
"model": "llama3.2",
|
||||
"created_at": "2023-08-04T19:22:45.499127Z",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": ""
|
||||
},
|
||||
"done": true,
|
||||
"total_duration": 4883583458,
|
||||
"load_duration": 1334875,
|
||||
@@ -1213,7 +1217,7 @@ Show information about a model including details, modelfile, template, parameter
|
||||
|
||||
```shell
|
||||
curl http://localhost:11434/api/show -d '{
|
||||
"model": "llama3.2"
|
||||
"model": "llava"
|
||||
}'
|
||||
```
|
||||
|
||||
@@ -1256,7 +1260,11 @@ curl http://localhost:11434/api/show -d '{
|
||||
"tokenizer.ggml.pre": "llama-bpe",
|
||||
"tokenizer.ggml.token_type": [], // populates if `verbose=true`
|
||||
"tokenizer.ggml.tokens": [] // populates if `verbose=true`
|
||||
}
|
||||
},
|
||||
"capabilities": [
|
||||
"completion",
|
||||
"vision"
|
||||
],
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
59
docs/benchmark.md
Normal file
59
docs/benchmark.md
Normal file
@@ -0,0 +1,59 @@
|
||||
# Benchmark
|
||||
|
||||
Go benchmark tests that measure end-to-end performance of a running Ollama server. Run these tests to evaluate model inference performance on your hardware and measure the impact of code changes.
|
||||
|
||||
## When to use
|
||||
|
||||
Run these benchmarks when:
|
||||
- Making changes to the model inference engine
|
||||
- Modifying model loading/unloading logic
|
||||
- Changing prompt processing or token generation code
|
||||
- Implementing a new model architecture
|
||||
- Testing performance across different hardware setups
|
||||
|
||||
## Prerequisites
|
||||
- Ollama server running locally with `ollama serve` on `127.0.0.1:11434`
|
||||
## Usage and Examples
|
||||
|
||||
>[!NOTE]
|
||||
>All commands must be run from the root directory of the Ollama project.
|
||||
|
||||
Basic syntax:
|
||||
```bash
|
||||
go test -bench=. ./benchmark/... -m $MODEL_NAME
|
||||
```
|
||||
|
||||
Required flags:
|
||||
- `-bench=.`: Run all benchmarks
|
||||
- `-m`: Model name to benchmark
|
||||
|
||||
Optional flags:
|
||||
- `-count N`: Number of times to run the benchmark (useful for statistical analysis)
|
||||
- `-timeout T`: Maximum time for the benchmark to run (e.g. "10m" for 10 minutes)
|
||||
|
||||
Common usage patterns:
|
||||
|
||||
Single benchmark run with a model specified:
|
||||
```bash
|
||||
go test -bench=. ./benchmark/... -m llama3.3
|
||||
```
|
||||
|
||||
## Output metrics
|
||||
|
||||
The benchmark reports several key metrics:
|
||||
|
||||
- `gen_tok/s`: Generated tokens per second
|
||||
- `prompt_tok/s`: Prompt processing tokens per second
|
||||
- `ttft_ms`: Time to first token in milliseconds
|
||||
- `load_ms`: Model load time in milliseconds
|
||||
- `gen_tokens`: Total tokens generated
|
||||
- `prompt_tokens`: Total prompt tokens processed
|
||||
|
||||
Each benchmark runs two scenarios:
|
||||
- Cold start: Model is loaded from disk for each test
|
||||
- Warm start: Model is pre-loaded in memory
|
||||
|
||||
Three prompt lengths are tested for each scenario:
|
||||
- Short prompt (100 tokens)
|
||||
- Medium prompt (500 tokens)
|
||||
- Long prompt (1000 tokens)
|
||||
15
docs/faq.md
15
docs/faq.md
@@ -20,7 +20,13 @@ Please refer to the [GPU docs](./gpu.md).
|
||||
|
||||
## How can I specify the context window size?
|
||||
|
||||
By default, Ollama uses a context window size of 2048 tokens. This can be overridden with the `OLLAMA_CONTEXT_LENGTH` environment variable. For example, to set the default context length to 8K, use: `OLLAMA_CONTEXT_LENGTH=8192 ollama serve`.
|
||||
By default, Ollama uses a context window size of 2048 tokens.
|
||||
|
||||
This can be overridden with the `OLLAMA_CONTEXT_LENGTH` environment variable. For example, to set the default context window to 8K, use:
|
||||
|
||||
```shell
|
||||
OLLAMA_CONTEXT_LENGTH=8192 ollama serve
|
||||
```
|
||||
|
||||
To change this when using `ollama run`, use `/set parameter`:
|
||||
|
||||
@@ -187,6 +193,13 @@ cloudflared tunnel --url http://localhost:11434 --http-host-header="localhost:11
|
||||
|
||||
Ollama allows cross-origin requests from `127.0.0.1` and `0.0.0.0` by default. Additional origins can be configured with `OLLAMA_ORIGINS`.
|
||||
|
||||
For browser extensions, you'll need to explicitly allow the extension's origin pattern. Set `OLLAMA_ORIGINS` to include `chrome-extension://*`, `moz-extension://*`, and `safari-web-extension://*` if you wish to allow all browser extensions access, or specific extensions as needed:
|
||||
|
||||
```
|
||||
# Allow all Chrome, Firefox, and Safari extensions
|
||||
OLLAMA_ORIGINS=chrome-extension://*,moz-extension://*,safari-web-extension://* ollama serve
|
||||
```
|
||||
|
||||
Refer to the section [above](#how-do-i-configure-ollama-server) for how to set environment variables on your platform.
|
||||
|
||||
## Where are models stored?
|
||||
|
||||
@@ -9,7 +9,7 @@ cat ~/.ollama/logs/server.log
|
||||
On **Linux** systems with systemd, the logs can be found with this command:
|
||||
|
||||
```shell
|
||||
journalctl -u ollama --no-pager
|
||||
journalctl -u ollama --no-pager --follow --pager-end
|
||||
```
|
||||
|
||||
When you run Ollama in a **container**, the logs go to stdout/stderr in the container:
|
||||
@@ -26,7 +26,6 @@ When you run Ollama on **Windows**, there are a few different locations. You can
|
||||
- `explorer %LOCALAPPDATA%\Ollama` to view logs. The most recent server logs will be in `server.log` and older logs will be in `server-#.log`
|
||||
- `explorer %LOCALAPPDATA%\Programs\Ollama` to browse the binaries (The installer adds this to your user PATH)
|
||||
- `explorer %HOMEPATH%\.ollama` to browse where models and configuration is stored
|
||||
- `explorer %TEMP%` where temporary executable files are stored in one or more `ollama*` directories
|
||||
|
||||
To enable additional debug logging to help troubleshoot problems, first **Quit the running app from the tray menu** then in a powershell terminal
|
||||
|
||||
@@ -69,10 +68,6 @@ If you run into problems on Linux and want to install an older version, or you'd
|
||||
curl -fsSL https://ollama.com/install.sh | OLLAMA_VERSION=0.5.7 sh
|
||||
```
|
||||
|
||||
## Linux tmp noexec
|
||||
|
||||
If your system is configured with the "noexec" flag where Ollama stores its temporary executable files, you can specify an alternate location by setting OLLAMA_TMPDIR to a location writable by the user ollama runs as. For example OLLAMA_TMPDIR=/usr/share/ollama/
|
||||
|
||||
## Linux docker
|
||||
|
||||
If Ollama initially works on the GPU in a docker container, but then switches to running on CPU after some period of time with errors in the server log reporting GPU discovery failures, this can be resolved by disabling systemd cgroup management in Docker. Edit `/etc/docker/daemon.json` on the host and add `"exec-opts": ["native.cgroupdriver=cgroupfs"]` to the docker configuration.
|
||||
|
||||
@@ -62,7 +62,6 @@ the explorer window by hitting `<Ctrl>+R` and type in:
|
||||
- *upgrade.log* contains log output for upgrades
|
||||
- `explorer %LOCALAPPDATA%\Programs\Ollama` contains the binaries (The installer adds this to your user PATH)
|
||||
- `explorer %HOMEPATH%\.ollama` contains models and configuration
|
||||
- `explorer %TEMP%` contains temporary executable files in one or more `ollama*` directories
|
||||
|
||||
## Uninstall
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
func assertEqual(t *testing.T, a interface{}, b interface{}) {
|
||||
func assertEqual(t *testing.T, a any, b any) {
|
||||
if a != b {
|
||||
t.Errorf("Assert failed, expected %v, got %v", b, a)
|
||||
}
|
||||
|
||||
13
fs/config.go
Normal file
13
fs/config.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package fs
|
||||
|
||||
type Config interface {
|
||||
Architecture() string
|
||||
String(string, ...string) string
|
||||
Uint(string, ...uint32) uint32
|
||||
Float(string, ...float32) float32
|
||||
Bool(string, ...bool) bool
|
||||
|
||||
Strings(string, ...[]string) []string
|
||||
Uints(string, ...[]uint32) []uint32
|
||||
Floats(string, ...[]float32) []float32
|
||||
}
|
||||
109
fs/ggml/ggml.go
109
fs/ggml/ggml.go
@@ -134,7 +134,10 @@ func (kv KV) Floats(key string, defaultValue ...[]float32) []float32 {
|
||||
}
|
||||
|
||||
func (kv KV) OllamaEngineRequired() bool {
|
||||
return kv.Architecture() == "gemma3"
|
||||
return slices.Contains([]string{
|
||||
"gemma3",
|
||||
"mistral3",
|
||||
}, kv.Architecture())
|
||||
}
|
||||
|
||||
func keyValue[T string | uint32 | uint64 | float32 | *array | bool](kv KV, key string, defaultValue ...T) T {
|
||||
@@ -327,6 +330,10 @@ func (t Tensor) Size() uint64 {
|
||||
return t.parameters() * t.typeSize() / t.blockSize()
|
||||
}
|
||||
|
||||
func (t Tensor) Type() string {
|
||||
return fileType(t.Kind).String()
|
||||
}
|
||||
|
||||
type container interface {
|
||||
Name() string
|
||||
Decode(io.ReadSeeker) (model, error)
|
||||
@@ -409,7 +416,7 @@ func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) {
|
||||
}, offset, nil
|
||||
}
|
||||
|
||||
func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialOffload, fullOffload uint64) {
|
||||
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string) (kv []uint64, partialOffload, fullOffload uint64) {
|
||||
embedding := f.KV().EmbeddingLength()
|
||||
heads := f.KV().HeadCount()
|
||||
headsKV := f.KV().HeadCountKV()
|
||||
@@ -422,7 +429,10 @@ func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialO
|
||||
layers := f.Tensors().GroupLayers()
|
||||
|
||||
bytesPerElement := kvCacheBytesPerElement(kvCacheType)
|
||||
kv = uint64(float64(context*f.KV().BlockCount()*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
|
||||
kv = make([]uint64, f.KV().BlockCount())
|
||||
for i := range kv {
|
||||
kv[i] = uint64(float64(context*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
|
||||
}
|
||||
|
||||
switch f.KV().Architecture() {
|
||||
case "llama":
|
||||
@@ -456,16 +466,14 @@ func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialO
|
||||
case "mllama":
|
||||
var visionTokens, tiles uint64 = 1601, 4
|
||||
|
||||
if crossAttentionLayers, ok := f.KV()["mllama.attention.cross_attention_layers"].(*array); ok {
|
||||
kv = headsKV *
|
||||
(embeddingHeadsK + embeddingHeadsV) * // one for K, one for V
|
||||
(2* // sizeof(float16)
|
||||
(f.KV().BlockCount()-uint64(crossAttentionLayers.size))* // num non-cross attention layers
|
||||
context +
|
||||
4* // sizeof(float32)
|
||||
uint64(crossAttentionLayers.size)* // num cross attention layers
|
||||
visionTokens*
|
||||
tiles)
|
||||
crossAttentionLayers := f.KV().Uints("attention.cross_attention_layers")
|
||||
for i := range kv {
|
||||
if slices.Contains(crossAttentionLayers, uint32(i)) {
|
||||
kv[i] = headsKV * (embeddingHeadsK + embeddingHeadsV) *
|
||||
4 * // sizeof(float32)
|
||||
visionTokens *
|
||||
tiles
|
||||
}
|
||||
}
|
||||
|
||||
fullOffload = max(
|
||||
@@ -501,6 +509,20 @@ func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialO
|
||||
4*embeddingHeadsK*context*8+
|
||||
embedding*embeddingHeadsK*heads*9/16,
|
||||
)
|
||||
|
||||
// Gemma2 also has sliding window attention but we only have an optimized implementation in the Ollama
|
||||
// engine. Gemma3 always uses the Ollama engine.
|
||||
if f.KV().Architecture() == "gemma3" {
|
||||
const gemma3GlobalCacheCount = 6
|
||||
slidingWindow := (uint64(numParallel) * uint64(f.KV().Uint("attention.sliding_window"))) + batch
|
||||
for i := range kv {
|
||||
// Every 6th layer is a global layer, which is the full context size that has already been set. The other
|
||||
// layers are the smaller local (sliding) layers.
|
||||
if (i+1)%gemma3GlobalCacheCount != 0 {
|
||||
kv[i] = uint64(float64(slidingWindow*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
|
||||
}
|
||||
}
|
||||
}
|
||||
case "command-r":
|
||||
fullOffload = max(
|
||||
4*batch*(embedding+vocab),
|
||||
@@ -579,39 +601,52 @@ func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialO
|
||||
}
|
||||
|
||||
func (llm GGML) VisionGraphSize() (weights, graphSize uint64) {
|
||||
if llm.KV().Uint("vision.block_count") == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
for name, layer := range llm.Tensors().GroupLayers() {
|
||||
if name == "v" || strings.HasPrefix(name, "v.") {
|
||||
for _, tensor := range layer {
|
||||
weights += tensor.Size()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
imageSize := uint64(llm.KV().Uint("vision.image_size"))
|
||||
patchSize := uint64(llm.KV().Uint("vision.patch_size"))
|
||||
if patchSize == 0 {
|
||||
slog.Warn("unknown patch size for vision model")
|
||||
return
|
||||
}
|
||||
|
||||
numChannels := uint64(llm.KV().Uint("vision.num_channels"))
|
||||
|
||||
numPatches := (imageSize / patchSize) * (imageSize / patchSize)
|
||||
if _, ok := llm.Tensors().GroupLayers()["v"]["class_embd"]; ok {
|
||||
numPatches++
|
||||
}
|
||||
|
||||
headCount := uint64(llm.KV().Uint("vision.attention.head_count"))
|
||||
embeddingLength := uint64(llm.KV().Uint("vision.embedding_length"))
|
||||
|
||||
switch llm.KV().Architecture() {
|
||||
case "mllama":
|
||||
for _, layer := range llm.Tensors().GroupLayers()["v"] {
|
||||
weights += layer.Size()
|
||||
}
|
||||
|
||||
kv := func(n string) uint64 {
|
||||
if v, ok := llm.KV()["mllama.vision."+n].(uint32); ok {
|
||||
return uint64(v)
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
imageSize := kv("image_size")
|
||||
|
||||
maxNumTiles := kv("max_num_tiles")
|
||||
embeddingLength := kv("embedding_length")
|
||||
headCount := kv("attention.head_count")
|
||||
|
||||
numPatches := (imageSize / kv("patch_size")) * (imageSize / kv("patch_size"))
|
||||
if _, ok := llm.Tensors().GroupLayers()["v"]["class_embd"]; ok {
|
||||
numPatches++
|
||||
}
|
||||
|
||||
numPaddedPatches := numPatches + 8 - (numPatches%8)%8
|
||||
|
||||
maxNumTiles := uint64(llm.KV().Uint("vision.max_num_tiles"))
|
||||
|
||||
graphSize = 4 * (8 +
|
||||
imageSize*imageSize*kv("num_channels")*maxNumTiles +
|
||||
imageSize*imageSize*numChannels*maxNumTiles +
|
||||
embeddingLength*numPatches*maxNumTiles +
|
||||
9*embeddingLength*numPaddedPatches*maxNumTiles +
|
||||
numPaddedPatches*maxNumTiles*numPaddedPatches*maxNumTiles*headCount)
|
||||
case "gemma3", "mistral3":
|
||||
graphSize = 4 * (imageSize*imageSize*numChannels +
|
||||
embeddingLength*patchSize +
|
||||
numPatches*numPatches*headCount)
|
||||
}
|
||||
|
||||
return weights, graphSize
|
||||
}
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ func TestOrcaMiniBlueSky(t *testing.T) {
|
||||
Model: "orca-mini",
|
||||
Prompt: "why is the sky blue?",
|
||||
Stream: &stream,
|
||||
Options: map[string]interface{}{
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
},
|
||||
@@ -39,7 +39,7 @@ func TestUnicode(t *testing.T) {
|
||||
Model: "deepseek-coder-v2:16b-lite-instruct-q2_K",
|
||||
Prompt: "天空为什么是蓝色的?",
|
||||
Stream: &stream,
|
||||
Options: map[string]interface{}{
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
// Workaround deepseek context shifting bug
|
||||
@@ -61,7 +61,7 @@ func TestExtendedUnicodeOutput(t *testing.T) {
|
||||
Model: "gemma2:2b",
|
||||
Prompt: "Output some smily face emoji",
|
||||
Stream: &stream,
|
||||
Options: map[string]interface{}{
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
},
|
||||
@@ -96,7 +96,7 @@ func TestUnicodeModelDir(t *testing.T) {
|
||||
Model: "orca-mini",
|
||||
Prompt: "why is the sky blue?",
|
||||
Stream: &stream,
|
||||
Options: map[string]interface{}{
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
},
|
||||
|
||||
@@ -25,7 +25,7 @@ func TestMultiModelConcurrency(t *testing.T) {
|
||||
Prompt: "why is the ocean blue?",
|
||||
Stream: &stream,
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
Options: map[string]interface{}{
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
@@ -34,7 +34,7 @@ func TestMultiModelConcurrency(t *testing.T) {
|
||||
Prompt: "what is the origin of the us thanksgiving holiday?",
|
||||
Stream: &stream,
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
Options: map[string]interface{}{
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
|
||||
@@ -23,7 +23,7 @@ func TestLongInputContext(t *testing.T) {
|
||||
Model: "llama2",
|
||||
Prompt: "Oh, don’t speak to me of Austria. Perhaps I don’t understand things, but Austria never has wished, and does not wish, for war. She is betraying us! Russia alone must save Europe. Our gracious sovereign recognizes his high vocation and will be true to it. That is the one thing I have faith in! Our good and wonderful sovereign has to perform the noblest role on earth, and he is so virtuous and noble that God will not forsake him. He will fulfill his vocation and crush the hydra of revolution, which has become more terrible than ever in the person of this murderer and villain! We alone must avenge the blood of the just one.... Whom, I ask you, can we rely on?... England with her commercial spirit will not and cannot understand the Emperor Alexander’s loftiness of soul. She has refused to evacuate Malta. She wanted to find, and still seeks, some secret motive in our actions. What answer did Novosíltsev get? None. The English have not understood and cannot understand the self-abnegation of our Emperor who wants nothing for himself, but only desires the good of mankind. And what have they promised? Nothing! And what little they have promised they will not perform! Prussia has always declared that Buonaparte is invincible, and that all Europe is powerless before him.... And I don’t believe a word that Hardenburg says, or Haugwitz either. This famous Prussian neutrality is just a trap. I have faith only in God and the lofty destiny of our adored monarch. He will save Europe! What country is this referring to?",
|
||||
Stream: &stream,
|
||||
Options: map[string]interface{}{
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
"num_ctx": 128,
|
||||
@@ -50,7 +50,7 @@ func TestContextExhaustion(t *testing.T) {
|
||||
Model: "llama2",
|
||||
Prompt: "Write me a story with a ton of emojis?",
|
||||
Stream: &stream,
|
||||
Options: map[string]interface{}{
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
"num_ctx": 128,
|
||||
|
||||
@@ -19,7 +19,7 @@ func TestIntegrationLlava(t *testing.T) {
|
||||
Model: "llava:7b",
|
||||
Prompt: "what does the text in this image say?",
|
||||
Stream: &stream,
|
||||
Options: map[string]interface{}{
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
@@ -47,7 +47,7 @@ func TestIntegrationMllama(t *testing.T) {
|
||||
Model: "x/llama3.2-vision",
|
||||
Prompt: "what does the text in this image say?",
|
||||
Stream: &stream,
|
||||
Options: map[string]interface{}{
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
@@ -66,6 +66,35 @@ func TestIntegrationMllama(t *testing.T) {
|
||||
DoGenerate(ctx, t, client, req, []string{resp}, 240*time.Second, 30*time.Second)
|
||||
}
|
||||
|
||||
func TestIntegrationSplitBatch(t *testing.T) {
|
||||
image, err := base64.StdEncoding.DecodeString(imageEncoding)
|
||||
require.NoError(t, err)
|
||||
req := api.GenerateRequest{
|
||||
Model: "gemma3:4b",
|
||||
// Fill up a chunk of the batch so the image will partially spill over into the next one
|
||||
System: "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed aliquet, justo in malesuada lobortis, odio ligula volutpat quam, quis faucibus ipsum magna quis sapien. Aliquam in venenatis diam, eu viverra magna. Phasellus imperdiet hendrerit volutpat. Vivamus sem ex, facilisis placerat felis non, dictum elementum est. Phasellus aliquam imperdiet lacus, eget placerat ligula sodales vel. Pellentesque nec auctor mi. Curabitur arcu nisi, faucibus eget nunc id, viverra interdum mi. Curabitur ornare ipsum ex, ac euismod ex aliquam in. Vestibulum id magna at purus accumsan fermentum. Proin scelerisque posuere nunc quis interdum. Maecenas sed mollis nisl. Etiam vitae ipsum interdum, placerat est quis, tincidunt velit. Nullam tempor nibh non lorem volutpat efficitur. Cras laoreet diam imperdiet ipsum auctor bibendum. Suspendisse ultrices urna sed metus sagittis suscipit. Quisque ullamcorper aliquam nibh ut mollis. Aenean dapibus mauris pharetra, venenatis elit ac, hendrerit odio. Cras vestibulum erat tempor, lobortis justo eu, lobortis ipsum. Nam laoreet dapibus sem. Proin vel diam ultrices, elementum ante et, ornare lectus. Proin eu accumsan nisl. Praesent ac ex vitae ipsum vulputate tristique facilisis sit amet lacus. Nullam faucibus magna a pellentesque pretium. Nunc lacinia ullamcorper sollicitudin. Donec vitae accumsan turpis, sed porttitor est. Donec porttitor mi vitae augue faucibus, vel mollis diam tincidunt.",
|
||||
Prompt: "what does the text in this image say?",
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
Images: []api.ImageData{
|
||||
image,
|
||||
},
|
||||
}
|
||||
|
||||
// Note: sometimes it returns "the ollamas" sometimes "the ollams"
|
||||
resp := "the ollam"
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
require.NoError(t, PullIfMissing(ctx, client, req.Model))
|
||||
// llava models on CPU can be quite slow to start,
|
||||
DoGenerate(ctx, t, client, req, []string{resp}, 120*time.Second, 30*time.Second)
|
||||
}
|
||||
|
||||
const imageEncoding = `iVBORw0KGgoAAAANSUhEUgAAANIAAAB4CAYAAACHHqzKAAAAAXNSR0IArs4c6QAAAIRlWElmTU0AKgAAAAgABQESAAMAAAABAAEAAAEaAAUAAAABAAAASgEb
|
||||
AAUAAAABAAAAUgEoAAMAAAABAAIAAIdpAAQAAAABAAAAWgAAAAAAAABIAAAAAQAAAEgAAAABAAOgAQADAAAAAQABAACgAgAEAAAAAQAAANKgAwAEAAAAAQAA
|
||||
AHgAAAAAXdsepgAAAAlwSFlzAAALEwAACxMBAJqcGAAAAVlpVFh0WE1MOmNvbS5hZG9iZS54bXAAAAAAADx4OnhtcG1ldGEgeG1sbnM6eD0iYWRvYmU6bnM6
|
||||
|
||||
@@ -20,7 +20,7 @@ var (
|
||||
Model: "orca-mini",
|
||||
Prompt: "why is the ocean blue?",
|
||||
Stream: &stream,
|
||||
Options: map[string]interface{}{
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
@@ -28,7 +28,7 @@ var (
|
||||
Model: "orca-mini",
|
||||
Prompt: "what is the origin of the us thanksgiving holiday?",
|
||||
Stream: &stream,
|
||||
Options: map[string]interface{}{
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
|
||||
@@ -32,7 +32,7 @@ func TestMaxQueue(t *testing.T) {
|
||||
req := api.GenerateRequest{
|
||||
Model: "orca-mini",
|
||||
Prompt: "write a long historical fiction story about christopher columbus. use at least 10 facts from his actual journey",
|
||||
Options: map[string]interface{}{
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
@@ -52,8 +52,8 @@ func TestMaxQueue(t *testing.T) {
|
||||
embedCtx := ctx
|
||||
|
||||
var genwg sync.WaitGroup
|
||||
genwg.Add(1)
|
||||
go func() {
|
||||
genwg.Add(1)
|
||||
defer genwg.Done()
|
||||
slog.Info("Starting generate request")
|
||||
DoGenerate(ctx, t, client, req, resp, 45*time.Second, 5*time.Second)
|
||||
@@ -71,8 +71,8 @@ func TestMaxQueue(t *testing.T) {
|
||||
counterMu := sync.Mutex{}
|
||||
var embedwg sync.WaitGroup
|
||||
for i := 0; i < threadCount; i++ {
|
||||
embedwg.Add(1)
|
||||
go func(i int) {
|
||||
embedwg.Add(1)
|
||||
defer embedwg.Done()
|
||||
slog.Info("embed started", "id", i)
|
||||
embedReq := api.EmbeddingRequest{
|
||||
|
||||
@@ -291,7 +291,7 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
|
||||
Prompt: "why is the ocean blue?",
|
||||
Stream: &stream,
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
Options: map[string]interface{}{
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
@@ -300,7 +300,7 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
|
||||
Prompt: "why is the color of dirt brown?",
|
||||
Stream: &stream,
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
Options: map[string]interface{}{
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
@@ -309,7 +309,7 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
|
||||
Prompt: "what is the origin of the us thanksgiving holiday?",
|
||||
Stream: &stream,
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
Options: map[string]interface{}{
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
@@ -318,7 +318,7 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
|
||||
Prompt: "what is the origin of independence day?",
|
||||
Stream: &stream,
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
Options: map[string]interface{}{
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
@@ -327,7 +327,7 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
|
||||
Prompt: "what is the composition of air?",
|
||||
Stream: &stream,
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
Options: map[string]interface{}{
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
|
||||
@@ -43,20 +43,31 @@ type Cache interface {
|
||||
|
||||
// ** cache management **
|
||||
|
||||
// Init sets up runtime parameters
|
||||
Init(backend ml.Backend, dtype ml.DType, capacity int32)
|
||||
// Init sets up runtime parameters.
|
||||
// backend: Used to allocate cache data storage and execute management operations (such as defrag)
|
||||
// dtype: The data type for storing cache entries
|
||||
// maxSequences: The maximum number of sequences stored in the cache - across all batches
|
||||
// capacity: The number of cache entries to store, per sequence
|
||||
// maxBatch: The maximum number of tokens that can occur in a single batch
|
||||
Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int)
|
||||
|
||||
// Close closes the cache and frees resources associated with it
|
||||
Close()
|
||||
|
||||
// StartForward is called before the start of the model's forward pass.
|
||||
// For each token in the coming batch, there must be a corresponding
|
||||
// entry in positions and seqs.
|
||||
StartForward(ctx ml.Context, opts input.Options) error
|
||||
// entry in positions and seqs. reserve is to preallocate memory
|
||||
// without actually storing data in the cache.
|
||||
StartForward(ctx ml.Context, batch input.Batch, reserve bool) error
|
||||
|
||||
// CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq
|
||||
CopyPrefix(srcSeq, dstSeq int, len int32)
|
||||
|
||||
// CanResume returns true if the cache can continue with the next token at
|
||||
// the given position and sequence. Assumes that the caller has already
|
||||
// verified the contents of the cache.
|
||||
CanResume(seq int, pos int32) bool
|
||||
|
||||
// Remove deletes tokens in the range [beginIndex, endIndex) from seq. Set
|
||||
// endIndex to math.MaxInt32 to remove everything starting at beginIndex.
|
||||
//
|
||||
|
||||
@@ -20,7 +20,6 @@ type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, e
|
||||
// The mask is of shape history size, batch size
|
||||
type Causal struct {
|
||||
DType ml.DType
|
||||
Capacity int32
|
||||
windowSize int32
|
||||
|
||||
opts CausalOptions
|
||||
@@ -98,7 +97,7 @@ func NewSWACache(windowSize int32, shift shiftFn) *Causal {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
|
||||
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||
if c.config == nil {
|
||||
var config ml.CacheConfig
|
||||
if cc, ok := backend.(ml.BackendCacheConfig); ok {
|
||||
@@ -119,9 +118,16 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
|
||||
c.config.MaskDType = ml.DTypeF32
|
||||
}
|
||||
|
||||
var cacheSize int
|
||||
if c.windowSize == math.MaxInt32 || capacity < int(c.windowSize) {
|
||||
cacheSize = maxSequences * capacity
|
||||
} else {
|
||||
cacheSize = (maxSequences * int(c.windowSize)) + maxBatch
|
||||
}
|
||||
cacheSize = roundUp(cacheSize, c.config.CachePadding)
|
||||
c.cells = make([]cacheCell, cacheSize)
|
||||
|
||||
c.DType = dtype
|
||||
c.Capacity = int32(roundUp(int(capacity), c.config.CachePadding))
|
||||
c.cells = make([]cacheCell, c.Capacity)
|
||||
c.cellRanges = make(map[int]cellRange)
|
||||
c.backend = backend
|
||||
}
|
||||
@@ -140,49 +146,60 @@ func (c *Causal) Close() {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Causal) StartForward(ctx ml.Context, opts input.Options) error {
|
||||
c.curBatchSize = len(opts.Positions)
|
||||
c.curSequences = opts.Sequences
|
||||
c.curPositions = opts.Positions
|
||||
func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
||||
c.curBatchSize = len(batch.Positions)
|
||||
c.curSequences = batch.Sequences
|
||||
c.curPositions = batch.Positions
|
||||
c.opts.Except = nil
|
||||
|
||||
var err error
|
||||
c.curLoc, err = c.findStartLoc()
|
||||
if errors.Is(err, ErrKvCacheFull) {
|
||||
c.defrag()
|
||||
if !reserve {
|
||||
c.updateSlidingWindow()
|
||||
|
||||
var err error
|
||||
c.curLoc, err = c.findStartLoc()
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.curCellRange = newRange()
|
||||
for i, pos := range opts.Positions {
|
||||
seq := opts.Sequences[i]
|
||||
|
||||
c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}}
|
||||
|
||||
seqRange, ok := c.cellRanges[seq]
|
||||
if !ok {
|
||||
seqRange = newRange()
|
||||
}
|
||||
|
||||
if c.curLoc+i > seqRange.max {
|
||||
seqRange.max = c.curLoc + i
|
||||
}
|
||||
if seqRange.max > c.curCellRange.max {
|
||||
c.curCellRange.max = seqRange.max
|
||||
}
|
||||
|
||||
if c.curLoc+i < seqRange.min {
|
||||
seqRange.min = c.curLoc + i
|
||||
}
|
||||
if seqRange.min < c.curCellRange.min {
|
||||
c.curCellRange.min = seqRange.min
|
||||
}
|
||||
c.cellRanges[seq] = seqRange
|
||||
if errors.Is(err, ErrKvCacheFull) {
|
||||
c.defrag()
|
||||
c.curLoc, err = c.findStartLoc()
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.curCellRange = newRange()
|
||||
for i, pos := range batch.Positions {
|
||||
seq := batch.Sequences[i]
|
||||
|
||||
c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}}
|
||||
|
||||
seqRange, ok := c.cellRanges[seq]
|
||||
if !ok {
|
||||
seqRange = newRange()
|
||||
}
|
||||
|
||||
if c.curLoc+i > seqRange.max {
|
||||
seqRange.max = c.curLoc + i
|
||||
}
|
||||
if seqRange.max > c.curCellRange.max {
|
||||
c.curCellRange.max = seqRange.max
|
||||
}
|
||||
|
||||
if c.curLoc+i < seqRange.min {
|
||||
seqRange.min = c.curLoc + i
|
||||
}
|
||||
if seqRange.min < c.curCellRange.min {
|
||||
c.curCellRange.min = seqRange.min
|
||||
}
|
||||
c.cellRanges[seq] = seqRange
|
||||
}
|
||||
} else {
|
||||
// If we are reserving memory, don't update any of the cache metadata but set the size
|
||||
// to the worst case.
|
||||
c.curLoc = 0
|
||||
c.curCellRange.min = 0
|
||||
c.curCellRange.max = len(c.cells) - 1
|
||||
}
|
||||
|
||||
var err error
|
||||
c.curMask, err = c.buildMask(ctx)
|
||||
|
||||
return err
|
||||
@@ -210,7 +227,51 @@ func (c *Causal) findStartLoc() (int, error) {
|
||||
}
|
||||
}
|
||||
|
||||
return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, c.Capacity)
|
||||
return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, len(c.cells))
|
||||
}
|
||||
|
||||
func (c *Causal) updateSlidingWindow() {
|
||||
if c.windowSize == math.MaxInt32 {
|
||||
return
|
||||
}
|
||||
|
||||
// create a map of unique sequences to the lowest position in that sequence
|
||||
lowestPos := make(map[int]int32)
|
||||
for i := range c.curPositions {
|
||||
seq := c.curSequences[i]
|
||||
|
||||
pos, ok := lowestPos[seq]
|
||||
if !ok {
|
||||
pos = c.curPositions[i]
|
||||
} else if c.curPositions[i] < pos {
|
||||
pos = c.curPositions[i]
|
||||
}
|
||||
|
||||
lowestPos[seq] = pos
|
||||
}
|
||||
|
||||
// delete any entries that are beyond the window of the oldest position in the sequence
|
||||
for seq, pos := range lowestPos {
|
||||
oldRange, ok := c.cellRanges[seq]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
newRange := newRange()
|
||||
|
||||
for i := oldRange.min; i <= oldRange.max; i++ {
|
||||
if slices.Contains(c.cells[i].sequences, seq) {
|
||||
if c.cells[i].pos < pos-c.windowSize {
|
||||
c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq })
|
||||
} else {
|
||||
newRange.min = min(newRange.min, i)
|
||||
newRange.max = max(newRange.max, i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.cellRanges[seq] = newRange
|
||||
}
|
||||
}
|
||||
|
||||
func roundDown(length, pad int) int {
|
||||
@@ -265,7 +326,7 @@ func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
|
||||
return maskTensor, nil
|
||||
}
|
||||
|
||||
func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) {
|
||||
func (c *Causal) moveCells(ctx ml.Context, src, dst, length int) {
|
||||
for i, key := range c.keys {
|
||||
if key == nil {
|
||||
continue
|
||||
@@ -275,8 +336,8 @@ func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) {
|
||||
numKVHeads := key.Dim(1)
|
||||
rowSize := key.Stride(2)
|
||||
|
||||
kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*len)
|
||||
kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*len)
|
||||
kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*length)
|
||||
kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*length)
|
||||
|
||||
value := c.values[i]
|
||||
var vSrcView, vDstView ml.Tensor
|
||||
@@ -284,14 +345,14 @@ func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) {
|
||||
vHeadDim := value.Dim(1)
|
||||
elemSize := value.Stride(0)
|
||||
|
||||
vSrcView = value.View(ctx, elemSize*src, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)
|
||||
vDstView = value.View(ctx, elemSize*dst, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)
|
||||
vSrcView = value.View(ctx, elemSize*src, length, len(c.cells)*elemSize, vHeadDim*numKVHeads)
|
||||
vDstView = value.View(ctx, elemSize*dst, length, len(c.cells)*elemSize, vHeadDim*numKVHeads)
|
||||
} else {
|
||||
vHeadDim := value.Dim(0)
|
||||
rowSize := value.Stride(2)
|
||||
|
||||
vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*len)
|
||||
vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*len)
|
||||
vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*length)
|
||||
vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*length)
|
||||
}
|
||||
|
||||
ctx.Forward(
|
||||
@@ -321,7 +382,8 @@ func (c *Causal) defrag() {
|
||||
ctx := c.backend.NewContext()
|
||||
|
||||
// For every move, 6 tensors are required per layer (2 views and a
|
||||
// copy for each of k and v).
|
||||
// copy for each of k and v). We also need to refer to the original
|
||||
// k and v cache tensors - once per layer, not per move.
|
||||
layers := 0
|
||||
for _, key := range c.keys {
|
||||
if key == nil {
|
||||
@@ -330,7 +392,7 @@ func (c *Causal) defrag() {
|
||||
layers++
|
||||
}
|
||||
|
||||
maxMoves := ctx.MaxGraphNodes() / (6 * layers)
|
||||
maxMoves := (ctx.MaxGraphNodes() - 2*layers) / (6 * layers)
|
||||
moves := 0
|
||||
|
||||
var pendingSrc, pendingDst, pendingLen int
|
||||
@@ -479,14 +541,14 @@ func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||
}
|
||||
|
||||
if _, ok := c.keys[c.curLayer]; !ok {
|
||||
c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, kHeadDim, numKVHeads, int(c.Capacity))
|
||||
c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, kHeadDim, numKVHeads, len(c.cells))
|
||||
}
|
||||
|
||||
if _, ok := c.values[c.curLayer]; !ok {
|
||||
if c.config.PermutedV {
|
||||
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, int(c.Capacity), vHeadDim, numKVHeads)
|
||||
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, len(c.cells), vHeadDim, numKVHeads)
|
||||
} else {
|
||||
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, vHeadDim, numKVHeads, int(c.Capacity))
|
||||
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, vHeadDim, numKVHeads, len(c.cells))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -497,7 +559,7 @@ func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||
elemSize := c.values[c.curLayer].Stride(0)
|
||||
|
||||
value = value.Permute(ctx, 1, 2, 0, 3)
|
||||
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)))
|
||||
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, len(c.cells)*elemSize, vHeadDim*numKVHeads)))
|
||||
} else {
|
||||
rowSize := c.values[c.curLayer].Stride(2)
|
||||
|
||||
@@ -528,6 +590,35 @@ func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
|
||||
c.cellRanges[dstSeq] = seqRange
|
||||
}
|
||||
|
||||
func (c *Causal) CanResume(seq int, pos int32) bool {
|
||||
if c.windowSize == math.MaxInt32 {
|
||||
return true
|
||||
}
|
||||
|
||||
seqRange, ok := c.cellRanges[seq]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
// for sliding window, check that the window of the new sequence is contained in
|
||||
// the window of what we are storing
|
||||
var last int32 = -1
|
||||
for i := seqRange.min; i <= seqRange.max; i++ {
|
||||
if slices.Contains(c.cells[i].sequences, seq) {
|
||||
last = max(last, c.cells[i].pos)
|
||||
}
|
||||
}
|
||||
|
||||
if last == -1 {
|
||||
return false
|
||||
}
|
||||
|
||||
lastWindowStart := max(0, last-c.windowSize)
|
||||
posWindowStart := max(0, pos-c.windowSize)
|
||||
|
||||
return posWindowStart >= lastWindowStart
|
||||
}
|
||||
|
||||
func (c *Causal) shift(seq int, beginIndex, offset int32) error {
|
||||
if c.shiftFn == nil {
|
||||
return ErrNotSupported
|
||||
@@ -582,6 +673,12 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error {
|
||||
}
|
||||
|
||||
func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error {
|
||||
// TODO(jessegross): We should check to see if removing the middle of the sequence will
|
||||
// cause the sliding window to encompass tokens that we no longer have. If so, then we
|
||||
// should return an error, which will trigger the runner to evaluate the full history and
|
||||
// rebuild the window. However, if we have multimodal inputs in our history, this reuse
|
||||
// results in use after free, so we don't do it for now.
|
||||
|
||||
var offset int32
|
||||
if endIndex != math.MaxInt32 {
|
||||
offset = beginIndex - endIndex
|
||||
@@ -596,8 +693,7 @@ func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error {
|
||||
} else {
|
||||
if c.cells[i].pos >= endIndex {
|
||||
if slices.ContainsFunc(c.cells[i].sequences, func(s int) bool { return s != seq }) {
|
||||
// TODO(jessegross): Need to be careful about data shared between sequences
|
||||
return errors.New("shifting on cells shared by multiple sequences not yet implemented")
|
||||
return errors.New("shifting cells shared by multiple sequences not supported")
|
||||
}
|
||||
|
||||
c.cells[i].pos += offset
|
||||
|
||||
@@ -25,7 +25,7 @@ func TestStore(t *testing.T) {
|
||||
cache := NewCausalCache(nil)
|
||||
defer cache.Close()
|
||||
|
||||
cache.Init(backend, ml.DTypeF16, 16)
|
||||
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||
|
||||
tests := []testCase{
|
||||
{
|
||||
@@ -58,11 +58,11 @@ func TestSWA(t *testing.T) {
|
||||
cache := NewSWACache(1, nil)
|
||||
defer cache.Close()
|
||||
|
||||
cache.Init(backend, ml.DTypeF32, 16)
|
||||
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||
|
||||
tests := []testCase{
|
||||
{
|
||||
name: "SlidingWindow",
|
||||
name: "FirstBatch",
|
||||
in: []float32{1, 2, 3, 4},
|
||||
inShape: []int{1, 1, 4},
|
||||
seqs: []int{0, 0, 0, 0},
|
||||
@@ -71,6 +71,16 @@ func TestSWA(t *testing.T) {
|
||||
expectedShape: []int{1, 1, 4},
|
||||
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
|
||||
},
|
||||
{
|
||||
name: "SecondBatch",
|
||||
in: []float32{5, 6},
|
||||
inShape: []int{1, 1, 2},
|
||||
seqs: []int{0, 0},
|
||||
pos: []int32{4, 5},
|
||||
expected: []float32{5, 6, 3, 4},
|
||||
expectedShape: []int{1, 1, 4},
|
||||
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1))},
|
||||
},
|
||||
}
|
||||
|
||||
testCache(t, backend, cache, tests)
|
||||
@@ -81,7 +91,7 @@ func TestSequences(t *testing.T) {
|
||||
cache := NewCausalCache(nil)
|
||||
defer cache.Close()
|
||||
|
||||
cache.Init(backend, ml.DTypeF16, 16)
|
||||
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||
|
||||
tests := []testCase{
|
||||
{
|
||||
@@ -116,7 +126,7 @@ func TestRemove(t *testing.T) {
|
||||
})
|
||||
defer cache.Close()
|
||||
|
||||
cache.Init(backend, ml.DTypeF16, 16)
|
||||
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||
|
||||
tests := []testCase{
|
||||
{
|
||||
@@ -181,7 +191,7 @@ func TestDefrag(t *testing.T) {
|
||||
})
|
||||
defer cache.Close()
|
||||
|
||||
cache.Init(backend, ml.DTypeF16, 16)
|
||||
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||
|
||||
tests := []testCase{
|
||||
{
|
||||
@@ -229,7 +239,7 @@ func TestCopy(t *testing.T) {
|
||||
cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { return key, nil })
|
||||
defer cache.Close()
|
||||
|
||||
cache.Init(backend, ml.DTypeF16, 16)
|
||||
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||
|
||||
tests := []testCase{
|
||||
{
|
||||
@@ -270,7 +280,7 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
|
||||
context := backend.NewContext()
|
||||
defer context.Close()
|
||||
|
||||
err := cache.StartForward(context, input.Options{Positions: test.pos, Sequences: test.seqs})
|
||||
err := cache.StartForward(context, input.Batch{Positions: test.pos, Sequences: test.seqs}, false)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
@@ -290,14 +300,79 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
|
||||
}
|
||||
}
|
||||
|
||||
type testBackend struct{}
|
||||
func TestCanResume(t *testing.T) {
|
||||
backend := &testBackend{}
|
||||
windowSize := int32(4)
|
||||
cache := NewSWACache(windowSize, nil)
|
||||
defer cache.Close()
|
||||
|
||||
func (b *testBackend) Config() ml.Config {
|
||||
panic("not implemented")
|
||||
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||
|
||||
context := backend.NewContext()
|
||||
defer context.Close()
|
||||
|
||||
err := cache.StartForward(context, input.Batch{
|
||||
Positions: []int32{0, 1, 2, 3},
|
||||
Sequences: []int{0, 0, 0, 0},
|
||||
}, false)
|
||||
if err != nil {
|
||||
t.Fatalf("StartForward failed: %v", err)
|
||||
}
|
||||
|
||||
cache.SetLayer(0)
|
||||
tensor, _ := context.FromFloatSlice([]float32{1, 2, 3, 4}, 1, 1, 4)
|
||||
cache.Put(context, tensor, tensor)
|
||||
|
||||
// with window size 4, nothing has slid out of the window yet
|
||||
if !cache.CanResume(0, 0) {
|
||||
t.Errorf("CanResume(0, 0) = false, want true (within window)")
|
||||
}
|
||||
if !cache.CanResume(0, 1) {
|
||||
t.Errorf("CanResume(0, 1) = false, want true (within window)")
|
||||
}
|
||||
if !cache.CanResume(0, 2) {
|
||||
t.Errorf("CanResume(0, 2) = false, want true (within window)")
|
||||
}
|
||||
if !cache.CanResume(0, 3) {
|
||||
t.Errorf("CanResume(0, 3) = false, want true (latest position)")
|
||||
}
|
||||
|
||||
// shift window by adding position 4
|
||||
err = cache.StartForward(context, input.Batch{
|
||||
Positions: []int32{4, 5},
|
||||
Sequences: []int{0, 0},
|
||||
}, false)
|
||||
if err != nil {
|
||||
t.Fatalf("StartForward failed: %v", err)
|
||||
}
|
||||
|
||||
cache.SetLayer(0)
|
||||
tensor, _ = context.FromFloatSlice([]float32{5, 6}, 1, 1, 2)
|
||||
cache.Put(context, tensor, tensor)
|
||||
|
||||
// only the latest position has overlapping windows
|
||||
if cache.CanResume(0, 0) {
|
||||
t.Errorf("after shift: CanResume(0, 0) = true, want false (outside window)")
|
||||
}
|
||||
if cache.CanResume(0, 1) {
|
||||
t.Errorf("after shift: CanResume(0, 1) = true, want false (outside window)")
|
||||
}
|
||||
if cache.CanResume(0, 2) {
|
||||
t.Errorf("after shift: CanResume(0, 2) = true, want false (outside window)")
|
||||
}
|
||||
if cache.CanResume(0, 3) {
|
||||
t.Errorf("after shift: CanResume(0, 3) = true, want false (outside window)")
|
||||
}
|
||||
if cache.CanResume(0, 4) {
|
||||
t.Errorf("after shift: CanResume(0, 4) = true, want false (outside window)")
|
||||
}
|
||||
if !cache.CanResume(0, 5) {
|
||||
t.Errorf("after shift: CanResume(0, 5) = false, want true (latest position)")
|
||||
}
|
||||
}
|
||||
|
||||
func (b *testBackend) Get(name string) ml.Tensor {
|
||||
panic("not implemented")
|
||||
type testBackend struct {
|
||||
ml.Backend
|
||||
}
|
||||
|
||||
func (b *testBackend) NewContext() ml.Context {
|
||||
@@ -308,12 +383,10 @@ func (b *testBackend) NewContextSize(int) ml.Context {
|
||||
return &testContext{}
|
||||
}
|
||||
|
||||
func (b *testBackend) SystemInfo() string {
|
||||
return "not implemented"
|
||||
type testContext struct {
|
||||
ml.Context
|
||||
}
|
||||
|
||||
type testContext struct{}
|
||||
|
||||
func (c *testContext) Empty(dtype ml.DType, shape ...int) ml.Tensor {
|
||||
total := 0
|
||||
|
||||
@@ -352,13 +425,14 @@ func (c *testContext) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
|
||||
}
|
||||
|
||||
func (c *testContext) Input() ml.Context { return c }
|
||||
func (c *testContext) Output() ml.Context { return c }
|
||||
func (c *testContext) Layer(int) ml.Context { return c }
|
||||
|
||||
func (c *testContext) Forward(...ml.Tensor) ml.Context { return c }
|
||||
|
||||
func (c *testContext) Compute(...ml.Tensor) {}
|
||||
|
||||
func (c *testContext) Reserve() error { return nil }
|
||||
|
||||
func (c *testContext) MaxGraphNodes() int {
|
||||
return 10
|
||||
}
|
||||
@@ -366,6 +440,8 @@ func (c *testContext) MaxGraphNodes() int {
|
||||
func (c *testContext) Close() {}
|
||||
|
||||
type testTensor struct {
|
||||
ml.Tensor
|
||||
|
||||
dtype ml.DType
|
||||
elementSize int
|
||||
data []float32
|
||||
@@ -393,16 +469,20 @@ func (t *testTensor) DType() ml.DType {
|
||||
return t.dtype
|
||||
}
|
||||
|
||||
func (t *testTensor) Bytes() []byte {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) Floats() []float32 {
|
||||
out := make([]float32, len(t.data))
|
||||
copy(out, t.data)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *testTensor) Neg(ctx ml.Context) ml.Tensor {
|
||||
out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor)
|
||||
for i := range out.data {
|
||||
out.data[i] = -t.data[i]
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor)
|
||||
|
||||
@@ -413,66 +493,6 @@ func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *testTensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) MulmatFullPrec(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) Softmax(ctx ml.Context) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) LayerNorm(ctx ml.Context, weight, bias ml.Tensor, eps float32) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) RMSNorm(ctx ml.Context, weight ml.Tensor, eps float32) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) Scale(ctx ml.Context, s float64) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) AvgPool1D(ctx ml.Context, k, s, p int) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) Conv2D(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, dim, ropeType uint32, base, scale float32) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) Tanh(ctx ml.Context) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) GELU(ctx ml.Context) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) SILU(ctx ml.Context) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
|
||||
offset /= t.elementSize
|
||||
|
||||
@@ -495,38 +515,6 @@ func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
|
||||
return view
|
||||
}
|
||||
|
||||
func (t *testTensor) Permute(ctx ml.Context, shape ...int) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) Contiguous(ctx ml.Context) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) Set(ctx ml.Context, t2 ml.Tensor, offset int, strides ...int) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) Unpad(ctx ml.Context, shape ...int) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) Stack(ctx ml.Context, dim int, s ...ml.Tensor) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) Concat(ctx ml.Context, t2 ml.Tensor, dim int) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) Rows(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
copy(t2.(*testTensor).data, t.data)
|
||||
return nil
|
||||
|
||||
@@ -27,6 +27,11 @@ type EncoderCache struct {
|
||||
// anything will be stored)
|
||||
curPos int32
|
||||
|
||||
// curReserve indicates that this forward pass is only for
|
||||
// memory reservation and we should not update our metadata
|
||||
// based on it.
|
||||
curReserve bool
|
||||
|
||||
// ** cache metadata **
|
||||
|
||||
// was something stored in the cache?
|
||||
@@ -49,7 +54,7 @@ func NewEncoderCache() *EncoderCache {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
|
||||
func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||
if c.config == nil {
|
||||
var config ml.CacheConfig
|
||||
if cc, ok := backend.(ml.BackendCacheConfig); ok {
|
||||
@@ -58,6 +63,10 @@ func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32)
|
||||
c.config = &config
|
||||
}
|
||||
|
||||
if maxSequences > 1 {
|
||||
panic(fmt.Errorf("encoder cache does not support multiple sequences; requested: %v", maxSequences))
|
||||
}
|
||||
|
||||
if c.config.CachePadding != 0 && c.config.CachePadding != 1 {
|
||||
panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding))
|
||||
}
|
||||
@@ -79,12 +88,14 @@ func (c *EncoderCache) Close() {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *EncoderCache) StartForward(ctx ml.Context, opts input.Options) error {
|
||||
func (c *EncoderCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
||||
// We work with the most recent image
|
||||
if len(opts.Multimodal) > 0 {
|
||||
c.curPos = opts.Positions[opts.Multimodal[len(opts.Multimodal)-1].Index]
|
||||
if len(batch.Multimodal) > 0 {
|
||||
c.curPos = batch.Positions[batch.Multimodal[len(batch.Multimodal)-1].Index]
|
||||
}
|
||||
|
||||
c.curReserve = reserve
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -101,8 +112,10 @@ func (c *EncoderCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||
}
|
||||
|
||||
func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||
c.encoderPos = c.curPos
|
||||
c.encoderCached = true
|
||||
if !c.curReserve {
|
||||
c.encoderPos = c.curPos
|
||||
c.encoderCached = true
|
||||
}
|
||||
|
||||
if c.config.PermutedV {
|
||||
value = value.Permute(ctx, 1, 2, 0, 3)
|
||||
@@ -130,6 +143,10 @@ func (c *EncoderCache) CopyPrefix(srcSeq, dstSeq int, len int32) {
|
||||
panic("encoder cache does not support multiple sequences")
|
||||
}
|
||||
|
||||
func (c *EncoderCache) CanResume(seq int, pos int32) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *EncoderCache) Remove(seq int, beginIndex, endIndex int32) error {
|
||||
if c.encoderPos >= beginIndex && c.encoderPos < endIndex {
|
||||
c.encoderCached = false
|
||||
|
||||
@@ -23,9 +23,9 @@ func NewWrapperCache(caches ...Cache) *WrapperCache {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
|
||||
func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||
for _, cache := range c.caches {
|
||||
cache.Init(backend, dtype, capacity)
|
||||
cache.Init(backend, dtype, maxSequences, capacity, maxBatch)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -41,14 +41,14 @@ func (c *WrapperCache) Close() {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *WrapperCache) StartForward(ctx ml.Context, opts input.Options) error {
|
||||
func (c *WrapperCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
||||
for i, cache := range c.caches {
|
||||
err := cache.StartForward(ctx, opts)
|
||||
err := cache.StartForward(ctx, batch, reserve)
|
||||
if err != nil {
|
||||
// unwind on error - Remove with endIndex set to math.MaxInt32 does not fail
|
||||
for j := i - 1; j >= 0; j-- {
|
||||
for k := range opts.Positions {
|
||||
_ = c.caches[j].Remove(opts.Sequences[k], opts.Positions[k], math.MaxInt32)
|
||||
for k := range batch.Positions {
|
||||
_ = c.caches[j].Remove(batch.Sequences[k], batch.Positions[k], math.MaxInt32)
|
||||
}
|
||||
}
|
||||
return err
|
||||
@@ -87,6 +87,16 @@ func (c *WrapperCache) CopyPrefix(srcSeq, dstSeq int, len int32) {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *WrapperCache) CanResume(seq int, pos int32) bool {
|
||||
for _, cache := range c.caches {
|
||||
if !cache.CanResume(seq, pos) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *WrapperCache) Remove(seq int, beginIndex, endIndex int32) error {
|
||||
// If the one of these fails, the caller is supposed to retry with endIndex set to math.MaxInt32, which should not fail
|
||||
for _, cache := range c.caches {
|
||||
|
||||
36
llama/llama.cpp/src/llama-arch.cpp
vendored
36
llama/llama.cpp/src/llama-arch.cpp
vendored
@@ -37,6 +37,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_MINICPM3, "minicpm3" },
|
||||
{ LLM_ARCH_GEMMA, "gemma" },
|
||||
{ LLM_ARCH_GEMMA2, "gemma2" },
|
||||
{ LLM_ARCH_GEMMA3, "gemma3" },
|
||||
{ LLM_ARCH_STARCODER2, "starcoder2" },
|
||||
{ LLM_ARCH_MAMBA, "mamba" },
|
||||
{ LLM_ARCH_XVERSE, "xverse" },
|
||||
@@ -64,6 +65,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_CHAMELEON, "chameleon" },
|
||||
{ LLM_ARCH_SOLAR, "solar" },
|
||||
{ LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
|
||||
{ LLM_ARCH_MISTRAL3, "mistral3" },
|
||||
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
||||
};
|
||||
|
||||
@@ -804,6 +806,24 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_GEMMA3,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_STARCODER2,
|
||||
{
|
||||
@@ -1352,6 +1372,22 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||
{ LLM_TENSOR_POS_NET_ATTN_OUT, "posnet.%d.attn_output" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_MISTRAL3,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
}
|
||||
},
|
||||
{
|
||||
LLM_ARCH_UNKNOWN,
|
||||
{
|
||||
|
||||
2
llama/llama.cpp/src/llama-arch.h
vendored
2
llama/llama.cpp/src/llama-arch.h
vendored
@@ -41,6 +41,7 @@ enum llm_arch {
|
||||
LLM_ARCH_MINICPM3,
|
||||
LLM_ARCH_GEMMA,
|
||||
LLM_ARCH_GEMMA2,
|
||||
LLM_ARCH_GEMMA3,
|
||||
LLM_ARCH_STARCODER2,
|
||||
LLM_ARCH_MAMBA,
|
||||
LLM_ARCH_XVERSE,
|
||||
@@ -68,6 +69,7 @@ enum llm_arch {
|
||||
LLM_ARCH_CHAMELEON,
|
||||
LLM_ARCH_SOLAR,
|
||||
LLM_ARCH_WAVTOKENIZER_DEC,
|
||||
LLM_ARCH_MISTRAL3,
|
||||
LLM_ARCH_UNKNOWN,
|
||||
};
|
||||
|
||||
|
||||
10
llama/llama.cpp/src/llama-model.cpp
vendored
10
llama/llama.cpp/src/llama-model.cpp
vendored
@@ -878,6 +878,9 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_GEMMA3:
|
||||
{
|
||||
} break;
|
||||
case LLM_ARCH_STARCODER2:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
||||
@@ -1274,6 +1277,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
ml.get_key(LLM_KV_ATTENTION_GROUPNORM_GROUPS, hparams.n_norm_groups);
|
||||
ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
|
||||
} break;
|
||||
case LLM_ARCH_MISTRAL3: break;
|
||||
default: throw std::runtime_error("unsupported model architecture");
|
||||
}
|
||||
|
||||
@@ -2537,6 +2541,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_GEMMA3:
|
||||
{
|
||||
} break;
|
||||
case LLM_ARCH_STARCODER2:
|
||||
{
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
@@ -3531,6 +3538,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hparams.convnext.n_embd, n_embd}, 0);
|
||||
output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {n_embd}, 0);
|
||||
} break;
|
||||
case LLM_ARCH_MISTRAL3: break;
|
||||
default:
|
||||
throw std::runtime_error("unknown architecture");
|
||||
}
|
||||
@@ -4009,6 +4017,7 @@ enum llama_rope_type llama_model_rope_type(const struct llama_model * model) {
|
||||
case LLM_ARCH_GRANITE_MOE:
|
||||
case LLM_ARCH_CHAMELEON:
|
||||
case LLM_ARCH_SOLAR:
|
||||
case LLM_ARCH_MISTRAL3:
|
||||
return LLAMA_ROPE_TYPE_NORM;
|
||||
|
||||
// the pairs of head values are offset by n_rot/2
|
||||
@@ -4029,6 +4038,7 @@ enum llama_rope_type llama_model_rope_type(const struct llama_model * model) {
|
||||
case LLM_ARCH_PHIMOE:
|
||||
case LLM_ARCH_GEMMA:
|
||||
case LLM_ARCH_GEMMA2:
|
||||
case LLM_ARCH_GEMMA3:
|
||||
case LLM_ARCH_STARCODER2:
|
||||
case LLM_ARCH_OPENELM:
|
||||
case LLM_ARCH_GPTNEOX:
|
||||
|
||||
4
llama/llama.cpp/src/llama-quant.cpp
vendored
4
llama/llama.cpp/src/llama-quant.cpp
vendored
@@ -737,6 +737,10 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||
// This used to be a regex, but <regex> has an extreme cost to compile times.
|
||||
bool quantize = name.rfind("weight") == name.size() - 6; // ends with 'weight'?
|
||||
|
||||
// don't quantize vision stuff
|
||||
quantize &= name.find("v.") == std::string::npos;
|
||||
quantize &= name.find("mm.") == std::string::npos;
|
||||
|
||||
// quantize only 2D and 3D tensors (experts)
|
||||
quantize &= (ggml_n_dims(tensor) >= 2);
|
||||
|
||||
|
||||
@@ -166,6 +166,10 @@ func (c *Context) KvCacheDefrag() {
|
||||
C.llama_kv_cache_defrag(c.c)
|
||||
}
|
||||
|
||||
func (c *Context) KvCacheCanShift() bool {
|
||||
return bool(C.llama_kv_cache_can_shift(c.c))
|
||||
}
|
||||
|
||||
// Get the embeddings for a sequence id
|
||||
func (c *Context) GetEmbeddingsSeq(seqId int) []float32 {
|
||||
e := unsafe.Pointer(C.llama_get_embeddings_seq(c.c, C.int(seqId)))
|
||||
|
||||
173
llama/patches/0021-add-model-quantizations.patch
Normal file
173
llama/patches/0021-add-model-quantizations.patch
Normal file
@@ -0,0 +1,173 @@
|
||||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: Patrick Devine <patrick@infrahq.com>
|
||||
Date: Fri, 14 Mar 2025 16:33:23 -0700
|
||||
Subject: [PATCH] add model quantizations
|
||||
|
||||
- gemma3
|
||||
- mistral3
|
||||
---
|
||||
src/llama-arch.cpp | 36 ++++++++++++++++++++++++++++++++++++
|
||||
src/llama-arch.h | 2 ++
|
||||
src/llama-model.cpp | 10 ++++++++++
|
||||
src/llama-quant.cpp | 4 ++++
|
||||
4 files changed, 52 insertions(+)
|
||||
|
||||
diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp
|
||||
index b6f20286..13a0a988 100644
|
||||
--- a/src/llama-arch.cpp
|
||||
+++ b/src/llama-arch.cpp
|
||||
@@ -37,6 +37,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_MINICPM3, "minicpm3" },
|
||||
{ LLM_ARCH_GEMMA, "gemma" },
|
||||
{ LLM_ARCH_GEMMA2, "gemma2" },
|
||||
+ { LLM_ARCH_GEMMA3, "gemma3" },
|
||||
{ LLM_ARCH_STARCODER2, "starcoder2" },
|
||||
{ LLM_ARCH_MAMBA, "mamba" },
|
||||
{ LLM_ARCH_XVERSE, "xverse" },
|
||||
@@ -64,6 +65,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_CHAMELEON, "chameleon" },
|
||||
{ LLM_ARCH_SOLAR, "solar" },
|
||||
{ LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
|
||||
+ { LLM_ARCH_MISTRAL3, "mistral3" },
|
||||
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
||||
};
|
||||
|
||||
@@ -804,6 +806,24 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
|
||||
},
|
||||
},
|
||||
+ {
|
||||
+ LLM_ARCH_GEMMA3,
|
||||
+ {
|
||||
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
+ { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
|
||||
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
+ { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
|
||||
+ },
|
||||
+ },
|
||||
{
|
||||
LLM_ARCH_STARCODER2,
|
||||
{
|
||||
@@ -1352,6 +1372,22 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||
{ LLM_TENSOR_POS_NET_ATTN_OUT, "posnet.%d.attn_output" },
|
||||
},
|
||||
},
|
||||
+ {
|
||||
+ LLM_ARCH_MISTRAL3,
|
||||
+ {
|
||||
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
+ }
|
||||
+ },
|
||||
{
|
||||
LLM_ARCH_UNKNOWN,
|
||||
{
|
||||
diff --git a/src/llama-arch.h b/src/llama-arch.h
|
||||
index ec742224..8476ae0a 100644
|
||||
--- a/src/llama-arch.h
|
||||
+++ b/src/llama-arch.h
|
||||
@@ -41,6 +41,7 @@ enum llm_arch {
|
||||
LLM_ARCH_MINICPM3,
|
||||
LLM_ARCH_GEMMA,
|
||||
LLM_ARCH_GEMMA2,
|
||||
+ LLM_ARCH_GEMMA3,
|
||||
LLM_ARCH_STARCODER2,
|
||||
LLM_ARCH_MAMBA,
|
||||
LLM_ARCH_XVERSE,
|
||||
@@ -68,6 +69,7 @@ enum llm_arch {
|
||||
LLM_ARCH_CHAMELEON,
|
||||
LLM_ARCH_SOLAR,
|
||||
LLM_ARCH_WAVTOKENIZER_DEC,
|
||||
+ LLM_ARCH_MISTRAL3,
|
||||
LLM_ARCH_UNKNOWN,
|
||||
};
|
||||
|
||||
diff --git a/src/llama-model.cpp b/src/llama-model.cpp
|
||||
index ab1a07d1..db4f2685 100644
|
||||
--- a/src/llama-model.cpp
|
||||
+++ b/src/llama-model.cpp
|
||||
@@ -878,6 +878,9 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
+ case LLM_ARCH_GEMMA3:
|
||||
+ {
|
||||
+ } break;
|
||||
case LLM_ARCH_STARCODER2:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
||||
@@ -1274,6 +1277,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
ml.get_key(LLM_KV_ATTENTION_GROUPNORM_GROUPS, hparams.n_norm_groups);
|
||||
ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
|
||||
} break;
|
||||
+ case LLM_ARCH_MISTRAL3: break;
|
||||
default: throw std::runtime_error("unsupported model architecture");
|
||||
}
|
||||
|
||||
@@ -2537,6 +2541,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
|
||||
}
|
||||
} break;
|
||||
+ case LLM_ARCH_GEMMA3:
|
||||
+ {
|
||||
+ } break;
|
||||
case LLM_ARCH_STARCODER2:
|
||||
{
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
@@ -3531,6 +3538,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hparams.convnext.n_embd, n_embd}, 0);
|
||||
output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {n_embd}, 0);
|
||||
} break;
|
||||
+ case LLM_ARCH_MISTRAL3: break;
|
||||
default:
|
||||
throw std::runtime_error("unknown architecture");
|
||||
}
|
||||
@@ -4009,6 +4017,7 @@ enum llama_rope_type llama_model_rope_type(const struct llama_model * model) {
|
||||
case LLM_ARCH_GRANITE_MOE:
|
||||
case LLM_ARCH_CHAMELEON:
|
||||
case LLM_ARCH_SOLAR:
|
||||
+ case LLM_ARCH_MISTRAL3:
|
||||
return LLAMA_ROPE_TYPE_NORM;
|
||||
|
||||
// the pairs of head values are offset by n_rot/2
|
||||
@@ -4029,6 +4038,7 @@ enum llama_rope_type llama_model_rope_type(const struct llama_model * model) {
|
||||
case LLM_ARCH_PHIMOE:
|
||||
case LLM_ARCH_GEMMA:
|
||||
case LLM_ARCH_GEMMA2:
|
||||
+ case LLM_ARCH_GEMMA3:
|
||||
case LLM_ARCH_STARCODER2:
|
||||
case LLM_ARCH_OPENELM:
|
||||
case LLM_ARCH_GPTNEOX:
|
||||
diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp
|
||||
index 6eb1da08..ebcbafa1 100644
|
||||
--- a/src/llama-quant.cpp
|
||||
+++ b/src/llama-quant.cpp
|
||||
@@ -737,6 +737,10 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||
// This used to be a regex, but <regex> has an extreme cost to compile times.
|
||||
bool quantize = name.rfind("weight") == name.size() - 6; // ends with 'weight'?
|
||||
|
||||
+ // don't quantize vision stuff
|
||||
+ quantize &= name.find("v.") == std::string::npos;
|
||||
+ quantize &= name.find("mm.") == std::string::npos;
|
||||
+
|
||||
// quantize only 2D and 3D tensors (experts)
|
||||
quantize &= (ggml_n_dims(tensor) >= 2);
|
||||
|
||||
103
llama/patches/0022-add-rdna4-support.patch
Normal file
103
llama/patches/0022-add-rdna4-support.patch
Normal file
@@ -0,0 +1,103 @@
|
||||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: Saman <saman.khatir@amd.com>
|
||||
Date: Wed, 19 Mar 2025 14:02:26 -0700
|
||||
Subject: [PATCH] add rdna4 support
|
||||
|
||||
---
|
||||
ggml/src/ggml-cuda/common.cuh | 6 ++++--
|
||||
ggml/src/ggml-cuda/mmq.cu | 2 +-
|
||||
ggml/src/ggml-cuda/mmq.cuh | 4 ++--
|
||||
ggml/src/ggml-cuda/mmvq.cu | 4 ++--
|
||||
ggml/src/ggml-cuda/vendors/hip.h | 4 ++++
|
||||
5 files changed, 13 insertions(+), 7 deletions(-)
|
||||
|
||||
diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh
|
||||
index adf0d3ec..b24593fc 100644
|
||||
--- a/ggml/src/ggml-cuda/common.cuh
|
||||
+++ b/ggml/src/ggml-cuda/common.cuh
|
||||
@@ -61,11 +61,13 @@
|
||||
#define GGML_CUDA_CC_RDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x1010) // RX 5000
|
||||
#define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x1030) // RX 6000, minimum for dp4a
|
||||
#define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x1100) // RX 7000, minimum for WMMA
|
||||
+#define GGML_CUDA_CC_RDNA4 (GGML_CUDA_CC_OFFSET_AMD + 0x1200) // RX 9000
|
||||
|
||||
#define GGML_CUDA_CC_IS_RDNA(cc) (cc >= GGML_CUDA_CC_RDNA1)
|
||||
#define GGML_CUDA_CC_IS_RDNA1(cc) (cc >= GGML_CUDA_CC_RDNA1 && cc < GGML_CUDA_CC_RDNA2)
|
||||
#define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3)
|
||||
-#define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3)
|
||||
+#define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3 && cc < GGML_CUDA_CC_RDNA4)
|
||||
+#define GGML_CUDA_CC_IS_RDNA4(cc) (cc >= GGML_CUDA_CC_RDNA4)
|
||||
#define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA)
|
||||
#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1)
|
||||
|
||||
@@ -386,7 +388,7 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i
|
||||
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
||||
#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(RDNA2)
|
||||
c = __builtin_amdgcn_sdot4(a, b, c, false);
|
||||
-#elif defined(RDNA3)
|
||||
+#elif defined(RDNA3) || defined(RDNA4)
|
||||
c = __builtin_amdgcn_sudot4( true, a, true, b, c, false);
|
||||
#elif defined(__gfx1010__) || defined(__gfx900__)
|
||||
int tmp1;
|
||||
diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu
|
||||
index 10f2ebb1..933d945c 100644
|
||||
--- a/ggml/src/ggml-cuda/mmq.cu
|
||||
+++ b/ggml/src/ggml-cuda/mmq.cu
|
||||
@@ -149,5 +149,5 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
|
||||
return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
|
||||
}
|
||||
|
||||
- return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
|
||||
+ return (!GGML_CUDA_CC_IS_RDNA4(cc) && !GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
|
||||
}
|
||||
diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh
|
||||
index 0451c65f..66ce2bc9 100644
|
||||
--- a/ggml/src/ggml-cuda/mmq.cuh
|
||||
+++ b/ggml/src/ggml-cuda/mmq.cuh
|
||||
@@ -2577,9 +2577,9 @@ static __device__ void mul_mat_q_process_tile(
|
||||
|
||||
template <ggml_type type, int mmq_x, int nwarps, bool need_check>
|
||||
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
||||
-#if defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
|
||||
+#if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
|
||||
__launch_bounds__(WARP_SIZE*nwarps, 2)
|
||||
-#endif // defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
|
||||
+#endif // defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
|
||||
#else
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
||||
__launch_bounds__(WARP_SIZE*nwarps, 1)
|
||||
diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu
|
||||
index 4fb466ca..23ae7abc 100644
|
||||
--- a/ggml/src/ggml-cuda/mmvq.cu
|
||||
+++ b/ggml/src/ggml-cuda/mmvq.cu
|
||||
@@ -62,13 +62,13 @@ static __global__ void mul_mat_vec_q(
|
||||
|
||||
constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
|
||||
|
||||
-#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
|
||||
+#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4))
|
||||
constexpr int nwarps = 1;
|
||||
constexpr int rows_per_cuda_block = 1;
|
||||
#else
|
||||
constexpr int nwarps = ncols_y <= 4 ? 4 : 2;
|
||||
constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2;
|
||||
-#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
|
||||
+#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3) && !defined(RDNA4)
|
||||
|
||||
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
|
||||
const int row0 = rows_per_cuda_block*blockIdx.x;
|
||||
diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h
|
||||
index 81964611..a62544b5 100644
|
||||
--- a/ggml/src/ggml-cuda/vendors/hip.h
|
||||
+++ b/ggml/src/ggml-cuda/vendors/hip.h
|
||||
@@ -150,6 +150,10 @@
|
||||
#define CDNA
|
||||
#endif
|
||||
|
||||
+#if defined(__gfx1200__) || defined(__gfx1201__)
|
||||
+#define RDNA4
|
||||
+#endif
|
||||
+
|
||||
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
|
||||
defined(__gfx1150__) || defined(__gfx1151__)
|
||||
#define RDNA3
|
||||
75
llama/patches/0022-metal-add-op_neg.patch
Normal file
75
llama/patches/0022-metal-add-op_neg.patch
Normal file
@@ -0,0 +1,75 @@
|
||||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: Michael Yang <git@mxy.ng>
|
||||
Date: Wed, 2 Apr 2025 15:26:15 -0700
|
||||
Subject: [PATCH] metal: add op_neg
|
||||
|
||||
---
|
||||
ggml/src/ggml-metal/ggml-metal.m | 15 +++++++++++++++
|
||||
ggml/src/ggml-metal/ggml-metal.metal | 7 +++++++
|
||||
2 files changed, 22 insertions(+)
|
||||
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m
|
||||
index e4c093f9..d8422f1b 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal.m
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal.m
|
||||
@@ -423,6 +423,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_SQRT,
|
||||
GGML_METAL_KERNEL_TYPE_SIN,
|
||||
GGML_METAL_KERNEL_TYPE_COS,
|
||||
+ GGML_METAL_KERNEL_TYPE_NEG,
|
||||
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
||||
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
|
||||
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
|
||||
@@ -1039,6 +1040,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
|
||||
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
|
||||
@@ -1202,6 +1204,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
||||
case GGML_UNARY_OP_GELU_QUICK:
|
||||
case GGML_UNARY_OP_SILU:
|
||||
case GGML_UNARY_OP_ELU:
|
||||
+ case GGML_UNARY_OP_NEG:
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
default:
|
||||
return false;
|
||||
@@ -1873,6 +1876,18 @@ static void ggml_metal_encode_node(
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
+ case GGML_UNARY_OP_NEG:
|
||||
+ {
|
||||
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NEG].pipeline;
|
||||
+
|
||||
+ [encoder setComputePipelineState:pipeline];
|
||||
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
+
|
||||
+ const int64_t n = ggml_nelements(dst);
|
||||
+
|
||||
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
+ } break;
|
||||
default:
|
||||
{
|
||||
GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
|
||||
index f38909d0..bb0ff668 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal.metal
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal.metal
|
||||
@@ -945,6 +945,13 @@ kernel void kernel_cos(
|
||||
dst[tpig] = cos(src0[tpig]);
|
||||
}
|
||||
|
||||
+kernel void kernel_neg(
|
||||
+ device const float * src0,
|
||||
+ device float * dst,
|
||||
+ uint tpig[[thread_position_in_grid]]) {
|
||||
+ dst[tpig] = -src0[tpig];
|
||||
+}
|
||||
+
|
||||
kernel void kernel_sum_rows(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
@@ -15,12 +15,12 @@ import (
|
||||
)
|
||||
|
||||
// This algorithm looks for a complete fit to determine if we need to unload other models
|
||||
func PredictServerFit(allGpus discover.GpuInfoList, f *ggml.GGML, adapters, projectors []string, opts api.Options) (bool, uint64) {
|
||||
func PredictServerFit(allGpus discover.GpuInfoList, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (bool, uint64) {
|
||||
// Split up the GPUs by type and try them
|
||||
var estimatedVRAM uint64
|
||||
for _, gpus := range allGpus.ByLibrary() {
|
||||
var layerCount int
|
||||
estimate := EstimateGPULayers(gpus, f, projectors, opts)
|
||||
estimate := EstimateGPULayers(gpus, f, projectors, opts, numParallel)
|
||||
layerCount, estimatedVRAM = estimate.Layers, estimate.VRAMSize
|
||||
if opts.NumGPU < 0 {
|
||||
if layerCount > 0 && layerCount >= int(f.KV().BlockCount()+1) {
|
||||
@@ -71,7 +71,7 @@ type MemoryEstimate struct {
|
||||
|
||||
// Given a model and one or more GPU targets, predict how many layers and bytes we can load, and the total size
|
||||
// The GPUs provided must all be the same Library
|
||||
func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []string, opts api.Options) MemoryEstimate {
|
||||
func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []string, opts api.Options, numParallel int) MemoryEstimate {
|
||||
// Graph size for a partial offload, applies to all GPUs
|
||||
var graphPartialOffload uint64
|
||||
|
||||
@@ -137,13 +137,19 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
|
||||
}
|
||||
}
|
||||
|
||||
kv, graphPartialOffload, graphFullOffload := f.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)), kvct)
|
||||
kv, graphPartialOffload, graphFullOffload := f.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)), numParallel, kvct)
|
||||
|
||||
// KV is proportional to the number of layers
|
||||
layerSize += kv / f.KV().BlockCount()
|
||||
if len(kv) > 0 {
|
||||
layerSize += kv[0]
|
||||
}
|
||||
|
||||
var kvTotal uint64
|
||||
for _, kvLayer := range kv {
|
||||
kvTotal += kvLayer
|
||||
}
|
||||
|
||||
if graphPartialOffload == 0 {
|
||||
graphPartialOffload = f.KV().GQA() * kv / 6
|
||||
graphPartialOffload = f.KV().GQA() * kvTotal / 6
|
||||
}
|
||||
if graphFullOffload == 0 {
|
||||
graphFullOffload = graphPartialOffload
|
||||
@@ -217,9 +223,9 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
|
||||
// Some models have inconsistent layer sizes
|
||||
if blk, ok := layers[fmt.Sprintf("blk.%d", i)]; ok {
|
||||
layerSize = blk.Size()
|
||||
layerSize += kv / f.KV().BlockCount()
|
||||
layerSize += kv[i]
|
||||
memoryWeights += blk.Size()
|
||||
}
|
||||
memoryWeights += layerSize
|
||||
|
||||
if opts.NumGPU >= 0 && layerCount >= opts.NumGPU {
|
||||
// Stop allocating on GPU(s) once we hit the users target NumGPU
|
||||
@@ -315,7 +321,7 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
|
||||
layersRequested: opts.NumGPU,
|
||||
layersModel: int(f.KV().BlockCount()) + 1,
|
||||
availableList: availableList,
|
||||
kv: kv,
|
||||
kv: kvTotal,
|
||||
allocationsList: allocationsList,
|
||||
memoryWeights: memoryWeights,
|
||||
memoryLayerOutput: memoryLayerOutput,
|
||||
@@ -374,9 +380,9 @@ func (m MemoryEstimate) LogValue() slog.Value {
|
||||
slog.Group(
|
||||
"weights",
|
||||
// memory of the weights
|
||||
"total", format.HumanBytes2(m.memoryWeights),
|
||||
"total", format.HumanBytes2(m.memoryWeights+m.memoryLayerOutput),
|
||||
// memory of repeating layers
|
||||
"repeating", format.HumanBytes2(m.memoryWeights-m.memoryLayerOutput),
|
||||
"repeating", format.HumanBytes2(m.memoryWeights),
|
||||
// memory of non-repeating layers
|
||||
"nonrepeating", format.HumanBytes2(m.memoryLayerOutput),
|
||||
),
|
||||
|
||||
@@ -61,7 +61,7 @@ func TestEstimateGPULayers(t *testing.T) {
|
||||
projectors := []string{}
|
||||
opts := api.DefaultOptions()
|
||||
t.Run("cpu", func(t *testing.T) {
|
||||
estimate := EstimateGPULayers(gpus, ggml, projectors, opts)
|
||||
estimate := EstimateGPULayers(gpus, ggml, projectors, opts, 1)
|
||||
assert.Equal(t, 0, estimate.Layers)
|
||||
assert.Equal(t, uint64(0), estimate.Graph)
|
||||
})
|
||||
@@ -112,7 +112,7 @@ func TestEstimateGPULayers(t *testing.T) {
|
||||
gpus[1].FreeMemory += gpuMinimumMemory + layerSize + s.layer1*layerSize + 1
|
||||
gpus[0].FreeMemory += max(graphFullOffload, graphPartialOffload)
|
||||
gpus[1].FreeMemory += max(graphFullOffload, graphPartialOffload)
|
||||
estimate := EstimateGPULayers(gpus, ggml, projectors, opts)
|
||||
estimate := EstimateGPULayers(gpus, ggml, projectors, opts, 1)
|
||||
assert.Equal(t, int(s.expect0+s.expect1), estimate.Layers, "scenario %d: %v", i, s)
|
||||
assert.Equal(t, fmt.Sprintf("%d,%d", s.expect0, s.expect1), estimate.TensorSplit, "scenario %d: %v", i, s)
|
||||
var layerSums uint64
|
||||
|
||||
162
llm/server.go
162
llm/server.go
@@ -109,7 +109,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
|
||||
gpus = discover.GetCPUInfo()
|
||||
}
|
||||
|
||||
estimate := EstimateGPULayers(gpus, f, projectors, opts)
|
||||
estimate := EstimateGPULayers(gpus, f, projectors, opts, numParallel)
|
||||
if len(gpus) > 1 || gpus[0].Library != "cpu" {
|
||||
switch {
|
||||
case gpus[0].Library == "metal" && estimate.VRAMSize > systemTotalMemory:
|
||||
@@ -402,7 +402,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
|
||||
s.cmd.Env = append(s.cmd.Env, visibleDevicesEnv+"="+visibleDevicesEnvVal)
|
||||
}
|
||||
|
||||
slog.Info("starting llama server", "cmd", s.cmd.String())
|
||||
slog.Info("starting llama server", "cmd", s.cmd)
|
||||
if envconfig.Debug() {
|
||||
filteredEnv := []string{}
|
||||
for _, ev := range s.cmd.Env {
|
||||
@@ -470,7 +470,7 @@ const ( // iota is reset to 0
|
||||
ServerStatusError
|
||||
)
|
||||
|
||||
func (s ServerStatus) ToString() string {
|
||||
func (s ServerStatus) String() string {
|
||||
switch s {
|
||||
case ServerStatusReady:
|
||||
return "llm server ready"
|
||||
@@ -485,12 +485,9 @@ func (s ServerStatus) ToString() string {
|
||||
}
|
||||
}
|
||||
|
||||
type ServerStatusResp struct {
|
||||
Status string `json:"status"`
|
||||
SlotsIdle int `json:"slots_idle"`
|
||||
SlotsProcessing int `json:"slots_processing"`
|
||||
Error string `json:"error"`
|
||||
Progress float32 `json:"progress"`
|
||||
type ServerStatusResponse struct {
|
||||
Status ServerStatus `json:"status"`
|
||||
Progress float32 `json:"progress"`
|
||||
}
|
||||
|
||||
func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
|
||||
@@ -502,7 +499,7 @@ func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
|
||||
}
|
||||
if s.cmd.ProcessState.ExitCode() == -1 {
|
||||
// Most likely a signal killed it, log some more details to try to help troubleshoot
|
||||
slog.Warn("llama runner process no longer running", "sys", s.cmd.ProcessState.Sys(), "string", s.cmd.ProcessState.String())
|
||||
slog.Warn("llama runner process no longer running", "sys", s.cmd.ProcessState.Sys(), "string", s.cmd.ProcessState)
|
||||
}
|
||||
return ServerStatusError, fmt.Errorf("llama runner process no longer running: %d %s", s.cmd.ProcessState.ExitCode(), msg)
|
||||
}
|
||||
@@ -527,21 +524,19 @@ func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
|
||||
return ServerStatusError, fmt.Errorf("read health request: %w", err)
|
||||
}
|
||||
|
||||
var status ServerStatusResp
|
||||
if err := json.Unmarshal(body, &status); err != nil {
|
||||
var ssr ServerStatusResponse
|
||||
if err := json.Unmarshal(body, &ssr); err != nil {
|
||||
return ServerStatusError, fmt.Errorf("health unmarshal encode response: %w", err)
|
||||
}
|
||||
|
||||
switch status.Status {
|
||||
case "ok":
|
||||
return ServerStatusReady, nil
|
||||
case "no slot available":
|
||||
return ServerStatusNoSlotsAvailable, nil
|
||||
case "loading model":
|
||||
s.loadProgress = status.Progress
|
||||
return ServerStatusLoadingModel, nil
|
||||
switch ssr.Status {
|
||||
case ServerStatusLoadingModel:
|
||||
s.loadProgress = ssr.Progress
|
||||
return ssr.Status, nil
|
||||
case ServerStatusReady, ServerStatusNoSlotsAvailable:
|
||||
return ssr.Status, nil
|
||||
default:
|
||||
return ServerStatusError, fmt.Errorf("server error: %+v", status)
|
||||
return ssr.Status, fmt.Errorf("server error: %+v", ssr)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -616,7 +611,7 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
|
||||
status, _ := s.getServerStatus(ctx)
|
||||
if lastStatus != status && status != ServerStatusReady {
|
||||
// Only log on status changes
|
||||
slog.Info("waiting for server to become available", "status", status.ToString())
|
||||
slog.Info("waiting for server to become available", "status", status)
|
||||
}
|
||||
switch status {
|
||||
case ServerStatusReady:
|
||||
@@ -630,7 +625,7 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
|
||||
slog.Debug(fmt.Sprintf("model load progress %0.2f", s.loadProgress))
|
||||
stallTimer = time.Now().Add(stallDuration)
|
||||
} else if !fullyLoaded && int(s.loadProgress*100.0) >= 100 {
|
||||
slog.Debug("model load completed, waiting for server to become available", "status", status.ToString())
|
||||
slog.Debug("model load completed, waiting for server to become available", "status", status)
|
||||
stallTimer = time.Now().Add(stallDuration)
|
||||
fullyLoaded = true
|
||||
}
|
||||
@@ -671,63 +666,49 @@ type ImageData struct {
|
||||
AspectRatioID int `json:"aspect_ratio_id"`
|
||||
}
|
||||
|
||||
type completion struct {
|
||||
Content string `json:"content"`
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
Stop bool `json:"stop"`
|
||||
StoppedLimit bool `json:"stopped_limit"`
|
||||
|
||||
Timings struct {
|
||||
PredictedN int `json:"predicted_n"`
|
||||
PredictedMS float64 `json:"predicted_ms"`
|
||||
PromptN int `json:"prompt_n"`
|
||||
PromptMS float64 `json:"prompt_ms"`
|
||||
}
|
||||
}
|
||||
|
||||
type CompletionRequest struct {
|
||||
Prompt string
|
||||
Format json.RawMessage
|
||||
Images []ImageData
|
||||
Options *api.Options
|
||||
|
||||
Grammar string // set before sending the request to the subprocess
|
||||
}
|
||||
|
||||
// DoneReason represents the reason why a completion response is done
|
||||
type DoneReason int
|
||||
|
||||
const (
|
||||
// DoneReasonStop indicates the completion stopped naturally
|
||||
DoneReasonStop DoneReason = iota
|
||||
// DoneReasonLength indicates the completion stopped due to length limits
|
||||
DoneReasonLength
|
||||
// DoneReasonConnectionClosed indicates the completion stopped due to the connection being closed
|
||||
DoneReasonConnectionClosed
|
||||
)
|
||||
|
||||
func (d DoneReason) String() string {
|
||||
switch d {
|
||||
case DoneReasonLength:
|
||||
return "length"
|
||||
case DoneReasonStop:
|
||||
return "stop"
|
||||
default:
|
||||
return "" // closed
|
||||
}
|
||||
}
|
||||
|
||||
type CompletionResponse struct {
|
||||
Content string
|
||||
DoneReason string
|
||||
Done bool
|
||||
PromptEvalCount int
|
||||
PromptEvalDuration time.Duration
|
||||
EvalCount int
|
||||
EvalDuration time.Duration
|
||||
Content string `json:"content"`
|
||||
DoneReason DoneReason `json:"done_reason"`
|
||||
Done bool `json:"done"`
|
||||
PromptEvalCount int `json:"prompt_eval_count"`
|
||||
PromptEvalDuration time.Duration `json:"prompt_eval_duration"`
|
||||
EvalCount int `json:"eval_count"`
|
||||
EvalDuration time.Duration `json:"eval_duration"`
|
||||
}
|
||||
|
||||
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
|
||||
request := map[string]any{
|
||||
"prompt": req.Prompt,
|
||||
"stream": true,
|
||||
"n_predict": req.Options.NumPredict,
|
||||
"n_keep": req.Options.NumKeep,
|
||||
"main_gpu": req.Options.MainGPU,
|
||||
"temperature": req.Options.Temperature,
|
||||
"top_k": req.Options.TopK,
|
||||
"top_p": req.Options.TopP,
|
||||
"min_p": req.Options.MinP,
|
||||
"typical_p": req.Options.TypicalP,
|
||||
"repeat_last_n": req.Options.RepeatLastN,
|
||||
"repeat_penalty": req.Options.RepeatPenalty,
|
||||
"presence_penalty": req.Options.PresencePenalty,
|
||||
"frequency_penalty": req.Options.FrequencyPenalty,
|
||||
"mirostat": req.Options.Mirostat,
|
||||
"mirostat_tau": req.Options.MirostatTau,
|
||||
"mirostat_eta": req.Options.MirostatEta,
|
||||
"seed": req.Options.Seed,
|
||||
"stop": req.Options.Stop,
|
||||
"image_data": req.Images,
|
||||
"cache_prompt": true,
|
||||
}
|
||||
|
||||
if len(req.Format) > 0 {
|
||||
switch string(req.Format) {
|
||||
case `null`, `""`:
|
||||
@@ -735,7 +716,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
// these as "not set".
|
||||
break
|
||||
case `"json"`:
|
||||
request["grammar"] = grammarJSON
|
||||
req.Grammar = grammarJSON
|
||||
default:
|
||||
if req.Format[0] != '{' {
|
||||
return fmt.Errorf("invalid format: %q; expected \"json\" or a valid JSON Schema object", req.Format)
|
||||
@@ -746,10 +727,15 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
if g == nil {
|
||||
return fmt.Errorf("invalid JSON schema in format")
|
||||
}
|
||||
request["grammar"] = string(g)
|
||||
req.Grammar = string(g)
|
||||
}
|
||||
}
|
||||
|
||||
if req.Options == nil {
|
||||
opts := api.DefaultOptions()
|
||||
req.Options = &opts
|
||||
}
|
||||
|
||||
if err := s.sem.Acquire(ctx, 1); err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
slog.Info("aborting completion request due to client closing the connection")
|
||||
@@ -770,7 +756,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
if err != nil {
|
||||
return err
|
||||
} else if status != ServerStatusReady {
|
||||
return fmt.Errorf("unexpected server status: %s", status.ToString())
|
||||
return fmt.Errorf("unexpected server status: %s", status)
|
||||
}
|
||||
|
||||
// Handling JSON marshaling with special characters unescaped.
|
||||
@@ -778,7 +764,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
enc := json.NewEncoder(buffer)
|
||||
enc.SetEscapeHTML(false)
|
||||
|
||||
if err := enc.Encode(request); err != nil {
|
||||
if err := enc.Encode(req); err != nil {
|
||||
return fmt.Errorf("failed to marshal data: %v", err)
|
||||
}
|
||||
|
||||
@@ -823,13 +809,12 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
continue
|
||||
}
|
||||
|
||||
// slog.Debug("got line", "line", string(line))
|
||||
evt, ok := bytes.CutPrefix(line, []byte("data: "))
|
||||
if !ok {
|
||||
evt = line
|
||||
}
|
||||
|
||||
var c completion
|
||||
var c CompletionResponse
|
||||
if err := json.Unmarshal(evt, &c); err != nil {
|
||||
return fmt.Errorf("error unmarshalling llm prediction response: %v", err)
|
||||
}
|
||||
@@ -853,20 +838,8 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
})
|
||||
}
|
||||
|
||||
if c.Stop {
|
||||
doneReason := "stop"
|
||||
if c.StoppedLimit {
|
||||
doneReason = "length"
|
||||
}
|
||||
|
||||
fn(CompletionResponse{
|
||||
Done: true,
|
||||
DoneReason: doneReason,
|
||||
PromptEvalCount: c.Timings.PromptN,
|
||||
PromptEvalDuration: parseDurationMs(c.Timings.PromptMS),
|
||||
EvalCount: c.Timings.PredictedN,
|
||||
EvalDuration: parseDurationMs(c.Timings.PredictedMS),
|
||||
})
|
||||
if c.Done {
|
||||
fn(c)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -914,7 +887,7 @@ func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, err
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if status != ServerStatusReady {
|
||||
return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
|
||||
return nil, fmt.Errorf("unexpected server status: %s", status)
|
||||
}
|
||||
|
||||
data, err := json.Marshal(EmbeddingRequest{Content: input})
|
||||
@@ -1059,12 +1032,3 @@ func (s *llmServer) EstimatedVRAMByGPU(gpuID string) uint64 {
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func parseDurationMs(ms float64) time.Duration {
|
||||
dur, err := time.ParseDuration(fmt.Sprintf("%fms", ms))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return dur
|
||||
}
|
||||
|
||||
@@ -2,28 +2,19 @@ package ml
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"os"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
)
|
||||
|
||||
type Config interface {
|
||||
Architecture() string
|
||||
String(string, ...string) string
|
||||
Uint(string, ...uint32) uint32
|
||||
Float(string, ...float32) float32
|
||||
Bool(string, ...bool) bool
|
||||
|
||||
Strings(string, ...[]string) []string
|
||||
Uints(string, ...[]uint32) []uint32
|
||||
Floats(string, ...[]float32) []float32
|
||||
}
|
||||
|
||||
type Backend interface {
|
||||
Config() Config
|
||||
Config() fs.Config
|
||||
Get(name string) Tensor
|
||||
NewContext() Context
|
||||
NewContextSize(size int) Context
|
||||
@@ -60,6 +51,10 @@ type CacheConfig struct {
|
||||
|
||||
// BackendParams controls how the backend loads and executes models
|
||||
type BackendParams struct {
|
||||
// Progress is a callback function that allows reporting percentage completion
|
||||
// of model loading
|
||||
Progress func(float32)
|
||||
|
||||
// NumThreads sets the number of threads to use if running on the CPU
|
||||
NumThreads int
|
||||
|
||||
@@ -76,9 +71,9 @@ type BackendParams struct {
|
||||
FlashAttention bool
|
||||
}
|
||||
|
||||
var backends = make(map[string]func(*os.File, BackendParams) (Backend, error))
|
||||
var backends = make(map[string]func(context.Context, *os.File, BackendParams) (Backend, error))
|
||||
|
||||
func RegisterBackend(name string, f func(*os.File, BackendParams) (Backend, error)) {
|
||||
func RegisterBackend(name string, f func(context.Context, *os.File, BackendParams) (Backend, error)) {
|
||||
if _, ok := backends[name]; ok {
|
||||
panic("backend: backend already registered")
|
||||
}
|
||||
@@ -86,9 +81,9 @@ func RegisterBackend(name string, f func(*os.File, BackendParams) (Backend, erro
|
||||
backends[name] = f
|
||||
}
|
||||
|
||||
func NewBackend(f *os.File, params BackendParams) (Backend, error) {
|
||||
func NewBackend(ctx context.Context, f *os.File, params BackendParams) (Backend, error) {
|
||||
if backend, ok := backends["ggml"]; ok {
|
||||
return backend(f, params)
|
||||
return backend(ctx, f, params)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unsupported backend")
|
||||
@@ -102,15 +97,20 @@ type Context interface {
|
||||
|
||||
Forward(...Tensor) Context
|
||||
Compute(...Tensor)
|
||||
|
||||
// Reserve is analogous to Compute but rather than executing a
|
||||
// graph, simply preallocates memory. Typically called with a
|
||||
// worst case graph to ensure all resources are available for
|
||||
// for future inference.
|
||||
Reserve() error
|
||||
|
||||
MaxGraphNodes() int
|
||||
Close()
|
||||
|
||||
// Input returns a context appropriate for creating input tensors
|
||||
// Input returns a context appropriate for creating tensors that are
|
||||
// inputs to the model (which includes things like output locations)
|
||||
Input() Context
|
||||
|
||||
// Output returns a context appropriate for creating output tensors
|
||||
Output() Context
|
||||
|
||||
// Layer returns a context appropriate for creating intermediate tensors
|
||||
Layer(int) Context
|
||||
}
|
||||
@@ -125,6 +125,7 @@ type Tensor interface {
|
||||
Bytes() []byte
|
||||
Floats() []float32
|
||||
|
||||
Neg(ctx Context) Tensor
|
||||
Add(ctx Context, t2 Tensor) Tensor
|
||||
Mul(ctx Context, t2 Tensor) Tensor
|
||||
Mulmat(ctx Context, t2 Tensor) Tensor
|
||||
@@ -139,7 +140,10 @@ type Tensor interface {
|
||||
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
|
||||
|
||||
RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32) Tensor
|
||||
IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
|
||||
|
||||
Sin(ctx Context) Tensor
|
||||
Cos(ctx Context) Tensor
|
||||
Tanh(ctx Context) Tensor
|
||||
GELU(ctx Context) Tensor
|
||||
SILU(ctx Context) Tensor
|
||||
@@ -154,9 +158,13 @@ type Tensor interface {
|
||||
Unpad(ctx Context, shape ...int) Tensor
|
||||
|
||||
Stack(ctx Context, dim int, s ...Tensor) Tensor
|
||||
|
||||
// Repeat repeats the tensor n times along dimension dim
|
||||
Repeat(ctx Context, dim, n int) Tensor
|
||||
Concat(ctx Context, t2 Tensor, dim int) Tensor
|
||||
Rows(ctx Context, t2 Tensor) Tensor
|
||||
Copy(ctx Context, t2 Tensor) Tensor
|
||||
Duplicate(ctx Context) Tensor
|
||||
}
|
||||
|
||||
// ScaledDotProductAttention implements a fused attention
|
||||
@@ -221,7 +229,7 @@ func Dump(ctx Context, t Tensor, opts ...DumpOptions) string {
|
||||
return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
|
||||
})
|
||||
case DTypeF16, DTypeQ80, DTypeQ40:
|
||||
f32 := ctx.Empty(DTypeF32, t.Shape()...)
|
||||
f32 := ctx.Input().Empty(DTypeF32, t.Shape()...)
|
||||
f32 = t.Copy(ctx, f32)
|
||||
return dump[[]float32](ctx, f32, opts[0].Items, func(f float32) string {
|
||||
return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
|
||||
|
||||
@@ -9,20 +9,24 @@ package ggml
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"maps"
|
||||
"os"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"unicode"
|
||||
"unsafe"
|
||||
|
||||
"github.com/ollama/ollama/format"
|
||||
fs "github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/fs"
|
||||
fsggml "github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/ml"
|
||||
ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src"
|
||||
"golang.org/x/sync/errgroup"
|
||||
@@ -39,16 +43,17 @@ func devices() []*C.struct_ggml_backend_device {
|
||||
}
|
||||
|
||||
type Backend struct {
|
||||
meta *fs.GGML
|
||||
sched *C.struct_ggml_backend_sched
|
||||
meta *fsggml.GGML
|
||||
|
||||
sched *C.struct_ggml_backend_sched
|
||||
schedBackends []*C.struct_ggml_backend
|
||||
schedBufts []*C.struct_ggml_backend_buffer_type
|
||||
|
||||
tensors map[string]*C.struct_ggml_tensor
|
||||
|
||||
// input is the backend used for inputs
|
||||
input *C.struct_ggml_backend_buffer_type
|
||||
|
||||
// output is the backend used for outputs
|
||||
output *C.struct_ggml_backend_buffer_type
|
||||
|
||||
// layers is the backend used for repeating layers
|
||||
layers map[int]*C.struct_ggml_backend_buffer_type
|
||||
|
||||
@@ -58,8 +63,8 @@ type Backend struct {
|
||||
maxGraphNodes int
|
||||
}
|
||||
|
||||
func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
|
||||
meta, n, err := fs.Decode(r, -1)
|
||||
func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend, error) {
|
||||
meta, n, err := fsggml.Decode(r, -1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -183,7 +188,7 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
|
||||
maxTensors += blocks * 2
|
||||
|
||||
type tensor struct {
|
||||
source *fs.Tensor
|
||||
source *fsggml.Tensor
|
||||
target string
|
||||
}
|
||||
|
||||
@@ -281,6 +286,10 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
|
||||
}
|
||||
|
||||
b := C.ggml_backend_alloc_ctx_tensors_from_buft(c, bt)
|
||||
if b == nil {
|
||||
return nil, fmt.Errorf("unable to allocate memory from device %v for model weights", C.GoString(C.ggml_backend_buft_name(bt)))
|
||||
}
|
||||
|
||||
C.ggml_backend_buffer_set_usage(b, C.GGML_BACKEND_BUFFER_USAGE_WEIGHTS)
|
||||
bbs[c] = b
|
||||
}
|
||||
@@ -297,12 +306,16 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// concurrently read in tensor data. uses a section reader which is safe for concurrent reads
|
||||
sr := io.NewSectionReader(r, int64(meta.Tensors().Offset), n-int64(meta.Tensors().Offset))
|
||||
var g errgroup.Group
|
||||
var doneBytes atomic.Uint64
|
||||
totalBytes := uint64(n) - meta.Tensors().Offset
|
||||
|
||||
g, ctx := errgroup.WithContext(ctx)
|
||||
g.SetLimit(runtime.GOMAXPROCS(0))
|
||||
for _, t := range meta.Tensors().Items() {
|
||||
for _, target := range targets[t.Name] {
|
||||
g.Go(func() error {
|
||||
g.Go(func() error {
|
||||
tts := make([]*C.struct_ggml_tensor, max(1, len(targets[t.Name])))
|
||||
for i := range tts {
|
||||
target := targets[t.Name][i]
|
||||
if target == "" {
|
||||
target = t.Name
|
||||
}
|
||||
@@ -312,23 +325,51 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
|
||||
return fmt.Errorf("unassigned tensor: %s", t.Name)
|
||||
}
|
||||
|
||||
bts := make([]byte, t.Size())
|
||||
n, err := io.ReadFull(io.NewSectionReader(sr, int64(t.Offset), int64(t.Size())), bts)
|
||||
tts[i] = tt
|
||||
}
|
||||
|
||||
// Create a new FD for each goroutine so that each FD is read sequentially, rather than
|
||||
// seeking around within an FD shared between all goroutines.
|
||||
file, err := os.Open(r.Name())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
sr := io.NewSectionReader(file, int64(meta.Tensors().Offset+t.Offset), int64(t.Size()))
|
||||
bts := make([]byte, 128*format.KibiByte)
|
||||
|
||||
var s uint64
|
||||
for s < t.Size() {
|
||||
n, err := io.ReadFull(sr, bts[:min(len(bts), int(t.Size()-s))])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if n != len(bts) {
|
||||
return errors.New("short read")
|
||||
for _, tt := range tts {
|
||||
C.ggml_backend_tensor_set(tt, unsafe.Pointer(&bts[0]), C.size_t(s), C.size_t(n))
|
||||
}
|
||||
|
||||
C.ggml_backend_tensor_set(tt, unsafe.Pointer(&bts[0]), 0, C.size_t(t.Size()))
|
||||
return nil
|
||||
})
|
||||
}
|
||||
s += uint64(n)
|
||||
|
||||
if params.Progress != nil {
|
||||
done := doneBytes.Add(uint64(n))
|
||||
params.Progress(float32(done) / float32(totalBytes))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
if g.Wait() != nil {
|
||||
// start a goroutine to cancel the errgroup if the parent context is done
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
g.Go(func() error {
|
||||
return ctx.Err()
|
||||
})
|
||||
}()
|
||||
|
||||
if err := g.Wait(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -353,8 +394,6 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
|
||||
schedBackends = append(schedBackends, b)
|
||||
schedBufts = append(schedBufts, bt)
|
||||
|
||||
slog.Info("compute graph", "backend", C.GoString(C.ggml_backend_name(b)), "buffer_type", C.GoString(C.ggml_backend_buft_name(bt)))
|
||||
|
||||
if C.ggml_backend_is_cpu(b) {
|
||||
// set number of threads for cpu backend
|
||||
C.ggml_backend_cpu_set_n_threads(b, C.int(Threads(params.NumThreads)))
|
||||
@@ -371,10 +410,11 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
|
||||
(*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&schedBufts[0])),
|
||||
C.int(len(schedBackends)),
|
||||
C.size_t(maxGraphNodes),
|
||||
true,
|
||||
C._Bool(len(gpus) > 1 && slices.Contains(gpus, output.d)),
|
||||
),
|
||||
input: deviceBufferTypes[input.d],
|
||||
output: deviceBufferTypes[output.d],
|
||||
schedBackends: schedBackends,
|
||||
schedBufts: schedBufts,
|
||||
input: deviceBufferTypes[input.d],
|
||||
layers: func() map[int]*C.struct_ggml_backend_buffer_type {
|
||||
m := make(map[int]*C.struct_ggml_backend_buffer_type)
|
||||
for i, layer := range layers {
|
||||
@@ -390,7 +430,7 @@ func init() {
|
||||
ml.RegisterBackend("ggml", New)
|
||||
}
|
||||
|
||||
func (b *Backend) Config() ml.Config {
|
||||
func (b *Backend) Config() fs.Config {
|
||||
return b.meta.KV()
|
||||
}
|
||||
|
||||
@@ -455,19 +495,6 @@ func (c Context) Input() ml.Context {
|
||||
return &c
|
||||
}
|
||||
|
||||
func (c Context) Output() ml.Context {
|
||||
if c.b.output != nil {
|
||||
return &Context{
|
||||
b: c.b,
|
||||
ctx: c.ctx,
|
||||
buft: c.b.output,
|
||||
maxGraphNodes: c.maxGraphNodes,
|
||||
}
|
||||
}
|
||||
|
||||
return &c
|
||||
}
|
||||
|
||||
func (c Context) Layer(i int) ml.Context {
|
||||
if buft, ok := c.b.layers[i]; ok {
|
||||
return &Context{
|
||||
@@ -512,6 +539,24 @@ func (c Context) Compute(tensors ...ml.Tensor) {
|
||||
}
|
||||
}
|
||||
|
||||
func (c Context) Reserve() error {
|
||||
if !C.ggml_backend_sched_reserve(c.b.sched, c.graph) {
|
||||
C.ggml_backend_sched_reset(c.b.sched)
|
||||
return errors.New("failed to reserve graph")
|
||||
}
|
||||
|
||||
slog.Debug("compute graph", "nodes", C.ggml_graph_n_nodes(c.graph), "splits", C.ggml_backend_sched_get_n_splits(c.b.sched))
|
||||
for i := range c.b.schedBackends {
|
||||
size := C.ggml_backend_sched_get_buffer_size(c.b.sched, c.b.schedBackends[i])
|
||||
slog.Info("compute graph", "backend", C.GoString(C.ggml_backend_name(c.b.schedBackends[i])), "buffer_type", C.GoString(C.ggml_backend_buft_name(c.b.schedBufts[i])),
|
||||
"size", format.HumanBytes2(uint64(size)))
|
||||
}
|
||||
|
||||
C.ggml_backend_sched_reset(c.b.sched)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c Context) MaxGraphNodes() int {
|
||||
return c.maxGraphNodes
|
||||
}
|
||||
@@ -529,9 +574,9 @@ func pad(length, pad C.size_t) C.size_t {
|
||||
return ((length + pad - 1) / pad) * pad
|
||||
}
|
||||
|
||||
func (c Context) newTensor(dtype ml.DType, shape []int) ml.Tensor {
|
||||
func (c Context) newTensor(dtype ml.DType, shape []int) (ml.Tensor, error) {
|
||||
if c.buft == nil {
|
||||
panic("set Input, Output, or Layer before creating tensors")
|
||||
panic("set Input or Layer before creating tensors")
|
||||
}
|
||||
|
||||
var cdtype uint32
|
||||
@@ -552,7 +597,7 @@ func (c Context) newTensor(dtype ml.DType, shape []int) ml.Tensor {
|
||||
|
||||
if len(shape) < 1 || shape[0] == 0 {
|
||||
var shape C.int64_t = 0
|
||||
return &Tensor{b: c.b, t: C.ggml_new_tensor(c.ctx, cdtype, 1, &shape)}
|
||||
return &Tensor{b: c.b, t: C.ggml_new_tensor(c.ctx, cdtype, 1, &shape)}, nil
|
||||
} else if len(shape) > 4 {
|
||||
panic("unsupported number of dimensions")
|
||||
}
|
||||
@@ -566,16 +611,29 @@ func (c Context) newTensor(dtype ml.DType, shape []int) ml.Tensor {
|
||||
t := C.ggml_new_tensor(c.ctx, cdtype, C.int(len(shape)), shapeToGGML(shape))
|
||||
size := pad(C.ggml_backend_buft_get_alloc_size(c.buft, t), C.ggml_backend_buft_get_alignment(c.buft))
|
||||
b := C.ggml_backend_buft_alloc_buffer(c.buft, size)
|
||||
if b == nil {
|
||||
return nil, fmt.Errorf("unable to allocate %v from device %v for new tensor", format.HumanBytes2(uint64(size)), C.GoString(C.ggml_backend_buft_name(c.buft)))
|
||||
}
|
||||
|
||||
C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
|
||||
return &Tensor{b: c.b, t: t}
|
||||
return &Tensor{b: c.b, t: t}, nil
|
||||
}
|
||||
|
||||
func (c Context) Empty(dtype ml.DType, shape ...int) ml.Tensor {
|
||||
return c.newTensor(dtype, shape)
|
||||
t, err := c.newTensor(dtype, shape)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
|
||||
t := c.newTensor(dtype, shape)
|
||||
t, err := c.newTensor(dtype, shape)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
C.ggml_set_zero(t.(*Tensor).t)
|
||||
return t
|
||||
}
|
||||
@@ -603,7 +661,11 @@ func (c Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t := c.newTensor(ml.DTypeF32, shape)
|
||||
t, err := c.newTensor(ml.DTypeF32, shape)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(s) > 0 {
|
||||
C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
|
||||
}
|
||||
@@ -616,7 +678,11 @@ func (c Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t := c.newTensor(ml.DTypeI32, shape)
|
||||
t, err := c.newTensor(ml.DTypeI32, shape)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(s) > 0 {
|
||||
C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
|
||||
}
|
||||
@@ -700,6 +766,13 @@ func (t *Tensor) DType() ml.DType {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Neg(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_neg(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
@@ -707,6 +780,27 @@ func (t *Tensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Repeat(ctx ml.Context, dim, n int) ml.Tensor {
|
||||
if dim < 0 || dim >= C.GGML_MAX_DIMS {
|
||||
panic("invalid dimension")
|
||||
}
|
||||
|
||||
shape := make([]C.int64_t, C.GGML_MAX_DIMS)
|
||||
for i := range C.GGML_MAX_DIMS {
|
||||
if i == dim {
|
||||
shape[i] = C.int64_t(t.Dim(i) * n)
|
||||
} else {
|
||||
shape[i] = C.int64_t(t.Dim(i))
|
||||
}
|
||||
}
|
||||
|
||||
tmpl := C.ggml_new_tensor(ctx.(*Context).ctx, t.t._type, C.int(len(shape)), unsafe.SliceData(shape))
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_repeat(ctx.(*Context).ctx, t.t, tmpl),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Stack(ctx ml.Context, dim int, s ...ml.Tensor) ml.Tensor {
|
||||
if len(s) > 0 {
|
||||
return t.Concat(ctx, s[0].Stack(ctx, dim, s[1:]...), dim)
|
||||
@@ -843,6 +937,20 @@ func (t *Tensor) Softmax(ctx ml.Context) ml.Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Sin(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_sin(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Cos(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_cos(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Tanh(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
@@ -931,6 +1039,13 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) IM2Col(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_im2col(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(s0), C.int(s1), C.int(p0), C.int(p1), C.int(d0), C.int(d1), true, C.GGML_TYPE_F32),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) GELU(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
@@ -999,3 +1114,10 @@ func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask ml.T
|
||||
return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Duplicate(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_dup(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -61,11 +61,13 @@
|
||||
#define GGML_CUDA_CC_RDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x1010) // RX 5000
|
||||
#define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x1030) // RX 6000, minimum for dp4a
|
||||
#define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x1100) // RX 7000, minimum for WMMA
|
||||
#define GGML_CUDA_CC_RDNA4 (GGML_CUDA_CC_OFFSET_AMD + 0x1200) // RX 9000
|
||||
|
||||
#define GGML_CUDA_CC_IS_RDNA(cc) (cc >= GGML_CUDA_CC_RDNA1)
|
||||
#define GGML_CUDA_CC_IS_RDNA1(cc) (cc >= GGML_CUDA_CC_RDNA1 && cc < GGML_CUDA_CC_RDNA2)
|
||||
#define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3)
|
||||
#define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3)
|
||||
#define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3 && cc < GGML_CUDA_CC_RDNA4)
|
||||
#define GGML_CUDA_CC_IS_RDNA4(cc) (cc >= GGML_CUDA_CC_RDNA4)
|
||||
#define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA)
|
||||
#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1)
|
||||
|
||||
@@ -386,7 +388,7 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i
|
||||
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
||||
#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(RDNA2)
|
||||
c = __builtin_amdgcn_sdot4(a, b, c, false);
|
||||
#elif defined(RDNA3)
|
||||
#elif defined(RDNA3) || defined(RDNA4)
|
||||
c = __builtin_amdgcn_sudot4( true, a, true, b, c, false);
|
||||
#elif defined(__gfx1010__) || defined(__gfx900__)
|
||||
int tmp1;
|
||||
|
||||
2
ml/backend/ggml/ggml/src/ggml-cuda/mmq.cu
vendored
2
ml/backend/ggml/ggml/src/ggml-cuda/mmq.cu
vendored
@@ -149,5 +149,5 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
|
||||
return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
|
||||
}
|
||||
|
||||
return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
|
||||
return (!GGML_CUDA_CC_IS_RDNA4(cc) && !GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
|
||||
}
|
||||
|
||||
4
ml/backend/ggml/ggml/src/ggml-cuda/mmq.cuh
vendored
4
ml/backend/ggml/ggml/src/ggml-cuda/mmq.cuh
vendored
@@ -2577,9 +2577,9 @@ static __device__ void mul_mat_q_process_tile(
|
||||
|
||||
template <ggml_type type, int mmq_x, int nwarps, bool need_check>
|
||||
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
||||
#if defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
|
||||
#if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
|
||||
__launch_bounds__(WARP_SIZE*nwarps, 2)
|
||||
#endif // defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
|
||||
#endif // defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
|
||||
#else
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
||||
__launch_bounds__(WARP_SIZE*nwarps, 1)
|
||||
|
||||
4
ml/backend/ggml/ggml/src/ggml-cuda/mmvq.cu
vendored
4
ml/backend/ggml/ggml/src/ggml-cuda/mmvq.cu
vendored
@@ -62,13 +62,13 @@ static __global__ void mul_mat_vec_q(
|
||||
|
||||
constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
|
||||
|
||||
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
|
||||
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4))
|
||||
constexpr int nwarps = 1;
|
||||
constexpr int rows_per_cuda_block = 1;
|
||||
#else
|
||||
constexpr int nwarps = ncols_y <= 4 ? 4 : 2;
|
||||
constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2;
|
||||
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
|
||||
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3) && !defined(RDNA4)
|
||||
|
||||
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
|
||||
const int row0 = rows_per_cuda_block*blockIdx.x;
|
||||
|
||||
@@ -150,6 +150,10 @@
|
||||
#define CDNA
|
||||
#endif
|
||||
|
||||
#if defined(__gfx1200__) || defined(__gfx1201__)
|
||||
#define RDNA4
|
||||
#endif
|
||||
|
||||
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
|
||||
defined(__gfx1150__) || defined(__gfx1151__)
|
||||
#define RDNA3
|
||||
|
||||
@@ -3083,6 +3083,13 @@ kernel void kernel_cos(
|
||||
dst[tpig] = cos(src0[tpig]);
|
||||
}
|
||||
|
||||
kernel void kernel_neg(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = -src0[tpig];
|
||||
}
|
||||
|
||||
kernel void kernel_sum_rows(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
|
||||
15
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.m
vendored
15
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.m
vendored
@@ -423,6 +423,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_SQRT,
|
||||
GGML_METAL_KERNEL_TYPE_SIN,
|
||||
GGML_METAL_KERNEL_TYPE_COS,
|
||||
GGML_METAL_KERNEL_TYPE_NEG,
|
||||
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
||||
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
|
||||
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
|
||||
@@ -1039,6 +1040,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
|
||||
@@ -1202,6 +1204,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
||||
case GGML_UNARY_OP_GELU_QUICK:
|
||||
case GGML_UNARY_OP_SILU:
|
||||
case GGML_UNARY_OP_ELU:
|
||||
case GGML_UNARY_OP_NEG:
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
default:
|
||||
return false;
|
||||
@@ -1873,6 +1876,18 @@ static void ggml_metal_encode_node(
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
case GGML_UNARY_OP_NEG:
|
||||
{
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NEG].pipeline;
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
|
||||
const int64_t n = ggml_nelements(dst);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
|
||||
|
||||
@@ -945,6 +945,13 @@ kernel void kernel_cos(
|
||||
dst[tpig] = cos(src0[tpig]);
|
||||
}
|
||||
|
||||
kernel void kernel_neg(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = -src0[tpig];
|
||||
}
|
||||
|
||||
kernel void kernel_sum_rows(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
|
||||
5
ml/backend/ggml/ggml/src/ollama-debug.c
vendored
5
ml/backend/ggml/ggml/src/ollama-debug.c
vendored
@@ -1,4 +1,5 @@
|
||||
#include <string.h>
|
||||
#include <inttypes.h>
|
||||
|
||||
#include "ollama-debug.h"
|
||||
|
||||
@@ -24,7 +25,7 @@ static void print_tensor(const void *tensor, void (*cb)(const void *, int),
|
||||
fprintf(stderr, "[");
|
||||
for (int i = 0; i < dims[0]; i++) {
|
||||
if (i >= nitems && i < dims[0] - nitems) {
|
||||
fprintf(stderr, "... (%lld more), ", dims[0] - 2 * nitems);
|
||||
fprintf(stderr, "... (%" PRIi64 " more), ", dims[0] - 2 * nitems);
|
||||
int skip = dims[0] - 2 * nitems;
|
||||
if (ndims > 1) {
|
||||
stride += mul(dims + 1, ndims - 1) * skip;
|
||||
@@ -67,7 +68,7 @@ static void print_tensor_i32(const void *tensor, int i) {
|
||||
}
|
||||
|
||||
static void ollama_debug_tensor(const struct ggml_tensor *tensor, bool verbose, const char *prefix, int indent) {
|
||||
fprintf(stderr, "%s%s %s (%s): [%lld %lld %lld %lld]\n", prefix, tensor->name,
|
||||
fprintf(stderr, "%s%s %s (%s): [%" PRIi64 " %" PRIi64 " %" PRIi64 " %" PRIi64 "]\n", prefix, tensor->name,
|
||||
ggml_op_name(tensor->op), ggml_type_name(tensor->type), tensor->ne[0],
|
||||
tensor->ne[1], tensor->ne[2], tensor->ne[3]);
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
package input
|
||||
|
||||
import "github.com/ollama/ollama/ml"
|
||||
|
||||
// Input represents one token in the input stream
|
||||
type Input struct {
|
||||
// Token is a single element of text.
|
||||
@@ -15,6 +17,12 @@ type Input struct {
|
||||
// stored in Multimodal, used for caching and comparing
|
||||
// equality.
|
||||
MultimodalHash uint64
|
||||
|
||||
// SameBatch forces the following number of tokens to be processed
|
||||
// in a single batch, breaking and extending batches as needed.
|
||||
// Useful for things like images that must be processed in one
|
||||
// shot.
|
||||
SameBatch int
|
||||
}
|
||||
|
||||
// MultimodalIndex is a multimodal element (such as an image)
|
||||
@@ -27,11 +35,24 @@ type MultimodalIndex struct {
|
||||
Multimodal any
|
||||
}
|
||||
|
||||
// Options contains the inputs for a model forward pass
|
||||
type Options struct {
|
||||
Inputs []int32
|
||||
// Batch contains the inputs for a model forward pass
|
||||
type Batch struct {
|
||||
// Inputs is the input tokens, including placeholders for multimodal inputs.
|
||||
Inputs ml.Tensor
|
||||
|
||||
// Multimodal is a set of multimodal embeddings previously created by
|
||||
// EncodeMultimodal, along with an index into Inputs. Unused for text-only
|
||||
// models or for batches without multimodal elements.
|
||||
Multimodal []MultimodalIndex
|
||||
Positions []int32
|
||||
Sequences []int
|
||||
Outputs []int32
|
||||
|
||||
// Positions is the position for each Input, relative to its sequence. Equal
|
||||
// in length to Inputs.
|
||||
Positions []int32
|
||||
|
||||
// Sequences is the sequence for each Input. Equal in length to Inputs.
|
||||
Sequences []int
|
||||
|
||||
// Outputs are the set of indicies into Inputs for which output data should
|
||||
// be returned.
|
||||
Outputs []int32
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
_ "image/jpeg"
|
||||
@@ -15,16 +16,19 @@ import (
|
||||
_ "golang.org/x/image/tiff"
|
||||
_ "golang.org/x/image/webp"
|
||||
|
||||
fs "github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/fs"
|
||||
fsggml "github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
_ "github.com/ollama/ollama/ml/backend"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
var ErrNoVisionModel = errors.New("this model is missing data required for image input")
|
||||
|
||||
// Model implements a specific model architecture, defining the forward pass and any model-specific configuration
|
||||
type Model interface {
|
||||
Forward(ml.Context, input.Options) (ml.Tensor, error)
|
||||
Forward(ml.Context, input.Batch) (ml.Tensor, error)
|
||||
|
||||
Backend() ml.Backend
|
||||
Config() config
|
||||
@@ -58,7 +62,7 @@ type MultimodalProcessor interface {
|
||||
// This function is also responsible for updating MultimodalHash for any Multimodal
|
||||
// that is modified to ensure that there is a unique hash value that accurately
|
||||
// represents the contents.
|
||||
PostTokenize(ml.Context, []input.Input) ([]input.Input, error)
|
||||
PostTokenize([]input.Input) ([]input.Input, error)
|
||||
}
|
||||
|
||||
// Base implements the common fields and methods for all models
|
||||
@@ -80,10 +84,10 @@ func (m *Base) Config() config {
|
||||
return m.config
|
||||
}
|
||||
|
||||
var models = make(map[string]func(ml.Config) (Model, error))
|
||||
var models = make(map[string]func(fs.Config) (Model, error))
|
||||
|
||||
// Register registers a model constructor for the given architecture
|
||||
func Register(name string, f func(ml.Config) (Model, error)) {
|
||||
func Register(name string, f func(fs.Config) (Model, error)) {
|
||||
if _, ok := models[name]; ok {
|
||||
panic("model: model already registered")
|
||||
}
|
||||
@@ -92,14 +96,14 @@ func Register(name string, f func(ml.Config) (Model, error)) {
|
||||
}
|
||||
|
||||
// New initializes a new model instance with the provided configuration based on the metadata in the model file
|
||||
func New(modelPath string, params ml.BackendParams) (Model, error) {
|
||||
func New(ctx context.Context, modelPath string, params ml.BackendParams) (Model, error) {
|
||||
r, err := os.Open(modelPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
b, err := ml.NewBackend(r, params)
|
||||
b, err := ml.NewBackend(ctx, r, params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -128,14 +132,14 @@ func NewTextProcessor(s string) (TextProcessor, error) {
|
||||
return nil, err
|
||||
}
|
||||
defer r.Close()
|
||||
meta, _, err := fs.Decode(r, -1)
|
||||
meta, _, err := fsggml.Decode(r, -1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return getTextProcessor(meta.KV())
|
||||
}
|
||||
|
||||
func getTextProcessor(kv fs.KV) (TextProcessor, error) {
|
||||
func getTextProcessor(kv fsggml.KV) (TextProcessor, error) {
|
||||
arch := kv.Architecture()
|
||||
f, ok := models[arch]
|
||||
if !ok {
|
||||
@@ -278,24 +282,30 @@ func canNil(t reflect.Type) bool {
|
||||
t.Kind() == reflect.Slice
|
||||
}
|
||||
|
||||
func Forward(ctx ml.Context, m Model, opts input.Options) (ml.Tensor, error) {
|
||||
if len(opts.Positions) != len(opts.Sequences) {
|
||||
return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(opts.Positions), len(opts.Sequences))
|
||||
func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Tensor, error) {
|
||||
if len(batch.Positions) != len(batch.Sequences) {
|
||||
return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(batch.Positions), len(batch.Sequences))
|
||||
}
|
||||
|
||||
if len(opts.Positions) < 1 {
|
||||
if len(batch.Positions) < 1 {
|
||||
return nil, errors.New("batch size cannot be less than 1")
|
||||
}
|
||||
|
||||
var err error
|
||||
batch.Inputs, err = ctx.Input().FromIntSlice(inputs, len(inputs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cache := m.Config().Cache
|
||||
if cache != nil {
|
||||
err := cache.StartForward(ctx, opts)
|
||||
err := cache.StartForward(ctx, batch, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
t, err := m.Forward(ctx, opts)
|
||||
t, err := m.Forward(ctx, batch)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -7,7 +7,8 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
fs "github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/fs"
|
||||
fsggml "github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/backend/ggml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
@@ -139,7 +140,7 @@ func TestPopulateFieldsAlternateName(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestGetTextProcessor(t *testing.T) {
|
||||
tp, err := getTextProcessor(fs.KV{})
|
||||
tp, err := getTextProcessor(fsggml.KV{})
|
||||
if err == nil {
|
||||
t.Error("expected error")
|
||||
} else if !strings.Contains(err.Error(), "unsupported model architecture") {
|
||||
@@ -148,10 +149,10 @@ func TestGetTextProcessor(t *testing.T) {
|
||||
t.Error("expected nil tp")
|
||||
}
|
||||
|
||||
models["dummy"] = func(ml.Config) (Model, error) {
|
||||
models["dummy"] = func(fs.Config) (Model, error) {
|
||||
return notTextProcessorModel{}, nil
|
||||
}
|
||||
tp, err = getTextProcessor(fs.KV{"general.architecture": "dummy"})
|
||||
tp, err = getTextProcessor(fsggml.KV{"general.architecture": "dummy"})
|
||||
if err == nil {
|
||||
t.Error("expected error")
|
||||
} else if !strings.Contains(err.Error(), "not a TextProcessor") {
|
||||
@@ -163,7 +164,7 @@ func TestGetTextProcessor(t *testing.T) {
|
||||
|
||||
type notTextProcessorModel struct{}
|
||||
|
||||
func (notTextProcessorModel) Forward(ml.Context, input.Options) (ml.Tensor, error) {
|
||||
func (notTextProcessorModel) Forward(ml.Context, input.Batch) (ml.Tensor, error) {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ package gemma2
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
@@ -35,10 +36,9 @@ const (
|
||||
gemma27BLayerCount = 46
|
||||
)
|
||||
|
||||
func New(c ml.Config) (model.Model, error) {
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Model{
|
||||
SentencePieceModel: model.NewSentencePieceModel(
|
||||
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
@@ -168,23 +168,18 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
|
||||
return hiddenState.Add(ctx, residual)
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
|
||||
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
|
||||
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize)))
|
||||
|
||||
if len(m.Layers) == gemma27BLayerCount {
|
||||
@@ -211,8 +206,7 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
|
||||
// final logit softcap
|
||||
hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.Options.finalLogitSoftcap))
|
||||
hiddenState = hiddenState.Tanh(ctx)
|
||||
hiddenState = hiddenState.Scale(ctx, float64(m.Options.finalLogitSoftcap))
|
||||
return hiddenState.Rows(ctx, outputs), nil
|
||||
return hiddenState.Scale(ctx, float64(m.Options.finalLogitSoftcap)), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
||||
@@ -2,11 +2,11 @@ package gemma3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"hash/fnv"
|
||||
"image"
|
||||
"math"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
@@ -53,10 +53,9 @@ func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, i
|
||||
return visionOutputs
|
||||
}
|
||||
|
||||
func New(c ml.Config) (model.Model, error) {
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Model{
|
||||
SentencePieceModel: model.NewSentencePieceModel(
|
||||
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
@@ -84,6 +83,10 @@ func New(c ml.Config) (model.Model, error) {
|
||||
}
|
||||
|
||||
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
|
||||
if len(m.VisionModel.Layers) == 0 {
|
||||
return nil, model.ErrNoVisionModel
|
||||
}
|
||||
|
||||
image, _, err := image.Decode(bytes.NewReader(multimodalData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -108,36 +111,23 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
|
||||
return visionOutputs, nil
|
||||
}
|
||||
|
||||
type imageToken struct {
|
||||
embedding ml.Tensor
|
||||
index int
|
||||
}
|
||||
|
||||
func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Input, error) {
|
||||
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||
var result []input.Input
|
||||
fnvHash := fnv.New64a()
|
||||
|
||||
for _, inp := range inputs {
|
||||
if inp.Multimodal == nil {
|
||||
result = append(result, inp)
|
||||
} else {
|
||||
imageInputs := []input.Input{
|
||||
{Token: 108}, // "\n\n"
|
||||
{Token: 255999}, // "<start_of_image>""
|
||||
}
|
||||
result = append(result, imageInputs...)
|
||||
|
||||
// add image embeddings
|
||||
inputMultimodal := inp.Multimodal.(ml.Tensor)
|
||||
|
||||
for i := range inputMultimodal.Dim(1) {
|
||||
fnvHash.Reset()
|
||||
binary.Write(fnvHash, binary.NativeEndian, inp.MultimodalHash)
|
||||
fnvHash.Write([]byte{byte(i)})
|
||||
result = append(result,
|
||||
input.Input{Token: 108, SameBatch: inputMultimodal.Dim(1) + 3}, // "\n\n"
|
||||
input.Input{Token: 255999}, // "<start_of_image>""
|
||||
input.Input{Multimodal: inputMultimodal, MultimodalHash: inp.MultimodalHash}, // image data is on the first placeholder
|
||||
)
|
||||
|
||||
imageToken := imageToken{embedding: inputMultimodal, index: i}
|
||||
result = append(result, input.Input{Multimodal: imageToken, MultimodalHash: fnvHash.Sum64()})
|
||||
}
|
||||
// add image token placeholders
|
||||
result = append(result, slices.Repeat([]input.Input{{Token: 0}}, inputMultimodal.Dim(1)-1)...)
|
||||
|
||||
result = append(result,
|
||||
input.Input{Token: 256000}, // <end_of_image>
|
||||
@@ -149,23 +139,18 @@ func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Inpu
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
|
||||
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
|
||||
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return m.TextModel.Forward(ctx, inputs, positions, outputs, opts, m.Cache), nil
|
||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
||||
@@ -3,6 +3,7 @@ package gemma3
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
@@ -10,12 +11,11 @@ import (
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
type TextOptions struct {
|
||||
type TextConfig struct {
|
||||
hiddenSize, numHeads, numKVHeads int
|
||||
attnKeyLen, attnValLen int
|
||||
eps, ropeScale float32
|
||||
ropeLocalBase, ropeGlobalBase float32
|
||||
finalLogitSoftcap float32
|
||||
largeModelScaling bool
|
||||
}
|
||||
|
||||
@@ -28,7 +28,7 @@ type TextModel struct {
|
||||
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
||||
|
||||
*TextOptions
|
||||
*TextConfig
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -41,12 +41,11 @@ const (
|
||||
cacheTypeCausal
|
||||
)
|
||||
|
||||
func newTextModel(c ml.Config) *TextModel {
|
||||
func newTextModel(c fs.Config) *TextModel {
|
||||
numBlocks := int(c.Uint("block_count"))
|
||||
|
||||
m := TextModel{
|
||||
SentencePieceModel: model.NewSentencePieceModel(
|
||||
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
@@ -56,17 +55,16 @@ func newTextModel(c ml.Config) *TextModel {
|
||||
},
|
||||
),
|
||||
Layers: make([]TextLayer, numBlocks),
|
||||
TextOptions: &TextOptions{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||
attnKeyLen: int(c.Uint("attention.key_length", 256)),
|
||||
attnValLen: int(c.Uint("attention.value_length", 256)),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
|
||||
ropeLocalBase: c.Float("rope.local.freq_base", 10000.0),
|
||||
ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0),
|
||||
ropeScale: c.Float("rope.freq_scale", 1.0),
|
||||
finalLogitSoftcap: c.Float("final_logit_softcapping", 30.0),
|
||||
TextConfig: &TextConfig{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||
attnKeyLen: int(c.Uint("attention.key_length", 256)),
|
||||
attnValLen: int(c.Uint("attention.value_length", 256)),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
|
||||
ropeLocalBase: c.Float("rope.local.freq_base", 10000.0),
|
||||
ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0),
|
||||
ropeScale: c.Float("rope.freq_scale", 1.0),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -86,7 +84,7 @@ type TextSelfAttention struct {
|
||||
Output *nn.Linear `gguf:"attn_output"`
|
||||
}
|
||||
|
||||
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
|
||||
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
|
||||
batchSize := hiddenState.Dim(1)
|
||||
ropeType := uint32(2)
|
||||
|
||||
@@ -122,12 +120,12 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
|
||||
}
|
||||
|
||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
ropeBase := m.TextOptions.ropeLocalBase
|
||||
ropeBase := m.TextConfig.ropeLocalBase
|
||||
if (layer+1)%gemmaGlobalCacheCount == 0 {
|
||||
ropeBase = m.TextOptions.ropeGlobalBase
|
||||
ropeBase = m.TextConfig.ropeGlobalBase
|
||||
}
|
||||
|
||||
return key.RoPE(ctx, shift, nil, uint32(m.TextOptions.attnKeyLen), uint32(2), ropeBase, m.TextOptions.ropeScale), nil
|
||||
return key.RoPE(ctx, shift, nil, uint32(m.TextConfig.attnKeyLen), uint32(2), ropeBase, m.TextConfig.ropeScale), nil
|
||||
}
|
||||
|
||||
type TextMLP struct {
|
||||
@@ -136,7 +134,7 @@ type TextMLP struct {
|
||||
Gate *nn.Linear `gguf:"ffn_gate"`
|
||||
}
|
||||
|
||||
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextConfig) ml.Tensor {
|
||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||
return mlp.Down.Forward(ctx, hiddenState)
|
||||
}
|
||||
@@ -150,7 +148,7 @@ type TextLayer struct {
|
||||
PostMLPNorm *nn.RMSNorm `gguf:"post_ffw_norm"`
|
||||
}
|
||||
|
||||
func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
|
||||
func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
|
||||
residual := hiddenState
|
||||
|
||||
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
@@ -173,53 +171,20 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs,
|
||||
return hiddenState.Add(ctx, residual)
|
||||
}
|
||||
|
||||
func setImageEmbeddings(ctx ml.Context, hiddenState ml.Tensor, multimodal []input.MultimodalIndex) []int {
|
||||
var embedding ml.Tensor
|
||||
var src, dst, length int
|
||||
var except []int
|
||||
|
||||
for _, image := range multimodal {
|
||||
imageToken := image.Multimodal.(imageToken)
|
||||
imageSrc := imageToken.index
|
||||
imageDst := image.Index
|
||||
|
||||
if embedding == nil {
|
||||
embedding = imageToken.embedding
|
||||
src = imageSrc
|
||||
dst = imageDst
|
||||
length = 1
|
||||
} else if embedding == imageToken.embedding && imageSrc+1 == src && imageDst+1 == dst {
|
||||
src = imageSrc
|
||||
dst = imageDst
|
||||
length++
|
||||
} else if embedding == imageToken.embedding && src+length == imageSrc && dst+length == imageDst {
|
||||
length++
|
||||
} else {
|
||||
visionOutputs := embedding.View(ctx, src*embedding.Stride(1), length*embedding.Dim(0))
|
||||
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, dst*hiddenState.Stride(1), length*hiddenState.Dim(0))))
|
||||
|
||||
embedding = imageToken.embedding
|
||||
src = imageSrc
|
||||
dst = imageDst
|
||||
length = 1
|
||||
}
|
||||
|
||||
except = append(except, imageDst)
|
||||
}
|
||||
|
||||
if embedding != nil {
|
||||
visionOutputs := embedding.View(ctx, src*embedding.Stride(1), length*embedding.Dim(0))
|
||||
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, dst*hiddenState.Stride(1), length*hiddenState.Dim(0))))
|
||||
}
|
||||
|
||||
return except
|
||||
}
|
||||
|
||||
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, opts input.Options, cache kvcache.Cache) ml.Tensor {
|
||||
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor {
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
|
||||
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
|
||||
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize)))
|
||||
|
||||
except := setImageEmbeddings(ctx, hiddenState, opts.Multimodal)
|
||||
// set image embeddings
|
||||
var except []int
|
||||
for _, image := range batch.Multimodal {
|
||||
visionOutputs := image.Multimodal.(ml.Tensor)
|
||||
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1))))
|
||||
|
||||
for i := range visionOutputs.Dim(1) {
|
||||
except = append(except, image.Index+i)
|
||||
}
|
||||
}
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
// gemma alternates between the sliding window (local) and causal (global)
|
||||
@@ -241,14 +206,9 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
|
||||
lastLayerOutputs = outputs
|
||||
}
|
||||
|
||||
hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextOptions)
|
||||
hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextConfig)
|
||||
}
|
||||
|
||||
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
||||
hiddenState = m.Output.Forward(ctx, hiddenState)
|
||||
|
||||
// final logit softcap
|
||||
hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.TextOptions.finalLogitSoftcap))
|
||||
hiddenState = hiddenState.Tanh(ctx)
|
||||
return hiddenState.Scale(ctx, float64(m.TextOptions.finalLogitSoftcap))
|
||||
return m.Output.Forward(ctx, hiddenState)
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package gemma3
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
)
|
||||
@@ -111,7 +112,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
|
||||
return hiddenState
|
||||
}
|
||||
|
||||
func newVisionModel(c ml.Config) *VisionModel {
|
||||
func newVisionModel(c fs.Config) *VisionModel {
|
||||
return &VisionModel{
|
||||
Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count")),
|
||||
VisionModelOptions: &VisionModelOptions{
|
||||
|
||||
@@ -3,7 +3,7 @@ package gemma3
|
||||
import (
|
||||
"image"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/model/imageproc"
|
||||
)
|
||||
|
||||
@@ -11,7 +11,7 @@ type ImageProcessor struct {
|
||||
imageSize, patchSize, numChannels int
|
||||
}
|
||||
|
||||
func newImageProcessor(c ml.Config) ImageProcessor {
|
||||
func newImageProcessor(c fs.Config) ImageProcessor {
|
||||
return ImageProcessor{
|
||||
imageSize: int(c.Uint("vision.image_size")),
|
||||
patchSize: int(c.Uint("vision.patch_size")),
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"math"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
@@ -30,7 +31,7 @@ type Model struct {
|
||||
*Options
|
||||
}
|
||||
|
||||
func New(c ml.Config) (model.Model, error) {
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
if !strings.EqualFold(c.String("tokenizer.ggml.model"), "gpt2") {
|
||||
return nil, fmt.Errorf("tokenizer %s not yet supported", c.String("tokenizer.ggml.model"))
|
||||
}
|
||||
@@ -139,23 +140,18 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
|
||||
return hiddenState.Add(ctx, residual)
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
|
||||
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
|
||||
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
m.Cache.SetLayer(i)
|
||||
|
||||
56
model/models/mistral3/imageproc.go
Normal file
56
model/models/mistral3/imageproc.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package mistral3
|
||||
|
||||
import (
|
||||
"image"
|
||||
_ "image/jpeg"
|
||||
_ "image/png"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/model/imageproc"
|
||||
)
|
||||
|
||||
type ImageProcessor struct {
|
||||
imageSize int
|
||||
patchSize int
|
||||
numChannels int
|
||||
longestEdge int
|
||||
}
|
||||
|
||||
func newImageProcessor(c fs.Config) ImageProcessor {
|
||||
return ImageProcessor{
|
||||
imageSize: int(c.Uint("vision.image_size", 1540)),
|
||||
patchSize: int(c.Uint("vision.patch_size", 14)),
|
||||
numChannels: int(c.Uint("vision.num_channels", 3)),
|
||||
longestEdge: int(c.Uint("vision.longest_edge", 1540)),
|
||||
}
|
||||
}
|
||||
|
||||
// ProcessImage prepares an image for the vision model by:
|
||||
// 1. Compositing transparent images
|
||||
// 2. Resizing to fit model constraints while preserving aspect ratio
|
||||
// 3. Normalizing pixel values
|
||||
// Returns normalized image data and the final size in pixels
|
||||
func (p *ImageProcessor) ProcessImage(img image.Image) ([]float32, image.Point, error) {
|
||||
img = imageproc.Composite(img)
|
||||
|
||||
size := img.Bounds().Size()
|
||||
ratio := max(float64(size.Y)/float64(p.longestEdge), float64(size.X)/float64(p.longestEdge))
|
||||
if ratio > 1.0 {
|
||||
size = image.Point{
|
||||
int(math.Floor(float64(size.X) / ratio)),
|
||||
int(math.Floor(float64(size.Y) / ratio)),
|
||||
}
|
||||
}
|
||||
|
||||
patchesX := (size.X-1)/p.patchSize + 1
|
||||
patchesY := (size.Y-1)/p.patchSize + 1
|
||||
size = image.Point{
|
||||
patchesX * p.patchSize,
|
||||
patchesY * p.patchSize,
|
||||
}
|
||||
|
||||
img = imageproc.Resize(img, size, imageproc.ResizeBilinear)
|
||||
data := imageproc.Normalize(img, imageproc.ClipDefaultMean, imageproc.ClipDefaultSTD, true, true)
|
||||
return data, size, nil
|
||||
}
|
||||
189
model/models/mistral3/model.go
Normal file
189
model/models/mistral3/model.go
Normal file
@@ -0,0 +1,189 @@
|
||||
package mistral3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"image"
|
||||
"slices"
|
||||
"sync"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
*TextModel
|
||||
*VisionModel `gguf:"v,vision"`
|
||||
*MultiModalProjector `gguf:"mm"`
|
||||
|
||||
ImageProcessor
|
||||
}
|
||||
|
||||
// Implement MultimodalProcessor interface
|
||||
var _ model.MultimodalProcessor = (*Model)(nil)
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
textModel, err := NewTextModel(c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m := &Model{
|
||||
TextModel: textModel,
|
||||
VisionModel: newVisionModel(c),
|
||||
ImageProcessor: newImageProcessor(c),
|
||||
MultiModalProjector: newMultiModalProjector(c),
|
||||
}
|
||||
|
||||
m.Cache = kvcache.NewCausalCache(m.TextModel.Shift)
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
type PatchMerger struct {
|
||||
MergingLayer *nn.Linear `gguf:"merging_layer"`
|
||||
}
|
||||
|
||||
func (pm *PatchMerger) Forward(ctx ml.Context, visionOutputs ml.Tensor, size image.Point, spatialMergeSize int) ml.Tensor {
|
||||
d := visionOutputs.Dim(0)
|
||||
imageGrid := visionOutputs.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Reshape(ctx, size.X, size.Y, d)
|
||||
kernel := ctx.Input().Empty(ml.DTypeF32, spatialMergeSize, spatialMergeSize, d)
|
||||
patches := kernel.IM2Col(ctx, imageGrid, spatialMergeSize, spatialMergeSize, 0, 0, 1, 1)
|
||||
reshaped := patches.Reshape(ctx, d*spatialMergeSize*spatialMergeSize, patches.Dim(1)*patches.Dim(2))
|
||||
return pm.MergingLayer.Forward(ctx, reshaped)
|
||||
}
|
||||
|
||||
type MultiModalProjector struct {
|
||||
Norm *nn.RMSNorm `gguf:"norm"`
|
||||
Linear1 *nn.Linear `gguf:"linear_1"`
|
||||
Linear2 *nn.Linear `gguf:"linear_2"`
|
||||
PatchMerger *PatchMerger `gguf:"patch_merger"`
|
||||
|
||||
spatialMergeSize int
|
||||
eps float32
|
||||
patchSize int
|
||||
}
|
||||
|
||||
func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, size image.Point) (ml.Tensor, image.Point) {
|
||||
visionOutputs = p.Norm.Forward(ctx, visionOutputs, p.eps)
|
||||
patchSizes := image.Point{size.X / p.patchSize, size.Y / p.patchSize}
|
||||
visionOutputs = p.PatchMerger.Forward(ctx, visionOutputs, patchSizes, p.spatialMergeSize)
|
||||
visionOutputs = p.Linear1.Forward(ctx, visionOutputs)
|
||||
visionOutputs = visionOutputs.GELU(ctx)
|
||||
return p.Linear2.Forward(ctx, visionOutputs), image.Point{patchSizes.X / p.spatialMergeSize, patchSizes.Y / p.spatialMergeSize}
|
||||
}
|
||||
|
||||
func newMultiModalProjector(c fs.Config) *MultiModalProjector {
|
||||
return &MultiModalProjector{
|
||||
spatialMergeSize: int(c.Uint("spatial_merge_size", 2)),
|
||||
eps: c.Float("text_config.rms_norm_eps", 1e-5),
|
||||
patchSize: int(c.Uint("vision.patch_size", 14)),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
|
||||
if len(m.VisionModel.Layers) == 0 {
|
||||
return nil, model.ErrNoVisionModel
|
||||
}
|
||||
|
||||
image, _, err := image.Decode(bytes.NewReader(multimodalData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f32s, size, err := m.ImageProcessor.ProcessImage(image)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pixelValues, err := ctx.Input().FromFloatSlice(f32s, size.X, size.Y, m.ImageProcessor.numChannels)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
|
||||
features, size := m.MultiModalProjector.Forward(ctx, visionOutputs, size)
|
||||
|
||||
// split into patches to be sent to the text transformer
|
||||
parent := imageFeatures{tensor: features}
|
||||
rows := make([]*imageRow, size.Y)
|
||||
for i := range rows {
|
||||
rows[i] = &imageRow{parent: &parent, s: i, shape: []int{features.Dim(0), size.X}}
|
||||
}
|
||||
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
type imageFeatures struct {
|
||||
tensor ml.Tensor
|
||||
|
||||
dataOnce sync.Once
|
||||
data []float32
|
||||
}
|
||||
|
||||
type imageRow struct {
|
||||
parent *imageFeatures
|
||||
s int
|
||||
shape []int
|
||||
}
|
||||
|
||||
func (r *imageRow) data() []float32 {
|
||||
n := 1
|
||||
for _, s := range r.shape {
|
||||
n *= s
|
||||
}
|
||||
|
||||
return r.parent.data[r.s*n : (r.s+1)*n]
|
||||
}
|
||||
|
||||
// PostTokenize arranges Mistral 3's inputs for the forward pass
|
||||
// In Mistral 3 and Pixtral, the input patches are arranged as follows:
|
||||
// [IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_END]
|
||||
// Each sequence of [IMG]...[IMG] is a set of patches of vision embeddings
|
||||
// that can be processed together.
|
||||
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||
var result []input.Input
|
||||
for _, inp := range inputs {
|
||||
if inp.Multimodal == nil {
|
||||
result = append(result, inp)
|
||||
} else {
|
||||
inputMultimodal := inp.Multimodal.([]*imageRow)
|
||||
for i, row := range inputMultimodal {
|
||||
// [IMG]
|
||||
result = append(result, input.Input{Token: 10, Multimodal: row, MultimodalHash: inp.MultimodalHash, SameBatch: row.shape[1]})
|
||||
result = append(result, slices.Repeat([]input.Input{{Token: 10}}, row.shape[1]-1)...)
|
||||
if i == len(inputMultimodal)-1 {
|
||||
// [IMG_END]
|
||||
result = append(result, input.Input{Token: 13})
|
||||
} else {
|
||||
// [IMG_BREAK]
|
||||
result = append(result, input.Input{Token: 12})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
model.Register("mistral3", New)
|
||||
}
|
||||
177
model/models/mistral3/model_text.go
Normal file
177
model/models/mistral3/model_text.go
Normal file
@@ -0,0 +1,177 @@
|
||||
package mistral3
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
type TextOptions struct {
|
||||
hiddenSize, numHeads, numKVHeads, headDim int
|
||||
eps, ropeBase, ropeScale float32
|
||||
ropeDim uint32
|
||||
}
|
||||
|
||||
type TextModel struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
||||
|
||||
*TextOptions
|
||||
}
|
||||
|
||||
type SelfAttention struct {
|
||||
Query *nn.Linear `gguf:"attn_q"`
|
||||
Key *nn.Linear `gguf:"attn_k"`
|
||||
Value *nn.Linear `gguf:"attn_v"`
|
||||
Output *nn.Linear `gguf:"attn_output"`
|
||||
}
|
||||
|
||||
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
|
||||
batchSize := hiddenState.Dim(1)
|
||||
ropeType := uint32(0)
|
||||
headDim := opts.headDim
|
||||
if headDim == 0 {
|
||||
headDim = opts.hiddenSize / opts.numHeads
|
||||
}
|
||||
|
||||
q := sa.Query.Forward(ctx, hiddenState)
|
||||
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||
q = q.RoPE(ctx, positionIDs, nil, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
||||
|
||||
k := sa.Key.Forward(ctx, hiddenState)
|
||||
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
k = k.RoPE(ctx, positionIDs, nil, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
||||
|
||||
v := sa.Value.Forward(ctx, hiddenState)
|
||||
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
|
||||
kqv := nn.Attention(ctx, q, k, v, 1.0/math.Sqrt(float64(headDim)), cache)
|
||||
kqv = kqv.Reshape(ctx, headDim*opts.numHeads, batchSize)
|
||||
return sa.Output.Forward(ctx, kqv)
|
||||
}
|
||||
|
||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
return key.RoPE(ctx, shift, nil, uint32(0), m.ropeDim, m.ropeBase, m.ropeScale), nil
|
||||
}
|
||||
|
||||
type MLP struct {
|
||||
Up *nn.Linear `gguf:"ffn_up"`
|
||||
Down *nn.Linear `gguf:"ffn_down"`
|
||||
Gate *nn.Linear `gguf:"ffn_gate"`
|
||||
}
|
||||
|
||||
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||
return mlp.Down.Forward(ctx, hiddenState)
|
||||
}
|
||||
|
||||
type Layer struct {
|
||||
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
|
||||
SelfAttention *SelfAttention
|
||||
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
|
||||
MLP *MLP
|
||||
}
|
||||
|
||||
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
|
||||
residual := hiddenState
|
||||
|
||||
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts)
|
||||
|
||||
// In the final layer (outputs != nil), optimize by pruning to just the token positions
|
||||
// we need logits for.
|
||||
if outputs != nil {
|
||||
hiddenState = hiddenState.Rows(ctx, outputs)
|
||||
residual = residual.Rows(ctx, outputs)
|
||||
}
|
||||
|
||||
hiddenState = hiddenState.Add(ctx, residual)
|
||||
residual = hiddenState
|
||||
|
||||
hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
|
||||
return hiddenState.Add(ctx, residual)
|
||||
}
|
||||
|
||||
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor {
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, inputs).Duplicate(ctx)
|
||||
|
||||
// image embeddings
|
||||
for _, image := range batch.Multimodal {
|
||||
row := image.Multimodal.(*imageRow)
|
||||
row.parent.dataOnce.Do(func() {
|
||||
// use a new, throwaway context so the image tensor is not added to the graph
|
||||
temp := m.Backend().NewContext()
|
||||
temp.Forward(row.parent.tensor).Compute(row.parent.tensor)
|
||||
row.parent.data = row.parent.tensor.Floats()
|
||||
temp.Close()
|
||||
})
|
||||
|
||||
imageFeature, err := ctx.Input().FromFloatSlice(row.data(), row.shape...)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
ctx.Forward(imageFeature.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), imageFeature.Dim(0)*imageFeature.Dim(1))))
|
||||
}
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
cache.SetLayer(i)
|
||||
|
||||
var lastLayerOutputs ml.Tensor
|
||||
if i == len(m.Layers)-1 {
|
||||
lastLayerOutputs = outputs
|
||||
}
|
||||
|
||||
hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, cache, m.TextOptions)
|
||||
}
|
||||
|
||||
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
||||
return m.Output.Forward(ctx, hiddenState)
|
||||
}
|
||||
|
||||
func NewTextModel(c fs.Config) (*TextModel, error) {
|
||||
if !strings.EqualFold(c.String("tokenizer.ggml.model"), "gpt2") {
|
||||
return nil, fmt.Errorf("tokenizer %s not yet supported", c.String("tokenizer.ggml.model"))
|
||||
}
|
||||
|
||||
textModel := &TextModel{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
c.String("tokenizer.ggml.pretokenizer", `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Uints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id", 1)),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id", 2)),
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
},
|
||||
),
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
TextOptions: &TextOptions{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||
headDim: int(c.Uint("attention.key_length")),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
ropeBase: c.Float("rope.freq_base"),
|
||||
ropeScale: c.Float("rope.freq_scale", 1),
|
||||
ropeDim: c.Uint("rope.dimension_count"),
|
||||
},
|
||||
}
|
||||
|
||||
return textModel, nil
|
||||
}
|
||||
186
model/models/mistral3/model_vision.go
Normal file
186
model/models/mistral3/model_vision.go
Normal file
@@ -0,0 +1,186 @@
|
||||
package mistral3
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
)
|
||||
|
||||
var batchSize int = 1
|
||||
|
||||
func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor {
|
||||
x1 := t.View(ctx, 0, t.Dim(0)/2, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2), t.Stride(3), t.Dim(3))
|
||||
x2 := t.View(ctx, t.Stride(0)*t.Dim(0)/2, t.Dim(0)/2, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2), t.Stride(3), t.Dim(3)).Contiguous(ctx)
|
||||
return x2.Neg(ctx).Concat(ctx, x1, 0)
|
||||
}
|
||||
|
||||
func applyRotaryPositionalEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor {
|
||||
return t.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, t).Mul(ctx, sin))
|
||||
}
|
||||
|
||||
type VisionSelfAttention struct {
|
||||
Query *nn.Linear `gguf:"attn_q"`
|
||||
Key *nn.Linear `gguf:"attn_k"`
|
||||
Value *nn.Linear `gguf:"attn_v"`
|
||||
Output *nn.Linear `gguf:"attn_output"`
|
||||
}
|
||||
|
||||
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
query := sa.Query.Forward(ctx, hiddenStates)
|
||||
key := sa.Key.Forward(ctx, hiddenStates)
|
||||
value := sa.Value.Forward(ctx, hiddenStates)
|
||||
|
||||
query = query.Reshape(ctx, opts.headDim, opts.numHeads, query.Dim(1), batchSize)
|
||||
key = key.Reshape(ctx, opts.headDim, opts.numHeads, key.Dim(1), batchSize)
|
||||
value = value.Reshape(ctx, opts.headDim, opts.numHeads, value.Dim(1), batchSize)
|
||||
|
||||
query = applyRotaryPositionalEmbedding(ctx, query, cos, sin)
|
||||
key = applyRotaryPositionalEmbedding(ctx, key, cos, sin)
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim)), nil)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
||||
type VisionMLP struct {
|
||||
Gate *nn.Linear `gguf:"ffn_gate"`
|
||||
Up *nn.Linear `gguf:"ffn_up"`
|
||||
Down *nn.Linear `gguf:"ffn_down"`
|
||||
}
|
||||
|
||||
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||
return mlp.Down.Forward(ctx, hiddenStates)
|
||||
}
|
||||
|
||||
type VisionEncoderLayer struct {
|
||||
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
|
||||
SelfAttention *VisionSelfAttention
|
||||
FFNNorm *nn.RMSNorm `gguf:"ffn_norm"`
|
||||
MLP *VisionMLP
|
||||
}
|
||||
|
||||
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
residual := hiddenStates
|
||||
hiddenStates = e.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||
hiddenStates = e.SelfAttention.Forward(ctx, hiddenStates, cos, sin, opts)
|
||||
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||
|
||||
residual = hiddenStates
|
||||
hiddenStates = e.FFNNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||
hiddenStates = e.MLP.Forward(ctx, hiddenStates, opts)
|
||||
return hiddenStates.Add(ctx, residual)
|
||||
}
|
||||
|
||||
type VisionModelOptions struct {
|
||||
hiddenSize int
|
||||
numHeads int
|
||||
headDim int
|
||||
intermediateSize int
|
||||
imageSize int
|
||||
patchSize int
|
||||
numChannels int
|
||||
eps float32
|
||||
ropeBase float32
|
||||
}
|
||||
|
||||
type VisionModel struct {
|
||||
PatchEmbedding *nn.Conv2D `gguf:"patch_conv"`
|
||||
EncoderNorm *nn.RMSNorm `gguf:"encoder_norm"`
|
||||
Layers []VisionEncoderLayer `gguf:"blk"`
|
||||
|
||||
*VisionModelOptions
|
||||
}
|
||||
|
||||
func (m *VisionModel) positionalEmbedding(ctx ml.Context, positionIDs ml.Tensor) ml.Tensor {
|
||||
maxPatchesPerSide := m.imageSize / m.patchSize
|
||||
frequencies := m.headDim / 2
|
||||
frequenciesHeight := make([]float32, frequencies/2*maxPatchesPerSide)
|
||||
frequenciesWidth := make([]float32, frequencies/2*maxPatchesPerSide)
|
||||
for i := range frequencies {
|
||||
for j := range maxPatchesPerSide {
|
||||
frequency := float32(j) / float32(math.Pow(float64(m.ropeBase), float64(i)*2/float64(m.headDim)))
|
||||
if i%2 == 0 {
|
||||
frequenciesHeight[i/2*maxPatchesPerSide+j] = frequency
|
||||
} else {
|
||||
frequenciesWidth[i/2*maxPatchesPerSide+j] = frequency
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
h, err := ctx.Input().FromFloatSlice(frequenciesHeight, maxPatchesPerSide, frequencies/2)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
w, err := ctx.Input().FromFloatSlice(frequenciesWidth, maxPatchesPerSide, frequencies/2)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
h = h.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
w = w.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
|
||||
h = h.Repeat(ctx, 1, maxPatchesPerSide)
|
||||
h = h.Reshape(ctx, frequencies/2, maxPatchesPerSide, maxPatchesPerSide).Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
w = w.Repeat(ctx, 2, maxPatchesPerSide)
|
||||
|
||||
inverseFrequencies := h.Concat(ctx, w, 0).Reshape(ctx, frequencies, maxPatchesPerSide*maxPatchesPerSide)
|
||||
inverseFrequencies = inverseFrequencies.Concat(ctx, inverseFrequencies, 0)
|
||||
return inverseFrequencies.Rows(ctx, positionIDs)
|
||||
}
|
||||
|
||||
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
|
||||
numPatchesW := pixelValues.Dim(0) / m.patchSize
|
||||
numPatchesH := pixelValues.Dim(1) / m.patchSize
|
||||
numPatches := numPatchesW * numPatchesH
|
||||
|
||||
hiddenStates := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1)
|
||||
hiddenStates = hiddenStates.Reshape(ctx, numPatches, m.hiddenSize)
|
||||
hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
hiddenStates = m.EncoderNorm.Forward(ctx, hiddenStates, m.VisionModelOptions.eps)
|
||||
|
||||
// Prepare position IDs for 2D rope
|
||||
positions := make([]int32, numPatches)
|
||||
for h := range numPatchesH {
|
||||
for w := range numPatchesW {
|
||||
idx := h*numPatchesW + w
|
||||
positions[idx] = int32(h*m.imageSize/m.patchSize + w)
|
||||
}
|
||||
}
|
||||
|
||||
positionIDs, err := ctx.Input().FromIntSlice(positions, len(positions))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
positionEmbedding := m.positionalEmbedding(ctx, positionIDs)
|
||||
cos, sin := positionEmbedding.Cos(ctx), positionEmbedding.Sin(ctx)
|
||||
cos = cos.Reshape(ctx, cos.Dim(0), 1, cos.Dim(1))
|
||||
sin = sin.Reshape(ctx, sin.Dim(0), 1, sin.Dim(1))
|
||||
|
||||
for _, layer := range m.Layers {
|
||||
hiddenStates = layer.Forward(ctx, hiddenStates, cos, sin, m.VisionModelOptions)
|
||||
}
|
||||
|
||||
return hiddenStates
|
||||
}
|
||||
|
||||
func newVisionModel(c fs.Config) *VisionModel {
|
||||
return &VisionModel{
|
||||
Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count", 24)),
|
||||
VisionModelOptions: &VisionModelOptions{
|
||||
hiddenSize: int(c.Uint("vision.embedding_length", 1024)),
|
||||
numHeads: int(c.Uint("vision.attention.head_count", 16)),
|
||||
headDim: int(c.Uint("vision.attention.key_length", 64)),
|
||||
intermediateSize: int(c.Uint("vision.feed_forward_length", 4096)),
|
||||
imageSize: int(c.Uint("vision.image_size", 1540)),
|
||||
patchSize: int(c.Uint("vision.patch_size", 14)),
|
||||
numChannels: int(c.Uint("vision.num_channels", 3)),
|
||||
eps: c.Float("vision.attention.layer_norm_epsilon", 1e-5),
|
||||
ropeBase: c.Float("vision.rope.freq_base", 10000.0),
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"image"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
@@ -32,7 +33,7 @@ const (
|
||||
selfAttentionLayer
|
||||
)
|
||||
|
||||
func New(c ml.Config) (model.Model, error) {
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
// Verify unified config
|
||||
if c.Uint("vision.block_count") == 0 {
|
||||
return nil, fmt.Errorf("non-unified vision model not supported")
|
||||
@@ -63,6 +64,10 @@ func New(c ml.Config) (model.Model, error) {
|
||||
}
|
||||
|
||||
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
|
||||
if len(m.VisionModel.Transformer.Layers) == 0 || len(m.GlobalTransformer.Layers) == 0 {
|
||||
return nil, model.ErrNoVisionModel
|
||||
}
|
||||
|
||||
image, _, err := image.Decode(bytes.NewReader(multimodalData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -102,17 +107,17 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
|
||||
return m.Projector.Forward(ctx, crossAttentionStates), nil
|
||||
}
|
||||
|
||||
func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Input, error) {
|
||||
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||
var images []input.Input
|
||||
fnvHash := fnv.New64a()
|
||||
|
||||
for i := range inputs {
|
||||
if inputs[i].Multimodal == nil {
|
||||
if len(images) > 0 {
|
||||
inputs[i].Multimodal = images[0].Multimodal
|
||||
inputs[i].Multimodal = []ml.Tensor{images[0].Multimodal.(ml.Tensor)}
|
||||
inputs[i].MultimodalHash = images[0].MultimodalHash
|
||||
for j := 1; j < len(images); j++ {
|
||||
inputs[i].Multimodal = inputs[i].Multimodal.(ml.Tensor).Concat(ctx, images[j].Multimodal.(ml.Tensor), 3)
|
||||
inputs[i].Multimodal = append(inputs[i].Multimodal.([]ml.Tensor), images[0].Multimodal.(ml.Tensor))
|
||||
fnvHash.Reset()
|
||||
binary.Write(fnvHash, binary.NativeEndian, inputs[i].MultimodalHash)
|
||||
binary.Write(fnvHash, binary.NativeEndian, inputs[j].MultimodalHash)
|
||||
@@ -131,29 +136,27 @@ func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Inpu
|
||||
return inputs, nil
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
var crossAttentionStates ml.Tensor
|
||||
if len(opts.Multimodal) > 0 {
|
||||
crossAttentionStates = opts.Multimodal[len(opts.Multimodal)-1].Multimodal.(ml.Tensor)
|
||||
if len(batch.Multimodal) > 0 {
|
||||
images := batch.Multimodal[len(batch.Multimodal)-1].Multimodal.([]ml.Tensor)
|
||||
if len(images) > 0 {
|
||||
crossAttentionStates = images[len(images)-1]
|
||||
}
|
||||
}
|
||||
|
||||
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
|
||||
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
|
||||
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TODO: attention mask, cross attention mask
|
||||
return m.TextModel.Forward(ctx, inputs, positions, outputs, nil, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
|
||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, nil, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"math"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
@@ -220,7 +221,7 @@ func (m *TextModel) Forward(ctx ml.Context, inputIDs, positionIDs, outputs, mask
|
||||
return m.Output.Forward(ctx, hiddenState)
|
||||
}
|
||||
|
||||
func newTextModel(c ml.Config) *TextModel {
|
||||
func newTextModel(c fs.Config) *TextModel {
|
||||
var decoderLayers []TextDecoderLayer
|
||||
for i := range c.Uint("block_count") {
|
||||
var textDecoderLayer TextDecoderLayer
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"math"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
)
|
||||
@@ -185,7 +186,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues, positionIDs, aspectRa
|
||||
hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
|
||||
hiddenState = m.PreTilePositionEmbedding.Forward(ctx, hiddenState, aspectRatioIDs, m.VisionModelOptions)
|
||||
hiddenState = m.ClassEmbedding.Stack(ctx, 2, slices.Repeat([]ml.Tensor{m.ClassEmbedding}, m.numTiles-1)...).Concat(ctx, hiddenState, 1)
|
||||
hiddenState = m.ClassEmbedding.Repeat(ctx, 2, m.numTiles).Concat(ctx, hiddenState, 1)
|
||||
|
||||
hiddenState = m.PositionEmbedding.Forward(ctx, hiddenState, positionIDs, aspectRatioIDs, numPositions, m.VisionModelOptions)
|
||||
hiddenState = m.PreLayerNorm.Forward(ctx, hiddenState, m.eps)
|
||||
@@ -213,7 +214,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues, positionIDs, aspectRa
|
||||
return hiddenState.Concat(ctx, hiddenStates, 0)
|
||||
}
|
||||
|
||||
func newVisionModel(c ml.Config) *VisionModel {
|
||||
func newVisionModel(c fs.Config) *VisionModel {
|
||||
return &VisionModel{
|
||||
Transformer: &VisionEncoder{Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count"))},
|
||||
GlobalTransformer: &VisionEncoder{Layers: make([]VisionEncoderLayer, c.Uint("vision.global.block_count"))},
|
||||
|
||||
@@ -8,14 +8,14 @@ import (
|
||||
|
||||
"golang.org/x/image/draw"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/fs"
|
||||
)
|
||||
|
||||
type ImageProcessor struct {
|
||||
imageSize, numChannels, maxNumTiles int
|
||||
}
|
||||
|
||||
func newImageProcessor(c ml.Config) ImageProcessor {
|
||||
func newImageProcessor(c fs.Config) ImageProcessor {
|
||||
return ImageProcessor{
|
||||
imageSize: int(c.Uint("vision.image_size")),
|
||||
numChannels: int(c.Uint("vision.num_channels")),
|
||||
|
||||
@@ -4,5 +4,6 @@ import (
|
||||
_ "github.com/ollama/ollama/model/models/gemma2"
|
||||
_ "github.com/ollama/ollama/model/models/gemma3"
|
||||
_ "github.com/ollama/ollama/model/models/llama"
|
||||
_ "github.com/ollama/ollama/model/models/mistral3"
|
||||
_ "github.com/ollama/ollama/model/models/mllama"
|
||||
)
|
||||
|
||||
@@ -1,68 +0,0 @@
|
||||
package pixtral
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"image"
|
||||
_ "image/jpeg"
|
||||
_ "image/png"
|
||||
"io"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/model/imageproc"
|
||||
)
|
||||
|
||||
func getNumImageTokens(imageSize, patchSize image.Point) image.Point {
|
||||
return image.Point{
|
||||
(imageSize.X-1)/patchSize.X + 1,
|
||||
(imageSize.Y-1)/patchSize.Y + 1,
|
||||
}
|
||||
}
|
||||
|
||||
func getResizeOutputImageSize(img image.Image, longestEdge int, patchSize image.Point) image.Point {
|
||||
b := img.Bounds()
|
||||
le := float64(longestEdge)
|
||||
ratio := math.Max(float64(b.Max.Y)/le, float64(b.Max.X)/le)
|
||||
|
||||
newSize := img.Bounds().Max
|
||||
|
||||
if ratio > 1.0 {
|
||||
newSize = image.Point{
|
||||
int(math.Ceil(float64(b.Max.X) / ratio)),
|
||||
int(math.Ceil(float64(b.Max.Y) / ratio)),
|
||||
}
|
||||
}
|
||||
|
||||
tokens := getNumImageTokens(newSize, patchSize)
|
||||
return image.Point{
|
||||
tokens.X * patchSize.X,
|
||||
tokens.Y * patchSize.Y,
|
||||
}
|
||||
}
|
||||
|
||||
func resizeImage(img image.Image, format string, longestEdge int, patchSize image.Point) image.Image {
|
||||
if format == "png" {
|
||||
img = imageproc.Composite(img)
|
||||
}
|
||||
|
||||
newSize := getResizeOutputImageSize(img, longestEdge, patchSize)
|
||||
|
||||
// todo should be ResizeBicubic, but it doesn't exist
|
||||
return imageproc.Resize(img, newSize, imageproc.ResizeBilinear)
|
||||
}
|
||||
|
||||
func Preprocess(imageData io.Reader) ([]float32, map[string]any, error) {
|
||||
img, format, err := image.Decode(imageData)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to decode image: %w", err)
|
||||
}
|
||||
|
||||
longestEdge := 1024
|
||||
patchSize := image.Point{16, 16}
|
||||
|
||||
img = resizeImage(img, format, longestEdge, patchSize)
|
||||
|
||||
data := imageproc.Normalize(img, imageproc.ClipDefaultMean, imageproc.ClipDefaultSTD, true, true)
|
||||
|
||||
opts := map[string]any{}
|
||||
return data, opts, nil
|
||||
}
|
||||
@@ -1,219 +0,0 @@
|
||||
package pixtral
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"image"
|
||||
"image/png"
|
||||
"math"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func TestGetNumImageTokens(t *testing.T) {
|
||||
type numImageTokensCase struct {
|
||||
ImageSize image.Point
|
||||
PatchSize image.Point
|
||||
Expected image.Point
|
||||
}
|
||||
|
||||
cases := []numImageTokensCase{
|
||||
{
|
||||
ImageSize: image.Point{1024, 764},
|
||||
PatchSize: image.Point{16, 16},
|
||||
Expected: image.Point{64, 48},
|
||||
},
|
||||
{
|
||||
ImageSize: image.Point{800, 600},
|
||||
PatchSize: image.Point{16, 16},
|
||||
Expected: image.Point{50, 38},
|
||||
},
|
||||
{
|
||||
ImageSize: image.Point{640, 480},
|
||||
PatchSize: image.Point{16, 16},
|
||||
Expected: image.Point{40, 30},
|
||||
},
|
||||
{
|
||||
ImageSize: image.Point{320, 200},
|
||||
PatchSize: image.Point{16, 16},
|
||||
Expected: image.Point{20, 13},
|
||||
},
|
||||
{
|
||||
ImageSize: image.Point{1320, 200},
|
||||
PatchSize: image.Point{16, 16},
|
||||
Expected: image.Point{83, 13},
|
||||
},
|
||||
{
|
||||
ImageSize: image.Point{2000, 200},
|
||||
PatchSize: image.Point{16, 16},
|
||||
Expected: image.Point{125, 13},
|
||||
},
|
||||
{
|
||||
ImageSize: image.Point{10000, 200},
|
||||
PatchSize: image.Point{16, 16},
|
||||
Expected: image.Point{625, 13},
|
||||
},
|
||||
{
|
||||
ImageSize: image.Point{1131, 577},
|
||||
PatchSize: image.Point{16, 16},
|
||||
Expected: image.Point{71, 37},
|
||||
},
|
||||
{
|
||||
ImageSize: image.Point{16, 16},
|
||||
PatchSize: image.Point{16, 16},
|
||||
Expected: image.Point{1, 1},
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
actual := getNumImageTokens(c.ImageSize, c.PatchSize)
|
||||
|
||||
if diff := cmp.Diff(actual, c.Expected); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetResizeOutputImageSize(t *testing.T) {
|
||||
type resizeCase struct {
|
||||
Image image.Image
|
||||
LongestEdge int
|
||||
PatchSize image.Point
|
||||
Expected image.Point
|
||||
}
|
||||
|
||||
cases := []resizeCase{
|
||||
{
|
||||
Image: image.NewRGBA(image.Rect(0, 0, 1024, 768)),
|
||||
LongestEdge: 1024,
|
||||
PatchSize: image.Point{16, 16},
|
||||
Expected: image.Point{1024, 768},
|
||||
},
|
||||
{
|
||||
Image: image.NewRGBA(image.Rect(0, 0, 1162, 690)),
|
||||
LongestEdge: 1024,
|
||||
PatchSize: image.Point{16, 16},
|
||||
Expected: image.Point{1024, 624},
|
||||
},
|
||||
{
|
||||
Image: image.NewRGBA(image.Rect(0, 0, 300, 200)),
|
||||
LongestEdge: 1024,
|
||||
PatchSize: image.Point{16, 16},
|
||||
Expected: image.Point{304, 208},
|
||||
},
|
||||
{
|
||||
Image: image.NewRGBA(image.Rect(0, 0, 1862, 522)),
|
||||
LongestEdge: 1024,
|
||||
PatchSize: image.Point{16, 16},
|
||||
Expected: image.Point{1024, 288},
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
actual := getResizeOutputImageSize(c.Image, c.LongestEdge, c.PatchSize)
|
||||
|
||||
if diff := cmp.Diff(actual, c.Expected); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestResize(t *testing.T) {
|
||||
type resizeCase struct {
|
||||
Image image.Image
|
||||
LongestEdge int
|
||||
PatchSize image.Point
|
||||
Expected image.Image
|
||||
}
|
||||
|
||||
cases := []resizeCase{
|
||||
{
|
||||
Image: image.NewRGBA(image.Rect(0, 0, 1862, 522)),
|
||||
LongestEdge: 1024,
|
||||
PatchSize: image.Point{16, 16},
|
||||
Expected: image.NewRGBA(image.Rect(0, 0, 1024, 288)),
|
||||
},
|
||||
{
|
||||
Image: image.NewRGBA(image.Rect(0, 0, 10, 10)),
|
||||
LongestEdge: 1024,
|
||||
PatchSize: image.Point{16, 16},
|
||||
Expected: image.NewRGBA(image.Rect(0, 0, 16, 16)),
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
actual := resizeImage(c.Image, "png", c.LongestEdge, c.PatchSize)
|
||||
|
||||
if actual.Bounds() != c.Expected.Bounds() {
|
||||
t.Errorf("image size incorrect: '%#v': expected: '%#v'", actual.Bounds(), c.Expected.Bounds())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPreprocess(t *testing.T) {
|
||||
type preprocessCase struct {
|
||||
TestImage image.Image
|
||||
ExpectedLen int
|
||||
}
|
||||
|
||||
cases := []preprocessCase{
|
||||
{
|
||||
TestImage: image.NewRGBA(image.Rect(0, 0, 10, 10)),
|
||||
ExpectedLen: 16 * 16 * 3 * 1,
|
||||
},
|
||||
{
|
||||
TestImage: image.NewRGBA(image.Rect(0, 0, 2000, 2000)),
|
||||
ExpectedLen: 1024 * 1024 * 3 * 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
var buf bytes.Buffer
|
||||
err := png.Encode(&buf, c.TestImage)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
imgData, _, err := Preprocess(&buf)
|
||||
if err != nil {
|
||||
t.Fatalf("error processing: %q", err)
|
||||
}
|
||||
|
||||
switch len(imgData) {
|
||||
case 0:
|
||||
t.Errorf("no image data returned")
|
||||
case c.ExpectedLen:
|
||||
// ok
|
||||
default:
|
||||
t.Errorf("unexpected image data length: %d, expected: %d", len(imgData), c.ExpectedLen)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPreprocessImages(t *testing.T) {
|
||||
for _, testFile := range []string{"flight.png", "sportsball.png"} {
|
||||
f, err := os.Open(testFile)
|
||||
if err != nil {
|
||||
t.Skipf("skipping test, no test image found at %s", testFile)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
imgData, _, err := Preprocess(f)
|
||||
if err != nil {
|
||||
t.Fatalf("error processing: %q", err)
|
||||
}
|
||||
|
||||
byteData := make([]byte, len(imgData)*4) // float32 is 4 bytes
|
||||
for i, f := range imgData {
|
||||
binary.LittleEndian.PutUint32(byteData[i*4:], math.Float32bits(f))
|
||||
}
|
||||
|
||||
outputPath := "processed_" + testFile + ".bin"
|
||||
err = os.WriteFile(outputPath, byteData, 0o644)
|
||||
if err != nil {
|
||||
t.Fatalf("error writing processed image: %q", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -263,6 +263,10 @@ func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||
continue
|
||||
}
|
||||
|
||||
if id := bpe.vocab.Encode(pair.value); id < 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
merges[pair.a].runes = append(left.runes, right.runes...)
|
||||
merges[pair.b].runes = nil
|
||||
|
||||
|
||||
@@ -1,29 +1,23 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"iter"
|
||||
"container/heap"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/dlclark/regexp2"
|
||||
queue "github.com/emirpasic/gods/v2/queues/priorityqueue"
|
||||
)
|
||||
|
||||
const spmWhitespaceSep = "▁"
|
||||
|
||||
func replaceWhitespaceBySeperator(s string) string {
|
||||
return strings.ReplaceAll(s, " ", spmWhitespaceSep)
|
||||
}
|
||||
|
||||
type SentencePieceModel struct {
|
||||
maxTokenLen int
|
||||
pre *regexp2.Regexp
|
||||
vocab *Vocabulary
|
||||
}
|
||||
|
||||
var _ TextProcessor = (*SentencePieceModel)(nil)
|
||||
|
||||
func NewSentencePieceModel(pre string, vocab *Vocabulary) SentencePieceModel {
|
||||
func NewSentencePieceModel(vocab *Vocabulary) SentencePieceModel {
|
||||
slog.Debug("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
|
||||
|
||||
counter := map[int]int{}
|
||||
@@ -44,7 +38,6 @@ func NewSentencePieceModel(pre string, vocab *Vocabulary) SentencePieceModel {
|
||||
|
||||
return SentencePieceModel{
|
||||
maxTokenLen: maxTokenLen,
|
||||
pre: regexp2.MustCompile(pre, regexp2.Unicode|regexp2.RE2),
|
||||
vocab: vocab,
|
||||
}
|
||||
}
|
||||
@@ -53,20 +46,9 @@ func (spm SentencePieceModel) Is(id int32, special Special) bool {
|
||||
return spm.vocab.Is(id, special)
|
||||
}
|
||||
|
||||
func (spm *SentencePieceModel) split(s string) iter.Seq[string] {
|
||||
return func(yield func(string) bool) {
|
||||
for m, _ := spm.pre.FindStringMatch(s); m != nil; m, _ = spm.pre.FindNextMatch(m) {
|
||||
if !yield(m.String()) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||
fragments := []fragment{{value: s}}
|
||||
for _, special := range spm.vocab.SpecialVocabulary() {
|
||||
// TODO: process special tokens concurrently
|
||||
id := spm.vocab.Encode(special)
|
||||
for i := 0; i < len(fragments); i++ {
|
||||
frag := fragments[i]
|
||||
@@ -91,7 +73,6 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error)
|
||||
fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...)
|
||||
}
|
||||
}
|
||||
slog.Debug("fragments", "frags", fragments)
|
||||
|
||||
var ids []int32
|
||||
for _, frag := range fragments {
|
||||
@@ -100,105 +81,96 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error)
|
||||
continue
|
||||
}
|
||||
|
||||
for split := range spm.split(frag.value) {
|
||||
split = replaceWhitespaceBySeperator(split)
|
||||
text := strings.ReplaceAll(frag.value, " ", spmWhitespaceSep)
|
||||
|
||||
var sb strings.Builder
|
||||
sb.Write([]byte(split))
|
||||
if id := spm.vocab.Encode(sb.String()); id >= 0 {
|
||||
ids = append(ids, id)
|
||||
continue
|
||||
if id := spm.vocab.Encode(text); id >= 0 {
|
||||
ids = append(ids, id)
|
||||
continue
|
||||
}
|
||||
|
||||
q := &queue{}
|
||||
heap.Init(q)
|
||||
|
||||
runes := []rune(text)
|
||||
merges := make([]merge, len(runes))
|
||||
for r := range runes {
|
||||
merges[r] = merge{
|
||||
p: r - 1,
|
||||
n: r + 1,
|
||||
runes: []rune{runes[r]},
|
||||
}
|
||||
}
|
||||
|
||||
runes := []rune(sb.String())
|
||||
pq := queue.NewWith(func(a, b any) int {
|
||||
priA := a.(*candidate)
|
||||
priB := b.(*candidate)
|
||||
if priA.score > priB.score || (priA.score == priB.score && priA.a < priB.a) {
|
||||
return -1
|
||||
}
|
||||
return 1
|
||||
})
|
||||
|
||||
merges := make([]merge, len(runes))
|
||||
for r := range runes {
|
||||
merges[r] = merge{
|
||||
p: r - 1,
|
||||
n: r + 1,
|
||||
runes: []rune{runes[r]},
|
||||
}
|
||||
}
|
||||
|
||||
slog.Debug("tokenizer", "merges", merges)
|
||||
|
||||
pairwise := func(a, b int) *candidate {
|
||||
if a < 0 || b >= len(runes) {
|
||||
return nil
|
||||
}
|
||||
|
||||
left, right := string(merges[a].runes), string(merges[b].runes)
|
||||
if id := spm.vocab.Encode(left + right); id >= 0 {
|
||||
return &candidate{
|
||||
a: a,
|
||||
b: b,
|
||||
score: spm.vocab.Scores[id],
|
||||
}
|
||||
}
|
||||
pairwise := func(a, b int) *candidate {
|
||||
if a < 0 || b >= len(runes) {
|
||||
return nil
|
||||
}
|
||||
|
||||
for i := range len(runes) - 1 {
|
||||
if pair := pairwise(i, i+1); pair != nil {
|
||||
pq.Enqueue(pair)
|
||||
left, right := string(merges[a].runes), string(merges[b].runes)
|
||||
if id := spm.vocab.Encode(left + right); id >= 0 {
|
||||
return &candidate{
|
||||
a: a,
|
||||
b: b,
|
||||
score: spm.vocab.Scores[id],
|
||||
size: len(left) + len(right),
|
||||
}
|
||||
}
|
||||
|
||||
pqv := pq.Values()
|
||||
for _, v := range pqv {
|
||||
e := v.(*candidate)
|
||||
slog.Debug("candidate", "candidate", e)
|
||||
return nil
|
||||
}
|
||||
|
||||
for i := range len(runes) - 1 {
|
||||
if pair := pairwise(i, i+1); pair != nil {
|
||||
heap.Push(q, pair)
|
||||
}
|
||||
}
|
||||
|
||||
for q.Len() > 0 {
|
||||
pair := heap.Pop(q).(*candidate)
|
||||
left, right := merges[pair.a], merges[pair.b]
|
||||
|
||||
if string(left.runes) == "" || string(right.runes) == "" || len(string(left.runes))+len(string(right.runes)) != pair.size {
|
||||
continue
|
||||
}
|
||||
|
||||
for !pq.Empty() {
|
||||
v, _ := pq.Dequeue()
|
||||
pair := v.(*candidate)
|
||||
left, right := merges[pair.a], merges[pair.b]
|
||||
merges[pair.a].runes = append(left.runes, right.runes...)
|
||||
merges[pair.b].runes = nil
|
||||
merges[pair.a].n = right.n
|
||||
if right.n < len(merges) {
|
||||
merges[right.n].p = pair.a
|
||||
}
|
||||
|
||||
slog.Debug("pair", "left", left, "right", right)
|
||||
if len(left.runes) == 0 || len(right.runes) == 0 {
|
||||
if pair := pairwise(merges[pair.a].p, pair.a); pair != nil {
|
||||
heap.Push(q, pair)
|
||||
}
|
||||
|
||||
if pair := pairwise(pair.a, merges[pair.a].n); pair != nil {
|
||||
heap.Push(q, pair)
|
||||
}
|
||||
}
|
||||
|
||||
for _, merge := range merges {
|
||||
if token := string(merge.runes); token != "" {
|
||||
id := spm.vocab.Encode(token)
|
||||
|
||||
if id >= 0 {
|
||||
ids = append(ids, id)
|
||||
continue
|
||||
}
|
||||
|
||||
if id := spm.vocab.Encode(string(left.runes) + string(right.runes)); id < 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
merges[pair.a].runes = append(left.runes, right.runes...)
|
||||
merges[pair.b].runes = nil
|
||||
merges[pair.a].n = right.n
|
||||
if right.n < len(merges) {
|
||||
merges[right.n].p = pair.a
|
||||
}
|
||||
|
||||
if pair := pairwise(merges[pair.a].p, pair.a); pair != nil {
|
||||
pq.Enqueue(pair)
|
||||
}
|
||||
|
||||
if pair := pairwise(pair.a, merges[pair.a].n); pair != nil {
|
||||
pq.Enqueue(pair)
|
||||
}
|
||||
}
|
||||
|
||||
slog.Debug("merges", "merges", merges)
|
||||
|
||||
for _, merge := range merges {
|
||||
if len(merge.runes) > 0 {
|
||||
if id := spm.vocab.Encode(string(merge.runes)); id >= 0 {
|
||||
ids = append(ids, id)
|
||||
// Fallback to byte tokenization
|
||||
var result []int32
|
||||
for _, b := range []byte(token) {
|
||||
byteToken := fmt.Sprintf("<0x%02X>", b)
|
||||
unknownID := spm.vocab.Encode(byteToken)
|
||||
if unknownID >= 0 {
|
||||
result = append(result, unknownID)
|
||||
} else {
|
||||
slog.Debug("missing token", "token", string(merge.runes))
|
||||
slog.Debug("unknown byte token", "byte", b, "token", byteToken)
|
||||
}
|
||||
}
|
||||
|
||||
ids = append(ids, result...)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -229,6 +201,30 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error)
|
||||
type candidate struct {
|
||||
a, b int
|
||||
score float32
|
||||
size int
|
||||
}
|
||||
|
||||
type queue []*candidate
|
||||
|
||||
func (q queue) Len() int { return len(q) }
|
||||
|
||||
func (q queue) Less(i, j int) bool {
|
||||
return (q[i].score > q[j].score) || (q[i].score == q[j].score && q[i].a < q[j].a)
|
||||
}
|
||||
|
||||
func (q queue) Swap(i, j int) { q[i], q[j] = q[j], q[i] }
|
||||
|
||||
func (q *queue) Push(x interface{}) {
|
||||
item := x.(*candidate)
|
||||
*q = append(*q, item)
|
||||
}
|
||||
|
||||
func (q *queue) Pop() interface{} {
|
||||
old := *q
|
||||
n := len(old)
|
||||
item := old[n-1]
|
||||
*q = old[0 : n-1]
|
||||
return item
|
||||
}
|
||||
|
||||
func (spm SentencePieceModel) Decode(ids []int32) (string, error) {
|
||||
@@ -236,11 +232,26 @@ func (spm SentencePieceModel) Decode(ids []int32) (string, error) {
|
||||
for _, id := range ids {
|
||||
data := spm.vocab.Decode(id)
|
||||
data = strings.ReplaceAll(data, spmWhitespaceSep, " ")
|
||||
if _, err := sb.WriteString(data); err != nil {
|
||||
return "", err
|
||||
|
||||
// For tokenizers that use byte tokens like "<0xEA>"
|
||||
// convert them to the partial unicode character
|
||||
// so they are buffered correctly by the runner instead
|
||||
// of being sent back to the api as "<0xEA>"
|
||||
if len(data) == 6 && strings.HasPrefix(data, "<0x") && strings.HasSuffix(data, ">") {
|
||||
byteVal, err := strconv.ParseUint(data[1:5], 0, 8)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to parse hex byte: %v", err)
|
||||
}
|
||||
|
||||
if err := sb.WriteByte(byte(byteVal)); err != nil {
|
||||
return "", err
|
||||
}
|
||||
} else {
|
||||
if _, err := sb.WriteString(data); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
slog.Debug("decoded", "ids", ids, "text", sb.String())
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
||||
@@ -25,8 +25,6 @@ func loadSentencePieceVocab(t *testing.T) SentencePieceModel {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
preTokenizer := `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`
|
||||
|
||||
var v Vocabulary
|
||||
|
||||
for _, piece := range spm.GetPieces() {
|
||||
@@ -47,7 +45,7 @@ func loadSentencePieceVocab(t *testing.T) SentencePieceModel {
|
||||
}
|
||||
}
|
||||
|
||||
return NewSentencePieceModel(preTokenizer, &v)
|
||||
return NewSentencePieceModel(&v)
|
||||
}
|
||||
|
||||
func TestSentencePieceEncode(t *testing.T) {
|
||||
@@ -116,3 +114,59 @@ func TestSentencePieceEncode(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSentencePieceModelDecodeByteTokens(t *testing.T) {
|
||||
vocab := &Vocabulary{
|
||||
Values: []string{
|
||||
"normal",
|
||||
"<0xEA>",
|
||||
"<0x41>",
|
||||
"<0xC3>",
|
||||
"<0xA3>",
|
||||
},
|
||||
Types: []uint32{
|
||||
TOKEN_TYPE_NORMAL,
|
||||
TOKEN_TYPE_BYTE,
|
||||
TOKEN_TYPE_BYTE,
|
||||
TOKEN_TYPE_BYTE,
|
||||
TOKEN_TYPE_BYTE,
|
||||
},
|
||||
Scores: []float32{0, 0, 0, 0, 0},
|
||||
}
|
||||
|
||||
spm := NewSentencePieceModel(vocab)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ids []int32
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "single byte token",
|
||||
ids: []int32{1},
|
||||
expected: "\xea",
|
||||
},
|
||||
{
|
||||
name: "ASCII byte token",
|
||||
ids: []int32{2},
|
||||
expected: "A",
|
||||
},
|
||||
{
|
||||
name: "multiple byte tokens forming UTF-8 character",
|
||||
ids: []int32{3, 4},
|
||||
expected: "ã",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := spm.Decode(tt.ids)
|
||||
if err != nil {
|
||||
t.Errorf("failed to decode token IDs %v: %v", tt.ids, err)
|
||||
}
|
||||
if result != tt.expected {
|
||||
t.Errorf("got %q, want %q", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -23,10 +23,10 @@ import (
|
||||
var finishReasonToolCalls = "tool_calls"
|
||||
|
||||
type Error struct {
|
||||
Message string `json:"message"`
|
||||
Type string `json:"type"`
|
||||
Param interface{} `json:"param"`
|
||||
Code *string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Type string `json:"type"`
|
||||
Param any `json:"param"`
|
||||
Code *string `json:"code"`
|
||||
}
|
||||
|
||||
type ErrorResponse struct {
|
||||
@@ -465,7 +465,7 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
||||
}
|
||||
}
|
||||
|
||||
options := make(map[string]interface{})
|
||||
options := make(map[string]any)
|
||||
|
||||
switch stop := r.Stop.(type) {
|
||||
case string:
|
||||
|
||||
@@ -219,7 +219,7 @@ func TestChatMiddleware(t *testing.T) {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_current_weather",
|
||||
Arguments: map[string]interface{}{
|
||||
Arguments: map[string]any{
|
||||
"location": "Paris, France",
|
||||
"format": "celsius",
|
||||
},
|
||||
@@ -281,27 +281,31 @@ func TestChatMiddleware(t *testing.T) {
|
||||
Description: "Get the current weather",
|
||||
Parameters: struct {
|
||||
Type string `json:"type"`
|
||||
Defs any `json:"$defs,omitempty"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Required []string `json:"required"`
|
||||
Properties map[string]struct {
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
Enum []string `json:"enum,omitempty"`
|
||||
Type api.PropertyType `json:"type"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Description string `json:"description"`
|
||||
Enum []any `json:"enum,omitempty"`
|
||||
} `json:"properties"`
|
||||
}{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: map[string]struct {
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
Enum []string `json:"enum,omitempty"`
|
||||
Type api.PropertyType `json:"type"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Description string `json:"description"`
|
||||
Enum []any `json:"enum,omitempty"`
|
||||
}{
|
||||
"location": {
|
||||
Type: "string",
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "The city and state",
|
||||
},
|
||||
"unit": {
|
||||
Type: "string",
|
||||
Enum: []string{"celsius", "fahrenheit"},
|
||||
Type: api.PropertyType{"string"},
|
||||
Enum: []any{"celsius", "fahrenheit"},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -11,10 +11,13 @@ import (
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
"golang.org/x/text/encoding/unicode"
|
||||
"golang.org/x/text/transform"
|
||||
|
||||
@@ -144,12 +147,25 @@ func fileDigestMap(path string) (map[string]string, error) {
|
||||
files = []string{path}
|
||||
}
|
||||
|
||||
var mu sync.Mutex
|
||||
var g errgroup.Group
|
||||
g.SetLimit(max(runtime.GOMAXPROCS(0)-1, 1))
|
||||
for _, f := range files {
|
||||
digest, err := digestForFile(f)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fl[f] = digest
|
||||
g.Go(func() error {
|
||||
digest, err := digestForFile(f)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
fl[f] = digest
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
if err := g.Wait(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return fl, nil
|
||||
@@ -211,16 +227,10 @@ func filesForModel(path string) ([]string, error) {
|
||||
}
|
||||
|
||||
var files []string
|
||||
if st, _ := glob(filepath.Join(path, "model*.safetensors"), "application/octet-stream"); len(st) > 0 {
|
||||
if st, _ := glob(filepath.Join(path, "*.safetensors"), "application/octet-stream"); len(st) > 0 {
|
||||
// safetensors files might be unresolved git lfs references; skip if they are
|
||||
// covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors
|
||||
files = append(files, st...)
|
||||
} else if st, _ := glob(filepath.Join(path, "adapters.safetensors"), "application/octet-stream"); len(st) > 0 {
|
||||
// covers adapters.safetensors
|
||||
files = append(files, st...)
|
||||
} else if st, _ := glob(filepath.Join(path, "adapter_model.safetensors"), "application/octet-stream"); len(st) > 0 {
|
||||
// covers adapter_model.safetensors
|
||||
files = append(files, st...)
|
||||
} else if pt, _ := glob(filepath.Join(path, "pytorch_model*.bin"), "application/zip"); len(pt) > 0 {
|
||||
// pytorch files might also be unresolved git lfs references; skip if they are
|
||||
// covers pytorch_model-x-of-y.bin, pytorch_model.fp32-x-of-y.bin, pytorch_model.bin
|
||||
|
||||
@@ -116,19 +116,9 @@ func (i *Instance) Readline() (string, error) {
|
||||
|
||||
switch r {
|
||||
case KeyUp:
|
||||
if i.History.Pos > 0 {
|
||||
if i.History.Pos == i.History.Size() {
|
||||
currentLineBuf = []rune(buf.String())
|
||||
}
|
||||
buf.Replace([]rune(i.History.Prev()))
|
||||
}
|
||||
i.historyPrev(buf, ¤tLineBuf)
|
||||
case KeyDown:
|
||||
if i.History.Pos < i.History.Size() {
|
||||
buf.Replace([]rune(i.History.Next()))
|
||||
if i.History.Pos == i.History.Size() {
|
||||
buf.Replace(currentLineBuf)
|
||||
}
|
||||
}
|
||||
i.historyNext(buf, ¤tLineBuf)
|
||||
case KeyLeft:
|
||||
buf.MoveLeft()
|
||||
case KeyRight:
|
||||
@@ -185,6 +175,10 @@ func (i *Instance) Readline() (string, error) {
|
||||
esc = true
|
||||
case CharInterrupt:
|
||||
return "", ErrInterrupt
|
||||
case CharPrev:
|
||||
i.historyPrev(buf, ¤tLineBuf)
|
||||
case CharNext:
|
||||
i.historyNext(buf, ¤tLineBuf)
|
||||
case CharLineStart:
|
||||
buf.MoveToStart()
|
||||
case CharLineEnd:
|
||||
@@ -246,6 +240,24 @@ func (i *Instance) HistoryDisable() {
|
||||
i.History.Enabled = false
|
||||
}
|
||||
|
||||
func (i *Instance) historyPrev(buf *Buffer, currentLineBuf *[]rune) {
|
||||
if i.History.Pos > 0 {
|
||||
if i.History.Pos == i.History.Size() {
|
||||
*currentLineBuf = []rune(buf.String())
|
||||
}
|
||||
buf.Replace([]rune(i.History.Prev()))
|
||||
}
|
||||
}
|
||||
|
||||
func (i *Instance) historyNext(buf *Buffer, currentLineBuf *[]rune) {
|
||||
if i.History.Pos < i.History.Size() {
|
||||
buf.Replace([]rune(i.History.Next()))
|
||||
if i.History.Pos == i.History.Size() {
|
||||
buf.Replace(*currentLineBuf)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func NewTerminal() (*Terminal, error) {
|
||||
fd := os.Stdin.Fd()
|
||||
termios, err := SetRawMode(fd)
|
||||
|
||||
@@ -213,8 +213,16 @@ func (c *InputCache) ShiftDiscard(inputLen int, numKeep int) int {
|
||||
return discard
|
||||
}
|
||||
|
||||
// Frees up space in the KV cache by deleting the oldest half of history and shifting
|
||||
// the newest half into that space (saving numKeep inputs at the beginning).
|
||||
type ErrReprocessInputs struct {
|
||||
Inputs []input
|
||||
}
|
||||
|
||||
func (e *ErrReprocessInputs) Error() string {
|
||||
return fmt.Sprintf("kv cache shift not supported, inputs need reprocessing (input count: %v)", len(e.Inputs))
|
||||
}
|
||||
|
||||
// ShiftCacheSlot frees up space in the KV cache by deleting the oldest half of history
|
||||
// and shifting the newest half into that space (saving numKeep inputs at the beginning).
|
||||
//
|
||||
// Assumes that at least 1 entry can be freed up by shifting (i.e. numKeep < numCtx)
|
||||
func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error {
|
||||
@@ -222,7 +230,8 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error {
|
||||
return fmt.Errorf("unable to shift context - keep exceeds context (keep: %v context: %v)", numKeep, c.numCtx)
|
||||
}
|
||||
|
||||
discard := c.ShiftDiscard(len(slot.Inputs), numKeep)
|
||||
inputLen := len(slot.Inputs)
|
||||
discard := c.ShiftDiscard(inputLen, numKeep)
|
||||
|
||||
if discard <= 0 {
|
||||
return nil
|
||||
@@ -231,16 +240,42 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error {
|
||||
slog.Debug("context limit hit - shifting", "id", slot.Id, "limit", c.numCtx, "input", len(slot.Inputs),
|
||||
"keep", numKeep, "discard", discard)
|
||||
|
||||
// TODO (jessegross): KV cache removal can fail for certain types of models
|
||||
if !c.lc.KvCacheSeqRm(slot.Id, numKeep, numKeep+discard) {
|
||||
return fmt.Errorf("unable to remove old kv cache entries (id: %v, keep: %v discard: %v)", slot.Id, numKeep, discard)
|
||||
}
|
||||
c.lc.KvCacheSeqAdd(slot.Id, numKeep+discard, len(slot.Inputs), -discard)
|
||||
var shiftFailed bool
|
||||
|
||||
for i := numKeep + discard; i < len(slot.Inputs); i++ {
|
||||
if c.lc.KvCacheCanShift() {
|
||||
// For models that support shifting, attempt to shift the KV cache
|
||||
if !c.lc.KvCacheSeqRm(slot.Id, numKeep, numKeep+discard) {
|
||||
shiftFailed = true
|
||||
slog.Debug("kv cache removal not supported, clearing cache and returning inputs for reprocessing", "id", slot.Id)
|
||||
} else {
|
||||
c.lc.KvCacheSeqAdd(slot.Id, numKeep+discard, inputLen, -discard)
|
||||
}
|
||||
} else {
|
||||
// For models that don't support shifting
|
||||
shiftFailed = true
|
||||
slog.Debug("kv cache cannot shift, clearing cache and returning inputs for reprocessing", "id", slot.Id)
|
||||
}
|
||||
|
||||
if shiftFailed {
|
||||
// Create new input slice with preserved tokens (numKeep + remaining tokens after discard)
|
||||
newInputs := make([]input, numKeep+inputLen-(numKeep+discard))
|
||||
copy(newInputs[:numKeep], slot.Inputs[:numKeep])
|
||||
copy(newInputs[numKeep:], slot.Inputs[numKeep+discard:])
|
||||
|
||||
// Clear the entire KV cache
|
||||
_ = c.lc.KvCacheSeqRm(slot.Id, 0, -1)
|
||||
// Reset the slot inputs since we've cleared the cache
|
||||
slot.Inputs = []input{}
|
||||
|
||||
// Return error with inputs that need to be reprocessed
|
||||
return &ErrReprocessInputs{Inputs: newInputs}
|
||||
}
|
||||
|
||||
// Standard shift succeeded - update input array
|
||||
for i := numKeep + discard; i < inputLen; i++ {
|
||||
slot.Inputs[i-discard] = slot.Inputs[i]
|
||||
}
|
||||
slot.Inputs = slot.Inputs[:len(slot.Inputs)-discard]
|
||||
slot.Inputs = slot.Inputs[:inputLen-discard]
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -24,6 +24,7 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/llama"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/runner/common"
|
||||
)
|
||||
|
||||
@@ -82,7 +83,7 @@ type Sequence struct {
|
||||
// true if an embedding are to be returned instead of text generation
|
||||
embeddingOnly bool
|
||||
|
||||
doneReason string
|
||||
doneReason llm.DoneReason
|
||||
|
||||
// Metrics
|
||||
startProcessingTime time.Time
|
||||
@@ -99,7 +100,7 @@ type NewSequenceParams struct {
|
||||
embedding bool
|
||||
}
|
||||
|
||||
func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequenceParams) (*Sequence, error) {
|
||||
func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) {
|
||||
s.ready.Wait()
|
||||
|
||||
startTime := time.Now()
|
||||
@@ -163,7 +164,7 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
|
||||
// inputs processes the prompt and images into a list of inputs
|
||||
// by splitting the prompt on [img-<n>] tags, tokenizing text and
|
||||
// generating image embeddings for each image
|
||||
func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) {
|
||||
func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input, error) {
|
||||
var inputs []input
|
||||
var parts []string
|
||||
var matches [][]string
|
||||
@@ -229,7 +230,7 @@ type Server struct {
|
||||
image *ImageContext
|
||||
|
||||
// status for external health reporting - loading, ready to serve, etc.
|
||||
status ServerStatus
|
||||
status llm.ServerStatus
|
||||
|
||||
// current progress on loading the model
|
||||
progress float32
|
||||
@@ -300,7 +301,7 @@ func flushPending(seq *Sequence) bool {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) removeSequence(seqIndex int, reason string) {
|
||||
func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
|
||||
seq := s.seqs[seqIndex]
|
||||
|
||||
flushPending(seq)
|
||||
@@ -379,7 +380,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
||||
|
||||
// if past the num predict limit
|
||||
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
|
||||
s.removeSequence(seqIdx, "limit")
|
||||
s.removeSequence(seqIdx, llm.DoneReasonLength)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -388,7 +389,15 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
||||
if len(seq.pendingInputs) == 0 {
|
||||
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
|
||||
if err != nil {
|
||||
return err
|
||||
var reprocess *ErrReprocessInputs
|
||||
if errors.As(err, &reprocess) {
|
||||
// Prepend these inputs to the sequence's inputs queue for reprocessing
|
||||
seq.inputs = append(reprocess.Inputs, seq.inputs...)
|
||||
// Continue processing as normal
|
||||
continue
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
break
|
||||
@@ -473,7 +482,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
||||
}
|
||||
|
||||
seq.embedding <- embed
|
||||
s.removeSequence(i, "")
|
||||
s.removeSequence(i, llm.DoneReasonStop)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -490,7 +499,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
||||
// as it's important for the /api/generate context
|
||||
// seq.responses <- piece
|
||||
|
||||
s.removeSequence(i, "stop")
|
||||
s.removeSequence(i, llm.DoneReasonStop)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -521,7 +530,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
||||
}
|
||||
seq.cache.Inputs = seq.cache.Inputs[:tokenLen]
|
||||
|
||||
s.removeSequence(i, "stop")
|
||||
s.removeSequence(i, llm.DoneReasonStop)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -534,82 +543,25 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
||||
}
|
||||
|
||||
if !flushPending(seq) {
|
||||
s.removeSequence(i, "connection")
|
||||
s.removeSequence(i, llm.DoneReasonConnectionClosed)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO (jmorganca): use structs from the api package to avoid duplication
|
||||
// this way the api acts as a proxy instead of using a different api for the
|
||||
// runner
|
||||
type Options struct {
|
||||
api.Runner
|
||||
|
||||
NumKeep int `json:"n_keep"`
|
||||
Seed int `json:"seed"`
|
||||
NumPredict int `json:"n_predict"`
|
||||
TopK int `json:"top_k"`
|
||||
TopP float32 `json:"top_p"`
|
||||
MinP float32 `json:"min_p"`
|
||||
TypicalP float32 `json:"typical_p"`
|
||||
RepeatLastN int `json:"repeat_last_n"`
|
||||
Temperature float32 `json:"temperature"`
|
||||
RepeatPenalty float32 `json:"repeat_penalty"`
|
||||
PresencePenalty float32 `json:"presence_penalty"`
|
||||
FrequencyPenalty float32 `json:"frequency_penalty"`
|
||||
Mirostat int `json:"mirostat"`
|
||||
MirostatTau float32 `json:"mirostat_tau"`
|
||||
MirostatEta float32 `json:"mirostat_eta"`
|
||||
Stop []string `json:"stop"`
|
||||
}
|
||||
|
||||
type ImageData struct {
|
||||
Data []byte `json:"data"`
|
||||
ID int `json:"id"`
|
||||
AspectRatioID int `json:"aspect_ratio_id"`
|
||||
}
|
||||
|
||||
type CompletionRequest struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Images []ImageData `json:"image_data"`
|
||||
Grammar string `json:"grammar"`
|
||||
CachePrompt bool `json:"cache_prompt"`
|
||||
|
||||
Options
|
||||
}
|
||||
|
||||
type Timings struct {
|
||||
PredictedN int `json:"predicted_n"`
|
||||
PredictedMS float64 `json:"predicted_ms"`
|
||||
PromptN int `json:"prompt_n"`
|
||||
PromptMS float64 `json:"prompt_ms"`
|
||||
}
|
||||
|
||||
type CompletionResponse struct {
|
||||
Content string `json:"content"`
|
||||
Stop bool `json:"stop"`
|
||||
|
||||
Model string `json:"model,omitempty"`
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
StoppedLimit bool `json:"stopped_limit,omitempty"`
|
||||
PredictedN int `json:"predicted_n,omitempty"`
|
||||
PredictedMS float64 `json:"predicted_ms,omitempty"`
|
||||
PromptN int `json:"prompt_n,omitempty"`
|
||||
PromptMS float64 `json:"prompt_ms,omitempty"`
|
||||
|
||||
Timings Timings `json:"timings"`
|
||||
}
|
||||
|
||||
func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
var req CompletionRequest
|
||||
req.Options = Options(api.DefaultOptions())
|
||||
var req llm.CompletionRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "Bad request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Options == nil {
|
||||
opts := api.DefaultOptions()
|
||||
req.Options = &opts
|
||||
}
|
||||
|
||||
// Set the headers to indicate streaming
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Transfer-Encoding", "chunked")
|
||||
@@ -620,26 +572,28 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
var samplingParams llama.SamplingParams
|
||||
samplingParams.TopK = req.TopK
|
||||
samplingParams.TopP = req.TopP
|
||||
samplingParams.MinP = req.MinP
|
||||
samplingParams.TypicalP = req.TypicalP
|
||||
samplingParams.Temp = req.Temperature
|
||||
samplingParams.RepeatLastN = req.RepeatLastN
|
||||
samplingParams.PenaltyRepeat = req.RepeatPenalty
|
||||
samplingParams.PenaltyFreq = req.FrequencyPenalty
|
||||
samplingParams.PenaltyPresent = req.PresencePenalty
|
||||
samplingParams.Mirostat = req.Mirostat
|
||||
samplingParams.MirostatTau = req.MirostatTau
|
||||
samplingParams.MirostatEta = req.MirostatEta
|
||||
samplingParams.Seed = uint32(req.Seed)
|
||||
samplingParams.Grammar = req.Grammar
|
||||
// Extract options from the CompletionRequest
|
||||
samplingParams := llama.SamplingParams{
|
||||
TopK: req.Options.TopK,
|
||||
TopP: req.Options.TopP,
|
||||
MinP: req.Options.MinP,
|
||||
TypicalP: req.Options.TypicalP,
|
||||
Temp: req.Options.Temperature,
|
||||
RepeatLastN: req.Options.RepeatLastN,
|
||||
PenaltyRepeat: req.Options.RepeatPenalty,
|
||||
PenaltyFreq: req.Options.FrequencyPenalty,
|
||||
PenaltyPresent: req.Options.PresencePenalty,
|
||||
Mirostat: req.Options.Mirostat,
|
||||
MirostatTau: req.Options.MirostatTau,
|
||||
MirostatEta: req.Options.MirostatEta,
|
||||
Seed: uint32(req.Options.Seed),
|
||||
Grammar: req.Grammar,
|
||||
}
|
||||
|
||||
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
|
||||
numPredict: req.NumPredict,
|
||||
stop: req.Stop,
|
||||
numKeep: req.NumKeep,
|
||||
numPredict: req.Options.NumPredict,
|
||||
stop: req.Options.Stop,
|
||||
numKeep: req.Options.NumKeep,
|
||||
samplingParams: &samplingParams,
|
||||
embedding: false,
|
||||
})
|
||||
@@ -653,7 +607,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
slog.Info("aborting completion request due to client closing the connection")
|
||||
} else {
|
||||
slog.Error("Failed to acquire semaphore", "error", err)
|
||||
http.Error(w, fmt.Sprintf("Failed to acquire semaphore: %v", err), http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -662,9 +616,10 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
found := false
|
||||
for i, sq := range s.seqs {
|
||||
if sq == nil {
|
||||
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
|
||||
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, true)
|
||||
if err != nil {
|
||||
s.mu.Unlock()
|
||||
s.seqsSem.Release(1)
|
||||
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
@@ -680,6 +635,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
s.mu.Unlock()
|
||||
|
||||
if !found {
|
||||
s.seqsSem.Release(1)
|
||||
http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
@@ -691,7 +647,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
case content, ok := <-seq.responses:
|
||||
if ok {
|
||||
if err := json.NewEncoder(w).Encode(&CompletionResponse{
|
||||
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
|
||||
Content: content,
|
||||
}); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||
@@ -701,16 +657,13 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
flusher.Flush()
|
||||
} else {
|
||||
// Send the final response
|
||||
if err := json.NewEncoder(w).Encode(&CompletionResponse{
|
||||
Stop: true,
|
||||
StoppedLimit: seq.doneReason == "limit",
|
||||
Timings: Timings{
|
||||
PromptN: seq.numPromptInputs,
|
||||
PromptMS: float64(seq.startGenerationTime.Sub(seq.startProcessingTime).Milliseconds()),
|
||||
PredictedN: seq.numDecoded,
|
||||
PredictedMS: float64(time.Since(seq.startGenerationTime).Milliseconds()),
|
||||
},
|
||||
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
|
||||
Done: true,
|
||||
DoneReason: seq.doneReason,
|
||||
PromptEvalCount: seq.numPromptInputs,
|
||||
PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime),
|
||||
EvalCount: seq.numDecoded,
|
||||
EvalDuration: time.Since(seq.startGenerationTime),
|
||||
}); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError)
|
||||
}
|
||||
@@ -721,17 +674,8 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
type EmbeddingRequest struct {
|
||||
Content string `json:"content"`
|
||||
CachePrompt bool `json:"cache_prompt"`
|
||||
}
|
||||
|
||||
type EmbeddingResponse struct {
|
||||
Embedding []float32 `json:"embedding"`
|
||||
}
|
||||
|
||||
func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
||||
var req EmbeddingRequest
|
||||
var req llm.EmbeddingRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, fmt.Sprintf("bad request: %s", err), http.StatusBadRequest)
|
||||
return
|
||||
@@ -752,7 +696,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
slog.Info("aborting embeddings request due to client closing the connection")
|
||||
} else {
|
||||
slog.Error("Failed to acquire semaphore", "error", err)
|
||||
http.Error(w, fmt.Sprintf("Failed to acquire semaphore: %v", err), http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -761,9 +705,10 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
||||
found := false
|
||||
for i, sq := range s.seqs {
|
||||
if sq == nil {
|
||||
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
|
||||
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, false)
|
||||
if err != nil {
|
||||
s.mu.Unlock()
|
||||
s.seqsSem.Release(1)
|
||||
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
@@ -776,47 +721,24 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
||||
s.mu.Unlock()
|
||||
|
||||
if !found {
|
||||
s.seqsSem.Release(1)
|
||||
http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
embedding := <-seq.embedding
|
||||
|
||||
if err := json.NewEncoder(w).Encode(&EmbeddingResponse{
|
||||
if err := json.NewEncoder(w).Encode(&llm.EmbeddingResponse{
|
||||
Embedding: embedding,
|
||||
}); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
type HealthResponse struct {
|
||||
Status string `json:"status"`
|
||||
Progress float32 `json:"progress"`
|
||||
}
|
||||
|
||||
type ServerStatus int
|
||||
|
||||
const (
|
||||
ServerStatusReady ServerStatus = iota
|
||||
ServerStatusLoadingModel
|
||||
ServerStatusError
|
||||
)
|
||||
|
||||
func (s ServerStatus) ToString() string {
|
||||
switch s {
|
||||
case ServerStatusReady:
|
||||
return "ok"
|
||||
case ServerStatusLoadingModel:
|
||||
return "loading model"
|
||||
default:
|
||||
return "server error"
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) health(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(&HealthResponse{
|
||||
Status: s.status.ToString(),
|
||||
if err := json.NewEncoder(w).Encode(&llm.ServerStatusResponse{
|
||||
Status: s.status,
|
||||
Progress: s.progress,
|
||||
}); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||
@@ -879,7 +801,7 @@ func (s *Server) loadModel(
|
||||
panic(err)
|
||||
}
|
||||
|
||||
s.status = ServerStatusReady
|
||||
s.status = llm.ServerStatusReady
|
||||
s.ready.Done()
|
||||
}
|
||||
|
||||
@@ -937,7 +859,7 @@ func Execute(args []string) error {
|
||||
parallel: *parallel,
|
||||
seqs: make([]*Sequence, *parallel),
|
||||
seqsSem: semaphore.NewWeighted(int64(*parallel)),
|
||||
status: ServerStatusLoadingModel,
|
||||
status: llm.ServerStatusLoadingModel,
|
||||
}
|
||||
|
||||
var tensorSplitFloats []float32
|
||||
|
||||
@@ -31,8 +31,10 @@ type InputCache struct {
|
||||
cache kvcache.Cache
|
||||
}
|
||||
|
||||
func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots int, multiUserCache bool) (*InputCache, error) {
|
||||
if kvSize/int32(numSlots) < 1 {
|
||||
func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots int, batchSize int, multiUserCache bool) (*InputCache, error) {
|
||||
numCtx := kvSize / int32(numSlots)
|
||||
|
||||
if numCtx < 1 {
|
||||
return nil, fmt.Errorf("must have at least one kv cache entry per parallel sequence (kv: %v parallel: %v)", kvSize, numSlots)
|
||||
}
|
||||
|
||||
@@ -44,11 +46,11 @@ func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots
|
||||
|
||||
cache := model.Config().Cache
|
||||
if cache != nil {
|
||||
cache.Init(model.Backend(), kvCacheTypeFromStr(kvCacheType), kvSize)
|
||||
cache.Init(model.Backend(), kvCacheTypeFromStr(kvCacheType), numSlots, int(numCtx), batchSize)
|
||||
}
|
||||
|
||||
return &InputCache{
|
||||
numCtx: kvSize / int32(numSlots),
|
||||
numCtx: numCtx,
|
||||
enabled: cache != nil,
|
||||
slots: slots,
|
||||
multiUserCache: multiUserCache,
|
||||
@@ -89,7 +91,7 @@ type InputCacheSlot struct {
|
||||
lastUsed time.Time
|
||||
}
|
||||
|
||||
func (c *InputCache) LoadCacheSlot(prompt []input.Input, cachePrompt bool) (*InputCacheSlot, []input.Input, error) {
|
||||
func (c *InputCache) LoadCacheSlot(prompt []input.Input) (*InputCacheSlot, []input.Input, error) {
|
||||
var slot *InputCacheSlot
|
||||
var numPast int32
|
||||
var err error
|
||||
@@ -107,10 +109,6 @@ func (c *InputCache) LoadCacheSlot(prompt []input.Input, cachePrompt bool) (*Inp
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if !cachePrompt {
|
||||
numPast = 0
|
||||
}
|
||||
|
||||
slot.InUse = true
|
||||
slot.lastUsed = time.Now()
|
||||
|
||||
@@ -120,6 +118,10 @@ func (c *InputCache) LoadCacheSlot(prompt []input.Input, cachePrompt bool) (*Inp
|
||||
}
|
||||
|
||||
if c.cache != nil {
|
||||
if numPast > 0 && !c.cache.CanResume(slot.Id, numPast) {
|
||||
numPast = 0
|
||||
}
|
||||
|
||||
err = c.cache.Remove(slot.Id, numPast, math.MaxInt32)
|
||||
if err != nil {
|
||||
// Some models don't support partial erasure
|
||||
@@ -227,6 +229,8 @@ func countCommonPrefix(a []input.Input, b []input.Input) int32 {
|
||||
return count
|
||||
}
|
||||
|
||||
// TODO(jessegross): If we need to reprocess the inputs we should ensure that
|
||||
// we don't split up a SameBatch
|
||||
func (c *InputCache) ShiftDiscard(inputLen int32, numKeep int32) int32 {
|
||||
targetFree := (c.numCtx - numKeep) / 2
|
||||
targetFree = max(targetFree, 1)
|
||||
@@ -241,6 +245,14 @@ func (c *InputCache) ShiftDiscard(inputLen int32, numKeep int32) int32 {
|
||||
return discard
|
||||
}
|
||||
|
||||
type ErrReprocessInputs struct {
|
||||
Inputs []input.Input
|
||||
}
|
||||
|
||||
func (e *ErrReprocessInputs) Error() string {
|
||||
return fmt.Sprintf("kv cache shift not supported, inputs need reprocessing (input count: %v)", len(e.Inputs))
|
||||
}
|
||||
|
||||
// Frees up space in the KV cache by deleting the oldest half of history and shifting
|
||||
// the newest half into that space (saving numKeep inputs at the beginning).
|
||||
//
|
||||
@@ -260,11 +272,23 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int32) error {
|
||||
slog.Debug("context limit hit - shifting", "id", slot.Id, "limit", c.numCtx, "input", len(slot.Inputs),
|
||||
"keep", numKeep, "discard", discard)
|
||||
|
||||
// TODO (jessegross): KV cache removal can fail for certain types of models
|
||||
if c.cache != nil {
|
||||
err := c.cache.Remove(slot.Id, numKeep, numKeep+discard)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to remove old kv cache entries (id: %v, keep: %v discard: %v): %w", slot.Id, numKeep, discard, err)
|
||||
slog.Debug("kv cache removal unsupported, clearing cache and returning inputs for reprocessing",
|
||||
"id", slot.Id, "error", err)
|
||||
|
||||
// Create new input slice with preserved tokens (numKeep + remaining tokens after discard)
|
||||
newInputs := make([]input.Input, numKeep+inputLen-(numKeep+discard))
|
||||
copy(newInputs[:numKeep], slot.Inputs[:numKeep])
|
||||
copy(newInputs[numKeep:], slot.Inputs[numKeep+discard:])
|
||||
|
||||
// Reset the cache
|
||||
_ = c.cache.Remove(slot.Id, 0, -1)
|
||||
slot.Inputs = []input.Input{}
|
||||
|
||||
// Return error with inputs that need to be reprocessed
|
||||
return &ErrReprocessInputs{Inputs: newInputs}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
package ollamarunner
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"image"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
@@ -297,3 +300,220 @@ func TestShiftDiscard(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadCacheSlot(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cache InputCache
|
||||
prompt []input.Input
|
||||
wantErr bool
|
||||
expectedSlotId int
|
||||
expectedPrompt int // expected length of remaining prompt
|
||||
}{
|
||||
{
|
||||
name: "Basic cache hit - single user",
|
||||
cache: InputCache{
|
||||
multiUserCache: false,
|
||||
slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
{
|
||||
Id: 1,
|
||||
Inputs: []input.Input{},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-2 * time.Second),
|
||||
},
|
||||
},
|
||||
},
|
||||
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
wantErr: false,
|
||||
expectedSlotId: 0,
|
||||
expectedPrompt: 1, // Only token 3 remains
|
||||
},
|
||||
{
|
||||
name: "Basic cache hit - multi user",
|
||||
cache: InputCache{
|
||||
multiUserCache: true,
|
||||
slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
{
|
||||
Id: 1,
|
||||
Inputs: []input.Input{},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-2 * time.Second),
|
||||
},
|
||||
},
|
||||
},
|
||||
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
wantErr: false,
|
||||
expectedSlotId: 0,
|
||||
expectedPrompt: 1, // Only token 3 remains
|
||||
},
|
||||
{
|
||||
name: "Exact match - leave one input",
|
||||
cache: InputCache{
|
||||
multiUserCache: false,
|
||||
slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
},
|
||||
},
|
||||
prompt: []input.Input{{Token: 1}, {Token: 2}},
|
||||
wantErr: false,
|
||||
expectedSlotId: 0,
|
||||
expectedPrompt: 1, // Should leave 1 token for sampling
|
||||
},
|
||||
{
|
||||
name: "No available slots",
|
||||
cache: InputCache{
|
||||
multiUserCache: false,
|
||||
slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
||||
InUse: true,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
},
|
||||
},
|
||||
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
wantErr: true,
|
||||
expectedSlotId: -1,
|
||||
expectedPrompt: -1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
slot, remainingPrompt, err := tt.cache.LoadCacheSlot(tt.prompt)
|
||||
|
||||
// Check error state
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("LoadCacheSlot() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if tt.wantErr {
|
||||
return // Skip further checks if we expected an error
|
||||
}
|
||||
|
||||
// Verify slot ID
|
||||
if slot.Id != tt.expectedSlotId {
|
||||
t.Errorf("LoadCacheSlot() slot ID = %v, expected %v", slot.Id, tt.expectedSlotId)
|
||||
}
|
||||
|
||||
// Verify slot is now marked in use
|
||||
if !slot.InUse {
|
||||
t.Errorf("LoadCacheSlot() slot not marked InUse")
|
||||
}
|
||||
|
||||
// Verify remaining prompt length
|
||||
if len(remainingPrompt) != tt.expectedPrompt {
|
||||
t.Errorf("LoadCacheSlot() remaining prompt length = %v, expected %v",
|
||||
len(remainingPrompt), tt.expectedPrompt)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Mock implementation of the Cache interface
|
||||
type mockCache struct {
|
||||
shouldFail bool
|
||||
}
|
||||
|
||||
// Implement only the methods needed for the test
|
||||
func (m *mockCache) Remove(seq int, beginIndex, endIndex int32) error {
|
||||
if m.shouldFail {
|
||||
return fmt.Errorf("mock cache removal error")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stub implementations for other interface methods
|
||||
func (m *mockCache) SetLayer(layer int) {}
|
||||
func (m *mockCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) { return nil, nil, nil }
|
||||
func (m *mockCache) Put(ctx ml.Context, key, value ml.Tensor) {}
|
||||
func (m *mockCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {}
|
||||
func (m *mockCache) Close() {}
|
||||
func (m *mockCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error { return nil }
|
||||
func (m *mockCache) CopyPrefix(srcSeq, dstSeq int, len int32) {}
|
||||
func (m *mockCache) SetConfig(ml.CacheConfig) {}
|
||||
func (m *mockCache) CanResume(seq int, pos int32) bool { return true }
|
||||
|
||||
func TestShiftCacheSlot(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
numCtx int32
|
||||
inputs []input.Input
|
||||
numKeep int32
|
||||
cacheErr bool
|
||||
wantErr any
|
||||
wantInputsLen int
|
||||
}{
|
||||
{
|
||||
name: "Normal shift",
|
||||
numCtx: 10,
|
||||
inputs: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
|
||||
numKeep: 2,
|
||||
cacheErr: false, // No error
|
||||
wantErr: nil,
|
||||
wantInputsLen: 6, // After discarding 4 tokens
|
||||
},
|
||||
{
|
||||
name: "Cache removal fails",
|
||||
numCtx: 10,
|
||||
inputs: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
|
||||
numKeep: 2,
|
||||
cacheErr: true,
|
||||
wantErr: &ErrReprocessInputs{},
|
||||
wantInputsLen: 0, // Original inputs should be cleared
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mock := &mockCache{shouldFail: tt.cacheErr}
|
||||
c := InputCache{
|
||||
numCtx: tt.numCtx,
|
||||
cache: mock,
|
||||
}
|
||||
slot := &InputCacheSlot{
|
||||
Id: 123,
|
||||
Inputs: make([]input.Input, len(tt.inputs)),
|
||||
}
|
||||
copy(slot.Inputs, tt.inputs)
|
||||
|
||||
err := c.ShiftCacheSlot(slot, tt.numKeep)
|
||||
|
||||
if tt.wantErr != nil {
|
||||
if err == nil {
|
||||
t.Errorf("Expected error but got nil")
|
||||
return
|
||||
}
|
||||
|
||||
if !errors.As(err, &tt.wantErr) {
|
||||
t.Errorf("Expected error of type %T but got %T: %v", tt.wantErr, err, err)
|
||||
}
|
||||
} else if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(slot.Inputs) != tt.wantInputsLen {
|
||||
t.Errorf("Slot inputs length after operation: got %v, want %v", len(slot.Inputs), tt.wantInputsLen)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -24,6 +24,7 @@ import (
|
||||
"golang.org/x/sync/semaphore"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
@@ -33,10 +34,14 @@ import (
|
||||
_ "github.com/ollama/ollama/model/models"
|
||||
)
|
||||
|
||||
type contextList struct {
|
||||
list []ml.Context
|
||||
}
|
||||
|
||||
type Sequence struct {
|
||||
// ctx for allocating tensors that last the lifetime of the sequence, such as
|
||||
// ctxs are used for allocating tensors that last the lifetime of the sequence, such as
|
||||
// multimodal embeddings
|
||||
ctx ml.Context
|
||||
ctxs *contextList
|
||||
|
||||
// batch index
|
||||
iBatch int
|
||||
@@ -77,7 +82,7 @@ type Sequence struct {
|
||||
// true if an embedding are to be returned instead of text generation
|
||||
embeddingOnly bool
|
||||
|
||||
doneReason string
|
||||
doneReason llm.DoneReason
|
||||
|
||||
// Metrics
|
||||
startProcessingTime time.Time
|
||||
@@ -94,13 +99,12 @@ type NewSequenceParams struct {
|
||||
embedding bool
|
||||
}
|
||||
|
||||
func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequenceParams) (*Sequence, error) {
|
||||
func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) {
|
||||
s.ready.Wait()
|
||||
|
||||
startTime := time.Now()
|
||||
ctx := s.model.Backend().NewContext()
|
||||
|
||||
inputs, err := s.inputs(ctx, prompt, images)
|
||||
inputs, ctxs, err := s.inputs(prompt, images)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to process inputs: %w", err)
|
||||
} else if len(inputs) == 0 {
|
||||
@@ -116,8 +120,36 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
|
||||
|
||||
if int32(len(inputs)) > s.cache.numCtx {
|
||||
discard := int32(len(inputs)) - s.cache.numCtx
|
||||
promptStart := params.numKeep + discard
|
||||
|
||||
// If we need to truncate in the middle of a unbreakable batch, remove the entire batch
|
||||
sameBatch := 0
|
||||
for i, inp := range inputs {
|
||||
if sameBatch > 0 {
|
||||
sameBatch--
|
||||
|
||||
if promptStart == int32(i) {
|
||||
promptStart++
|
||||
}
|
||||
} else if promptStart == int32(i) {
|
||||
break
|
||||
}
|
||||
|
||||
if inp.SameBatch != 0 {
|
||||
if int32(i) < params.numKeep {
|
||||
return nil, fmt.Errorf("SameBatch may not be specified within numKeep (index: %v numKeep: %v SameBatch: %v)", i, params.numKeep, inp.SameBatch)
|
||||
}
|
||||
|
||||
sameBatch = inp.SameBatch
|
||||
}
|
||||
}
|
||||
|
||||
if promptStart >= int32(len(inputs)) {
|
||||
return nil, errors.New("entire prompt removed by truncation")
|
||||
}
|
||||
|
||||
newInputs := inputs[:params.numKeep]
|
||||
newInputs = append(newInputs, inputs[params.numKeep+discard:]...)
|
||||
newInputs = append(newInputs, inputs[promptStart:]...)
|
||||
|
||||
slog.Warn("truncating input prompt", "limit", s.cache.numCtx, "prompt", len(inputs), "keep", params.numKeep, "new", len(newInputs))
|
||||
inputs = newInputs
|
||||
@@ -126,7 +158,7 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
|
||||
// TODO(jessegross): Ingest cached history for grammar
|
||||
|
||||
return &Sequence{
|
||||
ctx: ctx,
|
||||
ctxs: ctxs,
|
||||
inputs: inputs,
|
||||
numPromptInputs: len(inputs),
|
||||
startProcessingTime: startTime,
|
||||
@@ -145,7 +177,7 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
|
||||
// inputs processes the prompt and images into a list of inputs
|
||||
// by splitting the prompt on [img-<n>] tags, tokenizing text and
|
||||
// decoding images
|
||||
func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]input.Input, error) {
|
||||
func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, *contextList, error) {
|
||||
var inputs []input.Input
|
||||
var parts []string
|
||||
var matches [][]string
|
||||
@@ -160,12 +192,19 @@ func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]in
|
||||
parts = []string{prompt}
|
||||
}
|
||||
|
||||
var contexts contextList
|
||||
runtime.AddCleanup(&contexts, func(ctxs []ml.Context) {
|
||||
for _, ctx := range ctxs {
|
||||
ctx.Close()
|
||||
}
|
||||
}, contexts.list)
|
||||
|
||||
postTokenize := false
|
||||
for i, part := range parts {
|
||||
// text - tokenize
|
||||
tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
for _, t := range tokens {
|
||||
@@ -185,12 +224,14 @@ func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]in
|
||||
}
|
||||
|
||||
if imageIndex < 0 {
|
||||
return nil, fmt.Errorf("invalid image index: %d", n)
|
||||
return nil, nil, fmt.Errorf("invalid image index: %d", n)
|
||||
}
|
||||
|
||||
ctx := s.model.Backend().NewContext()
|
||||
contexts.list = append(contexts.list, ctx)
|
||||
imageEmbeddings, err := multimodalProcessor.EncodeMultimodal(ctx, images[imageIndex].Data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
s.multimodalHash.Reset()
|
||||
@@ -204,13 +245,13 @@ func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]in
|
||||
|
||||
if visionModel && postTokenize {
|
||||
var err error
|
||||
inputs, err = multimodalProcessor.PostTokenize(ctx, inputs)
|
||||
inputs, err = multimodalProcessor.PostTokenize(inputs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return inputs, nil
|
||||
return inputs, &contexts, nil
|
||||
}
|
||||
|
||||
type Server struct {
|
||||
@@ -222,7 +263,7 @@ type Server struct {
|
||||
model model.Model
|
||||
|
||||
// status for external health reporting - loading, ready to serve, etc.
|
||||
status ServerStatus
|
||||
status llm.ServerStatus
|
||||
|
||||
// current progress on loading the model
|
||||
progress float32
|
||||
@@ -251,6 +292,9 @@ type Server struct {
|
||||
// KV cache
|
||||
cache *InputCache
|
||||
|
||||
// next sequence for prompt processing to avoid starvation
|
||||
nextSeq int
|
||||
|
||||
// multimodalHash generates hashes for comparing equality
|
||||
// of non-text data
|
||||
multimodalHash maphash.Hash
|
||||
@@ -297,7 +341,7 @@ func flushPending(seq *Sequence) bool {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) removeSequence(seqIndex int, reason string) {
|
||||
func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
|
||||
seq := s.seqs[seqIndex]
|
||||
|
||||
flushPending(seq)
|
||||
@@ -305,7 +349,6 @@ func (s *Server) removeSequence(seqIndex int, reason string) {
|
||||
close(seq.responses)
|
||||
close(seq.embedding)
|
||||
seq.cache.InUse = false
|
||||
seq.ctx.Close()
|
||||
s.seqs[seqIndex] = nil
|
||||
s.seqsSem.Release(1)
|
||||
}
|
||||
@@ -333,16 +376,22 @@ func (s *Server) processBatch() error {
|
||||
}
|
||||
defer s.mu.Unlock()
|
||||
|
||||
var options input.Options
|
||||
var batchInputs []int32
|
||||
var batch input.Batch
|
||||
|
||||
resumeSeq := -1
|
||||
seqIdx := s.nextSeq - 1
|
||||
for range s.seqs {
|
||||
seqIdx = (seqIdx + 1) % len(s.seqs)
|
||||
seq := s.seqs[seqIdx]
|
||||
|
||||
for i, seq := range s.seqs {
|
||||
if seq == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// if past the num predict limit
|
||||
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
|
||||
s.removeSequence(i, "limit")
|
||||
s.removeSequence(seqIdx, llm.DoneReasonLength)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -351,33 +400,61 @@ func (s *Server) processBatch() error {
|
||||
seq.cache.Inputs = []input.Input{}
|
||||
}
|
||||
|
||||
for j, inp := range seq.inputs {
|
||||
if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+1) > s.cache.numCtx {
|
||||
if len(seq.pendingInputs) == 0 {
|
||||
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
break
|
||||
}
|
||||
batchSize := s.batchSize
|
||||
|
||||
for i, inp := range seq.inputs {
|
||||
// If we are required to put following inputs into a single batch then extend the
|
||||
// batch size. Since we are only extending the size the minimum amount possible, this
|
||||
// will cause a break if we have existing inputs.
|
||||
minBatch := 1 + inp.SameBatch
|
||||
if minBatch > batchSize {
|
||||
batchSize = minBatch
|
||||
}
|
||||
|
||||
if j >= s.batchSize {
|
||||
// Stop if the required batch would put us over the total batch size (including tokens
|
||||
// added by other sequences). If we haven't been able to add anything yet then pick up
|
||||
// here again for the next batch to avoid starvation, though we can opportunistically
|
||||
// check if other sequences can still squeeze something in.
|
||||
if len(batchInputs)+minBatch > batchSize {
|
||||
if len(seq.pendingInputs) == 0 && resumeSeq == -1 {
|
||||
resumeSeq = seqIdx
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
options.Inputs = append(options.Inputs, inp.Token)
|
||||
if inp.Multimodal != nil {
|
||||
options.Multimodal = append(options.Multimodal, input.MultimodalIndex{Index: len(options.Inputs) - 1, Multimodal: inp.Multimodal})
|
||||
// If the sum of our working set (already processed tokens, tokens we added to this
|
||||
// batch, required following tokens) exceeds the context size, then trigger a shift
|
||||
// now so we don't have to do one later when we can't break the batch.
|
||||
if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+minBatch) > s.cache.numCtx {
|
||||
if len(seq.pendingInputs) != 0 {
|
||||
break
|
||||
}
|
||||
|
||||
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
|
||||
if err != nil {
|
||||
var reprocess *ErrReprocessInputs
|
||||
if errors.As(err, &reprocess) {
|
||||
// Prepend these inputs to the sequence's inputs queue for reprocessing
|
||||
seq.inputs = append(reprocess.Inputs, seq.inputs...)
|
||||
// Skip this sequence but continue processing the rest
|
||||
continue
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
options.Positions = append(options.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
|
||||
options.Sequences = append(options.Sequences, seq.cache.Id)
|
||||
batchInputs = append(batchInputs, inp.Token)
|
||||
if inp.Multimodal != nil {
|
||||
batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: len(batchInputs) - 1, Multimodal: inp.Multimodal})
|
||||
}
|
||||
|
||||
seq.iBatch = len(options.Outputs)
|
||||
if j+1 == len(seq.inputs) {
|
||||
options.Outputs = append(options.Outputs, int32(len(options.Inputs)-1))
|
||||
batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
|
||||
batch.Sequences = append(batch.Sequences, seq.cache.Id)
|
||||
|
||||
seq.iBatch = len(batch.Outputs)
|
||||
if i+1 == len(seq.inputs) {
|
||||
batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1))
|
||||
}
|
||||
seq.pendingInputs = append(seq.pendingInputs, inp)
|
||||
}
|
||||
@@ -385,14 +462,20 @@ func (s *Server) processBatch() error {
|
||||
seq.inputs = seq.inputs[len(seq.pendingInputs):]
|
||||
}
|
||||
|
||||
if len(options.Inputs) == 0 {
|
||||
if resumeSeq != -1 {
|
||||
s.nextSeq = resumeSeq
|
||||
} else {
|
||||
s.nextSeq = seqIdx + 1
|
||||
}
|
||||
|
||||
if len(batchInputs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx := s.model.Backend().NewContext()
|
||||
defer ctx.Close()
|
||||
|
||||
modelOutput, err := model.Forward(ctx, s.model, options)
|
||||
modelOutput, err := model.Forward(ctx, s.model, batchInputs, batch)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decode batch: %w", err)
|
||||
}
|
||||
@@ -427,12 +510,12 @@ func (s *Server) processBatch() error {
|
||||
if seq.embeddingOnly {
|
||||
// TODO(jessegross): Embedding support
|
||||
slog.Warn("generation of embedding outputs not yet supported")
|
||||
s.removeSequence(i, "")
|
||||
s.removeSequence(i, llm.DoneReasonStop)
|
||||
continue
|
||||
}
|
||||
|
||||
// sample a token
|
||||
vocabSize := len(logits) / len(options.Outputs)
|
||||
vocabSize := len(logits) / len(batch.Outputs)
|
||||
|
||||
token, err := seq.sampler.Sample(logits[seq.iBatch*vocabSize : (seq.iBatch+1)*vocabSize])
|
||||
if err != nil {
|
||||
@@ -445,7 +528,7 @@ func (s *Server) processBatch() error {
|
||||
// as it's important for the /api/generate context
|
||||
// seq.responses <- piece
|
||||
|
||||
s.removeSequence(i, "stop")
|
||||
s.removeSequence(i, llm.DoneReasonStop)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -481,7 +564,7 @@ func (s *Server) processBatch() error {
|
||||
}
|
||||
seq.cache.Inputs = seq.cache.Inputs[:tokenLen]
|
||||
|
||||
s.removeSequence(i, "stop")
|
||||
s.removeSequence(i, llm.DoneReasonStop)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -494,82 +577,25 @@ func (s *Server) processBatch() error {
|
||||
}
|
||||
|
||||
if !flushPending(seq) {
|
||||
s.removeSequence(i, "connection")
|
||||
s.removeSequence(i, llm.DoneReasonConnectionClosed)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO (jmorganca): use structs from the api package to avoid duplication
|
||||
// this way the api acts as a proxy instead of using a different api for the
|
||||
// runner
|
||||
type Options struct {
|
||||
api.Runner
|
||||
|
||||
NumKeep int `json:"n_keep"`
|
||||
Seed int `json:"seed"`
|
||||
NumPredict int `json:"n_predict"`
|
||||
TopK int `json:"top_k"`
|
||||
TopP float32 `json:"top_p"`
|
||||
MinP float32 `json:"min_p"`
|
||||
TypicalP float32 `json:"typical_p"`
|
||||
RepeatLastN int `json:"repeat_last_n"`
|
||||
Temperature float32 `json:"temperature"`
|
||||
RepeatPenalty float32 `json:"repeat_penalty"`
|
||||
PresencePenalty float32 `json:"presence_penalty"`
|
||||
FrequencyPenalty float32 `json:"frequency_penalty"`
|
||||
Mirostat int `json:"mirostat"`
|
||||
MirostatTau float32 `json:"mirostat_tau"`
|
||||
MirostatEta float32 `json:"mirostat_eta"`
|
||||
Stop []string `json:"stop"`
|
||||
}
|
||||
|
||||
type ImageData struct {
|
||||
Data []byte `json:"data"`
|
||||
ID int `json:"id"`
|
||||
AspectRatioID int `json:"aspect_ratio_id"`
|
||||
}
|
||||
|
||||
type CompletionRequest struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Images []ImageData `json:"image_data"`
|
||||
Grammar string `json:"grammar"`
|
||||
CachePrompt bool `json:"cache_prompt"`
|
||||
|
||||
Options
|
||||
}
|
||||
|
||||
type Timings struct {
|
||||
PredictedN int `json:"predicted_n"`
|
||||
PredictedMS float64 `json:"predicted_ms"`
|
||||
PromptN int `json:"prompt_n"`
|
||||
PromptMS float64 `json:"prompt_ms"`
|
||||
}
|
||||
|
||||
type CompletionResponse struct {
|
||||
Content string `json:"content"`
|
||||
Stop bool `json:"stop"`
|
||||
|
||||
Model string `json:"model,omitempty"`
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
StoppedLimit bool `json:"stopped_limit,omitempty"`
|
||||
PredictedN int `json:"predicted_n,omitempty"`
|
||||
PredictedMS float64 `json:"predicted_ms,omitempty"`
|
||||
PromptN int `json:"prompt_n,omitempty"`
|
||||
PromptMS float64 `json:"prompt_ms,omitempty"`
|
||||
|
||||
Timings Timings `json:"timings"`
|
||||
}
|
||||
|
||||
func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
var req CompletionRequest
|
||||
req.Options = Options(api.DefaultOptions())
|
||||
var req llm.CompletionRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "Bad request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Options == nil {
|
||||
opts := api.DefaultOptions()
|
||||
req.Options = &opts
|
||||
}
|
||||
|
||||
// Set the headers to indicate streaming
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Transfer-Encoding", "chunked")
|
||||
@@ -591,18 +617,18 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
sampler := sample.NewSampler(
|
||||
req.Temperature,
|
||||
req.TopK,
|
||||
req.TopP,
|
||||
req.MinP,
|
||||
req.Seed,
|
||||
req.Options.Temperature,
|
||||
req.Options.TopK,
|
||||
req.Options.TopP,
|
||||
req.Options.MinP,
|
||||
req.Options.Seed,
|
||||
grammar,
|
||||
)
|
||||
|
||||
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
|
||||
numPredict: req.NumPredict,
|
||||
stop: req.Stop,
|
||||
numKeep: int32(req.NumKeep),
|
||||
numPredict: req.Options.NumPredict,
|
||||
stop: req.Options.Stop,
|
||||
numKeep: int32(req.Options.NumKeep),
|
||||
sampler: sampler,
|
||||
embedding: false,
|
||||
})
|
||||
@@ -616,7 +642,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
slog.Info("aborting completion request due to client closing the connection")
|
||||
} else {
|
||||
slog.Error("Failed to acquire semaphore", "error", err)
|
||||
http.Error(w, fmt.Sprintf("Failed to acquire semaphore: %v", err), http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -625,9 +651,10 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
found := false
|
||||
for i, sq := range s.seqs {
|
||||
if sq == nil {
|
||||
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
|
||||
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs)
|
||||
if err != nil {
|
||||
s.mu.Unlock()
|
||||
s.seqsSem.Release(1)
|
||||
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
@@ -641,6 +668,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
s.mu.Unlock()
|
||||
|
||||
if !found {
|
||||
s.seqsSem.Release(1)
|
||||
http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
@@ -652,7 +680,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
case content, ok := <-seq.responses:
|
||||
if ok {
|
||||
if err := json.NewEncoder(w).Encode(&CompletionResponse{
|
||||
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
|
||||
Content: content,
|
||||
}); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||
@@ -662,16 +690,13 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
flusher.Flush()
|
||||
} else {
|
||||
// Send the final response
|
||||
if err := json.NewEncoder(w).Encode(&CompletionResponse{
|
||||
Stop: true,
|
||||
StoppedLimit: seq.doneReason == "limit",
|
||||
Timings: Timings{
|
||||
PromptN: seq.numPromptInputs,
|
||||
PromptMS: float64(seq.startGenerationTime.Sub(seq.startProcessingTime).Milliseconds()),
|
||||
PredictedN: seq.numPredicted,
|
||||
PredictedMS: float64(time.Since(seq.startGenerationTime).Milliseconds()),
|
||||
},
|
||||
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
|
||||
Done: true,
|
||||
DoneReason: seq.doneReason,
|
||||
PromptEvalCount: seq.numPromptInputs,
|
||||
PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime),
|
||||
EvalCount: seq.numPredicted,
|
||||
EvalDuration: time.Since(seq.startGenerationTime),
|
||||
}); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError)
|
||||
}
|
||||
@@ -682,102 +707,10 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
type EmbeddingRequest struct {
|
||||
Content string `json:"content"`
|
||||
CachePrompt bool `json:"cache_prompt"`
|
||||
}
|
||||
|
||||
type EmbeddingResponse struct {
|
||||
Embedding []float32 `json:"embedding"`
|
||||
}
|
||||
|
||||
func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
||||
var req EmbeddingRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, fmt.Sprintf("bad request: %s", err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
slog.Debug("embedding request", "content", req.Content)
|
||||
|
||||
seq, err := s.NewSequence(req.Content, nil, NewSequenceParams{embedding: true})
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Ensure there is a place to put the sequence, released when removed from s.seqs
|
||||
if err := s.seqsSem.Acquire(r.Context(), 1); err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
slog.Info("aborting embeddings request due to client closing the connection")
|
||||
} else {
|
||||
slog.Error("Failed to acquire semaphore", "error", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
found := false
|
||||
for i, sq := range s.seqs {
|
||||
if sq == nil {
|
||||
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
|
||||
if err != nil {
|
||||
s.mu.Unlock()
|
||||
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
s.seqs[i] = seq
|
||||
s.cond.Signal()
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
if !found {
|
||||
http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
embedding := <-seq.embedding
|
||||
|
||||
if err := json.NewEncoder(w).Encode(&EmbeddingResponse{
|
||||
Embedding: embedding,
|
||||
}); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
type HealthResponse struct {
|
||||
Status string `json:"status"`
|
||||
Progress float32 `json:"progress"`
|
||||
}
|
||||
|
||||
type ServerStatus int
|
||||
|
||||
const (
|
||||
ServerStatusReady ServerStatus = iota
|
||||
ServerStatusLoadingModel
|
||||
ServerStatusError
|
||||
)
|
||||
|
||||
func (s ServerStatus) ToString() string {
|
||||
switch s {
|
||||
case ServerStatusReady:
|
||||
return "ok"
|
||||
case ServerStatusLoadingModel:
|
||||
return "loading model"
|
||||
default:
|
||||
return "server error"
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) health(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(&HealthResponse{
|
||||
Status: s.status.ToString(),
|
||||
if err := json.NewEncoder(w).Encode(&llm.ServerStatusResponse{
|
||||
Status: s.status,
|
||||
Progress: s.progress,
|
||||
}); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||
@@ -795,7 +728,53 @@ func (m *multiLPath) String() string {
|
||||
return strings.Join(*m, ", ")
|
||||
}
|
||||
|
||||
func (s *Server) reserveWorstCaseGraph() error {
|
||||
ctx := s.model.Backend().NewContext()
|
||||
defer ctx.Close()
|
||||
|
||||
var batch input.Batch
|
||||
|
||||
inputs := make([]int32, s.batchSize)
|
||||
batch.Positions = make([]int32, len(inputs))
|
||||
batch.Sequences = make([]int, len(inputs))
|
||||
for i := range inputs {
|
||||
batch.Positions[i] = int32(i)
|
||||
}
|
||||
|
||||
batch.Outputs = make([]int32, s.parallel)
|
||||
for i := range batch.Outputs {
|
||||
batch.Outputs[i] = int32(i)
|
||||
}
|
||||
|
||||
var err error
|
||||
batch.Inputs, err = ctx.Input().FromIntSlice(inputs, len(inputs))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cache := s.model.Config().Cache
|
||||
if cache != nil {
|
||||
err := cache.StartForward(ctx, batch, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
t, err := s.model.Forward(ctx, batch)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = ctx.Forward(t).Reserve()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) loadModel(
|
||||
ctx context.Context,
|
||||
mpath string,
|
||||
params ml.BackendParams,
|
||||
lpath multiLPath,
|
||||
@@ -805,7 +784,7 @@ func (s *Server) loadModel(
|
||||
multiUserCache bool,
|
||||
) {
|
||||
var err error
|
||||
s.model, err = model.New(mpath, params)
|
||||
s.model, err = model.New(ctx, mpath, params)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
@@ -817,7 +796,7 @@ func (s *Server) loadModel(
|
||||
panic("loras are not yet implemented")
|
||||
}
|
||||
|
||||
s.cache, err = NewInputCache(s.model, kvCacheType, int32(kvSize), parallel, multiUserCache)
|
||||
s.cache, err = NewInputCache(s.model, kvCacheType, int32(kvSize), parallel, s.batchSize, multiUserCache)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
@@ -831,7 +810,12 @@ func (s *Server) loadModel(
|
||||
s.seqs = make([]*Sequence, s.parallel)
|
||||
s.seqsSem = semaphore.NewWeighted(int64(s.parallel))
|
||||
|
||||
s.status = ServerStatusReady
|
||||
err = s.reserveWorstCaseGraph()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
s.status = llm.ServerStatusReady
|
||||
s.ready.Done()
|
||||
}
|
||||
|
||||
@@ -883,7 +867,7 @@ func Execute(args []string) error {
|
||||
|
||||
server := &Server{
|
||||
batchSize: *batchSize,
|
||||
status: ServerStatusLoadingModel,
|
||||
status: llm.ServerStatusLoadingModel,
|
||||
}
|
||||
|
||||
// TODO(jessegross): Parameters that need to be implemented:
|
||||
@@ -901,6 +885,9 @@ func Execute(args []string) error {
|
||||
}
|
||||
|
||||
params := ml.BackendParams{
|
||||
Progress: func(progress float32) {
|
||||
server.progress = progress
|
||||
},
|
||||
NumThreads: *threads,
|
||||
NumGPULayers: *numGPULayers,
|
||||
MainGPU: *mainGPU,
|
||||
@@ -909,13 +896,13 @@ func Execute(args []string) error {
|
||||
}
|
||||
|
||||
server.ready.Add(1)
|
||||
go server.loadModel(*mpath, params, lpaths, *parallel, *kvCacheType, *kvSize, *multiUserCache)
|
||||
|
||||
server.cond = sync.NewCond(&server.mu)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
go server.loadModel(ctx, *mpath, params, lpaths, *parallel, *kvCacheType, *kvSize, *multiUserCache)
|
||||
|
||||
server.cond = sync.NewCond(&server.mu)
|
||||
|
||||
go server.run(ctx)
|
||||
|
||||
addr := "127.0.0.1:" + strconv.Itoa(*port)
|
||||
@@ -927,9 +914,13 @@ func Execute(args []string) error {
|
||||
defer listener.Close()
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/embedding", server.embeddings)
|
||||
mux.HandleFunc("/completion", server.completion)
|
||||
mux.HandleFunc("/health", server.health)
|
||||
// TODO: support embeddings
|
||||
mux.HandleFunc("POST /embedding", func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "this model does not support embeddings", http.StatusNotImplemented)
|
||||
})
|
||||
|
||||
mux.HandleFunc("POST /completion", server.completion)
|
||||
mux.HandleFunc("GET /health", server.health)
|
||||
|
||||
httpServer := http.Server{
|
||||
Handler: mux,
|
||||
|
||||
@@ -26,6 +26,10 @@ type Sampler struct {
|
||||
}
|
||||
|
||||
func (s *Sampler) Sample(logits []float32) (int32, error) {
|
||||
if len(logits) == 0 {
|
||||
return -1, errors.New("sample: no logits provided to sample")
|
||||
}
|
||||
|
||||
tokens := make([]token, len(logits))
|
||||
for i := range logits {
|
||||
tokens[i].id = int32(i)
|
||||
@@ -84,25 +88,16 @@ func (s *Sampler) sample(tokens []token) (token, error) {
|
||||
return greedy(tokens), nil
|
||||
}
|
||||
|
||||
if s.topK > 0 {
|
||||
tokens = topK(tokens, s.topK)
|
||||
} else {
|
||||
sortLogits(tokens)
|
||||
}
|
||||
// topK also sorts the tokens in descending order of logits
|
||||
tokens = topK(tokens, s.topK)
|
||||
|
||||
// token logit values are updated to probabilities
|
||||
tokens = temperature(tokens, s.temperature)
|
||||
// scale and normalize the tokens in place
|
||||
temperature(tokens, s.temperature)
|
||||
softmax(tokens)
|
||||
|
||||
tokens = topP(tokens, s.topP)
|
||||
tokens = minP(tokens, s.minP)
|
||||
|
||||
// TODO: this should fall back to greedy sampling
|
||||
// or topP, topK values etc should be such that
|
||||
// there are always tokens to sample from
|
||||
if len(tokens) == 0 {
|
||||
return token{}, errors.New("no tokens to sample from")
|
||||
}
|
||||
|
||||
var r float32
|
||||
if s.rng != nil {
|
||||
r = s.rng.Float32()
|
||||
@@ -125,6 +120,9 @@ func (s *Sampler) sample(tokens []token) (token, error) {
|
||||
return 1
|
||||
})
|
||||
|
||||
if math.IsNaN(float64(sum)) {
|
||||
return token{}, errors.New("sample: logits sum to NaN, check model output")
|
||||
}
|
||||
return tokens[idx], nil
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package sample
|
||||
|
||||
import (
|
||||
"math"
|
||||
"math/rand/v2"
|
||||
"testing"
|
||||
)
|
||||
@@ -29,6 +30,29 @@ func TestWeighted(t *testing.T) {
|
||||
if want != got {
|
||||
t.Errorf("index mismatch: want %d, got %d", want, got)
|
||||
}
|
||||
|
||||
// Test very high p
|
||||
logits = []float32{1.0, 0.9999999999999999, 0.5, 0.1}
|
||||
// Use extremely small topP to filter out all tokens
|
||||
sampler = NewSampler(1.0, 0, 1e-10, 0, 0, nil)
|
||||
got, err = sampler.Sample(logits)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
// Should get the token with the highest logit
|
||||
want = int32(0)
|
||||
if want != got {
|
||||
t.Errorf("index mismatch: want %d, got %d", want, got)
|
||||
}
|
||||
|
||||
logits = []float32{float32(math.NaN()), float32(math.NaN()), float32(math.NaN())}
|
||||
sampler = NewSampler(1, 0, 0.95, 0.05, 0, nil)
|
||||
got, err = sampler.Sample(logits)
|
||||
if err == nil {
|
||||
t.Errorf("expected error, got %d", got)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSample(b *testing.B) {
|
||||
|
||||
@@ -1,12 +1,41 @@
|
||||
package sample
|
||||
|
||||
import (
|
||||
"container/heap"
|
||||
"math"
|
||||
"slices"
|
||||
)
|
||||
|
||||
// temperature applies scaling and softmax to the logits
|
||||
func temperature(ts []token, temp float32) []token {
|
||||
// tokenHeap implements heap.Interface and holds tokens as a min-heap to track k largest elements
|
||||
type tokenHeap []token
|
||||
|
||||
func (h tokenHeap) Len() int { return len(h) }
|
||||
func (h tokenHeap) Less(i, j int) bool { return h[i].value < h[j].value }
|
||||
func (h tokenHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
|
||||
|
||||
func (h *tokenHeap) Push(x any) {
|
||||
*h = append(*h, x.(token))
|
||||
}
|
||||
|
||||
func (h *tokenHeap) Pop() any {
|
||||
old := *h
|
||||
n := len(old)
|
||||
x := old[n-1]
|
||||
*h = old[0 : n-1]
|
||||
return x
|
||||
}
|
||||
|
||||
// temperature applies scaling to the logits
|
||||
func temperature(ts []token, temp float32) {
|
||||
// Ensure temperature clipping near 0 to avoid numerical instability
|
||||
temp = max(temp, 1e-7)
|
||||
for i := range ts {
|
||||
ts[i].value = ts[i].value / temp
|
||||
}
|
||||
}
|
||||
|
||||
// softmax applies normalization to the logits
|
||||
func softmax(ts []token) {
|
||||
// Find max logit for numerical stability
|
||||
maxLogit := float32(math.Inf(-1))
|
||||
for _, t := range ts {
|
||||
@@ -15,81 +44,59 @@ func temperature(ts []token, temp float32) []token {
|
||||
}
|
||||
}
|
||||
|
||||
// Apply temperature and compute exp(x - max)
|
||||
temp = max(temp, 1e-7)
|
||||
// Compute exp(x - max)
|
||||
var sum float32
|
||||
for i, v := range ts {
|
||||
ts[i].value = float32(math.Exp(float64((v.value - maxLogit) / temp)))
|
||||
ts[i].value = float32(math.Exp(float64(v.value - maxLogit)))
|
||||
sum += ts[i].value
|
||||
}
|
||||
|
||||
// Normalize
|
||||
// exp(x - max) / sum(exp(x - max))
|
||||
for i := range ts {
|
||||
ts[i].value /= sum
|
||||
}
|
||||
|
||||
return ts
|
||||
}
|
||||
|
||||
// siftDown maintains a min-heap property by recursively moving larger elements down the heap.
|
||||
//
|
||||
// The heap is represented as an array where for any node at index i:
|
||||
// - Left child is at index 2i + 1
|
||||
// - Right child is at index 2i + 2
|
||||
// - Parent is at index (i-1)/2
|
||||
//
|
||||
// The function compares a node with its children and:
|
||||
// 1. Finds the smallest value between the node and its children
|
||||
// 2. If the node is not the smallest, swaps it with its smallest child
|
||||
// 3. Continues this process down the affected path until the min-heap property is restored
|
||||
func siftDown(data []token, start, end int) {
|
||||
root := start
|
||||
for {
|
||||
child := 2*root + 1
|
||||
if child >= end {
|
||||
break
|
||||
}
|
||||
// Find smaller child (we want min heap)
|
||||
if child+1 < end && data[child+1].value < data[child].value {
|
||||
child++
|
||||
}
|
||||
// Exit if root is already smaller than children
|
||||
if data[root].value <= data[child].value {
|
||||
break
|
||||
}
|
||||
// Swap with smaller child and continue
|
||||
data[root], data[child] = data[child], data[root]
|
||||
root = child
|
||||
}
|
||||
}
|
||||
|
||||
// topK limits the number of tokens considered to the k highest logits
|
||||
func topK(ts []token, k int) []token {
|
||||
if k >= len(ts) {
|
||||
if k >= len(ts) || k <= 0 {
|
||||
slices.SortFunc(ts, func(a, b token) int {
|
||||
switch {
|
||||
case a.value < b.value:
|
||||
return 1
|
||||
case a.value > b.value:
|
||||
return -1
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
})
|
||||
return ts
|
||||
}
|
||||
// Heapify + siftDown - O(nlog(k))
|
||||
// Build min-heap of first k elements
|
||||
heap := ts[:k]
|
||||
for i := k/2 - 1; i >= 0; i-- {
|
||||
siftDown(heap, i, k)
|
||||
}
|
||||
|
||||
// Process remaining elements - if larger than heap root, replace root
|
||||
// Initialize min-heap with first k elements
|
||||
h := make(tokenHeap, k)
|
||||
copy(h, ts[:k])
|
||||
heap.Init(&h)
|
||||
|
||||
// Process remaining elements
|
||||
for i := k; i < len(ts); i++ {
|
||||
if ts[i].value > heap[0].value {
|
||||
heap[0] = ts[i]
|
||||
siftDown(heap, 0, k)
|
||||
if ts[i].value > h[0].value {
|
||||
heap.Pop(&h)
|
||||
heap.Push(&h, ts[i])
|
||||
}
|
||||
}
|
||||
|
||||
slices.Reverse(heap)
|
||||
// Convert heap to sorted slice in descending order
|
||||
result := make([]token, len(h))
|
||||
for i := k - 1; i >= 0; i-- {
|
||||
result[i] = heap.Pop(&h).(token)
|
||||
}
|
||||
|
||||
ts = heap
|
||||
return ts
|
||||
return result
|
||||
}
|
||||
|
||||
// topP limits tokens to those with cumulative probability p
|
||||
// requires ts to be sorted in descending order of probabilities
|
||||
func topP(ts []token, p float32) []token {
|
||||
if p == 1.0 {
|
||||
return ts
|
||||
@@ -100,96 +107,24 @@ func topP(ts []token, p float32) []token {
|
||||
for i, t := range ts {
|
||||
sum += t.value
|
||||
if sum > float32(p) {
|
||||
ts = ts[:i+1]
|
||||
return ts
|
||||
return ts[:i+1]
|
||||
}
|
||||
}
|
||||
|
||||
return ts
|
||||
}
|
||||
|
||||
// minP limits tokens to those with cumulative probability p
|
||||
// minP filters tokens with probabilities >= p * max_prob
|
||||
// requires ts to be sorted in descending order of probabilities
|
||||
func minP(ts []token, p float32) []token {
|
||||
if p == 1.0 {
|
||||
return ts
|
||||
}
|
||||
maxProb := ts[0].value
|
||||
|
||||
maxProb := float32(math.Inf(-1))
|
||||
for _, token := range ts {
|
||||
if token.value > maxProb {
|
||||
maxProb = token.value
|
||||
threshold := maxProb * p
|
||||
|
||||
for i, t := range ts {
|
||||
if t.value < threshold {
|
||||
return ts[:i]
|
||||
}
|
||||
}
|
||||
|
||||
threshold := maxProb * float32(p)
|
||||
|
||||
// Filter tokens in-place
|
||||
validTokens := ts[:0]
|
||||
for i, token := range ts {
|
||||
if token.value >= threshold {
|
||||
validTokens = append(validTokens, ts[i])
|
||||
}
|
||||
}
|
||||
|
||||
ts = validTokens
|
||||
return ts
|
||||
}
|
||||
|
||||
// TODO(parthsareen): possibly replace with simpler implementation https://github.com/ollama/ollama/issues/9584
|
||||
// sortLogits sorts implementation to sort tokens by logits using counting sort
|
||||
// counting sort is faster than built-in sort for this use case
|
||||
func sortLogits(tokens []token) {
|
||||
if len(tokens) <= 1 {
|
||||
return
|
||||
}
|
||||
|
||||
// Find max/min in a single pass
|
||||
minLogit, maxLogit := tokens[0].value, tokens[0].value
|
||||
for _, t := range tokens[1:] {
|
||||
if t.value < minLogit {
|
||||
minLogit = t.value
|
||||
} else if t.value > maxLogit {
|
||||
maxLogit = t.value
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate scaling to map to uint32 range
|
||||
logitRange := maxLogit - minLogit
|
||||
if logitRange < 1e-6 {
|
||||
return // All values effectively equal
|
||||
}
|
||||
|
||||
// Count frequencies directly from tokens
|
||||
const maxInt = (1 << 24) - 1 // Use 24 bits for good granularity
|
||||
var counts [256]int // For first byte
|
||||
|
||||
// First pass: count frequencies
|
||||
for _, t := range tokens {
|
||||
// Map to [0, maxInt] range
|
||||
score := min(uint32((t.value-minLogit)*float32(maxInt)/logitRange), maxInt)
|
||||
counts[score>>16]++
|
||||
}
|
||||
|
||||
// Calculate offsets
|
||||
var offset int
|
||||
for i := range counts {
|
||||
count := counts[i]
|
||||
counts[i] = offset
|
||||
offset += count
|
||||
}
|
||||
|
||||
// Second pass: place elements in correct position
|
||||
output := make([]token, len(tokens))
|
||||
// Track current positions
|
||||
countsCopy := counts
|
||||
|
||||
for i, t := range tokens {
|
||||
score := min(uint32((t.value-minLogit)*float32(maxInt)/logitRange), maxInt)
|
||||
|
||||
pos := countsCopy[score>>16]
|
||||
countsCopy[score>>16]++
|
||||
output[len(tokens)-1-pos] = tokens[i]
|
||||
}
|
||||
|
||||
copy(tokens, output)
|
||||
}
|
||||
|
||||
@@ -6,122 +6,293 @@ import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// Helper to convert float64 slice to logit slice
|
||||
func toTokens(values []float64) []token {
|
||||
// Helper to convert float32 slice to logit slice
|
||||
func toTokens(values []float32) []token {
|
||||
tokens := make([]token, len(values))
|
||||
for i, v := range values {
|
||||
tokens[i] = token{
|
||||
id: int32(i),
|
||||
value: float32(v),
|
||||
value: v,
|
||||
}
|
||||
}
|
||||
return tokens
|
||||
}
|
||||
|
||||
// Helper to compare logit slices
|
||||
func compareLogits(t *testing.T, name string, want []float64, got []token) {
|
||||
func compareLogits(t *testing.T, name string, want []float32, got []token) {
|
||||
t.Helper()
|
||||
if len(want) != len(got) {
|
||||
t.Errorf("%s: length mismatch: want %d, got %d", name, len(want), len(got))
|
||||
return
|
||||
}
|
||||
for i := range want {
|
||||
if math.Abs(float64(got[i].value)-want[i]) > 1e-6 {
|
||||
if math.Abs(float64(got[i].value-want[i])) > 1e-6 {
|
||||
t.Errorf("%s: index %d: want %f, got %f", name, i, want[i], got[i].value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTemperatureAndSoftmax(t *testing.T) {
|
||||
input := []float64{1, 4, -2, 0}
|
||||
got := temperature(toTokens(input), 0.5)
|
||||
func TestTemperature(t *testing.T) {
|
||||
input := []float32{1.0, 4.0, -2.0, 0.0}
|
||||
tokens := toTokens(input)
|
||||
temperature(tokens, 0.5)
|
||||
want := []float32{2.0, 8.0, -4.0, 0.0}
|
||||
compareLogits(t, "temperature(0.5)", want, tokens)
|
||||
|
||||
// Check probabilities sum to 1
|
||||
var sum float32
|
||||
for _, token := range got {
|
||||
sum += token.value
|
||||
}
|
||||
if math.Abs(float64(sum)-1.0) > 1e-6 {
|
||||
t.Errorf("probabilities don't sum to 1: got %f", sum)
|
||||
input = []float32{1.0, 4.0, -2.0, 0.0}
|
||||
tokens = toTokens(input)
|
||||
temperature(tokens, 1.0)
|
||||
want = []float32{1.0, 4.0, -2.0, 0.0}
|
||||
compareLogits(t, "temperature(1)", want, tokens)
|
||||
|
||||
input = []float32{1.0, 4.0, -2.0, 0.0}
|
||||
tokens = toTokens(input)
|
||||
temperature(tokens, 0.0)
|
||||
want = []float32{1e7, 4e7, -2e7, 0.0}
|
||||
compareLogits(t, "temperature(0)", want, tokens)
|
||||
}
|
||||
|
||||
func TestSoftmax(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []float32
|
||||
expected []float32
|
||||
}{
|
||||
{
|
||||
name: "correctness softmax",
|
||||
input: []float32{1, -2, 3, 0},
|
||||
expected: []float32{0.113550, 0.005653, 0.839024, 0.041773},
|
||||
},
|
||||
{
|
||||
name: "normal distribution",
|
||||
input: []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367},
|
||||
},
|
||||
{
|
||||
name: "single value",
|
||||
input: []float32{1.0},
|
||||
},
|
||||
{
|
||||
name: "identical values",
|
||||
input: []float32{0.9, 0.9, 0.9},
|
||||
},
|
||||
{
|
||||
name: "large values",
|
||||
input: []float32{1000.0, 2000.0, 3000.0},
|
||||
},
|
||||
{
|
||||
name: "small values",
|
||||
input: []float32{1e-6, 2e-6, 3e-6},
|
||||
},
|
||||
{
|
||||
name: "negative values",
|
||||
input: []float32{-1.0, -2.0, -3.0},
|
||||
},
|
||||
{
|
||||
name: "mixed values",
|
||||
input: []float32{-100.0, 0.0, 100.0},
|
||||
},
|
||||
}
|
||||
|
||||
got = temperature(toTokens(input), 1)
|
||||
// Check probabilities sum to 1
|
||||
sum = 0.0
|
||||
for _, token := range got {
|
||||
sum += token.value
|
||||
}
|
||||
if math.Abs(float64(sum)-1.0) > 1e-6 {
|
||||
t.Errorf("probabilities don't sum to 1: got %f", sum)
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tokens := toTokens(tt.input)
|
||||
softmax(tokens)
|
||||
|
||||
if tt.expected != nil {
|
||||
compareLogits(t, tt.name, tt.expected, tokens)
|
||||
return
|
||||
}
|
||||
|
||||
// Check probabilities sum to 1
|
||||
var sum float32
|
||||
for _, token := range tokens {
|
||||
sum += token.value
|
||||
if token.value < 0 || token.value > 1 {
|
||||
t.Errorf("probability out of range [0,1]: got %f", token.value)
|
||||
}
|
||||
}
|
||||
if math.Abs(float64(sum-1.0)) > 1e-6 {
|
||||
t.Errorf("probabilities don't sum to 1: got %f", sum)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTopK(t *testing.T) {
|
||||
input := []float64{-3, -2, -1, 0, 1, 2, 4}
|
||||
|
||||
// Test k=3
|
||||
got := topK(toTokens(input), 3)
|
||||
if len(got) != 3 {
|
||||
t.Errorf("topK(3): wrong length: want 3, got %d", len(got))
|
||||
input := []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
|
||||
tokens := toTokens(input)
|
||||
tokens = topK(tokens, 5)
|
||||
if len(tokens) != 5 {
|
||||
t.Errorf("topK(5): wrong length: want 5, got %d", len(tokens))
|
||||
}
|
||||
// Should keep highest 3 values: 4, 2, 1
|
||||
want := []float64{4, 2, 1}
|
||||
compareLogits(t, "topK(3)", want, got)
|
||||
want := []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154}
|
||||
compareLogits(t, "topK(3)", want, tokens)
|
||||
|
||||
// Test k > len
|
||||
got = topK(toTokens(input), 10)
|
||||
compareLogits(t, "topK(10)", input, got)
|
||||
tokens = toTokens(input)
|
||||
tokens = topK(tokens, 20)
|
||||
if len(tokens) != len(input) {
|
||||
t.Errorf("topK(20): wrong length: want %d, got %d", len(input), len(tokens))
|
||||
}
|
||||
|
||||
input = []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
|
||||
want = []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
|
||||
tokens = toTokens(input)
|
||||
tokens = topK(tokens, -1)
|
||||
if len(tokens) != len(input) {
|
||||
t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(tokens))
|
||||
}
|
||||
compareLogits(t, "topK(-1)", want, tokens)
|
||||
|
||||
input = []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
|
||||
want = []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
|
||||
tokens = toTokens(input)
|
||||
tokens = topK(tokens, 0)
|
||||
if len(tokens) != len(input) {
|
||||
t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(tokens))
|
||||
}
|
||||
compareLogits(t, "topK(-1)", want, tokens)
|
||||
|
||||
input = []float32{-1e7, -2e7, -3e7, -4e7}
|
||||
tokens = toTokens(input)
|
||||
tokens = topK(tokens, 1)
|
||||
if len(tokens) < 1 {
|
||||
t.Error("topK should keep at least one token")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTopP(t *testing.T) {
|
||||
input := []float64{-3, -2, -1, 0, 1, 2, 4}
|
||||
input := []float32{-3, -2, -1, 0, 1, 2, 4}
|
||||
tokens := toTokens(input)
|
||||
|
||||
// First apply temperature and softmax to get probabilities
|
||||
tokens = temperature(tokens, 1)
|
||||
sortLogits(tokens)
|
||||
softmax(tokens)
|
||||
tokens = topK(tokens, 20)
|
||||
|
||||
// Then apply topP
|
||||
got := topP(tokens, 0.95)
|
||||
// Test with very high p value
|
||||
got := topP(tokens, 1.0)
|
||||
|
||||
// Should keep all tokens since p is 1
|
||||
if len(got) != len(input) {
|
||||
t.Errorf("topP(1.0): should keep all tokens, got %d, want %d", len(got), len(input))
|
||||
}
|
||||
|
||||
// Test with normal p value
|
||||
got = topP(tokens, 0.95)
|
||||
|
||||
// Should keep tokens until cumsum > 0.95
|
||||
if len(got) > 3 {
|
||||
t.Errorf("topP(0.95): kept too many tokens: got %d", len(got))
|
||||
t.Errorf("topP(0.95): kept too many tokens: got %d", len(tokens))
|
||||
t.Logf("got: %v", got)
|
||||
}
|
||||
|
||||
// Test edge case - ensure at least one token remains
|
||||
input = []float32{-1e6, -1e6, -1e7}
|
||||
tokens = toTokens(input)
|
||||
tokens = topK(tokens, 20)
|
||||
softmax(tokens)
|
||||
got = topP(tokens, 0.0)
|
||||
if len(got) < 1 {
|
||||
t.Error("topP should keep at least one token")
|
||||
}
|
||||
|
||||
// Test with zero p value
|
||||
got = topP(tokens, 0.0)
|
||||
|
||||
// Should keep only the highest probability token
|
||||
if len(got) != 1 {
|
||||
t.Errorf("topP(0.0): should keep only one token, got %d", len(got))
|
||||
t.Logf("got: %v", got)
|
||||
}
|
||||
|
||||
tokens = toTokens(input)
|
||||
tokens = topK(tokens, 20)
|
||||
softmax(tokens)
|
||||
got = topP(tokens, 1e-10)
|
||||
if len(got) == 0 {
|
||||
t.Errorf("topP(1e-10): should keep at least one token, got %d", len(got))
|
||||
t.Logf("got: %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMinP(t *testing.T) {
|
||||
input := []float64{-3, -2, -1, 0, 1, 2, 4, 3}
|
||||
input := []float32{-2, 0, -1, -3, 2, 1, 4, 3}
|
||||
tokens := toTokens(input)
|
||||
|
||||
// First apply temperature and softmax
|
||||
tokens = temperature(tokens, 1)
|
||||
tokens = topK(tokens, 20)
|
||||
softmax(tokens)
|
||||
|
||||
// Then apply minP
|
||||
got := minP(tokens, 0.2)
|
||||
tokens = minP(tokens, 1.0)
|
||||
|
||||
if len(tokens) != 1 {
|
||||
t.Errorf("minP(1.0): should keep all tokens, got %d, want %d", len(tokens), len(tokens))
|
||||
}
|
||||
|
||||
// Test with normal p value
|
||||
tokens = toTokens(input) // Reset tokens
|
||||
tokens = topK(tokens, 20)
|
||||
softmax(tokens)
|
||||
tokens = minP(tokens, 0.2)
|
||||
|
||||
// Should keep tokens with prob >= 0.2 * max_prob
|
||||
if len(got) > 3 {
|
||||
t.Errorf("minP(0.2): kept too many tokens: got %d", len(got))
|
||||
if len(tokens) > 3 {
|
||||
t.Errorf("minP(0.2): kept too many tokens: got %d", len(tokens))
|
||||
t.Logf("got: %v", tokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSortLogits(t *testing.T) {
|
||||
input := []float64{3, 1, 4, 2, -1, 0, -2}
|
||||
tokens := toTokens(input)
|
||||
// Test with zero p value
|
||||
tokens = toTokens(input) // Reset tokens
|
||||
tokens = topK(tokens, 20)
|
||||
softmax(tokens)
|
||||
tokens = minP(tokens, 0.0)
|
||||
|
||||
sortLogits(tokens)
|
||||
// Should keep only the highest probability token
|
||||
if len(tokens) != len(input) {
|
||||
t.Errorf("minP(0.0): should keep only one token, got %d", len(tokens))
|
||||
t.Logf("got: %v", tokens)
|
||||
}
|
||||
|
||||
for i := 1; i < len(tokens); i++ {
|
||||
if tokens[i].value > tokens[i-1].value {
|
||||
t.Errorf("sortLogits: tokens not sorted in descending order at index %d: %f > %f",
|
||||
i, tokens[i].value, tokens[i-1].value)
|
||||
// Test with single token
|
||||
tokens = toTokens(input[:1])
|
||||
tokens = topK(tokens, 20)
|
||||
softmax(tokens)
|
||||
tokens = minP(tokens, 0.1)
|
||||
|
||||
// Should keep only the highest probability token
|
||||
if len(tokens) != 1 {
|
||||
t.Errorf("minP(0.1): should return single token, got %d", len(tokens))
|
||||
t.Logf("got: %v", tokens)
|
||||
}
|
||||
|
||||
input = []float32{1e-10, 1e-10, 1e-10}
|
||||
tokens = toTokens(input)
|
||||
softmax(tokens)
|
||||
tokens = minP(tokens, 1.0)
|
||||
if len(tokens) < 1 {
|
||||
t.Error("minP should keep at least one token even with extreme probabilities")
|
||||
got := minP(tokens, 1.0)
|
||||
|
||||
if len(got) != 1 {
|
||||
t.Errorf("minP(1.0): should keep all tokens, got %d, want %d", len(got), len(tokens))
|
||||
}
|
||||
|
||||
// Test with normal p value
|
||||
got = minP(tokens, 0.2)
|
||||
|
||||
// Should keep tokens with prob >= 0.2 * max_prob
|
||||
if len(got) > 3 {
|
||||
t.Errorf("minP(0.2): kept too many tokens: got %d", len(got))
|
||||
t.Logf("got: %v", got)
|
||||
}
|
||||
|
||||
// Test with zero p value
|
||||
got = minP(tokens, 0.0)
|
||||
|
||||
// Should keep only the highest probability token
|
||||
if len(got) != len(tokens) {
|
||||
t.Errorf("minP(0.0): should keep only one token, got %d", len(got))
|
||||
t.Logf("got: %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
want := []float64{4, 3, 2, 1, 0, -1, -2}
|
||||
compareLogits(t, "sortLogits", want, tokens)
|
||||
}
|
||||
|
||||
func BenchmarkTransforms(b *testing.B) {
|
||||
@@ -144,11 +315,19 @@ func BenchmarkTransforms(b *testing.B) {
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("Softmax", func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
copy(tokensCopy, tokens)
|
||||
softmax(tokensCopy)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("TopK", func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
copy(tokensCopy, tokens)
|
||||
topK(tokensCopy, 10)
|
||||
tokens = topK(tokensCopy, 10)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -156,7 +335,7 @@ func BenchmarkTransforms(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
copy(tokensCopy, tokens)
|
||||
topP(tokensCopy, 0.9)
|
||||
tokens = topP(tokensCopy, 0.9)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -164,7 +343,7 @@ func BenchmarkTransforms(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
copy(tokensCopy, tokens)
|
||||
minP(tokensCopy, 0.2)
|
||||
tokens = minP(tokensCopy, 0.2)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -172,7 +351,7 @@ func BenchmarkTransforms(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
copy(tokensCopy, tokens)
|
||||
sortLogits(tokensCopy)
|
||||
tokens = topK(tokensCopy, 200000)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ usage() {
|
||||
exit 1
|
||||
}
|
||||
|
||||
export VERSION=${VERSION:-$(git describe --tags --dirty)}
|
||||
export VERSION=${VERSION:-$(git describe --tags --first-parent --abbrev=7 --long --dirty --always | sed -e "s/^v//g")}
|
||||
export GOFLAGS="'-ldflags=-w -s \"-X=github.com/ollama/ollama/version.Version=${VERSION#v}\" \"-X=github.com/ollama/ollama/server.mode=release\"'"
|
||||
export CGO_CPPFLAGS='-mmacosx-version-min=11.3'
|
||||
|
||||
|
||||
@@ -29,8 +29,9 @@ import (
|
||||
const maxRetries = 6
|
||||
|
||||
var (
|
||||
errMaxRetriesExceeded = errors.New("max retries exceeded")
|
||||
errPartStalled = errors.New("part stalled")
|
||||
errMaxRetriesExceeded = errors.New("max retries exceeded")
|
||||
errPartStalled = errors.New("part stalled")
|
||||
errMaxRedirectsExceeded = errors.New("maximum redirects exceeded (10) for directURL")
|
||||
)
|
||||
|
||||
var blobDownloadManager sync.Map
|
||||
@@ -236,7 +237,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
|
||||
|
||||
newOpts.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
||||
if len(via) > 10 {
|
||||
return errors.New("maximum redirects exceeded (10) for directURL")
|
||||
return errMaxRedirectsExceeded
|
||||
}
|
||||
|
||||
// if the hostname is the same, allow the redirect
|
||||
|
||||
108
server/images.go
108
server/images.go
@@ -35,14 +35,9 @@ var (
|
||||
errCapabilityCompletion = errors.New("completion")
|
||||
errCapabilityTools = errors.New("tools")
|
||||
errCapabilityInsert = errors.New("insert")
|
||||
)
|
||||
|
||||
type Capability string
|
||||
|
||||
const (
|
||||
CapabilityCompletion = Capability("completion")
|
||||
CapabilityTools = Capability("tools")
|
||||
CapabilityInsert = Capability("insert")
|
||||
errCapabilityVision = errors.New("vision")
|
||||
errCapabilityEmbedding = errors.New("embedding")
|
||||
errInsecureProtocol = errors.New("insecure protocol http")
|
||||
)
|
||||
|
||||
type registryOptions struct {
|
||||
@@ -65,52 +60,83 @@ type Model struct {
|
||||
System string
|
||||
License []string
|
||||
Digest string
|
||||
Options map[string]interface{}
|
||||
Options map[string]any
|
||||
Messages []api.Message
|
||||
|
||||
Template *template.Template
|
||||
}
|
||||
|
||||
// Capabilities returns the capabilities that the model supports
|
||||
func (m *Model) Capabilities() []model.Capability {
|
||||
capabilities := []model.Capability{}
|
||||
|
||||
// Check for completion capability
|
||||
r, err := os.Open(m.ModelPath)
|
||||
if err == nil {
|
||||
defer r.Close()
|
||||
|
||||
f, _, err := ggml.Decode(r, 0)
|
||||
if err == nil {
|
||||
if _, ok := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]; ok {
|
||||
capabilities = append(capabilities, model.CapabilityEmbedding)
|
||||
} else {
|
||||
capabilities = append(capabilities, model.CapabilityCompletion)
|
||||
}
|
||||
if _, ok := f.KV()[fmt.Sprintf("%s.vision.block_count", f.KV().Architecture())]; ok {
|
||||
capabilities = append(capabilities, model.CapabilityVision)
|
||||
}
|
||||
} else {
|
||||
slog.Error("couldn't decode ggml", "error", err)
|
||||
}
|
||||
} else {
|
||||
slog.Error("couldn't open model file", "error", err)
|
||||
}
|
||||
|
||||
if m.Template == nil {
|
||||
return capabilities
|
||||
}
|
||||
|
||||
// Check for tools capability
|
||||
if slices.Contains(m.Template.Vars(), "tools") {
|
||||
capabilities = append(capabilities, model.CapabilityTools)
|
||||
}
|
||||
|
||||
// Check for insert capability
|
||||
if slices.Contains(m.Template.Vars(), "suffix") {
|
||||
capabilities = append(capabilities, model.CapabilityInsert)
|
||||
}
|
||||
|
||||
return capabilities
|
||||
}
|
||||
|
||||
// CheckCapabilities checks if the model has the specified capabilities returning an error describing
|
||||
// any missing or unknown capabilities
|
||||
func (m *Model) CheckCapabilities(caps ...Capability) error {
|
||||
func (m *Model) CheckCapabilities(want ...model.Capability) error {
|
||||
available := m.Capabilities()
|
||||
var errs []error
|
||||
for _, cap := range caps {
|
||||
switch cap {
|
||||
case CapabilityCompletion:
|
||||
r, err := os.Open(m.ModelPath)
|
||||
if err != nil {
|
||||
slog.Error("couldn't open model file", "error", err)
|
||||
continue
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
// TODO(mxyng): decode the GGML into model to avoid doing this multiple times
|
||||
f, _, err := ggml.Decode(r, 0)
|
||||
if err != nil {
|
||||
slog.Error("couldn't decode ggml", "error", err)
|
||||
continue
|
||||
}
|
||||
// Map capabilities to their corresponding error
|
||||
capToErr := map[model.Capability]error{
|
||||
model.CapabilityCompletion: errCapabilityCompletion,
|
||||
model.CapabilityTools: errCapabilityTools,
|
||||
model.CapabilityInsert: errCapabilityInsert,
|
||||
model.CapabilityVision: errCapabilityVision,
|
||||
model.CapabilityEmbedding: errCapabilityEmbedding,
|
||||
}
|
||||
|
||||
if _, ok := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]; ok {
|
||||
errs = append(errs, errCapabilityCompletion)
|
||||
}
|
||||
case CapabilityTools:
|
||||
if !slices.Contains(m.Template.Vars(), "tools") {
|
||||
errs = append(errs, errCapabilityTools)
|
||||
}
|
||||
case CapabilityInsert:
|
||||
vars := m.Template.Vars()
|
||||
if !slices.Contains(vars, "suffix") {
|
||||
errs = append(errs, errCapabilityInsert)
|
||||
}
|
||||
default:
|
||||
for _, cap := range want {
|
||||
err, ok := capToErr[cap]
|
||||
if !ok {
|
||||
slog.Error("unknown capability", "capability", cap)
|
||||
return fmt.Errorf("unknown capability: %s", cap)
|
||||
}
|
||||
|
||||
if !slices.Contains(available, cap) {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := errors.Join(errs...); err != nil {
|
||||
if len(errs) > 0 {
|
||||
return fmt.Errorf("%w %w", errCapabilities, errors.Join(errs...))
|
||||
}
|
||||
|
||||
@@ -479,7 +505,7 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
fn(api.ProgressResponse{Status: "retrieving manifest"})
|
||||
|
||||
if mp.ProtocolScheme == "http" && !regOpts.Insecure {
|
||||
return errors.New("insecure protocol http")
|
||||
return errInsecureProtocol
|
||||
}
|
||||
|
||||
manifest, _, err := GetManifest(mp)
|
||||
@@ -543,7 +569,7 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
}
|
||||
|
||||
if mp.ProtocolScheme == "http" && !regOpts.Insecure {
|
||||
return errors.New("insecure protocol http")
|
||||
return errInsecureProtocol
|
||||
}
|
||||
|
||||
fn(api.ProgressResponse{Status: "pulling manifest"})
|
||||
|
||||
360
server/images_test.go
Normal file
360
server/images_test.go
Normal file
@@ -0,0 +1,360 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/template"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
// Constants for GGUF magic bytes and version
|
||||
var (
|
||||
ggufMagic = []byte{0x47, 0x47, 0x55, 0x46} // "GGUF"
|
||||
ggufVer = uint32(3) // Version 3
|
||||
)
|
||||
|
||||
// Helper function to create mock GGUF data
|
||||
func createMockGGUFData(architecture string, vision bool) []byte {
|
||||
var buf bytes.Buffer
|
||||
|
||||
// Write GGUF header
|
||||
buf.Write(ggufMagic)
|
||||
binary.Write(&buf, binary.LittleEndian, ggufVer)
|
||||
|
||||
// Write tensor count (0 for our test)
|
||||
var numTensors uint64 = 0
|
||||
binary.Write(&buf, binary.LittleEndian, numTensors)
|
||||
|
||||
// Calculate number of metadata entries
|
||||
numMetaEntries := uint64(1) // architecture entry
|
||||
if vision {
|
||||
numMetaEntries++
|
||||
}
|
||||
// Add embedding entry if architecture is "bert"
|
||||
if architecture == "bert" {
|
||||
numMetaEntries++
|
||||
}
|
||||
binary.Write(&buf, binary.LittleEndian, numMetaEntries)
|
||||
|
||||
// Write architecture metadata
|
||||
archKey := "general.architecture"
|
||||
keyLen := uint64(len(archKey))
|
||||
binary.Write(&buf, binary.LittleEndian, keyLen)
|
||||
buf.WriteString(archKey)
|
||||
|
||||
// String type (8)
|
||||
var strType uint32 = 8
|
||||
binary.Write(&buf, binary.LittleEndian, strType)
|
||||
|
||||
// String length
|
||||
strLen := uint64(len(architecture))
|
||||
binary.Write(&buf, binary.LittleEndian, strLen)
|
||||
buf.WriteString(architecture)
|
||||
|
||||
if vision {
|
||||
visionKey := architecture + ".vision.block_count"
|
||||
keyLen = uint64(len(visionKey))
|
||||
binary.Write(&buf, binary.LittleEndian, keyLen)
|
||||
buf.WriteString(visionKey)
|
||||
|
||||
// uint32 type (4)
|
||||
var uint32Type uint32 = 4
|
||||
binary.Write(&buf, binary.LittleEndian, uint32Type)
|
||||
|
||||
// uint32 value (1)
|
||||
var countVal uint32 = 1
|
||||
binary.Write(&buf, binary.LittleEndian, countVal)
|
||||
}
|
||||
// Write embedding metadata if architecture is "bert"
|
||||
if architecture == "bert" {
|
||||
poolKey := architecture + ".pooling_type"
|
||||
keyLen = uint64(len(poolKey))
|
||||
binary.Write(&buf, binary.LittleEndian, keyLen)
|
||||
buf.WriteString(poolKey)
|
||||
|
||||
// uint32 type (4)
|
||||
var uint32Type uint32 = 4
|
||||
binary.Write(&buf, binary.LittleEndian, uint32Type)
|
||||
|
||||
// uint32 value (1)
|
||||
var poolingVal uint32 = 1
|
||||
binary.Write(&buf, binary.LittleEndian, poolingVal)
|
||||
}
|
||||
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func TestModelCapabilities(t *testing.T) {
|
||||
// Create a temporary directory for test files
|
||||
tempDir, err := os.MkdirTemp("", "model_capabilities_test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp directory: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
// Create different types of mock model files
|
||||
completionModelPath := filepath.Join(tempDir, "model.bin")
|
||||
visionModelPath := filepath.Join(tempDir, "vision_model.bin")
|
||||
embeddingModelPath := filepath.Join(tempDir, "embedding_model.bin")
|
||||
// Create a simple model file for tests that don't depend on GGUF content
|
||||
simpleModelPath := filepath.Join(tempDir, "simple_model.bin")
|
||||
|
||||
err = os.WriteFile(completionModelPath, createMockGGUFData("llama", false), 0o644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create completion model file: %v", err)
|
||||
}
|
||||
err = os.WriteFile(visionModelPath, createMockGGUFData("llama", true), 0o644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create completion model file: %v", err)
|
||||
}
|
||||
err = os.WriteFile(embeddingModelPath, createMockGGUFData("bert", false), 0o644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create embedding model file: %v", err)
|
||||
}
|
||||
err = os.WriteFile(simpleModelPath, []byte("dummy model data"), 0o644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create simple model file: %v", err)
|
||||
}
|
||||
|
||||
toolsInsertTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}{{ if .suffix }}{{ .suffix }}{{ end }}")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse template: %v", err)
|
||||
}
|
||||
chatTemplate, err := template.Parse("{{ .prompt }}")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse template: %v", err)
|
||||
}
|
||||
toolsTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse template: %v", err)
|
||||
}
|
||||
|
||||
testModels := []struct {
|
||||
name string
|
||||
model Model
|
||||
expectedCaps []model.Capability
|
||||
}{
|
||||
{
|
||||
name: "model with completion capability",
|
||||
model: Model{
|
||||
ModelPath: completionModelPath,
|
||||
Template: chatTemplate,
|
||||
},
|
||||
expectedCaps: []model.Capability{model.CapabilityCompletion},
|
||||
},
|
||||
|
||||
{
|
||||
name: "model with completion, tools, and insert capability",
|
||||
model: Model{
|
||||
ModelPath: completionModelPath,
|
||||
Template: toolsInsertTemplate,
|
||||
},
|
||||
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityTools, model.CapabilityInsert},
|
||||
},
|
||||
{
|
||||
name: "model with tools and insert capability",
|
||||
model: Model{
|
||||
ModelPath: simpleModelPath,
|
||||
Template: toolsInsertTemplate,
|
||||
},
|
||||
expectedCaps: []model.Capability{model.CapabilityTools, model.CapabilityInsert},
|
||||
},
|
||||
{
|
||||
name: "model with tools capability",
|
||||
model: Model{
|
||||
ModelPath: simpleModelPath,
|
||||
Template: toolsTemplate,
|
||||
},
|
||||
expectedCaps: []model.Capability{model.CapabilityTools},
|
||||
},
|
||||
{
|
||||
name: "model with vision capability",
|
||||
model: Model{
|
||||
ModelPath: visionModelPath,
|
||||
Template: chatTemplate,
|
||||
},
|
||||
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityVision},
|
||||
},
|
||||
{
|
||||
name: "model with vision, tools, and insert capability",
|
||||
model: Model{
|
||||
ModelPath: visionModelPath,
|
||||
Template: toolsInsertTemplate,
|
||||
},
|
||||
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityVision, model.CapabilityTools, model.CapabilityInsert},
|
||||
},
|
||||
{
|
||||
name: "model with embedding capability",
|
||||
model: Model{
|
||||
ModelPath: embeddingModelPath,
|
||||
Template: chatTemplate,
|
||||
},
|
||||
expectedCaps: []model.Capability{model.CapabilityEmbedding},
|
||||
},
|
||||
}
|
||||
|
||||
// compare two slices of model.Capability regardless of order
|
||||
compareCapabilities := func(a, b []model.Capability) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
|
||||
aCount := make(map[model.Capability]int)
|
||||
for _, cap := range a {
|
||||
aCount[cap]++
|
||||
}
|
||||
|
||||
bCount := make(map[model.Capability]int)
|
||||
for _, cap := range b {
|
||||
bCount[cap]++
|
||||
}
|
||||
|
||||
for cap, count := range aCount {
|
||||
if bCount[cap] != count {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
for _, tt := range testModels {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Test Capabilities method
|
||||
caps := tt.model.Capabilities()
|
||||
if !compareCapabilities(caps, tt.expectedCaps) {
|
||||
t.Errorf("Expected capabilities %v, got %v", tt.expectedCaps, caps)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelCheckCapabilities(t *testing.T) {
|
||||
// Create a temporary directory for test files
|
||||
tempDir, err := os.MkdirTemp("", "model_check_capabilities_test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp directory: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
visionModelPath := filepath.Join(tempDir, "vision_model.bin")
|
||||
simpleModelPath := filepath.Join(tempDir, "model.bin")
|
||||
embeddingModelPath := filepath.Join(tempDir, "embedding_model.bin")
|
||||
|
||||
err = os.WriteFile(simpleModelPath, []byte("dummy model data"), 0o644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create simple model file: %v", err)
|
||||
}
|
||||
err = os.WriteFile(visionModelPath, createMockGGUFData("llama", true), 0o644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create vision model file: %v", err)
|
||||
}
|
||||
err = os.WriteFile(embeddingModelPath, createMockGGUFData("bert", false), 0o644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create embedding model file: %v", err)
|
||||
}
|
||||
|
||||
toolsInsertTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}{{ if .suffix }}{{ .suffix }}{{ end }}")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse template: %v", err)
|
||||
}
|
||||
chatTemplate, err := template.Parse("{{ .prompt }}")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse template: %v", err)
|
||||
}
|
||||
toolsTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse template: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
model Model
|
||||
checkCaps []model.Capability
|
||||
expectedErrMsg string
|
||||
}{
|
||||
{
|
||||
name: "completion model without tools capability",
|
||||
model: Model{
|
||||
ModelPath: simpleModelPath,
|
||||
Template: chatTemplate,
|
||||
},
|
||||
checkCaps: []model.Capability{model.CapabilityTools},
|
||||
expectedErrMsg: "does not support tools",
|
||||
},
|
||||
{
|
||||
name: "model with all needed capabilities",
|
||||
model: Model{
|
||||
ModelPath: simpleModelPath,
|
||||
Template: toolsInsertTemplate,
|
||||
},
|
||||
checkCaps: []model.Capability{model.CapabilityTools, model.CapabilityInsert},
|
||||
},
|
||||
{
|
||||
name: "model missing insert capability",
|
||||
model: Model{
|
||||
ModelPath: simpleModelPath,
|
||||
Template: toolsTemplate,
|
||||
},
|
||||
checkCaps: []model.Capability{model.CapabilityInsert},
|
||||
expectedErrMsg: "does not support insert",
|
||||
},
|
||||
{
|
||||
name: "model missing vision capability",
|
||||
model: Model{
|
||||
ModelPath: simpleModelPath,
|
||||
Template: toolsTemplate,
|
||||
},
|
||||
checkCaps: []model.Capability{model.CapabilityVision},
|
||||
expectedErrMsg: "does not support vision",
|
||||
},
|
||||
{
|
||||
name: "model with vision capability",
|
||||
model: Model{
|
||||
ModelPath: visionModelPath,
|
||||
Template: chatTemplate,
|
||||
},
|
||||
checkCaps: []model.Capability{model.CapabilityVision},
|
||||
},
|
||||
{
|
||||
name: "model with embedding capability",
|
||||
model: Model{
|
||||
ModelPath: embeddingModelPath,
|
||||
Template: chatTemplate,
|
||||
},
|
||||
checkCaps: []model.Capability{model.CapabilityEmbedding},
|
||||
},
|
||||
{
|
||||
name: "unknown capability",
|
||||
model: Model{
|
||||
ModelPath: simpleModelPath,
|
||||
Template: chatTemplate,
|
||||
},
|
||||
checkCaps: []model.Capability{"unknown"},
|
||||
expectedErrMsg: "unknown capability",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Test CheckCapabilities method
|
||||
err := tt.model.CheckCapabilities(tt.checkCaps...)
|
||||
if tt.expectedErrMsg == "" {
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got: %v", err)
|
||||
}
|
||||
} else {
|
||||
if err == nil {
|
||||
t.Errorf("Expected error containing %q, got nil", tt.expectedErrMsg)
|
||||
} else if !strings.Contains(err.Error(), tt.expectedErrMsg) {
|
||||
t.Errorf("Expected error containing %q, got: %v", tt.expectedErrMsg, err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user