mirror of
https://github.com/ollama/ollama.git
synced 2025-12-24 08:10:54 -05:00
Compare commits
119 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
adeb40eaf2 | ||
|
|
d7d33e5255 | ||
|
|
63bc884e25 | ||
|
|
ef4e095d24 | ||
|
|
4d4f75a8a8 | ||
|
|
3f71ba406a | ||
|
|
88a67127d8 | ||
|
|
f7dc7dcc64 | ||
|
|
04f971c84b | ||
|
|
70edb9bc4d | ||
|
|
3f0ed03856 | ||
|
|
4736391bfb | ||
|
|
7c5330413b | ||
|
|
39d9d22ca3 | ||
|
|
af47413dba | ||
|
|
d091fe3c21 | ||
|
|
ee02f548c8 | ||
|
|
b08870aff3 | ||
|
|
3ecae420ac | ||
|
|
4cbbf0e13b | ||
|
|
380378cc80 | ||
|
|
0963c65027 | ||
|
|
ed740a2504 | ||
|
|
c9f98622b1 | ||
|
|
0a954e5066 | ||
|
|
aa93423fbf | ||
|
|
01c9386267 | ||
|
|
af9eb36f9f | ||
|
|
06093fd396 | ||
|
|
86b7fcac32 | ||
|
|
fb8ddc564e | ||
|
|
242efe6611 | ||
|
|
1b0e6c9c0e | ||
|
|
dfa2f32ca0 | ||
|
|
840424a2c4 | ||
|
|
f56aa20014 | ||
|
|
6707768ebd | ||
|
|
c78bb76a12 | ||
|
|
942c979232 | ||
|
|
06164911dd | ||
|
|
2a21363bb7 | ||
|
|
026869915f | ||
|
|
45d61aaaa3 | ||
|
|
20f6c06569 | ||
|
|
371f5e52aa | ||
|
|
e006480e49 | ||
|
|
aed545872d | ||
|
|
44869c59d6 | ||
|
|
52663284cf | ||
|
|
42fa9d7f0a | ||
|
|
b7a87a22b6 | ||
|
|
e8aaea030e | ||
|
|
b1ad3a43cb | ||
|
|
267e25a750 | ||
|
|
9a32c514cb | ||
|
|
e9ae607ece | ||
|
|
93707fa3f2 | ||
|
|
94c369095f | ||
|
|
9164b0161b | ||
|
|
e592e8fccb | ||
|
|
bf4fc25f7b | ||
|
|
5b806d8d24 | ||
|
|
cb1e072643 | ||
|
|
45b6a12e45 | ||
|
|
68755f1f5e | ||
|
|
997a455039 | ||
|
|
88775e1ff9 | ||
|
|
8867e744ff | ||
|
|
4fd064bea6 | ||
|
|
59fbceedcc | ||
|
|
321d57e1a0 | ||
|
|
ba26c7aa00 | ||
|
|
63c763685f | ||
|
|
34a4a94f13 | ||
|
|
f4a73d57a4 | ||
|
|
948114e3e3 | ||
|
|
a3e60d9058 | ||
|
|
8acb233668 | ||
|
|
119589fcb3 | ||
|
|
5ea844964e | ||
|
|
bd8eed57fc | ||
|
|
9cf0f2e973 | ||
|
|
176ad3aa6e | ||
|
|
4d08363580 | ||
|
|
8907bf51d2 | ||
|
|
abe614c705 | ||
|
|
238715037d | ||
|
|
c0a00f68ae | ||
|
|
f0c454ab57 | ||
|
|
089daaeabc | ||
|
|
b9f74ff3d6 | ||
|
|
fcf4d60eee | ||
|
|
e33d5c2dbc | ||
|
|
18d9a7e1f1 | ||
|
|
8488388cbd | ||
|
|
588901f449 | ||
|
|
0a7fdbe533 | ||
|
|
5950c176ca | ||
|
|
23d23409a0 | ||
|
|
9009bedf13 | ||
|
|
d4ac57e240 | ||
|
|
7b59d1770f | ||
|
|
95ead8ffba | ||
|
|
7aa08a77ca | ||
|
|
7e432cdfac | ||
|
|
586672f490 | ||
|
|
b03408de74 | ||
|
|
1e6a28bf5b | ||
|
|
d6e3b64582 | ||
|
|
114c932a8e | ||
|
|
7f7103de06 | ||
|
|
c631a9c726 | ||
|
|
8fd9e56804 | ||
|
|
8a65717f55 | ||
|
|
6d3152a98a | ||
|
|
b438d485f1 | ||
|
|
204349b17b | ||
|
|
86e67fc4a9 | ||
|
|
b99c291f47 |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -11,4 +11,5 @@ ggml-metal.metal
|
||||
.idea
|
||||
test_data
|
||||
*.crt
|
||||
llm/build
|
||||
llm/build
|
||||
__debug_bin*
|
||||
30
README.md
30
README.md
@@ -1,5 +1,5 @@
|
||||
<div align="center">
|
||||
<img alt="ollama" height="200px" src="https://github.com/ollama/ollama/assets/3325447/0d0b44e2-8f4a-4e99-9b52-a5c1c741c8f7">
|
||||
<img alt="ollama" height="200px" src="https://github.com/ollama/ollama/assets/3325447/0d0b44e2-8f4a-4e99-9b52-a5c1c741c8f7">
|
||||
</div>
|
||||
|
||||
# Ollama
|
||||
@@ -51,7 +51,7 @@ Here are some example models that can be downloaded:
|
||||
| ------------------ | ---------- | ----- | ------------------------------ |
|
||||
| Llama 3 | 8B | 4.7GB | `ollama run llama3` |
|
||||
| Llama 3 | 70B | 40GB | `ollama run llama3:70b` |
|
||||
| Phi-3 | 3,8B | 2.3GB | `ollama run phi3` |
|
||||
| Phi-3 | 3.8B | 2.3GB | `ollama run phi3` |
|
||||
| Mistral | 7B | 4.1GB | `ollama run mistral` |
|
||||
| Neural Chat | 7B | 4.1GB | `ollama run neural-chat` |
|
||||
| Starling | 7B | 4.1GB | `ollama run starling-lm` |
|
||||
@@ -173,7 +173,7 @@ I'm a basic program that prints the famous "Hello, world!" message to the consol
|
||||
The image features a yellow smiley face, which is likely the central focus of the picture.
|
||||
```
|
||||
|
||||
### Pass in prompt as arguments
|
||||
### Pass the prompt as an argument
|
||||
|
||||
```
|
||||
$ ollama run llama3 "Summarize this file: $(cat README.md)"
|
||||
@@ -284,17 +284,19 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [OllamaGUI](https://github.com/enoch1118/ollamaGUI)
|
||||
- [OpenAOE](https://github.com/InternLM/OpenAOE)
|
||||
- [Odin Runes](https://github.com/leonid20000/OdinRunes)
|
||||
- [LLM-X: Progressive Web App](https://github.com/mrdjohnson/llm-x)
|
||||
- [LLM-X](https://github.com/mrdjohnson/llm-x) (Progressive Web App)
|
||||
- [AnythingLLM (Docker + MacOs/Windows/Linux native app)](https://github.com/Mintplex-Labs/anything-llm)
|
||||
- [Ollama Basic Chat: Uses HyperDiv Reactive UI](https://github.com/rapidarchitect/ollama_basic_chat)
|
||||
- [Ollama-chats RPG](https://github.com/drazdra/ollama-chats)
|
||||
- [QA-Pilot: Chat with Code Repository](https://github.com/reid41/QA-Pilot)
|
||||
- [ChatOllama: Open Source Chatbot based on Ollama with Knowledge Bases](https://github.com/sugarforever/chat-ollama)
|
||||
- [CRAG Ollama Chat: Simple Web Search with Corrective RAG](https://github.com/Nagi-ovo/CRAG-Ollama-Chat)
|
||||
- [RAGFlow: Open-source Retrieval-Augmented Generation engine based on deep document understanding](https://github.com/infiniflow/ragflow)
|
||||
- [chat: chat web app for teams](https://github.com/swuecho/chat)
|
||||
- [QA-Pilot](https://github.com/reid41/QA-Pilot) (Chat with Code Repository)
|
||||
- [ChatOllama](https://github.com/sugarforever/chat-ollama) (Open Source Chatbot based on Ollama with Knowledge Bases)
|
||||
- [CRAG Ollama Chat](https://github.com/Nagi-ovo/CRAG-Ollama-Chat) (Simple Web Search with Corrective RAG)
|
||||
- [RAGFlow](https://github.com/infiniflow/ragflow) (Open-source Retrieval-Augmented Generation engine based on deep document understanding)
|
||||
- [StreamDeploy](https://github.com/StreamDeploy-DevRel/streamdeploy-llm-app-scaffold) (LLM Application Scaffold)
|
||||
- [chat](https://github.com/swuecho/chat) (chat web app for teams)
|
||||
- [Lobe Chat](https://github.com/lobehub/lobe-chat) with [Integrating Doc](https://lobehub.com/docs/self-hosting/examples/ollama)
|
||||
- [Ollama RAG Chatbot: Local Chat with multiples PDFs using Ollama and RAG.](https://github.com/datvodinh/rag-chatbot.git)
|
||||
- [Ollama RAG Chatbot](https://github.com/datvodinh/rag-chatbot.git) (Local Chat with multiple PDFs using Ollama and RAG)
|
||||
- [BrainSoup](https://www.nurgo-software.com/products/brainsoup) (Flexible native client with RAG & multi-agent automation)
|
||||
|
||||
### Terminal
|
||||
|
||||
@@ -348,9 +350,11 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Haystack](https://github.com/deepset-ai/haystack-integrations/blob/main/integrations/ollama.md)
|
||||
- [Elixir LangChain](https://github.com/brainlid/langchain)
|
||||
- [Ollama for R - rollama](https://github.com/JBGruber/rollama)
|
||||
- [Ollama for R - ollama-r](https://github.com/hauselin/ollama-r)
|
||||
- [Ollama-ex for Elixir](https://github.com/lebrunel/ollama-ex)
|
||||
- [Ollama Connector for SAP ABAP](https://github.com/b-tocs/abap_btocs_ollama)
|
||||
- [Testcontainers](https://testcontainers.com/modules/ollama/)
|
||||
- [Portkey](https://portkey.ai/docs/welcome/integration-guides/ollama)
|
||||
|
||||
### Mobile
|
||||
|
||||
@@ -370,12 +374,13 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Ollama Telegram Bot](https://github.com/ruecat/ollama-telegram)
|
||||
- [Hass Ollama Conversation](https://github.com/ej52/hass-ollama-conversation)
|
||||
- [Rivet plugin](https://github.com/abrenneke/rivet-plugin-ollama)
|
||||
- [Llama Coder](https://github.com/ex3ndr/llama-coder) (Copilot alternative using Ollama)
|
||||
- [Obsidian BMO Chatbot plugin](https://github.com/longy2k/obsidian-bmo-chatbot)
|
||||
- [Cliobot](https://github.com/herval/cliobot) (Telegram bot with Ollama support)
|
||||
- [Copilot for Obsidian plugin](https://github.com/logancyang/obsidian-copilot)
|
||||
- [Obsidian Local GPT plugin](https://github.com/pfrankov/obsidian-local-gpt)
|
||||
- [Open Interpreter](https://docs.openinterpreter.com/language-model-setup/local-models/ollama)
|
||||
- [Llama Coder](https://github.com/ex3ndr/llama-coder) (Copilot alternative using Ollama)
|
||||
- [Ollama Copilot](https://github.com/bernardo-bruning/ollama-copilot) (Proxy that allows you to use ollama as a copilot like Github copilot)
|
||||
- [twinny](https://github.com/rjmacarthy/twinny) (Copilot and Copilot chat alternative using Ollama)
|
||||
- [Wingman-AI](https://github.com/RussellCanfield/wingman-ai) (Copilot code and chat alternative using Ollama and HuggingFace)
|
||||
- [Page Assist](https://github.com/n4ze3m/page-assist) (Chrome Extension)
|
||||
@@ -384,4 +389,5 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Discord-Ollama Chat Bot](https://github.com/kevinthedang/discord-ollama) (Generalized TypeScript Discord Bot w/ Tuning Documentation)
|
||||
|
||||
### Supported backends
|
||||
- [llama.cpp](https://github.com/ggerganov/llama.cpp) project founded by Georgi Gerganov.
|
||||
- [llama.cpp](https://github.com/ggerganov/llama.cpp) project founded by Georgi Gerganov.
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"net/url"
|
||||
"os"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/format"
|
||||
@@ -57,12 +58,36 @@ func checkError(resp *http.Response, body []byte) error {
|
||||
// If the variable is not specified, a default ollama host and port will be
|
||||
// used.
|
||||
func ClientFromEnvironment() (*Client, error) {
|
||||
ollamaHost, err := GetOllamaHost()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Client{
|
||||
base: &url.URL{
|
||||
Scheme: ollamaHost.Scheme,
|
||||
Host: net.JoinHostPort(ollamaHost.Host, ollamaHost.Port),
|
||||
},
|
||||
http: http.DefaultClient,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type OllamaHost struct {
|
||||
Scheme string
|
||||
Host string
|
||||
Port string
|
||||
}
|
||||
|
||||
func GetOllamaHost() (OllamaHost, error) {
|
||||
defaultPort := "11434"
|
||||
|
||||
scheme, hostport, ok := strings.Cut(os.Getenv("OLLAMA_HOST"), "://")
|
||||
hostVar := os.Getenv("OLLAMA_HOST")
|
||||
hostVar = strings.TrimSpace(strings.Trim(strings.TrimSpace(hostVar), "\"'"))
|
||||
|
||||
scheme, hostport, ok := strings.Cut(hostVar, "://")
|
||||
switch {
|
||||
case !ok:
|
||||
scheme, hostport = "http", os.Getenv("OLLAMA_HOST")
|
||||
scheme, hostport = "http", hostVar
|
||||
case scheme == "http":
|
||||
defaultPort = "80"
|
||||
case scheme == "https":
|
||||
@@ -82,12 +107,14 @@ func ClientFromEnvironment() (*Client, error) {
|
||||
}
|
||||
}
|
||||
|
||||
return &Client{
|
||||
base: &url.URL{
|
||||
Scheme: scheme,
|
||||
Host: net.JoinHostPort(host, port),
|
||||
},
|
||||
http: http.DefaultClient,
|
||||
if portNum, err := strconv.ParseInt(port, 10, 32); err != nil || portNum > 65535 || portNum < 0 {
|
||||
return OllamaHost{}, ErrInvalidHostPort
|
||||
}
|
||||
|
||||
return OllamaHost{
|
||||
Scheme: scheme,
|
||||
Host: host,
|
||||
Port: port,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,12 @@
|
||||
package api
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestClientFromEnvironment(t *testing.T) {
|
||||
type testCase struct {
|
||||
@@ -40,4 +46,40 @@ func TestClientFromEnvironment(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
hostTestCases := map[string]*testCase{
|
||||
"empty": {value: "", expect: "127.0.0.1:11434"},
|
||||
"only address": {value: "1.2.3.4", expect: "1.2.3.4:11434"},
|
||||
"only port": {value: ":1234", expect: ":1234"},
|
||||
"address and port": {value: "1.2.3.4:1234", expect: "1.2.3.4:1234"},
|
||||
"hostname": {value: "example.com", expect: "example.com:11434"},
|
||||
"hostname and port": {value: "example.com:1234", expect: "example.com:1234"},
|
||||
"zero port": {value: ":0", expect: ":0"},
|
||||
"too large port": {value: ":66000", err: ErrInvalidHostPort},
|
||||
"too small port": {value: ":-1", err: ErrInvalidHostPort},
|
||||
"ipv6 localhost": {value: "[::1]", expect: "[::1]:11434"},
|
||||
"ipv6 world open": {value: "[::]", expect: "[::]:11434"},
|
||||
"ipv6 no brackets": {value: "::1", expect: "[::1]:11434"},
|
||||
"ipv6 + port": {value: "[::1]:1337", expect: "[::1]:1337"},
|
||||
"extra space": {value: " 1.2.3.4 ", expect: "1.2.3.4:11434"},
|
||||
"extra quotes": {value: "\"1.2.3.4\"", expect: "1.2.3.4:11434"},
|
||||
"extra space+quotes": {value: " \" 1.2.3.4 \" ", expect: "1.2.3.4:11434"},
|
||||
"extra single quotes": {value: "'1.2.3.4'", expect: "1.2.3.4:11434"},
|
||||
}
|
||||
|
||||
for k, v := range hostTestCases {
|
||||
t.Run(k, func(t *testing.T) {
|
||||
t.Setenv("OLLAMA_HOST", v.value)
|
||||
|
||||
oh, err := GetOllamaHost()
|
||||
if err != v.err {
|
||||
t.Fatalf("expected %s, got %s", v.err, err)
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
host := net.JoinHostPort(oh.Host, oh.Port)
|
||||
assert.Equal(t, v.expect, host, fmt.Sprintf("%s: expected %s, got %s", k, v.expect, host))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
12
api/types.go
12
api/types.go
@@ -309,6 +309,7 @@ func (m *Metrics) Summary() {
|
||||
}
|
||||
|
||||
var ErrInvalidOpts = errors.New("invalid options")
|
||||
var ErrInvalidHostPort = errors.New("invalid port specified in OLLAMA_HOST")
|
||||
|
||||
func (opts *Options) FromMap(m map[string]interface{}) error {
|
||||
valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct
|
||||
@@ -435,6 +436,13 @@ type Duration struct {
|
||||
time.Duration
|
||||
}
|
||||
|
||||
func (d Duration) MarshalJSON() ([]byte, error) {
|
||||
if d.Duration < 0 {
|
||||
return []byte("-1"), nil
|
||||
}
|
||||
return []byte("\"" + d.Duration.String() + "\""), nil
|
||||
}
|
||||
|
||||
func (d *Duration) UnmarshalJSON(b []byte) (err error) {
|
||||
var v any
|
||||
if err := json.Unmarshal(b, &v); err != nil {
|
||||
@@ -448,7 +456,7 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) {
|
||||
if t < 0 {
|
||||
d.Duration = time.Duration(math.MaxInt64)
|
||||
} else {
|
||||
d.Duration = time.Duration(t * float64(time.Second))
|
||||
d.Duration = time.Duration(int(t) * int(time.Second))
|
||||
}
|
||||
case string:
|
||||
d.Duration, err = time.ParseDuration(t)
|
||||
@@ -458,6 +466,8 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) {
|
||||
if d.Duration < 0 {
|
||||
d.Duration = time.Duration(math.MaxInt64)
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("Unsupported type: '%s'", reflect.TypeOf(v))
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -21,6 +21,11 @@ func TestKeepAliveParsingFromJSON(t *testing.T) {
|
||||
req: `{ "keep_alive": 42 }`,
|
||||
exp: &Duration{42 * time.Second},
|
||||
},
|
||||
{
|
||||
name: "Positive Float",
|
||||
req: `{ "keep_alive": 42.5 }`,
|
||||
exp: &Duration{42 * time.Second},
|
||||
},
|
||||
{
|
||||
name: "Positive Integer String",
|
||||
req: `{ "keep_alive": "42m" }`,
|
||||
@@ -31,6 +36,11 @@ func TestKeepAliveParsingFromJSON(t *testing.T) {
|
||||
req: `{ "keep_alive": -1 }`,
|
||||
exp: &Duration{math.MaxInt64},
|
||||
},
|
||||
{
|
||||
name: "Negative Float",
|
||||
req: `{ "keep_alive": -3.14 }`,
|
||||
exp: &Duration{math.MaxInt64},
|
||||
},
|
||||
{
|
||||
name: "Negative Integer String",
|
||||
req: `{ "keep_alive": "-1m" }`,
|
||||
@@ -48,3 +58,50 @@ func TestKeepAliveParsingFromJSON(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDurationMarshalUnmarshal(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input time.Duration
|
||||
expected time.Duration
|
||||
}{
|
||||
{
|
||||
"negative duration",
|
||||
time.Duration(-1),
|
||||
time.Duration(math.MaxInt64),
|
||||
},
|
||||
{
|
||||
"positive duration",
|
||||
time.Duration(42 * time.Second),
|
||||
time.Duration(42 * time.Second),
|
||||
},
|
||||
{
|
||||
"another positive duration",
|
||||
time.Duration(42 * time.Minute),
|
||||
time.Duration(42 * time.Minute),
|
||||
},
|
||||
{
|
||||
"zero duration",
|
||||
time.Duration(0),
|
||||
time.Duration(0),
|
||||
},
|
||||
{
|
||||
"max duration",
|
||||
time.Duration(math.MaxInt64),
|
||||
time.Duration(math.MaxInt64),
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
b, err := json.Marshal(Duration{test.input})
|
||||
require.NoError(t, err)
|
||||
|
||||
var d Duration
|
||||
err = json.Unmarshal(b, &d)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, test.expected, d.Duration, "input %v, marshalled %v, got %v", test.input, string(b), d.Duration)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,12 +5,14 @@ import (
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/server/envconfig"
|
||||
)
|
||||
|
||||
func InitLogging() {
|
||||
level := slog.LevelInfo
|
||||
|
||||
if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" {
|
||||
if envconfig.Debug {
|
||||
level = slog.LevelDebug
|
||||
}
|
||||
|
||||
|
||||
@@ -43,37 +43,36 @@ func getCLIFullPath(command string) string {
|
||||
return command
|
||||
}
|
||||
|
||||
func SpawnServer(ctx context.Context, command string) (chan int, error) {
|
||||
done := make(chan int)
|
||||
|
||||
logDir := filepath.Dir(ServerLogFile)
|
||||
_, err := os.Stat(logDir)
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
if err := os.MkdirAll(logDir, 0o755); err != nil {
|
||||
return done, fmt.Errorf("create ollama server log dir %s: %v", logDir, err)
|
||||
}
|
||||
}
|
||||
|
||||
func start(ctx context.Context, command string) (*exec.Cmd, error) {
|
||||
cmd := getCmd(ctx, getCLIFullPath(command))
|
||||
// send stdout and stderr to a file
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return done, fmt.Errorf("failed to spawn server stdout pipe %s", err)
|
||||
return nil, fmt.Errorf("failed to spawn server stdout pipe: %w", err)
|
||||
}
|
||||
stderr, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
return done, fmt.Errorf("failed to spawn server stderr pipe %s", err)
|
||||
}
|
||||
stdin, err := cmd.StdinPipe()
|
||||
if err != nil {
|
||||
return done, fmt.Errorf("failed to spawn server stdin pipe %s", err)
|
||||
return nil, fmt.Errorf("failed to spawn server stderr pipe: %w", err)
|
||||
}
|
||||
|
||||
// TODO - rotation
|
||||
logFile, err := os.OpenFile(ServerLogFile, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0755)
|
||||
if err != nil {
|
||||
return done, fmt.Errorf("failed to create server log %w", err)
|
||||
return nil, fmt.Errorf("failed to create server log: %w", err)
|
||||
}
|
||||
|
||||
logDir := filepath.Dir(ServerLogFile)
|
||||
_, err = os.Stat(logDir)
|
||||
if err != nil {
|
||||
if !errors.Is(err, os.ErrNotExist) {
|
||||
return nil, fmt.Errorf("stat ollama server log dir %s: %v", logDir, err)
|
||||
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(logDir, 0o755); err != nil {
|
||||
return nil, fmt.Errorf("create ollama server log dir %s: %v", logDir, err)
|
||||
}
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer logFile.Close()
|
||||
io.Copy(logFile, stdout) //nolint:errcheck
|
||||
@@ -117,19 +116,33 @@ func SpawnServer(ctx context.Context, command string) (chan int, error) {
|
||||
|
||||
// run the command and wait for it to finish
|
||||
if err := cmd.Start(); err != nil {
|
||||
return done, fmt.Errorf("failed to start server %w", err)
|
||||
return nil, fmt.Errorf("failed to start server %w", err)
|
||||
}
|
||||
if cmd.Process != nil {
|
||||
slog.Info(fmt.Sprintf("started ollama server with pid %d", cmd.Process.Pid))
|
||||
}
|
||||
slog.Info(fmt.Sprintf("ollama server logs %s", ServerLogFile))
|
||||
|
||||
return cmd, nil
|
||||
}
|
||||
|
||||
func SpawnServer(ctx context.Context, command string) (chan int, error) {
|
||||
done := make(chan int)
|
||||
|
||||
go func() {
|
||||
// Keep the server running unless we're shuttind down the app
|
||||
crashCount := 0
|
||||
for {
|
||||
slog.Info("starting server...")
|
||||
cmd, err := start(ctx, command)
|
||||
if err != nil {
|
||||
crashCount++
|
||||
slog.Error(fmt.Sprintf("failed to start server %s", err))
|
||||
time.Sleep(500 * time.Millisecond * time.Duration(crashCount))
|
||||
continue
|
||||
}
|
||||
|
||||
cmd.Wait() //nolint:errcheck
|
||||
stdin.Close()
|
||||
var code int
|
||||
if cmd.ProcessState != nil {
|
||||
code = cmd.ProcessState.ExitCode()
|
||||
@@ -143,15 +156,12 @@ func SpawnServer(ctx context.Context, command string) (chan int, error) {
|
||||
default:
|
||||
crashCount++
|
||||
slog.Warn(fmt.Sprintf("server crash %d - exit code %d - respawning", crashCount, code))
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
if err := cmd.Start(); err != nil {
|
||||
slog.Error(fmt.Sprintf("failed to restart server %s", err))
|
||||
// Keep trying, but back off if we keep failing
|
||||
time.Sleep(time.Duration(crashCount) * time.Second)
|
||||
}
|
||||
time.Sleep(500 * time.Millisecond * time.Duration(crashCount))
|
||||
break
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return done, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -31,16 +31,13 @@ func DoUpgrade(cancel context.CancelFunc, done chan int) error {
|
||||
"/LOG=" + filepath.Base(UpgradeLogFile), // Only relative seems reliable, so set pwd
|
||||
"/FORCECLOSEAPPLICATIONS", // Force close the tray app - might be needed
|
||||
}
|
||||
// When we're not in debug mode, make the upgrade as quiet as possible (no GUI, no prompts)
|
||||
// TODO - temporarily disable since we're pinning in debug mode for the preview
|
||||
// if debug := os.Getenv("OLLAMA_DEBUG"); debug == "" {
|
||||
// make the upgrade as quiet as possible (no GUI, no prompts)
|
||||
installArgs = append(installArgs,
|
||||
"/SP", // Skip the "This will install... Do you wish to continue" prompt
|
||||
"/SUPPRESSMSGBOXES",
|
||||
"/SILENT",
|
||||
"/VERYSILENT",
|
||||
)
|
||||
// }
|
||||
|
||||
// Safeguard in case we have requests in flight that need to drain...
|
||||
slog.Info("Waiting for server to shutdown")
|
||||
|
||||
@@ -88,8 +88,8 @@ DialogFontSize=12
|
||||
[Files]
|
||||
Source: ".\app.exe"; DestDir: "{app}"; DestName: "{#MyAppExeName}" ; Flags: ignoreversion 64bit
|
||||
Source: "..\ollama.exe"; DestDir: "{app}"; Flags: ignoreversion 64bit
|
||||
Source: "..\dist\windows-amd64\*.dll"; DestDir: "{app}"; Flags: ignoreversion 64bit
|
||||
Source: "..\dist\windows-amd64\ollama_runners\*"; DestDir: "{app}\ollama_runners"; Flags: ignoreversion 64bit recursesubdirs
|
||||
Source: "..\dist\windows-{#ARCH}\*.dll"; DestDir: "{app}"; Flags: ignoreversion 64bit
|
||||
Source: "..\dist\windows-{#ARCH}\ollama_runners\*"; DestDir: "{app}\ollama_runners"; Flags: ignoreversion 64bit recursesubdirs
|
||||
Source: "..\dist\ollama_welcome.ps1"; DestDir: "{app}"; Flags: ignoreversion
|
||||
Source: ".\assets\app.ico"; DestDir: "{app}"; Flags: ignoreversion
|
||||
#if DirExists("..\dist\windows-amd64\rocm")
|
||||
|
||||
@@ -1,71 +1,71 @@
|
||||
//go:build windows
|
||||
|
||||
package wintray
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
const (
|
||||
updatAvailableMenuID = 1
|
||||
updateMenuID = updatAvailableMenuID + 1
|
||||
separatorMenuID = updateMenuID + 1
|
||||
diagLogsMenuID = separatorMenuID + 1
|
||||
diagSeparatorMenuID = diagLogsMenuID + 1
|
||||
quitMenuID = diagSeparatorMenuID + 1
|
||||
)
|
||||
|
||||
func (t *winTray) initMenus() error {
|
||||
if err := t.addOrUpdateMenuItem(diagLogsMenuID, 0, diagLogsMenuTitle, false); err != nil {
|
||||
return fmt.Errorf("unable to create menu entries %w\n", err)
|
||||
}
|
||||
if err := t.addSeparatorMenuItem(diagSeparatorMenuID, 0); err != nil {
|
||||
return fmt.Errorf("unable to create menu entries %w", err)
|
||||
}
|
||||
if err := t.addOrUpdateMenuItem(quitMenuID, 0, quitMenuTitle, false); err != nil {
|
||||
return fmt.Errorf("unable to create menu entries %w\n", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *winTray) UpdateAvailable(ver string) error {
|
||||
if !t.updateNotified {
|
||||
slog.Debug("updating menu and sending notification for new update")
|
||||
if err := t.addOrUpdateMenuItem(updatAvailableMenuID, 0, updateAvailableMenuTitle, true); err != nil {
|
||||
return fmt.Errorf("unable to create menu entries %w", err)
|
||||
}
|
||||
if err := t.addOrUpdateMenuItem(updateMenuID, 0, updateMenutTitle, false); err != nil {
|
||||
return fmt.Errorf("unable to create menu entries %w", err)
|
||||
}
|
||||
if err := t.addSeparatorMenuItem(separatorMenuID, 0); err != nil {
|
||||
return fmt.Errorf("unable to create menu entries %w", err)
|
||||
}
|
||||
iconFilePath, err := iconBytesToFilePath(wt.updateIcon)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to write icon data to temp file: %w", err)
|
||||
}
|
||||
if err := wt.setIcon(iconFilePath); err != nil {
|
||||
return fmt.Errorf("unable to set icon: %w", err)
|
||||
}
|
||||
t.updateNotified = true
|
||||
|
||||
t.pendingUpdate = true
|
||||
// Now pop up the notification
|
||||
t.muNID.Lock()
|
||||
defer t.muNID.Unlock()
|
||||
copy(t.nid.InfoTitle[:], windows.StringToUTF16(updateTitle))
|
||||
copy(t.nid.Info[:], windows.StringToUTF16(fmt.Sprintf(updateMessage, ver)))
|
||||
t.nid.Flags |= NIF_INFO
|
||||
t.nid.Timeout = 10
|
||||
t.nid.Size = uint32(unsafe.Sizeof(*wt.nid))
|
||||
err = t.nid.modify()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
//go:build windows
|
||||
|
||||
package wintray
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
const (
|
||||
updatAvailableMenuID = 1
|
||||
updateMenuID = updatAvailableMenuID + 1
|
||||
separatorMenuID = updateMenuID + 1
|
||||
diagLogsMenuID = separatorMenuID + 1
|
||||
diagSeparatorMenuID = diagLogsMenuID + 1
|
||||
quitMenuID = diagSeparatorMenuID + 1
|
||||
)
|
||||
|
||||
func (t *winTray) initMenus() error {
|
||||
if err := t.addOrUpdateMenuItem(diagLogsMenuID, 0, diagLogsMenuTitle, false); err != nil {
|
||||
return fmt.Errorf("unable to create menu entries %w\n", err)
|
||||
}
|
||||
if err := t.addSeparatorMenuItem(diagSeparatorMenuID, 0); err != nil {
|
||||
return fmt.Errorf("unable to create menu entries %w", err)
|
||||
}
|
||||
if err := t.addOrUpdateMenuItem(quitMenuID, 0, quitMenuTitle, false); err != nil {
|
||||
return fmt.Errorf("unable to create menu entries %w\n", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *winTray) UpdateAvailable(ver string) error {
|
||||
if !t.updateNotified {
|
||||
slog.Debug("updating menu and sending notification for new update")
|
||||
if err := t.addOrUpdateMenuItem(updatAvailableMenuID, 0, updateAvailableMenuTitle, true); err != nil {
|
||||
return fmt.Errorf("unable to create menu entries %w", err)
|
||||
}
|
||||
if err := t.addOrUpdateMenuItem(updateMenuID, 0, updateMenutTitle, false); err != nil {
|
||||
return fmt.Errorf("unable to create menu entries %w", err)
|
||||
}
|
||||
if err := t.addSeparatorMenuItem(separatorMenuID, 0); err != nil {
|
||||
return fmt.Errorf("unable to create menu entries %w", err)
|
||||
}
|
||||
iconFilePath, err := iconBytesToFilePath(wt.updateIcon)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to write icon data to temp file: %w", err)
|
||||
}
|
||||
if err := wt.setIcon(iconFilePath); err != nil {
|
||||
return fmt.Errorf("unable to set icon: %w", err)
|
||||
}
|
||||
t.updateNotified = true
|
||||
|
||||
t.pendingUpdate = true
|
||||
// Now pop up the notification
|
||||
t.muNID.Lock()
|
||||
defer t.muNID.Unlock()
|
||||
copy(t.nid.InfoTitle[:], windows.StringToUTF16(updateTitle))
|
||||
copy(t.nid.Info[:], windows.StringToUTF16(fmt.Sprintf(updateMessage, ver)))
|
||||
t.nid.Flags |= NIF_INFO
|
||||
t.nid.Timeout = 10
|
||||
t.nid.Size = uint32(unsafe.Sizeof(*wt.nid))
|
||||
err = t.nid.modify()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
36
auth/auth.go
36
auth/auth.go
@@ -10,12 +10,44 @@ import (
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
const defaultPrivateKey = "id_ed25519"
|
||||
|
||||
func keyPath() (string, error) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return filepath.Join(home, ".ollama", defaultPrivateKey), nil
|
||||
}
|
||||
|
||||
func GetPublicKey() (string, error) {
|
||||
keyPath, err := keyPath()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
privateKeyFile, err := os.ReadFile(keyPath)
|
||||
if err != nil {
|
||||
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
|
||||
return "", err
|
||||
}
|
||||
|
||||
privateKey, err := ssh.ParsePrivateKey(privateKeyFile)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
publicKey := ssh.MarshalAuthorizedKey(privateKey.PublicKey())
|
||||
|
||||
return strings.TrimSpace(string(publicKey)), nil
|
||||
}
|
||||
|
||||
func NewNonce(r io.Reader, length int) (string, error) {
|
||||
nonce := make([]byte, length)
|
||||
if _, err := io.ReadFull(r, nonce); err != nil {
|
||||
@@ -26,13 +58,11 @@ func NewNonce(r io.Reader, length int) (string, error) {
|
||||
}
|
||||
|
||||
func Sign(ctx context.Context, bts []byte) (string, error) {
|
||||
home, err := os.UserHomeDir()
|
||||
keyPath, err := keyPath()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
keyPath := filepath.Join(home, ".ollama", defaultPrivateKey)
|
||||
|
||||
privateKeyFile, err := os.ReadFile(keyPath)
|
||||
if err != nil {
|
||||
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
|
||||
|
||||
101
cmd/cmd.go
101
cmd/cmd.go
@@ -32,10 +32,12 @@ import (
|
||||
"golang.org/x/term"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/auth"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/parser"
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/ollama/ollama/server"
|
||||
"github.com/ollama/ollama/types/errtypes"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
|
||||
@@ -54,12 +56,13 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
p := progress.NewProgress(os.Stderr)
|
||||
defer p.Stop()
|
||||
|
||||
modelfile, err := os.ReadFile(filename)
|
||||
f, err := os.Open(filename)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
commands, err := parser.Parse(bytes.NewReader(modelfile))
|
||||
modelfile, err := model.ParseFile(f)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -73,10 +76,10 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
spinner := progress.NewSpinner(status)
|
||||
p.Add(status, spinner)
|
||||
|
||||
for _, c := range commands {
|
||||
switch c.Name {
|
||||
for i := range modelfile.Commands {
|
||||
switch modelfile.Commands[i].Name {
|
||||
case "model", "adapter":
|
||||
path := c.Args
|
||||
path := modelfile.Commands[i].Args
|
||||
if path == "~" {
|
||||
path = home
|
||||
} else if strings.HasPrefix(path, "~/") {
|
||||
@@ -88,7 +91,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
|
||||
fi, err := os.Stat(path)
|
||||
if errors.Is(err, os.ErrNotExist) && c.Name == "model" {
|
||||
if errors.Is(err, os.ErrNotExist) && modelfile.Commands[i].Name == "model" {
|
||||
continue
|
||||
} else if err != nil {
|
||||
return err
|
||||
@@ -111,13 +114,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
name := c.Name
|
||||
if c.Name == "model" {
|
||||
name = "from"
|
||||
}
|
||||
|
||||
re := regexp.MustCompile(fmt.Sprintf(`(?im)^(%s)\s+%s\s*$`, name, c.Args))
|
||||
modelfile = re.ReplaceAll(modelfile, []byte("$1 @"+digest))
|
||||
modelfile.Commands[i].Args = "@" + digest
|
||||
}
|
||||
}
|
||||
|
||||
@@ -147,7 +144,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
|
||||
quantization, _ := cmd.Flags().GetString("quantization")
|
||||
|
||||
request := api.CreateRequest{Name: args[0], Modelfile: string(modelfile), Quantization: quantization}
|
||||
request := api.CreateRequest{Name: args[0], Modelfile: modelfile.String(), Quantization: quantization}
|
||||
if err := client.Create(cmd.Context(), &request, fn); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -357,6 +354,47 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
return generateInteractive(cmd, opts)
|
||||
}
|
||||
|
||||
func errFromUnknownKey(unknownKeyErr error) error {
|
||||
// find SSH public key in the error message
|
||||
sshKeyPattern := `ssh-\w+ [^\s"]+`
|
||||
re := regexp.MustCompile(sshKeyPattern)
|
||||
matches := re.FindStringSubmatch(unknownKeyErr.Error())
|
||||
|
||||
if len(matches) > 0 {
|
||||
serverPubKey := matches[0]
|
||||
|
||||
localPubKey, err := auth.GetPublicKey()
|
||||
if err != nil {
|
||||
return unknownKeyErr
|
||||
}
|
||||
|
||||
if runtime.GOOS == "linux" && serverPubKey != localPubKey {
|
||||
// try the ollama service public key
|
||||
svcPubKey, err := os.ReadFile("/usr/share/ollama/.ollama/id_ed25519.pub")
|
||||
if err != nil {
|
||||
return unknownKeyErr
|
||||
}
|
||||
localPubKey = strings.TrimSpace(string(svcPubKey))
|
||||
}
|
||||
|
||||
// check if the returned public key matches the local public key, this prevents adding a remote key to the user's account
|
||||
if serverPubKey != localPubKey {
|
||||
return unknownKeyErr
|
||||
}
|
||||
|
||||
var msg strings.Builder
|
||||
msg.WriteString(unknownKeyErr.Error())
|
||||
msg.WriteString("\n\nYour ollama key is:\n")
|
||||
msg.WriteString(localPubKey)
|
||||
msg.WriteString("\nAdd your key at:\n")
|
||||
msg.WriteString("https://ollama.com/settings/keys")
|
||||
|
||||
return errors.New(msg.String())
|
||||
}
|
||||
|
||||
return unknownKeyErr
|
||||
}
|
||||
|
||||
func PushHandler(cmd *cobra.Command, args []string) error {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
@@ -404,6 +442,20 @@ func PushHandler(cmd *cobra.Command, args []string) error {
|
||||
|
||||
request := api.PushRequest{Name: args[0], Insecure: insecure}
|
||||
if err := client.Push(cmd.Context(), &request, fn); err != nil {
|
||||
if spinner != nil {
|
||||
spinner.Stop()
|
||||
}
|
||||
if strings.Contains(err.Error(), "access denied") {
|
||||
return errors.New("you are not authorized to push to this namespace, create the model under a namespace you own")
|
||||
}
|
||||
host := model.ParseName(args[0]).Host
|
||||
isOllamaHost := strings.HasSuffix(host, ".ollama.ai") || strings.HasSuffix(host, ".ollama.com")
|
||||
if strings.Contains(err.Error(), errtypes.UnknownOllamaKeyErrMsg) && isOllamaHost {
|
||||
// the user has not added their ollama key to ollama.com
|
||||
// re-throw an error with a more user-friendly message
|
||||
return errFromUnknownKey(err)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -831,24 +883,27 @@ func generate(cmd *cobra.Command, opts runOptions) error {
|
||||
}
|
||||
|
||||
func RunServer(cmd *cobra.Command, _ []string) error {
|
||||
host, port, err := net.SplitHostPort(strings.Trim(os.Getenv("OLLAMA_HOST"), "\"'"))
|
||||
// retrieve the OLLAMA_HOST environment variable
|
||||
ollamaHost, err := api.GetOllamaHost()
|
||||
if err != nil {
|
||||
host, port = "127.0.0.1", "11434"
|
||||
if ip := net.ParseIP(strings.Trim(os.Getenv("OLLAMA_HOST"), "[]")); ip != nil {
|
||||
host = ip.String()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
if err := initializeKeypair(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ln, err := net.Listen("tcp", net.JoinHostPort(host, port))
|
||||
ln, err := net.Listen("tcp", net.JoinHostPort(ollamaHost.Host, ollamaHost.Port))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return server.Serve(ln)
|
||||
err = server.Serve(ln)
|
||||
if errors.Is(err, http.ErrServerClosed) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func initializeKeypair() error {
|
||||
@@ -1069,7 +1124,7 @@ Environment Variables:
|
||||
RunE: ListHandler,
|
||||
}
|
||||
copyCmd := &cobra.Command{
|
||||
Use: "cp SOURCE TARGET",
|
||||
Use: "cp SOURCE DESTINATION",
|
||||
Short: "Copy a model",
|
||||
Args: cobra.ExactArgs(2),
|
||||
PreRunE: checkServerHeartbeat,
|
||||
|
||||
@@ -94,6 +94,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
fmt.Fprintln(os.Stderr, " /show Show model information")
|
||||
fmt.Fprintln(os.Stderr, " /load <model> Load a session or model")
|
||||
fmt.Fprintln(os.Stderr, " /save <model> Save your current session")
|
||||
fmt.Fprintln(os.Stderr, " /clear Clear session context")
|
||||
fmt.Fprintln(os.Stderr, " /bye Exit")
|
||||
fmt.Fprintln(os.Stderr, " /?, /help Help for a command")
|
||||
fmt.Fprintln(os.Stderr, " /? shortcuts Help for keyboard shortcuts")
|
||||
@@ -280,6 +281,10 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
}
|
||||
fmt.Printf("Created new model '%s'\n", args[1])
|
||||
continue
|
||||
case strings.HasPrefix(line, "/clear"):
|
||||
opts.Messages = []api.Message{}
|
||||
fmt.Println("Cleared session context")
|
||||
continue
|
||||
case strings.HasPrefix(line, "/set"):
|
||||
args := strings.Fields(line)
|
||||
if len(args) > 1 {
|
||||
|
||||
@@ -53,7 +53,7 @@ func (m *SafetensorFormat) GetTensors(dirpath string, params *Params) ([]llm.Ten
|
||||
var err error
|
||||
t, offset, err = m.readTensors(f, offset, params)
|
||||
if err != nil {
|
||||
slog.Error("%v", err)
|
||||
slog.Error(err.Error())
|
||||
return nil, err
|
||||
}
|
||||
tensors = append(tensors, t...)
|
||||
@@ -122,7 +122,7 @@ func (m *SafetensorFormat) readTensors(fn string, offset uint64, params *Params)
|
||||
|
||||
ggufName, err := m.GetLayerName(k)
|
||||
if err != nil {
|
||||
slog.Error("%v", err)
|
||||
slog.Error(err.Error())
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
|
||||
@@ -74,7 +74,7 @@ func (tf *TorchFormat) GetTensors(dirpath string, params *Params) ([]llm.Tensor,
|
||||
|
||||
ggufName, err := tf.GetLayerName(k.(string))
|
||||
if err != nil {
|
||||
slog.Error("%v", err)
|
||||
slog.Error(err.Error())
|
||||
return nil, err
|
||||
}
|
||||
slog.Debug(fmt.Sprintf("finding name for '%s' -> '%s'", k.(string), ggufName))
|
||||
|
||||
60
docs/api.md
60
docs/api.md
@@ -17,7 +17,7 @@
|
||||
|
||||
### Model names
|
||||
|
||||
Model names follow a `model:tag` format, where `model` can have an optional namespace such as `example/model`. Some examples are `orca-mini:3b-q4_1` and `llama2:70b`. The tag is optional and, if not provided, will default to `latest`. The tag is used to identify a specific version.
|
||||
Model names follow a `model:tag` format, where `model` can have an optional namespace such as `example/model`. Some examples are `orca-mini:3b-q4_1` and `llama3:70b`. The tag is optional and, if not provided, will default to `latest`. The tag is used to identify a specific version.
|
||||
|
||||
### Durations
|
||||
|
||||
@@ -66,7 +66,7 @@ Enable JSON mode by setting the `format` parameter to `json`. This will structur
|
||||
|
||||
```shell
|
||||
curl http://localhost:11434/api/generate -d '{
|
||||
"model": "llama2",
|
||||
"model": "llama3",
|
||||
"prompt": "Why is the sky blue?"
|
||||
}'
|
||||
```
|
||||
@@ -77,7 +77,7 @@ A stream of JSON objects is returned:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "llama2",
|
||||
"model": "llama3",
|
||||
"created_at": "2023-08-04T08:52:19.385406455-07:00",
|
||||
"response": "The",
|
||||
"done": false
|
||||
@@ -95,11 +95,11 @@ The final response in the stream also includes additional data about the generat
|
||||
- `context`: an encoding of the conversation used in this response, this can be sent in the next request to keep a conversational memory
|
||||
- `response`: empty if the response was streamed, if not streamed, this will contain the full response
|
||||
|
||||
To calculate how fast the response is generated in tokens per second (token/s), divide `eval_count` / `eval_duration`.
|
||||
To calculate how fast the response is generated in tokens per second (token/s), divide `eval_count` / `eval_duration` * `10^9`.
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "llama2",
|
||||
"model": "llama3",
|
||||
"created_at": "2023-08-04T19:22:45.499127Z",
|
||||
"response": "",
|
||||
"done": true,
|
||||
@@ -121,7 +121,7 @@ A response can be received in one reply when streaming is off.
|
||||
|
||||
```shell
|
||||
curl http://localhost:11434/api/generate -d '{
|
||||
"model": "llama2",
|
||||
"model": "llama3",
|
||||
"prompt": "Why is the sky blue?",
|
||||
"stream": false
|
||||
}'
|
||||
@@ -133,7 +133,7 @@ If `stream` is set to `false`, the response will be a single JSON object:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "llama2",
|
||||
"model": "llama3",
|
||||
"created_at": "2023-08-04T19:22:45.499127Z",
|
||||
"response": "The sky is blue because it is the color of the sky.",
|
||||
"done": true,
|
||||
@@ -155,7 +155,7 @@ If `stream` is set to `false`, the response will be a single JSON object:
|
||||
|
||||
```shell
|
||||
curl http://localhost:11434/api/generate -d '{
|
||||
"model": "llama2",
|
||||
"model": "llama3",
|
||||
"prompt": "What color is the sky at different times of the day? Respond using JSON",
|
||||
"format": "json",
|
||||
"stream": false
|
||||
@@ -166,7 +166,7 @@ curl http://localhost:11434/api/generate -d '{
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "llama2",
|
||||
"model": "llama3",
|
||||
"created_at": "2023-11-09T21:07:55.186497Z",
|
||||
"response": "{\n\"morning\": {\n\"color\": \"blue\"\n},\n\"noon\": {\n\"color\": \"blue-gray\"\n},\n\"afternoon\": {\n\"color\": \"warm gray\"\n},\n\"evening\": {\n\"color\": \"orange\"\n}\n}\n",
|
||||
"done": true,
|
||||
@@ -289,7 +289,7 @@ If you want to set custom options for the model at runtime rather than in the Mo
|
||||
|
||||
```shell
|
||||
curl http://localhost:11434/api/generate -d '{
|
||||
"model": "llama2",
|
||||
"model": "llama3",
|
||||
"prompt": "Why is the sky blue?",
|
||||
"stream": false,
|
||||
"options": {
|
||||
@@ -332,7 +332,7 @@ curl http://localhost:11434/api/generate -d '{
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "llama2",
|
||||
"model": "llama3",
|
||||
"created_at": "2023-08-04T19:22:45.499127Z",
|
||||
"response": "The sky is blue because it is the color of the sky.",
|
||||
"done": true,
|
||||
@@ -354,7 +354,7 @@ If an empty prompt is provided, the model will be loaded into memory.
|
||||
|
||||
```shell
|
||||
curl http://localhost:11434/api/generate -d '{
|
||||
"model": "llama2"
|
||||
"model": "llama3"
|
||||
}'
|
||||
```
|
||||
|
||||
@@ -364,7 +364,7 @@ A single JSON object is returned:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "llama2",
|
||||
"model": "llama3",
|
||||
"created_at": "2023-12-18T19:52:07.071755Z",
|
||||
"response": "",
|
||||
"done": true
|
||||
@@ -407,7 +407,7 @@ Send a chat message with a streaming response.
|
||||
|
||||
```shell
|
||||
curl http://localhost:11434/api/chat -d '{
|
||||
"model": "llama2",
|
||||
"model": "llama3",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
@@ -423,7 +423,7 @@ A stream of JSON objects is returned:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "llama2",
|
||||
"model": "llama3",
|
||||
"created_at": "2023-08-04T08:52:19.385406455-07:00",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
@@ -438,7 +438,7 @@ Final response:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "llama2",
|
||||
"model": "llama3",
|
||||
"created_at": "2023-08-04T19:22:45.499127Z",
|
||||
"done": true,
|
||||
"total_duration": 4883583458,
|
||||
@@ -456,7 +456,7 @@ Final response:
|
||||
|
||||
```shell
|
||||
curl http://localhost:11434/api/chat -d '{
|
||||
"model": "llama2",
|
||||
"model": "llama3",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
@@ -471,7 +471,7 @@ curl http://localhost:11434/api/chat -d '{
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "registry.ollama.ai/library/llama2:latest",
|
||||
"model": "registry.ollama.ai/library/llama3:latest",
|
||||
"created_at": "2023-12-12T14:13:43.416799Z",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
@@ -495,7 +495,7 @@ Send a chat message with a conversation history. You can use this same approach
|
||||
|
||||
```shell
|
||||
curl http://localhost:11434/api/chat -d '{
|
||||
"model": "llama2",
|
||||
"model": "llama3",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
@@ -519,7 +519,7 @@ A stream of JSON objects is returned:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "llama2",
|
||||
"model": "llama3",
|
||||
"created_at": "2023-08-04T08:52:19.385406455-07:00",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
@@ -533,7 +533,7 @@ Final response:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "llama2",
|
||||
"model": "llama3",
|
||||
"created_at": "2023-08-04T19:22:45.499127Z",
|
||||
"done": true,
|
||||
"total_duration": 8113331500,
|
||||
@@ -591,7 +591,7 @@ curl http://localhost:11434/api/chat -d '{
|
||||
|
||||
```shell
|
||||
curl http://localhost:11434/api/chat -d '{
|
||||
"model": "llama2",
|
||||
"model": "llama3",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
@@ -609,7 +609,7 @@ curl http://localhost:11434/api/chat -d '{
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "registry.ollama.ai/library/llama2:latest",
|
||||
"model": "registry.ollama.ai/library/llama3:latest",
|
||||
"created_at": "2023-12-12T14:13:43.416799Z",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
@@ -651,7 +651,7 @@ Create a new model from a `Modelfile`.
|
||||
```shell
|
||||
curl http://localhost:11434/api/create -d '{
|
||||
"name": "mario",
|
||||
"modelfile": "FROM llama2\nSYSTEM You are mario from Super Mario Bros."
|
||||
"modelfile": "FROM llama3\nSYSTEM You are mario from Super Mario Bros."
|
||||
}'
|
||||
```
|
||||
|
||||
@@ -758,7 +758,7 @@ A single JSON object will be returned.
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "llama2:latest",
|
||||
"name": "llama3:latest",
|
||||
"modified_at": "2023-12-07T09:32:18.757212583-08:00",
|
||||
"size": 3825819519,
|
||||
"digest": "fe938a131f40e6f6d40083c9f0f430a515233eb2edaa6d72eb85c50d64f2300e",
|
||||
@@ -792,7 +792,7 @@ Show information about a model including details, modelfile, template, parameter
|
||||
|
||||
```shell
|
||||
curl http://localhost:11434/api/show -d '{
|
||||
"name": "llama2"
|
||||
"name": "llama3"
|
||||
}'
|
||||
```
|
||||
|
||||
@@ -827,8 +827,8 @@ Copy a model. Creates a model with another name from an existing model.
|
||||
|
||||
```shell
|
||||
curl http://localhost:11434/api/copy -d '{
|
||||
"source": "llama2",
|
||||
"destination": "llama2-backup"
|
||||
"source": "llama3",
|
||||
"destination": "llama3-backup"
|
||||
}'
|
||||
```
|
||||
|
||||
@@ -854,7 +854,7 @@ Delete a model and its data.
|
||||
|
||||
```shell
|
||||
curl -X DELETE http://localhost:11434/api/delete -d '{
|
||||
"name": "llama2:13b"
|
||||
"name": "llama3:13b"
|
||||
}'
|
||||
```
|
||||
|
||||
@@ -882,7 +882,7 @@ Download a model from the ollama library. Cancelled pulls are resumed from where
|
||||
|
||||
```shell
|
||||
curl http://localhost:11434/api/pull -d '{
|
||||
"name": "llama2"
|
||||
"name": "llama3"
|
||||
}'
|
||||
```
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ Typically the build scripts will auto-detect CUDA, however, if your Linux distro
|
||||
or installation approach uses unusual paths, you can specify the location by
|
||||
specifying an environment variable `CUDA_LIB_DIR` to the location of the shared
|
||||
libraries, and `CUDACXX` to the location of the nvcc compiler. You can customize
|
||||
set set of target CUDA architectues by setting `CMAKE_CUDA_ARCHITECTURES` (e.g. "50;60;70")
|
||||
a set of target CUDA architectures by setting `CMAKE_CUDA_ARCHITECTURES` (e.g. "50;60;70")
|
||||
|
||||
Then generate dependencies:
|
||||
|
||||
@@ -142,4 +142,4 @@ In addition to the common Windows development tools described above, install AMD
|
||||
- [AMD HIP](https://www.amd.com/en/developer/resources/rocm-hub/hip-sdk.html)
|
||||
- [Strawberry Perl](https://strawberryperl.com/)
|
||||
|
||||
Lastly, add `ninja.exe` included with MSVC to the system path (e.g. `C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\Common7\IDE\CommonExtensions\Microsoft\CMake\Ninja`).
|
||||
Lastly, add `ninja.exe` included with MSVC to the system path (e.g. `C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\Common7\IDE\CommonExtensions\Microsoft\CMake\Ninja`).
|
||||
|
||||
18
docs/faq.md
18
docs/faq.md
@@ -32,7 +32,7 @@ When using the API, specify the `num_ctx` parameter:
|
||||
|
||||
```
|
||||
curl http://localhost:11434/api/generate -d '{
|
||||
"model": "llama2",
|
||||
"model": "llama3",
|
||||
"prompt": "Why is the sky blue?",
|
||||
"options": {
|
||||
"num_ctx": 4096
|
||||
@@ -88,9 +88,9 @@ On windows, Ollama inherits your user and system environment variables.
|
||||
|
||||
3. Edit or create New variable(s) for your user account for `OLLAMA_HOST`, `OLLAMA_MODELS`, etc.
|
||||
|
||||
4. Click OK/Apply to save
|
||||
4. Click OK/Apply to save
|
||||
|
||||
5. Run `ollama` from a new terminal window
|
||||
5. Run `ollama` from a new terminal window
|
||||
|
||||
|
||||
## How can I expose Ollama on my network?
|
||||
@@ -140,7 +140,7 @@ Refer to the section [above](#how-do-i-configure-ollama-server) for how to set e
|
||||
|
||||
- macOS: `~/.ollama/models`
|
||||
- Linux: `/usr/share/ollama/.ollama/models`
|
||||
- Windows: `C:\Users\<username>\.ollama\models`
|
||||
- Windows: `C:\Users\%username%\.ollama\models`
|
||||
|
||||
### How do I set them to a different location?
|
||||
|
||||
@@ -221,14 +221,20 @@ The `keep_alive` parameter can be set to:
|
||||
|
||||
For example, to preload a model and leave it in memory use:
|
||||
```shell
|
||||
curl http://localhost:11434/api/generate -d '{"model": "llama2", "keep_alive": -1}'
|
||||
curl http://localhost:11434/api/generate -d '{"model": "llama3", "keep_alive": -1}'
|
||||
```
|
||||
|
||||
To unload the model and free up memory use:
|
||||
```shell
|
||||
curl http://localhost:11434/api/generate -d '{"model": "llama2", "keep_alive": 0}'
|
||||
curl http://localhost:11434/api/generate -d '{"model": "llama3", "keep_alive": 0}'
|
||||
```
|
||||
|
||||
Alternatively, you can change the amount of time all models are loaded into memory by setting the `OLLAMA_KEEP_ALIVE` environment variable when starting the Ollama server. The `OLLAMA_KEEP_ALIVE` variable uses the same parameter types as the `keep_alive` parameter types mentioned above. Refer to section explaining [how to configure the Ollama server](#how-do-i-configure-ollama-server) to correctly set the environment variable.
|
||||
|
||||
If you wish to override the `OLLAMA_KEEP_ALIVE` setting, use the `keep_alive` API parameter with the `/api/generate` or `/api/chat` API endpoints.
|
||||
|
||||
## How do I manage the maximum number of requests the server can queue
|
||||
|
||||
If too many requests are sent to the server, it will respond with a 503 error
|
||||
indicating the server is overloaded. You can adjust how many requests may be
|
||||
queue by setting `OLLAMA_MAX_QUEUE`
|
||||
@@ -125,7 +125,7 @@ Publishing models is in early alpha. If you'd like to publish your model to shar
|
||||
|
||||
1. Create [an account](https://ollama.com/signup)
|
||||
2. Copy your Ollama public key:
|
||||
- macOS: `cat ~/.ollama/id_ed25519.pub`
|
||||
- macOS: `cat ~/.ollama/id_ed25519.pub | pbcopy`
|
||||
- Windows: `type %USERPROFILE%\.ollama\id_ed25519.pub`
|
||||
- Linux: `cat /usr/share/ollama/.ollama/id_ed25519.pub`
|
||||
3. Add your public key to your [Ollama account](https://ollama.com/settings/keys)
|
||||
@@ -136,6 +136,8 @@ Next, copy your model to your username's namespace:
|
||||
ollama cp example <your username>/example
|
||||
```
|
||||
|
||||
> Note: model names may only contain lowercase letters, digits, and the characters `.`, `-`, and `_`.
|
||||
|
||||
Then push the model:
|
||||
|
||||
```
|
||||
|
||||
@@ -105,7 +105,7 @@ sudo chmod +x /usr/bin/ollama
|
||||
To view logs of Ollama running as a startup service, run:
|
||||
|
||||
```bash
|
||||
journalctl -u ollama
|
||||
journalctl -e -u ollama
|
||||
```
|
||||
|
||||
## Uninstall
|
||||
|
||||
@@ -10,7 +10,7 @@ A model file is the blueprint to create and share models with Ollama.
|
||||
- [Examples](#examples)
|
||||
- [Instructions](#instructions)
|
||||
- [FROM (Required)](#from-required)
|
||||
- [Build from llama2](#build-from-llama2)
|
||||
- [Build from llama3](#build-from-llama3)
|
||||
- [Build from a bin file](#build-from-a-bin-file)
|
||||
- [PARAMETER](#parameter)
|
||||
- [Valid Parameters and Values](#valid-parameters-and-values)
|
||||
@@ -48,7 +48,7 @@ INSTRUCTION arguments
|
||||
An example of a `Modelfile` creating a mario blueprint:
|
||||
|
||||
```modelfile
|
||||
FROM llama2
|
||||
FROM llama3
|
||||
# sets the temperature to 1 [higher is more creative, lower is more coherent]
|
||||
PARAMETER temperature 1
|
||||
# sets the context window size to 4096, this controls how many tokens the LLM can use as context to generate the next token
|
||||
@@ -67,33 +67,25 @@ To use this:
|
||||
|
||||
More examples are available in the [examples directory](../examples).
|
||||
|
||||
### `Modelfile`s in [ollama.com/library][1]
|
||||
|
||||
There are two ways to view `Modelfile`s underlying the models in [ollama.com/library][1]:
|
||||
|
||||
- Option 1: view a details page from a model's tags page:
|
||||
1. Go to a particular model's tags (e.g. https://ollama.com/library/llama2/tags)
|
||||
2. Click on a tag (e.g. https://ollama.com/library/llama2:13b)
|
||||
3. Scroll down to "Layers"
|
||||
- Note: if the [`FROM` instruction](#from-required) is not present,
|
||||
it means the model was created from a local file
|
||||
- Option 2: use `ollama show` to print the `Modelfile` for any local models like so:
|
||||
To view the Modelfile of a given model, use the `ollama show --modelfile` command.
|
||||
|
||||
```bash
|
||||
> ollama show --modelfile llama2:13b
|
||||
> ollama show --modelfile llama3
|
||||
# Modelfile generated by "ollama show"
|
||||
# To build a new Modelfile based on this one, replace the FROM line with:
|
||||
# FROM llama2:13b
|
||||
# FROM llama3:latest
|
||||
FROM /Users/pdevine/.ollama/models/blobs/sha256-00e1317cbf74d901080d7100f57580ba8dd8de57203072dc6f668324ba545f29
|
||||
TEMPLATE """{{ if .System }}<|start_header_id|>system<|end_header_id|>
|
||||
|
||||
FROM /root/.ollama/models/blobs/sha256:123abc
|
||||
TEMPLATE """[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>>
|
||||
{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
|
||||
|
||||
{{ end }}{{ .Prompt }} [/INST] """
|
||||
SYSTEM """"""
|
||||
PARAMETER stop [INST]
|
||||
PARAMETER stop [/INST]
|
||||
PARAMETER stop <<SYS>>
|
||||
PARAMETER stop <</SYS>>
|
||||
{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
{{ .Response }}<|eot_id|>"""
|
||||
PARAMETER stop "<|start_header_id|>"
|
||||
PARAMETER stop "<|end_header_id|>"
|
||||
PARAMETER stop "<|eot_id|>"
|
||||
PARAMETER stop "<|reserved_special_token"
|
||||
```
|
||||
|
||||
## Instructions
|
||||
@@ -106,10 +98,10 @@ The `FROM` instruction defines the base model to use when creating a model.
|
||||
FROM <model name>:<tag>
|
||||
```
|
||||
|
||||
#### Build from llama2
|
||||
#### Build from llama3
|
||||
|
||||
```modelfile
|
||||
FROM llama2
|
||||
FROM llama3
|
||||
```
|
||||
|
||||
A list of available base models:
|
||||
|
||||
@@ -25,7 +25,7 @@ chat_completion = client.chat.completions.create(
|
||||
'content': 'Say this is a test',
|
||||
}
|
||||
],
|
||||
model='llama2',
|
||||
model='llama3',
|
||||
)
|
||||
```
|
||||
|
||||
@@ -43,7 +43,7 @@ const openai = new OpenAI({
|
||||
|
||||
const chatCompletion = await openai.chat.completions.create({
|
||||
messages: [{ role: 'user', content: 'Say this is a test' }],
|
||||
model: 'llama2',
|
||||
model: 'llama3',
|
||||
})
|
||||
```
|
||||
|
||||
@@ -53,7 +53,7 @@ const chatCompletion = await openai.chat.completions.create({
|
||||
curl http://localhost:11434/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "llama2",
|
||||
"model": "llama3",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
@@ -113,7 +113,7 @@ curl http://localhost:11434/v1/chat/completions \
|
||||
Before using a model, pull it locally `ollama pull`:
|
||||
|
||||
```shell
|
||||
ollama pull llama2
|
||||
ollama pull llama3
|
||||
```
|
||||
|
||||
### Default model names
|
||||
@@ -121,7 +121,7 @@ ollama pull llama2
|
||||
For tooling that relies on default OpenAI model names such as `gpt-3.5-turbo`, use `ollama cp` to copy an existing model name to a temporary name:
|
||||
|
||||
```
|
||||
ollama cp llama2 gpt-3.5-turbo
|
||||
ollama cp llama3 gpt-3.5-turbo
|
||||
```
|
||||
|
||||
Afterwards, this new model name can be specified the `model` field:
|
||||
|
||||
@@ -15,7 +15,7 @@ import { Ollama } from "langchain/llms/ollama";
|
||||
|
||||
const ollama = new Ollama({
|
||||
baseUrl: "http://localhost:11434",
|
||||
model: "llama2",
|
||||
model: "llama3",
|
||||
});
|
||||
|
||||
const answer = await ollama.invoke(`why is the sky blue?`);
|
||||
@@ -23,10 +23,10 @@ const answer = await ollama.invoke(`why is the sky blue?`);
|
||||
console.log(answer);
|
||||
```
|
||||
|
||||
That will get us the same thing as if we ran `ollama run llama2 "why is the sky blue"` in the terminal. But we want to load a document from the web to ask a question against. **Cheerio** is a great library for ingesting a webpage, and **LangChain** uses it in their **CheerioWebBaseLoader**. So let's install **Cheerio** and build that part of the app.
|
||||
That will get us the same thing as if we ran `ollama run llama3 "why is the sky blue"` in the terminal. But we want to load a document from the web to ask a question against. **Cheerio** is a great library for ingesting a webpage, and **LangChain** uses it in their **CheerioWebBaseLoader**. So let's install **Cheerio** and build that part of the app.
|
||||
|
||||
```bash
|
||||
npm install cheerio
|
||||
npm install cheerio
|
||||
```
|
||||
|
||||
```javascript
|
||||
|
||||
@@ -17,10 +17,12 @@ Let's start by asking a simple question that we can get an answer to from the **
|
||||
Then we can create a model and ask the question:
|
||||
|
||||
```python
|
||||
from langchain.llms import Ollama
|
||||
ollama = Ollama(base_url='http://localhost:11434',
|
||||
model="llama2")
|
||||
print(ollama("why is the sky blue"))
|
||||
from langchain_community.llms import Ollama
|
||||
ollama = Ollama(
|
||||
base_url='http://localhost:11434',
|
||||
model="llama3"
|
||||
)
|
||||
print(ollama.invoke("why is the sky blue"))
|
||||
```
|
||||
|
||||
Notice that we are defining the model and the base URL for Ollama.
|
||||
|
||||
108
docs/windows.md
108
docs/windows.md
@@ -1,47 +1,61 @@
|
||||
# Ollama Windows Preview
|
||||
|
||||
Welcome to the Ollama Windows preview.
|
||||
|
||||
No more WSL required!
|
||||
|
||||
Ollama now runs as a native Windows application, including NVIDIA and AMD Radeon GPU support.
|
||||
After installing Ollama Windows Preview, Ollama will run in the background and
|
||||
the `ollama` command line is available in `cmd`, `powershell` or your favorite
|
||||
terminal application. As usual the Ollama [api](./api.md) will be served on
|
||||
`http://localhost:11434`.
|
||||
|
||||
As this is a preview release, you should expect a few bugs here and there. If
|
||||
you run into a problem you can reach out on
|
||||
[Discord](https://discord.gg/ollama), or file an
|
||||
[issue](https://github.com/ollama/ollama/issues).
|
||||
Logs will often be helpful in diagnosing the problem (see
|
||||
[Troubleshooting](#troubleshooting) below)
|
||||
|
||||
## System Requirements
|
||||
|
||||
* Windows 10 or newer, Home or Pro
|
||||
* NVIDIA 452.39 or newer Drivers if you have an NVIDIA card
|
||||
* AMD Radeon Driver https://www.amd.com/en/support if you have a Radeon card
|
||||
|
||||
## API Access
|
||||
|
||||
Here's a quick example showing API access from `powershell`
|
||||
```powershell
|
||||
(Invoke-WebRequest -method POST -Body '{"model":"llama2", "prompt":"Why is the sky blue?", "stream": false}' -uri http://localhost:11434/api/generate ).Content | ConvertFrom-json
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
While we're in preview, `OLLAMA_DEBUG` is always enabled, which adds
|
||||
a "view logs" menu item to the app, and increses logging for the GUI app and
|
||||
server.
|
||||
|
||||
Ollama on Windows stores files in a few different locations. You can view them in
|
||||
the explorer window by hitting `<cmd>+R` and type in:
|
||||
- `explorer %LOCALAPPDATA%\Ollama` contains logs, and downloaded updates
|
||||
- *app.log* contains logs from the GUI application
|
||||
- *server.log* contains the server logs
|
||||
- *upgrade.log* contains log output for upgrades
|
||||
- `explorer %LOCALAPPDATA%\Programs\Ollama` contains the binaries (The installer adds this to your user PATH)
|
||||
- `explorer %HOMEPATH%\.ollama` contains models and configuration
|
||||
- `explorer %TEMP%` contains temporary executable files in one or more `ollama*` directories
|
||||
# Ollama Windows Preview
|
||||
|
||||
Welcome to the Ollama Windows preview.
|
||||
|
||||
No more WSL required!
|
||||
|
||||
Ollama now runs as a native Windows application, including NVIDIA and AMD Radeon GPU support.
|
||||
After installing Ollama Windows Preview, Ollama will run in the background and
|
||||
the `ollama` command line is available in `cmd`, `powershell` or your favorite
|
||||
terminal application. As usual the Ollama [api](./api.md) will be served on
|
||||
`http://localhost:11434`.
|
||||
|
||||
As this is a preview release, you should expect a few bugs here and there. If
|
||||
you run into a problem you can reach out on
|
||||
[Discord](https://discord.gg/ollama), or file an
|
||||
[issue](https://github.com/ollama/ollama/issues).
|
||||
Logs will often be helpful in diagnosing the problem (see
|
||||
[Troubleshooting](#troubleshooting) below)
|
||||
|
||||
## System Requirements
|
||||
|
||||
* Windows 10 or newer, Home or Pro
|
||||
* NVIDIA 452.39 or newer Drivers if you have an NVIDIA card
|
||||
* AMD Radeon Driver https://www.amd.com/en/support if you have a Radeon card
|
||||
|
||||
## API Access
|
||||
|
||||
Here's a quick example showing API access from `powershell`
|
||||
```powershell
|
||||
(Invoke-WebRequest -method POST -Body '{"model":"llama3", "prompt":"Why is the sky blue?", "stream": false}' -uri http://localhost:11434/api/generate ).Content | ConvertFrom-json
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
While we're in preview, `OLLAMA_DEBUG` is always enabled, which adds
|
||||
a "view logs" menu item to the app, and increses logging for the GUI app and
|
||||
server.
|
||||
|
||||
Ollama on Windows stores files in a few different locations. You can view them in
|
||||
the explorer window by hitting `<cmd>+R` and type in:
|
||||
- `explorer %LOCALAPPDATA%\Ollama` contains logs, and downloaded updates
|
||||
- *app.log* contains logs from the GUI application
|
||||
- *server.log* contains the server logs
|
||||
- *upgrade.log* contains log output for upgrades
|
||||
- `explorer %LOCALAPPDATA%\Programs\Ollama` contains the binaries (The installer adds this to your user PATH)
|
||||
- `explorer %HOMEPATH%\.ollama` contains models and configuration
|
||||
- `explorer %TEMP%` contains temporary executable files in one or more `ollama*` directories
|
||||
|
||||
|
||||
## Standalone CLI
|
||||
|
||||
The easiest way to install Ollama on Windows is to use the `OllamaSetup.exe`
|
||||
installer. It installs in your account without requiring Administrator rights.
|
||||
We update Ollama regularly to support the latest models, and this installer will
|
||||
help you keep up to date.
|
||||
|
||||
If you'd like to install or integrate Ollama as a service, a standalone
|
||||
`ollama-windows-amd64.zip` zip file is available containing only the Ollama CLI
|
||||
and GPU library dependencies for Nvidia and AMD. This allows for embedding
|
||||
Ollama in existing applications, or running it as a system service via `ollama
|
||||
serve` with tools such as [NSSM](https://nssm.cc/).
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
When calling `ollama`, you can pass it a file to run all the prompts in the file, one after the other:
|
||||
|
||||
`ollama run llama2 < sourcequestions.txt`
|
||||
`ollama run llama3 < sourcequestions.txt`
|
||||
|
||||
This concept is used in the following example.
|
||||
|
||||
|
||||
1
examples/flyio/.gitignore
vendored
Normal file
1
examples/flyio/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
fly.toml
|
||||
67
examples/flyio/README.md
Normal file
67
examples/flyio/README.md
Normal file
@@ -0,0 +1,67 @@
|
||||
# Deploy Ollama to Fly.io
|
||||
|
||||
> Note: this example exposes a public endpoint and does not configure authentication. Use with care.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- Ollama: https://ollama.com/download
|
||||
- Fly.io account. Sign up for a free account: https://fly.io/app/sign-up
|
||||
|
||||
## Steps
|
||||
|
||||
1. Login to Fly.io
|
||||
|
||||
```bash
|
||||
fly auth login
|
||||
```
|
||||
|
||||
1. Create a new Fly app
|
||||
|
||||
```bash
|
||||
fly launch --name <name> --image ollama/ollama --internal-port 11434 --vm-size shared-cpu-8x --now
|
||||
```
|
||||
|
||||
1. Pull and run `orca-mini:3b`
|
||||
|
||||
```bash
|
||||
OLLAMA_HOST=https://<name>.fly.dev ollama run orca-mini:3b
|
||||
```
|
||||
|
||||
`shared-cpu-8x` is a free-tier eligible machine type. For better performance, switch to a `performance` or `dedicated` machine type or attach a GPU for hardware acceleration (see below).
|
||||
|
||||
## (Optional) Persistent Volume
|
||||
|
||||
By default Fly Machines use ephemeral storage which is problematic if you want to use the same model across restarts without pulling it again. Create and attach a persistent volume to store the downloaded models:
|
||||
|
||||
1. Create the Fly Volume
|
||||
|
||||
```bash
|
||||
fly volume create ollama
|
||||
```
|
||||
|
||||
1. Update `fly.toml` and add `[mounts]`
|
||||
|
||||
```toml
|
||||
[mounts]
|
||||
source = "ollama"
|
||||
destination = "/mnt/ollama/models"
|
||||
```
|
||||
|
||||
1. Update `fly.toml` and add `[env]`
|
||||
|
||||
```toml
|
||||
[env]
|
||||
OLLAMA_MODELS = "/mnt/ollama/models"
|
||||
```
|
||||
|
||||
1. Deploy your app
|
||||
|
||||
```bash
|
||||
fly deploy
|
||||
```
|
||||
|
||||
## (Optional) Hardware Acceleration
|
||||
|
||||
Fly.io GPU is currently in waitlist. Sign up for the waitlist: https://fly.io/gpu
|
||||
|
||||
Once you've been accepted, create the app with the additional flags `--vm-gpu-kind a100-pcie-40gb` or `--vm-gpu-kind a100-pcie-80gb`.
|
||||
@@ -35,7 +35,7 @@ func main() {
|
||||
|
||||
ctx := context.Background()
|
||||
req := &api.ChatRequest{
|
||||
Model: "llama2",
|
||||
Model: "llama3",
|
||||
Messages: messages,
|
||||
}
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ func main() {
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
|
||||
responseData, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
|
||||
@@ -7,12 +7,24 @@
|
||||
|
||||
## Steps
|
||||
|
||||
1. Create the Ollama namespace, daemon set, and service
|
||||
1. Create the Ollama namespace, deployment, and service
|
||||
|
||||
```bash
|
||||
kubectl apply -f cpu.yaml
|
||||
```
|
||||
|
||||
## (Optional) Hardware Acceleration
|
||||
|
||||
Hardware acceleration in Kubernetes requires NVIDIA's [`k8s-device-plugin`](https://github.com/NVIDIA/k8s-device-plugin) which is deployed in Kubernetes in form of daemonset. Follow the link for more details.
|
||||
|
||||
Once configured, create a GPU enabled Ollama deployment.
|
||||
|
||||
```bash
|
||||
kubectl apply -f gpu.yaml
|
||||
```
|
||||
|
||||
## Test
|
||||
|
||||
1. Port forward the Ollama service to connect and use it locally
|
||||
|
||||
```bash
|
||||
@@ -23,14 +35,4 @@
|
||||
|
||||
```bash
|
||||
ollama run orca-mini:3b
|
||||
```
|
||||
|
||||
## (Optional) Hardware Acceleration
|
||||
|
||||
Hardware acceleration in Kubernetes requires NVIDIA's [`k8s-device-plugin`](https://github.com/NVIDIA/k8s-device-plugin). Follow the link for more details.
|
||||
|
||||
Once configured, create a GPU enabled Ollama deployment.
|
||||
|
||||
```bash
|
||||
kubectl apply -f gpu.yaml
|
||||
```
|
||||
```
|
||||
@@ -40,9 +40,9 @@ while True:
|
||||
continue
|
||||
|
||||
# Prompt
|
||||
template = """Use the following pieces of context to answer the question at the end.
|
||||
If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
||||
Use three sentences maximum and keep the answer as concise as possible.
|
||||
template = """Use the following pieces of context to answer the question at the end.
|
||||
If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
||||
Use three sentences maximum and keep the answer as concise as possible.
|
||||
{context}
|
||||
Question: {question}
|
||||
Helpful Answer:"""
|
||||
@@ -51,11 +51,11 @@ while True:
|
||||
template=template,
|
||||
)
|
||||
|
||||
llm = Ollama(model="llama2:13b", callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]))
|
||||
llm = Ollama(model="llama3:8b", callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]))
|
||||
qa_chain = RetrievalQA.from_chain_type(
|
||||
llm,
|
||||
retriever=vectorstore.as_retriever(),
|
||||
chain_type_kwargs={"prompt": QA_CHAIN_PROMPT},
|
||||
)
|
||||
|
||||
result = qa_chain({"query": query})
|
||||
result = qa_chain({"query": query})
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
from langchain.llms import Ollama
|
||||
from langchain.document_loaders import WebBaseLoader
|
||||
from langchain_community.llms import Ollama
|
||||
from langchain_community.document_loaders import WebBaseLoader
|
||||
from langchain.chains.summarize import load_summarize_chain
|
||||
|
||||
loader = WebBaseLoader("https://ollama.com/blog/run-llama2-uncensored-locally")
|
||||
docs = loader.load()
|
||||
|
||||
llm = Ollama(model="llama2")
|
||||
llm = Ollama(model="llama3")
|
||||
chain = load_summarize_chain(llm, chain_type="stuff")
|
||||
|
||||
result = chain.run(docs)
|
||||
result = chain.invoke(docs)
|
||||
print(result)
|
||||
|
||||
@@ -4,10 +4,10 @@ This example is a basic "hello world" of using LangChain with Ollama.
|
||||
|
||||
## Running the Example
|
||||
|
||||
1. Ensure you have the `llama2` model installed:
|
||||
1. Ensure you have the `llama3` model installed:
|
||||
|
||||
```bash
|
||||
ollama pull llama2
|
||||
ollama pull llama3
|
||||
```
|
||||
|
||||
2. Install the Python Requirements.
|
||||
@@ -21,4 +21,3 @@ This example is a basic "hello world" of using LangChain with Ollama.
|
||||
```bash
|
||||
python main.py
|
||||
```
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from langchain.llms import Ollama
|
||||
|
||||
input = input("What is your question?")
|
||||
llm = Ollama(model="llama2")
|
||||
llm = Ollama(model="llama3")
|
||||
res = llm.predict(input)
|
||||
print (res)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
FROM llama2
|
||||
FROM llama3
|
||||
PARAMETER temperature 1
|
||||
SYSTEM """
|
||||
You are Mario from super mario bros, acting as an assistant.
|
||||
|
||||
@@ -2,12 +2,12 @@
|
||||
|
||||
# Example character: Mario
|
||||
|
||||
This example shows how to create a basic character using Llama2 as the base model.
|
||||
This example shows how to create a basic character using Llama3 as the base model.
|
||||
|
||||
To run this example:
|
||||
|
||||
1. Download the Modelfile
|
||||
2. `ollama pull llama2` to get the base model used in the model file.
|
||||
2. `ollama pull llama3` to get the base model used in the model file.
|
||||
3. `ollama create NAME -f ./Modelfile`
|
||||
4. `ollama run NAME`
|
||||
|
||||
@@ -18,7 +18,7 @@ Ask it some questions like "Who are you?" or "Is Peach in trouble again?"
|
||||
What the model file looks like:
|
||||
|
||||
```
|
||||
FROM llama2
|
||||
FROM llama3
|
||||
PARAMETER temperature 1
|
||||
SYSTEM """
|
||||
You are Mario from Super Mario Bros, acting as an assistant.
|
||||
|
||||
@@ -2,16 +2,16 @@ import requests
|
||||
import json
|
||||
import random
|
||||
|
||||
model = "llama2"
|
||||
model = "llama3"
|
||||
template = {
|
||||
"firstName": "",
|
||||
"lastName": "",
|
||||
"firstName": "",
|
||||
"lastName": "",
|
||||
"address": {
|
||||
"street": "",
|
||||
"city": "",
|
||||
"state": "",
|
||||
"street": "",
|
||||
"city": "",
|
||||
"state": "",
|
||||
"zipCode": ""
|
||||
},
|
||||
},
|
||||
"phoneNumber": ""
|
||||
}
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ countries = [
|
||||
"France",
|
||||
]
|
||||
country = random.choice(countries)
|
||||
model = "llama2"
|
||||
model = "llama3"
|
||||
|
||||
prompt = f"generate one realistically believable sample data set of a persons first name, last name, address in {country}, and phone number. Do not use common names. Respond using JSON. Key names should have no backslashes, values should use plain ascii with no special characters."
|
||||
|
||||
|
||||
@@ -6,10 +6,10 @@ There are two python scripts in this example. `randomaddresses.py` generates ran
|
||||
|
||||
## Running the Example
|
||||
|
||||
1. Ensure you have the `llama2` model installed:
|
||||
1. Ensure you have the `llama3` model installed:
|
||||
|
||||
```bash
|
||||
ollama pull llama2
|
||||
ollama pull llama3
|
||||
```
|
||||
|
||||
2. Install the Python Requirements.
|
||||
|
||||
@@ -2,7 +2,7 @@ import json
|
||||
import requests
|
||||
|
||||
# NOTE: ollama must be running for this to work, start the ollama app or run `ollama serve`
|
||||
model = "llama2" # TODO: update this for whatever model you wish to use
|
||||
model = "llama3" # TODO: update this for whatever model you wish to use
|
||||
|
||||
|
||||
def chat(messages):
|
||||
|
||||
@@ -4,10 +4,10 @@ The **chat** endpoint is one of two ways to generate text from an LLM with Ollam
|
||||
|
||||
## Running the Example
|
||||
|
||||
1. Ensure you have the `llama2` model installed:
|
||||
1. Ensure you have the `llama3` model installed:
|
||||
|
||||
```bash
|
||||
ollama pull llama2
|
||||
ollama pull llama3
|
||||
```
|
||||
|
||||
2. Install the Python Requirements.
|
||||
|
||||
@@ -4,10 +4,10 @@ This example demonstrates how one would create a set of 'mentors' you can have a
|
||||
|
||||
## Usage
|
||||
|
||||
1. Add llama2 to have the mentors ask your questions:
|
||||
1. Add llama3 to have the mentors ask your questions:
|
||||
|
||||
```bash
|
||||
ollama pull llama2
|
||||
ollama pull llama3
|
||||
```
|
||||
|
||||
2. Install prerequisites:
|
||||
|
||||
@@ -15,7 +15,7 @@ async function characterGenerator() {
|
||||
ollama.setModel("stablebeluga2:70b-q4_K_M");
|
||||
const bio = await ollama.generate(`create a bio of ${character} in a single long paragraph. Instead of saying '${character} is...' or '${character} was...' use language like 'You are...' or 'You were...'. Then create a paragraph describing the speaking mannerisms and style of ${character}. Don't include anything about how ${character} looked or what they sounded like, just focus on the words they said. Instead of saying '${character} would say...' use language like 'You should say...'. If you use quotes, always use single quotes instead of double quotes. If there are any specific words or phrases you used a lot, show how you used them. `);
|
||||
|
||||
const thecontents = `FROM llama2\nSYSTEM """\n${bio.response.replace(/(\r\n|\n|\r)/gm, " ").replace('would', 'should')} All answers to questions should be related back to what you are most known for.\n"""`;
|
||||
const thecontents = `FROM llama3\nSYSTEM """\n${bio.response.replace(/(\r\n|\n|\r)/gm, " ").replace('would', 'should')} All answers to questions should be related back to what you are most known for.\n"""`;
|
||||
|
||||
fs.writeFile(path.join(directory, 'Modelfile'), thecontents, (err: any) => {
|
||||
if (err) throw err;
|
||||
@@ -23,4 +23,4 @@ async function characterGenerator() {
|
||||
});
|
||||
}
|
||||
|
||||
characterGenerator();
|
||||
characterGenerator();
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import * as readline from "readline";
|
||||
|
||||
const model = "llama2";
|
||||
const model = "llama3";
|
||||
type Message = {
|
||||
role: "assistant" | "user" | "system";
|
||||
content: string;
|
||||
@@ -74,4 +74,4 @@ async function main() {
|
||||
|
||||
}
|
||||
|
||||
main();
|
||||
main();
|
||||
|
||||
@@ -81,8 +81,10 @@ func commonAMDValidateLibDir() (string, error) {
|
||||
}
|
||||
|
||||
// Well known location(s)
|
||||
if rocmLibUsable(RocmStandardLocation) {
|
||||
return RocmStandardLocation, nil
|
||||
for _, path := range RocmStandardLocations {
|
||||
if rocmLibUsable(path) {
|
||||
return path, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Installer payload location if we're running the installed binary
|
||||
|
||||
@@ -25,12 +25,12 @@ const (
|
||||
// Prefix with the node dir
|
||||
GPUTotalMemoryFileGlob = "mem_banks/*/properties" // size_in_bytes line
|
||||
GPUUsedMemoryFileGlob = "mem_banks/*/used_memory"
|
||||
RocmStandardLocation = "/opt/rocm/lib"
|
||||
)
|
||||
|
||||
var (
|
||||
// Used to validate if the given ROCm lib is usable
|
||||
ROCmLibGlobs = []string{"libhipblas.so.2*", "rocblas"} // TODO - probably include more coverage of files here...
|
||||
ROCmLibGlobs = []string{"libhipblas.so.2*", "rocblas"} // TODO - probably include more coverage of files here...
|
||||
RocmStandardLocations = []string{"/opt/rocm/lib", "/usr/lib64"}
|
||||
)
|
||||
|
||||
// Gather GPU information from the amdgpu driver if any supported GPUs are detected
|
||||
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
RocmStandardLocation = "C:\\Program Files\\AMD\\ROCm\\5.7\\bin" // TODO glob?
|
||||
|
||||
// TODO We're lookinng for this exact name to detect iGPUs since hipGetDeviceProperties never reports integrated==true
|
||||
iGPUName = "AMD Radeon(TM) Graphics"
|
||||
@@ -22,7 +21,8 @@ const (
|
||||
|
||||
var (
|
||||
// Used to validate if the given ROCm lib is usable
|
||||
ROCmLibGlobs = []string{"hipblas.dll", "rocblas"} // TODO - probably include more coverage of files here...
|
||||
ROCmLibGlobs = []string{"hipblas.dll", "rocblas"} // TODO - probably include more coverage of files here...
|
||||
RocmStandardLocations = []string{"C:\\Program Files\\AMD\\ROCm\\5.7\\bin"} // TODO glob?
|
||||
)
|
||||
|
||||
func AMDGetGPUInfo() []GpuInfo {
|
||||
|
||||
@@ -12,6 +12,8 @@ import (
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/server/envconfig"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -24,45 +26,8 @@ func PayloadsDir() (string, error) {
|
||||
defer lock.Unlock()
|
||||
var err error
|
||||
if payloadsDir == "" {
|
||||
runnersDir := os.Getenv("OLLAMA_RUNNERS_DIR")
|
||||
// On Windows we do not carry the payloads inside the main executable
|
||||
if runtime.GOOS == "windows" && runnersDir == "" {
|
||||
appExe, err := os.Executable()
|
||||
if err != nil {
|
||||
slog.Error("failed to lookup executable path", "error", err)
|
||||
return "", err
|
||||
}
|
||||
runnersDir := envconfig.RunnersDir
|
||||
|
||||
cwd, err := os.Getwd()
|
||||
if err != nil {
|
||||
slog.Error("failed to lookup working directory", "error", err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
var paths []string
|
||||
for _, root := range []string{appExe, cwd} {
|
||||
paths = append(paths,
|
||||
filepath.Join(root),
|
||||
filepath.Join(root, "windows-"+runtime.GOARCH),
|
||||
filepath.Join(root, "dist", "windows-"+runtime.GOARCH),
|
||||
)
|
||||
}
|
||||
|
||||
// Try a few variations to improve developer experience when building from source in the local tree
|
||||
for _, p := range paths {
|
||||
candidate := filepath.Join(p, "ollama_runners")
|
||||
_, err := os.Stat(candidate)
|
||||
if err == nil {
|
||||
runnersDir = candidate
|
||||
break
|
||||
}
|
||||
}
|
||||
if runnersDir == "" {
|
||||
err = fmt.Errorf("unable to locate llm runner directory. Set OLLAMA_RUNNERS_DIR to the location of 'ollama_runners'")
|
||||
slog.Error("incomplete distribution", "error", err)
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
if runnersDir != "" {
|
||||
payloadsDir = runnersDir
|
||||
return payloadsDir, nil
|
||||
@@ -70,7 +35,7 @@ func PayloadsDir() (string, error) {
|
||||
|
||||
// The remainder only applies on non-windows where we still carry payloads in the main executable
|
||||
cleanupTmpDirs()
|
||||
tmpDir := os.Getenv("OLLAMA_TMPDIR")
|
||||
tmpDir := envconfig.TmpDir
|
||||
if tmpDir == "" {
|
||||
tmpDir, err = os.MkdirTemp("", "ollama")
|
||||
if err != nil {
|
||||
@@ -133,7 +98,7 @@ func cleanupTmpDirs() {
|
||||
func Cleanup() {
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
runnersDir := os.Getenv("OLLAMA_RUNNERS_DIR")
|
||||
runnersDir := envconfig.RunnersDir
|
||||
if payloadsDir != "" && runnersDir == "" && runtime.GOOS != "windows" {
|
||||
// We want to fully clean up the tmpdir parent of the payloads dir
|
||||
tmpDir := filepath.Clean(filepath.Join(payloadsDir, ".."))
|
||||
|
||||
85
gpu/gpu.go
85
gpu/gpu.go
@@ -21,16 +21,18 @@ import (
|
||||
"unsafe"
|
||||
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/server/envconfig"
|
||||
)
|
||||
|
||||
type handles struct {
|
||||
deviceCount int
|
||||
cudart *C.cudart_handle_t
|
||||
nvcuda *C.nvcuda_handle_t
|
||||
}
|
||||
|
||||
const (
|
||||
cudaMinimumMemory = 457 * format.MebiByte
|
||||
rocmMinimumMemory = 457 * format.MebiByte
|
||||
cudaMinimumMemory = 256 * format.MebiByte
|
||||
rocmMinimumMemory = 256 * format.MebiByte
|
||||
)
|
||||
|
||||
var gpuMutex sync.Mutex
|
||||
@@ -62,6 +64,22 @@ var CudartWindowsGlobs = []string{
|
||||
"c:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v*\\bin\\cudart64_*.dll",
|
||||
}
|
||||
|
||||
var NvcudaLinuxGlobs = []string{
|
||||
"/usr/local/cuda*/targets/*/lib/libcuda.so*",
|
||||
"/usr/lib/*-linux-gnu/nvidia/current/libcuda.so*",
|
||||
"/usr/lib/*-linux-gnu/libcuda.so*",
|
||||
"/usr/lib/wsl/lib/libcuda.so*",
|
||||
"/usr/lib/wsl/drivers/*/libcuda.so*",
|
||||
"/opt/cuda/lib*/libcuda.so*",
|
||||
"/usr/local/cuda/lib*/libcuda.so*",
|
||||
"/usr/lib*/libcuda.so*",
|
||||
"/usr/local/lib*/libcuda.so*",
|
||||
}
|
||||
|
||||
var NvcudaWindowsGlobs = []string{
|
||||
"c:\\windows\\system*\\nvcuda.dll",
|
||||
}
|
||||
|
||||
// Jetson devices have JETSON_JETPACK="x.y.z" factory set to the Jetpack version installed.
|
||||
// Included to drive logic for reducing Ollama-allocated overhead on L4T/Jetson devices.
|
||||
var CudaTegra string = os.Getenv("JETSON_JETPACK")
|
||||
@@ -74,6 +92,8 @@ func initGPUHandles() *handles {
|
||||
gpuHandles := &handles{}
|
||||
var cudartMgmtName string
|
||||
var cudartMgmtPatterns []string
|
||||
var nvcudaMgmtName string
|
||||
var nvcudaMgmtPatterns []string
|
||||
|
||||
tmpDir, _ := PayloadsDir()
|
||||
switch runtime.GOOS {
|
||||
@@ -82,6 +102,9 @@ func initGPUHandles() *handles {
|
||||
localAppData := os.Getenv("LOCALAPPDATA")
|
||||
cudartMgmtPatterns = []string{filepath.Join(localAppData, "Programs", "Ollama", cudartMgmtName)}
|
||||
cudartMgmtPatterns = append(cudartMgmtPatterns, CudartWindowsGlobs...)
|
||||
// Aligned with driver, we can't carry as payloads
|
||||
nvcudaMgmtName = "nvcuda.dll"
|
||||
nvcudaMgmtPatterns = NvcudaWindowsGlobs
|
||||
case "linux":
|
||||
cudartMgmtName = "libcudart.so*"
|
||||
if tmpDir != "" {
|
||||
@@ -89,11 +112,25 @@ func initGPUHandles() *handles {
|
||||
cudartMgmtPatterns = []string{filepath.Join(tmpDir, "cuda*", cudartMgmtName)}
|
||||
}
|
||||
cudartMgmtPatterns = append(cudartMgmtPatterns, CudartLinuxGlobs...)
|
||||
// Aligned with driver, we can't carry as payloads
|
||||
nvcudaMgmtName = "libcuda.so*"
|
||||
nvcudaMgmtPatterns = NvcudaLinuxGlobs
|
||||
default:
|
||||
return gpuHandles
|
||||
}
|
||||
|
||||
slog.Info("Detecting GPUs")
|
||||
nvcudaLibPaths := FindGPULibs(nvcudaMgmtName, nvcudaMgmtPatterns)
|
||||
if len(nvcudaLibPaths) > 0 {
|
||||
deviceCount, nvcuda, libPath := LoadNVCUDAMgmt(nvcudaLibPaths)
|
||||
if nvcuda != nil {
|
||||
slog.Info("detected GPUs", "count", deviceCount, "library", libPath)
|
||||
gpuHandles.nvcuda = nvcuda
|
||||
gpuHandles.deviceCount = deviceCount
|
||||
return gpuHandles
|
||||
}
|
||||
}
|
||||
|
||||
cudartLibPaths := FindGPULibs(cudartMgmtName, cudartMgmtPatterns)
|
||||
if len(cudartLibPaths) > 0 {
|
||||
deviceCount, cudart, libPath := LoadCUDARTMgmt(cudartLibPaths)
|
||||
@@ -118,6 +155,9 @@ func GetGPUInfo() GpuInfoList {
|
||||
if gpuHandles.cudart != nil {
|
||||
C.cudart_release(*gpuHandles.cudart)
|
||||
}
|
||||
if gpuHandles.nvcuda != nil {
|
||||
C.nvcuda_release(*gpuHandles.nvcuda)
|
||||
}
|
||||
}()
|
||||
|
||||
// All our GPU builds on x86 have AVX enabled, so fallback to CPU if we don't detect at least AVX
|
||||
@@ -126,6 +166,12 @@ func GetGPUInfo() GpuInfoList {
|
||||
slog.Warn("CPU does not have AVX or AVX2, disabling GPU support.")
|
||||
}
|
||||
|
||||
// On windows we bundle the nvidia library one level above the runner dir
|
||||
depPath := ""
|
||||
if runtime.GOOS == "windows" && envconfig.RunnersDir != "" {
|
||||
depPath = filepath.Dir(envconfig.RunnersDir)
|
||||
}
|
||||
|
||||
var memInfo C.mem_info_t
|
||||
resp := []GpuInfo{}
|
||||
|
||||
@@ -138,7 +184,11 @@ func GetGPUInfo() GpuInfoList {
|
||||
gpuInfo := GpuInfo{
|
||||
Library: "cuda",
|
||||
}
|
||||
C.cudart_check_vram(*gpuHandles.cudart, C.int(i), &memInfo)
|
||||
if gpuHandles.cudart != nil {
|
||||
C.cudart_check_vram(*gpuHandles.cudart, C.int(i), &memInfo)
|
||||
} else {
|
||||
C.nvcuda_check_vram(*gpuHandles.nvcuda, C.int(i), &memInfo)
|
||||
}
|
||||
if memInfo.err != nil {
|
||||
slog.Info("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err))
|
||||
C.free(unsafe.Pointer(memInfo.err))
|
||||
@@ -154,6 +204,7 @@ func GetGPUInfo() GpuInfoList {
|
||||
gpuInfo.Major = int(memInfo.major)
|
||||
gpuInfo.Minor = int(memInfo.minor)
|
||||
gpuInfo.MinimumMemory = cudaMinimumMemory
|
||||
gpuInfo.DependencyPath = depPath
|
||||
|
||||
// TODO potentially sort on our own algorithm instead of what the underlying GPU library does...
|
||||
resp = append(resp, gpuInfo)
|
||||
@@ -196,9 +247,10 @@ func GetCPUMem() (memInfo, error) {
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func FindGPULibs(baseLibName string, patterns []string) []string {
|
||||
func FindGPULibs(baseLibName string, defaultPatterns []string) []string {
|
||||
// Multiple GPU libraries may exist, and some may not work, so keep trying until we exhaust them
|
||||
var ldPaths []string
|
||||
var patterns []string
|
||||
gpuLibPaths := []string{}
|
||||
slog.Debug("Searching for GPU library", "name", baseLibName)
|
||||
|
||||
@@ -218,8 +270,14 @@ func FindGPULibs(baseLibName string, patterns []string) []string {
|
||||
}
|
||||
patterns = append(patterns, filepath.Join(d, baseLibName+"*"))
|
||||
}
|
||||
patterns = append(patterns, defaultPatterns...)
|
||||
slog.Debug("gpu library search", "globs", patterns)
|
||||
for _, pattern := range patterns {
|
||||
|
||||
// Nvidia PhysX known to return bogus results
|
||||
if strings.Contains(pattern, "PhysX") {
|
||||
slog.Debug("skipping PhysX cuda library path", "path", pattern)
|
||||
}
|
||||
// Ignore glob discovery errors
|
||||
matches, _ := filepath.Glob(pattern)
|
||||
for _, match := range matches {
|
||||
@@ -267,8 +325,25 @@ func LoadCUDARTMgmt(cudartLibPaths []string) (int, *C.cudart_handle_t, string) {
|
||||
return 0, nil, ""
|
||||
}
|
||||
|
||||
func LoadNVCUDAMgmt(nvcudaLibPaths []string) (int, *C.nvcuda_handle_t, string) {
|
||||
var resp C.nvcuda_init_resp_t
|
||||
resp.ch.verbose = getVerboseState()
|
||||
for _, libPath := range nvcudaLibPaths {
|
||||
lib := C.CString(libPath)
|
||||
defer C.free(unsafe.Pointer(lib))
|
||||
C.nvcuda_init(lib, &resp)
|
||||
if resp.err != nil {
|
||||
slog.Debug("Unable to load nvcuda", "library", libPath, "error", C.GoString(resp.err))
|
||||
C.free(unsafe.Pointer(resp.err))
|
||||
} else {
|
||||
return int(resp.num_devices), &resp.ch, libPath
|
||||
}
|
||||
}
|
||||
return 0, nil, ""
|
||||
}
|
||||
|
||||
func getVerboseState() C.uint16_t {
|
||||
if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" {
|
||||
if envconfig.Debug {
|
||||
return C.uint16_t(1)
|
||||
}
|
||||
return C.uint16_t(0)
|
||||
|
||||
@@ -10,6 +10,12 @@ package gpu
|
||||
import "C"
|
||||
import (
|
||||
"runtime"
|
||||
|
||||
"github.com/ollama/ollama/format"
|
||||
)
|
||||
|
||||
const (
|
||||
metalMinimumMemory = 384 * format.MebiByte
|
||||
)
|
||||
|
||||
func GetGPUInfo() GpuInfoList {
|
||||
@@ -32,7 +38,7 @@ func GetGPUInfo() GpuInfoList {
|
||||
// TODO is there a way to gather actual allocated video memory? (currentAllocatedSize doesn't work)
|
||||
info.FreeMemory = info.TotalMemory
|
||||
|
||||
info.MinimumMemory = 0
|
||||
info.MinimumMemory = metalMinimumMemory
|
||||
return []GpuInfo{info}
|
||||
}
|
||||
|
||||
|
||||
@@ -58,6 +58,7 @@ void cpu_check_ram(mem_info_t *resp);
|
||||
#endif
|
||||
|
||||
#include "gpu_info_cudart.h"
|
||||
#include "gpu_info_nvcuda.h"
|
||||
|
||||
#endif // __GPU_INFO_H__
|
||||
#endif // __APPLE__
|
||||
@@ -6,9 +6,9 @@
|
||||
// Just enough typedef's to dlopen/dlsym for memory information
|
||||
typedef enum cudartReturn_enum {
|
||||
CUDART_SUCCESS = 0,
|
||||
CUDA_ERROR_INVALID_VALUE = 1,
|
||||
CUDA_ERROR_MEMORY_ALLOCATION = 2,
|
||||
CUDA_ERROR_INSUFFICIENT_DRIVER = 35,
|
||||
CUDART_ERROR_INVALID_VALUE = 1,
|
||||
CUDART_ERROR_MEMORY_ALLOCATION = 2,
|
||||
CUDART_ERROR_INSUFFICIENT_DRIVER = 35,
|
||||
// Other values omitted for now...
|
||||
} cudartReturn_t;
|
||||
|
||||
|
||||
203
gpu/gpu_info_nvcuda.c
Normal file
203
gpu/gpu_info_nvcuda.c
Normal file
@@ -0,0 +1,203 @@
|
||||
#ifndef __APPLE__ // TODO - maybe consider nvidia support on intel macs?
|
||||
|
||||
#include <string.h>
|
||||
#include "gpu_info_nvcuda.h"
|
||||
|
||||
void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp) {
|
||||
CUresult ret;
|
||||
resp->err = NULL;
|
||||
resp->num_devices = 0;
|
||||
const int buflen = 256;
|
||||
char buf[buflen + 1];
|
||||
int i;
|
||||
|
||||
struct lookup {
|
||||
char *s;
|
||||
void **p;
|
||||
} l[] = {
|
||||
|
||||
{"cuInit", (void *)&resp->ch.cuInit},
|
||||
{"cuDriverGetVersion", (void *)&resp->ch.cuDriverGetVersion},
|
||||
{"cuDeviceGetCount", (void *)&resp->ch.cuDeviceGetCount},
|
||||
{"cuDeviceGet", (void *)&resp->ch.cuDeviceGet},
|
||||
{"cuDeviceGetAttribute", (void *)&resp->ch.cuDeviceGetAttribute},
|
||||
{"cuDeviceGetUuid", (void *)&resp->ch.cuDeviceGetUuid},
|
||||
{"cuCtxCreate_v3", (void *)&resp->ch.cuCtxCreate_v3},
|
||||
{"cuMemGetInfo_v2", (void *)&resp->ch.cuMemGetInfo_v2},
|
||||
{"cuCtxDestroy", (void *)&resp->ch.cuCtxDestroy},
|
||||
{NULL, NULL},
|
||||
};
|
||||
|
||||
resp->ch.handle = LOAD_LIBRARY(nvcuda_lib_path, RTLD_LAZY);
|
||||
if (!resp->ch.handle) {
|
||||
char *msg = LOAD_ERR();
|
||||
LOG(resp->ch.verbose, "library %s load err: %s\n", nvcuda_lib_path, msg);
|
||||
snprintf(buf, buflen,
|
||||
"Unable to load %s library to query for Nvidia GPUs: %s",
|
||||
nvcuda_lib_path, msg);
|
||||
free(msg);
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
|
||||
for (i = 0; l[i].s != NULL; i++) {
|
||||
*l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s);
|
||||
if (!*l[i].p) {
|
||||
char *msg = LOAD_ERR();
|
||||
LOG(resp->ch.verbose, "dlerr: %s\n", msg);
|
||||
UNLOAD_LIBRARY(resp->ch.handle);
|
||||
resp->ch.handle = NULL;
|
||||
snprintf(buf, buflen, "symbol lookup for %s failed: %s", l[i].s,
|
||||
msg);
|
||||
free(msg);
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
ret = (*resp->ch.cuInit)(0);
|
||||
if (ret != CUDA_SUCCESS) {
|
||||
LOG(resp->ch.verbose, "cuInit err: %d\n", ret);
|
||||
UNLOAD_LIBRARY(resp->ch.handle);
|
||||
resp->ch.handle = NULL;
|
||||
if (ret == CUDA_ERROR_INSUFFICIENT_DRIVER) {
|
||||
resp->err = strdup("your nvidia driver is too old or missing. If you have a CUDA GPU please upgrade to run ollama");
|
||||
return;
|
||||
}
|
||||
snprintf(buf, buflen, "nvcuda init failure: %d", ret);
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
|
||||
int version = 0;
|
||||
nvcudaDriverVersion_t driverVersion;
|
||||
driverVersion.major = 0;
|
||||
driverVersion.minor = 0;
|
||||
|
||||
// Report driver version if we're in verbose mode, ignore errors
|
||||
ret = (*resp->ch.cuDriverGetVersion)(&version);
|
||||
if (ret != CUDA_SUCCESS) {
|
||||
LOG(resp->ch.verbose, "cuDriverGetVersion failed: %d\n", ret);
|
||||
} else {
|
||||
driverVersion.major = version / 1000;
|
||||
driverVersion.minor = (version - (driverVersion.major * 1000)) / 10;
|
||||
LOG(resp->ch.verbose, "CUDA driver version: %d-%d\n", driverVersion.major, driverVersion.minor);
|
||||
}
|
||||
|
||||
ret = (*resp->ch.cuDeviceGetCount)(&resp->num_devices);
|
||||
if (ret != CUDA_SUCCESS) {
|
||||
LOG(resp->ch.verbose, "cuDeviceGetCount err: %d\n", ret);
|
||||
UNLOAD_LIBRARY(resp->ch.handle);
|
||||
resp->ch.handle = NULL;
|
||||
snprintf(buf, buflen, "unable to get device count: %d", ret);
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
const int buflen = 256;
|
||||
void nvcuda_check_vram(nvcuda_handle_t h, int i, mem_info_t *resp) {
|
||||
resp->err = NULL;
|
||||
nvcudaMemory_t memInfo = {0,0};
|
||||
CUresult ret;
|
||||
CUdevice device = -1;
|
||||
CUcontext ctx = NULL;
|
||||
char buf[buflen + 1];
|
||||
CUuuid uuid = {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0};
|
||||
|
||||
if (h.handle == NULL) {
|
||||
resp->err = strdup("nvcuda handle isn't initialized");
|
||||
return;
|
||||
}
|
||||
|
||||
ret = (*h.cuDeviceGet)(&device, i);
|
||||
if (ret != CUDA_SUCCESS) {
|
||||
snprintf(buf, buflen, "nvcuda device failed to initialize");
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
|
||||
resp->major = 0;
|
||||
resp->minor = 0;
|
||||
int major = 0;
|
||||
int minor = 0;
|
||||
ret = (*h.cuDeviceGetAttribute)(&major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device);
|
||||
if (ret != CUDA_SUCCESS) {
|
||||
LOG(h.verbose, "[%d] device major lookup failure: %d\n", i, ret);
|
||||
} else {
|
||||
ret = (*h.cuDeviceGetAttribute)(&minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device);
|
||||
if (ret != CUDA_SUCCESS) {
|
||||
LOG(h.verbose, "[%d] device minor lookup failure: %d\n", i, ret);
|
||||
} else {
|
||||
resp->minor = minor;
|
||||
resp->major = major;
|
||||
}
|
||||
}
|
||||
|
||||
ret = (*h.cuDeviceGetUuid)(&uuid, device);
|
||||
if (ret != CUDA_SUCCESS) {
|
||||
LOG(h.verbose, "[%d] device uuid lookup failure: %d\n", i, ret);
|
||||
snprintf(&resp->gpu_id[0], GPU_ID_LEN, "%d", i);
|
||||
} else {
|
||||
// GPU-d110a105-ac29-1d54-7b49-9c90440f215b
|
||||
snprintf(&resp->gpu_id[0], GPU_ID_LEN,
|
||||
"GPU-%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x",
|
||||
uuid.bytes[0],
|
||||
uuid.bytes[1],
|
||||
uuid.bytes[2],
|
||||
uuid.bytes[3],
|
||||
uuid.bytes[4],
|
||||
uuid.bytes[5],
|
||||
uuid.bytes[6],
|
||||
uuid.bytes[7],
|
||||
uuid.bytes[8],
|
||||
uuid.bytes[9],
|
||||
uuid.bytes[10],
|
||||
uuid.bytes[11],
|
||||
uuid.bytes[12],
|
||||
uuid.bytes[13],
|
||||
uuid.bytes[14],
|
||||
uuid.bytes[15]
|
||||
);
|
||||
}
|
||||
|
||||
// To get memory we have to set (and release) a context
|
||||
ret = (*h.cuCtxCreate_v3)(&ctx, NULL, 0, 0, device);
|
||||
if (ret != CUDA_SUCCESS) {
|
||||
snprintf(buf, buflen, "nvcuda failed to get primary device context %d", ret);
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
|
||||
ret = (*h.cuMemGetInfo_v2)(&memInfo.free, &memInfo.total);
|
||||
if (ret != CUDA_SUCCESS) {
|
||||
snprintf(buf, buflen, "nvcuda device memory info lookup failure %d", ret);
|
||||
resp->err = strdup(buf);
|
||||
// Best effort on failure...
|
||||
(*h.cuCtxDestroy)(ctx);
|
||||
return;
|
||||
}
|
||||
|
||||
resp->total = memInfo.total;
|
||||
resp->free = memInfo.free;
|
||||
|
||||
LOG(h.verbose, "[%s] CUDA totalMem %lu mb\n", resp->gpu_id, resp->total / 1024 / 1024);
|
||||
LOG(h.verbose, "[%s] CUDA freeMem %lu mb\n", resp->gpu_id, resp->free / 1024 / 1024);
|
||||
LOG(h.verbose, "[%s] Compute Capability %d.%d\n", resp->gpu_id, resp->major, resp->minor);
|
||||
|
||||
|
||||
|
||||
ret = (*h.cuCtxDestroy)(ctx);
|
||||
if (ret != CUDA_SUCCESS) {
|
||||
LOG(1, "nvcuda failed to release primary device context %d", ret);
|
||||
}
|
||||
}
|
||||
|
||||
void nvcuda_release(nvcuda_handle_t h) {
|
||||
LOG(h.verbose, "releasing nvcuda library\n");
|
||||
UNLOAD_LIBRARY(h.handle);
|
||||
// TODO and other context release logic?
|
||||
h.handle = NULL;
|
||||
}
|
||||
|
||||
#endif // __APPLE__
|
||||
71
gpu/gpu_info_nvcuda.h
Normal file
71
gpu/gpu_info_nvcuda.h
Normal file
@@ -0,0 +1,71 @@
|
||||
#ifndef __APPLE__
|
||||
#ifndef __GPU_INFO_NVCUDA_H__
|
||||
#define __GPU_INFO_NVCUDA_H__
|
||||
#include "gpu_info.h"
|
||||
|
||||
// Just enough typedef's to dlopen/dlsym for memory information
|
||||
typedef enum cudaError_enum {
|
||||
CUDA_SUCCESS = 0,
|
||||
CUDA_ERROR_INVALID_VALUE = 1,
|
||||
CUDA_ERROR_MEMORY_ALLOCATION = 2,
|
||||
CUDA_ERROR_NOT_INITIALIZED = 3,
|
||||
CUDA_ERROR_INSUFFICIENT_DRIVER = 35,
|
||||
// Other values omitted for now...
|
||||
} CUresult;
|
||||
|
||||
typedef enum CUdevice_attribute_enum {
|
||||
CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR = 75,
|
||||
CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR = 76,
|
||||
|
||||
// TODO - not yet wired up but may be useful for Jetson or other
|
||||
// integrated GPU scenarios with shared memory
|
||||
CU_DEVICE_ATTRIBUTE_INTEGRATED = 18
|
||||
|
||||
} CUdevice_attribute;
|
||||
|
||||
typedef void *nvcudaDevice_t; // Opaque is sufficient
|
||||
typedef struct nvcudaMemory_st {
|
||||
uint64_t total;
|
||||
uint64_t free;
|
||||
} nvcudaMemory_t;
|
||||
|
||||
typedef struct nvcudaDriverVersion {
|
||||
int major;
|
||||
int minor;
|
||||
} nvcudaDriverVersion_t;
|
||||
|
||||
typedef struct CUuuid_st {
|
||||
unsigned char bytes[16];
|
||||
} CUuuid;
|
||||
|
||||
typedef int CUdevice;
|
||||
typedef void* CUcontext;
|
||||
|
||||
typedef struct nvcuda_handle {
|
||||
void *handle;
|
||||
uint16_t verbose;
|
||||
CUresult (*cuInit)(unsigned int Flags);
|
||||
CUresult (*cuDriverGetVersion)(int *driverVersion);
|
||||
CUresult (*cuDeviceGetCount)(int *);
|
||||
CUresult (*cuDeviceGet)(CUdevice* device, int ordinal);
|
||||
CUresult (*cuDeviceGetAttribute)(int* pi, CUdevice_attribute attrib, CUdevice dev);
|
||||
CUresult (*cuDeviceGetUuid)(CUuuid* uuid, CUdevice dev); // signature compatible with cuDeviceGetUuid_v2
|
||||
|
||||
// Context specific aspects
|
||||
CUresult (*cuCtxCreate_v3)(CUcontext* pctx, void *params, int len, unsigned int flags, CUdevice dev);
|
||||
CUresult (*cuMemGetInfo_v2)(uint64_t* free, uint64_t* total);
|
||||
CUresult (*cuCtxDestroy)(CUcontext ctx);
|
||||
} nvcuda_handle_t;
|
||||
|
||||
typedef struct nvcuda_init_resp {
|
||||
char *err; // If err is non-null handle is invalid
|
||||
nvcuda_handle_t ch;
|
||||
int num_devices;
|
||||
} nvcuda_init_resp_t;
|
||||
|
||||
void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp);
|
||||
void nvcuda_check_vram(nvcuda_handle_t ch, int device_id, mem_info_t *resp);
|
||||
void nvcuda_release(nvcuda_handle_t ch);
|
||||
|
||||
#endif // __GPU_INFO_NVCUDA_H__
|
||||
#endif // __APPLE__
|
||||
117
integration/max_queue_test.go
Normal file
117
integration/max_queue_test.go
Normal file
@@ -0,0 +1,117 @@
|
||||
//go:build integration
|
||||
|
||||
package integration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMaxQueue(t *testing.T) {
|
||||
// Note: This test can be quite slow when running in CPU mode, so keep the threadCount low unless your on GPU
|
||||
// Also note that by default Darwin can't sustain > ~128 connections without adjusting limits
|
||||
threadCount := 32
|
||||
mq := os.Getenv("OLLAMA_MAX_QUEUE")
|
||||
if mq != "" {
|
||||
var err error
|
||||
threadCount, err = strconv.Atoi(mq)
|
||||
require.NoError(t, err)
|
||||
} else {
|
||||
os.Setenv("OLLAMA_MAX_QUEUE", fmt.Sprintf("%d", threadCount))
|
||||
}
|
||||
|
||||
req := api.GenerateRequest{
|
||||
Model: "orca-mini",
|
||||
Prompt: "write a long historical fiction story about christopher columbus. use at least 10 facts from his actual journey",
|
||||
Options: map[string]interface{}{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
}
|
||||
resp := []string{"explore", "discover", "ocean"}
|
||||
|
||||
// CPU mode takes much longer at the limit with a large queue setting
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
require.NoError(t, PullIfMissing(ctx, client, req.Model))
|
||||
|
||||
// Context for the worker threads so we can shut them down
|
||||
// embedCtx, embedCancel := context.WithCancel(ctx)
|
||||
embedCtx := ctx
|
||||
|
||||
var genwg sync.WaitGroup
|
||||
go func() {
|
||||
genwg.Add(1)
|
||||
defer genwg.Done()
|
||||
slog.Info("Starting generate request")
|
||||
DoGenerate(ctx, t, client, req, resp, 45*time.Second, 5*time.Second)
|
||||
slog.Info("generate completed")
|
||||
}()
|
||||
|
||||
// Give the generate a chance to get started before we start hammering on embed requests
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
|
||||
threadCount += 10 // Add a few extra to ensure we push the queue past its limit
|
||||
busyCount := 0
|
||||
resetByPeerCount := 0
|
||||
canceledCount := 0
|
||||
succesCount := 0
|
||||
counterMu := sync.Mutex{}
|
||||
var embedwg sync.WaitGroup
|
||||
for i := 0; i < threadCount; i++ {
|
||||
go func(i int) {
|
||||
embedwg.Add(1)
|
||||
defer embedwg.Done()
|
||||
slog.Info("embed started", "id", i)
|
||||
embedReq := api.EmbeddingRequest{
|
||||
Model: req.Model,
|
||||
Prompt: req.Prompt,
|
||||
Options: req.Options,
|
||||
}
|
||||
// Fresh client for every request
|
||||
client, _ = GetTestEndpoint()
|
||||
|
||||
resp, genErr := client.Embeddings(embedCtx, &embedReq)
|
||||
counterMu.Lock()
|
||||
defer counterMu.Unlock()
|
||||
switch {
|
||||
case genErr == nil:
|
||||
succesCount++
|
||||
require.Greater(t, len(resp.Embedding), 5) // somewhat arbitrary, but sufficient to be reasonable
|
||||
case errors.Is(genErr, context.Canceled):
|
||||
canceledCount++
|
||||
case strings.Contains(genErr.Error(), "busy"):
|
||||
busyCount++
|
||||
case strings.Contains(genErr.Error(), "connection reset by peer"):
|
||||
resetByPeerCount++
|
||||
default:
|
||||
require.NoError(t, genErr, "%d request failed", i)
|
||||
}
|
||||
|
||||
slog.Info("embed finished", "id", i)
|
||||
}(i)
|
||||
}
|
||||
genwg.Wait()
|
||||
slog.Info("generate done, waiting for embeds")
|
||||
embedwg.Wait()
|
||||
|
||||
require.Equal(t, resetByPeerCount, 0, "Connections reset by peer, have you updated your fd and socket limits?")
|
||||
require.True(t, busyCount > 0, "no requests hit busy error but some should have")
|
||||
require.True(t, canceledCount == 0, "no requests should have been canceled due to timeout")
|
||||
|
||||
slog.Info("embeds completed", "success", succesCount, "busy", busyCount, "reset", resetByPeerCount, "canceled", canceledCount)
|
||||
}
|
||||
17
llm/ext_server/server.cpp
vendored
17
llm/ext_server/server.cpp
vendored
@@ -1032,7 +1032,7 @@ struct llama_server_context
|
||||
slot.has_next_token = false;
|
||||
}
|
||||
|
||||
if (!slot.cache_tokens.empty() && result.tok == llama_token_eos(model))
|
||||
if (!slot.cache_tokens.empty() && llama_token_is_eog(model, result.tok))
|
||||
{
|
||||
slot.stopped_eos = true;
|
||||
slot.has_next_token = false;
|
||||
@@ -1144,12 +1144,15 @@ struct llama_server_context
|
||||
|
||||
res.result_json = json
|
||||
{
|
||||
{"content", tkn.text_to_send},
|
||||
{"stop", false},
|
||||
{"slot_id", slot.id},
|
||||
{"multimodal", multimodal}
|
||||
};
|
||||
|
||||
if (!llama_token_is_eog(model, tkn.tok)) {
|
||||
res.result_json["content"] = tkn.text_to_send;
|
||||
}
|
||||
|
||||
if (slot.sparams.n_probs > 0)
|
||||
{
|
||||
std::vector<completion_token_output> probs_output = {};
|
||||
@@ -1183,8 +1186,6 @@ struct llama_server_context
|
||||
{"model", params.model_alias},
|
||||
{"tokens_predicted", slot.n_decoded},
|
||||
{"tokens_evaluated", slot.n_prompt_tokens},
|
||||
{"generation_settings", get_formated_generation(slot)},
|
||||
{"prompt", slot.prompt},
|
||||
{"truncated", slot.truncated},
|
||||
{"stopped_eos", slot.stopped_eos},
|
||||
{"stopped_word", slot.stopped_word},
|
||||
@@ -2644,18 +2645,18 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
|
||||
if (strncmp(sep, "int:", 4) == 0) {
|
||||
sep += 4;
|
||||
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT;
|
||||
kvo.int_value = std::atol(sep);
|
||||
kvo.val_i64 = std::atol(sep);
|
||||
} else if (strncmp(sep, "float:", 6) == 0) {
|
||||
sep += 6;
|
||||
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_FLOAT;
|
||||
kvo.float_value = std::atof(sep);
|
||||
kvo.val_f64 = std::atof(sep);
|
||||
} else if (strncmp(sep, "bool:", 5) == 0) {
|
||||
sep += 5;
|
||||
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_BOOL;
|
||||
if (std::strcmp(sep, "true") == 0) {
|
||||
kvo.bool_value = true;
|
||||
kvo.val_bool = true;
|
||||
} else if (std::strcmp(sep, "false") == 0) {
|
||||
kvo.bool_value = false;
|
||||
kvo.val_bool = false;
|
||||
} else {
|
||||
fprintf(stderr, "error: Invalid boolean value for KV override: %s\n", argv[i]);
|
||||
invalid_param = true;
|
||||
|
||||
@@ -42,7 +42,7 @@ function init_vars {
|
||||
"-DLLAMA_NATIVE=off"
|
||||
)
|
||||
$script:commonCpuDefs = @("-DCMAKE_POSITION_INDEPENDENT_CODE=on")
|
||||
$script:ARCH = "amd64" # arm not yet supported.
|
||||
$script:ARCH = $Env:PROCESSOR_ARCHITECTURE.ToLower()
|
||||
$script:DIST_BASE = "${script:SRC_DIR}\dist\windows-${script:ARCH}\ollama_runners"
|
||||
md "$script:DIST_BASE" -ea 0 > $null
|
||||
if ($env:CGO_CFLAGS -contains "-g") {
|
||||
@@ -213,11 +213,11 @@ function build_static() {
|
||||
}
|
||||
}
|
||||
|
||||
function build_cpu() {
|
||||
function build_cpu($gen_arch) {
|
||||
if ((-not "${env:OLLAMA_SKIP_CPU_GENERATE}" ) -and ((-not "${env:OLLAMA_CPU_TARGET}") -or ("${env:OLLAMA_CPU_TARGET}" -eq "cpu"))) {
|
||||
# remaining llama.cpp builds use MSVC
|
||||
init_vars
|
||||
$script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DLLAMA_AVX=off", "-DLLAMA_AVX2=off", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=off", "-DLLAMA_F16C=off") + $script:cmakeDefs
|
||||
$script:cmakeDefs = $script:commonCpuDefs + @("-A", $gen_arch, "-DLLAMA_AVX=off", "-DLLAMA_AVX2=off", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=off", "-DLLAMA_F16C=off") + $script:cmakeDefs
|
||||
$script:buildDir="../build/windows/${script:ARCH}/cpu"
|
||||
$script:distDir="$script:DIST_BASE\cpu"
|
||||
write-host "Building LCD CPU"
|
||||
@@ -349,11 +349,15 @@ if ($($args.count) -eq 0) {
|
||||
git_module_setup
|
||||
apply_patches
|
||||
build_static
|
||||
build_cpu
|
||||
build_cpu_avx
|
||||
build_cpu_avx2
|
||||
build_cuda
|
||||
build_rocm
|
||||
if ($script:ARCH -eq "arm64") {
|
||||
build_cpu("ARM64")
|
||||
} else { # amd64
|
||||
build_cpu("x64")
|
||||
build_cpu_avx
|
||||
build_cpu_avx2
|
||||
build_cuda
|
||||
build_rocm
|
||||
}
|
||||
|
||||
cleanup
|
||||
write-host "`ngo generate completed. LLM runners: $(get-childitem -path $script:DIST_BASE)"
|
||||
|
||||
Submodule llm/llama.cpp updated: 46e12c4692...952d03dbea
@@ -4,6 +4,7 @@ package llm
|
||||
// #cgo darwin,arm64 LDFLAGS: ${SRCDIR}/build/darwin/arm64_static/libllama.a -lstdc++
|
||||
// #cgo darwin,amd64 LDFLAGS: ${SRCDIR}/build/darwin/x86_64_static/libllama.a -lstdc++
|
||||
// #cgo windows,amd64 LDFLAGS: ${SRCDIR}/build/windows/amd64_static/libllama.a -static -lstdc++
|
||||
// #cgo windows,arm64 LDFLAGS: ${SRCDIR}/build/windows/arm64_static/libllama.a -static -lstdc++
|
||||
// #cgo linux,amd64 LDFLAGS: ${SRCDIR}/build/linux/x86_64_static/libllama.a -lstdc++
|
||||
// #cgo linux,arm64 LDFLAGS: ${SRCDIR}/build/linux/arm64_static/libllama.a -lstdc++
|
||||
// #include <stdlib.h>
|
||||
|
||||
@@ -3,12 +3,11 @@ package llm
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/gpu"
|
||||
"github.com/ollama/ollama/server/envconfig"
|
||||
)
|
||||
|
||||
// This algorithm looks for a complete fit to determine if we need to unload other models
|
||||
@@ -50,15 +49,8 @@ func EstimateGPULayers(gpus []gpu.GpuInfo, ggml *GGML, projectors []string, opts
|
||||
for _, info := range gpus {
|
||||
memoryAvailable += info.FreeMemory
|
||||
}
|
||||
userLimit := os.Getenv("OLLAMA_MAX_VRAM")
|
||||
if userLimit != "" {
|
||||
avail, err := strconv.ParseUint(userLimit, 10, 64)
|
||||
if err != nil {
|
||||
slog.Error("invalid setting, ignoring", "OLLAMA_MAX_VRAM", userLimit, "error", err)
|
||||
} else {
|
||||
slog.Info("user override memory limit", "OLLAMA_MAX_VRAM", avail, "actual", memoryAvailable)
|
||||
memoryAvailable = avail
|
||||
}
|
||||
if envconfig.MaxVRAM > 0 {
|
||||
memoryAvailable = envconfig.MaxVRAM
|
||||
}
|
||||
|
||||
slog.Debug("evaluating", "library", gpus[0].Library, "gpu_count", len(gpus), "available", format.HumanBytes2(memoryAvailable))
|
||||
@@ -88,19 +80,24 @@ func EstimateGPULayers(gpus []gpu.GpuInfo, ggml *GGML, projectors []string, opts
|
||||
graphFullOffload *= uint64(len(gpus))
|
||||
graphPartialOffload *= uint64(len(gpus))
|
||||
|
||||
// on metal there's no partial offload overhead
|
||||
if gpus[0].Library == "metal" {
|
||||
graphPartialOffload = graphFullOffload
|
||||
}
|
||||
|
||||
layers := ggml.Tensors().Layers()
|
||||
|
||||
// memoryRequiredTotal represents the memory required for full GPU offloading (all layers)
|
||||
memoryRequiredTotal := memoryMinimum + graphFullOffload
|
||||
memoryRequiredTotal := memoryMinimum + graphFullOffload + layers["blk.0"].size()
|
||||
|
||||
// memoryRequiredPartial represents the memory required for partial GPU offloading (n > 0, n < layers)
|
||||
memoryRequiredPartial := memoryMinimum + graphPartialOffload
|
||||
memoryRequiredPartial := memoryMinimum + graphPartialOffload + layers["blk.0"].size()
|
||||
|
||||
if memoryRequiredPartial > memoryAvailable {
|
||||
slog.Debug("insufficient VRAM to load any model layers")
|
||||
return 0, 0
|
||||
}
|
||||
|
||||
layers := ggml.Tensors().Layers()
|
||||
|
||||
var memoryLayerOutput uint64
|
||||
if layer, ok := layers["output_norm"]; ok {
|
||||
memoryLayerOutput += layer.size()
|
||||
|
||||
24
llm/patches/05-clip-fix.diff
Normal file
24
llm/patches/05-clip-fix.diff
Normal file
@@ -0,0 +1,24 @@
|
||||
diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp
|
||||
index e3c9bcd4..b43f892d 100644
|
||||
--- a/examples/llava/clip.cpp
|
||||
+++ b/examples/llava/clip.cpp
|
||||
@@ -573,14 +573,16 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||
struct ggml_tensor * embeddings = inp;
|
||||
if (ctx->has_class_embedding) {
|
||||
embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size);
|
||||
+ }
|
||||
+ ggml_set_name(embeddings, "embeddings");
|
||||
+ ggml_set_input(embeddings);
|
||||
+
|
||||
+ if (ctx->has_class_embedding) {
|
||||
embeddings = ggml_acc(ctx0, embeddings, model.class_embedding,
|
||||
embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], 0);
|
||||
embeddings = ggml_acc(ctx0, embeddings, inp,
|
||||
embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], model.class_embedding->nb[1]);
|
||||
}
|
||||
- ggml_set_name(embeddings, "embeddings");
|
||||
- ggml_set_input(embeddings);
|
||||
-
|
||||
|
||||
struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions);
|
||||
ggml_set_name(positions, "positions");
|
||||
315
llm/server.go
315
llm/server.go
@@ -26,6 +26,7 @@ import (
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/gpu"
|
||||
"github.com/ollama/ollama/server/envconfig"
|
||||
)
|
||||
|
||||
type LlamaServer interface {
|
||||
@@ -73,8 +74,7 @@ func LoadModel(model string) (*GGML, error) {
|
||||
func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, projectors []string, opts api.Options) (LlamaServer, error) {
|
||||
var err error
|
||||
if opts.NumCtx > int(ggml.KV().ContextLength()) {
|
||||
slog.Warn("requested context length is greater than model max context length", "requested", opts.NumCtx, "model", ggml.KV().ContextLength())
|
||||
opts.NumCtx = int(ggml.KV().ContextLength())
|
||||
slog.Warn("requested context length is greater than the model's training context window size", "requested", opts.NumCtx, "training size", ggml.KV().ContextLength())
|
||||
}
|
||||
|
||||
if opts.NumCtx < 4 {
|
||||
@@ -125,7 +125,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
||||
} else {
|
||||
servers = serversForGpu(gpus[0]) // All GPUs in the list are matching Library and Variant
|
||||
}
|
||||
demandLib := strings.Trim(os.Getenv("OLLAMA_LLM_LIBRARY"), "\"' ")
|
||||
demandLib := envconfig.LLMLibrary
|
||||
if demandLib != "" {
|
||||
serverPath := availableServers[demandLib]
|
||||
if serverPath == "" {
|
||||
@@ -146,7 +146,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
||||
"--batch-size", fmt.Sprintf("%d", opts.NumBatch),
|
||||
"--embedding",
|
||||
}
|
||||
if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" {
|
||||
if envconfig.Debug {
|
||||
params = append(params, "--log-format", "json")
|
||||
} else {
|
||||
params = append(params, "--log-disable")
|
||||
@@ -156,7 +156,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
||||
params = append(params, "--n-gpu-layers", fmt.Sprintf("%d", opts.NumGPU))
|
||||
}
|
||||
|
||||
if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" {
|
||||
if envconfig.Debug {
|
||||
params = append(params, "--verbose")
|
||||
}
|
||||
|
||||
@@ -194,16 +194,15 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
||||
params = append(params, "--numa")
|
||||
}
|
||||
|
||||
// "--cont-batching", // TODO - doesn't seem to have any noticeable perf change for multiple requests
|
||||
numParallel := 1
|
||||
if onp := os.Getenv("OLLAMA_NUM_PARALLEL"); onp != "" {
|
||||
numParallel, err = strconv.Atoi(onp)
|
||||
if err != nil || numParallel <= 0 {
|
||||
err = fmt.Errorf("invalid OLLAMA_NUM_PARALLEL=%s must be greater than zero - %w", onp, err)
|
||||
slog.Error("misconfiguration", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
numParallel := envconfig.NumParallel
|
||||
|
||||
// TODO (jmorganca): multimodal models don't support parallel yet
|
||||
// see https://github.com/ollama/ollama/issues/4165
|
||||
if len(projectors) > 0 {
|
||||
numParallel = 1
|
||||
slog.Warn("multimodal models don't support parallel requests yet")
|
||||
}
|
||||
|
||||
params = append(params, "--parallel", fmt.Sprintf("%d", numParallel))
|
||||
|
||||
for i := 0; i < len(servers); i++ {
|
||||
@@ -234,13 +233,13 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
||||
if runtime.GOOS == "windows" {
|
||||
pathEnv = "PATH"
|
||||
}
|
||||
// append the server directory to LD_LIBRARY_PATH/PATH
|
||||
// prepend the server directory to LD_LIBRARY_PATH/PATH
|
||||
libraryPaths := []string{dir}
|
||||
|
||||
if libraryPath, ok := os.LookupEnv(pathEnv); ok {
|
||||
// Append our runner directory to the path
|
||||
// This will favor system libraries over our bundled library dependencies
|
||||
libraryPaths = append(filepath.SplitList(libraryPath), libraryPaths...)
|
||||
libraryPaths = append(libraryPaths, filepath.SplitList(libraryPath)...)
|
||||
}
|
||||
|
||||
// Note: we always put the dependency path first
|
||||
@@ -276,15 +275,31 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
||||
sem: semaphore.NewWeighted(int64(numParallel)),
|
||||
}
|
||||
|
||||
libEnv := fmt.Sprintf("%s=%s", pathEnv, strings.Join(libraryPaths, string(filepath.ListSeparator)))
|
||||
s.cmd.Env = append(os.Environ(), libEnv)
|
||||
s.cmd.Env = os.Environ()
|
||||
s.cmd.Stdout = os.Stdout
|
||||
s.cmd.Stderr = s.status
|
||||
|
||||
// TODO - multiple GPU selection logic...
|
||||
key, val := gpu.GpuInfoList(gpus).GetVisibleDevicesEnv()
|
||||
if key != "" {
|
||||
s.cmd.Env = append(s.cmd.Env, key+"="+val)
|
||||
visibleDevicesEnv, visibleDevicesEnvVal := gpu.GpuInfoList(gpus).GetVisibleDevicesEnv()
|
||||
pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator))
|
||||
|
||||
// Update or add the path and visible devices variable with our adjusted version
|
||||
pathNeeded := true
|
||||
devicesNeeded := visibleDevicesEnv != ""
|
||||
for i := range s.cmd.Env {
|
||||
cmp := strings.SplitN(s.cmd.Env[i], "=", 2)
|
||||
if strings.EqualFold(cmp[0], pathEnv) {
|
||||
s.cmd.Env[i] = pathEnv + "=" + pathEnvVal
|
||||
pathNeeded = false
|
||||
} else if devicesNeeded && strings.EqualFold(cmp[0], visibleDevicesEnv) {
|
||||
s.cmd.Env[i] = visibleDevicesEnv + "=" + visibleDevicesEnvVal
|
||||
devicesNeeded = false
|
||||
}
|
||||
}
|
||||
if pathNeeded {
|
||||
s.cmd.Env = append(s.cmd.Env, pathEnv+"="+pathEnvVal)
|
||||
}
|
||||
if devicesNeeded {
|
||||
s.cmd.Env = append(s.cmd.Env, visibleDevicesEnv+"="+visibleDevicesEnvVal)
|
||||
}
|
||||
|
||||
slog.Info("starting llama server", "cmd", s.cmd.String())
|
||||
@@ -301,19 +316,6 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
||||
continue
|
||||
}
|
||||
|
||||
// reap subprocess when it exits
|
||||
go func() {
|
||||
// Exit status managed via getServerStatus
|
||||
_ = s.cmd.Wait()
|
||||
}()
|
||||
|
||||
// TODO - make sure this is all wired up correctly
|
||||
// if err = s.WaitUntilRunning(); err != nil {
|
||||
// slog.Error("error starting llama server", "server", servers[i], "error", err)
|
||||
// s.Close()
|
||||
// finalErr = err
|
||||
// continue
|
||||
// }
|
||||
return s, nil
|
||||
}
|
||||
|
||||
@@ -345,7 +347,7 @@ type ServerStatus int
|
||||
|
||||
const ( // iota is reset to 0
|
||||
ServerStatusReady ServerStatus = iota
|
||||
ServerStatusNoSlotsAvaialble
|
||||
ServerStatusNoSlotsAvailable
|
||||
ServerStatusLoadingModel
|
||||
ServerStatusNotResponding
|
||||
ServerStatusError
|
||||
@@ -355,7 +357,7 @@ func (s ServerStatus) ToString() string {
|
||||
switch s {
|
||||
case ServerStatusReady:
|
||||
return "llm server ready"
|
||||
case ServerStatusNoSlotsAvaialble:
|
||||
case ServerStatusNoSlotsAvailable:
|
||||
return "llm busy - no slots available"
|
||||
case ServerStatusLoadingModel:
|
||||
return "llm server loading model"
|
||||
@@ -412,7 +414,7 @@ func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
|
||||
case "ok":
|
||||
return ServerStatusReady, nil
|
||||
case "no slot available":
|
||||
return ServerStatusNoSlotsAvaialble, nil
|
||||
return ServerStatusNoSlotsAvailable, nil
|
||||
case "loading model":
|
||||
return ServerStatusLoadingModel, nil
|
||||
default:
|
||||
@@ -420,6 +422,29 @@ func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// getServerStatusRetry will retry if ServerStatusNoSlotsAvailable is received
|
||||
func (s *llmServer) getServerStatusRetry(ctx context.Context) (ServerStatus, error) {
|
||||
var retries int
|
||||
for {
|
||||
status, err := s.getServerStatus(ctx)
|
||||
if err != nil {
|
||||
return status, err
|
||||
}
|
||||
|
||||
if status == ServerStatusNoSlotsAvailable {
|
||||
if retries >= 10 {
|
||||
return status, fmt.Errorf("no slots available after %d retries", retries)
|
||||
}
|
||||
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
retries++
|
||||
continue
|
||||
}
|
||||
|
||||
return status, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *llmServer) Ping(ctx context.Context) error {
|
||||
_, err := s.getServerStatus(ctx)
|
||||
if err != nil {
|
||||
@@ -517,7 +542,6 @@ ws ::= ([ \t\n] ws)?
|
||||
`
|
||||
|
||||
const maxBufferSize = 512 * format.KiloByte
|
||||
const maxRetries = 3
|
||||
|
||||
type ImageData struct {
|
||||
Data []byte `json:"data"`
|
||||
@@ -593,7 +617,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
}
|
||||
|
||||
// Make sure the server is ready
|
||||
status, err := s.getServerStatus(ctx)
|
||||
status, err := s.getServerStatusRetry(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if status != ServerStatusReady {
|
||||
@@ -607,133 +631,113 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
}
|
||||
}
|
||||
|
||||
retryDelay := 100 * time.Microsecond
|
||||
for retries := 0; retries < maxRetries; retries++ {
|
||||
if retries > 0 {
|
||||
time.Sleep(retryDelay) // wait before retrying
|
||||
retryDelay *= 2 // exponential backoff
|
||||
}
|
||||
// Handling JSON marshaling with special characters unescaped.
|
||||
buffer := &bytes.Buffer{}
|
||||
enc := json.NewEncoder(buffer)
|
||||
enc.SetEscapeHTML(false)
|
||||
|
||||
// Handling JSON marshaling with special characters unescaped.
|
||||
buffer := &bytes.Buffer{}
|
||||
enc := json.NewEncoder(buffer)
|
||||
enc.SetEscapeHTML(false)
|
||||
if err := enc.Encode(request); err != nil {
|
||||
return fmt.Errorf("failed to marshal data: %v", err)
|
||||
}
|
||||
|
||||
if err := enc.Encode(request); err != nil {
|
||||
return fmt.Errorf("failed to marshal data: %v", err)
|
||||
}
|
||||
endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", s.port)
|
||||
serverReq, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, buffer)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating POST request: %v", err)
|
||||
}
|
||||
serverReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", s.port)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, buffer)
|
||||
res, err := http.DefaultClient.Do(serverReq)
|
||||
if err != nil {
|
||||
return fmt.Errorf("POST predict: %v", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode >= 400 {
|
||||
bodyBytes, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating POST request: %v", err)
|
||||
return fmt.Errorf("failed reading llm error response: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
log.Printf("llm predict error: %s", bodyBytes)
|
||||
return fmt.Errorf("%s", bodyBytes)
|
||||
}
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("POST predict: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
scanner := bufio.NewScanner(res.Body)
|
||||
buf := make([]byte, 0, maxBufferSize)
|
||||
scanner.Buffer(buf, maxBufferSize)
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading llm error response: %w", err)
|
||||
// keep track of the last token generated, this is used to abort if the model starts looping
|
||||
var lastToken string
|
||||
var tokenRepeat int
|
||||
|
||||
for scanner.Scan() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// This handles the request cancellation
|
||||
return ctx.Err()
|
||||
default:
|
||||
line := scanner.Bytes()
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
log.Printf("llm predict error: %s", bodyBytes)
|
||||
return fmt.Errorf("%s", bodyBytes)
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
buf := make([]byte, 0, maxBufferSize)
|
||||
scanner.Buffer(buf, maxBufferSize)
|
||||
evt, ok := bytes.CutPrefix(line, []byte("data: "))
|
||||
if !ok {
|
||||
return fmt.Errorf("error parsing llm response stream: %s", line)
|
||||
}
|
||||
|
||||
retryNeeded := false
|
||||
// keep track of the last token generated, this is used to abort if the model starts looping
|
||||
var lastToken string
|
||||
var tokenRepeat int
|
||||
var c completion
|
||||
if err := json.Unmarshal(evt, &c); err != nil {
|
||||
return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
|
||||
}
|
||||
|
||||
for scanner.Scan() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// This handles the request cancellation
|
||||
return ctx.Err()
|
||||
switch {
|
||||
case strings.TrimSpace(c.Content) == lastToken:
|
||||
tokenRepeat++
|
||||
default:
|
||||
line := scanner.Bytes()
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// try again on slot unavailable
|
||||
if bytes.Contains(line, []byte("slot unavailable")) {
|
||||
retryNeeded = true
|
||||
break
|
||||
}
|
||||
|
||||
evt, ok := bytes.CutPrefix(line, []byte("data: "))
|
||||
if !ok {
|
||||
return fmt.Errorf("error parsing llm response stream: %s", line)
|
||||
}
|
||||
|
||||
var c completion
|
||||
if err := json.Unmarshal(evt, &c); err != nil {
|
||||
return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
|
||||
}
|
||||
|
||||
switch {
|
||||
case strings.TrimSpace(c.Content) == lastToken:
|
||||
tokenRepeat++
|
||||
default:
|
||||
lastToken = strings.TrimSpace(c.Content)
|
||||
tokenRepeat = 0
|
||||
}
|
||||
|
||||
// 30 picked as an arbitrary max token repeat limit, modify as needed
|
||||
if tokenRepeat > 30 {
|
||||
slog.Debug("prediction aborted, token repeat limit reached")
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
if c.Content != "" {
|
||||
fn(CompletionResponse{
|
||||
Content: c.Content,
|
||||
})
|
||||
}
|
||||
|
||||
if c.Stop {
|
||||
fn(CompletionResponse{
|
||||
Done: true,
|
||||
PromptEvalCount: c.Timings.PromptN,
|
||||
PromptEvalDuration: parseDurationMs(c.Timings.PromptMS),
|
||||
EvalCount: c.Timings.PredictedN,
|
||||
EvalDuration: parseDurationMs(c.Timings.PredictedMS),
|
||||
})
|
||||
return nil
|
||||
}
|
||||
lastToken = strings.TrimSpace(c.Content)
|
||||
tokenRepeat = 0
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
if strings.Contains(err.Error(), "unexpected EOF") {
|
||||
s.Close()
|
||||
msg := ""
|
||||
if s.status != nil && s.status.LastErrMsg != "" {
|
||||
msg = s.status.LastErrMsg
|
||||
}
|
||||
|
||||
return fmt.Errorf("an unknown error was encountered while running the model %s", msg)
|
||||
// 30 picked as an arbitrary max token repeat limit, modify as needed
|
||||
if tokenRepeat > 30 {
|
||||
slog.Debug("prediction aborted, token repeat limit reached")
|
||||
return ctx.Err()
|
||||
}
|
||||
return fmt.Errorf("error reading llm response: %v", err)
|
||||
}
|
||||
|
||||
if !retryNeeded {
|
||||
return nil // success
|
||||
if c.Content != "" {
|
||||
fn(CompletionResponse{
|
||||
Content: c.Content,
|
||||
})
|
||||
}
|
||||
|
||||
if c.Stop {
|
||||
fn(CompletionResponse{
|
||||
Done: true,
|
||||
PromptEvalCount: c.Timings.PromptN,
|
||||
PromptEvalDuration: parseDurationMs(c.Timings.PromptMS),
|
||||
EvalCount: c.Timings.PredictedN,
|
||||
EvalDuration: parseDurationMs(c.Timings.PredictedMS),
|
||||
})
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// should never reach here ideally
|
||||
return fmt.Errorf("max retries exceeded")
|
||||
if err := scanner.Err(); err != nil {
|
||||
if strings.Contains(err.Error(), "unexpected EOF") {
|
||||
s.Close()
|
||||
msg := ""
|
||||
if s.status != nil && s.status.LastErrMsg != "" {
|
||||
msg = s.status.LastErrMsg
|
||||
}
|
||||
return fmt.Errorf("an unknown error was encountered while running the model %s", msg)
|
||||
}
|
||||
|
||||
return fmt.Errorf("error reading llm response: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type EmbeddingRequest struct {
|
||||
@@ -750,8 +754,9 @@ func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, er
|
||||
return nil, err
|
||||
}
|
||||
defer s.sem.Release(1)
|
||||
|
||||
// Make sure the server is ready
|
||||
status, err := s.getServerStatus(ctx)
|
||||
status, err := s.getServerStatusRetry(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if status != ServerStatusReady {
|
||||
@@ -806,7 +811,7 @@ func (s *llmServer) Tokenize(ctx context.Context, content string) ([]int, error)
|
||||
status, err := s.getServerStatus(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if status != ServerStatusReady && status != ServerStatusNoSlotsAvaialble {
|
||||
} else if status != ServerStatusReady && status != ServerStatusNoSlotsAvailable {
|
||||
return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
|
||||
}
|
||||
|
||||
@@ -858,7 +863,7 @@ func (s *llmServer) Detokenize(ctx context.Context, tokens []int) (string, error
|
||||
status, err := s.getServerStatus(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
} else if status != ServerStatusReady && status != ServerStatusNoSlotsAvaialble {
|
||||
} else if status != ServerStatusReady && status != ServerStatusNoSlotsAvailable {
|
||||
return "", fmt.Errorf("unexpected server status: %s", status.ToString())
|
||||
}
|
||||
|
||||
@@ -900,7 +905,13 @@ func (s *llmServer) Detokenize(ctx context.Context, tokens []int) (string, error
|
||||
func (s *llmServer) Close() error {
|
||||
if s.cmd != nil {
|
||||
slog.Debug("stopping llama server")
|
||||
return s.cmd.Process.Kill()
|
||||
if err := s.cmd.Process.Kill(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_ = s.cmd.Wait()
|
||||
|
||||
slog.Debug("llama server stopped")
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -19,7 +19,7 @@ export default function () {
|
||||
const [step, setStep] = useState<Step>(Step.WELCOME)
|
||||
const [commandCopied, setCommandCopied] = useState<boolean>(false)
|
||||
|
||||
const command = 'ollama run llama2'
|
||||
const command = 'ollama run llama3'
|
||||
|
||||
return (
|
||||
<div className='drag'>
|
||||
|
||||
132
parser/parser.go
132
parser/parser.go
@@ -1,132 +0,0 @@
|
||||
package parser
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"slices"
|
||||
)
|
||||
|
||||
type Command struct {
|
||||
Name string
|
||||
Args string
|
||||
}
|
||||
|
||||
func (c *Command) Reset() {
|
||||
c.Name = ""
|
||||
c.Args = ""
|
||||
}
|
||||
|
||||
func Parse(reader io.Reader) ([]Command, error) {
|
||||
var commands []Command
|
||||
var command, modelCommand Command
|
||||
|
||||
scanner := bufio.NewScanner(reader)
|
||||
scanner.Buffer(make([]byte, 0, bufio.MaxScanTokenSize), bufio.MaxScanTokenSize)
|
||||
scanner.Split(scanModelfile)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
|
||||
fields := bytes.SplitN(line, []byte(" "), 2)
|
||||
if len(fields) == 0 || len(fields[0]) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
switch string(bytes.ToUpper(fields[0])) {
|
||||
case "FROM":
|
||||
command.Name = "model"
|
||||
command.Args = string(bytes.TrimSpace(fields[1]))
|
||||
// copy command for validation
|
||||
modelCommand = command
|
||||
case "ADAPTER":
|
||||
command.Name = string(bytes.ToLower(fields[0]))
|
||||
command.Args = string(bytes.TrimSpace(fields[1]))
|
||||
case "LICENSE", "TEMPLATE", "SYSTEM", "PROMPT":
|
||||
command.Name = string(bytes.ToLower(fields[0]))
|
||||
command.Args = string(fields[1])
|
||||
case "PARAMETER":
|
||||
fields = bytes.SplitN(fields[1], []byte(" "), 2)
|
||||
if len(fields) < 2 {
|
||||
return nil, fmt.Errorf("missing value for %s", fields)
|
||||
}
|
||||
|
||||
command.Name = string(fields[0])
|
||||
command.Args = string(bytes.TrimSpace(fields[1]))
|
||||
case "EMBED":
|
||||
return nil, fmt.Errorf("deprecated command: EMBED is no longer supported, use the /embed API endpoint instead")
|
||||
case "MESSAGE":
|
||||
command.Name = string(bytes.ToLower(fields[0]))
|
||||
fields = bytes.SplitN(fields[1], []byte(" "), 2)
|
||||
if len(fields) < 2 {
|
||||
return nil, fmt.Errorf("should be in the format <role> <message>")
|
||||
}
|
||||
if !slices.Contains([]string{"system", "user", "assistant"}, string(bytes.ToLower(fields[0]))) {
|
||||
return nil, fmt.Errorf("role must be one of \"system\", \"user\", or \"assistant\"")
|
||||
}
|
||||
command.Args = fmt.Sprintf("%s: %s", string(bytes.ToLower(fields[0])), string(fields[1]))
|
||||
default:
|
||||
if !bytes.HasPrefix(fields[0], []byte("#")) {
|
||||
// log a warning for unknown commands
|
||||
slog.Warn(fmt.Sprintf("Unknown command: %s", fields[0]))
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
commands = append(commands, command)
|
||||
command.Reset()
|
||||
}
|
||||
|
||||
if modelCommand.Args == "" {
|
||||
return nil, errors.New("no FROM line for the model was specified")
|
||||
}
|
||||
|
||||
return commands, scanner.Err()
|
||||
}
|
||||
|
||||
func scanModelfile(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
advance, token, err = scan([]byte(`"""`), []byte(`"""`), data, atEOF)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
|
||||
if advance > 0 && token != nil {
|
||||
return advance, token, nil
|
||||
}
|
||||
|
||||
advance, token, err = scan([]byte(`"`), []byte(`"`), data, atEOF)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
|
||||
if advance > 0 && token != nil {
|
||||
return advance, token, nil
|
||||
}
|
||||
|
||||
return bufio.ScanLines(data, atEOF)
|
||||
}
|
||||
|
||||
func scan(openBytes, closeBytes, data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
newline := bytes.IndexByte(data, '\n')
|
||||
|
||||
if start := bytes.Index(data, openBytes); start >= 0 && start < newline {
|
||||
end := bytes.Index(data[start+len(openBytes):], closeBytes)
|
||||
if end < 0 {
|
||||
if atEOF {
|
||||
return 0, nil, fmt.Errorf("unterminated %s: expecting %s", openBytes, closeBytes)
|
||||
} else {
|
||||
return 0, nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
n := start + len(openBytes) + end + len(closeBytes)
|
||||
|
||||
newData := data[:start]
|
||||
newData = append(newData, data[start+len(openBytes):n-len(closeBytes)]...)
|
||||
return n, newData, nil
|
||||
}
|
||||
|
||||
return 0, nil, nil
|
||||
}
|
||||
@@ -1,98 +0,0 @@
|
||||
package parser
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_Parser(t *testing.T) {
|
||||
|
||||
input := `
|
||||
FROM model1
|
||||
ADAPTER adapter1
|
||||
LICENSE MIT
|
||||
PARAMETER param1 value1
|
||||
PARAMETER param2 value2
|
||||
TEMPLATE template1
|
||||
`
|
||||
|
||||
reader := strings.NewReader(input)
|
||||
|
||||
commands, err := Parse(reader)
|
||||
assert.Nil(t, err)
|
||||
|
||||
expectedCommands := []Command{
|
||||
{Name: "model", Args: "model1"},
|
||||
{Name: "adapter", Args: "adapter1"},
|
||||
{Name: "license", Args: "MIT"},
|
||||
{Name: "param1", Args: "value1"},
|
||||
{Name: "param2", Args: "value2"},
|
||||
{Name: "template", Args: "template1"},
|
||||
}
|
||||
|
||||
assert.Equal(t, expectedCommands, commands)
|
||||
}
|
||||
|
||||
func Test_Parser_NoFromLine(t *testing.T) {
|
||||
|
||||
input := `
|
||||
PARAMETER param1 value1
|
||||
PARAMETER param2 value2
|
||||
`
|
||||
|
||||
reader := strings.NewReader(input)
|
||||
|
||||
_, err := Parse(reader)
|
||||
assert.ErrorContains(t, err, "no FROM line")
|
||||
}
|
||||
|
||||
func Test_Parser_MissingValue(t *testing.T) {
|
||||
|
||||
input := `
|
||||
FROM foo
|
||||
PARAMETER param1
|
||||
`
|
||||
|
||||
reader := strings.NewReader(input)
|
||||
|
||||
_, err := Parse(reader)
|
||||
assert.ErrorContains(t, err, "missing value for [param1]")
|
||||
|
||||
}
|
||||
|
||||
func Test_Parser_Messages(t *testing.T) {
|
||||
|
||||
input := `
|
||||
FROM foo
|
||||
MESSAGE system You are a Parser. Always Parse things.
|
||||
MESSAGE user Hey there!
|
||||
MESSAGE assistant Hello, I want to parse all the things!
|
||||
`
|
||||
|
||||
reader := strings.NewReader(input)
|
||||
commands, err := Parse(reader)
|
||||
assert.Nil(t, err)
|
||||
|
||||
expectedCommands := []Command{
|
||||
{Name: "model", Args: "foo"},
|
||||
{Name: "message", Args: "system: You are a Parser. Always Parse things."},
|
||||
{Name: "message", Args: "user: Hey there!"},
|
||||
{Name: "message", Args: "assistant: Hello, I want to parse all the things!"},
|
||||
}
|
||||
|
||||
assert.Equal(t, expectedCommands, commands)
|
||||
}
|
||||
|
||||
func Test_Parser_Messages_BadRole(t *testing.T) {
|
||||
|
||||
input := `
|
||||
FROM foo
|
||||
MESSAGE badguy I'm a bad guy!
|
||||
`
|
||||
|
||||
reader := strings.NewReader(input)
|
||||
_, err := Parse(reader)
|
||||
assert.ErrorContains(t, err, "role must be one of \"system\", \"user\", or \"assistant\"")
|
||||
}
|
||||
@@ -7,6 +7,8 @@
|
||||
$ErrorActionPreference = "Stop"
|
||||
|
||||
function checkEnv() {
|
||||
$script:TARGET_ARCH=$Env:PROCESSOR_ARCHITECTURE.ToLower()
|
||||
Write-host "Building for ${script:TARGET_ARCH}"
|
||||
write-host "Locating required tools and paths"
|
||||
$script:SRC_DIR=$PWD
|
||||
if (!$env:VCToolsRedistDir) {
|
||||
@@ -30,7 +32,7 @@ function checkEnv() {
|
||||
|
||||
$script:INNO_SETUP_DIR=(get-item "C:\Program Files*\Inno Setup*\")[0]
|
||||
|
||||
$script:DEPS_DIR="${script:SRC_DIR}\dist\windows-amd64"
|
||||
$script:DEPS_DIR="${script:SRC_DIR}\dist\windows-${script:TARGET_ARCH}"
|
||||
$env:CGO_ENABLED="1"
|
||||
echo "Checking version"
|
||||
if (!$env:VERSION) {
|
||||
@@ -81,8 +83,8 @@ function buildOllama() {
|
||||
/csp "Google Cloud KMS Provider" /kc ${env:KEY_CONTAINER} ollama.exe
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
}
|
||||
New-Item -ItemType Directory -Path .\dist\windows-amd64\ -Force
|
||||
cp .\ollama.exe .\dist\windows-amd64\
|
||||
New-Item -ItemType Directory -Path .\dist\windows-${script:TARGET_ARCH}\ -Force
|
||||
cp .\ollama.exe .\dist\windows-${script:TARGET_ARCH}\
|
||||
}
|
||||
|
||||
function buildApp() {
|
||||
@@ -127,16 +129,16 @@ function buildInstaller() {
|
||||
cd "${script:SRC_DIR}\app"
|
||||
$env:PKG_VERSION=$script:PKG_VERSION
|
||||
if ("${env:KEY_CONTAINER}") {
|
||||
& "${script:INNO_SETUP_DIR}\ISCC.exe" /SMySignTool="${script:SignTool} sign /fd sha256 /t http://timestamp.digicert.com /f ${script:OLLAMA_CERT} /csp `$qGoogle Cloud KMS Provider`$q /kc ${env:KEY_CONTAINER} `$f" .\ollama.iss
|
||||
& "${script:INNO_SETUP_DIR}\ISCC.exe" /DARCH=$script:TARGET_ARCH /SMySignTool="${script:SignTool} sign /fd sha256 /t http://timestamp.digicert.com /f ${script:OLLAMA_CERT} /csp `$qGoogle Cloud KMS Provider`$q /kc ${env:KEY_CONTAINER} `$f" .\ollama.iss
|
||||
} else {
|
||||
& "${script:INNO_SETUP_DIR}\ISCC.exe" .\ollama.iss
|
||||
& "${script:INNO_SETUP_DIR}\ISCC.exe" /DARCH=$script:TARGET_ARCH .\ollama.iss
|
||||
}
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
}
|
||||
|
||||
function distZip() {
|
||||
write-host "Generating stand-alone distribution zip file ${script:SRC_DIR}\dist\ollama-windows-amd64.zip"
|
||||
Compress-Archive -Path "${script:SRC_DIR}\dist\windows-amd64\*" -DestinationPath "${script:SRC_DIR}\dist\ollama-windows-amd64.zip" -Force
|
||||
write-host "Generating stand-alone distribution zip file ${script:SRC_DIR}\dist\ollama-windows-${script:TARGET_ARCH}.zip"
|
||||
Compress-Archive -Path "${script:SRC_DIR}\dist\windows-${script:TARGET_ARCH}\*" -DestinationPath "${script:SRC_DIR}\dist\ollama-windows-${script:TARGET_ARCH}.zip" -Force
|
||||
}
|
||||
|
||||
try {
|
||||
|
||||
@@ -166,8 +166,8 @@ fi
|
||||
|
||||
if check_gpu lspci amdgpu || check_gpu lshw amdgpu; then
|
||||
# Look for pre-existing ROCm v6 before downloading the dependencies
|
||||
for search in "${HIP_PATH:-''}" "${ROCM_PATH:-''}" "/opt/rocm"; do
|
||||
if [ -n "${search}" ] && [ -e "${search}/lib/libhipblas.so.2" ]; then
|
||||
for search in "${HIP_PATH:-''}" "${ROCM_PATH:-''}" "/opt/rocm" "/usr/lib64"; do
|
||||
if [ -n "${search}" ] && [ -e "${search}/libhipblas.so.2" -o -e "${search}/lib/libhipblas.so.2" ]; then
|
||||
status "Compatible AMD GPU ROCm library detected at ${search}"
|
||||
install_success
|
||||
exit 0
|
||||
|
||||
174
server/envconfig/config.go
Normal file
174
server/envconfig/config.go
Normal file
@@ -0,0 +1,174 @@
|
||||
package envconfig
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
// Set via OLLAMA_ORIGINS in the environment
|
||||
AllowOrigins []string
|
||||
// Set via OLLAMA_DEBUG in the environment
|
||||
Debug bool
|
||||
// Set via OLLAMA_LLM_LIBRARY in the environment
|
||||
LLMLibrary string
|
||||
// Set via OLLAMA_MAX_LOADED_MODELS in the environment
|
||||
MaxRunners int
|
||||
// Set via OLLAMA_MAX_QUEUE in the environment
|
||||
MaxQueuedRequests int
|
||||
// Set via OLLAMA_MAX_VRAM in the environment
|
||||
MaxVRAM uint64
|
||||
// Set via OLLAMA_NOPRUNE in the environment
|
||||
NoPrune bool
|
||||
// Set via OLLAMA_NUM_PARALLEL in the environment
|
||||
NumParallel int
|
||||
// Set via OLLAMA_RUNNERS_DIR in the environment
|
||||
RunnersDir string
|
||||
// Set via OLLAMA_TMPDIR in the environment
|
||||
TmpDir string
|
||||
)
|
||||
|
||||
func AsMap() map[string]string {
|
||||
return map[string]string{
|
||||
"OLLAMA_ORIGINS": fmt.Sprintf("%v", AllowOrigins),
|
||||
"OLLAMA_DEBUG": fmt.Sprintf("%v", Debug),
|
||||
"OLLAMA_LLM_LIBRARY": fmt.Sprintf("%v", LLMLibrary),
|
||||
"OLLAMA_MAX_LOADED_MODELS": fmt.Sprintf("%v", MaxRunners),
|
||||
"OLLAMA_MAX_QUEUE": fmt.Sprintf("%v", MaxQueuedRequests),
|
||||
"OLLAMA_MAX_VRAM": fmt.Sprintf("%v", MaxVRAM),
|
||||
"OLLAMA_NOPRUNE": fmt.Sprintf("%v", NoPrune),
|
||||
"OLLAMA_NUM_PARALLEL": fmt.Sprintf("%v", NumParallel),
|
||||
"OLLAMA_RUNNERS_DIR": fmt.Sprintf("%v", RunnersDir),
|
||||
"OLLAMA_TMPDIR": fmt.Sprintf("%v", TmpDir),
|
||||
}
|
||||
}
|
||||
|
||||
var defaultAllowOrigins = []string{
|
||||
"localhost",
|
||||
"127.0.0.1",
|
||||
"0.0.0.0",
|
||||
}
|
||||
|
||||
// Clean quotes and spaces from the value
|
||||
func clean(key string) string {
|
||||
return strings.Trim(os.Getenv(key), "\"' ")
|
||||
}
|
||||
|
||||
func init() {
|
||||
// default values
|
||||
NumParallel = 1
|
||||
MaxRunners = 1
|
||||
MaxQueuedRequests = 512
|
||||
|
||||
LoadConfig()
|
||||
}
|
||||
|
||||
func LoadConfig() {
|
||||
if debug := clean("OLLAMA_DEBUG"); debug != "" {
|
||||
d, err := strconv.ParseBool(debug)
|
||||
if err == nil {
|
||||
Debug = d
|
||||
} else {
|
||||
Debug = true
|
||||
}
|
||||
}
|
||||
|
||||
RunnersDir = clean("OLLAMA_RUNNERS_DIR")
|
||||
if runtime.GOOS == "windows" && RunnersDir == "" {
|
||||
// On Windows we do not carry the payloads inside the main executable
|
||||
appExe, err := os.Executable()
|
||||
if err != nil {
|
||||
slog.Error("failed to lookup executable path", "error", err)
|
||||
}
|
||||
|
||||
cwd, err := os.Getwd()
|
||||
if err != nil {
|
||||
slog.Error("failed to lookup working directory", "error", err)
|
||||
}
|
||||
|
||||
var paths []string
|
||||
for _, root := range []string{filepath.Dir(appExe), cwd} {
|
||||
paths = append(paths,
|
||||
filepath.Join(root),
|
||||
filepath.Join(root, "windows-"+runtime.GOARCH),
|
||||
filepath.Join(root, "dist", "windows-"+runtime.GOARCH),
|
||||
)
|
||||
}
|
||||
|
||||
// Try a few variations to improve developer experience when building from source in the local tree
|
||||
for _, p := range paths {
|
||||
candidate := filepath.Join(p, "ollama_runners")
|
||||
_, err := os.Stat(candidate)
|
||||
if err == nil {
|
||||
RunnersDir = candidate
|
||||
break
|
||||
}
|
||||
}
|
||||
if RunnersDir == "" {
|
||||
slog.Error("unable to locate llm runner directory. Set OLLAMA_RUNNERS_DIR to the location of 'ollama_runners'")
|
||||
}
|
||||
}
|
||||
|
||||
TmpDir = clean("OLLAMA_TMPDIR")
|
||||
|
||||
userLimit := clean("OLLAMA_MAX_VRAM")
|
||||
if userLimit != "" {
|
||||
avail, err := strconv.ParseUint(userLimit, 10, 64)
|
||||
if err != nil {
|
||||
slog.Error("invalid setting, ignoring", "OLLAMA_MAX_VRAM", userLimit, "error", err)
|
||||
} else {
|
||||
MaxVRAM = avail
|
||||
}
|
||||
}
|
||||
|
||||
LLMLibrary = clean("OLLAMA_LLM_LIBRARY")
|
||||
|
||||
if onp := clean("OLLAMA_NUM_PARALLEL"); onp != "" {
|
||||
val, err := strconv.Atoi(onp)
|
||||
if err != nil || val <= 0 {
|
||||
slog.Error("invalid setting must be greater than zero", "OLLAMA_NUM_PARALLEL", onp, "error", err)
|
||||
} else {
|
||||
NumParallel = val
|
||||
}
|
||||
}
|
||||
|
||||
if noprune := clean("OLLAMA_NOPRUNE"); noprune != "" {
|
||||
NoPrune = true
|
||||
}
|
||||
|
||||
if origins := clean("OLLAMA_ORIGINS"); origins != "" {
|
||||
AllowOrigins = strings.Split(origins, ",")
|
||||
}
|
||||
for _, allowOrigin := range defaultAllowOrigins {
|
||||
AllowOrigins = append(AllowOrigins,
|
||||
fmt.Sprintf("http://%s", allowOrigin),
|
||||
fmt.Sprintf("https://%s", allowOrigin),
|
||||
fmt.Sprintf("http://%s:*", allowOrigin),
|
||||
fmt.Sprintf("https://%s:*", allowOrigin),
|
||||
)
|
||||
}
|
||||
|
||||
maxRunners := clean("OLLAMA_MAX_LOADED_MODELS")
|
||||
if maxRunners != "" {
|
||||
m, err := strconv.Atoi(maxRunners)
|
||||
if err != nil {
|
||||
slog.Error("invalid setting", "OLLAMA_MAX_LOADED_MODELS", maxRunners, "error", err)
|
||||
} else {
|
||||
MaxRunners = m
|
||||
}
|
||||
}
|
||||
|
||||
if onp := os.Getenv("OLLAMA_MAX_QUEUE"); onp != "" {
|
||||
p, err := strconv.Atoi(onp)
|
||||
if err != nil || p <= 0 {
|
||||
slog.Error("invalid setting", "OLLAMA_MAX_QUEUE", onp, "error", err)
|
||||
} else {
|
||||
MaxQueuedRequests = p
|
||||
}
|
||||
}
|
||||
}
|
||||
20
server/envconfig/config_test.go
Normal file
20
server/envconfig/config_test.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package envconfig
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestConfig(t *testing.T) {
|
||||
os.Setenv("OLLAMA_DEBUG", "")
|
||||
LoadConfig()
|
||||
require.False(t, Debug)
|
||||
os.Setenv("OLLAMA_DEBUG", "false")
|
||||
LoadConfig()
|
||||
require.False(t, Debug)
|
||||
os.Setenv("OLLAMA_DEBUG", "1")
|
||||
LoadConfig()
|
||||
require.True(t, Debug)
|
||||
}
|
||||
196
server/images.go
196
server/images.go
@@ -5,6 +5,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
@@ -20,15 +21,16 @@ import (
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
"golang.org/x/exp/slices"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/auth"
|
||||
"github.com/ollama/ollama/convert"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/parser"
|
||||
"github.com/ollama/ollama/server/envconfig"
|
||||
"github.com/ollama/ollama/types/errtypes"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
@@ -61,6 +63,76 @@ func (m *Model) IsEmbedding() bool {
|
||||
return slices.Contains(m.Config.ModelFamilies, "bert") || slices.Contains(m.Config.ModelFamilies, "nomic-bert")
|
||||
}
|
||||
|
||||
func (m *Model) String() string {
|
||||
var modelfile model.File
|
||||
|
||||
modelfile.Commands = append(modelfile.Commands, model.Command{
|
||||
Name: "model",
|
||||
Args: m.ModelPath,
|
||||
})
|
||||
|
||||
if m.Template != "" {
|
||||
modelfile.Commands = append(modelfile.Commands, model.Command{
|
||||
Name: "template",
|
||||
Args: m.Template,
|
||||
})
|
||||
}
|
||||
|
||||
if m.System != "" {
|
||||
modelfile.Commands = append(modelfile.Commands, model.Command{
|
||||
Name: "system",
|
||||
Args: m.System,
|
||||
})
|
||||
}
|
||||
|
||||
for _, adapter := range m.AdapterPaths {
|
||||
modelfile.Commands = append(modelfile.Commands, model.Command{
|
||||
Name: "adapter",
|
||||
Args: adapter,
|
||||
})
|
||||
}
|
||||
|
||||
for _, projector := range m.ProjectorPaths {
|
||||
modelfile.Commands = append(modelfile.Commands, model.Command{
|
||||
Name: "projector",
|
||||
Args: projector,
|
||||
})
|
||||
}
|
||||
|
||||
for k, v := range m.Options {
|
||||
switch v := v.(type) {
|
||||
case []any:
|
||||
for _, s := range v {
|
||||
modelfile.Commands = append(modelfile.Commands, model.Command{
|
||||
Name: k,
|
||||
Args: fmt.Sprintf("%v", s),
|
||||
})
|
||||
}
|
||||
default:
|
||||
modelfile.Commands = append(modelfile.Commands, model.Command{
|
||||
Name: k,
|
||||
Args: fmt.Sprintf("%v", v),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
for _, license := range m.License {
|
||||
modelfile.Commands = append(modelfile.Commands, model.Command{
|
||||
Name: "license",
|
||||
Args: license,
|
||||
})
|
||||
}
|
||||
|
||||
for _, msg := range m.Messages {
|
||||
modelfile.Commands = append(modelfile.Commands, model.Command{
|
||||
Name: "message",
|
||||
Args: fmt.Sprintf("%s %s", msg.Role, msg.Content),
|
||||
})
|
||||
}
|
||||
|
||||
return modelfile.String()
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
@@ -285,7 +357,7 @@ func realpath(mfDir, from string) string {
|
||||
return abspath
|
||||
}
|
||||
|
||||
func CreateModel(ctx context.Context, name, modelFileDir, quantization string, commands []parser.Command, fn func(resp api.ProgressResponse)) error {
|
||||
func CreateModel(ctx context.Context, name, modelFileDir, quantization string, modelfile *model.File, fn func(resp api.ProgressResponse)) error {
|
||||
deleteMap := make(map[string]struct{})
|
||||
if manifest, _, err := GetManifest(ParseModelPath(name)); err == nil {
|
||||
for _, layer := range append(manifest.Layers, manifest.Config) {
|
||||
@@ -307,7 +379,7 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, c
|
||||
params := make(map[string][]string)
|
||||
fromParams := make(map[string]any)
|
||||
|
||||
for _, c := range commands {
|
||||
for _, c := range modelfile.Commands {
|
||||
mediatype := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
|
||||
|
||||
switch c.Name {
|
||||
@@ -624,7 +696,7 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, c
|
||||
return err
|
||||
}
|
||||
|
||||
if noprune := os.Getenv("OLLAMA_NOPRUNE"); noprune == "" {
|
||||
if !envconfig.NoPrune {
|
||||
if err := deleteUnusedLayers(nil, deleteMap, false); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -710,6 +782,10 @@ func CopyModel(src, dst model.Name) error {
|
||||
return model.Unqualified(src)
|
||||
}
|
||||
|
||||
if src.Filepath() == dst.Filepath() {
|
||||
return nil
|
||||
}
|
||||
|
||||
manifests, err := GetManifestPath()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -894,67 +970,6 @@ func DeleteModel(name string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func ShowModelfile(model *Model) (string, error) {
|
||||
var mt struct {
|
||||
*Model
|
||||
From string
|
||||
Parameters map[string][]any
|
||||
}
|
||||
|
||||
mt.Parameters = make(map[string][]any)
|
||||
for k, v := range model.Options {
|
||||
if s, ok := v.([]any); ok {
|
||||
mt.Parameters[k] = s
|
||||
continue
|
||||
}
|
||||
|
||||
mt.Parameters[k] = []any{v}
|
||||
}
|
||||
|
||||
mt.Model = model
|
||||
mt.From = model.ModelPath
|
||||
|
||||
if model.ParentModel != "" {
|
||||
mt.From = model.ParentModel
|
||||
}
|
||||
|
||||
modelFile := `# Modelfile generated by "ollama show"
|
||||
# To build a new Modelfile based on this one, replace the FROM line with:
|
||||
# FROM {{ .ShortName }}
|
||||
|
||||
FROM {{ .From }}
|
||||
TEMPLATE """{{ .Template }}"""
|
||||
|
||||
{{- if .System }}
|
||||
SYSTEM """{{ .System }}"""
|
||||
{{- end }}
|
||||
|
||||
{{- range $adapter := .AdapterPaths }}
|
||||
ADAPTER {{ $adapter }}
|
||||
{{- end }}
|
||||
|
||||
{{- range $k, $v := .Parameters }}
|
||||
{{- range $parameter := $v }}
|
||||
PARAMETER {{ $k }} {{ printf "%#v" $parameter }}
|
||||
{{- end }}
|
||||
{{- end }}`
|
||||
|
||||
tmpl, err := template.New("").Parse(modelFile)
|
||||
if err != nil {
|
||||
slog.Info(fmt.Sprintf("error parsing template: %q", err))
|
||||
return "", err
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
|
||||
if err = tmpl.Execute(&buf, mt); err != nil {
|
||||
slog.Info(fmt.Sprintf("error executing template: %q", err))
|
||||
return "", err
|
||||
}
|
||||
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
|
||||
mp := ParseModelPath(name)
|
||||
fn(api.ProgressResponse{Status: "retrieving manifest"})
|
||||
@@ -976,9 +991,6 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
for _, layer := range layers {
|
||||
if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil {
|
||||
slog.Info(fmt.Sprintf("error uploading blob: %v", err))
|
||||
if errors.Is(err, errUnauthorized) {
|
||||
return fmt.Errorf("unable to push %s, make sure this namespace exists and you are authorized to push to it", ParseModelPath(name).GetNamespaceRepository())
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -1015,7 +1027,7 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
// build deleteMap to prune unused layers
|
||||
deleteMap := make(map[string]struct{})
|
||||
|
||||
if noprune = os.Getenv("OLLAMA_NOPRUNE"); noprune == "" {
|
||||
if !envconfig.NoPrune {
|
||||
manifest, _, err = GetManifest(mp)
|
||||
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
return err
|
||||
@@ -1141,9 +1153,40 @@ func GetSHA256Digest(r io.Reader) (string, int64) {
|
||||
return fmt.Sprintf("sha256:%x", h.Sum(nil)), n
|
||||
}
|
||||
|
||||
var errUnauthorized = errors.New("unauthorized")
|
||||
var errUnauthorized = fmt.Errorf("unauthorized: access denied")
|
||||
|
||||
// getTokenSubject returns the subject of a JWT token, it does not validate the token
|
||||
func getTokenSubject(token string) string {
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 3 {
|
||||
slog.Error("jwt token does not contain 3 parts")
|
||||
return ""
|
||||
}
|
||||
|
||||
payload := parts[1]
|
||||
payloadBytes, err := base64.RawURLEncoding.DecodeString(payload)
|
||||
if err != nil {
|
||||
slog.Error(fmt.Sprintf("failed to decode jwt payload: %v", err))
|
||||
return ""
|
||||
}
|
||||
|
||||
var payloadMap map[string]interface{}
|
||||
if err := json.Unmarshal(payloadBytes, &payloadMap); err != nil {
|
||||
slog.Error(fmt.Sprintf("failed to unmarshal payload JSON: %v", err))
|
||||
return ""
|
||||
}
|
||||
|
||||
sub, ok := payloadMap["sub"]
|
||||
if !ok {
|
||||
slog.Error("jwt does not contain 'sub' field")
|
||||
return ""
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s", sub)
|
||||
}
|
||||
|
||||
func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *registryOptions) (*http.Response, error) {
|
||||
anonymous := true // access will default to anonymous if no user is found associated with the public key
|
||||
for i := 0; i < 2; i++ {
|
||||
resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts)
|
||||
if err != nil {
|
||||
@@ -1162,6 +1205,7 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
anonymous = getTokenSubject(token) == "anonymous"
|
||||
regOpts.Token = token
|
||||
if body != nil {
|
||||
_, err = body.Seek(0, io.SeekStart)
|
||||
@@ -1182,6 +1226,16 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
|
||||
}
|
||||
}
|
||||
|
||||
if anonymous {
|
||||
// no user is associated with the public key, and the request requires non-anonymous access
|
||||
pubKey, nestedErr := auth.GetPublicKey()
|
||||
if nestedErr != nil {
|
||||
slog.Error(fmt.Sprintf("couldn't get public key: %v", nestedErr))
|
||||
return nil, errUnauthorized
|
||||
}
|
||||
return nil, &errtypes.UnknownOllamaKey{Key: pubKey}
|
||||
}
|
||||
// user is associated with the public key, but is not authorized to make the request
|
||||
return nil, errUnauthorized
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -25,9 +26,10 @@ const (
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidImageFormat = errors.New("invalid image format")
|
||||
ErrInvalidProtocol = errors.New("invalid protocol scheme")
|
||||
ErrInsecureProtocol = errors.New("insecure protocol http")
|
||||
ErrInvalidImageFormat = errors.New("invalid image format")
|
||||
ErrInvalidProtocol = errors.New("invalid protocol scheme")
|
||||
ErrInsecureProtocol = errors.New("insecure protocol http")
|
||||
ErrInvalidDigestFormat = errors.New("invalid digest format")
|
||||
)
|
||||
|
||||
func ParseModelPath(name string) ModelPath {
|
||||
@@ -149,6 +151,17 @@ func GetBlobsPath(digest string) (string, error) {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// only accept actual sha256 digests
|
||||
pattern := "^sha256[:-][0-9a-fA-F]{64}$"
|
||||
re := regexp.MustCompile(pattern)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if digest != "" && !re.MatchString(digest) {
|
||||
return "", ErrInvalidDigestFormat
|
||||
}
|
||||
|
||||
digest = strings.ReplaceAll(digest, ":", "-")
|
||||
path := filepath.Join(dir, "blobs", digest)
|
||||
dirPath := filepath.Dir(path)
|
||||
|
||||
@@ -1,6 +1,73 @@
|
||||
package server
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestGetBlobsPath(t *testing.T) {
|
||||
// GetBlobsPath expects an actual directory to exist
|
||||
dir, err := os.MkdirTemp("", "ollama-test")
|
||||
assert.Nil(t, err)
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
digest string
|
||||
expected string
|
||||
err error
|
||||
}{
|
||||
{
|
||||
"empty digest",
|
||||
"",
|
||||
filepath.Join(dir, "blobs"),
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"valid with colon",
|
||||
"sha256:456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9",
|
||||
filepath.Join(dir, "blobs", "sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9"),
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"valid with dash",
|
||||
"sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9",
|
||||
filepath.Join(dir, "blobs", "sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9"),
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"digest too short",
|
||||
"sha256-45640291",
|
||||
"",
|
||||
ErrInvalidDigestFormat,
|
||||
},
|
||||
{
|
||||
"digest too long",
|
||||
"sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9aaaaaaaaaa",
|
||||
"",
|
||||
ErrInvalidDigestFormat,
|
||||
},
|
||||
{
|
||||
"digest invalid chars",
|
||||
"../sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7a",
|
||||
"",
|
||||
ErrInvalidDigestFormat,
|
||||
},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Setenv("OLLAMA_MODELS", dir)
|
||||
|
||||
got, err := GetBlobsPath(tc.digest)
|
||||
|
||||
assert.ErrorIs(t, tc.err, err, tc.name)
|
||||
assert.Equal(t, tc.expected, got, tc.name)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseModelPath(t *testing.T) {
|
||||
tests := []struct {
|
||||
|
||||
116
server/routes.go
116
server/routes.go
@@ -1,6 +1,7 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
@@ -28,7 +29,7 @@ import (
|
||||
"github.com/ollama/ollama/gpu"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/openai"
|
||||
"github.com/ollama/ollama/parser"
|
||||
"github.com/ollama/ollama/server/envconfig"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
@@ -146,12 +147,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
select {
|
||||
case runner = <-rCh:
|
||||
case err = <-eCh:
|
||||
if errors.Is(err, context.Canceled) {
|
||||
c.JSON(499, gin.H{"error": "request canceled"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
handleErrorResponse(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -394,12 +390,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
||||
select {
|
||||
case runner = <-rCh:
|
||||
case err = <-eCh:
|
||||
if errors.Is(err, context.Canceled) {
|
||||
c.JSON(499, gin.H{"error": "request canceled"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
handleErrorResponse(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -522,28 +513,17 @@ func (s *Server) PushModelHandler(c *gin.Context) {
|
||||
|
||||
func (s *Server) CreateModelHandler(c *gin.Context) {
|
||||
var req api.CreateRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
switch {
|
||||
case errors.Is(err, io.EOF):
|
||||
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||
return
|
||||
case err != nil:
|
||||
} else if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
var model string
|
||||
if req.Model != "" {
|
||||
model = req.Model
|
||||
} else if req.Name != "" {
|
||||
model = req.Name
|
||||
} else {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := ParseModelPath(model).Validate(); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
name := model.ParseName(cmp.Or(req.Model, req.Name))
|
||||
if !name.IsValid() {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid model name"})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -552,19 +532,19 @@ func (s *Server) CreateModelHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
var modelfile io.Reader = strings.NewReader(req.Modelfile)
|
||||
var r io.Reader = strings.NewReader(req.Modelfile)
|
||||
if req.Path != "" && req.Modelfile == "" {
|
||||
mf, err := os.Open(req.Path)
|
||||
f, err := os.Open(req.Path)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("error reading modelfile: %s", err)})
|
||||
return
|
||||
}
|
||||
defer mf.Close()
|
||||
defer f.Close()
|
||||
|
||||
modelfile = mf
|
||||
r = f
|
||||
}
|
||||
|
||||
commands, err := parser.Parse(modelfile)
|
||||
modelfile, err := model.ParseFile(r)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -580,7 +560,7 @@ func (s *Server) CreateModelHandler(c *gin.Context) {
|
||||
ctx, cancel := context.WithCancel(c.Request.Context())
|
||||
defer cancel()
|
||||
|
||||
if err := CreateModel(ctx, model, filepath.Dir(req.Path), req.Quantization, commands, fn); err != nil {
|
||||
if err := CreateModel(ctx, name.String(), filepath.Dir(req.Path), req.Quantization, modelfile, fn); err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
}
|
||||
}()
|
||||
@@ -728,12 +708,12 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
}
|
||||
}
|
||||
|
||||
mf, err := ShowModelfile(model)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp.Modelfile = mf
|
||||
var sb strings.Builder
|
||||
fmt.Fprintln(&sb, "# Modelfile generate by \"ollama show\"")
|
||||
fmt.Fprintln(&sb, "# To build a new Modelfile based on this, replace FROM with:")
|
||||
fmt.Fprintf(&sb, "# FROM %s\n\n", model.ShortName)
|
||||
fmt.Fprint(&sb, model.String())
|
||||
resp.Modelfile = sb.String()
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
@@ -810,16 +790,13 @@ func (s *Server) CopyModelHandler(c *gin.Context) {
|
||||
|
||||
src := model.ParseName(r.Source)
|
||||
if !src.IsValid() {
|
||||
_ = c.Error(fmt.Errorf("source %q is invalid", r.Source))
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("source %q is invalid", r.Source)})
|
||||
return
|
||||
}
|
||||
|
||||
dst := model.ParseName(r.Destination)
|
||||
if !dst.IsValid() {
|
||||
_ = c.Error(fmt.Errorf("destination %q is invalid", r.Destination))
|
||||
}
|
||||
|
||||
if len(c.Errors) > 0 {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": c.Errors.Errors()})
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("destination %q is invalid", r.Source)})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -883,12 +860,6 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
|
||||
c.Status(http.StatusCreated)
|
||||
}
|
||||
|
||||
var defaultAllowOrigins = []string{
|
||||
"localhost",
|
||||
"127.0.0.1",
|
||||
"0.0.0.0",
|
||||
}
|
||||
|
||||
func isLocalIP(ip netip.Addr) bool {
|
||||
if interfaces, err := net.Interfaces(); err == nil {
|
||||
for _, iface := range interfaces {
|
||||
@@ -972,19 +943,7 @@ func (s *Server) GenerateRoutes() http.Handler {
|
||||
config := cors.DefaultConfig()
|
||||
config.AllowWildcard = true
|
||||
config.AllowBrowserExtensions = true
|
||||
|
||||
if allowedOrigins := strings.Trim(os.Getenv("OLLAMA_ORIGINS"), "\"'"); allowedOrigins != "" {
|
||||
config.AllowOrigins = strings.Split(allowedOrigins, ",")
|
||||
}
|
||||
|
||||
for _, allowOrigin := range defaultAllowOrigins {
|
||||
config.AllowOrigins = append(config.AllowOrigins,
|
||||
fmt.Sprintf("http://%s", allowOrigin),
|
||||
fmt.Sprintf("https://%s", allowOrigin),
|
||||
fmt.Sprintf("http://%s:*", allowOrigin),
|
||||
fmt.Sprintf("https://%s:*", allowOrigin),
|
||||
)
|
||||
}
|
||||
config.AllowOrigins = envconfig.AllowOrigins
|
||||
|
||||
r := gin.Default()
|
||||
r.Use(
|
||||
@@ -1023,10 +982,11 @@ func (s *Server) GenerateRoutes() http.Handler {
|
||||
|
||||
func Serve(ln net.Listener) error {
|
||||
level := slog.LevelInfo
|
||||
if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" {
|
||||
if envconfig.Debug {
|
||||
level = slog.LevelDebug
|
||||
}
|
||||
|
||||
slog.Info("server config", "env", envconfig.AsMap())
|
||||
handler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
|
||||
Level: level,
|
||||
AddSource: true,
|
||||
@@ -1050,7 +1010,7 @@ func Serve(ln net.Listener) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if noprune := os.Getenv("OLLAMA_NOPRUNE"); noprune == "" {
|
||||
if !envconfig.NoPrune {
|
||||
// clean up unused layers and manifests
|
||||
if err := PruneLayers(); err != nil {
|
||||
return err
|
||||
@@ -1081,6 +1041,7 @@ func Serve(ln net.Listener) error {
|
||||
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-signals
|
||||
srvr.Close()
|
||||
done()
|
||||
sched.unloadAllRunners()
|
||||
gpu.Cleanup()
|
||||
@@ -1226,12 +1187,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
select {
|
||||
case runner = <-rCh:
|
||||
case err = <-eCh:
|
||||
if errors.Is(err, context.Canceled) {
|
||||
c.JSON(499, gin.H{"error": "request canceled"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
handleErrorResponse(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1352,3 +1308,15 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
|
||||
streamResponse(c, ch)
|
||||
}
|
||||
|
||||
func handleErrorResponse(c *gin.Context, err error) {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
c.JSON(499, gin.H{"error": "request canceled"})
|
||||
return
|
||||
}
|
||||
if errors.Is(err, ErrMaxQueue) {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
}
|
||||
|
||||
@@ -17,7 +17,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/parser"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
|
||||
@@ -55,13 +55,13 @@ func Test_Routes(t *testing.T) {
|
||||
createTestModel := func(t *testing.T, name string) {
|
||||
fname := createTestFile(t, "ollama-model")
|
||||
|
||||
modelfile := strings.NewReader(fmt.Sprintf("FROM %s\nPARAMETER seed 42\nPARAMETER top_p 0.9\nPARAMETER stop foo\nPARAMETER stop bar", fname))
|
||||
commands, err := parser.Parse(modelfile)
|
||||
r := strings.NewReader(fmt.Sprintf("FROM %s\nPARAMETER seed 42\nPARAMETER top_p 0.9\nPARAMETER stop foo\nPARAMETER stop bar", fname))
|
||||
modelfile, err := model.ParseFile(r)
|
||||
assert.Nil(t, err)
|
||||
fn := func(resp api.ProgressResponse) {
|
||||
t.Logf("Status: %s", resp.Status)
|
||||
}
|
||||
err = CreateModel(context.TODO(), name, "", "", commands, fn)
|
||||
err = CreateModel(context.TODO(), name, "", "", modelfile, fn)
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
@@ -238,6 +238,5 @@ func Test_Routes(t *testing.T) {
|
||||
if tc.Expected != nil {
|
||||
tc.Expected(t, resp)
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,10 +5,8 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -17,6 +15,7 @@ import (
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/gpu"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/server/envconfig"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
@@ -43,35 +42,14 @@ type Scheduler struct {
|
||||
getGpuFn func() gpu.GpuInfoList
|
||||
}
|
||||
|
||||
// TODO set this to zero after a release or two, to enable multiple models by default
|
||||
var loadedMax = 1 // Maximum runners; < 1 maps to as many as will fit in VRAM (unlimited for CPU runners)
|
||||
var maxQueuedRequests = 10 // TODO configurable
|
||||
var numParallel = 1
|
||||
var ErrMaxQueue = fmt.Errorf("server busy, please try again. maximum pending requests exceeded")
|
||||
|
||||
func InitScheduler(ctx context.Context) *Scheduler {
|
||||
maxRunners := os.Getenv("OLLAMA_MAX_LOADED_MODELS")
|
||||
if maxRunners != "" {
|
||||
m, err := strconv.Atoi(maxRunners)
|
||||
if err != nil {
|
||||
slog.Error("invalid setting", "OLLAMA_MAX_LOADED_MODELS", maxRunners, "error", err)
|
||||
} else {
|
||||
loadedMax = m
|
||||
}
|
||||
}
|
||||
if onp := os.Getenv("OLLAMA_NUM_PARALLEL"); onp != "" {
|
||||
p, err := strconv.Atoi(onp)
|
||||
if err != nil || p <= 0 {
|
||||
slog.Error("invalid parallel setting, must be greater than zero", "OLLAMA_NUM_PARALLEL", onp, "error", err)
|
||||
} else {
|
||||
numParallel = p
|
||||
}
|
||||
}
|
||||
|
||||
sched := &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, maxQueuedRequests),
|
||||
finishedReqCh: make(chan *LlmRequest, maxQueuedRequests),
|
||||
expiredCh: make(chan *runnerRef, maxQueuedRequests),
|
||||
unloadedCh: make(chan interface{}, maxQueuedRequests),
|
||||
pendingReqCh: make(chan *LlmRequest, envconfig.MaxQueuedRequests),
|
||||
finishedReqCh: make(chan *LlmRequest, envconfig.MaxQueuedRequests),
|
||||
expiredCh: make(chan *runnerRef, envconfig.MaxQueuedRequests),
|
||||
unloadedCh: make(chan interface{}, envconfig.MaxQueuedRequests),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: llm.NewLlamaServer,
|
||||
getGpuFn: gpu.GetGPUInfo,
|
||||
@@ -82,6 +60,9 @@ func InitScheduler(ctx context.Context) *Scheduler {
|
||||
|
||||
// context must be canceled to decrement ref count and release the runner
|
||||
func (s *Scheduler) GetRunner(c context.Context, model *Model, opts api.Options, sessionDuration time.Duration) (chan *runnerRef, chan error) {
|
||||
// allocate a large enough kv cache for all parallel requests
|
||||
opts.NumCtx = opts.NumCtx * envconfig.NumParallel
|
||||
|
||||
req := &LlmRequest{
|
||||
ctx: c,
|
||||
model: model,
|
||||
@@ -90,12 +71,11 @@ func (s *Scheduler) GetRunner(c context.Context, model *Model, opts api.Options,
|
||||
successCh: make(chan *runnerRef),
|
||||
errCh: make(chan error, 1),
|
||||
}
|
||||
// context split across parallel threads
|
||||
opts.NumCtx = opts.NumCtx * numParallel
|
||||
|
||||
select {
|
||||
case s.pendingReqCh <- req:
|
||||
default:
|
||||
req.errCh <- fmt.Errorf("server busy, please try again. maximum pending requests exceeded")
|
||||
req.errCh <- ErrMaxQueue
|
||||
}
|
||||
return req.successCh, req.errCh
|
||||
}
|
||||
@@ -120,6 +100,12 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
||||
return
|
||||
case pending := <-s.pendingReqCh:
|
||||
// Block other requests until we get this pending request running
|
||||
|
||||
if pending.ctx.Err() != nil {
|
||||
slog.Debug("pending request cancelled or timed out, skipping scheduling")
|
||||
continue
|
||||
}
|
||||
|
||||
for {
|
||||
var runnerToExpire *runnerRef
|
||||
s.loadedMu.Lock()
|
||||
@@ -134,11 +120,11 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
||||
pending.useLoadedRunner(runner, s.finishedReqCh)
|
||||
break
|
||||
}
|
||||
} else if loadedMax > 0 && loadedCount >= loadedMax {
|
||||
} else if envconfig.MaxRunners > 0 && loadedCount >= envconfig.MaxRunners {
|
||||
slog.Debug("max runners achieved, unloading one to make room", "runner_count", loadedCount)
|
||||
runnerToExpire = s.findRunnerToUnload(pending)
|
||||
runnerToExpire = s.findRunnerToUnload()
|
||||
} else {
|
||||
// Either no models are loaded or below loadedMax
|
||||
// Either no models are loaded or below envconfig.MaxRunners
|
||||
// Get a refreshed GPU list
|
||||
gpus := s.getGpuFn()
|
||||
|
||||
@@ -149,6 +135,14 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
||||
break
|
||||
}
|
||||
|
||||
// If we're CPU only mode, just limit by envconfig.MaxRunners above
|
||||
// TODO handle system memory exhaustion
|
||||
if (len(gpus) == 1 && gpus[0].Library == "cpu") || pending.opts.NumGPU == 0 {
|
||||
slog.Debug("cpu mode with existing models, loading")
|
||||
s.loadFn(pending, ggml, gpus)
|
||||
break
|
||||
}
|
||||
|
||||
// No models loaded. Load the model but prefer the best fit.
|
||||
if loadedCount == 0 {
|
||||
slog.Debug("loading first model", "model", pending.model.ModelPath)
|
||||
@@ -169,7 +163,7 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
||||
s.loadFn(pending, ggml, gpus)
|
||||
break
|
||||
}
|
||||
runnerToExpire = s.findRunnerToUnload(pending)
|
||||
runnerToExpire = s.findRunnerToUnload()
|
||||
}
|
||||
|
||||
if runnerToExpire == nil {
|
||||
@@ -242,6 +236,7 @@ func (s *Scheduler) processCompleted(ctx context.Context) {
|
||||
defer runner.refMu.Unlock()
|
||||
if runner.expireTimer != nil {
|
||||
runner.expireTimer.Stop()
|
||||
runner.expireTimer = nil
|
||||
}
|
||||
s.expiredCh <- runner
|
||||
})
|
||||
@@ -268,9 +263,9 @@ func (s *Scheduler) processCompleted(ctx context.Context) {
|
||||
continue
|
||||
}
|
||||
|
||||
s.loadedMu.Lock()
|
||||
slog.Debug("got lock to unload", "model", runner.model)
|
||||
runner.unload()
|
||||
s.loadedMu.Lock()
|
||||
delete(s.loaded, runner.model)
|
||||
s.loadedMu.Unlock()
|
||||
slog.Debug("runner released", "model", runner.model)
|
||||
@@ -288,6 +283,10 @@ func (pending *LlmRequest) useLoadedRunner(runner *runnerRef, finished chan *Llm
|
||||
runner.refMu.Lock()
|
||||
defer runner.refMu.Unlock()
|
||||
runner.refCount++
|
||||
if runner.expireTimer != nil {
|
||||
runner.expireTimer.Stop()
|
||||
runner.expireTimer = nil
|
||||
}
|
||||
runner.sessionDuration = pending.sessionDuration
|
||||
pending.successCh <- runner
|
||||
go func() {
|
||||
@@ -418,6 +417,10 @@ type runnerRef struct {
|
||||
|
||||
// The refMu must already be held when calling unload
|
||||
func (runner *runnerRef) unload() {
|
||||
if runner.expireTimer != nil {
|
||||
runner.expireTimer.Stop()
|
||||
runner.expireTimer = nil
|
||||
}
|
||||
if runner.llama != nil {
|
||||
runner.llama.Close()
|
||||
}
|
||||
@@ -438,6 +441,10 @@ func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool
|
||||
timeout = 2 * time.Minute // Initial load can take a long time for big models on slow systems...
|
||||
}
|
||||
|
||||
if runner.Options == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
// Don't reload runner if num_gpu=-1 was provided
|
||||
optsExisting := runner.Options.Runner
|
||||
optsNew := req.opts.Runner
|
||||
@@ -507,7 +514,7 @@ func pickBestFitGPUs(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList) gpu.
|
||||
}
|
||||
|
||||
// findRunnerToUnload finds a runner to unload to make room for a new model
|
||||
func (s *Scheduler) findRunnerToUnload(req *LlmRequest) *runnerRef {
|
||||
func (s *Scheduler) findRunnerToUnload() *runnerRef {
|
||||
s.loadedMu.Lock()
|
||||
runnerList := make([]*runnerRef, 0, len(s.loaded))
|
||||
for _, r := range s.loaded {
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/gpu"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/server/envconfig"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -27,30 +28,21 @@ func init() {
|
||||
func TestInitScheduler(t *testing.T) {
|
||||
ctx, done := context.WithCancel(context.Background())
|
||||
defer done()
|
||||
initialMax := loadedMax
|
||||
s := InitScheduler(ctx)
|
||||
require.Equal(t, initialMax, loadedMax)
|
||||
require.NotNil(t, s.loaded)
|
||||
|
||||
os.Setenv("OLLAMA_MAX_LOADED_MODELS", "blue")
|
||||
s = InitScheduler(ctx)
|
||||
require.Equal(t, initialMax, loadedMax)
|
||||
require.NotNil(t, s.loaded)
|
||||
|
||||
os.Setenv("OLLAMA_MAX_LOADED_MODELS", "0")
|
||||
s = InitScheduler(ctx)
|
||||
require.Equal(t, 0, loadedMax)
|
||||
s.loadedMu.Lock()
|
||||
require.NotNil(t, s.loaded)
|
||||
s.loadedMu.Unlock()
|
||||
}
|
||||
|
||||
func TestLoad(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(context.Background(), 5*time.Millisecond)
|
||||
ctx, done := context.WithTimeout(context.Background(), 20*time.Millisecond)
|
||||
defer done()
|
||||
s := InitScheduler(ctx)
|
||||
var ggml *llm.GGML // value not used in tests
|
||||
req := &LlmRequest{
|
||||
ctx: ctx,
|
||||
model: &Model{ModelPath: "foo"},
|
||||
opts: api.DefaultOptions(),
|
||||
successCh: make(chan *runnerRef, 1),
|
||||
errCh: make(chan error, 1),
|
||||
sessionDuration: 2,
|
||||
@@ -63,7 +55,9 @@ func TestLoad(t *testing.T) {
|
||||
s.load(req, ggml, gpus)
|
||||
require.Len(t, req.successCh, 0)
|
||||
require.Len(t, req.errCh, 1)
|
||||
s.loadedMu.Lock()
|
||||
require.Len(t, s.loaded, 0)
|
||||
s.loadedMu.Unlock()
|
||||
err := <-req.errCh
|
||||
require.Contains(t, err.Error(), "this model may be incompatible")
|
||||
|
||||
@@ -78,7 +72,9 @@ func TestLoad(t *testing.T) {
|
||||
case resp := <-req.successCh:
|
||||
require.Equal(t, uint64(10), resp.estimatedVRAM)
|
||||
require.Equal(t, uint(1), resp.refCount)
|
||||
s.loadedMu.Lock()
|
||||
require.Len(t, s.loaded, 1)
|
||||
s.loadedMu.Unlock()
|
||||
}
|
||||
|
||||
req.model.ModelPath = "dummy_model_path"
|
||||
@@ -90,7 +86,9 @@ func TestLoad(t *testing.T) {
|
||||
case resp := <-req.successCh:
|
||||
t.Errorf("unexpected success %v", resp)
|
||||
}
|
||||
s.loadedMu.Lock()
|
||||
runner := s.loaded["dummy_model_path"]
|
||||
s.loadedMu.Unlock()
|
||||
require.NotNil(t, runner)
|
||||
require.Equal(t, uint(0), runner.refCount)
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
@@ -143,6 +141,7 @@ func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedV
|
||||
scenario.req = &LlmRequest{
|
||||
ctx: scenario.ctx,
|
||||
model: model,
|
||||
opts: api.DefaultOptions(),
|
||||
sessionDuration: 5 * time.Millisecond,
|
||||
successCh: make(chan *runnerRef, 1),
|
||||
errCh: make(chan error, 1),
|
||||
@@ -152,7 +151,7 @@ func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedV
|
||||
}
|
||||
|
||||
func TestRequests(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||
defer done()
|
||||
|
||||
// Same model, same request
|
||||
@@ -171,7 +170,9 @@ func TestRequests(t *testing.T) {
|
||||
// Multiple loaded models
|
||||
scenario3a := newScenario(t, ctx, "ollama-model-3a", 1*format.GigaByte)
|
||||
scenario3b := newScenario(t, ctx, "ollama-model-3b", 24*format.GigaByte)
|
||||
scenario3c := newScenario(t, ctx, "ollama-model-3c", 30) // Needs prior unloaded
|
||||
scenario3c := newScenario(t, ctx, "ollama-model-4a", 30)
|
||||
scenario3c.req.opts.NumGPU = 0 // CPU load, will be allowed
|
||||
scenario3d := newScenario(t, ctx, "ollama-model-3c", 30) // Needs prior unloaded
|
||||
|
||||
s := InitScheduler(ctx)
|
||||
s.getGpuFn = func() gpu.GpuInfoList {
|
||||
@@ -225,7 +226,7 @@ func TestRequests(t *testing.T) {
|
||||
t.Errorf("timeout")
|
||||
}
|
||||
|
||||
loadedMax = 1
|
||||
envconfig.MaxRunners = 1
|
||||
s.newServerFn = scenario3a.newServer
|
||||
slog.Info("scenario3a")
|
||||
s.pendingReqCh <- scenario3a.req
|
||||
@@ -240,9 +241,11 @@ func TestRequests(t *testing.T) {
|
||||
case <-ctx.Done():
|
||||
t.Errorf("timeout")
|
||||
}
|
||||
s.loadedMu.Lock()
|
||||
require.Len(t, s.loaded, 1)
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
loadedMax = 0
|
||||
envconfig.MaxRunners = 0
|
||||
s.newServerFn = scenario3b.newServer
|
||||
slog.Info("scenario3b")
|
||||
s.pendingReqCh <- scenario3b.req
|
||||
@@ -254,19 +257,14 @@ func TestRequests(t *testing.T) {
|
||||
case <-ctx.Done():
|
||||
t.Errorf("timeout")
|
||||
}
|
||||
s.loadedMu.Lock()
|
||||
require.Len(t, s.loaded, 2)
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
// Try to load a model that wont fit
|
||||
// This is a CPU load with NumGPU = 0 so it should load
|
||||
s.newServerFn = scenario3c.newServer
|
||||
slog.Info("scenario3c")
|
||||
require.Len(t, s.loaded, 2)
|
||||
scenario3a.ctxDone() // Won't help since this one isn't big enough to make room
|
||||
time.Sleep(2 * time.Millisecond)
|
||||
s.pendingReqCh <- scenario3c.req
|
||||
// finish prior request, so new model can load
|
||||
time.Sleep(6 * time.Millisecond)
|
||||
require.Len(t, s.loaded, 1)
|
||||
scenario3b.ctxDone()
|
||||
select {
|
||||
case resp := <-scenario3c.req.successCh:
|
||||
require.Equal(t, resp.llama, scenario3c.srv)
|
||||
@@ -275,11 +273,40 @@ func TestRequests(t *testing.T) {
|
||||
case <-ctx.Done():
|
||||
t.Errorf("timeout")
|
||||
}
|
||||
require.Len(t, s.loaded, 1)
|
||||
s.loadedMu.Lock()
|
||||
require.Len(t, s.loaded, 3)
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
// Try to load a model that wont fit
|
||||
s.newServerFn = scenario3d.newServer
|
||||
slog.Info("scenario3d")
|
||||
s.loadedMu.Lock()
|
||||
require.Len(t, s.loaded, 3)
|
||||
s.loadedMu.Unlock()
|
||||
scenario3a.ctxDone() // Won't help since this one isn't big enough to make room
|
||||
time.Sleep(2 * time.Millisecond)
|
||||
s.pendingReqCh <- scenario3d.req
|
||||
// finish prior request, so new model can load
|
||||
time.Sleep(6 * time.Millisecond)
|
||||
s.loadedMu.Lock()
|
||||
require.Len(t, s.loaded, 2)
|
||||
s.loadedMu.Unlock()
|
||||
scenario3b.ctxDone()
|
||||
select {
|
||||
case resp := <-scenario3d.req.successCh:
|
||||
require.Equal(t, resp.llama, scenario3d.srv)
|
||||
require.Len(t, s.pendingReqCh, 0)
|
||||
require.Len(t, scenario3d.req.errCh, 0)
|
||||
case <-ctx.Done():
|
||||
t.Errorf("timeout")
|
||||
}
|
||||
s.loadedMu.Lock()
|
||||
require.Len(t, s.loaded, 2)
|
||||
s.loadedMu.Unlock()
|
||||
}
|
||||
|
||||
func TestGetRunner(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(context.Background(), 20*time.Millisecond)
|
||||
ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer done()
|
||||
|
||||
// Same model, same request
|
||||
@@ -289,7 +316,7 @@ func TestGetRunner(t *testing.T) {
|
||||
scenario1b.req.sessionDuration = 0
|
||||
scenario1c := newScenario(t, ctx, "ollama-model-1c", 10)
|
||||
scenario1c.req.sessionDuration = 0
|
||||
maxQueuedRequests = 1
|
||||
envconfig.MaxQueuedRequests = 1
|
||||
s := InitScheduler(ctx)
|
||||
s.getGpuFn = func() gpu.GpuInfoList {
|
||||
g := gpu.GpuInfo{Library: "metal"}
|
||||
@@ -318,17 +345,19 @@ func TestGetRunner(t *testing.T) {
|
||||
t.Errorf("timeout")
|
||||
}
|
||||
scenario1a.ctxDone()
|
||||
s.loadedMu.Lock()
|
||||
require.Len(t, s.loaded, 1)
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
scenario1c.req.model.ModelPath = "bad path"
|
||||
slog.Info("scenario1c")
|
||||
successCh1c, errCh1c := s.GetRunner(scenario1c.ctx, scenario1c.req.model, scenario1c.req.opts, scenario1c.req.sessionDuration)
|
||||
require.Len(t, s.pendingReqCh, 0)
|
||||
require.Len(t, successCh1c, 0)
|
||||
require.Len(t, errCh1c, 0)
|
||||
|
||||
// Starts in pending channel, then should be quickly processsed to return an error
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
require.Len(t, successCh1c, 0)
|
||||
s.loadedMu.Lock()
|
||||
require.Len(t, s.loaded, 0)
|
||||
s.loadedMu.Unlock()
|
||||
require.Len(t, errCh1c, 1)
|
||||
err = <-errCh1c
|
||||
require.Contains(t, err.Error(), "bad path")
|
||||
@@ -337,7 +366,7 @@ func TestGetRunner(t *testing.T) {
|
||||
|
||||
// TODO - add one scenario that triggers the bogus finished event with positive ref count
|
||||
func TestPrematureExpired(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||
defer done()
|
||||
|
||||
// Same model, same request
|
||||
@@ -358,7 +387,9 @@ func TestPrematureExpired(t *testing.T) {
|
||||
require.Equal(t, resp.llama, scenario1a.srv)
|
||||
require.Len(t, s.pendingReqCh, 0)
|
||||
require.Len(t, errCh1a, 0)
|
||||
s.loadedMu.Lock()
|
||||
require.Len(t, s.loaded, 1)
|
||||
s.loadedMu.Unlock()
|
||||
slog.Info("sending premature expired event now")
|
||||
s.expiredCh <- resp // Shouldn't happen in real life, but make sure its safe
|
||||
case <-ctx.Done():
|
||||
@@ -380,9 +411,10 @@ func TestPrematureExpired(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUseLoadedRunner(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(context.Background(), 5*time.Millisecond)
|
||||
ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
req := &LlmRequest{
|
||||
ctx: ctx,
|
||||
opts: api.DefaultOptions(),
|
||||
successCh: make(chan *runnerRef, 1),
|
||||
sessionDuration: 2,
|
||||
}
|
||||
@@ -404,7 +436,7 @@ func TestUseLoadedRunner(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUpdateFreeSpace(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(context.Background(), 5*time.Millisecond)
|
||||
ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer done()
|
||||
gpus := gpu.GpuInfoList{
|
||||
{
|
||||
@@ -426,8 +458,10 @@ func TestUpdateFreeSpace(t *testing.T) {
|
||||
r2 := &runnerRef{llama: llm2, gpus: gpus}
|
||||
|
||||
s := InitScheduler(ctx)
|
||||
s.loadedMu.Lock()
|
||||
s.loaded["a"] = r1
|
||||
s.loaded["b"] = r2
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
s.updateFreeSpace(gpus)
|
||||
require.Equal(t, uint64(850), gpus[0].FreeMemory)
|
||||
@@ -435,33 +469,36 @@ func TestUpdateFreeSpace(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestFindRunnerToUnload(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(context.Background(), 5*time.Millisecond)
|
||||
ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer done()
|
||||
req := &LlmRequest{ctx: ctx}
|
||||
|
||||
r1 := &runnerRef{refCount: 1, sessionDuration: 1}
|
||||
r2 := &runnerRef{sessionDuration: 2}
|
||||
|
||||
s := InitScheduler(ctx)
|
||||
s.loadedMu.Lock()
|
||||
s.loaded["a"] = r1
|
||||
s.loaded["b"] = r2
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
resp := s.findRunnerToUnload(req)
|
||||
resp := s.findRunnerToUnload()
|
||||
require.Equal(t, r2, resp)
|
||||
r2.refCount = 1
|
||||
resp = s.findRunnerToUnload(req)
|
||||
resp = s.findRunnerToUnload()
|
||||
require.Equal(t, r1, resp)
|
||||
|
||||
}
|
||||
|
||||
func TestNeedsReload(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(context.Background(), 5*time.Millisecond)
|
||||
ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer done()
|
||||
|
||||
llm := &mockLlm{}
|
||||
do := api.DefaultOptions()
|
||||
runner := &runnerRef{
|
||||
adapters: []string{"adapter1"},
|
||||
projectors: []string{"projector1"},
|
||||
Options: &api.Options{},
|
||||
Options: &do,
|
||||
llama: llm,
|
||||
}
|
||||
req := &LlmRequest{
|
||||
@@ -469,7 +506,7 @@ func TestNeedsReload(t *testing.T) {
|
||||
AdapterPaths: []string{"adapter2"},
|
||||
ProjectorPaths: []string{"projector2"},
|
||||
},
|
||||
opts: api.Options{},
|
||||
opts: api.DefaultOptions(),
|
||||
}
|
||||
resp := runner.needsReload(ctx, req)
|
||||
require.True(t, resp)
|
||||
@@ -497,7 +534,7 @@ func TestNeedsReload(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUnloadAllRunners(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(context.Background(), 5*time.Millisecond)
|
||||
ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer done()
|
||||
|
||||
llm1 := &mockLlm{}
|
||||
@@ -508,8 +545,10 @@ func TestUnloadAllRunners(t *testing.T) {
|
||||
r1 := &runnerRef{llama: llm1}
|
||||
r2 := &runnerRef{llama: llm2}
|
||||
|
||||
s.loadedMu.Lock()
|
||||
s.loaded["a"] = r1
|
||||
s.loaded["b"] = r2
|
||||
s.loadedMu.Unlock()
|
||||
s.unloadAllRunners()
|
||||
|
||||
require.True(t, llm1.closeCalled)
|
||||
|
||||
18
types/errtypes/errtypes.go
Normal file
18
types/errtypes/errtypes.go
Normal file
@@ -0,0 +1,18 @@
|
||||
// Package errtypes contains custom error types
|
||||
package errtypes
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const UnknownOllamaKeyErrMsg = "unknown ollama key"
|
||||
|
||||
// TODO: This should have a structured response from the API
|
||||
type UnknownOllamaKey struct {
|
||||
Key string
|
||||
}
|
||||
|
||||
func (e *UnknownOllamaKey) Error() string {
|
||||
return fmt.Sprintf("unauthorized: %s %q", UnknownOllamaKeyErrMsg, strings.TrimSpace(e.Key))
|
||||
}
|
||||
299
types/model/file.go
Normal file
299
types/model/file.go
Normal file
@@ -0,0 +1,299 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type File struct {
|
||||
Commands []Command
|
||||
}
|
||||
|
||||
func (f File) String() string {
|
||||
var sb strings.Builder
|
||||
for _, cmd := range f.Commands {
|
||||
fmt.Fprintln(&sb, cmd.String())
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
type Command struct {
|
||||
Name string
|
||||
Args string
|
||||
}
|
||||
|
||||
func (c Command) String() string {
|
||||
var sb strings.Builder
|
||||
switch c.Name {
|
||||
case "model":
|
||||
fmt.Fprintf(&sb, "FROM %s", c.Args)
|
||||
case "license", "template", "system", "adapter":
|
||||
fmt.Fprintf(&sb, "%s %s", strings.ToUpper(c.Name), quote(c.Args))
|
||||
case "message":
|
||||
role, message, _ := strings.Cut(c.Args, ": ")
|
||||
fmt.Fprintf(&sb, "MESSAGE %s %s", role, quote(message))
|
||||
default:
|
||||
fmt.Fprintf(&sb, "PARAMETER %s %s", c.Name, quote(c.Args))
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
type state int
|
||||
|
||||
const (
|
||||
stateNil state = iota
|
||||
stateName
|
||||
stateValue
|
||||
stateParameter
|
||||
stateMessage
|
||||
stateComment
|
||||
)
|
||||
|
||||
var (
|
||||
errMissingFrom = errors.New("no FROM line")
|
||||
errInvalidMessageRole = errors.New("message role must be one of \"system\", \"user\", or \"assistant\"")
|
||||
errInvalidCommand = errors.New("command must be one of \"from\", \"license\", \"template\", \"system\", \"adapter\", \"parameter\", or \"message\"")
|
||||
)
|
||||
|
||||
func ParseFile(r io.Reader) (*File, error) {
|
||||
var cmd Command
|
||||
var curr state
|
||||
var b bytes.Buffer
|
||||
var role string
|
||||
|
||||
var f File
|
||||
|
||||
br := bufio.NewReader(r)
|
||||
for {
|
||||
r, _, err := br.ReadRune()
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
next, r, err := parseRuneForState(r, curr)
|
||||
if errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
return nil, fmt.Errorf("%w: %s", err, b.String())
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// process the state transition, some transitions need to be intercepted and redirected
|
||||
if next != curr {
|
||||
switch curr {
|
||||
case stateName:
|
||||
if !isValidCommand(b.String()) {
|
||||
return nil, errInvalidCommand
|
||||
}
|
||||
|
||||
// next state sometimes depends on the current buffer value
|
||||
switch s := strings.ToLower(b.String()); s {
|
||||
case "from":
|
||||
cmd.Name = "model"
|
||||
case "parameter":
|
||||
// transition to stateParameter which sets command name
|
||||
next = stateParameter
|
||||
case "message":
|
||||
// transition to stateMessage which validates the message role
|
||||
next = stateMessage
|
||||
fallthrough
|
||||
default:
|
||||
cmd.Name = s
|
||||
}
|
||||
case stateParameter:
|
||||
cmd.Name = b.String()
|
||||
case stateMessage:
|
||||
if !isValidMessageRole(b.String()) {
|
||||
return nil, errInvalidMessageRole
|
||||
}
|
||||
|
||||
role = b.String()
|
||||
case stateComment, stateNil:
|
||||
// pass
|
||||
case stateValue:
|
||||
s, ok := unquote(b.String())
|
||||
if !ok || isSpace(r) {
|
||||
if _, err := b.WriteRune(r); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if role != "" {
|
||||
s = role + ": " + s
|
||||
role = ""
|
||||
}
|
||||
|
||||
cmd.Args = s
|
||||
f.Commands = append(f.Commands, cmd)
|
||||
}
|
||||
|
||||
b.Reset()
|
||||
curr = next
|
||||
}
|
||||
|
||||
if strconv.IsPrint(r) {
|
||||
if _, err := b.WriteRune(r); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// flush the buffer
|
||||
switch curr {
|
||||
case stateComment, stateNil:
|
||||
// pass; nothing to flush
|
||||
case stateValue:
|
||||
s, ok := unquote(b.String())
|
||||
if !ok {
|
||||
return nil, io.ErrUnexpectedEOF
|
||||
}
|
||||
|
||||
if role != "" {
|
||||
s = role + ": " + s
|
||||
}
|
||||
|
||||
cmd.Args = s
|
||||
f.Commands = append(f.Commands, cmd)
|
||||
default:
|
||||
return nil, io.ErrUnexpectedEOF
|
||||
}
|
||||
|
||||
for _, cmd := range f.Commands {
|
||||
if cmd.Name == "model" {
|
||||
return &f, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errMissingFrom
|
||||
}
|
||||
|
||||
func parseRuneForState(r rune, cs state) (state, rune, error) {
|
||||
switch cs {
|
||||
case stateNil:
|
||||
switch {
|
||||
case r == '#':
|
||||
return stateComment, 0, nil
|
||||
case isSpace(r), isNewline(r):
|
||||
return stateNil, 0, nil
|
||||
default:
|
||||
return stateName, r, nil
|
||||
}
|
||||
case stateName:
|
||||
switch {
|
||||
case isAlpha(r):
|
||||
return stateName, r, nil
|
||||
case isSpace(r):
|
||||
return stateValue, 0, nil
|
||||
default:
|
||||
return stateNil, 0, errInvalidCommand
|
||||
}
|
||||
case stateValue:
|
||||
switch {
|
||||
case isNewline(r):
|
||||
return stateNil, r, nil
|
||||
case isSpace(r):
|
||||
return stateNil, r, nil
|
||||
default:
|
||||
return stateValue, r, nil
|
||||
}
|
||||
case stateParameter:
|
||||
switch {
|
||||
case isAlpha(r), isNumber(r), r == '_':
|
||||
return stateParameter, r, nil
|
||||
case isSpace(r):
|
||||
return stateValue, 0, nil
|
||||
default:
|
||||
return stateNil, 0, io.ErrUnexpectedEOF
|
||||
}
|
||||
case stateMessage:
|
||||
switch {
|
||||
case isAlpha(r):
|
||||
return stateMessage, r, nil
|
||||
case isSpace(r):
|
||||
return stateValue, 0, nil
|
||||
default:
|
||||
return stateNil, 0, io.ErrUnexpectedEOF
|
||||
}
|
||||
case stateComment:
|
||||
switch {
|
||||
case isNewline(r):
|
||||
return stateNil, 0, nil
|
||||
default:
|
||||
return stateComment, 0, nil
|
||||
}
|
||||
default:
|
||||
return stateNil, 0, errors.New("")
|
||||
}
|
||||
}
|
||||
|
||||
func quote(s string) string {
|
||||
if strings.Contains(s, "\n") || strings.HasPrefix(s, " ") || strings.HasSuffix(s, " ") {
|
||||
if strings.Contains(s, "\"") {
|
||||
return `"""` + s + `"""`
|
||||
}
|
||||
|
||||
return `"` + s + `"`
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func unquote(s string) (string, bool) {
|
||||
// TODO: single quotes
|
||||
if len(s) >= 3 && s[:3] == `"""` {
|
||||
if len(s) >= 6 && s[len(s)-3:] == `"""` {
|
||||
return s[3 : len(s)-3], true
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
||||
if len(s) >= 1 && s[0] == '"' {
|
||||
if len(s) >= 2 && s[len(s)-1] == '"' {
|
||||
return s[1 : len(s)-1], true
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
||||
return s, true
|
||||
}
|
||||
|
||||
func isAlpha(r rune) bool {
|
||||
return r >= 'a' && r <= 'z' || r >= 'A' && r <= 'Z'
|
||||
}
|
||||
|
||||
func isNumber(r rune) bool {
|
||||
return r >= '0' && r <= '9'
|
||||
}
|
||||
|
||||
func isSpace(r rune) bool {
|
||||
return r == ' ' || r == '\t'
|
||||
}
|
||||
|
||||
func isNewline(r rune) bool {
|
||||
return r == '\r' || r == '\n'
|
||||
}
|
||||
|
||||
func isValidMessageRole(role string) bool {
|
||||
return role == "system" || role == "user" || role == "assistant"
|
||||
}
|
||||
|
||||
func isValidCommand(cmd string) bool {
|
||||
switch strings.ToLower(cmd) {
|
||||
case "from", "license", "template", "system", "adapter", "parameter", "message":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
511
types/model/file_test.go
Normal file
511
types/model/file_test.go
Normal file
@@ -0,0 +1,511 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestParseFileFile(t *testing.T) {
|
||||
input := `
|
||||
FROM model1
|
||||
ADAPTER adapter1
|
||||
LICENSE MIT
|
||||
PARAMETER param1 value1
|
||||
PARAMETER param2 value2
|
||||
TEMPLATE template1
|
||||
`
|
||||
|
||||
reader := strings.NewReader(input)
|
||||
|
||||
modelfile, err := ParseFile(reader)
|
||||
assert.NoError(t, err)
|
||||
|
||||
expectedCommands := []Command{
|
||||
{Name: "model", Args: "model1"},
|
||||
{Name: "adapter", Args: "adapter1"},
|
||||
{Name: "license", Args: "MIT"},
|
||||
{Name: "param1", Args: "value1"},
|
||||
{Name: "param2", Args: "value2"},
|
||||
{Name: "template", Args: "template1"},
|
||||
}
|
||||
|
||||
assert.Equal(t, expectedCommands, modelfile.Commands)
|
||||
}
|
||||
|
||||
func TestParseFileFrom(t *testing.T) {
|
||||
var cases = []struct {
|
||||
input string
|
||||
expected []Command
|
||||
err error
|
||||
}{
|
||||
{
|
||||
"FROM foo",
|
||||
[]Command{{Name: "model", Args: "foo"}},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"FROM /path/to/model",
|
||||
[]Command{{Name: "model", Args: "/path/to/model"}},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"FROM /path/to/model/fp16.bin",
|
||||
[]Command{{Name: "model", Args: "/path/to/model/fp16.bin"}},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"FROM llama3:latest",
|
||||
[]Command{{Name: "model", Args: "llama3:latest"}},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"FROM llama3:7b-instruct-q4_K_M",
|
||||
[]Command{{Name: "model", Args: "llama3:7b-instruct-q4_K_M"}},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"", nil, errMissingFrom,
|
||||
},
|
||||
{
|
||||
"PARAMETER param1 value1",
|
||||
nil,
|
||||
errMissingFrom,
|
||||
},
|
||||
{
|
||||
"PARAMETER param1 value1\nFROM foo",
|
||||
[]Command{{Name: "param1", Args: "value1"}, {Name: "model", Args: "foo"}},
|
||||
nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run("", func(t *testing.T) {
|
||||
modelfile, err := ParseFile(strings.NewReader(c.input))
|
||||
assert.ErrorIs(t, err, c.err)
|
||||
if modelfile != nil {
|
||||
assert.Equal(t, c.expected, modelfile.Commands)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFileParametersMissingValue(t *testing.T) {
|
||||
input := `
|
||||
FROM foo
|
||||
PARAMETER param1
|
||||
`
|
||||
|
||||
reader := strings.NewReader(input)
|
||||
|
||||
_, err := ParseFile(reader)
|
||||
assert.ErrorIs(t, err, io.ErrUnexpectedEOF)
|
||||
}
|
||||
|
||||
func TestParseFileBadCommand(t *testing.T) {
|
||||
input := `
|
||||
FROM foo
|
||||
BADCOMMAND param1 value1
|
||||
`
|
||||
_, err := ParseFile(strings.NewReader(input))
|
||||
assert.ErrorIs(t, err, errInvalidCommand)
|
||||
|
||||
}
|
||||
|
||||
func TestParseFileMessages(t *testing.T) {
|
||||
var cases = []struct {
|
||||
input string
|
||||
expected []Command
|
||||
err error
|
||||
}{
|
||||
{
|
||||
`
|
||||
FROM foo
|
||||
MESSAGE system You are a file parser. Always parse things.
|
||||
`,
|
||||
[]Command{
|
||||
{Name: "model", Args: "foo"},
|
||||
{Name: "message", Args: "system: You are a file parser. Always parse things."},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
`
|
||||
FROM foo
|
||||
MESSAGE system You are a file parser. Always parse things.`,
|
||||
[]Command{
|
||||
{Name: "model", Args: "foo"},
|
||||
{Name: "message", Args: "system: You are a file parser. Always parse things."},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
`
|
||||
FROM foo
|
||||
MESSAGE system You are a file parser. Always parse things.
|
||||
MESSAGE user Hey there!
|
||||
MESSAGE assistant Hello, I want to parse all the things!
|
||||
`,
|
||||
[]Command{
|
||||
{Name: "model", Args: "foo"},
|
||||
{Name: "message", Args: "system: You are a file parser. Always parse things."},
|
||||
{Name: "message", Args: "user: Hey there!"},
|
||||
{Name: "message", Args: "assistant: Hello, I want to parse all the things!"},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
`
|
||||
FROM foo
|
||||
MESSAGE system """
|
||||
You are a multiline file parser. Always parse things.
|
||||
"""
|
||||
`,
|
||||
[]Command{
|
||||
{Name: "model", Args: "foo"},
|
||||
{Name: "message", Args: "system: \nYou are a multiline file parser. Always parse things.\n"},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
`
|
||||
FROM foo
|
||||
MESSAGE badguy I'm a bad guy!
|
||||
`,
|
||||
nil,
|
||||
errInvalidMessageRole,
|
||||
},
|
||||
{
|
||||
`
|
||||
FROM foo
|
||||
MESSAGE system
|
||||
`,
|
||||
nil,
|
||||
io.ErrUnexpectedEOF,
|
||||
},
|
||||
{
|
||||
`
|
||||
FROM foo
|
||||
MESSAGE system`,
|
||||
nil,
|
||||
io.ErrUnexpectedEOF,
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run("", func(t *testing.T) {
|
||||
modelfile, err := ParseFile(strings.NewReader(c.input))
|
||||
assert.ErrorIs(t, err, c.err)
|
||||
if modelfile != nil {
|
||||
assert.Equal(t, c.expected, modelfile.Commands)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFileQuoted(t *testing.T) {
|
||||
var cases = []struct {
|
||||
multiline string
|
||||
expected []Command
|
||||
err error
|
||||
}{
|
||||
{
|
||||
`
|
||||
FROM foo
|
||||
SYSTEM """
|
||||
This is a
|
||||
multiline system.
|
||||
"""
|
||||
`,
|
||||
[]Command{
|
||||
{Name: "model", Args: "foo"},
|
||||
{Name: "system", Args: "\nThis is a\nmultiline system.\n"},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
`
|
||||
FROM foo
|
||||
SYSTEM """
|
||||
This is a
|
||||
multiline system."""
|
||||
`,
|
||||
[]Command{
|
||||
{Name: "model", Args: "foo"},
|
||||
{Name: "system", Args: "\nThis is a\nmultiline system."},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
`
|
||||
FROM foo
|
||||
SYSTEM """This is a
|
||||
multiline system."""
|
||||
`,
|
||||
[]Command{
|
||||
{Name: "model", Args: "foo"},
|
||||
{Name: "system", Args: "This is a\nmultiline system."},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
`
|
||||
FROM foo
|
||||
SYSTEM """This is a multiline system."""
|
||||
`,
|
||||
[]Command{
|
||||
{Name: "model", Args: "foo"},
|
||||
{Name: "system", Args: "This is a multiline system."},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
`
|
||||
FROM foo
|
||||
SYSTEM """This is a multiline system.""
|
||||
`,
|
||||
nil,
|
||||
io.ErrUnexpectedEOF,
|
||||
},
|
||||
{
|
||||
`
|
||||
FROM foo
|
||||
SYSTEM "
|
||||
`,
|
||||
nil,
|
||||
io.ErrUnexpectedEOF,
|
||||
},
|
||||
{
|
||||
`
|
||||
FROM foo
|
||||
SYSTEM """
|
||||
This is a multiline system with "quotes".
|
||||
"""
|
||||
`,
|
||||
[]Command{
|
||||
{Name: "model", Args: "foo"},
|
||||
{Name: "system", Args: "\nThis is a multiline system with \"quotes\".\n"},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
`
|
||||
FROM foo
|
||||
SYSTEM """"""
|
||||
`,
|
||||
[]Command{
|
||||
{Name: "model", Args: "foo"},
|
||||
{Name: "system", Args: ""},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
`
|
||||
FROM foo
|
||||
SYSTEM ""
|
||||
`,
|
||||
[]Command{
|
||||
{Name: "model", Args: "foo"},
|
||||
{Name: "system", Args: ""},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
`
|
||||
FROM foo
|
||||
SYSTEM "'"
|
||||
`,
|
||||
[]Command{
|
||||
{Name: "model", Args: "foo"},
|
||||
{Name: "system", Args: "'"},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
`
|
||||
FROM foo
|
||||
SYSTEM """''"'""'""'"'''''""'""'"""
|
||||
`,
|
||||
[]Command{
|
||||
{Name: "model", Args: "foo"},
|
||||
{Name: "system", Args: `''"'""'""'"'''''""'""'`},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
`
|
||||
FROM foo
|
||||
TEMPLATE """
|
||||
{{ .Prompt }}
|
||||
"""`,
|
||||
[]Command{
|
||||
{Name: "model", Args: "foo"},
|
||||
{Name: "template", Args: "\n{{ .Prompt }}\n"},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run("", func(t *testing.T) {
|
||||
modelfile, err := ParseFile(strings.NewReader(c.multiline))
|
||||
assert.ErrorIs(t, err, c.err)
|
||||
if modelfile != nil {
|
||||
assert.Equal(t, c.expected, modelfile.Commands)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFileParameters(t *testing.T) {
|
||||
var cases = map[string]struct {
|
||||
name, value string
|
||||
}{
|
||||
"numa true": {"numa", "true"},
|
||||
"num_ctx 1": {"num_ctx", "1"},
|
||||
"num_batch 1": {"num_batch", "1"},
|
||||
"num_gqa 1": {"num_gqa", "1"},
|
||||
"num_gpu 1": {"num_gpu", "1"},
|
||||
"main_gpu 1": {"main_gpu", "1"},
|
||||
"low_vram true": {"low_vram", "true"},
|
||||
"f16_kv true": {"f16_kv", "true"},
|
||||
"logits_all true": {"logits_all", "true"},
|
||||
"vocab_only true": {"vocab_only", "true"},
|
||||
"use_mmap true": {"use_mmap", "true"},
|
||||
"use_mlock true": {"use_mlock", "true"},
|
||||
"num_thread 1": {"num_thread", "1"},
|
||||
"num_keep 1": {"num_keep", "1"},
|
||||
"seed 1": {"seed", "1"},
|
||||
"num_predict 1": {"num_predict", "1"},
|
||||
"top_k 1": {"top_k", "1"},
|
||||
"top_p 1.0": {"top_p", "1.0"},
|
||||
"tfs_z 1.0": {"tfs_z", "1.0"},
|
||||
"typical_p 1.0": {"typical_p", "1.0"},
|
||||
"repeat_last_n 1": {"repeat_last_n", "1"},
|
||||
"temperature 1.0": {"temperature", "1.0"},
|
||||
"repeat_penalty 1.0": {"repeat_penalty", "1.0"},
|
||||
"presence_penalty 1.0": {"presence_penalty", "1.0"},
|
||||
"frequency_penalty 1.0": {"frequency_penalty", "1.0"},
|
||||
"mirostat 1": {"mirostat", "1"},
|
||||
"mirostat_tau 1.0": {"mirostat_tau", "1.0"},
|
||||
"mirostat_eta 1.0": {"mirostat_eta", "1.0"},
|
||||
"penalize_newline true": {"penalize_newline", "true"},
|
||||
"stop ### User:": {"stop", "### User:"},
|
||||
"stop ### User: ": {"stop", "### User: "},
|
||||
"stop \"### User:\"": {"stop", "### User:"},
|
||||
"stop \"### User: \"": {"stop", "### User: "},
|
||||
"stop \"\"\"### User:\"\"\"": {"stop", "### User:"},
|
||||
"stop \"\"\"### User:\n\"\"\"": {"stop", "### User:\n"},
|
||||
"stop <|endoftext|>": {"stop", "<|endoftext|>"},
|
||||
"stop <|eot_id|>": {"stop", "<|eot_id|>"},
|
||||
"stop </s>": {"stop", "</s>"},
|
||||
}
|
||||
|
||||
for k, v := range cases {
|
||||
t.Run(k, func(t *testing.T) {
|
||||
var b bytes.Buffer
|
||||
fmt.Fprintln(&b, "FROM foo")
|
||||
fmt.Fprintln(&b, "PARAMETER", k)
|
||||
modelfile, err := ParseFile(&b)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, []Command{
|
||||
{Name: "model", Args: "foo"},
|
||||
{Name: v.name, Args: v.value},
|
||||
}, modelfile.Commands)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFileComments(t *testing.T) {
|
||||
var cases = []struct {
|
||||
input string
|
||||
expected []Command
|
||||
}{
|
||||
{
|
||||
`
|
||||
# comment
|
||||
FROM foo
|
||||
`,
|
||||
[]Command{
|
||||
{Name: "model", Args: "foo"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run("", func(t *testing.T) {
|
||||
modelfile, err := ParseFile(strings.NewReader(c.input))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, c.expected, modelfile.Commands)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFileFormatParseFile(t *testing.T) {
|
||||
var cases = []string{
|
||||
`
|
||||
FROM foo
|
||||
ADAPTER adapter1
|
||||
LICENSE MIT
|
||||
PARAMETER param1 value1
|
||||
PARAMETER param2 value2
|
||||
TEMPLATE template1
|
||||
MESSAGE system You are a file parser. Always parse things.
|
||||
MESSAGE user Hey there!
|
||||
MESSAGE assistant Hello, I want to parse all the things!
|
||||
`,
|
||||
`
|
||||
FROM foo
|
||||
ADAPTER adapter1
|
||||
LICENSE MIT
|
||||
PARAMETER param1 value1
|
||||
PARAMETER param2 value2
|
||||
TEMPLATE template1
|
||||
MESSAGE system """
|
||||
You are a store greeter. Always responsed with "Hello!".
|
||||
"""
|
||||
MESSAGE user Hey there!
|
||||
MESSAGE assistant Hello, I want to parse all the things!
|
||||
`,
|
||||
`
|
||||
FROM foo
|
||||
ADAPTER adapter1
|
||||
LICENSE """
|
||||
Very long and boring legal text.
|
||||
Blah blah blah.
|
||||
"Oh look, a quote!"
|
||||
"""
|
||||
|
||||
PARAMETER param1 value1
|
||||
PARAMETER param2 value2
|
||||
TEMPLATE template1
|
||||
MESSAGE system """
|
||||
You are a store greeter. Always responsed with "Hello!".
|
||||
"""
|
||||
MESSAGE user Hey there!
|
||||
MESSAGE assistant Hello, I want to parse all the things!
|
||||
`,
|
||||
`
|
||||
FROM foo
|
||||
SYSTEM ""
|
||||
`,
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run("", func(t *testing.T) {
|
||||
modelfile, err := ParseFile(strings.NewReader(c))
|
||||
assert.NoError(t, err)
|
||||
|
||||
modelfile2, err := ParseFile(strings.NewReader(modelfile.String()))
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, modelfile, modelfile2)
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
@@ -4,6 +4,7 @@ package model
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
@@ -80,9 +81,6 @@ func (k partKind) String() string {
|
||||
//
|
||||
// It is not guaranteed to be valid. Use [Name.IsValid] to check if the name
|
||||
// is valid.
|
||||
//
|
||||
// It is not directly comparable with other Names. Use [Name.Equal] and
|
||||
// [Name.MapHash] for determining equality and using as a map key.
|
||||
type Name struct {
|
||||
Host string
|
||||
Namespace string
|
||||
@@ -109,20 +107,20 @@ type Name struct {
|
||||
// { model }
|
||||
// "@" { digest }
|
||||
// host:
|
||||
// pattern: alphanum { alphanum | "-" | "_" | "." | ":" }*
|
||||
// pattern: { alphanum | "_" } { alphanum | "-" | "_" | "." | ":" }*
|
||||
// length: [1, 350]
|
||||
// namespace:
|
||||
// pattern: alphanum { alphanum | "-" | "_" }*
|
||||
// length: [2, 80]
|
||||
// pattern: { alphanum | "_" } { alphanum | "-" | "_" }*
|
||||
// length: [1, 80]
|
||||
// model:
|
||||
// pattern: alphanum { alphanum | "-" | "_" | "." }*
|
||||
// length: [2, 80]
|
||||
// pattern: { alphanum | "_" } { alphanum | "-" | "_" | "." }*
|
||||
// length: [1, 80]
|
||||
// tag:
|
||||
// pattern: alphanum { alphanum | "-" | "_" | "." }*
|
||||
// pattern: { alphanum | "_" } { alphanum | "-" | "_" | "." }*
|
||||
// length: [1, 80]
|
||||
// digest:
|
||||
// pattern: alphanum { alphanum | "-" | ":" }*
|
||||
// length: [2, 80]
|
||||
// pattern: { alphanum | "_" } { alphanum | "-" | ":" }*
|
||||
// length: [1, 80]
|
||||
//
|
||||
// Most users should use [ParseName] instead, unless need to support
|
||||
// different defaults than DefaultName.
|
||||
@@ -145,18 +143,28 @@ func ParseNameBare(s string) Name {
|
||||
n.RawDigest = MissingPart
|
||||
}
|
||||
|
||||
s, n.Tag, _ = cutPromised(s, ":")
|
||||
// "/" is an illegal tag character, so we can use it to split the host
|
||||
if strings.LastIndex(s, ":") > strings.LastIndex(s, "/") {
|
||||
s, n.Tag, _ = cutPromised(s, ":")
|
||||
}
|
||||
|
||||
s, n.Model, promised = cutPromised(s, "/")
|
||||
if !promised {
|
||||
n.Model = s
|
||||
return n
|
||||
}
|
||||
|
||||
s, n.Namespace, promised = cutPromised(s, "/")
|
||||
if !promised {
|
||||
n.Namespace = s
|
||||
return n
|
||||
}
|
||||
n.Host = s
|
||||
|
||||
scheme, host, ok := strings.Cut(s, "://")
|
||||
if !ok {
|
||||
host = scheme
|
||||
}
|
||||
n.Host = host
|
||||
|
||||
return n
|
||||
}
|
||||
@@ -234,12 +242,12 @@ func (n Name) Filepath() string {
|
||||
if !n.IsFullyQualified() {
|
||||
panic("illegal attempt to get filepath of invalid name")
|
||||
}
|
||||
return filepath.Join(
|
||||
strings.ToLower(n.Host),
|
||||
strings.ToLower(n.Namespace),
|
||||
strings.ToLower(n.Model),
|
||||
strings.ToLower(n.Tag),
|
||||
)
|
||||
return strings.ToLower(filepath.Join(
|
||||
n.Host,
|
||||
n.Namespace,
|
||||
n.Model,
|
||||
n.Tag,
|
||||
))
|
||||
}
|
||||
|
||||
// LogValue returns a slog.Value that represents the name as a string.
|
||||
@@ -254,7 +262,7 @@ func isValidLen(kind partKind, s string) bool {
|
||||
case kindTag:
|
||||
return len(s) >= 1 && len(s) <= 80
|
||||
default:
|
||||
return len(s) >= 2 && len(s) <= 80
|
||||
return len(s) >= 1 && len(s) <= 80
|
||||
}
|
||||
}
|
||||
|
||||
@@ -264,7 +272,7 @@ func isValidPart(kind partKind, s string) bool {
|
||||
}
|
||||
for i := range s {
|
||||
if i == 0 {
|
||||
if !isAlphanumeric(s[i]) {
|
||||
if !isAlphanumericOrUnderscore(s[i]) {
|
||||
return false
|
||||
}
|
||||
continue
|
||||
@@ -280,7 +288,7 @@ func isValidPart(kind partKind, s string) bool {
|
||||
return false
|
||||
}
|
||||
default:
|
||||
if !isAlphanumeric(s[i]) {
|
||||
if !isAlphanumericOrUnderscore(s[i]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
@@ -288,8 +296,8 @@ func isValidPart(kind partKind, s string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func isAlphanumeric(c byte) bool {
|
||||
return c >= 'A' && c <= 'Z' || c >= 'a' && c <= 'z' || c >= '0' && c <= '9'
|
||||
func isAlphanumericOrUnderscore(c byte) bool {
|
||||
return c >= 'A' && c <= 'Z' || c >= 'a' && c <= 'z' || c >= '0' && c <= '9' || c == '_'
|
||||
}
|
||||
|
||||
func cutLast(s, sep string) (before, after string, ok bool) {
|
||||
@@ -311,3 +319,57 @@ func cutPromised(s, sep string) (before, after string, ok bool) {
|
||||
}
|
||||
return cmp.Or(before, MissingPart), cmp.Or(after, MissingPart), true
|
||||
}
|
||||
|
||||
type DigestType byte
|
||||
|
||||
const (
|
||||
DigestTypeInvalid DigestType = iota
|
||||
DigestTypeSHA256
|
||||
)
|
||||
|
||||
func (t DigestType) String() string {
|
||||
switch t {
|
||||
case DigestTypeSHA256:
|
||||
return "sha256"
|
||||
default:
|
||||
return "invalid"
|
||||
}
|
||||
}
|
||||
|
||||
type Digest struct {
|
||||
Type DigestType
|
||||
Sum [32]byte
|
||||
}
|
||||
|
||||
func ParseDigest(s string) (Digest, error) {
|
||||
i := strings.IndexAny(s, "-:")
|
||||
if i < 0 {
|
||||
return Digest{}, fmt.Errorf("invalid digest %q", s)
|
||||
}
|
||||
typ, encSum := s[:i], s[i+1:]
|
||||
if typ != "sha256" {
|
||||
return Digest{}, fmt.Errorf("unsupported digest type %q", typ)
|
||||
}
|
||||
d := Digest{
|
||||
Type: DigestTypeSHA256,
|
||||
}
|
||||
n, err := hex.Decode(d.Sum[:], []byte(encSum))
|
||||
if err != nil {
|
||||
return Digest{}, err
|
||||
}
|
||||
if n != 32 {
|
||||
return Digest{}, fmt.Errorf("digest %q decoded to %d bytes; want 32", encSum, n)
|
||||
}
|
||||
return d, nil
|
||||
}
|
||||
|
||||
func (d Digest) String() string {
|
||||
if d.Type == DigestTypeInvalid {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("sha256-%x", d.Sum)
|
||||
}
|
||||
|
||||
func (d Digest) IsValid() bool {
|
||||
return d.Type != DigestTypeInvalid
|
||||
}
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -14,8 +16,19 @@ func TestParseNameParts(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
want Name
|
||||
wantFilepath string
|
||||
wantValidDigest bool
|
||||
}{
|
||||
{
|
||||
in: "scheme://host:port/namespace/model:tag",
|
||||
want: Name{
|
||||
Host: "host:port",
|
||||
Namespace: "namespace",
|
||||
Model: "model",
|
||||
Tag: "tag",
|
||||
},
|
||||
wantFilepath: filepath.Join("host:port", "namespace", "model", "tag"),
|
||||
},
|
||||
{
|
||||
in: "host/namespace/model:tag",
|
||||
want: Name{
|
||||
@@ -24,6 +37,17 @@ func TestParseNameParts(t *testing.T) {
|
||||
Model: "model",
|
||||
Tag: "tag",
|
||||
},
|
||||
wantFilepath: filepath.Join("host", "namespace", "model", "tag"),
|
||||
},
|
||||
{
|
||||
in: "host:port/namespace/model:tag",
|
||||
want: Name{
|
||||
Host: "host:port",
|
||||
Namespace: "namespace",
|
||||
Model: "model",
|
||||
Tag: "tag",
|
||||
},
|
||||
wantFilepath: filepath.Join("host:port", "namespace", "model", "tag"),
|
||||
},
|
||||
{
|
||||
in: "host/namespace/model",
|
||||
@@ -32,6 +56,16 @@ func TestParseNameParts(t *testing.T) {
|
||||
Namespace: "namespace",
|
||||
Model: "model",
|
||||
},
|
||||
wantFilepath: filepath.Join("host", "namespace", "model", "latest"),
|
||||
},
|
||||
{
|
||||
in: "host:port/namespace/model",
|
||||
want: Name{
|
||||
Host: "host:port",
|
||||
Namespace: "namespace",
|
||||
Model: "model",
|
||||
},
|
||||
wantFilepath: filepath.Join("host:port", "namespace", "model", "latest"),
|
||||
},
|
||||
{
|
||||
in: "namespace/model",
|
||||
@@ -39,12 +73,14 @@ func TestParseNameParts(t *testing.T) {
|
||||
Namespace: "namespace",
|
||||
Model: "model",
|
||||
},
|
||||
wantFilepath: filepath.Join("registry.ollama.ai", "namespace", "model", "latest"),
|
||||
},
|
||||
{
|
||||
in: "model",
|
||||
want: Name{
|
||||
Model: "model",
|
||||
},
|
||||
wantFilepath: filepath.Join("registry.ollama.ai", "library", "model", "latest"),
|
||||
},
|
||||
{
|
||||
in: "h/nn/mm:t",
|
||||
@@ -54,6 +90,7 @@ func TestParseNameParts(t *testing.T) {
|
||||
Model: "mm",
|
||||
Tag: "t",
|
||||
},
|
||||
wantFilepath: filepath.Join("h", "nn", "mm", "t"),
|
||||
},
|
||||
{
|
||||
in: part80 + "/" + part80 + "/" + part80 + ":" + part80,
|
||||
@@ -63,6 +100,7 @@ func TestParseNameParts(t *testing.T) {
|
||||
Model: part80,
|
||||
Tag: part80,
|
||||
},
|
||||
wantFilepath: filepath.Join(part80, part80, part80, part80),
|
||||
},
|
||||
{
|
||||
in: part350 + "/" + part80 + "/" + part80 + ":" + part80,
|
||||
@@ -72,6 +110,7 @@ func TestParseNameParts(t *testing.T) {
|
||||
Model: part80,
|
||||
Tag: part80,
|
||||
},
|
||||
wantFilepath: filepath.Join(part350, part80, part80, part80),
|
||||
},
|
||||
{
|
||||
in: "@digest",
|
||||
@@ -96,11 +135,23 @@ func TestParseNameParts(t *testing.T) {
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("parseName(%q) = %v; want %v", tt.in, got, tt.want)
|
||||
}
|
||||
|
||||
got = ParseName(tt.in)
|
||||
if tt.wantFilepath != "" && got.Filepath() != tt.wantFilepath {
|
||||
t.Errorf("parseName(%q).Filepath() = %q; want %q", tt.in, got.Filepath(), tt.wantFilepath)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
var testCases = map[string]bool{ // name -> valid
|
||||
"": false,
|
||||
|
||||
"_why/_the/_lucky:_stiff": true,
|
||||
|
||||
// minimal
|
||||
"h/n/m:t@d": true,
|
||||
|
||||
"host/namespace/model:tag": true,
|
||||
"host/namespace/model": false,
|
||||
"namespace/model": false,
|
||||
@@ -116,11 +167,12 @@ var testCases = map[string]bool{ // name -> valid
|
||||
"h/nn/mm:t@sha256-1000000000000000000000000000000000000000000000000000000000000000": true, // bare minimum part sizes
|
||||
"h/nn/mm:t@sha256:1000000000000000000000000000000000000000000000000000000000000000": true, // bare minimum part sizes
|
||||
|
||||
"m": false, // model too short
|
||||
"n/mm:": false, // namespace too short
|
||||
"h/n/mm:t": false, // namespace too short
|
||||
"@t": false, // digest too short
|
||||
"mm@d": false, // digest too short
|
||||
// unqualified
|
||||
"m": false,
|
||||
"n/m:": false,
|
||||
"h/n/m": false,
|
||||
"@t": false,
|
||||
"m@d": false,
|
||||
|
||||
// invalids
|
||||
"^": false,
|
||||
@@ -140,8 +192,6 @@ var testCases = map[string]bool{ // name -> valid
|
||||
"hh/nn/mm:-tt@dd": false,
|
||||
"hh/nn/mm:tt@-dd": false,
|
||||
|
||||
"": false,
|
||||
|
||||
// hosts
|
||||
"host:https/namespace/model:tag": true,
|
||||
|
||||
@@ -163,7 +213,6 @@ func TestNameIsValid(t *testing.T) {
|
||||
var numStringTests int
|
||||
for s, want := range testCases {
|
||||
n := ParseNameBare(s)
|
||||
t.Logf("n: %#v", n)
|
||||
got := n.IsValid()
|
||||
if got != want {
|
||||
t.Errorf("parseName(%q).IsValid() = %v; want %v", s, got, want)
|
||||
@@ -212,6 +261,54 @@ func TestNameIsValidPart(t *testing.T) {
|
||||
|
||||
}
|
||||
|
||||
func TestFilepathAllocs(t *testing.T) {
|
||||
n := ParseNameBare("HOST/NAMESPACE/MODEL:TAG")
|
||||
allocs := testing.AllocsPerRun(1000, func() {
|
||||
n.Filepath()
|
||||
})
|
||||
allowedAllocs := 2.0
|
||||
if runtime.GOOS == "windows" {
|
||||
allowedAllocs = 4
|
||||
}
|
||||
if allocs > allowedAllocs {
|
||||
t.Errorf("allocs = %v; allowed %v", allocs, allowedAllocs)
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
validSha256 = "sha256-1000000000000000000000000000000000000000000000000000000000000000"
|
||||
validSha256Old = "sha256:1000000000000000000000000000000000000000000000000000000000000000"
|
||||
)
|
||||
|
||||
func TestParseDigest(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
want string
|
||||
}{
|
||||
{"", ""}, // empty
|
||||
{"sha123-12", ""}, // invalid type
|
||||
{"sha256-", ""}, // invalid sum
|
||||
{"sha256-123", ""}, // invalid odd length sum
|
||||
|
||||
{validSha256, validSha256},
|
||||
{validSha256Old, validSha256},
|
||||
}
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.in, func(t *testing.T) {
|
||||
got, err := ParseDigest(tt.in)
|
||||
if err != nil {
|
||||
if tt.want != "" {
|
||||
t.Errorf("parseDigest(%q) = %v; want %v", tt.in, err, tt.want)
|
||||
}
|
||||
return
|
||||
}
|
||||
if got.String() != tt.want {
|
||||
t.Errorf("parseDigest(%q).String() = %q; want %q", tt.in, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func FuzzName(f *testing.F) {
|
||||
for s := range testCases {
|
||||
f.Add(s)
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
// Package structs contains the Incomparable type.
|
||||
package structs
|
||||
|
||||
// Incomparable is a zero-width incomparable type. If added as the
|
||||
// first field in a struct, it marks that struct as not comparable
|
||||
// (can't do == or be a map key) and usually doesn't add any width to
|
||||
// the struct (unless the struct has only small fields).
|
||||
//
|
||||
// By making a struct incomparable, you can prevent misuse (prevent
|
||||
// people from using ==), but also you can shrink generated binaries,
|
||||
// as the compiler can omit equality funcs from the binary.
|
||||
type Incomparable [0]func()
|
||||
Reference in New Issue
Block a user