mirror of
https://github.com/ollama/ollama.git
synced 2026-02-05 13:13:34 -05:00
Compare commits
41 Commits
v0.15.0
...
parth-laun
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7b24289871 | ||
|
|
3453020b23 | ||
|
|
2e1e05db5d | ||
|
|
aca0489121 | ||
|
|
80157f7512 | ||
|
|
0b07dfdc71 | ||
|
|
545384aa16 | ||
|
|
a6355329bf | ||
|
|
0398b24b42 | ||
|
|
75b1dddf91 | ||
|
|
e1e80ffc3e | ||
|
|
71896485fd | ||
|
|
ef00199fb4 | ||
|
|
8f4a008139 | ||
|
|
d8cc798c2b | ||
|
|
6582f6da5c | ||
|
|
0334ffa625 | ||
|
|
d11fbd2c60 | ||
|
|
6a7c3f188e | ||
|
|
427e2c962a | ||
|
|
27db7f806f | ||
|
|
3590fbfa76 | ||
|
|
cd0094f772 | ||
|
|
06bc8e6712 | ||
|
|
fc5f9bb448 | ||
|
|
a0740f7ef7 | ||
|
|
a0923cbdd0 | ||
|
|
f92e362b2e | ||
|
|
aa23d8ecd2 | ||
|
|
7b62c41060 | ||
|
|
26acab64b7 | ||
|
|
e0f03790b1 | ||
|
|
3ab842b0f5 | ||
|
|
b8e8ef8929 | ||
|
|
465d124183 | ||
|
|
d310e56fa3 | ||
|
|
a1ca428c90 | ||
|
|
16750865d1 | ||
|
|
f3b476c592 | ||
|
|
5267d31d56 | ||
|
|
b44f56319f |
2
.github/workflows/release.yaml
vendored
2
.github/workflows/release.yaml
vendored
@@ -92,7 +92,7 @@ jobs:
|
||||
flags: ''
|
||||
- os: windows
|
||||
arch: amd64
|
||||
preset: 'CUDA 13'
|
||||
preset: 'CUDA 13 Windows'
|
||||
install: https://developer.download.nvidia.com/compute/cuda/13.0.0/local_installers/cuda_13.0.0_windows.exe
|
||||
cuda-components:
|
||||
- '"cudart"'
|
||||
|
||||
@@ -40,7 +40,17 @@
|
||||
"name": "CUDA 13",
|
||||
"inherits": [ "CUDA" ],
|
||||
"cacheVariables": {
|
||||
"CMAKE_CUDA_ARCHITECTURES": "75-virtual;80-virtual;86-virtual;87-virtual;89-virtual;90-virtual;90a-virtual;100-virtual;103-virtual;110-virtual;120-virtual;121-virtual",
|
||||
"CMAKE_CUDA_ARCHITECTURES": "75-virtual;80-virtual;86-virtual;89-virtual;90-virtual;90a-virtual;100-virtual;103-virtual;110-virtual;120-virtual;121-virtual",
|
||||
"CMAKE_CUDA_FLAGS": "-t 4",
|
||||
"OLLAMA_RUNNER_DIR": "cuda_v13"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "CUDA 13 Windows",
|
||||
"inherits": [ "CUDA" ],
|
||||
"description": "Reduced architecture set for Windows to avoid MSVC template compilation issues",
|
||||
"cacheVariables": {
|
||||
"CMAKE_CUDA_ARCHITECTURES": "75-virtual;89-virtual;100-virtual;120-virtual",
|
||||
"CMAKE_CUDA_FLAGS": "-t 4",
|
||||
"OLLAMA_RUNNER_DIR": "cuda_v13"
|
||||
}
|
||||
@@ -138,6 +148,11 @@
|
||||
"inherits": [ "CUDA" ],
|
||||
"configurePreset": "CUDA 13"
|
||||
},
|
||||
{
|
||||
"name": "CUDA 13 Windows",
|
||||
"inherits": [ "CUDA" ],
|
||||
"configurePreset": "CUDA 13 Windows"
|
||||
},
|
||||
{
|
||||
"name": "JetPack 5",
|
||||
"inherits": [ "CUDA" ],
|
||||
|
||||
@@ -169,8 +169,10 @@ COPY . .
|
||||
RUN git clone --depth 1 --branch "$(cat MLX_VERSION)" https://github.com/ml-explore/mlx-c.git build/_deps/mlx-c-src
|
||||
ARG GOFLAGS="'-ldflags=-w -s'"
|
||||
ENV CGO_ENABLED=1
|
||||
ENV CGO_CFLAGS="-I/go/src/github.com/ollama/ollama/build/_deps/mlx-c-src"
|
||||
ARG CGO_CFLAGS
|
||||
ARG CGO_CXXFLAGS
|
||||
ENV CGO_CFLAGS="${CGO_CFLAGS} -I/go/src/github.com/ollama/ollama/build/_deps/mlx-c-src"
|
||||
ENV CGO_CXXFLAGS="${CGO_CXXFLAGS}"
|
||||
RUN --mount=type=cache,target=/root/.cache/go-build \
|
||||
go build -tags mlx -trimpath -buildmode=pie -o /bin/ollama .
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
UPSTREAM=https://github.com/ggml-org/llama.cpp.git
|
||||
WORKDIR=llama/vendor
|
||||
FETCH_HEAD=ec98e2002
|
||||
FETCH_HEAD=a5bb8ba4c50257437630c136210396810741bbf7
|
||||
|
||||
.PHONY: help
|
||||
help:
|
||||
|
||||
@@ -358,6 +358,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Odin Runes](https://github.com/leonid20000/OdinRunes)
|
||||
- [LLM-X](https://github.com/mrdjohnson/llm-x) (Progressive Web App)
|
||||
- [AnythingLLM (Docker + MacOs/Windows/Linux native app)](https://github.com/Mintplex-Labs/anything-llm)
|
||||
- [Screenpipe](https://github.com/mediar-ai/screenpipe) (24/7 screen & mic recording with AI-powered search, uses Ollama for local LLM features)
|
||||
- [Ollama Basic Chat: Uses HyperDiv Reactive UI](https://github.com/rapidarchitect/ollama_basic_chat)
|
||||
- [Ollama-chats RPG](https://github.com/drazdra/ollama-chats)
|
||||
- [IntelliBar](https://intellibar.app/) (AI-powered assistant for macOS)
|
||||
@@ -465,6 +466,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Clueless](https://github.com/KashyapTan/clueless) (Open Source & Local Cluely: A desktop application LLM assistant to help you talk to anything on your screen using locally served Ollama models. Also undetectable to screenshare)
|
||||
- [ollama-co2](https://github.com/carbonatedWaterOrg/ollama-co2) (FastAPI web interface for monitoring and managing local and remote Ollama servers with real-time model monitoring and concurrent downloads)
|
||||
- [Hillnote](https://hillnote.com) (A Markdown-first workspace designed to supercharge your AI workflow. Create documents ready to integrate with Claude, ChatGPT, Gemini, Cursor, and more - all while keeping your work on your device.)
|
||||
- [Stakpak](https://github.com/stakpak/agent) (An open source, vendor neutral DevOps agent that works with any model, and any stack, for teams who just want to ship)
|
||||
|
||||
### Cloud
|
||||
|
||||
@@ -558,7 +560,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [LiteLLM](https://github.com/BerriAI/litellm)
|
||||
- [OllamaFarm for Go](https://github.com/presbrey/ollamafarm)
|
||||
- [OllamaSharp for .NET](https://github.com/awaescher/OllamaSharp)
|
||||
- [Ollama for Ruby](https://github.com/gbaptista/ollama-ai)
|
||||
- [Ollama for Ruby](https://github.com/crmne/ruby_llm)
|
||||
- [Ollama-rs for Rust](https://github.com/pepperoni21/ollama-rs)
|
||||
- [Ollama-hpp for C++](https://github.com/jmont-dev/ollama-hpp)
|
||||
- [Ollama4j for Java](https://github.com/ollama4j/ollama4j)
|
||||
|
||||
157
anthropic/anthropic.go
Normal file → Executable file
157
anthropic/anthropic.go
Normal file → Executable file
@@ -211,6 +211,7 @@ type MessageDelta struct {
|
||||
|
||||
// DeltaUsage contains cumulative token usage
|
||||
type DeltaUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
}
|
||||
|
||||
@@ -517,24 +518,26 @@ func mapStopReason(reason string, hasToolCalls bool) string {
|
||||
|
||||
// StreamConverter manages state for converting Ollama streaming responses to Anthropic format
|
||||
type StreamConverter struct {
|
||||
ID string
|
||||
Model string
|
||||
firstWrite bool
|
||||
contentIndex int
|
||||
inputTokens int
|
||||
outputTokens int
|
||||
thinkingStarted bool
|
||||
thinkingDone bool
|
||||
textStarted bool
|
||||
toolCallsSent map[string]bool
|
||||
ID string
|
||||
Model string
|
||||
firstWrite bool
|
||||
contentIndex int
|
||||
inputTokens int
|
||||
outputTokens int
|
||||
estimatedInputTokens int // Estimated tokens from request (used when actual metrics are 0)
|
||||
thinkingStarted bool
|
||||
thinkingDone bool
|
||||
textStarted bool
|
||||
toolCallsSent map[string]bool
|
||||
}
|
||||
|
||||
func NewStreamConverter(id, model string) *StreamConverter {
|
||||
func NewStreamConverter(id, model string, estimatedInputTokens int) *StreamConverter {
|
||||
return &StreamConverter{
|
||||
ID: id,
|
||||
Model: model,
|
||||
firstWrite: true,
|
||||
toolCallsSent: make(map[string]bool),
|
||||
ID: id,
|
||||
Model: model,
|
||||
firstWrite: true,
|
||||
estimatedInputTokens: estimatedInputTokens,
|
||||
toolCallsSent: make(map[string]bool),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -550,7 +553,11 @@ func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
|
||||
|
||||
if c.firstWrite {
|
||||
c.firstWrite = false
|
||||
// Use actual metrics if available, otherwise use estimate
|
||||
c.inputTokens = r.Metrics.PromptEvalCount
|
||||
if c.inputTokens == 0 && c.estimatedInputTokens > 0 {
|
||||
c.inputTokens = c.estimatedInputTokens
|
||||
}
|
||||
|
||||
events = append(events, StreamEvent{
|
||||
Event: "message_start",
|
||||
@@ -721,6 +728,7 @@ func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
|
||||
})
|
||||
}
|
||||
|
||||
c.inputTokens = r.Metrics.PromptEvalCount
|
||||
c.outputTokens = r.Metrics.EvalCount
|
||||
stopReason := mapStopReason(r.DoneReason, len(c.toolCallsSent) > 0)
|
||||
|
||||
@@ -732,6 +740,7 @@ func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent {
|
||||
StopReason: stopReason,
|
||||
},
|
||||
Usage: DeltaUsage{
|
||||
InputTokens: c.inputTokens,
|
||||
OutputTokens: c.outputTokens,
|
||||
},
|
||||
},
|
||||
@@ -776,3 +785,121 @@ func mapToArgs(m map[string]any) api.ToolCallFunctionArguments {
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
// CountTokensRequest represents an Anthropic count_tokens request
|
||||
type CountTokensRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []MessageParam `json:"messages"`
|
||||
System any `json:"system,omitempty"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
Thinking *ThinkingConfig `json:"thinking,omitempty"`
|
||||
}
|
||||
|
||||
// EstimateInputTokens estimates input tokens from a MessagesRequest (reuses CountTokensRequest logic)
|
||||
func EstimateInputTokens(req MessagesRequest) int {
|
||||
return estimateTokens(CountTokensRequest{
|
||||
Model: req.Model,
|
||||
Messages: req.Messages,
|
||||
System: req.System,
|
||||
Tools: req.Tools,
|
||||
Thinking: req.Thinking,
|
||||
})
|
||||
}
|
||||
|
||||
// CountTokensResponse represents an Anthropic count_tokens response
|
||||
type CountTokensResponse struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
}
|
||||
|
||||
// estimateTokens returns a rough estimate of tokens (len/4)
|
||||
func estimateTokens(req CountTokensRequest) int {
|
||||
var totalLen int
|
||||
|
||||
// Count system prompt
|
||||
if req.System != nil {
|
||||
totalLen += countAnyContent(req.System)
|
||||
}
|
||||
|
||||
// Count messages
|
||||
for _, msg := range req.Messages {
|
||||
// Count role (always present)
|
||||
totalLen += len(msg.Role)
|
||||
// Count content
|
||||
contentLen := countAnyContent(msg.Content)
|
||||
totalLen += contentLen
|
||||
}
|
||||
|
||||
for _, tool := range req.Tools {
|
||||
totalLen += len(tool.Name) + len(tool.Description) + len(tool.InputSchema)
|
||||
}
|
||||
|
||||
// Return len/4 as rough token estimate, minimum 1 if there's any content
|
||||
tokens := totalLen / 4
|
||||
if tokens == 0 && (len(req.Messages) > 0 || req.System != nil) {
|
||||
tokens = 1
|
||||
}
|
||||
return tokens
|
||||
}
|
||||
|
||||
func countAnyContent(content any) int {
|
||||
if content == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
switch c := content.(type) {
|
||||
case string:
|
||||
return len(c)
|
||||
case []any:
|
||||
total := 0
|
||||
for _, block := range c {
|
||||
total += countContentBlock(block)
|
||||
}
|
||||
return total
|
||||
default:
|
||||
if data, err := json.Marshal(content); err == nil {
|
||||
return len(data)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func countContentBlock(block any) int {
|
||||
blockMap, ok := block.(map[string]any)
|
||||
if !ok {
|
||||
if s, ok := block.(string); ok {
|
||||
return len(s)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
total := 0
|
||||
blockType, _ := blockMap["type"].(string)
|
||||
|
||||
if text, ok := blockMap["text"].(string); ok {
|
||||
total += len(text)
|
||||
}
|
||||
|
||||
if thinking, ok := blockMap["thinking"].(string); ok {
|
||||
total += len(thinking)
|
||||
}
|
||||
|
||||
if blockType == "tool_use" {
|
||||
if data, err := json.Marshal(blockMap); err == nil {
|
||||
total += len(data)
|
||||
}
|
||||
}
|
||||
|
||||
if blockType == "tool_result" {
|
||||
if data, err := json.Marshal(blockMap); err == nil {
|
||||
total += len(data)
|
||||
}
|
||||
}
|
||||
|
||||
if source, ok := blockMap["source"].(map[string]any); ok {
|
||||
if data, ok := source["data"].(string); ok {
|
||||
total += len(data)
|
||||
}
|
||||
}
|
||||
|
||||
return total
|
||||
}
|
||||
|
||||
32
anthropic/anthropic_test.go
Normal file → Executable file
32
anthropic/anthropic_test.go
Normal file → Executable file
@@ -605,7 +605,7 @@ func TestGenerateMessageID(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStreamConverter_Basic(t *testing.T) {
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
conv := NewStreamConverter("msg_123", "test-model", 0)
|
||||
|
||||
// First chunk
|
||||
resp1 := api.ChatResponse{
|
||||
@@ -642,7 +642,7 @@ func TestStreamConverter_Basic(t *testing.T) {
|
||||
},
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
Metrics: api.Metrics{EvalCount: 5},
|
||||
Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 5},
|
||||
}
|
||||
|
||||
events2 := conv.Process(resp2)
|
||||
@@ -650,6 +650,24 @@ func TestStreamConverter_Basic(t *testing.T) {
|
||||
// Should have content_block_delta, content_block_stop, message_delta, message_stop
|
||||
hasStop := false
|
||||
for _, e := range events2 {
|
||||
if e.Event == "message_delta" {
|
||||
if data, ok := e.Data.(MessageDeltaEvent); ok {
|
||||
if data.Type != "message_delta" {
|
||||
t.Errorf("unexpected data type: %+v", data)
|
||||
}
|
||||
|
||||
if data.Delta.StopReason != "end_turn" {
|
||||
t.Errorf("unexpected stop reason: %+v", data.Delta.StopReason)
|
||||
}
|
||||
|
||||
if data.Usage.InputTokens != 10 || data.Usage.OutputTokens != 5 {
|
||||
t.Errorf("unexpected usage: %+v", data.Usage)
|
||||
}
|
||||
} else {
|
||||
t.Errorf("unexpected data: %+v", e.Data)
|
||||
}
|
||||
}
|
||||
|
||||
if e.Event == "message_stop" {
|
||||
hasStop = true
|
||||
}
|
||||
@@ -660,7 +678,7 @@ func TestStreamConverter_Basic(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStreamConverter_WithToolCalls(t *testing.T) {
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
conv := NewStreamConverter("msg_123", "test-model", 0)
|
||||
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
@@ -713,7 +731,7 @@ func TestStreamConverter_WithToolCalls(t *testing.T) {
|
||||
func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) {
|
||||
// Test that unmarshalable arguments (like channels) are handled gracefully
|
||||
// and don't cause a panic or corrupt stream
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
conv := NewStreamConverter("msg_123", "test-model", 0)
|
||||
|
||||
// Create a channel which cannot be JSON marshaled
|
||||
unmarshalable := make(chan int)
|
||||
@@ -760,7 +778,7 @@ func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) {
|
||||
|
||||
func TestStreamConverter_MultipleToolCallsWithMixedValidity(t *testing.T) {
|
||||
// Test that valid tool calls still work when mixed with invalid ones
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
conv := NewStreamConverter("msg_123", "test-model", 0)
|
||||
|
||||
unmarshalable := make(chan int)
|
||||
badArgs := api.NewToolCallFunctionArguments()
|
||||
@@ -885,7 +903,7 @@ func TestContentBlockJSON_EmptyFieldsPresent(t *testing.T) {
|
||||
// events include the required empty fields for SDK compatibility.
|
||||
func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
|
||||
t.Run("text block start includes empty text", func(t *testing.T) {
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
conv := NewStreamConverter("msg_123", "test-model", 0)
|
||||
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
@@ -919,7 +937,7 @@ func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("thinking block start includes empty thinking", func(t *testing.T) {
|
||||
conv := NewStreamConverter("msg_123", "test-model")
|
||||
conv := NewStreamConverter("msg_123", "test-model", 0)
|
||||
|
||||
resp := api.ChatResponse{
|
||||
Model: "test-model",
|
||||
|
||||
@@ -466,3 +466,25 @@ func (c *Client) Whoami(ctx context.Context) (*UserResponse, error) {
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// AliasRequest is the request body for creating or updating a model alias.
|
||||
type AliasRequest struct {
|
||||
Alias string `json:"alias"`
|
||||
Target string `json:"target"`
|
||||
PrefixMatching bool `json:"prefix_matching,omitempty"`
|
||||
}
|
||||
|
||||
// SetAliasExperimental creates or updates a model alias via the experimental aliases API.
|
||||
func (c *Client) SetAliasExperimental(ctx context.Context, req *AliasRequest) error {
|
||||
return c.do(ctx, http.MethodPost, "/api/experimental/aliases", req, nil)
|
||||
}
|
||||
|
||||
// AliasDeleteRequest is the request body for deleting a model alias.
|
||||
type AliasDeleteRequest struct {
|
||||
Alias string `json:"alias"`
|
||||
}
|
||||
|
||||
// DeleteAliasExperimental deletes a model alias via the experimental aliases API.
|
||||
func (c *Client) DeleteAliasExperimental(ctx context.Context, req *AliasDeleteRequest) error {
|
||||
return c.do(ctx, http.MethodDelete, "/api/experimental/aliases", req, nil)
|
||||
}
|
||||
|
||||
@@ -75,9 +75,9 @@ The `-dev` flag enables:
|
||||
CI builds with Xcode 14.1 for OS compatibility prior to v13. If you want to manually build v11+ support, you can download the older Xcode [here](https://developer.apple.com/services-account/download?path=/Developer_Tools/Xcode_14.1/Xcode_14.1.xip), extract, then `mv ./Xcode.app /Applications/Xcode_14.1.0.app` then activate with:
|
||||
|
||||
```
|
||||
export CGO_CFLAGS=-mmacosx-version-min=12.0
|
||||
export CGO_CXXFLAGS=-mmacosx-version-min=12.0
|
||||
export CGO_LDFLAGS=-mmacosx-version-min=12.0
|
||||
export CGO_CFLAGS="-O3 -mmacosx-version-min=12.0"
|
||||
export CGO_CXXFLAGS="-O3 -mmacosx-version-min=12.0"
|
||||
export CGO_LDFLAGS="-mmacosx-version-min=12.0"
|
||||
export SDKROOT=/Applications/Xcode_14.1.0.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk
|
||||
export DEVELOPER_DIR=/Applications/Xcode_14.1.0.app/Contents/Developer
|
||||
```
|
||||
|
||||
@@ -29,6 +29,7 @@ import (
|
||||
"github.com/containerd/console"
|
||||
"github.com/mattn/go-runewidth"
|
||||
"github.com/olekukonko/tablewriter"
|
||||
"github.com/pkg/browser"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/sync/errgroup"
|
||||
@@ -52,7 +53,7 @@ import (
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
)
|
||||
|
||||
const ConnectInstructions = "To sign in, navigate to:\n %s\n\n"
|
||||
const ConnectInstructions = "If your browser did not open, navigate to:\n %s\n\n"
|
||||
|
||||
// ensureThinkingSupport emits a warning if the model does not advertise thinking support
|
||||
func ensureThinkingSupport(ctx context.Context, client *api.Client, name string) {
|
||||
@@ -663,6 +664,7 @@ func SigninHandler(cmd *cobra.Command, args []string) error {
|
||||
fmt.Println()
|
||||
|
||||
if aErr.SigninURL != "" {
|
||||
_ = browser.OpenURL(aErr.SigninURL)
|
||||
fmt.Printf(ConnectInstructions, aErr.SigninURL)
|
||||
}
|
||||
return nil
|
||||
@@ -1750,7 +1752,7 @@ func checkServerHeartbeat(cmd *cobra.Command, _ []string) error {
|
||||
return err
|
||||
}
|
||||
if err := startApp(cmd.Context(), client); err != nil {
|
||||
return fmt.Errorf("ollama server not responding - %w", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
@@ -1888,7 +1890,7 @@ func NewCLI() *cobra.Command {
|
||||
serveCmd := &cobra.Command{
|
||||
Use: "serve",
|
||||
Aliases: []string{"start"},
|
||||
Short: "Start ollama",
|
||||
Short: "Start Ollama",
|
||||
Args: cobra.ExactArgs(0),
|
||||
RunE: RunServer,
|
||||
}
|
||||
|
||||
@@ -1553,7 +1553,7 @@ func TestShowInfoImageGen(t *testing.T) {
|
||||
Details: api.ModelDetails{
|
||||
Family: "ZImagePipeline",
|
||||
ParameterSize: "10.3B",
|
||||
QuantizationLevel: "FP8",
|
||||
QuantizationLevel: "Q8",
|
||||
},
|
||||
Capabilities: []model.Capability{model.CapabilityImage},
|
||||
Requires: "0.14.0",
|
||||
@@ -1565,7 +1565,7 @@ func TestShowInfoImageGen(t *testing.T) {
|
||||
expect := " Model\n" +
|
||||
" architecture ZImagePipeline \n" +
|
||||
" parameters 10.3B \n" +
|
||||
" quantization FP8 \n" +
|
||||
" quantization Q8 \n" +
|
||||
" requires 0.14.0 \n" +
|
||||
"\n" +
|
||||
" Capabilities\n" +
|
||||
|
||||
@@ -1,36 +1,166 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
// Claude implements Runner for Claude Code integration
|
||||
// Claude implements Runner and AliasConfigurer for Claude Code integration
|
||||
type Claude struct{}
|
||||
|
||||
// Compile-time check that Claude implements AliasConfigurer
|
||||
var _ AliasConfigurer = (*Claude)(nil)
|
||||
|
||||
func (c *Claude) String() string { return "Claude Code" }
|
||||
|
||||
func (c *Claude) args(model string) []string {
|
||||
func (c *Claude) args(model string, extra []string) []string {
|
||||
var args []string
|
||||
if model != "" {
|
||||
return []string{"--model", model}
|
||||
args = append(args, "--model", model)
|
||||
}
|
||||
return nil
|
||||
args = append(args, extra...)
|
||||
return args
|
||||
}
|
||||
|
||||
func (c *Claude) Run(model string) error {
|
||||
if _, err := exec.LookPath("claude"); err != nil {
|
||||
func (c *Claude) findPath() (string, error) {
|
||||
if p, err := exec.LookPath("claude"); err == nil {
|
||||
return p, nil
|
||||
}
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
name := "claude"
|
||||
if runtime.GOOS == "windows" {
|
||||
name = "claude.exe"
|
||||
}
|
||||
fallback := filepath.Join(home, ".claude", "local", name)
|
||||
if _, err := os.Stat(fallback); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return fallback, nil
|
||||
}
|
||||
|
||||
func (c *Claude) Run(model string, args []string) error {
|
||||
claudePath, err := c.findPath()
|
||||
if err != nil {
|
||||
return fmt.Errorf("claude is not installed, install from https://code.claude.com/docs/en/quickstart")
|
||||
}
|
||||
|
||||
cmd := exec.Command("claude", c.args(model)...)
|
||||
cmd := exec.Command(claudePath, c.args(model, args)...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
cmd.Env = append(os.Environ(),
|
||||
"ANTHROPIC_BASE_URL=http://localhost:11434",
|
||||
"ANTHROPIC_BASE_URL="+envconfig.Host().String(),
|
||||
"ANTHROPIC_API_KEY=",
|
||||
"ANTHROPIC_AUTH_TOKEN=ollama",
|
||||
)
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
// ConfigureAliases sets up model aliases for Claude Code.
|
||||
// Cloud-only: subagent routing (fast model) is gated to cloud models only until
|
||||
// there is a better strategy for prompt caching on local models.
|
||||
func (c *Claude) ConfigureAliases(ctx context.Context, primaryModel string, existing map[string]string, force bool) (map[string]string, bool, error) {
|
||||
aliases := make(map[string]string)
|
||||
for k, v := range existing {
|
||||
aliases[k] = v
|
||||
}
|
||||
|
||||
if primaryModel != "" {
|
||||
aliases["primary"] = primaryModel
|
||||
}
|
||||
|
||||
if !force && aliases["primary"] != "" {
|
||||
client, _ := api.ClientFromEnvironment()
|
||||
if isCloudModel(ctx, client, aliases["primary"]) {
|
||||
if isCloudModel(ctx, client, aliases["fast"]) {
|
||||
return aliases, false, nil
|
||||
}
|
||||
} else {
|
||||
delete(aliases, "fast")
|
||||
return aliases, false, nil
|
||||
}
|
||||
}
|
||||
|
||||
items, existingModels, cloudModels, client, err := listModels(ctx)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\n%sModel Configuration%s\n\n", ansiBold, ansiReset)
|
||||
|
||||
if aliases["primary"] == "" || force {
|
||||
primary, err := selectPrompt("Select model:", items)
|
||||
fmt.Fprintf(os.Stderr, "\033[3A\033[J")
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if err := pullIfNeeded(ctx, client, existingModels, primary); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if err := ensureAuth(ctx, client, cloudModels, []string{primary}); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
aliases["primary"] = primary
|
||||
}
|
||||
|
||||
if isCloudModel(ctx, client, aliases["primary"]) {
|
||||
if aliases["fast"] == "" || !isCloudModel(ctx, client, aliases["fast"]) {
|
||||
aliases["fast"] = aliases["primary"]
|
||||
}
|
||||
} else {
|
||||
delete(aliases, "fast")
|
||||
}
|
||||
|
||||
return aliases, true, nil
|
||||
}
|
||||
|
||||
// SetAliases syncs the configured aliases to the Ollama server using prefix matching.
|
||||
// Cloud-only: for local models (fast is empty), we delete any existing aliases to
|
||||
// prevent stale routing to a previous cloud model.
|
||||
func (c *Claude) SetAliases(ctx context.Context, aliases map[string]string) error {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
prefixes := []string{"claude-sonnet-", "claude-haiku-"}
|
||||
|
||||
if aliases["fast"] == "" {
|
||||
for _, prefix := range prefixes {
|
||||
_ = client.DeleteAliasExperimental(ctx, &api.AliasDeleteRequest{Alias: prefix})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
prefixAliases := map[string]string{
|
||||
"claude-sonnet-": aliases["primary"],
|
||||
"claude-haiku-": aliases["fast"],
|
||||
}
|
||||
|
||||
var errs []string
|
||||
for prefix, target := range prefixAliases {
|
||||
req := &api.AliasRequest{
|
||||
Alias: prefix,
|
||||
Target: target,
|
||||
PrefixMatching: true,
|
||||
}
|
||||
if err := client.SetAliasExperimental(ctx, req); err != nil {
|
||||
errs = append(errs, prefix)
|
||||
}
|
||||
}
|
||||
|
||||
if len(errs) > 0 {
|
||||
return fmt.Errorf("failed to set aliases: %v", errs)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
"testing"
|
||||
)
|
||||
@@ -19,23 +22,83 @@ func TestClaudeIntegration(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestClaudeFindPath(t *testing.T) {
|
||||
c := &Claude{}
|
||||
|
||||
t.Run("finds claude in PATH", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
name := "claude"
|
||||
if runtime.GOOS == "windows" {
|
||||
name = "claude.exe"
|
||||
}
|
||||
fakeBin := filepath.Join(tmpDir, name)
|
||||
os.WriteFile(fakeBin, []byte("#!/bin/sh\n"), 0o755)
|
||||
t.Setenv("PATH", tmpDir)
|
||||
|
||||
got, err := c.findPath()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != fakeBin {
|
||||
t.Errorf("findPath() = %q, want %q", got, fakeBin)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("falls back to ~/.claude/local/claude", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
t.Setenv("PATH", t.TempDir()) // empty dir, no claude binary
|
||||
|
||||
name := "claude"
|
||||
if runtime.GOOS == "windows" {
|
||||
name = "claude.exe"
|
||||
}
|
||||
fallback := filepath.Join(tmpDir, ".claude", "local", name)
|
||||
os.MkdirAll(filepath.Dir(fallback), 0o755)
|
||||
os.WriteFile(fallback, []byte("#!/bin/sh\n"), 0o755)
|
||||
|
||||
got, err := c.findPath()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != fallback {
|
||||
t.Errorf("findPath() = %q, want %q", got, fallback)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns error when neither PATH nor fallback exists", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
t.Setenv("PATH", t.TempDir()) // empty dir, no claude binary
|
||||
|
||||
_, err := c.findPath()
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestClaudeArgs(t *testing.T) {
|
||||
c := &Claude{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
args []string
|
||||
want []string
|
||||
}{
|
||||
{"with model", "llama3.2", []string{"--model", "llama3.2"}},
|
||||
{"empty model", "", nil},
|
||||
{"with model", "llama3.2", nil, []string{"--model", "llama3.2"}},
|
||||
{"empty model", "", nil, nil},
|
||||
{"with model and verbose", "llama3.2", []string{"--verbose"}, []string{"--model", "llama3.2", "--verbose"}},
|
||||
{"empty model with help", "", []string{"--help"}, []string{"--help"}},
|
||||
{"with allowed tools", "llama3.2", []string{"--allowedTools", "Read,Write,Bash"}, []string{"--model", "llama3.2", "--allowedTools", "Read,Write,Bash"}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := c.args(tt.model)
|
||||
got := c.args(tt.model, tt.args)
|
||||
if !slices.Equal(got, tt.want) {
|
||||
t.Errorf("args(%q) = %v, want %v", tt.model, got, tt.want)
|
||||
t.Errorf("args(%q, %v) = %v, want %v", tt.model, tt.args, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -14,20 +14,21 @@ type Codex struct{}
|
||||
|
||||
func (c *Codex) String() string { return "Codex" }
|
||||
|
||||
func (c *Codex) args(model string) []string {
|
||||
func (c *Codex) args(model string, extra []string) []string {
|
||||
args := []string{"--oss"}
|
||||
if model != "" {
|
||||
args = append(args, "-m", model)
|
||||
}
|
||||
args = append(args, extra...)
|
||||
return args
|
||||
}
|
||||
|
||||
func (c *Codex) Run(model string) error {
|
||||
func (c *Codex) Run(model string, args []string) error {
|
||||
if err := checkCodexVersion(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd := exec.Command("codex", c.args(model)...)
|
||||
cmd := exec.Command("codex", c.args(model, args)...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
|
||||
@@ -11,17 +11,20 @@ func TestCodexArgs(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
args []string
|
||||
want []string
|
||||
}{
|
||||
{"with model", "llama3.2", []string{"--oss", "-m", "llama3.2"}},
|
||||
{"empty model", "", []string{"--oss"}},
|
||||
{"with model", "llama3.2", nil, []string{"--oss", "-m", "llama3.2"}},
|
||||
{"empty model", "", nil, []string{"--oss"}},
|
||||
{"with model and profile", "qwen3-coder", []string{"-p", "myprofile"}, []string{"--oss", "-m", "qwen3-coder", "-p", "myprofile"}},
|
||||
{"with sandbox flag", "llama3.2", []string{"--sandbox", "workspace-write"}, []string{"--oss", "-m", "llama3.2", "--sandbox", "workspace-write"}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := c.args(tt.model)
|
||||
got := c.args(tt.model, tt.args)
|
||||
if !slices.Equal(got, tt.want) {
|
||||
t.Errorf("args(%q) = %v, want %v", tt.model, got, tt.want)
|
||||
t.Errorf("args(%q, %v) = %v, want %v", tt.model, tt.args, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -6,13 +6,15 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type integration struct {
|
||||
Models []string `json:"models"`
|
||||
Models []string `json:"models"`
|
||||
Aliases map[string]string `json:"aliases,omitempty"`
|
||||
}
|
||||
|
||||
type config struct {
|
||||
@@ -20,6 +22,14 @@ type config struct {
|
||||
}
|
||||
|
||||
func configPath() (string, error) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return filepath.Join(home, ".ollama", "config.json"), nil
|
||||
}
|
||||
|
||||
func legacyConfigPath() (string, error) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
@@ -27,6 +37,46 @@ func configPath() (string, error) {
|
||||
return filepath.Join(home, ".ollama", "config", "config.json"), nil
|
||||
}
|
||||
|
||||
// migrateConfig moves the config from the legacy path to ~/.ollama/config.json
|
||||
func migrateConfig() (bool, error) {
|
||||
oldPath, err := legacyConfigPath()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
oldData, err := os.ReadFile(oldPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
|
||||
var js json.RawMessage
|
||||
if err := json.Unmarshal(oldData, &js); err != nil {
|
||||
slog.Warn("legacy config has invalid JSON, skipping migration", "path", oldPath, "error", err)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
newPath, err := configPath()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(filepath.Dir(newPath), 0o755); err != nil {
|
||||
return false, err
|
||||
}
|
||||
if err := os.WriteFile(newPath, oldData, 0o644); err != nil {
|
||||
return false, fmt.Errorf("write new config: %w", err)
|
||||
}
|
||||
|
||||
_ = os.Remove(oldPath)
|
||||
_ = os.Remove(filepath.Dir(oldPath)) // clean up empty directory
|
||||
|
||||
slog.Info("migrated config", "from", oldPath, "to", newPath)
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func load() (*config, error) {
|
||||
path, err := configPath()
|
||||
if err != nil {
|
||||
@@ -34,6 +84,11 @@ func load() (*config, error) {
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil && os.IsNotExist(err) {
|
||||
if migrated, merr := migrateConfig(); merr == nil && migrated {
|
||||
data, err = os.ReadFile(path)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return &config{Integrations: make(map[string]*integration)}, nil
|
||||
@@ -79,8 +134,16 @@ func saveIntegration(appName string, models []string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
cfg.Integrations[strings.ToLower(appName)] = &integration{
|
||||
Models: models,
|
||||
key := strings.ToLower(appName)
|
||||
existing := cfg.Integrations[key]
|
||||
var aliases map[string]string
|
||||
if existing != nil && existing.Aliases != nil {
|
||||
aliases = existing.Aliases
|
||||
}
|
||||
|
||||
cfg.Integrations[key] = &integration{
|
||||
Models: models,
|
||||
Aliases: aliases,
|
||||
}
|
||||
|
||||
return save(cfg)
|
||||
@@ -100,6 +163,29 @@ func loadIntegration(appName string) (*integration, error) {
|
||||
return ic, nil
|
||||
}
|
||||
|
||||
func saveAliases(appName string, aliases map[string]string) error {
|
||||
if appName == "" {
|
||||
return errors.New("app name cannot be empty")
|
||||
}
|
||||
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
key := strings.ToLower(appName)
|
||||
existing := cfg.Integrations[key]
|
||||
if existing == nil {
|
||||
existing = &integration{}
|
||||
}
|
||||
|
||||
// Replace aliases entirely (not merge) so deletions are persisted
|
||||
existing.Aliases = aliases
|
||||
|
||||
cfg.Integrations[key] = existing
|
||||
return save(cfg)
|
||||
}
|
||||
|
||||
func listIntegrations() ([]integration, error) {
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
|
||||
735
cmd/config/config_cloud_test.go
Normal file
735
cmd/config/config_cloud_test.go
Normal file
@@ -0,0 +1,735 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// Mock API client for testing
|
||||
type mockAliasClient struct {
|
||||
setAliasCalls []setAliasCall
|
||||
deleteAliasCalls []string
|
||||
setAliasErr error
|
||||
deleteAliasErr error
|
||||
}
|
||||
|
||||
type setAliasCall struct {
|
||||
alias string
|
||||
target string
|
||||
prefixMatching bool
|
||||
}
|
||||
|
||||
// TestSetAliases_CloudModel verifies that cloud models set both prefix aliases
|
||||
func TestSetAliases_CloudModel(t *testing.T) {
|
||||
// Test the SetAliases logic by checking the alias map behavior
|
||||
aliases := map[string]string{
|
||||
"primary": "kimi-k2.5:cloud",
|
||||
"fast": "kimi-k2.5:cloud",
|
||||
}
|
||||
|
||||
// Verify fast is set (cloud model behavior)
|
||||
if aliases["fast"] == "" {
|
||||
t.Error("cloud model should have fast alias set")
|
||||
}
|
||||
if aliases["fast"] != aliases["primary"] {
|
||||
t.Errorf("fast should equal primary for auto-set, got fast=%q primary=%q", aliases["fast"], aliases["primary"])
|
||||
}
|
||||
}
|
||||
|
||||
// TestSetAliases_LocalModel verifies that local models clear fast and would delete aliases
|
||||
func TestSetAliases_LocalModel(t *testing.T) {
|
||||
aliases := map[string]string{
|
||||
"primary": "llama3.2:latest",
|
||||
}
|
||||
// Simulate local model behavior: fast should be empty
|
||||
delete(aliases, "fast")
|
||||
|
||||
if aliases["fast"] != "" {
|
||||
t.Error("local model should have empty fast alias")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSaveAliases_ReplacesNotMerges verifies the critical fix where saveAliases
|
||||
// replaces aliases entirely instead of merging (which would leave stale keys)
|
||||
func TestSaveAliases_ReplacesNotMerges(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// First save with both primary and fast
|
||||
initial := map[string]string{
|
||||
"primary": "cloud-model",
|
||||
"fast": "cloud-model",
|
||||
}
|
||||
if err := saveAliases("claude", initial); err != nil {
|
||||
t.Fatalf("failed to save initial aliases: %v", err)
|
||||
}
|
||||
|
||||
// Verify both are saved
|
||||
loaded, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
if loaded.Aliases["fast"] != "cloud-model" {
|
||||
t.Errorf("expected fast=cloud-model, got %q", loaded.Aliases["fast"])
|
||||
}
|
||||
|
||||
// Now save without fast (simulating switch to local model)
|
||||
updated := map[string]string{
|
||||
"primary": "local-model",
|
||||
// fast intentionally missing
|
||||
}
|
||||
if err := saveAliases("claude", updated); err != nil {
|
||||
t.Fatalf("failed to save updated aliases: %v", err)
|
||||
}
|
||||
|
||||
// Verify fast is GONE (not merged/preserved)
|
||||
loaded, err = loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load after update: %v", err)
|
||||
}
|
||||
if loaded.Aliases["fast"] != "" {
|
||||
t.Errorf("fast should be removed after saving without it, got %q", loaded.Aliases["fast"])
|
||||
}
|
||||
if loaded.Aliases["primary"] != "local-model" {
|
||||
t.Errorf("primary should be updated to local-model, got %q", loaded.Aliases["primary"])
|
||||
}
|
||||
}
|
||||
|
||||
// TestSaveAliases_PreservesModels verifies that updating aliases doesn't affect models
|
||||
func TestSaveAliases_PreservesModels(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// First save integration with models
|
||||
if err := saveIntegration("claude", []string{"model1", "model2"}); err != nil {
|
||||
t.Fatalf("failed to save integration: %v", err)
|
||||
}
|
||||
|
||||
// Then update aliases
|
||||
aliases := map[string]string{"primary": "new-model"}
|
||||
if err := saveAliases("claude", aliases); err != nil {
|
||||
t.Fatalf("failed to save aliases: %v", err)
|
||||
}
|
||||
|
||||
// Verify models are preserved
|
||||
loaded, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
if len(loaded.Models) != 2 || loaded.Models[0] != "model1" {
|
||||
t.Errorf("models should be preserved, got %v", loaded.Models)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSaveAliases_EmptyMap clears all aliases
|
||||
func TestSaveAliases_EmptyMap(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Save with aliases
|
||||
if err := saveAliases("claude", map[string]string{"primary": "model", "fast": "model"}); err != nil {
|
||||
t.Fatalf("failed to save: %v", err)
|
||||
}
|
||||
|
||||
// Save empty map
|
||||
if err := saveAliases("claude", map[string]string{}); err != nil {
|
||||
t.Fatalf("failed to save empty: %v", err)
|
||||
}
|
||||
|
||||
loaded, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
if len(loaded.Aliases) != 0 {
|
||||
t.Errorf("aliases should be empty, got %v", loaded.Aliases)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSaveAliases_NilMap handles nil gracefully
|
||||
func TestSaveAliases_NilMap(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Save with aliases first
|
||||
if err := saveAliases("claude", map[string]string{"primary": "model"}); err != nil {
|
||||
t.Fatalf("failed to save: %v", err)
|
||||
}
|
||||
|
||||
// Save nil map - should clear aliases
|
||||
if err := saveAliases("claude", nil); err != nil {
|
||||
t.Fatalf("failed to save nil: %v", err)
|
||||
}
|
||||
|
||||
loaded, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
if loaded.Aliases != nil && len(loaded.Aliases) > 0 {
|
||||
t.Errorf("aliases should be nil or empty, got %v", loaded.Aliases)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSaveAliases_EmptyAppName returns error
|
||||
func TestSaveAliases_EmptyAppName(t *testing.T) {
|
||||
err := saveAliases("", map[string]string{"primary": "model"})
|
||||
if err == nil {
|
||||
t.Error("expected error for empty app name")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSaveAliases_CaseInsensitive verifies app names are case insensitive
|
||||
func TestSaveAliases_CaseInsensitive(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
if err := saveAliases("Claude", map[string]string{"primary": "model1"}); err != nil {
|
||||
t.Fatalf("failed to save: %v", err)
|
||||
}
|
||||
|
||||
// Load with different case
|
||||
loaded, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
if loaded.Aliases["primary"] != "model1" {
|
||||
t.Errorf("expected primary=model1, got %q", loaded.Aliases["primary"])
|
||||
}
|
||||
|
||||
// Update with different case
|
||||
if err := saveAliases("CLAUDE", map[string]string{"primary": "model2"}); err != nil {
|
||||
t.Fatalf("failed to update: %v", err)
|
||||
}
|
||||
|
||||
loaded, err = loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load after update: %v", err)
|
||||
}
|
||||
if loaded.Aliases["primary"] != "model2" {
|
||||
t.Errorf("expected primary=model2, got %q", loaded.Aliases["primary"])
|
||||
}
|
||||
}
|
||||
|
||||
// TestSaveAliases_CreatesIntegration creates integration if it doesn't exist
|
||||
func TestSaveAliases_CreatesIntegration(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Save aliases for non-existent integration
|
||||
if err := saveAliases("newintegration", map[string]string{"primary": "model"}); err != nil {
|
||||
t.Fatalf("failed to save: %v", err)
|
||||
}
|
||||
|
||||
loaded, err := loadIntegration("newintegration")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
if loaded.Aliases["primary"] != "model" {
|
||||
t.Errorf("expected primary=model, got %q", loaded.Aliases["primary"])
|
||||
}
|
||||
}
|
||||
|
||||
// TestCleanAliasesForLocalModel tests the cleanup function
|
||||
// TestConfigureAliases_AliasMap tests the alias map manipulation logic
|
||||
func TestConfigureAliases_AliasMap(t *testing.T) {
|
||||
t.Run("cloud model auto-sets fast to primary", func(t *testing.T) {
|
||||
aliases := make(map[string]string)
|
||||
aliases["primary"] = "cloud-model"
|
||||
|
||||
// Simulate cloud model behavior
|
||||
isCloud := true
|
||||
if isCloud {
|
||||
if aliases["fast"] == "" {
|
||||
aliases["fast"] = aliases["primary"]
|
||||
}
|
||||
}
|
||||
|
||||
if aliases["fast"] != "cloud-model" {
|
||||
t.Errorf("expected fast=cloud-model, got %q", aliases["fast"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("cloud model preserves custom fast", func(t *testing.T) {
|
||||
aliases := map[string]string{
|
||||
"primary": "cloud-model",
|
||||
"fast": "custom-fast-model",
|
||||
}
|
||||
|
||||
// Simulate cloud model behavior - should preserve existing fast
|
||||
isCloud := true
|
||||
if isCloud {
|
||||
if aliases["fast"] == "" {
|
||||
aliases["fast"] = aliases["primary"]
|
||||
}
|
||||
}
|
||||
|
||||
if aliases["fast"] != "custom-fast-model" {
|
||||
t.Errorf("expected fast=custom-fast-model (preserved), got %q", aliases["fast"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("local model clears fast", func(t *testing.T) {
|
||||
aliases := map[string]string{
|
||||
"primary": "local-model",
|
||||
"fast": "should-be-cleared",
|
||||
}
|
||||
|
||||
// Simulate local model behavior
|
||||
isCloud := false
|
||||
if !isCloud {
|
||||
delete(aliases, "fast")
|
||||
}
|
||||
|
||||
if aliases["fast"] != "" {
|
||||
t.Errorf("expected fast to be cleared, got %q", aliases["fast"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("switching cloud to local clears fast", func(t *testing.T) {
|
||||
// Start with cloud config
|
||||
aliases := map[string]string{
|
||||
"primary": "cloud-model",
|
||||
"fast": "cloud-model",
|
||||
}
|
||||
|
||||
// Switch to local
|
||||
aliases["primary"] = "local-model"
|
||||
isCloud := false
|
||||
if !isCloud {
|
||||
delete(aliases, "fast")
|
||||
}
|
||||
|
||||
if aliases["fast"] != "" {
|
||||
t.Errorf("fast should be cleared when switching to local, got %q", aliases["fast"])
|
||||
}
|
||||
if aliases["primary"] != "local-model" {
|
||||
t.Errorf("primary should be updated, got %q", aliases["primary"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("switching local to cloud sets fast", func(t *testing.T) {
|
||||
// Start with local config (no fast)
|
||||
aliases := map[string]string{
|
||||
"primary": "local-model",
|
||||
}
|
||||
|
||||
// Switch to cloud
|
||||
aliases["primary"] = "cloud-model"
|
||||
isCloud := true
|
||||
if isCloud {
|
||||
if aliases["fast"] == "" {
|
||||
aliases["fast"] = aliases["primary"]
|
||||
}
|
||||
}
|
||||
|
||||
if aliases["fast"] != "cloud-model" {
|
||||
t.Errorf("fast should be set when switching to cloud, got %q", aliases["fast"])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestSetAliases_PrefixMapping verifies the correct prefix-to-alias mapping
|
||||
func TestSetAliases_PrefixMapping(t *testing.T) {
|
||||
// This tests the expected mapping without needing a real client
|
||||
aliases := map[string]string{
|
||||
"primary": "my-cloud-model",
|
||||
"fast": "my-fast-model",
|
||||
}
|
||||
|
||||
expectedMappings := map[string]string{
|
||||
"claude-sonnet-": aliases["primary"],
|
||||
"claude-haiku-": aliases["fast"],
|
||||
}
|
||||
|
||||
if expectedMappings["claude-sonnet-"] != "my-cloud-model" {
|
||||
t.Errorf("claude-sonnet- should map to primary")
|
||||
}
|
||||
if expectedMappings["claude-haiku-"] != "my-fast-model" {
|
||||
t.Errorf("claude-haiku- should map to fast")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSetAliases_LocalDeletesPrefixes verifies local models should delete both prefixes
|
||||
func TestSetAliases_LocalDeletesPrefixes(t *testing.T) {
|
||||
aliases := map[string]string{
|
||||
"primary": "local-model",
|
||||
// fast is empty/missing - indicates local model
|
||||
}
|
||||
|
||||
prefixesToDelete := []string{"claude-sonnet-", "claude-haiku-"}
|
||||
|
||||
// Verify the logic: when fast is empty, we should delete
|
||||
if aliases["fast"] != "" {
|
||||
t.Error("fast should be empty for local model")
|
||||
}
|
||||
|
||||
// Verify we have the right prefixes to delete
|
||||
if len(prefixesToDelete) != 2 {
|
||||
t.Errorf("expected 2 prefixes to delete, got %d", len(prefixesToDelete))
|
||||
}
|
||||
}
|
||||
|
||||
// TestAtomicUpdate_ServerFailsConfigNotSaved simulates atomic update behavior
|
||||
func TestAtomicUpdate_ServerFailsConfigNotSaved(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Simulate: server fails, config should NOT be saved
|
||||
serverErr := errors.New("server unavailable")
|
||||
configSaved := false
|
||||
|
||||
if serverErr != nil {
|
||||
// Server failed - don't save config
|
||||
configSaved = false
|
||||
} else {
|
||||
configSaved = true
|
||||
}
|
||||
|
||||
if configSaved {
|
||||
t.Error("config should NOT be saved when server fails")
|
||||
}
|
||||
}
|
||||
|
||||
// TestAtomicUpdate_ServerSucceedsConfigSaved simulates successful atomic update
|
||||
func TestAtomicUpdate_ServerSucceedsConfigSaved(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Simulate: server succeeds, config should be saved
|
||||
var serverErr error = nil
|
||||
configSaved := false
|
||||
|
||||
if serverErr != nil {
|
||||
configSaved = false
|
||||
} else {
|
||||
// Server succeeded - save config
|
||||
if err := saveAliases("claude", map[string]string{"primary": "model"}); err != nil {
|
||||
t.Fatalf("saveAliases failed: %v", err)
|
||||
}
|
||||
configSaved = true
|
||||
}
|
||||
|
||||
if !configSaved {
|
||||
t.Error("config should be saved when server succeeds")
|
||||
}
|
||||
|
||||
// Verify it was actually saved
|
||||
loaded, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
if loaded.Aliases["primary"] != "model" {
|
||||
t.Errorf("expected primary=model, got %q", loaded.Aliases["primary"])
|
||||
}
|
||||
}
|
||||
|
||||
// TestConfigFile_PreservesUnknownFields verifies we don't lose unknown fields
|
||||
func TestConfigFile_PreservesUnknownFields(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Write config with extra fields
|
||||
configPath := filepath.Join(tmpDir, ".ollama", "config.json")
|
||||
os.MkdirAll(filepath.Dir(configPath), 0o755)
|
||||
|
||||
// Note: Our config struct only has Integrations, so top-level unknown fields
|
||||
// won't be preserved by our current implementation. This test documents that.
|
||||
initialConfig := `{
|
||||
"integrations": {
|
||||
"claude": {
|
||||
"models": ["model1"],
|
||||
"aliases": {"primary": "model1"},
|
||||
"unknownField": "should be lost"
|
||||
}
|
||||
},
|
||||
"topLevelUnknown": "will be lost"
|
||||
}`
|
||||
os.WriteFile(configPath, []byte(initialConfig), 0o644)
|
||||
|
||||
// Update aliases
|
||||
if err := saveAliases("claude", map[string]string{"primary": "model2"}); err != nil {
|
||||
t.Fatalf("failed to save: %v", err)
|
||||
}
|
||||
|
||||
// Read raw file to check
|
||||
data, _ := os.ReadFile(configPath)
|
||||
content := string(data)
|
||||
|
||||
// models should be preserved
|
||||
if !contains(content, "model1") {
|
||||
t.Error("models should be preserved")
|
||||
}
|
||||
|
||||
// primary should be updated
|
||||
if !contains(content, "model2") {
|
||||
t.Error("primary should be updated to model2")
|
||||
}
|
||||
}
|
||||
|
||||
func contains(s, substr string) bool {
|
||||
return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsHelper(s, substr))
|
||||
}
|
||||
|
||||
func containsHelper(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// TestClaudeImplementsAliasConfigurer verifies Claude implements the interface
|
||||
func TestClaudeImplementsAliasConfigurer(t *testing.T) {
|
||||
c := &Claude{}
|
||||
var _ AliasConfigurer = c // Compile-time check
|
||||
}
|
||||
|
||||
// TestModelNameEdgeCases tests various model name formats
|
||||
func TestModelNameEdgeCases(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
model string
|
||||
}{
|
||||
{"simple", "llama3.2"},
|
||||
{"with tag", "llama3.2:latest"},
|
||||
{"with cloud tag", "kimi-k2.5:cloud"},
|
||||
{"with namespace", "library/llama3.2"},
|
||||
{"with dots", "glm-4.7-flash"},
|
||||
{"with numbers", "qwen3:8b"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
aliases := map[string]string{"primary": tc.model}
|
||||
if err := saveAliases("claude", aliases); err != nil {
|
||||
t.Fatalf("failed to save model %q: %v", tc.model, err)
|
||||
}
|
||||
|
||||
loaded, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load: %v", err)
|
||||
}
|
||||
if loaded.Aliases["primary"] != tc.model {
|
||||
t.Errorf("expected primary=%q, got %q", tc.model, loaded.Aliases["primary"])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSwitchingScenarios tests the full switching behavior
|
||||
func TestSwitchingScenarios(t *testing.T) {
|
||||
t.Run("cloud to local removes fast", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Initial cloud config
|
||||
if err := saveAliases("claude", map[string]string{
|
||||
"primary": "cloud-model",
|
||||
"fast": "cloud-model",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Switch to local (no fast)
|
||||
if err := saveAliases("claude", map[string]string{
|
||||
"primary": "local-model",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := loadIntegration("claude")
|
||||
if loaded.Aliases["fast"] != "" {
|
||||
t.Errorf("fast should be removed, got %q", loaded.Aliases["fast"])
|
||||
}
|
||||
if loaded.Aliases["primary"] != "local-model" {
|
||||
t.Errorf("primary should be local-model, got %q", loaded.Aliases["primary"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("local to cloud adds fast", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Initial local config
|
||||
if err := saveAliases("claude", map[string]string{
|
||||
"primary": "local-model",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Switch to cloud (with fast)
|
||||
if err := saveAliases("claude", map[string]string{
|
||||
"primary": "cloud-model",
|
||||
"fast": "cloud-model",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := loadIntegration("claude")
|
||||
if loaded.Aliases["fast"] != "cloud-model" {
|
||||
t.Errorf("fast should be cloud-model, got %q", loaded.Aliases["fast"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("cloud to different cloud updates both", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Initial cloud config
|
||||
if err := saveAliases("claude", map[string]string{
|
||||
"primary": "cloud-model-1",
|
||||
"fast": "cloud-model-1",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Switch to different cloud
|
||||
if err := saveAliases("claude", map[string]string{
|
||||
"primary": "cloud-model-2",
|
||||
"fast": "cloud-model-2",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := loadIntegration("claude")
|
||||
if loaded.Aliases["primary"] != "cloud-model-2" {
|
||||
t.Errorf("primary should be cloud-model-2, got %q", loaded.Aliases["primary"])
|
||||
}
|
||||
if loaded.Aliases["fast"] != "cloud-model-2" {
|
||||
t.Errorf("fast should be cloud-model-2, got %q", loaded.Aliases["fast"])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// mockRunner is a test runner that doesn't call the real server
|
||||
type mockRunner struct{}
|
||||
|
||||
func (m *mockRunner) Run(model string, args []string) error { return nil }
|
||||
func (m *mockRunner) String() string { return "mock" }
|
||||
|
||||
// TestToolCapabilityFiltering tests the tool capability detection logic
|
||||
func TestToolCapabilityFiltering(t *testing.T) {
|
||||
t.Run("all models checked for tool capability", func(t *testing.T) {
|
||||
// Both cloud and local models are checked for tool capability via Show API
|
||||
// Only models with "tools" in capabilities are included
|
||||
m := modelInfo{Name: "tool-model", Remote: false, ToolCapable: true}
|
||||
if !m.ToolCapable {
|
||||
t.Error("tool capable model should be marked as such")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("modelInfo includes ToolCapable field", func(t *testing.T) {
|
||||
m := modelInfo{Name: "test", Remote: true, ToolCapable: true}
|
||||
if !m.ToolCapable {
|
||||
t.Error("ToolCapable field should be accessible")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestIsCloudModel_SuffixLogic tests the cloud model suffix detection
|
||||
func TestIsCloudModel_SuffixLogic(t *testing.T) {
|
||||
t.Run("model with :cloud suffix is cloud", func(t *testing.T) {
|
||||
// nil client tests suffix-only path
|
||||
if !isCloudModel(context.Background(), nil, "model:cloud") {
|
||||
t.Error("model with :cloud suffix should be detected as cloud")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("model without :cloud suffix is not cloud (suffix check)", func(t *testing.T) {
|
||||
// nil client means only suffix is checked
|
||||
if isCloudModel(context.Background(), nil, "local-model") {
|
||||
t.Error("model without :cloud suffix should not be detected as cloud")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestModelsAndAliasesMustStayInSync verifies that models[0] and aliases["primary"]
|
||||
// stay synchronized. This prevents the bug where they can diverge.
|
||||
func TestModelsAndAliasesMustStayInSync(t *testing.T) {
|
||||
t.Run("saveAliases followed by saveIntegration keeps them in sync", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Save aliases with one model
|
||||
if err := saveAliases("claude", map[string]string{"primary": "model-a"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Save integration with same model (this is the pattern we use)
|
||||
if err := saveIntegration("claude", []string{"model-a"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := loadIntegration("claude")
|
||||
if loaded.Aliases["primary"] != loaded.Models[0] {
|
||||
t.Errorf("aliases.primary (%q) != models[0] (%q)", loaded.Aliases["primary"], loaded.Models[0])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("out of sync config is detectable", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Simulate out-of-sync state (like manual edit or bug)
|
||||
if err := saveIntegration("claude", []string{"old-model"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := saveAliases("claude", map[string]string{"primary": "new-model"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := loadIntegration("claude")
|
||||
|
||||
// They should be different (this is the bug state)
|
||||
if loaded.Models[0] == loaded.Aliases["primary"] {
|
||||
t.Error("expected out-of-sync state for this test")
|
||||
}
|
||||
|
||||
// The fix: when updating aliases, also update models
|
||||
if err := saveIntegration("claude", []string{loaded.Aliases["primary"]}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ = loadIntegration("claude")
|
||||
if loaded.Models[0] != loaded.Aliases["primary"] {
|
||||
t.Errorf("after fix: models[0] (%q) should equal aliases.primary (%q)",
|
||||
loaded.Models[0], loaded.Aliases["primary"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("updating primary alias updates models too", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Initial state
|
||||
if err := saveIntegration("claude", []string{"initial-model"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := saveAliases("claude", map[string]string{"primary": "initial-model"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Update aliases AND models together
|
||||
newAliases := map[string]string{"primary": "updated-model"}
|
||||
if err := saveAliases("claude", newAliases); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := saveIntegration("claude", []string{newAliases["primary"]}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, _ := loadIntegration("claude")
|
||||
if loaded.Models[0] != "updated-model" {
|
||||
t.Errorf("models[0] should be updated-model, got %q", loaded.Models[0])
|
||||
}
|
||||
if loaded.Aliases["primary"] != "updated-model" {
|
||||
t.Errorf("aliases.primary should be updated-model, got %q", loaded.Aliases["primary"])
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -46,6 +46,53 @@ func TestIntegrationConfig(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("save and load aliases", func(t *testing.T) {
|
||||
models := []string{"llama3.2"}
|
||||
if err := saveIntegration("claude", models); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
aliases := map[string]string{
|
||||
"primary": "llama3.2:70b",
|
||||
"fast": "llama3.2:8b",
|
||||
}
|
||||
if err := saveAliases("claude", aliases); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if config.Aliases == nil {
|
||||
t.Fatal("expected aliases to be saved")
|
||||
}
|
||||
for k, v := range aliases {
|
||||
if config.Aliases[k] != v {
|
||||
t.Errorf("alias %s: expected %s, got %s", k, v, config.Aliases[k])
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("saveIntegration preserves aliases", func(t *testing.T) {
|
||||
if err := saveIntegration("claude", []string{"model-a"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := saveAliases("claude", map[string]string{"primary": "model-a", "fast": "model-small"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := saveIntegration("claude", []string{"model-b"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
config, err := loadIntegration("claude")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if config.Aliases["primary"] != "model-a" {
|
||||
t.Errorf("expected aliases to be preserved, got %v", config.Aliases)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("defaultModel returns first model", func(t *testing.T) {
|
||||
saveIntegration("codex", []string{"model-a", "model-b"})
|
||||
|
||||
@@ -200,12 +247,10 @@ func TestLoadIntegration_CorruptedJSON(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
// Create corrupted config.json file
|
||||
dir := filepath.Join(tmpDir, ".ollama", "config")
|
||||
dir := filepath.Join(tmpDir, ".ollama")
|
||||
os.MkdirAll(dir, 0o755)
|
||||
os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{corrupted json`), 0o644)
|
||||
|
||||
// Corrupted file is treated as empty, so loadIntegration returns not found
|
||||
_, err := loadIntegration("test")
|
||||
if err == nil {
|
||||
t.Error("expected error for nonexistent integration in corrupted file")
|
||||
@@ -267,7 +312,7 @@ func TestConfigPath(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expected := filepath.Join(tmpDir, ".ollama", "config", "config.json")
|
||||
expected := filepath.Join(tmpDir, ".ollama", "config.json")
|
||||
if path != expected {
|
||||
t.Errorf("expected %s, got %s", expected, path)
|
||||
}
|
||||
@@ -322,6 +367,183 @@ func TestLoad(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestMigrateConfig(t *testing.T) {
|
||||
t.Run("migrates legacy file to new location", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
data := []byte(`{"integrations":{"claude":{"models":["llama3.2"]}}}`)
|
||||
os.WriteFile(filepath.Join(legacyDir, "config.json"), data, 0o644)
|
||||
|
||||
migrated, err := migrateConfig()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !migrated {
|
||||
t.Fatal("expected migration to occur")
|
||||
}
|
||||
|
||||
newPath, _ := configPath()
|
||||
got, err := os.ReadFile(newPath)
|
||||
if err != nil {
|
||||
t.Fatalf("new config not found: %v", err)
|
||||
}
|
||||
if string(got) != string(data) {
|
||||
t.Errorf("content mismatch: got %s", got)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(filepath.Join(legacyDir, "config.json")); !os.IsNotExist(err) {
|
||||
t.Error("legacy file should have been removed")
|
||||
}
|
||||
|
||||
if _, err := os.Stat(legacyDir); !os.IsNotExist(err) {
|
||||
t.Error("legacy directory should have been removed")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no-op when no legacy file exists", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
migrated, err := migrateConfig()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if migrated {
|
||||
t.Error("expected no migration")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("skips corrupt legacy file", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{corrupt`), 0o644)
|
||||
|
||||
migrated, err := migrateConfig()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if migrated {
|
||||
t.Error("should not migrate corrupt file")
|
||||
}
|
||||
|
||||
if _, err := os.Stat(filepath.Join(legacyDir, "config.json")); os.IsNotExist(err) {
|
||||
t.Error("corrupt legacy file should not have been deleted")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("new path takes precedence over legacy", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{"integrations":{"old":{"models":["old-model"]}}}`), 0o644)
|
||||
|
||||
newDir := filepath.Join(tmpDir, ".ollama")
|
||||
os.WriteFile(filepath.Join(newDir, "config.json"), []byte(`{"integrations":{"new":{"models":["new-model"]}}}`), 0o644)
|
||||
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, ok := cfg.Integrations["new"]; !ok {
|
||||
t.Error("expected new-path integration to be loaded")
|
||||
}
|
||||
if _, ok := cfg.Integrations["old"]; ok {
|
||||
t.Error("legacy integration should not have been loaded")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("idempotent when called twice", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{"integrations":{}}`), 0o644)
|
||||
|
||||
if _, err := migrateConfig(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
migrated, err := migrateConfig()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if migrated {
|
||||
t.Error("second migration should be a no-op")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("legacy directory preserved if not empty", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{"integrations":{}}`), 0o644)
|
||||
os.WriteFile(filepath.Join(legacyDir, "other-file.txt"), []byte("keep me"), 0o644)
|
||||
|
||||
if _, err := migrateConfig(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(legacyDir); os.IsNotExist(err) {
|
||||
t.Error("directory with other files should not have been removed")
|
||||
}
|
||||
if _, err := os.Stat(filepath.Join(legacyDir, "other-file.txt")); os.IsNotExist(err) {
|
||||
t.Error("other files in legacy directory should be untouched")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("save writes to new path after migration", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{"integrations":{"claude":{"models":["llama3.2"]}}}`), 0o644)
|
||||
|
||||
// load triggers migration, then save should write to new path
|
||||
if err := saveIntegration("codex", []string{"qwen2.5"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
newPath := filepath.Join(tmpDir, ".ollama", "config.json")
|
||||
if _, err := os.Stat(newPath); os.IsNotExist(err) {
|
||||
t.Error("save should write to new path")
|
||||
}
|
||||
|
||||
// old path should not be recreated
|
||||
if _, err := os.Stat(filepath.Join(legacyDir, "config.json")); !os.IsNotExist(err) {
|
||||
t.Error("save should not recreate legacy path")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("load triggers migration transparently", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
legacyDir := filepath.Join(tmpDir, ".ollama", "config")
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{"integrations":{"claude":{"models":["llama3.2"]}}}`), 0o644)
|
||||
|
||||
cfg, err := load()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if cfg.Integrations["claude"] == nil || cfg.Integrations["claude"].Models[0] != "llama3.2" {
|
||||
t.Error("migration via load() did not preserve data")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSave(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
@@ -7,6 +7,8 @@ import (
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
// Droid implements Runner and Editor for Droid integration
|
||||
@@ -37,7 +39,7 @@ type modelEntry struct {
|
||||
|
||||
func (d *Droid) String() string { return "Droid" }
|
||||
|
||||
func (d *Droid) Run(model string) error {
|
||||
func (d *Droid) Run(model string, args []string) error {
|
||||
if _, err := exec.LookPath("droid"); err != nil {
|
||||
return fmt.Errorf("droid is not installed, install from https://docs.factory.ai/cli/getting-started/quickstart")
|
||||
}
|
||||
@@ -51,7 +53,7 @@ func (d *Droid) Run(model string) error {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
|
||||
cmd := exec.Command("droid")
|
||||
cmd := exec.Command("droid", args...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
@@ -117,7 +119,7 @@ func (d *Droid) Edit(models []string) error {
|
||||
newModels = append(newModels, modelEntry{
|
||||
Model: model,
|
||||
DisplayName: model,
|
||||
BaseURL: "http://localhost:11434/v1",
|
||||
BaseURL: envconfig.Host().String() + "/v1",
|
||||
APIKey: "ollama",
|
||||
Provider: "generic-chat-completion-api",
|
||||
MaxOutputTokens: 64000,
|
||||
|
||||
@@ -218,7 +218,7 @@ func TestDroidEdit(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
if model["baseUrl"] != "http://localhost:11434/v1" {
|
||||
if model["baseUrl"] != "http://127.0.0.1:11434/v1" {
|
||||
t.Errorf("unexpected baseUrl: %s", model["baseUrl"])
|
||||
}
|
||||
if model["apiKey"] != "ollama" {
|
||||
@@ -447,7 +447,7 @@ const testDroidSettingsFixture = `{
|
||||
{
|
||||
"model": "existing-ollama-model",
|
||||
"displayName": "existing-ollama-model",
|
||||
"baseUrl": "http://localhost:11434/v1",
|
||||
"baseUrl": "http://127.0.0.1:11434/v1",
|
||||
"apiKey": "ollama",
|
||||
"provider": "generic-chat-completion-api",
|
||||
"maxOutputTokens": 64000,
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
@@ -22,7 +23,7 @@ import (
|
||||
// Runner can run an integration with a model.
|
||||
|
||||
type Runner interface {
|
||||
Run(model string) error
|
||||
Run(model string, args []string) error
|
||||
// String returns the human-readable name of the integration
|
||||
String() string
|
||||
}
|
||||
@@ -38,12 +39,39 @@ type Editor interface {
|
||||
Models() []string
|
||||
}
|
||||
|
||||
// AliasConfigurer can configure model aliases (e.g., for subagent routing).
|
||||
// Integrations like Claude and Codex use this to route model requests to local models.
|
||||
type AliasConfigurer interface {
|
||||
// ConfigureAliases prompts the user to configure aliases and returns the updated map.
|
||||
ConfigureAliases(ctx context.Context, primaryModel string, existing map[string]string, force bool) (map[string]string, bool, error)
|
||||
// SetAliases syncs the configured aliases to the server
|
||||
SetAliases(ctx context.Context, aliases map[string]string) error
|
||||
}
|
||||
|
||||
// integrations is the registry of available integrations.
|
||||
var integrations = map[string]Runner{
|
||||
"claude": &Claude{},
|
||||
"clawdbot": &Openclaw{},
|
||||
"codex": &Codex{},
|
||||
"moltbot": &Openclaw{},
|
||||
"droid": &Droid{},
|
||||
"opencode": &OpenCode{},
|
||||
"openclaw": &Openclaw{},
|
||||
}
|
||||
|
||||
// recommendedModels are shown when the user has no models or as suggestions.
|
||||
// Order matters: local models first, then cloud models.
|
||||
var recommendedModels = []selectItem{
|
||||
{Name: "glm-4.7-flash", Description: "Recommended (requires ~25GB VRAM)"},
|
||||
{Name: "qwen3:8b", Description: "Recommended (requires ~11GB VRAM)"},
|
||||
{Name: "glm-4.7:cloud", Description: "Recommended"},
|
||||
{Name: "kimi-k2.5:cloud", Description: "Recommended"},
|
||||
}
|
||||
|
||||
// integrationAliases are hidden from the interactive selector but work as CLI arguments.
|
||||
var integrationAliases = map[string]bool{
|
||||
"clawdbot": true,
|
||||
"moltbot": true,
|
||||
}
|
||||
|
||||
func selectIntegration() (string, error) {
|
||||
@@ -54,6 +82,9 @@ func selectIntegration() (string, error) {
|
||||
names := slices.Sorted(maps.Keys(integrations))
|
||||
var items []selectItem
|
||||
for _, name := range names {
|
||||
if integrationAliases[name] {
|
||||
continue
|
||||
}
|
||||
r := integrations[name]
|
||||
description := r.String()
|
||||
if conn, err := loadIntegration(name); err == nil && len(conn.Models) > 0 {
|
||||
@@ -82,152 +113,190 @@ func selectModels(ctx context.Context, name, current string) ([]string, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(models.Models) == 0 {
|
||||
return nil, fmt.Errorf("no models available, run 'ollama pull <model>' first")
|
||||
}
|
||||
|
||||
var items []selectItem
|
||||
cloudModels := make(map[string]bool)
|
||||
var existing []modelInfo
|
||||
for _, m := range models.Models {
|
||||
if m.RemoteModel != "" {
|
||||
cloudModels[m.Name] = true
|
||||
}
|
||||
items = append(items, selectItem{Name: m.Name})
|
||||
existing = append(existing, modelInfo{Name: m.Name, Remote: m.RemoteModel != ""})
|
||||
}
|
||||
|
||||
if len(items) == 0 {
|
||||
return nil, fmt.Errorf("no local models available, run 'ollama pull <model>' first")
|
||||
}
|
||||
|
||||
// Get previously configured models (saved config takes precedence)
|
||||
var preChecked []string
|
||||
if saved, err := loadIntegration(name); err == nil {
|
||||
preChecked = saved.Models
|
||||
} else if editor, ok := r.(Editor); ok {
|
||||
preChecked = editor.Models()
|
||||
}
|
||||
checked := make(map[string]bool, len(preChecked))
|
||||
for _, n := range preChecked {
|
||||
checked[n] = true
|
||||
}
|
||||
|
||||
// Resolve current to full name (e.g., "llama3.2" -> "llama3.2:latest")
|
||||
for _, item := range items {
|
||||
if item.Name == current || strings.HasPrefix(item.Name, current+":") {
|
||||
current = item.Name
|
||||
break
|
||||
}
|
||||
}
|
||||
items, preChecked, existingModels, cloudModels := buildModelList(existing, preChecked, current)
|
||||
|
||||
// If current model is configured, move to front of preChecked
|
||||
if checked[current] {
|
||||
preChecked = append([]string{current}, slices.DeleteFunc(preChecked, func(m string) bool { return m == current })...)
|
||||
if len(items) == 0 {
|
||||
return nil, fmt.Errorf("no models available")
|
||||
}
|
||||
|
||||
// Sort: checked first, then alphabetical
|
||||
slices.SortFunc(items, func(a, b selectItem) int {
|
||||
ac, bc := checked[a.Name], checked[b.Name]
|
||||
if ac != bc {
|
||||
if ac {
|
||||
return -1
|
||||
}
|
||||
return 1
|
||||
}
|
||||
return strings.Compare(strings.ToLower(a.Name), strings.ToLower(b.Name))
|
||||
})
|
||||
|
||||
var selected []string
|
||||
// only editors support multi-model selection
|
||||
if _, ok := r.(Editor); ok {
|
||||
selected, err = multiSelectPrompt(fmt.Sprintf("Select models for %s:", r), items, preChecked)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
model, err := selectPrompt(fmt.Sprintf("Select model for %s:", r), items)
|
||||
prompt := fmt.Sprintf("Select model for %s:", r)
|
||||
if _, ok := r.(AliasConfigurer); ok {
|
||||
prompt = fmt.Sprintf("Select Primary model for %s:", r)
|
||||
}
|
||||
model, err := selectPrompt(prompt, items)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
selected = []string{model}
|
||||
}
|
||||
|
||||
// if any model in selected is a cloud model, ensure signed in
|
||||
var toPull []string
|
||||
for _, m := range selected {
|
||||
if !existingModels[m] {
|
||||
toPull = append(toPull, m)
|
||||
}
|
||||
}
|
||||
if len(toPull) > 0 {
|
||||
msg := fmt.Sprintf("Download %s?", strings.Join(toPull, ", "))
|
||||
if ok, err := confirmPrompt(msg); err != nil {
|
||||
return nil, err
|
||||
} else if !ok {
|
||||
return nil, errCancelled
|
||||
}
|
||||
for _, m := range toPull {
|
||||
fmt.Fprintf(os.Stderr, "\n")
|
||||
if err := pullModel(ctx, client, m); err != nil {
|
||||
return nil, fmt.Errorf("failed to pull %s: %w", m, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := ensureAuth(ctx, client, cloudModels, selected); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
func pullIfNeeded(ctx context.Context, client *api.Client, existingModels map[string]bool, model string) error {
|
||||
if existingModels[model] {
|
||||
return nil
|
||||
}
|
||||
msg := fmt.Sprintf("Download %s?", model)
|
||||
if ok, err := confirmPrompt(msg); err != nil {
|
||||
return err
|
||||
} else if !ok {
|
||||
return errCancelled
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "\n")
|
||||
if err := pullModel(ctx, client, model); err != nil {
|
||||
return fmt.Errorf("failed to pull %s: %w", model, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func listModels(ctx context.Context) ([]selectItem, map[string]bool, map[string]bool, *api.Client, error) {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
}
|
||||
|
||||
models, err := client.List(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
}
|
||||
|
||||
var existing []modelInfo
|
||||
for _, m := range models.Models {
|
||||
existing = append(existing, modelInfo{
|
||||
Name: m.Name,
|
||||
Remote: m.RemoteModel != "",
|
||||
})
|
||||
}
|
||||
|
||||
items, _, existingModels, cloudModels := buildModelList(existing, nil, "")
|
||||
|
||||
if len(items) == 0 {
|
||||
return nil, nil, nil, nil, fmt.Errorf("no models available, run 'ollama pull <model>' first")
|
||||
}
|
||||
|
||||
return items, existingModels, cloudModels, client, nil
|
||||
}
|
||||
|
||||
func ensureAuth(ctx context.Context, client *api.Client, cloudModels map[string]bool, selected []string) error {
|
||||
var selectedCloudModels []string
|
||||
for _, m := range selected {
|
||||
if cloudModels[m] {
|
||||
selectedCloudModels = append(selectedCloudModels, m)
|
||||
}
|
||||
}
|
||||
if len(selectedCloudModels) > 0 {
|
||||
// ensure user is signed in
|
||||
user, err := client.Whoami(ctx)
|
||||
if err == nil && user != nil && user.Name != "" {
|
||||
return selected, nil
|
||||
}
|
||||
if len(selectedCloudModels) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var aErr api.AuthorizationError
|
||||
if !errors.As(err, &aErr) || aErr.SigninURL == "" {
|
||||
return nil, err
|
||||
}
|
||||
user, err := client.Whoami(ctx)
|
||||
if err == nil && user != nil && user.Name != "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
modelList := strings.Join(selectedCloudModels, ", ")
|
||||
yes, err := confirmPrompt(fmt.Sprintf("sign in to use %s?", modelList))
|
||||
if err != nil || !yes {
|
||||
return nil, fmt.Errorf("%s requires sign in", modelList)
|
||||
}
|
||||
var aErr api.AuthorizationError
|
||||
if !errors.As(err, &aErr) || aErr.SigninURL == "" {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\nTo sign in, navigate to:\n %s\n\n", aErr.SigninURL)
|
||||
modelList := strings.Join(selectedCloudModels, ", ")
|
||||
yes, err := confirmPrompt(fmt.Sprintf("sign in to use %s?", modelList))
|
||||
if err != nil || !yes {
|
||||
return fmt.Errorf("%s requires sign in", modelList)
|
||||
}
|
||||
|
||||
// TODO(parthsareen): extract into auth package for cmd
|
||||
// Auto-open browser (best effort, fail silently)
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
_ = exec.Command("open", aErr.SigninURL).Start()
|
||||
case "linux":
|
||||
_ = exec.Command("xdg-open", aErr.SigninURL).Start()
|
||||
case "windows":
|
||||
_ = exec.Command("rundll32", "url.dll,FileProtocolHandler", aErr.SigninURL).Start()
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "\nTo sign in, navigate to:\n %s\n\n", aErr.SigninURL)
|
||||
|
||||
spinnerFrames := []string{"|", "/", "-", "\\"}
|
||||
frame := 0
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
_ = exec.Command("open", aErr.SigninURL).Start()
|
||||
case "linux":
|
||||
_ = exec.Command("xdg-open", aErr.SigninURL).Start()
|
||||
case "windows":
|
||||
_ = exec.Command("rundll32", "url.dll,FileProtocolHandler", aErr.SigninURL).Start()
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[0])
|
||||
spinnerFrames := []string{"|", "/", "-", "\\"}
|
||||
frame := 0
|
||||
|
||||
ticker := time.NewTicker(200 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
fmt.Fprintf(os.Stderr, "\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[0])
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K")
|
||||
return nil, ctx.Err()
|
||||
case <-ticker.C:
|
||||
frame++
|
||||
fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)])
|
||||
ticker := time.NewTicker(200 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
// poll every 10th frame (~2 seconds)
|
||||
if frame%10 == 0 {
|
||||
u, err := client.Whoami(ctx)
|
||||
if err == nil && u != nil && u.Name != "" {
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1msigned in:\033[0m %s\n", u.Name)
|
||||
return selected, nil
|
||||
}
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K")
|
||||
return ctx.Err()
|
||||
case <-ticker.C:
|
||||
frame++
|
||||
fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)])
|
||||
|
||||
// poll every 10th frame (~2 seconds)
|
||||
if frame%10 == 0 {
|
||||
u, err := client.Whoami(ctx)
|
||||
if err == nil && u != nil && u.Name != "" {
|
||||
fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1msigned in:\033[0m %s\n", u.Name)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
func runIntegration(name, modelName string) error {
|
||||
func runIntegration(name, modelName string, args []string) error {
|
||||
r, ok := integrations[name]
|
||||
if !ok {
|
||||
return fmt.Errorf("unknown integration: %s", name)
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\nLaunching %s with %s...\n", r, modelName)
|
||||
return r.Run(modelName)
|
||||
return r.Run(modelName, args)
|
||||
}
|
||||
|
||||
// LaunchCmd returns the cobra command for launching integrations.
|
||||
@@ -236,7 +305,7 @@ func LaunchCmd(checkServerHeartbeat func(cmd *cobra.Command, args []string) erro
|
||||
var configFlag bool
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "launch [INTEGRATION]",
|
||||
Use: "launch [INTEGRATION] [-- [EXTRA_ARGS...]]",
|
||||
Short: "Launch an integration with Ollama",
|
||||
Long: `Launch an integration configured with Ollama models.
|
||||
|
||||
@@ -245,19 +314,43 @@ Supported integrations:
|
||||
codex Codex
|
||||
droid Droid
|
||||
opencode OpenCode
|
||||
openclaw OpenClaw (aliases: clawdbot, moltbot)
|
||||
|
||||
Examples:
|
||||
ollama launch
|
||||
ollama launch claude
|
||||
ollama launch claude --model <model>
|
||||
ollama launch droid --config (does not auto-launch)`,
|
||||
Args: cobra.MaximumNArgs(1),
|
||||
ollama launch droid --config (does not auto-launch)
|
||||
ollama launch codex -- -p myprofile (pass extra args to integration)
|
||||
ollama launch codex -- --sandbox workspace-write`,
|
||||
Args: cobra.ArbitraryArgs,
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
// Extract integration name and args to pass through using -- separator
|
||||
var name string
|
||||
if len(args) > 0 {
|
||||
name = args[0]
|
||||
var passArgs []string
|
||||
dashIdx := cmd.ArgsLenAtDash()
|
||||
|
||||
if dashIdx == -1 {
|
||||
// No "--" separator: only allow 0 or 1 args (integration name)
|
||||
if len(args) > 1 {
|
||||
return fmt.Errorf("unexpected arguments: %v\nUse '--' to pass extra arguments to the integration", args[1:])
|
||||
}
|
||||
if len(args) == 1 {
|
||||
name = args[0]
|
||||
}
|
||||
} else {
|
||||
// "--" was used: args before it = integration name, args after = passthrough
|
||||
if dashIdx > 1 {
|
||||
return fmt.Errorf("expected at most 1 integration name before '--', got %d", dashIdx)
|
||||
}
|
||||
if dashIdx == 1 {
|
||||
name = args[0]
|
||||
}
|
||||
passArgs = args[dashIdx:]
|
||||
}
|
||||
|
||||
if name == "" {
|
||||
var err error
|
||||
name, err = selectIntegration()
|
||||
if errors.Is(err, errCancelled) {
|
||||
@@ -273,16 +366,89 @@ Examples:
|
||||
return fmt.Errorf("unknown integration: %s", name)
|
||||
}
|
||||
|
||||
// If launching without --model, use saved config if available
|
||||
if !configFlag && modelFlag == "" {
|
||||
if config, err := loadIntegration(name); err == nil && len(config.Models) > 0 {
|
||||
return runIntegration(name, config.Models[0])
|
||||
if modelFlag != "" {
|
||||
if err := saveIntegration(name, []string{modelFlag}); err != nil {
|
||||
return fmt.Errorf("failed to save: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if !configFlag {
|
||||
if config, err := loadIntegration(name); err == nil && len(config.Models) > 0 {
|
||||
model := config.Models[0]
|
||||
if config.Aliases == nil {
|
||||
config.Aliases = make(map[string]string)
|
||||
}
|
||||
|
||||
client, _ := api.ClientFromEnvironment()
|
||||
needsLocalSave := false
|
||||
|
||||
if config.Aliases["primary"] != model {
|
||||
config.Aliases["primary"] = model
|
||||
needsLocalSave = true
|
||||
}
|
||||
|
||||
if isCloudModel(cmd.Context(), client, model) {
|
||||
if config.Aliases["fast"] == "" || !isCloudModel(cmd.Context(), client, config.Aliases["fast"]) {
|
||||
config.Aliases["fast"] = model
|
||||
needsLocalSave = true
|
||||
}
|
||||
} else if config.Aliases["fast"] != "" {
|
||||
delete(config.Aliases, "fast")
|
||||
needsLocalSave = true
|
||||
}
|
||||
|
||||
if ac, ok := r.(AliasConfigurer); ok {
|
||||
if err := ac.SetAliases(cmd.Context(), config.Aliases); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%sWarning: Could not update server aliases: %v%s\n", ansiGray, err, ansiReset)
|
||||
}
|
||||
}
|
||||
|
||||
if needsLocalSave {
|
||||
_ = saveAliases(name, config.Aliases)
|
||||
}
|
||||
|
||||
return runIntegration(name, model, passArgs)
|
||||
}
|
||||
}
|
||||
|
||||
if ac, ok := r.(AliasConfigurer); ok {
|
||||
var existingAliases map[string]string
|
||||
if existing, err := loadIntegration(name); err == nil {
|
||||
existingAliases = existing.Aliases
|
||||
}
|
||||
aliases, updated, err := ac.ConfigureAliases(cmd.Context(), "", existingAliases, configFlag)
|
||||
if errors.Is(err, errCancelled) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if updated {
|
||||
// Atomic: sync to server FIRST, only save locally if server succeeds
|
||||
if err := ac.SetAliases(cmd.Context(), aliases); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%sWarning: Could not update server aliases: %v%s\n", ansiGray, err, ansiReset)
|
||||
fmt.Fprintf(os.Stderr, "%sConfiguration not saved.%s\n\n", ansiGray, ansiReset)
|
||||
// Still launch with the model, but don't save config
|
||||
} else {
|
||||
if err := saveAliases(name, aliases); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := saveIntegration(name, []string{aliases["primary"]}); err != nil {
|
||||
return fmt.Errorf("failed to save: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
if configFlag {
|
||||
if launch, _ := confirmPrompt(fmt.Sprintf("Launch %s now?", r)); launch {
|
||||
return runIntegration(name, aliases["primary"], passArgs)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return runIntegration(name, aliases["primary"], passArgs)
|
||||
}
|
||||
|
||||
var models []string
|
||||
if modelFlag != "" {
|
||||
// When --model is specified, merge with existing models (new model becomes default)
|
||||
models = []string{modelFlag}
|
||||
if existing, err := loadIntegration(name); err == nil && len(existing.Models) > 0 {
|
||||
for _, m := range existing.Models {
|
||||
@@ -337,13 +503,13 @@ Examples:
|
||||
|
||||
if configFlag {
|
||||
if launch, _ := confirmPrompt(fmt.Sprintf("\nLaunch %s now?", r)); launch {
|
||||
return runIntegration(name, models[0])
|
||||
return runIntegration(name, models[0], passArgs)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "Run 'ollama launch %s' to start with %s\n", strings.ToLower(name), models[0])
|
||||
return nil
|
||||
}
|
||||
|
||||
return runIntegration(name, models[0])
|
||||
return runIntegration(name, models[0], passArgs)
|
||||
},
|
||||
}
|
||||
|
||||
@@ -351,3 +517,167 @@ Examples:
|
||||
cmd.Flags().BoolVar(&configFlag, "config", false, "Configure without launching")
|
||||
return cmd
|
||||
}
|
||||
|
||||
type modelInfo struct {
|
||||
Name string
|
||||
Remote bool
|
||||
ToolCapable bool
|
||||
}
|
||||
|
||||
// buildModelList merges existing models with recommendations, sorts them, and returns
|
||||
// the ordered items along with maps of existing and cloud model names.
|
||||
func buildModelList(existing []modelInfo, preChecked []string, current string) (items []selectItem, orderedChecked []string, existingModels, cloudModels map[string]bool) {
|
||||
existingModels = make(map[string]bool)
|
||||
cloudModels = make(map[string]bool)
|
||||
recommended := make(map[string]bool)
|
||||
var hasLocalModel, hasCloudModel bool
|
||||
|
||||
for _, rec := range recommendedModels {
|
||||
recommended[rec.Name] = true
|
||||
}
|
||||
|
||||
for _, m := range existing {
|
||||
existingModels[m.Name] = true
|
||||
if m.Remote {
|
||||
cloudModels[m.Name] = true
|
||||
hasCloudModel = true
|
||||
} else {
|
||||
hasLocalModel = true
|
||||
}
|
||||
displayName := strings.TrimSuffix(m.Name, ":latest")
|
||||
existingModels[displayName] = true
|
||||
item := selectItem{Name: displayName}
|
||||
if recommended[displayName] {
|
||||
item.Description = "recommended"
|
||||
}
|
||||
items = append(items, item)
|
||||
}
|
||||
|
||||
for _, rec := range recommendedModels {
|
||||
if existingModels[rec.Name] || existingModels[rec.Name+":latest"] {
|
||||
continue
|
||||
}
|
||||
items = append(items, rec)
|
||||
if strings.HasSuffix(rec.Name, ":cloud") {
|
||||
cloudModels[rec.Name] = true
|
||||
}
|
||||
}
|
||||
|
||||
checked := make(map[string]bool, len(preChecked))
|
||||
for _, n := range preChecked {
|
||||
checked[n] = true
|
||||
}
|
||||
|
||||
// Resolve current to full name (e.g., "llama3.2" -> "llama3.2:latest")
|
||||
for _, item := range items {
|
||||
if item.Name == current || strings.HasPrefix(item.Name, current+":") {
|
||||
current = item.Name
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if checked[current] {
|
||||
preChecked = append([]string{current}, slices.DeleteFunc(preChecked, func(m string) bool { return m == current })...)
|
||||
}
|
||||
|
||||
// Non-existing models get "install?" suffix and are pushed to the bottom.
|
||||
// When user has no models, preserve recommended order.
|
||||
notInstalled := make(map[string]bool)
|
||||
for i := range items {
|
||||
if !existingModels[items[i].Name] {
|
||||
notInstalled[items[i].Name] = true
|
||||
if items[i].Description != "" {
|
||||
items[i].Description += ", install?"
|
||||
} else {
|
||||
items[i].Description = "install?"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if hasLocalModel || hasCloudModel {
|
||||
slices.SortStableFunc(items, func(a, b selectItem) int {
|
||||
ac, bc := checked[a.Name], checked[b.Name]
|
||||
aNew, bNew := notInstalled[a.Name], notInstalled[b.Name]
|
||||
|
||||
if ac != bc {
|
||||
if ac {
|
||||
return -1
|
||||
}
|
||||
return 1
|
||||
}
|
||||
if !ac && !bc && aNew != bNew {
|
||||
if aNew {
|
||||
return 1
|
||||
}
|
||||
return -1
|
||||
}
|
||||
return strings.Compare(strings.ToLower(a.Name), strings.ToLower(b.Name))
|
||||
})
|
||||
}
|
||||
|
||||
return items, preChecked, existingModels, cloudModels
|
||||
}
|
||||
|
||||
// isCloudModel checks if a model is a cloud model using suffix check and server API fallback.
|
||||
func isCloudModel(ctx context.Context, client *api.Client, name string) bool {
|
||||
// Fast path: check suffix first (no network call)
|
||||
if strings.HasSuffix(name, ":cloud") {
|
||||
return true
|
||||
}
|
||||
// Fallback: check server if client available
|
||||
if client != nil {
|
||||
resp, err := client.Show(ctx, &api.ShowRequest{Name: name})
|
||||
if err == nil && resp.RemoteModel != "" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func pullModel(ctx context.Context, client *api.Client, model string) error {
|
||||
p := progress.NewProgress(os.Stderr)
|
||||
defer p.Stop()
|
||||
|
||||
bars := make(map[string]*progress.Bar)
|
||||
var status string
|
||||
var spinner *progress.Spinner
|
||||
|
||||
fn := func(resp api.ProgressResponse) error {
|
||||
if resp.Digest != "" {
|
||||
if resp.Completed == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if spinner != nil {
|
||||
spinner.Stop()
|
||||
}
|
||||
|
||||
bar, ok := bars[resp.Digest]
|
||||
if !ok {
|
||||
name, isDigest := strings.CutPrefix(resp.Digest, "sha256:")
|
||||
name = strings.TrimSpace(name)
|
||||
if isDigest {
|
||||
name = name[:min(12, len(name))]
|
||||
}
|
||||
bar = progress.NewBar(fmt.Sprintf("pulling %s:", name), resp.Total, resp.Completed)
|
||||
bars[resp.Digest] = bar
|
||||
p.Add(resp.Digest, bar)
|
||||
}
|
||||
|
||||
bar.Set(resp.Completed)
|
||||
} else if status != resp.Status {
|
||||
if spinner != nil {
|
||||
spinner.Stop()
|
||||
}
|
||||
|
||||
status = resp.Status
|
||||
spinner = progress.NewSpinner(status)
|
||||
p.Add(status, spinner)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
request := api.PullRequest{Name: model}
|
||||
return client.Pull(ctx, &request, fn)
|
||||
}
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
@@ -90,8 +93,8 @@ func TestLaunchCmd(t *testing.T) {
|
||||
cmd := LaunchCmd(mockCheck)
|
||||
|
||||
t.Run("command structure", func(t *testing.T) {
|
||||
if cmd.Use != "launch [INTEGRATION]" {
|
||||
t.Errorf("Use = %q, want %q", cmd.Use, "launch [INTEGRATION]")
|
||||
if cmd.Use != "launch [INTEGRATION] [-- [EXTRA_ARGS...]]" {
|
||||
t.Errorf("Use = %q, want %q", cmd.Use, "launch [INTEGRATION] [-- [EXTRA_ARGS...]]")
|
||||
}
|
||||
if cmd.Short == "" {
|
||||
t.Error("Short description should not be empty")
|
||||
@@ -121,7 +124,7 @@ func TestLaunchCmd(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRunIntegration_UnknownIntegration(t *testing.T) {
|
||||
err := runIntegration("unknown-integration", "model")
|
||||
err := runIntegration("unknown-integration", "model", nil)
|
||||
if err == nil {
|
||||
t.Error("expected error for unknown integration, got nil")
|
||||
}
|
||||
@@ -174,15 +177,353 @@ func TestLaunchCmd_NilHeartbeat(t *testing.T) {
|
||||
func TestAllIntegrations_HaveRequiredMethods(t *testing.T) {
|
||||
for name, r := range integrations {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
// Test String() doesn't panic and returns non-empty
|
||||
displayName := r.String()
|
||||
if displayName == "" {
|
||||
t.Error("String() should not return empty")
|
||||
}
|
||||
|
||||
// Test Run() exists (we can't call it without actually running the command)
|
||||
// Just verify the method is available
|
||||
var _ func(string) error = r.Run
|
||||
var _ func(string, []string) error = r.Run
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseArgs(t *testing.T) {
|
||||
// Tests reflect cobra's ArgsLenAtDash() semantics:
|
||||
// - cobra strips "--" from args
|
||||
// - ArgsLenAtDash() returns the index where "--" was, or -1
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string // args as cobra delivers them (no "--")
|
||||
dashIdx int // what ArgsLenAtDash() returns
|
||||
wantName string
|
||||
wantArgs []string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "no extra args, no dash",
|
||||
args: []string{"claude"},
|
||||
dashIdx: -1,
|
||||
wantName: "claude",
|
||||
},
|
||||
{
|
||||
name: "with extra args after --",
|
||||
args: []string{"codex", "-p", "myprofile"},
|
||||
dashIdx: 1,
|
||||
wantName: "codex",
|
||||
wantArgs: []string{"-p", "myprofile"},
|
||||
},
|
||||
{
|
||||
name: "extra args only after --",
|
||||
args: []string{"codex", "--sandbox", "workspace-write"},
|
||||
dashIdx: 1,
|
||||
wantName: "codex",
|
||||
wantArgs: []string{"--sandbox", "workspace-write"},
|
||||
},
|
||||
{
|
||||
name: "-- at end with no args after",
|
||||
args: []string{"claude"},
|
||||
dashIdx: 1,
|
||||
wantName: "claude",
|
||||
},
|
||||
{
|
||||
name: "-- with no integration name",
|
||||
args: []string{"--verbose"},
|
||||
dashIdx: 0,
|
||||
wantName: "",
|
||||
wantArgs: []string{"--verbose"},
|
||||
},
|
||||
{
|
||||
name: "multiple args before -- is error",
|
||||
args: []string{"claude", "codex", "--verbose"},
|
||||
dashIdx: 2,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "multiple args without -- is error",
|
||||
args: []string{"claude", "codex"},
|
||||
dashIdx: -1,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "no args, no dash",
|
||||
args: []string{},
|
||||
dashIdx: -1,
|
||||
wantName: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Simulate the parsing logic from LaunchCmd using dashIdx
|
||||
var name string
|
||||
var parsedArgs []string
|
||||
var err error
|
||||
|
||||
dashIdx := tt.dashIdx
|
||||
args := tt.args
|
||||
|
||||
if dashIdx == -1 {
|
||||
if len(args) > 1 {
|
||||
err = fmt.Errorf("unexpected arguments: %v", args[1:])
|
||||
} else if len(args) == 1 {
|
||||
name = args[0]
|
||||
}
|
||||
} else {
|
||||
if dashIdx > 1 {
|
||||
err = fmt.Errorf("expected at most 1 integration name before '--', got %d", dashIdx)
|
||||
} else {
|
||||
if dashIdx == 1 {
|
||||
name = args[0]
|
||||
}
|
||||
parsedArgs = args[dashIdx:]
|
||||
}
|
||||
}
|
||||
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if name != tt.wantName {
|
||||
t.Errorf("name = %q, want %q", name, tt.wantName)
|
||||
}
|
||||
if !slices.Equal(parsedArgs, tt.wantArgs) {
|
||||
t.Errorf("args = %v, want %v", parsedArgs, tt.wantArgs)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsCloudModel(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
want bool
|
||||
}{
|
||||
{"glm-4.7:cloud", true},
|
||||
{"kimi-k2.5:cloud", true},
|
||||
{"glm-4.7-flash", false},
|
||||
{"glm-4.7-flash:latest", false},
|
||||
{"cloud-model", false},
|
||||
{"model:cloudish", false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// nil client = suffix-only path (no API calls)
|
||||
if got := isCloudModel(context.Background(), nil, tt.name); got != tt.want {
|
||||
t.Errorf("isCloudModel(%q) = %v, want %v", tt.name, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func names(items []selectItem) []string {
|
||||
var out []string
|
||||
for _, item := range items {
|
||||
out = append(out, item.Name)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func TestBuildModelList_NoExistingModels(t *testing.T) {
|
||||
items, _, _, _ := buildModelList(nil, nil, "")
|
||||
|
||||
want := []string{"glm-4.7-flash", "qwen3:8b", "glm-4.7:cloud", "kimi-k2.5:cloud"}
|
||||
if diff := cmp.Diff(want, names(items)); diff != "" {
|
||||
t.Errorf("with no existing models, items should be recommended in order (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
for _, item := range items {
|
||||
if !strings.HasSuffix(item.Description, "install?") {
|
||||
t.Errorf("item %q should have description ending with 'install?', got %q", item.Name, item.Description)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildModelList_OnlyLocalModels_CloudRecsAtBottom(t *testing.T) {
|
||||
existing := []modelInfo{
|
||||
{Name: "llama3.2:latest", Remote: false},
|
||||
{Name: "qwen2.5:latest", Remote: false},
|
||||
}
|
||||
|
||||
items, _, _, _ := buildModelList(existing, nil, "")
|
||||
got := names(items)
|
||||
|
||||
want := []string{"llama3.2", "qwen2.5", "glm-4.7-flash", "glm-4.7:cloud", "kimi-k2.5:cloud", "qwen3:8b"}
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Errorf("cloud recs should be at bottom (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildModelList_BothCloudAndLocal_RegularSort(t *testing.T) {
|
||||
existing := []modelInfo{
|
||||
{Name: "llama3.2:latest", Remote: false},
|
||||
{Name: "glm-4.7:cloud", Remote: true},
|
||||
}
|
||||
|
||||
items, _, _, _ := buildModelList(existing, nil, "")
|
||||
got := names(items)
|
||||
|
||||
want := []string{"glm-4.7:cloud", "llama3.2", "glm-4.7-flash", "kimi-k2.5:cloud", "qwen3:8b"}
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Errorf("mixed models should be alphabetical (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildModelList_PreCheckedFirst(t *testing.T) {
|
||||
existing := []modelInfo{
|
||||
{Name: "llama3.2:latest", Remote: false},
|
||||
{Name: "glm-4.7:cloud", Remote: true},
|
||||
}
|
||||
|
||||
items, _, _, _ := buildModelList(existing, []string{"llama3.2"}, "")
|
||||
got := names(items)
|
||||
|
||||
if got[0] != "llama3.2" {
|
||||
t.Errorf("pre-checked model should be first, got %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildModelList_ExistingRecommendedMarked(t *testing.T) {
|
||||
existing := []modelInfo{
|
||||
{Name: "glm-4.7-flash", Remote: false},
|
||||
{Name: "glm-4.7:cloud", Remote: true},
|
||||
}
|
||||
|
||||
items, _, _, _ := buildModelList(existing, nil, "")
|
||||
|
||||
for _, item := range items {
|
||||
switch item.Name {
|
||||
case "glm-4.7-flash", "glm-4.7:cloud":
|
||||
if strings.HasSuffix(item.Description, "install?") {
|
||||
t.Errorf("installed recommended %q should not have 'install?' suffix, got %q", item.Name, item.Description)
|
||||
}
|
||||
case "kimi-k2.5:cloud", "qwen3:8b":
|
||||
if !strings.HasSuffix(item.Description, "install?") {
|
||||
t.Errorf("non-installed recommended %q should have 'install?' suffix, got %q", item.Name, item.Description)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildModelList_ExistingCloudModelsNotPushedToBottom(t *testing.T) {
|
||||
existing := []modelInfo{
|
||||
{Name: "glm-4.7-flash", Remote: false},
|
||||
{Name: "glm-4.7:cloud", Remote: true},
|
||||
}
|
||||
|
||||
items, _, _, _ := buildModelList(existing, nil, "")
|
||||
got := names(items)
|
||||
|
||||
// glm-4.7-flash and glm-4.7:cloud are installed so they sort normally;
|
||||
// kimi-k2.5:cloud and qwen3:8b are not installed so they go to the bottom
|
||||
want := []string{"glm-4.7-flash", "glm-4.7:cloud", "kimi-k2.5:cloud", "qwen3:8b"}
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Errorf("existing cloud models should sort normally (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildModelList_HasRecommendedCloudModel_OnlyNonInstalledAtBottom(t *testing.T) {
|
||||
existing := []modelInfo{
|
||||
{Name: "llama3.2:latest", Remote: false},
|
||||
{Name: "kimi-k2.5:cloud", Remote: true},
|
||||
}
|
||||
|
||||
items, _, _, _ := buildModelList(existing, nil, "")
|
||||
got := names(items)
|
||||
|
||||
// kimi-k2.5:cloud is installed so it sorts normally;
|
||||
// the rest of the recommendations are not installed so they go to the bottom
|
||||
want := []string{"kimi-k2.5:cloud", "llama3.2", "glm-4.7-flash", "glm-4.7:cloud", "qwen3:8b"}
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Errorf("only non-installed models should be at bottom (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
for _, item := range items {
|
||||
if !slices.Contains([]string{"kimi-k2.5:cloud", "llama3.2"}, item.Name) {
|
||||
if !strings.HasSuffix(item.Description, "install?") {
|
||||
t.Errorf("non-installed %q should have 'install?' suffix, got %q", item.Name, item.Description)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildModelList_LatestTagStripped(t *testing.T) {
|
||||
existing := []modelInfo{
|
||||
{Name: "glm-4.7-flash:latest", Remote: false},
|
||||
{Name: "llama3.2:latest", Remote: false},
|
||||
}
|
||||
|
||||
items, _, existingModels, _ := buildModelList(existing, nil, "")
|
||||
got := names(items)
|
||||
|
||||
// :latest should be stripped from display names
|
||||
for _, name := range got {
|
||||
if strings.HasSuffix(name, ":latest") {
|
||||
t.Errorf("name %q should not have :latest suffix", name)
|
||||
}
|
||||
}
|
||||
|
||||
// glm-4.7-flash should not be duplicated (existing :latest matches the recommendation)
|
||||
count := 0
|
||||
for _, name := range got {
|
||||
if name == "glm-4.7-flash" {
|
||||
count++
|
||||
}
|
||||
}
|
||||
if count != 1 {
|
||||
t.Errorf("glm-4.7-flash should appear exactly once, got %d in %v", count, got)
|
||||
}
|
||||
|
||||
// Stripped name should be in existingModels so it won't be pulled
|
||||
if !existingModels["glm-4.7-flash"] {
|
||||
t.Error("glm-4.7-flash should be in existingModels")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildModelList_ReturnsExistingAndCloudMaps(t *testing.T) {
|
||||
existing := []modelInfo{
|
||||
{Name: "llama3.2:latest", Remote: false},
|
||||
{Name: "glm-4.7:cloud", Remote: true},
|
||||
}
|
||||
|
||||
_, _, existingModels, cloudModels := buildModelList(existing, nil, "")
|
||||
|
||||
if !existingModels["llama3.2"] {
|
||||
t.Error("llama3.2 should be in existingModels")
|
||||
}
|
||||
if !existingModels["glm-4.7:cloud"] {
|
||||
t.Error("glm-4.7:cloud should be in existingModels")
|
||||
}
|
||||
if existingModels["glm-4.7-flash"] {
|
||||
t.Error("glm-4.7-flash should not be in existingModels (it's a recommendation)")
|
||||
}
|
||||
|
||||
if !cloudModels["glm-4.7:cloud"] {
|
||||
t.Error("glm-4.7:cloud should be in cloudModels")
|
||||
}
|
||||
if !cloudModels["kimi-k2.5:cloud"] {
|
||||
t.Error("kimi-k2.5:cloud should be in cloudModels (recommended cloud)")
|
||||
}
|
||||
if cloudModels["llama3.2"] {
|
||||
t.Error("llama3.2 should not be in cloudModels")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAliasConfigurerInterface(t *testing.T) {
|
||||
t.Run("claude implements AliasConfigurer", func(t *testing.T) {
|
||||
claude := &Claude{}
|
||||
if _, ok := interface{}(claude).(AliasConfigurer); !ok {
|
||||
t.Error("Claude should implement AliasConfigurer")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("codex does not implement AliasConfigurer", func(t *testing.T) {
|
||||
codex := &Codex{}
|
||||
if _, ok := interface{}(codex).(AliasConfigurer); ok {
|
||||
t.Error("Codex should not implement AliasConfigurer")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
252
cmd/config/openclaw.go
Normal file
252
cmd/config/openclaw.go
Normal file
@@ -0,0 +1,252 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
type Openclaw struct{}
|
||||
|
||||
func (c *Openclaw) String() string { return "OpenClaw" }
|
||||
|
||||
func (c *Openclaw) Run(model string, args []string) error {
|
||||
bin := "openclaw"
|
||||
if _, err := exec.LookPath(bin); err != nil {
|
||||
bin = "clawdbot"
|
||||
if _, err := exec.LookPath(bin); err != nil {
|
||||
return fmt.Errorf("openclaw is not installed, install from https://docs.openclaw.ai")
|
||||
}
|
||||
}
|
||||
|
||||
models := []string{model}
|
||||
if config, err := loadIntegration("openclaw"); err == nil && len(config.Models) > 0 {
|
||||
models = config.Models
|
||||
} else if config, err := loadIntegration("clawdbot"); err == nil && len(config.Models) > 0 {
|
||||
models = config.Models
|
||||
}
|
||||
if err := c.Edit(models); err != nil {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
|
||||
if !c.onboarded() {
|
||||
// Onboarding not completed: run it (model already set via Edit)
|
||||
// Use "ollama" as gateway token for simple local access
|
||||
cmd := exec.Command(bin, "onboard",
|
||||
"--auth-choice", "skip",
|
||||
"--gateway-token", "ollama",
|
||||
)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
// Onboarding completed: run gateway
|
||||
cmd := exec.Command(bin, append([]string{"gateway"}, args...)...)
|
||||
cmd.Stdin = os.Stdin
|
||||
|
||||
// Capture output to detect "already running" message
|
||||
var outputBuf bytes.Buffer
|
||||
cmd.Stdout = io.MultiWriter(os.Stdout, &outputBuf)
|
||||
cmd.Stderr = io.MultiWriter(os.Stderr, &outputBuf)
|
||||
|
||||
err := cmd.Run()
|
||||
if err != nil && strings.Contains(outputBuf.String(), "Gateway already running") {
|
||||
fmt.Fprintf(os.Stderr, "%sOpenClaw has been configured with Ollama. Gateway is already running.%s\n", ansiGreen, ansiReset)
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// onboarded checks if OpenClaw onboarding wizard was completed
|
||||
// by looking for the wizard.lastRunAt marker in the config
|
||||
func (c *Openclaw) onboarded() bool {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
configPath := filepath.Join(home, ".openclaw", "openclaw.json")
|
||||
legacyPath := filepath.Join(home, ".clawdbot", "clawdbot.json")
|
||||
|
||||
config := make(map[string]any)
|
||||
if data, err := os.ReadFile(configPath); err == nil {
|
||||
_ = json.Unmarshal(data, &config)
|
||||
} else if data, err := os.ReadFile(legacyPath); err == nil {
|
||||
_ = json.Unmarshal(data, &config)
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for wizard.lastRunAt marker (set when onboarding completes)
|
||||
wizard, _ := config["wizard"].(map[string]any)
|
||||
if wizard == nil {
|
||||
return false
|
||||
}
|
||||
lastRunAt, _ := wizard["lastRunAt"].(string)
|
||||
return lastRunAt != ""
|
||||
}
|
||||
|
||||
func (c *Openclaw) Paths() []string {
|
||||
home, _ := os.UserHomeDir()
|
||||
p := filepath.Join(home, ".openclaw", "openclaw.json")
|
||||
if _, err := os.Stat(p); err == nil {
|
||||
return []string{p}
|
||||
}
|
||||
legacy := filepath.Join(home, ".clawdbot", "clawdbot.json")
|
||||
if _, err := os.Stat(legacy); err == nil {
|
||||
return []string{legacy}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Openclaw) Edit(models []string) error {
|
||||
if len(models) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
configPath := filepath.Join(home, ".openclaw", "openclaw.json")
|
||||
legacyPath := filepath.Join(home, ".clawdbot", "clawdbot.json")
|
||||
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Read into map[string]any to preserve unknown fields
|
||||
config := make(map[string]any)
|
||||
if data, err := os.ReadFile(configPath); err == nil {
|
||||
_ = json.Unmarshal(data, &config)
|
||||
} else if data, err := os.ReadFile(legacyPath); err == nil {
|
||||
_ = json.Unmarshal(data, &config)
|
||||
}
|
||||
|
||||
// Navigate/create: models.providers.ollama (preserving other providers)
|
||||
modelsSection, _ := config["models"].(map[string]any)
|
||||
if modelsSection == nil {
|
||||
modelsSection = make(map[string]any)
|
||||
}
|
||||
providers, _ := modelsSection["providers"].(map[string]any)
|
||||
if providers == nil {
|
||||
providers = make(map[string]any)
|
||||
}
|
||||
ollama, _ := providers["ollama"].(map[string]any)
|
||||
if ollama == nil {
|
||||
ollama = make(map[string]any)
|
||||
}
|
||||
|
||||
ollama["baseUrl"] = envconfig.Host().String() + "/v1"
|
||||
// needed to register provider
|
||||
ollama["apiKey"] = "ollama-local"
|
||||
// TODO(parthsareen): potentially move to responses
|
||||
ollama["api"] = "openai-completions"
|
||||
|
||||
// Build map of existing models to preserve user customizations
|
||||
existingModels, _ := ollama["models"].([]any)
|
||||
existingByID := make(map[string]map[string]any)
|
||||
for _, m := range existingModels {
|
||||
if entry, ok := m.(map[string]any); ok {
|
||||
if id, ok := entry["id"].(string); ok {
|
||||
existingByID[id] = entry
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var newModels []any
|
||||
for _, model := range models {
|
||||
entry := map[string]any{
|
||||
"id": model,
|
||||
"name": model,
|
||||
"reasoning": false,
|
||||
"input": []any{"text"},
|
||||
"cost": map[string]any{
|
||||
"input": 0,
|
||||
"output": 0,
|
||||
"cacheRead": 0,
|
||||
"cacheWrite": 0,
|
||||
},
|
||||
// TODO(parthsareen): get these values from API
|
||||
"contextWindow": 131072,
|
||||
"maxTokens": 16384,
|
||||
}
|
||||
// Merge existing fields (user customizations)
|
||||
if existing, ok := existingByID[model]; ok {
|
||||
for k, v := range existing {
|
||||
if _, isNew := entry[k]; !isNew {
|
||||
entry[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
newModels = append(newModels, entry)
|
||||
}
|
||||
ollama["models"] = newModels
|
||||
|
||||
providers["ollama"] = ollama
|
||||
modelsSection["providers"] = providers
|
||||
config["models"] = modelsSection
|
||||
|
||||
// Update agents.defaults.model.primary (preserving other agent settings)
|
||||
agents, _ := config["agents"].(map[string]any)
|
||||
if agents == nil {
|
||||
agents = make(map[string]any)
|
||||
}
|
||||
defaults, _ := agents["defaults"].(map[string]any)
|
||||
if defaults == nil {
|
||||
defaults = make(map[string]any)
|
||||
}
|
||||
modelConfig, _ := defaults["model"].(map[string]any)
|
||||
if modelConfig == nil {
|
||||
modelConfig = make(map[string]any)
|
||||
}
|
||||
modelConfig["primary"] = "ollama/" + models[0]
|
||||
defaults["model"] = modelConfig
|
||||
agents["defaults"] = defaults
|
||||
config["agents"] = agents
|
||||
|
||||
data, err := json.MarshalIndent(config, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return writeWithBackup(configPath, data)
|
||||
}
|
||||
|
||||
func (c *Openclaw) Models() []string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
config, err := readJSONFile(filepath.Join(home, ".openclaw", "openclaw.json"))
|
||||
if err != nil {
|
||||
config, err = readJSONFile(filepath.Join(home, ".clawdbot", "clawdbot.json"))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
modelsSection, _ := config["models"].(map[string]any)
|
||||
providers, _ := modelsSection["providers"].(map[string]any)
|
||||
ollama, _ := providers["ollama"].(map[string]any)
|
||||
modelList, _ := ollama["models"].([]any)
|
||||
|
||||
var result []string
|
||||
for _, m := range modelList {
|
||||
if entry, ok := m.(map[string]any); ok {
|
||||
if id, ok := entry["id"].(string); ok {
|
||||
result = append(result, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
878
cmd/config/openclaw_test.go
Normal file
878
cmd/config/openclaw_test.go
Normal file
@@ -0,0 +1,878 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestOpenclawIntegration(t *testing.T) {
|
||||
c := &Openclaw{}
|
||||
|
||||
t.Run("String", func(t *testing.T) {
|
||||
if got := c.String(); got != "OpenClaw" {
|
||||
t.Errorf("String() = %q, want %q", got, "OpenClaw")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("implements Runner", func(t *testing.T) {
|
||||
var _ Runner = c
|
||||
})
|
||||
|
||||
t.Run("implements Editor", func(t *testing.T) {
|
||||
var _ Editor = c
|
||||
})
|
||||
}
|
||||
|
||||
func TestOpenclawEdit(t *testing.T) {
|
||||
c := &Openclaw{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
configPath := filepath.Join(configDir, "openclaw.json")
|
||||
|
||||
cleanup := func() { os.RemoveAll(configDir) }
|
||||
|
||||
t.Run("fresh install", func(t *testing.T) {
|
||||
cleanup()
|
||||
if err := c.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assertOpenclawModelExists(t, configPath, "llama3.2")
|
||||
assertOpenclawPrimaryModel(t, configPath, "ollama/llama3.2")
|
||||
})
|
||||
|
||||
t.Run("multiple models - first is primary", func(t *testing.T) {
|
||||
cleanup()
|
||||
if err := c.Edit([]string{"llama3.2", "mistral"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assertOpenclawModelExists(t, configPath, "llama3.2")
|
||||
assertOpenclawModelExists(t, configPath, "mistral")
|
||||
assertOpenclawPrimaryModel(t, configPath, "ollama/llama3.2")
|
||||
})
|
||||
|
||||
t.Run("preserve other providers", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"models":{"providers":{"anthropic":{"apiKey":"xxx"}}}}`), 0o644)
|
||||
if err := c.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
models := cfg["models"].(map[string]any)
|
||||
providers := models["providers"].(map[string]any)
|
||||
if providers["anthropic"] == nil {
|
||||
t.Error("anthropic provider was removed")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("preserve top-level keys", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"theme":"dark","mcp":{"servers":{}}}`), 0o644)
|
||||
if err := c.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
if cfg["theme"] != "dark" {
|
||||
t.Error("theme was removed")
|
||||
}
|
||||
if cfg["mcp"] == nil {
|
||||
t.Error("mcp was removed")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("preserve user customizations on models", func(t *testing.T) {
|
||||
cleanup()
|
||||
c.Edit([]string{"llama3.2"})
|
||||
|
||||
// User adds custom field
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
models := cfg["models"].(map[string]any)
|
||||
providers := models["providers"].(map[string]any)
|
||||
ollama := providers["ollama"].(map[string]any)
|
||||
modelList := ollama["models"].([]any)
|
||||
entry := modelList[0].(map[string]any)
|
||||
entry["customField"] = "user-value"
|
||||
configData, _ := json.MarshalIndent(cfg, "", " ")
|
||||
os.WriteFile(configPath, configData, 0o644)
|
||||
|
||||
// Re-run Edit
|
||||
c.Edit([]string{"llama3.2"})
|
||||
|
||||
data, _ = os.ReadFile(configPath)
|
||||
json.Unmarshal(data, &cfg)
|
||||
models = cfg["models"].(map[string]any)
|
||||
providers = models["providers"].(map[string]any)
|
||||
ollama = providers["ollama"].(map[string]any)
|
||||
modelList = ollama["models"].([]any)
|
||||
entry = modelList[0].(map[string]any)
|
||||
if entry["customField"] != "user-value" {
|
||||
t.Error("custom field was lost")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("edit replaces models list", func(t *testing.T) {
|
||||
cleanup()
|
||||
c.Edit([]string{"llama3.2", "mistral"})
|
||||
c.Edit([]string{"llama3.2"})
|
||||
|
||||
assertOpenclawModelExists(t, configPath, "llama3.2")
|
||||
assertOpenclawModelNotExists(t, configPath, "mistral")
|
||||
})
|
||||
|
||||
t.Run("empty models is no-op", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
original := `{"existing":"data"}`
|
||||
os.WriteFile(configPath, []byte(original), 0o644)
|
||||
|
||||
c.Edit([]string{})
|
||||
|
||||
data, _ := os.ReadFile(configPath)
|
||||
if string(data) != original {
|
||||
t.Error("empty models should not modify file")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("corrupted JSON treated as empty", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{corrupted`), 0o644)
|
||||
|
||||
if err := c.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
t.Error("result should be valid JSON")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("wrong type models section", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"models":"not a map"}`), 0o644)
|
||||
|
||||
if err := c.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assertOpenclawModelExists(t, configPath, "llama3.2")
|
||||
})
|
||||
}
|
||||
|
||||
func TestOpenclawModels(t *testing.T) {
|
||||
c := &Openclaw{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
t.Run("no config returns nil", func(t *testing.T) {
|
||||
if models := c.Models(); len(models) > 0 {
|
||||
t.Errorf("expected nil/empty, got %v", models)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns all ollama models", func(t *testing.T) {
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{
|
||||
"models":{"providers":{"ollama":{"models":[
|
||||
{"id":"llama3.2"},
|
||||
{"id":"mistral"}
|
||||
]}}}
|
||||
}`), 0o644)
|
||||
|
||||
models := c.Models()
|
||||
if len(models) != 2 {
|
||||
t.Errorf("expected 2 models, got %v", models)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
func assertOpenclawModelExists(t *testing.T, path, model string) {
|
||||
t.Helper()
|
||||
data, _ := os.ReadFile(path)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
models := cfg["models"].(map[string]any)
|
||||
providers := models["providers"].(map[string]any)
|
||||
ollama := providers["ollama"].(map[string]any)
|
||||
modelList := ollama["models"].([]any)
|
||||
for _, m := range modelList {
|
||||
if entry, ok := m.(map[string]any); ok {
|
||||
if entry["id"] == model {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
t.Errorf("model %s not found", model)
|
||||
}
|
||||
|
||||
func assertOpenclawModelNotExists(t *testing.T, path, model string) {
|
||||
t.Helper()
|
||||
data, _ := os.ReadFile(path)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
models, _ := cfg["models"].(map[string]any)
|
||||
providers, _ := models["providers"].(map[string]any)
|
||||
ollama, _ := providers["ollama"].(map[string]any)
|
||||
modelList, _ := ollama["models"].([]any)
|
||||
for _, m := range modelList {
|
||||
if entry, ok := m.(map[string]any); ok {
|
||||
if entry["id"] == model {
|
||||
t.Errorf("model %s should not exist", model)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func assertOpenclawPrimaryModel(t *testing.T, path, expected string) {
|
||||
t.Helper()
|
||||
data, _ := os.ReadFile(path)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
agents := cfg["agents"].(map[string]any)
|
||||
defaults := agents["defaults"].(map[string]any)
|
||||
model := defaults["model"].(map[string]any)
|
||||
if model["primary"] != expected {
|
||||
t.Errorf("primary model = %v, want %v", model["primary"], expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenclawPaths(t *testing.T) {
|
||||
c := &Openclaw{}
|
||||
|
||||
t.Run("returns path when config exists", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{}`), 0o644)
|
||||
|
||||
paths := c.Paths()
|
||||
if len(paths) != 1 {
|
||||
t.Errorf("expected 1 path, got %d", len(paths))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns nil when config missing", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
if paths := c.Paths(); paths != nil {
|
||||
t.Errorf("expected nil, got %v", paths)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestOpenclawModelsEdgeCases(t *testing.T) {
|
||||
c := &Openclaw{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
configPath := filepath.Join(configDir, "openclaw.json")
|
||||
cleanup := func() { os.RemoveAll(configDir) }
|
||||
|
||||
t.Run("corrupted JSON returns nil", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{corrupted`), 0o644)
|
||||
if models := c.Models(); models != nil {
|
||||
t.Errorf("expected nil, got %v", models)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("wrong type at models level", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"models":"string"}`), 0o644)
|
||||
if models := c.Models(); models != nil {
|
||||
t.Errorf("expected nil, got %v", models)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("wrong type at providers level", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"models":{"providers":"string"}}`), 0o644)
|
||||
if models := c.Models(); models != nil {
|
||||
t.Errorf("expected nil, got %v", models)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("wrong type at ollama level", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"models":{"providers":{"ollama":"string"}}}`), 0o644)
|
||||
if models := c.Models(); models != nil {
|
||||
t.Errorf("expected nil, got %v", models)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("model entry missing id", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"models":{"providers":{"ollama":{"models":[{"name":"test"}]}}}}`), 0o644)
|
||||
if len(c.Models()) != 0 {
|
||||
t.Error("expected empty for missing id")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("model id is not string", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"models":{"providers":{"ollama":{"models":[{"id":123}]}}}}`), 0o644)
|
||||
if len(c.Models()) != 0 {
|
||||
t.Error("expected empty for non-string id")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestOpenclawEditSchemaFields(t *testing.T) {
|
||||
c := &Openclaw{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configPath := filepath.Join(tmpDir, ".openclaw", "openclaw.json")
|
||||
|
||||
if err := c.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
models := cfg["models"].(map[string]any)
|
||||
providers := models["providers"].(map[string]any)
|
||||
ollama := providers["ollama"].(map[string]any)
|
||||
modelList := ollama["models"].([]any)
|
||||
entry := modelList[0].(map[string]any)
|
||||
|
||||
// Verify required schema fields
|
||||
if entry["reasoning"] != false {
|
||||
t.Error("reasoning should be false")
|
||||
}
|
||||
if entry["input"] == nil {
|
||||
t.Error("input should be set")
|
||||
}
|
||||
if entry["contextWindow"] == nil {
|
||||
t.Error("contextWindow should be set")
|
||||
}
|
||||
if entry["maxTokens"] == nil {
|
||||
t.Error("maxTokens should be set")
|
||||
}
|
||||
cost := entry["cost"].(map[string]any)
|
||||
if cost["cacheRead"] == nil {
|
||||
t.Error("cost.cacheRead should be set")
|
||||
}
|
||||
if cost["cacheWrite"] == nil {
|
||||
t.Error("cost.cacheWrite should be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenclawEditModelNames(t *testing.T) {
|
||||
c := &Openclaw{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configPath := filepath.Join(tmpDir, ".openclaw", "openclaw.json")
|
||||
cleanup := func() { os.RemoveAll(filepath.Join(tmpDir, ".openclaw")) }
|
||||
|
||||
t.Run("model with colon tag", func(t *testing.T) {
|
||||
cleanup()
|
||||
if err := c.Edit([]string{"llama3.2:70b"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assertOpenclawModelExists(t, configPath, "llama3.2:70b")
|
||||
assertOpenclawPrimaryModel(t, configPath, "ollama/llama3.2:70b")
|
||||
})
|
||||
|
||||
t.Run("model with slash", func(t *testing.T) {
|
||||
cleanup()
|
||||
if err := c.Edit([]string{"library/model:tag"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assertOpenclawModelExists(t, configPath, "library/model:tag")
|
||||
assertOpenclawPrimaryModel(t, configPath, "ollama/library/model:tag")
|
||||
})
|
||||
|
||||
t.Run("model with hyphen", func(t *testing.T) {
|
||||
cleanup()
|
||||
if err := c.Edit([]string{"test-model"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assertOpenclawModelExists(t, configPath, "test-model")
|
||||
})
|
||||
}
|
||||
|
||||
func TestOpenclawEditAgentsPreservation(t *testing.T) {
|
||||
c := &Openclaw{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
configPath := filepath.Join(configDir, "openclaw.json")
|
||||
cleanup := func() { os.RemoveAll(configDir) }
|
||||
|
||||
t.Run("preserve other agent defaults", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"agents":{"defaults":{"model":{"primary":"old"},"temperature":0.7}}}`), 0o644)
|
||||
|
||||
c.Edit([]string{"llama3.2"})
|
||||
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
agents := cfg["agents"].(map[string]any)
|
||||
defaults := agents["defaults"].(map[string]any)
|
||||
if defaults["temperature"] != 0.7 {
|
||||
t.Error("temperature setting was lost")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("preserve other agents besides defaults", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"agents":{"defaults":{},"custom-agent":{"foo":"bar"}}}`), 0o644)
|
||||
|
||||
c.Edit([]string{"llama3.2"})
|
||||
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
agents := cfg["agents"].(map[string]any)
|
||||
if agents["custom-agent"] == nil {
|
||||
t.Error("custom-agent was lost")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
const testOpenclawFixture = `{
|
||||
"theme": "dark",
|
||||
"mcp": {"servers": {"custom": {"enabled": true}}},
|
||||
"models": {
|
||||
"providers": {
|
||||
"anthropic": {"apiKey": "xxx"},
|
||||
"ollama": {
|
||||
"baseUrl": "http://127.0.0.1:11434/v1",
|
||||
"models": [{"id": "old-model", "customField": "preserved"}]
|
||||
}
|
||||
}
|
||||
},
|
||||
"agents": {
|
||||
"defaults": {"model": {"primary": "old"}, "temperature": 0.7},
|
||||
"custom-agent": {"foo": "bar"}
|
||||
}
|
||||
}`
|
||||
|
||||
func TestOpenclawEdit_RoundTrip(t *testing.T) {
|
||||
c := &Openclaw{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
configPath := filepath.Join(configDir, "openclaw.json")
|
||||
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(testOpenclawFixture), 0o644)
|
||||
|
||||
if err := c.Edit([]string{"llama3.2", "mistral"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
|
||||
// Verify top-level preserved
|
||||
if cfg["theme"] != "dark" {
|
||||
t.Error("theme not preserved")
|
||||
}
|
||||
mcp := cfg["mcp"].(map[string]any)
|
||||
servers := mcp["servers"].(map[string]any)
|
||||
if servers["custom"] == nil {
|
||||
t.Error("mcp.servers.custom not preserved")
|
||||
}
|
||||
|
||||
// Verify other providers preserved
|
||||
models := cfg["models"].(map[string]any)
|
||||
providers := models["providers"].(map[string]any)
|
||||
if providers["anthropic"] == nil {
|
||||
t.Error("anthropic provider not preserved")
|
||||
}
|
||||
|
||||
// Verify agents preserved
|
||||
agents := cfg["agents"].(map[string]any)
|
||||
if agents["custom-agent"] == nil {
|
||||
t.Error("custom-agent not preserved")
|
||||
}
|
||||
defaults := agents["defaults"].(map[string]any)
|
||||
if defaults["temperature"] != 0.7 {
|
||||
t.Error("temperature not preserved")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenclawEdit_Idempotent(t *testing.T) {
|
||||
c := &Openclaw{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
configPath := filepath.Join(configDir, "openclaw.json")
|
||||
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(testOpenclawFixture), 0o644)
|
||||
|
||||
c.Edit([]string{"llama3.2", "mistral"})
|
||||
firstData, _ := os.ReadFile(configPath)
|
||||
|
||||
c.Edit([]string{"llama3.2", "mistral"})
|
||||
secondData, _ := os.ReadFile(configPath)
|
||||
|
||||
if string(firstData) != string(secondData) {
|
||||
t.Error("repeated edits with same models produced different results")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenclawEdit_MultipleConsecutiveEdits(t *testing.T) {
|
||||
c := &Openclaw{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
configPath := filepath.Join(configDir, "openclaw.json")
|
||||
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(testOpenclawFixture), 0o644)
|
||||
|
||||
for i := range 10 {
|
||||
models := []string{"model-a", "model-b"}
|
||||
if i%2 == 0 {
|
||||
models = []string{"model-x", "model-y", "model-z"}
|
||||
}
|
||||
if err := c.Edit(models); err != nil {
|
||||
t.Fatalf("edit %d failed: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
t.Fatalf("file is not valid JSON after multiple edits: %v", err)
|
||||
}
|
||||
|
||||
if cfg["theme"] != "dark" {
|
||||
t.Error("theme lost after multiple edits")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenclawEdit_BackupCreated(t *testing.T) {
|
||||
c := &Openclaw{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
configPath := filepath.Join(configDir, "openclaw.json")
|
||||
backupDir := filepath.Join(os.TempDir(), "ollama-backups")
|
||||
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
uniqueMarker := fmt.Sprintf("test-marker-%d", os.Getpid())
|
||||
original := fmt.Sprintf(`{"theme": "%s"}`, uniqueMarker)
|
||||
os.WriteFile(configPath, []byte(original), 0o644)
|
||||
|
||||
if err := c.Edit([]string{"model-a"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
backups, _ := filepath.Glob(filepath.Join(backupDir, "openclaw.json.*"))
|
||||
foundBackup := false
|
||||
for _, backup := range backups {
|
||||
data, _ := os.ReadFile(backup)
|
||||
if string(data) == original {
|
||||
foundBackup = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !foundBackup {
|
||||
t.Error("backup with original content not found")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenclawClawdbotAlias(t *testing.T) {
|
||||
for _, alias := range []string{"clawdbot", "moltbot"} {
|
||||
t.Run(alias+" alias resolves to Openclaw runner", func(t *testing.T) {
|
||||
r, ok := integrations[alias]
|
||||
if !ok {
|
||||
t.Fatalf("%s not found in integrations", alias)
|
||||
}
|
||||
if _, ok := r.(*Openclaw); !ok {
|
||||
t.Errorf("%s integration is %T, want *Openclaw", alias, r)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run(alias+" is hidden from selector", func(t *testing.T) {
|
||||
if !integrationAliases[alias] {
|
||||
t.Errorf("%s should be in integrationAliases", alias)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenclawLegacyPaths(t *testing.T) {
|
||||
c := &Openclaw{}
|
||||
|
||||
t.Run("falls back to legacy clawdbot path", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
legacyDir := filepath.Join(tmpDir, ".clawdbot")
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{}`), 0o644)
|
||||
|
||||
paths := c.Paths()
|
||||
if len(paths) != 1 {
|
||||
t.Fatalf("expected 1 path, got %d", len(paths))
|
||||
}
|
||||
if paths[0] != filepath.Join(legacyDir, "clawdbot.json") {
|
||||
t.Errorf("expected legacy path, got %s", paths[0])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("prefers new path over legacy", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
newDir := filepath.Join(tmpDir, ".openclaw")
|
||||
legacyDir := filepath.Join(tmpDir, ".clawdbot")
|
||||
os.MkdirAll(newDir, 0o755)
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
os.WriteFile(filepath.Join(newDir, "openclaw.json"), []byte(`{}`), 0o644)
|
||||
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{}`), 0o644)
|
||||
|
||||
paths := c.Paths()
|
||||
if len(paths) != 1 {
|
||||
t.Fatalf("expected 1 path, got %d", len(paths))
|
||||
}
|
||||
if paths[0] != filepath.Join(newDir, "openclaw.json") {
|
||||
t.Errorf("expected new path, got %s", paths[0])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Models reads from legacy path", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
legacyDir := filepath.Join(tmpDir, ".clawdbot")
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{
|
||||
"models":{"providers":{"ollama":{"models":[{"id":"llama3.2"}]}}}
|
||||
}`), 0o644)
|
||||
|
||||
models := c.Models()
|
||||
if len(models) != 1 || models[0] != "llama3.2" {
|
||||
t.Errorf("expected [llama3.2], got %v", models)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Models prefers new path over legacy", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
newDir := filepath.Join(tmpDir, ".openclaw")
|
||||
legacyDir := filepath.Join(tmpDir, ".clawdbot")
|
||||
os.MkdirAll(newDir, 0o755)
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
os.WriteFile(filepath.Join(newDir, "openclaw.json"), []byte(`{
|
||||
"models":{"providers":{"ollama":{"models":[{"id":"new-model"}]}}}
|
||||
}`), 0o644)
|
||||
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{
|
||||
"models":{"providers":{"ollama":{"models":[{"id":"legacy-model"}]}}}
|
||||
}`), 0o644)
|
||||
|
||||
models := c.Models()
|
||||
if len(models) != 1 || models[0] != "new-model" {
|
||||
t.Errorf("expected [new-model], got %v", models)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Edit reads new path over legacy when both exist", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
newDir := filepath.Join(tmpDir, ".openclaw")
|
||||
legacyDir := filepath.Join(tmpDir, ".clawdbot")
|
||||
os.MkdirAll(newDir, 0o755)
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
os.WriteFile(filepath.Join(newDir, "openclaw.json"), []byte(`{"theme":"new"}`), 0o644)
|
||||
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{"theme":"legacy"}`), 0o644)
|
||||
|
||||
if err := c.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(filepath.Join(newDir, "openclaw.json"))
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
if cfg["theme"] != "new" {
|
||||
t.Errorf("expected theme from new config, got %v", cfg["theme"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Edit migrates from legacy config", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
legacyDir := filepath.Join(tmpDir, ".clawdbot")
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{"theme":"dark"}`), 0o644)
|
||||
|
||||
if err := c.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Should write to new path
|
||||
newPath := filepath.Join(tmpDir, ".openclaw", "openclaw.json")
|
||||
data, err := os.ReadFile(newPath)
|
||||
if err != nil {
|
||||
t.Fatal("expected new config file to be created")
|
||||
}
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
if cfg["theme"] != "dark" {
|
||||
t.Error("legacy theme setting was not migrated")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestOpenclawEdit_CreatesDirectoryIfMissing(t *testing.T) {
|
||||
c := &Openclaw{}
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
|
||||
if _, err := os.Stat(configDir); !os.IsNotExist(err) {
|
||||
t.Fatal("directory should not exist before test")
|
||||
}
|
||||
|
||||
if err := c.Edit([]string{"model-a"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(configDir); os.IsNotExist(err) {
|
||||
t.Fatal("directory was not created")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenclawOnboarded(t *testing.T) {
|
||||
c := &Openclaw{}
|
||||
|
||||
t.Run("returns false when no config exists", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
if c.onboarded() {
|
||||
t.Error("expected false when no config exists")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns false when config exists but no wizard section", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{"theme":"dark"}`), 0o644)
|
||||
|
||||
if c.onboarded() {
|
||||
t.Error("expected false when no wizard section")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns false when wizard section exists but no lastRunAt", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{"wizard":{}}`), 0o644)
|
||||
|
||||
if c.onboarded() {
|
||||
t.Error("expected false when wizard.lastRunAt is missing")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns false when wizard.lastRunAt is empty string", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{"wizard":{"lastRunAt":""}}`), 0o644)
|
||||
|
||||
if c.onboarded() {
|
||||
t.Error("expected false when wizard.lastRunAt is empty")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns true when wizard.lastRunAt is set", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{"wizard":{"lastRunAt":"2024-01-01T00:00:00Z"}}`), 0o644)
|
||||
|
||||
if !c.onboarded() {
|
||||
t.Error("expected true when wizard.lastRunAt is set")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("checks legacy clawdbot path", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
legacyDir := filepath.Join(tmpDir, ".clawdbot")
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{"wizard":{"lastRunAt":"2024-01-01T00:00:00Z"}}`), 0o644)
|
||||
|
||||
if !c.onboarded() {
|
||||
t.Error("expected true when legacy config has wizard.lastRunAt")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("prefers new path over legacy", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
newDir := filepath.Join(tmpDir, ".openclaw")
|
||||
legacyDir := filepath.Join(tmpDir, ".clawdbot")
|
||||
os.MkdirAll(newDir, 0o755)
|
||||
os.MkdirAll(legacyDir, 0o755)
|
||||
// New path has no wizard marker
|
||||
os.WriteFile(filepath.Join(newDir, "openclaw.json"), []byte(`{}`), 0o644)
|
||||
// Legacy has wizard marker
|
||||
os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{"wizard":{"lastRunAt":"2024-01-01T00:00:00Z"}}`), 0o644)
|
||||
|
||||
if c.onboarded() {
|
||||
t.Error("expected false - should prefer new path which has no wizard marker")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("handles corrupted JSON gracefully", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{corrupted`), 0o644)
|
||||
|
||||
if c.onboarded() {
|
||||
t.Error("expected false for corrupted JSON")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("handles wrong type for wizard section", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{"wizard":"not a map"}`), 0o644)
|
||||
|
||||
if c.onboarded() {
|
||||
t.Error("expected false when wizard is wrong type")
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -9,6 +9,8 @@ import (
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
// OpenCode implements Runner and Editor for OpenCode integration
|
||||
@@ -16,7 +18,7 @@ type OpenCode struct{}
|
||||
|
||||
func (o *OpenCode) String() string { return "OpenCode" }
|
||||
|
||||
func (o *OpenCode) Run(model string) error {
|
||||
func (o *OpenCode) Run(model string, args []string) error {
|
||||
if _, err := exec.LookPath("opencode"); err != nil {
|
||||
return fmt.Errorf("opencode is not installed, install from https://opencode.ai")
|
||||
}
|
||||
@@ -30,7 +32,7 @@ func (o *OpenCode) Run(model string) error {
|
||||
return fmt.Errorf("setup failed: %w", err)
|
||||
}
|
||||
|
||||
cmd := exec.Command("opencode")
|
||||
cmd := exec.Command("opencode", args...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
@@ -88,7 +90,7 @@ func (o *OpenCode) Edit(modelList []string) error {
|
||||
"npm": "@ai-sdk/openai-compatible",
|
||||
"name": "Ollama (local)",
|
||||
"options": map[string]any{
|
||||
"baseURL": "http://localhost:11434/v1",
|
||||
"baseURL": envconfig.Host().String() + "/v1",
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -105,17 +107,26 @@ func (o *OpenCode) Edit(modelList []string) error {
|
||||
|
||||
for name, cfg := range models {
|
||||
if cfgMap, ok := cfg.(map[string]any); ok {
|
||||
if displayName, ok := cfgMap["name"].(string); ok {
|
||||
if strings.HasSuffix(displayName, "[Ollama]") && !selectedSet[name] {
|
||||
delete(models, name)
|
||||
}
|
||||
if isOllamaModel(cfgMap) && !selectedSet[name] {
|
||||
delete(models, name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, model := range modelList {
|
||||
if existing, ok := models[model].(map[string]any); ok {
|
||||
// migrate existing models without _launch marker
|
||||
if isOllamaModel(existing) {
|
||||
existing["_launch"] = true
|
||||
if name, ok := existing["name"].(string); ok {
|
||||
existing["name"] = strings.TrimSuffix(name, " [Ollama]")
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
models[model] = map[string]any{
|
||||
"name": fmt.Sprintf("%s [Ollama]", model),
|
||||
"name": model,
|
||||
"_launch": true,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -201,3 +212,15 @@ func (o *OpenCode) Models() []string {
|
||||
slices.Sort(keys)
|
||||
return keys
|
||||
}
|
||||
|
||||
// isOllamaModel reports whether a model config entry is managed by us
|
||||
func isOllamaModel(cfg map[string]any) bool {
|
||||
if v, ok := cfg["_launch"].(bool); ok && v {
|
||||
return true
|
||||
}
|
||||
// previously used [Ollama] as a suffix for the model managed by ollama launch
|
||||
if name, ok := cfg["name"].(string); ok {
|
||||
return strings.HasSuffix(name, "[Ollama]")
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -161,6 +161,76 @@ func TestOpenCodeEdit(t *testing.T) {
|
||||
assertOpenCodeModelNotExists(t, configPath, "mistral")
|
||||
})
|
||||
|
||||
t.Run("preserve user customizations on managed models", func(t *testing.T) {
|
||||
cleanup()
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Add custom fields to the model entry (simulating user edits)
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
provider := cfg["provider"].(map[string]any)
|
||||
ollama := provider["ollama"].(map[string]any)
|
||||
models := ollama["models"].(map[string]any)
|
||||
entry := models["llama3.2"].(map[string]any)
|
||||
entry["_myPref"] = "custom-value"
|
||||
entry["_myNum"] = 42
|
||||
configData, _ := json.MarshalIndent(cfg, "", " ")
|
||||
os.WriteFile(configPath, configData, 0o644)
|
||||
|
||||
// Re-run Edit — should preserve custom fields
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data, _ = os.ReadFile(configPath)
|
||||
json.Unmarshal(data, &cfg)
|
||||
provider = cfg["provider"].(map[string]any)
|
||||
ollama = provider["ollama"].(map[string]any)
|
||||
models = ollama["models"].(map[string]any)
|
||||
entry = models["llama3.2"].(map[string]any)
|
||||
|
||||
if entry["_myPref"] != "custom-value" {
|
||||
t.Errorf("_myPref was lost: got %v", entry["_myPref"])
|
||||
}
|
||||
if entry["_myNum"] != float64(42) {
|
||||
t.Errorf("_myNum was lost: got %v", entry["_myNum"])
|
||||
}
|
||||
if v, ok := entry["_launch"].(bool); !ok || !v {
|
||||
t.Errorf("_launch marker missing or false: got %v", entry["_launch"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("migrate legacy [Ollama] suffix entries", func(t *testing.T) {
|
||||
cleanup()
|
||||
// Write a config with a legacy entry (has [Ollama] suffix but no _launch marker)
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
os.WriteFile(configPath, []byte(`{"provider":{"ollama":{"models":{"llama3.2":{"name":"llama3.2 [Ollama]"}}}}}`), 0o644)
|
||||
|
||||
if err := o.Edit([]string{"llama3.2"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(configPath)
|
||||
var cfg map[string]any
|
||||
json.Unmarshal(data, &cfg)
|
||||
provider := cfg["provider"].(map[string]any)
|
||||
ollama := provider["ollama"].(map[string]any)
|
||||
models := ollama["models"].(map[string]any)
|
||||
entry := models["llama3.2"].(map[string]any)
|
||||
|
||||
// _launch marker should be added
|
||||
if v, ok := entry["_launch"].(bool); !ok || !v {
|
||||
t.Errorf("_launch marker not added during migration: got %v", entry["_launch"])
|
||||
}
|
||||
// [Ollama] suffix should be stripped
|
||||
if name, ok := entry["name"].(string); !ok || name != "llama3.2" {
|
||||
t.Errorf("name suffix not stripped: got %q", entry["name"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("remove model preserves non-ollama models", func(t *testing.T) {
|
||||
cleanup()
|
||||
os.MkdirAll(configDir, 0o755)
|
||||
|
||||
@@ -17,6 +17,7 @@ const (
|
||||
ansiBold = "\033[1m"
|
||||
ansiReset = "\033[0m"
|
||||
ansiGray = "\033[37m"
|
||||
ansiGreen = "\033[32m"
|
||||
ansiClearDown = "\033[J"
|
||||
)
|
||||
|
||||
@@ -275,7 +276,11 @@ func parseInput(r io.Reader) (inputEvent, byte, error) {
|
||||
func renderSelect(w io.Writer, prompt string, s *selectState) int {
|
||||
filtered := s.filtered()
|
||||
|
||||
fmt.Fprintf(w, "%s %s\r\n", prompt, s.filter)
|
||||
if s.filter == "" {
|
||||
fmt.Fprintf(w, "%s %sType to filter...%s\r\n", prompt, ansiGray, ansiReset)
|
||||
} else {
|
||||
fmt.Fprintf(w, "%s %s\r\n", prompt, s.filter)
|
||||
}
|
||||
lineCount := 1
|
||||
|
||||
if len(filtered) == 0 {
|
||||
@@ -314,7 +319,11 @@ func renderSelect(w io.Writer, prompt string, s *selectState) int {
|
||||
func renderMultiSelect(w io.Writer, prompt string, s *multiSelectState) int {
|
||||
filtered := s.filtered()
|
||||
|
||||
fmt.Fprintf(w, "%s %s\r\n", prompt, s.filter)
|
||||
if s.filter == "" {
|
||||
fmt.Fprintf(w, "%s %sType to filter...%s\r\n", prompt, ansiGray, ansiReset)
|
||||
} else {
|
||||
fmt.Fprintf(w, "%s %s\r\n", prompt, s.filter)
|
||||
}
|
||||
lineCount := 1
|
||||
|
||||
if len(filtered) == 0 {
|
||||
@@ -345,10 +354,15 @@ func renderMultiSelect(w io.Writer, prompt string, s *multiSelectState) int {
|
||||
suffix = " " + ansiGray + "(default)" + ansiReset
|
||||
}
|
||||
|
||||
desc := ""
|
||||
if item.Description != "" {
|
||||
desc = " " + ansiGray + "- " + item.Description + ansiReset
|
||||
}
|
||||
|
||||
if idx == s.highlighted && !s.focusOnButton {
|
||||
fmt.Fprintf(w, " %s%s %s %s%s%s\r\n", ansiBold, prefix, checkbox, item.Name, ansiReset, suffix)
|
||||
fmt.Fprintf(w, " %s%s %s %s%s%s%s\r\n", ansiBold, prefix, checkbox, item.Name, ansiReset, desc, suffix)
|
||||
} else {
|
||||
fmt.Fprintf(w, " %s %s %s%s\r\n", prefix, checkbox, item.Name, suffix)
|
||||
fmt.Fprintf(w, " %s %s %s%s%s\r\n", prefix, checkbox, item.Name, desc, suffix)
|
||||
}
|
||||
lineCount++
|
||||
}
|
||||
|
||||
@@ -96,6 +96,14 @@ func TestSelectState(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Enter_EmptyFilteredList_EmptyFilter_DoesNothing", func(t *testing.T) {
|
||||
s := newSelectState([]selectItem{})
|
||||
done, result, err := s.handleInput(eventEnter, 0)
|
||||
if done || result != "" || err != nil {
|
||||
t.Errorf("expected (false, '', nil), got (%v, %v, %v)", done, result, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Escape_ReturnsCancelledError", func(t *testing.T) {
|
||||
s := newSelectState(items)
|
||||
done, result, err := s.handleInput(eventEscape, 0)
|
||||
@@ -574,8 +582,19 @@ func TestRenderSelect(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
renderSelect(&buf, "Select:", s)
|
||||
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "no matches") {
|
||||
t.Errorf("expected 'no matches' message, got: %s", output)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("EmptyFilteredList_EmptyFilter_ShowsNoMatches", func(t *testing.T) {
|
||||
s := newSelectState([]selectItem{})
|
||||
var buf bytes.Buffer
|
||||
renderSelect(&buf, "Select:", s)
|
||||
|
||||
if !strings.Contains(buf.String(), "no matches") {
|
||||
t.Error("expected 'no matches' message")
|
||||
t.Error("expected 'no matches' message for empty list with no filter")
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
@@ -10,19 +10,21 @@ import (
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
var errNotRunning = errors.New("could not connect to ollama server, run 'ollama serve' to start it")
|
||||
|
||||
func startApp(ctx context.Context, client *api.Client) error {
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return err
|
||||
return errNotRunning
|
||||
}
|
||||
link, err := os.Readlink(exe)
|
||||
if err != nil {
|
||||
return err
|
||||
return errNotRunning
|
||||
}
|
||||
r := regexp.MustCompile(`^.*/Ollama\s?\d*.app`)
|
||||
m := r.FindStringSubmatch(link)
|
||||
if len(m) != 1 {
|
||||
return errors.New("could not find ollama app")
|
||||
return errNotRunning
|
||||
}
|
||||
if err := exec.Command("/usr/bin/open", "-j", "-a", m[0], "--args", "--fast-startup").Run(); err != nil {
|
||||
return err
|
||||
|
||||
@@ -313,6 +313,8 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
|
||||
conv = &deepseek2Model{}
|
||||
case "Glm4MoeLiteForCausalLM":
|
||||
conv = &glm4MoeLiteModel{}
|
||||
case "GlmOcrForConditionalGeneration":
|
||||
conv = &glmOcrModel{}
|
||||
case "Lfm2ForCausalLM":
|
||||
conv = &lfm2Model{}
|
||||
default:
|
||||
|
||||
455
convert/convert_glmocr.go
Normal file
455
convert/convert_glmocr.go
Normal file
@@ -0,0 +1,455 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"encoding/json"
|
||||
"io/fs"
|
||||
"log/slog"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/pdevine/tensor"
|
||||
"github.com/pdevine/tensor/native"
|
||||
)
|
||||
|
||||
// normalToNeoXRepacker creates a repacker that permutes Q/K weights from interleaved (LLaMA)
|
||||
// to NeoX ordering for compatibility with GGML's M-RoPE kernel.
|
||||
//
|
||||
// For weights: reshape [out, in] -> [n_heads, head_dim, in], permute rotary dims, reshape back
|
||||
// For biases: reshape [out] -> [n_heads, head_dim], permute rotary dims, reshape back
|
||||
func normalToNeoXRepacker(nHeads, headDim int, partialRotaryFactor float32) func(string, []float32, []uint64) ([]float32, error) {
|
||||
return func(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||
rotaryDim := int(float32(headDim) * partialRotaryFactor)
|
||||
if rotaryDim%2 != 0 {
|
||||
rotaryDim = (rotaryDim / 2) * 2 // Round down to even
|
||||
}
|
||||
|
||||
// Handle 1D (bias) or 2D (weight) tensors
|
||||
is1D := len(shape) == 1
|
||||
var inFeatures int
|
||||
if is1D {
|
||||
inFeatures = 1
|
||||
} else {
|
||||
inFeatures = int(shape[1])
|
||||
}
|
||||
outFeatures := int(shape[0])
|
||||
nEffectiveHeads := outFeatures / headDim
|
||||
|
||||
if nEffectiveHeads != nHeads {
|
||||
slog.Warn("normalToNeoX: unexpected head count", "effective", nEffectiveHeads, "expected", nHeads)
|
||||
}
|
||||
|
||||
// Reshape to [n_heads, head_dim, in_features]
|
||||
reshaped := make([]float32, len(data))
|
||||
copy(reshaped, data)
|
||||
|
||||
// Permute the rotary dimensions: even indices first, then odd
|
||||
// For each head, reorder [0,1,2,3,4,5...] to [0,2,4...,1,3,5...]
|
||||
result := make([]float32, len(data))
|
||||
halfRotary := rotaryDim / 2
|
||||
|
||||
for h := range nEffectiveHeads {
|
||||
for f := range inFeatures {
|
||||
for i := range halfRotary {
|
||||
// Even dim (0, 2, 4, ...) -> position i
|
||||
srcIdx := h*headDim*inFeatures + (2*i)*inFeatures + f
|
||||
dstIdx := h*headDim*inFeatures + i*inFeatures + f
|
||||
result[dstIdx] = reshaped[srcIdx]
|
||||
|
||||
// Odd dim (1, 3, 5, ...) -> position halfRotary + i
|
||||
srcIdx = h*headDim*inFeatures + (2*i+1)*inFeatures + f
|
||||
dstIdx = h*headDim*inFeatures + (halfRotary+i)*inFeatures + f
|
||||
result[dstIdx] = reshaped[srcIdx]
|
||||
}
|
||||
|
||||
// Non-rotary part: copy as-is
|
||||
for i := rotaryDim; i < headDim; i++ {
|
||||
srcIdx := h*headDim*inFeatures + i*inFeatures + f
|
||||
result[srcIdx] = reshaped[srcIdx]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
}
|
||||
|
||||
type glmOcrModel struct {
|
||||
ModelParameters
|
||||
|
||||
TextConfig struct {
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
HeadDim uint32 `json:"head_dim"`
|
||||
MaxPositionEmbed uint32 `json:"max_position_embeddings"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
|
||||
RopeParameters struct {
|
||||
RopeType string `json:"rope_type"`
|
||||
MRopeSection []int32 `json:"mrope_section"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
|
||||
} `json:"rope_parameters"`
|
||||
} `json:"text_config"`
|
||||
|
||||
VisionConfig struct {
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
Depth uint32 `json:"depth"`
|
||||
NumHeads uint32 `json:"num_heads"`
|
||||
ImageSize uint32 `json:"image_size"`
|
||||
PatchSize uint32 `json:"patch_size"`
|
||||
OutHiddenSize uint32 `json:"out_hidden_size"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
SpatialMergeSize uint32 `json:"spatial_merge_size"`
|
||||
TemporalPatchSize uint32 `json:"temporal_patch_size"`
|
||||
} `json:"vision_config"`
|
||||
|
||||
ImageStartTokenID uint32 `json:"image_start_token_id"`
|
||||
ImageEndTokenID uint32 `json:"image_end_token_id"`
|
||||
VideoStartTokenID uint32 `json:"video_start_token_id"`
|
||||
VideoEndTokenID uint32 `json:"video_end_token_id"`
|
||||
ImageTokenID uint32 `json:"image_token_id"`
|
||||
VideoTokenID uint32 `json:"video_token_id"`
|
||||
|
||||
// Preprocessor config (preprocessor_config.json)
|
||||
Preprocessor struct {
|
||||
Size struct {
|
||||
ShortestEdge uint32 `json:"shortest_edge"`
|
||||
LongestEdge uint32 `json:"longest_edge"`
|
||||
} `json:"size"`
|
||||
PatchSize uint32 `json:"patch_size"`
|
||||
TemporalPatchSize uint32 `json:"temporal_patch_size"`
|
||||
MergeSize uint32 `json:"merge_size"`
|
||||
ImageMean []float32 `json:"image_mean"`
|
||||
ImageStd []float32 `json:"image_std"`
|
||||
} `json:"-"`
|
||||
}
|
||||
|
||||
var _ ModelConverter = (*glmOcrModel)(nil)
|
||||
|
||||
func (m *glmOcrModel) parseMore(fsys fs.FS) error {
|
||||
bts, err := fs.ReadFile(fsys, "preprocessor_config.json")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return json.Unmarshal(bts, &m.Preprocessor)
|
||||
}
|
||||
|
||||
func (m *glmOcrModel) KV(t *Tokenizer) KV {
|
||||
kv := m.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "glmocr"
|
||||
|
||||
// Text model parameters
|
||||
kv["glmocr.block_count"] = cmp.Or(m.TextConfig.NumHiddenLayers, 16)
|
||||
kv["glmocr.embedding_length"] = cmp.Or(m.TextConfig.HiddenSize, 1536)
|
||||
kv["glmocr.attention.head_count"] = cmp.Or(m.TextConfig.NumAttentionHeads, 16)
|
||||
kv["glmocr.attention.head_count_kv"] = cmp.Or(m.TextConfig.NumKeyValueHeads, 8)
|
||||
headDim := cmp.Or(m.TextConfig.HeadDim, m.TextConfig.HiddenSize/m.TextConfig.NumAttentionHeads)
|
||||
kv["glmocr.attention.key_length"] = headDim
|
||||
kv["glmocr.attention.value_length"] = headDim
|
||||
kv["glmocr.feed_forward_length"] = cmp.Or(m.TextConfig.IntermediateSize, 4608)
|
||||
kv["glmocr.attention.layer_norm_rms_epsilon"] = cmp.Or(m.TextConfig.RMSNormEps, 1e-5)
|
||||
kv["glmocr.context_length"] = cmp.Or(m.TextConfig.MaxPositionEmbed, 131072)
|
||||
kv["glmocr.rope.freq_base"] = cmp.Or(m.TextConfig.RopeParameters.RopeTheta, float32(10000))
|
||||
kv["glmocr.rope.partial_rotary_factor"] = cmp.Or(m.TextConfig.RopeParameters.PartialRotaryFactor, m.TextConfig.PartialRotaryFactor, float32(1.0))
|
||||
if len(m.TextConfig.RopeParameters.MRopeSection) > 0 {
|
||||
kv["glmocr.rope.mrope_section"] = m.TextConfig.RopeParameters.MRopeSection
|
||||
}
|
||||
|
||||
// Vision model parameters
|
||||
kv["glmocr.vision.block_count"] = cmp.Or(m.VisionConfig.Depth, 24)
|
||||
kv["glmocr.vision.embedding_length"] = cmp.Or(m.VisionConfig.HiddenSize, 1024)
|
||||
kv["glmocr.vision.attention.head_count"] = cmp.Or(m.VisionConfig.NumHeads, 16)
|
||||
kv["glmocr.vision.image_size"] = cmp.Or(m.VisionConfig.ImageSize, 336)
|
||||
kv["glmocr.vision.patch_size"] = cmp.Or(m.VisionConfig.PatchSize, m.Preprocessor.PatchSize, 14)
|
||||
kv["glmocr.vision.spatial_merge_size"] = cmp.Or(m.VisionConfig.SpatialMergeSize, m.Preprocessor.MergeSize, 2)
|
||||
kv["glmocr.vision.temporal_patch_size"] = cmp.Or(m.VisionConfig.TemporalPatchSize, m.Preprocessor.TemporalPatchSize, 2)
|
||||
kv["glmocr.vision.out_hidden_size"] = cmp.Or(m.VisionConfig.OutHiddenSize, 1536)
|
||||
kv["glmocr.vision.intermediate_size"] = cmp.Or(m.VisionConfig.IntermediateSize, 4096)
|
||||
kv["glmocr.vision.attention.layer_norm_rms_epsilon"] = cmp.Or(m.VisionConfig.RMSNormEps, 1e-5)
|
||||
|
||||
// Preprocessor-derived image settings (min/max pixels and normalization)
|
||||
// Note: fs.Config.keyValue() auto-prepends architecture prefix, so use full key
|
||||
if m.Preprocessor.Size.ShortestEdge > 0 {
|
||||
kv["glmocr.vision.min_pixels"] = m.Preprocessor.Size.ShortestEdge
|
||||
}
|
||||
if m.Preprocessor.Size.LongestEdge > 0 {
|
||||
kv["glmocr.vision.max_pixels"] = m.Preprocessor.Size.LongestEdge
|
||||
}
|
||||
if len(m.Preprocessor.ImageMean) == 3 {
|
||||
kv["glmocr.vision.image_mean"] = m.Preprocessor.ImageMean
|
||||
}
|
||||
if len(m.Preprocessor.ImageStd) == 3 {
|
||||
kv["glmocr.vision.image_std"] = m.Preprocessor.ImageStd
|
||||
}
|
||||
|
||||
// Special tokens
|
||||
kv["glmocr.image_token_id"] = m.ImageTokenID
|
||||
kv["glmocr.image_start_token_id"] = m.ImageStartTokenID
|
||||
kv["glmocr.image_end_token_id"] = m.ImageEndTokenID
|
||||
kv["glmocr.video_token_id"] = m.VideoTokenID
|
||||
kv["glmocr.video_start_token_id"] = m.VideoStartTokenID
|
||||
kv["glmocr.video_end_token_id"] = m.VideoEndTokenID
|
||||
|
||||
return kv
|
||||
}
|
||||
|
||||
func (m *glmOcrModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
var out []*ggml.Tensor
|
||||
|
||||
// Skip layers >= num_hidden_layers (Multi-Token Prediction layers not needed for basic inference)
|
||||
numLayers := int(cmp.Or(m.TextConfig.NumHiddenLayers, 16))
|
||||
skipLayer := func(name string) bool {
|
||||
// Tensor names are already replaced to "blk.N.xxx" format
|
||||
re := regexp.MustCompile(`^blk\.(\d+)`)
|
||||
matches := re.FindStringSubmatch(name)
|
||||
if matches == nil {
|
||||
return false
|
||||
}
|
||||
blkNum, err := strconv.Atoi(matches[1])
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return blkNum >= numLayers
|
||||
}
|
||||
|
||||
for _, t := range ts {
|
||||
name := t.Name()
|
||||
|
||||
// Skip next-n prediction layers (layers >= num_hidden_layers)
|
||||
if skipLayer(name) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Split ffn_gate_up into separate gate and up projections
|
||||
if strings.Contains(name, "ffn_gate_up") {
|
||||
for t := range splitDim(t, 0,
|
||||
split{Replacer: strings.NewReplacer("ffn_gate_up", "ffn_gate")},
|
||||
split{Replacer: strings.NewReplacer("ffn_gate_up", "ffn_up")},
|
||||
) {
|
||||
out = append(out, t)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.HasSuffix(name, "patch_embd.weight") {
|
||||
shape := t.Shape()
|
||||
if len(shape) == 5 && shape[2] == 2 {
|
||||
newShape := []uint64{shape[0], shape[1], shape[3], shape[4]}
|
||||
|
||||
t0 := t.Clone()
|
||||
t0.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||
dims := make([]int, len(shape))
|
||||
for i := range shape {
|
||||
dims[i] = int(shape[i])
|
||||
}
|
||||
var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
|
||||
tt, err := tt.Slice(nil, nil, tensor.S(0, 1), nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tt = tensor.Materialize(tt)
|
||||
newDims := []int{int(shape[0]), int(shape[1]), int(shape[3]), int(shape[4])}
|
||||
if err := tt.Reshape(newDims...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := tt.Reshape(tt.Shape().TotalSize()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return native.VectorF32(tt.(*tensor.Dense))
|
||||
})
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: strings.Replace(name, "patch_embd.weight", "patch_embd_0.weight", 1),
|
||||
Kind: t.Kind(),
|
||||
Shape: newShape,
|
||||
WriterTo: t0,
|
||||
})
|
||||
|
||||
t1 := t.Clone()
|
||||
t1.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||
dims := make([]int, len(shape))
|
||||
for i := range shape {
|
||||
dims[i] = int(shape[i])
|
||||
}
|
||||
var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
|
||||
tt, err := tt.Slice(nil, nil, tensor.S(1, 2), nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tt = tensor.Materialize(tt)
|
||||
newDims := []int{int(shape[0]), int(shape[1]), int(shape[3]), int(shape[4])}
|
||||
if err := tt.Reshape(newDims...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := tt.Reshape(tt.Shape().TotalSize()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return native.VectorF32(tt.(*tensor.Dense))
|
||||
})
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: strings.Replace(name, "patch_embd.weight", "patch_embd_1.weight", 1),
|
||||
Kind: t.Kind(),
|
||||
Shape: newShape,
|
||||
WriterTo: t1,
|
||||
})
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if len(shape) == 4 {
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: strings.Replace(name, "patch_embd.weight", "patch_embd_0.weight", 1),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
slog.Warn("glmocr: patch_embed weight has unexpected shape - not splitting", "shape", shape)
|
||||
// Fall through to default handling
|
||||
}
|
||||
|
||||
// Handle pre-split patch embedding weights
|
||||
// Pattern 1: v.patch_embd.0.weight, v.patch_embd.1.weight -> patch_embd_0.weight, patch_embd_1.weight
|
||||
// Pattern 2: v.patch_embd.weight.0, v.patch_embd.weight.1 -> patch_embd_0.weight, patch_embd_1.weight
|
||||
if strings.Contains(name, "patch_embd.0.") {
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: strings.Replace(name, "patch_embd.0.", "patch_embd_0.", 1),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
continue
|
||||
}
|
||||
if strings.Contains(name, "patch_embd.1.") {
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: strings.Replace(name, "patch_embd.1.", "patch_embd_1.", 1),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
continue
|
||||
}
|
||||
// Handle .weight.0 and .weight.1 suffix patterns
|
||||
if strings.HasSuffix(name, "patch_embd.weight.0") {
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: strings.Replace(name, "patch_embd.weight.0", "patch_embd_0.weight", 1),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
continue
|
||||
}
|
||||
if strings.HasSuffix(name, "patch_embd.weight.1") {
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: strings.Replace(name, "patch_embd.weight.1", "patch_embd_1.weight", 1),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
// Permute Q/K weights for M-RoPE compatibility (interleaved -> NeoX ordering)
|
||||
// GGML's M-RoPE kernel uses NeoX-style rotation, but GLM-OCR uses interleaved (LLaMA-style)
|
||||
// We permute at conversion time so the weights work correctly with GGML's kernel
|
||||
// This aligns Q/K rotary dimensions with GGML's NeoX-style rotation
|
||||
if len(m.TextConfig.RopeParameters.MRopeSection) > 0 &&
|
||||
strings.Contains(name, "blk.") && (strings.Contains(name, "attn_q.") || strings.Contains(name, "attn_k.")) {
|
||||
// Get config values for permutation
|
||||
nHeads := int(cmp.Or(m.TextConfig.NumAttentionHeads, 16))
|
||||
nKVHeads := int(cmp.Or(m.TextConfig.NumKeyValueHeads, 8))
|
||||
hiddenSize := int(cmp.Or(m.TextConfig.HiddenSize, 1536))
|
||||
headDim := int(cmp.Or(m.TextConfig.HeadDim, uint32(hiddenSize/nHeads)))
|
||||
partialRotaryFactor := cmp.Or(m.TextConfig.PartialRotaryFactor, m.TextConfig.RopeParameters.PartialRotaryFactor, float32(1.0))
|
||||
|
||||
// Use appropriate head count: nHeads for Q, nKVHeads for K
|
||||
effectiveHeads := nHeads
|
||||
if strings.Contains(name, "attn_k.") {
|
||||
effectiveHeads = nKVHeads
|
||||
}
|
||||
|
||||
permutedT := t.Clone()
|
||||
permutedT.SetRepacker(normalToNeoXRepacker(effectiveHeads, headDim, partialRotaryFactor))
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: name,
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: permutedT,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: name,
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func (m *glmOcrModel) Replacements() []string {
|
||||
return []string{
|
||||
// Vision encoder
|
||||
"model.visual.patch_embed.proj_1", "v.patch_embd_1", // Second temporal split
|
||||
"model.visual.patch_embed.proj", "v.patch_embd",
|
||||
"model.visual.blocks", "v.blk",
|
||||
"model.visual.post_layernorm", "v.post_ln",
|
||||
"model.visual.downsample", "mm.patch_merger",
|
||||
|
||||
// Vision attention
|
||||
"attn.qkv", "attn_qkv",
|
||||
"attn.proj", "attn_out",
|
||||
"attn.q_norm", "attn_q_norm",
|
||||
"attn.k_norm", "attn_k_norm",
|
||||
|
||||
// Vision norms
|
||||
"norm1", "ln1",
|
||||
"norm2", "ln2",
|
||||
|
||||
// Vision MLP
|
||||
"mlp.gate_proj", "ffn_gate",
|
||||
"mlp.up_proj", "ffn_up",
|
||||
"mlp.down_proj", "ffn_down",
|
||||
|
||||
// Merger (multimodal projector)
|
||||
"model.visual.merger.proj", "mm.model.fc",
|
||||
"model.visual.merger.post_projection_norm", "mm.post_norm",
|
||||
"model.visual.merger.gate_proj", "mm.gate",
|
||||
"model.visual.merger.up_proj", "mm.up",
|
||||
"model.visual.merger.down_proj", "mm.down",
|
||||
|
||||
// Language model
|
||||
"model.language_model.embed_tokens", "token_embd",
|
||||
"model.language_model.layers", "blk",
|
||||
"model.language_model.norm", "output_norm",
|
||||
"lm_head", "output",
|
||||
|
||||
// Language model attention
|
||||
"self_attn.q_proj", "attn_q",
|
||||
"self_attn.k_proj", "attn_k",
|
||||
"self_attn.v_proj", "attn_v",
|
||||
"self_attn.o_proj", "attn_out",
|
||||
|
||||
// Language model norms
|
||||
"input_layernorm", "attn_norm",
|
||||
"post_attention_layernorm", "ffn_norm",
|
||||
"post_self_attn_layernorm", "post_attn_norm",
|
||||
"post_mlp_layernorm", "post_ffn_norm",
|
||||
|
||||
// Language model MLP (remove mlp. prefix so ffn_* names work)
|
||||
"mlp.gate_up_proj", "ffn_gate_up",
|
||||
"mlp.down_proj", "ffn_down",
|
||||
}
|
||||
}
|
||||
@@ -99,6 +99,7 @@ func (st safetensor) Kind() uint32 {
|
||||
if st.dtype == "BF16" &&
|
||||
!strings.HasPrefix(st.name, "v.") &&
|
||||
!strings.HasPrefix(st.name, "s.") &&
|
||||
!strings.HasPrefix(st.name, "mm.") &&
|
||||
kind != tensorKindFP32 {
|
||||
kind = tensorKindBF16
|
||||
}
|
||||
|
||||
@@ -4,16 +4,6 @@ title: Anthropic compatibility
|
||||
|
||||
Ollama provides compatibility with the [Anthropic Messages API](https://docs.anthropic.com/en/api/messages) to help connect existing applications to Ollama, including tools like Claude Code.
|
||||
|
||||
## Recommended models
|
||||
|
||||
For coding use cases, models like `glm-4.7:cloud`, `minimax-m2.1:cloud`, and `qwen3-coder` are recommended.
|
||||
|
||||
Pull a model before use:
|
||||
```shell
|
||||
ollama pull qwen3-coder
|
||||
ollama pull glm-4.7:cloud
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### Environment variables
|
||||
@@ -22,8 +12,8 @@ To use Ollama with tools that expect the Anthropic API (like Claude Code), set t
|
||||
|
||||
```shell
|
||||
export ANTHROPIC_AUTH_TOKEN=ollama # required but ignored
|
||||
export ANTHROPIC_API_KEY="" # required but ignored
|
||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||
export ANTHROPIC_API_KEY=ollama # required but ignored
|
||||
```
|
||||
|
||||
### Simple `/v1/messages` example
|
||||
@@ -245,10 +235,41 @@ curl -X POST http://localhost:11434/v1/messages \
|
||||
|
||||
## Using with Claude Code
|
||||
|
||||
[Claude Code](https://code.claude.com/docs/en/overview) can be configured to use Ollama as its backend:
|
||||
[Claude Code](https://code.claude.com/docs/en/overview) can be configured to use Ollama as its backend.
|
||||
|
||||
### Recommended models
|
||||
|
||||
For coding use cases, models like `glm-4.7`, `minimax-m2.1`, and `qwen3-coder` are recommended.
|
||||
|
||||
Download a model before use:
|
||||
|
||||
```shell
|
||||
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
|
||||
ollama pull qwen3-coder
|
||||
```
|
||||
> Note: Qwen 3 coder is a 30B parameter model requiring at least 24GB of VRAM to run smoothly. More is required for longer context lengths.
|
||||
|
||||
```shell
|
||||
ollama pull glm-4.7:cloud
|
||||
```
|
||||
|
||||
### Quick setup
|
||||
|
||||
```shell
|
||||
ollama launch claude
|
||||
```
|
||||
|
||||
This will prompt you to select a model, configure Claude Code automatically, and launch it. To configure without launching:
|
||||
|
||||
```shell
|
||||
ollama launch claude --config
|
||||
```
|
||||
|
||||
### Manual setup
|
||||
|
||||
Set the environment variables and run Claude Code:
|
||||
|
||||
```shell
|
||||
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY="" claude --model qwen3-coder
|
||||
```
|
||||
|
||||
Or set the environment variables in your shell profile:
|
||||
@@ -256,19 +277,13 @@ Or set the environment variables in your shell profile:
|
||||
```shell
|
||||
export ANTHROPIC_AUTH_TOKEN=ollama
|
||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||
export ANTHROPIC_API_KEY=ollama
|
||||
export ANTHROPIC_API_KEY=""
|
||||
```
|
||||
|
||||
Then run Claude Code with any Ollama model:
|
||||
|
||||
```shell
|
||||
# Local models
|
||||
claude --model qwen3-coder
|
||||
claude --model gpt-oss:20b
|
||||
|
||||
# Cloud models
|
||||
claude --model glm-4.7:cloud
|
||||
claude --model minimax-m2.1:cloud
|
||||
```
|
||||
|
||||
## Endpoints
|
||||
|
||||
41
docs/cli.mdx
41
docs/cli.mdx
@@ -8,6 +8,47 @@ title: CLI Reference
|
||||
ollama run gemma3
|
||||
```
|
||||
|
||||
### Launch integrations
|
||||
|
||||
```
|
||||
ollama launch
|
||||
```
|
||||
|
||||
Configure and launch external applications to use Ollama models. This provides an interactive way to set up and start integrations with supported apps.
|
||||
|
||||
#### Supported integrations
|
||||
|
||||
- **OpenCode** - Open-source coding assistant
|
||||
- **Claude Code** - Anthropic's agentic coding tool
|
||||
- **Codex** - OpenAI's coding assistant
|
||||
- **Droid** - Factory's AI coding agent
|
||||
|
||||
#### Examples
|
||||
|
||||
Launch an integration interactively:
|
||||
|
||||
```
|
||||
ollama launch
|
||||
```
|
||||
|
||||
Launch a specific integration:
|
||||
|
||||
```
|
||||
ollama launch claude
|
||||
```
|
||||
|
||||
Launch with a specific model:
|
||||
|
||||
```
|
||||
ollama launch claude --model qwen3-coder
|
||||
```
|
||||
|
||||
Configure without launching:
|
||||
|
||||
```
|
||||
ollama launch droid --config
|
||||
```
|
||||
|
||||
#### Multiline input
|
||||
|
||||
For multiline input, you can wrap text with `"""`:
|
||||
|
||||
@@ -3,8 +3,6 @@ title: Cloud
|
||||
sidebarTitle: Cloud
|
||||
---
|
||||
|
||||
<Info>Ollama's cloud is currently in preview.</Info>
|
||||
|
||||
## Cloud Models
|
||||
|
||||
Ollama's cloud models are a new kind of model in Ollama that can run without a powerful GPU. Instead, cloud models are automatically offloaded to Ollama's cloud service while offering the same capabilities as local models, making it possible to keep using your local tools while running larger models that wouldn't fit on a personal computer.
|
||||
|
||||
@@ -8,7 +8,7 @@ Context length is the maximum number of tokens that the model has access to in m
|
||||
The default context length in Ollama is 4096 tokens.
|
||||
</Note>
|
||||
|
||||
Tasks which require large context like web search, agents, and coding tools should be set to at least 32000 tokens.
|
||||
Tasks which require large context like web search, agents, and coding tools should be set to at least 64000 tokens.
|
||||
|
||||
## Setting context length
|
||||
|
||||
@@ -24,7 +24,7 @@ Change the slider in the Ollama app under settings to your desired context lengt
|
||||
### CLI
|
||||
If editing the context length for Ollama is not possible, the context length can also be updated when serving Ollama.
|
||||
```
|
||||
OLLAMA_CONTEXT_LENGTH=32000 ollama serve
|
||||
OLLAMA_CONTEXT_LENGTH=64000 ollama serve
|
||||
```
|
||||
|
||||
### Check allocated context length and model offloading
|
||||
|
||||
@@ -71,6 +71,10 @@
|
||||
{
|
||||
"source": "/api",
|
||||
"destination": "/api/introduction"
|
||||
},
|
||||
{
|
||||
"source": "/integrations/clawdbot",
|
||||
"destination": "/integrations/openclaw"
|
||||
}
|
||||
],
|
||||
"navigation": {
|
||||
@@ -102,18 +106,20 @@
|
||||
"group": "Integrations",
|
||||
"pages": [
|
||||
"/integrations/claude-code",
|
||||
"/integrations/vscode",
|
||||
"/integrations/jetbrains",
|
||||
"/integrations/codex",
|
||||
"/integrations/cline",
|
||||
"/integrations/openclaw",
|
||||
"/integrations/codex",
|
||||
"/integrations/droid",
|
||||
"/integrations/goose",
|
||||
"/integrations/zed",
|
||||
"/integrations/roo-code",
|
||||
"/integrations/jetbrains",
|
||||
"/integrations/marimo",
|
||||
"/integrations/n8n",
|
||||
"/integrations/xcode",
|
||||
"/integrations/onyx",
|
||||
"/integrations/marimo"
|
||||
"/integrations/opencode",
|
||||
"/integrations/roo-code",
|
||||
"/integrations/vscode",
|
||||
"/integrations/xcode",
|
||||
"/integrations/zed"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -10,6 +10,7 @@ Check your compute compatibility to see if your card is supported:
|
||||
|
||||
| Compute Capability | Family | Cards |
|
||||
| ------------------ | ------------------- | ------------------------------------------------------------------------------------------------------------------------------ |
|
||||
| 12.1 | NVIDIA | `GB10 (DGX Spark)` |
|
||||
| 12.0 | GeForce RTX 50xx | `RTX 5060` `RTX 5060 Ti` `RTX 5070` `RTX 5070 Ti` `RTX 5080` `RTX 5090` |
|
||||
| | NVIDIA Professional | `RTX PRO 4000 Blackwell` `RTX PRO 4500 Blackwell` `RTX PRO 5000 Blackwell` `RTX PRO 6000 Blackwell` |
|
||||
| 9.0 | NVIDIA | `H200` `H100` |
|
||||
@@ -163,4 +164,4 @@ To select specific Vulkan GPU(s), you can set the environment variable
|
||||
`GGML_VK_VISIBLE_DEVICES` to one or more numeric IDs on the Ollama server as
|
||||
described in the [FAQ](faq#how-do-i-configure-ollama-server). If you
|
||||
encounter any problems with Vulkan based GPUs, you can disable all Vulkan GPUs
|
||||
by setting `GGML_VK_VISIBLE_DEVICES=-1`
|
||||
by setting `GGML_VK_VISIBLE_DEVICES=-1`
|
||||
|
||||
@@ -134,22 +134,12 @@ success
|
||||
|
||||
### Supported Quantizations
|
||||
|
||||
- `q4_0`
|
||||
- `q4_1`
|
||||
- `q5_0`
|
||||
- `q5_1`
|
||||
- `q8_0`
|
||||
|
||||
#### K-means Quantizations
|
||||
|
||||
- `q3_K_S`
|
||||
- `q3_K_M`
|
||||
- `q3_K_L`
|
||||
- `q4_K_S`
|
||||
- `q4_K_M`
|
||||
- `q5_K_S`
|
||||
- `q5_K_M`
|
||||
- `q6_K`
|
||||
|
||||
## Sharing your model on ollama.com
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ sidebarTitle: Welcome
|
||||
|
||||
<CardGroup cols={2}>
|
||||
<Card title="Quickstart" icon="rocket" href="/quickstart">
|
||||
Get up and running with your first model
|
||||
Get up and running with your first model or integrate Ollama with your favorite tools
|
||||
</Card>
|
||||
<Card
|
||||
title="Download Ollama"
|
||||
|
||||
@@ -4,7 +4,7 @@ title: Claude Code
|
||||
|
||||
Claude Code is Anthropic's agentic coding tool that can read, modify, and execute code in your working directory.
|
||||
|
||||
Open models can be used with Claude Code through Ollama's Anthropic-compatible API, enabling you to use models such as `qwen3-coder`, `gpt-oss:20b`, or other models.
|
||||
Open models can be used with Claude Code through Ollama's Anthropic-compatible API, enabling you to use models such as `glm-4.7`, `qwen3-coder`, `gpt-oss`.
|
||||
|
||||

|
||||
|
||||
@@ -26,12 +26,27 @@ irm https://claude.ai/install.ps1 | iex
|
||||
|
||||
## Usage with Ollama
|
||||
|
||||
### Quick setup
|
||||
|
||||
```shell
|
||||
ollama launch claude
|
||||
```
|
||||
|
||||
To configure without launching:
|
||||
|
||||
```shell
|
||||
ollama launch claude --config
|
||||
```
|
||||
|
||||
### Manual setup
|
||||
|
||||
Claude Code connects to Ollama using the Anthropic-compatible API.
|
||||
|
||||
1. Set the environment variables:
|
||||
|
||||
```shell
|
||||
export ANTHROPIC_AUTH_TOKEN=ollama
|
||||
export ANTHROPIC_API_KEY=""
|
||||
export ANTHROPIC_BASE_URL=http://localhost:11434
|
||||
```
|
||||
|
||||
@@ -44,35 +59,17 @@ claude --model gpt-oss:20b
|
||||
Or run with environment variables inline:
|
||||
|
||||
```shell
|
||||
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 claude --model gpt-oss:20b
|
||||
ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY="" claude --model qwen3-coder
|
||||
```
|
||||
|
||||
**Note:** Claude Code requires a large context window. We recommend at least 32K tokens. See the [context length documentation](/context-length) for how to adjust context length in Ollama.
|
||||
|
||||
## Connecting to ollama.com
|
||||
|
||||
1. Create an [API key](https://ollama.com/settings/keys) on ollama.com
|
||||
2. Set the environment variables:
|
||||
|
||||
```shell
|
||||
export ANTHROPIC_BASE_URL=https://ollama.com
|
||||
export ANTHROPIC_API_KEY=<your-api-key>
|
||||
```
|
||||
|
||||
3. Run Claude Code with a cloud model:
|
||||
|
||||
```shell
|
||||
claude --model glm-4.7:cloud
|
||||
```
|
||||
**Note:** Claude Code requires a large context window. We recommend at least 64k tokens. See the [context length documentation](/context-length) for how to adjust context length in Ollama.
|
||||
|
||||
## Recommended Models
|
||||
|
||||
### Cloud models
|
||||
- `glm-4.7:cloud` - High-performance cloud model
|
||||
- `minimax-m2.1:cloud` - Fast cloud model
|
||||
- `qwen3-coder:480b` - Large coding model
|
||||
- `qwen3-coder`
|
||||
- `glm-4.7`
|
||||
- `gpt-oss:20b`
|
||||
- `gpt-oss:120b`
|
||||
|
||||
Cloud models are also available at [ollama.com/search?c=cloud](https://ollama.com/search?c=cloud).
|
||||
|
||||
### Local models
|
||||
- `qwen3-coder` - Excellent for coding tasks
|
||||
- `gpt-oss:20b` - Strong general-purpose model
|
||||
- `gpt-oss:120b` - Larger general-purpose model for more complex tasks
|
||||
@@ -13,7 +13,21 @@ npm install -g @openai/codex
|
||||
|
||||
## Usage with Ollama
|
||||
|
||||
<Note>Codex requires a larger context window. It is recommended to use a context window of at least 32K tokens.</Note>
|
||||
<Note>Codex requires a larger context window. It is recommended to use a context window of at least 64k tokens.</Note>
|
||||
|
||||
### Quick setup
|
||||
|
||||
```
|
||||
ollama launch codex
|
||||
```
|
||||
|
||||
To configure without launching:
|
||||
|
||||
```shell
|
||||
ollama launch codex --config
|
||||
```
|
||||
|
||||
### Manual setup
|
||||
|
||||
To use `codex` with Ollama, use the `--oss` flag:
|
||||
|
||||
|
||||
@@ -11,10 +11,24 @@ Install the [Droid CLI](https://factory.ai/):
|
||||
curl -fsSL https://app.factory.ai/cli | sh
|
||||
```
|
||||
|
||||
<Note>Droid requires a larger context window. It is recommended to use a context window of at least 32K tokens. See [Context length](/context-length) for more information.</Note>
|
||||
<Note>Droid requires a larger context window. It is recommended to use a context window of at least 64k tokens. See [Context length](/context-length) for more information.</Note>
|
||||
|
||||
## Usage with Ollama
|
||||
|
||||
### Quick setup
|
||||
|
||||
```bash
|
||||
ollama launch droid
|
||||
```
|
||||
|
||||
To configure without launching:
|
||||
|
||||
```shell
|
||||
ollama launch droid --config
|
||||
```
|
||||
|
||||
### Manual setup
|
||||
|
||||
Add a local configuration block to `~/.factory/config.json`:
|
||||
|
||||
```json
|
||||
@@ -73,4 +87,4 @@ Add the cloud configuration block to `~/.factory/config.json`:
|
||||
}
|
||||
```
|
||||
|
||||
Run `droid` in a new terminal to load the new settings.
|
||||
Run `droid` in a new terminal to load the new settings.
|
||||
|
||||
50
docs/integrations/openclaw.mdx
Normal file
50
docs/integrations/openclaw.mdx
Normal file
@@ -0,0 +1,50 @@
|
||||
---
|
||||
title: OpenClaw
|
||||
---
|
||||
|
||||
OpenClaw is a personal AI assistant that runs on your own devices. It bridges messaging services (WhatsApp, Telegram, Slack, Discord, iMessage, and more) to AI coding agents through a centralized gateway.
|
||||
|
||||
## Install
|
||||
|
||||
Install [OpenClaw](https://openclaw.ai/)
|
||||
|
||||
```bash
|
||||
npm install -g openclaw@latest
|
||||
```
|
||||
|
||||
Then run the onboarding wizard:
|
||||
|
||||
```bash
|
||||
openclaw onboard --install-daemon
|
||||
```
|
||||
|
||||
<Note>OpenClaw requires a larger context window. It is recommended to use a context window of at least 64k tokens. See [Context length](/context-length) for more information.</Note>
|
||||
|
||||
## Usage with Ollama
|
||||
|
||||
### Quick setup
|
||||
|
||||
```bash
|
||||
ollama launch openclaw
|
||||
```
|
||||
|
||||
<Note>Previously known as Clawdbot. `ollama launch clawdbot` still works as an alias.</Note>
|
||||
|
||||
This configures OpenClaw to use Ollama and starts the gateway.
|
||||
If the gateway is already running, no changes need to be made as the gateway will auto-reload the changes.
|
||||
|
||||
|
||||
To configure without launching:
|
||||
|
||||
```shell
|
||||
ollama launch openclaw --config
|
||||
```
|
||||
|
||||
## Recommended Models
|
||||
|
||||
- `qwen3-coder`
|
||||
- `glm-4.7`
|
||||
- `gpt-oss:20b`
|
||||
- `gpt-oss:120b`
|
||||
|
||||
Cloud models are also available at [ollama.com/search?c=cloud](https://ollama.com/search?c=cloud).
|
||||
106
docs/integrations/opencode.mdx
Normal file
106
docs/integrations/opencode.mdx
Normal file
@@ -0,0 +1,106 @@
|
||||
---
|
||||
title: OpenCode
|
||||
---
|
||||
|
||||
OpenCode is an open-source AI coding assistant that runs in your terminal.
|
||||
|
||||
## Install
|
||||
|
||||
Install the [OpenCode CLI](https://opencode.ai):
|
||||
|
||||
```bash
|
||||
curl -fsSL https://opencode.ai/install | bash
|
||||
```
|
||||
|
||||
<Note>OpenCode requires a larger context window. It is recommended to use a context window of at least 64k tokens. See [Context length](/context-length) for more information.</Note>
|
||||
|
||||
## Usage with Ollama
|
||||
|
||||
### Quick setup
|
||||
|
||||
```bash
|
||||
ollama launch opencode
|
||||
```
|
||||
|
||||
To configure without launching:
|
||||
|
||||
```shell
|
||||
ollama launch opencode --config
|
||||
```
|
||||
|
||||
### Manual setup
|
||||
|
||||
Add a configuration block to `~/.config/opencode/opencode.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"$schema": "https://opencode.ai/config.json",
|
||||
"provider": {
|
||||
"ollama": {
|
||||
"npm": "@ai-sdk/openai-compatible",
|
||||
"name": "Ollama",
|
||||
"options": {
|
||||
"baseURL": "http://localhost:11434/v1"
|
||||
},
|
||||
"models": {
|
||||
"qwen3-coder": {
|
||||
"name": "qwen3-coder"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Cloud Models
|
||||
|
||||
`glm-4.7:cloud` is the recommended model for use with OpenCode.
|
||||
|
||||
Add the cloud configuration to `~/.config/opencode/opencode.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"$schema": "https://opencode.ai/config.json",
|
||||
"provider": {
|
||||
"ollama": {
|
||||
"npm": "@ai-sdk/openai-compatible",
|
||||
"name": "Ollama",
|
||||
"options": {
|
||||
"baseURL": "http://localhost:11434/v1"
|
||||
},
|
||||
"models": {
|
||||
"glm-4.7:cloud": {
|
||||
"name": "glm-4.7:cloud"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Connecting to ollama.com
|
||||
|
||||
1. Create an [API key](https://ollama.com/settings/keys) from ollama.com and export it as `OLLAMA_API_KEY`.
|
||||
2. Update `~/.config/opencode/opencode.json` to point to ollama.com:
|
||||
|
||||
```json
|
||||
{
|
||||
"$schema": "https://opencode.ai/config.json",
|
||||
"provider": {
|
||||
"ollama": {
|
||||
"npm": "@ai-sdk/openai-compatible",
|
||||
"name": "Ollama Cloud",
|
||||
"options": {
|
||||
"baseURL": "https://ollama.com/v1"
|
||||
},
|
||||
"models": {
|
||||
"glm-4.7:cloud": {
|
||||
"name": "glm-4.7:cloud"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Run `opencode` in a new terminal to load the new settings.
|
||||
@@ -18,13 +18,13 @@ This quickstart will walk your through running your first model with Ollama. To
|
||||
<Tab title="CLI">
|
||||
Open a terminal and run the command:
|
||||
|
||||
```
|
||||
```sh
|
||||
ollama run gemma3
|
||||
```
|
||||
|
||||
</Tab>
|
||||
<Tab title="cURL">
|
||||
```
|
||||
```sh
|
||||
ollama pull gemma3
|
||||
```
|
||||
|
||||
@@ -45,13 +45,13 @@ This quickstart will walk your through running your first model with Ollama. To
|
||||
<Tab title="Python">
|
||||
Start by downloading a model:
|
||||
|
||||
```
|
||||
```sh
|
||||
ollama pull gemma3
|
||||
```
|
||||
|
||||
Then install Ollama's Python library:
|
||||
|
||||
```
|
||||
```sh
|
||||
pip install ollama
|
||||
```
|
||||
|
||||
@@ -101,3 +101,42 @@ This quickstart will walk your through running your first model with Ollama. To
|
||||
</Tabs>
|
||||
|
||||
See a full list of available models [here](https://ollama.com/models).
|
||||
|
||||
## Coding
|
||||
|
||||
For coding use cases, we recommend using the `glm-4.7-flash` model.
|
||||
|
||||
Note: this model requires 23 GB of VRAM with 64000 tokens context length.
|
||||
```sh
|
||||
ollama pull glm-4.7-flash
|
||||
```
|
||||
|
||||
Alternatively, you can use a more powerful cloud model (with full context length):
|
||||
```sh
|
||||
ollama pull glm-4.7:cloud
|
||||
```
|
||||
|
||||
Use `ollama launch` to quickly set up a coding tool with Ollama models:
|
||||
|
||||
```sh
|
||||
ollama launch
|
||||
```
|
||||
|
||||
### Supported integrations
|
||||
|
||||
- [OpenCode](/integrations/opencode) - Open-source coding assistant
|
||||
- [Claude Code](/integrations/claude-code) - Anthropic's agentic coding tool
|
||||
- [Codex](/integrations/codex) - OpenAI's coding assistant
|
||||
- [Droid](/integrations/droid) - Factory's AI coding agent
|
||||
|
||||
### Launch with a specific model
|
||||
|
||||
```sh
|
||||
ollama launch claude --model glm-4.7-flash
|
||||
```
|
||||
|
||||
### Configure without launching
|
||||
|
||||
```sh
|
||||
ollama launch claude --config
|
||||
```
|
||||
|
||||
@@ -201,7 +201,7 @@ var (
|
||||
// Enable the new Ollama engine
|
||||
NewEngine = Bool("OLLAMA_NEW_ENGINE")
|
||||
// ContextLength sets the default context length
|
||||
ContextLength = Uint("OLLAMA_CONTEXT_LENGTH", 4096)
|
||||
ContextLength = Uint("OLLAMA_CONTEXT_LENGTH", 0)
|
||||
// Auth enables authentication between the Ollama client and server
|
||||
UseAuth = Bool("OLLAMA_AUTH")
|
||||
// Enable Vulkan backend
|
||||
@@ -290,7 +290,7 @@ func AsMap() map[string]EnvVar {
|
||||
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowedOrigins(), "A comma separated list of allowed origins"},
|
||||
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"},
|
||||
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
|
||||
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4096)"},
|
||||
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4k/32k/256k based on VRAM)"},
|
||||
"OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"},
|
||||
"OLLAMA_REMOTES": {"OLLAMA_REMOTES", Remotes(), "Allowed hosts for remote models (default \"ollama.com\")"},
|
||||
|
||||
|
||||
@@ -282,7 +282,7 @@ func TestVar(t *testing.T) {
|
||||
|
||||
func TestContextLength(t *testing.T) {
|
||||
cases := map[string]uint{
|
||||
"": 4096,
|
||||
"": 0,
|
||||
"2048": 2048,
|
||||
}
|
||||
|
||||
|
||||
@@ -270,6 +270,7 @@ func (kv KV) OllamaEngineRequired() bool {
|
||||
"qwen3", "qwen3moe",
|
||||
"qwen3vl", "qwen3vlmoe",
|
||||
"glm4moelite",
|
||||
"glmocr",
|
||||
"lfm2",
|
||||
}, kv.Architecture())
|
||||
}
|
||||
@@ -859,6 +860,7 @@ func (f GGML) FlashAttention() bool {
|
||||
"bert",
|
||||
"gemma3",
|
||||
"glm4moelite",
|
||||
"glmocr",
|
||||
"gptoss", "gpt-oss",
|
||||
"lfm2",
|
||||
"mistral3",
|
||||
|
||||
1
go.mod
1
go.mod
@@ -27,6 +27,7 @@ require (
|
||||
github.com/mattn/go-runewidth v0.0.14
|
||||
github.com/nlpodyssey/gopickle v0.3.0
|
||||
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c
|
||||
github.com/tkrajina/typescriptify-golang-structs v0.2.0
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8
|
||||
golang.org/x/image v0.22.0
|
||||
|
||||
3
go.sum
3
go.sum
@@ -174,6 +174,8 @@ github.com/phpdave11/gofpdf v1.4.2/go.mod h1:zpO6xFn9yxo3YLyMvW8HcKWVdbNqgIfOOp2
|
||||
github.com/phpdave11/gofpdi v1.0.12/go.mod h1:vBmVV0Do6hSBHC8uKUQ71JGW+ZGQq74llk/7bXwjDoI=
|
||||
github.com/pierrec/lz4/v4 v4.1.8 h1:ieHkV+i2BRzngO4Wd/3HGowuZStgq6QkPsD1eolNAO4=
|
||||
github.com/pierrec/lz4/v4 v4.1.8/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU=
|
||||
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
@@ -304,6 +306,7 @@ golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7w
|
||||
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
|
||||
|
||||
@@ -73,13 +73,18 @@ func manhattanDistance[V float32 | float64](v1, v2 []V) V {
|
||||
}
|
||||
|
||||
func TestEmbedCosineDistanceCorrelation(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
softTimeout, hardTimeout := getTimeouts(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
started := time.Now()
|
||||
for _, model := range libraryEmbedModels {
|
||||
t.Run(model, func(t *testing.T) {
|
||||
if time.Since(started) > softTimeout {
|
||||
t.Skip("skipping - soft timeout exceeded")
|
||||
}
|
||||
testCases := []struct {
|
||||
a string
|
||||
b string
|
||||
@@ -489,14 +494,19 @@ func TestEmbedTruncation(t *testing.T) {
|
||||
|
||||
// TestEmbedLargeInput tests that embedding models can handle large inputs that would exceed typical batch sizes.
|
||||
func TestEmbedLargeInput(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
|
||||
softTimeout, hardTimeout := getTimeouts(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
started := time.Now()
|
||||
for _, model := range libraryEmbedModels {
|
||||
model := model
|
||||
t.Run(model, func(t *testing.T) {
|
||||
if time.Since(started) > softTimeout {
|
||||
t.Skip("skipping - soft timeout exceeded")
|
||||
}
|
||||
mctx, mcancel := context.WithTimeout(ctx, 2*time.Minute)
|
||||
defer mcancel()
|
||||
|
||||
|
||||
@@ -21,9 +21,10 @@ func testPropsMap(m map[string]api.ToolProperty) *api.ToolPropertiesMap {
|
||||
}
|
||||
|
||||
func TestAPIToolCalling(t *testing.T) {
|
||||
initialTimeout := 60 * time.Second
|
||||
streamTimeout := 60 * time.Second
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
|
||||
initialTimeout := 90 * time.Second
|
||||
streamTimeout := 90 * time.Second
|
||||
softTimeout, hardTimeout := getTimeouts(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
|
||||
defer cancel()
|
||||
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
@@ -47,8 +48,12 @@ func TestAPIToolCalling(t *testing.T) {
|
||||
"granite3.3": 7,
|
||||
}
|
||||
|
||||
started := time.Now()
|
||||
for _, model := range libraryToolsModels {
|
||||
t.Run(model, func(t *testing.T) {
|
||||
if time.Since(started) > softTimeout {
|
||||
t.Skip("skipping - soft timeout exceeded")
|
||||
}
|
||||
if v, ok := minVRAM[model]; ok {
|
||||
skipUnderMinVRAM(t, v)
|
||||
}
|
||||
|
||||
@@ -14,25 +14,28 @@ make -f Makefile.sync apply-patches
|
||||
|
||||
### Updating Base Commit
|
||||
|
||||
**Pin to new base commit**
|
||||
To update to a new base commit:
|
||||
|
||||
To change the base commit, update `FETCH_HEAD` in Makefile.sync.
|
||||
1. **Update FETCH_HEAD** in `Makefile.sync` to the new commit hash.
|
||||
|
||||
When updating to a newer base commit, the existing patches may not apply cleanly and require manual merge resolution.
|
||||
2. **Check for upstreamed patches**: Before applying, review if any patches have been merged upstream. Remove those patches from `./patches/` to avoid conflicts.
|
||||
|
||||
Start by applying the patches. If any of the patches have conflicts, the `git am` will stop at the first failure.
|
||||
3. **Apply patches**:
|
||||
```shell
|
||||
make -f Makefile.sync apply-patches
|
||||
```
|
||||
|
||||
```shell
|
||||
make -f Makefile.sync apply-patches
|
||||
```
|
||||
4. **Resolve conflicts** (if any): When `git am` fails on a patch:
|
||||
- Fix conflicts in `./vendor/`
|
||||
- Stage the resolved files: `git -C llama/vendor add <file>`
|
||||
- Continue: `git -C llama/vendor am --continue`
|
||||
- Re-run: `make -f Makefile.sync apply-patches`
|
||||
- Repeat until all patches are applied.
|
||||
|
||||
If there are conflicts, you will see an error message. Resolve the conflicts in `./vendor/`, and continue the patch series with `git am --continue` and rerun `make -f Makefile.sync apply-patches`. Repeat until all patches are successfully applied.
|
||||
|
||||
Once all patches are applied, commit the changes to the tracking repository.
|
||||
|
||||
```shell
|
||||
make -f Makefile.sync format-patches sync
|
||||
```
|
||||
5. **Regenerate patches and sync**:
|
||||
```shell
|
||||
make -f Makefile.sync format-patches sync
|
||||
```
|
||||
|
||||
### Generating Patches
|
||||
|
||||
|
||||
2
llama/build-info.cpp
generated
vendored
2
llama/build-info.cpp
generated
vendored
@@ -1,4 +1,4 @@
|
||||
int LLAMA_BUILD_NUMBER = 0;
|
||||
char const *LLAMA_COMMIT = "ec98e2002";
|
||||
char const *LLAMA_COMMIT = "a5bb8ba4c50257437630c136210396810741bbf7";
|
||||
char const *LLAMA_COMPILER = "";
|
||||
char const *LLAMA_BUILD_TARGET = "";
|
||||
|
||||
69
llama/llama.cpp/common/common.cpp
vendored
69
llama/llama.cpp/common/common.cpp
vendored
@@ -251,7 +251,7 @@ bool set_process_priority(enum ggml_sched_priority prio) {
|
||||
case GGML_SCHED_PRIO_REALTIME: p = -20; break;
|
||||
}
|
||||
|
||||
if (!setpriority(PRIO_PROCESS, 0, p)) {
|
||||
if (setpriority(PRIO_PROCESS, 0, p) != 0) {
|
||||
LOG_WRN("failed to set process priority %d : %s (%d)\n", prio, strerror(errno), errno);
|
||||
return false;
|
||||
}
|
||||
@@ -1078,12 +1078,15 @@ struct common_init_result::impl {
|
||||
impl() = default;
|
||||
~impl() = default;
|
||||
|
||||
// note: the order in which model, context, etc. are declared matters because their destructors will be called bottom-to-top
|
||||
|
||||
llama_model_ptr model;
|
||||
llama_context_ptr context;
|
||||
|
||||
std::vector<llama_adapter_lora_ptr> lora;
|
||||
|
||||
std::vector<common_sampler_ptr> samplers;
|
||||
std::vector<llama_sampler_seq_config> samplers_seq_config;
|
||||
};
|
||||
|
||||
common_init_result::common_init_result(common_params & params) :
|
||||
@@ -1092,9 +1095,9 @@ common_init_result::common_init_result(common_params & params) :
|
||||
auto cparams = common_context_params_to_llama(params);
|
||||
|
||||
if (params.fit_params) {
|
||||
LOG_INF("%s: fitting params to device memory, to report bugs during this step use -fit off (or --verbose if you can't)\n", __func__);
|
||||
LOG_INF("%s: fitting params to device memory, for bugs during this step try to reproduce them with -fit off, or provide --verbose logs if the bug only occurs with -fit on\n", __func__);
|
||||
llama_params_fit(params.model.path.c_str(), &mparams, &cparams,
|
||||
params.tensor_split, params.tensor_buft_overrides.data(), params.fit_params_target, params.fit_params_min_ctx,
|
||||
params.tensor_split, params.tensor_buft_overrides.data(), params.fit_params_target.data(), params.fit_params_min_ctx,
|
||||
params.verbosity >= 4 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_ERROR);
|
||||
}
|
||||
|
||||
@@ -1107,6 +1110,25 @@ common_init_result::common_init_result(common_params & params) :
|
||||
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
// load and optionally apply lora adapters (must be loaded before context creation)
|
||||
for (auto & la : params.lora_adapters) {
|
||||
llama_adapter_lora_ptr lora;
|
||||
lora.reset(llama_adapter_lora_init(model, la.path.c_str()));
|
||||
if (lora == nullptr) {
|
||||
LOG_ERR("%s: failed to load lora adapter '%s'\n", __func__, la.path.c_str());
|
||||
pimpl->model.reset(model);
|
||||
return;
|
||||
}
|
||||
|
||||
char buf[1024];
|
||||
la.ptr = lora.get();
|
||||
llama_adapter_meta_val_str(la.ptr, "adapter.lora.task_name", buf, sizeof(buf));
|
||||
la.task_name = buf;
|
||||
llama_adapter_meta_val_str(la.ptr, "adapter.lora.prompt_prefix", buf, sizeof(buf));
|
||||
la.prompt_prefix = buf;
|
||||
pimpl->lora.emplace_back(std::move(lora)); // copy to list of loaded adapters
|
||||
}
|
||||
|
||||
// updates params.sampling
|
||||
// TODO: fix naming
|
||||
common_init_sampler_from_model(model, params.sampling);
|
||||
@@ -1141,10 +1163,18 @@ common_init_result::common_init_result(common_params & params) :
|
||||
// params.sampling.dry_penalty_last_n = llama_n_ctx(lctx);
|
||||
//}
|
||||
|
||||
// init the backend samplers as part of the context creation
|
||||
pimpl->samplers.resize(cparams.n_seq_max);
|
||||
pimpl->samplers_seq_config.resize(cparams.n_seq_max);
|
||||
|
||||
for (int i = 0; i < (int) cparams.n_seq_max; ++i) {
|
||||
pimpl->samplers[i].reset(common_sampler_init(model, params.sampling));
|
||||
pimpl->samplers_seq_config[i] = { i, common_sampler_get(pimpl->samplers[i].get()) };
|
||||
}
|
||||
|
||||
if (params.sampling.backend_sampling) {
|
||||
cparams.samplers = pimpl->samplers_seq_config.data();
|
||||
cparams.n_samplers = pimpl->samplers_seq_config.size();
|
||||
}
|
||||
|
||||
llama_context * lctx = llama_init_from_model(model, cparams);
|
||||
@@ -1168,6 +1198,12 @@ common_sampler * common_init_result::sampler(llama_seq_id seq_id) {
|
||||
return pimpl->samplers[seq_id].get();
|
||||
}
|
||||
|
||||
void common_init_result::reset_samplers() {
|
||||
for (int i = 0; i < (int) pimpl->samplers.size(); ++i) {
|
||||
llama_sampler_reset(common_sampler_get(pimpl->samplers[i].get()));
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<llama_adapter_lora_ptr> & common_init_result::lora() {
|
||||
return pimpl->lora;
|
||||
}
|
||||
@@ -1243,24 +1279,6 @@ common_init_result_ptr common_init_from_params(common_params & params) {
|
||||
}
|
||||
}
|
||||
|
||||
// load and optionally apply lora adapters
|
||||
for (auto & la : params.lora_adapters) {
|
||||
llama_adapter_lora_ptr lora;
|
||||
lora.reset(llama_adapter_lora_init(model, la.path.c_str()));
|
||||
if (lora == nullptr) {
|
||||
LOG_ERR("%s: failed to apply lora adapter '%s'\n", __func__, la.path.c_str());
|
||||
return res;
|
||||
}
|
||||
|
||||
char buf[1024];
|
||||
la.ptr = lora.get();
|
||||
llama_adapter_meta_val_str(la.ptr, "adapter.lora.task_name", buf, sizeof(buf));
|
||||
la.task_name = buf;
|
||||
llama_adapter_meta_val_str(la.ptr, "adapter.lora.prompt_prefix", buf, sizeof(buf));
|
||||
la.prompt_prefix = buf;
|
||||
res->lora().emplace_back(std::move(lora)); // copy to list of loaded adapters
|
||||
}
|
||||
|
||||
if (!params.lora_init_without_apply) {
|
||||
common_set_adapter_lora(lctx, params.lora_adapters);
|
||||
}
|
||||
@@ -1301,6 +1319,9 @@ common_init_result_ptr common_init_from_params(common_params & params) {
|
||||
llama_synchronize(lctx);
|
||||
llama_perf_context_reset(lctx);
|
||||
llama_set_warmup(lctx, false);
|
||||
|
||||
// reset samplers to reset RNG state after warmup to the seeded state
|
||||
res->reset_samplers();
|
||||
}
|
||||
|
||||
return res;
|
||||
@@ -1339,14 +1360,12 @@ struct llama_model_params common_model_params_to_llama(common_params & params) {
|
||||
mparams.devices = params.devices.data();
|
||||
}
|
||||
|
||||
if (params.n_gpu_layers != -1) {
|
||||
mparams.n_gpu_layers = params.n_gpu_layers;
|
||||
}
|
||||
|
||||
mparams.n_gpu_layers = params.n_gpu_layers;
|
||||
mparams.main_gpu = params.main_gpu;
|
||||
mparams.split_mode = params.split_mode;
|
||||
mparams.tensor_split = params.tensor_split;
|
||||
mparams.use_mmap = params.use_mmap;
|
||||
mparams.use_direct_io = params.use_direct_io;
|
||||
mparams.use_mlock = params.use_mlock;
|
||||
mparams.check_tensors = params.check_tensors;
|
||||
mparams.use_extra_bufts = !params.no_extra_bufts;
|
||||
|
||||
93
llama/llama.cpp/common/common.h
vendored
93
llama/llama.cpp/common/common.h
vendored
@@ -57,6 +57,8 @@ extern const char * LLAMA_COMMIT;
|
||||
extern const char * LLAMA_COMPILER;
|
||||
extern const char * LLAMA_BUILD_TARGET;
|
||||
|
||||
const static std::string build_info("b" + std::to_string(LLAMA_BUILD_NUMBER) + "-" + LLAMA_COMMIT);
|
||||
|
||||
struct common_control_vector_load_info;
|
||||
|
||||
//
|
||||
@@ -80,6 +82,8 @@ int32_t cpu_get_num_math();
|
||||
//
|
||||
|
||||
enum llama_example {
|
||||
LLAMA_EXAMPLE_BATCHED,
|
||||
LLAMA_EXAMPLE_DEBUG,
|
||||
LLAMA_EXAMPLE_COMMON,
|
||||
LLAMA_EXAMPLE_SPECULATIVE,
|
||||
LLAMA_EXAMPLE_COMPLETION,
|
||||
@@ -117,6 +121,7 @@ enum common_sampler_type {
|
||||
COMMON_SAMPLER_TYPE_INFILL = 9,
|
||||
COMMON_SAMPLER_TYPE_PENALTIES = 10,
|
||||
COMMON_SAMPLER_TYPE_TOP_N_SIGMA = 11,
|
||||
COMMON_SAMPLER_TYPE_ADAPTIVE_P = 12,
|
||||
};
|
||||
|
||||
// dimensionality reduction methods, used by cvector-generator
|
||||
@@ -164,32 +169,34 @@ enum common_params_sampling_config : uint64_t {
|
||||
struct common_params_sampling {
|
||||
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
|
||||
|
||||
int32_t n_prev = 64; // number of previous tokens to remember
|
||||
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
|
||||
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
|
||||
int32_t top_k = 40; // <= 0 to use vocab size
|
||||
float top_p = 0.95f; // 1.0 = disabled
|
||||
float min_p = 0.05f; // 0.0 = disabled
|
||||
float xtc_probability = 0.00f; // 0.0 = disabled
|
||||
float xtc_threshold = 0.10f; // > 0.5 disables XTC
|
||||
float typ_p = 1.00f; // typical_p, 1.0 = disabled
|
||||
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
|
||||
float dynatemp_range = 0.00f; // 0.0 = disabled
|
||||
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
|
||||
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
|
||||
float penalty_repeat = 1.00f; // 1.0 = disabled
|
||||
float penalty_freq = 0.00f; // 0.0 = disabled
|
||||
float penalty_present = 0.00f; // 0.0 = disabled
|
||||
float dry_multiplier = 0.0f; // 0.0 = disabled; DRY repetition penalty for tokens extending repetition:
|
||||
float dry_base = 1.75f; // 0.0 = disabled; multiplier * base ^ (length of sequence before token - allowed length)
|
||||
int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty
|
||||
int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
|
||||
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
|
||||
float top_n_sigma = -1.00f;// -1.0 = disabled
|
||||
float mirostat_tau = 5.00f; // target entropy
|
||||
float mirostat_eta = 0.10f; // learning rate
|
||||
int32_t n_prev = 64; // number of previous tokens to remember
|
||||
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
|
||||
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
|
||||
int32_t top_k = 40; // <= 0 to use vocab size
|
||||
float top_p = 0.95f; // 1.0 = disabled
|
||||
float min_p = 0.05f; // 0.0 = disabled
|
||||
float xtc_probability = 0.00f; // 0.0 = disabled
|
||||
float xtc_threshold = 0.10f; // > 0.5 disables XTC
|
||||
float typ_p = 1.00f; // typical_p, 1.0 = disabled
|
||||
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
|
||||
float dynatemp_range = 0.00f; // 0.0 = disabled
|
||||
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
|
||||
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
|
||||
float penalty_repeat = 1.00f; // 1.0 = disabled
|
||||
float penalty_freq = 0.00f; // 0.0 = disabled
|
||||
float penalty_present = 0.00f; // 0.0 = disabled
|
||||
float dry_multiplier = 0.0f; // 0.0 = disabled; DRY repetition penalty for tokens extending repetition:
|
||||
float dry_base = 1.75f; // 0.0 = disabled; multiplier * base ^ (length of sequence before token - allowed length)
|
||||
int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty
|
||||
int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
|
||||
float adaptive_target = -1.0f; // select tokens near this probability (valid range 0.0 to 1.0; negative = disabled)
|
||||
float adaptive_decay = 0.90f; // EMA decay for adaptation; history ≈ 1/(1-decay) tokens (0.0 - 0.99)
|
||||
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
|
||||
float top_n_sigma = -1.00f; // -1.0 = disabled
|
||||
float mirostat_tau = 5.00f; // target entropy
|
||||
float mirostat_eta = 0.10f; // learning rate
|
||||
bool ignore_eos = false;
|
||||
bool no_perf = false; // disable performance metrics
|
||||
bool no_perf = false; // disable performance metrics
|
||||
bool timing_per_token = false;
|
||||
|
||||
uint64_t user_sampling_config = 0; // bitfield to track user-specified samplers
|
||||
@@ -216,6 +223,8 @@ struct common_params_sampling {
|
||||
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
|
||||
std::vector<llama_logit_bias> logit_bias_eog; // pre-calculated logit biases for EOG tokens
|
||||
|
||||
bool backend_sampling = false;
|
||||
|
||||
bool has_logit_bias() const {
|
||||
return !logit_bias.empty();
|
||||
}
|
||||
@@ -277,6 +286,7 @@ struct common_params_diffusion {
|
||||
};
|
||||
|
||||
// reasoning API response format (not to be confused as chat template's reasoning format)
|
||||
// only used by server
|
||||
enum common_reasoning_format {
|
||||
COMMON_REASONING_FORMAT_NONE,
|
||||
COMMON_REASONING_FORMAT_AUTO, // Same as deepseek, using `message.reasoning_content`
|
||||
@@ -329,12 +339,14 @@ struct common_params {
|
||||
// offload params
|
||||
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
|
||||
|
||||
int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
|
||||
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
|
||||
float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs
|
||||
bool fit_params = true; // whether to fit unset model/context parameters to free device memory
|
||||
size_t fit_params_target = 1024 * 1024*1024; // margin per device in bytes for fitting parameters to free memory
|
||||
int32_t fit_params_min_ctx = 4096; // minimum context size to set when trying to reduce memory use
|
||||
int32_t n_gpu_layers = -1; // number of layers to store in VRAM, -1 is auto, <= -2 is all
|
||||
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
|
||||
float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs
|
||||
bool fit_params = true; // whether to fit unset model/context parameters to free device memory
|
||||
int32_t fit_params_min_ctx = 4096; // minimum context size to set when trying to reduce memory use
|
||||
|
||||
// margin per device in bytes for fitting parameters to free memory:
|
||||
std::vector<size_t> fit_params_target = std::vector<size_t>(llama_max_devices(), 1024 * 1024*1024);
|
||||
|
||||
enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs
|
||||
|
||||
@@ -370,6 +382,11 @@ struct common_params {
|
||||
std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding // NOLINT
|
||||
std::string logits_file = ""; // file for saving *all* logits // NOLINT
|
||||
|
||||
// llama-debug specific options
|
||||
std::string logits_output_dir = "data"; // directory for saving logits output files // NOLINT
|
||||
bool save_logits = false; // whether to save logits to files // NOLINT
|
||||
std::vector<std::string> tensor_filter; // filter tensor names for debug output (regex) // NOLINT
|
||||
|
||||
std::vector<std::string> in_files; // all input files
|
||||
std::vector<std::string> antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts)
|
||||
std::vector<llama_model_kv_override> kv_overrides;
|
||||
@@ -420,7 +437,8 @@ struct common_params {
|
||||
bool kv_unified = false; // enable unified KV cache
|
||||
|
||||
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
|
||||
bool use_mmap = true; // use mmap for faster loads
|
||||
bool use_mmap = true; // enable mmap to use filesystem cache
|
||||
bool use_direct_io = true; // read from disk without buffering for faster model loading
|
||||
bool use_mlock = false; // use mlock to keep model in memory
|
||||
bool verbose_prompt = false; // print prompt tokens before generation
|
||||
bool display_prompt = true; // print prompt before generation
|
||||
@@ -464,6 +482,7 @@ struct common_params {
|
||||
int32_t timeout_write = timeout_read; // http write timeout in seconds
|
||||
int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool)
|
||||
int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting
|
||||
bool cache_prompt = true; // whether to enable prompt caching
|
||||
int32_t n_ctx_checkpoints = 8; // max number of context checkpoints per slot
|
||||
int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc.
|
||||
|
||||
@@ -475,7 +494,8 @@ struct common_params {
|
||||
bool enable_chat_template = true;
|
||||
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
||||
int reasoning_budget = -1;
|
||||
bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response
|
||||
bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response
|
||||
int sleep_idle_seconds = -1; // if >0, server will sleep after this many seconds of idle time
|
||||
|
||||
std::vector<std::string> api_keys;
|
||||
|
||||
@@ -484,8 +504,11 @@ struct common_params {
|
||||
|
||||
std::map<std::string, std::string> default_template_kwargs;
|
||||
|
||||
// webui configs
|
||||
bool webui = true;
|
||||
std::string webui_config_json;
|
||||
|
||||
// "advanced" endpoints are disabled by default for better security
|
||||
bool webui = true;
|
||||
bool endpoint_slots = true;
|
||||
bool endpoint_props = false; // only control POST requests, not GET
|
||||
bool endpoint_metrics = false;
|
||||
@@ -685,7 +708,9 @@ struct common_init_result {
|
||||
|
||||
llama_model * model();
|
||||
llama_context * context();
|
||||
|
||||
common_sampler * sampler(llama_seq_id seq_id);
|
||||
void reset_samplers();
|
||||
|
||||
std::vector<llama_adapter_lora_ptr> & lora();
|
||||
|
||||
|
||||
229
llama/llama.cpp/common/sampling.cpp
vendored
229
llama/llama.cpp/common/sampling.cpp
vendored
@@ -104,10 +104,9 @@ struct ring_buffer {
|
||||
struct common_sampler {
|
||||
common_params_sampling params;
|
||||
|
||||
struct llama_sampler * grmr;
|
||||
struct llama_sampler * chain;
|
||||
|
||||
bool grammar;
|
||||
|
||||
ring_buffer<llama_token> prev;
|
||||
|
||||
std::vector<llama_token_data> cur;
|
||||
@@ -121,17 +120,34 @@ struct common_sampler {
|
||||
}
|
||||
|
||||
void set_logits(struct llama_context * ctx, int idx) {
|
||||
const auto * logits = llama_get_logits_ith(ctx, idx);
|
||||
const float * sampled_probs = llama_get_sampled_probs_ith (ctx, idx);
|
||||
const float * sampled_logits = llama_get_sampled_logits_ith (ctx, idx);
|
||||
const llama_token * sampled_ids = llama_get_sampled_candidates_ith(ctx, idx);
|
||||
|
||||
const llama_model * model = llama_get_model(ctx);
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
const int n_vocab = llama_vocab_n_tokens(vocab);
|
||||
|
||||
cur.resize(n_vocab);
|
||||
|
||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
|
||||
if (sampled_probs) {
|
||||
const uint32_t sampled_probs_count = llama_get_sampled_probs_count_ith(ctx, idx);
|
||||
cur.resize(sampled_probs_count);
|
||||
for (uint32_t i = 0; i < sampled_probs_count; ++i) {
|
||||
cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], sampled_probs[i]};
|
||||
}
|
||||
} else if (sampled_logits) {
|
||||
const uint32_t sampled_logits_count = llama_get_sampled_logits_count_ith(ctx, idx);
|
||||
cur.resize(sampled_logits_count);
|
||||
for (uint32_t i = 0; i < sampled_logits_count; i++) {
|
||||
cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], 0.0f};
|
||||
}
|
||||
} else {
|
||||
const auto * logits = llama_get_logits_ith(ctx, idx);
|
||||
GGML_ASSERT(logits != nullptr);
|
||||
cur.resize(n_vocab);
|
||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
|
||||
}
|
||||
}
|
||||
|
||||
cur_p = { cur.data(), cur.size(), -1, false };
|
||||
@@ -151,54 +167,59 @@ std::string common_params_sampling::print() const {
|
||||
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
|
||||
"\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n"
|
||||
"\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, top_n_sigma = %.3f, temp = %.3f\n"
|
||||
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
|
||||
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f, adaptive_target = %.3f, adaptive_decay = %.3f",
|
||||
penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
|
||||
dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n,
|
||||
top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, top_n_sigma, temp,
|
||||
mirostat, mirostat_eta, mirostat_tau);
|
||||
mirostat, mirostat_eta, mirostat_tau, adaptive_target, adaptive_decay);
|
||||
|
||||
return std::string(result);
|
||||
}
|
||||
|
||||
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) {
|
||||
struct common_sampler * common_sampler_init(const struct llama_model * model, struct common_params_sampling & params) {
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
llama_sampler_chain_params lparams = llama_sampler_chain_default_params();
|
||||
|
||||
lparams.no_perf = params.no_perf;
|
||||
|
||||
llama_sampler * grmr = nullptr;
|
||||
llama_sampler * chain = llama_sampler_chain_init(lparams);
|
||||
|
||||
bool grammar = false;
|
||||
std::vector<llama_sampler *> samplers;
|
||||
|
||||
if (params.grammar.compare(0, 11, "%llguidance") == 0) {
|
||||
#ifdef LLAMA_USE_LLGUIDANCE
|
||||
samplers.push_back(llama_sampler_init_llg(vocab, "lark", params.grammar.c_str()));
|
||||
grammar = true;
|
||||
grmr = llama_sampler_init_llg(vocab, "lark", params.grammar.c_str());
|
||||
#else
|
||||
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
|
||||
#endif // LLAMA_USE_LLGUIDANCE
|
||||
} else {
|
||||
std::vector<std::string> trigger_patterns;
|
||||
std::vector<std::string> patterns_anywhere;
|
||||
std::vector<llama_token> trigger_tokens;
|
||||
for (const auto & trigger : params.grammar_triggers) {
|
||||
switch (trigger.type) {
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
|
||||
{
|
||||
const auto & word = trigger.value;
|
||||
patterns_anywhere.push_back(regex_escape(word));
|
||||
trigger_patterns.push_back(regex_escape(word));
|
||||
break;
|
||||
}
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
|
||||
{
|
||||
patterns_anywhere.push_back(trigger.value);
|
||||
trigger_patterns.push_back(trigger.value);
|
||||
break;
|
||||
}
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL:
|
||||
{
|
||||
trigger_patterns.push_back(trigger.value);
|
||||
const auto & pattern = trigger.value;
|
||||
std::string anchored = "^$";
|
||||
if (!pattern.empty()) {
|
||||
anchored = (pattern.front() != '^' ? "^" : "")
|
||||
+ pattern
|
||||
+ (pattern.back() != '$' ? "$" : "");
|
||||
}
|
||||
trigger_patterns.push_back(anchored);
|
||||
break;
|
||||
}
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN:
|
||||
@@ -212,10 +233,6 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
||||
}
|
||||
}
|
||||
|
||||
if (!patterns_anywhere.empty()) {
|
||||
trigger_patterns.push_back("^[\\s\\S]*?(" + string_join(patterns_anywhere, "|") + ")[\\s\\S]*");
|
||||
}
|
||||
|
||||
std::vector<const char *> trigger_patterns_c;
|
||||
trigger_patterns_c.reserve(trigger_patterns.size());
|
||||
for (const auto & regex : trigger_patterns) {
|
||||
@@ -224,15 +241,12 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
||||
|
||||
if (!params.grammar.empty()) {
|
||||
if (params.grammar_lazy) {
|
||||
samplers.push_back(
|
||||
llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root",
|
||||
trigger_patterns_c.data(), trigger_patterns_c.size(),
|
||||
trigger_tokens.data(), trigger_tokens.size()));
|
||||
grmr = llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root",
|
||||
trigger_patterns_c.data(), trigger_patterns_c.size(),
|
||||
trigger_tokens.data(), trigger_tokens.size());
|
||||
} else {
|
||||
samplers.push_back(llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root"));
|
||||
grmr = llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root");
|
||||
}
|
||||
|
||||
grammar = true;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -241,6 +255,9 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
||||
}
|
||||
|
||||
if (params.mirostat == 0) {
|
||||
|
||||
bool use_adaptive_p = false; // see below
|
||||
|
||||
for (const auto & cnstr : params.samplers) {
|
||||
switch (cnstr) {
|
||||
case COMMON_SAMPLER_TYPE_DRY:
|
||||
@@ -250,43 +267,54 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
||||
for (const auto & str : params.dry_sequence_breakers) {
|
||||
c_breakers.push_back(str.c_str());
|
||||
}
|
||||
|
||||
samplers.push_back(llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
|
||||
samplers.push_back(llama_sampler_init_dry(vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
|
||||
}
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TOP_K:
|
||||
samplers.push_back(llama_sampler_init_top_k (params.top_k));
|
||||
samplers.push_back(llama_sampler_init_top_k(params.top_k));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TOP_P:
|
||||
samplers.push_back(llama_sampler_init_top_p (params.top_p, params.min_keep));
|
||||
samplers.push_back(llama_sampler_init_top_p(params.top_p, params.min_keep));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA:
|
||||
samplers.push_back(llama_sampler_init_top_n_sigma(params.top_n_sigma));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_MIN_P:
|
||||
samplers.push_back(llama_sampler_init_min_p (params.min_p, params.min_keep));
|
||||
samplers.push_back(llama_sampler_init_min_p(params.min_p, params.min_keep));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_XTC:
|
||||
samplers.push_back(llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
|
||||
samplers.push_back(llama_sampler_init_xtc(params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TYPICAL_P:
|
||||
samplers.push_back(llama_sampler_init_typical (params.typ_p, params.min_keep));
|
||||
samplers.push_back(llama_sampler_init_typical(params.typ_p, params.min_keep));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TEMPERATURE:
|
||||
samplers.push_back(llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
|
||||
samplers.push_back(llama_sampler_init_temp_ext(params.temp, params.dynatemp_range, params.dynatemp_exponent));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_INFILL:
|
||||
samplers.push_back(llama_sampler_init_infill (vocab));
|
||||
samplers.push_back(llama_sampler_init_infill(vocab));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_PENALTIES:
|
||||
samplers.push_back(llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
|
||||
samplers.push_back(llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_ADAPTIVE_P:
|
||||
// the `adaptive-p` sampler is like `dist` and `mirostat` in that it selects
|
||||
// a single token, so we will add `dist` at the end of the chain by default,
|
||||
// unless the user specifically included `adaptive-p`. we set this flag here
|
||||
// so we know to add the sampler at the very end.
|
||||
use_adaptive_p = true;
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false && "unknown sampler type");
|
||||
}
|
||||
}
|
||||
|
||||
samplers.push_back(llama_sampler_init_dist(params.seed));
|
||||
if (use_adaptive_p) {
|
||||
// only if user explicitly included adaptive-p sampler
|
||||
samplers.push_back(llama_sampler_init_adaptive_p(params.adaptive_target, params.adaptive_decay, params.seed));
|
||||
} else {
|
||||
// default: sample from distribution
|
||||
samplers.push_back(llama_sampler_init_dist(params.seed));
|
||||
}
|
||||
} else if (params.mirostat == 1) {
|
||||
samplers.push_back(llama_sampler_init_temp(params.temp));
|
||||
samplers.push_back(llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
|
||||
@@ -301,10 +329,16 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
||||
llama_sampler_chain_add(chain, smpl);
|
||||
}
|
||||
|
||||
if (grmr && params.backend_sampling) {
|
||||
LOG_WRN("%s: backend sampling is not compatible with grammar, disabling\n", __func__);
|
||||
|
||||
params.backend_sampling = false;
|
||||
}
|
||||
|
||||
auto * result = new common_sampler {
|
||||
/* .params = */ params,
|
||||
/* .grmr = */ grmr,
|
||||
/* .chain = */ chain,
|
||||
/* .grammar = */ grammar,
|
||||
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
|
||||
/* .cur = */ {},
|
||||
/* .cur_p = */ {},
|
||||
@@ -314,47 +348,45 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
||||
}
|
||||
|
||||
void common_sampler_free(struct common_sampler * gsmpl) {
|
||||
if (gsmpl) {
|
||||
llama_sampler_free(gsmpl->chain);
|
||||
|
||||
delete gsmpl;
|
||||
if (!gsmpl) {
|
||||
return;
|
||||
}
|
||||
|
||||
llama_sampler_free(gsmpl->grmr);
|
||||
llama_sampler_free(gsmpl->chain);
|
||||
|
||||
delete gsmpl;
|
||||
}
|
||||
|
||||
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
|
||||
if (!gsmpl) {
|
||||
return;
|
||||
}
|
||||
|
||||
const auto tm = gsmpl->tm();
|
||||
|
||||
if (gsmpl->grammar) {
|
||||
const int n_smpl = llama_sampler_chain_n(gsmpl->chain);
|
||||
|
||||
for (int i = 0; i < n_smpl; i++) {
|
||||
auto * smpl = llama_sampler_chain_get(gsmpl->chain, i);
|
||||
|
||||
// the grammar sampler is always the first one
|
||||
if (i == 0) {
|
||||
if (accept_grammar) {
|
||||
llama_sampler_accept(smpl, token);
|
||||
}
|
||||
} else {
|
||||
llama_sampler_accept(smpl, token);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
llama_sampler_accept(gsmpl->chain, token);
|
||||
if (gsmpl->grmr && accept_grammar) {
|
||||
llama_sampler_accept(gsmpl->grmr, token);
|
||||
}
|
||||
|
||||
llama_sampler_accept(gsmpl->chain, token);
|
||||
|
||||
gsmpl->prev.push_back(token);
|
||||
}
|
||||
|
||||
void common_sampler_reset(struct common_sampler * gsmpl) {
|
||||
if (!gsmpl) {
|
||||
return;
|
||||
}
|
||||
|
||||
gsmpl->reset();
|
||||
}
|
||||
|
||||
struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
|
||||
return new common_sampler {
|
||||
/* .params = */ gsmpl->params,
|
||||
/* .grmr = */ llama_sampler_clone(gsmpl->grmr),
|
||||
/* .chain = */ llama_sampler_clone(gsmpl->chain),
|
||||
/* .grammar = */ gsmpl->grammar,
|
||||
/* .prev = */ gsmpl->prev,
|
||||
/* .cur = */ gsmpl->cur,
|
||||
/* .cur_p = */ gsmpl->cur_p,
|
||||
@@ -407,10 +439,14 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam
|
||||
}
|
||||
|
||||
struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl) {
|
||||
if (!gsmpl) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return gsmpl->chain;
|
||||
}
|
||||
|
||||
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx) {
|
||||
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
|
||||
llama_synchronize(ctx);
|
||||
|
||||
// start measuring sampling time after the llama_context synchronization in order to not measure any ongoing async operations
|
||||
@@ -418,11 +454,61 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
|
||||
|
||||
llama_token id = LLAMA_TOKEN_NULL;
|
||||
|
||||
auto & grmr = gsmpl->grmr;
|
||||
auto & chain = gsmpl->chain;
|
||||
auto & cur_p = gsmpl->cur_p; // initialized by set_logits
|
||||
|
||||
// Check if a backend sampler has already sampled a token in which case we
|
||||
// return that token id directly.
|
||||
{
|
||||
id = llama_get_sampled_token_ith(ctx, idx);
|
||||
|
||||
if (id != LLAMA_TOKEN_NULL) {
|
||||
LOG_DBG("%s: Backend sampler selected token: '%d'. Will not run any CPU samplers\n", __func__, id);
|
||||
|
||||
GGML_ASSERT(!gsmpl->grmr && "using grammar in combination with backend sampling is not supported");
|
||||
|
||||
// TODO: simplify
|
||||
gsmpl->cur.resize(1);
|
||||
gsmpl->cur[0] = { id, 0.0f, 1.0f };
|
||||
cur_p = { gsmpl->cur.data(), gsmpl->cur.size(), 0, true };
|
||||
|
||||
return id;
|
||||
}
|
||||
}
|
||||
|
||||
gsmpl->set_logits(ctx, idx);
|
||||
|
||||
if (grammar_first) {
|
||||
llama_sampler_apply(grmr, &cur_p);
|
||||
}
|
||||
|
||||
llama_sampler_apply(chain, &cur_p);
|
||||
|
||||
id = cur_p.data[cur_p.selected].id;
|
||||
|
||||
if (grammar_first) {
|
||||
return id;
|
||||
}
|
||||
|
||||
// check if it the sampled token fits the grammar (grammar-based rejection sampling)
|
||||
{
|
||||
llama_token_data single_token_data = { id, 1.0f, 0.0f };
|
||||
llama_token_data_array single_token_data_array = { &single_token_data, 1, -1, false };
|
||||
|
||||
llama_sampler_apply(grmr, &single_token_data_array);
|
||||
|
||||
const bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
|
||||
if (is_valid) {
|
||||
return id;
|
||||
}
|
||||
}
|
||||
|
||||
// resampling:
|
||||
// if the token is not valid, sample again, but first apply the grammar sampler and then the sampling chain
|
||||
gsmpl->set_logits(ctx, idx);
|
||||
|
||||
llama_sampler_apply(grmr, &cur_p);
|
||||
llama_sampler_apply(chain, &cur_p);
|
||||
|
||||
GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
|
||||
@@ -432,7 +518,7 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
|
||||
return id;
|
||||
}
|
||||
|
||||
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft) {
|
||||
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first) {
|
||||
GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");
|
||||
|
||||
std::vector<llama_token> result;
|
||||
@@ -440,7 +526,7 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
|
||||
|
||||
size_t i = 0;
|
||||
for (; i < draft.size(); i++) {
|
||||
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i]);
|
||||
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
|
||||
|
||||
common_sampler_accept(gsmpl, id, true);
|
||||
|
||||
@@ -452,7 +538,7 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
|
||||
}
|
||||
|
||||
if (i == draft.size()) {
|
||||
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i]);
|
||||
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
|
||||
|
||||
common_sampler_accept(gsmpl, id, true);
|
||||
|
||||
@@ -462,13 +548,13 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft) {
|
||||
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) {
|
||||
std::vector<int> idxs(draft.size() + 1);
|
||||
for (size_t i = 0; i < idxs.size(); ++i) {
|
||||
idxs[i] = i;
|
||||
}
|
||||
|
||||
return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft);
|
||||
return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first);
|
||||
}
|
||||
|
||||
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
|
||||
@@ -553,6 +639,7 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
|
||||
case COMMON_SAMPLER_TYPE_XTC: return 'x';
|
||||
case COMMON_SAMPLER_TYPE_INFILL: return 'i';
|
||||
case COMMON_SAMPLER_TYPE_PENALTIES: return 'e';
|
||||
case COMMON_SAMPLER_TYPE_ADAPTIVE_P: return 'a';
|
||||
default : return '?';
|
||||
}
|
||||
}
|
||||
@@ -569,6 +656,7 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
|
||||
case COMMON_SAMPLER_TYPE_XTC: return "xtc";
|
||||
case COMMON_SAMPLER_TYPE_INFILL: return "infill";
|
||||
case COMMON_SAMPLER_TYPE_PENALTIES: return "penalties";
|
||||
case COMMON_SAMPLER_TYPE_ADAPTIVE_P: return "adaptive_p";
|
||||
default : return "";
|
||||
}
|
||||
}
|
||||
@@ -585,6 +673,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
|
||||
{ "xtc", COMMON_SAMPLER_TYPE_XTC },
|
||||
{ "infill", COMMON_SAMPLER_TYPE_INFILL },
|
||||
{ "penalties", COMMON_SAMPLER_TYPE_PENALTIES },
|
||||
{ "adaptive_p", COMMON_SAMPLER_TYPE_ADAPTIVE_P },
|
||||
};
|
||||
|
||||
// since samplers names are written multiple ways
|
||||
@@ -600,6 +689,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
|
||||
{ "typ", COMMON_SAMPLER_TYPE_TYPICAL_P },
|
||||
{ "min-p", COMMON_SAMPLER_TYPE_MIN_P },
|
||||
{ "temp", COMMON_SAMPLER_TYPE_TEMPERATURE },
|
||||
{ "adaptive-p", COMMON_SAMPLER_TYPE_ADAPTIVE_P },
|
||||
};
|
||||
|
||||
std::vector<common_sampler_type> samplers;
|
||||
@@ -636,6 +726,7 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
|
||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC },
|
||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_INFILL), COMMON_SAMPLER_TYPE_INFILL },
|
||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_PENALTIES), COMMON_SAMPLER_TYPE_PENALTIES },
|
||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_ADAPTIVE_P), COMMON_SAMPLER_TYPE_ADAPTIVE_P },
|
||||
};
|
||||
|
||||
std::vector<common_sampler_type> samplers;
|
||||
|
||||
13
llama/llama.cpp/common/sampling.h
vendored
13
llama/llama.cpp/common/sampling.h
vendored
@@ -36,7 +36,8 @@ struct common_sampler;
|
||||
|
||||
// llama_sampler API overloads
|
||||
|
||||
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params);
|
||||
// note: can mutate params in some cases
|
||||
struct common_sampler * common_sampler_init(const struct llama_model * model, struct common_params_sampling & params);
|
||||
|
||||
void common_sampler_free(struct common_sampler * gsmpl);
|
||||
|
||||
@@ -48,6 +49,7 @@ struct common_sampler * common_sampler_clone (struct common_sampler * gsmpl);
|
||||
// arguments can be nullptr to skip printing
|
||||
void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl);
|
||||
|
||||
// get the underlying llama_sampler_chain
|
||||
struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl);
|
||||
|
||||
// extended sampling implementation:
|
||||
@@ -57,7 +59,10 @@ struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl);
|
||||
// - check if the token fits the grammar (if any)
|
||||
// - if not: resample by first applying the grammar constraints and then sampling again (slower path)
|
||||
//
|
||||
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx);
|
||||
// if grammar_first is true, the grammar is applied before the samplers (slower)
|
||||
// useful in cases where all the resulting candidates (not just the sampled one) must fit the grammar
|
||||
//
|
||||
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false);
|
||||
|
||||
// generalized version of common_sampler_sample
|
||||
//
|
||||
@@ -75,10 +80,10 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
|
||||
//
|
||||
// returns at least 1 token, up to idxs.size()
|
||||
//
|
||||
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft);
|
||||
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first = false);
|
||||
|
||||
// assume idxs == [ 0, 1, 2, ..., draft.size() ]
|
||||
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft);
|
||||
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false);
|
||||
|
||||
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);
|
||||
|
||||
|
||||
4
llama/llama.cpp/include/llama-cpp.h
vendored
4
llama/llama.cpp/include/llama-cpp.h
vendored
@@ -21,7 +21,9 @@ struct llama_sampler_deleter {
|
||||
};
|
||||
|
||||
struct llama_adapter_lora_deleter {
|
||||
void operator()(llama_adapter_lora * adapter) { llama_adapter_lora_free(adapter); }
|
||||
void operator()(llama_adapter_lora *) {
|
||||
// llama_adapter_lora_free is deprecated
|
||||
}
|
||||
};
|
||||
|
||||
typedef std::unique_ptr<llama_model, llama_model_deleter> llama_model_ptr;
|
||||
|
||||
153
llama/llama.cpp/include/llama.h
vendored
153
llama/llama.cpp/include/llama.h
vendored
@@ -286,7 +286,7 @@ extern "C" {
|
||||
// NULL-terminated list of buffer types to use for tensors that match a pattern
|
||||
const struct llama_model_tensor_buft_override * tensor_buft_overrides;
|
||||
|
||||
int32_t n_gpu_layers; // number of layers to store in VRAM
|
||||
int32_t n_gpu_layers; // number of layers to store in VRAM, a negative value means all layers
|
||||
enum llama_split_mode split_mode; // how to split the model across multiple GPUs
|
||||
|
||||
// the GPU that is used for the entire model when split_mode is LLAMA_SPLIT_MODE_NONE
|
||||
@@ -309,6 +309,7 @@ extern "C" {
|
||||
// Keep the booleans together to avoid misalignment during copy-by-value.
|
||||
bool vocab_only; // only load the vocabulary, no weights
|
||||
bool use_mmap; // use mmap if possible
|
||||
bool use_direct_io; // use direct io, takes precedence over use_mmap
|
||||
bool use_mlock; // force system to keep model in RAM
|
||||
bool check_tensors; // validate model tensor data
|
||||
bool use_extra_bufts; // use extra buffer types (used for weight repacking)
|
||||
@@ -316,6 +317,11 @@ extern "C" {
|
||||
bool no_alloc; // only load metadata and simulate memory allocations
|
||||
};
|
||||
|
||||
struct llama_sampler_seq_config {
|
||||
llama_seq_id seq_id;
|
||||
struct llama_sampler * sampler;
|
||||
};
|
||||
|
||||
// NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations
|
||||
// https://github.com/ggml-org/llama.cpp/pull/7544
|
||||
struct llama_context_params {
|
||||
@@ -364,6 +370,12 @@ extern "C" {
|
||||
bool kv_unified; // use a unified buffer across the input sequences when computing the attention
|
||||
// try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/14363
|
||||
|
||||
// [EXPERIMENTAL]
|
||||
// backend sampler chain configuration (make sure the caller keeps the sampler chains alive)
|
||||
// note: the samplers must be sampler chains (i.e. use llama_sampler_chain_init)
|
||||
struct llama_sampler_seq_config * samplers;
|
||||
size_t n_samplers;
|
||||
};
|
||||
|
||||
// model quantization parameters
|
||||
@@ -467,16 +479,24 @@ extern "C" {
|
||||
// Frees all allocated memory
|
||||
LLAMA_API void llama_free(struct llama_context * ctx);
|
||||
|
||||
enum llama_params_fit_status {
|
||||
LLAMA_PARAMS_FIT_STATUS_SUCCESS = 0, // found allocations that are projected to fit
|
||||
LLAMA_PARAMS_FIT_STATUS_FAILURE = 1, // could not find allocations that are projected to fit
|
||||
LLAMA_PARAMS_FIT_STATUS_ERROR = 2, // a hard error occured, e.g. because no model could be found at the specified path
|
||||
};
|
||||
|
||||
// fits mparams and cparams to free device memory (assumes system memory is unlimited)
|
||||
// returns true if the parameters could be successfully modified to fit device memory
|
||||
// this function is NOT thread safe because it modifies the global llama logger state
|
||||
LLAMA_API bool llama_params_fit(
|
||||
// - returns true if the parameters could be successfully modified to fit device memory
|
||||
// - this function is NOT thread safe because it modifies the global llama logger state
|
||||
// - only parameters that have the same value as in llama_default_model_params are modified
|
||||
// with the exception of the context size which is modified if and only if equal to 0
|
||||
LLAMA_API enum llama_params_fit_status llama_params_fit(
|
||||
const char * path_model,
|
||||
struct llama_model_params * mparams,
|
||||
struct llama_context_params * cparams,
|
||||
float * tensor_split, // writable buffer for tensor split, needs at least llama_max_devices elements
|
||||
struct llama_model_tensor_buft_override * tensor_buft_overrides, // writable buffer for overrides, needs at least llama_max_tensor_buft_overrides elements
|
||||
size_t margin, // margin of memory to leave per device in bytes
|
||||
size_t * margins, // margins of memory to leave per device in bytes
|
||||
uint32_t n_ctx_min, // minimum context size to set when trying to reduce memory use
|
||||
enum ggml_log_level log_level); // minimum log level to print during fitting, lower levels go to debug log
|
||||
|
||||
@@ -517,6 +537,7 @@ extern "C" {
|
||||
LLAMA_API int32_t llama_model_n_ctx_train(const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_embd (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_embd_inp (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_embd_out (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_head (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model);
|
||||
@@ -600,6 +621,8 @@ extern "C" {
|
||||
//
|
||||
|
||||
// Load a LoRA adapter from file
|
||||
// The adapter is valid as long as the associated model is not freed
|
||||
// All adapters must be loaded before context creation
|
||||
LLAMA_API struct llama_adapter_lora * llama_adapter_lora_init(
|
||||
struct llama_model * model,
|
||||
const char * path_lora);
|
||||
@@ -624,7 +647,8 @@ extern "C" {
|
||||
|
||||
// Manually free a LoRA adapter
|
||||
// NOTE: loaded adapters will be free when the associated model is deleted
|
||||
LLAMA_API void llama_adapter_lora_free(struct llama_adapter_lora * adapter);
|
||||
LLAMA_API DEPRECATED(void llama_adapter_lora_free(struct llama_adapter_lora * adapter),
|
||||
"adapters are now freed together with the associated model");
|
||||
|
||||
// Get the invocation tokens if the current lora is an alora
|
||||
LLAMA_API uint64_t llama_adapter_get_alora_n_invocation_tokens(const struct llama_adapter_lora * adapter);
|
||||
@@ -983,6 +1007,32 @@ extern "C" {
|
||||
// otherwise: float[n_embd] (1-dimensional)
|
||||
LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id);
|
||||
|
||||
//
|
||||
// backend sampling API [EXPERIMENTAL]
|
||||
// note: use only if the llama_context was created with at least one llama_sampler_seq_config
|
||||
//
|
||||
|
||||
// Get the backend sampled token for the ith token.
|
||||
// Returns LLAMA_TOKEN_NULL if no token was sampled.
|
||||
LLAMA_API llama_token llama_get_sampled_token_ith(struct llama_context * ctx, int32_t i);
|
||||
|
||||
// Get the backend sampled probabilites for the ith token
|
||||
// The index matches llama_get_sampled_token_ith().
|
||||
// Returns NULL if no probabilites were generated.
|
||||
LLAMA_API float * llama_get_sampled_probs_ith (struct llama_context * ctx, int32_t i);
|
||||
LLAMA_API uint32_t llama_get_sampled_probs_count_ith(struct llama_context * ctx, int32_t i);
|
||||
|
||||
// Get the backend sampled logits for the ith token
|
||||
// Returns NULL if no logits were sampled.
|
||||
LLAMA_API float * llama_get_sampled_logits_ith (struct llama_context * ctx, int32_t i);
|
||||
LLAMA_API uint32_t llama_get_sampled_logits_count_ith(struct llama_context * ctx, int32_t i);
|
||||
|
||||
// Get the backend sampled candidates (token ids) for the ith token
|
||||
// These are needed to map probability/logit indices to vocab token ids.
|
||||
// Returns NULL if no candidates were sampled.
|
||||
LLAMA_API llama_token * llama_get_sampled_candidates_ith (struct llama_context * ctx, int32_t i);
|
||||
LLAMA_API uint32_t llama_get_sampled_candidates_count_ith(struct llama_context * ctx, int32_t i);
|
||||
|
||||
//
|
||||
// Vocab
|
||||
//
|
||||
@@ -1154,11 +1204,16 @@ extern "C" {
|
||||
//
|
||||
// llama_sampler_free(smpl);
|
||||
//
|
||||
// TODO: In the future, llama_sampler will be utilized to offload the sampling to the backends (e.g. GPU).
|
||||
//
|
||||
|
||||
typedef void * llama_sampler_context_t;
|
||||
|
||||
struct llama_sampler_data {
|
||||
struct ggml_tensor * logits;
|
||||
struct ggml_tensor * probs;
|
||||
struct ggml_tensor * sampled;
|
||||
struct ggml_tensor * candidates;
|
||||
};
|
||||
|
||||
// user code can implement the interface below in order to create custom llama_sampler
|
||||
struct llama_sampler_i {
|
||||
const char * (*name) (const struct llama_sampler * smpl); // can be NULL
|
||||
@@ -1168,17 +1223,44 @@ extern "C" {
|
||||
struct llama_sampler * (*clone) (const struct llama_sampler * smpl); // can be NULL if ctx is NULL
|
||||
void (*free) ( struct llama_sampler * smpl); // can be NULL if ctx is NULL
|
||||
|
||||
// TODO: API for internal libllama usage for appending the sampling to an existing ggml_cgraph
|
||||
//void (*apply_ggml) (struct llama_sampler * smpl, ...);
|
||||
// [EXPERIMENTAL]
|
||||
// backend sampling interface:
|
||||
|
||||
// return true if the backend supports all ops needed by the sampler
|
||||
// note: call once per sampler
|
||||
bool (*backend_init)(struct llama_sampler * smpl, ggml_backend_buffer_type_t buft);
|
||||
|
||||
// call after .backend_apply()
|
||||
void (*backend_accept)(
|
||||
struct llama_sampler * smpl,
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_cgraph * gf,
|
||||
struct ggml_tensor * selected_token);
|
||||
|
||||
// call after .backend_init()
|
||||
void (*backend_apply)(
|
||||
struct llama_sampler * smpl,
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_cgraph * gf,
|
||||
struct llama_sampler_data * data);
|
||||
|
||||
// called before graph execution to set inputs for the current ubatch
|
||||
void (*backend_set_input)(struct llama_sampler * smpl);
|
||||
};
|
||||
|
||||
struct llama_sampler {
|
||||
const struct llama_sampler_i * iface;
|
||||
llama_sampler_context_t ctx;
|
||||
struct llama_sampler_i * iface;
|
||||
|
||||
llama_sampler_context_t ctx;
|
||||
};
|
||||
|
||||
// [EXPERIMENTAL]
|
||||
// attach a sampler to the context
|
||||
// note: prefer initializing the context with llama_context_params.samplers when possible
|
||||
LLAMA_API bool llama_set_sampler(struct llama_context * ctx, llama_seq_id seq_id, struct llama_sampler * smpl);
|
||||
|
||||
// mirror of llama_sampler_i:
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init (const struct llama_sampler_i * iface, llama_sampler_context_t ctx);
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init ( struct llama_sampler_i * iface, llama_sampler_context_t ctx);
|
||||
LLAMA_API const char * llama_sampler_name (const struct llama_sampler * smpl);
|
||||
LLAMA_API void llama_sampler_accept( struct llama_sampler * smpl, llama_token token);
|
||||
LLAMA_API void llama_sampler_apply ( struct llama_sampler * smpl, llama_token_data_array * cur_p);
|
||||
@@ -1194,7 +1276,15 @@ extern "C" {
|
||||
|
||||
// important: takes ownership of the sampler object and will free it when llama_sampler_free is called
|
||||
LLAMA_API void llama_sampler_chain_add( struct llama_sampler * chain, struct llama_sampler * smpl);
|
||||
LLAMA_API struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i);
|
||||
|
||||
// return NULL if:
|
||||
// - the sampler is NULL
|
||||
// - the sampler is not a llama_sampler_chain
|
||||
// - the index is out of bounds, unless i == -1
|
||||
// - if i == -1, returns the chain itself (can be used to check if the sampler is a chain)
|
||||
LLAMA_API struct llama_sampler * llama_sampler_chain_get( struct llama_sampler * chain, int32_t i);
|
||||
|
||||
// the total number of samplers in the chain
|
||||
LLAMA_API int llama_sampler_chain_n (const struct llama_sampler * chain);
|
||||
|
||||
// after removing a sampler, the chain will no longer own it, and it will not be freed when the chain is freed
|
||||
@@ -1203,7 +1293,9 @@ extern "C" {
|
||||
// available samplers:
|
||||
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_greedy(void);
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_dist (uint32_t seed);
|
||||
|
||||
/// seed == LLAMA_DEFAULT_SEED to use a random seed.
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_dist(uint32_t seed);
|
||||
|
||||
/// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
||||
/// Setting k <= 0 makes this a noop
|
||||
@@ -1304,6 +1396,33 @@ extern "C" {
|
||||
const char ** seq_breakers,
|
||||
size_t num_breakers);
|
||||
|
||||
/// adaptive-p: select tokens near a configurable target probability over time.
|
||||
///
|
||||
/// the adaptive-p sampler transforms the token probability distribution to favor tokens
|
||||
/// that fall near a user-configurable probability target.
|
||||
///
|
||||
/// internally, the sampler maintains an exponential moving average of the *ORIGINAL*
|
||||
/// probabilities of selected tokens at each sampling step. it uses this EMA to compute an
|
||||
/// adapted target probability at each sampling step, thus maintaining the desired target
|
||||
/// probability over time.
|
||||
///
|
||||
/// adaptive-p selects a token ID rather than just mutating candidates, so it must be last
|
||||
/// in the sampler chain (like mirostat, dist, greedy).
|
||||
///
|
||||
/// only mild truncation before this sampler is recommended. we suggest applying min-p
|
||||
/// before adaptive-p as the only other active sampler in the chain.
|
||||
///
|
||||
/// @param target select tokens near this probability (valid range 0.0 to 1.0; negative = disabled)
|
||||
/// @param decay EMA decay for adaptation; history ≈ 1/(1-decay) tokens (valid range 0.0 - 0.99)
|
||||
/// @param seed RNG seed
|
||||
///
|
||||
/// ref: https://github.com/ggml-org/llama.cpp/pull/17927
|
||||
///
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_adaptive_p(
|
||||
float target,
|
||||
float decay,
|
||||
uint32_t seed);
|
||||
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias(
|
||||
int32_t n_vocab,
|
||||
int32_t n_logit_bias,
|
||||
@@ -1357,12 +1476,12 @@ extern "C" {
|
||||
/// @details Build a split GGUF final path for this chunk.
|
||||
/// llama_split_path(split_path, sizeof(split_path), "/models/ggml-model-q4_0", 2, 4) => split_path = "/models/ggml-model-q4_0-00002-of-00004.gguf"
|
||||
// Returns the split_path length.
|
||||
LLAMA_API int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int split_no, int split_count);
|
||||
LLAMA_API int32_t llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int32_t split_no, int32_t split_count);
|
||||
|
||||
/// @details Extract the path prefix from the split_path if and only if the split_no and split_count match.
|
||||
/// llama_split_prefix(split_prefix, 64, "/models/ggml-model-q4_0-00002-of-00004.gguf", 2, 4) => split_prefix = "/models/ggml-model-q4_0"
|
||||
// Returns the split_prefix length.
|
||||
LLAMA_API int llama_split_prefix(char * split_prefix, size_t maxlen, const char * split_path, int split_no, int split_count);
|
||||
LLAMA_API int32_t llama_split_prefix(char * split_prefix, size_t maxlen, const char * split_path, int32_t split_no, int32_t split_count);
|
||||
|
||||
// Print system information
|
||||
LLAMA_API const char * llama_print_system_info(void);
|
||||
|
||||
7
llama/llama.cpp/src/llama-adapter.cpp
vendored
7
llama/llama.cpp/src/llama-adapter.cpp
vendored
@@ -411,6 +411,9 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_
|
||||
}
|
||||
}
|
||||
|
||||
// register adapter with model
|
||||
model.loras.insert(&adapter);
|
||||
|
||||
LLAMA_LOG_INFO("%s: loaded %zu tensors from lora file\n", __func__, adapter.ab_map.size()*2);
|
||||
}
|
||||
|
||||
@@ -468,8 +471,8 @@ int32_t llama_adapter_meta_val_str_by_index(const llama_adapter_lora * adapter,
|
||||
return snprintf(buf, buf_size, "%s", it->second.c_str());
|
||||
}
|
||||
|
||||
void llama_adapter_lora_free(llama_adapter_lora * adapter) {
|
||||
delete adapter;
|
||||
void llama_adapter_lora_free(llama_adapter_lora *) {
|
||||
// deprecated: adapters are freed by llama_model's destructor
|
||||
}
|
||||
|
||||
uint64_t llama_adapter_get_alora_n_invocation_tokens(const struct llama_adapter_lora * adapter) {
|
||||
|
||||
4
llama/llama.cpp/src/llama-adapter.h
vendored
4
llama/llama.cpp/src/llama-adapter.h
vendored
@@ -77,6 +77,10 @@ struct llama_adapter_lora {
|
||||
~llama_adapter_lora() = default;
|
||||
|
||||
llama_adapter_lora_weight * get_weight(ggml_tensor * w);
|
||||
|
||||
uint32_t get_n_nodes() const {
|
||||
return ab_map.size() * 6u; // a, b, scale, add, 2 x mul_mat
|
||||
}
|
||||
};
|
||||
|
||||
using llama_adapter_loras = std::unordered_map<llama_adapter_lora *, float>;
|
||||
|
||||
115
llama/llama.cpp/src/llama-arch.cpp
vendored
115
llama/llama.cpp/src/llama-arch.cpp
vendored
@@ -20,6 +20,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_STARCODER, "starcoder" },
|
||||
{ LLM_ARCH_REFACT, "refact" },
|
||||
{ LLM_ARCH_BERT, "bert" },
|
||||
{ LLM_ARCH_MODERN_BERT, "modern-bert" },
|
||||
{ LLM_ARCH_NOMIC_BERT, "nomic-bert" },
|
||||
{ LLM_ARCH_NOMIC_BERT_MOE, "nomic-bert-moe" },
|
||||
{ LLM_ARCH_NEO_BERT, "neo-bert" },
|
||||
@@ -41,6 +42,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_PHIMOE, "phimoe" },
|
||||
{ LLM_ARCH_PLAMO, "plamo" },
|
||||
{ LLM_ARCH_PLAMO2, "plamo2" },
|
||||
{ LLM_ARCH_PLAMO3, "plamo3" },
|
||||
{ LLM_ARCH_CODESHELL, "codeshell" },
|
||||
{ LLM_ARCH_ORION, "orion" },
|
||||
{ LLM_ARCH_INTERNLM2, "internlm2" },
|
||||
@@ -79,6 +81,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_NEMOTRON_H_MOE, "nemotron_h_moe" },
|
||||
{ LLM_ARCH_EXAONE, "exaone" },
|
||||
{ LLM_ARCH_EXAONE4, "exaone4" },
|
||||
{ LLM_ARCH_EXAONE_MOE, "exaone-moe" },
|
||||
{ LLM_ARCH_RWKV6, "rwkv6" },
|
||||
{ LLM_ARCH_RWKV6QWEN2, "rwkv6qwen2" },
|
||||
{ LLM_ARCH_RWKV7, "rwkv7" },
|
||||
@@ -115,6 +118,9 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_RND1, "rnd1" },
|
||||
{ LLM_ARCH_PANGU_EMBED, "pangu-embedded" },
|
||||
{ LLM_ARCH_MISTRAL3, "mistral3" },
|
||||
{ LLM_ARCH_MIMO2, "mimo2" },
|
||||
{ LLM_ARCH_LLAMA_EMBED, "llama-embed" },
|
||||
{ LLM_ARCH_MAINCODER, "maincoder" },
|
||||
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
||||
};
|
||||
|
||||
@@ -148,6 +154,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||
{ LLM_KV_VOCAB_SIZE, "%s.vocab_size" },
|
||||
{ LLM_KV_CONTEXT_LENGTH, "%s.context_length" },
|
||||
{ LLM_KV_EMBEDDING_LENGTH, "%s.embedding_length" },
|
||||
{ LLM_KV_EMBEDDING_LENGTH_OUT, "%s.embedding_length_out" },
|
||||
{ LLM_KV_FEATURES_LENGTH, "%s.features_length" },
|
||||
{ LLM_KV_BLOCK_COUNT, "%s.block_count" },
|
||||
{ LLM_KV_LEADING_DENSE_BLOCK_COUNT, "%s.leading_dense_block_count" },
|
||||
@@ -205,6 +212,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||
{ LLM_KV_ATTENTION_GATE_LORA_RANK, "%s.attention.gate_lora_rank" },
|
||||
{ LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" },
|
||||
{ LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" },
|
||||
{ LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, "%s.attention.sliding_window_pattern" },
|
||||
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
|
||||
{ LLM_KV_ATTENTION_OUTPUT_SCALE, "%s.attention.output_scale" },
|
||||
{ LLM_KV_ATTENTION_TEMPERATURE_LENGTH, "%s.attention.temperature_length" },
|
||||
@@ -216,6 +224,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
|
||||
{ LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
|
||||
{ LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" },
|
||||
{ LLM_KV_ROPE_FREQ_BASE_SWA, "%s.rope.freq_base_swa" },
|
||||
{ LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" },
|
||||
{ LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" },
|
||||
{ LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" },
|
||||
@@ -500,6 +509,7 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
|
||||
case LLM_ARCH_LLAMA:
|
||||
case LLM_ARCH_DECI:
|
||||
case LLM_ARCH_MISTRAL3:
|
||||
case LLM_ARCH_LLAMA_EMBED:
|
||||
return {
|
||||
LLM_TENSOR_TOKEN_EMBD,
|
||||
LLM_TENSOR_OUTPUT_NORM,
|
||||
@@ -781,6 +791,20 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
|
||||
LLM_TENSOR_CLS,
|
||||
LLM_TENSOR_CLS_OUT,
|
||||
};
|
||||
case LLM_ARCH_MODERN_BERT:
|
||||
return {
|
||||
LLM_TENSOR_TOKEN_EMBD,
|
||||
LLM_TENSOR_TOKEN_EMBD_NORM,
|
||||
LLM_TENSOR_OUTPUT_NORM,
|
||||
LLM_TENSOR_ATTN_NORM,
|
||||
LLM_TENSOR_ATTN_OUT,
|
||||
LLM_TENSOR_ATTN_QKV,
|
||||
LLM_TENSOR_FFN_DOWN,
|
||||
LLM_TENSOR_FFN_UP,
|
||||
LLM_TENSOR_FFN_NORM,
|
||||
LLM_TENSOR_CLS,
|
||||
LLM_TENSOR_CLS_OUT,
|
||||
};
|
||||
case LLM_ARCH_JINA_BERT_V2:
|
||||
return {
|
||||
LLM_TENSOR_TOKEN_EMBD,
|
||||
@@ -930,6 +954,8 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
|
||||
LLM_TENSOR_ATTN_K_NORM,
|
||||
LLM_TENSOR_ATTN_V,
|
||||
LLM_TENSOR_ATTN_OUT,
|
||||
LLM_TENSOR_ATTN_QKV,
|
||||
LLM_TENSOR_ATTN_GATE,
|
||||
LLM_TENSOR_FFN_NORM,
|
||||
LLM_TENSOR_FFN_GATE_INP,
|
||||
LLM_TENSOR_FFN_GATE_EXPS,
|
||||
@@ -1060,6 +1086,22 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
|
||||
LLM_TENSOR_ATTN_POST_NORM,
|
||||
LLM_TENSOR_FFN_POST_NORM,
|
||||
};
|
||||
case LLM_ARCH_PLAMO3:
|
||||
return {
|
||||
LLM_TENSOR_TOKEN_EMBD,
|
||||
LLM_TENSOR_OUTPUT_NORM,
|
||||
LLM_TENSOR_OUTPUT,
|
||||
LLM_TENSOR_ATTN_NORM,
|
||||
LLM_TENSOR_ATTN_QKV,
|
||||
LLM_TENSOR_ATTN_Q_NORM,
|
||||
LLM_TENSOR_ATTN_K_NORM,
|
||||
LLM_TENSOR_ATTN_OUT,
|
||||
LLM_TENSOR_ATTN_POST_NORM,
|
||||
LLM_TENSOR_FFN_NORM,
|
||||
LLM_TENSOR_FFN_POST_NORM,
|
||||
LLM_TENSOR_FFN_DOWN,
|
||||
LLM_TENSOR_FFN_UP,
|
||||
};
|
||||
case LLM_ARCH_CODESHELL:
|
||||
return {
|
||||
LLM_TENSOR_TOKEN_EMBD,
|
||||
@@ -1690,6 +1732,38 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
|
||||
LLM_TENSOR_FFN_UP,
|
||||
LLM_TENSOR_FFN_POST_NORM,
|
||||
};
|
||||
case LLM_ARCH_EXAONE_MOE:
|
||||
return {
|
||||
LLM_TENSOR_TOKEN_EMBD,
|
||||
LLM_TENSOR_OUTPUT_NORM,
|
||||
LLM_TENSOR_OUTPUT,
|
||||
LLM_TENSOR_ROPE_FREQS,
|
||||
LLM_TENSOR_ATTN_NORM,
|
||||
LLM_TENSOR_ATTN_Q,
|
||||
LLM_TENSOR_ATTN_Q_NORM,
|
||||
LLM_TENSOR_ATTN_K,
|
||||
LLM_TENSOR_ATTN_K_NORM,
|
||||
LLM_TENSOR_ATTN_V,
|
||||
LLM_TENSOR_ATTN_OUT,
|
||||
LLM_TENSOR_FFN_NORM,
|
||||
LLM_TENSOR_FFN_GATE,
|
||||
LLM_TENSOR_FFN_DOWN,
|
||||
LLM_TENSOR_FFN_UP,
|
||||
LLM_TENSOR_FFN_GATE_INP,
|
||||
LLM_TENSOR_FFN_GATE_EXPS,
|
||||
LLM_TENSOR_FFN_DOWN_EXPS,
|
||||
LLM_TENSOR_FFN_UP_EXPS,
|
||||
LLM_TENSOR_FFN_GATE_SHEXP,
|
||||
LLM_TENSOR_FFN_UP_SHEXP,
|
||||
LLM_TENSOR_FFN_DOWN_SHEXP,
|
||||
LLM_TENSOR_FFN_EXP_PROBS_B,
|
||||
LLM_TENSOR_NEXTN_EH_PROJ,
|
||||
LLM_TENSOR_NEXTN_EMBED_TOKENS,
|
||||
LLM_TENSOR_NEXTN_ENORM,
|
||||
LLM_TENSOR_NEXTN_HNORM,
|
||||
LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD,
|
||||
LLM_TENSOR_NEXTN_SHARED_HEAD_NORM,
|
||||
};
|
||||
case LLM_ARCH_RWKV6:
|
||||
return {
|
||||
LLM_TENSOR_TOKEN_EMBD,
|
||||
@@ -2040,6 +2114,7 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
|
||||
LLM_TENSOR_TOKEN_EMBD,
|
||||
LLM_TENSOR_OUTPUT_NORM_LFM2,
|
||||
LLM_TENSOR_OUTPUT,
|
||||
LLM_TENSOR_DENSE_2_OUT,
|
||||
};
|
||||
case LLM_ARCH_LFM2MOE:
|
||||
return {
|
||||
@@ -2058,7 +2133,7 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
|
||||
LLM_TENSOR_SHORTCONV_INPROJ,
|
||||
LLM_TENSOR_SHORTCONV_OUTPROJ,
|
||||
LLM_TENSOR_TOKEN_EMBD,
|
||||
LLM_TENSOR_OUTPUT_NORM,
|
||||
LLM_TENSOR_OUTPUT_NORM_LFM2,
|
||||
LLM_TENSOR_FFN_GATE_INP,
|
||||
LLM_TENSOR_FFN_GATE_EXPS,
|
||||
LLM_TENSOR_FFN_DOWN_EXPS,
|
||||
@@ -2174,11 +2249,49 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
|
||||
LLM_TENSOR_VISEXP_FFN_DOWN,
|
||||
LLM_TENSOR_VISEXP_FFN_UP,
|
||||
};
|
||||
case LLM_ARCH_MIMO2:
|
||||
return {
|
||||
LLM_TENSOR_TOKEN_EMBD,
|
||||
LLM_TENSOR_OUTPUT_NORM,
|
||||
LLM_TENSOR_OUTPUT,
|
||||
LLM_TENSOR_ATTN_NORM,
|
||||
LLM_TENSOR_ATTN_Q,
|
||||
LLM_TENSOR_ATTN_K,
|
||||
LLM_TENSOR_ATTN_V,
|
||||
LLM_TENSOR_ATTN_SINKS,
|
||||
LLM_TENSOR_ATTN_OUT,
|
||||
LLM_TENSOR_FFN_NORM,
|
||||
LLM_TENSOR_FFN_GATE,
|
||||
LLM_TENSOR_FFN_DOWN,
|
||||
LLM_TENSOR_FFN_UP,
|
||||
LLM_TENSOR_FFN_GATE_INP,
|
||||
LLM_TENSOR_FFN_GATE_EXPS,
|
||||
LLM_TENSOR_FFN_DOWN_EXPS,
|
||||
LLM_TENSOR_FFN_UP_EXPS,
|
||||
LLM_TENSOR_FFN_EXP_PROBS_B,
|
||||
};
|
||||
case LLM_ARCH_GPTJ:
|
||||
case LLM_ARCH_UNKNOWN:
|
||||
return {
|
||||
LLM_TENSOR_TOKEN_EMBD,
|
||||
};
|
||||
case LLM_ARCH_MAINCODER:
|
||||
return {
|
||||
LLM_TENSOR_TOKEN_EMBD,
|
||||
LLM_TENSOR_OUTPUT_NORM,
|
||||
LLM_TENSOR_OUTPUT,
|
||||
LLM_TENSOR_ATTN_NORM,
|
||||
LLM_TENSOR_ATTN_Q,
|
||||
LLM_TENSOR_ATTN_Q_NORM,
|
||||
LLM_TENSOR_ATTN_K,
|
||||
LLM_TENSOR_ATTN_K_NORM,
|
||||
LLM_TENSOR_ATTN_V,
|
||||
LLM_TENSOR_ATTN_OUT,
|
||||
LLM_TENSOR_FFN_NORM,
|
||||
LLM_TENSOR_FFN_GATE,
|
||||
LLM_TENSOR_FFN_DOWN,
|
||||
LLM_TENSOR_FFN_UP,
|
||||
};
|
||||
case LLM_ARCH_SOLAR:
|
||||
return {
|
||||
LLM_TENSOR_TOKEN_EMBD,
|
||||
|
||||
9
llama/llama.cpp/src/llama-arch.h
vendored
9
llama/llama.cpp/src/llama-arch.h
vendored
@@ -24,6 +24,7 @@ enum llm_arch {
|
||||
LLM_ARCH_STARCODER,
|
||||
LLM_ARCH_REFACT,
|
||||
LLM_ARCH_BERT,
|
||||
LLM_ARCH_MODERN_BERT,
|
||||
LLM_ARCH_NOMIC_BERT,
|
||||
LLM_ARCH_NOMIC_BERT_MOE,
|
||||
LLM_ARCH_NEO_BERT,
|
||||
@@ -45,6 +46,7 @@ enum llm_arch {
|
||||
LLM_ARCH_PHIMOE,
|
||||
LLM_ARCH_PLAMO,
|
||||
LLM_ARCH_PLAMO2,
|
||||
LLM_ARCH_PLAMO3,
|
||||
LLM_ARCH_CODESHELL,
|
||||
LLM_ARCH_ORION,
|
||||
LLM_ARCH_INTERNLM2,
|
||||
@@ -83,6 +85,7 @@ enum llm_arch {
|
||||
LLM_ARCH_NEMOTRON_H_MOE,
|
||||
LLM_ARCH_EXAONE,
|
||||
LLM_ARCH_EXAONE4,
|
||||
LLM_ARCH_EXAONE_MOE,
|
||||
LLM_ARCH_RWKV6,
|
||||
LLM_ARCH_RWKV6QWEN2,
|
||||
LLM_ARCH_RWKV7,
|
||||
@@ -119,6 +122,9 @@ enum llm_arch {
|
||||
LLM_ARCH_RND1,
|
||||
LLM_ARCH_PANGU_EMBED,
|
||||
LLM_ARCH_MISTRAL3,
|
||||
LLM_ARCH_MIMO2,
|
||||
LLM_ARCH_LLAMA_EMBED,
|
||||
LLM_ARCH_MAINCODER,
|
||||
LLM_ARCH_UNKNOWN,
|
||||
};
|
||||
|
||||
@@ -152,6 +158,7 @@ enum llm_kv {
|
||||
LLM_KV_VOCAB_SIZE,
|
||||
LLM_KV_CONTEXT_LENGTH,
|
||||
LLM_KV_EMBEDDING_LENGTH,
|
||||
LLM_KV_EMBEDDING_LENGTH_OUT,
|
||||
LLM_KV_FEATURES_LENGTH,
|
||||
LLM_KV_BLOCK_COUNT,
|
||||
LLM_KV_LEADING_DENSE_BLOCK_COUNT,
|
||||
@@ -209,6 +216,7 @@ enum llm_kv {
|
||||
LLM_KV_ATTENTION_GATE_LORA_RANK,
|
||||
LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
|
||||
LLM_KV_ATTENTION_SLIDING_WINDOW,
|
||||
LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN,
|
||||
LLM_KV_ATTENTION_SCALE,
|
||||
LLM_KV_ATTENTION_OUTPUT_SCALE,
|
||||
LLM_KV_ATTENTION_TEMPERATURE_LENGTH,
|
||||
@@ -220,6 +228,7 @@ enum llm_kv {
|
||||
LLM_KV_ROPE_DIMENSION_COUNT,
|
||||
LLM_KV_ROPE_DIMENSION_SECTIONS,
|
||||
LLM_KV_ROPE_FREQ_BASE,
|
||||
LLM_KV_ROPE_FREQ_BASE_SWA,
|
||||
LLM_KV_ROPE_SCALE_LINEAR,
|
||||
LLM_KV_ROPE_SCALING_TYPE,
|
||||
LLM_KV_ROPE_SCALING_FACTOR,
|
||||
|
||||
31
llama/llama.cpp/src/llama-chat.cpp
vendored
31
llama/llama.cpp/src/llama-chat.cpp
vendored
@@ -57,6 +57,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
|
||||
{ "minicpm", LLM_CHAT_TEMPLATE_MINICPM },
|
||||
{ "exaone3", LLM_CHAT_TEMPLATE_EXAONE_3 },
|
||||
{ "exaone4", LLM_CHAT_TEMPLATE_EXAONE_4 },
|
||||
{ "exaone-moe", LLM_CHAT_TEMPLATE_EXAONE_MOE },
|
||||
{ "rwkv-world", LLM_CHAT_TEMPLATE_RWKV_WORLD },
|
||||
{ "granite", LLM_CHAT_TEMPLATE_GRANITE },
|
||||
{ "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT },
|
||||
@@ -74,6 +75,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
|
||||
{ "seed_oss", LLM_CHAT_TEMPLATE_SEED_OSS },
|
||||
{ "grok-2", LLM_CHAT_TEMPLATE_GROK_2 },
|
||||
{ "pangu-embedded", LLM_CHAT_TEMPLATE_PANGU_EMBED },
|
||||
{ "solar-open", LLM_CHAT_TEMPLATE_SOLAR_OPEN },
|
||||
};
|
||||
|
||||
llm_chat_template llm_chat_template_from_str(const std::string & name) {
|
||||
@@ -136,6 +138,9 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
|
||||
} else if (tmpl_contains("[gMASK]<sop>")) {
|
||||
return LLM_CHAT_TEMPLATE_CHATGLM_4;
|
||||
} else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|user|>")) {
|
||||
if (tmpl_contains("<|tool_declare|>")) {
|
||||
return LLM_CHAT_TEMPLATE_EXAONE_MOE;
|
||||
}
|
||||
return tmpl_contains("</s>") ? LLM_CHAT_TEMPLATE_FALCON_3 : LLM_CHAT_TEMPLATE_GLMEDGE;
|
||||
} else if (tmpl_contains("<|{{ item['role'] }}|>") && tmpl_contains("<|begin_of_image|>")) {
|
||||
return LLM_CHAT_TEMPLATE_GLMEDGE;
|
||||
@@ -216,6 +221,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
|
||||
return LLM_CHAT_TEMPLATE_GROK_2;
|
||||
} else if (tmpl_contains(LU8("[unused9]系统:[unused10]"))) {
|
||||
return LLM_CHAT_TEMPLATE_PANGU_EMBED;
|
||||
} else if (tmpl_contains("<|begin|>") && tmpl_contains("<|end|>") && tmpl_contains("<|content|>")) {
|
||||
return LLM_CHAT_TEMPLATE_SOLAR_OPEN;
|
||||
}
|
||||
return LLM_CHAT_TEMPLATE_UNKNOWN;
|
||||
}
|
||||
@@ -573,6 +580,22 @@ int32_t llm_chat_apply_template(
|
||||
if (add_ass) {
|
||||
ss << "[|assistant|]";
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_EXAONE_MOE) {
|
||||
for (auto message : chat) {
|
||||
std::string role(message->role);
|
||||
if (role == "system") {
|
||||
ss << "<|system|>\n" << trim(message->content) << "<|endofturn|>\n";
|
||||
} else if (role == "user") {
|
||||
ss << "<|user|>\n" << trim(message->content) << "<|endofturn|>\n";
|
||||
} else if (role == "assistant") {
|
||||
ss << "<|assistant|>\n" << trim(message->content) << "<|endofturn|>\n";
|
||||
} else if (role == "tool") {
|
||||
ss << "<|tool|>\n" << trim(message->content) << "<|endofturn|>\n";
|
||||
}
|
||||
}
|
||||
if (add_ass) {
|
||||
ss << "<|assistant|>\n";
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_RWKV_WORLD) {
|
||||
// this template requires the model to have "\n\n" as EOT token
|
||||
for (size_t i = 0; i < chat.size(); i++) {
|
||||
@@ -845,6 +868,14 @@ int32_t llm_chat_apply_template(
|
||||
if (add_ass) {
|
||||
ss << "[unused9]助手:";
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_SOLAR_OPEN) {
|
||||
for (auto message : chat) {
|
||||
std::string role(message->role);
|
||||
ss << "<|begin|>" << role << "<|content|>" << message->content << "<|end|>";
|
||||
}
|
||||
if (add_ass) {
|
||||
ss << "<|begin|>assistant";
|
||||
}
|
||||
} else {
|
||||
// template not supported
|
||||
return -1;
|
||||
|
||||
2
llama/llama.cpp/src/llama-chat.h
vendored
2
llama/llama.cpp/src/llama-chat.h
vendored
@@ -36,6 +36,7 @@ enum llm_chat_template {
|
||||
LLM_CHAT_TEMPLATE_MINICPM,
|
||||
LLM_CHAT_TEMPLATE_EXAONE_3,
|
||||
LLM_CHAT_TEMPLATE_EXAONE_4,
|
||||
LLM_CHAT_TEMPLATE_EXAONE_MOE,
|
||||
LLM_CHAT_TEMPLATE_RWKV_WORLD,
|
||||
LLM_CHAT_TEMPLATE_GRANITE,
|
||||
LLM_CHAT_TEMPLATE_GIGACHAT,
|
||||
@@ -54,6 +55,7 @@ enum llm_chat_template {
|
||||
LLM_CHAT_TEMPLATE_SEED_OSS,
|
||||
LLM_CHAT_TEMPLATE_GROK_2,
|
||||
LLM_CHAT_TEMPLATE_PANGU_EMBED,
|
||||
LLM_CHAT_TEMPLATE_SOLAR_OPEN,
|
||||
LLM_CHAT_TEMPLATE_UNKNOWN,
|
||||
};
|
||||
|
||||
|
||||
1066
llama/llama.cpp/src/llama-context.cpp
vendored
1066
llama/llama.cpp/src/llama-context.cpp
vendored
File diff suppressed because it is too large
Load Diff
54
llama/llama.cpp/src/llama-context.h
vendored
54
llama/llama.cpp/src/llama-context.h
vendored
@@ -40,6 +40,14 @@ struct llama_context {
|
||||
|
||||
~llama_context();
|
||||
|
||||
// reserve a new backend scheduler (if needed)
|
||||
// for example, when:
|
||||
// - changing loras
|
||||
// - changing samplers
|
||||
// - changing attention type
|
||||
// - etc.
|
||||
void sched_reserve();
|
||||
|
||||
void synchronize();
|
||||
|
||||
const llama_model & get_model() const;
|
||||
@@ -70,6 +78,18 @@ struct llama_context {
|
||||
float * get_embeddings_ith(int32_t i);
|
||||
float * get_embeddings_seq(llama_seq_id seq_id);
|
||||
|
||||
llama_token * get_sampled_tokens() const;
|
||||
llama_token get_sampled_token_ith(int32_t idx);
|
||||
|
||||
float * get_sampled_logits_ith(int32_t idx);
|
||||
size_t get_sampled_logits_count(int32_t idx);
|
||||
|
||||
float * get_sampled_probs_ith(int32_t idx);
|
||||
size_t get_sampled_probs_count(int32_t idx);
|
||||
|
||||
const llama_token * get_sampled_candidates_ith(int32_t idx);
|
||||
size_t get_sampled_candidates_count(int32_t idx);
|
||||
|
||||
void attach_threadpool(
|
||||
ggml_threadpool_t threadpool,
|
||||
ggml_threadpool_t threadpool_batch);
|
||||
@@ -192,10 +212,13 @@ private:
|
||||
|
||||
// Make sure enough space is available for outputs.
|
||||
// Returns max number of outputs for which space was reserved.
|
||||
uint32_t output_reserve(int32_t n_outputs);
|
||||
uint32_t output_reserve(int32_t n_outputs, const llama_batch & batch);
|
||||
|
||||
void output_reorder();
|
||||
|
||||
// map the output row index `i` to batch index
|
||||
int64_t output_resolve_row(int32_t i) const;
|
||||
|
||||
//
|
||||
// graph
|
||||
//
|
||||
@@ -213,6 +236,8 @@ public:
|
||||
ggml_cgraph * graph_reserve(
|
||||
uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only = false, size_t * sizes = nullptr);
|
||||
|
||||
bool set_sampler(llama_seq_id seq_id, llama_sampler * sampler);
|
||||
|
||||
private:
|
||||
llm_graph_params graph_params(
|
||||
llm_graph_result * res,
|
||||
@@ -252,6 +277,31 @@ private:
|
||||
size_t embd_size = 0; // capacity (of floats) for embeddings
|
||||
float * embd = nullptr;
|
||||
|
||||
// TODO: simplify
|
||||
struct sampling_info {
|
||||
std::map<llama_seq_id, llama_sampler *> samplers;
|
||||
|
||||
float * logits = nullptr;
|
||||
size_t logits_size = 0;
|
||||
|
||||
llama_token * sampled = nullptr;
|
||||
size_t sampled_size = 0;
|
||||
|
||||
float * probs = nullptr;
|
||||
size_t probs_size = 0;
|
||||
|
||||
llama_token * candidates = nullptr;
|
||||
size_t candidates_size = 0;
|
||||
|
||||
std::vector<uint32_t> logits_count;
|
||||
std::vector<uint32_t> probs_count;
|
||||
std::vector<uint32_t> candidates_count;
|
||||
|
||||
std::vector<llama_token> token_ids_full_vocab;
|
||||
};
|
||||
|
||||
sampling_info sampling;
|
||||
|
||||
// sequence embeddings output (map of [n_embd] vectors)
|
||||
// populated only when pooling_type != LLAMA_POOLING_TYPE_NONE
|
||||
std::map<llama_seq_id, std::vector<float>> embd_seq;
|
||||
@@ -272,6 +322,8 @@ private:
|
||||
|
||||
ggml_backend_sched_ptr sched;
|
||||
|
||||
bool sched_need_reserve = true;
|
||||
|
||||
ggml_backend_t backend_cpu = nullptr;
|
||||
std::vector<ggml_backend_ptr> backends;
|
||||
|
||||
|
||||
2
llama/llama.cpp/src/llama-cparams.h
vendored
2
llama/llama.cpp/src/llama-cparams.h
vendored
@@ -30,10 +30,12 @@ struct llama_cparams {
|
||||
bool causal_attn;
|
||||
bool offload_kqv;
|
||||
bool flash_attn;
|
||||
bool auto_fa;
|
||||
bool no_perf;
|
||||
bool warmup;
|
||||
bool op_offload;
|
||||
bool kv_unified;
|
||||
bool pipeline_parallel;
|
||||
|
||||
enum llama_pooling_type pooling_type;
|
||||
|
||||
|
||||
53
llama/llama.cpp/src/llama-grammar.cpp
vendored
53
llama/llama.cpp/src/llama-grammar.cpp
vendored
@@ -369,6 +369,44 @@ static void print_rule(
|
||||
fprintf(file, "\n");
|
||||
}
|
||||
|
||||
//
|
||||
// Regex utilities
|
||||
//
|
||||
|
||||
size_t llama_grammar_trigger_pattern::find(const std::string & input) const {
|
||||
auto find_start_pos = [](const std::smatch & match) {
|
||||
// get from the first matched capturing group to the end of the string
|
||||
size_t start = std::string::npos;
|
||||
for (auto i = 1u; i < match.size(); i++) {
|
||||
if (match.length(i) > 0) {
|
||||
start = match.position(i);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (start == std::string::npos) {
|
||||
start = match.position(0);
|
||||
}
|
||||
return start;
|
||||
};
|
||||
|
||||
if (!pattern.empty() && pattern.front() == '^' && pattern.back() == '$') {
|
||||
// match against the entire input
|
||||
std::smatch match;
|
||||
if (std::regex_match(input, match, regex)) {
|
||||
return find_start_pos(match);
|
||||
}
|
||||
}
|
||||
|
||||
// search anywhere
|
||||
std::smatch match;
|
||||
if (std::regex_search(input, match, regex)) {
|
||||
return find_start_pos(match);
|
||||
}
|
||||
|
||||
return std::string::npos;
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// implementation
|
||||
//
|
||||
@@ -1321,21 +1359,10 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
|
||||
grammar.trigger_buffer_positions.push_back(std::make_pair(token, position));
|
||||
grammar.trigger_buffer += piece;
|
||||
|
||||
std::smatch match;
|
||||
for (const auto & trigger_pattern : grammar.trigger_patterns) {
|
||||
if (std::regex_match(grammar.trigger_buffer, match, trigger_pattern.regex)) {
|
||||
auto start = trigger_pattern.find(grammar.trigger_buffer);
|
||||
if (start != std::string::npos) {
|
||||
grammar.awaiting_trigger = false;
|
||||
// get from the first matched capturing group to the end of the string
|
||||
size_t start = std::string::npos;
|
||||
for (auto i = 1u; i < match.size(); i++) {
|
||||
if (match.length(i) > 0) {
|
||||
start = match.position(i);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (start == std::string::npos) {
|
||||
start = match.position(0);
|
||||
}
|
||||
|
||||
// replay tokens that overlap with [start, end)
|
||||
for (const auto & [tok, tok_pos] : grammar.trigger_buffer_positions) {
|
||||
|
||||
2
llama/llama.cpp/src/llama-grammar.h
vendored
2
llama/llama.cpp/src/llama-grammar.h
vendored
@@ -130,6 +130,8 @@ struct llama_grammar_parser {
|
||||
struct llama_grammar_trigger_pattern {
|
||||
std::string pattern;
|
||||
std::regex regex;
|
||||
|
||||
size_t find(const std::string & input) const;
|
||||
};
|
||||
|
||||
struct llama_grammar {
|
||||
|
||||
548
llama/llama.cpp/src/llama-graph.cpp
vendored
548
llama/llama.cpp/src/llama-graph.cpp
vendored
@@ -7,11 +7,13 @@
|
||||
#include "llama-kv-cache.h"
|
||||
#include "llama-kv-cache-iswa.h"
|
||||
#include "llama-memory-hybrid.h"
|
||||
#include "llama-memory-hybrid-iswa.h"
|
||||
#include "llama-memory-recurrent.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
#include <unordered_set>
|
||||
|
||||
void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
|
||||
if (ubatch->token) {
|
||||
@@ -21,7 +23,8 @@ void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
|
||||
}
|
||||
|
||||
if (ubatch->embd) {
|
||||
const int64_t n_embd = embd->ne[0];
|
||||
GGML_ASSERT(n_embd == embd->ne[0]);
|
||||
|
||||
const int64_t n_tokens = ubatch->n_tokens;
|
||||
|
||||
ggml_backend_tensor_set(embd, ubatch->embd, 0, n_tokens*n_embd*ggml_element_size(embd));
|
||||
@@ -31,8 +34,8 @@ void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
|
||||
bool llm_graph_input_embd::can_reuse(const llm_graph_params & params) {
|
||||
bool res = true;
|
||||
|
||||
res &= (!tokens && !params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens);
|
||||
res &= (!embd && !params.ubatch.embd) || (embd && embd->ne[0] == params.ubatch.n_tokens);
|
||||
res &= (!params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens);
|
||||
res &= (!params.ubatch.embd) || (embd && embd->ne[1] == params.ubatch.n_tokens);
|
||||
|
||||
return res;
|
||||
}
|
||||
@@ -62,7 +65,7 @@ void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
|
||||
bool llm_graph_input_pos::can_reuse(const llm_graph_params & params) {
|
||||
bool res = true;
|
||||
|
||||
res &= pos->ne[0] == params.ubatch.n_tokens;
|
||||
res &= pos->ne[0] == params.ubatch.n_tokens*n_pos_per_embd;
|
||||
|
||||
return res;
|
||||
}
|
||||
@@ -95,11 +98,9 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
|
||||
|
||||
int32_t * data = (int32_t *) pos_bucket->data;
|
||||
|
||||
for (int h = 0; h < 1; ++h) {
|
||||
for (int j = 0; j < n_tokens; ++j) {
|
||||
for (int i = 0; i < n_tokens; ++i) {
|
||||
data[h*(n_tokens*n_tokens) + j*n_tokens + i] = llama_relative_position_bucket(ubatch->pos[i], ubatch->pos[j], hparams.n_rel_attn_bkts, true);
|
||||
}
|
||||
for (int j = 0; j < n_tokens; ++j) {
|
||||
for (int i = 0; i < n_tokens; ++i) {
|
||||
data[j*n_tokens + i] = llama_relative_position_bucket(ubatch->pos[i], ubatch->pos[j], hparams.n_rel_attn_bkts, true);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -322,34 +323,32 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
|
||||
const int64_t n_tokens = ubatch->n_tokens;
|
||||
|
||||
const auto fill_mask = [&](float * data, int n_swa, llama_swa_type swa_type) {
|
||||
for (int h = 0; h < 1; ++h) {
|
||||
for (int i1 = 0; i1 < n_tokens; ++i1) {
|
||||
const llama_seq_id s1 = ubatch->seq_id[i1][0];
|
||||
const llama_pos p1 = ubatch->pos[i1];
|
||||
for (int i1 = 0; i1 < n_tokens; ++i1) {
|
||||
const llama_seq_id s1 = ubatch->seq_id[i1][0];
|
||||
const llama_pos p1 = ubatch->pos[i1];
|
||||
|
||||
const uint64_t idst = h*(n_kv*n_tokens) + i1*n_kv;
|
||||
const uint64_t idst = i1*n_kv;
|
||||
|
||||
for (int i0 = 0; i0 < n_tokens; ++i0) {
|
||||
const llama_seq_id s0 = ubatch->seq_id[i0][0];
|
||||
const llama_pos p0 = ubatch->pos[i0];
|
||||
for (int i0 = 0; i0 < n_tokens; ++i0) {
|
||||
const llama_seq_id s0 = ubatch->seq_id[i0][0];
|
||||
const llama_pos p0 = ubatch->pos[i0];
|
||||
|
||||
// mask different sequences
|
||||
if (s0 != s1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// mask future tokens
|
||||
if (cparams.causal_attn && p0 > p1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// apply SWA if any
|
||||
if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
data[idst + i0] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
|
||||
// mask different sequences
|
||||
if (s0 != s1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// mask future tokens
|
||||
if (cparams.causal_attn && p0 > p1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// apply SWA if any
|
||||
if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
data[idst + i0] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -408,6 +407,27 @@ bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) {
|
||||
return res;
|
||||
}
|
||||
|
||||
void llm_graph_input_attn_k::set_input(const llama_ubatch * ubatch) {
|
||||
mctx->set_input_k_idxs(self_k_idxs, ubatch);
|
||||
|
||||
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
||||
}
|
||||
|
||||
bool llm_graph_input_attn_k::can_reuse(const llm_graph_params & params) {
|
||||
const auto * mctx = static_cast<const llama_kv_cache_context *>(params.mctx);
|
||||
|
||||
this->mctx = mctx;
|
||||
|
||||
bool res = true;
|
||||
|
||||
res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
|
||||
|
||||
res &= self_kq_mask->ne[0] == mctx->get_n_kv();
|
||||
res &= self_kq_mask->ne[1] == params.ubatch.n_tokens;
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
|
||||
mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
|
||||
mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
|
||||
@@ -453,27 +473,19 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
|
||||
|
||||
float * data = (float *) cross_kq_mask->data;
|
||||
|
||||
for (int h = 0; h < 1; ++h) {
|
||||
for (int i = 0; i < n_tokens; ++i) {
|
||||
for (int j = 0; j < n_enc; ++j) {
|
||||
float f = -INFINITY;
|
||||
for (int i = 0; i < n_tokens; ++i) {
|
||||
for (int j = 0; j < n_enc; ++j) {
|
||||
float f = -INFINITY;
|
||||
|
||||
for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
|
||||
const llama_seq_id seq_id = ubatch->seq_id[i][s];
|
||||
for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
|
||||
const llama_seq_id seq_id = ubatch->seq_id[i][s];
|
||||
|
||||
if (cross->seq_ids_enc[j].find(seq_id) != cross->seq_ids_enc[j].end()) {
|
||||
f = 0.0f;
|
||||
}
|
||||
if (cross->seq_ids_enc[j].find(seq_id) != cross->seq_ids_enc[j].end()) {
|
||||
f = 0.0f;
|
||||
}
|
||||
|
||||
data[h*(n_enc*n_tokens) + i*n_enc + j] = f;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = n_tokens; i < n_tokens; ++i) {
|
||||
for (int j = 0; j < n_enc; ++j) {
|
||||
data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
|
||||
}
|
||||
data[i*n_enc + j] = f;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -521,6 +533,113 @@ bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) {
|
||||
return res;
|
||||
}
|
||||
|
||||
void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) {
|
||||
const auto * attn_ctx = mctx->get_attn();
|
||||
|
||||
// base tensors may not be allocated if there are no non-SWA attention layers
|
||||
if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) {
|
||||
attn_ctx->get_base()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch);
|
||||
attn_ctx->get_base()->set_input_v_idxs(inp_attn->self_v_idxs, ubatch);
|
||||
|
||||
attn_ctx->get_base()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn);
|
||||
}
|
||||
|
||||
// swa tensors may not be allocated if there are no SWA attention layers
|
||||
if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) {
|
||||
attn_ctx->get_swa()->set_input_k_idxs(inp_attn->self_k_idxs_swa, ubatch);
|
||||
attn_ctx->get_swa()->set_input_v_idxs(inp_attn->self_v_idxs_swa, ubatch);
|
||||
|
||||
attn_ctx->get_swa()->set_input_kq_mask(inp_attn->self_kq_mask_swa, ubatch, cparams.causal_attn);
|
||||
}
|
||||
|
||||
const int64_t n_rs = mctx->get_recr()->get_n_rs();
|
||||
|
||||
if (inp_rs->s_copy) {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer));
|
||||
int32_t * data = (int32_t *) inp_rs->s_copy->data;
|
||||
|
||||
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
|
||||
for (uint32_t i = 0; i < n_rs; ++i) {
|
||||
data[i] = mctx->get_recr()->s_copy(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool llm_graph_input_mem_hybrid_iswa::can_reuse(const llm_graph_params & params) {
|
||||
const auto * mctx = static_cast<const llama_memory_hybrid_iswa_context *>(params.mctx);
|
||||
|
||||
this->mctx = mctx;
|
||||
|
||||
bool res = true;
|
||||
|
||||
const auto * attn_ctx = mctx->get_attn();
|
||||
|
||||
// base tensors may not be allocated if there are no non-SWA attention layers
|
||||
if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) {
|
||||
res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
|
||||
//res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
|
||||
|
||||
res &= inp_attn->self_kq_mask->ne[0] == attn_ctx->get_base()->get_n_kv();
|
||||
res &= inp_attn->self_kq_mask->ne[1] == params.ubatch.n_tokens;
|
||||
}
|
||||
|
||||
// swa tensors may not be allocated if there are no SWA attention layers
|
||||
if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) {
|
||||
res &= inp_attn->self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
|
||||
//res &= inp_attn->self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
|
||||
|
||||
res &= inp_attn->self_kq_mask_swa->ne[0] == attn_ctx->get_swa()->get_n_kv();
|
||||
res &= inp_attn->self_kq_mask_swa->ne[1] == params.ubatch.n_tokens;
|
||||
}
|
||||
|
||||
res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
|
||||
|
||||
res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs;
|
||||
res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs;
|
||||
|
||||
res &= inp_rs->head == mctx->get_recr()->get_head();
|
||||
res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z();
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
void llm_graph_input_sampling::set_input(const llama_ubatch * ubatch) {
|
||||
// set the inputs only for the active samplers in the current ubatch
|
||||
std::unordered_set<llama_seq_id> active_samplers;
|
||||
for (uint32_t i = 0; i < ubatch->n_tokens; i++) {
|
||||
if (ubatch->output[i]) {
|
||||
llama_seq_id seq_id = ubatch->seq_id[i][0];
|
||||
active_samplers.insert(seq_id);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto seq_id : active_samplers) {
|
||||
if (samplers.find(seq_id) == samplers.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto & sampler = samplers[seq_id];
|
||||
|
||||
if (sampler->iface->backend_set_input) {
|
||||
sampler->iface->backend_set_input(sampler);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool llm_graph_input_sampling::can_reuse(const llm_graph_params & params) {
|
||||
if (samplers.size() != params.samplers.size()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (const auto & [seq_id, sampler] : params.samplers) {
|
||||
if (samplers[seq_id] != sampler) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
//
|
||||
// llm_graph_result
|
||||
//
|
||||
@@ -537,10 +656,15 @@ int64_t llm_graph_result::get_max_nodes() const {
|
||||
}
|
||||
|
||||
void llm_graph_result::reset() {
|
||||
t_tokens = nullptr;
|
||||
t_inp_tokens = nullptr;
|
||||
t_inp_embd = nullptr;
|
||||
t_logits = nullptr;
|
||||
t_embd = nullptr;
|
||||
t_embd_pooled = nullptr;
|
||||
t_sampled.clear();
|
||||
t_sampled_probs.clear();
|
||||
t_sampled_logits.clear();
|
||||
t_candidates.clear();
|
||||
|
||||
params = {};
|
||||
|
||||
@@ -565,6 +689,38 @@ void llm_graph_result::set_inputs(const llama_ubatch * ubatch) {
|
||||
}
|
||||
}
|
||||
|
||||
void llm_graph_result::set_outputs() {
|
||||
if (t_logits != nullptr) {
|
||||
ggml_set_output(t_logits);
|
||||
}
|
||||
if (t_embd != nullptr) {
|
||||
ggml_set_output(t_embd);
|
||||
}
|
||||
if (t_embd_pooled != nullptr) {
|
||||
ggml_set_output(t_embd_pooled);
|
||||
}
|
||||
for (auto & [seq_id, t] : t_sampled) {
|
||||
if (t != nullptr) {
|
||||
ggml_set_output(t);
|
||||
}
|
||||
}
|
||||
for (auto & [seq_id, t] : t_sampled_probs) {
|
||||
if (t != nullptr) {
|
||||
ggml_set_output(t);
|
||||
}
|
||||
}
|
||||
for (auto & [seq_id, t] : t_sampled_logits) {
|
||||
if (t != nullptr) {
|
||||
ggml_set_output(t);
|
||||
}
|
||||
}
|
||||
for (auto & [seq_id, t] : t_candidates) {
|
||||
if (t != nullptr) {
|
||||
ggml_set_output(t);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool llm_graph_result::can_reuse(const llm_graph_params & params) {
|
||||
if (!this->params.allow_reuse(params)) {
|
||||
if (debug > 1) {
|
||||
@@ -646,6 +802,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
|
||||
loras (params.loras),
|
||||
mctx (params.mctx),
|
||||
cross (params.cross),
|
||||
samplers (params.samplers),
|
||||
cb_func (params.cb),
|
||||
res (params.res),
|
||||
ctx0 (res->get_ctx()),
|
||||
@@ -1204,17 +1361,29 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
||||
|
||||
// input embeddings with optional lora
|
||||
ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
|
||||
const int64_t n_embd = hparams.n_embd_inp();
|
||||
const int64_t n_embd_inp = hparams.n_embd_inp();
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_embd>();
|
||||
assert(n_embd_inp >= n_embd);
|
||||
|
||||
ggml_tensor * cur = nullptr;
|
||||
auto inp = std::make_unique<llm_graph_input_embd>(n_embd_inp);
|
||||
|
||||
if (ubatch.token) {
|
||||
inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
|
||||
//cb(inp->tokens, "inp_tokens", -1);
|
||||
ggml_set_input(inp->tokens);
|
||||
res->t_tokens = inp->tokens;
|
||||
inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
|
||||
cb(inp->tokens, "inp_tokens", -1);
|
||||
ggml_set_input(inp->tokens);
|
||||
res->t_inp_tokens = inp->tokens;
|
||||
|
||||
inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd_inp, ubatch.n_tokens);
|
||||
cb(inp->embd, "inp_embd", -1);
|
||||
ggml_set_input(inp->embd);
|
||||
|
||||
// select one of the 2 inputs, based on the batch contents
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/18550
|
||||
std::array<ggml_tensor *, 2> inps;
|
||||
|
||||
// token embeddings path (ubatch.token != nullptr)
|
||||
{
|
||||
auto & cur = inps[0];
|
||||
|
||||
cur = ggml_get_rows(ctx0, tok_embd, inp->tokens);
|
||||
|
||||
@@ -1235,22 +1404,43 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
|
||||
|
||||
cur = ggml_add(ctx0, cur, inpL_delta);
|
||||
}
|
||||
} else {
|
||||
inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, ubatch.n_tokens);
|
||||
ggml_set_input(inp->embd);
|
||||
|
||||
if (n_embd_inp != n_embd) {
|
||||
cur = ggml_pad(ctx0, cur, hparams.n_embd_inp() - n_embd, 0, 0, 0);
|
||||
}
|
||||
}
|
||||
|
||||
// vector embeddings path (ubatch.embd != nullptr)
|
||||
{
|
||||
auto & cur = inps[1];
|
||||
|
||||
cur = inp->embd;
|
||||
}
|
||||
|
||||
assert(ggml_are_same_shape (inps[0], inps[1]));
|
||||
assert(ggml_are_same_stride(inps[0], inps[1]));
|
||||
|
||||
ggml_tensor * cur = ggml_build_forward_select(gf, inps.data(), inps.size(), ubatch.token ? 0 : 1);
|
||||
|
||||
if (n_embd_inp != n_embd) {
|
||||
cur = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0);
|
||||
}
|
||||
|
||||
res->t_inp_embd = cur;
|
||||
|
||||
// For Granite architecture
|
||||
if (hparams.f_embedding_scale != 0.0f) {
|
||||
cur = ggml_scale(ctx0, cur, hparams.f_embedding_scale);
|
||||
}
|
||||
|
||||
cb(cur, "inp_embd", -1);
|
||||
cb(cur, "embd", -1);
|
||||
|
||||
res->add_input(std::move(inp));
|
||||
|
||||
// make sure the produced embeddings are immediately materialized in the ggml graph
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/18599
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
return cur;
|
||||
}
|
||||
|
||||
@@ -1342,7 +1532,7 @@ ggml_tensor * llm_graph_context::build_inp_cross_embd() const {
|
||||
//}
|
||||
|
||||
const auto n_embd = !cross->v_embd.empty() ? cross->n_embd : hparams.n_embd_inp();
|
||||
const auto n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
|
||||
const auto n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
|
||||
|
||||
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_enc);
|
||||
ggml_set_input(cur);
|
||||
@@ -1440,6 +1630,11 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
||||
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
|
||||
cb(cur, LLAMA_TENSOR_NAME_FATTN, il);
|
||||
|
||||
if (!cparams.offload_kqv) {
|
||||
// all nodes between the KV store and the attention output are run on the CPU
|
||||
ggml_backend_sched_set_tensor_backend(sched, cur, backend_cpu);
|
||||
}
|
||||
|
||||
ggml_flash_attn_ext_add_sinks(cur, sinks);
|
||||
ggml_flash_attn_ext_set_prec (cur, GGML_PREC_F32);
|
||||
|
||||
@@ -1649,9 +1844,11 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||
ggml_tensor * v_cur,
|
||||
ggml_tensor * kq_b,
|
||||
ggml_tensor * sinks,
|
||||
ggml_tensor * v_mla,
|
||||
ggml_tensor * v_mla, // TODO: remove
|
||||
float kq_scale,
|
||||
int il) const {
|
||||
GGML_ASSERT(v_mla == nullptr);
|
||||
|
||||
// these nodes are added to the graph together so that they are not reordered
|
||||
// by doing so, the number of splits in the graph is reduced
|
||||
// expand k later to enable rope fusion which directly writes into k-v cache
|
||||
@@ -1694,6 +1891,93 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||
return cur;
|
||||
}
|
||||
|
||||
static std::unique_ptr<llm_graph_input_attn_k> build_attn_inp_k_impl(
|
||||
ggml_context * ctx0,
|
||||
const llama_ubatch & ubatch,
|
||||
const llama_hparams & hparams,
|
||||
const llama_cparams & cparams,
|
||||
const llama_kv_cache_context * mctx_cur) {
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_attn_k>(hparams, cparams, mctx_cur);
|
||||
|
||||
{
|
||||
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA");
|
||||
|
||||
const auto n_kv = mctx_cur->get_n_kv();
|
||||
const auto n_tokens = ubatch.n_tokens;
|
||||
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
|
||||
|
||||
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
|
||||
|
||||
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
|
||||
ggml_set_input(inp->self_kq_mask);
|
||||
|
||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||
}
|
||||
|
||||
return inp;
|
||||
}
|
||||
|
||||
llm_graph_input_attn_k * llm_graph_context::build_attn_inp_k() const {
|
||||
const auto * mctx_cur = static_cast<const llama_kv_cache_context *>(mctx);
|
||||
|
||||
auto inp = build_attn_inp_k_impl(ctx0, ubatch, hparams, cparams, mctx_cur);
|
||||
|
||||
return (llm_graph_input_attn_k *) res->add_input(std::move(inp));
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_attn(
|
||||
llm_graph_input_attn_k * inp,
|
||||
ggml_tensor * wo,
|
||||
ggml_tensor * wo_b,
|
||||
ggml_tensor * q_cur,
|
||||
ggml_tensor * k_cur,
|
||||
ggml_tensor * v_cur,
|
||||
ggml_tensor * kq_b,
|
||||
ggml_tensor * sinks,
|
||||
ggml_tensor * v_mla,
|
||||
float kq_scale,
|
||||
int il) const {
|
||||
// these nodes are added to the graph together so that they are not reordered
|
||||
// by doing so, the number of splits in the graph is reduced
|
||||
// expand k later to enable rope fusion which directly writes into k-v cache
|
||||
ggml_build_forward_expand(gf, q_cur);
|
||||
ggml_build_forward_expand(gf, v_cur);
|
||||
ggml_build_forward_expand(gf, k_cur);
|
||||
|
||||
const auto * mctx_cur = inp->mctx;
|
||||
|
||||
// store to KV cache
|
||||
{
|
||||
const auto & k_idxs = inp->get_k_idxs();
|
||||
|
||||
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
|
||||
}
|
||||
|
||||
const auto & kq_mask = inp->get_kq_mask();
|
||||
|
||||
ggml_tensor * q = q_cur;
|
||||
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
|
||||
ggml_tensor * v = ggml_view_4d(ctx0, k, v_cur->ne[0], k->ne[1], k->ne[2], k->ne[3], k->nb[1], k->nb[2], k->nb[3], 0);
|
||||
|
||||
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
|
||||
cb(cur, "kqv_out", il);
|
||||
|
||||
if (wo) {
|
||||
cur = build_lora_mm(wo, cur);
|
||||
if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
|
||||
// GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
|
||||
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
|
||||
}
|
||||
}
|
||||
|
||||
if (wo_b) {
|
||||
cur = ggml_add(ctx0, cur, wo_b);
|
||||
}
|
||||
|
||||
return cur;
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_attn(
|
||||
llm_graph_input_attn_kv_iswa * inp,
|
||||
ggml_tensor * wo,
|
||||
@@ -1834,8 +2118,10 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
|
||||
|
||||
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
|
||||
ggml_set_input(inp->self_kq_mask);
|
||||
ggml_set_name(inp->self_kq_mask, "self_kq_mask");
|
||||
|
||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||
ggml_set_name(inp->self_kq_mask_cnv, "self_kq_mask_cnv");
|
||||
}
|
||||
|
||||
{
|
||||
@@ -1848,8 +2134,10 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
|
||||
|
||||
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
|
||||
ggml_set_input(inp->self_kq_mask_swa);
|
||||
ggml_set_name(inp->self_kq_mask_swa, "self_kq_mask_swa");
|
||||
|
||||
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
||||
ggml_set_name(inp->self_kq_mask_swa_cnv, "self_kq_mask_swa_cnv");
|
||||
}
|
||||
|
||||
return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp));
|
||||
@@ -1985,17 +2273,62 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
|
||||
return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
|
||||
}
|
||||
|
||||
llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa() const {
|
||||
const auto * mctx_cur = static_cast<const llama_memory_hybrid_iswa_context *>(mctx);
|
||||
|
||||
auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr());
|
||||
|
||||
// build iswa attention input
|
||||
const auto * attn_ctx = mctx_cur->get_attn();
|
||||
|
||||
auto inp_attn = std::make_unique<llm_graph_input_attn_kv_iswa>(hparams, cparams, attn_ctx);
|
||||
|
||||
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
|
||||
|
||||
{
|
||||
const auto n_kv = attn_ctx->get_base()->get_n_kv();
|
||||
|
||||
inp_attn->self_k_idxs = attn_ctx->get_base()->build_input_k_idxs(ctx0, ubatch);
|
||||
inp_attn->self_v_idxs = attn_ctx->get_base()->build_input_v_idxs(ctx0, ubatch);
|
||||
|
||||
inp_attn->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
|
||||
ggml_set_input(inp_attn->self_kq_mask);
|
||||
|
||||
inp_attn->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask, GGML_TYPE_F16) : inp_attn->self_kq_mask;
|
||||
}
|
||||
|
||||
{
|
||||
const auto n_kv = attn_ctx->get_swa()->get_n_kv();
|
||||
|
||||
inp_attn->self_k_idxs_swa = attn_ctx->get_swa()->build_input_k_idxs(ctx0, ubatch);
|
||||
inp_attn->self_v_idxs_swa = attn_ctx->get_swa()->build_input_v_idxs(ctx0, ubatch);
|
||||
|
||||
inp_attn->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
|
||||
ggml_set_input(inp_attn->self_kq_mask_swa);
|
||||
|
||||
inp_attn->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask_swa, GGML_TYPE_F16) : inp_attn->self_kq_mask_swa;
|
||||
}
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_mem_hybrid_iswa>(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur);
|
||||
|
||||
return (llm_graph_input_mem_hybrid_iswa *) res->add_input(std::move(inp));
|
||||
}
|
||||
|
||||
void llm_graph_context::build_dense_out(
|
||||
ggml_tensor * dense_2,
|
||||
ggml_tensor * dense_3) const {
|
||||
if (!cparams.embeddings || dense_2 == nullptr || dense_3 == nullptr) {
|
||||
if (!cparams.embeddings || !(dense_2 || dense_3)) {
|
||||
return;
|
||||
}
|
||||
ggml_tensor * cur = res->t_embd_pooled != nullptr ? res->t_embd_pooled : res->t_embd;
|
||||
GGML_ASSERT(cur != nullptr && "missing t_embd_pooled/t_embd");
|
||||
|
||||
cur = ggml_mul_mat(ctx0, dense_2, cur);
|
||||
cur = ggml_mul_mat(ctx0, dense_3, cur);
|
||||
if (dense_2) {
|
||||
cur = ggml_mul_mat(ctx0, dense_2, cur);
|
||||
}
|
||||
if (dense_3) {
|
||||
cur = ggml_mul_mat(ctx0, dense_3, cur);
|
||||
}
|
||||
cb(cur, "result_embd_pooled", -1);
|
||||
res->t_embd_pooled = cur;
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
@@ -2086,6 +2419,87 @@ void llm_graph_context::build_pooling(
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
|
||||
void llm_graph_context::build_sampling() const {
|
||||
if (samplers.empty() || !res->t_logits) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto inp_sampling = std::make_unique<llm_graph_input_sampling>(samplers);
|
||||
res->add_input(std::move(inp_sampling));
|
||||
|
||||
std::map<llama_seq_id, int32_t> seq_to_logit_row;
|
||||
int32_t logit_row_idx = 0;
|
||||
|
||||
for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
|
||||
if (ubatch.output[i]) {
|
||||
llama_seq_id seq_id = ubatch.seq_id[i][0];
|
||||
seq_to_logit_row[seq_id] = logit_row_idx;
|
||||
logit_row_idx++;
|
||||
}
|
||||
}
|
||||
|
||||
// res->t_logits will contain logits for all tokens that want the logits calculated (logits=1 or output=1)
|
||||
GGML_ASSERT(res->t_logits != nullptr && "missing t_logits tensor");
|
||||
|
||||
// add a dummy row of logits
|
||||
// this trick makes the graph static, regardless of which samplers are activated
|
||||
// this is important in order to minimize graph reallocations
|
||||
// TODO: use `ggml_build_forward_select()` when available (https://github.com/ggml-org/llama.cpp/pull/18550)
|
||||
ggml_tensor * logits_t = ggml_pad(ctx0, res->t_logits, 0, 1, 0, 0);
|
||||
|
||||
for (const auto & [seq_id, sampler] : samplers) {
|
||||
const auto it = seq_to_logit_row.find(seq_id);
|
||||
|
||||
// inactive samplers always work on the first row
|
||||
const auto row_idx = seq_to_logit_row.find(seq_id) != seq_to_logit_row.end() ? it->second : 0;
|
||||
|
||||
ggml_tensor * logits_seq = ggml_view_1d(ctx0, logits_t, logits_t->ne[0], row_idx * logits_t->nb[1]);
|
||||
ggml_format_name(logits_seq, "logits_seq_%d", seq_id);
|
||||
|
||||
struct llama_sampler_data data = {
|
||||
/*.logits =*/ logits_seq,
|
||||
/*.probs =*/ nullptr,
|
||||
/*.sampled =*/ nullptr,
|
||||
/*.candidates =*/ nullptr,
|
||||
};
|
||||
|
||||
assert(sampler->iface->backend_apply);
|
||||
sampler->iface->backend_apply(sampler, ctx0, gf, &data);
|
||||
|
||||
if (data.sampled != nullptr) {
|
||||
res->t_sampled[seq_id] = data.sampled;
|
||||
ggml_build_forward_expand(gf, data.sampled);
|
||||
}
|
||||
|
||||
if (data.probs != nullptr) {
|
||||
res->t_sampled_probs[seq_id] = data.probs;
|
||||
ggml_build_forward_expand(gf, data.probs);
|
||||
}
|
||||
|
||||
if (data.logits != nullptr) {
|
||||
res->t_sampled_logits[seq_id] = data.logits;
|
||||
ggml_build_forward_expand(gf, data.logits);
|
||||
}
|
||||
|
||||
if (data.candidates != nullptr) {
|
||||
res->t_candidates[seq_id] = data.candidates;
|
||||
ggml_build_forward_expand(gf, data.candidates);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Call llama_sampler_accept_ggml after all samplers have been applied.
|
||||
/*
|
||||
for (const auto & [seq_id, sampler] : samplers) {
|
||||
if (auto it = res->t_sampled.find(seq_id); it != res->t_sampled.end()) {
|
||||
ggml_tensor * selected_token = it->second;
|
||||
if (selected_token != nullptr) {
|
||||
llama_sampler_accept_ggml(sampler, ctx0, gf, selected_token);
|
||||
}
|
||||
}
|
||||
}
|
||||
*/
|
||||
}
|
||||
|
||||
int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
|
||||
// TODO move to hparams if a T5 variant appears that uses a different value
|
||||
const int64_t max_distance = 128;
|
||||
|
||||
165
llama/llama.cpp/src/llama-graph.h
vendored
165
llama/llama.cpp/src/llama-graph.h
vendored
@@ -10,6 +10,7 @@
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <functional>
|
||||
#include <map>
|
||||
|
||||
struct ggml_cgraph;
|
||||
struct ggml_context;
|
||||
@@ -23,6 +24,7 @@ class llama_kv_cache_context;
|
||||
class llama_kv_cache_iswa_context;
|
||||
class llama_memory_recurrent_context;
|
||||
class llama_memory_hybrid_context;
|
||||
class llama_memory_hybrid_iswa_context;
|
||||
|
||||
// certain models (typically multi-modal) can produce different types of graphs
|
||||
enum llm_graph_type {
|
||||
@@ -104,7 +106,7 @@ using llm_graph_input_ptr = std::unique_ptr<llm_graph_input_i>;
|
||||
|
||||
class llm_graph_input_embd : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_embd() = default;
|
||||
llm_graph_input_embd(int64_t n_embd) : n_embd(n_embd) {}
|
||||
virtual ~llm_graph_input_embd() = default;
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
@@ -113,6 +115,8 @@ public:
|
||||
|
||||
ggml_tensor * tokens = nullptr; // I32 [n_batch]
|
||||
ggml_tensor * embd = nullptr; // F32 [n_embd, n_batch]
|
||||
|
||||
const int64_t n_embd = 0;
|
||||
};
|
||||
|
||||
class llm_graph_input_pos : public llm_graph_input_i {
|
||||
@@ -313,6 +317,39 @@ public:
|
||||
const llama_kv_cache_context * mctx;
|
||||
};
|
||||
|
||||
// V-less input for the KV cache
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/19067
|
||||
class llm_graph_input_attn_k : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_attn_k(
|
||||
const llama_hparams & hparams,
|
||||
const llama_cparams & cparams,
|
||||
const llama_kv_cache_context * mctx) :
|
||||
hparams(hparams),
|
||||
cparams(cparams),
|
||||
mctx(mctx) {
|
||||
}
|
||||
~llm_graph_input_attn_k() = default;
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
||||
bool can_reuse(const llm_graph_params & params) override;
|
||||
|
||||
ggml_tensor * get_k_idxs() const { return self_k_idxs; }
|
||||
|
||||
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
||||
|
||||
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
|
||||
|
||||
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
|
||||
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
|
||||
|
||||
const llama_hparams hparams;
|
||||
const llama_cparams cparams;
|
||||
|
||||
const llama_kv_cache_context * mctx;
|
||||
};
|
||||
|
||||
class llm_graph_input_attn_kv_iswa : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_attn_kv_iswa(
|
||||
@@ -396,6 +433,46 @@ public:
|
||||
const llama_memory_hybrid_context * mctx;
|
||||
};
|
||||
|
||||
class llm_graph_input_mem_hybrid_iswa : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_mem_hybrid_iswa(
|
||||
const llama_cparams & cparams,
|
||||
std::unique_ptr<llm_graph_input_attn_kv_iswa> inp_attn,
|
||||
std::unique_ptr<llm_graph_input_rs> inp_rs,
|
||||
const llama_memory_hybrid_iswa_context * mctx) :
|
||||
inp_attn(std::move(inp_attn)),
|
||||
inp_rs(std::move(inp_rs)),
|
||||
cparams(cparams),
|
||||
mctx(mctx) { }
|
||||
virtual ~llm_graph_input_mem_hybrid_iswa() = default;
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
||||
bool can_reuse(const llm_graph_params & params) override;
|
||||
|
||||
std::unique_ptr<llm_graph_input_attn_kv_iswa> inp_attn;
|
||||
std::unique_ptr<llm_graph_input_rs> inp_rs;
|
||||
|
||||
llm_graph_input_attn_kv_iswa * get_attn() const { return inp_attn.get(); }
|
||||
llm_graph_input_rs * get_recr() const { return inp_rs.get(); }
|
||||
|
||||
const llama_cparams cparams;
|
||||
|
||||
const llama_memory_hybrid_iswa_context * mctx;
|
||||
};
|
||||
|
||||
class llm_graph_input_sampling : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_sampling(std::map<llama_seq_id, llama_sampler *> samplers) :
|
||||
samplers(std::move(samplers)) { }
|
||||
virtual ~llm_graph_input_sampling() = default;
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
bool can_reuse(const llm_graph_params & params) override;
|
||||
|
||||
std::map<llama_seq_id, llama_sampler *> samplers;
|
||||
};
|
||||
|
||||
//
|
||||
// llm_graph_result
|
||||
//
|
||||
@@ -429,6 +506,23 @@ struct llm_graph_params {
|
||||
const llama_memory_context_i * mctx;
|
||||
const llama_cross * cross;
|
||||
|
||||
std::map<llama_seq_id, llama_sampler *> samplers;
|
||||
|
||||
static bool samplers_equal(
|
||||
const std::map<llama_seq_id, llama_sampler *> & lhs,
|
||||
const std::map<llama_seq_id, llama_sampler *> & rhs) {
|
||||
if (lhs.size() != rhs.size()) {
|
||||
return false;
|
||||
}
|
||||
for (const auto & [seq_id, sampler] : lhs) {
|
||||
auto it = rhs.find(seq_id);
|
||||
if (it == rhs.end() || it->second != sampler) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
uint32_t n_outputs;
|
||||
|
||||
llm_graph_cb cb;
|
||||
@@ -468,15 +562,36 @@ struct llm_graph_params {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (n_outputs != other.n_outputs) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!samplers_equal(samplers, other.samplers)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (samplers.size() > 0) {
|
||||
if (!ubatch.data || !other.ubatch.data) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// check that the outputs are the same for all samplers
|
||||
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
|
||||
if (ubatch.output[i] != other.ubatch.output[i] ||
|
||||
ubatch.seq_id[i][0] != other.ubatch.seq_id[i][0]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
cparams.embeddings == other.cparams.embeddings &&
|
||||
cparams.causal_attn == other.cparams.causal_attn &&
|
||||
arch == other.arch &&
|
||||
gtype == other.gtype &&
|
||||
cvec == other.cvec &&
|
||||
loras == other.loras &&
|
||||
cross == other.cross &&
|
||||
n_outputs == other.n_outputs;
|
||||
arch == other.arch &&
|
||||
gtype == other.gtype &&
|
||||
cvec == other.cvec &&
|
||||
loras == other.loras &&
|
||||
cross == other.cross;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -486,7 +601,7 @@ public:
|
||||
|
||||
virtual ~llm_graph_result() = default;
|
||||
|
||||
ggml_tensor * get_tokens() const { return t_tokens; }
|
||||
ggml_tensor * get_inp_tokens() const { return t_inp_tokens; }
|
||||
ggml_tensor * get_logits() const { return t_logits; }
|
||||
ggml_tensor * get_embd() const { return t_embd; }
|
||||
ggml_tensor * get_embd_pooled() const { return t_embd_pooled; }
|
||||
@@ -499,6 +614,7 @@ public:
|
||||
void reset();
|
||||
|
||||
void set_inputs(const llama_ubatch * ubatch);
|
||||
void set_outputs();
|
||||
|
||||
// try to update the existing graph result using the new graph parameters in order to reuse it
|
||||
// this can only be done if we determine that the resulting graph using the new graph parameters
|
||||
@@ -512,11 +628,17 @@ public:
|
||||
void set_params(const llm_graph_params & params);
|
||||
|
||||
// important graph nodes
|
||||
ggml_tensor * t_tokens = nullptr;
|
||||
ggml_tensor * t_inp_tokens = nullptr;
|
||||
ggml_tensor * t_inp_embd = nullptr; // [n_embd_inp, n_tokens]
|
||||
ggml_tensor * t_logits = nullptr;
|
||||
ggml_tensor * t_embd = nullptr;
|
||||
ggml_tensor * t_embd_pooled = nullptr;
|
||||
|
||||
std::map<llama_seq_id, ggml_tensor*> t_sampled_logits;
|
||||
std::map<llama_seq_id, ggml_tensor*> t_candidates;
|
||||
std::map<llama_seq_id, ggml_tensor*> t_sampled;
|
||||
std::map<llama_seq_id, ggml_tensor*> t_sampled_probs;
|
||||
|
||||
std::vector<llm_graph_input_ptr> inputs;
|
||||
|
||||
ggml_context_ptr ctx_compute;
|
||||
@@ -592,6 +714,8 @@ struct llm_graph_context {
|
||||
const llama_memory_context_i * mctx;
|
||||
const llama_cross * cross;
|
||||
|
||||
std::map<llama_seq_id, llama_sampler *> samplers;
|
||||
|
||||
const llm_graph_cb & cb_func;
|
||||
|
||||
llm_graph_result * res;
|
||||
@@ -742,6 +866,21 @@ struct llm_graph_context {
|
||||
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
|
||||
ggml_tensor * kq_b,
|
||||
ggml_tensor * sinks, // [n_head_q]
|
||||
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] // TODO: remove
|
||||
float kq_scale,
|
||||
int il) const;
|
||||
|
||||
llm_graph_input_attn_k * build_attn_inp_k() const;
|
||||
|
||||
ggml_tensor * build_attn(
|
||||
llm_graph_input_attn_k * inp,
|
||||
ggml_tensor * wo,
|
||||
ggml_tensor * wo_b,
|
||||
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
|
||||
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
|
||||
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
|
||||
ggml_tensor * kq_b,
|
||||
ggml_tensor * sinks, // [n_head_q]
|
||||
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
||||
float kq_scale,
|
||||
int il) const;
|
||||
@@ -822,6 +961,8 @@ struct llm_graph_context {
|
||||
|
||||
llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const;
|
||||
|
||||
llm_graph_input_mem_hybrid_iswa * build_inp_mem_hybrid_iswa() const;
|
||||
|
||||
//
|
||||
// pooling
|
||||
//
|
||||
@@ -832,6 +973,12 @@ struct llm_graph_context {
|
||||
ggml_tensor * cls_out,
|
||||
ggml_tensor * cls_out_b) const;
|
||||
|
||||
//
|
||||
// sampling (backend sampling)
|
||||
//
|
||||
|
||||
void build_sampling() const;
|
||||
|
||||
//
|
||||
// dense (out)
|
||||
//
|
||||
|
||||
55
llama/llama.cpp/src/llama-hparams.cpp
vendored
55
llama/llama.cpp/src/llama-hparams.cpp
vendored
@@ -72,6 +72,10 @@ uint32_t llama_hparams::n_embd_inp() const {
|
||||
return n_embd_inp;
|
||||
}
|
||||
|
||||
uint32_t llama_hparams::n_embd_out() const {
|
||||
return n_embd_out_impl > 0 ? n_embd_out_impl : n_embd;
|
||||
}
|
||||
|
||||
uint32_t llama_hparams::n_embd_k_gqa(uint32_t il) const {
|
||||
const uint32_t n_head_kv = this->n_head_kv(il);
|
||||
|
||||
@@ -179,6 +183,21 @@ bool llama_hparams::is_swa(uint32_t il) const {
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
bool llama_hparams::is_mla() const {
|
||||
assert((n_embd_head_k_mla_impl == 0 && n_embd_head_v_mla_impl == 0) ||
|
||||
(n_embd_head_k_mla_impl != 0 && n_embd_head_v_mla_impl != 0));
|
||||
|
||||
return n_embd_head_k_mla_impl != 0 && n_embd_head_v_mla_impl != 0;
|
||||
}
|
||||
|
||||
uint32_t llama_hparams::n_embd_head_k_mla() const {
|
||||
return is_mla() ? n_embd_head_k_mla_impl : n_embd_head_k;
|
||||
}
|
||||
|
||||
uint32_t llama_hparams::n_embd_head_v_mla() const {
|
||||
return is_mla() ? n_embd_head_v_mla_impl : n_embd_head_v;
|
||||
}
|
||||
|
||||
bool llama_hparams::has_kv(uint32_t il) const {
|
||||
if (n_layer_kv_from_start >= 0) {
|
||||
if (il < (uint32_t) n_layer_kv_from_start) {
|
||||
@@ -204,42 +223,6 @@ uint32_t llama_hparams::n_layer_kv() const {
|
||||
return res;
|
||||
}
|
||||
|
||||
bool llama_hparams::is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1) {
|
||||
assert(p0 >= 0 && p1 >= 0);
|
||||
|
||||
switch (swa_type) {
|
||||
case LLAMA_SWA_TYPE_NONE:
|
||||
{
|
||||
} break;
|
||||
case LLAMA_SWA_TYPE_STANDARD:
|
||||
{
|
||||
if (p1 - p0 >= (int32_t) n_swa) {
|
||||
return true;
|
||||
}
|
||||
} break;
|
||||
case LLAMA_SWA_TYPE_CHUNKED:
|
||||
{
|
||||
const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa;
|
||||
|
||||
if (p0 < pos_chunk_start) {
|
||||
return true;
|
||||
}
|
||||
} break;
|
||||
case LLAMA_SWA_TYPE_SYMMETRIC:
|
||||
{
|
||||
const int32_t half_n_swa = (int32_t) n_swa / 2;
|
||||
const int32_t pos_diff = p1 - p0;
|
||||
|
||||
// Mask if outside the symmetric window
|
||||
if (pos_diff < -half_n_swa || pos_diff > half_n_swa) {
|
||||
return true;
|
||||
}
|
||||
} break;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool llama_hparams::use_mrope() const {
|
||||
return rope_sections[0] > 0 && rope_sections[1] > 0;
|
||||
}
|
||||
|
||||
66
llama/llama.cpp/src/llama-hparams.h
vendored
66
llama/llama.cpp/src/llama-hparams.h
vendored
@@ -3,6 +3,7 @@
|
||||
#include "llama.h"
|
||||
|
||||
#include <array>
|
||||
#include <cassert>
|
||||
|
||||
// bump if necessary
|
||||
#define LLAMA_MAX_LAYERS 512
|
||||
@@ -52,8 +53,8 @@ struct llama_hparams {
|
||||
uint32_t n_rel_attn_bkts = 0;
|
||||
|
||||
// note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA
|
||||
uint32_t n_embd_head_k_mla = 0;
|
||||
uint32_t n_embd_head_v_mla = 0;
|
||||
uint32_t n_embd_head_k_mla_impl = 0;
|
||||
uint32_t n_embd_head_v_mla_impl = 0;
|
||||
|
||||
// for WavTokenizer
|
||||
struct llama_hparams_posnet posnet;
|
||||
@@ -107,9 +108,9 @@ struct llama_hparams {
|
||||
|
||||
float rope_attn_factor = 1.0f;
|
||||
float rope_freq_base_train;
|
||||
float rope_freq_base_train_swa;
|
||||
float rope_freq_base_train_swa = 10000.0f;
|
||||
float rope_freq_scale_train;
|
||||
float rope_freq_scale_train_swa;
|
||||
float rope_freq_scale_train_swa = 1.0f;
|
||||
|
||||
uint32_t n_ctx_orig_yarn;
|
||||
float rope_yarn_log_mul = 0.0f;
|
||||
@@ -125,10 +126,11 @@ struct llama_hparams {
|
||||
llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
|
||||
// the size of the sliding window (0 - no SWA)
|
||||
uint32_t n_swa = 0;
|
||||
// if swa_layers[il] == true, then layer il is SWA
|
||||
// if swa_layers[il] == false, then layer il is dense (i.e. non-SWA)
|
||||
// if swa_layers[il] == 1, then layer il is SWA
|
||||
// if swa_layers[il] == 0, then layer il is dense (i.e. non-SWA)
|
||||
// by default, all layers are dense
|
||||
std::array<bool, LLAMA_MAX_LAYERS> swa_layers;
|
||||
// note: using uint32_t type for compatibility reason
|
||||
std::array<uint32_t, LLAMA_MAX_LAYERS> swa_layers;
|
||||
|
||||
// for State Space Models
|
||||
uint32_t ssm_d_conv = 0;
|
||||
@@ -163,6 +165,9 @@ struct llama_hparams {
|
||||
// for Classifiers
|
||||
uint32_t n_cls_out = 1;
|
||||
|
||||
// output embedding dimension (0 = use n_embd)
|
||||
uint32_t n_embd_out_impl = 0;
|
||||
|
||||
// llama4 smallthinker
|
||||
uint32_t n_moe_layer_step = 0;
|
||||
uint32_t n_no_rope_layer_step = 4;
|
||||
@@ -235,6 +240,9 @@ struct llama_hparams {
|
||||
// dimension of main + auxiliary input embeddings
|
||||
uint32_t n_embd_inp() const;
|
||||
|
||||
// dimension of output embeddings
|
||||
uint32_t n_embd_out() const;
|
||||
|
||||
// dimension of key embeddings across all k-v heads
|
||||
uint32_t n_embd_k_gqa(uint32_t il = 0) const;
|
||||
|
||||
@@ -266,15 +274,57 @@ struct llama_hparams {
|
||||
|
||||
bool is_swa(uint32_t il) const;
|
||||
|
||||
// note: currently only support if either all or none of the layers are MLA
|
||||
bool is_mla() const;
|
||||
|
||||
uint32_t n_embd_head_k_mla() const;
|
||||
uint32_t n_embd_head_v_mla() const;
|
||||
|
||||
bool has_kv(uint32_t il) const;
|
||||
|
||||
// number of layers for which has_kv() returns true
|
||||
uint32_t n_layer_kv() const;
|
||||
|
||||
// note that this function uses different SWA parameters from those in the hparams
|
||||
// note: inlined on purpose for performance reasons
|
||||
// TODO: think of a better place for this function
|
||||
// TODO: pack the SWA params in a struct?
|
||||
static bool is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1);
|
||||
static bool is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1) {
|
||||
assert(p0 >= 0 && p1 >= 0);
|
||||
|
||||
switch (swa_type) {
|
||||
case LLAMA_SWA_TYPE_NONE:
|
||||
{
|
||||
} break;
|
||||
case LLAMA_SWA_TYPE_STANDARD:
|
||||
{
|
||||
if (p1 - p0 >= (int32_t) n_swa) {
|
||||
return true;
|
||||
}
|
||||
} break;
|
||||
case LLAMA_SWA_TYPE_CHUNKED:
|
||||
{
|
||||
const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa;
|
||||
|
||||
if (p0 < pos_chunk_start) {
|
||||
return true;
|
||||
}
|
||||
} break;
|
||||
case LLAMA_SWA_TYPE_SYMMETRIC:
|
||||
{
|
||||
const int32_t half_n_swa = (int32_t) n_swa / 2;
|
||||
const int32_t pos_diff = p1 - p0;
|
||||
|
||||
// Mask if outside the symmetric window
|
||||
if (pos_diff < -half_n_swa || pos_diff > half_n_swa) {
|
||||
return true;
|
||||
}
|
||||
} break;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
bool use_mrope() const;
|
||||
};
|
||||
|
||||
326
llama/llama.cpp/src/llama-kv-cache.cpp
vendored
326
llama/llama.cpp/src/llama-kv-cache.cpp
vendored
@@ -97,6 +97,8 @@ llama_kv_cache::llama_kv_cache(
|
||||
__func__, hparams.n_embd_v_gqa_max());
|
||||
}
|
||||
|
||||
const bool is_mla = hparams.is_mla();
|
||||
|
||||
for (uint32_t il = 0; il < hparams.n_layer; il++) {
|
||||
if (!hparams.has_kv(il)) {
|
||||
LLAMA_LOG_DEBUG("%s: layer %3d: does not have KV cache\n", __func__, il);
|
||||
@@ -130,18 +132,21 @@ llama_kv_cache::llama_kv_cache(
|
||||
throw std::runtime_error("failed to create ggml context for kv cache");
|
||||
}
|
||||
|
||||
ggml_tensor * k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream);
|
||||
ggml_tensor * v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream);
|
||||
const bool has_k = true;
|
||||
const bool has_v = !is_mla;
|
||||
|
||||
ggml_format_name(k, "cache_k_l%d", il);
|
||||
ggml_format_name(v, "cache_v_l%d", il);
|
||||
ggml_tensor * k = has_k ? ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream) : nullptr;
|
||||
ggml_tensor * v = has_v ? ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream) : nullptr;
|
||||
|
||||
has_k && ggml_format_name(k, "cache_k_l%d", il);
|
||||
has_v && ggml_format_name(v, "cache_v_l%d", il);
|
||||
|
||||
std::vector<ggml_tensor *> k_stream;
|
||||
std::vector<ggml_tensor *> v_stream;
|
||||
|
||||
for (uint32_t s = 0; s < n_stream; ++s) {
|
||||
k_stream.push_back(ggml_view_2d(ctx, k, n_embd_k_gqa, kv_size, k->nb[1], s*k->nb[2]));
|
||||
v_stream.push_back(ggml_view_2d(ctx, v, n_embd_v_gqa, kv_size, v->nb[1], s*v->nb[2]));
|
||||
k_stream.push_back(has_k ? ggml_view_2d(ctx, k, n_embd_k_gqa, kv_size, k->nb[1], s*k->nb[2]) : nullptr);
|
||||
v_stream.push_back(has_v ? ggml_view_2d(ctx, v, n_embd_v_gqa, kv_size, v->nb[1], s*v->nb[2]) : nullptr);
|
||||
}
|
||||
|
||||
map_layer_ids[il] = layers.size();
|
||||
@@ -647,7 +652,10 @@ bool llama_kv_cache::update(llama_context * lctx, bool do_shift, const stream_co
|
||||
const auto & layer = layers[il];
|
||||
|
||||
ggml_backend_tensor_copy(layer.k_stream[ssrc], layer.k_stream[sdst]);
|
||||
ggml_backend_tensor_copy(layer.v_stream[ssrc], layer.v_stream[sdst]);
|
||||
|
||||
if (layer.v_stream[ssrc]) {
|
||||
ggml_backend_tensor_copy(layer.v_stream[ssrc], layer.v_stream[sdst]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -852,7 +860,7 @@ llama_kv_cache::slot_info llama_kv_cache::find_slot(const llama_ubatch & ubatch,
|
||||
const llama_seq_id seq_id_cell = cells.seq_get(idx);
|
||||
|
||||
// SWA mask
|
||||
if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
|
||||
if (llama_hparams::is_masked_swa(n_swa, swa_type, pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
|
||||
can_use = true;
|
||||
}
|
||||
}
|
||||
@@ -1237,6 +1245,197 @@ void llama_kv_cache::set_input_k_shift(ggml_tensor * dst) const {
|
||||
}
|
||||
}
|
||||
|
||||
struct args_set_input_kq_mask {
|
||||
const llama_hparams & hparams;
|
||||
const llama_ubatch * ubatch;
|
||||
|
||||
const std::vector<llama_kv_cells> & v_cells;
|
||||
const std::vector<uint32_t> & seq_to_stream;
|
||||
|
||||
uint32_t n_swa;
|
||||
llama_swa_type swa_type;
|
||||
|
||||
int64_t n_kv;
|
||||
int64_t n_stream;
|
||||
int64_t n_tps;
|
||||
};
|
||||
|
||||
template<bool causal, bool swa, bool is_2d, bool alibi>
|
||||
static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) {
|
||||
//const auto & hparams = args.hparams;
|
||||
const auto & ubatch = args.ubatch;
|
||||
|
||||
const auto & v_cells = args.v_cells;
|
||||
const auto & seq_to_stream = args.seq_to_stream;
|
||||
|
||||
const uint32_t n_swa = args.n_swa;
|
||||
const llama_swa_type swa_type = args.swa_type;
|
||||
|
||||
const int64_t n_kv = args.n_kv;
|
||||
const int64_t n_stream = args.n_stream;
|
||||
const int64_t n_tps = args.n_tps;
|
||||
|
||||
// the min position in the batch for each sequence
|
||||
llama_pos seq_pos_min[LLAMA_MAX_SEQ];
|
||||
std::fill(seq_pos_min, seq_pos_min + LLAMA_MAX_SEQ, INT32_MAX);
|
||||
|
||||
for (uint32_t i = 0; i < ubatch->n_tokens; ++i) {
|
||||
const llama_seq_id seq_id = ubatch->seq_id[i][0];
|
||||
|
||||
seq_pos_min[seq_id] = std::min(seq_pos_min[seq_id], ubatch->pos[i]);
|
||||
}
|
||||
|
||||
for (uint32_t s = 0; s < n_stream; ++s) {
|
||||
// bookeeping of the KQ mask cells that could change for other tokens of the same sequence
|
||||
std::unordered_map<llama_seq_id, uint32_t> seq_srct;
|
||||
std::unordered_map<llama_seq_id, std::vector<uint32_t>> seq_idxs;
|
||||
|
||||
for (uint32_t ii = 0; ii < n_tps; ++ii) {
|
||||
const uint32_t i = s*n_tps + ii;
|
||||
|
||||
const llama_seq_id seq_id = ubatch->seq_id[i][0];
|
||||
|
||||
const auto & cells = v_cells.at(seq_to_stream[seq_id]);
|
||||
|
||||
llama_pos p0 = -1;
|
||||
const llama_pos p1 = ubatch->pos[i];
|
||||
|
||||
// for M-RoPE
|
||||
const llama_pos p1_x = is_2d ? ubatch->pos[i + ubatch->n_tokens*2] : 0;
|
||||
const llama_pos p1_y = is_2d ? ubatch->pos[i + ubatch->n_tokens] : 0;
|
||||
|
||||
const uint64_t idst = n_kv*i;
|
||||
|
||||
// for tokens of the same sequence, the mask is mostly the same, so we can reuse it
|
||||
// the only cells that could change are the ones that are with similar positions as the
|
||||
// ones in the batch (i.e. due to causal masking, SWA, etc.)
|
||||
// keep track of those cells and shortcut the loop to save time
|
||||
// note: this optimization is not compatible with Alibi position encoding
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/18842
|
||||
bool prev = false;
|
||||
|
||||
auto & idxs = seq_idxs[seq_id];
|
||||
|
||||
if (!alibi) {
|
||||
if (seq_srct.find(seq_id) != seq_srct.end()) {
|
||||
const uint32_t srct = seq_srct[seq_id];
|
||||
|
||||
const uint64_t idst_prev = n_kv*srct;
|
||||
|
||||
std::copy(data + idst_prev, data + idst_prev + n_kv, data + idst);
|
||||
|
||||
prev = true;
|
||||
} else {
|
||||
idxs.clear();
|
||||
idxs.reserve(ubatch->n_tokens + n_swa + 32);
|
||||
|
||||
seq_srct[seq_id] = i;
|
||||
}
|
||||
}
|
||||
|
||||
for (uint32_t jj = 0; jj < n_kv; ++jj) {
|
||||
uint32_t j = jj;
|
||||
|
||||
// we have an exiting mask for this sequence -> update just seq_idxs
|
||||
if (!alibi) {
|
||||
if (prev) {
|
||||
if (jj >= idxs.size()) {
|
||||
break;
|
||||
}
|
||||
|
||||
j = idxs[jj];
|
||||
}
|
||||
}
|
||||
|
||||
if (cells.is_empty(j)) {
|
||||
goto skip;
|
||||
}
|
||||
|
||||
// mask the token if not the same sequence
|
||||
if (!cells.seq_has(j, seq_id)) {
|
||||
goto skip;
|
||||
}
|
||||
|
||||
p0 = cells.pos_get(j);
|
||||
|
||||
if (!alibi) {
|
||||
if (!prev) {
|
||||
// record all cells for which: p0 >= seq_pos_min[seq_id] - n_swa - 32
|
||||
if (p0 + (int32_t) (n_swa + 32) >= seq_pos_min[seq_id]) {
|
||||
idxs.push_back(j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (causal) {
|
||||
// mask future tokens
|
||||
if (p0 > p1) {
|
||||
goto skip;
|
||||
}
|
||||
|
||||
// M-RoPE causal mask
|
||||
if (is_2d) {
|
||||
if (p0 == p1) {
|
||||
const auto & p0_ext = cells.ext_get(j);
|
||||
|
||||
if (p0_ext.is_2d_gt(p1_x, p1_y)) {
|
||||
goto skip;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// apply SWA if any
|
||||
if (swa) {
|
||||
if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) {
|
||||
goto skip;
|
||||
}
|
||||
}
|
||||
|
||||
if (alibi) {
|
||||
data[idst + j] = -std::abs(p0 - p1);
|
||||
} else {
|
||||
data[idst + j] = 0.0f;
|
||||
}
|
||||
|
||||
continue;
|
||||
skip:
|
||||
data[idst + j] = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<bool causal, bool swa, bool is_2d>
|
||||
static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) {
|
||||
const bool alibi = args.hparams.use_alibi;
|
||||
if (alibi) {
|
||||
set_input_kq_mask_impl<causal, swa, is_2d, true> (args, data);
|
||||
} else {
|
||||
set_input_kq_mask_impl<causal, swa, is_2d, false>(args, data);
|
||||
}
|
||||
}
|
||||
|
||||
template<bool causal, bool swa>
|
||||
static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) {
|
||||
const bool is_2d = args.ubatch->is_pos_2d();
|
||||
if (is_2d) {
|
||||
set_input_kq_mask_impl<causal, swa, true> (args, data);
|
||||
} else {
|
||||
set_input_kq_mask_impl<causal, swa, false>(args, data);
|
||||
}
|
||||
}
|
||||
|
||||
template<bool causal>
|
||||
static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) {
|
||||
const bool swa = args.swa_type != LLAMA_SWA_TYPE_NONE;
|
||||
if (swa) {
|
||||
set_input_kq_mask_impl<causal, true> (args, data);
|
||||
} else {
|
||||
set_input_kq_mask_impl<causal, false>(args, data);
|
||||
}
|
||||
}
|
||||
|
||||
void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
|
||||
const uint32_t n_tokens = ubatch->n_tokens;
|
||||
|
||||
@@ -1251,74 +1450,29 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u
|
||||
// n_tps == n_tokens_per_stream
|
||||
const int64_t n_tps = n_tokens/n_stream;
|
||||
|
||||
std::fill(data, data + ggml_nelements(dst), -INFINITY);
|
||||
//const int64_t t_start = ggml_time_us();
|
||||
|
||||
// Use only the previous KV cells of the correct sequence for each token of the ubatch.
|
||||
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
|
||||
// Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
|
||||
// Causal mask:
|
||||
// xxx-------
|
||||
// xxxx------
|
||||
// xxxxx-----
|
||||
// Non-causal mask:
|
||||
// xxxxx-----
|
||||
// xxxxx-----
|
||||
// xxxxx-----
|
||||
// To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
|
||||
// TODO: optimize this section
|
||||
for (uint32_t h = 0; h < 1; ++h) {
|
||||
for (uint32_t s = 0; s < n_stream; ++s) {
|
||||
for (uint32_t ii = 0; ii < n_tps; ++ii) {
|
||||
const uint32_t i = s*n_tps + ii;
|
||||
const args_set_input_kq_mask args = {
|
||||
/*.hparams =*/ hparams,
|
||||
/*.ubatch =*/ ubatch,
|
||||
/*.v_cells =*/ v_cells,
|
||||
/*.seq_to_stream =*/ seq_to_stream,
|
||||
/*.n_swa =*/ n_swa,
|
||||
/*.swa_type =*/ swa_type,
|
||||
/*.n_kv =*/ n_kv,
|
||||
/*.n_stream =*/ n_stream,
|
||||
/*.n_tps =*/ n_tps,
|
||||
};
|
||||
|
||||
const llama_seq_id seq_id = ubatch->seq_id[i][0];
|
||||
|
||||
const auto & cells = v_cells[seq_to_stream[seq_id]];
|
||||
|
||||
const llama_pos p1 = ubatch->pos[i];
|
||||
|
||||
// for M-RoPE
|
||||
const bool is_2d = ubatch->is_pos_2d();
|
||||
const llama_pos p1_x = is_2d ? ubatch->pos[i + ubatch->n_tokens*2] : 0;
|
||||
const llama_pos p1_y = is_2d ? ubatch->pos[i + ubatch->n_tokens] : 0;
|
||||
|
||||
const uint64_t idst = n_kv*(h*n_stream*n_tps + s*n_tps + ii);
|
||||
|
||||
for (uint32_t j = 0; j < n_kv; ++j) {
|
||||
if (cells.is_empty(j)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// mask the token if not the same sequence
|
||||
if (!cells.seq_has(j, seq_id)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const llama_pos p0 = cells.pos_get(j);
|
||||
|
||||
// mask future tokens
|
||||
if (causal_attn && p0 > p1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// M-RoPE causal mask
|
||||
if (causal_attn && is_2d && p0 == p1) {
|
||||
const auto & p0_ext = cells.ext_get(j);
|
||||
if (p0_ext.is_2d_gt(p1_x, p1_y)) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// apply SWA if any
|
||||
if (is_masked_swa(p0, p1)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
data[idst + j] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (causal_attn) {
|
||||
set_input_kq_mask_impl<true> (args, data);
|
||||
} else {
|
||||
set_input_kq_mask_impl<false>(args, data);
|
||||
}
|
||||
|
||||
//const int64_t t_end = ggml_time_us();
|
||||
|
||||
//LLAMA_LOG_ERROR("%s: kq mask time: %0.3f ms\n", __func__, (t_end - t_start)/1000.0);
|
||||
}
|
||||
|
||||
void llama_kv_cache::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
|
||||
@@ -1370,7 +1524,7 @@ size_t llama_kv_cache::size_v_bytes() const {
|
||||
size_t size_v_bytes = 0;
|
||||
|
||||
for (const auto & layer : layers) {
|
||||
size_v_bytes += ggml_nbytes(layer.v);
|
||||
size_v_bytes += layer.v ? ggml_nbytes(layer.v) : 0;
|
||||
}
|
||||
|
||||
return size_v_bytes;
|
||||
@@ -1448,6 +1602,10 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
|
||||
const auto & n_embd_head_k = hparams.n_embd_head_k;
|
||||
//const auto & n_embd_head_v = hparams.n_embd_head_v;
|
||||
|
||||
const auto & n_rot = hparams.n_rot;
|
||||
|
||||
const auto n_embd_nope = hparams.n_lora_kv > 0 ? n_embd_head_k - n_rot : 0;
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_k_shift>(this);
|
||||
|
||||
inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, (int64_t) get_size()*n_stream);
|
||||
@@ -1468,10 +1626,10 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
|
||||
|
||||
ggml_tensor * k =
|
||||
ggml_view_3d(ctx, layer.k,
|
||||
n_embd_head_k, n_head_kv, get_size()*n_stream,
|
||||
n_rot, n_head_kv, get_size()*n_stream,
|
||||
ggml_row_size(layer.k->type, n_embd_head_k),
|
||||
ggml_row_size(layer.k->type, n_embd_k_gqa),
|
||||
0);
|
||||
ggml_row_size(layer.k->type, n_embd_nope));
|
||||
|
||||
ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l);
|
||||
|
||||
@@ -1483,10 +1641,6 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
|
||||
return gf;
|
||||
}
|
||||
|
||||
bool llama_kv_cache::is_masked_swa(llama_pos p0, llama_pos p1) const {
|
||||
return llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1);
|
||||
}
|
||||
|
||||
void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
|
||||
GGML_UNUSED(flags);
|
||||
|
||||
@@ -1652,6 +1806,9 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
||||
|
||||
auto * v = layer.v_stream[cr.strm];
|
||||
if (!v) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Write value type
|
||||
const int32_t v_type_i = (int32_t) v->type;
|
||||
@@ -1678,6 +1835,9 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
||||
|
||||
auto * v = layer.v_stream[cr.strm];
|
||||
if (!v) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Write value type
|
||||
const int32_t v_type_i = (int32_t) v->type;
|
||||
@@ -1881,6 +2041,9 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
||||
|
||||
auto * v = layer.v_stream[strm];
|
||||
if (!v) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Read type of value
|
||||
int32_t v_type_i_ref;
|
||||
@@ -1922,6 +2085,9 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
||||
|
||||
auto * v = layer.v_stream[strm];
|
||||
if (!v) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Read type of value
|
||||
int32_t v_type_i_ref;
|
||||
|
||||
4
llama/llama.cpp/src/llama-kv-cache.h
vendored
4
llama/llama.cpp/src/llama-kv-cache.h
vendored
@@ -257,8 +257,6 @@ private:
|
||||
size_t size_k_bytes() const;
|
||||
size_t size_v_bytes() const;
|
||||
|
||||
bool is_masked_swa(llama_pos p0, llama_pos p1) const;
|
||||
|
||||
ggml_tensor * build_rope_shift(
|
||||
const llama_cparams & cparams,
|
||||
ggml_context * ctx,
|
||||
@@ -305,7 +303,7 @@ public:
|
||||
bool do_shift,
|
||||
stream_copy_info sc_info);
|
||||
|
||||
// used to create a batch procesing context from a batch
|
||||
// used to create a batch processing context from a batch
|
||||
llama_kv_cache_context(
|
||||
llama_kv_cache * kv,
|
||||
slot_info_vec_t sinfos,
|
||||
|
||||
275
llama/llama.cpp/src/llama-memory-hybrid-iswa.cpp
vendored
Normal file
275
llama/llama.cpp/src/llama-memory-hybrid-iswa.cpp
vendored
Normal file
@@ -0,0 +1,275 @@
|
||||
#include "llama-memory-hybrid-iswa.h"
|
||||
|
||||
#include "llama-impl.h"
|
||||
#include "llama-model.h"
|
||||
#include "llama-context.h"
|
||||
|
||||
//
|
||||
// llama_memory_hybrid_iswa
|
||||
//
|
||||
|
||||
llama_memory_hybrid_iswa::llama_memory_hybrid_iswa(
|
||||
const llama_model & model,
|
||||
/* attn */
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
bool v_trans,
|
||||
bool swa_full,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_ubatch,
|
||||
uint32_t n_pad,
|
||||
/* recurrent */
|
||||
ggml_type type_r,
|
||||
ggml_type type_s,
|
||||
uint32_t rs_size,
|
||||
/* common */
|
||||
uint32_t n_seq_max,
|
||||
bool offload,
|
||||
bool unified,
|
||||
/* layer filters */
|
||||
const layer_filter_cb & filter_attn,
|
||||
const layer_filter_cb & filter_recr) :
|
||||
hparams(model.hparams),
|
||||
mem_attn(new llama_kv_cache_iswa(
|
||||
model,
|
||||
type_k,
|
||||
type_v,
|
||||
v_trans,
|
||||
offload,
|
||||
swa_full,
|
||||
unified,
|
||||
kv_size,
|
||||
n_seq_max,
|
||||
n_ubatch,
|
||||
n_pad,
|
||||
filter_attn == nullptr ?
|
||||
[&](int32_t il) { return !hparams.is_recurrent(il); }
|
||||
: filter_attn,
|
||||
nullptr
|
||||
)),
|
||||
mem_recr(new llama_memory_recurrent(
|
||||
model,
|
||||
type_r,
|
||||
type_s,
|
||||
offload,
|
||||
rs_size,
|
||||
n_seq_max,
|
||||
filter_recr == nullptr ?
|
||||
[&](int32_t il) { return hparams.is_recurrent(il); }
|
||||
: filter_recr
|
||||
)) {}
|
||||
|
||||
llama_memory_context_ptr llama_memory_hybrid_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
|
||||
do {
|
||||
balloc.split_reset();
|
||||
|
||||
// follow the recurrent pattern for creating the ubatch splits
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
|
||||
while (true) {
|
||||
llama_ubatch ubatch;
|
||||
|
||||
if (embd_all) {
|
||||
// if all tokens are output, split by sequence
|
||||
ubatch = balloc.split_seq(n_ubatch);
|
||||
} else {
|
||||
// TODO: non-sequential equal split can be done if using unified KV cache
|
||||
// for simplicity, we always use sequential equal split for now
|
||||
ubatch = balloc.split_equal(n_ubatch, true);
|
||||
}
|
||||
|
||||
if (ubatch.n_tokens == 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
ubatches.push_back(std::move(ubatch)); // NOLINT
|
||||
}
|
||||
|
||||
if (balloc.get_n_used() < balloc.get_n_tokens()) {
|
||||
// failed to find a suitable split
|
||||
break;
|
||||
}
|
||||
|
||||
// prepare the recurrent batches first
|
||||
if (!mem_recr->prepare(ubatches)) {
|
||||
// TODO: will the recurrent cache be in an undefined context at this point?
|
||||
LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__);
|
||||
return std::make_unique<llama_memory_hybrid_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||
}
|
||||
|
||||
// prepare the attention cache (iswa version returns both base and swa slot infos)
|
||||
auto sinfos_base = mem_attn->get_base()->prepare(ubatches);
|
||||
if (sinfos_base.empty()) {
|
||||
LLAMA_LOG_ERROR("%s: failed to prepare attention base ubatches\n", __func__);
|
||||
return std::make_unique<llama_memory_hybrid_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||
}
|
||||
|
||||
auto sinfos_swa = mem_attn->get_swa()->prepare(ubatches);
|
||||
if (sinfos_swa.empty()) {
|
||||
LLAMA_LOG_ERROR("%s: failed to prepare attention swa ubatches\n", __func__);
|
||||
return std::make_unique<llama_memory_hybrid_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||
}
|
||||
|
||||
return std::make_unique<llama_memory_hybrid_iswa_context>(
|
||||
this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
|
||||
} while(false);
|
||||
|
||||
return std::make_unique<llama_memory_hybrid_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||
}
|
||||
|
||||
llama_memory_context_ptr llama_memory_hybrid_iswa::init_full() {
|
||||
return std::make_unique<llama_memory_hybrid_iswa_context>(this);
|
||||
}
|
||||
|
||||
llama_memory_context_ptr llama_memory_hybrid_iswa::init_update(llama_context * lctx, bool optimize) {
|
||||
return std::make_unique<llama_memory_hybrid_iswa_context>(this, lctx, optimize);
|
||||
}
|
||||
|
||||
bool llama_memory_hybrid_iswa::get_can_shift() const {
|
||||
// Shifting is trivially supported for recurrent
|
||||
return mem_attn->get_can_shift();
|
||||
}
|
||||
|
||||
void llama_memory_hybrid_iswa::clear(bool data) {
|
||||
mem_attn->clear(data);
|
||||
mem_recr->clear(data);
|
||||
}
|
||||
|
||||
bool llama_memory_hybrid_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
||||
// Try removing from the recurrent cache first since it may fail. If it does
|
||||
// fail, the cache will not have been mutated.
|
||||
if (!mem_recr->seq_rm(seq_id, p0, p1)) {
|
||||
return false;
|
||||
}
|
||||
return mem_attn->seq_rm(seq_id, p0, p1);
|
||||
}
|
||||
|
||||
void llama_memory_hybrid_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
|
||||
mem_attn->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
||||
mem_recr->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
||||
}
|
||||
|
||||
void llama_memory_hybrid_iswa::seq_keep(llama_seq_id seq_id) {
|
||||
mem_attn->seq_keep(seq_id);
|
||||
mem_recr->seq_keep(seq_id);
|
||||
}
|
||||
|
||||
void llama_memory_hybrid_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
||||
mem_attn->seq_add(seq_id, p0, p1, shift);
|
||||
mem_recr->seq_add(seq_id, p0, p1, shift);
|
||||
}
|
||||
|
||||
void llama_memory_hybrid_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
||||
mem_attn->seq_div(seq_id, p0, p1, d);
|
||||
mem_recr->seq_div(seq_id, p0, p1, d);
|
||||
}
|
||||
|
||||
llama_pos llama_memory_hybrid_iswa::seq_pos_min(llama_seq_id seq_id) const {
|
||||
// the min of the total cache is the max of the two caches' min values
|
||||
return std::max(mem_attn->seq_pos_min(seq_id), mem_recr->seq_pos_min(seq_id));
|
||||
}
|
||||
|
||||
llama_pos llama_memory_hybrid_iswa::seq_pos_max(llama_seq_id seq_id) const {
|
||||
// the max of the total cache is the min of the two caches' max values
|
||||
return std::min(mem_attn->seq_pos_max(seq_id), mem_recr->seq_pos_max(seq_id));
|
||||
}
|
||||
|
||||
std::map<ggml_backend_buffer_type_t, size_t> llama_memory_hybrid_iswa::memory_breakdown() const {
|
||||
std::map<ggml_backend_buffer_type_t, size_t> mb = mem_attn->memory_breakdown();
|
||||
for (const auto & buft_size : mem_recr->memory_breakdown()) {
|
||||
mb[buft_size.first] += buft_size.second;
|
||||
}
|
||||
return mb;
|
||||
}
|
||||
|
||||
void llama_memory_hybrid_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
|
||||
mem_attn->state_write(io, seq_id, flags);
|
||||
mem_recr->state_write(io, seq_id, flags);
|
||||
}
|
||||
|
||||
void llama_memory_hybrid_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
|
||||
mem_attn->state_read(io, seq_id, flags);
|
||||
mem_recr->state_read(io, seq_id, flags);
|
||||
}
|
||||
|
||||
llama_kv_cache_iswa * llama_memory_hybrid_iswa::get_mem_attn() const {
|
||||
return mem_attn.get();
|
||||
}
|
||||
|
||||
llama_memory_recurrent * llama_memory_hybrid_iswa::get_mem_recr() const {
|
||||
return mem_recr.get();
|
||||
}
|
||||
|
||||
//
|
||||
// llama_memory_hybrid_iswa_context
|
||||
//
|
||||
|
||||
llama_memory_hybrid_iswa_context::llama_memory_hybrid_iswa_context(llama_memory_status status) : status(status) {}
|
||||
|
||||
llama_memory_hybrid_iswa_context::llama_memory_hybrid_iswa_context(llama_memory_hybrid_iswa * mem) :
|
||||
ctx_attn(mem->get_mem_attn()->init_full()),
|
||||
ctx_recr(mem->get_mem_recr()->init_full()),
|
||||
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
|
||||
}
|
||||
|
||||
llama_memory_hybrid_iswa_context::llama_memory_hybrid_iswa_context(
|
||||
llama_memory_hybrid_iswa * mem,
|
||||
llama_context * lctx,
|
||||
bool optimize) :
|
||||
ctx_attn(mem->get_mem_attn()->init_update(lctx, optimize)),
|
||||
ctx_recr(mem->get_mem_recr()->init_update(lctx, optimize)),
|
||||
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
|
||||
}
|
||||
|
||||
llama_memory_hybrid_iswa_context::llama_memory_hybrid_iswa_context(
|
||||
llama_memory_hybrid_iswa * mem,
|
||||
slot_info_vec_t sinfos_base,
|
||||
slot_info_vec_t sinfos_swa,
|
||||
std::vector<llama_ubatch> ubatches) :
|
||||
ubatches(std::move(ubatches)),
|
||||
// note: here we copy the ubatches. not sure if this is ideal
|
||||
ctx_attn(new llama_kv_cache_iswa_context(mem->get_mem_attn(), std::move(sinfos_base), std::move(sinfos_swa), this->ubatches)),
|
||||
ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)),
|
||||
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
|
||||
}
|
||||
|
||||
bool llama_memory_hybrid_iswa_context::next() {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
|
||||
ctx_attn->next();
|
||||
ctx_recr->next();
|
||||
|
||||
if (++i_next >= ubatches.size()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool llama_memory_hybrid_iswa_context::apply() {
|
||||
assert(!llama_memory_status_is_fail(status));
|
||||
|
||||
bool res = true;
|
||||
|
||||
res = res & ctx_attn->apply();
|
||||
res = res & ctx_recr->apply();
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
llama_memory_status llama_memory_hybrid_iswa_context::get_status() const {
|
||||
return status;
|
||||
}
|
||||
|
||||
const llama_ubatch & llama_memory_hybrid_iswa_context::get_ubatch() const {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
return ubatches[i_next];
|
||||
}
|
||||
|
||||
const llama_kv_cache_iswa_context * llama_memory_hybrid_iswa_context::get_attn() const {
|
||||
return static_cast<const llama_kv_cache_iswa_context *>(ctx_attn.get());
|
||||
}
|
||||
|
||||
const llama_memory_recurrent_context * llama_memory_hybrid_iswa_context::get_recr() const {
|
||||
return static_cast<const llama_memory_recurrent_context *>(ctx_recr.get());
|
||||
}
|
||||
140
llama/llama.cpp/src/llama-memory-hybrid-iswa.h
vendored
Normal file
140
llama/llama.cpp/src/llama-memory-hybrid-iswa.h
vendored
Normal file
@@ -0,0 +1,140 @@
|
||||
#pragma once
|
||||
|
||||
#include "llama-batch.h"
|
||||
#include "llama-graph.h"
|
||||
#include "llama-kv-cache-iswa.h"
|
||||
#include "llama-memory.h"
|
||||
#include "llama-memory-recurrent.h"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
//
|
||||
// llama_memory_hybrid_iswa
|
||||
//
|
||||
|
||||
// utilizes instances of llama_memory_recurrent and llama_kv_cache_iswa to
|
||||
// support models where each layer may be either attention-based (with SWA support) or recurrent
|
||||
|
||||
class llama_memory_hybrid_iswa : public llama_memory_i {
|
||||
public:
|
||||
llama_memory_hybrid_iswa(
|
||||
const llama_model & model,
|
||||
/* attn */
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
bool v_trans,
|
||||
bool swa_full,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_ubatch,
|
||||
uint32_t n_pad,
|
||||
/* recurrent */
|
||||
ggml_type type_r,
|
||||
ggml_type type_s,
|
||||
uint32_t rs_size,
|
||||
/* common */
|
||||
uint32_t n_seq_max,
|
||||
bool offload,
|
||||
bool unified,
|
||||
/* layer filters */
|
||||
const layer_filter_cb & filter_attn = nullptr,
|
||||
const layer_filter_cb & filter_recr = nullptr);
|
||||
|
||||
~llama_memory_hybrid_iswa() = default;
|
||||
|
||||
//
|
||||
// llama_memory_i
|
||||
//
|
||||
|
||||
llama_memory_context_ptr init_batch(
|
||||
llama_batch_allocr & balloc,
|
||||
uint32_t n_ubatch,
|
||||
bool embd_all) override;
|
||||
|
||||
llama_memory_context_ptr init_full() override;
|
||||
|
||||
llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
|
||||
|
||||
bool get_can_shift() const override;
|
||||
|
||||
void clear(bool data) override;
|
||||
|
||||
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
||||
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
||||
void seq_keep(llama_seq_id seq_id) override;
|
||||
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
|
||||
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
||||
|
||||
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
|
||||
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
||||
|
||||
std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const override;
|
||||
|
||||
// state write/load
|
||||
|
||||
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
|
||||
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
|
||||
|
||||
//
|
||||
// llama_memory_hybrid_iswa specific API
|
||||
//
|
||||
|
||||
llama_kv_cache_iswa * get_mem_attn() const;
|
||||
llama_memory_recurrent * get_mem_recr() const;
|
||||
|
||||
private:
|
||||
const llama_hparams & hparams;
|
||||
|
||||
const std::unique_ptr<llama_kv_cache_iswa> mem_attn;
|
||||
const std::unique_ptr<llama_memory_recurrent> mem_recr;
|
||||
};
|
||||
|
||||
class llama_memory_hybrid_iswa_context : public llama_memory_context_i {
|
||||
public:
|
||||
using slot_info_vec_t = llama_kv_cache::slot_info_vec_t;
|
||||
|
||||
// init failure
|
||||
explicit llama_memory_hybrid_iswa_context(llama_memory_status status);
|
||||
|
||||
// init full
|
||||
explicit llama_memory_hybrid_iswa_context(llama_memory_hybrid_iswa * mem);
|
||||
|
||||
// init update
|
||||
explicit llama_memory_hybrid_iswa_context(
|
||||
llama_memory_hybrid_iswa * mem,
|
||||
llama_context * lctx,
|
||||
bool optimize);
|
||||
|
||||
// init success
|
||||
llama_memory_hybrid_iswa_context(
|
||||
llama_memory_hybrid_iswa * mem,
|
||||
slot_info_vec_t sinfos_base,
|
||||
slot_info_vec_t sinfos_swa,
|
||||
std::vector<llama_ubatch> ubatches);
|
||||
|
||||
~llama_memory_hybrid_iswa_context() = default;
|
||||
|
||||
bool next() override;
|
||||
bool apply() override;
|
||||
|
||||
llama_memory_status get_status() const override;
|
||||
const llama_ubatch & get_ubatch() const override;
|
||||
|
||||
//
|
||||
// llama_memory_hybrid_iswa_context
|
||||
//
|
||||
|
||||
const llama_kv_cache_iswa_context * get_attn() const;
|
||||
const llama_memory_recurrent_context * get_recr() const;
|
||||
|
||||
private:
|
||||
// the index of the next ubatch to process
|
||||
size_t i_next = 0;
|
||||
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
|
||||
const llama_memory_context_ptr ctx_attn;
|
||||
const llama_memory_context_ptr ctx_recr;
|
||||
|
||||
const llama_memory_status status;
|
||||
};
|
||||
222
llama/llama.cpp/src/llama-mmap.cpp
vendored
222
llama/llama.cpp/src/llama-mmap.cpp
vendored
@@ -13,9 +13,10 @@
|
||||
#ifdef __has_include
|
||||
#if __has_include(<unistd.h>)
|
||||
#include <unistd.h>
|
||||
#include <fcntl.h>
|
||||
#include <sys/stat.h>
|
||||
#if defined(_POSIX_MAPPED_FILES)
|
||||
#include <sys/mman.h>
|
||||
#include <fcntl.h>
|
||||
#endif
|
||||
#if defined(_POSIX_MEMLOCK_RANGE)
|
||||
#include <sys/resource.h>
|
||||
@@ -74,7 +75,7 @@ struct llama_file::impl {
|
||||
return ret;
|
||||
}
|
||||
|
||||
impl(const char * fname, const char * mode) {
|
||||
impl(const char * fname, const char * mode, [[maybe_unused]] const bool use_direct_io = false) {
|
||||
fp = ggml_fopen(fname, mode);
|
||||
if (fp == NULL) {
|
||||
throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno)));
|
||||
@@ -109,7 +110,7 @@ struct llama_file::impl {
|
||||
}
|
||||
}
|
||||
|
||||
void read_raw(void * ptr, size_t len) const {
|
||||
void read_raw(void * ptr, size_t len) {
|
||||
size_t bytes_read = 0;
|
||||
while (bytes_read < len) {
|
||||
size_t chunk_size = std::min<size_t>(len - bytes_read, 64*1024*1024);
|
||||
@@ -126,7 +127,7 @@ struct llama_file::impl {
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t read_u32() const {
|
||||
uint32_t read_u32() {
|
||||
uint32_t val;
|
||||
read_raw(&val, sizeof(val));
|
||||
return val;
|
||||
@@ -153,16 +154,55 @@ struct llama_file::impl {
|
||||
write_raw(&val, sizeof(val));
|
||||
}
|
||||
|
||||
bool has_direct_io() const {
|
||||
return true;
|
||||
}
|
||||
|
||||
~impl() {
|
||||
if (fp) {
|
||||
std::fclose(fp);
|
||||
}
|
||||
}
|
||||
#else
|
||||
impl(const char * fname, const char * mode) {
|
||||
fp = ggml_fopen(fname, mode);
|
||||
impl(const char * fname, const char * mode, [[maybe_unused]] const bool use_direct_io = false) : fname(fname) {
|
||||
#ifdef __linux__
|
||||
// Try unbuffered I/O for read only
|
||||
if (use_direct_io && std::strcmp(mode, "rb") == 0) {
|
||||
if (init_fd()) {
|
||||
return;
|
||||
}
|
||||
LLAMA_LOG_WARN("Failed to open file '%s' with error: %s. Falling back to buffered I/O",
|
||||
fname, strerror(errno));
|
||||
}
|
||||
#endif
|
||||
init_fp(mode);
|
||||
}
|
||||
|
||||
#ifdef __linux__
|
||||
bool init_fd() {
|
||||
fd = open(fname.c_str(), O_RDONLY | O_DIRECT);
|
||||
|
||||
if (fd != -1) {
|
||||
struct stat file_stats{};
|
||||
fstat(fd, &file_stats);
|
||||
|
||||
size = file_stats.st_size;
|
||||
alignment = file_stats.st_blksize;
|
||||
|
||||
off_t ret = lseek(fd, 0, SEEK_SET);
|
||||
if (ret == -1) {
|
||||
throw std::runtime_error(format("seek error: %s", strerror(errno)));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
|
||||
void init_fp(const char * mode) {
|
||||
fp = ggml_fopen(fname.c_str(), mode);
|
||||
if (fp == NULL) {
|
||||
throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno)));
|
||||
throw std::runtime_error(format("failed to open %s: %s", fname.c_str(), strerror(errno)));
|
||||
}
|
||||
seek(0, SEEK_END);
|
||||
size = tell();
|
||||
@@ -170,46 +210,122 @@ struct llama_file::impl {
|
||||
}
|
||||
|
||||
size_t tell() const {
|
||||
// TODO: this ifdef is never true?
|
||||
#ifdef _WIN32
|
||||
__int64 ret = _ftelli64(fp);
|
||||
#else
|
||||
long ret = std::ftell(fp);
|
||||
#endif
|
||||
if (ret == -1) {
|
||||
throw std::runtime_error(format("ftell error: %s", strerror(errno)));
|
||||
if (fd == -1) {
|
||||
long ret = std::ftell(fp);
|
||||
if (ret == -1) {
|
||||
throw std::runtime_error(format("ftell error: %s", strerror(errno)));
|
||||
}
|
||||
|
||||
return (size_t) ret;
|
||||
}
|
||||
|
||||
return (size_t) ret;
|
||||
off_t pos = lseek(fd, 0, SEEK_CUR);
|
||||
if (pos == -1) {
|
||||
throw std::runtime_error(format("lseek error: %s", strerror(errno)));
|
||||
}
|
||||
return (size_t) pos;
|
||||
}
|
||||
|
||||
void seek(size_t offset, int whence) const {
|
||||
// TODO: this ifdef is never true?
|
||||
#ifdef _WIN32
|
||||
int ret = _fseeki64(fp, (__int64) offset, whence);
|
||||
#else
|
||||
int ret = std::fseek(fp, (long) offset, whence);
|
||||
#endif
|
||||
if (ret != 0) {
|
||||
off_t ret = 0;
|
||||
if (fd == -1) {
|
||||
ret = std::fseek(fp, (long) offset, whence);
|
||||
} else {
|
||||
ret = lseek(fd, offset, whence);
|
||||
}
|
||||
if (ret == -1) {
|
||||
throw std::runtime_error(format("seek error: %s", strerror(errno)));
|
||||
}
|
||||
}
|
||||
|
||||
void read_raw(void * ptr, size_t len) const {
|
||||
void read_raw_unsafe(void * ptr, size_t len) {
|
||||
if (len == 0) {
|
||||
return;
|
||||
}
|
||||
errno = 0;
|
||||
std::size_t ret = std::fread(ptr, len, 1, fp);
|
||||
if (ferror(fp)) {
|
||||
throw std::runtime_error(format("read error: %s", strerror(errno)));
|
||||
}
|
||||
if (ret != 1) {
|
||||
throw std::runtime_error("unexpectedly reached end of file");
|
||||
if (fd == -1) {
|
||||
const size_t curr_off = tell();
|
||||
const size_t to_read = std::min(len, size - curr_off);
|
||||
|
||||
std::size_t ret = std::fread(ptr, to_read, 1, fp);
|
||||
if (ferror(fp)) {
|
||||
throw std::runtime_error(format("read error: %s", strerror(errno)));
|
||||
}
|
||||
if (to_read > 0 && ret != 1) {
|
||||
throw std::runtime_error("unexpectedly reached end of file");
|
||||
}
|
||||
} else {
|
||||
size_t bytes_read = 0;
|
||||
while (bytes_read < len) {
|
||||
const size_t to_read = len - bytes_read;
|
||||
ssize_t ret = ::read(fd, reinterpret_cast<char *>(ptr) + bytes_read, to_read);
|
||||
|
||||
if (ret == -1) {
|
||||
if (errno == EINTR) {
|
||||
continue; // Interrupted by signal, retry
|
||||
}
|
||||
// Fallback to std::fread in case the DMA controller cannot access the buffer
|
||||
if (errno == EFAULT || errno == EINVAL) {
|
||||
LLAMA_LOG_WARN("%s: Falling back to buffered IO due to %s\n", __func__, strerror(errno));
|
||||
auto curr_off = tell();
|
||||
close(fd);
|
||||
fd = -1;
|
||||
alignment = 1;
|
||||
init_fp("rb");
|
||||
seek(curr_off, SEEK_SET);
|
||||
read_raw_unsafe(ptr, len);
|
||||
return;
|
||||
}
|
||||
throw std::runtime_error(format("read error: %s", strerror(errno)));
|
||||
}
|
||||
if (ret == 0) {
|
||||
// EOF: allow if this read was only pulling alignment padding past file end
|
||||
off_t pos = lseek(fd, 0, SEEK_CUR);
|
||||
if (pos != -1 && (size_t) pos == size) {
|
||||
std::memset(reinterpret_cast<char *>(ptr) + bytes_read, 0, len - bytes_read);
|
||||
return;
|
||||
}
|
||||
throw std::runtime_error("unexpectedly reached end of file");
|
||||
}
|
||||
|
||||
bytes_read += (size_t) ret;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t read_u32() const {
|
||||
void read_aligned_chunk(void * dest, size_t size) {
|
||||
size_t offset = tell();
|
||||
off_t aligned_offset = offset & ~(alignment - 1);
|
||||
off_t offset_from_alignment = offset - aligned_offset;
|
||||
size_t bytes_to_read = (offset_from_alignment + size + alignment - 1) & ~(alignment - 1);
|
||||
|
||||
void * raw_buffer = nullptr;
|
||||
int ret = posix_memalign(&raw_buffer, alignment, bytes_to_read);
|
||||
if (ret != 0) {
|
||||
throw std::runtime_error(format("posix_memalign failed with error %d", ret));
|
||||
}
|
||||
|
||||
struct aligned_buffer_deleter {
|
||||
void operator()(void * p) const { free(p); }
|
||||
};
|
||||
std::unique_ptr<void, aligned_buffer_deleter> buffer(raw_buffer);
|
||||
|
||||
seek(aligned_offset, SEEK_SET);
|
||||
read_raw_unsafe(buffer.get(), bytes_to_read);
|
||||
|
||||
uintptr_t actual_data = reinterpret_cast<uintptr_t>(buffer.get()) + offset_from_alignment;
|
||||
memcpy(dest, reinterpret_cast<void *>(actual_data), size);
|
||||
}
|
||||
|
||||
void read_raw(void * ptr, size_t len) {
|
||||
if (has_direct_io()) {
|
||||
read_aligned_chunk(ptr, len);
|
||||
} else {
|
||||
read_raw_unsafe(ptr, len);
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t read_u32() {
|
||||
uint32_t ret;
|
||||
read_raw(&ret, sizeof(ret));
|
||||
return ret;
|
||||
@@ -230,27 +346,48 @@ struct llama_file::impl {
|
||||
write_raw(&val, sizeof(val));
|
||||
}
|
||||
|
||||
bool has_direct_io() const {
|
||||
return fd != -1 && alignment > 1;
|
||||
}
|
||||
|
||||
~impl() {
|
||||
if (fp) {
|
||||
if (fd != -1) {
|
||||
close(fd);
|
||||
} else {
|
||||
std::fclose(fp);
|
||||
}
|
||||
}
|
||||
int fd = -1;
|
||||
std::string fname;
|
||||
#endif
|
||||
|
||||
FILE * fp;
|
||||
size_t size;
|
||||
size_t read_alignment() const {
|
||||
return alignment;
|
||||
}
|
||||
|
||||
size_t alignment = 1;
|
||||
|
||||
FILE * fp{};
|
||||
size_t size{};
|
||||
};
|
||||
|
||||
llama_file::llama_file(const char * fname, const char * mode) : pimpl(std::make_unique<impl>(fname, mode)) {}
|
||||
llama_file::llama_file(const char * fname, const char * mode, const bool use_direct_io) :
|
||||
pimpl(std::make_unique<impl>(fname, mode, use_direct_io)) {}
|
||||
llama_file::~llama_file() = default;
|
||||
|
||||
size_t llama_file::tell() const { return pimpl->tell(); }
|
||||
size_t llama_file::size() const { return pimpl->size; }
|
||||
|
||||
size_t llama_file::read_alignment() const { return pimpl->read_alignment(); }
|
||||
bool llama_file::has_direct_io() const { return pimpl->has_direct_io(); }
|
||||
|
||||
int llama_file::file_id() const {
|
||||
#ifdef _WIN32
|
||||
return _fileno(pimpl->fp);
|
||||
#else
|
||||
if (pimpl->fd != -1) {
|
||||
return pimpl->fd;
|
||||
}
|
||||
#if defined(fileno)
|
||||
return fileno(pimpl->fp);
|
||||
#else
|
||||
@@ -260,9 +397,14 @@ int llama_file::file_id() const {
|
||||
}
|
||||
|
||||
void llama_file::seek(size_t offset, int whence) const { pimpl->seek(offset, whence); }
|
||||
void llama_file::read_raw(void * ptr, size_t len) const { pimpl->read_raw(ptr, len); }
|
||||
void llama_file::read_raw(void * ptr, size_t len) { pimpl->read_raw(ptr, len); }
|
||||
#ifdef _WIN32
|
||||
void llama_file::read_raw_unsafe(void * ptr, size_t len) { pimpl->read_raw(ptr, len); }
|
||||
#else
|
||||
void llama_file::read_raw_unsafe(void * ptr, size_t len) { pimpl->read_raw_unsafe(ptr, len); }
|
||||
#endif
|
||||
|
||||
uint32_t llama_file::read_u32() const { return pimpl->read_u32(); }
|
||||
uint32_t llama_file::read_u32() { return pimpl->read_u32(); }
|
||||
|
||||
void llama_file::write_raw(const void * ptr, size_t len) const { pimpl->write_raw(ptr, len); }
|
||||
void llama_file::write_u32(uint32_t val) const { pimpl->write_u32(val); }
|
||||
@@ -476,9 +618,9 @@ struct llama_mlock::impl {
|
||||
|
||||
char* errmsg = std::strerror(errno);
|
||||
bool suggest = (errno == ENOMEM);
|
||||
#if defined(TARGET_OS_VISION) || defined(TARGET_OS_TV) || defined(_AIX)
|
||||
// visionOS/tvOS dont't support RLIMIT_MEMLOCK
|
||||
// Skip resource limit checks on visionOS/tvOS
|
||||
#if defined(TARGET_OS_VISION) || defined(TARGET_OS_TV) || defined(_AIX) || defined(__HAIKU__)
|
||||
// visionOS/tvOS/Haiku don't support RLIMIT_MEMLOCK
|
||||
// Skip resource limit checks on these platforms
|
||||
suggest = false;
|
||||
#else
|
||||
struct rlimit lock_limit;
|
||||
|
||||
11
llama/llama.cpp/src/llama-mmap.h
vendored
11
llama/llama.cpp/src/llama-mmap.h
vendored
@@ -3,6 +3,7 @@
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <cstdio>
|
||||
|
||||
struct llama_file;
|
||||
struct llama_mmap;
|
||||
@@ -13,7 +14,7 @@ using llama_mmaps = std::vector<std::unique_ptr<llama_mmap>>;
|
||||
using llama_mlocks = std::vector<std::unique_ptr<llama_mlock>>;
|
||||
|
||||
struct llama_file {
|
||||
llama_file(const char * fname, const char * mode);
|
||||
llama_file(const char * fname, const char * mode, bool use_direct_io = false);
|
||||
~llama_file();
|
||||
|
||||
size_t tell() const;
|
||||
@@ -23,12 +24,16 @@ struct llama_file {
|
||||
|
||||
void seek(size_t offset, int whence) const;
|
||||
|
||||
void read_raw(void * ptr, size_t len) const;
|
||||
uint32_t read_u32() const;
|
||||
void read_raw(void * ptr, size_t len);
|
||||
void read_raw_unsafe(void * ptr, size_t len);
|
||||
void read_aligned_chunk(void * dest, size_t size);
|
||||
uint32_t read_u32();
|
||||
|
||||
void write_raw(const void * ptr, size_t len) const;
|
||||
void write_u32(uint32_t val) const;
|
||||
|
||||
size_t read_alignment() const;
|
||||
bool has_direct_io() const;
|
||||
private:
|
||||
struct impl;
|
||||
std::unique_ptr<impl> pimpl;
|
||||
|
||||
112
llama/llama.cpp/src/llama-model-loader.cpp
vendored
112
llama/llama.cpp/src/llama-model-loader.cpp
vendored
@@ -2,6 +2,7 @@
|
||||
|
||||
#include "ggml.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <cinttypes>
|
||||
#include <cstring>
|
||||
@@ -344,6 +345,7 @@ namespace GGUFMeta {
|
||||
GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(ctx, kid);
|
||||
|
||||
switch (arr_info.gt) {
|
||||
case GGUF_TYPE_BOOL:
|
||||
case GGUF_TYPE_UINT32:
|
||||
case GGUF_TYPE_INT32: GGML_ASSERT((std::is_same<T, int32_t>::value) ||
|
||||
(std::is_same<T, uint32_t>::value)); break;
|
||||
@@ -365,7 +367,13 @@ namespace GGUFMeta {
|
||||
result[i] = value;
|
||||
}
|
||||
} else {
|
||||
std::copy((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length, result.begin());
|
||||
if (arr_info.gt == GGUF_TYPE_BOOL) {
|
||||
std::transform((const bool *)arr_info.data, (const bool *)arr_info.data + arr_info.length, result.begin(), [](bool x) {
|
||||
return static_cast<T>(x);
|
||||
});
|
||||
} else {
|
||||
std::copy((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length, result.begin());
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
@@ -462,6 +470,29 @@ namespace GGUFMeta {
|
||||
return get_key_or_arr(llm_kv(kid), result, n, required);
|
||||
}
|
||||
|
||||
bool llama_model_loader::get_key_or_arr(enum llm_kv kid, uint32_t & result, bool required) {
|
||||
const std::string key = llm_kv(kid);
|
||||
|
||||
const int id = gguf_find_key(meta.get(), key.c_str());
|
||||
|
||||
if (id < 0) {
|
||||
if (required) {
|
||||
throw std::runtime_error(format("key not found in model: %s", key.c_str()));
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// throw and error if type is an array
|
||||
if (gguf_get_kv_type(meta.get(), id) == GGUF_TYPE_ARRAY) {
|
||||
if (required) {
|
||||
throw std::runtime_error(format("expected scalar, found array for key: %s", key.c_str()));
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
return get_key(key, result, required);
|
||||
}
|
||||
|
||||
// TODO: this is not very clever - figure out something better
|
||||
template bool llama_model_loader::get_key_or_arr<std::array<int, 4>>(enum llm_kv kid, std::array<int, 4> & result, uint32_t n, bool required);
|
||||
template bool llama_model_loader::get_key_or_arr<std::array<uint32_t, 512>>(enum llm_kv kid, std::array<uint32_t, 512> & result, uint32_t n, bool required);
|
||||
@@ -472,6 +503,7 @@ llama_model_loader::llama_model_loader(
|
||||
const std::string & fname,
|
||||
std::vector<std::string> & splits,
|
||||
bool use_mmap,
|
||||
bool use_direct_io,
|
||||
bool check_tensors,
|
||||
bool no_alloc,
|
||||
const llama_model_kv_override * param_overrides_p,
|
||||
@@ -504,9 +536,23 @@ llama_model_loader::llama_model_loader(
|
||||
get_key(llm_kv(LLM_KV_GENERAL_ARCHITECTURE), arch_name, false);
|
||||
llm_kv = LLM_KV(llm_arch_from_string(arch_name));
|
||||
|
||||
files.emplace_back(new llama_file(fname.c_str(), "rb"));
|
||||
files.emplace_back(new llama_file(fname.c_str(), "rb", use_direct_io));
|
||||
contexts.emplace_back(ctx);
|
||||
|
||||
if (use_mmap && use_direct_io) {
|
||||
if (files.back()->has_direct_io()) {
|
||||
// Disable mmap, as DirectIO is available
|
||||
use_mmap = false;
|
||||
LLAMA_LOG_WARN("%s: direct I/O is enabled, disabling mmap\n", __func__);
|
||||
} else {
|
||||
// Disable DirectIO and reopen file using std::fopen for mmap
|
||||
use_direct_io = false;
|
||||
files.pop_back();
|
||||
files.emplace_back(new llama_file(fname.c_str(), "rb", false));
|
||||
LLAMA_LOG_WARN("%s: direct I/O is not available, using mmap\n", __func__);
|
||||
}
|
||||
}
|
||||
|
||||
// Save tensors data offset of the main file.
|
||||
// For subsidiary files, `meta` tensor data offset must not be used,
|
||||
// so we build a unified tensors index for weights.
|
||||
@@ -572,7 +618,7 @@ llama_model_loader::llama_model_loader(
|
||||
}
|
||||
}
|
||||
|
||||
files.emplace_back(new llama_file(fname_split, "rb"));
|
||||
files.emplace_back(new llama_file(fname_split, "rb", use_direct_io));
|
||||
contexts.emplace_back(ctx);
|
||||
|
||||
// Save tensors data offset info of the shard.
|
||||
@@ -716,6 +762,7 @@ llama_model_loader::llama_model_loader(
|
||||
}
|
||||
|
||||
this->use_mmap = use_mmap;
|
||||
this->use_direct_io = use_direct_io;
|
||||
this->check_tensors = check_tensors;
|
||||
this->no_alloc = no_alloc;
|
||||
}
|
||||
@@ -935,7 +982,15 @@ bool llama_model_loader::load_all_data(
|
||||
// 4 staging buffers for async uploads, each sized 1MB seems to be a good default for single NVMe drives.
|
||||
// NVMe raid configurations might require more / larger buffers.
|
||||
constexpr size_t n_buffers = 4;
|
||||
constexpr size_t buffer_size = 1 * 1024 * 1024; // 1MB
|
||||
|
||||
size_t alignment = 1;
|
||||
for (const auto & file : files) {
|
||||
alignment = std::max(file->read_alignment(), alignment);
|
||||
}
|
||||
|
||||
// Buffer size: balance between memory usage and I/O efficiency
|
||||
// 64MB works well for NVMe drives
|
||||
const size_t buffer_size = alignment != 1 ? 64 * 1024 * 1024 + 2 * alignment : 1 * 1024 * 1024;
|
||||
|
||||
std::vector<ggml_backend_buffer_t> host_buffers;
|
||||
std::vector<ggml_backend_event_t> events;
|
||||
@@ -985,6 +1040,7 @@ bool llama_model_loader::load_all_data(
|
||||
// If the backend is supported, create pinned memory buffers and events for synchronisation.
|
||||
for (size_t idx = 0; idx < n_buffers; ++idx) {
|
||||
auto * buf = ggml_backend_buft_alloc_buffer(host_buft, buffer_size);
|
||||
|
||||
if (!buf) {
|
||||
LLAMA_LOG_DEBUG("%s: failed to allocate host buffer for async uploads for device %s\n", func,
|
||||
ggml_backend_dev_name(dev));
|
||||
@@ -1066,6 +1122,7 @@ bool llama_model_loader::load_all_data(
|
||||
}
|
||||
} else {
|
||||
const auto & file = files.at(weight->idx);
|
||||
|
||||
if (ggml_backend_buffer_is_host(cur->buffer)) {
|
||||
file->seek(weight->offs, SEEK_SET);
|
||||
file->read_raw(cur->data, n_size);
|
||||
@@ -1077,19 +1134,54 @@ bool llama_model_loader::load_all_data(
|
||||
} else {
|
||||
// If upload_backend is valid load the tensor in chunks to pinned memory and upload the buffers asynchronously to the GPU.
|
||||
if (upload_backend) {
|
||||
file->seek(weight->offs, SEEK_SET);
|
||||
size_t offset = weight->offs;
|
||||
alignment = file->read_alignment();
|
||||
size_t aligned_offset = offset & ~(alignment - 1);
|
||||
size_t offset_from_alignment = offset - aligned_offset;
|
||||
file->seek(aligned_offset, SEEK_SET);
|
||||
|
||||
// Calculate aligned read boundaries
|
||||
size_t read_start = aligned_offset;
|
||||
size_t read_end = (offset + n_size + alignment - 1) & ~(alignment - 1);
|
||||
|
||||
size_t bytes_read = 0;
|
||||
size_t data_read = 0; // Actual tensor data copied (excluding padding)
|
||||
|
||||
while (bytes_read < n_size) {
|
||||
size_t read_iteration = std::min<size_t>(buffer_size, n_size - bytes_read);
|
||||
while (bytes_read < read_end - read_start) {
|
||||
size_t read_size = std::min<size_t>(buffer_size, read_end - read_start - bytes_read);
|
||||
|
||||
// Align the destination pointer within the pinned buffer
|
||||
uintptr_t ptr_dest_aligned = (reinterpret_cast<uintptr_t>(host_ptrs[buffer_idx]) + alignment - 1) & ~(alignment - 1);
|
||||
|
||||
// Wait for previous upload to complete before reusing buffer
|
||||
ggml_backend_event_synchronize(events[buffer_idx]);
|
||||
file->read_raw(host_ptrs[buffer_idx], read_iteration);
|
||||
ggml_backend_tensor_set_async(upload_backend, cur, host_ptrs[buffer_idx], bytes_read, read_iteration);
|
||||
|
||||
// Read aligned chunk from file
|
||||
file->read_raw_unsafe(reinterpret_cast<void *>(ptr_dest_aligned), read_size);
|
||||
|
||||
// Calculate actual data portion (excluding alignment padding)
|
||||
uintptr_t ptr_data = ptr_dest_aligned;
|
||||
size_t data_to_copy = read_size;
|
||||
|
||||
// Skip alignment padding at start of first chunk
|
||||
if (bytes_read == 0) {
|
||||
ptr_data += offset_from_alignment;
|
||||
data_to_copy -= offset_from_alignment;
|
||||
}
|
||||
|
||||
// Trim alignment padding at end of last chunk
|
||||
if (aligned_offset + bytes_read + read_size > offset + n_size) {
|
||||
data_to_copy -= (read_end - (offset + n_size));
|
||||
}
|
||||
|
||||
// Async upload actual data to GPU
|
||||
ggml_backend_tensor_set_async(upload_backend, cur,
|
||||
reinterpret_cast<void *>(ptr_data), data_read, data_to_copy);
|
||||
ggml_backend_event_record(events[buffer_idx], upload_backend);
|
||||
|
||||
bytes_read += read_iteration;
|
||||
data_read += data_to_copy;
|
||||
bytes_read += read_size;
|
||||
|
||||
++buffer_idx;
|
||||
buffer_idx %= n_buffers;
|
||||
}
|
||||
|
||||
4
llama/llama.cpp/src/llama-model-loader.h
vendored
4
llama/llama.cpp/src/llama-model-loader.h
vendored
@@ -70,6 +70,7 @@ struct llama_model_loader {
|
||||
size_t n_bytes = 0;
|
||||
|
||||
bool use_mmap = false;
|
||||
bool use_direct_io = false;
|
||||
bool check_tensors;
|
||||
bool no_alloc;
|
||||
|
||||
@@ -97,6 +98,7 @@ struct llama_model_loader {
|
||||
const std::string & fname,
|
||||
std::vector<std::string> & splits, // optional, only need if the split does not follow naming scheme
|
||||
bool use_mmap,
|
||||
bool use_direct_io,
|
||||
bool check_tensors,
|
||||
bool no_alloc,
|
||||
const llama_model_kv_override * param_overrides_p,
|
||||
@@ -131,6 +133,8 @@ struct llama_model_loader {
|
||||
template<typename T>
|
||||
bool get_key_or_arr(enum llm_kv kid, T & result, uint32_t n, bool required = true);
|
||||
|
||||
bool get_key_or_arr(enum llm_kv kid, uint32_t & result, bool required = true);
|
||||
|
||||
std::string get_arch_name() const;
|
||||
|
||||
enum llm_arch get_arch() const;
|
||||
|
||||
3
llama/llama.cpp/src/llama-model-saver.cpp
vendored
3
llama/llama.cpp/src/llama-model-saver.cpp
vendored
@@ -146,6 +146,9 @@ void llama_model_saver::add_kv_from_model() {
|
||||
add_kv(LLM_KV_VOCAB_SIZE, vocab.n_tokens());
|
||||
add_kv(LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train);
|
||||
add_kv(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd);
|
||||
if (hparams.n_embd_out_impl > 0) {
|
||||
add_kv(LLM_KV_EMBEDDING_LENGTH_OUT, hparams.n_embd_out_impl);
|
||||
}
|
||||
add_kv(LLM_KV_BLOCK_COUNT, hparams.n_layer);
|
||||
add_kv(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead);
|
||||
add_kv(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, true);
|
||||
|
||||
791
llama/llama.cpp/src/llama-model.cpp
vendored
791
llama/llama.cpp/src/llama-model.cpp
vendored
File diff suppressed because it is too large
Load Diff
16
llama/llama.cpp/src/llama-model.h
vendored
16
llama/llama.cpp/src/llama-model.h
vendored
@@ -11,6 +11,7 @@
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
struct llama_cparams;
|
||||
@@ -24,12 +25,14 @@ enum llm_type {
|
||||
LLM_TYPE_17M,
|
||||
LLM_TYPE_22M,
|
||||
LLM_TYPE_33M,
|
||||
LLM_TYPE_47M,
|
||||
LLM_TYPE_60M,
|
||||
LLM_TYPE_70M,
|
||||
LLM_TYPE_80M,
|
||||
LLM_TYPE_109M,
|
||||
LLM_TYPE_137M,
|
||||
LLM_TYPE_140M,
|
||||
LLM_TYPE_149M,
|
||||
LLM_TYPE_160M,
|
||||
LLM_TYPE_190M,
|
||||
LLM_TYPE_220M,
|
||||
@@ -39,6 +42,7 @@ enum llm_type {
|
||||
LLM_TYPE_335M,
|
||||
LLM_TYPE_350M,
|
||||
LLM_TYPE_360M,
|
||||
LLM_TYPE_395M,
|
||||
LLM_TYPE_410M,
|
||||
LLM_TYPE_450M,
|
||||
LLM_TYPE_475M,
|
||||
@@ -117,10 +121,12 @@ enum llm_type {
|
||||
LLM_TYPE_31B_A3_5B,
|
||||
LLM_TYPE_80B_A3B, // Qwen3 Next
|
||||
LLM_TYPE_100B_A6B,
|
||||
LLM_TYPE_102B_A12B, // Solar-Open
|
||||
LLM_TYPE_106B_A12B, // GLM-4.5-Air
|
||||
LLM_TYPE_230B_A10B, // Minimax M2
|
||||
LLM_TYPE_235B_A22B,
|
||||
LLM_TYPE_300B_A47B, // Ernie MoE big
|
||||
LLM_TYPE_310B_A15B, // /MiMo-V2-Flash
|
||||
LLM_TYPE_355B_A32B, // GLM-4.5
|
||||
LLM_TYPE_E2B,
|
||||
LLM_TYPE_E4B,
|
||||
@@ -465,8 +471,6 @@ struct llama_model {
|
||||
struct ggml_tensor * dense_2_out_layers = nullptr;
|
||||
struct ggml_tensor * dense_3_out_layers = nullptr;
|
||||
|
||||
llama_model_params params;
|
||||
|
||||
// gguf metadata
|
||||
std::unordered_map<std::string, std::string> gguf_kv;
|
||||
|
||||
@@ -476,6 +480,9 @@ struct llama_model {
|
||||
// for quantize-stats only
|
||||
std::vector<std::pair<std::string, struct ggml_tensor *>> tensors_by_name;
|
||||
|
||||
// for keeping track of associated LoRA adapters
|
||||
std::unordered_set<llama_adapter_lora *> loras;
|
||||
|
||||
int64_t t_load_us = 0;
|
||||
int64_t t_start_us = 0;
|
||||
|
||||
@@ -497,6 +504,9 @@ struct llama_model {
|
||||
size_t n_tensors() const;
|
||||
size_t n_devices() const;
|
||||
|
||||
uint32_t n_gpu_layers() const;
|
||||
llama_split_mode split_mode() const;
|
||||
|
||||
std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const;
|
||||
|
||||
// total number of parameters in the model
|
||||
@@ -525,6 +535,8 @@ struct llama_model {
|
||||
ggml_cgraph * build_graph(const llm_graph_params & params) const;
|
||||
|
||||
private:
|
||||
llama_model_params params;
|
||||
|
||||
struct impl;
|
||||
std::unique_ptr<impl> pimpl;
|
||||
};
|
||||
|
||||
111
llama/llama.cpp/src/llama-quant.cpp
vendored
111
llama/llama.cpp/src/llama-quant.cpp
vendored
@@ -422,57 +422,6 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t
|
||||
++qs.i_ffn_up;
|
||||
}
|
||||
|
||||
// if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
|
||||
//}
|
||||
// IK: let's remove this, else Q2_K is almost the same as Q3_K_S
|
||||
//else if (name.find("ffn_gate") != std::string::npos || name.find("ffn_up") != std::string::npos) {
|
||||
// if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
|
||||
//}
|
||||
// This can be used to reduce the size of the Q5_K_S model.
|
||||
// The associated PPL increase is fully in line with the size reduction
|
||||
//else {
|
||||
// if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_S) new_type = GGML_TYPE_Q4_K;
|
||||
//}
|
||||
bool convert_incompatible_tensor = false;
|
||||
{
|
||||
const int64_t nx = tensor->ne[0];
|
||||
const int64_t ny = tensor->ne[1];
|
||||
const int64_t qk_k = ggml_blck_size(new_type);
|
||||
|
||||
if (nx % qk_k != 0) {
|
||||
LLAMA_LOG_WARN("\n\n%s : tensor cols %" PRId64 " x %" PRId64 " are not divisible by %" PRId64 ", required for %s", __func__, nx, ny, qk_k, ggml_type_name(new_type));
|
||||
convert_incompatible_tensor = true;
|
||||
} else {
|
||||
++qs.n_k_quantized;
|
||||
}
|
||||
}
|
||||
|
||||
if (convert_incompatible_tensor) {
|
||||
switch (new_type) {
|
||||
case GGML_TYPE_TQ1_0:
|
||||
case GGML_TYPE_TQ2_0: new_type = GGML_TYPE_Q4_0; break; // TODO: use a symmetric type instead
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
case GGML_TYPE_IQ3_XXS:
|
||||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ1_S:
|
||||
case GGML_TYPE_IQ1_M:
|
||||
case GGML_TYPE_Q2_K:
|
||||
case GGML_TYPE_Q3_K:
|
||||
case GGML_TYPE_IQ4_XS: new_type = GGML_TYPE_IQ4_NL; break;
|
||||
case GGML_TYPE_Q4_K: new_type = GGML_TYPE_Q5_0; break;
|
||||
case GGML_TYPE_Q5_K: new_type = GGML_TYPE_Q5_1; break;
|
||||
case GGML_TYPE_Q6_K: new_type = GGML_TYPE_Q8_0; break;
|
||||
default: throw std::runtime_error("\nUnsupported tensor size encountered\n");
|
||||
}
|
||||
if (tensor->ne[0] % ggml_blck_size(new_type) != 0) {
|
||||
new_type = GGML_TYPE_F16;
|
||||
}
|
||||
LLAMA_LOG_WARN(" - using fallback quantization %s\n", ggml_type_name(new_type));
|
||||
++qs.n_fallback;
|
||||
}
|
||||
|
||||
return new_type;
|
||||
}
|
||||
|
||||
@@ -596,7 +545,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||
}
|
||||
|
||||
std::vector<std::string> splits = {};
|
||||
llama_model_loader ml(fname_inp, splits, use_mmap, /*check_tensors*/ true, /*no_alloc*/ false, kv_overrides, nullptr);
|
||||
llama_model_loader ml(fname_inp, splits, use_mmap, /*use_direct_io*/ true, /*check_tensors*/ true, /*no_alloc*/ false, kv_overrides, nullptr);
|
||||
ml.init_mappings(false); // no prefetching
|
||||
|
||||
llama_model model(llama_model_default_params());
|
||||
@@ -875,21 +824,69 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||
|
||||
// get more optimal quantization type based on the tensor shape, layer, etc.
|
||||
if (!params->pure && ggml_is_quantized(default_type)) {
|
||||
int fallback = qs.n_fallback;
|
||||
new_type = llama_tensor_get_type(qs, new_type, tensor, ftype);
|
||||
// unless the user specifies a type, and the tensor geometry will not require fallback quantisation
|
||||
if (params->tensor_types && qs.n_fallback - fallback == 0) {
|
||||
// if the user provided tensor types - use those
|
||||
bool manual = false;
|
||||
if (params->tensor_types) {
|
||||
const std::vector<tensor_quantization> & tensor_types = *static_cast<const std::vector<tensor_quantization> *>(params->tensor_types);
|
||||
const std::string tensor_name(tensor->name);
|
||||
for (const auto & [tname, qtype] : tensor_types) {
|
||||
if (std::regex pattern(tname); std::regex_search(tensor_name, pattern)) {
|
||||
if (qtype != new_type) {
|
||||
LLAMA_LOG_DEBUG("(overriding %s) ", ggml_type_name(new_type));
|
||||
LLAMA_LOG_WARN("(manual override: %s -> %s) ", ggml_type_name(new_type), ggml_type_name(qtype));
|
||||
new_type = qtype; // if two or more types are specified for the same tensor, the last match wins
|
||||
manual = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// if not manual - use the standard logic for choosing the quantization type based on the selected mixture
|
||||
if (!manual) {
|
||||
new_type = llama_tensor_get_type(qs, new_type, tensor, ftype);
|
||||
}
|
||||
|
||||
// incompatible tensor shapes are handled here - fallback to a compatible type
|
||||
{
|
||||
bool convert_incompatible_tensor = false;
|
||||
|
||||
const int64_t nx = tensor->ne[0];
|
||||
const int64_t ny = tensor->ne[1];
|
||||
const int64_t qk_k = ggml_blck_size(new_type);
|
||||
|
||||
if (nx % qk_k != 0) {
|
||||
LLAMA_LOG_WARN("\n\n%s : tensor cols %" PRId64 " x %" PRId64 " are not divisible by %" PRId64 ", required for %s", __func__, nx, ny, qk_k, ggml_type_name(new_type));
|
||||
convert_incompatible_tensor = true;
|
||||
} else {
|
||||
++qs.n_k_quantized;
|
||||
}
|
||||
|
||||
if (convert_incompatible_tensor) {
|
||||
switch (new_type) {
|
||||
case GGML_TYPE_TQ1_0:
|
||||
case GGML_TYPE_TQ2_0: new_type = GGML_TYPE_Q4_0; break; // TODO: use a symmetric type instead
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
case GGML_TYPE_IQ3_XXS:
|
||||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ1_S:
|
||||
case GGML_TYPE_IQ1_M:
|
||||
case GGML_TYPE_Q2_K:
|
||||
case GGML_TYPE_Q3_K:
|
||||
case GGML_TYPE_IQ4_XS: new_type = GGML_TYPE_IQ4_NL; break;
|
||||
case GGML_TYPE_Q4_K: new_type = GGML_TYPE_Q5_0; break;
|
||||
case GGML_TYPE_Q5_K: new_type = GGML_TYPE_Q5_1; break;
|
||||
case GGML_TYPE_Q6_K: new_type = GGML_TYPE_Q8_0; break;
|
||||
default: throw std::runtime_error("\nUnsupported tensor size encountered\n");
|
||||
}
|
||||
if (tensor->ne[0] % ggml_blck_size(new_type) != 0) {
|
||||
new_type = GGML_TYPE_F16;
|
||||
}
|
||||
LLAMA_LOG_WARN(" - using fallback quantization %s\n", ggml_type_name(new_type));
|
||||
++qs.n_fallback;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (params->token_embedding_type < GGML_TYPE_COUNT && strcmp(tensor->name, "token_embd.weight") == 0) {
|
||||
new_type = params->token_embedding_type;
|
||||
|
||||
1716
llama/llama.cpp/src/llama-sampling.cpp
vendored
1716
llama/llama.cpp/src/llama-sampling.cpp
vendored
File diff suppressed because it is too large
Load Diff
26
llama/llama.cpp/src/llama-sampling.h
vendored
26
llama/llama.cpp/src/llama-sampling.h
vendored
@@ -14,7 +14,19 @@ struct llama_grammar;
|
||||
struct llama_sampler_chain {
|
||||
llama_sampler_chain_params params;
|
||||
|
||||
std::vector<struct llama_sampler *> samplers;
|
||||
// has .backend_init() been called?
|
||||
bool is_init = false;
|
||||
|
||||
struct info {
|
||||
bool is_backend;
|
||||
|
||||
llama_sampler * ptr;
|
||||
};
|
||||
|
||||
std::vector<info> samplers;
|
||||
|
||||
// pre-allocated buffer for llama_sampler_sample to avoid repeated allocations
|
||||
std::vector<llama_token_data> cur;
|
||||
|
||||
// timing
|
||||
|
||||
@@ -24,9 +36,9 @@ struct llama_sampler_chain {
|
||||
};
|
||||
|
||||
struct llama_sampler * llama_sampler_init_dry_testing(
|
||||
int32_t context_size,
|
||||
float dry_multiplier,
|
||||
float dry_base,
|
||||
int32_t dry_allowed_length,
|
||||
int32_t dry_penalty_last_n,
|
||||
const std::vector<std::vector<llama_token>>& seq_breakers);
|
||||
int32_t context_size,
|
||||
float dry_multiplier,
|
||||
float dry_base,
|
||||
int32_t dry_allowed_length,
|
||||
int32_t dry_penalty_last_n,
|
||||
const std::vector<std::vector<llama_token>> & seq_breakers);
|
||||
|
||||
190
llama/llama.cpp/src/llama-vocab.cpp
vendored
190
llama/llama.cpp/src/llama-vocab.cpp
vendored
@@ -314,6 +314,12 @@ struct llm_tokenizer_bpe : llm_tokenizer {
|
||||
"[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+",
|
||||
};
|
||||
break;
|
||||
case LLAMA_VOCAB_PRE_TYPE_YOUTU:
|
||||
regex_exprs = {
|
||||
"[가-힣ㄱ-ㆎ]+|[!…“”‘’—:;,、-〿︰-﹏]+|[ㄅ-ㄯ]+|[一-龥-ゟ゠-ヿ]+",
|
||||
"[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||
};
|
||||
break;
|
||||
case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER:
|
||||
regex_exprs = {
|
||||
"[\r\n]",
|
||||
@@ -355,6 +361,7 @@ struct llm_tokenizer_bpe : llm_tokenizer {
|
||||
case LLAMA_VOCAB_PRE_TYPE_STABLELM2:
|
||||
case LLAMA_VOCAB_PRE_TYPE_QWEN2:
|
||||
case LLAMA_VOCAB_PRE_TYPE_HUNYUAN:
|
||||
case LLAMA_VOCAB_PRE_TYPE_SOLAR_OPEN:
|
||||
regex_exprs = {
|
||||
// original regex from tokenizer.json
|
||||
// "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
|
||||
@@ -454,6 +461,13 @@ struct llm_tokenizer_bpe : llm_tokenizer {
|
||||
"[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\\r\\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||
};
|
||||
break;
|
||||
case LLAMA_VOCAB_PRE_TYPE_EXAONE_MOE:
|
||||
regex_exprs = {
|
||||
// original regex from tokenizer.json
|
||||
// "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?(?:\\p{L}\\p{M}*(?: \\p{L}\\p{M}*)*)+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]?|\\s*[\\r\\n]|\\s+(?!\\S)|\\s+"
|
||||
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?(?:\\p{L}\\p{M}*(?: \\p{L}\\p{M}*)*)+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]?|\\s*[\\r\\n]|\\s+(?!\\S)|\\s+",
|
||||
};
|
||||
break;
|
||||
default:
|
||||
// default regex for BPE tokenization pre-processing
|
||||
regex_exprs = {
|
||||
@@ -1849,6 +1863,11 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||
tokenizer_pre == "deepseek-v3") {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM;
|
||||
clean_spaces = false;
|
||||
} else if (
|
||||
tokenizer_pre == "youtu") {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_YOUTU;
|
||||
clean_spaces = false;
|
||||
ignore_merges = true;
|
||||
} else if (
|
||||
tokenizer_pre == "falcon") {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_FALCON;
|
||||
@@ -1867,7 +1886,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||
tokenizer_pre == "jina-v2-es" ||
|
||||
tokenizer_pre == "jina-v2-de" ||
|
||||
tokenizer_pre == "a.x-4.0" ||
|
||||
tokenizer_pre == "mellum") {
|
||||
tokenizer_pre == "mellum" ||
|
||||
tokenizer_pre == "modern-bert" ) {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2;
|
||||
} else if (
|
||||
tokenizer_pre == "jina-v1-en" ||
|
||||
@@ -1941,6 +1961,9 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||
} else if (
|
||||
tokenizer_pre == "exaone4") {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2;
|
||||
} else if (
|
||||
tokenizer_pre == "exaone-moe") {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_EXAONE_MOE;
|
||||
} else if (
|
||||
tokenizer_pre == "chameleon") {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_CHAMELEON;
|
||||
@@ -2003,6 +2026,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||
tokenizer_pre == "minimax-m2") {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_MINIMAX_M2;
|
||||
clean_spaces = false;
|
||||
} else if (
|
||||
tokenizer_pre == "solar-open") {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_SOLAR_OPEN;
|
||||
clean_spaces = false;
|
||||
} else {
|
||||
LLAMA_LOG_WARN("%s: missing or unrecognized pre-tokenizer type, using: 'default'\n", __func__);
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
|
||||
@@ -2049,6 +2076,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||
scores = (const float * ) gguf_get_arr_data(ctx, score_idx);
|
||||
}
|
||||
|
||||
const uint32_t n_scores = score_idx != -1 ? gguf_get_arr_n(ctx, score_idx) : 0;
|
||||
const int * toktypes = nullptr;
|
||||
const int toktype_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_TOKEN_TYPE).c_str());
|
||||
if (toktype_idx != -1) {
|
||||
@@ -2070,7 +2098,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||
|
||||
auto & token_data = id_to_token[i];
|
||||
token_data.text = std::move(word);
|
||||
token_data.score = scores ? scores[i] : 0.0f;
|
||||
token_data.score = (scores && i < n_scores) ? scores[i] : 0.0f;
|
||||
token_data.attr = LLAMA_TOKEN_ATTR_NORMAL;
|
||||
|
||||
if (toktypes) { //TODO: remove, required until per token attributes are available from GGUF file
|
||||
@@ -2176,6 +2204,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||
// for now, we apply this workaround to find the tokens based on their text
|
||||
|
||||
for (const auto & t : token_to_id) {
|
||||
auto & attr = id_to_token[t.second].attr;
|
||||
|
||||
// find EOT token: "<|eot_id|>", "<|im_end|>", "<end_of_turn>", etc.
|
||||
if (special_eot_id == LLAMA_TOKEN_NULL) {
|
||||
if (false
|
||||
@@ -2191,10 +2221,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||
|| t.first == "<end_of_utterance>" // smoldocling
|
||||
) {
|
||||
special_eot_id = t.second;
|
||||
if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
||||
if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
||||
LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
|
||||
__func__, t.second, t.first.c_str());
|
||||
id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
|
||||
attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_CONTROL);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2205,10 +2235,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||
|| t.first == "<|eom_id|>"
|
||||
) {
|
||||
special_eom_id = t.second;
|
||||
if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
||||
if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
||||
LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
|
||||
__func__, t.second, t.first.c_str());
|
||||
id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
|
||||
attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_CONTROL);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2225,10 +2255,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||
|| t.first == "<|code_prefix|>" // GLM-4.5
|
||||
) {
|
||||
special_fim_pre_id = t.second;
|
||||
if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
||||
if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
||||
LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
|
||||
__func__, t.second, t.first.c_str());
|
||||
id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
|
||||
attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_CONTROL);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2245,10 +2275,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||
|| t.first == "<|code_suffix|>" // GLM-4.5
|
||||
) {
|
||||
special_fim_suf_id = t.second;
|
||||
if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
||||
if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
||||
LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
|
||||
__func__, t.second, t.first.c_str());
|
||||
id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
|
||||
attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_CONTROL);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2265,10 +2295,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||
|| t.first == "<|code_middle|>" // GLM-4.5
|
||||
) {
|
||||
special_fim_mid_id = t.second;
|
||||
if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
||||
if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
||||
LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
|
||||
__func__, t.second, t.first.c_str());
|
||||
id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
|
||||
attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_CONTROL);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2282,10 +2312,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||
|| t.first == "<PAD>"
|
||||
) {
|
||||
special_fim_pad_id = t.second;
|
||||
if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
||||
if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
||||
LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
|
||||
__func__, t.second, t.first.c_str());
|
||||
id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
|
||||
attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_CONTROL);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2300,10 +2330,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||
|| t.first == "<reponame>" // Granite
|
||||
) {
|
||||
special_fim_rep_id = t.second;
|
||||
if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
||||
if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
||||
LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
|
||||
__func__, t.second, t.first.c_str());
|
||||
id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
|
||||
attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_CONTROL);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2314,15 +2344,41 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||
|| t.first == "<|file_sep|>" // Qwen
|
||||
) {
|
||||
special_fim_sep_id = t.second;
|
||||
if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
||||
if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
||||
LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
|
||||
__func__, t.second, t.first.c_str());
|
||||
id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
|
||||
attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_CONTROL);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// auto-detect unused tokens: e.g. control tokens with the word "unused"
|
||||
// ideally, these tokens should be marked as unused during conversion
|
||||
{
|
||||
uint32_t n_unused = 0;
|
||||
|
||||
for (const auto & t : token_to_id) {
|
||||
auto & attr = id_to_token[t.second].attr;
|
||||
|
||||
if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if ((attr & LLAMA_TOKEN_ATTR_UNUSED) == 0) {
|
||||
if (strstr(t.first.c_str(), "unused") != NULL) {
|
||||
attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_UNUSED);
|
||||
}
|
||||
}
|
||||
|
||||
if (attr & LLAMA_TOKEN_ATTR_UNUSED) {
|
||||
n_unused++;
|
||||
}
|
||||
}
|
||||
|
||||
LLAMA_LOG_INFO("%s: %u unused tokens\n", __func__, n_unused);
|
||||
}
|
||||
|
||||
// maintain a list of tokens that cause end-of-generation
|
||||
// this is currently determined based on the token text, which is obviously not ideal
|
||||
// ref: https://github.com/ggerganov/llama.cpp/issues/9606
|
||||
@@ -2341,12 +2397,16 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||
}
|
||||
|
||||
for (const auto & t : token_to_id) {
|
||||
auto & attr = id_to_token[t.second].attr;
|
||||
|
||||
if (false
|
||||
|| t.first == "<|eot_id|>"
|
||||
|| t.first == "<|im_end|>"
|
||||
|| t.first == "<|end|>"
|
||||
|| t.first == "<|return|>" // o200k_harmony
|
||||
|| t.first == "<|call|>" // o200k_harmony
|
||||
|| t.first == "<|flush|>" // solar-open
|
||||
|| t.first == "<|calls|>" // solar-open
|
||||
|| t.first == "<end_of_turn>"
|
||||
|| t.first == "<|endoftext|>"
|
||||
|| t.first == "<|eom_id|>"
|
||||
@@ -2356,24 +2416,31 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||
|| t.first == "<end_of_utterance>" // smoldocling
|
||||
) {
|
||||
special_eog_ids.insert(t.second);
|
||||
if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
||||
if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
||||
LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
|
||||
__func__, t.second, t.first.c_str());
|
||||
id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
|
||||
attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_CONTROL);
|
||||
}
|
||||
} else {
|
||||
// token is control, but not marked as EOG -> print a debug log
|
||||
if (id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL && special_eog_ids.count(t.second) == 0) {
|
||||
LLAMA_LOG_DEBUG("%s: control token: %6d '%s' is not marked as EOG\n",
|
||||
__func__, t.second, t.first.c_str());
|
||||
if (attr & LLAMA_TOKEN_ATTR_CONTROL && !(attr & LLAMA_TOKEN_ATTR_UNUSED)) {
|
||||
// token is control, but not marked as EOG -> print a debug log
|
||||
if (special_eog_ids.count(t.second) == 0) {
|
||||
LLAMA_LOG_DEBUG("%s: control token: %6d '%s' is not marked as EOG\n",
|
||||
__func__, t.second, t.first.c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// @ngxson : quick hack for gpt-oss, always render these tokens
|
||||
for (const auto & t : token_to_id) {
|
||||
auto & attr = id_to_token[t.second].attr;
|
||||
|
||||
if (t.first == "<|channel|>" || t.first == "<|message|>" || t.first == "<|start|>" || t.first == "<|constrain|>") {
|
||||
id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_USER_DEFINED;
|
||||
LLAMA_LOG_WARN("%s: setting token '%s' (%d) attribute to USER_DEFINED (%u), old attributes: %u\n",
|
||||
__func__, t.first.c_str(), t.second, LLAMA_TOKEN_ATTR_USER_DEFINED, attr);
|
||||
|
||||
attr = LLAMA_TOKEN_ATTR_USER_DEFINED;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2393,34 +2460,42 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||
LLAMA_LOG_WARN("%s: special_eom_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
|
||||
}
|
||||
|
||||
// TODO: workaround for o200k_harmony tokenizer: the "<|end|>" token should not be EOG
|
||||
// we don't have a good way to detect this, so for now, if we have "<|return|>" and "<|call|>" tokens,
|
||||
// TODO: workaround for o200k_harmony and solar-open tokenizer: the "<|end|>" token should not be EOG
|
||||
// we don't have a good way to detect this, so for now, if we have "<|return|>" and "<|call|>" tokens ("<|calls|>" and "<|flush|>" for solar-open),
|
||||
// we remove the "<|end|>" token from the EOG list
|
||||
{
|
||||
bool has_return = false;
|
||||
bool has_call = false;
|
||||
bool has_end = false;
|
||||
bool has_flush = false;
|
||||
|
||||
llama_token end_id = LLAMA_TOKEN_NULL;
|
||||
|
||||
LLAMA_LOG_INFO("%s: printing all EOG tokens:\n", __func__);
|
||||
for (auto tid : special_eog_ids) {
|
||||
LLAMA_LOG_INFO("%s: - %d ('%s')\n", __func__, tid, id_to_token[tid].text.c_str());
|
||||
auto & text = id_to_token[tid].text;
|
||||
|
||||
if (id_to_token[tid].text == "<|return|>") {
|
||||
LLAMA_LOG_INFO("%s: - %d ('%s')\n", __func__, tid, text.c_str());
|
||||
|
||||
if (text == "<|return|>") {
|
||||
has_return = true;
|
||||
} else if (id_to_token[tid].text == "<|call|>") {
|
||||
} else if (text == "<|call|>" || text == "<|calls|>") {
|
||||
has_call = true;
|
||||
} else if (id_to_token[tid].text == "<|end|>") {
|
||||
} else if (text == "<|flush|>") {
|
||||
has_flush = true;
|
||||
} else if (text == "<|end|>") {
|
||||
has_end = true;
|
||||
end_id = tid;
|
||||
}
|
||||
}
|
||||
|
||||
if (has_return && has_call && has_end) {
|
||||
if ((has_return && has_call && has_end) || (has_call && has_flush && has_end)) {
|
||||
special_eog_ids.erase(end_id);
|
||||
id_to_token[end_id].attr = LLAMA_TOKEN_ATTR_USER_DEFINED;
|
||||
LLAMA_LOG_WARN("%s: special_eog_ids contains both '<|return|>' and '<|call|>' tokens, removing '<|end|>' token from EOG list\n", __func__);
|
||||
|
||||
auto & attr = id_to_token[end_id].attr;
|
||||
attr = LLAMA_TOKEN_ATTR_USER_DEFINED;
|
||||
|
||||
LLAMA_LOG_WARN("%s: special_eog_ids contains both '<|return|>' and '<|call|>', or '<|calls|>' and '<|flush|>' tokens, removing '<|end|>' token from EOG list\n", __func__);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2518,6 +2593,13 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||
for (const auto * token : {"<unk>", "<s>", "<|endoftext|>"}) {
|
||||
_set_token_attr(token, LLAMA_TOKEN_ATTR_RSTRIP, false);
|
||||
}
|
||||
} else if (_contains_any(model_name, {"modern-bert"})) {
|
||||
if (token_to_id.count("[MASK]") == 0 ) {
|
||||
LLAMA_LOG_WARN("%s: Mask token missing in vocab!\n", __func__);
|
||||
}
|
||||
else {
|
||||
_set_token_attr("[MASK]", LLAMA_TOKEN_ATTR_LSTRIP, true);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -3211,34 +3293,34 @@ int32_t llama_vocab::impl::detokenize(
|
||||
}
|
||||
|
||||
void llama_vocab::impl::print_info() const {
|
||||
LLAMA_LOG_INFO("%s: vocab type = %s\n", __func__, type_name().c_str());
|
||||
LLAMA_LOG_INFO("%s: n_vocab = %u\n", __func__, vocab.n_tokens());
|
||||
LLAMA_LOG_INFO("%s: n_merges = %u\n", __func__, (uint32_t) bpe_ranks.size());
|
||||
LLAMA_LOG_INFO("%s: vocab type = %s\n", __func__, type_name().c_str());
|
||||
LLAMA_LOG_INFO("%s: n_vocab = %u\n", __func__, vocab.n_tokens());
|
||||
LLAMA_LOG_INFO("%s: n_merges = %u\n", __func__, (uint32_t) bpe_ranks.size());
|
||||
|
||||
// special tokens
|
||||
if (special_bos_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: BOS token = %d '%s'\n", __func__, special_bos_id, id_to_token.at(special_bos_id).text.c_str() ); }
|
||||
if (special_eos_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: EOS token = %d '%s'\n", __func__, special_eos_id, id_to_token.at(special_eos_id).text.c_str() ); }
|
||||
if (special_eot_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: EOT token = %d '%s'\n", __func__, special_eot_id, id_to_token.at(special_eot_id).text.c_str() ); }
|
||||
if (special_eom_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: EOM token = %d '%s'\n", __func__, special_eom_id, id_to_token.at(special_eom_id).text.c_str() ); }
|
||||
if (special_unk_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: UNK token = %d '%s'\n", __func__, special_unk_id, id_to_token.at(special_unk_id).text.c_str() ); }
|
||||
if (special_sep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: SEP token = %d '%s'\n", __func__, special_sep_id, id_to_token.at(special_sep_id).text.c_str() ); }
|
||||
if (special_pad_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: PAD token = %d '%s'\n", __func__, special_pad_id, id_to_token.at(special_pad_id).text.c_str() ); }
|
||||
if (special_mask_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: MASK token = %d '%s'\n", __func__, special_mask_id, id_to_token.at(special_mask_id).text.c_str() ); }
|
||||
if (special_bos_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: BOS token = %d '%s'\n", __func__, special_bos_id, id_to_token.at(special_bos_id).text.c_str() ); }
|
||||
if (special_eos_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: EOS token = %d '%s'\n", __func__, special_eos_id, id_to_token.at(special_eos_id).text.c_str() ); }
|
||||
if (special_eot_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: EOT token = %d '%s'\n", __func__, special_eot_id, id_to_token.at(special_eot_id).text.c_str() ); }
|
||||
if (special_eom_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: EOM token = %d '%s'\n", __func__, special_eom_id, id_to_token.at(special_eom_id).text.c_str() ); }
|
||||
if (special_unk_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: UNK token = %d '%s'\n", __func__, special_unk_id, id_to_token.at(special_unk_id).text.c_str() ); }
|
||||
if (special_sep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: SEP token = %d '%s'\n", __func__, special_sep_id, id_to_token.at(special_sep_id).text.c_str() ); }
|
||||
if (special_pad_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: PAD token = %d '%s'\n", __func__, special_pad_id, id_to_token.at(special_pad_id).text.c_str() ); }
|
||||
if (special_mask_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: MASK token = %d '%s'\n", __func__, special_mask_id, id_to_token.at(special_mask_id).text.c_str() ); }
|
||||
|
||||
if (linefeed_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: LF token = %d '%s'\n", __func__, linefeed_id, id_to_token.at(linefeed_id).text.c_str() ); }
|
||||
if (linefeed_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: LF token = %d '%s'\n", __func__, linefeed_id, id_to_token.at(linefeed_id).text.c_str() ); }
|
||||
|
||||
if (special_fim_pre_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM PRE token = %d '%s'\n", __func__, special_fim_pre_id, id_to_token.at(special_fim_pre_id).text.c_str() ); }
|
||||
if (special_fim_suf_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM SUF token = %d '%s'\n", __func__, special_fim_suf_id, id_to_token.at(special_fim_suf_id).text.c_str() ); }
|
||||
if (special_fim_mid_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM MID token = %d '%s'\n", __func__, special_fim_mid_id, id_to_token.at(special_fim_mid_id).text.c_str() ); }
|
||||
if (special_fim_pad_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM PAD token = %d '%s'\n", __func__, special_fim_pad_id, id_to_token.at(special_fim_pad_id).text.c_str() ); }
|
||||
if (special_fim_rep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM REP token = %d '%s'\n", __func__, special_fim_rep_id, id_to_token.at(special_fim_rep_id).text.c_str() ); }
|
||||
if (special_fim_sep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM SEP token = %d '%s'\n", __func__, special_fim_sep_id, id_to_token.at(special_fim_sep_id).text.c_str() ); }
|
||||
if (special_fim_pre_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM PRE token = %d '%s'\n", __func__, special_fim_pre_id, id_to_token.at(special_fim_pre_id).text.c_str() ); }
|
||||
if (special_fim_suf_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM SUF token = %d '%s'\n", __func__, special_fim_suf_id, id_to_token.at(special_fim_suf_id).text.c_str() ); }
|
||||
if (special_fim_mid_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM MID token = %d '%s'\n", __func__, special_fim_mid_id, id_to_token.at(special_fim_mid_id).text.c_str() ); }
|
||||
if (special_fim_pad_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM PAD token = %d '%s'\n", __func__, special_fim_pad_id, id_to_token.at(special_fim_pad_id).text.c_str() ); }
|
||||
if (special_fim_rep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM REP token = %d '%s'\n", __func__, special_fim_rep_id, id_to_token.at(special_fim_rep_id).text.c_str() ); }
|
||||
if (special_fim_sep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM SEP token = %d '%s'\n", __func__, special_fim_sep_id, id_to_token.at(special_fim_sep_id).text.c_str() ); }
|
||||
|
||||
for (const auto & id : special_eog_ids) {
|
||||
LLAMA_LOG_INFO( "%s: EOG token = %d '%s'\n", __func__, id, id_to_token.at(id).text.c_str() );
|
||||
LLAMA_LOG_INFO( "%s: EOG token = %d '%s'\n", __func__, id, id_to_token.at(id).text.c_str() );
|
||||
}
|
||||
|
||||
LLAMA_LOG_INFO("%s: max token length = %d\n", __func__, max_token_len);
|
||||
LLAMA_LOG_INFO("%s: max token length = %d\n", __func__, max_token_len);
|
||||
}
|
||||
|
||||
llama_vocab::llama_vocab() : pimpl(new impl(*this)) {
|
||||
|
||||
3
llama/llama.cpp/src/llama-vocab.h
vendored
3
llama/llama.cpp/src/llama-vocab.h
vendored
@@ -51,6 +51,9 @@ enum llama_vocab_pre_type {
|
||||
LLAMA_VOCAB_PRE_TYPE_GRANITE_DOCLING = 40,
|
||||
LLAMA_VOCAB_PRE_TYPE_MINIMAX_M2 = 41,
|
||||
LLAMA_VOCAB_PRE_TYPE_AFMOE = 42,
|
||||
LLAMA_VOCAB_PRE_TYPE_SOLAR_OPEN = 43,
|
||||
LLAMA_VOCAB_PRE_TYPE_YOUTU = 44,
|
||||
LLAMA_VOCAB_PRE_TYPE_EXAONE_MOE = 45,
|
||||
};
|
||||
|
||||
struct LLM_KV;
|
||||
|
||||
390
llama/llama.cpp/src/llama.cpp
vendored
390
llama/llama.cpp/src/llama.cpp
vendored
@@ -71,8 +71,9 @@ static std::vector<llama_device_memory_data> llama_get_device_memory_data(
|
||||
}, &ud);
|
||||
|
||||
llama_model_params mparams_copy = *mparams;
|
||||
mparams_copy.no_alloc = true;
|
||||
mparams_copy.use_mmap = false;
|
||||
mparams_copy.no_alloc = true;
|
||||
mparams_copy.use_mmap = false;
|
||||
mparams_copy.use_mlock = false;
|
||||
|
||||
llama_model * model = llama_model_load_from_file(path_model, mparams_copy);
|
||||
if (model == nullptr) {
|
||||
@@ -110,8 +111,20 @@ static std::vector<llama_device_memory_data> llama_get_device_memory_data(
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < ret.size(); i++) {
|
||||
size_t free, total;
|
||||
size_t free;
|
||||
size_t total;
|
||||
ggml_backend_dev_memory(model->devices[i], &free, &total);
|
||||
|
||||
// devices can return 0 bytes for free and total memory if they do not
|
||||
// have any to report. in this case, we will use the host memory as a fallback
|
||||
// fixes: https://github.com/ggml-org/llama.cpp/issues/18577
|
||||
if (free == 0 && total == 0) {
|
||||
ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
|
||||
if (cpu_dev == nullptr) {
|
||||
throw std::runtime_error(format("%s: no CPU backend found", __func__));
|
||||
}
|
||||
ggml_backend_dev_memory(cpu_dev, &free, &total);
|
||||
}
|
||||
ret[i].free = free;
|
||||
ret[i].total = total;
|
||||
}
|
||||
@@ -139,12 +152,15 @@ enum layer_fraction_t {
|
||||
};
|
||||
// this enum is only used in llama_params_fit_impl but needs to be defined outside of it to fix a Windows compilation issue
|
||||
|
||||
class llama_params_fit_exception : public std::runtime_error {
|
||||
using std::runtime_error::runtime_error;
|
||||
};
|
||||
|
||||
static void llama_params_fit_impl(
|
||||
const char * path_model, struct llama_model_params * mparams, struct llama_context_params * cparams,
|
||||
float * tensor_split, struct llama_model_tensor_buft_override * tensor_buft_overrides,
|
||||
size_t margin_s, uint32_t n_ctx_min, enum ggml_log_level log_level) {
|
||||
size_t * margins_s, uint32_t n_ctx_min, enum ggml_log_level log_level) {
|
||||
constexpr int64_t MiB = 1024*1024;
|
||||
const int64_t margin = margin_s; // this function uses int64_t rather than size_t for memory sizes to more conveniently handle deficits
|
||||
typedef std::vector<llama_device_memory_data> dmds_t;
|
||||
const llama_model_params default_mparams = llama_model_default_params();
|
||||
|
||||
@@ -163,6 +179,12 @@ static void llama_params_fit_impl(
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<int64_t> margins; // this function uses int64_t rather than size_t for memory sizes to more conveniently handle deficits
|
||||
margins.reserve(nd);
|
||||
for (size_t id = 0; id < nd; id++) {
|
||||
margins.push_back(margins_s[id]);
|
||||
}
|
||||
|
||||
std::vector<std::string> dev_names;
|
||||
{
|
||||
dev_names.reserve(nd);
|
||||
@@ -180,11 +202,12 @@ static void llama_params_fit_impl(
|
||||
}
|
||||
}
|
||||
|
||||
int64_t sum_total = 0;
|
||||
int64_t sum_projected_free = 0;
|
||||
int64_t min_projected_free = INT64_MAX;
|
||||
int64_t sum_projected_used = 0;
|
||||
int64_t sum_projected_ctx = 0;
|
||||
int64_t sum_free = 0;
|
||||
int64_t sum_projected_free = 0;
|
||||
int64_t sum_projected_used = 0;
|
||||
int64_t sum_projected_model = 0;
|
||||
std::vector<int64_t> projected_free_per_device;
|
||||
projected_free_per_device.reserve(nd);
|
||||
|
||||
if (nd > 1) {
|
||||
LLAMA_LOG_INFO("%s: projected memory use with initial parameters [MiB]:\n", __func__);
|
||||
@@ -194,63 +217,106 @@ static void llama_params_fit_impl(
|
||||
|
||||
const int64_t projected_used = dmd.mb.total();
|
||||
const int64_t projected_free = dmd.free - projected_used;
|
||||
projected_free_per_device.push_back(projected_free);
|
||||
|
||||
sum_total += dmd.total;
|
||||
sum_projected_used += projected_used;
|
||||
sum_projected_free += projected_free;
|
||||
min_projected_free = std::min(min_projected_free, projected_free);
|
||||
sum_projected_ctx += dmd.mb.context;
|
||||
sum_free += dmd.free;
|
||||
sum_projected_used += projected_used;
|
||||
sum_projected_free += projected_free;
|
||||
sum_projected_model += dmd.mb.model;
|
||||
|
||||
if (nd > 1) {
|
||||
LLAMA_LOG_INFO("%s: - %s: %6" PRId64 " total, %6" PRId64 " used, %6" PRId64 " %s\n",
|
||||
__func__, dev_names[id].c_str(), dmd.total/MiB, projected_used/MiB, std::abs(projected_free)/MiB,
|
||||
projected_free >= 0 ? "surplus" : "deficit");
|
||||
LLAMA_LOG_INFO("%s: - %s: %6" PRId64 " total, %6" PRId64 " used, %6" PRId64 " free vs. target of %6" PRId64 "\n",
|
||||
__func__, dev_names[id].c_str(), dmd.total/MiB, projected_used/MiB, projected_free/MiB, margins[id]/MiB);
|
||||
}
|
||||
}
|
||||
assert(sum_total >= 0 && sum_projected_used >= 0 && sum_projected_ctx >= 0);
|
||||
assert(sum_projected_used >= sum_projected_ctx);
|
||||
assert(sum_free >= 0 && sum_projected_used >= 0);
|
||||
LLAMA_LOG_INFO("%s: projected to use %" PRId64 " MiB of device memory vs. %" PRId64 " MiB of free device memory\n",
|
||||
__func__, sum_projected_used/MiB, sum_total/MiB);
|
||||
if (min_projected_free >= margin) {
|
||||
if (nd == 1) {
|
||||
__func__, sum_projected_used/MiB, sum_free/MiB);
|
||||
if (nd == 1) {
|
||||
if (projected_free_per_device[0] >= margins[0]) {
|
||||
LLAMA_LOG_INFO("%s: will leave %" PRId64 " >= %" PRId64 " MiB of free device memory, no changes needed\n",
|
||||
__func__, min_projected_free/MiB, margin/MiB);
|
||||
__func__, projected_free_per_device[0]/MiB, margins[0]/MiB);
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
bool changes_needed = false;
|
||||
for (size_t id = 0; id < nd; id++) {
|
||||
if (projected_free_per_device[id] < margins[id]) {
|
||||
changes_needed = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!changes_needed) {
|
||||
LLAMA_LOG_INFO("%s: targets for free memory can be met on all devices, no changes needed\n", __func__);
|
||||
return;
|
||||
}
|
||||
LLAMA_LOG_INFO("%s: will leave at least %" PRId64 " >= %" PRId64 " MiB of free memory on all devices, no changes needed\n",
|
||||
__func__, min_projected_free/MiB, margin/MiB);
|
||||
return;
|
||||
}
|
||||
|
||||
// step 2: try reducing memory use by reducing the context size
|
||||
|
||||
{
|
||||
int64_t global_surplus = sum_projected_free - int64_t(nd)*margin;
|
||||
int64_t global_surplus = sum_projected_free;
|
||||
for (size_t id = 0; id < nd; id++) {
|
||||
global_surplus -= margins[id];
|
||||
}
|
||||
if (global_surplus < 0) {
|
||||
LLAMA_LOG_INFO(nd == 1 ?
|
||||
"%s: cannot fulfill margin of %" PRId64 " MiB, need to reduce device memory by %" PRId64 " MiB\n" :
|
||||
"%s: cannot fulfill margin of %" PRId64 " MiB on all devices, need to use %" PRId64 " MiB less in total\n",
|
||||
__func__, margin/MiB, -global_surplus/MiB);
|
||||
if (nd == 1) {
|
||||
LLAMA_LOG_INFO("%s: cannot meet free memory target of %" PRId64 " MiB, need to reduce device memory by %" PRId64 " MiB\n",
|
||||
__func__, margins[0]/MiB, -global_surplus/MiB);
|
||||
} else {
|
||||
LLAMA_LOG_INFO(
|
||||
"%s: cannot meet free memory targets on all devices, need to use %" PRId64 " MiB less in total\n",
|
||||
__func__, -global_surplus/MiB);
|
||||
}
|
||||
if (cparams->n_ctx == 0) {
|
||||
if (hp_nct > n_ctx_min) {
|
||||
const int64_t bytes_per_ctx = sum_projected_ctx / hp_nct;
|
||||
const uint32_t ctx_reduction = std::min(
|
||||
uint32_t((-global_surplus + bytes_per_ctx - 1) / bytes_per_ctx), hp_nct - n_ctx_min);
|
||||
cparams->n_ctx = hp_nct - ctx_reduction;
|
||||
const int64_t memory_reduction = ctx_reduction * bytes_per_ctx;
|
||||
global_surplus += memory_reduction;
|
||||
LLAMA_LOG_INFO("%s: context size reduced from %" PRIu32 " to %" PRIu32 " -> need %" PRId64 " MiB less memory in total\n",
|
||||
__func__, hp_nct, cparams->n_ctx, memory_reduction/MiB);
|
||||
if (global_surplus >= 0) {
|
||||
int64_t sum_used_target = sum_free;
|
||||
for (size_t id = 0; id < nd; id++) {
|
||||
sum_used_target -= margins[id];
|
||||
}
|
||||
if (nd > 1) {
|
||||
// for multiple devices we need to be more conservative in terms of how much context we think can fit:
|
||||
// - for dense models only whole layers can be assigned to devices
|
||||
// - for MoE models only whole tensors can be assigned to devices, which we estimate to be <= 1/3 of a layer
|
||||
// - on average we expect a waste of 0.5 layers/tensors per device
|
||||
// - use slightly more than the expected average for nd devices to be safe
|
||||
const int64_t model_per_layer = sum_projected_model / std::min(uint32_t(mparams->n_gpu_layers), hp_ngl);
|
||||
sum_used_target -= (nd + 1) * model_per_layer / (hp_nex == 0 ? 2 : 6);
|
||||
}
|
||||
|
||||
int64_t sum_projected_used_min_ctx = 0;
|
||||
cparams->n_ctx = n_ctx_min;
|
||||
const dmds_t dmds_min_ctx = llama_get_device_memory_data(path_model, mparams, cparams, devs, hp_ngl, hp_nct, hp_nex, log_level);
|
||||
for (const auto & dmd : dmds_min_ctx) {
|
||||
sum_projected_used_min_ctx += dmd.mb.total();
|
||||
}
|
||||
if (sum_used_target > sum_projected_used_min_ctx) {
|
||||
// linear interpolation between minimum and maximum context size:
|
||||
cparams->n_ctx += (hp_nct - n_ctx_min) * (sum_used_target - sum_projected_used_min_ctx)
|
||||
/ (sum_projected_used - sum_projected_used_min_ctx);
|
||||
cparams->n_ctx = std::max(cparams->n_ctx - cparams->n_ctx % 256, n_ctx_min); // round down context for CUDA backend
|
||||
|
||||
const int64_t bytes_per_ctx = (sum_projected_used - sum_projected_used_min_ctx) / (hp_nct - n_ctx_min);
|
||||
const int64_t memory_reduction = (hp_nct - cparams->n_ctx) * bytes_per_ctx;
|
||||
LLAMA_LOG_INFO("%s: context size reduced from %" PRIu32 " to %" PRIu32 " -> need %" PRId64 " MiB less memory in total\n",
|
||||
__func__, hp_nct, cparams->n_ctx, memory_reduction/MiB);
|
||||
if (nd == 1) {
|
||||
LLAMA_LOG_INFO("%s: entire model can be fit by reducing context\n", __func__);
|
||||
return;
|
||||
}
|
||||
LLAMA_LOG_INFO("%s: entire model should be fit across devices by reducing context\n", __func__);
|
||||
} else {
|
||||
const int64_t memory_reduction = sum_projected_used - sum_projected_used_min_ctx;
|
||||
LLAMA_LOG_INFO("%s: context size reduced from %" PRIu32 " to %" PRIu32 " -> need %" PRId64 " MiB less memory in total\n",
|
||||
__func__, hp_nct, cparams->n_ctx, memory_reduction/MiB);
|
||||
}
|
||||
} else {
|
||||
LLAMA_LOG_INFO("%s: default model context size is %" PRIu32 " which is <= the min. context size of %" PRIu32 " -> no change\n",
|
||||
__func__, hp_nct, n_ctx_min);
|
||||
if (n_ctx_min == UINT32_MAX) {
|
||||
LLAMA_LOG_INFO("%s: user has requested full context size of %" PRIu32 " -> no change\n", __func__, hp_nct);
|
||||
} else {
|
||||
LLAMA_LOG_INFO("%s: default model context size is %" PRIu32 " which is <= the min. context size of %" PRIu32 " -> no change\n",
|
||||
__func__, hp_nct, n_ctx_min);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
LLAMA_LOG_INFO("%s: context size set by user to %" PRIu32 " -> no change\n", __func__, cparams->n_ctx);
|
||||
@@ -259,32 +325,28 @@ static void llama_params_fit_impl(
|
||||
}
|
||||
|
||||
if (mparams->n_gpu_layers != default_mparams.n_gpu_layers) {
|
||||
throw std::runtime_error("n_gpu_layers already set by user to " + std::to_string(mparams->n_gpu_layers) + ", abort");
|
||||
throw llama_params_fit_exception("n_gpu_layers already set by user to " + std::to_string(mparams->n_gpu_layers) + ", abort");
|
||||
}
|
||||
if (nd > 1) {
|
||||
if (!tensor_split) {
|
||||
throw std::runtime_error("did not provide a buffer to write the tensor_split to, abort");
|
||||
throw llama_params_fit_exception("did not provide a buffer to write the tensor_split to, abort");
|
||||
}
|
||||
if (mparams->tensor_split) {
|
||||
for (size_t id = 0; id < nd; id++) {
|
||||
if (mparams->tensor_split[id] != 0.0f) {
|
||||
throw std::runtime_error("model_params::tensor_split already set by user, abort");
|
||||
throw llama_params_fit_exception("model_params::tensor_split already set by user, abort");
|
||||
}
|
||||
}
|
||||
}
|
||||
if (mparams->split_mode == LLAMA_SPLIT_MODE_ROW) {
|
||||
throw std::runtime_error("changing weight allocation for LLAMA_SPLIT_MODE_ROW not implemented, abort");
|
||||
}
|
||||
if (hp_ngl < 2*nd) {
|
||||
throw std::runtime_error("model has only " + std::to_string(hp_ngl) + " layers but need at least "
|
||||
+ std::to_string(2*nd) + " to fit memory for " + std::to_string(nd) + " devices, abort");
|
||||
throw llama_params_fit_exception("changing weight allocation for LLAMA_SPLIT_MODE_ROW not implemented, abort");
|
||||
}
|
||||
}
|
||||
if (!tensor_buft_overrides) {
|
||||
throw std::runtime_error("did not provide buffer to set tensor_buft_overrides, abort");
|
||||
throw llama_params_fit_exception("did not provide buffer to set tensor_buft_overrides, abort");
|
||||
}
|
||||
if (mparams->tensor_buft_overrides && (mparams->tensor_buft_overrides->pattern || mparams->tensor_buft_overrides->buft)) {
|
||||
throw std::runtime_error("model_params::tensor_buft_overrides already set by user, abort");
|
||||
throw llama_params_fit_exception("model_params::tensor_buft_overrides already set by user, abort");
|
||||
}
|
||||
|
||||
// step 3: iteratively fill the back to front with "dense" layers
|
||||
@@ -337,6 +399,11 @@ static void llama_params_fit_impl(
|
||||
|
||||
// for the first partial layer varying parts can overflow, all further layers use LAYER_FRACTION_MOE:
|
||||
layer_fraction_t overflow_type = LAYER_FRACTION_MOE;
|
||||
|
||||
uint32_t n_full() const {
|
||||
assert(n_layer >= n_part);
|
||||
return n_layer - n_part;
|
||||
}
|
||||
};
|
||||
|
||||
const size_t ntbo = llama_max_tensor_buft_overrides();
|
||||
@@ -345,8 +412,7 @@ static void llama_params_fit_impl(
|
||||
auto set_ngl_tensor_split_tbo = [&](
|
||||
const std::vector<ngl_t> & ngl_per_device,
|
||||
const std::vector<ggml_backend_buffer_type_t> & overflow_bufts,
|
||||
llama_model_params & mparams,
|
||||
const bool add_nonrepeating) {
|
||||
llama_model_params & mparams) {
|
||||
mparams.n_gpu_layers = 0;
|
||||
for (size_t id = 0; id < nd; id++) {
|
||||
mparams.n_gpu_layers += ngl_per_device[id].n_layer;
|
||||
@@ -354,29 +420,25 @@ static void llama_params_fit_impl(
|
||||
tensor_split[id] = ngl_per_device[id].n_layer;
|
||||
}
|
||||
}
|
||||
assert(uint32_t(mparams.n_gpu_layers) <= hp_ngl);
|
||||
uint32_t il0 = hp_ngl - mparams.n_gpu_layers; // start index for tensor buft overrides
|
||||
assert(uint32_t(mparams.n_gpu_layers) <= hp_ngl + 1);
|
||||
uint32_t il0 = hp_ngl + 1 - mparams.n_gpu_layers; // start index for tensor buft overrides
|
||||
|
||||
if (add_nonrepeating) {
|
||||
mparams.n_gpu_layers += 1;
|
||||
tensor_split[nd - 1] += 1;
|
||||
}
|
||||
mparams.tensor_split = tensor_split;
|
||||
|
||||
size_t itbo = 0;
|
||||
for (size_t id = 0; id < nd; id++) {
|
||||
il0 += ngl_per_device[id].n_layer - ngl_per_device[id].n_part;
|
||||
il0 += ngl_per_device[id].n_full();
|
||||
for (uint32_t il = il0; il < il0 + ngl_per_device[id].n_part; il++) {
|
||||
if (itbo + 1 >= ntbo) {
|
||||
tensor_buft_overrides[itbo].pattern = nullptr;
|
||||
tensor_buft_overrides[itbo].buft = nullptr;
|
||||
itbo++;
|
||||
mparams.tensor_buft_overrides = tensor_buft_overrides;
|
||||
throw std::runtime_error("llama_params_fit_n_tensor_buft_overrides() == "
|
||||
+ std::to_string(ntbo) + " is insufficient for model\n");
|
||||
throw llama_params_fit_exception("llama_max_tensor_buft_overrides() == "
|
||||
+ std::to_string(ntbo) + " is insufficient for model");
|
||||
}
|
||||
tensor_buft_overrides[itbo].pattern = get_overflow_pattern(il, il == il0 ? ngl_per_device[id].overflow_type : LAYER_FRACTION_MOE);
|
||||
tensor_buft_overrides[itbo].buft = overflow_bufts[id];
|
||||
tensor_buft_overrides[itbo].buft = il == il0 ? overflow_bufts[id] : ggml_backend_cpu_buffer_type();
|
||||
itbo++;
|
||||
}
|
||||
il0 += ngl_per_device[id].n_part;
|
||||
@@ -391,10 +453,9 @@ static void llama_params_fit_impl(
|
||||
auto get_memory_for_layers = [&](
|
||||
const char * func_name,
|
||||
const std::vector<ngl_t> & ngl_per_device,
|
||||
const std::vector<ggml_backend_buffer_type_t> & overflow_bufts,
|
||||
const bool add_nonrepeating) -> std::vector<int64_t> {
|
||||
const std::vector<ggml_backend_buffer_type_t> & overflow_bufts) -> std::vector<int64_t> {
|
||||
llama_model_params mparams_copy = *mparams;
|
||||
set_ngl_tensor_split_tbo(ngl_per_device, overflow_bufts, mparams_copy, add_nonrepeating);
|
||||
set_ngl_tensor_split_tbo(ngl_per_device, overflow_bufts, mparams_copy);
|
||||
|
||||
const dmds_t dmd_nl = llama_get_device_memory_data(
|
||||
path_model, &mparams_copy, cparams, devs, hp_ngl, hp_nct, hp_nex, log_level);
|
||||
@@ -427,9 +488,9 @@ static void llama_params_fit_impl(
|
||||
const dmds_t dmds_cpu_moe = llama_get_device_memory_data(
|
||||
path_model, mparams, cparams, devs, hp_ngl, hp_nct, hp_nex, log_level);
|
||||
|
||||
for (const llama_device_memory_data & dmd : dmds_cpu_moe) {
|
||||
global_surplus_cpu_moe += dmd.free;
|
||||
global_surplus_cpu_moe -= int64_t(dmd.mb.total()) + margin;
|
||||
for (size_t id = 0; id < nd; id++) {
|
||||
global_surplus_cpu_moe += dmds_cpu_moe[id].free;
|
||||
global_surplus_cpu_moe -= int64_t(dmds_cpu_moe[id].mb.total()) + margins[id];
|
||||
}
|
||||
|
||||
if (global_surplus_cpu_moe > 0) {
|
||||
@@ -448,27 +509,18 @@ static void llama_params_fit_impl(
|
||||
std::vector<int64_t> targets; // maximum acceptable memory use per device
|
||||
targets.reserve(nd);
|
||||
for (size_t id = 0; id < nd; id++) {
|
||||
targets.push_back(dmds_full[id].free - margin);
|
||||
targets.push_back(dmds_full[id].free - margins[id]);
|
||||
LLAMA_LOG_DEBUG("%s: id=%zu, target=%" PRId64 " MiB\n", __func__, id, targets[id]/MiB);
|
||||
}
|
||||
|
||||
// whether for the optimal memory use we expect to load at least some MoE tensors:
|
||||
const bool partial_moe = hp_nex > 0 && global_surplus_cpu_moe > 0;
|
||||
|
||||
std::vector<ggml_backend_buffer_type_t> overflow_bufts; // which bufts the partial layers of a device overflow to:
|
||||
std::vector<ggml_backend_buffer_type_t> overflow_bufts; // which bufts the first partial layer of a device overflows to:
|
||||
overflow_bufts.reserve(nd);
|
||||
for (size_t id = 0; id < nd - 1; ++id) {
|
||||
overflow_bufts.push_back(ggml_backend_dev_buffer_type(devs[id + 1]));
|
||||
for (size_t id = 0; id < nd; id++) {
|
||||
overflow_bufts.push_back(ggml_backend_cpu_buffer_type());
|
||||
}
|
||||
overflow_bufts.push_back(ggml_backend_cpu_buffer_type());
|
||||
|
||||
std::vector<ngl_t> ngl_per_device(nd);
|
||||
std::vector<int64_t> mem = get_memory_for_layers(__func__, ngl_per_device, overflow_bufts, partial_moe);
|
||||
if (hp_nex > 0) {
|
||||
for (size_t id = 0; id < nd; id++) {
|
||||
ngl_per_device[id].overflow_type = LAYER_FRACTION_MOE;
|
||||
}
|
||||
}
|
||||
std::vector<int64_t> mem = get_memory_for_layers(__func__, ngl_per_device, overflow_bufts);
|
||||
|
||||
// optimize the number of layers per device using the method of false position:
|
||||
// - ngl_per_device has 0 layers for each device, lower bound
|
||||
@@ -476,22 +528,30 @@ static void llama_params_fit_impl(
|
||||
// - interpolate the memory use / layer between low and high linearly to get a guess where it meets our target
|
||||
// - check memory use of our guess, replace either the low or high bound
|
||||
// - once we only have a difference of a single layer, stop and return the lower bound that just barely still fits
|
||||
// - the last device has the output layer, which cannot be a partial layer
|
||||
if (hp_nex == 0) {
|
||||
LLAMA_LOG_INFO("%s: filling dense layers back-to-front:\n", __func__);
|
||||
} else {
|
||||
LLAMA_LOG_INFO("%s: filling dense-only layers back-to-front:\n", __func__);
|
||||
}
|
||||
uint32_t n_unassigned = hp_ngl;
|
||||
for (int id = nd - 1; id >= 0; id--) {
|
||||
uint32_t n_unassigned = hp_ngl + 1;
|
||||
for (size_t jd = id + 1; jd < nd; ++jd) {
|
||||
assert(n_unassigned >= ngl_per_device[jd].n_layer);
|
||||
n_unassigned -= ngl_per_device[jd].n_layer;
|
||||
}
|
||||
|
||||
std::vector<ngl_t> ngl_per_device_high = ngl_per_device;
|
||||
ngl_per_device_high[id].n_layer = n_unassigned;
|
||||
if (hp_nex > 0) {
|
||||
ngl_per_device_high[id].n_part = ngl_per_device_high[id].n_layer;
|
||||
ngl_per_device_high[id].n_part = size_t(id) < nd - 1 ? ngl_per_device_high[id].n_layer : ngl_per_device_high[id].n_layer - 1;
|
||||
}
|
||||
if (ngl_per_device_high[id].n_layer > 0) {
|
||||
std::vector<int64_t> mem_high = get_memory_for_layers(__func__, ngl_per_device_high, overflow_bufts, partial_moe);
|
||||
std::vector<int64_t> mem_high = get_memory_for_layers(__func__, ngl_per_device_high, overflow_bufts);
|
||||
if (mem_high[id] > targets[id]) {
|
||||
assert(ngl_per_device_high[id].n_layer > ngl_per_device[id].n_layer);
|
||||
uint32_t delta = ngl_per_device_high[id].n_layer - ngl_per_device[id].n_layer;
|
||||
LLAMA_LOG_DEBUG("%s: start filling device %" PRIu32 ", delta=%" PRIu32 "\n", __func__, id, delta);
|
||||
while (delta > 1) {
|
||||
uint32_t step_size = int64_t(delta) * (targets[id] - mem[id]) / (mem_high[id] - mem[id]);
|
||||
step_size = std::max(step_size, uint32_t(1));
|
||||
@@ -500,25 +560,26 @@ static void llama_params_fit_impl(
|
||||
std::vector<ngl_t> ngl_per_device_test = ngl_per_device;
|
||||
ngl_per_device_test[id].n_layer += step_size;
|
||||
if (hp_nex) {
|
||||
ngl_per_device_test[id].n_part += step_size;
|
||||
ngl_per_device_test[id].n_part += size_t(id) == nd - 1 && ngl_per_device_test[id].n_part == 0 ?
|
||||
step_size - 1 : step_size; // the first layer is the output layer which must always be full
|
||||
}
|
||||
const std::vector<int64_t> mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts, partial_moe);
|
||||
const std::vector<int64_t> mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts);
|
||||
|
||||
if (mem_test[id] <= targets[id]) {
|
||||
ngl_per_device = ngl_per_device_test;
|
||||
mem = mem_test;
|
||||
n_unassigned -= ngl_per_device[id].n_layer;
|
||||
ngl_per_device = ngl_per_device_test;
|
||||
mem = mem_test;
|
||||
LLAMA_LOG_DEBUG("%s: set ngl_per_device[%d].n_layer=%" PRIu32 "\n", __func__, id, ngl_per_device[id].n_layer);
|
||||
} else {
|
||||
ngl_per_device_high = ngl_per_device_test;
|
||||
mem_high = mem_test;
|
||||
LLAMA_LOG_DEBUG("%s: set ngl_per_device_high[%d].n_layer=%" PRIu32 "\n", __func__, id, ngl_per_device[id].n_layer);
|
||||
LLAMA_LOG_DEBUG("%s: set ngl_per_device_high[%d].n_layer=%" PRIu32 "\n", __func__, id, ngl_per_device_high[id].n_layer);
|
||||
}
|
||||
delta = ngl_per_device_high[id].n_layer - ngl_per_device[id].n_layer;
|
||||
}
|
||||
} else {
|
||||
ngl_per_device = ngl_per_device_high;
|
||||
n_unassigned -= ngl_per_device[id].n_layer;
|
||||
assert(ngl_per_device_high[id].n_layer == n_unassigned);
|
||||
ngl_per_device = ngl_per_device_high;
|
||||
mem = mem_high;
|
||||
LLAMA_LOG_DEBUG("%s: set ngl_per_device[%d].n_layer=%" PRIu32 "\n", __func__, id, ngl_per_device[id].n_layer);
|
||||
}
|
||||
}
|
||||
@@ -529,7 +590,7 @@ static void llama_params_fit_impl(
|
||||
__func__, dev_names[id].c_str(), ngl_per_device[id].n_layer, mem[id]/MiB, projected_margin/MiB);
|
||||
}
|
||||
if (hp_nex == 0 || global_surplus_cpu_moe <= 0) {
|
||||
set_ngl_tensor_split_tbo(ngl_per_device, overflow_bufts, *mparams, partial_moe);
|
||||
set_ngl_tensor_split_tbo(ngl_per_device, overflow_bufts, *mparams);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -549,24 +610,20 @@ static void llama_params_fit_impl(
|
||||
assert(id_dense_start < nd);
|
||||
|
||||
LLAMA_LOG_INFO("%s: converting dense-only layers to full layers and filling them front-to-back with overflow to next device/system memory:\n", __func__);
|
||||
for (size_t id = 0; id <= id_dense_start; id++) {
|
||||
for (size_t id = 0; id <= id_dense_start && id_dense_start < nd; id++) {
|
||||
std::vector<ngl_t> ngl_per_device_high = ngl_per_device;
|
||||
for (size_t jd = id_dense_start; jd < nd; jd++) {
|
||||
const uint32_t n_layer_move = ngl_per_device_high[jd].n_layer;
|
||||
const uint32_t n_layer_move = jd < nd - 1 ? ngl_per_device_high[jd].n_layer : ngl_per_device_high[jd].n_layer - 1;
|
||||
ngl_per_device_high[id].n_layer += n_layer_move;
|
||||
ngl_per_device_high[jd].n_layer -= n_layer_move;
|
||||
ngl_per_device_high[jd].n_part = 0;
|
||||
}
|
||||
size_t id_dense_start_high = nd - 1;
|
||||
std::vector<int64_t> mem_high = get_memory_for_layers(__func__, ngl_per_device_high, overflow_bufts, partial_moe);
|
||||
std::vector<int64_t> mem_high = get_memory_for_layers(__func__, ngl_per_device_high, overflow_bufts);
|
||||
|
||||
if (mem_high[id] > targets[id]) {
|
||||
assert(ngl_per_device_high[id].n_layer >= ngl_per_device_high[id].n_part);
|
||||
assert(ngl_per_device[id].n_layer >= ngl_per_device[id].n_part);
|
||||
assert((ngl_per_device_high[id].n_layer - ngl_per_device_high[id].n_part)
|
||||
>= ngl_per_device[id].n_layer - ngl_per_device[id].n_part);
|
||||
uint32_t delta = (ngl_per_device_high[id].n_layer - ngl_per_device_high[id].n_part)
|
||||
- (ngl_per_device[id].n_layer - ngl_per_device[id].n_part);
|
||||
assert(ngl_per_device_high[id].n_full() >= ngl_per_device[id].n_full());
|
||||
uint32_t delta = ngl_per_device_high[id].n_full() - ngl_per_device[id].n_full();
|
||||
while (delta > 1) {
|
||||
uint32_t step_size = int64_t(delta) * (targets[id] - mem[id]) / (mem_high[id] - mem[id]);
|
||||
step_size = std::max(step_size, uint32_t(1));
|
||||
@@ -582,11 +639,11 @@ static void llama_params_fit_impl(
|
||||
ngl_per_device_test[id].n_layer += n_convert_jd;
|
||||
n_converted_test += n_convert_jd;
|
||||
|
||||
if (ngl_per_device_test[id_dense_start_test].n_layer > 0) {
|
||||
if (ngl_per_device_test[id_dense_start_test].n_part > 0) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
const std::vector<int64_t> mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts, partial_moe);
|
||||
const std::vector<int64_t> mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts);
|
||||
|
||||
if (mem_test[id] <= targets[id]) {
|
||||
ngl_per_device = ngl_per_device_test;
|
||||
@@ -601,32 +658,38 @@ static void llama_params_fit_impl(
|
||||
LLAMA_LOG_DEBUG("%s: set ngl_per_device_high[%zu].(n_layer, n_part)=(%" PRIu32 ", %" PRIu32 "), id_dense_start_high=%zu\n",
|
||||
__func__, id, ngl_per_device_high[id].n_layer, ngl_per_device_high[id].n_part, id_dense_start_high);
|
||||
}
|
||||
delta = (ngl_per_device_high[id].n_layer - ngl_per_device_high[id].n_part)
|
||||
- (ngl_per_device[id].n_layer - ngl_per_device[id].n_part);
|
||||
assert(ngl_per_device_high[id].n_full() >= ngl_per_device[id].n_full());
|
||||
delta = ngl_per_device_high[id].n_full() - ngl_per_device[id].n_full();
|
||||
}
|
||||
} else {
|
||||
ngl_per_device = ngl_per_device_high;
|
||||
mem = mem_high;
|
||||
id_dense_start = id_dense_start_high;
|
||||
LLAMA_LOG_DEBUG("%s: set ngl_per_device[%zu].(n_layer, n_part)=(%" PRIu32 ", %" PRIu32 "), id_dense_start=%zu\n",
|
||||
__func__, id, ngl_per_device[id].n_layer, ngl_per_device[id].n_part, id_dense_start);
|
||||
}
|
||||
|
||||
// try to fit at least part of one more layer
|
||||
if (ngl_per_device[id_dense_start].n_layer > 0) {
|
||||
if (ngl_per_device[id_dense_start].n_layer > (id < nd - 1 ? 0 : 1)) {
|
||||
std::vector<ngl_t> ngl_per_device_test = ngl_per_device;
|
||||
size_t id_dense_start_test = id_dense_start;
|
||||
ngl_per_device_test[id_dense_start_test].n_layer--;
|
||||
ngl_per_device_test[id_dense_start_test].n_part--;
|
||||
ngl_per_device_test[id].n_layer++;
|
||||
ngl_per_device_test[id].n_part++;
|
||||
if (ngl_per_device_test[id_dense_start_test].n_layer == 0) {
|
||||
if (ngl_per_device_test[id_dense_start_test].n_part == 0) {
|
||||
id_dense_start_test++;
|
||||
}
|
||||
ngl_per_device_test[id].overflow_type = LAYER_FRACTION_UP;
|
||||
std::vector<ggml_backend_buffer_type_t> overflow_bufts_test = overflow_bufts;
|
||||
if (id < nd - 1) {
|
||||
overflow_bufts_test[id] = ggml_backend_dev_buffer_type(devs[id + 1]);
|
||||
}
|
||||
LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_UP\n", __func__);
|
||||
std::vector<int64_t> mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts, partial_moe);
|
||||
if (mem_test[id] < targets[id]) {
|
||||
std::vector<int64_t> mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts_test);
|
||||
if (mem_test[id] < targets[id] && (id + 1 == nd || mem_test[id + 1] < targets[id + 1])) {
|
||||
ngl_per_device = ngl_per_device_test;
|
||||
overflow_bufts = overflow_bufts_test;
|
||||
mem = mem_test;
|
||||
id_dense_start = id_dense_start_test;
|
||||
LLAMA_LOG_DEBUG("%s: set ngl_per_device[%zu].(n_layer, n_part, overflow_type)=(%" PRIu32 ", %" PRIu32 ", UP), id_dense_start=%zu\n",
|
||||
@@ -634,9 +697,10 @@ static void llama_params_fit_impl(
|
||||
|
||||
ngl_per_device_test[id].overflow_type = LAYER_FRACTION_GATE;
|
||||
LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_GATE\n", __func__);
|
||||
mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts, partial_moe);
|
||||
if (mem_test[id] < targets[id]) {
|
||||
mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts_test);
|
||||
if (mem_test[id] < targets[id] && (id + 1 == nd || mem_test[id + 1] < targets[id + 1])) {
|
||||
ngl_per_device = ngl_per_device_test;
|
||||
overflow_bufts = overflow_bufts_test;
|
||||
mem = mem_test;
|
||||
id_dense_start = id_dense_start_test;
|
||||
LLAMA_LOG_DEBUG("%s: set ngl_per_device[%zu].(n_layer, n_part, overflow_type)=(%" PRIu32 ", %" PRIu32 ", GATE), id_dense_start=%zu\n",
|
||||
@@ -645,9 +709,10 @@ static void llama_params_fit_impl(
|
||||
} else {
|
||||
ngl_per_device_test[id].overflow_type = LAYER_FRACTION_ATTN;
|
||||
LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_ATTN\n", __func__);
|
||||
mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts, partial_moe);
|
||||
if (mem_test[id] < targets[id]) {
|
||||
mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts_test);
|
||||
if (mem_test[id] < targets[id] && (id + 1 == nd || mem_test[id + 1] < targets[id + 1])) {
|
||||
ngl_per_device = ngl_per_device_test;
|
||||
overflow_bufts = overflow_bufts_test;
|
||||
mem = mem_test;
|
||||
id_dense_start = id_dense_start_test;
|
||||
LLAMA_LOG_DEBUG("%s: set ngl_per_device[%zu].(n_layer, n_part, overflow_type)=(%" PRIu32 ", %" PRIu32 ", ATTN), id_dense_start=%zu\n",
|
||||
@@ -662,30 +727,41 @@ static void llama_params_fit_impl(
|
||||
__func__, dev_names[id].c_str(), ngl_per_device[id].n_layer, ngl_per_device[id].n_part, mem[id]/MiB, projected_margin/MiB);
|
||||
}
|
||||
|
||||
set_ngl_tensor_split_tbo(ngl_per_device, overflow_bufts, *mparams, partial_moe);
|
||||
// print info for devices that were not changed during the conversion from dense only to full layers:
|
||||
for (size_t id = id_dense_start + 1; id < nd; id++) {
|
||||
const int64_t projected_margin = dmds_full[id].free - mem[id];
|
||||
LLAMA_LOG_INFO(
|
||||
"%s: - %s: %2" PRIu32 " layers (%2" PRIu32 " overflowing), %6" PRId64 " MiB used, %6" PRId64 " MiB free\n",
|
||||
__func__, dev_names[id].c_str(), ngl_per_device[id].n_layer, ngl_per_device[id].n_part, mem[id]/MiB, projected_margin/MiB);
|
||||
}
|
||||
|
||||
set_ngl_tensor_split_tbo(ngl_per_device, overflow_bufts, *mparams);
|
||||
}
|
||||
|
||||
bool llama_params_fit(
|
||||
enum llama_params_fit_status llama_params_fit(
|
||||
const char * path_model, struct llama_model_params * mparams, struct llama_context_params * cparams,
|
||||
float * tensor_split, struct llama_model_tensor_buft_override * tensor_buft_overrides,
|
||||
size_t margin_s, uint32_t n_ctx_min, enum ggml_log_level log_level) {
|
||||
size_t * margins, uint32_t n_ctx_min, enum ggml_log_level log_level) {
|
||||
const int64_t t0_us = llama_time_us();
|
||||
bool ok = true;
|
||||
llama_params_fit_status status = LLAMA_PARAMS_FIT_STATUS_SUCCESS;
|
||||
try {
|
||||
llama_params_fit_impl(path_model, mparams, cparams, tensor_split, tensor_buft_overrides, margin_s, n_ctx_min, log_level);
|
||||
llama_params_fit_impl(path_model, mparams, cparams, tensor_split, tensor_buft_overrides, margins, n_ctx_min, log_level);
|
||||
LLAMA_LOG_INFO("%s: successfully fit params to free device memory\n", __func__);
|
||||
} catch (const std::runtime_error & e) {
|
||||
} catch (const llama_params_fit_exception & e) {
|
||||
LLAMA_LOG_WARN("%s: failed to fit params to free device memory: %s\n", __func__, e.what());
|
||||
ok = false;
|
||||
status = LLAMA_PARAMS_FIT_STATUS_FAILURE;
|
||||
} catch (const std::runtime_error & e) {
|
||||
LLAMA_LOG_ERROR("%s: encountered an error while trying to fit params to free device memory: %s\n", __func__, e.what());
|
||||
status = LLAMA_PARAMS_FIT_STATUS_ERROR;
|
||||
}
|
||||
const int64_t t1_us = llama_time_us();
|
||||
LLAMA_LOG_INFO("%s: fitting params to free memory took %.2f seconds\n", __func__, (t1_us - t0_us) * 1e-6);
|
||||
return ok;
|
||||
return status;
|
||||
}
|
||||
|
||||
struct llama_sampler_chain_params llama_sampler_chain_default_params() {
|
||||
struct llama_sampler_chain_params result = {
|
||||
/*.no_perf =*/ true,
|
||||
/*.no_perf =*/ true,
|
||||
};
|
||||
|
||||
return result;
|
||||
@@ -758,7 +834,7 @@ static int llama_model_load(const std::string & fname, std::vector<std::string>
|
||||
model.t_start_us = tm.t_start_us;
|
||||
|
||||
try {
|
||||
llama_model_loader ml(fname, splits, params.use_mmap, params.check_tensors, params.no_alloc, params.kv_overrides, params.tensor_buft_overrides);
|
||||
llama_model_loader ml(fname, splits, params.use_mmap, params.use_direct_io, params.check_tensors, params.no_alloc, params.kv_overrides, params.tensor_buft_overrides);
|
||||
|
||||
ml.print_info();
|
||||
|
||||
@@ -1021,25 +1097,55 @@ int32_t llama_chat_apply_template(
|
||||
// model split
|
||||
//
|
||||
|
||||
int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int split_no, int split_count) {
|
||||
int32_t llama_split_path(
|
||||
char * split_path,
|
||||
size_t maxlen,
|
||||
const char * path_prefix,
|
||||
int32_t split_no,
|
||||
int32_t split_count) {
|
||||
|
||||
static const char * const SPLIT_PATH_FORMAT = "%s-%05d-of-%05d.gguf";
|
||||
if (snprintf(split_path, maxlen, SPLIT_PATH_FORMAT, path_prefix, split_no + 1, split_count)) {
|
||||
return strlen(split_path);
|
||||
|
||||
const int written = snprintf(
|
||||
split_path,
|
||||
maxlen,
|
||||
SPLIT_PATH_FORMAT,
|
||||
path_prefix,
|
||||
split_no + 1,
|
||||
split_count
|
||||
);
|
||||
|
||||
if (written < 0 || (size_t) written >= maxlen) {
|
||||
return 0;
|
||||
}
|
||||
return 0;
|
||||
|
||||
return (int32_t) written;
|
||||
}
|
||||
|
||||
int llama_split_prefix(char * split_prefix, size_t maxlen, const char * split_path, int split_no, int split_count) {
|
||||
std::string str_split_path(split_path);
|
||||
char postfix[32];
|
||||
snprintf(postfix, 32, "-%05d-of-%05d.gguf", split_no + 1, split_count);
|
||||
std::string str_postfix(postfix);
|
||||
int32_t llama_split_prefix(
|
||||
char * split_prefix,
|
||||
size_t maxlen,
|
||||
const char * split_path,
|
||||
int32_t split_no,
|
||||
int32_t split_count) {
|
||||
|
||||
// check if split_prefix ends with postfix
|
||||
int size_prefix = str_split_path.size() - str_postfix.size();
|
||||
if (size_prefix > 0 && str_split_path.find(str_postfix, size_prefix) != std::string::npos) {
|
||||
snprintf(split_prefix, std::min((size_t) size_prefix + 1, maxlen), "%s", split_path);
|
||||
return size_prefix;
|
||||
const std::string str_split_path(split_path);
|
||||
|
||||
char postfix[32];
|
||||
snprintf(postfix, sizeof(postfix), "-%05d-of-%05d.gguf", split_no + 1, split_count);
|
||||
|
||||
const std::string str_postfix(postfix);
|
||||
if (str_split_path.size() <= str_postfix.size()) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const size_t size_prefix = str_split_path.size() - str_postfix.size();
|
||||
|
||||
if (str_split_path.compare(size_prefix, std::string::npos, str_postfix) == 0) {
|
||||
const size_t copy_len = std::min(size_prefix + 1, maxlen);
|
||||
snprintf(split_prefix, copy_len, "%s", split_path);
|
||||
|
||||
return (int32_t) size_prefix;
|
||||
}
|
||||
|
||||
return 0;
|
||||
|
||||
14
llama/llama.cpp/src/models/afmoe.cpp
vendored
14
llama/llama.cpp/src/models/afmoe.cpp
vendored
@@ -22,8 +22,15 @@ llm_build_afmoe::llm_build_afmoe(const llama_model & model, const llm_graph_para
|
||||
const float kq_scale = 1.0f/sqrtf(float(n_embd_head));
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
const float freq_base_l = model.get_rope_freq_base (cparams, il);
|
||||
const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
|
||||
|
||||
ggml_tensor * inpSA = inpL;
|
||||
|
||||
// This overlaps with SWA layers in current models, so get_rope_freq_base/scale may be superfluous
|
||||
const bool use_rope = hparams.n_no_rope_layer_step > 0 &&
|
||||
(il + 1) % hparams.n_no_rope_layer_step != 0;
|
||||
|
||||
// dual attention normalization (pre)
|
||||
cur = build_norm(inpL,
|
||||
model.layers[il].attn_norm, NULL,
|
||||
@@ -56,19 +63,16 @@ llm_build_afmoe::llm_build_afmoe(const llama_model & model, const llm_graph_para
|
||||
cb(Qcur, "Qcur_normed", il);
|
||||
cb(Kcur, "Kcur_normed", il);
|
||||
|
||||
// RoPE only for sliding_attention layers
|
||||
const bool use_rope = hparams.n_no_rope_layer_step > 0 &&
|
||||
((il + 1) % hparams.n_no_rope_layer_step) != 0;
|
||||
if (use_rope) {
|
||||
Qcur = ggml_rope_ext(
|
||||
ctx0, Qcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
cb(Qcur, "Qcur_rope", il);
|
||||
|
||||
Kcur = ggml_rope_ext(
|
||||
ctx0, Kcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
cb(Kcur, "Kcur_rope", il);
|
||||
}
|
||||
|
||||
6
llama/llama.cpp/src/models/bert.cpp
vendored
6
llama/llama.cpp/src/models/bert.cpp
vendored
@@ -142,11 +142,13 @@ llm_build_bert::llm_build_bert(const llama_model & model, const llm_graph_params
|
||||
LLM_FFN_GELU, LLM_FFN_SEQ, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
} else if (model.arch == LLM_ARCH_JINA_BERT_V2) {
|
||||
const bool up_contains_gate = !model.layers[il].ffn_gate && model.layers[il].ffn_up->ne[1] != hparams.n_ff();
|
||||
auto type_op = up_contains_gate ? LLM_FFN_GEGLU : LLM_FFN_GELU;
|
||||
cur = build_ffn(cur,
|
||||
model.layers[il].ffn_up, NULL, NULL,
|
||||
model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
|
||||
model.layers[il].ffn_gate, NULL, NULL,
|
||||
model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, NULL,
|
||||
model.layers[il].ffn_gate ? LLM_FFN_GELU : LLM_FFN_GEGLU, LLM_FFN_PAR, il);
|
||||
type_op, LLM_FFN_PAR, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
} else {
|
||||
cur = build_ffn(cur,
|
||||
|
||||
8
llama/llama.cpp/src/models/cogvlm.cpp
vendored
8
llama/llama.cpp/src/models/cogvlm.cpp
vendored
@@ -3,12 +3,14 @@
|
||||
llm_build_cogvlm::llm_build_cogvlm(const llama_model & model, const llm_graph_params & params) :
|
||||
llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
float kq_scale = 1.0f / sqrtf(float(n_embd_head));
|
||||
const float kq_scale = 1.0f / sqrtf(float(n_embd_head));
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
|
||||
ggml_tensor *inpL, *cur;
|
||||
ggml_tensor * inpL;
|
||||
ggml_tensor * cur;
|
||||
|
||||
inpL = build_inp_embd(model.tok_embd);
|
||||
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
@@ -44,7 +46,7 @@ llm_build_cogvlm::llm_build_cogvlm(const llama_model & model, const llm_graph_pa
|
||||
}
|
||||
|
||||
ggml_tensor * inpSA = inpL;
|
||||
cur = build_norm(inpSA, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
|
||||
cur = build_norm(inpSA, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
|
||||
|
||||
// build self attention
|
||||
{
|
||||
|
||||
3
llama/llama.cpp/src/models/cohere2-iswa.cpp
vendored
3
llama/llama.cpp/src/models/cohere2-iswa.cpp
vendored
@@ -21,6 +21,9 @@ llm_build_cohere2_iswa::llm_build_cohere2_iswa(const llama_model & model, const
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
const bool is_swa = hparams.is_swa(il);
|
||||
// UNUSED:
|
||||
// const float freq_base_l = model.get_rope_freq_base (cparams, il);
|
||||
// const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
|
||||
|
||||
// norm
|
||||
cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM, il);
|
||||
|
||||
30
llama/llama.cpp/src/models/deepseek2.cpp
vendored
30
llama/llama.cpp/src/models/deepseek2.cpp
vendored
@@ -2,14 +2,11 @@
|
||||
|
||||
llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_graph_params & params) :
|
||||
llm_graph_context(params) {
|
||||
// lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B
|
||||
bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26);
|
||||
|
||||
const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0);
|
||||
const bool is_mla = hparams.is_mla();
|
||||
|
||||
// note: these are the actual head sizes you get when treating as MHA or after "decompression" using wv_b for MLA
|
||||
const int64_t n_embd_head_k = is_mla ? hparams.n_embd_head_k_mla : hparams.n_embd_head_k;
|
||||
const int64_t n_embd_head_v = is_mla ? hparams.n_embd_head_v_mla : hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head_k = hparams.n_embd_head_k_mla();
|
||||
const int64_t n_embd_head_v = hparams.n_embd_head_v_mla();
|
||||
|
||||
const int64_t n_embd_head_qk_rope = hparams.n_rot;
|
||||
const int64_t n_embd_head_qk_nope = n_embd_head_k - n_embd_head_qk_rope;
|
||||
@@ -43,7 +40,8 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
auto * inp_attn_kv = !is_mla ? build_attn_inp_kv() : nullptr;
|
||||
auto * inp_attn_k = is_mla ? build_attn_inp_k() : nullptr;
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
@@ -57,6 +55,9 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
|
||||
// self_attention
|
||||
{
|
||||
ggml_tensor * q = NULL;
|
||||
|
||||
const bool is_lite = model.layers[il].wq;
|
||||
|
||||
if (!is_lite) {
|
||||
q = ggml_mul_mat(ctx0, model.layers[il].wq_a, cur);
|
||||
cb(q, "q", il);
|
||||
@@ -124,14 +125,14 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
|
||||
|
||||
// {n_embd_head_qk_rope + kv_lora_rank, n_head, n_tokens}
|
||||
// note: rope must go first for in-place context shifting in build_rope_shift()
|
||||
ggml_tensor * Qcur = ggml_concat(ctx0, q_pe, q_nope_absorbed, 0);
|
||||
ggml_tensor * Qcur = ggml_concat(ctx0, q_nope_absorbed, q_pe, 0);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
kv_cmpr = ggml_reshape_3d(ctx0, kv_cmpr, kv_lora_rank, 1, n_tokens);
|
||||
cb(kv_cmpr, "kv_cmpr_reshape", il);
|
||||
|
||||
// {n_embd_head_qk_rope + kv_lora_rank, 1, n_tokens}
|
||||
ggml_tensor * Kcur = ggml_concat(ctx0, k_pe, kv_cmpr, 0);
|
||||
ggml_tensor * Kcur = ggml_concat(ctx0, kv_cmpr, k_pe, 0);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
// {kv_lora_rank, 1, n_tokens}
|
||||
@@ -145,7 +146,7 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
|
||||
}
|
||||
|
||||
// note: MLA with the absorption optimzation converts into MQA (ie: GQA with 1 group)
|
||||
cur = build_attn(inp_attn,
|
||||
cur = build_attn(inp_attn_k,
|
||||
model.layers[il].wo, NULL,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, model.layers[il].wv_b, kq_scale, il);
|
||||
} else {
|
||||
@@ -169,11 +170,10 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
|
||||
Vcur = ggml_cont(ctx0, Vcur);
|
||||
cb(Vcur, "Vcur_cont", il);
|
||||
|
||||
// note: rope must go first for in-place context shifting in build_rope_shift()
|
||||
ggml_tensor * Qcur = ggml_concat(ctx0, q_pe, q_nope, 0);
|
||||
ggml_tensor * Qcur = ggml_concat(ctx0, q_nope, q_pe, 0);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
ggml_tensor * Kcur = ggml_concat(ctx0, ggml_repeat(ctx0, k_pe, q_pe), k_nope, 0);
|
||||
ggml_tensor * Kcur = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
if (inp_attn_scale) {
|
||||
@@ -183,7 +183,7 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
|
||||
}
|
||||
|
||||
// note: MLA without the absorption optimization converts into MHA (ie: GQA with full n_head groups)
|
||||
cur = build_attn(inp_attn,
|
||||
cur = build_attn(inp_attn_kv,
|
||||
model.layers[il].wo, NULL,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
|
||||
}
|
||||
@@ -215,7 +215,7 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
|
||||
model.layers[il].ffn_exp_probs_b,
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_SILU, hparams.expert_weights_norm,
|
||||
true, hparams.expert_weights_scale,
|
||||
hparams.expert_weights_scale, hparams.expert_weights_scale,
|
||||
(llama_expert_gating_func_type) hparams.expert_gating_func,
|
||||
il);
|
||||
cb(moe_out, "ffn_moe_out", il);
|
||||
|
||||
146
llama/llama.cpp/src/models/exaone-moe.cpp
vendored
Normal file
146
llama/llama.cpp/src/models/exaone-moe.cpp
vendored
Normal file
@@ -0,0 +1,146 @@
|
||||
#include "models.h"
|
||||
|
||||
|
||||
llm_build_exaone_moe::llm_build_exaone_moe(const llama_model & model, const llm_graph_params & params) :
|
||||
llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_k;
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_v);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
inpL = build_inp_embd(model.tok_embd);
|
||||
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn_iswa = build_attn_inp_kv_iswa();
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
const int n_transformer_layers = n_layer - hparams.nextn_predict_layers;
|
||||
for (int il = 0; il < n_transformer_layers; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
|
||||
// use RoPE for SWA layers
|
||||
const bool is_local_layer = hparams.is_swa(il);
|
||||
|
||||
// norm
|
||||
cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
// self-attention
|
||||
{
|
||||
ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
|
||||
|
||||
// compute Q and K and RoPE them
|
||||
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
|
||||
Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(Qcur, "Qcur_normed", il);
|
||||
cb(Kcur, "Kcur_normed", il);
|
||||
|
||||
if (is_local_layer) {
|
||||
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base,
|
||||
freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
|
||||
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base,
|
||||
freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
}
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
cur = build_attn(inp_attn_iswa,
|
||||
model.layers[il].wo, NULL,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il);
|
||||
cb(cur, "attn_out", il);
|
||||
}
|
||||
if (il == n_transformer_layers - 1 && inp_out_ids) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
}
|
||||
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
// norm
|
||||
cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
// feed-forward network
|
||||
if (model.layers[il].ffn_gate_inp == nullptr) {
|
||||
// dense branch
|
||||
cur = build_ffn(cur,
|
||||
model.layers[il].ffn_up, NULL, NULL,
|
||||
model.layers[il].ffn_gate, NULL, NULL,
|
||||
model.layers[il].ffn_down, NULL, NULL, NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
} else {
|
||||
// MoE branch
|
||||
ggml_tensor * moe_out = build_moe_ffn(cur,
|
||||
model.layers[il].ffn_gate_inp,
|
||||
model.layers[il].ffn_up_exps,
|
||||
model.layers[il].ffn_gate_exps,
|
||||
model.layers[il].ffn_down_exps,
|
||||
model.layers[il].ffn_exp_probs_b,
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_SILU, hparams.expert_weights_norm,
|
||||
true, hparams.expert_weights_scale,
|
||||
(llama_expert_gating_func_type) hparams.expert_gating_func,
|
||||
il);
|
||||
cb(moe_out, "ffn_moe_out", il);
|
||||
|
||||
// FFN shared expert
|
||||
{
|
||||
ggml_tensor * ffn_shexp =
|
||||
build_ffn(cur,
|
||||
model.layers[il].ffn_up_shexp, NULL, NULL,
|
||||
model.layers[il].ffn_gate_shexp, NULL, NULL,
|
||||
model.layers[il].ffn_down_shexp, NULL, NULL,
|
||||
NULL, LLM_FFN_SILU, LLM_FFN_PAR, il);
|
||||
cb(ffn_shexp, "ffn_shexp", il);
|
||||
|
||||
cur = ggml_add(ctx0, moe_out, ffn_shexp);
|
||||
cb(cur, "ffn_out", il);
|
||||
}
|
||||
}
|
||||
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
|
||||
cur = build_cvec(cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
cur = inpL;
|
||||
|
||||
// final norm
|
||||
cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);
|
||||
|
||||
cb(cur, "result_norm", -1);
|
||||
res->t_embd = cur;
|
||||
|
||||
// lm_head
|
||||
cur = build_lora_mm(model.output, cur);
|
||||
|
||||
cb(cur, "result_output", -1);
|
||||
res->t_logits = cur;
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
@@ -1,7 +1,5 @@
|
||||
#include "models.h"
|
||||
|
||||
|
||||
|
||||
llm_build_gemma_embedding::llm_build_gemma_embedding(const llama_model & model, const llm_graph_params & params) :
|
||||
llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_k;
|
||||
@@ -12,10 +10,8 @@ llm_build_gemma_embedding::llm_build_gemma_embedding(const llama_model & model,
|
||||
inpL = build_inp_embd(model.tok_embd);
|
||||
|
||||
// important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings)
|
||||
if (ubatch.token) {
|
||||
inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd));
|
||||
cb(inpL, "inp_scaled", -1);
|
||||
}
|
||||
inpL = ggml_scale(ctx0, inpL, ubatch.token ? sqrtf(n_embd) : 1.0f);
|
||||
cb(inpL, "inp_scaled", -1);
|
||||
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user