Compare commits

...

22 Commits

Author SHA1 Message Date
ParthSareen
afa2e855d4 log probs working 2025-01-10 11:15:31 -08:00
ParthSareen
f9928b677f Working e2e logits 2025-01-03 16:06:28 -08:00
ParthSareen
c92d418a7c WIP but got logits n stuff 2025-01-02 13:32:33 -08:00
ParthSareen
d7e7e6a01e wip 2025-01-02 13:32:33 -08:00
Simon Schampijer
844899440a examples: updated deprecated imports (#3602) 2024-12-29 14:36:25 -05:00
Anas Khan
103db4216d docs: add /api/version endpoint documentation (#8082)
Co-authored-by: Jeffrey Morgan <jmorganca@gmail.com>
2024-12-29 14:33:44 -05:00
Jeffrey Morgan
6daddcde01 readme: update import header 2024-12-29 14:12:23 -05:00
Emilien Lancelot
07f7e69b36 readme: add Yacana multi-agent framework to community integrations (#7259) 2024-12-28 15:05:57 -05:00
CIIDMike
b68e8e5727 docs: add syntax highlighting on Go template code blocks (#8215) 2024-12-27 13:17:49 -05:00
Adarsh Mishra
369fb529e2 readme: add TextLLaMA to community integrations 2024-12-27 13:16:06 -05:00
Jared Donnell
023e4bca14 readme: add neollama to terminal section of community integrations (#8242) 2024-12-25 17:16:11 -05:00
aritra saha
51af455f62 readme: add alpaca client application to community integrations (#8227) 2024-12-24 23:05:35 -05:00
Emanuil Rusev
ffe3549064 readme: add IntelliBar to community integrations (#7950) 2024-12-23 12:04:18 -05:00
湛露先生
928de9050e server: reuse InvalidModelNameErrMsg type (#8163) 2024-12-23 10:38:34 -05:00
ItzCrazyKns
36aea6154a readme: add Perplexica to community-integrations (#8198) 2024-12-22 20:04:01 -05:00
Patrick Devine
dd352ab27f fix crash bug with /save when quotes are used (#8208) 2024-12-21 22:31:37 -08:00
Patrick Devine
d8bab8ea44 remove tutorials.md which pointed to removed tutorials (#8189) 2024-12-20 14:04:20 -08:00
Squishedmac
9ab62eb96f update golang.org/x dependencies (#8172) 2024-12-20 09:29:30 -08:00
Parth Sareen
290cf2040a llama: test key order preservation in schema_to_grammar (#8078)
This change adds a test to catch a regression in schema_to_grammar where
the order of keys in the JSON schema is not preserved in the generated
grammar, which is critical for step-by-step reasoning.
2024-12-18 19:44:50 -08:00
Jeffrey Morgan
a72f2dce45 scripts: sign renamed macOS binary (#8131) 2024-12-17 18:03:49 -08:00
Jesse Gross
08a832b482 llama: Ensure KV cache is fully defragmented.
Sometimes the KV cache requires defragmentation even without
triggering the threshold heuristic. In this case, decoding
will not being able to find a KV cache slot. This is particularly
difficult for the caller to handle if it happens in between
ubatches. To avoid this, we should immediately trigger a defrag.

In addition, a heavily fragmented cache can require more than
max_moves to defragment. Currently, we stop when we hit the limit
but this can leave a cache that still does not have adequate space
even after defragmentation is triggered. Instead, we should do
multiple batches of processing until everything is complete.

Fixes #7949
2024-12-17 14:01:19 -08:00
Blake Mizerany
2ddc32d5c5 llm: do not error on "null" format (#8139)
This fixes another regression in the previous commit that fixed other
known bugs.
2024-12-17 09:49:37 -08:00
20 changed files with 757 additions and 259 deletions

View File

@@ -97,7 +97,7 @@ Ollama supports importing GGUF models in the Modelfile:
ollama run example
```
### Import from PyTorch or Safetensors
### Import from Safetensors
See the [guide](docs/import.md) on importing models for more information.
@@ -298,6 +298,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [AnythingLLM (Docker + MacOs/Windows/Linux native app)](https://github.com/Mintplex-Labs/anything-llm)
- [Ollama Basic Chat: Uses HyperDiv Reactive UI](https://github.com/rapidarchitect/ollama_basic_chat)
- [Ollama-chats RPG](https://github.com/drazdra/ollama-chats)
- [IntelliBar](https://intellibar.app/) (AI-powered assistant for macOS)
- [QA-Pilot](https://github.com/reid41/QA-Pilot) (Interactive chat tool that can leverage Ollama models for rapid understanding and navigation of GitHub code repositories)
- [ChatOllama](https://github.com/sugarforever/chat-ollama) (Open Source Chatbot based on Ollama with Knowledge Bases)
- [CRAG Ollama Chat](https://github.com/Nagi-ovo/CRAG-Ollama-Chat) (Simple Web Search with Corrective RAG)
@@ -327,6 +328,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [BoltAI for Mac](https://boltai.com) (AI Chat Client for Mac)
- [Harbor](https://github.com/av/harbor) (Containerized LLM Toolkit with Ollama as default backend)
- [PyGPT](https://github.com/szczyglis-dev/py-gpt) (AI desktop assistant for Linux, Windows and Mac)
- [Alpaca](https://github.com/Jeffser/Alpaca) (An Ollama client application for linux and macos made with GTK4 and Adwaita)
- [AutoGPT](https://github.com/Significant-Gravitas/AutoGPT/blob/master/docs/content/platform/ollama.md) (AutoGPT Ollama integration)
- [Go-CREW](https://www.jonathanhecl.com/go-crew/) (Powerful Offline RAG in Golang)
- [PartCAD](https://github.com/openvmp/partcad/) (CAD model generation with OpenSCAD and CadQuery)
@@ -361,6 +363,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Abbey](https://github.com/US-Artificial-Intelligence/abbey) (A configurable AI interface server with notebooks, document storage, and YouTube support)
- [Minima](https://github.com/dmayboroda/minima) (RAG with on-premises or fully local workflow)
- [aidful-ollama-model-delete](https://github.com/AidfulAI/aidful-ollama-model-delete) (User interface for simplified model cleanup)
- [Perplexica](https://github.com/ItzCrazyKns/Perplexica) (An AI-powered search engine & an open-source alternative to Perplexity AI)
### Cloud
@@ -373,6 +376,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [oterm](https://github.com/ggozad/oterm)
- [Ellama Emacs client](https://github.com/s-kostyaev/ellama)
- [Emacs client](https://github.com/zweifisch/ollama)
- [neollama](https://github.com/paradoxical-dev/neollama) UI client for interacting with models from within Neovim
- [gen.nvim](https://github.com/David-Kunz/gen.nvim)
- [ollama.nvim](https://github.com/nomnivore/ollama.nvim)
- [ollero.nvim](https://github.com/marco-souza/ollero.nvim)
@@ -427,6 +431,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [LangChain](https://python.langchain.com/docs/integrations/llms/ollama) and [LangChain.js](https://js.langchain.com/docs/integrations/chat/ollama/) with [example](https://js.langchain.com/docs/tutorials/local_rag/)
- [Firebase Genkit](https://firebase.google.com/docs/genkit/plugins/ollama)
- [crewAI](https://github.com/crewAIInc/crewAI)
- [Yacana](https://remembersoftwares.github.io/yacana/) (User-friendly multi-agent framework for brainstorming and executing predetermined flows with built-in tool integration)
- [Spring AI](https://github.com/spring-projects/spring-ai) with [reference](https://docs.spring.io/spring-ai/reference/api/chat/ollama-chat.html) and [example](https://github.com/tzolov/ollama-tools)
- [LangChainGo](https://github.com/tmc/langchaingo/) with [example](https://github.com/tmc/langchaingo/tree/main/examples/ollama-completion-example)
- [LangChain4j](https://github.com/langchain4j/langchain4j) with [example](https://github.com/langchain4j/langchain4j-examples/tree/main/ollama-examples/src/main/java)
@@ -519,6 +524,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [AI Summmary Helper plugin](https://github.com/philffm/ai-summary-helper)
- [TextCraft](https://github.com/suncloudsmoon/TextCraft) (Copilot in Word alternative using Ollama)
- [Alfred Ollama](https://github.com/zeitlings/alfred-ollama) (Alfred Workflow)
- [TextLLaMA](https://github.com/adarshM84/TextLLaMA) A Chrome Extension that helps you write emails, correct grammar, and translate into any language
### Supported backends

View File

@@ -129,7 +129,7 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
return nil
}
const maxBufferSize = 512 * format.KiloByte
const maxBufferSize = 1024 * format.KiloByte
func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
var buf *bytes.Buffer

View File

@@ -80,6 +80,8 @@ type GenerateRequest struct {
// Options lists model-specific options. For example, temperature can be
// set through this field, if the model supports it.
Options map[string]interface{} `json:"options"`
ReturnLogits bool `json:"return_logits,omitempty"`
}
// ChatRequest describes a request sent by [Client.Chat].
@@ -105,6 +107,8 @@ type ChatRequest struct {
// Options lists model-specific options.
Options map[string]interface{} `json:"options"`
ReturnLogits bool `json:"return_logits,omitempty"`
}
type Tools []Tool
@@ -185,10 +189,12 @@ func (t *ToolFunction) String() string {
// ChatResponse is the response returned by [Client.Chat]. Its fields are
// similar to [GenerateResponse].
type ChatResponse struct {
Model string `json:"model"`
CreatedAt time.Time `json:"created_at"`
Message Message `json:"message"`
DoneReason string `json:"done_reason,omitempty"`
Model string `json:"model"`
CreatedAt time.Time `json:"created_at"`
Message Message `json:"message"`
DoneReason string `json:"done_reason,omitempty"`
Logits []float32 `json:"logits"`
TopLogprobs []TokenLogprob `json:"top_logprobs"`
Done bool `json:"done"`
@@ -204,6 +210,11 @@ type Metrics struct {
EvalDuration time.Duration `json:"eval_duration,omitempty"`
}
type TokenLogprob struct {
Text string `json:"text"`
Logprob float32 `json:"logprob"`
}
// Options specified in [GenerateRequest]. If you add a new option here, also
// add it to the API docs.
type Options struct {
@@ -450,6 +461,8 @@ type GenerateResponse struct {
Context []int `json:"context,omitempty"`
Metrics
Logits []float32 `json:"logits"`
}
// ModelDetails provides details about a model.

View File

@@ -485,6 +485,9 @@ func buildModelfile(opts runOptions) string {
}
for _, msg := range opts.Messages {
if strings.Contains(msg.Content, "\"") {
msg.Content = `"""` + msg.Content + `"""`
}
f.Commands = append(f.Commands, parser.Command{Name: "message", Args: fmt.Sprintf("%s: %s", msg.Role, msg.Content)})
}

View File

@@ -13,6 +13,7 @@
- [Push a Model](#push-a-model)
- [Generate Embeddings](#generate-embeddings)
- [List Running Models](#list-running-models)
- [Version](#version)
## Conventions
@@ -1526,3 +1527,29 @@ curl http://localhost:11434/api/embeddings -d '{
]
}
```
## Version
```shell
GET /api/version
```
Retrieve the Ollama version
### Examples
#### Request
```shell
curl http://localhost:11434/api/version
```
#### Response
```json
{
"version": "0.5.1"
}
```

View File

@@ -111,7 +111,7 @@ Keep the following tips and best practices in mind when working with Go template
ChatML is a popular template format. It can be used for models such as Databrick's DBRX, Intel's Neural Chat, and Microsoft's Orca 2.
```gotmpl
```go
{{- range .Messages }}<|im_start|>{{ .Role }}
{{ .Content }}<|im_end|>
{{ end }}<|im_start|>assistant
@@ -125,7 +125,7 @@ Tools support can be added to a model by adding a `{{ .Tools }}` node to the tem
Mistral v0.3 and Mixtral 8x22B supports tool calling.
```gotmpl
```go
{{- range $index, $_ := .Messages }}
{{- if eq .Role "user" }}
{{- if and (le (len (slice $.Messages $index)) 2) $.Tools }}[AVAILABLE_TOOLS] {{ json $.Tools }}[/AVAILABLE_TOOLS]
@@ -151,7 +151,7 @@ Fill-in-middle support can be added to a model by adding a `{{ .Suffix }}` node
CodeLlama [7B](https://ollama.com/library/codellama:7b-code) and [13B](https://ollama.com/library/codellama:13b-code) code completion models support fill-in-middle.
```gotmpl
```go
<PRE> {{ .Prompt }} <SUF>{{ .Suffix }} <MID>
```

View File

@@ -1,9 +0,0 @@
# Tutorials
Here is a list of ways you can use Ollama with other tools to build interesting applications.
- [Using LangChain with Ollama in JavaScript](./tutorials/langchainjs.md)
- [Using LangChain with Ollama in Python](./tutorials/langchainpy.md)
- [Running Ollama on NVIDIA Jetson Devices](./tutorials/nvidia-jetson.md)
Also be sure to check out the [examples](../examples) directory for more ways to use Ollama.

View File

@@ -1,8 +1,8 @@
from langchain.document_loaders import OnlinePDFLoader
from langchain.vectorstores import Chroma
from langchain.embeddings import GPT4AllEmbeddings
from langchain import PromptTemplate
from langchain.llms import Ollama
from langchain_community.document_loaders import OnlinePDFLoader
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import GPT4AllEmbeddings
from langchain_core.prompts import PromptTemplate
from langchain_community.llms import Ollama
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.chains import RetrievalQA

10
go.mod
View File

@@ -12,7 +12,7 @@ require (
github.com/spf13/cobra v1.7.0
github.com/stretchr/testify v1.9.0
github.com/x448/float16 v0.8.4
golang.org/x/sync v0.9.0
golang.org/x/sync v0.10.0
)
require (
@@ -68,12 +68,12 @@ require (
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect
golang.org/x/arch v0.8.0 // indirect
golang.org/x/crypto v0.23.0
golang.org/x/crypto v0.31.0
golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa
golang.org/x/net v0.25.0 // indirect
golang.org/x/sys v0.20.0
golang.org/x/term v0.20.0
golang.org/x/text v0.20.0
golang.org/x/sys v0.28.0
golang.org/x/term v0.27.0
golang.org/x/text v0.21.0
google.golang.org/protobuf v1.34.1
gopkg.in/yaml.v3 v3.0.1 // indirect
)

20
go.sum
View File

@@ -212,8 +212,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
@@ -266,8 +266,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.9.0 h1:fEo0HyrW1GIgZdpbhCRO0PkJajUS5H9IFUztCgEo2jQ=
golang.org/x/sync v0.9.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -283,17 +283,17 @@ golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.20.0 h1:VnkxpohqXaOBYJtBmEppKUG6mXpi+4O6purfc2+sMhw=
golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
golang.org/x/term v0.27.0 h1:WP60Sv1nlK1T6SupCHbXzSaN0b9wUmsPoRS9b61A23Q=
golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.20.0 h1:gK/Kv2otX8gz+wn7Rmb3vT96ZwuoxnQlY+HlJVj7Qug=
golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=

View File

@@ -0,0 +1,107 @@
package grammar
import (
"bufio"
"bytes"
"strings"
"testing"
"github.com/ollama/ollama/llama"
)
// https://github.com/ollama/ollama/issues/7978
const issue7978JSONSchema = `{
"type": "object",
"properties": {
"steps": {
"type": "array",
"items": {
"type": "object",
"properties": {
"explanation": { "type": "string" },
"output": { "type": "string" },
"nested": {
"type": "object",
"properties": {
"deep": { "type": "string" }
}
}
},
"required": ["explanation", "output"],
"additionalProperties": false
}
},
"final_answer": { "type": "string" },
"01_numbered_key": { "type": "string" },
"numbers": {
"type": "array",
"items": { "type": "number" }
},
"booleans": {
"type": "array",
"items": { "type": "boolean" }
},
"mixed": {
"type": "array",
"items": {
"oneOf": [
{ "type": "string" },
{ "type": "number" },
{ "type": "boolean" }
]
}
}
},
"required": ["steps", "final_answer"],
"additionalProperties": false
}`
func TestIssue7978(t *testing.T) {
g := llama.SchemaToGrammar([]byte(issue7978JSONSchema))
if g == nil {
t.Fatal("failed to convert JSON schema to grammar")
}
t.Logf("grammar:\n%s", g)
t.Log()
var got string
s := bufio.NewScanner(bytes.NewReader(g))
for s.Scan() {
line := strings.TrimSpace(s.Text())
step, _, _ := strings.Cut(line, " ::= ")
step = strings.TrimSpace(step)
if step == "root" {
got = line
}
}
want := `root ::= "{" space steps-kv "," space final-answer-kv ( "," space ( 01-numbered-key-kv 01-numbered-key-rest | numbers-kv numbers-rest | booleans-kv booleans-rest | mixed-kv ) )? "}" space`
if got != want {
t.Errorf("root =\n%qwant:\n%q", got, want)
}
}
func TestSchemaToGrammer(t *testing.T) {
cases := []struct {
schema string
prefix []byte // nil is check as nil
}{
{`invalid`, nil},
// Simple heuristic/smoke test
{`{"type":"object"}`, []byte("root ::= object")},
}
for _, c := range cases {
t.Run("x", func(t *testing.T) {
g := llama.SchemaToGrammar([]byte(c.schema))
if c.prefix == nil && g != nil {
t.Fatalf("grammar = %v, want nil", g)
}
if !bytes.HasPrefix(g, c.prefix) {
t.Errorf("grammar = %q, want %q", g, c.prefix)
}
})
}
}

View File

@@ -1,76 +0,0 @@
package llama
import (
"bufio"
"bytes"
"strings"
"testing"
)
// https://github.com/ollama/ollama/issues/7978
const issue7978JSONSchema = `{
"type": "object",
"properties": {
"steps": {
"type": "array",
"items": {
"type": "object",
"properties": {
"explanation": { "type": "string" },
"output": { "type": "string" }
},
"required": ["explanation", "output"],
"additionalProperties": false
}
},
"final_answer": { "type": "string" }
},
"required": ["steps", "final_answer"],
"additionalProperties": false
}`
func TestIssue7978(t *testing.T) {
g := SchemaToGrammar([]byte(issue7978JSONSchema))
if g == nil {
t.Fatal("failed to convert JSON schema to grammar")
}
t.Logf("grammar:\n%s", g)
t.Log()
var sawSteps bool
s := bufio.NewScanner(bytes.NewReader(g))
for s.Scan() {
line := s.Text()
if strings.Contains(line, "steps") {
sawSteps = true
}
if strings.Contains(line, "final-answer") && !sawSteps {
t.Error("expected 'steps' before 'final-answer'")
}
}
}
func TestSchemaToGrammer(t *testing.T) {
cases := []struct {
schema string
prefix []byte // nil is check as nil
}{
{`invalid`, nil},
// Simple heuristic/smoke test
{`{"type":"object"}`, []byte("root ::= object")},
}
for _, c := range cases {
t.Run("x", func(t *testing.T) {
g := SchemaToGrammar([]byte(c.schema))
if c.prefix == nil && g != nil {
t.Fatalf("grammar = %v, want nil", g)
}
if !bytes.HasPrefix(g, c.prefix) {
t.Errorf("grammar = %q, want %q", g, c.prefix)
}
})
}
}

99
llama/llama.cpp vendored
View File

@@ -3051,6 +3051,13 @@ struct llama_kv_cache {
}
};
// block of KV slots to move when defragging
struct llama_kv_defrag_move {
uint32_t src;
uint32_t dst;
uint32_t len;
};
struct llama_control_vector {
std::vector<struct ggml_tensor *> tensors; // per layer
std::vector<ggml_context_ptr> ctxs;
@@ -10828,35 +10835,23 @@ struct llm_build_context {
return gf;
}
struct ggml_cgraph * build_defrag(const std::vector<uint32_t> & ids) {
struct ggml_cgraph * build_defrag(const std::vector<struct llama_kv_defrag_move> & moves) {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
for (uint32_t i = 0; i < ids.size(); ++i) {
const uint32_t id = ids[i];
if (i == id || id == ids.size()) {
continue;
}
uint32_t nm = 1;
while (i + nm < ids.size() && ids[i + nm] == id + nm) {
nm++;
}
for (const auto & move : moves) {
for (int il = 0; il < n_layer; ++il) {
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
ggml_tensor * view_k_src = ggml_view_2d(ctx0, kv_self.k_l[il],
n_embd_k_gqa, nm,
n_embd_k_gqa, move.len,
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*i));
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*move.src));
ggml_tensor * view_k_dst = ggml_view_2d(ctx0, kv_self.k_l[il],
n_embd_k_gqa, nm,
n_embd_k_gqa, move.len,
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*id));
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*move.dst));
ggml_tensor * view_v_src;
ggml_tensor * view_v_dst;
@@ -10864,31 +10859,29 @@ struct llm_build_context {
if (flash_attn) {
// NOTE: the V cache is not transposed when using flash attention
view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il],
n_embd_v_gqa, nm,
n_embd_v_gqa, move.len,
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa),
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*i));
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*move.src));
view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il],
n_embd_v_gqa, nm,
n_embd_v_gqa, move.len,
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa),
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*id));
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*move.dst));
} else {
view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il],
nm, n_embd_v_gqa,
move.len, n_embd_v_gqa,
ggml_row_size(kv_self.v_l[il]->type, kv_self.size),
ggml_row_size(kv_self.v_l[il]->type, i));
ggml_row_size(kv_self.v_l[il]->type, move.src));
view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il],
nm, n_embd_v_gqa,
move.len, n_embd_v_gqa,
ggml_row_size(kv_self.v_l[il]->type, kv_self.size),
ggml_row_size(kv_self.v_l[il]->type, id));
ggml_row_size(kv_self.v_l[il]->type, move.dst));
}
ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_k_src, view_k_dst));
ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_v_src, view_v_dst));
}
i += nm - 1;
}
//LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
@@ -17351,7 +17344,7 @@ struct llm_build_context {
}
};
static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<struct llama_kv_defrag_move> & moves) {
llama_ubatch dummy = {};
dummy.equal_seqs = true;
@@ -17361,7 +17354,7 @@ static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const
llm.init();
struct ggml_cgraph * result = llm.build_defrag(ids);
struct ggml_cgraph * result = llm.build_defrag(moves);
llm.free();
@@ -18377,7 +18370,12 @@ static int llama_decode_internal(
kv_self.head = 0;
}
const auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
if (!slot) {
llama_kv_cache_defrag(kv_self);
llama_kv_cache_update(&lctx);
slot = llama_kv_cache_find_slot(kv_self, ubatch);
}
if (!slot) {
return 1;
}
@@ -18782,8 +18780,8 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
//const int64_t t_start = ggml_time_us();
// number of cells moved
uint32_t n_moves = 0;
// groups of cells moved
std::vector<struct llama_kv_defrag_move> moves;
// each move requires 6*n_layer tensors (see build_defrag)
// - source view, destination view, copy operation
@@ -18847,19 +18845,11 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
// are we moving a continuous block of memory?
bool cont = false;
// should we stop searching for the next move?
bool stop = false;
// go back and move the nf cells to the hole
for (; i1 < n_kv; ++i1) {
auto & cell1 = kv_self.cells[i1];
if (cell1.is_empty() || ids[i1] != n_kv) {
if (n_moves == max_moves) {
stop = true;
break;
}
cont = false;
continue;
}
@@ -18875,8 +18865,10 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
kv_self.head = n_used;
if (!cont) {
n_moves++;
moves.push_back({i1, i0 + nf, 1});
cont = true;
} else {
moves.back().len++;
}
nf++;
@@ -18886,22 +18878,16 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
}
}
if (stop || n_moves == max_moves) {
break;
}
//LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh);
i0 += nh - 1;
}
if (n_moves == 0) {
if (moves.size() == 0) {
return;
}
//LLAMA_LOG_INFO("(tmp log) KV defrag cell moves: %u\n", n_moves);
//LLAMA_LOG_INFO("expected gf nodes: %u\n", 6*n_moves*n_layer);
//LLAMA_LOG_INFO("(tmp log) KV defrag cell moves: %u\n", moves.size());
#if 0
// CPU defrag
@@ -18976,11 +18962,18 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
#else
// ggml_graph defrag
ggml_backend_sched_reset(lctx.sched.get());
for (std::size_t i = 0; i < moves.size(); i += max_moves) {
std::vector<struct llama_kv_defrag_move> chunk;
auto end = std::min(i + max_moves, moves.size());
chunk.assign(moves.begin() + i, moves.begin() + end);
ggml_cgraph * gf = llama_build_graph_defrag(lctx, ids);
ggml_backend_sched_reset(lctx.sched.get());
llama_graph_compute(lctx, gf, lctx.cparams.n_threads, lctx.threadpool);
//LLAMA_LOG_INFO("expected gf nodes: %u\n", 6*chunk.size()*n_layer);
ggml_cgraph * gf = llama_build_graph_defrag(lctx, chunk);
llama_graph_compute(lctx, gf, lctx.cparams.n_threads, lctx.threadpool);
}
#endif
//const int64_t t_end = ggml_time_us();

View File

@@ -260,6 +260,31 @@ func (c *Context) GetEmbeddingsIth(i int) []float32 {
return unsafe.Slice((*float32)(embeddings), c.Model().NEmbd())
}
// GetLogits returns the logits from the last decode operation.
// The returned slice has length equal to the vocabulary size.
func (c *Context) GetLogits() []float32 {
logits := unsafe.Pointer(C.llama_get_logits(c.c))
if logits == nil {
return nil
}
// Get the number of vocabulary tokens to determine array size
vocabSize := c.Model().NumVocab()
return unsafe.Slice((*float32)(logits), vocabSize)
}
func (m *Model) Detokenize(tokens []int) (string, error) {
var text string
for _, token := range tokens {
piece := m.TokenToPiece(token)
if piece == "" {
return "", fmt.Errorf("failed to convert token %d to piece", token)
}
text += piece
}
return text, nil
}
type ModelParams struct {
NumGpuLayers int
MainGpu int

View File

@@ -0,0 +1,242 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Jesse Gross <jesse@ollama.com>
Date: Fri, 13 Dec 2024 16:11:59 -0800
Subject: [PATCH] llama: Ensure KV cache is fully defragmented.
Sometimes the KV cache requires defragmentation even without
triggering the threshold heuristic. In this case, decoding
will not being able to find a KV cache slot. This is particularly
difficult for the caller to handle if it happens in between
ubatches. To avoid this, we should immediately trigger a defrag.
In addition, a heavily fragmented cache can require more than
max_moves to defragment. Currently, we stop when we hit the limit
but this can leave a cache that still does not have adequate space
even after defragmentation is triggered. Instead, we should do
multiple batches of processing until everything is complete.
---
src/llama.cpp | 99 ++++++++++++++++++++++++---------------------------
1 file changed, 46 insertions(+), 53 deletions(-)
diff --git a/src/llama.cpp b/src/llama.cpp
index 4778a9ed..654e32bc 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -3025,6 +3025,13 @@ struct llama_kv_cache {
}
};
+// block of KV slots to move when defragging
+struct llama_kv_defrag_move {
+ uint32_t src;
+ uint32_t dst;
+ uint32_t len;
+};
+
struct llama_control_vector {
std::vector<struct ggml_tensor *> tensors; // per layer
std::vector<ggml_context_ptr> ctxs;
@@ -10802,35 +10809,23 @@ struct llm_build_context {
return gf;
}
- struct ggml_cgraph * build_defrag(const std::vector<uint32_t> & ids) {
+ struct ggml_cgraph * build_defrag(const std::vector<struct llama_kv_defrag_move> & moves) {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
- for (uint32_t i = 0; i < ids.size(); ++i) {
- const uint32_t id = ids[i];
-
- if (i == id || id == ids.size()) {
- continue;
- }
-
- uint32_t nm = 1;
-
- while (i + nm < ids.size() && ids[i + nm] == id + nm) {
- nm++;
- }
-
+ for (const auto & move : moves) {
for (int il = 0; il < n_layer; ++il) {
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
ggml_tensor * view_k_src = ggml_view_2d(ctx0, kv_self.k_l[il],
- n_embd_k_gqa, nm,
+ n_embd_k_gqa, move.len,
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
- ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*i));
+ ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*move.src));
ggml_tensor * view_k_dst = ggml_view_2d(ctx0, kv_self.k_l[il],
- n_embd_k_gqa, nm,
+ n_embd_k_gqa, move.len,
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
- ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*id));
+ ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*move.dst));
ggml_tensor * view_v_src;
ggml_tensor * view_v_dst;
@@ -10838,31 +10833,29 @@ struct llm_build_context {
if (flash_attn) {
// NOTE: the V cache is not transposed when using flash attention
view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il],
- n_embd_v_gqa, nm,
+ n_embd_v_gqa, move.len,
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa),
- ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*i));
+ ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*move.src));
view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il],
- n_embd_v_gqa, nm,
+ n_embd_v_gqa, move.len,
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa),
- ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*id));
+ ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*move.dst));
} else {
view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il],
- nm, n_embd_v_gqa,
+ move.len, n_embd_v_gqa,
ggml_row_size(kv_self.v_l[il]->type, kv_self.size),
- ggml_row_size(kv_self.v_l[il]->type, i));
+ ggml_row_size(kv_self.v_l[il]->type, move.src));
view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il],
- nm, n_embd_v_gqa,
+ move.len, n_embd_v_gqa,
ggml_row_size(kv_self.v_l[il]->type, kv_self.size),
- ggml_row_size(kv_self.v_l[il]->type, id));
+ ggml_row_size(kv_self.v_l[il]->type, move.dst));
}
ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_k_src, view_k_dst));
ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_v_src, view_v_dst));
}
-
- i += nm - 1;
}
//LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
@@ -17325,7 +17318,7 @@ struct llm_build_context {
}
};
-static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
+static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<struct llama_kv_defrag_move> & moves) {
llama_ubatch dummy = {};
dummy.equal_seqs = true;
@@ -17335,7 +17328,7 @@ static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const
llm.init();
- struct ggml_cgraph * result = llm.build_defrag(ids);
+ struct ggml_cgraph * result = llm.build_defrag(moves);
llm.free();
@@ -18351,7 +18344,12 @@ static int llama_decode_internal(
kv_self.head = 0;
}
- const auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
+ auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
+ if (!slot) {
+ llama_kv_cache_defrag(kv_self);
+ llama_kv_cache_update(&lctx);
+ slot = llama_kv_cache_find_slot(kv_self, ubatch);
+ }
if (!slot) {
return 1;
}
@@ -18756,8 +18754,8 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
//const int64_t t_start = ggml_time_us();
- // number of cells moved
- uint32_t n_moves = 0;
+ // groups of cells moved
+ std::vector<struct llama_kv_defrag_move> moves;
// each move requires 6*n_layer tensors (see build_defrag)
// - source view, destination view, copy operation
@@ -18821,19 +18819,11 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
// are we moving a continuous block of memory?
bool cont = false;
- // should we stop searching for the next move?
- bool stop = false;
-
// go back and move the nf cells to the hole
for (; i1 < n_kv; ++i1) {
auto & cell1 = kv_self.cells[i1];
if (cell1.is_empty() || ids[i1] != n_kv) {
- if (n_moves == max_moves) {
- stop = true;
- break;
- }
-
cont = false;
continue;
}
@@ -18849,8 +18839,10 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
kv_self.head = n_used;
if (!cont) {
- n_moves++;
+ moves.push_back({i1, i0 + nf, 1});
cont = true;
+ } else {
+ moves.back().len++;
}
nf++;
@@ -18860,22 +18852,16 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
}
}
- if (stop || n_moves == max_moves) {
- break;
- }
-
//LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh);
i0 += nh - 1;
}
- if (n_moves == 0) {
+ if (moves.size() == 0) {
return;
}
- //LLAMA_LOG_INFO("(tmp log) KV defrag cell moves: %u\n", n_moves);
-
- //LLAMA_LOG_INFO("expected gf nodes: %u\n", 6*n_moves*n_layer);
+ //LLAMA_LOG_INFO("(tmp log) KV defrag cell moves: %u\n", moves.size());
#if 0
// CPU defrag
@@ -18950,11 +18936,18 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
#else
// ggml_graph defrag
- ggml_backend_sched_reset(lctx.sched.get());
+ for (std::size_t i = 0; i < moves.size(); i += max_moves) {
+ std::vector<struct llama_kv_defrag_move> chunk;
+ auto end = std::min(i + max_moves, moves.size());
+ chunk.assign(moves.begin() + i, moves.begin() + end);
- ggml_cgraph * gf = llama_build_graph_defrag(lctx, ids);
+ ggml_backend_sched_reset(lctx.sched.get());
+
+ //LLAMA_LOG_INFO("expected gf nodes: %u\n", 6*chunk.size()*n_layer);
+ ggml_cgraph * gf = llama_build_graph_defrag(lctx, chunk);
- llama_graph_compute(lctx, gf, lctx.cparams.n_threads, lctx.threadpool);
+ llama_graph_compute(lctx, gf, lctx.cparams.n_threads, lctx.threadpool);
+ }
#endif
//const int64_t t_end = ggml_time_us();

View File

@@ -8,6 +8,7 @@ import (
"fmt"
"log"
"log/slog"
"math"
"net"
"net/http"
"os"
@@ -59,7 +60,7 @@ type Sequence struct {
crossAttention bool
// channel to send responses over
responses chan string
responses chan CompletionResponse
// channel to stop decoding (such as if the remote connection is closed)
quit chan bool
@@ -88,6 +89,15 @@ type Sequence struct {
startGenerationTime time.Time
numDecoded int
numPromptInputs int
// New flag we need to add to Sequence struct
returnLogits bool
// Using our new GetLogits() method
logits []float32
// Add new channel for logits
logitsOut chan []float32
}
type NewSequenceParams struct {
@@ -96,6 +106,7 @@ type NewSequenceParams struct {
numKeep int
samplingParams *llama.SamplingParams
embedding bool
returnLogits bool
}
func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequenceParams) (*Sequence, error) {
@@ -149,13 +160,15 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
startProcessingTime: startTime,
numPredict: params.numPredict,
pendingResponses: make([]string, 0),
responses: make(chan string, 100),
responses: make(chan CompletionResponse, 100),
quit: make(chan bool, 1),
embedding: make(chan []float32, 1),
samplingCtx: sc,
embeddingOnly: params.embedding,
stop: params.stop,
numKeep: params.numKeep,
returnLogits: params.returnLogits,
logitsOut: make(chan []float32, 100),
}, nil
}
@@ -274,25 +287,34 @@ func (s *Server) allNil() bool {
}
func flushPending(seq *Sequence) bool {
joined := strings.Join(seq.pendingResponses, "")
seq.pendingResponses = []string{}
if len(seq.pendingResponses) == 0 {
return true
}
content := strings.Join(seq.pendingResponses, "")
// Check if there are any partial UTF-8 characters remaining.
// We already check and queue as we are generating but some may
// still make it here:
// - Sequence is ending, e.g. generation limit has been hit
// - Invalid characters in the middle of a string
// This is a stricter check to ensure we never output invalid Unicode.
for !utf8.ValidString(joined) {
joined = joined[:len(joined)-1]
for !utf8.ValidString(content) {
content = content[:len(content)-1]
}
seq.pendingResponses = nil
resp := CompletionResponse{
Content: content,
}
if len(joined) == 0 {
return true
// Add logits if requested and available
if seq.returnLogits && seq.logits != nil {
resp.Logits = seq.logits
seq.logits = nil
}
select {
case seq.responses <- joined:
case seq.responses <- resp:
return true
case <-seq.quit:
return false
@@ -433,14 +455,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
err := s.lc.Decode(batch)
if err != nil {
if errors.Is(err, llama.ErrKvCacheFull) {
slog.Debug("defragmenting kv cache")
s.cache.lc.KvCacheDefrag()
err = s.lc.Decode(batch)
}
if err != nil {
return fmt.Errorf("failed to decode batch: %w", err)
}
return fmt.Errorf("failed to decode batch: %w", err)
}
if crossAttention {
@@ -483,7 +498,14 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
continue
}
// sample a token
// Before sampling:
if seq.returnLogits { // New flag we need to add to Sequence struct
logits := s.lc.GetLogits()
seq.logits = make([]float32, len(logits))
copy(seq.logits, logits)
}
// Then sample token
token := seq.samplingCtx.Sample(s.lc, seq.iBatch)
seq.samplingCtx.Accept(token, true)
piece := s.model.TokenToPiece(token)
@@ -579,10 +601,11 @@ type ImageData struct {
}
type CompletionRequest struct {
Prompt string `json:"prompt"`
Images []ImageData `json:"image_data"`
Grammar string `json:"grammar"`
CachePrompt bool `json:"cache_prompt"`
Prompt string `json:"prompt"`
Images []ImageData `json:"image_data"`
Grammar string `json:"grammar"`
CachePrompt bool `json:"cache_prompt"`
ReturnLogits bool `json:"return_logits,omitempty"` // defaults to false
Options
}
@@ -595,8 +618,10 @@ type Timings struct {
}
type CompletionResponse struct {
Content string `json:"content"`
Stop bool `json:"stop"`
Content string `json:"content"`
Logits []float32 `json:"logits,omitempty"`
Tokens []string `json:"tokens,omitempty"`
Stop bool `json:"stop"`
Model string `json:"model,omitempty"`
Prompt string `json:"prompt,omitempty"`
@@ -644,12 +669,14 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
samplingParams.Seed = uint32(req.Seed)
samplingParams.Grammar = req.Grammar
slog.Info("completion request", "return_logits", req.ReturnLogits)
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
numPredict: req.NumPredict,
stop: req.Stop,
numKeep: req.NumKeep,
samplingParams: &samplingParams,
embedding: false,
returnLogits: req.ReturnLogits,
})
if err != nil {
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
@@ -699,9 +726,8 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return
case content, ok := <-seq.responses:
if ok {
if err := json.NewEncoder(w).Encode(&CompletionResponse{
Content: content,
}); err != nil {
// slog.Info("content", "content", content.Content)
if err := json.NewEncoder(w).Encode(&content); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
close(seq.quit)
return
@@ -1010,3 +1036,76 @@ func Execute(args []string) error {
cancel()
return nil
}
// // Helper function to get top K logits and convert to log probabilities
// func getTopLogits(logits []float32, k int, model *llama.Model) []api.LogProbs {
// if k <= 0 {
// return nil
// }
// // Convert logits to probabilities using softmax
// probs := softmax(logits)
// // Create slice of index/probability pairs
// pairs := make([]struct {
// token int
// prob float32
// }, len(probs))
// for i, p := range probs {
// pairs[i] = struct {
// token int
// prob float32
// }{i, p}
// }
// // Sort by probability (descending)
// sort.Slice(pairs, func(i, j int) bool {
// return pairs[i].prob > pairs[j].prob
// })
// // Take top K
// k = min(k, len(pairs))
// result := make([]api.LogProbs, k)
// for i := 0; i < k; i++ {
// result[i] = api.LogProbs{
// TopLogprobs: []api.TokenLogprob{
// {
// Token: model.TokenToPiece(pairs[i].token),
// Logprob: float32(math.Log(float64(pairs[i].prob))),
// },
// },
// }
// }
// return result
// }
// Helper function to compute softmax
func softmax(logits []float32) []float32 {
probs := make([]float32, len(logits))
// Find max for numerical stability
max := float32(math.Inf(-1))
for _, l := range logits {
if l > max {
max = l
}
}
// Compute exp(x - max) and sum
sum := float32(0)
for i, l := range logits {
ex := float32(math.Exp(float64(l - max)))
probs[i] = ex
sum += ex
}
// Normalize
for i := range probs {
probs[i] /= sum
}
return probs
}

View File

@@ -633,7 +633,8 @@ number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
ws ::= ([ \t\n] ws)?
`
const maxBufferSize = 512 * format.KiloByte
// TODO: change back to 512 * format.KiloByte
const maxBufferSize = 2048 * format.KiloByte
type ImageData struct {
Data []byte `json:"data"`
@@ -642,11 +643,12 @@ type ImageData struct {
}
type completion struct {
Content string `json:"content"`
Model string `json:"model"`
Prompt string `json:"prompt"`
Stop bool `json:"stop"`
StoppedLimit bool `json:"stopped_limit"`
Content string `json:"content"`
Model string `json:"model"`
Prompt string `json:"prompt"`
Stop bool `json:"stop"`
StoppedLimit bool `json:"stopped_limit"`
Logits []float32 `json:"logits,omitempty"`
Timings struct {
PredictedN int `json:"predicted_n"`
@@ -657,10 +659,11 @@ type completion struct {
}
type CompletionRequest struct {
Prompt string
Format json.RawMessage
Images []ImageData
Options *api.Options
Prompt string
Format json.RawMessage
Images []ImageData
Options *api.Options
ReturnLogits bool
}
type CompletionResponse struct {
@@ -671,6 +674,7 @@ type CompletionResponse struct {
PromptEvalDuration time.Duration
EvalCount int
EvalDuration time.Duration
Logits []float32
}
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
@@ -696,24 +700,29 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
"seed": req.Options.Seed,
"stop": req.Options.Stop,
"image_data": req.Images,
"return_logits": req.ReturnLogits,
"cache_prompt": true,
}
if len(req.Format) > 0 {
switch {
case bytes.Equal(req.Format, []byte(`""`)):
// fallthrough
case bytes.Equal(req.Format, []byte(`"json"`)):
switch string(req.Format) {
case `null`, `""`:
// Field was set, but "missing" a value. We accept
// these as "not set".
break
case `"json"`:
request["grammar"] = grammarJSON
case bytes.HasPrefix(req.Format, []byte("{")):
default:
if req.Format[0] != '{' {
return fmt.Errorf("invalid format: %q; expected \"json\" or a valid JSON Schema object", req.Format)
}
// User provided a JSON schema
g := llama.SchemaToGrammar(req.Format)
if g == nil {
return fmt.Errorf("invalid JSON schema in format")
}
request["grammar"] = string(g)
default:
return fmt.Errorf("invalid format: %q; expected \"json\" or a valid JSON Schema", req.Format)
}
}
@@ -817,6 +826,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
if c.Content != "" {
fn(CompletionResponse{
Content: c.Content,
Logits: c.Logits,
})
}
@@ -833,6 +843,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
PromptEvalDuration: parseDurationMs(c.Timings.PromptMS),
EvalCount: c.Timings.PredictedN,
EvalDuration: parseDurationMs(c.Timings.PredictedMS),
Logits: c.Logits,
})
return nil
}

View File

@@ -39,25 +39,34 @@ func TestLLMServerCompletionFormat(t *testing.T) {
cancel() // prevent further processing if request makes it past the format check
checkCanceled := func(err error) {
checkValid := func(err error) {
t.Helper()
if !errors.Is(err, context.Canceled) {
t.Fatalf("Completion: err = %v; expected context.Canceled", err)
}
}
valids := []string{`"json"`, `{"type":"object"}`, ``, `""`}
valids := []string{
// "missing"
``,
`""`,
`null`,
// JSON
`"json"`,
`{"type":"object"}`,
}
for _, valid := range valids {
err := s.Completion(ctx, CompletionRequest{
Options: new(api.Options),
Format: []byte(valid),
}, nil)
checkCanceled(err)
checkValid(err)
}
err := s.Completion(ctx, CompletionRequest{
Options: new(api.Options),
Format: nil, // missing format
}, nil)
checkCanceled(err)
checkValid(err)
}

View File

@@ -15,28 +15,36 @@ export CGO_CXXFLAGS=-mmacosx-version-min=11.3
export CGO_LDFLAGS=-mmacosx-version-min=11.3
rm -rf llama/build dist/darwin-*
# Generate the universal ollama binary for stand-alone usage: metal + avx
echo "Building binary"
echo "Building darwin arm64"
GOOS=darwin ARCH=arm64 GOARCH=arm64 make -j 8 dist
echo "Building darwin amd64 with AVX enabled"
GOOS=darwin ARCH=amd64 GOARCH=amd64 CUSTOM_CPU_FLAGS="avx" make -j 8 dist_exe
# Generate the universal ollama binary for stand-alone usage: metal + avx
lipo -create -output dist/ollama-darwin dist/darwin-arm64/bin/ollama dist/darwin-amd64/bin/ollama
# sign the binary and rename it
if [ -n "$APPLE_IDENTITY" ]; then
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime dist/ollama-darwin
else
echo "WARNING: Skipping code signing - set APPLE_IDENTITY"
fi
ditto -c -k --keepParent dist/ollama-darwin dist/temp.zip
if [ -n "$APPLE_IDENTITY" ]; then
xcrun notarytool submit dist/temp.zip --wait --timeout 10m --apple-id $APPLE_ID --password $APPLE_PASSWORD --team-id $APPLE_TEAM_ID
fi
rm -f dist/temp.zip
# Build the app bundle
echo "Building app"
echo "Building darwin amd64 with runners"
rm dist/darwin-amd64/bin/ollama
GOOS=darwin ARCH=amd64 GOARCH=amd64 make -j 8 dist
# Generate the universal ollama binary for the app bundle: metal + no-avx
lipo -create -output dist/ollama dist/darwin-arm64/bin/ollama dist/darwin-amd64/bin/ollama
if [ -n "$APPLE_IDENTITY" ]; then
codesign --deep --force --options=runtime --sign "$APPLE_IDENTITY" --timestamp dist/ollama
else
echo "Skipping code signing - set APPLE_IDENTITY"
fi
chmod +x dist/ollama
# build and optionally sign the mac app
npm install --prefix macapp
if [ -n "$APPLE_IDENTITY" ]; then
@@ -46,14 +54,3 @@ else
fi
cp macapp/out/make/zip/darwin/universal/Ollama-darwin-universal-$VERSION.zip dist/Ollama-darwin.zip
# sign the binary and rename it
if [ -n "$APPLE_IDENTITY" ]; then
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime dist/ollama
else
echo "WARNING: Skipping code signing - set APPLE_IDENTITY"
fi
ditto -c -k --keepParent dist/ollama dist/temp.zip
if [ -n "$APPLE_IDENTITY" ]; then
xcrun notarytool submit dist/temp.zip --wait --timeout 10m --apple-id $APPLE_ID --password $APPLE_PASSWORD --team-id $APPLE_TEAM_ID
fi
rm -f dist/temp.zip

View File

@@ -19,6 +19,7 @@ import (
"os/signal"
"path/filepath"
"slices"
"sort"
"strings"
"syscall"
"time"
@@ -142,7 +143,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
switch {
case errors.Is(err, fs.ErrNotExist):
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
case err.Error() == "invalid model name":
case err.Error() == errtypes.InvalidModelNameErrMsg:
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@@ -295,10 +296,11 @@ func (s *Server) GenerateHandler(c *gin.Context) {
var sb strings.Builder
defer close(ch)
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt,
Images: images,
Format: req.Format,
Options: opts,
Prompt: prompt,
Images: images,
Format: req.Format,
Options: opts,
ReturnLogits: false,
}, func(cr llm.CompletionResponse) {
res := api.GenerateResponse{
Model: req.Model,
@@ -312,6 +314,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
EvalCount: cr.EvalCount,
EvalDuration: cr.EvalDuration,
},
Logits: cr.Logits,
}
if _, err := sb.WriteString(cr.Content); err != nil {
@@ -568,7 +571,7 @@ func (s *Server) PullHandler(c *gin.Context) {
name := model.ParseName(cmp.Or(req.Model, req.Name))
if !name.IsValid() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid model name"})
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errtypes.InvalidModelNameErrMsg})
return
}
@@ -829,7 +832,7 @@ func (s *Server) ShowHandler(c *gin.Context) {
switch {
case os.IsNotExist(err):
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
case err.Error() == "invalid model name":
case err.Error() == errtypes.InvalidModelNameErrMsg:
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@@ -1470,7 +1473,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
switch {
case os.IsNotExist(err):
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
case err.Error() == "invalid model name":
case err.Error() == errtypes.InvalidModelNameErrMsg:
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@@ -1547,26 +1550,32 @@ func (s *Server) ChatHandler(c *gin.Context) {
var sb strings.Builder
var toolCallIndex int = 0
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt,
Images: images,
Format: req.Format,
Options: opts,
}, func(r llm.CompletionResponse) {
Prompt: prompt,
Images: images,
Format: req.Format,
Options: opts,
ReturnLogits: true,
}, func(cr llm.CompletionResponse) {
res := api.ChatResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Message: api.Message{Role: "assistant", Content: r.Content},
Done: r.Done,
DoneReason: r.DoneReason,
Message: api.Message{Role: "assistant", Content: cr.Content},
Done: cr.Done,
DoneReason: cr.DoneReason,
Logits: []float32{},
Metrics: api.Metrics{
PromptEvalCount: r.PromptEvalCount,
PromptEvalDuration: r.PromptEvalDuration,
EvalCount: r.EvalCount,
EvalDuration: r.EvalDuration,
PromptEvalCount: cr.PromptEvalCount,
PromptEvalDuration: cr.PromptEvalDuration,
EvalCount: cr.EvalCount,
EvalDuration: cr.EvalDuration,
},
}
if r.Done {
topK := int(3)
logits := make([]float32, len(cr.Logits))
copy(logits, cr.Logits)
res.TopLogprobs = getTopKLogProbs(c.Request.Context(), r, logits, topK)
if cr.Done {
res.TotalDuration = time.Since(checkpointStart)
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
}
@@ -1582,7 +1591,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
// Streaming tool calls:
// If tools are recognized, use a flag to track the sending of a tool downstream
// This ensures that content is cleared from the message on the last chunk sent
sb.WriteString(r.Content)
sb.WriteString(cr.Content)
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
res.Message.ToolCalls = toolCalls
for i := range toolCalls {
@@ -1595,7 +1604,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
return
}
if r.Done {
if cr.Done {
// Send any remaining content if no tool calls were detected
if toolCallIndex == 0 {
res.Message.Content = sb.String()
@@ -1645,6 +1654,48 @@ func (s *Server) ChatHandler(c *gin.Context) {
streamResponse(c, ch)
}
func getTopKLogProbs(ctx context.Context, s llm.LlamaServer, logits []float32, topK int) []api.TokenLogprob {
// Calculate softmax denominator first (log sum exp trick for numerical stability)
maxLogit := float32(math.Inf(-1))
for _, logit := range logits {
if logit > maxLogit {
maxLogit = logit
}
}
var sumExp float32
for _, logit := range logits {
sumExp += float32(math.Exp(float64(logit - maxLogit)))
}
logSumExp := float32(math.Log(float64(sumExp))) + maxLogit
// Calculate log probs and track top K
logProbs := make([]api.TokenLogprob, len(logits))
for i, logit := range logits {
text, err := s.Detokenize(ctx, []int{i})
if err != nil {
slog.Error("detokenize error for logprob", "error", err)
continue
}
logProbs[i] = api.TokenLogprob{
Text: text,
Logprob: logit - logSumExp,
}
}
// Sort by logprob descending and take top K
sort.Slice(logProbs, func(i, j int) bool {
return logProbs[i].Logprob > logProbs[j].Logprob
})
if len(logProbs) > topK {
logProbs = logProbs[:topK]
}
return logProbs
}
func handleScheduleError(c *gin.Context, name string, err error) {
switch {
case errors.Is(err, errCapabilities), errors.Is(err, errRequired):